Skip to content

Commit 7351bf6

Browse files
authored
Finish reasoning summaries before text deltas (#2609)
1 parent a70a002 commit 7351bf6

File tree

2 files changed

+193
-51
lines changed

2 files changed

+193
-51
lines changed

src/agents/models/chatcmpl_stream_handler.py

Lines changed: 99 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
from collections.abc import AsyncIterator
3+
from collections.abc import AsyncIterator, Iterator
44
from dataclasses import dataclass, field
55
from typing import Any
66

@@ -60,6 +60,8 @@ class StreamingState:
6060
text_content_index_and_output: tuple[int, ResponseOutputText] | None = None
6161
refusal_content_index_and_output: tuple[int, ResponseOutputRefusal] | None = None
6262
reasoning_content_index_and_output: tuple[int, ResponseReasoningItem] | None = None
63+
active_reasoning_summary_index: int | None = None
64+
reasoning_item_done: bool = False
6365
function_calls: dict[int, ResponseFunctionToolCall] = field(default_factory=dict)
6466
# Fields for real-time function call streaming
6567
function_call_streaming: dict[int, bool] = field(default_factory=dict)
@@ -82,6 +84,67 @@ def get_and_increment(self) -> int:
8284

8385

8486
class ChatCmplStreamHandler:
87+
@classmethod
88+
def _finish_reasoning_summary_part(
89+
cls,
90+
state: StreamingState,
91+
sequence_number: SequenceNumber,
92+
) -> Iterator[TResponseStreamEvent]:
93+
if (
94+
not state.reasoning_content_index_and_output
95+
or state.active_reasoning_summary_index is None
96+
):
97+
return
98+
99+
reasoning_item = state.reasoning_content_index_and_output[1]
100+
summary_index = state.active_reasoning_summary_index
101+
if not reasoning_item.summary or summary_index >= len(reasoning_item.summary):
102+
state.active_reasoning_summary_index = None
103+
return
104+
105+
yield ResponseReasoningSummaryPartDoneEvent(
106+
item_id=FAKE_RESPONSES_ID,
107+
output_index=0,
108+
summary_index=summary_index,
109+
part=DoneEventPart(
110+
text=reasoning_item.summary[summary_index].text,
111+
type="summary_text",
112+
),
113+
type="response.reasoning_summary_part.done",
114+
sequence_number=sequence_number.get_and_increment(),
115+
)
116+
state.active_reasoning_summary_index = None
117+
118+
@classmethod
119+
def _finish_reasoning_item(
120+
cls,
121+
state: StreamingState,
122+
sequence_number: SequenceNumber,
123+
) -> Iterator[TResponseStreamEvent]:
124+
if not state.reasoning_content_index_and_output or state.reasoning_item_done:
125+
return
126+
127+
reasoning_item = state.reasoning_content_index_and_output[1]
128+
if reasoning_item.summary and len(reasoning_item.summary) > 0:
129+
yield from cls._finish_reasoning_summary_part(state, sequence_number)
130+
elif reasoning_item.content is not None:
131+
yield ResponseReasoningTextDoneEvent(
132+
item_id=FAKE_RESPONSES_ID,
133+
output_index=0,
134+
content_index=0,
135+
text=reasoning_item.content[0].text,
136+
type="response.reasoning_text.done",
137+
sequence_number=sequence_number.get_and_increment(),
138+
)
139+
140+
yield ResponseOutputItemDoneEvent(
141+
item=reasoning_item,
142+
output_index=0,
143+
type="response.output_item.done",
144+
sequence_number=sequence_number.get_and_increment(),
145+
)
146+
state.reasoning_item_done = True
147+
85148
@classmethod
86149
async def handle_stream(
87150
cls,
@@ -149,7 +212,7 @@ async def handle_stream(
149212
if reasoning_content and not state.reasoning_content_index_and_output:
150213
reasoning_item = ResponseReasoningItem(
151214
id=FAKE_RESPONSES_ID,
152-
summary=[Summary(text="", type="summary_text")],
215+
summary=[],
153216
type="reasoning",
154217
)
155218
if state.provider_data:
@@ -162,36 +225,37 @@ async def handle_stream(
162225
sequence_number=sequence_number.get_and_increment(),
163226
)
164227

165-
yield ResponseReasoningSummaryPartAddedEvent(
166-
item_id=FAKE_RESPONSES_ID,
167-
output_index=0,
168-
summary_index=0,
169-
part=AddedEventPart(text="", type="summary_text"),
170-
type="response.reasoning_summary_part.added",
171-
sequence_number=sequence_number.get_and_increment(),
172-
)
173-
174228
if reasoning_content and state.reasoning_content_index_and_output:
175-
# Ensure summary list has at least one element
176-
if not state.reasoning_content_index_and_output[1].summary:
177-
state.reasoning_content_index_and_output[1].summary = [
178-
Summary(text="", type="summary_text")
179-
]
229+
reasoning_item = state.reasoning_content_index_and_output[1]
230+
if state.active_reasoning_summary_index is None:
231+
summary_index = len(reasoning_item.summary)
232+
reasoning_item.summary.append(Summary(text="", type="summary_text"))
233+
state.active_reasoning_summary_index = summary_index
234+
235+
yield ResponseReasoningSummaryPartAddedEvent(
236+
item_id=FAKE_RESPONSES_ID,
237+
output_index=0,
238+
summary_index=summary_index,
239+
part=AddedEventPart(text="", type="summary_text"),
240+
type="response.reasoning_summary_part.added",
241+
sequence_number=sequence_number.get_and_increment(),
242+
)
243+
244+
summary_index = state.active_reasoning_summary_index
180245

181246
yield ResponseReasoningSummaryTextDeltaEvent(
182247
delta=reasoning_content,
183248
item_id=FAKE_RESPONSES_ID,
184249
output_index=0,
185-
summary_index=0,
250+
summary_index=summary_index,
186251
type="response.reasoning_summary_text.delta",
187252
sequence_number=sequence_number.get_and_increment(),
188253
)
189254

190-
# Create a new summary with updated text
191-
current_content = state.reasoning_content_index_and_output[1].summary[0]
255+
current_content = reasoning_item.summary[summary_index]
192256
updated_text = current_content.text + reasoning_content
193257
new_content = Summary(text=updated_text, type="summary_text")
194-
state.reasoning_content_index_and_output[1].summary[0] = new_content
258+
reasoning_item.summary[summary_index] = new_content
195259

196260
# Handle reasoning content from 3rd party platforms
197261
if hasattr(delta, "reasoning"):
@@ -233,6 +297,19 @@ async def handle_stream(
233297
new_text_content = Content(text=updated_text, type="reasoning_text")
234298
state.reasoning_content_index_and_output[1].content[0] = new_text_content
235299

300+
if (
301+
state.reasoning_content_index_and_output
302+
and state.active_reasoning_summary_index is not None
303+
and not (hasattr(delta, "reasoning_content") and delta.reasoning_content)
304+
and (
305+
delta.content is not None
306+
or (hasattr(delta, "refusal") and delta.refusal)
307+
or bool(delta.tool_calls)
308+
)
309+
):
310+
for event in cls._finish_reasoning_summary_part(state, sequence_number):
311+
yield event
312+
236313
# Handle regular content
237314
if delta.content is not None:
238315
if not state.text_content_index_and_output:
@@ -513,37 +590,8 @@ async def handle_stream(
513590
sequence_number=sequence_number.get_and_increment(),
514591
)
515592

516-
if state.reasoning_content_index_and_output:
517-
if (
518-
state.reasoning_content_index_and_output[1].summary
519-
and len(state.reasoning_content_index_and_output[1].summary) > 0
520-
):
521-
yield ResponseReasoningSummaryPartDoneEvent(
522-
item_id=FAKE_RESPONSES_ID,
523-
output_index=0,
524-
summary_index=0,
525-
part=DoneEventPart(
526-
text=state.reasoning_content_index_and_output[1].summary[0].text,
527-
type="summary_text",
528-
),
529-
type="response.reasoning_summary_part.done",
530-
sequence_number=sequence_number.get_and_increment(),
531-
)
532-
elif state.reasoning_content_index_and_output[1].content is not None:
533-
yield ResponseReasoningTextDoneEvent(
534-
item_id=FAKE_RESPONSES_ID,
535-
output_index=0,
536-
content_index=0,
537-
text=state.reasoning_content_index_and_output[1].content[0].text,
538-
type="response.reasoning_text.done",
539-
sequence_number=sequence_number.get_and_increment(),
540-
)
541-
yield ResponseOutputItemDoneEvent(
542-
item=state.reasoning_content_index_and_output[1],
543-
output_index=0,
544-
type="response.output_item.done",
545-
sequence_number=sequence_number.get_and_increment(),
546-
)
593+
for event in cls._finish_reasoning_item(state, sequence_number):
594+
yield event
547595

548596
function_call_starting_index = 0
549597
if state.reasoning_content_index_and_output:

tests/test_reasoning_content.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,18 @@ async def patched_fetch_response(self, *args, **kwargs):
142142
assert reasoning_delta_events[0].delta == "Let me think"
143143
assert reasoning_delta_events[1].delta == " about this"
144144

145+
reasoning_done_index = next(
146+
index
147+
for index, event in enumerate(output_events)
148+
if event.type == "response.reasoning_summary_part.done"
149+
)
150+
first_text_delta_index = next(
151+
index
152+
for index, event in enumerate(output_events)
153+
if event.type == "response.output_text.delta"
154+
)
155+
assert reasoning_done_index < first_text_delta_index
156+
145157
# verify regular content events were emitted
146158
content_delta_events = [e for e in output_events if e.type == "response.output_text.delta"]
147159
assert len(content_delta_events) == 2
@@ -163,6 +175,88 @@ async def patched_fetch_response(self, *args, **kwargs):
163175
assert response_event.response.output[1].content[0].text == "The answer is 42"
164176

165177

178+
@pytest.mark.allow_call_model_methods
179+
@pytest.mark.asyncio
180+
async def test_stream_response_keeps_reasoning_item_open_across_interleaved_text(
181+
monkeypatch,
182+
) -> None:
183+
chunks = [
184+
create_chunk(create_reasoning_delta("Let me think")),
185+
create_chunk(create_content_delta("The answer")),
186+
create_chunk(create_reasoning_delta(" more carefully")),
187+
create_chunk(create_content_delta(" is 42"), include_usage=True),
188+
]
189+
190+
async def patched_fetch_response(self, *args, **kwargs):
191+
resp = Response(
192+
id="resp-id",
193+
created_at=0,
194+
model="fake-model",
195+
object="response",
196+
output=[],
197+
tool_choice="none",
198+
tools=[],
199+
parallel_tool_calls=False,
200+
)
201+
return resp, create_fake_stream(chunks)
202+
203+
monkeypatch.setattr(OpenAIChatCompletionsModel, "_fetch_response", patched_fetch_response)
204+
model = OpenAIProvider(use_responses=False).get_model("gpt-4")
205+
output_events = []
206+
async for event in model.stream_response(
207+
system_instructions=None,
208+
input="",
209+
model_settings=ModelSettings(),
210+
tools=[],
211+
output_schema=None,
212+
handoffs=[],
213+
tracing=ModelTracing.DISABLED,
214+
previous_response_id=None,
215+
conversation_id=None,
216+
prompt=None,
217+
):
218+
output_events.append(event)
219+
220+
reasoning_part_added_events = [
221+
event for event in output_events if event.type == "response.reasoning_summary_part.added"
222+
]
223+
assert [event.summary_index for event in reasoning_part_added_events] == [0, 1]
224+
225+
reasoning_part_done_events = [
226+
event for event in output_events if event.type == "response.reasoning_summary_part.done"
227+
]
228+
assert [event.summary_index for event in reasoning_part_done_events] == [0, 1]
229+
230+
first_reasoning_done_index = output_events.index(reasoning_part_done_events[0])
231+
first_text_delta_index = next(
232+
index
233+
for index, event in enumerate(output_events)
234+
if event.type == "response.output_text.delta"
235+
)
236+
second_reasoning_delta_index = next(
237+
index
238+
for index, event in enumerate(output_events)
239+
if event.type == "response.reasoning_summary_text.delta" and event.summary_index == 1
240+
)
241+
reasoning_item_done_index = next(
242+
index
243+
for index, event in enumerate(output_events)
244+
if event.type == "response.output_item.done" and event.item.type == "reasoning"
245+
)
246+
247+
assert first_reasoning_done_index < first_text_delta_index
248+
assert second_reasoning_delta_index > first_text_delta_index
249+
assert reasoning_item_done_index > second_reasoning_delta_index
250+
251+
response_event = output_events[-1]
252+
assert response_event.type == "response.completed"
253+
assert isinstance(response_event.response.output[0], ResponseReasoningItem)
254+
assert [summary.text for summary in response_event.response.output[0].summary] == [
255+
"Let me think",
256+
" more carefully",
257+
]
258+
259+
166260
@pytest.mark.allow_call_model_methods
167261
@pytest.mark.asyncio
168262
async def test_get_response_with_reasoning_content(monkeypatch) -> None:

0 commit comments

Comments
 (0)