Skip to content

Commit 74fc3b6

Browse files
alexbeviclaude
andcommitted
feat: add MongoDBSession extension using pymongo async API
Implements MongoDBSession under src/agents/extensions/memory/ following the extensions directory structure established in issue #1328. Uses pymongo>=4.13's native AsyncMongoClient (pymongo.asynchronous) rather than Motor. - Two-collection schema: agent_sessions (metadata) + agent_messages (items) - Chronological ordering via ObjectId; compound index on (session_id, _id) - Idempotent one-shot index creation with per-key asyncio.Lock - from_uri() factory with owned-client lifecycle tracking - Supports SessionSettings.limit and explicit per-call limit overrides - Gracefully skips corrupted/missing message_data documents - ping() / close() lifecycle helpers; close() only touches owned clients Adds 26 tests in tests/extensions/memory/test_mongodb_session.py using in-process fake pymongo types injected via sys.modules — no real MongoDB or pymongo installation required to run the suite. Install: pip install openai-agents[mongodb] Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent 09ea6aa commit 74fc3b6

File tree

5 files changed

+1047
-1
lines changed

5 files changed

+1047
-1
lines changed

pyproject.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ sqlalchemy = ["SQLAlchemy>=2.0", "asyncpg>=0.29.0"]
4444
encrypt = ["cryptography>=45.0, <46"]
4545
redis = ["redis>=7"]
4646
dapr = ["dapr>=1.16.0", "grpcio>=1.60.0"]
47+
mongodb = ["pymongo>=4.13"]
4748
docker = ["docker>=6.1"]
4849
blaxel = ["blaxel>=0.2.50", "aiohttp>=3.12,<4"]
4950
daytona = ["daytona>=0.155.0"]
@@ -154,6 +155,10 @@ ignore_missing_imports = true
154155
module = ["runloop_api_client", "runloop_api_client.*"]
155156
ignore_missing_imports = true
156157

158+
[[tool.mypy.overrides]]
159+
module = ["pymongo", "pymongo.*"]
160+
ignore_missing_imports = true
161+
157162
[[tool.mypy.overrides]]
158163
module = ["blaxel", "blaxel.*"]
159164
ignore_missing_imports = true

src/agents/extensions/memory/__init__.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
DaprSession,
2020
)
2121
from .encrypt_session import EncryptedSession
22+
from .mongodb_session import MongoDBSession
2223
from .redis_session import RedisSession
2324
from .sqlalchemy_session import SQLAlchemySession
2425

@@ -29,6 +30,7 @@
2930
"DAPR_CONSISTENCY_STRONG",
3031
"DaprSession",
3132
"EncryptedSession",
33+
"MongoDBSession",
3234
"RedisSession",
3335
"SQLAlchemySession",
3436
]
@@ -117,4 +119,15 @@ def __getattr__(name: str) -> Any:
117119
"Install it with: pip install openai-agents[dapr]"
118120
) from e
119121

122+
if name == "MongoDBSession":
123+
try:
124+
from .mongodb_session import MongoDBSession # noqa: F401
125+
126+
return MongoDBSession
127+
except ModuleNotFoundError as e:
128+
raise ImportError(
129+
"MongoDBSession requires the 'mongodb' extra. "
130+
"Install it with: pip install openai-agents[mongodb]"
131+
) from e
132+
120133
raise AttributeError(f"module {__name__} has no attribute {name}")
Lines changed: 321 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,321 @@
1+
"""MongoDB-powered Session backend.
2+
3+
Requires ``pymongo>=4.13``, which ships the native async API
4+
(``AsyncMongoClient``). Install it with::
5+
6+
pip install openai-agents[mongodb]
7+
8+
Usage::
9+
10+
from agents.extensions.memory import MongoDBSession
11+
12+
# Create from MongoDB URI
13+
session = MongoDBSession.from_uri(
14+
session_id="user-123",
15+
uri="mongodb://localhost:27017",
16+
database="agents",
17+
)
18+
19+
# Or pass an existing AsyncMongoClient that your application already manages
20+
from pymongo.asynchronous.mongo_client import AsyncMongoClient
21+
22+
client = AsyncMongoClient("mongodb://localhost:27017")
23+
session = MongoDBSession(
24+
session_id="user-123",
25+
client=client,
26+
database="agents",
27+
)
28+
29+
await Runner.run(agent, "Hello", session=session)
30+
"""
31+
32+
from __future__ import annotations
33+
34+
import asyncio
35+
import json
36+
from typing import Any
37+
38+
try:
39+
from pymongo.asynchronous.collection import AsyncCollection
40+
from pymongo.asynchronous.mongo_client import AsyncMongoClient
41+
except ImportError as e:
42+
raise ImportError(
43+
"MongoDBSession requires the 'pymongo' package (>=4.13). "
44+
"Install it with: pip install openai-agents[mongodb]"
45+
) from e
46+
47+
from ...items import TResponseInputItem
48+
from ...memory.session import SessionABC
49+
from ...memory.session_settings import SessionSettings, resolve_session_limit
50+
51+
52+
class MongoDBSession(SessionABC):
53+
"""MongoDB implementation of :pyclass:`agents.memory.session.Session`.
54+
55+
Conversation items are stored as individual documents in a ``messages``
56+
collection. A lightweight ``sessions`` collection tracks metadata
57+
(creation time, last-updated time) for each session.
58+
59+
Indexes are created once per ``(database, sessions_collection,
60+
messages_collection)`` combination on the first call to any of the
61+
session protocol methods. Subsequent calls skip the setup entirely.
62+
"""
63+
64+
# Class-level registry so index creation only runs once per unique key.
65+
_initialized_keys: set[tuple[str, str, str]] = set()
66+
_init_locks: dict[tuple[str, str, str], asyncio.Lock] = {}
67+
_init_locks_guard: asyncio.Lock = asyncio.Lock()
68+
69+
session_settings: SessionSettings | None = None
70+
71+
def __init__(
72+
self,
73+
session_id: str,
74+
*,
75+
client: AsyncMongoClient,
76+
database: str = "agents",
77+
sessions_collection: str = "agent_sessions",
78+
messages_collection: str = "agent_messages",
79+
session_settings: SessionSettings | None = None,
80+
):
81+
"""Initialize a new MongoDBSession.
82+
83+
Args:
84+
session_id: Unique identifier for the conversation.
85+
client: A pre-configured ``AsyncMongoClient`` instance.
86+
database: Name of the MongoDB database to use.
87+
Defaults to ``"agents"``.
88+
sessions_collection: Name of the collection that stores session
89+
metadata. Defaults to ``"agent_sessions"``.
90+
messages_collection: Name of the collection that stores individual
91+
conversation items. Defaults to ``"agent_messages"``.
92+
session_settings: Optional session configuration. When ``None`` a
93+
default :class:`~agents.memory.session_settings.SessionSettings`
94+
is used (no item limit).
95+
"""
96+
self.session_id = session_id
97+
self.session_settings = session_settings or SessionSettings()
98+
self._client = client
99+
self._owns_client = False
100+
101+
db = client[database]
102+
self._sessions: AsyncCollection = db[sessions_collection]
103+
self._messages: AsyncCollection = db[messages_collection]
104+
105+
self._init_key = (database, sessions_collection, messages_collection)
106+
107+
# ------------------------------------------------------------------
108+
# Convenience constructors
109+
# ------------------------------------------------------------------
110+
111+
@classmethod
112+
def from_uri(
113+
cls,
114+
session_id: str,
115+
*,
116+
uri: str,
117+
database: str = "agents",
118+
client_kwargs: dict[str, Any] | None = None,
119+
session_settings: SessionSettings | None = None,
120+
**kwargs: Any,
121+
) -> MongoDBSession:
122+
"""Create a session from a MongoDB URI string.
123+
124+
Args:
125+
session_id: Conversation ID.
126+
uri: MongoDB connection URI,
127+
e.g. ``"mongodb://localhost:27017"`` or
128+
``"mongodb+srv://user:pass@cluster.example.com"``.
129+
database: Name of the MongoDB database to use.
130+
client_kwargs: Additional keyword arguments forwarded to
131+
:class:`pymongo.asynchronous.mongo_client.AsyncMongoClient`.
132+
session_settings: Optional session configuration settings.
133+
**kwargs: Additional keyword arguments forwarded to the main
134+
constructor (e.g. ``sessions_collection``,
135+
``messages_collection``).
136+
137+
Returns:
138+
A :class:`MongoDBSession` connected to the specified MongoDB server.
139+
"""
140+
client_kwargs = client_kwargs or {}
141+
client: AsyncMongoClient = AsyncMongoClient(uri, **client_kwargs)
142+
session = cls(
143+
session_id,
144+
client=client,
145+
database=database,
146+
session_settings=session_settings,
147+
**kwargs,
148+
)
149+
session._owns_client = True
150+
return session
151+
152+
# ------------------------------------------------------------------
153+
# Index initialisation
154+
# ------------------------------------------------------------------
155+
156+
async def _get_init_lock(self) -> asyncio.Lock:
157+
"""Return (creating if necessary) the per-init-key asyncio Lock."""
158+
async with self._init_locks_guard:
159+
lock = self._init_locks.get(self._init_key)
160+
if lock is None:
161+
lock = asyncio.Lock()
162+
self._init_locks[self._init_key] = lock
163+
return lock
164+
165+
async def _ensure_indexes(self) -> None:
166+
"""Create required indexes the first time this key is accessed."""
167+
if self._init_key in self._initialized_keys:
168+
return
169+
170+
lock = await self._get_init_lock()
171+
async with lock:
172+
# Double-checked locking: another coroutine may have finished first.
173+
if self._init_key in self._initialized_keys:
174+
return
175+
176+
# sessions: unique index on session_id.
177+
await self._sessions.create_index("session_id", unique=True)
178+
179+
# messages: compound index for efficient per-session retrieval and sorting.
180+
await self._messages.create_index([("session_id", 1), ("_id", 1)])
181+
182+
self._initialized_keys.add(self._init_key)
183+
184+
# ------------------------------------------------------------------
185+
# Serialization helpers
186+
# ------------------------------------------------------------------
187+
188+
async def _serialize_item(self, item: TResponseInputItem) -> str:
189+
"""Serialize an item to a JSON string. Can be overridden by subclasses."""
190+
return json.dumps(item, separators=(",", ":"))
191+
192+
async def _deserialize_item(self, raw: str) -> TResponseInputItem:
193+
"""Deserialize a JSON string to an item. Can be overridden by subclasses."""
194+
return json.loads(raw) # type: ignore[no-any-return]
195+
196+
# ------------------------------------------------------------------
197+
# Session protocol implementation
198+
# ------------------------------------------------------------------
199+
200+
async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]:
201+
"""Retrieve the conversation history for this session.
202+
203+
Args:
204+
limit: Maximum number of items to retrieve. When ``None``, the
205+
effective limit is taken from :attr:`session_settings`.
206+
If that is also ``None``, all items are returned.
207+
The returned list is always in chronological (oldest-first)
208+
order.
209+
210+
Returns:
211+
List of input items representing the conversation history.
212+
"""
213+
await self._ensure_indexes()
214+
215+
session_limit = resolve_session_limit(limit, self.session_settings)
216+
217+
if session_limit is not None and session_limit <= 0:
218+
return []
219+
220+
query = {"session_id": self.session_id}
221+
222+
if session_limit is None:
223+
cursor = self._messages.find(query).sort("_id", 1)
224+
docs = await cursor.to_list()
225+
else:
226+
# Fetch the latest N documents in reverse order, then reverse the
227+
# list to restore chronological order.
228+
cursor = self._messages.find(query).sort("_id", -1).limit(session_limit)
229+
docs = await cursor.to_list()
230+
docs.reverse()
231+
232+
items: list[TResponseInputItem] = []
233+
for doc in docs:
234+
try:
235+
items.append(await self._deserialize_item(doc["message_data"]))
236+
except (json.JSONDecodeError, KeyError):
237+
# Skip corrupted or malformed documents.
238+
continue
239+
240+
return items
241+
242+
async def add_items(self, items: list[TResponseInputItem]) -> None:
243+
"""Add new items to the conversation history.
244+
245+
Args:
246+
items: List of input items to append to the session.
247+
"""
248+
if not items:
249+
return
250+
251+
await self._ensure_indexes()
252+
253+
# Upsert the session metadata document.
254+
await self._sessions.update_one(
255+
{"session_id": self.session_id},
256+
{"$setOnInsert": {"session_id": self.session_id}},
257+
upsert=True,
258+
)
259+
260+
payload = [
261+
{
262+
"session_id": self.session_id,
263+
"message_data": await self._serialize_item(item),
264+
}
265+
for item in items
266+
]
267+
268+
await self._messages.insert_many(payload, ordered=True)
269+
270+
async def pop_item(self) -> TResponseInputItem | None:
271+
"""Remove and return the most recent item from the session.
272+
273+
Returns:
274+
The most recent item if it exists, ``None`` if the session is empty.
275+
"""
276+
await self._ensure_indexes()
277+
278+
doc = await self._messages.find_one_and_delete(
279+
{"session_id": self.session_id},
280+
sort=[("_id", -1)],
281+
)
282+
283+
if doc is None:
284+
return None
285+
286+
try:
287+
return await self._deserialize_item(doc["message_data"])
288+
except (json.JSONDecodeError, KeyError):
289+
return None
290+
291+
async def clear_session(self) -> None:
292+
"""Clear all items for this session."""
293+
await self._ensure_indexes()
294+
await self._messages.delete_many({"session_id": self.session_id})
295+
await self._sessions.delete_one({"session_id": self.session_id})
296+
297+
# ------------------------------------------------------------------
298+
# Lifecycle helpers
299+
# ------------------------------------------------------------------
300+
301+
async def close(self) -> None:
302+
"""Close the underlying MongoDB connection.
303+
304+
Only closes the client if this session owns it (i.e. it was created
305+
via :meth:`from_uri`). If the client was injected externally the
306+
caller is responsible for managing its lifecycle.
307+
"""
308+
if self._owns_client:
309+
await self._client.aclose()
310+
311+
async def ping(self) -> bool:
312+
"""Test MongoDB connectivity.
313+
314+
Returns:
315+
``True`` if the server is reachable, ``False`` otherwise.
316+
"""
317+
try:
318+
await self._client.admin.command("ping")
319+
return True
320+
except Exception:
321+
return False

0 commit comments

Comments
 (0)