|
| 1 | +import asyncio |
| 2 | + |
1 | 3 | import pytest |
| 4 | +from mcp.types import Tool as MCPTool |
2 | 5 |
|
3 | | -from agents import Agent, Runner |
| 6 | +from agents import Agent, RunContextWrapper, Runner |
4 | 7 |
|
5 | 8 | from ..fake_model import FakeModel |
6 | 9 | from ..test_responses import get_function_tool_call, get_text_message |
@@ -122,3 +125,96 @@ async def test_mcp_require_approval_mapping_allows_policy_keyword_tool_names(): |
122 | 125 |
|
123 | 126 | second = await Runner.run(agent, "call never") |
124 | 127 | assert not second.interruptions, "tool named 'never' should not require approval" |
| 128 | + |
| 129 | + |
| 130 | +@pytest.mark.asyncio |
| 131 | +async def test_mcp_require_approval_callable_can_allow_and_block_by_tool_name(): |
| 132 | + """Callable policies should decide approval dynamically for each MCP tool.""" |
| 133 | + |
| 134 | + seen: list[str] = [] |
| 135 | + |
| 136 | + def require_approval( |
| 137 | + _run_context: RunContextWrapper[object | None], |
| 138 | + _agent: Agent, |
| 139 | + tool: MCPTool, |
| 140 | + ) -> bool: |
| 141 | + seen.append(tool.name) |
| 142 | + return tool.name == "guarded" |
| 143 | + |
| 144 | + server = FakeMCPServer(require_approval=require_approval) |
| 145 | + server.add_tool("guarded", {"type": "object", "properties": {}}) |
| 146 | + server.add_tool("safe", {"type": "object", "properties": {}}) |
| 147 | + |
| 148 | + model = FakeModel() |
| 149 | + agent = Agent(name="TestAgent", model=model, mcp_servers=[server]) |
| 150 | + |
| 151 | + queue_function_call_and_text( |
| 152 | + model, |
| 153 | + get_function_tool_call("guarded", "{}"), |
| 154 | + followup=[get_text_message("guarded done")], |
| 155 | + ) |
| 156 | + first = await Runner.run(agent, "call guarded") |
| 157 | + assert first.interruptions, "guarded should require approval via callable policy" |
| 158 | + assert first.interruptions[0].tool_name == "guarded" |
| 159 | + |
| 160 | + resumed = await resume_after_first_approval(agent, first, always_approve=True) |
| 161 | + assert resumed.final_output == "guarded done" |
| 162 | + |
| 163 | + queue_function_call_and_text( |
| 164 | + model, |
| 165 | + get_function_tool_call("safe", "{}"), |
| 166 | + followup=[get_text_message("safe done")], |
| 167 | + ) |
| 168 | + second = await Runner.run(agent, "call safe") |
| 169 | + assert not second.interruptions, "safe should bypass approval via callable policy" |
| 170 | + assert second.final_output == "safe done" |
| 171 | + |
| 172 | + assert seen == ["guarded", "guarded", "safe"] |
| 173 | + |
| 174 | + |
| 175 | +@pytest.mark.asyncio |
| 176 | +async def test_mcp_require_approval_async_callable_uses_run_context(): |
| 177 | + """Async callable policies should receive the run context and be awaited.""" |
| 178 | + |
| 179 | + seen_contexts: list[object | None] = [] |
| 180 | + |
| 181 | + async def require_approval( |
| 182 | + run_context: RunContextWrapper[dict[str, bool] | None], |
| 183 | + _agent: Agent, |
| 184 | + _tool, |
| 185 | + ) -> bool: |
| 186 | + seen_contexts.append(run_context.context) |
| 187 | + await asyncio.sleep(0) |
| 188 | + return bool(run_context.context and run_context.context.get("needs_approval")) |
| 189 | + |
| 190 | + server = FakeMCPServer(require_approval=require_approval) |
| 191 | + server.add_tool("conditional", {"type": "object", "properties": {}}) |
| 192 | + |
| 193 | + model = FakeModel() |
| 194 | + agent = Agent(name="TestAgent", model=model, mcp_servers=[server]) |
| 195 | + |
| 196 | + queue_function_call_and_text( |
| 197 | + model, |
| 198 | + get_function_tool_call("conditional", "{}"), |
| 199 | + followup=[get_text_message("approved path")], |
| 200 | + ) |
| 201 | + first = await Runner.run(agent, "call conditional", context={"needs_approval": True}) |
| 202 | + assert first.interruptions, "run context should be able to trigger approval" |
| 203 | + |
| 204 | + resumed = await resume_after_first_approval(agent, first, always_approve=True) |
| 205 | + assert resumed.final_output == "approved path" |
| 206 | + |
| 207 | + queue_function_call_and_text( |
| 208 | + model, |
| 209 | + get_function_tool_call("conditional", "{}"), |
| 210 | + followup=[get_text_message("no approval path")], |
| 211 | + ) |
| 212 | + second = await Runner.run(agent, "call conditional", context={"needs_approval": False}) |
| 213 | + assert not second.interruptions, "run context should be able to skip approval" |
| 214 | + assert second.final_output == "no approval path" |
| 215 | + |
| 216 | + assert seen_contexts == [ |
| 217 | + {"needs_approval": True}, |
| 218 | + {"needs_approval": True}, |
| 219 | + {"needs_approval": False}, |
| 220 | + ] |
0 commit comments