diff --git a/python/delta_sharing/_internal_auth.py b/python/delta_sharing/_internal_auth.py new file mode 100644 index 000000000..55a15fb06 --- /dev/null +++ b/python/delta_sharing/_internal_auth.py @@ -0,0 +1,233 @@ +# +# Copyright (C) 2021 The Delta Lake Project Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from abc import ABC, abstractmethod +from datetime import datetime +from typing import Optional +import requests +import base64 +import json +import threading +import requests.sessions +import time +from typing import Dict + +from delta_sharing.protocol import ( + DeltaSharingProfile, +) + +# This module contains internal implementation classes. +# These classes are not part of the public API and should not be used directly by users. +# Internal classes may change or be removed at any time without notice. + + +class AuthConfig: + def __init__(self, token_exchange_max_retries=5, + token_exchange_max_retry_duration_in_seconds=60, + token_renewal_threshold_in_seconds=600): + self.token_exchange_max_retries = token_exchange_max_retries + self.token_exchange_max_retry_duration_in_seconds = ( + token_exchange_max_retry_duration_in_seconds) + self.token_renewal_threshold_in_seconds = token_renewal_threshold_in_seconds + + +class AuthCredentialProvider(ABC): + @abstractmethod + def add_auth_header(self, session: requests.Session) -> None: + pass + + def is_expired(self) -> bool: + return False + + @abstractmethod + def get_expiration_time(self) -> Optional[str]: + return None + + +class BearerTokenAuthProvider(AuthCredentialProvider): + def __init__(self, bearer_token: str, expiration_time: Optional[str]): + self.bearer_token = bearer_token + self.expiration_time = expiration_time + + def add_auth_header(self, session: requests.Session) -> None: + session.headers.update( + { + "Authorization": f"Bearer {self.bearer_token}", + } + ) + + def is_expired(self) -> bool: + if self.expiration_time is None: + return False + try: + expiration_time_as_timestamp = datetime.fromisoformat(self.expiration_time) + return expiration_time_as_timestamp < datetime.now() + except ValueError: + return False + + def get_expiration_time(self) -> Optional[str]: + return self.expiration_time + + +class BasicAuthProvider(AuthCredentialProvider): + def __init__(self, endpoint: str, username: str, password: str): + self.username = username + self.password = password + self.endpoint = endpoint + + def add_auth_header(self, session: requests.Session) -> None: + session.auth = (self.username, self.password) + session.post(self.endpoint, data={"grant_type": "client_credentials"},) + + def is_expired(self) -> bool: + return False + + def get_expiration_time(self) -> Optional[str]: + return None + + +class OAuthClientCredentials: + def __init__(self, access_token: str, expires_in: int, creation_timestamp: int): + self.access_token = access_token + self.expires_in = expires_in + self.creation_timestamp = creation_timestamp + + +class OAuthClient: + def __init__(self, + token_endpoint: str, + client_id: str, + client_secret: str, + scope: Optional[str] = None): + self.token_endpoint = token_endpoint + self.client_id = client_id + self.client_secret = client_secret + self.scope = scope + + def client_credentials(self) -> OAuthClientCredentials: + credentials = base64.b64encode( + f"{self.client_id}:{self.client_secret}".encode('utf-8')).decode('utf-8') + headers = { + 'accept': 'application/json', + 'authorization': f'Basic {credentials}', + 'content-type': 'application/x-www-form-urlencoded' + } + body = f"grant_type=client_credentials{f'&scope={self.scope}' if self.scope else ''}" + response = requests.post(self.token_endpoint, headers=headers, data=body) + response.raise_for_status() + return self.parse_oauth_token_response(response.text) + + def parse_oauth_token_response(self, response: str) -> OAuthClientCredentials: + if not response: + raise RuntimeError("Empty response from OAuth token endpoint") + json_node = json.loads(response) + if 'access_token' not in json_node or not isinstance(json_node['access_token'], str): + raise RuntimeError("Missing 'access_token' field in OAuth token response") + if 'expires_in' not in json_node or not isinstance(json_node['expires_in'], int): + raise RuntimeError("Missing 'expires_in' field in OAuth token response") + return OAuthClientCredentials( + json_node['access_token'], + json_node['expires_in'], + int(datetime.now().timestamp()) + ) + + +class OAuthClientCredentialsAuthProvider(AuthCredentialProvider): + def __init__(self, oauth_client: OAuthClient, auth_config: AuthConfig = AuthConfig()): + self.auth_config = auth_config + self.oauth_client = oauth_client + self.current_token: Optional[OAuthClientCredentials] = None + self.lock = threading.RLock() + + def add_auth_header(self,session: requests.Session) -> None: + token = self.maybe_refresh_token() + with self.lock: + session.headers.update( + { + "Authorization": f"Bearer {token.access_token}", + } + ) + + def maybe_refresh_token(self) -> OAuthClientCredentials: + with self.lock: + if self.current_token and not self.needs_refresh(self.current_token): + return self.current_token + new_token = self.oauth_client.client_credentials() + self.current_token = new_token + return new_token + + def needs_refresh(self, token: OAuthClientCredentials) -> bool: + now = int(time.time()) + expiration_time = token.creation_timestamp + token.expires_in + return expiration_time - now < self.auth_config.token_renewal_threshold_in_seconds + + def get_expiration_time(self) -> Optional[str]: + return None + + +class AuthCredentialProviderFactory: + __oauth_auth_provider_cache : Dict[ + DeltaSharingProfile, + OAuthClientCredentialsAuthProvider] = {} + + @staticmethod + def create_auth_credential_provider(profile: DeltaSharingProfile): + if profile.share_credentials_version == 2: + if profile.type == "oauth_client_credentials": + return AuthCredentialProviderFactory.__oauth_client_credentials(profile) + elif profile.type == "basic": + return AuthCredentialProviderFactory.__auth_basic(profile) + elif (profile.share_credentials_version == 1 and + (profile.type is None or profile.type == "bearer_token")): + return AuthCredentialProviderFactory.__auth_bearer_token(profile) + + # any other scenario is unsupported + raise RuntimeError(f"unsupported profile.type: {profile.type}" + f" profile.share_credentials_version" + f" {profile.share_credentials_version}") + + @staticmethod + def __oauth_client_credentials(profile): + # Once a clientId/clientSecret is exchanged for an accessToken, + # the accessToken can be reused until it expires. + # The Python client re-creates DeltaSharingClient for different requests. + # To ensure the OAuth access_token is reused, + # we keep a mapping from profile -> OAuthClientCredentialsAuthProvider. + # This prevents re-initializing OAuthClientCredentialsAuthProvider for the same profile, + # ensuring the access_token can be reused. + if profile in AuthCredentialProviderFactory.__oauth_auth_provider_cache: + return AuthCredentialProviderFactory.__oauth_auth_provider_cache[profile] + + oauth_client = OAuthClient( + token_endpoint=profile.token_endpoint, + client_id=profile.client_id, + client_secret=profile.client_secret, + scope=profile.scope + ) + provider = OAuthClientCredentialsAuthProvider( + oauth_client=oauth_client, + auth_config=AuthConfig() + ) + AuthCredentialProviderFactory.__oauth_auth_provider_cache[profile] = provider + return provider + + @staticmethod + def __auth_bearer_token(profile): + return BearerTokenAuthProvider(profile.bearer_token, profile.expiration_time) + + @staticmethod + def __auth_basic(profile): + return BasicAuthProvider(profile.endpoint, profile.username, profile.password) diff --git a/python/delta_sharing/protocol.py b/python/delta_sharing/protocol.py index ea1566786..d8873c156 100644 --- a/python/delta_sharing/protocol.py +++ b/python/delta_sharing/protocol.py @@ -35,6 +35,7 @@ class DeltaSharingProfile: client_secret: Optional[str] = None username: Optional[str] = None password: Optional[str] = None + scope: Optional[str] = None def __post_init__(self): if self.share_credentials_version > DeltaSharingProfile.CURRENT: @@ -77,7 +78,7 @@ def from_json(json) -> "DeltaSharingProfile": ) elif share_credentials_version == 2: type = json["type"] - if type == "persistent_oauth2.0": + if type == "oauth_client_credentials": token_endpoint = json["tokenEndpoint"] if token_endpoint is not None and token_endpoint.endswith("/"): token_endpoint = token_endpoint[:-1] @@ -88,6 +89,7 @@ def from_json(json) -> "DeltaSharingProfile": token_endpoint=token_endpoint, client_id=json["clientId"], client_secret=json["clientSecret"], + scope=json.get("scope"), ) elif type == "bearer_token": return DeltaSharingProfile( @@ -107,7 +109,7 @@ def from_json(json) -> "DeltaSharingProfile": ) else: raise ValueError( - "The current release does not supports {type} type. " + f"The current release does not supports {type} type. " "Please check type.") else: raise ValueError( diff --git a/python/delta_sharing/rest_client.py b/python/delta_sharing/rest_client.py index 73fe82573..e1103239a 100644 --- a/python/delta_sharing/rest_client.py +++ b/python/delta_sharing/rest_client.py @@ -22,7 +22,6 @@ import time import logging import pprint -from datetime import datetime import requests from requests.exceptions import HTTPError, ConnectionError @@ -39,6 +38,8 @@ Table, ) +from delta_sharing._internal_auth import AuthCredentialProviderFactory + @dataclass(frozen=True) class ListSharesResponse: @@ -151,65 +152,20 @@ def __init__(self, profile: DeltaSharingProfile, num_retries=10): self._profile = profile self._num_retries = num_retries self._sleeper = lambda sleep_ms: time.sleep(sleep_ms / 1000) - self.auth_session(profile) - - def auth_session(self, profile): - self._session = requests.Session() - self.__auth_broker(profile) - if urlparse(profile.endpoint).hostname == "localhost": - self._session.verify = False - - def __auth_broker(self, profile): - if profile.share_credentials_version == 2: - if profile.type == "persistent_oauth2.0": - self.__auth_persistent_oauth2(profile) - elif profile.type == "bearer_token": - self.__auth_bearer_token(profile) - elif profile.type == "basic": - self.__auth_basic(profile) - else: - self.__auth_bearer_token(profile) - else: - self.__auth_bearer_token(profile) - - def __auth_bearer_token(self, profile): - self._session.headers.update( - { - "Authorization": f"Bearer {profile.bearer_token}", - "User-Agent": DataSharingRestClient.USER_AGENT, - } - ) - - def __auth_persistent_oauth2(self, profile): - headers = {"Content-Type": "application/x-www-form-urlencoded", - "Accept": "application/json"} - - response = requests.post(profile.token_endpoint, - data={"grant_type": "client_credentials"}, - headers=headers, - auth=(profile.client_id, - profile.client_secret),) - - bearer_token = "{}".format(response.json()["access_token"]) + self.__auth_session(profile) self._session.headers.update( { - "Authorization": f"Bearer {bearer_token}", "User-Agent": DataSharingRestClient.USER_AGENT, } ) - def __auth_basic(self, profile): - self._session.auth = (profile.username, profile.password) - - response = self._session.post(profile.endpoint, - data={"grant_type": "client_credentials"},) - - self._session.headers.update( - { - "User-Agent": DataSharingRestClient.USER_AGENT, - } - ) + def __auth_session(self, profile): + self._session = requests.Session() + self._auth_credential_provider = ( + AuthCredentialProviderFactory.create_auth_credential_provider(profile)) + if urlparse(profile.endpoint).hostname == "localhost": + self._session.verify = False def set_sharing_capabilities_header(self): delta_sharing_capabilities = ( @@ -502,6 +458,7 @@ def _request_internal( **kwargs, ): assert target.startswith("/"), "Targets should start with '/'" + self._auth_credential_provider.add_auth_header(self._session) response = request(f"{self._profile.endpoint}{target}", **kwargs) try: response.raise_for_status() @@ -541,10 +498,7 @@ def _should_retry(self, error): def _error_on_expired_token(self, error): if isinstance(error, HTTPError) and error.response.status_code == 401: try: - expiration_time = datetime.strptime( - self._profile.expiration_time, "%Y-%m-%dT%H:%M:%S.%fZ" - ) - return datetime.now() > expiration_time + self._auth_credential_provider.is_expired() except Exception: return False else: diff --git a/python/delta_sharing/tests/test_auth.py b/python/delta_sharing/tests/test_auth.py new file mode 100644 index 000000000..81dbd1ba9 --- /dev/null +++ b/python/delta_sharing/tests/test_auth.py @@ -0,0 +1,294 @@ +# +# Copyright (C) 2021 The Delta Lake Project Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from unittest.mock import MagicMock +from datetime import datetime, timedelta +from delta_sharing._internal_auth import (OAuthClient, + BasicAuthProvider, + AuthCredentialProviderFactory, + OAuthClientCredentialsAuthProvider, + OAuthClientCredentials) +from requests import Session +import requests +from delta_sharing._internal_auth import BearerTokenAuthProvider +from delta_sharing.protocol import DeltaSharingProfile + + +def test_bearer_token_auth_provider_initialization(): + token = "test-token" + expiration_time = "2021-11-12T00:12:29.0Z" + provider = BearerTokenAuthProvider(token, expiration_time) + assert provider.bearer_token == token + assert provider.expiration_time == expiration_time + + +def test_bearer_token_auth_provider_add_auth_header(): + token = "test-token" + provider = BearerTokenAuthProvider(token, None) + session = requests.Session() + provider.add_auth_header(session) + assert session.headers["Authorization"] == f"Bearer {token}" + + +def test_bearer_token_auth_provider_is_expired(): + expired_token = "expired-token" + expiration_time = (datetime.now() - timedelta(days=1)).isoformat() + provider = BearerTokenAuthProvider(expired_token, expiration_time) + assert provider.is_expired() + + valid_token = "valid-token" + expiration_time = (datetime.now() + timedelta(days=1)).isoformat() + provider = BearerTokenAuthProvider(valid_token, expiration_time) + assert not provider.is_expired() + + +def test_bearer_token_auth_provider_get_expiration_time(): + token = "test-token" + expiration_time = "2021-11-12T00:12:29.0Z" + provider = BearerTokenAuthProvider(token, expiration_time) + assert provider.get_expiration_time() == expiration_time + + provider = BearerTokenAuthProvider(token, None) + assert provider.get_expiration_time() is None + + +def test_oauth_client_credentials_auth_provider_exchange_token(): + oauth_client = MagicMock(spec=OAuthClient) + profile = MagicMock() + profile.token_endpoint = "http://example.com/token" + profile.client_id = "client-id" + profile.client_secret = "client-secret" + profile.scope = None + + provider = OAuthClientCredentialsAuthProvider(oauth_client) + mock_session = MagicMock(spec=Session) + mock_session.headers = MagicMock() + + token = OAuthClientCredentials("access-token", 3600, int(datetime.now().timestamp())) + oauth_client.client_credentials.return_value = token + + provider.add_auth_header(mock_session) + + mock_session.headers.update.assert_called_once_with( + {"Authorization": f"Bearer {token.access_token}"}) + oauth_client.client_credentials.assert_called_once() + + +def test_oauth_client_credentials_auth_provider_reuse_token(): + oauth_client = MagicMock(spec=OAuthClient) + profile = MagicMock() + profile.token_endpoint = "http://example.com/token" + profile.client_id = "client-id" + profile.client_secret = "client-secret" + profile.scope = None + + provider = OAuthClientCredentialsAuthProvider(oauth_client) + mock_session = MagicMock(spec=Session) + mock_session.headers = MagicMock() + + valid_token = OAuthClientCredentials( + "valid-token", 3600, int(datetime.now().timestamp())) + provider.current_token = valid_token + + provider.add_auth_header(mock_session) + + mock_session.headers.update.assert_called_once_with( + {"Authorization": f"Bearer {valid_token.access_token}"}) + oauth_client.client_credentials.assert_not_called() + + +def test_oauth_client_credentials_auth_provider_refresh_token(): + oauth_client = MagicMock(spec=OAuthClient) + profile = MagicMock() + profile.token_endpoint = "http://example.com/token" + profile.client_id = "client-id" + profile.client_secret = "client-secret" + profile.scope = None + + provider = OAuthClientCredentialsAuthProvider(oauth_client) + mock_session = MagicMock(spec=Session) + mock_session.headers = MagicMock() + + expired_token = OAuthClientCredentials( + "expired-token", 1, int(datetime.now().timestamp()) - 3600) + new_token = OAuthClientCredentials( + "new-token", 3600, int(datetime.now().timestamp())) + provider.current_token = expired_token + oauth_client.client_credentials.return_value = new_token + + provider.add_auth_header(mock_session) + + mock_session.headers.update.assert_called_once_with( + {"Authorization": f"Bearer {new_token.access_token}"}) + oauth_client.client_credentials.assert_called_once() + + +def test_oauth_client_credentials_auth_provider_needs_refresh(): + oauth_client = MagicMock(spec=OAuthClient) + profile = MagicMock() + profile.token_endpoint = "http://example.com/token" + profile.client_id = "client-id" + profile.client_secret = "client-secret" + profile.scope = None + + provider = OAuthClientCredentialsAuthProvider(oauth_client) + + expired_token = OAuthClientCredentials( + "expired-token", 1, int(datetime.now().timestamp()) - 3600) + assert provider.needs_refresh(expired_token) + + token_expiring_soon = OAuthClientCredentials( + "expiring-soon-token", 600 - 5, int(datetime.now().timestamp())) + assert provider.needs_refresh(token_expiring_soon) + + valid_token = OAuthClientCredentials( + "valid-token", 600 + 10, int(datetime.now().timestamp())) + assert not provider.needs_refresh(valid_token) + + +def test_oauth_client_credentials_auth_provider_is_expired(): + oauth_client = MagicMock(spec=OAuthClient) + profile = MagicMock() + profile.token_endpoint = "http://example.com/token" + profile.client_id = "client-id" + profile.client_secret = "client-secret" + profile.scope = None + + provider = OAuthClientCredentialsAuthProvider(oauth_client) + assert not provider.is_expired() + + +def test_oauth_client_credentials_auth_provider_get_expiration_time(): + oauth_client = MagicMock(spec=OAuthClient) + profile = MagicMock() + profile.token_endpoint = "http://example.com/token" + profile.client_id = "client-id" + profile.client_secret = "client-secret" + profile.scope = None + + provider = OAuthClientCredentialsAuthProvider(oauth_client) + assert provider.get_expiration_time() is None + + +def test_basic_auth_provider_initialization(): + provider = BasicAuthProvider("https://localhost", "username", "password") + assert provider.username == "username" + assert provider.password == "password" + + +def test_basic_auth_provider_add_auth_header(): + provider = BasicAuthProvider("https://localhost", "username", "password") + session = MagicMock(spec=requests.Session) + session.headers = MagicMock() + session.auth = MagicMock() + provider.add_auth_header(session) + session.post("https://localhost/delta-sharing/", data={"grant_type": "client_credentials"}) + assert session.auth == ("username", "password") + + +def test_basic_auth_provider_is_expired(): + provider = BasicAuthProvider("https://localhost", "username", "password") + assert not provider.is_expired() + + +def test_basic_auth_provider_get_expiration_time(): + provider = BasicAuthProvider("https://localhost", "username", "password") + assert provider.get_expiration_time() is None + + +def test_factory_creation(): + profile_basic = DeltaSharingProfile( + share_credentials_version=2, + type="basic", + endpoint="https://localhost/delta-sharing/", + username="username", + password="password" + ) + provider = AuthCredentialProviderFactory.create_auth_credential_provider(profile_basic) + assert isinstance(provider, BasicAuthProvider) + + profile_bearer = DeltaSharingProfile( + share_credentials_version=1, + type="bearer_token", + endpoint="https://localhost/delta-sharing/", + bearer_token="token", + expiration_time=(datetime.now() + timedelta(hours=1)).isoformat() + ) + provider = AuthCredentialProviderFactory.create_auth_credential_provider(profile_bearer) + assert isinstance(provider, BearerTokenAuthProvider) + + profile_oauth = DeltaSharingProfile( + share_credentials_version=2, + type="oauth_client_credentials", + endpoint="https://localhost/delta-sharing/", + token_endpoint="https://localhost/token", + client_id="clientId", + client_secret="clientSecret" + ) + provider = AuthCredentialProviderFactory.create_auth_credential_provider(profile_oauth) + assert isinstance(provider, OAuthClientCredentialsAuthProvider) + + +def test_oauth_auth_provider_reused(): + profile_oauth1 = DeltaSharingProfile( + share_credentials_version=2, + type="oauth_client_credentials", + endpoint="https://localhost/delta-sharing/", + token_endpoint="https://localhost/token", + client_id="clientId", + client_secret="clientSecret" + ) + provider1 = AuthCredentialProviderFactory.create_auth_credential_provider(profile_oauth1) + assert isinstance(provider1, OAuthClientCredentialsAuthProvider) + + profile_oauth2 = DeltaSharingProfile( + share_credentials_version=2, + type="oauth_client_credentials", + endpoint="https://localhost/delta-sharing/", + token_endpoint="https://localhost/token", + client_id="clientId", + client_secret="clientSecret" + ) + + provider2 = AuthCredentialProviderFactory.create_auth_credential_provider(profile_oauth2) + + assert provider1 == provider2 + + +def test_oauth_auth_provider_with_different_profiles(): + profile_oauth1 = DeltaSharingProfile( + share_credentials_version=2, + type="oauth_client_credentials", + endpoint="https://localhost/delta-sharing/", + token_endpoint="https://localhost/1/token", + client_id="clientId", + client_secret="clientSecret" + ) + provider1 = AuthCredentialProviderFactory.create_auth_credential_provider(profile_oauth1) + assert isinstance(provider1, OAuthClientCredentialsAuthProvider) + + profile_oauth2 = DeltaSharingProfile( + share_credentials_version=2, + type="oauth_client_credentials", + endpoint="https://localhost/delta-sharing/", + token_endpoint="https://localhost/2/token", + client_id="clientId", + client_secret="clientSecret" + ) + + provider2 = AuthCredentialProviderFactory.create_auth_credential_provider(profile_oauth2) + + assert provider1 != provider2 diff --git a/python/delta_sharing/tests/test_oauth_client.py b/python/delta_sharing/tests/test_oauth_client.py new file mode 100644 index 000000000..bf2316cca --- /dev/null +++ b/python/delta_sharing/tests/test_oauth_client.py @@ -0,0 +1,81 @@ +# +# Copyright (C) 2021 The Delta Lake Project Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import pytest +import requests +from requests.models import Response +from unittest.mock import patch +from datetime import datetime +from delta_sharing._internal_auth import OAuthClient + + +class MockServer: + def __init__(self): + self.url = "http://localhost:1080/token" + self.responses = [] + + def add_response(self, status_code, json_data): + response = Response() + response.status_code = status_code + response._content = json_data.encode('utf-8') + self.responses.append(response) + + def get_response(self): + return self.responses.pop(0) + + +@pytest.fixture +def mock_server(): + server = MockServer() + yield server + + +def test_oauth_client_should_parse_token_response_correctly(mock_server): + mock_server.add_response( + 200, + '{"access_token": "test-access-token", "expires_in": 3600, "token_type": "bearer"}') + + with patch('requests.post') as mock_post: + mock_post.side_effect = lambda *args, **kwargs: mock_server.get_response() + oauth_client = OAuthClient( + token_endpoint=mock_server.url, + client_id="client-id", + client_secret="client-secret" + ) + + start = datetime.now().timestamp() + token = oauth_client.client_credentials() + end = datetime.now().timestamp() + + assert token.access_token == "test-access-token" + assert token.expires_in == 3600 + assert int(start) <= token.creation_timestamp + assert token.creation_timestamp <= int(end) + + +def test_oauth_client_should_handle_401_unauthorized_response(mock_server): + mock_server.add_response(401, 'Unauthorized') + + with patch('requests.post') as mock_post: + mock_post.side_effect = lambda *args, **kwargs: mock_server.get_response() + oauth_client = OAuthClient( + token_endpoint=mock_server.url, + client_id="client-id", + client_secret="client-secret" + ) + try: + oauth_client.client_credentials() + except requests.HTTPError as e: + assert e.response.status_code == 401 diff --git a/python/delta_sharing/tests/test_profile_oauth2.json b/python/delta_sharing/tests/test_profile_oauth2.json index 2242aa57f..3e253832a 100644 --- a/python/delta_sharing/tests/test_profile_oauth2.json +++ b/python/delta_sharing/tests/test_profile_oauth2.json @@ -1,6 +1,6 @@ { "shareCredentialsVersion": 2, - "type": "persistent_oauth2.0", + "type": "oauth_client_credentials", "endpoint": "https://localhost/delta-sharing/", "tokenEndpoint": "tokenEndpoint", "clientId": "clientId", diff --git a/python/delta_sharing/tests/test_protocol.py b/python/delta_sharing/tests/test_protocol.py index 1e980b5a2..3301eb390 100644 --- a/python/delta_sharing/tests/test_protocol.py +++ b/python/delta_sharing/tests/test_protocol.py @@ -186,11 +186,11 @@ def test_share_profile_bearer(tmp_path): DeltaSharingProfile.read_from_file(io.StringIO(json)) -def test_share_profile_oauth2(tmp_path): +def oauth_client_credentials(tmp_path): json = """ { "shareCredentialsVersion": 2, - "type": "persistent_oauth2.0", + "type": "oauth_client_credentials", "endpoint": "https://localhost/delta-sharing/", "tokenEndpoint": "tokenEndpoint", "clientId": "clientId", @@ -202,7 +202,7 @@ def test_share_profile_oauth2(tmp_path): "https://localhost/delta-sharing", None, None, - "persistent_oauth2.0", + "oauth_client_credentials", "tokenEndpoint", "clientId", "clientSecret") @@ -212,7 +212,7 @@ def test_share_profile_oauth2(tmp_path): "https://localhost/delta-sharing", None, None, - "persistent_oauth2.0", + "oauth_client_credentials", "tokenEndpoint", "clientId", "clientSecret") @@ -226,7 +226,7 @@ def test_share_profile_oauth2(tmp_path): "https://localhost/delta-sharing", None, None, - "persistent_oauth2.0", + "oauth_client_credentials", "tokenEndpoint", "clientId", "clientSecret") @@ -236,7 +236,7 @@ def test_share_profile_oauth2(tmp_path): "https://localhost/delta-sharing", None, None, - "persistent_oauth2.0", + "oauth_client_credentials", "tokenEndpoint", "clientId", "clientSecret") @@ -246,7 +246,7 @@ def test_share_profile_oauth2(tmp_path): "https://localhost/delta-sharing", None, None, - "persistent_oauth2.0", + "oauth_client_credentials", "tokenEndpoint", "clientId", "clientSecret") @@ -256,7 +256,7 @@ def test_share_profile_oauth2(tmp_path): "https://localhost/delta-sharing", None, None, - "persistent_oauth2.0", + "oauth_client_credentials", "tokenEndpoint", "clientId", "clientSecret") @@ -264,7 +264,7 @@ def test_share_profile_oauth2(tmp_path): json = """ { "shareCredentialsVersion": 100, - "type": "persistent_oauth2.0", + "type": "oauth_client_credentials", "endpoint": "https://localhost/delta-sharing/", "tokenEndpoint": "tokenEndpoint", "clientId": "clientId",