Skip to content

Commit 159beb5

Browse files
authored
fix: #1121 expose model request IDs on raw responses (#2552)
1 parent f04ab55 commit 159beb5

7 files changed

Lines changed: 573 additions & 19 deletions

File tree

src/agents/items.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -483,6 +483,9 @@ class ModelResponse:
483483
be passed to `Runner.run`.
484484
"""
485485

486+
request_id: str | None = None
487+
"""The transport request ID for this model call, if provided by the model SDK."""
488+
486489
def to_input_items(self) -> list[TResponseInputItem]:
487490
"""Convert the output into a list of input items suitable for passing to the model."""
488491
# We happen to know that the shape of the Pydantic output items are the same as the

src/agents/models/openai_responses.py

Lines changed: 147 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,11 @@
55
import inspect
66
import json
77
import weakref
8-
from collections.abc import AsyncIterator, Awaitable, Mapping
8+
from collections.abc import AsyncIterator, Awaitable, Callable, Mapping
99
from contextvars import ContextVar
1010
from dataclasses import asdict, dataclass, is_dataclass
1111
from enum import Enum
12-
from typing import TYPE_CHECKING, Any, Literal, Union, cast, overload
12+
from typing import TYPE_CHECKING, Any, Literal, cast, overload
1313

1414
import httpx
1515
from openai import AsyncOpenAI, NotGiven, Omit, omit
@@ -121,6 +121,106 @@ class _WebsocketRequestTimeouts:
121121
recv: float | None
122122

123123

124+
class _ResponseStreamWithRequestId:
125+
"""Wrap an SDK event stream and retain the originating request ID."""
126+
127+
_TERMINAL_EVENT_TYPES = {
128+
"response.completed",
129+
"response.failed",
130+
"response.incomplete",
131+
"response.error",
132+
}
133+
134+
def __init__(
135+
self,
136+
stream: AsyncIterator[ResponseStreamEvent],
137+
*,
138+
request_id: str | None,
139+
cleanup: Callable[[], Awaitable[object]],
140+
) -> None:
141+
self._stream = stream
142+
self.request_id = request_id
143+
self._cleanup = cleanup
144+
self._closed = False
145+
self._stream_close_complete = False
146+
self._cleanup_complete = False
147+
self._yielded_terminal_event = False
148+
149+
def __aiter__(self) -> _ResponseStreamWithRequestId:
150+
return self
151+
152+
async def __anext__(self) -> ResponseStreamEvent:
153+
if self._closed:
154+
raise StopAsyncIteration
155+
156+
try:
157+
event = await self._stream.__anext__()
158+
except StopAsyncIteration:
159+
self._closed = True
160+
await self._cleanup_after_exhaustion()
161+
raise
162+
163+
self._attach_request_id(event)
164+
event_type = getattr(event, "type", None)
165+
if event_type in self._TERMINAL_EVENT_TYPES:
166+
self._yielded_terminal_event = True
167+
return event
168+
169+
async def aclose(self) -> None:
170+
self._closed = True
171+
try:
172+
await self._close_stream_once()
173+
finally:
174+
await self._cleanup_once()
175+
176+
async def close(self) -> None:
177+
await self.aclose()
178+
179+
def _attach_request_id(self, event: ResponseStreamEvent) -> None:
180+
if self.request_id is None:
181+
return
182+
183+
response = getattr(event, "response", None)
184+
if response is None:
185+
return
186+
187+
try:
188+
response._request_id = self.request_id
189+
except Exception:
190+
return
191+
192+
async def _cleanup_once(self) -> None:
193+
if self._cleanup_complete:
194+
return
195+
self._cleanup_complete = True
196+
await self._cleanup()
197+
198+
async def _cleanup_after_exhaustion(self) -> None:
199+
try:
200+
await self._cleanup_once()
201+
except Exception as exc:
202+
if self._yielded_terminal_event:
203+
logger.debug(f"Ignoring stream cleanup error after terminal event: {exc}")
204+
return
205+
raise
206+
207+
async def _close_stream_once(self) -> None:
208+
if self._stream_close_complete:
209+
return
210+
self._stream_close_complete = True
211+
212+
aclose = getattr(self._stream, "aclose", None)
213+
if callable(aclose):
214+
await aclose()
215+
return
216+
217+
close = getattr(self._stream, "close", None)
218+
if callable(close):
219+
close_result = close()
220+
if inspect.isawaitable(close_result):
221+
await close_result
222+
223+
124224
class ResponsesWebSocketError(RuntimeError):
125225
"""Error raised for websocket transport error frames."""
126226

@@ -269,6 +369,7 @@ async def get_response(
269369
output=response.output,
270370
usage=usage,
271371
response_id=response.id,
372+
request_id=getattr(response, "_request_id", None),
272373
)
273374

274375
async def stream_response(
@@ -400,21 +501,46 @@ async def _fetch_response(
400501
stream: Literal[True] | Literal[False] = False,
401502
prompt: ResponsePromptParam | None = None,
402503
) -> Response | AsyncIterator[ResponseStreamEvent]:
403-
response = await self._client.responses.create(
404-
**self._build_response_create_kwargs(
405-
system_instructions=system_instructions,
406-
input=input,
407-
model_settings=model_settings,
408-
tools=tools,
409-
output_schema=output_schema,
410-
handoffs=handoffs,
411-
previous_response_id=previous_response_id,
412-
conversation_id=conversation_id,
413-
stream=stream,
414-
prompt=prompt,
415-
)
504+
create_kwargs = self._build_response_create_kwargs(
505+
system_instructions=system_instructions,
506+
input=input,
507+
model_settings=model_settings,
508+
tools=tools,
509+
output_schema=output_schema,
510+
handoffs=handoffs,
511+
previous_response_id=previous_response_id,
512+
conversation_id=conversation_id,
513+
stream=stream,
514+
prompt=prompt,
515+
)
516+
517+
if not stream:
518+
response = await self._client.responses.create(**create_kwargs)
519+
return cast(Response, response)
520+
521+
streaming_response = getattr(self._client.responses, "with_streaming_response", None)
522+
stream_create = getattr(streaming_response, "create", None)
523+
if not callable(stream_create):
524+
# Some tests and custom clients only implement `responses.create()`. Fall back to the
525+
# older path in that case and simply omit request IDs for streamed calls.
526+
response = await self._client.responses.create(**create_kwargs)
527+
return cast(AsyncIterator[ResponseStreamEvent], response)
528+
529+
# Keep the raw API response open while callers consume the SSE stream so we can expose
530+
# its request ID on terminal response payloads before cleanup closes the transport.
531+
api_response_cm = stream_create(**create_kwargs)
532+
api_response = await api_response_cm.__aenter__()
533+
try:
534+
stream_response = await api_response.parse()
535+
except BaseException as exc:
536+
await api_response_cm.__aexit__(type(exc), exc, exc.__traceback__)
537+
raise
538+
539+
return _ResponseStreamWithRequestId(
540+
cast(AsyncIterator[ResponseStreamEvent], stream_response),
541+
request_id=getattr(api_response, "request_id", None),
542+
cleanup=lambda: api_response_cm.__aexit__(None, None, None),
416543
)
417-
return cast(Union[Response, AsyncIterator[ResponseStreamEvent]], response)
418544

419545
def _build_response_create_kwargs(
420546
self,
@@ -601,7 +727,8 @@ class OpenAIResponsesWSModel(OpenAIResponsesModel):
601727
602728
The websocket transport currently sends `response.create` frames and always streams events.
603729
`get_response()` is implemented by consuming the streamed events until a terminal response
604-
event is received.
730+
event is received. Successful websocket responses do not currently expose a request ID, so
731+
`ModelResponse.request_id` remains `None` on this transport.
605732
"""
606733

607734
def __init__(
@@ -785,6 +912,9 @@ async def _iter_websocket_response_events(
785912
received_any_event = True
786913
raise ResponsesWebSocketError(payload)
787914

915+
# Successful websocket frames currently expose no per-request ID.
916+
# Unlike the HTTP transport, the websocket upgrade response does not
917+
# include `x-request-id`, and success events carry no equivalent field.
788918
event = _construct_response_stream_event_from_payload(payload)
789919
received_any_event = True
790920
is_terminal_event = event_type in {

src/agents/run_internal/run_loop.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1227,6 +1227,7 @@ async def run_single_turn_streamed(
12271227
output=terminal_response.output,
12281228
usage=usage,
12291229
response_id=terminal_response.id,
1230+
request_id=getattr(terminal_response, "_request_id", None),
12301231
)
12311232
context_wrapper.usage.add(usage)
12321233

src/agents/run_state.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,8 +99,8 @@
9999
# 2. Keep older readable versions in SUPPORTED_SCHEMA_VERSIONS for backward reads.
100100
# 3. to_json() always emits CURRENT_SCHEMA_VERSION.
101101
# 4. Forward compatibility is intentionally fail-fast (older SDKs reject newer versions).
102-
CURRENT_SCHEMA_VERSION = "1.3"
103-
SUPPORTED_SCHEMA_VERSIONS = frozenset({"1.0", "1.1", "1.2", CURRENT_SCHEMA_VERSION})
102+
CURRENT_SCHEMA_VERSION = "1.4"
103+
SUPPORTED_SCHEMA_VERSIONS = frozenset({"1.0", "1.1", "1.2", "1.3", CURRENT_SCHEMA_VERSION})
104104

105105
_FUNCTION_OUTPUT_ADAPTER: TypeAdapter[FunctionCallOutput] = TypeAdapter(FunctionCallOutput)
106106
_COMPUTER_OUTPUT_ADAPTER: TypeAdapter[ComputerCallOutput] = TypeAdapter(ComputerCallOutput)
@@ -265,6 +265,7 @@ def _serialize_model_responses(self) -> list[dict[str, Any]]:
265265
"usage": serialize_usage(resp.usage),
266266
"output": [_serialize_raw_item_value(item) for item in resp.output],
267267
"response_id": resp.response_id,
268+
"request_id": resp.request_id,
268269
}
269270
for resp in self._model_responses
270271
]
@@ -2191,12 +2192,14 @@ def _deserialize_model_responses(responses_data: list[dict[str, Any]]) -> list[M
21912192
output = output_adapter.validate_python(normalized_output)
21922193

21932194
response_id = resp_data.get("response_id")
2195+
request_id = resp_data.get("request_id")
21942196

21952197
result.append(
21962198
ModelResponse(
21972199
usage=usage,
21982200
output=output,
21992201
response_id=response_id,
2202+
request_id=request_id,
22002203
)
22012204
)
22022205

tests/test_agent_runner_streamed.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import httpx
88
import pytest
99
from openai.types.responses import (
10+
ResponseCompletedEvent,
1011
ResponseFailedEvent,
1112
ResponseFunctionToolCall,
1213
ResponseIncompleteEvent,
@@ -173,6 +174,44 @@ async def stream_response(
173174
assert result.raw_responses[0].response_id == "resp-partial"
174175

175176

177+
@pytest.mark.asyncio
178+
async def test_streamed_run_exposes_request_id_on_raw_responses() -> None:
179+
class RequestIdTerminalFakeModel(FakeModel):
180+
async def stream_response(
181+
self,
182+
system_instructions,
183+
input,
184+
model_settings,
185+
tools,
186+
output_schema,
187+
handoffs,
188+
tracing,
189+
*,
190+
previous_response_id=None,
191+
conversation_id=None,
192+
prompt=None,
193+
):
194+
response = get_response_obj(
195+
[get_text_message("partial final")], response_id="resp-partial"
196+
)
197+
response._request_id = "req_streamed_result_123"
198+
yield ResponseCompletedEvent(
199+
type="response.completed",
200+
response=response,
201+
sequence_number=0,
202+
)
203+
204+
model = RequestIdTerminalFakeModel()
205+
agent = Agent(name="test", model=model)
206+
207+
result = Runner.run_streamed(agent, input="test")
208+
async for _ in result.stream_events():
209+
pass
210+
211+
assert len(result.raw_responses) == 1
212+
assert result.raw_responses[0].request_id == "req_streamed_result_123"
213+
214+
176215
@pytest.mark.allow_call_model_methods
177216
@pytest.mark.asyncio
178217
@pytest.mark.parametrize("terminal_event_type", ["response.incomplete", "response.failed"])

0 commit comments

Comments
 (0)