Skip to content

Commit c97c958

Browse files
authored
fix: restore replay compatibility for run context approvals and guardrail execution (#2415)
1 parent 4b63ed7 commit c97c958

2 files changed

Lines changed: 30 additions & 61 deletions

File tree

src/agents/run.py

Lines changed: 22 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,6 @@
8686
NextStepHandoff,
8787
NextStepInterruption,
8888
NextStepRunAgain,
89-
SingleStepResult,
9089
)
9190
from .run_internal.session_persistence import (
9291
persist_session_items_for_guardrail_trip,
@@ -975,21 +974,6 @@ async def run(
975974
raise
976975

977976
parallel_results: list[InputGuardrailResult] = []
978-
parallel_guardrail_task: asyncio.Task[list[InputGuardrailResult]] | None = (
979-
None
980-
)
981-
model_task: asyncio.Task[SingleStepResult] | None = None
982-
983-
if parallel_guardrails:
984-
parallel_guardrail_task = asyncio.create_task(
985-
run_input_guardrails(
986-
starting_agent,
987-
parallel_guardrails,
988-
copy_input_items(prepared_input),
989-
context_wrapper,
990-
)
991-
)
992-
993977
model_task = asyncio.create_task(
994978
run_single_turn(
995979
agent=current_agent,
@@ -1011,46 +995,29 @@ async def run(
1011995
)
1012996
)
1013997

1014-
if parallel_guardrail_task:
1015-
done, pending = await asyncio.wait(
1016-
{parallel_guardrail_task, model_task},
1017-
return_when=asyncio.FIRST_COMPLETED,
1018-
)
1019-
1020-
if parallel_guardrail_task in done:
1021-
try:
1022-
parallel_results = parallel_guardrail_task.result()
1023-
except InputGuardrailTripwireTriggered:
1024-
model_task.cancel()
1025-
await asyncio.gather(model_task, return_exceptions=True)
1026-
session_input_items_for_persistence = (
1027-
await persist_session_items_for_guardrail_trip(
1028-
session,
1029-
server_conversation_tracker,
1030-
session_input_items_for_persistence,
1031-
original_user_input,
1032-
run_state,
1033-
store=store_setting,
1034-
)
1035-
)
1036-
raise
1037-
turn_result = await model_task
1038-
else:
1039-
turn_result = await model_task
1040-
try:
1041-
parallel_results = await parallel_guardrail_task
1042-
except InputGuardrailTripwireTriggered:
1043-
session_input_items_for_persistence = (
1044-
await persist_session_items_for_guardrail_trip(
1045-
session,
1046-
server_conversation_tracker,
1047-
session_input_items_for_persistence,
1048-
original_user_input,
1049-
run_state,
1050-
store=store_setting,
1051-
)
998+
if parallel_guardrails:
999+
try:
1000+
parallel_results, turn_result = await asyncio.gather(
1001+
run_input_guardrails(
1002+
starting_agent,
1003+
parallel_guardrails,
1004+
copy_input_items(prepared_input),
1005+
context_wrapper,
1006+
),
1007+
model_task,
1008+
)
1009+
except InputGuardrailTripwireTriggered:
1010+
session_input_items_for_persistence = (
1011+
await persist_session_items_for_guardrail_trip(
1012+
session,
1013+
server_conversation_tracker,
1014+
session_input_items_for_persistence,
1015+
original_user_input,
1016+
run_state,
1017+
store=store_setting,
10521018
)
1053-
raise
1019+
)
1020+
raise
10541021
else:
10551022
turn_result = await model_task
10561023

src/agents/run_context.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,23 +9,25 @@
99

1010
if TYPE_CHECKING:
1111
from .items import ToolApprovalItem, TResponseInputItem
12+
else:
13+
# Keep runtime annotations resolvable for TypeAdapter users (e.g., Temporal's
14+
# Pydantic data converter) without importing items.py and introducing cycles.
15+
ToolApprovalItem = Any
16+
TResponseInputItem = Any
1217

1318
TContext = TypeVar("TContext", default=Any)
1419

1520

21+
@dataclass(eq=False)
1622
class _ApprovalRecord:
1723
"""Tracks approval/rejection state for a tool.
1824
1925
``approved`` and ``rejected`` are either booleans (permanent allow/deny)
2026
or lists of call IDs when approval is scoped to specific tool calls.
2127
"""
2228

23-
approved: bool | list[str]
24-
rejected: bool | list[str]
25-
26-
def __init__(self):
27-
self.approved = []
28-
self.rejected = []
29+
approved: bool | list[str] = field(default_factory=list)
30+
rejected: bool | list[str] = field(default_factory=list)
2931

3032

3133
@dataclass(eq=False)

0 commit comments

Comments
 (0)