From 27d5dbe23b0f4045baa0b6f0343778ba0060b6d8 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Tue, 14 Apr 2026 22:33:56 +0000 Subject: [PATCH 1/3] feat(lifecycle): add on_turn_start and on_turn_end hooks to RunHooksBase (#2671) Both RunHooksBase and AgentHooksBase get two new hook methods: - on_turn_start(context, agent, turn_number): fires before each LLM call - on_turn_end(context, agent, turn_number): fires after all tool calls for the turn complete (i.e. just before the next-step decision) Turn numbers are 1-indexed and increment each time through the agent loop, regardless of handoffs. The hooks are called in both the sync and streaming code paths. Agent-level hooks on agent.hooks are also called, matching the existing on_tool_start/on_tool_end pattern. Closes #2671 --- src/agents/lifecycle.py | 60 +++++++++ src/agents/run.py | 24 +++- src/agents/run_internal/run_loop.py | 22 ++++ tests/test_turn_lifecycle_hooks.py | 189 ++++++++++++++++++++++++++++ 4 files changed, 294 insertions(+), 1 deletion(-) create mode 100644 tests/test_turn_lifecycle_hooks.py diff --git a/src/agents/lifecycle.py b/src/agents/lifecycle.py index 38744471fb..165daa590e 100644 --- a/src/agents/lifecycle.py +++ b/src/agents/lifecycle.py @@ -86,6 +86,36 @@ async def on_tool_end( """Called immediately after a local tool is invoked.""" pass + async def on_turn_start( + self, + context: RunContextWrapper[TContext], + agent: TAgent, + turn_number: int, + ) -> None: + """Called at the start of each agent turn, before the LLM is invoked. + + Args: + context: The run context wrapper. + agent: The current agent. + turn_number: The 1-indexed turn number (increments each time through the agent loop). + """ + pass + + async def on_turn_end( + self, + context: RunContextWrapper[TContext], + agent: TAgent, + turn_number: int, + ) -> None: + """Called at the end of each agent turn, after all tool calls for that turn complete. + + Args: + context: The run context wrapper. + agent: The current agent. + turn_number: The 1-indexed turn number. + """ + pass + class AgentHooksBase(Generic[TContext, TAgent]): """A class that receives callbacks on various lifecycle events for a specific agent. You can @@ -148,6 +178,36 @@ async def on_tool_end( """Called immediately after a local tool is invoked.""" pass + async def on_turn_start( + self, + context: RunContextWrapper[TContext], + agent: TAgent, + turn_number: int, + ) -> None: + """Called at the start of each agent turn, before the LLM is invoked. + + Args: + context: The run context wrapper. + agent: The current agent. + turn_number: The 1-indexed turn number (increments each time through the agent loop). + """ + pass + + async def on_turn_end( + self, + context: RunContextWrapper[TContext], + agent: TAgent, + turn_number: int, + ) -> None: + """Called at the end of each agent turn, after all tool calls for that turn complete. + + Args: + context: The run context wrapper. + agent: The current agent. + turn_number: The 1-indexed turn number. + """ + pass + async def on_llm_start( self, context: RunContextWrapper[TContext], diff --git a/src/agents/run.py b/src/agents/run.py index 047d454d35..4fc872f5bf 100644 --- a/src/agents/run.py +++ b/src/agents/run.py @@ -111,7 +111,7 @@ from .tracing import Span, SpanError, agent_span, get_current_trace from .tracing.context import TraceCtxManager, create_trace_for_run from .tracing.span_data import AgentSpanData -from .util import _error_tracing +from .util import _coro, _error_tracing DEFAULT_AGENT_RUNNER: AgentRunner = None # type: ignore # the value is set at the end of the module @@ -968,6 +968,17 @@ def _with_reasoning_item_id_policy(result: RunResult) -> RunResult: logger.debug("Running agent %s (turn %s)", current_agent.name, current_turn) + await asyncio.gather( + hooks.on_turn_start(context_wrapper, current_agent, current_turn), + ( + current_agent.hooks.on_turn_start( + context_wrapper, current_agent, current_turn + ) + if current_agent.hooks + else _coro.noop_coroutine() + ), + ) + if session_persistence_enabled: try: last_saved_input_snapshot_for_rewind = ( @@ -1093,6 +1104,17 @@ def _with_reasoning_item_id_policy(result: RunResult) -> RunResult: last_saved_input_snapshot_for_rewind = None should_run_agent_start_hooks = False + await asyncio.gather( + hooks.on_turn_end(context_wrapper, current_agent, current_turn), + ( + current_agent.hooks.on_turn_end( + context_wrapper, current_agent, current_turn + ) + if current_agent.hooks + else _coro.noop_coroutine() + ), + ) + model_responses.append(turn_result.model_response) original_input = turn_result.original_input # For model input, use new_step_items (filtered on handoffs). diff --git a/src/agents/run_internal/run_loop.py b/src/agents/run_internal/run_loop.py index 3d21d89fda..8f2ca0684a 100644 --- a/src/agents/run_internal/run_loop.py +++ b/src/agents/run_internal/run_loop.py @@ -820,6 +820,17 @@ async def _save_stream_items_without_count( streamed_result._event_queue.put_nowait(QueueCompleteSentinel()) break + await asyncio.gather( + hooks.on_turn_start(context_wrapper, current_agent, current_turn), + ( + current_agent.hooks.on_turn_start( + context_wrapper, current_agent, current_turn + ) + if current_agent.hooks + else _coro.noop_coroutine() + ), + ) + if current_turn == 1: all_input_guardrails = starting_agent.input_guardrails + ( run_config.input_guardrails or [] @@ -909,6 +920,17 @@ async def _save_stream_items_without_count( tool_use_tracker ) + await asyncio.gather( + hooks.on_turn_end(context_wrapper, current_agent, current_turn), + ( + current_agent.hooks.on_turn_end( + context_wrapper, current_agent, current_turn + ) + if current_agent.hooks + else _coro.noop_coroutine() + ), + ) + streamed_result.raw_responses = streamed_result.raw_responses + [ turn_result.model_response ] diff --git a/tests/test_turn_lifecycle_hooks.py b/tests/test_turn_lifecycle_hooks.py new file mode 100644 index 0000000000..2baf1f5d03 --- /dev/null +++ b/tests/test_turn_lifecycle_hooks.py @@ -0,0 +1,189 @@ +"""Tests for on_turn_start / on_turn_end lifecycle hooks (issue #2671).""" + +from __future__ import annotations + +from collections import defaultdict +from typing import Any, Optional + +import pytest + +from agents import Agent, Runner +from agents.items import ModelResponse, TResponseInputItem +from agents.lifecycle import AgentHooks, RunHooks +from agents.run_context import RunContextWrapper, TContext +from agents.tool import FunctionTool, Tool + +from .fake_model import FakeModel +from .test_responses import get_function_tool, get_function_tool_call, get_text_message + + +class TurnTrackingRunHooks(RunHooks): + """Records turn numbers seen by on_turn_start and on_turn_end.""" + + def __init__(self) -> None: + self.turn_starts: list[int] = [] + self.turn_ends: list[int] = [] + self.events: dict[str, int] = defaultdict(int) + + async def on_turn_start( + self, + context: RunContextWrapper[TContext], + agent: Any, + turn_number: int, + ) -> None: + self.turn_starts.append(turn_number) + self.events["on_turn_start"] += 1 + + async def on_turn_end( + self, + context: RunContextWrapper[TContext], + agent: Any, + turn_number: int, + ) -> None: + self.turn_ends.append(turn_number) + self.events["on_turn_end"] += 1 + + +class TurnTrackingAgentHooks(AgentHooks): + """Records turn numbers seen on agent-level hooks.""" + + def __init__(self) -> None: + self.turn_starts: list[int] = [] + self.turn_ends: list[int] = [] + + async def on_turn_start( + self, + context: RunContextWrapper[TContext], + agent: Any, + turn_number: int, + ) -> None: + self.turn_starts.append(turn_number) + + async def on_turn_end( + self, + context: RunContextWrapper[TContext], + agent: Any, + turn_number: int, + ) -> None: + self.turn_ends.append(turn_number) + + +@pytest.mark.asyncio +async def test_on_turn_start_and_end_single_turn() -> None: + """on_turn_start and on_turn_end are both called once for a single-turn run.""" + model = FakeModel() + model.set_next_output([get_text_message("hello")]) + + hooks = TurnTrackingRunHooks() + agent = Agent(name="A", model=model) + + await Runner.run(agent, input="hi", hooks=hooks) + + assert hooks.turn_starts == [1] + assert hooks.turn_ends == [1] + + +@pytest.mark.asyncio +async def test_on_turn_numbers_multi_turn() -> None: + """Turn numbers increment correctly across multiple turns.""" + model = FakeModel() + # Turn 1: model calls a tool; turn 2: model produces final output. + tool = get_function_tool("my_tool", "tool_result") + model.add_multiple_turn_outputs([ + [get_function_tool_call("my_tool", "{}")], + [get_text_message("done")], + ]) + + hooks = TurnTrackingRunHooks() + agent = Agent(name="A", model=model, tools=[tool]) + + await Runner.run(agent, input="hi", hooks=hooks) + + assert hooks.turn_starts == [1, 2] + assert hooks.turn_ends == [1, 2] + + +@pytest.mark.asyncio +async def test_on_turn_start_fires_before_llm() -> None: + """on_turn_start fires before the LLM call each turn.""" + call_order: list[str] = [] + + class OrderTrackingHooks(RunHooks): + async def on_turn_start(self, context: Any, agent: Any, turn_number: int) -> None: + call_order.append(f"turn_start:{turn_number}") + + async def on_llm_start(self, context: Any, agent: Any, system_prompt: Any, input_items: Any) -> None: + call_order.append(f"llm_start") + + async def on_llm_end(self, context: Any, agent: Any, response: Any) -> None: + call_order.append(f"llm_end") + + async def on_turn_end(self, context: Any, agent: Any, turn_number: int) -> None: + call_order.append(f"turn_end:{turn_number}") + + model = FakeModel() + model.set_next_output([get_text_message("hello")]) + hooks = OrderTrackingHooks() + agent = Agent(name="A", model=model) + + await Runner.run(agent, input="hi", hooks=hooks) + + # turn_start must come before llm_start, llm_end before turn_end + ts_idx = call_order.index("turn_start:1") + ls_idx = call_order.index("llm_start") + le_idx = call_order.index("llm_end") + te_idx = call_order.index("turn_end:1") + + assert ts_idx < ls_idx + assert ls_idx < le_idx + assert le_idx < te_idx + + +@pytest.mark.asyncio +async def test_agent_level_on_turn_start_and_end() -> None: + """Agent-level on_turn_start / on_turn_end hooks are also called.""" + model = FakeModel() + model.set_next_output([get_text_message("hello")]) + + agent_hooks = TurnTrackingAgentHooks() + agent = Agent(name="A", model=model, hooks=agent_hooks) + + await Runner.run(agent, input="hi") + + assert agent_hooks.turn_starts == [1] + assert agent_hooks.turn_ends == [1] + + +@pytest.mark.asyncio +async def test_run_and_agent_hooks_both_called() -> None: + """Both run-level and agent-level hooks fire for the same turn.""" + model = FakeModel() + model.set_next_output([get_text_message("hi")]) + + run_hooks = TurnTrackingRunHooks() + agent_hooks = TurnTrackingAgentHooks() + agent = Agent(name="A", model=model, hooks=agent_hooks) + + await Runner.run(agent, input="hi", hooks=run_hooks) + + assert run_hooks.turn_starts == [1] + assert run_hooks.turn_ends == [1] + assert agent_hooks.turn_starts == [1] + assert agent_hooks.turn_ends == [1] + + +@pytest.mark.asyncio +async def test_on_turn_hooks_with_streaming() -> None: + """on_turn_start and on_turn_end are called when using the streaming runner.""" + model = FakeModel() + model.set_next_output([get_text_message("streamed")]) + + hooks = TurnTrackingRunHooks() + agent = Agent(name="A", model=model) + + result = Runner.run_streamed(agent, input="hi", hooks=hooks) + async for _ in result.stream_events(): + pass + + assert hooks.turn_starts == [1] + assert hooks.turn_ends == [1] From ff99d9098e3f00eca5495524d5a51a5cc14c4605 Mon Sep 17 00:00:00 2001 From: Aditya Singh Date: Tue, 14 Apr 2026 22:45:37 +0000 Subject: [PATCH 2/3] fix: address CodeRabbit review comments - Remove unnecessary f-string prefixes from on_llm_start/on_llm_end (Ruff F541) - Add missing docstrings to class methods to improve docstring coverage - Add docstring to OrderTrackingHooks inner class --- tests/test_turn_lifecycle_hooks.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/tests/test_turn_lifecycle_hooks.py b/tests/test_turn_lifecycle_hooks.py index 2baf1f5d03..b20a8b2093 100644 --- a/tests/test_turn_lifecycle_hooks.py +++ b/tests/test_turn_lifecycle_hooks.py @@ -21,6 +21,7 @@ class TurnTrackingRunHooks(RunHooks): """Records turn numbers seen by on_turn_start and on_turn_end.""" def __init__(self) -> None: + """Initialise empty tracking lists and event counters.""" self.turn_starts: list[int] = [] self.turn_ends: list[int] = [] self.events: dict[str, int] = defaultdict(int) @@ -31,6 +32,7 @@ async def on_turn_start( agent: Any, turn_number: int, ) -> None: + """Record the turn number when a turn starts.""" self.turn_starts.append(turn_number) self.events["on_turn_start"] += 1 @@ -40,6 +42,7 @@ async def on_turn_end( agent: Any, turn_number: int, ) -> None: + """Record the turn number when a turn ends.""" self.turn_ends.append(turn_number) self.events["on_turn_end"] += 1 @@ -48,6 +51,7 @@ class TurnTrackingAgentHooks(AgentHooks): """Records turn numbers seen on agent-level hooks.""" def __init__(self) -> None: + """Initialise empty tracking lists.""" self.turn_starts: list[int] = [] self.turn_ends: list[int] = [] @@ -57,6 +61,7 @@ async def on_turn_start( agent: Any, turn_number: int, ) -> None: + """Record the turn number when a turn starts.""" self.turn_starts.append(turn_number) async def on_turn_end( @@ -65,6 +70,7 @@ async def on_turn_end( agent: Any, turn_number: int, ) -> None: + """Record the turn number when a turn ends.""" self.turn_ends.append(turn_number) @@ -109,16 +115,22 @@ async def test_on_turn_start_fires_before_llm() -> None: call_order: list[str] = [] class OrderTrackingHooks(RunHooks): + """Tracks the order that turn/LLM lifecycle hooks are called.""" + async def on_turn_start(self, context: Any, agent: Any, turn_number: int) -> None: + """Append a turn_start marker with the turn number.""" call_order.append(f"turn_start:{turn_number}") async def on_llm_start(self, context: Any, agent: Any, system_prompt: Any, input_items: Any) -> None: - call_order.append(f"llm_start") + """Append an llm_start marker.""" + call_order.append("llm_start") async def on_llm_end(self, context: Any, agent: Any, response: Any) -> None: - call_order.append(f"llm_end") + """Append an llm_end marker.""" + call_order.append("llm_end") async def on_turn_end(self, context: Any, agent: Any, turn_number: int) -> None: + """Append a turn_end marker with the turn number.""" call_order.append(f"turn_end:{turn_number}") model = FakeModel() From f1cafcd9ce1beab0f90f6d5a91ac1568c45f4c18 Mon Sep 17 00:00:00 2001 From: Aditya Singh Date: Fri, 17 Apr 2026 03:11:55 +0000 Subject: [PATCH 3/3] feat(lifecycle): add TurnControl return value to on_turn_start hooks Address seratch's review feedback on #2911: hooks that only observe cannot affect agent loop orchestration. This commit adds a TurnControl return type ('continue' | 'stop') so on_turn_start can now halt the run before the LLM is called for that turn. Changes: - lifecycle.py: on_turn_start now returns Union[TurnControl, None] (None and 'continue' are equivalent; 'stop' halts the loop) - run.py (non-streaming path): checks return value; raises MaxTurnsExceeded with descriptive message on 'stop' - run_internal/run_loop.py (streaming path): checks return value; signals QueueCompleteSentinel on 'stop' - __init__.py: exports TurnControl, RunHooksBase, AgentHooksBase - tests: 4 new test cases covering stop-on-turn-N, stop-on-turn-1, explicit 'continue', and agent-level stop The MaxTurnsExceeded raise on 'stop' keeps behaviour consistent with the existing max_turns limit: callers can catch and inspect .run_data if needed. --- src/agents/__init__.py | 5 +- src/agents/lifecycle.py | 54 +++++++++-- src/agents/run.py | 10 +- src/agents/run_internal/run_loop.py | 14 ++- tests/test_turn_lifecycle_hooks.py | 139 ++++++++++++++++++++++++++++ 5 files changed, 212 insertions(+), 10 deletions(-) diff --git a/src/agents/__init__.py b/src/agents/__init__.py index 214e814d3e..bdb040c83e 100644 --- a/src/agents/__init__.py +++ b/src/agents/__init__.py @@ -65,7 +65,7 @@ ToolCallOutputItem, TResponseInputItem, ) -from .lifecycle import AgentHooks, RunHooks +from .lifecycle import AgentHooks, AgentHooksBase, RunHooks, RunHooksBase, TurnControl from .memory import ( OpenAIConversationsSession, OpenAIResponsesCompactionArgs, @@ -361,7 +361,10 @@ def enable_verbose_stdout_logging(): "ReasoningItem", "ItemHelpers", "RunHooks", + "RunHooksBase", "AgentHooks", + "AgentHooksBase", + "TurnControl", "Session", "SessionABC", "SessionSettings", diff --git a/src/agents/lifecycle.py b/src/agents/lifecycle.py index 165daa590e..a79a81b087 100644 --- a/src/agents/lifecycle.py +++ b/src/agents/lifecycle.py @@ -1,4 +1,6 @@ -from typing import Any, Generic, Optional +from __future__ import annotations + +from typing import Any, Generic, Literal, Optional, Union from typing_extensions import TypeVar @@ -9,10 +11,26 @@ TAgent = TypeVar("TAgent", bound=AgentBase, default=AgentBase) +TurnControl = Literal["continue", "stop"] +"""Return value for :meth:`RunHooksBase.on_turn_start` / :meth:`AgentHooksBase.on_turn_start`. + +* ``"continue"`` (default / ``None``) – proceed with the turn as normal. +* ``"stop"`` – abort the run gracefully after this hook returns, exactly as if + ``max_turns`` had been reached. The model is **not** called for this turn and + :meth:`on_turn_end` is **not** fired. +""" + class RunHooksBase(Generic[TContext, TAgent]): """A class that receives callbacks on various lifecycle events in an agent run. Subclass and override the methods you need. + + Turn-lifecycle hooks + -------------------- + :meth:`on_turn_start` and :meth:`on_turn_end` fire once per iteration of the + agent loop. :meth:`on_turn_start` may return ``"stop"`` to halt the run + gracefully before the LLM is called for that turn (useful for implementing + custom turn-budget logic, external kill-switches, etc.). """ async def on_llm_start( @@ -91,15 +109,24 @@ async def on_turn_start( context: RunContextWrapper[TContext], agent: TAgent, turn_number: int, - ) -> None: + ) -> Union[TurnControl, None]: """Called at the start of each agent turn, before the LLM is invoked. + Returning ``"stop"`` (or raising :class:`StopAgentRun`) will halt the run + gracefully — the model is **not** called for this turn and + :meth:`on_turn_end` is **not** fired. Returning ``None`` or ``"continue"`` + proceeds normally. + Args: context: The run context wrapper. agent: The current agent. - turn_number: The 1-indexed turn number (increments each time through the agent loop). + turn_number: The 1-indexed turn number (increments each time through the + agent loop). + + Returns: + ``None`` / ``"continue"`` to proceed, or ``"stop"`` to halt the run. """ - pass + return None async def on_turn_end( self, @@ -122,6 +149,12 @@ class AgentHooksBase(Generic[TContext, TAgent]): set this on `agent.hooks` to receive events for that specific agent. Subclass and override the methods you need. + + Turn-lifecycle hooks + -------------------- + :meth:`on_turn_start` and :meth:`on_turn_end` fire once per iteration of the + agent loop. :meth:`on_turn_start` may return ``"stop"`` to halt the run + gracefully before the LLM is called for that turn. """ async def on_start(self, context: AgentHookContext[TContext], agent: TAgent) -> None: @@ -183,15 +216,22 @@ async def on_turn_start( context: RunContextWrapper[TContext], agent: TAgent, turn_number: int, - ) -> None: + ) -> Union[TurnControl, None]: """Called at the start of each agent turn, before the LLM is invoked. + Returning ``"stop"`` halts the run gracefully before the model is called. + Returning ``None`` or ``"continue"`` proceeds normally. + Args: context: The run context wrapper. agent: The current agent. - turn_number: The 1-indexed turn number (increments each time through the agent loop). + turn_number: The 1-indexed turn number (increments each time through the + agent loop). + + Returns: + ``None`` / ``"continue"`` to proceed, or ``"stop"`` to halt the run. """ - pass + return None async def on_turn_end( self, diff --git a/src/agents/run.py b/src/agents/run.py index 4fc872f5bf..2e5bdca62a 100644 --- a/src/agents/run.py +++ b/src/agents/run.py @@ -968,7 +968,7 @@ def _with_reasoning_item_id_policy(result: RunResult) -> RunResult: logger.debug("Running agent %s (turn %s)", current_agent.name, current_turn) - await asyncio.gather( + run_hook_control, agent_hook_control = await asyncio.gather( hooks.on_turn_start(context_wrapper, current_agent, current_turn), ( current_agent.hooks.on_turn_start( @@ -978,6 +978,14 @@ def _with_reasoning_item_id_policy(result: RunResult) -> RunResult: else _coro.noop_coroutine() ), ) + if run_hook_control == "stop" or agent_hook_control == "stop": + logger.debug( + "Turn %s: on_turn_start hook requested stop; halting run.", + current_turn, + ) + raise MaxTurnsExceeded( + f"Run halted by on_turn_start hook at turn {current_turn}" + ) if session_persistence_enabled: try: diff --git a/src/agents/run_internal/run_loop.py b/src/agents/run_internal/run_loop.py index 8f2ca0684a..c65050e1eb 100644 --- a/src/agents/run_internal/run_loop.py +++ b/src/agents/run_internal/run_loop.py @@ -820,7 +820,7 @@ async def _save_stream_items_without_count( streamed_result._event_queue.put_nowait(QueueCompleteSentinel()) break - await asyncio.gather( + run_hook_control, agent_hook_control = await asyncio.gather( hooks.on_turn_start(context_wrapper, current_agent, current_turn), ( current_agent.hooks.on_turn_start( @@ -830,6 +830,18 @@ async def _save_stream_items_without_count( else _coro.noop_coroutine() ), ) + if run_hook_control == "stop" or agent_hook_control == "stop": + logger.debug( + "Turn %s: on_turn_start hook requested stop; halting run.", + current_turn, + ) + streamed_result._max_turns_handled = True + streamed_result.current_turn = current_turn + if run_state is not None: + run_state._current_turn = current_turn + run_state._current_step = None + streamed_result._event_queue.put_nowait(QueueCompleteSentinel()) + break if current_turn == 1: all_input_guardrails = starting_agent.input_guardrails + ( diff --git a/tests/test_turn_lifecycle_hooks.py b/tests/test_turn_lifecycle_hooks.py index b20a8b2093..f9a3238378 100644 --- a/tests/test_turn_lifecycle_hooks.py +++ b/tests/test_turn_lifecycle_hooks.py @@ -199,3 +199,142 @@ async def test_on_turn_hooks_with_streaming() -> None: assert hooks.turn_starts == [1] assert hooks.turn_ends == [1] + + +# --------------------------------------------------------------------------- +# TurnControl tests: on_turn_start returning "stop" halts the loop +# --------------------------------------------------------------------------- + +class StopAfterTurnRunHooks(RunHooks): + """Stops the run when on_turn_start is called for a turn > stop_after.""" + + def __init__(self, stop_after: int = 1) -> None: + self.stop_after = stop_after + self.turn_starts: list[int] = [] + self.turn_ends: list[int] = [] + + async def on_turn_start( + self, + context: RunContextWrapper[TContext], + agent: Any, + turn_number: int, + ) -> Optional[str]: + self.turn_starts.append(turn_number) + if turn_number > self.stop_after: + return "stop" + return None + + async def on_turn_end( + self, + context: RunContextWrapper[TContext], + agent: Any, + turn_number: int, + ) -> None: + self.turn_ends.append(turn_number) + + +class StopAfterTurnAgentHooks(AgentHooks): + """Agent-level hooks that return 'stop' after a configurable turn.""" + + def __init__(self, stop_after: int = 1) -> None: + self.stop_after = stop_after + self.turn_starts: list[int] = [] + + async def on_turn_start( + self, + context: RunContextWrapper[TContext], + agent: Any, + turn_number: int, + ) -> Optional[str]: + self.turn_starts.append(turn_number) + if turn_number > self.stop_after: + return "stop" + return None + + +@pytest.mark.asyncio +async def test_run_hook_stop_halts_loop() -> None: + """Returning 'stop' from RunHooks.on_turn_start raises MaxTurnsExceeded before the LLM is called. + + Turn 1: hook returns None → LLM executes, returns a tool call. + Turn 2: hook returns "stop" → MaxTurnsExceeded is raised before the LLM is called. + """ + from agents.exceptions import MaxTurnsExceeded + + tool = get_function_tool("my_tool", "tool_result") + model = FakeModel() + model.add_multiple_turn_outputs( + [ + [get_function_tool_call("my_tool", "{}")], # turn 1: tool call + [get_text_message("turn2")], # turn 2: would be final — never reached + ] + ) + + hooks = StopAfterTurnRunHooks(stop_after=1) + agent = Agent(name="A", model=model, tools=[tool]) + + with pytest.raises(MaxTurnsExceeded, match="halted by on_turn_start hook"): + await Runner.run(agent, input="hi", hooks=hooks, max_turns=10) + + # on_turn_start fires for turn 1 (None → continue) AND turn 2 (returns "stop") + assert hooks.turn_starts == [1, 2] + # on_turn_end fires for turn 1 (completed), NOT turn 2 (never ran) + assert hooks.turn_ends == [1] + + +@pytest.mark.asyncio +async def test_agent_hook_stop_halts_loop() -> None: + """Returning 'stop' from AgentHooksBase.on_turn_start also raises MaxTurnsExceeded.""" + from agents.exceptions import MaxTurnsExceeded + + tool = get_function_tool("my_tool", "tool_result") + model = FakeModel() + model.add_multiple_turn_outputs( + [ + [get_function_tool_call("my_tool", "{}")], + [get_text_message("turn2")], + ] + ) + + agent_hooks = StopAfterTurnAgentHooks(stop_after=1) + agent = Agent(name="A", model=model, tools=[tool], hooks=agent_hooks) + + with pytest.raises(MaxTurnsExceeded, match="halted by on_turn_start hook"): + await Runner.run(agent, input="hi", max_turns=10) + + assert agent_hooks.turn_starts == [1, 2] + + +@pytest.mark.asyncio +async def test_continue_return_value_is_valid() -> None: + """Returning the literal 'continue' from on_turn_start is treated as proceed.""" + model = FakeModel() + model.set_next_output([get_text_message("hello")]) + + class ExplicitContinueHooks(RunHooks): + async def on_turn_start(self, context: Any, agent: Any, turn_number: int) -> str: + return "continue" + + hooks = ExplicitContinueHooks() + agent = Agent(name="A", model=model) + result = await Runner.run(agent, input="hi", hooks=hooks) + assert result.final_output == "hello" + + +@pytest.mark.asyncio +async def test_stop_on_first_turn_raises_max_turns() -> None: + """If on_turn_start returns 'stop' on turn 1, MaxTurnsExceeded is raised immediately.""" + from agents.exceptions import MaxTurnsExceeded + + model = FakeModel() + model.set_next_output([get_text_message("should not appear")]) + + class ImmediateStopHooks(RunHooks): + async def on_turn_start(self, context: Any, agent: Any, turn_number: int) -> str: + return "stop" + + hooks = ImmediateStopHooks() + agent = Agent(name="A", model=model) + + with pytest.raises(MaxTurnsExceeded, match="halted by on_turn_start hook"): + await Runner.run(agent, input="hi", hooks=hooks)