@@ -172,8 +172,8 @@ type Invocation struct {
172
172
173
173
// WithOS returns the invocation as a main package, filling in the invocation's unset
174
174
// fields with OS defaults.
175
- func (inv * Invocation ) WithOS () * Invocation {
176
- return inv .with (func (i * Invocation ) {
175
+ func (i * Invocation ) WithOS () * Invocation {
176
+ return i .with (func (i * Invocation ) {
177
177
i .Stdout = os .Stdout
178
178
i .Stderr = os .Stderr
179
179
i .Stdin = os .Stdin
@@ -182,18 +182,18 @@ func (inv *Invocation) WithOS() *Invocation {
182
182
})
183
183
}
184
184
185
- func (inv * Invocation ) Context () context.Context {
186
- if inv .ctx == nil {
185
+ func (i * Invocation ) Context () context.Context {
186
+ if i .ctx == nil {
187
187
return context .Background ()
188
188
}
189
- return inv .ctx
189
+ return i .ctx
190
190
}
191
191
192
- func (inv * Invocation ) ParsedFlags () * pflag.FlagSet {
193
- if inv .parsedFlags == nil {
192
+ func (i * Invocation ) ParsedFlags () * pflag.FlagSet {
193
+ if i .parsedFlags == nil {
194
194
panic ("flags not parsed, has Run() been called?" )
195
195
}
196
- return inv .parsedFlags
196
+ return i .parsedFlags
197
197
}
198
198
199
199
type runState struct {
@@ -218,17 +218,39 @@ func copyFlagSetWithout(fs *pflag.FlagSet, without string) *pflag.FlagSet {
218
218
// run recursively executes the command and its children.
219
219
// allArgs is wired through the stack so that global flags can be accepted
220
220
// anywhere in the command invocation.
221
- func (inv * Invocation ) run (state * runState ) error {
222
- err := inv .Command .Options .ParseEnv (inv .Environ )
221
+ func (i * Invocation ) run (state * runState ) error {
222
+ err := i .Command .Options .SetDefaults ()
223
+ if err != nil {
224
+ return xerrors .Errorf ("setting defaults: %w" , err )
225
+ }
226
+
227
+ // If we set the Default of an array but later see a flag for it, we
228
+ // don't want to append, we want to replace. So, we need to keep the state
229
+ // of defaulted array options.
230
+ defaultedArrays := make (map [string ]int )
231
+ for _ , opt := range i .Command .Options {
232
+ sv , ok := opt .Value .(pflag.SliceValue )
233
+ if ! ok {
234
+ continue
235
+ }
236
+
237
+ if opt .Flag == "" {
238
+ continue
239
+ }
240
+
241
+ defaultedArrays [opt .Flag ] = len (sv .GetSlice ())
242
+ }
243
+
244
+ err = i .Command .Options .ParseEnv (i .Environ )
223
245
if err != nil {
224
246
return xerrors .Errorf ("parsing env: %w" , err )
225
247
}
226
248
227
249
// Now the fun part, argument parsing!
228
250
229
251
children := make (map [string ]* Cmd )
230
- for _ , child := range inv .Command .Children {
231
- child .Parent = inv .Command
252
+ for _ , child := range i .Command .Children {
253
+ child .Parent = i .Command
232
254
for _ , name := range append (child .Aliases , child .Name ()) {
233
255
if _ , ok := children [name ]; ok {
234
256
return xerrors .Errorf ("duplicate command name: %s" , name )
@@ -237,44 +259,57 @@ func (inv *Invocation) run(state *runState) error {
237
259
}
238
260
}
239
261
240
- if inv .parsedFlags == nil {
241
- inv .parsedFlags = pflag .NewFlagSet (inv .Command .Name (), pflag .ContinueOnError )
262
+ if i .parsedFlags == nil {
263
+ i .parsedFlags = pflag .NewFlagSet (i .Command .Name (), pflag .ContinueOnError )
242
264
// We handle Usage ourselves.
243
- inv .parsedFlags .Usage = func () {}
265
+ i .parsedFlags .Usage = func () {}
244
266
}
245
267
246
268
// If we find a duplicate flag, we want the deeper command's flag to override
247
269
// the shallow one. Unfortunately, pflag has no way to remove a flag, so we
248
270
// have to create a copy of the flagset without a value.
249
- inv .Command .Options .FlagSet ().VisitAll (func (f * pflag.Flag ) {
250
- if inv .parsedFlags .Lookup (f .Name ) != nil {
251
- inv .parsedFlags = copyFlagSetWithout (inv .parsedFlags , f .Name )
271
+ i .Command .Options .FlagSet ().VisitAll (func (f * pflag.Flag ) {
272
+ if i .parsedFlags .Lookup (f .Name ) != nil {
273
+ i .parsedFlags = copyFlagSetWithout (i .parsedFlags , f .Name )
252
274
}
253
- inv .parsedFlags .AddFlag (f )
275
+ i .parsedFlags .AddFlag (f )
254
276
})
255
277
256
278
var parsedArgs []string
257
279
258
- if ! inv .Command .RawArgs {
280
+ if ! i .Command .RawArgs {
259
281
// Flag parsing will fail on intermediate commands in the command tree,
260
282
// so we check the error after looking for a child command.
261
- state .flagParseErr = inv .parsedFlags .Parse (state .allArgs )
262
- parsedArgs = inv .parsedFlags .Args ()
263
- }
283
+ state .flagParseErr = i .parsedFlags .Parse (state .allArgs )
284
+ parsedArgs = i .parsedFlags .Args ()
264
285
265
- // Set defaults for flags that weren't set by the user.
266
- skipDefaults := make (map [int ]struct {}, len (inv .Command .Options ))
267
- for i , opt := range inv .Command .Options {
268
- if fl := inv .parsedFlags .Lookup (opt .Flag ); fl != nil && fl .Changed {
269
- skipDefaults [i ] = struct {}{}
270
- }
271
- if opt .envChanged {
272
- skipDefaults [i ] = struct {}{}
273
- }
274
- }
275
- err = inv .Command .Options .SetDefaults (skipDefaults )
276
- if err != nil {
277
- return xerrors .Errorf ("setting defaults: %w" , err )
286
+ i .parsedFlags .VisitAll (func (f * pflag.Flag ) {
287
+ i , ok := defaultedArrays [f .Name ]
288
+ if ! ok {
289
+ return
290
+ }
291
+
292
+ if ! f .Changed {
293
+ return
294
+ }
295
+
296
+ // If flag was changed, we need to remove the default values.
297
+ sv , ok := f .Value .(pflag.SliceValue )
298
+ if ! ok {
299
+ panic ("defaulted array option is not a slice value" )
300
+ }
301
+ ss := sv .GetSlice ()
302
+ if len (ss ) == 0 {
303
+ // Slice likely zeroed by a flag.
304
+ // E.g. "--fruit" may default to "apples,oranges" but the user
305
+ // provided "--fruit=""".
306
+ return
307
+ }
308
+ err := sv .Replace (ss [i :])
309
+ if err != nil {
310
+ panic (err )
311
+ }
312
+ })
278
313
}
279
314
280
315
// Run child command if found (next child only)
@@ -283,64 +318,64 @@ func (inv *Invocation) run(state *runState) error {
283
318
if len (parsedArgs ) > state .commandDepth {
284
319
nextArg := parsedArgs [state .commandDepth ]
285
320
if child , ok := children [nextArg ]; ok {
286
- child .Parent = inv .Command
287
- inv .Command = child
321
+ child .Parent = i .Command
322
+ i .Command = child
288
323
state .commandDepth ++
289
- return inv .run (state )
324
+ return i .run (state )
290
325
}
291
326
}
292
327
293
328
// Flag parse errors are irrelevant for raw args commands.
294
- if ! inv .Command .RawArgs && state .flagParseErr != nil && ! errors .Is (state .flagParseErr , pflag .ErrHelp ) {
329
+ if ! i .Command .RawArgs && state .flagParseErr != nil && ! errors .Is (state .flagParseErr , pflag .ErrHelp ) {
295
330
return xerrors .Errorf (
296
331
"parsing flags (%v) for %q: %w" ,
297
332
state .allArgs ,
298
- inv .Command .FullName (), state .flagParseErr ,
333
+ i .Command .FullName (), state .flagParseErr ,
299
334
)
300
335
}
301
336
302
- if inv .Command .RawArgs {
337
+ if i .Command .RawArgs {
303
338
// If we're at the root command, then the name is omitted
304
339
// from the arguments, so we can just use the entire slice.
305
340
if state .commandDepth == 0 {
306
- inv .Args = state .allArgs
341
+ i .Args = state .allArgs
307
342
} else {
308
- argPos , err := findArg (inv .Command .Name (), state .allArgs , inv .parsedFlags )
343
+ argPos , err := findArg (i .Command .Name (), state .allArgs , i .parsedFlags )
309
344
if err != nil {
310
345
panic (err )
311
346
}
312
- inv .Args = state .allArgs [argPos + 1 :]
347
+ i .Args = state .allArgs [argPos + 1 :]
313
348
}
314
349
} else {
315
350
// In non-raw-arg mode, we want to skip over flags.
316
- inv .Args = parsedArgs [state .commandDepth :]
351
+ i .Args = parsedArgs [state .commandDepth :]
317
352
}
318
353
319
- mw := inv .Command .Middleware
354
+ mw := i .Command .Middleware
320
355
if mw == nil {
321
356
mw = Chain ()
322
357
}
323
358
324
- ctx := inv .ctx
359
+ ctx := i .ctx
325
360
if ctx == nil {
326
361
ctx = context .Background ()
327
362
}
328
363
329
364
ctx , cancel := context .WithCancel (ctx )
330
365
defer cancel ()
331
- inv = inv .WithContext (ctx )
366
+ i = i .WithContext (ctx )
332
367
333
- if inv .Command .Handler == nil || errors .Is (state .flagParseErr , pflag .ErrHelp ) {
334
- if inv .Command .HelpHandler == nil {
335
- return xerrors .Errorf ("no handler or help for command %s" , inv .Command .FullName ())
368
+ if i .Command .Handler == nil || errors .Is (state .flagParseErr , pflag .ErrHelp ) {
369
+ if i .Command .HelpHandler == nil {
370
+ return xerrors .Errorf ("no handler or help for command %s" , i .Command .FullName ())
336
371
}
337
- return inv .Command .HelpHandler (inv )
372
+ return i .Command .HelpHandler (i )
338
373
}
339
374
340
- err = mw (inv .Command .Handler )(inv )
375
+ err = mw (i .Command .Handler )(i )
341
376
if err != nil {
342
377
return & RunCommandError {
343
- Cmd : inv .Command ,
378
+ Cmd : i .Command ,
344
379
Err : err ,
345
380
}
346
381
}
@@ -403,33 +438,33 @@ func findArg(want string, args []string, fs *pflag.FlagSet) (int, error) {
403
438
// If two command share a flag name, the first command wins.
404
439
//
405
440
//nolint:revive
406
- func (inv * Invocation ) Run () (err error ) {
441
+ func (i * Invocation ) Run () (err error ) {
407
442
defer func () {
408
443
// Pflag is panicky, so additional context is helpful in tests.
409
444
if flag .Lookup ("test.v" ) == nil {
410
445
return
411
446
}
412
447
if r := recover (); r != nil {
413
- err = xerrors .Errorf ("panic recovered for %s: %v" , inv .Command .FullName (), r )
448
+ err = xerrors .Errorf ("panic recovered for %s: %v" , i .Command .FullName (), r )
414
449
panic (err )
415
450
}
416
451
}()
417
- err = inv .run (& runState {
418
- allArgs : inv .Args ,
452
+ err = i .run (& runState {
453
+ allArgs : i .Args ,
419
454
})
420
455
return err
421
456
}
422
457
423
458
// WithContext returns a copy of the Invocation with the given context.
424
- func (inv * Invocation ) WithContext (ctx context.Context ) * Invocation {
425
- return inv .with (func (i * Invocation ) {
459
+ func (i * Invocation ) WithContext (ctx context.Context ) * Invocation {
460
+ return i .with (func (i * Invocation ) {
426
461
i .ctx = ctx
427
462
})
428
463
}
429
464
430
465
// with returns a copy of the Invocation with the given function applied.
431
- func (inv * Invocation ) with (fn func (* Invocation )) * Invocation {
432
- i2 := * inv
466
+ func (i * Invocation ) with (fn func (* Invocation )) * Invocation {
467
+ i2 := * i
433
468
fn (& i2 )
434
469
return & i2
435
470
}
0 commit comments