diff --git a/examples/servers/simple-auth/README.md b/examples/servers/simple-auth/README.md index 5021dc3a9..2c21143c8 100644 --- a/examples/servers/simple-auth/README.md +++ b/examples/servers/simple-auth/README.md @@ -47,6 +47,10 @@ cd examples/servers/simple-auth # Start Resource Server on port 8001, connected to Authorization Server uv run mcp-simple-auth-rs --port=8001 --auth-server=http://localhost:9000 --transport=streamable-http + +# With RFC 8707 strict resource validation (recommended for production) +uv run mcp-simple-auth-rs --port=8001 --auth-server=http://localhost:9000 --transport=streamable-http --oauth-strict + ``` diff --git a/examples/servers/simple-auth/mcp_simple_auth/auth_server.py b/examples/servers/simple-auth/mcp_simple_auth/auth_server.py index 892bd8541..2594f81d6 100644 --- a/examples/servers/simple-auth/mcp_simple_auth/auth_server.py +++ b/examples/servers/simple-auth/mcp_simple_auth/auth_server.py @@ -121,6 +121,7 @@ async def introspect_handler(request: Request) -> Response: "exp": access_token.expires_at, "iat": int(time.time()), "token_type": "Bearer", + "aud": access_token.resource, # RFC 8707 audience claim } ) diff --git a/examples/servers/simple-auth/mcp_simple_auth/github_oauth_provider.py b/examples/servers/simple-auth/mcp_simple_auth/github_oauth_provider.py index bb45ae6c5..c64db96b7 100644 --- a/examples/servers/simple-auth/mcp_simple_auth/github_oauth_provider.py +++ b/examples/servers/simple-auth/mcp_simple_auth/github_oauth_provider.py @@ -65,7 +65,7 @@ def __init__(self, settings: GitHubOAuthSettings, github_callback_url: str): self.clients: dict[str, OAuthClientInformationFull] = {} self.auth_codes: dict[str, AuthorizationCode] = {} self.tokens: dict[str, AccessToken] = {} - self.state_mapping: dict[str, dict[str, str]] = {} + self.state_mapping: dict[str, dict[str, str | None]] = {} # Maps MCP tokens to GitHub tokens self.token_mapping: dict[str, str] = {} @@ -87,6 +87,7 @@ async def authorize(self, client: OAuthClientInformationFull, params: Authorizat "code_challenge": params.code_challenge, "redirect_uri_provided_explicitly": str(params.redirect_uri_provided_explicitly), "client_id": client.client_id, + "resource": params.resource, # RFC 8707 } # Build GitHub authorization URL @@ -110,6 +111,12 @@ async def handle_github_callback(self, code: str, state: str) -> str: code_challenge = state_data["code_challenge"] redirect_uri_provided_explicitly = state_data["redirect_uri_provided_explicitly"] == "True" client_id = state_data["client_id"] + resource = state_data.get("resource") # RFC 8707 + + # These are required values from our own state mapping + assert redirect_uri is not None + assert code_challenge is not None + assert client_id is not None # Exchange code for token with GitHub async with create_mcp_http_client() as client: @@ -144,6 +151,7 @@ async def handle_github_callback(self, code: str, state: str) -> str: expires_at=time.time() + 300, scopes=[self.settings.mcp_scope], code_challenge=code_challenge, + resource=resource, # RFC 8707 ) self.auth_codes[new_code] = auth_code @@ -180,6 +188,7 @@ async def exchange_authorization_code( client_id=client.client_id, scopes=authorization_code.scopes, expires_at=int(time.time()) + 3600, + resource=authorization_code.resource, # RFC 8707 ) # Find GitHub token for this client diff --git a/examples/servers/simple-auth/mcp_simple_auth/server.py b/examples/servers/simple-auth/mcp_simple_auth/server.py index 47174bcaf..898ee7837 100644 --- a/examples/servers/simple-auth/mcp_simple_auth/server.py +++ b/examples/servers/simple-auth/mcp_simple_auth/server.py @@ -43,6 +43,9 @@ class ResourceServerSettings(BaseSettings): # MCP settings mcp_scope: str = "user" + # RFC 8707 resource validation + oauth_strict: bool = False + def __init__(self, **data): """Initialize settings with values from environment variables.""" super().__init__(**data) @@ -57,8 +60,12 @@ def create_resource_server(settings: ResourceServerSettings) -> FastMCP: 2. Validates tokens via Authorization Server introspection 3. Serves MCP tools and resources """ - # Create token verifier for introspection - token_verifier = IntrospectionTokenVerifier(settings.auth_server_introspection_endpoint) + # Create token verifier for introspection with RFC 8707 resource validation + token_verifier = IntrospectionTokenVerifier( + introspection_endpoint=settings.auth_server_introspection_endpoint, + server_url=str(settings.server_url), + validate_resource=settings.oauth_strict, # Only validate when --oauth-strict is set + ) # Create FastMCP server as a Resource Server app = FastMCP( @@ -144,7 +151,12 @@ async def get_user_info() -> dict[str, Any]: type=click.Choice(["sse", "streamable-http"]), help="Transport protocol to use ('sse' or 'streamable-http')", ) -def main(port: int, auth_server: str, transport: Literal["sse", "streamable-http"]) -> int: +@click.option( + "--oauth-strict", + is_flag=True, + help="Enable RFC 8707 resource validation", +) +def main(port: int, auth_server: str, transport: Literal["sse", "streamable-http"], oauth_strict: bool) -> int: """ Run the MCP Resource Server. @@ -171,6 +183,7 @@ def main(port: int, auth_server: str, transport: Literal["sse", "streamable-http auth_server_url=auth_server_url, auth_server_introspection_endpoint=f"{auth_server}/introspect", auth_server_github_user_endpoint=f"{auth_server}/github/user", + oauth_strict=oauth_strict, ) except ValueError as e: logger.error(f"Configuration error: {e}") diff --git a/examples/servers/simple-auth/mcp_simple_auth/token_verifier.py b/examples/servers/simple-auth/mcp_simple_auth/token_verifier.py index ba71322fa..de3140238 100644 --- a/examples/servers/simple-auth/mcp_simple_auth/token_verifier.py +++ b/examples/servers/simple-auth/mcp_simple_auth/token_verifier.py @@ -3,6 +3,7 @@ import logging from mcp.server.auth.provider import AccessToken, TokenVerifier +from mcp.shared.auth_utils import check_resource_allowed, resource_url_from_server_url logger = logging.getLogger(__name__) @@ -18,8 +19,16 @@ class IntrospectionTokenVerifier(TokenVerifier): - Comprehensive configuration options """ - def __init__(self, introspection_endpoint: str): + def __init__( + self, + introspection_endpoint: str, + server_url: str, + validate_resource: bool = False, + ): self.introspection_endpoint = introspection_endpoint + self.server_url = server_url + self.validate_resource = validate_resource + self.resource_url = resource_url_from_server_url(server_url) async def verify_token(self, token: str) -> AccessToken | None: """Verify token via introspection endpoint.""" @@ -54,12 +63,43 @@ async def verify_token(self, token: str) -> AccessToken | None: if not data.get("active", False): return None + # RFC 8707 resource validation (only when --oauth-strict is set) + if self.validate_resource and not self._validate_resource(data): + logger.warning(f"Token resource validation failed. Expected: {self.resource_url}") + return None + return AccessToken( token=token, client_id=data.get("client_id", "unknown"), scopes=data.get("scope", "").split() if data.get("scope") else [], expires_at=data.get("exp"), + resource=data.get("aud"), # Include resource in token ) except Exception as e: logger.warning(f"Token introspection failed: {e}") return None + + def _validate_resource(self, token_data: dict) -> bool: + """Validate token was issued for this resource server.""" + if not self.server_url or not self.resource_url: + return False # Fail if strict validation requested but URLs missing + + # Check 'aud' claim first (standard JWT audience) + aud = token_data.get("aud") + if isinstance(aud, list): + for audience in aud: + if self._is_valid_resource(audience): + return True + return False + elif aud: + return self._is_valid_resource(aud) + + # No resource binding - invalid per RFC 8707 + return False + + def _is_valid_resource(self, resource: str) -> bool: + """Check if resource matches this server using hierarchical matching.""" + if not self.resource_url: + return False + + return check_resource_allowed(requested_resource=self.resource_url, configured_resource=resource) diff --git a/src/mcp/client/auth.py b/src/mcp/client/auth.py index 50ce74aa4..c174385ea 100644 --- a/src/mcp/client/auth.py +++ b/src/mcp/client/auth.py @@ -27,6 +27,7 @@ OAuthToken, ProtectedResourceMetadata, ) +from mcp.shared.auth_utils import check_resource_allowed, resource_url_from_server_url from mcp.types import LATEST_PROTOCOL_VERSION logger = logging.getLogger(__name__) @@ -134,6 +135,21 @@ def clear_tokens(self) -> None: self.current_tokens = None self.token_expiry_time = None + def get_resource_url(self) -> str: + """Get resource URL for RFC 8707. + + Uses PRM resource if it's a valid parent, otherwise uses canonical server URL. + """ + resource = resource_url_from_server_url(self.server_url) + + # If PRM provides a resource that's a valid parent, use it + if self.protected_resource_metadata and self.protected_resource_metadata.resource: + prm_resource = str(self.protected_resource_metadata.resource) + if check_resource_allowed(requested_resource=resource, configured_resource=prm_resource): + resource = prm_resource + + return resource + class OAuthClientProvider(httpx.Auth): """ @@ -256,6 +272,7 @@ async def _perform_authorization(self) -> tuple[str, str]: "state": state, "code_challenge": pkce_params.code_challenge, "code_challenge_method": "S256", + "resource": self.context.get_resource_url(), # RFC 8707 } if self.context.client_metadata.scope: @@ -293,6 +310,7 @@ async def _exchange_token(self, auth_code: str, code_verifier: str) -> httpx.Req "redirect_uri": str(self.context.client_metadata.redirect_uris[0]), "client_id": self.context.client_info.client_id, "code_verifier": code_verifier, + "resource": self.context.get_resource_url(), # RFC 8707 } if self.context.client_info.client_secret: @@ -343,6 +361,7 @@ async def _refresh_token(self) -> httpx.Request: "grant_type": "refresh_token", "refresh_token": self.context.current_tokens.refresh_token, "client_id": self.context.client_info.client_id, + "resource": self.context.get_resource_url(), # RFC 8707 } if self.context.client_info.client_secret: diff --git a/src/mcp/server/auth/handlers/authorize.py b/src/mcp/server/auth/handlers/authorize.py index 8d5e2622f..3ce4c34bc 100644 --- a/src/mcp/server/auth/handlers/authorize.py +++ b/src/mcp/server/auth/handlers/authorize.py @@ -35,6 +35,10 @@ class AuthorizationRequest(BaseModel): None, description="Optional scope; if specified, should be " "a space-separated list of scope strings", ) + resource: str | None = Field( + None, + description="RFC 8707 resource indicator - the MCP server this token will be used with", + ) class AuthorizationErrorResponse(BaseModel): @@ -197,6 +201,7 @@ async def error_response( code_challenge=auth_request.code_challenge, redirect_uri=redirect_uri, redirect_uri_provided_explicitly=auth_request.redirect_uri is not None, + resource=auth_request.resource, # RFC 8707 ) try: diff --git a/src/mcp/server/auth/handlers/token.py b/src/mcp/server/auth/handlers/token.py index 450ee406c..552417169 100644 --- a/src/mcp/server/auth/handlers/token.py +++ b/src/mcp/server/auth/handlers/token.py @@ -24,6 +24,8 @@ class AuthorizationCodeRequest(BaseModel): client_secret: str | None = None # See https://datatracker.ietf.org/doc/html/rfc7636#section-4.5 code_verifier: str = Field(..., description="PKCE code verifier") + # RFC 8707 resource indicator + resource: str | None = Field(None, description="Resource indicator for the token") class RefreshTokenRequest(BaseModel): @@ -34,6 +36,8 @@ class RefreshTokenRequest(BaseModel): client_id: str # we use the client_secret param, per https://datatracker.ietf.org/doc/html/rfc6749#section-2.3.1 client_secret: str | None = None + # RFC 8707 resource indicator + resource: str | None = Field(None, description="Resource indicator for the token") class TokenRequest( diff --git a/src/mcp/server/auth/provider.py b/src/mcp/server/auth/provider.py index acdd55bc2..b84db89a2 100644 --- a/src/mcp/server/auth/provider.py +++ b/src/mcp/server/auth/provider.py @@ -13,6 +13,7 @@ class AuthorizationParams(BaseModel): code_challenge: str redirect_uri: AnyUrl redirect_uri_provided_explicitly: bool + resource: str | None = None # RFC 8707 resource indicator class AuthorizationCode(BaseModel): @@ -23,6 +24,7 @@ class AuthorizationCode(BaseModel): code_challenge: str redirect_uri: AnyUrl redirect_uri_provided_explicitly: bool + resource: str | None = None # RFC 8707 resource indicator class RefreshToken(BaseModel): @@ -37,6 +39,7 @@ class AccessToken(BaseModel): client_id: str scopes: list[str] expires_at: int | None = None + resource: str | None = None # RFC 8707 resource indicator RegistrationErrorCode = Literal[ diff --git a/src/mcp/shared/auth_utils.py b/src/mcp/shared/auth_utils.py new file mode 100644 index 000000000..6d6300c9c --- /dev/null +++ b/src/mcp/shared/auth_utils.py @@ -0,0 +1,69 @@ +"""Utilities for OAuth 2.0 Resource Indicators (RFC 8707).""" + +from urllib.parse import urlparse, urlsplit, urlunsplit + +from pydantic import AnyUrl, HttpUrl + + +def resource_url_from_server_url(url: str | HttpUrl | AnyUrl) -> str: + """Convert server URL to canonical resource URL per RFC 8707. + + RFC 8707 section 2 states that resource URIs "MUST NOT include a fragment component". + Returns absolute URI with lowercase scheme/host for canonical form. + + Args: + url: Server URL to convert + + Returns: + Canonical resource URL string + """ + # Convert to string if needed + url_str = str(url) + + # Parse the URL and remove fragment, create canonical form + parsed = urlsplit(url_str) + canonical = urlunsplit(parsed._replace(scheme=parsed.scheme.lower(), netloc=parsed.netloc.lower(), fragment="")) + + return canonical + + +def check_resource_allowed(requested_resource: str, configured_resource: str) -> bool: + """Check if a requested resource URL matches a configured resource URL. + + A requested resource matches if it has the same scheme, domain, port, + and its path starts with the configured resource's path. This allows + hierarchical matching where a token for a parent resource can be used + for child resources. + + Args: + requested_resource: The resource URL being requested + configured_resource: The resource URL that has been configured + + Returns: + True if the requested resource matches the configured resource + """ + # Parse both URLs + requested = urlparse(requested_resource) + configured = urlparse(configured_resource) + + # Compare scheme, host, and port (origin) + if requested.scheme.lower() != configured.scheme.lower() or requested.netloc.lower() != configured.netloc.lower(): + return False + + # Handle cases like requested=/foo and configured=/foo/ + requested_path = requested.path + configured_path = configured.path + + # If requested path is shorter, it cannot be a child + if len(requested_path) < len(configured_path): + return False + + # Check if the requested path starts with the configured path + # Ensure both paths end with / for proper comparison + # This ensures that paths like "/api123" don't incorrectly match "/api" + if not requested_path.endswith("/"): + requested_path += "/" + if not configured_path.endswith("/"): + configured_path += "/" + + return requested_path.startswith(configured_path) diff --git a/tests/shared/test_auth_utils.py b/tests/shared/test_auth_utils.py new file mode 100644 index 000000000..5b12dc677 --- /dev/null +++ b/tests/shared/test_auth_utils.py @@ -0,0 +1,112 @@ +"""Tests for OAuth 2.0 Resource Indicators utilities.""" + +from mcp.shared.auth_utils import check_resource_allowed, resource_url_from_server_url + + +class TestResourceUrlFromServerUrl: + """Tests for resource_url_from_server_url function.""" + + def test_removes_fragment(self): + """Fragment should be removed per RFC 8707.""" + assert resource_url_from_server_url("https://example.com/path#fragment") == "https://example.com/path" + assert resource_url_from_server_url("https://example.com/#fragment") == "https://example.com/" + + def test_preserves_path(self): + """Path should be preserved.""" + assert ( + resource_url_from_server_url("https://example.com/path/to/resource") + == "https://example.com/path/to/resource" + ) + assert resource_url_from_server_url("https://example.com/") == "https://example.com/" + assert resource_url_from_server_url("https://example.com") == "https://example.com" + + def test_preserves_query(self): + """Query parameters should be preserved.""" + assert resource_url_from_server_url("https://example.com/path?foo=bar") == "https://example.com/path?foo=bar" + assert resource_url_from_server_url("https://example.com/?key=value") == "https://example.com/?key=value" + + def test_preserves_port(self): + """Non-default ports should be preserved.""" + assert resource_url_from_server_url("https://example.com:8443/path") == "https://example.com:8443/path" + assert resource_url_from_server_url("http://example.com:8080/") == "http://example.com:8080/" + + def test_lowercase_scheme_and_host(self): + """Scheme and host should be lowercase for canonical form.""" + assert resource_url_from_server_url("HTTPS://EXAMPLE.COM/path") == "https://example.com/path" + assert resource_url_from_server_url("Http://Example.Com:8080/") == "http://example.com:8080/" + + def test_handles_pydantic_urls(self): + """Should handle Pydantic URL types.""" + from pydantic import HttpUrl + + url = HttpUrl("https://example.com/path") + assert resource_url_from_server_url(url) == "https://example.com/path" + + +class TestCheckResourceAllowed: + """Tests for check_resource_allowed function.""" + + def test_identical_urls(self): + """Identical URLs should match.""" + assert check_resource_allowed("https://example.com/path", "https://example.com/path") is True + assert check_resource_allowed("https://example.com/", "https://example.com/") is True + assert check_resource_allowed("https://example.com", "https://example.com") is True + + def test_different_schemes(self): + """Different schemes should not match.""" + assert check_resource_allowed("https://example.com/path", "http://example.com/path") is False + assert check_resource_allowed("http://example.com/", "https://example.com/") is False + + def test_different_domains(self): + """Different domains should not match.""" + assert check_resource_allowed("https://example.com/path", "https://example.org/path") is False + assert check_resource_allowed("https://sub.example.com/", "https://example.com/") is False + + def test_different_ports(self): + """Different ports should not match.""" + assert check_resource_allowed("https://example.com:8443/path", "https://example.com/path") is False + assert check_resource_allowed("https://example.com:8080/", "https://example.com:8443/") is False + + def test_hierarchical_matching(self): + """Child paths should match parent paths.""" + # Parent resource allows child resources + assert check_resource_allowed("https://example.com/api/v1/users", "https://example.com/api") is True + assert check_resource_allowed("https://example.com/api/v1", "https://example.com/api") is True + assert check_resource_allowed("https://example.com/mcp/server", "https://example.com/mcp") is True + + # Exact match + assert check_resource_allowed("https://example.com/api", "https://example.com/api") is True + + # Parent cannot use child's token + assert check_resource_allowed("https://example.com/api", "https://example.com/api/v1") is False + assert check_resource_allowed("https://example.com/", "https://example.com/api") is False + + def test_path_boundary_matching(self): + """Path matching should respect boundaries.""" + # Should not match partial path segments + assert check_resource_allowed("https://example.com/apiextra", "https://example.com/api") is False + assert check_resource_allowed("https://example.com/api123", "https://example.com/api") is False + + # Should match with trailing slash + assert check_resource_allowed("https://example.com/api/", "https://example.com/api") is True + assert check_resource_allowed("https://example.com/api/v1", "https://example.com/api/") is True + + def test_trailing_slash_handling(self): + """Trailing slashes should be handled correctly.""" + # With and without trailing slashes + assert check_resource_allowed("https://example.com/api/", "https://example.com/api") is True + assert check_resource_allowed("https://example.com/api", "https://example.com/api/") is False + assert check_resource_allowed("https://example.com/api/v1", "https://example.com/api") is True + assert check_resource_allowed("https://example.com/api/v1", "https://example.com/api/") is True + + def test_case_insensitive_origin(self): + """Origin comparison should be case-insensitive.""" + assert check_resource_allowed("https://EXAMPLE.COM/path", "https://example.com/path") is True + assert check_resource_allowed("HTTPS://example.com/path", "https://example.com/path") is True + assert check_resource_allowed("https://Example.Com:8080/api", "https://example.com:8080/api") is True + + def test_empty_paths(self): + """Empty paths should be handled correctly.""" + assert check_resource_allowed("https://example.com", "https://example.com") is True + assert check_resource_allowed("https://example.com/", "https://example.com") is True + assert check_resource_allowed("https://example.com/api", "https://example.com") is True