Skip to content

Commit

Permalink
Add method to remove all cloud data (#564)
Browse files Browse the repository at this point in the history
* Add method to remove all cloud data

* Avoid protected access

* Add tests

* Format code
  • Loading branch information
emontnemery authored Feb 6, 2024
1 parent d92a04a commit 35e360e
Show file tree
Hide file tree
Showing 6 changed files with 93 additions and 8 deletions.
24 changes: 24 additions & 0 deletions hass_nabucasa/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
from datetime import datetime, timedelta
import json
import logging
import os
from pathlib import Path
import shutil
from typing import Any, Awaitable, Callable, Generic, Literal, Mapping, TypeVar

import aiohttp
Expand Down Expand Up @@ -230,6 +232,28 @@ async def logout(self) -> None:

await self.client.logout_cleanups()

async def remove_data(self) -> None:
"""Remove all stored data."""
if self.started:
raise ValueError("Cloud not stopped")

try:
await self.remote.reset_acme()
finally:
await self.run_executor(self._remove_data)

def _remove_data(self) -> None:
"""Remove all stored data."""
base_path = self.path()

# Recursively remove .cloud
if os.path.isdir(base_path):
shutil.rmtree(base_path)

# Guard against .cloud not being a directory
if base_path.exists():
base_path.unlink()

def _write_user_info(self) -> None:
"""Write user info to a file."""
base_path = self.path()
Expand Down
17 changes: 17 additions & 0 deletions hass_nabucasa/acme.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,14 @@ def _deactivate_account(self) -> None:
except errors.Error as err:
raise AcmeClientError(f"Can't deactivate account: {err}") from err

def _have_any_file(self) -> bool:
return (
self.path_registration_info.exists()
or self.path_account_key.exists()
or self.path_fullchain.exists()
or self.path_private_key.exists()
)

def _remove_files(self) -> None:
self.path_registration_info.unlink(missing_ok=True)
self.path_account_key.unlink(missing_ok=True)
Expand Down Expand Up @@ -477,6 +485,15 @@ async def issue_certificate(self) -> None:
async def reset_acme(self) -> None:
"""Revoke and deactivate acme certificate/account."""
_LOGGER.info("Revoke and deactivate ACME user/certificate")
if (
self._acme_client is None
and self._account_jwk is None
and self._x509 is None
and not await self.cloud.run_executor(self._have_any_file)
):
_LOGGER.info("ACME user/certificates already cleaned up")
return

if not self._acme_client:
await self.cloud.run_executor(self._create_client)

Expand Down
6 changes: 6 additions & 0 deletions hass_nabucasa/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -622,3 +622,9 @@ async def _should_renew_certificates(self) -> bool:
self._acme.email,
)
return True

async def reset_acme(self) -> None:
"""Reset the ACME client."""
if not self._acme:
return
await self._acme.reset_acme()
9 changes: 3 additions & 6 deletions tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

import asyncio
from pathlib import Path
import tempfile
from typing import Any, Coroutine, Literal
from unittest.mock import Mock

Expand All @@ -13,7 +12,7 @@
class MockClient(CloudClient):
"""Interface class for Home Assistant."""

def __init__(self, loop, websession):
def __init__(self, base_path, loop, websession):
"""Initialize MockClient."""
self._loop = loop
self._websession = websession
Expand All @@ -31,15 +30,13 @@ def __init__(self, loop, websession):
self.mock_connection_info = []

self.mock_return = []
self._base_path = None
self._base_path = base_path

self.pref_should_connect = False

@property
def base_path(self):
def base_path(self) -> Path:
"""Return path to base dir."""
if self._base_path is None:
self._base_path = Path(tempfile.gettempdir())
return self._base_path

@property
Expand Down
4 changes: 2 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ async def aioclient_mock(loop):


@pytest.fixture
async def cloud_mock(loop, aioclient_mock):
async def cloud_mock(loop, aioclient_mock, tmp_path):
"""Yield a simple cloud mock."""
cloud = MagicMock(name="Mock Cloud", is_logged_in=True)

Expand All @@ -37,7 +37,7 @@ def _executor(call, *args):
cloud.run_executor = _executor

cloud.websession = aioclient_mock.create_session(loop)
cloud.client = MockClient(loop, cloud.websession)
cloud.client = MockClient(tmp_path, loop, cloud.websession)

async def update_token(id_token, access_token, refresh_token=None):
cloud.id_token = id_token
Expand Down
41 changes: 41 additions & 0 deletions tests/test_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,13 @@
import json
from unittest.mock import AsyncMock, patch, MagicMock, Mock, PropertyMock

import pytest

import hass_nabucasa as cloud
from hass_nabucasa.utils import utcnow

from .common import MockClient


def test_constructor_loads_info_from_constant(cloud_client):
"""Test non-dev mode loads info from SERVERS constant."""
Expand Down Expand Up @@ -186,6 +190,43 @@ async def test_logout_clears_info(cloud_client):
assert info_file.unlink.called


async def test_remove_data(cloud_client: MockClient) -> None:
"""Test removing data."""
cloud_dir = cloud_client.base_path / ".cloud"
cloud_dir.mkdir()
open(cloud_dir / "unexpected_file", "w")

cl = cloud.Cloud(cloud_client, cloud.MODE_DEV)
await cl.remove_data()

assert not cloud_dir.exists()


async def test_remove_data_file(cloud_client: MockClient) -> None:
"""Test removing data when .cloud is not a directory."""
cloud_dir = cloud_client.base_path / ".cloud"
open(cloud_dir, "w")

cl = cloud.Cloud(cloud_client, cloud.MODE_DEV)
await cl.remove_data()

assert not cloud_dir.exists()


async def test_remove_data_started(cloud_client: MockClient) -> None:
"""Test removing data when cloud is started."""
cloud_dir = cloud_client.base_path / ".cloud"
cloud_dir.mkdir()

cl = cloud.Cloud(cloud_client, cloud.MODE_DEV)
cl.started = True
with pytest.raises(ValueError):
await cl.remove_data()

assert cloud_dir.exists()
cloud_dir.rmdir()


def test_write_user_info(cloud_client):
"""Test writing user info works."""
cl = cloud.Cloud(cloud_client, cloud.MODE_DEV)
Expand Down

0 comments on commit 35e360e

Please sign in to comment.