Skip to content

Commit

Permalink
Added auto-renewal of nebula certificates and GC on badgerDB (#22)
Browse files Browse the repository at this point in the history
* Added auto-renewal of nebula certificates and GC on badgerDB
  • Loading branch information
SlyngDK authored Nov 28, 2021
1 parent c89e42a commit b0c22d3
Show file tree
Hide file tree
Showing 6 changed files with 160 additions and 6 deletions.
6 changes: 6 additions & 0 deletions examples/server.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,9 @@
# authUrl: ""
# tokenUrl: ""
# userInfoUrl: ""

#tasks:
# certRenew:
# interval: 1h
# dbGC:
# interval: 5m
9 changes: 8 additions & 1 deletion server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ type server struct {
ipManager *store.IPManager
unixGrpc *grpc.Server
agentService *grpc.Server
tasks *tasks
}

func Main(config *nebula.Config, buildVersion string, logger *logrus.Logger) (*Control, error) {
Expand All @@ -33,7 +34,7 @@ func Main(config *nebula.Config, buildVersion string, logger *logrus.Logger) (*C
FullTimestamp: true,
}

server := server{l, config, buildVersion, false, nil, nil, nil, nil}
server := server{l, config, buildVersion, false, nil, nil, nil, nil, nil}

return &Control{l, server.start, server.stop, make(chan interface{})}, nil
}
Expand Down Expand Up @@ -91,12 +92,18 @@ func (s *server) start() error {
if err != nil {
return err
}
s.tasks = NewTasks(s.l, s.config, s.store)
s.tasks.Start()
}

return nil
}

func (s *server) stop() {
if s.tasks != nil {
s.tasks.Stop()
}

if err := s.stopAgentService(); err != nil {
s.l.WithError(err).Error("Failed to stop agentService server")
}
Expand Down
65 changes: 65 additions & 0 deletions server/store/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package store

import (
"fmt"
"time"

"github.com/dgraph-io/badger/v3"
"github.com/golang/protobuf/proto"
Expand Down Expand Up @@ -50,6 +51,23 @@ func (s *Store) ListAgentByNetwork(networkName string) ([]*Agent, error) {
return s.listAgentByNetwork(txn, networkName)
}

func (s Store) RenewCertForAgents() error {
txn := s.db.NewTransaction(true)
defer txn.Discard()

err := s.renewCertForAgents(txn)
if err != nil {
return fmt.Errorf("failed to new certificates for agents")
}

err = txn.Commit()
if err != nil {
return fmt.Errorf("failed to add enrollment token: %s", err)
}

return nil
}

func (s *Store) isAgentEnrolled(txn *badger.Txn, fingerprint []byte) bool {
return exists(txn, prefix_agent, fingerprint)
}
Expand Down Expand Up @@ -136,3 +154,50 @@ func (s *Store) deleteAgent(txn *badger.Txn, fingerprint []byte) error {
}
return nil
}

func (s Store) renewCertForAgents(txn *badger.Txn) error {
renewThreshold := 7 * 24 * time.Hour

opts := badger.DefaultIteratorOptions
opts.PrefetchSize = 10
opts.Prefix = prefix_agent
it := txn.NewIterator(opts)
defer it.Close()

for it.Seek(prefix_agent); it.ValidForPrefix(prefix_agent); it.Next() {
item := it.Item()
err := item.Value(func(v []byte) error {
agent := &Agent{}
if err := proto.Unmarshal(v, agent); err != nil {
s.l.WithError(err).Error("Failed to parse agent")
return nil
}

untilExpires := time.Until(agent.ExpiresAt.AsTime())

if untilExpires.Hours() < renewThreshold.Hours() {
s.l.Debugf("Renewing certificate for agent: %s %x", agent.Name, agent.Fingerprint)
ip, err := assignedIPToIPNet(agent.AssignedIP)
if err != nil {
return fmt.Errorf("failed to parse ip of agent: %s", err)
}

agent, err = s.signCSR(txn, agent, ip)
if err != nil {
return fmt.Errorf("failed to sign agent csr: %s", err)
}

agent, err = s.updateAgent(txn, agent)
if err != nil {
return fmt.Errorf("failed to update agent as part of renewing agent cerfiticate: %s", err)
}
}

return nil
})
if err != nil {
return err
}
}
return nil
}
6 changes: 1 addition & 5 deletions server/store/enrollment.go
Original file line number Diff line number Diff line change
Expand Up @@ -343,14 +343,10 @@ func (s *Store) approveEnrollmentRequest(txn *badger.Txn, ipManager *IPManager,
}

if enrolled {
i, n, err := net.ParseCIDR(agent.AssignedIP)
ip, err = assignedIPToIPNet(agent.AssignedIP)
if err != nil {
return nil, fmt.Errorf("failed to parse ip of existing agent: %s", err)
}
ip = &net.IPNet{
IP: i,
Mask: n.Mask,
}
}

if ip == nil {
Expand Down
20 changes: 20 additions & 0 deletions server/store/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"encoding/hex"
"fmt"
"io/ioutil"
"net"
"os"
"path"
"path/filepath"
Expand Down Expand Up @@ -135,6 +136,14 @@ func (s *Store) Unseal(keyPart string, removeExistingParts bool) error {
return nil
}

func (s *Store) GC() {
again:
err := s.db.RunValueLogGC(0.7)
if err == nil {
goto again
}
}

func NewStore(l *logrus.Logger, dataDir string, unsealed chan interface{}, encryptionEnabled bool) (*Store, error) {
dbPath := filepath.Join(dataDir, "db")
stat, err := os.Stat(dbPath)
Expand Down Expand Up @@ -208,3 +217,14 @@ func containsByteSlice(array [][]byte, value []byte) bool {
}
return false
}

func assignedIPToIPNet(assignedIP string) (*net.IPNet, error) {
i, n, err := net.ParseCIDR(assignedIP)
if err != nil {
return nil, err
}
return &net.IPNet{
IP: i,
Mask: n.Mask,
}, nil
}
60 changes: 60 additions & 0 deletions server/tasks.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
package server

import (
"time"

"github.com/sirupsen/logrus"
"github.com/slackhq/nebula"
"github.com/slyngdk/nebula-provisioner/server/store"
)

type tasks struct {
config *nebula.Config
store *store.Store

l *logrus.Logger
quit chan interface{}
}

func NewTasks(l *logrus.Logger, config *nebula.Config, store *store.Store) *tasks {
return &tasks{l: l, config: config, store: store, quit: make(chan interface{})}
}

func (t *tasks) Start() {
t.l.Infoln("Starting task scheduler")

renewCertTicker := time.NewTicker(t.config.GetDuration("tasks.certRenew.interval", 1*time.Hour))
dbGCTicker := time.NewTicker(t.config.GetDuration("tasks.dbGC.interval", 5*time.Minute))

go func() {
for {
select {
case <-renewCertTicker.C:
t.renewCerts()
case <-dbGCTicker.C:
t.dbGC()
case <-t.quit:
renewCertTicker.Stop()
return
}
}
}()
}

func (t *tasks) Stop() {
t.l.Infoln("Stopping task scheduler")
t.quit <- struct{}{}
}

func (t *tasks) renewCerts() {
t.l.Infoln("Task: renew agent certificates")
err := t.store.RenewCertForAgents()
if err != nil {
t.l.WithError(err).Errorln("error when renewing certificates for agents")
}
}

func (t *tasks) dbGC() {
t.l.Debugln("Task: db garbage collection")
t.store.GC()
}

0 comments on commit b0c22d3

Please sign in to comment.