Skip to content

Commit

Permalink
Improve typing (#401)
Browse files Browse the repository at this point in the history
  • Loading branch information
Dreamsorcerer authored Jan 7, 2022
1 parent b696dac commit e57cdc8
Show file tree
Hide file tree
Showing 14 changed files with 205 additions and 162 deletions.
8 changes: 4 additions & 4 deletions .mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@
files = aiohttp_devtools, tests
check_untyped_defs = True
follow_imports_for_stubs = True
#disallow_any_decorated = True
disallow_any_decorated = True
disallow_any_generics = True
#disallow_incomplete_defs = True
disallow_incomplete_defs = True
disallow_subclassing_any = True
#disallow_untyped_calls = True
disallow_untyped_calls = True
disallow_untyped_decorators = True
#disallow_untyped_defs = True
disallow_untyped_defs = True
implicit_reexport = False
no_implicit_optional = True
show_error_codes = True
Expand Down
7 changes: 4 additions & 3 deletions aiohttp_devtools/cli.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import sys
import traceback
from typing import Any

import click
from aiohttp.web import run_app
Expand All @@ -18,7 +19,7 @@

@click.group()
@click.version_option(__version__, "-V", "--version", prog_name="aiohttp-devtools")
def cli():
def cli() -> None:
pass


Expand All @@ -32,7 +33,7 @@ def cli():
@click.option('--livereload/--no-livereload', envvar='AIO_LIVERELOAD', default=True, help=livereload_help)
@click.option('-p', '--port', default=8000, type=int)
@click.option('-v', '--verbose', is_flag=True, help=verbose_help)
def serve(path, livereload, port, verbose):
def serve(path: str, livereload: bool, port: int, verbose: bool) -> None:
"""
Serve static files from a directory.
"""
Expand Down Expand Up @@ -68,7 +69,7 @@ def serve(path, livereload, port, verbose):
@click.option('--aux-port', envvar='AIO_AUX_PORT', type=click.INT, help=aux_port_help)
@click.option('-v', '--verbose', is_flag=True, help=verbose_help)
@click.argument('project_args', nargs=-1)
def runserver(**config):
def runserver(**config: Any) -> None:
"""
Run a development server for an aiohttp apps.
Expand Down
26 changes: 16 additions & 10 deletions aiohttp_devtools/logs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,18 @@
import logging.config
import platform
import re
import sys
import traceback
from io import StringIO
from logging import LogRecord
from types import TracebackType
from typing import Dict, Optional, Tuple, Type, Union

if sys.version_info < (3, 8):
from typing_extensions import Literal
else:
from typing import Literal

import pygments
from devtools import pformat
from devtools.ansi import isatty, sformat
Expand All @@ -33,32 +40,31 @@


class DefaultFormatter(logging.Formatter):
def __init__(self, fmt=None, datefmt=None, style='%'):
def __init__(self, fmt: Optional[str] = None, datefmt: Optional[str] = None, style: Literal["%", "{", "$"] = "%"):
super().__init__(fmt, datefmt, style)
self.stream_is_tty = False

def format(self, record):
def format(self, record: LogRecord) -> str:
msg = super().format(record)
if not self.stream_is_tty:
return msg
m = split_log.match(msg)
log_color = LOG_FORMATS.get(record.levelno, sformat.red)
if m:
time = sformat(m.groups()[0], sformat.magenta)
return time + sformat(msg[m.end():], log_color)
return time + sformat(msg[m.end():], log_color) # type: ignore[no-any-return]

return sformat(msg, log_color)
return sformat(msg, log_color) # type: ignore[no-any-return]


class AccessFormatter(logging.Formatter):
"""
Used to log aiohttp_access and aiohttp_server
"""
def __init__(self, fmt=None, datefmt=None, style='%'):
"""Used to log aiohttp_access and aiohttp_server."""

def __init__(self, fmt: Optional[str] = None, datefmt: Optional[str] = None, style: Literal["%", "{", "$"] = "%"):
super().__init__(fmt, datefmt, style)
self.stream_is_tty = False

def formatMessage(self, record):
def formatMessage(self, record: LogRecord) -> str:
msg = super().formatMessage(record)
if msg[0] != '{':
return msg
Expand Down Expand Up @@ -174,6 +180,6 @@ def log_config(verbose: bool) -> Dict[str, object]:
}


def setup_logging(verbose):
def setup_logging(verbose: bool) -> None:
config = log_config(verbose)
logging.config.dictConfig(config)
72 changes: 38 additions & 34 deletions aiohttp_devtools/runserver/config.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
import asyncio
import inspect
import re
import sys
from importlib import import_module
from pathlib import Path
from typing import Optional
from typing import Awaitable, Callable, Optional, Union

from aiohttp import web

from ..exceptions import AiohttpDevConfigError as AdevConfigError
from ..logs import rs_dft_logger as logger

AppFactory = Union[web.Application, Callable[[], web.Application], Callable[[], Awaitable[web.Application]]]

STD_FILE_NAMES = [
re.compile(r'main\.py'),
re.compile(r'app\.py'),
Expand Down Expand Up @@ -56,9 +57,12 @@ def __init__(self, *,
self.settings_found = False

self.py_file = self._resolve_path(str(self.app_path), 'is_file', 'app-path')
self.python_path = self._resolve_path(python_path, 'is_dir', 'python-path') or self.root_path
if python_path:
self.python_path = self._resolve_path(python_path, "is_dir", "python-path")
else:
self.python_path = self.root_path

self.static_path = self._resolve_path(static_path, 'is_dir', 'static-path')
self.static_path = self._resolve_path(static_path, "is_dir", "static-path") if static_path else None
self.static_url = static_url
self.livereload = livereload
self.app_factory_name = app_factory_name
Expand All @@ -70,7 +74,7 @@ def __init__(self, *,

@property
def static_path_str(self) -> Optional[str]:
return self.static_path and str(self.static_path)
return str(self.static_path) if self.static_path else None

def _find_app_path(self, app_path: str) -> Path:
# for backwards compatibility try this first
Expand All @@ -94,10 +98,7 @@ def _find_app_path(self, app_path: str) -> Path:
raise AdevConfigError('unable to find a recognised default file ("app.py" or "main.py") '
'in the directory "%s"' % app_path)

def _resolve_path(self, _path: Optional[str], check: str, arg_name: str):
if _path is None:
return

def _resolve_path(self, _path: str, check: str, arg_name: str) -> Path:
if _path.startswith('/'):
path = Path(_path)
error_msg = '{arg_name} "{path}" is not a valid path'
Expand All @@ -119,11 +120,11 @@ def _resolve_path(self, _path: Optional[str], check: str, arg_name: str):
raise AdevConfigError('{} is not a directory'.format(path))
return path

def import_app_factory(self):
"""
Import attribute/class from from a python module. Raise AdevConfigError if the import failed.
def import_app_factory(self) -> AppFactory:
"""Import and return attribute/class from a python module.
:return: (attribute, Path object for directory of file)
Raises:
AdevConfigError - If the import failed.
"""
rel_py_file = self.py_file.relative_to(self.python_path)
module_path = '.'.join(rel_py_file.with_suffix('').parts)
Expand All @@ -145,36 +146,39 @@ def import_app_factory(self):

try:
attr = getattr(module, self.app_factory_name)
except AttributeError as e:
raise AdevConfigError('Module "{s.py_file.name}" '
'does not define a "{s.app_factory_name}" attribute/class'.format(s=self)) from e
except AttributeError:
raise AdevConfigError("Module '{}' does not define a '{}' attribute/class".format(
self.py_file.name, self.app_factory_name))

if not isinstance(attr, web.Application) and not callable(attr):
raise AdevConfigError("'{}.{}' is not an Application or callable".format(
self.py_file.name, self.app_factory_name))

if callable(attr):
required_args = attr.__code__.co_argcount - len(attr.__defaults__)
if required_args > 0:
raise AdevConfigError("'{}.{}' should not have required arguments.".format(
self.py_file.name, self.app_factory_name))

self.watch_path = self.watch_path or Path(module.__file__ or ".").parent
return attr
return attr # type: ignore[no-any-return]

async def load_app(self, app_factory):
async def load_app(self, app_factory: AppFactory) -> web.Application:
if isinstance(app_factory, web.Application):
app = app_factory
else:
# app_factory should be a proper factory with signature (loop): -> Application
signature = inspect.signature(app_factory)
if 'loop' in signature.parameters:
loop = asyncio.get_event_loop()
app = app_factory(loop=loop)
else:
# loop argument missing, assume no arguments
app = app_factory()
return app_factory

app = app_factory()

if asyncio.iscoroutine(app):
app = await app
if asyncio.iscoroutine(app):
app = await app

if not isinstance(app, web.Application):
raise AdevConfigError('app factory "{.app_factory_name}" returned "{.__class__.__name__}" not an '
'aiohttp.web.Application'.format(self, app))
if not isinstance(app, web.Application):
raise AdevConfigError("app factory '{}' returned '{}' not an aiohttp.web.Application".format(
self.app_factory_name, app.__class__.__name__))

return app

def __str__(self):
def __str__(self) -> str:
fields = ('py_file', 'static_path', 'static_url', 'livereload',
'app_factory_name', 'host', 'main_port', 'aux_port')
return 'Config:\n' + '\n'.join(' {0}: {1!r}'.format(f, getattr(self, f)) for f in fields)
83 changes: 45 additions & 38 deletions aiohttp_devtools/runserver/log_handlers.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import json
import warnings
from datetime import datetime, timedelta
from typing import cast
from typing import Dict, Optional, Union, cast

from aiohttp import web
from aiohttp.abc import AbstractAccessLogger

dbtb = '/_debugtoolbar/'
Expand All @@ -12,13 +13,13 @@
class _AccessLogger(AbstractAccessLogger):
prefix: str

def get_msg(self, request, response, time):
def get_msg(self, request: web.BaseRequest, response: web.StreamResponse, time: float) -> Optional[str]:
raise NotImplementedError()

def extra(self, request, response, time):
def extra(self, request: web.BaseRequest, response: web.StreamResponse, time: float) -> Optional[Dict[str, object]]:
pass

def log(self, request, response, time):
def log(self, request: web.BaseRequest, response: web.StreamResponse, time: float) -> None:
msg = self.get_msg(request, response, time)
if not msg:
return
Expand All @@ -39,7 +40,7 @@ def log(self, request, response, time):
class AccessLogger(_AccessLogger):
prefix = '●'

def get_msg(self, request, response, time):
def get_msg(self, request: web.BaseRequest, response: web.StreamResponse, time: float) -> str:
return '{method} {path} {code} {size} {ms:0.0f}ms'.format(
method=request.method,
path=request.path_qs,
Expand All @@ -48,35 +49,40 @@ def get_msg(self, request, response, time):
ms=time * 1000,
)

def extra(self, request, response, time: float):
if response.status > 310:
request_body = request._read_bytes
details = dict(
request_duration_ms=round(time * 1000, 3),
request_headers=dict(request.headers),
request_body=parse_body(request_body, 'request body'),
request_size=fmt_size(0 if request_body is None else len(request_body)),
response_headers=dict(response.headers),
response_body=parse_body(response.text, "response body"),
)
return dict(details=details)
def extra(self, request: web.BaseRequest, response: web.StreamResponse, time: float) -> Optional[Dict[str, object]]:
if response.status <= 310:
return None

request_body = request._read_bytes
body_text = response.text if isinstance(response, web.Response) else None
details = {
"request_duration_ms": round(time * 1000, 3),
"request_headers": dict(request.headers),
"request_body": parse_body(request_body, "request body"),
"request_size": fmt_size(0 if request_body is None else len(request_body)),
"response_headers": dict(response.headers),
"response_body": parse_body(body_text, "response body"),
}
return {"details": details}


class AuxAccessLogger(_AccessLogger):
prefix = '◆'

def get_msg(self, request, response, time):
def get_msg(self, request: web.BaseRequest, response: web.StreamResponse, time: float) -> Optional[str]:
# don't log livereload
if request.path not in {'/livereload', '/livereload.js'}:
return '{method} {path} {code} {size}'.format(
method=request.method,
path=request.path_qs,
code=response.status,
size=fmt_size(response.body_length),
)
if request.path in {"/livereload", "/livereload.js"}:
return None

return "{method} {path} {code} {size}".format(
method=request.method,
path=request.path_qs,
code=response.status,
size=fmt_size(response.body_length),
)


def fmt_size(num):
def fmt_size(num: int) -> str:
if not num:
return ''
if num < 1024:
Expand All @@ -85,15 +91,16 @@ def fmt_size(num):
return '{:0.1f}KB'.format(num / 1024)


def parse_body(v, name):
if isinstance(v, (str, bytes)):
try:
return json.loads(v)
except UnicodeDecodeError:
v = cast(bytes, v) # UnicodeDecodeError only occurs on bytes.
warnings.warn('UnicodeDecodeError parsing ' + name, UserWarning)
# bytes which cause UnicodeDecodeError can cause problems later on
return v.decode(errors='ignore')
except (ValueError, TypeError):
pass
return v
def parse_body(v: Union[str, bytes, None], name: str) -> object:
if v is None:
return v

try:
return json.loads(v)
except UnicodeDecodeError:
v = cast(bytes, v) # UnicodeDecodeError only occurs on bytes.
warnings.warn("UnicodeDecodeError parsing " + name, UserWarning)
# bytes which cause UnicodeDecodeError can cause problems later on
return v.decode(errors="ignore")
except (ValueError, TypeError):
return v
Loading

0 comments on commit e57cdc8

Please sign in to comment.