Skip to content

Commit c4c1772

Browse files
Retry transient streamable-http MCP tool failures on isolated session (#2703)
1 parent 29f422b commit c4c1772

File tree

2 files changed

+499
-6
lines changed

2 files changed

+499
-6
lines changed

src/agents/mcp/server.py

Lines changed: 190 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import inspect
66
import sys
77
from collections.abc import Awaitable
8-
from contextlib import AbstractAsyncContextManager, AsyncExitStack
8+
from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager
99
from datetime import timedelta
1010
from pathlib import Path
1111
from typing import TYPE_CHECKING, Any, Callable, Literal, TypeVar, Union, cast
@@ -61,6 +61,14 @@ class RequireApprovalObject(TypedDict, total=False):
6161
T = TypeVar("T")
6262

6363

64+
class _SharedSessionRequestNeedsIsolation(Exception):
65+
"""Raised when a shared-session request should be retried on an isolated session."""
66+
67+
68+
class _IsolatedSessionRetryFailed(Exception):
69+
"""Raised when an isolated-session retry fails after consuming retry budget."""
70+
71+
6472
class _UnsetType:
6573
pass
6674

@@ -456,7 +464,7 @@ def invalidate_tools_cache(self):
456464
"""Invalidate the tools cache."""
457465
self._cache_dirty = True
458466

459-
def _extract_http_error_from_exception(self, e: Exception) -> Exception | None:
467+
def _extract_http_error_from_exception(self, e: BaseException) -> Exception | None:
460468
"""Extract HTTP error from exception or ExceptionGroup."""
461469
if isinstance(e, (httpx.HTTPStatusError, httpx.ConnectError, httpx.TimeoutException)):
462470
return e
@@ -1127,6 +1135,186 @@ def create_streams(
11271135
terminate_on_close=self.params.get("terminate_on_close", True),
11281136
)
11291137

1138+
@asynccontextmanager
1139+
async def _isolated_client_session(self):
1140+
async with AsyncExitStack() as exit_stack:
1141+
transport = await exit_stack.enter_async_context(self.create_streams())
1142+
read, write, *_ = transport
1143+
session = await exit_stack.enter_async_context(
1144+
ClientSession(
1145+
read,
1146+
write,
1147+
timedelta(seconds=self.client_session_timeout_seconds)
1148+
if self.client_session_timeout_seconds
1149+
else None,
1150+
message_handler=self.message_handler,
1151+
)
1152+
)
1153+
await session.initialize()
1154+
yield session
1155+
1156+
async def _call_tool_with_session(
1157+
self,
1158+
session: ClientSession,
1159+
tool_name: str,
1160+
arguments: dict[str, Any] | None,
1161+
meta: dict[str, Any] | None = None,
1162+
) -> CallToolResult:
1163+
if meta is None:
1164+
return await session.call_tool(tool_name, arguments)
1165+
return await session.call_tool(tool_name, arguments, meta=meta)
1166+
1167+
def _should_retry_in_isolated_session(self, exc: BaseException) -> bool:
1168+
if isinstance(exc, (asyncio.CancelledError, httpx.ConnectError, httpx.TimeoutException)):
1169+
return True
1170+
if isinstance(exc, httpx.HTTPStatusError):
1171+
return exc.response.status_code >= 500
1172+
if isinstance(exc, BaseExceptionGroup):
1173+
return bool(exc.exceptions) and all(
1174+
self._should_retry_in_isolated_session(inner) for inner in exc.exceptions
1175+
)
1176+
return False
1177+
1178+
async def _call_tool_with_shared_session(
1179+
self,
1180+
tool_name: str,
1181+
arguments: dict[str, Any] | None,
1182+
meta: dict[str, Any] | None = None,
1183+
*,
1184+
allow_isolated_retry: bool,
1185+
) -> CallToolResult:
1186+
session = self.session
1187+
assert session is not None
1188+
try:
1189+
return await self._maybe_serialize_request(
1190+
lambda: self._call_tool_with_session(session, tool_name, arguments, meta)
1191+
)
1192+
except BaseException as exc:
1193+
if allow_isolated_retry and self._should_retry_in_isolated_session(exc):
1194+
raise _SharedSessionRequestNeedsIsolation from exc
1195+
raise
1196+
1197+
async def _call_tool_with_isolated_retry(
1198+
self,
1199+
tool_name: str,
1200+
arguments: dict[str, Any] | None,
1201+
meta: dict[str, Any] | None = None,
1202+
*,
1203+
allow_isolated_retry: bool,
1204+
) -> tuple[CallToolResult, bool]:
1205+
request_task = asyncio.create_task(
1206+
self._call_tool_with_shared_session(
1207+
tool_name,
1208+
arguments,
1209+
meta,
1210+
allow_isolated_retry=allow_isolated_retry,
1211+
)
1212+
)
1213+
try:
1214+
return await asyncio.shield(request_task), False
1215+
except _SharedSessionRequestNeedsIsolation:
1216+
exit_stack = AsyncExitStack()
1217+
try:
1218+
session = await exit_stack.enter_async_context(self._isolated_client_session())
1219+
except asyncio.CancelledError:
1220+
await exit_stack.aclose()
1221+
raise
1222+
except BaseException as exc:
1223+
await exit_stack.aclose()
1224+
raise _IsolatedSessionRetryFailed() from exc
1225+
try:
1226+
try:
1227+
result = await self._call_tool_with_session(session, tool_name, arguments, meta)
1228+
return result, True
1229+
except asyncio.CancelledError:
1230+
raise
1231+
except BaseException as exc:
1232+
raise _IsolatedSessionRetryFailed() from exc
1233+
finally:
1234+
await exit_stack.aclose()
1235+
except asyncio.CancelledError:
1236+
if not request_task.done():
1237+
request_task.cancel()
1238+
try:
1239+
await request_task
1240+
except BaseException:
1241+
pass
1242+
raise
1243+
1244+
async def call_tool(
1245+
self,
1246+
tool_name: str,
1247+
arguments: dict[str, Any] | None,
1248+
meta: dict[str, Any] | None = None,
1249+
) -> CallToolResult:
1250+
if not self.session:
1251+
raise UserError("Server not initialized. Make sure you call `connect()` first.")
1252+
1253+
try:
1254+
self._validate_required_parameters(tool_name=tool_name, arguments=arguments)
1255+
retries_used = 0
1256+
first_attempt = True
1257+
while True:
1258+
if not first_attempt and self.max_retry_attempts != -1:
1259+
retries_used += 1
1260+
allow_isolated_retry = (
1261+
self.max_retry_attempts == -1 or retries_used < self.max_retry_attempts
1262+
)
1263+
try:
1264+
result, used_isolated_retry = await self._call_tool_with_isolated_retry(
1265+
tool_name,
1266+
arguments,
1267+
meta,
1268+
allow_isolated_retry=allow_isolated_retry,
1269+
)
1270+
if used_isolated_retry and self.max_retry_attempts != -1:
1271+
retries_used += 1
1272+
return result
1273+
except _IsolatedSessionRetryFailed as exc:
1274+
retries_used += 1
1275+
if self.max_retry_attempts != -1 and retries_used >= self.max_retry_attempts:
1276+
if exc.__cause__ is not None:
1277+
raise exc.__cause__ from exc
1278+
raise exc
1279+
backoff = self.retry_backoff_seconds_base * (2 ** (retries_used - 1))
1280+
await asyncio.sleep(backoff)
1281+
except Exception:
1282+
if self.max_retry_attempts != -1 and retries_used >= self.max_retry_attempts:
1283+
raise
1284+
backoff = self.retry_backoff_seconds_base * (2**retries_used)
1285+
await asyncio.sleep(backoff)
1286+
first_attempt = False
1287+
except httpx.HTTPStatusError as e:
1288+
status_code = e.response.status_code
1289+
raise UserError(
1290+
f"Failed to call tool '{tool_name}' on MCP server '{self.name}': "
1291+
f"HTTP error {status_code}"
1292+
) from e
1293+
except httpx.ConnectError as e:
1294+
raise UserError(
1295+
f"Failed to call tool '{tool_name}' on MCP server '{self.name}': Connection lost. "
1296+
f"The server may have disconnected."
1297+
) from e
1298+
except BaseExceptionGroup as e:
1299+
http_error = self._extract_http_error_from_exception(e)
1300+
if isinstance(http_error, httpx.HTTPStatusError):
1301+
status_code = http_error.response.status_code
1302+
raise UserError(
1303+
f"Failed to call tool '{tool_name}' on MCP server '{self.name}': "
1304+
f"HTTP error {status_code}"
1305+
) from http_error
1306+
if isinstance(http_error, httpx.ConnectError):
1307+
raise UserError(
1308+
f"Failed to call tool '{tool_name}' on MCP server '{self.name}': "
1309+
"Connection lost. The server may have disconnected."
1310+
) from http_error
1311+
if isinstance(http_error, httpx.TimeoutException):
1312+
raise UserError(
1313+
f"Failed to call tool '{tool_name}' on MCP server '{self.name}': "
1314+
"Connection timeout."
1315+
) from http_error
1316+
raise
1317+
11301318
@property
11311319
def name(self) -> str:
11321320
"""A readable name for the server."""

0 commit comments

Comments
 (0)