Skip to content

Commit 6981f89

Browse files
committed
Revert "fix: allow overridding default string array (coder#6873)"
This reverts commit 58d650c.
1 parent 58d650c commit 6981f89

File tree

6 files changed

+107
-91
lines changed

6 files changed

+107
-91
lines changed

cli/clibase/cmd.go

+98-63
Original file line numberDiff line numberDiff line change
@@ -172,8 +172,8 @@ type Invocation struct {
172172

173173
// WithOS returns the invocation as a main package, filling in the invocation's unset
174174
// fields with OS defaults.
175-
func (inv *Invocation) WithOS() *Invocation {
176-
return inv.with(func(i *Invocation) {
175+
func (i *Invocation) WithOS() *Invocation {
176+
return i.with(func(i *Invocation) {
177177
i.Stdout = os.Stdout
178178
i.Stderr = os.Stderr
179179
i.Stdin = os.Stdin
@@ -182,18 +182,18 @@ func (inv *Invocation) WithOS() *Invocation {
182182
})
183183
}
184184

185-
func (inv *Invocation) Context() context.Context {
186-
if inv.ctx == nil {
185+
func (i *Invocation) Context() context.Context {
186+
if i.ctx == nil {
187187
return context.Background()
188188
}
189-
return inv.ctx
189+
return i.ctx
190190
}
191191

192-
func (inv *Invocation) ParsedFlags() *pflag.FlagSet {
193-
if inv.parsedFlags == nil {
192+
func (i *Invocation) ParsedFlags() *pflag.FlagSet {
193+
if i.parsedFlags == nil {
194194
panic("flags not parsed, has Run() been called?")
195195
}
196-
return inv.parsedFlags
196+
return i.parsedFlags
197197
}
198198

199199
type runState struct {
@@ -218,17 +218,39 @@ func copyFlagSetWithout(fs *pflag.FlagSet, without string) *pflag.FlagSet {
218218
// run recursively executes the command and its children.
219219
// allArgs is wired through the stack so that global flags can be accepted
220220
// anywhere in the command invocation.
221-
func (inv *Invocation) run(state *runState) error {
222-
err := inv.Command.Options.ParseEnv(inv.Environ)
221+
func (i *Invocation) run(state *runState) error {
222+
err := i.Command.Options.SetDefaults()
223+
if err != nil {
224+
return xerrors.Errorf("setting defaults: %w", err)
225+
}
226+
227+
// If we set the Default of an array but later see a flag for it, we
228+
// don't want to append, we want to replace. So, we need to keep the state
229+
// of defaulted array options.
230+
defaultedArrays := make(map[string]int)
231+
for _, opt := range i.Command.Options {
232+
sv, ok := opt.Value.(pflag.SliceValue)
233+
if !ok {
234+
continue
235+
}
236+
237+
if opt.Flag == "" {
238+
continue
239+
}
240+
241+
defaultedArrays[opt.Flag] = len(sv.GetSlice())
242+
}
243+
244+
err = i.Command.Options.ParseEnv(i.Environ)
223245
if err != nil {
224246
return xerrors.Errorf("parsing env: %w", err)
225247
}
226248

227249
// Now the fun part, argument parsing!
228250

229251
children := make(map[string]*Cmd)
230-
for _, child := range inv.Command.Children {
231-
child.Parent = inv.Command
252+
for _, child := range i.Command.Children {
253+
child.Parent = i.Command
232254
for _, name := range append(child.Aliases, child.Name()) {
233255
if _, ok := children[name]; ok {
234256
return xerrors.Errorf("duplicate command name: %s", name)
@@ -237,44 +259,57 @@ func (inv *Invocation) run(state *runState) error {
237259
}
238260
}
239261

240-
if inv.parsedFlags == nil {
241-
inv.parsedFlags = pflag.NewFlagSet(inv.Command.Name(), pflag.ContinueOnError)
262+
if i.parsedFlags == nil {
263+
i.parsedFlags = pflag.NewFlagSet(i.Command.Name(), pflag.ContinueOnError)
242264
// We handle Usage ourselves.
243-
inv.parsedFlags.Usage = func() {}
265+
i.parsedFlags.Usage = func() {}
244266
}
245267

246268
// If we find a duplicate flag, we want the deeper command's flag to override
247269
// the shallow one. Unfortunately, pflag has no way to remove a flag, so we
248270
// have to create a copy of the flagset without a value.
249-
inv.Command.Options.FlagSet().VisitAll(func(f *pflag.Flag) {
250-
if inv.parsedFlags.Lookup(f.Name) != nil {
251-
inv.parsedFlags = copyFlagSetWithout(inv.parsedFlags, f.Name)
271+
i.Command.Options.FlagSet().VisitAll(func(f *pflag.Flag) {
272+
if i.parsedFlags.Lookup(f.Name) != nil {
273+
i.parsedFlags = copyFlagSetWithout(i.parsedFlags, f.Name)
252274
}
253-
inv.parsedFlags.AddFlag(f)
275+
i.parsedFlags.AddFlag(f)
254276
})
255277

256278
var parsedArgs []string
257279

258-
if !inv.Command.RawArgs {
280+
if !i.Command.RawArgs {
259281
// Flag parsing will fail on intermediate commands in the command tree,
260282
// so we check the error after looking for a child command.
261-
state.flagParseErr = inv.parsedFlags.Parse(state.allArgs)
262-
parsedArgs = inv.parsedFlags.Args()
263-
}
283+
state.flagParseErr = i.parsedFlags.Parse(state.allArgs)
284+
parsedArgs = i.parsedFlags.Args()
264285

265-
// Set defaults for flags that weren't set by the user.
266-
skipDefaults := make(map[int]struct{}, len(inv.Command.Options))
267-
for i, opt := range inv.Command.Options {
268-
if fl := inv.parsedFlags.Lookup(opt.Flag); fl != nil && fl.Changed {
269-
skipDefaults[i] = struct{}{}
270-
}
271-
if opt.envChanged {
272-
skipDefaults[i] = struct{}{}
273-
}
274-
}
275-
err = inv.Command.Options.SetDefaults(skipDefaults)
276-
if err != nil {
277-
return xerrors.Errorf("setting defaults: %w", err)
286+
i.parsedFlags.VisitAll(func(f *pflag.Flag) {
287+
i, ok := defaultedArrays[f.Name]
288+
if !ok {
289+
return
290+
}
291+
292+
if !f.Changed {
293+
return
294+
}
295+
296+
// If flag was changed, we need to remove the default values.
297+
sv, ok := f.Value.(pflag.SliceValue)
298+
if !ok {
299+
panic("defaulted array option is not a slice value")
300+
}
301+
ss := sv.GetSlice()
302+
if len(ss) == 0 {
303+
// Slice likely zeroed by a flag.
304+
// E.g. "--fruit" may default to "apples,oranges" but the user
305+
// provided "--fruit=""".
306+
return
307+
}
308+
err := sv.Replace(ss[i:])
309+
if err != nil {
310+
panic(err)
311+
}
312+
})
278313
}
279314

280315
// Run child command if found (next child only)
@@ -283,64 +318,64 @@ func (inv *Invocation) run(state *runState) error {
283318
if len(parsedArgs) > state.commandDepth {
284319
nextArg := parsedArgs[state.commandDepth]
285320
if child, ok := children[nextArg]; ok {
286-
child.Parent = inv.Command
287-
inv.Command = child
321+
child.Parent = i.Command
322+
i.Command = child
288323
state.commandDepth++
289-
return inv.run(state)
324+
return i.run(state)
290325
}
291326
}
292327

293328
// Flag parse errors are irrelevant for raw args commands.
294-
if !inv.Command.RawArgs && state.flagParseErr != nil && !errors.Is(state.flagParseErr, pflag.ErrHelp) {
329+
if !i.Command.RawArgs && state.flagParseErr != nil && !errors.Is(state.flagParseErr, pflag.ErrHelp) {
295330
return xerrors.Errorf(
296331
"parsing flags (%v) for %q: %w",
297332
state.allArgs,
298-
inv.Command.FullName(), state.flagParseErr,
333+
i.Command.FullName(), state.flagParseErr,
299334
)
300335
}
301336

302-
if inv.Command.RawArgs {
337+
if i.Command.RawArgs {
303338
// If we're at the root command, then the name is omitted
304339
// from the arguments, so we can just use the entire slice.
305340
if state.commandDepth == 0 {
306-
inv.Args = state.allArgs
341+
i.Args = state.allArgs
307342
} else {
308-
argPos, err := findArg(inv.Command.Name(), state.allArgs, inv.parsedFlags)
343+
argPos, err := findArg(i.Command.Name(), state.allArgs, i.parsedFlags)
309344
if err != nil {
310345
panic(err)
311346
}
312-
inv.Args = state.allArgs[argPos+1:]
347+
i.Args = state.allArgs[argPos+1:]
313348
}
314349
} else {
315350
// In non-raw-arg mode, we want to skip over flags.
316-
inv.Args = parsedArgs[state.commandDepth:]
351+
i.Args = parsedArgs[state.commandDepth:]
317352
}
318353

319-
mw := inv.Command.Middleware
354+
mw := i.Command.Middleware
320355
if mw == nil {
321356
mw = Chain()
322357
}
323358

324-
ctx := inv.ctx
359+
ctx := i.ctx
325360
if ctx == nil {
326361
ctx = context.Background()
327362
}
328363

329364
ctx, cancel := context.WithCancel(ctx)
330365
defer cancel()
331-
inv = inv.WithContext(ctx)
366+
i = i.WithContext(ctx)
332367

333-
if inv.Command.Handler == nil || errors.Is(state.flagParseErr, pflag.ErrHelp) {
334-
if inv.Command.HelpHandler == nil {
335-
return xerrors.Errorf("no handler or help for command %s", inv.Command.FullName())
368+
if i.Command.Handler == nil || errors.Is(state.flagParseErr, pflag.ErrHelp) {
369+
if i.Command.HelpHandler == nil {
370+
return xerrors.Errorf("no handler or help for command %s", i.Command.FullName())
336371
}
337-
return inv.Command.HelpHandler(inv)
372+
return i.Command.HelpHandler(i)
338373
}
339374

340-
err = mw(inv.Command.Handler)(inv)
375+
err = mw(i.Command.Handler)(i)
341376
if err != nil {
342377
return &RunCommandError{
343-
Cmd: inv.Command,
378+
Cmd: i.Command,
344379
Err: err,
345380
}
346381
}
@@ -403,33 +438,33 @@ func findArg(want string, args []string, fs *pflag.FlagSet) (int, error) {
403438
// If two command share a flag name, the first command wins.
404439
//
405440
//nolint:revive
406-
func (inv *Invocation) Run() (err error) {
441+
func (i *Invocation) Run() (err error) {
407442
defer func() {
408443
// Pflag is panicky, so additional context is helpful in tests.
409444
if flag.Lookup("test.v") == nil {
410445
return
411446
}
412447
if r := recover(); r != nil {
413-
err = xerrors.Errorf("panic recovered for %s: %v", inv.Command.FullName(), r)
448+
err = xerrors.Errorf("panic recovered for %s: %v", i.Command.FullName(), r)
414449
panic(err)
415450
}
416451
}()
417-
err = inv.run(&runState{
418-
allArgs: inv.Args,
452+
err = i.run(&runState{
453+
allArgs: i.Args,
419454
})
420455
return err
421456
}
422457

423458
// WithContext returns a copy of the Invocation with the given context.
424-
func (inv *Invocation) WithContext(ctx context.Context) *Invocation {
425-
return inv.with(func(i *Invocation) {
459+
func (i *Invocation) WithContext(ctx context.Context) *Invocation {
460+
return i.with(func(i *Invocation) {
426461
i.ctx = ctx
427462
})
428463
}
429464

430465
// with returns a copy of the Invocation with the given function applied.
431-
func (inv *Invocation) with(fn func(*Invocation)) *Invocation {
432-
i2 := *inv
466+
func (i *Invocation) with(fn func(*Invocation)) *Invocation {
467+
i2 := *i
433468
fn(&i2)
434469
return &i2
435470
}

cli/clibase/cmd_test.go

+1-9
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,6 @@ func TestCommand_FlagOverride(t *testing.T) {
247247
Use: "1",
248248
Options: clibase.OptionSet{
249249
{
250-
Name: "flag",
251250
Flag: "f",
252251
Value: clibase.DiscardValue,
253252
},
@@ -257,7 +256,6 @@ func TestCommand_FlagOverride(t *testing.T) {
257256
Use: "2",
258257
Options: clibase.OptionSet{
259258
{
260-
Name: "flag",
261259
Flag: "f",
262260
Value: clibase.StringOf(&flag),
263261
},
@@ -529,17 +527,11 @@ func TestCommand_EmptySlice(t *testing.T) {
529527
}
530528
}
531529

532-
// Base-case, uses default.
530+
// Base-case
533531
err := cmd("bad", "bad", "bad").Invoke().Run()
534532
require.NoError(t, err)
535533

536-
// Reset to nothing at all.
537534
inv := cmd().Invoke("--arr", "")
538535
err = inv.Run()
539536
require.NoError(t, err)
540-
541-
// Override
542-
inv = cmd("great").Invoke("--arr", "great")
543-
err = inv.Run()
544-
require.NoError(t, err)
545537
}

cli/clibase/option.go

+4-15
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,6 @@ type Option struct {
4646
UseInstead []Option `json:"use_instead,omitempty"`
4747

4848
Hidden bool `json:"hidden,omitempty"`
49-
50-
envChanged bool
5149
}
5250

5351
// OptionSet is a group of options that can be applied to a command.
@@ -135,7 +133,6 @@ func (s *OptionSet) ParseEnv(vs []EnvVar) error {
135133
continue
136134
}
137135

138-
opt.envChanged = true
139136
if err := opt.Value.Set(envVal); err != nil {
140137
merr = multierror.Append(
141138
merr, xerrors.Errorf("parse %q: %w", opt.Name, err),
@@ -146,27 +143,19 @@ func (s *OptionSet) ParseEnv(vs []EnvVar) error {
146143
return merr.ErrorOrNil()
147144
}
148145

149-
// SetDefaults sets the default values for each Option, skipping values
150-
// that have already been set as indicated by the skip map.
151-
func (s *OptionSet) SetDefaults(skip map[int]struct{}) error {
146+
// SetDefaults sets the default values for each Option.
147+
// It should be called before all parsing (e.g. ParseFlags, ParseEnv).
148+
func (s *OptionSet) SetDefaults() error {
152149
if s == nil {
153150
return nil
154151
}
155152

156153
var merr *multierror.Error
157154

158-
for i, opt := range *s {
159-
// Skip values that may have already been set by the user.
160-
if len(skip) > 0 {
161-
if _, ok := skip[i]; ok {
162-
continue
163-
}
164-
}
165-
155+
for _, opt := range *s {
166156
if opt.Default == "" {
167157
continue
168158
}
169-
170159
if opt.Value == nil {
171160
merr = multierror.Append(
172161
merr,

cli/clibase/option_test.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ func TestOptionSet_ParseFlags(t *testing.T) {
4949
},
5050
}
5151

52-
err := os.SetDefaults(nil)
52+
err := os.SetDefaults()
5353
require.NoError(t, err)
5454

5555
err = os.FlagSet().Parse([]string{"--name", "foo", "--name", "bar"})
@@ -111,7 +111,7 @@ func TestOptionSet_ParseEnv(t *testing.T) {
111111
},
112112
}
113113

114-
err := os.SetDefaults(nil)
114+
err := os.SetDefaults()
115115
require.NoError(t, err)
116116

117117
err = os.ParseEnv(clibase.ParseEnviron([]string{"CODER_WORKSPACE_NAME="}, "CODER_"))

cli/clibase/yaml_test.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ func TestOption_ToYAML(t *testing.T) {
4444
},
4545
}
4646

47-
err := os.SetDefaults(nil)
47+
err := os.SetDefaults()
4848
require.NoError(t, err)
4949

5050
n, err := os.ToYAML()

0 commit comments

Comments
 (0)