Skip to content

Commit

Permalink
recursive map[string]any decode fix
Browse files Browse the repository at this point in the history
  • Loading branch information
loicalleyne committed Nov 8, 2024
1 parent f7d00ad commit 8c5878c
Show file tree
Hide file tree
Showing 9 changed files with 302 additions and 6 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ go.work.sum
.env

avro
experiments
map.go
*.schema
*.pgo
10 changes: 6 additions & 4 deletions cmd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,16 @@ type Student struct {
ID int64
Day int32
School
Addresses []AddressType
}

func main() {
stu := Student{
Name: "StudentName",
Age: 25,
ID: 123456,
Day: 123,
Name: "StudentName",
Age: 25,
ID: 123456,
Day: 123,
Addresses: []AddressType{{Country: "Azerbijjan"}, {Country: "Zimbabwe"}},
}
sch := School{
Name: "SchoolName",
Expand Down
9 changes: 9 additions & 0 deletions debug/assert_off.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
//go:build !assert
// +build !assert

package debug

// Assert will panic with msg if cond is false.
//
// msg must be a string, func() string or fmt.Stringer.
func Assert(cond bool, msg interface{}) {}
13 changes: 13 additions & 0 deletions debug/assert_on.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
//go:build assert
// +build assert

package debug

// Assert will panic with msg if cond is false.
//
// msg must be a string, func() string or fmt.Stringer.
func Assert(cond bool, msg interface{}) {
if !cond {
panic(getStringValue(msg))
}
}
9 changes: 9 additions & 0 deletions debug/doc.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
/*
Package debug provides APIs for conditional runtime assertions and debug logging.
# Using Assert
To enable runtime assertions, build with the assert tag. When the assert tag is omitted,
the code for the assertion will be omitted from the binary.
*/
package debug
22 changes: 22 additions & 0 deletions debug/util.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
//go:build debug || assert
// +build debug assert

package debug

import "fmt"

func getStringValue(v interface{}) string {
switch a := v.(type) {
case func() string:
return a()

case string:
return a

case fmt.Stringer:
return a.String()

default:
panic(fmt.Sprintf("unexpected type, %t", v))
}
}
227 changes: 227 additions & 0 deletions reader/encoder.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,227 @@
package reader

import (
"encoding"
"errors"
"fmt"
"reflect"
"strings"

"github.com/go-viper/mapstructure/v2"
)

const (
tagNameMapStructure = "mapstructure"
optionSeparator = ","
optionOmitEmpty = "omitempty"
optionSquash = "squash"
optionRemain = "remain"
optionSkip = "-"
)

var (
errNonStringEncodedKey = errors.New("non string-encoded key")
)

// tagInfo stores the mapstructure tag details.
type tagInfo struct {
name string
omitEmpty bool
squash bool
}

// An Encoder takes structured data and converts it into an
// interface following the mapstructure tags.
type Encoder struct {
config *EncoderConfig
}

// EncoderConfig is the configuration used to create a new encoder.
type EncoderConfig struct {
// EncodeHook, if set, is a way to provide custom encoding. It
// will be called before structs and primitive types.
EncodeHook mapstructure.DecodeHookFunc
}

// New returns a new encoder for the configuration.
func New(cfg *EncoderConfig) *Encoder {
return &Encoder{config: cfg}
}

// Encode takes the input and uses reflection to encode it to
// an interface based on the mapstructure spec.
func (e *Encoder) Encode(input any) (any, error) {
return e.encode(reflect.ValueOf(input))
}

// encode processes the value based on the reflect.Kind.
func (e *Encoder) encode(value reflect.Value) (any, error) {
if value.IsValid() {
switch value.Kind() {
case reflect.Interface, reflect.Ptr:
return e.encode(value.Elem())
case reflect.Map:
return e.encodeMap(value)
case reflect.Slice:
return e.encodeSlice(value)
case reflect.Struct:
return e.encodeStruct(value)
default:
return e.encodeHook(value)
}
}
return nil, nil
}

// encodeHook calls the EncodeHook in the EncoderConfig with the value passed in.
// This is called before processing structs and for primitive data types.
func (e *Encoder) encodeHook(value reflect.Value) (any, error) {
if e.config != nil && e.config.EncodeHook != nil {
out, err := mapstructure.DecodeHookExec(e.config.EncodeHook, value, value)
if err != nil {
return nil, fmt.Errorf("error running encode hook: %w", err)
}
return out, nil
}
return value.Interface(), nil
}

// encodeStruct encodes the struct by iterating over the fields, getting the
// mapstructure tagInfo for each exported field, and encoding the value.
func (e *Encoder) encodeStruct(value reflect.Value) (any, error) {
if value.Kind() != reflect.Struct {
return nil, &reflect.ValueError{
Method: "encodeStruct",
Kind: value.Kind(),
}
}
out, err := e.encodeHook(value)
if err != nil {
return nil, err
}
value = reflect.ValueOf(out)
// if the output of encodeHook is no longer a struct,
// call encode against it.
if value.Kind() != reflect.Struct {
return e.encode(value)
}
result := make(map[string]any)
for i := 0; i < value.NumField(); i++ {
field := value.Field(i)
if field.CanInterface() {
info := getTagInfo(value.Type().Field(i))
if (info.omitEmpty && field.IsZero()) || info.name == optionSkip {
continue
}
encoded, err := e.encode(field)
if err != nil {
return nil, fmt.Errorf("error encoding field %q: %w", info.name, err)
}
if info.squash {
if m, ok := encoded.(map[string]any); ok {
for k, v := range m {
result[k] = v
}
}
} else {
result[info.name] = encoded
}
}
}
return result, nil
}

// encodeSlice iterates over the slice and encodes each of the elements.
func (e *Encoder) encodeSlice(value reflect.Value) (any, error) {
if value.Kind() != reflect.Slice {
return nil, &reflect.ValueError{
Method: "encodeSlice",
Kind: value.Kind(),
}
}
result := make([]any, value.Len())
for i := 0; i < value.Len(); i++ {
var err error
if result[i], err = e.encode(value.Index(i)); err != nil {
return nil, fmt.Errorf("error encoding element in slice at index %d: %w", i, err)
}
}
return result, nil
}

// encodeMap encodes a map by encoding the key and value. Returns errNonStringEncodedKey
// if the key is not encoded into a string.
func (e *Encoder) encodeMap(value reflect.Value) (any, error) {
if value.Kind() != reflect.Map {
return nil, &reflect.ValueError{
Method: "encodeMap",
Kind: value.Kind(),
}
}
result := make(map[string]any)
iterator := value.MapRange()
for iterator.Next() {
encoded, err := e.encode(iterator.Key())
if err != nil {
return nil, fmt.Errorf("error encoding key: %w", err)
}

v := reflect.ValueOf(encoded)
var key string

switch v.Kind() {
case reflect.String:
key = v.String()
default:
return nil, fmt.Errorf("%w, key: %q, kind: %v, type: %T", errNonStringEncodedKey, iterator.Key().Interface(), iterator.Key().Kind(), encoded)
}

if _, ok := result[key]; ok {
return nil, fmt.Errorf("duplicate key %q while encoding", key)
}
if result[key], err = e.encode(iterator.Value()); err != nil {
return nil, fmt.Errorf("error encoding map value for key %q: %w", key, err)
}
}
return result, nil
}

// getTagInfo looks up the mapstructure tag and uses that if available.
// Uses the lowercase field if not found. Checks for omitempty and squash.
func getTagInfo(field reflect.StructField) *tagInfo {
info := tagInfo{}
if tag, ok := field.Tag.Lookup(tagNameMapStructure); ok {
options := strings.Split(tag, optionSeparator)
info.name = options[0]
if len(options) > 1 {
for _, option := range options[1:] {
switch option {
case optionOmitEmpty:
info.omitEmpty = true
case optionSquash, optionRemain:
info.squash = true
}
}
}
} else {
info.name = strings.ToLower(field.Name)
}
return &info
}

// TextMarshalerHookFunc returns a DecodeHookFuncValue that checks
// for the encoding.TextMarshaler interface and calls the MarshalText
// function if found.
func TextMarshalerHookFunc() mapstructure.DecodeHookFuncValue {
return func(from reflect.Value, _ reflect.Value) (any, error) {
marshaler, ok := from.Interface().(encoding.TextMarshaler)
if !ok {
return from.Interface(), nil
}
out, err := marshaler.MarshalText()
if err != nil {
return nil, err
}
return string(out), nil
}
}
6 changes: 4 additions & 2 deletions reader/input.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,12 @@ func InputMap(a any) (map[string]any, error) {
return nil, fmt.Errorf("%v : %v", ErrInvalidInput, err)
}
default:
err := mapstructure.Decode(a, &m)
ms := New(&EncoderConfig{EncodeHook: mapstructure.RecursiveStructToMapHookFunc()})
enc, err := ms.Encode(a)
if err != nil {
return nil, fmt.Errorf("%v : %v", ErrInvalidInput, err)
return nil, fmt.Errorf("Error decoding to map[string]interface{}: %v", err)
}
return enc.(map[string]any), nil
}
return m, nil
}
11 changes: 11 additions & 0 deletions types.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,13 +116,24 @@ func goType2Arrow(f *fieldPos, gt any) arrow.DataType {
// the set of all complex numbers with float32 real and imaginary parts
case complex64:
// TO-DO
f.arrowType = arrow.NULL
f.err = fmt.Errorf("%v : %v", ErrUndefinedFieldType, f.namePath())
dt = arrow.BinaryTypes.Binary
// the set of all complex numbers with float64 real and imaginary parts
case complex128:
// TO-DO
f.arrowType = arrow.NULL
f.err = fmt.Errorf("%v : %v", ErrUndefinedFieldType, f.namePath())
dt = arrow.BinaryTypes.Binary
case nil:
f.arrowType = arrow.NULL
f.err = fmt.Errorf("%v : %v", ErrUndefinedFieldType, f.namePath())
dt = arrow.BinaryTypes.Binary
default:
// Catch-all for exotic unsupported types - ie. input field is a func
f.arrowType = arrow.NULL
f.err = fmt.Errorf("%v : %v", ErrUndefinedFieldType, f.namePath())
dt = arrow.BinaryTypes.Binary
}
return dt
}
Expand Down

0 comments on commit 8c5878c

Please sign in to comment.