diff --git a/docs/usage.rst b/docs/usage.rst index b538ba2..4633d1d 100644 --- a/docs/usage.rst +++ b/docs/usage.rst @@ -60,6 +60,10 @@ In case you want to explicitly pass the credentials from Python, use :py:class:` | +.. autoclass:: aiodynamo.credentials.ProcessCredentials + +| + .. autoclass:: aiodynamo.credentials.Key :members: :undoc-members: diff --git a/src/aiodynamo/credentials.py b/src/aiodynamo/credentials.py index 2977328..f0a940a 100644 --- a/src/aiodynamo/credentials.py +++ b/src/aiodynamo/credentials.py @@ -572,6 +572,68 @@ async def fetch_metadata(self, http: HttpImplementation) -> Metadata: ) +@dataclass +class ProcessCredentialsError(Exception): + reason: str + return_code: int | None + stdout: bytes + stderr: bytes + + +@dataclass +class ProcessCredentials(MetadataCredentials): + command: list[str] + + async def fetch_metadata(self, http: HttpImplementation) -> Metadata: + process = await asyncio.create_subprocess_exec( + *self.command, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + stdin=asyncio.subprocess.PIPE, + ) + stdout, stderr = await process.communicate(b"") + if process.returncode != 0: + raise ProcessCredentialsError( + f"Process terminated with non-zero return code {process.returncode}: {try_decode(stderr)}", + process.returncode, + stdout, + stderr, + ) + try: + data = json.loads(stdout) + except json.JSONDecodeError: + raise ProcessCredentialsError( + f"Process returned non-JSON string: {try_decode(stdout)}", + process.returncode, + stdout, + stderr, + ) + if data["version"] != 1: + raise ProcessCredentialsError( + f"Process returned unsupported version: {data['version']}", + process.returncode, + stdout, + stderr, + ) + + key = Key( + data["access_key_id"], + data["secret_access_key"], + data.get("session_token", None), + ) + return Metadata(key, parse_amazon_timestamp(data["expiration"])) + + def is_disabled(self) -> bool: + return False + + +def try_decode(data: bytes) -> str: + try: + return data.decode() + except: + return repr(data) + + class TooManyRetries(Exception): pass diff --git a/tests/unit/process_credentials.py b/tests/unit/process_credentials.py new file mode 100644 index 0000000..b2a4cf6 --- /dev/null +++ b/tests/unit/process_credentials.py @@ -0,0 +1,12 @@ +import sys + + +def main() -> None: + if sys.argv[1] == "-x": + sys.exit(-1) + else: + sys.stdout.write(sys.argv[1]) + + +if __name__ == "__main__": + main() diff --git a/tests/unit/test_credentials.py b/tests/unit/test_credentials.py index 7615446..b06dc2f 100644 --- a/tests/unit/test_credentials.py +++ b/tests/unit/test_credentials.py @@ -1,8 +1,12 @@ import asyncio +import contextlib import datetime +import inspect +import json +import sys from pathlib import Path from textwrap import dedent -from typing import AsyncGenerator, Optional, Type, Union +from typing import Any, AsyncGenerator, ContextManager, Optional, Type, Union import pytest from _pytest.monkeypatch import MonkeyPatch @@ -23,13 +27,13 @@ InstanceMetadataCredentialsV2, Key, Metadata, + ProcessCredentials, + ProcessCredentialsError, Refresh, Refreshable, ) from aiodynamo.http.types import HttpImplementation, Request, RequestFailed, Response -pytestmark = [pytest.mark.usefixtures("fs")] - class InstanceMetadataServer: def __init__(self) -> None: @@ -328,3 +332,60 @@ async def refresher(http: HttpImplementation) -> int: await refreshable._active_refresh_task assert refreshable._active_refresh_task is None assert refreshable._current == 1 + + +@pytest.mark.parametrize( + "data,result", + [ + pytest.param( + json.dumps( + { + "version": 1, + "access_key_id": "foo", + "secret_access_key": "bar", + "session_token": "baz", + "expiration": ( + datetime.datetime.now() + datetime.timedelta(days=2) + ).strftime("%Y-%m-%dT%H:%M:%SZ"), + } + ), + Key("foo", "bar", "baz"), + id="good", + ), + pytest.param( + json.dumps( + { + "version": 2, + "access_key_id": "foo", + "secret_access_key": "bar", + "session_token": "baz", + "expiration": ( + datetime.datetime.now() + datetime.timedelta(days=2) + ).strftime("%Y-%m-%dT%H:%M:%SZ"), + } + ), + ProcessCredentialsError, + id="bad-version", + ), + pytest.param(json.dumps({}), KeyError, id="bad-json"), + pytest.param("this is not json", ProcessCredentialsError, id="not-json"), + pytest.param("-x", ProcessCredentialsError, id="bad-return-code"), + ], +) +async def test_process_credentials(data: str, result: Any) -> None: + creds = ProcessCredentials( + [ + sys.executable, + str(Path(__file__).parent.joinpath("process_credentials.py")), + data, + ] + ) + with magic(result): + assert await creds.get_key(null_http) == result + + +def magic(result: Any) -> ContextManager[Any]: + if inspect.isclass(result) and issubclass(result, Exception): + return pytest.raises(result) + else: + return contextlib.nullcontext()