Skip to content

add authorizer plugin to enable fine grained authorization checks on … #1032

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 9 commits into
base: main
Choose a base branch
from
138 changes: 138 additions & 0 deletions src/mcp/server/fastmcp/authorizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
from __future__ import annotations

import abc
from typing import TYPE_CHECKING, Any

from pydantic import AnyUrl
from starlette.requests import Request

from mcp.server.session import ServerSession

if TYPE_CHECKING:
from mcp.server.fastmcp.server import Context


class Authorizer:
__metaclass__ = abc.ABCMeta

@abc.abstractmethod
def permit_get_tool(self, name: str, context: Context[ServerSession, object, Request] | None = None) -> bool:
"""Check if the specified tool can be retrieved from the associated mcp server"""
return False

@abc.abstractmethod
def permit_list_tool(
self,
name: str,
context: Context[ServerSession, object, Request] | None = None,
) -> bool:
"""Check if the specified tool can be listed from the associated mcp server"""
return False

@abc.abstractmethod
def permit_call_tool(
self,
name: str,
arguments: dict[str, Any],
context: Context[ServerSession, object, Request] | None = None,
) -> bool:
"""Check if the specified tool can be called from the associated mcp server"""
return False

@abc.abstractmethod
def permit_get_resource(
self, resource: AnyUrl | str, context: Context[ServerSession, object, Request] | None = None
) -> bool:
"""Check if the specified resource can be retrieved from the associated mcp server"""
return False

@abc.abstractmethod
def permit_create_resource(
self, uri: str, params: dict[str, Any], context: Context[ServerSession, object, Request] | None = None
) -> bool:
"""Check if the specified resource can be created on the associated mcp server"""
return False

@abc.abstractmethod
def permit_list_resource(
self, resource: AnyUrl | str, context: Context[ServerSession, object, Request] | None = None
) -> bool:
"""Check if the specified resource can be listed from the associated mcp server"""
return False

@abc.abstractmethod
def permit_list_template(
self, resource: AnyUrl | str, context: Context[ServerSession, object, Request] | None = None
) -> bool:
"""Check if the specified template can be listed from the associated mcp server"""
return False

@abc.abstractmethod
def permit_get_prompt(self, name: str, context: Context[ServerSession, object, Request] | None = None) -> bool:
"""Check if the specified prompt can be retrieved from the associated mcp server"""
return False

@abc.abstractmethod
def permit_list_prompt(self, name: str, context: Context[ServerSession, object, Request] | None = None) -> bool:
"""Check if the specified prompt can be listed from the associated mcp server"""
return False

@abc.abstractmethod
def permit_render_prompt(
self,
name: str,
arguments: dict[str, Any] | None = None,
context: Context[ServerSession, object, Request] | None = None,
) -> bool:
"""Check if the specified prompt can be rendered from the associated mcp server"""
return False


class AllowAllAuthorizer(Authorizer):
def permit_get_tool(self, name: str, context: Context[ServerSession, object, Request] | None = None) -> bool:
return True

def permit_list_tool(self, name: str, context: Context[ServerSession, object, Request] | None = None) -> bool:
return True

def permit_call_tool(
self,
name: str,
arguments: dict[str, Any],
context: Context[ServerSession, object, Request] | None = None,
) -> bool:
return True

def permit_get_resource(
self, resource: AnyUrl | str, context: Context[ServerSession, object, Request] | None = None
) -> bool:
return True

def permit_create_resource(
self, uri: str, params: dict[str, Any], context: Context[ServerSession, object, Request] | None = None
) -> bool:
return True

def permit_list_resource(
self, resource: AnyUrl | str, context: Context[ServerSession, object, Request] | None = None
) -> bool:
return True

def permit_list_template(
self, resource: AnyUrl | str, context: Context[ServerSession, object, Request] | None = None
) -> bool:
return True

def permit_get_prompt(self, name: str, context: Context[ServerSession, object, Request] | None = None) -> bool:
return True

def permit_list_prompt(self, name: str, context: Context[ServerSession, object, Request] | None = None) -> bool:
return True

def permit_render_prompt(
self,
name: str,
arguments: dict[str, Any] | None = None,
context: Context[ServerSession, object, Request] | None = None,
) -> bool:
return True
42 changes: 33 additions & 9 deletions src/mcp/server/fastmcp/prompts/manager.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,44 @@
"""Prompt management functionality."""

from typing import Any
from __future__ import annotations as _annotations

from typing import TYPE_CHECKING, Any

from starlette.requests import Request

from mcp.server.fastmcp.authorizer import AllowAllAuthorizer, Authorizer
from mcp.server.fastmcp.prompts.base import Message, Prompt
from mcp.server.fastmcp.utilities.logging import get_logger
from mcp.server.session import ServerSession

if TYPE_CHECKING:
from mcp.server.fastmcp.server import Context

logger = get_logger(__name__)


class PromptManager:
"""Manages FastMCP prompts."""

def __init__(self, warn_on_duplicate_prompts: bool = True):
def __init__(
self,
warn_on_duplicate_prompts: bool = True,
authorizer: Authorizer = AllowAllAuthorizer(),
):
self._prompts: dict[str, Prompt] = {}
self._authorizer = authorizer
self.warn_on_duplicate_prompts = warn_on_duplicate_prompts

def get_prompt(self, name: str) -> Prompt | None:
def get_prompt(self, name: str, context: Context[ServerSession, object, Request] | None = None) -> Prompt | None:
"""Get prompt by name."""
return self._prompts.get(name)
if self._authorizer.permit_get_prompt(name, context):
return self._prompts.get(name)
else:
return None

def list_prompts(self) -> list[Prompt]:
def list_prompts(self, context: Context[ServerSession, object, Request] | None = None) -> list[Prompt]:
"""List all registered prompts."""
return list(self._prompts.values())
return [prompt for name, prompt in self._prompts.items() if self._authorizer.permit_list_prompt(name, context)]

def add_prompt(
self,
Expand All @@ -39,10 +56,17 @@ def add_prompt(
self._prompts[prompt.name] = prompt
return prompt

async def render_prompt(self, name: str, arguments: dict[str, Any] | None = None) -> list[Message]:
async def render_prompt(
self,
name: str,
arguments: dict[str, Any] | None = None,
context: Context[ServerSession, object, Request] | None = None,
) -> list[Message]:
"""Render a prompt by name with arguments."""
prompt = self.get_prompt(name)
if not prompt:
raise ValueError(f"Unknown prompt: {name}")

return await prompt.render(arguments)
if self._authorizer.permit_render_prompt(name, arguments, context):
return await prompt.render(arguments)
else:
raise ValueError(f"Unknown prompt: {name}")
43 changes: 34 additions & 9 deletions src/mcp/server/fastmcp/resources/resource_manager.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,37 @@
"""Resource manager functionality."""

from __future__ import annotations as _annotations

from collections.abc import Callable
from typing import Any
from typing import TYPE_CHECKING, Any

from pydantic import AnyUrl
from starlette.requests import Request

from mcp.server.fastmcp.authorizer import AllowAllAuthorizer, Authorizer
from mcp.server.fastmcp.resources.base import Resource
from mcp.server.fastmcp.resources.templates import ResourceTemplate
from mcp.server.fastmcp.utilities.logging import get_logger
from mcp.server.session import ServerSession

if TYPE_CHECKING:
from mcp.server.fastmcp.server import Context

logger = get_logger(__name__)


class ResourceManager:
"""Manages FastMCP resources."""

def __init__(self, warn_on_duplicate_resources: bool = True):
def __init__(
self,
warn_on_duplicate_resources: bool = True,
authorizer: Authorizer = AllowAllAuthorizer(),
):
self._resources: dict[str, Resource] = {}
self._templates: dict[str, ResourceTemplate] = {}
self.warn_on_duplicate_resources = warn_on_duplicate_resources
self._authorizer = authorizer

def add_resource(self, resource: Resource) -> Resource:
"""Add a resource to the manager.
Expand Down Expand Up @@ -67,31 +80,43 @@ def add_template(
self._templates[template.uri_template] = template
return template

async def get_resource(self, uri: AnyUrl | str) -> Resource | None:
async def get_resource(
self, uri: AnyUrl | str, context: Context[ServerSession, object, Request] | None = None
) -> Resource | None:
"""Get resource by URI, checking concrete resources first, then templates."""
uri_str = str(uri)
logger.debug("Getting resource", extra={"uri": uri_str})

# First check concrete resources
if resource := self._resources.get(uri_str):
return resource
if self._authorizer.permit_get_resource(uri_str, context):
return resource
else:
raise ValueError(f"Unknown resource: {uri}")

# Then check templates
for template in self._templates.values():
if params := template.matches(uri_str):
try:
return await template.create_resource(uri_str, params)
if self._authorizer.permit_create_resource(uri_str, params):
return await template.create_resource(uri_str, params)
else:
raise ValueError(f"Unknown resource: {uri}")
except Exception as e:
raise ValueError(f"Error creating resource from template: {e}")

raise ValueError(f"Unknown resource: {uri}")

def list_resources(self) -> list[Resource]:
def list_resources(self, context: Context[ServerSession, object, Request] | None = None) -> list[Resource]:
"""List all registered resources."""
logger.debug("Listing resources", extra={"count": len(self._resources)})
return list(self._resources.values())
return [
resource for uri, resource in self._resources.items() if self._authorizer.permit_list_resource(uri, context)
]

def list_templates(self) -> list[ResourceTemplate]:
def list_templates(self, context: Context[ServerSession, object, Request] | None = None) -> list[ResourceTemplate]:
"""List all registered templates."""
logger.debug("Listing templates", extra={"count": len(self._templates)})
return list(self._templates.values())
return [
template for uri, template in self._templates.items() if self._authorizer.permit_list_template(uri, context)
]
Loading
Loading