Skip to content

Commit

Permalink
Annotate almost everything
Browse files Browse the repository at this point in the history
The majority of these were added via MonkeyType from running the
tests, with some manual adjustment afterwards. This also includes
some minor refactors where doing so made the types clearer.

This does leave one typing issue unresolved as it is highlighting
a potential bug.
  • Loading branch information
PeterJCLaw committed Sep 24, 2022
1 parent fb49812 commit 6989656
Show file tree
Hide file tree
Showing 23 changed files with 307 additions and 138 deletions.
8 changes: 4 additions & 4 deletions routemaster/app.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Core App singleton that holds state for the application."""
import threading
import contextlib
from typing import Dict, Optional
from typing import Dict, Iterator, Optional

from sqlalchemy.orm import Session, sessionmaker
from sqlalchemy.engine import Engine
Expand Down Expand Up @@ -32,7 +32,7 @@ def __init__(
self.config = config
self.initialise()

def initialise(self):
def initialise(self) -> None:
"""
Initialise this instance of the app.
Expand Down Expand Up @@ -66,7 +66,7 @@ def session(self) -> Session:

return self._current_session

def set_rollback(self):
def set_rollback(self) -> None:
"""Mark the current session as needing rollback."""
if self._current_session is None:
raise RuntimeError(
Expand All @@ -77,7 +77,7 @@ def set_rollback(self):
self._needs_rollback = True

@contextlib.contextmanager
def new_session(self):
def new_session(self) -> Iterator[None]:
"""Run a single session in this scope."""
if self._current_session is not None:
raise RuntimeError("There is already a session running.")
Expand Down
2 changes: 1 addition & 1 deletion routemaster/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def post_fork():
cron_thread.stop()


def _validate_config(app: App):
def _validate_config(app: App) -> None:
try:
validate_config(app, app.config)
except ValidationError as e:
Expand Down
63 changes: 39 additions & 24 deletions routemaster/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import functools
import contextlib
import subprocess
from typing import Any, Dict
from typing import Any, Dict, Callable, Iterator, ContextManager
from unittest import mock

import pytest
Expand All @@ -19,7 +19,10 @@
from sqlalchemy import create_engine
from werkzeug.test import Client
from sqlalchemy.orm import sessionmaker
from _pytest.fixtures import SubRequest
from sqlalchemy.orm.session import Session

import routemaster.config.model
from routemaster import state_machine
from routemaster.db import Label, History, metadata
from routemaster.app import App
Expand All @@ -45,6 +48,7 @@
from routemaster.logging import BaseLogger, SplitLogger, register_loggers
from routemaster.webhooks import (
WebhookResult,
WebhookRunner,
webhook_runner_for_state_machine,
)
from routemaster.middleware import wrap_application
Expand Down Expand Up @@ -235,7 +239,8 @@ class TestApp(App):
2. We can set a flag on access to `.db` so that we needn't bother with
resetting the database if nothing has actually been changed.
"""
def __init__(self, config):

def __init__(self, config: routemaster.config.model.Config) -> None:
self.config = config
self.session_used = False
self.logger = SplitLogger(config, loggers=register_loggers(config))
Expand All @@ -249,13 +254,13 @@ def __init__(self, config):
}

@property
def session(self):
def session(self) -> Session:
"""Start if necessary and return a shared session."""
self.session_used = True
return super().session


def get_test_app(**kwargs):
def get_test_app(**kwargs) -> TestApp:
"""Instantiate an app with testing parameters."""
return TestApp(Config(
state_machines=kwargs.get('state_machines', TEST_STATE_MACHINES),
Expand All @@ -270,28 +275,28 @@ def get_test_app(**kwargs):


@pytest.fixture()
def client(custom_app=None):
def client(custom_app: None = None) -> Client:
"""Create a werkzeug test client."""
_app = get_test_app() if custom_app is None else custom_app
server.config.app = _app
server.config.app = _app # type: ignore[attr-defined]
_app.logger.init_flask(server)
return Client(wrap_application(_app, server), werkzeug.Response)


@pytest.fixture()
def app(**kwargs):
def app(**kwargs: Any) -> TestApp:
"""Create an `App` config object for testing."""
return get_test_app(**kwargs)


@pytest.fixture()
def custom_app():
def custom_app() -> Callable[..., TestApp]:
"""Return the test app generator so that we can pass in custom config."""
return get_test_app


@pytest.fixture()
def app_env():
def app_env() -> Dict[str, str]:
"""
Create a dict of environment variables.
Expand All @@ -307,15 +312,15 @@ def app_env():


@pytest.fixture(autouse=True, scope='session')
def database_creation(request):
def database_creation(request: SubRequest) -> Iterator[None]:
"""Wrap test session in creating and destroying all required tables."""
metadata.drop_all(bind=TEST_ENGINE)
metadata.create_all(bind=TEST_ENGINE)
yield


@pytest.fixture(autouse=True)
def database_clear(app):
def database_clear(app: TestApp) -> Iterator[None]:
"""Truncate all tables after each test."""
yield
if app.session_used:
Expand All @@ -328,7 +333,10 @@ def database_clear(app):


@pytest.fixture()
def create_label(app, mock_test_feed):
def create_label(
app: TestApp,
mock_test_feed: Callable[[], ContextManager[None]],
) -> Callable[[str, str, Dict[str, Any]], LabelRef]:
"""Create a label in the database."""

def _create(
Expand All @@ -348,7 +356,7 @@ def _create(


@pytest.fixture()
def delete_label(app):
def delete_label(app: TestApp) -> Callable[[str, str], None]:
"""
Mark a label in the database as deleted.
"""
Expand All @@ -364,7 +372,10 @@ def _delete(name: str, state_machine_name: str) -> None:


@pytest.fixture()
def create_deleted_label(create_label, delete_label):
def create_deleted_label(
create_label: Callable[[str, str, Dict[str, Any]], LabelRef],
delete_label: Callable[[str, str], None],
) -> Callable[[str, str], LabelRef]:
"""
Create a label in the database and then delete it.
"""
Expand All @@ -378,7 +389,10 @@ def _create_and_delete(name: str, state_machine_name: str) -> LabelRef:


@pytest.fixture()
def mock_webhook():
def mock_webhook() -> Callable[
[WebhookResult],
ContextManager[Callable[[StateMachine], WebhookRunner]],
]:
"""Mock the test config's webhook call."""
@contextlib.contextmanager
def _mock(result=WebhookResult.SUCCESS):
Expand All @@ -392,7 +406,7 @@ def _mock(result=WebhookResult.SUCCESS):


@pytest.fixture()
def mock_test_feed():
def mock_test_feed() -> Callable[[Dict[str, Any]], ContextManager[None]]:
"""Mock out the test feed."""
@contextlib.contextmanager
def _mock(data={'should_do_alternate_action': False}):
Expand All @@ -414,7 +428,7 @@ def _mock(data={'should_do_alternate_action': False}):


@pytest.fixture()
def assert_history(app):
def assert_history(app: TestApp) -> Callable:
"""Assert that the database history matches what is expected."""
def _assert(entries):
with app.new_session():
Expand All @@ -432,7 +446,7 @@ def _assert(entries):


@pytest.fixture()
def set_metadata(app):
def set_metadata(app: TestApp) -> Callable:
"""Directly set the metadata for a label in the database."""
def _inner(label, update):
with app.new_session():
Expand All @@ -449,7 +463,7 @@ def _inner(label, update):


@pytest.fixture()
def make_context(app):
def make_context(app: TestApp) -> Callable:
"""Factory for Contexts that provides sane defaults for testing."""
def _inner(**kwargs):
logger = BaseLogger(app.config)
Expand Down Expand Up @@ -491,31 +505,32 @@ def version():


@pytest.fixture()
def current_state(app):
def current_state(app: TestApp) -> Callable[[LabelRef], str]:
"""Get the current state of a label."""
def _inner(label):
def _inner(label: LabelRef) -> str:
with app.new_session():
return app.session.query(
History.new_state,
).filter_by(
label_name=label.name,
label_state_machine=label.state_machine,
).order_by(
History.id.desc(),
# TODO: use the sqlalchemy mypy plugin rather than our stubs
History.id.desc(), # type: ignore[attr-defined]
).limit(1).scalar()
return _inner


@pytest.fixture()
def unused_tcp_port():
def unused_tcp_port() -> int:
"""Returns an unused TCP port, inspired by pytest-asyncio."""
with contextlib.closing(socket.socket()) as sock:
sock.bind(('127.0.0.1', 0))
return sock.getsockname()[1]


@pytest.fixture()
def routemaster_serve_subprocess(unused_tcp_port):
def routemaster_serve_subprocess(unused_tcp_port: int) -> Callable:
"""
Fixture to spawn a routemaster server as a subprocess.
Expand Down
30 changes: 25 additions & 5 deletions routemaster/context.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,25 @@
"""Context definition for exit condition programs."""
import datetime
from typing import Any, Dict, Iterable, Optional, Sequence
from typing import (
Any,
Dict,
Tuple,
Union,
Callable,
Iterable,
Optional,
Sequence,
ContextManager,
)

import requests

from routemaster.feeds import Feed
from routemaster.utils import get_path

ResponseLogger = Callable[[requests.Response], None]
FeedLoggingContext = Callable[[str], ContextManager[ResponseLogger]]


class Context(object):
"""Execution context for exit condition programs."""
Expand All @@ -18,7 +33,7 @@ def __init__(
feeds: Dict[str, Feed],
accessed_variables: Iterable[str],
current_history_entry: Optional[Any],
feed_logging_context,
feed_logging_context: FeedLoggingContext,
) -> None:
"""Create an execution context."""
if now.tzinfo is None:
Expand Down Expand Up @@ -65,7 +80,12 @@ def _lookup_history(self, path: Sequence[str]) -> Any:
'previous_state': self.current_history_entry.old_state,
}[variable_name]

def property_handler(self, property_name, value, **kwargs):
def property_handler(
self,
property_name: Union[Tuple[str, ...]],
value: Any,
**kwargs: Any,
) -> bool:
"""Handle a property in execution."""
if property_name == ('passed',):
epoch = kwargs['since']
Expand All @@ -82,8 +102,8 @@ def _pre_warm_feeds(
self,
label: str,
accessed_variables: Iterable[str],
logging_context,
):
logging_context: FeedLoggingContext,
) -> None:
for accessed_variable in accessed_variables:
parts = accessed_variable.split('.')

Expand Down
2 changes: 1 addition & 1 deletion routemaster/cron.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def process_job(
# Bound when scheduling a specific job for a state
fn: LabelStateProcessor,
label_provider: LabelProvider,
):
) -> None:
"""Process a single instance of a single cron job."""

def _iter_labels_until_terminating(
Expand Down
4 changes: 3 additions & 1 deletion routemaster/exit_conditions/analysis.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
"""Analysis of compiled programs."""


from typing import Any, Tuple, Union, Iterator

from routemaster.exit_conditions.operations import Operation


def find_accessed_keys(instructions):
def find_accessed_keys(instructions: Any) -> Iterator[Union[Tuple[str, str], Tuple[str], Tuple[str, str, str]]]:
"""Yield each key accessed under the program."""
for instruction, *args in instructions:
if instruction == Operation.LOOKUP:
Expand Down
5 changes: 3 additions & 2 deletions routemaster/exit_conditions/error_display.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
"""Human-readable ParseError handling."""
from typing import Tuple


def _find_line_containing(source, index):
def _find_line_containing(source: str, index: int) -> Tuple[int, str, int]:
"""Find (line number, line, offset) triple for an index into a string."""
lines = source.splitlines()

Expand All @@ -20,7 +21,7 @@ def _find_line_containing(source, index):
raise AssertionError("index >> len(source)")


def format_parse_error_message(*, source, error):
def format_parse_error_message(*, source, error) -> str:
"""Format a parse error on some source for nicer display."""
error_line_number, error_line, error_offset = _find_line_containing(
source,
Expand Down
Loading

0 comments on commit 6989656

Please sign in to comment.