From 8c5878c1bba1554c2621e87ea6517784f5f0419f Mon Sep 17 00:00:00 2001 From: loicalleyne Date: Fri, 8 Nov 2024 17:49:35 -0500 Subject: [PATCH] recursive map[string]any decode fix --- .gitignore | 1 + cmd/main.go | 10 +- debug/assert_off.go | 9 ++ debug/assert_on.go | 13 +++ debug/doc.go | 9 ++ debug/util.go | 22 +++++ reader/encoder.go | 227 ++++++++++++++++++++++++++++++++++++++++++++ reader/input.go | 6 +- types.go | 11 +++ 9 files changed, 302 insertions(+), 6 deletions(-) create mode 100644 debug/assert_off.go create mode 100644 debug/assert_on.go create mode 100644 debug/doc.go create mode 100644 debug/util.go create mode 100644 reader/encoder.go diff --git a/.gitignore b/.gitignore index 2031b4f..72980f9 100644 --- a/.gitignore +++ b/.gitignore @@ -25,6 +25,7 @@ go.work.sum .env avro +experiments map.go *.schema *.pgo \ No newline at end of file diff --git a/cmd/main.go b/cmd/main.go index 57464cb..fa8b273 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -26,14 +26,16 @@ type Student struct { ID int64 Day int32 School + Addresses []AddressType } func main() { stu := Student{ - Name: "StudentName", - Age: 25, - ID: 123456, - Day: 123, + Name: "StudentName", + Age: 25, + ID: 123456, + Day: 123, + Addresses: []AddressType{{Country: "Azerbijjan"}, {Country: "Zimbabwe"}}, } sch := School{ Name: "SchoolName", diff --git a/debug/assert_off.go b/debug/assert_off.go new file mode 100644 index 0000000..7236066 --- /dev/null +++ b/debug/assert_off.go @@ -0,0 +1,9 @@ +//go:build !assert +// +build !assert + +package debug + +// Assert will panic with msg if cond is false. +// +// msg must be a string, func() string or fmt.Stringer. +func Assert(cond bool, msg interface{}) {} diff --git a/debug/assert_on.go b/debug/assert_on.go new file mode 100644 index 0000000..164ce3b --- /dev/null +++ b/debug/assert_on.go @@ -0,0 +1,13 @@ +//go:build assert +// +build assert + +package debug + +// Assert will panic with msg if cond is false. +// +// msg must be a string, func() string or fmt.Stringer. +func Assert(cond bool, msg interface{}) { + if !cond { + panic(getStringValue(msg)) + } +} diff --git a/debug/doc.go b/debug/doc.go new file mode 100644 index 0000000..9ff7166 --- /dev/null +++ b/debug/doc.go @@ -0,0 +1,9 @@ +/* +Package debug provides APIs for conditional runtime assertions and debug logging. + +# Using Assert + +To enable runtime assertions, build with the assert tag. When the assert tag is omitted, +the code for the assertion will be omitted from the binary. +*/ +package debug diff --git a/debug/util.go b/debug/util.go new file mode 100644 index 0000000..5baf29d --- /dev/null +++ b/debug/util.go @@ -0,0 +1,22 @@ +//go:build debug || assert +// +build debug assert + +package debug + +import "fmt" + +func getStringValue(v interface{}) string { + switch a := v.(type) { + case func() string: + return a() + + case string: + return a + + case fmt.Stringer: + return a.String() + + default: + panic(fmt.Sprintf("unexpected type, %t", v)) + } +} diff --git a/reader/encoder.go b/reader/encoder.go new file mode 100644 index 0000000..a759fce --- /dev/null +++ b/reader/encoder.go @@ -0,0 +1,227 @@ +package reader + +import ( + "encoding" + "errors" + "fmt" + "reflect" + "strings" + + "github.com/go-viper/mapstructure/v2" +) + +const ( + tagNameMapStructure = "mapstructure" + optionSeparator = "," + optionOmitEmpty = "omitempty" + optionSquash = "squash" + optionRemain = "remain" + optionSkip = "-" +) + +var ( + errNonStringEncodedKey = errors.New("non string-encoded key") +) + +// tagInfo stores the mapstructure tag details. +type tagInfo struct { + name string + omitEmpty bool + squash bool +} + +// An Encoder takes structured data and converts it into an +// interface following the mapstructure tags. +type Encoder struct { + config *EncoderConfig +} + +// EncoderConfig is the configuration used to create a new encoder. +type EncoderConfig struct { + // EncodeHook, if set, is a way to provide custom encoding. It + // will be called before structs and primitive types. + EncodeHook mapstructure.DecodeHookFunc +} + +// New returns a new encoder for the configuration. +func New(cfg *EncoderConfig) *Encoder { + return &Encoder{config: cfg} +} + +// Encode takes the input and uses reflection to encode it to +// an interface based on the mapstructure spec. +func (e *Encoder) Encode(input any) (any, error) { + return e.encode(reflect.ValueOf(input)) +} + +// encode processes the value based on the reflect.Kind. +func (e *Encoder) encode(value reflect.Value) (any, error) { + if value.IsValid() { + switch value.Kind() { + case reflect.Interface, reflect.Ptr: + return e.encode(value.Elem()) + case reflect.Map: + return e.encodeMap(value) + case reflect.Slice: + return e.encodeSlice(value) + case reflect.Struct: + return e.encodeStruct(value) + default: + return e.encodeHook(value) + } + } + return nil, nil +} + +// encodeHook calls the EncodeHook in the EncoderConfig with the value passed in. +// This is called before processing structs and for primitive data types. +func (e *Encoder) encodeHook(value reflect.Value) (any, error) { + if e.config != nil && e.config.EncodeHook != nil { + out, err := mapstructure.DecodeHookExec(e.config.EncodeHook, value, value) + if err != nil { + return nil, fmt.Errorf("error running encode hook: %w", err) + } + return out, nil + } + return value.Interface(), nil +} + +// encodeStruct encodes the struct by iterating over the fields, getting the +// mapstructure tagInfo for each exported field, and encoding the value. +func (e *Encoder) encodeStruct(value reflect.Value) (any, error) { + if value.Kind() != reflect.Struct { + return nil, &reflect.ValueError{ + Method: "encodeStruct", + Kind: value.Kind(), + } + } + out, err := e.encodeHook(value) + if err != nil { + return nil, err + } + value = reflect.ValueOf(out) + // if the output of encodeHook is no longer a struct, + // call encode against it. + if value.Kind() != reflect.Struct { + return e.encode(value) + } + result := make(map[string]any) + for i := 0; i < value.NumField(); i++ { + field := value.Field(i) + if field.CanInterface() { + info := getTagInfo(value.Type().Field(i)) + if (info.omitEmpty && field.IsZero()) || info.name == optionSkip { + continue + } + encoded, err := e.encode(field) + if err != nil { + return nil, fmt.Errorf("error encoding field %q: %w", info.name, err) + } + if info.squash { + if m, ok := encoded.(map[string]any); ok { + for k, v := range m { + result[k] = v + } + } + } else { + result[info.name] = encoded + } + } + } + return result, nil +} + +// encodeSlice iterates over the slice and encodes each of the elements. +func (e *Encoder) encodeSlice(value reflect.Value) (any, error) { + if value.Kind() != reflect.Slice { + return nil, &reflect.ValueError{ + Method: "encodeSlice", + Kind: value.Kind(), + } + } + result := make([]any, value.Len()) + for i := 0; i < value.Len(); i++ { + var err error + if result[i], err = e.encode(value.Index(i)); err != nil { + return nil, fmt.Errorf("error encoding element in slice at index %d: %w", i, err) + } + } + return result, nil +} + +// encodeMap encodes a map by encoding the key and value. Returns errNonStringEncodedKey +// if the key is not encoded into a string. +func (e *Encoder) encodeMap(value reflect.Value) (any, error) { + if value.Kind() != reflect.Map { + return nil, &reflect.ValueError{ + Method: "encodeMap", + Kind: value.Kind(), + } + } + result := make(map[string]any) + iterator := value.MapRange() + for iterator.Next() { + encoded, err := e.encode(iterator.Key()) + if err != nil { + return nil, fmt.Errorf("error encoding key: %w", err) + } + + v := reflect.ValueOf(encoded) + var key string + + switch v.Kind() { + case reflect.String: + key = v.String() + default: + return nil, fmt.Errorf("%w, key: %q, kind: %v, type: %T", errNonStringEncodedKey, iterator.Key().Interface(), iterator.Key().Kind(), encoded) + } + + if _, ok := result[key]; ok { + return nil, fmt.Errorf("duplicate key %q while encoding", key) + } + if result[key], err = e.encode(iterator.Value()); err != nil { + return nil, fmt.Errorf("error encoding map value for key %q: %w", key, err) + } + } + return result, nil +} + +// getTagInfo looks up the mapstructure tag and uses that if available. +// Uses the lowercase field if not found. Checks for omitempty and squash. +func getTagInfo(field reflect.StructField) *tagInfo { + info := tagInfo{} + if tag, ok := field.Tag.Lookup(tagNameMapStructure); ok { + options := strings.Split(tag, optionSeparator) + info.name = options[0] + if len(options) > 1 { + for _, option := range options[1:] { + switch option { + case optionOmitEmpty: + info.omitEmpty = true + case optionSquash, optionRemain: + info.squash = true + } + } + } + } else { + info.name = strings.ToLower(field.Name) + } + return &info +} + +// TextMarshalerHookFunc returns a DecodeHookFuncValue that checks +// for the encoding.TextMarshaler interface and calls the MarshalText +// function if found. +func TextMarshalerHookFunc() mapstructure.DecodeHookFuncValue { + return func(from reflect.Value, _ reflect.Value) (any, error) { + marshaler, ok := from.Interface().(encoding.TextMarshaler) + if !ok { + return from.Interface(), nil + } + out, err := marshaler.MarshalText() + if err != nil { + return nil, err + } + return string(out), nil + } +} diff --git a/reader/input.go b/reader/input.go index c1643b1..bbf7206 100644 --- a/reader/input.go +++ b/reader/input.go @@ -43,10 +43,12 @@ func InputMap(a any) (map[string]any, error) { return nil, fmt.Errorf("%v : %v", ErrInvalidInput, err) } default: - err := mapstructure.Decode(a, &m) + ms := New(&EncoderConfig{EncodeHook: mapstructure.RecursiveStructToMapHookFunc()}) + enc, err := ms.Encode(a) if err != nil { - return nil, fmt.Errorf("%v : %v", ErrInvalidInput, err) + return nil, fmt.Errorf("Error decoding to map[string]interface{}: %v", err) } + return enc.(map[string]any), nil } return m, nil } diff --git a/types.go b/types.go index 5df70b1..315f999 100644 --- a/types.go +++ b/types.go @@ -116,13 +116,24 @@ func goType2Arrow(f *fieldPos, gt any) arrow.DataType { // the set of all complex numbers with float32 real and imaginary parts case complex64: // TO-DO + f.arrowType = arrow.NULL + f.err = fmt.Errorf("%v : %v", ErrUndefinedFieldType, f.namePath()) + dt = arrow.BinaryTypes.Binary // the set of all complex numbers with float64 real and imaginary parts case complex128: // TO-DO + f.arrowType = arrow.NULL + f.err = fmt.Errorf("%v : %v", ErrUndefinedFieldType, f.namePath()) + dt = arrow.BinaryTypes.Binary case nil: f.arrowType = arrow.NULL f.err = fmt.Errorf("%v : %v", ErrUndefinedFieldType, f.namePath()) dt = arrow.BinaryTypes.Binary + default: + // Catch-all for exotic unsupported types - ie. input field is a func + f.arrowType = arrow.NULL + f.err = fmt.Errorf("%v : %v", ErrUndefinedFieldType, f.namePath()) + dt = arrow.BinaryTypes.Binary } return dt }