17
17
18
18
from ..exceptions import UserError
19
19
from ..logger import logger
20
+ from .util import ToolFilter , ToolFilterStatic
20
21
21
22
22
23
class MCPServer (abc .ABC ):
@@ -61,8 +62,7 @@ def __init__(
61
62
self ,
62
63
cache_tools_list : bool ,
63
64
client_session_timeout_seconds : float | None ,
64
- allowed_tools : list [str ] | None = None ,
65
- excluded_tools : list [str ] | None = None ,
65
+ tool_filter : ToolFilter = None ,
66
66
):
67
67
"""
68
68
Args:
@@ -74,10 +74,7 @@ def __init__(
74
74
(by avoiding a round-trip to the server every time).
75
75
76
76
client_session_timeout_seconds: the read timeout passed to the MCP ClientSession.
77
- allowed_tools: Optional list of tool names to allow (whitelist).
78
- If set, only these tools will be available.
79
- excluded_tools: Optional list of tool names to exclude (blacklist).
80
- If set, these tools will be filtered out.
77
+ tool_filter: The tool filter to use for filtering tools.
81
78
"""
82
79
self .session : ClientSession | None = None
83
80
self .exit_stack : AsyncExitStack = AsyncExitStack ()
@@ -91,8 +88,39 @@ def __init__(
91
88
self ._cache_dirty = True
92
89
self ._tools_list : list [MCPTool ] | None = None
93
90
94
- self .allowed_tools = allowed_tools
95
- self .excluded_tools = excluded_tools
91
+ self .tool_filter = tool_filter
92
+
93
+ def _apply_tool_filter (self , tools : list [MCPTool ]) -> list [MCPTool ]:
94
+ """Apply the tool filter to the list of tools."""
95
+ if self .tool_filter is None :
96
+ return tools
97
+
98
+ # Handle static tool filter
99
+ if isinstance (self .tool_filter , dict ):
100
+ static_filter : ToolFilterStatic = self .tool_filter
101
+ filtered_tools = tools
102
+
103
+ # Apply allowed_tool_names filter (whitelist)
104
+ if "allowed_tool_names" in static_filter :
105
+ allowed_names = static_filter ["allowed_tool_names" ]
106
+ filtered_tools = [t for t in filtered_tools if t .name in allowed_names ]
107
+
108
+ # Apply blocked_tool_names filter (blacklist)
109
+ if "blocked_tool_names" in static_filter :
110
+ blocked_names = static_filter ["blocked_tool_names" ]
111
+ filtered_tools = [t for t in filtered_tools if t .name not in blocked_names ]
112
+
113
+ return filtered_tools
114
+
115
+ # Handle callable tool filter
116
+ # For now, we can't support callable filters because we don't have access to
117
+ # run context and agent in the current list_tools signature.
118
+ # This could be enhanced in the future by modifying the call chain.
119
+ else :
120
+ raise NotImplementedError (
121
+ "Callable tool filters are not yet supported. Please use ToolFilterStatic "
122
+ "with 'allowed_tool_names' and/or 'blocked_tool_names' for now."
123
+ )
96
124
97
125
@abc .abstractmethod
98
126
def create_streams (
@@ -159,12 +187,10 @@ async def list_tools(self) -> list[MCPTool]:
159
187
self ._tools_list = (await self .session .list_tools ()).tools
160
188
tools = self ._tools_list
161
189
162
- # Filter tools based on allowed and excluded tools
190
+ # Filter tools based on tool_filter
163
191
filtered_tools = tools
164
- if self .allowed_tools is not None :
165
- filtered_tools = [t for t in filtered_tools if t .name in self .allowed_tools ]
166
- if self .excluded_tools is not None :
167
- filtered_tools = [t for t in filtered_tools if t .name not in self .excluded_tools ]
192
+ if self .tool_filter is not None :
193
+ filtered_tools = self ._apply_tool_filter (filtered_tools )
168
194
return filtered_tools
169
195
170
196
async def call_tool (self , tool_name : str , arguments : dict [str , Any ] | None ) -> CallToolResult :
@@ -226,8 +252,7 @@ def __init__(
226
252
cache_tools_list : bool = False ,
227
253
name : str | None = None ,
228
254
client_session_timeout_seconds : float | None = 5 ,
229
- allowed_tools : list [str ] | None = None ,
230
- excluded_tools : list [str ] | None = None ,
255
+ tool_filter : ToolFilter = None ,
231
256
):
232
257
"""Create a new MCP server based on the stdio transport.
233
258
@@ -245,14 +270,12 @@ def __init__(
245
270
name: A readable name for the server. If not provided, we'll create one from the
246
271
command.
247
272
client_session_timeout_seconds: the read timeout passed to the MCP ClientSession.
248
- allowed_tools: Optional list of tool names to allow (whitelist).
249
- excluded_tools: Optional list of tool names to exclude (blacklist).
273
+ tool_filter: The tool filter to use for filtering tools.
250
274
"""
251
275
super ().__init__ (
252
276
cache_tools_list ,
253
277
client_session_timeout_seconds ,
254
- allowed_tools ,
255
- excluded_tools ,
278
+ tool_filter ,
256
279
)
257
280
258
281
self .params = StdioServerParameters (
@@ -312,8 +335,7 @@ def __init__(
312
335
cache_tools_list : bool = False ,
313
336
name : str | None = None ,
314
337
client_session_timeout_seconds : float | None = 5 ,
315
- allowed_tools : list [str ] | None = None ,
316
- excluded_tools : list [str ] | None = None ,
338
+ tool_filter : ToolFilter = None ,
317
339
):
318
340
"""Create a new MCP server based on the HTTP with SSE transport.
319
341
@@ -333,14 +355,12 @@ def __init__(
333
355
URL.
334
356
335
357
client_session_timeout_seconds: the read timeout passed to the MCP ClientSession.
336
- allowed_tools: Optional list of tool names to allow (whitelist).
337
- excluded_tools: Optional list of tool names to exclude (blacklist).
358
+ tool_filter: The tool filter to use for filtering tools.
338
359
"""
339
360
super ().__init__ (
340
361
cache_tools_list ,
341
362
client_session_timeout_seconds ,
342
- allowed_tools ,
343
- excluded_tools ,
363
+ tool_filter ,
344
364
)
345
365
346
366
self .params = params
@@ -400,8 +420,7 @@ def __init__(
400
420
cache_tools_list : bool = False ,
401
421
name : str | None = None ,
402
422
client_session_timeout_seconds : float | None = 5 ,
403
- allowed_tools : list [str ] | None = None ,
404
- excluded_tools : list [str ] | None = None ,
423
+ tool_filter : ToolFilter = None ,
405
424
):
406
425
"""Create a new MCP server based on the Streamable HTTP transport.
407
426
@@ -422,14 +441,12 @@ def __init__(
422
441
URL.
423
442
424
443
client_session_timeout_seconds: the read timeout passed to the MCP ClientSession.
425
- allowed_tools: Optional list of tool names to allow (whitelist).
426
- excluded_tools: Optional list of tool names to exclude (blacklist).
444
+ tool_filter: The tool filter to use for filtering tools.
427
445
"""
428
446
super ().__init__ (
429
447
cache_tools_list ,
430
448
client_session_timeout_seconds ,
431
- allowed_tools ,
432
- excluded_tools ,
449
+ tool_filter ,
433
450
)
434
451
435
452
self .params = params
0 commit comments