Skip to content

Commit

Permalink
Use atomic.Pointer for certState (#833)
Browse files Browse the repository at this point in the history
  • Loading branch information
nbrownus authored Mar 30, 2023
1 parent 2801fb2 commit 6b3d42e
Show file tree
Hide file tree
Showing 5 changed files with 12 additions and 11 deletions.
6 changes: 3 additions & 3 deletions connection_manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,12 @@ func Test_NewConnectionManagerTest(t *testing.T) {
hostMap: hostMap,
inside: &test.NoopTun{},
outside: &udp.Conn{},
certState: cs,
firewall: &Firewall{},
lightHouse: lh,
handshakeManager: NewHandshakeManager(l, vpncidr, preferredRanges, hostMap, lh, &udp.Conn{}, defaultHandshakeConfig),
l: l,
}
ifce.certState.Store(cs)
now := time.Now()

// Create manager
Expand Down Expand Up @@ -130,12 +130,12 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
hostMap: hostMap,
inside: &test.NoopTun{},
outside: &udp.Conn{},
certState: cs,
firewall: &Firewall{},
lightHouse: lh,
handshakeManager: NewHandshakeManager(l, vpncidr, preferredRanges, hostMap, lh, &udp.Conn{}, defaultHandshakeConfig),
l: l,
}
ifce.certState.Store(cs)
now := time.Now()

// Create manager
Expand Down Expand Up @@ -245,14 +245,14 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
hostMap: hostMap,
inside: &test.NoopTun{},
outside: &udp.Conn{},
certState: cs,
firewall: &Firewall{},
lightHouse: lh,
handshakeManager: NewHandshakeManager(l, vpncidr, preferredRanges, hostMap, lh, &udp.Conn{}, defaultHandshakeConfig),
l: l,
disconnectInvalid: true,
caPool: ncp,
}
ifce.certState.Store(cs)

// Create manager
ctx, cancel := context.WithCancel(context.Background())
Expand Down
2 changes: 1 addition & 1 deletion connection_state.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ func (f *Interface) newConnectionState(l *logrus.Logger, initiator bool, pattern
cs = noise.NewCipherSuite(noise.DH25519, noise.CipherChaChaPoly, noise.HashSHA256)
}

curCertState := f.certState
curCertState := f.certState.Load()
static := noise.DHKey{Private: curCertState.privateKey, Public: curCertState.publicKey}

b := NewBits(ReplayWindow)
Expand Down
2 changes: 1 addition & 1 deletion control_tester.go
Original file line number Diff line number Diff line change
Expand Up @@ -161,5 +161,5 @@ func (c *Control) GetHostmap() *HostMap {
}

func (c *Control) GetCert() *cert.NebulaCertificate {
return c.f.certState.certificate
return c.f.certState.Load().certificate
}
11 changes: 6 additions & 5 deletions interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ type Interface struct {
hostMap *HostMap
outside *udp.Conn
inside overlay.Device
certState *CertState
certState atomic.Pointer[CertState]
cipher string
firewall *Firewall
connectionManager *connectionManager
Expand Down Expand Up @@ -141,7 +141,6 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
hostMap: c.HostMap,
outside: c.Outside,
inside: c.Inside,
certState: c.certState,
cipher: c.Cipher,
firewall: c.Firewall,
serveDns: c.ServeDns,
Expand Down Expand Up @@ -172,6 +171,7 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
l: c.l,
}

ifce.certState.Store(c.certState)
ifce.connectionManager = newConnectionManager(ctx, c.l, ifce, c.checkInterval, c.pendingDeletionInterval)

return ifce, nil
Expand Down Expand Up @@ -298,14 +298,15 @@ func (f *Interface) reloadCertKey(c *config.C) {
}

// did IP in cert change? if so, don't set
oldIPs := f.certState.certificate.Details.Ips
currentCert := f.certState.Load().certificate
oldIPs := currentCert.Details.Ips
newIPs := cs.certificate.Details.Ips
if len(oldIPs) > 0 && len(newIPs) > 0 && oldIPs[0].String() != newIPs[0].String() {
f.l.WithField("new_ip", newIPs[0]).WithField("old_ip", oldIPs[0]).Error("IP in new cert was different from old")
return
}

f.certState = cs
f.certState.Store(cs)
f.l.WithField("cert", cs.certificate).Info("Client cert refreshed from disk")
}

Expand All @@ -316,7 +317,7 @@ func (f *Interface) reloadFirewall(c *config.C) {
return
}

fw, err := NewFirewallFromConfig(f.l, f.certState.certificate, c)
fw, err := NewFirewallFromConfig(f.l, f.certState.Load().certificate, c)
if err != nil {
f.l.WithError(err).Error("Error while creating firewall during reload")
return
Expand Down
2 changes: 1 addition & 1 deletion ssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -753,7 +753,7 @@ func sshPrintCert(ifce *Interface, fs interface{}, a []string, w sshd.StringWrit
return nil
}

cert := ifce.certState.certificate
cert := ifce.certState.Load().certificate
if len(a) > 0 {
parsedIp := net.ParseIP(a[0])
if parsedIp == nil {
Expand Down

0 comments on commit 6b3d42e

Please sign in to comment.