Skip to content

Commit 56ec673

Browse files
authored
fix: #2426 persist streamed run-again tool items to session (#2433)
1 parent f923b13 commit 56ec673

File tree

2 files changed

+42
-0
lines changed

2 files changed

+42
-0
lines changed

src/agents/run_internal/run_loop.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -946,6 +946,12 @@ async def _save_stream_items_without_count(
946946
if streamed_result._state is not None:
947947
streamed_result._state._current_step = NextStepRunAgain()
948948

949+
await _save_stream_items_with_count(
950+
turn_session_items,
951+
turn_result.model_response.response_id,
952+
store_setting,
953+
)
954+
949955
if streamed_result._cancel_mode == "after_turn": # type: ignore[comparison-overlap]
950956
streamed_result.is_complete = True
951957
streamed_result._event_queue.put_nowait(QueueCompleteSentinel())

tests/test_agent_runner_streamed.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,42 @@ async def test_tool_call_runs():
151151
)
152152

153153

154+
@pytest.mark.asyncio
155+
async def test_streamed_run_again_persists_tool_items_to_session():
156+
model = FakeModel()
157+
call_id = "call-session-run-again"
158+
agent = Agent(
159+
name="test",
160+
model=model,
161+
tools=[get_function_tool("foo", "tool_result")],
162+
)
163+
session = SimpleListSession()
164+
165+
model.add_multiple_turn_outputs(
166+
[
167+
[get_function_tool_call("foo", json.dumps({"a": "b"}), call_id=call_id)],
168+
[get_text_message("done")],
169+
]
170+
)
171+
172+
result = Runner.run_streamed(agent, input="user_message", session=session)
173+
await consume_stream(result)
174+
175+
saved_items = await session.get_items()
176+
assert any(
177+
isinstance(item, dict)
178+
and item.get("type") == "function_call"
179+
and item.get("call_id") == call_id
180+
for item in saved_items
181+
)
182+
assert any(
183+
isinstance(item, dict)
184+
and item.get("type") == "function_call_output"
185+
and item.get("call_id") == call_id
186+
for item in saved_items
187+
)
188+
189+
154190
@pytest.mark.asyncio
155191
async def test_handoffs():
156192
model = FakeModel()

0 commit comments

Comments
 (0)