Skip to content

Commit 1959dd3

Browse files
authored
feat: #2658 preserve explicit approval rejection messages across resume flows (#2660)
1 parent e0f6a28 commit 1959dd3

13 files changed

Lines changed: 614 additions & 17 deletions

examples/agent_patterns/human_in_the_loop_custom_rejection.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@
33
This example is intentionally minimal:
44
1. A single sensitive tool requires human approval.
55
2. The first turn always issues that tool call.
6-
3. Rejection uses a custom message via ``tool_error_formatter``.
7-
4. The example prints both the formatter output and the assistant's final reply.
6+
3. ``tool_error_formatter`` defines the universal fallback message shape.
7+
4. A per-call ``rejection_message`` passed to ``state.reject(...)`` overrides that fallback.
8+
5. The example prints both the tool output and the assistant's final reply.
89
"""
910

1011
import asyncio
@@ -21,7 +22,7 @@
2122

2223

2324
async def tool_error_formatter(args: ToolErrorFormatterArgs[None]) -> str | None:
24-
"""Build a simple output message for rejected tool calls."""
25+
"""Build the universal fallback output message for rejected tool calls."""
2526
if args.kind != "approval_rejected":
2627
return None
2728
# The default message is "Tool execution was not approved."
@@ -60,6 +61,8 @@ async def main() -> None:
6061
tools=[publish_announcement],
6162
)
6263
run_config = RunConfig(tool_error_formatter=tool_error_formatter)
64+
# ``tool_error_formatter`` is the universal fallback for approval rejects.
65+
# A specific ``rejection_message`` passed to ``state.reject(...)`` below overrides it.
6366

6467
result = await Runner.run(
6568
agent,
@@ -81,7 +84,13 @@ async def main() -> None:
8184
if approved:
8285
state.approve(interruption)
8386
else:
84-
state.reject(interruption)
87+
# This per-call rejection message takes precedence over ``tool_error_formatter``.
88+
state.reject(
89+
interruption,
90+
rejection_message=(
91+
"Publish action was canceled because the reviewer denied approval."
92+
),
93+
)
8594

8695
result = await Runner.run(agent, state, run_config=run_config)
8796

src/agents/realtime/session.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from pydantic import BaseModel
1111
from typing_extensions import assert_never
1212

13+
from .._tool_identity import get_function_tool_lookup_key_for_tool
1314
from ..agent import Agent
1415
from ..exceptions import UserError
1516
from ..handoffs import Handoff
@@ -527,6 +528,14 @@ async def _send_tool_rejection(
527528

528529
async def _resolve_approval_rejection_message(self, *, tool: FunctionTool, call_id: str) -> str:
529530
"""Resolve model-visible output text for approval rejections."""
531+
explicit_message = self._context_wrapper.get_rejection_message(
532+
tool.name,
533+
call_id,
534+
tool_lookup_key=get_function_tool_lookup_key_for_tool(tool),
535+
)
536+
if explicit_message is not None:
537+
return explicit_message
538+
530539
formatter = self._run_config.get("tool_error_formatter")
531540
if formatter is None:
532541
return REJECTION_MESSAGE
@@ -574,14 +583,24 @@ async def approve_tool_call(self, call_id: str, *, always: bool = False) -> None
574583
else:
575584
await self._handle_tool_call(tool_call, agent_snapshot=agent_snapshot)
576585

577-
async def reject_tool_call(self, call_id: str, *, always: bool = False) -> None:
586+
async def reject_tool_call(
587+
self,
588+
call_id: str,
589+
*,
590+
always: bool = False,
591+
rejection_message: str | None = None,
592+
) -> None:
578593
"""Reject a pending tool call and notify the model."""
579594
pending = self._pending_tool_calls.pop(call_id, None)
580595
if pending is None:
581596
return
582597

583598
tool_call, agent_snapshot, function_tool, approval_item = pending
584-
self._context_wrapper.reject_tool(approval_item, always_reject=always)
599+
self._context_wrapper.reject_tool(
600+
approval_item,
601+
always_reject=always,
602+
rejection_message=rejection_message,
603+
)
585604
await self._send_tool_rejection(tool_call, tool=function_tool, agent=agent_snapshot)
586605

587606
async def _handle_tool_call(

src/agents/run_context.py

Lines changed: 128 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ class _ApprovalRecord:
3535

3636
approved: bool | list[str] = field(default_factory=list)
3737
rejected: bool | list[str] = field(default_factory=list)
38+
rejection_messages: dict[str, str] = field(default_factory=dict)
39+
sticky_rejection_message: str | None = None
3840

3941

4042
@dataclass(eq=False)
@@ -207,8 +209,101 @@ def _get_approval_status_for_key(self, approval_key: str, call_id: str) -> bool
207209
# Per-call approvals are scoped to the exact call ID, so other calls require a new decision.
208210
return None
209211

212+
@staticmethod
213+
def _clear_rejection_message(record: _ApprovalRecord, call_id: str | None) -> None:
214+
if call_id is None:
215+
return
216+
record.rejection_messages.pop(call_id, None)
217+
218+
@staticmethod
219+
def _get_rejection_message_for_key(record: _ApprovalRecord, call_id: str) -> str | None:
220+
if record.rejected is True:
221+
if call_id in record.rejection_messages:
222+
return record.rejection_messages[call_id]
223+
return record.sticky_rejection_message
224+
if isinstance(record.rejected, list) and call_id in record.rejected:
225+
return record.rejection_messages.get(call_id)
226+
return None
227+
228+
def get_rejection_message(
229+
self,
230+
tool_name: str,
231+
call_id: str,
232+
*,
233+
tool_namespace: str | None = None,
234+
existing_pending: ToolApprovalItem | None = None,
235+
tool_lookup_key: FunctionToolLookupKey | None = None,
236+
) -> str | None:
237+
"""Return a stored rejection message for a tool call if one exists."""
238+
candidates: list[str] = []
239+
explicit_namespace = (
240+
tool_namespace if isinstance(tool_namespace, str) and tool_namespace else None
241+
)
242+
pending_namespace = (
243+
self._resolve_tool_namespace(existing_pending) if existing_pending is not None else None
244+
)
245+
pending_key = self._resolve_approval_key(existing_pending) if existing_pending else None
246+
pending_tool_name = self._resolve_tool_name(existing_pending) if existing_pending else None
247+
pending_keys = (
248+
list(self._resolve_approval_keys(existing_pending))
249+
if existing_pending is not None
250+
else []
251+
)
252+
253+
if existing_pending and pending_key is not None:
254+
candidates.append(pending_key)
255+
explicit_keys = (
256+
list(
257+
get_function_tool_approval_keys(
258+
tool_name=tool_name,
259+
tool_namespace=explicit_namespace,
260+
tool_lookup_key=tool_lookup_key,
261+
include_legacy_deferred_key=True,
262+
)
263+
)
264+
if explicit_namespace is not None or tool_lookup_key is not None
265+
else []
266+
)
267+
for explicit_key in explicit_keys:
268+
if explicit_key not in candidates:
269+
candidates.append(explicit_key)
270+
if not explicit_keys and pending_namespace and pending_key is not None:
271+
if pending_key not in candidates:
272+
candidates.append(pending_key)
273+
if (
274+
explicit_namespace is None
275+
and tool_lookup_key is None
276+
and existing_pending is None
277+
and tool_name not in candidates
278+
):
279+
candidates.append(tool_name)
280+
if existing_pending:
281+
for pending_candidate in pending_keys:
282+
if pending_candidate not in candidates:
283+
candidates.append(pending_candidate)
284+
if (
285+
pending_namespace is None
286+
and pending_tool_name is not None
287+
and pending_tool_name not in candidates
288+
):
289+
candidates.append(pending_tool_name)
290+
291+
for candidate in candidates:
292+
approval_entry = self._approvals.get(candidate)
293+
if not approval_entry:
294+
continue
295+
message = self._get_rejection_message_for_key(approval_entry, call_id)
296+
if message is not None:
297+
return message
298+
return None
299+
210300
def _apply_approval_decision(
211-
self, approval_item: ToolApprovalItem, *, always: bool, approve: bool
301+
self,
302+
approval_item: ToolApprovalItem,
303+
*,
304+
always: bool,
305+
approve: bool,
306+
rejection_message: str | None = None,
212307
) -> None:
213308
"""Record an approval or rejection decision."""
214309
approval_keys = self._resolve_approval_keys(approval_item) or ("unknown_tool",)
@@ -223,6 +318,14 @@ def _apply_approval_decision(
223318
approval_entry.rejected = [] if approve else True
224319
if not approve:
225320
approval_entry.approved = False
321+
if rejection_message is not None and call_id is not None:
322+
approval_entry.rejection_messages[call_id] = rejection_message
323+
elif call_id is not None:
324+
self._clear_rejection_message(approval_entry, call_id)
325+
approval_entry.sticky_rejection_message = rejection_message
326+
else:
327+
approval_entry.rejection_messages.clear()
328+
approval_entry.sticky_rejection_message = None
226329
continue
227330

228331
opposite = approval_entry.rejected if approve else approval_entry.approved
@@ -232,6 +335,13 @@ def _apply_approval_decision(
232335
target = approval_entry.approved if approve else approval_entry.rejected
233336
if isinstance(target, list) and call_id not in target:
234337
target.append(call_id)
338+
if approve:
339+
self._clear_rejection_message(approval_entry, call_id)
340+
elif call_id is not None:
341+
if rejection_message is not None:
342+
approval_entry.rejection_messages[call_id] = rejection_message
343+
else:
344+
self._clear_rejection_message(approval_entry, call_id)
235345

236346
def approve_tool(self, approval_item: ToolApprovalItem, always_approve: bool = False) -> None:
237347
"""Approve a tool call, optionally for all future calls."""
@@ -241,12 +351,18 @@ def approve_tool(self, approval_item: ToolApprovalItem, always_approve: bool = F
241351
approve=True,
242352
)
243353

244-
def reject_tool(self, approval_item: ToolApprovalItem, always_reject: bool = False) -> None:
354+
def reject_tool(
355+
self,
356+
approval_item: ToolApprovalItem,
357+
always_reject: bool = False,
358+
rejection_message: str | None = None,
359+
) -> None:
245360
"""Reject a tool call, optionally for all future calls."""
246361
self._apply_approval_decision(
247362
approval_item,
248363
always=always_reject,
249364
approve=False,
365+
rejection_message=rejection_message,
250366
)
251367

252368
def get_approval_status(
@@ -326,6 +442,16 @@ def _rebuild_approvals(self, approvals: dict[str, dict[str, Any]]) -> None:
326442
record = _ApprovalRecord()
327443
record.approved = record_dict.get("approved", [])
328444
record.rejected = record_dict.get("rejected", [])
445+
rejection_messages = record_dict.get("rejection_messages", {})
446+
if isinstance(rejection_messages, dict):
447+
record.rejection_messages = {
448+
str(call_id): message
449+
for call_id, message in rejection_messages.items()
450+
if isinstance(message, str)
451+
}
452+
sticky_rejection_message = record_dict.get("sticky_rejection_message")
453+
if isinstance(sticky_rejection_message, str):
454+
record.sticky_rejection_message = sticky_rejection_message
329455
self._approvals[tool_name] = record
330456

331457
def _fork_with_tool_input(self, tool_input: Any) -> RunContextWrapper[TContext]:

src/agents/run_internal/tool_execution.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1049,8 +1049,21 @@ async def resolve_approval_rejection_message(
10491049
tool_type: Literal["function", "computer", "shell", "apply_patch"],
10501050
tool_name: str,
10511051
call_id: str,
1052+
tool_namespace: str | None = None,
1053+
tool_lookup_key: FunctionToolLookupKey | None = None,
1054+
existing_pending: ToolApprovalItem | None = None,
10521055
) -> str:
10531056
"""Resolve model-visible output text for approval rejections."""
1057+
explicit_message = context_wrapper.get_rejection_message(
1058+
tool_name,
1059+
call_id,
1060+
tool_namespace=tool_namespace,
1061+
tool_lookup_key=tool_lookup_key,
1062+
existing_pending=existing_pending,
1063+
)
1064+
if explicit_message is not None:
1065+
return explicit_message
1066+
10541067
formatter = run_config.tool_error_formatter
10551068
if formatter is None:
10561069
return REJECTION_MESSAGE
@@ -1150,6 +1163,13 @@ def process_hosted_mcp_approvals(
11501163
"approval_request_id": request_id,
11511164
"approve": approved,
11521165
}
1166+
rejection_message = context_wrapper.get_rejection_message(
1167+
tool_name=tool_name,
1168+
call_id=request_id,
1169+
existing_pending=approval_item,
1170+
)
1171+
if approved is False and rejection_message is not None:
1172+
raw_item["reason"] = rejection_message
11531173
response_item = MCPApprovalResponseItem(raw_item=raw_item, agent=agent)
11541174
append_item(response_item)
11551175
continue
@@ -1199,6 +1219,13 @@ def collect_manual_mcp_approvals(
11991219
"approval_request_id": request_id,
12001220
"approve": approval_status,
12011221
}
1222+
rejection_message = context_wrapper.get_rejection_message(
1223+
tool_name,
1224+
request_id,
1225+
existing_pending=existing_pending,
1226+
)
1227+
if approval_status is False and rejection_message is not None:
1228+
approval_response_raw["reason"] = rejection_message
12021229
approved.append(MCPApprovalResponseItem(raw_item=approval_response_raw, agent=agent))
12031230
continue
12041231

@@ -1520,6 +1547,8 @@ async def _maybe_execute_tool_approval(
15201547
tool_type="function",
15211548
tool_name=tool_trace_name(func_tool.name, tool_namespace) or func_tool.name,
15221549
call_id=tool_call.call_id,
1550+
tool_namespace=tool_namespace,
1551+
tool_lookup_key=tool_lookup_key,
15231552
)
15241553
span_fn.set_error(
15251554
SpanError(
@@ -1988,6 +2017,9 @@ async def _resolve_tool_run(
19882017
tool_type="function",
19892018
tool_name=display_tool_name,
19902019
call_id=call_id,
2020+
tool_namespace=tool_namespace,
2021+
tool_lookup_key=tool_lookup_key,
2022+
existing_pending=interruption,
19912023
)
19922024
_append_error(
19932025
message=message,

src/agents/run_internal/turn_resolution.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
build_function_tool_lookup_map,
3333
get_function_tool_lookup_key,
3434
get_function_tool_lookup_key_for_call,
35+
get_function_tool_lookup_key_for_tool,
3536
get_tool_call_namespace,
3637
get_tool_call_qualified_name,
3738
get_tool_call_trace_name,
@@ -705,12 +706,16 @@ async def _record_function_rejection(
705706
return
706707
rejection_message = REJECTION_MESSAGE
707708
if call_id:
709+
tool_namespace = get_tool_call_namespace(tool_call)
708710
rejection_message = await resolve_approval_rejection_message(
709711
context_wrapper=context_wrapper,
710712
run_config=run_config,
711713
tool_type="function",
712714
tool_name=get_tool_call_trace_name(tool_call) or function_tool.name,
713715
call_id=call_id,
716+
tool_namespace=tool_namespace,
717+
tool_lookup_key=get_function_tool_lookup_key_for_tool(function_tool),
718+
existing_pending=approval_items_by_call_id.get(call_id),
714719
)
715720
rejected_function_outputs.append(
716721
function_rejection_item(

0 commit comments

Comments
 (0)