Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion src/agents/run_internal/session_persistence.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,12 +569,20 @@ def _sanitize_openai_conversation_item(item: TResponseInputItem) -> TResponseInp
"""Remove provider-specific fields before fingerprinting or persistence."""
if isinstance(item, dict):
clean_item = cast(dict[str, Any], strip_internal_input_item_metadata(item))
clean_item.pop("id", None)
if _should_strip_openai_conversation_item_id(clean_item):
clean_item.pop("id", None)
clean_item.pop("provider_data", None)
return cast(TResponseInputItem, clean_item)
return item


def _should_strip_openai_conversation_item_id(item: dict[str, Any]) -> bool:
"""Return True when the Conversations API does not require an item's ``id`` field."""
# Built-in tool calls and assistant-authored items rely on their ids for replay. The only
# shapes we know are safe to drop are user messages and function_call_output items.
return item.get("role") == "user" or item.get("type") == "function_call_output"


def _fingerprint_or_repr(item: TResponseInputItem, *, ignore_ids_for_matching: bool) -> str:
"""Fingerprint an item or fall back to repr when unavailable."""
return fingerprint_input_item(item, ignore_ids_for_matching=ignore_ids_for_matching) or repr(
Expand Down
66 changes: 66 additions & 0 deletions tests/test_agent_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2577,6 +2577,72 @@ async def test_save_result_to_session_sanitizes_original_input_items() -> None:
assert "title" not in saved_tool_call


@pytest.mark.asyncio
@pytest.mark.parametrize(
"payload",
[
cast(
dict[str, Any],
{"type": "file_search_call", "id": "fs_123", "queries": ["customer profile"]},
),
cast(
dict[str, Any],
{
"type": "web_search_call",
"id": "ws_123",
"action": {"type": "search", "query": "customer profile"},
},
),
cast(
dict[str, Any],
{"type": "code_interpreter_call", "id": "ci_123", "status": "completed"},
),
],
)
async def test_save_result_to_session_preserves_required_built_in_tool_call_ids(
payload: dict[str, Any],
) -> None:
class DummyOpenAIConversationsSession(OpenAIConversationsSession):
def __init__(self) -> None:
self.saved_items: list[TResponseInputItem] = []

async def _get_session_id(self) -> str:
return "conv_test"

async def add_items(self, items: list[TResponseInputItem]) -> None:
self.saved_items.extend(items)

async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]:
return []

async def pop_item(self) -> TResponseInputItem | None:
return None

async def clear_session(self) -> None:
return None

session = DummyOpenAIConversationsSession()
agent = Agent(name="agent", model=FakeModel())
run_state: RunState[Any] = RunState(
context=RunContextWrapper(context={}),
original_input="input",
starting_agent=agent,
max_turns=1,
)

saved_count = await save_result_to_session(
session,
[],
cast(list[RunItem], [_DummyRunItem(payload)]),
run_state,
)

assert saved_count == 1
assert run_state._current_turn_persisted_item_count == 1
assert len(session.saved_items) == 1
assert cast(dict[str, Any], session.saved_items[0])["id"] == payload["id"]


@pytest.mark.asyncio
async def test_prepare_input_with_session_strips_internal_tool_call_metadata() -> None:
tool_call = cast(
Expand Down
65 changes: 64 additions & 1 deletion tests/test_run_internal_items.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from agents.models.fake_id import FAKE_RESPONSES_ID
from agents.result import RunResult
from agents.run_context import RunContextWrapper
from agents.run_internal import items as run_items
from agents.run_internal import items as run_items, session_persistence as session_persistence


def test_drop_orphan_function_calls_preserves_non_mapping_entries() -> None:
Expand Down Expand Up @@ -533,6 +533,69 @@ def test_fingerprint_input_item_ignores_internal_tool_call_metadata() -> None:
)


def test_sanitize_openai_conversation_item_preserves_required_tool_call_ids() -> None:
file_search_call = cast(
TResponseInputItem,
{
"type": "file_search_call",
"id": "fs_123",
"queries": ["customer profile"],
"status": "completed",
},
)
web_search_call = cast(
TResponseInputItem,
{
"type": "web_search_call",
"id": "ws_123",
"action": {"type": "search", "query": "customer profile"},
"status": "completed",
},
)
code_interpreter_call = cast(
TResponseInputItem,
{
"type": "code_interpreter_call",
"id": "ci_123",
"status": "completed",
},
)
user_message = cast(
TResponseInputItem,
{"role": "user", "id": "user_123", "content": "hello"},
)
function_call_output = cast(
TResponseInputItem,
{"type": "function_call_output", "id": "out_123", "call_id": "call_123", "output": "ok"},
)

assert (
cast(
dict[str, Any], session_persistence._sanitize_openai_conversation_item(file_search_call)
)["id"]
== "fs_123"
)
assert (
cast(
dict[str, Any], session_persistence._sanitize_openai_conversation_item(web_search_call)
)["id"]
== "ws_123"
)
assert (
cast(
dict[str, Any],
session_persistence._sanitize_openai_conversation_item(code_interpreter_call),
)["id"]
== "ci_123"
)
assert "id" not in cast(
dict[str, Any], session_persistence._sanitize_openai_conversation_item(user_message)
)
assert "id" not in cast(
dict[str, Any], session_persistence._sanitize_openai_conversation_item(function_call_output)
)


def test_run_result_to_input_list_preserves_tool_search_items() -> None:
agent = Agent(name="A")
result = RunResult(
Expand Down