-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtls-cache.go
166 lines (147 loc) · 4.22 KB
/
tls-cache.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
package fleet
import (
"crypto"
"crypto/tls"
"crypto/x509"
"encoding/pem"
"errors"
"fmt"
"log/slog"
"sync"
"time"
"github.com/KarpelesLab/tpmlib"
)
type crtCache struct {
a *Agent
k string
lk sync.Mutex
t time.Time
crt *tls.Certificate
exp time.Time // expiration time
err error
}
func (c *crtCache) GetClientCertificate(*tls.CertificateRequestInfo) (*tls.Certificate, error) {
crt, _ := c.GetCertificate(nil)
if crt == nil {
// error happened, but we don't care, let's just try without certificate.
// Go documentation: unlike GetCertificate, GetClientCertificate must return non-nil
return &tls.Certificate{}, nil
}
return crt, nil
}
func (c *crtCache) GetCertificate(*tls.ClientHelloInfo) (*tls.Certificate, error) {
if time.Until(c.exp) > 0 && time.Since(c.t) < time.Hour*24 {
return c.crt, c.err
}
c.lk.Lock()
defer c.lk.Unlock()
if time.Until(c.exp) > 0 && time.Since(c.t) < time.Hour*24 {
return c.crt, c.err
}
c.t = time.Now()
c.crt, c.err = c.loadCert(true)
if c.err != nil {
slog.Warn(fmt.Sprintf("[tls] Failed to fetch %s certificate: %s", c.k, c.err), "event", "fleet:tls:fetch_fail")
}
return c.crt, c.err
}
func (c *crtCache) PrivateKey() (crypto.PrivateKey, error) {
if time.Since(c.t) < time.Hour*24 {
if c.err != nil {
return nil, c.err
}
return c.crt.PrivateKey, nil
}
c.lk.Lock()
defer c.lk.Unlock()
if time.Since(c.t) < time.Hour*24 {
if c.err != nil {
return nil, c.err
}
return c.crt.PrivateKey, nil
}
c.t = time.Now()
c.crt, c.err = c.loadCert(true)
if c.err != nil {
slog.Warn(fmt.Sprintf("[tls] Failed to fetch %s certificate: %s", c.k, c.err), "event", "fleet:tls:fetch_fail")
}
if c.err != nil {
return nil, c.err
}
return c.crt.PrivateKey, nil
}
// loadCert actually fetches the certificate and instanciates a tls.Certificate
func (c *crtCache) loadCert(allowRetry bool) (*tls.Certificate, error) {
crt, err := c.a.dbFleetLoad(c.k + ":crt")
if err != nil {
return nil, err
}
key, err := c.a.dbFleetGet(c.k + ":key")
if err != nil && c.k == "internal_key" {
// check for tpm key
var s crypto.Signer
s, err := tpmlib.GetKey()
if err == nil {
// we need to generate the appropriate object to use this certificate with the tpm
res := &tls.Certificate{}
var derBlock *pem.Block
for {
derBlock, crt = pem.Decode(crt)
if derBlock == nil {
break
}
if derBlock.Type == "CERTIFICATE" {
res.Certificate = append(res.Certificate, derBlock.Bytes)
}
}
if len(res.Certificate) == 0 {
return nil, errors.New("tls: failed to find any PEM data in internal_key:crt certificate input")
}
// note that we aren't checking if the certificate matches the key, maybe we should but it's not cheap on an external auth device
res.PrivateKey = s
return res, nil
}
}
if err != nil {
return nil, err
}
res, err := crtCacheLoadAndCheck(crt, key)
if err != nil {
// remove from local data cache and try again to see if that helps
if allowRetry {
c.a.dbFleetDel(c.k+":crt", c.k+":key")
return c.loadCert(false)
}
// give up
c.exp = time.Now().Add(time.Hour) // force retry in 1h
return nil, fmt.Errorf("while instanciating tls keypair: %w", err)
} else {
// set expiration 24 hours before actual expiration, typically we fetch the new cert sooner
c.exp = res.Leaf.NotAfter.Add(-24 * time.Hour)
}
return &res, nil
}
func crtCacheLoadAndCheck(certPEMBlock, keyPEMBlock []byte) (tls.Certificate, error) {
cert, err := tls.X509KeyPair(certPEMBlock, keyPEMBlock)
if err != nil {
return cert, err
}
// ensure leaf is loaded (tls.X509KeyPair will not set it, but maybe it will in the future?)
if cert.Leaf == nil {
x509Cert, err := x509.ParseCertificate(cert.Certificate[0])
if err != nil {
return cert, err
}
cert.Leaf = x509Cert
}
// check leaf for expiration (returning an error allows clearing cache & fetching of new certificate)
now := time.Now()
if now.Before(cert.Leaf.NotBefore) {
return cert, fmt.Errorf("certificate is not valid yet (now=%s notbefore=%s)", now, cert.Leaf.NotBefore)
}
if now.After(cert.Leaf.NotAfter.Add(-24 * time.Hour)) {
return cert, fmt.Errorf("certificate has expired (now=%s notafter=%s)", now, cert.Leaf.NotAfter)
}
// all good
return cert, nil
}