Skip to content

Commit 85e5616

Browse files
fix(tool): resolve default failure handler at invoke time (#2460)
1 parent f1f88f8 commit 85e5616

2 files changed

Lines changed: 33 additions & 3 deletions

File tree

src/agents/tool.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
Callable,
1313
Generic,
1414
Literal,
15+
Optional,
1516
Protocol,
1617
TypeVar,
1718
Union,
@@ -765,6 +766,7 @@ def default_tool_error_function(ctx: RunContextWrapper[Any], error: Exception) -
765766

766767

767768
ToolErrorFunction = Callable[[RunContextWrapper[Any], Exception], MaybeAwaitable[str]]
769+
_UNSET_FAILURE_ERROR_FUNCTION = object()
768770

769771

770772
@overload
@@ -813,7 +815,7 @@ def function_tool(
813815
description_override: str | None = None,
814816
docstring_style: DocstringStyle | None = None,
815817
use_docstring_info: bool = True,
816-
failure_error_function: ToolErrorFunction | None = default_tool_error_function,
818+
failure_error_function: ToolErrorFunction | None | object = _UNSET_FAILURE_ERROR_FUNCTION,
817819
strict_mode: bool = True,
818820
is_enabled: bool | Callable[[RunContextWrapper[Any], AgentBase], MaybeAwaitable[bool]] = True,
819821
needs_approval: bool
@@ -923,10 +925,18 @@ async def _on_invoke_tool(ctx: ToolContext[Any], input: str) -> Any:
923925
try:
924926
return await _on_invoke_tool_impl(ctx, input)
925927
except Exception as e:
926-
if failure_error_function is None:
928+
resolved_failure_error_function: ToolErrorFunction | None
929+
if failure_error_function is _UNSET_FAILURE_ERROR_FUNCTION:
930+
resolved_failure_error_function = default_tool_error_function
931+
else:
932+
resolved_failure_error_function = cast(
933+
Optional[ToolErrorFunction], failure_error_function
934+
)
935+
936+
if resolved_failure_error_function is None:
927937
raise
928938

929-
result = failure_error_function(ctx, e)
939+
result = resolved_failure_error_function(ctx, e)
930940
if inspect.isawaitable(result):
931941
return await result
932942

tests/test_function_tool.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from pydantic import BaseModel
99
from typing_extensions import TypedDict
1010

11+
import agents.tool as tool_module
1112
from agents import (
1213
Agent,
1314
AgentBase,
@@ -448,6 +449,25 @@ def boom() -> None:
448449
assert result.startswith("handled:")
449450

450451

452+
@pytest.mark.asyncio
453+
async def test_default_failure_error_function_is_resolved_at_invoke_time(
454+
monkeypatch: pytest.MonkeyPatch,
455+
) -> None:
456+
def boom(a: int) -> None:
457+
raise ValueError(f"boom:{a}")
458+
459+
tool = function_tool(boom)
460+
461+
def patched_default(_ctx: RunContextWrapper[Any], error: Exception) -> str:
462+
return f"patched:{error}"
463+
464+
monkeypatch.setattr(tool_module, "default_tool_error_function", patched_default)
465+
466+
ctx = ToolContext(None, tool_name=tool.name, tool_call_id="1", tool_arguments='{"a": 7}')
467+
result = await tool.on_invoke_tool(ctx, '{"a": 7}')
468+
assert result == "patched:boom:7"
469+
470+
451471
def test_function_tool_accepts_guardrail_arguments():
452472
tool = function_tool(
453473
simple_function,

0 commit comments

Comments
 (0)