Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ssl support #712

Open
wants to merge 15 commits into
base: master
Choose a base branch
from
Open
2 changes: 2 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ or ``main.py``) or to a specific file. The ``--app-factory`` option can be used
from the app path file, if not supplied some default method names are tried
(namely `app`, `app_factory`, `get_app` and `create_app`, which can be
variables, functions, or coroutines).
The ``--ssl-context-factory`` option can be used to define method from the app path file, which returns ssl.SSLContext
for ssl support.

All ``runserver`` arguments can be set via environment variables.

Expand Down
2 changes: 2 additions & 0 deletions aiohttp_devtools/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def serve(path: str, livereload: bool, bind_address: str, port: int, verbose: bo
'or just an instance of aiohttp.Application. env variable AIO_APP_FACTORY')
port_help = 'Port to serve app from, default 8000. env variable: AIO_PORT'
aux_port_help = 'Port to serve auxiliary app (reload and static) on, default port + 1. env variable: AIO_AUX_PORT'
ssl_context_factory_help = 'name of the ssl context factory to create ssl.SSLContext with'


# defaults are all None here so default settings are defined in one place: DEV_DICT validation
Expand All @@ -83,6 +84,7 @@ def serve(path: str, livereload: bool, bind_address: str, port: int, verbose: bo
@click.option('-v', '--verbose', is_flag=True, help=verbose_help)
@click.option("--browser-cache/--no-browser-cache", envvar="AIO_BROWSER_CACHE", default=None,
help=browser_cache_help)
@click.option('--ssl-context-factory', 'ssl_context_factory_name', default=None, help=ssl_context_factory_help)
@click.argument('project_args', nargs=-1)
def runserver(**config: Any) -> None:
"""
Expand Down
54 changes: 46 additions & 8 deletions aiohttp_devtools/runserver/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
import sys
from importlib import import_module
from pathlib import Path
from typing import Awaitable, Callable, Optional, Union
from typing import Awaitable, Callable, Optional, Union, Literal
from types import ModuleType

from aiohttp import web
from ssl import SSLContext

import __main__
from ..exceptions import AiohttpDevConfigError as AdevConfigError
Expand All @@ -26,6 +28,8 @@
'create_app',
]

DEFAULT_PORT = 8000

INFER_HOST = '<inference>'


Expand All @@ -43,9 +47,10 @@ def __init__(self, *,
app_factory_name: Optional[str] = None,
host: str = INFER_HOST,
bind_address: str = "localhost",
main_port: int = 8000,
main_port: Optional[int] = None,
aux_port: Optional[int] = None,
browser_cache: bool = False):
browser_cache: bool = False,
ssl_context_factory_name: Optional[str] = None):
if root_path:
self.root_path = Path(root_path).resolve()
logger.debug('Root path specified: %s', self.root_path)
Expand Down Expand Up @@ -83,11 +88,20 @@ def __init__(self, *,
self.host = bind_address

self.bind_address = bind_address
if main_port is None:
main_port = DEFAULT_PORT if ssl_context_factory_name is None else DEFAULT_PORT + 443
self.main_port = main_port
self.aux_port = aux_port or (main_port + 1)
if aux_port is None:
aux_port = main_port + 1 if ssl_context_factory_name is None else DEFAULT_PORT + 1
self.aux_port = aux_port
self.browser_cache = browser_cache
self.ssl_context_factory_name = ssl_context_factory_name
logger.debug('config loaded:\n%s', self)

@property
def protocol(self) -> Literal["http", "https"]:
return "http" if self.ssl_context_factory_name is None else "https"

@property
def static_path_str(self) -> Optional[str]:
return str(self.static_path) if self.static_path else None
Expand Down Expand Up @@ -136,15 +150,14 @@ def _resolve_path(self, _path: str, check: str, arg_name: str) -> Path:
raise AdevConfigError('{} is not a directory'.format(path))
return path

def import_app_factory(self) -> AppFactory:
"""Import and return attribute/class from a python module.
def import_module(self) -> ModuleType:
"""Import and return python module.

Raises:
AdevConfigError - If the import failed.
"""
rel_py_file = self.py_file.relative_to(self.python_path)
module_path = '.'.join(rel_py_file.with_suffix('').parts)

sys.path.insert(0, str(self.python_path))
module = import_module(module_path)
# Rewrite the package name, so it will appear the same as running the app.
Expand All @@ -153,6 +166,16 @@ def import_app_factory(self) -> AppFactory:

logger.debug('successfully loaded "%s" from "%s"', module_path, self.python_path)

self.watch_path = self.watch_path or Path(module.__file__ or ".").parent
return module

def get_app_factory(self, module: ModuleType) -> AppFactory:
"""Return attribute/class from a python module.

Raises:
AdevConfigError - If the import failed.
"""

if self.app_factory_name is None:
try:
self.app_factory_name = next(an for an in APP_FACTORY_NAMES if hasattr(module, an))
Expand All @@ -179,9 +202,24 @@ def import_app_factory(self) -> AppFactory:
raise AdevConfigError("'{}.{}' should not have required arguments.".format(
self.py_file.name, self.app_factory_name))

self.watch_path = self.watch_path or Path(module.__file__ or ".").parent
return attr # type: ignore[no-any-return]

def get_ssl_context(self, module: ModuleType) -> Union[SSLContext, None]:
if self.ssl_context_factory_name is None:
return None
else:
try:
attr = getattr(module, self.ssl_context_factory_name)
except AttributeError:
raise AdevConfigError("Module '{}' does not define a '{}' attribute/class".format(
self.py_file.name, self.ssl_context_factory_name))
ssl_context = attr()
if isinstance(ssl_context, SSLContext):
return ssl_context
else:
raise AdevConfigError("ssl-context-factory '{}' in module '{}' didn't return valid SSLContext".format(
self.ssl_context_factory_name, self.py_file.name))

async def load_app(self, app_factory: AppFactory) -> web.Application:
if isinstance(app_factory, web.Application):
return app_factory
Expand Down
11 changes: 6 additions & 5 deletions aiohttp_devtools/runserver/main.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import asyncio
import os
from multiprocessing import set_start_method
from typing import Any, Type, TypedDict
from typing import Any, Type, TypedDict, Union

from aiohttp.abc import AbstractAccessLogger
from aiohttp.web import Application
Expand All @@ -11,6 +11,7 @@
from .log_handlers import AuxAccessLogger
from .serve import check_port_open, create_auxiliary_app
from .watch import AppTask, LiveReloadTask
from ssl import SSLContext


class RunServer(TypedDict):
Expand All @@ -19,6 +20,7 @@ class RunServer(TypedDict):
port: int
shutdown_timeout: float
access_log_class: Type[AbstractAccessLogger]
ssl_context: Union[SSLContext, None]


def runserver(**config_kwargs: Any) -> RunServer:
Expand All @@ -29,9 +31,8 @@ def runserver(**config_kwargs: Any) -> RunServer:
"""
# force a full reload in sub processes so they load an updated version of code, this must be called only once
set_start_method('spawn')

config = Config(**config_kwargs)
config.import_app_factory()
config.import_module()

asyncio.run(check_port_open(config.main_port, host=config.bind_address))

Expand All @@ -57,7 +58,7 @@ def runserver(**config_kwargs: Any) -> RunServer:
logger.info('serving static files from ./%s/ at %s%s', rel_path, url, config.static_url)

return {"app": aux_app, "host": config.bind_address, "port": config.aux_port,
"shutdown_timeout": 0.01, "access_log_class": AuxAccessLogger}
"shutdown_timeout": 0.01, "access_log_class": AuxAccessLogger, "ssl_context": None}


def serve_static(*, static_path: str, livereload: bool = True, bind_address: str = "localhost", port: int = 8000,
Expand All @@ -75,4 +76,4 @@ def serve_static(*, static_path: str, livereload: bool = True, bind_address: str
livereload_status = 'ON' if livereload else 'OFF'
logger.info('Serving "%s" at http://%s:%d, livereload %s', static_path, bind_address, port, livereload_status)
return {"app": app, "host": bind_address, "port": port,
"shutdown_timeout": 0.01, "access_log_class": AuxAccessLogger}
"shutdown_timeout": 0.01, "access_log_class": AuxAccessLogger, "ssl_context": None}
19 changes: 12 additions & 7 deletions aiohttp_devtools/runserver/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import warnings
from errno import EADDRINUSE
from pathlib import Path
from typing import Any, Iterator, List, NoReturn, Optional, Set, Tuple
from typing import Any, Iterator, List, NoReturn, Optional, Set, Tuple, Union

from aiohttp import WSMsgType, web
from aiohttp.hdrs import LAST_MODIFIED, CONTENT_LENGTH
Expand All @@ -25,6 +25,8 @@
from .log_handlers import AccessLogger
from .utils import MutableValue

from ssl import SSLContext

try:
from aiohttp_jinja2 import static_root_key
except ImportError:
Expand Down Expand Up @@ -120,7 +122,8 @@ def shutdown() -> NoReturn:

path = config.path_prefix + "/shutdown"
app.router.add_route("GET", path, do_shutdown, name="_devtools.shutdown")
dft_logger.debug("Created shutdown endpoint at http://{}:{}{}".format(config.host, config.main_port, path))
dft_logger.debug("Created shutdown endpoint at {}://{}:{}{}".format(
config.protocol, config.host, config.main_port, path))

if config.static_path is not None:
static_url = 'http://{}:{}/{}'.format(config.host, config.aux_port, static_path)
Expand Down Expand Up @@ -164,12 +167,14 @@ def set_tty(tty_path: Optional[str]) -> Iterator[None]:
def serve_main_app(config: Config, tty_path: Optional[str]) -> None:
with set_tty(tty_path):
setup_logging(config.verbose)
app_factory = config.import_app_factory()
module = config.import_module()
app_factory = config.get_app_factory(module)
ssl_context = config.get_ssl_context(module)
if sys.version_info >= (3, 11):
with asyncio.Runner() as runner:
app_runner = runner.run(create_main_app(config, app_factory))
try:
runner.run(start_main_app(app_runner, config.bind_address, config.main_port))
runner.run(start_main_app(app_runner, config.bind_address, config.main_port, ssl_context))
runner.get_loop().run_forever()
except KeyboardInterrupt:
pass
Expand All @@ -180,7 +185,7 @@ def serve_main_app(config: Config, tty_path: Optional[str]) -> None:
loop = asyncio.new_event_loop()
runner = loop.run_until_complete(create_main_app(config, app_factory))
try:
loop.run_until_complete(start_main_app(runner, config.bind_address, config.main_port))
loop.run_until_complete(start_main_app(runner, config.bind_address, config.main_port, ssl_context))
loop.run_forever()
except KeyboardInterrupt: # pragma: no cover
pass
Expand All @@ -197,9 +202,9 @@ async def create_main_app(config: Config, app_factory: AppFactory) -> web.AppRun
return web.AppRunner(app, access_log_class=AccessLogger, shutdown_timeout=0.1)


async def start_main_app(runner: web.AppRunner, host: str, port: int) -> None:
async def start_main_app(runner: web.AppRunner, host: str, port: int, ssl_context: Union[SSLContext, None]) -> None:
await runner.setup()
site = web.TCPSite(runner, host=host, port=port)
site = web.TCPSite(runner, host=host, port=port, ssl_context=ssl_context)
await site.start()


Expand Down
14 changes: 9 additions & 5 deletions aiohttp_devtools/runserver/watch.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from ..logs import rs_dft_logger as logger
from .config import Config
from .serve import LAST_RELOAD, STATIC_PATH, WS, serve_main_app, src_reload
import ssl


class WatchTask:
Expand Down Expand Up @@ -55,7 +56,9 @@ def __init__(self, config: Config):
self._reloads = 0
self._session: Optional[ClientSession] = None
self._runner = None
self.ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) if config.protocol == 'https' else None
assert self._config.watch_path

super().__init__(self._config.watch_path)

async def _run(self, live_checks: int = 150) -> None:
Expand Down Expand Up @@ -107,12 +110,12 @@ async def _src_reload_when_live(self, checks: int) -> None:
assert self._app is not None and self._session is not None

if self._app[WS]:
url = "http://{0.host}:{0.main_port}/?_checking_alive=1".format(self._config)
url = "{0.protocol}://{0.host}:{0.main_port}/?_checking_alive=1".format(self._config)
logger.debug('checking app at "%s" is running before prompting reload...', url)
for i in range(checks):
await asyncio.sleep(0.1)
try:
async with self._session.get(url):
async with self._session.get(url, ssl=self.ssl_context):
pass
except OSError as e:
logger.debug('try %d | OSError %d app not running', i, e.errno)
Expand All @@ -123,7 +126,8 @@ async def _src_reload_when_live(self, checks: int) -> None:

def _start_dev_server(self) -> None:
act = 'Start' if self._reloads == 0 else 'Restart'
logger.info('%sing dev server at http://%s:%s ●', act, self._config.host, self._config.main_port)
logger.info('%sing dev server at %s://%s:%s ●',
act, self._config.protocol, self._config.host, self._config.main_port)

try:
tty_path = os.ttyname(sys.stdin.fileno())
Expand All @@ -141,12 +145,12 @@ async def _stop_dev_server(self) -> None:
if self._process.is_alive():
logger.debug('stopping server process...')
if self._config.shutdown_by_url: # Workaround for signals not working on Windows
url = "http://{0.host}:{0.main_port}{0.path_prefix}/shutdown".format(self._config)
url = "{0.protocol}://{0.host}:{0.main_port}{0.path_prefix}/shutdown".format(self._config)
logger.debug("Attempting to stop process via shutdown endpoint {}".format(url))
try:
with suppress(ClientConnectionError):
async with ClientSession() as session:
async with session.get(url):
async with session.get(url, ssl=self.ssl_context):
pass
except (ConnectionError, ClientError, asyncio.TimeoutError) as ex:
if self._process.is_alive():
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,5 @@ pytest-mock==3.14.0
pytest-sugar==1.0.0
pytest-timeout==2.2.0
pytest-toolbox==0.4
pytest-datafiles==3.0.0
watchfiles==1.0.4
30 changes: 30 additions & 0 deletions tests/test_certs/server.crt
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
-----BEGIN CERTIFICATE-----
MIIFOjCCAyICFCvF0YymuYiohdstQyrHzWdMOa5gMA0GCSqGSIb3DQEBCwUAMFox
CzAJBgNVBAYTAkFVMRMwEQYDVQQIDApTb21lLVN0YXRlMSEwHwYDVQQKDBhJbnRl
cm5ldCBXaWRnaXRzIFB0eSBMdGQxEzARBgNVBAMMClRlc3RSb290Q0EwHhcNMjUw
MTI2MTIyMzIyWhcNMzUwMTI0MTIyMzIyWjBZMQswCQYDVQQGEwJBVTETMBEGA1UE
CAwKU29tZS1TdGF0ZTEhMB8GA1UECgwYSW50ZXJuZXQgV2lkZ2l0cyBQdHkgTHRk
MRIwEAYDVQQDDAlsb2NhbGhvc3QwggIiMA0GCSqGSIb3DQEBAQUAA4ICDwAwggIK
AoICAQDI5oxS4wjneH21uhBj3bnkVQCTntieysO28zMKJdA8M/LVI6NtX9zJGIiR
Oum00ZmN1ESNIgXXscyeeQuFaK7CNc6JFVMOXoUBukWHhdA3IXotAoS0+6Nt8rc2
joyQzembHuA2BQxHhF8gXXKhW6hk0vAjBjpYLGusxuVOgbvBKzL1VXNblSVGaBUm
xUZ9oZnGJw0HeBphDicGjJEokJMDe70vs9wlZdPDxDy/8iyFf+dPtnfCR1v2wzcT
vxI0lRqcf8n5k2cGAZKsE268/PNKbTyR5J7xqRe9hMhnEdCvxkWLhwQcwTKU1a2H
zXii3zZBh0MkcosZ8PmG/JtTSQRKFFGBa7aFh5oVuw//Kdm8qSEqrEeTXB9Us1eB
OS+kFTb/630kEuvLOc1gB3KcLw43AWLc5u5jzxyEcI6yc6wRxQTcxhEIfbj9tLEe
H9aw4nIsSa7lcOZXVboF5i1XrOC+KeAUPAxRjqttjlxAToZtIOtVPIhnjh1iVAP7
g+Y6iGlc1t1jWN2IJrnlf7NyX98Uf5pr98O2NwyCcz0rpZPxHdLNE6/Wk2EugQJ9
fNTEDn9rYW1iw1VMETZ/A53kCOIvse/KxS6aoWq4iPtfgzS3928DN7fZ5wJ5rYuv
pHBzzsFqkY+Oy341s91LIq6ZImTaIWd22KjU1hqdu+2MlCWPTQIDAQABMA0GCSqG
SIb3DQEBCwUAA4ICAQBEpPAoiWFX6st160wz/wJLqlgrr53iQwGyP/CttTE/LNHa
g+bVeJ14fsnwk47+DFxJbWuo3YipVEaIXXqdI2BUgNZLUrfBNGIdq4G0K1KcNeQf
O+Qql5he8LV9TKHj9N6efoZbQFWXixhkJwzb08XVEfWwUt4rDbFWLfEKLMpucRGw
1E5hB/92HuM9yB7ao5sXsMNddvlS4wVLThIw5pr/170nB3uHQXTVnAif0301SMk/
i4wD7wevC9gz+40zbyC2HSsKhS2s+Jjey0/nSack8l5dISMp1XCJweG51Vb8F9Ml
5JZWdlw9J6cbJhw++oOBktLMCmnTiTP67aYlJhgrQeyPQXcg2uYDNlsK5nPqFxWZ
qjdvB6FMI9wS7LJylI9wJHDcG18+U8LrYnDIlotN2OJE7RVP/fYsAHCSswBwl0kH
3y1xIthILUSL1vCUXIZcI6hYSkxGlxTSd4KQoioVHr4/uavIIJxtf1dfkGRvVAEu
vo92OIiXpZ6Rhf4WUXaV6/kB9Jwkj0lJZ2RUw/CSVS8v3g4m4d09TsrhxKKmZxz7
m+vsVopyBewXHHXKcmzI1OO5hL1wyjSx1TIAlr9MW3o3umvSgh3hM3Ildcmni8Xt
tUa95NKvjruz8UJ8gE50TZgsJI4ywT1QSyC55bfv74kKzHysv6zAj19y2CE0Hg==
-----END CERTIFICATE-----
Loading