Skip to content

Commit e8c749e

Browse files
authored
fix: #2487 persist nested agent-tool HITL state across RunState JSON round-trips (#2500)
1 parent a226a0d commit e8c749e

File tree

2 files changed

+316
-5
lines changed

2 files changed

+316
-5
lines changed

src/agents/run_state.py

Lines changed: 164 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -544,7 +544,12 @@ def to_json(
544544
result["current_step"] = self._serialize_current_step()
545545
result["last_model_response"] = _serialize_last_model_response(model_responses)
546546
result["last_processed_response"] = (
547-
self._serialize_processed_response(self._last_processed_response)
547+
self._serialize_processed_response(
548+
self._last_processed_response,
549+
context_serializer=context_serializer,
550+
strict_context=strict_context,
551+
include_tracing_api_key=include_tracing_api_key,
552+
)
548553
if self._last_processed_response
549554
else None
550555
)
@@ -556,7 +561,12 @@ def to_json(
556561
return result
557562

558563
def _serialize_processed_response(
559-
self, processed_response: ProcessedResponse
564+
self,
565+
processed_response: ProcessedResponse,
566+
*,
567+
context_serializer: ContextSerializer | None = None,
568+
strict_context: bool = False,
569+
include_tracing_api_key: bool = False,
560570
) -> dict[str, Any]:
561571
"""Serialize a ProcessedResponse to JSON format.
562572
@@ -568,6 +578,14 @@ def _serialize_processed_response(
568578
"""
569579

570580
action_groups = _serialize_tool_action_groups(processed_response)
581+
_serialize_pending_nested_agent_tool_runs(
582+
parent_state=self,
583+
function_entries=action_groups.get("functions", []),
584+
function_runs=processed_response.functions,
585+
context_serializer=context_serializer,
586+
strict_context=strict_context,
587+
include_tracing_api_key=include_tracing_api_key,
588+
)
571589

572590
interruptions_data = [
573591
_serialize_tool_approval_interruption(interruption, include_tool_name=True)
@@ -1138,6 +1156,82 @@ def _serialize_tool_action_groups(
11381156
return serialized
11391157

11401158

1159+
def _serialize_pending_nested_agent_tool_runs(
1160+
*,
1161+
parent_state: RunState[Any, Any],
1162+
function_entries: Sequence[dict[str, Any]],
1163+
function_runs: Sequence[Any],
1164+
context_serializer: ContextSerializer | None = None,
1165+
strict_context: bool = False,
1166+
include_tracing_api_key: bool = False,
1167+
) -> None:
1168+
"""Attach serialized nested run state for pending agent-as-tool interruptions."""
1169+
if not function_entries or not function_runs:
1170+
return
1171+
1172+
from .agent_tool_state import peek_agent_tool_run_result
1173+
1174+
for entry, function_run in zip(function_entries, function_runs):
1175+
tool_call = getattr(function_run, "tool_call", None)
1176+
if not isinstance(tool_call, ResponseFunctionToolCall):
1177+
continue
1178+
1179+
pending_run_result = peek_agent_tool_run_result(tool_call)
1180+
if pending_run_result is None:
1181+
continue
1182+
1183+
interruptions = getattr(pending_run_result, "interruptions", None)
1184+
if not isinstance(interruptions, list) or not interruptions:
1185+
continue
1186+
1187+
to_state = getattr(pending_run_result, "to_state", None)
1188+
if not callable(to_state):
1189+
continue
1190+
1191+
try:
1192+
nested_state = to_state()
1193+
except Exception:
1194+
if strict_context:
1195+
raise
1196+
logger.warning(
1197+
"Failed to capture nested agent run state for tool call %s.",
1198+
tool_call.call_id,
1199+
)
1200+
continue
1201+
1202+
if not isinstance(nested_state, RunState):
1203+
continue
1204+
if nested_state is parent_state:
1205+
# Defensive guard against accidental self-referential serialization loops.
1206+
continue
1207+
1208+
try:
1209+
entry["agent_run_state"] = nested_state.to_json(
1210+
context_serializer=context_serializer,
1211+
strict_context=strict_context,
1212+
include_tracing_api_key=include_tracing_api_key,
1213+
)
1214+
except Exception:
1215+
if strict_context:
1216+
raise
1217+
logger.warning(
1218+
"Failed to serialize nested agent run state for tool call %s.",
1219+
tool_call.call_id,
1220+
)
1221+
1222+
1223+
class _SerializedAgentToolRunResult:
1224+
"""Minimal run-result wrapper used to restore nested agent-as-tool resumptions."""
1225+
1226+
def __init__(self, state: RunState[Any, Agent[Any]]) -> None:
1227+
self._state = state
1228+
self.interruptions = list(state.get_interruptions())
1229+
self.final_output = None
1230+
1231+
def to_state(self) -> RunState[Any, Agent[Any]]:
1232+
return self._state
1233+
1234+
11411235
def _serialize_guardrail_results(
11421236
results: Sequence[InputGuardrailResult | OutputGuardrailResult],
11431237
) -> list[dict[str, Any]]:
@@ -1215,11 +1309,65 @@ def _build_handoffs_map(current_agent: Agent[Any]) -> dict[str, Handoff[Any, Age
12151309
return handoffs_map
12161310

12171311

1312+
async def _restore_pending_nested_agent_tool_runs(
1313+
*,
1314+
current_agent: Agent[Any],
1315+
function_entries: Sequence[Any],
1316+
function_runs: Sequence[Any],
1317+
context_deserializer: ContextDeserializer | None = None,
1318+
strict_context: bool = False,
1319+
) -> None:
1320+
"""Rehydrate nested agent-as-tool run state into the ephemeral tool-call cache."""
1321+
if not function_entries or not function_runs:
1322+
return
1323+
1324+
from .agent_tool_state import drop_agent_tool_run_result, record_agent_tool_run_result
1325+
1326+
for entry, function_run in zip(function_entries, function_runs):
1327+
if not isinstance(entry, Mapping):
1328+
continue
1329+
nested_state_data = entry.get("agent_run_state")
1330+
if not isinstance(nested_state_data, Mapping):
1331+
continue
1332+
1333+
tool_call = getattr(function_run, "tool_call", None)
1334+
if not isinstance(tool_call, ResponseFunctionToolCall):
1335+
continue
1336+
1337+
try:
1338+
nested_state = await _build_run_state_from_json(
1339+
initial_agent=current_agent,
1340+
state_json=dict(nested_state_data),
1341+
context_deserializer=context_deserializer,
1342+
strict_context=strict_context,
1343+
)
1344+
except Exception:
1345+
if strict_context:
1346+
raise
1347+
logger.warning(
1348+
"Failed to deserialize nested agent run state for tool call %s.",
1349+
tool_call.call_id,
1350+
)
1351+
continue
1352+
1353+
pending_result = _SerializedAgentToolRunResult(nested_state)
1354+
if not pending_result.interruptions:
1355+
continue
1356+
1357+
# Replace any stale cache entry with the same signature so resumed runs do not read
1358+
# older pending interruptions after consuming this restored entry.
1359+
drop_agent_tool_run_result(tool_call)
1360+
record_agent_tool_run_result(tool_call, cast(Any, pending_result))
1361+
1362+
12181363
async def _deserialize_processed_response(
12191364
processed_response_data: dict[str, Any],
12201365
current_agent: Agent[Any],
12211366
context: RunContextWrapper[Any],
12221367
agent_map: dict[str, Agent[Any]],
1368+
*,
1369+
context_deserializer: ContextDeserializer | None = None,
1370+
strict_context: bool = False,
12231371
) -> ProcessedResponse:
12241372
"""Deserialize a ProcessedResponse from JSON data.
12251373
@@ -1403,6 +1551,14 @@ def _deserialize_action_groups() -> dict[str, list[Any]]:
14031551
shell_actions = action_groups["shell_actions"]
14041552
apply_patch_actions = action_groups["apply_patch_actions"]
14051553

1554+
await _restore_pending_nested_agent_tool_runs(
1555+
current_agent=current_agent,
1556+
function_entries=processed_response_data.get("functions", []),
1557+
function_runs=functions,
1558+
context_deserializer=context_deserializer,
1559+
strict_context=strict_context,
1560+
)
1561+
14061562
mcp_approval_requests: list[ToolRunMCPApprovalRequest] = []
14071563
for request_data in processed_response_data.get("mcp_approval_requests", []):
14081564
request_item_data = request_data.get("request_item", {})
@@ -1824,7 +1980,12 @@ async def _build_run_state_from_json(
18241980
last_processed_response_data = state_json.get("last_processed_response")
18251981
if last_processed_response_data and state._context is not None:
18261982
state._last_processed_response = await _deserialize_processed_response(
1827-
last_processed_response_data, current_agent, state._context, agent_map
1983+
last_processed_response_data,
1984+
current_agent,
1985+
state._context,
1986+
agent_map,
1987+
context_deserializer=context_deserializer,
1988+
strict_context=strict_context,
18281989
)
18291990
else:
18301991
state._last_processed_response = None

0 commit comments

Comments
 (0)