Skip to content

Commit

Permalink
Fix edge case where we wouldn't start (#274)
Browse files Browse the repository at this point in the history
Co-authored-by: Joakim Sørensen <[email protected]>
  • Loading branch information
balloob and ludeeus authored Sep 6, 2021
1 parent 74b7ece commit 78ee314
Show file tree
Hide file tree
Showing 6 changed files with 31 additions and 24 deletions.
28 changes: 19 additions & 9 deletions hass_nabucasa/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def __init__(
self.id_token = None
self.access_token = None
self.refresh_token = None
self.started = None
self.iot = CloudIoT(self)
self.google_report_state = GoogleReportState(self)
self.cloudhooks = Cloudhooks(self)
Expand Down Expand Up @@ -142,16 +143,22 @@ async def update_token(
self, id_token: str, access_token: str, refresh_token: str | None = None
) -> None:
"""Update the id and access token."""
is_stopped = not self.is_logged_in or self.subscription_expired
self.id_token = id_token
self.access_token = access_token
if refresh_token is not None:
self.refresh_token = refresh_token

if is_stopped and not self.subscription_expired:
await self.run_executor(self._write_user_info)

if self.started is None:
return

if not self.started and not self.subscription_expired:
self.started = True
self.run_task(self._start())

elif not is_stopped and self.subscription_expired:
elif self.started and self.subscription_expired:
self.started = False
await self.stop()

def register_on_start(self, on_start_cb: Callable[[], Awaitable[None]]):
Expand Down Expand Up @@ -201,6 +208,7 @@ async def logout(self) -> None:
self.access_token = None
self.refresh_token = None

self.started = False
await self.stop()

# Cleanup auth data
Expand All @@ -209,7 +217,7 @@ async def logout(self) -> None:

await self.client.logout_cleanups()

def write_user_info(self) -> None:
def _write_user_info(self) -> None:
"""Write user info to a file."""
base_path = self.path()
if not base_path.exists():
Expand Down Expand Up @@ -257,19 +265,21 @@ def load_config():
info = await self.run_executor(load_config)
if info is None:
# No previous token data
self.started = False
return

self.id_token = info["id_token"]
self.access_token = info["access_token"]
self.refresh_token = info["refresh_token"]

orig_token = self.id_token

await self.auth.async_check_token()

# A refresh will trigger a start, so only start if we didn't refresh
if self.id_token == orig_token and not self.subscription_expired:
await self._start()
if self.subscription_expired:
self.started = False
return

self.started = True
await self._start()

async def _start(self):
"""Start the cloud component."""
Expand Down
2 changes: 0 additions & 2 deletions hass_nabucasa/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,6 @@ async def async_login(self, email, password):
await self.cloud.update_token(
cognito.id_token, cognito.access_token, cognito.refresh_token
)
await self.cloud.run_executor(self.cloud.write_user_info)

except ForceChangePasswordException as err:
raise PasswordChangeRequired() from err
Expand Down Expand Up @@ -198,7 +197,6 @@ async def _async_renew_access_token(self):
try:
await self.cloud.run_executor(cognito.renew_access_token)
await self.cloud.update_token(cognito.id_token, cognito.access_token)
await self.cloud.run_executor(self.cloud.write_user_info)

except ClientError as err:
raise _map_aws_exception(err) from err
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from setuptools import setup

VERSION = "0.47.1"
VERSION = "0.48.0"

setup(
name="hass-nabucasa",
Expand Down
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ async def update_token(id_token, access_token, refresh_token=None):
if refresh_token is not None:
cloud.refresh_token = refresh_token

cloud.update_token = update_token
cloud.update_token = MagicMock(side_effect=update_token)

yield cloud

Expand Down
19 changes: 9 additions & 10 deletions tests/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ async def test_login_invalid_auth(mock_cognito, mock_cloud):
with pytest.raises(auth_api.Unauthenticated):
await auth.async_login("user", "pass")

assert len(mock_cloud.write_user_info.mock_calls) == 0
assert len(mock_cloud.update_token.mock_calls) == 0


async def test_login_user_not_found(mock_cognito, mock_cloud):
Expand All @@ -41,7 +41,7 @@ async def test_login_user_not_found(mock_cognito, mock_cloud):
with pytest.raises(auth_api.UserNotFound):
await auth.async_login("user", "pass")

assert len(mock_cloud.write_user_info.mock_calls) == 0
assert len(mock_cloud.update_token.mock_calls) == 0


async def test_login_user_not_confirmed(mock_cognito, mock_cloud):
Expand All @@ -52,7 +52,7 @@ async def test_login_user_not_confirmed(mock_cognito, mock_cloud):
with pytest.raises(auth_api.UserNotConfirmed):
await auth.async_login("user", "pass")

assert len(mock_cloud.write_user_info.mock_calls) == 0
assert len(mock_cloud.update_token.mock_calls) == 0


async def test_login(mock_cognito, mock_cloud):
Expand All @@ -65,10 +65,9 @@ async def test_login(mock_cognito, mock_cloud):
await auth.async_login("user", "pass")

assert len(mock_cognito.authenticate.mock_calls) == 1
assert mock_cloud.id_token == "test_id_token"
assert mock_cloud.access_token == "test_access_token"
assert mock_cloud.refresh_token == "test_refresh_token"
assert len(mock_cloud.write_user_info.mock_calls) == 1
mock_cloud.update_token.assert_called_once_with(
"test_id_token", "test_access_token", "test_refresh_token"
)


async def test_register(mock_cognito, cloud_mock):
Expand Down Expand Up @@ -131,7 +130,7 @@ async def test_check_token_writes_new_token_on_refresh(mock_cognito, cloud_mock)
assert len(mock_cognito.check_token.mock_calls) == 1
assert cloud_mock.id_token == "new id token"
assert cloud_mock.access_token == "new access token"
assert len(cloud_mock.write_user_info.mock_calls) == 1
cloud_mock.update_token.assert_called_once_with("new id token", "new access token")


async def test_check_token_does_not_write_existing_token(mock_cognito, cloud_mock):
Expand All @@ -144,7 +143,7 @@ async def test_check_token_does_not_write_existing_token(mock_cognito, cloud_moc
assert len(mock_cognito.check_token.mock_calls) == 1
assert cloud_mock.id_token != mock_cognito.id_token
assert cloud_mock.access_token != mock_cognito.access_token
assert len(cloud_mock.write_user_info.mock_calls) == 0
assert len(cloud_mock.update_token.mock_calls) == 0


async def test_check_token_raises(mock_cognito, cloud_mock):
Expand All @@ -158,7 +157,7 @@ async def test_check_token_raises(mock_cognito, cloud_mock):
assert len(mock_cognito.check_token.mock_calls) == 2
assert cloud_mock.id_token != mock_cognito.id_token
assert cloud_mock.access_token != mock_cognito.access_token
assert len(cloud_mock.write_user_info.mock_calls) == 0
assert len(cloud_mock.update_token.mock_calls) == 0


async def test_async_setup(cloud_mock):
Expand Down
2 changes: 1 addition & 1 deletion tests/test_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def test_write_user_info(cloud_client):
cl.refresh_token = "test-refresh-token"

with patch("pathlib.Path.chmod"), patch("hass_nabucasa.atomic_write") as mock_write:
cl.write_user_info()
cl._write_user_info()

mock_file = mock_write.return_value.__enter__.return_value

Expand Down

0 comments on commit 78ee314

Please sign in to comment.