diff --git a/examples/servers/simple-auth-remote/README.md b/examples/servers/simple-auth-remote/README.md new file mode 100644 index 000000000..9906c4d36 --- /dev/null +++ b/examples/servers/simple-auth-remote/README.md @@ -0,0 +1,91 @@ +# Simple MCP Server with GitHub OAuth Authentication + +This is a simple example of an MCP server with GitHub OAuth authentication. It demonstrates the essential components needed for OAuth integration with just a single tool. + +This is just an example of a server that uses auth, an official GitHub mcp server is [here](https://github.com/github/github-mcp-server) + +## Overview + +This simple demo to show to set up a server with: +- GitHub OAuth2 authorization flow +- Single tool: `get_user_profile` to retrieve GitHub user information + + +## Prerequisites + +1. Create a GitHub OAuth App: + - Go to GitHub Settings > Developer settings > OAuth Apps > New OAuth App + - Application name: Any name (e.g., "Simple MCP Auth Demo") + - Homepage URL: `http://localhost:8000` + - Authorization callback URL: `http://localhost:8000/github/callback` + - Click "Register application" + - Note down your Client ID and Client Secret + +## Required Environment Variables + +You MUST set these environment variables before running the server: + +```bash +export MCP_GITHUB_GITHUB_CLIENT_ID="your_client_id_here" +export MCP_GITHUB_GITHUB_CLIENT_SECRET="your_client_secret_here" +``` + +The server will not start without these environment variables properly set. + + +## Running the Server + +```bash +# Set environment variables first (see above) + +# Run the server +uv run mcp-simple-auth +``` + +The server will start on `http://localhost:8000`. + +### Transport Options + +This server supports multiple transport protocols that can run on the same port: + +#### SSE (Server-Sent Events) - Default +```bash +uv run mcp-simple-auth +# or explicitly: +uv run mcp-simple-auth --transport sse +``` + +SSE transport provides endpoint: +- `/sse` + +#### Streamable HTTP +```bash +uv run mcp-simple-auth --transport streamable-http +``` + +Streamable HTTP transport provides endpoint: +- `/mcp` + + +This ensures backward compatibility without needing multiple server instances. When using SSE transport (`--transport sse`), only the `/sse` endpoint is available. + +## Available Tool + +### get_user_profile + +The only tool in this simple example. Returns the authenticated user's GitHub profile information. + +**Required scope**: `user` + +**Returns**: GitHub user profile data including username, email, bio, etc. + + +## Troubleshooting + +If the server fails to start, check: +1. Environment variables `MCP_GITHUB_GITHUB_CLIENT_ID` and `MCP_GITHUB_GITHUB_CLIENT_SECRET` are set +2. The GitHub OAuth app callback URL matches `http://localhost:8000/github/callback` +3. No other service is using port 8000 +4. The transport specified is valid (`sse` or `streamable-http`) + +You can use [Inspector](https://github.com/modelcontextprotocol/inspector) to test Auth \ No newline at end of file diff --git a/examples/servers/simple-auth-remote/mcp_simple_remote_auth/__init__.py b/examples/servers/simple-auth-remote/mcp_simple_remote_auth/__init__.py new file mode 100644 index 000000000..3e12b3183 --- /dev/null +++ b/examples/servers/simple-auth-remote/mcp_simple_remote_auth/__init__.py @@ -0,0 +1 @@ +"""Simple MCP server with GitHub OAuth authentication.""" diff --git a/examples/servers/simple-auth-remote/mcp_simple_remote_auth/__main__.py b/examples/servers/simple-auth-remote/mcp_simple_remote_auth/__main__.py new file mode 100644 index 000000000..41d1e0f34 --- /dev/null +++ b/examples/servers/simple-auth-remote/mcp_simple_remote_auth/__main__.py @@ -0,0 +1,7 @@ +"""Main entry point for simple MCP server with GitHub OAuth authentication.""" + +import sys + +from mcp_simple_remote_auth.server import main + +sys.exit(main()) # type: ignore[call-arg] diff --git a/examples/servers/simple-auth-remote/mcp_simple_remote_auth/server.py b/examples/servers/simple-auth-remote/mcp_simple_remote_auth/server.py new file mode 100644 index 000000000..53ab41a9e --- /dev/null +++ b/examples/servers/simple-auth-remote/mcp_simple_remote_auth/server.py @@ -0,0 +1,210 @@ +"""Simple MCP Server with GitHub OAuth Authentication.""" + +import logging +from typing import Any, Literal + +import click +import jwt +import requests +from pydantic import AnyHttpUrl +from pydantic_settings import BaseSettings, SettingsConfigDict + +from mcp.server.auth.provider import ( + AccessToken, + TokenValidator, +) +from mcp.server.auth.settings import AuthSettings, ClientRegistrationOptions +from mcp.server.fastmcp.server import FastMCP +from mcp.shared.auth import ProtectedResourceMetadata + +logger = logging.getLogger(__name__) + + +class TokenValidatorJWT(TokenValidator[AccessToken]): + def __init__(self, resource_metadata: ProtectedResourceMetadata): + self._resource_metadata = resource_metadata + + async def validate_token(self, token: str) -> AccessToken | None: + try: + return await self.decode_token(token) + except Exception as e: + logger.error(f"Token validation failed: {e}") + return None + + async def _get_jwks_uri(self, auth_server: str) -> str: + """Get the JWKS URI from the OIDC or OAuth well-known configuration. + + Args: + auth_server: The base URL of the authorization server + + Returns: + The JWKS URI + + Raises: + ValueError: If the JWKS URI cannot be found in either OIDC or OAuth + well-known configurations + requests.RequestException: If there's an error fetching the configuration + """ + well_known_paths = [ + "/.well-known/openid-configuration", # OIDC well-known + "/.well-known/oauth-authorization-server", # OAuth well-known + ] + + last_error = None + + for path in well_known_paths: + try: + config_url = f"https://{auth_server}{path}" + response = requests.get( + config_url, + timeout=10, # Add timeout to prevent hanging + headers={"Accept": "application/json"}, + ) + response.raise_for_status() # Raise an exception for bad status codes + config = response.json() + + # Try to get JWKS URI from the configuration + jwks_uri = config.get("jwks_uri") + if jwks_uri: + return jwks_uri + + except requests.RequestException as e: + last_error = e + logger.debug(f"Failed to fetch {path}: {e}") + continue + + # If we get here, we couldn't find a valid JWKS URI + error_msg = "Could not find jwks_uri in OIDC or OAuth well-known configurations" + logger.error(f"{error_msg}. Last error: {last_error}") + raise ValueError(error_msg) + + async def decode_token(self, token: str) -> AccessToken | None: + try: + auth_server = self._resource_metadata.authorization_servers[0] + jwks_uri = await self._get_jwks_uri(auth_server) + jwks_client = jwt.PyJWKClient(jwks_uri) + signing_key = jwks_client.get_signing_key_from_jwt(token) + + # Rest of your decode_token method remains the same + payload = jwt.decode( + token, + key=signing_key.key, + algorithms=["RS256"], + audience=self._resource_metadata.resource, + issuer=f"https://{auth_server}", + options={ + "verify_signature": True, + "verify_aud": True, + "verify_iss": True, + "verify_exp": True, + "verify_nbf": True, + "verify_iat": True, + }, + ) + + return AccessToken( + token=token, + client_id=payload["client_id"], + scopes=payload["scope"].split(" "), + expires_at=payload["exp"], + ) + except Exception as e: + logger.error(f"Token validation failed: {e}") + return None + + +class ServerSettings(BaseSettings): + """Settings for the simple GitHub MCP server.""" + + model_config = SettingsConfigDict(env_prefix="MCP_GITHUB_") + + # Server settings + host: str = "localhost" + port: int = 8000 + server_url: AnyHttpUrl = AnyHttpUrl("http://localhost:8000") + mcp_scope: str = "user" + + def __init__(self, **data): + """Initialize settings with values from environment variables. + + Note: github_client_id and github_client_secret are required but can be + loaded automatically from environment variables (MCP_GITHUB_GITHUB_CLIENT_ID + and MCP_GITHUB_GITHUB_CLIENT_SECRET) and don't need to be passed explicitly. + """ + super().__init__(**data) + + +def create_simple_mcp_server(settings: ServerSettings) -> FastMCP: + """Create a simple FastMCP server with GitHub OAuth.""" + + auth_settings = AuthSettings( + issuer_url=settings.server_url, + client_registration_options=ClientRegistrationOptions( + enabled=True, + valid_scopes=[settings.mcp_scope], + default_scopes=[settings.mcp_scope], + ), + required_scopes=[settings.mcp_scope], + ) + + app = FastMCP( + name="Simple GitHub MCP Server", + instructions="A simple MCP server with GitHub OAuth authentication", + host=settings.host, + port=settings.port, + debug=True, + auth=auth_settings, + token_validator=TokenValidatorJWT( + ProtectedResourceMetadata( + resource="asdasd", + authorization_servers=["https://auth.devramp.ai"], + scopes_supported=["user"], + ) + ), + protected_resource_metadata={ + "resource": "asdasd", + "authorization_servers": ["https://auth.devramp.ai"], + "scopes_supported": ["user"], + }, + ) + + @app.tool() + async def get_user_profile() -> dict[str, Any]: + """Get the authenticated user's GitHub profile information. + + This is the only tool in our simple example. It requires the 'user' scope. + """ + return {"user": "asdasd"} + + return app + + +@click.command() +@click.option("--port", default=8000, help="Port to listen on") +@click.option("--host", default="localhost", help="Host to bind to") +@click.option( + "--transport", + default="streamable-http", + type=click.Choice(["sse", "streamable-http"]), + help="Transport protocol to use ('sse' or 'streamable-http')", +) +def main(port: int, host: str, transport: Literal["sse", "streamable-http"]) -> int: + """Run the simple GitHub MCP server.""" + logging.basicConfig(level=logging.INFO) + + try: + # No hardcoded credentials - all from environment variables + settings = ServerSettings(host=host, port=port) + except ValueError as e: + logger.error( + "Failed to load settings. Make sure environment variables are set:" + ) + logger.error(" MCP_GITHUB_GITHUB_CLIENT_ID=") + logger.error(" MCP_GITHUB_GITHUB_CLIENT_SECRET=") + logger.error(f"Error: {e}") + return 1 + + mcp_server = create_simple_mcp_server(settings) + logger.info(f"Starting server with {transport} transport") + mcp_server.run(transport=transport) + return 0 diff --git a/examples/servers/simple-auth-remote/pyproject.toml b/examples/servers/simple-auth-remote/pyproject.toml new file mode 100644 index 000000000..82d3902b5 --- /dev/null +++ b/examples/servers/simple-auth-remote/pyproject.toml @@ -0,0 +1,31 @@ +[project] +name = "mcp-simple-remote-auth" +version = "0.1.0" +description = "A simple MCP server demonstrating OAuth authentication" +readme = "README.md" +requires-python = ">=3.10" +authors = [{ name = "Anthropic, PBC." }] +license = { text = "MIT" } +dependencies = [ + "anyio>=4.5", + "click>=8.1.0", + "httpx>=0.27", + "mcp", + "pydantic>=2.0", + "pydantic-settings>=2.5.2", + "sse-starlette>=1.6.1", + "uvicorn>=0.23.1; sys_platform != 'emscripten'", +] + +[project.scripts] +mcp-simple-remote-auth = "mcp_simple_remote_auth.server:main" + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.hatch.build.targets.wheel] +packages = ["mcp_simple_remote_auth"] + +[tool.uv] +dev-dependencies = ["pyright>=1.1.391", "pytest>=8.3.4", "ruff>=0.8.5"] \ No newline at end of file diff --git a/src/mcp/server/auth/middleware/bearer_auth.py b/src/mcp/server/auth/middleware/bearer_auth.py index 30b5e2ba6..177f1bc20 100644 --- a/src/mcp/server/auth/middleware/bearer_auth.py +++ b/src/mcp/server/auth/middleware/bearer_auth.py @@ -10,7 +10,11 @@ from starlette.requests import HTTPConnection from starlette.types import Receive, Scope, Send -from mcp.server.auth.provider import AccessToken, OAuthAuthorizationServerProvider +from mcp.server.auth.provider import ( + AccessToken, + OAuthAuthorizationServerProvider, + TokenValidator, +) class AuthenticatedUser(SimpleUser): @@ -21,6 +25,42 @@ def __init__(self, auth_info: AccessToken): self.access_token = auth_info self.scopes = auth_info.scopes +class JWTBearerTokenAuthBackend(AuthenticationBackend): + """ + Authentication backend that validates Bearer tokens. + """ + + def __init__( + self, + provider: TokenValidator[AccessToken], + ): + self.provider = provider + + async def authenticate(self, conn: HTTPConnection): + auth_header = next( + ( + conn.headers.get(key) + for key in conn.headers + if key.lower() == "authorization" + ), + None, + ) + if not auth_header or not auth_header.lower().startswith("bearer "): + return None + + token = auth_header[7:] # Remove "Bearer " prefix + + # Validate the token with the provider + auth_info = await self.provider.validate_token(token) + + + if not auth_info: + return None + + if auth_info.expires_at and auth_info.expires_at < int(time.time()): + return None + + return AuthCredentials(auth_info.scopes), AuthenticatedUser(auth_info) class BearerAuthBackend(AuthenticationBackend): """ diff --git a/src/mcp/server/auth/provider.py b/src/mcp/server/auth/provider.py index be1ac1dbc..153d3e15a 100644 --- a/src/mcp/server/auth/provider.py +++ b/src/mcp/server/auth/provider.py @@ -96,6 +96,13 @@ class TokenError(Exception): AccessTokenT = TypeVar("AccessTokenT", bound=AccessToken) + +class TokenValidator(BaseModel, Generic[AccessTokenT]): + async def validate_token(self, token: str) -> AccessTokenT | None: + ... + + + class OAuthAuthorizationServerProvider( Protocol, Generic[AuthorizationCodeT, RefreshTokenT, AccessTokenT] ): diff --git a/src/mcp/server/fastmcp/server.py b/src/mcp/server/fastmcp/server.py index 3282baae6..4074a3986 100644 --- a/src/mcp/server/fastmcp/server.py +++ b/src/mcp/server/fastmcp/server.py @@ -5,10 +5,7 @@ import inspect import re from collections.abc import AsyncIterator, Awaitable, Callable, Iterable, Sequence -from contextlib import ( - AbstractAsyncContextManager, - asynccontextmanager, -) +from contextlib import AbstractAsyncContextManager, asynccontextmanager from itertools import chain from typing import Any, Generic, Literal @@ -18,6 +15,7 @@ from pydantic.networks import AnyUrl from pydantic_settings import BaseSettings, SettingsConfigDict from starlette.applications import Starlette +from starlette.exceptions import HTTPException from starlette.middleware import Middleware from starlette.middleware.authentication import AuthenticationMiddleware from starlette.requests import Request @@ -28,16 +26,19 @@ from mcp.server.auth.middleware.auth_context import AuthContextMiddleware from mcp.server.auth.middleware.bearer_auth import ( BearerAuthBackend, + JWTBearerTokenAuthBackend, RequireAuthMiddleware, ) -from mcp.server.auth.provider import OAuthAuthorizationServerProvider -from mcp.server.auth.settings import ( - AuthSettings, +from mcp.server.auth.provider import ( + AccessToken, + OAuthAuthorizationServerProvider, + TokenValidator, ) +from mcp.server.auth.settings import AuthSettings from mcp.server.fastmcp.exceptions import ResourceError from mcp.server.fastmcp.prompts import Prompt, PromptManager from mcp.server.fastmcp.resources import FunctionResource, Resource, ResourceManager -from mcp.server.fastmcp.tools import Tool, ToolManager +from mcp.server.fastmcp.tools import ToolManager from mcp.server.fastmcp.utilities.logging import configure_logging, get_logger from mcp.server.fastmcp.utilities.types import Image from mcp.server.lowlevel.helper_types import ReadResourceContents @@ -49,6 +50,7 @@ from mcp.server.stdio import stdio_server from mcp.server.streamable_http import EventStore from mcp.server.streamable_http_manager import StreamableHTTPSessionManager +from mcp.shared.auth import ProtectedResourceMetadata from mcp.shared.context import LifespanContextT, RequestContext from mcp.types import ( AnyFunction, @@ -87,7 +89,7 @@ class Settings(BaseSettings, Generic[LifespanResultT]): log_level: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] = "INFO" # HTTP settings - host: str = "127.0.0.1" + host: str = "0.0.0.0" port: int = 8000 mount_path: str = "/" # Mount path (e.g. "/github", defaults to root path) sse_path: str = "/sse" @@ -138,14 +140,23 @@ def __init__( self, name: str | None = None, instructions: str | None = None, - auth_server_provider: OAuthAuthorizationServerProvider[Any, Any, Any] - | None = None, + auth_server_details: dict[str, Any] | None = None, + auth_server_provider: ( + OAuthAuthorizationServerProvider[Any, Any, Any] | None + ) = None, + protected_resource_metadata: dict[str, Any] | None = None, event_store: EventStore | None = None, - *, - tools: list[Tool] | None = None, + token_validator: TokenValidator[AccessToken] | None = None, **settings: Any, ): self.settings = Settings(**settings) + self._auth_server_details = auth_server_details + self._protected_resource_metadata = None + if protected_resource_metadata: + self._protected_resource_metadata = ProtectedResourceMetadata( + **protected_resource_metadata + ) + self._token_validator = token_validator self._mcp_server = MCPServer( name=name or "FastMCP", @@ -157,7 +168,7 @@ def __init__( ), ) self._tool_manager = ToolManager( - tools=tools, warn_on_duplicate_tools=self.settings.warn_on_duplicate_tools + warn_on_duplicate_tools=self.settings.warn_on_duplicate_tools ) self._resource_manager = ResourceManager( warn_on_duplicate_resources=self.settings.warn_on_duplicate_resources @@ -165,7 +176,10 @@ def __init__( self._prompt_manager = PromptManager( warn_on_duplicate_prompts=self.settings.warn_on_duplicate_prompts ) - if (self.settings.auth is not None) != (auth_server_provider is not None): + # don't do this check if protected_resource_metadata is not None + if (self.settings.auth is not None) != ( + auth_server_provider is not None + ) and self._protected_resource_metadata is None: # TODO: after we support separate authorization servers (see # https://github.com/modelcontextprotocol/modelcontextprotocol/pull/284) # we should validate that if auth is enabled, we have either an @@ -214,6 +228,17 @@ def session_manager(self) -> StreamableHTTPSessionManager: ) return self._session_manager + async def _serve_protected_resource_metadata(self, request: Request) -> Response: + """Serve the OAuth protected resource metadata.""" + if not self._protected_resource_metadata: + raise HTTPException( + status_code=404, detail="Protected resource metadata not configured" + ) + return Response( + self._protected_resource_metadata.model_dump_json(), + media_type="application/json", + ) + def run( self, transport: Literal["stdio", "sse", "streamable-http"] = "stdio", @@ -689,6 +714,7 @@ async def handle_sse(scope: Scope, receive: Receive, send: Send): # Create routes routes: list[Route | Mount] = [] + middleware: list[Middleware] = [] required_scopes = [] @@ -789,6 +815,38 @@ async def handle_streamable_http( routes: list[Route | Mount] = [] middleware: list[Middleware] = [] required_scopes = [] + if self._protected_resource_metadata and self._token_validator: + # only add the well-known route if the protected resource metadata is + # configured + routes.append( + Route( + "/.well-known/oauth-protected-resource", + self._serve_protected_resource_metadata, + methods=["GET"], + ) + ) + # by default assuming that this would be a JWT Bearer Token; + # Make this also optional somehow; may be as part of the protected resource + # metadata, take a class for validating the token + middleware = [ + Middleware( + AuthenticationMiddleware, + backend=JWTBearerTokenAuthBackend( + provider=self._token_validator, + ), + ), + Middleware(AuthContextMiddleware), + ] + # fetch required scopes from protected resource metadata + required_scopes = self._protected_resource_metadata.required_scopes or [] + + # wrap the streamable http handler with require auth middleware + routes.append( + Mount( + self.settings.streamable_http_path, + app=RequireAuthMiddleware(handle_streamable_http, required_scopes), + ) + ) # Add auth endpoints if auth provider is configured if self._auth_server_provider: diff --git a/src/mcp/shared/auth.py b/src/mcp/shared/auth.py index 22f8a971d..715aecf13 100644 --- a/src/mcp/shared/auth.py +++ b/src/mcp/shared/auth.py @@ -3,6 +3,24 @@ from pydantic import AnyHttpUrl, BaseModel, Field +class ProtectedResourceMetadata(BaseModel): + # create a pydantic model with required params as resource, authorization_servers + resource: str + authorization_servers: list[str] + jwks_uri: AnyHttpUrl | None = None + scopes_supported: list[str] | None = None + bearer_methods_supported: list[str] | None = None + resource_name: str | None = None + resource_signing_alg_values_supported: list[str] | None = None + resource_documentation: AnyHttpUrl | None = None + resource_policy_uri: AnyHttpUrl | None = None + resource_tos_uri: AnyHttpUrl | None = None + authorization_details_types_supported: list[str] | None = None + dpop_signing_alg_values_supported: list[str] | None = None + dpop_bound_access_tokens_required: bool | None = None + required_scopes: list[str] | None = None + + class OAuthToken(BaseModel): """ See https://datatracker.ietf.org/doc/html/rfc6749#section-5.1 diff --git a/uv.lock b/uv.lock index 88869fa50..9bcb51e07 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,4 @@ version = 1 -revision = 1 requires-python = ">=3.10" [options] @@ -10,6 +9,7 @@ members = [ "mcp", "mcp-simple-auth", "mcp-simple-prompt", + "mcp-simple-remote-auth", "mcp-simple-resource", "mcp-simple-streamablehttp", "mcp-simple-streamablehttp-stateless", @@ -549,7 +549,6 @@ requires-dist = [ { name = "uvicorn", marker = "sys_platform != 'emscripten'", specifier = ">=0.23.1" }, { name = "websockets", marker = "extra == 'ws'", specifier = ">=15.0.1" }, ] -provides-extras = ["cli", "rich", "ws"] [package.metadata.requires-dev] dev = [ @@ -643,6 +642,47 @@ dev = [ { name = "ruff", specifier = ">=0.6.9" }, ] +[[package]] +name = "mcp-simple-remote-auth" +version = "0.1.0" +source = { editable = "examples/servers/simple-auth-remote" } +dependencies = [ + { name = "anyio" }, + { name = "click" }, + { name = "httpx" }, + { name = "mcp" }, + { name = "pydantic" }, + { name = "pydantic-settings" }, + { name = "sse-starlette" }, + { name = "uvicorn", marker = "sys_platform != 'emscripten'" }, +] + +[package.dev-dependencies] +dev = [ + { name = "pyright" }, + { name = "pytest" }, + { name = "ruff" }, +] + +[package.metadata] +requires-dist = [ + { name = "anyio", specifier = ">=4.5" }, + { name = "click", specifier = ">=8.1.0" }, + { name = "httpx", specifier = ">=0.27" }, + { name = "mcp", editable = "." }, + { name = "pydantic", specifier = ">=2.0" }, + { name = "pydantic-settings", specifier = ">=2.5.2" }, + { name = "sse-starlette", specifier = ">=1.6.1" }, + { name = "uvicorn", marker = "sys_platform != 'emscripten'", specifier = ">=0.23.1" }, +] + +[package.metadata.requires-dev] +dev = [ + { name = "pyright", specifier = ">=1.1.391" }, + { name = "pytest", specifier = ">=8.3.4" }, + { name = "ruff", specifier = ">=0.8.5" }, +] + [[package]] name = "mcp-simple-resource" version = "0.1.0"