Skip to content

Commit

Permalink
Fix timing bug in Refreshable, optimize ChainCredentials
Browse files Browse the repository at this point in the history
Refreshable had a timing bug where refreshing the token could result in
an invalid state if invalidate() was called at the same time. This is
due to the Event Loop being able to schedule other coroutines to run
between the current value being set in _refresh() and the value being
checked in get(). By moving the assignment from _refresh() to get(), no
other coroutines can run between the setting and checking, preventing
the class from ending up in an invalid state.

ChainCredentials would needlessly check many candidates every time, even
if it had already done so previously. Since running code cannot move to
different environments, this check is unneccessary and inefficient, so
instead ChainCredentials now remembers when a provider has succeeded and
will only try that one again.
  • Loading branch information
ojii committed Oct 19, 2023
1 parent a2d320b commit f9a4f98
Show file tree
Hide file tree
Showing 5 changed files with 110 additions and 20 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,5 +33,5 @@ repos:
language: system
types: [python]
pass_filenames: false
args: ['check', 'src', 'tests']
args: ['check', '--fix', 'src', 'tests']

6 changes: 6 additions & 0 deletions docs/changelog.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
Changelog
=========

23.10.1
-------

* Fixed issue with refreshable credentials not working due to timing issue introduced in 23.10.
* Improved performance of :py:class:`aiodynamo.credentials.ChainCredentials`

23.10
-----

Expand Down
40 changes: 26 additions & 14 deletions src/aiodynamo/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
from enum import Enum, auto
from pathlib import Path
from typing import (
Awaitable,
Callable,
Coroutine,
Generic,
Optional,
Sequence,
Expand Down Expand Up @@ -99,17 +99,24 @@ class ChainCredentials(Credentials):
"""
Chains multiple credentials providers together, trying them
in order. Returns the first key found. Exceptions are suppressed.
Once a credentials provider returns a key, only that provider will
be used in subsequent calls.
"""

candidates: Sequence[Credentials]
_candidates: Sequence[Credentials]
_chosen: Credentials | None

def __init__(self, candidates: Sequence[Credentials]) -> None:
self.candidates = [
self._candidates = [
candidate for candidate in candidates if not candidate.is_disabled()
]
self._chosen = None

async def get_key(self, http: HttpImplementation) -> Optional[Key]:
for candidate in self.candidates:
if self._chosen is not None:
return await self._chosen.get_key(http)
for candidate in self._candidates:
try:
key = await candidate.get_key(http)
except:
Expand All @@ -119,14 +126,15 @@ async def get_key(self, http: HttpImplementation) -> Optional[Key]:
logger.debug("Candidate %r didn't find a key", candidate)
else:
logger.debug("Candidate %r found a key %r", candidate, key)
self._chosen = candidate
return key
return None

def invalidate(self) -> bool:
return any(candidate.invalidate() for candidate in self.candidates)
return any(candidate.invalidate() for candidate in self._candidates)

def is_disabled(self) -> bool:
return not self.candidates
return not self._candidates


class EnvironmentCredentials(Credentials):
Expand Down Expand Up @@ -239,8 +247,8 @@ class _Unset:
class Refreshable(Generic[T]):
name: str
should_refresh: Callable[[T], Refresh]
do_refresh: Callable[[HttpImplementation], Awaitable[T]]
_active_refresh_task: Optional[asyncio.Task[None]] = None
do_refresh: Callable[[HttpImplementation], Coroutine[None, None, T]]
_active_refresh_task: Optional[asyncio.Task[T]] = None
_current: Union[_Unset, T] = _UNSET

async def get(self, http: HttpImplementation) -> T:
Expand All @@ -250,16 +258,20 @@ async def get(self, http: HttpImplementation) -> T:
if self._active_refresh_task is None:
logger.debug("%s starting mandatory refresh", self.name)
self._active_refresh_task = task = asyncio.create_task(
self._refresh(http)
self.do_refresh(http)
)
task.add_done_callback(self._clear_active_refresh)
else:
logger.debug("%s re-using active refresh", self.name)
await self._active_refresh_task
self._current = await self._active_refresh_task
elif refresh is Refresh.soon:
if self._active_refresh_task is None:
logger.debug("%s starting early refresh", self.name)
self._active_refresh_task = asyncio.create_task(self._refresh(http))
self._active_refresh_task = task = asyncio.create_task(
self.do_refresh(http)
)
task.add_done_callback(self._clear_active_refresh)
task.add_done_callback(self._set_current)
else:
logger.debug("%s already refreshing", self.name)
assert not isinstance(self._current, _Unset)
Expand All @@ -273,10 +285,10 @@ def _check_refresh(self) -> Refresh:
return Refresh.required
return self.should_refresh(self._current)

async def _refresh(self, http: HttpImplementation) -> None:
self._current = await self.do_refresh(http)
def _set_current(self, task: asyncio.Task[T]) -> None:
self._current = task.result()

def _clear_active_refresh(self, _task: asyncio.Task[None]) -> None:
def _clear_active_refresh(self, _task: asyncio.Task[T]) -> None:
self._active_refresh_task = None


Expand Down
80 changes: 76 additions & 4 deletions tests/unit/test_credentials.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import datetime
from pathlib import Path
from textwrap import dedent
Expand All @@ -7,12 +8,13 @@
from _pytest.monkeypatch import MonkeyPatch
from aiohttp import web
from freezegun import freeze_time
from pyfakefs.fake_filesystem import FakeFilesystem # type: ignore[import]
from pyfakefs.fake_filesystem import FakeFilesystem # type: ignore[import-untyped]
from yarl import URL

from aiodynamo.credentials import (
EXPIRED_THRESHOLD,
EXPIRES_SOON_THRESHOLD,
ChainCredentials,
ContainerMetadataCredentials,
Credentials,
EnvironmentCredentials,
Expand All @@ -21,8 +23,10 @@
InstanceMetadataCredentialsV2,
Key,
Metadata,
Refresh,
Refreshable,
)
from aiodynamo.http.types import HttpImplementation
from aiodynamo.http.types import HttpImplementation, Request, RequestFailed, Response

pytestmark = [pytest.mark.usefixtures("fs")]

Expand Down Expand Up @@ -190,7 +194,7 @@ async def test_disabled(monkeypatch: MonkeyPatch) -> None:
monkeypatch.delenv("AWS_CONTAINER_CREDENTIALS_FULL_URI", raising=False)
monkeypatch.setenv("AWS_EC2_METADATA_DISABLED", "true")
creds = Credentials.auto()
assert creds.candidates == []
assert creds._candidates == []
assert creds.is_disabled()
assert EnvironmentCredentials().is_disabled()
assert InstanceMetadataCredentialsV2().is_disabled()
Expand All @@ -200,7 +204,7 @@ async def test_disabled(monkeypatch: MonkeyPatch) -> None:
assert not InstanceMetadataCredentialsV2().is_disabled()
assert not InstanceMetadataCredentialsV1().is_disabled()
creds = Credentials.auto()
assert len(creds.candidates) == 2
assert len(creds._candidates) == 2
assert not creds.is_disabled()


Expand Down Expand Up @@ -256,3 +260,71 @@ async def test_file_credentials(
assert await credentials.get_key(http) == Key(
id="custom-baz", secret="custom-hoge", token="custom-token"
)


async def null_http(request: Request) -> Response:
raise RequestFailed(Exception())


async def test_chain_credential_memory() -> None:
class BadLoader(Credentials):
def __init__(self) -> None:
self.called = 0

async def get_key(self, http: HttpImplementation) -> Optional[Key]:
self.called += 1
return None

def invalidate(self) -> bool:
return False

def is_disabled(self) -> bool:
return False

key = Key("id", "secret")

class GoodLoader(Credentials):
def __init__(self) -> None:
self.called = 0

async def get_key(self, http: HttpImplementation) -> Optional[Key]:
self.called += 1
return key

def invalidate(self) -> bool:
return False

def is_disabled(self) -> bool:
return False

bl = BadLoader()
gl = GoodLoader()
chain = ChainCredentials([bl, gl])
assert chain._chosen is None
assert await chain.get_key(null_http) is key
assert bl.called == 1
assert gl.called == 1
assert await chain.get_key(null_http) is key
assert bl.called == 1
assert gl.called == 2
assert chain._chosen is gl


async def test_refreshable_background() -> None:
ev = asyncio.Event()

async def refresher(http: HttpImplementation) -> int:
await ev.wait()
return 1

refreshable = Refreshable(
"test_refreshable_background", lambda _: Refresh.soon, refresher
)
refreshable._current = 0
assert refreshable._active_refresh_task is None
assert await refreshable.get(null_http) == 0
assert refreshable._active_refresh_task is not None
ev.set()
await refreshable._active_refresh_task
assert refreshable._active_refresh_task is None
assert refreshable._current == 1
2 changes: 1 addition & 1 deletion tests/unit/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Any, Callable, Dict

import pytest
from boto3.dynamodb.types import ( # type: ignore[import]
from boto3.dynamodb.types import ( # type: ignore[import-untyped]
DYNAMODB_CONTEXT,
TypeDeserializer,
)
Expand Down

0 comments on commit f9a4f98

Please sign in to comment.