Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Python DeltaSharing D2O Secret-less managed identity support for Azure Compute VM #574

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 107 additions & 1 deletion python/delta_sharing/_internal_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,94 @@ def __init__(self, access_token: str, expires_in: int, creation_timestamp: int):
self.creation_timestamp = creation_timestamp


class ManagedIdentityAuthProvider(AuthCredentialProvider):
def __init__(self, auth_config: AuthConfig = AuthConfig()):
self.auth_config = auth_config
self.current_token: Optional[OAuthClientCredentials] = None
self.lock = threading.RLock()

def get_managed_identity_token(self) -> OAuthClientCredentials:
# Azure IMDS endpoint to get the access token
url = "http://169.254.169.254/metadata/identity/oauth2/token"

resource = "https://management.azure.com/"
# Headers required to access Azure Instance Metadata Service
headers = {"Metadata": "true"}

# Parameters to specify the resource and API version
params = {
"api-version": "2019-08-01",
"resource": resource
}

# Make the GET request to fetch the token
response = requests.get(url, headers=headers, params=params)

# Check if the request was successful
if response.status_code == 200:
# Return the access token
return self.parse_oauth_token_response(response.text)

else:
# Handle errors
raise Exception(f"Failed to obtain token: {response.status_code} - {response.text}")

def parse_oauth_token_response(self, response: str) -> OAuthClientCredentials:
if not response:
raise RuntimeError("Empty response from OAuth token endpoint")
json_node = json.loads(response)
if 'access_token' not in json_node or not isinstance(json_node['access_token'], str):
raise RuntimeError("Missing 'access_token' field in OAuth token response")
expires_in = None
if 'expires_in' not in json_node:
raise RuntimeError("Missing 'expires_in' field in OAuth token response")
elif isinstance(json_node['expires_in'], int):
expires_in = json_node['expires_in']
elif isinstance(json_node['expires_in'], str):
expires_in = int(json_node['expires_in'])
else:
raise RuntimeError("Invalid 'expires_in' field in OAuth token response")
return OAuthClientCredentials(
json_node['access_token'],
expires_in,
int(datetime.now().timestamp())
)

def add_auth_header(self,session: requests.Session) -> None:
token = self.maybe_refresh_token()

print("######")
print(token.access_token)
print("######")

with self.lock:
session.headers.update(
{
"Authorization": f"Bearer {token.access_token}",
}
)

def maybe_refresh_token(self) -> OAuthClientCredentials:
with self.lock:
if self.current_token and not self.needs_refresh(self.current_token):
return self.current_token
new_token = self.get_managed_identity_token()
self.current_token = new_token
return new_token

def needs_refresh(self, token: OAuthClientCredentials) -> bool:
now = int(time.time())
expiration_time = token.creation_timestamp + token.expires_in
return expiration_time - now < self.auth_config.token_renewal_threshold_in_seconds

def is_expired(self) -> bool:
return False

def get_expiration_time(self) -> Optional[str]:
return None



class OAuthClient:
def __init__(self,
token_endpoint: str,
Expand Down Expand Up @@ -186,8 +274,10 @@ class AuthCredentialProviderFactory:
@staticmethod
def create_auth_credential_provider(profile: DeltaSharingProfile):
if profile.share_credentials_version == 2:
if profile.type == "oauth_client_credentials":
if profile.type == "oauth_client_credentials" or profile.type == "oidc_client_credentials":
return AuthCredentialProviderFactory.__oauth_client_credentials(profile)
elif profile.type == "oidc_managed_identity":
return AuthCredentialProviderFactory.__oidc_managed_identity(profile)
elif profile.type == "basic":
return AuthCredentialProviderFactory.__auth_basic(profile)
elif (profile.share_credentials_version == 1 and
Expand Down Expand Up @@ -224,6 +314,22 @@ def __oauth_client_credentials(profile):
AuthCredentialProviderFactory.__oauth_auth_provider_cache[profile] = provider
return provider

@staticmethod
def __oidc_managed_identity(profile):
# Once a clientId/clientSecret is exchanged for an accessToken,
# the accessToken can be reused until it expires.
# The Python client re-creates DeltaSharingClient for different requests.
# To ensure the OAuth access_token is reused,
# we keep a mapping from profile -> OAuthClientCredentialsAuthProvider.
# This prevents re-initializing OAuthClientCredentialsAuthProvider for the same profile,
# ensuring the access_token can be reused.
if profile in AuthCredentialProviderFactory.__oauth_auth_provider_cache:
return AuthCredentialProviderFactory.__oauth_auth_provider_cache[profile]

provider = ManagedIdentityAuthProvider()
AuthCredentialProviderFactory.__oauth_auth_provider_cache[profile] = provider
return provider

@staticmethod
def __auth_bearer_token(profile):
return BearerTokenAuthProvider(profile.bearer_token, profile.expiration_time)
Expand Down
6 changes: 6 additions & 0 deletions python/delta_sharing/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,12 @@ def from_json(json) -> "DeltaSharingProfile":
bearer_token=json["bearerToken"],
expiration_time=json.get("expirationTime")
)
elif type == "oidc_managed_identity":
return DeltaSharingProfile(
share_credentials_version=share_credentials_version,
type=type,
endpoint=endpoint
)
elif type == "basic":
return DeltaSharingProfile(
share_credentials_version=share_credentials_version,
Expand Down
1 change: 1 addition & 0 deletions python/delta_sharing/rest_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ def __init__(self, profile: DeltaSharingProfile, num_retries=10):
self._session.headers.update(
{
"User-Agent": DataSharingRestClient.USER_AGENT,
"Custom-Header-Recipient-Id": "7ccbb5da-b1b1-4519-ae53-190db7988199"
}
)

Expand Down
Loading