Skip to content

Commit 335b7e6

Browse files
authored
fix: isolate nested agent-tool resume cache by RunState scope (#2501)
1 parent e8c749e commit 335b7e6

File tree

11 files changed

+263
-49
lines changed

11 files changed

+263
-49
lines changed

src/agents/agent.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,10 @@
2222
)
2323
from .agent_tool_state import (
2424
consume_agent_tool_run_result,
25+
get_agent_tool_state_scope,
2526
peek_agent_tool_run_result,
2627
record_agent_tool_run_result,
28+
set_agent_tool_state_scope,
2729
)
2830
from .exceptions import ModelBehaviorError, UserError
2931
from .guardrail import InputGuardrail, OutputGuardrail
@@ -593,6 +595,7 @@ async def _run_agent_impl(context: ToolContext, input_json: str) -> Any:
593595
resolved_run_config = run_config
594596
if resolved_run_config is None and isinstance(context, ToolContext):
595597
resolved_run_config = context.run_config
598+
tool_state_scope_id = get_agent_tool_state_scope(context)
596599
if isinstance(context, ToolContext):
597600
# Use a fresh ToolContext to avoid sharing approval state with parent runs.
598601
nested_context = ToolContext(
@@ -605,17 +608,20 @@ async def _run_agent_impl(context: ToolContext, input_json: str) -> Any:
605608
agent=context.agent,
606609
run_config=resolved_run_config,
607610
)
611+
set_agent_tool_state_scope(nested_context, tool_state_scope_id)
608612
if should_capture_tool_input:
609613
nested_context.tool_input = params_data
610614
elif isinstance(context, RunContextWrapper):
611615
if should_capture_tool_input:
612616
nested_context = RunContextWrapper(context=context.context)
617+
set_agent_tool_state_scope(nested_context, tool_state_scope_id)
613618
nested_context.tool_input = params_data
614619
else:
615620
nested_context = context.context
616621
else:
617622
if should_capture_tool_input:
618623
nested_context = RunContextWrapper(context=context)
624+
set_agent_tool_state_scope(nested_context, tool_state_scope_id)
619625
nested_context.tool_input = params_data
620626
else:
621627
nested_context = context
@@ -678,7 +684,10 @@ def _apply_nested_approvals(
678684
)
679685

680686
if isinstance(context, ToolContext) and context.tool_call is not None:
681-
pending_run_result = peek_agent_tool_run_result(context.tool_call)
687+
pending_run_result = peek_agent_tool_run_result(
688+
context.tool_call,
689+
scope_id=tool_state_scope_id,
690+
)
682691
if pending_run_result and getattr(pending_run_result, "interruptions", None):
683692
status = _nested_approvals_status(pending_run_result.interruptions)
684693
if status == "pending":
@@ -693,7 +702,10 @@ def _apply_nested_approvals(
693702
context,
694703
pending_run_result.interruptions,
695704
)
696-
consume_agent_tool_run_result(context.tool_call)
705+
consume_agent_tool_run_result(
706+
context.tool_call,
707+
scope_id=tool_state_scope_id,
708+
)
697709

698710
if run_result is None:
699711
if on_stream is not None:
@@ -780,7 +792,11 @@ async def dispatch_stream_events() -> None:
780792
interruptions = getattr(run_result, "interruptions", None)
781793
if isinstance(context, ToolContext) and context.tool_call is not None and interruptions:
782794
if should_record_run_result:
783-
record_agent_tool_run_result(context.tool_call, run_result)
795+
record_agent_tool_run_result(
796+
context.tool_call,
797+
run_result,
798+
scope_id=tool_state_scope_id,
799+
)
784800

785801
if custom_output_extractor:
786802
return await custom_output_extractor(run_result)

src/agents/agent_tool_state.py

Lines changed: 82 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,57 @@
11
from __future__ import annotations
22

33
import weakref
4-
from typing import TYPE_CHECKING
4+
from typing import TYPE_CHECKING, Any
55

66
if TYPE_CHECKING:
77
from openai.types.responses.response_function_tool_call import ResponseFunctionToolCall
88

99
from .result import RunResult, RunResultStreaming
1010

11+
ToolCallSignature = tuple[str, str, str, str, str | None, str | None]
12+
ScopedToolCallSignature = tuple[str | None, ToolCallSignature]
13+
14+
_AGENT_TOOL_STATE_SCOPE_ATTR = "_agent_tool_state_scope_id"
15+
1116
# Ephemeral maps linking tool call objects to nested agent results within the same run.
1217
# Store by object identity, and index by a stable signature to avoid call ID collisions.
1318
_agent_tool_run_results_by_obj: dict[int, RunResult | RunResultStreaming] = {}
1419
_agent_tool_run_results_by_signature: dict[
15-
tuple[str, str, str, str, str | None, str | None],
20+
ScopedToolCallSignature,
1621
set[int],
1722
] = {}
1823
_agent_tool_run_result_signature_by_obj: dict[
1924
int,
20-
tuple[str, str, str, str, str | None, str | None],
25+
ScopedToolCallSignature,
2126
] = {}
2227
_agent_tool_call_refs_by_obj: dict[int, weakref.ReferenceType[ResponseFunctionToolCall]] = {}
2328

2429

30+
def get_agent_tool_state_scope(context: Any) -> str | None:
31+
"""Read the private agent-tool cache scope id from a context wrapper."""
32+
scope_id = getattr(context, _AGENT_TOOL_STATE_SCOPE_ATTR, None)
33+
return scope_id if isinstance(scope_id, str) else None
34+
35+
36+
def set_agent_tool_state_scope(context: Any, scope_id: str | None) -> None:
37+
"""Attach or clear the private agent-tool cache scope id on a context wrapper."""
38+
if context is None:
39+
return
40+
if scope_id is None:
41+
try:
42+
delattr(context, _AGENT_TOOL_STATE_SCOPE_ATTR)
43+
except Exception:
44+
return
45+
return
46+
try:
47+
setattr(context, _AGENT_TOOL_STATE_SCOPE_ATTR, scope_id)
48+
except Exception:
49+
return
50+
51+
2552
def _tool_call_signature(
2653
tool_call: ResponseFunctionToolCall,
27-
) -> tuple[str, str, str, str, str | None, str | None]:
54+
) -> ToolCallSignature:
2855
"""Build a stable signature for fallback lookup across tool call instances."""
2956
return (
3057
tool_call.call_id,
@@ -36,11 +63,21 @@ def _tool_call_signature(
3663
)
3764

3865

66+
def _scoped_tool_call_signature(
67+
tool_call: ResponseFunctionToolCall, *, scope_id: str | None
68+
) -> ScopedToolCallSignature:
69+
"""Build a scope-qualified signature so independently restored states do not collide."""
70+
return (scope_id, _tool_call_signature(tool_call))
71+
72+
3973
def _index_agent_tool_run_result(
40-
tool_call: ResponseFunctionToolCall, tool_call_obj_id: int
74+
tool_call: ResponseFunctionToolCall,
75+
tool_call_obj_id: int,
76+
*,
77+
scope_id: str | None,
4178
) -> None:
4279
"""Track tool call objects by signature for fallback lookup."""
43-
signature = _tool_call_signature(tool_call)
80+
signature = _scoped_tool_call_signature(tool_call, scope_id=scope_id)
4481
_agent_tool_run_result_signature_by_obj[tool_call_obj_id] = signature
4582
_agent_tool_run_results_by_signature.setdefault(signature, set()).add(tool_call_obj_id)
4683

@@ -80,26 +117,40 @@ def _on_tool_call_gc(_ref: weakref.ReferenceType[ResponseFunctionToolCall]) -> N
80117

81118

82119
def record_agent_tool_run_result(
83-
tool_call: ResponseFunctionToolCall, run_result: RunResult | RunResultStreaming
120+
tool_call: ResponseFunctionToolCall,
121+
run_result: RunResult | RunResultStreaming,
122+
*,
123+
scope_id: str | None = None,
84124
) -> None:
85125
"""Store the nested agent run result by tool call identity."""
86126
tool_call_obj_id = id(tool_call)
87127
_agent_tool_run_results_by_obj[tool_call_obj_id] = run_result
88-
_index_agent_tool_run_result(tool_call, tool_call_obj_id)
128+
_index_agent_tool_run_result(tool_call, tool_call_obj_id, scope_id=scope_id)
89129
_register_tool_call_ref(tool_call, tool_call_obj_id)
90130

91131

132+
def _tool_call_obj_matches_scope(tool_call_obj_id: int, *, scope_id: str | None) -> bool:
133+
scoped_signature = _agent_tool_run_result_signature_by_obj.get(tool_call_obj_id)
134+
if scoped_signature is None:
135+
# Fallback for unindexed entries.
136+
return scope_id is None
137+
return scoped_signature[0] == scope_id
138+
139+
92140
def consume_agent_tool_run_result(
93141
tool_call: ResponseFunctionToolCall,
142+
*,
143+
scope_id: str | None = None,
94144
) -> RunResult | RunResultStreaming | None:
95145
"""Return and drop the stored nested agent run result for the given tool call."""
96146
obj_id = id(tool_call)
97-
run_result = _agent_tool_run_results_by_obj.pop(obj_id, None)
98-
if run_result is not None:
99-
_drop_agent_tool_run_result(obj_id)
100-
return run_result
147+
if _tool_call_obj_matches_scope(obj_id, scope_id=scope_id):
148+
run_result = _agent_tool_run_results_by_obj.pop(obj_id, None)
149+
if run_result is not None:
150+
_drop_agent_tool_run_result(obj_id)
151+
return run_result
101152

102-
signature = _tool_call_signature(tool_call)
153+
signature = _scoped_tool_call_signature(tool_call, scope_id=scope_id)
103154
candidate_ids = _agent_tool_run_results_by_signature.get(signature)
104155
if not candidate_ids:
105156
return None
@@ -115,14 +166,17 @@ def consume_agent_tool_run_result(
115166

116167
def peek_agent_tool_run_result(
117168
tool_call: ResponseFunctionToolCall,
169+
*,
170+
scope_id: str | None = None,
118171
) -> RunResult | RunResultStreaming | None:
119172
"""Return the stored nested agent run result without removing it."""
120173
obj_id = id(tool_call)
121-
run_result = _agent_tool_run_results_by_obj.get(obj_id)
122-
if run_result is not None:
123-
return run_result
174+
if _tool_call_obj_matches_scope(obj_id, scope_id=scope_id):
175+
run_result = _agent_tool_run_results_by_obj.get(obj_id)
176+
if run_result is not None:
177+
return run_result
124178

125-
signature = _tool_call_signature(tool_call)
179+
signature = _scoped_tool_call_signature(tool_call, scope_id=scope_id)
126180
candidate_ids = _agent_tool_run_results_by_signature.get(signature)
127181
if not candidate_ids:
128182
return None
@@ -133,15 +187,20 @@ def peek_agent_tool_run_result(
133187
return _agent_tool_run_results_by_obj.get(candidate_id)
134188

135189

136-
def drop_agent_tool_run_result(tool_call: ResponseFunctionToolCall) -> None:
190+
def drop_agent_tool_run_result(
191+
tool_call: ResponseFunctionToolCall,
192+
*,
193+
scope_id: str | None = None,
194+
) -> None:
137195
"""Drop the stored nested agent run result, if present."""
138196
obj_id = id(tool_call)
139-
run_result = _agent_tool_run_results_by_obj.pop(obj_id, None)
140-
if run_result is not None:
141-
_drop_agent_tool_run_result(obj_id)
142-
return
197+
if _tool_call_obj_matches_scope(obj_id, scope_id=scope_id):
198+
run_result = _agent_tool_run_results_by_obj.pop(obj_id, None)
199+
if run_result is not None:
200+
_drop_agent_tool_run_result(obj_id)
201+
return
143202

144-
signature = _tool_call_signature(tool_call)
203+
signature = _scoped_tool_call_signature(tool_call, scope_id=scope_id)
145204
candidate_ids = _agent_tool_run_results_by_signature.get(signature)
146205
if not candidate_ids:
147206
return

src/agents/run.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
from . import _debug
1111
from .agent import Agent
12+
from .agent_tool_state import set_agent_tool_state_scope
1213
from .exceptions import (
1314
AgentsException,
1415
InputGuardrailTripwireTriggered,
@@ -555,6 +556,7 @@ async def run(
555556
session_items = []
556557
model_responses = []
557558
context_wrapper = ensure_context_wrapper(context)
559+
set_agent_tool_state_scope(context_wrapper, None)
558560
run_state = RunState(
559561
context=context_wrapper,
560562
original_input=original_input,
@@ -1458,6 +1460,7 @@ def run_streamed(
14581460
auto_previous_response_id=auto_previous_response_id,
14591461
)
14601462
context_wrapper = ensure_context_wrapper(context)
1463+
set_agent_tool_state_scope(context_wrapper, None)
14611464
# input_for_state is the same as input_for_result here
14621465
input_for_state = input_for_result
14631466
run_state = RunState(

src/agents/run_internal/agent_runner_helpers.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from typing import Any, cast
66

77
from ..agent import Agent
8+
from ..agent_tool_state import set_agent_tool_state_scope
89
from ..exceptions import UserError
910
from ..guardrail import InputGuardrailResult
1011
from ..items import ModelResponse, RunItem, ToolApprovalItem, TResponseInputItem
@@ -141,10 +142,12 @@ def resolve_resumed_context(
141142
"""Return the context wrapper for a resumed run, overriding when provided."""
142143
if context is not None:
143144
context_wrapper = ensure_context_wrapper(context)
145+
set_agent_tool_state_scope(context_wrapper, run_state._agent_tool_state_scope_id)
144146
run_state._context = context_wrapper
145147
return context_wrapper
146148
if run_state._context is None:
147149
run_state._context = ensure_context_wrapper(context)
150+
set_agent_tool_state_scope(run_state._context, run_state._agent_tool_state_scope_id)
148151
return run_state._context
149152

150153

src/agents/run_internal/items.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,10 +194,11 @@ def function_rejection_item(
194194
tool_call: Any,
195195
*,
196196
rejection_message: str = REJECTION_MESSAGE,
197+
scope_id: str | None = None,
197198
) -> ToolCallOutputItem:
198199
"""Build a ToolCallOutputItem representing a rejected function tool call."""
199200
if isinstance(tool_call, ResponseFunctionToolCall):
200-
drop_agent_tool_run_result(tool_call)
201+
drop_agent_tool_run_result(tool_call, scope_id=scope_id)
201202
return ToolCallOutputItem(
202203
output=rejection_message,
203204
raw_item=ItemHelpers.tool_call_output_item(tool_call, rejection_message),

src/agents/run_internal/tool_execution.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,11 @@
2020
from openai.types.responses.response_output_item import McpApprovalRequest
2121

2222
from ..agent import Agent
23-
from ..agent_tool_state import consume_agent_tool_run_result, peek_agent_tool_run_result
23+
from ..agent_tool_state import (
24+
consume_agent_tool_run_result,
25+
get_agent_tool_state_scope,
26+
peek_agent_tool_run_result,
27+
)
2428
from ..editor import ApplyPatchOperation, ApplyPatchResult
2529
from ..exceptions import (
2630
AgentsException,
@@ -840,6 +844,7 @@ async def execute_function_tool_calls(
840844
"""Execute function tool calls with approvals, guardrails, and hooks."""
841845
tool_input_guardrail_results: list[ToolInputGuardrailResult] = []
842846
tool_output_guardrail_results: list[ToolOutputGuardrailResult] = []
847+
tool_state_scope_id = get_agent_tool_state_scope(context_wrapper)
843848

844849
async def run_single_tool(func_tool: FunctionTool, tool_call: ResponseFunctionToolCall) -> Any:
845850
with function_span(func_tool.name) as span_fn:
@@ -903,6 +908,7 @@ async def run_single_tool(func_tool: FunctionTool, tool_call: ResponseFunctionTo
903908
agent,
904909
tool_call,
905910
rejection_message=rejection_message,
911+
scope_id=tool_state_scope_id,
906912
),
907913
)
908914

@@ -972,7 +978,10 @@ async def run_single_tool(func_tool: FunctionTool, tool_call: ResponseFunctionTo
972978
function_tool_results = []
973979
for tool_run, result in zip(tool_runs, results):
974980
if isinstance(result, FunctionToolResult):
975-
nested_run_result = consume_agent_tool_run_result(tool_run.tool_call)
981+
nested_run_result = consume_agent_tool_run_result(
982+
tool_run.tool_call,
983+
scope_id=tool_state_scope_id,
984+
)
976985
if nested_run_result:
977986
result.agent_run_result = nested_run_result
978987
nested_interruptions_from_result: list[ToolApprovalItem] = (
@@ -985,7 +994,10 @@ async def run_single_tool(func_tool: FunctionTool, tool_call: ResponseFunctionTo
985994

986995
function_tool_results.append(result)
987996
else:
988-
nested_run_result = peek_agent_tool_run_result(tool_run.tool_call)
997+
nested_run_result = peek_agent_tool_run_result(
998+
tool_run.tool_call,
999+
scope_id=tool_state_scope_id,
1000+
)
9891001
nested_interruptions: list[ToolApprovalItem] = []
9901002
if nested_run_result:
9911003
nested_interruptions = (
@@ -994,9 +1006,15 @@ async def run_single_tool(func_tool: FunctionTool, tool_call: ResponseFunctionTo
9941006
else []
9951007
)
9961008
if nested_run_result and not nested_interruptions:
997-
nested_run_result = consume_agent_tool_run_result(tool_run.tool_call)
1009+
nested_run_result = consume_agent_tool_run_result(
1010+
tool_run.tool_call,
1011+
scope_id=tool_state_scope_id,
1012+
)
9981013
elif nested_run_result is None:
999-
nested_run_result = consume_agent_tool_run_result(tool_run.tool_call)
1014+
nested_run_result = consume_agent_tool_run_result(
1015+
tool_run.tool_call,
1016+
scope_id=tool_state_scope_id,
1017+
)
10001018
if nested_run_result:
10011019
nested_interruptions = (
10021020
nested_run_result.interruptions

0 commit comments

Comments
 (0)