diff --git a/README.rst b/README.rst index 40e4bf63..c30891eb 100644 --- a/README.rst +++ b/README.rst @@ -48,6 +48,11 @@ All ``runserver`` arguments can be set via environment variables. * **livereload** will reload resources in the browser as your code changes without having to hit refresh, see `livereload`_ for more details. * **static files** are served separately from your main app (generally on ``8001`` while your app is on ``8000``) so you don't have to contaminate your application to serve static files you only need locally. +The ``--ssl-context-factory`` option can be used to define method from the app path file, which returns ssl.SSLContext +for ssl support. +If You are going to use self-signed certificate for your dev server, you should install proper rootCA certificate to your system. +Or you can use ``--ssl-rootcert`` option. If proper rootCA certificate is not installed or specified by option, livereload feature will not work. + For more options see ``adev runserver --help``. serve diff --git a/aiohttp_devtools/cli.py b/aiohttp_devtools/cli.py index 37080cb5..a22afa0e 100644 --- a/aiohttp_devtools/cli.py +++ b/aiohttp_devtools/cli.py @@ -64,6 +64,10 @@ def serve(path: str, livereload: bool, bind_address: str, port: int, verbose: bo 'or just an instance of aiohttp.Application. env variable AIO_APP_FACTORY') port_help = 'Port to serve app from, default 8000. env variable: AIO_PORT' aux_port_help = 'Port to serve auxiliary app (reload and static) on, default port + 1. env variable: AIO_AUX_PORT' +ssl_context_factory_help = ("name of the ssl context factory to create ssl.SSLContext with. " + "env variable: AIO_SSL_CONTEXT_FACTORY") +ssl_rootcert_file_help = ("path to a rootCA certificate file for self-signed cert chain (if needed). " + "env variable: AIO_SSL_ROOTCERT") # defaults are all None here so default settings are defined in one place: DEV_DICT validation @@ -83,6 +87,10 @@ def serve(path: str, livereload: bool, bind_address: str, port: int, verbose: bo @click.option('-v', '--verbose', is_flag=True, help=verbose_help) @click.option("--browser-cache/--no-browser-cache", envvar="AIO_BROWSER_CACHE", default=None, help=browser_cache_help) +@click.option("--ssl-context-factory", "ssl_context_factory_name", envvar="AIO_SSL_CONTEXT_FACTORY", default=None, + help=ssl_context_factory_help) +@click.option("--ssl-rootcert", "ssl_rootcert_file_path", envvar="AIO_SSL_ROOTCERT", default=None, + help=ssl_rootcert_file_help) @click.argument('project_args', nargs=-1) def runserver(**config: Any) -> None: """ diff --git a/aiohttp_devtools/runserver/config.py b/aiohttp_devtools/runserver/config.py index 5c65b17d..eed7c392 100644 --- a/aiohttp_devtools/runserver/config.py +++ b/aiohttp_devtools/runserver/config.py @@ -3,9 +3,11 @@ import sys from importlib import import_module from pathlib import Path -from typing import Awaitable, Callable, Optional, Union +from typing import Awaitable, Callable, Literal, Optional, Union +from types import ModuleType from aiohttp import web +from ssl import SSLContext, SSLError, create_default_context as create_default_ssl_context import __main__ from ..exceptions import AiohttpDevConfigError as AdevConfigError @@ -45,7 +47,9 @@ def __init__(self, *, bind_address: str = "localhost", main_port: int = 8000, aux_port: Optional[int] = None, - browser_cache: bool = False): + browser_cache: bool = False, + ssl_context_factory_name: Optional[str] = None, + ssl_rootcert_file_path: Optional[str] = None): if root_path: self.root_path = Path(root_path).resolve() logger.debug('Root path specified: %s', self.root_path) @@ -86,12 +90,32 @@ def __init__(self, *, self.main_port = main_port self.aux_port = aux_port or (main_port + 1) self.browser_cache = browser_cache + self.ssl_context_factory_name = ssl_context_factory_name + self.ssl_rootcert_file_path = ssl_rootcert_file_path logger.debug('config loaded:\n%s', self) + @property + def protocol(self) -> Literal["http", "https"]: + return "http" if self.ssl_context_factory_name is None else "https" + @property def static_path_str(self) -> Optional[str]: return str(self.static_path) if self.static_path else None + @property + def client_ssl_context(self) -> Union[SSLContext, None]: + client_ssl_context = None + if self.protocol == "https": + client_ssl_context = create_default_ssl_context() + if self.ssl_rootcert_file_path: + try: + client_ssl_context.load_verify_locations(self.ssl_rootcert_file_path) + except FileNotFoundError: + raise AdevConfigError("No such file or directory: {}".format(self.ssl_rootcert_file_path)) + except SSLError: + raise AdevConfigError("invalid root cert file: {}".format(self.ssl_rootcert_file_path)) + return client_ssl_context + def _find_app_path(self, app_path: str) -> Path: # for backwards compatibility try this first path = (self.root_path / app_path).resolve() @@ -136,15 +160,14 @@ def _resolve_path(self, _path: str, check: str, arg_name: str) -> Path: raise AdevConfigError('{} is not a directory'.format(path)) return path - def import_app_factory(self) -> AppFactory: - """Import and return attribute/class from a python module. + def import_module(self) -> ModuleType: + """Import and return python module. Raises: AdevConfigError - If the import failed. """ rel_py_file = self.py_file.relative_to(self.python_path) module_path = '.'.join(rel_py_file.with_suffix('').parts) - sys.path.insert(0, str(self.python_path)) module = import_module(module_path) # Rewrite the package name, so it will appear the same as running the app. @@ -153,6 +176,16 @@ def import_app_factory(self) -> AppFactory: logger.debug('successfully loaded "%s" from "%s"', module_path, self.python_path) + self.watch_path = self.watch_path or Path(module.__file__ or ".").parent + return module + + def get_app_factory(self, module: ModuleType) -> AppFactory: + """Return attribute/class from a python module. + + Raises: + AdevConfigError - If the import failed. + """ + if self.app_factory_name is None: try: self.app_factory_name = next(an for an in APP_FACTORY_NAMES if hasattr(module, an)) @@ -179,9 +212,24 @@ def import_app_factory(self) -> AppFactory: raise AdevConfigError("'{}.{}' should not have required arguments.".format( self.py_file.name, self.app_factory_name)) - self.watch_path = self.watch_path or Path(module.__file__ or ".").parent return attr # type: ignore[no-any-return] + def get_ssl_context(self, module: ModuleType) -> Union[SSLContext, None]: + if self.ssl_context_factory_name is None: + return None + else: + try: + attr = getattr(module, self.ssl_context_factory_name) + except AttributeError: + raise AdevConfigError("Module '{}' does not define a '{}' attribute/class".format( + self.py_file.name, self.ssl_context_factory_name)) + ssl_context = attr() + if isinstance(ssl_context, SSLContext): + return ssl_context + else: + raise AdevConfigError("ssl-context-factory '{}' in module '{}' didn't return valid SSLContext".format( + self.ssl_context_factory_name, self.py_file.name)) + async def load_app(self, app_factory: AppFactory) -> web.Application: if isinstance(app_factory, web.Application): return app_factory diff --git a/aiohttp_devtools/runserver/main.py b/aiohttp_devtools/runserver/main.py index 0dd35052..86d2611b 100644 --- a/aiohttp_devtools/runserver/main.py +++ b/aiohttp_devtools/runserver/main.py @@ -1,7 +1,7 @@ import asyncio import os from multiprocessing import set_start_method -from typing import Any, Type, TypedDict +from typing import Any, Type, TypedDict, Union from aiohttp.abc import AbstractAccessLogger from aiohttp.web import Application @@ -11,6 +11,7 @@ from .log_handlers import AuxAccessLogger from .serve import check_port_open, create_auxiliary_app from .watch import AppTask, LiveReloadTask +from ssl import SSLContext class RunServer(TypedDict): @@ -19,6 +20,7 @@ class RunServer(TypedDict): port: int shutdown_timeout: float access_log_class: Type[AbstractAccessLogger] + ssl_context: Union[SSLContext, None] def runserver(**config_kwargs: Any) -> RunServer: @@ -29,9 +31,8 @@ def runserver(**config_kwargs: Any) -> RunServer: """ # force a full reload in sub processes so they load an updated version of code, this must be called only once set_start_method('spawn') - config = Config(**config_kwargs) - config.import_app_factory() + config.import_module() asyncio.run(check_port_open(config.main_port, host=config.bind_address)) @@ -57,7 +58,7 @@ def runserver(**config_kwargs: Any) -> RunServer: logger.info('serving static files from ./%s/ at %s%s', rel_path, url, config.static_url) return {"app": aux_app, "host": config.bind_address, "port": config.aux_port, - "shutdown_timeout": 0.01, "access_log_class": AuxAccessLogger} + "shutdown_timeout": 0.01, "access_log_class": AuxAccessLogger, "ssl_context": None} def serve_static(*, static_path: str, livereload: bool = True, bind_address: str = "localhost", port: int = 8000, @@ -75,4 +76,4 @@ def serve_static(*, static_path: str, livereload: bool = True, bind_address: str livereload_status = 'ON' if livereload else 'OFF' logger.info('Serving "%s" at http://%s:%d, livereload %s', static_path, bind_address, port, livereload_status) return {"app": app, "host": bind_address, "port": port, - "shutdown_timeout": 0.01, "access_log_class": AuxAccessLogger} + "shutdown_timeout": 0.01, "access_log_class": AuxAccessLogger, "ssl_context": None} diff --git a/aiohttp_devtools/runserver/serve.py b/aiohttp_devtools/runserver/serve.py index 9be8e6ba..4e58715b 100644 --- a/aiohttp_devtools/runserver/serve.py +++ b/aiohttp_devtools/runserver/serve.py @@ -7,7 +7,7 @@ import warnings from errno import EADDRINUSE from pathlib import Path -from typing import Any, Iterator, List, NoReturn, Optional, Set, Tuple +from typing import Any, Iterator, List, NoReturn, Optional, Set, Tuple, Union from aiohttp import WSMsgType, web from aiohttp.hdrs import LAST_MODIFIED, CONTENT_LENGTH @@ -25,6 +25,8 @@ from .log_handlers import AccessLogger from .utils import MutableValue +from ssl import SSLContext + try: from aiohttp_jinja2 import static_root_key except ImportError: @@ -120,7 +122,8 @@ def shutdown() -> NoReturn: path = config.path_prefix + "/shutdown" app.router.add_route("GET", path, do_shutdown, name="_devtools.shutdown") - dft_logger.debug("Created shutdown endpoint at http://{}:{}{}".format(config.host, config.main_port, path)) + dft_logger.debug("Created shutdown endpoint at {}://{}:{}{}".format( + config.protocol, config.host, config.main_port, path)) if config.static_path is not None: static_url = 'http://{}:{}/{}'.format(config.host, config.aux_port, static_path) @@ -164,12 +167,14 @@ def set_tty(tty_path: Optional[str]) -> Iterator[None]: def serve_main_app(config: Config, tty_path: Optional[str]) -> None: with set_tty(tty_path): setup_logging(config.verbose) - app_factory = config.import_app_factory() + module = config.import_module() + app_factory = config.get_app_factory(module) + ssl_context = config.get_ssl_context(module) if sys.version_info >= (3, 11): with asyncio.Runner() as runner: app_runner = runner.run(create_main_app(config, app_factory)) try: - runner.run(start_main_app(app_runner, config.bind_address, config.main_port)) + runner.run(start_main_app(app_runner, config.bind_address, config.main_port, ssl_context)) runner.get_loop().run_forever() except KeyboardInterrupt: pass @@ -180,7 +185,7 @@ def serve_main_app(config: Config, tty_path: Optional[str]) -> None: loop = asyncio.new_event_loop() runner = loop.run_until_complete(create_main_app(config, app_factory)) try: - loop.run_until_complete(start_main_app(runner, config.bind_address, config.main_port)) + loop.run_until_complete(start_main_app(runner, config.bind_address, config.main_port, ssl_context)) loop.run_forever() except KeyboardInterrupt: # pragma: no cover pass @@ -197,9 +202,9 @@ async def create_main_app(config: Config, app_factory: AppFactory) -> web.AppRun return web.AppRunner(app, access_log_class=AccessLogger, shutdown_timeout=0.1) -async def start_main_app(runner: web.AppRunner, host: str, port: int) -> None: +async def start_main_app(runner: web.AppRunner, host: str, port: int, ssl_context: Union[SSLContext, None]) -> None: await runner.setup() - site = web.TCPSite(runner, host=host, port=port) + site = web.TCPSite(runner, host=host, port=port, ssl_context=ssl_context) await site.start() diff --git a/aiohttp_devtools/runserver/watch.py b/aiohttp_devtools/runserver/watch.py index 07352397..0d3624b6 100644 --- a/aiohttp_devtools/runserver/watch.py +++ b/aiohttp_devtools/runserver/watch.py @@ -16,6 +16,7 @@ from ..logs import rs_dft_logger as logger from .config import Config from .serve import LAST_RELOAD, STATIC_PATH, WS, serve_main_app, src_reload +from ssl import SSLContext class WatchTask: @@ -55,13 +56,17 @@ def __init__(self, config: Config): self._reloads = 0 self._session: Optional[ClientSession] = None self._runner = None + self._client_ssl_context: Union[None, SSLContext] = None assert self._config.watch_path + super().__init__(self._config.watch_path) async def _run(self, live_checks: int = 150) -> None: assert self._app is not None self._session = ClientSession() + self._client_ssl_context = self._config.client_ssl_context + try: self._start_dev_server() @@ -107,12 +112,12 @@ async def _src_reload_when_live(self, checks: int) -> None: assert self._app is not None and self._session is not None if self._app[WS]: - url = "http://{0.host}:{0.main_port}/?_checking_alive=1".format(self._config) + url = "{0.protocol}://{0.host}:{0.main_port}/?_checking_alive=1".format(self._config) logger.debug('checking app at "%s" is running before prompting reload...', url) for i in range(checks): await asyncio.sleep(0.1) try: - async with self._session.get(url): + async with self._session.get(url, ssl=self._client_ssl_context): pass except OSError as e: logger.debug('try %d | OSError %d app not running', i, e.errno) @@ -123,7 +128,8 @@ async def _src_reload_when_live(self, checks: int) -> None: def _start_dev_server(self) -> None: act = 'Start' if self._reloads == 0 else 'Restart' - logger.info('%sing dev server at http://%s:%s ●', act, self._config.host, self._config.main_port) + logger.info("%sing dev server at %s://%s:%s ●", + act, self._config.protocol, self._config.host, self._config.main_port) try: tty_path = os.ttyname(sys.stdin.fileno()) @@ -141,12 +147,12 @@ async def _stop_dev_server(self) -> None: if self._process.is_alive(): logger.debug('stopping server process...') if self._config.shutdown_by_url: # Workaround for signals not working on Windows - url = "http://{0.host}:{0.main_port}{0.path_prefix}/shutdown".format(self._config) + url = "{0.protocol}://{0.host}:{0.main_port}{0.path_prefix}/shutdown".format(self._config) logger.debug("Attempting to stop process via shutdown endpoint {}".format(url)) try: with suppress(ClientConnectionError): async with ClientSession() as session: - async with session.get(url): + async with session.get(url, ssl=self._client_ssl_context): pass except (ConnectionError, ClientError, asyncio.TimeoutError) as ex: if self._process.is_alive(): diff --git a/requirements.txt b/requirements.txt index 2312c7f4..326e2697 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,4 +14,5 @@ pytest-mock==3.14.0 pytest-sugar==1.0.0 pytest-timeout==2.2.0 pytest-toolbox==0.4 +pytest-datafiles==3.0.0 watchfiles==1.0.4 diff --git a/tests/conftest.py b/tests/conftest.py index 848890d1..e2b56fb3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -8,6 +8,12 @@ else: forked = pytest.mark.forked +if sys.platform == "linux": + linux_forked = pytest.mark.forked +else: + def linux_forked(func): + return func + SIMPLE_APP = { 'app.py': """\ from aiohttp import web diff --git a/tests/test_certs/rootCA.pem b/tests/test_certs/rootCA.pem new file mode 100644 index 00000000..eb2ecada --- /dev/null +++ b/tests/test_certs/rootCA.pem @@ -0,0 +1,32 @@ +-----BEGIN CERTIFICATE----- +MIIFlTCCA32gAwIBAgIUMqRqzVHCUfN7kz43bWrwlfmtl7kwDQYJKoZIhvcNAQEN +BQAwWjELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoM +GEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDETMBEGA1UEAwwKVGVzdFJvb3RDQTAe +Fw0yNTAxMjYxMjE3MDBaFw0zNTAxMjQxMjE3MDBaMFoxCzAJBgNVBAYTAkFVMRMw +EQYDVQQIDApTb21lLVN0YXRlMSEwHwYDVQQKDBhJbnRlcm5ldCBXaWRnaXRzIFB0 +eSBMdGQxEzARBgNVBAMMClRlc3RSb290Q0EwggIiMA0GCSqGSIb3DQEBAQUAA4IC +DwAwggIKAoICAQDFvixQRLk0R2WOnXDkdMrmittYqWfHr3ZhZtS6HvFWBSV6AWc3 +DbseUgE7uD5xdJFlId35UH7HCFeeu8y/KkOPwH9KIzSWbZNcT3UJSDtnoA/sVYtN +MuS6Uu4DNkbDRNHf1udqc+0EwPpiZ7/3FwQify0pXyq7PbkOcJyFQh2YHG/EjZ4I +mBSz8NMwYQDeVMLxhQHTXruHIef1clLSSTRCXKLLKoKw/Rzje1jrBvLLollOJxLT +UXC1Fbpuh3KMnhwWsX4F4N8iWczcPxwCGcmYJA5xjo5tstkYzShUtNmMbFu3FCS8 +Vl/h25I3Znq7VdEI+brR7ZEeJj0yp9H1Aiev6XAojqWoNC1M63HgYY7uhl3YGC6f +uwx0qgmGI32dzv5JHCpOtI8N2V5rwwtYBVws8lGmkqbUEkF5oO5V6yQHulVsdGr1 +Kn5OPGolY8QmGcCE0LmvzRZCwZU2UcVxJsDJkNwup1C7wQEWC5pePEr58j3H3z6y +d3pkxaQmzXSB4jGJRzKbth6BQF47WwcphYjMtdWZUvy860isu9CEGjxbLjweATra +5o/8MIRuRPiJI2wlnEXHYWY96vrBQ202seQzMtJAtVoQxdpfokRHY8+jKfwZ/gRR +7tXxIRGfHoOgU9I8jtLNp782o/gjVTs9UGT0I66+PzpzS+XjshdH25OktwIDAQAB +o1MwUTAdBgNVHQ4EFgQUlT7d176QebrmSVanT1sGL2TyFuIwHwYDVR0jBBgwFoAU +lT7d176QebrmSVanT1sGL2TyFuIwDwYDVR0TAQH/BAUwAwEB/zANBgkqhkiG9w0B +AQ0FAAOCAgEAr6ZFZu6WYPUVY9zxJesNmnrm3xGbQn62iU6SrG9tsi/SFkQNPcVe +0CJ/zdA89yKet2Hpo95NSz9O4Jm5gapvGk8ie9UecqzEuKSLWR7mozupaPqDfF0O +YGgnhVMJIPGXbbm52oVV6FZtTRatQHatEnUS/09w2HkA/fyXbvRFA9O6RREevhjU +jcsB/ORx4Ni162Nr8waf6/2pJIturomz8hRtVsD5m6dGQuk7R6d7KZQQ+4Td7Cru +1xOxoWNDc0BBTbkv7DjOcy3YewgANgXqSsLrjprv30InoBgHvL8303EUkge268vd +jZ9mEsXdbZAVX1exetdBcoMQG8UmkKPnyU09w9NltnR7gVqZQyPDNZKTefP505X6 +67du/bw3Try/qUbiwJoyr1hf2d7rAJQ2CHDgedz8v5UszX4FAZ/yB5gUUxczld+r +6CCNR7FRfCCNmU6WPSa6CFvlg3x7JRXIdITHMtr14bhtLSmcfmRZhpG9N8r54C4P +L5OluPzU2P2JpV8i8YX8az5mFCdPxrAzjoAN8KU9WYp1LjKkTRT0UGYaTXLcVxyx +4+AWPJgT2GLXRyAcoEFdRQDSG+8jUy+ra0iEN6jp6JN04zBhIWVoQoA6+8u3PAna +DBVn5n32PZQjfu21u+cjvR3TrA3dXwi0/DPOYAeYr2S4D2R+6EAwFAo= +-----END CERTIFICATE----- diff --git a/tests/test_certs/server.crt b/tests/test_certs/server.crt new file mode 100644 index 00000000..4bd9edbc --- /dev/null +++ b/tests/test_certs/server.crt @@ -0,0 +1,30 @@ +-----BEGIN CERTIFICATE----- +MIIFOjCCAyICFCvF0YymuYiohdstQyrHzWdMOa5gMA0GCSqGSIb3DQEBCwUAMFox +CzAJBgNVBAYTAkFVMRMwEQYDVQQIDApTb21lLVN0YXRlMSEwHwYDVQQKDBhJbnRl +cm5ldCBXaWRnaXRzIFB0eSBMdGQxEzARBgNVBAMMClRlc3RSb290Q0EwHhcNMjUw +MTI2MTIyMzIyWhcNMzUwMTI0MTIyMzIyWjBZMQswCQYDVQQGEwJBVTETMBEGA1UE +CAwKU29tZS1TdGF0ZTEhMB8GA1UECgwYSW50ZXJuZXQgV2lkZ2l0cyBQdHkgTHRk +MRIwEAYDVQQDDAlsb2NhbGhvc3QwggIiMA0GCSqGSIb3DQEBAQUAA4ICDwAwggIK +AoICAQDI5oxS4wjneH21uhBj3bnkVQCTntieysO28zMKJdA8M/LVI6NtX9zJGIiR +Oum00ZmN1ESNIgXXscyeeQuFaK7CNc6JFVMOXoUBukWHhdA3IXotAoS0+6Nt8rc2 +joyQzembHuA2BQxHhF8gXXKhW6hk0vAjBjpYLGusxuVOgbvBKzL1VXNblSVGaBUm +xUZ9oZnGJw0HeBphDicGjJEokJMDe70vs9wlZdPDxDy/8iyFf+dPtnfCR1v2wzcT +vxI0lRqcf8n5k2cGAZKsE268/PNKbTyR5J7xqRe9hMhnEdCvxkWLhwQcwTKU1a2H +zXii3zZBh0MkcosZ8PmG/JtTSQRKFFGBa7aFh5oVuw//Kdm8qSEqrEeTXB9Us1eB +OS+kFTb/630kEuvLOc1gB3KcLw43AWLc5u5jzxyEcI6yc6wRxQTcxhEIfbj9tLEe +H9aw4nIsSa7lcOZXVboF5i1XrOC+KeAUPAxRjqttjlxAToZtIOtVPIhnjh1iVAP7 +g+Y6iGlc1t1jWN2IJrnlf7NyX98Uf5pr98O2NwyCcz0rpZPxHdLNE6/Wk2EugQJ9 +fNTEDn9rYW1iw1VMETZ/A53kCOIvse/KxS6aoWq4iPtfgzS3928DN7fZ5wJ5rYuv +pHBzzsFqkY+Oy341s91LIq6ZImTaIWd22KjU1hqdu+2MlCWPTQIDAQABMA0GCSqG +SIb3DQEBCwUAA4ICAQBEpPAoiWFX6st160wz/wJLqlgrr53iQwGyP/CttTE/LNHa +g+bVeJ14fsnwk47+DFxJbWuo3YipVEaIXXqdI2BUgNZLUrfBNGIdq4G0K1KcNeQf +O+Qql5he8LV9TKHj9N6efoZbQFWXixhkJwzb08XVEfWwUt4rDbFWLfEKLMpucRGw +1E5hB/92HuM9yB7ao5sXsMNddvlS4wVLThIw5pr/170nB3uHQXTVnAif0301SMk/ +i4wD7wevC9gz+40zbyC2HSsKhS2s+Jjey0/nSack8l5dISMp1XCJweG51Vb8F9Ml +5JZWdlw9J6cbJhw++oOBktLMCmnTiTP67aYlJhgrQeyPQXcg2uYDNlsK5nPqFxWZ +qjdvB6FMI9wS7LJylI9wJHDcG18+U8LrYnDIlotN2OJE7RVP/fYsAHCSswBwl0kH +3y1xIthILUSL1vCUXIZcI6hYSkxGlxTSd4KQoioVHr4/uavIIJxtf1dfkGRvVAEu +vo92OIiXpZ6Rhf4WUXaV6/kB9Jwkj0lJZ2RUw/CSVS8v3g4m4d09TsrhxKKmZxz7 +m+vsVopyBewXHHXKcmzI1OO5hL1wyjSx1TIAlr9MW3o3umvSgh3hM3Ildcmni8Xt +tUa95NKvjruz8UJ8gE50TZgsJI4ywT1QSyC55bfv74kKzHysv6zAj19y2CE0Hg== +-----END CERTIFICATE----- diff --git a/tests/test_certs/server.key b/tests/test_certs/server.key new file mode 100644 index 00000000..f4255586 --- /dev/null +++ b/tests/test_certs/server.key @@ -0,0 +1,51 @@ +-----BEGIN RSA PRIVATE KEY----- +MIIJJwIBAAKCAgEAyOaMUuMI53h9tboQY9255FUAk57YnsrDtvMzCiXQPDPy1SOj +bV/cyRiIkTrptNGZjdREjSIF17HMnnkLhWiuwjXOiRVTDl6FAbpFh4XQNyF6LQKE +tPujbfK3No6MkM3pmx7gNgUMR4RfIF1yoVuoZNLwIwY6WCxrrMblToG7wSsy9VVz +W5UlRmgVJsVGfaGZxicNB3gaYQ4nBoyRKJCTA3u9L7PcJWXTw8Q8v/IshX/nT7Z3 +wkdb9sM3E78SNJUanH/J+ZNnBgGSrBNuvPzzSm08keSe8akXvYTIZxHQr8ZFi4cE +HMEylNWth814ot82QYdDJHKLGfD5hvybU0kEShRRgWu2hYeaFbsP/ynZvKkhKqxH +k1wfVLNXgTkvpBU2/+t9JBLryznNYAdynC8ONwFi3ObuY88chHCOsnOsEcUE3MYR +CH24/bSxHh/WsOJyLEmu5XDmV1W6BeYtV6zgvingFDwMUY6rbY5cQE6GbSDrVTyI +Z44dYlQD+4PmOohpXNbdY1jdiCa55X+zcl/fFH+aa/fDtjcMgnM9K6WT8R3SzROv +1pNhLoECfXzUxA5/a2FtYsNVTBE2fwOd5AjiL7HvysUumqFquIj7X4M0t/dvAze3 +2ecCea2Lr6Rwc87BapGPjst+NbPdSyKumSJk2iFndtio1NYanbvtjJQlj00CAwEA +AQKCAgAKPqd9OpKjqyNN9xUK4q2uFR+YZ4tIXbKpS7GYnOEHkOabM9wLoc3Se2vL +bCOq0t1vvBla0RdXLnvuwOFzhikTQkcr+mhn3S4PLn6JMKuzhAOE9BHsYMCuxKfP +ImnMoJN/E43/czZzFy76qYlE7TWjHpacUp77DBjZkLL00+zNJvTMSfU+AFcMRhZ+ +CaVUlr8OucMSVG+T73LSBK0KUoUMsmytWBCr34ty+jjW2PSoQiN7jySARb9M0Buo +6B93ivr2bBXSok+ooL/oAn2tKYEGlJd4IR5x2FubkH/fsargq820FciB5uA7csIM +oM+8DoHnyYwE+cpaIk23Mn6BOsH7JffsIYJfdZWfzD3yiC/h2Ha6ohIC8zuYxU2c +kqnV4UdGlXwoFP4Mhhi+DyMQ9L+Ty9RidLbgJ27FMaLkri26REImEhyljCYutvuk +d0vB0gwJW1K0j/2/p5TQnLBkNJlDdfzhoCsb2EWU17Tx2bHiJYlbPbiyyRK9c/Wx +2/i9GfTIz09FllBHhynfqhC7zOFq9VXxHIEsYj2skSiPUi95Doheupz8U2wcsVTm +6/T++I7VxJJBeU+GF91Q9KJVrO0rFxmyolynRHHcnrzNaBVLHGdDWyiglHaI/3Hj +hP2co++b4HhzC+/4cpLPY3+sLrNSdWVCkZ1A+uCwfr5OnWrpwQKCAQEA5wXsdSRm +s2pdU0vZarqtRFq2F+7ZjAHtp6WEGrwcFwjirhCG+ko/K0PSqEci7WiL07/Zsw26 +0e+Nq30u4VTz9/DbUFzuFlR3N+/al6MrX1O5mOkrgb7s3YcBNWKmURTv5spP5MAo +uXz4Mv4KMKksO6gTEM9NDGiik6e3CBs2tyYvmCTzhY+PwLAldEQDJZQ6xAK3V7i/ +xcMumCS9z5EYcdF0DDx5m7mJGL9BF1OwWd2kAxssGDkihs2ZXvtgzxzT2xdVizZW +3ojS/c2KIrIp6J8tvLMzCIrvkK9rpy+WTNzOiIor/4Fhh35YrAaVO/APnS79kFDB +NY2H8jfY7LJtSQKCAQEA3p7qfkTVJGGO5d7mAR6ZGG/P+soEU6xW4/ybgieAZpM8 +5Phd45z89IyataR+8GjOTo+v7hGePcGFhcoZA62kKPejX72gs2bB1x3W/k169M43 +TB02lkXuHv/8wxletQAK/L0++aKQccwruK8iw6Yc8AZguoL8CXqlMaRJQiS1emfi +71sRfCSeNJmLCqOIiRkSg0xzBUmN2cP672KKu+fw4JjruJLZy6cGRfn8chKzYDuR +fKc8rS2sRL+L3ufjpWg1+lP1c6DQn2gFDquZ3e20YapRY20nyAObbRH3KvlZLQAM +BQNHMN6eW50mYA3rNMe805nCci6DdE0YSFiatAxl5QKCAQBflgfcAA+uNFgg2sU+ +b7a5DX9CL8U7NKEMOGOMXECTF04TDyuJ66ZvVESY87Xz3Mnd9wcwGoIt0pwfVFBN +U0UOVU2o1op8Gr6pGkirbQvJCW9FYVRq/oAquG07lXGTIsKQDy03THqNJLPdBVda +AuUWWdhpoBwVAkYiKcaFSB0/ckFHBiLsJBYqd7dHf8x9g/M8npMVbI+MV9GziaAv +fa1LiooldfArCn07DAb2i93vkNEHp/p6m0k51V+b+Q55I0hU4ja2vuj6cko6UQzS +hjzoztOxu8NlyXaNusckCYB6lPGvdNv3f6TG1vQBWUft4MnVE1g+mesXKVQSWCEc +7kZhAoIBAEx3i5ZJsGiptfrRYHG7/9w789Vx9KCFDueKyiOfy+Pv6TfA9AcN0nlx +nmaMFSog5dRoWIbOuGsAAQweig8QYtXLket96CgXQLfSQRnipTxXZPkZA7oEVTGC +vmCJY1WKqTt9CZeXtkPQXKg4SBmqAkCUAD+wZEAhR4LQqnU0xL1B19pdjpj0vv7U +SsUhvPFSkmBVLyD+zeGiBpyZXYwDtGKBRF6G2pawTWBV6NeKAuEoNOX7T8Uwbf7D +SJkNT81uCTRuCF5qO561jR8n5Fctogr2BLTBNqvmSUnipOK2+WGSpY5HPPnVTdGs +HhVaUpMzlHGeXAL6ZR7aqF+ZR7JWm90CggEAG0OKsZHAIGdeKDZnzu8lQqBWfGCp +9VXz9tcw7EQPn+KZ30FzxZsI5hHvxVKYHmjVDGbKChc0aq2nD8H29ONeDdpQxjSF +7dIZckzz45l3vUZco8b2V2SBgnv7XY4iTefhsbMs9y+cVCTByAXULtmOTZclgep8 +Ss0r2tX6kLpfQrCSitk+449dYXqm/pEZGw1+19LEZ8tZNhzTsJOtOddIh0FFJW14 +jClxlJvs0iSWfX3ihR/bgXIkxyXJcO/FGRxytek8ngWvkCZm8cFQMqrCNkRN4F60 +UKZPPgyCsBX5A1W9QuHdklCCARxCZ8xdkSs8ecb+QrJZvpawskUqSn8KKg== +-----END RSA PRIVATE KEY----- diff --git a/tests/test_runserver_config.py b/tests/test_runserver_config.py index 5b22747a..ad3e9a25 100644 --- a/tests/test_runserver_config.py +++ b/tests/test_runserver_config.py @@ -36,7 +36,8 @@ async def test_create_app_wrong_name(tmpworkdir): mktree(tmpworkdir, SIMPLE_APP) config = Config(app_path='app.py', app_factory_name='missing') with pytest.raises(AiohttpDevConfigError) as excinfo: - config.import_app_factory() + module = config.import_module() + config.get_app_factory(module) assert excinfo.value.args[0] == "Module 'app.py' does not define a 'missing' attribute/class" @@ -56,7 +57,8 @@ async def app_factory(): """ }) config = Config(app_path='app.py') - app = await config.load_app(config.import_app_factory()) + module = config.import_module() + app = await config.load_app(config.get_app_factory(module)) assert isinstance(app, web.Application) @@ -69,9 +71,10 @@ def app_factory(): """ }) config = Config(app_path='app.py') + module = config.import_module() with pytest.raises(AiohttpDevConfigError, match=r"'app_factory' returned 'int' not an aiohttp\.web\.Application"): - await config.load_app(config.import_app_factory()) + await config.load_app(config.get_app_factory(module)) @forked @@ -83,6 +86,67 @@ def app_factory(foo): """ }) config = Config(app_path='app.py') + module = config.import_module() with pytest.raises(AiohttpDevConfigError, match=r"'app\.py\.app_factory' should not have required arguments"): - await config.load_app(config.import_app_factory()) + await config.load_app(config.get_app_factory(module)) + + +@forked +async def test_no_ssl_context_factory(tmpworkdir): + mktree(tmpworkdir, { + "app.py": """\ +def app_factory(foo): + return web.Application() +""" + }) + config = Config(app_path="app.py", ssl_context_factory_name="get_ssl_context") + module = config.import_module() + with pytest.raises(AiohttpDevConfigError, + match="Module 'app.py' does not define a 'get_ssl_context' attribute/class"): + config.get_ssl_context(module) + + +@forked +async def test_invalid_ssl_context(tmpworkdir): + mktree(tmpworkdir, { + "app.py": """\ +def app_factory(foo): + return web.Application() + +def get_ssl_context(): + return 'invalid ssl_context' +""" + }) + config = Config(app_path="app.py", ssl_context_factory_name="get_ssl_context") + module = config.import_module() + with pytest.raises(AiohttpDevConfigError, + match="ssl-context-factory 'get_ssl_context' in module 'app.py' didn't return valid SSLContext"): + config.get_ssl_context(module) + + +async def test_rootcert_file_notfound(tmpworkdir): + mktree(tmpworkdir, { + "app.py": """\ +def app_factory(foo): + return web.Application() +""" + }) + config = Config(app_path="app.py", ssl_context_factory_name="get_ssl_context", ssl_rootcert_file_path="rootCA.pem") + with pytest.raises(AiohttpDevConfigError, + match="No such file or directory: rootCA.pem"): + config.client_ssl_context + + +async def test_invalid_rootcert_file(tmpworkdir): + mktree(tmpworkdir, { + "app.py": """\ +def app_factory(foo): + return web.Application() +""", + "rootCA.pem": "invalid X509 certificate" + }) + config = Config(app_path="app.py", ssl_context_factory_name="get_ssl_context", ssl_rootcert_file_path="rootCA.pem") + with pytest.raises(AiohttpDevConfigError, + match="invalid root cert file: rootCA.pem"): + config.client_ssl_context diff --git a/tests/test_runserver_main.py b/tests/test_runserver_main.py index 2f0d9a9e..a0440342 100644 --- a/tests/test_runserver_main.py +++ b/tests/test_runserver_main.py @@ -1,5 +1,6 @@ import asyncio import json +import ssl from unittest import mock import aiohttp @@ -15,7 +16,7 @@ WS, create_auxiliary_app, create_main_app, modify_main_app, src_reload, start_main_app) from aiohttp_devtools.runserver.watch import AppTask -from .conftest import SIMPLE_APP, forked +from .conftest import SIMPLE_APP, forked, linux_forked async def check_server_running(check_callback): @@ -145,7 +146,8 @@ async def create_app(): set_start_method("spawn") config = Config(app_path="app.py", root_path=tmpworkdir, main_port=0, app_factory_name="create_app") - config.import_app_factory() + module = config.import_module() + config.get_app_factory(module) app_task = AppTask(config) app_task._start_dev_server() @@ -162,7 +164,8 @@ async def create_app(): async def test_run_app_aiohttp_client(tmpworkdir, aiohttp_client): mktree(tmpworkdir, SIMPLE_APP) config = Config(app_path='app.py') - app_factory = config.import_app_factory() + module = config.import_module() + app_factory = config.get_app_factory(module) app = await config.load_app(app_factory) modify_main_app(app, config) assert isinstance(app, aiohttp.web.Application) @@ -178,7 +181,8 @@ async def test_run_app_aiohttp_client(tmpworkdir, aiohttp_client): async def test_run_app_browser_cache(tmpworkdir, aiohttp_client): mktree(tmpworkdir, SIMPLE_APP) config = Config(app_path="app.py", browser_cache=True) - app_factory = config.import_app_factory() + module = config.import_module() + app_factory = config.get_app_factory(module) app = await config.load_app(app_factory) modify_main_app(app, config) cli = await aiohttp_client(app) @@ -208,8 +212,9 @@ async def test_serve_main_app(tmpworkdir, mocker): loop.call_later(0.5, loop.stop) config = Config(app_path="app.py", main_port=0) - runner = await create_main_app(config, config.import_app_factory()) - await start_main_app(runner, config.bind_address, config.main_port) + module = config.import_module() + runner = await create_main_app(config, config.get_app_factory(module)) + await start_main_app(runner, config.bind_address, config.main_port, None) mock_modify_main_app.assert_called_with(mock.ANY, config) @@ -232,8 +237,9 @@ async def hello(request): mock_modify_main_app = mocker.patch('aiohttp_devtools.runserver.serve.modify_main_app') config = Config(app_path="app.py", main_port=0) - runner = await create_main_app(config, config.import_app_factory()) - await start_main_app(runner, config.bind_address, config.main_port) + module = config.import_module() + runner = await create_main_app(config, config.get_app_factory(module)) + await start_main_app(runner, config.bind_address, config.main_port, None) mock_modify_main_app.assert_called_with(mock.ANY, config) @@ -303,3 +309,91 @@ async def test_websocket_reload(aux_cli): assert reloads == 1 finally: await ws.close() + + +async def check_ssl_server_running(check_callback): + port_open = False + ssl_context = ssl.create_default_context() + ssl_context.load_verify_locations("test_certs/rootCA.pem") + + async with aiohttp.ClientSession(timeout=ClientTimeout(total=1)) as session: + for i in range(50): # pragma: no branch + try: + async with session.get("https://localhost:8000/", ssl=ssl_context): + pass + except OSError: + await asyncio.sleep(0.1) + else: + port_open = True + break + assert port_open + await check_callback(session, ssl_context) + await asyncio.sleep(.25) # TODO(aiohttp 4): Remove this hack + + +@pytest.mark.filterwarnings(r"ignore:unclosed:ResourceWarning") +@linux_forked +@pytest.mark.datafiles("tests/test_certs", keep_top_dir=True) +def test_start_runserver_ssl(datafiles, tmpworkdir, smart_caplog): + mktree(tmpworkdir, { + "app.py": """\ +from aiohttp import web +import ssl +async def hello(request): + return web.Response(text="

hello world

", content_type="text/html") + +async def has_error(request): + raise ValueError() + +def create_app(): + app = web.Application() + app.router.add_get("/", hello) + app.router.add_get("/error", has_error) + return app + +def get_ssl_context(): + ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) + ssl_context.load_cert_chain("test_certs/server.crt", "test_certs/server.key") + return ssl_context + """, + "static_dir/foo.js": "var bar=1;", + }) + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + args = runserver(app_path="app.py", static_path="static_dir", + bind_address="0.0.0.0", ssl_context_factory_name="get_ssl_context") + aux_app = args["app"] + aux_port = args["port"] + runapp_host = args["host"] + assert isinstance(aux_app, aiohttp.web.Application) + assert aux_port == 8001 + assert runapp_host == "0.0.0.0" + for startup in aux_app.on_startup: + loop.run_until_complete(startup(aux_app)) + + async def check_callback(session, ssl_context): + print(session, ssl_context) + async with session.get("https://localhost:8000/", ssl=ssl_context) as r: + assert r.status == 200 + assert r.headers["content-type"].startswith("text/html") + text = await r.text() + print(text) + assert "

hello world

" in text + assert '' in text + + async with session.get("https://localhost:8000/error", ssl=ssl_context) as r: + assert r.status == 500 + assert "raise ValueError()" in (await r.text()) + + try: + loop.run_until_complete(check_ssl_server_running(check_callback)) + finally: + for shutdown in aux_app.on_shutdown: + loop.run_until_complete(shutdown(aux_app)) + loop.run_until_complete(aux_app.cleanup()) + assert ( + "adev.server.dft INFO: Starting aux server at http://localhost:8001 ◆\n" + "adev.server.dft INFO: serving static files from ./static_dir/ at http://localhost:8001/static/\n" + "adev.server.dft INFO: Starting dev server at https://localhost:8000 ●\n" + ) in smart_caplog + loop.run_until_complete(asyncio.sleep(.25)) # TODO(aiohttp 4): Remove this hack diff --git a/tests/test_runserver_watch.py b/tests/test_runserver_watch.py index b8373d68..dcfd1166 100644 --- a/tests/test_runserver_watch.py +++ b/tests/test_runserver_watch.py @@ -77,6 +77,9 @@ async def test_python_no_server(mocker): config = MagicMock() config.main_port = 8000 + config.protocol = "http" + config.client_ssl_context = None + app_task = AppTask(config) start_mock = mocker.patch.object(app_task, "_start_dev_server", autospec=True) stop_mock = mocker.patch.object(app_task, "_stop_dev_server", autospec=True) @@ -109,6 +112,8 @@ async def test_reload_server_running(aiohttp_client, mocker): config = MagicMock() config.host = "localhost" config.main_port = cli.server.port + config.protocol = "http" + config.client_ssl_context = None app_task = AppTask(config) app_task._app = app