Skip to content

Commit 4cfa66a

Browse files
author
Kcstring
committed
extensions/memory: add Google Cloud Firestore session backend
Add FirestoreSession, a production-grade session backend backed by Google Cloud Firestore. Conversation items are stored as documents in a messages sub-collection; a monotonic sequence counter on the parent session document (updated via a Firestore transaction) guarantees strict insertion order across concurrent writers. Public API mirrors the existing MongoDB/Redis backends: - FirestoreSession(session_id, *, client, ...) - FirestoreSession.from_project(session_id, *, project, ...) - get_items / add_items / pop_item / clear_session / close / ping Also adds: - firestore optional dependency group in pyproject.toml - 14 unit tests (no real Firestore server required)
1 parent da82b2c commit 4cfa66a

File tree

4 files changed

+828
-0
lines changed

4 files changed

+828
-0
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ encrypt = ["cryptography>=45.0, <46"]
4545
redis = ["redis>=7"]
4646
dapr = ["dapr>=1.16.0", "grpcio>=1.60.0"]
4747
mongodb = ["pymongo>=4.14"]
48+
firestore = ["google-cloud-firestore>=2.19"]
4849
docker = ["docker>=6.1"]
4950
blaxel = ["blaxel>=0.2.50", "aiohttp>=3.12,<4"]
5051
daytona = ["daytona>=0.155.0"]

src/agents/extensions/memory/__init__.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
DaprSession,
2020
)
2121
from .encrypt_session import EncryptedSession
22+
from .firestore_session import FirestoreSession
2223
from .mongodb_session import MongoDBSession
2324
from .redis_session import RedisSession
2425
from .sqlalchemy_session import SQLAlchemySession
@@ -30,6 +31,7 @@
3031
"DAPR_CONSISTENCY_STRONG",
3132
"DaprSession",
3233
"EncryptedSession",
34+
"FirestoreSession",
3335
"MongoDBSession",
3436
"RedisSession",
3537
"SQLAlchemySession",
@@ -130,4 +132,15 @@ def __getattr__(name: str) -> Any:
130132
"Install it with: pip install openai-agents[mongodb]"
131133
) from e
132134

135+
if name == "FirestoreSession":
136+
try:
137+
from .firestore_session import FirestoreSession # noqa: F401
138+
139+
return FirestoreSession
140+
except ModuleNotFoundError as e:
141+
raise ImportError(
142+
"FirestoreSession requires the 'firestore' extra. "
143+
"Install it with: pip install openai-agents[firestore]"
144+
) from e
145+
133146
raise AttributeError(f"module {__name__} has no attribute {name}")
Lines changed: 338 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,338 @@
1+
"""Google Cloud Firestore-powered Session backend.
2+
3+
Requires ``google-cloud-firestore>=2.19``, which ships the native async API.
4+
Install it with::
5+
6+
pip install openai-agents[firestore]
7+
8+
Usage::
9+
10+
from agents.extensions.memory import FirestoreSession
11+
12+
# Create from a Google Cloud project ID (uses Application Default Credentials)
13+
session = FirestoreSession.from_project(
14+
session_id="user-123",
15+
project="my-gcp-project",
16+
)
17+
18+
# Or pass an existing AsyncClient that your application already manages
19+
from google.cloud.firestore_v1.async_client import AsyncClient
20+
21+
client = AsyncClient(project="my-gcp-project")
22+
session = FirestoreSession(
23+
session_id="user-123",
24+
client=client,
25+
)
26+
27+
await Runner.run(agent, "Hello", session=session)
28+
"""
29+
30+
from __future__ import annotations
31+
32+
import asyncio
33+
import json
34+
from typing import Any
35+
36+
try:
37+
from google.cloud.firestore_v1.async_client import AsyncClient
38+
from google.cloud.firestore_v1.async_collection import AsyncCollectionReference
39+
from google.cloud.firestore_v1.async_document import AsyncDocumentReference
40+
except ImportError as e:
41+
raise ImportError(
42+
"FirestoreSession requires the 'google-cloud-firestore' package (>=2.19). "
43+
"Install it with: pip install openai-agents[firestore]"
44+
) from e
45+
46+
from ...items import TResponseInputItem
47+
from ...memory.session import SessionABC
48+
from ...memory.session_settings import SessionSettings, resolve_session_limit
49+
50+
51+
class FirestoreSession(SessionABC):
52+
"""Google Cloud Firestore implementation of :class:`agents.memory.session.Session`.
53+
54+
Conversation items are stored as individual documents in a ``messages``
55+
sub-collection under each session document. A parent ``sessions``
56+
collection holds lightweight metadata (creation time, last-updated time,
57+
and a monotonic sequence counter) for each session.
58+
59+
Each message document carries a ``seq`` field — an integer assigned by
60+
atomically incrementing a counter on the session metadata document via a
61+
Firestore transaction. This guarantees a strictly monotonic insertion
62+
order that is safe across multiple writers and processes.
63+
64+
Data layout in Firestore::
65+
66+
{sessions_collection}/
67+
{session_id} ← session metadata doc
68+
_seq: int ← monotonic counter
69+
created_at: int ← Unix timestamp
70+
updated_at: int ← Unix timestamp
71+
messages/ ← sub-collection
72+
{auto_id}
73+
seq: int
74+
message_data: str ← JSON-serialized TResponseInputItem
75+
"""
76+
77+
session_settings: SessionSettings | None = None
78+
79+
def __init__(
80+
self,
81+
session_id: str,
82+
*,
83+
client: AsyncClient,
84+
sessions_collection: str = "agent_sessions",
85+
session_settings: SessionSettings | None = None,
86+
):
87+
"""Initialize a new FirestoreSession.
88+
89+
Args:
90+
session_id: Unique identifier for the conversation.
91+
client: A pre-configured Firestore :class:`AsyncClient` instance.
92+
sessions_collection: Name of the top-level Firestore collection that
93+
stores session documents. Each session document contains a
94+
``messages`` sub-collection. Defaults to ``"agent_sessions"``.
95+
session_settings: Optional session configuration. When ``None`` a
96+
default :class:`~agents.memory.session_settings.SessionSettings`
97+
is used (no item limit).
98+
"""
99+
self.session_id = session_id
100+
self.session_settings = session_settings or SessionSettings()
101+
self._client = client
102+
self._owns_client = False
103+
self._lock = asyncio.Lock()
104+
105+
self._session_ref: AsyncDocumentReference = client.collection(sessions_collection).document(
106+
session_id
107+
)
108+
self._messages_ref: AsyncCollectionReference = self._session_ref.collection("messages")
109+
110+
# ------------------------------------------------------------------
111+
# Convenience constructors
112+
# ------------------------------------------------------------------
113+
114+
@classmethod
115+
def from_project(
116+
cls,
117+
session_id: str,
118+
*,
119+
project: str,
120+
database: str = "(default)",
121+
client_kwargs: dict[str, Any] | None = None,
122+
session_settings: SessionSettings | None = None,
123+
**kwargs: Any,
124+
) -> FirestoreSession:
125+
"""Create a session from a Google Cloud project ID.
126+
127+
Authentication uses `Application Default Credentials`_ (ADC). Run
128+
``gcloud auth application-default login`` locally, or rely on the
129+
service account attached to your GCP resource in production.
130+
131+
.. _Application Default Credentials:
132+
https://cloud.google.com/docs/authentication/application-default-credentials
133+
134+
Args:
135+
session_id: Conversation ID.
136+
project: Google Cloud project ID.
137+
database: Firestore database ID. Defaults to ``"(default)"``.
138+
client_kwargs: Additional keyword arguments forwarded to
139+
:class:`google.cloud.firestore_v1.async_client.AsyncClient`.
140+
session_settings: Optional session configuration settings.
141+
**kwargs: Additional keyword arguments forwarded to the main
142+
constructor (e.g. ``sessions_collection``).
143+
144+
Returns:
145+
A :class:`FirestoreSession` connected to the specified project.
146+
"""
147+
client_kwargs = client_kwargs or {}
148+
client = AsyncClient(project=project, database=database, **client_kwargs)
149+
session = cls(
150+
session_id,
151+
client=client,
152+
session_settings=session_settings,
153+
**kwargs,
154+
)
155+
session._owns_client = True
156+
return session
157+
158+
# ------------------------------------------------------------------
159+
# Serialization helpers
160+
# ------------------------------------------------------------------
161+
162+
async def _serialize_item(self, item: TResponseInputItem) -> str:
163+
"""Serialize an item to a JSON string. Can be overridden by subclasses."""
164+
return json.dumps(item, separators=(",", ":"))
165+
166+
async def _deserialize_item(self, raw: str) -> TResponseInputItem:
167+
"""Deserialize a JSON string to an item. Can be overridden by subclasses."""
168+
return json.loads(raw) # type: ignore[no-any-return]
169+
170+
# ------------------------------------------------------------------
171+
# Session protocol implementation
172+
# ------------------------------------------------------------------
173+
174+
async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]:
175+
"""Retrieve the conversation history for this session.
176+
177+
Args:
178+
limit: Maximum number of items to retrieve. When ``None``, the
179+
effective limit is taken from :attr:`session_settings`.
180+
If that is also ``None``, all items are returned.
181+
The returned list is always in chronological (oldest-first)
182+
order.
183+
184+
Returns:
185+
List of input items representing the conversation history.
186+
"""
187+
session_limit = resolve_session_limit(limit, self.session_settings)
188+
189+
if session_limit is not None and session_limit <= 0:
190+
return []
191+
192+
query = self._messages_ref.order_by("seq")
193+
194+
if session_limit is not None:
195+
# Firestore has no native "last N" query; fetch all and slice.
196+
# For large histories consider storing a running offset in the
197+
# session metadata document and using a range query instead.
198+
docs_stream = query.stream()
199+
all_docs = [doc async for doc in docs_stream]
200+
docs = all_docs[-session_limit:]
201+
else:
202+
docs_stream = query.stream()
203+
docs = [doc async for doc in docs_stream]
204+
205+
items: list[TResponseInputItem] = []
206+
for doc in docs:
207+
data = doc.to_dict()
208+
if data is None:
209+
continue
210+
try:
211+
items.append(await self._deserialize_item(data["message_data"]))
212+
except (json.JSONDecodeError, KeyError, TypeError):
213+
# Skip corrupted or malformed documents.
214+
continue
215+
216+
return items
217+
218+
async def add_items(self, items: list[TResponseInputItem]) -> None:
219+
"""Add new items to the conversation history.
220+
221+
Args:
222+
items: List of input items to append to the session.
223+
"""
224+
if not items:
225+
return
226+
227+
import time
228+
229+
async with self._lock:
230+
# Atomically reserve a block of sequence numbers via a transaction.
231+
@self._client.transaction() # type: ignore[arg-type]
232+
async def _txn(transaction: Any) -> int:
233+
snap = await self._session_ref.get(transaction=transaction)
234+
current_seq: int = snap.get("_seq") if snap.exists else 0 # type: ignore[union-attr]
235+
new_seq = current_seq + len(items)
236+
now = int(time.time())
237+
if snap.exists:
238+
transaction.update(
239+
self._session_ref,
240+
{"_seq": new_seq, "updated_at": now},
241+
)
242+
else:
243+
transaction.set(
244+
self._session_ref,
245+
{
246+
"_seq": new_seq,
247+
"created_at": now,
248+
"updated_at": now,
249+
},
250+
)
251+
return current_seq
252+
253+
first_seq: int = await _txn() # type: ignore[call-arg]
254+
255+
# Write message documents outside the transaction (non-atomic batch
256+
# is fine here — sequence numbers are already reserved).
257+
batch = self._client.batch()
258+
for i, item in enumerate(items):
259+
doc_ref = self._messages_ref.document()
260+
batch.set(
261+
doc_ref,
262+
{
263+
"seq": first_seq + i,
264+
"message_data": await self._serialize_item(item),
265+
},
266+
)
267+
await batch.commit()
268+
269+
async def pop_item(self) -> TResponseInputItem | None:
270+
"""Remove and return the most recent item from the session.
271+
272+
Returns:
273+
The most recent item if it exists, ``None`` if the session is empty.
274+
"""
275+
async with self._lock:
276+
# Find the document with the highest seq value.
277+
query = self._messages_ref.order_by("seq", direction="DESCENDING").limit(1)
278+
docs = [doc async for doc in query.stream()]
279+
280+
if not docs:
281+
return None
282+
283+
doc = docs[0]
284+
data = doc.to_dict()
285+
await doc.reference.delete()
286+
287+
if data is None:
288+
return None
289+
290+
try:
291+
return await self._deserialize_item(data["message_data"])
292+
except (json.JSONDecodeError, KeyError, TypeError):
293+
return None
294+
295+
async def clear_session(self) -> None:
296+
"""Clear all items for this session."""
297+
async with self._lock:
298+
# Delete all message documents in batches of 500 (Firestore limit).
299+
batch_size = 500
300+
while True:
301+
docs = [doc async for doc in self._messages_ref.limit(batch_size).stream()]
302+
if not docs:
303+
break
304+
batch = self._client.batch()
305+
for doc in docs:
306+
batch.delete(doc.reference)
307+
await batch.commit()
308+
309+
# Delete the session metadata document.
310+
await self._session_ref.delete()
311+
312+
# ------------------------------------------------------------------
313+
# Lifecycle helpers
314+
# ------------------------------------------------------------------
315+
316+
async def close(self) -> None:
317+
"""Close the underlying Firestore client.
318+
319+
Only closes the client if this session owns it (i.e. it was created
320+
via :meth:`from_project`). If the client was injected externally the
321+
caller is responsible for managing its lifecycle.
322+
"""
323+
if self._owns_client:
324+
await self._client.close()
325+
326+
async def ping(self) -> bool:
327+
"""Test Firestore connectivity.
328+
329+
Returns:
330+
``True`` if the service is reachable, ``False`` otherwise.
331+
"""
332+
try:
333+
# A lightweight read against the session document is sufficient to
334+
# verify that the client can reach the Firestore service.
335+
await self._session_ref.get()
336+
return True
337+
except Exception:
338+
return False

0 commit comments

Comments
 (0)