Skip to content

Commit 4b63ed7

Browse files
authored
fix: restore v0.7.0 constructor compatibility for RunResult types (#2414)
1 parent 01f9c37 commit 4b63ed7

3 files changed

Lines changed: 226 additions & 17 deletions

File tree

src/agents/result.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import copy
66
import weakref
77
from collections.abc import AsyncIterator
8-
from dataclasses import dataclass, field
8+
from dataclasses import InitVar, dataclass, field
99
from typing import Any, Literal, TypeVar, cast
1010

1111
from .agent import Agent
@@ -76,8 +76,9 @@ def _populate_state_from_result(
7676
state._previous_response_id = previous_response_id
7777
state._auto_previous_response_id = auto_previous_response_id
7878

79-
if result.interruptions:
80-
state._current_step = NextStepInterruption(interruptions=result.interruptions)
79+
interruptions = list(getattr(result, "interruptions", []))
80+
if interruptions:
81+
state._current_step = NextStepInterruption(interruptions=interruptions)
8182

8283
trace_state = getattr(result, "_trace_state", None)
8384
if trace_state is None:
@@ -120,9 +121,6 @@ class RunResultBase(abc.ABC):
120121
context_wrapper: RunContextWrapper[Any]
121122
"""The context wrapper for the agent run."""
122123

123-
interruptions: list[ToolApprovalItem]
124-
"""Pending tool approval requests (interruptions) for this run."""
125-
126124
_trace_state: TraceState | None = field(default=None, init=False, repr=False)
127125
"""Serialized trace metadata captured during the run."""
128126

@@ -228,6 +226,8 @@ class RunResult(RunResultBase):
228226
"""Whether automatic previous response tracking was enabled."""
229227
max_turns: int = 10
230228
"""The maximum number of turns allowed for this run."""
229+
interruptions: list[ToolApprovalItem] = field(default_factory=list)
230+
"""Pending tool approval requests (interruptions) for this run."""
231231

232232
def __post_init__(self) -> None:
233233
self._last_agent_ref = weakref.ref(self._last_agent)
@@ -337,8 +337,6 @@ class RunResultStreaming(RunResultBase):
337337
repr=False,
338338
default=None,
339339
)
340-
_last_processed_response: ProcessedResponse | None = field(default=None, repr=False)
341-
"""The last processed model response. This is needed for resuming from interruptions."""
342340

343341
_model_input_items: list[RunItem] = field(default_factory=list, repr=False)
344342
"""Filtered items used to build model input between streaming turns."""
@@ -356,6 +354,11 @@ class RunResultStreaming(RunResultBase):
356354
_input_guardrails_task: asyncio.Task[Any] | None = field(default=None, repr=False)
357355
_output_guardrails_task: asyncio.Task[Any] | None = field(default=None, repr=False)
358356
_stored_exception: Exception | None = field(default=None, repr=False)
357+
_cancel_mode: Literal["none", "immediate", "after_turn"] = field(default="none", repr=False)
358+
_last_processed_response: ProcessedResponse | None = field(default=None, repr=False)
359+
"""The last processed model response. This is needed for resuming from interruptions."""
360+
interruptions: list[ToolApprovalItem] = field(default_factory=list)
361+
"""Pending tool approval requests (interruptions) for this run."""
359362
_waiting_on_event_queue: bool = field(default=False, repr=False)
360363

361364
_current_turn_persisted_item_count: int = 0
@@ -369,9 +372,6 @@ class RunResultStreaming(RunResultBase):
369372
"""Original turn input before session history was merged, used for
370373
persistence (matches JS sessionInputOriginalSnapshot)."""
371374

372-
# Soft cancel state
373-
_cancel_mode: Literal["none", "immediate", "after_turn"] = field(default="none", repr=False)
374-
375375
_max_turns_handled: bool = field(default=False, repr=False)
376376

377377
_original_input: str | list[TResponseInputItem] | None = field(default=None, repr=False)
@@ -386,12 +386,16 @@ class RunResultStreaming(RunResultBase):
386386
"""Response identifier returned by the server for the last turn."""
387387
_auto_previous_response_id: bool = field(default=False, repr=False)
388388
"""Whether automatic previous response tracking was enabled."""
389+
_run_impl_task: InitVar[asyncio.Task[Any] | None] = None
389390

390-
def __post_init__(self) -> None:
391+
def __post_init__(self, _run_impl_task: asyncio.Task[Any] | None) -> None:
391392
self._current_agent_ref = weakref.ref(self.current_agent)
392393
# Store the original input at creation time (it will be set via input field)
393394
if self._original_input is None:
394395
self._original_input = self.input
396+
# Compatibility shim: accept legacy `_run_impl_task` constructor keyword.
397+
if self.run_loop_task is None and _run_impl_task is not None:
398+
self.run_loop_task = _run_impl_task
395399

396400
@property
397401
def last_agent(self) -> Agent[Any]:

src/agents/tool_context.py

Lines changed: 51 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,16 @@
1+
from __future__ import annotations
2+
13
from dataclasses import dataclass, field, fields
2-
from typing import Any, Optional
4+
from typing import TYPE_CHECKING, Any, cast
35

46
from openai.types.responses import ResponseFunctionToolCall
57

68
from .run_context import RunContextWrapper, TContext
9+
from .usage import Usage
10+
11+
if TYPE_CHECKING:
12+
from .items import TResponseInputItem
13+
from .run_context import _ApprovalRecord
714

815

916
def _assert_must_pass_tool_call_id() -> str:
@@ -18,6 +25,9 @@ def _assert_must_pass_tool_arguments() -> str:
1825
raise ValueError("tool_arguments must be passed to ToolContext")
1926

2027

28+
_MISSING = object()
29+
30+
2131
@dataclass
2232
class ToolContext(RunContextWrapper[TContext]):
2333
"""The context of a tool call."""
@@ -31,16 +41,53 @@ class ToolContext(RunContextWrapper[TContext]):
3141
tool_arguments: str = field(default_factory=_assert_must_pass_tool_arguments)
3242
"""The raw arguments string of the tool call."""
3343

34-
tool_call: Optional[ResponseFunctionToolCall] = None
44+
tool_call: ResponseFunctionToolCall | None = None
3545
"""The tool call object associated with this invocation."""
3646

47+
def __init__(
48+
self,
49+
context: TContext,
50+
usage: Usage | object = _MISSING,
51+
tool_name: str | object = _MISSING,
52+
tool_call_id: str | object = _MISSING,
53+
tool_arguments: str | object = _MISSING,
54+
tool_call: ResponseFunctionToolCall | None = None,
55+
*,
56+
turn_input: list[TResponseInputItem] | None = None,
57+
_approvals: dict[str, _ApprovalRecord] | None = None,
58+
tool_input: Any | None = None,
59+
) -> None:
60+
"""Preserve the v0.7 positional constructor while accepting new context fields."""
61+
resolved_usage = Usage() if usage is _MISSING else cast(Usage, usage)
62+
super().__init__(
63+
context=context,
64+
usage=resolved_usage,
65+
turn_input=list(turn_input or []),
66+
_approvals={} if _approvals is None else _approvals,
67+
tool_input=tool_input,
68+
)
69+
self.tool_name = (
70+
_assert_must_pass_tool_name() if tool_name is _MISSING else cast(str, tool_name)
71+
)
72+
self.tool_arguments = (
73+
_assert_must_pass_tool_arguments()
74+
if tool_arguments is _MISSING
75+
else cast(str, tool_arguments)
76+
)
77+
self.tool_call_id = (
78+
_assert_must_pass_tool_call_id()
79+
if tool_call_id is _MISSING
80+
else cast(str, tool_call_id)
81+
)
82+
self.tool_call = tool_call
83+
3784
@classmethod
3885
def from_agent_context(
3986
cls,
4087
context: RunContextWrapper[TContext],
4188
tool_call_id: str,
42-
tool_call: Optional[ResponseFunctionToolCall] = None,
43-
) -> "ToolContext":
89+
tool_call: ResponseFunctionToolCall | None = None,
90+
) -> ToolContext:
4491
"""
4592
Create a ToolContext from a RunContextWrapper.
4693
"""

tests/test_source_compat_constructors.py

Lines changed: 159 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,19 @@
11
from __future__ import annotations
22

3-
from typing import Any
3+
import asyncio
4+
from typing import Any, cast
45

56
from agents import (
7+
Agent,
68
AgentHookContext,
79
FunctionTool,
810
HandoffInputData,
911
ItemHelpers,
1012
MultiProvider,
1113
RunConfig,
14+
RunContextWrapper,
15+
RunResult,
16+
RunResultStreaming,
1217
ToolGuardrailFunctionOutput,
1318
ToolInputGuardrailData,
1419
ToolOutputGuardrailData,
@@ -68,3 +73,156 @@ def test_agent_hook_context_third_positional_argument_is_turn_input() -> None:
6873

6974
assert context.turn_input == turn_input
7075
assert isinstance(context._approvals, dict)
76+
77+
78+
def test_tool_context_v070_positional_constructor_still_works() -> None:
79+
usage = Usage()
80+
context = ToolContext(None, usage, "tool_name", "call_id", '{"x":1}', None)
81+
82+
assert context.usage is usage
83+
assert context.tool_name == "tool_name"
84+
assert context.tool_call_id == "call_id"
85+
assert context.tool_arguments == '{"x":1}'
86+
87+
88+
def test_run_result_v070_positional_constructor_still_works() -> None:
89+
result = RunResult(
90+
"x",
91+
[],
92+
[],
93+
"ok",
94+
[],
95+
[],
96+
[],
97+
[],
98+
RunContextWrapper(context=None),
99+
Agent(name="agent"),
100+
)
101+
assert result.final_output == "ok"
102+
assert result.interruptions == []
103+
104+
105+
def test_run_result_streaming_v070_positional_constructor_still_works() -> None:
106+
result = RunResultStreaming(
107+
"x",
108+
[],
109+
[],
110+
"ok",
111+
[],
112+
[],
113+
[],
114+
[],
115+
RunContextWrapper(context=None),
116+
Agent(name="agent"),
117+
0,
118+
1,
119+
None,
120+
None,
121+
)
122+
assert result.final_output == "ok"
123+
assert result.interruptions == []
124+
125+
126+
def test_run_result_streaming_v070_optional_positional_constructor_still_works() -> None:
127+
event_queue: asyncio.Queue[Any] = asyncio.Queue()
128+
input_guardrail_queue: asyncio.Queue[Any] = asyncio.Queue()
129+
result = RunResultStreaming(
130+
"x",
131+
[],
132+
[],
133+
"ok",
134+
[],
135+
[],
136+
[],
137+
[],
138+
RunContextWrapper(context=None),
139+
Agent(name="agent"),
140+
0,
141+
1,
142+
None,
143+
None,
144+
True,
145+
[],
146+
event_queue,
147+
input_guardrail_queue,
148+
None,
149+
)
150+
assert result.is_complete is True
151+
assert result.run_loop_task is None
152+
assert result._event_queue is event_queue
153+
assert result._input_guardrail_queue is input_guardrail_queue
154+
assert result.interruptions == []
155+
156+
157+
def test_run_result_streaming_accepts_legacy_run_impl_task_keyword() -> None:
158+
sentinel_task = cast(Any, object())
159+
result = RunResultStreaming(
160+
input="x",
161+
new_items=[],
162+
raw_responses=[],
163+
final_output="ok",
164+
input_guardrail_results=[],
165+
output_guardrail_results=[],
166+
tool_input_guardrail_results=[],
167+
tool_output_guardrail_results=[],
168+
context_wrapper=RunContextWrapper(context=None),
169+
current_agent=Agent(name="agent"),
170+
current_turn=0,
171+
max_turns=1,
172+
_current_agent_output_schema=None,
173+
trace=None,
174+
_run_impl_task=sentinel_task,
175+
)
176+
assert result.run_loop_task is sentinel_task
177+
178+
179+
def test_run_result_streaming_accepts_run_loop_task_keyword() -> None:
180+
sentinel_task = cast(Any, object())
181+
result = RunResultStreaming(
182+
input="x",
183+
new_items=[],
184+
raw_responses=[],
185+
final_output="ok",
186+
input_guardrail_results=[],
187+
output_guardrail_results=[],
188+
tool_input_guardrail_results=[],
189+
tool_output_guardrail_results=[],
190+
context_wrapper=RunContextWrapper(context=None),
191+
current_agent=Agent(name="agent"),
192+
current_turn=0,
193+
max_turns=1,
194+
_current_agent_output_schema=None,
195+
trace=None,
196+
run_loop_task=sentinel_task,
197+
)
198+
assert result.run_loop_task is sentinel_task
199+
200+
201+
def test_run_result_streaming_v070_run_impl_task_positional_binding_is_preserved() -> None:
202+
sentinel_task = cast(Any, object())
203+
event_queue: asyncio.Queue[Any] = asyncio.Queue()
204+
input_guardrail_queue: asyncio.Queue[Any] = asyncio.Queue()
205+
result = RunResultStreaming(
206+
"x",
207+
[],
208+
[],
209+
"ok",
210+
[],
211+
[],
212+
[],
213+
[],
214+
RunContextWrapper(context=None),
215+
Agent(name="agent"),
216+
0,
217+
1,
218+
None,
219+
None,
220+
False,
221+
[],
222+
event_queue,
223+
input_guardrail_queue,
224+
sentinel_task,
225+
)
226+
assert result._event_queue is event_queue
227+
assert result._input_guardrail_queue is input_guardrail_queue
228+
assert result.run_loop_task is sentinel_task

0 commit comments

Comments
 (0)