Skip to content

Commit 7fea1a5

Browse files
authored
fix: improve a flaky test for realtime module (#2787)
1 parent c52d25f commit 7fea1a5

File tree

1 file changed

+62
-73
lines changed

1 file changed

+62
-73
lines changed

tests/realtime/test_openai_realtime.py

Lines changed: 62 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -1863,6 +1863,40 @@ def mock_create_task_func(coro):
18631863
assert captured_kwargs_long.get("ping_interval") == 5.0
18641864
assert captured_kwargs_long.get("ping_timeout") == 10.0
18651865

1866+
@pytest.mark.asyncio
1867+
async def test_handshake_timeout_config_is_applied(self):
1868+
"""Test that handshake_timeout is passed through as websockets open_timeout."""
1869+
captured_kwargs: dict[str, Any] = {}
1870+
1871+
async def capture_connect(*args, **kwargs):
1872+
captured_kwargs.update(kwargs)
1873+
mock_ws = AsyncMock()
1874+
mock_ws.close_code = None
1875+
return mock_ws
1876+
1877+
transport: TransportConfig = {
1878+
"handshake_timeout": 0.75,
1879+
}
1880+
model = OpenAIRealtimeWebSocketModel(transport_config=transport)
1881+
with patch("websockets.connect", side_effect=capture_connect):
1882+
with patch("asyncio.create_task") as mock_create_task:
1883+
mock_task = AsyncMock()
1884+
1885+
def mock_create_task_func(coro):
1886+
coro.close()
1887+
return mock_task
1888+
1889+
mock_create_task.side_effect = mock_create_task_func
1890+
1891+
config: RealtimeModelConfig = {
1892+
"api_key": "test-key",
1893+
"url": "ws://localhost:8080/v1/realtime",
1894+
"initial_model_settings": {"model_name": "gpt-4o-realtime-preview"},
1895+
}
1896+
await model.connect(config)
1897+
1898+
assert captured_kwargs.get("open_timeout") == 0.75
1899+
18661900
@pytest.mark.asyncio
18671901
async def test_ping_timeout_disabled_vs_enabled(self):
18681902
"""Test that ping timeout can be disabled (None) vs enabled with a value."""
@@ -1978,78 +2012,37 @@ async def test_handshake_timeout_with_delayed_server(self):
19782012
- Success: client timeout > server delay
19792013
- Failure: client timeout < server delay
19802014
"""
1981-
import base64
1982-
import hashlib
1983-
19842015
# Server handshake delay threshold (in seconds)
1985-
SERVER_HANDSHAKE_DELAY = 0.05
2016+
SERVER_HANDSHAKE_DELAY = 0.5
19862017

19872018
shutdown_event = asyncio.Event()
1988-
connections_attempted = []
1989-
1990-
async def delayed_websocket_server(reader, writer):
1991-
"""A WebSocket server that delays the handshake by a fixed amount."""
1992-
connections_attempted.append(True)
1993-
try:
1994-
# Read HTTP upgrade request
1995-
request = b""
1996-
while b"\r\n\r\n" not in request:
1997-
chunk = await asyncio.wait_for(reader.read(1024), timeout=5.0)
1998-
if not chunk:
1999-
return
2000-
request += chunk
2001-
2002-
# Extract Sec-WebSocket-Key
2003-
key = None
2004-
for line in request.decode().split("\r\n"):
2005-
if line.lower().startswith("sec-websocket-key:"):
2006-
key = line.split(":", 1)[1].strip()
2007-
break
2008-
2009-
if not key:
2010-
writer.close()
2011-
return
2012-
2013-
# Intentional delay before completing handshake
2014-
await asyncio.sleep(SERVER_HANDSHAKE_DELAY)
2015-
2016-
# Generate accept key
2017-
GUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
2018-
accept = base64.b64encode(hashlib.sha1((key + GUID).encode()).digest()).decode()
2019-
2020-
# Send HTTP 101 Switching Protocols response
2021-
response = (
2022-
"HTTP/1.1 101 Switching Protocols\r\n"
2023-
"Upgrade: websocket\r\n"
2024-
"Connection: Upgrade\r\n"
2025-
f"Sec-WebSocket-Accept: {accept}\r\n"
2026-
"\r\n"
2027-
)
2028-
writer.write(response.encode())
2029-
await writer.drain()
2030-
2031-
# Keep connection open until shutdown, then send a close frame so
2032-
# the client can complete close() without waiting for a timeout.
2033-
await shutdown_event.wait()
2034-
writer.write(b"\x88\x00")
2035-
await writer.drain()
2036-
2037-
except asyncio.TimeoutError:
2038-
pass
2039-
except Exception:
2040-
pass
2041-
finally:
2042-
writer.close()
2043-
2044-
server = await asyncio.start_server(delayed_websocket_server, "127.0.0.1", 0)
2045-
port = server.sockets[0].getsockname()[1]
2046-
url = f"ws://127.0.0.1:{port}/v1/realtime"
2019+
handshake_started = asyncio.Event()
2020+
handshake_attempts = 0
2021+
2022+
async def process_request(_connection, _request):
2023+
nonlocal handshake_attempts
2024+
handshake_attempts += 1
2025+
handshake_started.set()
2026+
await asyncio.sleep(SERVER_HANDSHAKE_DELAY)
2027+
return None
2028+
2029+
async def delayed_handler(_websocket):
2030+
await shutdown_event.wait()
2031+
2032+
async with websockets.serve(
2033+
delayed_handler,
2034+
"127.0.0.1",
2035+
0,
2036+
process_request=process_request,
2037+
) as server:
2038+
sockets = list(server.sockets)
2039+
port = sockets[0].getsockname()[1]
2040+
url = f"ws://127.0.0.1:{port}/v1/realtime"
20472041

2048-
try:
20492042
# Test 1: FAILURE - Client timeout < server delay
20502043
# Client gives up before server completes handshake
20512044
transport_fail: TransportConfig = {
2052-
"handshake_timeout": 0.01,
2045+
"handshake_timeout": 0.2,
20532046
}
20542047
model_fail = OpenAIRealtimeWebSocketModel(transport_config=transport_fail)
20552048
config_fail: RealtimeModelConfig = {
@@ -2061,13 +2054,14 @@ async def delayed_websocket_server(reader, writer):
20612054
with pytest.raises((TimeoutError, asyncio.TimeoutError)):
20622055
await model_fail.connect(config_fail)
20632056

2064-
# Verify connection was attempted
2065-
assert len(connections_attempted) >= 1
2057+
# Wait briefly for the server to observe the request before asserting.
2058+
await asyncio.wait_for(handshake_started.wait(), timeout=1.0)
2059+
assert handshake_attempts >= 1
20662060

20672061
# Test 2: SUCCESS - Client timeout > server delay
20682062
# Client waits long enough for server to complete handshake
20692063
transport_success: TransportConfig = {
2070-
"handshake_timeout": 0.2,
2064+
"handshake_timeout": 1.0,
20712065
}
20722066
model_success = OpenAIRealtimeWebSocketModel(transport_config=transport_success)
20732067
config_success: RealtimeModelConfig = {
@@ -2085,11 +2079,6 @@ async def delayed_websocket_server(reader, writer):
20852079
shutdown_event.set()
20862080
await model_success.close()
20872081

2088-
finally:
2089-
shutdown_event.set()
2090-
server.close()
2091-
await server.wait_closed()
2092-
20932082
@pytest.mark.asyncio
20942083
async def test_ping_interval_comparison_fast_vs_slow(self):
20952084
"""Test that faster ping intervals detect issues sooner than slower ones."""

0 commit comments

Comments
 (0)