Skip to content

Commit 9088496

Browse files
fix: optionize initialized notification tolerance (#2765)
1 parent 687e974 commit 9088496

File tree

2 files changed

+313
-4
lines changed

2 files changed

+313
-4
lines changed

src/agents/mcp/server.py

Lines changed: 120 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,13 @@
44
import asyncio
55
import inspect
66
import sys
7-
from collections.abc import Awaitable
7+
from collections.abc import AsyncGenerator, Awaitable
88
from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager
99
from datetime import timedelta
1010
from pathlib import Path
1111
from typing import TYPE_CHECKING, Any, Callable, Literal, TypeVar, Union, cast
1212

13+
import anyio
1314
import httpx
1415

1516
if sys.version_info < (3, 11):
@@ -19,7 +20,11 @@
1920
from mcp import ClientSession, StdioServerParameters, Tool as MCPTool, stdio_client
2021
from mcp.client.session import MessageHandlerFnT
2122
from mcp.client.sse import sse_client
22-
from mcp.client.streamable_http import GetSessionIdCallback, streamablehttp_client
23+
from mcp.client.streamable_http import (
24+
GetSessionIdCallback,
25+
StreamableHTTPTransport,
26+
streamablehttp_client,
27+
)
2328
from mcp.shared.exceptions import McpError
2429
from mcp.shared.message import SessionMessage
2530
from mcp.types import (
@@ -71,6 +76,101 @@ class RequireApprovalObject(TypedDict, total=False):
7176
T = TypeVar("T")
7277

7378

79+
def _create_default_streamable_http_client(
80+
headers: dict[str, str] | None = None,
81+
timeout: httpx.Timeout | None = None,
82+
auth: httpx.Auth | None = None,
83+
) -> httpx.AsyncClient:
84+
kwargs: dict[str, Any] = {"follow_redirects": True}
85+
if timeout is not None:
86+
kwargs["timeout"] = timeout
87+
if headers is not None:
88+
kwargs["headers"] = headers
89+
if auth is not None:
90+
kwargs["auth"] = auth
91+
return httpx.AsyncClient(**kwargs)
92+
93+
94+
class _InitializedNotificationTolerantStreamableHTTPTransport(StreamableHTTPTransport):
95+
async def _handle_post_request(self, ctx: Any) -> None:
96+
message = ctx.session_message.message
97+
if not self._is_initialized_notification(message):
98+
await super()._handle_post_request(ctx)
99+
return
100+
101+
try:
102+
await super()._handle_post_request(ctx)
103+
except httpx.HTTPError:
104+
logger.warning(
105+
"Ignoring initialized notification HTTP failure",
106+
exc_info=True,
107+
)
108+
return
109+
110+
111+
@asynccontextmanager
112+
async def _streamablehttp_client_with_transport(
113+
url: str,
114+
*,
115+
headers: dict[str, str] | None = None,
116+
timeout: float | timedelta = 30,
117+
sse_read_timeout: float | timedelta = 60 * 5,
118+
terminate_on_close: bool = True,
119+
httpx_client_factory: HttpClientFactory = _create_default_streamable_http_client,
120+
auth: httpx.Auth | None = None,
121+
transport_factory: Callable[[str], StreamableHTTPTransport] = StreamableHTTPTransport,
122+
) -> AsyncGenerator[MCPStreamTransport, None]:
123+
timeout_seconds = timeout.total_seconds() if isinstance(timeout, timedelta) else timeout
124+
sse_read_timeout_seconds = (
125+
sse_read_timeout.total_seconds()
126+
if isinstance(sse_read_timeout, timedelta)
127+
else sse_read_timeout
128+
)
129+
130+
client = httpx_client_factory(
131+
headers=headers,
132+
timeout=httpx.Timeout(timeout_seconds, read=sse_read_timeout_seconds),
133+
auth=auth,
134+
)
135+
transport = transport_factory(url)
136+
read_stream_writer, read_stream = anyio.create_memory_object_stream[SessionMessage | Exception](
137+
0
138+
)
139+
write_stream, write_stream_reader = anyio.create_memory_object_stream[SessionMessage](0)
140+
141+
async with client:
142+
async with anyio.create_task_group() as tg:
143+
try:
144+
logger.debug(f"Connecting to StreamableHTTP endpoint: {url}")
145+
146+
def start_get_stream() -> None:
147+
tg.start_soon(transport.handle_get_stream, client, read_stream_writer)
148+
149+
tg.start_soon(
150+
transport.post_writer,
151+
client,
152+
write_stream_reader,
153+
read_stream_writer,
154+
write_stream,
155+
start_get_stream,
156+
tg,
157+
)
158+
159+
try:
160+
yield (
161+
read_stream,
162+
write_stream,
163+
transport.get_session_id,
164+
)
165+
finally:
166+
if transport.session_id and terminate_on_close:
167+
await transport.terminate_session(client)
168+
tg.cancel_scope.cancel()
169+
finally:
170+
await read_stream_writer.aclose()
171+
await write_stream.aclose()
172+
173+
74174
class _SharedSessionRequestNeedsIsolation(Exception):
75175
"""Raised when a shared-session request should be retried on an isolated session."""
76176

@@ -1160,6 +1260,14 @@ class MCPServerStreamableHttpParams(TypedDict):
11601260
transport.
11611261
"""
11621262

1263+
ignore_initialized_notification_failure: NotRequired[bool]
1264+
"""Whether to ignore failures when sending the best-effort
1265+
``notifications/initialized`` POST.
1266+
1267+
Defaults to ``False``. When set to ``True``, initialized-notification failures are
1268+
logged and ignored so subsequent requests on the same transport can continue.
1269+
"""
1270+
11631271

11641272
class MCPServerStreamableHttp(_MCPServerWithClientSession):
11651273
"""MCP server implementation that uses the Streamable HTTP transport. See the [spec]
@@ -1250,8 +1358,16 @@ def create_streams(
12501358
"sse_read_timeout": self.params.get("sse_read_timeout", 60 * 5),
12511359
"terminate_on_close": self.params.get("terminate_on_close", True),
12521360
}
1253-
if "httpx_client_factory" in self.params:
1254-
kwargs["httpx_client_factory"] = self.params["httpx_client_factory"]
1361+
httpx_client_factory = self.params.get("httpx_client_factory")
1362+
if self.params.get("ignore_initialized_notification_failure", False):
1363+
return _streamablehttp_client_with_transport(
1364+
**kwargs,
1365+
httpx_client_factory=httpx_client_factory or _create_default_streamable_http_client,
1366+
auth=self.params.get("auth"),
1367+
transport_factory=_InitializedNotificationTolerantStreamableHTTPTransport,
1368+
)
1369+
if httpx_client_factory is not None:
1370+
kwargs["httpx_client_factory"] = httpx_client_factory
12551371
if "auth" in self.params:
12561372
kwargs["auth"] = self.params["auth"]
12571373
return streamablehttp_client(**kwargs)

tests/mcp/test_streamable_http_client_factory.py

Lines changed: 193 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,21 @@
22

33
from __future__ import annotations
44

5+
import base64
56
from unittest.mock import MagicMock, patch
67

78
import httpx
89
import pytest
10+
from anyio import create_memory_object_stream
11+
from mcp.shared.message import SessionMessage
12+
from mcp.types import JSONRPCMessage, JSONRPCNotification, JSONRPCRequest
913

1014
from agents.mcp import MCPServerStreamableHttp
15+
from agents.mcp.server import (
16+
_create_default_streamable_http_client,
17+
_InitializedNotificationTolerantStreamableHTTPTransport,
18+
_streamablehttp_client_with_transport,
19+
)
1120

1221

1322
class TestMCPServerStreamableHttpClientFactory:
@@ -247,3 +256,187 @@ def comprehensive_factory(
247256
terminate_on_close=False,
248257
httpx_client_factory=comprehensive_factory,
249258
)
259+
260+
261+
@pytest.mark.asyncio
262+
async def test_initialized_notification_failure_returns_synthetic_success():
263+
async def handler(request: httpx.Request) -> httpx.Response:
264+
return httpx.Response(503, request=request)
265+
266+
transport = _InitializedNotificationTolerantStreamableHTTPTransport("https://example.test/mcp")
267+
read_stream_writer, _ = create_memory_object_stream[SessionMessage | Exception](0)
268+
client = httpx.AsyncClient(transport=httpx.MockTransport(handler))
269+
try:
270+
ctx = MagicMock()
271+
ctx.client = client
272+
ctx.read_stream_writer = read_stream_writer
273+
ctx.session_message = SessionMessage(
274+
JSONRPCMessage(
275+
JSONRPCNotification(
276+
jsonrpc="2.0",
277+
method="notifications/initialized",
278+
params={},
279+
)
280+
)
281+
)
282+
283+
await transport._handle_post_request(ctx)
284+
finally:
285+
await client.aclose()
286+
await read_stream_writer.aclose()
287+
288+
289+
@pytest.mark.asyncio
290+
async def test_initialized_notification_transport_exception_returns_synthetic_success():
291+
async def handler(request: httpx.Request) -> httpx.Response:
292+
raise httpx.ConnectError("boom", request=request)
293+
294+
transport = _InitializedNotificationTolerantStreamableHTTPTransport("https://example.test/mcp")
295+
read_stream_writer, _ = create_memory_object_stream[SessionMessage | Exception](0)
296+
client = httpx.AsyncClient(transport=httpx.MockTransport(handler))
297+
try:
298+
ctx = MagicMock()
299+
ctx.client = client
300+
ctx.read_stream_writer = read_stream_writer
301+
ctx.session_message = SessionMessage(
302+
JSONRPCMessage(
303+
JSONRPCNotification(
304+
jsonrpc="2.0",
305+
method="notifications/initialized",
306+
params={},
307+
)
308+
)
309+
)
310+
311+
await transport._handle_post_request(ctx)
312+
finally:
313+
await client.aclose()
314+
await read_stream_writer.aclose()
315+
316+
317+
@pytest.mark.asyncio
318+
async def test_streamable_http_server_passes_ignore_initialized_notification_failure():
319+
with patch("agents.mcp.server._streamablehttp_client_with_transport") as mock_client:
320+
mock_client.return_value = MagicMock()
321+
322+
server = MCPServerStreamableHttp(
323+
params={
324+
"url": "http://localhost:8000/mcp",
325+
"ignore_initialized_notification_failure": True,
326+
}
327+
)
328+
329+
server.create_streams()
330+
331+
kwargs = mock_client.call_args.kwargs
332+
assert kwargs["url"] == "http://localhost:8000/mcp"
333+
assert kwargs["headers"] is None
334+
assert kwargs["timeout"] == 5
335+
assert kwargs["sse_read_timeout"] == 300
336+
assert kwargs["terminate_on_close"] is True
337+
assert (
338+
kwargs["transport_factory"] is _InitializedNotificationTolerantStreamableHTTPTransport
339+
)
340+
341+
342+
@pytest.mark.asyncio
343+
async def test_transport_preserves_non_initialized_failures():
344+
async def handler(request: httpx.Request) -> httpx.Response:
345+
raise httpx.ConnectError("boom", request=request)
346+
347+
transport = _InitializedNotificationTolerantStreamableHTTPTransport("https://example.test/mcp")
348+
read_stream_writer, _ = create_memory_object_stream[SessionMessage | Exception](0)
349+
client = httpx.AsyncClient(transport=httpx.MockTransport(handler))
350+
try:
351+
ctx = MagicMock()
352+
ctx.client = client
353+
ctx.read_stream_writer = read_stream_writer
354+
ctx.session_message = SessionMessage(
355+
JSONRPCMessage(
356+
JSONRPCRequest(
357+
jsonrpc="2.0",
358+
id=1,
359+
method="tools/list",
360+
params={},
361+
)
362+
)
363+
)
364+
365+
with pytest.raises(httpx.ConnectError):
366+
await transport._handle_post_request(ctx)
367+
finally:
368+
await client.aclose()
369+
await read_stream_writer.aclose()
370+
371+
372+
@pytest.mark.asyncio
373+
async def test_stream_client_preserves_custom_factory_headers_timeout_and_auth():
374+
seen: dict[str, object] = {}
375+
376+
class RecordingAuth(httpx.Auth):
377+
def auth_flow(self, request: httpx.Request):
378+
request.headers["Authorization"] = f"Basic {base64.b64encode(b'user:pass').decode()}"
379+
yield request
380+
381+
async def handler(request: httpx.Request) -> httpx.Response:
382+
seen["request_headers"] = dict(request.headers)
383+
return httpx.Response(200, request=request)
384+
385+
def base_factory(
386+
headers: dict[str, str] | None = None,
387+
timeout: httpx.Timeout | None = None,
388+
auth: httpx.Auth | None = None,
389+
) -> httpx.AsyncClient:
390+
seen["factory_headers"] = headers
391+
seen["factory_timeout"] = timeout
392+
seen["factory_auth"] = auth
393+
return httpx.AsyncClient(
394+
headers=headers,
395+
timeout=timeout,
396+
auth=auth,
397+
transport=httpx.MockTransport(handler),
398+
)
399+
400+
timeout = httpx.Timeout(12.0)
401+
auth = RecordingAuth()
402+
async with _streamablehttp_client_with_transport(
403+
"https://example.test/mcp",
404+
headers={"X-Test": "value"},
405+
timeout=12.0,
406+
sse_read_timeout=30.0,
407+
httpx_client_factory=base_factory,
408+
auth=auth,
409+
transport_factory=_InitializedNotificationTolerantStreamableHTTPTransport,
410+
):
411+
pass
412+
413+
assert seen["factory_headers"] == {"X-Test": "value"}
414+
seen_timeout = seen["factory_timeout"]
415+
assert isinstance(seen_timeout, httpx.Timeout)
416+
assert seen_timeout.connect == timeout.connect
417+
assert seen_timeout.read == 30.0
418+
assert seen_timeout.write == timeout.write
419+
assert seen_timeout.pool == timeout.pool
420+
assert seen["factory_auth"] is auth
421+
422+
423+
@pytest.mark.asyncio
424+
async def test_default_streamable_http_client_matches_expected_defaults():
425+
timeout = httpx.Timeout(12.0)
426+
auth = httpx.BasicAuth("user", "pass")
427+
428+
client = _create_default_streamable_http_client(
429+
headers={"X-Test": "value"},
430+
timeout=timeout,
431+
auth=auth,
432+
)
433+
try:
434+
assert client.headers["X-Test"] == "value"
435+
assert client.timeout.connect == timeout.connect
436+
assert client.timeout.read == timeout.read
437+
assert client.timeout.write == timeout.write
438+
assert client.timeout.pool == timeout.pool
439+
assert client.auth is auth
440+
assert client.follow_redirects is True
441+
finally:
442+
await client.aclose()

0 commit comments

Comments
 (0)