Skip to content

Commit

Permalink
Add additional typing (#363)
Browse files Browse the repository at this point in the history
  • Loading branch information
Snuffy2 authored Jan 18, 2025
1 parent 7bf90fe commit b6edb50
Show file tree
Hide file tree
Showing 7 changed files with 66 additions and 70 deletions.
23 changes: 13 additions & 10 deletions custom_components/opnsense/config_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def is_valid_mac_address(mac: str) -> bool:
return bool(mac_regex.match(mac))


def is_ip_address(value) -> bool:
def is_ip_address(value: str) -> bool:
"""Check if string is a valid IP address."""
try:
ipaddress.ip_address(value)
Expand Down Expand Up @@ -393,13 +393,13 @@ async def async_step_reconfigure(
},
)

async def async_step_import(self, user_input) -> ConfigFlowResult:
async def async_step_import(self, user_input: MutableMapping[str, Any]) -> ConfigFlowResult:
"""Handle import."""
return await self.async_step_user(user_input)

@staticmethod
@callback
def async_get_options_flow(config_entry):
def async_get_options_flow(config_entry: ConfigEntry):
"""Get the options flow for this handler."""
return OPNsenseOptionsFlow(config_entry)

Expand All @@ -412,7 +412,9 @@ def __init__(self, config_entry: ConfigEntry) -> None:
self.new_options: MutableMapping[str, Any] = {}
self.config_entry = config_entry

async def async_step_init(self, user_input=None) -> ConfigFlowResult:
async def async_step_init(
self, user_input: MutableMapping[str, Any] | None = None
) -> ConfigFlowResult:
"""Handle options flow."""
if user_input is not None:
_LOGGER.debug("[options_flow init] user_input: %s", user_input)
Expand Down Expand Up @@ -447,7 +449,9 @@ async def async_step_init(self, user_input=None) -> ConfigFlowResult:

return self.async_show_form(step_id="init", data_schema=vol.Schema(base_schema))

async def async_step_device_tracker(self, user_input=None) -> ConfigFlowResult:
async def async_step_device_tracker(
self, user_input: MutableMapping[str, Any] | None = None
) -> ConfigFlowResult:
"""Handle device tracker list step."""
url = self.config_entry.data[CONF_URL].strip()
username: str = self.config_entry.data[CONF_USERNAME]
Expand Down Expand Up @@ -512,18 +516,17 @@ async def async_step_device_tracker(self, user_input=None) -> ConfigFlowResult:
if user_input:
_LOGGER.debug("[options_flow device_tracker] user_input: %s", user_input)
macs: list = []
if isinstance(user_input.get(CONF_MANUAL_DEVICES, None), str) and user_input.get(
CONF_MANUAL_DEVICES, None
):
for item in user_input.get(CONF_MANUAL_DEVICES).split(","):
manual_devices: str | None = user_input.get(CONF_MANUAL_DEVICES)
if isinstance(manual_devices, str):
for item in manual_devices.split(","):
if not isinstance(item, str) or not item:
continue
item = item.strip()
if is_valid_mac_address(item):
macs.append(item)
_LOGGER.debug("[async_step_device_tracker] Manual Devices: %s", macs)
_LOGGER.debug("[async_step_device_tracker] Devices: %s", user_input.get(CONF_DEVICES))
self.new_options[CONF_DEVICES] = user_input.get(CONF_DEVICES) + macs
self.new_options[CONF_DEVICES] = user_input.get(CONF_DEVICES, []) + macs
return self.async_create_entry(title="", data=self.new_options)


Expand Down
4 changes: 2 additions & 2 deletions custom_components/opnsense/coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ async def _get_states(self, categories: list) -> MutableMapping[str, Any]:

return state

async def _async_update_data(self):
async def _async_update_data(self) -> dict[str, Any]:
"""Fetch the latest state from OPNsense."""
_LOGGER.info(
"%sUpdating Data",
Expand Down Expand Up @@ -294,7 +294,7 @@ async def _async_update_data(self):
@staticmethod
async def _calculate_speed(
prop_name: str,
elapsed_time,
elapsed_time: float,
current_parent_value: float,
previous_parent_value: float,
) -> tuple[str, int]:
Expand Down
6 changes: 3 additions & 3 deletions custom_components/opnsense/entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,15 @@ def available(self) -> bool:
return self._available

@property
def opnsense_device_name(self) -> str:
def opnsense_device_name(self) -> str | None:
"""Return the OPNsense device name."""
if self.config_entry.title and len(self.config_entry.title) > 0:
return self.config_entry.title
return self._get_opnsense_state_value("system_info.name")

def _get_opnsense_state_value(self, path, default=None):
def _get_opnsense_state_value(self, path: str) -> Any | None:
state = self.coordinator.data
return dict_get(state, path, default)
return dict_get(state, path)

def _get_opnsense_client(self) -> OPNsenseClient | None:
if self.hass is None:
Expand Down
4 changes: 2 additions & 2 deletions custom_components/opnsense/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
from urllib.parse import urlparse


def dict_get(data: MutableMapping[str, Any], path: str, default=None) -> Any | None:
def dict_get(data: MutableMapping[str, Any], path: str, default: Any | None = None) -> Any | None:
"""Parse the path to get the desired value out of the data."""
pathList: list = re.split(r"\.", path, flags=re.IGNORECASE)
result: MutableMapping[str, Any] = data
result: Any | None = data

for key in pathList:
if key.isnumeric():
Expand Down
56 changes: 24 additions & 32 deletions custom_components/opnsense/pyopnsense/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@
_LOGGER: logging.Logger = logging.getLogger(__name__)


def _log_errors(func: Callable):
async def inner(self, *args, **kwargs):
def _log_errors(func: Callable) -> Any:
async def inner(self, *args: Any, **kwargs: Any) -> Any:
try:
return await func(self, *args, **kwargs)
except asyncio.CancelledError:
Expand All @@ -51,8 +51,8 @@ async def inner(self, *args, **kwargs):
return inner


def _xmlrpc_timeout(func: Callable):
async def inner(self, *args, **kwargs):
def _xmlrpc_timeout(func: Callable) -> Any:
async def inner(self, *args: Any, **kwargs: Any) -> Any:
response = None
# timout applies to each recv() call, not the whole request
default_timeout = socket.getdefaulttimeout()
Expand All @@ -73,7 +73,7 @@ def wireguard_is_connected(past_time: datetime | None) -> bool:
return datetime.now().astimezone() - past_time <= timedelta(minutes=3)


def human_friendly_duration(seconds) -> str:
def human_friendly_duration(seconds: int) -> str:
"""Convert the duration in seconds to human friendly."""
months, seconds = divmod(
seconds, 2419200
Expand All @@ -100,7 +100,7 @@ def human_friendly_duration(seconds) -> str:
return ", ".join(duration)


def get_ip_key(item) -> tuple:
def get_ip_key(item: MutableMapping[str, Any]) -> tuple:
"""Use to sort the DHCP Lease IPs."""
address = item.get("address", None)

Expand All @@ -116,10 +116,10 @@ def get_ip_key(item) -> tuple:
return (0 if ip_obj.version == 4 else 1, ip_obj)


def dict_get(data: MutableMapping[str, Any], path: str, default=None):
def dict_get(data: MutableMapping[str, Any], path: str, default: Any | None = None) -> Any | None:
"""Parse the path to get the desired value out of the data."""
pathList = re.split(r"\.", path, flags=re.IGNORECASE)
result = data
pathList: list = re.split(r"\.", path, flags=re.IGNORECASE)
result: Any | None = data
for key in pathList:
if key.isnumeric():
key = int(key)
Expand Down Expand Up @@ -214,22 +214,14 @@ def _get_proxy(self) -> xmlrpc.client.ServerProxy:
f"{self._xmlrpc_url}/xmlrpc.php", context=context, verbose=verbose
)

# @_xmlrpc_timeout
async def _get_config_section(self, section) -> MutableMapping[str, Any]:
config: MutableMapping[str, Any] = await self.get_config()
if config is None or not isinstance(config, MutableMapping):
_LOGGER.error("Invalid data returned from get_config_section")
return {}
return config.get(section, {})

@_xmlrpc_timeout
async def _restore_config_section(self, section_name, data):
async def _restore_config_section(self, section_name: str, data) -> None:
params = {section_name: data}
proxy_method = partial(self._get_proxy().opnsense.restore_config_section, params)
return await self._loop.run_in_executor(None, proxy_method)
await self._loop.run_in_executor(None, proxy_method)

@_xmlrpc_timeout
async def _exec_php(self, script) -> MutableMapping[str, Any]:
async def _exec_php(self, script: str) -> MutableMapping[str, Any]:
self._xmlrpc_query_count += 1
script = rf"""
ini_set('display_errors', 0);
Expand Down Expand Up @@ -608,7 +600,7 @@ async def get_firmware_update_info(self) -> MutableMapping[str, Any]:
return status

@_log_errors
async def upgrade_firmware(self, type="update") -> MutableMapping[str, Any] | None:
async def upgrade_firmware(self, type: str = "update") -> MutableMapping[str, Any] | None:
"""Trigger a firmware upgrade."""
# minor updates of the same opnsense version
if type == "update":
Expand All @@ -627,7 +619,7 @@ async def upgrade_status(self) -> MutableMapping[str, Any]:
return await self._safe_dict_post("/api/core/firmware/upgradestatus")

@_log_errors
async def firmware_changelog(self, version) -> MutableMapping[str, Any]:
async def firmware_changelog(self, version: str) -> MutableMapping[str, Any]:
"""Return the changelog for the firmware upgrade."""
return await self._safe_dict_post(f"/api/core/firmware/changelog/{version}")

Expand All @@ -650,7 +642,7 @@ async def get_config(self) -> MutableMapping[str, Any]:
return ret_data

@_log_errors
async def enable_filter_rule_by_created_time(self, created_time) -> None:
async def enable_filter_rule_by_created_time(self, created_time: str) -> None:
"""Enable a filter rule."""
config = await self.get_config()
for rule in config["filter"]["rule"]:
Expand All @@ -667,7 +659,7 @@ async def enable_filter_rule_by_created_time(self, created_time) -> None:
await self._filter_configure()

@_log_errors
async def disable_filter_rule_by_created_time(self, created_time) -> None:
async def disable_filter_rule_by_created_time(self, created_time: str) -> None:
"""Disable a filter rule."""
config: MutableMapping[str, Any] = await self.get_config()

Expand All @@ -686,7 +678,7 @@ async def disable_filter_rule_by_created_time(self, created_time) -> None:

# use created_time as a unique_id since none other exists
@_log_errors
async def enable_nat_port_forward_rule_by_created_time(self, created_time) -> None:
async def enable_nat_port_forward_rule_by_created_time(self, created_time: str) -> None:
"""Enable a NAT Port Forward rule."""
config: MutableMapping[str, Any] = await self.get_config()
for rule in config.get("nat", {}).get("rule", []):
Expand All @@ -704,7 +696,7 @@ async def enable_nat_port_forward_rule_by_created_time(self, created_time) -> No

# use created_time as a unique_id since none other exists
@_log_errors
async def disable_nat_port_forward_rule_by_created_time(self, created_time) -> None:
async def disable_nat_port_forward_rule_by_created_time(self, created_time: str) -> None:
"""Disable a NAT Port Forward rule."""
config: MutableMapping[str, Any] = await self.get_config()
for rule in config.get("nat", {}).get("rule", []):
Expand All @@ -722,7 +714,7 @@ async def disable_nat_port_forward_rule_by_created_time(self, created_time) -> N

# use created_time as a unique_id since none other exists
@_log_errors
async def enable_nat_outbound_rule_by_created_time(self, created_time) -> None:
async def enable_nat_outbound_rule_by_created_time(self, created_time: str) -> None:
"""Enable NAT Outbound rule."""
config: MutableMapping[str, Any] = await self.get_config()
for rule in config.get("nat", {}).get("outbound", {}).get("rule", []):
Expand All @@ -740,7 +732,7 @@ async def enable_nat_outbound_rule_by_created_time(self, created_time) -> None:

# use created_time as a unique_id since none other exists
@_log_errors
async def disable_nat_outbound_rule_by_created_time(self, created_time) -> None:
async def disable_nat_outbound_rule_by_created_time(self, created_time: str) -> None:
"""Disable NAT Outbound Rule."""
config: MutableMapping[str, Any] = await self.get_config()
for rule in config.get("nat", {}).get("outbound", {}).get("rule", []):
Expand Down Expand Up @@ -1073,7 +1065,7 @@ async def system_halt(self) -> None:
return

@_log_errors
async def send_wol(self, interface, mac) -> bool:
async def send_wol(self, interface: str, mac: str) -> bool:
"""Send a wake on lan packet to the specified MAC address."""
payload: MutableMapping[str, Any] = {"wake": {"interface": interface, "mac": mac}}
_LOGGER.debug("[send_wol] payload: %s", payload)
Expand Down Expand Up @@ -1632,7 +1624,7 @@ async def get_notices(self) -> MutableMapping[str, Any]:
# _LOGGER.debug(f"[get_notices] notices: {notices}")

@_log_errors
async def close_notice(self, id) -> bool:
async def close_notice(self, id: str) -> bool:
"""Close selected notices."""

# id = "all" to close all notices
Expand Down Expand Up @@ -2046,7 +2038,7 @@ async def generate_vouchers(self, data: MutableMapping[str, Any]) -> list:
_LOGGER.debug("[generate_vouchers] vouchers: %s", vouchers)
return vouchers

async def kill_states(self, ip_addr) -> MutableMapping[str, Any]:
async def kill_states(self, ip_addr: str) -> MutableMapping[str, Any]:
"""Kill the active states of the IP address."""
payload: MutableMapping[str, Any] = {"filter": ip_addr}
response = await self._safe_dict_post(
Expand All @@ -2059,7 +2051,7 @@ async def kill_states(self, ip_addr) -> MutableMapping[str, Any]:
"dropped_states": response.get("dropped_states", 0),
}

async def toggle_alias(self, alias, toggle_on_off) -> bool:
async def toggle_alias(self, alias: str, toggle_on_off: str) -> bool:
"""Toggle alias on and off."""
alias_list_resp = await self._safe_dict_get("/api/firewall/alias/searchItem")
alias_list: list = alias_list_resp.get("rows", [])
Expand Down
21 changes: 11 additions & 10 deletions custom_components/opnsense/sensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,7 +479,7 @@ async def async_setup_entry(
async_add_entities(entities)


def slugify_filesystem_mountpoint(mountpoint) -> str:
def slugify_filesystem_mountpoint(mountpoint: str) -> str:
"""Slugify the mountpoint."""
if not mountpoint:
return ""
Expand All @@ -488,7 +488,7 @@ def slugify_filesystem_mountpoint(mountpoint) -> str:
return mountpoint.replace("/", "_").strip("_")


def normalize_filesystem_mountpoint(mountpoint) -> str:
def normalize_filesystem_mountpoint(mountpoint: str) -> str:
"""Normalize the mountpoint."""
if not mountpoint:
return ""
Expand All @@ -502,7 +502,7 @@ class OPNsenseSensor(OPNsenseEntity, SensorEntity):

def __init__(
self,
config_entry,
config_entry: ConfigEntry,
coordinator: OPNsenseDataUpdateCoordinator,
entity_description: SensorEntityDescription,
) -> None:
Expand Down Expand Up @@ -566,14 +566,15 @@ def _handle_coordinator_update(self) -> None:
if self.entity_description.key == "telemetry.cpu.usage_total":
temp_attr = self._get_opnsense_state_value("telemetry.cpu")
# _LOGGER.debug(f"[extra_state_attributes] temp_attr: {temp_attr}")
for k, v in temp_attr.items():
if k.startswith("usage_") and k != "usage_total":
self._attr_extra_state_attributes[k.replace("usage_", "")] = f"{v}%"
# _LOGGER.debug(f"[extra_state_attributes] attributes: {attributes}")
if isinstance(temp_attr, MutableMapping):
for k, v in temp_attr.items():
if k.startswith("usage_") and k != "usage_total":
self._attr_extra_state_attributes[k.replace("usage_", "")] = f"{v}%"
# _LOGGER.debug(f"[extra_state_attributes] attributes: {attributes}")
elif self.entity_description.key == "certificates":
self._attr_extra_state_attributes = self._get_opnsense_state_value(
self.entity_description.key
)
certs = self._get_opnsense_state_value(self.entity_description.key)
if isinstance(certs, MutableMapping):
self._attr_extra_state_attributes = dict(certs)

self.async_write_ha_state()

Expand Down
Loading

0 comments on commit b6edb50

Please sign in to comment.