Skip to content

Commit 9073d5c

Browse files
authored
fix: #2664 drop orphan hosted shell calls before multi-turn replay (#2665)
1 parent cae28f0 commit 9073d5c

5 files changed

Lines changed: 314 additions & 48 deletions

File tree

src/agents/run_internal/items.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
"copy_input_items",
3434
"drop_orphan_function_calls",
3535
"ensure_input_item_format",
36+
"prepare_model_input_items",
3637
"run_item_to_input_item",
3738
"run_items_to_input_items",
3839
"normalize_input_items_for_api",
@@ -86,7 +87,11 @@ def run_items_to_input_items(
8687
return converted
8788

8889

89-
def drop_orphan_function_calls(items: list[TResponseInputItem]) -> list[TResponseInputItem]:
90+
def drop_orphan_function_calls(
91+
items: list[TResponseInputItem],
92+
*,
93+
pruning_indexes: set[int] | None = None,
94+
) -> list[TResponseInputItem]:
9095
"""
9196
Remove tool call items that do not have corresponding outputs so resumptions or retries do not
9297
replay stale tool calls.
@@ -108,6 +113,9 @@ def drop_orphan_function_calls(items: list[TResponseInputItem]) -> list[TRespons
108113
if output_type is None:
109114
filtered.append(entry)
110115
continue
116+
if pruning_indexes is not None and index not in pruning_indexes:
117+
filtered.append(entry)
118+
continue
111119
call_id = entry.get("call_id")
112120
if isinstance(call_id, str) and call_id in completed_call_ids.get(output_type, set()):
113121
filtered.append(entry)
@@ -145,6 +153,20 @@ def normalize_input_items_for_api(items: list[TResponseInputItem]) -> list[TResp
145153
return normalized
146154

147155

156+
def prepare_model_input_items(
157+
caller_items: Sequence[TResponseInputItem],
158+
generated_items: Sequence[TResponseInputItem] = (),
159+
) -> list[TResponseInputItem]:
160+
"""Normalize model input while pruning orphans only from runner-generated history."""
161+
normalized_caller_items = normalize_input_items_for_api(list(caller_items))
162+
if not generated_items:
163+
return normalized_caller_items
164+
165+
normalized_generated_items = normalize_input_items_for_api(list(generated_items))
166+
filtered_generated_items = drop_orphan_function_calls(normalized_generated_items)
167+
return normalized_caller_items + filtered_generated_items
168+
169+
148170
def normalize_resumed_input(
149171
raw_input: str | list[TResponseInputItem],
150172
) -> str | list[TResponseInputItem]:

src/agents/run_internal/oai_conversation.py

Lines changed: 31 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
drop_orphan_function_calls,
2424
fingerprint_input_item,
2525
normalize_input_items_for_api,
26+
prepare_model_input_items,
2627
run_item_to_input_item,
2728
)
2829

@@ -153,8 +154,7 @@ def hydrate_from_state(
153154

154155
normalized_input = original_input
155156
if isinstance(original_input, list):
156-
normalized = normalize_input_items_for_api(original_input)
157-
normalized_input = drop_orphan_function_calls(normalized)
157+
normalized_input = prepare_model_input_items(original_input)
158158

159159
for item in ItemHelpers.input_to_new_input_list(normalized_input):
160160
if item is None:
@@ -404,13 +404,17 @@ def prepare_input(
404404
generated_items: list[RunItem],
405405
) -> list[TResponseInputItem]:
406406
"""Assemble the next model input while skipping duplicates and approvals."""
407-
input_items: list[TResponseInputItem] = []
407+
prepared_initial_items: list[TResponseInputItem] = []
408+
prepared_generated_items: list[TResponseInputItem] = []
409+
generated_item_sources: dict[int, TResponseInputItem] = {}
408410

409411
if not self.sent_initial_input:
410412
initial_items = ItemHelpers.input_to_new_input_list(original_input)
411-
input_items.extend(initial_items)
412-
for item in initial_items:
413-
self._register_prepared_item_source(item)
413+
prepared_initial_items = normalize_input_items_for_api(initial_items)
414+
for prepared_item, source_item in zip(
415+
prepared_initial_items, initial_items, strict=False
416+
):
417+
self._register_prepared_item_source(prepared_item, source_item)
414418
filtered_initials = []
415419
for item in initial_items:
416420
if item is None or isinstance(item, (str, bytes)):
@@ -419,9 +423,11 @@ def prepare_input(
419423
self.remaining_initial_input = filtered_initials or None
420424
self.sent_initial_input = True
421425
elif self.remaining_initial_input:
422-
input_items.extend(self.remaining_initial_input)
423-
for item in self.remaining_initial_input:
424-
self._register_prepared_item_source(item)
426+
prepared_initial_items = normalize_input_items_for_api(self.remaining_initial_input)
427+
for prepared_item, source_item in zip(
428+
prepared_initial_items, self.remaining_initial_input, strict=False
429+
):
430+
self._register_prepared_item_source(prepared_item, source_item)
425431

426432
for item in generated_items: # type: ignore[assignment]
427433
run_item: RunItem = cast(RunItem, item)
@@ -474,13 +480,23 @@ def prepare_input(
474480
):
475481
continue
476482

477-
input_items.append(converted_input_item)
478-
self._register_prepared_item_source(
479-
converted_input_item,
480-
cast(TResponseInputItem, raw_item),
481-
)
483+
prepared_generated_items.append(converted_input_item)
484+
generated_item_sources[id(converted_input_item)] = cast(TResponseInputItem, raw_item)
482485

483-
return input_items
486+
normalized_generated_items = normalize_input_items_for_api(prepared_generated_items)
487+
normalized_generated_sources = {
488+
id(normalized_item): generated_item_sources[id(source_item)]
489+
for normalized_item, source_item in zip(
490+
normalized_generated_items, prepared_generated_items, strict=False
491+
)
492+
}
493+
filtered_generated_items = drop_orphan_function_calls(normalized_generated_items)
494+
for item in filtered_generated_items:
495+
prepared_source_item = normalized_generated_sources.get(id(item))
496+
if prepared_source_item is not None:
497+
self._register_prepared_item_source(item, prepared_source_item)
498+
499+
return prepared_initial_items + filtered_generated_items
484500

485501
def _register_prepared_item_source(
486502
self, prepared_item: TResponseInputItem, source_item: TResponseInputItem | None = None

src/agents/run_internal/run_loop.py

Lines changed: 16 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -69,10 +69,7 @@
6969
from ..usage import Usage
7070
from ..util import _coro, _error_tracing
7171
from .agent_runner_helpers import apply_resumed_conversation_settings
72-
from .approvals import (
73-
append_input_items_excluding_approvals,
74-
approvals_from_step,
75-
)
72+
from .approvals import approvals_from_step
7673
from .error_handlers import (
7774
build_run_error_data,
7875
create_message_output_item,
@@ -93,8 +90,9 @@
9390
copy_input_items,
9491
deduplicate_input_items_preferring_latest,
9592
ensure_input_item_format,
96-
normalize_input_items_for_api,
9793
normalize_resumed_input,
94+
prepare_model_input_items,
95+
run_items_to_input_items,
9896
)
9997
from .model_retry import (
10098
apply_retry_attempt_usage,
@@ -244,6 +242,16 @@ async def _should_persist_stream_items(
244242
return should_skip_session_save is False
245243

246244

245+
def _prepare_turn_input_items(
246+
caller_input: str | list[TResponseInputItem],
247+
generated_items: list[RunItem],
248+
reasoning_item_id_policy: ReasoningItemIdPolicy | None,
249+
) -> list[TResponseInputItem]:
250+
caller_items = ItemHelpers.input_to_new_input_list(caller_input)
251+
continuation_items = run_items_to_input_items(generated_items, reasoning_item_id_policy)
252+
return prepare_model_input_items(caller_items, continuation_items)
253+
254+
247255
def _complete_stream_interruption(
248256
streamed_result: RunResultStreaming,
249257
*,
@@ -1164,16 +1172,12 @@ def _tool_search_fingerprint(raw_item: Any) -> str:
11641172
else 0,
11651173
)
11661174
else:
1167-
input = ItemHelpers.input_to_new_input_list(streamed_result.input)
1168-
append_input_items_excluding_approvals(
1169-
input,
1175+
input = _prepare_turn_input_items(
1176+
streamed_result.input,
11701177
streamed_result._model_input_items,
11711178
reasoning_item_id_policy,
11721179
)
11731180

1174-
if isinstance(input, list):
1175-
input = normalize_input_items_for_api(input)
1176-
11771181
filtered = await maybe_filter_model_input(
11781182
agent=agent,
11791183
run_config=run_config,
@@ -1512,23 +1516,7 @@ async def run_single_turn(
15121516
if server_conversation_tracker is not None:
15131517
input = server_conversation_tracker.prepare_input(original_input, generated_items)
15141518
else:
1515-
input = ItemHelpers.input_to_new_input_list(original_input)
1516-
if isinstance(input, list):
1517-
append_input_items_excluding_approvals(
1518-
input,
1519-
generated_items,
1520-
reasoning_item_id_policy,
1521-
)
1522-
else:
1523-
input = ItemHelpers.input_to_new_input_list(input)
1524-
append_input_items_excluding_approvals(
1525-
input,
1526-
generated_items,
1527-
reasoning_item_id_policy,
1528-
)
1529-
1530-
if isinstance(input, list):
1531-
input = normalize_input_items_for_api(input)
1519+
input = _prepare_turn_input_items(original_input, generated_items, reasoning_item_id_policy)
15321520

15331521
new_response = await get_new_response(
15341522
agent,

src/agents/run_internal/session_persistence.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,13 +91,17 @@ async def prepare_input_with_session(
9191
ensure_input_item_format(item) for item in ItemHelpers.input_to_new_input_list(input)
9292
]
9393

94+
prune_history_indexes: set[int] = set()
95+
9496
if session_input_callback is None or not include_history_in_prepared_input:
9597
prepared_items_raw: list[TResponseInputItem] = (
9698
converted_history + new_input_list
9799
if include_history_in_prepared_input
98100
else list(new_input_list)
99101
)
100102
appended_items = list(new_input_list)
103+
if include_history_in_prepared_input:
104+
prune_history_indexes = set(range(len(converted_history)))
101105
else:
102106
if not callable(session_input_callback):
103107
raise UserError(
@@ -121,17 +125,19 @@ async def prepare_input_with_session(
121125
new_counts = _build_frequency_map(new_items_for_callback)
122126

123127
appended: list[Any] = []
124-
for item in combined:
128+
for combined_index, item in enumerate(combined):
125129
key = _session_item_key(item)
126130
if _consume_reference(new_refs, key, item):
127131
new_counts[key] = max(new_counts.get(key, 0) - 1, 0)
128132
appended.append(item)
129133
continue
130134
if _consume_reference(history_refs, key, item):
131135
history_counts[key] = max(history_counts.get(key, 0) - 1, 0)
136+
prune_history_indexes.add(combined_index)
132137
continue
133138
if history_counts.get(key, 0) > 0:
134139
history_counts[key] = history_counts.get(key, 0) - 1
140+
prune_history_indexes.add(combined_index)
135141
continue
136142
if new_counts.get(key, 0) > 0:
137143
new_counts[key] = max(new_counts.get(key, 0) - 1, 0)
@@ -151,7 +157,10 @@ async def prepare_input_with_session(
151157
# Normalize exactly as the runtime does elsewhere so the prepared model input and the
152158
# persisted session items are derived from the same item shape and dedupe rules.
153159
prepared_as_inputs = [ensure_input_item_format(item) for item in prepared_items_raw]
154-
filtered = drop_orphan_function_calls(prepared_as_inputs)
160+
filtered = drop_orphan_function_calls(
161+
prepared_as_inputs,
162+
pruning_indexes=prune_history_indexes,
163+
)
155164
normalized = normalize_input_items_for_api(filtered)
156165
deduplicated = deduplicate_input_items_preferring_latest(normalized)
157166

0 commit comments

Comments
 (0)