Skip to content

Commit 4619a6e

Browse files
UbuntuAditya Singh
authored andcommitted
feat(lifecycle): add on_turn_start and on_turn_end hooks to RunHooksBase (#2671)
Both RunHooksBase and AgentHooksBase get two new hook methods: - on_turn_start(context, agent, turn_number): fires before each LLM call - on_turn_end(context, agent, turn_number): fires after all tool calls for the turn complete (i.e. just before the next-step decision) Turn numbers are 1-indexed and increment each time through the agent loop, regardless of handoffs. The hooks are called in both the sync and streaming code paths. Agent-level hooks on agent.hooks are also called, matching the existing on_tool_start/on_tool_end pattern. Closes #2671
1 parent 5c9fb2c commit 4619a6e

File tree

4 files changed

+295
-2
lines changed

4 files changed

+295
-2
lines changed

src/agents/lifecycle.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,36 @@ async def on_tool_end(
8686
"""Called immediately after a local tool is invoked."""
8787
pass
8888

89+
async def on_turn_start(
90+
self,
91+
context: RunContextWrapper[TContext],
92+
agent: TAgent,
93+
turn_number: int,
94+
) -> None:
95+
"""Called at the start of each agent turn, before the LLM is invoked.
96+
97+
Args:
98+
context: The run context wrapper.
99+
agent: The current agent.
100+
turn_number: The 1-indexed turn number (increments each time through the agent loop).
101+
"""
102+
pass
103+
104+
async def on_turn_end(
105+
self,
106+
context: RunContextWrapper[TContext],
107+
agent: TAgent,
108+
turn_number: int,
109+
) -> None:
110+
"""Called at the end of each agent turn, after all tool calls for that turn complete.
111+
112+
Args:
113+
context: The run context wrapper.
114+
agent: The current agent.
115+
turn_number: The 1-indexed turn number.
116+
"""
117+
pass
118+
89119

90120
class AgentHooksBase(Generic[TContext, TAgent]):
91121
"""A class that receives callbacks on various lifecycle events for a specific agent. You can
@@ -148,6 +178,36 @@ async def on_tool_end(
148178
"""Called immediately after a local tool is invoked."""
149179
pass
150180

181+
async def on_turn_start(
182+
self,
183+
context: RunContextWrapper[TContext],
184+
agent: TAgent,
185+
turn_number: int,
186+
) -> None:
187+
"""Called at the start of each agent turn, before the LLM is invoked.
188+
189+
Args:
190+
context: The run context wrapper.
191+
agent: The current agent.
192+
turn_number: The 1-indexed turn number (increments each time through the agent loop).
193+
"""
194+
pass
195+
196+
async def on_turn_end(
197+
self,
198+
context: RunContextWrapper[TContext],
199+
agent: TAgent,
200+
turn_number: int,
201+
) -> None:
202+
"""Called at the end of each agent turn, after all tool calls for that turn complete.
203+
204+
Args:
205+
context: The run context wrapper.
206+
agent: The current agent.
207+
turn_number: The 1-indexed turn number.
208+
"""
209+
pass
210+
151211
async def on_llm_start(
152212
self,
153213
context: RunContextWrapper[TContext],

src/agents/run.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,8 +110,8 @@
110110
from .tool_guardrails import ToolInputGuardrailResult, ToolOutputGuardrailResult
111111
from .tracing import Span, SpanError, agent_span, get_current_trace
112112
from .tracing.context import TraceCtxManager, create_trace_for_run
113-
from .tracing.span_data import AgentSpanData
114-
from .util import _error_tracing
113+
from .tracing.span_data import AgentSpanData, TaskSpanData
114+
from .util import _coro, _error_tracing
115115

116116
DEFAULT_AGENT_RUNNER: AgentRunner = None # type: ignore
117117
# the value is set at the end of the module
@@ -968,6 +968,17 @@ def _with_reasoning_item_id_policy(result: RunResult) -> RunResult:
968968

969969
logger.debug("Running agent %s (turn %s)", current_agent.name, current_turn)
970970

971+
await asyncio.gather(
972+
hooks.on_turn_start(context_wrapper, current_agent, current_turn),
973+
(
974+
current_agent.hooks.on_turn_start(
975+
context_wrapper, current_agent, current_turn
976+
)
977+
if current_agent.hooks
978+
else _coro.noop_coroutine()
979+
),
980+
)
981+
971982
if session_persistence_enabled:
972983
try:
973984
last_saved_input_snapshot_for_rewind = (
@@ -1093,6 +1104,17 @@ def _with_reasoning_item_id_policy(result: RunResult) -> RunResult:
10931104
last_saved_input_snapshot_for_rewind = None
10941105
should_run_agent_start_hooks = False
10951106

1107+
await asyncio.gather(
1108+
hooks.on_turn_end(context_wrapper, current_agent, current_turn),
1109+
(
1110+
current_agent.hooks.on_turn_end(
1111+
context_wrapper, current_agent, current_turn
1112+
)
1113+
if current_agent.hooks
1114+
else _coro.noop_coroutine()
1115+
),
1116+
)
1117+
10961118
model_responses.append(turn_result.model_response)
10971119
original_input = turn_result.original_input
10981120
# For model input, use new_step_items (filtered on handoffs).

src/agents/run_internal/run_loop.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -820,6 +820,17 @@ async def _save_stream_items_without_count(
820820
streamed_result._event_queue.put_nowait(QueueCompleteSentinel())
821821
break
822822

823+
await asyncio.gather(
824+
hooks.on_turn_start(context_wrapper, current_agent, current_turn),
825+
(
826+
current_agent.hooks.on_turn_start(
827+
context_wrapper, current_agent, current_turn
828+
)
829+
if current_agent.hooks
830+
else _coro.noop_coroutine()
831+
),
832+
)
833+
823834
if current_turn == 1:
824835
all_input_guardrails = starting_agent.input_guardrails + (
825836
run_config.input_guardrails or []
@@ -909,6 +920,17 @@ async def _save_stream_items_without_count(
909920
tool_use_tracker
910921
)
911922

923+
await asyncio.gather(
924+
hooks.on_turn_end(context_wrapper, current_agent, current_turn),
925+
(
926+
current_agent.hooks.on_turn_end(
927+
context_wrapper, current_agent, current_turn
928+
)
929+
if current_agent.hooks
930+
else _coro.noop_coroutine()
931+
),
932+
)
933+
912934
streamed_result.raw_responses = streamed_result.raw_responses + [
913935
turn_result.model_response
914936
]

tests/test_turn_lifecycle_hooks.py

Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,189 @@
1+
"""Tests for on_turn_start / on_turn_end lifecycle hooks (issue #2671)."""
2+
3+
from __future__ import annotations
4+
5+
from collections import defaultdict
6+
from typing import Any, Optional
7+
8+
import pytest
9+
10+
from agents import Agent, Runner
11+
from agents.items import ModelResponse, TResponseInputItem
12+
from agents.lifecycle import AgentHooks, RunHooks
13+
from agents.run_context import RunContextWrapper, TContext
14+
from agents.tool import FunctionTool, Tool
15+
16+
from .fake_model import FakeModel
17+
from .test_responses import get_function_tool, get_function_tool_call, get_text_message
18+
19+
20+
class TurnTrackingRunHooks(RunHooks):
21+
"""Records turn numbers seen by on_turn_start and on_turn_end."""
22+
23+
def __init__(self) -> None:
24+
self.turn_starts: list[int] = []
25+
self.turn_ends: list[int] = []
26+
self.events: dict[str, int] = defaultdict(int)
27+
28+
async def on_turn_start(
29+
self,
30+
context: RunContextWrapper[TContext],
31+
agent: Any,
32+
turn_number: int,
33+
) -> None:
34+
self.turn_starts.append(turn_number)
35+
self.events["on_turn_start"] += 1
36+
37+
async def on_turn_end(
38+
self,
39+
context: RunContextWrapper[TContext],
40+
agent: Any,
41+
turn_number: int,
42+
) -> None:
43+
self.turn_ends.append(turn_number)
44+
self.events["on_turn_end"] += 1
45+
46+
47+
class TurnTrackingAgentHooks(AgentHooks):
48+
"""Records turn numbers seen on agent-level hooks."""
49+
50+
def __init__(self) -> None:
51+
self.turn_starts: list[int] = []
52+
self.turn_ends: list[int] = []
53+
54+
async def on_turn_start(
55+
self,
56+
context: RunContextWrapper[TContext],
57+
agent: Any,
58+
turn_number: int,
59+
) -> None:
60+
self.turn_starts.append(turn_number)
61+
62+
async def on_turn_end(
63+
self,
64+
context: RunContextWrapper[TContext],
65+
agent: Any,
66+
turn_number: int,
67+
) -> None:
68+
self.turn_ends.append(turn_number)
69+
70+
71+
@pytest.mark.asyncio
72+
async def test_on_turn_start_and_end_single_turn() -> None:
73+
"""on_turn_start and on_turn_end are both called once for a single-turn run."""
74+
model = FakeModel()
75+
model.set_next_output([get_text_message("hello")])
76+
77+
hooks = TurnTrackingRunHooks()
78+
agent = Agent(name="A", model=model)
79+
80+
await Runner.run(agent, input="hi", hooks=hooks)
81+
82+
assert hooks.turn_starts == [1]
83+
assert hooks.turn_ends == [1]
84+
85+
86+
@pytest.mark.asyncio
87+
async def test_on_turn_numbers_multi_turn() -> None:
88+
"""Turn numbers increment correctly across multiple turns."""
89+
model = FakeModel()
90+
# Turn 1: model calls a tool; turn 2: model produces final output.
91+
tool = get_function_tool("my_tool", "tool_result")
92+
model.add_multiple_turn_outputs([
93+
[get_function_tool_call("my_tool", "{}")],
94+
[get_text_message("done")],
95+
])
96+
97+
hooks = TurnTrackingRunHooks()
98+
agent = Agent(name="A", model=model, tools=[tool])
99+
100+
await Runner.run(agent, input="hi", hooks=hooks)
101+
102+
assert hooks.turn_starts == [1, 2]
103+
assert hooks.turn_ends == [1, 2]
104+
105+
106+
@pytest.mark.asyncio
107+
async def test_on_turn_start_fires_before_llm() -> None:
108+
"""on_turn_start fires before the LLM call each turn."""
109+
call_order: list[str] = []
110+
111+
class OrderTrackingHooks(RunHooks):
112+
async def on_turn_start(self, context: Any, agent: Any, turn_number: int) -> None:
113+
call_order.append(f"turn_start:{turn_number}")
114+
115+
async def on_llm_start(self, context: Any, agent: Any, system_prompt: Any, input_items: Any) -> None:
116+
call_order.append(f"llm_start")
117+
118+
async def on_llm_end(self, context: Any, agent: Any, response: Any) -> None:
119+
call_order.append(f"llm_end")
120+
121+
async def on_turn_end(self, context: Any, agent: Any, turn_number: int) -> None:
122+
call_order.append(f"turn_end:{turn_number}")
123+
124+
model = FakeModel()
125+
model.set_next_output([get_text_message("hello")])
126+
hooks = OrderTrackingHooks()
127+
agent = Agent(name="A", model=model)
128+
129+
await Runner.run(agent, input="hi", hooks=hooks)
130+
131+
# turn_start must come before llm_start, llm_end before turn_end
132+
ts_idx = call_order.index("turn_start:1")
133+
ls_idx = call_order.index("llm_start")
134+
le_idx = call_order.index("llm_end")
135+
te_idx = call_order.index("turn_end:1")
136+
137+
assert ts_idx < ls_idx
138+
assert ls_idx < le_idx
139+
assert le_idx < te_idx
140+
141+
142+
@pytest.mark.asyncio
143+
async def test_agent_level_on_turn_start_and_end() -> None:
144+
"""Agent-level on_turn_start / on_turn_end hooks are also called."""
145+
model = FakeModel()
146+
model.set_next_output([get_text_message("hello")])
147+
148+
agent_hooks = TurnTrackingAgentHooks()
149+
agent = Agent(name="A", model=model, hooks=agent_hooks)
150+
151+
await Runner.run(agent, input="hi")
152+
153+
assert agent_hooks.turn_starts == [1]
154+
assert agent_hooks.turn_ends == [1]
155+
156+
157+
@pytest.mark.asyncio
158+
async def test_run_and_agent_hooks_both_called() -> None:
159+
"""Both run-level and agent-level hooks fire for the same turn."""
160+
model = FakeModel()
161+
model.set_next_output([get_text_message("hi")])
162+
163+
run_hooks = TurnTrackingRunHooks()
164+
agent_hooks = TurnTrackingAgentHooks()
165+
agent = Agent(name="A", model=model, hooks=agent_hooks)
166+
167+
await Runner.run(agent, input="hi", hooks=run_hooks)
168+
169+
assert run_hooks.turn_starts == [1]
170+
assert run_hooks.turn_ends == [1]
171+
assert agent_hooks.turn_starts == [1]
172+
assert agent_hooks.turn_ends == [1]
173+
174+
175+
@pytest.mark.asyncio
176+
async def test_on_turn_hooks_with_streaming() -> None:
177+
"""on_turn_start and on_turn_end are called when using the streaming runner."""
178+
model = FakeModel()
179+
model.set_next_output([get_text_message("streamed")])
180+
181+
hooks = TurnTrackingRunHooks()
182+
agent = Agent(name="A", model=model)
183+
184+
result = Runner.run_streamed(agent, input="hi", hooks=hooks)
185+
async for _ in result.stream_events():
186+
pass
187+
188+
assert hooks.turn_starts == [1]
189+
assert hooks.turn_ends == [1]

0 commit comments

Comments
 (0)