Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use the zigpy serial protocol #75

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 12 additions & 49 deletions universal_silabs_flasher/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,42 +118,6 @@ async def wait_for_state(self, state: str) -> None:
self._futures_for_state[state].remove(future)


class SerialProtocol(asyncio.Protocol):
"""Base class for packet-parsing serial protocol implementations."""

def __init__(self) -> None:
self._buffer = bytearray()
self._transport: serial_asyncio.SerialTransport | None = None
self._connected_event = asyncio.Event()

async def wait_until_connected(self) -> None:
"""Wait for the protocol's transport to be connected."""
await self._connected_event.wait()

def connection_made(self, transport: serial_asyncio.SerialTransport) -> None:
_LOGGER.debug("Connection made: %s", transport)

self._transport = transport
self._connected_event.set()

def send_data(self, data: bytes) -> None:
"""Sends data over the connected transport."""
assert self._transport is not None
data = bytes(data)
_LOGGER.debug("Sending data %s", data)
self._transport.write(data)

def data_received(self, data: bytes) -> None:
_LOGGER.debug("Received data %s", data)
self._buffer += data

def disconnect(self) -> None:
if self._transport is not None:
self._transport.close()
self._buffer.clear()
self._connected_event.clear()


def patch_pyserial_asyncio() -> None:
"""Patches pyserial-asyncio's `SerialTransport` to support swapping protocols."""

Expand All @@ -176,23 +140,22 @@ def set_protocol(self, protocol: asyncio.Protocol) -> None:
@contextlib.asynccontextmanager
async def connect_protocol(port, baudrate, factory):
loop = asyncio.get_running_loop()

async with async_timeout.timeout(CONNECT_TIMEOUT):
_, protocol = await zigpy.serial.create_serial_connection(
loop=loop,
protocol_factory=factory,
url=port,
baudrate=baudrate,
)
await protocol.wait_until_connected()
protocol: zigpy.serial.SerialProtocol | None = None

try:
async with async_timeout.timeout(CONNECT_TIMEOUT):
_, protocol = await zigpy.serial.create_serial_connection(
loop=loop,
protocol_factory=factory,
url=port,
baudrate=baudrate,
)
await protocol.wait_until_connected()

yield protocol
finally:
protocol.disconnect()

# Required for Windows to be able to re-connect to the same serial port
await asyncio.sleep(0)
if protocol is not None:
await protocol.disconnect()


class CommaSeparatedNumbers(click.ParamType):
Expand Down
7 changes: 5 additions & 2 deletions universal_silabs_flasher/cpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@
import typing

import async_timeout
import zigpy.serial
import zigpy.types

from . import cpc_types
from .common import BufferTooShort, SerialProtocol, Version, crc16_ccitt
from .common import BufferTooShort, Version, crc16_ccitt

_LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -206,7 +207,7 @@ def poll_final(self) -> bool:
return bool((self.control & 0b00001000) >> 3)


class CPCProtocol(SerialProtocol):
class CPCProtocol(zigpy.serial.SerialProtocol):
"""Partial implementation of the CPC protocol."""

def __init__(self) -> None:
Expand Down Expand Up @@ -282,6 +283,8 @@ async def get_secondary_version(self) -> Version | None:
def data_received(self, data: bytes) -> None:
super().data_received(data)

self._buffer: bytearray

while self._buffer:
try:
frame, new_buffer = CPCTransportFrame.deserialize(self._buffer)
Expand Down
13 changes: 7 additions & 6 deletions universal_silabs_flasher/emberznet.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
import asyncio
import contextlib

import bellows.config
import bellows.ezsp
import bellows.types
import zigpy.config

AFTER_DISCONNECT_DELAY = 0.1


@contextlib.asynccontextmanager
async def connect_ezsp(port: str, baudrate: int = 115200) -> bellows.ezsp.EZSP:
Expand Down Expand Up @@ -38,10 +35,14 @@ async def connect_ezsp(port: str, baudrate: int = 115200) -> bellows.ezsp.EZSP:
}
)

ezsp = await bellows.ezsp.EZSP.initialize(app_config)
ezsp = bellows.ezsp.EZSP(app_config[zigpy.config.CONF_DEVICE])
await ezsp.connect(use_thread=False)
await ezsp.startup_reset()

# Writing config is required here because network info can't be loaded
await ezsp.write_config(app_config[bellows.config.CONF_EZSP_CONFIG])

try:
yield ezsp
finally:
ezsp.close()
await asyncio.sleep(AFTER_DISCONNECT_DELAY)
await ezsp.disconnect()
13 changes: 5 additions & 8 deletions universal_silabs_flasher/flasher.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,9 @@
import bellows.config
import bellows.ezsp
import bellows.types
import zigpy.serial

from .common import (
PROBE_TIMEOUT,
SerialProtocol,
Version,
connect_protocol,
pad_to_multiple,
)
from .common import PROBE_TIMEOUT, Version, connect_protocol, pad_to_multiple
from .const import DEFAULT_BAUDRATES, GPIO_CONFIGS, ApplicationType, ResetTarget
from .cpc import CPCProtocol
from .emberznet import connect_ezsp
Expand Down Expand Up @@ -83,7 +78,9 @@ async def enter_bootloader_reset(self, target):

async def enter_serial_bootloader(self):
baudrate = self._baudrates[ApplicationType.GECKO_BOOTLOADER][0]
async with connect_protocol(self._device, baudrate, SerialProtocol) as sonoff:
async with connect_protocol(
self._device, baudrate, zigpy.serial.SerialProtocol
) as sonoff:
serial = sonoff._transport.serial
serial.dtr = False
serial.rts = True
Expand Down
5 changes: 3 additions & 2 deletions universal_silabs_flasher/gecko_bootloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@
import typing

import async_timeout
import zigpy.serial

from .common import PROBE_TIMEOUT, SerialProtocol, StateMachine, Version
from .common import PROBE_TIMEOUT, StateMachine, Version
from .xmodemcrc import send_xmodem128_crc

_LOGGER = logging.getLogger(__name__)
Expand Down Expand Up @@ -55,7 +56,7 @@ class GeckoBootloaderOption(bytes, enum.Enum):
EBL_INFO = b"3"


class GeckoBootloaderProtocol(SerialProtocol):
class GeckoBootloaderProtocol(zigpy.serial.SerialProtocol):
def __init__(self) -> None:
super().__init__()
self._state_machine = StateMachine(
Expand Down
7 changes: 4 additions & 3 deletions universal_silabs_flasher/spinel.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
import typing

import async_timeout
import zigpy.serial
import zigpy.types

from .common import SerialProtocol, Version, crc16_kermit
from .common import Version, crc16_kermit
from .spinel_types import CommandID, HDLCSpecial, PropertyID, ResetReason

_LOGGER = logging.getLogger(__name__)
Expand Down Expand Up @@ -103,7 +104,7 @@ def serialize(self) -> bytes:
return self.header.serialize() + self.command_id.serialize() + self.data


class SpinelProtocol(SerialProtocol):
class SpinelProtocol(zigpy.serial.SerialProtocol):
def __init__(self) -> None:
super().__init__()
self._transaction_id: int = 1
Expand All @@ -112,7 +113,7 @@ def __init__(self) -> None:
def data_received(self, data: bytes) -> None:
super().data_received(data)

self._buffer = self._buffer.lstrip(bytes([HDLCSpecial.FLAG]))
self._buffer: bytearray = self._buffer.lstrip(bytes([HDLCSpecial.FLAG]))

if bytes([HDLCSpecial.FLAG]) not in self._buffer:
return
Expand Down
Loading