diff --git a/src/agents/__init__.py b/src/agents/__init__.py index 214e814d3e..bdb040c83e 100644 --- a/src/agents/__init__.py +++ b/src/agents/__init__.py @@ -65,7 +65,7 @@ ToolCallOutputItem, TResponseInputItem, ) -from .lifecycle import AgentHooks, RunHooks +from .lifecycle import AgentHooks, AgentHooksBase, RunHooks, RunHooksBase, TurnControl from .memory import ( OpenAIConversationsSession, OpenAIResponsesCompactionArgs, @@ -361,7 +361,10 @@ def enable_verbose_stdout_logging(): "ReasoningItem", "ItemHelpers", "RunHooks", + "RunHooksBase", "AgentHooks", + "AgentHooksBase", + "TurnControl", "Session", "SessionABC", "SessionSettings", diff --git a/src/agents/lifecycle.py b/src/agents/lifecycle.py index 38744471fb..a79a81b087 100644 --- a/src/agents/lifecycle.py +++ b/src/agents/lifecycle.py @@ -1,4 +1,6 @@ -from typing import Any, Generic, Optional +from __future__ import annotations + +from typing import Any, Generic, Literal, Optional, Union from typing_extensions import TypeVar @@ -9,10 +11,26 @@ TAgent = TypeVar("TAgent", bound=AgentBase, default=AgentBase) +TurnControl = Literal["continue", "stop"] +"""Return value for :meth:`RunHooksBase.on_turn_start` / :meth:`AgentHooksBase.on_turn_start`. + +* ``"continue"`` (default / ``None``) – proceed with the turn as normal. +* ``"stop"`` – abort the run gracefully after this hook returns, exactly as if + ``max_turns`` had been reached. The model is **not** called for this turn and + :meth:`on_turn_end` is **not** fired. +""" + class RunHooksBase(Generic[TContext, TAgent]): """A class that receives callbacks on various lifecycle events in an agent run. Subclass and override the methods you need. + + Turn-lifecycle hooks + -------------------- + :meth:`on_turn_start` and :meth:`on_turn_end` fire once per iteration of the + agent loop. :meth:`on_turn_start` may return ``"stop"`` to halt the run + gracefully before the LLM is called for that turn (useful for implementing + custom turn-budget logic, external kill-switches, etc.). """ async def on_llm_start( @@ -86,12 +104,57 @@ async def on_tool_end( """Called immediately after a local tool is invoked.""" pass + async def on_turn_start( + self, + context: RunContextWrapper[TContext], + agent: TAgent, + turn_number: int, + ) -> Union[TurnControl, None]: + """Called at the start of each agent turn, before the LLM is invoked. + + Returning ``"stop"`` (or raising :class:`StopAgentRun`) will halt the run + gracefully — the model is **not** called for this turn and + :meth:`on_turn_end` is **not** fired. Returning ``None`` or ``"continue"`` + proceeds normally. + + Args: + context: The run context wrapper. + agent: The current agent. + turn_number: The 1-indexed turn number (increments each time through the + agent loop). + + Returns: + ``None`` / ``"continue"`` to proceed, or ``"stop"`` to halt the run. + """ + return None + + async def on_turn_end( + self, + context: RunContextWrapper[TContext], + agent: TAgent, + turn_number: int, + ) -> None: + """Called at the end of each agent turn, after all tool calls for that turn complete. + + Args: + context: The run context wrapper. + agent: The current agent. + turn_number: The 1-indexed turn number. + """ + pass + class AgentHooksBase(Generic[TContext, TAgent]): """A class that receives callbacks on various lifecycle events for a specific agent. You can set this on `agent.hooks` to receive events for that specific agent. Subclass and override the methods you need. + + Turn-lifecycle hooks + -------------------- + :meth:`on_turn_start` and :meth:`on_turn_end` fire once per iteration of the + agent loop. :meth:`on_turn_start` may return ``"stop"`` to halt the run + gracefully before the LLM is called for that turn. """ async def on_start(self, context: AgentHookContext[TContext], agent: TAgent) -> None: @@ -148,6 +211,43 @@ async def on_tool_end( """Called immediately after a local tool is invoked.""" pass + async def on_turn_start( + self, + context: RunContextWrapper[TContext], + agent: TAgent, + turn_number: int, + ) -> Union[TurnControl, None]: + """Called at the start of each agent turn, before the LLM is invoked. + + Returning ``"stop"`` halts the run gracefully before the model is called. + Returning ``None`` or ``"continue"`` proceeds normally. + + Args: + context: The run context wrapper. + agent: The current agent. + turn_number: The 1-indexed turn number (increments each time through the + agent loop). + + Returns: + ``None`` / ``"continue"`` to proceed, or ``"stop"`` to halt the run. + """ + return None + + async def on_turn_end( + self, + context: RunContextWrapper[TContext], + agent: TAgent, + turn_number: int, + ) -> None: + """Called at the end of each agent turn, after all tool calls for that turn complete. + + Args: + context: The run context wrapper. + agent: The current agent. + turn_number: The 1-indexed turn number. + """ + pass + async def on_llm_start( self, context: RunContextWrapper[TContext], diff --git a/src/agents/run.py b/src/agents/run.py index 047d454d35..2e5bdca62a 100644 --- a/src/agents/run.py +++ b/src/agents/run.py @@ -111,7 +111,7 @@ from .tracing import Span, SpanError, agent_span, get_current_trace from .tracing.context import TraceCtxManager, create_trace_for_run from .tracing.span_data import AgentSpanData -from .util import _error_tracing +from .util import _coro, _error_tracing DEFAULT_AGENT_RUNNER: AgentRunner = None # type: ignore # the value is set at the end of the module @@ -968,6 +968,25 @@ def _with_reasoning_item_id_policy(result: RunResult) -> RunResult: logger.debug("Running agent %s (turn %s)", current_agent.name, current_turn) + run_hook_control, agent_hook_control = await asyncio.gather( + hooks.on_turn_start(context_wrapper, current_agent, current_turn), + ( + current_agent.hooks.on_turn_start( + context_wrapper, current_agent, current_turn + ) + if current_agent.hooks + else _coro.noop_coroutine() + ), + ) + if run_hook_control == "stop" or agent_hook_control == "stop": + logger.debug( + "Turn %s: on_turn_start hook requested stop; halting run.", + current_turn, + ) + raise MaxTurnsExceeded( + f"Run halted by on_turn_start hook at turn {current_turn}" + ) + if session_persistence_enabled: try: last_saved_input_snapshot_for_rewind = ( @@ -1093,6 +1112,17 @@ def _with_reasoning_item_id_policy(result: RunResult) -> RunResult: last_saved_input_snapshot_for_rewind = None should_run_agent_start_hooks = False + await asyncio.gather( + hooks.on_turn_end(context_wrapper, current_agent, current_turn), + ( + current_agent.hooks.on_turn_end( + context_wrapper, current_agent, current_turn + ) + if current_agent.hooks + else _coro.noop_coroutine() + ), + ) + model_responses.append(turn_result.model_response) original_input = turn_result.original_input # For model input, use new_step_items (filtered on handoffs). diff --git a/src/agents/run_internal/run_loop.py b/src/agents/run_internal/run_loop.py index 3d21d89fda..c65050e1eb 100644 --- a/src/agents/run_internal/run_loop.py +++ b/src/agents/run_internal/run_loop.py @@ -820,6 +820,29 @@ async def _save_stream_items_without_count( streamed_result._event_queue.put_nowait(QueueCompleteSentinel()) break + run_hook_control, agent_hook_control = await asyncio.gather( + hooks.on_turn_start(context_wrapper, current_agent, current_turn), + ( + current_agent.hooks.on_turn_start( + context_wrapper, current_agent, current_turn + ) + if current_agent.hooks + else _coro.noop_coroutine() + ), + ) + if run_hook_control == "stop" or agent_hook_control == "stop": + logger.debug( + "Turn %s: on_turn_start hook requested stop; halting run.", + current_turn, + ) + streamed_result._max_turns_handled = True + streamed_result.current_turn = current_turn + if run_state is not None: + run_state._current_turn = current_turn + run_state._current_step = None + streamed_result._event_queue.put_nowait(QueueCompleteSentinel()) + break + if current_turn == 1: all_input_guardrails = starting_agent.input_guardrails + ( run_config.input_guardrails or [] @@ -909,6 +932,17 @@ async def _save_stream_items_without_count( tool_use_tracker ) + await asyncio.gather( + hooks.on_turn_end(context_wrapper, current_agent, current_turn), + ( + current_agent.hooks.on_turn_end( + context_wrapper, current_agent, current_turn + ) + if current_agent.hooks + else _coro.noop_coroutine() + ), + ) + streamed_result.raw_responses = streamed_result.raw_responses + [ turn_result.model_response ] diff --git a/tests/test_turn_lifecycle_hooks.py b/tests/test_turn_lifecycle_hooks.py new file mode 100644 index 0000000000..f9a3238378 --- /dev/null +++ b/tests/test_turn_lifecycle_hooks.py @@ -0,0 +1,340 @@ +"""Tests for on_turn_start / on_turn_end lifecycle hooks (issue #2671).""" + +from __future__ import annotations + +from collections import defaultdict +from typing import Any, Optional + +import pytest + +from agents import Agent, Runner +from agents.items import ModelResponse, TResponseInputItem +from agents.lifecycle import AgentHooks, RunHooks +from agents.run_context import RunContextWrapper, TContext +from agents.tool import FunctionTool, Tool + +from .fake_model import FakeModel +from .test_responses import get_function_tool, get_function_tool_call, get_text_message + + +class TurnTrackingRunHooks(RunHooks): + """Records turn numbers seen by on_turn_start and on_turn_end.""" + + def __init__(self) -> None: + """Initialise empty tracking lists and event counters.""" + self.turn_starts: list[int] = [] + self.turn_ends: list[int] = [] + self.events: dict[str, int] = defaultdict(int) + + async def on_turn_start( + self, + context: RunContextWrapper[TContext], + agent: Any, + turn_number: int, + ) -> None: + """Record the turn number when a turn starts.""" + self.turn_starts.append(turn_number) + self.events["on_turn_start"] += 1 + + async def on_turn_end( + self, + context: RunContextWrapper[TContext], + agent: Any, + turn_number: int, + ) -> None: + """Record the turn number when a turn ends.""" + self.turn_ends.append(turn_number) + self.events["on_turn_end"] += 1 + + +class TurnTrackingAgentHooks(AgentHooks): + """Records turn numbers seen on agent-level hooks.""" + + def __init__(self) -> None: + """Initialise empty tracking lists.""" + self.turn_starts: list[int] = [] + self.turn_ends: list[int] = [] + + async def on_turn_start( + self, + context: RunContextWrapper[TContext], + agent: Any, + turn_number: int, + ) -> None: + """Record the turn number when a turn starts.""" + self.turn_starts.append(turn_number) + + async def on_turn_end( + self, + context: RunContextWrapper[TContext], + agent: Any, + turn_number: int, + ) -> None: + """Record the turn number when a turn ends.""" + self.turn_ends.append(turn_number) + + +@pytest.mark.asyncio +async def test_on_turn_start_and_end_single_turn() -> None: + """on_turn_start and on_turn_end are both called once for a single-turn run.""" + model = FakeModel() + model.set_next_output([get_text_message("hello")]) + + hooks = TurnTrackingRunHooks() + agent = Agent(name="A", model=model) + + await Runner.run(agent, input="hi", hooks=hooks) + + assert hooks.turn_starts == [1] + assert hooks.turn_ends == [1] + + +@pytest.mark.asyncio +async def test_on_turn_numbers_multi_turn() -> None: + """Turn numbers increment correctly across multiple turns.""" + model = FakeModel() + # Turn 1: model calls a tool; turn 2: model produces final output. + tool = get_function_tool("my_tool", "tool_result") + model.add_multiple_turn_outputs([ + [get_function_tool_call("my_tool", "{}")], + [get_text_message("done")], + ]) + + hooks = TurnTrackingRunHooks() + agent = Agent(name="A", model=model, tools=[tool]) + + await Runner.run(agent, input="hi", hooks=hooks) + + assert hooks.turn_starts == [1, 2] + assert hooks.turn_ends == [1, 2] + + +@pytest.mark.asyncio +async def test_on_turn_start_fires_before_llm() -> None: + """on_turn_start fires before the LLM call each turn.""" + call_order: list[str] = [] + + class OrderTrackingHooks(RunHooks): + """Tracks the order that turn/LLM lifecycle hooks are called.""" + + async def on_turn_start(self, context: Any, agent: Any, turn_number: int) -> None: + """Append a turn_start marker with the turn number.""" + call_order.append(f"turn_start:{turn_number}") + + async def on_llm_start(self, context: Any, agent: Any, system_prompt: Any, input_items: Any) -> None: + """Append an llm_start marker.""" + call_order.append("llm_start") + + async def on_llm_end(self, context: Any, agent: Any, response: Any) -> None: + """Append an llm_end marker.""" + call_order.append("llm_end") + + async def on_turn_end(self, context: Any, agent: Any, turn_number: int) -> None: + """Append a turn_end marker with the turn number.""" + call_order.append(f"turn_end:{turn_number}") + + model = FakeModel() + model.set_next_output([get_text_message("hello")]) + hooks = OrderTrackingHooks() + agent = Agent(name="A", model=model) + + await Runner.run(agent, input="hi", hooks=hooks) + + # turn_start must come before llm_start, llm_end before turn_end + ts_idx = call_order.index("turn_start:1") + ls_idx = call_order.index("llm_start") + le_idx = call_order.index("llm_end") + te_idx = call_order.index("turn_end:1") + + assert ts_idx < ls_idx + assert ls_idx < le_idx + assert le_idx < te_idx + + +@pytest.mark.asyncio +async def test_agent_level_on_turn_start_and_end() -> None: + """Agent-level on_turn_start / on_turn_end hooks are also called.""" + model = FakeModel() + model.set_next_output([get_text_message("hello")]) + + agent_hooks = TurnTrackingAgentHooks() + agent = Agent(name="A", model=model, hooks=agent_hooks) + + await Runner.run(agent, input="hi") + + assert agent_hooks.turn_starts == [1] + assert agent_hooks.turn_ends == [1] + + +@pytest.mark.asyncio +async def test_run_and_agent_hooks_both_called() -> None: + """Both run-level and agent-level hooks fire for the same turn.""" + model = FakeModel() + model.set_next_output([get_text_message("hi")]) + + run_hooks = TurnTrackingRunHooks() + agent_hooks = TurnTrackingAgentHooks() + agent = Agent(name="A", model=model, hooks=agent_hooks) + + await Runner.run(agent, input="hi", hooks=run_hooks) + + assert run_hooks.turn_starts == [1] + assert run_hooks.turn_ends == [1] + assert agent_hooks.turn_starts == [1] + assert agent_hooks.turn_ends == [1] + + +@pytest.mark.asyncio +async def test_on_turn_hooks_with_streaming() -> None: + """on_turn_start and on_turn_end are called when using the streaming runner.""" + model = FakeModel() + model.set_next_output([get_text_message("streamed")]) + + hooks = TurnTrackingRunHooks() + agent = Agent(name="A", model=model) + + result = Runner.run_streamed(agent, input="hi", hooks=hooks) + async for _ in result.stream_events(): + pass + + assert hooks.turn_starts == [1] + assert hooks.turn_ends == [1] + + +# --------------------------------------------------------------------------- +# TurnControl tests: on_turn_start returning "stop" halts the loop +# --------------------------------------------------------------------------- + +class StopAfterTurnRunHooks(RunHooks): + """Stops the run when on_turn_start is called for a turn > stop_after.""" + + def __init__(self, stop_after: int = 1) -> None: + self.stop_after = stop_after + self.turn_starts: list[int] = [] + self.turn_ends: list[int] = [] + + async def on_turn_start( + self, + context: RunContextWrapper[TContext], + agent: Any, + turn_number: int, + ) -> Optional[str]: + self.turn_starts.append(turn_number) + if turn_number > self.stop_after: + return "stop" + return None + + async def on_turn_end( + self, + context: RunContextWrapper[TContext], + agent: Any, + turn_number: int, + ) -> None: + self.turn_ends.append(turn_number) + + +class StopAfterTurnAgentHooks(AgentHooks): + """Agent-level hooks that return 'stop' after a configurable turn.""" + + def __init__(self, stop_after: int = 1) -> None: + self.stop_after = stop_after + self.turn_starts: list[int] = [] + + async def on_turn_start( + self, + context: RunContextWrapper[TContext], + agent: Any, + turn_number: int, + ) -> Optional[str]: + self.turn_starts.append(turn_number) + if turn_number > self.stop_after: + return "stop" + return None + + +@pytest.mark.asyncio +async def test_run_hook_stop_halts_loop() -> None: + """Returning 'stop' from RunHooks.on_turn_start raises MaxTurnsExceeded before the LLM is called. + + Turn 1: hook returns None → LLM executes, returns a tool call. + Turn 2: hook returns "stop" → MaxTurnsExceeded is raised before the LLM is called. + """ + from agents.exceptions import MaxTurnsExceeded + + tool = get_function_tool("my_tool", "tool_result") + model = FakeModel() + model.add_multiple_turn_outputs( + [ + [get_function_tool_call("my_tool", "{}")], # turn 1: tool call + [get_text_message("turn2")], # turn 2: would be final — never reached + ] + ) + + hooks = StopAfterTurnRunHooks(stop_after=1) + agent = Agent(name="A", model=model, tools=[tool]) + + with pytest.raises(MaxTurnsExceeded, match="halted by on_turn_start hook"): + await Runner.run(agent, input="hi", hooks=hooks, max_turns=10) + + # on_turn_start fires for turn 1 (None → continue) AND turn 2 (returns "stop") + assert hooks.turn_starts == [1, 2] + # on_turn_end fires for turn 1 (completed), NOT turn 2 (never ran) + assert hooks.turn_ends == [1] + + +@pytest.mark.asyncio +async def test_agent_hook_stop_halts_loop() -> None: + """Returning 'stop' from AgentHooksBase.on_turn_start also raises MaxTurnsExceeded.""" + from agents.exceptions import MaxTurnsExceeded + + tool = get_function_tool("my_tool", "tool_result") + model = FakeModel() + model.add_multiple_turn_outputs( + [ + [get_function_tool_call("my_tool", "{}")], + [get_text_message("turn2")], + ] + ) + + agent_hooks = StopAfterTurnAgentHooks(stop_after=1) + agent = Agent(name="A", model=model, tools=[tool], hooks=agent_hooks) + + with pytest.raises(MaxTurnsExceeded, match="halted by on_turn_start hook"): + await Runner.run(agent, input="hi", max_turns=10) + + assert agent_hooks.turn_starts == [1, 2] + + +@pytest.mark.asyncio +async def test_continue_return_value_is_valid() -> None: + """Returning the literal 'continue' from on_turn_start is treated as proceed.""" + model = FakeModel() + model.set_next_output([get_text_message("hello")]) + + class ExplicitContinueHooks(RunHooks): + async def on_turn_start(self, context: Any, agent: Any, turn_number: int) -> str: + return "continue" + + hooks = ExplicitContinueHooks() + agent = Agent(name="A", model=model) + result = await Runner.run(agent, input="hi", hooks=hooks) + assert result.final_output == "hello" + + +@pytest.mark.asyncio +async def test_stop_on_first_turn_raises_max_turns() -> None: + """If on_turn_start returns 'stop' on turn 1, MaxTurnsExceeded is raised immediately.""" + from agents.exceptions import MaxTurnsExceeded + + model = FakeModel() + model.set_next_output([get_text_message("should not appear")]) + + class ImmediateStopHooks(RunHooks): + async def on_turn_start(self, context: Any, agent: Any, turn_number: int) -> str: + return "stop" + + hooks = ImmediateStopHooks() + agent = Agent(name="A", model=model) + + with pytest.raises(MaxTurnsExceeded, match="halted by on_turn_start hook"): + await Runner.run(agent, input="hi", hooks=hooks)