diff --git a/tosfs/certification.py b/tosfs/certification.py index f3af727..922bb6c 100644 --- a/tosfs/certification.py +++ b/tosfs/certification.py @@ -14,10 +14,12 @@ """It contains everything about certification via a file-based provider.""" import threading -from datetime import datetime +from datetime import datetime, timedelta from typing import Optional from xml.etree import ElementTree +import requests +from tos.consts import ECS_DATE_FORMAT from tos.credential import Credentials, CredentialsProvider from tosfs.core import logger @@ -182,3 +184,55 @@ def _try_get_credentials(self) -> Optional[Credentials]: ) return None return self.credentials + + +class UrlCredentialsProvider(CredentialsProvider): + """The class provides the credentials from an url.""" + + def __init__(self, credential_url: str): + """Initialize the UrlCredentialsProvider.""" + if not credential_url: + raise TosfsCertificationError("The credential_url param must not be empty.") + self._lock = threading.Lock() + self.expires: Optional[datetime] = None + self.credentials = None + self.credential_url = credential_url + + def get_credentials(self) -> Credentials: + """Get the credentials from the url.""" + res = self._try_get_credentials() + if res is not None: + return res + with self._lock: + try: + res = self._try_get_credentials() + if res is not None: + return res + + res = requests.get(self.credential_url, timeout=30) + res_body = res.json() + self.credentials = Credentials( + res_body.get("AccessKeyId"), + res_body.get("SecretAccessKey"), + res_body.get("SessionToken"), + ) + self.expires = datetime.strptime( + res_body.get("ExpiredTime"), ECS_DATE_FORMAT + ) + return self.credentials + except Exception as e: + if self.expires is not None and ( + datetime.now().timestamp() < self.expires.timestamp() + ): + return self.credentials + raise TosfsCertificationError("Get token failed") from e + + def _try_get_credentials(self) -> Optional[Credentials]: + if self.expires is None or self.credentials is None: + return None + if ( + datetime.now().timestamp() + > (self.expires - timedelta(minutes=10)).timestamp() + ): + return None + return self.credentials diff --git a/tosfs/exceptions.py b/tosfs/exceptions.py index f89dc7c..321e483 100644 --- a/tosfs/exceptions.py +++ b/tosfs/exceptions.py @@ -13,19 +13,27 @@ # limitations under the License. """It contains exceptions definition for the tosfs package.""" +from typing import Optional class TosfsError(Exception): """Base class for all tosfs exceptions.""" - def __init__(self, message: str): + def __init__(self, msg: str, cause: Optional[Exception] = None): """Initialize the base class for all exceptions in the tosfs package.""" - super().__init__(message) + super().__init__(msg, cause) + self.message = msg + self.cause = cause + + def __str__(self) -> str: + """Return the string representation of the exception.""" + error = {"message": self.message, "case": str(self.cause)} + return str(error) class TosfsCertificationError(TosfsError): """Exception class for certification related exception.""" - def __init__(self, message: str): + def __init__(self, message: str, cause: Optional[Exception] = None): """Initialize the exception class for certification related exception.""" - super().__init__(message) + super().__init__(message, cause)