From 992713ee306986c9192363b14da9a41fe87886aa Mon Sep 17 00:00:00 2001 From: Michael Hansen Date: Thu, 16 Nov 2023 16:49:39 -0600 Subject: [PATCH 01/18] First draft of Wyoming satellite --- homeassistant/components/wyoming/__init__.py | 7 + .../components/wyoming/config_flow.py | 13 +- .../components/wyoming/manifest.json | 3 +- homeassistant/components/wyoming/satellite.py | 259 ++++++++++++++++++ requirements_all.txt | 2 +- requirements_test_all.txt | 2 +- 6 files changed, 282 insertions(+), 4 deletions(-) create mode 100644 homeassistant/components/wyoming/satellite.py diff --git a/homeassistant/components/wyoming/__init__.py b/homeassistant/components/wyoming/__init__.py index 33064d2109755..66c75f8f81076 100644 --- a/homeassistant/components/wyoming/__init__.py +++ b/homeassistant/components/wyoming/__init__.py @@ -9,6 +9,7 @@ from .const import ATTR_SPEAKER, DOMAIN from .data import WyomingService +from .satellite import WyomingSatellite _LOGGER = logging.getLogger(__name__) @@ -32,6 +33,12 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: service.platforms, ) + if service.info.satellite is not None: + hass.async_create_background_task( + WyomingSatellite(hass, service).run(), + f"Satellite {service.info.satellite.name}", + ) + return True diff --git a/homeassistant/components/wyoming/config_flow.py b/homeassistant/components/wyoming/config_flow.py index f6b8ed7389095..ae1ea2f513d3f 100644 --- a/homeassistant/components/wyoming/config_flow.py +++ b/homeassistant/components/wyoming/config_flow.py @@ -1,7 +1,7 @@ """Config flow for Wyoming integration.""" from __future__ import annotations -from typing import Any +from typing import TYPE_CHECKING, Any from urllib.parse import urlparse import voluptuous as vol @@ -11,6 +11,9 @@ from homeassistant.const import CONF_HOST, CONF_PORT from homeassistant.data_entry_flow import FlowResult +if TYPE_CHECKING: + from wyoming.info import Satellite + from .const import DOMAIN from .data import WyomingService @@ -59,12 +62,20 @@ async def async_step_user( # wake-word-detection wake_installed = [wake for wake in service.info.wake if wake.installed] + # satellite + satellite_installed: Satellite | None = None + + if (service.info.satellite is not None) and service.info.satellite.installed: + satellite_installed = service.info.satellite + if asr_installed: name = asr_installed[0].name elif tts_installed: name = tts_installed[0].name elif wake_installed: name = wake_installed[0].name + elif satellite_installed: + name = satellite_installed.name else: return self.async_abort(reason="no_services") diff --git a/homeassistant/components/wyoming/manifest.json b/homeassistant/components/wyoming/manifest.json index ddb5407e1cea5..f1a7faac81f90 100644 --- a/homeassistant/components/wyoming/manifest.json +++ b/homeassistant/components/wyoming/manifest.json @@ -3,7 +3,8 @@ "name": "Wyoming Protocol", "codeowners": ["@balloob", "@synesthesiam"], "config_flow": true, + "dependencies": ["assist_pipeline"], "documentation": "https://www.home-assistant.io/integrations/wyoming", "iot_class": "local_push", - "requirements": ["wyoming==1.2.0"] + "requirements": ["wyoming==1.3.0"] } diff --git a/homeassistant/components/wyoming/satellite.py b/homeassistant/components/wyoming/satellite.py new file mode 100644 index 0000000000000..bca99f30ec616 --- /dev/null +++ b/homeassistant/components/wyoming/satellite.py @@ -0,0 +1,259 @@ +"""Support for Wyoming satellite services.""" +import asyncio +from collections.abc import AsyncGenerator +import io +import logging +from typing import Final +import wave + +from wyoming.asr import Transcript +from wyoming.audio import AudioChunk, AudioChunkConverter, AudioStart, AudioStop +from wyoming.client import AsyncTcpClient +from wyoming.pipeline import PipelineStage, RunPipeline +from wyoming.satellite import RunSatellite +from wyoming.tts import Synthesize, SynthesizeVoice +from wyoming.wake import Detection + +from homeassistant.components import assist_pipeline, stt, tts +from homeassistant.core import Context, HomeAssistant + +from .data import WyomingService + +_LOGGER = logging.getLogger() + +_SAMPLES_PER_CHUNK: Final = 1024 +_RECONNECT_SECONDS: Final = 10 +_RESTART_SECONDS: Final = 3 + + +class WyomingSatellite: + """Remove voice satellite running the Wyoming protocol.""" + + def __init__(self, hass: HomeAssistant, service: WyomingService) -> None: + """Initialize satellite.""" + self.hass = hass + self.service = service + self._client: AsyncTcpClient | None = None + self._chunk_converter = AudioChunkConverter(rate=16000, width=2, channels=1) + self._is_pipeline_running = False + + async def run(self) -> None: + """Run and maintain a connection to satellite.""" + while self.hass.is_running: + try: + await self._run_once() + except Exception: # pylint: disable=broad-exception-caught + _LOGGER.exception( + "Unexpected error running satellite. Restarting in %s second(s)", + _RECONNECT_SECONDS, + ) + await asyncio.sleep(_RESTART_SECONDS) + + async def _run_once(self) -> None: + """Run pipelines until an error occurs.""" + while True: + try: + await self._connect() + break + except ConnectionError: + _LOGGER.exception( + "Failed to connect to satellite. Reconnecting in %s second(s)", + _RECONNECT_SECONDS, + ) + await asyncio.sleep(_RECONNECT_SECONDS) + + assert self._client is not None + _LOGGER.debug("Connected to satellite") + + # Tell satellite that we're ready + await self._client.write_event(RunSatellite().event()) + + # Wait until we get RunPipeline event + run_pipeline: RunPipeline | None = None + while True: + run_event = await self._client.read_event() + if run_event is None: + raise ConnectionResetError("Satellite disconnected") + + if RunPipeline.is_type(run_event.type): + run_pipeline = RunPipeline.from_event(run_event) + break + + _LOGGER.debug("Unexpected event from satellite: %s", run_event) + + assert run_pipeline is not None + _LOGGER.debug("Received run information: %s", run_pipeline) + + start_stage = _convert_stage(run_pipeline.start_stage) + end_stage = _convert_stage(run_pipeline.end_stage) + + # Default pipeline + pipeline = assist_pipeline.async_get_pipeline(self.hass) + + while True: + # We will push audio in through a queue + audio_queue: asyncio.Queue[bytes] = asyncio.Queue() + stt_stream = _stt_stream(audio_queue) + + # Start pipeline running + _LOGGER.debug( + "Starting pipeline %s from %s to %s", + pipeline.name, + start_stage, + end_stage, + ) + self._is_pipeline_running = True + _pipeline_task = asyncio.create_task( + assist_pipeline.async_pipeline_from_audio_stream( + self.hass, + context=Context(), + event_callback=self._event_callback, + stt_metadata=stt.SpeechMetadata( + language=pipeline.language, + format=stt.AudioFormats.WAV, + codec=stt.AudioCodecs.PCM, + bit_rate=stt.AudioBitRates.BITRATE_16, + sample_rate=stt.AudioSampleRates.SAMPLERATE_16000, + channel=stt.AudioChannels.CHANNEL_MONO, + ), + stt_stream=stt_stream, + start_stage=start_stage, + end_stage=end_stage, + tts_audio_output="wav", + pipeline_id=pipeline.id, + ) + ) + + while self._is_pipeline_running: + client_event = await self._client.read_event() + if client_event is None: + raise ConnectionResetError("Satellite disconnected") + + if AudioChunk.is_type(client_event.type): + # Microphone audio + chunk = AudioChunk.from_event(client_event) + chunk = self._chunk_converter.convert(chunk) + audio_queue.put_nowait(chunk.audio) + else: + _LOGGER.debug("Unexpected event from satellite: %s", client_event) + + _LOGGER.debug("Pipeline finished") + + def _event_callback(self, event: assist_pipeline.PipelineEvent) -> None: + """Translate pipeline events into Wyoming events.""" + assert self._client is not None + + if event.type == assist_pipeline.PipelineEventType.RUN_END: + self._is_pipeline_running = False + elif event.type == assist_pipeline.PipelineEventType.WAKE_WORD_END: + # Wake word detection + detection = Detection() + if event.data: + wake_word_output = event.data["wake_word_output"] + detection.name = wake_word_output["wake_word_id"] + detection.timestamp = wake_word_output.get("timestamp") + + self.hass.add_job(self._client.write_event(detection.event())) + elif event.type == assist_pipeline.PipelineEventType.STT_END: + # STT transcript + if event.data: + stt_text = event.data["stt_output"]["text"] + self.hass.add_job( + self._client.write_event(Transcript(text=stt_text).event()) + ) + elif event.type == assist_pipeline.PipelineEventType.TTS_START: + # TTS text + if event.data: + self.hass.add_job( + self._client.write_event( + Synthesize( + text=event.data["tts_input"], + voice=SynthesizeVoice( + name=event.data["voice"], + language=event.data["language"], + ), + ).event() + ) + ) + elif event.type == assist_pipeline.PipelineEventType.TTS_END: + # TTS stream + if event.data: + media_id = event.data["tts_output"]["media_id"] + self.hass.add_job(self._stream_tts(media_id)) + + async def _connect(self) -> None: + """Connect to satellite over TCP.""" + _LOGGER.debug( + "Connecting to satellite at %s:%s", self.service.host, self.service.port + ) + self._client = AsyncTcpClient(self.service.host, self.service.port) + await self._client.connect() + + async def _stream_tts(self, media_id: str) -> None: + """Stream TTS WAV audio to satellite in chunks.""" + assert self._client is not None + + extension, data = await tts.async_get_media_source_audio(self.hass, media_id) + if extension != "wav": + raise ValueError(f"Cannot stream audio format to satellite: {extension}") + + with io.BytesIO(data) as wav_io, wave.open(wav_io, "rb") as wav_file: + sample_rate = wav_file.getframerate() + sample_width = wav_file.getsampwidth() + sample_channels = wav_file.getnchannels() + _LOGGER.debug("Streaming %s TTS sample(s)", wav_file.getnframes()) + + await self._client.write_event( + AudioStart( + rate=sample_rate, + width=sample_width, + channels=sample_channels, + ).event() + ) + + # Stream audio chunks + while audio_bytes := wav_file.readframes(_SAMPLES_PER_CHUNK): + await self._client.write_event( + AudioChunk( + rate=sample_rate, + width=sample_width, + channels=sample_channels, + audio=audio_bytes, + ).event() + ) + + await self._client.write_event(AudioStop().event()) + _LOGGER.debug("TTS streaming complete") + + +# ----------------------------------------------------------------------------- + + +async def _stt_stream( + audio_queue: asyncio.Queue[bytes], +) -> AsyncGenerator[bytes, None]: + """Yield audio chunks from a queue.""" + is_first_chunk = True + while chunk := await audio_queue.get(): + if is_first_chunk: + is_first_chunk = False + _LOGGER.debug("Receiving audio from satellite") + + yield chunk + + +def _convert_stage(wyoming_stage: PipelineStage) -> assist_pipeline.PipelineStage: + """Convert Wyoming pipeline stage to Assist pipeline stage.""" + if wyoming_stage == PipelineStage.WAKE: + return assist_pipeline.PipelineStage.WAKE_WORD + + if wyoming_stage == PipelineStage.ASR: + return assist_pipeline.PipelineStage.STT + + if wyoming_stage == PipelineStage.HANDLE: + return assist_pipeline.PipelineStage.INTENT + + if wyoming_stage == PipelineStage.TTS: + return assist_pipeline.PipelineStage.TTS + + raise ValueError(f"Unknown Wyoming pipeline stage: {wyoming_stage}") diff --git a/requirements_all.txt b/requirements_all.txt index d19df8c1c9aeb..7319f43c9c9bf 100644 --- a/requirements_all.txt +++ b/requirements_all.txt @@ -2753,7 +2753,7 @@ wled==0.17.0 wolf-smartset==0.1.11 # homeassistant.components.wyoming -wyoming==1.2.0 +wyoming==1.3.0 # homeassistant.components.xbox xbox-webapi==2.0.11 diff --git a/requirements_test_all.txt b/requirements_test_all.txt index 5bf8f26c87407..6db5652c55250 100644 --- a/requirements_test_all.txt +++ b/requirements_test_all.txt @@ -2057,7 +2057,7 @@ wled==0.17.0 wolf-smartset==0.1.11 # homeassistant.components.wyoming -wyoming==1.2.0 +wyoming==1.3.0 # homeassistant.components.xbox xbox-webapi==2.0.11 From 6e07329d19a4cf7cbe91ad983a07efe18a626a75 Mon Sep 17 00:00:00 2001 From: Michael Hansen Date: Mon, 27 Nov 2023 15:58:06 -0600 Subject: [PATCH 02/18] Set up homeassistant in tests --- tests/components/wyoming/conftest.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/tests/components/wyoming/conftest.py b/tests/components/wyoming/conftest.py index 2c8081908f772..d211b95e92720 100644 --- a/tests/components/wyoming/conftest.py +++ b/tests/components/wyoming/conftest.py @@ -7,12 +7,19 @@ from homeassistant.components import stt from homeassistant.config_entries import ConfigEntry from homeassistant.core import HomeAssistant +from homeassistant.setup import async_setup_component from . import STT_INFO, TTS_INFO, WAKE_WORD_INFO from tests.common import MockConfigEntry +@pytest.fixture(autouse=True) +async def init_components(hass: HomeAssistant): + """Set up required components.""" + assert await async_setup_component(hass, "homeassistant", {}) + + @pytest.fixture def mock_setup_entry() -> Generator[AsyncMock, None, None]: """Override async_setup_entry.""" From ca5d347f008872fa8ddbba507d00cb0210c9cc91 Mon Sep 17 00:00:00 2001 From: Michael Hansen Date: Mon, 27 Nov 2023 15:58:19 -0600 Subject: [PATCH 03/18] Move satellite --- homeassistant/components/wyoming/satellite/__init__.py | 7 +++++++ .../components/wyoming/{ => satellite}/satellite.py | 2 +- 2 files changed, 8 insertions(+), 1 deletion(-) create mode 100644 homeassistant/components/wyoming/satellite/__init__.py rename homeassistant/components/wyoming/{ => satellite}/satellite.py (99%) diff --git a/homeassistant/components/wyoming/satellite/__init__.py b/homeassistant/components/wyoming/satellite/__init__.py new file mode 100644 index 0000000000000..7da9af4884c8f --- /dev/null +++ b/homeassistant/components/wyoming/satellite/__init__.py @@ -0,0 +1,7 @@ +"""Support for Wyoming satellite services.""" + +from .satellite import WyomingSatellite + +__all__ = [ + "WyomingSatellite", +] diff --git a/homeassistant/components/wyoming/satellite.py b/homeassistant/components/wyoming/satellite/satellite.py similarity index 99% rename from homeassistant/components/wyoming/satellite.py rename to homeassistant/components/wyoming/satellite/satellite.py index bca99f30ec616..95a7b2c149935 100644 --- a/homeassistant/components/wyoming/satellite.py +++ b/homeassistant/components/wyoming/satellite/satellite.py @@ -17,7 +17,7 @@ from homeassistant.components import assist_pipeline, stt, tts from homeassistant.core import Context, HomeAssistant -from .data import WyomingService +from ..data import WyomingService _LOGGER = logging.getLogger() From 3bc83927321cd130532d3b237a5c9329cfe68d0a Mon Sep 17 00:00:00 2001 From: Michael Hansen Date: Tue, 28 Nov 2023 16:29:10 -0600 Subject: [PATCH 04/18] Add devices with binary sensor and select --- homeassistant/components/wyoming/__init__.py | 40 ++++-- .../components/wyoming/binary_sensor.py | 65 +++++++++ .../components/wyoming/config_flow.py | 103 ++++++++----- homeassistant/components/wyoming/data.py | 39 ++++- homeassistant/components/wyoming/entity.py | 24 ++++ .../components/wyoming/manifest.json | 3 +- homeassistant/components/wyoming/models.py | 13 ++ .../components/wyoming/satellite/__init__.py | 3 + .../components/wyoming/satellite/devices.py | 136 ++++++++++++++++++ .../components/wyoming/satellite/satellite.py | 123 ++++++++++++---- homeassistant/components/wyoming/select.py | 55 +++++++ homeassistant/components/wyoming/strings.json | 22 ++- homeassistant/components/wyoming/stt.py | 5 +- homeassistant/components/wyoming/tts.py | 5 +- homeassistant/components/wyoming/wake_word.py | 5 +- homeassistant/generated/zeroconf.py | 5 + 16 files changed, 559 insertions(+), 87 deletions(-) create mode 100644 homeassistant/components/wyoming/binary_sensor.py create mode 100644 homeassistant/components/wyoming/entity.py create mode 100644 homeassistant/components/wyoming/models.py create mode 100644 homeassistant/components/wyoming/satellite/devices.py create mode 100644 homeassistant/components/wyoming/select.py diff --git a/homeassistant/components/wyoming/__init__.py b/homeassistant/components/wyoming/__init__.py index 66c75f8f81076..b6dd66b533652 100644 --- a/homeassistant/components/wyoming/__init__.py +++ b/homeassistant/components/wyoming/__init__.py @@ -4,18 +4,23 @@ import logging from homeassistant.config_entries import ConfigEntry +from homeassistant.const import Platform from homeassistant.core import HomeAssistant from homeassistant.exceptions import ConfigEntryNotReady -from .const import ATTR_SPEAKER, DOMAIN +from .const import DOMAIN from .data import WyomingService -from .satellite import WyomingSatellite +from .models import DomainDataItem +from .satellite import SatelliteDevices, WyomingSatellite _LOGGER = logging.getLogger(__name__) +PLATFORMS = [Platform.BINARY_SENSOR, Platform.SELECT] + __all__ = [ - "ATTR_SPEAKER", "DOMAIN", + "async_setup_entry", + "async_unload_entry", ] @@ -26,29 +31,46 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: if service is None: raise ConfigEntryNotReady("Unable to connect") - hass.data.setdefault(DOMAIN, {})[entry.entry_id] = service + satellite_devices = SatelliteDevices(hass, entry) + satellite_devices.async_setup() + + item = DomainDataItem(service=service, satellite_devices=satellite_devices) + hass.data.setdefault(DOMAIN, {})[entry.entry_id] = item await hass.config_entries.async_forward_entry_setups( entry, - service.platforms, + service.platforms + PLATFORMS, ) + entry.async_on_unload(entry.add_update_listener(update_listener)) + if service.info.satellite is not None: + satellite_device = satellite_devices.async_get_or_create(item.service) + wyoming_satellite = WyomingSatellite(hass, service, satellite_device) hass.async_create_background_task( - WyomingSatellite(hass, service).run(), - f"Satellite {service.info.satellite.name}", + wyoming_satellite.run(), f"Satellite {item.service.info.satellite.name}" ) + def stop_satellite(): + wyoming_satellite.is_running = False + + entry.async_on_unload(stop_satellite) + return True +async def update_listener(hass: HomeAssistant, entry: ConfigEntry): + """Handle options update.""" + await hass.config_entries.async_reload(entry.entry_id) + + async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: """Unload Wyoming.""" - service: WyomingService = hass.data[DOMAIN][entry.entry_id] + item: DomainDataItem = hass.data[DOMAIN][entry.entry_id] unload_ok = await hass.config_entries.async_unload_platforms( entry, - service.platforms, + item.service.platforms, ) if unload_ok: del hass.data[DOMAIN][entry.entry_id] diff --git a/homeassistant/components/wyoming/binary_sensor.py b/homeassistant/components/wyoming/binary_sensor.py new file mode 100644 index 0000000000000..07efc4516587d --- /dev/null +++ b/homeassistant/components/wyoming/binary_sensor.py @@ -0,0 +1,65 @@ +"""Binary sensor for Wyoming.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from homeassistant.components.binary_sensor import ( + BinarySensorEntity, + BinarySensorEntityDescription, +) +from homeassistant.config_entries import ConfigEntry +from homeassistant.core import HomeAssistant, callback +from homeassistant.helpers.entity_platform import AddEntitiesCallback + +from .const import DOMAIN +from .entity import WyomingSatelliteEntity +from .satellite import SatelliteDevice + +if TYPE_CHECKING: + from .models import DomainDataItem + + +async def async_setup_entry( + hass: HomeAssistant, + config_entry: ConfigEntry, + async_add_entities: AddEntitiesCallback, +) -> None: + """Set up binary sensor entities.""" + domain_data: DomainDataItem = hass.data[DOMAIN][config_entry.entry_id] + + @callback + def async_add_device(device: SatelliteDevice) -> None: + """Add device.""" + async_add_entities([WyomingSatelliteAssistInProgress(device)]) + + domain_data.satellite_devices.async_add_new_device_listener(async_add_device) + + async_add_entities( + [ + WyomingSatelliteAssistInProgress(device) + for device in domain_data.satellite_devices + ] + ) + + +class WyomingSatelliteAssistInProgress(WyomingSatelliteEntity, BinarySensorEntity): + """Entity to represent Assist is in progress for satellite.""" + + entity_description = BinarySensorEntityDescription( + key="assist_in_progress", + translation_key="assist_in_progress", + ) + _attr_is_on = False + + async def async_added_to_hass(self) -> None: + """Call when entity about to be added to hass.""" + await super().async_added_to_hass() + + self.async_on_remove(self._device.async_listen_update(self._is_active_changed)) + + @callback + def _is_active_changed(self, device: SatelliteDevice) -> None: + """Call when active state changed.""" + self._attr_is_on = self._device.is_active + self.async_write_ha_state() diff --git a/homeassistant/components/wyoming/config_flow.py b/homeassistant/components/wyoming/config_flow.py index ae1ea2f513d3f..bdf62bbb7eee6 100644 --- a/homeassistant/components/wyoming/config_flow.py +++ b/homeassistant/components/wyoming/config_flow.py @@ -1,22 +1,22 @@ """Config flow for Wyoming integration.""" from __future__ import annotations -from typing import TYPE_CHECKING, Any +import logging +from typing import Any from urllib.parse import urlparse import voluptuous as vol from homeassistant import config_entries -from homeassistant.components.hassio import HassioServiceInfo -from homeassistant.const import CONF_HOST, CONF_PORT +from homeassistant.components import hassio, zeroconf +from homeassistant.const import CONF_HOST, CONF_NAME, CONF_PORT from homeassistant.data_entry_flow import FlowResult -if TYPE_CHECKING: - from wyoming.info import Satellite - from .const import DOMAIN from .data import WyomingService +_LOGGER = logging.getLogger() + STEP_USER_DATA_SCHEMA = vol.Schema( { vol.Required(CONF_HOST): str, @@ -30,7 +30,8 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN): VERSION = 1 - _hassio_discovery: HassioServiceInfo + _hassio_discovery: hassio.HassioServiceInfo + _service: WyomingService | None = None async def async_step_user( self, user_input: dict[str, Any] | None = None @@ -53,35 +54,14 @@ async def async_step_user( errors={"base": "cannot_connect"}, ) - # ASR = automated speech recognition (speech-to-text) - asr_installed = [asr for asr in service.info.asr if asr.installed] - - # TTS = text-to-speech - tts_installed = [tts for tts in service.info.tts if tts.installed] - - # wake-word-detection - wake_installed = [wake for wake in service.info.wake if wake.installed] - - # satellite - satellite_installed: Satellite | None = None - - if (service.info.satellite is not None) and service.info.satellite.installed: - satellite_installed = service.info.satellite - - if asr_installed: - name = asr_installed[0].name - elif tts_installed: - name = tts_installed[0].name - elif wake_installed: - name = wake_installed[0].name - elif satellite_installed: - name = satellite_installed.name - else: - return self.async_abort(reason="no_services") + if name := service.get_name(): + return self.async_create_entry(title=name, data=user_input) - return self.async_create_entry(title=name, data=user_input) + return self.async_abort(reason="no_services") - async def async_step_hassio(self, discovery_info: HassioServiceInfo) -> FlowResult: + async def async_step_hassio( + self, discovery_info: hassio.HassioServiceInfo + ) -> FlowResult: """Handle Supervisor add-on discovery.""" await self.async_set_unique_id(discovery_info.uuid) self._abort_if_unique_id_configured() @@ -104,11 +84,7 @@ async def async_step_hassio_confirm( if user_input is not None: uri = urlparse(self._hassio_discovery.config["uri"]) if service := await WyomingService.create(uri.hostname, uri.port): - if ( - not any(asr for asr in service.info.asr if asr.installed) - and not any(tts for tts in service.info.tts if tts.installed) - and not any(wake for wake in service.info.wake if wake.installed) - ): + if not service.has_services(): return self.async_abort(reason="no_services") return self.async_create_entry( @@ -123,3 +99,52 @@ async def async_step_hassio_confirm( description_placeholders={"addon": self._hassio_discovery.name}, errors=errors, ) + + async def async_step_zeroconf( + self, discovery_info: zeroconf.ZeroconfServiceInfo + ) -> FlowResult: + """Handle zeroconf discovery.""" + _LOGGER.debug("Discovery info: %s", discovery_info) + if discovery_info.port is None: + return self.async_abort(reason="no_port") + + service = await WyomingService.create(discovery_info.host, discovery_info.port) + if (service is None) or (not (name := service.get_name())): + return self.async_abort(reason="no_services") + + self.context[CONF_NAME] = name + self.context["title_placeholders"] = {"name": name} + + uuid = f"wyoming_{service.host}_{service.port}" + + await self.async_set_unique_id(uuid) + self._abort_if_unique_id_configured() + + self._service = service + return await self.async_step_zeroconf_confirm() + + async def async_step_zeroconf_confirm( + self, user_input: dict[str, Any] | None = None + ) -> FlowResult: + """Handle a flow initiated by zeroconf.""" + if ( + (self._service is None) + or (not self._service.has_services()) + or (not (name := self._service.get_name())) + ): + return self.async_abort(reason="no_services") + + if user_input is None: + return self.async_show_form( + step_id="zeroconf_confirm", + description_placeholders={"name": name}, + errors={}, + ) + + return self.async_create_entry( + title=name, + data={ + CONF_HOST: self._service.host, + CONF_PORT: self._service.port, + }, + ) diff --git a/homeassistant/components/wyoming/data.py b/homeassistant/components/wyoming/data.py index 64b92eb847177..ea58181a7074d 100644 --- a/homeassistant/components/wyoming/data.py +++ b/homeassistant/components/wyoming/data.py @@ -4,7 +4,7 @@ import asyncio from wyoming.client import AsyncTcpClient -from wyoming.info import Describe, Info +from wyoming.info import Describe, Info, Satellite from homeassistant.const import Platform @@ -32,6 +32,43 @@ def __init__(self, host: str, port: int, info: Info) -> None: platforms.append(Platform.WAKE_WORD) self.platforms = platforms + def has_services(self) -> bool: + """Return True if services are installed that Home Assistant can use.""" + return ( + any(asr for asr in self.info.asr if asr.installed) + or any(tts for tts in self.info.tts if tts.installed) + or any(wake for wake in self.info.wake if wake.installed) + or ((self.info.satellite is not None) and self.info.satellite.installed) + ) + + def get_name(self) -> str | None: + """Return name of first installed usable service.""" + # ASR = automated speech recognition (speech-to-text) + asr_installed = [asr for asr in self.info.asr if asr.installed] + if asr_installed: + return asr_installed[0].name + + # TTS = text-to-speech + tts_installed = [tts for tts in self.info.tts if tts.installed] + if tts_installed: + return tts_installed[0].name + + # wake-word-detection + wake_installed = [wake for wake in self.info.wake if wake.installed] + if wake_installed: + return wake_installed[0].name + + # satellite + satellite_installed: Satellite | None = None + + if (self.info.satellite is not None) and self.info.satellite.installed: + satellite_installed = self.info.satellite + + if satellite_installed: + return satellite_installed.name + + return None + @classmethod async def create(cls, host: str, port: int) -> WyomingService | None: """Create a Wyoming service.""" diff --git a/homeassistant/components/wyoming/entity.py b/homeassistant/components/wyoming/entity.py new file mode 100644 index 0000000000000..5ed890bc60e7d --- /dev/null +++ b/homeassistant/components/wyoming/entity.py @@ -0,0 +1,24 @@ +"""Wyoming entities.""" + +from __future__ import annotations + +from homeassistant.helpers import entity +from homeassistant.helpers.device_registry import DeviceInfo + +from .const import DOMAIN +from .satellite import SatelliteDevice + + +class WyomingSatelliteEntity(entity.Entity): + """Wyoming satellite entity.""" + + _attr_has_entity_name = True + _attr_should_poll = False + + def __init__(self, device: SatelliteDevice) -> None: + """Initialize entity.""" + self._device = device + self._attr_unique_id = f"{device.satellite_id}-{self.entity_description.key}" + self._attr_device_info = DeviceInfo( + identifiers={(DOMAIN, device.satellite_id)}, + ) diff --git a/homeassistant/components/wyoming/manifest.json b/homeassistant/components/wyoming/manifest.json index f1a7faac81f90..540aaa9aeac09 100644 --- a/homeassistant/components/wyoming/manifest.json +++ b/homeassistant/components/wyoming/manifest.json @@ -6,5 +6,6 @@ "dependencies": ["assist_pipeline"], "documentation": "https://www.home-assistant.io/integrations/wyoming", "iot_class": "local_push", - "requirements": ["wyoming==1.3.0"] + "requirements": ["wyoming==1.3.0"], + "zeroconf": ["_wyoming._tcp.local."] } diff --git a/homeassistant/components/wyoming/models.py b/homeassistant/components/wyoming/models.py new file mode 100644 index 0000000000000..18579c9000094 --- /dev/null +++ b/homeassistant/components/wyoming/models.py @@ -0,0 +1,13 @@ +"""Models for wyoming.""" +from dataclasses import dataclass + +from .data import WyomingService +from .satellite import SatelliteDevices + + +@dataclass +class DomainDataItem: + """Domain data item.""" + + service: WyomingService + satellite_devices: SatelliteDevices diff --git a/homeassistant/components/wyoming/satellite/__init__.py b/homeassistant/components/wyoming/satellite/__init__.py index 7da9af4884c8f..006cc9ff7f795 100644 --- a/homeassistant/components/wyoming/satellite/__init__.py +++ b/homeassistant/components/wyoming/satellite/__init__.py @@ -1,7 +1,10 @@ """Support for Wyoming satellite services.""" +from .devices import SatelliteDevice, SatelliteDevices from .satellite import WyomingSatellite __all__ = [ + "SatelliteDevice", + "SatelliteDevices", "WyomingSatellite", ] diff --git a/homeassistant/components/wyoming/satellite/devices.py b/homeassistant/components/wyoming/satellite/devices.py new file mode 100644 index 0000000000000..734854f768235 --- /dev/null +++ b/homeassistant/components/wyoming/satellite/devices.py @@ -0,0 +1,136 @@ +"""Class to manage satellite devices.""" +from __future__ import annotations + +from collections.abc import Callable, Iterator +from dataclasses import dataclass, field + +from homeassistant.config_entries import ConfigEntry +from homeassistant.core import Event, HomeAssistant, callback +from homeassistant.helpers import area_registry as ar, device_registry as dr + +from ..const import DOMAIN +from ..data import WyomingService + + +@dataclass +class SatelliteDevice: + """Class to store device.""" + + satellite_id: str + device_id: str + is_active: bool = False + update_listeners: list[Callable[[SatelliteDevice], None]] = field( + default_factory=list + ) + + @callback + def set_is_active(self, active: bool) -> None: + """Set active state.""" + self.is_active = active + for listener in self.update_listeners: + listener(self) + + @callback + def async_pipeline_changed(self) -> None: + """Inform listeners that pipeline selection has changed.""" + for listener in self.update_listeners: + listener(self) + + @callback + def async_listen_update( + self, listener: Callable[[SatelliteDevice], None] + ) -> Callable[[], None]: + """Listen for updates.""" + self.update_listeners.append(listener) + return lambda: self.update_listeners.remove(listener) + + +class SatelliteDevices: + """Class to store devices.""" + + def __init__(self, hass: HomeAssistant, config_entry: ConfigEntry) -> None: + """Initialize satellite devices.""" + self.hass = hass + self.config_entry = config_entry + self._new_device_listeners: list[Callable[[SatelliteDevice], None]] = [] + self.devices: dict[str, SatelliteDevice] = {} + + @callback + def async_setup(self) -> None: + """Set up devices.""" + for device in dr.async_entries_for_config_entry( + dr.async_get(self.hass), self.config_entry.entry_id + ): + satellite_id = next( + (item[1] for item in device.identifiers if item[0] == DOMAIN), None + ) + if satellite_id is None: + continue + self.devices[satellite_id] = SatelliteDevice( + satellite_id=satellite_id, + device_id=device.id, + ) + + @callback + def async_device_removed(ev: Event) -> None: + """Handle device removed.""" + removed_id = ev.data["device_id"] + self.devices = { + satellite_id: satellite_device + for satellite_id, satellite_device in self.devices.items() + if satellite_device.device_id != removed_id + } + + self.config_entry.async_on_unload( + self.hass.bus.async_listen( + dr.EVENT_DEVICE_REGISTRY_UPDATED, + async_device_removed, + callback(lambda ev: ev.data.get("action") == "remove"), + ) + ) + + @callback + def async_add_new_device_listener( + self, listener: Callable[[SatelliteDevice], None] + ) -> None: + """Add a new device listener.""" + self._new_device_listeners.append(listener) + + @callback + def async_get_or_create(self, service: WyomingService) -> SatelliteDevice: + """Get or create a device.""" + dev_reg = dr.async_get(self.hass) + satellite_id = f"{service.host}_{service.port}" + satellite_device = self.devices.get(satellite_id) + + if satellite_device is not None: + return satellite_device + + satellite_info = service.info.satellite + if not satellite_info: + raise ValueError("No satellite info") + + device = dev_reg.async_get_or_create( + config_entry_id=self.config_entry.entry_id, + identifiers={(DOMAIN, satellite_id)}, + name=satellite_id, + ) + + if satellite_info.area: + # Use area hint + area_reg = ar.async_get(self.hass) + if area := area_reg.async_get_area_by_name(satellite_info.area): + dev_reg.async_update_device(device.id, area_id=area.id) + + satellite_device = self.devices[satellite_id] = SatelliteDevice( + satellite_id=satellite_id, + device_id=device.id, + ) + for listener in self._new_device_listeners: + listener(satellite_device) + + return satellite_device + + def __iter__(self) -> Iterator[SatelliteDevice]: + """Iterate over devices.""" + return iter(self.devices.values()) diff --git a/homeassistant/components/wyoming/satellite/satellite.py b/homeassistant/components/wyoming/satellite/satellite.py index 95a7b2c149935..10b2bded58513 100644 --- a/homeassistant/components/wyoming/satellite/satellite.py +++ b/homeassistant/components/wyoming/satellite/satellite.py @@ -15,9 +15,12 @@ from wyoming.wake import Detection from homeassistant.components import assist_pipeline, stt, tts +from homeassistant.components.assist_pipeline import select as pipeline_select from homeassistant.core import Context, HomeAssistant +from ..const import DOMAIN from ..data import WyomingService +from .devices import SatelliteDevice _LOGGER = logging.getLogger() @@ -29,28 +32,69 @@ class WyomingSatellite: """Remove voice satellite running the Wyoming protocol.""" - def __init__(self, hass: HomeAssistant, service: WyomingService) -> None: + def __init__( + self, hass: HomeAssistant, service: WyomingService, device: SatelliteDevice + ) -> None: """Initialize satellite.""" self.hass = hass self.service = service + self.device = device + self.is_running = True + self._client: AsyncTcpClient | None = None self._chunk_converter = AudioChunkConverter(rate=16000, width=2, channels=1) self._is_pipeline_running = False + self._audio_queue: asyncio.Queue[bytes | None] = asyncio.Queue() + self._pipeline_id: str | None = None async def run(self) -> None: """Run and maintain a connection to satellite.""" - while self.hass.is_running: + _LOGGER.debug("Running satellite task") + self._pipeline_id = pipeline_select.get_chosen_pipeline( + self.hass, + DOMAIN, + self.device.satellite_id, + ) + + remove_listener = self.device.async_listen_update(self._device_updated) + + while self.is_running: try: await self._run_once() + except asyncio.CancelledError: + raise except Exception: # pylint: disable=broad-exception-caught _LOGGER.exception( "Unexpected error running satellite. Restarting in %s second(s)", _RECONNECT_SECONDS, ) await asyncio.sleep(_RESTART_SECONDS) + finally: + # Ensure sensor is off + if self.device.is_active: + self.device.set_is_active(False) + + remove_listener() + + _LOGGER.debug("Satellite task stopped") + + def _device_updated(self, device: SatelliteDevice) -> None: + pipeline_id = pipeline_select.get_chosen_pipeline( + self.hass, + DOMAIN, + self.device.satellite_id, + ) + + if self._pipeline_id != pipeline_id: + # Pipeline has changed + self._pipeline_id = pipeline_id + self._audio_queue.put_nowait(None) async def _run_once(self) -> None: """Run pipelines until an error occurs.""" + if self.device.is_active: + self.device.set_is_active(False) + while True: try: await self._connect() @@ -65,6 +109,10 @@ async def _run_once(self) -> None: assert self._client is not None _LOGGER.debug("Connected to satellite") + if not self.is_running: + # Run was cancelled + return + # Tell satellite that we're ready await self._client.write_event(RunSatellite().event()) @@ -84,16 +132,22 @@ async def _run_once(self) -> None: assert run_pipeline is not None _LOGGER.debug("Received run information: %s", run_pipeline) + if not self.is_running: + # Run was cancelled + return + start_stage = _convert_stage(run_pipeline.start_stage) end_stage = _convert_stage(run_pipeline.end_stage) - # Default pipeline - pipeline = assist_pipeline.async_get_pipeline(self.hass) - + # Each loop is a pipeline run while True: + # Use select to get pipeline each time in case it's changed + pipeline = assist_pipeline.async_get_pipeline(self.hass, self._pipeline_id) + assert pipeline is not None + # We will push audio in through a queue - audio_queue: asyncio.Queue[bytes] = asyncio.Queue() - stt_stream = _stt_stream(audio_queue) + self._audio_queue = asyncio.Queue() + stt_stream = self._stt_stream() # Start pipeline running _LOGGER.debug( @@ -120,7 +174,7 @@ async def _run_once(self) -> None: start_stage=start_stage, end_stage=end_stage, tts_audio_output="wav", - pipeline_id=pipeline.id, + pipeline_id=self._pipeline_id, ) ) @@ -133,7 +187,7 @@ async def _run_once(self) -> None: # Microphone audio chunk = AudioChunk.from_event(client_event) chunk = self._chunk_converter.convert(chunk) - audio_queue.put_nowait(chunk.audio) + self._audio_queue.put_nowait(chunk.audio) else: _LOGGER.debug("Unexpected event from satellite: %s", client_event) @@ -142,28 +196,40 @@ async def _run_once(self) -> None: def _event_callback(self, event: assist_pipeline.PipelineEvent) -> None: """Translate pipeline events into Wyoming events.""" assert self._client is not None + _LOGGER.debug(event) if event.type == assist_pipeline.PipelineEventType.RUN_END: self._is_pipeline_running = False + if self.device.is_active: + self.device.set_is_active(False) elif event.type == assist_pipeline.PipelineEventType.WAKE_WORD_END: # Wake word detection - detection = Detection() - if event.data: - wake_word_output = event.data["wake_word_output"] - detection.name = wake_word_output["wake_word_id"] - detection.timestamp = wake_word_output.get("timestamp") - - self.hass.add_job(self._client.write_event(detection.event())) + if not self.device.is_active: + self.device.set_is_active(True) + + # Inform client of wake word detection + if event.data and (wake_word_output := event.data.get("wake_word_output")): + detection = Detection( + name=wake_word_output["wake_word_id"], + timestamp=wake_word_output.get("timestamp"), + ) + self.hass.add_job(self._client.write_event(detection.event())) + elif event.type == assist_pipeline.PipelineEventType.STT_START: + # Speech-to-text + if not self.device.is_active: + self.device.set_is_active(True) elif event.type == assist_pipeline.PipelineEventType.STT_END: - # STT transcript + # Speech-to-text transcript if event.data: + # Inform client of transript stt_text = event.data["stt_output"]["text"] self.hass.add_job( self._client.write_event(Transcript(text=stt_text).event()) ) elif event.type == assist_pipeline.PipelineEventType.TTS_START: - # TTS text + # Text-to-speech text if event.data: + # Inform client of text self.hass.add_job( self._client.write_event( Synthesize( @@ -225,21 +291,18 @@ async def _stream_tts(self, media_id: str) -> None: await self._client.write_event(AudioStop().event()) _LOGGER.debug("TTS streaming complete") + async def _stt_stream(self) -> AsyncGenerator[bytes, None]: + """Yield audio chunks from a queue.""" + is_first_chunk = True + while chunk := await self._audio_queue.get(): + if is_first_chunk: + is_first_chunk = False + _LOGGER.debug("Receiving audio from satellite") -# ----------------------------------------------------------------------------- + yield chunk -async def _stt_stream( - audio_queue: asyncio.Queue[bytes], -) -> AsyncGenerator[bytes, None]: - """Yield audio chunks from a queue.""" - is_first_chunk = True - while chunk := await audio_queue.get(): - if is_first_chunk: - is_first_chunk = False - _LOGGER.debug("Receiving audio from satellite") - - yield chunk +# ----------------------------------------------------------------------------- def _convert_stage(wyoming_stage: PipelineStage) -> assist_pipeline.PipelineStage: diff --git a/homeassistant/components/wyoming/select.py b/homeassistant/components/wyoming/select.py new file mode 100644 index 0000000000000..d001c789202ce --- /dev/null +++ b/homeassistant/components/wyoming/select.py @@ -0,0 +1,55 @@ +"""Select entities for VoIP integration.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from homeassistant.components.assist_pipeline.select import AssistPipelineSelect +from homeassistant.config_entries import ConfigEntry +from homeassistant.core import HomeAssistant, callback +from homeassistant.helpers.entity_platform import AddEntitiesCallback + +from .const import DOMAIN +from .entity import WyomingSatelliteEntity +from .satellite.devices import SatelliteDevice + +if TYPE_CHECKING: + from .models import DomainDataItem + + +async def async_setup_entry( + hass: HomeAssistant, + config_entry: ConfigEntry, + async_add_entities: AddEntitiesCallback, +) -> None: + """Set up VoIP switch entities.""" + domain_data: DomainDataItem = hass.data[DOMAIN][config_entry.entry_id] + + @callback + def async_add_device(device: SatelliteDevice) -> None: + """Add device.""" + async_add_entities([WyomingSatellitePipelineSelect(hass, device)]) + + domain_data.satellite_devices.async_add_new_device_listener(async_add_device) + + entities: list[WyomingSatelliteEntity] = [] + for device in domain_data.satellite_devices: + entities.append(WyomingSatellitePipelineSelect(hass, device)) + + async_add_entities(entities) + + +class WyomingSatellitePipelineSelect(WyomingSatelliteEntity, AssistPipelineSelect): + """Pipeline selector for Wyoming satellites.""" + + def __init__(self, hass: HomeAssistant, device: SatelliteDevice) -> None: + """Initialize a pipeline selector.""" + self.device = device + + WyomingSatelliteEntity.__init__(self, device) + AssistPipelineSelect.__init__(self, hass, device.satellite_id) + + async def async_select_option(self, option: str) -> None: + """Select an option.""" + await super().async_select_option(option) + self.device.async_pipeline_changed() diff --git a/homeassistant/components/wyoming/strings.json b/homeassistant/components/wyoming/strings.json index 20d73d8dc1391..38629a4b86b6e 100644 --- a/homeassistant/components/wyoming/strings.json +++ b/homeassistant/components/wyoming/strings.json @@ -9,6 +9,10 @@ }, "hassio_confirm": { "description": "Do you want to configure Home Assistant to connect to the Wyoming service provided by the add-on: {addon}?" + }, + "zeroconf_confirm": { + "description": "Do you want to configure Home Assistant to connect to the Wyoming service {name}?", + "title": "Discovered Wyoming service" } }, "error": { @@ -16,7 +20,23 @@ }, "abort": { "already_configured": "[%key:common::config_flow::abort::already_configured_service%]", - "no_services": "No services found at endpoint" + "no_services": "No services found at endpoint", + "no_port": "No port for endpoint" + } + }, + "entity": { + "binary_sensor": { + "assist_in_progress": { + "name": "[%key:component::assist_pipeline::entity::binary_sensor::assist_in_progress::name%]" + } + }, + "select": { + "pipeline": { + "name": "[%key:component::assist_pipeline::entity::select::pipeline::name%]", + "state": { + "preferred": "[%key:component::assist_pipeline::entity::select::pipeline::state::preferred%]" + } + } } } } diff --git a/homeassistant/components/wyoming/stt.py b/homeassistant/components/wyoming/stt.py index e64a2f1466702..8a21ef051fced 100644 --- a/homeassistant/components/wyoming/stt.py +++ b/homeassistant/components/wyoming/stt.py @@ -14,6 +14,7 @@ from .const import DOMAIN, SAMPLE_CHANNELS, SAMPLE_RATE, SAMPLE_WIDTH from .data import WyomingService from .error import WyomingError +from .models import DomainDataItem _LOGGER = logging.getLogger(__name__) @@ -24,10 +25,10 @@ async def async_setup_entry( async_add_entities: AddEntitiesCallback, ) -> None: """Set up Wyoming speech-to-text.""" - service: WyomingService = hass.data[DOMAIN][config_entry.entry_id] + item: DomainDataItem = hass.data[DOMAIN][config_entry.entry_id] async_add_entities( [ - WyomingSttProvider(config_entry, service), + WyomingSttProvider(config_entry, item.service), ] ) diff --git a/homeassistant/components/wyoming/tts.py b/homeassistant/components/wyoming/tts.py index cde771cd33056..f024f925514ff 100644 --- a/homeassistant/components/wyoming/tts.py +++ b/homeassistant/components/wyoming/tts.py @@ -16,6 +16,7 @@ from .const import ATTR_SPEAKER, DOMAIN from .data import WyomingService from .error import WyomingError +from .models import DomainDataItem _LOGGER = logging.getLogger(__name__) @@ -26,10 +27,10 @@ async def async_setup_entry( async_add_entities: AddEntitiesCallback, ) -> None: """Set up Wyoming speech-to-text.""" - service: WyomingService = hass.data[DOMAIN][config_entry.entry_id] + item: DomainDataItem = hass.data[DOMAIN][config_entry.entry_id] async_add_entities( [ - WyomingTtsProvider(config_entry, service), + WyomingTtsProvider(config_entry, item.service), ] ) diff --git a/homeassistant/components/wyoming/wake_word.py b/homeassistant/components/wyoming/wake_word.py index fce8bbf6327c1..da05e8c9fe112 100644 --- a/homeassistant/components/wyoming/wake_word.py +++ b/homeassistant/components/wyoming/wake_word.py @@ -15,6 +15,7 @@ from .const import DOMAIN from .data import WyomingService, load_wyoming_info from .error import WyomingError +from .models import DomainDataItem _LOGGER = logging.getLogger(__name__) @@ -25,10 +26,10 @@ async def async_setup_entry( async_add_entities: AddEntitiesCallback, ) -> None: """Set up Wyoming wake-word-detection.""" - service: WyomingService = hass.data[DOMAIN][config_entry.entry_id] + item: DomainDataItem = hass.data[DOMAIN][config_entry.entry_id] async_add_entities( [ - WyomingWakeWordProvider(hass, config_entry, service), + WyomingWakeWordProvider(hass, config_entry, item.service), ] ) diff --git a/homeassistant/generated/zeroconf.py b/homeassistant/generated/zeroconf.py index e8d117d1f338c..55570078d8075 100644 --- a/homeassistant/generated/zeroconf.py +++ b/homeassistant/generated/zeroconf.py @@ -715,6 +715,11 @@ "domain": "wled", }, ], + "_wyoming._tcp.local.": [ + { + "domain": "wyoming", + }, + ], "_xbmc-jsonrpc-h._tcp.local.": [ { "domain": "kodi", From 6a8e04a1984224d70429fa52324d4ae6fdeb9907 Mon Sep 17 00:00:00 2001 From: Michael Hansen Date: Tue, 28 Nov 2023 16:58:39 -0600 Subject: [PATCH 05/18] Add more events --- .../components/wyoming/satellite/satellite.py | 38 ++++++++++++++----- 1 file changed, 29 insertions(+), 9 deletions(-) diff --git a/homeassistant/components/wyoming/satellite/satellite.py b/homeassistant/components/wyoming/satellite/satellite.py index 10b2bded58513..b4e8d615d49fb 100644 --- a/homeassistant/components/wyoming/satellite/satellite.py +++ b/homeassistant/components/wyoming/satellite/satellite.py @@ -12,7 +12,8 @@ from wyoming.pipeline import PipelineStage, RunPipeline from wyoming.satellite import RunSatellite from wyoming.tts import Synthesize, SynthesizeVoice -from wyoming.wake import Detection +from wyoming.vad import VoiceStarted, VoiceStopped +from wyoming.wake import Detect, Detection from homeassistant.components import assist_pipeline, stt, tts from homeassistant.components.assist_pipeline import select as pipeline_select @@ -202,6 +203,8 @@ def _event_callback(self, event: assist_pipeline.PipelineEvent) -> None: self._is_pipeline_running = False if self.device.is_active: self.device.set_is_active(False) + elif event.type == assist_pipeline.PipelineEventType.WAKE_WORD_START: + self.hass.add_job(self._client.write_event(Detect().event())) elif event.type == assist_pipeline.PipelineEventType.WAKE_WORD_END: # Wake word detection if not self.device.is_active: @@ -218,6 +221,20 @@ def _event_callback(self, event: assist_pipeline.PipelineEvent) -> None: # Speech-to-text if not self.device.is_active: self.device.set_is_active(True) + elif event.type == assist_pipeline.PipelineEventType.STT_VAD_START: + if event.data: + self.hass.add_job( + self._client.write_event( + VoiceStarted(timestamp=event.data["timestamp"]).event() + ) + ) + elif event.type == assist_pipeline.PipelineEventType.STT_VAD_END: + if event.data: + self.hass.add_job( + self._client.write_event( + VoiceStopped(timestamp=event.data["timestamp"]).event() + ) + ) elif event.type == assist_pipeline.PipelineEventType.STT_END: # Speech-to-text transcript if event.data: @@ -269,26 +286,29 @@ async def _stream_tts(self, media_id: str) -> None: sample_channels = wav_file.getnchannels() _LOGGER.debug("Streaming %s TTS sample(s)", wav_file.getnframes()) + timestamp = 0 await self._client.write_event( AudioStart( rate=sample_rate, width=sample_width, channels=sample_channels, + timestamp=timestamp, ).event() ) # Stream audio chunks while audio_bytes := wav_file.readframes(_SAMPLES_PER_CHUNK): - await self._client.write_event( - AudioChunk( - rate=sample_rate, - width=sample_width, - channels=sample_channels, - audio=audio_bytes, - ).event() + chunk = AudioChunk( + rate=sample_rate, + width=sample_width, + channels=sample_channels, + audio=audio_bytes, + timestamp=timestamp, ) + await self._client.write_event(chunk.event()) + timestamp += chunk.seconds - await self._client.write_event(AudioStop().event()) + await self._client.write_event(AudioStop(timestamp=timestamp).event()) _LOGGER.debug("TTS streaming complete") async def _stt_stream(self) -> AsyncGenerator[bytes, None]: From 90d2d5371c3183a1aa059252bef4db84cbc59b6c Mon Sep 17 00:00:00 2001 From: Michael Hansen Date: Wed, 29 Nov 2023 16:24:34 -0600 Subject: [PATCH 06/18] Add satellite enabled switch --- homeassistant/components/wyoming/__init__.py | 5 +- .../wyoming/{satellite => }/devices.py | 12 ++- homeassistant/components/wyoming/models.py | 2 +- .../wyoming/{satellite => }/satellite.py | 81 +++++++++++++------ .../components/wyoming/satellite/__init__.py | 10 --- homeassistant/components/wyoming/select.py | 2 +- homeassistant/components/wyoming/strings.json | 5 ++ homeassistant/components/wyoming/switch.py | 75 +++++++++++++++++ 8 files changed, 152 insertions(+), 40 deletions(-) rename homeassistant/components/wyoming/{satellite => }/devices.py (93%) rename homeassistant/components/wyoming/{satellite => }/satellite.py (85%) delete mode 100644 homeassistant/components/wyoming/satellite/__init__.py create mode 100644 homeassistant/components/wyoming/switch.py diff --git a/homeassistant/components/wyoming/__init__.py b/homeassistant/components/wyoming/__init__.py index b6dd66b533652..43283594f481e 100644 --- a/homeassistant/components/wyoming/__init__.py +++ b/homeassistant/components/wyoming/__init__.py @@ -10,12 +10,13 @@ from .const import DOMAIN from .data import WyomingService +from .devices import SatelliteDevices from .models import DomainDataItem -from .satellite import SatelliteDevices, WyomingSatellite +from .satellite import WyomingSatellite _LOGGER = logging.getLogger(__name__) -PLATFORMS = [Platform.BINARY_SENSOR, Platform.SELECT] +PLATFORMS = [Platform.BINARY_SENSOR, Platform.SELECT, Platform.SWITCH] __all__ = [ "DOMAIN", diff --git a/homeassistant/components/wyoming/satellite/devices.py b/homeassistant/components/wyoming/devices.py similarity index 93% rename from homeassistant/components/wyoming/satellite/devices.py rename to homeassistant/components/wyoming/devices.py index 734854f768235..73850ad11e3ef 100644 --- a/homeassistant/components/wyoming/satellite/devices.py +++ b/homeassistant/components/wyoming/devices.py @@ -8,8 +8,8 @@ from homeassistant.core import Event, HomeAssistant, callback from homeassistant.helpers import area_registry as ar, device_registry as dr -from ..const import DOMAIN -from ..data import WyomingService +from .const import DOMAIN +from .data import WyomingService @dataclass @@ -19,6 +19,7 @@ class SatelliteDevice: satellite_id: str device_id: str is_active: bool = False + is_enabled: bool = True update_listeners: list[Callable[[SatelliteDevice], None]] = field( default_factory=list ) @@ -30,6 +31,13 @@ def set_is_active(self, active: bool) -> None: for listener in self.update_listeners: listener(self) + @callback + def set_is_enabled(self, enabled: bool) -> None: + """Set enabled state.""" + self.is_enabled = enabled + for listener in self.update_listeners: + listener(self) + @callback def async_pipeline_changed(self) -> None: """Inform listeners that pipeline selection has changed.""" diff --git a/homeassistant/components/wyoming/models.py b/homeassistant/components/wyoming/models.py index 18579c9000094..adb113c541e7b 100644 --- a/homeassistant/components/wyoming/models.py +++ b/homeassistant/components/wyoming/models.py @@ -2,7 +2,7 @@ from dataclasses import dataclass from .data import WyomingService -from .satellite import SatelliteDevices +from .devices import SatelliteDevices @dataclass diff --git a/homeassistant/components/wyoming/satellite/satellite.py b/homeassistant/components/wyoming/satellite.py similarity index 85% rename from homeassistant/components/wyoming/satellite/satellite.py rename to homeassistant/components/wyoming/satellite.py index b4e8d615d49fb..bd3555df78849 100644 --- a/homeassistant/components/wyoming/satellite/satellite.py +++ b/homeassistant/components/wyoming/satellite.py @@ -19,8 +19,8 @@ from homeassistant.components.assist_pipeline import select as pipeline_select from homeassistant.core import Context, HomeAssistant -from ..const import DOMAIN -from ..data import WyomingService +from .const import DOMAIN +from .data import WyomingService from .devices import SatelliteDevice _LOGGER = logging.getLogger() @@ -40,13 +40,28 @@ def __init__( self.hass = hass self.service = service self.device = device - self.is_running = True + self.is_enabled = True + self._is_running = True self._client: AsyncTcpClient | None = None self._chunk_converter = AudioChunkConverter(rate=16000, width=2, channels=1) self._is_pipeline_running = False self._audio_queue: asyncio.Queue[bytes | None] = asyncio.Queue() self._pipeline_id: str | None = None + self._device_updated_event = asyncio.Event() + + @property + def is_running(self) -> bool: + """Return True if satellite is running.""" + return self._is_running + + @is_running.setter + def is_running(self, value: bool) -> None: + """Set whether satellite is running or not.""" + self._is_running = value + + # Unblock waiting for enabled + self._device_updated_event.set() async def run(self) -> None: """Run and maintain a connection to satellite.""" @@ -56,39 +71,58 @@ async def run(self) -> None: DOMAIN, self.device.satellite_id, ) - + self.is_enabled = self.device.is_enabled remove_listener = self.device.async_listen_update(self._device_updated) - while self.is_running: - try: - await self._run_once() - except asyncio.CancelledError: - raise - except Exception: # pylint: disable=broad-exception-caught - _LOGGER.exception( - "Unexpected error running satellite. Restarting in %s second(s)", - _RECONNECT_SECONDS, - ) - await asyncio.sleep(_RESTART_SECONDS) - finally: - # Ensure sensor is off - if self.device.is_active: - self.device.set_is_active(False) + try: + while self.is_running: + try: + if not self.is_enabled: + await self._device_updated_event.wait() + if not self.is_running: + # Satellite was stopped while waiting to be enabled + break + + await self._run_once() + except asyncio.CancelledError: + raise + except Exception: # pylint: disable=broad-exception-caught + _LOGGER.exception( + "Unexpected error running satellite. Restarting in %s second(s)", + _RECONNECT_SECONDS, + ) + await asyncio.sleep(_RESTART_SECONDS) + finally: + # Ensure sensor is off + if self.device.is_active: + self.device.set_is_active(False) - remove_listener() + remove_listener() _LOGGER.debug("Satellite task stopped") def _device_updated(self, device: SatelliteDevice) -> None: + """Reacts to updated device settings.""" pipeline_id = pipeline_select.get_chosen_pipeline( self.hass, DOMAIN, self.device.satellite_id, ) + stop_pipeline = False if self._pipeline_id != pipeline_id: # Pipeline has changed self._pipeline_id = pipeline_id + stop_pipeline = True + + if self.is_enabled and (not self.device.is_enabled): + stop_pipeline = True + + self.is_enabled = self.device.is_enabled + self._device_updated_event.set() + self._device_updated_event.clear() + + if stop_pipeline: self._audio_queue.put_nowait(None) async def _run_once(self) -> None: @@ -141,7 +175,7 @@ async def _run_once(self) -> None: end_stage = _convert_stage(run_pipeline.end_stage) # Each loop is a pipeline run - while True: + while self.is_running and self.is_enabled: # Use select to get pipeline each time in case it's changed pipeline = assist_pipeline.async_get_pipeline(self.hass, self._pipeline_id) assert pipeline is not None @@ -197,7 +231,6 @@ async def _run_once(self) -> None: def _event_callback(self, event: assist_pipeline.PipelineEvent) -> None: """Translate pipeline events into Wyoming events.""" assert self._client is not None - _LOGGER.debug(event) if event.type == assist_pipeline.PipelineEventType.RUN_END: self._is_pipeline_running = False @@ -260,8 +293,8 @@ def _event_callback(self, event: assist_pipeline.PipelineEvent) -> None: ) elif event.type == assist_pipeline.PipelineEventType.TTS_END: # TTS stream - if event.data: - media_id = event.data["tts_output"]["media_id"] + if event.data and (tts_output := event.data["tts_output"]): + media_id = tts_output["media_id"] self.hass.add_job(self._stream_tts(media_id)) async def _connect(self) -> None: diff --git a/homeassistant/components/wyoming/satellite/__init__.py b/homeassistant/components/wyoming/satellite/__init__.py deleted file mode 100644 index 006cc9ff7f795..0000000000000 --- a/homeassistant/components/wyoming/satellite/__init__.py +++ /dev/null @@ -1,10 +0,0 @@ -"""Support for Wyoming satellite services.""" - -from .devices import SatelliteDevice, SatelliteDevices -from .satellite import WyomingSatellite - -__all__ = [ - "SatelliteDevice", - "SatelliteDevices", - "WyomingSatellite", -] diff --git a/homeassistant/components/wyoming/select.py b/homeassistant/components/wyoming/select.py index d001c789202ce..b9ec0794a5b47 100644 --- a/homeassistant/components/wyoming/select.py +++ b/homeassistant/components/wyoming/select.py @@ -10,8 +10,8 @@ from homeassistant.helpers.entity_platform import AddEntitiesCallback from .const import DOMAIN +from .devices import SatelliteDevice from .entity import WyomingSatelliteEntity -from .satellite.devices import SatelliteDevice if TYPE_CHECKING: from .models import DomainDataItem diff --git a/homeassistant/components/wyoming/strings.json b/homeassistant/components/wyoming/strings.json index 38629a4b86b6e..f5a8a2de2ce32 100644 --- a/homeassistant/components/wyoming/strings.json +++ b/homeassistant/components/wyoming/strings.json @@ -37,6 +37,11 @@ "preferred": "[%key:component::assist_pipeline::entity::select::pipeline::state::preferred%]" } } + }, + "switch": { + "satellite_enabled": { + "name": "Satellite enabled" + } } } } diff --git a/homeassistant/components/wyoming/switch.py b/homeassistant/components/wyoming/switch.py new file mode 100644 index 0000000000000..17e663c54f27b --- /dev/null +++ b/homeassistant/components/wyoming/switch.py @@ -0,0 +1,75 @@ +"""Wyoming switch entities.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +from homeassistant.components.switch import SwitchEntity, SwitchEntityDescription +from homeassistant.config_entries import ConfigEntry +from homeassistant.const import STATE_ON, EntityCategory +from homeassistant.core import HomeAssistant, callback +from homeassistant.helpers import restore_state +from homeassistant.helpers.entity_platform import AddEntitiesCallback + +from .const import DOMAIN +from .devices import SatelliteDevice +from .entity import WyomingSatelliteEntity + +if TYPE_CHECKING: + from .models import DomainDataItem + + +async def async_setup_entry( + hass: HomeAssistant, + config_entry: ConfigEntry, + async_add_entities: AddEntitiesCallback, +) -> None: + """Set up VoIP switch entities.""" + domain_data: DomainDataItem = hass.data[DOMAIN][config_entry.entry_id] + + @callback + def async_add_device(device: SatelliteDevice) -> None: + """Add device.""" + async_add_entities([WyomingSatelliteEnabledSwitch(device)]) + + domain_data.satellite_devices.async_add_new_device_listener(async_add_device) + + async_add_entities( + [ + WyomingSatelliteEnabledSwitch(device) + for device in domain_data.satellite_devices + ] + ) + + +class WyomingSatelliteEnabledSwitch( + WyomingSatelliteEntity, restore_state.RestoreEntity, SwitchEntity +): + """Entity to represent voip is allowed.""" + + _attr_is_on = True + + entity_description = SwitchEntityDescription( + key="satellite_enabled", + translation_key="satellite_enabled", + entity_category=EntityCategory.CONFIG, + ) + + async def async_added_to_hass(self) -> None: + """Call when entity about to be added to hass.""" + await super().async_added_to_hass() + + state = await self.async_get_last_state() + self._attr_is_on = state is not None and state.state == STATE_ON + + async def async_turn_on(self, **kwargs: Any) -> None: + """Turn on.""" + self._attr_is_on = True + self._device.set_is_enabled(True) + self.async_write_ha_state() + + async def async_turn_off(self, **kwargs: Any) -> None: + """Turn off.""" + self._attr_is_on = False + self._device.set_is_enabled(False) + self.async_write_ha_state() From 6556b5b7c46df60baa0231c79260753ae6149cc4 Mon Sep 17 00:00:00 2001 From: Michael Hansen Date: Wed, 29 Nov 2023 16:25:48 -0600 Subject: [PATCH 07/18] Fix mistake --- homeassistant/components/wyoming/__init__.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/homeassistant/components/wyoming/__init__.py b/homeassistant/components/wyoming/__init__.py index 43283594f481e..88560547001b6 100644 --- a/homeassistant/components/wyoming/__init__.py +++ b/homeassistant/components/wyoming/__init__.py @@ -8,7 +8,7 @@ from homeassistant.core import HomeAssistant from homeassistant.exceptions import ConfigEntryNotReady -from .const import DOMAIN +from .const import ATTR_SPEAKER, DOMAIN from .data import WyomingService from .devices import SatelliteDevices from .models import DomainDataItem @@ -19,6 +19,7 @@ PLATFORMS = [Platform.BINARY_SENSOR, Platform.SELECT, Platform.SWITCH] __all__ = [ + "ATTR_SPEAKER", "DOMAIN", "async_setup_entry", "async_unload_entry", From 82d94d29f9f2381f1e9d62829d759944bb50e3f2 Mon Sep 17 00:00:00 2001 From: Michael Hansen Date: Thu, 30 Nov 2023 09:40:14 -0600 Subject: [PATCH 08/18] Only set up necessary platforms for satellites --- homeassistant/components/wyoming/__init__.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/homeassistant/components/wyoming/__init__.py b/homeassistant/components/wyoming/__init__.py index 88560547001b6..d21b6329b3893 100644 --- a/homeassistant/components/wyoming/__init__.py +++ b/homeassistant/components/wyoming/__init__.py @@ -16,7 +16,7 @@ _LOGGER = logging.getLogger(__name__) -PLATFORMS = [Platform.BINARY_SENSOR, Platform.SELECT, Platform.SWITCH] +SATELLITE_PLATFORMS = [Platform.BINARY_SENSOR, Platform.SELECT, Platform.SWITCH] __all__ = [ "ATTR_SPEAKER", @@ -39,14 +39,13 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: item = DomainDataItem(service=service, satellite_devices=satellite_devices) hass.data.setdefault(DOMAIN, {})[entry.entry_id] = item - await hass.config_entries.async_forward_entry_setups( - entry, - service.platforms + PLATFORMS, - ) - + await hass.config_entries.async_forward_entry_setups(entry, service.platforms) entry.async_on_unload(entry.add_update_listener(update_listener)) if service.info.satellite is not None: + # Set up satellite sensors, switches, etc. + await hass.config_entries.async_forward_entry_setups(entry, SATELLITE_PLATFORMS) + satellite_device = satellite_devices.async_get_or_create(item.service) wyoming_satellite = WyomingSatellite(hass, service, satellite_device) hass.async_create_background_task( From 65c96367af8f38e33f91616912925ffbb65eeebd Mon Sep 17 00:00:00 2001 From: Michael Hansen Date: Thu, 30 Nov 2023 14:02:57 -0600 Subject: [PATCH 09/18] Lots of fixes --- homeassistant/components/wyoming/__init__.py | 16 +++-- .../components/wyoming/config_flow.py | 27 ++++---- homeassistant/components/wyoming/devices.py | 24 ++----- homeassistant/components/wyoming/satellite.py | 64 ++++++++----------- homeassistant/components/wyoming/switch.py | 8 +-- tests/components/wyoming/__init__.py | 9 +++ tests/components/wyoming/conftest.py | 35 +++++++++- 7 files changed, 102 insertions(+), 81 deletions(-) diff --git a/homeassistant/components/wyoming/__init__.py b/homeassistant/components/wyoming/__init__.py index d21b6329b3893..cf0e0eb928362 100644 --- a/homeassistant/components/wyoming/__init__.py +++ b/homeassistant/components/wyoming/__init__.py @@ -46,16 +46,18 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: # Set up satellite sensors, switches, etc. await hass.config_entries.async_forward_entry_setups(entry, SATELLITE_PLATFORMS) - satellite_device = satellite_devices.async_get_or_create(item.service) + # Run satellite connection in a separate task + satellite_device = satellite_devices.async_get_or_create( + suggested_area=item.service.info.satellite.area + ) wyoming_satellite = WyomingSatellite(hass, service, satellite_device) - hass.async_create_background_task( - wyoming_satellite.run(), f"Satellite {item.service.info.satellite.name}" + entry.async_create_background_task( + hass, + wyoming_satellite.run(), + f"Satellite {item.service.info.satellite.name}", ) - def stop_satellite(): - wyoming_satellite.is_running = False - - entry.async_on_unload(stop_satellite) + entry.async_on_unload(wyoming_satellite.stop) return True diff --git a/homeassistant/components/wyoming/config_flow.py b/homeassistant/components/wyoming/config_flow.py index bdf62bbb7eee6..b766fc80c89a9 100644 --- a/homeassistant/components/wyoming/config_flow.py +++ b/homeassistant/components/wyoming/config_flow.py @@ -32,6 +32,7 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN): _hassio_discovery: hassio.HassioServiceInfo _service: WyomingService | None = None + _name: str | None = None async def async_step_user( self, user_input: dict[str, Any] | None = None @@ -110,16 +111,20 @@ async def async_step_zeroconf( service = await WyomingService.create(discovery_info.host, discovery_info.port) if (service is None) or (not (name := service.get_name())): + # No supported services return self.async_abort(reason="no_services") - self.context[CONF_NAME] = name - self.context["title_placeholders"] = {"name": name} + self._name = name - uuid = f"wyoming_{service.host}_{service.port}" - - await self.async_set_unique_id(uuid) + # Use zeroconf name + service name as unique id. + # The satellite will use its own MAC as the zeroconf name by default. + unique_id = f"{discovery_info.name}_{self._name}" + await self.async_set_unique_id(unique_id) self._abort_if_unique_id_configured() + self.context[CONF_NAME] = self._name + self.context["title_placeholders"] = {"name": self._name} + self._service = service return await self.async_step_zeroconf_confirm() @@ -127,22 +132,18 @@ async def async_step_zeroconf_confirm( self, user_input: dict[str, Any] | None = None ) -> FlowResult: """Handle a flow initiated by zeroconf.""" - if ( - (self._service is None) - or (not self._service.has_services()) - or (not (name := self._service.get_name())) - ): - return self.async_abort(reason="no_services") + assert self._service is not None + assert self._name is not None if user_input is None: return self.async_show_form( step_id="zeroconf_confirm", - description_placeholders={"name": name}, + description_placeholders={"name": self._name}, errors={}, ) return self.async_create_entry( - title=name, + title=self._name, data={ CONF_HOST: self._service.host, CONF_PORT: self._service.port, diff --git a/homeassistant/components/wyoming/devices.py b/homeassistant/components/wyoming/devices.py index 73850ad11e3ef..9a14bf8f92c61 100644 --- a/homeassistant/components/wyoming/devices.py +++ b/homeassistant/components/wyoming/devices.py @@ -6,10 +6,9 @@ from homeassistant.config_entries import ConfigEntry from homeassistant.core import Event, HomeAssistant, callback -from homeassistant.helpers import area_registry as ar, device_registry as dr +from homeassistant.helpers import device_registry as dr from .const import DOMAIN -from .data import WyomingService @dataclass @@ -83,11 +82,7 @@ def async_setup(self) -> None: def async_device_removed(ev: Event) -> None: """Handle device removed.""" removed_id = ev.data["device_id"] - self.devices = { - satellite_id: satellite_device - for satellite_id, satellite_device in self.devices.items() - if satellite_device.device_id != removed_id - } + self.devices.pop(removed_id, None) self.config_entry.async_on_unload( self.hass.bus.async_listen( @@ -105,31 +100,22 @@ def async_add_new_device_listener( self._new_device_listeners.append(listener) @callback - def async_get_or_create(self, service: WyomingService) -> SatelliteDevice: + def async_get_or_create(self, suggested_area: str | None = None) -> SatelliteDevice: """Get or create a device.""" dev_reg = dr.async_get(self.hass) - satellite_id = f"{service.host}_{service.port}" + satellite_id = self.config_entry.entry_id satellite_device = self.devices.get(satellite_id) if satellite_device is not None: return satellite_device - satellite_info = service.info.satellite - if not satellite_info: - raise ValueError("No satellite info") - device = dev_reg.async_get_or_create( config_entry_id=self.config_entry.entry_id, identifiers={(DOMAIN, satellite_id)}, name=satellite_id, + suggested_area=suggested_area, ) - if satellite_info.area: - # Use area hint - area_reg = ar.async_get(self.hass) - if area := area_reg.async_get_area_by_name(satellite_info.area): - dev_reg.async_update_device(device.id, area_id=area.id) - satellite_device = self.devices[satellite_id] = SatelliteDevice( satellite_id=satellite_id, device_id=device.id, diff --git a/homeassistant/components/wyoming/satellite.py b/homeassistant/components/wyoming/satellite.py index bd3555df78849..e27513743419a 100644 --- a/homeassistant/components/wyoming/satellite.py +++ b/homeassistant/components/wyoming/satellite.py @@ -29,6 +29,14 @@ _RECONNECT_SECONDS: Final = 10 _RESTART_SECONDS: Final = 3 +# Wyoming stage -> Assist stage +_STAGES: dict[PipelineStage, assist_pipeline.PipelineStage] = { + PipelineStage.WAKE: assist_pipeline.PipelineStage.WAKE_WORD, + PipelineStage.ASR: assist_pipeline.PipelineStage.STT, + PipelineStage.HANDLE: assist_pipeline.PipelineStage.INTENT, + PipelineStage.TTS: assist_pipeline.PipelineStage.TTS, +} + class WyomingSatellite: """Remove voice satellite running the Wyoming protocol.""" @@ -41,7 +49,7 @@ def __init__( self.service = service self.device = device self.is_enabled = True - self._is_running = True + self.is_running = True self._client: AsyncTcpClient | None = None self._chunk_converter = AudioChunkConverter(rate=16000, width=2, channels=1) @@ -50,19 +58,6 @@ def __init__( self._pipeline_id: str | None = None self._device_updated_event = asyncio.Event() - @property - def is_running(self) -> bool: - """Return True if satellite is running.""" - return self._is_running - - @is_running.setter - def is_running(self, value: bool) -> None: - """Set whether satellite is running or not.""" - self._is_running = value - - # Unblock waiting for enabled - self._device_updated_event.set() - async def run(self) -> None: """Run and maintain a connection to satellite.""" _LOGGER.debug("Running satellite task") @@ -101,6 +96,15 @@ async def run(self) -> None: _LOGGER.debug("Satellite task stopped") + def stop(self) -> None: + """Signal satellite task to stop running.""" + self.is_running = False + + # Unblock waiting for enabled + self._device_updated_event.set() + + # ------------------------------------------------------------------------- + def _device_updated(self, device: SatelliteDevice) -> None: """Reacts to updated device settings.""" pipeline_id = pipeline_select.get_chosen_pipeline( @@ -135,7 +139,7 @@ async def _run_once(self) -> None: await self._connect() break except ConnectionError: - _LOGGER.exception( + _LOGGER.debug( "Failed to connect to satellite. Reconnecting in %s second(s)", _RECONNECT_SECONDS, ) @@ -171,8 +175,14 @@ async def _run_once(self) -> None: # Run was cancelled return - start_stage = _convert_stage(run_pipeline.start_stage) - end_stage = _convert_stage(run_pipeline.end_stage) + start_stage = _STAGES.get(run_pipeline.start_stage) + end_stage = _STAGES.get(run_pipeline.end_stage) + + if start_stage is None: + raise ValueError(f"Invalid start stage: {start_stage}") + + if end_stage is None: + raise ValueError(f"Invalid end stage: {end_stage}") # Each loop is a pipeline run while self.is_running and self.is_enabled: @@ -353,23 +363,3 @@ async def _stt_stream(self) -> AsyncGenerator[bytes, None]: _LOGGER.debug("Receiving audio from satellite") yield chunk - - -# ----------------------------------------------------------------------------- - - -def _convert_stage(wyoming_stage: PipelineStage) -> assist_pipeline.PipelineStage: - """Convert Wyoming pipeline stage to Assist pipeline stage.""" - if wyoming_stage == PipelineStage.WAKE: - return assist_pipeline.PipelineStage.WAKE_WORD - - if wyoming_stage == PipelineStage.ASR: - return assist_pipeline.PipelineStage.STT - - if wyoming_stage == PipelineStage.HANDLE: - return assist_pipeline.PipelineStage.INTENT - - if wyoming_stage == PipelineStage.TTS: - return assist_pipeline.PipelineStage.TTS - - raise ValueError(f"Unknown Wyoming pipeline stage: {wyoming_stage}") diff --git a/homeassistant/components/wyoming/switch.py b/homeassistant/components/wyoming/switch.py index 17e663c54f27b..374b953cc8235 100644 --- a/homeassistant/components/wyoming/switch.py +++ b/homeassistant/components/wyoming/switch.py @@ -45,9 +45,7 @@ def async_add_device(device: SatelliteDevice) -> None: class WyomingSatelliteEnabledSwitch( WyomingSatelliteEntity, restore_state.RestoreEntity, SwitchEntity ): - """Entity to represent voip is allowed.""" - - _attr_is_on = True + """Entity to represent if satellite is enabled.""" entity_description = SwitchEntityDescription( key="satellite_enabled", @@ -60,7 +58,9 @@ async def async_added_to_hass(self) -> None: await super().async_added_to_hass() state = await self.async_get_last_state() - self._attr_is_on = state is not None and state.state == STATE_ON + + # Default to on + self._attr_is_on = (state is None) or (state.state == STATE_ON) async def async_turn_on(self, **kwargs: Any) -> None: """Turn on.""" diff --git a/tests/components/wyoming/__init__.py b/tests/components/wyoming/__init__.py index e04ff4eda03af..8cda35d83351e 100644 --- a/tests/components/wyoming/__init__.py +++ b/tests/components/wyoming/__init__.py @@ -6,6 +6,7 @@ AsrProgram, Attribution, Info, + Satellite, TtsProgram, TtsVoice, TtsVoiceSpeaker, @@ -72,6 +73,14 @@ ) ] ) +SATELLITE_INFO = Info( + satellite=Satellite( + name="Test Satellite", + description="Test Satellite", + installed=True, + attribution=TEST_ATTR, + ) +) EMPTY_INFO = Info() diff --git a/tests/components/wyoming/conftest.py b/tests/components/wyoming/conftest.py index d211b95e92720..52ad6a709be0c 100644 --- a/tests/components/wyoming/conftest.py +++ b/tests/components/wyoming/conftest.py @@ -5,11 +5,13 @@ import pytest from homeassistant.components import stt +from homeassistant.components.wyoming import DOMAIN +from homeassistant.components.wyoming.devices import SatelliteDevices from homeassistant.config_entries import ConfigEntry from homeassistant.core import HomeAssistant from homeassistant.setup import async_setup_component -from . import STT_INFO, TTS_INFO, WAKE_WORD_INFO +from . import SATELLITE_INFO, STT_INFO, TTS_INFO, WAKE_WORD_INFO from tests.common import MockConfigEntry @@ -117,3 +119,34 @@ def metadata(hass: HomeAssistant) -> stt.SpeechMetadata: sample_rate=stt.AudioSampleRates.SAMPLERATE_16000, channel=stt.AudioChannels.CHANNEL_MONO, ) + + +@pytest.fixture +def satellite_config_entry(hass: HomeAssistant) -> ConfigEntry: + """Create a config entry.""" + entry = MockConfigEntry( + domain="wyoming", + data={ + "host": "1.2.3.4", + "port": 1234, + }, + title="Test Satellite", + ) + entry.add_to_hass(hass) + return entry + + +@pytest.fixture +async def init_satellite(hass: HomeAssistant, satellite_config_entry: ConfigEntry): + """Initialize Wyoming satellite.""" + with patch( + "homeassistant.components.wyoming.data.load_wyoming_info", + return_value=SATELLITE_INFO, + ): + await hass.config_entries.async_setup(satellite_config_entry.entry_id) + + +@pytest.fixture +async def satellite_devices(hass: HomeAssistant, init_satellite) -> SatelliteDevices: + """Get satellite devices object from a configured instance.""" + return hass.data[DOMAIN].satellite_devices From a8b6a50578902589f1bc83475044eb12a620f1be Mon Sep 17 00:00:00 2001 From: Michael Hansen Date: Thu, 30 Nov 2023 15:34:07 -0600 Subject: [PATCH 10/18] Add tests --- homeassistant/components/wyoming/__init__.py | 7 +- homeassistant/components/wyoming/devices.py | 42 ++++++++-- tests/components/wyoming/__init__.py | 1 + tests/components/wyoming/conftest.py | 22 ++++- .../components/wyoming/test_binary_sensor.py | 34 ++++++++ tests/components/wyoming/test_devices.py | 81 +++++++++++++++++++ tests/components/wyoming/test_select.py | 23 ++++++ tests/components/wyoming/test_switch.py | 59 ++++++++++++++ 8 files changed, 258 insertions(+), 11 deletions(-) create mode 100644 tests/components/wyoming/test_binary_sensor.py create mode 100644 tests/components/wyoming/test_devices.py create mode 100644 tests/components/wyoming/test_select.py create mode 100644 tests/components/wyoming/test_switch.py diff --git a/homeassistant/components/wyoming/__init__.py b/homeassistant/components/wyoming/__init__.py index cf0e0eb928362..8301ac3207c48 100644 --- a/homeassistant/components/wyoming/__init__.py +++ b/homeassistant/components/wyoming/__init__.py @@ -42,19 +42,20 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: await hass.config_entries.async_forward_entry_setups(entry, service.platforms) entry.async_on_unload(entry.add_update_listener(update_listener)) - if service.info.satellite is not None: + if (satellite_info := service.info.satellite) is not None: # Set up satellite sensors, switches, etc. await hass.config_entries.async_forward_entry_setups(entry, SATELLITE_PLATFORMS) # Run satellite connection in a separate task satellite_device = satellite_devices.async_get_or_create( - suggested_area=item.service.info.satellite.area + name=satellite_info.name, + suggested_area=satellite_info.area, ) wyoming_satellite = WyomingSatellite(hass, service, satellite_device) entry.async_create_background_task( hass, wyoming_satellite.run(), - f"Satellite {item.service.info.satellite.name}", + f"Satellite {satellite_info.name}", ) entry.async_on_unload(wyoming_satellite.stop) diff --git a/homeassistant/components/wyoming/devices.py b/homeassistant/components/wyoming/devices.py index 9a14bf8f92c61..ce34c9cf06418 100644 --- a/homeassistant/components/wyoming/devices.py +++ b/homeassistant/components/wyoming/devices.py @@ -6,7 +6,7 @@ from homeassistant.config_entries import ConfigEntry from homeassistant.core import Event, HomeAssistant, callback -from homeassistant.helpers import device_registry as dr +from homeassistant.helpers import device_registry as dr, entity_registry as er from .const import DOMAIN @@ -51,6 +51,27 @@ def async_listen_update( self.update_listeners.append(listener) return lambda: self.update_listeners.remove(listener) + def get_assist_in_progress_entity_id(self, hass: HomeAssistant) -> str | None: + """Return entity id for assist in progress binary sensor.""" + ent_reg = er.async_get(hass) + return ent_reg.async_get_entity_id( + "binary_sensor", DOMAIN, f"{self.satellite_id}-assist_in_progress" + ) + + def get_satellite_enabled_entity_id(self, hass: HomeAssistant) -> str | None: + """Return entity id for satellite enabled switch.""" + ent_reg = er.async_get(hass) + return ent_reg.async_get_entity_id( + "switch", DOMAIN, f"{self.satellite_id}-satellite_enabled" + ) + + def get_pipeline_entity_id(self, hass: HomeAssistant) -> str | None: + """Return entity id for pipeline select.""" + ent_reg = er.async_get(hass) + return ent_reg.async_get_entity_id( + "select", DOMAIN, f"{self.satellite_id}-pipeline" + ) + class SatelliteDevices: """Class to store devices.""" @@ -60,6 +81,8 @@ def __init__(self, hass: HomeAssistant, config_entry: ConfigEntry) -> None: self.hass = hass self.config_entry = config_entry self._new_device_listeners: list[Callable[[SatelliteDevice], None]] = [] + + # satellite_id -> device self.devices: dict[str, SatelliteDevice] = {} @callback @@ -82,7 +105,11 @@ def async_setup(self) -> None: def async_device_removed(ev: Event) -> None: """Handle device removed.""" removed_id = ev.data["device_id"] - self.devices.pop(removed_id, None) + self.devices = { + satellite_id: satellite_device + for satellite_id, satellite_device in self.devices.items() + if satellite_device.device_id != removed_id + } self.config_entry.async_on_unload( self.hass.bus.async_listen( @@ -100,10 +127,15 @@ def async_add_new_device_listener( self._new_device_listeners.append(listener) @callback - def async_get_or_create(self, suggested_area: str | None = None) -> SatelliteDevice: + def async_get_or_create( + self, name: str | None = None, suggested_area: str | None = None + ) -> SatelliteDevice: """Get or create a device.""" + if not self.config_entry.unique_id: + raise ValueError("No unique id is set for config entry") + dev_reg = dr.async_get(self.hass) - satellite_id = self.config_entry.entry_id + satellite_id = self.config_entry.unique_id satellite_device = self.devices.get(satellite_id) if satellite_device is not None: @@ -112,7 +144,7 @@ def async_get_or_create(self, suggested_area: str | None = None) -> SatelliteDev device = dev_reg.async_get_or_create( config_entry_id=self.config_entry.entry_id, identifiers={(DOMAIN, satellite_id)}, - name=satellite_id, + name=name or satellite_id, suggested_area=suggested_area, ) diff --git a/tests/components/wyoming/__init__.py b/tests/components/wyoming/__init__.py index 8cda35d83351e..3e77dc780a183 100644 --- a/tests/components/wyoming/__init__.py +++ b/tests/components/wyoming/__init__.py @@ -79,6 +79,7 @@ description="Test Satellite", installed=True, attribution=TEST_ATTR, + area="Office", ) ) EMPTY_INFO = Info() diff --git a/tests/components/wyoming/conftest.py b/tests/components/wyoming/conftest.py index 52ad6a709be0c..ec1fbfebf2e18 100644 --- a/tests/components/wyoming/conftest.py +++ b/tests/components/wyoming/conftest.py @@ -6,7 +6,7 @@ from homeassistant.components import stt from homeassistant.components.wyoming import DOMAIN -from homeassistant.components.wyoming.devices import SatelliteDevices +from homeassistant.components.wyoming.devices import SatelliteDevice, SatelliteDevices from homeassistant.config_entries import ConfigEntry from homeassistant.core import HomeAssistant from homeassistant.setup import async_setup_component @@ -131,6 +131,7 @@ def satellite_config_entry(hass: HomeAssistant) -> ConfigEntry: "port": 1234, }, title="Test Satellite", + unique_id="1234_test", ) entry.add_to_hass(hass) return entry @@ -147,6 +148,21 @@ async def init_satellite(hass: HomeAssistant, satellite_config_entry: ConfigEntr @pytest.fixture -async def satellite_devices(hass: HomeAssistant, init_satellite) -> SatelliteDevices: +async def satellite_devices( + hass: HomeAssistant, init_satellite, satellite_config_entry: ConfigEntry +) -> SatelliteDevices: """Get satellite devices object from a configured instance.""" - return hass.data[DOMAIN].satellite_devices + return hass.data[DOMAIN][satellite_config_entry.entry_id].satellite_devices + + +@pytest.fixture +async def satellite_device( + hass: HomeAssistant, + satellite_devices: SatelliteDevices, + satellite_config_entry: ConfigEntry, +) -> SatelliteDevice: + """Get a satellite device fixture.""" + device = satellite_devices.async_get_or_create() + # to make sure all platforms are set up + await hass.async_block_till_done() + return device diff --git a/tests/components/wyoming/test_binary_sensor.py b/tests/components/wyoming/test_binary_sensor.py new file mode 100644 index 0000000000000..27294186a9012 --- /dev/null +++ b/tests/components/wyoming/test_binary_sensor.py @@ -0,0 +1,34 @@ +"""Test Wyoming binary sensor devices.""" +from homeassistant.components.wyoming.devices import SatelliteDevice +from homeassistant.config_entries import ConfigEntry +from homeassistant.const import STATE_OFF, STATE_ON +from homeassistant.core import HomeAssistant + + +async def test_assist_in_progress( + hass: HomeAssistant, + satellite_config_entry: ConfigEntry, + satellite_device: SatelliteDevice, +) -> None: + """Test assist in progress.""" + assist_in_progress_id = satellite_device.get_assist_in_progress_entity_id(hass) + assert assist_in_progress_id + + state = hass.states.get(assist_in_progress_id) + assert state is not None + assert state.state == STATE_OFF + assert not satellite_device.is_active + + satellite_device.set_is_active(True) + + state = hass.states.get(assist_in_progress_id) + assert state is not None + assert state.state == STATE_ON + assert satellite_device.is_active + + satellite_device.set_is_active(False) + + state = hass.states.get(assist_in_progress_id) + assert state is not None + assert state.state == STATE_OFF + assert not satellite_device.is_active diff --git a/tests/components/wyoming/test_devices.py b/tests/components/wyoming/test_devices.py new file mode 100644 index 0000000000000..6862b2d1aae8b --- /dev/null +++ b/tests/components/wyoming/test_devices.py @@ -0,0 +1,81 @@ +"""Test Wyoming devices.""" +from __future__ import annotations + +from homeassistant.components.assist_pipeline.select import OPTION_PREFERRED +from homeassistant.components.wyoming import DOMAIN +from homeassistant.components.wyoming.devices import SatelliteDevice, SatelliteDevices +from homeassistant.config_entries import ConfigEntry +from homeassistant.const import STATE_OFF, STATE_ON +from homeassistant.core import HomeAssistant +from homeassistant.helpers import device_registry as dr + + +async def test_device_registry_info( + hass: HomeAssistant, + satellite_devices: SatelliteDevices, + satellite_config_entry: ConfigEntry, + device_registry: dr.DeviceRegistry, +) -> None: + """Test info in device registry.""" + assert satellite_config_entry.unique_id + satellite_device = satellite_devices.async_get_or_create() + + device = device_registry.async_get_device( + identifiers={(DOMAIN, satellite_config_entry.unique_id)} + ) + assert device is not None + assert device.name == "Test Satellite" + assert device.suggested_area == "Office" + + # Check associated entities + assist_in_progress_id = satellite_device.get_assist_in_progress_entity_id(hass) + assert assist_in_progress_id + assist_in_progress_state = hass.states.get(assist_in_progress_id) + assert assist_in_progress_state is not None + assert assist_in_progress_state.state == STATE_OFF + + satellite_enabled_id = satellite_device.get_satellite_enabled_entity_id(hass) + assert satellite_enabled_id + satellite_enabled_state = hass.states.get(satellite_enabled_id) + assert satellite_enabled_state is not None + assert satellite_enabled_state.state == STATE_ON + + pipeline_entity_id = satellite_device.get_pipeline_entity_id(hass) + assert pipeline_entity_id + pipeline_state = hass.states.get(pipeline_entity_id) + assert pipeline_state is not None + assert pipeline_state.state == OPTION_PREFERRED + + +async def test_remove_device_registry_entry( + hass: HomeAssistant, + satellite_devices: SatelliteDevices, + satellite_device: SatelliteDevice, + device_registry: dr.DeviceRegistry, +) -> None: + """Test removing a device registry entry.""" + assert satellite_device.satellite_id in satellite_devices.devices + + # Check associated entities + assist_in_progress_id = satellite_device.get_assist_in_progress_entity_id(hass) + assert assist_in_progress_id + assert hass.states.get(assist_in_progress_id) is not None + + satellite_enabled_id = satellite_device.get_satellite_enabled_entity_id(hass) + assert satellite_enabled_id + assert hass.states.get(satellite_enabled_id) is not None + + pipeline_entity_id = satellite_device.get_pipeline_entity_id(hass) + assert pipeline_entity_id + assert hass.states.get(pipeline_entity_id) is not None + + # Remove + device_registry.async_remove_device(satellite_device.device_id) + await hass.async_block_till_done() + await hass.async_block_till_done() + + # Everything should be gone + assert hass.states.get(assist_in_progress_id) is None + assert hass.states.get(satellite_enabled_id) is None + assert hass.states.get(pipeline_entity_id) is None + assert satellite_device.satellite_id not in satellite_devices.devices diff --git a/tests/components/wyoming/test_select.py b/tests/components/wyoming/test_select.py new file mode 100644 index 0000000000000..3b12210193b2f --- /dev/null +++ b/tests/components/wyoming/test_select.py @@ -0,0 +1,23 @@ +"""Test Wyoming select.""" +from homeassistant.components.assist_pipeline.select import OPTION_PREFERRED +from homeassistant.components.wyoming.devices import SatelliteDevice +from homeassistant.config_entries import ConfigEntry +from homeassistant.core import HomeAssistant + + +async def test_pipeline_select( + hass: HomeAssistant, + satellite_config_entry: ConfigEntry, + satellite_device: SatelliteDevice, +) -> None: + """Test pipeline select. + + Functionality is tested in assist_pipeline/test_select.py. + This test is only to ensure it is set up. + """ + pipeline_entity_id = satellite_device.get_pipeline_entity_id(hass) + assert pipeline_entity_id + + state = hass.states.get(pipeline_entity_id) + assert state is not None + assert state.state == OPTION_PREFERRED diff --git a/tests/components/wyoming/test_switch.py b/tests/components/wyoming/test_switch.py new file mode 100644 index 0000000000000..5b39db38e5c03 --- /dev/null +++ b/tests/components/wyoming/test_switch.py @@ -0,0 +1,59 @@ +"""Test Wyoming switch devices.""" +from homeassistant.components.wyoming.devices import SatelliteDevice +from homeassistant.config_entries import ConfigEntry +from homeassistant.const import STATE_OFF, STATE_ON +from homeassistant.core import HomeAssistant + + +async def test_satellite_enabled( + hass: HomeAssistant, + satellite_config_entry: ConfigEntry, + satellite_device: SatelliteDevice, +) -> None: + """Test satellite enabled.""" + satellite_enabled_id = satellite_device.get_satellite_enabled_entity_id(hass) + assert satellite_enabled_id + + state = hass.states.get(satellite_enabled_id) + assert state is not None + assert state.state == STATE_ON + assert satellite_device.is_enabled + + await hass.config_entries.async_reload(satellite_config_entry.entry_id) + + state = hass.states.get(satellite_enabled_id) + assert state is not None + assert state.state == STATE_ON + assert satellite_device.is_enabled + + await hass.services.async_call( + "switch", + "turn_off", + {"entity_id": satellite_enabled_id}, + blocking=True, + ) + + state = hass.states.get(satellite_enabled_id) + assert state is not None + assert state.state == STATE_OFF + assert not satellite_device.is_enabled + + await hass.config_entries.async_reload(satellite_config_entry.entry_id) + await hass.async_block_till_done() + + state = hass.states.get(satellite_enabled_id) + assert state is not None + assert state.state == STATE_OFF + assert not satellite_device.is_enabled + + await hass.services.async_call( + "switch", + "turn_on", + {"entity_id": satellite_enabled_id}, + blocking=True, + ) + + state = hass.states.get(satellite_enabled_id) + assert state is not None + assert state.state == STATE_ON + assert satellite_device.is_enabled From 822f6fcafac2419c89707eaf53efcce14ec8deaf Mon Sep 17 00:00:00 2001 From: Michael Hansen Date: Thu, 30 Nov 2023 16:03:19 -0600 Subject: [PATCH 11/18] Use config entry id as satellite id --- homeassistant/components/wyoming/devices.py | 7 +++---- tests/components/wyoming/conftest.py | 6 ++++-- tests/components/wyoming/test_devices.py | 5 +++-- 3 files changed, 10 insertions(+), 8 deletions(-) diff --git a/homeassistant/components/wyoming/devices.py b/homeassistant/components/wyoming/devices.py index ce34c9cf06418..53d38946fd7e8 100644 --- a/homeassistant/components/wyoming/devices.py +++ b/homeassistant/components/wyoming/devices.py @@ -131,11 +131,10 @@ def async_get_or_create( self, name: str | None = None, suggested_area: str | None = None ) -> SatelliteDevice: """Get or create a device.""" - if not self.config_entry.unique_id: - raise ValueError("No unique id is set for config entry") - dev_reg = dr.async_get(self.hass) - satellite_id = self.config_entry.unique_id + + # Use config entry id since only one satellite per entry is supported + satellite_id = self.config_entry.entry_id satellite_device = self.devices.get(satellite_id) if satellite_device is not None: diff --git a/tests/components/wyoming/conftest.py b/tests/components/wyoming/conftest.py index ec1fbfebf2e18..c70c40342e377 100644 --- a/tests/components/wyoming/conftest.py +++ b/tests/components/wyoming/conftest.py @@ -131,7 +131,6 @@ def satellite_config_entry(hass: HomeAssistant) -> ConfigEntry: "port": 1234, }, title="Test Satellite", - unique_id="1234_test", ) entry.add_to_hass(hass) return entry @@ -143,7 +142,10 @@ async def init_satellite(hass: HomeAssistant, satellite_config_entry: ConfigEntr with patch( "homeassistant.components.wyoming.data.load_wyoming_info", return_value=SATELLITE_INFO, - ): + ), patch( + "homeassistant.components.wyoming.satellite.WyomingSatellite.run" + ) as _run_mock: + # _run_mock: satellite task does not actually run await hass.config_entries.async_setup(satellite_config_entry.entry_id) diff --git a/tests/components/wyoming/test_devices.py b/tests/components/wyoming/test_devices.py index 6862b2d1aae8b..58616673cd5b8 100644 --- a/tests/components/wyoming/test_devices.py +++ b/tests/components/wyoming/test_devices.py @@ -17,11 +17,12 @@ async def test_device_registry_info( device_registry: dr.DeviceRegistry, ) -> None: """Test info in device registry.""" - assert satellite_config_entry.unique_id satellite_device = satellite_devices.async_get_or_create() + # Satellite uses config entry id since only one satellite per entry is + # supported. device = device_registry.async_get_device( - identifiers={(DOMAIN, satellite_config_entry.unique_id)} + identifiers={(DOMAIN, satellite_config_entry.entry_id)} ) assert device is not None assert device.name == "Test Satellite" From b0731ea86680a436dffe17a9efb68339286b275f Mon Sep 17 00:00:00 2001 From: Michael Hansen Date: Thu, 30 Nov 2023 17:04:51 -0600 Subject: [PATCH 12/18] Initial satellite test --- tests/components/wyoming/__init__.py | 18 ++- tests/components/wyoming/test_satellite.py | 173 +++++++++++++++++++++ 2 files changed, 184 insertions(+), 7 deletions(-) create mode 100644 tests/components/wyoming/test_satellite.py diff --git a/tests/components/wyoming/__init__.py b/tests/components/wyoming/__init__.py index 3e77dc780a183..899eda7ec1a3e 100644 --- a/tests/components/wyoming/__init__.py +++ b/tests/components/wyoming/__init__.py @@ -1,6 +1,7 @@ """Tests for the Wyoming integration.""" import asyncio +from wyoming.event import Event from wyoming.info import ( AsrModel, AsrProgram, @@ -88,18 +89,21 @@ class MockAsyncTcpClient: """Mock AsyncTcpClient.""" - def __init__(self, responses) -> None: + def __init__(self, responses: list[Event]) -> None: """Initialize.""" - self.host = None - self.port = None - self.written = [] + self.host: str | None = None + self.port: int | None = None + self.written: list[Event] = [] self.responses = responses - async def write_event(self, event): + async def connect(self) -> None: + """Connect.""" + + async def write_event(self, event: Event): """Send.""" self.written.append(event) - async def read_event(self): + async def read_event(self) -> Event | None: """Receive.""" await asyncio.sleep(0) # force context switch @@ -115,7 +119,7 @@ async def __aenter__(self): async def __aexit__(self, exc_type, exc, tb): """Exit.""" - def __call__(self, host, port): + def __call__(self, host: str, port: int): """Call.""" self.host = host self.port = port diff --git a/tests/components/wyoming/test_satellite.py b/tests/components/wyoming/test_satellite.py new file mode 100644 index 0000000000000..5543656a99ec2 --- /dev/null +++ b/tests/components/wyoming/test_satellite.py @@ -0,0 +1,173 @@ +"""Test Wyoming satellite.""" +from __future__ import annotations + +import asyncio +from unittest.mock import patch + +from wyoming.audio import AudioChunk +from wyoming.event import Event +from wyoming.pipeline import PipelineStage, RunPipeline +from wyoming.satellite import RunSatellite +from wyoming.vad import VoiceStarted, VoiceStopped +from wyoming.wake import Detect, Detection + +from homeassistant.components import assist_pipeline, wyoming +from homeassistant.components.wyoming.devices import SatelliteDevices +from homeassistant.core import HomeAssistant +from homeassistant.setup import async_setup_component + +from . import SATELLITE_INFO, MockAsyncTcpClient + +from tests.common import MockConfigEntry + + +class SatelliteAsyncTcpClient(MockAsyncTcpClient): + """Satellite AsyncTcpClient.""" + + def __init__(self, responses: list[Event]) -> None: + """Initialize client.""" + super().__init__(responses) + + self.connect_event = asyncio.Event() + self.run_satellite_event = asyncio.Event() + self.detect_event = asyncio.Event() + + self.detection_event = asyncio.Event() + self.detection: Detection | None = None + + self.stt_vad_start_event = asyncio.Event() + self.voice_started: VoiceStarted | None = None + + self.stt_vad_end_event = asyncio.Event() + self.voice_stopped: VoiceStopped | None = None + + self._audio_chunk = AudioChunk( + rate=16000, width=2, channels=1, audio=b"chunk" + ).event() + + async def connect(self) -> None: + """Connect.""" + self.connect_event.set() + + async def write_event(self, event: Event): + """Send.""" + if RunSatellite.is_type(event.type): + self.run_satellite_event.set() + elif Detect.is_type(event.type): + self.detect_event.set() + elif Detection.is_type(event.type): + self.detection_event.set() + self.detection = Detection.from_event(event) + elif VoiceStarted.is_type(event.type): + self.stt_vad_start_event.set() + self.voice_started = VoiceStarted.from_event(event) + elif VoiceStopped.is_type(event.type): + self.stt_vad_end_event.set() + self.voice_stopped = VoiceStopped.from_event(event) + + async def read_event(self) -> Event | None: + """Receive.""" + event = await super().read_event() + + # Keep sending audio chunks + return event or self._audio_chunk + + +async def test_satellite(hass: HomeAssistant) -> None: + """Test running a pipeline with a satellite.""" + assert await async_setup_component(hass, assist_pipeline.DOMAIN, {}) + + events = [ + RunPipeline( + start_stage=PipelineStage.WAKE, end_stage=PipelineStage.TTS + ).event(), + ] + + with patch( + "homeassistant.components.wyoming.data.load_wyoming_info", + return_value=SATELLITE_INFO, + ), patch( + "homeassistant.components.wyoming.satellite.AsyncTcpClient", + SatelliteAsyncTcpClient(events), + ) as mock_client, patch( + "homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream", + ) as mock_run_pipeline: + entry = MockConfigEntry( + domain="wyoming", + data={ + "host": "1.2.3.4", + "port": 1234, + }, + title="Test Satellite", + unique_id="1234_test", + ) + entry.add_to_hass(hass) + await hass.config_entries.async_setup(entry.entry_id) + await hass.async_block_till_done() + + satellite_devices: SatelliteDevices = hass.data[wyoming.DOMAIN][ + entry.entry_id + ].satellite_devices + assert entry.entry_id in satellite_devices.devices + device = satellite_devices.devices[entry.entry_id] + + async with asyncio.timeout(1): + await mock_client.connect_event.wait() + await mock_client.run_satellite_event.wait() + + mock_run_pipeline.assert_called() + event_callback = mock_run_pipeline.call_args.kwargs["event_callback"] + + # Start detecting wake word + event_callback( + assist_pipeline.PipelineEvent( + assist_pipeline.PipelineEventType.WAKE_WORD_START + ) + ) + async with asyncio.timeout(1): + await mock_client.detect_event.wait() + + assert not device.is_active + assert device.is_enabled + + # Wake word is detected + event_callback( + assist_pipeline.PipelineEvent( + assist_pipeline.PipelineEventType.WAKE_WORD_END, + {"wake_word_output": {"wake_word_id": "test_wake_word"}}, + ) + ) + async with asyncio.timeout(1): + await mock_client.detection_event.wait() + + assert mock_client.detection is not None + assert mock_client.detection.name == "test_wake_word" + + # "Assist in progress" sensor should be active now + assert device.is_active + + # Speech to text started + event_callback( + assist_pipeline.PipelineEvent(assist_pipeline.PipelineEventType.STT_START) + ) + + event_callback( + assist_pipeline.PipelineEvent( + assist_pipeline.PipelineEventType.STT_VAD_START, {"timestamp": 1234} + ) + ) + async with asyncio.timeout(1): + await mock_client.stt_vad_start_event.wait() + + assert mock_client.voice_started is not None + assert mock_client.voice_started.timestamp == 1234 + + # Pipeline finished + event_callback( + assist_pipeline.PipelineEvent(assist_pipeline.PipelineEventType.RUN_END) + ) + assert not device.is_active + + # Stop the satellite + await hass.config_entries.async_unload(entry.entry_id) + await hass.async_block_till_done() From 2ca80a13c3877224683d44b25b8e09c75d1b5977 Mon Sep 17 00:00:00 2001 From: Michael Hansen Date: Fri, 1 Dec 2023 10:44:06 -0600 Subject: [PATCH 13/18] Add satellite pipeline test --- homeassistant/components/wyoming/__init__.py | 6 +- homeassistant/components/wyoming/models.py | 2 + homeassistant/components/wyoming/satellite.py | 21 +- tests/components/wyoming/test_satellite.py | 185 +++++++++++++++--- 4 files changed, 178 insertions(+), 36 deletions(-) diff --git a/homeassistant/components/wyoming/__init__.py b/homeassistant/components/wyoming/__init__.py index 8301ac3207c48..19a487029d699 100644 --- a/homeassistant/components/wyoming/__init__.py +++ b/homeassistant/components/wyoming/__init__.py @@ -51,14 +51,14 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: name=satellite_info.name, suggested_area=satellite_info.area, ) - wyoming_satellite = WyomingSatellite(hass, service, satellite_device) + item.satellite = WyomingSatellite(hass, service, satellite_device) entry.async_create_background_task( hass, - wyoming_satellite.run(), + item.satellite.run(), f"Satellite {satellite_info.name}", ) - entry.async_on_unload(wyoming_satellite.stop) + entry.async_on_unload(item.satellite.stop) return True diff --git a/homeassistant/components/wyoming/models.py b/homeassistant/components/wyoming/models.py index adb113c541e7b..c29e28ea3f3f1 100644 --- a/homeassistant/components/wyoming/models.py +++ b/homeassistant/components/wyoming/models.py @@ -3,6 +3,7 @@ from .data import WyomingService from .devices import SatelliteDevices +from .satellite import WyomingSatellite @dataclass @@ -11,3 +12,4 @@ class DomainDataItem: service: WyomingService satellite_devices: SatelliteDevices + satellite: WyomingSatellite | None = None diff --git a/homeassistant/components/wyoming/satellite.py b/homeassistant/components/wyoming/satellite.py index e27513743419a..df894cd201385 100644 --- a/homeassistant/components/wyoming/satellite.py +++ b/homeassistant/components/wyoming/satellite.py @@ -6,7 +6,7 @@ from typing import Final import wave -from wyoming.asr import Transcript +from wyoming.asr import Transcribe, Transcript from wyoming.audio import AudioChunk, AudioChunkConverter, AudioStart, AudioStop from wyoming.client import AsyncTcpClient from wyoming.pipeline import PipelineStage, RunPipeline @@ -134,7 +134,7 @@ async def _run_once(self) -> None: if self.device.is_active: self.device.set_is_active(False) - while True: + while self.is_running and self.is_enabled: try: await self._connect() break @@ -148,8 +148,8 @@ async def _run_once(self) -> None: assert self._client is not None _LOGGER.debug("Connected to satellite") - if not self.is_running: - # Run was cancelled + if (not self.is_running) or (not self.is_enabled): + # Run was cancelled or satellite was disabled return # Tell satellite that we're ready @@ -264,7 +264,15 @@ def _event_callback(self, event: assist_pipeline.PipelineEvent) -> None: # Speech-to-text if not self.device.is_active: self.device.set_is_active(True) + + if event.data: + self.hass.add_job( + self._client.write_event( + Transcribe(language=event.data["metadata"]["language"]).event() + ) + ) elif event.type == assist_pipeline.PipelineEventType.STT_VAD_START: + # User started speaking if event.data: self.hass.add_job( self._client.write_event( @@ -272,6 +280,7 @@ def _event_callback(self, event: assist_pipeline.PipelineEvent) -> None: ) ) elif event.type == assist_pipeline.PipelineEventType.STT_VAD_END: + # User stopped speaking if event.data: self.hass.add_job( self._client.write_event( @@ -295,8 +304,8 @@ def _event_callback(self, event: assist_pipeline.PipelineEvent) -> None: Synthesize( text=event.data["tts_input"], voice=SynthesizeVoice( - name=event.data["voice"], - language=event.data["language"], + name=event.data.get("voice"), + language=event.data.get("language"), ), ).event() ) diff --git a/tests/components/wyoming/test_satellite.py b/tests/components/wyoming/test_satellite.py index 5543656a99ec2..7443bc7691c78 100644 --- a/tests/components/wyoming/test_satellite.py +++ b/tests/components/wyoming/test_satellite.py @@ -2,12 +2,16 @@ from __future__ import annotations import asyncio +import io from unittest.mock import patch +import wave -from wyoming.audio import AudioChunk +from wyoming.asr import Transcribe, Transcript +from wyoming.audio import AudioChunk, AudioStart, AudioStop from wyoming.event import Event from wyoming.pipeline import PipelineStage, RunPipeline from wyoming.satellite import RunSatellite +from wyoming.tts import Synthesize from wyoming.vad import VoiceStarted, VoiceStopped from wyoming.wake import Detect, Detection @@ -21,6 +25,41 @@ from tests.common import MockConfigEntry +async def setup_config_entry(hass: HomeAssistant) -> MockConfigEntry: + """Set up config entry for Wyoming satellite. + + This is separated from the satellite_config_entry method in conftest.py so + we can patch functions before the satellite task is run during setup. + """ + entry = MockConfigEntry( + domain="wyoming", + data={ + "host": "1.2.3.4", + "port": 1234, + }, + title="Test Satellite", + ) + entry.add_to_hass(hass) + await hass.config_entries.async_setup(entry.entry_id) + await hass.async_block_till_done() + + return entry + + +def get_test_wav() -> bytes: + """Get bytes for test WAV file.""" + with io.BytesIO() as wav_io: + with wave.open(wav_io, "wb") as wav_file: + wav_file.setframerate(22050) + wav_file.setsampwidth(2) + wav_file.setnchannels(1) + + # Single frame + wav_file.writeframes(b"123") + + return wav_io.getvalue() + + class SatelliteAsyncTcpClient(MockAsyncTcpClient): """Satellite AsyncTcpClient.""" @@ -35,13 +74,27 @@ def __init__(self, responses: list[Event]) -> None: self.detection_event = asyncio.Event() self.detection: Detection | None = None - self.stt_vad_start_event = asyncio.Event() + self.transcribe_event = asyncio.Event() + self.transcribe: Transcribe | None = None + + self.voice_started_event = asyncio.Event() self.voice_started: VoiceStarted | None = None - self.stt_vad_end_event = asyncio.Event() + self.voice_stopped_event = asyncio.Event() self.voice_stopped: VoiceStopped | None = None - self._audio_chunk = AudioChunk( + self.transcript_event = asyncio.Event() + self.transcript: Transcript | None = None + + self.synthesize_event = asyncio.Event() + self.synthesize: Synthesize | None = None + + self.tts_audio_start_event = asyncio.Event() + self.tts_audio_chunk_event = asyncio.Event() + self.tts_audio_stop_event = asyncio.Event() + self.tts_audio_chunk: AudioChunk | None = None + + self._mic_audio_chunk = AudioChunk( rate=16000, width=2, channels=1, audio=b"chunk" ).event() @@ -56,24 +109,40 @@ async def write_event(self, event: Event): elif Detect.is_type(event.type): self.detect_event.set() elif Detection.is_type(event.type): - self.detection_event.set() self.detection = Detection.from_event(event) + self.detection_event.set() + elif Transcribe.is_type(event.type): + self.transcribe = Transcribe.from_event(event) + self.transcribe_event.set() elif VoiceStarted.is_type(event.type): - self.stt_vad_start_event.set() self.voice_started = VoiceStarted.from_event(event) + self.voice_started_event.set() elif VoiceStopped.is_type(event.type): - self.stt_vad_end_event.set() self.voice_stopped = VoiceStopped.from_event(event) + self.voice_stopped_event.set() + elif Transcript.is_type(event.type): + self.transcript = Transcript.from_event(event) + self.transcript_event.set() + elif Synthesize.is_type(event.type): + self.synthesize = Synthesize.from_event(event) + self.synthesize_event.set() + elif AudioStart.is_type(event.type): + self.tts_audio_start_event.set() + elif AudioChunk.is_type(event.type): + self.tts_audio_chunk = AudioChunk.from_event(event) + self.tts_audio_chunk_event.set() + elif AudioStop.is_type(event.type): + self.tts_audio_stop_event.set() async def read_event(self) -> Event | None: """Receive.""" event = await super().read_event() - # Keep sending audio chunks - return event or self._audio_chunk + # Keep sending audio chunks instead of None + return event or self._mic_audio_chunk -async def test_satellite(hass: HomeAssistant) -> None: +async def test_satellite_pipeline(hass: HomeAssistant) -> None: """Test running a pipeline with a satellite.""" assert await async_setup_component(hass, assist_pipeline.DOMAIN, {}) @@ -91,20 +160,11 @@ async def test_satellite(hass: HomeAssistant) -> None: SatelliteAsyncTcpClient(events), ) as mock_client, patch( "homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream", - ) as mock_run_pipeline: - entry = MockConfigEntry( - domain="wyoming", - data={ - "host": "1.2.3.4", - "port": 1234, - }, - title="Test Satellite", - unique_id="1234_test", - ) - entry.add_to_hass(hass) - await hass.config_entries.async_setup(entry.entry_id) - await hass.async_block_till_done() - + ) as mock_run_pipeline, patch( + "homeassistant.components.wyoming.satellite.tts.async_get_media_source_audio", + return_value=("wav", get_test_wav()), + ): + entry = await setup_config_entry(hass) satellite_devices: SatelliteDevices = hass.data[wyoming.DOMAIN][ entry.entry_id ].satellite_devices @@ -146,22 +206,93 @@ async def test_satellite(hass: HomeAssistant) -> None: # "Assist in progress" sensor should be active now assert device.is_active - # Speech to text started + # Speech-to-text started event_callback( - assist_pipeline.PipelineEvent(assist_pipeline.PipelineEventType.STT_START) + assist_pipeline.PipelineEvent( + assist_pipeline.PipelineEventType.STT_START, + {"metadata": {"language": "en"}}, + ) ) + async with asyncio.timeout(1): + await mock_client.transcribe_event.wait() + + assert mock_client.transcribe is not None + assert mock_client.transcribe.language == "en" + # User started speaking event_callback( assist_pipeline.PipelineEvent( assist_pipeline.PipelineEventType.STT_VAD_START, {"timestamp": 1234} ) ) async with asyncio.timeout(1): - await mock_client.stt_vad_start_event.wait() + await mock_client.voice_started_event.wait() assert mock_client.voice_started is not None assert mock_client.voice_started.timestamp == 1234 + # User stopped speaking + event_callback( + assist_pipeline.PipelineEvent( + assist_pipeline.PipelineEventType.STT_VAD_END, {"timestamp": 5678} + ) + ) + async with asyncio.timeout(1): + await mock_client.voice_stopped_event.wait() + + assert mock_client.voice_stopped is not None + assert mock_client.voice_stopped.timestamp == 5678 + + # Speech-to-text transcription + event_callback( + assist_pipeline.PipelineEvent( + assist_pipeline.PipelineEventType.STT_END, + {"stt_output": {"text": "test transcript"}}, + ) + ) + async with asyncio.timeout(1): + await mock_client.transcript_event.wait() + + assert mock_client.transcript is not None + assert mock_client.transcript.text == "test transcript" + + # Text-to-speech text + event_callback( + assist_pipeline.PipelineEvent( + assist_pipeline.PipelineEventType.TTS_START, + { + "tts_input": "test text to speak", + "voice": "test voice", + }, + ) + ) + async with asyncio.timeout(1): + await mock_client.synthesize_event.wait() + + assert mock_client.synthesize is not None + assert mock_client.synthesize.text == "test text to speak" + assert mock_client.synthesize.voice is not None + assert mock_client.synthesize.voice.name == "test voice" + + # Text-to-speech media + event_callback( + assist_pipeline.PipelineEvent( + assist_pipeline.PipelineEventType.TTS_END, + {"tts_output": {"media_id": "test media id"}}, + ) + ) + async with asyncio.timeout(1): + await mock_client.tts_audio_start_event.wait() + await mock_client.tts_audio_chunk_event.wait() + await mock_client.tts_audio_stop_event.wait() + + # Verify audio chunk from test WAV + assert mock_client.tts_audio_chunk is not None + assert mock_client.tts_audio_chunk.rate == 22050 + assert mock_client.tts_audio_chunk.width == 2 + assert mock_client.tts_audio_chunk.channels == 1 + assert mock_client.tts_audio_chunk.audio == b"123" + # Pipeline finished event_callback( assist_pipeline.PipelineEvent(assist_pipeline.PipelineEventType.RUN_END) From 454640d5d50cab07d39ce4e807686c0d6d2e4f9e Mon Sep 17 00:00:00 2001 From: Michael Hansen Date: Fri, 1 Dec 2023 15:22:55 -0600 Subject: [PATCH 14/18] More tests --- .../wyoming/snapshots/test_config_flow.ambr | 42 ++++++++++ tests/components/wyoming/test_config_flow.py | 81 ++++++++++++++++++- tests/components/wyoming/test_data.py | 43 +++++++++- tests/components/wyoming/test_select.py | 62 ++++++++++++++ 4 files changed, 224 insertions(+), 4 deletions(-) diff --git a/tests/components/wyoming/snapshots/test_config_flow.ambr b/tests/components/wyoming/snapshots/test_config_flow.ambr index d4220a3972424..99f411027f5e7 100644 --- a/tests/components/wyoming/snapshots/test_config_flow.ambr +++ b/tests/components/wyoming/snapshots/test_config_flow.ambr @@ -121,3 +121,45 @@ 'version': 1, }) # --- +# name: test_zeroconf_discovery + FlowResultSnapshot({ + 'context': dict({ + 'name': 'Test Satellite', + 'source': 'zeroconf', + 'title_placeholders': dict({ + 'name': 'Test Satellite', + }), + 'unique_id': 'test_zeroconf_name._wyoming._tcp.local._Test Satellite', + }), + 'data': dict({ + 'host': '127.0.0.1', + 'port': 12345, + }), + 'description': None, + 'description_placeholders': None, + 'flow_id': , + 'handler': 'wyoming', + 'options': dict({ + }), + 'result': ConfigEntrySnapshot({ + 'data': dict({ + 'host': '127.0.0.1', + 'port': 12345, + }), + 'disabled_by': None, + 'domain': 'wyoming', + 'entry_id': , + 'options': dict({ + }), + 'pref_disable_new_entities': False, + 'pref_disable_polling': False, + 'source': 'zeroconf', + 'title': 'Test Satellite', + 'unique_id': 'test_zeroconf_name._wyoming._tcp.local._Test Satellite', + 'version': 1, + }), + 'title': 'Test Satellite', + 'type': , + 'version': 1, + }) +# --- diff --git a/tests/components/wyoming/test_config_flow.py b/tests/components/wyoming/test_config_flow.py index 896d3748ebdc6..f711b56b3bc5f 100644 --- a/tests/components/wyoming/test_config_flow.py +++ b/tests/components/wyoming/test_config_flow.py @@ -1,4 +1,5 @@ """Test the Wyoming config flow.""" +from ipaddress import IPv4Address from unittest.mock import AsyncMock, patch import pytest @@ -8,10 +9,11 @@ from homeassistant import config_entries from homeassistant.components.hassio import HassioServiceInfo from homeassistant.components.wyoming.const import DOMAIN +from homeassistant.components.zeroconf import ZeroconfServiceInfo from homeassistant.core import HomeAssistant from homeassistant.data_entry_flow import FlowResultType -from . import EMPTY_INFO, STT_INFO, TTS_INFO +from . import EMPTY_INFO, SATELLITE_INFO, STT_INFO, TTS_INFO from tests.common import MockConfigEntry @@ -25,6 +27,16 @@ uuid="1234", ) +ZEROCONF_DISCOVERY = ZeroconfServiceInfo( + ip_address=IPv4Address("127.0.0.1"), + ip_addresses=[IPv4Address("127.0.0.1")], + port=12345, + hostname="localhost", + type="_wyoming._tcp.local.", + name="test_zeroconf_name._wyoming._tcp.local.", + properties={}, +) + pytestmark = pytest.mark.usefixtures("mock_setup_entry") @@ -214,3 +226,70 @@ async def test_hassio_addon_no_supported_services(hass: HomeAssistant) -> None: assert result2.get("type") == FlowResultType.ABORT assert result2.get("reason") == "no_services" + + +async def test_zeroconf_discovery( + hass: HomeAssistant, + mock_setup_entry: AsyncMock, + snapshot: SnapshotAssertion, +) -> None: + """Test config flow initiated by Supervisor.""" + with patch( + "homeassistant.components.wyoming.data.load_wyoming_info", + return_value=SATELLITE_INFO, + ): + result = await hass.config_entries.flow.async_init( + DOMAIN, + data=ZEROCONF_DISCOVERY, + context={"source": config_entries.SOURCE_ZEROCONF}, + ) + + assert result.get("type") == FlowResultType.FORM + assert result.get("step_id") == "zeroconf_confirm" + assert result.get("description_placeholders") == { + "name": SATELLITE_INFO.satellite.name + } + + result2 = await hass.config_entries.flow.async_configure(result["flow_id"], {}) + assert result2.get("type") == FlowResultType.CREATE_ENTRY + assert result2 == snapshot + + +async def test_zeroconf_discovery_no_port( + hass: HomeAssistant, + mock_setup_entry: AsyncMock, + snapshot: SnapshotAssertion, +) -> None: + """Test discovery when the zeroconf service does not have a port.""" + with patch( + "homeassistant.components.wyoming.data.load_wyoming_info", + return_value=SATELLITE_INFO, + ), patch.object(ZEROCONF_DISCOVERY, "port", None): + result = await hass.config_entries.flow.async_init( + DOMAIN, + data=ZEROCONF_DISCOVERY, + context={"source": config_entries.SOURCE_ZEROCONF}, + ) + + assert result.get("type") == FlowResultType.ABORT + assert result.get("reason") == "no_port" + + +async def test_zeroconf_discovery_no_services( + hass: HomeAssistant, + mock_setup_entry: AsyncMock, + snapshot: SnapshotAssertion, +) -> None: + """Test discovery when there are no supported services on the client.""" + with patch( + "homeassistant.components.wyoming.data.load_wyoming_info", + return_value=Info(), + ): + result = await hass.config_entries.flow.async_init( + DOMAIN, + data=ZEROCONF_DISCOVERY, + context={"source": config_entries.SOURCE_ZEROCONF}, + ) + + assert result.get("type") == FlowResultType.ABORT + assert result.get("reason") == "no_services" diff --git a/tests/components/wyoming/test_data.py b/tests/components/wyoming/test_data.py index 0cb878c39c1ed..b7de9dbfdc1a1 100644 --- a/tests/components/wyoming/test_data.py +++ b/tests/components/wyoming/test_data.py @@ -3,13 +3,15 @@ from unittest.mock import patch -from homeassistant.components.wyoming.data import load_wyoming_info +from syrupy.assertion import SnapshotAssertion + +from homeassistant.components.wyoming.data import WyomingService, load_wyoming_info from homeassistant.core import HomeAssistant -from . import STT_INFO, MockAsyncTcpClient +from . import SATELLITE_INFO, STT_INFO, TTS_INFO, WAKE_WORD_INFO, MockAsyncTcpClient -async def test_load_info(hass: HomeAssistant, snapshot) -> None: +async def test_load_info(hass: HomeAssistant, snapshot: SnapshotAssertion) -> None: """Test loading info.""" with patch( "homeassistant.components.wyoming.data.AsyncTcpClient", @@ -38,3 +40,38 @@ async def test_load_info_oserror(hass: HomeAssistant) -> None: ) assert info is None + + +async def test_service_name(hass: HomeAssistant) -> None: + """Test loading service info.""" + with patch( + "homeassistant.components.wyoming.data.AsyncTcpClient", + MockAsyncTcpClient([STT_INFO.event()]), + ): + service = await WyomingService.create("localhost", 1234) + assert service is not None + assert service.get_name() == STT_INFO.asr[0].name + + with patch( + "homeassistant.components.wyoming.data.AsyncTcpClient", + MockAsyncTcpClient([TTS_INFO.event()]), + ): + service = await WyomingService.create("localhost", 1234) + assert service is not None + assert service.get_name() == TTS_INFO.tts[0].name + + with patch( + "homeassistant.components.wyoming.data.AsyncTcpClient", + MockAsyncTcpClient([WAKE_WORD_INFO.event()]), + ): + service = await WyomingService.create("localhost", 1234) + assert service is not None + assert service.get_name() == WAKE_WORD_INFO.wake[0].name + + with patch( + "homeassistant.components.wyoming.data.AsyncTcpClient", + MockAsyncTcpClient([SATELLITE_INFO.event()]), + ): + service = await WyomingService.create("localhost", 1234) + assert service is not None + assert service.get_name() == SATELLITE_INFO.satellite.name diff --git a/tests/components/wyoming/test_select.py b/tests/components/wyoming/test_select.py index 3b12210193b2f..7d3a6e07e2bf5 100644 --- a/tests/components/wyoming/test_select.py +++ b/tests/components/wyoming/test_select.py @@ -1,8 +1,13 @@ """Test Wyoming select.""" +from unittest.mock import Mock, patch + +from homeassistant.components import assist_pipeline +from homeassistant.components.assist_pipeline.pipeline import PipelineData from homeassistant.components.assist_pipeline.select import OPTION_PREFERRED from homeassistant.components.wyoming.devices import SatelliteDevice from homeassistant.config_entries import ConfigEntry from homeassistant.core import HomeAssistant +from homeassistant.setup import async_setup_component async def test_pipeline_select( @@ -15,9 +20,66 @@ async def test_pipeline_select( Functionality is tested in assist_pipeline/test_select.py. This test is only to ensure it is set up. """ + assert await async_setup_component(hass, assist_pipeline.DOMAIN, {}) + pipeline_data: PipelineData = hass.data[assist_pipeline.DOMAIN] + + # Create second pipeline + await pipeline_data.pipeline_store.async_create_item( + { + "name": "Test 1", + "language": "en-US", + "conversation_engine": None, + "conversation_language": "en-US", + "tts_engine": None, + "tts_language": None, + "tts_voice": None, + "stt_engine": None, + "stt_language": None, + "wake_word_entity": None, + "wake_word_id": None, + } + ) + + # Preferred pipeline is the default pipeline_entity_id = satellite_device.get_pipeline_entity_id(hass) assert pipeline_entity_id state = hass.states.get(pipeline_entity_id) assert state is not None assert state.state == OPTION_PREFERRED + + # Change to second pipeline + with patch.object( + satellite_device, "async_pipeline_changed" + ) as mock_pipeline_changed: + await hass.services.async_call( + "select", + "select_option", + {"entity_id": pipeline_entity_id, "option": "Test 1"}, + blocking=True, + ) + + state = hass.states.get(pipeline_entity_id) + assert state is not None + assert state.state == "Test 1" + + # async_pipeline_changed should have been called + mock_pipeline_changed.assert_called() + + # Change back and check update listener + update_listener = Mock() + satellite_device.async_listen_update(update_listener) + + await hass.services.async_call( + "select", + "select_option", + {"entity_id": pipeline_entity_id, "option": OPTION_PREFERRED}, + blocking=True, + ) + + state = hass.states.get(pipeline_entity_id) + assert state is not None + assert state.state == OPTION_PREFERRED + + # update listener should have been called + update_listener.assert_called() From a50e9dba73d0a4c8cbbb2a21dbe5adf35db6d903 Mon Sep 17 00:00:00 2001 From: Michael Hansen Date: Fri, 1 Dec 2023 16:34:18 -0600 Subject: [PATCH 15/18] More satellite tests --- homeassistant/components/wyoming/__init__.py | 14 +- homeassistant/components/wyoming/satellite.py | 42 +++-- tests/components/wyoming/test_satellite.py | 159 +++++++++++++++++- 3 files changed, 200 insertions(+), 15 deletions(-) diff --git a/homeassistant/components/wyoming/__init__.py b/homeassistant/components/wyoming/__init__.py index 19a487029d699..deec2de21a50d 100644 --- a/homeassistant/components/wyoming/__init__.py +++ b/homeassistant/components/wyoming/__init__.py @@ -10,7 +10,7 @@ from .const import ATTR_SPEAKER, DOMAIN from .data import WyomingService -from .devices import SatelliteDevices +from .devices import SatelliteDevice, SatelliteDevices from .models import DomainDataItem from .satellite import WyomingSatellite @@ -51,7 +51,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: name=satellite_info.name, suggested_area=satellite_info.area, ) - item.satellite = WyomingSatellite(hass, service, satellite_device) + item.satellite = _make_satellite(hass, service, satellite_device) entry.async_create_background_task( hass, item.satellite.run(), @@ -63,6 +63,16 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: return True +def _make_satellite( + hass: HomeAssistant, service: WyomingService, satellite_device: SatelliteDevice +) -> WyomingSatellite: + """Create WyomingSatellite from service/device. + + Used to make testing more convenient. + """ + return WyomingSatellite(hass, service, satellite_device) + + async def update_listener(hass: HomeAssistant, entry: ConfigEntry): """Handle options update.""" await hass.config_entries.async_reload(entry.entry_id) diff --git a/homeassistant/components/wyoming/satellite.py b/homeassistant/components/wyoming/satellite.py index df894cd201385..a53489afca0f5 100644 --- a/homeassistant/components/wyoming/satellite.py +++ b/homeassistant/components/wyoming/satellite.py @@ -72,21 +72,19 @@ async def run(self) -> None: try: while self.is_running: try: + # Check if satellite has been disabled if not self.is_enabled: - await self._device_updated_event.wait() + await self.on_disabled() if not self.is_running: # Satellite was stopped while waiting to be enabled break + # Connect and run pipeline loop await self._run_once() except asyncio.CancelledError: raise except Exception: # pylint: disable=broad-exception-caught - _LOGGER.exception( - "Unexpected error running satellite. Restarting in %s second(s)", - _RECONNECT_SECONDS, - ) - await asyncio.sleep(_RESTART_SECONDS) + await self.on_restart() finally: # Ensure sensor is off if self.device.is_active: @@ -94,7 +92,7 @@ async def run(self) -> None: remove_listener() - _LOGGER.debug("Satellite task stopped") + await self.on_stopped() def stop(self) -> None: """Signal satellite task to stop running.""" @@ -103,6 +101,30 @@ def stop(self) -> None: # Unblock waiting for enabled self._device_updated_event.set() + async def on_restart(self) -> None: + """Block until pipeline loop will be restarted.""" + _LOGGER.warning( + "Unexpected error running satellite. Restarting in %s second(s)", + _RECONNECT_SECONDS, + ) + await asyncio.sleep(_RESTART_SECONDS) + + async def on_reconnect(self) -> None: + """Block until a reconnection attempt should be made.""" + _LOGGER.debug( + "Failed to connect to satellite. Reconnecting in %s second(s)", + _RECONNECT_SECONDS, + ) + await asyncio.sleep(_RECONNECT_SECONDS) + + async def on_disabled(self) -> None: + """Block until device may be enabled again.""" + await self._device_updated_event.wait() + + async def on_stopped(self) -> None: + """Run when run() has fully stopped.""" + _LOGGER.debug("Satellite task stopped") + # ------------------------------------------------------------------------- def _device_updated(self, device: SatelliteDevice) -> None: @@ -139,11 +161,7 @@ async def _run_once(self) -> None: await self._connect() break except ConnectionError: - _LOGGER.debug( - "Failed to connect to satellite. Reconnecting in %s second(s)", - _RECONNECT_SECONDS, - ) - await asyncio.sleep(_RECONNECT_SECONDS) + await self.on_reconnect() assert self._client is not None _LOGGER.debug("Connected to satellite") diff --git a/tests/components/wyoming/test_satellite.py b/tests/components/wyoming/test_satellite.py index 7443bc7691c78..ac8fca4bce676 100644 --- a/tests/components/wyoming/test_satellite.py +++ b/tests/components/wyoming/test_satellite.py @@ -16,7 +16,9 @@ from wyoming.wake import Detect, Detection from homeassistant.components import assist_pipeline, wyoming -from homeassistant.components.wyoming.devices import SatelliteDevices +from homeassistant.components.wyoming.data import WyomingService +from homeassistant.components.wyoming.devices import SatelliteDevice, SatelliteDevices +from homeassistant.components.wyoming.satellite import WyomingSatellite from homeassistant.core import HomeAssistant from homeassistant.setup import async_setup_component @@ -302,3 +304,158 @@ async def test_satellite_pipeline(hass: HomeAssistant) -> None: # Stop the satellite await hass.config_entries.async_unload(entry.entry_id) await hass.async_block_till_done() + + +async def test_satellite_disabled(hass: HomeAssistant) -> None: + """Test callback for a satellite that has been disabled.""" + on_disabled_event = asyncio.Event() + + def make_disabled_satellite( + hass: HomeAssistant, service: WyomingService, device: SatelliteDevice + ): + device.is_enabled = False + satellite = WyomingSatellite(hass, service, device) + return satellite + + async def on_disabled(self): + on_disabled_event.set() + + with patch( + "homeassistant.components.wyoming.data.load_wyoming_info", + return_value=SATELLITE_INFO, + ), patch( + "homeassistant.components.wyoming._make_satellite", make_disabled_satellite + ), patch( + "homeassistant.components.wyoming.satellite.WyomingSatellite.on_disabled", + on_disabled, + ): + await setup_config_entry(hass) + async with asyncio.timeout(1): + await on_disabled_event.wait() + + +async def test_satellite_restart(hass: HomeAssistant) -> None: + """Test pipeline loop restart after unexpected error.""" + on_restart_event = asyncio.Event() + + async def on_restart(self): + self.stop() + on_restart_event.set() + + with patch( + "homeassistant.components.wyoming.data.load_wyoming_info", + return_value=SATELLITE_INFO, + ), patch( + "homeassistant.components.wyoming.satellite.WyomingSatellite._run_once", + side_effect=RuntimeError(), + ), patch( + "homeassistant.components.wyoming.satellite.WyomingSatellite.on_restart", + on_restart, + ): + await setup_config_entry(hass) + async with asyncio.timeout(1): + await on_restart_event.wait() + + +async def test_satellite_reconnect(hass: HomeAssistant) -> None: + """Test satellite reconnect call after connection refused.""" + on_reconnect_event = asyncio.Event() + + async def on_reconnect(self): + self.stop() + on_reconnect_event.set() + + with patch( + "homeassistant.components.wyoming.data.load_wyoming_info", + return_value=SATELLITE_INFO, + ), patch( + "homeassistant.components.wyoming.satellite.AsyncTcpClient.connect", + side_effect=ConnectionRefusedError(), + ), patch( + "homeassistant.components.wyoming.satellite.WyomingSatellite.on_reconnect", + on_reconnect, + ): + await setup_config_entry(hass) + async with asyncio.timeout(1): + await on_reconnect_event.wait() + + +async def test_satellite_disconnect_before_pipeline(hass: HomeAssistant) -> None: + """Test satellite disconnecting before pipeline run.""" + on_restart_event = asyncio.Event() + + async def on_restart(self): + self.stop() + on_restart_event.set() + + with patch( + "homeassistant.components.wyoming.data.load_wyoming_info", + return_value=SATELLITE_INFO, + ), patch( + "homeassistant.components.wyoming.satellite.AsyncTcpClient", + MockAsyncTcpClient([]), # no RunPipeline event + ), patch( + "homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream", + ) as mock_run_pipeline, patch( + "homeassistant.components.wyoming.satellite.WyomingSatellite.on_restart", + on_restart, + ): + await setup_config_entry(hass) + async with asyncio.timeout(1): + await on_restart_event.wait() + + # Pipeline should never have run + mock_run_pipeline.assert_not_called() + + +async def test_satellite_disconnect_during_pipeline(hass: HomeAssistant) -> None: + """Test satellite disconnecting during pipeline run.""" + events = [ + RunPipeline( + start_stage=PipelineStage.WAKE, end_stage=PipelineStage.TTS + ).event(), + ] # no audio chunks after RunPipeline + + on_restart_event = asyncio.Event() + on_stopped_event = asyncio.Event() + + async def on_restart(self): + # Pretend sensor got stuck on + self.device.is_active = True + self.stop() + on_restart_event.set() + + async def on_stopped(self): + on_stopped_event.set() + + with patch( + "homeassistant.components.wyoming.data.load_wyoming_info", + return_value=SATELLITE_INFO, + ), patch( + "homeassistant.components.wyoming.satellite.AsyncTcpClient", + MockAsyncTcpClient(events), + ), patch( + "homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream", + ) as mock_run_pipeline, patch( + "homeassistant.components.wyoming.satellite.WyomingSatellite.on_restart", + on_restart, + ), patch( + "homeassistant.components.wyoming.satellite.WyomingSatellite.on_stopped", + on_stopped, + ): + entry = await setup_config_entry(hass) + satellite_devices: SatelliteDevices = hass.data[wyoming.DOMAIN][ + entry.entry_id + ].satellite_devices + assert entry.entry_id in satellite_devices.devices + device = satellite_devices.devices[entry.entry_id] + + async with asyncio.timeout(1): + await on_restart_event.wait() + await on_stopped_event.wait() + + # Pipeline should have run once + mock_run_pipeline.assert_called_once() + + # Sensor should have been turned off + assert not device.is_active From 3332114f3109af14082158f2f52cc28f5ced07c1 Mon Sep 17 00:00:00 2001 From: Michael Hansen Date: Sat, 2 Dec 2023 21:43:16 -0600 Subject: [PATCH 16/18] Only support single device per config entry --- homeassistant/components/wyoming/__init__.py | 43 ++++++--- .../components/wyoming/binary_sensor.py | 18 +--- homeassistant/components/wyoming/devices.py | 95 +------------------ homeassistant/components/wyoming/models.py | 2 - homeassistant/components/wyoming/select.py | 20 ++-- homeassistant/components/wyoming/strings.json | 3 + homeassistant/components/wyoming/switch.py | 21 +--- tests/components/wyoming/conftest.py | 19 +--- tests/components/wyoming/test_devices.py | 8 +- tests/components/wyoming/test_satellite.py | 25 +++-- 10 files changed, 66 insertions(+), 188 deletions(-) diff --git a/homeassistant/components/wyoming/__init__.py b/homeassistant/components/wyoming/__init__.py index deec2de21a50d..ff278e7a8336d 100644 --- a/homeassistant/components/wyoming/__init__.py +++ b/homeassistant/components/wyoming/__init__.py @@ -7,10 +7,11 @@ from homeassistant.const import Platform from homeassistant.core import HomeAssistant from homeassistant.exceptions import ConfigEntryNotReady +from homeassistant.helpers import device_registry as dr from .const import ATTR_SPEAKER, DOMAIN from .data import WyomingService -from .devices import SatelliteDevice, SatelliteDevices +from .devices import SatelliteDevice from .models import DomainDataItem from .satellite import WyomingSatellite @@ -33,25 +34,20 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: if service is None: raise ConfigEntryNotReady("Unable to connect") - satellite_devices = SatelliteDevices(hass, entry) - satellite_devices.async_setup() - - item = DomainDataItem(service=service, satellite_devices=satellite_devices) + item = DomainDataItem(service=service) hass.data.setdefault(DOMAIN, {})[entry.entry_id] = item await hass.config_entries.async_forward_entry_setups(entry, service.platforms) entry.async_on_unload(entry.add_update_listener(update_listener)) if (satellite_info := service.info.satellite) is not None: + # Create satellite device, etc. + item.satellite = _make_satellite(hass, entry, service) + # Set up satellite sensors, switches, etc. await hass.config_entries.async_forward_entry_setups(entry, SATELLITE_PLATFORMS) - # Run satellite connection in a separate task - satellite_device = satellite_devices.async_get_or_create( - name=satellite_info.name, - suggested_area=satellite_info.area, - ) - item.satellite = _make_satellite(hass, service, satellite_device) + # Start satellite communication entry.async_create_background_task( hass, item.satellite.run(), @@ -64,12 +60,29 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: def _make_satellite( - hass: HomeAssistant, service: WyomingService, satellite_device: SatelliteDevice + hass: HomeAssistant, config_entry: ConfigEntry, service: WyomingService ) -> WyomingSatellite: - """Create WyomingSatellite from service/device. + """Create Wyoming satellite/device from config entry and Wyoming service.""" + satellite_info = service.info.satellite + assert satellite_info is not None + + dev_reg = dr.async_get(hass) + + # Use config entry id since only one satellite per entry is supported + satellite_id = config_entry.entry_id + + device = dev_reg.async_get_or_create( + config_entry_id=config_entry.entry_id, + identifiers={(DOMAIN, satellite_id)}, + name=satellite_info.name, + suggested_area=satellite_info.area, + ) + + satellite_device = SatelliteDevice( + satellite_id=satellite_id, + device_id=device.id, + ) - Used to make testing more convenient. - """ return WyomingSatellite(hass, service, satellite_device) diff --git a/homeassistant/components/wyoming/binary_sensor.py b/homeassistant/components/wyoming/binary_sensor.py index 07efc4516587d..5a920fec35ee5 100644 --- a/homeassistant/components/wyoming/binary_sensor.py +++ b/homeassistant/components/wyoming/binary_sensor.py @@ -26,21 +26,11 @@ async def async_setup_entry( async_add_entities: AddEntitiesCallback, ) -> None: """Set up binary sensor entities.""" - domain_data: DomainDataItem = hass.data[DOMAIN][config_entry.entry_id] + item: DomainDataItem = hass.data[DOMAIN][config_entry.entry_id] + if not item.satellite: + return - @callback - def async_add_device(device: SatelliteDevice) -> None: - """Add device.""" - async_add_entities([WyomingSatelliteAssistInProgress(device)]) - - domain_data.satellite_devices.async_add_new_device_listener(async_add_device) - - async_add_entities( - [ - WyomingSatelliteAssistInProgress(device) - for device in domain_data.satellite_devices - ] - ) + async_add_entities([WyomingSatelliteAssistInProgress(item.satellite.device)]) class WyomingSatelliteAssistInProgress(WyomingSatelliteEntity, BinarySensorEntity): diff --git a/homeassistant/components/wyoming/devices.py b/homeassistant/components/wyoming/devices.py index 53d38946fd7e8..432781b76c643 100644 --- a/homeassistant/components/wyoming/devices.py +++ b/homeassistant/components/wyoming/devices.py @@ -1,12 +1,11 @@ """Class to manage satellite devices.""" from __future__ import annotations -from collections.abc import Callable, Iterator +from collections.abc import Callable from dataclasses import dataclass, field -from homeassistant.config_entries import ConfigEntry -from homeassistant.core import Event, HomeAssistant, callback -from homeassistant.helpers import device_registry as dr, entity_registry as er +from homeassistant.core import HomeAssistant, callback +from homeassistant.helpers import entity_registry as er from .const import DOMAIN @@ -71,91 +70,3 @@ def get_pipeline_entity_id(self, hass: HomeAssistant) -> str | None: return ent_reg.async_get_entity_id( "select", DOMAIN, f"{self.satellite_id}-pipeline" ) - - -class SatelliteDevices: - """Class to store devices.""" - - def __init__(self, hass: HomeAssistant, config_entry: ConfigEntry) -> None: - """Initialize satellite devices.""" - self.hass = hass - self.config_entry = config_entry - self._new_device_listeners: list[Callable[[SatelliteDevice], None]] = [] - - # satellite_id -> device - self.devices: dict[str, SatelliteDevice] = {} - - @callback - def async_setup(self) -> None: - """Set up devices.""" - for device in dr.async_entries_for_config_entry( - dr.async_get(self.hass), self.config_entry.entry_id - ): - satellite_id = next( - (item[1] for item in device.identifiers if item[0] == DOMAIN), None - ) - if satellite_id is None: - continue - self.devices[satellite_id] = SatelliteDevice( - satellite_id=satellite_id, - device_id=device.id, - ) - - @callback - def async_device_removed(ev: Event) -> None: - """Handle device removed.""" - removed_id = ev.data["device_id"] - self.devices = { - satellite_id: satellite_device - for satellite_id, satellite_device in self.devices.items() - if satellite_device.device_id != removed_id - } - - self.config_entry.async_on_unload( - self.hass.bus.async_listen( - dr.EVENT_DEVICE_REGISTRY_UPDATED, - async_device_removed, - callback(lambda ev: ev.data.get("action") == "remove"), - ) - ) - - @callback - def async_add_new_device_listener( - self, listener: Callable[[SatelliteDevice], None] - ) -> None: - """Add a new device listener.""" - self._new_device_listeners.append(listener) - - @callback - def async_get_or_create( - self, name: str | None = None, suggested_area: str | None = None - ) -> SatelliteDevice: - """Get or create a device.""" - dev_reg = dr.async_get(self.hass) - - # Use config entry id since only one satellite per entry is supported - satellite_id = self.config_entry.entry_id - satellite_device = self.devices.get(satellite_id) - - if satellite_device is not None: - return satellite_device - - device = dev_reg.async_get_or_create( - config_entry_id=self.config_entry.entry_id, - identifiers={(DOMAIN, satellite_id)}, - name=name or satellite_id, - suggested_area=suggested_area, - ) - - satellite_device = self.devices[satellite_id] = SatelliteDevice( - satellite_id=satellite_id, - device_id=device.id, - ) - for listener in self._new_device_listeners: - listener(satellite_device) - - return satellite_device - - def __iter__(self) -> Iterator[SatelliteDevice]: - """Iterate over devices.""" - return iter(self.devices.values()) diff --git a/homeassistant/components/wyoming/models.py b/homeassistant/components/wyoming/models.py index c29e28ea3f3f1..dce45d509eb86 100644 --- a/homeassistant/components/wyoming/models.py +++ b/homeassistant/components/wyoming/models.py @@ -2,7 +2,6 @@ from dataclasses import dataclass from .data import WyomingService -from .devices import SatelliteDevices from .satellite import WyomingSatellite @@ -11,5 +10,4 @@ class DomainDataItem: """Domain data item.""" service: WyomingService - satellite_devices: SatelliteDevices satellite: WyomingSatellite | None = None diff --git a/homeassistant/components/wyoming/select.py b/homeassistant/components/wyoming/select.py index b9ec0794a5b47..c9417215b368f 100644 --- a/homeassistant/components/wyoming/select.py +++ b/homeassistant/components/wyoming/select.py @@ -6,7 +6,7 @@ from homeassistant.components.assist_pipeline.select import AssistPipelineSelect from homeassistant.config_entries import ConfigEntry -from homeassistant.core import HomeAssistant, callback +from homeassistant.core import HomeAssistant from homeassistant.helpers.entity_platform import AddEntitiesCallback from .const import DOMAIN @@ -23,20 +23,12 @@ async def async_setup_entry( async_add_entities: AddEntitiesCallback, ) -> None: """Set up VoIP switch entities.""" - domain_data: DomainDataItem = hass.data[DOMAIN][config_entry.entry_id] + item: DomainDataItem = hass.data[DOMAIN][config_entry.entry_id] + if not item.satellite: + return - @callback - def async_add_device(device: SatelliteDevice) -> None: - """Add device.""" - async_add_entities([WyomingSatellitePipelineSelect(hass, device)]) - - domain_data.satellite_devices.async_add_new_device_listener(async_add_device) - - entities: list[WyomingSatelliteEntity] = [] - for device in domain_data.satellite_devices: - entities.append(WyomingSatellitePipelineSelect(hass, device)) - - async_add_entities(entities) + device = item.satellite.device + async_add_entities([WyomingSatellitePipelineSelect(hass, device)]) class WyomingSatellitePipelineSelect(WyomingSatelliteEntity, AssistPipelineSelect): diff --git a/homeassistant/components/wyoming/strings.json b/homeassistant/components/wyoming/strings.json index f5a8a2de2ce32..19b6a513d4ba4 100644 --- a/homeassistant/components/wyoming/strings.json +++ b/homeassistant/components/wyoming/strings.json @@ -36,6 +36,9 @@ "state": { "preferred": "[%key:component::assist_pipeline::entity::select::pipeline::state::preferred%]" } + }, + "noise_suppression": { + "name": "Noise suppression" } }, "switch": { diff --git a/homeassistant/components/wyoming/switch.py b/homeassistant/components/wyoming/switch.py index 374b953cc8235..2f2f552ec1235 100644 --- a/homeassistant/components/wyoming/switch.py +++ b/homeassistant/components/wyoming/switch.py @@ -7,12 +7,11 @@ from homeassistant.components.switch import SwitchEntity, SwitchEntityDescription from homeassistant.config_entries import ConfigEntry from homeassistant.const import STATE_ON, EntityCategory -from homeassistant.core import HomeAssistant, callback +from homeassistant.core import HomeAssistant from homeassistant.helpers import restore_state from homeassistant.helpers.entity_platform import AddEntitiesCallback from .const import DOMAIN -from .devices import SatelliteDevice from .entity import WyomingSatelliteEntity if TYPE_CHECKING: @@ -25,21 +24,11 @@ async def async_setup_entry( async_add_entities: AddEntitiesCallback, ) -> None: """Set up VoIP switch entities.""" - domain_data: DomainDataItem = hass.data[DOMAIN][config_entry.entry_id] + item: DomainDataItem = hass.data[DOMAIN][config_entry.entry_id] + if not item.satellite: + return - @callback - def async_add_device(device: SatelliteDevice) -> None: - """Add device.""" - async_add_entities([WyomingSatelliteEnabledSwitch(device)]) - - domain_data.satellite_devices.async_add_new_device_listener(async_add_device) - - async_add_entities( - [ - WyomingSatelliteEnabledSwitch(device) - for device in domain_data.satellite_devices - ] - ) + async_add_entities([WyomingSatelliteEnabledSwitch(item.satellite.device)]) class WyomingSatelliteEnabledSwitch( diff --git a/tests/components/wyoming/conftest.py b/tests/components/wyoming/conftest.py index c70c40342e377..a30c1048eb6e9 100644 --- a/tests/components/wyoming/conftest.py +++ b/tests/components/wyoming/conftest.py @@ -6,7 +6,7 @@ from homeassistant.components import stt from homeassistant.components.wyoming import DOMAIN -from homeassistant.components.wyoming.devices import SatelliteDevice, SatelliteDevices +from homeassistant.components.wyoming.devices import SatelliteDevice from homeassistant.config_entries import ConfigEntry from homeassistant.core import HomeAssistant from homeassistant.setup import async_setup_component @@ -149,22 +149,9 @@ async def init_satellite(hass: HomeAssistant, satellite_config_entry: ConfigEntr await hass.config_entries.async_setup(satellite_config_entry.entry_id) -@pytest.fixture -async def satellite_devices( - hass: HomeAssistant, init_satellite, satellite_config_entry: ConfigEntry -) -> SatelliteDevices: - """Get satellite devices object from a configured instance.""" - return hass.data[DOMAIN][satellite_config_entry.entry_id].satellite_devices - - @pytest.fixture async def satellite_device( - hass: HomeAssistant, - satellite_devices: SatelliteDevices, - satellite_config_entry: ConfigEntry, + hass: HomeAssistant, init_satellite, satellite_config_entry: ConfigEntry ) -> SatelliteDevice: """Get a satellite device fixture.""" - device = satellite_devices.async_get_or_create() - # to make sure all platforms are set up - await hass.async_block_till_done() - return device + return hass.data[DOMAIN][satellite_config_entry.entry_id].satellite.device diff --git a/tests/components/wyoming/test_devices.py b/tests/components/wyoming/test_devices.py index 58616673cd5b8..549f76f20f1c2 100644 --- a/tests/components/wyoming/test_devices.py +++ b/tests/components/wyoming/test_devices.py @@ -3,7 +3,7 @@ from homeassistant.components.assist_pipeline.select import OPTION_PREFERRED from homeassistant.components.wyoming import DOMAIN -from homeassistant.components.wyoming.devices import SatelliteDevice, SatelliteDevices +from homeassistant.components.wyoming.devices import SatelliteDevice from homeassistant.config_entries import ConfigEntry from homeassistant.const import STATE_OFF, STATE_ON from homeassistant.core import HomeAssistant @@ -12,12 +12,11 @@ async def test_device_registry_info( hass: HomeAssistant, - satellite_devices: SatelliteDevices, + satellite_device: SatelliteDevice, satellite_config_entry: ConfigEntry, device_registry: dr.DeviceRegistry, ) -> None: """Test info in device registry.""" - satellite_device = satellite_devices.async_get_or_create() # Satellite uses config entry id since only one satellite per entry is # supported. @@ -50,12 +49,10 @@ async def test_device_registry_info( async def test_remove_device_registry_entry( hass: HomeAssistant, - satellite_devices: SatelliteDevices, satellite_device: SatelliteDevice, device_registry: dr.DeviceRegistry, ) -> None: """Test removing a device registry entry.""" - assert satellite_device.satellite_id in satellite_devices.devices # Check associated entities assist_in_progress_id = satellite_device.get_assist_in_progress_entity_id(hass) @@ -79,4 +76,3 @@ async def test_remove_device_registry_entry( assert hass.states.get(assist_in_progress_id) is None assert hass.states.get(satellite_enabled_id) is None assert hass.states.get(pipeline_entity_id) is None - assert satellite_device.satellite_id not in satellite_devices.devices diff --git a/tests/components/wyoming/test_satellite.py b/tests/components/wyoming/test_satellite.py index ac8fca4bce676..06ae337a19cd2 100644 --- a/tests/components/wyoming/test_satellite.py +++ b/tests/components/wyoming/test_satellite.py @@ -17,8 +17,8 @@ from homeassistant.components import assist_pipeline, wyoming from homeassistant.components.wyoming.data import WyomingService -from homeassistant.components.wyoming.devices import SatelliteDevice, SatelliteDevices -from homeassistant.components.wyoming.satellite import WyomingSatellite +from homeassistant.components.wyoming.devices import SatelliteDevice +from homeassistant.config_entries import ConfigEntry from homeassistant.core import HomeAssistant from homeassistant.setup import async_setup_component @@ -167,11 +167,9 @@ async def test_satellite_pipeline(hass: HomeAssistant) -> None: return_value=("wav", get_test_wav()), ): entry = await setup_config_entry(hass) - satellite_devices: SatelliteDevices = hass.data[wyoming.DOMAIN][ + device: SatelliteDevice = hass.data[wyoming.DOMAIN][ entry.entry_id - ].satellite_devices - assert entry.entry_id in satellite_devices.devices - device = satellite_devices.devices[entry.entry_id] + ].satellite.device async with asyncio.timeout(1): await mock_client.connect_event.wait() @@ -310,11 +308,14 @@ async def test_satellite_disabled(hass: HomeAssistant) -> None: """Test callback for a satellite that has been disabled.""" on_disabled_event = asyncio.Event() + original_make_satellite = wyoming._make_satellite + def make_disabled_satellite( - hass: HomeAssistant, service: WyomingService, device: SatelliteDevice + hass: HomeAssistant, config_entry: ConfigEntry, service: WyomingService ): - device.is_enabled = False - satellite = WyomingSatellite(hass, service, device) + satellite = original_make_satellite(hass, config_entry, service) + satellite.device.is_enabled = False + return satellite async def on_disabled(self): @@ -444,11 +445,9 @@ async def on_stopped(self): on_stopped, ): entry = await setup_config_entry(hass) - satellite_devices: SatelliteDevices = hass.data[wyoming.DOMAIN][ + device: SatelliteDevice = hass.data[wyoming.DOMAIN][ entry.entry_id - ].satellite_devices - assert entry.entry_id in satellite_devices.devices - device = satellite_devices.devices[entry.entry_id] + ].satellite.device async with asyncio.timeout(1): await on_restart_event.wait() From c4bf29d68542b5d4d8c54d63e5bd277f9a4435ed Mon Sep 17 00:00:00 2001 From: Michael Hansen Date: Sun, 3 Dec 2023 14:40:09 -0600 Subject: [PATCH 17/18] Address comments --- homeassistant/components/wyoming/__init__.py | 9 +- .../components/wyoming/binary_sensor.py | 10 +-- homeassistant/components/wyoming/devices.py | 51 ++++++----- homeassistant/components/wyoming/satellite.py | 84 ++++++++----------- homeassistant/components/wyoming/select.py | 10 +-- homeassistant/components/wyoming/switch.py | 9 +- tests/components/wyoming/test_select.py | 14 ++-- tests/components/wyoming/test_switch.py | 27 ------ 8 files changed, 94 insertions(+), 120 deletions(-) diff --git a/homeassistant/components/wyoming/__init__.py b/homeassistant/components/wyoming/__init__.py index ff278e7a8336d..22bd3aefe2149 100644 --- a/homeassistant/components/wyoming/__init__.py +++ b/homeassistant/components/wyoming/__init__.py @@ -95,10 +95,11 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: """Unload Wyoming.""" item: DomainDataItem = hass.data[DOMAIN][entry.entry_id] - unload_ok = await hass.config_entries.async_unload_platforms( - entry, - item.service.platforms, - ) + platforms = item.service.platforms + if item.satellite is not None: + platforms += SATELLITE_PLATFORMS + + unload_ok = await hass.config_entries.async_unload_platforms(entry, platforms) if unload_ok: del hass.data[DOMAIN][entry.entry_id] diff --git a/homeassistant/components/wyoming/binary_sensor.py b/homeassistant/components/wyoming/binary_sensor.py index 5a920fec35ee5..4f2c0bb170acb 100644 --- a/homeassistant/components/wyoming/binary_sensor.py +++ b/homeassistant/components/wyoming/binary_sensor.py @@ -14,7 +14,6 @@ from .const import DOMAIN from .entity import WyomingSatelliteEntity -from .satellite import SatelliteDevice if TYPE_CHECKING: from .models import DomainDataItem @@ -27,8 +26,9 @@ async def async_setup_entry( ) -> None: """Set up binary sensor entities.""" item: DomainDataItem = hass.data[DOMAIN][config_entry.entry_id] - if not item.satellite: - return + + # Setup is only forwarded for satellites + assert item.satellite is not None async_add_entities([WyomingSatelliteAssistInProgress(item.satellite.device)]) @@ -46,10 +46,10 @@ async def async_added_to_hass(self) -> None: """Call when entity about to be added to hass.""" await super().async_added_to_hass() - self.async_on_remove(self._device.async_listen_update(self._is_active_changed)) + self._device.set_is_active_listener(self._is_active_changed) @callback - def _is_active_changed(self, device: SatelliteDevice) -> None: + def _is_active_changed(self) -> None: """Call when active state changed.""" self._attr_is_on = self._device.is_active self.async_write_ha_state() diff --git a/homeassistant/components/wyoming/devices.py b/homeassistant/components/wyoming/devices.py index 432781b76c643..90dad8897078b 100644 --- a/homeassistant/components/wyoming/devices.py +++ b/homeassistant/components/wyoming/devices.py @@ -2,7 +2,7 @@ from __future__ import annotations from collections.abc import Callable -from dataclasses import dataclass, field +from dataclasses import dataclass from homeassistant.core import HomeAssistant, callback from homeassistant.helpers import entity_registry as er @@ -18,37 +18,50 @@ class SatelliteDevice: device_id: str is_active: bool = False is_enabled: bool = True - update_listeners: list[Callable[[SatelliteDevice], None]] = field( - default_factory=list - ) + pipeline_name: str | None = None + + _is_active_listener: Callable[[], None] | None = None + _is_enabled_listener: Callable[[], None] | None = None + _pipeline_listener: Callable[[], None] | None = None @callback def set_is_active(self, active: bool) -> None: """Set active state.""" - self.is_active = active - for listener in self.update_listeners: - listener(self) + if active != self.is_active: + self.is_active = active + if self._is_active_listener is not None: + self._is_active_listener() @callback def set_is_enabled(self, enabled: bool) -> None: """Set enabled state.""" - self.is_enabled = enabled - for listener in self.update_listeners: - listener(self) + if enabled != self.is_enabled: + self.is_enabled = enabled + if self._is_enabled_listener is not None: + self._is_enabled_listener() @callback - def async_pipeline_changed(self) -> None: + def set_pipeline_name(self, pipeline_name: str) -> None: """Inform listeners that pipeline selection has changed.""" - for listener in self.update_listeners: - listener(self) + if pipeline_name != self.pipeline_name: + self.pipeline_name = pipeline_name + if self._pipeline_listener is not None: + self._pipeline_listener() + + @callback + def set_is_active_listener(self, is_active_listener: Callable[[], None]) -> None: + """Listen for updates to is_active.""" + self._is_active_listener = is_active_listener + + @callback + def set_is_enabled_listener(self, is_enabled_listener: Callable[[], None]) -> None: + """Listen for updates to is_enabled.""" + self._is_enabled_listener = is_enabled_listener @callback - def async_listen_update( - self, listener: Callable[[SatelliteDevice], None] - ) -> Callable[[], None]: - """Listen for updates.""" - self.update_listeners.append(listener) - return lambda: self.update_listeners.remove(listener) + def set_pipeline_listener(self, pipeline_listener: Callable[[], None]) -> None: + """Listen for updates to pipeline.""" + self._pipeline_listener = pipeline_listener def get_assist_in_progress_entity_id(self, hass: HomeAssistant) -> str | None: """Return entity id for assist in progress binary sensor.""" diff --git a/homeassistant/components/wyoming/satellite.py b/homeassistant/components/wyoming/satellite.py index a53489afca0f5..caf65db115eea 100644 --- a/homeassistant/components/wyoming/satellite.py +++ b/homeassistant/components/wyoming/satellite.py @@ -56,24 +56,20 @@ def __init__( self._is_pipeline_running = False self._audio_queue: asyncio.Queue[bytes | None] = asyncio.Queue() self._pipeline_id: str | None = None - self._device_updated_event = asyncio.Event() + self._enabled_changed_event = asyncio.Event() + + self.device.set_is_enabled_listener(self._enabled_changed) + self.device.set_pipeline_listener(self._pipeline_changed) async def run(self) -> None: """Run and maintain a connection to satellite.""" _LOGGER.debug("Running satellite task") - self._pipeline_id = pipeline_select.get_chosen_pipeline( - self.hass, - DOMAIN, - self.device.satellite_id, - ) - self.is_enabled = self.device.is_enabled - remove_listener = self.device.async_listen_update(self._device_updated) try: while self.is_running: try: # Check if satellite has been disabled - if not self.is_enabled: + if not self.device.is_enabled: await self.on_disabled() if not self.is_running: # Satellite was stopped while waiting to be enabled @@ -87,10 +83,7 @@ async def run(self) -> None: await self.on_restart() finally: # Ensure sensor is off - if self.device.is_active: - self.device.set_is_active(False) - - remove_listener() + self.device.set_is_active(False) await self.on_stopped() @@ -99,7 +92,7 @@ def stop(self) -> None: self.is_running = False # Unblock waiting for enabled - self._device_updated_event.set() + self._enabled_changed_event.set() async def on_restart(self) -> None: """Block until pipeline loop will be restarted.""" @@ -119,7 +112,7 @@ async def on_reconnect(self) -> None: async def on_disabled(self) -> None: """Block until device may be enabled again.""" - await self._device_updated_event.wait() + await self._enabled_changed_event.wait() async def on_stopped(self) -> None: """Run when run() has fully stopped.""" @@ -127,34 +120,24 @@ async def on_stopped(self) -> None: # ------------------------------------------------------------------------- - def _device_updated(self, device: SatelliteDevice) -> None: - """Reacts to updated device settings.""" - pipeline_id = pipeline_select.get_chosen_pipeline( - self.hass, - DOMAIN, - self.device.satellite_id, - ) + def _enabled_changed(self) -> None: + """Run when device enabled status changes.""" - stop_pipeline = False - if self._pipeline_id != pipeline_id: - # Pipeline has changed - self._pipeline_id = pipeline_id - stop_pipeline = True + if not self.device.is_enabled: + # Cancel any running pipeline + self._audio_queue.put_nowait(None) - if self.is_enabled and (not self.device.is_enabled): - stop_pipeline = True + self._enabled_changed_event.set() - self.is_enabled = self.device.is_enabled - self._device_updated_event.set() - self._device_updated_event.clear() + def _pipeline_changed(self) -> None: + """Run when device pipeline changes.""" - if stop_pipeline: - self._audio_queue.put_nowait(None) + # Cancel any running pipeline + self._audio_queue.put_nowait(None) async def _run_once(self) -> None: """Run pipelines until an error occurs.""" - if self.device.is_active: - self.device.set_is_active(False) + self.device.set_is_active(False) while self.is_running and self.is_enabled: try: @@ -167,7 +150,7 @@ async def _run_once(self) -> None: _LOGGER.debug("Connected to satellite") if (not self.is_running) or (not self.is_enabled): - # Run was cancelled or satellite was disabled + # Run was cancelled or satellite was disabled during connection return # Tell satellite that we're ready @@ -175,7 +158,7 @@ async def _run_once(self) -> None: # Wait until we get RunPipeline event run_pipeline: RunPipeline | None = None - while True: + while self.is_running and self.is_enabled: run_event = await self._client.read_event() if run_event is None: raise ConnectionResetError("Satellite disconnected") @@ -189,8 +172,9 @@ async def _run_once(self) -> None: assert run_pipeline is not None _LOGGER.debug("Received run information: %s", run_pipeline) - if not self.is_running: - # Run was cancelled + if (not self.is_running) or (not self.is_enabled): + # Run was cancelled or satellite was disabled while waiting for + # RunPipeline event. return start_stage = _STAGES.get(run_pipeline.start_stage) @@ -205,7 +189,12 @@ async def _run_once(self) -> None: # Each loop is a pipeline run while self.is_running and self.is_enabled: # Use select to get pipeline each time in case it's changed - pipeline = assist_pipeline.async_get_pipeline(self.hass, self._pipeline_id) + pipeline_id = pipeline_select.get_chosen_pipeline( + self.hass, + DOMAIN, + self.device.satellite_id, + ) + pipeline = assist_pipeline.async_get_pipeline(self.hass, pipeline_id) assert pipeline is not None # We will push audio in through a queue @@ -237,10 +226,11 @@ async def _run_once(self) -> None: start_stage=start_stage, end_stage=end_stage, tts_audio_output="wav", - pipeline_id=self._pipeline_id, + pipeline_id=pipeline_id, ) ) + # Run until pipeline is complete or cancelled with an empty audio chunk while self._is_pipeline_running: client_event = await self._client.read_event() if client_event is None: @@ -261,15 +251,14 @@ def _event_callback(self, event: assist_pipeline.PipelineEvent) -> None: assert self._client is not None if event.type == assist_pipeline.PipelineEventType.RUN_END: + # Pipeline run is complete self._is_pipeline_running = False - if self.device.is_active: - self.device.set_is_active(False) + self.device.set_is_active(False) elif event.type == assist_pipeline.PipelineEventType.WAKE_WORD_START: self.hass.add_job(self._client.write_event(Detect().event())) elif event.type == assist_pipeline.PipelineEventType.WAKE_WORD_END: # Wake word detection - if not self.device.is_active: - self.device.set_is_active(True) + self.device.set_is_active(True) # Inform client of wake word detection if event.data and (wake_word_output := event.data.get("wake_word_output")): @@ -280,8 +269,7 @@ def _event_callback(self, event: assist_pipeline.PipelineEvent) -> None: self.hass.add_job(self._client.write_event(detection.event())) elif event.type == assist_pipeline.PipelineEventType.STT_START: # Speech-to-text - if not self.device.is_active: - self.device.set_is_active(True) + self.device.set_is_active(True) if event.data: self.hass.add_job( diff --git a/homeassistant/components/wyoming/select.py b/homeassistant/components/wyoming/select.py index c9417215b368f..2929ae79fa022 100644 --- a/homeassistant/components/wyoming/select.py +++ b/homeassistant/components/wyoming/select.py @@ -24,11 +24,11 @@ async def async_setup_entry( ) -> None: """Set up VoIP switch entities.""" item: DomainDataItem = hass.data[DOMAIN][config_entry.entry_id] - if not item.satellite: - return - device = item.satellite.device - async_add_entities([WyomingSatellitePipelineSelect(hass, device)]) + # Setup is only forwarded for satellites + assert item.satellite is not None + + async_add_entities([WyomingSatellitePipelineSelect(hass, item.satellite.device)]) class WyomingSatellitePipelineSelect(WyomingSatelliteEntity, AssistPipelineSelect): @@ -44,4 +44,4 @@ def __init__(self, hass: HomeAssistant, device: SatelliteDevice) -> None: async def async_select_option(self, option: str) -> None: """Select an option.""" await super().async_select_option(option) - self.device.async_pipeline_changed() + self.device.set_pipeline_name(option) diff --git a/homeassistant/components/wyoming/switch.py b/homeassistant/components/wyoming/switch.py index 2f2f552ec1235..2bc4312258891 100644 --- a/homeassistant/components/wyoming/switch.py +++ b/homeassistant/components/wyoming/switch.py @@ -25,8 +25,9 @@ async def async_setup_entry( ) -> None: """Set up VoIP switch entities.""" item: DomainDataItem = hass.data[DOMAIN][config_entry.entry_id] - if not item.satellite: - return + + # Setup is only forwarded for satellites + assert item.satellite is not None async_add_entities([WyomingSatelliteEnabledSwitch(item.satellite.device)]) @@ -54,11 +55,11 @@ async def async_added_to_hass(self) -> None: async def async_turn_on(self, **kwargs: Any) -> None: """Turn on.""" self._attr_is_on = True - self._device.set_is_enabled(True) self.async_write_ha_state() + self._device.set_is_enabled(True) async def async_turn_off(self, **kwargs: Any) -> None: """Turn off.""" self._attr_is_on = False - self._device.set_is_enabled(False) self.async_write_ha_state() + self._device.set_is_enabled(False) diff --git a/tests/components/wyoming/test_select.py b/tests/components/wyoming/test_select.py index 7d3a6e07e2bf5..cab699336fb57 100644 --- a/tests/components/wyoming/test_select.py +++ b/tests/components/wyoming/test_select.py @@ -49,9 +49,7 @@ async def test_pipeline_select( assert state.state == OPTION_PREFERRED # Change to second pipeline - with patch.object( - satellite_device, "async_pipeline_changed" - ) as mock_pipeline_changed: + with patch.object(satellite_device, "set_pipeline_name") as mock_pipeline_changed: await hass.services.async_call( "select", "select_option", @@ -64,11 +62,11 @@ async def test_pipeline_select( assert state.state == "Test 1" # async_pipeline_changed should have been called - mock_pipeline_changed.assert_called() + mock_pipeline_changed.assert_called_once_with("Test 1") # Change back and check update listener - update_listener = Mock() - satellite_device.async_listen_update(update_listener) + pipeline_listener = Mock() + satellite_device.set_pipeline_listener(pipeline_listener) await hass.services.async_call( "select", @@ -81,5 +79,5 @@ async def test_pipeline_select( assert state is not None assert state.state == OPTION_PREFERRED - # update listener should have been called - update_listener.assert_called() + # listener should have been called + pipeline_listener.assert_called_once() diff --git a/tests/components/wyoming/test_switch.py b/tests/components/wyoming/test_switch.py index 5b39db38e5c03..0b05724d76198 100644 --- a/tests/components/wyoming/test_switch.py +++ b/tests/components/wyoming/test_switch.py @@ -19,13 +19,6 @@ async def test_satellite_enabled( assert state.state == STATE_ON assert satellite_device.is_enabled - await hass.config_entries.async_reload(satellite_config_entry.entry_id) - - state = hass.states.get(satellite_enabled_id) - assert state is not None - assert state.state == STATE_ON - assert satellite_device.is_enabled - await hass.services.async_call( "switch", "turn_off", @@ -37,23 +30,3 @@ async def test_satellite_enabled( assert state is not None assert state.state == STATE_OFF assert not satellite_device.is_enabled - - await hass.config_entries.async_reload(satellite_config_entry.entry_id) - await hass.async_block_till_done() - - state = hass.states.get(satellite_enabled_id) - assert state is not None - assert state.state == STATE_OFF - assert not satellite_device.is_enabled - - await hass.services.async_call( - "switch", - "turn_on", - {"entity_id": satellite_enabled_id}, - blocking=True, - ) - - state = hass.states.get(satellite_enabled_id) - assert state is not None - assert state.state == STATE_ON - assert satellite_device.is_enabled From 19fe2935d46e825cb7842fadd9cfef0239cc16ba Mon Sep 17 00:00:00 2001 From: Michael Hansen Date: Mon, 4 Dec 2023 10:58:18 -0600 Subject: [PATCH 18/18] Make a copy of platforms --- homeassistant/components/wyoming/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/homeassistant/components/wyoming/__init__.py b/homeassistant/components/wyoming/__init__.py index 22bd3aefe2149..2cc9b7050a005 100644 --- a/homeassistant/components/wyoming/__init__.py +++ b/homeassistant/components/wyoming/__init__.py @@ -95,7 +95,7 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: """Unload Wyoming.""" item: DomainDataItem = hass.data[DOMAIN][entry.entry_id] - platforms = item.service.platforms + platforms = list(item.service.platforms) if item.satellite is not None: platforms += SATELLITE_PLATFORMS