Skip to content

Commit e22539b

Browse files
committed
pr comments
1 parent 15c9aaa commit e22539b

File tree

2 files changed

+219
-0
lines changed

2 files changed

+219
-0
lines changed

cli/cliflag/cliflag.go

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
// Package cliflags extends flagset with environment variable defaults.
2+
// Helpers will set flags to their default value if the environment variable and flag are unset.
3+
// Helpers inject environment variable into flag usage docs if provided.
4+
//
5+
// Usage:
6+
//
7+
// cliflag.String(root.Flags(), &address, "address", "a", "CODER_ADDRESS", "127.0.0.1:3000", "The address to serve the API and dashboard")
8+
//
9+
// Will produce the following usage docs:
10+
//
11+
// -a, --address string The address to serve the API and dashboard (uses $CODER_ADDRESS). (default "127.0.0.1:3000")
12+
//
13+
package cliflag
14+
15+
import (
16+
"fmt"
17+
"os"
18+
"strconv"
19+
20+
"github.com/spf13/pflag"
21+
)
22+
23+
// StringVarP sets a string flag on the given flag set.
24+
func StringVarP(flagset *pflag.FlagSet, p *string, name string, shorthand string, env string, def string, usage string) {
25+
v, ok := os.LookupEnv(env)
26+
if !ok || v == "" {
27+
v = def
28+
}
29+
flagset.StringVarP(p, name, shorthand, v, fmtUsage(usage, env))
30+
}
31+
32+
// Uint8VarP sets a uint8 flag on the given flag set.
33+
func Uint8VarP(flagset *pflag.FlagSet, ptr *uint8, name string, shorthand string, env string, def uint8, usage string) {
34+
val, ok := os.LookupEnv(env)
35+
if !ok || val == "" {
36+
flagset.Uint8VarP(ptr, name, shorthand, def, fmtUsage(usage, env))
37+
return
38+
}
39+
40+
vi64, err := strconv.ParseUint(val, 10, 8)
41+
if err != nil {
42+
flagset.Uint8VarP(ptr, name, shorthand, def, fmtUsage(usage, env))
43+
return
44+
}
45+
46+
flagset.Uint8VarP(ptr, name, shorthand, uint8(vi64), fmtUsage(usage, env))
47+
}
48+
49+
// BoolVarP sets a bool flag on the given flag set.
50+
func BoolVarP(flagset *pflag.FlagSet, ptr *bool, name string, shorthand string, env string, def bool, usage string) {
51+
val, ok := os.LookupEnv(env)
52+
if !ok || val == "" {
53+
flagset.BoolVarP(ptr, name, shorthand, def, fmtUsage(usage, env))
54+
return
55+
}
56+
57+
valb, err := strconv.ParseBool(val)
58+
if err != nil {
59+
flagset.BoolVarP(ptr, name, shorthand, def, fmtUsage(usage, env))
60+
return
61+
}
62+
63+
flagset.BoolVarP(ptr, name, shorthand, valb, fmtUsage(usage, env))
64+
}
65+
66+
func fmtUsage(u string, env string) string {
67+
if env == "" {
68+
return fmt.Sprintf("%s.", u)
69+
}
70+
71+
return fmt.Sprintf("%s (uses $%s).", u, env)
72+
}

cli/cliflag/cliflag_test.go

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
package cliflag_test
2+
3+
import (
4+
"fmt"
5+
"strconv"
6+
"testing"
7+
8+
"github.com/spf13/pflag"
9+
"github.com/stretchr/testify/require"
10+
11+
"github.com/coder/coder/cli/cliflag"
12+
"github.com/coder/coder/cryptorand"
13+
)
14+
15+
// Testcliflag cannot run in parallel because it uses t.Setenv.
16+
//nolint:paralleltest
17+
func TestCliflag(t *testing.T) {
18+
19+
t.Run("StringDefault", func(t *testing.T) {
20+
var p string
21+
flagset, name, shorthand, env, usage := randomFlag()
22+
def, _ := cryptorand.String(10)
23+
24+
cliflag.StringVarP(flagset, &p, name, shorthand, env, def, usage)
25+
got, err := flagset.GetString(name)
26+
require.NoError(t, err)
27+
require.Equal(t, def, got)
28+
require.Contains(t, flagset.FlagUsages(), usage)
29+
require.Contains(t, flagset.FlagUsages(), fmt.Sprintf("(uses $%s).", env))
30+
})
31+
32+
t.Run("StringEnvVar", func(t *testing.T) {
33+
var p string
34+
flagset, name, shorthand, env, usage := randomFlag()
35+
envValue, _ := cryptorand.String(10)
36+
t.Setenv(env, envValue)
37+
def, _ := cryptorand.String(10)
38+
39+
cliflag.StringVarP(flagset, &p, name, shorthand, env, def, usage)
40+
got, err := flagset.GetString(name)
41+
require.NoError(t, err)
42+
require.Equal(t, envValue, got)
43+
})
44+
45+
t.Run("EmptyEnvVar", func(t *testing.T) {
46+
var p string
47+
flagset, name, shorthand, env, usage := randomFlag()
48+
env = ""
49+
def, _ := cryptorand.String(10)
50+
51+
cliflag.StringVarP(flagset, &p, name, shorthand, env, def, usage)
52+
got, err := flagset.GetString(name)
53+
require.NoError(t, err)
54+
require.Equal(t, def, got)
55+
require.Contains(t, flagset.FlagUsages(), usage)
56+
require.NotContains(t, flagset.FlagUsages(), fmt.Sprintf("(uses $%s).", env))
57+
})
58+
59+
t.Run("IntDefault", func(t *testing.T) {
60+
var p uint8
61+
flagset, name, shorthand, env, usage := randomFlag()
62+
def, _ := cryptorand.Int63n(10)
63+
64+
cliflag.Uint8VarP(flagset, &p, name, shorthand, env, uint8(def), usage)
65+
got, err := flagset.GetUint8(name)
66+
require.NoError(t, err)
67+
require.Equal(t, uint8(def), got)
68+
require.Contains(t, flagset.FlagUsages(), usage)
69+
require.Contains(t, flagset.FlagUsages(), fmt.Sprintf("(uses $%s).", env))
70+
})
71+
72+
t.Run("IntEnvVar", func(t *testing.T) {
73+
var p uint8
74+
flagset, name, shorthand, env, usage := randomFlag()
75+
envValue, _ := cryptorand.Int63n(10)
76+
t.Setenv(env, strconv.FormatUint(uint64(envValue), 10))
77+
def, _ := cryptorand.Int()
78+
79+
cliflag.Uint8VarP(flagset, &p, name, shorthand, env, uint8(def), usage)
80+
got, err := flagset.GetUint8(name)
81+
require.NoError(t, err)
82+
require.Equal(t, uint8(envValue), got)
83+
})
84+
85+
t.Run("IntFailParse", func(t *testing.T) {
86+
var p uint8
87+
flagset, name, shorthand, env, usage := randomFlag()
88+
envValue, _ := cryptorand.String(10)
89+
t.Setenv(env, envValue)
90+
def, _ := cryptorand.Int63n(10)
91+
92+
cliflag.Uint8VarP(flagset, &p, name, shorthand, env, uint8(def), usage)
93+
got, err := flagset.GetUint8(name)
94+
require.NoError(t, err)
95+
require.Equal(t, uint8(def), got)
96+
})
97+
98+
t.Run("BoolDefault", func(t *testing.T) {
99+
var p bool
100+
flagset, name, shorthand, env, usage := randomFlag()
101+
def, _ := cryptorand.Bool()
102+
103+
cliflag.BoolVarP(flagset, &p, name, shorthand, env, def, usage)
104+
got, err := flagset.GetBool(name)
105+
require.NoError(t, err)
106+
require.Equal(t, def, got)
107+
require.Contains(t, flagset.FlagUsages(), usage)
108+
require.Contains(t, flagset.FlagUsages(), fmt.Sprintf("(uses $%s).", env))
109+
})
110+
111+
t.Run("BoolEnvVar", func(t *testing.T) {
112+
var p bool
113+
flagset, name, shorthand, env, usage := randomFlag()
114+
envValue, _ := cryptorand.Bool()
115+
t.Setenv(env, strconv.FormatBool(envValue))
116+
def, _ := cryptorand.Bool()
117+
118+
cliflag.BoolVarP(flagset, &p, name, shorthand, env, def, usage)
119+
got, err := flagset.GetBool(name)
120+
require.NoError(t, err)
121+
require.Equal(t, envValue, got)
122+
})
123+
124+
t.Run("BoolFailParse", func(t *testing.T) {
125+
var p bool
126+
flagset, name, shorthand, env, usage := randomFlag()
127+
envValue, _ := cryptorand.String(10)
128+
t.Setenv(env, envValue)
129+
def, _ := cryptorand.Bool()
130+
131+
cliflag.BoolVarP(flagset, &p, name, shorthand, env, def, usage)
132+
got, err := flagset.GetBool(name)
133+
require.NoError(t, err)
134+
require.Equal(t, def, got)
135+
})
136+
}
137+
138+
func randomFlag() (*pflag.FlagSet, string, string, string, string) {
139+
fsname, _ := cryptorand.String(10)
140+
flagset := pflag.NewFlagSet(fsname, pflag.PanicOnError)
141+
name, _ := cryptorand.String(10)
142+
shorthand, _ := cryptorand.String(1)
143+
env, _ := cryptorand.String(10)
144+
usage, _ := cryptorand.String(10)
145+
146+
return flagset, name, shorthand, env, usage
147+
}

0 commit comments

Comments
 (0)