|
6 | 6 |
|
7 | 7 | import json |
8 | 8 |
|
| 9 | +import anyio |
9 | 10 | import httpx |
10 | 11 | import mcp_types as types |
11 | 12 | import pytest |
12 | 13 | from mcp_types import RootsListChangedNotification |
13 | 14 | from starlette.applications import Starlette |
14 | 15 | from starlette.requests import Request |
15 | | -from starlette.responses import JSONResponse, Response |
| 16 | +from starlette.responses import JSONResponse, Response, StreamingResponse |
16 | 17 | from starlette.routing import Route |
17 | 18 |
|
18 | 19 | from mcp import ClientSession, MCPError |
@@ -72,6 +73,24 @@ async def handle_mcp_request(request: Request) -> Response: |
72 | 73 | return Starlette(debug=True, routes=[Route("/mcp", handle_mcp_request, methods=["POST"])]) |
73 | 74 |
|
74 | 75 |
|
| 76 | +def _create_empty_sse_response_app() -> Starlette: |
| 77 | + """Create a server that closes a POST SSE response without a JSON-RPC reply.""" |
| 78 | + |
| 79 | + async def handle_mcp_request(request: Request) -> Response: |
| 80 | + body = await request.body() |
| 81 | + data = json.loads(body) |
| 82 | + |
| 83 | + if data.get("method") == "initialize": |
| 84 | + return _init_json_response(data) |
| 85 | + |
| 86 | + if "id" not in data: |
| 87 | + return Response(status_code=202) |
| 88 | + |
| 89 | + return StreamingResponse(iter(()), media_type="text/event-stream") |
| 90 | + |
| 91 | + return Starlette(debug=True, routes=[Route("/mcp", handle_mcp_request, methods=["POST"])]) |
| 92 | + |
| 93 | + |
75 | 94 | async def test_non_compliant_notification_response() -> None: |
76 | 95 | """Verify the client ignores unexpected responses to notifications. |
77 | 96 |
|
@@ -117,6 +136,20 @@ async def test_unexpected_content_type_sends_jsonrpc_error() -> None: |
117 | 136 | await session.list_tools() |
118 | 137 |
|
119 | 138 |
|
| 139 | +async def test_empty_post_sse_response_unblocks_pending_tool_call() -> None: |
| 140 | + """An SSE response that closes before a JSON-RPC reply raises instead of hanging.""" |
| 141 | + async with httpx.AsyncClient(transport=httpx.ASGITransport(app=_create_empty_sse_response_app())) as client: |
| 142 | + async with streamable_http_client("http://localhost/mcp", http_client=client) as (read_stream, write_stream): |
| 143 | + async with ClientSession(read_stream, write_stream) as session: # pragma: no branch |
| 144 | + await session.initialize() |
| 145 | + |
| 146 | + with pytest.raises(MCPError) as exc_info: |
| 147 | + with anyio.fail_after(1): |
| 148 | + await session.call_tool("greet", {}) |
| 149 | + |
| 150 | + assert exc_info.value.error.code == types.CONNECTION_CLOSED |
| 151 | + |
| 152 | + |
120 | 153 | def _create_http_error_app(error_status: int, *, error_on_notifications: bool = False) -> Starlette: |
121 | 154 | """Create a server that returns an HTTP error for non-init requests.""" |
122 | 155 |
|
|
0 commit comments