Skip to content

Commit c473b0c

Browse files
authored
fix: narrow Agent.as_tool return type to FunctionTool (#2473)
1 parent bc9dbd7 commit c473b0c

3 files changed

Lines changed: 82 additions & 142 deletions

File tree

src/agents/agent.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -485,7 +485,7 @@ def as_tool(
485485
parameters: type[Any] | None = None,
486486
input_builder: StructuredToolInputBuilder | None = None,
487487
include_input_schema: bool = False,
488-
) -> Tool:
488+
) -> FunctionTool:
489489
"""Transform this agent into a tool, callable by other agents.
490490
491491
This is different from handoffs in two ways:

tests/test_agent_as_tool.py

Lines changed: 76 additions & 133 deletions
Original file line numberDiff line numberDiff line change
@@ -421,12 +421,9 @@ async def fake_run(
421421

422422
monkeypatch.setattr(Runner, "run", classmethod(fake_run))
423423

424-
tool = cast(
425-
FunctionTool,
426-
agent.as_tool(
427-
tool_name="inherits_config_tool",
428-
tool_description="inherit config",
429-
),
424+
tool = agent.as_tool(
425+
tool_name="inherits_config_tool",
426+
tool_description="inherit config",
430427
)
431428
tool_context = ToolContext(
432429
context=None,
@@ -475,13 +472,10 @@ async def fake_run(
475472

476473
monkeypatch.setattr(Runner, "run", classmethod(fake_run))
477474

478-
tool = cast(
479-
FunctionTool,
480-
agent.as_tool(
481-
tool_name="override_config_tool",
482-
tool_description="override config",
483-
run_config=explicit_run_config,
484-
),
475+
tool = agent.as_tool(
476+
tool_name="override_config_tool",
477+
tool_description="override config",
478+
run_config=explicit_run_config,
485479
)
486480
tool_context = ToolContext(
487481
context=None,
@@ -529,12 +523,9 @@ async def fake_run(
529523

530524
monkeypatch.setattr(Runner, "run", classmethod(fake_run))
531525

532-
tool = cast(
533-
FunctionTool,
534-
agent.as_tool(
535-
tool_name="trace_config_tool",
536-
tool_description="inherits trace config",
537-
),
526+
tool = agent.as_tool(
527+
tool_name="trace_config_tool",
528+
tool_description="inherits trace config",
538529
)
539530
tool_context = ToolContext(
540531
context=None,
@@ -561,13 +552,10 @@ class TranslationInput(BaseModel):
561552
target: str
562553

563554
agent = Agent(name="translator")
564-
tool = cast(
565-
FunctionTool,
566-
agent.as_tool(
567-
tool_name="translate",
568-
tool_description="Translate text",
569-
parameters=TranslationInput,
570-
),
555+
tool = agent.as_tool(
556+
tool_name="translate",
557+
tool_description="Translate text",
558+
parameters=TranslationInput,
571559
)
572560

573561
captured: dict[str, Any] = {}
@@ -626,12 +614,9 @@ async def test_agent_as_tool_clears_stale_tool_input_for_plain_tools(
626614
"""Non-structured agent tools should not inherit stale tool input."""
627615

628616
agent = Agent(name="plain_agent")
629-
tool = cast(
630-
FunctionTool,
631-
agent.as_tool(
632-
tool_name="plain_tool",
633-
tool_description="Plain tool",
634-
),
617+
tool = agent.as_tool(
618+
tool_name="plain_tool",
619+
tool_description="Plain tool",
635620
)
636621

637622
run_context = RunContextWrapper({"locale": "en-US"})
@@ -685,13 +670,10 @@ class TranslationInput(BaseModel):
685670
target: str = Field(description="Target language")
686671

687672
agent = Agent(name="summary_agent")
688-
tool = cast(
689-
FunctionTool,
690-
agent.as_tool(
691-
tool_name="summarize_schema",
692-
tool_description="Summary tool",
693-
parameters=TranslationInput,
694-
),
673+
tool = agent.as_tool(
674+
tool_name="summarize_schema",
675+
tool_description="Summary tool",
676+
parameters=TranslationInput,
695677
)
696678

697679
captured: dict[str, Any] = {}
@@ -756,14 +738,11 @@ async def builder(options: StructuredToolInputBuilderOptions):
756738
builder_calls.append(options)
757739
return custom_items
758740

759-
tool = cast(
760-
FunctionTool,
761-
agent.as_tool(
762-
tool_name="builder_tool",
763-
tool_description="Builder tool",
764-
parameters=TranslationInput,
765-
input_builder=builder,
766-
),
741+
tool = agent.as_tool(
742+
tool_name="builder_tool",
743+
tool_description="Builder tool",
744+
parameters=TranslationInput,
745+
input_builder=builder,
767746
)
768747

769748
class DummyResult:
@@ -813,13 +792,10 @@ async def test_agent_as_tool_rejects_invalid_builder_output() -> None:
813792
def builder(_options):
814793
return 123
815794

816-
tool = cast(
817-
FunctionTool,
818-
agent.as_tool(
819-
tool_name="invalid_builder_tool",
820-
tool_description="Invalid builder tool",
821-
input_builder=builder,
822-
),
795+
tool = agent.as_tool(
796+
tool_name="invalid_builder_tool",
797+
tool_description="Invalid builder tool",
798+
input_builder=builder,
823799
)
824800

825801
tool_context = ToolContext(
@@ -844,14 +820,11 @@ class TranslationInput(BaseModel):
844820
target: str = Field(description="Target language")
845821

846822
agent = Agent(name="schema_agent")
847-
tool = cast(
848-
FunctionTool,
849-
agent.as_tool(
850-
tool_name="schema_tool",
851-
tool_description="Schema tool",
852-
parameters=TranslationInput,
853-
include_input_schema=True,
854-
),
823+
tool = agent.as_tool(
824+
tool_name="schema_tool",
825+
tool_description="Schema tool",
826+
parameters=TranslationInput,
827+
include_input_schema=True,
855828
)
856829

857830
captured: dict[str, Any] = {}
@@ -903,13 +876,10 @@ async def test_agent_as_tool_ignores_input_schema_without_parameters(
903876
"""include_input_schema should be ignored when no parameters are provided."""
904877

905878
agent = Agent(name="default_schema_agent")
906-
tool = cast(
907-
FunctionTool,
908-
agent.as_tool(
909-
tool_name="default_schema_tool",
910-
tool_description="Default schema tool",
911-
include_input_schema=True,
912-
),
879+
tool = agent.as_tool(
880+
tool_name="default_schema_tool",
881+
tool_description="Default schema tool",
882+
include_input_schema=True,
913883
)
914884

915885
captured: dict[str, Any] = {}
@@ -1017,14 +987,11 @@ async def extractor(result: Any) -> str:
1017987
assert result is resumed_result
1018988
return "from_resume"
1019989

1020-
tool = cast(
1021-
FunctionTool,
1022-
agent.as_tool(
1023-
tool_name="outer_tool",
1024-
tool_description="Outer agent tool",
1025-
custom_output_extractor=extractor,
1026-
is_enabled=True,
1027-
),
990+
tool = agent.as_tool(
991+
tool_name="outer_tool",
992+
tool_description="Outer agent tool",
993+
custom_output_extractor=extractor,
994+
is_enabled=True,
1028995
)
1029996

1030997
output = await tool.on_invoke_tool(tool_context, tool_call.arguments)
@@ -1102,13 +1069,10 @@ async def on_stream(payload: AgentToolStreamEvent) -> None:
11021069
type="function_call",
11031070
)
11041071

1105-
tool = cast(
1106-
FunctionTool,
1107-
agent.as_tool(
1108-
tool_name="stream_tool",
1109-
tool_description="Streams events",
1110-
on_stream=on_stream,
1111-
),
1072+
tool = agent.as_tool(
1073+
tool_name="stream_tool",
1074+
tool_description="Streams events",
1075+
on_stream=on_stream,
11121076
)
11131077

11141078
tool_context = ToolContext(
@@ -1179,13 +1143,10 @@ def fake_run_streamed(
11791143
async def on_stream(payload: AgentToolStreamEvent) -> None:
11801144
seen_agents.append(payload["agent"])
11811145

1182-
tool = cast(
1183-
FunctionTool,
1184-
first_agent.as_tool(
1185-
tool_name="delegate_tool",
1186-
tool_description="Streams handoff events",
1187-
on_stream=on_stream,
1188-
),
1146+
tool = first_agent.as_tool(
1147+
tool_name="delegate_tool",
1148+
tool_description="Streams handoff events",
1149+
on_stream=on_stream,
11891150
)
11901151

11911152
tool_call = ResponseFunctionToolCall(
@@ -1269,14 +1230,11 @@ async def on_stream(payload: AgentToolStreamEvent) -> None:
12691230
type="function_call",
12701231
)
12711232

1272-
tool = cast(
1273-
FunctionTool,
1274-
agent.as_tool(
1275-
tool_name="stream_tool",
1276-
tool_description="Streams events",
1277-
custom_output_extractor=extractor,
1278-
on_stream=on_stream,
1279-
),
1233+
tool = agent.as_tool(
1234+
tool_name="stream_tool",
1235+
tool_description="Streams events",
1236+
custom_output_extractor=extractor,
1237+
on_stream=on_stream,
12801238
)
12811239

12821240
tool_context = ToolContext(
@@ -1329,13 +1287,10 @@ def sync_handler(event: AgentToolStreamEvent) -> None:
13291287
type="function_call",
13301288
)
13311289

1332-
tool = cast(
1333-
FunctionTool,
1334-
agent.as_tool(
1335-
tool_name="sync_tool",
1336-
tool_description="Uses sync handler",
1337-
on_stream=sync_handler,
1338-
),
1290+
tool = agent.as_tool(
1291+
tool_name="sync_tool",
1292+
tool_description="Uses sync handler",
1293+
on_stream=sync_handler,
13391294
)
13401295
tool_context = ToolContext(
13411296
context=None,
@@ -1402,13 +1357,10 @@ async def on_stream(payload: AgentToolStreamEvent) -> None:
14021357
type="function_call",
14031358
)
14041359

1405-
tool = cast(
1406-
FunctionTool,
1407-
agent.as_tool(
1408-
tool_name="nonblocking_tool",
1409-
tool_description="Uses non-blocking streaming handler",
1410-
on_stream=on_stream,
1411-
),
1360+
tool = agent.as_tool(
1361+
tool_name="nonblocking_tool",
1362+
tool_description="Uses non-blocking streaming handler",
1363+
on_stream=on_stream,
14121364
)
14131365
tool_context = ToolContext(
14141366
context=None,
@@ -1468,13 +1420,10 @@ def bad_handler(event: AgentToolStreamEvent) -> None:
14681420
type="function_call",
14691421
)
14701422

1471-
tool = cast(
1472-
FunctionTool,
1473-
agent.as_tool(
1474-
tool_name="error_tool",
1475-
tool_description="Handler throws",
1476-
on_stream=bad_handler,
1477-
),
1423+
tool = agent.as_tool(
1424+
tool_name="error_tool",
1425+
tool_description="Handler throws",
1426+
on_stream=bad_handler,
14781427
)
14791428
tool_context = ToolContext(
14801429
context=None,
@@ -1525,12 +1474,9 @@ async def fake_run(
15251474
classmethod(lambda *args, **kwargs: (_ for _ in ()).throw(AssertionError("no stream"))),
15261475
)
15271476

1528-
tool = cast(
1529-
FunctionTool,
1530-
agent.as_tool(
1531-
tool_name="nostream_tool",
1532-
tool_description="No streaming path",
1533-
),
1477+
tool = agent.as_tool(
1478+
tool_name="nostream_tool",
1479+
tool_description="No streaming path",
15341480
)
15351481
tool_context = ToolContext(
15361482
context=None,
@@ -1581,13 +1527,10 @@ async def on_stream(event: AgentToolStreamEvent) -> None:
15811527
type="function_call",
15821528
)
15831529

1584-
tool = cast(
1585-
FunctionTool,
1586-
agent.as_tool(
1587-
tool_name="direct_stream_tool",
1588-
tool_description="Direct invocation",
1589-
on_stream=on_stream,
1590-
),
1530+
tool = agent.as_tool(
1531+
tool_name="direct_stream_tool",
1532+
tool_description="Direct invocation",
1533+
on_stream=on_stream,
15911534
)
15921535
tool_context = ToolContext(
15931536
context=None,

tests/test_example_workflows.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
)
2727
from agents.agent import ToolsToFinalOutputResult
2828
from agents.items import TResponseInputItem
29-
from agents.tool import FunctionTool, FunctionToolResult, function_tool
29+
from agents.tool import FunctionToolResult, function_tool
3030

3131
from .fake_model import FakeModel
3232
from .test_responses import (
@@ -444,13 +444,10 @@ async def test_agent_as_tool_streaming_example_collects_events() -> None:
444444
async def on_stream(event: AgentToolStreamEvent) -> None:
445445
received.append(event)
446446

447-
billing_tool = cast(
448-
FunctionTool,
449-
billing_agent.as_tool(
450-
tool_name="billing_agent",
451-
tool_description="Answer billing questions",
452-
on_stream=on_stream,
453-
),
447+
billing_tool = billing_agent.as_tool(
448+
tool_name="billing_agent",
449+
tool_description="Answer billing questions",
450+
on_stream=on_stream,
454451
)
455452

456453
async def fake_invoke(ctx, input: str) -> str:

0 commit comments

Comments
 (0)