diff --git a/agent/job_runner.go b/agent/job_runner.go index c99aa02e15..06bbd52e78 100644 --- a/agent/job_runner.go +++ b/agent/job_runner.go @@ -732,7 +732,7 @@ func (w LogWriter) Write(bytes []byte) (int, error) { func (r *JobRunner) executePreBootstrapHook(ctx context.Context, hook string) (bool, error) { r.logger.Info("Running pre-bootstrap hook %q", hook) - sh, err := shell.New() + sh, err := shell.New(ctx, shell.Config{}) if err != nil { return false, err } diff --git a/bootstrap/bootstrap.go b/bootstrap/bootstrap.go index 88932131e2..ddc9867988 100644 --- a/bootstrap/bootstrap.go +++ b/bootstrap/bootstrap.go @@ -70,16 +70,20 @@ func New(conf Config) *Bootstrap { func (b *Bootstrap) Run(ctx context.Context) (exitCode int) { // Check if not nil to allow for tests to overwrite shell if b.shell == nil { - var err error - b.shell, err = shell.New() + // The Agent API socket could be important for file locking. + cfg := shell.Config{ + SocketsPath: b.Config.SocketsPath, + } + sh, err := shell.New(ctx, cfg) if err != nil { fmt.Printf("Error creating shell: %v", err) return 1 } - b.shell.PTY = b.Config.RunInPty - b.shell.Debug = b.Config.Debug - b.shell.InterruptSignal = b.Config.CancelSignal + sh.PTY = b.Config.RunInPty + sh.Debug = b.Config.Debug + sh.InterruptSignal = b.Config.CancelSignal + b.shell = sh } if experiments.IsEnabled(experiments.KubernetesExec) { kubernetesClient := &kubernetes.Client{} diff --git a/bootstrap/bootstrap_test.go b/bootstrap/bootstrap_test.go index 1dce8bdf9a..71234cc422 100644 --- a/bootstrap/bootstrap_test.go +++ b/bootstrap/bootstrap_test.go @@ -86,7 +86,7 @@ func TestStartTracing_NoTracingBackend(t *testing.T) { b := New(Config{}) oriCtx := context.Background() - b.shell, err = shell.New() + b.shell, err = shell.New(oriCtx, shell.Config{}) assert.NoError(t, err) span, _, stopper := b.startTracing(oriCtx) @@ -107,7 +107,7 @@ func TestStartTracing_Datadog(t *testing.T) { b := New(cfg) oriCtx := context.Background() - b.shell, err = shell.New() + b.shell, err = shell.New(oriCtx, shell.Config{}) assert.NoError(t, err) span, ctx, stopper := b.startTracing(oriCtx) diff --git a/bootstrap/shell/shell.go b/bootstrap/shell/shell.go index 6d9eb1e07d..c96227c695 100644 --- a/bootstrap/shell/shell.go +++ b/bootstrap/shell/shell.go @@ -23,7 +23,9 @@ import ( "github.com/opentracing/opentracing-go" "github.com/buildkite/agent/v3/env" + "github.com/buildkite/agent/v3/experiments" "github.com/buildkite/agent/v3/internal/shellscript" + "github.com/buildkite/agent/v3/lock" "github.com/buildkite/agent/v3/logger" "github.com/buildkite/agent/v3/process" "github.com/buildkite/agent/v3/tracetools" @@ -67,23 +69,45 @@ type Shell struct { cmd *command cmdLock sync.Mutex + // Lock service client, if available + lockClient *lock.Client + // The signal to use to interrupt the command InterruptSignal process.Signal } -// New returns a new Shell -func New() (*Shell, error) { +// Config contains configuration options for creating a Shell. +type Config struct { + SocketsPath string +} + +// New returns a new Shell. +func New(ctx context.Context, cfg Config) (*Shell, error) { wd, err := os.Getwd() if err != nil { return nil, fmt.Errorf("Failed to find current working directory: %w", err) } - return &Shell{ + sh := &Shell{ Logger: StderrLogger, Env: env.FromSlice(os.Environ()), Writer: os.Stdout, wd: wd, - }, nil + } + + // Use the Agent API for locking? + if cfg.SocketsPath != "" && experiments.IsEnabled(experiments.AgentAPI) { + ctx, canc := context.WithTimeout(ctx, 10*time.Second) + defer canc() + lc, err := lock.NewClient(ctx, cfg.SocketsPath) + if err != nil { + sh.Logger.Errorf("Couldn't use Agent API for locking, so falling back to using flock-based locks: %v", err) + lc = nil + } + sh.lockClient = lc + } + + return sh, nil } // WithStdin returns a copy of the Shell with the provided io.Reader set as the @@ -181,8 +205,8 @@ func (s *Shell) WaitStatus() (process.WaitStatus, error) { return s.cmd.proc.WaitStatus(), nil } -// LockFile is a pid-based lock for cross-process locking -type LockFile interface { +// Unlocker types can unlock a cross-process lock (such as an flock). +type Unlocker interface { Unlock() error } @@ -222,8 +246,44 @@ func (s *Shell) flock(ctx context.Context, path string, timeout time.Duration) ( return lock, err } +// agentAPILock contains all the information required to unlock an Agent API +// lock-service lock. +type agentAPILock struct { + client *lock.Client + key, token string +} + +func (l *agentAPILock) Unlock() error { + return l.client.Unlock(context.Background(), l.key, l.token) +} + +// lockWithAgentAPI acquires a lock in the Agent API lock service. +func (s *Shell) lockWithAgentAPI(ctx context.Context, path string, timeout time.Duration) (*agentAPILock, error) { + absolutePathToLock, err := filepath.Abs(path) + if err != nil { + return nil, fmt.Errorf("Failed to find absolute path to lock \"%s\" (%v)", path, err) + } + + ctx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + + token, err := s.lockClient.Lock(ctx, absolutePathToLock) + if err != nil { + return nil, fmt.Errorf("Failed to acquire lock for %q: %v", path, err) + } + + return &agentAPILock{ + client: s.lockClient, + key: absolutePathToLock, + token: token, + }, err +} + // Create a cross-process file-based lock based on pid files -func (s *Shell) LockFile(ctx context.Context, path string, timeout time.Duration) (LockFile, error) { +func (s *Shell) LockFile(ctx context.Context, path string, timeout time.Duration) (Unlocker, error) { + if s.lockClient != nil { + return s.lockWithAgentAPI(ctx, path, timeout) + } return s.flock(ctx, path, timeout) } diff --git a/bootstrap/shell/shell_test.go b/bootstrap/shell/shell_test.go index a5e35a3a60..7bffe40916 100644 --- a/bootstrap/shell/shell_test.go +++ b/bootstrap/shell/shell_test.go @@ -131,7 +131,7 @@ func TestContextCancelTerminates(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - sh, err := shell.New() + sh, err := shell.New(ctx, shell.Config{}) if err != nil { t.Fatalf("shell.New() error = %v", err) } @@ -165,7 +165,7 @@ func TestInterrupt(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - sh, err := shell.New() + sh, err := shell.New(ctx, shell.Config{}) if err != nil { t.Fatalf("shell.New() error = %v", err) } @@ -190,7 +190,7 @@ func TestInterrupt(t *testing.T) { } func TestDefaultWorkingDirFromSystem(t *testing.T) { - sh, err := shell.New() + sh, err := shell.New(context.Background(), shell.Config{}) if err != nil { t.Fatalf("shell.New() error = %v", err) } @@ -231,7 +231,7 @@ func TestWorkingDir(t *testing.T) { t.Fatalf("os.Getwd() error = %v", err) } - sh, err := shell.New() + sh, err := shell.New(context.Background(), shell.Config{}) if err != nil { t.Fatalf("shell.New() error = %v", err) } @@ -353,7 +353,7 @@ func TestAcquiringLockHelperProcess(t *testing.T) { } func newShellForTest(t *testing.T) *shell.Shell { - sh, err := shell.New() + sh, err := shell.New(context.Background(), shell.Config{}) if err != nil { t.Fatalf("shell.New() error = %v", err) } @@ -362,7 +362,7 @@ func newShellForTest(t *testing.T) *shell.Shell { } func TestRunWithoutPrompt(t *testing.T) { - sh, err := shell.New() + sh, err := shell.New(context.Background(), shell.Config{}) if err != nil { t.Fatalf("shell.New() error = %v", err) } diff --git a/bootstrap/shell/test.go b/bootstrap/shell/test.go index a4c62a7ccd..9b60817640 100644 --- a/bootstrap/shell/test.go +++ b/bootstrap/shell/test.go @@ -1,6 +1,7 @@ package shell import ( + "context" "io" "os" "runtime" @@ -11,7 +12,7 @@ import ( // NewTestShell creates a minimal shell suitable for tests. func NewTestShell(t *testing.T) *Shell { - sh, err := New() + sh, err := New(context.Background(), Config{}) if err != nil { t.Fatalf("shell.New() error = %v", err) } diff --git a/bootstrap/ssh_test.go b/bootstrap/ssh_test.go index 705db89b83..60cb2c3501 100644 --- a/bootstrap/ssh_test.go +++ b/bootstrap/ssh_test.go @@ -18,7 +18,7 @@ func init() { func TestFindingSSHTools(t *testing.T) { t.Parallel() - sh, err := shell.New() + sh, err := shell.New(context.Background(), shell.Config{}) if err != nil { t.Fatalf("shell.New() error = %v", err) } diff --git a/clicommand/agent_start.go b/clicommand/agent_start.go index d415db2ff1..a4bd751900 100644 --- a/clicommand/agent_start.go +++ b/clicommand/agent_start.go @@ -963,10 +963,10 @@ var AgentStartCommand = cli.Command{ pool := agent.NewAgentPool(workers) // Agent-wide shutdown hook. Once per agent, for all workers on the agent. - defer agentShutdownHook(l, cfg) + defer agentShutdownHook(ctx, l, cfg) // Once the shutdown hook has been setup, trigger the startup hook. - if err := agentStartupHook(l, cfg); err != nil { + if err := agentStartupHook(ctx, l, cfg); err != nil { l.Fatal("%s", err) } @@ -1052,17 +1052,17 @@ func handlePoolSignals(ctx context.Context, l logger.Logger, pool *agent.AgentPo return signals } -func agentStartupHook(log logger.Logger, cfg AgentStartConfig) error { - return agentLifecycleHook("agent-startup", log, cfg) +func agentStartupHook(ctx context.Context, log logger.Logger, cfg AgentStartConfig) error { + return agentLifecycleHook(ctx, "agent-startup", log, cfg) } -func agentShutdownHook(log logger.Logger, cfg AgentStartConfig) { - _ = agentLifecycleHook("agent-shutdown", log, cfg) +func agentShutdownHook(ctx context.Context, log logger.Logger, cfg AgentStartConfig) { + _ = agentLifecycleHook(ctx, "agent-shutdown", log, cfg) } // agentLifecycleHook looks for a hook script in the hooks path // and executes it if found. Output (stdout + stderr) is streamed into the main // agent logger. Exit status failure is logged and returned for the caller to handle -func agentLifecycleHook(hookName string, log logger.Logger, cfg AgentStartConfig) error { +func agentLifecycleHook(ctx context.Context, hookName string, log logger.Logger, cfg AgentStartConfig) error { // search for hook (including .bat & .ps1 files on Windows) p, err := hook.Find(cfg.HooksPath, hookName) if err != nil { @@ -1072,7 +1072,7 @@ func agentLifecycleHook(hookName string, log logger.Logger, cfg AgentStartConfig } return nil } - sh, err := shell.New() + sh, err := shell.New(ctx, shell.Config{}) if err != nil { log.Error("creating shell for %q hook: %v", hookName, err) return err diff --git a/clicommand/agent_start_test.go b/clicommand/agent_start_test.go index 1f105bc3c2..aa01841511 100644 --- a/clicommand/agent_start_test.go +++ b/clicommand/agent_start_test.go @@ -1,6 +1,7 @@ package clicommand import ( + "context" "os" "path/filepath" "runtime" @@ -50,7 +51,7 @@ func TestAgentStartupHook(t *testing.T) { defer closer() filepath := writeAgentHook(t, hooksPath, "agent-startup") log := logger.NewBuffer() - err := agentStartupHook(log, cfg(hooksPath)) + err := agentStartupHook(context.Background(), log, cfg(hooksPath)) if assert.NoError(t, err, log.Messages) { assert.Equal(t, []string{ @@ -64,14 +65,14 @@ func TestAgentStartupHook(t *testing.T) { defer closer() log := logger.NewBuffer() - err := agentStartupHook(log, cfg(hooksPath)) + err := agentStartupHook(context.Background(), log, cfg(hooksPath)) if assert.NoError(t, err, log.Messages) { assert.Equal(t, []string{}, log.Messages) } }) t.Run("with bad hooks path", func(t *testing.T) { log := logger.NewBuffer() - err := agentStartupHook(log, cfg("zxczxczxc")) + err := agentStartupHook(context.Background(), log, cfg("zxczxczxc")) if assert.NoError(t, err, log.Messages) { assert.Equal(t, []string{}, log.Messages) @@ -95,7 +96,7 @@ func TestAgentShutdownHook(t *testing.T) { defer closer() filepath := writeAgentHook(t, hooksPath, "agent-shutdown") log := logger.NewBuffer() - agentShutdownHook(log, cfg(hooksPath)) + agentShutdownHook(context.Background(), log, cfg(hooksPath)) assert.Equal(t, []string{ "[info] " + prompt + " " + filepath, // prompt @@ -107,12 +108,12 @@ func TestAgentShutdownHook(t *testing.T) { defer closer() log := logger.NewBuffer() - agentShutdownHook(log, cfg(hooksPath)) + agentShutdownHook(context.Background(), log, cfg(hooksPath)) assert.Equal(t, []string{}, log.Messages) }) t.Run("with bad hooks path", func(t *testing.T) { log := logger.NewBuffer() - agentShutdownHook(log, cfg("zxczxczxc")) + agentShutdownHook(context.Background(), log, cfg("zxczxczxc")) assert.Equal(t, []string{}, log.Messages) }) }