2323 drop_orphan_function_calls ,
2424 fingerprint_input_item ,
2525 normalize_input_items_for_api ,
26+ prepare_model_input_items ,
2627 run_item_to_input_item ,
2728)
2829
@@ -153,8 +154,7 @@ def hydrate_from_state(
153154
154155 normalized_input = original_input
155156 if isinstance (original_input , list ):
156- normalized = normalize_input_items_for_api (original_input )
157- normalized_input = drop_orphan_function_calls (normalized )
157+ normalized_input = prepare_model_input_items (original_input )
158158
159159 for item in ItemHelpers .input_to_new_input_list (normalized_input ):
160160 if item is None :
@@ -404,13 +404,17 @@ def prepare_input(
404404 generated_items : list [RunItem ],
405405 ) -> list [TResponseInputItem ]:
406406 """Assemble the next model input while skipping duplicates and approvals."""
407- input_items : list [TResponseInputItem ] = []
407+ prepared_initial_items : list [TResponseInputItem ] = []
408+ prepared_generated_items : list [TResponseInputItem ] = []
409+ generated_item_sources : dict [int , TResponseInputItem ] = {}
408410
409411 if not self .sent_initial_input :
410412 initial_items = ItemHelpers .input_to_new_input_list (original_input )
411- input_items .extend (initial_items )
412- for item in initial_items :
413- self ._register_prepared_item_source (item )
413+ prepared_initial_items = normalize_input_items_for_api (initial_items )
414+ for prepared_item , source_item in zip (
415+ prepared_initial_items , initial_items , strict = False
416+ ):
417+ self ._register_prepared_item_source (prepared_item , source_item )
414418 filtered_initials = []
415419 for item in initial_items :
416420 if item is None or isinstance (item , (str , bytes )):
@@ -419,9 +423,11 @@ def prepare_input(
419423 self .remaining_initial_input = filtered_initials or None
420424 self .sent_initial_input = True
421425 elif self .remaining_initial_input :
422- input_items .extend (self .remaining_initial_input )
423- for item in self .remaining_initial_input :
424- self ._register_prepared_item_source (item )
426+ prepared_initial_items = normalize_input_items_for_api (self .remaining_initial_input )
427+ for prepared_item , source_item in zip (
428+ prepared_initial_items , self .remaining_initial_input , strict = False
429+ ):
430+ self ._register_prepared_item_source (prepared_item , source_item )
425431
426432 for item in generated_items : # type: ignore[assignment]
427433 run_item : RunItem = cast (RunItem , item )
@@ -474,13 +480,23 @@ def prepare_input(
474480 ):
475481 continue
476482
477- input_items .append (converted_input_item )
478- self ._register_prepared_item_source (
479- converted_input_item ,
480- cast (TResponseInputItem , raw_item ),
481- )
483+ prepared_generated_items .append (converted_input_item )
484+ generated_item_sources [id (converted_input_item )] = cast (TResponseInputItem , raw_item )
482485
483- return input_items
486+ normalized_generated_items = normalize_input_items_for_api (prepared_generated_items )
487+ normalized_generated_sources = {
488+ id (normalized_item ): generated_item_sources [id (source_item )]
489+ for normalized_item , source_item in zip (
490+ normalized_generated_items , prepared_generated_items , strict = False
491+ )
492+ }
493+ filtered_generated_items = drop_orphan_function_calls (normalized_generated_items )
494+ for item in filtered_generated_items :
495+ prepared_source_item = normalized_generated_sources .get (id (item ))
496+ if prepared_source_item is not None :
497+ self ._register_prepared_item_source (item , prepared_source_item )
498+
499+ return prepared_initial_items + filtered_generated_items
484500
485501 def _register_prepared_item_source (
486502 self , prepared_item : TResponseInputItem , source_item : TResponseInputItem | None = None
0 commit comments