From 7a298cfbbbd3dbac2b2346ab10f4d0499596f8fb Mon Sep 17 00:00:00 2001 From: Luca Sbardella Date: Tue, 25 Sep 2018 18:54:02 +0100 Subject: [PATCH 1/4] channels --- openapi/cli.py | 9 --------- openapi/ws/channel.py | 11 +++++++++-- openapi/ws/channels.py | 19 +++++++++++++++---- openapi/ws/pubsub.py | 2 +- 4 files changed, 25 insertions(+), 16 deletions(-) diff --git a/openapi/cli.py b/openapi/cli.py index b46a403..639ea89 100644 --- a/openapi/cli.py +++ b/openapi/cli.py @@ -4,7 +4,6 @@ from aiohttp import web import click -import dotenv import uvloop from .utils import get_debug_flag, getLogger @@ -79,16 +78,8 @@ def list_commands(self, ctx): def main(self, *args, **kwargs): os.environ['OPENAPI_RUN_FROM_CLI'] = 'true' - self.load_dotenv() return super().main(*args, **kwargs) - def load_dotenv(self, path=None): - if path is not None: - return dotenv.load_dotenv(path) - path = dotenv.find_dotenv('.env', usecwd=True) - if path: - dotenv.load_dotenv(path) - def get_server_version(self, ctx, param, value): if not value or ctx.resilient_parsing: return diff --git a/openapi/ws/channel.py b/openapi/ws/channel.py index 1bca39c..1d21eb1 100644 --- a/openapi/ws/channel.py +++ b/openapi/ws/channel.py @@ -2,7 +2,7 @@ import enum import asyncio import logging -from typing import List +from typing import Set from functools import wraps from collections import OrderedDict from dataclasses import dataclass @@ -29,7 +29,7 @@ class Event: name: str pattern: str regex: object - callbacks: List + callbacks: Set def safe_execution(method): @@ -124,6 +124,13 @@ def register(self, event, callback): return entry + def get_subscribed(self, handler): + events = [] + for event in self.callbacks.values(): + if handler in event.callbacks: + events.append(event.name) + return events + def unregister(self, event, callback): pattern = self.channels.event_pattern(event) entry = self.callbacks.get(pattern) diff --git a/openapi/ws/channels.py b/openapi/ws/channels.py index 0c49cea..456c2b3 100644 --- a/openapi/ws/channels.py +++ b/openapi/ws/channels.py @@ -1,6 +1,6 @@ import asyncio from collections import OrderedDict -from typing import Union, Dict, Iterator, Callable +from typing import Dict, Iterator, Callable from .channel import Channel, StatusType, logger from .utils import redis_to_py_pattern @@ -37,7 +37,7 @@ def __init__( status_channel or DEFAULT_CHANNEL) self.status = self.statusType.initialised if broker: - broker.set_channels(self) + broker.on_connection_lost(self.connection_lost) @property def registered(self): @@ -58,7 +58,7 @@ def __contains__(self, name) -> bool: def __iter__(self) -> Iterator: return iter(self.channels.values()) - async def __call__(self, channel_name: str, message: Union[str, Dict]): + async def __call__(self, channel_name: str, message: Dict): if channel_name.startswith(self.namespace): name = channel_name[len(self.namespace):] channel = self.channels.get(name) @@ -148,12 +148,23 @@ def event_pattern(self, event): """ return redis_to_py_pattern(event or '*') + def get_subscribed(self, handler): + subscribed = {} + for channel in self.channels.values(): + events = channel.get_subscribed(handler) + if events: + subscribed[channel.name] = events + return subscribed + # INTERNALS + def connection_lost(self): + self.status = StatusType.disconnected + async def _subscribe(self, channel_name): """Subscribe to the remote server """ - await self.broker.subscribe(self.prefixed(channel_name)) + await self.broker.subscribe(self.prefixed(channel_name), handler=self) async def _unsubscribe(self, channel_name): pass diff --git a/openapi/ws/pubsub.py b/openapi/ws/pubsub.py index b47dd39..e6af636 100644 --- a/openapi/ws/pubsub.py +++ b/openapi/ws/pubsub.py @@ -52,7 +52,7 @@ async def ws_rpc_subscribe(self, payload): """ await self.channels.register( payload['channel'], payload.get('event'), self.new_message) - return dict(subscribed=self.channels.registered) + return dict(subscribed=self.channels.get_subscribed(self.new_message)) @ws_rpc(body_schema=SubscribeSchema) async def ws_rpc_unsubscribe(self, payload): From b28508b524222febba85e706f83758b2b9654f8d Mon Sep 17 00:00:00 2001 From: Luca Sbardella Date: Thu, 27 Sep 2018 14:42:18 +0100 Subject: [PATCH 2/4] regression --- openapi/data/db.py | 2 +- openapi/data/fields.py | 11 +++++++++-- openapi/ws/broker.py | 22 ++++++++++++---------- openapi/ws/channel.py | 1 + openapi/ws/pubsub.py | 2 +- tests/test_channels.py | 1 - tests/test_ws.py | 11 +++++++---- 7 files changed, 31 insertions(+), 19 deletions(-) diff --git a/openapi/data/db.py b/openapi/data/db.py index eefb1e0..86f901f 100644 --- a/openapi/data/db.py +++ b/openapi/data/db.py @@ -82,7 +82,7 @@ def dt_ti(col, required): data_field = col.info.get('data_field', fields.date_time_field) return ( datetime, - data_field(**info(col, required)) + data_field(timezone=col.type.timezone, **info(col, required)) ) diff --git a/openapi/data/fields.py b/openapi/data/fields.py index 441a0c0..9aaaa5e 100644 --- a/openapi/data/fields.py +++ b/openapi/data/fields.py @@ -109,8 +109,8 @@ def date_field(**kw): return data_field(**kw) -def date_time_field(**kw): - kw.setdefault('validator', DateTimeValidator()) +def date_time_field(timezone=False, **kw): + kw.setdefault('validator', DateTimeValidator(timezone=timezone)) return data_field(**kw) @@ -274,6 +274,9 @@ def __call__(self, field, value, data=None): class DateTimeValidator(Validator): + def __init__(self, timezone=False): + self.timezone = timezone + def dump(self, value): if isinstance(value, datetime): return value.isoformat() @@ -289,6 +292,10 @@ def __call__(self, field, value, data=None): raise ValidationError( field.name, '%s not valid format' % value ) + if self.timezone and not value.tzinfo: + raise ValidationError( + field.name, 'Timezone infoirmation required' + ) return value diff --git a/openapi/ws/broker.py b/openapi/ws/broker.py index e87f78a..f878b47 100644 --- a/openapi/ws/broker.py +++ b/openapi/ws/broker.py @@ -1,16 +1,11 @@ import abc import asyncio -from typing import Dict +from typing import Dict, Callable class Broker(abc.ABC): """Abstract class for pubsub brokers """ - channels = None - - def set_channels(self, channels) -> None: - self.channels = channels - async def start(self) -> None: """ Start broker @@ -28,7 +23,7 @@ async def publish(self, channel: str, body: Dict) -> None: pass @abc.abstractmethod - async def subscribe(self, channel: str) -> None: + async def subscribe(self, channel: str, handler: Callable=None) -> None: """Bind the broker to a channel/exchange """ pass @@ -39,6 +34,9 @@ async def unsubscribe(self, channel: str) -> None: """ pass + def on_connection_lost(self, lost): + pass + class LocalBroker(Broker): @@ -47,6 +45,7 @@ def __init__(self): self.messages = None self.worker = None self._stop = False + self._handlers = set() async def start(self): if not self.worker: @@ -58,8 +57,10 @@ async def publish(self, channel, body): 0.01, self.messages.put_nowait, (channel, body) ) - async def subscribe(self, key): + async def subscribe(self, key: str, handler: Callable=None) -> None: self.binds.add(key) + if handler: + self._handlers.add(handler) async def unsubscribe(self, key): self.binds.discard(key) @@ -76,5 +77,6 @@ async def _work(self): key, body = await self.messages.get() if self._stop: break - if self.channels and key in self.binds: - await self.channels(key, body) + if key in self.binds: + for handler in self._handlers: + await handler(key, body) diff --git a/openapi/ws/channel.py b/openapi/ws/channel.py index 1d21eb1..0cc87ca 100644 --- a/openapi/ws/channel.py +++ b/openapi/ws/channel.py @@ -111,6 +111,7 @@ async def disconnect(self): def register(self, event, callback): """Register a ``callback`` for ``event`` """ + event = event or '*' pattern = self.channels.event_pattern(event) entry = self.callbacks.get(pattern) if not entry: diff --git a/openapi/ws/pubsub.py b/openapi/ws/pubsub.py index e6af636..9ef64c2 100644 --- a/openapi/ws/pubsub.py +++ b/openapi/ws/pubsub.py @@ -60,7 +60,7 @@ async def ws_rpc_unsubscribe(self, payload): """ await self.channels.unregister( payload['channel'], payload.get('event'), self.new_message) - return dict(subscribed=self.channels.registered) + return dict(subscribed=self.channels.get_subscribed(self.new_message)) async def new_message(self, channel, match, data): """A new message has arrived from channels diff --git a/tests/test_channels.py b/tests/test_channels.py index 76d6324..c8d7c26 100644 --- a/tests/test_channels.py +++ b/tests/test_channels.py @@ -17,7 +17,6 @@ async def channels(): async def test_channels_properties(channels): assert channels.broker - assert channels.broker.channels == channels assert len(channels.status_channel) == 0 assert channels.status == channels.statusType.initialised assert len(channels) == 1 diff --git a/tests/test_ws.py b/tests/test_ws.py index 79bbdd8..b6a0d22 100644 --- a/tests/test_ws.py +++ b/tests/test_ws.py @@ -110,7 +110,7 @@ async def test_rpc_subscribe(cli): assert msg.type == aiohttp.WSMsgType.TEXT data = msg.json() assert data['id'] == 'abc' - assert data['response'] == dict(subscribed=['server', 'test']) + assert data['response'] == dict(subscribed={'test': ['*']}) await ws.send_json( dict( id='abcd', method='subscribe', @@ -121,7 +121,8 @@ async def test_rpc_subscribe(cli): assert msg.type == aiohttp.WSMsgType.TEXT data = msg.json() assert data['id'] == 'abcd' - assert data['response'] == dict(subscribed=['server', 'test', 'foo']) + assert data['response'] == dict( + subscribed={'test': ['*'], 'foo': ['*']}) async def test_rpc_unsubscribe(cli): @@ -141,7 +142,9 @@ async def test_rpc_unsubscribe(cli): ) msg = await ws.receive() data = msg.json() - assert data['response'] == dict(subscribed=['server', 'test', 'foo']) + assert data['response'] == dict( + subscribed={'test': ['*'], 'foo': ['*']} + ) await ws.send_json( dict( id='xyz', method='unsubscribe', @@ -150,7 +153,7 @@ async def test_rpc_unsubscribe(cli): ) msg = await ws.receive() data = msg.json() - assert data['response'] == dict(subscribed=['server', 'foo']) + assert data['response'] == dict(subscribed={'foo': ['*']}) async def test_rpc_pubsub(cli): From b85665731e3ae346380599c6976e22c121dadcb2 Mon Sep 17 00:00:00 2001 From: Luca Sbardella Date: Thu, 27 Sep 2018 15:34:36 +0100 Subject: [PATCH 3/4] typo --- openapi/data/fields.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/openapi/data/fields.py b/openapi/data/fields.py index 9aaaa5e..e1198b9 100644 --- a/openapi/data/fields.py +++ b/openapi/data/fields.py @@ -294,7 +294,7 @@ def __call__(self, field, value, data=None): ) if self.timezone and not value.tzinfo: raise ValidationError( - field.name, 'Timezone infoirmation required' + field.name, 'Timezone information required' ) return value From 4027b83db1de620c73245ce5803367e2305ae8c4 Mon Sep 17 00:00:00 2001 From: Luca Sbardella Date: Fri, 28 Sep 2018 07:34:39 +0100 Subject: [PATCH 4/4] Update __init__.py --- openapi/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/openapi/__init__.py b/openapi/__init__.py index d897658..a53f894 100644 --- a/openapi/__init__.py +++ b/openapi/__init__.py @@ -1,4 +1,4 @@ """Minimal OpenAPI asynchronous server application """ -__version__ = '0.8.8' +__version__ = '0.8.9'