Skip to content

Commit

Permalink
Merge pull request Uninett#149 from lunkwill42/feature/ntie
Browse files Browse the repository at this point in the history
Add notification protocol
  • Loading branch information
lunkwill42 authored Jan 25, 2024
2 parents c761716 + 90a9460 commit fc580af
Show file tree
Hide file tree
Showing 10 changed files with 431 additions and 10 deletions.
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -107,4 +107,6 @@ src_paths = ["src", "tests"]
exclude_also = [
# Don't need coverage for ellipsis used for type annotations
"...",
# Don't complain about lines excluded unless type checking
"if TYPE_CHECKING:",
]
43 changes: 39 additions & 4 deletions src/zino/api/legacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,17 @@
import textwrap
from functools import wraps
from pathlib import Path
from typing import Callable, List, Optional, Union
from typing import TYPE_CHECKING, Callable, List, Optional, Union

from zino import version
from zino.api import auth
from zino.api.notify import Zino1NotificationProtocol
from zino.state import ZinoState
from zino.statemodels import Event, EventState

if TYPE_CHECKING:
from zino.api.server import ZinoServer

_logger = logging.getLogger(__name__)


Expand All @@ -29,14 +33,22 @@ def requires_authentication(func: Callable) -> Callable:
class Zino1BaseServerProtocol(asyncio.Protocol):
"""Base implementation of the Zino 1 protocol, with a basic command dispatcher for subclasses to utilize."""

def __init__(self, state: Optional[ZinoState] = None, secrets_file: Optional[Union[Path, str]] = "secrets"):
def __init__(
self,
server: Optional["ZinoServer"] = None,
state: Optional[ZinoState] = None,
secrets_file: Optional[Union[Path, str]] = "secrets",
):
"""Initializes a protocol instance.
:param server: An optional instance of `ZinoServer`.
:param state: An optional reference to a running Zino state that this server should be based on. If omitted,
this protocol will create and work on an empty state object.
:param secrets_file: An optional alternative path to the file containing users and their secrets.
"""
self.server = server
self.transport: Optional[asyncio.Transport] = None
self.notification_channel: Optional[Zino1NotificationProtocol] = None
self._authenticated_as: Optional[str] = None
self._current_task: asyncio.Task = None
self._multiline_future: asyncio.Future = None
Expand All @@ -47,7 +59,7 @@ def __init__(self, state: Optional[ZinoState] = None, secrets_file: Optional[Uni
self._secrets_file = secrets_file

@property
def peer_name(self) -> str:
def peer_name(self) -> Optional[str]:
return self.transport.get_extra_info("peername") if self.transport else None

@property
Expand All @@ -65,10 +77,19 @@ def user(self, user_name: str):

def connection_made(self, transport: asyncio.Transport):
self.transport = transport
_logger.debug("New server connection from %s", self.peer_name)
_logger.info("New server connection from %s", self.peer_name)
if self.server:
self.server.active_clients.add(self)
self._authentication_challenge = auth.get_challenge()
self._respond_ok(f"{self._authentication_challenge} Hello, there")

def connection_lost(self, exc: Optional[Exception]) -> None:
_logger.info("Client disconnected: %s", self.peer_name)
if self.server:
self.server.active_clients.remove(self)
if self.notification_channel:
self.notification_channel.goodbye()

def data_received(self, data):
try:
message = data.decode().rstrip("\r\n")
Expand Down Expand Up @@ -313,6 +334,20 @@ async def do_community(self, router_name: str):
else:
self._respond_error("router unknown")

@requires_authentication
async def do_ntie(self, nonce: str):
"""Implements the NTIE command that ties together this session with a notification channel."""
try:
channel = self.server.notification_channels[nonce]
except (AttributeError, KeyError):
return self._respond_error("Could not find your notify socket")

self.notification_channel = channel
channel.tied_to = self
_logger.info("Client %s tied to notification channel %s", self.peer_name, channel.peer_name)

return self._respond_ok()


class ZinoTestProtocol(Zino1ServerProtocol):
"""Extended Zino 1 server protocol with test commands added in"""
Expand Down
118 changes: 118 additions & 0 deletions src/zino/api/notify.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
"""Notification channel implementation for Zino 2.0.
Notification channels are currently part of the legacy API from the Tcl-based Zino 1.0. They are a simple text-based,
line-oriented protocol. Clients are not expected to send any data to a notification channel, only receive data from
the server.
"""
import asyncio
import logging
from typing import TYPE_CHECKING, Any, Iterator, NamedTuple, Optional

from zino.api import auth
from zino.state import ZinoState
from zino.statemodels import Event, EventState

if TYPE_CHECKING:
from zino.api.legacy import Zino1ServerProtocol
from zino.api.server import ZinoServer

_logger = logging.getLogger(__name__)


class Notification(NamedTuple):
"""Represents the contents of a single notification"""

event_id: int
change_type: str
value: Any


class Zino1NotificationProtocol(asyncio.Protocol):
"""Basic implementation of the Zino 1 notification protocol"""

def __init__(self, server: Optional["ZinoServer"] = None, state: Optional[ZinoState] = None):
"""Initializes a protocol instance.
:param server: An optional instance of `ZinoServer`.
:param state: An optional reference to a running Zino state that this server should be based on. If omitted,
this protocol will create and work on an empty state object.
"""
self.server = server
self.transport: Optional[asyncio.Transport] = None
self.nonce: Optional[str] = None

self._state = state if state is not None else ZinoState()
self._tied_to: "Zino1ServerProtocol" = None

@property
def peer_name(self) -> Optional[str]:
return self.transport.get_extra_info("peername") if self.transport else None

def connection_made(self, transport: asyncio.Transport):
self.transport = transport
_logger.debug("New notification channel from %s", self.peer_name)
self.nonce = auth.get_challenge() # Challenges are also useful as nonces
if self.server:
self.server.notification_channels[self.nonce] = self
self._respond_raw(self.nonce)

def connection_lost(self, exc: Optional[Exception]) -> None:
_logger.info("Lost connection from %s: %s", self.peer_name, exc)
if self.server:
del self.server.notification_channels[self.nonce]

def goodbye(self):
"""Called by the tied server channel when that closes to gracefully close this channel too"""
self._respond_raw("Normal quit from client, closing down")
self.transport.close()

@property
def tied_to(self) -> Optional["Zino1ServerProtocol"]:
return self._tied_to

@tied_to.setter
def tied_to(self, client: "Zino1ServerProtocol") -> None:
self._tied_to = client

def notify(self, notification: Notification):
"""Sends a notification to the connected client"""
self._respond_raw(f"{notification.event_id} {notification.change_type} {notification.value}")

def _respond_raw(self, message: str):
"""Encodes and sends a response line to the connected client"""
self.transport.write(f"{message}\r\n".encode("utf-8"))

@classmethod
def build_and_send_notifications(
cls, server: "ZinoServer", new_event: Event, old_event: Optional[Event] = None
) -> None:
"""Prepares and sends notifications for all changes between old_event and new_event to all connected and tied
notification channels.
"""
notifications = list(cls.build_notifications(new_event, old_event))
tied_channels = [channel for channel in server.notification_channels.values() if channel.tied_to]
_logger.debug("Sending %s notifications to %s tied channels", len(notifications), len(tied_channels))

for notification in notifications:
for channel in tied_channels:
channel.notify(notification)

@classmethod
def build_notifications(cls, new_event: Event, old_event: Optional[Event] = None) -> Iterator[Notification]:
"""Generates a sequence of Notification objects from the changes detected between old_event and new_event.
If `old_event` is `None`, it is assumed the event is brand new, and only the state change from EMBRYONIC
matters.
"""
changed = new_event.get_changed_fields(old_event) if old_event else ["state"]

for attr in changed:
if attr == "state":
old_state = EventState.EMBRYONIC if not old_event else old_event.state
yield Notification(new_event.id, attr, f"{old_state.value} {new_event.state.value}")

elif attr in ("log", "history"):
yield Notification(new_event.id, attr, 1)

else:
yield Notification(new_event.id, "attr", attr)
48 changes: 48 additions & 0 deletions src/zino/api/server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import logging
from asyncio import AbstractEventLoop
from typing import Optional

from zino.api.legacy import Zino1ServerProtocol, ZinoTestProtocol
from zino.api.notify import Zino1NotificationProtocol
from zino.state import ZinoState
from zino.statemodels import Event

_logger = logging.getLogger(__name__)


class ZinoServer:
"""Represents the two asyncio servers that work in tandem to implement the Zino 1 legacy API:
Port 8001 is the text-based command interface.
Port 8002 is the text-based notification interface.
"""

API_PORT = 8001
NOTIFY_PORT = 8002

def __init__(self, loop: AbstractEventLoop, state: ZinoState):
self._loop = loop
self.state: ZinoState = state
self.active_clients: set[Zino1ServerProtocol] = set()
self.notification_channels: dict[str, Zino1NotificationProtocol] = {}
self.notify_server = self.api_server = None

def serve(self, address: str = "127.0.0.1"):
"""Sets up the two asyncio servers to serve in tandem 'forever'"""
api_coroutine = self._loop.create_server(
lambda: ZinoTestProtocol(server=self, state=self.state), address, self.API_PORT
)
self.api_server = self._loop.run_until_complete(api_coroutine)
_logger.info("Serving API on %r", self.api_server.sockets[0].getsockname())

notify_coroutine = self._loop.create_server(
lambda: Zino1NotificationProtocol(server=self, state=self.state), address, self.NOTIFY_PORT
)
self.notify_server = self._loop.run_until_complete(notify_coroutine)
_logger.info("Serving notifications on %r", self.notify_server.sockets[0].getsockname())

self.state.events.add_event_observer(self.on_event_commit)

def on_event_commit(self, new_event: Event, old_event: Optional[Event] = None) -> None:
"""Event observer to build notifications for notification channels"""
Zino1NotificationProtocol.build_and_send_notifications(self, new_event, old_event)
4 changes: 4 additions & 0 deletions src/zino/statemodels.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,10 @@ def zinoify_value(value: Any) -> str:
return str(int(value.timestamp()))
return str(value)

def get_changed_fields(self, other: "Event") -> List[str]:
"""Compares this Event to another Event and returns a list of names of attributes that are different"""
return [field for field in self.model_fields if getattr(other, field) != getattr(self, field)]


class PortStateEvent(Event):
type: Literal["portstate"] = "portstate"
Expand Down
7 changes: 3 additions & 4 deletions src/zino/zino.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import tzlocal

from zino import state
from zino.api.legacy import ZinoTestProtocol
from zino.api.server import ZinoServer
from zino.config.models import DEFAULT_INTERVAL_MINUTES
from zino.scheduler import get_scheduler, load_and_schedule_polldevs
from zino.statemodels import Event
Expand Down Expand Up @@ -47,9 +47,8 @@ def init_event_loop(args: argparse.Namespace):
state.state.events.add_event_observer(reschedule_dump_state_on_commit)

loop = asyncio.get_event_loop()
server = loop.create_server(lambda: ZinoTestProtocol(state=state.state), "127.0.0.1", 8001)
server_setup_result = loop.run_until_complete(server)
_log.info("Serving on %r", server_setup_result.sockets[0].getsockname())
server = ZinoServer(loop=loop, state=state.state)
server.serve()

if args.stop_in:
_log.info("Instructed to stop in %s seconds", args.stop_in)
Expand Down
62 changes: 60 additions & 2 deletions tests/api/legacy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,23 +4,26 @@
import pytest

from zino import version
from zino.api.auth import get_challenge
from zino.api.legacy import (
Zino1BaseServerProtocol,
Zino1ServerProtocol,
ZinoTestProtocol,
requires_authentication,
)
from zino.api.server import ZinoServer
from zino.config.models import PollDevice
from zino.state import ZinoState
from zino.statemodels import EventState, ReachabilityEvent


class TestZino1BaseServerProtocol:
def test_should_init_without_error(self):
assert Zino1BaseServerProtocol()

def test_when_unconnected_then_peer_name_should_be_none(self):
def test_when_not_connected_then_peer_name_should_be_none(self):
protocol = Zino1BaseServerProtocol()
assert not protocol.peer_name
assert protocol.peer_name is None

def test_when_connected_then_peer_name_should_be_available(self):
expected = "foobar"
Expand Down Expand Up @@ -178,6 +181,23 @@ async def test_when_command_raises_unhandled_exception_then_exception_should_be_
await protocol.data_received(b"RAISEERROR\r\n")
assert "ZeroDivisionError" in caplog.text

def test_when_connected_it_should_register_instance_in_server(self, event_loop):
server = ZinoServer(loop=event_loop, state=ZinoState())
protocol = Zino1BaseServerProtocol(server=server)
fake_transport = Mock()
protocol.connection_made(fake_transport)

assert protocol in server.active_clients

def test_when_disconnected_it_should_deregister_instance_from_server(self, event_loop):
server = ZinoServer(loop=event_loop, state=ZinoState())
protocol = Zino1BaseServerProtocol(server=server)
fake_transport = Mock()
protocol.connection_made(fake_transport)
protocol.connection_lost(exc=None)

assert protocol not in server.active_clients


class TestZino1ServerProtocolUserCommand:
@pytest.mark.asyncio
Expand Down Expand Up @@ -489,6 +509,44 @@ async def test_should_output_error_response_for_unknown_router(self, authenticat
assert b"500 router unknown\r\n" in output


class TestZino1ServerProtocolNtieCommand:
@pytest.mark.asyncio
async def test_when_nonce_is_bogus_it_should_respond_with_error(self, event_loop, authenticated_protocol):
server = ZinoServer(loop=event_loop, state=ZinoState())
server.notification_channels = dict() # Ensure there are none for this test
authenticated_protocol.server = server

await authenticated_protocol.data_received(b"NTIE cromulent\r\n")

output = authenticated_protocol.transport.data_buffer.getvalue().decode()
assert "\r\n500 " in output

@pytest.mark.asyncio
async def test_when_nonce_exists_it_should_respond_with_ok(self, event_loop, authenticated_protocol):
server = ZinoServer(loop=event_loop, state=ZinoState())
nonce = get_challenge()
mock_channel = Mock()
server.notification_channels[nonce] = mock_channel
authenticated_protocol.server = server

await authenticated_protocol.data_received(f"NTIE {nonce}\r\n".encode())

output = authenticated_protocol.transport.data_buffer.getvalue().decode()
assert "\r\n200 " in output

@pytest.mark.asyncio
async def test_when_nonce_exists_it_should_tie_the_corresponding_channel(self, event_loop, authenticated_protocol):
server = ZinoServer(loop=event_loop, state=ZinoState())
nonce = get_challenge()
mock_channel = Mock()
server.notification_channels[nonce] = mock_channel
authenticated_protocol.server = server

await authenticated_protocol.data_received(f"NTIE {nonce}\r\n".encode())

assert mock_channel.tied_to is authenticated_protocol


class TestZino1TestProtocol:
@pytest.mark.asyncio
async def test_when_authenticated_then_authtest_should_respond_with_ok(self):
Expand Down
Loading

0 comments on commit fc580af

Please sign in to comment.