Skip to content

Commit bc9dbd7

Browse files
authored
feat: add ToolOutputTrimmer for smart context management (#2468)
1 parent 67687cb commit bc9dbd7

3 files changed

Lines changed: 647 additions & 0 deletions

File tree

src/agents/extensions/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .tool_output_trimmer import ToolOutputTrimmer
2+
3+
__all__ = ["ToolOutputTrimmer"]
Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
"""Built-in call_model_input_filter that trims large tool outputs from older turns.
2+
3+
Agentic applications often accumulate large tool outputs (search results, code execution
4+
output, error analyses) that consume significant tokens but lose relevance as the
5+
conversation progresses. This module provides a configurable filter that surgically trims
6+
bulky tool outputs from older turns while keeping recent turns at full fidelity.
7+
8+
Usage::
9+
10+
from agents import RunConfig
11+
from agents.extensions import ToolOutputTrimmer
12+
13+
config = RunConfig(
14+
call_model_input_filter=ToolOutputTrimmer(
15+
recent_turns=2,
16+
max_output_chars=500,
17+
preview_chars=200,
18+
trimmable_tools={"search", "execute_code"},
19+
),
20+
)
21+
22+
The trimmer operates as a sliding window: the last ``recent_turns`` user messages (and
23+
all items after them) are never modified. Older tool outputs that exceed
24+
``max_output_chars`` — and optionally belong to ``trimmable_tools`` — are replaced with a
25+
compact preview.
26+
"""
27+
28+
from __future__ import annotations
29+
30+
import logging
31+
from dataclasses import dataclass, field
32+
from typing import TYPE_CHECKING, Any
33+
34+
if TYPE_CHECKING:
35+
from ..run_config import CallModelData, ModelInputData
36+
37+
logger = logging.getLogger(__name__)
38+
39+
40+
@dataclass
41+
class ToolOutputTrimmer:
42+
"""Configurable filter that trims large tool outputs from older conversation turns.
43+
44+
This class implements the ``CallModelInputFilter`` protocol and can be passed directly
45+
to ``RunConfig.call_model_input_filter``. It runs immediately before each model call
46+
and replaces large tool outputs from older turns with a concise preview, reducing token
47+
usage without losing the context of what happened.
48+
49+
Args:
50+
recent_turns: Number of recent user messages whose surrounding items are never
51+
trimmed. Defaults to 2.
52+
max_output_chars: Tool outputs above this character count are candidates for
53+
trimming. Defaults to 500.
54+
preview_chars: How many characters of the original output to preserve as a
55+
preview when trimming. Defaults to 200.
56+
trimmable_tools: Optional set of tool names whose outputs can be trimmed. If
57+
``None``, all tool outputs are eligible for trimming. Defaults to ``None``.
58+
"""
59+
60+
recent_turns: int = 2
61+
max_output_chars: int = 500
62+
preview_chars: int = 200
63+
trimmable_tools: frozenset[str] | None = field(default=None)
64+
65+
def __post_init__(self) -> None:
66+
if self.recent_turns < 1:
67+
raise ValueError(f"recent_turns must be >= 1, got {self.recent_turns}")
68+
if self.max_output_chars < 1:
69+
raise ValueError(f"max_output_chars must be >= 1, got {self.max_output_chars}")
70+
if self.preview_chars < 0:
71+
raise ValueError(f"preview_chars must be >= 0, got {self.preview_chars}")
72+
# Coerce any iterable to frozenset for immutability
73+
if self.trimmable_tools is not None and not isinstance(self.trimmable_tools, frozenset):
74+
object.__setattr__(self, "trimmable_tools", frozenset(self.trimmable_tools))
75+
76+
def __call__(self, data: CallModelData[Any]) -> ModelInputData:
77+
"""Filter callback invoked before each model call.
78+
79+
Finds the boundary between old and recent items, then trims large tool outputs
80+
from old turns. Does NOT mutate the original items — creates shallow copies when
81+
needed.
82+
"""
83+
from ..run_config import ModelInputData as _ModelInputData
84+
85+
model_data = data.model_data
86+
items = model_data.input
87+
88+
if not items:
89+
return model_data
90+
91+
boundary = self._find_recent_boundary(items)
92+
if boundary == 0:
93+
return model_data
94+
95+
call_id_to_name = self._build_call_id_to_name(items)
96+
97+
trimmed_count = 0
98+
chars_saved = 0
99+
new_items: list[Any] = []
100+
101+
for i, item in enumerate(items):
102+
if (
103+
i < boundary
104+
and isinstance(item, dict)
105+
and item.get("type") == "function_call_output"
106+
):
107+
output = item.get("output", "")
108+
output_str = output if isinstance(output, str) else str(output)
109+
output_len = len(output_str)
110+
111+
call_id = str(item.get("call_id", ""))
112+
tool_name = call_id_to_name.get(call_id, "")
113+
114+
if output_len > self.max_output_chars and (
115+
self.trimmable_tools is None or tool_name in self.trimmable_tools
116+
):
117+
display_name = tool_name or "unknown_tool"
118+
preview = output_str[: self.preview_chars]
119+
summary = (
120+
f"[Trimmed: {display_name} output — {output_len} chars → "
121+
f"{self.preview_chars} char preview]\n{preview}..."
122+
)
123+
# Only replace if summary is actually shorter than the original
124+
if len(summary) < output_len:
125+
trimmed_item = dict(item)
126+
trimmed_item["output"] = summary
127+
new_items.append(trimmed_item)
128+
129+
trimmed_count += 1
130+
chars_saved += output_len - len(summary)
131+
continue
132+
133+
new_items.append(item)
134+
135+
if trimmed_count > 0:
136+
logger.debug(
137+
f"ToolOutputTrimmer: trimmed {trimmed_count} tool output(s), "
138+
f"saved ~{chars_saved} chars"
139+
)
140+
141+
return _ModelInputData(input=new_items, instructions=model_data.instructions)
142+
143+
def _find_recent_boundary(self, items: list[Any]) -> int:
144+
"""Find the index separating 'old' items from 'recent' items.
145+
146+
Walks backward through the items list counting user messages. Returns the index
147+
of the Nth user message from the end, where N = ``recent_turns``. Items at or
148+
after this index are considered recent and will not be trimmed.
149+
150+
If there are fewer than N user messages, returns 0 (nothing is old).
151+
"""
152+
user_msg_count = 0
153+
for i in range(len(items) - 1, -1, -1):
154+
item = items[i]
155+
if isinstance(item, dict) and item.get("role") == "user":
156+
user_msg_count += 1
157+
if user_msg_count >= self.recent_turns:
158+
return i
159+
return 0
160+
161+
def _build_call_id_to_name(self, items: list[Any]) -> dict[str, str]:
162+
"""Build a mapping from function call_id to tool name."""
163+
mapping: dict[str, str] = {}
164+
for item in items:
165+
if isinstance(item, dict) and item.get("type") == "function_call":
166+
call_id = item.get("call_id")
167+
name = item.get("name")
168+
if call_id and name:
169+
mapping[call_id] = name
170+
return mapping

0 commit comments

Comments
 (0)