Skip to content

Commit

Permalink
Fix mypy type hints
Browse files Browse the repository at this point in the history
  • Loading branch information
ojii committed Feb 14, 2022
1 parent 514b5e1 commit dfe1291
Show file tree
Hide file tree
Showing 13 changed files with 108 additions and 69 deletions.
3 changes: 0 additions & 3 deletions mypy.ini

This file was deleted.

8 changes: 8 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ freezegun = "^1.0"
mypy = "0.931"
pyfakefs = "^4.3.2"
isort = "^5.8.0"
types-freezegun = "^1.1.6"

[tool.pytest.ini_options]
asyncio_mode = "auto"
Expand All @@ -55,6 +56,13 @@ combine_as_imports = "1"
include_trailing_comma = "True"
known_first_party = "aiodynamo"

[tool.mypy]
strict = true
files = [
"src/",
"tests/",
]

[build-system]
requires = ["poetry>=0.12"]
build-backend = "poetry.masonry.api"
12 changes: 0 additions & 12 deletions src/aiodynamo/_mypy_hacks.py

This file was deleted.

9 changes: 5 additions & 4 deletions src/aiodynamo/http/types.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from dataclasses import dataclass
from typing import Awaitable, Dict, Optional, Union
from typing import Dict, Optional, Union

from aiodynamo._compat import Literal
from aiodynamo._mypy_hacks import FixedCallable
from aiodynamo._compat import Literal, Protocol


@dataclass(frozen=True)
Expand All @@ -24,4 +23,6 @@ class RequestFailed(Exception):
inner: Exception


HttpImplementation = FixedCallable[Request, Awaitable[Response]]
class HttpImplementation(Protocol):
async def __call__(self, request: Request) -> Response:
...
7 changes: 4 additions & 3 deletions src/aiodynamo/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@
from enum import Enum
from typing import Any, Dict, List, Union

from ._compat import TypedDict
from ._mypy_hacks import FixedCallable
from ._compat import Protocol, TypedDict

Timeout = Union[float, int]
Numeric = Union[float, int, decimal.Decimal]
Expand Down Expand Up @@ -83,4 +82,6 @@ class EncodedStreamSpecification(EncodedStreamSpecificationRequired, total=False
SIMPLE_TYPES = frozenset({AttributeType.boolean, AttributeType.string})


NumericTypeConverter = FixedCallable[str, Any]
class NumericTypeConverter(Protocol):
def __call__(self, value: str) -> Any:
...
9 changes: 7 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
import asyncio
from asyncio import AbstractEventLoop
from typing import AsyncGenerator, Generator

import pytest
from _pytest.fixtures import SubRequest

from aiodynamo.http.types import HttpImplementation


@pytest.fixture(params=["httpx", "aiohttp"])
async def http(request):
async def http(request: SubRequest) -> AsyncGenerator[HttpImplementation, None]:
if request.param == "httpx":
try:
import httpx
Expand All @@ -26,7 +31,7 @@ async def http(request):


@pytest.fixture(scope="session")
def session_event_loop():
def session_event_loop() -> Generator[AbstractEventLoop, None, None]:
loop = asyncio.get_event_loop_policy().new_event_loop()
yield loop
loop.close()
15 changes: 8 additions & 7 deletions tests/integration/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
import os
import sys
import uuid
from typing import AsyncGenerator, Generator, Union
from typing import AsyncGenerator, Generator, Union, cast

from _pytest.fixtures import SubRequest
from httpx import AsyncClient

from aiodynamo.http.httpx import HTTPX
Expand Down Expand Up @@ -75,7 +76,7 @@ def client(


@pytest.fixture(scope="session")
def wait_config(real_dynamo) -> RetryConfig:
def wait_config(real_dynamo: Optional[URL]) -> RetryConfig:
return (
RetryConfig.default_wait_config()
if real_dynamo
Expand All @@ -84,8 +85,8 @@ def wait_config(real_dynamo) -> RetryConfig:


@pytest.fixture(params=[True, False], scope="session")
def consistent_read(request) -> bool:
return request.param
def consistent_read(request: SubRequest) -> bool:
return cast(bool, request.param)


async def _make_table(
Expand All @@ -107,7 +108,7 @@ async def _make_table(
@pytest.fixture
async def table_factory(
client: Client, table_name_prefix: str, wait_config: RetryConfig
) -> Callable[[Optional[Throughput]], Awaitable[str]]:
) -> Callable[[Throughput], Awaitable[str]]:
async def factory(throughput: Throughput = Throughput(5, 5)) -> str:
return await _make_table(client, table_name_prefix, throughput, wait_config)

Expand All @@ -132,7 +133,7 @@ def prefilled_table(
table_name_prefix: str,
wait_config: RetryConfig,
session_event_loop: asyncio.BaseEventLoop,
):
) -> Generator[str, None, None]:
"""
Event loop is function scoped, so we can't use pytest-asyncio here.
"""
Expand All @@ -154,7 +155,7 @@ async def startup() -> str:

return name

async def shutdown(name: str):
async def shutdown(name: str) -> None:
async with AsyncClient() as session:
await Client(
HTTPX(session), Credentials.auto(), region, endpoint
Expand Down
2 changes: 1 addition & 1 deletion tests/integration/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@

async def test_create_table_with_indices(
client: Client, table_name_prefix: str, wait_config: RetryConfig
):
) -> None:
name = table_name_prefix + secrets.token_hex(4)
await client.create_table(
name,
Expand Down
7 changes: 4 additions & 3 deletions tests/unit/test_client.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
from typing import Any, Type

import pytest

Expand All @@ -16,7 +17,7 @@
from aiodynamo.models import StaticDelayRetry


def bjson(data):
def bjson(data: Any) -> bytes:
return json.dumps(data).encode()


Expand All @@ -31,8 +32,8 @@ def bjson(data):
],
)
async def test_client_send_request_retryable_errors(
status, dynamo_error, aiodynamo_error
):
status: int, dynamo_error: str, aiodynamo_error: Type[Exception]
) -> None:
async def http(request: Request) -> Response:
return Response(
status=status,
Expand Down
36 changes: 24 additions & 12 deletions tests/unit/test_credentials.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import datetime
from pathlib import Path
from textwrap import dedent
from typing import AsyncGenerator, Optional

import pytest
from _pytest.monkeypatch import MonkeyPatch
from aiohttp import web
from freezegun import freeze_time
from pyfakefs.fake_filesystem import FakeFilesystem # type: ignore[import]
from yarl import URL

from aiodynamo.credentials import (
Expand All @@ -18,23 +21,24 @@
Key,
Metadata,
)
from aiodynamo.http.types import HttpImplementation

pytestmark = [pytest.mark.usefixtures("fs")]


class InstanceMetadataServer:
def __init__(self):
def __init__(self) -> None:
self.port = 0
self.role = None
self.metadata = None
self.role: Optional[str] = None
self.metadata: Optional[Metadata] = None
self.calls = 0

async def role_handler(self, request):
async def role_handler(self, request: web.Request) -> web.Response:
if self.role is None:
raise web.HTTPNotFound()
return web.Response(body=self.role.encode("utf-8"))

async def credentials_handler(self, request):
async def credentials_handler(self, request: web.Request) -> web.Response:
self.calls += 1
if self.role is None:
raise web.HTTPNotFound()
Expand All @@ -56,7 +60,7 @@ async def credentials_handler(self, request):


@pytest.fixture
async def instance_metadata_server():
async def instance_metadata_server() -> AsyncGenerator[InstanceMetadataServer, None]:
ims = InstanceMetadataServer()
app = web.Application()
app.add_routes(
Expand All @@ -74,12 +78,14 @@ async def instance_metadata_server():
await runner.setup()
site = web.TCPSite(runner, "127.0.0.1", 0)
await site.start()
ims.port = site._server.sockets[0].getsockname()[1]
ims.port = site._server.sockets[0].getsockname()[1] # type: ignore[union-attr]
yield ims
await runner.cleanup()


async def test_env_credentials(monkeypatch, http):
async def test_env_credentials(
monkeypatch: MonkeyPatch, http: HttpImplementation
) -> None:
monkeypatch.delenv("AWS_ACCESS_KEY_ID", raising=False)
monkeypatch.delenv("AWS_SECRET_ACCESS_KEY", raising=False)
monkeypatch.delenv("AWS_SESSION_TOKEN", raising=False)
Expand All @@ -100,7 +106,11 @@ async def test_env_credentials(monkeypatch, http):


@pytest.mark.parametrize("role", ["role", "arn:aws:iam::1234567890:role/test-role"])
async def test_ec2_instance_metdata_credentials(http, instance_metadata_server, role):
async def test_ec2_instance_metdata_credentials(
http: HttpImplementation,
instance_metadata_server: InstanceMetadataServer,
role: str,
) -> None:
imc = InstanceMetadataCredentials(
timeout=0.1,
base_url=URL("http://localhost").with_port(instance_metadata_server.port),
Expand All @@ -119,7 +129,9 @@ async def test_ec2_instance_metdata_credentials(http, instance_metadata_server,
assert await imc.get_key(http) == metadata.key


async def test_simultaneous_credentials_refresh(http, instance_metadata_server):
async def test_simultaneous_credentials_refresh(
http: HttpImplementation, instance_metadata_server: InstanceMetadataServer
) -> None:
instance_metadata_server.role = "hoge"
now = datetime.datetime(2020, 3, 12, 15, 37, 51, tzinfo=datetime.timezone.utc)
expires = now + EXPIRES_SOON_THRESHOLD - datetime.timedelta(seconds=10)
Expand Down Expand Up @@ -148,7 +160,7 @@ async def test_simultaneous_credentials_refresh(http, instance_metadata_server):
assert instance_metadata_server.calls == 1


async def test_disabled(monkeypatch):
async def test_disabled(monkeypatch: MonkeyPatch) -> None:
monkeypatch.delenv("AWS_ACCESS_KEY_ID", raising=False)
monkeypatch.delenv("AWS_SECRET_ACCESS_KEY", raising=False)
monkeypatch.delenv("AWS_SESSION_TOKEN", raising=False)
Expand All @@ -168,7 +180,7 @@ async def test_disabled(monkeypatch):
assert not creds.is_disabled()


async def test_file_credentials(fs, http):
async def test_file_credentials(fs: FakeFilesystem, http: HttpImplementation) -> None:
assert FileCredentials().is_disabled()
fs.create_file(
Path.home().joinpath(".aws", "credentials"),
Expand Down
18 changes: 10 additions & 8 deletions tests/unit/test_errors.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json
from contextlib import asynccontextmanager
from typing import cast
from typing import Any, AsyncIterator, cast

import pytest
from aiohttp import ClientConnectionError, ClientSession
Expand All @@ -15,10 +15,10 @@
creds = StaticCredentials(Key("a", "b"))


async def test_retry_raises_underlying_error_aiohttp():
async def test_retry_raises_underlying_error_aiohttp() -> None:
class TestSession:
@asynccontextmanager
async def request(self, *args, **kwargs):
async def request(self, *args: Any, **kwargs: Any) -> AsyncIterator[None]:
raise ClientConnectionError()
yield # needed for asynccontextmanager

Expand All @@ -30,11 +30,11 @@ async def request(self, *args, **kwargs):
await client.get_item("test", {"a": "b"})


async def test_dynamo_errors_get_raised_depaginated():
async def test_dynamo_errors_get_raised_depaginated() -> None:
class TestResponse:
status = 400

async def read(self):
async def read(self) -> bytes:
return json.dumps(
{
"__type": "com.amazonaws.dynamodb.v20120810#ValidationException",
Expand All @@ -44,7 +44,9 @@ async def read(self):

class TestSession:
@asynccontextmanager
async def request(self, *args, **kwargs):
async def request(
self, *args: Any, **kwargs: Any
) -> AsyncIterator[TestResponse]:
yield TestResponse()

http = AIOHTTP(cast(ClientSession, TestSession()))
Expand All @@ -57,7 +59,7 @@ async def request(self, *args, **kwargs):


@pytest.mark.parametrize("status", [500, 503])
async def test_dynamo_retries_50x(status):
async def test_dynamo_retries_50x(status: int) -> None:
responses = iter(
[
Response(status=status, body=b""),
Expand All @@ -67,7 +69,7 @@ async def test_dynamo_retries_50x(status):
]
)

async def http(request: Request):
async def http(request: Request) -> Response:
return next(responses)

client = Client(http, creds, "test", throttle_config=StaticDelayRetry(delay=0.01))
Expand Down
Loading

0 comments on commit dfe1291

Please sign in to comment.