Skip to content

Commit 3307f96

Browse files
committed
feat: implement tool filter interface for MCP servers
1 parent f553f20 commit 3307f96

File tree

5 files changed

+194
-52
lines changed

5 files changed

+194
-52
lines changed

src/agents/mcp/__init__.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,14 @@
1111
except ImportError:
1212
pass
1313

14-
from .util import MCPUtil
14+
from .util import (
15+
MCPUtil,
16+
ToolFilter,
17+
ToolFilterCallable,
18+
ToolFilterContext,
19+
ToolFilterStatic,
20+
create_static_tool_filter,
21+
)
1522

1623
__all__ = [
1724
"MCPServer",
@@ -22,4 +29,9 @@
2229
"MCPServerStreamableHttp",
2330
"MCPServerStreamableHttpParams",
2431
"MCPUtil",
32+
"ToolFilter",
33+
"ToolFilterCallable",
34+
"ToolFilterContext",
35+
"ToolFilterStatic",
36+
"create_static_tool_filter",
2537
]

src/agents/mcp/server.py

Lines changed: 48 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
from ..exceptions import UserError
1919
from ..logger import logger
20+
from .util import ToolFilter, ToolFilterStatic
2021

2122

2223
class MCPServer(abc.ABC):
@@ -61,8 +62,7 @@ def __init__(
6162
self,
6263
cache_tools_list: bool,
6364
client_session_timeout_seconds: float | None,
64-
allowed_tools: list[str] | None = None,
65-
excluded_tools: list[str] | None = None,
65+
tool_filter: ToolFilter = None,
6666
):
6767
"""
6868
Args:
@@ -74,10 +74,7 @@ def __init__(
7474
(by avoiding a round-trip to the server every time).
7575
7676
client_session_timeout_seconds: the read timeout passed to the MCP ClientSession.
77-
allowed_tools: Optional list of tool names to allow (whitelist).
78-
If set, only these tools will be available.
79-
excluded_tools: Optional list of tool names to exclude (blacklist).
80-
If set, these tools will be filtered out.
77+
tool_filter: The tool filter to use for filtering tools.
8178
"""
8279
self.session: ClientSession | None = None
8380
self.exit_stack: AsyncExitStack = AsyncExitStack()
@@ -91,8 +88,39 @@ def __init__(
9188
self._cache_dirty = True
9289
self._tools_list: list[MCPTool] | None = None
9390

94-
self.allowed_tools = allowed_tools
95-
self.excluded_tools = excluded_tools
91+
self.tool_filter = tool_filter
92+
93+
def _apply_tool_filter(self, tools: list[MCPTool]) -> list[MCPTool]:
94+
"""Apply the tool filter to the list of tools."""
95+
if self.tool_filter is None:
96+
return tools
97+
98+
# Handle static tool filter
99+
if isinstance(self.tool_filter, dict):
100+
static_filter: ToolFilterStatic = self.tool_filter
101+
filtered_tools = tools
102+
103+
# Apply allowed_tool_names filter (whitelist)
104+
if "allowed_tool_names" in static_filter:
105+
allowed_names = static_filter["allowed_tool_names"]
106+
filtered_tools = [t for t in filtered_tools if t.name in allowed_names]
107+
108+
# Apply blocked_tool_names filter (blacklist)
109+
if "blocked_tool_names" in static_filter:
110+
blocked_names = static_filter["blocked_tool_names"]
111+
filtered_tools = [t for t in filtered_tools if t.name not in blocked_names]
112+
113+
return filtered_tools
114+
115+
# Handle callable tool filter
116+
# For now, we can't support callable filters because we don't have access to
117+
# run context and agent in the current list_tools signature.
118+
# This could be enhanced in the future by modifying the call chain.
119+
else:
120+
raise NotImplementedError(
121+
"Callable tool filters are not yet supported. Please use ToolFilterStatic "
122+
"with 'allowed_tool_names' and/or 'blocked_tool_names' for now."
123+
)
96124

97125
@abc.abstractmethod
98126
def create_streams(
@@ -159,12 +187,10 @@ async def list_tools(self) -> list[MCPTool]:
159187
self._tools_list = (await self.session.list_tools()).tools
160188
tools = self._tools_list
161189

162-
# Filter tools based on allowed and excluded tools
190+
# Filter tools based on tool_filter
163191
filtered_tools = tools
164-
if self.allowed_tools is not None:
165-
filtered_tools = [t for t in filtered_tools if t.name in self.allowed_tools]
166-
if self.excluded_tools is not None:
167-
filtered_tools = [t for t in filtered_tools if t.name not in self.excluded_tools]
192+
if self.tool_filter is not None:
193+
filtered_tools = self._apply_tool_filter(filtered_tools)
168194
return filtered_tools
169195

170196
async def call_tool(self, tool_name: str, arguments: dict[str, Any] | None) -> CallToolResult:
@@ -226,8 +252,7 @@ def __init__(
226252
cache_tools_list: bool = False,
227253
name: str | None = None,
228254
client_session_timeout_seconds: float | None = 5,
229-
allowed_tools: list[str] | None = None,
230-
excluded_tools: list[str] | None = None,
255+
tool_filter: ToolFilter = None,
231256
):
232257
"""Create a new MCP server based on the stdio transport.
233258
@@ -245,14 +270,12 @@ def __init__(
245270
name: A readable name for the server. If not provided, we'll create one from the
246271
command.
247272
client_session_timeout_seconds: the read timeout passed to the MCP ClientSession.
248-
allowed_tools: Optional list of tool names to allow (whitelist).
249-
excluded_tools: Optional list of tool names to exclude (blacklist).
273+
tool_filter: The tool filter to use for filtering tools.
250274
"""
251275
super().__init__(
252276
cache_tools_list,
253277
client_session_timeout_seconds,
254-
allowed_tools,
255-
excluded_tools,
278+
tool_filter,
256279
)
257280

258281
self.params = StdioServerParameters(
@@ -312,8 +335,7 @@ def __init__(
312335
cache_tools_list: bool = False,
313336
name: str | None = None,
314337
client_session_timeout_seconds: float | None = 5,
315-
allowed_tools: list[str] | None = None,
316-
excluded_tools: list[str] | None = None,
338+
tool_filter: ToolFilter = None,
317339
):
318340
"""Create a new MCP server based on the HTTP with SSE transport.
319341
@@ -333,14 +355,12 @@ def __init__(
333355
URL.
334356
335357
client_session_timeout_seconds: the read timeout passed to the MCP ClientSession.
336-
allowed_tools: Optional list of tool names to allow (whitelist).
337-
excluded_tools: Optional list of tool names to exclude (blacklist).
358+
tool_filter: The tool filter to use for filtering tools.
338359
"""
339360
super().__init__(
340361
cache_tools_list,
341362
client_session_timeout_seconds,
342-
allowed_tools,
343-
excluded_tools,
363+
tool_filter,
344364
)
345365

346366
self.params = params
@@ -400,8 +420,7 @@ def __init__(
400420
cache_tools_list: bool = False,
401421
name: str | None = None,
402422
client_session_timeout_seconds: float | None = 5,
403-
allowed_tools: list[str] | None = None,
404-
excluded_tools: list[str] | None = None,
423+
tool_filter: ToolFilter = None,
405424
):
406425
"""Create a new MCP server based on the Streamable HTTP transport.
407426
@@ -422,14 +441,12 @@ def __init__(
422441
URL.
423442
424443
client_session_timeout_seconds: the read timeout passed to the MCP ClientSession.
425-
allowed_tools: Optional list of tool names to allow (whitelist).
426-
excluded_tools: Optional list of tool names to exclude (blacklist).
444+
tool_filter: The tool filter to use for filtering tools.
427445
"""
428446
super().__init__(
429447
cache_tools_list,
430448
client_session_timeout_seconds,
431-
allowed_tools,
432-
excluded_tools,
449+
tool_filter,
433450
)
434451

435452
self.params = params

src/agents/mcp/util.py

Lines changed: 72 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import functools
22
import json
3-
from typing import TYPE_CHECKING, Any
3+
from dataclasses import dataclass
4+
from typing import TYPE_CHECKING, Any, Callable, Union
5+
from typing_extensions import NotRequired, TypedDict
46

57
from agents.strict_schema import ensure_strict_json_schema
68

@@ -10,13 +12,82 @@
1012
from ..run_context import RunContextWrapper
1113
from ..tool import FunctionTool, Tool
1214
from ..tracing import FunctionSpanData, get_current_span, mcp_tools_span
15+
from ..util._types import MaybeAwaitable
1316

1417
if TYPE_CHECKING:
1518
from mcp.types import Tool as MCPTool
1619

20+
from ..agent import Agent
1721
from .server import MCPServer
1822

1923

24+
@dataclass
25+
class ToolFilterContext:
26+
"""Context information available to tool filter functions."""
27+
28+
run_context: RunContextWrapper[Any]
29+
"""The current run context."""
30+
31+
agent: "Agent[Any]"
32+
"""The agent that is requesting the tool list."""
33+
34+
server_name: str
35+
"""The name of the MCP server."""
36+
37+
38+
ToolFilterCallable = Callable[["ToolFilterContext", "MCPTool"], MaybeAwaitable[bool]]
39+
"""A function that determines whether a tool should be available.
40+
41+
Args:
42+
context: The context information including run context, agent, and server name.
43+
tool: The MCP tool to filter.
44+
45+
Returns:
46+
Whether the tool should be available (True) or filtered out (False).
47+
"""
48+
49+
50+
class ToolFilterStatic(TypedDict):
51+
"""Static tool filter configuration using allowlists and blocklists."""
52+
53+
allowed_tool_names: NotRequired[list[str]]
54+
"""Optional list of tool names to allow (whitelist). If set, only these tools will be available."""
55+
56+
blocked_tool_names: NotRequired[list[str]]
57+
"""Optional list of tool names to exclude (blacklist). If set, these tools will be filtered out."""
58+
59+
60+
ToolFilter = Union[ToolFilterCallable, ToolFilterStatic, None]
61+
"""A tool filter that can be either a function, static configuration, or None (no filtering)."""
62+
63+
64+
def create_static_tool_filter(
65+
allowed_tool_names: list[str] | None = None,
66+
blocked_tool_names: list[str] | None = None,
67+
) -> ToolFilterStatic | None:
68+
"""Create a static tool filter from allowlist and blocklist parameters.
69+
70+
This is a convenience function for creating a ToolFilterStatic.
71+
72+
Args:
73+
allowed_tool_names: Optional list of tool names to allow (whitelist).
74+
blocked_tool_names: Optional list of tool names to exclude (blacklist).
75+
76+
Returns:
77+
A ToolFilterStatic if any filtering is specified, None otherwise.
78+
"""
79+
if allowed_tool_names is None and blocked_tool_names is None:
80+
return None
81+
82+
filter_dict: ToolFilterStatic = {}
83+
if allowed_tool_names is not None:
84+
filter_dict["allowed_tool_names"] = allowed_tool_names
85+
if blocked_tool_names is not None:
86+
filter_dict["blocked_tool_names"] = blocked_tool_names
87+
88+
return filter_dict
89+
90+
2091
class MCPUtil:
2192
"""Set of utilities for interop between MCP and Agents SDK tools."""
2293

src/agents/tool.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727

2828
if TYPE_CHECKING:
2929
from .agent import Agent
30+
from mcp.types import Tool as MCPTool
3031

3132
ToolParams = ParamSpec("ToolParams")
3233

0 commit comments

Comments
 (0)