diff --git a/msgio_test.go b/msgio_test.go index 4167e02..f26dd86 100644 --- a/msgio_test.go +++ b/msgio_test.go @@ -6,10 +6,15 @@ import ( "fmt" "io" "math/rand" - str "strings" + "strings" "sync" "testing" "time" + + "github.com/libp2p/go-msgio/pbio/pb" + "github.com/libp2p/go-msgio/protoio" + "github.com/multiformats/go-varint" + "google.golang.org/protobuf/proto" ) func randBuf(r *rand.Rand, size int) []byte { @@ -79,7 +84,7 @@ func TestMultiError(t *testing.T) { } twoErrors := multiErr([]error{errors.New("one"), errors.New("two")}) - if eStr := twoErrors.Error(); !str.Contains(eStr, "one") && !str.Contains(eStr, "two") { + if eStr := twoErrors.Error(); !strings.Contains(eStr, "one") && !strings.Contains(eStr, "two") { t.Fatal("Expected error messages not included") } } @@ -328,3 +333,45 @@ func SubtestReadShortBuffer(t *testing.T, writer WriteCloser, reader ReadCloser) t.Fatal("Expected short buffer error") } } + +func TestHandleProtoGeneratedByGoogleProtobufInProtoio(t *testing.T) { + record := &pb.TestRecord{ + Uint32: 42, + Uint64: 84, + Bytes: []byte("test bytes"), + String_: "test string", + Int32: -42, + Int64: -84, + } + + recordBytes, err := proto.Marshal(record) + if err != nil { + t.Fatal(err) + } + + for _, tc := range []string{"read", "write"} { + t.Run(tc, func(t *testing.T) { + var buf bytes.Buffer + readRecord := &pb.TestRecord{} + switch tc { + case "read": + buf.Write(varint.ToUvarint(uint64(len(recordBytes)))) + buf.Write(recordBytes) + + reader := protoio.NewDelimitedReader(&buf, 1024) + defer reader.Close() + err = reader.ReadMsg(readRecord) + case "write": + writer := protoio.NewDelimitedWriter(&buf) + err = writer.WriteMsg(record) + } + if err == nil { + t.Fatal("expected error") + } + expectedError := "google Protobuf message passed into a GoGo Protobuf" + if !strings.Contains(err.Error(), expectedError) { + t.Fatalf("expected error to contain '%s'", expectedError) + } + }) + } +} diff --git a/protoio/isgoog.go b/protoio/isgoog.go new file mode 100644 index 0000000..9247613 --- /dev/null +++ b/protoio/isgoog.go @@ -0,0 +1,15 @@ +package protoio + +import ( + "github.com/gogo/protobuf/proto" + "google.golang.org/protobuf/reflect/protoreflect" +) + +// isGoogleProtobufMsg checks if the given proto.Message was +// generated by the official Google protobuf compiler +func isGoogleProtobufMsg(msg proto.Message) bool { + _, ok := msg.(interface { + ProtoReflect() protoreflect.Message + }) + return ok +} diff --git a/protoio/uvarint_reader.go b/protoio/uvarint_reader.go index 6722cd3..f2d4731 100644 --- a/protoio/uvarint_reader.go +++ b/protoio/uvarint_reader.go @@ -34,6 +34,7 @@ package protoio import ( "bufio" + "errors" "fmt" "io" "os" @@ -82,7 +83,23 @@ func (ur *uvarintReader) ReadMsg(msg proto.Message) (err error) { if _, err := io.ReadFull(ur.r, buf); err != nil { return err } - return proto.Unmarshal(buf, msg) + + // Hoist up gogo's proto.Unmarshal logic so we can also check if this is a google protobuf message + msg.Reset() + if u, ok := msg.(interface { + XXX_Unmarshal([]byte) error + }); ok { + return u.XXX_Unmarshal(buf) + } else if u, ok := msg.(interface { + Unmarshal([]byte) error + }); ok { + return u.Unmarshal(buf) + } else if isGoogleProtobufMsg(msg) { + return errors.New("google Protobuf message passed into a GoGo Protobuf reader. Use github.com/libp2p/go-msgio/pbio instead of github.com/gogo/protobuf/proto") + } + + // Fallback to GoGo's proto.Unmarshal around this buffer + return proto.NewBuffer(buf).Unmarshal(msg) } func (ur *uvarintReader) Close() error { diff --git a/protoio/uvarint_writer.go b/protoio/uvarint_writer.go index e311075..ece42db 100644 --- a/protoio/uvarint_writer.go +++ b/protoio/uvarint_writer.go @@ -33,6 +33,7 @@ package protoio import ( + "errors" "fmt" "io" "os" @@ -80,6 +81,10 @@ func (uw *uvarintWriter) WriteMsg(msg proto.Message) (err error) { } } + if isGoogleProtobufMsg(msg) { + return errors.New("google Protobuf message passed into a GoGo Protobuf writer. Use github.com/libp2p/go-msgio/pbio instead of github.com/gogo/protobuf/proto") + } + // fallback data, err = proto.Marshal(msg) if err != nil {