Skip to content

Commit 436289f

Browse files
alexbeviclaude
andcommitted
fix: address Codex review feedback on MongoDBSession
- P1: Use client.close() (sync) instead of aclose() when closing an owned AsyncMongoClient; PyMongo does not expose aclose(). - P1: Include id(client) in _init_key so that different AsyncMongoClient instances pointing at different clusters each run their own index- creation pass rather than sharing a single guard. - P2: Add TypeError to the except tuple in get_items() and pop_item() so non-string BSON message_data values (e.g. int/object) are silently skipped rather than aborting history retrieval. Adds two new tests: test_non_string_message_data_is_skipped (P2) and test_different_clients_each_run_index_init (P1). Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent 9c2a083 commit 436289f

File tree

2 files changed

+62
-9
lines changed

2 files changed

+62
-9
lines changed

src/agents/extensions/memory/mongodb_session.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -67,14 +67,16 @@ class MongoDBSession(SessionABC):
6767
collection. A lightweight ``sessions`` collection tracks metadata
6868
(creation time, last-updated time) for each session.
6969
70-
Indexes are created once per ``(database, sessions_collection,
70+
Indexes are created once per ``(client, database, sessions_collection,
7171
messages_collection)`` combination on the first call to any of the
7272
session protocol methods. Subsequent calls skip the setup entirely.
7373
"""
7474

7575
# Class-level registry so index creation only runs once per unique key.
76-
_initialized_keys: set[tuple[str, str, str]] = set()
77-
_init_locks: dict[tuple[str, str, str], asyncio.Lock] = {}
76+
# The key includes id(client) so that different AsyncMongoClient instances
77+
# (e.g. pointing at different clusters) each get their own init pass.
78+
_initialized_keys: set[tuple[int, str, str, str]] = set()
79+
_init_locks: dict[tuple[int, str, str, str], asyncio.Lock] = {}
7880
_init_locks_guard: asyncio.Lock = asyncio.Lock()
7981

8082
session_settings: SessionSettings | None = None
@@ -118,7 +120,7 @@ def __init__(
118120
self._sessions: AsyncCollection = db[sessions_collection]
119121
self._messages: AsyncCollection = db[messages_collection]
120122

121-
self._init_key = (database, sessions_collection, messages_collection)
123+
self._init_key = (id(client), database, sessions_collection, messages_collection)
122124

123125
# ------------------------------------------------------------------
124126
# Convenience constructors
@@ -250,8 +252,8 @@ async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]:
250252
for doc in docs:
251253
try:
252254
items.append(await self._deserialize_item(doc["message_data"]))
253-
except (json.JSONDecodeError, KeyError):
254-
# Skip corrupted or malformed documents.
255+
except (json.JSONDecodeError, KeyError, TypeError):
256+
# Skip corrupted or malformed documents (including non-string BSON values).
255257
continue
256258

257259
return items
@@ -302,7 +304,7 @@ async def pop_item(self) -> TResponseInputItem | None:
302304

303305
try:
304306
return await self._deserialize_item(doc["message_data"])
305-
except (json.JSONDecodeError, KeyError):
307+
except (json.JSONDecodeError, KeyError, TypeError):
306308
return None
307309

308310
async def clear_session(self) -> None:
@@ -323,7 +325,7 @@ async def close(self) -> None:
323325
caller is responsible for managing its lifecycle.
324326
"""
325327
if self._owns_client:
326-
await self._client.aclose()
328+
self._client.close()
327329

328330
async def ping(self) -> bool:
329331
"""Test MongoDB connectivity.

tests/extensions/memory/test_mongodb_session.py

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,10 +191,14 @@ def append_metadata(self, driver_info: FakeDriverInfo) -> None:
191191
"""Record append_metadata calls for test assertions."""
192192
self._metadata_calls.append(driver_info)
193193

194-
async def aclose(self) -> None:
194+
def close(self) -> None:
195+
"""Synchronous close — matches PyMongo's AsyncMongoClient.close() signature."""
195196
self._closed = True
196197
self.admin._closed = True
197198

199+
async def aclose(self) -> None:
200+
self.close()
201+
198202

199203
# ---------------------------------------------------------------------------
200204
# Inject fake pymongo into sys.modules before importing the module under test
@@ -484,6 +488,19 @@ async def test_missing_message_data_field_is_skipped(session: MongoDBSession) ->
484488
assert len(items) == 1
485489

486490

491+
async def test_non_string_message_data_is_skipped(session: MongoDBSession) -> None:
492+
"""Documents whose message_data is a non-string BSON type are silently skipped."""
493+
await session.add_items([{"role": "user", "content": "valid"}])
494+
495+
# Inject a document where message_data is an integer — json.loads raises TypeError.
496+
bad_doc = {"_id": FakeObjectId(), "session_id": session.session_id, "message_data": 42}
497+
session._messages._docs[id(bad_doc["_id"])] = bad_doc # type: ignore[attr-defined]
498+
499+
items = await session.get_items()
500+
assert len(items) == 1
501+
assert items[0].get("content") == "valid"
502+
503+
487504
# ---------------------------------------------------------------------------
488505
# Index initialisation (idempotency)
489506
# ---------------------------------------------------------------------------
@@ -513,6 +530,40 @@ async def counting(*args: Any, **kwargs: Any) -> str:
513530
session._sessions.create_index = original_sessions # type: ignore[attr-defined]
514531

515532

533+
async def test_different_clients_each_run_index_init() -> None:
534+
"""Each distinct AsyncMongoClient gets its own index-creation pass."""
535+
MongoDBSession._initialized_keys.clear()
536+
MongoDBSession._init_locks.clear()
537+
538+
client_a = FakeAsyncMongoClient()
539+
client_b = FakeAsyncMongoClient()
540+
541+
call_counts: dict[str, int] = {"a": 0, "b": 0}
542+
543+
async def counting_a(*args: Any, **kwargs: Any) -> str:
544+
call_counts["a"] += 1
545+
return "fake_index"
546+
547+
async def counting_b(*args: Any, **kwargs: Any) -> str:
548+
call_counts["b"] += 1
549+
return "fake_index"
550+
551+
s_a = MongoDBSession("x", client=client_a, database="agents_test") # type: ignore[arg-type]
552+
s_b = MongoDBSession("x", client=client_b, database="agents_test") # type: ignore[arg-type]
553+
554+
s_a._messages.create_index = counting_a # type: ignore[attr-defined]
555+
s_a._sessions.create_index = counting_a # type: ignore[attr-defined]
556+
s_b._messages.create_index = counting_b # type: ignore[attr-defined]
557+
s_b._sessions.create_index = counting_b # type: ignore[attr-defined]
558+
559+
await s_a._ensure_indexes()
560+
await s_b._ensure_indexes()
561+
562+
# Each client must trigger its own index creation (2 calls = sessions + messages).
563+
assert call_counts["a"] == 2
564+
assert call_counts["b"] == 2
565+
566+
516567
# ---------------------------------------------------------------------------
517568
# Connectivity and lifecycle
518569
# ---------------------------------------------------------------------------

0 commit comments

Comments
 (0)