From ad24dcc5b098da5b93588448a626beb0b6cfe549 Mon Sep 17 00:00:00 2001 From: Pascal Vizeli Date: Wed, 2 Oct 2019 20:58:36 +0200 Subject: [PATCH 1/5] Bump version 0.21 --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 88996725b..b3ddb8ee4 100644 --- a/setup.py +++ b/setup.py @@ -1,6 +1,6 @@ from setuptools import setup -VERSION = "0.20" +VERSION = "0.21" setup( name="hass-nabucasa", From 511285f2a00b9891394c8753cb031a739962d99c Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Wed, 2 Oct 2019 14:58:55 -0700 Subject: [PATCH 2/5] Update message sending --- hass_nabucasa/const.py | 2 +- hass_nabucasa/google_report_state.py | 46 +++++++++++++++++++++- hass_nabucasa/iot_base.py | 5 ++- tests/conftest.py | 8 +++- tests/test_google_report_state.py | 57 +++++++++++++++++----------- 5 files changed, 91 insertions(+), 27 deletions(-) diff --git a/hass_nabucasa/const.py b/hass_nabucasa/const.py index b89b37792..46819a878 100644 --- a/hass_nabucasa/const.py +++ b/hass_nabucasa/const.py @@ -21,7 +21,7 @@ "user_pool_id": "us-east-1_87ll5WOP8", "region": "us-east-1", "relayer": "wss://cloud.nabucasa.com/websocket", - "google_actions_report_state_url": "https://remotestate.nabucasa.com", + "google_actions_report_state_url": "wss://remotestate.nabucasa.com/v1", "google_actions_sync_url": ( "https://24ab3v80xd.execute-api.us-east-1." "amazonaws.com/prod/smart_home_sync" diff --git a/hass_nabucasa/google_report_state.py b/hass_nabucasa/google_report_state.py index ade80328e..65d65365e 100644 --- a/hass_nabucasa/google_report_state.py +++ b/hass_nabucasa/google_report_state.py @@ -2,12 +2,26 @@ import asyncio from asyncio.queues import Queue import logging +from typing import Dict +import uuid from . import iot_base _LOGGER = logging.getLogger(__name__) MAX_PENDING = 100 +ERR_DISCARD_CODE = "message_discarded" +ERR_DISCARD_MSG = "Message discarded because max messages reachced" + + +class ErrorResponse(Exception): + """Raised when a request receives a success=false response.""" + + def __init__(self, code: str, message: str): + """Initialize error response.""" + super().__init__(code) + self.code = code + class GoogleReportState(iot_base.BaseIoT): """Report states to Google. @@ -21,6 +35,8 @@ def __init__(self, cloud): self._connect_lock = asyncio.Lock() self._to_send = Queue(100) self._message_sender_task = None + # Local code waiting for a response + self._response_handler: Dict[str, asyncio.Future] = {} self.register_on_connect(self._async_on_connect) self.register_on_disconnect(self._async_on_disconnect) @@ -36,6 +52,8 @@ def ws_server_url(self) -> str: async def async_send_message(self, msg): """Send a message.""" + msgid = uuid.uuid4().hex + # Since connect is async, guard against send_message called twice in parallel. async with self._connect_lock: if self.state == iot_base.STATE_DISCONNECTED: @@ -44,11 +62,33 @@ async def async_send_message(self, msg): await asyncio.sleep(0) if self._to_send.full(): - self._to_send.get_nowait() - self._to_send.put_nowait(msg) + discard_msg = self._to_send.get_nowait() + self._response_handler.pop(discard_msg["msgid"]).set_exception( + ErrorResponse(ERR_DISCARD_CODE, ERR_DISCARD_MSG) + ) + + fut = self._response_handler[msgid] = asyncio.Future() + + self._to_send.put_nowait({"msgid": msgid, "payload": msg}) + + try: + return await fut + finally: + self._response_handler.pop(msgid, None) def async_handle_message(self, msg): """Handle a message.""" + response_handler = self._response_handler.get(msg["msgid"]) + + if response_handler is not None: + if "error" in msg: + response_handler.set_exception( + ErrorResponse(msg["error"], msg["message"]) + ) + else: + response_handler.set_result(msg.get("payload")) + return + self._logger.warning("Got unhandled message: %s", msg) async def _async_on_connect(self): @@ -62,8 +102,10 @@ async def _async_on_disconnect(self): async def _async_message_sender(self): """Start sending messages.""" + self._logger.debug("Message sender task activated") try: while True: await self.async_send_json_message(await self._to_send.get()) except asyncio.CancelledError: pass + self._logger.debug("Message sender task shut down") diff --git a/hass_nabucasa/iot_base.py b/hass_nabucasa/iot_base.py index 983ea03b4..329b9a12e 100644 --- a/hass_nabucasa/iot_base.py +++ b/hass_nabucasa/iot_base.py @@ -208,7 +208,10 @@ async def _handle_connection(self): if self._logger.isEnabledFor(logging.DEBUG): self._logger.debug("Received message:\n%s\n", pprint.pformat(msg)) - self.async_handle_message(msg) + try: + self.async_handle_message(msg) + except Exception: # pylint: disable=broad-except + self._logger.exception("Unexpected error handling %s", msg) except client_exceptions.WSServerHandshakeError as err: if err.status == 401: diff --git a/tests/conftest.py b/tests/conftest.py index 248887217..49a782f21 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -88,6 +88,7 @@ async def ws_server(aiohttp_client): async def create_client_to_server(handle_server_msg): """Create a websocket server.""" + logger = logging.getLogger(f"{__name__}.ws_server") async def websocket_handler(request): @@ -95,9 +96,14 @@ async def websocket_handler(request): await ws.prepare(request) async for msg in ws: + logger.debug("Received msg: %s", msg) try: - await handle_server_msg(msg) + resp = await handle_server_msg(msg) + if resp is not None: + logger.debug("Sending msg: %s", msg) + await ws.send_json(resp) except DisconnectMockServer: + logger.debug("Closing connection (via DisconnectMockServer)") await ws.close() return ws diff --git a/tests/test_google_report_state.py b/tests/test_google_report_state.py index c105bfd48..70af2170c 100644 --- a/tests/test_google_report_state.py +++ b/tests/test_google_report_state.py @@ -1,9 +1,9 @@ """Tests for Google Report State.""" import asyncio -from unittest.mock import Mock +from unittest.mock import Mock, patch from hass_nabucasa import iot_base -from hass_nabucasa.google_report_state import GoogleReportState +from hass_nabucasa.google_report_state import GoogleReportState, ErrorResponse from tests.common import mock_coro @@ -23,26 +23,28 @@ async def create_grs(loop, ws_server, server_msg_handler) -> GoogleReportState: async def test_send_messages(loop, ws_server): """Test that we connect if we are not connected.""" - msgs = [] + server_msgs = [] async def handle_server_msg(msg): """handle a server msg.""" - msgs.append(msg.json()) + incoming = msg.json() + server_msgs.append(incoming["payload"]) + return {"msgid": incoming["msgid"], "payload": incoming["payload"]["hello"]} grs = await create_grs(loop, ws_server, handle_server_msg) assert grs.state == iot_base.STATE_DISCONNECTED # Test we can handle two simultaneous messages while disconnected - await asyncio.gather( + responses = await asyncio.gather( *[grs.async_send_message({"hello": 0}), grs.async_send_message({"hello": 1})] ) assert grs.state == iot_base.STATE_CONNECTED + assert responses == [0, 1] - # One per message to handle - await asyncio.sleep(0) - await asyncio.sleep(0) - - assert sorted(msgs, key=lambda val: val["hello"]) == [{"hello": 0}, {"hello": 1}] + assert sorted(server_msgs, key=lambda val: val["hello"]) == [ + {"hello": 0}, + {"hello": 1}, + ] await grs.disconnect() assert grs.state == iot_base.STATE_DISCONNECTED @@ -51,24 +53,35 @@ async def handle_server_msg(msg): async def test_max_queue_message(loop, ws_server): """Test that we connect if we are not connected.""" - msgs = [] + server_msgs = [] async def handle_server_msg(msg): """handle a server msg.""" - msgs.append(msg.json()) + incoming = msg.json() + server_msgs.append(incoming["payload"]) + return {"msgid": incoming["msgid"], "payload": incoming["payload"]["hello"]} grs = await create_grs(loop, ws_server, handle_server_msg) - orig_connect = grs.connect - grs.connect = mock_coro - # Test we can handle sending more messages than queue fits - await asyncio.gather(*[grs.async_send_message({"hello": i}) for i in range(150)]) - - loop.create_task(orig_connect()) - - # One per message to handle - for i in range(100): + with patch.object(grs, "_async_message_sender", side_effect=mock_coro): + gather_task = asyncio.gather( + *[grs.async_send_message({"hello": i}) for i in range(150)], + return_exceptions=True + ) + # One per message + for i in range(150): + await asyncio.sleep(0) + + # Start handling messages. + await grs._async_on_connect() + + # One per message + for i in range(150): await asyncio.sleep(0) - assert len(msgs) == 100 + assert len(server_msgs) == 100 + + results = await gather_task + assert len(results) == 150 + assert sum(isinstance(result, ErrorResponse) for result in results) == 50 From 883ac5124e54d7190e4350d0052c7ee1e0ded607 Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Wed, 2 Oct 2019 15:13:35 -0700 Subject: [PATCH 3/5] Test an error response --- hass_nabucasa/google_report_state.py | 1 + tests/test_google_report_state.py | 21 ++++++++++++++++++--- 2 files changed, 19 insertions(+), 3 deletions(-) diff --git a/hass_nabucasa/google_report_state.py b/hass_nabucasa/google_report_state.py index 65d65365e..3af92b152 100644 --- a/hass_nabucasa/google_report_state.py +++ b/hass_nabucasa/google_report_state.py @@ -21,6 +21,7 @@ def __init__(self, code: str, message: str): """Initialize error response.""" super().__init__(code) self.code = code + self.message = message class GoogleReportState(iot_base.BaseIoT): diff --git a/tests/test_google_report_state.py b/tests/test_google_report_state.py index 70af2170c..325243a05 100644 --- a/tests/test_google_report_state.py +++ b/tests/test_google_report_state.py @@ -29,17 +29,32 @@ async def handle_server_msg(msg): """handle a server msg.""" incoming = msg.json() server_msgs.append(incoming["payload"]) - return {"msgid": incoming["msgid"], "payload": incoming["payload"]["hello"]} + + # First msg is ok + if incoming["payload"]["hello"] == 0: + return {"msgid": incoming["msgid"], "payload": "mock-response"} + + # Second msg is error + return { + "msgid": incoming["msgid"], + "error": "mock-code", + "message": "mock-message", + } grs = await create_grs(loop, ws_server, handle_server_msg) assert grs.state == iot_base.STATE_DISCONNECTED # Test we can handle two simultaneous messages while disconnected responses = await asyncio.gather( - *[grs.async_send_message({"hello": 0}), grs.async_send_message({"hello": 1})] + *[grs.async_send_message({"hello": 0}), grs.async_send_message({"hello": 1})], + return_exceptions=True ) assert grs.state == iot_base.STATE_CONNECTED - assert responses == [0, 1] + assert len(responses) == 2 + assert responses[0] == "mock-response" + assert isinstance(responses[1], ErrorResponse) + assert responses[1].code == "mock-code" + assert responses[1].message == "mock-message" assert sorted(server_msgs, key=lambda val: val["hello"]) == [ {"hello": 0}, From ba46d6d93bb9487c3ec27b4607c5f6da0d6110c1 Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Wed, 2 Oct 2019 17:18:13 -0700 Subject: [PATCH 4/5] Handle some error cases --- hass_nabucasa/__init__.py | 1 + hass_nabucasa/iot_base.py | 3 +++ 2 files changed, 4 insertions(+) diff --git a/hass_nabucasa/__init__.py b/hass_nabucasa/__init__.py index 8ea12b194..1c25e9f8c 100644 --- a/hass_nabucasa/__init__.py +++ b/hass_nabucasa/__init__.py @@ -210,6 +210,7 @@ def load_config(): async def stop(self): """Stop the cloud component.""" await self.iot.disconnect() + await self.google_report_state.disconnect() def _decode_claims(self, token): # pylint: disable=no-self-use """Decode the claims in a token.""" diff --git a/hass_nabucasa/iot_base.py b/hass_nabucasa/iot_base.py index 329b9a12e..6a7f66e00 100644 --- a/hass_nabucasa/iot_base.py +++ b/hass_nabucasa/iot_base.py @@ -224,6 +224,9 @@ async def _handle_connection(self): except client_exceptions.ClientError as err: self._logger.warning("Unable to connect: %s", err) + except asyncio.CancelledError: + pass + finally: if disconnect_warn is None: self._logger.info("Connection closed") From e284d7843a86dc8881fbd48fb047f8d9d816533c Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Wed, 2 Oct 2019 20:01:37 -0700 Subject: [PATCH 5/5] Version bump to 0.22 --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index b3ddb8ee4..323625c84 100644 --- a/setup.py +++ b/setup.py @@ -1,6 +1,6 @@ from setuptools import setup -VERSION = "0.21" +VERSION = "0.22" setup( name="hass-nabucasa",