diff --git a/internal/pkg/cli/interfaces.go b/internal/pkg/cli/interfaces.go index 2bf20e791a0..1bada558c00 100644 --- a/internal/pkg/cli/interfaces.go +++ b/internal/pkg/cli/interfaces.go @@ -706,12 +706,13 @@ type templateDiffer interface { type dockerEngineRunner interface { CheckDockerEngineRunning() error Run(context.Context, *dockerengine.RunOptions) error - DoesContainerExist(context.Context, string) (bool, error) IsContainerRunning(context.Context, string) (bool, error) Stop(context.Context, string) error - Rm(string) error Build(context.Context, *dockerengine.BuildArguments, io.Writer) error Exec(ctx context.Context, container string, out io.Writer, cmd string, args ...string) error + ContainerExitCode(ctx context.Context, containerName string) (int, error) + IsContainerHealthy(ctx context.Context, containerName string) (bool, error) + Rm(context.Context, string) error } type workloadStackGenerator interface { diff --git a/internal/pkg/cli/mocks/mock_interfaces.go b/internal/pkg/cli/mocks/mock_interfaces.go index f859295f531..decab61a443 100644 --- a/internal/pkg/cli/mocks/mock_interfaces.go +++ b/internal/pkg/cli/mocks/mock_interfaces.go @@ -7693,19 +7693,19 @@ func (mr *MockdockerEngineRunnerMockRecorder) CheckDockerEngineRunning() *gomock return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CheckDockerEngineRunning", reflect.TypeOf((*MockdockerEngineRunner)(nil).CheckDockerEngineRunning)) } -// DoesContainerExist mocks base method. -func (m *MockdockerEngineRunner) DoesContainerExist(arg0 context.Context, arg1 string) (bool, error) { +// ContainerExitCode mocks base method. +func (m *MockdockerEngineRunner) ContainerExitCode(ctx context.Context, containerName string) (int, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "DoesContainerExist", arg0, arg1) - ret0, _ := ret[0].(bool) + ret := m.ctrl.Call(m, "ContainerExitCode", ctx, containerName) + ret0, _ := ret[0].(int) ret1, _ := ret[1].(error) return ret0, ret1 } -// DoesContainerExist indicates an expected call of DoesContainerExist. -func (mr *MockdockerEngineRunnerMockRecorder) DoesContainerExist(arg0, arg1 interface{}) *gomock.Call { +// ContainerExitCode indicates an expected call of ContainerExitCode. +func (mr *MockdockerEngineRunnerMockRecorder) ContainerExitCode(ctx, containerName interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DoesContainerExist", reflect.TypeOf((*MockdockerEngineRunner)(nil).DoesContainerExist), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ContainerExitCode", reflect.TypeOf((*MockdockerEngineRunner)(nil).ContainerExitCode), ctx, containerName) } // Exec mocks base method. @@ -7727,6 +7727,21 @@ func (mr *MockdockerEngineRunnerMockRecorder) Exec(ctx, container, out, cmd inte return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exec", reflect.TypeOf((*MockdockerEngineRunner)(nil).Exec), varargs...) } +// IsContainerHealthy mocks base method. +func (m *MockdockerEngineRunner) IsContainerHealthy(ctx context.Context, containerName string) (bool, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "IsContainerHealthy", ctx, containerName) + ret0, _ := ret[0].(bool) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// IsContainerHealthy indicates an expected call of IsContainerHealthy. +func (mr *MockdockerEngineRunnerMockRecorder) IsContainerHealthy(ctx, containerName interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsContainerHealthy", reflect.TypeOf((*MockdockerEngineRunner)(nil).IsContainerHealthy), ctx, containerName) +} + // IsContainerRunning mocks base method. func (m *MockdockerEngineRunner) IsContainerRunning(arg0 context.Context, arg1 string) (bool, error) { m.ctrl.T.Helper() @@ -7743,17 +7758,17 @@ func (mr *MockdockerEngineRunnerMockRecorder) IsContainerRunning(arg0, arg1 inte } // Rm mocks base method. -func (m *MockdockerEngineRunner) Rm(arg0 string) error { +func (m *MockdockerEngineRunner) Rm(arg0 context.Context, arg1 string) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Rm", arg0) + ret := m.ctrl.Call(m, "Rm", arg0, arg1) ret0, _ := ret[0].(error) return ret0 } // Rm indicates an expected call of Rm. -func (mr *MockdockerEngineRunnerMockRecorder) Rm(arg0 interface{}) *gomock.Call { +func (mr *MockdockerEngineRunnerMockRecorder) Rm(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Rm", reflect.TypeOf((*MockdockerEngineRunner)(nil).Rm), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Rm", reflect.TypeOf((*MockdockerEngineRunner)(nil).Rm), arg0, arg1) } // Run mocks base method. diff --git a/internal/pkg/cli/run_local.go b/internal/pkg/cli/run_local.go index 3fd82a913e2..7dedcc01863 100644 --- a/internal/pkg/cli/run_local.go +++ b/internal/pkg/cli/run_local.go @@ -577,8 +577,6 @@ func (o *runLocalOpts) prepareTask(ctx context.Context) (orchestrator.Task, erro task.Containers[name] = ctr } - // TODO (Adi): Use this dependency order in orchestrator to start and stop containers. - // replace container dependencies with the local dependencies from manifest. containerDeps := manifest.ContainerDependencies(mft.Manifest()) for name, dep := range containerDeps { ctr, ok := task.Containers[name] diff --git a/internal/pkg/docker/dockerengine/dockerengine.go b/internal/pkg/docker/dockerengine/dockerengine.go index 2771bc12494..5f216a584b6 100644 --- a/internal/pkg/docker/dockerengine/dockerengine.go +++ b/internal/pkg/docker/dockerengine/dockerengine.go @@ -9,6 +9,7 @@ import ( "bytes" "context" "encoding/json" + "errors" "fmt" "io" "os" @@ -263,7 +264,6 @@ func (c DockerCmdClient) Push(ctx context.Context, uri string, w io.Writer, tags func (in *RunOptions) generateRunArguments() []string { args := []string{"run"} - args = append(args, "--rm") if in.ContainerName != "" { args = append(args, "--name", in.ContainerName) @@ -303,6 +303,9 @@ func (in *RunOptions) generateRunArguments() []string { // Run runs a Docker container with the sepcified options. func (c DockerCmdClient) Run(ctx context.Context, options *RunOptions) error { + type exitCodeError interface { + ExitCode() int + } // set default options if options.LogOptions.Color == nil { options.LogOptions.Color = color.New() @@ -343,6 +346,13 @@ func (c DockerCmdClient) Run(ctx context.Context, options *RunOptions) error { exec.Stdout(stdout), exec.Stderr(stderr), exec.NewProcessGroup()); err != nil { + var ec exitCodeError + if errors.As(err, &ec) { + return &ErrContainerExited{ + name: options.ContainerName, + exitcode: ec.ExitCode(), + } + } return fmt.Errorf("running container: %w", err) } return nil @@ -351,15 +361,6 @@ func (c DockerCmdClient) Run(ctx context.Context, options *RunOptions) error { return g.Wait() } -// DoesContainerExist checks if a specific Docker container exists. -func (c DockerCmdClient) DoesContainerExist(ctx context.Context, name string) (bool, error) { - output, err := c.containerID(ctx, name) - if err != nil { - return false, err - } - return output != "", nil -} - // IsContainerRunning checks if a specific Docker container is running. func (c DockerCmdClient) IsContainerRunning(ctx context.Context, name string) (bool, error) { state, err := c.containerState(ctx, name) @@ -375,14 +376,14 @@ func (c DockerCmdClient) IsContainerRunning(ctx context.Context, name string) (b return false, nil } -// IsContainerCompleteOrSuccess returns true if a docker container exits with an exitcode. -func (c DockerCmdClient) IsContainerCompleteOrSuccess(ctx context.Context, containerName string) (int, error) { - state, err := c.containerState(ctx, containerName) +// ContainerExitCode returns the exit code of a container. +func (c DockerCmdClient) ContainerExitCode(ctx context.Context, name string) (int, error) { + state, err := c.containerState(ctx, name) if err != nil { return 0, err } if state.Status == containerStatusRunning { - return -1, nil + return 0, &ErrContainerNotExited{name: name} } return state.ExitCode, nil } @@ -452,18 +453,6 @@ func (d *DockerCmdClient) containerID(ctx context.Context, containerName string) return strings.TrimSpace(buf.String()), nil } -// ErrContainerExited represents an error when a Docker container has exited. -// It includes the container name and exit code in the error message. -type ErrContainerExited struct { - name string - exitcode int -} - -// ErrContainerExited represents docker container exited with an exitcode. -func (e *ErrContainerExited) Error() string { - return fmt.Sprintf("container %q exited with code %d", e.name, e.exitcode) -} - // Stop calls `docker stop` to stop a running container. func (c DockerCmdClient) Stop(ctx context.Context, containerID string) error { buf := &bytes.Buffer{} @@ -474,9 +463,9 @@ func (c DockerCmdClient) Stop(ctx context.Context, containerID string) error { } // Rm calls `docker rm` to remove a stopped container. -func (c DockerCmdClient) Rm(containerID string) error { +func (c DockerCmdClient) Rm(ctx context.Context, containerID string) error { buf := &bytes.Buffer{} - if err := c.runner.Run("docker", []string{"rm", containerID}, exec.Stdout(buf), exec.Stderr(buf)); err != nil { + if err := c.runner.RunWithContext(ctx, "docker", []string{"rm", containerID}, exec.Stdout(buf), exec.Stderr(buf)); err != nil { return fmt.Errorf("%s: %w", strings.TrimSpace(buf.String()), err) } return nil @@ -600,11 +589,3 @@ func userHomeDirectory() string { return home } - -type errEmptyImageTags struct { - uri string -} - -func (e *errEmptyImageTags) Error() string { - return fmt.Sprintf("tags to reference an image should not be empty for building and pushing into the ECR repository %s", e.uri) -} diff --git a/internal/pkg/docker/dockerengine/dockerengine_test.go b/internal/pkg/docker/dockerengine/dockerengine_test.go index d8cd130f254..a2c6d5609de 100644 --- a/internal/pkg/docker/dockerengine/dockerengine_test.go +++ b/internal/pkg/docker/dockerengine/dockerengine_test.go @@ -9,6 +9,7 @@ import ( "errors" "fmt" "io" + "os" osexec "os/exec" "path/filepath" "strings" @@ -680,13 +681,35 @@ func TestDockerCommand_Run(t *testing.T) { setupMocks: func(controller *gomock.Controller) { mockCmd = NewMockCmd(controller) mockCmd.EXPECT().RunWithContext(gomock.Any(), "docker", []string{"run", - "--rm", "--name", mockPauseContainer, "mockImageUri", "sleep", "infinity"}, gomock.Any(), gomock.Any(), gomock.Any()).Return(mockError) }, wantedError: fmt.Errorf("running container: %w", mockError), }, + + "should return error when container exits": { + containerName: mockPauseContainer, + command: mockCommand, + uri: mockImageURI, + setupMocks: func(controller *gomock.Controller) { + mockCmd = NewMockCmd(controller) + + mockCmd.EXPECT().RunWithContext(gomock.Any(), "docker", []string{"run", + "--name", mockPauseContainer, + "mockImageUri", + "sleep", "infinity"}, gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, name string, args []string, opts ...exec.CmdOption) error { + // Simulate an zero exit code. + return &osexec.ExitError{ProcessState: &os.ProcessState{}} + }) + }, + wantedError: &ErrContainerExited{ + name: mockPauseContainer, + exitcode: 0, + }, + }, + "success with run options for pause container": { containerName: mockPauseContainer, ports: mockContainerPorts, @@ -695,7 +718,6 @@ func TestDockerCommand_Run(t *testing.T) { setupMocks: func(controller *gomock.Controller) { mockCmd = NewMockCmd(controller) mockCmd.EXPECT().RunWithContext(gomock.Any(), "docker", gomock.InAnyOrder([]string{"run", - "--rm", "--name", mockPauseContainer, "--publish", "8080:8080", "--publish", "8081:8081", @@ -712,7 +734,6 @@ func TestDockerCommand_Run(t *testing.T) { setupMocks: func(controller *gomock.Controller) { mockCmd = NewMockCmd(controller) mockCmd.EXPECT().RunWithContext(gomock.Any(), "docker", gomock.InAnyOrder([]string{"run", - "--rm", "--name", mockContainerName, "--network", "container:pauseContainer", "--env", "DB_PASSWORD=mysecretPassword", @@ -733,7 +754,6 @@ func TestDockerCommand_Run(t *testing.T) { setupMocks: func(controller *gomock.Controller) { mockCmd = NewMockCmd(controller) mockCmd.EXPECT().RunWithContext(gomock.Any(), "docker", gomock.InAnyOrder([]string{"run", - "--rm", "--name", mockContainerName, "--network", "container:pauseContainer", "--env", "DB_PASSWORD=mysecretPassword", @@ -1032,7 +1052,7 @@ func TestDockerCommand_IsContainerHealthy(t *testing.T) { } } -func TestDockerCommand_IsContainerCompleteOrSuccess(t *testing.T) { +func TestDockerCommand_ContainerExitCode(t *testing.T) { tests := map[string]struct { mockContainerName string mockHealthStatus string @@ -1107,7 +1127,7 @@ func TestDockerCommand_IsContainerCompleteOrSuccess(t *testing.T) { }, wantErr: fmt.Errorf("run docker ps: some error"), }, - "return negative exitcode if container is running": { + "return err if container is running": { mockContainerName: "mockContainer", mockHealthStatus: "unhealthy", setupMocks: func(controller *gomock.Controller) *MockCmd { @@ -1134,7 +1154,7 @@ func TestDockerCommand_IsContainerCompleteOrSuccess(t *testing.T) { }) return mockCmd }, - wantExitCode: -1, + wantErr: fmt.Errorf(`container "mockContainer" has not exited`), }, } @@ -1147,7 +1167,7 @@ func TestDockerCommand_IsContainerCompleteOrSuccess(t *testing.T) { runner: tc.setupMocks(ctrl), } - expectedCode, err := s.IsContainerCompleteOrSuccess(context.Background(), tc.mockContainerName) + expectedCode, err := s.ContainerExitCode(context.Background(), tc.mockContainerName) require.Equal(t, tc.wantExitCode, expectedCode) if tc.wantErr != nil { require.EqualError(t, err, tc.wantErr.Error()) diff --git a/internal/pkg/docker/dockerengine/dockerenginetest/dockerenginetest.go b/internal/pkg/docker/dockerengine/dockerenginetest/dockerenginetest.go index a4ec3398a46..ac3416bd571 100644 --- a/internal/pkg/docker/dockerengine/dockerenginetest/dockerenginetest.go +++ b/internal/pkg/docker/dockerengine/dockerenginetest/dockerenginetest.go @@ -13,11 +13,13 @@ import ( // Double is a test double for dockerengine.DockerCmdClient type Double struct { StopFn func(context.Context, string) error - DoesContainerExistFn func(context.Context, string) (bool, error) IsContainerRunningFn func(context.Context, string) (bool, error) RunFn func(context.Context, *dockerengine.RunOptions) error BuildFn func(context.Context, *dockerengine.BuildArguments, io.Writer) error ExecFn func(context.Context, string, io.Writer, string, ...string) error + IsContainerHealthyFn func(ctx context.Context, containerName string) (bool, error) + ContainerExitCodeFn func(ctx context.Context, containerName string) (int, error) + RmFn func(context.Context, string) error } // Stop calls the stubbed function. @@ -28,14 +30,6 @@ func (d *Double) Stop(ctx context.Context, name string) error { return d.StopFn(ctx, name) } -// DoesContainerExist calls the stubbed function. -func (d *Double) DoesContainerExist(ctx context.Context, name string) (bool, error) { - if d.IsContainerRunningFn == nil { - return false, nil - } - return d.DoesContainerExistFn(ctx, name) -} - // IsContainerRunning calls the stubbed function. func (d *Double) IsContainerRunning(ctx context.Context, name string) (bool, error) { if d.IsContainerRunningFn == nil { @@ -67,3 +61,27 @@ func (d *Double) Exec(ctx context.Context, container string, out io.Writer, cmd } return d.ExecFn(ctx, container, out, cmd, args...) } + +// Rm calls the stubbed function. +func (d *Double) Rm(ctx context.Context, name string) error { + if d.RmFn == nil { + return nil + } + return d.RmFn(ctx, name) +} + +// ContainerExitCode implements orchestrator.DockerEngine. +func (d *Double) ContainerExitCode(ctx context.Context, containerName string) (int, error) { + if d.ContainerExitCodeFn == nil { + return 0, nil + } + return d.ContainerExitCodeFn(ctx, containerName) +} + +// IsContainerHealthy implements orchestrator.DockerEngine. +func (d *Double) IsContainerHealthy(ctx context.Context, containerName string) (bool, error) { + if d.IsContainerHealthyFn == nil { + return false, nil + } + return d.IsContainerHealthyFn(ctx, containerName) +} diff --git a/internal/pkg/docker/dockerengine/errors.go b/internal/pkg/docker/dockerengine/errors.go index 9d0f469dcee..d43fbb512ed 100644 --- a/internal/pkg/docker/dockerengine/errors.go +++ b/internal/pkg/docker/dockerengine/errors.go @@ -19,3 +19,38 @@ type ErrDockerDaemonNotResponsive struct { func (e ErrDockerDaemonNotResponsive) Error() string { return fmt.Sprintf("docker daemon is not responsive: %s", e.msg) } + +type errEmptyImageTags struct { + uri string +} + +func (e *errEmptyImageTags) Error() string { + return fmt.Sprintf("tags to reference an image should not be empty for building and pushing into the ECR repository %s", e.uri) +} + +// ErrContainerNotExited represents an error when a Docker container has not exited. +type ErrContainerNotExited struct { + name string +} + +// Error returns the error message. +func (e *ErrContainerNotExited) Error() string { + return fmt.Sprintf("container %q has not exited", e.name) +} + +// ErrContainerExited represents an error when a Docker container has exited. +// It includes the container name and exit code in the error message. +type ErrContainerExited struct { + name string + exitcode int +} + +// ExitCode returns the OS exit code configured for this error. +func (e *ErrContainerExited) ExitCode() int { + return e.exitcode +} + +// ErrContainerExited represents docker container exited with an exitcode. +func (e *ErrContainerExited) Error() string { + return fmt.Sprintf("container %q exited with code %d", e.name, e.exitcode) +} diff --git a/internal/pkg/docker/orchestrator/orchestrator.go b/internal/pkg/docker/orchestrator/orchestrator.go index a6581e954da..4ed28be08d0 100644 --- a/internal/pkg/docker/orchestrator/orchestrator.go +++ b/internal/pkg/docker/orchestrator/orchestrator.go @@ -20,6 +20,10 @@ import ( "time" "github.com/aws/copilot-cli/internal/pkg/docker/dockerengine" + "github.com/aws/copilot-cli/internal/pkg/graph" + "github.com/aws/copilot-cli/internal/pkg/term/color" + "github.com/aws/copilot-cli/internal/pkg/term/log" + "golang.org/x/sync/errgroup" ) // Orchestrator manages running a Task. Only a single Task @@ -48,11 +52,13 @@ type logOptionsFunc func(name string, ctr ContainerDefinition) dockerengine.RunL // DockerEngine is used by Orchestrator to manage containers. type DockerEngine interface { Run(context.Context, *dockerengine.RunOptions) error - DoesContainerExist(context.Context, string) (bool, error) IsContainerRunning(context.Context, string) (bool, error) + ContainerExitCode(ctx context.Context, containerName string) (int, error) + IsContainerHealthy(ctx context.Context, containerName string) (bool, error) Stop(context.Context, string) error Build(ctx context.Context, args *dockerengine.BuildArguments, w io.Writer) error Exec(ctx context.Context, container string, out io.Writer, cmd string, args ...string) error + Rm(context.Context, string) error } const ( @@ -69,6 +75,13 @@ const ( proxyPortStart = uint16(50000) ) +const ( + ctrStateHealthy = "healthy" + ctrStateComplete = "complete" + ctrStateSuccess = "success" + ctrStateStart = "start" +) + //go:embed Pause-Dockerfile var pauseDockerfile string @@ -95,7 +108,9 @@ func New(docker DockerEngine, idPrefix string, logOptions logOptionsFunc) *Orche func (o *Orchestrator) Start() <-chan error { // close done when all goroutines created by Orchestrator have finished done := make(chan struct{}) - errs := make(chan error) + // buffered channel so that the orchestrator routine does not block and + // can always send the error from both runErrs and action.Do to errs. + errs := make(chan error, 1) // orchestrator routine o.wg.Add(1) // decremented by stopAction @@ -187,7 +202,8 @@ func (a *runTaskAction) Do(o *Orchestrator) error { cancel() } }() - + prevTask := o.curTask + o.curTask = a.task if taskID == 1 { if err := o.buildPauseContainer(ctx); err != nil { return fmt.Errorf("build pause container: %w", err) @@ -195,7 +211,7 @@ func (a *runTaskAction) Do(o *Orchestrator) error { // start the pause container opts := o.pauseRunOptions(a.task) - o.run(pauseCtrTaskID, opts) + o.run(pauseCtrTaskID, opts, true, cancel) if err := o.waitForContainerToStart(ctx, opts.ContainerName); err != nil { return fmt.Errorf("wait for pause container to start: %w", err) } @@ -207,28 +223,58 @@ func (a *runTaskAction) Do(o *Orchestrator) error { } } else { // ensure no pause container changes - curOpts := o.pauseRunOptions(o.curTask) + prevOpts := o.pauseRunOptions(prevTask) newOpts := o.pauseRunOptions(a.task) - if !maps.Equal(curOpts.EnvVars, newOpts.EnvVars) || - !maps.Equal(curOpts.Secrets, newOpts.Secrets) || - !maps.Equal(curOpts.ContainerPorts, newOpts.ContainerPorts) { + if !maps.Equal(prevOpts.EnvVars, newOpts.EnvVars) || + !maps.Equal(prevOpts.Secrets, newOpts.Secrets) || + !maps.Equal(prevOpts.ContainerPorts, newOpts.ContainerPorts) { return errors.New("new task requires recreating pause container") } - if err := o.stopTask(ctx, o.curTask); err != nil { + if err := o.stopTask(ctx, prevTask); err != nil { return fmt.Errorf("stop existing task: %w", err) } } - - for name, ctr := range a.task.Containers { - name, ctr := name, ctr - o.run(taskID, o.containerRunOptions(name, ctr)) + depGraph := buildDependencyGraph(a.task.Containers) + err := depGraph.UpwardTraversal(ctx, func(ctx context.Context, containerName string) error { + if len(a.task.Containers[containerName].DependsOn) > 0 { + if err := o.waitForContainerDependencies(ctx, containerName, a.task.Containers); err != nil { + return fmt.Errorf("wait for container %s dependencies: %w", containerName, err) + } + } + o.run(taskID, o.containerRunOptions(containerName, a.task.Containers[containerName]), a.task.Containers[containerName].IsEssential, cancel) + var errContainerExited *dockerengine.ErrContainerExited + if err := o.waitForContainerToStart(ctx, o.containerID(containerName)); err != nil && !errors.As(err, &errContainerExited) { + return fmt.Errorf("wait for container %s to start: %w", containerName, err) + } + return nil + }) + if err != nil { + if errors.Is(err, context.Canceled) { + return nil + } + return fmt.Errorf("upward traversal: %w", err) } - - o.curTask = a.task return nil } +func buildDependencyGraph(containers map[string]ContainerDefinition) *graph.LabeledGraph[string] { + var vertices []string + for vertex := range containers { + vertices = append(vertices, vertex) + } + dependencyGraph := graph.NewLabeledGraph(vertices) + for containerName, container := range containers { + for depCtr := range container.DependsOn { + dependencyGraph.Add(graph.Edge[string]{ + From: containerName, + To: depCtr, + }) + } + } + return dependencyGraph +} + // setupProxyConnections creates proxy connections to a.hosts in pauseContainer. // It assumes that pauseContainer is already running. A unique proxy connection // is created for each host (in parallel) using AWS SSM Port Forwarding through @@ -353,7 +399,6 @@ func (o *Orchestrator) buildPauseContainer(ctx context.Context) error { func (o *Orchestrator) Stop() { o.stopOnce.Do(func() { close(o.stopped) - fmt.Printf("\nStopping task...\n") o.actions <- &stopAction{} }) } @@ -364,6 +409,7 @@ func (a *stopAction) Do(o *Orchestrator) error { defer o.wg.Done() // for the Orchestrator o.curTaskID.Store(orchestratorStoppedTaskID) // ignore runtime errors + fmt.Printf("\nStopping task...\n") // collect errors since we want to try to clean up everything we can var errs []error if err := o.stopTask(context.Background(), o.curTask); err != nil { @@ -371,12 +417,14 @@ func (a *stopAction) Do(o *Orchestrator) error { } // stop pause container - fmt.Printf("Stopping %q\n", "pause") + fmt.Printf("Stopping and removing %q\n", "pause") if err := o.docker.Stop(context.Background(), o.containerID("pause")); err != nil { errs = append(errs, fmt.Errorf("stop %q: %w", "pause", err)) } - fmt.Printf("Stopped %q\n", "pause") - + if err := o.docker.Rm(context.Background(), o.containerID("pause")); err != nil { + errs = append(errs, fmt.Errorf("remove %q: %w", "pause", err)) + } + fmt.Printf("Stopped and removed %q\n", "pause") return errors.Join(errs...) } @@ -387,39 +435,25 @@ func (o *Orchestrator) stopTask(ctx context.Context, task Task) error { } // errCh gets one error per container - errCh := make(chan error) - for name := range task.Containers { - name := name - go func() { - fmt.Printf("Stopping %q\n", name) - if err := o.docker.Stop(ctx, o.containerID(name)); err != nil { - errCh <- fmt.Errorf("stop %q: %w", name, err) - return - } - - // ensure that container is fully stopped before stopTask finishes blocking - for { - exists, err := o.docker.DoesContainerExist(ctx, o.containerID(name)) - if err != nil { - errCh <- fmt.Errorf("polling container %q for removal: %w", name, err) - return - } - - if exists { - select { - case <-time.After(1 * time.Second): - continue - case <-ctx.Done(): - errCh <- fmt.Errorf("check container %q stopped: %w", name, ctx.Err()) - return - } - } + errCh := make(chan error, len(task.Containers)) + depGraph := buildDependencyGraph(task.Containers) + err := depGraph.DownwardTraversal(ctx, func(ctx context.Context, name string) error { + fmt.Printf("Stopping and removing %q\n", name) + if err := o.docker.Stop(ctx, o.containerID(name)); err != nil { + errCh <- fmt.Errorf("stop %q: %w", name, err) + return nil + } + if err := o.docker.Rm(ctx, o.containerID(name)); err != nil { + errCh <- fmt.Errorf("remove %q: %w", name, err) + return nil + } + fmt.Printf("Stopped and removed %q\n", name) + errCh <- nil + return nil + }) - fmt.Printf("Stopped %q\n", name) - errCh <- nil - return - } - }() + if err != nil { + return fmt.Errorf("downward traversal: %w", err) } var errs []error @@ -441,6 +475,7 @@ func (o *Orchestrator) waitForContainerToStart(ctx context.Context, id string) e case err != nil: return fmt.Errorf("check if %q is running: %w", id, err) case isRunning: + log.Successf("Successfully started container %s\n", id) return nil } @@ -452,6 +487,70 @@ func (o *Orchestrator) waitForContainerToStart(ctx context.Context, id string) e } } +func (o *Orchestrator) waitForContainerDependencies(ctx context.Context, name string, definitions map[string]ContainerDefinition) error { + var deps []string + for depName, state := range definitions[name].DependsOn { + deps = append(deps, fmt.Sprintf("%s->%s", depName, state)) + } + logMsg := strings.Join(deps, ", ") + fmt.Printf("Waiting for container %q dependencies: [%s]\n", name, color.Emphasize(logMsg)) + eg, ctx := errgroup.WithContext(ctx) + for name, state := range definitions[name].DependsOn { + name, state := name, state + eg.Go(func() error { + ctrId := o.containerID(name) + ticker := time.NewTicker(700 * time.Millisecond) + defer ticker.Stop() + for { + select { + case <-ticker.C: + case <-ctx.Done(): + return ctx.Err() + } + switch state { + case ctrStateStart: + return nil + case ctrStateHealthy: + healthy, err := o.docker.IsContainerHealthy(ctx, ctrId) + if err != nil { + return fmt.Errorf("wait for container %q to be healthy: %w", ctrId, err) + } + if healthy { + log.Successf("Successfully dependency container %q reached healthy\n", ctrId) + return nil + } + case ctrStateComplete: + exitCode, err := o.docker.ContainerExitCode(ctx, ctrId) + var errContainerNotExited *dockerengine.ErrContainerNotExited + if errors.As(err, &errContainerNotExited) { + continue + } + if err != nil { + return fmt.Errorf("wait for container %q to complete: %w", ctrId, err) + } + log.Successf("%q's dependency container %q exited with code: %d\n", name, ctrId, exitCode) + return nil + case ctrStateSuccess: + exitCode, err := o.docker.ContainerExitCode(ctx, ctrId) + var errContainerNotExited *dockerengine.ErrContainerNotExited + if errors.As(err, &errContainerNotExited) { + continue + } + if err != nil { + return fmt.Errorf("wait for container %q to success: %w", ctrId, err) + } + if exitCode != 0 { + return fmt.Errorf("dependency container %q exited with non-zero exit code %d", ctrId, exitCode) + } + log.Successf("%q's dependency container %q exited with code: %d\n", name, ctrId, exitCode) + return nil + } + } + }) + } + return eg.Wait() +} + // containerID returns the full ID for a container with name run by s. func (o *Orchestrator) containerID(name string) string { return o.idPrefix + name @@ -513,7 +612,7 @@ func (o *Orchestrator) containerRunOptions(name string, ctr ContainerDefinition) // run calls `docker run` using opts. Errors are only returned // to the main Orchestrator routine if the taskID the container was run with // matches the current taskID the Orchestrator is running. -func (o *Orchestrator) run(taskID int32, opts dockerengine.RunOptions) { +func (o *Orchestrator) run(taskID int32, opts dockerengine.RunOptions, isEssential bool, cancel context.CancelFunc) { o.wg.Add(1) go func() { defer o.wg.Done() @@ -529,9 +628,16 @@ func (o *Orchestrator) run(taskID int32, opts dockerengine.RunOptions) { // the error is from the pause container // or from the currently running task if taskID == pauseCtrTaskID || taskID == curTaskID { + var errContainerExited *dockerengine.ErrContainerExited + if !isEssential && (errors.As(err, &errContainerExited) || err == nil) { + fmt.Printf("non-essential container %q stopped\n", opts.ContainerName) + return + } if err == nil { err = errors.New("container stopped unexpectedly") } + // cancel context to indicate all the other go routines spawned by `graph.UpwardTarversal`. + cancel() o.runErrs <- fmt.Errorf("run %q: %w", opts.ContainerName, err) } }() diff --git a/internal/pkg/docker/orchestrator/orchestrator_test.go b/internal/pkg/docker/orchestrator/orchestrator_test.go index 53ad0114aa8..756d953feb6 100644 --- a/internal/pkg/docker/orchestrator/orchestrator_test.go +++ b/internal/pkg/docker/orchestrator/orchestrator_test.go @@ -85,9 +85,6 @@ func TestOrchestrator(t *testing.T) { runUntilStopped: true, test: func(t *testing.T) (test, *dockerenginetest.Double) { de := &dockerenginetest.Double{ - DoesContainerExistFn: func(ctx context.Context, s string) (bool, error) { - return false, nil - }, IsContainerRunningFn: func(ctx context.Context, name string) (bool, error) { return true, nil }, @@ -114,33 +111,38 @@ func TestOrchestrator(t *testing.T) { `stop "bar": some error`, }, }, - "error polling tasks removed": { + "error removing task": { logOptions: noLogs, runUntilStopped: true, test: func(t *testing.T) (test, *dockerenginetest.Double) { de := &dockerenginetest.Double{ - DoesContainerExistFn: func(ctx context.Context, s string) (bool, error) { - return false, errors.New("some error") - }, IsContainerRunningFn: func(ctx context.Context, name string) (bool, error) { return true, nil }, StopFn: func(ctx context.Context, name string) error { return nil }, + RmFn: func(ctx context.Context, name string) error { + if name == "prefix-success" { + return nil + } + return errors.New("some error") + }, } return func(t *testing.T, o *Orchestrator) { o.RunTask(Task{ Containers: map[string]ContainerDefinition{ - "foo": {}, - "bar": {}, + "foo": {}, + "bar": {}, + "success": {}, }, }) }, de }, errs: []string{ - `polling container "foo" for removal: some error`, - `polling container "bar" for removal: some error`, + `remove "pause": some error`, + `remove "foo": some error`, + `remove "bar": some error`, }, }, "error restarting new task due to pause changes": { @@ -148,9 +150,6 @@ func TestOrchestrator(t *testing.T) { runUntilStopped: true, test: func(t *testing.T) (test, *dockerenginetest.Double) { de := &dockerenginetest.Double{ - DoesContainerExistFn: func(ctx context.Context, s string) (bool, error) { - return false, nil - }, IsContainerRunningFn: func(ctx context.Context, name string) (bool, error) { return true, nil }, @@ -185,9 +184,6 @@ func TestOrchestrator(t *testing.T) { runUntilStopped: true, test: func(t *testing.T) (test, *dockerenginetest.Double) { de := &dockerenginetest.Double{ - DoesContainerExistFn: func(ctx context.Context, s string) (bool, error) { - return false, nil - }, IsContainerRunningFn: func(ctx context.Context, name string) (bool, error) { return true, nil }, @@ -229,17 +225,320 @@ func TestOrchestrator(t *testing.T) { }, de }, }, - "container run stops early with error": { + + "return nil if non essential container exits": { + logOptions: noLogs, + runUntilStopped: true, + test: func(t *testing.T) (test, *dockerenginetest.Double) { + de := &dockerenginetest.Double{ + IsContainerRunningFn: func(ctx context.Context, name string) (bool, error) { + if name == "prefix-foo" { + return false, &dockerengine.ErrContainerExited{} + } + return true, nil + }, + RunFn: func(ctx context.Context, opts *dockerengine.RunOptions) error { + return nil + }, + StopFn: func(ctx context.Context, name string) error { + return nil + }, + } + return func(t *testing.T, o *Orchestrator) { + o.RunTask(Task{ + Containers: map[string]ContainerDefinition{ + "foo": { + IsEssential: false, + }, + }, + }) + }, de + }, + }, + + "success with dependsOn order": { logOptions: noLogs, test: func(t *testing.T) (test, *dockerenginetest.Double) { stopPause := make(chan struct{}) + stopFoo := make(chan struct{}) de := &dockerenginetest.Double{ - DoesContainerExistFn: func(ctx context.Context, s string) (bool, error) { - return false, nil + IsContainerRunningFn: func(ctx context.Context, name string) (bool, error) { + return true, nil + }, + RunFn: func(ctx context.Context, opts *dockerengine.RunOptions) error { + if opts.ContainerName == "prefix-pause" { + // block pause container until Stop(pause) + <-stopPause + } + if opts.ContainerName == "prefix-foo" { + // block bar container until Stop(foo) + <-stopFoo + } + return nil + }, + StopFn: func(ctx context.Context, s string) error { + if s == "prefix-pause" { + stopPause <- struct{}{} + } + if s == "prefix-foo" { + stopFoo <- struct{}{} + } + return nil }, + } + return func(t *testing.T, o *Orchestrator) { + o.RunTask(Task{ + Containers: map[string]ContainerDefinition{ + "foo": { + IsEssential: true, + }, + "bar": { + IsEssential: true, + DependsOn: map[string]string{ + "foo": ctrStateStart, + }, + }, + }, + }) + }, de + }, + }, + + "return error when dependency container is unhealthy": { + logOptions: noLogs, + stopAfterNErrs: 1, + test: func(t *testing.T) (test, *dockerenginetest.Double) { + stopPause := make(chan struct{}) + stopFoo := make(chan struct{}) + de := &dockerenginetest.Double{ IsContainerRunningFn: func(ctx context.Context, name string) (bool, error) { return true, nil }, + RunFn: func(ctx context.Context, opts *dockerengine.RunOptions) error { + if opts.ContainerName == "prefix-bar" { + return errors.New("container `prefix-bar` exited with code 143") + } + if opts.ContainerName == "prefix-pause" { + // block pause container until Stop(pause) + <-stopPause + } + if opts.ContainerName == "prefix-foo" { + // block bar container until Stop(foo) + <-stopFoo + } + return nil + }, + StopFn: func(ctx context.Context, s string) error { + if s == "prefix-pause" { + stopPause <- struct{}{} + } + if s == "prefix-foo" { + stopFoo <- struct{}{} + } + return nil + }, + IsContainerHealthyFn: func(ctx context.Context, containerName string) (bool, error) { + if containerName == "prefix-foo" { + return false, fmt.Errorf("container `prefix-foo` is unhealthy") + } + return true, nil + }, + } + return func(t *testing.T, o *Orchestrator) { + o.RunTask(Task{ + Containers: map[string]ContainerDefinition{ + "foo": { + IsEssential: true, + }, + "bar": { + IsEssential: false, + DependsOn: map[string]string{ + "foo": ctrStateHealthy, + }, + }, + }, + }) + }, de + }, + errs: []string{"upward traversal: wait for container bar dependencies: wait for container \"prefix-foo\" to be healthy: container `prefix-foo` is unhealthy"}, + }, + + "return error when dependency container is not started": { + logOptions: noLogs, + stopAfterNErrs: 1, + test: func(t *testing.T) (test, *dockerenginetest.Double) { + stopPause := make(chan struct{}) + stopFoo := make(chan struct{}) + de := &dockerenginetest.Double{ + IsContainerRunningFn: func(ctx context.Context, name string) (bool, error) { + if name == "prefix-foo" { + return false, fmt.Errorf("some error") + } + return true, nil + }, + RunFn: func(ctx context.Context, opts *dockerengine.RunOptions) error { + if opts.ContainerName == "prefix-pause" { + // block pause container until Stop(pause) + <-stopPause + } + if opts.ContainerName == "prefix-foo" { + // block bar container until Stop(foo) + <-stopFoo + } + return nil + }, + StopFn: func(ctx context.Context, s string) error { + if s == "prefix-pause" { + stopPause <- struct{}{} + } + if s == "prefix-foo" { + stopFoo <- struct{}{} + } + return nil + }, + } + return func(t *testing.T, o *Orchestrator) { + o.RunTask(Task{ + Containers: map[string]ContainerDefinition{ + "foo": { + IsEssential: false, + }, + "bar": { + IsEssential: true, + DependsOn: map[string]string{ + "foo": ctrStateStart, + }, + }, + }, + }) + }, de + }, + errs: []string{"upward traversal: wait for container foo to start: check if \"prefix-foo\" is running: some error"}, + }, + + "return error when dependency container complete failed": { + logOptions: noLogs, + stopAfterNErrs: 1, + test: func(t *testing.T) (test, *dockerenginetest.Double) { + stopPause := make(chan struct{}) + stopFoo := make(chan struct{}) + de := &dockerenginetest.Double{ + IsContainerRunningFn: func(ctx context.Context, name string) (bool, error) { + return true, nil + }, + RunFn: func(ctx context.Context, opts *dockerengine.RunOptions) error { + if opts.ContainerName == "prefix-pause" { + // block pause container until Stop(pause) + <-stopPause + } + if opts.ContainerName == "prefix-foo" { + // block bar container until Stop(foo) + <-stopFoo + } + return nil + }, + StopFn: func(ctx context.Context, s string) error { + if s == "prefix-pause" { + stopPause <- struct{}{} + } + if s == "prefix-foo" { + stopFoo <- struct{}{} + } + return nil + }, + ContainerExitCodeFn: func(ctx context.Context, name string) (int, error) { + if name == "prefix-foo" { + return 143, fmt.Errorf("some error") + } + return 0, &dockerengine.ErrContainerNotExited{} + }, + } + return func(t *testing.T, o *Orchestrator) { + o.RunTask(Task{ + Containers: map[string]ContainerDefinition{ + "foo": { + IsEssential: false, + }, + "bar": { + IsEssential: true, + DependsOn: map[string]string{ + "foo": ctrStateComplete, + }, + }, + }, + }) + }, de + }, + errs: []string{"upward traversal: wait for container bar dependencies: wait for container \"prefix-foo\" to complete: some error"}, + }, + + "return error when container with non zero exitcode if condition is success": { + logOptions: noLogs, + stopAfterNErrs: 1, + test: func(t *testing.T) (test, *dockerenginetest.Double) { + stopPause := make(chan struct{}) + stopFoo := make(chan struct{}) + de := &dockerenginetest.Double{ + IsContainerRunningFn: func(ctx context.Context, name string) (bool, error) { + return true, nil + }, + RunFn: func(ctx context.Context, opts *dockerengine.RunOptions) error { + if opts.ContainerName == "prefix-pause" { + // block pause container until Stop(pause) + <-stopPause + } + if opts.ContainerName == "prefix-foo" { + // block bar container until Stop(foo) + <-stopFoo + } + return nil + }, + StopFn: func(ctx context.Context, s string) error { + if s == "prefix-pause" { + stopPause <- struct{}{} + } + if s == "prefix-foo" { + stopFoo <- struct{}{} + } + return nil + }, + ContainerExitCodeFn: func(ctx context.Context, containerName string) (int, error) { + if containerName == "prefix-foo" { + return 143, nil + } + return 0, &dockerengine.ErrContainerNotExited{} + }, + } + return func(t *testing.T, o *Orchestrator) { + o.RunTask(Task{ + Containers: map[string]ContainerDefinition{ + "foo": { + IsEssential: false, + }, + "bar": { + IsEssential: true, + DependsOn: map[string]string{ + "foo": ctrStateSuccess, + }, + }, + }, + }) + }, de + }, + errs: []string{"upward traversal: wait for container bar dependencies: dependency container \"prefix-foo\" exited with non-zero exit code 143"}, + }, + + "container run stops early with error": { + logOptions: noLogs, + test: func(t *testing.T) (test, *dockerenginetest.Double) { + stopPause := make(chan struct{}) + de := &dockerenginetest.Double{ + IsContainerRunningFn: func(ctx context.Context, name string) (bool, error) { + return true, nil + }, + ContainerExitCodeFn: func(ctx context.Context, containerName string) (int, error) { + return 0, nil + }, RunFn: func(ctx context.Context, opts *dockerengine.RunOptions) error { if opts.ContainerName == "prefix-foo" { return errors.New("some error") @@ -259,7 +558,15 @@ func TestOrchestrator(t *testing.T) { return func(t *testing.T, o *Orchestrator) { o.RunTask(Task{ Containers: map[string]ContainerDefinition{ - "foo": {}, + "foo": { + IsEssential: true, + }, + "bar": { + IsEssential: true, + DependsOn: map[string]string{ + "foo": "start", + }, + }, }, }) }, de @@ -272,9 +579,6 @@ func TestOrchestrator(t *testing.T) { test: func(t *testing.T) (test, *dockerenginetest.Double) { stopPause := make(chan struct{}) de := &dockerenginetest.Double{ - DoesContainerExistFn: func(ctx context.Context, s string) (bool, error) { - return false, nil - }, IsContainerRunningFn: func(ctx context.Context, name string) (bool, error) { return true, nil }, @@ -297,7 +601,9 @@ func TestOrchestrator(t *testing.T) { return func(t *testing.T, o *Orchestrator) { o.RunTask(Task{ Containers: map[string]ContainerDefinition{ - "foo": {}, + "foo": { + IsEssential: true, + }, }, }) }, de @@ -420,9 +726,6 @@ func TestOrchestrator(t *testing.T) { runUntilStopped: true, test: func(t *testing.T) (test, *dockerenginetest.Double) { de := &dockerenginetest.Double{ - DoesContainerExistFn: func(ctx context.Context, s string) (bool, error) { - return false, nil - }, IsContainerRunningFn: func(ctx context.Context, name string) (bool, error) { return true, nil }, diff --git a/internal/pkg/graph/graph.go b/internal/pkg/graph/graph.go index 0eb69f059dc..9769b7423c9 100644 --- a/internal/pkg/graph/graph.go +++ b/internal/pkg/graph/graph.go @@ -227,46 +227,32 @@ func TopologicalOrder[V comparable](digraph *Graph[V]) (*TopologicalSorter[V], e // It is concurrency-safe, utilizing a mutex lock for synchronized access. type LabeledGraph[V comparable] struct { *Graph[V] - status map[V]string + status map[V]vertexStatus lock sync.Mutex } -// NewLabeledGraph initializes a LabeledGraph with specified vertices and optional configurations. -// It creates a base Graph with the vertices and applies any LabeledGraphOption to configure additional properties. -func NewLabeledGraph[V comparable](vertices []V, opts ...LabeledGraphOption[V]) *LabeledGraph[V] { - g := New(vertices...) +// NewLabeledGraph initializes a LabeledGraph with specified vertices and set the status of each vertex to unvisited. +func NewLabeledGraph[V comparable](vertices []V) *LabeledGraph[V] { lg := &LabeledGraph[V]{ - Graph: g, - status: make(map[V]string), + Graph: New(vertices...), + status: make(map[V]vertexStatus), + lock: sync.Mutex{}, } - for _, opt := range opts { - opt(lg) + for _, vertex := range vertices { + lg.status[vertex] = unvisited } return lg } -// LabeledGraphOption allows you to initialize Graph with additional properties. -type LabeledGraphOption[V comparable] func(g *LabeledGraph[V]) - -// WithStatus sets the status of each vertex in the Graph. -func WithStatus[V comparable](status string) func(g *LabeledGraph[V]) { - return func(g *LabeledGraph[V]) { - g.status = make(map[V]string) - for vertex := range g.vertices { - g.status[vertex] = status - } - } -} - // updateStatus updates the status of a vertex. -func (lg *LabeledGraph[V]) updateStatus(vertex V, status string) { +func (lg *LabeledGraph[V]) updateStatus(vertex V, status vertexStatus) { lg.lock.Lock() defer lg.lock.Unlock() lg.status[vertex] = status } // getStatus gets the status of a vertex. -func (lg *LabeledGraph[V]) getStatus(vertex V) string { +func (lg *LabeledGraph[V]) getStatus(vertex V) vertexStatus { lg.lock.Lock() defer lg.lock.Unlock() return lg.status[vertex] @@ -306,7 +292,7 @@ func (lg *LabeledGraph[V]) children(vtx V) []V { } // filterParents filters parents based on the vertex status. -func (lg *LabeledGraph[V]) filterParents(vtx V, status string) []V { +func (lg *LabeledGraph[V]) filterParents(vtx V, status vertexStatus) []V { parents := lg.parents(vtx) var filtered []V for _, parent := range parents { @@ -318,7 +304,7 @@ func (lg *LabeledGraph[V]) filterParents(vtx V, status string) []V { } // filterChildren filters children based on the vertex status. -func (lg *LabeledGraph[V]) filterChildren(vtx V, status string) []V { +func (lg *LabeledGraph[V]) filterChildren(vtx V, status vertexStatus) []V { children := lg.children(vtx) var filtered []V for _, child := range children { @@ -330,57 +316,39 @@ func (lg *LabeledGraph[V]) filterChildren(vtx V, status string) []V { } /* -UpwardTraversal performs an upward traversal on the graph starting from leaves (nodes with no children) -and moving towards root nodes (nodes with children). -It applies the specified process function to each vertex in the graph, skipping vertices with the -"adjacentVertexSkipStatus" status, and continuing traversal until reaching vertices with the "requiredVertexStatus" status. -The traversal is concurrent and may process vertices in parallel. -Returns an error if the traversal encounters any issues, or nil if successful. +UpwardTraversal performs a traversal from leaf nodes (with no children) to root nodes (with children). +It processes each vertex using processVertexFunc, and skips processing for vertices with specific statuses. +The traversal is concurrent, handling vertices in parallel, and returns an error if any issue occurs. */ -func (lg *LabeledGraph[V]) UpwardTraversal(ctx context.Context, processVertexFunc func(context.Context, V) error, nextVertexSkipStatus, requiredVertexStatus string) error { +func (lg *LabeledGraph[V]) UpwardTraversal(ctx context.Context, processVertexFunc func(context.Context, V) error) error { traversal := &graphTraversal[V]{ - mu: sync.Mutex{}, - seen: make(map[V]struct{}), findStartVertices: func(lg *LabeledGraph[V]) []V { return lg.leaves() }, findNextVertices: func(lg *LabeledGraph[V], v V) []V { return lg.parents(v) }, - filterPreviousVerticesByStatus: func(g *LabeledGraph[V], v V, status string) []V { return g.filterChildren(v, status) }, - requiredVertexStatus: requiredVertexStatus, - nextVertexSkipStatus: nextVertexSkipStatus, + filterPreviousVerticesByStatus: func(g *LabeledGraph[V], v V, status vertexStatus) []V { return g.filterChildren(v, status) }, processVertex: processVertexFunc, } return traversal.execute(ctx, lg) } /* -DownwardTraversal performs a downward traversal on the graph starting from root nodes (nodes with no parents) -and moving towards leaf nodes (nodes with parents). It applies the specified process function to each -vertex in the graph, skipping vertices with the "adjacentVertexSkipStatus" status, and continuing traversal -until reaching vertices with the "requiredVertexStatus" status. -The traversal is concurrent and may process vertices in parallel. -Returns an error if the traversal encounters any issues. +DownwardTraversal performs a traversal from root nodes (with no parents) to leaf nodes (with parents). +It applies processVertexFunc to each vertex, skipping those with specified statuses. +It conducts concurrent processing of vertices and returns an error for any encountered issues. */ -func (lg *LabeledGraph[V]) DownwardTraversal(ctx context.Context, processVertexFunc func(context.Context, V) error, adjacentVertexSkipStatus, requiredVertexStatus string) error { +func (lg *LabeledGraph[V]) DownwardTraversal(ctx context.Context, processVertexFunc func(context.Context, V) error) error { traversal := &graphTraversal[V]{ - mu: sync.Mutex{}, - seen: make(map[V]struct{}), findStartVertices: func(lg *LabeledGraph[V]) []V { return lg.Roots() }, findNextVertices: func(lg *LabeledGraph[V], v V) []V { return lg.children(v) }, - filterPreviousVerticesByStatus: func(lg *LabeledGraph[V], v V, status string) []V { return lg.filterParents(v, status) }, - requiredVertexStatus: requiredVertexStatus, - nextVertexSkipStatus: adjacentVertexSkipStatus, + filterPreviousVerticesByStatus: func(lg *LabeledGraph[V], v V, status vertexStatus) []V { return lg.filterParents(v, status) }, processVertex: processVertexFunc, } return traversal.execute(ctx, lg) } type graphTraversal[V comparable] struct { - mu sync.Mutex - seen map[V]struct{} findStartVertices func(*LabeledGraph[V]) []V findNextVertices func(*LabeledGraph[V], V) []V - filterPreviousVerticesByStatus func(*LabeledGraph[V], V, string) []V - requiredVertexStatus string - nextVertexSkipStatus string + filterPreviousVerticesByStatus func(*LabeledGraph[V], V, vertexStatus) []V processVertex func(context.Context, V) error } @@ -400,20 +368,23 @@ func (t *graphTraversal[V]) execute(ctx context.Context, lg *LabeledGraph[V]) er processVertices := func(ctx context.Context, graph *LabeledGraph[V], eg *errgroup.Group, vertices []V, vertexCh chan V) { for _, vertex := range vertices { vertex := vertex - // Delay processing this vertex if any of its dependent vertices are yet to be processed. - if len(t.filterPreviousVerticesByStatus(graph, vertex, t.nextVertexSkipStatus)) != 0 { + // If any of the vertices that should be visited before this vertex are yet to be processed, we delay processing it. + if len(t.filterPreviousVerticesByStatus(graph, vertex, unvisited)) != 0 || + len(t.filterPreviousVerticesByStatus(graph, vertex, visiting)) != 0 { continue } - if !t.markAsSeen(vertex) { - // Skip this vertex if it's already been processed by another routine. + // Check if the vertex is already visited or being visited + if graph.getStatus(vertex) != unvisited { continue } + // Mark the vertex as visiting + graph.updateStatus(vertex, visiting) eg.Go(func() error { if err := t.processVertex(ctx, vertex); err != nil { return err } // Assign new status to the vertex upon successful processing. - graph.updateStatus(vertex, t.requiredVertexStatus) + graph.updateStatus(vertex, visited) vertexCh <- vertex return nil }) @@ -437,13 +408,3 @@ func (t *graphTraversal[V]) execute(ctx context.Context, lg *LabeledGraph[V]) er processVertices(ctx, lg, eg, t.findStartVertices(lg), vertexCh) return eg.Wait() } - -func (t *graphTraversal[V]) markAsSeen(vertex V) bool { - t.mu.Lock() - defer t.mu.Unlock() - if _, seen := t.seen[vertex]; seen { - return false - } - t.seen[vertex] = struct{}{} - return true -} diff --git a/internal/pkg/graph/graph_test.go b/internal/pkg/graph/graph_test.go index 1f4a818e1a1..cf770ccf441 100644 --- a/internal/pkg/graph/graph_test.go +++ b/internal/pkg/graph/graph_test.go @@ -373,7 +373,7 @@ func TestTopologicalOrder(t *testing.T) { func buildGraphWithSingleParent() *LabeledGraph[string] { vertices := []string{"A", "B", "C", "D"} - graph := NewLabeledGraph[string](vertices, WithStatus[string]("started")) + graph := NewLabeledGraph[string](vertices) graph.Add(Edge[string]{From: "D", To: "C"}) // D -> C graph.Add(Edge[string]{From: "C", To: "B"}) // C -> B graph.Add(Edge[string]{From: "B", To: "A"}) // B -> A @@ -388,14 +388,14 @@ func TestTraverseInDependencyOrder(t *testing.T) { visited = append(visited, v) return nil } - err := graph.UpwardTraversal(context.Background(), processFn, "started", "stopped") + err := graph.UpwardTraversal(context.Background(), processFn) require.NoError(t, err) expected := []string{"A", "B", "C", "D"} require.Equal(t, expected, visited) }) t.Run("graph with multiple parents and boundary nodes", func(t *testing.T) { vertices := []string{"A", "B", "C", "D"} - graph := NewLabeledGraph[string](vertices, WithStatus[string]("started")) + graph := NewLabeledGraph[string](vertices) graph.Add(Edge[string]{From: "A", To: "C"}) graph.Add(Edge[string]{From: "A", To: "D"}) graph.Add(Edge[string]{From: "B", To: "D"}) @@ -412,7 +412,7 @@ func TestTraverseInDependencyOrder(t *testing.T) { err := graph.DownwardTraversal(context.Background(), func(ctx context.Context, vtx string) error { vtxChan <- vtx return nil - }, "started", "stopped") + }) require.NoError(t, err, "Error during iteration") close(vtxChan) <-done @@ -432,7 +432,7 @@ func TestTraverseInReverseDependencyOrder(t *testing.T) { visited = append(visited, v) return nil } - err := graph.DownwardTraversal(context.Background(), processFn, "started", "stopped") + err := graph.DownwardTraversal(context.Background(), processFn) require.NoError(t, err) expected := []string{"D", "C", "B", "A"} require.Equal(t, expected, visited)