Skip to content

Commit d988475

Browse files
committed
fix streamable http early SSE close
1 parent 220d362 commit d988475

2 files changed

Lines changed: 45 additions & 1 deletion

File tree

src/mcp/client/streamable_http.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from anyio.abc import TaskGroup
1414
from httpx_sse import EventSource, ServerSentEvent, aconnect_sse
1515
from mcp_types import (
16+
CONNECTION_CLOSED,
1617
INTERNAL_ERROR,
1718
INVALID_REQUEST,
1819
METHOD_NOT_FOUND,
@@ -381,6 +382,14 @@ async def _handle_sse_response(
381382
if last_event_id is not None: # pragma: no branch
382383
logger.info("SSE stream disconnected, reconnecting...")
383384
await self._handle_reconnection(ctx, last_event_id, retry_interval_ms)
385+
else:
386+
await self._send_connection_closed(ctx, original_request_id)
387+
388+
async def _send_connection_closed(self, ctx: RequestContext, request_id: RequestId) -> None:
389+
"""Resolve a pending POST SSE request when the stream closes before a reply."""
390+
error_data = ErrorData(code=CONNECTION_CLOSED, message="Connection closed")
391+
error_msg = SessionMessage(JSONRPCError(jsonrpc="2.0", id=request_id, error=error_data))
392+
await ctx.read_stream_writer.send(error_msg)
384393

385394
async def _handle_reconnection(
386395
self,
@@ -393,6 +402,8 @@ async def _handle_reconnection(
393402
# Bail if max retries exceeded
394403
if attempt >= MAX_RECONNECTION_ATTEMPTS: # pragma: no cover
395404
logger.debug(f"Max reconnection attempts ({MAX_RECONNECTION_ATTEMPTS}) exceeded")
405+
if isinstance(ctx.session_message.message, JSONRPCRequest):
406+
await self._send_connection_closed(ctx, ctx.session_message.message.id)
396407
return
397408

398409
# Always wait - use server value or default

tests/client/test_notification_response.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,14 @@
66

77
import json
88

9+
import anyio
910
import httpx
1011
import mcp_types as types
1112
import pytest
1213
from mcp_types import RootsListChangedNotification
1314
from starlette.applications import Starlette
1415
from starlette.requests import Request
15-
from starlette.responses import JSONResponse, Response
16+
from starlette.responses import JSONResponse, Response, StreamingResponse
1617
from starlette.routing import Route
1718

1819
from mcp import ClientSession, MCPError
@@ -72,6 +73,24 @@ async def handle_mcp_request(request: Request) -> Response:
7273
return Starlette(debug=True, routes=[Route("/mcp", handle_mcp_request, methods=["POST"])])
7374

7475

76+
def _create_empty_sse_response_app() -> Starlette:
77+
"""Create a server that closes a POST SSE response without a JSON-RPC reply."""
78+
79+
async def handle_mcp_request(request: Request) -> Response:
80+
body = await request.body()
81+
data = json.loads(body)
82+
83+
if data.get("method") == "initialize":
84+
return _init_json_response(data)
85+
86+
if "id" not in data:
87+
return Response(status_code=202)
88+
89+
return StreamingResponse(iter(()), media_type="text/event-stream")
90+
91+
return Starlette(debug=True, routes=[Route("/mcp", handle_mcp_request, methods=["POST"])])
92+
93+
7594
async def test_non_compliant_notification_response() -> None:
7695
"""Verify the client ignores unexpected responses to notifications.
7796
@@ -117,6 +136,20 @@ async def test_unexpected_content_type_sends_jsonrpc_error() -> None:
117136
await session.list_tools()
118137

119138

139+
async def test_empty_post_sse_response_unblocks_pending_tool_call() -> None:
140+
"""An SSE response that closes before a JSON-RPC reply raises instead of hanging."""
141+
async with httpx.AsyncClient(transport=httpx.ASGITransport(app=_create_empty_sse_response_app())) as client:
142+
async with streamable_http_client("http://localhost/mcp", http_client=client) as (read_stream, write_stream):
143+
async with ClientSession(read_stream, write_stream) as session: # pragma: no branch
144+
await session.initialize()
145+
146+
with pytest.raises(MCPError) as exc_info:
147+
with anyio.fail_after(1):
148+
await session.call_tool("greet", {})
149+
150+
assert exc_info.value.error.code == types.CONNECTION_CLOSED
151+
152+
120153
def _create_http_error_app(error_status: int, *, error_on_notifications: bool = False) -> Starlette:
121154
"""Create a server that returns an HTTP error for non-init requests."""
122155

0 commit comments

Comments
 (0)