|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
| 3 | +import ast |
3 | 4 | import asyncio |
4 | 5 | import copy |
5 | 6 | import dataclasses |
|
9 | 10 | import weakref |
10 | 11 | from collections.abc import Awaitable, Mapping |
11 | 12 | from dataclasses import dataclass, field |
| 13 | +from types import UnionType |
12 | 14 | from typing import ( |
13 | 15 | TYPE_CHECKING, |
| 16 | + Annotated, |
14 | 17 | Any, |
15 | 18 | Callable, |
16 | 19 | Generic, |
|
19 | 22 | TypeVar, |
20 | 23 | Union, |
21 | 24 | cast, |
| 25 | + get_args, |
| 26 | + get_origin, |
| 27 | + get_type_hints, |
22 | 28 | overload, |
23 | 29 | ) |
24 | 30 |
|
@@ -1373,19 +1379,131 @@ async def maybe_invoke_function_tool_failure_error_function( |
1373 | 1379 | return result |
1374 | 1380 |
|
1375 | 1381 |
|
| 1382 | +def _annotation_expr_name(expr: ast.expr) -> str | None: |
| 1383 | + """Return the unqualified type name for a string annotation expression node.""" |
| 1384 | + if isinstance(expr, ast.Name): |
| 1385 | + return expr.id |
| 1386 | + if isinstance(expr, ast.Attribute): |
| 1387 | + return expr.attr |
| 1388 | + return None |
| 1389 | + |
| 1390 | + |
| 1391 | +def _string_annotation_mentions_context_type(annotation: str, *, type_name: str) -> bool: |
| 1392 | + """Return True when a string annotation structurally references the given context type.""" |
| 1393 | + try: |
| 1394 | + expression = ast.parse(annotation, mode="eval").body |
| 1395 | + except SyntaxError: |
| 1396 | + return False |
| 1397 | + |
| 1398 | + return _annotation_expr_mentions_context_type(expression, type_name=type_name) |
| 1399 | + |
| 1400 | + |
| 1401 | +def _annotation_expr_mentions_context_type(expr: ast.expr, *, type_name: str) -> bool: |
| 1402 | + """Return True when an annotation expression structurally references the given context type.""" |
| 1403 | + if isinstance(expr, ast.Constant) and isinstance(expr.value, str): |
| 1404 | + return _string_annotation_mentions_context_type(expr.value, type_name=type_name) |
| 1405 | + |
| 1406 | + if _annotation_expr_name(expr) == type_name: |
| 1407 | + return True |
| 1408 | + |
| 1409 | + if isinstance(expr, ast.BinOp) and isinstance(expr.op, ast.BitOr): |
| 1410 | + return _annotation_expr_mentions_context_type( |
| 1411 | + expr.left, type_name=type_name |
| 1412 | + ) or _annotation_expr_mentions_context_type(expr.right, type_name=type_name) |
| 1413 | + |
| 1414 | + if isinstance(expr, ast.Subscript): |
| 1415 | + wrapper_name = _annotation_expr_name(expr.value) |
| 1416 | + args = expr.slice.elts if isinstance(expr.slice, ast.Tuple) else (expr.slice,) |
| 1417 | + |
| 1418 | + if wrapper_name == "Annotated": |
| 1419 | + return bool(args) and _annotation_expr_mentions_context_type( |
| 1420 | + args[0], type_name=type_name |
| 1421 | + ) |
| 1422 | + |
| 1423 | + if wrapper_name in {"Optional", "Union"}: |
| 1424 | + return any( |
| 1425 | + _annotation_expr_mentions_context_type(arg, type_name=type_name) for arg in args |
| 1426 | + ) |
| 1427 | + |
| 1428 | + return _annotation_expr_mentions_context_type(expr.value, type_name=type_name) |
| 1429 | + |
| 1430 | + return False |
| 1431 | + |
| 1432 | + |
| 1433 | +def _annotation_mentions_context_type(annotation: Any, *, context_type: type[Any]) -> bool: |
| 1434 | + """Return True when an annotation structurally references the given context type.""" |
| 1435 | + if annotation is inspect.Signature.empty: |
| 1436 | + return False |
| 1437 | + |
| 1438 | + if isinstance(annotation, str): |
| 1439 | + return _string_annotation_mentions_context_type(annotation, type_name=context_type.__name__) |
| 1440 | + |
| 1441 | + origin = get_origin(annotation) |
| 1442 | + |
| 1443 | + if annotation is context_type or origin is context_type: |
| 1444 | + return True |
| 1445 | + |
| 1446 | + if origin is Annotated: |
| 1447 | + args = get_args(annotation) |
| 1448 | + return bool(args) and _annotation_mentions_context_type(args[0], context_type=context_type) |
| 1449 | + |
| 1450 | + if origin in (Union, UnionType): |
| 1451 | + return any( |
| 1452 | + _annotation_mentions_context_type(arg, context_type=context_type) |
| 1453 | + for arg in get_args(annotation) |
| 1454 | + ) |
| 1455 | + |
| 1456 | + return False |
| 1457 | + |
| 1458 | + |
| 1459 | +def _get_function_tool_invoke_context( |
| 1460 | + function_tool: FunctionTool, |
| 1461 | + context: ToolContext[Any], |
| 1462 | +) -> ToolContext[Any] | RunContextWrapper[Any]: |
| 1463 | + """Choose the runtime context object to pass into a function tool wrapper. |
| 1464 | +
|
| 1465 | + Third-party wrappers may declare a narrower `RunContextWrapper` contract and then serialize |
| 1466 | + that object downstream. In those cases, passing the richer `ToolContext` can leak runtime-only |
| 1467 | + metadata such as agents or run config into incompatible serializers. When the wrapper |
| 1468 | + explicitly declares `RunContextWrapper`, preserve only the base context state. |
| 1469 | + """ |
| 1470 | + try: |
| 1471 | + parameters = tuple(inspect.signature(function_tool.on_invoke_tool).parameters.values()) |
| 1472 | + except (TypeError, ValueError): |
| 1473 | + return context |
| 1474 | + |
| 1475 | + if not parameters: |
| 1476 | + return context |
| 1477 | + |
| 1478 | + context_annotation = parameters[0].annotation |
| 1479 | + try: |
| 1480 | + resolved_annotations = get_type_hints(function_tool.on_invoke_tool, include_extras=True) |
| 1481 | + except Exception: |
| 1482 | + pass |
| 1483 | + else: |
| 1484 | + context_annotation = resolved_annotations.get(parameters[0].name, context_annotation) |
| 1485 | + |
| 1486 | + if _annotation_mentions_context_type(context_annotation, context_type=ToolContext): |
| 1487 | + return context |
| 1488 | + if _annotation_mentions_context_type(context_annotation, context_type=RunContextWrapper): |
| 1489 | + return context._fork_with_tool_input(context.tool_input) |
| 1490 | + return context |
| 1491 | + |
| 1492 | + |
1376 | 1493 | async def invoke_function_tool( |
1377 | 1494 | *, |
1378 | 1495 | function_tool: FunctionTool, |
1379 | 1496 | context: ToolContext[Any], |
1380 | 1497 | arguments: str, |
1381 | 1498 | ) -> Any: |
1382 | 1499 | """Invoke a function tool, enforcing timeout configuration when provided.""" |
| 1500 | + invoke_context = _get_function_tool_invoke_context(function_tool, context) |
1383 | 1501 | timeout_seconds = function_tool.timeout_seconds |
1384 | 1502 | if timeout_seconds is None: |
1385 | | - return await function_tool.on_invoke_tool(context, arguments) |
| 1503 | + return await function_tool.on_invoke_tool(cast(Any, invoke_context), arguments) |
1386 | 1504 |
|
1387 | 1505 | tool_task: asyncio.Future[Any] = asyncio.ensure_future( |
1388 | | - function_tool.on_invoke_tool(context, arguments) |
| 1506 | + function_tool.on_invoke_tool(cast(Any, invoke_context), arguments) |
1389 | 1507 | ) |
1390 | 1508 | try: |
1391 | 1509 | return await asyncio.wait_for(tool_task, timeout=timeout_seconds) |
|
0 commit comments