Skip to content

Commit e88bcf2

Browse files
authored
fix: #2258 add normalized to_input_list mode for filtered handoff follow-ups (#2667)
1 parent 0cf38a6 commit e88bcf2

7 files changed

Lines changed: 258 additions & 12 deletions

File tree

src/agents/result.py

Lines changed: 58 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
)
3030
from .logger import logger
3131
from .run_context import RunContextWrapper
32-
from .run_internal.items import run_item_to_input_item
32+
from .run_internal.items import run_items_to_input_items
3333
from .run_internal.run_steps import (
3434
NextStepInterruption,
3535
ProcessedResponse,
@@ -110,6 +110,40 @@ def _populate_state_from_result(
110110
return state
111111

112112

113+
ToInputListMode = Literal["preserve_all", "normalized"]
114+
115+
116+
def _input_items_for_result(
117+
result: RunResultBase,
118+
*,
119+
mode: ToInputListMode,
120+
reasoning_item_id_policy: Literal["preserve", "omit"] | None,
121+
) -> list[TResponseInputItem]:
122+
"""Return input items for the requested result view.
123+
124+
``preserve_all`` keeps the full converted history from ``new_items``. ``normalized`` returns
125+
the canonical continuation input when handoff filtering rewrote model history, otherwise it
126+
falls back to the same converted history.
127+
"""
128+
session_items = run_items_to_input_items(result.new_items, reasoning_item_id_policy)
129+
if mode == "preserve_all":
130+
return session_items
131+
if mode != "normalized":
132+
raise ValueError(f"Unsupported to_input_list mode: {mode}")
133+
if not getattr(result, "_replay_from_model_input_items", False):
134+
# Most runs never rewrite continuation history, so normalized stays identical to the
135+
# historical preserve-all view unless the runner explicitly marked a divergence.
136+
return session_items
137+
138+
model_input_items = getattr(result, "_model_input_items", None)
139+
if not isinstance(model_input_items, list):
140+
return session_items
141+
142+
# When the runner marks a divergence, generated_items already reflect the continuation input
143+
# chosen for the next local run after applying handoff/input filtering.
144+
return run_items_to_input_items(model_input_items, reasoning_item_id_policy)
145+
146+
113147
@dataclass
114148
class RunResultBase(abc.ABC):
115149
input: str | list[TResponseInputItem]
@@ -145,6 +179,12 @@ class RunResultBase(abc.ABC):
145179

146180
_trace_state: TraceState | None = field(default=None, init=False, repr=False)
147181
"""Serialized trace metadata captured during the run."""
182+
_replay_from_model_input_items: bool = field(default=False, init=False, repr=False)
183+
"""Whether replay helpers should prefer `_model_input_items` over `new_items`.
184+
185+
This is only set when the runner preserved extra session history items that should not be
186+
replayed into the next local run, such as nested handoff history or filtered handoff input.
187+
"""
148188

149189
@classmethod
150190
def __get_pydantic_core_schema__(
@@ -208,18 +248,25 @@ def final_output_as(self, cls: type[T], raise_if_incorrect_type: bool = False) -
208248

209249
return cast(T, self.final_output)
210250

211-
def to_input_list(self) -> list[TResponseInputItem]:
212-
"""Creates a new input list, merging the original input with all the new items generated."""
251+
def to_input_list(
252+
self,
253+
*,
254+
mode: ToInputListMode = "preserve_all",
255+
) -> list[TResponseInputItem]:
256+
"""Create an input-item view of this run.
257+
258+
``mode="preserve_all"`` keeps the historical behavior of converting ``new_items`` into a
259+
full plain-item history. ``mode="normalized"`` prefers the canonical continuation input
260+
when handoff filtering rewrote model history, while remaining identical for ordinary runs.
261+
"""
213262
original_items: list[TResponseInputItem] = ItemHelpers.input_to_new_input_list(self.input)
214-
new_items: list[TResponseInputItem] = []
215263
reasoning_item_id_policy = getattr(self, "_reasoning_item_id_policy", None)
216-
for item in self.new_items:
217-
converted = run_item_to_input_item(item, reasoning_item_id_policy)
218-
if converted is None:
219-
continue
220-
new_items.append(converted)
221-
222-
return original_items + new_items
264+
replay_items = _input_items_for_result(
265+
self,
266+
mode=mode,
267+
reasoning_item_id_policy=reasoning_item_id_policy,
268+
)
269+
return original_items + replay_items
223270

224271
@property
225272
def agent_tool_invocation(self) -> AgentToolInvocation | None:

src/agents/run.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -798,6 +798,11 @@ def _with_reasoning_item_id_policy(result: RunResult) -> RunResult:
798798
)
799799
result._current_turn = current_turn
800800
result._model_input_items = list(generated_items)
801+
# Keep normalized replay aligned with the model-facing
802+
# continuation whenever session history preserved extra items.
803+
result._replay_from_model_input_items = list(
804+
generated_items
805+
) != list(session_items)
801806
if run_state is not None:
802807
result._trace_state = run_state._trace_state
803808
if session_persistence_enabled:
@@ -932,6 +937,9 @@ def _with_reasoning_item_id_policy(result: RunResult) -> RunResult:
932937
)
933938
result._current_turn = max_turns
934939
result._model_input_items = list(generated_items)
940+
result._replay_from_model_input_items = list(generated_items) != list(
941+
session_items
942+
)
935943
if run_state is not None:
936944
result._trace_state = run_state._trace_state
937945
if session_persistence_enabled and include_in_history:
@@ -1200,6 +1208,9 @@ def _with_reasoning_item_id_policy(result: RunResult) -> RunResult:
12001208
)
12011209
result._current_turn = current_turn
12021210
result._model_input_items = list(generated_items)
1211+
result._replay_from_model_input_items = list(generated_items) != list(
1212+
session_items
1213+
)
12031214
if run_state is not None:
12041215
result._current_turn_persisted_item_count = (
12051216
run_state._current_turn_persisted_item_count
@@ -1591,6 +1602,11 @@ def run_streamed(
15911602
streamed_result._model_input_items = (
15921603
list(run_state._generated_items) if run_state is not None else []
15931604
)
1605+
streamed_result._replay_from_model_input_items = (
1606+
list(run_state._generated_items) != list(run_state._session_items)
1607+
if run_state is not None
1608+
else False
1609+
)
15941610
streamed_result._reasoning_item_id_policy = resolved_reasoning_item_id_policy
15951611
if run_state is not None:
15961612
streamed_result._trace_state = run_state._trace_state

src/agents/run_internal/agent_runner_helpers.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,7 @@ def build_interruption_result(
271271
)
272272
result._current_turn = current_turn
273273
result._model_input_items = list(generated_items)
274+
result._replay_from_model_input_items = list(generated_items) != list(session_items)
274275
if run_state is not None:
275276
result._current_turn_persisted_item_count = run_state._current_turn_persisted_item_count
276277
result._trace_state = run_state._trace_state

src/agents/run_internal/run_loop.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -483,6 +483,11 @@ def _sync_conversation_tracking_from_tracker() -> None:
483483
streamed_result._state = run_state
484484
if run_state is not None:
485485
streamed_result._model_input_items = list(run_state._generated_items)
486+
# Streamed follow-ups need the same normalized replay signal as sync runs when the
487+
# runner's continuation differs from the richer session history.
488+
streamed_result._replay_from_model_input_items = list(run_state._generated_items) != list(
489+
run_state._session_items
490+
)
486491

487492
if run_state is not None:
488493
run_state._conversation_id = conversation_id
@@ -627,6 +632,9 @@ async def _save_stream_items_without_count(
627632
)
628633
streamed_result._model_input_items = generated_items
629634
streamed_result.new_items = base_session_items + list(turn_session_items)
635+
streamed_result._replay_from_model_input_items = list(
636+
streamed_result._model_input_items
637+
) != list(streamed_result.new_items)
630638
if run_state is not None:
631639
update_run_state_after_resume(
632640
run_state,
@@ -914,6 +922,9 @@ async def _save_stream_items_without_count(
914922
)
915923
turn_session_items = session_items_for_turn(turn_result)
916924
streamed_result.new_items.extend(turn_session_items)
925+
streamed_result._replay_from_model_input_items = list(
926+
streamed_result._model_input_items
927+
) != list(streamed_result.new_items)
917928
store_setting = current_agent.model_settings.resolve(
918929
run_config.model_settings
919930
).store

tests/test_agent_runner.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1146,6 +1146,10 @@ async def test_structured_output():
11461146
"should have input: conversation summary, function call, function call result, message, "
11471147
"handoff, handoff output, preamble message, tool call, tool call result, final output"
11481148
)
1149+
assert len(result.to_input_list(mode="normalized")) == 6, (
1150+
"should have normalized replay input: conversation summary, carried-forward message, "
1151+
"preamble message, tool call, tool call result, final output"
1152+
)
11491153

11501154
assert result.last_agent == agent_1, "should have handed off to agent_1"
11511155
assert result.final_output == Foo(bar="baz"), "should have structured output"

tests/test_agent_runner_streamed.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -669,6 +669,10 @@ async def test_structured_output():
669669
"should have input: conversation summary, function call, function call result, message, "
670670
"handoff, handoff output, preamble message, tool call, tool call result, final output"
671671
)
672+
assert len(result.to_input_list(mode="normalized")) == 6, (
673+
"should have normalized replay input: conversation summary, carried-forward message, "
674+
"preamble message, tool call, tool call result, final output"
675+
)
672676

673677
assert result.last_agent == agent_1, "should have handed off to agent_1"
674678
assert result.final_output == Foo(bar="baz"), "should have structured output"
@@ -1398,6 +1402,10 @@ async def test_streaming_events():
13981402
"should have input: conversation summary, function call, function call result, message, "
13991403
"handoff, handoff output, tool call, tool call result, final output"
14001404
)
1405+
assert len(result.to_input_list(mode="normalized")) == 5, (
1406+
"should have normalized replay input: conversation summary, carried-forward message, "
1407+
"tool call, tool call result, final output"
1408+
)
14011409

14021410
assert result.last_agent == agent_1, "should have handed off to agent_1"
14031411
assert result.final_output == Foo(bar="baz"), "should have structured output"

0 commit comments

Comments
 (0)