Skip to content

Commit 5510bda

Browse files
fix(voice): #1824 handle odd-length audio buffers in StreamedAudioResult (#2456)
1 parent 1cea241 commit 5510bda

File tree

2 files changed

+99
-9
lines changed

2 files changed

+99
-9
lines changed

src/agents/voice/result.py

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,12 @@ async def _add_error(self, error: Exception):
8888
def _transform_audio_buffer(
8989
self, buffer: list[bytes], output_dtype: npt.DTypeLike
9090
) -> npt.NDArray[np.int16 | np.float32]:
91-
np_array = np.frombuffer(b"".join(buffer), dtype=np.int16)
91+
combined_buffer = b"".join(buffer)
92+
if len(combined_buffer) % 2 != 0:
93+
# np.int16 needs 2-byte alignment; pad odd-length chunks safely.
94+
combined_buffer += b"\x00"
95+
96+
np_array = np.frombuffer(combined_buffer, dtype=np.int16)
9297

9398
if output_dtype == np.int16:
9499
return np_array
@@ -118,6 +123,7 @@ async def _stream_audio(
118123
first_byte_received = False
119124
buffer: list[bytes] = []
120125
full_audio_data: list[bytes] = []
126+
pending_byte = b""
121127

122128
async for chunk in self.tts_model.run(text, self.tts_settings):
123129
if not first_byte_received:
@@ -128,15 +134,33 @@ async def _stream_audio(
128134
buffer.append(chunk)
129135
full_audio_data.append(chunk)
130136
if len(buffer) >= self._buffer_size:
131-
audio_np = self._transform_audio_buffer(buffer, self.tts_settings.dtype)
132-
if self.tts_settings.transform_data:
133-
audio_np = self.tts_settings.transform_data(audio_np)
134-
await local_queue.put(
135-
VoiceStreamEventAudio(data=audio_np)
136-
) # Use local queue
137+
combined = pending_byte + b"".join(buffer)
138+
if len(combined) % 2 != 0:
139+
pending_byte = combined[-1:]
140+
combined = combined[:-1]
141+
else:
142+
pending_byte = b""
143+
144+
if combined:
145+
audio_np = self._transform_audio_buffer(
146+
[combined], self.tts_settings.dtype
147+
)
148+
if self.tts_settings.transform_data:
149+
audio_np = self.tts_settings.transform_data(audio_np)
150+
await local_queue.put(
151+
VoiceStreamEventAudio(data=audio_np)
152+
) # Use local queue
137153
buffer = []
138154
if buffer:
139-
audio_np = self._transform_audio_buffer(buffer, self.tts_settings.dtype)
155+
combined = pending_byte + b"".join(buffer)
156+
else:
157+
combined = pending_byte
158+
159+
if combined:
160+
# Final flush: pad the remaining half sample if needed.
161+
if len(combined) % 2 != 0:
162+
combined += b"\x00"
163+
audio_np = self._transform_audio_buffer([combined], self.tts_settings.dtype)
140164
if self.tts_settings.transform_data:
141165
audio_np = self.tts_settings.transform_data(audio_np)
142166
await local_queue.put(VoiceStreamEventAudio(data=audio_np)) # Use local queue

tests/voice/test_pipeline.py

Lines changed: 67 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,84 @@
11
from __future__ import annotations
22

3+
import asyncio
4+
35
import numpy as np
46
import numpy.typing as npt
57
import pytest
68

79
try:
8-
from agents.voice import AudioInput, TTSModelSettings, VoicePipeline, VoicePipelineConfig
10+
from agents.voice import (
11+
AudioInput,
12+
StreamedAudioResult,
13+
TTSModelSettings,
14+
VoicePipeline,
15+
VoicePipelineConfig,
16+
VoiceStreamEvent,
17+
VoiceStreamEventAudio,
18+
VoiceStreamEventLifecycle,
19+
)
920

1021
from .fake_models import FakeStreamedAudioInput, FakeSTT, FakeTTS, FakeWorkflow
1122
from .helpers import extract_events
1223
except ImportError:
1324
pass
1425

1526

27+
def test_streamed_audio_result_odd_length_buffer_int16() -> None:
28+
result = StreamedAudioResult(
29+
FakeTTS(),
30+
TTSModelSettings(dtype=np.int16),
31+
VoicePipelineConfig(),
32+
)
33+
34+
transformed = result._transform_audio_buffer([b"\x01"], np.int16)
35+
36+
assert transformed.dtype == np.int16
37+
assert transformed.tolist() == [1]
38+
39+
40+
def test_streamed_audio_result_odd_length_buffer_float32() -> None:
41+
result = StreamedAudioResult(
42+
FakeTTS(),
43+
TTSModelSettings(dtype=np.float32),
44+
VoicePipelineConfig(),
45+
)
46+
47+
transformed = result._transform_audio_buffer([b"\x01"], np.float32)
48+
49+
assert transformed.dtype == np.float32
50+
assert transformed.shape == (1, 1)
51+
assert transformed[0, 0] == pytest.approx(1 / 32767.0)
52+
53+
54+
@pytest.mark.asyncio
55+
async def test_streamed_audio_result_preserves_cross_chunk_sample_boundaries() -> None:
56+
class SplitSampleTTS(FakeTTS):
57+
async def run(self, text: str, settings: TTSModelSettings):
58+
del text, settings
59+
yield b"\x01"
60+
yield b"\x00"
61+
62+
result = StreamedAudioResult(
63+
SplitSampleTTS(),
64+
TTSModelSettings(buffer_size=1, dtype=np.int16),
65+
VoicePipelineConfig(),
66+
)
67+
local_queue: asyncio.Queue[VoiceStreamEvent | None] = asyncio.Queue()
68+
69+
await result._stream_audio("hello", local_queue, finish_turn=True)
70+
71+
audio_chunks: list[bytes] = []
72+
while True:
73+
event = await local_queue.get()
74+
if isinstance(event, VoiceStreamEventAudio) and event.data is not None:
75+
audio_chunks.append(event.data.tobytes())
76+
if isinstance(event, VoiceStreamEventLifecycle) and event.event == "turn_ended":
77+
break
78+
79+
assert audio_chunks == [np.array([1], dtype=np.int16).tobytes()]
80+
81+
1682
@pytest.mark.asyncio
1783
async def test_voicepipeline_run_single_turn() -> None:
1884
# Single turn. Should produce a single audio output, which is the TTS output for "out_1".

0 commit comments

Comments
 (0)