Skip to content

Commit 58374fe

Browse files
feat(mcp): expose session_id on MCPServerStreamableHttp (#2708)
1 parent 47ff1da commit 58374fe

2 files changed

Lines changed: 146 additions & 1 deletion

File tree

src/agents/mcp/server.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -361,6 +361,7 @@ def __init__(
361361

362362
self.tool_filter = tool_filter
363363
self._serialize_session_requests = False
364+
self._get_session_id: GetSessionIdCallback | None = None
364365

365366
async def _maybe_serialize_request(self, func: Callable[[], Awaitable[T]]) -> T:
366367
if not self._serialize_session_requests:
@@ -515,7 +516,9 @@ async def connect(self):
515516
# streamablehttp_client returns (read, write, get_session_id)
516517
# sse_client returns (read, write)
517518

518-
read, write, *_ = transport
519+
read, write, *rest = transport
520+
# Capture the session-id callback when present (streamablehttp_client only).
521+
self._get_session_id = rest[0] if rest and callable(rest[0]) else None
519522

520523
session = await self.exit_stack.enter_async_context(
521524
ClientSession(
@@ -780,6 +783,7 @@ async def cleanup(self):
780783
logger.error(f"Error cleaning up server: {e}")
781784
finally:
782785
self.session = None
786+
self._get_session_id = None
783787

784788

785789
class MCPServerStdioParams(TypedDict):
@@ -1348,3 +1352,29 @@ async def call_tool(
13481352
def name(self) -> str:
13491353
"""A readable name for the server."""
13501354
return self._name
1355+
1356+
@property
1357+
def session_id(self) -> str | None:
1358+
"""The MCP session ID assigned by the server, or None if not yet connected
1359+
or if the server did not issue a session ID.
1360+
1361+
The session ID is stable for the lifetime of this server instance's connection.
1362+
You can persist it and pass it back via the Mcp-Session-Id request header
1363+
(params["headers"]) on a new MCPServerStreamableHttp instance to resume
1364+
the same server-side session across process restarts or stateless workers.
1365+
1366+
Example::
1367+
1368+
async with MCPServerStreamableHttp(params={"url": url}) as server:
1369+
session_id = server.session_id
1370+
1371+
# In a new worker / process:
1372+
async with MCPServerStreamableHttp(
1373+
params={"url": url, "headers": {"Mcp-Session-Id": session_id}}
1374+
) as server:
1375+
# Resumes the same server-side session.
1376+
...
1377+
"""
1378+
if self._get_session_id is None:
1379+
return None
1380+
return self._get_session_id()
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
"""Tests for MCPServerStreamableHttp.session_id property (issue #924)."""
2+
3+
from __future__ import annotations
4+
5+
from unittest.mock import AsyncMock, MagicMock, patch
6+
7+
import pytest
8+
9+
from agents.mcp import MCPServerStreamableHttp
10+
11+
12+
class TestStreamableHttpSessionId:
13+
"""Tests that the session_id property is correctly exposed."""
14+
15+
def test_session_id_is_none_before_connect(self):
16+
"""session_id should be None when the server has not been connected yet."""
17+
server = MCPServerStreamableHttp(params={"url": "http://localhost:9999/mcp"})
18+
assert server.session_id is None
19+
20+
def test_session_id_returns_none_when_callback_is_none(self):
21+
"""session_id should be None when _get_session_id callback is None."""
22+
server = MCPServerStreamableHttp(params={"url": "http://localhost:9999/mcp"})
23+
server._get_session_id = None
24+
assert server.session_id is None
25+
26+
def test_session_id_returns_callback_value(self):
27+
"""session_id should return the value from the get_session_id callback."""
28+
server = MCPServerStreamableHttp(params={"url": "http://localhost:9999/mcp"})
29+
mock_get_session_id = MagicMock(return_value="test-session-abc123")
30+
server._get_session_id = mock_get_session_id
31+
assert server.session_id == "test-session-abc123"
32+
mock_get_session_id.assert_called_once()
33+
34+
def test_session_id_returns_none_when_callback_returns_none(self):
35+
"""session_id should return None when the callback itself returns None."""
36+
server = MCPServerStreamableHttp(params={"url": "http://localhost:9999/mcp"})
37+
mock_get_session_id = MagicMock(return_value=None)
38+
server._get_session_id = mock_get_session_id
39+
assert server.session_id is None
40+
41+
def test_session_id_reflects_updated_callback_value(self):
42+
"""session_id should reflect the latest value from the callback each time."""
43+
server = MCPServerStreamableHttp(params={"url": "http://localhost:9999/mcp"})
44+
call_count = 0
45+
46+
def changing_callback() -> str | None:
47+
nonlocal call_count
48+
call_count += 1
49+
return f"session-{call_count}"
50+
51+
server._get_session_id = changing_callback
52+
assert server.session_id == "session-1"
53+
assert server.session_id == "session-2"
54+
55+
@pytest.mark.asyncio
56+
async def test_connect_captures_get_session_id_callback(self):
57+
"""connect() should capture the third element of the transport tuple as _get_session_id."""
58+
server = MCPServerStreamableHttp(params={"url": "http://localhost:9999/mcp"})
59+
60+
mock_read = AsyncMock()
61+
mock_write = AsyncMock()
62+
mock_get_session_id = MagicMock(return_value="captured-session-xyz")
63+
64+
mock_initialize_result = MagicMock()
65+
mock_session = AsyncMock()
66+
mock_session.initialize = AsyncMock(return_value=mock_initialize_result)
67+
68+
# Simulate the full 3-tuple that streamablehttp_client returns
69+
transport_tuple = (mock_read, mock_write, mock_get_session_id)
70+
71+
with patch("agents.mcp.server.ClientSession") as mock_client_session_cls:
72+
mock_client_session_cls.return_value.__aenter__ = AsyncMock(return_value=mock_session)
73+
mock_client_session_cls.return_value.__aexit__ = AsyncMock(return_value=None)
74+
75+
with patch.object(
76+
server,
77+
"create_streams",
78+
) as mock_create_streams:
79+
mock_cm = MagicMock()
80+
mock_cm.__aenter__ = AsyncMock(return_value=transport_tuple)
81+
mock_cm.__aexit__ = AsyncMock(return_value=None)
82+
mock_create_streams.return_value = mock_cm
83+
84+
with patch.object(server.exit_stack, "enter_async_context") as mock_enter:
85+
# First call returns transport, second call returns session
86+
mock_enter.side_effect = [transport_tuple, mock_session]
87+
mock_session.initialize.return_value = mock_initialize_result
88+
89+
await server.connect()
90+
91+
# After connect, _get_session_id should be the callable from the transport
92+
assert server._get_session_id is mock_get_session_id
93+
assert server.session_id == "captured-session-xyz"
94+
95+
96+
@pytest.mark.asyncio
97+
async def test_session_id_is_none_after_cleanup():
98+
"""session_id must return None after disconnect (cleanup clears _get_session_id)."""
99+
server = MCPServerStreamableHttp(params={"url": "http://localhost:8000/mcp"})
100+
101+
mock_get_session_id = MagicMock(return_value="session-to-clear")
102+
# Manually inject a session-id callback to simulate a connected state
103+
server._get_session_id = mock_get_session_id
104+
server.session = MagicMock() # pretend connected
105+
106+
assert server.session_id == "session-to-clear"
107+
108+
# Now simulate cleanup completing (exit_stack.aclose is a no-op here)
109+
with patch.object(server.exit_stack, "aclose", new_callable=AsyncMock):
110+
await server.cleanup()
111+
112+
# After cleanup both session and _get_session_id must be None
113+
assert server.session is None
114+
assert server._get_session_id is None
115+
assert server.session_id is None

0 commit comments

Comments
 (0)