Skip to content

Commit bac71c6

Browse files
authored
fix: #2729 avoid eager-task race in function tool batch executor (#2731)
1 parent a727b1f commit bac71c6

2 files changed

Lines changed: 80 additions & 17 deletions

File tree

src/agents/run_internal/tool_execution.py

Lines changed: 25 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1334,10 +1334,15 @@ async def execute(
13341334
)
13351335

13361336
def _create_tool_task(self, tool_run: ToolRunFunction, order: int) -> None:
1337+
task_state = _FunctionToolTaskState(tool_run=tool_run, order=order)
13371338
task = asyncio.create_task(
1338-
self._run_single_tool(tool_run.function_tool, tool_run.tool_call)
1339+
self._run_single_tool(
1340+
task_state=task_state,
1341+
func_tool=tool_run.function_tool,
1342+
tool_call=tool_run.tool_call,
1343+
)
13391344
)
1340-
self.task_states[task] = _FunctionToolTaskState(tool_run=tool_run, order=order)
1345+
self.task_states[task] = task_state
13411346
self.pending_tasks.add(task)
13421347

13431348
async def _drain_pending_tasks(self) -> None:
@@ -1431,13 +1436,14 @@ def _cancel_pending_tasks_for_parent_cancellation(self) -> None:
14311436

14321437
async def _run_single_tool(
14331438
self,
1439+
*,
1440+
task_state: _FunctionToolTaskState,
14341441
func_tool: FunctionTool,
14351442
tool_call: ResponseFunctionToolCall,
14361443
) -> Any:
14371444
raw_tool_call = tool_call
1438-
current_task = asyncio.current_task()
1439-
if current_task is not None:
1440-
self.task_states[current_task].in_post_invoke_phase = False
1445+
outer_task = asyncio.current_task()
1446+
task_state.in_post_invoke_phase = False
14411447

14421448
tool_call = cast(
14431449
ResponseFunctionToolCall,
@@ -1475,7 +1481,8 @@ async def _run_single_tool(
14751481
result = approval_result
14761482
else:
14771483
result = await self._execute_single_tool_body(
1478-
current_task=current_task,
1484+
outer_task=outer_task,
1485+
task_state=task_state,
14791486
func_tool=func_tool,
14801487
tool_call=tool_call,
14811488
tool_context=tool_context,
@@ -1576,7 +1583,8 @@ async def _maybe_execute_tool_approval(
15761583
async def _execute_single_tool_body(
15771584
self,
15781585
*,
1579-
current_task: asyncio.Task[Any] | None,
1586+
outer_task: asyncio.Task[Any] | None,
1587+
task_state: _FunctionToolTaskState,
15801588
func_tool: FunctionTool,
15811589
tool_call: ResponseFunctionToolCall,
15821590
tool_context: ToolContext[Any],
@@ -1602,21 +1610,22 @@ async def _execute_single_tool_body(
16021610

16031611
invoke_task = asyncio.create_task(
16041612
self._invoke_tool_and_run_post_invoke(
1605-
current_task=current_task,
1613+
outer_task=outer_task,
1614+
task_state=task_state,
16061615
func_tool=func_tool,
16071616
tool_call=tool_call,
16081617
tool_context=tool_context,
16091618
agent_hooks=agent_hooks,
16101619
)
16111620
)
1612-
if current_task is not None:
1613-
self.task_states[current_task].invoke_task = invoke_task
1614-
return await self._await_invoke_task(current_task=current_task, invoke_task=invoke_task)
1621+
task_state.invoke_task = invoke_task
1622+
return await self._await_invoke_task(outer_task=outer_task, invoke_task=invoke_task)
16151623

16161624
async def _invoke_tool_and_run_post_invoke(
16171625
self,
16181626
*,
1619-
current_task: asyncio.Task[Any] | None,
1627+
outer_task: asyncio.Task[Any] | None,
1628+
task_state: _FunctionToolTaskState,
16201629
func_tool: FunctionTool,
16211630
tool_call: ResponseFunctionToolCall,
16221631
tool_context: ToolContext[Any],
@@ -1629,7 +1638,7 @@ async def _invoke_tool_and_run_post_invoke(
16291638
arguments=tool_call.arguments,
16301639
)
16311640
except asyncio.CancelledError as e:
1632-
if not self.isolate_parallel_failures or current_task in self.teardown_cancelled_tasks:
1641+
if not self.isolate_parallel_failures or outer_task in self.teardown_cancelled_tasks:
16331642
raise
16341643

16351644
result = await maybe_invoke_function_tool_failure_error_function(
@@ -1648,8 +1657,7 @@ async def _invoke_tool_and_run_post_invoke(
16481657
)
16491658
real_result = result
16501659

1651-
if current_task is not None:
1652-
self.task_states[current_task].in_post_invoke_phase = True
1660+
task_state.in_post_invoke_phase = True
16531661

16541662
final_result = await _execute_tool_output_guardrails(
16551663
func_tool=func_tool,
@@ -1672,14 +1680,14 @@ async def _invoke_tool_and_run_post_invoke(
16721680
async def _await_invoke_task(
16731681
self,
16741682
*,
1675-
current_task: asyncio.Task[Any] | None,
1683+
outer_task: asyncio.Task[Any] | None,
16761684
invoke_task: asyncio.Task[Any],
16771685
) -> Any:
16781686
try:
16791687
return await asyncio.shield(invoke_task)
16801688
except asyncio.CancelledError as cancel_exc:
16811689
sibling_failure_cancelled = (
1682-
current_task is not None and current_task in self.teardown_cancelled_tasks
1690+
outer_task is not None and outer_task in self.teardown_cancelled_tasks
16831691
)
16841692
if not invoke_task.done():
16851693
invoke_task.cancel()

tests/test_run_step_execution.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1163,6 +1163,61 @@ def _failure_handler(_ctx: RunContextWrapper[Any], error: Exception) -> str:
11631163
assert not on_tool_end_called.is_set()
11641164

11651165

1166+
@pytest.mark.asyncio
1167+
@pytest.mark.skipif(
1168+
not hasattr(asyncio, "eager_task_factory"),
1169+
reason="eager_task_factory requires Python 3.12+",
1170+
)
1171+
async def test_execute_function_tool_calls_eager_task_factory_tracks_state_safely():
1172+
async def _first_tool() -> str:
1173+
return "first"
1174+
1175+
async def _second_tool() -> str:
1176+
return "second"
1177+
1178+
first_tool = function_tool(_first_tool, name_override="first_tool")
1179+
second_tool = function_tool(_second_tool, name_override="second_tool")
1180+
tool_runs = [
1181+
ToolRunFunction(
1182+
tool_call=cast(
1183+
ResponseFunctionToolCall,
1184+
get_function_tool_call("first_tool", "{}", call_id="call-1"),
1185+
),
1186+
function_tool=first_tool,
1187+
),
1188+
ToolRunFunction(
1189+
tool_call=cast(
1190+
ResponseFunctionToolCall,
1191+
get_function_tool_call("second_tool", "{}", call_id="call-2"),
1192+
),
1193+
function_tool=second_tool,
1194+
),
1195+
]
1196+
loop = asyncio.get_running_loop()
1197+
previous_task_factory = loop.get_task_factory()
1198+
eager_task_factory = cast(Any, asyncio.eager_task_factory)
1199+
loop.set_task_factory(eager_task_factory)
1200+
1201+
try:
1202+
(
1203+
function_results,
1204+
input_guardrail_results,
1205+
output_guardrail_results,
1206+
) = await execute_function_tool_calls(
1207+
agent=Agent(name="test", tools=[first_tool, second_tool]),
1208+
tool_runs=tool_runs,
1209+
hooks=RunHooks(),
1210+
context_wrapper=RunContextWrapper(None),
1211+
config=RunConfig(),
1212+
)
1213+
finally:
1214+
loop.set_task_factory(previous_task_factory)
1215+
1216+
assert [result.output for result in function_results] == ["first", "second"]
1217+
assert input_guardrail_results == []
1218+
assert output_guardrail_results == []
1219+
1220+
11661221
@pytest.mark.asyncio
11671222
async def test_execute_function_tool_calls_collapse_trace_name_for_top_level_deferred_tools():
11681223
async def _shipping_eta(tracking_number: str) -> str:

0 commit comments

Comments
 (0)