Skip to content

Commit

Permalink
Provide a FileCredentialsProvider for assume role (#296)
Browse files Browse the repository at this point in the history
  • Loading branch information
yanghua authored Nov 25, 2024
1 parent 38a236c commit e390679
Show file tree
Hide file tree
Showing 3 changed files with 376 additions and 1 deletion.
184 changes: 184 additions & 0 deletions tosfs/certification.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
# ByteDance Volcengine EMR, Copyright 2024.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""It contains everything about certification via a file-based provider."""

import threading
from datetime import datetime
from typing import Optional
from xml.etree import ElementTree

from tos.credential import Credentials, CredentialsProvider

from tosfs.core import logger
from tosfs.exceptions import TosfsCertificationError

CERTIFICATION_REFRESH_INTERVAL_MINUTES = 60
CERTIFICATION_MAX_VALID_PERIOD_HOURS = 12


class FileCredentialsProvider(CredentialsProvider):
"""The class provides the credentials from a file.
The file should be in the format of:
<configuration>
<property>
<name>fs.tos.access-key-id</name>
<value>access_key</value>
</property>
<property>
<name>fs.tos.secret-access-key</name>
<value>secret_key</value>
</property>
<property>
<name>fs.tos.session-token</name>
<value>session_token</value>
</property>
</configuration>
It can only receive a file path which exists in the local file system.
Note :
1. The default refresh interval is 60 minutes.
2. The maximum valid period for provisional certification is 12 hours.
This provider will cache the credentials and refresh them every 60 minutes.
And note that, it only reads the credentials from the file and refreshes itself,
to guarantee the credentials are always up-to-date,
we need a service update it internally.
Examples
--------
>>> from tosfs import TosFileSystem
>>> tosfs = TosFileSystem(
>>> endpoint="tos-cn-beijing.volcengine.com",
>>> regions="cn-bejing",
>>> credentials_provider=FileCredentialsProvider("dummy_path"))
>>> tosfs.ls("tos://bucket/path")
"""

def __init__(
self,
file_path: str,
refresh_interval_min: int = CERTIFICATION_REFRESH_INTERVAL_MINUTES,
) -> None:
"""Initialize the FileCredentialsProvider."""
self.file_path = file_path
self.refresh_interval = refresh_interval_min
if self.refresh_interval <= 0:
logger.warning(
f"Invalid refresh interval {self.refresh_interval}, "
f"set to default value: 60 minutes"
)
self.refresh_interval = CERTIFICATION_REFRESH_INTERVAL_MINUTES
if self.refresh_interval > CERTIFICATION_MAX_VALID_PERIOD_HOURS * 60:
logger.warning(
f"Invalid refresh interval {self.refresh_interval}, "
f"set to maximum value: 60 minutes"
)
self.refresh_interval = CERTIFICATION_REFRESH_INTERVAL_MINUTES
self.prev_refresh_time: Optional[datetime] = None
self.credentials = None
self._lock = threading.Lock()

def get_credentials(self) -> Credentials:
"""Get the credentials from the file.
Returns
-------
Credentials: The credentials object.
Raises
------
TosfsCertificationError: If the credentials cannot be retrieved.
"""
res = self._try_get_credentials()
if res is not None:
return res
with self._lock:
try:
res = self._try_get_credentials()
if res is not None:
return res

with open(self.file_path, "r") as f:
logger.debug(
f"Trying to refresh the credentials from file: "
f"{self.file_path}"
)
tree = ElementTree.parse(f) # noqa S314
root = tree.getroot()

access_key_element = root.find(
".//property[name='fs.tos.access-key-id']/value"
)
secret_key_element = root.find(
".//property[name='fs.tos.secret-access-key']/value"
)
session_token_element = root.find(
".//property[name='fs.tos.session-token']/value"
)

access_key = (
access_key_element.text
if access_key_element is not None
else None
)
secret_key = (
secret_key_element.text
if secret_key_element is not None
else None
)
session_token = (
session_token_element.text
if session_token_element is not None
else None
)

if None in (
access_key,
secret_key,
session_token,
):
raise TosfsCertificationError(
"Missing required credential elements in the file"
)

self.prev_refresh_time = datetime.now()
self.credentials = Credentials(
access_key, secret_key, session_token
)
logger.debug(
f"Successfully refreshed the credentials from file: "
f"{self.file_path}"
)

return self.credentials
except Exception as e:
raise TosfsCertificationError("Get certification error: ") from e

def _try_get_credentials(self) -> Optional[Credentials]:
if self.prev_refresh_time is None or self.credentials is None:
return None
if (
datetime.now() - self.prev_refresh_time
).total_seconds() / 60 > CERTIFICATION_REFRESH_INTERVAL_MINUTES:
logger.debug(
f"Credentials are expired, "
f"will try to refresh the credentials from file: "
f"{self.file_path}"
)
return None
return self.credentials
11 changes: 10 additions & 1 deletion tosfs/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""iT contains exceptions definition for the tosfs package."""

"""It contains exceptions definition for the tosfs package."""


class TosfsError(Exception):
Expand All @@ -20,3 +21,11 @@ class TosfsError(Exception):
def __init__(self, message: str):
"""Initialize the base class for all exceptions in the tosfs package."""
super().__init__(message)


class TosfsCertificationError(TosfsError):
"""Exception class for certification related exception."""

def __init__(self, message: str):
"""Initialize the exception class for certification related exception."""
super().__init__(message)
182 changes: 182 additions & 0 deletions tosfs/tests/test_certification.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
# ByteDance Volcengine EMR, Copyright 2024.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from datetime import datetime, timedelta
from unittest.mock import mock_open, patch

import pytest

from tosfs.certification import FileCredentialsProvider
from tosfs.exceptions import TosfsCertificationError


@patch(
"builtins.open",
new_callable=mock_open,
read_data="""
<configuration>
<property>
<name>fs.tos.access-key-id</name>
<value>access_key</value>
</property>
<property>
<name>fs.tos.secret-access-key</name>
<value>secret_key</value>
</property>
<property>
<name>fs.tos.session-token</name>
<value>session_token</value>
</property>
</configuration>
""",
)
def test_get_credentials(mock_file):
provider = FileCredentialsProvider("dummy_path")
credentials = provider.get_credentials()
assert credentials.access_key_id == "access_key"
assert credentials.access_key_secret == "secret_key" # noqa S105
assert credentials.security_token == "session_token" # noqa S105


@patch(
"builtins.open",
new_callable=mock_open,
read_data="""
<configuration>
<property>
<name>fs.tos.access-key-id</name>
<value>access_key</value>
</property>
<property>
<name>fs.tos.secret-access-key</name>
<value>secret_key</value>
</property>
<property>
<name>fs.tos.session-token</name>
<value>session_token</value>
</property>
</configuration>
""",
)
def test_get_credentials_with_refresh(mock_file):
provider = FileCredentialsProvider("dummy_path")
provider.prev_refresh_time = datetime.now() - timedelta(minutes=61)
credentials = provider.get_credentials()
assert credentials.access_key_id == "access_key"
assert credentials.access_key_secret == "secret_key" # noqa S105
assert credentials.security_token == "session_token" # noqa S105


@patch(
"builtins.open",
new_callable=mock_open,
read_data="""
<configuration>
<property>
<name>fs.tos.access-key-id</name>
<value>access_key</value>
</property>
<property>
<name>fs.tos.secret-access-key</name>
<value>secret_key</value>
</property>
<property>
<name>fs.tos.session-token</name>
<value>session_token</value>
</property>
</configuration>
""",
)
def test_get_credentials_error(mock_file):
provider = FileCredentialsProvider("dummy_path")
with patch("xml.etree.ElementTree.parse", side_effect=Exception("Parse error")):
with pytest.raises(TosfsCertificationError):
provider.get_credentials()


@patch(
"builtins.open",
new_callable=mock_open,
read_data="""
<configuration>
<property>
<key>fs.tos.access-key-id</name>
<value>access_key</value>
</property>
<property>
<key>fs.tos.secret-access-key</name>
<value>secret_key</value>
</property>
<property>
<key>fs.tos.session-token</name>
<value>session_token</value>
</property>
</configuration>
""",
)
def test_wrong_file_format_error(mock_file):
provider = FileCredentialsProvider("dummy_path")
provider.prev_refresh_time = datetime.now() - timedelta(minutes=61)
with pytest.raises(TosfsCertificationError):
provider.get_credentials()


@patch(
"builtins.open",
new_callable=mock_open,
read_data="""
<configuration>
<property>
<name>fs.tos.access-key-id</name>
<value>access_key</value>
</property>
<property>
<name>fs.tos.secret-access-key</name>
<value>secret_key</value>
</property>
<property>
<name>fs.tos.session-token</name>
<value>session_token</value>
</property>
</configuration>
""",
)
def test_no_refresh_within_interval(mock_file):
provider = FileCredentialsProvider("dummy_path")
provider.get_credentials()

provider.prev_refresh_time = datetime.now() - timedelta(minutes=30)

mock_file().read_data = """
<configuration>
<property>
<name>fs.tos.access-key-id</name>
<value>new_access_key</value>
</property>
<property>
<name>fs.tos.secret-access-key</name>
<value>new_secret_key</value>
</property>
<property>
<name>fs.tos.session-token</name>
<value>new_session_token</value>
</property>
</configuration>
"""

new_credentials = provider.get_credentials()

assert new_credentials.access_key_id == "access_key"
assert new_credentials.access_key_secret == "secret_key" # noqa S105
assert new_credentials.security_token == "session_token" # noqa S105

0 comments on commit e390679

Please sign in to comment.