Skip to content

Commit 9e0f7af

Browse files
committed
feat: add approval override argument parity for RunState resumes
1 parent fb67680 commit 9e0f7af

17 files changed

Lines changed: 1592 additions & 68 deletions

src/agents/memory/__init__.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,19 @@
11
from .openai_conversations_session import OpenAIConversationsSession
22
from .openai_responses_compaction_session import OpenAIResponsesCompactionSession
33
from .session import (
4+
SERVER_MANAGED_CONVERSATION_SESSION_ATTR,
45
OpenAIResponsesCompactionArgs,
56
OpenAIResponsesCompactionAwareSession,
7+
ServerManagedConversationSession,
68
Session,
79
SessionABC,
10+
SessionHistoryMutation,
11+
SessionHistoryRewriteArgs,
12+
SessionHistoryRewriteAwareSession,
13+
apply_session_history_mutations,
814
is_openai_responses_compaction_aware_session,
15+
is_server_managed_conversation_session,
16+
is_session_history_rewrite_aware_session,
917
)
1018
from .session_settings import SessionSettings
1119
from .sqlite_session import SQLiteSession
@@ -21,5 +29,13 @@
2129
"OpenAIResponsesCompactionSession",
2230
"OpenAIResponsesCompactionArgs",
2331
"OpenAIResponsesCompactionAwareSession",
32+
"SERVER_MANAGED_CONVERSATION_SESSION_ATTR",
33+
"SessionHistoryMutation",
34+
"SessionHistoryRewriteArgs",
35+
"SessionHistoryRewriteAwareSession",
36+
"ServerManagedConversationSession",
37+
"apply_session_history_mutations",
38+
"is_server_managed_conversation_session",
2439
"is_openai_responses_compaction_aware_session",
40+
"is_session_history_rewrite_aware_session",
2541
]

src/agents/memory/openai_conversations_session.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ async def start_openai_conversations_session(openai_client: AsyncOpenAI | None =
2121

2222

2323
class OpenAIConversationsSession(SessionABC):
24+
_server_managed_conversation_session = True
2425
session_settings: SessionSettings | None = None
2526

2627
def __init__(

src/agents/memory/openai_responses_compaction_session.py

Lines changed: 132 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@
1212
OpenAIResponsesCompactionArgs,
1313
OpenAIResponsesCompactionAwareSession,
1414
SessionABC,
15+
SessionHistoryRewriteArgs,
16+
apply_session_history_mutations,
17+
is_session_history_rewrite_aware_session,
1518
)
1619

1720
if TYPE_CHECKING:
@@ -130,48 +133,87 @@ def __init__(
130133
self._session_items: list[TResponseInputItem] | None = None
131134
self._response_id: str | None = None
132135
self._deferred_response_id: str | None = None
133-
self._last_unstored_response_id: str | None = None
136+
self._last_store: bool | None = None
137+
self._has_pending_local_history_rewrite = False
138+
self._local_history_rewrite_response_id: str | None = None
139+
self._has_unacknowledged_local_session_adds = False
134140

135141
@property
136142
def client(self) -> AsyncOpenAI:
137143
if self._client is None:
138144
self._client = get_default_openai_client() or AsyncOpenAI()
139145
return self._client
140146

141-
def _resolve_compaction_mode_for_response(
147+
def _resolve_compaction_mode(
142148
self,
143149
*,
150+
requested_mode: OpenAIResponsesCompactionMode,
144151
response_id: str | None,
145152
store: bool | None,
146-
requested_mode: OpenAIResponsesCompactionMode | None,
153+
turn_has_local_adds_without_new_response_id: bool,
147154
) -> _ResolvedCompactionMode:
148-
mode = requested_mode or self.compaction_mode
155+
resolved_mode = _resolve_compaction_mode(
156+
requested_mode,
157+
response_id=response_id,
158+
store=store,
159+
)
160+
161+
if turn_has_local_adds_without_new_response_id and resolved_mode == "previous_response_id":
162+
self._has_unacknowledged_local_session_adds = False
163+
self._mark_local_history_rewrite()
164+
logger.debug(
165+
"compact: forcing input mode after local session delta without new response id"
166+
)
167+
return "input"
168+
169+
if not self._has_pending_local_history_rewrite:
170+
return resolved_mode
171+
149172
if (
150-
mode == "auto"
151-
and store is None
173+
self._local_history_rewrite_response_id is not None
152174
and response_id is not None
153-
and response_id == self._last_unstored_response_id
175+
and response_id != self._local_history_rewrite_response_id
154176
):
177+
self._has_pending_local_history_rewrite = False
178+
self._local_history_rewrite_response_id = None
179+
return resolved_mode
180+
181+
if resolved_mode == "previous_response_id":
182+
if self._local_history_rewrite_response_id is None and response_id is not None:
183+
self._local_history_rewrite_response_id = response_id
184+
logger.debug("compact: forcing input mode after local history rewrite")
155185
return "input"
156-
return _resolve_compaction_mode(mode, response_id=response_id, store=store)
186+
187+
return resolved_mode
157188

158189
async def run_compaction(self, args: OpenAIResponsesCompactionArgs | None = None) -> None:
159190
"""Run compaction using responses.compact API."""
191+
previous_response_id = self._response_id
160192
if args and args.get("response_id"):
161193
self._response_id = args["response_id"]
162194
requested_mode = args.get("compaction_mode") if args else None
163195
if args and "store" in args:
164-
store = args["store"]
165-
if store is False and self._response_id:
166-
self._last_unstored_response_id = self._response_id
167-
elif store is True and self._response_id == self._last_unstored_response_id:
168-
self._last_unstored_response_id = None
196+
store: bool | None = args["store"]
197+
self._last_store = store
169198
else:
170-
store = None
171-
resolved_mode = self._resolve_compaction_mode_for_response(
199+
store = self._last_store
200+
turn_has_local_adds_without_new_response_id = (
201+
self._has_unacknowledged_local_session_adds
202+
and (args is None or args.get("response_id") in {None, previous_response_id})
203+
)
204+
if (
205+
args
206+
and args.get("response_id") is not None
207+
and args["response_id"] != previous_response_id
208+
):
209+
self._has_unacknowledged_local_session_adds = False
210+
resolved_mode = self._resolve_compaction_mode(
172211
response_id=self._response_id,
173212
store=store,
174-
requested_mode=requested_mode,
213+
requested_mode=requested_mode or self.compaction_mode,
214+
turn_has_local_adds_without_new_response_id=(
215+
turn_has_local_adds_without_new_response_id
216+
),
175217
)
176218

177219
if resolved_mode == "previous_response_id" and not self._response_id:
@@ -199,6 +241,15 @@ async def run_compaction(self, args: OpenAIResponsesCompactionArgs | None = None
199241
)
200242
return
201243

244+
unresolved_function_calls = _find_unresolved_function_calls_without_results(session_items)
245+
if unresolved_function_calls:
246+
logger.debug(
247+
"compact: blocked unresolved function calls for %s: %s",
248+
self._response_id,
249+
unresolved_function_calls,
250+
)
251+
return
252+
202253
self._deferred_response_id = None
203254
logger.debug(
204255
f"compact: start for {self._response_id} using {self.model} (mode={resolved_mode})"
@@ -242,14 +293,37 @@ async def run_compaction(self, args: OpenAIResponsesCompactionArgs | None = None
242293
async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]:
243294
return await self.underlying_session.get_items(limit)
244295

296+
async def apply_history_mutations(self, args: SessionHistoryRewriteArgs) -> None:
297+
"""Rewrite persisted history and keep compaction caches aligned with the new transcript."""
298+
mutations = list(args.get("mutations", []))
299+
if not mutations:
300+
return
301+
302+
if is_session_history_rewrite_aware_session(self.underlying_session):
303+
await self.underlying_session.apply_history_mutations({"mutations": mutations})
304+
await self._refresh_caches_from_underlying_session()
305+
self._mark_local_history_rewrite()
306+
return
307+
308+
rewritten_items = apply_session_history_mutations(
309+
await self.underlying_session.get_items(),
310+
mutations,
311+
)
312+
await self.underlying_session.clear_session()
313+
if rewritten_items:
314+
await self.underlying_session.add_items(rewritten_items)
315+
self._session_items = rewritten_items
316+
self._compaction_candidate_items = select_compaction_candidate_items(rewritten_items)
317+
self._mark_local_history_rewrite()
318+
245319
async def _defer_compaction(self, response_id: str, store: bool | None = None) -> None:
246320
if self._deferred_response_id is not None:
247321
return
248322
compaction_candidate_items, session_items = await self._ensure_compaction_candidates()
249-
resolved_mode = self._resolve_compaction_mode_for_response(
323+
resolved_mode = _resolve_compaction_mode(
324+
self.compaction_mode,
250325
response_id=response_id,
251-
store=store,
252-
requested_mode=None,
326+
store=store if store is not None else self._last_store,
253327
)
254328
should_compact = self.should_trigger_compaction(
255329
{
@@ -269,7 +343,10 @@ def _clear_deferred_compaction(self) -> None:
269343
self._deferred_response_id = None
270344

271345
async def add_items(self, items: list[TResponseInputItem]) -> None:
346+
if not items:
347+
return
272348
await self.underlying_session.add_items(items)
349+
self._has_unacknowledged_local_session_adds = True
273350
if self._compaction_candidate_items is not None:
274351
new_items = _normalize_compaction_session_items(items)
275352
new_candidates = select_compaction_candidate_items(new_items)
@@ -290,6 +367,15 @@ async def clear_session(self) -> None:
290367
self._compaction_candidate_items = []
291368
self._session_items = []
292369
self._deferred_response_id = None
370+
self._has_pending_local_history_rewrite = False
371+
self._local_history_rewrite_response_id = None
372+
self._has_unacknowledged_local_session_adds = False
373+
self._last_store = None
374+
375+
async def _refresh_caches_from_underlying_session(self) -> None:
376+
history = await self.underlying_session.get_items()
377+
self._session_items = history
378+
self._compaction_candidate_items = select_compaction_candidate_items(history)
293379

294380
async def _ensure_compaction_candidates(
295381
self,
@@ -308,6 +394,10 @@ async def _ensure_compaction_candidates(
308394
)
309395
return (candidates[:], history[:])
310396

397+
def _mark_local_history_rewrite(self) -> None:
398+
self._has_pending_local_history_rewrite = True
399+
self._local_history_rewrite_response_id = self._response_id
400+
311401

312402
def _strip_orphaned_assistant_ids(
313403
items: list[TResponseInputItem],
@@ -348,6 +438,29 @@ def _normalize_compaction_session_items(
348438
_ResolvedCompactionMode = Literal["previous_response_id", "input"]
349439

350440

441+
def _find_unresolved_function_calls_without_results(items: list[TResponseInputItem]) -> list[str]:
442+
"""Return function-call ids that do not yet have matching outputs."""
443+
function_calls: dict[str, TResponseInputItem] = {}
444+
resolved_call_ids: set[str] = set()
445+
446+
for item in items:
447+
if isinstance(item, dict):
448+
item_type = item.get("type")
449+
call_id = item.get("call_id")
450+
else:
451+
item_type = getattr(item, "type", None)
452+
call_id = getattr(item, "call_id", None)
453+
454+
if not isinstance(call_id, str):
455+
continue
456+
if item_type == "function_call":
457+
function_calls[call_id] = item
458+
elif item_type == "function_call_output":
459+
resolved_call_ids.add(call_id)
460+
461+
return [call_id for call_id in function_calls if call_id not in resolved_call_ids]
462+
463+
351464
def _resolve_compaction_mode(
352465
requested_mode: OpenAIResponsesCompactionMode,
353466
*,

0 commit comments

Comments
 (0)