|
2 | 2 |
|
3 | 3 | from __future__ import annotations |
4 | 4 |
|
| 5 | +import base64 |
5 | 6 | from unittest.mock import MagicMock, patch |
6 | 7 |
|
7 | 8 | import httpx |
8 | 9 | import pytest |
| 10 | +from anyio import create_memory_object_stream |
| 11 | +from mcp.shared.message import SessionMessage |
| 12 | +from mcp.types import JSONRPCMessage, JSONRPCNotification, JSONRPCRequest |
9 | 13 |
|
10 | 14 | from agents.mcp import MCPServerStreamableHttp |
| 15 | +from agents.mcp.server import ( |
| 16 | + _create_default_streamable_http_client, |
| 17 | + _InitializedNotificationTolerantStreamableHTTPTransport, |
| 18 | + _streamablehttp_client_with_transport, |
| 19 | +) |
11 | 20 |
|
12 | 21 |
|
13 | 22 | class TestMCPServerStreamableHttpClientFactory: |
@@ -247,3 +256,187 @@ def comprehensive_factory( |
247 | 256 | terminate_on_close=False, |
248 | 257 | httpx_client_factory=comprehensive_factory, |
249 | 258 | ) |
| 259 | + |
| 260 | + |
| 261 | +@pytest.mark.asyncio |
| 262 | +async def test_initialized_notification_failure_returns_synthetic_success(): |
| 263 | + async def handler(request: httpx.Request) -> httpx.Response: |
| 264 | + return httpx.Response(503, request=request) |
| 265 | + |
| 266 | + transport = _InitializedNotificationTolerantStreamableHTTPTransport("https://example.test/mcp") |
| 267 | + read_stream_writer, _ = create_memory_object_stream[SessionMessage | Exception](0) |
| 268 | + client = httpx.AsyncClient(transport=httpx.MockTransport(handler)) |
| 269 | + try: |
| 270 | + ctx = MagicMock() |
| 271 | + ctx.client = client |
| 272 | + ctx.read_stream_writer = read_stream_writer |
| 273 | + ctx.session_message = SessionMessage( |
| 274 | + JSONRPCMessage( |
| 275 | + JSONRPCNotification( |
| 276 | + jsonrpc="2.0", |
| 277 | + method="notifications/initialized", |
| 278 | + params={}, |
| 279 | + ) |
| 280 | + ) |
| 281 | + ) |
| 282 | + |
| 283 | + await transport._handle_post_request(ctx) |
| 284 | + finally: |
| 285 | + await client.aclose() |
| 286 | + await read_stream_writer.aclose() |
| 287 | + |
| 288 | + |
| 289 | +@pytest.mark.asyncio |
| 290 | +async def test_initialized_notification_transport_exception_returns_synthetic_success(): |
| 291 | + async def handler(request: httpx.Request) -> httpx.Response: |
| 292 | + raise httpx.ConnectError("boom", request=request) |
| 293 | + |
| 294 | + transport = _InitializedNotificationTolerantStreamableHTTPTransport("https://example.test/mcp") |
| 295 | + read_stream_writer, _ = create_memory_object_stream[SessionMessage | Exception](0) |
| 296 | + client = httpx.AsyncClient(transport=httpx.MockTransport(handler)) |
| 297 | + try: |
| 298 | + ctx = MagicMock() |
| 299 | + ctx.client = client |
| 300 | + ctx.read_stream_writer = read_stream_writer |
| 301 | + ctx.session_message = SessionMessage( |
| 302 | + JSONRPCMessage( |
| 303 | + JSONRPCNotification( |
| 304 | + jsonrpc="2.0", |
| 305 | + method="notifications/initialized", |
| 306 | + params={}, |
| 307 | + ) |
| 308 | + ) |
| 309 | + ) |
| 310 | + |
| 311 | + await transport._handle_post_request(ctx) |
| 312 | + finally: |
| 313 | + await client.aclose() |
| 314 | + await read_stream_writer.aclose() |
| 315 | + |
| 316 | + |
| 317 | +@pytest.mark.asyncio |
| 318 | +async def test_streamable_http_server_passes_ignore_initialized_notification_failure(): |
| 319 | + with patch("agents.mcp.server._streamablehttp_client_with_transport") as mock_client: |
| 320 | + mock_client.return_value = MagicMock() |
| 321 | + |
| 322 | + server = MCPServerStreamableHttp( |
| 323 | + params={ |
| 324 | + "url": "http://localhost:8000/mcp", |
| 325 | + "ignore_initialized_notification_failure": True, |
| 326 | + } |
| 327 | + ) |
| 328 | + |
| 329 | + server.create_streams() |
| 330 | + |
| 331 | + kwargs = mock_client.call_args.kwargs |
| 332 | + assert kwargs["url"] == "http://localhost:8000/mcp" |
| 333 | + assert kwargs["headers"] is None |
| 334 | + assert kwargs["timeout"] == 5 |
| 335 | + assert kwargs["sse_read_timeout"] == 300 |
| 336 | + assert kwargs["terminate_on_close"] is True |
| 337 | + assert ( |
| 338 | + kwargs["transport_factory"] is _InitializedNotificationTolerantStreamableHTTPTransport |
| 339 | + ) |
| 340 | + |
| 341 | + |
| 342 | +@pytest.mark.asyncio |
| 343 | +async def test_transport_preserves_non_initialized_failures(): |
| 344 | + async def handler(request: httpx.Request) -> httpx.Response: |
| 345 | + raise httpx.ConnectError("boom", request=request) |
| 346 | + |
| 347 | + transport = _InitializedNotificationTolerantStreamableHTTPTransport("https://example.test/mcp") |
| 348 | + read_stream_writer, _ = create_memory_object_stream[SessionMessage | Exception](0) |
| 349 | + client = httpx.AsyncClient(transport=httpx.MockTransport(handler)) |
| 350 | + try: |
| 351 | + ctx = MagicMock() |
| 352 | + ctx.client = client |
| 353 | + ctx.read_stream_writer = read_stream_writer |
| 354 | + ctx.session_message = SessionMessage( |
| 355 | + JSONRPCMessage( |
| 356 | + JSONRPCRequest( |
| 357 | + jsonrpc="2.0", |
| 358 | + id=1, |
| 359 | + method="tools/list", |
| 360 | + params={}, |
| 361 | + ) |
| 362 | + ) |
| 363 | + ) |
| 364 | + |
| 365 | + with pytest.raises(httpx.ConnectError): |
| 366 | + await transport._handle_post_request(ctx) |
| 367 | + finally: |
| 368 | + await client.aclose() |
| 369 | + await read_stream_writer.aclose() |
| 370 | + |
| 371 | + |
| 372 | +@pytest.mark.asyncio |
| 373 | +async def test_stream_client_preserves_custom_factory_headers_timeout_and_auth(): |
| 374 | + seen: dict[str, object] = {} |
| 375 | + |
| 376 | + class RecordingAuth(httpx.Auth): |
| 377 | + def auth_flow(self, request: httpx.Request): |
| 378 | + request.headers["Authorization"] = f"Basic {base64.b64encode(b'user:pass').decode()}" |
| 379 | + yield request |
| 380 | + |
| 381 | + async def handler(request: httpx.Request) -> httpx.Response: |
| 382 | + seen["request_headers"] = dict(request.headers) |
| 383 | + return httpx.Response(200, request=request) |
| 384 | + |
| 385 | + def base_factory( |
| 386 | + headers: dict[str, str] | None = None, |
| 387 | + timeout: httpx.Timeout | None = None, |
| 388 | + auth: httpx.Auth | None = None, |
| 389 | + ) -> httpx.AsyncClient: |
| 390 | + seen["factory_headers"] = headers |
| 391 | + seen["factory_timeout"] = timeout |
| 392 | + seen["factory_auth"] = auth |
| 393 | + return httpx.AsyncClient( |
| 394 | + headers=headers, |
| 395 | + timeout=timeout, |
| 396 | + auth=auth, |
| 397 | + transport=httpx.MockTransport(handler), |
| 398 | + ) |
| 399 | + |
| 400 | + timeout = httpx.Timeout(12.0) |
| 401 | + auth = RecordingAuth() |
| 402 | + async with _streamablehttp_client_with_transport( |
| 403 | + "https://example.test/mcp", |
| 404 | + headers={"X-Test": "value"}, |
| 405 | + timeout=12.0, |
| 406 | + sse_read_timeout=30.0, |
| 407 | + httpx_client_factory=base_factory, |
| 408 | + auth=auth, |
| 409 | + transport_factory=_InitializedNotificationTolerantStreamableHTTPTransport, |
| 410 | + ): |
| 411 | + pass |
| 412 | + |
| 413 | + assert seen["factory_headers"] == {"X-Test": "value"} |
| 414 | + seen_timeout = seen["factory_timeout"] |
| 415 | + assert isinstance(seen_timeout, httpx.Timeout) |
| 416 | + assert seen_timeout.connect == timeout.connect |
| 417 | + assert seen_timeout.read == 30.0 |
| 418 | + assert seen_timeout.write == timeout.write |
| 419 | + assert seen_timeout.pool == timeout.pool |
| 420 | + assert seen["factory_auth"] is auth |
| 421 | + |
| 422 | + |
| 423 | +@pytest.mark.asyncio |
| 424 | +async def test_default_streamable_http_client_matches_expected_defaults(): |
| 425 | + timeout = httpx.Timeout(12.0) |
| 426 | + auth = httpx.BasicAuth("user", "pass") |
| 427 | + |
| 428 | + client = _create_default_streamable_http_client( |
| 429 | + headers={"X-Test": "value"}, |
| 430 | + timeout=timeout, |
| 431 | + auth=auth, |
| 432 | + ) |
| 433 | + try: |
| 434 | + assert client.headers["X-Test"] == "value" |
| 435 | + assert client.timeout.connect == timeout.connect |
| 436 | + assert client.timeout.read == timeout.read |
| 437 | + assert client.timeout.write == timeout.write |
| 438 | + assert client.timeout.pool == timeout.pool |
| 439 | + assert client.auth is auth |
| 440 | + assert client.follow_redirects is True |
| 441 | + finally: |
| 442 | + await client.aclose() |
0 commit comments