Skip to content

Commit

Permalink
Fix typing in remaining files (#397)
Browse files Browse the repository at this point in the history
  • Loading branch information
Dreamsorcerer authored Jan 5, 2022
1 parent 336b9fa commit b696dac
Show file tree
Hide file tree
Showing 11 changed files with 100 additions and 77 deletions.
9 changes: 1 addition & 8 deletions .mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -21,23 +21,16 @@ disallow_any_unimported = True
warn_return_any = True

[mypy-tests.*]
ignore_errors = True
disallow_any_decorated = False
disallow_untyped_calls = False
disallow_untyped_defs = False

[mypy-aiohttp_devtools.logs]
ignore_errors = True

[mypy-aiohttp_devtools.runserver.*]
ignore_errors = True


[mypy-aiohttp_debugtoolbar.*]
ignore_missing_imports = True

[mypy-devtools.*]
ignore_missing_imports = True

[mypy-pygments.*]
[mypy-pytest_toolbox.*]
ignore_missing_imports = True
35 changes: 21 additions & 14 deletions aiohttp_devtools/logs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,17 @@
import re
import traceback
from io import StringIO
from types import TracebackType
from typing import Dict, Optional, Tuple, Type, Union

import pygments
from devtools import pformat
from devtools.ansi import isatty, sformat
from pygments.formatters import Terminal256Formatter
from pygments.lexers import Python3TracebackLexer

_Ei = Union[Tuple[Type[BaseException], BaseException, Optional[TracebackType]], Tuple[None, None, None]]

rs_dft_logger = logging.getLogger('adev.server.dft')
rs_aux_logger = logging.getLogger('adev.server.aux')

Expand All @@ -28,12 +32,6 @@
split_log = re.compile(r'^(\[.*?\])')


class HighlightStreamHandler(logging.StreamHandler):
def setFormatter(self, fmt):
self.formatter = fmt
self.formatter.stream_is_tty = isatty(self.stream) and platform.system().lower() != 'windows'


class DefaultFormatter(logging.Formatter):
def __init__(self, fmt=None, datefmt=None, style='%'):
super().__init__(fmt, datefmt, style)
Expand All @@ -48,8 +46,8 @@ def format(self, record):
if m:
time = sformat(m.groups()[0], sformat.magenta)
return time + sformat(msg[m.end():], log_color)
else:
return sformat(msg, log_color)

return sformat(msg, log_color)


class AccessFormatter(logging.Formatter):
Expand Down Expand Up @@ -80,18 +78,27 @@ def formatMessage(self, record):
msg = 'details: {}\n{}'.format(pformat(details, highlight=self.stream_is_tty), msg)
return msg

def formatException(self, ei):
def formatException(self, ei: _Ei) -> str:
sio = StringIO()
traceback.print_exception(*ei, file=sio)
traceback.print_exception(*ei, file=sio) # type: ignore[misc]
stack = sio.getvalue()
sio.close()
if self.stream_is_tty and pyg_lexer:
return pygments.highlight(stack, lexer=pyg_lexer, formatter=pyg_formatter).rstrip('\n')
else:
return stack
return pygments.highlight(stack, lexer=pyg_lexer, formatter=pyg_formatter).rstrip("\n") # type: ignore[no-any-return] # noqa

return stack


class HighlightStreamHandler(logging.StreamHandler): # type: ignore[type-arg]
def setFormatter(self, fmt: Optional[logging.Formatter]) -> None:
stream_is_tty = isatty(self.stream) and platform.system().lower() != "windows"
if isinstance(fmt, (DefaultFormatter, AccessFormatter)):
fmt.stream_is_tty = stream_is_tty

self.formatter = fmt


def log_config(verbose: bool) -> dict:
def log_config(verbose: bool) -> Dict[str, object]:
"""
Setup default config. for dictConfig.
:param verbose: level: DEBUG if True, INFO if False
Expand Down
4 changes: 2 additions & 2 deletions aiohttp_devtools/runserver/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def import_app_factory(self):
raise AdevConfigError('Module "{s.py_file.name}" '
'does not define a "{s.app_factory_name}" attribute/class'.format(s=self)) from e

self.watch_path = self.watch_path or Path(module.__file__).parent
self.watch_path = self.watch_path or Path(module.__file__ or ".").parent
return attr

async def load_app(self, app_factory):
Expand All @@ -166,7 +166,7 @@ async def load_app(self, app_factory):
app = app_factory()

if asyncio.iscoroutine(app):
app = await app # type: ignore[misc]
app = await app

if not isinstance(app, web.Application):
raise AdevConfigError('app factory "{.app_factory_name}" returned "{.__class__.__name__}" not an '
Expand Down
2 changes: 1 addition & 1 deletion aiohttp_devtools/runserver/log_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def extra(self, request, response, time: float):
request_body=parse_body(request_body, 'request body'),
request_size=fmt_size(0 if request_body is None else len(request_body)),
response_headers=dict(response.headers),
response_body=parse_body(response.text or response.body, 'response body'),
response_body=parse_body(response.text, "response body"),
)
return dict(details=details)

Expand Down
10 changes: 6 additions & 4 deletions aiohttp_devtools/runserver/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,13 +83,13 @@ async def static_middleware(request, handler):
_set_static_url(app, static_url)


async def check_port_open(port: int, delay: int = 1) -> None:
async def check_port_open(port: int, delay: float = 1) -> None:
loop = asyncio.get_running_loop()
# the "s = socket.socket; s.bind" approach sometimes says a port is in use when it's not
# this approach replicates aiohttp so should always give the same answer
for i in range(5, 0, -1):
try:
server = await loop.create_server(asyncio.Protocol(), host=HOST, port=port)
server = await loop.create_server(asyncio.Protocol, host=HOST, port=port)
except OSError as e:
if e.errno != EADDRINUSE:
raise
Expand Down Expand Up @@ -198,7 +198,8 @@ async def cleanup_aux_app(app):
await asyncio.gather(*(ws.close() for ws, _ in app[WS]))


def create_auxiliary_app(*, static_path: str, static_url='/', livereload=True):
def create_auxiliary_app(
*, static_path: Optional[str], static_url="/", livereload=True) -> web.Application:
app = web.Application()
app[WS] = set()
app.update(
Expand Down Expand Up @@ -330,7 +331,8 @@ def _insert_footer(self, response):
body = f.read() + LIVE_RELOAD_LOCAL_SNIPPET

resp = Response(body=body, content_type='text/html')
resp.last_modified = filepath.stat().st_mtime
# Mypy bug: https://github.com/python/mypy/issues/11892
resp.last_modified = filepath.stat().st_mtime # type: ignore[assignment]
return resp

async def _handle(self, request):
Expand Down
27 changes: 18 additions & 9 deletions aiohttp_devtools/runserver/watch.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import signal
import sys
from multiprocessing import Process
from pathlib import Path
from typing import AsyncIterator, Optional, Union

from aiohttp import ClientSession, web
from watchgod import awatch
Expand All @@ -14,17 +16,17 @@


class WatchTask:
def __init__(self, path: str):
self._app = None
self._task = None
def __init__(self, path: Union[Path, str, None]):
self._app: Optional[web.Application] = None
self._task: Optional[asyncio.Task[None]] = None
assert path
self._path = path

async def start(self, app: web.Application) -> None:
self._app = app
self.stopper = asyncio.Event()
self._awatch = awatch(self._path, stop_event=self.stopper)
self._task = asyncio.get_event_loop().create_task(self._run())
self._task = asyncio.create_task(self._run())

async def _run(self):
raise NotImplementedError()
Expand All @@ -37,7 +39,7 @@ async def close(self, *args):
self._task.result()
self._task.cancel()

async def cleanup_ctx(self, app: web.Application) -> None:
async def cleanup_ctx(self, app: web.Application) -> AsyncIterator[None]:
await self.start(app)
yield
await self.close(app)
Expand All @@ -49,11 +51,13 @@ class AppTask(WatchTask):
def __init__(self, config: Config):
self._config = config
self._reloads = 0
self._session = None
self._session: Optional[ClientSession] = None
self._runner = None
super().__init__(self._config.watch_path)

async def _run(self, live_checks=20):
assert self._app is not None

self._session = ClientSession()
try:
self._start_dev_server()
Expand Down Expand Up @@ -81,7 +85,9 @@ def is_static(changes):
await self._session.close()
raise AiohttpDevException('error running dev server')

async def _src_reload_when_live(self, checks=20):
async def _src_reload_when_live(self, checks=20) -> None:
assert self._app is not None and self._session is not None

if self._app[WS]:
url = 'http://localhost:{.main_port}/?_checking_alive=1'.format(self._config)
logger.debug('checking app at "%s" is running before prompting reload...', url)
Expand Down Expand Up @@ -116,11 +122,12 @@ def _start_dev_server(self):
def _stop_dev_server(self):
if self._process.is_alive():
logger.debug('stopping server process...')
os.kill(self._process.pid, signal.SIGINT)
if self._process.pid:
os.kill(self._process.pid, signal.SIGINT)
self._process.join(5)
if self._process.exitcode is None:
logger.warning('process has not terminated, sending SIGKILL')
os.kill(self._process.pid, signal.SIGKILL)
self._process.kill()
self._process.join(1)
else:
logger.debug('process stopped')
Expand All @@ -130,6 +137,8 @@ def _stop_dev_server(self):
async def close(self, *args):
self.stopper.set()
self._stop_dev_server()
if self._session is None:
raise RuntimeError("Object not started correctly before calling .close()")
await asyncio.gather(super().close(), self._session.close())


Expand Down
1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@ pytest-timeout==2.0.2
pytest-toolbox==0.4
pytest-xdist==2.5.0
Sphinx==4.3.2
types-pygments==2.9.10
20 changes: 10 additions & 10 deletions tests/test_runserver_logs.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@

def test_aiohttp_std():
info = MagicMock()
logger = type('Logger', (), {'info': info})
logger = AccessLogger(logger, None)
logger_type = type("Logger", (), {"info": info})
logger = AccessLogger(logger_type(), "")
request = MagicMock()
request.method = 'GET'
request.path_qs = '/foobar?v=1'
Expand All @@ -34,8 +34,8 @@ def test_aiohttp_std():

def test_aiohttp_debugtoolbar():
info = MagicMock()
logger = type('Logger', (), {'info': info})
logger = AccessLogger(logger, None)
logger_type = type("Logger", (), {"info": info})
logger = AccessLogger(logger_type(), "")
request = MagicMock()
request.method = 'GET'
request.path_qs = '/_debugtoolbar/whatever'
Expand All @@ -56,8 +56,8 @@ def test_aiohttp_debugtoolbar():

def test_aux_logger():
info = MagicMock()
logger = type('Logger', (), {'info': info})
logger = AuxAccessLogger(logger, None)
logger_type = type("Logger", (), {"info": info})
logger = AuxAccessLogger(logger_type(), "")
request = MagicMock()
request.method = 'GET'
request.path = '/'
Expand All @@ -79,8 +79,8 @@ def test_aux_logger():

def test_aux_logger_livereload():
info = MagicMock()
logger = type('Logger', (), {'info': info})
logger = AuxAccessLogger(logger, None)
logger_type = type("Logger", (), {"info": info})
logger = AuxAccessLogger(logger_type(), "")
request = MagicMock()
request.method = 'GET'
request.path = '/livereload.js'
Expand All @@ -94,8 +94,8 @@ def test_aux_logger_livereload():

def test_extra():
info = MagicMock()
logger = type('Logger', (), {'info': info})
logger = AccessLogger(logger, None)
logger_type = type("Logger", (), {"info": info})
logger = AccessLogger(logger_type(), "")
request = MagicMock()
request.method = 'GET'
request.headers = {'Foo': 'Bar'}
Expand Down
4 changes: 2 additions & 2 deletions tests/test_runserver_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ async def test_serve_main_app(tmpworkdir, loop, mocker):
loop.call_later(0.5, loop.stop)

config = Config(app_path='app.py')
runner = await create_main_app(config, config.import_app_factory(), loop)
runner = await create_main_app(config, config.import_app_factory())
await start_main_app(runner, config.main_port)

mock_modify_main_app.assert_called_with(mock.ANY, config)
Expand All @@ -171,7 +171,7 @@ async def hello(request):
mock_modify_main_app = mocker.patch('aiohttp_devtools.runserver.serve.modify_main_app')

config = Config(app_path='app.py')
runner = await create_main_app(config, config.import_app_factory(), loop)
runner = await create_main_app(config, config.import_app_factory())
await start_main_app(runner, config.main_port)

mock_modify_main_app.assert_called_with(mock.ANY, config)
Expand Down
5 changes: 4 additions & 1 deletion tests/test_runserver_serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import pathlib
import socket
from platform import system as get_os_family
from typing import Dict
from unittest.mock import MagicMock

import pytest
Expand Down Expand Up @@ -126,7 +127,9 @@ def test_fmt_size_large(value, result):
assert fmt_size(value) == result


class DummyApplication(dict):
class DummyApplication(Dict[str, object]):
_debug = False

def __init__(self):
self.on_response_prepare = []
self.middlewares = []
Expand Down
Loading

0 comments on commit b696dac

Please sign in to comment.