Skip to content

Commit

Permalink
Merge pull request #13 from Avivsalem/staging
Browse files Browse the repository at this point in the history
Staging
  • Loading branch information
Avivsalem authored Aug 11, 2023
2 parents cc8f7a2 + 0ade6c0 commit fedf3e7
Show file tree
Hide file tree
Showing 12 changed files with 409 additions and 147 deletions.
2 changes: 1 addition & 1 deletion VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.1.0
0.2.0
36 changes: 23 additions & 13 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -98,31 +98,41 @@ def process_somemodel(__root__: SomeModel):

### Special Param Types

There are three special types which you can annotate the arguments for the callback with.
There are special types which you can annotate the arguments for the callback with.

* ```InputDeviceName``` - arguments annotated with this type will receive the name of the input device the message came
from. useful for registering the same callback for several input devices
* ```Message``` - arguments annotated with this type will receive the raw message which came from the device
* ```MessageBundle``` - arguments annotated with this type will receive the complete MessageBundle (with device headers)

* ```MethodValidator``` - argument of this type, will receive an object that can help validate return values for other methods
Notice that arguments annotated with these types MUST NOT have default values (Since they always have values).

```python
from pydantic import BaseModel

from fastmessage import FastMessage, InputDeviceName

from fastmessage import FastMessage, InputDeviceName, MethodValidator
from messageflux.iodevices.base import Message, MessageBundle

fm = FastMessage()


@fm.map(input_device='some_queue')
def do_something(i: InputDeviceName, m: Message, mb: MessageBundle, x: int):
# i will be 'some_queue'
# m will be the message that arrived
# mb will be the MessageBundle that arrived
# x will be the serialized value of the message
pass # do something
# i will be 'some_queue'
# m will be the message that arrived
# mb will be the MessageBundle that arrived
# x will be the serialized value of the message
pass # do something


@fm.map()
def func1(mv: MethodValidator):
yield mv.validate_and_return(func2, x=3, y="hello") # this will succeed
yield mv.validate_and_return(func2, x=4) # this will raise MethodValidationError because y param is required but missing

@fm.map()
def func2(x:int, y:str):
pass
```

### Returning Multiple Results
Expand Down Expand Up @@ -157,18 +167,18 @@ def do_something_generator(x: int): # this method does the same as the previous

You can make the function return a result to a different output device then the one in the decorator

You do this by using the ```FastMessageOutput``` class, and giving it the value to send, and the output device name to
You do this by using the ```CustomOutput``` class, and giving it the value to send, and the output device name to
send to

```python
from fastmessage import FastMessage, FastMessageOutput
from fastmessage import FastMessage, CustomOutput

fm = FastMessage()


@fm.map(input_device='some_queue', output_device='default_output_device')
def do_something(x: int):
return FastMessageOutput(value=1,
output_device='other_output_device') # this will send the value 1 to 'other_output_device' instead of the default
return CustomOutput(value=1,
output_device='other_output_device') # this will send the value 1 to 'other_output_device' instead of the default
```

6 changes: 4 additions & 2 deletions fastmessage/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .callable_wrapper import (
FastMessageOutput,
from .common import (
CustomOutput,
InputDeviceName,
MultipleReturnValues,
)
Expand All @@ -10,6 +10,8 @@
DuplicateCallbackException,
UnnamedCallableException,
NotAllowedParamKindException,
MethodValidationError

)
from .fastmessage_handler import FastMessage
from .method_validator import MethodValidator
179 changes: 113 additions & 66 deletions fastmessage/callable_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,18 @@
import json
from asyncio import AbstractEventLoop
from dataclasses import dataclass
from typing import Optional, Dict, Any, Union, Iterable, Generator, AsyncGenerator, TYPE_CHECKING
from enum import Enum, auto
from typing import Optional, Dict, Any, Union, Iterable, Generator, AsyncGenerator, TYPE_CHECKING, Callable, Type

import itertools
from pydantic import BaseModel, parse_raw_as, create_model, Extra
from pydantic import BaseModel, create_model, Extra
from pydantic.config import get_config
from pydantic.typing import get_all_type_hints

from fastmessage.common import _CALLABLE_TYPE, _get_callable_name, _logger
from fastmessage.common import CustomOutput, InputDeviceName, MultipleReturnValues
from fastmessage.common import _CALLABLE_TYPE, get_callable_name, _logger
from fastmessage.exceptions import NotAllowedParamKindException, SpecialDefaultValueException
from fastmessage.method_validator import MethodValidator
from messageflux import InputDevice
from messageflux.iodevices.base.common import MessageBundle, Message
from messageflux.pipeline_service import PipelineResult
Expand All @@ -19,33 +22,24 @@
from fastmessage.fastmessage_handler import FastMessage


class _CallableType(Enum):
SYNC = auto()
ASYNC = auto()
ASYNC_GENERATOR = auto()


@dataclass
class _ParamInfo:
annotation: Any
default: Any


class InputDeviceName(str):
"""
a place holder class for input_device name
"""
pass


class MultipleReturnValues(list):
"""
a value that indicates that multiple output values should be returned
"""
pass


@dataclass
class FastMessageOutput:
"""
a result that contains the output device name to send the value to
"""
output_device: str
value: Any
class _CallableAnalysis:
params: Dict[str, _ParamInfo]
special_params: Dict[str, _ParamInfo]
has_kwargs: bool
callable_type: _CallableType


class CallableWrapper:
Expand All @@ -55,58 +49,105 @@ class CallableWrapper:

def __init__(self, *,
fastmessage_handler: 'FastMessage',
callback: _CALLABLE_TYPE,
input_device: str,
output_device: Optional[str] = None):
self._callback = callback
wrapped_callable: _CALLABLE_TYPE,
input_device_name: str,
output_device_name: Optional[str] = None):
self._fastmessage_handler = fastmessage_handler
self._input_device = input_device
self._output_device = output_device
self._special_params: Dict[str, _ParamInfo] = dict()
self._params: Dict[str, _ParamInfo] = dict()
self._is_async = inspect.iscoroutinefunction(callback)
self._is_async_gen = inspect.isasyncgenfunction(callback)

type_hints = get_all_type_hints(self._callback)
extra = Extra.ignore
for param_name, param in inspect.signature(self._callback).parameters.items():
self._callable = wrapped_callable
self._input_device_name = input_device_name
self._output_device_name = output_device_name

self._callable_analysis = self._analyze_callable(self._callable)
self._model: Type[BaseModel] = self._create_model(model_name=self._get_model_name(),
callable_analysis=self._callable_analysis)

@staticmethod
def _analyze_callable(wrapped_callable: _CALLABLE_TYPE) -> _CallableAnalysis:
params = dict()
special_params = dict()
type_hints = get_all_type_hints(wrapped_callable)
has_kwargs = False
for param_name, param in inspect.signature(wrapped_callable).parameters.items():
if param.kind in (param.POSITIONAL_ONLY, param.VAR_POSITIONAL):
raise NotAllowedParamKindException(
f"param '{param_name}' is of '{param.kind}' kind. this is now allowed")

if param.kind == param.VAR_KEYWORD: # there's **kwargs param
extra = Extra.allow
has_kwargs = True
continue

annotation = Any if param.annotation is param.empty else type_hints[param_name]
default = ... if param.default is param.empty else param.default

param_info = _ParamInfo(annotation=annotation, default=default)

if param_info.annotation in (MessageBundle, Message, InputDeviceName,
Optional[MessageBundle], Optional[Message], Optional[InputDeviceName]):
if param_info.annotation in (MessageBundle, Optional[MessageBundle],
Message, Optional[Message],
InputDeviceName, Optional[InputDeviceName],
MethodValidator, Optional[MethodValidator]):
if param_info.default is not ...:
raise SpecialDefaultValueException(
f"param '{param_name}' is of special type '{param.annotation.__name__}' "
f"but has a default value")
self._special_params[param_name] = param_info
special_params[param_name] = param_info

else:
self._params[param_name] = param_info

self._model = None
if self._params:
model_name = self._get_model_name()
model_params: Dict[str, Any] = {}
for param_name, param_info in self._params.items():
model_params[param_name] = (param_info.annotation, param_info.default)
self._model = create_model(model_name,
__config__=get_config(dict(extra=extra)),
**model_params)
params[param_name] = param_info

callable_type = _CallableType.SYNC
if inspect.iscoroutinefunction(wrapped_callable):
callable_type = _CallableType.ASYNC
elif inspect.isasyncgenfunction(wrapped_callable):
callable_type = _CallableType.ASYNC_GENERATOR

return _CallableAnalysis(params=params,
special_params=special_params,
has_kwargs=has_kwargs,
callable_type=callable_type)

@staticmethod
def _create_model(model_name: str, callable_analysis: _CallableAnalysis) -> Type[BaseModel]:
model_params: Dict[str, Any] = {}
for param_name, param_info in callable_analysis.params.items():
model_params[param_name] = (param_info.annotation, param_info.default)
extra = Extra.allow if callable_analysis.has_kwargs else Extra.ignore
model = create_model(model_name,
__config__=get_config(dict(extra=extra)),
**model_params)

return model

@property
def model(self) -> Type[BaseModel]:
"""
the model that was created for this callable (if it has input params)
"""
return self._model

@property
def callable(self) -> Callable:
"""
the original callable that was passed to this wrapper
"""
return self._callable

@property
def input_device_name(self) -> str:
"""
the input device name for this callable
"""
return self._input_device_name

@property
def output_device_name(self) -> Optional[str]:
"""
the default output device name for this callable
"""
return self._output_device_name

def _get_model_name(self) -> str:
callable_name = _get_callable_name(self._callback)
return f"model_{callable_name}_{self._input_device}"
callable_name = get_callable_name(self._callable)
return f"model_{callable_name}_{self._input_device_name}"

@staticmethod
def _iter_over_async(async_generator: AsyncGenerator, loop: AbstractEventLoop):
Expand All @@ -128,29 +169,35 @@ async def get_next():
def __call__(self,
input_device: InputDevice,
message_bundle: MessageBundle) -> Optional[Union[PipelineResult, Iterable[PipelineResult]]]:

kwargs: Dict[str, Any] = {}
for param_name, param_info in self._special_params.items():
for param_name, param_info in self._callable_analysis.special_params.items():
if param_info.annotation is InputDeviceName:
kwargs[param_name] = input_device.name
elif param_info.annotation is MessageBundle:
kwargs[param_name] = message_bundle
elif param_info.annotation is Message:
kwargs[param_name] = message_bundle.message
elif param_info.annotation is MethodValidator:
kwargs[param_name] = MethodValidator(self._fastmessage_handler)

model: BaseModel = self._model.parse_raw(message_bundle.message.bytes)
kwargs.update(dict(model))

if self._callable_analysis.callable_type == _CallableType.ASYNC:
callback_return = self._fastmessage_handler.event_loop.run_until_complete(self._callable(**kwargs))

elif self._callable_analysis.callable_type == _CallableType.ASYNC_GENERATOR:
callback_return = self._iter_over_async(self._callable(**kwargs), self._fastmessage_handler.event_loop)

if self._model:
model = parse_raw_as(self._model, message_bundle.message.bytes)
kwargs.update(dict(model))
if self._is_async:
callback_return = self._fastmessage_handler.event_loop.run_until_complete(self._callback(**kwargs))
elif self._is_async_gen:
callback_return = self._iter_over_async(self._callback(**kwargs), self._fastmessage_handler.event_loop)
else:
callback_return = self._callback(**kwargs)
callback_return = self._callable(**kwargs)

if callback_return is None:
return None

return self._get_pipeline_results(value=callback_return,
default_output_device=self._output_device)
default_output_device=self._output_device_name)

def _get_pipeline_results(self,
value: Any,
Expand All @@ -161,7 +208,7 @@ def _get_pipeline_results(self,
default_output_device),
value))

elif isinstance(value, FastMessageOutput):
elif isinstance(value, CustomOutput):
return self._get_pipeline_results(value=value.value,
default_output_device=value.output_device)
else:
Expand All @@ -174,7 +221,7 @@ def _get_pipeline_results(self,

def _get_single_pipeline_result(self, value: Any, output_device: Optional[str]) -> Optional[PipelineResult]:
if output_device is None:
_logger.warning(f"callback for input device '{self._input_device}' returned value, "
_logger.warning(f"callback for input device '{self._input_device_name}' returned value, "
f"but is not mapped to output device")
return None

Expand Down
Loading

0 comments on commit fedf3e7

Please sign in to comment.