Skip to content

Implement RFC9728 - Support WWW-Authenticate header by MCP client #1071

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
176 changes: 87 additions & 89 deletions src/mcp/client/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import base64
import hashlib
import logging
import re
import secrets
import string
import time
Expand Down Expand Up @@ -203,10 +204,39 @@ def __init__(
)
self._initialized = False

async def _discover_protected_resource(self) -> httpx.Request:
"""Build discovery request for protected resource metadata."""
auth_base_url = self.context.get_authorization_base_url(self.context.server_url)
url = urljoin(auth_base_url, "/.well-known/oauth-protected-resource")
def _extract_resource_metadata_from_www_auth(self, init_response: httpx.Response) -> str | None:
"""
Extract protected resource metadata URL from WWW-Authenticate header as per RFC9728.

Returns:
Resource metadata URL if found in WWW-Authenticate header, None otherwise
"""
if not init_response or init_response.status_code != 401:
return None

www_auth_header = init_response.headers.get("WWW-Authenticate")
if not www_auth_header:
return None

# Pattern matches: resource_metadata="url" or resource_metadata=url (unquoted)
pattern = r'resource_metadata=(?:"([^"]+)"|([^\s,]+))'
match = re.search(pattern, www_auth_header)

if match:
# Return quoted value if present, otherwise unquoted value
return match.group(1) or match.group(2)

return None

async def _discover_protected_resource(self, init_response: httpx.Response) -> httpx.Request:
# RFC9728: Try to extract resource_metadata URL from WWW-Authenticate header of the initial response
url = self._extract_resource_metadata_from_www_auth(init_response)

if not url:
# Fallback to well-known discovery
auth_base_url = self.context.get_authorization_base_url(self.context.server_url)
url = urljoin(auth_base_url, "/.well-known/oauth-protected-resource")

return httpx.Request("GET", url, headers={MCP_PROTOCOL_VERSION: LATEST_PROTOCOL_VERSION})

async def _handle_protected_resource_response(self, response: httpx.Response) -> None:
Expand Down Expand Up @@ -489,92 +519,60 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
# Capture protocol version from request headers
self.context.protocol_version = request.headers.get(MCP_PROTOCOL_VERSION)

# Perform OAuth flow if not authenticated
if not self.context.is_token_valid():
try:
# OAuth flow must be inline due to generator constraints
# Step 1: Discover protected resource metadata (spec revision 2025-06-18)
discovery_request = await self._discover_protected_resource()
discovery_response = yield discovery_request
await self._handle_protected_resource_response(discovery_response)

# Step 2: Discover OAuth metadata (with fallback for legacy servers)
oauth_request = await self._discover_oauth_metadata()
oauth_response = yield oauth_request
handled = await self._handle_oauth_metadata_response(oauth_response, is_fallback=False)

# If path-aware discovery failed with 404, try fallback to root
if not handled:
fallback_request = await self._discover_oauth_metadata_fallback()
fallback_response = yield fallback_request
await self._handle_oauth_metadata_response(fallback_response, is_fallback=True)

# Step 3: Register client if needed
registration_request = await self._register_client()
if registration_request:
registration_response = yield registration_request
await self._handle_registration_response(registration_response)

# Step 4: Perform authorization
auth_code, code_verifier = await self._perform_authorization()

# Step 5: Exchange authorization code for tokens
token_request = await self._exchange_token(auth_code, code_verifier)
token_response = yield token_request
await self._handle_token_response(token_response)
except Exception as e:
logger.error(f"OAuth flow error: {e}")
raise

# Add authorization header and make request
self._add_auth_header(request)
if self.context.is_token_valid():
self._add_auth_header(request)

response = yield request

# Handle 401 responses
if response.status_code == 401 and self.context.can_refresh_token():
# Try to refresh token
refresh_request = await self._refresh_token()
refresh_response = yield refresh_request
if response.status_code == 401:
if self.context.can_refresh_token():
# Try to refresh token
refresh_request = await self._refresh_token()
refresh_response = yield refresh_request

if await self._handle_refresh_response(refresh_response):
# Retry original request with new token
self._add_auth_header(request)
yield request
if not await self._handle_refresh_response(refresh_response):
# Refresh failed, need full re-authentication
self._initialized = False
else:
# Refresh failed, need full re-authentication
self._initialized = False

# OAuth flow must be inline due to generator constraints
# Step 1: Discover protected resource metadata (spec revision 2025-06-18)
discovery_request = await self._discover_protected_resource()
discovery_response = yield discovery_request
await self._handle_protected_resource_response(discovery_response)

# Step 2: Discover OAuth metadata (with fallback for legacy servers)
oauth_request = await self._discover_oauth_metadata()
oauth_response = yield oauth_request
handled = await self._handle_oauth_metadata_response(oauth_response, is_fallback=False)

# If path-aware discovery failed with 404, try fallback to root
if not handled:
fallback_request = await self._discover_oauth_metadata_fallback()
fallback_response = yield fallback_request
await self._handle_oauth_metadata_response(fallback_response, is_fallback=True)

# Step 3: Register client if needed
registration_request = await self._register_client()
if registration_request:
registration_response = yield registration_request
await self._handle_registration_response(registration_response)

# Step 4: Perform authorization
auth_code, code_verifier = await self._perform_authorization()

# Step 5: Exchange authorization code for tokens
token_request = await self._exchange_token(auth_code, code_verifier)
token_response = yield token_request
await self._handle_token_response(token_response)

# Retry with new tokens
self._add_auth_header(request)
yield request
self.context.clear_tokens()

# If we don't have valid tokens after refresh, perform OAuth flow
if not self.context.is_token_valid():
try:
# OAuth flow must be inline due to generator constraints
# Step 1: Discover protected resource metadata (RFC9728 with WWW-Authenticate support)
discovery_request = await self._discover_protected_resource(response)
discovery_response = yield discovery_request
await self._handle_protected_resource_response(discovery_response)

# Step 2: Discover OAuth metadata (with fallback for legacy servers)
oauth_request = await self._discover_oauth_metadata()
oauth_response = yield oauth_request
handled = await self._handle_oauth_metadata_response(oauth_response, is_fallback=False)

# If path-aware discovery failed with 404, try fallback to root
if not handled:
fallback_request = await self._discover_oauth_metadata_fallback()
fallback_response = yield fallback_request
await self._handle_oauth_metadata_response(fallback_response, is_fallback=True)

# Step 3: Register client if needed
registration_request = await self._register_client()
if registration_request:
registration_response = yield registration_request
await self._handle_registration_response(registration_response)

# Step 4: Perform authorization
auth_code, code_verifier = await self._perform_authorization()

# Step 5: Exchange authorization code for tokens
token_request = await self._exchange_token(auth_code, code_verifier)
token_response = yield token_request
await self._handle_token_response(token_response)
except Exception as e:
logger.error(f"OAuth flow error: {e}")
raise

# Retry with new tokens
self._add_auth_header(request)
yield request
146 changes: 143 additions & 3 deletions tests/client/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,14 +196,43 @@ class TestOAuthFlow:
"""Test OAuth flow methods."""

@pytest.mark.anyio
async def test_discover_protected_resource_request(self, oauth_provider):
"""Test protected resource discovery request building."""
request = await oauth_provider._discover_protected_resource()
async def test_discover_protected_resource_request(self, client_metadata, mock_storage):
"""Test protected resource discovery request building maintains backward compatibility."""

async def redirect_handler(url: str) -> None:
pass

async def callback_handler() -> tuple[str, str | None]:
return "test_auth_code", "test_state"

provider = OAuthClientProvider(
server_url="https://api.example.com",
client_metadata=client_metadata,
storage=mock_storage,
redirect_handler=redirect_handler,
callback_handler=callback_handler,
)

# Test without WWW-Authenticate (fallback)
init_response = httpx.Response(
status_code=401, headers={}, request=httpx.Request("GET", "https://request-api.example.com")
)

request = await provider._discover_protected_resource(init_response)
assert request.method == "GET"
assert str(request.url) == "https://api.example.com/.well-known/oauth-protected-resource"
assert "mcp-protocol-version" in request.headers

# Test with WWW-Authenticate header
init_response.headers["WWW-Authenticate"] = (
'Bearer resource_metadata="https://prm.example.com/.well-known/oauth-protected-resource/path"'
)

request = await provider._discover_protected_resource(init_response)
assert request.method == "GET"
assert str(request.url) == "https://prm.example.com/.well-known/oauth-protected-resource/path"
assert "mcp-protocol-version" in request.headers

@pytest.mark.anyio
async def test_discover_oauth_metadata_request(self, oauth_provider):
"""Test OAuth metadata discovery request building."""
Expand Down Expand Up @@ -544,3 +573,114 @@ async def test_auth_flow_with_valid_tokens(self, oauth_provider, mock_storage, v
await auth_flow.asend(response)
except StopAsyncIteration:
pass # Expected


class TestProtectedResourceWWWAuthenticate:
"""Test RFC9728 WWW-Authenticate header parsing functionality for protected resource."""

@pytest.mark.parametrize(
"www_auth_header,expected_url",
[
# Quoted URL
(
'Bearer resource_metadata="https://api.example.com/.well-known/oauth-protected-resource"',
"https://api.example.com/.well-known/oauth-protected-resource",
),
# Unquoted URL
(
"Bearer resource_metadata=https://api.example.com/.well-known/oauth-protected-resource",
"https://api.example.com/.well-known/oauth-protected-resource",
),
# Complex header with multiple parameters
(
'Bearer realm="api", resource_metadata="https://api.example.com/.well-known/oauth-protected-resource", '
'error="insufficient_scope"',
"https://api.example.com/.well-known/oauth-protected-resource",
),
# Different URL format
('Bearer resource_metadata="https://custom.domain.com/metadata"', "https://custom.domain.com/metadata"),
# With path and query params
(
'Bearer resource_metadata="https://api.example.com/auth/metadata?version=1"',
"https://api.example.com/auth/metadata?version=1",
),
],
)
def test_extract_resource_metadata_from_www_auth_valid_cases(
self, client_metadata, mock_storage, www_auth_header, expected_url
):
"""Test extraction of resource_metadata URL from various valid WWW-Authenticate headers."""

async def redirect_handler(url: str) -> None:
pass

async def callback_handler() -> tuple[str, str | None]:
return "test_auth_code", "test_state"

provider = OAuthClientProvider(
server_url="https://api.example.com/v1/mcp",
client_metadata=client_metadata,
storage=mock_storage,
redirect_handler=redirect_handler,
callback_handler=callback_handler,
)

init_response = httpx.Response(
status_code=401,
headers={"WWW-Authenticate": www_auth_header},
request=httpx.Request("GET", "https://api.example.com/test"),
)

result = provider._extract_resource_metadata_from_www_auth(init_response)
assert result == expected_url

@pytest.mark.parametrize(
"status_code,www_auth_header,description",
[
# No header
(401, None, "no WWW-Authenticate header"),
# Empty header
(401, "", "empty WWW-Authenticate header"),
# Header without resource_metadata
(401, 'Bearer realm="api", error="insufficient_scope"', "no resource_metadata parameter"),
# Malformed header
(401, "Bearer resource_metadata=", "malformed resource_metadata parameter"),
# Non-401 status code
(
200,
'Bearer resource_metadata="https://api.example.com/.well-known/oauth-protected-resource"',
"200 OK response",
),
(
500,
'Bearer resource_metadata="https://api.example.com/.well-known/oauth-protected-resource"',
"500 error response",
),
],
)
def test_extract_resource_metadata_from_www_auth_invalid_cases(
self, client_metadata, mock_storage, status_code, www_auth_header, description
):
"""Test extraction returns None for invalid cases."""

async def redirect_handler(url: str) -> None:
pass

async def callback_handler() -> tuple[str, str | None]:
return "test_auth_code", "test_state"

provider = OAuthClientProvider(
server_url="https://api.example.com/v1/mcp",
client_metadata=client_metadata,
storage=mock_storage,
redirect_handler=redirect_handler,
callback_handler=callback_handler,
)

headers = {"WWW-Authenticate": www_auth_header} if www_auth_header is not None else {}
init_response = httpx.Response(
status_code=status_code, headers=headers, request=httpx.Request("GET", "https://api.example.com/test")
)

result = provider._extract_resource_metadata_from_www_auth(init_response)
assert result is None, f"Should return None for {description}"
Loading