Skip to content

Commit bb14973

Browse files
authored
feat: #2247 add RunResult tool_context accessor for agent tools (#2575)
This pull request adds a typed `RunResult.tool_context` convenience accessor for nested agent-tool runs so `custom_output_extractor` implementations can inspect `ToolContext` metadata without a cast. The change keeps the existing `RunResult` and `RunResultStreaming` constructor shapes intact by adding the accessor on the shared `RunResultBase`, which keeps the release risk low while improving mypy ergonomics for agent-as-tool integrations. It also updates the agent-as-tool tests to cover both non-streaming and streaming extractor access to `tool_context`, and documents the new extractor pattern in `docs/tools.md`. Behavior is unchanged for normal runs: `tool_context` is `None` unless the run was started through `Agent.as_tool(...)`. Resolves #2247
1 parent 43b63dc commit bb14973

3 files changed

Lines changed: 244 additions & 13 deletions

File tree

src/agents/result.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import weakref
77
from collections.abc import AsyncIterator
88
from dataclasses import InitVar, dataclass, field
9-
from typing import Any, Literal, TypeVar, cast
9+
from typing import TYPE_CHECKING, Any, Literal, TypeVar, cast
1010

1111
from pydantic import GetCoreSchemaHandler
1212
from pydantic_core import core_schema
@@ -45,6 +45,9 @@
4545
pretty_print_run_result_streaming,
4646
)
4747

48+
if TYPE_CHECKING:
49+
from .tool_context import ToolContext
50+
4851
T = TypeVar("T")
4952

5053

@@ -204,6 +207,15 @@ def to_input_list(self) -> list[TResponseInputItem]:
204207

205208
return original_items + new_items
206209

210+
@property
211+
def tool_context(self) -> ToolContext[Any] | None:
212+
"""The tool context for runs started via ``Agent.as_tool()``, if available."""
213+
from .tool_context import ToolContext
214+
215+
if isinstance(self.context_wrapper, ToolContext):
216+
return self.context_wrapper
217+
return None
218+
207219
@property
208220
def last_response_id(self) -> str | None:
209221
"""Convenience method to get the response ID of the last model response."""

tests/test_agent_as_tool.py

Lines changed: 164 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
RunContextWrapper,
2020
RunHooks,
2121
Runner,
22+
RunResult,
23+
RunResultStreaming,
2224
Session,
2325
SessionSettings,
2426
ToolApprovalItem,
@@ -392,6 +394,75 @@ async def extractor(result) -> str:
392394
assert output == "custom output"
393395

394396

397+
@pytest.mark.asyncio
398+
async def test_agent_as_tool_extractor_can_access_tool_context(
399+
monkeypatch: pytest.MonkeyPatch,
400+
) -> None:
401+
agent = Agent(name="nested_agent")
402+
403+
real_run_result = RunResult(
404+
input="hello",
405+
new_items=[],
406+
raw_responses=[],
407+
final_output="done",
408+
input_guardrail_results=[],
409+
output_guardrail_results=[],
410+
tool_input_guardrail_results=[],
411+
tool_output_guardrail_results=[],
412+
context_wrapper=ToolContext(
413+
context=None,
414+
tool_name="nested_tool",
415+
tool_call_id="call_abc_123",
416+
tool_arguments="{}",
417+
),
418+
_last_agent=agent,
419+
)
420+
421+
async def fake_run(
422+
cls,
423+
starting_agent,
424+
input,
425+
*,
426+
context,
427+
max_turns,
428+
hooks,
429+
run_config,
430+
previous_response_id,
431+
conversation_id,
432+
session,
433+
):
434+
del cls, starting_agent, input, context, max_turns, hooks, run_config
435+
del previous_response_id, conversation_id, session
436+
return real_run_result
437+
438+
monkeypatch.setattr(Runner, "run", classmethod(fake_run))
439+
440+
received_call_id: str | None = None
441+
442+
async def extractor(result: RunResult | RunResultStreaming) -> str:
443+
nonlocal received_call_id
444+
assert result.tool_context is not None
445+
received_call_id = result.tool_context.tool_call_id
446+
return "extracted"
447+
448+
tool = agent.as_tool(
449+
tool_name="nested_tool",
450+
tool_description="A nested agent tool",
451+
custom_output_extractor=extractor,
452+
)
453+
454+
parent_tool_context = ToolContext(
455+
context=None,
456+
tool_name="nested_tool",
457+
tool_call_id="call_abc_123",
458+
tool_arguments='{"input": "hello"}',
459+
)
460+
output = await tool.on_invoke_tool(parent_tool_context, '{"input": "hello"}')
461+
462+
assert output == "extracted"
463+
assert received_call_id == "call_abc_123"
464+
465+
395466
@pytest.mark.asyncio
396467
async def test_agent_as_tool_inherits_parent_run_config_when_not_set(
397468
monkeypatch: pytest.MonkeyPatch,
@@ -1254,18 +1325,29 @@ async def test_agent_as_tool_streaming_works_with_custom_extractor(
12541325
) -> None:
12551326
agent = Agent(name="streamer")
12561327
stream_events = [RawResponsesStreamEvent(data=cast(Any, {"type": "response_started"}))]
1257-
stream_events = [RawResponsesStreamEvent(data=cast(Any, {"type": "response_started"}))]
1258-
1259-
class DummyStreamingResult:
1260-
def __init__(self) -> None:
1261-
self.final_output = "raw output"
1262-
self.current_agent = agent
1263-
1264-
async def stream_events(self):
1265-
for ev in stream_events:
1266-
yield ev
1267-
1268-
streamed_instance = DummyStreamingResult()
1328+
streamed_instance = RunResultStreaming(
1329+
input="stream please",
1330+
new_items=[],
1331+
raw_responses=[],
1332+
final_output="raw output",
1333+
input_guardrail_results=[],
1334+
output_guardrail_results=[],
1335+
tool_input_guardrail_results=[],
1336+
tool_output_guardrail_results=[],
1337+
context_wrapper=ToolContext(
1338+
context=None,
1339+
tool_name="stream_tool",
1340+
tool_call_id="call-abc",
1341+
tool_arguments='{"input": "stream please"}',
1342+
),
1343+
current_agent=agent,
1344+
current_turn=0,
1345+
max_turns=1,
1346+
_current_agent_output_schema=None,
1347+
trace=None,
1348+
)
1349+
streamed_instance._event_queue.put_nowait(stream_events[0])
1350+
streamed_instance.is_complete = True
12691351

12701352
def fake_run_streamed(
12711353
cls,
@@ -1329,6 +1411,76 @@ async def on_stream(payload: AgentToolStreamEvent) -> None:
13291411
assert callbacks == stream_events
13301412

13311413

1414+
@pytest.mark.asyncio
1415+
async def test_agent_as_tool_streaming_extractor_can_access_tool_context(
1416+
monkeypatch: pytest.MonkeyPatch,
1417+
) -> None:
1418+
agent = Agent(name="streaming_tool_context_agent")
1419+
stream_event = RawResponsesStreamEvent(data=cast(Any, {"type": "response_started"}))
1420+
streamed_instance = RunResultStreaming(
1421+
input="go",
1422+
new_items=[],
1423+
raw_responses=[],
1424+
final_output="raw output",
1425+
input_guardrail_results=[],
1426+
output_guardrail_results=[],
1427+
tool_input_guardrail_results=[],
1428+
tool_output_guardrail_results=[],
1429+
context_wrapper=ToolContext(
1430+
context=None,
1431+
tool_name="stream_tool",
1432+
tool_call_id="call-stream-123",
1433+
tool_arguments='{"input": "go"}',
1434+
),
1435+
current_agent=agent,
1436+
current_turn=0,
1437+
max_turns=1,
1438+
_current_agent_output_schema=None,
1439+
trace=None,
1440+
)
1441+
streamed_instance._event_queue.put_nowait(stream_event)
1442+
streamed_instance.is_complete = True
1443+
1444+
def fake_run_streamed(cls, /, starting_agent, input, **kwargs) -> RunResultStreaming:
1445+
del cls, starting_agent, input, kwargs
1446+
return streamed_instance
1447+
1448+
async def unexpected_run(*args: Any, **kwargs: Any) -> None:
1449+
raise AssertionError("Runner.run should not be called when on_stream is provided.")
1450+
1451+
monkeypatch.setattr(Runner, "run_streamed", classmethod(fake_run_streamed))
1452+
monkeypatch.setattr(Runner, "run", classmethod(unexpected_run))
1453+
1454+
received_call_id: str | None = None
1455+
1456+
async def extractor(result: RunResult | RunResultStreaming) -> str:
1457+
nonlocal received_call_id
1458+
assert result.tool_context is not None
1459+
received_call_id = result.tool_context.tool_call_id
1460+
return "custom value"
1461+
1462+
async def on_stream(payload: AgentToolStreamEvent) -> None:
1463+
del payload
1464+
1465+
tool = agent.as_tool(
1466+
tool_name="stream_tool",
1467+
tool_description="Streams events",
1468+
custom_output_extractor=extractor,
1469+
on_stream=on_stream,
1470+
)
1471+
1472+
tool_context = ToolContext(
1473+
context=None,
1474+
tool_name="stream_tool",
1475+
tool_call_id="call-stream-123",
1476+
tool_arguments='{"input": "go"}',
1477+
)
1478+
output = await tool.on_invoke_tool(tool_context, '{"input": "go"}')
1479+
1480+
assert output == "custom value"
1481+
assert received_call_id == "call-stream-123"
1482+
1483+
13321484
@pytest.mark.asyncio
13331485
async def test_agent_as_tool_streaming_accepts_sync_handler(
13341486
monkeypatch: pytest.MonkeyPatch,

tests/test_result_cast.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,3 +257,70 @@ def test_run_result_streaming_release_agents_releases_current_agent() -> None:
257257
assert agent_ref() is None
258258
with pytest.raises(AgentsException):
259259
_ = streaming_result.last_agent
260+
261+
262+
def test_run_result_tool_context_returns_tool_context() -> None:
263+
from agents.tool_context import ToolContext
264+
265+
tool_ctx = ToolContext(
266+
context=None,
267+
tool_name="my_tool",
268+
tool_call_id="call_xyz",
269+
tool_arguments="{}",
270+
)
271+
result = RunResult(
272+
input="test",
273+
new_items=[],
274+
raw_responses=[],
275+
final_output="ok",
276+
input_guardrail_results=[],
277+
output_guardrail_results=[],
278+
tool_input_guardrail_results=[],
279+
tool_output_guardrail_results=[],
280+
_last_agent=Agent(name="test"),
281+
context_wrapper=tool_ctx,
282+
interruptions=[],
283+
)
284+
285+
assert result.tool_context is tool_ctx
286+
assert result.tool_context is not None
287+
assert result.tool_context.tool_call_id == "call_xyz"
288+
289+
290+
def test_run_result_tool_context_returns_none_for_plain_context() -> None:
291+
result = create_run_result("ok")
292+
293+
assert result.tool_context is None
294+
295+
296+
def test_run_result_streaming_tool_context_returns_tool_context() -> None:
297+
from agents.tool_context import ToolContext
298+
299+
agent = Agent(name="streaming-tool-agent")
300+
tool_ctx = ToolContext(
301+
context=None,
302+
tool_name="stream_tool",
303+
tool_call_id="call_stream",
304+
tool_arguments='{"input":"stream"}',
305+
)
306+
result = RunResultStreaming(
307+
input="stream",
308+
new_items=[],
309+
raw_responses=[],
310+
final_output="done",
311+
input_guardrail_results=[],
312+
output_guardrail_results=[],
313+
tool_input_guardrail_results=[],
314+
tool_output_guardrail_results=[],
315+
context_wrapper=tool_ctx,
316+
current_agent=agent,
317+
current_turn=0,
318+
max_turns=1,
319+
_current_agent_output_schema=None,
320+
trace=None,
321+
interruptions=[],
322+
)
323+
324+
assert result.tool_context is tool_ctx
325+
assert result.tool_context is not None
326+
assert result.tool_context.tool_call_id == "call_stream"

0 commit comments

Comments
 (0)