Skip to content

Commit

Permalink
Merge pull request #9 from Avivsalem/staging
Browse files Browse the repository at this point in the history
Staging
  • Loading branch information
Avivsalem authored Aug 9, 2023
2 parents a564bc3 + cd9f7a4 commit a4fa326
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 15 deletions.
2 changes: 1 addition & 1 deletion VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.0.0
0.1.0
63 changes: 49 additions & 14 deletions fastmessage/fastmessage_handler.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import asyncio
import inspect
import json
import logging
from asyncio import AbstractEventLoop
from dataclasses import dataclass
from typing import Optional, Callable, Dict, List, Any, TypeVar, Union, Iterable, Generator
from typing import Optional, Callable, Dict, List, Any, Union, Iterable, Generator, AsyncGenerator, TypeVar

import itertools
from pydantic import BaseModel, parse_raw_as, create_model, ValidationError, Extra
Expand Down Expand Up @@ -57,15 +59,26 @@ class _ParamInfo:
_logger = logging.getLogger(__name__)


def _get_callable_name(callback: _CALLABLE_TYPE) -> str:
try:
return getattr(callback, "__name__")
except AttributeError as ex:
raise UnnamedCallableException(f"Callable {repr(callback)} doesn't have a name") from ex


class _CallbackWrapper:
def __init__(self, callback: Callable,
def __init__(self, callback: _CALLABLE_TYPE,
input_device: str,
output_device: Optional[str] = None):
self._callback = callback
self._input_device = input_device
self._output_device = output_device
self._special_params: Dict[str, _ParamInfo] = dict()
self._params: Dict[str, _ParamInfo] = dict()
self._event_loop_cache: Optional[AbstractEventLoop] = None
self._is_async = inspect.iscoroutinefunction(callback)
self._is_async_gen = inspect.isasyncgenfunction(callback)

type_hints = get_all_type_hints(self._callback)
extra = Extra.ignore
for param_name, param in inspect.signature(self._callback).parameters.items():
Expand Down Expand Up @@ -103,8 +116,32 @@ def __init__(self, callback: Callable,
__config__=get_config(dict(extra=extra)),
**model_params)

def _event_loop(self) -> AbstractEventLoop:
if self._event_loop_cache is None:
self._event_loop_cache = asyncio.new_event_loop() # TODO: when to we close the loop?

return self._event_loop_cache

def _get_model_name(self) -> str:
return f"model_{self._callback.__name__}_{self._input_device}"
callable_name = _get_callable_name(self._callback)
return f"model_{callable_name}_{self._input_device}"

@staticmethod
def _iter_over_async(async_generator: AsyncGenerator, loop: AbstractEventLoop):
ait = async_generator.__aiter__()

async def get_next():
try:
obj = await ait.__anext__()
return False, obj
except StopAsyncIteration:
return True, None

while True:
done, obj = loop.run_until_complete(get_next())
if done:
break
yield obj

def __call__(self,
input_device: InputDevice,
Expand All @@ -121,8 +158,12 @@ def __call__(self,
if self._model:
model = parse_raw_as(self._model, message_bundle.message.bytes)
kwargs.update(dict(model))

callback_return = self._callback(**kwargs)
if self._is_async:
callback_return = self._event_loop().run_until_complete(self._callback(**kwargs))
elif self._is_async_gen:
callback_return = self._iter_over_async(self._callback(**kwargs), self._event_loop())
else:
callback_return = self._callback(**kwargs)
if callback_return is None:
return None

Expand Down Expand Up @@ -175,7 +216,7 @@ def __init__(self,
Optional[Union[PipelineResult, List[PipelineResult]]]]] = None):
"""
:param default_output_device: an optional default output device to send callaback results to,
:param default_output_device: an optional default output device to send callback results to,
unless mapped otherwise
:param validation_error_handler: an optional handler that will be called on validation errors,
in order to give the user a chance to handle them gracefully
Expand All @@ -202,14 +243,8 @@ def register_validation_error_handler(self,
"""
self._validation_error_handler = handler

def _get_callable_name(self, callback: Callable) -> str:
try:
return getattr(callback, "__name__")
except AttributeError as ex:
raise UnnamedCallableException(f"Callable {repr(callback)} doesn't have a name") from ex

def register_callback(self,
callback: Callable,
callback: _CALLABLE_TYPE,
input_device: str = _DEFAULT,
output_device: Optional[str] = _DEFAULT):
"""
Expand All @@ -222,7 +257,7 @@ def register_callback(self,
if callback returns None, no routing will be made even if 'output_device' is not None
"""
if input_device is _DEFAULT:
input_device = self._get_callable_name(callback)
input_device = _get_callable_name(callback)

if input_device in self._wrappers:
raise DuplicateCallbackException(f"Can't register more than one callback on device '{input_device}'")
Expand Down
47 changes: 47 additions & 0 deletions tests/fastmessage_handler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,31 @@ def do_something1(x: SomeModel, y: str, z: List[int] = None):
assert json_result['y'] == 'x=1, y=a, z=[1, 2]'


def test_sanity_async():
default_output_device = str(uuid.uuid4()).replace('-', '')
fm: FastMessage = FastMessage(default_output_device=default_output_device)

@fm.map()
async def do_something1(x: SomeModel, y: str, z: List[int] = None):
return SomeOtherModel(y=f'x={x.x}, y={y}, z={z}')

result = fm.handle_message(FakeInputDevice('do_something1'),
MessageBundle(Message(b'{"x": {"x":1}, "y": "a", "F":3}')))
assert result is not None
result = result[0]
assert result.output_device_name == default_output_device
json_result = json.loads(result.message_bundle.message.bytes.decode())
assert json_result['y'] == 'x=1, y=a, z=None'

result = fm.handle_message(FakeInputDevice('do_something1'),
MessageBundle(Message(b'{"x": {"x":1}, "y": "a", "z":[1,2]}')))
assert result is not None
result = result[0]
assert result.output_device_name == default_output_device
json_result = json.loads(result.message_bundle.message.bytes.decode())
assert json_result['y'] == 'x=1, y=a, z=[1, 2]'


def test_root_model():
fm: FastMessage = FastMessage()

Expand Down Expand Up @@ -294,6 +319,28 @@ def do_something1(m: Message, b: MessageBundle, d: InputDeviceName, y: int):
assert result[2].message_bundle.message.bytes == b'3'


def test_return_async_generator():
default_output_device = str(uuid.uuid4()).replace('-', '')
fm: FastMessage = FastMessage(default_output_device=default_output_device)

@fm.map(input_device='input1')
async def do_something1(m: Message, b: MessageBundle, d: InputDeviceName, y: int):
yield 1
yield 2
yield 3

result = fm.handle_message(FakeInputDevice('input1'),
MessageBundle(message=Message(data=b'{"y": 10}',
headers={'test': 'mtest'}),
device_headers={'test': 'btest'}))
assert result is not None
result = list(result)
assert len(result) == 3
assert result[0].message_bundle.message.bytes == b'1'
assert result[1].message_bundle.message.bytes == b'2'
assert result[2].message_bundle.message.bytes == b'3'


def test_return_complex():
default_output_device = str(uuid.uuid4()).replace('-', '')
fm: FastMessage = FastMessage(default_output_device=default_output_device)
Expand Down

0 comments on commit a4fa326

Please sign in to comment.