Skip to content

Commit b949d1a

Browse files
committed
chore: Add helper for uniform flags and env vars
1 parent 3a48e40 commit b949d1a

File tree

4 files changed

+290
-52
lines changed

4 files changed

+290
-52
lines changed

cli/cliflags/cliflags.go

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
// Package cliflags provides helpers for uniform flags, env vars, and usage docs.
2+
// Helpers will set flags to their default value if the environment variable and flag are unset.
3+
// Helpers inject environment variable into flag usage docs if provided.
4+
//
5+
// Usage:
6+
//
7+
// cliflags.String(root.Flags(), &address, "address", "a", "CODER_ADDRESS", "127.0.0.1:3000", "The address to serve the API and dashboard")
8+
//
9+
// Will produce the following usage docs:
10+
//
11+
// -a, --address string The address to serve the API and dashboard (uses $CODER_ADDRESS). (default "127.0.0.1:3000")
12+
//
13+
package cliflags
14+
15+
import (
16+
"fmt"
17+
"os"
18+
"strconv"
19+
20+
"github.com/spf13/pflag"
21+
)
22+
23+
// String sets a string flag on the given flag set.
24+
func String(f *pflag.FlagSet, p *string, name string, shorthand string, env string, def string, usage string) {
25+
f.StringVarP(p, name, shorthand, envOrDefaultString(env, def), fmtUsage(usage, env))
26+
}
27+
28+
// Int sets a int flag on the given flag set.
29+
func Int(f *pflag.FlagSet, p *int, name string, shorthand string, env string, def int, usage string) {
30+
f.IntVarP(p, name, shorthand, envOrDefaultInt(env, def), fmtUsage(usage, env))
31+
}
32+
33+
// Bool sets a bool flag on the given flag set.
34+
func Bool(f *pflag.FlagSet, p *bool, name string, shorthand string, env string, def bool, usage string) {
35+
f.BoolVarP(p, name, shorthand, envOrDefaultBool(env, def), fmtUsage(usage, env))
36+
}
37+
38+
func envOrDefaultString(env string, def string) string {
39+
v, ok := os.LookupEnv(env)
40+
if !ok {
41+
return def
42+
}
43+
44+
return v
45+
}
46+
47+
func envOrDefaultInt(env string, def int) int {
48+
v, ok := os.LookupEnv(env)
49+
if !ok {
50+
return def
51+
}
52+
53+
i, err := strconv.Atoi(v)
54+
if err != nil {
55+
return def
56+
}
57+
58+
return i
59+
}
60+
61+
func envOrDefaultBool(env string, def bool) bool {
62+
v, ok := os.LookupEnv(env)
63+
if !ok {
64+
return def
65+
}
66+
67+
i, err := strconv.ParseBool(v)
68+
if err != nil {
69+
return def
70+
}
71+
72+
return i
73+
}
74+
75+
func fmtUsage(u string, env string) string {
76+
if env == "" {
77+
return fmt.Sprintf("%s.", u)
78+
}
79+
80+
return fmt.Sprintf("%s (uses $%s).", u, env)
81+
}

cli/cliflags/cliflags_test.go

Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
package cliflags
2+
3+
import (
4+
"fmt"
5+
"os"
6+
"strconv"
7+
"testing"
8+
9+
"github.com/coder/coder/cryptorand"
10+
"github.com/spf13/pflag"
11+
"github.com/stretchr/testify/require"
12+
)
13+
14+
func TestCliFlags(t *testing.T) {
15+
t.Parallel()
16+
17+
t.Run("StringDefault", func(t *testing.T) {
18+
var p string
19+
fsname, _ := cryptorand.String(10)
20+
f := pflag.NewFlagSet(fsname, pflag.PanicOnError)
21+
name, _ := cryptorand.String(10)
22+
shorthand, _ := cryptorand.String(1)
23+
env, _ := cryptorand.String(10)
24+
def, _ := cryptorand.String(10)
25+
usage, _ := cryptorand.String(10)
26+
27+
String(f, &p, name, shorthand, env, def, usage)
28+
got, err := f.GetString(name)
29+
require.NoError(t, err)
30+
require.Equal(t, def, got)
31+
require.Contains(t, f.FlagUsages(), usage)
32+
require.Contains(t, f.FlagUsages(), fmt.Sprintf("(uses $%s).", env))
33+
})
34+
35+
t.Run("StringEnvVar", func(t *testing.T) {
36+
var p string
37+
fsname, _ := cryptorand.String(10)
38+
f := pflag.NewFlagSet(fsname, pflag.PanicOnError)
39+
name, _ := cryptorand.String(10)
40+
shorthand, _ := cryptorand.String(1)
41+
env, _ := cryptorand.String(10)
42+
envValue, _ := cryptorand.String(10)
43+
os.Setenv(env, envValue)
44+
defer os.Unsetenv(env)
45+
def, _ := cryptorand.String(10)
46+
usage, _ := cryptorand.String(10)
47+
48+
String(f, &p, name, shorthand, env, def, usage)
49+
got, err := f.GetString(name)
50+
require.NoError(t, err)
51+
require.Equal(t, envValue, got)
52+
})
53+
54+
t.Run("EmptyEnvVar", func(t *testing.T) {
55+
var p string
56+
fsname, _ := cryptorand.String(10)
57+
f := pflag.NewFlagSet(fsname, pflag.PanicOnError)
58+
name, _ := cryptorand.String(10)
59+
shorthand, _ := cryptorand.String(1)
60+
env := ""
61+
def, _ := cryptorand.String(10)
62+
usage, _ := cryptorand.String(10)
63+
64+
String(f, &p, name, shorthand, env, def, usage)
65+
got, err := f.GetString(name)
66+
require.NoError(t, err)
67+
require.Equal(t, def, got)
68+
require.Contains(t, f.FlagUsages(), usage)
69+
require.NotContains(t, f.FlagUsages(), fmt.Sprintf("(uses $%s).", env))
70+
})
71+
72+
t.Run("IntDefault", func(t *testing.T) {
73+
var p int
74+
fsname, _ := cryptorand.String(10)
75+
f := pflag.NewFlagSet(fsname, pflag.PanicOnError)
76+
name, _ := cryptorand.String(10)
77+
shorthand, _ := cryptorand.String(1)
78+
env, _ := cryptorand.String(10)
79+
def, _ := cryptorand.Int()
80+
usage, _ := cryptorand.String(10)
81+
82+
Int(f, &p, name, shorthand, env, def, usage)
83+
got, err := f.GetInt(name)
84+
require.NoError(t, err)
85+
require.Equal(t, def, got)
86+
require.Contains(t, f.FlagUsages(), usage)
87+
require.Contains(t, f.FlagUsages(), fmt.Sprintf("(uses $%s).", env))
88+
})
89+
90+
t.Run("IntEnvVar", func(t *testing.T) {
91+
var p int
92+
fsname, _ := cryptorand.String(10)
93+
f := pflag.NewFlagSet(fsname, pflag.PanicOnError)
94+
name, _ := cryptorand.String(10)
95+
shorthand, _ := cryptorand.String(1)
96+
env, _ := cryptorand.String(10)
97+
envValue, _ := cryptorand.Int()
98+
os.Setenv(env, strconv.Itoa(envValue))
99+
defer os.Unsetenv(env)
100+
def, _ := cryptorand.Int()
101+
usage, _ := cryptorand.String(10)
102+
103+
Int(f, &p, name, shorthand, env, def, usage)
104+
got, err := f.GetInt(name)
105+
require.NoError(t, err)
106+
require.Equal(t, envValue, got)
107+
})
108+
109+
t.Run("IntFailParse", func(t *testing.T) {
110+
var p int
111+
fsname, _ := cryptorand.String(10)
112+
f := pflag.NewFlagSet(fsname, pflag.PanicOnError)
113+
name, _ := cryptorand.String(10)
114+
shorthand, _ := cryptorand.String(1)
115+
env, _ := cryptorand.String(10)
116+
envValue, _ := cryptorand.String(10)
117+
os.Setenv(env, envValue)
118+
defer os.Unsetenv(env)
119+
def, _ := cryptorand.Int()
120+
usage, _ := cryptorand.String(10)
121+
122+
Int(f, &p, name, shorthand, env, def, usage)
123+
got, err := f.GetInt(name)
124+
require.NoError(t, err)
125+
require.Equal(t, def, got)
126+
})
127+
128+
t.Run("BoolDefault", func(t *testing.T) {
129+
var p bool
130+
fsname, _ := cryptorand.String(10)
131+
f := pflag.NewFlagSet(fsname, pflag.PanicOnError)
132+
name, _ := cryptorand.String(10)
133+
shorthand, _ := cryptorand.String(1)
134+
env, _ := cryptorand.String(10)
135+
def, _ := cryptorand.Bool()
136+
usage, _ := cryptorand.String(10)
137+
138+
Bool(f, &p, name, shorthand, env, def, usage)
139+
got, err := f.GetBool(name)
140+
require.NoError(t, err)
141+
require.Equal(t, def, got)
142+
require.Contains(t, f.FlagUsages(), usage)
143+
require.Contains(t, f.FlagUsages(), fmt.Sprintf("(uses $%s).", env))
144+
})
145+
146+
t.Run("BoolEnvVar", func(t *testing.T) {
147+
var p bool
148+
fsname, _ := cryptorand.String(10)
149+
f := pflag.NewFlagSet(fsname, pflag.PanicOnError)
150+
name, _ := cryptorand.String(10)
151+
shorthand, _ := cryptorand.String(1)
152+
env, _ := cryptorand.String(10)
153+
envValue, _ := cryptorand.Bool()
154+
os.Setenv(env, strconv.FormatBool(envValue))
155+
defer os.Unsetenv(env)
156+
def, _ := cryptorand.Bool()
157+
usage, _ := cryptorand.String(10)
158+
159+
Bool(f, &p, name, shorthand, env, def, usage)
160+
got, err := f.GetBool(name)
161+
require.NoError(t, err)
162+
require.Equal(t, envValue, got)
163+
})
164+
165+
t.Run("BoolFailParse", func(t *testing.T) {
166+
var p bool
167+
fsname, _ := cryptorand.String(10)
168+
f := pflag.NewFlagSet(fsname, pflag.PanicOnError)
169+
name, _ := cryptorand.String(10)
170+
shorthand, _ := cryptorand.String(1)
171+
env, _ := cryptorand.String(10)
172+
envValue, _ := cryptorand.String(10)
173+
os.Setenv(env, envValue)
174+
defer os.Unsetenv(env)
175+
def, _ := cryptorand.Bool()
176+
usage, _ := cryptorand.String(10)
177+
178+
Bool(f, &p, name, shorthand, env, def, usage)
179+
got, err := f.GetBool(name)
180+
require.NoError(t, err)
181+
require.Equal(t, def, got)
182+
})
183+
}

cli/start.go

Lines changed: 18 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ import (
1313
"net/url"
1414
"os"
1515
"os/signal"
16-
"strconv"
1716
"time"
1817

1918
"github.com/briandowns/spinner"
@@ -25,6 +24,7 @@ import (
2524

2625
"cdr.dev/slog"
2726
"cdr.dev/slog/sloggers/sloghuman"
27+
"github.com/coder/coder/cli/cliflags"
2828
"github.com/coder/coder/cli/cliui"
2929
"github.com/coder/coder/cli/config"
3030
"github.com/coder/coder/coderd"
@@ -57,10 +57,6 @@ func start() *cobra.Command {
5757
Use: "start",
5858
RunE: func(cmd *cobra.Command, args []string) error {
5959
printLogo(cmd)
60-
if postgresURL == "" {
61-
// Default to the environment variable!
62-
postgresURL = os.Getenv("CODER_PG_CONNECTION_URL")
63-
}
6460

6561
listener, err := net.Listen("tcp", address)
6662
if err != nil {
@@ -306,46 +302,26 @@ func start() *cobra.Command {
306302
return nil
307303
},
308304
}
309-
defaultAddress := os.Getenv("CODER_ADDRESS")
310-
if defaultAddress == "" {
311-
defaultAddress = "127.0.0.1:3000"
312-
}
313-
root.Flags().StringVarP(&accessURL, "access-url", "", os.Getenv("CODER_ACCESS_URL"), "Specifies the external URL to access Coder (uses $CODER_ACCESS_URL).")
314-
root.Flags().StringVarP(&address, "address", "a", defaultAddress, "The address to serve the API and dashboard (uses $CODER_ADDRESS).")
315-
defaultDev, _ := strconv.ParseBool(os.Getenv("CODER_DEV_MODE"))
316-
root.Flags().BoolVarP(&dev, "dev", "", defaultDev, "Serve Coder in dev mode for tinkering (uses $CODER_DEV_MODE).")
317-
root.Flags().StringVarP(&postgresURL, "postgres-url", "", "",
318-
"URL of a PostgreSQL database to connect to (defaults to $CODER_PG_CONNECTION_URL).")
319-
root.Flags().Uint8VarP(&provisionerDaemonCount, "provisioner-daemons", "", 1, "The amount of provisioner daemons to create on start.")
320-
defaultTLSEnable, _ := strconv.ParseBool(os.Getenv("CODER_TLS_ENABLE"))
321-
root.Flags().BoolVarP(&tlsEnable, "tls-enable", "", defaultTLSEnable, "Specifies if TLS will be enabled (uses $CODER_TLS_ENABLE).")
322-
root.Flags().StringVarP(&tlsCertFile, "tls-cert-file", "", os.Getenv("CODER_TLS_CERT_FILE"),
305+
306+
cliflags.String(root.Flags(), &accessURL, "access-url", "", "CODER_ACCESS_URL", "", "Specifies the external URL to access Coder")
307+
cliflags.String(root.Flags(), &address, "address", "a", "CODER_ADDRESS", "127.0.0.1:3000", "The address to serve the API and dashboard")
308+
cliflags.Bool(root.Flags(), &dev, "dev", "", "CODER_DEV_MODE", false, "Serve Coder in dev mode for tinkering")
309+
cliflags.String(root.Flags(), &postgresURL, "postgres-url", "", "CODER_PG_CONNECTION_URL", "", "URL of a PostgreSQL database to connect to")
310+
cliflags.Bool(root.Flags(), &tlsEnable, "tls-enable", "", "CODER_TLS_ENABLE", false, "Specifies if TLS will be enabled")
311+
cliflags.String(root.Flags(), &tlsCertFile, "tls-cert-file", "", "CODER_TLS_CERT_FILE", "",
323312
"Specifies the path to the certificate for TLS. It requires a PEM-encoded file. "+
324313
"To configure the listener to use a CA certificate, concatenate the primary certificate "+
325-
"and the CA certificate together. The primary certificate should appear first in the combined file (uses $CODER_TLS_CERT_FILE).")
326-
root.Flags().StringVarP(&tlsClientCAFile, "tls-client-ca-file", "", os.Getenv("CODER_TLS_CLIENT_CA_FILE"),
327-
"PEM-encoded Certificate Authority file used for checking the authenticity of client (uses $CODER_TLS_CLIENT_CA_FILE).")
328-
defaultTLSClientAuth := os.Getenv("CODER_TLS_CLIENT_AUTH")
329-
if defaultTLSClientAuth == "" {
330-
defaultTLSClientAuth = "request"
331-
}
332-
root.Flags().StringVarP(&tlsClientAuth, "tls-client-auth", "", defaultTLSClientAuth,
314+
"and the CA certificate together. The primary certificate should appear first in the combined file")
315+
cliflags.String(root.Flags(), &tlsClientCAFile, "tls-client-ca-file", "", "CODER_TLS_CLIENT_CA_FILE", "",
316+
"PEM-encoded Certificate Authority file used for checking the authenticity of client")
317+
cliflags.String(root.Flags(), &tlsClientAuth, "tls-client-auth", "", "CODER_TLS_CLIENT_AUTH", "request",
333318
`Specifies the policy the server will follow for TLS Client Authentication. `+
334-
`Accepted values are "none", "request", "require-any", "verify-if-given", or "require-and-verify" (uses $CODER_TLS_CLIENT_AUTH).`)
335-
root.Flags().StringVarP(&tlsKeyFile, "tls-key-file", "", os.Getenv("CODER_TLS_KEY_FILE"),
336-
"Specifies the path to the private key for the certificate. It requires a PEM-encoded file (uses $CODER_TLS_KEY_FILE).")
337-
defaultTLSMinVersion := os.Getenv("CODER_TLS_MIN_VERSION")
338-
if defaultTLSMinVersion == "" {
339-
defaultTLSMinVersion = "tls12"
340-
}
341-
root.Flags().StringVarP(&tlsMinVersion, "tls-min-version", "", defaultTLSMinVersion,
342-
`Specifies the minimum supported version of TLS. Accepted values are "tls10", "tls11", "tls12" or "tls13" (uses $CODER_TLS_MIN_VERSION).`)
343-
defaultTunnelRaw := os.Getenv("CODER_DEV_TUNNEL")
344-
if defaultTunnelRaw == "" {
345-
defaultTunnelRaw = "true"
346-
}
347-
defaultTunnel, _ := strconv.ParseBool(defaultTunnelRaw)
348-
root.Flags().BoolVarP(&useTunnel, "tunnel", "", defaultTunnel, "Serve dev mode through a Cloudflare Tunnel for easy setup (uses $CODER_DEV_TUNNEL).")
319+
`Accepted values are "none", "request", "require-any", "verify-if-given", or "require-and-verify"`)
320+
cliflags.String(root.Flags(), &tlsKeyFile, "tls-key-file", "", "CODER_TLS_KEY_FILE", "",
321+
"Specifies the path to the private key for the certificate. It requires a PEM-encoded file")
322+
cliflags.String(root.Flags(), &tlsMinVersion, "tls-min-version", "", "CODER_TLS_MIN_VERSION", "tls12",
323+
`Specifies the minimum supported version of TLS. Accepted values are "tls10", "tls11", "tls12" or "tls13"`)
324+
cliflags.Bool(root.Flags(), &useTunnel, "tunnel", "", "CODER_DEV_TUNNEL", false, "Serve dev mode through a Cloudflare Tunnel for easy setup")
349325
_ = root.Flags().MarkHidden("tunnel")
350326

351327
return root

0 commit comments

Comments
 (0)