Skip to content

Commit 2531843

Browse files
fix(mcp): validate required params before call_tool (#2453)
1 parent 9e812de commit 2531843

2 files changed

Lines changed: 121 additions & 0 deletions

File tree

src/agents/mcp/server.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -600,6 +600,7 @@ async def call_tool(
600600
assert session is not None
601601

602602
try:
603+
self._validate_required_parameters(tool_name=tool_name, arguments=arguments)
603604
if meta is None:
604605
return await self._run_with_retries(lambda: session.call_tool(tool_name, arguments))
605606
return await self._run_with_retries(
@@ -617,6 +618,40 @@ async def call_tool(
617618
f"The server may have disconnected."
618619
) from e
619620

621+
def _validate_required_parameters(
622+
self, tool_name: str, arguments: dict[str, Any] | None
623+
) -> None:
624+
"""Validate required tool parameters from cached MCP tool schemas before invocation."""
625+
if self._tools_list is None:
626+
return
627+
628+
tool = next((item for item in self._tools_list if item.name == tool_name), None)
629+
if tool is None or not isinstance(tool.inputSchema, dict):
630+
return
631+
632+
raw_required = tool.inputSchema.get("required")
633+
if not isinstance(raw_required, list) or not raw_required:
634+
return
635+
636+
if arguments is None:
637+
arguments_to_validate: dict[str, Any] = {}
638+
elif isinstance(arguments, dict):
639+
arguments_to_validate = arguments
640+
else:
641+
raise UserError(
642+
f"Failed to call tool '{tool_name}' on MCP server '{self.name}': "
643+
"arguments must be an object."
644+
)
645+
646+
required_names = [name for name in raw_required if isinstance(name, str)]
647+
missing = [name for name in required_names if name not in arguments_to_validate]
648+
if missing:
649+
missing_text = ", ".join(sorted(missing))
650+
raise UserError(
651+
f"Failed to call tool '{tool_name}' on MCP server '{self.name}': "
652+
f"missing required parameters: {missing_text}"
653+
)
654+
620655
async def list_prompts(
621656
self,
622657
) -> ListPromptsResult:

tests/mcp/test_client_session_retries.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from mcp import ClientSession, Tool as MCPTool
55
from mcp.types import CallToolResult, ListToolsResult
66

7+
from agents.exceptions import UserError
78
from agents.mcp.server import _MCPServerWithClientSession
89

910

@@ -62,3 +63,88 @@ async def test_list_tools_unlimited_retries():
6263
assert len(tools) == 1
6364
assert tools[0].name == "tool"
6465
assert session.list_tools_attempts == 4
66+
67+
68+
@pytest.mark.asyncio
69+
async def test_call_tool_validates_required_parameters_before_remote_call():
70+
session = DummySession()
71+
server = DummyServer(session=session, retries=0)
72+
server._tools_list = [ # noqa: SLF001
73+
MCPTool(
74+
name="tool",
75+
inputSchema={
76+
"type": "object",
77+
"properties": {"param_a": {"type": "string"}},
78+
"required": ["param_a"],
79+
},
80+
)
81+
]
82+
83+
with pytest.raises(UserError, match="missing required parameters: param_a"):
84+
await server.call_tool("tool", {})
85+
86+
assert session.call_tool_attempts == 0
87+
88+
89+
@pytest.mark.asyncio
90+
async def test_call_tool_with_required_parameters_still_calls_remote_tool():
91+
session = DummySession()
92+
server = DummyServer(session=session, retries=0)
93+
server._tools_list = [ # noqa: SLF001
94+
MCPTool(
95+
name="tool",
96+
inputSchema={
97+
"type": "object",
98+
"properties": {"param_a": {"type": "string"}},
99+
"required": ["param_a"],
100+
},
101+
)
102+
]
103+
104+
result = await server.call_tool("tool", {"param_a": "value"})
105+
assert isinstance(result, CallToolResult)
106+
assert session.call_tool_attempts == 1
107+
108+
109+
@pytest.mark.asyncio
110+
async def test_call_tool_skips_validation_when_tool_is_missing_from_cache():
111+
session = DummySession()
112+
server = DummyServer(session=session, retries=0)
113+
server._tools_list = [MCPTool(name="different_tool", inputSchema={"required": ["param_a"]})] # noqa: SLF001
114+
115+
await server.call_tool("tool", {})
116+
assert session.call_tool_attempts == 1
117+
118+
119+
@pytest.mark.asyncio
120+
async def test_call_tool_skips_validation_when_required_list_is_absent():
121+
session = DummySession()
122+
server = DummyServer(session=session, retries=0)
123+
server._tools_list = [MCPTool(name="tool", inputSchema={"type": "object"})] # noqa: SLF001
124+
125+
await server.call_tool("tool", None)
126+
assert session.call_tool_attempts == 1
127+
128+
129+
@pytest.mark.asyncio
130+
async def test_call_tool_validates_required_parameters_when_arguments_is_none():
131+
session = DummySession()
132+
server = DummyServer(session=session, retries=0)
133+
server._tools_list = [MCPTool(name="tool", inputSchema={"required": ["param_a"]})] # noqa: SLF001
134+
135+
with pytest.raises(UserError, match="missing required parameters: param_a"):
136+
await server.call_tool("tool", None)
137+
138+
assert session.call_tool_attempts == 0
139+
140+
141+
@pytest.mark.asyncio
142+
async def test_call_tool_rejects_non_object_arguments_before_remote_call():
143+
session = DummySession()
144+
server = DummyServer(session=session, retries=0)
145+
server._tools_list = [MCPTool(name="tool", inputSchema={"required": ["param_a"]})] # noqa: SLF001
146+
147+
with pytest.raises(UserError, match="arguments must be an object"):
148+
await server.call_tool("tool", cast(dict[str, object] | None, ["bad"]))
149+
150+
assert session.call_tool_attempts == 0

0 commit comments

Comments
 (0)