diff --git a/hass_nabucasa/__init__.py b/hass_nabucasa/__init__.py index 1e4e96fbc..5be53c2fa 100644 --- a/hass_nabucasa/__init__.py +++ b/hass_nabucasa/__init__.py @@ -1,20 +1,20 @@ """Component to integrate the Home Assistant cloud.""" import asyncio +from datetime import datetime, timedelta import json import logging -from typing import Coroutine, Callable -from datetime import datetime, timedelta from pathlib import Path +from typing import Callable, Coroutine import aiohttp -from homeassistant.util import dt as dt_util from . import auth_api from .client import CloudClient from .cloudhooks import Cloudhooks +from .const import CONFIG_DIR, MODE_DEV, SERVERS, STATE_CONNECTED from .iot import CloudIoT from .remote import RemoteUI -from .const import CONFIG_DIR, MODE_DEV, SERVERS +from .utils import parse_date, utcnow, UTC _LOGGER = logging.getLogger(__name__) @@ -75,6 +75,11 @@ def is_logged_in(self) -> bool: """Get if cloud is logged in.""" return self.id_token is not None + @property + def is_connected(self) -> bool: + """Return True if we are connected.""" + return self.iot.state == STATE_CONNECTED + @property def websession(self) -> aiohttp.ClientSession: """Return websession for connections.""" @@ -83,14 +88,14 @@ def websession(self) -> aiohttp.ClientSession: @property def subscription_expired(self) -> bool: """Return a boolean if the subscription has expired.""" - return dt_util.utcnow() > self.expiration_date + timedelta(days=7) + return utcnow() > self.expiration_date + timedelta(days=7) @property def expiration_date(self) -> datetime: """Return the subscription expiration as a UTC datetime object.""" return datetime.combine( - dt_util.parse_date(self.claims["custom:sub-exp"]), datetime.min.time() - ).replace(tzinfo=dt_util.UTC) + parse_date(self.claims["custom:sub-exp"]), datetime.min.time() + ).replace(tzinfo=UTC) @property def claims(self): diff --git a/hass_nabucasa/const.py b/hass_nabucasa/const.py index 6aa773f4d..b53c9d54c 100644 --- a/hass_nabucasa/const.py +++ b/hass_nabucasa/const.py @@ -6,6 +6,10 @@ MODE_PROD = "production" MODE_DEV = "development" +STATE_CONNECTING = "connecting" +STATE_CONNECTED = "connected" +STATE_DISCONNECTED = "disconnected" + SERVERS = { "production": { "cognito_client_id": "60i2uvhvbiref2mftj7rgcrt9u", diff --git a/hass_nabucasa/iot.py b/hass_nabucasa/iot.py index 284eb0e30..39b0c3821 100644 --- a/hass_nabucasa/iot.py +++ b/hass_nabucasa/iot.py @@ -5,19 +5,21 @@ import random import uuid -from aiohttp import hdrs, client_exceptions, WSMsgType -from homeassistant.util.decorator import Registry +from aiohttp import WSMsgType, client_exceptions, hdrs from . import auth_api -from .const import MESSAGE_EXPIRATION, MESSAGE_AUTH_FAIL +from .const import ( + MESSAGE_AUTH_FAIL, + MESSAGE_EXPIRATION, + STATE_CONNECTED, + STATE_CONNECTING, + STATE_DISCONNECTED, +) +from .utils import Registry HANDLERS = Registry() _LOGGER = logging.getLogger(__name__) -STATE_CONNECTING = "connecting" -STATE_CONNECTED = "connected" -STATE_DISCONNECTED = "disconnected" - class UnknownHandler(Exception): """Exception raised when trying to handle unknown handler.""" diff --git a/hass_nabucasa/remote.py b/hass_nabucasa/remote.py index 304899a21..87a09d0a5 100644 --- a/hass_nabucasa/remote.py +++ b/hass_nabucasa/remote.py @@ -1,18 +1,20 @@ """Manage remote UI connections.""" import asyncio -from contextlib import suppress +from datetime import datetime import logging import random import ssl from typing import Optional import async_timeout -from homeassistant.util.ssl import server_context_modern +import attr +from snitun.exceptions import SniTunConnectionError from snitun.utils.aes import generate_aes_keyset from snitun.utils.aiohttp_client import SniTunClientAioHttp from . import cloud_api from .acme import AcmeClientError, AcmeHandler +from .utils import server_context_modern, utcnow, utc_from_timestamp _LOGGER = logging.getLogger(__name__) @@ -29,6 +31,16 @@ class RemoteNotConnected(RemoteError): """Raise if a request need connection and we are not ready.""" +@attr.s +class SniTunToken: + """Handle snitun token.""" + + fernet = attr.ib(type=bytes) + aes_key = attr.ib(type=bytes) + aes_iv = attr.ib(type=bytes) + valid = attr.ib(type=datetime) + + class RemoteUI: """Class to help manage remote connections.""" @@ -39,6 +51,7 @@ def __init__(self, cloud): self._snitun = None self._snitun_server = None self._reconnect_task = None + self._token = None # Register start/stop cloud.iot.register_on_connect(self.load_backend) @@ -55,8 +68,8 @@ async def _create_context(self) -> ssl.SSLContext: await self.cloud.run_executor( context.load_cert_chain, - self._acme.path_fullchain, - self._acme.path_private_key, + str(self._acme.path_fullchain), + str(self._acme.path_private_key), ) return context @@ -105,7 +118,7 @@ async def load_backend(self) -> None: self._snitun_server = data["server"] await self._snitun.start() - await self._connect_snitun() + self.cloud.run_task(self.connect()) async def close_backend(self) -> None: """Close connections and shutdown backend.""" @@ -114,8 +127,7 @@ async def close_backend(self) -> None: # Disconnect snitun if self._snitun: - with suppress(RuntimeError): - await self._snitun.stop() + await self._snitun.stop() self._snitun = None self._acme = None @@ -128,35 +140,75 @@ async def handle_connection_requests(self, caller_ip): if self._snitun.is_connected: return + await self.connect() - await self._connect_snitun() + async def _refresh_snitun_token(self): + """Handle snitun token.""" + if self._token and self._token.valid > utcnow(): + _LOGGER.debug("Don't need refresh snitun token") + return - async def _connect_snitun(self): - """Connect to snitun server.""" # Generate session token aes_key, aes_iv = generate_aes_keyset() async with async_timeout.timeout(10): resp = await cloud_api.async_remote_token(self.cloud, aes_key, aes_iv) if resp.status != 200: - _LOGGER.error("Can't register a snitun token by server") raise RemoteBackendError() data = await resp.json() - await self._snitun.connect(data["token"].encode(), aes_key, aes_iv) + self._token = SniTunToken( + data["token"].encode(), aes_key, aes_iv, utc_from_timestamp(data["valid"]) + ) + + async def connect(self): + """Connect to snitun server.""" + if not self._snitun: + _LOGGER.error("Can't handle request-connection without backend") + raise RemoteNotConnected() + + # Check if we already connected + if self._snitun.is_connected: + return + + try: + await self._refresh_snitun_token() + await self._snitun.connect( + self._token.fernet, self._token.aes_key, self._token.aes_iv + ) + except SniTunConnectionError: + _LOGGER.error("Connection problem to snitun server") + except RemoteBackendError: + _LOGGER.error("Can't refresh the snitun token") + finally: + # start retry task + if not self._reconnect_task: + self._reconnect_task = self.cloud.run_task(self._reconnect_snitun()) + + async def disconnect(self): + """Disconnect from snitun server.""" + if not self._snitun: + _LOGGER.error("Can't handle request-connection without backend") + raise RemoteNotConnected() - # start retry task + # Stop reconnect task if self._reconnect_task: + self._reconnect_task.cancel() + + # Check if we already connected + if not self._snitun.is_connected: return - self._reconnect_task = self.cloud.run_task(self._reconnect_snitun()) + await self._snitun.disconnect() async def _reconnect_snitun(self): """Reconnect after disconnect.""" try: while True: - await self._snitun.wait() - await asyncio.sleep(random.randint(1, 10)) - await self._connect_snitun() + if self._snitun.is_connected: + await self._snitun.wait() + + await asyncio.sleep(random.randint(1, 15)) + await self.connect() except asyncio.CancelledError: pass finally: diff --git a/hass_nabucasa/utils.py b/hass_nabucasa/utils.py new file mode 100644 index 000000000..da03b9427 --- /dev/null +++ b/hass_nabucasa/utils.py @@ -0,0 +1,71 @@ +"""Helper methods to handle the time in Home Assistant.""" +import datetime as dt +import ssl +from typing import Optional, Callable, TypeVar + +import pytz + +CALLABLE_T = TypeVar("CALLABLE_T", bound=Callable) # noqa pylint: disable=invalid-name +DATE_STR_FORMAT = "%Y-%m-%d" +UTC = pytz.utc + + +def utcnow() -> dt.datetime: + """Get now in UTC time.""" + return dt.datetime.now(UTC) + + +def utc_from_timestamp(timestamp: float) -> dt.datetime: + """Return a UTC time from a timestamp.""" + return UTC.localize(dt.datetime.utcfromtimestamp(timestamp)) + + +def parse_date(dt_str: str) -> Optional[dt.date]: + """Convert a date string to a date object.""" + try: + return dt.datetime.strptime(dt_str, DATE_STR_FORMAT).date() + except ValueError: # If dt_str did not match our format + return None + + +def server_context_modern() -> ssl.SSLContext: + """Return an SSL context following the Mozilla recommendations. + TLS configuration follows the best-practice guidelines specified here: + https://wiki.mozilla.org/Security/Server_Side_TLS + Modern guidelines are followed. + """ + context = ssl.SSLContext(ssl.PROTOCOL_TLS) # pylint: disable=no-member + + context.options |= ( + ssl.OP_NO_SSLv2 + | ssl.OP_NO_SSLv3 + | ssl.OP_NO_TLSv1 + | ssl.OP_NO_TLSv1_1 + | ssl.OP_CIPHER_SERVER_PREFERENCE + ) + if hasattr(ssl, "OP_NO_COMPRESSION"): + context.options |= ssl.OP_NO_COMPRESSION + + context.set_ciphers( + "ECDHE-ECDSA-AES256-GCM-SHA384:ECDHE-RSA-AES256-GCM-SHA384:" + "ECDHE-ECDSA-CHACHA20-POLY1305:ECDHE-RSA-CHACHA20-POLY1305:" + "ECDHE-ECDSA-AES128-GCM-SHA256:ECDHE-RSA-AES128-GCM-SHA256:" + "ECDHE-ECDSA-AES256-SHA384:ECDHE-RSA-AES256-SHA384:" + "ECDHE-ECDSA-AES128-SHA256:ECDHE-RSA-AES128-SHA256" + ) + + return context + + +class Registry(dict): + """Registry of items.""" + + def register(self, name: str) -> Callable[[CALLABLE_T], CALLABLE_T]: + """Return decorator to register item with a specific name.""" + + def decorator(func: CALLABLE_T) -> CALLABLE_T: + """Register decorated function.""" + self[name] = func + return func + + return decorator diff --git a/setup.cfg b/setup.cfg index a0fed8150..1a9f8213b 100644 --- a/setup.cfg +++ b/setup.cfg @@ -14,4 +14,4 @@ use_parentheses = true [flake8] max-line-length = 88 -ignore = E501 +ignore = E501, W503 diff --git a/setup.py b/setup.py index cb3ee219b..6c4ba1d2d 100644 --- a/setup.py +++ b/setup.py @@ -1,6 +1,6 @@ from setuptools import setup -VERSION = "0.1" +VERSION = "0.2" setup( name="hass-nabucasa", @@ -30,9 +30,10 @@ packages=["hass_nabucasa"], install_requires=[ "warrant==0.6.1", - "snitun==0.9", + "snitun==0.11", "acme==0.31.0", "cryptography>=2.5", "attrs>=18.2.0", + "pytz", ], ) diff --git a/tests/test_cloud_api.py b/tests/test_cloud_api.py index edac35102..058e4f088 100644 --- a/tests/test_cloud_api.py +++ b/tests/test_cloud_api.py @@ -56,14 +56,18 @@ async def test_remote_token(cloud_mock, aioclient_mock): """Test creating a cloudhook.""" aioclient_mock.post( "https://example.com/bla/snitun_token", - json={"token": "123456", "server": "rest-remote.nabu.casa"}, + json={"token": "123456", "server": "rest-remote.nabu.casa", "valid": 12345}, ) cloud_mock.id_token = "mock-id-token" cloud_mock.remote_api_url = "https://example.com/bla" resp = await cloud_api.async_remote_token(cloud_mock, b"aes", b"iv") assert len(aioclient_mock.mock_calls) == 1 - assert await resp.json() == {"token": "123456", "server": "rest-remote.nabu.casa"} + assert await resp.json() == { + "token": "123456", + "server": "rest-remote.nabu.casa", + "valid": 12345, + } assert aioclient_mock.mock_calls[0][2] == {"aes_iv": "6976", "aes_key": "616573"} diff --git a/tox.ini b/tox.ini index a0149df2b..236035dce 100644 --- a/tox.ini +++ b/tox.ini @@ -4,7 +4,6 @@ envlist = lint, tests [testenv] basepython = python3 deps = - homeassistant -r{toxinidir}/requirements_tests.txt [testenv:lint]