Skip to content

Commit 599699b

Browse files
authored
fix: truly allow overridding default string array (coder#6874)
1 parent 96ff400 commit 599699b

File tree

7 files changed

+148
-112
lines changed

7 files changed

+148
-112
lines changed

cli/clibase/cmd.go

+63-98
Original file line numberDiff line numberDiff line change
@@ -172,8 +172,8 @@ type Invocation struct {
172172

173173
// WithOS returns the invocation as a main package, filling in the invocation's unset
174174
// fields with OS defaults.
175-
func (i *Invocation) WithOS() *Invocation {
176-
return i.with(func(i *Invocation) {
175+
func (inv *Invocation) WithOS() *Invocation {
176+
return inv.with(func(i *Invocation) {
177177
i.Stdout = os.Stdout
178178
i.Stderr = os.Stderr
179179
i.Stdin = os.Stdin
@@ -182,18 +182,18 @@ func (i *Invocation) WithOS() *Invocation {
182182
})
183183
}
184184

185-
func (i *Invocation) Context() context.Context {
186-
if i.ctx == nil {
185+
func (inv *Invocation) Context() context.Context {
186+
if inv.ctx == nil {
187187
return context.Background()
188188
}
189-
return i.ctx
189+
return inv.ctx
190190
}
191191

192-
func (i *Invocation) ParsedFlags() *pflag.FlagSet {
193-
if i.parsedFlags == nil {
192+
func (inv *Invocation) ParsedFlags() *pflag.FlagSet {
193+
if inv.parsedFlags == nil {
194194
panic("flags not parsed, has Run() been called?")
195195
}
196-
return i.parsedFlags
196+
return inv.parsedFlags
197197
}
198198

199199
type runState struct {
@@ -218,39 +218,17 @@ func copyFlagSetWithout(fs *pflag.FlagSet, without string) *pflag.FlagSet {
218218
// run recursively executes the command and its children.
219219
// allArgs is wired through the stack so that global flags can be accepted
220220
// anywhere in the command invocation.
221-
func (i *Invocation) run(state *runState) error {
222-
err := i.Command.Options.SetDefaults()
223-
if err != nil {
224-
return xerrors.Errorf("setting defaults: %w", err)
225-
}
226-
227-
// If we set the Default of an array but later see a flag for it, we
228-
// don't want to append, we want to replace. So, we need to keep the state
229-
// of defaulted array options.
230-
defaultedArrays := make(map[string]int)
231-
for _, opt := range i.Command.Options {
232-
sv, ok := opt.Value.(pflag.SliceValue)
233-
if !ok {
234-
continue
235-
}
236-
237-
if opt.Flag == "" {
238-
continue
239-
}
240-
241-
defaultedArrays[opt.Flag] = len(sv.GetSlice())
242-
}
243-
244-
err = i.Command.Options.ParseEnv(i.Environ)
221+
func (inv *Invocation) run(state *runState) error {
222+
err := inv.Command.Options.ParseEnv(inv.Environ)
245223
if err != nil {
246224
return xerrors.Errorf("parsing env: %w", err)
247225
}
248226

249227
// Now the fun part, argument parsing!
250228

251229
children := make(map[string]*Cmd)
252-
for _, child := range i.Command.Children {
253-
child.Parent = i.Command
230+
for _, child := range inv.Command.Children {
231+
child.Parent = inv.Command
254232
for _, name := range append(child.Aliases, child.Name()) {
255233
if _, ok := children[name]; ok {
256234
return xerrors.Errorf("duplicate command name: %s", name)
@@ -259,57 +237,44 @@ func (i *Invocation) run(state *runState) error {
259237
}
260238
}
261239

262-
if i.parsedFlags == nil {
263-
i.parsedFlags = pflag.NewFlagSet(i.Command.Name(), pflag.ContinueOnError)
240+
if inv.parsedFlags == nil {
241+
inv.parsedFlags = pflag.NewFlagSet(inv.Command.Name(), pflag.ContinueOnError)
264242
// We handle Usage ourselves.
265-
i.parsedFlags.Usage = func() {}
243+
inv.parsedFlags.Usage = func() {}
266244
}
267245

268246
// If we find a duplicate flag, we want the deeper command's flag to override
269247
// the shallow one. Unfortunately, pflag has no way to remove a flag, so we
270248
// have to create a copy of the flagset without a value.
271-
i.Command.Options.FlagSet().VisitAll(func(f *pflag.Flag) {
272-
if i.parsedFlags.Lookup(f.Name) != nil {
273-
i.parsedFlags = copyFlagSetWithout(i.parsedFlags, f.Name)
249+
inv.Command.Options.FlagSet().VisitAll(func(f *pflag.Flag) {
250+
if inv.parsedFlags.Lookup(f.Name) != nil {
251+
inv.parsedFlags = copyFlagSetWithout(inv.parsedFlags, f.Name)
274252
}
275-
i.parsedFlags.AddFlag(f)
253+
inv.parsedFlags.AddFlag(f)
276254
})
277255

278256
var parsedArgs []string
279257

280-
if !i.Command.RawArgs {
258+
if !inv.Command.RawArgs {
281259
// Flag parsing will fail on intermediate commands in the command tree,
282260
// so we check the error after looking for a child command.
283-
state.flagParseErr = i.parsedFlags.Parse(state.allArgs)
284-
parsedArgs = i.parsedFlags.Args()
285-
286-
i.parsedFlags.VisitAll(func(f *pflag.Flag) {
287-
i, ok := defaultedArrays[f.Name]
288-
if !ok {
289-
return
290-
}
291-
292-
if !f.Changed {
293-
return
294-
}
261+
state.flagParseErr = inv.parsedFlags.Parse(state.allArgs)
262+
parsedArgs = inv.parsedFlags.Args()
263+
}
295264

296-
// If flag was changed, we need to remove the default values.
297-
sv, ok := f.Value.(pflag.SliceValue)
298-
if !ok {
299-
panic("defaulted array option is not a slice value")
300-
}
301-
ss := sv.GetSlice()
302-
if len(ss) == 0 {
303-
// Slice likely zeroed by a flag.
304-
// E.g. "--fruit" may default to "apples,oranges" but the user
305-
// provided "--fruit=""".
306-
return
307-
}
308-
err := sv.Replace(ss[i:])
309-
if err != nil {
310-
panic(err)
311-
}
312-
})
265+
// Set defaults for flags that weren't set by the user.
266+
skipDefaults := make(map[int]struct{}, len(inv.Command.Options))
267+
for i, opt := range inv.Command.Options {
268+
if fl := inv.parsedFlags.Lookup(opt.Flag); fl != nil && fl.Changed {
269+
skipDefaults[i] = struct{}{}
270+
}
271+
if opt.envChanged {
272+
skipDefaults[i] = struct{}{}
273+
}
274+
}
275+
err = inv.Command.Options.SetDefaults(skipDefaults)
276+
if err != nil {
277+
return xerrors.Errorf("setting defaults: %w", err)
313278
}
314279

315280
// Run child command if found (next child only)
@@ -318,64 +283,64 @@ func (i *Invocation) run(state *runState) error {
318283
if len(parsedArgs) > state.commandDepth {
319284
nextArg := parsedArgs[state.commandDepth]
320285
if child, ok := children[nextArg]; ok {
321-
child.Parent = i.Command
322-
i.Command = child
286+
child.Parent = inv.Command
287+
inv.Command = child
323288
state.commandDepth++
324-
return i.run(state)
289+
return inv.run(state)
325290
}
326291
}
327292

328293
// Flag parse errors are irrelevant for raw args commands.
329-
if !i.Command.RawArgs && state.flagParseErr != nil && !errors.Is(state.flagParseErr, pflag.ErrHelp) {
294+
if !inv.Command.RawArgs && state.flagParseErr != nil && !errors.Is(state.flagParseErr, pflag.ErrHelp) {
330295
return xerrors.Errorf(
331296
"parsing flags (%v) for %q: %w",
332297
state.allArgs,
333-
i.Command.FullName(), state.flagParseErr,
298+
inv.Command.FullName(), state.flagParseErr,
334299
)
335300
}
336301

337-
if i.Command.RawArgs {
302+
if inv.Command.RawArgs {
338303
// If we're at the root command, then the name is omitted
339304
// from the arguments, so we can just use the entire slice.
340305
if state.commandDepth == 0 {
341-
i.Args = state.allArgs
306+
inv.Args = state.allArgs
342307
} else {
343-
argPos, err := findArg(i.Command.Name(), state.allArgs, i.parsedFlags)
308+
argPos, err := findArg(inv.Command.Name(), state.allArgs, inv.parsedFlags)
344309
if err != nil {
345310
panic(err)
346311
}
347-
i.Args = state.allArgs[argPos+1:]
312+
inv.Args = state.allArgs[argPos+1:]
348313
}
349314
} else {
350315
// In non-raw-arg mode, we want to skip over flags.
351-
i.Args = parsedArgs[state.commandDepth:]
316+
inv.Args = parsedArgs[state.commandDepth:]
352317
}
353318

354-
mw := i.Command.Middleware
319+
mw := inv.Command.Middleware
355320
if mw == nil {
356321
mw = Chain()
357322
}
358323

359-
ctx := i.ctx
324+
ctx := inv.ctx
360325
if ctx == nil {
361326
ctx = context.Background()
362327
}
363328

364329
ctx, cancel := context.WithCancel(ctx)
365330
defer cancel()
366-
i = i.WithContext(ctx)
331+
inv = inv.WithContext(ctx)
367332

368-
if i.Command.Handler == nil || errors.Is(state.flagParseErr, pflag.ErrHelp) {
369-
if i.Command.HelpHandler == nil {
370-
return xerrors.Errorf("no handler or help for command %s", i.Command.FullName())
333+
if inv.Command.Handler == nil || errors.Is(state.flagParseErr, pflag.ErrHelp) {
334+
if inv.Command.HelpHandler == nil {
335+
return xerrors.Errorf("no handler or help for command %s", inv.Command.FullName())
371336
}
372-
return i.Command.HelpHandler(i)
337+
return inv.Command.HelpHandler(inv)
373338
}
374339

375-
err = mw(i.Command.Handler)(i)
340+
err = mw(inv.Command.Handler)(inv)
376341
if err != nil {
377342
return &RunCommandError{
378-
Cmd: i.Command,
343+
Cmd: inv.Command,
379344
Err: err,
380345
}
381346
}
@@ -438,33 +403,33 @@ func findArg(want string, args []string, fs *pflag.FlagSet) (int, error) {
438403
// If two command share a flag name, the first command wins.
439404
//
440405
//nolint:revive
441-
func (i *Invocation) Run() (err error) {
406+
func (inv *Invocation) Run() (err error) {
442407
defer func() {
443408
// Pflag is panicky, so additional context is helpful in tests.
444409
if flag.Lookup("test.v") == nil {
445410
return
446411
}
447412
if r := recover(); r != nil {
448-
err = xerrors.Errorf("panic recovered for %s: %v", i.Command.FullName(), r)
413+
err = xerrors.Errorf("panic recovered for %s: %v", inv.Command.FullName(), r)
449414
panic(err)
450415
}
451416
}()
452-
err = i.run(&runState{
453-
allArgs: i.Args,
417+
err = inv.run(&runState{
418+
allArgs: inv.Args,
454419
})
455420
return err
456421
}
457422

458423
// WithContext returns a copy of the Invocation with the given context.
459-
func (i *Invocation) WithContext(ctx context.Context) *Invocation {
460-
return i.with(func(i *Invocation) {
424+
func (inv *Invocation) WithContext(ctx context.Context) *Invocation {
425+
return inv.with(func(i *Invocation) {
461426
i.ctx = ctx
462427
})
463428
}
464429

465430
// with returns a copy of the Invocation with the given function applied.
466-
func (i *Invocation) with(fn func(*Invocation)) *Invocation {
467-
i2 := *i
431+
func (inv *Invocation) with(fn func(*Invocation)) *Invocation {
432+
i2 := *inv
468433
fn(&i2)
469434
return &i2
470435
}

cli/clibase/cmd_test.go

+63-4
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package clibase_test
33
import (
44
"bytes"
55
"context"
6+
"fmt"
67
"strings"
78
"testing"
89

@@ -247,6 +248,7 @@ func TestCommand_FlagOverride(t *testing.T) {
247248
Use: "1",
248249
Options: clibase.OptionSet{
249250
{
251+
Name: "flag",
250252
Flag: "f",
251253
Value: clibase.DiscardValue,
252254
},
@@ -256,6 +258,7 @@ func TestCommand_FlagOverride(t *testing.T) {
256258
Use: "2",
257259
Options: clibase.OptionSet{
258260
{
261+
Name: "flag",
259262
Flag: "f",
260263
Value: clibase.StringOf(&flag),
261264
},
@@ -515,7 +518,7 @@ func TestCommand_EmptySlice(t *testing.T) {
515518
{
516519
Name: "arr",
517520
Flag: "arr",
518-
Default: "bad,bad,bad",
521+
Default: "def,def,def",
519522
Env: "ARR",
520523
Value: clibase.StringArrayOf(&got),
521524
},
@@ -527,11 +530,67 @@ func TestCommand_EmptySlice(t *testing.T) {
527530
}
528531
}
529532

530-
// Base-case
531-
err := cmd("bad", "bad", "bad").Invoke().Run()
533+
// Base-case, uses default.
534+
err := cmd("def", "def", "def").Invoke().Run()
535+
require.NoError(t, err)
536+
537+
// Empty-env uses default, too.
538+
inv := cmd("def", "def", "def").Invoke()
539+
inv.Environ.Set("ARR", "")
540+
require.NoError(t, err)
541+
542+
// Reset to nothing at all via flag.
543+
inv = cmd().Invoke("--arr", "")
544+
inv.Environ.Set("ARR", "cant see")
545+
err = inv.Run()
546+
require.NoError(t, err)
547+
548+
// Reset to a specific value with flag.
549+
inv = cmd("great").Invoke("--arr", "great")
550+
inv.Environ.Set("ARR", "")
551+
err = inv.Run()
552+
require.NoError(t, err)
553+
}
554+
555+
func TestCommand_DefaultsOverride(t *testing.T) {
556+
t.Parallel()
557+
558+
var got string
559+
cmd := &clibase.Cmd{
560+
Options: clibase.OptionSet{
561+
{
562+
Name: "url",
563+
Flag: "url",
564+
Default: "def.com",
565+
Env: "URL",
566+
Value: clibase.StringOf(&got),
567+
},
568+
},
569+
Handler: (func(i *clibase.Invocation) error {
570+
_, _ = fmt.Fprintf(i.Stdout, "%s", got)
571+
return nil
572+
}),
573+
}
574+
575+
// Base case
576+
inv := cmd.Invoke()
577+
stdio := fakeIO(inv)
578+
err := inv.Run()
579+
require.NoError(t, err)
580+
require.Equal(t, "def.com", stdio.Stdout.String())
581+
582+
// Flag overrides
583+
inv = cmd.Invoke("--url", "good.com")
584+
stdio = fakeIO(inv)
585+
err = inv.Run()
532586
require.NoError(t, err)
587+
require.Equal(t, "good.com", stdio.Stdout.String())
533588

534-
inv := cmd().Invoke("--arr", "")
589+
// Env overrides
590+
inv = cmd.Invoke()
591+
inv.Environ.Set("URL", "good.com")
592+
stdio = fakeIO(inv)
535593
err = inv.Run()
536594
require.NoError(t, err)
595+
require.Equal(t, "good.com", stdio.Stdout.String())
537596
}

0 commit comments

Comments
 (0)