forked from openai/openai-agents-python
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtool_context.py
More file actions
183 lines (156 loc) · 6.51 KB
/
tool_context.py
File metadata and controls
183 lines (156 loc) · 6.51 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
from __future__ import annotations
from dataclasses import dataclass, field, fields
from typing import TYPE_CHECKING, Any, cast
from openai.types.responses import ResponseFunctionToolCall
from ._tool_identity import get_tool_call_namespace, tool_trace_name
from .agent_tool_state import get_agent_tool_state_scope, set_agent_tool_state_scope
from .run_context import RunContextWrapper, TContext
from .usage import Usage
if TYPE_CHECKING:
from .agent import AgentBase
from .items import TResponseInputItem
from .run_config import RunConfig
from .run_context import _ApprovalRecord
def _assert_must_pass_tool_call_id() -> str:
raise ValueError("tool_call_id must be passed to ToolContext")
def _assert_must_pass_tool_name() -> str:
raise ValueError("tool_name must be passed to ToolContext")
def _assert_must_pass_tool_arguments() -> str:
raise ValueError("tool_arguments must be passed to ToolContext")
_MISSING = object()
@dataclass
class ToolContext(RunContextWrapper[TContext]):
"""The context of a tool call."""
tool_name: str = field(default_factory=_assert_must_pass_tool_name)
"""The name of the tool being invoked."""
tool_call_id: str = field(default_factory=_assert_must_pass_tool_call_id)
"""The ID of the tool call."""
tool_arguments: str = field(default_factory=_assert_must_pass_tool_arguments)
"""The raw arguments string of the tool call."""
tool_call: ResponseFunctionToolCall | None = None
"""The tool call object associated with this invocation."""
tool_namespace: str | None = None
"""The Responses API namespace for this tool call, when present."""
agent: AgentBase[Any] | None = None
"""The active agent for this tool call, when available."""
run_config: RunConfig | None = None
"""The active run config for this tool call, when available."""
def __init__(
self,
context: TContext,
usage: Usage | object = _MISSING,
tool_name: str | object = _MISSING,
tool_call_id: str | object = _MISSING,
tool_arguments: str | object = _MISSING,
tool_call: ResponseFunctionToolCall | None = None,
*,
tool_namespace: str | None = None,
agent: AgentBase[Any] | None = None,
run_config: RunConfig | None = None,
turn_input: list[TResponseInputItem] | None = None,
_approvals: dict[str, _ApprovalRecord] | None = None,
tool_input: Any | None = None,
_generated_items: list[Any] | None = None,
) -> None:
"""Preserve the v0.7 positional constructor while accepting new context fields."""
resolved_usage = Usage() if usage is _MISSING else cast(Usage, usage)
super().__init__(
context=context,
usage=resolved_usage,
turn_input=list(turn_input or []),
_approvals={} if _approvals is None else _approvals,
tool_input=tool_input,
_generated_items=list(_generated_items or []),
)
self.tool_name = (
_assert_must_pass_tool_name() if tool_name is _MISSING else cast(str, tool_name)
)
self.tool_arguments = (
_assert_must_pass_tool_arguments()
if tool_arguments is _MISSING
else cast(str, tool_arguments)
)
self.tool_call_id = (
_assert_must_pass_tool_call_id()
if tool_call_id is _MISSING
else cast(str, tool_call_id)
)
self.tool_call = tool_call
self.tool_namespace = (
tool_namespace
if isinstance(tool_namespace, str)
else get_tool_call_namespace(tool_call)
)
self.agent = agent
self.run_config = run_config
@property
def conversation_history(self) -> list[Any]:
"""The items generated so far in the current agent run.
This is a snapshot of the conversation history at the time the tool was invoked,
allowing tools to inspect prior messages, tool calls, and other items produced
during the run. Returns a copy so that mutations do not affect the run state.
"""
return list(self._generated_items)
@property
def qualified_tool_name(self) -> str:
"""Return the tool name qualified by namespace when available."""
return tool_trace_name(self.tool_name, self.tool_namespace) or self.tool_name
@classmethod
def from_agent_context(
cls,
context: RunContextWrapper[TContext],
tool_call_id: str,
tool_call: ResponseFunctionToolCall | None = None,
agent: AgentBase[Any] | None = None,
*,
tool_name: str | None = None,
tool_arguments: str | None = None,
tool_namespace: str | None = None,
run_config: RunConfig | None = None,
) -> ToolContext:
"""
Create a ToolContext from a RunContextWrapper.
"""
# Grab the names of the RunContextWrapper's init=True fields
base_values: dict[str, Any] = {
f.name: getattr(context, f.name) for f in fields(RunContextWrapper) if f.init
}
resolved_tool_name = (
tool_name
if tool_name is not None
else (tool_call.name if tool_call is not None else _assert_must_pass_tool_name())
)
resolved_tool_args = (
tool_arguments
if tool_arguments is not None
else (
tool_call.arguments if tool_call is not None else _assert_must_pass_tool_arguments()
)
)
tool_agent = agent
if tool_agent is None and isinstance(context, ToolContext):
tool_agent = context.agent
tool_run_config = run_config
if tool_run_config is None and isinstance(context, ToolContext):
tool_run_config = context.run_config
tool_context = cls(
tool_name=resolved_tool_name,
tool_call_id=tool_call_id,
tool_arguments=resolved_tool_args,
tool_call=tool_call,
tool_namespace=(
tool_namespace
if isinstance(tool_namespace, str)
else (
getattr(tool_call, "namespace", None)
if tool_call is not None
and isinstance(getattr(tool_call, "namespace", None), str)
else None
)
),
agent=tool_agent,
run_config=tool_run_config,
**base_values,
)
set_agent_tool_state_scope(tool_context, get_agent_tool_state_scope(context))
return tool_context