diff --git a/src/mcp/server/fastmcp/authorizer.py b/src/mcp/server/fastmcp/authorizer.py new file mode 100644 index 000000000..3fa4932e2 --- /dev/null +++ b/src/mcp/server/fastmcp/authorizer.py @@ -0,0 +1,138 @@ +from __future__ import annotations + +import abc +from typing import TYPE_CHECKING, Any + +from pydantic import AnyUrl +from starlette.requests import Request + +from mcp.server.session import ServerSession + +if TYPE_CHECKING: + from mcp.server.fastmcp.server import Context + + +class Authorizer: + __metaclass__ = abc.ABCMeta + + @abc.abstractmethod + def permit_get_tool(self, name: str, context: Context[ServerSession, object, Request] | None = None) -> bool: + """Check if the specified tool can be retrieved from the associated mcp server""" + return False + + @abc.abstractmethod + def permit_list_tool( + self, + name: str, + context: Context[ServerSession, object, Request] | None = None, + ) -> bool: + """Check if the specified tool can be listed from the associated mcp server""" + return False + + @abc.abstractmethod + def permit_call_tool( + self, + name: str, + arguments: dict[str, Any], + context: Context[ServerSession, object, Request] | None = None, + ) -> bool: + """Check if the specified tool can be called from the associated mcp server""" + return False + + @abc.abstractmethod + def permit_get_resource( + self, resource: AnyUrl | str, context: Context[ServerSession, object, Request] | None = None + ) -> bool: + """Check if the specified resource can be retrieved from the associated mcp server""" + return False + + @abc.abstractmethod + def permit_create_resource( + self, uri: str, params: dict[str, Any], context: Context[ServerSession, object, Request] | None = None + ) -> bool: + """Check if the specified resource can be created on the associated mcp server""" + return False + + @abc.abstractmethod + def permit_list_resource( + self, resource: AnyUrl | str, context: Context[ServerSession, object, Request] | None = None + ) -> bool: + """Check if the specified resource can be listed from the associated mcp server""" + return False + + @abc.abstractmethod + def permit_list_template( + self, resource: AnyUrl | str, context: Context[ServerSession, object, Request] | None = None + ) -> bool: + """Check if the specified template can be listed from the associated mcp server""" + return False + + @abc.abstractmethod + def permit_get_prompt(self, name: str, context: Context[ServerSession, object, Request] | None = None) -> bool: + """Check if the specified prompt can be retrieved from the associated mcp server""" + return False + + @abc.abstractmethod + def permit_list_prompt(self, name: str, context: Context[ServerSession, object, Request] | None = None) -> bool: + """Check if the specified prompt can be listed from the associated mcp server""" + return False + + @abc.abstractmethod + def permit_render_prompt( + self, + name: str, + arguments: dict[str, Any] | None = None, + context: Context[ServerSession, object, Request] | None = None, + ) -> bool: + """Check if the specified prompt can be rendered from the associated mcp server""" + return False + + +class AllowAllAuthorizer(Authorizer): + def permit_get_tool(self, name: str, context: Context[ServerSession, object, Request] | None = None) -> bool: + return True + + def permit_list_tool(self, name: str, context: Context[ServerSession, object, Request] | None = None) -> bool: + return True + + def permit_call_tool( + self, + name: str, + arguments: dict[str, Any], + context: Context[ServerSession, object, Request] | None = None, + ) -> bool: + return True + + def permit_get_resource( + self, resource: AnyUrl | str, context: Context[ServerSession, object, Request] | None = None + ) -> bool: + return True + + def permit_create_resource( + self, uri: str, params: dict[str, Any], context: Context[ServerSession, object, Request] | None = None + ) -> bool: + return True + + def permit_list_resource( + self, resource: AnyUrl | str, context: Context[ServerSession, object, Request] | None = None + ) -> bool: + return True + + def permit_list_template( + self, resource: AnyUrl | str, context: Context[ServerSession, object, Request] | None = None + ) -> bool: + return True + + def permit_get_prompt(self, name: str, context: Context[ServerSession, object, Request] | None = None) -> bool: + return True + + def permit_list_prompt(self, name: str, context: Context[ServerSession, object, Request] | None = None) -> bool: + return True + + def permit_render_prompt( + self, + name: str, + arguments: dict[str, Any] | None = None, + context: Context[ServerSession, object, Request] | None = None, + ) -> bool: + return True diff --git a/src/mcp/server/fastmcp/prompts/manager.py b/src/mcp/server/fastmcp/prompts/manager.py index 6b01d91cd..fafe493b5 100644 --- a/src/mcp/server/fastmcp/prompts/manager.py +++ b/src/mcp/server/fastmcp/prompts/manager.py @@ -1,9 +1,18 @@ """Prompt management functionality.""" -from typing import Any +from __future__ import annotations as _annotations +from typing import TYPE_CHECKING, Any + +from starlette.requests import Request + +from mcp.server.fastmcp.authorizer import AllowAllAuthorizer, Authorizer from mcp.server.fastmcp.prompts.base import Message, Prompt from mcp.server.fastmcp.utilities.logging import get_logger +from mcp.server.session import ServerSession + +if TYPE_CHECKING: + from mcp.server.fastmcp.server import Context logger = get_logger(__name__) @@ -11,17 +20,25 @@ class PromptManager: """Manages FastMCP prompts.""" - def __init__(self, warn_on_duplicate_prompts: bool = True): + def __init__( + self, + warn_on_duplicate_prompts: bool = True, + authorizer: Authorizer = AllowAllAuthorizer(), + ): self._prompts: dict[str, Prompt] = {} + self._authorizer = authorizer self.warn_on_duplicate_prompts = warn_on_duplicate_prompts - def get_prompt(self, name: str) -> Prompt | None: + def get_prompt(self, name: str, context: Context[ServerSession, object, Request] | None = None) -> Prompt | None: """Get prompt by name.""" - return self._prompts.get(name) + if self._authorizer.permit_get_prompt(name, context): + return self._prompts.get(name) + else: + return None - def list_prompts(self) -> list[Prompt]: + def list_prompts(self, context: Context[ServerSession, object, Request] | None = None) -> list[Prompt]: """List all registered prompts.""" - return list(self._prompts.values()) + return [prompt for name, prompt in self._prompts.items() if self._authorizer.permit_list_prompt(name, context)] def add_prompt( self, @@ -39,10 +56,17 @@ def add_prompt( self._prompts[prompt.name] = prompt return prompt - async def render_prompt(self, name: str, arguments: dict[str, Any] | None = None) -> list[Message]: + async def render_prompt( + self, + name: str, + arguments: dict[str, Any] | None = None, + context: Context[ServerSession, object, Request] | None = None, + ) -> list[Message]: """Render a prompt by name with arguments.""" prompt = self.get_prompt(name) if not prompt: raise ValueError(f"Unknown prompt: {name}") - - return await prompt.render(arguments) + if self._authorizer.permit_render_prompt(name, arguments, context): + return await prompt.render(arguments) + else: + raise ValueError(f"Unknown prompt: {name}") diff --git a/src/mcp/server/fastmcp/resources/resource_manager.py b/src/mcp/server/fastmcp/resources/resource_manager.py index 35e4ec04d..ea65ce81b 100644 --- a/src/mcp/server/fastmcp/resources/resource_manager.py +++ b/src/mcp/server/fastmcp/resources/resource_manager.py @@ -1,13 +1,21 @@ """Resource manager functionality.""" +from __future__ import annotations as _annotations + from collections.abc import Callable -from typing import Any +from typing import TYPE_CHECKING, Any from pydantic import AnyUrl +from starlette.requests import Request +from mcp.server.fastmcp.authorizer import AllowAllAuthorizer, Authorizer from mcp.server.fastmcp.resources.base import Resource from mcp.server.fastmcp.resources.templates import ResourceTemplate from mcp.server.fastmcp.utilities.logging import get_logger +from mcp.server.session import ServerSession + +if TYPE_CHECKING: + from mcp.server.fastmcp.server import Context logger = get_logger(__name__) @@ -15,10 +23,15 @@ class ResourceManager: """Manages FastMCP resources.""" - def __init__(self, warn_on_duplicate_resources: bool = True): + def __init__( + self, + warn_on_duplicate_resources: bool = True, + authorizer: Authorizer = AllowAllAuthorizer(), + ): self._resources: dict[str, Resource] = {} self._templates: dict[str, ResourceTemplate] = {} self.warn_on_duplicate_resources = warn_on_duplicate_resources + self._authorizer = authorizer def add_resource(self, resource: Resource) -> Resource: """Add a resource to the manager. @@ -67,31 +80,43 @@ def add_template( self._templates[template.uri_template] = template return template - async def get_resource(self, uri: AnyUrl | str) -> Resource | None: + async def get_resource( + self, uri: AnyUrl | str, context: Context[ServerSession, object, Request] | None = None + ) -> Resource | None: """Get resource by URI, checking concrete resources first, then templates.""" uri_str = str(uri) logger.debug("Getting resource", extra={"uri": uri_str}) # First check concrete resources if resource := self._resources.get(uri_str): - return resource + if self._authorizer.permit_get_resource(uri_str, context): + return resource + else: + raise ValueError(f"Unknown resource: {uri}") # Then check templates for template in self._templates.values(): if params := template.matches(uri_str): try: - return await template.create_resource(uri_str, params) + if self._authorizer.permit_create_resource(uri_str, params): + return await template.create_resource(uri_str, params) + else: + raise ValueError(f"Unknown resource: {uri}") except Exception as e: raise ValueError(f"Error creating resource from template: {e}") raise ValueError(f"Unknown resource: {uri}") - def list_resources(self) -> list[Resource]: + def list_resources(self, context: Context[ServerSession, object, Request] | None = None) -> list[Resource]: """List all registered resources.""" logger.debug("Listing resources", extra={"count": len(self._resources)}) - return list(self._resources.values()) + return [ + resource for uri, resource in self._resources.items() if self._authorizer.permit_list_resource(uri, context) + ] - def list_templates(self) -> list[ResourceTemplate]: + def list_templates(self, context: Context[ServerSession, object, Request] | None = None) -> list[ResourceTemplate]: """List all registered templates.""" logger.debug("Listing templates", extra={"count": len(self._templates)}) - return list(self._templates.values()) + return [ + template for uri, template in self._templates.items() if self._authorizer.permit_list_template(uri, context) + ] diff --git a/src/mcp/server/fastmcp/server.py b/src/mcp/server/fastmcp/server.py index 956a8aa78..1922fa454 100644 --- a/src/mcp/server/fastmcp/server.py +++ b/src/mcp/server/fastmcp/server.py @@ -32,6 +32,7 @@ from mcp.server.auth.provider import OAuthAuthorizationServerProvider, ProviderTokenVerifier, TokenVerifier from mcp.server.auth.settings import AuthSettings from mcp.server.elicitation import ElicitationResult, ElicitSchemaModelT, elicit_with_validation +from mcp.server.fastmcp.authorizer import AllowAllAuthorizer, Authorizer from mcp.server.fastmcp.exceptions import ResourceError from mcp.server.fastmcp.prompts import Prompt, PromptManager from mcp.server.fastmcp.resources import FunctionResource, Resource, ResourceManager @@ -117,6 +118,8 @@ class Settings(BaseSettings, Generic[LifespanResultT]): # Transport security settings (DNS rebinding protection) transport_security: TransportSecuritySettings | None = None + authorizer: Authorizer = AllowAllAuthorizer() + def lifespan_wrapper( app: FastMCP, @@ -149,9 +152,19 @@ def __init__( instructions=instructions, lifespan=(lifespan_wrapper(self, self.settings.lifespan) if self.settings.lifespan else default_lifespan), ) - self._tool_manager = ToolManager(tools=tools, warn_on_duplicate_tools=self.settings.warn_on_duplicate_tools) - self._resource_manager = ResourceManager(warn_on_duplicate_resources=self.settings.warn_on_duplicate_resources) - self._prompt_manager = PromptManager(warn_on_duplicate_prompts=self.settings.warn_on_duplicate_prompts) + self._tool_manager = ToolManager( + tools=tools, + warn_on_duplicate_tools=self.settings.warn_on_duplicate_tools, + authorizer=self.settings.authorizer, + ) + self._resource_manager = ResourceManager( + warn_on_duplicate_resources=self.settings.warn_on_duplicate_resources, + authorizer=self.settings.authorizer, + ) + self._prompt_manager = PromptManager( + warn_on_duplicate_prompts=self.settings.warn_on_duplicate_prompts, + authorizer=self.settings.authorizer, + ) # Validate auth configuration if self.settings.auth is not None: if auth_server_provider and token_verifier: @@ -244,7 +257,8 @@ def _setup_handlers(self) -> None: async def list_tools(self) -> list[MCPTool]: """List all available tools.""" - tools = self._tool_manager.list_tools() + context = self.get_context() + tools = self._tool_manager.list_tools(context) return [ MCPTool( name=info.name, @@ -275,8 +289,8 @@ async def call_tool(self, name: str, arguments: dict[str, Any]) -> Sequence[Cont async def list_resources(self) -> list[MCPResource]: """List all available resources.""" - - resources = self._resource_manager.list_resources() + context = self.get_context() + resources = self._resource_manager.list_resources(context) return [ MCPResource( uri=resource.uri, @@ -289,7 +303,8 @@ async def list_resources(self) -> list[MCPResource]: ] async def list_resource_templates(self) -> list[MCPResourceTemplate]: - templates = self._resource_manager.list_templates() + context = self.get_context() + templates = self._resource_manager.list_templates(context) return [ MCPResourceTemplate( uriTemplate=template.uri_template, @@ -302,8 +317,8 @@ async def list_resource_templates(self) -> list[MCPResourceTemplate]: async def read_resource(self, uri: AnyUrl | str) -> Iterable[ReadResourceContents]: """Read a resource by URI.""" - - resource = await self._resource_manager.get_resource(uri) + context = self.get_context() + resource = await self._resource_manager.get_resource(uri, context) if not resource: raise ResourceError(f"Unknown resource: {uri}") @@ -934,9 +949,9 @@ async def handle_streamable_http(scope: Scope, receive: Receive, send: Send) -> lifespan=lambda app: self.session_manager.run(), ) - async def list_prompts(self) -> list[MCPPrompt]: + async def list_prompts(self, context: Context[ServerSession, object, Request] | None = None) -> list[MCPPrompt]: """List all available prompts.""" - prompts = self._prompt_manager.list_prompts() + prompts = self._prompt_manager.list_prompts(context) return [ MCPPrompt( name=prompt.name, @@ -954,10 +969,15 @@ async def list_prompts(self) -> list[MCPPrompt]: for prompt in prompts ] - async def get_prompt(self, name: str, arguments: dict[str, Any] | None = None) -> GetPromptResult: + async def get_prompt( + self, + name: str, + arguments: dict[str, Any] | None = None, + context: Context[ServerSession, object, Request] | None = None, + ) -> GetPromptResult: """Get a prompt by name with arguments.""" try: - messages = await self._prompt_manager.render_prompt(name, arguments) + messages = await self._prompt_manager.render_prompt(name, arguments, context) return GetPromptResult(messages=pydantic_core.to_jsonable_python(messages)) except Exception as e: diff --git a/src/mcp/server/fastmcp/tools/tool_manager.py b/src/mcp/server/fastmcp/tools/tool_manager.py index bfa8b2382..95d2584eb 100644 --- a/src/mcp/server/fastmcp/tools/tool_manager.py +++ b/src/mcp/server/fastmcp/tools/tool_manager.py @@ -3,15 +3,17 @@ from collections.abc import Callable from typing import TYPE_CHECKING, Any +from starlette.requests import Request + +from mcp.server.fastmcp.authorizer import AllowAllAuthorizer, Authorizer from mcp.server.fastmcp.exceptions import ToolError from mcp.server.fastmcp.tools.base import Tool from mcp.server.fastmcp.utilities.logging import get_logger -from mcp.shared.context import LifespanContextT, RequestT +from mcp.server.session import ServerSession from mcp.types import ToolAnnotations if TYPE_CHECKING: from mcp.server.fastmcp.server import Context - from mcp.server.session import ServerSessionT logger = get_logger(__name__) @@ -24,6 +26,7 @@ def __init__( warn_on_duplicate_tools: bool = True, *, tools: list[Tool] | None = None, + authorizer: Authorizer = AllowAllAuthorizer(), ): self._tools: dict[str, Tool] = {} if tools is not None: @@ -32,15 +35,19 @@ def __init__( logger.warning(f"Tool already exists: {tool.name}") self._tools[tool.name] = tool - self.warn_on_duplicate_tools = warn_on_duplicate_tools + self.warn_on_duplicate_tools = (warn_on_duplicate_tools,) + self._authorizer = authorizer - def get_tool(self, name: str) -> Tool | None: + def get_tool(self, name: str, context: Context[ServerSession, object, Request] | None = None) -> Tool | None: """Get tool by name.""" - return self._tools.get(name) + if self._authorizer.permit_get_tool(name, context): + return self._tools.get(name) + else: + return None - def list_tools(self) -> list[Tool]: + def list_tools(self, context: Context[ServerSession, object, Request] | None = None) -> list[Tool]: """List all registered tools.""" - return list(self._tools.values()) + return [tool for name, tool in self._tools.items() if self._authorizer.permit_list_tool(name, context)] def add_tool( self, @@ -72,12 +79,12 @@ async def call_tool( self, name: str, arguments: dict[str, Any], - context: Context[ServerSessionT, LifespanContextT, RequestT] | None = None, + context: Context[ServerSession, object, Request] | None = None, convert_result: bool = False, ) -> Any: """Call a tool by name with arguments.""" - tool = self.get_tool(name) - if not tool: + tool = self._tools.get(name) + if not tool or not self._authorizer.permit_call_tool(name, arguments, context): raise ToolError(f"Unknown tool: {name}") return await tool.run(arguments, context=context, convert_result=convert_result) diff --git a/tests/server/fastmcp/test_tool_manager.py b/tests/server/fastmcp/test_tool_manager.py index 4b2052da5..65f002fab 100644 --- a/tests/server/fastmcp/test_tool_manager.py +++ b/tests/server/fastmcp/test_tool_manager.py @@ -7,6 +7,7 @@ from pydantic import BaseModel from mcp.server.fastmcp import Context, FastMCP +from mcp.server.fastmcp.authorizer import Authorizer from mcp.server.fastmcp.exceptions import ToolError from mcp.server.fastmcp.tools import Tool, ToolManager from mcp.server.fastmcp.utilities.func_metadata import ArgModelBase, FuncMetadata @@ -173,7 +174,7 @@ def f(x: int) -> int: manager = ToolManager() manager.add_tool(f) - manager.warn_on_duplicate_tools = False + manager.warn_on_duplicate_tools = False # type: ignore with caplog.at_level(logging.WARNING): manager.add_tool(f) assert "Tool already exists: f" not in caplog.text @@ -313,6 +314,30 @@ def name_shrimp(tank: MyShrimpTank, ctx: Context) -> list[str]: ) assert result == ["rex", "gertrude"] + @pytest.mark.anyio + async def test_call_tool_not_permitted(self): + async def double(n: int) -> int: + """Double a number.""" + return n * 2 + + class TestAuthorizer(Authorizer): + allow: bool = True + + def permit_list_tool(self, name, context=None): + return self.allow + + def permit_call_tool(self, name, arguments, context=None): + return self.allow + + authorizer = TestAuthorizer() + manager = ToolManager(authorizer=authorizer) + manager.add_tool(double) + result = await manager.call_tool("double", {"n": 5}) + assert result == 10 + authorizer.allow = False + with pytest.raises(ToolError, match="Unknown tool: double"): + await manager.call_tool("double", {"n": 5}) + class TestToolSchema: @pytest.mark.anyio