|
5 | 5 | import inspect |
6 | 6 | import sys |
7 | 7 | from collections.abc import Awaitable |
8 | | -from contextlib import AbstractAsyncContextManager, AsyncExitStack |
| 8 | +from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager |
9 | 9 | from datetime import timedelta |
10 | 10 | from pathlib import Path |
11 | 11 | from typing import TYPE_CHECKING, Any, Callable, Literal, TypeVar, Union, cast |
@@ -61,6 +61,14 @@ class RequireApprovalObject(TypedDict, total=False): |
61 | 61 | T = TypeVar("T") |
62 | 62 |
|
63 | 63 |
|
| 64 | +class _SharedSessionRequestNeedsIsolation(Exception): |
| 65 | + """Raised when a shared-session request should be retried on an isolated session.""" |
| 66 | + |
| 67 | + |
| 68 | +class _IsolatedSessionRetryFailed(Exception): |
| 69 | + """Raised when an isolated-session retry fails after consuming retry budget.""" |
| 70 | + |
| 71 | + |
64 | 72 | class _UnsetType: |
65 | 73 | pass |
66 | 74 |
|
@@ -456,7 +464,7 @@ def invalidate_tools_cache(self): |
456 | 464 | """Invalidate the tools cache.""" |
457 | 465 | self._cache_dirty = True |
458 | 466 |
|
459 | | - def _extract_http_error_from_exception(self, e: Exception) -> Exception | None: |
| 467 | + def _extract_http_error_from_exception(self, e: BaseException) -> Exception | None: |
460 | 468 | """Extract HTTP error from exception or ExceptionGroup.""" |
461 | 469 | if isinstance(e, (httpx.HTTPStatusError, httpx.ConnectError, httpx.TimeoutException)): |
462 | 470 | return e |
@@ -1127,6 +1135,186 @@ def create_streams( |
1127 | 1135 | terminate_on_close=self.params.get("terminate_on_close", True), |
1128 | 1136 | ) |
1129 | 1137 |
|
| 1138 | + @asynccontextmanager |
| 1139 | + async def _isolated_client_session(self): |
| 1140 | + async with AsyncExitStack() as exit_stack: |
| 1141 | + transport = await exit_stack.enter_async_context(self.create_streams()) |
| 1142 | + read, write, *_ = transport |
| 1143 | + session = await exit_stack.enter_async_context( |
| 1144 | + ClientSession( |
| 1145 | + read, |
| 1146 | + write, |
| 1147 | + timedelta(seconds=self.client_session_timeout_seconds) |
| 1148 | + if self.client_session_timeout_seconds |
| 1149 | + else None, |
| 1150 | + message_handler=self.message_handler, |
| 1151 | + ) |
| 1152 | + ) |
| 1153 | + await session.initialize() |
| 1154 | + yield session |
| 1155 | + |
| 1156 | + async def _call_tool_with_session( |
| 1157 | + self, |
| 1158 | + session: ClientSession, |
| 1159 | + tool_name: str, |
| 1160 | + arguments: dict[str, Any] | None, |
| 1161 | + meta: dict[str, Any] | None = None, |
| 1162 | + ) -> CallToolResult: |
| 1163 | + if meta is None: |
| 1164 | + return await session.call_tool(tool_name, arguments) |
| 1165 | + return await session.call_tool(tool_name, arguments, meta=meta) |
| 1166 | + |
| 1167 | + def _should_retry_in_isolated_session(self, exc: BaseException) -> bool: |
| 1168 | + if isinstance(exc, (asyncio.CancelledError, httpx.ConnectError, httpx.TimeoutException)): |
| 1169 | + return True |
| 1170 | + if isinstance(exc, httpx.HTTPStatusError): |
| 1171 | + return exc.response.status_code >= 500 |
| 1172 | + if isinstance(exc, BaseExceptionGroup): |
| 1173 | + return bool(exc.exceptions) and all( |
| 1174 | + self._should_retry_in_isolated_session(inner) for inner in exc.exceptions |
| 1175 | + ) |
| 1176 | + return False |
| 1177 | + |
| 1178 | + async def _call_tool_with_shared_session( |
| 1179 | + self, |
| 1180 | + tool_name: str, |
| 1181 | + arguments: dict[str, Any] | None, |
| 1182 | + meta: dict[str, Any] | None = None, |
| 1183 | + *, |
| 1184 | + allow_isolated_retry: bool, |
| 1185 | + ) -> CallToolResult: |
| 1186 | + session = self.session |
| 1187 | + assert session is not None |
| 1188 | + try: |
| 1189 | + return await self._maybe_serialize_request( |
| 1190 | + lambda: self._call_tool_with_session(session, tool_name, arguments, meta) |
| 1191 | + ) |
| 1192 | + except BaseException as exc: |
| 1193 | + if allow_isolated_retry and self._should_retry_in_isolated_session(exc): |
| 1194 | + raise _SharedSessionRequestNeedsIsolation from exc |
| 1195 | + raise |
| 1196 | + |
| 1197 | + async def _call_tool_with_isolated_retry( |
| 1198 | + self, |
| 1199 | + tool_name: str, |
| 1200 | + arguments: dict[str, Any] | None, |
| 1201 | + meta: dict[str, Any] | None = None, |
| 1202 | + *, |
| 1203 | + allow_isolated_retry: bool, |
| 1204 | + ) -> tuple[CallToolResult, bool]: |
| 1205 | + request_task = asyncio.create_task( |
| 1206 | + self._call_tool_with_shared_session( |
| 1207 | + tool_name, |
| 1208 | + arguments, |
| 1209 | + meta, |
| 1210 | + allow_isolated_retry=allow_isolated_retry, |
| 1211 | + ) |
| 1212 | + ) |
| 1213 | + try: |
| 1214 | + return await asyncio.shield(request_task), False |
| 1215 | + except _SharedSessionRequestNeedsIsolation: |
| 1216 | + exit_stack = AsyncExitStack() |
| 1217 | + try: |
| 1218 | + session = await exit_stack.enter_async_context(self._isolated_client_session()) |
| 1219 | + except asyncio.CancelledError: |
| 1220 | + await exit_stack.aclose() |
| 1221 | + raise |
| 1222 | + except BaseException as exc: |
| 1223 | + await exit_stack.aclose() |
| 1224 | + raise _IsolatedSessionRetryFailed() from exc |
| 1225 | + try: |
| 1226 | + try: |
| 1227 | + result = await self._call_tool_with_session(session, tool_name, arguments, meta) |
| 1228 | + return result, True |
| 1229 | + except asyncio.CancelledError: |
| 1230 | + raise |
| 1231 | + except BaseException as exc: |
| 1232 | + raise _IsolatedSessionRetryFailed() from exc |
| 1233 | + finally: |
| 1234 | + await exit_stack.aclose() |
| 1235 | + except asyncio.CancelledError: |
| 1236 | + if not request_task.done(): |
| 1237 | + request_task.cancel() |
| 1238 | + try: |
| 1239 | + await request_task |
| 1240 | + except BaseException: |
| 1241 | + pass |
| 1242 | + raise |
| 1243 | + |
| 1244 | + async def call_tool( |
| 1245 | + self, |
| 1246 | + tool_name: str, |
| 1247 | + arguments: dict[str, Any] | None, |
| 1248 | + meta: dict[str, Any] | None = None, |
| 1249 | + ) -> CallToolResult: |
| 1250 | + if not self.session: |
| 1251 | + raise UserError("Server not initialized. Make sure you call `connect()` first.") |
| 1252 | + |
| 1253 | + try: |
| 1254 | + self._validate_required_parameters(tool_name=tool_name, arguments=arguments) |
| 1255 | + retries_used = 0 |
| 1256 | + first_attempt = True |
| 1257 | + while True: |
| 1258 | + if not first_attempt and self.max_retry_attempts != -1: |
| 1259 | + retries_used += 1 |
| 1260 | + allow_isolated_retry = ( |
| 1261 | + self.max_retry_attempts == -1 or retries_used < self.max_retry_attempts |
| 1262 | + ) |
| 1263 | + try: |
| 1264 | + result, used_isolated_retry = await self._call_tool_with_isolated_retry( |
| 1265 | + tool_name, |
| 1266 | + arguments, |
| 1267 | + meta, |
| 1268 | + allow_isolated_retry=allow_isolated_retry, |
| 1269 | + ) |
| 1270 | + if used_isolated_retry and self.max_retry_attempts != -1: |
| 1271 | + retries_used += 1 |
| 1272 | + return result |
| 1273 | + except _IsolatedSessionRetryFailed as exc: |
| 1274 | + retries_used += 1 |
| 1275 | + if self.max_retry_attempts != -1 and retries_used >= self.max_retry_attempts: |
| 1276 | + if exc.__cause__ is not None: |
| 1277 | + raise exc.__cause__ from exc |
| 1278 | + raise exc |
| 1279 | + backoff = self.retry_backoff_seconds_base * (2 ** (retries_used - 1)) |
| 1280 | + await asyncio.sleep(backoff) |
| 1281 | + except Exception: |
| 1282 | + if self.max_retry_attempts != -1 and retries_used >= self.max_retry_attempts: |
| 1283 | + raise |
| 1284 | + backoff = self.retry_backoff_seconds_base * (2**retries_used) |
| 1285 | + await asyncio.sleep(backoff) |
| 1286 | + first_attempt = False |
| 1287 | + except httpx.HTTPStatusError as e: |
| 1288 | + status_code = e.response.status_code |
| 1289 | + raise UserError( |
| 1290 | + f"Failed to call tool '{tool_name}' on MCP server '{self.name}': " |
| 1291 | + f"HTTP error {status_code}" |
| 1292 | + ) from e |
| 1293 | + except httpx.ConnectError as e: |
| 1294 | + raise UserError( |
| 1295 | + f"Failed to call tool '{tool_name}' on MCP server '{self.name}': Connection lost. " |
| 1296 | + f"The server may have disconnected." |
| 1297 | + ) from e |
| 1298 | + except BaseExceptionGroup as e: |
| 1299 | + http_error = self._extract_http_error_from_exception(e) |
| 1300 | + if isinstance(http_error, httpx.HTTPStatusError): |
| 1301 | + status_code = http_error.response.status_code |
| 1302 | + raise UserError( |
| 1303 | + f"Failed to call tool '{tool_name}' on MCP server '{self.name}': " |
| 1304 | + f"HTTP error {status_code}" |
| 1305 | + ) from http_error |
| 1306 | + if isinstance(http_error, httpx.ConnectError): |
| 1307 | + raise UserError( |
| 1308 | + f"Failed to call tool '{tool_name}' on MCP server '{self.name}': " |
| 1309 | + "Connection lost. The server may have disconnected." |
| 1310 | + ) from http_error |
| 1311 | + if isinstance(http_error, httpx.TimeoutException): |
| 1312 | + raise UserError( |
| 1313 | + f"Failed to call tool '{tool_name}' on MCP server '{self.name}': " |
| 1314 | + "Connection timeout." |
| 1315 | + ) from http_error |
| 1316 | + raise |
| 1317 | + |
1130 | 1318 | @property |
1131 | 1319 | def name(self) -> str: |
1132 | 1320 | """A readable name for the server.""" |
|
0 commit comments