Skip to content

Commit

Permalink
Update / correct typing
Browse files Browse the repository at this point in the history
  • Loading branch information
nocarryr committed Dec 22, 2023
1 parent 553576c commit ecff768
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 53 deletions.
11 changes: 6 additions & 5 deletions src/tslumd/common.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations
import enum
from typing import Tuple, Iterable
from typing import Tuple, Iterator

__all__ = ('TallyColor', 'TallyType', 'TallyState', 'MessageType', 'TallyKey')

Expand Down Expand Up @@ -32,7 +33,7 @@ class TallyColor(enum.IntFlag):
AMBER = RED | GREEN #: Amber

@staticmethod
def from_str(s: str) -> 'TallyColor':
def from_str(s: str) -> TallyColor:
"""Return the member matching the given name (case-insensitive)
>>> TallyColor.from_str('RED')
Expand Down Expand Up @@ -112,7 +113,7 @@ def is_iterable(self) -> bool:
return self.name is None

@classmethod
def all(cls) -> Iterable['TallyType']:
def all(cls) -> Iterator[TallyType]:
"""Iterate over all members, excluding :attr:`no_tally` and :attr:`all_tally`
.. versionadded:: 0.0.4
Expand All @@ -122,7 +123,7 @@ def all(cls) -> Iterable['TallyType']:
yield ttype

@staticmethod
def from_str(s: str) -> 'TallyType':
def from_str(s: str) -> TallyType:
"""Create an instance from a string of member name(s)
The string can be a single member or multiple member names separated by
Expand Down Expand Up @@ -169,7 +170,7 @@ def to_str(self) -> str:
return '|'.join((obj.name for obj in self))
return self.name

def __iter__(self):
def __iter__(self) -> Iterator[TallyType]:
for ttype in self.all():
if ttype in self:
yield ttype
Expand Down
28 changes: 15 additions & 13 deletions src/tslumd/messages.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from __future__ import annotations
import asyncio
import dataclasses
from dataclasses import dataclass, field
import enum
import struct
from typing import List, Tuple, Dict, Iterable, Optional
from typing import Tuple, Iterator, Any, cast

from tslumd import MessageType, TallyColor, Tally

Expand Down Expand Up @@ -111,7 +112,7 @@ def __post_init__(self):
raise ValueError('Control message cannot contain text')

@classmethod
def broadcast(cls, **kwargs) -> 'Display':
def broadcast(cls, **kwargs) -> Display:
"""Create a :attr:`broadcast <is_broadcast>` display
(with :attr:`index` set to ``0xffff``)
Expand All @@ -123,7 +124,7 @@ def broadcast(cls, **kwargs) -> 'Display':
return cls(**kwargs)

@classmethod
def from_dmsg(cls, flags: Flags, dmsg: bytes) -> Tuple['Display', bytes]:
def from_dmsg(cls, flags: Flags, dmsg: bytes) -> Tuple[Display, bytes]:
"""Construct an instance from a ``DMSG`` portion of received message.
Any remaining message data after the relevant ``DMSG`` is returned along
Expand All @@ -132,9 +133,10 @@ def from_dmsg(cls, flags: Flags, dmsg: bytes) -> Tuple['Display', bytes]:
if len(dmsg) < 4:
raise DmsgParseError('Invalid dmsg length', dmsg)
hdr = struct.unpack('<2H', dmsg[:4])
hdr = cast(Tuple[int, int], hdr)
dmsg = dmsg[4:]
ctrl = hdr[1]
kw = dict(
kw: dict[str, Any] = dict(
index=hdr[0],
rh_tally=TallyColor(ctrl & 0b11),
txt_tally=TallyColor(ctrl >> 2 & 0b11),
Expand Down Expand Up @@ -168,7 +170,7 @@ def from_dmsg(cls, flags: Flags, dmsg: bytes) -> Tuple['Display', bytes]:
return cls(**kw), dmsg

@staticmethod
def _unpack_control_data(data: bytes) -> bytes:
def _unpack_control_data(data: bytes) -> Tuple[bytes, bytes]:
"""Unpack control data (if control bit 15 is set)
Arguments:
Expand Down Expand Up @@ -239,13 +241,13 @@ def to_dmsg(self, flags: Flags) -> bytes:
data.extend(txt_bytes)
return data

def to_dict(self) -> Dict:
def to_dict(self) -> dict:
d = dataclasses.asdict(self)
del d['is_broadcast']
return d

@classmethod
def from_tally(cls, tally: Tally, msg_type: MessageType = MessageType.display) -> 'Display':
def from_tally(cls, tally: Tally, msg_type: MessageType = MessageType.display) -> Display:
"""Create a :class:`Display` from the given :class:`~.Tally`
.. versionadded:: 0.0.2
Expand Down Expand Up @@ -290,9 +292,9 @@ class Message:
"""A single UMDv5 message packet
"""
version: int = 0 #: Protocol minor version
flags: int = Flags.NO_FLAGS #: The message :class:`Flags` field
flags: Flags = Flags.NO_FLAGS #: The message :class:`Flags` field
screen: int = 0 #: Screen index from 0 to 65534 (``0xFFFE``)
displays: List[Display] = field(default_factory=list)
displays: list[Display] = field(default_factory=list)
"""A list of :class:`Display` instances"""

scontrol: bytes = b''
Expand Down Expand Up @@ -338,7 +340,7 @@ def __post_init__(self):
self.type = MessageType.display

@classmethod
def broadcast(cls, **kwargs) -> 'Message':
def broadcast(cls, **kwargs) -> Message:
"""Create a :attr:`broadcast <is_broadcast>` message
(with :attr:`screen` set to ``0xffff``)
Expand All @@ -350,7 +352,7 @@ def broadcast(cls, **kwargs) -> 'Message':
return cls(**kwargs)

@classmethod
def parse(cls, msg: bytes) -> Tuple['Message', bytes]:
def parse(cls, msg: bytes) -> Tuple[Message, bytes]:
"""Parse incoming message data to create a :class:`Message` instance.
Any remaining message data after parsing is returned along with the instance.
Expand Down Expand Up @@ -382,7 +384,7 @@ def parse(cls, msg: bytes) -> Tuple['Message', bytes]:
obj.displays.append(disp)
return obj, remaining

def build_message(self, ignore_packet_length: Optional[bool] = False) -> bytes:
def build_message(self, ignore_packet_length: bool = False) -> bytes:
"""Build a message packet from data in this instance
Arguments:
Expand Down Expand Up @@ -414,7 +416,7 @@ def build_message(self, ignore_packet_length: Optional[bool] = False) -> bytes:
raise MessageLengthError()
return data

def build_messages(self, ignore_packet_length: Optional[bool] = False) -> Iterable[bytes]:
def build_messages(self, ignore_packet_length: bool = False) -> Iterator[bytes]:
"""Build message packet(s) from data in this instance as an iterator
The specified maximum packet length of 2048 is respected and if
Expand Down
7 changes: 4 additions & 3 deletions src/tslumd/receiver.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from __future__ import annotations
try:
from loguru import logger
except ImportError: # pragma: no cover
import logging
logger = logging.getLogger(__name__)
import asyncio
from typing import Dict, Tuple, Set, Optional
from typing import Tuple

from pydispatch import Dispatcher, Property, DictProperty, ListProperty

Expand Down Expand Up @@ -70,7 +71,7 @@ class UmdReceiver(Dispatcher):
DEFAULT_HOST: str = '0.0.0.0' #: The default host address to listen on
DEFAULT_PORT: int = 65000 #: The default host port to listen on

screens: Dict[int, Screen]
screens: dict[int, Screen]
"""Mapping of :class:`~.Screen` objects by :attr:`~.Screen.index`
.. versionadded:: 0.0.3
Expand All @@ -82,7 +83,7 @@ class UmdReceiver(Dispatcher):
.. versionadded:: 0.0.3
"""

tallies: Dict[TallyKey, Tally]
tallies: dict[TallyKey, Tally]
"""Mapping of :class:`~.Tally` objects by their :attr:`~.Tally.id`
.. versionchanged:: 0.0.3
Expand Down
27 changes: 15 additions & 12 deletions src/tslumd/sender.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from __future__ import annotations
try:
from loguru import logger
except ImportError: # pragma: no cover
Expand All @@ -6,7 +7,7 @@
import asyncio
import socket
import argparse
from typing import Dict, Tuple, Set, Optional, Sequence, Iterable
from typing import Tuple, Iterable

from pydispatch import Dispatcher, Property, DictProperty, ListProperty

Expand Down Expand Up @@ -49,13 +50,13 @@ class UmdSender(Dispatcher):
The ``all_off_on_close`` parameter was added
"""

screens: Dict[int, Screen]
screens: dict[int, Screen]
"""Mapping of :class:`~.Screen` objects by :attr:`~.Screen.index`
.. versionadded:: 0.0.3
"""

tallies: Dict[TallyKey, Tally]
tallies: dict[TallyKey, Tally]
"""Mapping of :class:`~.Tally` objects by their :attr:`~.Tally.id`
Note:
Expand Down Expand Up @@ -83,7 +84,7 @@ class UmdSender(Dispatcher):
"""Interval to send tally messages, regardless of state changes
"""

clients: Set[Client]
clients: set[Client]
"""Set of :data:`clients <Client>` to send messages to
"""

Expand All @@ -95,8 +96,8 @@ class UmdSender(Dispatcher):
"""

def __init__(self,
clients: Optional[Iterable[Client]] = None,
all_off_on_close: Optional[bool] = False):
clients: Iterable[Client]|None = None,
all_off_on_close: bool = False):
self.clients = set()
if clients is not None:
for client in clients:
Expand All @@ -110,7 +111,7 @@ def __init__(self,
assert screen.is_broadcast
self.screens[screen.index] = screen
self._bind_screen(screen)
self.update_queue = asyncio.PriorityQueue()
self.update_queue: asyncio.PriorityQueue[TallyKey|Tuple[int, bool]]|None = None
self.update_task = None
self.tx_task = None
self.connected_evt = asyncio.Event()
Expand Down Expand Up @@ -338,7 +339,7 @@ async def send_broadcast_tally(self, screen_index: int, **kwargs):
oth_tally.update_from_display(disp)
self._bind_screen(screen)

async def on_tally_updated(self, tally: Tally, props_changed: Set[str], **kwargs):
async def on_tally_updated(self, tally: Tally, props_changed: set[str], **kwargs):
if self.running:
if set(props_changed) == set(['control']):
return
Expand All @@ -349,6 +350,7 @@ async def on_tally_control(self, tally: Tally, data: bytes, **kwargs):
if self.running:
async with self._tx_lock:
disp = Display.from_tally(tally, msg_type=MessageType.control)
assert tally.screen is not None
msg = self._build_message(
screen=tally.screen.index,
displays=[disp],
Expand Down Expand Up @@ -389,8 +391,9 @@ async def get_queue_item(timeout):
if item is False:
self.update_queue.task_done()
break
elif item is None and not self._tx_lock.locked():
await self.send_full_update()
elif item is None:
if not self._tx_lock.locked():
await self.send_full_update()
else:
screen_index, _ = item
ids = set([item])
Expand Down Expand Up @@ -476,13 +479,13 @@ def __init__(self,
)

def __call__(self, parser, namespace, values, option_string=None):
addr, port = values.split(':')
addr, port = values.split(':') # type: ignore
values = (addr, int(port))
items = getattr(namespace, self.dest, None)
if items == [('127.0.0.1', 65000)]:
items = []
else:
items = argparse._copy_items(items)
items = argparse._copy_items(items) # type: ignore
items.append(values)
setattr(namespace, self.dest, items)

Expand Down
Loading

0 comments on commit ecff768

Please sign in to comment.