Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use ParamSpec for wrapped signatures #508

Open
wants to merge 22 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions .github/workflows/ci-cd.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,6 @@ on:
- '[0-9].[0-9]+' # matches to backport branches, e.g. 3.6
tags: [ 'v*' ]
pull_request:
branches:
- master
- '[0-9].[0-9]+'


jobs:
Expand Down
24 changes: 24 additions & 0 deletions .mypy.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
[mypy]
files = async_lru, tests
check_untyped_defs = True
follow_imports_for_stubs = True
disallow_any_decorated = True
asvetlov marked this conversation as resolved.
Show resolved Hide resolved
disallow_any_generics = True
disallow_any_unimported = True
disallow_incomplete_defs = True
disallow_subclassing_any = True
disallow_untyped_calls = True
disallow_untyped_decorators = True
disallow_untyped_defs = True
enable_error_code = ignore-without-code, possibly-undefined, redundant-expr, redundant-self, truthy-bool, truthy-iterable, unused-awaitable
implicit_reexport = False
no_implicit_optional = True
pretty = True
show_column_numbers = True
show_error_codes = True
strict_equality = True
warn_incomplete_stub = True
warn_redundant_casts = True
warn_return_any = True
warn_unreachable = True
warn_unused_ignores = True
70 changes: 44 additions & 26 deletions async_lru/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from asyncio.coroutines import _is_coroutine # type: ignore[attr-defined]
from functools import _CacheInfo, _make_key, partial, partialmethod
from typing import (
Any,
Callable,
Coroutine,
Generic,
Expand All @@ -16,12 +15,17 @@
TypedDict,
TypeVar,
Union,
cast,
final,
overload,
)


if sys.version_info >= (3, 10):
from typing import ParamSpec
else:
from typing_extensions import ParamSpec


if sys.version_info >= (3, 11):
from typing import Self
else:
Expand All @@ -35,9 +39,7 @@

_T = TypeVar("_T")
_R = TypeVar("_R")
_Coro = Coroutine[Any, Any, _R]
_CB = Callable[..., _Coro[_R]]
_CBP = Union[_CB[_R], "partial[_Coro[_R]]", "partialmethod[_Coro[_R]]"]
_P = ParamSpec("_P")


@final
Expand All @@ -61,10 +63,10 @@ def cancel(self) -> None:


@final
class _LRUCacheWrapper(Generic[_R]):
class _LRUCacheWrapper(Generic[_P, _R]):
def __init__(
self,
fn: _CB[_R],
fn: Callable[_P, Coroutine[object, object, _R]],
maxsize: Optional[int],
typed: bool,
ttl: Optional[float],
Expand Down Expand Up @@ -106,7 +108,7 @@ def __init__(
self.__misses = 0
self.__tasks: Set["asyncio.Task[_R]"] = set()

def cache_invalidate(self, /, *args: Hashable, **kwargs: Any) -> bool:
def cache_invalidate(self, /, *args: _P.args, **kwargs: _P.kwargs) -> bool:
key = _make_key(args, kwargs, self.__typed)

cache_item = self.__cache.pop(key, None)
Expand Down Expand Up @@ -188,7 +190,7 @@ def _task_done_callback(

fut.set_result(task.result())

async def __call__(self, /, *fn_args: Any, **fn_kwargs: Any) -> _R:
async def __call__(self, /, *fn_args: _P.args, **fn_kwargs: _P.kwargs) -> _R:
if self.__closed:
raise RuntimeError(f"alru_cache is closed for {self}")

Expand All @@ -207,7 +209,7 @@ async def __call__(self, /, *fn_args: Any, **fn_kwargs: Any) -> _R:

fut = loop.create_future()
coro = self.__wrapped__(*fn_args, **fn_kwargs)
task: asyncio.Task[_R] = loop.create_task(coro)
task = loop.create_task(coro)
self.__tasks.add(task)
task.add_done_callback(partial(self._task_done_callback, fut, key))

Expand All @@ -220,20 +222,30 @@ async def __call__(self, /, *fn_args: Any, **fn_kwargs: Any) -> _R:
self._cache_miss(key)
return await asyncio.shield(fut)

@overload
def __get__(self, instance: _T, owner: None) -> Self:
...

@overload
def __get__(
self, instance: _T, owner: Type[_T]
) -> "_LRUCacheWrapperInstanceMethod[_P, _R, _T]":
...

def __get__(
self, instance: _T, owner: Optional[Type[_T]]
) -> Union[Self, "_LRUCacheWrapperInstanceMethod[_R, _T]"]:
) -> Union[Self, "_LRUCacheWrapperInstanceMethod[_P, _R, _T]"]:
if owner is None:
return self
else:
return _LRUCacheWrapperInstanceMethod(self, instance)


@final
class _LRUCacheWrapperInstanceMethod(Generic[_R, _T]):
class _LRUCacheWrapperInstanceMethod(Generic[_P, _R, _T]):
def __init__(
self,
wrapper: _LRUCacheWrapper[_R],
wrapper: _LRUCacheWrapper[_P, _R],
instance: _T,
) -> None:
try:
Expand Down Expand Up @@ -267,7 +279,7 @@ def __init__(
self.__instance = instance
self.__wrapper = wrapper

def cache_invalidate(self, /, *args: Hashable, **kwargs: Any) -> bool:
def cache_invalidate(self, /, *args: _P.args, **kwargs: _P.kwargs) -> bool:
return self.__wrapper.cache_invalidate(self.__instance, *args, **kwargs)

def cache_clear(self) -> None:
Expand All @@ -284,16 +296,18 @@ def cache_info(self) -> _CacheInfo:
def cache_parameters(self) -> _CacheParameters:
return self.__wrapper.cache_parameters()

async def __call__(self, /, *fn_args: Any, **fn_kwargs: Any) -> _R:
return await self.__wrapper(self.__instance, *fn_args, **fn_kwargs)
async def __call__(self, /, *fn_args: _P.args, **fn_kwargs: _P.kwargs) -> _R:
return await self.__wrapper(self.__instance, *fn_args, **fn_kwargs) # type: ignore[arg-type]


def _make_wrapper(
maxsize: Optional[int],
typed: bool,
ttl: Optional[float] = None,
) -> Callable[[_CBP[_R]], _LRUCacheWrapper[_R]]:
def wrapper(fn: _CBP[_R]) -> _LRUCacheWrapper[_R]:
) -> Callable[[Callable[_P, Coroutine[object, object, _R]]], _LRUCacheWrapper[_P, _R]]:
def wrapper(
fn: Callable[_P, Coroutine[object, object, _R]]
) -> _LRUCacheWrapper[_P, _R]:
origin = fn

while isinstance(origin, (partial, partialmethod)):
Expand All @@ -306,7 +320,7 @@ def wrapper(fn: _CBP[_R]) -> _LRUCacheWrapper[_R]:
if hasattr(fn, "_make_unbound_method"):
fn = fn._make_unbound_method()

return _LRUCacheWrapper(cast(_CB[_R], fn), maxsize, typed, ttl)
return _LRUCacheWrapper(fn, maxsize, typed, ttl)

return wrapper

Expand All @@ -317,30 +331,34 @@ def alru_cache(
typed: bool = False,
*,
ttl: Optional[float] = None,
) -> Callable[[_CBP[_R]], _LRUCacheWrapper[_R]]:
) -> Callable[[Callable[_P, Coroutine[object, object, _R]]], _LRUCacheWrapper[_P, _R]]:
...


@overload
def alru_cache(
maxsize: _CBP[_R],
maxsize: Callable[_P, Coroutine[object, object, _R]],
/,
) -> _LRUCacheWrapper[_R]:
) -> _LRUCacheWrapper[_P, _R]:
...


def alru_cache(
maxsize: Union[Optional[int], _CBP[_R]] = 128,
maxsize: Union[Optional[int], Callable[_P, Coroutine[object, object, _R]]] = 128,
typed: bool = False,
*,
ttl: Optional[float] = None,
) -> Union[Callable[[_CBP[_R]], _LRUCacheWrapper[_R]], _LRUCacheWrapper[_R]]:
) -> Union[
Callable[[Callable[_P, Coroutine[object, object, _R]]], _LRUCacheWrapper[_P, _R]],
_LRUCacheWrapper[_P, _R],
]:
if maxsize is None or isinstance(maxsize, int):
return _make_wrapper(maxsize, typed, ttl)
else:
fn = cast(_CB[_R], maxsize)
fn = maxsize

if callable(fn) or hasattr(fn, "_make_unbound_method"):
# partialmethod is not callable() at runtime.
if callable(fn) or hasattr(fn, "_make_unbound_method"): # type: ignore[unreachable]
return _make_wrapper(128, False, None)(fn)

raise NotImplementedError(f"{fn!r} decorating is not supported")
5 changes: 0 additions & 5 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,3 @@ junit_family=xunit2
asyncio_mode=auto
timeout=15
xfail_strict = true

[mypy]
strict=True
pretty=True
packages=async_lru, tests
19 changes: 15 additions & 4 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,26 @@
import sys
from functools import _CacheInfo
from typing import Callable
from typing import Callable, TypeVar

import pytest

from async_lru import _R, _LRUCacheWrapper
from async_lru import _LRUCacheWrapper


if sys.version_info >= (3, 10):
from typing import ParamSpec
else:
from typing_extensions import ParamSpec


_T = TypeVar("_T")
_P = ParamSpec("_P")


@pytest.fixture
def check_lru() -> Callable[..., None]:
def check_lru() -> Callable[..., None]: # type: ignore[misc]
def _check_lru(
wrapped: _LRUCacheWrapper[_R],
wrapped: _LRUCacheWrapper[_P, _T],
*,
hits: int,
misses: int,
Expand Down
15 changes: 15 additions & 0 deletions tests/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,21 @@ async def coro(val: int) -> int:
assert await coro_wrapped2() == 2


async def test_alru_cache_partial_typing() -> None:
"""Test that mypy produces call-arg errors correctly."""

async def coro(val: int) -> int:
return val

coro_wrapped1 = alru_cache(coro)
with pytest.raises(TypeError):
await coro_wrapped1(1, 1) # type: ignore[call-arg]

coro_wrapped2 = alru_cache(partial(coro, 2))
with pytest.raises(TypeError):
await coro_wrapped2(4) == 2 # type: ignore[call-arg]


async def test_alru_cache_await_same_result_async(
check_lru: Callable[..., None]
) -> None:
Expand Down
2 changes: 1 addition & 1 deletion tests/test_exception.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ async def coro(val: int) -> None:
reason="Memory leak is not fixed for PyPy3.9",
condition=sys.implementation.name == "pypy",
)
async def test_alru_exception_reference_cleanup(check_lru: Callable[..., None]) -> None:
async def test_alru_exception_reference_cleanup(check_lru: Callable[..., None]) -> None: # type: ignore[misc]
class CustomClass:
...

Expand Down
Loading