diff --git a/cli/clibase/cmd.go b/cli/clibase/cmd.go index 49a3cae718a18..2b9da500225ce 100644 --- a/cli/clibase/cmd.go +++ b/cli/clibase/cmd.go @@ -172,8 +172,8 @@ type Invocation struct { // WithOS returns the invocation as a main package, filling in the invocation's unset // fields with OS defaults. -func (i *Invocation) WithOS() *Invocation { - return i.with(func(i *Invocation) { +func (inv *Invocation) WithOS() *Invocation { + return inv.with(func(i *Invocation) { i.Stdout = os.Stdout i.Stderr = os.Stderr i.Stdin = os.Stdin @@ -182,18 +182,18 @@ func (i *Invocation) WithOS() *Invocation { }) } -func (i *Invocation) Context() context.Context { - if i.ctx == nil { +func (inv *Invocation) Context() context.Context { + if inv.ctx == nil { return context.Background() } - return i.ctx + return inv.ctx } -func (i *Invocation) ParsedFlags() *pflag.FlagSet { - if i.parsedFlags == nil { +func (inv *Invocation) ParsedFlags() *pflag.FlagSet { + if inv.parsedFlags == nil { panic("flags not parsed, has Run() been called?") } - return i.parsedFlags + return inv.parsedFlags } type runState struct { @@ -218,30 +218,8 @@ func copyFlagSetWithout(fs *pflag.FlagSet, without string) *pflag.FlagSet { // run recursively executes the command and its children. // allArgs is wired through the stack so that global flags can be accepted // anywhere in the command invocation. -func (i *Invocation) run(state *runState) error { - err := i.Command.Options.SetDefaults() - if err != nil { - return xerrors.Errorf("setting defaults: %w", err) - } - - // If we set the Default of an array but later see a flag for it, we - // don't want to append, we want to replace. So, we need to keep the state - // of defaulted array options. - defaultedArrays := make(map[string]int) - for _, opt := range i.Command.Options { - sv, ok := opt.Value.(pflag.SliceValue) - if !ok { - continue - } - - if opt.Flag == "" { - continue - } - - defaultedArrays[opt.Flag] = len(sv.GetSlice()) - } - - err = i.Command.Options.ParseEnv(i.Environ) +func (inv *Invocation) run(state *runState) error { + err := inv.Command.Options.ParseEnv(inv.Environ) if err != nil { return xerrors.Errorf("parsing env: %w", err) } @@ -249,8 +227,8 @@ func (i *Invocation) run(state *runState) error { // Now the fun part, argument parsing! children := make(map[string]*Cmd) - for _, child := range i.Command.Children { - child.Parent = i.Command + for _, child := range inv.Command.Children { + child.Parent = inv.Command for _, name := range append(child.Aliases, child.Name()) { if _, ok := children[name]; ok { return xerrors.Errorf("duplicate command name: %s", name) @@ -259,57 +237,44 @@ func (i *Invocation) run(state *runState) error { } } - if i.parsedFlags == nil { - i.parsedFlags = pflag.NewFlagSet(i.Command.Name(), pflag.ContinueOnError) + if inv.parsedFlags == nil { + inv.parsedFlags = pflag.NewFlagSet(inv.Command.Name(), pflag.ContinueOnError) // We handle Usage ourselves. - i.parsedFlags.Usage = func() {} + inv.parsedFlags.Usage = func() {} } // If we find a duplicate flag, we want the deeper command's flag to override // the shallow one. Unfortunately, pflag has no way to remove a flag, so we // have to create a copy of the flagset without a value. - i.Command.Options.FlagSet().VisitAll(func(f *pflag.Flag) { - if i.parsedFlags.Lookup(f.Name) != nil { - i.parsedFlags = copyFlagSetWithout(i.parsedFlags, f.Name) + inv.Command.Options.FlagSet().VisitAll(func(f *pflag.Flag) { + if inv.parsedFlags.Lookup(f.Name) != nil { + inv.parsedFlags = copyFlagSetWithout(inv.parsedFlags, f.Name) } - i.parsedFlags.AddFlag(f) + inv.parsedFlags.AddFlag(f) }) var parsedArgs []string - if !i.Command.RawArgs { + if !inv.Command.RawArgs { // Flag parsing will fail on intermediate commands in the command tree, // so we check the error after looking for a child command. - state.flagParseErr = i.parsedFlags.Parse(state.allArgs) - parsedArgs = i.parsedFlags.Args() - - i.parsedFlags.VisitAll(func(f *pflag.Flag) { - i, ok := defaultedArrays[f.Name] - if !ok { - return - } - - if !f.Changed { - return - } + state.flagParseErr = inv.parsedFlags.Parse(state.allArgs) + parsedArgs = inv.parsedFlags.Args() + } - // If flag was changed, we need to remove the default values. - sv, ok := f.Value.(pflag.SliceValue) - if !ok { - panic("defaulted array option is not a slice value") - } - ss := sv.GetSlice() - if len(ss) == 0 { - // Slice likely zeroed by a flag. - // E.g. "--fruit" may default to "apples,oranges" but the user - // provided "--fruit=""". - return - } - err := sv.Replace(ss[i:]) - if err != nil { - panic(err) - } - }) + // Set defaults for flags that weren't set by the user. + skipDefaults := make(map[int]struct{}, len(inv.Command.Options)) + for i, opt := range inv.Command.Options { + if fl := inv.parsedFlags.Lookup(opt.Flag); fl != nil && fl.Changed { + skipDefaults[i] = struct{}{} + } + if opt.envChanged { + skipDefaults[i] = struct{}{} + } + } + err = inv.Command.Options.SetDefaults(skipDefaults) + if err != nil { + return xerrors.Errorf("setting defaults: %w", err) } // Run child command if found (next child only) @@ -318,64 +283,64 @@ func (i *Invocation) run(state *runState) error { if len(parsedArgs) > state.commandDepth { nextArg := parsedArgs[state.commandDepth] if child, ok := children[nextArg]; ok { - child.Parent = i.Command - i.Command = child + child.Parent = inv.Command + inv.Command = child state.commandDepth++ - return i.run(state) + return inv.run(state) } } // Flag parse errors are irrelevant for raw args commands. - if !i.Command.RawArgs && state.flagParseErr != nil && !errors.Is(state.flagParseErr, pflag.ErrHelp) { + if !inv.Command.RawArgs && state.flagParseErr != nil && !errors.Is(state.flagParseErr, pflag.ErrHelp) { return xerrors.Errorf( "parsing flags (%v) for %q: %w", state.allArgs, - i.Command.FullName(), state.flagParseErr, + inv.Command.FullName(), state.flagParseErr, ) } - if i.Command.RawArgs { + if inv.Command.RawArgs { // If we're at the root command, then the name is omitted // from the arguments, so we can just use the entire slice. if state.commandDepth == 0 { - i.Args = state.allArgs + inv.Args = state.allArgs } else { - argPos, err := findArg(i.Command.Name(), state.allArgs, i.parsedFlags) + argPos, err := findArg(inv.Command.Name(), state.allArgs, inv.parsedFlags) if err != nil { panic(err) } - i.Args = state.allArgs[argPos+1:] + inv.Args = state.allArgs[argPos+1:] } } else { // In non-raw-arg mode, we want to skip over flags. - i.Args = parsedArgs[state.commandDepth:] + inv.Args = parsedArgs[state.commandDepth:] } - mw := i.Command.Middleware + mw := inv.Command.Middleware if mw == nil { mw = Chain() } - ctx := i.ctx + ctx := inv.ctx if ctx == nil { ctx = context.Background() } ctx, cancel := context.WithCancel(ctx) defer cancel() - i = i.WithContext(ctx) + inv = inv.WithContext(ctx) - if i.Command.Handler == nil || errors.Is(state.flagParseErr, pflag.ErrHelp) { - if i.Command.HelpHandler == nil { - return xerrors.Errorf("no handler or help for command %s", i.Command.FullName()) + if inv.Command.Handler == nil || errors.Is(state.flagParseErr, pflag.ErrHelp) { + if inv.Command.HelpHandler == nil { + return xerrors.Errorf("no handler or help for command %s", inv.Command.FullName()) } - return i.Command.HelpHandler(i) + return inv.Command.HelpHandler(inv) } - err = mw(i.Command.Handler)(i) + err = mw(inv.Command.Handler)(inv) if err != nil { return &RunCommandError{ - Cmd: i.Command, + Cmd: inv.Command, Err: err, } } @@ -438,33 +403,33 @@ func findArg(want string, args []string, fs *pflag.FlagSet) (int, error) { // If two command share a flag name, the first command wins. // //nolint:revive -func (i *Invocation) Run() (err error) { +func (inv *Invocation) Run() (err error) { defer func() { // Pflag is panicky, so additional context is helpful in tests. if flag.Lookup("test.v") == nil { return } if r := recover(); r != nil { - err = xerrors.Errorf("panic recovered for %s: %v", i.Command.FullName(), r) + err = xerrors.Errorf("panic recovered for %s: %v", inv.Command.FullName(), r) panic(err) } }() - err = i.run(&runState{ - allArgs: i.Args, + err = inv.run(&runState{ + allArgs: inv.Args, }) return err } // WithContext returns a copy of the Invocation with the given context. -func (i *Invocation) WithContext(ctx context.Context) *Invocation { - return i.with(func(i *Invocation) { +func (inv *Invocation) WithContext(ctx context.Context) *Invocation { + return inv.with(func(i *Invocation) { i.ctx = ctx }) } // with returns a copy of the Invocation with the given function applied. -func (i *Invocation) with(fn func(*Invocation)) *Invocation { - i2 := *i +func (inv *Invocation) with(fn func(*Invocation)) *Invocation { + i2 := *inv fn(&i2) return &i2 } diff --git a/cli/clibase/cmd_test.go b/cli/clibase/cmd_test.go index cf835327cf822..f5ed6f676396c 100644 --- a/cli/clibase/cmd_test.go +++ b/cli/clibase/cmd_test.go @@ -247,6 +247,7 @@ func TestCommand_FlagOverride(t *testing.T) { Use: "1", Options: clibase.OptionSet{ { + Name: "flag", Flag: "f", Value: clibase.DiscardValue, }, @@ -256,6 +257,7 @@ func TestCommand_FlagOverride(t *testing.T) { Use: "2", Options: clibase.OptionSet{ { + Name: "flag", Flag: "f", Value: clibase.StringOf(&flag), }, @@ -527,11 +529,17 @@ func TestCommand_EmptySlice(t *testing.T) { } } - // Base-case + // Base-case, uses default. err := cmd("bad", "bad", "bad").Invoke().Run() require.NoError(t, err) + // Reset to nothing at all. inv := cmd().Invoke("--arr", "") err = inv.Run() require.NoError(t, err) + + // Override + inv = cmd("great").Invoke("--arr", "great") + err = inv.Run() + require.NoError(t, err) } diff --git a/cli/clibase/option.go b/cli/clibase/option.go index 05b444c24803b..7b294f4884281 100644 --- a/cli/clibase/option.go +++ b/cli/clibase/option.go @@ -46,6 +46,8 @@ type Option struct { UseInstead []Option `json:"use_instead,omitempty"` Hidden bool `json:"hidden,omitempty"` + + envChanged bool } // OptionSet is a group of options that can be applied to a command. @@ -133,6 +135,7 @@ func (s *OptionSet) ParseEnv(vs []EnvVar) error { continue } + opt.envChanged = true if err := opt.Value.Set(envVal); err != nil { merr = multierror.Append( merr, xerrors.Errorf("parse %q: %w", opt.Name, err), @@ -143,19 +146,27 @@ func (s *OptionSet) ParseEnv(vs []EnvVar) error { return merr.ErrorOrNil() } -// SetDefaults sets the default values for each Option. -// It should be called before all parsing (e.g. ParseFlags, ParseEnv). -func (s *OptionSet) SetDefaults() error { +// SetDefaults sets the default values for each Option, skipping values +// that have already been set as indicated by the skip map. +func (s *OptionSet) SetDefaults(skip map[int]struct{}) error { if s == nil { return nil } var merr *multierror.Error - for _, opt := range *s { + for i, opt := range *s { + // Skip values that may have already been set by the user. + if len(skip) > 0 { + if _, ok := skip[i]; ok { + continue + } + } + if opt.Default == "" { continue } + if opt.Value == nil { merr = multierror.Append( merr, diff --git a/cli/clibase/option_test.go b/cli/clibase/option_test.go index d9d38cc6c7bd9..862e8098db573 100644 --- a/cli/clibase/option_test.go +++ b/cli/clibase/option_test.go @@ -49,7 +49,7 @@ func TestOptionSet_ParseFlags(t *testing.T) { }, } - err := os.SetDefaults() + err := os.SetDefaults(nil) require.NoError(t, err) err = os.FlagSet().Parse([]string{"--name", "foo", "--name", "bar"}) @@ -111,7 +111,7 @@ func TestOptionSet_ParseEnv(t *testing.T) { }, } - err := os.SetDefaults() + err := os.SetDefaults(nil) require.NoError(t, err) err = os.ParseEnv(clibase.ParseEnviron([]string{"CODER_WORKSPACE_NAME="}, "CODER_")) diff --git a/cli/clibase/yaml_test.go b/cli/clibase/yaml_test.go index 3efad6ee54ed8..62582a5252396 100644 --- a/cli/clibase/yaml_test.go +++ b/cli/clibase/yaml_test.go @@ -44,7 +44,7 @@ func TestOption_ToYAML(t *testing.T) { }, } - err := os.SetDefaults() + err := os.SetDefaults(nil) require.NoError(t, err) n, err := os.ToYAML() diff --git a/coderd/coderdtest/coderdtest.go b/coderd/coderdtest/coderdtest.go index ff9bf82addc98..aaf9779d2ef23 100644 --- a/coderd/coderdtest/coderdtest.go +++ b/coderd/coderdtest/coderdtest.go @@ -1075,7 +1075,7 @@ QastnN77KfUwdj3SJt44U/uh1jAIv4oSLBr8HYUkbnI8 func DeploymentValues(t *testing.T) *codersdk.DeploymentValues { var cfg codersdk.DeploymentValues opts := cfg.Options() - err := opts.SetDefaults() + err := opts.SetDefaults(nil) require.NoError(t, err) return &cfg }