|
5 | 5 | import inspect |
6 | 6 | import json |
7 | 7 | import weakref |
8 | | -from collections.abc import AsyncIterator, Awaitable, Mapping |
| 8 | +from collections.abc import AsyncIterator, Awaitable, Callable, Mapping |
9 | 9 | from contextvars import ContextVar |
10 | 10 | from dataclasses import asdict, dataclass, is_dataclass |
11 | 11 | from enum import Enum |
12 | | -from typing import TYPE_CHECKING, Any, Literal, Union, cast, overload |
| 12 | +from typing import TYPE_CHECKING, Any, Literal, cast, overload |
13 | 13 |
|
14 | 14 | import httpx |
15 | 15 | from openai import AsyncOpenAI, NotGiven, Omit, omit |
@@ -121,6 +121,106 @@ class _WebsocketRequestTimeouts: |
121 | 121 | recv: float | None |
122 | 122 |
|
123 | 123 |
|
| 124 | +class _ResponseStreamWithRequestId: |
| 125 | + """Wrap an SDK event stream and retain the originating request ID.""" |
| 126 | + |
| 127 | + _TERMINAL_EVENT_TYPES = { |
| 128 | + "response.completed", |
| 129 | + "response.failed", |
| 130 | + "response.incomplete", |
| 131 | + "response.error", |
| 132 | + } |
| 133 | + |
| 134 | + def __init__( |
| 135 | + self, |
| 136 | + stream: AsyncIterator[ResponseStreamEvent], |
| 137 | + *, |
| 138 | + request_id: str | None, |
| 139 | + cleanup: Callable[[], Awaitable[object]], |
| 140 | + ) -> None: |
| 141 | + self._stream = stream |
| 142 | + self.request_id = request_id |
| 143 | + self._cleanup = cleanup |
| 144 | + self._closed = False |
| 145 | + self._stream_close_complete = False |
| 146 | + self._cleanup_complete = False |
| 147 | + self._yielded_terminal_event = False |
| 148 | + |
| 149 | + def __aiter__(self) -> _ResponseStreamWithRequestId: |
| 150 | + return self |
| 151 | + |
| 152 | + async def __anext__(self) -> ResponseStreamEvent: |
| 153 | + if self._closed: |
| 154 | + raise StopAsyncIteration |
| 155 | + |
| 156 | + try: |
| 157 | + event = await self._stream.__anext__() |
| 158 | + except StopAsyncIteration: |
| 159 | + self._closed = True |
| 160 | + await self._cleanup_after_exhaustion() |
| 161 | + raise |
| 162 | + |
| 163 | + self._attach_request_id(event) |
| 164 | + event_type = getattr(event, "type", None) |
| 165 | + if event_type in self._TERMINAL_EVENT_TYPES: |
| 166 | + self._yielded_terminal_event = True |
| 167 | + return event |
| 168 | + |
| 169 | + async def aclose(self) -> None: |
| 170 | + self._closed = True |
| 171 | + try: |
| 172 | + await self._close_stream_once() |
| 173 | + finally: |
| 174 | + await self._cleanup_once() |
| 175 | + |
| 176 | + async def close(self) -> None: |
| 177 | + await self.aclose() |
| 178 | + |
| 179 | + def _attach_request_id(self, event: ResponseStreamEvent) -> None: |
| 180 | + if self.request_id is None: |
| 181 | + return |
| 182 | + |
| 183 | + response = getattr(event, "response", None) |
| 184 | + if response is None: |
| 185 | + return |
| 186 | + |
| 187 | + try: |
| 188 | + response._request_id = self.request_id |
| 189 | + except Exception: |
| 190 | + return |
| 191 | + |
| 192 | + async def _cleanup_once(self) -> None: |
| 193 | + if self._cleanup_complete: |
| 194 | + return |
| 195 | + self._cleanup_complete = True |
| 196 | + await self._cleanup() |
| 197 | + |
| 198 | + async def _cleanup_after_exhaustion(self) -> None: |
| 199 | + try: |
| 200 | + await self._cleanup_once() |
| 201 | + except Exception as exc: |
| 202 | + if self._yielded_terminal_event: |
| 203 | + logger.debug(f"Ignoring stream cleanup error after terminal event: {exc}") |
| 204 | + return |
| 205 | + raise |
| 206 | + |
| 207 | + async def _close_stream_once(self) -> None: |
| 208 | + if self._stream_close_complete: |
| 209 | + return |
| 210 | + self._stream_close_complete = True |
| 211 | + |
| 212 | + aclose = getattr(self._stream, "aclose", None) |
| 213 | + if callable(aclose): |
| 214 | + await aclose() |
| 215 | + return |
| 216 | + |
| 217 | + close = getattr(self._stream, "close", None) |
| 218 | + if callable(close): |
| 219 | + close_result = close() |
| 220 | + if inspect.isawaitable(close_result): |
| 221 | + await close_result |
| 222 | + |
| 223 | + |
124 | 224 | class ResponsesWebSocketError(RuntimeError): |
125 | 225 | """Error raised for websocket transport error frames.""" |
126 | 226 |
|
@@ -269,6 +369,7 @@ async def get_response( |
269 | 369 | output=response.output, |
270 | 370 | usage=usage, |
271 | 371 | response_id=response.id, |
| 372 | + request_id=getattr(response, "_request_id", None), |
272 | 373 | ) |
273 | 374 |
|
274 | 375 | async def stream_response( |
@@ -400,21 +501,46 @@ async def _fetch_response( |
400 | 501 | stream: Literal[True] | Literal[False] = False, |
401 | 502 | prompt: ResponsePromptParam | None = None, |
402 | 503 | ) -> Response | AsyncIterator[ResponseStreamEvent]: |
403 | | - response = await self._client.responses.create( |
404 | | - **self._build_response_create_kwargs( |
405 | | - system_instructions=system_instructions, |
406 | | - input=input, |
407 | | - model_settings=model_settings, |
408 | | - tools=tools, |
409 | | - output_schema=output_schema, |
410 | | - handoffs=handoffs, |
411 | | - previous_response_id=previous_response_id, |
412 | | - conversation_id=conversation_id, |
413 | | - stream=stream, |
414 | | - prompt=prompt, |
415 | | - ) |
| 504 | + create_kwargs = self._build_response_create_kwargs( |
| 505 | + system_instructions=system_instructions, |
| 506 | + input=input, |
| 507 | + model_settings=model_settings, |
| 508 | + tools=tools, |
| 509 | + output_schema=output_schema, |
| 510 | + handoffs=handoffs, |
| 511 | + previous_response_id=previous_response_id, |
| 512 | + conversation_id=conversation_id, |
| 513 | + stream=stream, |
| 514 | + prompt=prompt, |
| 515 | + ) |
| 516 | + |
| 517 | + if not stream: |
| 518 | + response = await self._client.responses.create(**create_kwargs) |
| 519 | + return cast(Response, response) |
| 520 | + |
| 521 | + streaming_response = getattr(self._client.responses, "with_streaming_response", None) |
| 522 | + stream_create = getattr(streaming_response, "create", None) |
| 523 | + if not callable(stream_create): |
| 524 | + # Some tests and custom clients only implement `responses.create()`. Fall back to the |
| 525 | + # older path in that case and simply omit request IDs for streamed calls. |
| 526 | + response = await self._client.responses.create(**create_kwargs) |
| 527 | + return cast(AsyncIterator[ResponseStreamEvent], response) |
| 528 | + |
| 529 | + # Keep the raw API response open while callers consume the SSE stream so we can expose |
| 530 | + # its request ID on terminal response payloads before cleanup closes the transport. |
| 531 | + api_response_cm = stream_create(**create_kwargs) |
| 532 | + api_response = await api_response_cm.__aenter__() |
| 533 | + try: |
| 534 | + stream_response = await api_response.parse() |
| 535 | + except BaseException as exc: |
| 536 | + await api_response_cm.__aexit__(type(exc), exc, exc.__traceback__) |
| 537 | + raise |
| 538 | + |
| 539 | + return _ResponseStreamWithRequestId( |
| 540 | + cast(AsyncIterator[ResponseStreamEvent], stream_response), |
| 541 | + request_id=getattr(api_response, "request_id", None), |
| 542 | + cleanup=lambda: api_response_cm.__aexit__(None, None, None), |
416 | 543 | ) |
417 | | - return cast(Union[Response, AsyncIterator[ResponseStreamEvent]], response) |
418 | 544 |
|
419 | 545 | def _build_response_create_kwargs( |
420 | 546 | self, |
@@ -601,7 +727,8 @@ class OpenAIResponsesWSModel(OpenAIResponsesModel): |
601 | 727 |
|
602 | 728 | The websocket transport currently sends `response.create` frames and always streams events. |
603 | 729 | `get_response()` is implemented by consuming the streamed events until a terminal response |
604 | | - event is received. |
| 730 | + event is received. Successful websocket responses do not currently expose a request ID, so |
| 731 | + `ModelResponse.request_id` remains `None` on this transport. |
605 | 732 | """ |
606 | 733 |
|
607 | 734 | def __init__( |
@@ -785,6 +912,9 @@ async def _iter_websocket_response_events( |
785 | 912 | received_any_event = True |
786 | 913 | raise ResponsesWebSocketError(payload) |
787 | 914 |
|
| 915 | + # Successful websocket frames currently expose no per-request ID. |
| 916 | + # Unlike the HTTP transport, the websocket upgrade response does not |
| 917 | + # include `x-request-id`, and success events carry no equivalent field. |
788 | 918 | event = _construct_response_stream_event_from_payload(payload) |
789 | 919 | received_any_event = True |
790 | 920 | is_terminal_event = event_type in { |
|
0 commit comments