|
12 | 12 | Callable, |
13 | 13 | Generic, |
14 | 14 | Literal, |
| 15 | + Optional, |
15 | 16 | Protocol, |
16 | 17 | TypeVar, |
17 | 18 | Union, |
@@ -765,6 +766,7 @@ def default_tool_error_function(ctx: RunContextWrapper[Any], error: Exception) - |
765 | 766 |
|
766 | 767 |
|
767 | 768 | ToolErrorFunction = Callable[[RunContextWrapper[Any], Exception], MaybeAwaitable[str]] |
| 769 | +_UNSET_FAILURE_ERROR_FUNCTION = object() |
768 | 770 |
|
769 | 771 |
|
770 | 772 | @overload |
@@ -813,7 +815,7 @@ def function_tool( |
813 | 815 | description_override: str | None = None, |
814 | 816 | docstring_style: DocstringStyle | None = None, |
815 | 817 | use_docstring_info: bool = True, |
816 | | - failure_error_function: ToolErrorFunction | None = default_tool_error_function, |
| 818 | + failure_error_function: ToolErrorFunction | None | object = _UNSET_FAILURE_ERROR_FUNCTION, |
817 | 819 | strict_mode: bool = True, |
818 | 820 | is_enabled: bool | Callable[[RunContextWrapper[Any], AgentBase], MaybeAwaitable[bool]] = True, |
819 | 821 | needs_approval: bool |
@@ -923,10 +925,18 @@ async def _on_invoke_tool(ctx: ToolContext[Any], input: str) -> Any: |
923 | 925 | try: |
924 | 926 | return await _on_invoke_tool_impl(ctx, input) |
925 | 927 | except Exception as e: |
926 | | - if failure_error_function is None: |
| 928 | + resolved_failure_error_function: ToolErrorFunction | None |
| 929 | + if failure_error_function is _UNSET_FAILURE_ERROR_FUNCTION: |
| 930 | + resolved_failure_error_function = default_tool_error_function |
| 931 | + else: |
| 932 | + resolved_failure_error_function = cast( |
| 933 | + Optional[ToolErrorFunction], failure_error_function |
| 934 | + ) |
| 935 | + |
| 936 | + if resolved_failure_error_function is None: |
927 | 937 | raise |
928 | 938 |
|
929 | | - result = failure_error_function(ctx, e) |
| 939 | + result = resolved_failure_error_function(ctx, e) |
930 | 940 | if inspect.isawaitable(result): |
931 | 941 | return await result |
932 | 942 |
|
|
0 commit comments