diff --git a/cmd/verify/main.go b/cmd/verify/main.go index 718cca5..67a6c87 100644 --- a/cmd/verify/main.go +++ b/cmd/verify/main.go @@ -2,7 +2,9 @@ package main import ( "context" + "encoding/csv" "os" + "strings" "github.com/rs/zerolog/log" @@ -27,12 +29,22 @@ func main() { firestoreProjectID = v } + var extraCaCerts []string + if v, ok := os.LookupEnv("EXTRA_CA_CERTS"); ok { + var err error + extraCaCerts, err = csv.NewReader(strings.NewReader(v)).Read() + if err != nil { + log.Fatal().Err(err).Msg("failed to parse $EXTRA_CA_CERTS (expected comma-separated list of file paths)") + } + } + srv := verify.New( verify.WithBindAddress(addr), verify.WithFirestoreProjectID(firestoreProjectID), verify.WithJWKSEndpoint(jwksEndpoint), verify.WithExpectedJWTIssuer(os.Getenv("EXPECTED_JWT_ISSUER")), verify.WithExpectedJWTAudience(os.Getenv("EXPECTED_JWT_AUDIENCE")), + verify.WithExtraCACerts(extraCaCerts...), ) err := srv.Run(context.Background()) if err != nil { diff --git a/config.go b/config.go index 6036f73..3a71bfb 100644 --- a/config.go +++ b/config.go @@ -15,6 +15,7 @@ type config struct { jwksEndpoint string expectedJWTIssuer string expectedJWTAudience string + extraCACerts []string } // An Option customizes the config. @@ -57,6 +58,15 @@ func WithFirestoreProjectID(projectID string) Option { } } +// WithExtraCACerts adds paths to custom CA certificates to the config. +// Certificates added with this option will be used in addition to the system +// default pool. +func WithExtraCACerts(paths ...string) Option { + return func(cfg *config) { + cfg.extraCACerts = append(cfg.extraCACerts, paths...) + } +} + func getConfig(options ...Option) *config { cfg := new(config) WithBindAddress(DefaultBindAddress)(cfg) diff --git a/tls.go b/tls.go index 4945c06..531d06f 100644 --- a/tls.go +++ b/tls.go @@ -14,13 +14,21 @@ import ( const maxRemoteWait = 5 * time.Second +type tlsVerifierOptions struct { + rootCAs *x509.CertPool +} + type tlsVerifier struct { + tlsVerifierOptions mu sync.Mutex errors map[string]error } -func newTLSVerifier() *tlsVerifier { - return &tlsVerifier{errors: make(map[string]error)} +func newTLSVerifier(opts tlsVerifierOptions) *tlsVerifier { + return &tlsVerifier{ + tlsVerifierOptions: opts, + errors: make(map[string]error), + } } func (v *tlsVerifier) DialTLSContext(ctx context.Context, network, addr string) (net.Conn, error) { @@ -57,6 +65,7 @@ func (v *tlsVerifier) VerifyPeerCertificate(serverName string, rawCerts [][]byte opts := x509.VerifyOptions{ DNSName: serverName, + Roots: v.rootCAs, Intermediates: x509.NewCertPool(), } for _, cert := range certs[1:] { diff --git a/verify.go b/verify.go index 16b1e9a..78db123 100644 --- a/verify.go +++ b/verify.go @@ -2,8 +2,10 @@ package verify import ( "context" + "crypto/x509" "net" "net/http" + "os" "cloud.google.com/go/firestore" "github.com/go-chi/chi" @@ -26,9 +28,30 @@ type Server struct { // New creates a new Server. func New(options ...Option) *Server { cfg := getConfig(options...) + + var verifierOpts tlsVerifierOptions + if len(cfg.extraCACerts) > 0 { + pool, err := x509.SystemCertPool() + if err != nil { + panic(err) + } + for _, certPath := range cfg.extraCACerts { + cert, err := os.ReadFile(certPath) + if err != nil { + log.Fatal().Err(err).Str("path", certPath).Msg("failed to read CA cert") + } else { + log.Info().Str("path", certPath).Msg("adding extra CA cert") + } + ok := pool.AppendCertsFromPEM(cert) + if !ok { + log.Warn().Str("path", certPath).Msg("no CA certs found in file") + } + } + verifierOpts.rootCAs = pool + } return &Server{ cfg: cfg, - tlsVerifier: newTLSVerifier(), + tlsVerifier: newTLSVerifier(verifierOpts), } }