diff --git a/custom_components/opnsense/pyopnsense/__init__.py b/custom_components/opnsense/pyopnsense/__init__.py index 5f253fa..7116793 100644 --- a/custom_components/opnsense/pyopnsense/__init__.py +++ b/custom_components/opnsense/pyopnsense/__init__.py @@ -19,9 +19,6 @@ import aiohttp import awesomeversion -from dateutil.parser import UnknownTimezoneWarning, parse - -from .const import AMBIGUOUS_TZINFOS # value to set as the socket timeout DEFAULT_TIMEOUT = 60 @@ -69,8 +66,10 @@ async def inner(self, *args, **kwargs): return inner -def wireguard_is_connected(past_time: datetime) -> bool: +def wireguard_is_connected(past_time: datetime | None) -> bool: """Return if Wireguard client is still connected.""" + if not past_time: + return False return datetime.now().astimezone() - past_time <= timedelta(minutes=3) @@ -133,6 +132,16 @@ def dict_get(data: MutableMapping[str, Any], path: str, default=None): return result +def timestamp_to_datetime(timestamp: int | None) -> datetime | None: + """Convert a timestamp to a timezone-aware datetime.""" + if timestamp is None: + return None + return datetime.fromtimestamp( + int(timestamp), + tz=timezone(datetime.now().astimezone().utcoffset() or timedelta()), + ) + + class VoucherServerError(Exception): """Error from Voucher Server.""" @@ -283,9 +292,7 @@ async def _exec_php(self, script) -> MutableMapping[str, Any]: @_log_errors async def get_host_firmware_version(self) -> None | str: """Return the OPNsense Firmware version.""" - firmware_info = await self._get("/api/core/firmware/status") - if not isinstance(firmware_info, MutableMapping): - return None + firmware_info = await self._safe_dict_get("/api/core/firmware/status") firmware: str | None = firmware_info.get("product", {}).get("product_version") if not firmware or not awesomeversion.AwesomeVersion(firmware).valid: old = firmware @@ -303,17 +310,15 @@ async def get_host_firmware_version(self) -> None | str: async def is_plugin_installed(self) -> bool: """Retun whether OPNsense plugin is installed or not.""" - firmware_info = await self._get("/api/core/firmware/info") - if not isinstance(firmware_info, MutableMapping) or not isinstance( - firmware_info.get("package"), list - ): + firmware_info = await self._safe_dict_get("/api/core/firmware/info") + if not isinstance(firmware_info.get("package"), list): return False for pkg in firmware_info.get("package", []): if pkg.get("name") == "os-homeassistant-maxit": return True return False - async def _get_from_stream(self, path: str) -> MutableMapping[str, Any] | list | None: + async def _get_from_stream(self, path: str) -> MutableMapping[str, Any]: self._rest_api_query_count += 1 url: str = f"{self._url}{path}" _LOGGER.debug("[get_from_stream] url: %s", url) @@ -346,12 +351,14 @@ async def _get_from_stream(self, path: str) -> MutableMapping[str, Any] | list | message_count += 1 if message_count == 2: response_str: str = line[len("data:") :].strip() - response_json: MutableMapping[str, Any] | list = json.loads( - response_str - ) + response_json = json.loads(response_str) # _LOGGER.debug(f"[get_from_stream] response_json ({type(response_json).__name__}): {response_json}") - return response_json # Exit after processing the second message + return ( + response_json + if isinstance(response_json, MutableMapping) + else {} + ) # Exit after processing the second message else: if response.status == 403: stack = inspect.stack() @@ -388,7 +395,7 @@ async def _get_from_stream(self, path: str) -> MutableMapping[str, Any] | list | if self._initial: raise - return None + return {} async def _get(self, path: str) -> MutableMapping[str, Any] | list | None: # /api////[/[/...]] @@ -441,7 +448,19 @@ async def _get(self, path: str) -> MutableMapping[str, Any] | list | None: return None - async def _post(self, path: str, payload=None) -> MutableMapping[str, Any] | list | None: + async def _safe_dict_get(self, path: str) -> MutableMapping[str, Any]: + """Fetch data from the given path, ensuring the result is a dict.""" + result = await self._get(path=path) + return result if isinstance(result, MutableMapping) else {} + + async def _safe_list_get(self, path: str) -> list: + """Fetch data from the given path, ensuring the result is a list.""" + result = await self._get(path=path) + return result if isinstance(result, list) else [] + + async def _post( + self, path: str, payload: MutableMapping[str, Any] | None = None + ) -> MutableMapping[str, Any] | list | None: # /api////[/[/...]] self._rest_api_query_count += 1 url: str = f"{self._url}{path}" @@ -494,6 +513,20 @@ async def _post(self, path: str, payload=None) -> MutableMapping[str, Any] | lis return None + async def _safe_dict_post( + self, path: str, payload: MutableMapping[str, Any] | None = None + ) -> MutableMapping[str, Any]: + """Fetch data from the given path, ensuring the result is a dict.""" + result = await self._post(path=path, payload=payload) + return result if isinstance(result, MutableMapping) else {} + + async def _safe_list_post( + self, path: str, payload: MutableMapping[str, Any] | None = None + ) -> list: + """Fetch data from the given path, ensuring the result is a list.""" + result = await self._post(path=path, payload=payload) + return result if isinstance(result, list) else [] + @_log_errors async def _filter_configure(self) -> None: script: str = r""" @@ -506,10 +539,7 @@ async def _filter_configure(self) -> None: @_log_errors async def get_device_unique_id(self) -> str | None: """Get the OPNsense Unique ID.""" - instances = await self._get("/api/interfaces/overview/export") - if not isinstance(instances, list): - return None - + instances = await self._safe_list_get("/api/interfaces/overview/export") mac_addresses = [ d.get("macaddr_hw") for d in instances if d.get("is_physical") and "macaddr_hw" in d ] @@ -538,9 +568,8 @@ async def get_system_info(self) -> MutableMapping[str, Any]: except awesomeversion.exceptions.AwesomeVersionCompareException: pass system_info: MutableMapping[str, Any] = {} - response = await self._get("/api/diagnostics/system/systemInformation") - if isinstance(response, MutableMapping): - system_info["name"] = response.get("name", None) + response = await self._safe_dict_get("/api/diagnostics/system/systemInformation") + system_info["name"] = response.get("name", None) return system_info @_log_errors @@ -561,96 +590,46 @@ async def _get_system_info_legacy(self) -> MutableMapping[str, Any]: return response @_log_errors - async def get_firmware_update_info(self): + async def get_firmware_update_info(self) -> MutableMapping[str, Any]: """Get the details of available firmware updates.""" - refresh_triggered = False - refresh_interval = 2 * 60 * 60 # 2 hours - - status = None - upgradestatus = None - - # GET /api/core/firmware/status - status = await self._get("/api/core/firmware/status") - # print(status) + status = await self._safe_dict_get("/api/core/firmware/status") # if error or too old trigger check (only if check is not already in progress) # {'status_msg': 'Firmware status check was aborted internally. Please try again.', 'status': 'error'} # error could be because data has not been refreshed at all OR an upgrade is currently in progress if ( - not isinstance(status, MutableMapping) - or status.get("status", None) == "error" + status.get("status", None) == "error" or "last_check" not in status or not isinstance(dict_get(status, "product.product_check"), dict) or not dict_get(status, "product.product_check") ): await self._post("/api/core/firmware/check") - refresh_triggered = True - elif "last_check" in status: - # "last_check": "Wed Dec 22 16:56:20 UTC 2021" - # "last_check": "Mon Jan 16 00:08:28 CET 2023" - # "last_check": "Sun Jan 15 22:05:55 UTC 2023" - # format = "%a %b %d %H:%M:%S %Z %Y" - try: - last_check: datetime = parse(status.get("last_check", 0), tzinfos=AMBIGUOUS_TZINFOS) - if last_check.tzinfo is None: - last_check = last_check.replace( - tzinfo=timezone(datetime.now().astimezone().utcoffset() or timedelta()) - ) - - last_check_timestamp: float = last_check.timestamp() - - except (ValueError, TypeError, UnknownTimezoneWarning): - last_check_timestamp = 0 - - stale: bool = ( - datetime.now().astimezone().timestamp() - last_check_timestamp - ) > refresh_interval - if stale: - upgradestatus = await self._get("/api/core/firmware/upgradestatus") - # print(upgradestatus) - if isinstance(upgradestatus, MutableMapping): - # status = running (package refresh in progress OR upgrade in progress) - # status = done (refresh/upgrade done) - if upgradestatus.get("status", None) == "done": - # tigger repo update - # should this be /api/core/firmware/upgrade - # check = await self._post("/api/core/firmware/check") - # print(check) - refresh_triggered = True - else: - # print("upgrade already running") - pass - - wait_for_refresh = False - if refresh_triggered and wait_for_refresh: - # print("refresh triggered, waiting for it to finish") - pass return status @_log_errors - async def upgrade_firmware(self, type="update"): + async def upgrade_firmware(self, type="update") -> MutableMapping[str, Any] | None: """Trigger a firmware upgrade.""" # minor updates of the same opnsense version if type == "update": # can watch the progress on the 'Updates' tab in the UI - return await self._post("/api/core/firmware/update") + return await self._safe_dict_post("/api/core/firmware/update") # major updates to a new opnsense version if type == "upgrade": # can watch the progress on the 'Updates' tab in the UI - return await self._post("/api/core/firmware/upgrade") + return await self._safe_dict_post("/api/core/firmware/upgrade") return None @_log_errors - async def upgrade_status(self): + async def upgrade_status(self) -> MutableMapping[str, Any]: """Return the status of the firmware upgrade.""" - return await self._post("/api/core/firmware/upgradestatus") + return await self._safe_dict_post("/api/core/firmware/upgradestatus") @_log_errors - async def firmware_changelog(self, version): + async def firmware_changelog(self, version) -> MutableMapping[str, Any]: """Return the changelog for the firmware upgrade.""" - return await self._post("/api/core/firmware/changelog/" + version) + return await self._safe_dict_post(f"/api/core/firmware/changelog/{version}") @_log_errors async def get_config(self) -> MutableMapping[str, Any]: @@ -774,15 +753,13 @@ async def disable_nat_outbound_rule_by_created_time(self, created_time) -> None: await self._filter_configure() @_log_errors - async def get_arp_table(self, resolve_hostnames=False) -> list: + async def get_arp_table(self, resolve_hostnames: bool = False) -> list: """Return the active ARP table.""" # [{'hostname': '?', 'ip-address': '', 'mac-address': '', 'interface': 'em0', 'expires': 1199, 'type': 'ethernet'}, ...] request_body: MutableMapping[str, Any] = {"resolve": "yes"} - arp_table_info = await self._post( + arp_table_info = await self._safe_dict_post( "/api/diagnostics/interface/search_arp", payload=request_body ) - if not isinstance(arp_table_info, MutableMapping): - return [] # _LOGGER.debug(f"[get_arp_table] arp_table_info: {arp_table_info}") arp_table: list = arp_table_info.get("rows", []) # _LOGGER.debug(f"[get_arp_table] arp_table: {arp_table}") @@ -791,10 +768,7 @@ async def get_arp_table(self, resolve_hostnames=False) -> list: @_log_errors async def get_services(self) -> list: """Get the list of OPNsense services.""" - response = await self._get("/api/core/service/search") - if not isinstance(response, MutableMapping): - _LOGGER.error("Invalid data returned from get_services") - return [] + response = await self._safe_dict_get("/api/core/service/search") # _LOGGER.debug(f"[get_services] response: {response}") services: list = response.get("rows", []) for service in services: @@ -819,9 +793,9 @@ async def _manage_service(self, action: str, service: str) -> bool: if not service: return False api_addr: str = f"/api/core/service/{action}/{service}" - response = await self._post(api_addr) + response = await self._safe_dict_post(api_addr) _LOGGER.debug("[%s_service] service: %s, response: %s", action, service, response) - return isinstance(response, MutableMapping) and response.get("result", "failed") == "ok" + return response.get("result", "failed") == "ok" @_log_errors async def start_service(self, service: str) -> bool: @@ -889,9 +863,7 @@ async def get_dhcp_leases(self) -> MutableMapping[str, Any]: async def _get_kea_interfaces(self) -> MutableMapping[str, Any]: """Return interfaces setup for Kea.""" - response = await self._get("/api/kea/dhcpv4/get") - if not isinstance(response, MutableMapping): - return {} + response = await self._safe_dict_get("/api/kea/dhcpv4/get") lease_interfaces: MutableMapping[str, Any] = {} general: MutableMapping[str, Any] = response.get("dhcpv4", {}).get("general", {}) if general.get("enabled", "0") != "1": @@ -906,15 +878,11 @@ async def _get_kea_interfaces(self) -> MutableMapping[str, Any]: async def _get_kea_dhcpv4_leases(self) -> list: """Return IPv4 DHCP Leases by Kea.""" - response = await self._get("/api/kea/leases4/search") - if not isinstance(response, MutableMapping) or not isinstance( - response.get("rows", None), list - ): + response = await self._safe_dict_get("/api/kea/leases4/search") + if not isinstance(response.get("rows", None), list): return [] - res_resp = await self._get("/api/kea/dhcpv4/searchReservation") - if not isinstance(res_resp, MutableMapping) or not isinstance( - res_resp.get("rows", None), list - ): + res_resp = await self._safe_dict_get("/api/kea/dhcpv4/searchReservation") + if not isinstance(res_resp.get("rows", None), list): res_info = [] else: res_info = res_resp.get("rows", []) @@ -954,9 +922,8 @@ async def _get_kea_dhcpv4_leases(self) -> list: lease["type"] = "dynamic" lease["mac"] = lease_info.get("hwaddr", None) if OPNsenseClient._try_to_int(lease_info.get("expire", None)): - lease["expires"] = datetime.fromtimestamp( - OPNsenseClient._try_to_int(lease_info.get("expire", None)) or 0, - tz=timezone(datetime.now().astimezone().utcoffset() or timedelta()), + lease["expires"] = timestamp_to_datetime( + OPNsenseClient._try_to_int(lease_info.get("expire", None)) or 0 ) if lease["expires"] < datetime.now().astimezone(): continue @@ -968,9 +935,7 @@ async def _get_kea_dhcpv4_leases(self) -> list: async def _get_isc_dhcpv4_leases(self) -> list: """Return IPv4 DHCP Leases by ISC.""" - response = await self._get("/api/dhcpv4/leases/searchLease") - if not isinstance(response, MutableMapping): - return [] + response = await self._safe_dict_get("/api/dhcpv4/leases/searchLease") leases_info: list = response.get("rows", []) if not isinstance(leases_info, list): return [] @@ -1011,9 +976,7 @@ async def _get_isc_dhcpv4_leases(self) -> list: async def _get_isc_dhcpv6_leases(self) -> list: """Return IPv6 DHCP Leases by ISC.""" - response = await self._get("/api/dhcpv6/leases/searchLease") - if not isinstance(response, MutableMapping): - return [] + response = await self._safe_dict_get("/api/dhcpv6/leases/searchLease") leases_info: list = response.get("rows", []) if not isinstance(leases_info, list): return [] @@ -1055,29 +1018,22 @@ async def _get_isc_dhcpv6_leases(self) -> list: @_log_errors async def get_carp_status(self) -> bool: """Return the Carp status.""" - response = await self._get("/api/diagnostics/interface/get_vip_status") - if not isinstance(response, MutableMapping): - _LOGGER.error("Invalid data returned from get_carp_status") - return False + response = await self._safe_dict_get("/api/diagnostics/interface/get_vip_status") # _LOGGER.debug(f"[get_carp_status] response: {response}") return response.get("carp", {}).get("allow", "0") == "1" @_log_errors async def get_carp_interfaces(self) -> list: """Return the interfaces used by Carp.""" - vip_settings_raw = await self._get("/api/interfaces/vip_settings/get") - if not isinstance(vip_settings_raw, MutableMapping) or not isinstance( - vip_settings_raw.get("rows", None), list - ): + vip_settings_raw = await self._safe_dict_get("/api/interfaces/vip_settings/get") + if not isinstance(vip_settings_raw.get("rows", None), list): vip_settings: list = [] else: vip_settings = vip_settings_raw.get("rows", []) # _LOGGER.debug(f"[get_carp_interfaces] vip_settings: {vip_settings}") - vip_status_raw = await self._get("/api/diagnostics/interface/get_vip_status") - if not isinstance(vip_status_raw, MutableMapping) or not isinstance( - vip_status_raw.get("rows", None), list - ): + vip_status_raw = await self._safe_dict_get("/api/diagnostics/interface/get_vip_status") + if not isinstance(vip_status_raw.get("rows", None), list): vip_status: list = [] else: vip_status = vip_status_raw.get("rows", []) @@ -1101,18 +1057,18 @@ async def get_carp_interfaces(self) -> list: @_log_errors async def system_reboot(self) -> bool: """Reboot OPNsense.""" - response = await self._post("/api/core/system/reboot") + response = await self._safe_dict_post("/api/core/system/reboot") _LOGGER.debug("[system_reboot] response: %s", response) - if isinstance(response, MutableMapping) and response.get("status", "") == "ok": + if response.get("status", "") == "ok": return True return False @_log_errors async def system_halt(self) -> None: """Shutdown OPNsense.""" - response = await self._post("/api/core/system/halt") + response = await self._safe_dict_post("/api/core/system/halt") _LOGGER.debug("[system_halt] response: %s", response) - if isinstance(response, MutableMapping) and response.get("status", "") == "ok": + if response.get("status", "") == "ok": return return @@ -1121,9 +1077,9 @@ async def send_wol(self, interface, mac) -> bool: """Send a wake on lan packet to the specified MAC address.""" payload: MutableMapping[str, Any] = {"wake": {"interface": interface, "mac": mac}} _LOGGER.debug("[send_wol] payload: %s", payload) - response = await self._post("/api/wol/wol/set", payload) + response = await self._safe_dict_post("/api/wol/wol/set", payload) _LOGGER.debug("[send_wol] response: %s", response) - if isinstance(response, MutableMapping) and response.get("status", "") == "ok": + if response.get("status", "") == "ok": return True return False @@ -1174,9 +1130,9 @@ async def get_telemetry(self) -> MutableMapping[str, Any]: @_log_errors async def get_interfaces(self) -> MutableMapping[str, Any]: """Return all OPNsense interfaces.""" - interface_info = await self._get("/api/interfaces/overview/export") + interface_info = await self._safe_list_get("/api/interfaces/overview/export") # _LOGGER.debug(f"[get_interfaces] interface_info: {interface_info}") - if not isinstance(interface_info, list) or not len(interface_info) > 0: + if not len(interface_info) > 0: return {} interfaces: MutableMapping[str, Any] = {} for ifinfo in interface_info: @@ -1233,10 +1189,8 @@ async def get_interfaces(self) -> MutableMapping[str, Any]: @_log_errors async def _get_telemetry_mbuf(self) -> MutableMapping[str, Any]: - mbuf_info = await self._post("/api/diagnostics/system/system_mbuf") + mbuf_info = await self._safe_dict_post("/api/diagnostics/system/system_mbuf") # _LOGGER.debug(f"[get_telemetry_mbuf] mbuf_info: {mbuf_info}") - if not isinstance(mbuf_info, MutableMapping): - return {} mbuf: MutableMapping[str, Any] = {} mbuf["used"] = OPNsenseClient._try_to_int( mbuf_info.get("mbuf-statistics", {}).get("mbuf-current", None) @@ -1256,10 +1210,8 @@ async def _get_telemetry_mbuf(self) -> MutableMapping[str, Any]: @_log_errors async def _get_telemetry_pfstate(self) -> MutableMapping[str, Any]: - pfstate_info = await self._post("/api/diagnostics/firewall/pf_states") + pfstate_info = await self._safe_dict_post("/api/diagnostics/firewall/pf_states") # _LOGGER.debug(f"[get_telemetry_pfstate] pfstate_info: {pfstate_info}") - if not isinstance(pfstate_info, MutableMapping): - return {} pfstate: MutableMapping[str, Any] = {} pfstate["used"] = OPNsenseClient._try_to_int(pfstate_info.get("current", None)) pfstate["total"] = OPNsenseClient._try_to_int(pfstate_info.get("limit", None)) @@ -1275,10 +1227,8 @@ async def _get_telemetry_pfstate(self) -> MutableMapping[str, Any]: @_log_errors async def _get_telemetry_memory(self) -> MutableMapping[str, Any]: - memory_info = await self._post("/api/diagnostics/system/systemResources") + memory_info = await self._safe_dict_post("/api/diagnostics/system/systemResources") # _LOGGER.debug(f"[get_telemetry_memory] memory_info: {memory_info}") - if not isinstance(memory_info, MutableMapping): - return {} memory: MutableMapping[str, Any] = {} memory["physmem"] = OPNsenseClient._try_to_int( memory_info.get("memory", {}).get("total", None) @@ -1291,10 +1241,9 @@ async def _get_telemetry_memory(self) -> MutableMapping[str, Any]: and memory["physmem"] > 0 else None ) - swap_info = await self._post("/api/diagnostics/system/system_swap") + swap_info = await self._safe_dict_post("/api/diagnostics/system/system_swap") if ( - not isinstance(swap_info, MutableMapping) - or not isinstance(swap_info.get("swap", None), list) + not isinstance(swap_info.get("swap", None), list) or not len(swap_info.get("swap", [])) > 0 or not isinstance(swap_info.get("swap", [])[0], MutableMapping) ): @@ -1316,10 +1265,8 @@ async def _get_telemetry_memory(self) -> MutableMapping[str, Any]: @_log_errors async def _get_telemetry_system(self) -> MutableMapping[str, Any]: - time_info = await self._post("/api/diagnostics/system/systemTime") + time_info = await self._safe_dict_post("/api/diagnostics/system/systemTime") # _LOGGER.debug(f"[get_telemetry_system] time_info: {time_info}") - if not isinstance(time_info, MutableMapping): - return {} system: MutableMapping[str, Any] = {} pattern = re.compile(r"^(?:(\d+)\s+days?,\s+)?(\d{2}):(\d{2}):(\d{2})$") match = pattern.match(time_info.get("uptime", "")) @@ -1355,9 +1302,9 @@ async def _get_telemetry_system(self) -> MutableMapping[str, Any]: @_log_errors async def _get_telemetry_cpu(self) -> MutableMapping[str, Any]: - cputype_info = await self._post("/api/diagnostics/cpu_usage/getCPUType") + cputype_info = await self._safe_list_post("/api/diagnostics/cpu_usage/getCPUType") # _LOGGER.debug(f"[get_telemetry_cpu] cputype_info: {cputype_info}") - if not isinstance(cputype_info, list) or not len(cputype_info) > 0: + if not len(cputype_info) > 0: return {} cpu: MutableMapping[str, Any] = {} cores_match = re.search(r"\((\d+) cores", cputype_info[0]) @@ -1366,8 +1313,6 @@ async def _get_telemetry_cpu(self) -> MutableMapping[str, Any]: cpustream_info = await self._get_from_stream("/api/diagnostics/cpu_usage/stream") # {"total":29,"user":2,"nice":0,"sys":27,"intr":0,"idle":70} # _LOGGER.debug(f"[get_telemetry_cpu] cpustream_info: {cpustream_info}") - if not isinstance(cpustream_info, MutableMapping): - return cpu cpu["usage_total"] = OPNsenseClient._try_to_int(cpustream_info.get("total", None)) cpu["usage_user"] = OPNsenseClient._try_to_int(cpustream_info.get("user", None)) cpu["usage_nice"] = OPNsenseClient._try_to_int(cpustream_info.get("nice", None)) @@ -1379,9 +1324,7 @@ async def _get_telemetry_cpu(self) -> MutableMapping[str, Any]: @_log_errors async def _get_telemetry_filesystems(self) -> list: - filesystems_info = await self._post("/api/diagnostics/system/systemDisk") - if not isinstance(filesystems_info, MutableMapping): - return [] + filesystems_info = await self._safe_dict_post("/api/diagnostics/system/systemDisk") # _LOGGER.debug(f"[get_telemetry_filesystems] filesystems_info: {filesystems_info}") filesystems: list = filesystems_info.get("devices", []) # _LOGGER.debug(f"[get_telemetry_filesystems] filesystems: {filesystems}") @@ -1393,162 +1336,161 @@ async def get_openvpn(self) -> MutableMapping[str, Any]: # https://docs.opnsense.org/development/api/core/openvpn.html # https://github.com/opnsense/core/blob/master/src/opnsense/www/js/widgets/OpenVPNClients.js # https://github.com/opnsense/core/blob/master/src/opnsense/www/js/widgets/OpenVPNServers.js - - sessions_info = await self._get("/api/openvpn/service/searchSessions") - - routes_info = await self._get("/api/openvpn/service/searchRoutes") - - providers_info = await self._get("/api/openvpn/export/providers") - - instances_info = await self._get("/api/openvpn/instances/search") + openvpn: MutableMapping[str, Any] = {"servers": {}, "clients": {}} + + # Fetch data + sessions_info = await self._safe_dict_get("/api/openvpn/service/searchSessions") + routes_info = await self._safe_dict_get("/api/openvpn/service/searchRoutes") + providers_info = await self._safe_dict_get("/api/openvpn/export/providers") + instances_info = await self._safe_dict_get("/api/openvpn/instances/search") + + await OPNsenseClient._process_openvpn_instances(instances_info, openvpn) + await OPNsenseClient._process_openvpn_providers(providers_info, openvpn) + await OPNsenseClient._process_openvpn_sessions(sessions_info, openvpn) + await OPNsenseClient._process_openvpn_routes(routes_info, openvpn) # _LOGGER.debug(f"[get_openvpn] sessions_info: {sessions_info}") # _LOGGER.debug(f"[get_openvpn] routes_info: {routes_info}") # _LOGGER.debug(f"[get_openvpn] providers_info: {providers_info}") # _LOGGER.debug(f"[get_openvpn] instances_info: {instances_info}") - if not isinstance(sessions_info, MutableMapping): - sessions_info = {} - if not isinstance(routes_info, MutableMapping): - routes_info = {} - if not isinstance(providers_info, MutableMapping): - providers_info = {} - if not isinstance(instances_info, MutableMapping): - instances_info = {} - - openvpn: MutableMapping[str, Any] = {} - openvpn["servers"] = {} - openvpn["clients"] = {} - - # Servers + + await self._fetch_openvpn_server_details(openvpn) + + _LOGGER.debug("[get_openvpn] openvpn: %s", openvpn) + return openvpn + + @staticmethod + async def _process_openvpn_instances( + instances_info: MutableMapping[str, Any], openvpn: MutableMapping[str, Any] + ) -> None: + """Process OpenVPN instances into servers and clients.""" for instance in instances_info.get("rows", []): - if ( - not isinstance(instance, MutableMapping) - or instance.get("role", "").lower() != "server" - ): + if not isinstance(instance, MutableMapping): continue - if instance.get("uuid", None) and instance.get("uuid", None) not in openvpn["servers"]: - openvpn["servers"][instance.get("uuid")] = { - "uuid": instance.get("uuid"), + role = instance.get("role", "").lower() + uuid = instance.get("uuid") + if role == "server": + await OPNsenseClient._add_openvpn_server(instance, openvpn) + elif role == "client" and uuid: + openvpn["clients"][uuid] = { "name": instance.get("description"), - "enabled": bool(instance.get("enabled", "0") == "1"), - "dev_type": instance.get("dev_type", None), - "clients": [], + "uuid": uuid, + "enabled": instance.get("enabled") == "1", } + @staticmethod + async def _add_openvpn_server( + instance: MutableMapping[str, Any], openvpn: MutableMapping[str, Any] + ) -> None: + """Add a server to the OpenVPN structure.""" + uuid = instance.get("uuid") + if not uuid: + return + if uuid not in openvpn["servers"]: + openvpn["servers"][uuid] = { + "uuid": uuid, + "name": instance.get("description"), + "enabled": instance.get("enabled") == "1", + "dev_type": instance.get("dev_type"), + "clients": [], + } + + @staticmethod + async def _process_openvpn_providers( + providers_info: MutableMapping[str, Any], openvpn: MutableMapping[str, Any] + ) -> None: + """Process OpenVPN providers.""" for uuid, vpn_info in providers_info.items(): if not uuid or not isinstance(vpn_info, MutableMapping): continue - if uuid not in openvpn["servers"]: - openvpn["servers"][uuid] = { - "uuid": uuid, - "name": vpn_info.get("name"), - "clients": [], - } - if vpn_info.get("hostname", None) and vpn_info.get("local_port", None): - openvpn["servers"][uuid]["endpoint"] = ( - f"{vpn_info.get('hostname')}:{vpn_info.get('local_port')}" - ) + server = openvpn["servers"].setdefault(uuid, {"uuid": uuid, "clients": []}) + server.update({"name": vpn_info.get("name")}) + if vpn_info.get("hostname") and vpn_info.get("local_port"): + server["endpoint"] = f"{vpn_info['hostname']}:{vpn_info['local_port']}" + @staticmethod + async def _process_openvpn_sessions( + sessions_info: MutableMapping[str, Any], openvpn: MutableMapping[str, Any] + ) -> None: + """Process OpenVPN sessions.""" for session in sessions_info.get("rows", []): - if session.get("type", None) != "server": + if session.get("type") != "server": continue - server_id = str(session["id"]).split("_", maxsplit=1)[0] + server_id = str(session["id"]).split("_", 1)[0] + server = openvpn["servers"].setdefault(server_id, {"uuid": server_id, "clients": []}) + server["name"] = session.get("description", "") + await OPNsenseClient._update_openvpn_server_status(server, session) - if server_id not in openvpn["servers"]: - openvpn["servers"][server_id] = { - "uuid": server_id, - "clients": [], + @staticmethod + async def _update_openvpn_server_status( + server: MutableMapping[str, Any], session: MutableMapping[str, Any] + ) -> None: + """Update server status based on session data.""" + status = session.get("status") + if not session.get("is_client", False): + server["status"] = ( + "disabled" + if not server.get("enabled", True) + else "up" + if status in {"connected", "ok"} + else "failed" + if status == "failed" + else status or "down" + ) + else: + server.update( + { + "status": "up", + "latest_handshake": timestamp_to_datetime( + session.get("connected_since__time_t_") + ), + "total_bytes_recv": OPNsenseClient._try_to_int( + session.get("bytes_received", 0), 0 + ), + "total_bytes_sent": OPNsenseClient._try_to_int(session.get("bytes_sent", 0), 0), } - openvpn["servers"][server_id]["name"] = session.get("description", "") - - if not session.get("is_client", False): - if openvpn["servers"][server_id].get("enabled", True) is False: - openvpn["servers"][server_id].update({"status": "disabled"}) - elif session.get("status", None) in {"connected", "ok"}: - openvpn["servers"][server_id].update({"status": "up"}) - elif session.get("status", None) == "failed": - openvpn["servers"][server_id].update({"status": "failed"}) - elif isinstance(session.get("status", None), str): - openvpn["servers"][server_id].update({"status": session.get("status")}) - else: - openvpn["servers"][server_id].update({"status": "down"}) - else: - openvpn["servers"][server_id].update( - { - "status": "up", - "latest_handshake": datetime.fromtimestamp( - int(session.get("connected_since__time_t_")), - tz=timezone(datetime.now().astimezone().utcoffset() or timedelta()), - ), - "total_bytes_recv": OPNsenseClient._try_to_int( - session.get("bytes_received", 0), 0 - ), - "total_bytes_sent": OPNsenseClient._try_to_int( - session.get("bytes_sent", 0), 0 - ), - } - ) + ) + @staticmethod + async def _process_openvpn_routes( + routes_info: MutableMapping[str, Any], openvpn: MutableMapping[str, Any] + ) -> None: + """Process OpenVPN routes.""" for route in routes_info.get("rows", []): - if ( - not isinstance(route, MutableMapping) - or route.get("id", None) is None - or route.get("id") not in openvpn.get("servers", {}) - ): + server_id = route.get("id") + if not isinstance(route, MutableMapping) or server_id not in openvpn["servers"]: continue - openvpn["servers"][route.get("id")]["clients"].append( + openvpn["servers"][server_id]["clients"].append( { - "name": route.get("common_name", None), - "endpoint": route.get("real_address", None), + "name": route.get("common_name"), + "endpoint": route.get("real_address"), "tunnel_addresses": [route.get("virtual_address")], - "latest_handshake": datetime.fromtimestamp( - int(route.get("last_ref__time_t_", 0)), - tz=timezone(datetime.now().astimezone().utcoffset() or timedelta()), - ), + "latest_handshake": timestamp_to_datetime(route.get("last_ref__time_t_", 0)), } ) + async def _fetch_openvpn_server_details(self, openvpn: MutableMapping[str, Any]) -> None: + """Fetch detailed server information.""" for uuid, server in openvpn["servers"].items(): - if "total_bytes_sent" not in server: - server["total_bytes_sent"] = 0 - if "total_bytes_recv" not in server: - server["total_bytes_recv"] = 0 + server.setdefault("total_bytes_sent", 0) + server.setdefault("total_bytes_recv", 0) server["connected_clients"] = len(server.get("clients", [])) - details_info = await self._get(f"/api/openvpn/instances/get/{uuid}") - if isinstance(details_info, MutableMapping) and isinstance( - details_info.get("instance", None), MutableMapping - ): - details = details_info.get("instance", {}) - if details.get("server", None): - server["tunnel_addresses"] = [details.get("server")] - server["dns_servers"] = [] - for dns in details.get("dns_servers", {}).values(): - if dns.get("selected", 0) == 1 and dns.get("value", None): - server["dns_servers"].append(dns.get("value")) - - # Clients - for instance in instances_info.get("rows", []): - if ( - not isinstance(instance, MutableMapping) - or instance.get("role", "").lower() != "client" - ): - continue - if instance.get("uuid", None): - openvpn["clients"][instance.get("uuid")] = { - "name": instance.get("description", None), - "uuid": instance.get("uuid", None), - "enabled": bool(instance.get("enabled", "0") == "1"), - } - - _LOGGER.debug("[get_openvpn] openvpn: %s", openvpn) - return openvpn + details_info = await self._safe_dict_get(f"/api/openvpn/instances/get/{uuid}") + details = ( + details_info.get("instance", {}) if isinstance(details_info, MutableMapping) else {} + ) + if details.get("server"): + server["tunnel_addresses"] = [details["server"]] + server["dns_servers"] = [ + dns["value"] + for dns in details.get("dns_servers", {}).values() + if dns.get("selected") == 1 and dns.get("value") + ] @_log_errors async def get_gateways(self) -> MutableMapping[str, Any]: """Return OPNsense Gateway details.""" - gateways_info = await self._get("/api/routes/gateway/status") + gateways_info = await self._safe_dict_get("/api/routes/gateway/status") # _LOGGER.debug(f"[get_gateways] gateways_info: {gateways_info}") - if not isinstance(gateways_info, MutableMapping): - return {} gateways: MutableMapping[str, Any] = {} for gw_info in gateways_info.get("items", []): if isinstance(gw_info, MutableMapping) and "name" in gw_info: @@ -1568,9 +1510,9 @@ async def _get_telemetry_temps(self) -> MutableMapping[str, Any]: return {} except awesomeversion.exceptions.AwesomeVersionCompareException: pass - temps_info = await self._get("/api/diagnostics/system/systemTemperature") + temps_info = await self._safe_list_get("/api/diagnostics/system/systemTemperature") # _LOGGER.debug(f"[get_telemetry_temps] temps_info: {temps_info}") - if not isinstance(temps_info, list) or not len(temps_info) > 0: + if not len(temps_info) > 0: return {} temps: MutableMapping[str, Any] = {} for i, temp_info in enumerate(temps_info): @@ -1664,11 +1606,8 @@ async def _get_telemetry_legacy(self) -> MutableMapping[str, Any]: @_log_errors async def get_notices(self) -> MutableMapping[str, Any]: """Get active OPNsense notices.""" - notices_info = await self._get("/api/core/system/status") + notices_info = await self._safe_dict_get("/api/core/system/status") # _LOGGER.debug(f"[get_notices] notices_info: {notices_info}") - - if not isinstance(notices_info, MutableMapping): - return {} pending_notices_present = False pending_notices: list = [] for key, notice in notices_info.items(): @@ -1679,10 +1618,7 @@ async def get_notices(self) -> MutableMapping[str, Any]: "notice": notice.get("message", None), "id": key, "created_at": ( - datetime.fromtimestamp( - int(notice.get("timestamp", 0)), - tz=timezone(datetime.now().astimezone().utcoffset() or timedelta()), - ) + timestamp_to_datetime(int(notice.get("timestamp", 0))) if notice.get("timestamp", None) else None ), @@ -1702,26 +1638,22 @@ async def close_notice(self, id) -> bool: # id = "all" to close all notices success = True if id.lower() == "all": - notices = await self._get("/api/core/system/status") + notices = await self._safe_dict_get("/api/core/system/status") # _LOGGER.debug(f"[close_notice] notices: {notices}") - - if not isinstance(notices, MutableMapping): - return False for key, notice in notices.items(): if "statusCode" in notice: - dismiss = await self._post( + dismiss = await self._safe_dict_post( "/api/core/system/dismissStatus", payload={"subject": key} ) # _LOGGER.debug(f"[close_notice] id: {key}, dismiss: {dismiss}") - if ( - not isinstance(dismiss, MutableMapping) - or dismiss.get("status", "failed") != "ok" - ): + if dismiss.get("status", "failed") != "ok": success = False else: - dismiss = await self._post("/api/core/system/dismissStatus", payload={"subject": id}) + dismiss = await self._safe_dict_post( + "/api/core/system/dismissStatus", payload={"subject": id} + ) _LOGGER.debug("[close_notice] id: %s, dismiss: %s", id, dismiss) - if not isinstance(dismiss, MutableMapping) or dismiss.get("status", "failed") != "ok": + if dismiss.get("status", "failed") != "ok": success = False _LOGGER.debug("[close_notice] success: %s", success) return success @@ -1729,10 +1661,7 @@ async def close_notice(self, id) -> bool: @_log_errors async def get_unbound_blocklist(self) -> MutableMapping[str, Any]: """Return the Unbound Blocklist details.""" - response = await self._get("/api/unbound/settings/get") - if not isinstance(response, MutableMapping): - _LOGGER.error("Invalid data returned from get_unbound_blocklist") - return {} + response = await self._safe_dict_get("/api/unbound/settings/get") # _LOGGER.debug(f"[get_unbound_blocklist] response: {response}") dnsbl_settings = response.get("unbound", {}).get("dnsbl", {}) # _LOGGER.debug(f"[get_unbound_blocklist] dnsbl_settings: {dnsbl_settings}") @@ -1798,241 +1727,242 @@ async def disable_unbound_blocklist(self) -> bool: @_log_errors async def get_wireguard(self) -> MutableMapping[str, Any]: - """Get the details of the wireguard services.""" - summary_raw = await self._get("/api/wireguard/service/show") - clients_raw = await self._get("/api/wireguard/client/get") - servers_raw = await self._get("/api/wireguard/server/get") - if ( - not isinstance(summary_raw, MutableMapping) - or not isinstance(clients_raw, MutableMapping) - or not isinstance(servers_raw, MutableMapping) - ): - return {} - summary = summary_raw.get("rows", []) - client_summ = clients_raw.get("client", {}).get("clients", {}).get("client", {}) - server_summ = servers_raw.get("server", {}).get("servers", {}).get("server", {}) + """Get the details of the WireGuard services.""" + data_sources = { + "summary_raw": "/api/wireguard/service/show", + "clients_raw": "/api/wireguard/client/get", + "servers_raw": "/api/wireguard/server/get", + } + data = {key: await self._safe_dict_get(path) for key, path in data_sources.items()} + + summary = data["summary_raw"].get("rows", []) + client_summ = data["clients_raw"].get("client", {}).get("clients", {}).get("client", {}) + server_summ = data["servers_raw"].get("server", {}).get("servers", {}).get("server", {}) + if ( not isinstance(summary, list) or not isinstance(client_summ, MutableMapping) or not isinstance(server_summ, MutableMapping) ): return {} - servers: MutableMapping[str, Any] = {} - clients: MutableMapping[str, Any] = {} - for uid, srv in server_summ.items(): - if not isinstance(srv, MutableMapping): - continue - server: MutableMapping[str, Any] = {} - for attr in ("name", "pubkey", "endpoint", "peer_dns"): - if srv.get(attr, None): - if attr == "peer_dns": - server["dns_servers"] = [srv.get(attr)] - else: - server[attr] = srv.get(attr) - server["uuid"] = uid - server["enabled"] = bool(srv.get("enabled", "") == "1") - server["interface"] = f"wg{srv.get('instance', '')}" - server["tunnel_addresses"] = [] - for addr in srv.get("tunneladdress", {}).values(): - if addr.get("selected", 0) == 1 and addr.get("value", None): - server["tunnel_addresses"].append(addr.get("value")) - server["clients"] = [] - for peer_id, peer in srv.get("peers", {}).items(): - if peer.get("selected", 0) == 1 and peer.get("value", None): - server["clients"].append( - { - "name": peer.get("value"), - "uuid": peer_id, - "connected": False, - } - ) - server["connected_clients"] = 0 - server["total_bytes_recv"] = 0 - server["total_bytes_sent"] = 0 - servers[uid] = server + servers = { + uid: await OPNsenseClient._process_wireguard_server(uid, srv) + for uid, srv in server_summ.items() + if isinstance(srv, MutableMapping) + } + clients = { + uid: await OPNsenseClient._process_wireguard_client(uid, clnt, servers) + for uid, clnt in client_summ.items() + if isinstance(clnt, MutableMapping) + } - for uid, clnt in client_summ.items(): - if not isinstance(clnt, MutableMapping): - continue - client: MutableMapping[str, Any] = {} - for attr in ("name", "pubkey"): - if clnt.get(attr, None): - client[attr] = clnt.get(attr, None) - client["uuid"] = uid - client["enabled"] = bool(clnt.get("enabled", "0") == "1") - client["tunnel_addresses"] = [] - for addr in clnt.get("tunneladdress", {}).values(): - if addr.get("selected", 0) == 1 and addr.get("value", None): - client["tunnel_addresses"].append(addr.get("value")) - client["servers"] = [] - for srv_id, srv in clnt.get("servers", {}).items(): - if srv.get("selected", 0) == 1 and srv.get("value", None): - if servers.get(srv_id, None): - add_srv: MutableMapping[str, Any] = { - "name": servers[srv_id]["name"], - "uuid": srv_id, - "connected": False, - } - for attr in ("pubkey", "interface", "tunnel_addresses"): - if servers.get(srv_id, {}).get(attr, None): - add_srv[attr] = servers[srv_id][attr] - client["servers"].append(add_srv) - else: - client["servers"].append( - { - "name": srv.get("value"), - "uuid": srv_id, - "connected": False, - } - ) - for server in servers.values(): - if isinstance(server, MutableMapping) and isinstance( - server.get("clients", None), list - ): - match_cl: MutableMapping[str, Any] = {} - for cl in server.get("clients", {}): - if isinstance(cl, MutableMapping) and cl.get("uuid", None) == uid: - match_cl = cl - break - if match_cl: - for attr in ("name", "enabled", "pubkey", "tunnel_addresses"): - if client.get(attr, None): - match_cl[attr] = client.get(attr) - client["connected_servers"] = 0 - client["total_bytes_recv"] = 0 - client["total_bytes_sent"] = 0 - clients[uid] = client + await OPNsenseClient._update_wireguard_status(summary, servers, clients) + + wireguard = {"servers": servers, "clients": clients} + _LOGGER.debug("[get_wireguard] wireguard: %s", wireguard) + return wireguard + + @staticmethod + async def _process_wireguard_server( + uid: str, srv: MutableMapping[str, Any] + ) -> MutableMapping[str, Any]: + """Process a single WireGuard server entry.""" + return { + "uuid": uid, + "name": srv.get("name"), + "pubkey": srv.get("pubkey"), + "enabled": srv.get("enabled", "") == "1", + "interface": f"wg{srv.get('instance', '')}", + "dns_servers": [srv.get("peer_dns")] if srv.get("peer_dns") else [], + "tunnel_addresses": [ + addr.get("value") + for addr in srv.get("tunneladdress", {}).values() + if addr.get("selected") == 1 and addr.get("value") + ], + "clients": [ + { + "name": peer.get("value"), + "uuid": peer_id, + "connected": False, + } + for peer_id, peer in srv.get("peers", {}).items() + if peer.get("selected") == 1 and peer.get("value") + ], + "connected_clients": 0, + "total_bytes_recv": 0, + "total_bytes_sent": 0, + } + @staticmethod + async def _process_wireguard_client( + uid: str, clnt: MutableMapping[str, Any], servers: MutableMapping[str, Any] + ) -> MutableMapping[str, Any]: + """Process a single WireGuard client entry.""" + return { + "uuid": uid, + "name": clnt.get("name"), + "pubkey": clnt.get("pubkey"), + "enabled": clnt.get("enabled", "") == "1", + "tunnel_addresses": [ + addr.get("value") + for addr in clnt.get("tunneladdress", {}).values() + if addr.get("selected") == 1 and addr.get("value") + ], + "servers": [ + await OPNsenseClient._link_wireguard_client_to_server(srv_id, servers, srv) + for srv_id, srv in clnt.get("servers", {}).items() + if srv.get("selected") == 1 and srv.get("value") + ], + "connected_servers": 0, + "total_bytes_recv": 0, + "total_bytes_sent": 0, + } + + @staticmethod + async def _link_wireguard_client_to_server( + srv_id: str, servers: MutableMapping[str, Any], srv: MutableMapping[str, Any] + ) -> MutableMapping[str, Any]: + """Link a WireGuard client to its corresponding server.""" + if srv_id in servers: + server = servers[srv_id] + return { + "name": server.get("name"), + "uuid": srv_id, + "connected": False, + "pubkey": server.get("pubkey"), + "interface": server.get("interface"), + "tunnel_addresses": server.get("tunnel_addresses"), + } + return { + "name": srv.get("value"), + "uuid": srv_id, + "connected": False, + } + + @staticmethod + async def _update_wireguard_status( + summary: list[MutableMapping[str, Any]], + servers: MutableMapping[str, Any], + clients: MutableMapping[str, Any], + ) -> None: + """Update WireGuard server and client statuses based on the summary.""" for entry in summary: - if isinstance(entry, MutableMapping) and entry.get("type", "") == "interface": + if entry.get("type") == "interface": for server in servers.values(): - if ( - isinstance(server, MutableMapping) - and server.get("pubkey", "") == entry.get("public-key", "-") - and entry.get("status", None) - ): + if server.get("pubkey") == entry.get("public-key"): server["status"] = entry.get("status") - elif isinstance(entry, MutableMapping) and entry.get("type", "") == "peer": - for client in clients.values(): - if ( - isinstance(client, MutableMapping) - and client.get("pubkey", "") == entry.get("public-key", "-") - and isinstance(client.get("servers", None), list) - ): - client["connected_servers"] = 0 - for srv in client.get("servers", []): - if isinstance(srv, MutableMapping) and srv.get( - "interface", "" - ) == entry.get("if", "-"): - if ( - entry.get("endpoint", None) - and entry.get("endpoint", None) != "(none)" - ): - srv["endpoint"] = entry.get("endpoint") - if entry.get("transfer-rx", None): - srv["bytes_recv"] = entry.get("transfer-rx") - client["total_bytes_recv"] = int( - client.get("total_bytes_recv", 0) - ) + int(entry.get("transfer-rx", 0)) - if entry.get("transfer-tx", None): - srv["bytes_sent"] = entry.get("transfer-tx") - client["total_bytes_sent"] = int( - client.get("total_bytes_sent", 0) - ) + int(entry.get("transfer-tx", 0)) - if entry.get("latest-handshake", None): - srv["latest_handshake"] = datetime.fromtimestamp( - int(entry.get("latest-handshake", 0)), - tz=timezone( - datetime.now().astimezone().utcoffset() or timedelta() - ), - ) - srv["connected"] = wireguard_is_connected( - srv.get("latest_handshake", datetime.min) - ) - if srv["connected"]: - client["connected_servers"] += 1 - if client.get("latest_handshake", None) is None or client.get( - "latest_handshake" - ) < srv.get("latest_handshake", 0): - client["latest_handshake"] = srv.get("latest_handshake") - else: - srv["connected"] = False + elif entry.get("type") == "peer": + await OPNsenseClient._update_wireguard_peer_status(entry, servers, clients) - for server in servers.values(): - if ( - isinstance(server, MutableMapping) - and server.get("interface", "") == entry.get("if", "-") - and isinstance(server.get("clients", None), list) - ): - for clnt in server.get("clients", []): - if isinstance(clnt, MutableMapping) and clnt.get( - "pubkey", "" - ) == entry.get("public-key", "-"): - if ( - entry.get("endpoint", None) - and entry.get("endpoint", None) != "(none)" - ): - clnt["endpoint"] = entry.get("endpoint") - if entry.get("transfer-rx", None): - clnt["bytes_recv"] = entry.get("transfer-rx") - server["total_bytes_recv"] = int( - server.get("total_bytes_recv", 0) - ) + int(entry.get("transfer-rx", 0)) - if entry.get("transfer-tx", None): - clnt["bytes_sent"] = entry.get("transfer-tx") - server["total_bytes_sent"] = int( - server.get("total_bytes_sent", 0) - ) + int(entry.get("transfer-tx", 0)) - if entry.get("latest-handshake", None): - clnt["latest_handshake"] = datetime.fromtimestamp( - int(entry.get("latest-handshake", 0)), - tz=timezone( - datetime.now().astimezone().utcoffset() or timedelta() - ), - ) - clnt["connected"] = wireguard_is_connected( - clnt.get("latest_handshake", datetime.min) - ) - if clnt["connected"]: - server["connected_clients"] += 1 - if server.get("latest_handshake", None) is None or server.get( - "latest_handshake" - ) < clnt.get("latest_handshake", 0): - server["latest_handshake"] = clnt.get("latest_handshake") - else: - clnt["connected"] = False - - wireguard: MutableMapping[str, Any] = {"servers": servers, "clients": clients} - _LOGGER.debug("[get_wireguard] wireguard: %s", wireguard) - return wireguard + @staticmethod + async def _update_wireguard_peer_status( + entry: MutableMapping[str, Any], + servers: MutableMapping[str, Any], + clients: MutableMapping[str, Any], + ) -> None: + """Update the WireGuard peer status for clients and servers.""" + pubkey = entry.get("public-key", "-") + interface = entry.get("if", "-") + endpoint = entry.get("endpoint", None) + transfer_rx = int(entry.get("transfer-rx", 0)) + transfer_tx = int(entry.get("transfer-tx", 0)) + latest_handshake = int(entry.get("latest-handshake", 0)) + handshake_time = timestamp_to_datetime(latest_handshake) + is_connected = wireguard_is_connected(handshake_time) + + # Update servers + for server in servers.values(): + if server.get("interface") == interface: + for client in server.get("clients", []): + if client.get("pubkey") == pubkey: + await OPNsenseClient._update_wireguard_peer_details( + peer=client, + server_or_client=server, + endpoint=endpoint, + transfer_rx=transfer_rx, + transfer_tx=transfer_tx, + handshake_time=handshake_time, + is_connected=is_connected, + connection_counter_key="connected_clients", + ) + + # Update clients + for client in clients.values(): + if client.get("pubkey") == pubkey: + for server in client.get("servers", []): + if server.get("interface") == interface: + await OPNsenseClient._update_wireguard_peer_details( + peer=server, + server_or_client=client, + endpoint=endpoint, + transfer_rx=transfer_rx, + transfer_tx=transfer_tx, + handshake_time=handshake_time, + is_connected=is_connected, + connection_counter_key="connected_servers", + ) + + @staticmethod + async def _update_wireguard_peer_details( + peer: MutableMapping[str, Any], + server_or_client: MutableMapping[str, Any], + endpoint: str, + transfer_rx: int, + transfer_tx: int, + handshake_time: datetime | None, + is_connected: bool, + connection_counter_key: str, + ) -> None: + """Update details of WireGuard peers.""" + if endpoint and endpoint != "(none)": + peer["endpoint"] = endpoint + peer["bytes_recv"] = transfer_rx + peer["bytes_sent"] = transfer_tx + peer["latest_handshake"] = handshake_time + peer["connected"] = is_connected + + # Update the parent (server or client) stats + server_or_client["total_bytes_recv"] = ( + server_or_client.get("total_bytes_recv", 0) + transfer_rx + ) + server_or_client["total_bytes_sent"] = ( + server_or_client.get("total_bytes_sent", 0) + transfer_tx + ) + + if is_connected: + server_or_client[connection_counter_key] = ( + server_or_client.get(connection_counter_key, 0) + 1 + ) + # Update the latest handshake time if it's newer + if ( + server_or_client.get("latest_handshake") is None + or server_or_client["latest_handshake"] < handshake_time + ): + server_or_client["latest_handshake"] = handshake_time async def toggle_vpn_instance(self, vpn_type: str, clients_servers: str, uuid: str) -> bool: """Toggle the specified VPN instance on or off.""" if vpn_type == "openvpn": - success = await self._post(f"/api/openvpn/instances/toggle/{uuid}") - if not isinstance(success, MutableMapping) or not success.get("changed", False): + success = await self._safe_dict_post(f"/api/openvpn/instances/toggle/{uuid}") + if not success.get("changed", False): return False - reconfigure = await self._post("/api/openvpn/service/reconfigure") - if isinstance(reconfigure, MutableMapping): - return reconfigure.get("result", "") == "ok" - elif vpn_type == "wireguard": + reconfigure = await self._safe_dict_post("/api/openvpn/service/reconfigure") + return reconfigure.get("result", "") == "ok" + if vpn_type == "wireguard": if clients_servers == "clients": - success = await self._post(f"/api/wireguard/client/toggleClient/{uuid}") + success = await self._safe_dict_post(f"/api/wireguard/client/toggleClient/{uuid}") elif clients_servers == "servers": - success = await self._post(f"/api/wireguard/server/toggleServer/{uuid}") - if not isinstance(success, MutableMapping) or not success.get("changed", False): + success = await self._safe_dict_post(f"/api/wireguard/server/toggleServer/{uuid}") + if not success.get("changed", False): return False - reconfigure = await self._post("/api/wireguard/service/reconfigure") - if isinstance(reconfigure, MutableMapping): - return reconfigure.get("result", "") == "ok" + reconfigure = await self._safe_dict_post("/api/wireguard/service/reconfigure") + return reconfigure.get("result", "") == "ok" return False async def reload_interface(self, if_name: str) -> bool: """Reload the specified interface.""" - reload = await self._post(f"/api/interfaces/overview/reloadInterface/{if_name}") - if not isinstance(reload, MutableMapping): - return False + reload = await self._safe_dict_post(f"/api/interfaces/overview/reloadInterface/{if_name}") return reload.get("message", "").startswith("OK") async def get_certificates(self) -> MutableMapping[str, Any]: @@ -2045,10 +1975,8 @@ async def get_certificates(self) -> MutableMapping[str, Any]: return {} except awesomeversion.exceptions.AwesomeVersionCompareException: pass - certs_raw = await self._get("/api/trust/cert/search") - if not isinstance(certs_raw, MutableMapping) or not isinstance( - certs_raw.get("rows", None), list - ): + certs_raw = await self._safe_dict_get("/api/trust/cert/search") + if not isinstance(certs_raw.get("rows", None), list): return {} certs: MutableMapping[str, Any] = {} for cert in certs_raw.get("rows", None): @@ -2058,13 +1986,11 @@ async def get_certificates(self) -> MutableMapping[str, Any]: "issuer": cert.get("caref", None), "purpose": cert.get("rfc3280_purpose", None), "in_use": bool(cert.get("in_use", "0") == "1"), - "valid_from": datetime.fromtimestamp( - OPNsenseClient._try_to_int(cert.get("valid_from", None)) or 0, - tz=timezone(datetime.now().astimezone().utcoffset() or timedelta()), + "valid_from": timestamp_to_datetime( + OPNsenseClient._try_to_int(cert.get("valid_from", None)) or 0 ), - "valid_to": datetime.fromtimestamp( - OPNsenseClient._try_to_int(cert.get("valid_to", None)) or 0, - tz=timezone(datetime.now().astimezone().utcoffset() or timedelta()), + "valid_to": timestamp_to_datetime( + OPNsenseClient._try_to_int(cert.get("valid_to", None)) or 0 ), } _LOGGER.debug("[get_certificates] certs: %s", certs) @@ -2075,9 +2001,7 @@ async def generate_vouchers(self, data: MutableMapping[str, Any]) -> list: if data.get("voucher_server", None): server = data.get("voucher_server") else: - servers = await self._get("/api/captiveportal/voucher/listProviders") - if not isinstance(servers, list): - raise VoucherServerError(f"Error getting list of voucher servers: {servers}") + servers = await self._safe_list_get("/api/captiveportal/voucher/listProviders") if len(servers) == 0: raise VoucherServerError("No voucher servers exist") if len(servers) != 1: @@ -2090,12 +2014,10 @@ async def generate_vouchers(self, data: MutableMapping[str, Any]) -> list: payload.pop("voucher_server", None) voucher_url: str = f"/api/captiveportal/voucher/generateVouchers/{server_slug}/" _LOGGER.debug("[generate_vouchers] url: %s, payload: %s", voucher_url, payload) - vouchers = await self._post( + vouchers = await self._safe_list_post( voucher_url, payload=payload, ) - if not isinstance(vouchers, list): - raise VoucherServerError(f"Error returned requesting vouchers: {vouchers}") ordered_keys: list = [ "username", "password", @@ -2111,9 +2033,8 @@ async def generate_vouchers(self, data: MutableMapping[str, Any]) -> list: voucher["validity_str"] = human_friendly_duration(voucher.get("validity")) if voucher.get("expirytime", None): voucher["expiry_timestamp"] = voucher.get("expirytime") - voucher["expirytime"] = datetime.fromtimestamp( - OPNsenseClient._try_to_int(voucher.get("expirytime")) or 0, - tz=timezone(datetime.now().astimezone().utcoffset() or timedelta()), + voucher["expirytime"] = timestamp_to_datetime( + OPNsenseClient._try_to_int(voucher.get("expirytime")) or 0 ) rearranged_voucher: MutableMapping[str, Any] = { @@ -2128,23 +2049,19 @@ async def generate_vouchers(self, data: MutableMapping[str, Any]) -> list: async def kill_states(self, ip_addr) -> MutableMapping[str, Any]: """Kill the active states of the IP address.""" payload: MutableMapping[str, Any] = {"filter": ip_addr} - response = await self._post( + response = await self._safe_dict_post( "/api/diagnostics/firewall/kill_states/", payload=payload, ) _LOGGER.debug("[kill_states] ip_addr: %s, response: %s", ip_addr, response) - if not isinstance(response, MutableMapping): - return {"success": False, "dropped_states": 0} return { - "success": bool(response.get("result", None) == "ok"), + "success": bool(response.get("result", "") == "ok"), "dropped_states": response.get("dropped_states", 0), } async def toggle_alias(self, alias, toggle_on_off) -> bool: """Toggle alias on and off.""" - alias_list_resp = await self._get("/api/firewall/alias/searchItem") - if not isinstance(alias_list_resp, MutableMapping): - return False + alias_list_resp = await self._safe_dict_get("/api/firewall/alias/searchItem") alias_list: list = alias_list_resp.get("rows", []) if not isinstance(alias_list, list): return False @@ -2163,7 +2080,7 @@ async def toggle_alias(self, alias, toggle_on_off) -> bool: url = f"{url}/1" elif toggle_on_off == "off": url = f"{url}/0" - response = await self._post( + response = await self._safe_dict_post( url, payload=payload, ) @@ -2175,22 +2092,15 @@ async def toggle_alias(self, alias, toggle_on_off) -> bool: url, response, ) - if ( - not isinstance(response, MutableMapping) - or "result" not in response - or response.get("result") == "failed" - ): + if response.get("result") == "failed": return False - set_resp = await self._post("/api/firewall/alias/set") - if not isinstance(set_resp, MutableMapping) or set_resp.get("result") != "saved": + set_resp = await self._safe_dict_post("/api/firewall/alias/set") + if set_resp.get("result") != "saved": return False - reconfigure_resp = await self._post("/api/firewall/alias/reconfigure") - if ( - not isinstance(reconfigure_resp, MutableMapping) - or reconfigure_resp.get("status") != "ok" - ): + reconfigure_resp = await self._safe_dict_post("/api/firewall/alias/reconfigure") + if reconfigure_resp.get("status") != "ok": return False return True diff --git a/custom_components/opnsense/pyopnsense/const.py b/custom_components/opnsense/pyopnsense/const.py deleted file mode 100644 index b5f8e9c..0000000 --- a/custom_components/opnsense/pyopnsense/const.py +++ /dev/null @@ -1,24 +0,0 @@ -"""Constants for pyopnsense.""" - -from collections.abc import MutableMapping -from typing import Any - -from dateutil.tz import gettz - -AMBIGUOUS_TZINFOS: MutableMapping[str, Any] = { - "ACST": gettz("Australia/Darwin"), # Australian Central Standard Time - "ACT": gettz("America/Rio_Branco"), # Acre Time (Brazil) - "AEST": gettz("Australia/Sydney"), # Australian Eastern Standard Time - "AST": gettz("America/Halifax"), # Atlantic Standard Time (Caribbean/Canada) - "AWST": gettz("Australia/Perth"), # Australian Western Standard Time - "BST": gettz("Europe/London"), # British Summer Time - "CET": gettz("Europe/Paris"), # Central European Time - "CST": gettz("America/Chicago"), # Central Standard Time (North America) - "EET": gettz("Europe/Bucharest"), # Eastern European Time - "EST": gettz("America/New_York"), # Eastern Standard Time (North America) - "HST": gettz("Pacific/Honolulu"), # Hawaii-Aleutian Standard Time - "IST": gettz("Asia/Kolkata"), # Indian Standard Time - "MST": gettz("America/Denver"), # Mountain Standard Time (North America) - "NZST": gettz("Pacific/Auckland"), # New Zealand Standard Time - "PST": gettz("America/Los_Angeles"), # Pacific Standard Time (North America) -} diff --git a/pyproject.toml b/pyproject.toml index 0a565ea..7b803ea 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -156,7 +156,7 @@ select = [ ] ignore = [ - "C901", # Temporarily + # "C901", # McCabe cyclomatic complexity "D202", # No blank lines allowed after function docstring "D203", # 1 blank line required before class docstring "D213", # Multi-line docstring summary should start at the second line