Skip to content

Commit

Permalink
[BACKPORT] Add serialize support for complex type (#519)
Browse files Browse the repository at this point in the history
* Add serialize support for complex type (#488)

* sync more serialize part from master
  • Loading branch information
Xuye (Chris) Qin authored and wjsi committed Jul 6, 2019
1 parent bc1765e commit 6f9b2c2
Show file tree
Hide file tree
Showing 9 changed files with 638 additions and 150 deletions.
20 changes: 17 additions & 3 deletions mars/serialize/core.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ cpdef enum PrimitiveType:
float64 = 12
bytes = 13
unicode = 14
complex64 = 24
complex128 = 25


cpdef enum ExtendType:
Expand Down Expand Up @@ -75,14 +77,18 @@ cdef class ValueType:
pass


cdef class SelfReferenceOverwritten(Exception):
pass


cdef class Field:
cdef object tag
cdef object default_val
cdef str _tag_name
cdef object _type
cdef object _model_cls

cdef public bint weak_ref
cdef public object model
cdef public str attr
cdef public object on_serialize
cdef public object on_deserialize
Expand Down Expand Up @@ -148,6 +154,14 @@ cdef class Float64Field(Field):
pass


cdef class Complex64Field(Field):
pass


cdef class Complex128Field(Field):
pass


cdef class StringField(Field):
pass

Expand Down Expand Up @@ -181,7 +195,7 @@ cdef class DataTypeField(Field):


cdef class ListField(Field):
cdef object _nest_ref
cdef public object _nest_ref


cdef class TupleField(Field):
Expand All @@ -193,7 +207,7 @@ cdef class DictField(Field):


cdef class ReferenceField(Field):
cdef object _model
cdef public object _model


cdef class OneOfField(Field):
Expand Down
99 changes: 90 additions & 9 deletions mars/serialize/core.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import importlib
import inspect
import copy
from collections import Iterable

from ..compat import six, OrderedDict
Expand Down Expand Up @@ -100,6 +101,8 @@ cdef class ValueType:
bytes = PrimitiveType.bytes
unicode = PrimitiveType.unicode
string = PrimitiveType.unicode if PY3 else PrimitiveType.bytes
complex64 = PrimitiveType.complex64
complex128 = PrimitiveType.complex128
slice = ExtendType.slice
arr = ExtendType.arr
dtype = ExtendType.dtype
Expand All @@ -119,7 +122,7 @@ cdef class ValueType:

cdef class Field:
def __init__(self, tag, default=None, bint weak_ref=False, on_serialize=None, on_deserialize=None):
self.model = None
self._model_cls = None
self.attr = None

self.tag = tag
Expand All @@ -131,6 +134,14 @@ cdef class Field:
self.on_serialize = on_serialize
self.on_deserialize = on_deserialize

@property
def model(self):
return self._model_cls

@model.setter
def model(self, model_cls):
self._model_cls = model_cls

cpdef str tag_name(self, Provider provider):
if self._tag_name is None:
return self.tag(provider)
Expand Down Expand Up @@ -267,6 +278,22 @@ cdef class Float64Field(Field):
self._type = ValueType.float64


cdef class Complex64Field(Field):
def __init__(self, tag, default=None, bint weak_ref=False, on_serialize=None, on_deserialize=None):
super(Complex64Field, self).__init__(
tag, default=default, weak_ref=weak_ref,
on_serialize=on_serialize, on_deserialize=on_deserialize)
self._type = ValueType.complex64


cdef class Complex128Field(Field):
def __init__(self, tag, default=None, bint weak_ref=False, on_serialize=None, on_deserialize=None):
super(Complex128Field, self).__init__(
tag, default=default, weak_ref=weak_ref,
on_serialize=on_serialize, on_deserialize=on_deserialize)
self._type = ValueType.complex128


cdef class StringField(Field):
def __init__(self, tag, default=None, bint weak_ref=False, on_serialize=None, on_deserialize=None):
super(StringField, self).__init__(
Expand Down Expand Up @@ -354,6 +381,18 @@ cdef class ListField(Field):
else:
self._type = ValueType.list(tp)

@property
def model(self):
return self._model_cls

@model.setter
def model(self, new_model_cls):
if getattr(self, '_nest_ref', None) is not None and \
self._nest_ref.model == 'self' and self._model_cls is not None and \
new_model_cls is not None:
raise SelfReferenceOverwritten('self reference is overwritten')
self._model_cls = new_model_cls

@property
def type(self):
if self._type is None:
Expand Down Expand Up @@ -397,6 +436,17 @@ cdef class ReferenceField(Field):
self._type = None
self._model = model

@property
def model(self):
return self._model_cls

@model.setter
def model(self, new_model_cls):
if getattr(self, '_model', None) == 'self' and \
self._model_cls is not None and new_model_cls is not None:
raise SelfReferenceOverwritten('self reference is overwritten')
self._model_cls = new_model_cls

@property
def type(self):
if not self._type:
Expand Down Expand Up @@ -432,6 +482,42 @@ cdef class OneOfField(Field):
self._type = ValueType.oneof(*[f.type for f in self.fields])
return self._type

@property
def attrs(self):
return [f.attr for f in self.fields]


cdef inline set_model(dict fields, cls):
cdef str slot
cdef bint modified

for slot, field in fields.items():
if not isinstance(field, OneOfField):
try:
field.model = cls
except SelfReferenceOverwritten:
field = copy.copy(field)
# reset old model after copy
field.model = None
field.model = cls
cls._FIELDS[slot] = field
else:
one_field_fields = []
modified = False
for f in field.fields:
try:
f.model = cls
except SelfReferenceOverwritten:
f = copy.copy(f)
# reset old model after copy
f.model = None
f.model = cls
modified = True
f.attr = field.attr
one_field_fields.append(f)
if modified:
field.fields = one_field_fields


class SerializableMetaclass(type):
def __new__(mcs, str name, tuple bases, dict kv):
Expand Down Expand Up @@ -477,13 +563,7 @@ class SerializableMetaclass(type):
kv['__slots__'] = tuple(slots)

cls = type.__new__(mcs, name, bases, kv)
for field in fields.values():
if not isinstance(field, OneOfField):
field.model = cls
else:
for f in field.fields:
f.model = cls
f.attr = field.attr
set_model(fields, cls)
return cls


Expand All @@ -495,7 +575,8 @@ class Serializable(six.with_metaclass(SerializableMetaclass, Base)):
if provider.type == ProviderType.json:
return dict

raise TypeError('Unknown provider type: {0}'.format(provider.type.value))
raise TypeError('Unknown provider type `{0}` for class `{1}`'.format(
ProviderType(provider.type).name, cls.__name__))

def serialize(self, Provider provider, obj=None):
return provider.serialize_model(self, obj=obj)
Expand Down
39 changes: 36 additions & 3 deletions mars/serialize/jsonserializer.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ cdef dict EXTEND_TYPE_TO_NAME = {
ValueType.key: 'key',
ValueType.datetime64: 'datetime64',
ValueType.timedelta64: 'timedelta64',
ValueType.complex64: 'complex64',
ValueType.complex128: 'complex128',
}


Expand Down Expand Up @@ -260,6 +262,18 @@ cdef class JsonSerializeProvider(Provider):
cdef inline _deserialize_timedelta64(self, obj, list callbacks):
return self._deserialize_datetime64_timedelta64(obj, callbacks)

cdef inline dict _serialize_complex(self, value, tp):
return {
'type': _get_name(tp),
'value': (value.real, value.imag)
}

cdef inline _deserialize_complex(self, obj, list callbacks):
cdef list v

v = obj['value']
return complex(*v)

cdef inline object _serialize_typed_value(self, value, tp, bint weak_ref=False):
if type(tp) not in (List, Tuple, Dict) and weak_ref:
# not iterable, and is weak ref
Expand All @@ -276,6 +290,8 @@ cdef class JsonSerializeProvider(Provider):
return value
elif type(tp) is Identity:
return self._serialize_typed_value(value, tp.type, weak_ref=weak_ref)
elif tp in {ValueType.complex64, ValueType.complex128}:
return self._serialize_complex(value, tp)
elif tp is ValueType.slice:
return self._serialize_slice(value)
elif tp is ValueType.arr:
Expand Down Expand Up @@ -319,6 +335,8 @@ cdef class JsonSerializeProvider(Provider):
return value
elif isinstance(value, float):
return value
elif isinstance(value, complex):
return self._serialize_complex(value, ValueType.complex128)
elif isinstance(value, slice):
return self._serialize_slice(value)
elif isinstance(value, np.ndarray):
Expand All @@ -337,6 +355,8 @@ cdef class JsonSerializeProvider(Provider):
return self._serialize_datetime64(value)
elif isinstance(value, np.timedelta64):
return self._serialize_timedelta64(value)
elif isinstance(value, np.number):
return self._serialize_untyped_value(value.item())
else:
raise TypeError('Unknown type to serialize: {0}'.format(type(value)))

Expand Down Expand Up @@ -364,8 +384,12 @@ cdef class JsonSerializeProvider(Provider):
field_val = getattr(model_instance, field.attr)
if field.weak_ref:
field_val = field_val()
value = self._on_serial(field, field_val)
value.serialize(self, new_obj)
if field_val is not None:
if not isinstance(field_val, field.type.model):
raise TypeError('Does not match type for reference field {0}: '
'expect {1}, got {2}'.format(tag, field.type.model, type(field_val)))
value = self._on_serial(field, field_val)
value.serialize(self, new_obj)
elif isinstance(field, OneOfField):
has_val = False
field_val = getattr(model_instance, field.attr, None)
Expand All @@ -384,6 +408,9 @@ cdef class JsonSerializeProvider(Provider):
new_obj = obj[tag] = dict()
value.serialize(self, new_obj)
return
if not has_val and value is not None:
raise ValueError('Value {0} cannot match any type for OneOfField `{1}`'.format(
value, field.tag_name(self)))
elif isinstance(field, ListField) and type(field.type.type) == Reference:
tag = field.tag_name(self)
value = self._on_serial(field, getattr(model_instance, field.attr, None))
Expand All @@ -394,7 +421,11 @@ cdef class JsonSerializeProvider(Provider):
if field.weak_ref:
val = val()
if val is not None:
new_obj.append(val.serialize(self, dict()))
if isinstance(val, field.type.type.model):
new_obj.append(val.serialize(self, dict()))
else:
raise TypeError('Does not match type for reference in list field {0}: '
'expect {1}, got {2}'.format(tag, field.type.type.model, type(val)))
else:
new_obj.append(None)
else:
Expand All @@ -417,6 +448,8 @@ cdef class JsonSerializeProvider(Provider):

if tp is ValueType.bytes:
return ref(base64.b64decode(obj['value']))
elif tp in {ValueType.complex64, ValueType.complex128}:
return ref(self._deserialize_complex(obj, callbacks))
elif tp is ValueType.slice:
return ref(self._deserialize_slice(obj, callbacks))
elif tp is ValueType.arr:
Expand Down
Loading

0 comments on commit 6f9b2c2

Please sign in to comment.