diff --git a/src/dubbo/classes.py b/src/dubbo/classes.py index 8d87299..d6ba2ab 100644 --- a/src/dubbo/classes.py +++ b/src/dubbo/classes.py @@ -13,10 +13,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import abc import threading -from typing import Any, Callable, Optional, Union - +from typing import Any, Callable, Optional, Union,Type +from abc import ABC, abstractmethod from dubbo.types import DeserializingFunction, RpcType, RpcTypes, SerializingFunction __all__ = [ @@ -244,3 +245,21 @@ class ReadWriteStream(ReadStream, WriteStream, abc.ABC): """ pass + + +class Codec(ABC): + def __init__(self, model_type: Optional[Type[Any]] = None, **kwargs): + self.model_type = model_type + + @abstractmethod + def encode(self, data: Any) -> bytes: + pass + + @abstractmethod + def decode(self, data: bytes) -> Any: + pass + +class CodecHelper: + @staticmethod + def get_class(): + return Codec diff --git a/src/dubbo/client.py b/src/dubbo/client.py index 33e6264..7dc215a 100644 --- a/src/dubbo/client.py +++ b/src/dubbo/client.py @@ -13,8 +13,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import threading -from typing import Optional +from typing import Optional, Callable, List, Type, Union, Any from dubbo.bootstrap import Dubbo from dubbo.classes import MethodDescriptor @@ -31,6 +32,7 @@ SerializingFunction, ) from dubbo.url import URL +from dubbo.codec import DubboTransportService __all__ = ["Client"] @@ -82,70 +84,179 @@ def _initialize(self): self._initialized = True - def unary( + def _create_rpc_callable( self, - method_name: str, + rpc_type: str, + interface: Optional[Callable] = None, + method_name: Optional[str] = None, + params_types: Optional[List[Type]] = None, + return_type: Optional[Type] = None, + codec: Optional[str] = None, request_serializer: Optional[SerializingFunction] = None, response_deserializer: Optional[DeserializingFunction] = None, + default_method_name: str = "rpc_call", ) -> RpcCallable: - return self._callable( - MethodDescriptor( - method_name=method_name, - arg_serialization=(request_serializer, None), - return_serialization=(None, response_deserializer), - rpc_type=RpcTypes.UNARY.value, + """ + Create RPC callable with the specified type. + """ + # Validate + if interface is None and method_name is None: + raise ValueError("Either 'interface' or 'method_name' must be provided") + + # Determine the actual method name to call + actual_method_name = method_name or (interface.__name__ if interface else default_method_name) + + # Build method descriptor (automatic or manual) + if interface: + method_desc = DubboTransportService.create_method_descriptor( + func=interface, + method_name=actual_method_name, + parameter_types=params_types, + return_type=return_type, + interface=interface, + ) + else: + # Manual mode fallback: use dummy function for descriptor creation + def dummy(): pass + + method_desc = DubboTransportService.create_method_descriptor( + func=dummy, + method_name=actual_method_name, + parameter_types=params_types or [], + return_type=return_type or Any, ) + + # Determine serializers if not provided + if request_serializer and response_deserializer: + final_request_serializer = request_serializer + final_response_deserializer = response_deserializer + else: + # Use DubboTransportService to generate serialization functions + final_request_serializer, final_response_deserializer = DubboTransportService.create_serialization_functions(codec, parameter_types=params_types, return_type=return_type) + print("final",codec, final_request_serializer, final_response_deserializer) + + # Create the proper MethodDescriptor for the RPC call + rpc_method_descriptor = MethodDescriptor( + method_name=actual_method_name, + arg_serialization=(final_request_serializer, None), # (serializer, deserializer) for arguments + return_serialization=(None, final_response_deserializer), # (serializer, deserializer) for return value + rpc_type=rpc_type, + ) + + # Create and return the RpcCallable + return self._callable(rpc_method_descriptor) + + def unary( + self, + interface: Optional[Callable] = None, + method_name: Optional[str] = None, + params_types: Optional[List[Type]] = None, + return_type: Optional[Type] = None, + codec: Optional[str] = None, + request_serializer: Optional[SerializingFunction] = None, + response_deserializer: Optional[DeserializingFunction] = None, + ) -> RpcCallable: + """ + Create unary RPC call. + + Supports both automatic mode (via interface) and manual mode (via method_name + params_types + return_type + codec). + """ + return self._create_rpc_callable( + rpc_type=RpcTypes.UNARY.value, + interface=interface, + method_name=method_name, + params_types=params_types, + return_type=return_type, + codec=codec, + request_serializer=request_serializer, + response_deserializer=response_deserializer, + default_method_name="unary", ) def client_stream( self, - method_name: str, + interface: Optional[Callable] = None, + method_name: Optional[str] = None, + params_types: Optional[List[Type]] = None, + return_type: Optional[Type] = None, + codec: Optional[str] = None, request_serializer: Optional[SerializingFunction] = None, response_deserializer: Optional[DeserializingFunction] = None, ) -> RpcCallable: - return self._callable( - MethodDescriptor( - method_name=method_name, - arg_serialization=(request_serializer, None), - return_serialization=(None, response_deserializer), - rpc_type=RpcTypes.CLIENT_STREAM.value, - ) + """ + Create client streaming RPC call. + + Supports both automatic mode (via interface) and manual mode (via method_name + params_types + return_type + codec). + """ + return self._create_rpc_callable( + rpc_type=RpcTypes.CLIENT_STREAM.value, + interface=interface, + method_name=method_name, + params_types=params_types, + return_type=return_type, + codec=codec, + request_serializer=request_serializer, + response_deserializer=response_deserializer, + default_method_name="client_stream", ) def server_stream( self, - method_name: str, + interface: Optional[Callable] = None, + method_name: Optional[str] = None, + params_types: Optional[List[Type]] = None, + return_type: Optional[Type] = None, + codec: Optional[str] = None, request_serializer: Optional[SerializingFunction] = None, response_deserializer: Optional[DeserializingFunction] = None, ) -> RpcCallable: - return self._callable( - MethodDescriptor( - method_name=method_name, - arg_serialization=(request_serializer, None), - return_serialization=(None, response_deserializer), - rpc_type=RpcTypes.SERVER_STREAM.value, - ) + """ + Create server streaming RPC call. + + Supports both automatic mode (via interface) and manual mode (via method_name + params_types + return_type + codec). + """ + return self._create_rpc_callable( + rpc_type=RpcTypes.SERVER_STREAM.value, + interface=interface, + method_name=method_name, + params_types=params_types, + return_type=return_type, + codec=codec, + request_serializer=request_serializer, + response_deserializer=response_deserializer, + default_method_name="server_stream", ) def bi_stream( self, - method_name: str, + interface: Optional[Callable] = None, + method_name: Optional[str] = None, + params_types: Optional[List[Type]] = None, + return_type: Optional[Type] = None, + codec: Optional[str] = None, request_serializer: Optional[SerializingFunction] = None, response_deserializer: Optional[DeserializingFunction] = None, ) -> RpcCallable: - # create method descriptor - return self._callable( - MethodDescriptor( - method_name=method_name, - arg_serialization=(request_serializer, None), - return_serialization=(None, response_deserializer), - rpc_type=RpcTypes.BI_STREAM.value, - ) + """ + Create bidirectional streaming RPC call. + + Supports both automatic mode (via interface) and manual mode (via method_name + params_types + return_type + codec). + """ + return self._create_rpc_callable( + rpc_type=RpcTypes.BI_STREAM.value, + interface=interface, + method_name=method_name, + params_types=params_types, + return_type=return_type, + codec=codec, + request_serializer=request_serializer, + response_deserializer=response_deserializer, + default_method_name="bi_stream", ) def _callable(self, method_descriptor: MethodDescriptor) -> RpcCallable: """ - Generate a proxy for the given method + Generate a proxy for the given method. :param method_descriptor: The method descriptor. :return: The proxy. :rtype: RpcCallable @@ -160,4 +271,4 @@ def _callable(self, method_descriptor: MethodDescriptor) -> RpcCallable: url.attributes[common_constants.METHOD_DESCRIPTOR_KEY] = method_descriptor # create proxy - return self._callable_factory.get_callable(self._invoker, url) + return self._callable_factory.get_callable(self._invoker, url) \ No newline at end of file diff --git a/src/dubbo/codec/__init__.py b/src/dubbo/codec/__init__.py new file mode 100644 index 0000000..dfd1b56 --- /dev/null +++ b/src/dubbo/codec/__init__.py @@ -0,0 +1,19 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .dubbo_codec import DubboTransportService + +__all__ = ['DubboTransportService'] \ No newline at end of file diff --git a/src/dubbo/codec/dubbo_codec.py b/src/dubbo/codec/dubbo_codec.py new file mode 100644 index 0000000..8758684 --- /dev/null +++ b/src/dubbo/codec/dubbo_codec.py @@ -0,0 +1,160 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Type, Optional, Callable, List, Dict +from dataclasses import dataclass +import inspect + +from dubbo.classes import CodecHelper +from dubbo.codec.json_codec import JsonTransportCodec, JsonTransportEncoder, JsonTransportDecoder + +@dataclass +class ParameterDescriptor: + """Detailed information about a method parameter""" + name: str + annotation: Any + is_required: bool = True + default_value: Any = None + + +@dataclass +class MethodDescriptor: + """Complete method descriptor with all necessary information""" + function: Callable + name: str + parameters: List[ParameterDescriptor] + return_parameter: ParameterDescriptor + documentation: Optional[str] = None + + +class DubboTransportService: + """Enhanced Dubbo transport service with robust type handling""" + + @staticmethod + def create_transport_codec(transport_type: str = 'json', parameter_types: List[Type] = None, + return_type: Type = None, **codec_options): + """Create transport codec with enhanced parameter structure""" + if transport_type == 'json': + return JsonTransportCodec( + parameter_types=parameter_types, + return_type=return_type, + **codec_options + ) + else: + from dubbo.extension.extension_loader import ExtensionLoader + Codec = CodecHelper.get_class() + codec_class = ExtensionLoader().get_extension(Codec, transport_type) + return codec_class( + parameter_types=parameter_types, + return_type=return_type, + **codec_options + ) + + @staticmethod + def create_encoder_decoder_pair(transport_type: str, parameter_types: List[Type] = None, + return_type: Type = None, **codec_options) -> tuple[any,any]: + """Create separate encoder and decoder instances""" + + if transport_type == 'json': + parameter_encoder = JsonTransportEncoder(parameter_types=parameter_types, **codec_options) + return_decoder = JsonTransportDecoder(target_type=return_type, **codec_options) + return parameter_encoder, return_decoder + else: + from dubbo.extension.extension_loader import ExtensionLoader + Codec = CodecHelper.get_class() + codec_class = ExtensionLoader().get_extension(Codec, transport_type) + + codec_instance = codec_class( + parameter_types=parameter_types, + return_type=return_type, + **codec_options + ) + return codec_instance.get_encoder(), codec_instance.get_decoder() + + @staticmethod + def create_serialization_functions(transport_type: str, parameter_types: List[Type] = None, + return_type: Type = None, **codec_options) -> tuple[Callable, Callable]: + """Create serializer and deserializer functions for RPC (backward compatibility)""" + + parameter_encoder, return_decoder = DubboTransportService.create_encoder_decoder_pair( + transport_type=transport_type, + parameter_types=parameter_types, + return_type=return_type, + **codec_options + ) + + def serialize_method_parameters(*args) -> bytes: + return parameter_encoder.encode(args) + + def deserialize_method_return(data: bytes): + return return_decoder.decode(data) + + return serialize_method_parameters, deserialize_method_return + + @staticmethod + def create_method_descriptor(func: Callable, method_name: str = None, + parameter_types: List[Type] = None, return_type: Type = None, + interface: Callable = None) -> MethodDescriptor: + """Create a method descriptor from function and configuration""" + + name = method_name or (interface.__name__ if interface else func.__name__) + sig = inspect.signature(interface if interface else func) + + parameters = [] + resolved_parameter_types = parameter_types or [] + + for i, (param_name, param) in enumerate(sig.parameters.items()): + if param_name == 'self': + continue + + param_index = i - 1 if 'self' in sig.parameters else i + + if param_index < len(resolved_parameter_types): + param_type = resolved_parameter_types[param_index] + elif param.annotation != inspect.Parameter.empty: + param_type = param.annotation + else: + param_type = Any + + is_required = param.default == inspect.Parameter.empty + default_value = param.default if not is_required else None + + parameters.append(ParameterDescriptor( + name=param_name, + annotation=param_type, + is_required=is_required, + default_value=default_value + )) + + if return_type: + resolved_return_type = return_type + elif sig.return_annotation != inspect.Signature.empty: + resolved_return_type = sig.return_annotation + else: + resolved_return_type = Any + + return_parameter = ParameterDescriptor( + name="return_value", + annotation=resolved_return_type + ) + + return MethodDescriptor( + function=func, + name=name, + parameters=parameters, + return_parameter=return_parameter, + documentation=func.__doc__ + ) \ No newline at end of file diff --git a/src/dubbo/codec/json_codec/__init__.py b/src/dubbo/codec/json_codec/__init__.py new file mode 100644 index 0000000..f66e01d --- /dev/null +++ b/src/dubbo/codec/json_codec/__init__.py @@ -0,0 +1,19 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .json_codec_handler import JsonTransportCodec,JsonTransportDecoder,JsonTransportEncoder + +__all__ = ["JsonTransportCodec", "JsonTransportDecoder", "JsonTransportEncoder"] diff --git a/src/dubbo/codec/json_codec/json_codec_handler.py b/src/dubbo/codec/json_codec/json_codec_handler.py new file mode 100644 index 0000000..b7533c0 --- /dev/null +++ b/src/dubbo/codec/json_codec/json_codec_handler.py @@ -0,0 +1,322 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Type, List, Union, Dict, TypeVar, Protocol +from datetime import datetime, date, time +from decimal import Decimal +from pathlib import Path +from uuid import UUID +import json + +from .json_type import ( + TypeProviderFactory, SerializationState, + SerializationException, DeserializationException +) + +try: + import orjson + HAS_ORJSON = True +except ImportError: + HAS_ORJSON = False + +try: + import ujson + HAS_UJSON = True +except ImportError: + HAS_UJSON = False + +try: + from pydantic import BaseModel, create_model + HAS_PYDANTIC = True +except ImportError: + HAS_PYDANTIC = False + + +class EncodingFunction(Protocol): + def __call__(self, obj: Any) -> bytes: ... + + +class DecodingFunction(Protocol): + def __call__(self, data: bytes) -> Any: ... + + +ModelT = TypeVar('ModelT', bound=BaseModel) + + +class CustomJSONEncoder(json.JSONEncoder): + def default(self, obj): + if isinstance(obj, datetime): + return { + "__datetime__": obj.isoformat(), + "__timezone__": str(obj.tzinfo) if obj.tzinfo else None + } + elif isinstance(obj, date): + return {"__date__": obj.isoformat()} + elif isinstance(obj, time): + return {"__time__": obj.isoformat()} + elif isinstance(obj, Decimal): + return {"__decimal__": str(obj)} + elif isinstance(obj, (set, frozenset)): + return { + "__frozenset__" if isinstance(obj, frozenset) else "__set__": list(obj) + } + elif isinstance(obj, UUID): + return {"__uuid__": str(obj)} + elif isinstance(obj, Path): + return {"__path__": str(obj)} + else: + return {"__fallback_string__": str(obj), "__original_type__": type(obj).__name__} + + +class JsonTransportEncoder: + def __init__(self, parameter_types: List[Type] = None, maximum_depth: int = 100, + strict_validation: bool = True, **kwargs): + self.parameter_types = parameter_types or [] + self.maximum_depth = maximum_depth + self.strict_validation = strict_validation + self.type_registry = TypeProviderFactory.create_default_registry() + self.custom_encoder = CustomJSONEncoder(ensure_ascii=False, separators=(',', ':')) + self.single_parameter_mode = len(self.parameter_types) == 1 + self.multiple_parameter_mode = len(self.parameter_types) > 1 + if self.multiple_parameter_mode and HAS_PYDANTIC: + self.parameter_wrapper_model = self._create_parameter_wrapper_model() + + def _create_parameter_wrapper_model(self) -> Type[BaseModel]: + model_fields = {} + for i, param_type in enumerate(self.parameter_types): + model_fields[f"parameter_{i}"] = (param_type, ...) + return create_model('MethodParametersWrapper', **model_fields) + + def register_type_provider(self, provider) -> None: + self.type_registry.register_provider(provider) + + def encode(self, arguments: tuple) -> bytes: + try: + if not arguments: + return self._serialize_to_json_bytes([]) + + if self.single_parameter_mode: + parameter = arguments[0] + serialized_param = self._serialize_with_state(parameter) + if HAS_PYDANTIC and isinstance(parameter, BaseModel): + if hasattr(parameter, 'model_dump'): + return self._serialize_to_json_bytes(parameter.model_dump()) + return self._serialize_to_json_bytes(parameter.dict()) + elif isinstance(parameter, dict): + return self._serialize_to_json_bytes(serialized_param) + else: + return self._serialize_to_json_bytes(serialized_param) + + elif self.multiple_parameter_mode and HAS_PYDANTIC: + wrapper_data = {f"parameter_{i}": arg for i, arg in enumerate(arguments)} + wrapper_instance = self.parameter_wrapper_model(**wrapper_data) + return self._serialize_to_json_bytes(wrapper_instance.model_dump()) + + else: + serialized_args = [self._serialize_with_state(arg) for arg in arguments] + return self._serialize_to_json_bytes(serialized_args) + + except Exception as e: + raise SerializationException(f"Encoding failed: {e}") from e + + def _serialize_with_state(self, obj: Any) -> Any: + state = SerializationState(maximum_depth=self.maximum_depth) + return self._serialize_recursively(obj, state) + + def _serialize_recursively(self, obj: Any, state: SerializationState) -> Any: + if obj is None or isinstance(obj, (bool, int, float, str)): + return obj + if isinstance(obj, (list, tuple)): + state.validate_circular_reference(obj) + new_state = state.create_child_state(obj) + return [self._serialize_recursively(item, new_state) for item in obj] + elif isinstance(obj, dict): + state.validate_circular_reference(obj) + new_state = state.create_child_state(obj) + result = {} + for key, value in obj.items(): + if not isinstance(key, str): + if self.strict_validation: + raise SerializationException(f"Dictionary key must be string, got {type(key).__name__}") + key = str(key) + result[key] = self._serialize_recursively(value, new_state) + return result + + provider = self.type_registry.find_provider_for_object(obj) + if provider: + try: + serialized = provider.serialize_to_dict(obj, state) + return self._serialize_recursively(serialized, state) + except Exception as e: + if self.strict_validation: + raise SerializationException(f"Provider failed for {type(obj).__name__}: {e}") from e + return {"__serialization_error__": str(e), "__original_type__": type(obj).__name__} + else: + if self.strict_validation: + raise SerializationException(f"No provider for type {type(obj).__name__}") + return {"__fallback_string__": str(obj), "__original_type__": type(obj).__name__} + + def _serialize_to_json_bytes(self, obj: Any) -> bytes: + if HAS_ORJSON: + try: + return orjson.dumps(obj, default=self._orjson_default_handler) + except TypeError: + pass + if HAS_UJSON: + try: + return ujson.dumps(obj, ensure_ascii=False, default=self._ujson_default_handler).encode('utf-8') + except (TypeError, ValueError): + pass + return self.custom_encoder.encode(obj).encode('utf-8') + + def _orjson_default_handler(self, obj): + if isinstance(obj, datetime): + return { + "__datetime__": obj.isoformat(), + "__timezone__": str(obj.tzinfo) if obj.tzinfo else None + } + elif isinstance(obj, date): + return {"__date__": obj.isoformat()} + elif isinstance(obj, time): + return {"__time__": obj.isoformat()} + elif isinstance(obj, Decimal): + return {"__decimal__": str(obj)} + elif isinstance(obj, (set, frozenset)): + return { + "__frozenset__" if isinstance(obj, frozenset) else "__set__": list(obj) + } + elif isinstance(obj, UUID): + return {"__uuid__": str(obj)} + elif isinstance(obj, Path): + return {"__path__": str(obj)} + else: + return {"__fallback_string__": str(obj), "__original_type__": type(obj).__name__} + + def _ujson_default_handler(self, obj): + return self._orjson_default_handler(obj) + + +class JsonTransportDecoder: + def __init__(self, target_type: Union[Type, List[Type]] = None, **kwargs): + self.target_type = target_type + if isinstance(target_type, list): + self.multiple_parameter_mode = len(target_type) > 1 + self.parameter_types = target_type + if self.multiple_parameter_mode and HAS_PYDANTIC: + self.parameter_wrapper_model = self._create_parameter_wrapper_model() + else: + self.multiple_parameter_mode = False + self.parameter_types = [target_type] if target_type else [] + + def _create_parameter_wrapper_model(self) -> Type[BaseModel]: + model_fields = {} + for i, param_type in enumerate(self.parameter_types): + model_fields[f"parameter_{i}"] = (param_type, ...) + return create_model('MethodParametersWrapper', **model_fields) + + def decode(self, data: bytes) -> Any: + try: + if not data: + return None + json_data = self._deserialize_from_json_bytes(data) + reconstructed_data = self._reconstruct_objects(json_data) + if not self.target_type: + return reconstructed_data + if isinstance(self.target_type, list): + if self.multiple_parameter_mode and HAS_PYDANTIC: + wrapper_instance = self.parameter_wrapper_model(**reconstructed_data) + return tuple(getattr(wrapper_instance, f"parameter_{i}") for i in range(len(self.parameter_types))) + else: + return self._decode_to_target_type(reconstructed_data, self.parameter_types[0]) + else: + return self._decode_to_target_type(reconstructed_data, self.target_type) + except Exception as e: + raise DeserializationException(f"Decoding failed: {e}") from e + + def _deserialize_from_json_bytes(self, data: bytes) -> Any: + if HAS_ORJSON: + try: + return orjson.loads(data) + except orjson.JSONDecodeError: + pass + if HAS_UJSON: + try: + return ujson.loads(data.decode('utf-8')) + except (ujson.JSONDecodeError, UnicodeDecodeError): + pass + return json.loads(data.decode('utf-8')) + + def _decode_to_target_type(self, json_data: Any, target_type: Type) -> Any: + if target_type in (str, int, float, bool, list, dict): + return target_type(json_data) + return json_data + + def _reconstruct_objects(self, data: Any) -> Any: + if not isinstance(data, dict): + if isinstance(data, list): + return [self._reconstruct_objects(item) for item in data] + return data + if "__datetime__" in data: + return datetime.fromisoformat(data["__datetime__"]) + elif "__date__" in data: + return date.fromisoformat(data["__date__"]) + elif "__time__" in data: + return time.fromisoformat(data["__time__"]) + elif "__decimal__" in data: + return Decimal(data["__decimal__"]) + elif "__set__" in data: + return set(self._reconstruct_objects(item) for item in data["__set__"]) + elif "__frozenset__" in data: + return frozenset(self._reconstruct_objects(item) for item in data["__frozenset__"]) + elif "__uuid__" in data: + return UUID(data["__uuid__"]) + elif "__path__" in data: + return Path(data["__path__"]) + elif "__dataclass__" in data or "__pydantic_model__" in data: + return data + else: + return {key: self._reconstruct_objects(value) for key, value in data.items()} + + +class JsonTransportCodec: + def __init__(self, parameter_types: List[Type] = None, return_type: Type = None, + maximum_depth: int = 100, strict_validation: bool = True, **kwargs): + self.parameter_types = parameter_types or [] + self.return_type = return_type + self.maximum_depth = maximum_depth + self.strict_validation = strict_validation + self._encoder = JsonTransportEncoder( + parameter_types=parameter_types, + maximum_depth=maximum_depth, + strict_validation=strict_validation, + **kwargs + ) + self._decoder = JsonTransportDecoder(target_type=return_type, **kwargs) + + def encode_parameters(self, *arguments) -> bytes: + return self._encoder.encode(arguments) + + def decode_return_value(self, data: bytes) -> Any: + return self._decoder.decode(data) + + def get_encoder(self) -> JsonTransportEncoder: + return self._encoder + + def get_decoder(self) -> JsonTransportDecoder: + return self._decoder + + def register_type_provider(self, provider) -> None: + self._encoder.register_type_provider(provider) diff --git a/src/dubbo/codec/json_codec/json_type.py b/src/dubbo/codec/json_codec/json_type.py new file mode 100644 index 0000000..3a6d840 --- /dev/null +++ b/src/dubbo/codec/json_codec/json_type.py @@ -0,0 +1,274 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod +from typing import ( + Any, + Type, + Optional, + List, + Dict, + Set, + Protocol, + runtime_checkable, + Union, +) +from dataclasses import dataclass, fields, is_dataclass, asdict +from datetime import datetime, date, time +from decimal import Decimal +from collections import namedtuple +from pathlib import Path +from uuid import UUID +from enum import Enum +import weakref + +try: + from pydantic import BaseModel + + HAS_PYDANTIC = True +except ImportError: + HAS_PYDANTIC = False + + +class SerializationException(Exception): + """Exception raised during serialization""" + pass + +class DeserializationException(Exception): + """Exception raised during deserialization""" + pass + +class CircularReferenceException(SerializationException): + """Exception raised when circular references are detected""" + pass + +@dataclass(frozen=True) +class SerializationState: + _visited_objects: Set[int] = None + maximum_depth: int = 100 + current_depth: int = 0 + + def __post_init__(self): + if self._visited_objects is None: + object.__setattr__(self, "_visited_objects", set()) + + def validate_circular_reference(self, obj: Any) -> None: + object_id = id(obj) + if object_id in self._visited_objects: + raise CircularReferenceException( + f"Circular reference detected for {type(obj).__name__}" + ) + if self.current_depth >= self.maximum_depth: + raise SerializationException( + f"Maximum serialization depth ({self.maximum_depth}) exceeded" + ) + + def create_child_state(self, obj: Any) -> "SerializationState": + new_visited = self._visited_objects.copy() + new_visited.add(id(obj)) + return SerializationState( + _visited_objects=new_visited, + maximum_depth=self.maximum_depth, + current_depth=self.current_depth + 1, + ) + + +@runtime_checkable +class TypeSerializationProvider(Protocol): + def can_serialize_type(self, obj: Any, obj_type: type) -> bool: ... + + def serialize_to_dict(self, obj: Any, state: SerializationState) -> Any: ... + + +class TypeProviderRegistry: + def __init__(self): + self._type_cache: Dict[type, Optional[TypeSerializationProvider]] = {} + self._providers: List[TypeSerializationProvider] = [] + self._weak_cache = weakref.WeakKeyDictionary() + + def register_provider(self, provider: TypeSerializationProvider) -> None: + self._providers.append(provider) + self._type_cache.clear() + self._weak_cache.clear() + + def find_provider_for_object(self, obj: Any) -> Optional[TypeSerializationProvider]: + obj_type = type(obj) + if obj_type in self._type_cache: + return self._type_cache[obj_type] + provider = None + for p in self._providers: + if p.can_serialize_type(obj, obj_type): + provider = p + break + self._type_cache[obj_type] = provider + return provider + + +class DateTimeSerializationProvider: + def can_serialize_type(self, obj: Any, obj_type: type) -> bool: + return obj_type in (datetime, date, time) + + def serialize_to_dict( + self, obj: Union[datetime, date, time], state: SerializationState + ) -> Dict[str, str]: + if isinstance(obj, datetime): + return { + "__datetime__": obj.isoformat(), + "__timezone__": str(obj.tzinfo) if obj.tzinfo else None, + } + elif isinstance(obj, date): + return {"__date__": obj.isoformat()} + else: + return {"__time__": obj.isoformat()} + + +class DecimalSerializationProvider: + def can_serialize_type(self, obj: Any, obj_type: type) -> bool: + return obj_type is Decimal + + def serialize_to_dict( + self, obj: Decimal, state: SerializationState + ) -> Dict[str, str]: + return {"__decimal__": str(obj)} + + +class CollectionSerializationProvider: + def can_serialize_type(self, obj: Any, obj_type: type) -> bool: + return obj_type in (set, frozenset) + + def serialize_to_dict( + self, obj: Union[set, frozenset], state: SerializationState + ) -> Dict[str, Any]: + safe_items = [] + for item in obj: + if isinstance(item, (str, int, float, bool, type(None))): + safe_items.append(item) + else: + raise SerializationException( + f"Cannot serialize {type(item).__name__} in collection. " + f"Collections can only contain JSON-safe types (str, int, float, bool, None)" + ) + return { + "__frozenset__" if isinstance(obj, frozenset) else "__set__": safe_items + } + + +class DataclassSerializationProvider: + def can_serialize_type(self, obj: Any, obj_type: type) -> bool: + return is_dataclass(obj) and not isinstance(obj, type) + + def serialize_to_dict(self, obj: Any, state: SerializationState) -> Dict[str, Any]: + state.validate_circular_reference(obj) + try: + field_data = asdict(obj) + return { + "__dataclass__": f"{obj.__class__.__module__}.{obj.__class__.__qualname__}", + "__field_data__": field_data, + } + except (TypeError, RecursionError): + field_data = {} + for field in fields(obj): + try: + field_data[field.name] = getattr(obj, field.name) + except Exception as e: + raise SerializationException( + f"Cannot serialize field '{field.name}': {e}" + ) + return { + "__dataclass__": f"{obj.__class__.__module__}.{obj.__class__.__qualname__}", + "__field_data__": field_data, + } + + +class NamedTupleSerializationProvider: + def can_serialize_type(self, obj: Any, obj_type: type) -> bool: + return ( + hasattr(obj_type, "_fields") + and hasattr(obj, "_asdict") + and callable(obj._asdict) + ) + + def serialize_to_dict(self, obj: Any, state: SerializationState) -> Dict[str, Any]: + return { + "__namedtuple__": f"{obj.__class__.__module__}.{obj.__class__.__qualname__}", + "__tuple_data__": obj._asdict(), + } + + +class PydanticModelSerializationProvider: + def can_serialize_type(self, obj: Any, obj_type: type) -> bool: + return HAS_PYDANTIC and isinstance(obj, BaseModel) + + def serialize_to_dict( + self, obj: BaseModel, state: SerializationState + ) -> Dict[str, Any]: + state.validate_circular_reference(obj) + if hasattr(obj, "model_dump"): + model_data = obj.model_dump() + else: + model_data = obj.dict() + return { + "__pydantic_model__": f"{obj.__class__.__module__}.{obj.__class__.__qualname__}", + "__model_data__": model_data, + } + + +class SimpleTypeSerializationProvider: + def can_serialize_type(self, obj: Any, obj_type: type) -> bool: + return obj_type is UUID or isinstance(obj, (Path, Enum)) + + def serialize_to_dict( + self, obj: Union[UUID, Path, Enum], state: SerializationState + ) -> Dict[str, str]: + if isinstance(obj, UUID): + return {"__uuid__": str(obj)} + elif isinstance(obj, Path): + return {"__path__": str(obj)} + else: + return { + "__enum__": f"{obj.__class__.__module__}.{obj.__class__.__qualname__}", + "__enum_value__": obj.value, + } + + +class TypeProviderFactory: + @staticmethod + def create_default_registry() -> TypeProviderRegistry: + registry = TypeProviderRegistry() + default_providers = [ + DateTimeSerializationProvider(), + DecimalSerializationProvider(), + CollectionSerializationProvider(), + DataclassSerializationProvider(), + NamedTupleSerializationProvider(), + PydanticModelSerializationProvider(), + SimpleTypeSerializationProvider(), + ] + for provider in default_providers: + registry.register_provider(provider) + return registry + + @staticmethod + def create_minimal_registry() -> TypeProviderRegistry: + registry = TypeProviderRegistry() + essential_providers = [ + DateTimeSerializationProvider(), + DecimalSerializationProvider(), + SimpleTypeSerializationProvider(), + ] + for provider in essential_providers: + registry.register_provider(provider) + return registry diff --git a/src/dubbo/codec/protobuf_codec/__init__.py b/src/dubbo/codec/protobuf_codec/__init__.py new file mode 100644 index 0000000..4ddd635 --- /dev/null +++ b/src/dubbo/codec/protobuf_codec/__init__.py @@ -0,0 +1,21 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .protobuf_codec_handler import ProtobufTransportCodec, ProtobufTransportEncoder, ProtobufTransportDecoder + +__all__ = [ + "ProtobufTransportCodec", "ProtobufTransportEncoder", "ProtobufTransportDecoder" +] \ No newline at end of file diff --git a/src/dubbo/codec/protobuf_codec/protobuf_codec_handler.py b/src/dubbo/codec/protobuf_codec/protobuf_codec_handler.py new file mode 100644 index 0000000..8c75062 --- /dev/null +++ b/src/dubbo/codec/protobuf_codec/protobuf_codec_handler.py @@ -0,0 +1,305 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Type, Protocol, Optional +from abc import ABC, abstractmethod +import json +from dataclasses import dataclass + +# Betterproto imports +try: + import betterproto + HAS_BETTERPROTO = True +except ImportError: + HAS_BETTERPROTO = False + +try: + from pydantic import BaseModel + HAS_PYDANTIC = True +except ImportError: + HAS_PYDANTIC = False + +# Reuse your existing JSON type system +from dubbo.codec.json_codec.json_type import ( + TypeProviderFactory, SerializationState, + SerializationException, DeserializationException +) + + +class ProtobufEncodingFunction(Protocol): + def __call__(self, obj: Any) -> bytes: ... + + +class ProtobufDecodingFunction(Protocol): + def __call__(self, data: bytes) -> Any: ... + + +@dataclass +class ProtobufMethodDescriptor: + """Protobuf-specific method descriptor for single parameter""" + parameter_type: Type + return_type: Type + protobuf_message_type: Optional[Type] = None + use_json_fallback: bool = False + + +class ProtobufTypeHandler: + """Handles type conversion between Python types and Betterproto""" + + @staticmethod + def is_betterproto_message(obj_type: Type) -> bool: + """Check if type is a betterproto message class""" + if not HAS_BETTERPROTO: + return False + try: + return (hasattr(obj_type, '__dataclass_fields__') and + issubclass(obj_type, betterproto.Message)) + except (TypeError, AttributeError): + return False + + @staticmethod + def is_betterproto_message_instance(obj: Any) -> bool: + """Check if object is a betterproto message instance""" + if not HAS_BETTERPROTO: + return False + try: + return isinstance(obj, betterproto.Message) + except: + return False + + @staticmethod + def is_protobuf_compatible(obj_type: Type) -> bool: + """Check if type can be handled by protobuf""" + return (obj_type in (str, int, float, bool, bytes) or + ProtobufTypeHandler.is_betterproto_message(obj_type)) + + @staticmethod + def needs_json_fallback(parameter_type: Type) -> bool: + """Check if we need JSON fallback for this type""" + return not ProtobufTypeHandler.is_protobuf_compatible(parameter_type) + + +class ProtobufTransportEncoder: + """Protobuf encoder for single parameters using betterproto""" + + def __init__(self, parameter_type: Type = None, **kwargs): + if not HAS_BETTERPROTO: + raise ImportError("betterproto library is required for ProtobufTransportEncoder") + + self.parameter_type = parameter_type + + self.descriptor = ProtobufMethodDescriptor( + parameter_type=parameter_type, + return_type=None, + use_json_fallback=ProtobufTypeHandler.needs_json_fallback(parameter_type) if parameter_type else False + ) + + if self.descriptor.use_json_fallback: + from dubbo.codec.json_codec.json_codec_handler import JsonTransportEncoder + self.json_fallback_encoder = JsonTransportEncoder([parameter_type], **kwargs) + + def encode(self, parameter: Any) -> bytes: + """Encode single parameter to bytes""" + try: + if parameter is None: + return b'' + + # Handle case where parameter is a tuple (common in RPC calls) + if isinstance(parameter, tuple): + if len(parameter) == 0: + return b'' + elif len(parameter) == 1: + return self._encode_single_parameter(parameter[0]) + else: + raise SerializationException(f"Multiple parameters not supported. Got tuple with {len(parameter)} elements, expected 1.") + + return self._encode_single_parameter(parameter) + + except Exception as e: + raise SerializationException(f"Protobuf encoding failed: {e}") from e + + def _encode_single_parameter(self, parameter: Any) -> bytes: + """Encode a single parameter using betterproto""" + # If it's already a betterproto message instance, serialize it + if ProtobufTypeHandler.is_betterproto_message_instance(parameter): + return bytes(parameter) + + # If we have type info and it's a betterproto message type + if self.parameter_type and ProtobufTypeHandler.is_betterproto_message(self.parameter_type): + if isinstance(parameter, self.parameter_type): + return bytes(parameter) + elif isinstance(parameter, dict): + # Convert dict to betterproto message + try: + message = self.parameter_type().from_dict(parameter) + return bytes(message) + except Exception as e: + raise SerializationException(f"Cannot convert dict to {self.parameter_type}: {e}") + else: + raise SerializationException(f"Cannot convert {type(parameter)} to {self.parameter_type}") + + # Handle primitive types by wrapping in a simple message + if isinstance(parameter, (str, int, float, bool, bytes)): + return self._encode_primitive(parameter) + + # Use JSON fallback if configured + if self.descriptor.use_json_fallback: + json_data = self.json_fallback_encoder.encode((parameter,)) + return json_data + + raise SerializationException(f"Cannot encode {type(parameter)} as protobuf") + + def _encode_primitive(self, value: Any) -> bytes: + """Encode primitive values by wrapping them in a simple structure""" + # For primitives, we'll use JSON encoding wrapped in bytes + # This is a simplified approach - in a real implementation you might + # want to define a wrapper protobuf message for primitives + try: + json_str = json.dumps({"value": value, "type": type(value).__name__}) + return json_str.encode('utf-8') + except Exception as e: + raise SerializationException(f"Failed to encode primitive {value}: {e}") + + +class ProtobufTransportDecoder: + """Protobuf decoder for single parameters using betterproto""" + + def __init__(self, target_type: Type = None, **kwargs): + if not HAS_BETTERPROTO: + raise ImportError("betterproto library is required for ProtobufTransportDecoder") + + self.target_type = target_type + self.use_json_fallback = ProtobufTypeHandler.needs_json_fallback(target_type) if target_type else False + + if self.use_json_fallback: + from dubbo.codec.json_codec.json_codec_handler import JsonTransportDecoder + self.json_fallback_decoder = JsonTransportDecoder(target_type, **kwargs) + + def decode(self, data: bytes) -> Any: + """Decode bytes to single parameter""" + try: + if not data: + return None + + if not self.target_type: + return self._decode_without_type_info(data) + + return self._decode_single_parameter(data, self.target_type) + + except Exception as e: + raise DeserializationException(f"Protobuf decoding failed: {e}") from e + + def _decode_single_parameter(self, data: bytes, target_type: Type) -> Any: + """Decode single parameter using betterproto""" + if ProtobufTypeHandler.is_betterproto_message(target_type): + try: + # Use betterproto's parsing + message_instance = target_type().parse(data) + return message_instance + except Exception as e: + if self.use_json_fallback: + return self.json_fallback_decoder.decode(data) + raise DeserializationException(f"Failed to parse betterproto message: {e}") + + # Handle primitives + elif target_type in (str, int, float, bool, bytes): + return self._decode_primitive(data, target_type) + + # Use JSON fallback + elif self.use_json_fallback: + return self.json_fallback_decoder.decode(data) + + else: + raise DeserializationException(f"Cannot decode to {target_type} from protobuf") + + def _decode_primitive(self, data: bytes, target_type: Type) -> Any: + """Decode primitive values from their wrapped format""" + try: + json_str = data.decode('utf-8') + parsed = json.loads(json_str) + value = parsed.get("value") + + # Convert to target type if needed + if target_type == str: + return str(value) + elif target_type == int: + return int(value) + elif target_type == float: + return float(value) + elif target_type == bool: + return bool(value) + elif target_type == bytes: + return bytes(value) if isinstance(value, (list, bytes)) else str(value).encode() + else: + return value + + except Exception as e: + raise DeserializationException(f"Failed to decode primitive: {e}") + + def _decode_without_type_info(self, data: bytes) -> Any: + """Decode without type information - try JSON first""" + try: + return json.loads(data.decode('utf-8')) + except: + return data + + +class ProtobufTransportCodec: + """Main protobuf codec class for single parameters using betterproto""" + + def __init__(self, parameter_type: Type = None, return_type: Type = None, **kwargs): + if not HAS_BETTERPROTO: + raise ImportError("betterproto library is required for ProtobufTransportCodec") + + self.parameter_type = parameter_type + self.return_type = return_type + + self._encoder = ProtobufTransportEncoder( + parameter_type=parameter_type, + **kwargs + ) + self._decoder = ProtobufTransportDecoder( + target_type=return_type, + **kwargs + ) + + def encode_parameter(self, argument: Any) -> bytes: + """Encode single parameter""" + return self._encoder.encode(argument) + + def encode_parameters(self, arguments: tuple) -> bytes: + """Legacy method to handle tuple of arguments (for backward compatibility)""" + if not arguments: + return b'' + if len(arguments) == 1: + return self._encoder.encode(arguments[0]) + else: + raise SerializationException(f"Multiple parameters not supported. Got {len(arguments)} arguments, expected 1.") + + def decode_return_value(self, data: bytes) -> Any: + """Decode return value""" + return self._decoder.decode(data) + + def get_encoder(self) -> ProtobufTransportEncoder: + return self._encoder + + def get_decoder(self) -> ProtobufTransportDecoder: + return self._decoder + + +def create_protobuf_codec(**kwargs) -> ProtobufTransportCodec: + """Factory function to create protobuf codec""" + return ProtobufTransportCodec(**kwargs) \ No newline at end of file diff --git a/src/dubbo/extension/registries.py b/src/dubbo/extension/registries.py index cf23ae7..d188d9c 100644 --- a/src/dubbo/extension/registries.py +++ b/src/dubbo/extension/registries.py @@ -14,6 +14,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +import importlib +from typing import Any +from dataclasses import dataclass + +from dubbo.classes import SingletonBase +from dubbo.extension import registries as registries_module + +# Import all the required interface classes from dataclasses import dataclass from typing import Any @@ -22,6 +30,78 @@ from dubbo.protocol import Protocol from dubbo.registry import RegistryFactory from dubbo.remoting import Transporter +from dubbo.classes import Codec + + +class ExtensionError(Exception): + """ + Extension error. + """ + + def __init__(self, message: str): + """ + Initialize the extension error. + :param message: The error message. + :type message: str + """ + super().__init__(message) + + +class ExtensionLoader(SingletonBase): + """ + Singleton class for loading extension implementations. + """ + + def __init__(self): + """ + Initialize the extension loader. + + Load all the registries from the registries module. + """ + if not hasattr(self, "_initialized"): # Ensure __init__ runs only once + self._registries = {} + for name in registries_module.registries: + registry = getattr(registries_module, name) + self._registries[registry.interface] = registry.impls + self._initialized = True + + def get_extension(self, interface: Any, impl_name: str) -> Any: + """ + Get the extension implementation for the interface. + + :param interface: Interface class. + :type interface: Any + :param impl_name: Implementation name. + :type impl_name: str + :return: Extension implementation class. + :rtype: Any + :raises ExtensionError: If the interface or implementation is not found. + """ + # Get the registry for the interface + impls = self._registries.get(interface) + print("value is ", impls, interface) + if not impls: + raise ExtensionError(f"Interface '{interface.__name__}' is not supported.") + + # Get the full name of the implementation + full_name = impls.get(impl_name) + if not full_name: + raise ExtensionError(f"Implementation '{impl_name}' for interface '{interface.__name__}' is not exist.") + + try: + # Split the full name into module and class + module_name, class_name = full_name.rsplit(".", 1) + + # Load the module and get the class + module = importlib.import_module(module_name) + subclass = getattr(module, class_name) + + # Return the subclass + return subclass + except Exception as e: + raise ExtensionError( + f"Failed to load extension '{impl_name}' for interface '{interface.__name__}'. \nDetail: {e}" + ) @dataclass @@ -39,7 +119,7 @@ class ExtendedRegistry: impls: dict[str, Any] -# All Extension Registries +# All Extension Registries - FIXED: Added codecRegistry to the list registries = [ "registryFactoryRegistry", "loadBalanceRegistry", @@ -47,6 +127,7 @@ class ExtendedRegistry: "compressorRegistry", "decompressorRegistry", "transporterRegistry", + "codecRegistry", ] # RegistryFactory registry @@ -84,7 +165,6 @@ class ExtendedRegistry: }, ) - # Decompressor registry decompressorRegistry = ExtendedRegistry( interface=Decompressor, @@ -95,7 +175,6 @@ class ExtendedRegistry: }, ) - # Transporter registry transporterRegistry = ExtendedRegistry( interface=Transporter, @@ -103,3 +182,12 @@ class ExtendedRegistry: "aio": "dubbo.remoting.aio.aio_transporter.AioTransporter", }, ) + +# Codec Registry +codecRegistry = ExtendedRegistry( + interface=Codec, + impls={ + "json": "dubbo.codec.json_codec.JsonTransportCodec", + "protobuf": "dubbo.codec.protobuf_codec.ProtobufTransportCodec", + }, +) \ No newline at end of file diff --git a/src/dubbo/proxy/handlers.py b/src/dubbo/proxy/handlers.py index 8c89663..aa5004f 100644 --- a/src/dubbo/proxy/handlers.py +++ b/src/dubbo/proxy/handlers.py @@ -14,9 +14,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, Optional - +import inspect +from typing import Callable, Optional, List, Type, Any, get_type_hints from dubbo.classes import MethodDescriptor +from dubbo.codec import DubboTransportService from dubbo.types import ( DeserializingFunction, RpcTypes, @@ -26,9 +27,16 @@ __all__ = ["RpcMethodHandler", "RpcServiceHandler"] +class RpcMethodConfigurationError(Exception): + """ + Raised when RPC method is configured incorrectly. + """ + pass + + class RpcMethodHandler: """ - Rpc method handler + Rpc method handler that wraps metadata and serialization logic for a callable. """ __slots__ = ["_method_descriptor"] @@ -50,34 +58,134 @@ def method_descriptor(self) -> MethodDescriptor: """ return self._method_descriptor + @staticmethod + def get_codec(**kwargs) -> tuple: + """ + Get the serialization and deserialization functions based on codec + :param kwargs: codec settings like transport_type, parameter_types, return_type + :return: serializer and deserializer functions + :rtype: Tuple[SerializingFunction, DeserializingFunction] + """ + return DubboTransportService.create_serialization_functions(**kwargs) + + @classmethod + def _infer_types_from_method(cls, method: Callable) -> tuple: + """ + Infer method name, parameter types, and return type from a callable + :param method: the method to analyze + :type method: Callable + :return: tuple of method name, parameter types, return type + :rtype: Tuple[str, List[Type], Type] + """ + try: + type_hints = get_type_hints(method) + sig = inspect.signature(method) + method_name = method.__name__ + params = list(sig.parameters.values()) + + if params and params[0].name == "self": + raise RpcMethodConfigurationError( + f"Method '{method_name}' appears to be an unbound method with 'self' parameter. " + "RPC methods should be bound methods (e.g., instance.method) or standalone functions. " + "If you're registering a class method, ensure you pass a bound method: " + "RpcMethodHandler.unary(instance.method) not RpcMethodHandler.unary(Class.method)" + ) + + params_types = [type_hints.get(p.name, Any) for p in params] + return_type = type_hints.get("return", Any) + return method_name, params_types, return_type + except RpcMethodConfigurationError: + raise + except Exception: + return method.__name__, [Any], Any + + @classmethod + def _create_method_descriptor( + cls, + method: Callable, + method_name: str, + params_types: List[Type], + return_type: Type, + rpc_type: str, + codec: Optional[str] = None, + param_encoder: Optional[DeserializingFunction] = None, + return_decoder: Optional[SerializingFunction] = None, + **kwargs, + ) -> MethodDescriptor: + """ + Create a MethodDescriptor with serialization configuration + :param method: the actual function/method + :param method_name: RPC method name + :param params_types: parameter type hints + :param return_type: return type hint + :param rpc_type: type of RPC (unary, stream, etc.) + :param codec: serialization codec (json, pb, etc.) + :param param_encoder: deserialization function + :param return_decoder: serialization function + :param kwargs: additional codec args + :return: MethodDescriptor instance + :rtype: MethodDescriptor + """ + if param_encoder is None or return_decoder is None: + codec_kwargs = { + "transport_type": codec or "json", + "parameter_types": params_types, + "return_type": return_type, + **kwargs, + } + serializer, deserializer = cls.get_codec(**codec_kwargs) + request_deserializer = param_encoder or deserializer + response_serializer = return_decoder or serializer + + return MethodDescriptor( + callable_method=method, + method_name=method_name or method.__name__, + arg_serialization=(None, request_deserializer), + return_serialization=(response_serializer, None), + rpc_type=rpc_type + ) + @classmethod def unary( cls, method: Callable, method_name: Optional[str] = None, + params_types: Optional[List[Type]] = None, + return_type: Optional[Type] = None, + codec: Optional[str] = None, request_deserializer: Optional[DeserializingFunction] = None, response_serializer: Optional[SerializingFunction] = None, + **kwargs, ) -> "RpcMethodHandler": """ - Create a unary method handler - :param method: the method. - :type method: Callable - :param method_name: the method name. If not provided, the method name will be used. - :type method_name: Optional[str] - :param request_deserializer: the request deserializer. - :type request_deserializer: Optional[DeserializingFunction] - :param response_serializer: the response serializer. - :type response_serializer: Optional[SerializingFunction] - :return: the unary method handler. + Register a unary RPC method handler + :param method: the callable function + :param method_name: RPC method name + :param params_types: input types + :param return_type: output type + :param codec: serialization codec + :param request_deserializer: custom deserializer + :param response_serializer: custom serializer + :return: RpcMethodHandler instance :rtype: RpcMethodHandler """ + inferred_name, inferred_param_types, inferred_return_type = cls._infer_types_from_method(method) + resolved_method_name = method_name or inferred_name + resolved_param_types = params_types or inferred_param_types + resolved_return_type = return_type or inferred_return_type + codec = codec or "json" + return cls( - MethodDescriptor( - callable_method=method, - method_name=method_name or method.__name__, - arg_serialization=(None, request_deserializer), - return_serialization=(response_serializer, None), + cls._create_method_descriptor( + method=method, + method_name=resolved_method_name, + params_types=resolved_param_types, + return_type=resolved_return_type, rpc_type=RpcTypes.UNARY.value, + codec=codec, + request_deserializer=request_deserializer, + response_serializer=response_serializer, + **kwargs, ) ) @@ -86,29 +194,42 @@ def client_stream( cls, method: Callable, method_name: Optional[str] = None, + params_types: Optional[List[Type]] = None, + return_type: Optional[Type] = None, + codec: Optional[str] = None, request_deserializer: Optional[DeserializingFunction] = None, response_serializer: Optional[SerializingFunction] = None, - ): + **kwargs, + ) -> "RpcMethodHandler": """ - Create a client stream method handler - :param method: the method. - :type method: Callable - :param method_name: the method name. If not provided, the method name will be used. - :type method_name: Optional[str] - :param request_deserializer: the request deserializer. - :type request_deserializer: Optional[DeserializingFunction] - :param response_serializer: the response serializer. - :type response_serializer: Optional[SerializingFunction] - :return: the client stream method handler. + Register a client-streaming RPC method handler + :param method: the callable function + :param method_name: RPC method name + :param params_types: input types + :param return_type: output type + :param codec: serialization codec + :param request_deserializer: custom deserializer + :param response_serializer: custom serializer + :return: RpcMethodHandler instance :rtype: RpcMethodHandler """ + inferred_name, inferred_param_types, inferred_return_type = cls._infer_types_from_method(method) + resolved_method_name = method_name or inferred_name + resolved_param_types = params_types or inferred_param_types + resolved_return_type = return_type or inferred_return_type + resolved_codec = codec or "json" + return cls( - MethodDescriptor( - callable_method=method, - method_name=method_name or method.__name__, - arg_serialization=(None, request_deserializer), - return_serialization=(response_serializer, None), + cls._create_method_descriptor( + method=method, + method_name=resolved_method_name, + params_types=resolved_param_types, + return_type=resolved_return_type, rpc_type=RpcTypes.CLIENT_STREAM.value, + codec=resolved_codec, + request_deserializer=request_deserializer, + response_serializer=response_serializer, + **kwargs, ) ) @@ -117,29 +238,42 @@ def server_stream( cls, method: Callable, method_name: Optional[str] = None, + params_types: Optional[List[Type]] = None, + return_type: Optional[Type] = None, + codec: Optional[str] = None, request_deserializer: Optional[DeserializingFunction] = None, response_serializer: Optional[SerializingFunction] = None, - ): + **kwargs, + ) -> "RpcMethodHandler": """ - Create a server stream method handler - :param method: the method. - :type method: Callable - :param method_name: the method name. If not provided, the method name will be used. - :type method_name: Optional[str] - :param request_deserializer: the request deserializer. - :type request_deserializer: Optional[DeserializingFunction] - :param response_serializer: the response serializer. - :type response_serializer: Optional[SerializingFunction] - :return: the server stream method handler. + Register a server-streaming RPC method handler + :param method: the callable function + :param method_name: RPC method name + :param params_types: input types + :param return_type: output type + :param codec: serialization codec + :param request_deserializer: custom deserializer + :param response_serializer: custom serializer + :return: RpcMethodHandler instance :rtype: RpcMethodHandler """ + inferred_name, inferred_param_types, inferred_return_type = cls._infer_types_from_method(method) + resolved_method_name = method_name or inferred_name + resolved_param_types = params_types or inferred_param_types + resolved_return_type = return_type or inferred_return_type + resolved_codec = codec or "json" + return cls( - MethodDescriptor( - callable_method=method, - method_name=method_name or method.__name__, - arg_serialization=(None, request_deserializer), - return_serialization=(response_serializer, None), + cls._create_method_descriptor( + method=method, + method_name=resolved_method_name, + params_types=resolved_param_types, + return_type=resolved_return_type, rpc_type=RpcTypes.SERVER_STREAM.value, + codec=resolved_codec, + request_deserializer=request_deserializer, + response_serializer=response_serializer, + **kwargs, ) ) @@ -148,46 +282,59 @@ def bi_stream( cls, method: Callable, method_name: Optional[str] = None, + params_types: Optional[List[Type]] = None, + return_type: Optional[Type] = None, + codec: Optional[str] = None, request_deserializer: Optional[DeserializingFunction] = None, response_serializer: Optional[SerializingFunction] = None, - ): + **kwargs, + ) -> "RpcMethodHandler": """ - Create a bidi stream method handler - :param method: the method. - :type method: Callable - :param method_name: the method name. If not provided, the method name will be used. - :type method_name: Optional[str] - :param request_deserializer: the request deserializer. - :type request_deserializer: Optional[DeserializingFunction] - :param response_serializer: the response serializer. - :type response_serializer: Optional[SerializingFunction] - :return: the bidi stream method handler. + Register a bidirectional streaming RPC method handler + :param method: the callable function + :param method_name: RPC method name + :param params_types: input types + :param return_type: output type + :param codec: serialization codec + :param request_deserializer: custom deserializer + :param response_serializer: custom serializer + :return: RpcMethodHandler instance :rtype: RpcMethodHandler """ + inferred_name, inferred_param_types, inferred_return_type = cls._infer_types_from_method(method) + resolved_method_name = method_name or inferred_name + resolved_param_types = params_types or inferred_param_types + resolved_return_type = return_type or inferred_return_type + resolved_codec = codec or "json" + return cls( - MethodDescriptor( - callable_method=method, - method_name=method_name or method.__name__, - arg_serialization=(None, request_deserializer), - return_serialization=(response_serializer, None), + cls._create_method_descriptor( + method=method, + method_name=resolved_method_name, + params_types=resolved_param_types, + return_type=resolved_return_type, rpc_type=RpcTypes.BI_STREAM.value, + codec=resolved_codec, + request_deserializer=request_deserializer, + response_serializer=response_serializer, + **kwargs, ) ) class RpcServiceHandler: """ - Rpc service handler + Rpc service handler that maps method names to their corresponding RpcMethodHandler. """ __slots__ = ["_service_name", "_method_handlers"] - def __init__(self, service_name: str, method_handlers: list[RpcMethodHandler]): + def __init__(self, service_name: str, method_handlers: List[RpcMethodHandler]): """ Initialize the RpcServiceHandler :param service_name: the name of the service. :type service_name: str - :param method_handlers: the method handlers. + :param method_handlers: list of RpcMethodHandler instances :type method_handlers: List[RpcMethodHandler] """ self._service_name = service_name @@ -209,8 +356,8 @@ def service_name(self) -> str: @property def method_handlers(self) -> dict[str, RpcMethodHandler]: """ - Get the method handlers - :return: the method handlers + Get the registered RPC method handlers + :return: mapping of method names to handlers :rtype: Dict[str, RpcMethodHandler] """ return self._method_handlers diff --git a/src/dubbo/server.py b/src/dubbo/server.py index b7c4dee..52a3507 100644 --- a/src/dubbo/server.py +++ b/src/dubbo/server.py @@ -81,4 +81,4 @@ def start(self): self._protocol.export(self._url) - self._exported = True + self._exported = True \ No newline at end of file