From 352d20c8ffbbcd12bb500d8be31a479336c16f66 Mon Sep 17 00:00:00 2001 From: poy Date: Tue, 13 Oct 2020 22:34:28 -0600 Subject: [PATCH] Adds Persistent{Pre,Post}Run hook chaining PersistentPreRun and PersistentPostRun are chained together so that each child PersistentPreRun is ran, and the PersistentPostRun are ran in reverse order. For example: Commands: root -> subcommand-a -> subcommand-b root - PersistentPreRun subcommand-a - PersistentPreRun subcommand-b - PersistentPreRun subcommand-b - Run subcommand-b - PersistentPostRun subcommand-a - PersistentPostRun root - PersistentPostRun fixes #252 --- .golangci.yml | 2 +- command.go | 165 +++++++++++++++++++++------ command_test.go | 290 ++++++++++++++++++++++++++++++++++++++++++++---- 3 files changed, 399 insertions(+), 58 deletions(-) diff --git a/.golangci.yml b/.golangci.yml index 0d6e61793..6676a7795 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -28,7 +28,7 @@ linters: - ineffassign - interfacer #- lll - - maligned + # - maligned - megacheck #- misspell #- nakedret diff --git a/command.go b/command.go index 5c85c899d..b098c330e 100644 --- a/command.go +++ b/command.go @@ -102,6 +102,21 @@ type Command struct { // * PersistentPostRun() // All functions get the same args, the arguments after the command name. // + // When TraverseChildrenHooks is set, PersistentPreRun and + // PersistentPostRun are chained together so that each child + // PersistentPreRun is ran, and the PersistentPostRun are ran in reverse + // order. For example: + // + // Commands: root -> subcommand-a -> subcommand-b + // + // root - PersistentPreRun + // subcommand-a - PersistentPreRun + // subcommand-b - PersistentPreRun + // subcommand-b - Run + // subcommand-b - PersistentPostRun + // subcommand-a - PersistentPostRun + // root - PersistentPostRun + // // PersistentPreRun: children of this command will inherit and execute. PersistentPreRun func(cmd *Command, args []string) // PersistentPreRunE: PersistentPreRun but returns an error. @@ -193,6 +208,11 @@ type Command struct { // TraverseChildren parses flags on all parents before executing child command. TraverseChildren bool + // TraverseChildrenHooks will have each subcommand's PersistentPreRun and + // PersistentPostRun instead of overriding. It should be set on the root + // command. + TraverseChildrenHooks bool + // Hidden defines, if this command is hidden and should NOT show up in the list of available commands. Hidden bool @@ -829,55 +849,130 @@ func (c *Command) execute(a []string) (err error) { return err } - for p := c; p != nil; p = p.Parent() { - if p.PersistentPreRunE != nil { - if err := p.PersistentPreRunE(c, argWoFlags); err != nil { - return err + // Look to see if TraverseChildrenHooks is set on the root command. + if _, err := c.runTree(c, argWoFlags, c.traverseChildrenHooks()); err != nil { + return err + } + + return nil +} + +func (c *Command) traverseChildrenHooks() bool { + if c.HasParent() { + return c.Parent().traverseChildrenHooks() + } + + return c.TraverseChildrenHooks +} + +func (c *Command) runTree( + cmd *Command, + args []string, + traverseChildrenHooks bool, +) ( + persistentPostRunEs []func(cmd *Command, args []string) error, + err error, +) { + if c == nil { + return nil, nil + } + + // Traverse command tree and save the PersistentPostRun{,E} functions. + persistentPostRunEs, err = c.Parent().runTree(cmd, args, traverseChildrenHooks) + if err != nil { + return nil, err + } + + if traverseChildrenHooks || c == cmd { + // PersistentPreRun/PersistentPreRunE + switch { + case c.PersistentPreRun != nil: + c.PersistentPreRun(cmd, args) + case c.PersistentPreRunE != nil: + if err := c.PersistentPreRunE(cmd, args); err != nil { + return nil, err } - break - } else if p.PersistentPreRun != nil { - p.PersistentPreRun(c, argWoFlags) - break + default: + // Doesn't have a registered PersistentPreRun{,E}. Move on... + } + + // PersistentPostRun/PersistentPostRunE + switch { + case c.PersistentPostRun != nil: + persistentPostRunEs = append( + persistentPostRunEs, + func(cmd *Command, args []string) error { + c.PersistentPostRun(cmd, args) + return nil + }, + ) + case c.PersistentPostRunE != nil: + persistentPostRunEs = append( + persistentPostRunEs, + c.PersistentPostRunE, + ) + default: + // Doesn't have a registered PersistentPostRun{,E}. Move on... } } - if c.PreRunE != nil { - if err := c.PreRunE(c, argWoFlags); err != nil { - return err + + if c != cmd { + // Don't run a parent command. + return persistentPostRunEs, nil + } + + // PreRun/PreRunE + switch { + case c.PreRun != nil: + c.PreRun(cmd, args) + case c.PreRunE != nil: + if err := c.PreRunE(cmd, args); err != nil { + return nil, err } - } else if c.PreRun != nil { - c.PreRun(c, argWoFlags) + default: + // Doesn't have a registered PreRun{,E}. Move on... } if err := c.validateRequiredFlags(); err != nil { - return err + return nil, err } - if c.RunE != nil { - if err := c.RunE(c, argWoFlags); err != nil { - return err + + // Run/RunE + switch { + case c.RunE != nil: + if err := c.RunE(cmd, args); err != nil { + return nil, err } - } else { - c.Run(c, argWoFlags) - } - if c.PostRunE != nil { - if err := c.PostRunE(c, argWoFlags); err != nil { - return err + case c.Run != nil: + c.Run(cmd, args) + default: + // Both RunE and Run are nil... + panic(fmt.Sprintf("command %q does not have a non-nil RunE or Run function", c.Use)) + } + + // PostRun/PostRunE + switch { + case c.PostRun != nil: + c.PostRun(cmd, args) + case c.PostRunE != nil: + if err := c.PostRunE(cmd, args); err != nil { + return nil, err } - } else if c.PostRun != nil { - c.PostRun(c, argWoFlags) + default: + // Doesn't have a registered PostRun{,E}. Move on... } - for p := c; p != nil; p = p.Parent() { - if p.PersistentPostRunE != nil { - if err := p.PersistentPostRunE(c, argWoFlags); err != nil { - return err - } - break - } else if p.PersistentPostRun != nil { - p.PersistentPostRun(c, argWoFlags) - break + + // PersistentPostRun/PersistentPostRunE + // Iterate through the list in reverse order. Similar to a defer, allow + // the topmost commands to cleanup first. + for i := range persistentPostRunEs { + r := persistentPostRunEs[len(persistentPostRunEs)-1-i] + if err := r(cmd, args); err != nil { + return nil, err } } - return nil + return nil, nil } func (c *Command) preRun() { diff --git a/command_test.go b/command_test.go index 583cb0235..8033ac2e2 100644 --- a/command_test.go +++ b/command_test.go @@ -3,6 +3,7 @@ package cobra import ( "bytes" "context" + "errors" "fmt" "io/ioutil" "os" @@ -1364,7 +1365,8 @@ func TestPersistentHooks(t *testing.T) { ) parentCmd := &Command{ - Use: "parent", + Use: "parent", + TraverseChildrenHooks: false, // Set explicitly to highlight setting. PersistentPreRun: func(_ *Command, args []string) { parentPersPreArgs = strings.Join(args, " ") }, @@ -1410,27 +1412,21 @@ func TestPersistentHooks(t *testing.T) { t.Errorf("Unexpected error: %v", err) } - for _, v := range []struct { - name string - got string - }{ - // TODO: currently PersistenPreRun* defined in parent does not - // run if the matchin child subcommand has PersistenPreRun. - // If the behavior changes (https://github.com/spf13/cobra/issues/252) - // this test must be fixed. - {"parentPersPreArgs", parentPersPreArgs}, - {"parentPreArgs", parentPreArgs}, - {"parentRunArgs", parentRunArgs}, - {"parentPostArgs", parentPostArgs}, - // TODO: currently PersistenPostRun* defined in parent does not - // run if the matchin child subcommand has PersistenPostRun. - // If the behavior changes (https://github.com/spf13/cobra/issues/252) - // this test must be fixed. - {"parentPersPostArgs", parentPersPostArgs}, - } { - if v.got != "" { - t.Errorf("Expected blank %s, got %q", v.name, v.got) - } + if parentPersPreArgs != "" { + t.Errorf("Expected blank parentPersPreArgs, got %q", parentPersPreArgs) + } + if parentPreArgs != "" { + t.Errorf("Expected blank parentPreArgs, got %q", parentPreArgs) + } + if parentRunArgs != "" { + t.Errorf("Expected blank parentRunArgs, got %q", parentRunArgs) + } + if parentPostArgs != "" { + t.Errorf("Expected blank parentPostArgs, got %q", parentPostArgs) + } + + if parentPersPostArgs != "" { + t.Errorf("Expected blank parentPersPostArgs, got %q", parentPersPostArgs) } for _, v := range []struct { @@ -1449,6 +1445,256 @@ func TestPersistentHooks(t *testing.T) { } } +func TestPersistentHooks_TraverseChildrenHooks(t *testing.T) { + var ( + parentPersPreArgs string + parentPreArgs string + parentRunArgs string + parentPostArgs string + parentPersPostArgs string + ) + + var ( + childPersPreArgs string + childPreArgs string + childRunArgs string + childPostArgs string + childPersPostArgs string + ) + + parentCmd := &Command{ + Use: "parent", + TraverseChildrenHooks: true, + PersistentPreRun: func(_ *Command, args []string) { + parentPersPreArgs = strings.Join(args, " ") + }, + PreRun: func(_ *Command, args []string) { + parentPreArgs = strings.Join(args, " ") + }, + Run: func(_ *Command, args []string) { + parentRunArgs = strings.Join(args, " ") + }, + PostRun: func(_ *Command, args []string) { + parentPostArgs = strings.Join(args, " ") + }, + PersistentPostRun: func(_ *Command, args []string) { + parentPersPostArgs = strings.Join(args, " ") + }, + } + + childCmd := &Command{ + Use: "child", + PersistentPreRun: func(_ *Command, args []string) { + childPersPreArgs = strings.Join(args, " ") + }, + PreRun: func(_ *Command, args []string) { + childPreArgs = strings.Join(args, " ") + }, + Run: func(_ *Command, args []string) { + childRunArgs = strings.Join(args, " ") + }, + PostRun: func(_ *Command, args []string) { + childPostArgs = strings.Join(args, " ") + }, + PersistentPostRun: func(_ *Command, args []string) { + childPersPostArgs = strings.Join(args, " ") + }, + } + parentCmd.AddCommand(childCmd) + + output, err := executeCommand(parentCmd, "child", "one", "two") + if output != "" { + t.Errorf("Unexpected output: %v", output) + } + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + if parentPersPreArgs != onetwo { + t.Errorf("Expected parentPersPreArgs %q, got %q", onetwo, parentPersPreArgs) + } + if parentPreArgs != "" { + t.Errorf("Expected blank parentPreArgs, got %q", parentPreArgs) + } + if parentRunArgs != "" { + t.Errorf("Expected blank parentRunArgs, got %q", parentRunArgs) + } + if parentPostArgs != "" { + t.Errorf("Expected blank parentPostArgs, got %q", parentPostArgs) + } + if parentPersPostArgs != onetwo { + t.Errorf("Expected parentPersPostArgs %q, got %q", onetwo, parentPersPostArgs) + } + + if childPersPreArgs != onetwo { + t.Errorf("Expected childPersPreArgs %q, got %q", onetwo, childPersPreArgs) + } + if childPreArgs != onetwo { + t.Errorf("Expected childPreArgs %q, got %q", onetwo, childPreArgs) + } + if childRunArgs != onetwo { + t.Errorf("Expected childRunArgs %q, got %q", onetwo, childRunArgs) + } + if childPostArgs != onetwo { + t.Errorf("Expected childPostArgs %q, got %q", onetwo, childPostArgs) + } + if childPersPostArgs != onetwo { + t.Errorf("Expected childPersPostArgs %q, got %q", onetwo, childPersPostArgs) + } +} + +func TestPersistentHooks_persistentPostRun_ordering(t *testing.T) { + var uses []string + nopRun := func(*Command, []string) {} + printRun := func(name string) func(*Command, []string) { + return func(cmd *Command, args []string) { + uses = append(uses, name) + } + } + + rootCmd := &Command{ + Use: "root", + TraverseChildrenHooks: true, + Run: nopRun, + PersistentPostRun: printRun("root"), + } + childCmd := &Command{ + Use: "child", + Run: nopRun, + PersistentPostRun: printRun("child"), + } + granchildCmd := &Command{ + Use: "grandchild", + Run: nopRun, + PersistentPostRun: printRun("grandchild"), + } + + childCmd.AddCommand(granchildCmd) + rootCmd.AddCommand(childCmd) + if _, err := executeCommand(rootCmd, "child", "grandchild"); err != nil { + t.Fatal(err) + } + + if !reflect.DeepEqual(uses, []string{"grandchild", "child", "root"}) { + t.Fatalf("incorrect ordering: %v", uses) + } +} + +func TestPersistentHooks_errs(t *testing.T) { + nopRun := func(*Command, []string) {} + + testCases := []struct { + name string + setup func() *Command + args []string + expectedErr error + }{ + { + name: "PersistentPreRunE", + expectedErr: errors.New("some-error"), + args: []string{"child"}, + setup: func() *Command { + parentCmd := &Command{ + Use: "parent", + TraverseChildrenHooks: true, + PersistentPreRunE: func(_ *Command, args []string) error { + return errors.New("some-error") + }, + Run: nopRun, + } + childCmd := &Command{ + Use: "child", + Run: nopRun, + PersistentPreRunE: func(_ *Command, args []string) error { + t.Fatal("should not be invoked") + return nil + }, + } + parentCmd.AddCommand(childCmd) + + return parentCmd + }, + }, + { + name: "PersistentPostRunE", + expectedErr: errors.New("some-error"), + args: []string{"child"}, + setup: func() *Command { + parentCmd := &Command{ + Use: "parent", + TraverseChildrenHooks: true, + PersistentPostRunE: func(_ *Command, args []string) error { + t.Fatal("should not be invoked") + return nil + }, + Run: nopRun, + } + childCmd := &Command{ + Use: "child", + Run: nopRun, + PersistentPostRunE: func(_ *Command, args []string) error { + return errors.New("some-error") + }, + } + parentCmd.AddCommand(childCmd) + + return parentCmd + }, + }, + { + name: "PreRunE", + expectedErr: errors.New("some-error"), + args: []string{"parent"}, + setup: func() *Command { + return &Command{ + Use: "parent", + PreRunE: func(_ *Command, args []string) error { + return errors.New("some-error") + }, + Run: nopRun, + } + }, + }, + { + name: "RunE", + expectedErr: errors.New("some-error"), + args: []string{"parent"}, + setup: func() *Command { + return &Command{ + Use: "parent", + RunE: func(_ *Command, args []string) error { + return errors.New("some-error") + }, + } + }, + }, + { + name: "PostRunE", + expectedErr: errors.New("some-error"), + args: []string{"parent"}, + setup: func() *Command { + return &Command{ + Use: "parent", + PostRunE: func(_ *Command, args []string) error { + return errors.New("some-error") + }, + Run: nopRun, + } + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + _, err := executeCommand(tc.setup(), tc.args...) + + if actual, expected := fmt.Sprint(err), fmt.Sprint(tc.expectedErr); expected != actual { + t.Fatalf("expected err %v, got %v", expected, actual) + } + }) + } +} + // Related to https://github.com/spf13/cobra/issues/521. func TestGlobalNormFuncPropagation(t *testing.T) { normFunc := func(f *pflag.FlagSet, name string) pflag.NormalizedName {