Skip to content

Commit 054d391

Browse files
alexbeviclaude
andcommitted
fix: resolve mypy type: ignore mismatches in MongoDBSession tests
With pymongo now installed in dev deps, mypy resolves the real async types so several ignore codes were wrong or unused: - Make FakeAsyncMongoClient.close() async (MongoDBSession awaits it) - Change method-assignment ignores from [attr-defined] to [method-assign] - Add [assignment] to admin.command monkey-patch (overloaded function) - Remove ignores on _docs access and create_index reads that mypy now handles via the Any type parameter on AsyncCollection[Any] Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent 436289f commit 054d391

4 files changed

Lines changed: 25 additions & 29 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ dev = [
9191
"grpcio>=1.60.0",
9292
"testcontainers==4.12.0", # pinned to 4.12.0 because 4.13.0 has a warning bug in wait_for_logs, see https://github.com/testcontainers/testcontainers-python/issues/874
9393
"pyright==1.1.408",
94+
"pymongo>=4.13",
9495
]
9596

9697
[tool.uv.workspace]
@@ -155,10 +156,6 @@ ignore_missing_imports = true
155156
module = ["runloop_api_client", "runloop_api_client.*"]
156157
ignore_missing_imports = true
157158

158-
[[tool.mypy.overrides]]
159-
module = ["pymongo", "pymongo.*"]
160-
ignore_missing_imports = true
161-
162159
[[tool.mypy.overrides]]
163160
module = ["blaxel", "blaxel.*"]
164161
ignore_missing_imports = true

src/agents/extensions/memory/mongodb_session.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def __init__(
8585
self,
8686
session_id: str,
8787
*,
88-
client: AsyncMongoClient,
88+
client: AsyncMongoClient[Any],
8989
database: str = "agents",
9090
sessions_collection: str = "agent_sessions",
9191
messages_collection: str = "agent_messages",
@@ -117,8 +117,8 @@ def __init__(
117117
client.append_metadata(_DRIVER_INFO)
118118

119119
db = client[database]
120-
self._sessions: AsyncCollection = db[sessions_collection]
121-
self._messages: AsyncCollection = db[messages_collection]
120+
self._sessions: AsyncCollection[Any] = db[sessions_collection]
121+
self._messages: AsyncCollection[Any] = db[messages_collection]
122122

123123
self._init_key = (id(client), database, sessions_collection, messages_collection)
124124

@@ -157,7 +157,7 @@ def from_uri(
157157
"""
158158
client_kwargs = client_kwargs or {}
159159
client_kwargs.setdefault("driver", _DRIVER_INFO)
160-
client: AsyncMongoClient = AsyncMongoClient(uri, **client_kwargs)
160+
client: AsyncMongoClient[Any] = AsyncMongoClient(uri, **client_kwargs)
161161
session = cls(
162162
session_id,
163163
client=client,
@@ -325,7 +325,7 @@ async def close(self) -> None:
325325
caller is responsible for managing its lifecycle.
326326
"""
327327
if self._owns_client:
328-
self._client.close()
328+
await self._client.close()
329329

330330
async def ping(self) -> bool:
331331
"""Test MongoDB connectivity.

tests/extensions/memory/test_mongodb_session.py

Lines changed: 17 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -191,14 +191,11 @@ def append_metadata(self, driver_info: FakeDriverInfo) -> None:
191191
"""Record append_metadata calls for test assertions."""
192192
self._metadata_calls.append(driver_info)
193193

194-
def close(self) -> None:
195-
"""Synchronous close — matches PyMongo's AsyncMongoClient.close() signature."""
194+
async def close(self) -> None:
195+
"""Async close — matches PyMongo's AsyncMongoClient.close() signature."""
196196
self._closed = True
197197
self.admin._closed = True
198198

199-
async def aclose(self) -> None:
200-
self.close()
201-
202199

203200
# ---------------------------------------------------------------------------
204201
# Inject fake pymongo into sys.modules before importing the module under test
@@ -470,7 +467,7 @@ async def test_corrupted_document_is_skipped(session: MongoDBSession) -> None:
470467
"session_id": session.session_id,
471468
"message_data": "not valid json {{{",
472469
}
473-
session._messages._docs[id(bad_doc["_id"])] = bad_doc # type: ignore[attr-defined]
470+
session._messages._docs[id(bad_doc["_id"])] = bad_doc
474471

475472
items = await session.get_items()
476473
assert len(items) == 1
@@ -482,7 +479,7 @@ async def test_missing_message_data_field_is_skipped(session: MongoDBSession) ->
482479
await session.add_items([{"role": "user", "content": "valid"}])
483480

484481
bad_doc = {"_id": FakeObjectId(), "session_id": session.session_id}
485-
session._messages._docs[id(bad_doc["_id"])] = bad_doc # type: ignore[attr-defined]
482+
session._messages._docs[id(bad_doc["_id"])] = bad_doc
486483

487484
items = await session.get_items()
488485
assert len(items) == 1
@@ -494,7 +491,7 @@ async def test_non_string_message_data_is_skipped(session: MongoDBSession) -> No
494491

495492
# Inject a document where message_data is an integer — json.loads raises TypeError.
496493
bad_doc = {"_id": FakeObjectId(), "session_id": session.session_id, "message_data": 42}
497-
session._messages._docs[id(bad_doc["_id"])] = bad_doc # type: ignore[attr-defined]
494+
session._messages._docs[id(bad_doc["_id"])] = bad_doc
498495

499496
items = await session.get_items()
500497
assert len(items) == 1
@@ -509,25 +506,25 @@ async def test_non_string_message_data_is_skipped(session: MongoDBSession) -> No
509506
async def test_index_creation_runs_only_once(session: MongoDBSession) -> None:
510507
"""_ensure_indexes must call create_index only on the very first call."""
511508
call_count = 0
512-
original_messages = session._messages.create_index # type: ignore[attr-defined]
513-
original_sessions = session._sessions.create_index # type: ignore[attr-defined]
509+
original_messages = session._messages.create_index
510+
original_sessions = session._sessions.create_index
514511

515512
async def counting(*args: Any, **kwargs: Any) -> str:
516513
nonlocal call_count
517514
call_count += 1
518515
return "fake_index"
519516

520-
session._messages.create_index = counting # type: ignore[attr-defined]
521-
session._sessions.create_index = counting # type: ignore[attr-defined]
517+
session._messages.create_index = counting # type: ignore[method-assign]
518+
session._sessions.create_index = counting # type: ignore[method-assign]
522519

523520
await session._ensure_indexes()
524521
await session._ensure_indexes() # Second call must be a no-op.
525522

526523
# Exactly one call per collection (sessions + messages).
527524
assert call_count == 2
528525

529-
session._messages.create_index = original_messages # type: ignore[attr-defined]
530-
session._sessions.create_index = original_sessions # type: ignore[attr-defined]
526+
session._messages.create_index = original_messages # type: ignore[method-assign]
527+
session._sessions.create_index = original_sessions # type: ignore[method-assign]
531528

532529

533530
async def test_different_clients_each_run_index_init() -> None:
@@ -551,10 +548,10 @@ async def counting_b(*args: Any, **kwargs: Any) -> str:
551548
s_a = MongoDBSession("x", client=client_a, database="agents_test") # type: ignore[arg-type]
552549
s_b = MongoDBSession("x", client=client_b, database="agents_test") # type: ignore[arg-type]
553550

554-
s_a._messages.create_index = counting_a # type: ignore[attr-defined]
555-
s_a._sessions.create_index = counting_a # type: ignore[attr-defined]
556-
s_b._messages.create_index = counting_b # type: ignore[attr-defined]
557-
s_b._sessions.create_index = counting_b # type: ignore[attr-defined]
551+
s_a._messages.create_index = counting_a # type: ignore[method-assign]
552+
s_a._sessions.create_index = counting_a # type: ignore[method-assign]
553+
s_b._messages.create_index = counting_b # type: ignore[method-assign]
554+
s_b._sessions.create_index = counting_b # type: ignore[method-assign]
558555

559556
await s_a._ensure_indexes()
560557
await s_b._ensure_indexes()
@@ -581,9 +578,9 @@ async def test_ping_failure(session: MongoDBSession) -> None:
581578
async def _fail(*args: Any, **kwargs: Any) -> dict[str, Any]:
582579
raise ConnectionError("unreachable")
583580

584-
session._client.admin.command = _fail # type: ignore[attr-defined]
581+
session._client.admin.command = _fail # type: ignore[method-assign, assignment]
585582
assert await session.ping() is False
586-
session._client.admin.command = original # type: ignore[attr-defined]
583+
session._client.admin.command = original # type: ignore[method-assign]
587584

588585

589586
async def test_close_external_client_not_closed() -> None:

uv.lock

Lines changed: 2 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)