Skip to content

Commit f29a658

Browse files
authored
fix(memory): honor custom table names in AdvancedSQLiteSession (#2694)
1 parent 8f3b104 commit f29a658

File tree

2 files changed

+73
-23
lines changed

2 files changed

+73
-23
lines changed

src/agents/extensions/memory/advanced_sqlite_session.py

Lines changed: 27 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def _init_structure_tables(self):
5959
conn = self._get_connection()
6060

6161
# Message structure with branch support
62-
conn.execute("""
62+
conn.execute(f"""
6363
CREATE TABLE IF NOT EXISTS message_structure (
6464
id INTEGER PRIMARY KEY AUTOINCREMENT,
6565
session_id TEXT NOT NULL,
@@ -71,13 +71,15 @@ def _init_structure_tables(self):
7171
branch_turn_number INTEGER,
7272
tool_name TEXT,
7373
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
74-
FOREIGN KEY (session_id) REFERENCES agent_sessions(session_id) ON DELETE CASCADE,
75-
FOREIGN KEY (message_id) REFERENCES agent_messages(id) ON DELETE CASCADE
74+
FOREIGN KEY (session_id)
75+
REFERENCES {self.sessions_table}(session_id) ON DELETE CASCADE,
76+
FOREIGN KEY (message_id)
77+
REFERENCES {self.messages_table}(id) ON DELETE CASCADE
7678
)
7779
""")
7880

7981
# Turn-level usage tracking with branch support and full JSON details
80-
conn.execute("""
82+
conn.execute(f"""
8183
CREATE TABLE IF NOT EXISTS turn_usage (
8284
id INTEGER PRIMARY KEY AUTOINCREMENT,
8385
session_id TEXT NOT NULL,
@@ -90,7 +92,8 @@ def _init_structure_tables(self):
9092
input_tokens_details JSON,
9193
output_tokens_details JSON,
9294
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
93-
FOREIGN KEY (session_id) REFERENCES agent_sessions(session_id) ON DELETE CASCADE,
95+
FOREIGN KEY (session_id)
96+
REFERENCES {self.sessions_table}(session_id) ON DELETE CASCADE,
9497
UNIQUE(session_id, branch_id, user_turn_number)
9598
)
9699
""")
@@ -160,9 +163,9 @@ def _get_all_items_sync():
160163
with closing(conn.cursor()) as cursor:
161164
if session_limit is None:
162165
cursor.execute(
163-
"""
166+
f"""
164167
SELECT m.message_data
165-
FROM agent_messages m
168+
FROM {self.messages_table} m
166169
JOIN message_structure s ON m.id = s.message_id
167170
WHERE m.session_id = ? AND s.branch_id = ?
168171
ORDER BY s.sequence_number ASC
@@ -171,9 +174,9 @@ def _get_all_items_sync():
171174
)
172175
else:
173176
cursor.execute(
174-
"""
177+
f"""
175178
SELECT m.message_data
176-
FROM agent_messages m
179+
FROM {self.messages_table} m
177180
JOIN message_structure s ON m.id = s.message_id
178181
WHERE m.session_id = ? AND s.branch_id = ?
179182
ORDER BY s.sequence_number DESC
@@ -206,9 +209,9 @@ def _get_items_sync():
206209
# Get message IDs in correct order for this branch
207210
if session_limit is None:
208211
cursor.execute(
209-
"""
212+
f"""
210213
SELECT m.message_data
211-
FROM agent_messages m
214+
FROM {self.messages_table} m
212215
JOIN message_structure s ON m.id = s.message_id
213216
WHERE m.session_id = ? AND s.branch_id = ?
214217
ORDER BY s.sequence_number ASC
@@ -217,9 +220,9 @@ def _get_items_sync():
217220
)
218221
else:
219222
cursor.execute(
220-
"""
223+
f"""
221224
SELECT m.message_data
222-
FROM agent_messages m
225+
FROM {self.messages_table} m
223226
JOIN message_structure s ON m.id = s.message_id
224227
WHERE m.session_id = ? AND s.branch_id = ?
225228
ORDER BY s.sequence_number DESC
@@ -439,7 +442,7 @@ def _add_structure_sync():
439442
# Don't re-raise - structure metadata is supplementary
440443

441444
async def _cleanup_orphaned_messages(self) -> int:
442-
"""Remove messages that exist in agent_messages but not in message_structure.
445+
"""Remove messages that exist in the configured message table but not in message_structure.
443446
444447
This can happen if _add_structure_metadata fails after super().add_items() succeeds.
445448
Used for maintaining data consistency.
@@ -453,9 +456,9 @@ def _cleanup_sync():
453456
with closing(conn.cursor()) as cursor:
454457
# Find messages without structure metadata
455458
cursor.execute(
456-
"""
459+
f"""
457460
SELECT am.id
458-
FROM agent_messages am
461+
FROM {self.messages_table} am
459462
LEFT JOIN message_structure ms ON am.id = ms.message_id
460463
WHERE am.session_id = ? AND ms.message_id IS NULL
461464
""",
@@ -468,7 +471,8 @@ def _cleanup_sync():
468471
# Delete orphaned messages
469472
placeholders = ",".join("?" * len(orphaned_ids))
470473
cursor.execute(
471-
f"DELETE FROM agent_messages WHERE id IN ({placeholders})", orphaned_ids
474+
f"DELETE FROM {self.messages_table} WHERE id IN ({placeholders})",
475+
orphaned_ids,
472476
)
473477

474478
deleted_count = cursor.rowcount
@@ -587,10 +591,10 @@ def _validate_turn():
587591
conn = self._get_connection()
588592
with closing(conn.cursor()) as cursor:
589593
cursor.execute(
590-
"""
594+
f"""
591595
SELECT am.message_data
592596
FROM message_structure ms
593-
JOIN agent_messages am ON ms.message_id = am.id
597+
JOIN {self.messages_table} am ON ms.message_id = am.id
594598
WHERE ms.session_id = ? AND ms.branch_id = ?
595599
AND ms.branch_turn_number = ? AND ms.message_type = 'user'
596600
""",
@@ -920,13 +924,13 @@ def _get_turns_sync():
920924
conn = self._get_connection()
921925
with closing(conn.cursor()) as cursor:
922926
cursor.execute(
923-
"""
927+
f"""
924928
SELECT
925929
ms.branch_turn_number,
926930
am.message_data,
927931
ms.created_at
928932
FROM message_structure ms
929-
JOIN agent_messages am ON ms.message_id = am.id
933+
JOIN {self.messages_table} am ON ms.message_id = am.id
930934
WHERE ms.session_id = ? AND ms.branch_id = ?
931935
AND ms.message_type = 'user'
932936
ORDER BY ms.branch_turn_number
@@ -975,13 +979,13 @@ def _search_sync():
975979
conn = self._get_connection()
976980
with closing(conn.cursor()) as cursor:
977981
cursor.execute(
978-
"""
982+
f"""
979983
SELECT
980984
ms.branch_turn_number,
981985
am.message_data,
982986
ms.created_at
983987
FROM message_structure ms
984-
JOIN agent_messages am ON ms.message_id = am.id
988+
JOIN {self.messages_table} am ON ms.message_id = am.id
985989
WHERE ms.session_id = ? AND ms.branch_id = ?
986990
AND ms.message_type = 'user'
987991
AND am.message_data LIKE ?

tests/extensions/memory/test_advanced_sqlite_session.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,52 @@ async def test_advanced_session_basic_functionality(agent: Agent):
9999
session.close()
100100

101101

102+
async def test_advanced_session_respects_custom_table_names():
103+
"""AdvancedSQLiteSession should consistently use configured table names."""
104+
session = AdvancedSQLiteSession(
105+
session_id="advanced_custom_tables",
106+
create_tables=True,
107+
sessions_table="custom_agent_sessions",
108+
messages_table="custom_agent_messages",
109+
)
110+
111+
items: list[TResponseInputItem] = [
112+
{"role": "user", "content": "Hello"},
113+
{"role": "assistant", "content": "Hi there!"},
114+
{"role": "user", "content": "Let's do some math"},
115+
{"role": "assistant", "content": "Sure"},
116+
]
117+
await session.add_items(items)
118+
119+
assert await session.get_items() == items
120+
121+
conversation_turns = await session.get_conversation_turns()
122+
assert [turn["turn"] for turn in conversation_turns] == [1, 2]
123+
124+
matching_turns = await session.find_turns_by_content("math")
125+
assert [turn["turn"] for turn in matching_turns] == [2]
126+
127+
conn = session._get_connection()
128+
structure_foreign_keys = {
129+
row[2] for row in conn.execute("PRAGMA foreign_key_list(message_structure)").fetchall()
130+
}
131+
usage_foreign_keys = {
132+
row[2] for row in conn.execute("PRAGMA foreign_key_list(turn_usage)").fetchall()
133+
}
134+
assert structure_foreign_keys == {
135+
session.messages_table,
136+
session.sessions_table,
137+
}
138+
assert usage_foreign_keys == {session.sessions_table}
139+
140+
branch_name = await session.create_branch_from_turn(2, "custom_branch")
141+
assert branch_name == "custom_branch"
142+
assert await session.get_items() == items[:2]
143+
assert await session.get_items(branch_id="main") == items
144+
145+
session.close()
146+
147+
102148
async def test_message_structure_tracking(agent: Agent):
103149
"""Test that message structure is properly tracked."""
104150
session_id = "structure_test"

0 commit comments

Comments
 (0)