diff --git a/bellows/ezsp/__init__.py b/bellows/ezsp/__init__.py index c7aefeb9..30d1e0a7 100644 --- a/bellows/ezsp/__init__.py +++ b/bellows/ezsp/__init__.py @@ -642,12 +642,21 @@ async def get_supported_firmware_features( self, ) -> custom_commands.FirmwareFeatures: """Get supported firmware extensions.""" + req = custom_commands.CustomCommand( + command_id=custom_commands.CustomCommandId.CMD_GET_SUPPORTED_FEATURES_REQ, + payload=custom_commands.GetSupportedFeaturesReq().serialize(), + ) + try: - status, rsp_data = await self.customFrame( - bytes([custom_commands.CustomCommand.CMD_GET_SUPPORTED_FEATURES]), - ) + status, data = await self.customFrame(req.serialize()) except InvalidCommandError: return custom_commands.FirmwareFeatures(0) - features, _ = custom_commands.FirmwareFeatures.deserialize(rsp_data) - return features + rsp_cmd, _ = custom_commands.CustomCommand.deserialize(data) + assert ( + rsp_cmd.command_id + == custom_commands.CustomCommandId.CMD_GET_SUPPORTED_FEATURES_RSP + ) + + rsp, _ = custom_commands.GetSupportedFeaturesRsp.deserialize(rsp_cmd.payload) + return rsp.features diff --git a/bellows/ezsp/custom_commands.py b/bellows/ezsp/custom_commands.py index e03b0767..e5edb6e4 100644 --- a/bellows/ezsp/custom_commands.py +++ b/bellows/ezsp/custom_commands.py @@ -1,12 +1,48 @@ """Custom EZSP commands.""" +from __future__ import annotations import zigpy.types as t -class CustomCommand(t.enum8): - CMD_GET_SUPPORTED_FEATURES = 0x00 +class Bytes(bytes): + def serialize(self) -> Bytes: + return self + + @classmethod + def deserialize(cls, data: bytes) -> tuple[Bytes, bytes]: + return cls(data), b"" + + def __repr__(self) -> str: + # Reading byte sequences like \x200\x21 is extremely annoying + # compared to \x20\x30\x21 + escaped = "".join(f"\\x{b:02X}" for b in self) + + return f"b'{escaped}'" + + __str__ = __repr__ + + +class CustomCommandId(t.enum16): + CMD_GET_PROTOCOL_VERSION_REQ = 0x0000 + CMD_GET_PROTOCOL_VERSION_RSP = 0x8000 + + CMD_GET_SUPPORTED_FEATURES_REQ = 0x0001 + CMD_GET_SUPPORTED_FEATURES_RSP = 0x8001 + + +class CustomCommand(t.Struct): + command_id: CustomCommandId + payload: Bytes class FirmwareFeatures(t.bitmap32): # The firmware passes through all group traffic, regardless of group membership MEMBER_OF_ALL_GROUPS = 0b00000000_00000000_00000000_00000001 + + +class GetSupportedFeaturesReq(t.Struct): + pass + + +class GetSupportedFeaturesRsp(t.Struct): + features: FirmwareFeatures