diff --git a/.mypy.ini b/.mypy.ini index f356fd6c..aa4e3878 100644 --- a/.mypy.ini +++ b/.mypy.ini @@ -2,13 +2,13 @@ files = aiohttp_devtools, tests check_untyped_defs = True follow_imports_for_stubs = True -#disallow_any_decorated = True +disallow_any_decorated = True disallow_any_generics = True -#disallow_incomplete_defs = True +disallow_incomplete_defs = True disallow_subclassing_any = True -#disallow_untyped_calls = True +disallow_untyped_calls = True disallow_untyped_decorators = True -#disallow_untyped_defs = True +disallow_untyped_defs = True implicit_reexport = False no_implicit_optional = True show_error_codes = True diff --git a/aiohttp_devtools/cli.py b/aiohttp_devtools/cli.py index 637991cf..a87fee09 100644 --- a/aiohttp_devtools/cli.py +++ b/aiohttp_devtools/cli.py @@ -1,5 +1,6 @@ import sys import traceback +from typing import Any import click from aiohttp.web import run_app @@ -18,7 +19,7 @@ @click.group() @click.version_option(__version__, "-V", "--version", prog_name="aiohttp-devtools") -def cli(): +def cli() -> None: pass @@ -32,7 +33,7 @@ def cli(): @click.option('--livereload/--no-livereload', envvar='AIO_LIVERELOAD', default=True, help=livereload_help) @click.option('-p', '--port', default=8000, type=int) @click.option('-v', '--verbose', is_flag=True, help=verbose_help) -def serve(path, livereload, port, verbose): +def serve(path: str, livereload: bool, port: int, verbose: bool) -> None: """ Serve static files from a directory. """ @@ -68,7 +69,7 @@ def serve(path, livereload, port, verbose): @click.option('--aux-port', envvar='AIO_AUX_PORT', type=click.INT, help=aux_port_help) @click.option('-v', '--verbose', is_flag=True, help=verbose_help) @click.argument('project_args', nargs=-1) -def runserver(**config): +def runserver(**config: Any) -> None: """ Run a development server for an aiohttp apps. diff --git a/aiohttp_devtools/logs.py b/aiohttp_devtools/logs.py index e16663cc..b02df1f2 100644 --- a/aiohttp_devtools/logs.py +++ b/aiohttp_devtools/logs.py @@ -3,11 +3,18 @@ import logging.config import platform import re +import sys import traceback from io import StringIO +from logging import LogRecord from types import TracebackType from typing import Dict, Optional, Tuple, Type, Union +if sys.version_info < (3, 8): + from typing_extensions import Literal +else: + from typing import Literal + import pygments from devtools import pformat from devtools.ansi import isatty, sformat @@ -33,11 +40,11 @@ class DefaultFormatter(logging.Formatter): - def __init__(self, fmt=None, datefmt=None, style='%'): + def __init__(self, fmt: Optional[str] = None, datefmt: Optional[str] = None, style: Literal["%", "{", "$"] = "%"): super().__init__(fmt, datefmt, style) self.stream_is_tty = False - def format(self, record): + def format(self, record: LogRecord) -> str: msg = super().format(record) if not self.stream_is_tty: return msg @@ -45,20 +52,19 @@ def format(self, record): log_color = LOG_FORMATS.get(record.levelno, sformat.red) if m: time = sformat(m.groups()[0], sformat.magenta) - return time + sformat(msg[m.end():], log_color) + return time + sformat(msg[m.end():], log_color) # type: ignore[no-any-return] - return sformat(msg, log_color) + return sformat(msg, log_color) # type: ignore[no-any-return] class AccessFormatter(logging.Formatter): - """ - Used to log aiohttp_access and aiohttp_server - """ - def __init__(self, fmt=None, datefmt=None, style='%'): + """Used to log aiohttp_access and aiohttp_server.""" + + def __init__(self, fmt: Optional[str] = None, datefmt: Optional[str] = None, style: Literal["%", "{", "$"] = "%"): super().__init__(fmt, datefmt, style) self.stream_is_tty = False - def formatMessage(self, record): + def formatMessage(self, record: LogRecord) -> str: msg = super().formatMessage(record) if msg[0] != '{': return msg @@ -174,6 +180,6 @@ def log_config(verbose: bool) -> Dict[str, object]: } -def setup_logging(verbose): +def setup_logging(verbose: bool) -> None: config = log_config(verbose) logging.config.dictConfig(config) diff --git a/aiohttp_devtools/runserver/config.py b/aiohttp_devtools/runserver/config.py index 8454e266..ce7c836e 100644 --- a/aiohttp_devtools/runserver/config.py +++ b/aiohttp_devtools/runserver/config.py @@ -1,16 +1,17 @@ import asyncio -import inspect import re import sys from importlib import import_module from pathlib import Path -from typing import Optional +from typing import Awaitable, Callable, Optional, Union from aiohttp import web from ..exceptions import AiohttpDevConfigError as AdevConfigError from ..logs import rs_dft_logger as logger +AppFactory = Union[web.Application, Callable[[], web.Application], Callable[[], Awaitable[web.Application]]] + STD_FILE_NAMES = [ re.compile(r'main\.py'), re.compile(r'app\.py'), @@ -56,9 +57,12 @@ def __init__(self, *, self.settings_found = False self.py_file = self._resolve_path(str(self.app_path), 'is_file', 'app-path') - self.python_path = self._resolve_path(python_path, 'is_dir', 'python-path') or self.root_path + if python_path: + self.python_path = self._resolve_path(python_path, "is_dir", "python-path") + else: + self.python_path = self.root_path - self.static_path = self._resolve_path(static_path, 'is_dir', 'static-path') + self.static_path = self._resolve_path(static_path, "is_dir", "static-path") if static_path else None self.static_url = static_url self.livereload = livereload self.app_factory_name = app_factory_name @@ -70,7 +74,7 @@ def __init__(self, *, @property def static_path_str(self) -> Optional[str]: - return self.static_path and str(self.static_path) + return str(self.static_path) if self.static_path else None def _find_app_path(self, app_path: str) -> Path: # for backwards compatibility try this first @@ -94,10 +98,7 @@ def _find_app_path(self, app_path: str) -> Path: raise AdevConfigError('unable to find a recognised default file ("app.py" or "main.py") ' 'in the directory "%s"' % app_path) - def _resolve_path(self, _path: Optional[str], check: str, arg_name: str): - if _path is None: - return - + def _resolve_path(self, _path: str, check: str, arg_name: str) -> Path: if _path.startswith('/'): path = Path(_path) error_msg = '{arg_name} "{path}" is not a valid path' @@ -119,11 +120,11 @@ def _resolve_path(self, _path: Optional[str], check: str, arg_name: str): raise AdevConfigError('{} is not a directory'.format(path)) return path - def import_app_factory(self): - """ - Import attribute/class from from a python module. Raise AdevConfigError if the import failed. + def import_app_factory(self) -> AppFactory: + """Import and return attribute/class from a python module. - :return: (attribute, Path object for directory of file) + Raises: + AdevConfigError - If the import failed. """ rel_py_file = self.py_file.relative_to(self.python_path) module_path = '.'.join(rel_py_file.with_suffix('').parts) @@ -145,36 +146,39 @@ def import_app_factory(self): try: attr = getattr(module, self.app_factory_name) - except AttributeError as e: - raise AdevConfigError('Module "{s.py_file.name}" ' - 'does not define a "{s.app_factory_name}" attribute/class'.format(s=self)) from e + except AttributeError: + raise AdevConfigError("Module '{}' does not define a '{}' attribute/class".format( + self.py_file.name, self.app_factory_name)) + + if not isinstance(attr, web.Application) and not callable(attr): + raise AdevConfigError("'{}.{}' is not an Application or callable".format( + self.py_file.name, self.app_factory_name)) + + if callable(attr): + required_args = attr.__code__.co_argcount - len(attr.__defaults__) + if required_args > 0: + raise AdevConfigError("'{}.{}' should not have required arguments.".format( + self.py_file.name, self.app_factory_name)) self.watch_path = self.watch_path or Path(module.__file__ or ".").parent - return attr + return attr # type: ignore[no-any-return] - async def load_app(self, app_factory): + async def load_app(self, app_factory: AppFactory) -> web.Application: if isinstance(app_factory, web.Application): - app = app_factory - else: - # app_factory should be a proper factory with signature (loop): -> Application - signature = inspect.signature(app_factory) - if 'loop' in signature.parameters: - loop = asyncio.get_event_loop() - app = app_factory(loop=loop) - else: - # loop argument missing, assume no arguments - app = app_factory() + return app_factory + + app = app_factory() - if asyncio.iscoroutine(app): - app = await app + if asyncio.iscoroutine(app): + app = await app - if not isinstance(app, web.Application): - raise AdevConfigError('app factory "{.app_factory_name}" returned "{.__class__.__name__}" not an ' - 'aiohttp.web.Application'.format(self, app)) + if not isinstance(app, web.Application): + raise AdevConfigError("app factory '{}' returned '{}' not an aiohttp.web.Application".format( + self.app_factory_name, app.__class__.__name__)) return app - def __str__(self): + def __str__(self) -> str: fields = ('py_file', 'static_path', 'static_url', 'livereload', 'app_factory_name', 'host', 'main_port', 'aux_port') return 'Config:\n' + '\n'.join(' {0}: {1!r}'.format(f, getattr(self, f)) for f in fields) diff --git a/aiohttp_devtools/runserver/log_handlers.py b/aiohttp_devtools/runserver/log_handlers.py index 6b239377..09652d64 100644 --- a/aiohttp_devtools/runserver/log_handlers.py +++ b/aiohttp_devtools/runserver/log_handlers.py @@ -1,8 +1,9 @@ import json import warnings from datetime import datetime, timedelta -from typing import cast +from typing import Dict, Optional, Union, cast +from aiohttp import web from aiohttp.abc import AbstractAccessLogger dbtb = '/_debugtoolbar/' @@ -12,13 +13,13 @@ class _AccessLogger(AbstractAccessLogger): prefix: str - def get_msg(self, request, response, time): + def get_msg(self, request: web.BaseRequest, response: web.StreamResponse, time: float) -> Optional[str]: raise NotImplementedError() - def extra(self, request, response, time): + def extra(self, request: web.BaseRequest, response: web.StreamResponse, time: float) -> Optional[Dict[str, object]]: pass - def log(self, request, response, time): + def log(self, request: web.BaseRequest, response: web.StreamResponse, time: float) -> None: msg = self.get_msg(request, response, time) if not msg: return @@ -39,7 +40,7 @@ def log(self, request, response, time): class AccessLogger(_AccessLogger): prefix = '●' - def get_msg(self, request, response, time): + def get_msg(self, request: web.BaseRequest, response: web.StreamResponse, time: float) -> str: return '{method} {path} {code} {size} {ms:0.0f}ms'.format( method=request.method, path=request.path_qs, @@ -48,35 +49,40 @@ def get_msg(self, request, response, time): ms=time * 1000, ) - def extra(self, request, response, time: float): - if response.status > 310: - request_body = request._read_bytes - details = dict( - request_duration_ms=round(time * 1000, 3), - request_headers=dict(request.headers), - request_body=parse_body(request_body, 'request body'), - request_size=fmt_size(0 if request_body is None else len(request_body)), - response_headers=dict(response.headers), - response_body=parse_body(response.text, "response body"), - ) - return dict(details=details) + def extra(self, request: web.BaseRequest, response: web.StreamResponse, time: float) -> Optional[Dict[str, object]]: + if response.status <= 310: + return None + + request_body = request._read_bytes + body_text = response.text if isinstance(response, web.Response) else None + details = { + "request_duration_ms": round(time * 1000, 3), + "request_headers": dict(request.headers), + "request_body": parse_body(request_body, "request body"), + "request_size": fmt_size(0 if request_body is None else len(request_body)), + "response_headers": dict(response.headers), + "response_body": parse_body(body_text, "response body"), + } + return {"details": details} class AuxAccessLogger(_AccessLogger): prefix = '◆' - def get_msg(self, request, response, time): + def get_msg(self, request: web.BaseRequest, response: web.StreamResponse, time: float) -> Optional[str]: # don't log livereload - if request.path not in {'/livereload', '/livereload.js'}: - return '{method} {path} {code} {size}'.format( - method=request.method, - path=request.path_qs, - code=response.status, - size=fmt_size(response.body_length), - ) + if request.path in {"/livereload", "/livereload.js"}: + return None + + return "{method} {path} {code} {size}".format( + method=request.method, + path=request.path_qs, + code=response.status, + size=fmt_size(response.body_length), + ) -def fmt_size(num): +def fmt_size(num: int) -> str: if not num: return '' if num < 1024: @@ -85,15 +91,16 @@ def fmt_size(num): return '{:0.1f}KB'.format(num / 1024) -def parse_body(v, name): - if isinstance(v, (str, bytes)): - try: - return json.loads(v) - except UnicodeDecodeError: - v = cast(bytes, v) # UnicodeDecodeError only occurs on bytes. - warnings.warn('UnicodeDecodeError parsing ' + name, UserWarning) - # bytes which cause UnicodeDecodeError can cause problems later on - return v.decode(errors='ignore') - except (ValueError, TypeError): - pass - return v +def parse_body(v: Union[str, bytes, None], name: str) -> object: + if v is None: + return v + + try: + return json.loads(v) + except UnicodeDecodeError: + v = cast(bytes, v) # UnicodeDecodeError only occurs on bytes. + warnings.warn("UnicodeDecodeError parsing " + name, UserWarning) + # bytes which cause UnicodeDecodeError can cause problems later on + return v.decode(errors="ignore") + except (ValueError, TypeError): + return v diff --git a/aiohttp_devtools/runserver/main.py b/aiohttp_devtools/runserver/main.py index cd1ecbf3..e5a9375b 100644 --- a/aiohttp_devtools/runserver/main.py +++ b/aiohttp_devtools/runserver/main.py @@ -1,6 +1,11 @@ import asyncio import os +import sys from multiprocessing import set_start_method +from typing import Any, Type + +from aiohttp.abc import AbstractAccessLogger +from aiohttp.web import Application from ..logs import rs_dft_logger as logger from .config import Config @@ -8,8 +13,21 @@ from .serve import HOST, check_port_open, create_auxiliary_app from .watch import AppTask, LiveReloadTask +if sys.version_info < (3, 8): + from typing_extensions import TypedDict +else: + from typing import TypedDict + + +class RunServer(TypedDict): + app: Application + host: str + port: int + shutdown_timeout: float + access_log_class: Type[AbstractAccessLogger] + -def runserver(**config_kwargs): +def runserver(**config_kwargs: Any) -> RunServer: """ Prepare app ready to run development server. @@ -49,7 +67,7 @@ def runserver(**config_kwargs): "shutdown_timeout": 0.01, "access_log_class": AuxAccessLogger} -def serve_static(*, static_path: str, livereload: bool = True, port: int = 8000): +def serve_static(*, static_path: str, livereload: bool = True, port: int = 8000) -> RunServer: logger.debug('Config: path="%s", livereload=%s, port=%s', static_path, livereload, port) app = create_auxiliary_app(static_path=static_path, livereload=livereload) diff --git a/aiohttp_devtools/runserver/serve.py b/aiohttp_devtools/runserver/serve.py index 7baea4ad..37dcb7c0 100644 --- a/aiohttp_devtools/runserver/serve.py +++ b/aiohttp_devtools/runserver/serve.py @@ -5,11 +5,11 @@ import sys from errno import EADDRINUSE from pathlib import Path -from typing import Optional +from typing import Any, Iterator, Optional from aiohttp import WSMsgType, web from aiohttp.hdrs import LAST_MODIFIED, CONTENT_LENGTH -from aiohttp.web import FileResponse, Response +from aiohttp.typedefs import Handler from aiohttp.web_exceptions import HTTPNotFound, HTTPNotModified from aiohttp.web_urldispatcher import StaticResource from yarl import URL @@ -18,7 +18,7 @@ from ..logs import rs_aux_logger as aux_logger from ..logs import rs_dft_logger as dft_logger from ..logs import setup_logging -from .config import Config +from .config import AppFactory, Config from .log_handlers import AccessLogger from .utils import MutableValue @@ -27,19 +27,19 @@ HOST = '0.0.0.0' -def _set_static_url(app, url): +def _set_static_url(app: web.Application, url: str) -> None: app["static_root_url"] = MutableValue(url) for subapp in app._subapps: _set_static_url(subapp, url) -def _change_static_url(app, url): +def _change_static_url(app: web.Application, url: str) -> None: app["static_root_url"].change(url) for subapp in app._subapps: _change_static_url(subapp, url) -def modify_main_app(app, config: Config): +def modify_main_app(app: web.Application, config: Config) -> None: """ Modify the app we're serving to make development easier, eg. * modify responses to add the livereload snippet @@ -48,28 +48,30 @@ def modify_main_app(app, config: Config): app._debug = True dft_logger.debug('livereload enabled: %s', '✓' if config.livereload else '✖') - def get_host(request): + def get_host(request: web.Request) -> str: if config.infer_host: return request.headers.get('host', 'localhost').split(':', 1)[0] else: return config.host if config.livereload: - async def on_prepare(request, response): - if (not request.path.startswith('/_debugtoolbar') and - 'text/html' in response.content_type and - getattr(response, 'body', False)): - lr_snippet = LIVE_RELOAD_HOST_SNIPPET.format(get_host(request), config.aux_port) - dft_logger.debug('appending live reload snippet "%s" to body', lr_snippet) - response.body += lr_snippet.encode() - response.headers[CONTENT_LENGTH] = str(len(response.body)) + async def on_prepare(request: web.Request, response: web.StreamResponse) -> None: + if (not isinstance(response, web.Response) + or not isinstance(response.body, bytes) # No support for Payload + or request.path.startswith("/_debugtoolbar") + or "text/html" not in response.content_type): + return + lr_snippet = LIVE_RELOAD_HOST_SNIPPET.format(get_host(request), config.aux_port) + dft_logger.debug("appending live reload snippet '%s' to body", lr_snippet) + response.body += lr_snippet.encode() + response.headers[CONTENT_LENGTH] = str(len(response.body)) app.on_response_prepare.append(on_prepare) static_path = config.static_url.strip('/') if config.infer_host and config.static_path is not None: # we set the app key even in middleware to make the switch to production easier and for backwards compat. @web.middleware - async def static_middleware(request, handler): + async def static_middleware(request: web.Request, handler: Handler) -> web.StreamResponse: static_url = 'http://{}:{}/{}'.format(get_host(request), config.aux_port, static_path) dft_logger.debug('setting app static_root_url to "%s"', static_url) _change_static_url(request.app, static_url) @@ -103,7 +105,7 @@ async def check_port_open(port: int, delay: float = 1) -> None: @contextlib.contextmanager -def set_tty(tty_path): # pragma: no cover +def set_tty(tty_path: Optional[str]) -> Iterator[None]: try: if not tty_path: # to match OSError from open @@ -116,7 +118,7 @@ def set_tty(tty_path): # pragma: no cover yield -def serve_main_app(config: Config, tty_path: Optional[str]): +def serve_main_app(config: Config, tty_path: Optional[str]) -> None: with set_tty(tty_path): setup_logging(config.verbose) app_factory = config.import_app_factory() @@ -132,7 +134,7 @@ def serve_main_app(config: Config, tty_path: Optional[str]): loop.run_until_complete(runner.cleanup()) -async def create_main_app(config: Config, app_factory): +async def create_main_app(config: Config, app_factory: AppFactory) -> web.AppRunner: app = await config.load_app(app_factory) modify_main_app(app, config) @@ -140,7 +142,7 @@ async def create_main_app(config: Config, app_factory): return web.AppRunner(app, access_log_class=AccessLogger) -async def start_main_app(runner: web.AppRunner, port): +async def start_main_app(runner: web.AppRunner, port: int) -> None: await runner.setup() site = web.TCPSite(runner, host=HOST, port=port, shutdown_timeout=0.1) await site.start() @@ -149,7 +151,7 @@ async def start_main_app(runner: web.AppRunner, port): WS = 'websockets' -async def src_reload(app, path: Optional[str] = None): +async def src_reload(app: web.Application, path: Optional[str] = None) -> int: """ prompt each connected browser to reload by sending websocket message. @@ -193,13 +195,13 @@ async def src_reload(app, path: Optional[str] = None): return reloads -async def cleanup_aux_app(app): +async def cleanup_aux_app(app: web.Application) -> None: aux_logger.debug('closing %d websockets...', len(app[WS])) await asyncio.gather(*(ws.close() for ws, _ in app[WS])) def create_auxiliary_app( - *, static_path: Optional[str], static_url="/", livereload=True) -> web.Application: + *, static_path: Optional[str], static_url: str = "/", livereload: bool = True) -> web.Application: app = web.Application() app[WS] = set() app.update( @@ -228,7 +230,7 @@ def create_auxiliary_app( return app -async def livereload_js(request): +async def livereload_js(request: web.Request) -> web.Response: if request.if_modified_since: raise HTTPNotModified() @@ -239,7 +241,7 @@ async def livereload_js(request): WS_TYPE_LOOKUP = {k.value: v for v, k in WSMsgType.__members__.items()} -async def websocket_handler(request): +async def websocket_handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse(timeout=0.01) url = None await ws.prepare(request) @@ -288,12 +290,12 @@ async def websocket_handler(request): class CustomStaticResource(StaticResource): - def __init__(self, *args, add_tail_snippet=False, **kwargs): + def __init__(self, *args: Any, add_tail_snippet: bool = False, **kwargs: Any): self._add_tail_snippet = add_tail_snippet super().__init__(*args, **kwargs) self._show_index = True - def modify_request(self, request): + def modify_request(self, request: web.Request) -> None: """ Apply common path conventions eg. / > /index.html, /foobar > /foobar.html """ @@ -318,8 +320,8 @@ def modify_request(self, request): # path is not not relative to self._directory pass - def _insert_footer(self, response): - if not isinstance(response, FileResponse) or not self._add_tail_snippet: + def _insert_footer(self, response: web.StreamResponse) -> web.StreamResponse: + if not isinstance(response, web.FileResponse) or not self._add_tail_snippet: return response filepath = response._path @@ -330,12 +332,12 @@ def _insert_footer(self, response): with filepath.open('rb') as f: body = f.read() + LIVE_RELOAD_LOCAL_SNIPPET - resp = Response(body=body, content_type='text/html') + resp = web.Response(body=body, content_type="text/html") # Mypy bug: https://github.com/python/mypy/issues/11892 resp.last_modified = filepath.stat().st_mtime # type: ignore[assignment] return resp - async def _handle(self, request): + async def _handle(self, request: web.Request) -> web.StreamResponse: self.modify_request(request) try: response = await super()._handle(request) diff --git a/aiohttp_devtools/runserver/utils.py b/aiohttp_devtools/runserver/utils.py index b309ac06..c03671b8 100644 --- a/aiohttp_devtools/runserver/utils.py +++ b/aiohttp_devtools/runserver/utils.py @@ -1,34 +1,36 @@ +from typing import Any, Generic, Optional, TypeVar +_T = TypeVar("_T") -class MutableValue: - """ - Used to avoid warnings (and in future errors) from aiohttp when the app context is modified. - """ - __slots__ = 'value', - def __init__(self, value=None): +class MutableValue(Generic[_T]): + """Used to avoid errors from aiohttp when the app context is modified.""" + + __slots__ = ("value",) + + def __init__(self, value: Optional[_T] = None): self.value = value - def change(self, new_value): + def change(self, new_value: _T) -> None: self.value = new_value - def __len__(self): - return len(self.value) + def __len__(self) -> int: + return len(self.value) # type: ignore[arg-type] - def __repr__(self): + def __repr__(self) -> str: return repr(self.value) - def __str__(self): + def __str__(self) -> str: return str(self.value) - def __bool__(self): + def __bool__(self) -> bool: return bool(self.value) - def __eq__(self, other): + def __eq__(self, other: object) -> "MutableValue[bool]": # type: ignore[override] return MutableValue(self.value == other) - def __add__(self, other): - return self.value + other + def __add__(self, other: _T) -> _T: + return self.value + other # type: ignore[no-any-return, operator] - def __getattr__(self, item): + def __getattr__(self, item: str) -> Any: return getattr(self.value, item) diff --git a/aiohttp_devtools/runserver/watch.py b/aiohttp_devtools/runserver/watch.py index 5303db42..8e7d7ae2 100644 --- a/aiohttp_devtools/runserver/watch.py +++ b/aiohttp_devtools/runserver/watch.py @@ -4,7 +4,7 @@ import sys from multiprocessing import Process from pathlib import Path -from typing import AsyncIterator, Optional, Union +from typing import AsyncIterator, Iterable, Optional, Tuple, Union from aiohttp import ClientSession, web from watchgod import awatch @@ -16,10 +16,10 @@ class WatchTask: - def __init__(self, path: Union[Path, str, None]): - self._app: Optional[web.Application] = None - self._task: Optional[asyncio.Task[None]] = None - assert path + _app: web.Application + _task: "asyncio.Task[None]" + + def __init__(self, path: Union[Path, str]): self._path = path async def start(self, app: web.Application) -> None: @@ -28,10 +28,10 @@ async def start(self, app: web.Application) -> None: self._awatch = awatch(self._path, stop_event=self.stopper) self._task = asyncio.create_task(self._run()) - async def _run(self): + async def _run(self) -> None: raise NotImplementedError() - async def close(self, *args): + async def close(self, *args: object) -> None: if self._task: self.stopper.set() async with self._awatch.lock: @@ -53,9 +53,10 @@ def __init__(self, config: Config): self._reloads = 0 self._session: Optional[ClientSession] = None self._runner = None + assert self._config.watch_path super().__init__(self._config.watch_path) - async def _run(self, live_checks=20): + async def _run(self, live_checks: int = 20) -> None: assert self._app is not None self._session = ClientSession() @@ -64,7 +65,7 @@ async def _run(self, live_checks=20): static_path = str(self._app['static_path']) - def is_static(changes): + def is_static(changes: Iterable[Tuple[object, str]]) -> bool: return all(str(c[1]).startswith(static_path) for c in changes) async for changes in self._awatch: @@ -85,7 +86,7 @@ def is_static(changes): await self._session.close() raise AiohttpDevException('error running dev server') - async def _src_reload_when_live(self, checks=20) -> None: + async def _src_reload_when_live(self, checks: int = 20) -> None: assert self._app is not None and self._session is not None if self._app[WS]: @@ -103,7 +104,7 @@ async def _src_reload_when_live(self, checks=20) -> None: await src_reload(self._app) return - def _start_dev_server(self): + def _start_dev_server(self) -> None: act = 'Start' if self._reloads == 0 else 'Restart' logger.info('%sing dev server at http://%s:%s ●', act, self._config.host, self._config.main_port) @@ -119,7 +120,7 @@ def _start_dev_server(self): self._process = Process(target=serve_main_app, args=(self._config, tty_path)) self._process.start() - def _stop_dev_server(self): + def _stop_dev_server(self) -> None: if self._process.is_alive(): logger.debug('stopping server process...') if self._process.pid: @@ -134,7 +135,7 @@ def _stop_dev_server(self): else: logger.warning('server process already dead, exit code: %s', self._process.exitcode) - async def close(self, *args): + async def close(self, *args: object) -> None: self.stopper.set() self._stop_dev_server() if self._session is None: @@ -143,7 +144,7 @@ async def close(self, *args): class LiveReloadTask(WatchTask): - async def _run(self): + async def _run(self) -> None: async for changes in self._awatch: if len(changes) > 1: await src_reload(self._app) diff --git a/setup.py b/setup.py index 35f4c753..5b611b13 100644 --- a/setup.py +++ b/setup.py @@ -68,6 +68,7 @@ 'devtools>=0.5', 'Pygments>=2.2.0', 'watchgod>=0.2', + 'typing_extensions >= 3.7.4; python_version<"3.8"' ], python_requires='>=3.7', ) diff --git a/tests/test_runserver_config.py b/tests/test_runserver_config.py index ae679bbb..c0dd3a54 100644 --- a/tests/test_runserver_config.py +++ b/tests/test_runserver_config.py @@ -18,7 +18,7 @@ async def test_create_app_wrong_name(tmpworkdir, loop): config = Config(app_path='app.py', app_factory_name='missing') with pytest.raises(AiohttpDevConfigError) as excinfo: config.import_app_factory() - assert excinfo.value.args[0] == 'Module "app.py" does not define a "missing" attribute/class' + assert excinfo.value.args[0] == "Module 'app.py' does not define a 'missing' attribute/class" @pytest.mark.boxed diff --git a/tests/test_runserver_logs.py b/tests/test_runserver_logs.py index 7492f7ea..2447e000 100644 --- a/tests/test_runserver_logs.py +++ b/tests/test_runserver_logs.py @@ -4,6 +4,7 @@ import sys from unittest.mock import MagicMock +from aiohttp import web import pytest from aiohttp_devtools.logs import AccessFormatter, DefaultFormatter @@ -96,12 +97,12 @@ def test_extra(): info = MagicMock() logger_type = type("Logger", (), {"info": info}) logger = AccessLogger(logger_type(), "") - request = MagicMock() + request = MagicMock(spec=web.Request) request.method = 'GET' request.headers = {'Foo': 'Bar'} request.path_qs = '/foobar?v=1' request._read_bytes = b'testing' - response = MagicMock() + response = MagicMock(spec=web.Response) response.status = 500 response.body_length = 100 response.headers = {'Foo': 'Spam'} diff --git a/tests/test_runserver_main.py b/tests/test_runserver_main.py index 31ef05fd..12fa367b 100644 --- a/tests/test_runserver_main.py +++ b/tests/test_runserver_main.py @@ -119,7 +119,7 @@ async def test_run_app_aiohttp_client(tmpworkdir, aiohttp_client): mktree(tmpworkdir, SIMPLE_APP) config = Config(app_path='app.py') app_factory = config.import_app_factory() - app = app_factory() + app = await config.load_app(app_factory) modify_main_app(app, config) assert isinstance(app, aiohttp.web.Application) cli = await aiohttp_client(app) diff --git a/tests/test_runserver_serve.py b/tests/test_runserver_serve.py index b368e484..28b46e09 100644 --- a/tests/test_runserver_serve.py +++ b/tests/test_runserver_serve.py @@ -6,7 +6,7 @@ from unittest.mock import MagicMock import pytest -from aiohttp.web_app import Application +from aiohttp.web import Application, Request, Response from pytest_toolbox import mktree from aiohttp_devtools.exceptions import AiohttpDevException @@ -147,7 +147,7 @@ def test_modify_main_app_all_off(tmpworkdir): app = DummyApplication() subapp = DummyApplication() app.add_subapp("/sub/", subapp) - modify_main_app(app, config) + modify_main_app(app, config) # type: ignore[arg-type] assert len(app.on_response_prepare) == 0 assert len(app.middlewares) == 0 assert app['static_root_url'] == 'http://foobar.com:8001/static' @@ -161,7 +161,7 @@ def test_modify_main_app_all_on(tmpworkdir): app = DummyApplication() subapp = DummyApplication() app.add_subapp("/sub/", subapp) - modify_main_app(app, config) + modify_main_app(app, config) # type: ignore[arg-type] assert len(app.on_response_prepare) == 1 assert len(app.middlewares) == 1 assert app['static_root_url'] == 'http://localhost:8001/static' @@ -173,11 +173,11 @@ async def test_modify_main_app_on_prepare(tmpworkdir): mktree(tmpworkdir, SIMPLE_APP) config = Config(app_path='app.py', host='foobar.com') app = DummyApplication() - modify_main_app(app, config) + modify_main_app(app, config) # type: ignore[arg-type] on_prepare = app.on_response_prepare[0] - request = MagicMock() + request = MagicMock(spec=Request) request.path = '/' - response = MagicMock() + response = MagicMock(spec=Response) response.body = b'