Skip to content

Commit 9b2248e

Browse files
committed
feat: add MCP tool filtering support
1 parent 0eee6b8 commit 9b2248e

File tree

5 files changed

+374
-18
lines changed

5 files changed

+374
-18
lines changed

docs/mcp.md

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,57 @@ agent=Agent(
4141
)
4242
```
4343

44+
## Tool filtering
45+
46+
You can filter which tools are available to your Agent in two ways:
47+
48+
### Server-level filtering
49+
50+
Each MCP server instance can be configured with `allowed_tools` and `excluded_tools` parameters to control which tools it exposes:
51+
52+
```python
53+
# Only expose specific tools from this server
54+
server = MCPServerStdio(
55+
params={
56+
"command": "npx",
57+
"args": ["-y", "@modelcontextprotocol/server-filesystem", samples_dir],
58+
},
59+
allowed_tools=["read_file", "write_file"], # Only these tools will be available
60+
)
61+
62+
# Exclude specific tools from this server
63+
server = MCPServerStdio(
64+
params={
65+
"command": "npx",
66+
"args": ["-y", "@modelcontextprotocol/server-filesystem", samples_dir],
67+
},
68+
excluded_tools=["delete_file"], # This tool will be filtered out
69+
)
70+
```
71+
72+
### Agent-level filtering
73+
74+
You can also filter tools at the Agent level using the `mcp_config` parameter. This allows you to control which tools are available across all MCP servers:
75+
76+
```python
77+
agent = Agent(
78+
name="Assistant",
79+
instructions="Use the tools to achieve the task",
80+
mcp_servers=[server1, server2, server3],
81+
mcp_config={
82+
"allowed_tools": {
83+
"server1": ["read_file", "write_file"], # Only these tools from server1
84+
"server2": ["search"], # Only search tool from server2
85+
},
86+
"excluded_tools": {
87+
"server3": ["dangerous_tool"], # Exclude this tool from server3
88+
}
89+
}
90+
)
91+
```
92+
93+
**Filtering priority**: Server-level filtering is applied first, then Agent-level filtering. This allows for fine-grained control where servers can limit their exposed tools, and Agents can further restrict which tools they use.
94+
4495
## Caching
4596

4697
Every time an Agent runs, it calls `list_tools()` on the MCP server. This can be a latency hit, especially if the server is a remote server. To automatically cache the list of tools, you can pass `cache_tools_list=True` to [`MCPServerStdio`][agents.mcp.server.MCPServerStdio], [`MCPServerSse`][agents.mcp.server.MCPServerSse], and [`MCPServerStreamableHttp`][agents.mcp.server.MCPServerStreamableHttp]. You should only do this if you're certain the tool list will not change.

src/agents/agent.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,10 @@ class MCPConfig(TypedDict):
6262
"""If True, we will attempt to convert the MCP schemas to strict-mode schemas. This is a
6363
best-effort conversion, so some schemas may not be convertible. Defaults to False.
6464
"""
65+
allowed_tools: NotRequired[dict[str, list[str]]]
66+
"""Optional: server_name -> allowed tool names (whitelist)"""
67+
excluded_tools: NotRequired[dict[str, list[str]]]
68+
"""Optional: server_name -> excluded tool names (blacklist)"""
6569

6670

6771
@dataclass
@@ -245,7 +249,14 @@ async def get_system_prompt(self, run_context: RunContextWrapper[TContext]) -> s
245249
async def get_mcp_tools(self) -> list[Tool]:
246250
"""Fetches the available tools from the MCP servers."""
247251
convert_schemas_to_strict = self.mcp_config.get("convert_schemas_to_strict", False)
248-
return await MCPUtil.get_all_function_tools(self.mcp_servers, convert_schemas_to_strict)
252+
allowed_tools_map = self.mcp_config.get("allowed_tools", {})
253+
excluded_tools_map = self.mcp_config.get("excluded_tools", {})
254+
return await MCPUtil.get_all_function_tools(
255+
self.mcp_servers,
256+
convert_schemas_to_strict,
257+
allowed_tools_map,
258+
excluded_tools_map,
259+
)
249260

250261
async def get_all_tools(self, run_context: RunContextWrapper[Any]) -> list[Tool]:
251262
"""All agent tools, including MCP tools and function tools."""

src/agents/mcp/server.py

Lines changed: 59 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,13 @@ async def call_tool(self, tool_name: str, arguments: dict[str, Any] | None) -> C
5757
class _MCPServerWithClientSession(MCPServer, abc.ABC):
5858
"""Base class for MCP servers that use a `ClientSession` to communicate with the server."""
5959

60-
def __init__(self, cache_tools_list: bool, client_session_timeout_seconds: float | None):
60+
def __init__(
61+
self,
62+
cache_tools_list: bool,
63+
client_session_timeout_seconds: float | None,
64+
allowed_tools: list[str] | None = None,
65+
excluded_tools: list[str] | None = None,
66+
):
6167
"""
6268
Args:
6369
cache_tools_list: Whether to cache the tools list. If `True`, the tools list will be
@@ -68,6 +74,10 @@ def __init__(self, cache_tools_list: bool, client_session_timeout_seconds: float
6874
(by avoiding a round-trip to the server every time).
6975
7076
client_session_timeout_seconds: the read timeout passed to the MCP ClientSession.
77+
allowed_tools: Optional list of tool names to allow (whitelist).
78+
If set, only these tools will be available.
79+
excluded_tools: Optional list of tool names to exclude (blacklist).
80+
If set, these tools will be filtered out.
7181
"""
7282
self.session: ClientSession | None = None
7383
self.exit_stack: AsyncExitStack = AsyncExitStack()
@@ -81,6 +91,9 @@ def __init__(self, cache_tools_list: bool, client_session_timeout_seconds: float
8191
self._cache_dirty = True
8292
self._tools_list: list[MCPTool] | None = None
8393

94+
self.allowed_tools = allowed_tools
95+
self.excluded_tools = excluded_tools
96+
8497
@abc.abstractmethod
8598
def create_streams(
8699
self,
@@ -138,14 +151,21 @@ async def list_tools(self) -> list[MCPTool]:
138151

139152
# Return from cache if caching is enabled, we have tools, and the cache is not dirty
140153
if self.cache_tools_list and not self._cache_dirty and self._tools_list:
141-
return self._tools_list
142-
143-
# Reset the cache dirty to False
144-
self._cache_dirty = False
145-
146-
# Fetch the tools from the server
147-
self._tools_list = (await self.session.list_tools()).tools
148-
return self._tools_list
154+
tools = self._tools_list
155+
else:
156+
# Reset the cache dirty to False
157+
self._cache_dirty = False
158+
# Fetch the tools from the server
159+
self._tools_list = (await self.session.list_tools()).tools
160+
tools = self._tools_list
161+
162+
# Filter tools based on allowed and excluded tools
163+
filtered_tools = tools
164+
if self.allowed_tools is not None:
165+
filtered_tools = [t for t in filtered_tools if t.name in self.allowed_tools]
166+
if self.excluded_tools is not None:
167+
filtered_tools = [t for t in filtered_tools if t.name not in self.excluded_tools]
168+
return filtered_tools
149169

150170
async def call_tool(self, tool_name: str, arguments: dict[str, Any] | None) -> CallToolResult:
151171
"""Invoke a tool on the server."""
@@ -206,6 +226,8 @@ def __init__(
206226
cache_tools_list: bool = False,
207227
name: str | None = None,
208228
client_session_timeout_seconds: float | None = 5,
229+
allowed_tools: list[str] | None = None,
230+
excluded_tools: list[str] | None = None,
209231
):
210232
"""Create a new MCP server based on the stdio transport.
211233
@@ -223,8 +245,15 @@ def __init__(
223245
name: A readable name for the server. If not provided, we'll create one from the
224246
command.
225247
client_session_timeout_seconds: the read timeout passed to the MCP ClientSession.
248+
allowed_tools: Optional list of tool names to allow (whitelist).
249+
excluded_tools: Optional list of tool names to exclude (blacklist).
226250
"""
227-
super().__init__(cache_tools_list, client_session_timeout_seconds)
251+
super().__init__(
252+
cache_tools_list,
253+
client_session_timeout_seconds,
254+
allowed_tools,
255+
excluded_tools,
256+
)
228257

229258
self.params = StdioServerParameters(
230259
command=params["command"],
@@ -283,6 +312,8 @@ def __init__(
283312
cache_tools_list: bool = False,
284313
name: str | None = None,
285314
client_session_timeout_seconds: float | None = 5,
315+
allowed_tools: list[str] | None = None,
316+
excluded_tools: list[str] | None = None,
286317
):
287318
"""Create a new MCP server based on the HTTP with SSE transport.
288319
@@ -302,8 +333,15 @@ def __init__(
302333
URL.
303334
304335
client_session_timeout_seconds: the read timeout passed to the MCP ClientSession.
336+
allowed_tools: Optional list of tool names to allow (whitelist).
337+
excluded_tools: Optional list of tool names to exclude (blacklist).
305338
"""
306-
super().__init__(cache_tools_list, client_session_timeout_seconds)
339+
super().__init__(
340+
cache_tools_list,
341+
client_session_timeout_seconds,
342+
allowed_tools,
343+
excluded_tools,
344+
)
307345

308346
self.params = params
309347
self._name = name or f"sse: {self.params['url']}"
@@ -362,6 +400,8 @@ def __init__(
362400
cache_tools_list: bool = False,
363401
name: str | None = None,
364402
client_session_timeout_seconds: float | None = 5,
403+
allowed_tools: list[str] | None = None,
404+
excluded_tools: list[str] | None = None,
365405
):
366406
"""Create a new MCP server based on the Streamable HTTP transport.
367407
@@ -382,8 +422,15 @@ def __init__(
382422
URL.
383423
384424
client_session_timeout_seconds: the read timeout passed to the MCP ClientSession.
425+
allowed_tools: Optional list of tool names to allow (whitelist).
426+
excluded_tools: Optional list of tool names to exclude (blacklist).
385427
"""
386-
super().__init__(cache_tools_list, client_session_timeout_seconds)
428+
super().__init__(
429+
cache_tools_list,
430+
client_session_timeout_seconds,
431+
allowed_tools,
432+
excluded_tools,
433+
)
387434

388435
self.params = params
389436
self._name = name or f"streamable_http: {self.params['url']}"

src/agents/mcp/util.py

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import functools
22
import json
3-
from typing import TYPE_CHECKING, Any
3+
from typing import TYPE_CHECKING, Any, Optional
44

55
from agents.strict_schema import ensure_strict_json_schema
66

@@ -22,13 +22,23 @@ class MCPUtil:
2222

2323
@classmethod
2424
async def get_all_function_tools(
25-
cls, servers: list["MCPServer"], convert_schemas_to_strict: bool
25+
cls,
26+
servers: list["MCPServer"],
27+
convert_schemas_to_strict: bool,
28+
allowed_tools_map: Optional[dict[str, list[str]]] = None,
29+
excluded_tools_map: Optional[dict[str, list[str]]] = None,
2630
) -> list[Tool]:
2731
"""Get all function tools from a list of MCP servers."""
2832
tools = []
2933
tool_names: set[str] = set()
34+
allowed_tools_map = allowed_tools_map or {}
35+
excluded_tools_map = excluded_tools_map or {}
3036
for server in servers:
31-
server_tools = await cls.get_function_tools(server, convert_schemas_to_strict)
37+
allowed = allowed_tools_map.get(server.name)
38+
excluded = excluded_tools_map.get(server.name)
39+
server_tools = await cls.get_function_tools(
40+
server, convert_schemas_to_strict, allowed, excluded
41+
)
3242
server_tool_names = {tool.name for tool in server_tools}
3343
if len(server_tool_names & tool_names) > 0:
3444
raise UserError(
@@ -42,15 +52,29 @@ async def get_all_function_tools(
4252

4353
@classmethod
4454
async def get_function_tools(
45-
cls, server: "MCPServer", convert_schemas_to_strict: bool
55+
cls,
56+
server: "MCPServer",
57+
convert_schemas_to_strict: bool,
58+
allowed_tools: Optional[list[str]] = None,
59+
excluded_tools: Optional[list[str]] = None,
4660
) -> list[Tool]:
4761
"""Get all function tools from a single MCP server."""
4862

4963
with mcp_tools_span(server=server.name) as span:
5064
tools = await server.list_tools()
5165
span.span_data.result = [tool.name for tool in tools]
5266

53-
return [cls.to_function_tool(tool, server, convert_schemas_to_strict) for tool in tools]
67+
# Apply Agent-level filtering (additional filtering on top of server-level filtering)
68+
filtered_tools = tools
69+
if allowed_tools is not None:
70+
filtered_tools = [t for t in filtered_tools if t.name in allowed_tools]
71+
if excluded_tools is not None:
72+
filtered_tools = [t for t in filtered_tools if t.name not in excluded_tools]
73+
74+
return [
75+
cls.to_function_tool(tool, server, convert_schemas_to_strict)
76+
for tool in filtered_tools
77+
]
5478

5579
@classmethod
5680
def to_function_tool(

0 commit comments

Comments
 (0)