Skip to content

Commit a873952

Browse files
committed
feat: implement tool filter interface for MCP servers
1 parent f553f20 commit a873952

File tree

3 files changed

+176
-50
lines changed

3 files changed

+176
-50
lines changed

src/agents/mcp/server.py

Lines changed: 48 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
from ..exceptions import UserError
1919
from ..logger import logger
20+
from ..tool import ToolFilter, ToolFilterContext, ToolFilterStatic
2021

2122

2223
class MCPServer(abc.ABC):
@@ -61,8 +62,7 @@ def __init__(
6162
self,
6263
cache_tools_list: bool,
6364
client_session_timeout_seconds: float | None,
64-
allowed_tools: list[str] | None = None,
65-
excluded_tools: list[str] | None = None,
65+
tool_filter: ToolFilter = None,
6666
):
6767
"""
6868
Args:
@@ -74,10 +74,7 @@ def __init__(
7474
(by avoiding a round-trip to the server every time).
7575
7676
client_session_timeout_seconds: the read timeout passed to the MCP ClientSession.
77-
allowed_tools: Optional list of tool names to allow (whitelist).
78-
If set, only these tools will be available.
79-
excluded_tools: Optional list of tool names to exclude (blacklist).
80-
If set, these tools will be filtered out.
77+
tool_filter: The tool filter to use for filtering tools.
8178
"""
8279
self.session: ClientSession | None = None
8380
self.exit_stack: AsyncExitStack = AsyncExitStack()
@@ -91,8 +88,39 @@ def __init__(
9188
self._cache_dirty = True
9289
self._tools_list: list[MCPTool] | None = None
9390

94-
self.allowed_tools = allowed_tools
95-
self.excluded_tools = excluded_tools
91+
self.tool_filter = tool_filter
92+
93+
def _apply_tool_filter(self, tools: list[MCPTool]) -> list[MCPTool]:
94+
"""Apply the tool filter to the list of tools."""
95+
if self.tool_filter is None:
96+
return tools
97+
98+
# Handle static tool filter
99+
if isinstance(self.tool_filter, dict):
100+
static_filter: ToolFilterStatic = self.tool_filter
101+
filtered_tools = tools
102+
103+
# Apply allowed_tool_names filter (whitelist)
104+
if "allowed_tool_names" in static_filter:
105+
allowed_names = static_filter["allowed_tool_names"]
106+
filtered_tools = [t for t in filtered_tools if t.name in allowed_names]
107+
108+
# Apply blocked_tool_names filter (blacklist)
109+
if "blocked_tool_names" in static_filter:
110+
blocked_names = static_filter["blocked_tool_names"]
111+
filtered_tools = [t for t in filtered_tools if t.name not in blocked_names]
112+
113+
return filtered_tools
114+
115+
# Handle callable tool filter
116+
# For now, we can't support callable filters because we don't have access to
117+
# run context and agent in the current list_tools signature.
118+
# This could be enhanced in the future by modifying the call chain.
119+
else:
120+
raise NotImplementedError(
121+
"Callable tool filters are not yet supported. Please use ToolFilterStatic "
122+
"with 'allowed_tool_names' and/or 'blocked_tool_names' for now."
123+
)
96124

97125
@abc.abstractmethod
98126
def create_streams(
@@ -159,12 +187,10 @@ async def list_tools(self) -> list[MCPTool]:
159187
self._tools_list = (await self.session.list_tools()).tools
160188
tools = self._tools_list
161189

162-
# Filter tools based on allowed and excluded tools
190+
# Filter tools based on tool_filter
163191
filtered_tools = tools
164-
if self.allowed_tools is not None:
165-
filtered_tools = [t for t in filtered_tools if t.name in self.allowed_tools]
166-
if self.excluded_tools is not None:
167-
filtered_tools = [t for t in filtered_tools if t.name not in self.excluded_tools]
192+
if self.tool_filter is not None:
193+
filtered_tools = self._apply_tool_filter(filtered_tools)
168194
return filtered_tools
169195

170196
async def call_tool(self, tool_name: str, arguments: dict[str, Any] | None) -> CallToolResult:
@@ -226,8 +252,7 @@ def __init__(
226252
cache_tools_list: bool = False,
227253
name: str | None = None,
228254
client_session_timeout_seconds: float | None = 5,
229-
allowed_tools: list[str] | None = None,
230-
excluded_tools: list[str] | None = None,
255+
tool_filter: ToolFilter = None,
231256
):
232257
"""Create a new MCP server based on the stdio transport.
233258
@@ -245,14 +270,12 @@ def __init__(
245270
name: A readable name for the server. If not provided, we'll create one from the
246271
command.
247272
client_session_timeout_seconds: the read timeout passed to the MCP ClientSession.
248-
allowed_tools: Optional list of tool names to allow (whitelist).
249-
excluded_tools: Optional list of tool names to exclude (blacklist).
273+
tool_filter: The tool filter to use for filtering tools.
250274
"""
251275
super().__init__(
252276
cache_tools_list,
253277
client_session_timeout_seconds,
254-
allowed_tools,
255-
excluded_tools,
278+
tool_filter,
256279
)
257280

258281
self.params = StdioServerParameters(
@@ -312,8 +335,7 @@ def __init__(
312335
cache_tools_list: bool = False,
313336
name: str | None = None,
314337
client_session_timeout_seconds: float | None = 5,
315-
allowed_tools: list[str] | None = None,
316-
excluded_tools: list[str] | None = None,
338+
tool_filter: ToolFilter = None,
317339
):
318340
"""Create a new MCP server based on the HTTP with SSE transport.
319341
@@ -333,14 +355,12 @@ def __init__(
333355
URL.
334356
335357
client_session_timeout_seconds: the read timeout passed to the MCP ClientSession.
336-
allowed_tools: Optional list of tool names to allow (whitelist).
337-
excluded_tools: Optional list of tool names to exclude (blacklist).
358+
tool_filter: The tool filter to use for filtering tools.
338359
"""
339360
super().__init__(
340361
cache_tools_list,
341362
client_session_timeout_seconds,
342-
allowed_tools,
343-
excluded_tools,
363+
tool_filter,
344364
)
345365

346366
self.params = params
@@ -400,8 +420,7 @@ def __init__(
400420
cache_tools_list: bool = False,
401421
name: str | None = None,
402422
client_session_timeout_seconds: float | None = 5,
403-
allowed_tools: list[str] | None = None,
404-
excluded_tools: list[str] | None = None,
423+
tool_filter: ToolFilter = None,
405424
):
406425
"""Create a new MCP server based on the Streamable HTTP transport.
407426
@@ -422,14 +441,12 @@ def __init__(
422441
URL.
423442
424443
client_session_timeout_seconds: the read timeout passed to the MCP ClientSession.
425-
allowed_tools: Optional list of tool names to allow (whitelist).
426-
excluded_tools: Optional list of tool names to exclude (blacklist).
444+
tool_filter: The tool filter to use for filtering tools.
427445
"""
428446
super().__init__(
429447
cache_tools_list,
430448
client_session_timeout_seconds,
431-
allowed_tools,
432-
excluded_tools,
449+
tool_filter,
433450
)
434451

435452
self.params = params

src/agents/tool.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727

2828
if TYPE_CHECKING:
2929
from .agent import Agent
30+
from mcp.types import Tool as MCPTool
3031

3132
ToolParams = ParamSpec("ToolParams")
3233

@@ -41,6 +42,73 @@
4142
]
4243

4344

45+
@dataclass
46+
class ToolFilterContext:
47+
"""Context information available to tool filter functions."""
48+
49+
run_context: RunContextWrapper[Any]
50+
"""The current run context."""
51+
52+
agent: "Agent[Any]"
53+
"""The agent that is requesting the tool list."""
54+
55+
server_name: str
56+
"""The name of the MCP server."""
57+
58+
59+
ToolFilterCallable = Callable[["ToolFilterContext", "MCPTool"], MaybeAwaitable[bool]]
60+
"""A function that determines whether a tool should be available.
61+
62+
Args:
63+
context: The context information including run context, agent, and server name.
64+
tool: The MCP tool to filter.
65+
66+
Returns:
67+
Whether the tool should be available (True) or filtered out (False).
68+
"""
69+
70+
71+
class ToolFilterStatic(TypedDict):
72+
"""Static tool filter configuration using allowlists and blocklists."""
73+
74+
allowed_tool_names: NotRequired[list[str]]
75+
"""Optional list of tool names to allow (whitelist). If set, only these tools will be available."""
76+
77+
blocked_tool_names: NotRequired[list[str]]
78+
"""Optional list of tool names to exclude (blacklist). If set, these tools will be filtered out."""
79+
80+
81+
ToolFilter = Union[ToolFilterCallable, ToolFilterStatic, None]
82+
"""A tool filter that can be either a function, static configuration, or None (no filtering)."""
83+
84+
85+
def create_static_tool_filter(
86+
allowed_tool_names: list[str] | None = None,
87+
blocked_tool_names: list[str] | None = None,
88+
) -> ToolFilterStatic | None:
89+
"""Create a static tool filter from allowlist and blocklist parameters.
90+
91+
This is a convenience function for creating a ToolFilterStatic.
92+
93+
Args:
94+
allowed_tool_names: Optional list of tool names to allow (whitelist).
95+
blocked_tool_names: Optional list of tool names to exclude (blacklist).
96+
97+
Returns:
98+
A ToolFilterStatic if any filtering is specified, None otherwise.
99+
"""
100+
if allowed_tool_names is None and blocked_tool_names is None:
101+
return None
102+
103+
filter_dict: ToolFilterStatic = {}
104+
if allowed_tool_names is not None:
105+
filter_dict["allowed_tool_names"] = allowed_tool_names
106+
if blocked_tool_names is not None:
107+
filter_dict["blocked_tool_names"] = blocked_tool_names
108+
109+
return filter_dict
110+
111+
44112
@dataclass
45113
class FunctionToolResult:
46114
tool: FunctionTool

tests/mcp/test_tool_filtering.py

Lines changed: 60 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,43 +1,65 @@
11
import pytest
22

3+
from agents.tool import ToolFilterStatic
34
from .helpers import FakeMCPServer
45

56

67
class FilterableFakeMCPServer(FakeMCPServer):
78
"""Extended FakeMCPServer that supports tool filtering"""
89

9-
def __init__(self, tools=None, allowed_tools=None, excluded_tools=None, server_name=None):
10+
def __init__(self, tools=None, tool_filter=None, server_name=None):
1011
super().__init__(tools)
11-
self.allowed_tools = allowed_tools
12-
self.excluded_tools = excluded_tools
12+
self.tool_filter = tool_filter
1313
self._server_name = server_name
1414

1515
async def list_tools(self):
1616
tools = await super().list_tools()
1717

1818
# Apply filtering logic similar to _MCPServerWithClientSession
1919
filtered_tools = tools
20-
if self.allowed_tools is not None:
21-
filtered_tools = [t for t in filtered_tools if t.name in self.allowed_tools]
22-
if self.excluded_tools is not None:
23-
filtered_tools = [t for t in filtered_tools if t.name not in self.excluded_tools]
20+
if self.tool_filter is not None:
21+
filtered_tools = self._apply_tool_filter(filtered_tools)
2422
return filtered_tools
2523

24+
def _apply_tool_filter(self, tools):
25+
"""Apply the tool filter to the list of tools."""
26+
if self.tool_filter is None:
27+
return tools
28+
29+
# Handle static tool filter
30+
if isinstance(self.tool_filter, dict):
31+
static_filter: ToolFilterStatic = self.tool_filter
32+
filtered_tools = tools
33+
34+
# Apply allowed_tool_names filter (whitelist)
35+
if "allowed_tool_names" in static_filter:
36+
allowed_names = static_filter["allowed_tool_names"]
37+
filtered_tools = [t for t in filtered_tools if t.name in allowed_names]
38+
39+
# Apply blocked_tool_names filter (blacklist)
40+
if "blocked_tool_names" in static_filter:
41+
blocked_names = static_filter["blocked_tool_names"]
42+
filtered_tools = [t for t in filtered_tools if t.name not in blocked_names]
43+
44+
return filtered_tools
45+
46+
return tools
47+
2648
@property
2749
def name(self) -> str:
2850
return self._server_name or "filterable_fake_server"
2951

3052

3153
@pytest.mark.asyncio
32-
async def test_server_allowed_tools():
33-
"""Test that server-level allowed_tools filters tools correctly"""
54+
async def test_server_allowed_tool_names():
55+
"""Test that server-level allowed_tool_names filters tools correctly"""
3456
server = FilterableFakeMCPServer(server_name="test_server")
3557
server.add_tool("tool1", {})
3658
server.add_tool("tool2", {})
3759
server.add_tool("tool3", {})
3860

39-
# Set allowed_tools to only include tool1 and tool2
40-
server.allowed_tools = ["tool1", "tool2"]
61+
# Set tool_filter to only include tool1 and tool2
62+
server.tool_filter = {"allowed_tool_names": ["tool1", "tool2"]}
4163

4264
# Get tools and verify filtering
4365
tools = await server.list_tools()
@@ -46,15 +68,15 @@ async def test_server_allowed_tools():
4668

4769

4870
@pytest.mark.asyncio
49-
async def test_server_excluded_tools():
50-
"""Test that server-level excluded_tools filters tools correctly"""
71+
async def test_server_blocked_tool_names():
72+
"""Test that server-level blocked_tool_names filters tools correctly"""
5173
server = FilterableFakeMCPServer(server_name="test_server")
5274
server.add_tool("tool1", {})
5375
server.add_tool("tool2", {})
5476
server.add_tool("tool3", {})
5577

56-
# Set excluded_tools to exclude tool3
57-
server.excluded_tools = ["tool3"]
78+
# Set tool_filter to exclude tool3
79+
server.tool_filter = {"blocked_tool_names": ["tool3"]}
5880

5981
# Get tools and verify filtering
6082
tools = await server.list_tools()
@@ -64,18 +86,37 @@ async def test_server_excluded_tools():
6486

6587
@pytest.mark.asyncio
6688
async def test_server_both_filters():
67-
"""Test that server-level allowed_tools and excluded_tools work together correctly"""
89+
"""Test that server-level allowed_tool_names and blocked_tool_names work together correctly"""
6890
server = FilterableFakeMCPServer(server_name="test_server")
6991
server.add_tool("tool1", {})
7092
server.add_tool("tool2", {})
7193
server.add_tool("tool3", {})
7294
server.add_tool("tool4", {})
7395

7496
# Set both filters
75-
server.allowed_tools = ["tool1", "tool2", "tool3"]
76-
server.excluded_tools = ["tool3"]
97+
server.tool_filter = {
98+
"allowed_tool_names": ["tool1", "tool2", "tool3"],
99+
"blocked_tool_names": ["tool3"]
100+
}
77101

78-
# Get tools and verify filtering (allowed_tools applied first, then excluded_tools)
102+
# Get tools and verify filtering (allowed_tool_names applied first, then blocked_tool_names)
79103
tools = await server.list_tools()
80104
assert len(tools) == 2
81105
assert {t.name for t in tools} == {"tool1", "tool2"}
106+
107+
108+
@pytest.mark.asyncio
109+
async def test_server_no_filter():
110+
"""Test that when no filter is set, all tools are returned"""
111+
server = FilterableFakeMCPServer(server_name="test_server")
112+
server.add_tool("tool1", {})
113+
server.add_tool("tool2", {})
114+
server.add_tool("tool3", {})
115+
116+
# No filter set (None)
117+
server.tool_filter = None
118+
119+
# Get tools and verify no filtering
120+
tools = await server.list_tools()
121+
assert len(tools) == 3
122+
assert {t.name for t in tools} == {"tool1", "tool2", "tool3"}

0 commit comments

Comments
 (0)