Skip to content

Commit 9b166de

Browse files
alexbeviclaude
andcommitted
fix(mongodb): address Codex P1/P2 review feedback (round 3)
- Replace asyncio.Lock process-wide guard with threading.Lock so _get_or_create_init_lock() is safe to call from any event loop - Replace id(client)-keyed dicts with WeakKeyDictionary so entries are pruned when clients are GC'd, preventing stale id reuse from skipping index creation on a new client - Add explicit seq field (monotonic per-session counter via $inc) to every message document and sort by seq instead of _id; ObjectId is only second-level accurate across processes and not reliably monotonic - Update FakeAsyncCollection in tests with find_one_and_update + $inc support; replace _initialized_keys/_init_locks clears with _init_state Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent 054d391 commit 9b166de

File tree

2 files changed

+118
-56
lines changed

2 files changed

+118
-56
lines changed

src/agents/extensions/memory/mongodb_session.py

Lines changed: 80 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@
3333

3434
import asyncio
3535
import json
36+
import threading
37+
import weakref
3638
from typing import Any
3739

3840
try:
@@ -70,14 +72,29 @@ class MongoDBSession(SessionABC):
7072
Indexes are created once per ``(client, database, sessions_collection,
7173
messages_collection)`` combination on the first call to any of the
7274
session protocol methods. Subsequent calls skip the setup entirely.
75+
76+
Each message document carries a ``seq`` field — an integer assigned by
77+
atomically incrementing a counter on the session metadata document. This
78+
guarantees a strictly monotonic insertion order that is safe across
79+
multiple writers and processes, unlike sorting by ``_id`` / ObjectId which
80+
is only second-level accurate and non-monotonic across machines.
7381
"""
7482

75-
# Class-level registry so index creation only runs once per unique key.
76-
# The key includes id(client) so that different AsyncMongoClient instances
77-
# (e.g. pointing at different clusters) each get their own init pass.
78-
_initialized_keys: set[tuple[int, str, str, str]] = set()
79-
_init_locks: dict[tuple[int, str, str, str], asyncio.Lock] = {}
80-
_init_locks_guard: asyncio.Lock = asyncio.Lock()
83+
# Class-level registry so index creation runs only once per unique
84+
# (client, database, sessions_collection, messages_collection) combination.
85+
#
86+
# WeakKeyDictionary lets the client object itself be the outer key so
87+
# that the entry is automatically pruned when the client is garbage-
88+
# collected. This prevents id(client) reuse from causing a new client to
89+
# skip index creation because it shares an address with a dead one.
90+
#
91+
# _init_guard is a plain threading.Lock so it is safe to acquire from any
92+
# event loop — an asyncio.Lock binds to the loop that first contends on
93+
# it and raises RuntimeError when a different loop tries to acquire it.
94+
_init_state: weakref.WeakKeyDictionary[
95+
Any, dict[tuple[str, str, str], asyncio.Lock | bool]
96+
] = weakref.WeakKeyDictionary()
97+
_init_guard: threading.Lock = threading.Lock()
8198

8299
session_settings: SessionSettings | None = None
83100

@@ -120,7 +137,9 @@ def __init__(
120137
self._sessions: AsyncCollection[Any] = db[sessions_collection]
121138
self._messages: AsyncCollection[Any] = db[messages_collection]
122139

123-
self._init_key = (id(client), database, sessions_collection, messages_collection)
140+
# Key within the per-client mapping — no id() needed because the client
141+
# object itself is the outer WeakKeyDictionary key.
142+
self._init_sub_key = (database, sessions_collection, messages_collection)
124143

125144
# ------------------------------------------------------------------
126145
# Convenience constructors
@@ -172,33 +191,59 @@ def from_uri(
172191
# Index initialisation
173192
# ------------------------------------------------------------------
174193

175-
async def _get_init_lock(self) -> asyncio.Lock:
176-
"""Return (creating if necessary) the per-init-key asyncio Lock."""
177-
async with self._init_locks_guard:
178-
lock = self._init_locks.get(self._init_key)
179-
if lock is None:
194+
def _get_or_create_init_lock(self) -> tuple[asyncio.Lock, bool]:
195+
"""Return (lock, already_done) for this session's (client, sub_key) pair.
196+
197+
Uses a threading.Lock for the registry mutation so it is safe to call
198+
from any event loop without risking a cross-loop RuntimeError.
199+
"""
200+
with self._init_guard:
201+
per_client = self._init_state.get(self._client)
202+
if per_client is None:
203+
per_client = {}
204+
self._init_state[self._client] = per_client
205+
206+
entry = per_client.get(self._init_sub_key)
207+
if entry is True:
208+
# Already initialised.
209+
return asyncio.Lock(), True
210+
if entry is None:
180211
lock = asyncio.Lock()
181-
self._init_locks[self._init_key] = lock
182-
return lock
212+
per_client[self._init_sub_key] = lock
213+
return lock, False
214+
# entry is an asyncio.Lock — initialisation is in progress.
215+
assert isinstance(entry, asyncio.Lock)
216+
return entry, False
217+
218+
def _mark_init_done(self) -> None:
219+
"""Record that index creation is complete for this (client, sub_key)."""
220+
with self._init_guard:
221+
per_client = self._init_state.get(self._client)
222+
if per_client is not None:
223+
per_client[self._init_sub_key] = True
183224

184225
async def _ensure_indexes(self) -> None:
185-
"""Create required indexes the first time this key is accessed."""
186-
if self._init_key in self._initialized_keys:
226+
"""Create required indexes the first time this (client, sub_key) is accessed."""
227+
lock, done = self._get_or_create_init_lock()
228+
if done:
187229
return
188230

189-
lock = await self._get_init_lock()
190231
async with lock:
191232
# Double-checked locking: another coroutine may have finished first.
192-
if self._init_key in self._initialized_keys:
233+
_, done = self._get_or_create_init_lock()
234+
if done:
193235
return
194236

195237
# sessions: unique index on session_id.
196238
await self._sessions.create_index("session_id", unique=True)
197239

198240
# messages: compound index for efficient per-session retrieval and sorting.
199-
await self._messages.create_index([("session_id", 1), ("_id", 1)])
241+
# seq provides a strictly monotonic insertion-order tie-breaker that is
242+
# reliable across multiple writers (unlike ObjectId which is only
243+
# second-level accurate).
244+
await self._messages.create_index([("session_id", 1), ("seq", 1)])
200245

201-
self._initialized_keys.add(self._init_key)
246+
self._mark_init_done()
202247

203248
# ------------------------------------------------------------------
204249
# Serialization helpers
@@ -239,12 +284,12 @@ async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]:
239284
query = {"session_id": self.session_id}
240285

241286
if session_limit is None:
242-
cursor = self._messages.find(query).sort("_id", 1)
287+
cursor = self._messages.find(query).sort("seq", 1)
243288
docs = await cursor.to_list()
244289
else:
245290
# Fetch the latest N documents in reverse order, then reverse the
246291
# list to restore chronological order.
247-
cursor = self._messages.find(query).sort("_id", -1).limit(session_limit)
292+
cursor = self._messages.find(query).sort("seq", -1).limit(session_limit)
248293
docs = await cursor.to_list()
249294
docs.reverse()
250295

@@ -269,19 +314,27 @@ async def add_items(self, items: list[TResponseInputItem]) -> None:
269314

270315
await self._ensure_indexes()
271316

272-
# Upsert the session metadata document.
273-
await self._sessions.update_one(
317+
# Atomically reserve a block of sequence numbers for this batch.
318+
# $inc returns the new value, so subtract len(items) to get the first
319+
# number in the block.
320+
result = await self._sessions.find_one_and_update(
274321
{"session_id": self.session_id},
275-
{"$setOnInsert": {"session_id": self.session_id}},
322+
{
323+
"$setOnInsert": {"session_id": self.session_id},
324+
"$inc": {"_seq": len(items)},
325+
},
276326
upsert=True,
327+
return_document=True,
277328
)
329+
next_seq: int = (result["_seq"] if result else len(items)) - len(items)
278330

279331
payload = [
280332
{
281333
"session_id": self.session_id,
334+
"seq": next_seq + i,
282335
"message_data": await self._serialize_item(item),
283336
}
284-
for item in items
337+
for i, item in enumerate(items)
285338
]
286339

287340
await self._messages.insert_many(payload, ordered=True)
@@ -296,7 +349,7 @@ async def pop_item(self) -> TResponseInputItem | None:
296349

297350
doc = await self._messages.find_one_and_delete(
298351
{"session_id": self.session_id},
299-
sort=[("_id", -1)],
352+
sort=[("seq", -1)],
300353
)
301354

302355
if doc is None:

tests/extensions/memory/test_mongodb_session.py

Lines changed: 38 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -115,19 +115,41 @@ async def insert_many(
115115
doc["_id"] = FakeObjectId()
116116
self._docs[id(doc["_id"])] = dict(doc)
117117

118-
async def update_one(
118+
async def find_one_and_update(
119119
self,
120120
query: dict[str, Any],
121121
update: dict[str, Any],
122122
upsert: bool = False,
123-
) -> None:
123+
return_document: bool = False,
124+
) -> dict[str, Any] | None:
124125
for doc in self._docs.values():
125126
if self._matches(doc, query):
126-
return # Exists — $setOnInsert is a no-op on existing docs.
127+
# Apply $inc fields.
128+
for field, delta in update.get("$inc", {}).items():
129+
doc[field] = doc.get(field, 0) + delta
130+
return dict(doc) if return_document else None
127131
if upsert:
128132
new_doc: dict[str, Any] = {"_id": FakeObjectId()}
129133
new_doc.update(update.get("$setOnInsert", {}))
134+
for field, delta in update.get("$inc", {}).items():
135+
new_doc[field] = new_doc.get(field, 0) + delta
130136
self._docs[id(new_doc["_id"])] = new_doc
137+
return dict(new_doc) if return_document else None
138+
return None
139+
140+
async def update_one(
141+
self,
142+
query: dict[str, Any],
143+
update: dict[str, Any],
144+
upsert: bool = False,
145+
) -> None:
146+
for doc in self._docs.values():
147+
if self._matches(doc, query):
148+
return # Exists — $setOnInsert is a no-op on existing docs.
149+
if upsert:
150+
new_doc2: dict[str, Any] = {"_id": FakeObjectId()}
151+
new_doc2.update(update.get("$setOnInsert", {}))
152+
self._docs[id(new_doc2["_id"])] = new_doc2
131153

132154
async def delete_many(self, query: dict[str, Any]) -> None:
133155
to_remove = [k for k, d in self._docs.items() if self._matches(d, query)]
@@ -235,8 +257,7 @@ def _make_fake_pymongo_modules() -> None:
235257
def _make_session(session_id: str = "test-session", **kwargs: Any) -> MongoDBSession:
236258
"""Create a MongoDBSession backed by a FakeAsyncMongoClient."""
237259
client = FakeAsyncMongoClient()
238-
MongoDBSession._initialized_keys.clear()
239-
MongoDBSession._init_locks.clear()
260+
MongoDBSession._init_state.clear()
240261
return MongoDBSession(
241262
session_id,
242263
client=client, # type: ignore[arg-type]
@@ -353,8 +374,7 @@ async def test_get_items_limit_exceeds_count(session: MongoDBSession) -> None:
353374

354375
async def test_session_settings_limit_used_as_default() -> None:
355376
"""session_settings.limit is applied when no explicit limit is given."""
356-
MongoDBSession._initialized_keys.clear()
357-
MongoDBSession._init_locks.clear()
377+
MongoDBSession._init_state.clear()
358378
s = MongoDBSession(
359379
"ls-test",
360380
client=FakeAsyncMongoClient(), # type: ignore[arg-type]
@@ -371,8 +391,7 @@ async def test_session_settings_limit_used_as_default() -> None:
371391

372392
async def test_explicit_limit_overrides_session_settings() -> None:
373393
"""An explicit limit passed to get_items must override session_settings.limit."""
374-
MongoDBSession._initialized_keys.clear()
375-
MongoDBSession._init_locks.clear()
394+
MongoDBSession._init_state.clear()
376395
s = MongoDBSession(
377396
"override-test",
378397
client=FakeAsyncMongoClient(), # type: ignore[arg-type]
@@ -394,8 +413,7 @@ async def test_explicit_limit_overrides_session_settings() -> None:
394413

395414
async def test_sessions_are_isolated() -> None:
396415
"""Two sessions with different IDs must not share data."""
397-
MongoDBSession._initialized_keys.clear()
398-
MongoDBSession._init_locks.clear()
416+
MongoDBSession._init_state.clear()
399417
client = FakeAsyncMongoClient()
400418
s1 = MongoDBSession("alice", client=client, database="agents_test") # type: ignore[arg-type]
401419
s2 = MongoDBSession("bob", client=client, database="agents_test") # type: ignore[arg-type]
@@ -409,8 +427,7 @@ async def test_sessions_are_isolated() -> None:
409427

410428
async def test_clear_does_not_affect_other_sessions() -> None:
411429
"""Clearing one session must leave sibling sessions untouched."""
412-
MongoDBSession._initialized_keys.clear()
413-
MongoDBSession._init_locks.clear()
430+
MongoDBSession._init_state.clear()
414431
client = FakeAsyncMongoClient()
415432
s1 = MongoDBSession("s1", client=client, database="agents_test") # type: ignore[arg-type]
416433
s2 = MongoDBSession("s2", client=client, database="agents_test") # type: ignore[arg-type]
@@ -529,8 +546,7 @@ async def counting(*args: Any, **kwargs: Any) -> str:
529546

530547
async def test_different_clients_each_run_index_init() -> None:
531548
"""Each distinct AsyncMongoClient gets its own index-creation pass."""
532-
MongoDBSession._initialized_keys.clear()
533-
MongoDBSession._init_locks.clear()
549+
MongoDBSession._init_state.clear()
534550

535551
client_a = FakeAsyncMongoClient()
536552
client_b = FakeAsyncMongoClient()
@@ -585,8 +601,7 @@ async def _fail(*args: Any, **kwargs: Any) -> dict[str, Any]:
585601

586602
async def test_close_external_client_not_closed() -> None:
587603
"""close() must NOT close a client that was injected externally."""
588-
MongoDBSession._initialized_keys.clear()
589-
MongoDBSession._init_locks.clear()
604+
MongoDBSession._init_state.clear()
590605
client = FakeAsyncMongoClient()
591606
s = MongoDBSession("x", client=client, database="agents_test") # type: ignore[arg-type]
592607
assert s._owns_client is False
@@ -597,8 +612,7 @@ async def test_close_external_client_not_closed() -> None:
597612

598613
async def test_close_owned_client_is_closed() -> None:
599614
"""close() must close a client created by from_uri."""
600-
MongoDBSession._initialized_keys.clear()
601-
MongoDBSession._init_locks.clear()
615+
MongoDBSession._init_state.clear()
602616
fake_client = FakeAsyncMongoClient()
603617
with patch(
604618
"agents.extensions.memory.mongodb_session.AsyncMongoClient",
@@ -636,8 +650,7 @@ async def test_runner_integration(agent: Agent) -> None:
636650

637651
async def test_runner_session_isolation(agent: Agent) -> None:
638652
"""Two independent sessions must not bleed history into each other."""
639-
MongoDBSession._initialized_keys.clear()
640-
MongoDBSession._init_locks.clear()
653+
MongoDBSession._init_state.clear()
641654
client = FakeAsyncMongoClient()
642655
s1 = MongoDBSession("user-a", client=client, database="agents_test") # type: ignore[arg-type]
643656
s2 = MongoDBSession("user-b", client=client, database="agents_test") # type: ignore[arg-type]
@@ -659,8 +672,7 @@ async def test_runner_with_session_settings_limit(agent: Agent) -> None:
659672
"""RunConfig.session_settings.limit must cap the history sent to the model."""
660673
from agents import RunConfig
661674

662-
MongoDBSession._initialized_keys.clear()
663-
MongoDBSession._init_locks.clear()
675+
MongoDBSession._init_state.clear()
664676
session = MongoDBSession(
665677
"limit-test",
666678
client=FakeAsyncMongoClient(), # type: ignore[arg-type]
@@ -694,8 +706,7 @@ async def test_runner_with_session_settings_limit(agent: Agent) -> None:
694706

695707
async def test_injected_client_receives_append_metadata() -> None:
696708
"""Append_metadata is called on a caller-supplied client."""
697-
MongoDBSession._initialized_keys.clear()
698-
MongoDBSession._init_locks.clear()
709+
MongoDBSession._init_state.clear()
699710
client = FakeAsyncMongoClient()
700711

701712
MongoDBSession("meta-test", client=client, database="agents_test") # type: ignore[arg-type]
@@ -707,8 +718,7 @@ async def test_injected_client_receives_append_metadata() -> None:
707718

708719
async def test_from_uri_passes_driver_info_to_constructor() -> None:
709720
"""driver=_DRIVER_INFO is forwarded to AsyncMongoClient via from_uri."""
710-
MongoDBSession._initialized_keys.clear()
711-
MongoDBSession._init_locks.clear()
721+
MongoDBSession._init_state.clear()
712722

713723
captured_kwargs: dict[str, Any] = {}
714724

@@ -728,8 +738,7 @@ def _fake_client(uri: str, **kwargs: Any) -> FakeAsyncMongoClient:
728738

729739
async def test_caller_supplied_driver_info_is_not_overwritten() -> None:
730740
"""A caller-supplied driver kwarg must not be silently replaced."""
731-
MongoDBSession._initialized_keys.clear()
732-
MongoDBSession._init_locks.clear()
741+
MongoDBSession._init_state.clear()
733742

734743
captured_kwargs: dict[str, Any] = {}
735744
custom_info = FakeDriverInfo(name="MyApp")

0 commit comments

Comments
 (0)