|
19 | 19 | RunContextWrapper, |
20 | 20 | RunHooks, |
21 | 21 | Runner, |
| 22 | + RunResult, |
| 23 | + RunResultStreaming, |
22 | 24 | Session, |
23 | 25 | SessionSettings, |
24 | 26 | ToolApprovalItem, |
@@ -392,6 +394,75 @@ async def extractor(result) -> str: |
392 | 394 | assert output == "custom output" |
393 | 395 |
|
394 | 396 |
|
| 397 | +@pytest.mark.asyncio |
| 398 | +async def test_agent_as_tool_extractor_can_access_tool_context( |
| 399 | + monkeypatch: pytest.MonkeyPatch, |
| 400 | +) -> None: |
| 401 | + agent = Agent(name="nested_agent") |
| 402 | + |
| 403 | + real_run_result = RunResult( |
| 404 | + input="hello", |
| 405 | + new_items=[], |
| 406 | + raw_responses=[], |
| 407 | + final_output="done", |
| 408 | + input_guardrail_results=[], |
| 409 | + output_guardrail_results=[], |
| 410 | + tool_input_guardrail_results=[], |
| 411 | + tool_output_guardrail_results=[], |
| 412 | + context_wrapper=ToolContext( |
| 413 | + context=None, |
| 414 | + tool_name="nested_tool", |
| 415 | + tool_call_id="call_abc_123", |
| 416 | + tool_arguments="{}", |
| 417 | + ), |
| 418 | + _last_agent=agent, |
| 419 | + ) |
| 420 | + |
| 421 | + async def fake_run( |
| 422 | + cls, |
| 423 | + starting_agent, |
| 424 | + input, |
| 425 | + *, |
| 426 | + context, |
| 427 | + max_turns, |
| 428 | + hooks, |
| 429 | + run_config, |
| 430 | + previous_response_id, |
| 431 | + conversation_id, |
| 432 | + session, |
| 433 | + ): |
| 434 | + del cls, starting_agent, input, context, max_turns, hooks, run_config |
| 435 | + del previous_response_id, conversation_id, session |
| 436 | + return real_run_result |
| 437 | + |
| 438 | + monkeypatch.setattr(Runner, "run", classmethod(fake_run)) |
| 439 | + |
| 440 | + received_call_id: str | None = None |
| 441 | + |
| 442 | + async def extractor(result: RunResult | RunResultStreaming) -> str: |
| 443 | + nonlocal received_call_id |
| 444 | + assert result.tool_context is not None |
| 445 | + received_call_id = result.tool_context.tool_call_id |
| 446 | + return "extracted" |
| 447 | + |
| 448 | + tool = agent.as_tool( |
| 449 | + tool_name="nested_tool", |
| 450 | + tool_description="A nested agent tool", |
| 451 | + custom_output_extractor=extractor, |
| 452 | + ) |
| 453 | + |
| 454 | + parent_tool_context = ToolContext( |
| 455 | + context=None, |
| 456 | + tool_name="nested_tool", |
| 457 | + tool_call_id="call_abc_123", |
| 458 | + tool_arguments='{"input": "hello"}', |
| 459 | + ) |
| 460 | + output = await tool.on_invoke_tool(parent_tool_context, '{"input": "hello"}') |
| 461 | + |
| 462 | + assert output == "extracted" |
| 463 | + assert received_call_id == "call_abc_123" |
| 464 | + |
| 465 | + |
395 | 466 | @pytest.mark.asyncio |
396 | 467 | async def test_agent_as_tool_inherits_parent_run_config_when_not_set( |
397 | 468 | monkeypatch: pytest.MonkeyPatch, |
@@ -1254,18 +1325,29 @@ async def test_agent_as_tool_streaming_works_with_custom_extractor( |
1254 | 1325 | ) -> None: |
1255 | 1326 | agent = Agent(name="streamer") |
1256 | 1327 | stream_events = [RawResponsesStreamEvent(data=cast(Any, {"type": "response_started"}))] |
1257 | | - stream_events = [RawResponsesStreamEvent(data=cast(Any, {"type": "response_started"}))] |
1258 | | - |
1259 | | - class DummyStreamingResult: |
1260 | | - def __init__(self) -> None: |
1261 | | - self.final_output = "raw output" |
1262 | | - self.current_agent = agent |
1263 | | - |
1264 | | - async def stream_events(self): |
1265 | | - for ev in stream_events: |
1266 | | - yield ev |
1267 | | - |
1268 | | - streamed_instance = DummyStreamingResult() |
| 1328 | + streamed_instance = RunResultStreaming( |
| 1329 | + input="stream please", |
| 1330 | + new_items=[], |
| 1331 | + raw_responses=[], |
| 1332 | + final_output="raw output", |
| 1333 | + input_guardrail_results=[], |
| 1334 | + output_guardrail_results=[], |
| 1335 | + tool_input_guardrail_results=[], |
| 1336 | + tool_output_guardrail_results=[], |
| 1337 | + context_wrapper=ToolContext( |
| 1338 | + context=None, |
| 1339 | + tool_name="stream_tool", |
| 1340 | + tool_call_id="call-abc", |
| 1341 | + tool_arguments='{"input": "stream please"}', |
| 1342 | + ), |
| 1343 | + current_agent=agent, |
| 1344 | + current_turn=0, |
| 1345 | + max_turns=1, |
| 1346 | + _current_agent_output_schema=None, |
| 1347 | + trace=None, |
| 1348 | + ) |
| 1349 | + streamed_instance._event_queue.put_nowait(stream_events[0]) |
| 1350 | + streamed_instance.is_complete = True |
1269 | 1351 |
|
1270 | 1352 | def fake_run_streamed( |
1271 | 1353 | cls, |
@@ -1329,6 +1411,76 @@ async def on_stream(payload: AgentToolStreamEvent) -> None: |
1329 | 1411 | assert callbacks == stream_events |
1330 | 1412 |
|
1331 | 1413 |
|
| 1414 | +@pytest.mark.asyncio |
| 1415 | +async def test_agent_as_tool_streaming_extractor_can_access_tool_context( |
| 1416 | + monkeypatch: pytest.MonkeyPatch, |
| 1417 | +) -> None: |
| 1418 | + agent = Agent(name="streaming_tool_context_agent") |
| 1419 | + stream_event = RawResponsesStreamEvent(data=cast(Any, {"type": "response_started"})) |
| 1420 | + streamed_instance = RunResultStreaming( |
| 1421 | + input="go", |
| 1422 | + new_items=[], |
| 1423 | + raw_responses=[], |
| 1424 | + final_output="raw output", |
| 1425 | + input_guardrail_results=[], |
| 1426 | + output_guardrail_results=[], |
| 1427 | + tool_input_guardrail_results=[], |
| 1428 | + tool_output_guardrail_results=[], |
| 1429 | + context_wrapper=ToolContext( |
| 1430 | + context=None, |
| 1431 | + tool_name="stream_tool", |
| 1432 | + tool_call_id="call-stream-123", |
| 1433 | + tool_arguments='{"input": "go"}', |
| 1434 | + ), |
| 1435 | + current_agent=agent, |
| 1436 | + current_turn=0, |
| 1437 | + max_turns=1, |
| 1438 | + _current_agent_output_schema=None, |
| 1439 | + trace=None, |
| 1440 | + ) |
| 1441 | + streamed_instance._event_queue.put_nowait(stream_event) |
| 1442 | + streamed_instance.is_complete = True |
| 1443 | + |
| 1444 | + def fake_run_streamed(cls, /, starting_agent, input, **kwargs) -> RunResultStreaming: |
| 1445 | + del cls, starting_agent, input, kwargs |
| 1446 | + return streamed_instance |
| 1447 | + |
| 1448 | + async def unexpected_run(*args: Any, **kwargs: Any) -> None: |
| 1449 | + raise AssertionError("Runner.run should not be called when on_stream is provided.") |
| 1450 | + |
| 1451 | + monkeypatch.setattr(Runner, "run_streamed", classmethod(fake_run_streamed)) |
| 1452 | + monkeypatch.setattr(Runner, "run", classmethod(unexpected_run)) |
| 1453 | + |
| 1454 | + received_call_id: str | None = None |
| 1455 | + |
| 1456 | + async def extractor(result: RunResult | RunResultStreaming) -> str: |
| 1457 | + nonlocal received_call_id |
| 1458 | + assert result.tool_context is not None |
| 1459 | + received_call_id = result.tool_context.tool_call_id |
| 1460 | + return "custom value" |
| 1461 | + |
| 1462 | + async def on_stream(payload: AgentToolStreamEvent) -> None: |
| 1463 | + del payload |
| 1464 | + |
| 1465 | + tool = agent.as_tool( |
| 1466 | + tool_name="stream_tool", |
| 1467 | + tool_description="Streams events", |
| 1468 | + custom_output_extractor=extractor, |
| 1469 | + on_stream=on_stream, |
| 1470 | + ) |
| 1471 | + |
| 1472 | + tool_context = ToolContext( |
| 1473 | + context=None, |
| 1474 | + tool_name="stream_tool", |
| 1475 | + tool_call_id="call-stream-123", |
| 1476 | + tool_arguments='{"input": "go"}', |
| 1477 | + ) |
| 1478 | + output = await tool.on_invoke_tool(tool_context, '{"input": "go"}') |
| 1479 | + |
| 1480 | + assert output == "custom value" |
| 1481 | + assert received_call_id == "call-stream-123" |
| 1482 | + |
| 1483 | + |
1332 | 1484 | @pytest.mark.asyncio |
1333 | 1485 | async def test_agent_as_tool_streaming_accepts_sync_handler( |
1334 | 1486 | monkeypatch: pytest.MonkeyPatch, |
|
0 commit comments