diff --git a/custom_components/nws_alerts/sensor.py b/custom_components/nws_alerts/sensor.py index 6612d6f..303bfac 100644 --- a/custom_components/nws_alerts/sensor.py +++ b/custom_components/nws_alerts/sensor.py @@ -1,8 +1,9 @@ -import requests +import aiohttp import logging import voluptuous as vol from datetime import timedelta -from homeassistant.const import CONF_NAME, ATTR_ATTRIBUTION +from homeassistant.core import callback +from homeassistant.const import CONF_NAME, ATTR_ATTRIBUTION, EVENT_HOMEASSISTANT_START from homeassistant.helpers.entity import Entity from homeassistant.util import Throttle from homeassistant.components.sensor import PLATFORM_SCHEMA @@ -61,7 +62,6 @@ def __init__(self, name, zone_id): self._display_desc = None self._spoken_desc = None self._zone_id = zone_id.replace(' ', '') - self.update() @property def unique_id(self): @@ -98,19 +98,33 @@ def device_state_attributes(self): return attrs + async def async_added_to_hass(self): + """Register callbacks.""" + _LOGGER.debug("Registering: %s...", self.entity_id) + + @callback + def sensor_startup(event): + """Update sensor on startup.""" + + self.async_schedule_update_ha_state(True) + + self.hass.bus.async_listen_once( + EVENT_HOMEASSISTANT_START, sensor_startup + ) + @Throttle(MIN_TIME_BETWEEN_UPDATES) - def update(self): + async def async_update(self): """Fetch new state data for the sensor. This is the only method that should fetch new data for Home Assistant. """ - values = self.get_state() + values = await self.async_get_state() self._state = values['state'] self._event = values['event'] self._event_id = values['event_id'] self._display_desc = values['display_desc'] self._spoken_desc = values['spoken_desc'] - def get_state(self): + async def async_get_state(self): values = {'state': 0, 'event': None, 'event_id': None, @@ -122,19 +136,24 @@ def get_state(self): 'Accept': 'application/ld+json' } + data = None url = '%s/alerts/active/count' % API_ENDPOINT - r = requests.get(url, headers=headers) - _LOGGER.debug("getting state for %s from %s" % (self._zone_id, url)) - if r.status_code == 200: - if 'zones' in r.json(): + async with aiohttp.ClientSession() as session: + async with session.get(url, headers=headers) as r: + _LOGGER.debug("getting state for %s from %s" % (self._zone_id, url)) + if r.status == 200: + data = await r.json() + + if data is not None: + if 'zones' in data: for zone in self._zone_id.split(','): - if zone in r.json()['zones']: - values = self.get_alerts() + if zone in data['zones']: + values = await self.async_get_alerts() break return values - def get_alerts(self): + async def async_get_alerts(self): values = {'state': 0, 'event': None, 'event_id': None, @@ -145,16 +164,21 @@ def get_alerts(self): headers = {'User-Agent': USER_AGENT, 'Accept': 'application/geo+json' } + data = None url = '%s/alerts/active?zone=%s' % (API_ENDPOINT, self._zone_id) - r = requests.get(url, headers=headers) - _LOGGER.debug("getting alert for %s from %s" % (self._zone_id, url)) - if r.status_code == 200: + async with aiohttp.ClientSession() as session: + async with session.get(url, headers=headers) as r: + _LOGGER.debug("getting alert for %s from %s" % (self._zone_id, url)) + if r.status == 200: + data = await r.json() + + if data is not None: events = [] headlines = [] event_id = '' display_desc = '' spoken_desc = '' - features = r.json()['features'] + features = data['features'] for alert in features: event = alert['properties']['event'] if 'NWSheadline' in alert['properties']['parameters']: @@ -208,7 +232,7 @@ def get_alerts(self): values['display_desc'] = display_desc values['spoken_desc'] = spoken_desc - if r.status_code != 200: + if data is None: values['state'] = "Unknown" values['event'] = "Unknown" values['event_id'] = "Unknown"