From 6e960547d8315631c3e81d358114278bc174d1fc Mon Sep 17 00:00:00 2001 From: puddly <32534428+puddly@users.noreply.github.com> Date: Sun, 7 Jun 2020 15:10:03 -0400 Subject: [PATCH] Implement `flash_write` --- tests/test_tools_flash.py | 65 ++++++---- zigpy_znp/commands/ubl.py | 5 + .../tools/{flash_backup.py => flash_read.py} | 34 +----- zigpy_znp/tools/flash_write.py | 112 ++++++++++++++++++ 4 files changed, 163 insertions(+), 53 deletions(-) rename zigpy_znp/tools/{flash_backup.py => flash_read.py} (66%) create mode 100644 zigpy_znp/tools/flash_write.py diff --git a/tests/test_tools_flash.py b/tests/test_tools_flash.py index 7d724045..0618da3d 100644 --- a/tests/test_tools_flash.py +++ b/tests/test_tools_flash.py @@ -1,35 +1,30 @@ +import random import zigpy_znp.types as t import zigpy_znp.commands as c -from zigpy_znp.tools.flash_backup import main as flash_backup +from zigpy_znp.tools.flash_read import main as flash_read +from zigpy_znp.tools.flash_write import main as flash_write from test_api import pytest_mark_asyncio_timeout # noqa: F401 from test_application import znp_server # noqa: F401 from test_tools_nvram import openable_serial_znp_server # noqa: F401 -# Just random bytes -FAKE_FLASH = bytes.fromhex( - """ - a66ea64b2299ef91102c692c8739433776ac1f7967b2d7be3b532db5255dee88f49cad134ef4155375d2 - 67acecbe64637bd1df47ce1cb8b776caad7a7cd2b39892b69fbf2420176e598f689df05a3554400efb99 - 60dcedfb3416fe72b1570b6eb4aa877213afb92c7a6fc8b755e7457072a8c4d4ac9ec727b7748b267fda - 241334ab9195b4eb52cb50b396859c355dfad136e1c56b18f6599e08a7464524587a44ea0caaeb2b0a79 - 44ff74576db0c16b133f862de8ee8b6b37181a897416b40c589a645c62bbc6b2b4e993a6ee39ca1141bb - 7baeb7bb85476c7b905fa8f3f2148fe1162a218fb575eb3ed9849bc63212f7332a27f83c75e6590a25ad - 8ad3d13b212da0142bc257851afcc7c87c80c23d9f741f7159ccc89fed58ff2369523af224369df39224 - a4154dc2932958d3289d387356af931aa6e02d8216bffc3972674cf060de50c10e0705b2f80d7b54c763 - 0999d2f28f8e3b1917d89e960a1893ebdaa1695c5b2f1fc36efb144b326d4cb8119803ea327f2848b45a - a6e3e1ca93459eb848a8333826b12d87949be6cf652b1265a7c74e2b750303ee25f6296ed687393cb1a1 - 64648ae92eb2c426ea3f35770f6d64fefcd87fc9835ab39134be9a5d325cc2839a47515f15ce5b2072fe - 808a5e897a273f883751d029bec9fe89797fd2940603537770c745c17e817e495e4d8741e744b652254b - 2b776c1d313ca30a -""" -) +random.seed(12345) +FAKE_IMAGE_SIZE = 2 ** 10 +FAKE_FLASH = random.getrandbits(FAKE_IMAGE_SIZE * 8).to_bytes(FAKE_IMAGE_SIZE, "little") +random.seed() @pytest_mark_asyncio_timeout(seconds=5) -async def test_flash_backup(openable_serial_znp_server, tmp_path): # noqa: F811 +async def test_flash_backup_write( + openable_serial_znp_server, tmp_path, mocker # noqa: F811 +): + # It takes too long otherwise + mocker.patch("zigpy_znp.commands.ubl.IMAGE_SIZE", FAKE_IMAGE_SIZE) + + WRITABLE_FLASH = bytearray(len(FAKE_FLASH)) + openable_serial_znp_server.reply_to( request=c.UBL.HandshakeReq.Req(partial=True), responses=[ @@ -46,7 +41,7 @@ async def test_flash_backup(openable_serial_znp_server, tmp_path): # noqa: F811 def read_flash(req): offset = req.FlashWordAddr * 4 - data = FAKE_FLASH[offset : offset + 64] + data = WRITABLE_FLASH[offset : offset + 64] # We should not read partial blocks assert len(data) in (0, 64) @@ -60,11 +55,37 @@ def read_flash(req): Data=t.TrailingBytes(data), ) + def write_flash(req): + offset = req.FlashWordAddr * 4 + + assert len(req.Data) == 64 + + WRITABLE_FLASH[offset : offset + 64] = req.Data + assert len(WRITABLE_FLASH) == FAKE_IMAGE_SIZE + + return c.UBL.WriteRsp.Callback(Status=c.ubl.BootloaderStatus.SUCCESS) + openable_serial_znp_server.reply_to( request=c.UBL.ReadReq.Req(partial=True), responses=[read_flash] ) + openable_serial_znp_server.reply_to( + request=c.UBL.WriteReq.Req(partial=True), responses=[write_flash] + ) + + openable_serial_znp_server.reply_to( + request=c.UBL.EnableReq.Req(partial=True), + responses=[c.UBL.EnableRsp.Callback(Status=c.ubl.BootloaderStatus.SUCCESS)], + ) + + # First we write the flash + firmware_file = tmp_path / "firmware.bin" + firmware_file.write_bytes(FAKE_FLASH) + await flash_write([openable_serial_znp_server._port_path, "-i", str(firmware_file)]) + + # And then make a backup backup_file = tmp_path / "backup.bin" - await flash_backup([openable_serial_znp_server._port_path, "-o", str(backup_file)]) + await flash_read([openable_serial_znp_server._port_path, "-o", str(backup_file)]) + # They should be identical assert backup_file.read_bytes() == FAKE_FLASH diff --git a/zigpy_znp/commands/ubl.py b/zigpy_znp/commands/ubl.py index a3e83fa7..f4b1d538 100644 --- a/zigpy_znp/commands/ubl.py +++ b/zigpy_znp/commands/ubl.py @@ -4,6 +4,11 @@ import zigpy_znp.types as t +# Size of internal flash less 4 pages for boot loader, +# 6 pages for NV, & 1 page for lock bits. +IMAGE_SIZE = 0x40000 - 0x2000 - 0x3000 - 0x0800 +IMAGE_CRC_OFFSET = 0x90 + FLASH_WORD_SIZE = 4 diff --git a/zigpy_znp/tools/flash_backup.py b/zigpy_znp/tools/flash_read.py similarity index 66% rename from zigpy_znp/tools/flash_backup.py rename to zigpy_znp/tools/flash_read.py index 30ae353e..c1268edd 100644 --- a/zigpy_znp/tools/flash_backup.py +++ b/zigpy_znp/tools/flash_read.py @@ -16,31 +16,6 @@ LOGGER = logging.getLogger(__name__) -async def get_firmware_size(znp: ZNP, block_size: int) -> int: - valid_index = 0x0000 - - # Z-Stack lets you read beyond the end of the flash (???) if you go too high, - # instead of throwing an error. We need to be careful. - invalid_index = 0xFFFF // block_size - - while invalid_index - valid_index > 1: - midpoint = (valid_index + invalid_index) // 2 - - read_rsp = await znp.request_callback_rsp( - request=c.UBL.ReadReq.Req(FlashWordAddr=midpoint * block_size), - callback=c.UBL.ReadRsp.Callback(partial=True), - ) - - if read_rsp.Status == c.ubl.BootloaderStatus.SUCCESS: - valid_index = midpoint - elif read_rsp.Status == c.ubl.BootloaderStatus.FAILURE: - invalid_index = midpoint - else: - raise ValueError(f"Unexpected read response: {read_rsp}") - - return invalid_index * block_size - - async def read_firmware(radio_path: str) -> bytearray: znp = ZNP(CONFIG_SCHEMA({"device": {"path": radio_path}})) @@ -65,15 +40,12 @@ async def read_firmware(radio_path: str) -> bytearray: # All reads and writes are this size buffer_size = handshake_rsp.BufferSize - block_size = buffer_size // c.ubl.FLASH_WORD_SIZE - firmware_size = await get_firmware_size(znp, buffer_size) - - LOGGER.info("Total firmware size is %d", firmware_size) data = bytearray() - for address in range(0, firmware_size, block_size): - LOGGER.info("Progress: %0.2f%%", (100.0 * address) / firmware_size) + for offset in range(0, c.ubl.IMAGE_SIZE, buffer_size): + address = offset // c.ubl.FLASH_WORD_SIZE + LOGGER.info("Progress: %0.2f%%", (100.0 * offset) / c.ubl.IMAGE_SIZE) read_rsp = await znp.request_callback_rsp( request=c.UBL.ReadReq.Req(FlashWordAddr=address), diff --git a/zigpy_znp/tools/flash_write.py b/zigpy_znp/tools/flash_write.py new file mode 100644 index 00000000..96dc4cf7 --- /dev/null +++ b/zigpy_znp/tools/flash_write.py @@ -0,0 +1,112 @@ +import sys +import asyncio +import logging +import argparse +import coloredlogs +import async_timeout + +import zigpy_znp.types as t +import zigpy_znp.commands as c + +from zigpy_znp.api import ZNP +from zigpy_znp.config import CONFIG_SCHEMA + +coloredlogs.install(level=logging.DEBUG) +logging.getLogger("zigpy_znp").setLevel(logging.DEBUG) + +LOGGER = logging.getLogger(__name__) + + +async def write_firmware(firmware: bytes, radio_path: str): + if len(firmware) != c.ubl.IMAGE_SIZE: + raise ValueError( + f"Firmware is the wrong size." + f" Expected {c.ubl.IMAGE_SIZE}, got {len(firmware)}" + ) + + znp = ZNP(CONFIG_SCHEMA({"device": {"path": radio_path}})) + + # The bootloader handshake must be the very first command + await znp.connect(test_port=False) + + try: + async with async_timeout.timeout(5): + handshake_rsp = await znp.request_callback_rsp( + request=c.UBL.HandshakeReq.Req(), + callback=c.UBL.HandshakeRsp.Callback(partial=True), + ) + except asyncio.TimeoutError: + raise RuntimeError( + "Did not receive a bootloader handshake response!" + " Make sure your adapter has just been plugged in and" + " nothing else has had a chance to communicate with it." + ) + + if handshake_rsp.Status != c.ubl.BootloaderStatus.SUCCESS: + raise RuntimeError(f"Bad bootloader handshake response: {handshake_rsp}") + + # All reads and writes are this size + buffer_size = handshake_rsp.BufferSize + + for offset in range(0, c.ubl.IMAGE_SIZE, buffer_size): + address = offset // c.ubl.FLASH_WORD_SIZE + LOGGER.info("Write progress: %0.2f%%", (100.0 * offset) / c.ubl.IMAGE_SIZE) + + write_rsp = await znp.request_callback_rsp( + request=c.UBL.WriteReq.Req( + FlashWordAddr=address, + Data=t.TrailingBytes(firmware[offset : offset + buffer_size]), + ), + callback=c.UBL.WriteRsp.Callback(partial=True), + ) + + assert write_rsp.Status == c.ubl.BootloaderStatus.SUCCESS + + # Now we have to read it all back + # TODO: figure out how the CRC is computed! + for offset in range(0, c.ubl.IMAGE_SIZE, buffer_size): + address = offset // c.ubl.FLASH_WORD_SIZE + LOGGER.info( + "Verification progress: %0.2f%%", (100.0 * offset) / c.ubl.IMAGE_SIZE + ) + + read_rsp = await znp.request_callback_rsp( + request=c.UBL.ReadReq.Req(FlashWordAddr=address,), + callback=c.UBL.ReadRsp.Callback(partial=True), + ) + + assert read_rsp.Status == c.ubl.BootloaderStatus.SUCCESS + assert read_rsp.FlashWordAddr == address + assert read_rsp.Data == firmware[offset : offset + buffer_size] + + # This seems to cause the firmware to compute and verify the CRC + enable_rsp = await znp.request_callback_rsp( + request=c.UBL.EnableReq.Req(), callback=c.UBL.EnableRsp.Callback(partial=True), + ) + + assert enable_rsp.Status == c.ubl.BootloaderStatus.SUCCESS + + +async def main(argv): + parser = argparse.ArgumentParser(description="Write firmware to a radio") + parser.add_argument("serial", type=argparse.FileType("rb"), help="Serial port path") + parser.add_argument( + "--input", + "-i", + type=argparse.FileType("rb"), + help="Input .bin file", + required=True, + ) + + args = parser.parse_args(argv) + + # We just want to make sure it exists + args.serial.close() + + await write_firmware(args.input.read(), args.serial.name) + + LOGGER.info("Unplug your adapter to leave bootloader mode!") + + +if __name__ == "__main__": + asyncio.run(main(sys.argv[1:])) # pragma: no cover