Skip to content

Commit a7d0635

Browse files
committed
Pass context wrappers to session methods
1 parent e80d2d2 commit a7d0635

21 files changed

+617
-88
lines changed

examples/memory/file_session.py

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,16 @@
1010
import json
1111
from datetime import datetime
1212
from pathlib import Path
13-
from typing import Any
13+
from typing import TYPE_CHECKING, Any
1414
from uuid import uuid4
1515

1616
from agents.memory.session import Session
1717
from agents.memory.session_settings import SessionSettings
1818

19+
if TYPE_CHECKING:
20+
from agents.items import TResponseInputItem
21+
from agents.run_context import RunContextWrapper
22+
1923

2024
class FileSession(Session):
2125
"""Persist session items to a JSON file on disk."""
@@ -43,14 +47,26 @@ async def get_session_id(self) -> str:
4347
"""Return the session id, creating one if needed."""
4448
return await self._ensure_session_id()
4549

46-
async def get_items(self, limit: int | None = None) -> list[Any]:
50+
async def get_items(
51+
self,
52+
limit: int | None = None,
53+
*,
54+
wrapper: RunContextWrapper[Any] | None = None,
55+
) -> list[TResponseInputItem]:
56+
del wrapper
4757
session_id = await self._ensure_session_id()
4858
items = await self._read_items(session_id)
4959
if limit is not None and limit >= 0:
5060
return items[-limit:]
5161
return items
5262

53-
async def add_items(self, items: list[Any]) -> None:
63+
async def add_items(
64+
self,
65+
items: list[TResponseInputItem],
66+
*,
67+
wrapper: RunContextWrapper[Any] | None = None,
68+
) -> None:
69+
del wrapper
5470
if not items:
5571
return
5672
session_id = await self._ensure_session_id()
@@ -59,7 +75,12 @@ async def add_items(self, items: list[Any]) -> None:
5975
cloned = json.loads(json.dumps(items))
6076
await self._write_items(session_id, current + cloned)
6177

62-
async def pop_item(self) -> Any | None:
78+
async def pop_item(
79+
self,
80+
*,
81+
wrapper: RunContextWrapper[Any] | None = None,
82+
) -> TResponseInputItem | None:
83+
del wrapper
6384
session_id = await self._ensure_session_id()
6485
items = await self._read_items(session_id)
6586
if not items:
@@ -89,7 +110,7 @@ def _items_path(self, session_id: str) -> Path:
89110
def _state_path(self, session_id: str) -> Path:
90111
return self._dir / f"{session_id}-state.json"
91112

92-
async def _read_items(self, session_id: str) -> list[Any]:
113+
async def _read_items(self, session_id: str) -> list[TResponseInputItem]:
93114
file_path = self._items_path(session_id)
94115
try:
95116
data = await asyncio.to_thread(file_path.read_text, "utf-8")
@@ -98,7 +119,7 @@ async def _read_items(self, session_id: str) -> list[Any]:
98119
except FileNotFoundError:
99120
return []
100121

101-
async def _write_items(self, session_id: str, items: list[Any]) -> None:
122+
async def _write_items(self, session_id: str, items: list[TResponseInputItem]) -> None:
102123
file_path = self._items_path(session_id)
103124
payload = json.dumps(items, indent=2, ensure_ascii=False)
104125
await asyncio.to_thread(self._dir.mkdir, parents=True, exist_ok=True)

src/agents/extensions/memory/advanced_sqlite_session.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,12 @@ def _init_structure_tables(self):
121121

122122
conn.commit()
123123

124-
async def add_items(self, items: list[TResponseInputItem]) -> None:
124+
async def add_items(
125+
self,
126+
items: list[TResponseInputItem],
127+
*,
128+
wrapper: Any = None,
129+
) -> None:
125130
"""Add items to the session.
126131
127132
Args:
@@ -160,6 +165,8 @@ async def get_items(
160165
self,
161166
limit: int | None = None,
162167
branch_id: str | None = None,
168+
*,
169+
wrapper: Any = None,
163170
) -> list[TResponseInputItem]:
164171
"""Get items from current or specified branch.
165172

src/agents/extensions/memory/async_sqlite_session.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from collections.abc import AsyncIterator
66
from contextlib import asynccontextmanager
77
from pathlib import Path
8-
from typing import cast
8+
from typing import Any, cast
99

1010
import aiosqlite
1111

@@ -102,7 +102,12 @@ async def _locked_connection(self) -> AsyncIterator[aiosqlite.Connection]:
102102
conn = await self._get_connection()
103103
yield conn
104104

105-
async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]:
105+
async def get_items(
106+
self,
107+
limit: int | None = None,
108+
*,
109+
wrapper: Any = None,
110+
) -> list[TResponseInputItem]:
106111
"""Retrieve the conversation history for this session.
107112
108113
Args:
@@ -150,7 +155,12 @@ async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]:
150155

151156
return items
152157

153-
async def add_items(self, items: list[TResponseInputItem]) -> None:
158+
async def add_items(
159+
self,
160+
items: list[TResponseInputItem],
161+
*,
162+
wrapper: Any = None,
163+
) -> None:
154164
"""Add new items to the conversation history.
155165
156166
Args:
@@ -186,7 +196,7 @@ async def add_items(self, items: list[TResponseInputItem]) -> None:
186196

187197
await conn.commit()
188198

189-
async def pop_item(self) -> TResponseInputItem | None:
199+
async def pop_item(self, *, wrapper: Any = None) -> TResponseInputItem | None:
190200
"""Remove and return the most recent item from the session.
191201
192202
Returns:

src/agents/extensions/memory/dapr_session.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,12 @@ async def _handle_concurrency_conflict(self, error: Exception, attempt: int) ->
232232
# Session protocol implementation
233233
# ------------------------------------------------------------------
234234

235-
async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]:
235+
async def get_items(
236+
self,
237+
limit: int | None = None,
238+
*,
239+
wrapper: Any = None,
240+
) -> list[TResponseInputItem]:
236241
"""Retrieve the conversation history for this session.
237242
238243
Args:
@@ -271,7 +276,12 @@ async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]:
271276
continue
272277
return items
273278

274-
async def add_items(self, items: list[TResponseInputItem]) -> None:
279+
async def add_items(
280+
self,
281+
items: list[TResponseInputItem],
282+
*,
283+
wrapper: Any = None,
284+
) -> None:
275285
"""Add new items to the conversation history.
276286
277287
Args:
@@ -324,7 +334,7 @@ async def add_items(self, items: list[TResponseInputItem]) -> None:
324334
options=self._get_state_options(),
325335
)
326336

327-
async def pop_item(self) -> TResponseInputItem | None:
337+
async def pop_item(self, *, wrapper: Any = None) -> TResponseInputItem | None:
328338
"""Remove and return the most recent item from the session.
329339
330340
Returns:

src/agents/extensions/memory/encrypt_session.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
from typing_extensions import TypedDict
3838

3939
from ...items import TResponseInputItem
40-
from ...memory.session import SessionABC
40+
from ...memory.session import SessionABC, add_session_items, get_session_items, pop_session_item
4141
from ...memory.session_settings import SessionSettings
4242

4343

@@ -170,22 +170,40 @@ def _unwrap(self, item: TResponseInputItem | EncryptedEnvelope) -> TResponseInpu
170170
except (InvalidToken, KeyError):
171171
return None
172172

173-
async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]:
174-
encrypted_items = await self.underlying_session.get_items(limit)
173+
async def get_items(
174+
self,
175+
limit: int | None = None,
176+
*,
177+
wrapper: Any = None,
178+
) -> list[TResponseInputItem]:
179+
encrypted_items = await get_session_items(
180+
self.underlying_session,
181+
limit,
182+
wrapper=cast(Any, wrapper),
183+
)
175184
valid_items: list[TResponseInputItem] = []
176185
for enc in encrypted_items:
177186
item = self._unwrap(enc)
178187
if item is not None:
179188
valid_items.append(item)
180189
return valid_items
181190

182-
async def add_items(self, items: list[TResponseInputItem]) -> None:
191+
async def add_items(
192+
self,
193+
items: list[TResponseInputItem],
194+
*,
195+
wrapper: Any = None,
196+
) -> None:
183197
wrapped: list[EncryptedEnvelope] = [self._wrap(it) for it in items]
184-
await self.underlying_session.add_items(cast(list[TResponseInputItem], wrapped))
198+
await add_session_items(
199+
self.underlying_session,
200+
cast(list[TResponseInputItem], wrapped),
201+
wrapper=cast(Any, wrapper),
202+
)
185203

186-
async def pop_item(self) -> TResponseInputItem | None:
204+
async def pop_item(self, *, wrapper: Any = None) -> TResponseInputItem | None:
187205
while True:
188-
enc = await self.underlying_session.pop_item()
206+
enc = await pop_session_item(self.underlying_session, wrapper=cast(Any, wrapper))
189207
if not enc:
190208
return None
191209
item = self._unwrap(enc)

src/agents/extensions/memory/mongodb_session.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,12 @@ async def _deserialize_item(self, raw: str) -> TResponseInputItem:
241241
# Session protocol implementation
242242
# ------------------------------------------------------------------
243243

244-
async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]:
244+
async def get_items(
245+
self,
246+
limit: int | None = None,
247+
*,
248+
wrapper: Any = None,
249+
) -> list[TResponseInputItem]:
245250
"""Retrieve the conversation history for this session.
246251
247252
Args:
@@ -283,7 +288,12 @@ async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]:
283288

284289
return items
285290

286-
async def add_items(self, items: list[TResponseInputItem]) -> None:
291+
async def add_items(
292+
self,
293+
items: list[TResponseInputItem],
294+
*,
295+
wrapper: Any = None,
296+
) -> None:
287297
"""Add new items to the conversation history.
288298
289299
Args:
@@ -319,7 +329,7 @@ async def add_items(self, items: list[TResponseInputItem]) -> None:
319329

320330
await self._messages.insert_many(payload, ordered=True)
321331

322-
async def pop_item(self) -> TResponseInputItem | None:
332+
async def pop_item(self, *, wrapper: Any = None) -> TResponseInputItem | None:
323333
"""Remove and return the most recent item from the session.
324334
325335
Returns:

src/agents/extensions/memory/redis_session.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,12 @@ async def _set_ttl_if_configured(self, *keys: str) -> None:
140140
# Session protocol implementation
141141
# ------------------------------------------------------------------
142142

143-
async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]:
143+
async def get_items(
144+
self,
145+
limit: int | None = None,
146+
*,
147+
wrapper: Any = None,
148+
) -> list[TResponseInputItem]:
144149
"""Retrieve the conversation history for this session.
145150
146151
Args:
@@ -179,7 +184,12 @@ async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]:
179184

180185
return items
181186

182-
async def add_items(self, items: list[TResponseInputItem]) -> None:
187+
async def add_items(
188+
self,
189+
items: list[TResponseInputItem],
190+
*,
191+
wrapper: Any = None,
192+
) -> None:
183193
"""Add new items to the conversation history.
184194
185195
Args:
@@ -221,7 +231,7 @@ async def add_items(self, items: list[TResponseInputItem]) -> None:
221231
self._session_key, self._messages_key, self._counter_key
222232
)
223233

224-
async def pop_item(self) -> TResponseInputItem | None:
234+
async def pop_item(self, *, wrapper: Any = None) -> TResponseInputItem | None:
225235
"""Remove and return the most recent item from the session.
226236
227237
Returns:

src/agents/extensions/memory/sqlalchemy_session.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,12 @@ async def _ensure_tables(self) -> None:
274274
finally:
275275
self._init_lock.release()
276276

277-
async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]:
277+
async def get_items(
278+
self,
279+
limit: int | None = None,
280+
*,
281+
wrapper: Any = None,
282+
) -> list[TResponseInputItem]:
278283
"""Retrieve the conversation history for this session.
279284
280285
Args:
@@ -326,7 +331,12 @@ async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]:
326331
continue
327332
return items
328333

329-
async def add_items(self, items: list[TResponseInputItem]) -> None:
334+
async def add_items(
335+
self,
336+
items: list[TResponseInputItem],
337+
*,
338+
wrapper: Any = None,
339+
) -> None:
330340
"""Add new items to the conversation history.
331341
332342
Args:
@@ -376,7 +386,7 @@ async def _write_items() -> None:
376386

377387
await self._run_sqlite_write_with_retry(_write_items)
378388

379-
async def pop_item(self) -> TResponseInputItem | None:
389+
async def pop_item(self, *, wrapper: Any = None) -> TResponseInputItem | None:
380390
"""Remove and return the most recent item from the session.
381391
382392
Returns:

src/agents/memory/openai_conversations_session.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from __future__ import annotations
22

3+
from typing import Any
4+
35
from openai import AsyncOpenAI
46

57
from agents.models._openai_shared import get_default_openai_client
@@ -70,7 +72,12 @@ async def _get_session_id(self) -> str:
7072
async def _clear_session_id(self) -> None:
7173
self._session_id = None
7274

73-
async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]:
75+
async def get_items(
76+
self,
77+
limit: int | None = None,
78+
*,
79+
wrapper: Any = None,
80+
) -> list[TResponseInputItem]:
7481
session_id = await self._get_session_id()
7582

7683
session_limit = resolve_session_limit(limit, self.session_settings)
@@ -97,7 +104,12 @@ async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]:
97104

98105
return all_items # type: ignore
99106

100-
async def add_items(self, items: list[TResponseInputItem]) -> None:
107+
async def add_items(
108+
self,
109+
items: list[TResponseInputItem],
110+
*,
111+
wrapper: Any = None,
112+
) -> None:
101113
session_id = await self._get_session_id()
102114
if not items:
103115
return
@@ -107,9 +119,9 @@ async def add_items(self, items: list[TResponseInputItem]) -> None:
107119
items=items,
108120
)
109121

110-
async def pop_item(self) -> TResponseInputItem | None:
122+
async def pop_item(self, *, wrapper: Any = None) -> TResponseInputItem | None:
111123
session_id = await self._get_session_id()
112-
items = await self.get_items(limit=1)
124+
items = await self.get_items(limit=1, wrapper=wrapper)
113125
if not items:
114126
return None
115127
item_id: str = str(items[0]["id"]) # type: ignore [typeddict-item]

0 commit comments

Comments
 (0)