Skip to content

Commit e94c09e

Browse files
committed
feat: implement comprehensive MCP tool filtering with static and dynamic support
1 parent 304b106 commit e94c09e

File tree

7 files changed

+459
-130
lines changed

7 files changed

+459
-130
lines changed

docs/mcp.md

Lines changed: 66 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,16 +43,24 @@ agent=Agent(
4343

4444
## Tool filtering
4545

46-
You can filter which tools are available to your Agent using server-level filtering:
46+
You can filter which tools are available to your Agent by configuring tool filters on MCP servers. The SDK supports both static and dynamic tool filtering.
47+
48+
### Static tool filtering
49+
50+
For simple allow/block lists, you can use static filtering:
4751

4852
```python
53+
from agents.mcp import create_static_tool_filter
54+
4955
# Only expose specific tools from this server
5056
server = MCPServerStdio(
5157
params={
5258
"command": "npx",
5359
"args": ["-y", "@modelcontextprotocol/server-filesystem", samples_dir],
5460
},
55-
allowed_tools=["read_file", "write_file"], # Only these tools will be available
61+
tool_filter=create_static_tool_filter(
62+
allowed_tool_names=["read_file", "write_file"]
63+
)
5664
)
5765

5866
# Exclude specific tools from this server
@@ -61,10 +69,65 @@ server = MCPServerStdio(
6169
"command": "npx",
6270
"args": ["-y", "@modelcontextprotocol/server-filesystem", samples_dir],
6371
},
64-
excluded_tools=["delete_file"], # This tool will be filtered out
72+
tool_filter=create_static_tool_filter(
73+
blocked_tool_names=["delete_file"]
74+
)
75+
)
76+
77+
```
78+
79+
**When both `allowed_tool_names` and `blocked_tool_names` are configured, the processing order is:**
80+
1. First apply `allowed_tool_names` (allowlist) - only keep the specified tools
81+
2. Then apply `blocked_tool_names` (blocklist) - exclude specified tools from the remaining tools
82+
83+
For example, if you configure `allowed_tool_names=["read_file", "write_file", "delete_file"]` and `blocked_tool_names=["delete_file"]`, only `read_file` and `write_file` tools will be available.
84+
85+
### Dynamic tool filtering
86+
87+
For more complex filtering logic, you can use dynamic filters with functions:
88+
89+
```python
90+
from agents.mcp import ToolFilterContext
91+
92+
# Simple synchronous filter
93+
def custom_filter(context: ToolFilterContext, tool) -> bool:
94+
"""Example of a custom tool filter."""
95+
# Filter logic based on tool name patterns
96+
return tool.name.startswith("allowed_prefix")
97+
98+
# Context-aware filter
99+
def context_aware_filter(context: ToolFilterContext, tool) -> bool:
100+
"""Filter tools based on context information."""
101+
# Access agent information
102+
agent_name = context.agent.name
103+
104+
# Access server information
105+
server_name = context.server_name
106+
107+
# Implement your custom filtering logic here
108+
return some_filtering_logic(agent_name, server_name, tool)
109+
110+
# Asynchronous filter
111+
async def async_filter(context: ToolFilterContext, tool) -> bool:
112+
"""Example of an asynchronous filter."""
113+
# Perform async operations if needed
114+
result = await some_async_check(context, tool)
115+
return result
116+
117+
server = MCPServerStdio(
118+
params={
119+
"command": "npx",
120+
"args": ["-y", "@modelcontextprotocol/server-filesystem", samples_dir],
121+
},
122+
tool_filter=custom_filter # or context_aware_filter or async_filter
65123
)
66124
```
67125

126+
The `ToolFilterContext` provides access to:
127+
- `run_context`: The current run context
128+
- `agent`: The agent requesting the tools
129+
- `server_name`: The name of the MCP server
130+
68131
## Caching
69132

70133
Every time an Agent runs, it calls `list_tools()` on the MCP server. This can be a latency hit, especially if the server is a remote server. To automatically cache the list of tools, you can pass `cache_tools_list=True` to [`MCPServerStdio`][agents.mcp.server.MCPServerStdio], [`MCPServerSse`][agents.mcp.server.MCPServerSse], and [`MCPServerStreamableHttp`][agents.mcp.server.MCPServerStreamableHttp]. You should only do this if you're certain the tool list will not change.

src/agents/agent.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -256,14 +256,18 @@ async def get_prompt(
256256
"""Get the prompt for the agent."""
257257
return await PromptUtil.to_model_input(self.prompt, run_context, self)
258258

259-
async def get_mcp_tools(self) -> list[Tool]:
259+
async def get_mcp_tools(
260+
self, run_context: RunContextWrapper[TContext] | None = None
261+
) -> list[Tool]:
260262
"""Fetches the available tools from the MCP servers."""
261263
convert_schemas_to_strict = self.mcp_config.get("convert_schemas_to_strict", False)
262-
return await MCPUtil.get_all_function_tools(self.mcp_servers, convert_schemas_to_strict)
264+
return await MCPUtil.get_all_function_tools(
265+
self.mcp_servers, convert_schemas_to_strict, run_context, self
266+
)
263267

264268
async def get_all_tools(self, run_context: RunContextWrapper[Any]) -> list[Tool]:
265269
"""All agent tools, including MCP tools and function tools."""
266-
mcp_tools = await self.get_mcp_tools()
270+
mcp_tools = await self.get_mcp_tools(run_context)
267271

268272
async def _check_tool_enabled(tool: Tool) -> bool:
269273
if not isinstance(tool, FunctionTool):

src/agents/mcp/server.py

Lines changed: 106 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,11 @@
22

33
import abc
44
import asyncio
5+
import inspect
56
from contextlib import AbstractAsyncContextManager, AsyncExitStack
67
from datetime import timedelta
78
from pathlib import Path
8-
from typing import Any, Literal
9+
from typing import TYPE_CHECKING, Any, Literal, cast
910

1011
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
1112
from mcp import ClientSession, StdioServerParameters, Tool as MCPTool, stdio_client
@@ -17,7 +18,11 @@
1718

1819
from ..exceptions import UserError
1920
from ..logger import logger
20-
from .util import ToolFilter, ToolFilterStatic
21+
from ..run_context import RunContextWrapper
22+
from .util import ToolFilter, ToolFilterCallable, ToolFilterContext, ToolFilterStatic
23+
24+
if TYPE_CHECKING:
25+
from ..agent import Agent
2126

2227

2328
class MCPServer(abc.ABC):
@@ -45,7 +50,11 @@ async def cleanup(self):
4550
pass
4651

4752
@abc.abstractmethod
48-
async def list_tools(self) -> list[MCPTool]:
53+
async def list_tools(
54+
self,
55+
run_context: RunContextWrapper[Any] | None = None,
56+
agent: Agent[Any] | None = None,
57+
) -> list[MCPTool]:
4958
"""List the tools available on the server."""
5059
pass
5160

@@ -90,38 +99,106 @@ def __init__(
9099

91100
self.tool_filter = tool_filter
92101

93-
def _apply_tool_filter(self, tools: list[MCPTool]) -> list[MCPTool]:
102+
async def _apply_tool_filter(
103+
self,
104+
tools: list[MCPTool],
105+
run_context: RunContextWrapper[Any] | None,
106+
agent: Agent[Any] | None,
107+
) -> list[MCPTool]:
94108
"""Apply the tool filter to the list of tools."""
95109
if self.tool_filter is None:
96110
return tools
97111

98112
# Handle static tool filter
99113
if isinstance(self.tool_filter, dict):
100-
static_filter: ToolFilterStatic = self.tool_filter
101-
filtered_tools = tools
114+
return self._apply_static_tool_filter(tools, self.tool_filter)
102115

103-
# Apply allowed_tool_names filter (whitelist)
104-
if "allowed_tool_names" in static_filter:
105-
allowed_names = static_filter["allowed_tool_names"]
106-
filtered_tools = [t for t in filtered_tools if t.name in allowed_names]
116+
# Handle callable tool filter (dynamic filter)
117+
else:
118+
return await self._apply_dynamic_tool_filter(tools, run_context, agent)
107119

108-
# Apply blocked_tool_names filter (blacklist)
109-
if "blocked_tool_names" in static_filter:
110-
blocked_names = static_filter["blocked_tool_names"]
111-
filtered_tools = [t for t in filtered_tools if t.name not in blocked_names]
120+
def _apply_static_tool_filter(
121+
self,
122+
tools: list[MCPTool],
123+
static_filter: ToolFilterStatic
124+
) -> list[MCPTool]:
125+
"""Apply static tool filtering based on allowlist and blocklist."""
126+
filtered_tools = tools
112127

113-
return filtered_tools
128+
# Apply allowed_tool_names filter (whitelist)
129+
if "allowed_tool_names" in static_filter:
130+
allowed_names = static_filter["allowed_tool_names"]
131+
filtered_tools = [t for t in filtered_tools if t.name in allowed_names]
114132

115-
# Handle callable tool filter
116-
# For now, we can't support callable filters because we don't have access to
117-
# run context and agent in the current list_tools signature.
118-
# This could be enhanced in the future by modifying the call chain.
119-
else:
120-
raise NotImplementedError(
121-
"Callable tool filters are not yet supported. Please use ToolFilterStatic "
122-
"with 'allowed_tool_names' and/or 'blocked_tool_names' for now."
133+
# Apply blocked_tool_names filter (blacklist)
134+
if "blocked_tool_names" in static_filter:
135+
blocked_names = static_filter["blocked_tool_names"]
136+
filtered_tools = [t for t in filtered_tools if t.name not in blocked_names]
137+
138+
return filtered_tools
139+
140+
async def _apply_dynamic_tool_filter(
141+
self,
142+
tools: list[MCPTool],
143+
run_context: RunContextWrapper[Any] | None,
144+
agent: Agent[Any] | None,
145+
) -> list[MCPTool]:
146+
"""Apply dynamic tool filtering using a callable filter function."""
147+
148+
# Ensure we have a callable filter and cast to help mypy
149+
if not callable(self.tool_filter):
150+
raise ValueError("Tool filter must be callable for dynamic filtering")
151+
tool_filter_func = cast(ToolFilterCallable, self.tool_filter)
152+
153+
# Create filter context - it may be None if run_context or agent is None
154+
filter_context = None
155+
if run_context is not None and agent is not None:
156+
filter_context = ToolFilterContext(
157+
run_context=run_context,
158+
agent=agent,
159+
server_name=self.name,
123160
)
124161

162+
filtered_tools = []
163+
for tool in tools:
164+
try:
165+
# Try to call the filter function
166+
if filter_context is not None:
167+
# We have full context, call with context
168+
result = tool_filter_func(filter_context, tool)
169+
else:
170+
# Try to call without context first to see if it works
171+
try:
172+
# Some filters might not need context parameters at all
173+
result = tool_filter_func(None, tool)
174+
except (TypeError, AttributeError) as e:
175+
# If the filter tries to access context attributes, raise a helpful error
176+
raise UserError(
177+
"Dynamic tool filters require both run_context and agent when the "
178+
"filter function accesses context information. This typically happens "
179+
"when calling list_tools() directly without these parameters. Either "
180+
"provide both parameters or use a static tool filter instead."
181+
) from e
182+
183+
if inspect.isawaitable(result):
184+
should_include = await result
185+
else:
186+
should_include = result
187+
188+
if should_include:
189+
filtered_tools.append(tool)
190+
except UserError:
191+
# Re-raise UserError as-is (this includes our context requirement error)
192+
raise
193+
except Exception as e:
194+
logger.error(
195+
f"Error applying tool filter to tool '{tool.name}' on server '{self.name}': {e}"
196+
)
197+
# On error, exclude the tool for safety
198+
continue
199+
200+
return filtered_tools
201+
125202
@abc.abstractmethod
126203
def create_streams(
127204
self,
@@ -172,7 +249,11 @@ async def connect(self):
172249
await self.cleanup()
173250
raise
174251

175-
async def list_tools(self) -> list[MCPTool]:
252+
async def list_tools(
253+
self,
254+
run_context: RunContextWrapper[Any] | None = None,
255+
agent: Agent[Any] | None = None,
256+
) -> list[MCPTool]:
176257
"""List the tools available on the server."""
177258
if not self.session:
178259
raise UserError("Server not initialized. Make sure you call `connect()` first.")
@@ -190,7 +271,7 @@ async def list_tools(self) -> list[MCPTool]:
190271
# Filter tools based on tool_filter
191272
filtered_tools = tools
192273
if self.tool_filter is not None:
193-
filtered_tools = self._apply_tool_filter(filtered_tools)
274+
filtered_tools = await self._apply_tool_filter(filtered_tools, run_context, agent)
194275
return filtered_tools
195276

196277
async def call_tool(self, tool_name: str, arguments: dict[str, Any] | None) -> CallToolResult:

src/agents/mcp/util.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import json
33
from dataclasses import dataclass
44
from typing import TYPE_CHECKING, Any, Callable, Union
5+
56
from typing_extensions import NotRequired, TypedDict
67

78
from agents.strict_schema import ensure_strict_json_schema
@@ -35,11 +36,12 @@ class ToolFilterContext:
3536
"""The name of the MCP server."""
3637

3738

38-
ToolFilterCallable = Callable[["ToolFilterContext", "MCPTool"], MaybeAwaitable[bool]]
39+
ToolFilterCallable = Callable[["ToolFilterContext | None", "MCPTool"], MaybeAwaitable[bool]]
3940
"""A function that determines whether a tool should be available.
4041
4142
Args:
4243
context: The context information including run context, agent, and server name.
44+
Can be None if run_context or agent is not available.
4345
tool: The MCP tool to filter.
4446
4547
Returns:
@@ -51,10 +53,12 @@ class ToolFilterStatic(TypedDict):
5153
"""Static tool filter configuration using allowlists and blocklists."""
5254

5355
allowed_tool_names: NotRequired[list[str]]
54-
"""Optional list of tool names to allow (whitelist). If set, only these tools will be available."""
56+
"""Optional list of tool names to allow (whitelist).
57+
If set, only these tools will be available."""
5558

5659
blocked_tool_names: NotRequired[list[str]]
57-
"""Optional list of tool names to exclude (blacklist). If set, these tools will be filtered out."""
60+
"""Optional list of tool names to exclude (blacklist).
61+
If set, these tools will be filtered out."""
5862

5963

6064
ToolFilter = Union[ToolFilterCallable, ToolFilterStatic, None]
@@ -93,13 +97,19 @@ class MCPUtil:
9397

9498
@classmethod
9599
async def get_all_function_tools(
96-
cls, servers: list["MCPServer"], convert_schemas_to_strict: bool
100+
cls,
101+
servers: list["MCPServer"],
102+
convert_schemas_to_strict: bool,
103+
run_context: RunContextWrapper[Any] | None = None,
104+
agent: "Agent[Any] | None" = None,
97105
) -> list[Tool]:
98106
"""Get all function tools from a list of MCP servers."""
99107
tools = []
100108
tool_names: set[str] = set()
101109
for server in servers:
102-
server_tools = await cls.get_function_tools(server, convert_schemas_to_strict)
110+
server_tools = await cls.get_function_tools(
111+
server, convert_schemas_to_strict, run_context, agent
112+
)
103113
server_tool_names = {tool.name for tool in server_tools}
104114
if len(server_tool_names & tool_names) > 0:
105115
raise UserError(
@@ -113,12 +123,16 @@ async def get_all_function_tools(
113123

114124
@classmethod
115125
async def get_function_tools(
116-
cls, server: "MCPServer", convert_schemas_to_strict: bool
126+
cls,
127+
server: "MCPServer",
128+
convert_schemas_to_strict: bool,
129+
run_context: RunContextWrapper[Any] | None = None,
130+
agent: "Agent[Any] | None" = None,
117131
) -> list[Tool]:
118132
"""Get all function tools from a single MCP server."""
119133

120134
with mcp_tools_span(server=server.name) as span:
121-
tools = await server.list_tools()
135+
tools = await server.list_tools(run_context, agent)
122136
span.span_data.result = [tool.name for tool in tools]
123137

124138
return [cls.to_function_tool(tool, server, convert_schemas_to_strict) for tool in tools]

src/agents/tool.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@
2626
from .util._types import MaybeAwaitable
2727

2828
if TYPE_CHECKING:
29+
2930
from .agent import Agent
30-
from mcp.types import Tool as MCPTool
3131

3232
ToolParams = ParamSpec("ToolParams")
3333

0 commit comments

Comments
 (0)