@@ -509,6 +509,104 @@ async def slow_parallel_check(
509509 assert model .first_turn_args is not None , "Model should have been called in parallel mode"
510510
511511
512+ @pytest .mark .asyncio
513+ async def test_parallel_guardrail_trip_cancels_model_task ():
514+ model_started = asyncio .Event ()
515+ model_cancelled = asyncio .Event ()
516+ model_finished = asyncio .Event ()
517+
518+ @input_guardrail (run_in_parallel = True )
519+ async def tripwire_after_model_starts (
520+ ctx : RunContextWrapper [Any ], agent : Agent [Any ], input : str | list [TResponseInputItem ]
521+ ) -> GuardrailFunctionOutput :
522+ await asyncio .wait_for (model_started .wait (), timeout = 1 )
523+ return GuardrailFunctionOutput (
524+ output_info = "parallel_tripwire" ,
525+ tripwire_triggered = True ,
526+ )
527+
528+ model = FakeModel ()
529+ original_get_response = model .get_response
530+
531+ async def slow_get_response (* args , ** kwargs ):
532+ model_started .set ()
533+ try :
534+ await asyncio .sleep (0.2 )
535+ return await original_get_response (* args , ** kwargs )
536+ except asyncio .CancelledError :
537+ model_cancelled .set ()
538+ raise
539+ finally :
540+ model_finished .set ()
541+
542+ agent = Agent (
543+ name = "parallel_tripwire_agent" ,
544+ instructions = "Reply with 'hello'" ,
545+ input_guardrails = [tripwire_after_model_starts ],
546+ model = model ,
547+ )
548+ model .set_next_output ([get_text_message ("should_not_finish" )])
549+
550+ with patch .object (model , "get_response" , side_effect = slow_get_response ):
551+ with pytest .raises (InputGuardrailTripwireTriggered ):
552+ await Runner .run (agent , "trigger guardrail" )
553+
554+ await asyncio .wait_for (model_finished .wait (), timeout = 1 )
555+ assert model_started .is_set () is True
556+ assert model_cancelled .is_set () is True
557+
558+
559+ @pytest .mark .asyncio
560+ async def test_parallel_guardrail_trip_compat_mode_does_not_cancel_model_task ():
561+ model_started = asyncio .Event ()
562+ model_cancelled = asyncio .Event ()
563+ model_finished = asyncio .Event ()
564+
565+ @input_guardrail (run_in_parallel = True )
566+ async def tripwire_after_model_starts (
567+ ctx : RunContextWrapper [Any ], agent : Agent [Any ], input : str | list [TResponseInputItem ]
568+ ) -> GuardrailFunctionOutput :
569+ await asyncio .wait_for (model_started .wait (), timeout = 1 )
570+ return GuardrailFunctionOutput (
571+ output_info = "parallel_tripwire" ,
572+ tripwire_triggered = True ,
573+ )
574+
575+ model = FakeModel ()
576+ original_get_response = model .get_response
577+
578+ async def slow_get_response (* args , ** kwargs ):
579+ model_started .set ()
580+ try :
581+ await asyncio .sleep (0.2 )
582+ return await original_get_response (* args , ** kwargs )
583+ except asyncio .CancelledError :
584+ model_cancelled .set ()
585+ raise
586+ finally :
587+ model_finished .set ()
588+
589+ agent = Agent (
590+ name = "parallel_tripwire_agent" ,
591+ instructions = "Reply with 'hello'" ,
592+ input_guardrails = [tripwire_after_model_starts ],
593+ model = model ,
594+ )
595+ model .set_next_output ([get_text_message ("should_finish_without_cancel" )])
596+
597+ with patch .object (model , "get_response" , side_effect = slow_get_response ):
598+ with patch (
599+ "agents.run.should_cancel_parallel_model_task_on_input_guardrail_trip" ,
600+ return_value = False ,
601+ ):
602+ with pytest .raises (InputGuardrailTripwireTriggered ):
603+ await Runner .run (agent , "trigger guardrail" )
604+
605+ await asyncio .wait_for (model_finished .wait (), timeout = 1 )
606+ assert model_started .is_set () is True
607+ assert model_cancelled .is_set () is False
608+
609+
512610@pytest .mark .asyncio
513611async def test_parallel_guardrail_may_not_prevent_tool_execution_streaming ():
514612 tool_was_executed = False
0 commit comments