Skip to content

Commit f1cafcd

Browse files
author
Aditya Singh
committed
feat(lifecycle): add TurnControl return value to on_turn_start hooks
Address seratch's review feedback on #2911: hooks that only observe cannot affect agent loop orchestration. This commit adds a TurnControl return type ('continue' | 'stop') so on_turn_start can now halt the run before the LLM is called for that turn. Changes: - lifecycle.py: on_turn_start now returns Union[TurnControl, None] (None and 'continue' are equivalent; 'stop' halts the loop) - run.py (non-streaming path): checks return value; raises MaxTurnsExceeded with descriptive message on 'stop' - run_internal/run_loop.py (streaming path): checks return value; signals QueueCompleteSentinel on 'stop' - __init__.py: exports TurnControl, RunHooksBase, AgentHooksBase - tests: 4 new test cases covering stop-on-turn-N, stop-on-turn-1, explicit 'continue', and agent-level stop The MaxTurnsExceeded raise on 'stop' keeps behaviour consistent with the existing max_turns limit: callers can catch and inspect .run_data if needed.
1 parent ff99d90 commit f1cafcd

File tree

5 files changed

+212
-10
lines changed

5 files changed

+212
-10
lines changed

src/agents/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@
6565
ToolCallOutputItem,
6666
TResponseInputItem,
6767
)
68-
from .lifecycle import AgentHooks, RunHooks
68+
from .lifecycle import AgentHooks, AgentHooksBase, RunHooks, RunHooksBase, TurnControl
6969
from .memory import (
7070
OpenAIConversationsSession,
7171
OpenAIResponsesCompactionArgs,
@@ -361,7 +361,10 @@ def enable_verbose_stdout_logging():
361361
"ReasoningItem",
362362
"ItemHelpers",
363363
"RunHooks",
364+
"RunHooksBase",
364365
"AgentHooks",
366+
"AgentHooksBase",
367+
"TurnControl",
365368
"Session",
366369
"SessionABC",
367370
"SessionSettings",

src/agents/lifecycle.py

Lines changed: 47 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
from typing import Any, Generic, Optional
1+
from __future__ import annotations
2+
3+
from typing import Any, Generic, Literal, Optional, Union
24

35
from typing_extensions import TypeVar
46

@@ -9,10 +11,26 @@
911

1012
TAgent = TypeVar("TAgent", bound=AgentBase, default=AgentBase)
1113

14+
TurnControl = Literal["continue", "stop"]
15+
"""Return value for :meth:`RunHooksBase.on_turn_start` / :meth:`AgentHooksBase.on_turn_start`.
16+
17+
* ``"continue"`` (default / ``None``) – proceed with the turn as normal.
18+
* ``"stop"`` – abort the run gracefully after this hook returns, exactly as if
19+
``max_turns`` had been reached. The model is **not** called for this turn and
20+
:meth:`on_turn_end` is **not** fired.
21+
"""
22+
1223

1324
class RunHooksBase(Generic[TContext, TAgent]):
1425
"""A class that receives callbacks on various lifecycle events in an agent run. Subclass and
1526
override the methods you need.
27+
28+
Turn-lifecycle hooks
29+
--------------------
30+
:meth:`on_turn_start` and :meth:`on_turn_end` fire once per iteration of the
31+
agent loop. :meth:`on_turn_start` may return ``"stop"`` to halt the run
32+
gracefully before the LLM is called for that turn (useful for implementing
33+
custom turn-budget logic, external kill-switches, etc.).
1634
"""
1735

1836
async def on_llm_start(
@@ -91,15 +109,24 @@ async def on_turn_start(
91109
context: RunContextWrapper[TContext],
92110
agent: TAgent,
93111
turn_number: int,
94-
) -> None:
112+
) -> Union[TurnControl, None]:
95113
"""Called at the start of each agent turn, before the LLM is invoked.
96114
115+
Returning ``"stop"`` (or raising :class:`StopAgentRun`) will halt the run
116+
gracefully — the model is **not** called for this turn and
117+
:meth:`on_turn_end` is **not** fired. Returning ``None`` or ``"continue"``
118+
proceeds normally.
119+
97120
Args:
98121
context: The run context wrapper.
99122
agent: The current agent.
100-
turn_number: The 1-indexed turn number (increments each time through the agent loop).
123+
turn_number: The 1-indexed turn number (increments each time through the
124+
agent loop).
125+
126+
Returns:
127+
``None`` / ``"continue"`` to proceed, or ``"stop"`` to halt the run.
101128
"""
102-
pass
129+
return None
103130

104131
async def on_turn_end(
105132
self,
@@ -122,6 +149,12 @@ class AgentHooksBase(Generic[TContext, TAgent]):
122149
set this on `agent.hooks` to receive events for that specific agent.
123150
124151
Subclass and override the methods you need.
152+
153+
Turn-lifecycle hooks
154+
--------------------
155+
:meth:`on_turn_start` and :meth:`on_turn_end` fire once per iteration of the
156+
agent loop. :meth:`on_turn_start` may return ``"stop"`` to halt the run
157+
gracefully before the LLM is called for that turn.
125158
"""
126159

127160
async def on_start(self, context: AgentHookContext[TContext], agent: TAgent) -> None:
@@ -183,15 +216,22 @@ async def on_turn_start(
183216
context: RunContextWrapper[TContext],
184217
agent: TAgent,
185218
turn_number: int,
186-
) -> None:
219+
) -> Union[TurnControl, None]:
187220
"""Called at the start of each agent turn, before the LLM is invoked.
188221
222+
Returning ``"stop"`` halts the run gracefully before the model is called.
223+
Returning ``None`` or ``"continue"`` proceeds normally.
224+
189225
Args:
190226
context: The run context wrapper.
191227
agent: The current agent.
192-
turn_number: The 1-indexed turn number (increments each time through the agent loop).
228+
turn_number: The 1-indexed turn number (increments each time through the
229+
agent loop).
230+
231+
Returns:
232+
``None`` / ``"continue"`` to proceed, or ``"stop"`` to halt the run.
193233
"""
194-
pass
234+
return None
195235

196236
async def on_turn_end(
197237
self,

src/agents/run.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -968,7 +968,7 @@ def _with_reasoning_item_id_policy(result: RunResult) -> RunResult:
968968

969969
logger.debug("Running agent %s (turn %s)", current_agent.name, current_turn)
970970

971-
await asyncio.gather(
971+
run_hook_control, agent_hook_control = await asyncio.gather(
972972
hooks.on_turn_start(context_wrapper, current_agent, current_turn),
973973
(
974974
current_agent.hooks.on_turn_start(
@@ -978,6 +978,14 @@ def _with_reasoning_item_id_policy(result: RunResult) -> RunResult:
978978
else _coro.noop_coroutine()
979979
),
980980
)
981+
if run_hook_control == "stop" or agent_hook_control == "stop":
982+
logger.debug(
983+
"Turn %s: on_turn_start hook requested stop; halting run.",
984+
current_turn,
985+
)
986+
raise MaxTurnsExceeded(
987+
f"Run halted by on_turn_start hook at turn {current_turn}"
988+
)
981989

982990
if session_persistence_enabled:
983991
try:

src/agents/run_internal/run_loop.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -820,7 +820,7 @@ async def _save_stream_items_without_count(
820820
streamed_result._event_queue.put_nowait(QueueCompleteSentinel())
821821
break
822822

823-
await asyncio.gather(
823+
run_hook_control, agent_hook_control = await asyncio.gather(
824824
hooks.on_turn_start(context_wrapper, current_agent, current_turn),
825825
(
826826
current_agent.hooks.on_turn_start(
@@ -830,6 +830,18 @@ async def _save_stream_items_without_count(
830830
else _coro.noop_coroutine()
831831
),
832832
)
833+
if run_hook_control == "stop" or agent_hook_control == "stop":
834+
logger.debug(
835+
"Turn %s: on_turn_start hook requested stop; halting run.",
836+
current_turn,
837+
)
838+
streamed_result._max_turns_handled = True
839+
streamed_result.current_turn = current_turn
840+
if run_state is not None:
841+
run_state._current_turn = current_turn
842+
run_state._current_step = None
843+
streamed_result._event_queue.put_nowait(QueueCompleteSentinel())
844+
break
833845

834846
if current_turn == 1:
835847
all_input_guardrails = starting_agent.input_guardrails + (

tests/test_turn_lifecycle_hooks.py

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,3 +199,142 @@ async def test_on_turn_hooks_with_streaming() -> None:
199199

200200
assert hooks.turn_starts == [1]
201201
assert hooks.turn_ends == [1]
202+
203+
204+
# ---------------------------------------------------------------------------
205+
# TurnControl tests: on_turn_start returning "stop" halts the loop
206+
# ---------------------------------------------------------------------------
207+
208+
class StopAfterTurnRunHooks(RunHooks):
209+
"""Stops the run when on_turn_start is called for a turn > stop_after."""
210+
211+
def __init__(self, stop_after: int = 1) -> None:
212+
self.stop_after = stop_after
213+
self.turn_starts: list[int] = []
214+
self.turn_ends: list[int] = []
215+
216+
async def on_turn_start(
217+
self,
218+
context: RunContextWrapper[TContext],
219+
agent: Any,
220+
turn_number: int,
221+
) -> Optional[str]:
222+
self.turn_starts.append(turn_number)
223+
if turn_number > self.stop_after:
224+
return "stop"
225+
return None
226+
227+
async def on_turn_end(
228+
self,
229+
context: RunContextWrapper[TContext],
230+
agent: Any,
231+
turn_number: int,
232+
) -> None:
233+
self.turn_ends.append(turn_number)
234+
235+
236+
class StopAfterTurnAgentHooks(AgentHooks):
237+
"""Agent-level hooks that return 'stop' after a configurable turn."""
238+
239+
def __init__(self, stop_after: int = 1) -> None:
240+
self.stop_after = stop_after
241+
self.turn_starts: list[int] = []
242+
243+
async def on_turn_start(
244+
self,
245+
context: RunContextWrapper[TContext],
246+
agent: Any,
247+
turn_number: int,
248+
) -> Optional[str]:
249+
self.turn_starts.append(turn_number)
250+
if turn_number > self.stop_after:
251+
return "stop"
252+
return None
253+
254+
255+
@pytest.mark.asyncio
256+
async def test_run_hook_stop_halts_loop() -> None:
257+
"""Returning 'stop' from RunHooks.on_turn_start raises MaxTurnsExceeded before the LLM is called.
258+
259+
Turn 1: hook returns None → LLM executes, returns a tool call.
260+
Turn 2: hook returns "stop" → MaxTurnsExceeded is raised before the LLM is called.
261+
"""
262+
from agents.exceptions import MaxTurnsExceeded
263+
264+
tool = get_function_tool("my_tool", "tool_result")
265+
model = FakeModel()
266+
model.add_multiple_turn_outputs(
267+
[
268+
[get_function_tool_call("my_tool", "{}")], # turn 1: tool call
269+
[get_text_message("turn2")], # turn 2: would be final — never reached
270+
]
271+
)
272+
273+
hooks = StopAfterTurnRunHooks(stop_after=1)
274+
agent = Agent(name="A", model=model, tools=[tool])
275+
276+
with pytest.raises(MaxTurnsExceeded, match="halted by on_turn_start hook"):
277+
await Runner.run(agent, input="hi", hooks=hooks, max_turns=10)
278+
279+
# on_turn_start fires for turn 1 (None → continue) AND turn 2 (returns "stop")
280+
assert hooks.turn_starts == [1, 2]
281+
# on_turn_end fires for turn 1 (completed), NOT turn 2 (never ran)
282+
assert hooks.turn_ends == [1]
283+
284+
285+
@pytest.mark.asyncio
286+
async def test_agent_hook_stop_halts_loop() -> None:
287+
"""Returning 'stop' from AgentHooksBase.on_turn_start also raises MaxTurnsExceeded."""
288+
from agents.exceptions import MaxTurnsExceeded
289+
290+
tool = get_function_tool("my_tool", "tool_result")
291+
model = FakeModel()
292+
model.add_multiple_turn_outputs(
293+
[
294+
[get_function_tool_call("my_tool", "{}")],
295+
[get_text_message("turn2")],
296+
]
297+
)
298+
299+
agent_hooks = StopAfterTurnAgentHooks(stop_after=1)
300+
agent = Agent(name="A", model=model, tools=[tool], hooks=agent_hooks)
301+
302+
with pytest.raises(MaxTurnsExceeded, match="halted by on_turn_start hook"):
303+
await Runner.run(agent, input="hi", max_turns=10)
304+
305+
assert agent_hooks.turn_starts == [1, 2]
306+
307+
308+
@pytest.mark.asyncio
309+
async def test_continue_return_value_is_valid() -> None:
310+
"""Returning the literal 'continue' from on_turn_start is treated as proceed."""
311+
model = FakeModel()
312+
model.set_next_output([get_text_message("hello")])
313+
314+
class ExplicitContinueHooks(RunHooks):
315+
async def on_turn_start(self, context: Any, agent: Any, turn_number: int) -> str:
316+
return "continue"
317+
318+
hooks = ExplicitContinueHooks()
319+
agent = Agent(name="A", model=model)
320+
result = await Runner.run(agent, input="hi", hooks=hooks)
321+
assert result.final_output == "hello"
322+
323+
324+
@pytest.mark.asyncio
325+
async def test_stop_on_first_turn_raises_max_turns() -> None:
326+
"""If on_turn_start returns 'stop' on turn 1, MaxTurnsExceeded is raised immediately."""
327+
from agents.exceptions import MaxTurnsExceeded
328+
329+
model = FakeModel()
330+
model.set_next_output([get_text_message("should not appear")])
331+
332+
class ImmediateStopHooks(RunHooks):
333+
async def on_turn_start(self, context: Any, agent: Any, turn_number: int) -> str:
334+
return "stop"
335+
336+
hooks = ImmediateStopHooks()
337+
agent = Agent(name="A", model=model)
338+
339+
with pytest.raises(MaxTurnsExceeded, match="halted by on_turn_start hook"):
340+
await Runner.run(agent, input="hi", hooks=hooks)

0 commit comments

Comments
 (0)