From 35e360e4cacd917fa959d1c3658512057e23bc92 Mon Sep 17 00:00:00 2001 From: Erik Montnemery Date: Tue, 6 Feb 2024 16:51:53 +0100 Subject: [PATCH] Add method to remove all cloud data (#564) * Add method to remove all cloud data * Avoid protected access * Add tests * Format code --- hass_nabucasa/__init__.py | 24 +++++++++++++++++++++++ hass_nabucasa/acme.py | 17 ++++++++++++++++ hass_nabucasa/remote.py | 6 ++++++ tests/common.py | 9 +++------ tests/conftest.py | 4 ++-- tests/test_init.py | 41 +++++++++++++++++++++++++++++++++++++++ 6 files changed, 93 insertions(+), 8 deletions(-) diff --git a/hass_nabucasa/__init__.py b/hass_nabucasa/__init__.py index 3aec35e6a..07407f0dd 100644 --- a/hass_nabucasa/__init__.py +++ b/hass_nabucasa/__init__.py @@ -5,7 +5,9 @@ from datetime import datetime, timedelta import json import logging +import os from pathlib import Path +import shutil from typing import Any, Awaitable, Callable, Generic, Literal, Mapping, TypeVar import aiohttp @@ -230,6 +232,28 @@ async def logout(self) -> None: await self.client.logout_cleanups() + async def remove_data(self) -> None: + """Remove all stored data.""" + if self.started: + raise ValueError("Cloud not stopped") + + try: + await self.remote.reset_acme() + finally: + await self.run_executor(self._remove_data) + + def _remove_data(self) -> None: + """Remove all stored data.""" + base_path = self.path() + + # Recursively remove .cloud + if os.path.isdir(base_path): + shutil.rmtree(base_path) + + # Guard against .cloud not being a directory + if base_path.exists(): + base_path.unlink() + def _write_user_info(self) -> None: """Write user info to a file.""" base_path = self.path() diff --git a/hass_nabucasa/acme.py b/hass_nabucasa/acme.py index 759f05910..eade1df29 100644 --- a/hass_nabucasa/acme.py +++ b/hass_nabucasa/acme.py @@ -410,6 +410,14 @@ def _deactivate_account(self) -> None: except errors.Error as err: raise AcmeClientError(f"Can't deactivate account: {err}") from err + def _have_any_file(self) -> bool: + return ( + self.path_registration_info.exists() + or self.path_account_key.exists() + or self.path_fullchain.exists() + or self.path_private_key.exists() + ) + def _remove_files(self) -> None: self.path_registration_info.unlink(missing_ok=True) self.path_account_key.unlink(missing_ok=True) @@ -477,6 +485,15 @@ async def issue_certificate(self) -> None: async def reset_acme(self) -> None: """Revoke and deactivate acme certificate/account.""" _LOGGER.info("Revoke and deactivate ACME user/certificate") + if ( + self._acme_client is None + and self._account_jwk is None + and self._x509 is None + and not await self.cloud.run_executor(self._have_any_file) + ): + _LOGGER.info("ACME user/certificates already cleaned up") + return + if not self._acme_client: await self.cloud.run_executor(self._create_client) diff --git a/hass_nabucasa/remote.py b/hass_nabucasa/remote.py index 79eaeb03f..92fcc1781 100644 --- a/hass_nabucasa/remote.py +++ b/hass_nabucasa/remote.py @@ -622,3 +622,9 @@ async def _should_renew_certificates(self) -> bool: self._acme.email, ) return True + + async def reset_acme(self) -> None: + """Reset the ACME client.""" + if not self._acme: + return + await self._acme.reset_acme() diff --git a/tests/common.py b/tests/common.py index c9f5a9675..d02d0c91c 100644 --- a/tests/common.py +++ b/tests/common.py @@ -3,7 +3,6 @@ import asyncio from pathlib import Path -import tempfile from typing import Any, Coroutine, Literal from unittest.mock import Mock @@ -13,7 +12,7 @@ class MockClient(CloudClient): """Interface class for Home Assistant.""" - def __init__(self, loop, websession): + def __init__(self, base_path, loop, websession): """Initialize MockClient.""" self._loop = loop self._websession = websession @@ -31,15 +30,13 @@ def __init__(self, loop, websession): self.mock_connection_info = [] self.mock_return = [] - self._base_path = None + self._base_path = base_path self.pref_should_connect = False @property - def base_path(self): + def base_path(self) -> Path: """Return path to base dir.""" - if self._base_path is None: - self._base_path = Path(tempfile.gettempdir()) return self._base_path @property diff --git a/tests/conftest.py b/tests/conftest.py index 1e578a021..40445a97a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -26,7 +26,7 @@ async def aioclient_mock(loop): @pytest.fixture -async def cloud_mock(loop, aioclient_mock): +async def cloud_mock(loop, aioclient_mock, tmp_path): """Yield a simple cloud mock.""" cloud = MagicMock(name="Mock Cloud", is_logged_in=True) @@ -37,7 +37,7 @@ def _executor(call, *args): cloud.run_executor = _executor cloud.websession = aioclient_mock.create_session(loop) - cloud.client = MockClient(loop, cloud.websession) + cloud.client = MockClient(tmp_path, loop, cloud.websession) async def update_token(id_token, access_token, refresh_token=None): cloud.id_token = id_token diff --git a/tests/test_init.py b/tests/test_init.py index 87ddd0238..bd824fdd2 100644 --- a/tests/test_init.py +++ b/tests/test_init.py @@ -3,9 +3,13 @@ import json from unittest.mock import AsyncMock, patch, MagicMock, Mock, PropertyMock +import pytest + import hass_nabucasa as cloud from hass_nabucasa.utils import utcnow +from .common import MockClient + def test_constructor_loads_info_from_constant(cloud_client): """Test non-dev mode loads info from SERVERS constant.""" @@ -186,6 +190,43 @@ async def test_logout_clears_info(cloud_client): assert info_file.unlink.called +async def test_remove_data(cloud_client: MockClient) -> None: + """Test removing data.""" + cloud_dir = cloud_client.base_path / ".cloud" + cloud_dir.mkdir() + open(cloud_dir / "unexpected_file", "w") + + cl = cloud.Cloud(cloud_client, cloud.MODE_DEV) + await cl.remove_data() + + assert not cloud_dir.exists() + + +async def test_remove_data_file(cloud_client: MockClient) -> None: + """Test removing data when .cloud is not a directory.""" + cloud_dir = cloud_client.base_path / ".cloud" + open(cloud_dir, "w") + + cl = cloud.Cloud(cloud_client, cloud.MODE_DEV) + await cl.remove_data() + + assert not cloud_dir.exists() + + +async def test_remove_data_started(cloud_client: MockClient) -> None: + """Test removing data when cloud is started.""" + cloud_dir = cloud_client.base_path / ".cloud" + cloud_dir.mkdir() + + cl = cloud.Cloud(cloud_client, cloud.MODE_DEV) + cl.started = True + with pytest.raises(ValueError): + await cl.remove_data() + + assert cloud_dir.exists() + cloud_dir.rmdir() + + def test_write_user_info(cloud_client): """Test writing user info works.""" cl = cloud.Cloud(cloud_client, cloud.MODE_DEV)