Skip to content

Commit 3589b6b

Browse files
fix(voice): #2470 keep trace active until pipeline processing completes (#2472)
1 parent b96b0ce commit 3589b6b

3 files changed

Lines changed: 126 additions & 48 deletions

File tree

src/agents/voice/pipeline.py

Lines changed: 45 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -84,24 +84,20 @@ async def _process_audio_input(self, audio_input: AudioInput) -> str:
8484
)
8585

8686
async def _run_single_turn(self, audio_input: AudioInput) -> StreamedAudioResult:
87-
# Since this is single turn, we can use the TraceCtxManager to manage starting/ending the
88-
# trace
89-
with TraceCtxManager(
90-
workflow_name=self.config.workflow_name or "Voice Agent",
91-
trace_id=None, # Automatically generated
92-
group_id=self.config.group_id,
93-
metadata=self.config.trace_metadata,
94-
tracing=self.config.tracing,
95-
disabled=self.config.tracing_disabled,
96-
):
97-
input_text = await self._process_audio_input(audio_input)
98-
99-
output = StreamedAudioResult(
100-
self._get_tts_model(), self.config.tts_settings, self.config
101-
)
102-
103-
async def stream_events():
87+
output = StreamedAudioResult(self._get_tts_model(), self.config.tts_settings, self.config)
88+
89+
async def stream_events():
90+
# Keep the trace scope active for the entire async processing lifecycle.
91+
with TraceCtxManager(
92+
workflow_name=self.config.workflow_name or "Voice Agent",
93+
trace_id=None, # Automatically generated
94+
group_id=self.config.group_id,
95+
metadata=self.config.trace_metadata,
96+
tracing=self.config.tracing,
97+
disabled=self.config.tracing_disabled,
98+
):
10499
try:
100+
input_text = await self._process_audio_input(audio_input)
105101
async for text_event in self.workflow.run(input_text):
106102
await output._add_text(text_event)
107103
await output._turn_done()
@@ -111,37 +107,37 @@ async def stream_events():
111107
await output._add_error(e)
112108
raise e
113109

114-
output._set_task(asyncio.create_task(stream_events()))
115-
return output
110+
output._set_task(asyncio.create_task(stream_events()))
111+
return output
116112

117113
async def _run_multi_turn(self, audio_input: StreamedAudioInput) -> StreamedAudioResult:
118-
with TraceCtxManager(
119-
workflow_name=self.config.workflow_name or "Voice Agent",
120-
trace_id=None,
121-
group_id=self.config.group_id,
122-
metadata=self.config.trace_metadata,
123-
tracing=self.config.tracing,
124-
disabled=self.config.tracing_disabled,
125-
):
126-
output = StreamedAudioResult(
127-
self._get_tts_model(), self.config.tts_settings, self.config
128-
)
129-
130-
try:
131-
async for intro_text in self.workflow.on_start():
132-
await output._add_text(intro_text)
133-
except Exception as e:
134-
logger.warning(f"on_start() failed: {e}")
135-
136-
transcription_session = await self._get_stt_model().create_session(
137-
audio_input,
138-
self.config.stt_settings,
139-
self.config.trace_include_sensitive_data,
140-
self.config.trace_include_sensitive_audio_data,
141-
)
142-
143-
async def process_turns():
114+
output = StreamedAudioResult(self._get_tts_model(), self.config.tts_settings, self.config)
115+
116+
async def process_turns():
117+
# Keep the trace scope active for the full streamed session.
118+
with TraceCtxManager(
119+
workflow_name=self.config.workflow_name or "Voice Agent",
120+
trace_id=None,
121+
group_id=self.config.group_id,
122+
metadata=self.config.trace_metadata,
123+
tracing=self.config.tracing,
124+
disabled=self.config.tracing_disabled,
125+
):
126+
transcription_session = None
144127
try:
128+
try:
129+
async for intro_text in self.workflow.on_start():
130+
await output._add_text(intro_text)
131+
except Exception as e:
132+
logger.warning(f"on_start() failed: {e}")
133+
134+
transcription_session = await self._get_stt_model().create_session(
135+
audio_input,
136+
self.config.stt_settings,
137+
self.config.trace_include_sensitive_data,
138+
self.config.trace_include_sensitive_audio_data,
139+
)
140+
145141
async for input_text in transcription_session.transcribe_turns():
146142
result = self.workflow.run(input_text)
147143
async for text_event in result:
@@ -152,8 +148,9 @@ async def process_turns():
152148
await output._add_error(e)
153149
raise e
154150
finally:
155-
await transcription_session.close()
151+
if transcription_session is not None:
152+
await transcription_session.close()
156153
await output._done()
157154

158-
output._set_task(asyncio.create_task(process_turns()))
159-
return output
155+
output._set_task(asyncio.create_task(process_turns()))
156+
return output

src/agents/voice/result.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,7 @@ def _check_errors(self):
289289

290290
async def stream(self) -> AsyncIterator[VoiceStreamEvent]:
291291
"""Stream the events and audio data as they're generated."""
292+
saw_session_end = False
292293
while True:
293294
try:
294295
event = await self._queue.get()
@@ -302,8 +303,18 @@ async def stream(self) -> AsyncIterator[VoiceStreamEvent]:
302303
break
303304
yield event
304305
if event.type == "voice_stream_event_lifecycle" and event.event == "session_ended":
306+
saw_session_end = True
305307
break
306308

309+
# On the normal completion path, let the producer task finish gracefully so any active
310+
# trace context can emit `trace_end` before we run cleanup.
311+
if (
312+
saw_session_end
313+
and self.text_generation_task is not None
314+
and not self.text_generation_task.done()
315+
):
316+
await asyncio.shield(self.text_generation_task)
317+
307318
self._check_errors()
308319
self._cleanup_tasks()
309320

tests/voice/test_pipeline.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
import numpy.typing as npt
77
import pytest
88

9+
from tests.testing_processor import fetch_events
10+
911
try:
1012
from agents.voice import (
1113
AudioInput,
@@ -243,3 +245,71 @@ def _transform_data(
243245
"session_ended",
244246
]
245247
await fake_tts.verify_audio("out_1", audio_chunks[0], dtype=np.int16)
248+
249+
250+
class _BlockingWorkflow(FakeWorkflow):
251+
def __init__(self, gate: asyncio.Event):
252+
super().__init__()
253+
self._gate = gate
254+
255+
async def run(self, _: str):
256+
await self._gate.wait()
257+
yield "out_1"
258+
259+
260+
class _OnStartYieldThenFailWorkflow(FakeWorkflow):
261+
async def on_start(self):
262+
yield "intro"
263+
raise RuntimeError("boom")
264+
265+
266+
@pytest.mark.asyncio
267+
async def test_voicepipeline_trace_not_finished_before_single_turn_completes() -> None:
268+
fake_stt = FakeSTT(["first"])
269+
fake_tts = FakeTTS()
270+
gate = asyncio.Event()
271+
workflow = _BlockingWorkflow(gate)
272+
config = VoicePipelineConfig(tts_settings=TTSModelSettings(buffer_size=1))
273+
pipeline = VoicePipeline(
274+
workflow=workflow, stt_model=fake_stt, tts_model=fake_tts, config=config
275+
)
276+
277+
audio_input = AudioInput(buffer=np.zeros(2, dtype=np.int16))
278+
result = await pipeline.run(audio_input)
279+
await asyncio.sleep(0)
280+
281+
events_before_unblock = fetch_events()
282+
assert "trace_start" in events_before_unblock
283+
assert "trace_end" not in events_before_unblock
284+
285+
gate.set()
286+
await extract_events(result)
287+
assert fetch_events()[-1] == "trace_end"
288+
289+
290+
@pytest.mark.asyncio
291+
async def test_voicepipeline_trace_finishes_after_multi_turn_processing() -> None:
292+
fake_stt = FakeSTT(["first", "second"])
293+
workflow = FakeWorkflow([["out_1"], ["out_2"]])
294+
fake_tts = FakeTTS()
295+
pipeline = VoicePipeline(workflow=workflow, stt_model=fake_stt, tts_model=fake_tts)
296+
297+
streamed_audio_input = await FakeStreamedAudioInput.get(count=2)
298+
result = await pipeline.run(streamed_audio_input)
299+
await extract_events(result)
300+
assert fetch_events()[-1] == "trace_end"
301+
302+
303+
@pytest.mark.asyncio
304+
async def test_voicepipeline_multi_turn_on_start_exception_does_not_abort() -> None:
305+
fake_stt = FakeSTT(["first"])
306+
workflow = _OnStartYieldThenFailWorkflow([["out_1"]])
307+
fake_tts = FakeTTS()
308+
pipeline = VoicePipeline(workflow=workflow, stt_model=fake_stt, tts_model=fake_tts)
309+
310+
streamed_audio_input = await FakeStreamedAudioInput.get(count=1)
311+
result = await pipeline.run(streamed_audio_input)
312+
events, _ = await extract_events(result)
313+
314+
assert events[-1] == "session_ended"
315+
assert "error" not in events

0 commit comments

Comments
 (0)