@@ -115,19 +115,41 @@ async def insert_many(
115115 doc ["_id" ] = FakeObjectId ()
116116 self ._docs [id (doc ["_id" ])] = dict (doc )
117117
118- async def update_one (
118+ async def find_one_and_update (
119119 self ,
120120 query : dict [str , Any ],
121121 update : dict [str , Any ],
122122 upsert : bool = False ,
123- ) -> None :
123+ return_document : bool = False ,
124+ ) -> dict [str , Any ] | None :
124125 for doc in self ._docs .values ():
125126 if self ._matches (doc , query ):
126- return # Exists — $setOnInsert is a no-op on existing docs.
127+ # Apply $inc fields.
128+ for field , delta in update .get ("$inc" , {}).items ():
129+ doc [field ] = doc .get (field , 0 ) + delta
130+ return dict (doc ) if return_document else None
127131 if upsert :
128132 new_doc : dict [str , Any ] = {"_id" : FakeObjectId ()}
129133 new_doc .update (update .get ("$setOnInsert" , {}))
134+ for field , delta in update .get ("$inc" , {}).items ():
135+ new_doc [field ] = new_doc .get (field , 0 ) + delta
130136 self ._docs [id (new_doc ["_id" ])] = new_doc
137+ return dict (new_doc ) if return_document else None
138+ return None
139+
140+ async def update_one (
141+ self ,
142+ query : dict [str , Any ],
143+ update : dict [str , Any ],
144+ upsert : bool = False ,
145+ ) -> None :
146+ for doc in self ._docs .values ():
147+ if self ._matches (doc , query ):
148+ return # Exists — $setOnInsert is a no-op on existing docs.
149+ if upsert :
150+ new_doc2 : dict [str , Any ] = {"_id" : FakeObjectId ()}
151+ new_doc2 .update (update .get ("$setOnInsert" , {}))
152+ self ._docs [id (new_doc2 ["_id" ])] = new_doc2
131153
132154 async def delete_many (self , query : dict [str , Any ]) -> None :
133155 to_remove = [k for k , d in self ._docs .items () if self ._matches (d , query )]
@@ -235,8 +257,7 @@ def _make_fake_pymongo_modules() -> None:
235257def _make_session (session_id : str = "test-session" , ** kwargs : Any ) -> MongoDBSession :
236258 """Create a MongoDBSession backed by a FakeAsyncMongoClient."""
237259 client = FakeAsyncMongoClient ()
238- MongoDBSession ._initialized_keys .clear ()
239- MongoDBSession ._init_locks .clear ()
260+ MongoDBSession ._init_state .clear ()
240261 return MongoDBSession (
241262 session_id ,
242263 client = client , # type: ignore[arg-type]
@@ -353,8 +374,7 @@ async def test_get_items_limit_exceeds_count(session: MongoDBSession) -> None:
353374
354375async def test_session_settings_limit_used_as_default () -> None :
355376 """session_settings.limit is applied when no explicit limit is given."""
356- MongoDBSession ._initialized_keys .clear ()
357- MongoDBSession ._init_locks .clear ()
377+ MongoDBSession ._init_state .clear ()
358378 s = MongoDBSession (
359379 "ls-test" ,
360380 client = FakeAsyncMongoClient (), # type: ignore[arg-type]
@@ -371,8 +391,7 @@ async def test_session_settings_limit_used_as_default() -> None:
371391
372392async def test_explicit_limit_overrides_session_settings () -> None :
373393 """An explicit limit passed to get_items must override session_settings.limit."""
374- MongoDBSession ._initialized_keys .clear ()
375- MongoDBSession ._init_locks .clear ()
394+ MongoDBSession ._init_state .clear ()
376395 s = MongoDBSession (
377396 "override-test" ,
378397 client = FakeAsyncMongoClient (), # type: ignore[arg-type]
@@ -394,8 +413,7 @@ async def test_explicit_limit_overrides_session_settings() -> None:
394413
395414async def test_sessions_are_isolated () -> None :
396415 """Two sessions with different IDs must not share data."""
397- MongoDBSession ._initialized_keys .clear ()
398- MongoDBSession ._init_locks .clear ()
416+ MongoDBSession ._init_state .clear ()
399417 client = FakeAsyncMongoClient ()
400418 s1 = MongoDBSession ("alice" , client = client , database = "agents_test" ) # type: ignore[arg-type]
401419 s2 = MongoDBSession ("bob" , client = client , database = "agents_test" ) # type: ignore[arg-type]
@@ -409,8 +427,7 @@ async def test_sessions_are_isolated() -> None:
409427
410428async def test_clear_does_not_affect_other_sessions () -> None :
411429 """Clearing one session must leave sibling sessions untouched."""
412- MongoDBSession ._initialized_keys .clear ()
413- MongoDBSession ._init_locks .clear ()
430+ MongoDBSession ._init_state .clear ()
414431 client = FakeAsyncMongoClient ()
415432 s1 = MongoDBSession ("s1" , client = client , database = "agents_test" ) # type: ignore[arg-type]
416433 s2 = MongoDBSession ("s2" , client = client , database = "agents_test" ) # type: ignore[arg-type]
@@ -529,8 +546,7 @@ async def counting(*args: Any, **kwargs: Any) -> str:
529546
530547async def test_different_clients_each_run_index_init () -> None :
531548 """Each distinct AsyncMongoClient gets its own index-creation pass."""
532- MongoDBSession ._initialized_keys .clear ()
533- MongoDBSession ._init_locks .clear ()
549+ MongoDBSession ._init_state .clear ()
534550
535551 client_a = FakeAsyncMongoClient ()
536552 client_b = FakeAsyncMongoClient ()
@@ -585,8 +601,7 @@ async def _fail(*args: Any, **kwargs: Any) -> dict[str, Any]:
585601
586602async def test_close_external_client_not_closed () -> None :
587603 """close() must NOT close a client that was injected externally."""
588- MongoDBSession ._initialized_keys .clear ()
589- MongoDBSession ._init_locks .clear ()
604+ MongoDBSession ._init_state .clear ()
590605 client = FakeAsyncMongoClient ()
591606 s = MongoDBSession ("x" , client = client , database = "agents_test" ) # type: ignore[arg-type]
592607 assert s ._owns_client is False
@@ -597,8 +612,7 @@ async def test_close_external_client_not_closed() -> None:
597612
598613async def test_close_owned_client_is_closed () -> None :
599614 """close() must close a client created by from_uri."""
600- MongoDBSession ._initialized_keys .clear ()
601- MongoDBSession ._init_locks .clear ()
615+ MongoDBSession ._init_state .clear ()
602616 fake_client = FakeAsyncMongoClient ()
603617 with patch (
604618 "agents.extensions.memory.mongodb_session.AsyncMongoClient" ,
@@ -636,8 +650,7 @@ async def test_runner_integration(agent: Agent) -> None:
636650
637651async def test_runner_session_isolation (agent : Agent ) -> None :
638652 """Two independent sessions must not bleed history into each other."""
639- MongoDBSession ._initialized_keys .clear ()
640- MongoDBSession ._init_locks .clear ()
653+ MongoDBSession ._init_state .clear ()
641654 client = FakeAsyncMongoClient ()
642655 s1 = MongoDBSession ("user-a" , client = client , database = "agents_test" ) # type: ignore[arg-type]
643656 s2 = MongoDBSession ("user-b" , client = client , database = "agents_test" ) # type: ignore[arg-type]
@@ -659,8 +672,7 @@ async def test_runner_with_session_settings_limit(agent: Agent) -> None:
659672 """RunConfig.session_settings.limit must cap the history sent to the model."""
660673 from agents import RunConfig
661674
662- MongoDBSession ._initialized_keys .clear ()
663- MongoDBSession ._init_locks .clear ()
675+ MongoDBSession ._init_state .clear ()
664676 session = MongoDBSession (
665677 "limit-test" ,
666678 client = FakeAsyncMongoClient (), # type: ignore[arg-type]
@@ -694,8 +706,7 @@ async def test_runner_with_session_settings_limit(agent: Agent) -> None:
694706
695707async def test_injected_client_receives_append_metadata () -> None :
696708 """Append_metadata is called on a caller-supplied client."""
697- MongoDBSession ._initialized_keys .clear ()
698- MongoDBSession ._init_locks .clear ()
709+ MongoDBSession ._init_state .clear ()
699710 client = FakeAsyncMongoClient ()
700711
701712 MongoDBSession ("meta-test" , client = client , database = "agents_test" ) # type: ignore[arg-type]
@@ -707,8 +718,7 @@ async def test_injected_client_receives_append_metadata() -> None:
707718
708719async def test_from_uri_passes_driver_info_to_constructor () -> None :
709720 """driver=_DRIVER_INFO is forwarded to AsyncMongoClient via from_uri."""
710- MongoDBSession ._initialized_keys .clear ()
711- MongoDBSession ._init_locks .clear ()
721+ MongoDBSession ._init_state .clear ()
712722
713723 captured_kwargs : dict [str , Any ] = {}
714724
@@ -728,8 +738,7 @@ def _fake_client(uri: str, **kwargs: Any) -> FakeAsyncMongoClient:
728738
729739async def test_caller_supplied_driver_info_is_not_overwritten () -> None :
730740 """A caller-supplied driver kwarg must not be silently replaced."""
731- MongoDBSession ._initialized_keys .clear ()
732- MongoDBSession ._init_locks .clear ()
741+ MongoDBSession ._init_state .clear ()
733742
734743 captured_kwargs : dict [str , Any ] = {}
735744 custom_info = FakeDriverInfo (name = "MyApp" )
0 commit comments