Skip to content

Commit 687e974

Browse files
fix: handle cancelled single function tools as tool failures (#2762)
1 parent b8355b2 commit 687e974

File tree

6 files changed

+308
-7
lines changed

6 files changed

+308
-7
lines changed

src/agents/agent.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -850,6 +850,26 @@ async def dispatch_stream_events() -> None:
850850
if custom_output_extractor:
851851
return await custom_output_extractor(run_result)
852852

853+
if run_result.final_output is not None and (
854+
not isinstance(run_result.final_output, str) or run_result.final_output != ""
855+
):
856+
return run_result.final_output
857+
858+
from .items import ItemHelpers, MessageOutputItem, ToolCallOutputItem
859+
860+
for item in reversed(run_result.new_items):
861+
if isinstance(item, MessageOutputItem):
862+
text_output = ItemHelpers.text_message_output(item)
863+
if text_output:
864+
return text_output
865+
866+
if (
867+
isinstance(item, ToolCallOutputItem)
868+
and isinstance(item.output, str)
869+
and item.output
870+
):
871+
return item.output
872+
853873
return run_result.final_output
854874

855875
run_agent_tool = _build_wrapped_function_tool(

src/agents/run_internal/tool_execution.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1400,9 +1400,10 @@ async def _drain_cancelled_tasks(
14001400
self,
14011401
tasks: set[asyncio.Task[Any]],
14021402
) -> tuple[_FunctionToolFailure | None, set[asyncio.Task[Any]]]:
1403-
late_failure_sources: dict[asyncio.Task[Any], _FunctionToolFailureSource] = {
1404-
task: "cancelled_teardown" for task in tasks
1405-
}
1403+
late_failure_sources: dict[asyncio.Task[Any], _FunctionToolFailureSource] = dict.fromkeys(
1404+
tasks,
1405+
"cancelled_teardown",
1406+
)
14061407
return await _drain_cancelled_function_tool_tasks(
14071408
pending_tasks=tasks,
14081409
task_states=self.task_states,
@@ -1415,9 +1416,9 @@ async def _wait_post_invoke_tasks(
14151416
self,
14161417
tasks: set[asyncio.Task[Any]],
14171418
) -> tuple[_FunctionToolFailure | None, set[asyncio.Task[Any]]]:
1418-
post_invoke_failure_sources: dict[asyncio.Task[Any], _FunctionToolFailureSource] = {
1419-
task: "post_invoke" for task in tasks
1420-
}
1419+
post_invoke_failure_sources: dict[asyncio.Task[Any], _FunctionToolFailureSource] = (
1420+
dict.fromkeys(tasks, "post_invoke")
1421+
)
14211422
return await _wait_pending_function_tool_tasks_for_timeout(
14221423
pending_tasks=tasks,
14231424
task_states=self.task_states,
@@ -1638,7 +1639,7 @@ async def _invoke_tool_and_run_post_invoke(
16381639
arguments=tool_call.arguments,
16391640
)
16401641
except asyncio.CancelledError as e:
1641-
if not self.isolate_parallel_failures or outer_task in self.teardown_cancelled_tasks:
1642+
if outer_task in self.teardown_cancelled_tasks:
16421643
raise
16431644

16441645
result = await maybe_invoke_function_tool_failure_error_function(

tests/test_agent_as_tool.py

Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
Session,
3131
SessionSettings,
3232
ToolApprovalItem,
33+
ToolCallOutputItem,
3334
TResponseInputItem,
3435
Usage,
3536
tool_namespace,
@@ -407,6 +408,183 @@ async def extractor(result) -> str:
407408
assert output == "custom output"
408409

409410

411+
@pytest.mark.asyncio
412+
async def test_agent_as_tool_fallback_uses_current_run_items_only(
413+
monkeypatch: pytest.MonkeyPatch,
414+
) -> None:
415+
agent = Agent(name="summarizer")
416+
417+
message = ResponseOutputMessage(
418+
id="msg_current",
419+
role="assistant",
420+
status="completed",
421+
type="message",
422+
content=[
423+
ResponseOutputText(
424+
annotations=[],
425+
text="Current run summary",
426+
type="output_text",
427+
logprobs=[],
428+
)
429+
],
430+
)
431+
432+
class DummyResult:
433+
def __init__(self) -> None:
434+
self.final_output = ""
435+
self.new_items = [
436+
ToolCallOutputItem(
437+
agent=agent,
438+
raw_item={
439+
"call_id": "call_current",
440+
"output": "Current tool output",
441+
"type": "function_call_output",
442+
},
443+
output="Current tool output",
444+
),
445+
MessageOutputItem(agent=agent, raw_item=message),
446+
]
447+
448+
def to_input_list(self) -> list[dict[str, Any]]:
449+
return [
450+
{
451+
"call_id": "call_old",
452+
"output": "Old output from prior history",
453+
"type": "function_call_output",
454+
}
455+
]
456+
457+
run_result = DummyResult()
458+
459+
async def fake_run(
460+
cls,
461+
starting_agent,
462+
input,
463+
*,
464+
context,
465+
max_turns,
466+
hooks,
467+
run_config,
468+
previous_response_id,
469+
conversation_id,
470+
session,
471+
):
472+
del (
473+
cls,
474+
starting_agent,
475+
input,
476+
context,
477+
max_turns,
478+
hooks,
479+
run_config,
480+
previous_response_id,
481+
conversation_id,
482+
session,
483+
)
484+
return run_result
485+
486+
monkeypatch.setattr(Runner, "run", classmethod(fake_run))
487+
488+
tool = agent.as_tool(
489+
tool_name="summary_tool",
490+
tool_description="Summarize current run output",
491+
)
492+
tool_context = ToolContext(
493+
context=None,
494+
tool_name="summary_tool",
495+
tool_call_id="call_1",
496+
tool_arguments='{"input": "hello"}',
497+
)
498+
499+
output = await tool.on_invoke_tool(tool_context, '{"input": "hello"}')
500+
501+
assert output == "Current run summary"
502+
503+
504+
@pytest.mark.asyncio
505+
async def test_agent_as_tool_fallback_returns_most_recent_current_run_output(
506+
monkeypatch: pytest.MonkeyPatch,
507+
) -> None:
508+
agent = Agent(name="summarizer")
509+
510+
older_message = ResponseOutputMessage(
511+
id="msg_older",
512+
role="assistant",
513+
status="completed",
514+
type="message",
515+
content=[
516+
ResponseOutputText(
517+
annotations=[],
518+
text="Older message output",
519+
type="output_text",
520+
logprobs=[],
521+
)
522+
],
523+
)
524+
525+
class DummyResult:
526+
def __init__(self) -> None:
527+
self.final_output = ""
528+
self.new_items = [
529+
MessageOutputItem(agent=agent, raw_item=older_message),
530+
ToolCallOutputItem(
531+
agent=agent,
532+
raw_item={
533+
"call_id": "call_current",
534+
"output": "Newest tool output",
535+
"type": "function_call_output",
536+
},
537+
output="Newest tool output",
538+
),
539+
]
540+
541+
run_result = DummyResult()
542+
543+
async def fake_run(
544+
cls,
545+
starting_agent,
546+
input,
547+
*,
548+
context,
549+
max_turns,
550+
hooks,
551+
run_config,
552+
previous_response_id,
553+
conversation_id,
554+
session,
555+
):
556+
del (
557+
cls,
558+
starting_agent,
559+
input,
560+
context,
561+
max_turns,
562+
hooks,
563+
run_config,
564+
previous_response_id,
565+
conversation_id,
566+
session,
567+
)
568+
return run_result
569+
570+
monkeypatch.setattr(Runner, "run", classmethod(fake_run))
571+
572+
tool = agent.as_tool(
573+
tool_name="summary_tool",
574+
tool_description="Summarize current run output",
575+
)
576+
tool_context = ToolContext(
577+
context=None,
578+
tool_name="summary_tool",
579+
tool_call_id="call_1",
580+
tool_arguments='{"input": "hello"}',
581+
)
582+
583+
output = await tool.on_invoke_tool(tool_context, '{"input": "hello"}')
584+
585+
assert output == "Newest tool output"
586+
587+
410588
@pytest.mark.asyncio
411589
async def test_agent_as_tool_extractor_can_access_agent_tool_invocation(
412590
monkeypatch: pytest.MonkeyPatch,

tests/test_agent_runner.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -641,6 +641,45 @@ async def _cancel_tool() -> str:
641641
]
642642

643643

644+
@pytest.mark.asyncio
645+
async def test_single_tool_call_with_cancelled_tool_reaches_final_output() -> None:
646+
async def _cancel_tool() -> str:
647+
raise asyncio.CancelledError("tool-cancelled")
648+
649+
model = FakeModel()
650+
agent = Agent(
651+
name="test",
652+
model=model,
653+
tools=[function_tool(_cancel_tool, name_override="cancel_tool")],
654+
)
655+
656+
model.add_multiple_turn_outputs(
657+
[
658+
[get_function_tool_call("cancel_tool", "{}", call_id="call_cancel")],
659+
[get_text_message("final answer")],
660+
]
661+
)
662+
663+
result = await Runner.run(agent, input="user_message")
664+
665+
assert result.final_output == "final answer"
666+
assert len(result.raw_responses) == 2
667+
668+
second_turn_input = cast(list[dict[str, Any]], model.last_turn_args["input"])
669+
tool_outputs = [
670+
item for item in second_turn_input if item.get("type") == "function_call_output"
671+
]
672+
assert tool_outputs == [
673+
{
674+
"call_id": "call_cancel",
675+
"output": (
676+
"An error occurred while running the tool. Please try again. Error: tool-cancelled"
677+
),
678+
"type": "function_call_output",
679+
},
680+
]
681+
682+
644683
@pytest.mark.asyncio
645684
async def test_reasoning_item_id_policy_omits_follow_up_reasoning_ids() -> None:
646685
model = FakeModel()

tests/test_agent_runner_streamed.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -496,6 +496,46 @@ async def _cancel_tool() -> str:
496496
]
497497

498498

499+
@pytest.mark.asyncio
500+
async def test_streamed_single_tool_call_with_cancelled_tool_reaches_final_output() -> None:
501+
async def _cancel_tool() -> str:
502+
raise asyncio.CancelledError("tool-cancelled")
503+
504+
model = FakeModel()
505+
agent = Agent(
506+
name="test",
507+
model=model,
508+
tools=[function_tool(_cancel_tool, name_override="cancel_tool")],
509+
)
510+
511+
model.add_multiple_turn_outputs(
512+
[
513+
[get_function_tool_call("cancel_tool", "{}", call_id="call_cancel")],
514+
[get_text_message("final answer")],
515+
]
516+
)
517+
518+
result = Runner.run_streamed(agent, input="user_message")
519+
await consume_stream(result)
520+
521+
assert result.final_output == "final answer"
522+
assert len(result.raw_responses) == 2
523+
524+
second_turn_input = cast(list[dict[str, Any]], model.last_turn_args["input"])
525+
tool_outputs = [
526+
item for item in second_turn_input if item.get("type") == "function_call_output"
527+
]
528+
assert tool_outputs == [
529+
{
530+
"call_id": "call_cancel",
531+
"output": (
532+
"An error occurred while running the tool. Please try again. Error: tool-cancelled"
533+
),
534+
"type": "function_call_output",
535+
},
536+
]
537+
538+
499539
@pytest.mark.asyncio
500540
async def test_streamed_reasoning_item_id_policy_omits_follow_up_reasoning_ids() -> None:
501541
model = FakeModel()

tests/test_run_step_execution.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -740,6 +740,29 @@ async def _manual_on_invoke_tool(_ctx: ToolContext[Any], _args: str) -> str:
740740
)
741741

742742

743+
@pytest.mark.asyncio
744+
async def test_single_tool_call_uses_default_failure_error_function_for_cancelled_tool():
745+
async def _cancel_tool() -> str:
746+
raise asyncio.CancelledError("tool-cancelled")
747+
748+
cancel_tool = function_tool(_cancel_tool, name_override="cancel_tool")
749+
agent = Agent(name="test", tools=[cancel_tool])
750+
response = ModelResponse(
751+
output=[get_function_tool_call("cancel_tool", "{}", call_id="1")],
752+
usage=Usage(),
753+
response_id=None,
754+
)
755+
756+
result = await get_execute_result(agent, response)
757+
758+
assert len(result.generated_items) == 2
759+
assert isinstance(result.next_step, NextStepRunAgain)
760+
assert_item_is_function_tool_call_output(
761+
result.generated_items[1],
762+
"An error occurred while running the tool. Please try again. Error: tool-cancelled",
763+
)
764+
765+
743766
@pytest.mark.asyncio
744767
async def test_multiple_tool_calls_surface_hook_failure_over_sibling_cancellation():
745768
hook_started = asyncio.Event()

0 commit comments

Comments
 (0)