Skip to content

Commit da1ad76

Browse files
authored
fix: preserve plain run context for run-context tool wrappers (#2634)
1 parent 5fd8720 commit da1ad76

File tree

2 files changed

+275
-2
lines changed

2 files changed

+275
-2
lines changed

src/agents/tool.py

Lines changed: 120 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import ast
34
import asyncio
45
import copy
56
import dataclasses
@@ -9,8 +10,10 @@
910
import weakref
1011
from collections.abc import Awaitable, Mapping
1112
from dataclasses import dataclass, field
13+
from types import UnionType
1214
from typing import (
1315
TYPE_CHECKING,
16+
Annotated,
1417
Any,
1518
Callable,
1619
Generic,
@@ -19,6 +22,9 @@
1922
TypeVar,
2023
Union,
2124
cast,
25+
get_args,
26+
get_origin,
27+
get_type_hints,
2228
overload,
2329
)
2430

@@ -1373,19 +1379,131 @@ async def maybe_invoke_function_tool_failure_error_function(
13731379
return result
13741380

13751381

1382+
def _annotation_expr_name(expr: ast.expr) -> str | None:
1383+
"""Return the unqualified type name for a string annotation expression node."""
1384+
if isinstance(expr, ast.Name):
1385+
return expr.id
1386+
if isinstance(expr, ast.Attribute):
1387+
return expr.attr
1388+
return None
1389+
1390+
1391+
def _string_annotation_mentions_context_type(annotation: str, *, type_name: str) -> bool:
1392+
"""Return True when a string annotation structurally references the given context type."""
1393+
try:
1394+
expression = ast.parse(annotation, mode="eval").body
1395+
except SyntaxError:
1396+
return False
1397+
1398+
return _annotation_expr_mentions_context_type(expression, type_name=type_name)
1399+
1400+
1401+
def _annotation_expr_mentions_context_type(expr: ast.expr, *, type_name: str) -> bool:
1402+
"""Return True when an annotation expression structurally references the given context type."""
1403+
if isinstance(expr, ast.Constant) and isinstance(expr.value, str):
1404+
return _string_annotation_mentions_context_type(expr.value, type_name=type_name)
1405+
1406+
if _annotation_expr_name(expr) == type_name:
1407+
return True
1408+
1409+
if isinstance(expr, ast.BinOp) and isinstance(expr.op, ast.BitOr):
1410+
return _annotation_expr_mentions_context_type(
1411+
expr.left, type_name=type_name
1412+
) or _annotation_expr_mentions_context_type(expr.right, type_name=type_name)
1413+
1414+
if isinstance(expr, ast.Subscript):
1415+
wrapper_name = _annotation_expr_name(expr.value)
1416+
args = expr.slice.elts if isinstance(expr.slice, ast.Tuple) else (expr.slice,)
1417+
1418+
if wrapper_name == "Annotated":
1419+
return bool(args) and _annotation_expr_mentions_context_type(
1420+
args[0], type_name=type_name
1421+
)
1422+
1423+
if wrapper_name in {"Optional", "Union"}:
1424+
return any(
1425+
_annotation_expr_mentions_context_type(arg, type_name=type_name) for arg in args
1426+
)
1427+
1428+
return _annotation_expr_mentions_context_type(expr.value, type_name=type_name)
1429+
1430+
return False
1431+
1432+
1433+
def _annotation_mentions_context_type(annotation: Any, *, context_type: type[Any]) -> bool:
1434+
"""Return True when an annotation structurally references the given context type."""
1435+
if annotation is inspect.Signature.empty:
1436+
return False
1437+
1438+
if isinstance(annotation, str):
1439+
return _string_annotation_mentions_context_type(annotation, type_name=context_type.__name__)
1440+
1441+
origin = get_origin(annotation)
1442+
1443+
if annotation is context_type or origin is context_type:
1444+
return True
1445+
1446+
if origin is Annotated:
1447+
args = get_args(annotation)
1448+
return bool(args) and _annotation_mentions_context_type(args[0], context_type=context_type)
1449+
1450+
if origin in (Union, UnionType):
1451+
return any(
1452+
_annotation_mentions_context_type(arg, context_type=context_type)
1453+
for arg in get_args(annotation)
1454+
)
1455+
1456+
return False
1457+
1458+
1459+
def _get_function_tool_invoke_context(
1460+
function_tool: FunctionTool,
1461+
context: ToolContext[Any],
1462+
) -> ToolContext[Any] | RunContextWrapper[Any]:
1463+
"""Choose the runtime context object to pass into a function tool wrapper.
1464+
1465+
Third-party wrappers may declare a narrower `RunContextWrapper` contract and then serialize
1466+
that object downstream. In those cases, passing the richer `ToolContext` can leak runtime-only
1467+
metadata such as agents or run config into incompatible serializers. When the wrapper
1468+
explicitly declares `RunContextWrapper`, preserve only the base context state.
1469+
"""
1470+
try:
1471+
parameters = tuple(inspect.signature(function_tool.on_invoke_tool).parameters.values())
1472+
except (TypeError, ValueError):
1473+
return context
1474+
1475+
if not parameters:
1476+
return context
1477+
1478+
context_annotation = parameters[0].annotation
1479+
try:
1480+
resolved_annotations = get_type_hints(function_tool.on_invoke_tool, include_extras=True)
1481+
except Exception:
1482+
pass
1483+
else:
1484+
context_annotation = resolved_annotations.get(parameters[0].name, context_annotation)
1485+
1486+
if _annotation_mentions_context_type(context_annotation, context_type=ToolContext):
1487+
return context
1488+
if _annotation_mentions_context_type(context_annotation, context_type=RunContextWrapper):
1489+
return context._fork_with_tool_input(context.tool_input)
1490+
return context
1491+
1492+
13761493
async def invoke_function_tool(
13771494
*,
13781495
function_tool: FunctionTool,
13791496
context: ToolContext[Any],
13801497
arguments: str,
13811498
) -> Any:
13821499
"""Invoke a function tool, enforcing timeout configuration when provided."""
1500+
invoke_context = _get_function_tool_invoke_context(function_tool, context)
13831501
timeout_seconds = function_tool.timeout_seconds
13841502
if timeout_seconds is None:
1385-
return await function_tool.on_invoke_tool(context, arguments)
1503+
return await function_tool.on_invoke_tool(cast(Any, invoke_context), arguments)
13861504

13871505
tool_task: asyncio.Future[Any] = asyncio.ensure_future(
1388-
function_tool.on_invoke_tool(context, arguments)
1506+
function_tool.on_invoke_tool(cast(Any, invoke_context), arguments)
13891507
)
13901508
try:
13911509
return await asyncio.wait_for(tool_task, timeout=timeout_seconds)

tests/test_tool_context.py

Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
1+
from typing import Annotated, Any, cast
2+
13
import pytest
24
from openai.types.responses import ResponseFunctionToolCall
35

46
from agents import Agent
57
from agents.run_config import RunConfig
68
from agents.run_context import RunContextWrapper
9+
from agents.tool import FunctionTool, invoke_function_tool
710
from agents.tool_context import ToolContext
11+
from agents.usage import Usage
812
from tests.utils.hitl import make_context_wrapper
913

1014

@@ -198,3 +202,154 @@ def test_tool_context_from_agent_context_prefers_explicit_run_config() -> None:
198202
)
199203

200204
assert tool_ctx.run_config is explicit_run_config
205+
206+
207+
@pytest.mark.asyncio
208+
async def test_invoke_function_tool_passes_plain_run_context_when_requested() -> None:
209+
captured_context: RunContextWrapper[str] | None = None
210+
211+
async def on_invoke_tool(ctx: RunContextWrapper[str], _input: str) -> str:
212+
nonlocal captured_context
213+
captured_context = ctx
214+
return ctx.context
215+
216+
function_tool = FunctionTool(
217+
name="plain_context_tool",
218+
description="test",
219+
params_json_schema={"type": "object", "properties": {}},
220+
on_invoke_tool=on_invoke_tool,
221+
)
222+
tool_context = ToolContext(
223+
context="Stormy",
224+
usage=Usage(),
225+
tool_name="plain_context_tool",
226+
tool_call_id="call-1",
227+
tool_arguments="{}",
228+
agent=Agent(name="agent"),
229+
run_config=RunConfig(model="gpt-4.1-mini"),
230+
tool_input={"city": "Tokyo"},
231+
)
232+
233+
result = await invoke_function_tool(
234+
function_tool=function_tool,
235+
context=tool_context,
236+
arguments="{}",
237+
)
238+
239+
assert result == "Stormy"
240+
assert captured_context is not None
241+
assert not isinstance(captured_context, ToolContext)
242+
assert captured_context.context == "Stormy"
243+
assert captured_context.usage is tool_context.usage
244+
assert captured_context.tool_input == {"city": "Tokyo"}
245+
246+
247+
@pytest.mark.asyncio
248+
async def test_invoke_function_tool_preserves_tool_context_when_requested() -> None:
249+
captured_context: ToolContext[str] | None = None
250+
251+
async def on_invoke_tool(ctx: ToolContext[str], _input: str) -> str:
252+
nonlocal captured_context
253+
captured_context = ctx
254+
return ctx.tool_name
255+
256+
function_tool = FunctionTool(
257+
name="tool_context_tool",
258+
description="test",
259+
params_json_schema={"type": "object", "properties": {}},
260+
on_invoke_tool=on_invoke_tool,
261+
)
262+
tool_context = ToolContext(
263+
context="Stormy",
264+
usage=Usage(),
265+
tool_name="tool_context_tool",
266+
tool_call_id="call-2",
267+
tool_arguments="{}",
268+
agent=Agent(name="agent"),
269+
run_config=RunConfig(model="gpt-4.1-mini"),
270+
)
271+
272+
result = await invoke_function_tool(
273+
function_tool=function_tool,
274+
context=tool_context,
275+
arguments="{}",
276+
)
277+
278+
assert result == "tool_context_tool"
279+
assert captured_context is tool_context
280+
281+
282+
@pytest.mark.asyncio
283+
async def test_invoke_function_tool_ignores_context_name_substrings_in_string_annotations() -> None:
284+
captured_context: object | None = None
285+
286+
class MyRunContextWrapper:
287+
pass
288+
289+
async def on_invoke_tool(ctx: "MyRunContextWrapper", _input: str) -> str:
290+
nonlocal captured_context
291+
captured_context = ctx
292+
return "ok"
293+
294+
function_tool = FunctionTool(
295+
name="substring_context_tool",
296+
description="test",
297+
params_json_schema={"type": "object", "properties": {}},
298+
on_invoke_tool=cast(Any, on_invoke_tool),
299+
)
300+
tool_context = ToolContext(
301+
context="Stormy",
302+
usage=Usage(),
303+
tool_name="substring_context_tool",
304+
tool_call_id="call-3",
305+
tool_arguments="{}",
306+
)
307+
308+
result = await invoke_function_tool(
309+
function_tool=function_tool,
310+
context=tool_context,
311+
arguments="{}",
312+
)
313+
314+
assert result == "ok"
315+
assert captured_context is tool_context
316+
317+
318+
@pytest.mark.asyncio
319+
async def test_invoke_function_tool_ignores_annotated_string_metadata_when_matching_context() -> (
320+
None
321+
):
322+
captured_context: ToolContext[str] | RunContextWrapper[str] | None = None
323+
324+
async def on_invoke_tool(
325+
ctx: Annotated[RunContextWrapper[str], "ToolContext note"], _input: str
326+
) -> str:
327+
nonlocal captured_context
328+
captured_context = ctx
329+
return ctx.context
330+
331+
function_tool = FunctionTool(
332+
name="annotated_string_context_tool",
333+
description="test",
334+
params_json_schema={"type": "object", "properties": {}},
335+
on_invoke_tool=on_invoke_tool,
336+
)
337+
tool_context = ToolContext(
338+
context="Stormy",
339+
usage=Usage(),
340+
tool_name="annotated_string_context_tool",
341+
tool_call_id="call-4",
342+
tool_arguments="{}",
343+
tool_input={"city": "Tokyo"},
344+
)
345+
346+
result = await invoke_function_tool(
347+
function_tool=function_tool,
348+
context=tool_context,
349+
arguments="{}",
350+
)
351+
352+
assert result == "Stormy"
353+
assert captured_context is not None
354+
assert not isinstance(captured_context, ToolContext)
355+
assert captured_context.tool_input == {"city": "Tokyo"}

0 commit comments

Comments
 (0)