Skip to content

Commit

Permalink
fix(protoio): Better error when passing in unexpected type
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoPolo committed Feb 3, 2025
1 parent 0967bbf commit 29cd5a3
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 3 deletions.
51 changes: 49 additions & 2 deletions msgio_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,15 @@ import (
"fmt"
"io"
"math/rand"
str "strings"
"strings"
"sync"
"testing"
"time"

"github.com/libp2p/go-msgio/pbio/pb"
"github.com/libp2p/go-msgio/protoio"

Check failure on line 15 in msgio_test.go

View workflow job for this annotation

GitHub Actions / go-check / All

"github.com/libp2p/go-msgio/protoio" is deprecated: GoGo Protobuf is deprecated and unmaintained. (SA1019)
"github.com/multiformats/go-varint"
"google.golang.org/protobuf/proto"
)

func randBuf(r *rand.Rand, size int) []byte {
Expand Down Expand Up @@ -79,7 +84,7 @@ func TestMultiError(t *testing.T) {
}

twoErrors := multiErr([]error{errors.New("one"), errors.New("two")})
if eStr := twoErrors.Error(); !str.Contains(eStr, "one") && !str.Contains(eStr, "two") {
if eStr := twoErrors.Error(); !strings.Contains(eStr, "one") && !strings.Contains(eStr, "two") {
t.Fatal("Expected error messages not included")
}
}
Expand Down Expand Up @@ -328,3 +333,45 @@ func SubtestReadShortBuffer(t *testing.T, writer WriteCloser, reader ReadCloser)
t.Fatal("Expected short buffer error")
}
}

func TestHandleProtoGeneratedByGoogleProtobufInProtoio(t *testing.T) {
record := &pb.TestRecord{
Uint32: 42,
Uint64: 84,
Bytes: []byte("test bytes"),
String_: "test string",
Int32: -42,
Int64: -84,
}

recordBytes, err := proto.Marshal(record)
if err != nil {
t.Fatal(err)
}

for _, tc := range []string{"read", "write"} {
t.Run(tc, func(t *testing.T) {
var buf bytes.Buffer
readRecord := &pb.TestRecord{}
switch tc {
case "read":
buf.Write(varint.ToUvarint(uint64(len(recordBytes))))
buf.Write(recordBytes)

reader := protoio.NewDelimitedReader(&buf, 1024)
defer reader.Close()
err = reader.ReadMsg(readRecord)
case "write":
writer := protoio.NewDelimitedWriter(&buf)
err = writer.WriteMsg(record)
}
if err == nil {
t.Fatal("expected error")
}
expectedError := "google Protobuf message passed into a GoGo Protobuf"
if !strings.Contains(err.Error(), expectedError) {
t.Fatalf("expected error to contain '%s'", expectedError)
}
})
}
}
15 changes: 15 additions & 0 deletions protoio/isgoog.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package protoio

import (
"github.com/gogo/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"
)

// isGoogleProtobufMsg checks if the given proto.Message was
// generated by the official Google protobuf compiler
func isGoogleProtobufMsg(msg proto.Message) bool {
_, ok := msg.(interface {
ProtoReflect() protoreflect.Message
})
return ok
}
19 changes: 18 additions & 1 deletion protoio/uvarint_reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ package protoio

import (
"bufio"
"errors"
"fmt"
"io"
"os"
Expand Down Expand Up @@ -82,7 +83,23 @@ func (ur *uvarintReader) ReadMsg(msg proto.Message) (err error) {
if _, err := io.ReadFull(ur.r, buf); err != nil {
return err
}
return proto.Unmarshal(buf, msg)

// Hoist up gogo's proto.Unmarshal logic so we can also check if this is a google protobuf message
msg.Reset()
if u, ok := msg.(interface {
XXX_Unmarshal([]byte) error
}); ok {
return u.XXX_Unmarshal(buf)
} else if u, ok := msg.(interface {
Unmarshal([]byte) error
}); ok {
return u.Unmarshal(buf)
} else if isGoogleProtobufMsg(msg) {
return errors.New("google Protobuf message passed into a GoGo Protobuf reader. Use github.com/libp2p/go-msgio/pbio instead of github.com/gogo/protobuf/proto")
}

// Fallback to GoGo's proto.Unmarshal around this buffer
return proto.NewBuffer(buf).Unmarshal(msg)
}

func (ur *uvarintReader) Close() error {
Expand Down
5 changes: 5 additions & 0 deletions protoio/uvarint_writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
package protoio

import (
"errors"
"fmt"
"io"
"os"
Expand Down Expand Up @@ -80,6 +81,10 @@ func (uw *uvarintWriter) WriteMsg(msg proto.Message) (err error) {
}
}

if isGoogleProtobufMsg(msg) {
return errors.New("google Protobuf message passed into a GoGo Protobuf writer. Use github.com/libp2p/go-msgio/pbio instead of github.com/gogo/protobuf/proto")
}

// fallback
data, err = proto.Marshal(msg)
if err != nil {
Expand Down

0 comments on commit 29cd5a3

Please sign in to comment.