@@ -167,17 +167,30 @@ async def command(self, cmd: str) -> dict[str, Any]:
167167 return {"ok" : 1 }
168168
169169
170+ class FakeDriverInfo :
171+ """Minimal stand-in for pymongo.driver_info.DriverInfo."""
172+
173+ def __init__ (self , name : str , version : str | None = None ) -> None :
174+ self .name = name
175+ self .version = version
176+
177+
170178class FakeAsyncMongoClient :
171179 """In-memory substitute for pymongo AsyncMongoClient."""
172180
173181 def __init__ (self , * args : Any , ** kwargs : Any ) -> None :
174182 self ._databases : dict [str , FakeAsyncDatabase ] = defaultdict (FakeAsyncDatabase )
175183 self ._closed = False
176184 self .admin = FakeAdminDatabase ()
185+ self ._metadata_calls : list [FakeDriverInfo ] = []
177186
178187 def __getitem__ (self , name : str ) -> FakeAsyncDatabase :
179188 return self ._databases [name ]
180189
190+ def append_metadata (self , driver_info : FakeDriverInfo ) -> None :
191+ """Record append_metadata calls for test assertions."""
192+ self ._metadata_calls .append (driver_info )
193+
181194 async def aclose (self ) -> None :
182195 self ._closed = True
183196 self .admin ._closed = True
@@ -195,14 +208,17 @@ def _make_fake_pymongo_modules() -> None:
195208 async_pkg = types .ModuleType ("pymongo.asynchronous" )
196209 collection_mod = types .ModuleType ("pymongo.asynchronous.collection" )
197210 client_mod = types .ModuleType ("pymongo.asynchronous.mongo_client" )
211+ driver_info_mod = types .ModuleType ("pymongo.driver_info" )
198212
199213 collection_mod .AsyncCollection = FakeAsyncCollection # type: ignore[attr-defined]
200214 client_mod .AsyncMongoClient = FakeAsyncMongoClient # type: ignore[attr-defined]
215+ driver_info_mod .DriverInfo = FakeDriverInfo # type: ignore[attr-defined]
201216
202217 sys .modules ["pymongo" ] = pymongo_mod
203218 sys .modules ["pymongo.asynchronous" ] = async_pkg
204219 sys .modules ["pymongo.asynchronous.collection" ] = collection_mod
205220 sys .modules ["pymongo.asynchronous.mongo_client" ] = client_mod
221+ sys .modules ["pymongo.driver_info" ] = driver_info_mod
206222
207223
208224_make_fake_pymongo_modules ()
@@ -621,3 +637,69 @@ async def test_runner_with_session_settings_limit(agent: Agent) -> None:
621637 last_input = agent .model .last_turn_args ["input" ]
622638 history_items = [i for i in last_input if i .get ("content" ) != "New question" ]
623639 assert len (history_items ) == 2
640+
641+
642+ # ---------------------------------------------------------------------------
643+ # Client metadata (driver handshake)
644+ # ---------------------------------------------------------------------------
645+
646+
647+ async def test_injected_client_receives_append_metadata () -> None :
648+ """Pattern B: append_metadata is called on a caller-supplied client."""
649+ MongoDBSession ._initialized_keys .clear ()
650+ MongoDBSession ._init_locks .clear ()
651+ client = FakeAsyncMongoClient ()
652+
653+ MongoDBSession ("meta-test" , client = client , database = "agents_test" ) # type: ignore[arg-type]
654+
655+ assert len (client ._metadata_calls ) == 1
656+ info = client ._metadata_calls [0 ]
657+ assert info .name == "openai-agents"
658+
659+
660+ async def test_from_uri_passes_driver_info_to_constructor () -> None :
661+ """Pattern A: driver=_DRIVER_INFO is forwarded to AsyncMongoClient via from_uri."""
662+ MongoDBSession ._initialized_keys .clear ()
663+ MongoDBSession ._init_locks .clear ()
664+
665+ captured_kwargs : dict [str , Any ] = {}
666+
667+ def _fake_client (uri : str , ** kwargs : Any ) -> FakeAsyncMongoClient :
668+ captured_kwargs .update (kwargs )
669+ return FakeAsyncMongoClient ()
670+
671+ with patch (
672+ "agents.extensions.memory.mongodb_session.AsyncMongoClient" ,
673+ side_effect = _fake_client ,
674+ ):
675+ MongoDBSession .from_uri ("uri-test" , uri = "mongodb://localhost:27017" , database = "t" )
676+
677+ assert "driver" in captured_kwargs
678+ assert captured_kwargs ["driver" ].name == "openai-agents"
679+
680+
681+ async def test_caller_supplied_driver_info_is_not_overwritten () -> None :
682+ """Pattern A: a caller-supplied driver kwarg must not be silently replaced."""
683+ MongoDBSession ._initialized_keys .clear ()
684+ MongoDBSession ._init_locks .clear ()
685+
686+ captured_kwargs : dict [str , Any ] = {}
687+ custom_info = FakeDriverInfo (name = "MyApp" )
688+
689+ def _fake_client (uri : str , ** kwargs : Any ) -> FakeAsyncMongoClient :
690+ captured_kwargs .update (kwargs )
691+ return FakeAsyncMongoClient ()
692+
693+ with patch (
694+ "agents.extensions.memory.mongodb_session.AsyncMongoClient" ,
695+ side_effect = _fake_client ,
696+ ):
697+ MongoDBSession .from_uri (
698+ "uri-test" ,
699+ uri = "mongodb://localhost:27017" ,
700+ database = "t" ,
701+ client_kwargs = {"driver" : custom_info },
702+ )
703+
704+ # The caller's value must be preserved — setdefault must not overwrite it.
705+ assert captured_kwargs ["driver" ] is custom_info
0 commit comments