Skip to content

Commit

Permalink
fixed issue of overwriting a slice
Browse files Browse the repository at this point in the history
  • Loading branch information
JAicewizard committed May 23, 2020
1 parent 3b5ab61 commit 5f89800
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 9 deletions.
6 changes: 6 additions & 0 deletions tt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -576,6 +576,12 @@ var testCases = []testCase{
{3, 0, 0, 0, 18, 0, 2, 2, 4, 1, 104, 105, 2, 111, 111, 112, 115, 0, 0, 5, 18, 2, 69, 109, 98, 101, 100, 1, 3, 2, 1, 108, 111, 108, 2, 104, 105, 0},
{3, 0, 0, 0, 18, 0, 2, 0, 5, 18, 2, 69, 109, 98, 101, 100, 1, 3, 2, 1, 108, 111, 108, 2, 104, 105, 0, 2, 4, 1, 104, 105, 2, 111, 111, 112, 115, 0}},
},
{
name: "testSingleValueToInterface",
data: "hello",
bytes: [][]byte{
{3, 0, 5, 0, 1, 104, 101, 108, 108, 111, 0, 0}},
},
}

func testStructDecode(t *testing.T, testcase testCase) {
Expand Down
57 changes: 48 additions & 9 deletions v3.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ type V3Encoder struct {
out io.Writer
varintbuf *[binary.MaxVarintLen64 + 1]byte
sync.Mutex
typeCache map[string]map[string]int
}

var v3StreamHeader = []byte{version3, 1 << 7}
Expand All @@ -32,6 +33,7 @@ func NewV3Encoder(out io.Writer, isStream bool) *V3Encoder {
return &V3Encoder{
out: out,
varintbuf: &[binary.MaxVarintLen64 + 1]byte{},
typeCache: map[string]map[string]int{},
}
}

Expand All @@ -42,6 +44,7 @@ func Encodev3(d interface{}, out io.Writer) error {
enc := &V3Encoder{
out: out,
varintbuf: &[binary.MaxVarintLen64 + 1]byte{},
typeCache: map[string]map[string]int{},
}
//We dont have to lock/unlock since we know we are the only one witha acces
return enc.encodeValuev3(d, v3.Key{})
Expand Down Expand Up @@ -256,7 +259,15 @@ func (enc *V3Encoder) encodeValuev3(d interface{}, k v3.Key) error {
}
value.Vtype = v3.BytesT
} else if kind == reflect.Struct {
usableFields := getStructFields(val)
name := val.Type().String()
var usableFields map[string]int
if v, ok := enc.typeCache[name]; ok {
usableFields = v
} else {
usableFields = getStructFields(val)
enc.typeCache[name] = usableFields
}

value.Childrenn = uint64(len(usableFields))
value.Vtype = v3.MapT
alreadyEncoded = true
Expand Down Expand Up @@ -300,6 +311,7 @@ type V3Decoder struct {
isStream bool
didDecode bool
in v3.Reader
typeCache map[string]map[string]int
sync.Mutex
}

Expand All @@ -308,8 +320,9 @@ type V3Decoder struct {
//Initializing the decoder blocks until at least the first 2 bytes are read.
func NewV3Decoder(in v3.Reader, init bool) *V3Decoder {
dec := V3Decoder{
didInit: !init,
in: in,
didInit: !init,
in: in,
typeCache: map[string]map[string]int{},
}
if init {
dec.Init()
Expand Down Expand Up @@ -543,7 +556,7 @@ func (dec *V3Decoder) decodeValuev3(v v3.Value, e reflect.Value, yetToRead *uint
}
case v3.Int64T:
val := v3.Int64FromBytes(v.Value)
if e.Kind() != reflect.Int64 || e.Kind() != reflect.Int {
if e.Kind() != reflect.Int64 && e.Kind() != reflect.Int {
if e.Kind() != reflect.Interface || e.Type().NumMethod() != 0 {
return errors.New("TT: cannot unmarshal int64 into " + e.Kind().String() + " Go type")
}
Expand Down Expand Up @@ -647,7 +660,7 @@ func (dec *V3Decoder) decodeValuev3(v v3.Value, e reflect.Value, yetToRead *uint
if v, ok := k.([]byte); ok {
m[string(v)] = key.Interface()
} else {
m[k] = key.Interface()
m[v] = key.Interface()
}
}
e.Set(reflect.ValueOf(m))
Expand All @@ -660,9 +673,15 @@ func (dec *V3Decoder) decodeValuev3(v v3.Value, e reflect.Value, yetToRead *uint
}

var err error
value := reflect.New(e.Type().Elem()).Elem()
var value reflect.Value
key := reflect.New(e.Type().Key()).Elem()

ValueKind := e.Type().Elem().Kind()
shouldReplace := ValueKind == reflect.Array || ValueKind == reflect.Slice || ValueKind == reflect.Map

if !shouldReplace {
value = reflect.New(e.Type().Elem()).Elem()
}
for i := uint64(0); i < children; i++ {
v.FromBytes(dec.in)
*yetToRead += v.Childrenn - 1
Expand All @@ -671,16 +690,26 @@ func (dec *V3Decoder) decodeValuev3(v v3.Value, e reflect.Value, yetToRead *uint
if err != nil {
return err
}
if shouldReplace {
value = reflect.New(e.Type().Elem()).Elem()
}
err = dec.decodeValuev3(v, value, yetToRead)
if err != nil {
return err
}

e.SetMapIndex(key, value)
}
} else if e.Kind() == reflect.Struct {

children := v.Childrenn
usableFields := getStructFields(e)
name := e.Type().String()
var usableFields map[string]int
if v, ok := dec.typeCache[name]; ok {
usableFields = v
} else {
usableFields = getStructFields(e)
dec.typeCache[name] = usableFields
}

for i := uint64(0); i < children; i++ {
v.FromBytes(dec.in)
Expand Down Expand Up @@ -727,6 +756,8 @@ func (dec *V3Decoder) decodeValuev3(v v3.Value, e reflect.Value, yetToRead *uint
len := e.Len()
if len < int(children) {
e.Set(reflect.MakeSlice(e.Type(), int(children), int(children)))
} else if len > int(children) {
e.SetLen(int(children))
}
for i := 0; i < int(children); i++ {
v.FromBytes(dec.in)
Expand All @@ -745,11 +776,19 @@ func (dec *V3Decoder) decodeValuev3(v v3.Value, e reflect.Value, yetToRead *uint
//if all special cases fail we fall back to []interface{}
arr := make([]interface{}, children)
var err error
value := reflect.New(reflect.TypeOf(arr).Elem()).Elem()
var value reflect.Value
ValueKind := e.Type().Elem().Kind()
shouldReplace := ValueKind == reflect.Array || ValueKind == reflect.Slice || ValueKind == reflect.Map

if !shouldReplace {
value = reflect.New(e.Type().Elem()).Elem()
}
for i := 0; i < int(children); i++ {
v.FromBytes(dec.in)
*yetToRead += v.Childrenn - 1
if shouldReplace {
value = reflect.New(e.Type().Elem()).Elem()
}

err = dec.decodeValuev3(v, value, yetToRead)
if err != nil {
Expand Down

0 comments on commit 5f89800

Please sign in to comment.