Skip to content
This repository has been archived by the owner on Mar 24, 2024. It is now read-only.

Commit

Permalink
Merge pull request #259 from quantmind/master
Browse files Browse the repository at this point in the history
2.1.2
  • Loading branch information
lsbardel authored Jan 9, 2021
2 parents 9512a63 + b9200c0 commit 4b811f5
Show file tree
Hide file tree
Showing 8 changed files with 155 additions and 33 deletions.
77 changes: 77 additions & 0 deletions docs/websocket.rst
Original file line number Diff line number Diff line change
Expand Up @@ -99,3 +99,80 @@ Messages take the form:
...
}
}
Backend
========

The websocket backend is implemented by subclassing the :class:`.SocketsManager` and implement the methods required by your application.
This example implements a very simple backend for testing the websocket module in unittests.


.. code-block:: python
import asyncio
from aiohttp import web
from openapi.ws.manager import SocketsManager
class LocalBroker(SocketsManager):
"""A local broker for testing"""
def __init__(self):
self.binds = set()
self.messages: asyncio.Queue = asyncio.Queue()
self.worker = None
self._stop = False
@classmethod
def for_app(cls, app: web.Application) -> "LocalBroker":
broker = cls()
app.on_startup.append(broker.start)
app.on_shutdown.append(broker.close)
return broker
async def start(self, *arg):
if not self.worker:
self.worker = asyncio.ensure_future(self._work())
async def publish(self, channel: str, event: str, body: Any):
"""simulate network latency"""
if channel.lower() != channel:
raise CannotPublish
payload = dict(event=event, data=self.get_data(body))
asyncio.get_event_loop().call_later(
0.01, self.messages.put_nowait, (channel, payload)
)
async def subscribe(self, channel: str) -> None:
""" force channel names to be lowercase"""
if channel.lower() != channel:
raise CannotSubscribe
async def close(self, *arg):
self._stop = True
await self.close_sockets()
if self.worker:
self.messages.put_nowait((None, None))
await self.worker
self.worker = None
async def _work(self):
while True:
channel, body = await self.messages.get()
if self._stop:
break
await self.channels(channel, body)
def get_data(self, data: Any) -> Any:
if data == "error":
return self.raise_error
elif data == "runtime_error":
return self.raise_runtime
return data
def raise_error(self):
raise ValueError
def raise_runtime(self):
raise RuntimeError
2 changes: 1 addition & 1 deletion openapi/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
"""Minimal OpenAPI asynchronous server application"""

__version__ = "2.1.1"
__version__ = "2.1.2"
15 changes: 6 additions & 9 deletions openapi/ws/channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,25 +20,22 @@ class Event:
callbacks: Set[CallbackType] = field(default_factory=set)


@dataclass
class Channel:
"""A websocket channel"""

def __init__(self, name: str) -> None:
self.name: str = name
self._events: Dict[str, Event] = {}
name: str
_events: Dict[str, Event] = field(default_factory=dict)

@property
def events(self):
"""List of event names this channel is registered with"""
return tuple((e.name for e in self.events.values()))

def __repr__(self):
return self.name
return tuple((e.name for e in self._events.values()))

def __len__(self):
def __len__(self) -> int:
return len(self._events)

def __contains__(self, pattern):
def __contains__(self, pattern: str) -> bool:
return pattern in self._events

def __iter__(self):
Expand Down
34 changes: 18 additions & 16 deletions openapi/ws/channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from .channel import CallbackType, Channel
from .errors import CannotSubscribe

if TYPE_CHECKING:
if TYPE_CHECKING: # pragma: no cover
from .manager import SocketsManager


Expand All @@ -14,35 +14,38 @@ class Channels:

def __init__(self, sockets: "SocketsManager") -> None:
self.sockets: "SocketsManager" = sockets
self.channels: Dict[str, Channel] = {}
self._channels: Dict[str, Channel] = {}

@property
def registered(self) -> Tuple[str, ...]:
"""Registered channels"""
return tuple(self.channels)
return tuple(self._channels)

def __len__(self) -> int:
return len(self.channels)
return len(self._channels)

def __contains__(self, channel_name: str) -> bool:
return channel_name in self.channels
return channel_name in self._channels

def __iter__(self) -> Iterator[Channel]:
return iter(self.channels.values())
return iter(self._channels.values())

def clear(self) -> None:
self.channels.clear()
self._channels.clear()

def get(self, channel_name: str) -> Optional[Channel]:
return self._channels.get(channel_name)

def info(self) -> Dict:
return {channel.name: channel.info() for channel in self}

async def __call__(self, channel_name: str, message: Dict) -> None:
"""Channel callback"""
channel = self.channels.get(channel_name)
channel = self.get(channel_name)
if channel:
closed = await channel(message)
for websocket in closed:
for channel_name, channel in tuple(self.channels.items()):
for channel_name, channel in tuple(self._channels.items()):
channel.remove_callback(websocket)
await self._maybe_remove_channel(channel)

Expand All @@ -55,18 +58,17 @@ async def register(
:param event_name: name of the event in the channel or a pattern
:param callback: the callback to invoke when the `event` on `channel` occurs
"""
channel_name = channel_name.lower()
channel = self.channels.get(channel_name)
channel = self.get(channel_name)
if channel is None:
try:
await self.sockets.subscribe(channel_name)
except CannotSubscribe:
raise ValidationErrors(dict(channel="Invalid channel"))
else:
channel = Channel(channel_name)
self.channels[channel_name] = channel
self._channels[channel_name] = channel
event = channel.register(event_name, callback)
await self.sockets.subscribe_to_event(channel.name, event)
await self.sockets.subscribe_to_event(channel.name, event.name)
return channel

async def unregister(
Expand All @@ -75,7 +77,7 @@ async def unregister(
"""Safely unregister a callback from the list of event
callbacks for channel_name
"""
channel = self.channels.get(channel_name.lower())
channel = self.get(channel_name)
if channel is None:
raise ValidationErrors(dict(channel="Invalid channel"))
channel.unregister(event, callback)
Expand All @@ -84,12 +86,12 @@ async def unregister(
async def _maybe_remove_channel(self, channel: Channel) -> Channel:
if not channel:
await self.sockets.unsubscribe(channel.name)
self.channels.pop(channel.name)
self._channels.pop(channel.name)
return channel

def get_subscribed(self, callback: CallbackType) -> Dict[str, List[str]]:
subscribed = {}
for channel in self.channels.values():
for channel in self:
events = channel.get_subscribed(callback)
if events:
subscribed[channel.name] = events
Expand Down
20 changes: 16 additions & 4 deletions openapi/ws/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from typing import Any, Callable, Dict, Set

from ..utils import cached_property
from .channel import Event
from .channels import CannotSubscribe, Channels
from .errors import CannotPublish

Expand All @@ -29,7 +28,7 @@ def sockets(self) -> Set[Websocket]:

@cached_property
def channels(self) -> Channels:
"""Set of connected :class:`.Websocket`"""
"""Pub/sub :class:`.Channels` currently active on the running pod"""
return Channels(self)

def add(self, ws: Websocket) -> None:
Expand All @@ -55,6 +54,10 @@ async def publish(
) -> None: # pragma: no cover
"""Publish an event to a channel
:property channel: the channel to publish to
:property event: the event in the channel
:property body: the body of the event to broadcast in the channel
This method should raise :class:`.CannotPublish` if not possible to publish
"""
raise CannotPublish
Expand All @@ -66,8 +69,17 @@ async def subscribe(self, channel: str) -> None: # pragma: no cover
"""
raise CannotSubscribe

async def subscribe_to_event(self, channel: str, event: Event) -> None:
"""Callback when a subscribe to event is done"""
async def subscribe_to_event(self, channel: str, event: str) -> None:
"""Callback when a subscription to an event is done
:property channel: the channel to publish to
:property event: the event in the channel
You can use this callback to perform any backend subscriptions to
third-party streaming services if required.
By default it does nothing.
"""

async def unsubscribe(self, channel: str) -> None:
"""Unsubscribe from a channel"""
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[flake8]
exclude = __pycache__,.eggs,venv,build,dist,docs,dev
max-line-length = 88
ignore = A001,A002,A003,C815,C812,W503,E203
ignore = A001,A002,A003,B902,C815,C812,W503,E203

[isort]
line_length=88
Expand Down
18 changes: 16 additions & 2 deletions tests/ws/test_channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pytest
from async_timeout import timeout

from openapi.ws import Channels
from openapi.ws.utils import redis_to_py_pattern
from tests.example.ws import LocalBroker

Expand All @@ -18,14 +19,14 @@ async def channels():
await broker.close()


async def test_channels_properties(channels):
async def test_channels_properties(channels: Channels):
assert channels.sockets
await channels.register("foo", "*", lambda c, e, d: d)
assert len(channels) == 1
assert "foo" in channels


async def test_channels_wildcard(channels):
async def test_channels_wildcard(channels: Channels):
future = asyncio.Future()

def fire(channel, event, data):
Expand Down Expand Up @@ -64,6 +65,19 @@ def test_redis_to_py_pattern():
assert not_match(c, "hollo")


async def test_channel(channels: Channels):
assert channels.sockets
await channels.register("test", "foo", lambda c, e, d: d)
assert channels.registered == ("test",)
channel = channels.get("test")
assert channel.name == "test"
assert channel.events == ("foo",)
assert "foo$" in channel
events = list(channel)
assert len(events) == 1
assert await channel({}) == ()


def match(c, text):
return c.match(text).group() == text

Expand Down
20 changes: 20 additions & 0 deletions tests/ws/test_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,15 @@ async def test_rpc_subscribe(cli):
response = data["response"]
assert response["connections"] == 1
assert len(response["channels"]) == 2
#
# invalid channel
await ws.send_json(
dict(id="abc", method="subscribe", payload=dict(channel="Test"))
)
msg = await ws.receive()
data = msg.json()
assert data["error"]["message"] == "Invalid RPC parameters"
assert data["error"]["errors"]["channel"] == "Invalid channel"


async def test_rpc_unsubscribe(cli):
Expand Down Expand Up @@ -256,3 +265,14 @@ async def test_badjson(cli):
assert data["id"] == "abc"
assert data["error"]
assert data["error"]["message"] == "JSON object expected"


async def test_rpc_unsubscribe_error(cli):
async with cli.ws_connect("/stream") as ws:
await ws.send_json(
dict(id="xyz", method="unsubscribe", payload=dict(channel="whazaaa"))
)
msg = await ws.receive()
data = msg.json()
assert data["error"]["message"] == "Invalid RPC parameters"
assert data["error"]["errors"]["channel"] == "Invalid channel"

0 comments on commit 4b811f5

Please sign in to comment.