|
| 1 | +"""Built-in call_model_input_filter that trims large tool outputs from older turns. |
| 2 | +
|
| 3 | +Agentic applications often accumulate large tool outputs (search results, code execution |
| 4 | +output, error analyses) that consume significant tokens but lose relevance as the |
| 5 | +conversation progresses. This module provides a configurable filter that surgically trims |
| 6 | +bulky tool outputs from older turns while keeping recent turns at full fidelity. |
| 7 | +
|
| 8 | +Usage:: |
| 9 | +
|
| 10 | + from agents import RunConfig |
| 11 | + from agents.extensions import ToolOutputTrimmer |
| 12 | +
|
| 13 | + config = RunConfig( |
| 14 | + call_model_input_filter=ToolOutputTrimmer( |
| 15 | + recent_turns=2, |
| 16 | + max_output_chars=500, |
| 17 | + preview_chars=200, |
| 18 | + trimmable_tools={"search", "execute_code"}, |
| 19 | + ), |
| 20 | + ) |
| 21 | +
|
| 22 | +The trimmer operates as a sliding window: the last ``recent_turns`` user messages (and |
| 23 | +all items after them) are never modified. Older tool outputs that exceed |
| 24 | +``max_output_chars`` — and optionally belong to ``trimmable_tools`` — are replaced with a |
| 25 | +compact preview. |
| 26 | +""" |
| 27 | + |
| 28 | +from __future__ import annotations |
| 29 | + |
| 30 | +import logging |
| 31 | +from dataclasses import dataclass, field |
| 32 | +from typing import TYPE_CHECKING, Any |
| 33 | + |
| 34 | +if TYPE_CHECKING: |
| 35 | + from ..run_config import CallModelData, ModelInputData |
| 36 | + |
| 37 | +logger = logging.getLogger(__name__) |
| 38 | + |
| 39 | + |
| 40 | +@dataclass |
| 41 | +class ToolOutputTrimmer: |
| 42 | + """Configurable filter that trims large tool outputs from older conversation turns. |
| 43 | +
|
| 44 | + This class implements the ``CallModelInputFilter`` protocol and can be passed directly |
| 45 | + to ``RunConfig.call_model_input_filter``. It runs immediately before each model call |
| 46 | + and replaces large tool outputs from older turns with a concise preview, reducing token |
| 47 | + usage without losing the context of what happened. |
| 48 | +
|
| 49 | + Args: |
| 50 | + recent_turns: Number of recent user messages whose surrounding items are never |
| 51 | + trimmed. Defaults to 2. |
| 52 | + max_output_chars: Tool outputs above this character count are candidates for |
| 53 | + trimming. Defaults to 500. |
| 54 | + preview_chars: How many characters of the original output to preserve as a |
| 55 | + preview when trimming. Defaults to 200. |
| 56 | + trimmable_tools: Optional set of tool names whose outputs can be trimmed. If |
| 57 | + ``None``, all tool outputs are eligible for trimming. Defaults to ``None``. |
| 58 | + """ |
| 59 | + |
| 60 | + recent_turns: int = 2 |
| 61 | + max_output_chars: int = 500 |
| 62 | + preview_chars: int = 200 |
| 63 | + trimmable_tools: frozenset[str] | None = field(default=None) |
| 64 | + |
| 65 | + def __post_init__(self) -> None: |
| 66 | + if self.recent_turns < 1: |
| 67 | + raise ValueError(f"recent_turns must be >= 1, got {self.recent_turns}") |
| 68 | + if self.max_output_chars < 1: |
| 69 | + raise ValueError(f"max_output_chars must be >= 1, got {self.max_output_chars}") |
| 70 | + if self.preview_chars < 0: |
| 71 | + raise ValueError(f"preview_chars must be >= 0, got {self.preview_chars}") |
| 72 | + # Coerce any iterable to frozenset for immutability |
| 73 | + if self.trimmable_tools is not None and not isinstance(self.trimmable_tools, frozenset): |
| 74 | + object.__setattr__(self, "trimmable_tools", frozenset(self.trimmable_tools)) |
| 75 | + |
| 76 | + def __call__(self, data: CallModelData[Any]) -> ModelInputData: |
| 77 | + """Filter callback invoked before each model call. |
| 78 | +
|
| 79 | + Finds the boundary between old and recent items, then trims large tool outputs |
| 80 | + from old turns. Does NOT mutate the original items — creates shallow copies when |
| 81 | + needed. |
| 82 | + """ |
| 83 | + from ..run_config import ModelInputData as _ModelInputData |
| 84 | + |
| 85 | + model_data = data.model_data |
| 86 | + items = model_data.input |
| 87 | + |
| 88 | + if not items: |
| 89 | + return model_data |
| 90 | + |
| 91 | + boundary = self._find_recent_boundary(items) |
| 92 | + if boundary == 0: |
| 93 | + return model_data |
| 94 | + |
| 95 | + call_id_to_name = self._build_call_id_to_name(items) |
| 96 | + |
| 97 | + trimmed_count = 0 |
| 98 | + chars_saved = 0 |
| 99 | + new_items: list[Any] = [] |
| 100 | + |
| 101 | + for i, item in enumerate(items): |
| 102 | + if ( |
| 103 | + i < boundary |
| 104 | + and isinstance(item, dict) |
| 105 | + and item.get("type") == "function_call_output" |
| 106 | + ): |
| 107 | + output = item.get("output", "") |
| 108 | + output_str = output if isinstance(output, str) else str(output) |
| 109 | + output_len = len(output_str) |
| 110 | + |
| 111 | + call_id = str(item.get("call_id", "")) |
| 112 | + tool_name = call_id_to_name.get(call_id, "") |
| 113 | + |
| 114 | + if output_len > self.max_output_chars and ( |
| 115 | + self.trimmable_tools is None or tool_name in self.trimmable_tools |
| 116 | + ): |
| 117 | + display_name = tool_name or "unknown_tool" |
| 118 | + preview = output_str[: self.preview_chars] |
| 119 | + summary = ( |
| 120 | + f"[Trimmed: {display_name} output — {output_len} chars → " |
| 121 | + f"{self.preview_chars} char preview]\n{preview}..." |
| 122 | + ) |
| 123 | + # Only replace if summary is actually shorter than the original |
| 124 | + if len(summary) < output_len: |
| 125 | + trimmed_item = dict(item) |
| 126 | + trimmed_item["output"] = summary |
| 127 | + new_items.append(trimmed_item) |
| 128 | + |
| 129 | + trimmed_count += 1 |
| 130 | + chars_saved += output_len - len(summary) |
| 131 | + continue |
| 132 | + |
| 133 | + new_items.append(item) |
| 134 | + |
| 135 | + if trimmed_count > 0: |
| 136 | + logger.debug( |
| 137 | + f"ToolOutputTrimmer: trimmed {trimmed_count} tool output(s), " |
| 138 | + f"saved ~{chars_saved} chars" |
| 139 | + ) |
| 140 | + |
| 141 | + return _ModelInputData(input=new_items, instructions=model_data.instructions) |
| 142 | + |
| 143 | + def _find_recent_boundary(self, items: list[Any]) -> int: |
| 144 | + """Find the index separating 'old' items from 'recent' items. |
| 145 | +
|
| 146 | + Walks backward through the items list counting user messages. Returns the index |
| 147 | + of the Nth user message from the end, where N = ``recent_turns``. Items at or |
| 148 | + after this index are considered recent and will not be trimmed. |
| 149 | +
|
| 150 | + If there are fewer than N user messages, returns 0 (nothing is old). |
| 151 | + """ |
| 152 | + user_msg_count = 0 |
| 153 | + for i in range(len(items) - 1, -1, -1): |
| 154 | + item = items[i] |
| 155 | + if isinstance(item, dict) and item.get("role") == "user": |
| 156 | + user_msg_count += 1 |
| 157 | + if user_msg_count >= self.recent_turns: |
| 158 | + return i |
| 159 | + return 0 |
| 160 | + |
| 161 | + def _build_call_id_to_name(self, items: list[Any]) -> dict[str, str]: |
| 162 | + """Build a mapping from function call_id to tool name.""" |
| 163 | + mapping: dict[str, str] = {} |
| 164 | + for item in items: |
| 165 | + if isinstance(item, dict) and item.get("type") == "function_call": |
| 166 | + call_id = item.get("call_id") |
| 167 | + name = item.get("name") |
| 168 | + if call_id and name: |
| 169 | + mapping[call_id] = name |
| 170 | + return mapping |
0 commit comments