Skip to content

Commit 58d650c

Browse files
authored
fix: allow overridding default string array (coder#6873)
* fix: allow overridding default string array * Cleanup code * fixup! Cleanup code * fixup! Cleanup code * fixup! Cleanup code * fixup! Cleanup code
1 parent 1c7adc0 commit 58d650c

File tree

6 files changed

+91
-107
lines changed

6 files changed

+91
-107
lines changed

cli/clibase/cmd.go

+63-98
Original file line numberDiff line numberDiff line change
@@ -172,8 +172,8 @@ type Invocation struct {
172172

173173
// WithOS returns the invocation as a main package, filling in the invocation's unset
174174
// fields with OS defaults.
175-
func (i *Invocation) WithOS() *Invocation {
176-
return i.with(func(i *Invocation) {
175+
func (inv *Invocation) WithOS() *Invocation {
176+
return inv.with(func(i *Invocation) {
177177
i.Stdout = os.Stdout
178178
i.Stderr = os.Stderr
179179
i.Stdin = os.Stdin
@@ -182,18 +182,18 @@ func (i *Invocation) WithOS() *Invocation {
182182
})
183183
}
184184

185-
func (i *Invocation) Context() context.Context {
186-
if i.ctx == nil {
185+
func (inv *Invocation) Context() context.Context {
186+
if inv.ctx == nil {
187187
return context.Background()
188188
}
189-
return i.ctx
189+
return inv.ctx
190190
}
191191

192-
func (i *Invocation) ParsedFlags() *pflag.FlagSet {
193-
if i.parsedFlags == nil {
192+
func (inv *Invocation) ParsedFlags() *pflag.FlagSet {
193+
if inv.parsedFlags == nil {
194194
panic("flags not parsed, has Run() been called?")
195195
}
196-
return i.parsedFlags
196+
return inv.parsedFlags
197197
}
198198

199199
type runState struct {
@@ -218,39 +218,17 @@ func copyFlagSetWithout(fs *pflag.FlagSet, without string) *pflag.FlagSet {
218218
// run recursively executes the command and its children.
219219
// allArgs is wired through the stack so that global flags can be accepted
220220
// anywhere in the command invocation.
221-
func (i *Invocation) run(state *runState) error {
222-
err := i.Command.Options.SetDefaults()
223-
if err != nil {
224-
return xerrors.Errorf("setting defaults: %w", err)
225-
}
226-
227-
// If we set the Default of an array but later see a flag for it, we
228-
// don't want to append, we want to replace. So, we need to keep the state
229-
// of defaulted array options.
230-
defaultedArrays := make(map[string]int)
231-
for _, opt := range i.Command.Options {
232-
sv, ok := opt.Value.(pflag.SliceValue)
233-
if !ok {
234-
continue
235-
}
236-
237-
if opt.Flag == "" {
238-
continue
239-
}
240-
241-
defaultedArrays[opt.Flag] = len(sv.GetSlice())
242-
}
243-
244-
err = i.Command.Options.ParseEnv(i.Environ)
221+
func (inv *Invocation) run(state *runState) error {
222+
err := inv.Command.Options.ParseEnv(inv.Environ)
245223
if err != nil {
246224
return xerrors.Errorf("parsing env: %w", err)
247225
}
248226

249227
// Now the fun part, argument parsing!
250228

251229
children := make(map[string]*Cmd)
252-
for _, child := range i.Command.Children {
253-
child.Parent = i.Command
230+
for _, child := range inv.Command.Children {
231+
child.Parent = inv.Command
254232
for _, name := range append(child.Aliases, child.Name()) {
255233
if _, ok := children[name]; ok {
256234
return xerrors.Errorf("duplicate command name: %s", name)
@@ -259,57 +237,44 @@ func (i *Invocation) run(state *runState) error {
259237
}
260238
}
261239

262-
if i.parsedFlags == nil {
263-
i.parsedFlags = pflag.NewFlagSet(i.Command.Name(), pflag.ContinueOnError)
240+
if inv.parsedFlags == nil {
241+
inv.parsedFlags = pflag.NewFlagSet(inv.Command.Name(), pflag.ContinueOnError)
264242
// We handle Usage ourselves.
265-
i.parsedFlags.Usage = func() {}
243+
inv.parsedFlags.Usage = func() {}
266244
}
267245

268246
// If we find a duplicate flag, we want the deeper command's flag to override
269247
// the shallow one. Unfortunately, pflag has no way to remove a flag, so we
270248
// have to create a copy of the flagset without a value.
271-
i.Command.Options.FlagSet().VisitAll(func(f *pflag.Flag) {
272-
if i.parsedFlags.Lookup(f.Name) != nil {
273-
i.parsedFlags = copyFlagSetWithout(i.parsedFlags, f.Name)
249+
inv.Command.Options.FlagSet().VisitAll(func(f *pflag.Flag) {
250+
if inv.parsedFlags.Lookup(f.Name) != nil {
251+
inv.parsedFlags = copyFlagSetWithout(inv.parsedFlags, f.Name)
274252
}
275-
i.parsedFlags.AddFlag(f)
253+
inv.parsedFlags.AddFlag(f)
276254
})
277255

278256
var parsedArgs []string
279257

280-
if !i.Command.RawArgs {
258+
if !inv.Command.RawArgs {
281259
// Flag parsing will fail on intermediate commands in the command tree,
282260
// so we check the error after looking for a child command.
283-
state.flagParseErr = i.parsedFlags.Parse(state.allArgs)
284-
parsedArgs = i.parsedFlags.Args()
285-
286-
i.parsedFlags.VisitAll(func(f *pflag.Flag) {
287-
i, ok := defaultedArrays[f.Name]
288-
if !ok {
289-
return
290-
}
291-
292-
if !f.Changed {
293-
return
294-
}
261+
state.flagParseErr = inv.parsedFlags.Parse(state.allArgs)
262+
parsedArgs = inv.parsedFlags.Args()
263+
}
295264

296-
// If flag was changed, we need to remove the default values.
297-
sv, ok := f.Value.(pflag.SliceValue)
298-
if !ok {
299-
panic("defaulted array option is not a slice value")
300-
}
301-
ss := sv.GetSlice()
302-
if len(ss) == 0 {
303-
// Slice likely zeroed by a flag.
304-
// E.g. "--fruit" may default to "apples,oranges" but the user
305-
// provided "--fruit=""".
306-
return
307-
}
308-
err := sv.Replace(ss[i:])
309-
if err != nil {
310-
panic(err)
311-
}
312-
})
265+
// Set defaults for flags that weren't set by the user.
266+
skipDefaults := make(map[int]struct{}, len(inv.Command.Options))
267+
for i, opt := range inv.Command.Options {
268+
if fl := inv.parsedFlags.Lookup(opt.Flag); fl != nil && fl.Changed {
269+
skipDefaults[i] = struct{}{}
270+
}
271+
if opt.envChanged {
272+
skipDefaults[i] = struct{}{}
273+
}
274+
}
275+
err = inv.Command.Options.SetDefaults(skipDefaults)
276+
if err != nil {
277+
return xerrors.Errorf("setting defaults: %w", err)
313278
}
314279

315280
// Run child command if found (next child only)
@@ -318,64 +283,64 @@ func (i *Invocation) run(state *runState) error {
318283
if len(parsedArgs) > state.commandDepth {
319284
nextArg := parsedArgs[state.commandDepth]
320285
if child, ok := children[nextArg]; ok {
321-
child.Parent = i.Command
322-
i.Command = child
286+
child.Parent = inv.Command
287+
inv.Command = child
323288
state.commandDepth++
324-
return i.run(state)
289+
return inv.run(state)
325290
}
326291
}
327292

328293
// Flag parse errors are irrelevant for raw args commands.
329-
if !i.Command.RawArgs && state.flagParseErr != nil && !errors.Is(state.flagParseErr, pflag.ErrHelp) {
294+
if !inv.Command.RawArgs && state.flagParseErr != nil && !errors.Is(state.flagParseErr, pflag.ErrHelp) {
330295
return xerrors.Errorf(
331296
"parsing flags (%v) for %q: %w",
332297
state.allArgs,
333-
i.Command.FullName(), state.flagParseErr,
298+
inv.Command.FullName(), state.flagParseErr,
334299
)
335300
}
336301

337-
if i.Command.RawArgs {
302+
if inv.Command.RawArgs {
338303
// If we're at the root command, then the name is omitted
339304
// from the arguments, so we can just use the entire slice.
340305
if state.commandDepth == 0 {
341-
i.Args = state.allArgs
306+
inv.Args = state.allArgs
342307
} else {
343-
argPos, err := findArg(i.Command.Name(), state.allArgs, i.parsedFlags)
308+
argPos, err := findArg(inv.Command.Name(), state.allArgs, inv.parsedFlags)
344309
if err != nil {
345310
panic(err)
346311
}
347-
i.Args = state.allArgs[argPos+1:]
312+
inv.Args = state.allArgs[argPos+1:]
348313
}
349314
} else {
350315
// In non-raw-arg mode, we want to skip over flags.
351-
i.Args = parsedArgs[state.commandDepth:]
316+
inv.Args = parsedArgs[state.commandDepth:]
352317
}
353318

354-
mw := i.Command.Middleware
319+
mw := inv.Command.Middleware
355320
if mw == nil {
356321
mw = Chain()
357322
}
358323

359-
ctx := i.ctx
324+
ctx := inv.ctx
360325
if ctx == nil {
361326
ctx = context.Background()
362327
}
363328

364329
ctx, cancel := context.WithCancel(ctx)
365330
defer cancel()
366-
i = i.WithContext(ctx)
331+
inv = inv.WithContext(ctx)
367332

368-
if i.Command.Handler == nil || errors.Is(state.flagParseErr, pflag.ErrHelp) {
369-
if i.Command.HelpHandler == nil {
370-
return xerrors.Errorf("no handler or help for command %s", i.Command.FullName())
333+
if inv.Command.Handler == nil || errors.Is(state.flagParseErr, pflag.ErrHelp) {
334+
if inv.Command.HelpHandler == nil {
335+
return xerrors.Errorf("no handler or help for command %s", inv.Command.FullName())
371336
}
372-
return i.Command.HelpHandler(i)
337+
return inv.Command.HelpHandler(inv)
373338
}
374339

375-
err = mw(i.Command.Handler)(i)
340+
err = mw(inv.Command.Handler)(inv)
376341
if err != nil {
377342
return &RunCommandError{
378-
Cmd: i.Command,
343+
Cmd: inv.Command,
379344
Err: err,
380345
}
381346
}
@@ -438,33 +403,33 @@ func findArg(want string, args []string, fs *pflag.FlagSet) (int, error) {
438403
// If two command share a flag name, the first command wins.
439404
//
440405
//nolint:revive
441-
func (i *Invocation) Run() (err error) {
406+
func (inv *Invocation) Run() (err error) {
442407
defer func() {
443408
// Pflag is panicky, so additional context is helpful in tests.
444409
if flag.Lookup("test.v") == nil {
445410
return
446411
}
447412
if r := recover(); r != nil {
448-
err = xerrors.Errorf("panic recovered for %s: %v", i.Command.FullName(), r)
413+
err = xerrors.Errorf("panic recovered for %s: %v", inv.Command.FullName(), r)
449414
panic(err)
450415
}
451416
}()
452-
err = i.run(&runState{
453-
allArgs: i.Args,
417+
err = inv.run(&runState{
418+
allArgs: inv.Args,
454419
})
455420
return err
456421
}
457422

458423
// WithContext returns a copy of the Invocation with the given context.
459-
func (i *Invocation) WithContext(ctx context.Context) *Invocation {
460-
return i.with(func(i *Invocation) {
424+
func (inv *Invocation) WithContext(ctx context.Context) *Invocation {
425+
return inv.with(func(i *Invocation) {
461426
i.ctx = ctx
462427
})
463428
}
464429

465430
// with returns a copy of the Invocation with the given function applied.
466-
func (i *Invocation) with(fn func(*Invocation)) *Invocation {
467-
i2 := *i
431+
func (inv *Invocation) with(fn func(*Invocation)) *Invocation {
432+
i2 := *inv
468433
fn(&i2)
469434
return &i2
470435
}

cli/clibase/cmd_test.go

+9-1
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,7 @@ func TestCommand_FlagOverride(t *testing.T) {
247247
Use: "1",
248248
Options: clibase.OptionSet{
249249
{
250+
Name: "flag",
250251
Flag: "f",
251252
Value: clibase.DiscardValue,
252253
},
@@ -256,6 +257,7 @@ func TestCommand_FlagOverride(t *testing.T) {
256257
Use: "2",
257258
Options: clibase.OptionSet{
258259
{
260+
Name: "flag",
259261
Flag: "f",
260262
Value: clibase.StringOf(&flag),
261263
},
@@ -527,11 +529,17 @@ func TestCommand_EmptySlice(t *testing.T) {
527529
}
528530
}
529531

530-
// Base-case
532+
// Base-case, uses default.
531533
err := cmd("bad", "bad", "bad").Invoke().Run()
532534
require.NoError(t, err)
533535

536+
// Reset to nothing at all.
534537
inv := cmd().Invoke("--arr", "")
535538
err = inv.Run()
536539
require.NoError(t, err)
540+
541+
// Override
542+
inv = cmd("great").Invoke("--arr", "great")
543+
err = inv.Run()
544+
require.NoError(t, err)
537545
}

cli/clibase/option.go

+15-4
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ type Option struct {
4646
UseInstead []Option `json:"use_instead,omitempty"`
4747

4848
Hidden bool `json:"hidden,omitempty"`
49+
50+
envChanged bool
4951
}
5052

5153
// OptionSet is a group of options that can be applied to a command.
@@ -133,6 +135,7 @@ func (s *OptionSet) ParseEnv(vs []EnvVar) error {
133135
continue
134136
}
135137

138+
opt.envChanged = true
136139
if err := opt.Value.Set(envVal); err != nil {
137140
merr = multierror.Append(
138141
merr, xerrors.Errorf("parse %q: %w", opt.Name, err),
@@ -143,19 +146,27 @@ func (s *OptionSet) ParseEnv(vs []EnvVar) error {
143146
return merr.ErrorOrNil()
144147
}
145148

146-
// SetDefaults sets the default values for each Option.
147-
// It should be called before all parsing (e.g. ParseFlags, ParseEnv).
148-
func (s *OptionSet) SetDefaults() error {
149+
// SetDefaults sets the default values for each Option, skipping values
150+
// that have already been set as indicated by the skip map.
151+
func (s *OptionSet) SetDefaults(skip map[int]struct{}) error {
149152
if s == nil {
150153
return nil
151154
}
152155

153156
var merr *multierror.Error
154157

155-
for _, opt := range *s {
158+
for i, opt := range *s {
159+
// Skip values that may have already been set by the user.
160+
if len(skip) > 0 {
161+
if _, ok := skip[i]; ok {
162+
continue
163+
}
164+
}
165+
156166
if opt.Default == "" {
157167
continue
158168
}
169+
159170
if opt.Value == nil {
160171
merr = multierror.Append(
161172
merr,

cli/clibase/option_test.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ func TestOptionSet_ParseFlags(t *testing.T) {
4949
},
5050
}
5151

52-
err := os.SetDefaults()
52+
err := os.SetDefaults(nil)
5353
require.NoError(t, err)
5454

5555
err = os.FlagSet().Parse([]string{"--name", "foo", "--name", "bar"})
@@ -111,7 +111,7 @@ func TestOptionSet_ParseEnv(t *testing.T) {
111111
},
112112
}
113113

114-
err := os.SetDefaults()
114+
err := os.SetDefaults(nil)
115115
require.NoError(t, err)
116116

117117
err = os.ParseEnv(clibase.ParseEnviron([]string{"CODER_WORKSPACE_NAME="}, "CODER_"))

cli/clibase/yaml_test.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ func TestOption_ToYAML(t *testing.T) {
4444
},
4545
}
4646

47-
err := os.SetDefaults()
47+
err := os.SetDefaults(nil)
4848
require.NoError(t, err)
4949

5050
n, err := os.ToYAML()

0 commit comments

Comments
 (0)