Skip to content

Commit c7386ec

Browse files
authored
feat: expose immutable agent tool invocation metadata on run results (ref #2575) (#2576)
1 parent 2b32271 commit c7386ec

5 files changed

Lines changed: 81 additions & 39 deletions

File tree

src/agents/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@
8686
from .prompts import DynamicPromptFunction, GenerateDynamicPromptData, Prompt
8787
from .repl import run_demo_loop
8888
from .responses_websocket_session import ResponsesWebSocketSession, responses_websocket_session
89-
from .result import RunResult, RunResultStreaming
89+
from .result import AgentToolInvocation, RunResult, RunResultStreaming
9090
from .run import (
9191
ReasoningItemIdPolicy,
9292
RunConfig,
@@ -359,6 +359,7 @@ def enable_verbose_stdout_logging():
359359
"RunErrorHandlerInput",
360360
"RunErrorHandlerResult",
361361
"RunErrorHandlers",
362+
"AgentToolInvocation",
362363
"RunResult",
363364
"RunResultStreaming",
364365
"ResponsesWebSocketSession",

src/agents/agent.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -501,7 +501,8 @@ def as_tool(
501501
tool_description: The description of the tool, which should indicate what it does and
502502
when to use it.
503503
custom_output_extractor: A function that extracts the output from the agent. If not
504-
provided, the last message from the agent will be used.
504+
provided, the last message from the agent will be used. Nested run results expose
505+
`agent_tool_invocation` metadata when this agent is invoked via `as_tool()`.
505506
is_enabled: Whether the tool is enabled. Can be a bool or a callable that takes the run
506507
context and agent and returns whether the tool is enabled. Disabled tools are hidden
507508
from the LLM at runtime.

src/agents/result.py

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -46,11 +46,25 @@
4646
)
4747

4848
if TYPE_CHECKING:
49-
from .tool_context import ToolContext
49+
pass
5050

5151
T = TypeVar("T")
5252

5353

54+
@dataclass(frozen=True)
55+
class AgentToolInvocation:
56+
"""Immutable metadata about a nested agent-tool invocation."""
57+
58+
tool_name: str
59+
"""The nested tool name exposed to the model."""
60+
61+
tool_call_id: str
62+
"""The tool call ID for the nested invocation."""
63+
64+
tool_arguments: str
65+
"""The raw JSON arguments for the nested invocation."""
66+
67+
5468
def _populate_state_from_result(
5569
state: RunState[Any],
5670
result: RunResultBase,
@@ -208,13 +222,21 @@ def to_input_list(self) -> list[TResponseInputItem]:
208222
return original_items + new_items
209223

210224
@property
211-
def tool_context(self) -> ToolContext[Any] | None:
212-
"""The tool context for runs started via ``Agent.as_tool()``, if available."""
225+
def agent_tool_invocation(self) -> AgentToolInvocation | None:
226+
"""Immutable metadata for results produced by `Agent.as_tool()`.
227+
228+
Returns `None` for ordinary top-level runs.
229+
"""
213230
from .tool_context import ToolContext
214231

215-
if isinstance(self.context_wrapper, ToolContext):
216-
return self.context_wrapper
217-
return None
232+
if not isinstance(self.context_wrapper, ToolContext):
233+
return None
234+
235+
return AgentToolInvocation(
236+
tool_name=self.context_wrapper.tool_name,
237+
tool_call_id=self.context_wrapper.tool_call_id,
238+
tool_arguments=self.context_wrapper.tool_arguments,
239+
)
218240

219241
@property
220242
def last_response_id(self) -> str | None:

tests/test_agent_as_tool.py

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -395,12 +395,11 @@ async def extractor(result) -> str:
395395

396396

397397
@pytest.mark.asyncio
398-
async def test_agent_as_tool_extractor_can_access_tool_context(
398+
async def test_agent_as_tool_extractor_can_access_agent_tool_invocation(
399399
monkeypatch: pytest.MonkeyPatch,
400400
) -> None:
401401
agent = Agent(name="nested_agent")
402-
403-
real_run_result = RunResult(
402+
run_result = RunResult(
404403
input="hello",
405404
new_items=[],
406405
raw_responses=[],
@@ -413,7 +412,7 @@ async def test_agent_as_tool_extractor_can_access_tool_context(
413412
context=None,
414413
tool_name="nested_tool",
415414
tool_call_id="call_abc_123",
416-
tool_arguments="{}",
415+
tool_arguments='{"input": "hello"}',
417416
),
418417
_last_agent=agent,
419418
)
@@ -433,16 +432,19 @@ async def fake_run(
433432
):
434433
del cls, starting_agent, input, context, max_turns, hooks, run_config
435434
del previous_response_id, conversation_id, session
436-
return real_run_result
435+
return run_result
437436

438437
monkeypatch.setattr(Runner, "run", classmethod(fake_run))
439438

440-
received_call_id: str | None = None
439+
received_tool_call_id: str | None = None
441440

442441
async def extractor(result: RunResult | RunResultStreaming) -> str:
443-
nonlocal received_call_id
444-
assert result.tool_context is not None
445-
received_call_id = result.tool_context.tool_call_id
442+
nonlocal received_tool_call_id
443+
invocation = result.agent_tool_invocation
444+
assert invocation is not None
445+
received_tool_call_id = invocation.tool_call_id
446+
assert invocation.tool_name == "nested_tool"
447+
assert invocation.tool_arguments == '{"input": "hello"}'
446448
return "extracted"
447449

448450
tool = agent.as_tool(
@@ -460,7 +462,7 @@ async def extractor(result: RunResult | RunResultStreaming) -> str:
460462
output = await tool.on_invoke_tool(parent_tool_context, '{"input": "hello"}')
461463

462464
assert output == "extracted"
463-
assert received_call_id == "call_abc_123"
465+
assert received_tool_call_id == "call_abc_123"
464466

465467

466468
@pytest.mark.asyncio
@@ -1412,7 +1414,7 @@ async def on_stream(payload: AgentToolStreamEvent) -> None:
14121414

14131415

14141416
@pytest.mark.asyncio
1415-
async def test_agent_as_tool_streaming_extractor_can_access_tool_context(
1417+
async def test_agent_as_tool_streaming_extractor_can_access_agent_tool_invocation(
14161418
monkeypatch: pytest.MonkeyPatch,
14171419
) -> None:
14181420
agent = Agent(name="streaming_tool_context_agent")
@@ -1441,7 +1443,13 @@ async def test_agent_as_tool_streaming_extractor_can_access_tool_context(
14411443
streamed_instance._event_queue.put_nowait(stream_event)
14421444
streamed_instance.is_complete = True
14431445

1444-
def fake_run_streamed(cls, /, starting_agent, input, **kwargs) -> RunResultStreaming:
1446+
def fake_run_streamed(
1447+
cls,
1448+
/,
1449+
starting_agent,
1450+
input,
1451+
**kwargs,
1452+
) -> RunResultStreaming:
14451453
del cls, starting_agent, input, kwargs
14461454
return streamed_instance
14471455

@@ -1455,8 +1463,11 @@ async def unexpected_run(*args: Any, **kwargs: Any) -> None:
14551463

14561464
async def extractor(result: RunResult | RunResultStreaming) -> str:
14571465
nonlocal received_call_id
1458-
assert result.tool_context is not None
1459-
received_call_id = result.tool_context.tool_call_id
1466+
invocation = result.agent_tool_invocation
1467+
assert invocation is not None
1468+
received_call_id = invocation.tool_call_id
1469+
assert invocation.tool_name == "stream_tool"
1470+
assert invocation.tool_arguments == '{"input": "go"}'
14601471
return "custom value"
14611472

14621473
async def on_stream(payload: AgentToolStreamEvent) -> None:

tests/test_result_cast.py

Lines changed: 24 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3,21 +3,23 @@
33
import dataclasses
44
import gc
55
import weakref
6-
from typing import Any
6+
from typing import Any, cast
77

88
import pytest
99
from openai.types.responses import ResponseOutputMessage, ResponseOutputText
1010
from pydantic import BaseModel, ConfigDict
1111

1212
from agents import (
1313
Agent,
14+
AgentToolInvocation,
1415
MessageOutputItem,
1516
RunContextWrapper,
1617
RunItem,
1718
RunResult,
1819
RunResultStreaming,
1920
)
2021
from agents.exceptions import AgentsException
22+
from agents.tool_context import ToolContext
2123

2224

2325
def create_run_result(
@@ -259,9 +261,13 @@ def test_run_result_streaming_release_agents_releases_current_agent() -> None:
259261
_ = streaming_result.last_agent
260262

261263

262-
def test_run_result_tool_context_returns_tool_context() -> None:
263-
from agents.tool_context import ToolContext
264+
def test_run_result_agent_tool_invocation_returns_none_for_plain_context() -> None:
265+
result = create_run_result("ok")
266+
267+
assert result.agent_tool_invocation is None
264268

269+
270+
def test_run_result_agent_tool_invocation_returns_immutable_metadata() -> None:
265271
tool_ctx = ToolContext(
266272
context=None,
267273
tool_name="my_tool",
@@ -282,20 +288,19 @@ def test_run_result_tool_context_returns_tool_context() -> None:
282288
interruptions=[],
283289
)
284290

285-
assert result.tool_context is tool_ctx
286-
assert result.tool_context is not None
287-
assert result.tool_context.tool_call_id == "call_xyz"
288-
289-
290-
def test_run_result_tool_context_returns_none_for_plain_context() -> None:
291-
result = create_run_result("ok")
292-
293-
assert result.tool_context is None
291+
assert result.agent_tool_invocation == AgentToolInvocation(
292+
tool_name="my_tool",
293+
tool_call_id="call_xyz",
294+
tool_arguments="{}",
295+
)
294296

297+
invocation = result.agent_tool_invocation
298+
assert invocation is not None
299+
with pytest.raises(dataclasses.FrozenInstanceError):
300+
cast(Any, invocation).tool_name = "other"
295301

296-
def test_run_result_streaming_tool_context_returns_tool_context() -> None:
297-
from agents.tool_context import ToolContext
298302

303+
def test_run_result_streaming_agent_tool_invocation_returns_metadata() -> None:
299304
agent = Agent(name="streaming-tool-agent")
300305
tool_ctx = ToolContext(
301306
context=None,
@@ -321,6 +326,8 @@ def test_run_result_streaming_tool_context_returns_tool_context() -> None:
321326
interruptions=[],
322327
)
323328

324-
assert result.tool_context is tool_ctx
325-
assert result.tool_context is not None
326-
assert result.tool_context.tool_call_id == "call_stream"
329+
assert result.agent_tool_invocation == AgentToolInvocation(
330+
tool_name="stream_tool",
331+
tool_call_id="call_stream",
332+
tool_arguments='{"input":"stream"}',
333+
)

0 commit comments

Comments
 (0)