diff --git a/src/agents/run_internal/session_persistence.py b/src/agents/run_internal/session_persistence.py index 25874ad345..8197bae594 100644 --- a/src/agents/run_internal/session_persistence.py +++ b/src/agents/run_internal/session_persistence.py @@ -569,12 +569,20 @@ def _sanitize_openai_conversation_item(item: TResponseInputItem) -> TResponseInp """Remove provider-specific fields before fingerprinting or persistence.""" if isinstance(item, dict): clean_item = cast(dict[str, Any], strip_internal_input_item_metadata(item)) - clean_item.pop("id", None) + if _should_strip_openai_conversation_item_id(clean_item): + clean_item.pop("id", None) clean_item.pop("provider_data", None) return cast(TResponseInputItem, clean_item) return item +def _should_strip_openai_conversation_item_id(item: dict[str, Any]) -> bool: + """Return True when the Conversations API does not require an item's ``id`` field.""" + # Built-in tool calls and assistant-authored items rely on their ids for replay. The only + # shapes we know are safe to drop are user messages and function_call_output items. + return item.get("role") == "user" or item.get("type") == "function_call_output" + + def _fingerprint_or_repr(item: TResponseInputItem, *, ignore_ids_for_matching: bool) -> str: """Fingerprint an item or fall back to repr when unavailable.""" return fingerprint_input_item(item, ignore_ids_for_matching=ignore_ids_for_matching) or repr( diff --git a/tests/test_agent_runner.py b/tests/test_agent_runner.py index 45cdab7711..a3caf9f70b 100644 --- a/tests/test_agent_runner.py +++ b/tests/test_agent_runner.py @@ -2577,6 +2577,72 @@ async def test_save_result_to_session_sanitizes_original_input_items() -> None: assert "title" not in saved_tool_call +@pytest.mark.asyncio +@pytest.mark.parametrize( + "payload", + [ + cast( + dict[str, Any], + {"type": "file_search_call", "id": "fs_123", "queries": ["customer profile"]}, + ), + cast( + dict[str, Any], + { + "type": "web_search_call", + "id": "ws_123", + "action": {"type": "search", "query": "customer profile"}, + }, + ), + cast( + dict[str, Any], + {"type": "code_interpreter_call", "id": "ci_123", "status": "completed"}, + ), + ], +) +async def test_save_result_to_session_preserves_required_built_in_tool_call_ids( + payload: dict[str, Any], +) -> None: + class DummyOpenAIConversationsSession(OpenAIConversationsSession): + def __init__(self) -> None: + self.saved_items: list[TResponseInputItem] = [] + + async def _get_session_id(self) -> str: + return "conv_test" + + async def add_items(self, items: list[TResponseInputItem]) -> None: + self.saved_items.extend(items) + + async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: + return [] + + async def pop_item(self) -> TResponseInputItem | None: + return None + + async def clear_session(self) -> None: + return None + + session = DummyOpenAIConversationsSession() + agent = Agent(name="agent", model=FakeModel()) + run_state: RunState[Any] = RunState( + context=RunContextWrapper(context={}), + original_input="input", + starting_agent=agent, + max_turns=1, + ) + + saved_count = await save_result_to_session( + session, + [], + cast(list[RunItem], [_DummyRunItem(payload)]), + run_state, + ) + + assert saved_count == 1 + assert run_state._current_turn_persisted_item_count == 1 + assert len(session.saved_items) == 1 + assert cast(dict[str, Any], session.saved_items[0])["id"] == payload["id"] + + @pytest.mark.asyncio async def test_prepare_input_with_session_strips_internal_tool_call_metadata() -> None: tool_call = cast( diff --git a/tests/test_run_internal_items.py b/tests/test_run_internal_items.py index e7daafa577..28b7dda53c 100644 --- a/tests/test_run_internal_items.py +++ b/tests/test_run_internal_items.py @@ -23,7 +23,7 @@ from agents.models.fake_id import FAKE_RESPONSES_ID from agents.result import RunResult from agents.run_context import RunContextWrapper -from agents.run_internal import items as run_items +from agents.run_internal import items as run_items, session_persistence as session_persistence def test_drop_orphan_function_calls_preserves_non_mapping_entries() -> None: @@ -533,6 +533,69 @@ def test_fingerprint_input_item_ignores_internal_tool_call_metadata() -> None: ) +def test_sanitize_openai_conversation_item_preserves_required_tool_call_ids() -> None: + file_search_call = cast( + TResponseInputItem, + { + "type": "file_search_call", + "id": "fs_123", + "queries": ["customer profile"], + "status": "completed", + }, + ) + web_search_call = cast( + TResponseInputItem, + { + "type": "web_search_call", + "id": "ws_123", + "action": {"type": "search", "query": "customer profile"}, + "status": "completed", + }, + ) + code_interpreter_call = cast( + TResponseInputItem, + { + "type": "code_interpreter_call", + "id": "ci_123", + "status": "completed", + }, + ) + user_message = cast( + TResponseInputItem, + {"role": "user", "id": "user_123", "content": "hello"}, + ) + function_call_output = cast( + TResponseInputItem, + {"type": "function_call_output", "id": "out_123", "call_id": "call_123", "output": "ok"}, + ) + + assert ( + cast( + dict[str, Any], session_persistence._sanitize_openai_conversation_item(file_search_call) + )["id"] + == "fs_123" + ) + assert ( + cast( + dict[str, Any], session_persistence._sanitize_openai_conversation_item(web_search_call) + )["id"] + == "ws_123" + ) + assert ( + cast( + dict[str, Any], + session_persistence._sanitize_openai_conversation_item(code_interpreter_call), + )["id"] + == "ci_123" + ) + assert "id" not in cast( + dict[str, Any], session_persistence._sanitize_openai_conversation_item(user_message) + ) + assert "id" not in cast( + dict[str, Any], session_persistence._sanitize_openai_conversation_item(function_call_output) + ) + + def test_run_result_to_input_list_preserves_tool_search_items() -> None: agent = Agent(name="A") result = RunResult(