Skip to content

Commit 48e28a3

Browse files
authored
fix: cancel model task when parallel input guardrail trips (#2416)
1 parent c97c958 commit 48e28a3

3 files changed

Lines changed: 124 additions & 0 deletions

File tree

src/agents/run.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
resolve_resumed_context,
5353
resolve_trace_settings,
5454
save_turn_items_if_needed,
55+
should_cancel_parallel_model_task_on_input_guardrail_trip,
5556
update_run_state_for_interruption,
5657
validate_session_conversation_settings,
5758
)
@@ -1007,6 +1008,10 @@ async def run(
10071008
model_task,
10081009
)
10091010
except InputGuardrailTripwireTriggered:
1011+
if should_cancel_parallel_model_task_on_input_guardrail_trip():
1012+
if not model_task.done():
1013+
model_task.cancel()
1014+
await asyncio.gather(model_task, return_exceptions=True)
10101015
session_input_items_for_persistence = (
10111016
await persist_session_items_for_guardrail_trip(
10121017
session,

src/agents/run_internal/agent_runner_helpers.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,30 @@
4343
"resolve_processed_response",
4444
"resolve_resumed_context",
4545
"save_turn_items_if_needed",
46+
"should_cancel_parallel_model_task_on_input_guardrail_trip",
4647
"update_run_state_for_interruption",
4748
]
4849

50+
_PARALLEL_INPUT_GUARDRAIL_CANCEL_PATCH_ID = (
51+
"openai_agents.cancel_parallel_model_task_on_input_guardrail_trip.v1"
52+
)
53+
54+
55+
def should_cancel_parallel_model_task_on_input_guardrail_trip() -> bool:
56+
"""Return whether an in-flight model task should be cancelled on guardrail trip."""
57+
try:
58+
from temporalio import workflow as temporal_workflow # type: ignore[import-not-found]
59+
except Exception:
60+
return True
61+
62+
try:
63+
if not temporal_workflow.in_workflow():
64+
return True
65+
# Preserve replay compatibility for histories created before cancellation.
66+
return bool(temporal_workflow.patched(_PARALLEL_INPUT_GUARDRAIL_CANCEL_PATCH_ID))
67+
except Exception:
68+
return True
69+
4970

5071
def apply_resumed_conversation_settings(
5172
*,

tests/test_guardrails.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -509,6 +509,104 @@ async def slow_parallel_check(
509509
assert model.first_turn_args is not None, "Model should have been called in parallel mode"
510510

511511

512+
@pytest.mark.asyncio
513+
async def test_parallel_guardrail_trip_cancels_model_task():
514+
model_started = asyncio.Event()
515+
model_cancelled = asyncio.Event()
516+
model_finished = asyncio.Event()
517+
518+
@input_guardrail(run_in_parallel=True)
519+
async def tripwire_after_model_starts(
520+
ctx: RunContextWrapper[Any], agent: Agent[Any], input: str | list[TResponseInputItem]
521+
) -> GuardrailFunctionOutput:
522+
await asyncio.wait_for(model_started.wait(), timeout=1)
523+
return GuardrailFunctionOutput(
524+
output_info="parallel_tripwire",
525+
tripwire_triggered=True,
526+
)
527+
528+
model = FakeModel()
529+
original_get_response = model.get_response
530+
531+
async def slow_get_response(*args, **kwargs):
532+
model_started.set()
533+
try:
534+
await asyncio.sleep(0.2)
535+
return await original_get_response(*args, **kwargs)
536+
except asyncio.CancelledError:
537+
model_cancelled.set()
538+
raise
539+
finally:
540+
model_finished.set()
541+
542+
agent = Agent(
543+
name="parallel_tripwire_agent",
544+
instructions="Reply with 'hello'",
545+
input_guardrails=[tripwire_after_model_starts],
546+
model=model,
547+
)
548+
model.set_next_output([get_text_message("should_not_finish")])
549+
550+
with patch.object(model, "get_response", side_effect=slow_get_response):
551+
with pytest.raises(InputGuardrailTripwireTriggered):
552+
await Runner.run(agent, "trigger guardrail")
553+
554+
await asyncio.wait_for(model_finished.wait(), timeout=1)
555+
assert model_started.is_set() is True
556+
assert model_cancelled.is_set() is True
557+
558+
559+
@pytest.mark.asyncio
560+
async def test_parallel_guardrail_trip_compat_mode_does_not_cancel_model_task():
561+
model_started = asyncio.Event()
562+
model_cancelled = asyncio.Event()
563+
model_finished = asyncio.Event()
564+
565+
@input_guardrail(run_in_parallel=True)
566+
async def tripwire_after_model_starts(
567+
ctx: RunContextWrapper[Any], agent: Agent[Any], input: str | list[TResponseInputItem]
568+
) -> GuardrailFunctionOutput:
569+
await asyncio.wait_for(model_started.wait(), timeout=1)
570+
return GuardrailFunctionOutput(
571+
output_info="parallel_tripwire",
572+
tripwire_triggered=True,
573+
)
574+
575+
model = FakeModel()
576+
original_get_response = model.get_response
577+
578+
async def slow_get_response(*args, **kwargs):
579+
model_started.set()
580+
try:
581+
await asyncio.sleep(0.2)
582+
return await original_get_response(*args, **kwargs)
583+
except asyncio.CancelledError:
584+
model_cancelled.set()
585+
raise
586+
finally:
587+
model_finished.set()
588+
589+
agent = Agent(
590+
name="parallel_tripwire_agent",
591+
instructions="Reply with 'hello'",
592+
input_guardrails=[tripwire_after_model_starts],
593+
model=model,
594+
)
595+
model.set_next_output([get_text_message("should_finish_without_cancel")])
596+
597+
with patch.object(model, "get_response", side_effect=slow_get_response):
598+
with patch(
599+
"agents.run.should_cancel_parallel_model_task_on_input_guardrail_trip",
600+
return_value=False,
601+
):
602+
with pytest.raises(InputGuardrailTripwireTriggered):
603+
await Runner.run(agent, "trigger guardrail")
604+
605+
await asyncio.wait_for(model_finished.wait(), timeout=1)
606+
assert model_started.is_set() is True
607+
assert model_cancelled.is_set() is False
608+
609+
512610
@pytest.mark.asyncio
513611
async def test_parallel_guardrail_may_not_prevent_tool_execution_streaming():
514612
tool_was_executed = False

0 commit comments

Comments
 (0)