diff --git a/experimental/credentials/internal/spiffe.go b/experimental/credentials/internal/spiffe.go new file mode 100644 index 000000000000..74eb775388a3 --- /dev/null +++ b/experimental/credentials/internal/spiffe.go @@ -0,0 +1,75 @@ +/* + * + * Copyright 2020 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +// Package credentials defines APIs for parsing SPIFFE ID. +// +// All APIs in this package are experimental. +package internal + +import ( + "crypto/tls" + "crypto/x509" + "net/url" + + "google.golang.org/grpc/grpclog" +) + +var logger = grpclog.Component("credentials") + +// SPIFFEIDFromState parses the SPIFFE ID from State. If the SPIFFE ID format +// is invalid, return nil with warning. +func SPIFFEIDFromState(state tls.ConnectionState) *url.URL { + if len(state.PeerCertificates) == 0 || len(state.PeerCertificates[0].URIs) == 0 { + return nil + } + return SPIFFEIDFromCert(state.PeerCertificates[0]) +} + +// SPIFFEIDFromCert parses the SPIFFE ID from x509.Certificate. If the SPIFFE +// ID format is invalid, return nil with warning. +func SPIFFEIDFromCert(cert *x509.Certificate) *url.URL { + if cert == nil || cert.URIs == nil { + return nil + } + var spiffeID *url.URL + for _, uri := range cert.URIs { + if uri == nil || uri.Scheme != "spiffe" || uri.Opaque != "" || (uri.User != nil && uri.User.Username() != "") { + continue + } + // From this point, we assume the uri is intended for a SPIFFE ID. + if len(uri.String()) > 2048 { + logger.Warning("invalid SPIFFE ID: total ID length larger than 2048 bytes") + return nil + } + if len(uri.Host) == 0 || len(uri.Path) == 0 { + logger.Warning("invalid SPIFFE ID: domain or workload ID is empty") + return nil + } + if len(uri.Host) > 255 { + logger.Warning("invalid SPIFFE ID: domain length larger than 255 characters") + return nil + } + // A valid SPIFFE certificate can only have exactly one URI SAN field. + if len(cert.URIs) > 1 { + logger.Warning("invalid SPIFFE ID: multiple URI SANs") + return nil + } + spiffeID = uri + } + return spiffeID +} diff --git a/experimental/credentials/internal/spiffe_test.go b/experimental/credentials/internal/spiffe_test.go new file mode 100644 index 000000000000..50af9ca5bff3 --- /dev/null +++ b/experimental/credentials/internal/spiffe_test.go @@ -0,0 +1,233 @@ +/* + * + * Copyright 2020 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package internal + +import ( + "crypto/tls" + "crypto/x509" + "encoding/pem" + "net/url" + "os" + "testing" + + "google.golang.org/grpc/internal/grpctest" + "google.golang.org/grpc/testdata" +) + +const wantURI = "spiffe://foo.bar.com/client/workload/1" + +type s struct { + grpctest.Tester +} + +func Test(t *testing.T) { + grpctest.RunSubTests(t, s{}) +} + +func (s) TestSPIFFEIDFromState(t *testing.T) { + tests := []struct { + name string + urls []*url.URL + // If we expect a SPIFFE ID to be returned. + wantID bool + }{ + { + name: "empty URIs", + urls: []*url.URL{}, + wantID: false, + }, + { + name: "good SPIFFE ID", + urls: []*url.URL{ + { + Scheme: "spiffe", + Host: "foo.bar.com", + Path: "workload/wl1", + RawPath: "workload/wl1", + }, + }, + wantID: true, + }, + { + name: "invalid host", + urls: []*url.URL{ + { + Scheme: "spiffe", + Host: "", + Path: "workload/wl1", + RawPath: "workload/wl1", + }, + }, + wantID: false, + }, + { + name: "invalid path", + urls: []*url.URL{ + { + Scheme: "spiffe", + Host: "foo.bar.com", + Path: "", + RawPath: "", + }, + }, + wantID: false, + }, + { + name: "large path", + urls: []*url.URL{ + { + Scheme: "spiffe", + Host: "foo.bar.com", + Path: string(make([]byte, 2050)), + RawPath: string(make([]byte, 2050)), + }, + }, + wantID: false, + }, + { + name: "large host", + urls: []*url.URL{ + { + Scheme: "spiffe", + Host: string(make([]byte, 256)), + Path: "workload/wl1", + RawPath: "workload/wl1", + }, + }, + wantID: false, + }, + { + name: "multiple URI SANs", + urls: []*url.URL{ + { + Scheme: "spiffe", + Host: "foo.bar.com", + Path: "workload/wl1", + RawPath: "workload/wl1", + }, + { + Scheme: "spiffe", + Host: "bar.baz.com", + Path: "workload/wl2", + RawPath: "workload/wl2", + }, + { + Scheme: "https", + Host: "foo.bar.com", + Path: "workload/wl1", + RawPath: "workload/wl1", + }, + }, + wantID: false, + }, + { + name: "multiple URI SANs without SPIFFE ID", + urls: []*url.URL{ + { + Scheme: "https", + Host: "foo.bar.com", + Path: "workload/wl1", + RawPath: "workload/wl1", + }, + { + Scheme: "ssh", + Host: "foo.bar.com", + Path: "workload/wl1", + RawPath: "workload/wl1", + }, + }, + wantID: false, + }, + { + name: "multiple URI SANs with one SPIFFE ID", + urls: []*url.URL{ + { + Scheme: "spiffe", + Host: "foo.bar.com", + Path: "workload/wl1", + RawPath: "workload/wl1", + }, + { + Scheme: "https", + Host: "foo.bar.com", + Path: "workload/wl1", + RawPath: "workload/wl1", + }, + }, + wantID: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + state := tls.ConnectionState{PeerCertificates: []*x509.Certificate{{URIs: tt.urls}}} + id := SPIFFEIDFromState(state) + if got, want := id != nil, tt.wantID; got != want { + t.Errorf("want wantID = %v, but SPIFFE ID is %v", want, id) + } + }) + } +} + +func (s) TestSPIFFEIDFromCert(t *testing.T) { + tests := []struct { + name string + dataPath string + // If we expect a SPIFFE ID to be returned. + wantID bool + }{ + { + name: "good certificate with SPIFFE ID", + dataPath: "x509/spiffe_cert.pem", + wantID: true, + }, + { + name: "bad certificate with SPIFFE ID and another URI", + dataPath: "x509/multiple_uri_cert.pem", + wantID: false, + }, + { + name: "certificate without SPIFFE ID", + dataPath: "x509/client1_cert.pem", + wantID: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + data, err := os.ReadFile(testdata.Path(tt.dataPath)) + if err != nil { + t.Fatalf("os.ReadFile(%s) failed: %v", testdata.Path(tt.dataPath), err) + } + block, _ := pem.Decode(data) + if block == nil { + t.Fatalf("Failed to parse the certificate: byte block is nil") + } + cert, err := x509.ParseCertificate(block.Bytes) + if err != nil { + t.Fatalf("x509.ParseCertificate(%b) failed: %v", block.Bytes, err) + } + uri := SPIFFEIDFromCert(cert) + if (uri != nil) != tt.wantID { + t.Fatalf("wantID got and want mismatch, got %t, want %t", uri != nil, tt.wantID) + } + if uri != nil && uri.String() != wantURI { + t.Fatalf("SPIFFE ID not expected, got %s, want %s", uri.String(), wantURI) + } + }) + } +} diff --git a/experimental/credentials/internal/syscallconn.go b/experimental/credentials/internal/syscallconn.go new file mode 100644 index 000000000000..6f5f88c62bfb --- /dev/null +++ b/experimental/credentials/internal/syscallconn.go @@ -0,0 +1,58 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package internal + +import ( + "net" + "syscall" +) + +type sysConn = syscall.Conn + +// syscallConn keeps reference of rawConn to support syscall.Conn for channelz. +// SyscallConn() (the method in interface syscall.Conn) is explicitly +// implemented on this type, +// +// Interface syscall.Conn is implemented by most net.Conn implementations (e.g. +// TCPConn, UnixConn), but is not part of net.Conn interface. So wrapper conns +// that embed net.Conn don't implement syscall.Conn. (Side note: tls.Conn +// doesn't embed net.Conn, so even if syscall.Conn is part of net.Conn, it won't +// help here). +type syscallConn struct { + net.Conn + // sysConn is a type alias of syscall.Conn. It's necessary because the name + // `Conn` collides with `net.Conn`. + sysConn +} + +// WrapSyscallConn tries to wrap rawConn and newConn into a net.Conn that +// implements syscall.Conn. rawConn will be used to support syscall, and newConn +// will be used for read/write. +// +// This function returns newConn if rawConn doesn't implement syscall.Conn. +func WrapSyscallConn(rawConn, newConn net.Conn) net.Conn { + sysConn, ok := rawConn.(syscall.Conn) + if !ok { + return newConn + } + return &syscallConn{ + Conn: newConn, + sysConn: sysConn, + } +} diff --git a/experimental/credentials/internal/syscallconn_test.go b/experimental/credentials/internal/syscallconn_test.go new file mode 100644 index 000000000000..12d1ad858314 --- /dev/null +++ b/experimental/credentials/internal/syscallconn_test.go @@ -0,0 +1,56 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package internal + +import ( + "net" + "syscall" + "testing" +) + +func (*syscallConn) SyscallConn() (syscall.RawConn, error) { + return nil, nil +} + +type nonSyscallConn struct { + net.Conn +} + +func (s) TestWrapSyscallConn(t *testing.T) { + sc := &syscallConn{} + nsc := &nonSyscallConn{} + + wrapConn := WrapSyscallConn(sc, nsc) + if _, ok := wrapConn.(syscall.Conn); !ok { + t.Errorf("returned conn (type %T) doesn't implement syscall.Conn, want implement", wrapConn) + } +} + +func (s) TestWrapSyscallConnNoWrap(t *testing.T) { + nscRaw := &nonSyscallConn{} + nsc := &nonSyscallConn{} + + wrapConn := WrapSyscallConn(nscRaw, nsc) + if _, ok := wrapConn.(syscall.Conn); ok { + t.Errorf("returned conn (type %T) implements syscall.Conn, want not implement", wrapConn) + } + if wrapConn != nsc { + t.Errorf("returned conn is %p, want %p (the passed-in newConn)", wrapConn, nsc) + } +} diff --git a/experimental/credentials/tls.go b/experimental/credentials/tls.go index 0a4457217237..b363db56520d 100644 --- a/experimental/credentials/tls.go +++ b/experimental/credentials/tls.go @@ -27,9 +27,10 @@ import ( "net/url" "os" + "golang.org/x/net/http2" "google.golang.org/grpc/credentials" + "google.golang.org/grpc/experimental/credentials/internal" "google.golang.org/grpc/grpclog" - credinternal "google.golang.org/grpc/internal/credentials" ) var logger = grpclog.Component("credentials") @@ -91,7 +92,7 @@ func (c tlsCreds) Info() credentials.ProtocolInfo { func (c *tlsCreds) ClientHandshake(ctx context.Context, authority string, rawConn net.Conn) (_ net.Conn, _ credentials.AuthInfo, err error) { // use local cfg to avoid clobbering ServerName if using multiple endpoints - cfg := credinternal.CloneTLSConfig(c.config) + cfg := cloneTLSConfig(c.config) if cfg.ServerName == "" { serverName, _, err := net.SplitHostPort(authority) if err != nil { @@ -123,11 +124,11 @@ func (c *tlsCreds) ClientHandshake(ctx context.Context, authority string, rawCon SecurityLevel: credentials.PrivacyAndIntegrity, }, } - id := credinternal.SPIFFEIDFromState(conn.ConnectionState()) + id := internal.SPIFFEIDFromState(conn.ConnectionState()) if id != nil { tlsInfo.SPIFFEID = id } - return credinternal.WrapSyscallConn(rawConn, conn), tlsInfo, nil + return internal.WrapSyscallConn(rawConn, conn), tlsInfo, nil } func (c *tlsCreds) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { @@ -143,11 +144,11 @@ func (c *tlsCreds) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.Auth SecurityLevel: credentials.PrivacyAndIntegrity, }, } - id := credinternal.SPIFFEIDFromState(conn.ConnectionState()) + id := internal.SPIFFEIDFromState(conn.ConnectionState()) if id != nil { tlsInfo.SPIFFEID = id } - return credinternal.WrapSyscallConn(rawConn, conn), tlsInfo, nil + return internal.WrapSyscallConn(rawConn, conn), tlsInfo, nil } func (c *tlsCreds) Clone() credentials.TransportCredentials { @@ -175,8 +176,8 @@ var tls12ForbiddenCipherSuites = map[uint16]struct{}{ // NewTLSWithALPNDisabled uses c to construct a TransportCredentials based on // TLS. ALPN verification is disabled. func NewTLSWithALPNDisabled(c *tls.Config) credentials.TransportCredentials { - tc := &tlsCreds{credinternal.CloneTLSConfig(c)} - tc.config.NextProtos = credinternal.AppendH2ToNextProtos(tc.config.NextProtos) + tc := &tlsCreds{cloneTLSConfig(c)} + tc.config.NextProtos = appendH2ToNextProtos(tc.config.NextProtos) // If the user did not configure a MinVersion and did not configure a // MaxVersion < 1.2, use MinVersion=1.2, which is required by // https://datatracker.ietf.org/doc/html/rfc7540#section-9.2 @@ -257,3 +258,28 @@ type TLSChannelzSecurityValue struct { LocalCertificate []byte RemoteCertificate []byte } + +// cloneTLSConfig returns a shallow clone of the exported +// fields of cfg, ignoring the unexported sync.Once, which +// contains a mutex and must not be copied. +// +// If cfg is nil, a new zero tls.Config is returned. +func cloneTLSConfig(cfg *tls.Config) *tls.Config { + if cfg == nil { + return &tls.Config{} + } + + return cfg.Clone() +} + +// appendH2ToNextProtos appends h2 to next protos. +func appendH2ToNextProtos(ps []string) []string { + for _, p := range ps { + if p == http2.NextProtoTLS { + return ps + } + } + ret := make([]string, 0, len(ps)+1) + ret = append(ret, ps...) + return append(ret, http2.NextProtoTLS) +}