Skip to content

Commit 038c84e

Browse files
Add WebSocket custom options to OpenAIRealtimeWebSocketModel (#2264)
Co-authored-by: Kazuhiro Sera <seratch@openai.com>
1 parent 5ce91f2 commit 038c84e

File tree

2 files changed

+568
-8
lines changed

2 files changed

+568
-8
lines changed

src/agents/realtime/openai_realtime.py

Lines changed: 55 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@
7979
)
8080
from openai.types.responses.response_prompt import ResponsePrompt
8181
from pydantic import Field, TypeAdapter
82-
from typing_extensions import TypeAlias, assert_never
82+
from typing_extensions import NotRequired, TypeAlias, TypedDict, assert_never
8383
from websockets.asyncio.client import ClientConnection
8484

8585
from agents.handoffs import Handoff
@@ -251,10 +251,25 @@ async def _build_model_settings_from_agent(
251251
# during import on 3.9. We instead inline the union in annotations below.
252252

253253

254+
class TransportConfig(TypedDict):
255+
"""Low-level network transport configuration."""
256+
257+
ping_interval: NotRequired[float | None]
258+
"""Time in seconds between keepalive pings sent by the client.
259+
Default is usually 20.0. Set to None to disable."""
260+
261+
ping_timeout: NotRequired[float | None]
262+
"""Time in seconds to wait for a pong response before disconnecting.
263+
Set to None to disable ping timeout and keep an open connection (ignore network lag)."""
264+
265+
handshake_timeout: NotRequired[float]
266+
"""Time in seconds to wait for the connection handshake to complete."""
267+
268+
254269
class OpenAIRealtimeWebSocketModel(RealtimeModel):
255270
"""A model that uses OpenAI's WebSocket API."""
256271

257-
def __init__(self) -> None:
272+
def __init__(self, *, transport_config: TransportConfig | None = None) -> None:
258273
self.model = "gpt-realtime" # Default model
259274
self._websocket: ClientConnection | None = None
260275
self._websocket_task: asyncio.Task[None] | None = None
@@ -267,6 +282,7 @@ def __init__(self) -> None:
267282
self._created_session: OpenAISessionCreateRequest | None = None
268283
self._server_event_type_adapter = get_server_event_type_adapter()
269284
self._call_id: str | None = None
285+
self._transport_config: TransportConfig | None = transport_config
270286

271287
async def connect(self, options: RealtimeModelConfig) -> None:
272288
"""Establish a connection to the model and keep it alive."""
@@ -312,15 +328,47 @@ async def connect(self, options: RealtimeModelConfig) -> None:
312328
raise UserError("API key is required but was not provided.")
313329

314330
headers.update({"Authorization": f"Bearer {api_key}"})
315-
self._websocket = await websockets.connect(
316-
url,
317-
user_agent_header=_USER_AGENT,
318-
additional_headers=headers,
319-
max_size=None, # Allow any size of message
331+
332+
self._websocket = await self._create_websocket_connection(
333+
url=url,
334+
headers=headers,
335+
transport_config=self._transport_config,
320336
)
321337
self._websocket_task = asyncio.create_task(self._listen_for_messages())
322338
await self._update_session_config(model_settings)
323339

340+
async def _create_websocket_connection(
341+
self,
342+
url: str,
343+
headers: dict[str, str],
344+
transport_config: TransportConfig | None = None,
345+
) -> ClientConnection:
346+
"""Create a WebSocket connection with the given configuration.
347+
348+
Args:
349+
url: The WebSocket URL to connect to.
350+
headers: HTTP headers to include in the connection request.
351+
transport_config: Optional low-level transport configuration.
352+
353+
Returns:
354+
A connected WebSocket client connection.
355+
"""
356+
connect_kwargs: dict[str, Any] = {
357+
"user_agent_header": _USER_AGENT,
358+
"additional_headers": headers,
359+
"max_size": None, # Allow any size of message
360+
}
361+
362+
if transport_config:
363+
if "ping_interval" in transport_config:
364+
connect_kwargs["ping_interval"] = transport_config["ping_interval"]
365+
if "ping_timeout" in transport_config:
366+
connect_kwargs["ping_timeout"] = transport_config["ping_timeout"]
367+
if "handshake_timeout" in transport_config:
368+
connect_kwargs["open_timeout"] = transport_config["handshake_timeout"]
369+
370+
return await websockets.connect(url, **connect_kwargs)
371+
324372
async def _send_tracing_config(
325373
self, tracing_config: RealtimeModelTracingConfig | Literal["auto"] | None
326374
) -> None:

0 commit comments

Comments
 (0)