diff --git a/README.md b/README.md index a63c59f..2aed90e 100644 --- a/README.md +++ b/README.md @@ -10,16 +10,17 @@ Data is are benchmarks that use Data, map benchmarks just convert the data to a r5 3600 make bench ``` -BenchmarkV3 305139 3920 ns/op 1326 B/op 33 allocs/op -BenchmarkV3Decode 374406 2980 ns/op 778 B/op 28 allocs/op -BenchmarkV3Encode 1238996 1003 ns/op 545 B/op 5 allocs/op -BenchmarkV3int64 1504563 787 ns/op 412 B/op 11 allocs/op -BenchmarkV2 446307 2431 ns/op 1645 B/op 26 allocs/op -BenchmarkV2Decode 976438 1254 ns/op 776 B/op 22 allocs/op -BenchmarkV2Encode 943730 1184 ns/op 751 B/op 4 allocs/op -BenchmarkGobMap 91776 12232 ns/op 1916 B/op 68 allocs/op -BenchmarkGobMapDecode 165802 6618 ns/op 1260 B/op 48 allocs/op -BenchmarkGobMapEncode 237160 5112 ns/op 656 B/op 20 allocs/op``` +BenchmarkV3 303894 3921 ns/op 1337 B/op 33 allocs/op +BenchmarkV3Decode 401946 2827 ns/op 778 B/op 28 allocs/op +BenchmarkV3Encode 1208140 1049 ns/op 558 B/op 5 allocs/op +BenchmarkV3int64 1657701 730 ns/op 572 B/op 11 allocs/op +BenchmarkV2 450891 2442 ns/op 1615 B/op 26 allocs/op +BenchmarkV2Decode 914095 1216 ns/op 776 B/op 22 allocs/op +BenchmarkV2Encode 962227 1174 ns/op 791 B/op 4 allocs/op +BenchmarkGobMap 99562 11806 ns/op 1916 B/op 68 allocs/op +BenchmarkGobMapDecode 183130 6511 ns/op 1260 B/op 48 allocs/op +BenchmarkGobMapEncode 235747 4983 ns/op 656 B/op 20 allocs/op +``` ## notes - all tests are run using `GOMAXPROCS=1`, this is because on zen running on multiple threads will cause horrible cache-invalidation. A single alloc/op would cause the GC to run at some point, this would kick the benching to a diferent core. The reason I decided to run using `GOMAXPROCS=1` is because this doesnt have a big impact on Intel cpus, and any real world application would be generating garbage anyways, so eleminitin the GC from running should be part of the benchmark. Another reason coul be: real world applications would so something else in between runs causing cache-invalidation anyways. diff --git a/makefile b/makefile index a2b2d89..7005622 100644 --- a/makefile +++ b/makefile @@ -3,7 +3,10 @@ alloc: env GODEBUG=allocfreetrace=1 ./tt.test -test.run=none -test.bench=BenchmarkV3$$ -test.benchtime=10ms 2>trace.log bench: - go test -bench=. -run=^\$ -benchtime=10s -benchmem + env GOMAXPROCS=1 go test -bench=. -run=^\$ -benchtime=10s -benchmem shortbench: - go test -bench=. -run=^\$ -benchtime=1s -benchmem \ No newline at end of file + env GOMAXPROCS=1 go test -bench=. -run=^\$ -benchtime=1s -benchmem + +test: + go test -short \ No newline at end of file diff --git a/tt_test.go b/tt_test.go index 25c9288..0f7dbcd 100644 --- a/tt_test.go +++ b/tt_test.go @@ -660,6 +660,7 @@ func testFuz(t *testing.T, testcase FuzzStruct) { t.Error(err) } //this only tests public fields + deep.NilMapsAreEmpty = true if diff := deep.Equal(testcase, after.Interface()); diff != nil { t.Error(testcase) t.Error(after.Interface()) @@ -667,43 +668,43 @@ func testFuz(t *testing.T, testcase FuzzStruct) { } } -//{map[-1849047553:{0.70218486 0.5657049046289432} -1704443557:{0.2711049 0.1320528724201477} -579969259:{0.41055316 0.6793484076806898} -324059402:{0.2030408 0.24218549972596393} 109147779:{0.4525726 0.32443739229246954} 644540783:{0.17598473 0.6027663378651792} 1387980123:{0.53488904 0.8896629879214714} 1466347468:{0.6818829 0.03911422469086078}] 36 18963 847877453 10861296638687174198 16814441198678319341 -120 9841 1125880010 -4082773609634810997 5510737390671974350 0.87748927 0.9334683660418484 map[]} -//{map[-1849047553:{0.70218486 0.5657049046289432} -1704443557:{0.2711049 0.1320528724201477} -579969259:{0.41055316 0.6793484076806898} -324059402:{0.2030408 0.24218549972596393} 109147779:{0.4525726 0.32443739229246954} 644540783:{0.17598473 0.6027663378651792} 1387980123:{0.53488904 0.8896629879214714} 1466347468:{0.6818829 0.03911422469086078}] 36 18963 847877453 10861296638687174198 16814441198678319341 -120 9841 1125880010 -4082773609634810997 5510737390671974350 0.87748927 0.9334683660418484 map[]} func TestEncodeDecodeFuzz(t *testing.T) { + if testing.Short() { + t.Skip("Skipping long-running test.") + } data := make([]byte, 1000000) s := FuzzStruct{} - for i := 0; i < 50000; i++ { + for i := 0; i < 10000; i++ { t.Run(fmt.Sprintf("%d", i), func(te *testing.T) { rand.Read(data) fuzz.NewFromGoFuzz(data).Fuzz(&s) - if i%1 == 0 { + if i%100 == 0 { fmt.Printf("%d\n", i) } testFuz(t, s) - if i%1 == 0 { - fmt.Printf("%d\n", i) - } - buf := bytes.NewBuffer(data[:10000]) - after := map[interface{}]interface{}{} - err := Decodev3(buf, &after) - if err != nil { - t.Error(err) - } - }) } } func TestDecodeFuzz(t *testing.T) { - data := make([]byte, 1000000) - s := FuzzStruct{} + if testing.Short() { + t.Skip("Skipping long-running test.") + } + data := make([]byte, 10000) for i := 0; i < 50000; i++ { t.Run(fmt.Sprintf("%d", i), func(te *testing.T) { rand.Read(data) - if i%10000 == 0 { + if i%500 == 0 { fmt.Printf("%d\n", i) } - testFuz(t, s) + buf := bytes.NewBuffer(data[:10000]) + after := interface{}(nil) + dec := NewV3Decoder(buf, true) + dec.SetAllocLimmit(2 << 30) //1 GiB + err := dec.Decode(&after) + if err != nil { + t.Log(err) + } }) } } diff --git a/v3.go b/v3.go index a1b5614..c289171 100644 --- a/v3.go +++ b/v3.go @@ -5,6 +5,7 @@ import ( "encoding/gob" "errors" "io" + "math" "reflect" "strconv" "sync" @@ -493,25 +494,15 @@ func (enc *V3Encoder) encodeValuev3_reflect(d reflect.Value, k v3.KeyValue) erro } } case reflect.Array: - //TODO Only check the types first and only then convert?? - i := d.Interface() - value.Childrenn = uint64(d.Len()) value.Value.Vtype = v3.ArrT alreadyEncoded = true v3.AddValue(enc.out, &value, enc.varintbuf) - switch s := i.(type) { - case []string: - for _, v := range s { - enc.encodeString(v, v3.KeyValue{}) - } - default: - //if its not a specific slice type - for i := 0; i < int(value.Childrenn); i++ { - err := enc.encodeValuev3_reflect(d.Index(i), v3.KeyValue{}) - if err != nil { - return err - } + //if its not a specific slice type + for i := 0; i < int(value.Childrenn); i++ { + err := enc.encodeValuev3_reflect(d.Index(i), v3.KeyValue{}) + if err != nil { + return err } } case reflect.Struct: @@ -568,12 +559,13 @@ func getStructFields(val reflect.Value) map[string]int { //V3Decoder is the decoder used to decode a ttv3 data stream type V3Decoder struct { - didInit bool - isStream bool - didDecode bool - in v3.Reader - typeCache map[string]map[string]int - yetToRead uint64 + didInit bool + isStream bool + didDecode bool + in v3.Reader + typeCache map[string]map[string]int + yetToRead uint64 + allocLimmit uint64 sync.Mutex } @@ -582,9 +574,10 @@ type V3Decoder struct { //Initializing the decoder blocks until at least the first 2 bytes are read. func NewV3Decoder(in v3.Reader, init bool) *V3Decoder { dec := V3Decoder{ - didInit: !init, - in: in, - typeCache: map[string]map[string]int{}, + didInit: !init, + in: in, + typeCache: map[string]map[string]int{}, + allocLimmit: math.MaxUint64, } if init { dec.Init() @@ -608,6 +601,12 @@ func (dec *V3Decoder) Init() error { return nil } +//SetAllocLimmit sets the limmit of allocations. This does not induce +//a global limmit in tt but only for individual allocations +func (dec *V3Decoder) SetAllocLimmit(limit uint64) { + dec.allocLimmit = limit +} + //Decode decodes a one ttv3 encoded value from a stream. //Note that a stream of one value is the same as one value just with //the stream bit set @@ -635,12 +634,12 @@ func (dec *V3Decoder) decode(e interface{}) error { } if e == nil { var v v3.Value - err = v.FromBytes(dec.in) + err = v.FromBytes(dec.in, dec.allocLimmit) if err != nil { return err } - clearNextValues(dec.in, v.Childrenn) - return nil + + return clearNextValues(dec.in, v.Childrenn, dec.allocLimmit) } value := reflect.ValueOf(e) @@ -657,17 +656,16 @@ func (dec *V3Decoder) decode(e interface{}) error { } } else { var v v3.Value - err = v.FromBytes(dec.in) + err = v.FromBytes(dec.in, dec.allocLimmit) if err != nil { return err } - clearNextValues(dec.in, v.Childrenn) - return nil + return clearNextValues(dec.in, v.Childrenn, dec.allocLimmit) } var v v3.Value - err = v.FromBytes(dec.in) + err = v.FromBytes(dec.in, dec.allocLimmit) if err != nil { return err } @@ -675,7 +673,10 @@ func (dec *V3Decoder) decode(e interface{}) error { dec.yetToRead = v.Childrenn err = dec.decodeValuev3(v, value) if dec.yetToRead != 0 { - clearNextValues(dec.in, dec.yetToRead) + err2 := clearNextValues(dec.in, dec.yetToRead, dec.allocLimmit) + if err2 != nil { + return err2 + } } return err @@ -711,6 +712,11 @@ func decodeBytes(data v3.KeyValue, e reflect.Value) error { } func decodeInt8(data v3.KeyValue, e reflect.Value) error { + if len(data.Value) != 1 { + buf := [1]byte{} + copy(buf[:], data.Value) + data.Value = buf[:] + } val := v3.Int8FromBytes(data.Value[0]) if e.Kind() != reflect.Int8 { if e.Kind() != reflect.Interface || e.Type().NumMethod() != 0 { @@ -724,6 +730,11 @@ func decodeInt8(data v3.KeyValue, e reflect.Value) error { } func decodeInt16(data v3.KeyValue, e reflect.Value) error { + if len(data.Value) != 2 { + buf := [2]byte{} + copy(buf[:], data.Value) + data.Value = buf[:] + } val := v3.Int16FromBytes(data.Value) if e.Kind() != reflect.Int16 { if e.Kind() != reflect.Interface || e.Type().NumMethod() != 0 { @@ -737,6 +748,11 @@ func decodeInt16(data v3.KeyValue, e reflect.Value) error { } func decodeInt32(data v3.KeyValue, e reflect.Value) error { + if len(data.Value) != 4 { + buf := [4]byte{} + copy(buf[:], data.Value) + data.Value = buf[:] + } val := v3.Int32FromBytes(data.Value) if e.Kind() != reflect.Int32 { if e.Kind() != reflect.Interface || e.Type().NumMethod() != 0 { @@ -750,6 +766,11 @@ func decodeInt32(data v3.KeyValue, e reflect.Value) error { } func decodeInt64(data v3.KeyValue, e reflect.Value) error { + if len(data.Value) != 8 { + buf := [8]byte{} + copy(buf[:], data.Value) + data.Value = buf[:] + } val := v3.Int64FromBytes(data.Value) if e.Kind() != reflect.Int64 && e.Kind() != reflect.Int { if e.Kind() != reflect.Interface || e.Type().NumMethod() != 0 { @@ -763,6 +784,11 @@ func decodeInt64(data v3.KeyValue, e reflect.Value) error { } func decodeUint8(data v3.KeyValue, e reflect.Value) error { + if len(data.Value) != 1 { + buf := [1]byte{} + copy(buf[:], data.Value) + data.Value = buf[:] + } val := v3.Uint8FromBytes(data.Value[0]) if e.Kind() != reflect.Uint8 { if e.Kind() != reflect.Interface || e.Type().NumMethod() != 0 { @@ -776,6 +802,11 @@ func decodeUint8(data v3.KeyValue, e reflect.Value) error { } func decodeUint16(data v3.KeyValue, e reflect.Value) error { + if len(data.Value) != 2 { + buf := [2]byte{} + copy(buf[:], data.Value) + data.Value = buf[:] + } val := v3.Uint16FromBytes(data.Value) if e.Kind() != reflect.Uint16 { if e.Kind() != reflect.Interface || e.Type().NumMethod() != 0 { @@ -789,6 +820,11 @@ func decodeUint16(data v3.KeyValue, e reflect.Value) error { } func decodeUint32(data v3.KeyValue, e reflect.Value) error { + if len(data.Value) != 4 { + buf := [4]byte{} + copy(buf[:], data.Value) + data.Value = buf[:] + } val := v3.Uint32FromBytes(data.Value) if e.Kind() != reflect.Uint32 { if e.Kind() != reflect.Interface || e.Type().NumMethod() != 0 { @@ -802,6 +838,11 @@ func decodeUint32(data v3.KeyValue, e reflect.Value) error { } func decodeUint64(data v3.KeyValue, e reflect.Value) error { + if len(data.Value) != 8 { + buf := [8]byte{} + copy(buf[:], data.Value) + data.Value = buf[:] + } val := v3.Uint64FromBytes(data.Value) if e.Kind() != reflect.Uint64 && e.Kind() != reflect.Uint { if e.Kind() != reflect.Interface || e.Type().NumMethod() != 0 { @@ -815,7 +856,13 @@ func decodeUint64(data v3.KeyValue, e reflect.Value) error { } func decodeBool(data v3.KeyValue, e reflect.Value) error { - val := v3.BoolFromBytes(data.Value) + val := false + if len(data.Value) != 1 { + val = v3.BoolFromBytes([]byte{0}) + } else { + val = v3.BoolFromBytes(data.Value) + } + if e.Kind() != reflect.Bool { if e.Kind() != reflect.Interface || e.Type().NumMethod() != 0 { return errors.New("TT: cannot unmarshal bytes into " + e.Kind().String() + " Go type") @@ -828,6 +875,11 @@ func decodeBool(data v3.KeyValue, e reflect.Value) error { } func decodeFloat32(data v3.KeyValue, e reflect.Value) error { + if len(data.Value) != 4 { + buf := [4]byte{} + copy(buf[:], data.Value) + data.Value = buf[:] + } val := v3.Float32FromBytes(data.Value) if e.Kind() != reflect.Float32 { if e.Kind() != reflect.Interface || e.Type().NumMethod() != 0 { @@ -841,6 +893,11 @@ func decodeFloat32(data v3.KeyValue, e reflect.Value) error { } func decodeFloat64(data v3.KeyValue, e reflect.Value) error { + if len(data.Value) != 8 { + buf := [8]byte{} + copy(buf[:], data.Value) + data.Value = buf[:] + } val := v3.Float64FromBytes(data.Value) if e.Kind() != reflect.Float64 { if e.Kind() != reflect.Interface || e.Type().NumMethod() != 0 { @@ -860,7 +917,7 @@ func decodeMap(dec *V3Decoder, v v3.Value, e reflect.Value) error { var err error key := reflect.New(reflect.TypeOf(m).Key()).Elem() for i := uint64(0); i < children; i++ { - err = v.FromBytes(dec.in) + err = v.FromBytes(dec.in, dec.allocLimmit) if err != nil { return err } @@ -901,7 +958,7 @@ func decodeMap(dec *V3Decoder, v v3.Value, e reflect.Value) error { value = reflect.New(elem).Elem() } for i := uint64(0); i < children; i++ { - err = v.FromBytes(dec.in) + err = v.FromBytes(dec.in, dec.allocLimmit) if err != nil { return err } @@ -932,7 +989,7 @@ func decodeMap(dec *V3Decoder, v v3.Value, e reflect.Value) error { } for i := uint64(0); i < children; i++ { - err := v.FromBytes(dec.in) + err := v.FromBytes(dec.in, dec.allocLimmit) if err != nil { return err } @@ -941,12 +998,20 @@ func decodeMap(dec *V3Decoder, v v3.Value, e reflect.Value) error { key := v.Key.ExportStructID() if key == "" { - clearNextValues(dec.in, v.Childrenn) + err = clearNextValues(dec.in, v.Childrenn, dec.allocLimmit) + if err != nil { + return err + } + continue } fieldIndex, ok := usableFields[key] if !ok { - clearNextValues(dec.in, v.Childrenn) + err = clearNextValues(dec.in, v.Childrenn, dec.allocLimmit) + if err != nil { + return err + } + continue } @@ -970,7 +1035,7 @@ func decodeArr(dec *V3Decoder, v v3.Value, e reflect.Value) error { return nil } for i := 0; i < int(children); i++ { - err := v.FromBytes(dec.in) + err := v.FromBytes(dec.in, dec.allocLimmit) if err != nil { return err } @@ -989,7 +1054,7 @@ func decodeArr(dec *V3Decoder, v v3.Value, e reflect.Value) error { e.SetLen(int(children)) } for i := 0; i < int(children); i++ { - err := v.FromBytes(dec.in) + err := v.FromBytes(dec.in, dec.allocLimmit) if err != nil { return err } @@ -1017,7 +1082,7 @@ func decodeArr(dec *V3Decoder, v v3.Value, e reflect.Value) error { value = reflect.New(valueElem).Elem() } for i := 0; i < int(children); i++ { - err = v.FromBytes(dec.in) + err = v.FromBytes(dec.in, dec.allocLimmit) if err != nil { return err } @@ -1039,6 +1104,9 @@ func decodeArr(dec *V3Decoder, v v3.Value, e reflect.Value) error { } func decodeKeyv3(k v3.KeyValue, e reflect.Value) error { + if int(k.Vtype) > 18 { + return errors.New("TT: cannot unmarshal invalid key type:" + strconv.Itoa(int(k.Vtype))) + } return decodersSlice[int(k.Vtype)](k, e) } @@ -1109,10 +1177,14 @@ func getFieldName(field reflect.StructField) string { return name } -func clearNextValues(buf v3.Reader, values uint64) { +func clearNextValues(buf v3.Reader, values uint64, limit uint64) error { var value v3.Value for ; values > 0; values-- { - value.FromBytes(buf) + err := value.FromBytes(buf, limit) + if err != nil { + return err + } values += value.Childrenn } + return nil } diff --git a/v3/v3.go b/v3/v3.go index eb24a49..d9d7188 100644 --- a/v3/v3.go +++ b/v3/v3.go @@ -13,7 +13,7 @@ var ( const ( corruptinputdata = "Not enough data in the datastream, imput might be corrupt." - oversizedInputData = "imput too big." + oversizedInputData = "Imput hit allocation limmit." ) const ( diff --git a/v3/value.go b/v3/value.go index 17c472e..a4247bd 100644 --- a/v3/value.go +++ b/v3/value.go @@ -4,7 +4,6 @@ import ( "encoding/binary" "errors" "io" - "math" ) type ( @@ -66,7 +65,7 @@ func (v *Value) Tobytes(out io.Writer, varintbuf *[binary.MaxVarintLen64 + 1]byt } //FromBytes reads bytes from a v3.Reader into Value -func (v *Value) FromBytes(in Reader) error { +func (v *Value) FromBytes(in Reader, limit uint64) error { vlen, err := readerReadUvarint(in) if err != nil { return errors.New(corruptinputdata) @@ -75,7 +74,7 @@ func (v *Value) FromBytes(in Reader) error { if err != nil { return errors.New(corruptinputdata) } - if vlen > math.MaxInt64-3-klen { + if 1+vlen+1+klen+1 > limit { return errors.New(oversizedInputData) } data := make([]byte, 1+vlen+1+klen+1)