Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Specify one (from many) top-level messages #21

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 17 additions & 3 deletions content_builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,22 @@ func (b *contentBuilder) build() (string, error) {
return "", errors.New("contentBuilder.build(): protoFile is nil")
}

if len(b.file.GetMessageType()) != 1 {
return "", errors.New(b.file.GetName() + ": only one top-level type may be defined in a file (see https://cloud.google.com/pubsub/docs/schemas#schema_types). use nested types or imports (see https://developers.google.com/protocol-buffers/docs/proto)")
messageTopLevel := []*descriptorpb.DescriptorProto{}
messageTypes := b.file.GetMessageType()

// If the file contains multiple messages
if len(messageTypes) > 1 {
// Then, if there's no top-level message specified using the parameter
if b.messageTopLevel == "" {
return "", errors.New(b.file.GetName() + ": contains multiple top-level messages and no top-level message is identified, use --pubsub-schema_opt=top-level-message={message} to specify one top-level message.")
}

// Because there are multiple top-level messages, lookup the one defined by the top-level-message parameter
for _, messageType := range messageTypes {
if *messageType.Name == b.messageTopLevel {
messageTopLevel = append(messageTopLevel, messageType)
}
}
}

fmt.Fprintf(b.output, "// Code generated by protoc-gen-pubsub-schema. DO NOT EDIT.\n")
Expand All @@ -34,7 +48,7 @@ func (b *contentBuilder) build() (string, error) {
fmt.Fprintf(b.output, "// source: %s\n\n", b.file.GetName())
fmt.Fprintf(b.output, "syntax = \"%s\";\n\n", b.schemaSyntax)
fmt.Fprintf(b.output, "package %s;\n", b.file.GetPackage())
b.buildMessages(0, b.file.GetMessageType())
b.buildMessages(0, messageTopLevel)
b.buildEnums(0, b.file.GetEnumType())
return b.output.String(), nil
}
Expand Down
40 changes: 40 additions & 0 deletions response_builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,33 @@ package main

import (
"errors"
"fmt"
"regexp"
"strings"

"google.golang.org/protobuf/types/descriptorpb"
"google.golang.org/protobuf/types/pluginpb"
)

const (
// Per Buf's definition of identifier
// https://protobuf.com/docs/language-spec#identifiers-and-keywords
identifier string = `[A-Za-z_]([A-Za-z_]|[0-9])*`

// Parse parameter of the form top-level-message={message}
keyTopLevelMessage string = "top-level-message"
)

var (
// RegEx to extract Message name from top-level-message parameter
regexTopLevelMessage string = fmt.Sprintf("%s=(%s)", keyTopLevelMessage, identifier)
)

type responseBuilder struct {
request *pluginpb.CodeGeneratorRequest
schemaSyntax string
messageEncoding string
messageTopLevel string
protoFiles map[string]*descriptorpb.FileDescriptorProto
messageTypes map[string]*descriptorpb.DescriptorProto
enums map[string]*descriptorpb.EnumDescriptorProto
Expand All @@ -22,10 +39,12 @@ func newResponseBuilder(request *pluginpb.CodeGeneratorRequest) (*responseBuilde
if request == nil {
return nil, errors.New("newResponseBuilder(request *pluginpb.CodeGeneratorRequest): request is nil")
}

builder := &responseBuilder{
request,
getSyntax(request),
getEncoding(request),
getTopLevelMessage(request),
make(map[string]*descriptorpb.FileDescriptorProto),
make(map[string]*descriptorpb.DescriptorProto),
make(map[string]*descriptorpb.EnumDescriptorProto),
Expand All @@ -50,6 +69,27 @@ func getEncoding(request *pluginpb.CodeGeneratorRequest) string {
return "binary"
}

func getTopLevelMessage(request *pluginpb.CodeGeneratorRequest) string {
parameter := request.GetParameter()

// For consistency with getSyntax and getEncoding, check whether parameter contains key
if strings.Contains(parameter, fmt.Sprintf("%s=", keyTopLevelMessage)) {
// If the parameter contains the key, use a more expensive regex to extract the value
re := regexp.MustCompile(regexTopLevelMessage)
messages := re.FindAllStringSubmatch(parameter, -1)

// Expect single occurrence (not top-level-message=Foo,top-level-message=Bar)
if len(messages) == 1 {
// Don't return the entire substring ([0]), i.e. top-level-message=message
// Only return the value ([1]) i.e. message
return messages[0][1]
}
}

// Otherwise unable to determine top-level messsage name
return ""
}

func (b *responseBuilder) initProtoFiles() {
for _, protoFile := range b.request.GetProtoFile() {
b.protoFiles[protoFile.GetName()] = protoFile
Expand Down
51 changes: 51 additions & 0 deletions response_builder_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
package main

import (
"fmt"
"testing"

"google.golang.org/protobuf/types/pluginpb"
)

func Test_getTopLevelMessage(t *testing.T) {
tests := []struct {
name string
parameter string
want string
}{
{
name: "empty",
parameter: "",
want: "",
},
{
// Return value on valid value
name: "single valid value",
parameter: fmt.Sprintf("blah,blah,%s=Foo,blah,blah", keyTopLevelMessage),
want: "Foo",
},
{
// Return empty string on invalid value
name: "single invalid value",
parameter: "blah,blah,root-message=9,blah,blah",
want: "",
},
{
// Return empty string for multiple values
name: "multiple values",
parameter: fmt.Sprintf("blah,blah,%s=Foo,blah,%s=Bar,blah", keyTopLevelMessage, keyTopLevelMessage),
want: "",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
request := &pluginpb.CodeGeneratorRequest{
Parameter: &tt.parameter,
}
if got := getTopLevelMessage(request); got != tt.want {
t.Errorf("getTopLevelMessage() = %v, want %v", got, tt.want)
}
})
}
}