Skip to content

Commit 1e8b70a

Browse files
committed
feat: add approval override argument parity for RunState resumes
1 parent 457d1a5 commit 1e8b70a

17 files changed

Lines changed: 1592 additions & 68 deletions

src/agents/memory/__init__.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,19 @@
11
from .openai_conversations_session import OpenAIConversationsSession
22
from .openai_responses_compaction_session import OpenAIResponsesCompactionSession
33
from .session import (
4+
SERVER_MANAGED_CONVERSATION_SESSION_ATTR,
45
OpenAIResponsesCompactionArgs,
56
OpenAIResponsesCompactionAwareSession,
7+
ServerManagedConversationSession,
68
Session,
79
SessionABC,
10+
SessionHistoryMutation,
11+
SessionHistoryRewriteArgs,
12+
SessionHistoryRewriteAwareSession,
13+
apply_session_history_mutations,
814
is_openai_responses_compaction_aware_session,
15+
is_server_managed_conversation_session,
16+
is_session_history_rewrite_aware_session,
917
)
1018
from .session_settings import SessionSettings
1119
from .sqlite_session import SQLiteSession
@@ -21,5 +29,13 @@
2129
"OpenAIResponsesCompactionSession",
2230
"OpenAIResponsesCompactionArgs",
2331
"OpenAIResponsesCompactionAwareSession",
32+
"SERVER_MANAGED_CONVERSATION_SESSION_ATTR",
33+
"SessionHistoryMutation",
34+
"SessionHistoryRewriteArgs",
35+
"SessionHistoryRewriteAwareSession",
36+
"ServerManagedConversationSession",
37+
"apply_session_history_mutations",
38+
"is_server_managed_conversation_session",
2439
"is_openai_responses_compaction_aware_session",
40+
"is_session_history_rewrite_aware_session",
2541
]

src/agents/memory/openai_conversations_session.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ async def start_openai_conversations_session(openai_client: AsyncOpenAI | None =
2121

2222

2323
class OpenAIConversationsSession(SessionABC):
24+
_server_managed_conversation_session = True
2425
session_settings: SessionSettings | None = None
2526

2627
def __init__(

src/agents/memory/openai_responses_compaction_session.py

Lines changed: 132 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@
1111
OpenAIResponsesCompactionArgs,
1212
OpenAIResponsesCompactionAwareSession,
1313
SessionABC,
14+
SessionHistoryRewriteArgs,
15+
apply_session_history_mutations,
16+
is_session_history_rewrite_aware_session,
1417
)
1518

1619
if TYPE_CHECKING:
@@ -129,48 +132,87 @@ def __init__(
129132
self._session_items: list[TResponseInputItem] | None = None
130133
self._response_id: str | None = None
131134
self._deferred_response_id: str | None = None
132-
self._last_unstored_response_id: str | None = None
135+
self._last_store: bool | None = None
136+
self._has_pending_local_history_rewrite = False
137+
self._local_history_rewrite_response_id: str | None = None
138+
self._has_unacknowledged_local_session_adds = False
133139

134140
@property
135141
def client(self) -> AsyncOpenAI:
136142
if self._client is None:
137143
self._client = get_default_openai_client() or AsyncOpenAI()
138144
return self._client
139145

140-
def _resolve_compaction_mode_for_response(
146+
def _resolve_compaction_mode(
141147
self,
142148
*,
149+
requested_mode: OpenAIResponsesCompactionMode,
143150
response_id: str | None,
144151
store: bool | None,
145-
requested_mode: OpenAIResponsesCompactionMode | None,
152+
turn_has_local_adds_without_new_response_id: bool,
146153
) -> _ResolvedCompactionMode:
147-
mode = requested_mode or self.compaction_mode
154+
resolved_mode = _resolve_compaction_mode(
155+
requested_mode,
156+
response_id=response_id,
157+
store=store,
158+
)
159+
160+
if turn_has_local_adds_without_new_response_id and resolved_mode == "previous_response_id":
161+
self._has_unacknowledged_local_session_adds = False
162+
self._mark_local_history_rewrite()
163+
logger.debug(
164+
"compact: forcing input mode after local session delta without new response id"
165+
)
166+
return "input"
167+
168+
if not self._has_pending_local_history_rewrite:
169+
return resolved_mode
170+
148171
if (
149-
mode == "auto"
150-
and store is None
172+
self._local_history_rewrite_response_id is not None
151173
and response_id is not None
152-
and response_id == self._last_unstored_response_id
174+
and response_id != self._local_history_rewrite_response_id
153175
):
176+
self._has_pending_local_history_rewrite = False
177+
self._local_history_rewrite_response_id = None
178+
return resolved_mode
179+
180+
if resolved_mode == "previous_response_id":
181+
if self._local_history_rewrite_response_id is None and response_id is not None:
182+
self._local_history_rewrite_response_id = response_id
183+
logger.debug("compact: forcing input mode after local history rewrite")
154184
return "input"
155-
return _resolve_compaction_mode(mode, response_id=response_id, store=store)
185+
186+
return resolved_mode
156187

157188
async def run_compaction(self, args: OpenAIResponsesCompactionArgs | None = None) -> None:
158189
"""Run compaction using responses.compact API."""
190+
previous_response_id = self._response_id
159191
if args and args.get("response_id"):
160192
self._response_id = args["response_id"]
161193
requested_mode = args.get("compaction_mode") if args else None
162194
if args and "store" in args:
163-
store = args["store"]
164-
if store is False and self._response_id:
165-
self._last_unstored_response_id = self._response_id
166-
elif store is True and self._response_id == self._last_unstored_response_id:
167-
self._last_unstored_response_id = None
195+
store: bool | None = args["store"]
196+
self._last_store = store
168197
else:
169-
store = None
170-
resolved_mode = self._resolve_compaction_mode_for_response(
198+
store = self._last_store
199+
turn_has_local_adds_without_new_response_id = (
200+
self._has_unacknowledged_local_session_adds
201+
and (args is None or args.get("response_id") in {None, previous_response_id})
202+
)
203+
if (
204+
args
205+
and args.get("response_id") is not None
206+
and args["response_id"] != previous_response_id
207+
):
208+
self._has_unacknowledged_local_session_adds = False
209+
resolved_mode = self._resolve_compaction_mode(
171210
response_id=self._response_id,
172211
store=store,
173-
requested_mode=requested_mode,
212+
requested_mode=requested_mode or self.compaction_mode,
213+
turn_has_local_adds_without_new_response_id=(
214+
turn_has_local_adds_without_new_response_id
215+
),
174216
)
175217

176218
if resolved_mode == "previous_response_id" and not self._response_id:
@@ -198,6 +240,15 @@ async def run_compaction(self, args: OpenAIResponsesCompactionArgs | None = None
198240
)
199241
return
200242

243+
unresolved_function_calls = _find_unresolved_function_calls_without_results(session_items)
244+
if unresolved_function_calls:
245+
logger.debug(
246+
"compact: blocked unresolved function calls for %s: %s",
247+
self._response_id,
248+
unresolved_function_calls,
249+
)
250+
return
251+
201252
self._deferred_response_id = None
202253
logger.debug(
203254
f"compact: start for {self._response_id} using {self.model} (mode={resolved_mode})"
@@ -239,14 +290,37 @@ async def run_compaction(self, args: OpenAIResponsesCompactionArgs | None = None
239290
async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]:
240291
return await self.underlying_session.get_items(limit)
241292

293+
async def apply_history_mutations(self, args: SessionHistoryRewriteArgs) -> None:
294+
"""Rewrite persisted history and keep compaction caches aligned with the new transcript."""
295+
mutations = list(args.get("mutations", []))
296+
if not mutations:
297+
return
298+
299+
if is_session_history_rewrite_aware_session(self.underlying_session):
300+
await self.underlying_session.apply_history_mutations({"mutations": mutations})
301+
await self._refresh_caches_from_underlying_session()
302+
self._mark_local_history_rewrite()
303+
return
304+
305+
rewritten_items = apply_session_history_mutations(
306+
await self.underlying_session.get_items(),
307+
mutations,
308+
)
309+
await self.underlying_session.clear_session()
310+
if rewritten_items:
311+
await self.underlying_session.add_items(rewritten_items)
312+
self._session_items = rewritten_items
313+
self._compaction_candidate_items = select_compaction_candidate_items(rewritten_items)
314+
self._mark_local_history_rewrite()
315+
242316
async def _defer_compaction(self, response_id: str, store: bool | None = None) -> None:
243317
if self._deferred_response_id is not None:
244318
return
245319
compaction_candidate_items, session_items = await self._ensure_compaction_candidates()
246-
resolved_mode = self._resolve_compaction_mode_for_response(
320+
resolved_mode = _resolve_compaction_mode(
321+
self.compaction_mode,
247322
response_id=response_id,
248-
store=store,
249-
requested_mode=None,
323+
store=store if store is not None else self._last_store,
250324
)
251325
should_compact = self.should_trigger_compaction(
252326
{
@@ -266,7 +340,10 @@ def _clear_deferred_compaction(self) -> None:
266340
self._deferred_response_id = None
267341

268342
async def add_items(self, items: list[TResponseInputItem]) -> None:
343+
if not items:
344+
return
269345
await self.underlying_session.add_items(items)
346+
self._has_unacknowledged_local_session_adds = True
270347
if self._compaction_candidate_items is not None:
271348
new_candidates = select_compaction_candidate_items(items)
272349
if new_candidates:
@@ -286,6 +363,15 @@ async def clear_session(self) -> None:
286363
self._compaction_candidate_items = []
287364
self._session_items = []
288365
self._deferred_response_id = None
366+
self._has_pending_local_history_rewrite = False
367+
self._local_history_rewrite_response_id = None
368+
self._has_unacknowledged_local_session_adds = False
369+
self._last_store = None
370+
371+
async def _refresh_caches_from_underlying_session(self) -> None:
372+
history = await self.underlying_session.get_items()
373+
self._session_items = history
374+
self._compaction_candidate_items = select_compaction_candidate_items(history)
289375

290376
async def _ensure_compaction_candidates(
291377
self,
@@ -304,10 +390,37 @@ async def _ensure_compaction_candidates(
304390
)
305391
return (candidates[:], history[:])
306392

393+
def _mark_local_history_rewrite(self) -> None:
394+
self._has_pending_local_history_rewrite = True
395+
self._local_history_rewrite_response_id = self._response_id
396+
307397

308398
_ResolvedCompactionMode = Literal["previous_response_id", "input"]
309399

310400

401+
def _find_unresolved_function_calls_without_results(items: list[TResponseInputItem]) -> list[str]:
402+
"""Return function-call ids that do not yet have matching outputs."""
403+
function_calls: dict[str, TResponseInputItem] = {}
404+
resolved_call_ids: set[str] = set()
405+
406+
for item in items:
407+
if isinstance(item, dict):
408+
item_type = item.get("type")
409+
call_id = item.get("call_id")
410+
else:
411+
item_type = getattr(item, "type", None)
412+
call_id = getattr(item, "call_id", None)
413+
414+
if not isinstance(call_id, str):
415+
continue
416+
if item_type == "function_call":
417+
function_calls[call_id] = item
418+
elif item_type == "function_call_output":
419+
resolved_call_ids.add(call_id)
420+
421+
return [call_id for call_id in function_calls if call_id not in resolved_call_ids]
422+
423+
311424
def _resolve_compaction_mode(
312425
requested_mode: OpenAIResponsesCompactionMode,
313426
*,

src/agents/memory/session.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import copy
34
from abc import ABC, abstractmethod
45
from typing import TYPE_CHECKING, Literal, Protocol, runtime_checkable
56

@@ -9,6 +10,8 @@
910
from ..items import TResponseInputItem
1011
from .session_settings import SessionSettings
1112

13+
SERVER_MANAGED_CONVERSATION_SESSION_ATTR = "_server_managed_conversation_session"
14+
1215

1316
@runtime_checkable
1417
class Session(Protocol):
@@ -104,6 +107,105 @@ async def clear_session(self) -> None:
104107
...
105108

106109

110+
@runtime_checkable
111+
class ServerManagedConversationSession(Session, Protocol):
112+
"""Protocol for sessions whose canonical history is managed by a remote service."""
113+
114+
_server_managed_conversation_session: Literal[True]
115+
116+
117+
def is_server_managed_conversation_session(
118+
session: Session | None,
119+
) -> TypeGuard[ServerManagedConversationSession]:
120+
"""Check whether the session advertises server-managed history semantics."""
121+
if session is None:
122+
return False
123+
try:
124+
marker = getattr(session, SERVER_MANAGED_CONVERSATION_SESSION_ATTR, False)
125+
except Exception:
126+
return False
127+
return marker is True
128+
129+
130+
class ReplaceFunctionCallSessionHistoryMutation(TypedDict):
131+
"""Replace the canonical persisted function call for a tool call."""
132+
133+
type: Literal["replace_function_call"]
134+
call_id: str
135+
replacement: TResponseInputItem
136+
137+
138+
SessionHistoryMutation = ReplaceFunctionCallSessionHistoryMutation
139+
140+
141+
class SessionHistoryRewriteArgs(TypedDict):
142+
"""Arguments for persisted-history rewrites."""
143+
144+
mutations: list[SessionHistoryMutation]
145+
146+
147+
@runtime_checkable
148+
class SessionHistoryRewriteAwareSession(Session, Protocol):
149+
"""Protocol for sessions that can rewrite previously persisted history."""
150+
151+
async def apply_history_mutations(self, args: SessionHistoryRewriteArgs) -> None:
152+
"""Apply structured history mutations to the persisted session history."""
153+
...
154+
155+
156+
def is_session_history_rewrite_aware_session(
157+
session: Session | None,
158+
) -> TypeGuard[SessionHistoryRewriteAwareSession]:
159+
"""Check whether a session supports persisted-history rewrites."""
160+
if session is None:
161+
return False
162+
try:
163+
apply_history_mutations = getattr(session, "apply_history_mutations", None)
164+
except Exception:
165+
return False
166+
return callable(apply_history_mutations)
167+
168+
169+
def apply_session_history_mutations(
170+
items: list[TResponseInputItem],
171+
mutations: list[SessionHistoryMutation],
172+
) -> list[TResponseInputItem]:
173+
"""Apply structured history mutations to a list of persisted session items."""
174+
next_items = [copy.deepcopy(item) for item in items]
175+
for mutation in mutations:
176+
if mutation["type"] == "replace_function_call":
177+
next_items = _apply_replace_function_call_mutation(next_items, mutation)
178+
return next_items
179+
180+
181+
def _apply_replace_function_call_mutation(
182+
items: list[TResponseInputItem],
183+
mutation: ReplaceFunctionCallSessionHistoryMutation,
184+
) -> list[TResponseInputItem]:
185+
"""Replace the first matching function call and drop later duplicates for the same call id."""
186+
replacement = copy.deepcopy(mutation["replacement"])
187+
next_items: list[TResponseInputItem] = []
188+
kept_replacement = False
189+
190+
for item in items:
191+
if _is_matching_function_call(item, mutation["call_id"]):
192+
if not kept_replacement:
193+
next_items.append(replacement)
194+
kept_replacement = True
195+
continue
196+
next_items.append(item)
197+
198+
return next_items
199+
200+
201+
def _is_matching_function_call(item: TResponseInputItem, call_id: str) -> bool:
202+
if isinstance(item, dict):
203+
return item.get("type") == "function_call" and item.get("call_id") == call_id
204+
item_type = getattr(item, "type", None)
205+
item_call_id = getattr(item, "call_id", None)
206+
return item_type == "function_call" and item_call_id == call_id
207+
208+
107209
class OpenAIResponsesCompactionArgs(TypedDict, total=False):
108210
"""Arguments for the run_compaction method."""
109211

0 commit comments

Comments
 (0)