Skip to content

Commit d5bf66c

Browse files
committed
chore: Add helper for uniform flags and env vars
1 parent 3a48e40 commit d5bf66c

File tree

4 files changed

+313
-54
lines changed

4 files changed

+313
-54
lines changed

cli/cliflags/cliflags.go

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
// Package cliflags provides helpers for uniform flags, env vars, and usage docs.
2+
// Helpers will set flags to their default value if the environment variable and flag are unset.
3+
// Helpers inject environment variable into flag usage docs if provided.
4+
//
5+
// Usage:
6+
//
7+
// cliflags.String(root.Flags(), &address, "address", "a", "CODER_ADDRESS", "127.0.0.1:3000", "The address to serve the API and dashboard")
8+
//
9+
// Will produce the following usage docs:
10+
//
11+
// -a, --address string The address to serve the API and dashboard (uses $CODER_ADDRESS). (default "127.0.0.1:3000")
12+
//
13+
package cliflags
14+
15+
import (
16+
"fmt"
17+
"os"
18+
"strconv"
19+
20+
"github.com/spf13/pflag"
21+
)
22+
23+
// String sets a string flag on the given flag set.
24+
func String(flagset *pflag.FlagSet, p *string, name string, shorthand string, env string, def string, usage string) {
25+
flagset.StringVarP(p, name, shorthand, envOrDefaultString(env, def), fmtUsage(usage, env))
26+
}
27+
28+
// Int sets a int flag on the given flag set.
29+
func Int(flagset *pflag.FlagSet, p *int, name string, shorthand string, env string, def int, usage string) {
30+
flagset.IntVarP(p, name, shorthand, envOrDefaultInt(env, def), fmtUsage(usage, env))
31+
}
32+
33+
// Bool sets a bool flag on the given flag set.
34+
func Bool(flagset *pflag.FlagSet, p *bool, name string, shorthand string, env string, def bool, usage string) {
35+
flagset.BoolVarP(p, name, shorthand, envOrDefaultBool(env, def), fmtUsage(usage, env))
36+
}
37+
38+
func envOrDefaultString(env string, def string) string {
39+
v, ok := os.LookupEnv(env)
40+
if !ok {
41+
return def
42+
}
43+
44+
return v
45+
}
46+
47+
func envOrDefaultInt(env string, def int) int {
48+
v, ok := os.LookupEnv(env)
49+
if !ok {
50+
return def
51+
}
52+
53+
i, err := strconv.Atoi(v)
54+
if err != nil {
55+
return def
56+
}
57+
58+
return i
59+
}
60+
61+
func envOrDefaultBool(env string, def bool) bool {
62+
v, ok := os.LookupEnv(env)
63+
if !ok {
64+
return def
65+
}
66+
67+
i, err := strconv.ParseBool(v)
68+
if err != nil {
69+
return def
70+
}
71+
72+
return i
73+
}
74+
75+
func fmtUsage(u string, env string) string {
76+
if env == "" {
77+
return fmt.Sprintf("%s.", u)
78+
}
79+
80+
return fmt.Sprintf("%s (uses $%s).", u, env)
81+
}

cli/cliflags/cliflags_test.go

Lines changed: 203 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,203 @@
1+
package cliflags_test
2+
3+
import (
4+
"fmt"
5+
"os"
6+
"strconv"
7+
"testing"
8+
9+
"github.com/spf13/pflag"
10+
"github.com/stretchr/testify/require"
11+
12+
"github.com/coder/coder/cli/cliflags"
13+
"github.com/coder/coder/cryptorand"
14+
)
15+
16+
func TestCliFlags(t *testing.T) {
17+
t.Parallel()
18+
19+
t.Run("StringDefault", func(t *testing.T) {
20+
t.Parallel()
21+
22+
var p string
23+
fsname, _ := cryptorand.String(10)
24+
flagset := pflag.NewFlagSet(fsname, pflag.PanicOnError)
25+
name, _ := cryptorand.String(10)
26+
shorthand, _ := cryptorand.String(1)
27+
env, _ := cryptorand.String(10)
28+
def, _ := cryptorand.String(10)
29+
usage, _ := cryptorand.String(10)
30+
31+
cliflags.String(flagset, &p, name, shorthand, env, def, usage)
32+
got, err := flagset.GetString(name)
33+
require.NoError(t, err)
34+
require.Equal(t, def, got)
35+
require.Contains(t, flagset.FlagUsages(), usage)
36+
require.Contains(t, flagset.FlagUsages(), fmt.Sprintf("(uses $%s).", env))
37+
})
38+
39+
t.Run("StringEnvVar", func(t *testing.T) {
40+
t.Parallel()
41+
42+
var p string
43+
fsname, _ := cryptorand.String(10)
44+
flagset := pflag.NewFlagSet(fsname, pflag.PanicOnError)
45+
name, _ := cryptorand.String(10)
46+
shorthand, _ := cryptorand.String(1)
47+
env, _ := cryptorand.String(10)
48+
envValue, _ := cryptorand.String(10)
49+
os.Setenv(env, envValue)
50+
defer os.Unsetenv(env)
51+
def, _ := cryptorand.String(10)
52+
usage, _ := cryptorand.String(10)
53+
54+
cliflags.String(flagset, &p, name, shorthand, env, def, usage)
55+
got, err := flagset.GetString(name)
56+
require.NoError(t, err)
57+
require.Equal(t, envValue, got)
58+
})
59+
60+
t.Run("EmptyEnvVar", func(t *testing.T) {
61+
t.Parallel()
62+
63+
var p string
64+
fsname, _ := cryptorand.String(10)
65+
flagset := pflag.NewFlagSet(fsname, pflag.PanicOnError)
66+
name, _ := cryptorand.String(10)
67+
shorthand, _ := cryptorand.String(1)
68+
env := ""
69+
def, _ := cryptorand.String(10)
70+
usage, _ := cryptorand.String(10)
71+
72+
cliflags.String(flagset, &p, name, shorthand, env, def, usage)
73+
got, err := flagset.GetString(name)
74+
require.NoError(t, err)
75+
require.Equal(t, def, got)
76+
require.Contains(t, flagset.FlagUsages(), usage)
77+
require.NotContains(t, flagset.FlagUsages(), fmt.Sprintf("(uses $%s).", env))
78+
})
79+
80+
t.Run("IntDefault", func(t *testing.T) {
81+
t.Parallel()
82+
83+
var p int
84+
fsname, _ := cryptorand.String(10)
85+
flagset := pflag.NewFlagSet(fsname, pflag.PanicOnError)
86+
name, _ := cryptorand.String(10)
87+
shorthand, _ := cryptorand.String(1)
88+
env, _ := cryptorand.String(10)
89+
def, _ := cryptorand.Int()
90+
usage, _ := cryptorand.String(10)
91+
92+
cliflags.Int(flagset, &p, name, shorthand, env, def, usage)
93+
got, err := flagset.GetInt(name)
94+
require.NoError(t, err)
95+
require.Equal(t, def, got)
96+
require.Contains(t, flagset.FlagUsages(), usage)
97+
require.Contains(t, flagset.FlagUsages(), fmt.Sprintf("(uses $%s).", env))
98+
})
99+
100+
t.Run("IntEnvVar", func(t *testing.T) {
101+
t.Parallel()
102+
103+
var p int
104+
fsname, _ := cryptorand.String(10)
105+
flagset := pflag.NewFlagSet(fsname, pflag.PanicOnError)
106+
name, _ := cryptorand.String(10)
107+
shorthand, _ := cryptorand.String(1)
108+
env, _ := cryptorand.String(10)
109+
envValue, _ := cryptorand.Int()
110+
os.Setenv(env, strconv.Itoa(envValue))
111+
defer os.Unsetenv(env)
112+
def, _ := cryptorand.Int()
113+
usage, _ := cryptorand.String(10)
114+
115+
cliflags.Int(flagset, &p, name, shorthand, env, def, usage)
116+
got, err := flagset.GetInt(name)
117+
require.NoError(t, err)
118+
require.Equal(t, envValue, got)
119+
})
120+
121+
t.Run("IntFailParse", func(t *testing.T) {
122+
t.Parallel()
123+
124+
var p int
125+
fsname, _ := cryptorand.String(10)
126+
flagset := pflag.NewFlagSet(fsname, pflag.PanicOnError)
127+
name, _ := cryptorand.String(10)
128+
shorthand, _ := cryptorand.String(1)
129+
env, _ := cryptorand.String(10)
130+
envValue, _ := cryptorand.String(10)
131+
os.Setenv(env, envValue)
132+
defer os.Unsetenv(env)
133+
def, _ := cryptorand.Int()
134+
usage, _ := cryptorand.String(10)
135+
136+
cliflags.Int(flagset, &p, name, shorthand, env, def, usage)
137+
got, err := flagset.GetInt(name)
138+
require.NoError(t, err)
139+
require.Equal(t, def, got)
140+
})
141+
142+
t.Run("BoolDefault", func(t *testing.T) {
143+
t.Parallel()
144+
145+
var p bool
146+
fsname, _ := cryptorand.String(10)
147+
flagset := pflag.NewFlagSet(fsname, pflag.PanicOnError)
148+
name, _ := cryptorand.String(10)
149+
shorthand, _ := cryptorand.String(1)
150+
env, _ := cryptorand.String(10)
151+
def, _ := cryptorand.Bool()
152+
usage, _ := cryptorand.String(10)
153+
154+
cliflags.Bool(flagset, &p, name, shorthand, env, def, usage)
155+
got, err := flagset.GetBool(name)
156+
require.NoError(t, err)
157+
require.Equal(t, def, got)
158+
require.Contains(t, flagset.FlagUsages(), usage)
159+
require.Contains(t, flagset.FlagUsages(), fmt.Sprintf("(uses $%s).", env))
160+
})
161+
162+
t.Run("BoolEnvVar", func(t *testing.T) {
163+
t.Parallel()
164+
165+
var p bool
166+
fsname, _ := cryptorand.String(10)
167+
flagset := pflag.NewFlagSet(fsname, pflag.PanicOnError)
168+
name, _ := cryptorand.String(10)
169+
shorthand, _ := cryptorand.String(1)
170+
env, _ := cryptorand.String(10)
171+
envValue, _ := cryptorand.Bool()
172+
os.Setenv(env, strconv.FormatBool(envValue))
173+
defer os.Unsetenv(env)
174+
def, _ := cryptorand.Bool()
175+
usage, _ := cryptorand.String(10)
176+
177+
cliflags.Bool(flagset, &p, name, shorthand, env, def, usage)
178+
got, err := flagset.GetBool(name)
179+
require.NoError(t, err)
180+
require.Equal(t, envValue, got)
181+
})
182+
183+
t.Run("BoolFailParse", func(t *testing.T) {
184+
t.Parallel()
185+
186+
var p bool
187+
fsname, _ := cryptorand.String(10)
188+
flagset := pflag.NewFlagSet(fsname, pflag.PanicOnError)
189+
name, _ := cryptorand.String(10)
190+
shorthand, _ := cryptorand.String(1)
191+
env, _ := cryptorand.String(10)
192+
envValue, _ := cryptorand.String(10)
193+
os.Setenv(env, envValue)
194+
defer os.Unsetenv(env)
195+
def, _ := cryptorand.Bool()
196+
usage, _ := cryptorand.String(10)
197+
198+
cliflags.Bool(flagset, &p, name, shorthand, env, def, usage)
199+
got, err := flagset.GetBool(name)
200+
require.NoError(t, err)
201+
require.Equal(t, def, got)
202+
})
203+
}

cli/start.go

Lines changed: 21 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ import (
1313
"net/url"
1414
"os"
1515
"os/signal"
16-
"strconv"
1716
"time"
1817

1918
"github.com/briandowns/spinner"
@@ -25,6 +24,7 @@ import (
2524

2625
"cdr.dev/slog"
2726
"cdr.dev/slog/sloggers/sloghuman"
27+
"github.com/coder/coder/cli/cliflags"
2828
"github.com/coder/coder/cli/cliui"
2929
"github.com/coder/coder/cli/config"
3030
"github.com/coder/coder/coderd"
@@ -44,7 +44,7 @@ func start() *cobra.Command {
4444
address string
4545
dev bool
4646
postgresURL string
47-
provisionerDaemonCount uint8
47+
provisionerDaemonCount int
4848
tlsCertFile string
4949
tlsClientCAFile string
5050
tlsClientAuth string
@@ -57,10 +57,6 @@ func start() *cobra.Command {
5757
Use: "start",
5858
RunE: func(cmd *cobra.Command, args []string) error {
5959
printLogo(cmd)
60-
if postgresURL == "" {
61-
// Default to the environment variable!
62-
postgresURL = os.Getenv("CODER_PG_CONNECTION_URL")
63-
}
6460

6561
listener, err := net.Listen("tcp", address)
6662
if err != nil {
@@ -163,7 +159,7 @@ func start() *cobra.Command {
163159
}
164160

165161
provisionerDaemons := make([]*provisionerd.Server, 0)
166-
for i := uint8(0); i < provisionerDaemonCount; i++ {
162+
for i := 0; i < provisionerDaemonCount; i++ {
167163
daemonClose, err := newProvisionerDaemon(cmd.Context(), client, logger)
168164
if err != nil {
169165
return xerrors.Errorf("create provisioner daemon: %w", err)
@@ -306,46 +302,27 @@ func start() *cobra.Command {
306302
return nil
307303
},
308304
}
309-
defaultAddress := os.Getenv("CODER_ADDRESS")
310-
if defaultAddress == "" {
311-
defaultAddress = "127.0.0.1:3000"
312-
}
313-
root.Flags().StringVarP(&accessURL, "access-url", "", os.Getenv("CODER_ACCESS_URL"), "Specifies the external URL to access Coder (uses $CODER_ACCESS_URL).")
314-
root.Flags().StringVarP(&address, "address", "a", defaultAddress, "The address to serve the API and dashboard (uses $CODER_ADDRESS).")
315-
defaultDev, _ := strconv.ParseBool(os.Getenv("CODER_DEV_MODE"))
316-
root.Flags().BoolVarP(&dev, "dev", "", defaultDev, "Serve Coder in dev mode for tinkering (uses $CODER_DEV_MODE).")
317-
root.Flags().StringVarP(&postgresURL, "postgres-url", "", "",
318-
"URL of a PostgreSQL database to connect to (defaults to $CODER_PG_CONNECTION_URL).")
319-
root.Flags().Uint8VarP(&provisionerDaemonCount, "provisioner-daemons", "", 1, "The amount of provisioner daemons to create on start.")
320-
defaultTLSEnable, _ := strconv.ParseBool(os.Getenv("CODER_TLS_ENABLE"))
321-
root.Flags().BoolVarP(&tlsEnable, "tls-enable", "", defaultTLSEnable, "Specifies if TLS will be enabled (uses $CODER_TLS_ENABLE).")
322-
root.Flags().StringVarP(&tlsCertFile, "tls-cert-file", "", os.Getenv("CODER_TLS_CERT_FILE"),
305+
306+
cliflags.String(root.Flags(), &accessURL, "access-url", "", "CODER_ACCESS_URL", "", "Specifies the external URL to access Coder")
307+
cliflags.String(root.Flags(), &address, "address", "a", "CODER_ADDRESS", "127.0.0.1:3000", "The address to serve the API and dashboard")
308+
cliflags.Bool(root.Flags(), &dev, "dev", "", "CODER_DEV_MODE", false, "Serve Coder in dev mode for tinkering")
309+
cliflags.String(root.Flags(), &postgresURL, "postgres-url", "", "CODER_PG_CONNECTION_URL", "", "URL of a PostgreSQL database to connect to")
310+
cliflags.Int(root.Flags(), &provisionerDaemonCount, "provisioner-daemons", "", "CODER_PROVISIONER_DAEMONS", 1, "The amount of provisioner daemons to create on start.")
311+
cliflags.Bool(root.Flags(), &tlsEnable, "tls-enable", "", "CODER_TLS_ENABLE", false, "Specifies if TLS will be enabled")
312+
cliflags.String(root.Flags(), &tlsCertFile, "tls-cert-file", "", "CODER_TLS_CERT_FILE", "",
323313
"Specifies the path to the certificate for TLS. It requires a PEM-encoded file. "+
324314
"To configure the listener to use a CA certificate, concatenate the primary certificate "+
325-
"and the CA certificate together. The primary certificate should appear first in the combined file (uses $CODER_TLS_CERT_FILE).")
326-
root.Flags().StringVarP(&tlsClientCAFile, "tls-client-ca-file", "", os.Getenv("CODER_TLS_CLIENT_CA_FILE"),
327-
"PEM-encoded Certificate Authority file used for checking the authenticity of client (uses $CODER_TLS_CLIENT_CA_FILE).")
328-
defaultTLSClientAuth := os.Getenv("CODER_TLS_CLIENT_AUTH")
329-
if defaultTLSClientAuth == "" {
330-
defaultTLSClientAuth = "request"
331-
}
332-
root.Flags().StringVarP(&tlsClientAuth, "tls-client-auth", "", defaultTLSClientAuth,
315+
"and the CA certificate together. The primary certificate should appear first in the combined file")
316+
cliflags.String(root.Flags(), &tlsClientCAFile, "tls-client-ca-file", "", "CODER_TLS_CLIENT_CA_FILE", "",
317+
"PEM-encoded Certificate Authority file used for checking the authenticity of client")
318+
cliflags.String(root.Flags(), &tlsClientAuth, "tls-client-auth", "", "CODER_TLS_CLIENT_AUTH", "request",
333319
`Specifies the policy the server will follow for TLS Client Authentication. `+
334-
`Accepted values are "none", "request", "require-any", "verify-if-given", or "require-and-verify" (uses $CODER_TLS_CLIENT_AUTH).`)
335-
root.Flags().StringVarP(&tlsKeyFile, "tls-key-file", "", os.Getenv("CODER_TLS_KEY_FILE"),
336-
"Specifies the path to the private key for the certificate. It requires a PEM-encoded file (uses $CODER_TLS_KEY_FILE).")
337-
defaultTLSMinVersion := os.Getenv("CODER_TLS_MIN_VERSION")
338-
if defaultTLSMinVersion == "" {
339-
defaultTLSMinVersion = "tls12"
340-
}
341-
root.Flags().StringVarP(&tlsMinVersion, "tls-min-version", "", defaultTLSMinVersion,
342-
`Specifies the minimum supported version of TLS. Accepted values are "tls10", "tls11", "tls12" or "tls13" (uses $CODER_TLS_MIN_VERSION).`)
343-
defaultTunnelRaw := os.Getenv("CODER_DEV_TUNNEL")
344-
if defaultTunnelRaw == "" {
345-
defaultTunnelRaw = "true"
346-
}
347-
defaultTunnel, _ := strconv.ParseBool(defaultTunnelRaw)
348-
root.Flags().BoolVarP(&useTunnel, "tunnel", "", defaultTunnel, "Serve dev mode through a Cloudflare Tunnel for easy setup (uses $CODER_DEV_TUNNEL).")
320+
`Accepted values are "none", "request", "require-any", "verify-if-given", or "require-and-verify"`)
321+
cliflags.String(root.Flags(), &tlsKeyFile, "tls-key-file", "", "CODER_TLS_KEY_FILE", "",
322+
"Specifies the path to the private key for the certificate. It requires a PEM-encoded file")
323+
cliflags.String(root.Flags(), &tlsMinVersion, "tls-min-version", "", "CODER_TLS_MIN_VERSION", "tls12",
324+
`Specifies the minimum supported version of TLS. Accepted values are "tls10", "tls11", "tls12" or "tls13"`)
325+
cliflags.Bool(root.Flags(), &useTunnel, "tunnel", "", "CODER_DEV_TUNNEL", false, "Serve dev mode through a Cloudflare Tunnel for easy setup")
349326
_ = root.Flags().MarkHidden("tunnel")
350327

351328
return root

0 commit comments

Comments
 (0)