Skip to content
This repository has been archived by the owner on Mar 24, 2024. It is now read-only.

Commit

Permalink
Merge pull request #122 from lendingblock/master
Browse files Browse the repository at this point in the history
0.8.9
  • Loading branch information
lsbardel authored Sep 28, 2018
2 parents 063698d + 4027b83 commit edfcb9f
Show file tree
Hide file tree
Showing 10 changed files with 57 additions and 36 deletions.
2 changes: 1 addition & 1 deletion openapi/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Minimal OpenAPI asynchronous server application
"""

__version__ = '0.8.8'
__version__ = '0.8.9'
9 changes: 0 additions & 9 deletions openapi/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

from aiohttp import web
import click
import dotenv
import uvloop

from .utils import get_debug_flag, getLogger
Expand Down Expand Up @@ -79,16 +78,8 @@ def list_commands(self, ctx):

def main(self, *args, **kwargs):
os.environ['OPENAPI_RUN_FROM_CLI'] = 'true'
self.load_dotenv()
return super().main(*args, **kwargs)

def load_dotenv(self, path=None):
if path is not None:
return dotenv.load_dotenv(path)
path = dotenv.find_dotenv('.env', usecwd=True)
if path:
dotenv.load_dotenv(path)

def get_server_version(self, ctx, param, value):
if not value or ctx.resilient_parsing:
return
Expand Down
2 changes: 1 addition & 1 deletion openapi/data/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def dt_ti(col, required):
data_field = col.info.get('data_field', fields.date_time_field)
return (
datetime,
data_field(**info(col, required))
data_field(timezone=col.type.timezone, **info(col, required))
)


Expand Down
11 changes: 9 additions & 2 deletions openapi/data/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,8 @@ def date_field(**kw):
return data_field(**kw)


def date_time_field(**kw):
kw.setdefault('validator', DateTimeValidator())
def date_time_field(timezone=False, **kw):
kw.setdefault('validator', DateTimeValidator(timezone=timezone))
return data_field(**kw)


Expand Down Expand Up @@ -274,6 +274,9 @@ def __call__(self, field, value, data=None):

class DateTimeValidator(Validator):

def __init__(self, timezone=False):
self.timezone = timezone

def dump(self, value):
if isinstance(value, datetime):
return value.isoformat()
Expand All @@ -289,6 +292,10 @@ def __call__(self, field, value, data=None):
raise ValidationError(
field.name, '%s not valid format' % value
)
if self.timezone and not value.tzinfo:
raise ValidationError(
field.name, 'Timezone information required'
)
return value


Expand Down
22 changes: 12 additions & 10 deletions openapi/ws/broker.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,11 @@
import abc
import asyncio
from typing import Dict
from typing import Dict, Callable


class Broker(abc.ABC):
"""Abstract class for pubsub brokers
"""
channels = None

def set_channels(self, channels) -> None:
self.channels = channels

async def start(self) -> None:
"""
Start broker
Expand All @@ -28,7 +23,7 @@ async def publish(self, channel: str, body: Dict) -> None:
pass

@abc.abstractmethod
async def subscribe(self, channel: str) -> None:
async def subscribe(self, channel: str, handler: Callable=None) -> None:
"""Bind the broker to a channel/exchange
"""
pass
Expand All @@ -39,6 +34,9 @@ async def unsubscribe(self, channel: str) -> None:
"""
pass

def on_connection_lost(self, lost):
pass


class LocalBroker(Broker):

Expand All @@ -47,6 +45,7 @@ def __init__(self):
self.messages = None
self.worker = None
self._stop = False
self._handlers = set()

async def start(self):
if not self.worker:
Expand All @@ -58,8 +57,10 @@ async def publish(self, channel, body):
0.01, self.messages.put_nowait, (channel, body)
)

async def subscribe(self, key):
async def subscribe(self, key: str, handler: Callable=None) -> None:
self.binds.add(key)
if handler:
self._handlers.add(handler)

async def unsubscribe(self, key):
self.binds.discard(key)
Expand All @@ -76,5 +77,6 @@ async def _work(self):
key, body = await self.messages.get()
if self._stop:
break
if self.channels and key in self.binds:
await self.channels(key, body)
if key in self.binds:
for handler in self._handlers:
await handler(key, body)
12 changes: 10 additions & 2 deletions openapi/ws/channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import enum
import asyncio
import logging
from typing import List
from typing import Set
from functools import wraps
from collections import OrderedDict
from dataclasses import dataclass
Expand All @@ -29,7 +29,7 @@ class Event:
name: str
pattern: str
regex: object
callbacks: List
callbacks: Set


def safe_execution(method):
Expand Down Expand Up @@ -111,6 +111,7 @@ async def disconnect(self):
def register(self, event, callback):
"""Register a ``callback`` for ``event``
"""
event = event or '*'
pattern = self.channels.event_pattern(event)
entry = self.callbacks.get(pattern)
if not entry:
Expand All @@ -124,6 +125,13 @@ def register(self, event, callback):

return entry

def get_subscribed(self, handler):
events = []
for event in self.callbacks.values():
if handler in event.callbacks:
events.append(event.name)
return events

def unregister(self, event, callback):
pattern = self.channels.event_pattern(event)
entry = self.callbacks.get(pattern)
Expand Down
19 changes: 15 additions & 4 deletions openapi/ws/channels.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import asyncio
from collections import OrderedDict
from typing import Union, Dict, Iterator, Callable
from typing import Dict, Iterator, Callable

from .channel import Channel, StatusType, logger
from .utils import redis_to_py_pattern
Expand Down Expand Up @@ -37,7 +37,7 @@ def __init__(
status_channel or DEFAULT_CHANNEL)
self.status = self.statusType.initialised
if broker:
broker.set_channels(self)
broker.on_connection_lost(self.connection_lost)

@property
def registered(self):
Expand All @@ -58,7 +58,7 @@ def __contains__(self, name) -> bool:
def __iter__(self) -> Iterator:
return iter(self.channels.values())

async def __call__(self, channel_name: str, message: Union[str, Dict]):
async def __call__(self, channel_name: str, message: Dict):
if channel_name.startswith(self.namespace):
name = channel_name[len(self.namespace):]
channel = self.channels.get(name)
Expand Down Expand Up @@ -148,12 +148,23 @@ def event_pattern(self, event):
"""
return redis_to_py_pattern(event or '*')

def get_subscribed(self, handler):
subscribed = {}
for channel in self.channels.values():
events = channel.get_subscribed(handler)
if events:
subscribed[channel.name] = events
return subscribed

# INTERNALS

def connection_lost(self):
self.status = StatusType.disconnected

async def _subscribe(self, channel_name):
"""Subscribe to the remote server
"""
await self.broker.subscribe(self.prefixed(channel_name))
await self.broker.subscribe(self.prefixed(channel_name), handler=self)

async def _unsubscribe(self, channel_name):
pass
Expand Down
4 changes: 2 additions & 2 deletions openapi/ws/pubsub.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,15 +52,15 @@ async def ws_rpc_subscribe(self, payload):
"""
await self.channels.register(
payload['channel'], payload.get('event'), self.new_message)
return dict(subscribed=self.channels.registered)
return dict(subscribed=self.channels.get_subscribed(self.new_message))

@ws_rpc(body_schema=SubscribeSchema)
async def ws_rpc_unsubscribe(self, payload):
"""Unsubscribe to an event on a channel
"""
await self.channels.unregister(
payload['channel'], payload.get('event'), self.new_message)
return dict(subscribed=self.channels.registered)
return dict(subscribed=self.channels.get_subscribed(self.new_message))

async def new_message(self, channel, match, data):
"""A new message has arrived from channels
Expand Down
1 change: 0 additions & 1 deletion tests/test_channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ async def channels():

async def test_channels_properties(channels):
assert channels.broker
assert channels.broker.channels == channels
assert len(channels.status_channel) == 0
assert channels.status == channels.statusType.initialised
assert len(channels) == 1
Expand Down
11 changes: 7 additions & 4 deletions tests/test_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ async def test_rpc_subscribe(cli):
assert msg.type == aiohttp.WSMsgType.TEXT
data = msg.json()
assert data['id'] == 'abc'
assert data['response'] == dict(subscribed=['server', 'test'])
assert data['response'] == dict(subscribed={'test': ['*']})
await ws.send_json(
dict(
id='abcd', method='subscribe',
Expand All @@ -121,7 +121,8 @@ async def test_rpc_subscribe(cli):
assert msg.type == aiohttp.WSMsgType.TEXT
data = msg.json()
assert data['id'] == 'abcd'
assert data['response'] == dict(subscribed=['server', 'test', 'foo'])
assert data['response'] == dict(
subscribed={'test': ['*'], 'foo': ['*']})


async def test_rpc_unsubscribe(cli):
Expand All @@ -141,7 +142,9 @@ async def test_rpc_unsubscribe(cli):
)
msg = await ws.receive()
data = msg.json()
assert data['response'] == dict(subscribed=['server', 'test', 'foo'])
assert data['response'] == dict(
subscribed={'test': ['*'], 'foo': ['*']}
)
await ws.send_json(
dict(
id='xyz', method='unsubscribe',
Expand All @@ -150,7 +153,7 @@ async def test_rpc_unsubscribe(cli):
)
msg = await ws.receive()
data = msg.json()
assert data['response'] == dict(subscribed=['server', 'foo'])
assert data['response'] == dict(subscribed={'foo': ['*']})


async def test_rpc_pubsub(cli):
Expand Down

0 comments on commit edfcb9f

Please sign in to comment.