Skip to content

Commit 30d1dcb

Browse files
alexbeviclaude
andcommitted
fix(mongodb): remove asyncio.Lock from init registry — use threading-only guard
The per-entry asyncio.Lock stored in _init_state was still loop-bound: a second event loop reusing the same AsyncMongoClient would hang or raise RuntimeError trying to acquire it. create_index is idempotent on the server side, so no async coordination is needed — concurrent first-time callers may each issue a redundant create_index round-trip, but that is harmless. Replace the asyncio.Lock with a plain bool guarded by the existing threading.Lock, removing the last asyncio.Lock from the process-wide registry entirely. Also removes the now-unused asyncio import. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent 9b166de commit 30d1dcb

File tree

1 file changed

+33
-52
lines changed

1 file changed

+33
-52
lines changed

src/agents/extensions/memory/mongodb_session.py

Lines changed: 33 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131

3232
from __future__ import annotations
3333

34-
import asyncio
3534
import json
3635
import threading
3736
import weakref
@@ -83,17 +82,18 @@ class MongoDBSession(SessionABC):
8382
# Class-level registry so index creation runs only once per unique
8483
# (client, database, sessions_collection, messages_collection) combination.
8584
#
86-
# WeakKeyDictionary lets the client object itself be the outer key so
87-
# that the entry is automatically pruned when the client is garbage-
88-
# collected. This prevents id(client) reuse from causing a new client to
89-
# skip index creation because it shares an address with a dead one.
90-
#
91-
# _init_guard is a plain threading.Lock so it is safe to acquire from any
92-
# event loop — an asyncio.Lock binds to the loop that first contends on
93-
# it and raises RuntimeError when a different loop tries to acquire it.
94-
_init_state: weakref.WeakKeyDictionary[
95-
Any, dict[tuple[str, str, str], asyncio.Lock | bool]
96-
] = weakref.WeakKeyDictionary()
85+
# Design notes:
86+
# - WeakKeyDictionary keyed on the client object: entries are pruned
87+
# automatically when the client is GC'd, so id() reuse can never cause
88+
# a new client to skip index creation.
89+
# - Only a threading.Lock (never an asyncio.Lock) touches the registry.
90+
# asyncio.Lock is bound to the event loop that first acquires it; reusing
91+
# one across loops raises RuntimeError. create_index is idempotent, so
92+
# we only need the threading lock to guard the boolean done flag — no
93+
# async coordination is required.
94+
_init_state: weakref.WeakKeyDictionary[Any, dict[tuple[str, str, str], bool]] = (
95+
weakref.WeakKeyDictionary()
96+
)
9797
_init_guard: threading.Lock = threading.Lock()
9898

9999
session_settings: SessionSettings | None = None
@@ -191,59 +191,40 @@ def from_uri(
191191
# Index initialisation
192192
# ------------------------------------------------------------------
193193

194-
def _get_or_create_init_lock(self) -> tuple[asyncio.Lock, bool]:
195-
"""Return (lock, already_done) for this session's (client, sub_key) pair.
196-
197-
Uses a threading.Lock for the registry mutation so it is safe to call
198-
from any event loop without risking a cross-loop RuntimeError.
199-
"""
194+
def _is_init_done(self) -> bool:
195+
"""Return True if indexes have already been created for this (client, sub_key)."""
200196
with self._init_guard:
201197
per_client = self._init_state.get(self._client)
202-
if per_client is None:
203-
per_client = {}
204-
self._init_state[self._client] = per_client
205-
206-
entry = per_client.get(self._init_sub_key)
207-
if entry is True:
208-
# Already initialised.
209-
return asyncio.Lock(), True
210-
if entry is None:
211-
lock = asyncio.Lock()
212-
per_client[self._init_sub_key] = lock
213-
return lock, False
214-
# entry is an asyncio.Lock — initialisation is in progress.
215-
assert isinstance(entry, asyncio.Lock)
216-
return entry, False
198+
return per_client is not None and per_client.get(self._init_sub_key, False)
217199

218200
def _mark_init_done(self) -> None:
219201
"""Record that index creation is complete for this (client, sub_key)."""
220202
with self._init_guard:
221203
per_client = self._init_state.get(self._client)
222-
if per_client is not None:
223-
per_client[self._init_sub_key] = True
204+
if per_client is None:
205+
per_client = {}
206+
self._init_state[self._client] = per_client
207+
per_client[self._init_sub_key] = True
224208

225209
async def _ensure_indexes(self) -> None:
226-
"""Create required indexes the first time this (client, sub_key) is accessed."""
227-
lock, done = self._get_or_create_init_lock()
228-
if done:
229-
return
210+
"""Create required indexes the first time this (client, sub_key) is accessed.
230211
231-
async with lock:
232-
# Double-checked locking: another coroutine may have finished first.
233-
_, done = self._get_or_create_init_lock()
234-
if done:
235-
return
212+
``create_index`` is idempotent on the server side, so concurrent calls
213+
from different coroutines or event loops are safe — at most a redundant
214+
round-trip is issued. The threading-lock-guarded boolean prevents that
215+
extra round-trip after the first call completes.
216+
"""
217+
if self._is_init_done():
218+
return
236219

237-
# sessions: unique index on session_id.
238-
await self._sessions.create_index("session_id", unique=True)
220+
# sessions: unique index on session_id.
221+
await self._sessions.create_index("session_id", unique=True)
239222

240-
# messages: compound index for efficient per-session retrieval and sorting.
241-
# seq provides a strictly monotonic insertion-order tie-breaker that is
242-
# reliable across multiple writers (unlike ObjectId which is only
243-
# second-level accurate).
244-
await self._messages.create_index([("session_id", 1), ("seq", 1)])
223+
# messages: compound index for efficient per-session retrieval and
224+
# sorting by the explicit seq counter.
225+
await self._messages.create_index([("session_id", 1), ("seq", 1)])
245226

246-
self._mark_init_done()
227+
self._mark_init_done()
247228

248229
# ------------------------------------------------------------------
249230
# Serialization helpers

0 commit comments

Comments
 (0)