diff --git a/certtostore.go b/certtostore.go index 0a5a993..07c4ebd 100644 --- a/certtostore.go +++ b/certtostore.go @@ -18,6 +18,8 @@ package certtostore import ( "bytes" "crypto" + "crypto/ecdsa" + "crypto/elliptic" "crypto/rand" "crypto/rsa" "crypto/tls" @@ -28,6 +30,10 @@ import ( "io/ioutil" "os" "path/filepath" + + // BEGIN-INTERNAL + // intnernal content #1 + // END-INTERNAL ) const ( @@ -106,7 +112,7 @@ type FileStorage struct { certFile string caCertFile string keyFile string - key *rsa.PrivateKey + key crypto.Signer } // NewFileStorage sets up a new file storage struct for use by StoreCert. @@ -166,14 +172,29 @@ func (f FileStorage) CertificateChain() ([][]*x509.Certificate, error) { return certificateChain(cert, intermediate) } +var ecdsaCurves = map[int]elliptic.Curve{ + 224: elliptic.P224(), + 256: elliptic.P256(), + 384: elliptic.P384(), + 521: elliptic.P521(), +} + // Generate creates a new RSA private key and returns a signer that can be used to make a CSR for the key. func (f *FileStorage) Generate(opts GenerateOpts) (crypto.Signer, error) { + var err error switch opts.Algorithm { case RSA: - var err error f.key, err = rsa.GenerateKey(rand.Reader, opts.Size) return f.key, err + case EC: + curve, ok := ecdsaCurves[opts.Size] + if !ok { + return nil, fmt.Errorf("invalid ecdsa curve size: %d", opts.Size) + } + f.key, err = ecdsa.GenerateKey(curve, rand.Reader) + return f.key, err default: + return nil, fmt.Errorf("unsupported key type: %q", opts.Algorithm) } } diff --git a/certtostore_test.go b/certtostore_test.go index b6e8b21..2304bcb 100644 --- a/certtostore_test.go +++ b/certtostore_test.go @@ -173,67 +173,85 @@ func TestDecrypt(t *testing.T) { } func TestFileStore(t *testing.T) { - pem, err := testdata.Certificate() - if err != nil { - t.Fatalf("testdata.Certificate: %v", err) - } - xc, err := PEMToX509(pem) - if err != nil { - t.Fatalf("error decoding test certificate: %v", err) - } + for _, testCase := range []struct { + name string + opts GenerateOpts + }{ + { + name: "rsa-2048", + opts: GenerateOpts{ + Algorithm: RSA, + Size: 2048, + }, + }, + { + name: "ecdsa-p256", + opts: GenerateOpts{ + Algorithm: EC, + Size: 256, + }, + }, + } { + t.Run(testCase.name, func(t *testing.T) { + pem, err := testdata.Certificate() + if err != nil { + t.Fatalf("testdata.Certificate: %v", err) + } + xc, err := PEMToX509(pem) + if err != nil { + t.Fatalf("error decoding test certificate: %v", err) + } - dir, err := ioutil.TempDir("", "certstorage_test") - if err != nil { - t.Fatalf("failed to create temporary dir: %v", err) - } - tc := NewFileStorage(dir) - cert, err := tc.Cert() - if err != nil { - t.Errorf("error while reading empty cert: %v", err) - } - if cert != nil { - t.Errorf("expected cert on new file store to be nil, instead %v", cert) - } + dir, err := ioutil.TempDir("", "certstorage_test") + if err != nil { + t.Fatalf("failed to create temporary dir: %v", err) + } + tc := NewFileStorage(dir) + cert, err := tc.Cert() + if err != nil { + t.Errorf("error while reading empty cert: %v", err) + } + if cert != nil { + t.Errorf("expected cert on new file store to be nil, instead %v", cert) + } - cert, err = tc.Intermediate() - if err != nil { - t.Errorf("error while reading empty intermediate: %v", err) - } - if cert != nil { - t.Errorf("expected intermediate on new file store to be nil, instead %v", cert) - } + cert, err = tc.Intermediate() + if err != nil { + t.Errorf("error while reading empty intermediate: %v", err) + } + if cert != nil { + t.Errorf("expected intermediate on new file store to be nil, instead %v", cert) + } - opts := GenerateOpts{ - Algorithm: RSA, - Size: 2048, - } - signer, err := tc.Generate(opts) - if err != nil { - t.Errorf("failed to generate signer: %v", err) - } - _, err = x509.CreateCertificateRequest(rand.Reader, &x509.CertificateRequest{}, signer) - if err != nil { - t.Errorf("failed to create signed CSR with signer from Generate: %v", err) - } + signer, err := tc.Generate(testCase.opts) + if err != nil { + t.Errorf("failed to generate signer: %v", err) + } + _, err = x509.CreateCertificateRequest(rand.Reader, &x509.CertificateRequest{}, signer) + if err != nil { + t.Errorf("failed to create signed CSR with signer from Generate: %v", err) + } - if err := tc.Store(xc, xc); err != nil { - t.Errorf("store failed: %v", err) - } + if err := tc.Store(xc, xc); err != nil { + t.Errorf("store failed: %v", err) + } - cert, err = tc.Cert() - if err != nil { - t.Fatalf("error while reading back written cert: %v", err) - } - if !cert.Equal(xc) { - t.Errorf("expected read-back cert to match xc, instead it's %v", cert) - } + cert, err = tc.Cert() + if err != nil { + t.Fatalf("error while reading back written cert: %v", err) + } + if !cert.Equal(xc) { + t.Errorf("expected read-back cert to match xc, instead it's %v", cert) + } - cert, err = tc.Intermediate() - if err != nil { - t.Fatalf("error while reading back written intermediate: %v", err) - } - if !cert.Equal(xc) { - t.Errorf("expected read-back intermediate to match xc, instead it's %v", cert) + cert, err = tc.Intermediate() + if err != nil { + t.Fatalf("error while reading back written intermediate: %v", err) + } + if !cert.Equal(xc) { + t.Errorf("expected read-back intermediate to match xc, instead it's %v", cert) + } + }) } }