From 1be4eedad10aa4ff8379cd2607b3f00f76f35de3 Mon Sep 17 00:00:00 2001 From: Moe Derakhshani Date: Wed, 14 Aug 2024 09:42:26 -0700 Subject: [PATCH] Python OAuth client_credentials support (#563) The current OAuth implementation in the Python client does not reuse access tokens. Instead, it exchanges the client-id and client-secret for a new access token with every request. This behavior increases the load on the tokenEndpoint and introduces unnecessary latency in the request processing. This PR updates the OAuth implementation to oauth reuse access tokens, aligning the Python client with the OAuth behavior of the Scala client, as detailed in this [PR for the Scala client](https://github.com/delta-io/delta-sharing/pull/553). For additional context on the usage of auth_provider and its functionality, please refer to the description in the original PR for the Spark client. --- python/delta_sharing/_internal_auth.py | 233 ++++++++++++++ python/delta_sharing/protocol.py | 6 +- python/delta_sharing/rest_client.py | 68 +--- python/delta_sharing/tests/test_auth.py | 294 ++++++++++++++++++ .../delta_sharing/tests/test_oauth_client.py | 81 +++++ .../tests/test_profile_oauth2.json | 2 +- python/delta_sharing/tests/test_protocol.py | 18 +- 7 files changed, 633 insertions(+), 69 deletions(-) create mode 100644 python/delta_sharing/_internal_auth.py create mode 100644 python/delta_sharing/tests/test_auth.py create mode 100644 python/delta_sharing/tests/test_oauth_client.py diff --git a/python/delta_sharing/_internal_auth.py b/python/delta_sharing/_internal_auth.py new file mode 100644 index 000000000..55a15fb06 --- /dev/null +++ b/python/delta_sharing/_internal_auth.py @@ -0,0 +1,233 @@ +# +# Copyright (C) 2021 The Delta Lake Project Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from abc import ABC, abstractmethod +from datetime import datetime +from typing import Optional +import requests +import base64 +import json +import threading +import requests.sessions +import time +from typing import Dict + +from delta_sharing.protocol import ( + DeltaSharingProfile, +) + +# This module contains internal implementation classes. +# These classes are not part of the public API and should not be used directly by users. +# Internal classes may change or be removed at any time without notice. + + +class AuthConfig: + def __init__(self, token_exchange_max_retries=5, + token_exchange_max_retry_duration_in_seconds=60, + token_renewal_threshold_in_seconds=600): + self.token_exchange_max_retries = token_exchange_max_retries + self.token_exchange_max_retry_duration_in_seconds = ( + token_exchange_max_retry_duration_in_seconds) + self.token_renewal_threshold_in_seconds = token_renewal_threshold_in_seconds + + +class AuthCredentialProvider(ABC): + @abstractmethod + def add_auth_header(self, session: requests.Session) -> None: + pass + + def is_expired(self) -> bool: + return False + + @abstractmethod + def get_expiration_time(self) -> Optional[str]: + return None + + +class BearerTokenAuthProvider(AuthCredentialProvider): + def __init__(self, bearer_token: str, expiration_time: Optional[str]): + self.bearer_token = bearer_token + self.expiration_time = expiration_time + + def add_auth_header(self, session: requests.Session) -> None: + session.headers.update( + { + "Authorization": f"Bearer {self.bearer_token}", + } + ) + + def is_expired(self) -> bool: + if self.expiration_time is None: + return False + try: + expiration_time_as_timestamp = datetime.fromisoformat(self.expiration_time) + return expiration_time_as_timestamp < datetime.now() + except ValueError: + return False + + def get_expiration_time(self) -> Optional[str]: + return self.expiration_time + + +class BasicAuthProvider(AuthCredentialProvider): + def __init__(self, endpoint: str, username: str, password: str): + self.username = username + self.password = password + self.endpoint = endpoint + + def add_auth_header(self, session: requests.Session) -> None: + session.auth = (self.username, self.password) + session.post(self.endpoint, data={"grant_type": "client_credentials"},) + + def is_expired(self) -> bool: + return False + + def get_expiration_time(self) -> Optional[str]: + return None + + +class OAuthClientCredentials: + def __init__(self, access_token: str, expires_in: int, creation_timestamp: int): + self.access_token = access_token + self.expires_in = expires_in + self.creation_timestamp = creation_timestamp + + +class OAuthClient: + def __init__(self, + token_endpoint: str, + client_id: str, + client_secret: str, + scope: Optional[str] = None): + self.token_endpoint = token_endpoint + self.client_id = client_id + self.client_secret = client_secret + self.scope = scope + + def client_credentials(self) -> OAuthClientCredentials: + credentials = base64.b64encode( + f"{self.client_id}:{self.client_secret}".encode('utf-8')).decode('utf-8') + headers = { + 'accept': 'application/json', + 'authorization': f'Basic {credentials}', + 'content-type': 'application/x-www-form-urlencoded' + } + body = f"grant_type=client_credentials{f'&scope={self.scope}' if self.scope else ''}" + response = requests.post(self.token_endpoint, headers=headers, data=body) + response.raise_for_status() + return self.parse_oauth_token_response(response.text) + + def parse_oauth_token_response(self, response: str) -> OAuthClientCredentials: + if not response: + raise RuntimeError("Empty response from OAuth token endpoint") + json_node = json.loads(response) + if 'access_token' not in json_node or not isinstance(json_node['access_token'], str): + raise RuntimeError("Missing 'access_token' field in OAuth token response") + if 'expires_in' not in json_node or not isinstance(json_node['expires_in'], int): + raise RuntimeError("Missing 'expires_in' field in OAuth token response") + return OAuthClientCredentials( + json_node['access_token'], + json_node['expires_in'], + int(datetime.now().timestamp()) + ) + + +class OAuthClientCredentialsAuthProvider(AuthCredentialProvider): + def __init__(self, oauth_client: OAuthClient, auth_config: AuthConfig = AuthConfig()): + self.auth_config = auth_config + self.oauth_client = oauth_client + self.current_token: Optional[OAuthClientCredentials] = None + self.lock = threading.RLock() + + def add_auth_header(self,session: requests.Session) -> None: + token = self.maybe_refresh_token() + with self.lock: + session.headers.update( + { + "Authorization": f"Bearer {token.access_token}", + } + ) + + def maybe_refresh_token(self) -> OAuthClientCredentials: + with self.lock: + if self.current_token and not self.needs_refresh(self.current_token): + return self.current_token + new_token = self.oauth_client.client_credentials() + self.current_token = new_token + return new_token + + def needs_refresh(self, token: OAuthClientCredentials) -> bool: + now = int(time.time()) + expiration_time = token.creation_timestamp + token.expires_in + return expiration_time - now < self.auth_config.token_renewal_threshold_in_seconds + + def get_expiration_time(self) -> Optional[str]: + return None + + +class AuthCredentialProviderFactory: + __oauth_auth_provider_cache : Dict[ + DeltaSharingProfile, + OAuthClientCredentialsAuthProvider] = {} + + @staticmethod + def create_auth_credential_provider(profile: DeltaSharingProfile): + if profile.share_credentials_version == 2: + if profile.type == "oauth_client_credentials": + return AuthCredentialProviderFactory.__oauth_client_credentials(profile) + elif profile.type == "basic": + return AuthCredentialProviderFactory.__auth_basic(profile) + elif (profile.share_credentials_version == 1 and + (profile.type is None or profile.type == "bearer_token")): + return AuthCredentialProviderFactory.__auth_bearer_token(profile) + + # any other scenario is unsupported + raise RuntimeError(f"unsupported profile.type: {profile.type}" + f" profile.share_credentials_version" + f" {profile.share_credentials_version}") + + @staticmethod + def __oauth_client_credentials(profile): + # Once a clientId/clientSecret is exchanged for an accessToken, + # the accessToken can be reused until it expires. + # The Python client re-creates DeltaSharingClient for different requests. + # To ensure the OAuth access_token is reused, + # we keep a mapping from profile -> OAuthClientCredentialsAuthProvider. + # This prevents re-initializing OAuthClientCredentialsAuthProvider for the same profile, + # ensuring the access_token can be reused. + if profile in AuthCredentialProviderFactory.__oauth_auth_provider_cache: + return AuthCredentialProviderFactory.__oauth_auth_provider_cache[profile] + + oauth_client = OAuthClient( + token_endpoint=profile.token_endpoint, + client_id=profile.client_id, + client_secret=profile.client_secret, + scope=profile.scope + ) + provider = OAuthClientCredentialsAuthProvider( + oauth_client=oauth_client, + auth_config=AuthConfig() + ) + AuthCredentialProviderFactory.__oauth_auth_provider_cache[profile] = provider + return provider + + @staticmethod + def __auth_bearer_token(profile): + return BearerTokenAuthProvider(profile.bearer_token, profile.expiration_time) + + @staticmethod + def __auth_basic(profile): + return BasicAuthProvider(profile.endpoint, profile.username, profile.password) diff --git a/python/delta_sharing/protocol.py b/python/delta_sharing/protocol.py index ea1566786..d8873c156 100644 --- a/python/delta_sharing/protocol.py +++ b/python/delta_sharing/protocol.py @@ -35,6 +35,7 @@ class DeltaSharingProfile: client_secret: Optional[str] = None username: Optional[str] = None password: Optional[str] = None + scope: Optional[str] = None def __post_init__(self): if self.share_credentials_version > DeltaSharingProfile.CURRENT: @@ -77,7 +78,7 @@ def from_json(json) -> "DeltaSharingProfile": ) elif share_credentials_version == 2: type = json["type"] - if type == "persistent_oauth2.0": + if type == "oauth_client_credentials": token_endpoint = json["tokenEndpoint"] if token_endpoint is not None and token_endpoint.endswith("/"): token_endpoint = token_endpoint[:-1] @@ -88,6 +89,7 @@ def from_json(json) -> "DeltaSharingProfile": token_endpoint=token_endpoint, client_id=json["clientId"], client_secret=json["clientSecret"], + scope=json.get("scope"), ) elif type == "bearer_token": return DeltaSharingProfile( @@ -107,7 +109,7 @@ def from_json(json) -> "DeltaSharingProfile": ) else: raise ValueError( - "The current release does not supports {type} type. " + f"The current release does not supports {type} type. " "Please check type.") else: raise ValueError( diff --git a/python/delta_sharing/rest_client.py b/python/delta_sharing/rest_client.py index 73fe82573..e1103239a 100644 --- a/python/delta_sharing/rest_client.py +++ b/python/delta_sharing/rest_client.py @@ -22,7 +22,6 @@ import time import logging import pprint -from datetime import datetime import requests from requests.exceptions import HTTPError, ConnectionError @@ -39,6 +38,8 @@ Table, ) +from delta_sharing._internal_auth import AuthCredentialProviderFactory + @dataclass(frozen=True) class ListSharesResponse: @@ -151,65 +152,20 @@ def __init__(self, profile: DeltaSharingProfile, num_retries=10): self._profile = profile self._num_retries = num_retries self._sleeper = lambda sleep_ms: time.sleep(sleep_ms / 1000) - self.auth_session(profile) - - def auth_session(self, profile): - self._session = requests.Session() - self.__auth_broker(profile) - if urlparse(profile.endpoint).hostname == "localhost": - self._session.verify = False - - def __auth_broker(self, profile): - if profile.share_credentials_version == 2: - if profile.type == "persistent_oauth2.0": - self.__auth_persistent_oauth2(profile) - elif profile.type == "bearer_token": - self.__auth_bearer_token(profile) - elif profile.type == "basic": - self.__auth_basic(profile) - else: - self.__auth_bearer_token(profile) - else: - self.__auth_bearer_token(profile) - - def __auth_bearer_token(self, profile): - self._session.headers.update( - { - "Authorization": f"Bearer {profile.bearer_token}", - "User-Agent": DataSharingRestClient.USER_AGENT, - } - ) - - def __auth_persistent_oauth2(self, profile): - headers = {"Content-Type": "application/x-www-form-urlencoded", - "Accept": "application/json"} - - response = requests.post(profile.token_endpoint, - data={"grant_type": "client_credentials"}, - headers=headers, - auth=(profile.client_id, - profile.client_secret),) - - bearer_token = "{}".format(response.json()["access_token"]) + self.__auth_session(profile) self._session.headers.update( { - "Authorization": f"Bearer {bearer_token}", "User-Agent": DataSharingRestClient.USER_AGENT, } ) - def __auth_basic(self, profile): - self._session.auth = (profile.username, profile.password) - - response = self._session.post(profile.endpoint, - data={"grant_type": "client_credentials"},) - - self._session.headers.update( - { - "User-Agent": DataSharingRestClient.USER_AGENT, - } - ) + def __auth_session(self, profile): + self._session = requests.Session() + self._auth_credential_provider = ( + AuthCredentialProviderFactory.create_auth_credential_provider(profile)) + if urlparse(profile.endpoint).hostname == "localhost": + self._session.verify = False def set_sharing_capabilities_header(self): delta_sharing_capabilities = ( @@ -502,6 +458,7 @@ def _request_internal( **kwargs, ): assert target.startswith("/"), "Targets should start with '/'" + self._auth_credential_provider.add_auth_header(self._session) response = request(f"{self._profile.endpoint}{target}", **kwargs) try: response.raise_for_status() @@ -541,10 +498,7 @@ def _should_retry(self, error): def _error_on_expired_token(self, error): if isinstance(error, HTTPError) and error.response.status_code == 401: try: - expiration_time = datetime.strptime( - self._profile.expiration_time, "%Y-%m-%dT%H:%M:%S.%fZ" - ) - return datetime.now() > expiration_time + self._auth_credential_provider.is_expired() except Exception: return False else: diff --git a/python/delta_sharing/tests/test_auth.py b/python/delta_sharing/tests/test_auth.py new file mode 100644 index 000000000..81dbd1ba9 --- /dev/null +++ b/python/delta_sharing/tests/test_auth.py @@ -0,0 +1,294 @@ +# +# Copyright (C) 2021 The Delta Lake Project Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from unittest.mock import MagicMock +from datetime import datetime, timedelta +from delta_sharing._internal_auth import (OAuthClient, + BasicAuthProvider, + AuthCredentialProviderFactory, + OAuthClientCredentialsAuthProvider, + OAuthClientCredentials) +from requests import Session +import requests +from delta_sharing._internal_auth import BearerTokenAuthProvider +from delta_sharing.protocol import DeltaSharingProfile + + +def test_bearer_token_auth_provider_initialization(): + token = "test-token" + expiration_time = "2021-11-12T00:12:29.0Z" + provider = BearerTokenAuthProvider(token, expiration_time) + assert provider.bearer_token == token + assert provider.expiration_time == expiration_time + + +def test_bearer_token_auth_provider_add_auth_header(): + token = "test-token" + provider = BearerTokenAuthProvider(token, None) + session = requests.Session() + provider.add_auth_header(session) + assert session.headers["Authorization"] == f"Bearer {token}" + + +def test_bearer_token_auth_provider_is_expired(): + expired_token = "expired-token" + expiration_time = (datetime.now() - timedelta(days=1)).isoformat() + provider = BearerTokenAuthProvider(expired_token, expiration_time) + assert provider.is_expired() + + valid_token = "valid-token" + expiration_time = (datetime.now() + timedelta(days=1)).isoformat() + provider = BearerTokenAuthProvider(valid_token, expiration_time) + assert not provider.is_expired() + + +def test_bearer_token_auth_provider_get_expiration_time(): + token = "test-token" + expiration_time = "2021-11-12T00:12:29.0Z" + provider = BearerTokenAuthProvider(token, expiration_time) + assert provider.get_expiration_time() == expiration_time + + provider = BearerTokenAuthProvider(token, None) + assert provider.get_expiration_time() is None + + +def test_oauth_client_credentials_auth_provider_exchange_token(): + oauth_client = MagicMock(spec=OAuthClient) + profile = MagicMock() + profile.token_endpoint = "http://example.com/token" + profile.client_id = "client-id" + profile.client_secret = "client-secret" + profile.scope = None + + provider = OAuthClientCredentialsAuthProvider(oauth_client) + mock_session = MagicMock(spec=Session) + mock_session.headers = MagicMock() + + token = OAuthClientCredentials("access-token", 3600, int(datetime.now().timestamp())) + oauth_client.client_credentials.return_value = token + + provider.add_auth_header(mock_session) + + mock_session.headers.update.assert_called_once_with( + {"Authorization": f"Bearer {token.access_token}"}) + oauth_client.client_credentials.assert_called_once() + + +def test_oauth_client_credentials_auth_provider_reuse_token(): + oauth_client = MagicMock(spec=OAuthClient) + profile = MagicMock() + profile.token_endpoint = "http://example.com/token" + profile.client_id = "client-id" + profile.client_secret = "client-secret" + profile.scope = None + + provider = OAuthClientCredentialsAuthProvider(oauth_client) + mock_session = MagicMock(spec=Session) + mock_session.headers = MagicMock() + + valid_token = OAuthClientCredentials( + "valid-token", 3600, int(datetime.now().timestamp())) + provider.current_token = valid_token + + provider.add_auth_header(mock_session) + + mock_session.headers.update.assert_called_once_with( + {"Authorization": f"Bearer {valid_token.access_token}"}) + oauth_client.client_credentials.assert_not_called() + + +def test_oauth_client_credentials_auth_provider_refresh_token(): + oauth_client = MagicMock(spec=OAuthClient) + profile = MagicMock() + profile.token_endpoint = "http://example.com/token" + profile.client_id = "client-id" + profile.client_secret = "client-secret" + profile.scope = None + + provider = OAuthClientCredentialsAuthProvider(oauth_client) + mock_session = MagicMock(spec=Session) + mock_session.headers = MagicMock() + + expired_token = OAuthClientCredentials( + "expired-token", 1, int(datetime.now().timestamp()) - 3600) + new_token = OAuthClientCredentials( + "new-token", 3600, int(datetime.now().timestamp())) + provider.current_token = expired_token + oauth_client.client_credentials.return_value = new_token + + provider.add_auth_header(mock_session) + + mock_session.headers.update.assert_called_once_with( + {"Authorization": f"Bearer {new_token.access_token}"}) + oauth_client.client_credentials.assert_called_once() + + +def test_oauth_client_credentials_auth_provider_needs_refresh(): + oauth_client = MagicMock(spec=OAuthClient) + profile = MagicMock() + profile.token_endpoint = "http://example.com/token" + profile.client_id = "client-id" + profile.client_secret = "client-secret" + profile.scope = None + + provider = OAuthClientCredentialsAuthProvider(oauth_client) + + expired_token = OAuthClientCredentials( + "expired-token", 1, int(datetime.now().timestamp()) - 3600) + assert provider.needs_refresh(expired_token) + + token_expiring_soon = OAuthClientCredentials( + "expiring-soon-token", 600 - 5, int(datetime.now().timestamp())) + assert provider.needs_refresh(token_expiring_soon) + + valid_token = OAuthClientCredentials( + "valid-token", 600 + 10, int(datetime.now().timestamp())) + assert not provider.needs_refresh(valid_token) + + +def test_oauth_client_credentials_auth_provider_is_expired(): + oauth_client = MagicMock(spec=OAuthClient) + profile = MagicMock() + profile.token_endpoint = "http://example.com/token" + profile.client_id = "client-id" + profile.client_secret = "client-secret" + profile.scope = None + + provider = OAuthClientCredentialsAuthProvider(oauth_client) + assert not provider.is_expired() + + +def test_oauth_client_credentials_auth_provider_get_expiration_time(): + oauth_client = MagicMock(spec=OAuthClient) + profile = MagicMock() + profile.token_endpoint = "http://example.com/token" + profile.client_id = "client-id" + profile.client_secret = "client-secret" + profile.scope = None + + provider = OAuthClientCredentialsAuthProvider(oauth_client) + assert provider.get_expiration_time() is None + + +def test_basic_auth_provider_initialization(): + provider = BasicAuthProvider("https://localhost", "username", "password") + assert provider.username == "username" + assert provider.password == "password" + + +def test_basic_auth_provider_add_auth_header(): + provider = BasicAuthProvider("https://localhost", "username", "password") + session = MagicMock(spec=requests.Session) + session.headers = MagicMock() + session.auth = MagicMock() + provider.add_auth_header(session) + session.post("https://localhost/delta-sharing/", data={"grant_type": "client_credentials"}) + assert session.auth == ("username", "password") + + +def test_basic_auth_provider_is_expired(): + provider = BasicAuthProvider("https://localhost", "username", "password") + assert not provider.is_expired() + + +def test_basic_auth_provider_get_expiration_time(): + provider = BasicAuthProvider("https://localhost", "username", "password") + assert provider.get_expiration_time() is None + + +def test_factory_creation(): + profile_basic = DeltaSharingProfile( + share_credentials_version=2, + type="basic", + endpoint="https://localhost/delta-sharing/", + username="username", + password="password" + ) + provider = AuthCredentialProviderFactory.create_auth_credential_provider(profile_basic) + assert isinstance(provider, BasicAuthProvider) + + profile_bearer = DeltaSharingProfile( + share_credentials_version=1, + type="bearer_token", + endpoint="https://localhost/delta-sharing/", + bearer_token="token", + expiration_time=(datetime.now() + timedelta(hours=1)).isoformat() + ) + provider = AuthCredentialProviderFactory.create_auth_credential_provider(profile_bearer) + assert isinstance(provider, BearerTokenAuthProvider) + + profile_oauth = DeltaSharingProfile( + share_credentials_version=2, + type="oauth_client_credentials", + endpoint="https://localhost/delta-sharing/", + token_endpoint="https://localhost/token", + client_id="clientId", + client_secret="clientSecret" + ) + provider = AuthCredentialProviderFactory.create_auth_credential_provider(profile_oauth) + assert isinstance(provider, OAuthClientCredentialsAuthProvider) + + +def test_oauth_auth_provider_reused(): + profile_oauth1 = DeltaSharingProfile( + share_credentials_version=2, + type="oauth_client_credentials", + endpoint="https://localhost/delta-sharing/", + token_endpoint="https://localhost/token", + client_id="clientId", + client_secret="clientSecret" + ) + provider1 = AuthCredentialProviderFactory.create_auth_credential_provider(profile_oauth1) + assert isinstance(provider1, OAuthClientCredentialsAuthProvider) + + profile_oauth2 = DeltaSharingProfile( + share_credentials_version=2, + type="oauth_client_credentials", + endpoint="https://localhost/delta-sharing/", + token_endpoint="https://localhost/token", + client_id="clientId", + client_secret="clientSecret" + ) + + provider2 = AuthCredentialProviderFactory.create_auth_credential_provider(profile_oauth2) + + assert provider1 == provider2 + + +def test_oauth_auth_provider_with_different_profiles(): + profile_oauth1 = DeltaSharingProfile( + share_credentials_version=2, + type="oauth_client_credentials", + endpoint="https://localhost/delta-sharing/", + token_endpoint="https://localhost/1/token", + client_id="clientId", + client_secret="clientSecret" + ) + provider1 = AuthCredentialProviderFactory.create_auth_credential_provider(profile_oauth1) + assert isinstance(provider1, OAuthClientCredentialsAuthProvider) + + profile_oauth2 = DeltaSharingProfile( + share_credentials_version=2, + type="oauth_client_credentials", + endpoint="https://localhost/delta-sharing/", + token_endpoint="https://localhost/2/token", + client_id="clientId", + client_secret="clientSecret" + ) + + provider2 = AuthCredentialProviderFactory.create_auth_credential_provider(profile_oauth2) + + assert provider1 != provider2 diff --git a/python/delta_sharing/tests/test_oauth_client.py b/python/delta_sharing/tests/test_oauth_client.py new file mode 100644 index 000000000..bf2316cca --- /dev/null +++ b/python/delta_sharing/tests/test_oauth_client.py @@ -0,0 +1,81 @@ +# +# Copyright (C) 2021 The Delta Lake Project Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import pytest +import requests +from requests.models import Response +from unittest.mock import patch +from datetime import datetime +from delta_sharing._internal_auth import OAuthClient + + +class MockServer: + def __init__(self): + self.url = "http://localhost:1080/token" + self.responses = [] + + def add_response(self, status_code, json_data): + response = Response() + response.status_code = status_code + response._content = json_data.encode('utf-8') + self.responses.append(response) + + def get_response(self): + return self.responses.pop(0) + + +@pytest.fixture +def mock_server(): + server = MockServer() + yield server + + +def test_oauth_client_should_parse_token_response_correctly(mock_server): + mock_server.add_response( + 200, + '{"access_token": "test-access-token", "expires_in": 3600, "token_type": "bearer"}') + + with patch('requests.post') as mock_post: + mock_post.side_effect = lambda *args, **kwargs: mock_server.get_response() + oauth_client = OAuthClient( + token_endpoint=mock_server.url, + client_id="client-id", + client_secret="client-secret" + ) + + start = datetime.now().timestamp() + token = oauth_client.client_credentials() + end = datetime.now().timestamp() + + assert token.access_token == "test-access-token" + assert token.expires_in == 3600 + assert int(start) <= token.creation_timestamp + assert token.creation_timestamp <= int(end) + + +def test_oauth_client_should_handle_401_unauthorized_response(mock_server): + mock_server.add_response(401, 'Unauthorized') + + with patch('requests.post') as mock_post: + mock_post.side_effect = lambda *args, **kwargs: mock_server.get_response() + oauth_client = OAuthClient( + token_endpoint=mock_server.url, + client_id="client-id", + client_secret="client-secret" + ) + try: + oauth_client.client_credentials() + except requests.HTTPError as e: + assert e.response.status_code == 401 diff --git a/python/delta_sharing/tests/test_profile_oauth2.json b/python/delta_sharing/tests/test_profile_oauth2.json index 2242aa57f..3e253832a 100644 --- a/python/delta_sharing/tests/test_profile_oauth2.json +++ b/python/delta_sharing/tests/test_profile_oauth2.json @@ -1,6 +1,6 @@ { "shareCredentialsVersion": 2, - "type": "persistent_oauth2.0", + "type": "oauth_client_credentials", "endpoint": "https://localhost/delta-sharing/", "tokenEndpoint": "tokenEndpoint", "clientId": "clientId", diff --git a/python/delta_sharing/tests/test_protocol.py b/python/delta_sharing/tests/test_protocol.py index 1e980b5a2..3301eb390 100644 --- a/python/delta_sharing/tests/test_protocol.py +++ b/python/delta_sharing/tests/test_protocol.py @@ -186,11 +186,11 @@ def test_share_profile_bearer(tmp_path): DeltaSharingProfile.read_from_file(io.StringIO(json)) -def test_share_profile_oauth2(tmp_path): +def oauth_client_credentials(tmp_path): json = """ { "shareCredentialsVersion": 2, - "type": "persistent_oauth2.0", + "type": "oauth_client_credentials", "endpoint": "https://localhost/delta-sharing/", "tokenEndpoint": "tokenEndpoint", "clientId": "clientId", @@ -202,7 +202,7 @@ def test_share_profile_oauth2(tmp_path): "https://localhost/delta-sharing", None, None, - "persistent_oauth2.0", + "oauth_client_credentials", "tokenEndpoint", "clientId", "clientSecret") @@ -212,7 +212,7 @@ def test_share_profile_oauth2(tmp_path): "https://localhost/delta-sharing", None, None, - "persistent_oauth2.0", + "oauth_client_credentials", "tokenEndpoint", "clientId", "clientSecret") @@ -226,7 +226,7 @@ def test_share_profile_oauth2(tmp_path): "https://localhost/delta-sharing", None, None, - "persistent_oauth2.0", + "oauth_client_credentials", "tokenEndpoint", "clientId", "clientSecret") @@ -236,7 +236,7 @@ def test_share_profile_oauth2(tmp_path): "https://localhost/delta-sharing", None, None, - "persistent_oauth2.0", + "oauth_client_credentials", "tokenEndpoint", "clientId", "clientSecret") @@ -246,7 +246,7 @@ def test_share_profile_oauth2(tmp_path): "https://localhost/delta-sharing", None, None, - "persistent_oauth2.0", + "oauth_client_credentials", "tokenEndpoint", "clientId", "clientSecret") @@ -256,7 +256,7 @@ def test_share_profile_oauth2(tmp_path): "https://localhost/delta-sharing", None, None, - "persistent_oauth2.0", + "oauth_client_credentials", "tokenEndpoint", "clientId", "clientSecret") @@ -264,7 +264,7 @@ def test_share_profile_oauth2(tmp_path): json = """ { "shareCredentialsVersion": 100, - "type": "persistent_oauth2.0", + "type": "oauth_client_credentials", "endpoint": "https://localhost/delta-sharing/", "tokenEndpoint": "tokenEndpoint", "clientId": "clientId",