Skip to content

Commit

Permalink
Merge pull request #94 from NabuCasa/dev
Browse files Browse the repository at this point in the history
0.22
  • Loading branch information
balloob authored Oct 3, 2019
2 parents 7ae1bdc + e284d78 commit ba97b7f
Show file tree
Hide file tree
Showing 7 changed files with 113 additions and 29 deletions.
1 change: 1 addition & 0 deletions hass_nabucasa/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@ def load_config():
async def stop(self):
"""Stop the cloud component."""
await self.iot.disconnect()
await self.google_report_state.disconnect()

def _decode_claims(self, token): # pylint: disable=no-self-use
"""Decode the claims in a token."""
Expand Down
2 changes: 1 addition & 1 deletion hass_nabucasa/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
"user_pool_id": "us-east-1_87ll5WOP8",
"region": "us-east-1",
"relayer": "wss://cloud.nabucasa.com/websocket",
"google_actions_report_state_url": "https://remotestate.nabucasa.com",
"google_actions_report_state_url": "wss://remotestate.nabucasa.com/v1",
"google_actions_sync_url": (
"https://24ab3v80xd.execute-api.us-east-1."
"amazonaws.com/prod/smart_home_sync"
Expand Down
47 changes: 45 additions & 2 deletions hass_nabucasa/google_report_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,27 @@
import asyncio
from asyncio.queues import Queue
import logging
from typing import Dict
import uuid

from . import iot_base

_LOGGER = logging.getLogger(__name__)
MAX_PENDING = 100

ERR_DISCARD_CODE = "message_discarded"
ERR_DISCARD_MSG = "Message discarded because max messages reachced"


class ErrorResponse(Exception):
"""Raised when a request receives a success=false response."""

def __init__(self, code: str, message: str):
"""Initialize error response."""
super().__init__(code)
self.code = code
self.message = message


class GoogleReportState(iot_base.BaseIoT):
"""Report states to Google.
Expand All @@ -21,6 +36,8 @@ def __init__(self, cloud):
self._connect_lock = asyncio.Lock()
self._to_send = Queue(100)
self._message_sender_task = None
# Local code waiting for a response
self._response_handler: Dict[str, asyncio.Future] = {}
self.register_on_connect(self._async_on_connect)
self.register_on_disconnect(self._async_on_disconnect)

Expand All @@ -36,6 +53,8 @@ def ws_server_url(self) -> str:

async def async_send_message(self, msg):
"""Send a message."""
msgid = uuid.uuid4().hex

# Since connect is async, guard against send_message called twice in parallel.
async with self._connect_lock:
if self.state == iot_base.STATE_DISCONNECTED:
Expand All @@ -44,11 +63,33 @@ async def async_send_message(self, msg):
await asyncio.sleep(0)

if self._to_send.full():
self._to_send.get_nowait()
self._to_send.put_nowait(msg)
discard_msg = self._to_send.get_nowait()
self._response_handler.pop(discard_msg["msgid"]).set_exception(
ErrorResponse(ERR_DISCARD_CODE, ERR_DISCARD_MSG)
)

fut = self._response_handler[msgid] = asyncio.Future()

self._to_send.put_nowait({"msgid": msgid, "payload": msg})

try:
return await fut
finally:
self._response_handler.pop(msgid, None)

def async_handle_message(self, msg):
"""Handle a message."""
response_handler = self._response_handler.get(msg["msgid"])

if response_handler is not None:
if "error" in msg:
response_handler.set_exception(
ErrorResponse(msg["error"], msg["message"])
)
else:
response_handler.set_result(msg.get("payload"))
return

self._logger.warning("Got unhandled message: %s", msg)

async def _async_on_connect(self):
Expand All @@ -62,8 +103,10 @@ async def _async_on_disconnect(self):

async def _async_message_sender(self):
"""Start sending messages."""
self._logger.debug("Message sender task activated")
try:
while True:
await self.async_send_json_message(await self._to_send.get())
except asyncio.CancelledError:
pass
self._logger.debug("Message sender task shut down")
8 changes: 7 additions & 1 deletion hass_nabucasa/iot_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,10 @@ async def _handle_connection(self):
if self._logger.isEnabledFor(logging.DEBUG):
self._logger.debug("Received message:\n%s\n", pprint.pformat(msg))

self.async_handle_message(msg)
try:
self.async_handle_message(msg)
except Exception: # pylint: disable=broad-except
self._logger.exception("Unexpected error handling %s", msg)

except client_exceptions.WSServerHandshakeError as err:
if err.status == 401:
Expand All @@ -221,6 +224,9 @@ async def _handle_connection(self):
except client_exceptions.ClientError as err:
self._logger.warning("Unable to connect: %s", err)

except asyncio.CancelledError:
pass

finally:
if disconnect_warn is None:
self._logger.info("Connection closed")
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from setuptools import setup

VERSION = "0.20"
VERSION = "0.22"

setup(
name="hass-nabucasa",
Expand Down
8 changes: 7 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,16 +88,22 @@ async def ws_server(aiohttp_client):

async def create_client_to_server(handle_server_msg):
"""Create a websocket server."""
logger = logging.getLogger(f"{__name__}.ws_server")

async def websocket_handler(request):

ws = web.WebSocketResponse()
await ws.prepare(request)

async for msg in ws:
logger.debug("Received msg: %s", msg)
try:
await handle_server_msg(msg)
resp = await handle_server_msg(msg)
if resp is not None:
logger.debug("Sending msg: %s", msg)
await ws.send_json(resp)
except DisconnectMockServer:
logger.debug("Closing connection (via DisconnectMockServer)")
await ws.close()

return ws
Expand Down
74 changes: 51 additions & 23 deletions tests/test_google_report_state.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
"""Tests for Google Report State."""
import asyncio
from unittest.mock import Mock
from unittest.mock import Mock, patch

from hass_nabucasa import iot_base
from hass_nabucasa.google_report_state import GoogleReportState
from hass_nabucasa.google_report_state import GoogleReportState, ErrorResponse

from tests.common import mock_coro

Expand All @@ -23,26 +23,43 @@ async def create_grs(loop, ws_server, server_msg_handler) -> GoogleReportState:

async def test_send_messages(loop, ws_server):
"""Test that we connect if we are not connected."""
msgs = []
server_msgs = []

async def handle_server_msg(msg):
"""handle a server msg."""
msgs.append(msg.json())
incoming = msg.json()
server_msgs.append(incoming["payload"])

# First msg is ok
if incoming["payload"]["hello"] == 0:
return {"msgid": incoming["msgid"], "payload": "mock-response"}

# Second msg is error
return {
"msgid": incoming["msgid"],
"error": "mock-code",
"message": "mock-message",
}

grs = await create_grs(loop, ws_server, handle_server_msg)
assert grs.state == iot_base.STATE_DISCONNECTED

# Test we can handle two simultaneous messages while disconnected
await asyncio.gather(
*[grs.async_send_message({"hello": 0}), grs.async_send_message({"hello": 1})]
responses = await asyncio.gather(
*[grs.async_send_message({"hello": 0}), grs.async_send_message({"hello": 1})],
return_exceptions=True
)
assert grs.state == iot_base.STATE_CONNECTED
assert len(responses) == 2
assert responses[0] == "mock-response"
assert isinstance(responses[1], ErrorResponse)
assert responses[1].code == "mock-code"
assert responses[1].message == "mock-message"

# One per message to handle
await asyncio.sleep(0)
await asyncio.sleep(0)

assert sorted(msgs, key=lambda val: val["hello"]) == [{"hello": 0}, {"hello": 1}]
assert sorted(server_msgs, key=lambda val: val["hello"]) == [
{"hello": 0},
{"hello": 1},
]

await grs.disconnect()
assert grs.state == iot_base.STATE_DISCONNECTED
Expand All @@ -51,24 +68,35 @@ async def handle_server_msg(msg):

async def test_max_queue_message(loop, ws_server):
"""Test that we connect if we are not connected."""
msgs = []
server_msgs = []

async def handle_server_msg(msg):
"""handle a server msg."""
msgs.append(msg.json())
incoming = msg.json()
server_msgs.append(incoming["payload"])
return {"msgid": incoming["msgid"], "payload": incoming["payload"]["hello"]}

grs = await create_grs(loop, ws_server, handle_server_msg)

orig_connect = grs.connect
grs.connect = mock_coro

# Test we can handle sending more messages than queue fits
await asyncio.gather(*[grs.async_send_message({"hello": i}) for i in range(150)])

loop.create_task(orig_connect())

# One per message to handle
for i in range(100):
with patch.object(grs, "_async_message_sender", side_effect=mock_coro):
gather_task = asyncio.gather(
*[grs.async_send_message({"hello": i}) for i in range(150)],
return_exceptions=True
)
# One per message
for i in range(150):
await asyncio.sleep(0)

# Start handling messages.
await grs._async_on_connect()

# One per message
for i in range(150):
await asyncio.sleep(0)

assert len(msgs) == 100
assert len(server_msgs) == 100

results = await gather_task
assert len(results) == 150
assert sum(isinstance(result, ErrorResponse) for result in results) == 50

0 comments on commit ba97b7f

Please sign in to comment.