diff --git a/content_builder.go b/content_builder.go index 312290b..7e5b005 100644 --- a/content_builder.go +++ b/content_builder.go @@ -24,8 +24,22 @@ func (b *contentBuilder) build() (string, error) { return "", errors.New("contentBuilder.build(): protoFile is nil") } - if len(b.file.GetMessageType()) != 1 { - return "", errors.New(b.file.GetName() + ": only one top-level type may be defined in a file (see https://cloud.google.com/pubsub/docs/schemas#schema_types). use nested types or imports (see https://developers.google.com/protocol-buffers/docs/proto)") + messageTopLevel := []*descriptorpb.DescriptorProto{} + messageTypes := b.file.GetMessageType() + + // If the file contains multiple messages + if len(messageTypes) > 1 { + // Then, if there's no top-level message specified using the parameter + if b.messageTopLevel == "" { + return "", errors.New(b.file.GetName() + ": contains multiple top-level messages and no top-level message is identified, use --pubsub-schema_opt=top-level-message={message} to specify one top-level message.") + } + + // Because there are multiple top-level messages, lookup the one defined by the top-level-message parameter + for _, messageType := range messageTypes { + if *messageType.Name == b.messageTopLevel { + messageTopLevel = append(messageTopLevel, messageType) + } + } } fmt.Fprintf(b.output, "// Code generated by protoc-gen-pubsub-schema. DO NOT EDIT.\n") @@ -34,7 +48,7 @@ func (b *contentBuilder) build() (string, error) { fmt.Fprintf(b.output, "// source: %s\n\n", b.file.GetName()) fmt.Fprintf(b.output, "syntax = \"%s\";\n\n", b.schemaSyntax) fmt.Fprintf(b.output, "package %s;\n", b.file.GetPackage()) - b.buildMessages(0, b.file.GetMessageType()) + b.buildMessages(0, messageTopLevel) b.buildEnums(0, b.file.GetEnumType()) return b.output.String(), nil } diff --git a/response_builder.go b/response_builder.go index 19ee1aa..e2768c0 100644 --- a/response_builder.go +++ b/response_builder.go @@ -2,16 +2,33 @@ package main import ( "errors" + "fmt" + "regexp" "strings" "google.golang.org/protobuf/types/descriptorpb" "google.golang.org/protobuf/types/pluginpb" ) +const ( + // Per Buf's definition of identifier + // https://protobuf.com/docs/language-spec#identifiers-and-keywords + identifier string = `[A-Za-z_]([A-Za-z_]|[0-9])*` + + // Parse parameter of the form top-level-message={message} + keyTopLevelMessage string = "top-level-message" +) + +var ( + // RegEx to extract Message name from top-level-message parameter + regexTopLevelMessage string = fmt.Sprintf("%s=(%s)", keyTopLevelMessage, identifier) +) + type responseBuilder struct { request *pluginpb.CodeGeneratorRequest schemaSyntax string messageEncoding string + messageTopLevel string protoFiles map[string]*descriptorpb.FileDescriptorProto messageTypes map[string]*descriptorpb.DescriptorProto enums map[string]*descriptorpb.EnumDescriptorProto @@ -22,10 +39,12 @@ func newResponseBuilder(request *pluginpb.CodeGeneratorRequest) (*responseBuilde if request == nil { return nil, errors.New("newResponseBuilder(request *pluginpb.CodeGeneratorRequest): request is nil") } + builder := &responseBuilder{ request, getSyntax(request), getEncoding(request), + getTopLevelMessage(request), make(map[string]*descriptorpb.FileDescriptorProto), make(map[string]*descriptorpb.DescriptorProto), make(map[string]*descriptorpb.EnumDescriptorProto), @@ -50,6 +69,27 @@ func getEncoding(request *pluginpb.CodeGeneratorRequest) string { return "binary" } +func getTopLevelMessage(request *pluginpb.CodeGeneratorRequest) string { + parameter := request.GetParameter() + + // For consistency with getSyntax and getEncoding, check whether parameter contains key + if strings.Contains(parameter, fmt.Sprintf("%s=", keyTopLevelMessage)) { + // If the parameter contains the key, use a more expensive regex to extract the value + re := regexp.MustCompile(regexTopLevelMessage) + messages := re.FindAllStringSubmatch(parameter, -1) + + // Expect single occurrence (not top-level-message=Foo,top-level-message=Bar) + if len(messages) == 1 { + // Don't return the entire substring ([0]), i.e. top-level-message=message + // Only return the value ([1]) i.e. message + return messages[0][1] + } + } + + // Otherwise unable to determine top-level messsage name + return "" +} + func (b *responseBuilder) initProtoFiles() { for _, protoFile := range b.request.GetProtoFile() { b.protoFiles[protoFile.GetName()] = protoFile diff --git a/response_builder_test.go b/response_builder_test.go new file mode 100644 index 0000000..521aa35 --- /dev/null +++ b/response_builder_test.go @@ -0,0 +1,51 @@ +package main + +import ( + "fmt" + "testing" + + "google.golang.org/protobuf/types/pluginpb" +) + +func Test_getTopLevelMessage(t *testing.T) { + tests := []struct { + name string + parameter string + want string + }{ + { + name: "empty", + parameter: "", + want: "", + }, + { + // Return value on valid value + name: "single valid value", + parameter: fmt.Sprintf("blah,blah,%s=Foo,blah,blah", keyTopLevelMessage), + want: "Foo", + }, + { + // Return empty string on invalid value + name: "single invalid value", + parameter: "blah,blah,root-message=9,blah,blah", + want: "", + }, + { + // Return empty string for multiple values + name: "multiple values", + parameter: fmt.Sprintf("blah,blah,%s=Foo,blah,%s=Bar,blah", keyTopLevelMessage, keyTopLevelMessage), + want: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + request := &pluginpb.CodeGeneratorRequest{ + Parameter: &tt.parameter, + } + if got := getTopLevelMessage(request); got != tt.want { + t.Errorf("getTopLevelMessage() = %v, want %v", got, tt.want) + } + }) + } +}