Skip to content

Commit 976bcce

Browse files
authored
feat: openai#2807 support callable approval policies for local MCP servers (openai#2818)
1 parent 051c2ea commit 976bcce

4 files changed

Lines changed: 204 additions & 6 deletions

File tree

src/agents/mcp/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
try:
22
from .manager import MCPServerManager
33
from .server import (
4+
LocalMCPApprovalCallable,
45
MCPServer,
56
MCPServerSse,
67
MCPServerSseParams,
@@ -32,6 +33,7 @@
3233
"MCPServerStreamableHttp",
3334
"MCPServerStreamableHttpParams",
3435
"MCPServerManager",
36+
"LocalMCPApprovalCallable",
3537
"MCPUtil",
3638
"MCPToolMetaContext",
3739
"MCPToolMetaResolver",

src/agents/mcp/server.py

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -63,13 +63,31 @@ class RequireApprovalObject(TypedDict, total=False):
6363

6464
RequireApprovalPolicy = Literal["always", "never"]
6565
RequireApprovalMapping = dict[str, RequireApprovalPolicy]
66+
if TYPE_CHECKING:
67+
LocalMCPApprovalCallable = Callable[
68+
[RunContextWrapper[Any], "AgentBase", MCPTool],
69+
MaybeAwaitable[bool],
70+
]
71+
else:
72+
LocalMCPApprovalCallable = Callable[..., Any]
73+
6674
if TYPE_CHECKING:
6775
RequireApprovalSetting = (
68-
RequireApprovalPolicy | RequireApprovalObject | RequireApprovalMapping | bool | None
76+
RequireApprovalPolicy
77+
| RequireApprovalObject
78+
| RequireApprovalMapping
79+
| LocalMCPApprovalCallable
80+
| bool
81+
| None
6982
)
7083
else:
7184
RequireApprovalSetting = Union[ # noqa: UP007
72-
RequireApprovalPolicy, RequireApprovalObject, RequireApprovalMapping, bool, None
85+
RequireApprovalPolicy,
86+
RequireApprovalObject,
87+
RequireApprovalMapping,
88+
LocalMCPApprovalCallable,
89+
bool,
90+
None,
7391
]
7492

7593

@@ -220,8 +238,10 @@ def __init__(
220238
default will cause duplicate content. You can set this to True if you know the
221239
server will not duplicate the structured content in the `tool_result.content`.
222240
require_approval: Approval policy for tools on this server. Accepts "always"/"never",
223-
a dict of tool names to those values, a boolean, or an object with always/never
224-
tool lists (mirroring TS requireApproval). Normalized into a needs_approval policy.
241+
a dict of tool names to those values, a boolean, an object with always/never
242+
tool lists (mirroring TS requireApproval), or a sync/async callable that receives
243+
`(run_context, agent, tool)` and returns whether the tool call needs approval.
244+
Normalized into a needs_approval policy.
225245
failure_error_function: Optional function used to convert MCP tool failures into
226246
a model-visible error message. If explicitly set to None, tool errors will be
227247
raised instead of converted. If left unset, the agent-level configuration (or
@@ -408,6 +428,9 @@ def _is_tool_list_schema(value: object) -> bool:
408428
tool_mapping[str(name)] = _to_bool(value)
409429
return tool_mapping
410430

431+
if callable(require_approval):
432+
return require_approval
433+
411434
if isinstance(require_approval, bool):
412435
return require_approval
413436

@@ -418,7 +441,12 @@ def _get_needs_approval_for_tool(
418441
tool: MCPTool,
419442
agent: AgentBase | None,
420443
) -> bool | Callable[[RunContextWrapper[Any], dict[str, Any], str], Awaitable[bool]]:
421-
"""Return a FunctionTool.needs_approval value for a given MCP tool."""
444+
"""Return a FunctionTool.needs_approval value for a given MCP tool.
445+
446+
Legacy callers may omit ``agent`` when using ``MCPUtil.to_function_tool()`` directly.
447+
When approval is configured with a callable policy and no agent is available, this method
448+
returns ``True`` to preserve the historical fail-closed behavior.
449+
"""
422450

423451
policy = self._needs_approval_policy
424452

tests/mcp/test_mcp_approval.py

Lines changed: 97 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
1+
import asyncio
2+
13
import pytest
4+
from mcp.types import Tool as MCPTool
25

3-
from agents import Agent, Runner
6+
from agents import Agent, RunContextWrapper, Runner
47

58
from ..fake_model import FakeModel
69
from ..test_responses import get_function_tool_call, get_text_message
@@ -122,3 +125,96 @@ async def test_mcp_require_approval_mapping_allows_policy_keyword_tool_names():
122125

123126
second = await Runner.run(agent, "call never")
124127
assert not second.interruptions, "tool named 'never' should not require approval"
128+
129+
130+
@pytest.mark.asyncio
131+
async def test_mcp_require_approval_callable_can_allow_and_block_by_tool_name():
132+
"""Callable policies should decide approval dynamically for each MCP tool."""
133+
134+
seen: list[str] = []
135+
136+
def require_approval(
137+
_run_context: RunContextWrapper[object | None],
138+
_agent: Agent,
139+
tool: MCPTool,
140+
) -> bool:
141+
seen.append(tool.name)
142+
return tool.name == "guarded"
143+
144+
server = FakeMCPServer(require_approval=require_approval)
145+
server.add_tool("guarded", {"type": "object", "properties": {}})
146+
server.add_tool("safe", {"type": "object", "properties": {}})
147+
148+
model = FakeModel()
149+
agent = Agent(name="TestAgent", model=model, mcp_servers=[server])
150+
151+
queue_function_call_and_text(
152+
model,
153+
get_function_tool_call("guarded", "{}"),
154+
followup=[get_text_message("guarded done")],
155+
)
156+
first = await Runner.run(agent, "call guarded")
157+
assert first.interruptions, "guarded should require approval via callable policy"
158+
assert first.interruptions[0].tool_name == "guarded"
159+
160+
resumed = await resume_after_first_approval(agent, first, always_approve=True)
161+
assert resumed.final_output == "guarded done"
162+
163+
queue_function_call_and_text(
164+
model,
165+
get_function_tool_call("safe", "{}"),
166+
followup=[get_text_message("safe done")],
167+
)
168+
second = await Runner.run(agent, "call safe")
169+
assert not second.interruptions, "safe should bypass approval via callable policy"
170+
assert second.final_output == "safe done"
171+
172+
assert seen == ["guarded", "guarded", "safe"]
173+
174+
175+
@pytest.mark.asyncio
176+
async def test_mcp_require_approval_async_callable_uses_run_context():
177+
"""Async callable policies should receive the run context and be awaited."""
178+
179+
seen_contexts: list[object | None] = []
180+
181+
async def require_approval(
182+
run_context: RunContextWrapper[dict[str, bool] | None],
183+
_agent: Agent,
184+
_tool,
185+
) -> bool:
186+
seen_contexts.append(run_context.context)
187+
await asyncio.sleep(0)
188+
return bool(run_context.context and run_context.context.get("needs_approval"))
189+
190+
server = FakeMCPServer(require_approval=require_approval)
191+
server.add_tool("conditional", {"type": "object", "properties": {}})
192+
193+
model = FakeModel()
194+
agent = Agent(name="TestAgent", model=model, mcp_servers=[server])
195+
196+
queue_function_call_and_text(
197+
model,
198+
get_function_tool_call("conditional", "{}"),
199+
followup=[get_text_message("approved path")],
200+
)
201+
first = await Runner.run(agent, "call conditional", context={"needs_approval": True})
202+
assert first.interruptions, "run context should be able to trigger approval"
203+
204+
resumed = await resume_after_first_approval(agent, first, always_approve=True)
205+
assert resumed.final_output == "approved path"
206+
207+
queue_function_call_and_text(
208+
model,
209+
get_function_tool_call("conditional", "{}"),
210+
followup=[get_text_message("no approval path")],
211+
)
212+
second = await Runner.run(agent, "call conditional", context={"needs_approval": False})
213+
assert not second.interruptions, "run context should be able to skip approval"
214+
assert second.final_output == "no approval path"
215+
216+
assert seen_contexts == [
217+
{"needs_approval": True},
218+
{"needs_approval": True},
219+
{"needs_approval": False},
220+
]

tests/mcp/test_mcp_util.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -676,6 +676,78 @@ def require_approval(
676676
assert function_tool.needs_approval is True
677677

678678

679+
@pytest.mark.asyncio
680+
async def test_to_function_tool_callable_policy_uses_agent_and_tool():
681+
"""Callable require_approval policies should bridge into FunctionTool.needs_approval."""
682+
683+
captured: dict[str, Any] = {}
684+
685+
def require_approval(
686+
run_context: RunContextWrapper[Any],
687+
agent: Agent,
688+
tool: MCPTool,
689+
) -> bool:
690+
captured["run_context"] = run_context
691+
captured["agent"] = agent
692+
captured["tool"] = tool
693+
return tool.name == "guarded_tool"
694+
695+
server = FakeMCPServer(require_approval=require_approval)
696+
tool = MCPTool(name="guarded_tool", inputSchema={})
697+
agent = Agent(name="test-agent")
698+
699+
function_tool = MCPUtil.to_function_tool(
700+
tool,
701+
server,
702+
convert_schemas_to_strict=False,
703+
agent=agent,
704+
)
705+
706+
assert callable(function_tool.needs_approval)
707+
708+
run_context = RunContextWrapper(context={"request_id": "req_123"})
709+
needs_approval = await function_tool.needs_approval(run_context, {}, "call_123")
710+
711+
assert needs_approval is True
712+
assert captured["run_context"] is run_context
713+
assert captured["agent"] is agent
714+
assert captured["tool"].name == "guarded_tool"
715+
716+
717+
@pytest.mark.asyncio
718+
async def test_to_function_tool_async_callable_policy_is_awaited():
719+
"""Async require_approval policies should be awaited before tool execution."""
720+
721+
async def require_approval(
722+
_run_context: RunContextWrapper[Any],
723+
_agent: Agent,
724+
tool: MCPTool,
725+
) -> bool:
726+
await asyncio.sleep(0)
727+
return tool.name == "async_guarded_tool"
728+
729+
server = FakeMCPServer(require_approval=require_approval)
730+
tool = MCPTool(name="async_guarded_tool", inputSchema={})
731+
agent = Agent(name="test-agent")
732+
733+
function_tool = MCPUtil.to_function_tool(
734+
tool,
735+
server,
736+
convert_schemas_to_strict=False,
737+
agent=agent,
738+
)
739+
740+
assert callable(function_tool.needs_approval)
741+
742+
needs_approval = await function_tool.needs_approval(
743+
RunContextWrapper(context=None),
744+
{},
745+
"call_async_123",
746+
)
747+
748+
assert needs_approval is True
749+
750+
679751
@pytest.mark.asyncio
680752
async def test_mcp_tool_failure_error_function_agent_default():
681753
"""Agent-level failure_error_function should handle MCP tool failures."""

0 commit comments

Comments
 (0)