From cfb3f7e3ce5ff7dcc5a34df558ee34451544df1f Mon Sep 17 00:00:00 2001 From: Yuri Kunash Date: Wed, 2 Jul 2025 11:09:13 +0800 Subject: [PATCH 1/5] Added method for parsing WWW-Authenticate header --- src/mcp/client/auth.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/src/mcp/client/auth.py b/src/mcp/client/auth.py index 769e9b4c8..5361918c9 100644 --- a/src/mcp/client/auth.py +++ b/src/mcp/client/auth.py @@ -7,6 +7,7 @@ import base64 import hashlib import logging +import re import secrets import string import time @@ -33,6 +34,30 @@ logger = logging.getLogger(__name__) +def _extract_resource_metadata_from_www_auth(header_value: str) -> str | None: + """ + Parse WWW-Authenticate header to extract resource_metadata parameter. + + According to RFC9728, the header format is: + WWW-Authenticate: Bearer resource_metadata="https://example.com/.well-known/oauth-protected-resource" + + Returns the resource_metadata URL if found, None otherwise. + """ + if not header_value: + return None + + # Look for resource_metadata parameter in the header + # Pattern matches: resource_metadata="url" or resource_metadata=url (unquoted) + pattern = r'resource_metadata=(?:"([^"]+)"|([^\s,]+))' + match = re.search(pattern, header_value) + + if match: + # Return quoted value if present, otherwise unquoted value + return match.group(1) or match.group(2) + + return None + + class OAuthFlowError(Exception): """Base exception for OAuth flow errors.""" From 41b3d91e1e18fe51fbef442e11347bce55543d53 Mon Sep 17 00:00:00 2001 From: Yuri Kunash Date: Wed, 2 Jul 2025 18:14:12 +0800 Subject: [PATCH 2/5] Update async_auth_flow --- src/mcp/client/auth.py | 138 ++++++++++++++++------------------------- 1 file changed, 53 insertions(+), 85 deletions(-) diff --git a/src/mcp/client/auth.py b/src/mcp/client/auth.py index 5361918c9..9a95ad898 100644 --- a/src/mcp/client/auth.py +++ b/src/mcp/client/auth.py @@ -514,92 +514,60 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx. # Capture protocol version from request headers self.context.protocol_version = request.headers.get(MCP_PROTOCOL_VERSION) - # Perform OAuth flow if not authenticated - if not self.context.is_token_valid(): - try: - # OAuth flow must be inline due to generator constraints - # Step 1: Discover protected resource metadata (spec revision 2025-06-18) - discovery_request = await self._discover_protected_resource() - discovery_response = yield discovery_request - await self._handle_protected_resource_response(discovery_response) - - # Step 2: Discover OAuth metadata (with fallback for legacy servers) - oauth_request = await self._discover_oauth_metadata() - oauth_response = yield oauth_request - handled = await self._handle_oauth_metadata_response(oauth_response, is_fallback=False) - - # If path-aware discovery failed with 404, try fallback to root - if not handled: - fallback_request = await self._discover_oauth_metadata_fallback() - fallback_response = yield fallback_request - await self._handle_oauth_metadata_response(fallback_response, is_fallback=True) - - # Step 3: Register client if needed - registration_request = await self._register_client() - if registration_request: - registration_response = yield registration_request - await self._handle_registration_response(registration_response) - - # Step 4: Perform authorization - auth_code, code_verifier = await self._perform_authorization() - - # Step 5: Exchange authorization code for tokens - token_request = await self._exchange_token(auth_code, code_verifier) - token_response = yield token_request - await self._handle_token_response(token_response) - except Exception as e: - logger.error(f"OAuth flow error: {e}") - raise - - # Add authorization header and make request - self._add_auth_header(request) + if self.context.is_token_valid(): + self._add_auth_header(request) + response = yield request - # Handle 401 responses - if response.status_code == 401 and self.context.can_refresh_token(): - # Try to refresh token - refresh_request = await self._refresh_token() - refresh_response = yield refresh_request + if response.status_code == 401: + if self.context.can_refresh_token(): + # Try to refresh token + refresh_request = await self._refresh_token() + refresh_response = yield refresh_request - if await self._handle_refresh_response(refresh_response): - # Retry original request with new token - self._add_auth_header(request) - yield request + if not await self._handle_refresh_response(refresh_response): + # Refresh failed, need full re-authentication + self._initialized = False else: - # Refresh failed, need full re-authentication - self._initialized = False - - # OAuth flow must be inline due to generator constraints - # Step 1: Discover protected resource metadata (spec revision 2025-06-18) - discovery_request = await self._discover_protected_resource() - discovery_response = yield discovery_request - await self._handle_protected_resource_response(discovery_response) - - # Step 2: Discover OAuth metadata (with fallback for legacy servers) - oauth_request = await self._discover_oauth_metadata() - oauth_response = yield oauth_request - handled = await self._handle_oauth_metadata_response(oauth_response, is_fallback=False) - - # If path-aware discovery failed with 404, try fallback to root - if not handled: - fallback_request = await self._discover_oauth_metadata_fallback() - fallback_response = yield fallback_request - await self._handle_oauth_metadata_response(fallback_response, is_fallback=True) - - # Step 3: Register client if needed - registration_request = await self._register_client() - if registration_request: - registration_response = yield registration_request - await self._handle_registration_response(registration_response) - - # Step 4: Perform authorization - auth_code, code_verifier = await self._perform_authorization() - - # Step 5: Exchange authorization code for tokens - token_request = await self._exchange_token(auth_code, code_verifier) - token_response = yield token_request - await self._handle_token_response(token_response) - - # Retry with new tokens - self._add_auth_header(request) - yield request + self.context.clear_tokens() + + # If we don't have valid tokens after refresh, perform OAuth flow + if not self.context.is_token_valid(): + try: + # OAuth flow must be inline due to generator constraints + # Step 1: Discover protected resource metadata (spec revision 2025-06-18) + discovery_request = await self._discover_protected_resource() + discovery_response = yield discovery_request + await self._handle_protected_resource_response(discovery_response) + + # Step 2: Discover OAuth metadata (with fallback for legacy servers) + oauth_request = await self._discover_oauth_metadata() + oauth_response = yield oauth_request + handled = await self._handle_oauth_metadata_response(oauth_response, is_fallback=False) + + # If path-aware discovery failed with 404, try fallback to root + if not handled: + fallback_request = await self._discover_oauth_metadata_fallback() + fallback_response = yield fallback_request + await self._handle_oauth_metadata_response(fallback_response, is_fallback=True) + + # Step 3: Register client if needed + registration_request = await self._register_client() + if registration_request: + registration_response = yield registration_request + await self._handle_registration_response(registration_response) + + # Step 4: Perform authorization + auth_code, code_verifier = await self._perform_authorization() + + # Step 5: Exchange authorization code for tokens + token_request = await self._exchange_token(auth_code, code_verifier) + token_response = yield token_request + await self._handle_token_response(token_response) + except Exception as e: + logger.error(f"OAuth flow error: {e}") + raise + + # Retry with new tokens + self._add_auth_header(request) + yield request From dd0902e7385841fc7bc85fac2d4b0d0c20067668 Mon Sep 17 00:00:00 2001 From: Yuri Kunash Date: Wed, 2 Jul 2025 18:43:03 +0800 Subject: [PATCH 3/5] Check for WWW-Authenticate header --- src/mcp/client/auth.py | 38 ++++++++++++++++++++++---------------- 1 file changed, 22 insertions(+), 16 deletions(-) diff --git a/src/mcp/client/auth.py b/src/mcp/client/auth.py index 9a95ad898..6190bd474 100644 --- a/src/mcp/client/auth.py +++ b/src/mcp/client/auth.py @@ -34,22 +34,23 @@ logger = logging.getLogger(__name__) -def _extract_resource_metadata_from_www_auth(header_value: str) -> str | None: +def _extract_resource_metadata_from_www_auth(response: httpx.Response) -> str | None: """ - Parse WWW-Authenticate header to extract resource_metadata parameter. - - According to RFC9728, the header format is: - WWW-Authenticate: Bearer resource_metadata="https://example.com/.well-known/oauth-protected-resource" - - Returns the resource_metadata URL if found, None otherwise. + Extract protected resource metadata URL from WWW-Authenticate header as per RFC9728. + + Returns: + Resource metadata URL if found in WWW-Authenticate header, None otherwise """ - if not header_value: + if not response or response.status_code != 401: + return None + + www_auth_header = response.headers.get("WWW-Authenticate") + if not www_auth_header: return None - # Look for resource_metadata parameter in the header # Pattern matches: resource_metadata="url" or resource_metadata=url (unquoted) pattern = r'resource_metadata=(?:"([^"]+)"|([^\s,]+))' - match = re.search(pattern, header_value) + match = re.search(pattern, www_auth_header) if match: # Return quoted value if present, otherwise unquoted value @@ -228,10 +229,15 @@ def __init__( ) self._initialized = False - async def _discover_protected_resource(self) -> httpx.Request: - """Build discovery request for protected resource metadata.""" - auth_base_url = self.context.get_authorization_base_url(self.context.server_url) - url = urljoin(auth_base_url, "/.well-known/oauth-protected-resource") + async def _discover_protected_resource(self, response: httpx.Response | None = None) -> httpx.Request: + # RFC9728: Try to extract resource_metadata URL from WWW-Authenticate header + url = _extract_resource_metadata_from_www_auth(response) if response else None + + if not url: + # Fallback to well-known discovery + auth_base_url = self.context.get_authorization_base_url(self.context.server_url) + url = urljoin(auth_base_url, "/.well-known/oauth-protected-resource") + return httpx.Request("GET", url, headers={MCP_PROTOCOL_VERSION: LATEST_PROTOCOL_VERSION}) async def _handle_protected_resource_response(self, response: httpx.Response) -> None: @@ -535,8 +541,8 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx. if not self.context.is_token_valid(): try: # OAuth flow must be inline due to generator constraints - # Step 1: Discover protected resource metadata (spec revision 2025-06-18) - discovery_request = await self._discover_protected_resource() + # Step 1: Discover protected resource metadata (RFC9728 with WWW-Authenticate support) + discovery_request = await self._discover_protected_resource(response) discovery_response = yield discovery_request await self._handle_protected_resource_response(discovery_response) From b7e88d8edec10c6a4f9c3c1a11d62ea10c61bbd0 Mon Sep 17 00:00:00 2001 From: Yuri Kunash Date: Wed, 2 Jul 2025 21:11:34 +0800 Subject: [PATCH 4/5] Added unit-tests --- src/mcp/client/auth.py | 55 +++++++-------- tests/client/test_auth.py | 145 +++++++++++++++++++++++++++++++++++++- 2 files changed, 169 insertions(+), 31 deletions(-) diff --git a/src/mcp/client/auth.py b/src/mcp/client/auth.py index 6190bd474..baad0040a 100644 --- a/src/mcp/client/auth.py +++ b/src/mcp/client/auth.py @@ -34,31 +34,6 @@ logger = logging.getLogger(__name__) -def _extract_resource_metadata_from_www_auth(response: httpx.Response) -> str | None: - """ - Extract protected resource metadata URL from WWW-Authenticate header as per RFC9728. - - Returns: - Resource metadata URL if found in WWW-Authenticate header, None otherwise - """ - if not response or response.status_code != 401: - return None - - www_auth_header = response.headers.get("WWW-Authenticate") - if not www_auth_header: - return None - - # Pattern matches: resource_metadata="url" or resource_metadata=url (unquoted) - pattern = r'resource_metadata=(?:"([^"]+)"|([^\s,]+))' - match = re.search(pattern, www_auth_header) - - if match: - # Return quoted value if present, otherwise unquoted value - return match.group(1) or match.group(2) - - return None - - class OAuthFlowError(Exception): """Base exception for OAuth flow errors.""" @@ -229,9 +204,33 @@ def __init__( ) self._initialized = False - async def _discover_protected_resource(self, response: httpx.Response | None = None) -> httpx.Request: - # RFC9728: Try to extract resource_metadata URL from WWW-Authenticate header - url = _extract_resource_metadata_from_www_auth(response) if response else None + def _extract_resource_metadata_from_www_auth(self, init_response: httpx.Response) -> str | None: + """ + Extract protected resource metadata URL from WWW-Authenticate header as per RFC9728. + + Returns: + Resource metadata URL if found in WWW-Authenticate header, None otherwise + """ + if not init_response or init_response.status_code != 401: + return None + + www_auth_header = init_response.headers.get("WWW-Authenticate") + if not www_auth_header: + return None + + # Pattern matches: resource_metadata="url" or resource_metadata=url (unquoted) + pattern = r'resource_metadata=(?:"([^"]+)"|([^\s,]+))' + match = re.search(pattern, www_auth_header) + + if match: + # Return quoted value if present, otherwise unquoted value + return match.group(1) or match.group(2) + + return None + + async def _discover_protected_resource(self, init_response: httpx.Response | None = None) -> httpx.Request: + # RFC9728: Try to extract resource_metadata URL from WWW-Authenticate header of the initial response + url = self._extract_resource_metadata_from_www_auth(init_response) if init_response else None if not url: # Fallback to well-known discovery diff --git a/tests/client/test_auth.py b/tests/client/test_auth.py index 8e6b4f54d..13699d0fb 100644 --- a/tests/client/test_auth.py +++ b/tests/client/test_auth.py @@ -196,13 +196,47 @@ class TestOAuthFlow: """Test OAuth flow methods.""" @pytest.mark.anyio - async def test_discover_protected_resource_request(self, oauth_provider): - """Test protected resource discovery request building.""" - request = await oauth_provider._discover_protected_resource() + async def test_discover_protected_resource_request(self, client_metadata, mock_storage): + """Test protected resource discovery request building maintains backward compatibility.""" + async def redirect_handler(url: str) -> None: + pass + + async def callback_handler() -> tuple[str, str | None]: + return "test_auth_code", "test_state" + provider = OAuthClientProvider( + server_url="https://api.example.com", + client_metadata=client_metadata, + storage=mock_storage, + redirect_handler=redirect_handler, + callback_handler=callback_handler, + ) + + # Test without response (backward compatibility) + request = await provider._discover_protected_resource() assert request.method == "GET" assert str(request.url) == "https://api.example.com/.well-known/oauth-protected-resource" assert "mcp-protocol-version" in request.headers + + # Test with response but no WWW-Authenticate (fallback) + init_response = httpx.Response( + status_code=401, + headers={}, + request=httpx.Request("GET", "https://request-api.example.com") + ) + + request = await provider._discover_protected_resource(init_response) + assert request.method == "GET" + assert str(request.url) == "https://api.example.com/.well-known/oauth-protected-resource" + assert "mcp-protocol-version" in request.headers + + # Test with WWW-Authenticate header + init_response.headers["WWW-Authenticate"] = 'Bearer resource_metadata="https://prm.example.com/.well-known/oauth-protected-resource/path"' + + request = await provider._discover_protected_resource(init_response) + assert request.method == "GET" + assert str(request.url) == "https://prm.example.com/.well-known/oauth-protected-resource/path" + assert "mcp-protocol-version" in request.headers @pytest.mark.anyio async def test_discover_oauth_metadata_request(self, oauth_provider): @@ -544,3 +578,108 @@ async def test_auth_flow_with_valid_tokens(self, oauth_provider, mock_storage, v await auth_flow.asend(response) except StopAsyncIteration: pass # Expected + + +class TestRFC9728WWWAuthenticate: + """Test RFC9728 WWW-Authenticate header parsing functionality.""" + + @pytest.mark.parametrize("www_auth_header,expected_url", [ + # Quoted URL + ('Bearer resource_metadata="https://api.example.com/.well-known/oauth-protected-resource"', + "https://api.example.com/.well-known/oauth-protected-resource"), + # Unquoted URL + ("Bearer resource_metadata=https://api.example.com/.well-known/oauth-protected-resource", + "https://api.example.com/.well-known/oauth-protected-resource"), + # Complex header with multiple parameters + ('Bearer realm="api", resource_metadata="https://api.example.com/.well-known/oauth-protected-resource", error="insufficient_scope"', + "https://api.example.com/.well-known/oauth-protected-resource"), + # Different URL format + ('Bearer resource_metadata="https://custom.domain.com/metadata"', + "https://custom.domain.com/metadata"), + # With path and query params + ('Bearer resource_metadata="https://api.example.com/auth/metadata?version=1"', + "https://api.example.com/auth/metadata?version=1"), + ]) + def test_extract_resource_metadata_from_www_auth_valid_cases(self, client_metadata, mock_storage, www_auth_header, expected_url): + """Test extraction of resource_metadata URL from various valid WWW-Authenticate headers.""" + async def redirect_handler(url: str) -> None: + pass + + async def callback_handler() -> tuple[str, str | None]: + return "test_auth_code", "test_state" + + provider = OAuthClientProvider( + server_url="https://api.example.com/v1/mcp", + client_metadata=client_metadata, + storage=mock_storage, + redirect_handler=redirect_handler, + callback_handler=callback_handler, + ) + + init_response = httpx.Response( + status_code=401, + headers={"WWW-Authenticate": www_auth_header}, + request=httpx.Request("GET", "https://api.example.com/test") + ) + + result = provider._extract_resource_metadata_from_www_auth(init_response) + assert result == expected_url + + @pytest.mark.parametrize("status_code,www_auth_header,description", [ + # No header + (401, None, "no WWW-Authenticate header"), + # Empty header + (401, "", "empty WWW-Authenticate header"), + # Header without resource_metadata + (401, 'Bearer realm="api", error="insufficient_scope"', "no resource_metadata parameter"), + # Malformed header + (401, "Bearer resource_metadata=", "malformed resource_metadata parameter"), + # Non-401 status code + (200, 'Bearer resource_metadata="https://api.example.com/.well-known/oauth-protected-resource"', "200 OK response"), + (500, 'Bearer resource_metadata="https://api.example.com/.well-known/oauth-protected-resource"', "500 error response"), + ]) + def test_extract_resource_metadata_from_www_auth_invalid_cases(self, client_metadata, mock_storage, status_code, www_auth_header, description): + """Test extraction returns None for invalid cases.""" + async def redirect_handler(url: str) -> None: + pass + + async def callback_handler() -> tuple[str, str | None]: + return "test_auth_code", "test_state" + + provider = OAuthClientProvider( + server_url="https://api.example.com/v1/mcp", + client_metadata=client_metadata, + storage=mock_storage, + redirect_handler=redirect_handler, + callback_handler=callback_handler, + ) + + headers = {"WWW-Authenticate": www_auth_header} if www_auth_header is not None else {} + init_response = httpx.Response( + status_code=status_code, + headers=headers, + request=httpx.Request("GET", "https://api.example.com/test") + ) + + result = provider._extract_resource_metadata_from_www_auth(init_response) + assert result is None, f"Should return None for {description}" + + def test_extract_resource_metadata_from_www_auth_none_response(self, client_metadata, mock_storage): + """Test extraction with None response returns None.""" + async def redirect_handler(url: str) -> None: + pass + + async def callback_handler() -> tuple[str, str | None]: + return "test_auth_code", "test_state" + + provider = OAuthClientProvider( + server_url="https://api.example.com/v1/mcp", + client_metadata=client_metadata, + storage=mock_storage, + redirect_handler=redirect_handler, + callback_handler=callback_handler, + ) + + result = provider._extract_resource_metadata_from_www_auth(None) + assert result is None + From 2bc58dbccffd2add41537b504c47b529963dfd7a Mon Sep 17 00:00:00 2001 From: Yuri Kunash Date: Wed, 2 Jul 2025 21:23:37 +0800 Subject: [PATCH 5/5] Liniting issues fixed --- src/mcp/client/auth.py | 20 ++--- tests/client/test_auth.py | 157 +++++++++++++++++++------------------- 2 files changed, 89 insertions(+), 88 deletions(-) diff --git a/src/mcp/client/auth.py b/src/mcp/client/auth.py index baad0040a..3d6f28fa2 100644 --- a/src/mcp/client/auth.py +++ b/src/mcp/client/auth.py @@ -207,36 +207,36 @@ def __init__( def _extract_resource_metadata_from_www_auth(self, init_response: httpx.Response) -> str | None: """ Extract protected resource metadata URL from WWW-Authenticate header as per RFC9728. - + Returns: Resource metadata URL if found in WWW-Authenticate header, None otherwise """ if not init_response or init_response.status_code != 401: return None - + www_auth_header = init_response.headers.get("WWW-Authenticate") if not www_auth_header: return None - + # Pattern matches: resource_metadata="url" or resource_metadata=url (unquoted) pattern = r'resource_metadata=(?:"([^"]+)"|([^\s,]+))' match = re.search(pattern, www_auth_header) - + if match: # Return quoted value if present, otherwise unquoted value return match.group(1) or match.group(2) - + return None - async def _discover_protected_resource(self, init_response: httpx.Response | None = None) -> httpx.Request: + async def _discover_protected_resource(self, init_response: httpx.Response) -> httpx.Request: # RFC9728: Try to extract resource_metadata URL from WWW-Authenticate header of the initial response - url = self._extract_resource_metadata_from_www_auth(init_response) if init_response else None - + url = self._extract_resource_metadata_from_www_auth(init_response) + if not url: # Fallback to well-known discovery auth_base_url = self.context.get_authorization_base_url(self.context.server_url) url = urljoin(auth_base_url, "/.well-known/oauth-protected-resource") - + return httpx.Request("GET", url, headers={MCP_PROTOCOL_VERSION: LATEST_PROTOCOL_VERSION}) async def _handle_protected_resource_response(self, response: httpx.Response) -> None: @@ -521,7 +521,7 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx. if self.context.is_token_valid(): self._add_auth_header(request) - + response = yield request if response.status_code == 401: diff --git a/tests/client/test_auth.py b/tests/client/test_auth.py index 13699d0fb..a230e1209 100644 --- a/tests/client/test_auth.py +++ b/tests/client/test_auth.py @@ -198,6 +198,7 @@ class TestOAuthFlow: @pytest.mark.anyio async def test_discover_protected_resource_request(self, client_metadata, mock_storage): """Test protected resource discovery request building maintains backward compatibility.""" + async def redirect_handler(url: str) -> None: pass @@ -211,30 +212,24 @@ async def callback_handler() -> tuple[str, str | None]: redirect_handler=redirect_handler, callback_handler=callback_handler, ) - - # Test without response (backward compatibility) - request = await provider._discover_protected_resource() - assert request.method == "GET" - assert str(request.url) == "https://api.example.com/.well-known/oauth-protected-resource" - assert "mcp-protocol-version" in request.headers - - # Test with response but no WWW-Authenticate (fallback) + + # Test without WWW-Authenticate (fallback) init_response = httpx.Response( - status_code=401, - headers={}, - request=httpx.Request("GET", "https://request-api.example.com") + status_code=401, headers={}, request=httpx.Request("GET", "https://request-api.example.com") ) request = await provider._discover_protected_resource(init_response) - assert request.method == "GET" + assert request.method == "GET" assert str(request.url) == "https://api.example.com/.well-known/oauth-protected-resource" assert "mcp-protocol-version" in request.headers # Test with WWW-Authenticate header - init_response.headers["WWW-Authenticate"] = 'Bearer resource_metadata="https://prm.example.com/.well-known/oauth-protected-resource/path"' - + init_response.headers["WWW-Authenticate"] = ( + 'Bearer resource_metadata="https://prm.example.com/.well-known/oauth-protected-resource/path"' + ) + request = await provider._discover_protected_resource(init_response) - assert request.method == "GET" + assert request.method == "GET" assert str(request.url) == "https://prm.example.com/.well-known/oauth-protected-resource/path" assert "mcp-protocol-version" in request.headers @@ -580,28 +575,42 @@ async def test_auth_flow_with_valid_tokens(self, oauth_provider, mock_storage, v pass # Expected -class TestRFC9728WWWAuthenticate: - """Test RFC9728 WWW-Authenticate header parsing functionality.""" - - @pytest.mark.parametrize("www_auth_header,expected_url", [ - # Quoted URL - ('Bearer resource_metadata="https://api.example.com/.well-known/oauth-protected-resource"', - "https://api.example.com/.well-known/oauth-protected-resource"), - # Unquoted URL - ("Bearer resource_metadata=https://api.example.com/.well-known/oauth-protected-resource", - "https://api.example.com/.well-known/oauth-protected-resource"), - # Complex header with multiple parameters - ('Bearer realm="api", resource_metadata="https://api.example.com/.well-known/oauth-protected-resource", error="insufficient_scope"', - "https://api.example.com/.well-known/oauth-protected-resource"), - # Different URL format - ('Bearer resource_metadata="https://custom.domain.com/metadata"', - "https://custom.domain.com/metadata"), - # With path and query params - ('Bearer resource_metadata="https://api.example.com/auth/metadata?version=1"', - "https://api.example.com/auth/metadata?version=1"), - ]) - def test_extract_resource_metadata_from_www_auth_valid_cases(self, client_metadata, mock_storage, www_auth_header, expected_url): +class TestProtectedResourceWWWAuthenticate: + """Test RFC9728 WWW-Authenticate header parsing functionality for protected resource.""" + + @pytest.mark.parametrize( + "www_auth_header,expected_url", + [ + # Quoted URL + ( + 'Bearer resource_metadata="https://api.example.com/.well-known/oauth-protected-resource"', + "https://api.example.com/.well-known/oauth-protected-resource", + ), + # Unquoted URL + ( + "Bearer resource_metadata=https://api.example.com/.well-known/oauth-protected-resource", + "https://api.example.com/.well-known/oauth-protected-resource", + ), + # Complex header with multiple parameters + ( + 'Bearer realm="api", resource_metadata="https://api.example.com/.well-known/oauth-protected-resource", ' + 'error="insufficient_scope"', + "https://api.example.com/.well-known/oauth-protected-resource", + ), + # Different URL format + ('Bearer resource_metadata="https://custom.domain.com/metadata"', "https://custom.domain.com/metadata"), + # With path and query params + ( + 'Bearer resource_metadata="https://api.example.com/auth/metadata?version=1"', + "https://api.example.com/auth/metadata?version=1", + ), + ], + ) + def test_extract_resource_metadata_from_www_auth_valid_cases( + self, client_metadata, mock_storage, www_auth_header, expected_url + ): """Test extraction of resource_metadata URL from various valid WWW-Authenticate headers.""" + async def redirect_handler(url: str) -> None: pass @@ -615,31 +624,45 @@ async def callback_handler() -> tuple[str, str | None]: redirect_handler=redirect_handler, callback_handler=callback_handler, ) - + init_response = httpx.Response( status_code=401, headers={"WWW-Authenticate": www_auth_header}, - request=httpx.Request("GET", "https://api.example.com/test") + request=httpx.Request("GET", "https://api.example.com/test"), ) - + result = provider._extract_resource_metadata_from_www_auth(init_response) assert result == expected_url - @pytest.mark.parametrize("status_code,www_auth_header,description", [ - # No header - (401, None, "no WWW-Authenticate header"), - # Empty header - (401, "", "empty WWW-Authenticate header"), - # Header without resource_metadata - (401, 'Bearer realm="api", error="insufficient_scope"', "no resource_metadata parameter"), - # Malformed header - (401, "Bearer resource_metadata=", "malformed resource_metadata parameter"), - # Non-401 status code - (200, 'Bearer resource_metadata="https://api.example.com/.well-known/oauth-protected-resource"', "200 OK response"), - (500, 'Bearer resource_metadata="https://api.example.com/.well-known/oauth-protected-resource"', "500 error response"), - ]) - def test_extract_resource_metadata_from_www_auth_invalid_cases(self, client_metadata, mock_storage, status_code, www_auth_header, description): + @pytest.mark.parametrize( + "status_code,www_auth_header,description", + [ + # No header + (401, None, "no WWW-Authenticate header"), + # Empty header + (401, "", "empty WWW-Authenticate header"), + # Header without resource_metadata + (401, 'Bearer realm="api", error="insufficient_scope"', "no resource_metadata parameter"), + # Malformed header + (401, "Bearer resource_metadata=", "malformed resource_metadata parameter"), + # Non-401 status code + ( + 200, + 'Bearer resource_metadata="https://api.example.com/.well-known/oauth-protected-resource"', + "200 OK response", + ), + ( + 500, + 'Bearer resource_metadata="https://api.example.com/.well-known/oauth-protected-resource"', + "500 error response", + ), + ], + ) + def test_extract_resource_metadata_from_www_auth_invalid_cases( + self, client_metadata, mock_storage, status_code, www_auth_header, description + ): """Test extraction returns None for invalid cases.""" + async def redirect_handler(url: str) -> None: pass @@ -653,33 +676,11 @@ async def callback_handler() -> tuple[str, str | None]: redirect_handler=redirect_handler, callback_handler=callback_handler, ) - + headers = {"WWW-Authenticate": www_auth_header} if www_auth_header is not None else {} init_response = httpx.Response( - status_code=status_code, - headers=headers, - request=httpx.Request("GET", "https://api.example.com/test") + status_code=status_code, headers=headers, request=httpx.Request("GET", "https://api.example.com/test") ) - + result = provider._extract_resource_metadata_from_www_auth(init_response) assert result is None, f"Should return None for {description}" - - def test_extract_resource_metadata_from_www_auth_none_response(self, client_metadata, mock_storage): - """Test extraction with None response returns None.""" - async def redirect_handler(url: str) -> None: - pass - - async def callback_handler() -> tuple[str, str | None]: - return "test_auth_code", "test_state" - - provider = OAuthClientProvider( - server_url="https://api.example.com/v1/mcp", - client_metadata=client_metadata, - storage=mock_storage, - redirect_handler=redirect_handler, - callback_handler=callback_handler, - ) - - result = provider._extract_resource_metadata_from_www_auth(None) - assert result is None -