Skip to content

Commit

Permalink
Refactored agent enroll, to create a new enrollment request if input …
Browse files Browse the repository at this point in the history
…is different from currently.
  • Loading branch information
SlyngDK committed Nov 26, 2021
1 parent e4ac087 commit e7e0f04
Show file tree
Hide file tree
Showing 2 changed files with 133 additions and 73 deletions.
190 changes: 117 additions & 73 deletions cmd/agent/enrollment.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ import (
"fmt"
"io"
"io/ioutil"
"net"
"os"
"strings"
"time"

"github.com/slackhq/nebula/cert"
Expand All @@ -28,22 +28,7 @@ var enrollCmd = &cobra.Command{
}
defer agent.Close()

token, err := cmd.Flags().GetString("token")
if err != nil {
l.WithError(err).Fatalln("failed to get token")
}
groups, err := cmd.Flags().GetString("groups")
if err != nil {
l.WithError(err).Fatalln("failed to get groups")
}
ip, err := cmd.Flags().GetString("ip")
if err != nil {
l.WithError(err).Fatalln("failed to get ip")
}

if err = enroll(agent, token, groups, ip); err != nil {
l.WithError(err).Fatalln("failed to enroll to server")
}
updateEnrollmentRequest(agent, cmd)
},
}

Expand Down Expand Up @@ -91,82 +76,146 @@ var enrollWaitCmd = &cobra.Command{
ticker := time.NewTicker(10 * time.Second)
defer ticker.Stop()

status := getStatus(agent)
if status == 0 {
token, _ := cmd.Flags().GetString("token")
if token != "" {
groups, err := cmd.Flags().GetString("groups")
if err != nil {
l.WithError(err).Fatalln("failed to get groups")
}
ip, err := cmd.Flags().GetString("ip")
if err != nil {
l.WithError(err).Fatalln("failed to get ip")
}
updateEnrollmentRequest(agent, cmd)

if err := enroll(agent, token, groups, ip); err != nil {
l.WithError(err).Fatalln("failed to enroll to server")
os.Exit(2)
}
} else {
l.Errorln("Require the enrollment has been started, you can provide enrollment token and doing it in one process.")
os.Exit(1)
}
}
if status == 2 {
if isEnrollDone(agent) {
return
}

for {
select {
case _ = <-ticker.C:
status := getStatus(agent)
if status == 2 {
l.Info("Agent is now enrolled")
if isEnrollDone(agent) {
return
}
}
}
},
}

func getStatus(agent *agentClient) int8 {
func init() {
enrollCmd.Flags().StringP("token", "t", "", "Enrollment token")
enrollCmd.Flags().StringSliceP("groups", "g", []string{}, "Comma separated list of groups")
enrollCmd.Flags().StringP("ip", "i", "", "Requesting for this specific nebula ip")
enrollCmd.MarkFlagRequired("token")
enrollWaitCmd.Flags().StringP("token", "t", "", "Enrollment token")
enrollWaitCmd.Flags().StringSliceP("groups", "g", []string{}, "Comma separated list of groups")
enrollWaitCmd.Flags().StringP("ip", "i", "", "Requesting for this specific nebula ip")
enrollWaitCmd.MarkFlagRequired("token")

enrollCmd.AddCommand(enrollStatusCmd)
enrollCmd.AddCommand(enrollWaitCmd)
}

func updateEnrollmentRequest(agent *agentClient, cmd *cobra.Command) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()

res, err := agent.client.GetEnrollStatus(ctx, &emptypb.Empty{})
status, err := agent.client.GetEnrollStatus(ctx, &emptypb.Empty{})
if err != nil {
l.WithError(err).Fatalln("failed to get enrollment status")
os.Exit(1)
}

if res.IsEnrolled {
l.Info("Agent is enrolled")
l.Infof("IssuedAt: %s, ExpiresAt: %s\n", res.IssuedAt.AsTime().Format(time.RFC3339), res.ExpiresAt.AsTime().Format(time.RFC3339))
return 2
} else if res.IsEnrollmentRequested {
l.Info("Agent has requested to be enrolled")
return 1
token, err := cmd.Flags().GetString("token")
if err != nil {
l.WithError(err).Fatalln("failed to get token")
}
if token == "" {
l.Errorln("Require the enrollment has been started, you can provide enrollment token and doing it in one process.")
os.Exit(1)
}
groups, err := cmd.Flags().GetStringSlice("groups")
if err != nil {
l.WithError(err).Fatalln("failed to get groups")
}
ip, err := cmd.Flags().GetString("ip")
if err != nil {
l.WithError(err).Fatalln("failed to get ip")
} else {
l.Info("Agent enrollment not started")
return 0
parseIP := net.ParseIP(ip)
if parseIP == nil && ip != "" {
l.WithError(err).Fatalln("ip is of invalid format")
os.Exit(1)
}
}
hostname, err := os.Hostname()
if err != nil {
l.WithError(err).Errorln("error when getting hostname")
os.Exit(1)
}

var diff = false

if status.EnrollmentRequest != nil {
l.Debug("comparing against existing enrollment request")
l.Debugf("hostname compare %s <=> %s", hostname, status.EnrollmentRequest.Name)
if hostname != status.EnrollmentRequest.Name {
diff = true
l.Debugf("diff on hostname")
}

l.Debugf("ip compare %s <=> %s", ip, status.EnrollmentRequest.RequestedIP)
if ip != status.EnrollmentRequest.RequestedIP {
diff = true
l.Debugf("diff on ip")
}

l.Debugf("groups compare %s <=> %s", groups, status.EnrollmentRequest.Groups)
if !stringSlicesEqual(groups, status.EnrollmentRequest.Groups) {
diff = true
l.Debugf("diff on groups")
}
} else if status.IsEnrolled {
l.Debug("comparing against enrolled agent")
l.Debugf("hostname compare %s <=> %s", hostname, status.Name)
if hostname != status.Name {
diff = true
l.Debugf("diff on hostname")
}

l.Debugf("ip compare %s <=> %s", ip, status.AssignedIP)
if ip != status.AssignedIP {
diff = true
l.Debugf("diff on ip")
}

l.Debugf("groups compare %s <=> %s", groups, status.Groups)
if !stringSlicesEqual(groups, status.Groups) {
diff = true
l.Debugf("diff on groups")
}
} else {
diff = true
}

if diff {
l.Info("adding enrollment request")
if err := enroll(agent, token, ip, groups); err != nil {
l.WithError(err).Fatalln("failed to enroll agent")
os.Exit(2)
}
}
}

func init() {
enrollCmd.Flags().StringP("token", "t", "", "Enrollment token")
enrollCmd.Flags().StringP("groups", "g", "", "Comma separated list of groups")
enrollCmd.Flags().StringP("ip", "i", "", "Requesting for this specific nebula ip")
enrollCmd.MarkFlagRequired("token")
enrollWaitCmd.Flags().StringP("token", "t", "", "Enrollment token")
enrollWaitCmd.Flags().StringP("groups", "g", "", "Comma separated list of groups")
enrollWaitCmd.Flags().StringP("ip", "i", "", "Requesting for this specific nebula ip")
enrollWaitCmd.MarkFlagRequired("token")
func isEnrollDone(agent *agentClient) bool {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
status, err := agent.client.GetEnrollStatus(ctx, &emptypb.Empty{})
if err != nil {
l.WithError(err).Fatalln("failed to get enrollment status")
os.Exit(1)
}

enrollCmd.AddCommand(enrollStatusCmd)
enrollCmd.AddCommand(enrollWaitCmd)
if status.IsEnrolled && !status.IsEnrollmentRequested {
l.Info("Agent is enrolled")
return true
}

return false
}

func enroll(c *agentClient, enrollmentToken, groups, ip string) error {
func enroll(c *agentClient, enrollmentToken, ip string, groups []string) error {
if enrollmentToken == "" {
return fmt.Errorf("requires enrollmentToken")
}
Expand All @@ -181,11 +230,8 @@ func enroll(c *agentClient, enrollmentToken, groups, ip string) error {
CsrPEM: string(csr),
}

if groups != "" {
g := strings.Split(groups, ",")
if len(g) > 0 {
enrollRequest.Groups = g
}
if len(groups) > 0 {
enrollRequest.Groups = groups
}

if ip != "" {
Expand All @@ -200,13 +246,11 @@ func enroll(c *agentClient, enrollmentToken, groups, ip string) error {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()

res, err := c.client.Enroll(ctx, enrollRequest)
_, err = c.client.Enroll(ctx, enrollRequest)
if err != nil {
return err
}

c.l.Println(res)

return nil
}

Expand Down
16 changes: 16 additions & 0 deletions cmd/agent/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package main
import (
"os"
"path/filepath"
"sort"
)

func fileExists(filepath string) (bool, os.FileInfo) {
Expand Down Expand Up @@ -52,3 +53,18 @@ func resolvePath(path string) string {
}
return path
}

func stringSlicesEqual(a, b []string) bool {
if len(a) != len(b) {
return false
}
sort.Strings(a)
sort.Strings(b)

for i, v := range a {
if v != b[i] {
return false
}
}
return true
}

0 comments on commit e7e0f04

Please sign in to comment.