1+ import sys
12from unittest .mock import Mock
23
34import graphviz # type: ignore
1213)
1314from agents .handoffs import Handoff
1415
16+ if sys .version_info >= (3 , 10 ):
17+ from .mcp .helpers import FakeMCPServer
18+
1519
1620@pytest .fixture
1721def mock_agent ():
@@ -27,6 +31,10 @@ def mock_agent():
2731 agent .name = "Agent1"
2832 agent .tools = [tool1 , tool2 ]
2933 agent .handoffs = [handoff1 ]
34+ agent .mcp_servers = []
35+
36+ if sys .version_info >= (3 , 10 ):
37+ agent .mcp_servers = [FakeMCPServer (server_name = "MCPServer1" )]
3038
3139 return agent
3240
@@ -62,6 +70,7 @@ def test_get_main_graph(mock_agent):
6270 '"Handoff1" [label="Handoff1", shape=box, style=filled, style=rounded, '
6371 "fillcolor=lightyellow, width=1.5, height=0.8];" in result
6472 )
73+ _assert_mcp_nodes (result )
6574
6675
6776def test_get_all_nodes (mock_agent ):
@@ -90,6 +99,7 @@ def test_get_all_nodes(mock_agent):
9099 '"Handoff1" [label="Handoff1", shape=box, style=filled, style=rounded, '
91100 "fillcolor=lightyellow, width=1.5, height=0.8];" in result
92101 )
102+ _assert_mcp_nodes (result )
93103
94104
95105def test_get_all_edges (mock_agent ):
@@ -101,6 +111,7 @@ def test_get_all_edges(mock_agent):
101111 assert '"Agent1" -> "Tool2" [style=dotted, penwidth=1.5];' in result
102112 assert '"Tool2" -> "Agent1" [style=dotted, penwidth=1.5];' in result
103113 assert '"Agent1" -> "Handoff1";' in result
114+ _assert_mcp_edges (result )
104115
105116
106117def test_draw_graph (mock_agent ):
@@ -134,6 +145,25 @@ def test_draw_graph(mock_agent):
134145 '"Handoff1" [label="Handoff1", shape=box, style=filled, style=rounded, '
135146 "fillcolor=lightyellow, width=1.5, height=0.8];" in graph .source
136147 )
148+ _assert_mcp_nodes (graph .source )
149+
150+
151+ def _assert_mcp_nodes (source : str ):
152+ if sys .version_info < (3 , 10 ):
153+ assert "MCPServer1" not in source
154+ return
155+ assert (
156+ '"MCPServer1" [label="MCPServer1", shape=box, style=filled, '
157+ "fillcolor=lightgrey, width=1, height=0.5];" in source
158+ )
159+
160+
161+ def _assert_mcp_edges (source : str ):
162+ if sys .version_info < (3 , 10 ):
163+ assert "MCPServer1" not in source
164+ return
165+ assert '"Agent1" -> "MCPServer1" [style=dashed, penwidth=1.5];' in source
166+ assert '"MCPServer1" -> "Agent1" [style=dashed, penwidth=1.5];' in source
137167
138168
139169def test_cycle_detection ():
0 commit comments