Skip to content

Commit 696a9a8

Browse files
authored
Serialize structured realtime tool outputs as JSON (#2608)
1 parent 9ac31ab commit 696a9a8

File tree

2 files changed

+109
-5
lines changed

2 files changed

+109
-5
lines changed

src/agents/realtime/session.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
from __future__ import annotations
22

33
import asyncio
4+
import dataclasses
45
import inspect
56
import json
67
from collections.abc import AsyncIterator
78
from typing import Any, cast
89

10+
from pydantic import BaseModel
911
from typing_extensions import assert_never
1012

1113
from ..agent import Agent
@@ -67,6 +69,29 @@
6769
REJECTION_MESSAGE = DEFAULT_APPROVAL_REJECTION_MESSAGE
6870

6971

72+
def _serialize_tool_output(output: Any) -> str:
73+
"""Serialize structured tool outputs to JSON when possible."""
74+
if isinstance(output, str):
75+
return output
76+
if isinstance(output, BaseModel):
77+
try:
78+
output = output.model_dump(mode="json")
79+
except Exception:
80+
try:
81+
output = output.model_dump()
82+
except Exception:
83+
return str(output)
84+
elif dataclasses.is_dataclass(output) and not isinstance(output, type):
85+
try:
86+
output = dataclasses.asdict(output)
87+
except Exception:
88+
return str(output)
89+
try:
90+
return json.dumps(output, ensure_ascii=False)
91+
except (TypeError, ValueError):
92+
return str(output)
93+
94+
7095
class RealtimeSession(RealtimeModelListener):
7196
"""A connection to a realtime model. It streams events from the model to you, and allows you to
7297
send messages and audio to the model.
@@ -610,7 +635,9 @@ async def _handle_tool_call(
610635

611636
await self._model.send_event(
612637
RealtimeModelSendToolOutput(
613-
tool_call=event, output=str(result), start_response=True
638+
tool_call=event,
639+
output=_serialize_tool_output(result),
640+
start_response=True,
614641
)
615642
)
616643

tests/realtime/test_session.py

Lines changed: 81 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
11
import asyncio
2+
import dataclasses
3+
import json
4+
import threading
25
from typing import Any, cast
36
from unittest.mock import AsyncMock, Mock, PropertyMock, patch
47

58
import pytest
9+
from pydantic import BaseModel, ConfigDict
610

711
from agents.exceptions import UserError
812
from agents.guardrail import GuardrailFunctionOutput, OutputGuardrail
@@ -55,7 +59,7 @@
5559
RealtimeModelSendSessionUpdate,
5660
RealtimeModelSendUserInput,
5761
)
58-
from agents.realtime.session import REJECTION_MESSAGE, RealtimeSession
62+
from agents.realtime.session import REJECTION_MESSAGE, RealtimeSession, _serialize_tool_output
5963
from agents.tool import FunctionTool
6064
from agents.tool_context import ToolContext
6165

@@ -1364,7 +1368,7 @@ async def test_tool_call_with_custom_call_id(self, mock_model, mock_agent, mock_
13641368

13651369
@pytest.mark.asyncio
13661370
async def test_tool_result_conversion_to_string(self, mock_model, mock_agent):
1367-
"""Test that tool results are converted to strings for model output"""
1371+
"""Test that structured tool results are serialized to JSON for model output."""
13681372
# Create tool that returns non-string result
13691373
tool = _set_default_timeout_fields(Mock(spec=FunctionTool))
13701374
tool.name = "test_function"
@@ -1381,10 +1385,83 @@ async def test_tool_result_conversion_to_string(self, mock_model, mock_agent):
13811385

13821386
await session._handle_tool_call(tool_call_event)
13831387

1384-
# Verify result was converted to string
1388+
# Verify result was serialized to JSON
13851389
sent_call, sent_output, _ = mock_model.sent_tool_outputs[0]
13861390
assert isinstance(sent_output, str)
1387-
assert sent_output == "{'result': 'data', 'count': 42}"
1391+
assert sent_output == json.dumps({"result": "data", "count": 42})
1392+
1393+
@pytest.mark.asyncio
1394+
async def test_tool_result_conversion_serializes_pydantic_models(self, mock_model, mock_agent):
1395+
"""Test that pydantic tool results are serialized to JSON for model output."""
1396+
1397+
class ToolResult(BaseModel):
1398+
name: str
1399+
score: int
1400+
1401+
tool = _set_default_timeout_fields(Mock(spec=FunctionTool))
1402+
tool.name = "test_function"
1403+
tool.on_invoke_tool = AsyncMock(return_value=ToolResult(name="demo", score=7))
1404+
tool.needs_approval = False
1405+
1406+
mock_agent.get_all_tools.return_value = [tool]
1407+
1408+
session = RealtimeSession(mock_model, mock_agent, None)
1409+
1410+
tool_call_event = RealtimeModelToolCallEvent(
1411+
name="test_function", call_id="call_pydantic_conversion", arguments="{}"
1412+
)
1413+
1414+
await session._handle_tool_call(tool_call_event)
1415+
1416+
_sent_call, sent_output, _ = mock_model.sent_tool_outputs[0]
1417+
assert sent_output == json.dumps({"name": "demo", "score": 7})
1418+
1419+
def test_serialize_tool_output_ignores_non_pydantic_model_dump_objects(self) -> None:
1420+
class FakeModelDump:
1421+
def model_dump(self, *_args: Any, **_kwargs: Any) -> dict[str, Any]:
1422+
raise AssertionError("non-pydantic objects should not use model_dump")
1423+
1424+
def __str__(self) -> str:
1425+
return "fake-model-dump-object"
1426+
1427+
assert _serialize_tool_output(FakeModelDump()) == "fake-model-dump-object"
1428+
1429+
def test_serialize_tool_output_falls_back_when_pydantic_json_dump_fails(self) -> None:
1430+
class FallbackModel(BaseModel):
1431+
model_config = ConfigDict(arbitrary_types_allowed=True)
1432+
1433+
payload: object
1434+
1435+
def model_dump(self, *args: Any, **kwargs: Any) -> dict[str, Any]:
1436+
if kwargs.get("mode") == "json":
1437+
raise ValueError("json mode failed")
1438+
return {"payload": "ok"}
1439+
1440+
assert _serialize_tool_output(FallbackModel(payload=object())) == json.dumps(
1441+
{"payload": "ok"}
1442+
)
1443+
1444+
def test_serialize_tool_output_returns_string_when_pydantic_dump_fails(self) -> None:
1445+
class BrokenModel(BaseModel):
1446+
value: int
1447+
1448+
def model_dump(self, *args: Any, **kwargs: Any) -> dict[str, Any]:
1449+
raise ValueError("dump failed")
1450+
1451+
def __str__(self) -> str:
1452+
return "broken-model"
1453+
1454+
assert _serialize_tool_output(BrokenModel(value=1)) == "broken-model"
1455+
1456+
def test_serialize_tool_output_returns_string_when_dataclass_asdict_fails(self) -> None:
1457+
@dataclasses.dataclass
1458+
class BrokenDataclass:
1459+
lock: Any
1460+
1461+
def __str__(self) -> str:
1462+
return "broken-dataclass"
1463+
1464+
assert _serialize_tool_output(BrokenDataclass(lock=threading.Lock())) == "broken-dataclass"
13881465

13891466
@pytest.mark.asyncio
13901467
async def test_mixed_tool_types_filtering(self, mock_model, mock_agent):

0 commit comments

Comments
 (0)