Skip to content

Commit

Permalink
Use Agent API for locking mirrors and known_hosts
Browse files Browse the repository at this point in the history
  • Loading branch information
DrJosh9000 committed Jun 8, 2023
1 parent 17c13fa commit 97a4fdf
Show file tree
Hide file tree
Showing 9 changed files with 103 additions and 37 deletions.
2 changes: 1 addition & 1 deletion agent/job_runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -732,7 +732,7 @@ func (w LogWriter) Write(bytes []byte) (int, error) {
func (r *JobRunner) executePreBootstrapHook(ctx context.Context, hook string) (bool, error) {
r.logger.Info("Running pre-bootstrap hook %q", hook)

sh, err := shell.New()
sh, err := shell.New(ctx, shell.Config{})
if err != nil {
return false, err
}
Expand Down
14 changes: 9 additions & 5 deletions bootstrap/bootstrap.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,16 +70,20 @@ func New(conf Config) *Bootstrap {
func (b *Bootstrap) Run(ctx context.Context) (exitCode int) {
// Check if not nil to allow for tests to overwrite shell
if b.shell == nil {
var err error
b.shell, err = shell.New()
// The Agent API socket could be important for file locking.
cfg := shell.Config{
SocketsPath: b.Config.SocketsPath,
}
sh, err := shell.New(ctx, cfg)
if err != nil {
fmt.Printf("Error creating shell: %v", err)
return 1
}

b.shell.PTY = b.Config.RunInPty
b.shell.Debug = b.Config.Debug
b.shell.InterruptSignal = b.Config.CancelSignal
sh.PTY = b.Config.RunInPty
sh.Debug = b.Config.Debug
sh.InterruptSignal = b.Config.CancelSignal
b.shell = sh
}
if experiments.IsEnabled(experiments.KubernetesExec) {
kubernetesClient := &kubernetes.Client{}
Expand Down
4 changes: 2 additions & 2 deletions bootstrap/bootstrap_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ func TestStartTracing_NoTracingBackend(t *testing.T) {
b := New(Config{})

oriCtx := context.Background()
b.shell, err = shell.New()
b.shell, err = shell.New(oriCtx, shell.Config{})
assert.NoError(t, err)

span, _, stopper := b.startTracing(oriCtx)
Expand All @@ -107,7 +107,7 @@ func TestStartTracing_Datadog(t *testing.T) {
b := New(cfg)

oriCtx := context.Background()
b.shell, err = shell.New()
b.shell, err = shell.New(oriCtx, shell.Config{})
assert.NoError(t, err)

span, ctx, stopper := b.startTracing(oriCtx)
Expand Down
74 changes: 67 additions & 7 deletions bootstrap/shell/shell.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ import (
"github.com/opentracing/opentracing-go"

"github.com/buildkite/agent/v3/env"
"github.com/buildkite/agent/v3/experiments"
"github.com/buildkite/agent/v3/internal/shellscript"
"github.com/buildkite/agent/v3/lock"
"github.com/buildkite/agent/v3/logger"
"github.com/buildkite/agent/v3/process"
"github.com/buildkite/agent/v3/tracetools"
Expand Down Expand Up @@ -67,23 +69,45 @@ type Shell struct {
cmd *command
cmdLock sync.Mutex

// Lock service client, if available
lockClient *lock.Client

// The signal to use to interrupt the command
InterruptSignal process.Signal
}

// New returns a new Shell
func New() (*Shell, error) {
// Config contains configuration options for creating a Shell.
type Config struct {
SocketsPath string
}

// New returns a new Shell.
func New(ctx context.Context, cfg Config) (*Shell, error) {
wd, err := os.Getwd()
if err != nil {
return nil, fmt.Errorf("Failed to find current working directory: %w", err)
}

return &Shell{
sh := &Shell{
Logger: StderrLogger,
Env: env.FromSlice(os.Environ()),
Writer: os.Stdout,
wd: wd,
}, nil
}

// Use the Agent API for locking?
if cfg.SocketsPath != "" && experiments.IsEnabled(experiments.AgentAPI) {
ctx, canc := context.WithTimeout(ctx, 10*time.Second)
defer canc()
lc, err := lock.NewClient(ctx, cfg.SocketsPath)
if err != nil {
sh.Logger.Errorf("Couldn't use Agent API for locking, so falling back to using flock-based locks: %v", err)
lc = nil
}
sh.lockClient = lc
}

return sh, nil
}

// WithStdin returns a copy of the Shell with the provided io.Reader set as the
Expand Down Expand Up @@ -181,8 +205,8 @@ func (s *Shell) WaitStatus() (process.WaitStatus, error) {
return s.cmd.proc.WaitStatus(), nil
}

// LockFile is a pid-based lock for cross-process locking
type LockFile interface {
// Unlocker types can unlock a cross-process lock (such as an flock).
type Unlocker interface {
Unlock() error
}

Expand Down Expand Up @@ -222,8 +246,44 @@ func (s *Shell) flock(ctx context.Context, path string, timeout time.Duration) (
return lock, err
}

// agentAPILock contains all the information required to unlock an Agent API
// lock-service lock.
type agentAPILock struct {
client *lock.Client
key, token string
}

func (l *agentAPILock) Unlock() error {
return l.client.Unlock(context.Background(), l.key, l.token)
}

// lockWithAgentAPI acquires a lock in the Agent API lock service.
func (s *Shell) lockWithAgentAPI(ctx context.Context, path string, timeout time.Duration) (*agentAPILock, error) {
absolutePathToLock, err := filepath.Abs(path)
if err != nil {
return nil, fmt.Errorf("Failed to find absolute path to lock \"%s\" (%v)", path, err)
}

ctx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()

token, err := s.lockClient.Lock(ctx, absolutePathToLock)
if err != nil {
return nil, fmt.Errorf("Failed to acquire lock for %q: %v", path, err)
}

return &agentAPILock{
client: s.lockClient,
key: absolutePathToLock,
token: token,
}, err
}

// Create a cross-process file-based lock based on pid files
func (s *Shell) LockFile(ctx context.Context, path string, timeout time.Duration) (LockFile, error) {
func (s *Shell) LockFile(ctx context.Context, path string, timeout time.Duration) (Unlocker, error) {
if s.lockClient != nil {
return s.lockWithAgentAPI(ctx, path, timeout)
}
return s.flock(ctx, path, timeout)
}

Expand Down
12 changes: 6 additions & 6 deletions bootstrap/shell/shell_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ func TestContextCancelTerminates(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

sh, err := shell.New()
sh, err := shell.New(ctx, shell.Config{})
if err != nil {
t.Fatalf("shell.New() error = %v", err)
}
Expand Down Expand Up @@ -165,7 +165,7 @@ func TestInterrupt(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

sh, err := shell.New()
sh, err := shell.New(ctx, shell.Config{})
if err != nil {
t.Fatalf("shell.New() error = %v", err)
}
Expand All @@ -190,7 +190,7 @@ func TestInterrupt(t *testing.T) {
}

func TestDefaultWorkingDirFromSystem(t *testing.T) {
sh, err := shell.New()
sh, err := shell.New(context.Background(), shell.Config{})
if err != nil {
t.Fatalf("shell.New() error = %v", err)
}
Expand Down Expand Up @@ -231,7 +231,7 @@ func TestWorkingDir(t *testing.T) {
t.Fatalf("os.Getwd() error = %v", err)
}

sh, err := shell.New()
sh, err := shell.New(context.Background(), shell.Config{})
if err != nil {
t.Fatalf("shell.New() error = %v", err)
}
Expand Down Expand Up @@ -353,7 +353,7 @@ func TestAcquiringLockHelperProcess(t *testing.T) {
}

func newShellForTest(t *testing.T) *shell.Shell {
sh, err := shell.New()
sh, err := shell.New(context.Background(), shell.Config{})
if err != nil {
t.Fatalf("shell.New() error = %v", err)
}
Expand All @@ -362,7 +362,7 @@ func newShellForTest(t *testing.T) *shell.Shell {
}

func TestRunWithoutPrompt(t *testing.T) {
sh, err := shell.New()
sh, err := shell.New(context.Background(), shell.Config{})
if err != nil {
t.Fatalf("shell.New() error = %v", err)
}
Expand Down
3 changes: 2 additions & 1 deletion bootstrap/shell/test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package shell

import (
"context"
"io"
"os"
"runtime"
Expand All @@ -11,7 +12,7 @@ import (

// NewTestShell creates a minimal shell suitable for tests.
func NewTestShell(t *testing.T) *Shell {
sh, err := New()
sh, err := New(context.Background(), Config{})
if err != nil {
t.Fatalf("shell.New() error = %v", err)
}
Expand Down
2 changes: 1 addition & 1 deletion bootstrap/ssh_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ func init() {
func TestFindingSSHTools(t *testing.T) {
t.Parallel()

sh, err := shell.New()
sh, err := shell.New(context.Background(), shell.Config{})
if err != nil {
t.Fatalf("shell.New() error = %v", err)
}
Expand Down
16 changes: 8 additions & 8 deletions clicommand/agent_start.go
Original file line number Diff line number Diff line change
Expand Up @@ -963,10 +963,10 @@ var AgentStartCommand = cli.Command{
pool := agent.NewAgentPool(workers)

// Agent-wide shutdown hook. Once per agent, for all workers on the agent.
defer agentShutdownHook(l, cfg)
defer agentShutdownHook(ctx, l, cfg)

// Once the shutdown hook has been setup, trigger the startup hook.
if err := agentStartupHook(l, cfg); err != nil {
if err := agentStartupHook(ctx, l, cfg); err != nil {
l.Fatal("%s", err)
}

Expand Down Expand Up @@ -1052,17 +1052,17 @@ func handlePoolSignals(ctx context.Context, l logger.Logger, pool *agent.AgentPo
return signals
}

func agentStartupHook(log logger.Logger, cfg AgentStartConfig) error {
return agentLifecycleHook("agent-startup", log, cfg)
func agentStartupHook(ctx context.Context, log logger.Logger, cfg AgentStartConfig) error {
return agentLifecycleHook(ctx, "agent-startup", log, cfg)
}
func agentShutdownHook(log logger.Logger, cfg AgentStartConfig) {
_ = agentLifecycleHook("agent-shutdown", log, cfg)
func agentShutdownHook(ctx context.Context, log logger.Logger, cfg AgentStartConfig) {
_ = agentLifecycleHook(ctx, "agent-shutdown", log, cfg)
}

// agentLifecycleHook looks for a hook script in the hooks path
// and executes it if found. Output (stdout + stderr) is streamed into the main
// agent logger. Exit status failure is logged and returned for the caller to handle
func agentLifecycleHook(hookName string, log logger.Logger, cfg AgentStartConfig) error {
func agentLifecycleHook(ctx context.Context, hookName string, log logger.Logger, cfg AgentStartConfig) error {
// search for hook (including .bat & .ps1 files on Windows)
p, err := hook.Find(cfg.HooksPath, hookName)
if err != nil {
Expand All @@ -1072,7 +1072,7 @@ func agentLifecycleHook(hookName string, log logger.Logger, cfg AgentStartConfig
}
return nil
}
sh, err := shell.New()
sh, err := shell.New(ctx, shell.Config{})
if err != nil {
log.Error("creating shell for %q hook: %v", hookName, err)
return err
Expand Down
13 changes: 7 additions & 6 deletions clicommand/agent_start_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package clicommand

import (
"context"
"os"
"path/filepath"
"runtime"
Expand Down Expand Up @@ -50,7 +51,7 @@ func TestAgentStartupHook(t *testing.T) {
defer closer()
filepath := writeAgentHook(t, hooksPath, "agent-startup")
log := logger.NewBuffer()
err := agentStartupHook(log, cfg(hooksPath))
err := agentStartupHook(context.Background(), log, cfg(hooksPath))

if assert.NoError(t, err, log.Messages) {
assert.Equal(t, []string{
Expand All @@ -64,14 +65,14 @@ func TestAgentStartupHook(t *testing.T) {
defer closer()

log := logger.NewBuffer()
err := agentStartupHook(log, cfg(hooksPath))
err := agentStartupHook(context.Background(), log, cfg(hooksPath))
if assert.NoError(t, err, log.Messages) {
assert.Equal(t, []string{}, log.Messages)
}
})
t.Run("with bad hooks path", func(t *testing.T) {
log := logger.NewBuffer()
err := agentStartupHook(log, cfg("zxczxczxc"))
err := agentStartupHook(context.Background(), log, cfg("zxczxczxc"))

if assert.NoError(t, err, log.Messages) {
assert.Equal(t, []string{}, log.Messages)
Expand All @@ -95,7 +96,7 @@ func TestAgentShutdownHook(t *testing.T) {
defer closer()
filepath := writeAgentHook(t, hooksPath, "agent-shutdown")
log := logger.NewBuffer()
agentShutdownHook(log, cfg(hooksPath))
agentShutdownHook(context.Background(), log, cfg(hooksPath))

assert.Equal(t, []string{
"[info] " + prompt + " " + filepath, // prompt
Expand All @@ -107,12 +108,12 @@ func TestAgentShutdownHook(t *testing.T) {
defer closer()

log := logger.NewBuffer()
agentShutdownHook(log, cfg(hooksPath))
agentShutdownHook(context.Background(), log, cfg(hooksPath))
assert.Equal(t, []string{}, log.Messages)
})
t.Run("with bad hooks path", func(t *testing.T) {
log := logger.NewBuffer()
agentShutdownHook(log, cfg("zxczxczxc"))
agentShutdownHook(context.Background(), log, cfg("zxczxczxc"))
assert.Equal(t, []string{}, log.Messages)
})
}

0 comments on commit 97a4fdf

Please sign in to comment.