Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cleanups and refactor of protosanitizer #184

Merged
merged 6 commits into from
Dec 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 2 additions & 16 deletions connection/connection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,20 +53,6 @@ const (
serverSock = "server.sock"
)

type identityServer struct{}

func (ids *identityServer) GetPluginInfo(ctx context.Context, req *csi.GetPluginInfoRequest) (*csi.GetPluginInfoResponse, error) {
return nil, status.Error(codes.Unimplemented, "Unimplemented")
}

func (ids *identityServer) Probe(ctx context.Context, req *csi.ProbeRequest) (*csi.ProbeResponse, error) {
return nil, status.Error(codes.Unimplemented, "Unimplemented")
}

func (ids *identityServer) GetPluginCapabilities(ctx context.Context, req *csi.GetPluginCapabilitiesRequest) (*csi.GetPluginCapabilitiesResponse, error) {
return nil, status.Error(codes.Unimplemented, "Unimplemented")
}

// startServer creates a gRPC server without any registered services.
// The returned address can be used to connect to it. The cleanup
// function stops it. It can be called multiple times.
Expand Down Expand Up @@ -458,7 +444,7 @@ func TestConnectMetrics(t *testing.T) {
cmmServer := metrics.NewCSIMetricsManagerForPlugin("fake.csi.driver.io")
// We have to have a real implementation of the gRPC call, otherwise the metrics
// interceptor is not called. The CSI identity service is used because it's simple.
addr, stopServer := startServer(t, tmp, &identityServer{}, nil, cmmServer)
addr, stopServer := startServer(t, tmp, &csi.UnimplementedIdentityServer{}, nil, cmmServer)
defer stopServer()

cmm := test.cmm
Expand Down Expand Up @@ -516,7 +502,7 @@ func TestConnectWithOtelGrpcInterceptorTraces(t *testing.T) {
defer os.RemoveAll(tmp)
// We have to have a real implementation of the gRPC call, otherwise the trace
// interceptor is not called. The CSI identity service is used because it's simple.
addr, stopServer := startServer(t, tmp, &identityServer{}, nil, nil)
addr, stopServer := startServer(t, tmp, &csi.UnimplementedIdentityServer{}, nil, nil)
defer stopServer()

_, ctx := ktesting.NewTestContext(t)
Expand Down
8 changes: 4 additions & 4 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,12 @@ module github.com/kubernetes-csi/csi-lib-utils
go 1.22.5

require (
github.com/container-storage-interface/spec v1.9.0
github.com/golang/protobuf v1.5.4
github.com/container-storage-interface/spec v1.10.0
github.com/stretchr/testify v1.9.0
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.46.0
go.opentelemetry.io/otel/trace v1.28.0
golang.org/x/net v0.26.0
google.golang.org/grpc v1.65.0
google.golang.org/protobuf v1.34.2
k8s.io/api v0.31.0
k8s.io/client-go v0.31.0
k8s.io/component-base v0.31.0
Expand All @@ -30,6 +29,7 @@ require (
github.com/go-openapi/swag v0.22.4 // indirect
github.com/gogo/protobuf v1.3.2 // indirect
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect
github.com/golang/protobuf v1.5.4 // indirect
github.com/google/gnostic-models v0.6.8 // indirect
github.com/google/go-cmp v0.6.0 // indirect
github.com/google/gofuzz v1.2.0 // indirect
Expand All @@ -51,13 +51,13 @@ require (
github.com/x448/float16 v0.8.4 // indirect
go.opentelemetry.io/otel v1.28.0 // indirect
go.opentelemetry.io/otel/metric v1.28.0 // indirect
golang.org/x/net v0.26.0 // indirect
golang.org/x/oauth2 v0.21.0 // indirect
golang.org/x/sys v0.22.0 // indirect
golang.org/x/term v0.21.0 // indirect
golang.org/x/text v0.16.0 // indirect
golang.org/x/time v0.3.0 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20240701130421-f6361c86f094 // indirect
google.golang.org/protobuf v1.34.2 // indirect
gopkg.in/inf.v0 v0.9.1 // indirect
gopkg.in/yaml.v2 v2.4.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UF
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/cncf/xds/go v0.0.0-20240423153145-555b57ec207b h1:ga8SEFjZ60pxLcmhnThWgvH2wg8376yUJmPhEH4H3kw=
github.com/cncf/xds/go v0.0.0-20240423153145-555b57ec207b/go.mod h1:W+zGtBO5Y1IgJhy4+A9GOqVhqLpfZi+vwmdNXUehLA8=
github.com/container-storage-interface/spec v1.9.0 h1:zKtX4STsq31Knz3gciCYCi1SXtO2HJDecIjDVboYavY=
github.com/container-storage-interface/spec v1.9.0/go.mod h1:ZfDu+3ZRyeVqxZM0Ds19MVLkN2d1XJ5MAfi1L3VjlT0=
github.com/container-storage-interface/spec v1.10.0 h1:YkzWPV39x+ZMTa6Ax2czJLLwpryrQ+dPesB34mrRMXA=
github.com/container-storage-interface/spec v1.10.0/go.mod h1:DtUvaQszPml1YJfIK7c00mlv6/g4wNMLanLgiUbKFRI=
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
Expand Down
168 changes: 56 additions & 112 deletions protosanitizer/protosanitizer.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,10 @@ package protosanitizer
import (
"encoding/json"
"fmt"
"reflect"
"strings"

"github.com/golang/protobuf/descriptor"
"github.com/golang/protobuf/proto"
protobufdescriptor "github.com/golang/protobuf/protoc-gen-go/descriptor"
"github.com/container-storage-interface/spec/lib/go/csi"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"
)

// StripSecrets returns a wrapper around the original CSI gRPC message
Expand All @@ -42,135 +40,81 @@ import (
// result to logging functions which may or may not end up serializing
// the parameter depending on the current log level.
func StripSecrets(msg interface{}) fmt.Stringer {
return &stripSecrets{msg, isCSI1Secret}
}

// StripSecretsCSI03 is like StripSecrets, except that it works
// for messages based on CSI 0.3 and older. It does not work
// for CSI 1.0, use StripSecrets for that.
func StripSecretsCSI03(msg interface{}) fmt.Stringer {
return &stripSecrets{msg, isCSI03Secret}
return &stripSecrets{msg}
}

type stripSecrets struct {
msg interface{}

isSecretField func(field *protobufdescriptor.FieldDescriptorProto) bool
msg any
}

func (s *stripSecrets) String() string {
// First convert to a generic representation. That's less efficient
// than using reflect directly, but easier to work with.
var parsed interface{}
b, err := json.Marshal(s.msg)
if err != nil {
return fmt.Sprintf("<<json.Marshal %T: %s>>", s.msg, err)
}
if err := json.Unmarshal(b, &parsed); err != nil {
return fmt.Sprintf("<<json.Unmarshal %T: %s>>", s.msg, err)
}
stripped := s.msg

// Now remove secrets from the generic representation of the message.
s.strip(parsed, s.msg)
// also support scalar types like string, int, etc.
msg, ok := s.msg.(proto.Message)
if ok {
stripped = stripMessage(msg.ProtoReflect())
}

// Re-encoded the stripped representation and return that.
b, err = json.Marshal(parsed)
b, err := json.Marshal(stripped)
if err != nil {
return fmt.Sprintf("<<json.Marshal %T: %s>>", s.msg, err)
}
return string(b)
}

func (s *stripSecrets) strip(parsed interface{}, msg interface{}) {
protobufMsg, ok := msg.(descriptor.Message)
if !ok {
// Not a protobuf message, so we are done.
return
func stripSingleValue(field protoreflect.FieldDescriptor, v protoreflect.Value) any {
switch field.Kind() {
case protoreflect.MessageKind:
return stripMessage(v.Message())
case protoreflect.EnumKind:
return field.Enum().Values().ByNumber(v.Enum()).Name()
default:
return v.Interface()
}
}

// The corresponding map in the parsed JSON representation.
parsedFields, ok := parsed.(map[string]interface{})
if !ok {
// Probably nil.
return
func stripValue(field protoreflect.FieldDescriptor, v protoreflect.Value) any {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can be stripValue() and stripSingleValue() combined into a single function with a bigger switch?

switch {
case field.IsList():
...

case field.IsMap():
...

case field.Kind() == protoreflect.MessageKind:
...
case field.Kind() == ...

default: return v.Interface()

Copy link
Contributor Author

@huww98 huww98 Dec 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is possible, but less clear, I think. Because for field that IsList(), its Kind() still returns the kind of list element. So, we will have multiple case statements matched.

My code structure actually looks very like the official encoder:
https://github.com/protocolbuffers/protobuf-go/blob/b98563540c0a4edb38526bcd6e6c97f9fac1f453/encoding/prototext/encode.go#L201-L212

(I'm not referencing it when writing, but reached agreement eventually :)

if field.IsList() {
l := v.List()
res := make([]any, l.Len())
for i := range l.Len() {
res[i] = stripSingleValue(field, l.Get(i))
Comment on lines +81 to +82
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For example, what happens when a list has a list? Will it call stripSingleValue() with a list, which then leads to default: branch there and thus print the list with potential secrets?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

List is represented as repeated field in the protobuf. Like:

  repeated VolumeCapability volume_capabilities = 3;

And it cannot be nested without going through a level of message. i.e.:

  repeated repeated VolumeCapability volume_capabilities = 3;

is not valid, but this one is:

  repeated Topology accessible_topology = 5;

message Topology {
  map<string, string> segments = 1;
}

Similarly, map<string, map<string, string>> or repeated map<string, string> are not valid.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the explanation!
/lgtm
/approve

}
return res
} else if field.IsMap() {
m := v.Map()
res := make(map[string]any, m.Len())
m.Range(func(mk protoreflect.MapKey, v protoreflect.Value) bool {
res[mk.String()] = stripSingleValue(field.MapValue(), v)
return true
})
return res
} else {
return stripSingleValue(field, v)
}
}

func stripMessage(msg protoreflect.Message) map[string]any {
stripped := make(map[string]any)

// Walk through all fields and replace those with ***stripped*** that
// are marked as secret. This relies on protobuf adding "json:" tags
// on each field where the name matches the field name in the protobuf
// spec (like volume_capabilities). The field.GetJsonName() method returns
// a different name (volumeCapabilities) which we don't use.
_, md := descriptor.ForMessage(protobufMsg)
fields := md.GetField()
if fields != nil {
for _, field := range fields {
if s.isSecretField(field) {
// Overwrite only if already set.
if _, ok := parsedFields[field.GetName()]; ok {
parsedFields[field.GetName()] = "***stripped***"
}
} else if field.GetType() == protobufdescriptor.FieldDescriptorProto_TYPE_MESSAGE {
// When we get here,
// the type name is something like ".csi.v1.CapacityRange" (leading dot!)
// and looking up "csi.v1.CapacityRange"
// returns the type of a pointer to a pointer
// to CapacityRange. We need a pointer to such
// a value for recursive stripping.
typeName := field.GetTypeName()
if strings.HasPrefix(typeName, ".") {
typeName = typeName[1:]
}
t := proto.MessageType(typeName)
if t == nil || t.Kind() != reflect.Ptr {
// Shouldn't happen, but
// better check anyway instead
// of panicking.
continue
}
v := reflect.New(t.Elem())

// Recursively strip the message(s) that
// the field contains.
i := v.Interface()
entry := parsedFields[field.GetName()]
if slice, ok := entry.([]interface{}); ok {
// Array of values, like VolumeCapabilities in CreateVolumeRequest.
for _, entry := range slice {
s.strip(entry, i)
}
} else {
// Single value.
s.strip(entry, i)
}
}
// are marked as secret.
msg.Range(func(field protoreflect.FieldDescriptor, v protoreflect.Value) bool {
name := field.TextName()
if isCSI1Secret(field) {
stripped[name] = "***stripped***"
} else {
stripped[name] = stripValue(field, v)
}
}
return true
})
return stripped
}

// isCSI1Secret uses the csi.E_CsiSecret extension from CSI 1.0 to
// determine whether a field contains secrets.
func isCSI1Secret(field *protobufdescriptor.FieldDescriptorProto) bool {
ex, err := proto.GetExtension(field.Options, e_CsiSecret)
return err == nil && ex != nil && *ex.(*bool)
}

// Copied from the CSI 1.0 spec (https://github.com/container-storage-interface/spec/blob/37e74064635d27c8e33537c863b37ccb1182d4f8/lib/go/csi/csi.pb.go#L4520-L4527)
// to avoid a package dependency that would prevent usage of this package
// in repos using an older version of the spec.
//
// Future revision of the CSI spec must not change this extensions, otherwise
// they will break filtering in binaries based on the 1.0 version of the spec.
var e_CsiSecret = &proto.ExtensionDesc{
ExtendedType: (*protobufdescriptor.FieldOptions)(nil),
ExtensionType: (*bool)(nil),
Field: 1059,
Name: "csi.v1.csi_secret",
Tag: "varint,1059,opt,name=csi_secret,json=csiSecret",
Filename: "github.com/container-storage-interface/spec/csi.proto",
}

// isCSI03Secret relies on the naming convention in CSI <= 0.3
// to determine whether a field contains secrets.
func isCSI03Secret(field *protobufdescriptor.FieldDescriptorProto) bool {
return strings.HasSuffix(field.GetName(), "_secrets")
func isCSI1Secret(desc protoreflect.FieldDescriptor) bool {
ex := proto.GetExtension(desc.Options(), csi.E_CsiSecret)
return ex.(bool)
}
Loading