Skip to content

Commit fa049a2

Browse files
authored
fix: stop streamed tool execution after known input guardrail tripwire (#2688)
1 parent 9013480 commit fa049a2

5 files changed

Lines changed: 168 additions & 3 deletions

File tree

src/agents/result.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -489,6 +489,7 @@ class RunResultStreaming(RunResultBase):
489489
# Store the asyncio tasks that we're waiting on
490490
run_loop_task: asyncio.Task[Any] | None = field(default=None, repr=False)
491491
_input_guardrails_task: asyncio.Task[Any] | None = field(default=None, repr=False)
492+
_triggered_input_guardrail_result: InputGuardrailResult | None = field(default=None, repr=False)
492493
_output_guardrails_task: asyncio.Task[Any] | None = field(default=None, repr=False)
493494
_stored_exception: Exception | None = field(default=None, repr=False)
494495
_cancel_mode: Literal["none", "immediate", "after_turn"] = field(default="none", repr=False)

src/agents/run_internal/guardrails.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,14 @@ async def run_input_guardrails_with_queue(
7070
try:
7171
for done in asyncio.as_completed(guardrail_tasks):
7272
result = await done
73+
guardrail_results.append(result)
7374
if result.output.tripwire_triggered:
75+
streamed_result.input_guardrail_results = (
76+
streamed_result.input_guardrail_results + guardrail_results
77+
)
78+
guardrail_results = []
79+
streamed_result._triggered_input_guardrail_result = result
80+
queue.put_nowait(result)
7481
for t in guardrail_tasks:
7582
t.cancel()
7683
await asyncio.gather(*guardrail_tasks, return_exceptions=True)
@@ -86,11 +93,8 @@ async def run_input_guardrails_with_queue(
8693
else:
8794
# Early first-turn streamed guardrails can run before the agent span exists.
8895
_error_tracing.attach_error_to_current_span(span_error)
89-
queue.put_nowait(result)
90-
guardrail_results.append(result)
9196
break
9297
queue.put_nowait(result)
93-
guardrail_results.append(result)
9498
except Exception:
9599
for t in guardrail_tasks:
96100
t.cancel()

src/agents/run_internal/run_loop.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1237,6 +1237,24 @@ async def run_single_turn_streamed(
12371237
"""Run a single streamed turn and emit events as results arrive."""
12381238
public_agent = bindings.public_agent
12391239
execution_agent = bindings.execution_agent
1240+
1241+
async def raise_if_input_guardrail_tripwire_known() -> None:
1242+
tripwire_result = streamed_result._triggered_input_guardrail_result
1243+
if tripwire_result is not None:
1244+
raise InputGuardrailTripwireTriggered(tripwire_result)
1245+
1246+
task = streamed_result._input_guardrails_task
1247+
if task is None or not task.done():
1248+
return
1249+
1250+
guardrail_exception = task.exception()
1251+
if guardrail_exception is not None:
1252+
raise guardrail_exception
1253+
1254+
tripwire_result = streamed_result._triggered_input_guardrail_result
1255+
if tripwire_result is not None:
1256+
raise InputGuardrailTripwireTriggered(tripwire_result)
1257+
12401258
emitted_tool_call_ids: set[str] = set()
12411259
emitted_reasoning_item_ids: set[str] = set()
12421260
emitted_tool_search_fingerprints: set[str] = set()
@@ -1590,6 +1608,7 @@ async def rewind_model_request() -> None:
15901608
tool_use_tracker=tool_use_tracker,
15911609
server_manages_conversation=server_conversation_tracker is not None,
15921610
event_queue=streamed_result._event_queue,
1611+
before_side_effects=raise_if_input_guardrail_tripwire_known,
15931612
)
15941613

15951614
items_to_filter = session_items_for_turn(single_step_result)

src/agents/run_internal/turn_resolution.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1855,6 +1855,7 @@ async def get_single_step_result_from_response(
18551855
tool_use_tracker,
18561856
server_manages_conversation: bool = False,
18571857
event_queue: asyncio.Queue[StreamEvent | QueueCompleteSentinel] | None = None,
1858+
before_side_effects: Callable[[], Awaitable[None]] | None = None,
18581859
) -> SingleStepResult:
18591860
item_agent = bindings.public_agent
18601861
processed_response = process_model_response(
@@ -1866,6 +1867,9 @@ async def get_single_step_result_from_response(
18661867
existing_items=pre_step_items,
18671868
)
18681869

1870+
if before_side_effects is not None:
1871+
await before_side_effects()
1872+
18691873
tool_use_tracker.record_processed_response(item_agent, processed_response)
18701874

18711875
if event_queue is not None and processed_response.new_items:

tests/test_guardrails.py

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -658,6 +658,143 @@ async def slow_parallel_check(
658658
assert model.first_turn_args is not None, "Model should have been called in parallel mode"
659659

660660

661+
@pytest.mark.asyncio
662+
async def test_parallel_guardrail_trip_before_tool_execution_stops_streaming_turn():
663+
tool_was_executed = False
664+
model_started = asyncio.Event()
665+
guardrail_tripped = asyncio.Event()
666+
667+
@function_tool
668+
def dangerous_tool() -> str:
669+
nonlocal tool_was_executed
670+
tool_was_executed = True
671+
return "tool_executed"
672+
673+
@input_guardrail(run_in_parallel=True)
674+
async def tripwire_before_tool_execution(
675+
ctx: RunContextWrapper[Any], agent: Agent[Any], input: str | list[TResponseInputItem]
676+
) -> GuardrailFunctionOutput:
677+
await asyncio.wait_for(model_started.wait(), timeout=1)
678+
guardrail_tripped.set()
679+
return GuardrailFunctionOutput(
680+
output_info="parallel_trip_before_tool_execution",
681+
tripwire_triggered=True,
682+
)
683+
684+
model = FakeModel()
685+
original_stream_response = model.stream_response
686+
687+
async def delayed_stream_response(*args, **kwargs):
688+
model_started.set()
689+
await asyncio.wait_for(guardrail_tripped.wait(), timeout=1)
690+
await asyncio.sleep(SHORT_DELAY)
691+
async for event in original_stream_response(*args, **kwargs):
692+
yield event
693+
694+
agent = Agent(
695+
name="streaming_guardrail_hardening_agent",
696+
instructions="Call the dangerous_tool immediately",
697+
tools=[dangerous_tool],
698+
input_guardrails=[tripwire_before_tool_execution],
699+
model=model,
700+
)
701+
model.set_next_output([get_function_tool_call("dangerous_tool", arguments="{}")])
702+
model.set_next_output([get_text_message("done")])
703+
704+
with patch.object(model, "stream_response", side_effect=delayed_stream_response):
705+
result = Runner.run_streamed(agent, "trigger guardrail")
706+
707+
with pytest.raises(InputGuardrailTripwireTriggered):
708+
async for _event in result.stream_events():
709+
pass
710+
711+
assert model_started.is_set() is True
712+
assert guardrail_tripped.is_set() is True
713+
assert tool_was_executed is False
714+
assert model.first_turn_args is not None, "Model should have been called in parallel mode"
715+
716+
717+
@pytest.mark.asyncio
718+
async def test_parallel_guardrail_trip_with_slow_cancel_sibling_stops_streaming_turn():
719+
tool_was_executed = False
720+
model_started = asyncio.Event()
721+
guardrail_tripped = asyncio.Event()
722+
slow_cancel_started = asyncio.Event()
723+
slow_cancel_finished = asyncio.Event()
724+
725+
@function_tool
726+
def dangerous_tool() -> str:
727+
nonlocal tool_was_executed
728+
tool_was_executed = True
729+
return "tool_executed"
730+
731+
@input_guardrail(run_in_parallel=True)
732+
async def tripwire_before_tool_execution(
733+
ctx: RunContextWrapper[Any], agent: Agent[Any], input: str | list[TResponseInputItem]
734+
) -> GuardrailFunctionOutput:
735+
await asyncio.wait_for(model_started.wait(), timeout=1)
736+
guardrail_tripped.set()
737+
return GuardrailFunctionOutput(
738+
output_info="parallel_trip_before_tool_execution_with_slow_cancel",
739+
tripwire_triggered=True,
740+
)
741+
742+
@input_guardrail(run_in_parallel=True)
743+
async def slow_to_cancel_guardrail(
744+
ctx: RunContextWrapper[Any], agent: Agent[Any], input: str | list[TResponseInputItem]
745+
) -> GuardrailFunctionOutput:
746+
try:
747+
await asyncio.Event().wait()
748+
return GuardrailFunctionOutput(
749+
output_info="slow_to_cancel_guardrail_completed",
750+
tripwire_triggered=False,
751+
)
752+
except asyncio.CancelledError:
753+
slow_cancel_started.set()
754+
await asyncio.sleep(SHORT_DELAY)
755+
slow_cancel_finished.set()
756+
raise
757+
758+
model = FakeModel()
759+
original_stream_response = model.stream_response
760+
761+
async def delayed_stream_response(*args, **kwargs):
762+
model_started.set()
763+
await asyncio.wait_for(guardrail_tripped.wait(), timeout=1)
764+
await asyncio.wait_for(slow_cancel_started.wait(), timeout=1)
765+
async for event in original_stream_response(*args, **kwargs):
766+
yield event
767+
768+
agent = Agent(
769+
name="streaming_guardrail_slow_cancel_agent",
770+
instructions="Call the dangerous_tool immediately",
771+
tools=[dangerous_tool],
772+
input_guardrails=[tripwire_before_tool_execution, slow_to_cancel_guardrail],
773+
model=model,
774+
)
775+
model.set_next_output([get_function_tool_call("dangerous_tool", arguments="{}")])
776+
model.set_next_output([get_text_message("done")])
777+
778+
with patch.object(model, "stream_response", side_effect=delayed_stream_response):
779+
result = Runner.run_streamed(agent, "trigger guardrail")
780+
781+
with pytest.raises(InputGuardrailTripwireTriggered) as excinfo:
782+
async for _event in result.stream_events():
783+
pass
784+
785+
exc = excinfo.value
786+
assert exc.run_data is not None
787+
assert [res.output.output_info for res in exc.run_data.input_guardrail_results] == [
788+
"parallel_trip_before_tool_execution_with_slow_cancel"
789+
]
790+
assert model_started.is_set() is True
791+
assert guardrail_tripped.is_set() is True
792+
assert slow_cancel_started.is_set() is True
793+
assert slow_cancel_finished.is_set() is True
794+
assert tool_was_executed is False
795+
assert model.first_turn_args is not None, "Model should have been called in parallel mode"
796+
797+
661798
@pytest.mark.asyncio
662799
async def test_blocking_guardrail_prevents_tool_execution():
663800
tool_was_executed = False

0 commit comments

Comments
 (0)