Skip to content

Commit b662362

Browse files
committed
fix review comments
1 parent 0ad7092 commit b662362

File tree

2 files changed

+57
-2
lines changed

2 files changed

+57
-2
lines changed

src/agents/memory/sqlite_session.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -326,8 +326,7 @@ async def apply_history_mutations(self, args: SessionHistoryRewriteArgs) -> None
326326
return
327327

328328
def _apply_history_mutations_sync() -> None:
329-
conn = self._get_connection()
330-
with self._lock if self._is_memory_db else threading.Lock():
329+
with self._locked_connection() as conn:
331330
cursor = conn.execute(
332331
f"""
333332
SELECT message_data FROM {self.messages_table}

tests/test_session.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,19 @@
1212
from .test_responses import get_text_message
1313

1414

15+
class _RecordingLock:
16+
def __init__(self, lock):
17+
self._lock = lock
18+
self.enter_count = 0
19+
20+
def __enter__(self):
21+
self.enter_count += 1
22+
return self._lock.__enter__()
23+
24+
def __exit__(self, exc_type, exc, tb):
25+
return self._lock.__exit__(exc_type, exc, tb)
26+
27+
1528
# Helper functions for parametrized testing of different Runner methods
1629
def _run_sync_wrapper(agent, input_data, **kwargs):
1730
"""Wrapper for run_sync that properly sets up an event loop."""
@@ -567,6 +580,49 @@ async def test_sqlite_session_file_lock_is_shared_across_instances():
567580
assert lock_path not in SQLiteSession._file_locks
568581

569582

583+
@pytest.mark.asyncio
584+
async def test_sqlite_session_apply_history_mutations_uses_file_lock():
585+
"""File-backed history rewrites should reuse the session lock."""
586+
with tempfile.TemporaryDirectory() as temp_dir:
587+
db_path = Path(temp_dir) / "test_rewrite_lock.db"
588+
session = SQLiteSession("rewrite_lock_test", db_path)
589+
function_call: TResponseInputItem = {
590+
"type": "function_call",
591+
"call_id": "call-1",
592+
"id": "fc_1",
593+
"name": "test_tool",
594+
"arguments": '{"value":"before"}',
595+
}
596+
replacement: TResponseInputItem = {
597+
"type": "function_call",
598+
"call_id": "call-1",
599+
"id": "fc_1",
600+
"name": "test_tool",
601+
"arguments": '{"value":"after"}',
602+
}
603+
604+
await session.add_items([function_call])
605+
recording_lock = _RecordingLock(session._lock)
606+
session.__dict__["_lock"] = recording_lock
607+
608+
await session.apply_history_mutations(
609+
{
610+
"mutations": [
611+
{
612+
"type": "replace_function_call",
613+
"call_id": "call-1",
614+
"replacement": replacement,
615+
}
616+
]
617+
}
618+
)
619+
assert recording_lock.enter_count == 1
620+
621+
retrieved = await session.get_items()
622+
assert retrieved == [replacement]
623+
session.close()
624+
625+
570626
@pytest.mark.asyncio
571627
async def test_session_add_items_exception_propagates_in_streamed():
572628
"""Test that exceptions from session.add_items are properly propagated

0 commit comments

Comments
 (0)