From d3cbee43c3e0e85a73f235c8c00bfe2d6246c4f3 Mon Sep 17 00:00:00 2001 From: RomanZhukov Date: Sat, 25 Jan 2025 00:02:46 +0500 Subject: [PATCH] linter and test failures correction --- aiohttp_devtools/runserver/config.py | 11 ++++++----- aiohttp_devtools/runserver/main.py | 7 +++++-- aiohttp_devtools/runserver/serve.py | 8 ++++---- 3 files changed, 15 insertions(+), 11 deletions(-) diff --git a/aiohttp_devtools/runserver/config.py b/aiohttp_devtools/runserver/config.py index 5d433209..5bfb955e 100644 --- a/aiohttp_devtools/runserver/config.py +++ b/aiohttp_devtools/runserver/config.py @@ -4,9 +4,10 @@ from importlib import import_module from pathlib import Path from typing import Awaitable, Callable, Optional, Union, Literal +from types import ModuleType from aiohttp import web -import ssl +from ssl import SSLContext import __main__ from ..exceptions import AiohttpDevConfigError as AdevConfigError @@ -145,7 +146,7 @@ def _resolve_path(self, _path: str, check: str, arg_name: str) -> Path: raise AdevConfigError('{} is not a directory'.format(path)) return path - def import_module(self): + def import_module(self) -> ModuleType: """Import and return python module. Raises: @@ -164,7 +165,7 @@ def import_module(self): self.watch_path = self.watch_path or Path(module.__file__ or ".").parent return module - def get_app_factory(self, module) -> AppFactory: + def get_app_factory(self, module: ModuleType) -> AppFactory: """Return attribute/class from a python module. Raises: @@ -199,7 +200,7 @@ def get_app_factory(self, module) -> AppFactory: return attr # type: ignore[no-any-return] - def get_ssl_context(self, module) -> ssl.SSLContext: + def get_ssl_context(self, module: ModuleType) -> Union[SSLContext, None]: if self.ssl_context_factory_name is None: return None else: @@ -209,7 +210,7 @@ def get_ssl_context(self, module) -> ssl.SSLContext: raise AdevConfigError("Module '{}' does not define a '{}' attribute/class".format( self.py_file.name, self.ssl_context_factory_name)) ssl_context = attr() - if isinstance(ssl_context, ssl.SSLContext): + if isinstance(ssl_context, SSLContext): return ssl_context else: raise AdevConfigError("ssl-context-factory '{}' in module '{}' didn't return valid SSLContext".format( diff --git a/aiohttp_devtools/runserver/main.py b/aiohttp_devtools/runserver/main.py index a919e7ef..69fee106 100644 --- a/aiohttp_devtools/runserver/main.py +++ b/aiohttp_devtools/runserver/main.py @@ -1,7 +1,7 @@ import asyncio import os from multiprocessing import set_start_method -from typing import Any, Type, TypedDict +from typing import Any, Type, TypedDict, Union from aiohttp.abc import AbstractAccessLogger from aiohttp.web import Application @@ -11,6 +11,7 @@ from .log_handlers import AuxAccessLogger from .serve import check_port_open, create_auxiliary_app from .watch import AppTask, LiveReloadTask +from ssl import SSLContext class RunServer(TypedDict): @@ -19,6 +20,8 @@ class RunServer(TypedDict): port: int shutdown_timeout: float access_log_class: Type[AbstractAccessLogger] + ssl_context: Union[SSLContext, None] + def runserver(**config_kwargs: Any) -> RunServer: @@ -75,4 +78,4 @@ def serve_static(*, static_path: str, livereload: bool = True, bind_address: str livereload_status = 'ON' if livereload else 'OFF' logger.info('Serving "%s" at http://%s:%d, livereload %s', static_path, bind_address, port, livereload_status) return {"app": app, "host": bind_address, "port": port, - "shutdown_timeout": 0.01, "access_log_class": AuxAccessLogger} + "shutdown_timeout": 0.01, "access_log_class": AuxAccessLogger, "ssl_context": None} diff --git a/aiohttp_devtools/runserver/serve.py b/aiohttp_devtools/runserver/serve.py index f82c3d41..0b921902 100644 --- a/aiohttp_devtools/runserver/serve.py +++ b/aiohttp_devtools/runserver/serve.py @@ -7,7 +7,7 @@ import warnings from errno import EADDRINUSE from pathlib import Path -from typing import Any, Iterator, List, NoReturn, Optional, Set, Tuple +from typing import Any, Iterator, List, NoReturn, Optional, Set, Tuple, Union from aiohttp import WSMsgType, web from aiohttp.hdrs import LAST_MODIFIED, CONTENT_LENGTH @@ -25,7 +25,7 @@ from .log_handlers import AccessLogger from .utils import MutableValue -import ssl +from ssl import SSLContext try: from aiohttp_jinja2 import static_root_key @@ -173,7 +173,7 @@ def serve_main_app(config: Config, tty_path: Optional[str]) -> None: with asyncio.Runner() as runner: app_runner = runner.run(create_main_app(config, app_factory)) try: - runner.run(start_main_app(app_runner, config.bind_address, config.main_port)) + runner.run(start_main_app(app_runner, config.bind_address, config.main_port, ssl_context)) runner.get_loop().run_forever() except KeyboardInterrupt: pass @@ -201,7 +201,7 @@ async def create_main_app(config: Config, app_factory: AppFactory) -> web.AppRun return web.AppRunner(app, access_log_class=AccessLogger, shutdown_timeout=0.1) -async def start_main_app(runner: web.AppRunner, host: str, port: int, ssl_context: ssl.SSLContext) -> None: +async def start_main_app(runner: web.AppRunner, host: str, port: int, ssl_context: Union[SSLContext, None]) -> None: await runner.setup() site = web.TCPSite(runner, host=host, port=port, ssl_context=ssl_context) await site.start()