Skip to content

Commit 09ea6aa

Browse files
fix: remove_all_tools missing hosted tool types (#2885)
1 parent 3dffa4b commit 09ea6aa

File tree

2 files changed

+65
-0
lines changed

2 files changed

+65
-0
lines changed

src/agents/extensions/handoff_filters.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
MCPListToolsItem,
1414
ReasoningItem,
1515
RunItem,
16+
ToolApprovalItem,
1617
ToolCallItem,
1718
ToolCallOutputItem,
1819
ToolSearchCallItem,
@@ -63,6 +64,7 @@ def _remove_tools_from_items(items: tuple[RunItem, ...]) -> tuple[RunItem, ...]:
6364
or isinstance(item, MCPListToolsItem)
6465
or isinstance(item, MCPApprovalRequestItem)
6566
or isinstance(item, MCPApprovalResponseItem)
67+
or isinstance(item, ToolApprovalItem)
6668
):
6769
continue
6870
filtered_items.append(item)
@@ -86,6 +88,14 @@ def _remove_tool_types_from_input(
8688
"mcp_approval_request",
8789
"mcp_approval_response",
8890
"reasoning",
91+
"code_interpreter_call",
92+
"image_generation_call",
93+
"local_shell_call",
94+
"local_shell_call_output",
95+
"shell_call",
96+
"shell_call_output",
97+
"apply_patch_call",
98+
"apply_patch_call_output",
8999
]
90100

91101
filtered_items: list[TResponseInputItem] = []

tests/test_extension_filters.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
MCPListToolsItem,
2525
MessageOutputItem,
2626
ReasoningItem,
27+
ToolApprovalItem,
2728
ToolCallItem,
2829
ToolCallOutputItem,
2930
ToolSearchCallItem,
@@ -1015,3 +1016,57 @@ def test_removes_mixed_mcp_and_function_items() -> None:
10151016
assert len(filtered_data.input_history) == 2
10161017
assert len(filtered_data.pre_handoff_items) == 1
10171018
assert len(filtered_data.new_items) == 1
1019+
1020+
1021+
def _get_hosted_tool_input_item(type_name: str) -> TResponseInputItem:
1022+
return cast(TResponseInputItem, {"id": "ht1", "type": type_name})
1023+
1024+
1025+
def _get_tool_approval_run_item() -> ToolApprovalItem:
1026+
return ToolApprovalItem(
1027+
agent=fake_agent(),
1028+
raw_item={"type": "function_call", "call_id": "c1", "name": "fn", "arguments": "{}"},
1029+
tool_name="fn",
1030+
)
1031+
1032+
1033+
def test_removes_hosted_tool_types_from_input_history() -> None:
1034+
"""Hosted tool types in raw input history should be removed by remove_all_tools."""
1035+
hosted_types = [
1036+
"code_interpreter_call",
1037+
"image_generation_call",
1038+
"local_shell_call",
1039+
"local_shell_call_output",
1040+
"shell_call",
1041+
"shell_call_output",
1042+
"apply_patch_call",
1043+
"apply_patch_call_output",
1044+
]
1045+
input_items: list[TResponseInputItem] = [_get_message_input_item("Hello")]
1046+
for t in hosted_types:
1047+
input_items.append(_get_hosted_tool_input_item(t))
1048+
input_items.append(_get_message_input_item("World"))
1049+
1050+
handoff_input_data = handoff_data(input_history=tuple(input_items))
1051+
filtered_data = remove_all_tools(handoff_input_data)
1052+
assert len(filtered_data.input_history) == 2
1053+
for item in filtered_data.input_history:
1054+
assert not isinstance(item, str)
1055+
assert item.get("type") not in set(hosted_types)
1056+
1057+
1058+
def test_removes_tool_approval_from_new_items() -> None:
1059+
"""ToolApprovalItem should be removed from new_items and pre_handoff_items."""
1060+
handoff_input_data = handoff_data(
1061+
pre_handoff_items=(
1062+
_get_tool_approval_run_item(),
1063+
_get_message_output_run_item("kept"),
1064+
),
1065+
new_items=(
1066+
_get_tool_approval_run_item(),
1067+
_get_message_output_run_item("also kept"),
1068+
),
1069+
)
1070+
filtered_data = remove_all_tools(handoff_input_data)
1071+
assert len(filtered_data.pre_handoff_items) == 1
1072+
assert len(filtered_data.new_items) == 1

0 commit comments

Comments
 (0)