Skip to content

Commit 22c57b1

Browse files
authored
add max turns in REPL (#2431)
1 parent 01b18d0 commit 22c57b1

File tree

1 file changed

+13
-4
lines changed

1 file changed

+13
-4
lines changed

src/agents/repl.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,17 @@
77
from .agent import Agent
88
from .items import TResponseInputItem
99
from .result import RunResultBase
10-
from .run import Runner
10+
from .run import DEFAULT_MAX_TURNS, Runner
1111
from .run_context import TContext
1212
from .stream_events import AgentUpdatedStreamEvent, RawResponsesStreamEvent, RunItemStreamEvent
1313

1414

1515
async def run_demo_loop(
16-
agent: Agent[Any], *, stream: bool = True, context: TContext | None = None
16+
agent: Agent[Any],
17+
*,
18+
stream: bool = True,
19+
context: TContext | None = None,
20+
max_turns: int = DEFAULT_MAX_TURNS,
1721
) -> None:
1822
"""Run a simple REPL loop with the given agent.
1923
@@ -25,6 +29,7 @@ async def run_demo_loop(
2529
agent: The starting agent to run.
2630
stream: Whether to stream the agent output.
2731
context: Additional context information to pass to the runner.
32+
max_turns: Maximum number of turns for the runner to iterate.
2833
"""
2934

3035
current_agent = agent
@@ -44,7 +49,9 @@ async def run_demo_loop(
4449

4550
result: RunResultBase
4651
if stream:
47-
result = Runner.run_streamed(current_agent, input=input_items, context=context)
52+
result = Runner.run_streamed(
53+
current_agent, input=input_items, context=context, max_turns=max_turns
54+
)
4855
async for event in result.stream_events():
4956
if isinstance(event, RawResponsesStreamEvent):
5057
if isinstance(event.data, ResponseTextDeltaEvent):
@@ -58,7 +65,9 @@ async def run_demo_loop(
5865
print(f"\n[Agent updated: {event.new_agent.name}]", flush=True)
5966
print()
6067
else:
61-
result = await Runner.run(current_agent, input_items, context=context)
68+
result = await Runner.run(
69+
current_agent, input_items, context=context, max_turns=max_turns
70+
)
6271
if result.final_output is not None:
6372
print(result.final_output)
6473

0 commit comments

Comments
 (0)