Skip to content

Commit

Permalink
Force token refresh when subscriber sub info fetched while not started (
Browse files Browse the repository at this point in the history
  • Loading branch information
balloob authored Sep 6, 2021
1 parent 39f91bb commit 9c55adc
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 10 deletions.
7 changes: 0 additions & 7 deletions hass_nabucasa/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,13 +190,6 @@ def run_executor(self, callback: Callable, *args) -> asyncio.Future:
"""
return self.client.loop.run_in_executor(None, callback, *args)

async def fetch_subscription_info(self):
"""Fetch subscription info."""
await self.auth.async_check_token()
return await self.websession.get(
self.subscription_info_url, headers={"authorization": self.id_token}
)

async def login(self, email: str, password: str) -> None:
"""Log a user in."""
async with async_timeout.timeout(30):
Expand Down
27 changes: 25 additions & 2 deletions hass_nabucasa/cloud_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,12 @@
_LOGGER = logging.getLogger(__name__)


def _do_log_response(resp):
"""Log the response."""
meth = _LOGGER.debug if resp.status < 400 else _LOGGER.warning
meth("Fetched %s (%s)", resp.url, resp.status)


def _check_token(func):
"""Decorate a function to verify valid token."""

Expand All @@ -26,8 +32,7 @@ def _log_response(func):
async def log_response(*args):
"""Log response if it's bad."""
resp = await func(*args)
meth = _LOGGER.debug if resp.status < 400 else _LOGGER.warning
meth("Fetched %s (%s)", resp.url, resp.status)
_do_log_response(resp)
return resp

return log_response
Expand Down Expand Up @@ -107,3 +112,21 @@ async def async_google_actions_request_sync(cloud):
f"{cloud.google_actions_report_state_url}/request_sync",
headers={AUTHORIZATION: f"Bearer {cloud.id_token}"},
)


@_check_token
async def async_subscription_info(cloud):
"""Fetch subscription info."""
resp = await cloud.websession.get(
cloud.subscription_info_url, headers={"authorization": cloud.id_token}
)
_do_log_response(resp)
resp.raise_for_status()
data = await resp.json()

# If subscription info indicates we are subscribed, force a refresh of the token
if data.get("provider") and not cloud.started:
_LOGGER.debug("Found disconnected account with valid subscription, connecting")
await cloud.auth.async_renew_access_token()

return data
48 changes: 47 additions & 1 deletion tests/test_cloud_api.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""Test cloud API."""

from unittest.mock import patch, AsyncMock

from hass_nabucasa import cloud_api

Expand Down Expand Up @@ -108,3 +108,49 @@ async def test_voice_connection_details(auth_cloud_mock, aioclient_mock):

await cloud_api.async_voice_connection_details(auth_cloud_mock)
assert len(aioclient_mock.mock_calls) == 1


async def test_subscription_info(auth_cloud_mock, aioclient_mock):
"""Test fetching subscription info."""
aioclient_mock.get(
"https://example.com/payments/subscription_info",
json={
"success": True,
"provider": None,
},
)
auth_cloud_mock.id_token = "mock-id-token"
auth_cloud_mock.subscription_info_url = (
"https://example.com/payments/subscription_info"
)

with patch.object(
auth_cloud_mock.auth, "async_renew_access_token", AsyncMock()
) as mock_renew:
data = await cloud_api.async_subscription_info(auth_cloud_mock)
assert len(aioclient_mock.mock_calls) == 1
assert data == {
"success": True,
"provider": None,
}

auth_cloud_mock.started = False
aioclient_mock.clear_requests()
aioclient_mock.get(
"https://example.com/payments/subscription_info",
json={
"success": True,
"provider": "mock-provider",
},
)
with patch.object(
auth_cloud_mock.auth, "async_renew_access_token", AsyncMock()
) as mock_renew:
data = await cloud_api.async_subscription_info(auth_cloud_mock)

assert len(aioclient_mock.mock_calls) == 1
assert data == {
"success": True,
"provider": "mock-provider",
}
assert len(mock_renew.mock_calls) == 1

0 comments on commit 9c55adc

Please sign in to comment.