Skip to content

fix: allow overridding default string array #6873

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Mar 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 63 additions & 98 deletions cli/clibase/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -172,8 +172,8 @@ type Invocation struct {

// WithOS returns the invocation as a main package, filling in the invocation's unset
// fields with OS defaults.
func (i *Invocation) WithOS() *Invocation {
return i.with(func(i *Invocation) {
func (inv *Invocation) WithOS() *Invocation {
return inv.with(func(i *Invocation) {
i.Stdout = os.Stdout
i.Stderr = os.Stderr
i.Stdin = os.Stdin
Expand All @@ -182,18 +182,18 @@ func (i *Invocation) WithOS() *Invocation {
})
}

func (i *Invocation) Context() context.Context {
if i.ctx == nil {
func (inv *Invocation) Context() context.Context {
if inv.ctx == nil {
return context.Background()
}
return i.ctx
return inv.ctx
}

func (i *Invocation) ParsedFlags() *pflag.FlagSet {
if i.parsedFlags == nil {
func (inv *Invocation) ParsedFlags() *pflag.FlagSet {
if inv.parsedFlags == nil {
panic("flags not parsed, has Run() been called?")
}
return i.parsedFlags
return inv.parsedFlags
}

type runState struct {
Expand All @@ -218,39 +218,17 @@ func copyFlagSetWithout(fs *pflag.FlagSet, without string) *pflag.FlagSet {
// run recursively executes the command and its children.
// allArgs is wired through the stack so that global flags can be accepted
// anywhere in the command invocation.
func (i *Invocation) run(state *runState) error {
err := i.Command.Options.SetDefaults()
if err != nil {
return xerrors.Errorf("setting defaults: %w", err)
}

// If we set the Default of an array but later see a flag for it, we
// don't want to append, we want to replace. So, we need to keep the state
// of defaulted array options.
defaultedArrays := make(map[string]int)
for _, opt := range i.Command.Options {
sv, ok := opt.Value.(pflag.SliceValue)
if !ok {
continue
}

if opt.Flag == "" {
continue
}

defaultedArrays[opt.Flag] = len(sv.GetSlice())
}

err = i.Command.Options.ParseEnv(i.Environ)
func (inv *Invocation) run(state *runState) error {
err := inv.Command.Options.ParseEnv(inv.Environ)
if err != nil {
return xerrors.Errorf("parsing env: %w", err)
}

// Now the fun part, argument parsing!

children := make(map[string]*Cmd)
for _, child := range i.Command.Children {
child.Parent = i.Command
for _, child := range inv.Command.Children {
child.Parent = inv.Command
for _, name := range append(child.Aliases, child.Name()) {
if _, ok := children[name]; ok {
return xerrors.Errorf("duplicate command name: %s", name)
Expand All @@ -259,57 +237,44 @@ func (i *Invocation) run(state *runState) error {
}
}

if i.parsedFlags == nil {
i.parsedFlags = pflag.NewFlagSet(i.Command.Name(), pflag.ContinueOnError)
if inv.parsedFlags == nil {
inv.parsedFlags = pflag.NewFlagSet(inv.Command.Name(), pflag.ContinueOnError)
// We handle Usage ourselves.
i.parsedFlags.Usage = func() {}
inv.parsedFlags.Usage = func() {}
}

// If we find a duplicate flag, we want the deeper command's flag to override
// the shallow one. Unfortunately, pflag has no way to remove a flag, so we
// have to create a copy of the flagset without a value.
i.Command.Options.FlagSet().VisitAll(func(f *pflag.Flag) {
if i.parsedFlags.Lookup(f.Name) != nil {
i.parsedFlags = copyFlagSetWithout(i.parsedFlags, f.Name)
inv.Command.Options.FlagSet().VisitAll(func(f *pflag.Flag) {
if inv.parsedFlags.Lookup(f.Name) != nil {
inv.parsedFlags = copyFlagSetWithout(inv.parsedFlags, f.Name)
}
i.parsedFlags.AddFlag(f)
inv.parsedFlags.AddFlag(f)
})

var parsedArgs []string

if !i.Command.RawArgs {
if !inv.Command.RawArgs {
// Flag parsing will fail on intermediate commands in the command tree,
// so we check the error after looking for a child command.
state.flagParseErr = i.parsedFlags.Parse(state.allArgs)
parsedArgs = i.parsedFlags.Args()

i.parsedFlags.VisitAll(func(f *pflag.Flag) {
i, ok := defaultedArrays[f.Name]
if !ok {
return
}

if !f.Changed {
return
}
state.flagParseErr = inv.parsedFlags.Parse(state.allArgs)
parsedArgs = inv.parsedFlags.Args()
}

// If flag was changed, we need to remove the default values.
sv, ok := f.Value.(pflag.SliceValue)
if !ok {
panic("defaulted array option is not a slice value")
}
ss := sv.GetSlice()
if len(ss) == 0 {
// Slice likely zeroed by a flag.
// E.g. "--fruit" may default to "apples,oranges" but the user
// provided "--fruit=""".
return
}
err := sv.Replace(ss[i:])
if err != nil {
panic(err)
}
})
// Set defaults for flags that weren't set by the user.
skipDefaults := make(map[int]struct{}, len(inv.Command.Options))
for i, opt := range inv.Command.Options {
if fl := inv.parsedFlags.Lookup(opt.Flag); fl != nil && fl.Changed {
skipDefaults[i] = struct{}{}
}
if opt.envChanged {
skipDefaults[i] = struct{}{}
}
}
err = inv.Command.Options.SetDefaults(skipDefaults)
if err != nil {
return xerrors.Errorf("setting defaults: %w", err)
}

// Run child command if found (next child only)
Expand All @@ -318,64 +283,64 @@ func (i *Invocation) run(state *runState) error {
if len(parsedArgs) > state.commandDepth {
nextArg := parsedArgs[state.commandDepth]
if child, ok := children[nextArg]; ok {
child.Parent = i.Command
i.Command = child
child.Parent = inv.Command
inv.Command = child
state.commandDepth++
return i.run(state)
return inv.run(state)
}
}

// Flag parse errors are irrelevant for raw args commands.
if !i.Command.RawArgs && state.flagParseErr != nil && !errors.Is(state.flagParseErr, pflag.ErrHelp) {
if !inv.Command.RawArgs && state.flagParseErr != nil && !errors.Is(state.flagParseErr, pflag.ErrHelp) {
return xerrors.Errorf(
"parsing flags (%v) for %q: %w",
state.allArgs,
i.Command.FullName(), state.flagParseErr,
inv.Command.FullName(), state.flagParseErr,
)
}

if i.Command.RawArgs {
if inv.Command.RawArgs {
// If we're at the root command, then the name is omitted
// from the arguments, so we can just use the entire slice.
if state.commandDepth == 0 {
i.Args = state.allArgs
inv.Args = state.allArgs
} else {
argPos, err := findArg(i.Command.Name(), state.allArgs, i.parsedFlags)
argPos, err := findArg(inv.Command.Name(), state.allArgs, inv.parsedFlags)
if err != nil {
panic(err)
}
i.Args = state.allArgs[argPos+1:]
inv.Args = state.allArgs[argPos+1:]
}
} else {
// In non-raw-arg mode, we want to skip over flags.
i.Args = parsedArgs[state.commandDepth:]
inv.Args = parsedArgs[state.commandDepth:]
}

mw := i.Command.Middleware
mw := inv.Command.Middleware
if mw == nil {
mw = Chain()
}

ctx := i.ctx
ctx := inv.ctx
if ctx == nil {
ctx = context.Background()
}

ctx, cancel := context.WithCancel(ctx)
defer cancel()
i = i.WithContext(ctx)
inv = inv.WithContext(ctx)

if i.Command.Handler == nil || errors.Is(state.flagParseErr, pflag.ErrHelp) {
if i.Command.HelpHandler == nil {
return xerrors.Errorf("no handler or help for command %s", i.Command.FullName())
if inv.Command.Handler == nil || errors.Is(state.flagParseErr, pflag.ErrHelp) {
if inv.Command.HelpHandler == nil {
return xerrors.Errorf("no handler or help for command %s", inv.Command.FullName())
}
return i.Command.HelpHandler(i)
return inv.Command.HelpHandler(inv)
}

err = mw(i.Command.Handler)(i)
err = mw(inv.Command.Handler)(inv)
if err != nil {
return &RunCommandError{
Cmd: i.Command,
Cmd: inv.Command,
Err: err,
}
}
Expand Down Expand Up @@ -438,33 +403,33 @@ func findArg(want string, args []string, fs *pflag.FlagSet) (int, error) {
// If two command share a flag name, the first command wins.
//
//nolint:revive
func (i *Invocation) Run() (err error) {
func (inv *Invocation) Run() (err error) {
defer func() {
// Pflag is panicky, so additional context is helpful in tests.
if flag.Lookup("test.v") == nil {
return
}
if r := recover(); r != nil {
err = xerrors.Errorf("panic recovered for %s: %v", i.Command.FullName(), r)
err = xerrors.Errorf("panic recovered for %s: %v", inv.Command.FullName(), r)
panic(err)
}
}()
err = i.run(&runState{
allArgs: i.Args,
err = inv.run(&runState{
allArgs: inv.Args,
})
return err
}

// WithContext returns a copy of the Invocation with the given context.
func (i *Invocation) WithContext(ctx context.Context) *Invocation {
return i.with(func(i *Invocation) {
func (inv *Invocation) WithContext(ctx context.Context) *Invocation {
return inv.with(func(i *Invocation) {
i.ctx = ctx
})
}

// with returns a copy of the Invocation with the given function applied.
func (i *Invocation) with(fn func(*Invocation)) *Invocation {
i2 := *i
func (inv *Invocation) with(fn func(*Invocation)) *Invocation {
i2 := *inv
fn(&i2)
return &i2
}
Expand Down
10 changes: 9 additions & 1 deletion cli/clibase/cmd_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,7 @@ func TestCommand_FlagOverride(t *testing.T) {
Use: "1",
Options: clibase.OptionSet{
{
Name: "flag",
Flag: "f",
Value: clibase.DiscardValue,
},
Expand All @@ -256,6 +257,7 @@ func TestCommand_FlagOverride(t *testing.T) {
Use: "2",
Options: clibase.OptionSet{
{
Name: "flag",
Flag: "f",
Value: clibase.StringOf(&flag),
},
Expand Down Expand Up @@ -527,11 +529,17 @@ func TestCommand_EmptySlice(t *testing.T) {
}
}

// Base-case
// Base-case, uses default.
err := cmd("bad", "bad", "bad").Invoke().Run()
require.NoError(t, err)

// Reset to nothing at all.
inv := cmd().Invoke("--arr", "")
err = inv.Run()
require.NoError(t, err)

// Override
inv = cmd("great").Invoke("--arr", "great")
err = inv.Run()
require.NoError(t, err)
}
19 changes: 15 additions & 4 deletions cli/clibase/option.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ type Option struct {
UseInstead []Option `json:"use_instead,omitempty"`

Hidden bool `json:"hidden,omitempty"`

envChanged bool
}

// OptionSet is a group of options that can be applied to a command.
Expand Down Expand Up @@ -133,6 +135,7 @@ func (s *OptionSet) ParseEnv(vs []EnvVar) error {
continue
}

opt.envChanged = true
if err := opt.Value.Set(envVal); err != nil {
merr = multierror.Append(
merr, xerrors.Errorf("parse %q: %w", opt.Name, err),
Expand All @@ -143,19 +146,27 @@ func (s *OptionSet) ParseEnv(vs []EnvVar) error {
return merr.ErrorOrNil()
}

// SetDefaults sets the default values for each Option.
// It should be called before all parsing (e.g. ParseFlags, ParseEnv).
func (s *OptionSet) SetDefaults() error {
// SetDefaults sets the default values for each Option, skipping values
// that have already been set as indicated by the skip map.
func (s *OptionSet) SetDefaults(skip map[int]struct{}) error {
if s == nil {
return nil
}

var merr *multierror.Error

for _, opt := range *s {
for i, opt := range *s {
// Skip values that may have already been set by the user.
if len(skip) > 0 {
if _, ok := skip[i]; ok {
continue
}
}

if opt.Default == "" {
continue
}

if opt.Value == nil {
merr = multierror.Append(
merr,
Expand Down
4 changes: 2 additions & 2 deletions cli/clibase/option_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ func TestOptionSet_ParseFlags(t *testing.T) {
},
}

err := os.SetDefaults()
err := os.SetDefaults(nil)
require.NoError(t, err)

err = os.FlagSet().Parse([]string{"--name", "foo", "--name", "bar"})
Expand Down Expand Up @@ -111,7 +111,7 @@ func TestOptionSet_ParseEnv(t *testing.T) {
},
}

err := os.SetDefaults()
err := os.SetDefaults(nil)
require.NoError(t, err)

err = os.ParseEnv(clibase.ParseEnviron([]string{"CODER_WORKSPACE_NAME="}, "CODER_"))
Expand Down
2 changes: 1 addition & 1 deletion cli/clibase/yaml_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ func TestOption_ToYAML(t *testing.T) {
},
}

err := os.SetDefaults()
err := os.SetDefaults(nil)
require.NoError(t, err)

n, err := os.ToYAML()
Expand Down
Loading