From ba2bcea2d075f6844b5ab3b895f2b42541a6e12a Mon Sep 17 00:00:00 2001 From: avivs Date: Fri, 11 Aug 2023 14:39:11 +0300 Subject: [PATCH] big refactor in modules added MethodValidator changed FastMessageOutput to CustomOutput --- VERSION | 2 +- docs/index.md | 36 ++-- fastmessage/__init__.py | 6 +- fastmessage/callable_wrapper.py | 179 +++++++++++------- fastmessage/common.py | 37 +++- fastmessage/exceptions.py | 4 + fastmessage/fastmessage_handler.py | 12 +- fastmessage/method_validator.py | 55 ++++++ tests/async_test.py | 54 ++++++ tests/common.py | 15 ++ ...essage_handler_test.py => handler_test.py} | 72 ++----- tests/method_validation_test.py | 84 ++++++++ 12 files changed, 409 insertions(+), 147 deletions(-) create mode 100644 fastmessage/method_validator.py create mode 100644 tests/async_test.py create mode 100644 tests/common.py rename tests/{fastmessage_handler_test.py => handler_test.py} (82%) create mode 100644 tests/method_validation_test.py diff --git a/VERSION b/VERSION index 6c6aa7c..341cf11 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -0.1.0 \ No newline at end of file +0.2.0 \ No newline at end of file diff --git a/docs/index.md b/docs/index.md index 965c9da..a187e66 100644 --- a/docs/index.md +++ b/docs/index.md @@ -98,19 +98,19 @@ def process_somemodel(__root__: SomeModel): ### Special Param Types -There are three special types which you can annotate the arguments for the callback with. +There are special types which you can annotate the arguments for the callback with. * ```InputDeviceName``` - arguments annotated with this type will receive the name of the input device the message came from. useful for registering the same callback for several input devices * ```Message``` - arguments annotated with this type will receive the raw message which came from the device * ```MessageBundle``` - arguments annotated with this type will receive the complete MessageBundle (with device headers) - +* ```MethodValidator``` - argument of this type, will receive an object that can help validate return values for other methods Notice that arguments annotated with these types MUST NOT have default values (Since they always have values). ```python -from pydantic import BaseModel -from fastmessage import FastMessage, InputDeviceName + +from fastmessage import FastMessage, InputDeviceName, MethodValidator from messageflux.iodevices.base import Message, MessageBundle fm = FastMessage() @@ -118,11 +118,21 @@ fm = FastMessage() @fm.map(input_device='some_queue') def do_something(i: InputDeviceName, m: Message, mb: MessageBundle, x: int): - # i will be 'some_queue' - # m will be the message that arrived - # mb will be the MessageBundle that arrived - # x will be the serialized value of the message - pass # do something + # i will be 'some_queue' + # m will be the message that arrived + # mb will be the MessageBundle that arrived + # x will be the serialized value of the message + pass # do something + + +@fm.map() +def func1(mv: MethodValidator): + yield mv.validate_and_return(func2, x=3, y="hello") # this will succeed + yield mv.validate_and_return(func2, x=4) # this will raise MethodValidationError because y param is required but missing + +@fm.map() +def func2(x:int, y:str): + pass ``` ### Returning Multiple Results @@ -157,18 +167,18 @@ def do_something_generator(x: int): # this method does the same as the previous You can make the function return a result to a different output device then the one in the decorator -You do this by using the ```FastMessageOutput``` class, and giving it the value to send, and the output device name to +You do this by using the ```CustomOutput``` class, and giving it the value to send, and the output device name to send to ```python -from fastmessage import FastMessage, FastMessageOutput +from fastmessage import FastMessage, CustomOutput fm = FastMessage() @fm.map(input_device='some_queue', output_device='default_output_device') def do_something(x: int): - return FastMessageOutput(value=1, - output_device='other_output_device') # this will send the value 1 to 'other_output_device' instead of the default + return CustomOutput(value=1, + output_device='other_output_device') # this will send the value 1 to 'other_output_device' instead of the default ``` diff --git a/fastmessage/__init__.py b/fastmessage/__init__.py index 8ec6469..7bbf23b 100644 --- a/fastmessage/__init__.py +++ b/fastmessage/__init__.py @@ -1,5 +1,5 @@ -from .callable_wrapper import ( - FastMessageOutput, +from .common import ( + CustomOutput, InputDeviceName, MultipleReturnValues, ) @@ -10,6 +10,8 @@ DuplicateCallbackException, UnnamedCallableException, NotAllowedParamKindException, + MethodValidationError ) from .fastmessage_handler import FastMessage +from .method_validator import MethodValidator diff --git a/fastmessage/callable_wrapper.py b/fastmessage/callable_wrapper.py index 6f7c1cf..c629433 100644 --- a/fastmessage/callable_wrapper.py +++ b/fastmessage/callable_wrapper.py @@ -2,15 +2,18 @@ import json from asyncio import AbstractEventLoop from dataclasses import dataclass -from typing import Optional, Dict, Any, Union, Iterable, Generator, AsyncGenerator, TYPE_CHECKING +from enum import Enum, auto +from typing import Optional, Dict, Any, Union, Iterable, Generator, AsyncGenerator, TYPE_CHECKING, Callable, Type import itertools -from pydantic import BaseModel, parse_raw_as, create_model, Extra +from pydantic import BaseModel, create_model, Extra from pydantic.config import get_config from pydantic.typing import get_all_type_hints -from fastmessage.common import _CALLABLE_TYPE, _get_callable_name, _logger +from fastmessage.common import CustomOutput, InputDeviceName, MultipleReturnValues +from fastmessage.common import _CALLABLE_TYPE, get_callable_name, _logger from fastmessage.exceptions import NotAllowedParamKindException, SpecialDefaultValueException +from fastmessage.method_validator import MethodValidator from messageflux import InputDevice from messageflux.iodevices.base.common import MessageBundle, Message from messageflux.pipeline_service import PipelineResult @@ -19,33 +22,24 @@ from fastmessage.fastmessage_handler import FastMessage +class _CallableType(Enum): + SYNC = auto() + ASYNC = auto() + ASYNC_GENERATOR = auto() + + @dataclass class _ParamInfo: annotation: Any default: Any -class InputDeviceName(str): - """ - a place holder class for input_device name - """ - pass - - -class MultipleReturnValues(list): - """ - a value that indicates that multiple output values should be returned - """ - pass - - @dataclass -class FastMessageOutput: - """ - a result that contains the output device name to send the value to - """ - output_device: str - value: Any +class _CallableAnalysis: + params: Dict[str, _ParamInfo] + special_params: Dict[str, _ParamInfo] + has_kwargs: bool + callable_type: _CallableType class CallableWrapper: @@ -55,27 +49,31 @@ class CallableWrapper: def __init__(self, *, fastmessage_handler: 'FastMessage', - callback: _CALLABLE_TYPE, - input_device: str, - output_device: Optional[str] = None): - self._callback = callback + wrapped_callable: _CALLABLE_TYPE, + input_device_name: str, + output_device_name: Optional[str] = None): self._fastmessage_handler = fastmessage_handler - self._input_device = input_device - self._output_device = output_device - self._special_params: Dict[str, _ParamInfo] = dict() - self._params: Dict[str, _ParamInfo] = dict() - self._is_async = inspect.iscoroutinefunction(callback) - self._is_async_gen = inspect.isasyncgenfunction(callback) - - type_hints = get_all_type_hints(self._callback) - extra = Extra.ignore - for param_name, param in inspect.signature(self._callback).parameters.items(): + self._callable = wrapped_callable + self._input_device_name = input_device_name + self._output_device_name = output_device_name + + self._callable_analysis = self._analyze_callable(self._callable) + self._model: Type[BaseModel] = self._create_model(model_name=self._get_model_name(), + callable_analysis=self._callable_analysis) + + @staticmethod + def _analyze_callable(wrapped_callable: _CALLABLE_TYPE) -> _CallableAnalysis: + params = dict() + special_params = dict() + type_hints = get_all_type_hints(wrapped_callable) + has_kwargs = False + for param_name, param in inspect.signature(wrapped_callable).parameters.items(): if param.kind in (param.POSITIONAL_ONLY, param.VAR_POSITIONAL): raise NotAllowedParamKindException( f"param '{param_name}' is of '{param.kind}' kind. this is now allowed") if param.kind == param.VAR_KEYWORD: # there's **kwargs param - extra = Extra.allow + has_kwargs = True continue annotation = Any if param.annotation is param.empty else type_hints[param_name] @@ -83,30 +81,73 @@ def __init__(self, *, param_info = _ParamInfo(annotation=annotation, default=default) - if param_info.annotation in (MessageBundle, Message, InputDeviceName, - Optional[MessageBundle], Optional[Message], Optional[InputDeviceName]): + if param_info.annotation in (MessageBundle, Optional[MessageBundle], + Message, Optional[Message], + InputDeviceName, Optional[InputDeviceName], + MethodValidator, Optional[MethodValidator]): if param_info.default is not ...: raise SpecialDefaultValueException( f"param '{param_name}' is of special type '{param.annotation.__name__}' " f"but has a default value") - self._special_params[param_name] = param_info + special_params[param_name] = param_info else: - self._params[param_name] = param_info - - self._model = None - if self._params: - model_name = self._get_model_name() - model_params: Dict[str, Any] = {} - for param_name, param_info in self._params.items(): - model_params[param_name] = (param_info.annotation, param_info.default) - self._model = create_model(model_name, - __config__=get_config(dict(extra=extra)), - **model_params) + params[param_name] = param_info + + callable_type = _CallableType.SYNC + if inspect.iscoroutinefunction(wrapped_callable): + callable_type = _CallableType.ASYNC + elif inspect.isasyncgenfunction(wrapped_callable): + callable_type = _CallableType.ASYNC_GENERATOR + + return _CallableAnalysis(params=params, + special_params=special_params, + has_kwargs=has_kwargs, + callable_type=callable_type) + + @staticmethod + def _create_model(model_name: str, callable_analysis: _CallableAnalysis) -> Type[BaseModel]: + model_params: Dict[str, Any] = {} + for param_name, param_info in callable_analysis.params.items(): + model_params[param_name] = (param_info.annotation, param_info.default) + extra = Extra.allow if callable_analysis.has_kwargs else Extra.ignore + model = create_model(model_name, + __config__=get_config(dict(extra=extra)), + **model_params) + + return model + + @property + def model(self) -> Type[BaseModel]: + """ + the model that was created for this callable (if it has input params) + """ + return self._model + + @property + def callable(self) -> Callable: + """ + the original callable that was passed to this wrapper + """ + return self._callable + + @property + def input_device_name(self) -> str: + """ + the input device name for this callable + """ + return self._input_device_name + + @property + def output_device_name(self) -> Optional[str]: + """ + the default output device name for this callable + """ + return self._output_device_name def _get_model_name(self) -> str: - callable_name = _get_callable_name(self._callback) - return f"model_{callable_name}_{self._input_device}" + callable_name = get_callable_name(self._callable) + return f"model_{callable_name}_{self._input_device_name}" @staticmethod def _iter_over_async(async_generator: AsyncGenerator, loop: AbstractEventLoop): @@ -128,29 +169,35 @@ async def get_next(): def __call__(self, input_device: InputDevice, message_bundle: MessageBundle) -> Optional[Union[PipelineResult, Iterable[PipelineResult]]]: + kwargs: Dict[str, Any] = {} - for param_name, param_info in self._special_params.items(): + for param_name, param_info in self._callable_analysis.special_params.items(): if param_info.annotation is InputDeviceName: kwargs[param_name] = input_device.name elif param_info.annotation is MessageBundle: kwargs[param_name] = message_bundle elif param_info.annotation is Message: kwargs[param_name] = message_bundle.message + elif param_info.annotation is MethodValidator: + kwargs[param_name] = MethodValidator(self._fastmessage_handler) + + model: BaseModel = self._model.parse_raw(message_bundle.message.bytes) + kwargs.update(dict(model)) + + if self._callable_analysis.callable_type == _CallableType.ASYNC: + callback_return = self._fastmessage_handler.event_loop.run_until_complete(self._callable(**kwargs)) + + elif self._callable_analysis.callable_type == _CallableType.ASYNC_GENERATOR: + callback_return = self._iter_over_async(self._callable(**kwargs), self._fastmessage_handler.event_loop) - if self._model: - model = parse_raw_as(self._model, message_bundle.message.bytes) - kwargs.update(dict(model)) - if self._is_async: - callback_return = self._fastmessage_handler.event_loop.run_until_complete(self._callback(**kwargs)) - elif self._is_async_gen: - callback_return = self._iter_over_async(self._callback(**kwargs), self._fastmessage_handler.event_loop) else: - callback_return = self._callback(**kwargs) + callback_return = self._callable(**kwargs) + if callback_return is None: return None return self._get_pipeline_results(value=callback_return, - default_output_device=self._output_device) + default_output_device=self._output_device_name) def _get_pipeline_results(self, value: Any, @@ -161,7 +208,7 @@ def _get_pipeline_results(self, default_output_device), value)) - elif isinstance(value, FastMessageOutput): + elif isinstance(value, CustomOutput): return self._get_pipeline_results(value=value.value, default_output_device=value.output_device) else: @@ -174,7 +221,7 @@ def _get_pipeline_results(self, def _get_single_pipeline_result(self, value: Any, output_device: Optional[str]) -> Optional[PipelineResult]: if output_device is None: - _logger.warning(f"callback for input device '{self._input_device}' returned value, " + _logger.warning(f"callback for input device '{self._input_device_name}' returned value, " f"but is not mapped to output device") return None diff --git a/fastmessage/common.py b/fastmessage/common.py index 6ef4ba7..a8c488f 100644 --- a/fastmessage/common.py +++ b/fastmessage/common.py @@ -1,4 +1,5 @@ import logging +from dataclasses import dataclass from typing import Callable, Any, TypeVar from fastmessage.exceptions import UnnamedCallableException @@ -8,8 +9,38 @@ _CALLABLE_TYPE = TypeVar('_CALLABLE_TYPE', bound=Callable[..., Any]) -def _get_callable_name(callback: _CALLABLE_TYPE) -> str: +def get_callable_name(named_callable: _CALLABLE_TYPE) -> str: + """ + tries to return a callable name + + :param named_callable: the callable to get the name for + + :return: the name of the callable, or raises UnnamedCallableException + """ try: - return getattr(callback, "__name__") + return getattr(named_callable, "__name__") except AttributeError as ex: - raise UnnamedCallableException(f"Callable {repr(callback)} doesn't have a name") from ex + raise UnnamedCallableException(f"Callable {repr(named_callable)} doesn't have a name") from ex + + +class InputDeviceName(str): + """ + a place holder class for input_device name + """ + pass + + +class MultipleReturnValues(list): + """ + a value that indicates that multiple output values should be returned + """ + pass + + +@dataclass +class CustomOutput: + """ + a result that contains the output device name to send the value to + """ + output_device: str + value: Any diff --git a/fastmessage/exceptions.py b/fastmessage/exceptions.py index 7f7f85f..4b4c096 100644 --- a/fastmessage/exceptions.py +++ b/fastmessage/exceptions.py @@ -2,6 +2,10 @@ class FastMessageException(Exception): pass +class MethodValidationError(Exception): + pass + + class DuplicateCallbackException(FastMessageException): pass diff --git a/fastmessage/fastmessage_handler.py b/fastmessage/fastmessage_handler.py index 8dc7c4c..bf31115 100644 --- a/fastmessage/fastmessage_handler.py +++ b/fastmessage/fastmessage_handler.py @@ -5,7 +5,7 @@ from pydantic import ValidationError from fastmessage.callable_wrapper import CallableWrapper -from fastmessage.common import _CALLABLE_TYPE, _get_callable_name +from fastmessage.common import _CALLABLE_TYPE, get_callable_name from fastmessage.exceptions import DuplicateCallbackException, MissingCallbackException from messageflux import InputDevice from messageflux.iodevices.base import InputDeviceManager, OutputDeviceManager @@ -36,6 +36,7 @@ def __init__(self, self._default_output_device = default_output_device self._validation_error_handler = validation_error_handler self._wrappers: Dict[str, CallableWrapper] = {} + self._callable_to_input_device: Dict[Callable, str] = {} self._event_loop_cache: Optional[AbstractEventLoop] = None @property @@ -80,7 +81,7 @@ def register_callback(self, if callback returns None, no routing will be made even if 'output_device' is not None """ if input_device is _DEFAULT: - input_device = _get_callable_name(callback) + input_device = get_callable_name(callback) if input_device in self._wrappers: raise DuplicateCallbackException(f"Can't register more than one callback on device '{input_device}'") @@ -88,10 +89,11 @@ def register_callback(self, if output_device is _DEFAULT: output_device = self._default_output_device + self._callable_to_input_device[callback] = input_device self._wrappers[input_device] = CallableWrapper(fastmessage_handler=self, - callback=callback, - input_device=input_device, - output_device=output_device) + wrapped_callable=callback, + input_device_name=input_device, + output_device_name=output_device) def map(self, input_device: str = _DEFAULT, diff --git a/fastmessage/method_validator.py b/fastmessage/method_validator.py new file mode 100644 index 0000000..aa80649 --- /dev/null +++ b/fastmessage/method_validator.py @@ -0,0 +1,55 @@ +from typing import Union, Callable, TYPE_CHECKING, Type + +from pydantic import BaseModel, ValidationError + +from fastmessage.common import CustomOutput +from fastmessage.exceptions import MissingCallbackException, MethodValidationError + +if TYPE_CHECKING: + from fastmessage.fastmessage_handler import FastMessage + from fastmessage.callable_wrapper import CallableWrapper + + +class MethodValidator: + """ + a class used to validate results for sending to another method BEFORE sending them + """ + + def __init__(self, fastmessage_handler: 'FastMessage'): + self._fastmessage_handler = fastmessage_handler + + def _get_callable_wrapper(self, method: Union[str, Callable]) -> 'CallableWrapper': + input_device = method + try: + if callable(method): + input_device = self._fastmessage_handler._callable_to_input_device[method] + + assert isinstance(input_device, str) + return self._fastmessage_handler._wrappers[input_device] + except KeyError: + raise MissingCallbackException(f'callback {input_device} is not registered') + + def validate_and_return(self, method: Union[str, Callable], **kwargs) -> CustomOutput: + """ + validates the arguments for the method, and returns it with the right output device + + :param method: the method or input device name to send the arguments to + :param kwargs: the arguments to the method + + :return: a CustomOutput object, with the right details + """ + callable_wrapper = self._get_callable_wrapper(method) + try: + return CustomOutput(output_device=callable_wrapper.input_device_name, + value=callable_wrapper.model(**kwargs)) + except ValidationError as ex: + raise MethodValidationError(str(ex)) from ex + + def get_model(self, method: Union[str, Callable]) -> Type[BaseModel]: + """ + return the input model for the method + + :param method: the method (or input device name) to get the model to + :return: the BaseModel type for that method + """ + return self._get_callable_wrapper(method).model diff --git a/tests/async_test.py b/tests/async_test.py new file mode 100644 index 0000000..de56c84 --- /dev/null +++ b/tests/async_test.py @@ -0,0 +1,54 @@ +import json +import uuid +from typing import List + +from fastmessage import FastMessage, InputDeviceName +from messageflux.iodevices.base.common import MessageBundle, Message +from tests.common import FakeInputDevice + + +def test_sanity_async(): + default_output_device = str(uuid.uuid4()).replace('-', '') + fm: FastMessage = FastMessage(default_output_device=default_output_device) + + @fm.map() + async def do_something1(x: int, y: str, z: List[int] = None): + return dict(y=f'x={x}, y={y}, z={z}') + + result = fm.handle_message(FakeInputDevice('do_something1'), + MessageBundle(Message(b'{"x": 1, "y": "a", "F":3}'))) + assert result is not None + result = result[0] + assert result.output_device_name == default_output_device + json_result = json.loads(result.message_bundle.message.bytes.decode()) + assert json_result['y'] == 'x=1, y=a, z=None' + + result = fm.handle_message(FakeInputDevice('do_something1'), + MessageBundle(Message(b'{"x": 1, "y": "a", "z":[1,2]}'))) + assert result is not None + result = result[0] + assert result.output_device_name == default_output_device + json_result = json.loads(result.message_bundle.message.bytes.decode()) + assert json_result['y'] == 'x=1, y=a, z=[1, 2]' + + +def test_return_async_generator(): + default_output_device = str(uuid.uuid4()).replace('-', '') + fm: FastMessage = FastMessage(default_output_device=default_output_device) + + @fm.map(input_device='input1') + async def do_something1(m: Message, b: MessageBundle, d: InputDeviceName, y: int): + yield 1 + yield 2 + yield 3 + + result = fm.handle_message(FakeInputDevice('input1'), + MessageBundle(message=Message(data=b'{"y": 10}', + headers={'test': 'mtest'}), + device_headers={'test': 'btest'})) + assert result is not None + result = list(result) + assert len(result) == 3 + assert result[0].message_bundle.message.bytes == b'1' + assert result[1].message_bundle.message.bytes == b'2' + assert result[2].message_bundle.message.bytes == b'3' diff --git a/tests/common.py b/tests/common.py new file mode 100644 index 0000000..32918cc --- /dev/null +++ b/tests/common.py @@ -0,0 +1,15 @@ +import threading +from typing import Optional + +from messageflux import InputDevice, ReadResult + + +class FakeInputDevice(InputDevice): + def _read_message(self, + cancellation_token: threading.Event, + timeout: Optional[float] = None, + with_transaction: bool = True) -> Optional['ReadResult']: + return None + + def __init__(self, name: str): + super().__init__(None, name) diff --git a/tests/fastmessage_handler_test.py b/tests/handler_test.py similarity index 82% rename from tests/fastmessage_handler_test.py rename to tests/handler_test.py index ab7d007..11d6258 100644 --- a/tests/fastmessage_handler_test.py +++ b/tests/handler_test.py @@ -1,28 +1,16 @@ import json -import threading import uuid -from typing import Optional, List +from typing import List import pytest from pydantic import BaseModel, ValidationError from fastmessage import FastMessage, MissingCallbackException, DuplicateCallbackException, \ - InputDeviceName, SpecialDefaultValueException, NotAllowedParamKindException, MultipleReturnValues, FastMessageOutput -from messageflux import ReadResult + InputDeviceName, SpecialDefaultValueException, NotAllowedParamKindException, MultipleReturnValues, CustomOutput from messageflux.iodevices.base import InputDevice from messageflux.iodevices.base.common import MessageBundle, Message from messageflux.pipeline_service import PipelineResult - - -class FakeInputDevice(InputDevice): - def _read_message(self, - cancellation_token: threading.Event, - timeout: Optional[float] = None, - with_transaction: bool = True) -> Optional['ReadResult']: - return None - - def __init__(self, name: str): - super().__init__(None, name) +from tests.common import FakeInputDevice class SomeModel(BaseModel): @@ -58,29 +46,21 @@ def do_something1(x: SomeModel, y: str, z: List[int] = None): assert json_result['y'] == 'x=1, y=a, z=[1, 2]' -def test_sanity_async(): +def test_empty_method(): default_output_device = str(uuid.uuid4()).replace('-', '') fm: FastMessage = FastMessage(default_output_device=default_output_device) @fm.map() - async def do_something1(x: SomeModel, y: str, z: List[int] = None): - return SomeOtherModel(y=f'x={x.x}, y={y}, z={z}') - - result = fm.handle_message(FakeInputDevice('do_something1'), - MessageBundle(Message(b'{"x": {"x":1}, "y": "a", "F":3}'))) - assert result is not None - result = result[0] - assert result.output_device_name == default_output_device - json_result = json.loads(result.message_bundle.message.bytes.decode()) - assert json_result['y'] == 'x=1, y=a, z=None' + def do_something(): + return SomeOtherModel(y='x=1, y=2, z=3') - result = fm.handle_message(FakeInputDevice('do_something1'), - MessageBundle(Message(b'{"x": {"x":1}, "y": "a", "z":[1,2]}'))) + result = fm.handle_message(FakeInputDevice('do_something'), + MessageBundle(Message(b'{}'))) assert result is not None result = result[0] assert result.output_device_name == default_output_device json_result = json.loads(result.message_bundle.message.bytes.decode()) - assert json_result['y'] == 'x=1, y=a, z=[1, 2]' + assert json_result['y'] == 'x=1, y=2, z=3' def test_root_model(): @@ -247,9 +227,9 @@ def test_custom_output_device_result(): @fm.map(input_device='input1') def do_something1(m: Message, b: MessageBundle, d: InputDeviceName, y: int): return MultipleReturnValues([ - FastMessageOutput(value=1, output_device='test1'), - FastMessageOutput(value=2, output_device='test2'), - FastMessageOutput(value=3, output_device='test3'), + CustomOutput(value=1, output_device='test1'), + CustomOutput(value=2, output_device='test2'), + CustomOutput(value=3, output_device='test3'), 4 ]) @@ -276,9 +256,9 @@ def test_no_output_device(): @fm.map(input_device='input1') def do_something1(m: Message, b: MessageBundle, d: InputDeviceName, y: int): return MultipleReturnValues([ - FastMessageOutput(value=1, output_device='test1'), - FastMessageOutput(value=2, output_device='test2'), - FastMessageOutput(value=3, output_device='test3'), + CustomOutput(value=1, output_device='test1'), + CustomOutput(value=2, output_device='test2'), + CustomOutput(value=3, output_device='test3'), 4 ]) @@ -319,28 +299,6 @@ def do_something1(m: Message, b: MessageBundle, d: InputDeviceName, y: int): assert result[2].message_bundle.message.bytes == b'3' -def test_return_async_generator(): - default_output_device = str(uuid.uuid4()).replace('-', '') - fm: FastMessage = FastMessage(default_output_device=default_output_device) - - @fm.map(input_device='input1') - async def do_something1(m: Message, b: MessageBundle, d: InputDeviceName, y: int): - yield 1 - yield 2 - yield 3 - - result = fm.handle_message(FakeInputDevice('input1'), - MessageBundle(message=Message(data=b'{"y": 10}', - headers={'test': 'mtest'}), - device_headers={'test': 'btest'})) - assert result is not None - result = list(result) - assert len(result) == 3 - assert result[0].message_bundle.message.bytes == b'1' - assert result[1].message_bundle.message.bytes == b'2' - assert result[2].message_bundle.message.bytes == b'3' - - def test_return_complex(): default_output_device = str(uuid.uuid4()).replace('-', '') fm: FastMessage = FastMessage(default_output_device=default_output_device) diff --git a/tests/method_validation_test.py b/tests/method_validation_test.py new file mode 100644 index 0000000..30904bc --- /dev/null +++ b/tests/method_validation_test.py @@ -0,0 +1,84 @@ +import pytest + +from fastmessage import FastMessage, MissingCallbackException +from fastmessage.exceptions import MethodValidationError +from fastmessage.method_validator import MethodValidator +from messageflux.iodevices.base.common import MessageBundle, Message +from tests.common import FakeInputDevice + + +def test_by_method(): + fm: FastMessage = FastMessage() + + @fm.map() + def func_input(method_validator: MethodValidator): + return method_validator.validate_and_return(func_output, x=3, y="hello") + + @fm.map(input_device='func2_device', output_device='output') + def func_output(x: int, y: str): + return f"Success: x={x}, y={y}" + + result = fm.handle_message(FakeInputDevice('func_input'), MessageBundle(message=Message(data=b'{"y": 10}'))) + assert result is not None + result = list(result) + assert len(result) == 1 + assert result[0].output_device_name == "func2_device" + + result = fm.handle_message(FakeInputDevice(result[0].output_device_name), result[0].message_bundle) + assert result is not None + result = list(result) + assert len(result) == 1 + assert result[0].message_bundle.message.bytes == b'"Success: x=3, y=hello"' + + +def test_by_input_device_name(): + fm: FastMessage = FastMessage() + + @fm.map() + def func_input(method_validator: MethodValidator): + return method_validator.validate_and_return("func2_device", x=3, y="hello") + + @fm.map(input_device='func2_device', output_device='output') + def func_output(x: int, y: str): + return f"Success: x={x}, y={y}" + + result = fm.handle_message(FakeInputDevice('func_input'), MessageBundle(message=Message(data=b'{"y": 10}'))) + assert result is not None + result = list(result) + assert len(result) == 1 + assert result[0].output_device_name == "func2_device" + + result = fm.handle_message(FakeInputDevice(result[0].output_device_name), result[0].message_bundle) + assert result is not None + result = list(result) + assert len(result) == 1 + assert result[0].message_bundle.message.bytes == b'"Success: x=3, y=hello"' + + +def test_validation_error(): + fm: FastMessage = FastMessage() + + @fm.map() + def func_input(method_validator: MethodValidator): + return method_validator.validate_and_return("func2_device", x=3) + + @fm.map(input_device='func2_device', output_device='output') + def func_output(x: int, y: str): + return f"Success: x={x}, y={y}" + + with pytest.raises(MethodValidationError): + _ = fm.handle_message(FakeInputDevice('func_input'), MessageBundle(message=Message(data=b'{"y": 10}'))) + + +def test_missing_callback(): + fm: FastMessage = FastMessage() + + @fm.map() + def func_input(method_validator: MethodValidator): + return method_validator.validate_and_return(func_output, x=3) + + def func_output(x: int, y: str): + return f"Success: x={x}, y={y}" + + with pytest.raises(MissingCallbackException): + _ = fm.handle_message(FakeInputDevice('func_input'), MessageBundle(message=Message(data=b'{"y": 10}')))