Skip to content

Commit 01b18d0

Browse files
authored
fix: preserve handoff target resolution compatibility in run-state agent maps (#2423)
1 parent b39ae9c commit 01b18d0

File tree

3 files changed

+131
-7
lines changed

3 files changed

+131
-7
lines changed

src/agents/handoffs/__init__.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@
22

33
import inspect
44
import json
5+
import weakref
56
from collections.abc import Awaitable
6-
from dataclasses import dataclass, replace as dataclasses_replace
7+
from dataclasses import dataclass, field, replace as dataclasses_replace
78
from typing import TYPE_CHECKING, Any, Callable, Generic, cast, overload
89

910
from pydantic import TypeAdapter
@@ -148,6 +149,11 @@ class Handoff(Generic[TContext, TAgent]):
148149
context or state.
149150
"""
150151

152+
_agent_ref: weakref.ReferenceType[AgentBase[Any]] | None = field(
153+
default=None, init=False, repr=False
154+
)
155+
"""Weak reference to the target agent when constructed via `handoff()`."""
156+
151157
def get_transfer_message(self, agent: AgentBase[Any]) -> str:
152158
return json.dumps({"assistant": agent.name})
153159

@@ -300,7 +306,7 @@ async def _is_enabled(ctx: RunContextWrapper[Any], agent_base: AgentBase[Any]) -
300306
return await result
301307
return bool(result)
302308

303-
return Handoff(
309+
handoff_obj = Handoff(
304310
tool_name=tool_name,
305311
tool_description=tool_description,
306312
input_json_schema=input_json_schema,
@@ -310,6 +316,8 @@ async def _is_enabled(ctx: RunContextWrapper[Any], agent_base: AgentBase[Any]) -
310316
agent_name=agent.name,
311317
is_enabled=_is_enabled if callable(is_enabled) else is_enabled,
312318
)
319+
handoff_obj._agent_ref = weakref.ref(agent)
320+
return handoff_obj
313321

314322

315323
__all__ = [

src/agents/run_state.py

Lines changed: 57 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1898,11 +1898,63 @@ def _build_agent_map(initial_agent: Agent[Any]) -> dict[str, Agent[Any]]:
18981898
agent_map[current.name] = current
18991899

19001900
# Add handoff agents to the queue
1901-
for handoff in current.handoffs:
1902-
# Handoff can be either an Agent or a Handoff object with an .agent attribute
1903-
handoff_agent = handoff if not hasattr(handoff, "agent") else handoff.agent
1904-
if handoff_agent and handoff_agent.name not in agent_map: # type: ignore[union-attr]
1905-
queue.append(handoff_agent) # type: ignore[arg-type]
1901+
for handoff_item in current.handoffs:
1902+
handoff_agent: Any | None = None
1903+
handoff_agent_name: str | None = None
1904+
1905+
if isinstance(handoff_item, Handoff):
1906+
# Some custom/mocked Handoff subclasses bypass dataclass initialization.
1907+
# Prefer agent_name, then legacy name fallback used in tests.
1908+
candidate_name = getattr(handoff_item, "agent_name", None) or getattr(
1909+
handoff_item, "name", None
1910+
)
1911+
if isinstance(candidate_name, str):
1912+
handoff_agent_name = candidate_name
1913+
if handoff_agent_name in agent_map:
1914+
continue
1915+
1916+
handoff_ref = getattr(handoff_item, "_agent_ref", None)
1917+
handoff_agent = handoff_ref() if callable(handoff_ref) else None
1918+
if handoff_agent is None:
1919+
# Backward-compatibility fallback for custom legacy handoff objects that store
1920+
# the target directly on `.agent`. New code should prefer `handoff()` objects.
1921+
legacy_agent = getattr(handoff_item, "agent", None)
1922+
if legacy_agent is not None:
1923+
handoff_agent = legacy_agent
1924+
logger.debug(
1925+
"Using legacy handoff `.agent` fallback while building agent map. "
1926+
"This compatibility path is not recommended for new code."
1927+
)
1928+
if handoff_agent_name is None:
1929+
candidate_name = getattr(handoff_agent, "name", None)
1930+
handoff_agent_name = candidate_name if isinstance(candidate_name, str) else None
1931+
if handoff_agent is None or not hasattr(handoff_agent, "handoffs"):
1932+
if handoff_agent_name:
1933+
logger.debug(
1934+
"Skipping unresolved handoff target while building agent map: %s",
1935+
handoff_agent_name,
1936+
)
1937+
continue
1938+
else:
1939+
# Backward-compatibility fallback for custom legacy handoff wrappers that expose
1940+
# the target directly on `.agent` without inheriting from `Handoff`.
1941+
legacy_agent = getattr(handoff_item, "agent", None)
1942+
if legacy_agent is not None:
1943+
handoff_agent = legacy_agent
1944+
logger.debug(
1945+
"Using legacy non-`Handoff` `.agent` fallback while building agent map."
1946+
)
1947+
else:
1948+
handoff_agent = handoff_item
1949+
candidate_name = getattr(handoff_agent, "name", None)
1950+
handoff_agent_name = candidate_name if isinstance(candidate_name, str) else None
1951+
1952+
if (
1953+
handoff_agent is not None
1954+
and handoff_agent_name
1955+
and handoff_agent_name not in agent_map
1956+
):
1957+
queue.append(cast(Any, handoff_agent))
19061958

19071959
# Include agent-as-tool instances so nested approvals can be restored.
19081960
tools = getattr(current, "tools", None)

tests/test_run_state.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -876,6 +876,70 @@ def test_build_agent_map_handles_complex_handoff_graphs(self):
876876
assert len(agent_map) == 4
877877
assert all(agent_map.get(name) is not None for name in ["A", "B", "C", "D"])
878878

879+
def test_build_agent_map_handles_handoff_objects(self):
880+
"""Test that buildAgentMap resolves handoff() objects via weak references."""
881+
agent_a = Agent(name="AgentA")
882+
agent_b = Agent(name="AgentB")
883+
agent_a.handoffs = [handoff(agent_b)]
884+
885+
agent_map = _build_agent_map(agent_a)
886+
887+
assert sorted(agent_map.keys()) == ["AgentA", "AgentB"]
888+
889+
def test_build_agent_map_supports_legacy_handoff_agent_attribute(self):
890+
"""Test that buildAgentMap keeps legacy custom handoffs with `.agent` targets working."""
891+
agent_a = Agent(name="AgentA")
892+
agent_b = Agent(name="AgentB")
893+
894+
class LegacyHandoff(Handoff):
895+
def __init__(self, target: Agent[Any]):
896+
# Legacy custom handoff shape supported only for backward compatibility.
897+
self.agent = target
898+
self.agent_name = target.name
899+
self.name = "legacy_handoff"
900+
901+
agent_a.handoffs = [LegacyHandoff(agent_b)]
902+
903+
agent_map = _build_agent_map(agent_a)
904+
905+
assert sorted(agent_map.keys()) == ["AgentA", "AgentB"]
906+
907+
def test_build_agent_map_supports_legacy_non_handoff_agent_wrapper(self):
908+
"""Test that buildAgentMap supports legacy non-Handoff wrappers with `.agent` targets."""
909+
agent_a = Agent(name="AgentA")
910+
agent_b = Agent(name="AgentB")
911+
912+
class LegacyWrapper:
913+
def __init__(self, target: Agent[Any]):
914+
self.agent = target
915+
916+
agent_a.handoffs = [LegacyWrapper(agent_b)] # type: ignore[list-item]
917+
918+
agent_map = _build_agent_map(agent_a)
919+
920+
assert sorted(agent_map.keys()) == ["AgentA", "AgentB"]
921+
922+
def test_build_agent_map_skips_unresolved_handoff_objects(self):
923+
"""Test that buildAgentMap skips custom handoffs without target agent references."""
924+
agent_a = Agent(name="AgentA")
925+
agent_b = Agent(name="AgentB")
926+
927+
async def _invoke_handoff(_ctx: RunContextWrapper[Any], _input: str) -> Agent[Any]:
928+
return agent_b
929+
930+
detached_handoff = Handoff(
931+
tool_name="transfer_to_agent_b",
932+
tool_description="Transfer to AgentB.",
933+
input_json_schema={},
934+
on_invoke_handoff=_invoke_handoff,
935+
agent_name=agent_b.name,
936+
)
937+
agent_a.handoffs = [detached_handoff]
938+
939+
agent_map = _build_agent_map(agent_a)
940+
941+
assert sorted(agent_map.keys()) == ["AgentA"]
942+
879943

880944
class TestSerializationRoundTrip:
881945
"""Test that serialization and deserialization preserve state correctly."""

0 commit comments

Comments
 (0)