From 9679ab01b47e88b49ab411fa6dc3d7c47d895254 Mon Sep 17 00:00:00 2001 From: "firstof9@gmail.com" Date: Thu, 16 Mar 2023 13:00:00 -0700 Subject: [PATCH] feat: add gps location option --- custom_components/nws_alerts/__init__.py | 30 +++++-- custom_components/nws_alerts/config_flow.py | 68 ++++++++++++++-- custom_components/nws_alerts/const.py | 1 + custom_components/nws_alerts/sensor.py | 31 +++---- .../nws_alerts/translations/en.json | 34 +++++++- tests/const.py | 4 +- tests/test_config_flow.py | 81 ++++++++++++++++++- tests/test_init.py | 5 +- tests/test_sensor.py | 5 +- 9 files changed, 221 insertions(+), 38 deletions(-) diff --git a/custom_components/nws_alerts/__init__.py b/custom_components/nws_alerts/__init__.py index 7386b6f..aab0d53 100644 --- a/custom_components/nws_alerts/__init__.py +++ b/custom_components/nws_alerts/__init__.py @@ -17,6 +17,7 @@ from .const import ( API_ENDPOINT, + CONF_GPS_LOC, CONF_INTERVAL, CONF_TIMEOUT, CONF_ZONE_ID, @@ -146,14 +147,22 @@ async def update_alerts(config) -> dict: async def async_get_state(config) -> dict: """Query API for status.""" + zone_id = "" + gps_loc = "" + url = "%s/alerts/active/count" % API_ENDPOINT values = {} headers = {"User-Agent": USER_AGENT, "Accept": "application/ld+json"} data = None - url = "%s/alerts/active/count" % API_ENDPOINT - zone_id = config[CONF_ZONE_ID] + + if CONF_ZONE_ID in config: + zone_id = config[CONF_ZONE_ID] + _LOGGER.debug("getting state for %s from %s" % (zone_id, url)) + elif CONF_GPS_LOC in config: + gps_loc = config[CONF_GPS_LOC] + _LOGGER.debug("getting state for %s from %s" % (gps_loc, url)) + async with aiohttp.ClientSession() as session: async with session.get(url, headers=headers) as r: - _LOGGER.debug("getting state for %s from %s" % (zone_id, url)) if r.status == 200: data = await r.json() @@ -173,22 +182,29 @@ async def async_get_state(config) -> dict: if "zones" in data: for zone in zone_id.split(","): if zone in data["zones"]: - values = await async_get_alerts(zone_id) + values = await async_get_alerts(zone_id, gps_loc) break return values -async def async_get_alerts(zone_id: str) -> dict: +async def async_get_alerts(zone_id: str = "", gps_loc: str = "") -> dict: """Query API for Alerts.""" + url = "" values = {} headers = {"User-Agent": USER_AGENT, "Accept": "application/geo+json"} data = None - url = "%s/alerts/active?zone=%s" % (API_ENDPOINT, zone_id) + + if zone_id != "": + url = "%s/alerts/active?zone=%s" % (API_ENDPOINT, zone_id) + _LOGGER.debug("getting alert for %s from %s" % (zone_id, url)) + elif gps_loc != "": + url = '%s/alerts/active?point=%s' % (API_ENDPOINT, gps_loc) + _LOGGER.debug("getting alert for %s from %s" % (gps_loc, url)) + async with aiohttp.ClientSession() as session: async with session.get(url, headers=headers) as r: - _LOGGER.debug("getting alert for %s from %s" % (zone_id, url)) if r.status == 200: data = await r.json() diff --git a/custom_components/nws_alerts/config_flow.py b/custom_components/nws_alerts/config_flow.py index 45b4b1f..784f7fa 100644 --- a/custom_components/nws_alerts/config_flow.py +++ b/custom_components/nws_alerts/config_flow.py @@ -13,6 +13,7 @@ from .const import ( API_ENDPOINT, + CONF_GPS_LOC, CONF_INTERVAL, CONF_TIMEOUT, CONF_ZONE_ID, @@ -28,9 +29,10 @@ JSON_ID = "id" _LOGGER = logging.getLogger(__name__) +MENU_OPTIONS = ["zone", "gps_loc"] -def _get_schema(hass: Any, user_input: list, default_dict: list) -> Any: +def _get_schema_zone(hass: Any, user_input: list, default_dict: list) -> Any: """Gets a schema using the default_dict as a backup.""" if user_input is None: user_input = {} @@ -48,6 +50,23 @@ def _get_default(key): } ) +def _get_schema_gps(hass: Any, user_input: list, default_dict: list) -> Any: + """Gets a schema using the default_dict as a backup.""" + if user_input is None: + user_input = {} + + def _get_default(key): + """Gets default value for key.""" + return user_input.get(key, default_dict.get(key)) + + return vol.Schema( + { + vol.Required(CONF_GPS_LOC, default=_get_default(CONF_GPS_LOC)): str, + vol.Optional(CONF_NAME, default=_get_default(CONF_NAME)): str, + vol.Optional(CONF_INTERVAL, default=_get_default(CONF_INTERVAL)): int, + vol.Optional(CONF_TIMEOUT, default=_get_default(CONF_TIMEOUT)): int, + } + ) async def _get_zone_list(self) -> list | None: """Return list of zone by lat/lon""" @@ -100,7 +119,42 @@ def __init__(self): # return self.async_abort(reason=next(iter(errors.values()))) # return result - async def async_step_user(self, user_input={}): + async def async_step_user( + self, user_input: dict[str, Any] | None = None + ) -> FlowResult: + """Handle the flow initialized by the user.""" + return self.async_show_menu(step_id="user", menu_options=MENU_OPTIONS) + + async def async_step_gps_loc(self, user_input={}): + """Handle a flow initialized by the user.""" + lat = self.hass.config.latitude + lon = self.hass.config.longitude + self._errors = {} + self._gps_loc = f"{lat},{lon}" + + if user_input is not None: + self._data.update(user_input) + return self.async_create_entry(title=self._data[CONF_NAME], data=self._data) + return await self._show_config_gps_loc(user_input) + + async def _show_config_gps_loc(self, user_input): + """Show the configuration form to edit location data.""" + + # Defaults + defaults = { + CONF_NAME: DEFAULT_NAME, + CONF_INTERVAL: DEFAULT_INTERVAL, + CONF_TIMEOUT: DEFAULT_TIMEOUT, + CONF_GPS_LOC: self._gps_loc, + } + + return self.async_show_form( + step_id="gps_loc", + data_schema=_get_schema_gps(self.hass, user_input, defaults), + errors=self._errors, + ) + + async def async_step_zone(self, user_input={}): """Handle a flow initialized by the user.""" self._errors = {} self._zone_list = await _get_zone_list(self) @@ -108,9 +162,9 @@ async def async_step_user(self, user_input={}): if user_input is not None: self._data.update(user_input) return self.async_create_entry(title=self._data[CONF_NAME], data=self._data) - return await self._show_config_form(user_input) + return await self._show_config_zone(user_input) - async def _show_config_form(self, user_input): + async def _show_config_zone(self, user_input): """Show the configuration form to edit location data.""" # Defaults @@ -122,8 +176,8 @@ async def _show_config_form(self, user_input): } return self.async_show_form( - step_id="user", - data_schema=_get_schema(self.hass, user_input, defaults), + step_id="zone", + data_schema=_get_schema_zone(self.hass, user_input, defaults), errors=self._errors, ) @@ -154,6 +208,6 @@ async def _show_options_form(self, user_input): return self.async_show_form( step_id="init", - data_schema=_get_schema(self.hass, user_input, self._data), + data_schema=_get_schema_zone(self.hass, user_input, self._data), errors=self._errors, ) diff --git a/custom_components/nws_alerts/const.py b/custom_components/nws_alerts/const.py index 3fc7e59..6086f83 100644 --- a/custom_components/nws_alerts/const.py +++ b/custom_components/nws_alerts/const.py @@ -6,6 +6,7 @@ CONF_TIMEOUT = "timeout" CONF_INTERVAL = "interval" CONF_ZONE_ID = "zone_id" +CONF_GPS_LOC = "gps_loc" # Defaults DEFAULT_ICON = "mdi:alert" diff --git a/custom_components/nws_alerts/sensor.py b/custom_components/nws_alerts/sensor.py index bcda0a8..7b27329 100644 --- a/custom_components/nws_alerts/sensor.py +++ b/custom_components/nws_alerts/sensor.py @@ -13,6 +13,7 @@ from .const import ( ATTRIBUTION, + CONF_GPS_LOC, CONF_INTERVAL, CONF_TIMEOUT, CONF_ZONE_ID, @@ -35,7 +36,8 @@ PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend( { - vol.Required(CONF_ZONE_ID): cv.string, + vol.Optional(CONF_ZONE_ID): cv.string, + vol.Optional(CONF_GPS_LOC): cv.string, vol.Optional(CONF_NAME, default=DEFAULT_NAME): cv.string, vol.Optional(CONF_INTERVAL, default=DEFAULT_INTERVAL): int, vol.Optional(CONF_TIMEOUT, default=DEFAULT_TIMEOUT): int, @@ -47,10 +49,20 @@ async def async_setup_platform(hass, config, async_add_entities, discovery_info= """Configuration from yaml""" if DOMAIN not in hass.data.keys(): hass.data.setdefault(DOMAIN, {}) - config.entry_id = slugify(f"{config.get(CONF_ZONE_ID)}") + if CONF_ZONE_ID in config: + config.entry_id = slugify(f"{config.get(CONF_ZONE_ID)}") + elif CONF_GPS_LOC in config: + config.entry_id = slugify(f"{config.get(CONF_GPS_LOC)}") + elif CONF_GPS_LOC and CONF_ZONE_ID not in config: + raise ValueError("GPS or Zone needs to be configured.") config.data = config else: - config.entry_id = slugify(f"{config.get(CONF_ZONE_ID)}") + if CONF_ZONE_ID in config: + config.entry_id = slugify(f"{config.get(CONF_ZONE_ID)}") + elif CONF_GPS_LOC in config: + config.entry_id = slugify(f"{config.get(CONF_GPS_LOC)}") + elif CONF_GPS_LOC and CONF_ZONE_ID not in config: + raise ValueError("GPS or Zone needs to be configured.") config.data = config # Setup the data coordinator @@ -84,16 +96,6 @@ def __init__(self, hass: HomeAssistant, entry: ConfigEntry) -> None: self._config = entry self._name = entry.data[CONF_NAME] self._icon = DEFAULT_ICON - self._state = 0 - self._event = None - self._event_id = None - self._message_type = None - self._event_status = None - self._event_severity = None - self._event_expires = None - self._display_desc = None - self._spoken_desc = None - self._zone_id = entry.data[CONF_ZONE_ID].replace(" ", "") self.coordinator = hass.data[DOMAIN][entry.entry_id][COORDINATOR] @property @@ -120,8 +122,7 @@ def state(self): return None elif "state" in self.coordinator.data.keys(): return self.coordinator.data["state"] - else: - return None + return None @property def extra_state_attributes(self): diff --git a/custom_components/nws_alerts/translations/en.json b/custom_components/nws_alerts/translations/en.json index 1a5494f..41ec64e 100644 --- a/custom_components/nws_alerts/translations/en.json +++ b/custom_components/nws_alerts/translations/en.json @@ -2,6 +2,22 @@ "config": { "step": { "user": { + "description": "Please select your NWS lookup method.", + "menu_options": { + "zone": "Zone ID (old method)", + "gps_loc": "GPS Location (new method)" + } + }, + "gps_loc": { + "description": "Please enter your latitude and longitude, by default your coordinates from Home Assistant are used.", + "data": { + "name": "Friendly Name", + "gps_loc": "Your GPS coordinates", + "interval": "Update Interval (in minutes)", + "timeout":"Update Timeout (in seconds)" + } + }, + "zone": { "data": { "name": "Friendly Name", "zone_id": "Zone ID(s)", @@ -14,7 +30,23 @@ }, "options": { "step": { - "init": { + "user": { + "description": "Please select your NWS lookup method.", + "menu_options": { + "zone": "Zone ID (old method)", + "gps_loc": "GPS Location (new method)" + } + }, + "gps_loc": { + "description": "Please enter your latitude and longitude, by default your coordinates from Home Assistant are used.", + "data": { + "name": "Friendly Name", + "gps_loc": "Your GPS coordinates", + "interval": "Update Interval (in minutes)", + "timeout":"Update Timeout (in seconds)" + } + }, + "zone": { "data": { "name": "Friendly Name", "zone_id": "Zone ID(s)", diff --git a/tests/const.py b/tests/const.py index 4cc850e..cda761e 100644 --- a/tests/const.py +++ b/tests/const.py @@ -1,4 +1,6 @@ """Constants for tests.""" CONFIG_DATA = {"name": "NWS Alerts", "zone_id": "AZZ540,AZC013"} -CONFIG_DATA_2 = {"name": "NWS Alerts YAML", "zone_id": "AZZ540"} \ No newline at end of file +CONFIG_DATA_2 = {"name": "NWS Alerts YAML", "zone_id": "AZZ540"} +CONFIG_DATA_3 = {"name": "NWS Alerts", "gps_loc": "123,-456"} +CONFIG_DATA_BAD = {"name": "NWS Alerts" } \ No newline at end of file diff --git a/tests/test_config_flow.py b/tests/test_config_flow.py index 3dda612..f02886e 100644 --- a/tests/test_config_flow.py +++ b/tests/test_config_flow.py @@ -5,9 +5,11 @@ from homeassistant import config_entries, data_entry_flow, setup from homeassistant.const import CONF_NAME from pytest_homeassistant_custom_component.common import MockConfigEntry +from homeassistant.data_entry_flow import FlowResult, FlowResultType from custom_components.nws_alerts.const import CONF_ZONE_ID, DOMAIN +pytestmark = pytest.mark.asyncio @pytest.mark.parametrize( "input,step_id,title,data", @@ -19,7 +21,7 @@ "interval": 5, "timeout": 120, }, - "user", + "zone", "Testing Alerts", { "name": "Testing Alerts", @@ -30,7 +32,7 @@ ), ], ) -async def test_form( +async def test_form_zone( input, step_id, title, @@ -46,8 +48,7 @@ async def test_form( result = await hass.config_entries.flow.async_init( DOMAIN, context={"source": config_entries.SOURCE_USER} ) - assert result["type"] == "form" - assert result["errors"] == {} + assert result["type"] == FlowResultType.MENU # assert result["title"] == title_1 with patch( @@ -57,6 +58,13 @@ async def test_form( "custom_components.nws_alerts.config_flow._get_zone_list", return_value=None): + result = await hass.config_entries.flow.async_configure( + result["flow_id"], {"next_step_id": "zone"} + ) + await hass.async_block_till_done() + + assert result["type"] == FlowResultType.FORM + result2 = await hass.config_entries.flow.async_configure( result["flow_id"], input ) @@ -69,6 +77,71 @@ async def test_form( assert len(mock_setup_entry.mock_calls) == 1 +@pytest.mark.parametrize( + "input,step_id,title,data", + [ + ( + { + "name": "Testing Alerts", + "gps_loc": "123,-456", + "interval": 5, + "timeout": 120, + }, + "gps_loc", + "Testing Alerts", + { + "name": "Testing Alerts", + "gps_loc": "123,-456", + "interval": 5, + "timeout": 120, + }, + ), + ], +) +async def test_form_gps( + input, + step_id, + title, + data, + hass, +): + """Test we get the form.""" + await setup.async_setup_component(hass, "persistent_notification", {}) + with patch( + "custom_components.nws_alerts.config_flow._get_zone_list", + return_value=None): + + result = await hass.config_entries.flow.async_init( + DOMAIN, context={"source": config_entries.SOURCE_USER} + ) + assert result["type"] == FlowResultType.MENU + # assert result["title"] == title_1 + + with patch( + "custom_components.nws_alerts.async_setup_entry", + return_value=True, + ) as mock_setup_entry, patch( + "custom_components.nws_alerts.config_flow._get_zone_list", + return_value=None): + + result = await hass.config_entries.flow.async_configure( + result["flow_id"], {"next_step_id": "gps_loc"} + ) + await hass.async_block_till_done() + + assert result["type"] == FlowResultType.FORM + + result2 = await hass.config_entries.flow.async_configure( + result["flow_id"], input + ) + + assert result2["type"] == "create_entry" + assert result2["title"] == title + assert result2["data"] == data + + await hass.async_block_till_done() + assert len(mock_setup_entry.mock_calls) == 1 + # @pytest.mark.parametrize( # "user_input", # [ diff --git a/tests/test_init.py b/tests/test_init.py index 0d12096..38be362 100644 --- a/tests/test_init.py +++ b/tests/test_init.py @@ -9,8 +9,9 @@ from pytest_homeassistant_custom_component.common import MockConfigEntry from custom_components.nws_alerts.const import CONF_ZONE_ID, DOMAIN -from tests.const import CONFIG_DATA +from tests.const import CONFIG_DATA, CONFIG_DATA_3 +pytestmark = pytest.mark.asyncio async def test_setup_entry( hass, @@ -36,7 +37,7 @@ async def test_unload_entry(hass): entry = MockConfigEntry( domain=DOMAIN, title="NWS Alerts", - data=CONFIG_DATA, + data=CONFIG_DATA_3, ) entry.add_to_hass(hass) diff --git a/tests/test_sensor.py b/tests/test_sensor.py index 17a60b2..e1f66d3 100644 --- a/tests/test_sensor.py +++ b/tests/test_sensor.py @@ -1,10 +1,13 @@ """Test NWS Alerts Sensors""" +import pytest from pytest_homeassistant_custom_component.common import MockConfigEntry from homeassistant.util import slugify from homeassistant.helpers import entity_registry as er from custom_components.nws_alerts.const import DOMAIN -from tests.const import CONFIG_DATA, CONFIG_DATA_2 +from tests.const import CONFIG_DATA, CONFIG_DATA_2, CONFIG_DATA_BAD + +pytestmark = pytest.mark.asyncio NWS_SENSOR = "sensor.nws_alerts" NWS_SENSOR_2 = "sensor.nws_alerts_yaml"