Skip to content

Commit

Permalink
Provide a UrlCredentialsProvider for assume role (#317)
Browse files Browse the repository at this point in the history
  • Loading branch information
yanghua authored Dec 2, 2024
1 parent 8ccf49c commit 184df95
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 5 deletions.
56 changes: 55 additions & 1 deletion tosfs/certification.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,12 @@
"""It contains everything about certification via a file-based provider."""

import threading
from datetime import datetime
from datetime import datetime, timedelta
from typing import Optional
from xml.etree import ElementTree

import requests
from tos.consts import ECS_DATE_FORMAT
from tos.credential import Credentials, CredentialsProvider

from tosfs.core import logger
Expand Down Expand Up @@ -182,3 +184,55 @@ def _try_get_credentials(self) -> Optional[Credentials]:
)
return None
return self.credentials


class UrlCredentialsProvider(CredentialsProvider):
"""The class provides the credentials from an url."""

def __init__(self, credential_url: str):
"""Initialize the UrlCredentialsProvider."""
if not credential_url:
raise TosfsCertificationError("The credential_url param must not be empty.")
self._lock = threading.Lock()
self.expires: Optional[datetime] = None
self.credentials = None
self.credential_url = credential_url

def get_credentials(self) -> Credentials:
"""Get the credentials from the url."""
res = self._try_get_credentials()
if res is not None:
return res
with self._lock:
try:
res = self._try_get_credentials()
if res is not None:
return res

res = requests.get(self.credential_url, timeout=30)
res_body = res.json()
self.credentials = Credentials(
res_body.get("AccessKeyId"),
res_body.get("SecretAccessKey"),
res_body.get("SessionToken"),
)
self.expires = datetime.strptime(
res_body.get("ExpiredTime"), ECS_DATE_FORMAT
)
return self.credentials
except Exception as e:
if self.expires is not None and (
datetime.now().timestamp() < self.expires.timestamp()
):
return self.credentials
raise TosfsCertificationError("Get token failed") from e

def _try_get_credentials(self) -> Optional[Credentials]:
if self.expires is None or self.credentials is None:
return None
if (
datetime.now().timestamp()
> (self.expires - timedelta(minutes=10)).timestamp()
):
return None
return self.credentials
16 changes: 12 additions & 4 deletions tosfs/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,27 @@
# limitations under the License.

"""It contains exceptions definition for the tosfs package."""
from typing import Optional


class TosfsError(Exception):
"""Base class for all tosfs exceptions."""

def __init__(self, message: str):
def __init__(self, msg: str, cause: Optional[Exception] = None):
"""Initialize the base class for all exceptions in the tosfs package."""
super().__init__(message)
super().__init__(msg, cause)
self.message = msg
self.cause = cause

def __str__(self) -> str:
"""Return the string representation of the exception."""
error = {"message": self.message, "case": str(self.cause)}
return str(error)


class TosfsCertificationError(TosfsError):
"""Exception class for certification related exception."""

def __init__(self, message: str):
def __init__(self, message: str, cause: Optional[Exception] = None):
"""Initialize the exception class for certification related exception."""
super().__init__(message)
super().__init__(message, cause)

0 comments on commit 184df95

Please sign in to comment.