Skip to content

Commit a861452

Browse files
alexbeviclaude
andcommitted
feat: add MongoDB driver handshake metadata to MongoDBSession
Implements both client metadata patterns per the add-client-metadata skill: - Pattern A (from_uri): passes driver=_DRIVER_INFO to AsyncMongoClient via client_kwargs.setdefault(), so caller-supplied driver values are preserved. - Pattern B (injected client): calls client.append_metadata(_DRIVER_INFO) in __init__ guarded by hasattr, compatible with PyMongo <4.14. _DRIVER_INFO is a module-level DriverInfo(name="openai-agents", version=...) where version is resolved at runtime via importlib.metadata. Adds three new tests covering both patterns and the no-overwrite guard. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent 74fc3b6 commit a861452

File tree

2 files changed

+99
-0
lines changed

2 files changed

+99
-0
lines changed

src/agents/extensions/memory/mongodb_session.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,17 @@
3535
import json
3636
from typing import Any
3737

38+
try:
39+
from importlib.metadata import version as _get_version
40+
41+
_VERSION: str | None = _get_version("openai-agents")
42+
except Exception:
43+
_VERSION = None
44+
3845
try:
3946
from pymongo.asynchronous.collection import AsyncCollection
4047
from pymongo.asynchronous.mongo_client import AsyncMongoClient
48+
from pymongo.driver_info import DriverInfo
4149
except ImportError as e:
4250
raise ImportError(
4351
"MongoDBSession requires the 'pymongo' package (>=4.13). "
@@ -48,6 +56,9 @@
4856
from ...memory.session import SessionABC
4957
from ...memory.session_settings import SessionSettings, resolve_session_limit
5058

59+
# Identifies this library in the MongoDB handshake for server-side telemetry.
60+
_DRIVER_INFO = DriverInfo(name="openai-agents", version=_VERSION)
61+
5162

5263
class MongoDBSession(SessionABC):
5364
"""MongoDB implementation of :pyclass:`agents.memory.session.Session`.
@@ -98,6 +109,11 @@ def __init__(
98109
self._client = client
99110
self._owns_client = False
100111

112+
# Pattern B: annotate an externally-supplied client with library metadata.
113+
# append_metadata is available in PyMongo >=4.14; guard for older installs.
114+
if hasattr(client, "append_metadata"):
115+
client.append_metadata(_DRIVER_INFO)
116+
101117
db = client[database]
102118
self._sessions: AsyncCollection = db[sessions_collection]
103119
self._messages: AsyncCollection = db[messages_collection]
@@ -138,6 +154,7 @@ def from_uri(
138154
A :class:`MongoDBSession` connected to the specified MongoDB server.
139155
"""
140156
client_kwargs = client_kwargs or {}
157+
client_kwargs.setdefault("driver", _DRIVER_INFO)
141158
client: AsyncMongoClient = AsyncMongoClient(uri, **client_kwargs)
142159
session = cls(
143160
session_id,

tests/extensions/memory/test_mongodb_session.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,17 +167,30 @@ async def command(self, cmd: str) -> dict[str, Any]:
167167
return {"ok": 1}
168168

169169

170+
class FakeDriverInfo:
171+
"""Minimal stand-in for pymongo.driver_info.DriverInfo."""
172+
173+
def __init__(self, name: str, version: str | None = None) -> None:
174+
self.name = name
175+
self.version = version
176+
177+
170178
class FakeAsyncMongoClient:
171179
"""In-memory substitute for pymongo AsyncMongoClient."""
172180

173181
def __init__(self, *args: Any, **kwargs: Any) -> None:
174182
self._databases: dict[str, FakeAsyncDatabase] = defaultdict(FakeAsyncDatabase)
175183
self._closed = False
176184
self.admin = FakeAdminDatabase()
185+
self._metadata_calls: list[FakeDriverInfo] = []
177186

178187
def __getitem__(self, name: str) -> FakeAsyncDatabase:
179188
return self._databases[name]
180189

190+
def append_metadata(self, driver_info: FakeDriverInfo) -> None:
191+
"""Record append_metadata calls for test assertions."""
192+
self._metadata_calls.append(driver_info)
193+
181194
async def aclose(self) -> None:
182195
self._closed = True
183196
self.admin._closed = True
@@ -195,14 +208,17 @@ def _make_fake_pymongo_modules() -> None:
195208
async_pkg = types.ModuleType("pymongo.asynchronous")
196209
collection_mod = types.ModuleType("pymongo.asynchronous.collection")
197210
client_mod = types.ModuleType("pymongo.asynchronous.mongo_client")
211+
driver_info_mod = types.ModuleType("pymongo.driver_info")
198212

199213
collection_mod.AsyncCollection = FakeAsyncCollection # type: ignore[attr-defined]
200214
client_mod.AsyncMongoClient = FakeAsyncMongoClient # type: ignore[attr-defined]
215+
driver_info_mod.DriverInfo = FakeDriverInfo # type: ignore[attr-defined]
201216

202217
sys.modules["pymongo"] = pymongo_mod
203218
sys.modules["pymongo.asynchronous"] = async_pkg
204219
sys.modules["pymongo.asynchronous.collection"] = collection_mod
205220
sys.modules["pymongo.asynchronous.mongo_client"] = client_mod
221+
sys.modules["pymongo.driver_info"] = driver_info_mod
206222

207223

208224
_make_fake_pymongo_modules()
@@ -621,3 +637,69 @@ async def test_runner_with_session_settings_limit(agent: Agent) -> None:
621637
last_input = agent.model.last_turn_args["input"]
622638
history_items = [i for i in last_input if i.get("content") != "New question"]
623639
assert len(history_items) == 2
640+
641+
642+
# ---------------------------------------------------------------------------
643+
# Client metadata (driver handshake)
644+
# ---------------------------------------------------------------------------
645+
646+
647+
async def test_injected_client_receives_append_metadata() -> None:
648+
"""Pattern B: append_metadata is called on a caller-supplied client."""
649+
MongoDBSession._initialized_keys.clear()
650+
MongoDBSession._init_locks.clear()
651+
client = FakeAsyncMongoClient()
652+
653+
MongoDBSession("meta-test", client=client, database="agents_test") # type: ignore[arg-type]
654+
655+
assert len(client._metadata_calls) == 1
656+
info = client._metadata_calls[0]
657+
assert info.name == "openai-agents"
658+
659+
660+
async def test_from_uri_passes_driver_info_to_constructor() -> None:
661+
"""Pattern A: driver=_DRIVER_INFO is forwarded to AsyncMongoClient via from_uri."""
662+
MongoDBSession._initialized_keys.clear()
663+
MongoDBSession._init_locks.clear()
664+
665+
captured_kwargs: dict[str, Any] = {}
666+
667+
def _fake_client(uri: str, **kwargs: Any) -> FakeAsyncMongoClient:
668+
captured_kwargs.update(kwargs)
669+
return FakeAsyncMongoClient()
670+
671+
with patch(
672+
"agents.extensions.memory.mongodb_session.AsyncMongoClient",
673+
side_effect=_fake_client,
674+
):
675+
MongoDBSession.from_uri("uri-test", uri="mongodb://localhost:27017", database="t")
676+
677+
assert "driver" in captured_kwargs
678+
assert captured_kwargs["driver"].name == "openai-agents"
679+
680+
681+
async def test_caller_supplied_driver_info_is_not_overwritten() -> None:
682+
"""Pattern A: a caller-supplied driver kwarg must not be silently replaced."""
683+
MongoDBSession._initialized_keys.clear()
684+
MongoDBSession._init_locks.clear()
685+
686+
captured_kwargs: dict[str, Any] = {}
687+
custom_info = FakeDriverInfo(name="MyApp")
688+
689+
def _fake_client(uri: str, **kwargs: Any) -> FakeAsyncMongoClient:
690+
captured_kwargs.update(kwargs)
691+
return FakeAsyncMongoClient()
692+
693+
with patch(
694+
"agents.extensions.memory.mongodb_session.AsyncMongoClient",
695+
side_effect=_fake_client,
696+
):
697+
MongoDBSession.from_uri(
698+
"uri-test",
699+
uri="mongodb://localhost:27017",
700+
database="t",
701+
client_kwargs={"driver": custom_info},
702+
)
703+
704+
# The caller's value must be preserved — setdefault must not overwrite it.
705+
assert captured_kwargs["driver"] is custom_info

0 commit comments

Comments
 (0)