Skip to content

Commit 01957da

Browse files
authored
chore: Add helper for uniform flags and env vars (coder#588)
1 parent be8389f commit 01957da

File tree

4 files changed

+248
-57
lines changed

4 files changed

+248
-57
lines changed

cli/cliflag/cliflag.go

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
// Package cliflag extends flagset with environment variable defaults.
2+
//
3+
// Usage:
4+
//
5+
// cliflag.String(root.Flags(), &address, "address", "a", "CODER_ADDRESS", "127.0.0.1:3000", "The address to serve the API and dashboard")
6+
//
7+
// Will produce the following usage docs:
8+
//
9+
// -a, --address string The address to serve the API and dashboard (uses $CODER_ADDRESS). (default "127.0.0.1:3000")
10+
//
11+
package cliflag
12+
13+
import (
14+
"fmt"
15+
"os"
16+
"strconv"
17+
18+
"github.com/spf13/pflag"
19+
)
20+
21+
// StringVarP sets a string flag on the given flag set.
22+
func StringVarP(flagset *pflag.FlagSet, p *string, name string, shorthand string, env string, def string, usage string) {
23+
v, ok := os.LookupEnv(env)
24+
if !ok || v == "" {
25+
v = def
26+
}
27+
flagset.StringVarP(p, name, shorthand, v, fmtUsage(usage, env))
28+
}
29+
30+
// Uint8VarP sets a uint8 flag on the given flag set.
31+
func Uint8VarP(flagset *pflag.FlagSet, ptr *uint8, name string, shorthand string, env string, def uint8, usage string) {
32+
val, ok := os.LookupEnv(env)
33+
if !ok || val == "" {
34+
flagset.Uint8VarP(ptr, name, shorthand, def, fmtUsage(usage, env))
35+
return
36+
}
37+
38+
vi64, err := strconv.ParseUint(val, 10, 8)
39+
if err != nil {
40+
flagset.Uint8VarP(ptr, name, shorthand, def, fmtUsage(usage, env))
41+
return
42+
}
43+
44+
flagset.Uint8VarP(ptr, name, shorthand, uint8(vi64), fmtUsage(usage, env))
45+
}
46+
47+
// BoolVarP sets a bool flag on the given flag set.
48+
func BoolVarP(flagset *pflag.FlagSet, ptr *bool, name string, shorthand string, env string, def bool, usage string) {
49+
val, ok := os.LookupEnv(env)
50+
if !ok || val == "" {
51+
flagset.BoolVarP(ptr, name, shorthand, def, fmtUsage(usage, env))
52+
return
53+
}
54+
55+
valb, err := strconv.ParseBool(val)
56+
if err != nil {
57+
flagset.BoolVarP(ptr, name, shorthand, def, fmtUsage(usage, env))
58+
return
59+
}
60+
61+
flagset.BoolVarP(ptr, name, shorthand, valb, fmtUsage(usage, env))
62+
}
63+
64+
func fmtUsage(u string, env string) string {
65+
if env == "" {
66+
return fmt.Sprintf("%s.", u)
67+
}
68+
69+
return fmt.Sprintf("%s - consumes $%s.", u, env)
70+
}

cli/cliflag/cliflag_test.go

Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
package cliflag_test
2+
3+
import (
4+
"fmt"
5+
"strconv"
6+
"testing"
7+
8+
"github.com/spf13/pflag"
9+
"github.com/stretchr/testify/require"
10+
11+
"github.com/coder/coder/cli/cliflag"
12+
"github.com/coder/coder/cryptorand"
13+
)
14+
15+
// Testcliflag cannot run in parallel because it uses t.Setenv.
16+
//nolint:paralleltest
17+
func TestCliflag(t *testing.T) {
18+
t.Run("StringDefault", func(t *testing.T) {
19+
var p string
20+
flagset, name, shorthand, env, usage := randomFlag()
21+
def, _ := cryptorand.String(10)
22+
23+
cliflag.StringVarP(flagset, &p, name, shorthand, env, def, usage)
24+
got, err := flagset.GetString(name)
25+
require.NoError(t, err)
26+
require.Equal(t, def, got)
27+
require.Contains(t, flagset.FlagUsages(), usage)
28+
require.Contains(t, flagset.FlagUsages(), fmt.Sprintf(" - consumes $%s", env))
29+
})
30+
31+
t.Run("StringEnvVar", func(t *testing.T) {
32+
var p string
33+
flagset, name, shorthand, env, usage := randomFlag()
34+
envValue, _ := cryptorand.String(10)
35+
t.Setenv(env, envValue)
36+
def, _ := cryptorand.String(10)
37+
38+
cliflag.StringVarP(flagset, &p, name, shorthand, env, def, usage)
39+
got, err := flagset.GetString(name)
40+
require.NoError(t, err)
41+
require.Equal(t, envValue, got)
42+
})
43+
44+
t.Run("EmptyEnvVar", func(t *testing.T) {
45+
var p string
46+
flagset, name, shorthand, _, usage := randomFlag()
47+
def, _ := cryptorand.String(10)
48+
49+
cliflag.StringVarP(flagset, &p, name, shorthand, "", def, usage)
50+
got, err := flagset.GetString(name)
51+
require.NoError(t, err)
52+
require.Equal(t, def, got)
53+
require.Contains(t, flagset.FlagUsages(), usage)
54+
require.NotContains(t, flagset.FlagUsages(), " - consumes")
55+
})
56+
57+
t.Run("IntDefault", func(t *testing.T) {
58+
var p uint8
59+
flagset, name, shorthand, env, usage := randomFlag()
60+
def, _ := cryptorand.Int63n(10)
61+
62+
cliflag.Uint8VarP(flagset, &p, name, shorthand, env, uint8(def), usage)
63+
got, err := flagset.GetUint8(name)
64+
require.NoError(t, err)
65+
require.Equal(t, uint8(def), got)
66+
require.Contains(t, flagset.FlagUsages(), usage)
67+
require.Contains(t, flagset.FlagUsages(), fmt.Sprintf(" - consumes $%s", env))
68+
})
69+
70+
t.Run("IntEnvVar", func(t *testing.T) {
71+
var p uint8
72+
flagset, name, shorthand, env, usage := randomFlag()
73+
envValue, _ := cryptorand.Int63n(10)
74+
t.Setenv(env, strconv.FormatUint(uint64(envValue), 10))
75+
def, _ := cryptorand.Int()
76+
77+
cliflag.Uint8VarP(flagset, &p, name, shorthand, env, uint8(def), usage)
78+
got, err := flagset.GetUint8(name)
79+
require.NoError(t, err)
80+
require.Equal(t, uint8(envValue), got)
81+
})
82+
83+
t.Run("IntFailParse", func(t *testing.T) {
84+
var p uint8
85+
flagset, name, shorthand, env, usage := randomFlag()
86+
envValue, _ := cryptorand.String(10)
87+
t.Setenv(env, envValue)
88+
def, _ := cryptorand.Int63n(10)
89+
90+
cliflag.Uint8VarP(flagset, &p, name, shorthand, env, uint8(def), usage)
91+
got, err := flagset.GetUint8(name)
92+
require.NoError(t, err)
93+
require.Equal(t, uint8(def), got)
94+
})
95+
96+
t.Run("BoolDefault", func(t *testing.T) {
97+
var p bool
98+
flagset, name, shorthand, env, usage := randomFlag()
99+
def, _ := cryptorand.Bool()
100+
101+
cliflag.BoolVarP(flagset, &p, name, shorthand, env, def, usage)
102+
got, err := flagset.GetBool(name)
103+
require.NoError(t, err)
104+
require.Equal(t, def, got)
105+
require.Contains(t, flagset.FlagUsages(), usage)
106+
require.Contains(t, flagset.FlagUsages(), fmt.Sprintf(" - consumes $%s", env))
107+
})
108+
109+
t.Run("BoolEnvVar", func(t *testing.T) {
110+
var p bool
111+
flagset, name, shorthand, env, usage := randomFlag()
112+
envValue, _ := cryptorand.Bool()
113+
t.Setenv(env, strconv.FormatBool(envValue))
114+
def, _ := cryptorand.Bool()
115+
116+
cliflag.BoolVarP(flagset, &p, name, shorthand, env, def, usage)
117+
got, err := flagset.GetBool(name)
118+
require.NoError(t, err)
119+
require.Equal(t, envValue, got)
120+
})
121+
122+
t.Run("BoolFailParse", func(t *testing.T) {
123+
var p bool
124+
flagset, name, shorthand, env, usage := randomFlag()
125+
envValue, _ := cryptorand.String(10)
126+
t.Setenv(env, envValue)
127+
def, _ := cryptorand.Bool()
128+
129+
cliflag.BoolVarP(flagset, &p, name, shorthand, env, def, usage)
130+
got, err := flagset.GetBool(name)
131+
require.NoError(t, err)
132+
require.Equal(t, def, got)
133+
})
134+
}
135+
136+
func randomFlag() (*pflag.FlagSet, string, string, string, string) {
137+
fsname, _ := cryptorand.String(10)
138+
flagset := pflag.NewFlagSet(fsname, pflag.PanicOnError)
139+
name, _ := cryptorand.String(10)
140+
shorthand, _ := cryptorand.String(1)
141+
env, _ := cryptorand.String(10)
142+
usage, _ := cryptorand.String(10)
143+
144+
return flagset, name, shorthand, env, usage
145+
}

cli/start.go

Lines changed: 25 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ import (
1313
"net/url"
1414
"os"
1515
"os/signal"
16-
"strconv"
1716
"time"
1817

1918
"github.com/briandowns/spinner"
@@ -25,6 +24,7 @@ import (
2524

2625
"cdr.dev/slog"
2726
"cdr.dev/slog/sloggers/sloghuman"
27+
"github.com/coder/coder/cli/cliflag"
2828
"github.com/coder/coder/cli/cliui"
2929
"github.com/coder/coder/cli/config"
3030
"github.com/coder/coder/coderd"
@@ -40,10 +40,11 @@ import (
4040

4141
func start() *cobra.Command {
4242
var (
43-
accessURL string
44-
address string
45-
dev bool
46-
postgresURL string
43+
accessURL string
44+
address string
45+
dev bool
46+
postgresURL string
47+
// provisionerDaemonCount is a uint8 to ensure a number > 0.
4748
provisionerDaemonCount uint8
4849
tlsCertFile string
4950
tlsClientCAFile string
@@ -57,10 +58,6 @@ func start() *cobra.Command {
5758
Use: "start",
5859
RunE: func(cmd *cobra.Command, args []string) error {
5960
printLogo(cmd)
60-
if postgresURL == "" {
61-
// Default to the environment variable!
62-
postgresURL = os.Getenv("CODER_PG_CONNECTION_URL")
63-
}
6461

6562
listener, err := net.Listen("tcp", address)
6663
if err != nil {
@@ -163,7 +160,7 @@ func start() *cobra.Command {
163160
}
164161

165162
provisionerDaemons := make([]*provisionerd.Server, 0)
166-
for i := uint8(0); i < provisionerDaemonCount; i++ {
163+
for i := 0; uint8(i) < provisionerDaemonCount; i++ {
167164
daemonClose, err := newProvisionerDaemon(cmd.Context(), client, logger)
168165
if err != nil {
169166
return xerrors.Errorf("create provisioner daemon: %w", err)
@@ -305,46 +302,27 @@ func start() *cobra.Command {
305302
return nil
306303
},
307304
}
308-
defaultAddress := os.Getenv("CODER_ADDRESS")
309-
if defaultAddress == "" {
310-
defaultAddress = "127.0.0.1:3000"
311-
}
312-
root.Flags().StringVarP(&accessURL, "access-url", "", os.Getenv("CODER_ACCESS_URL"), "Specifies the external URL to access Coder (uses $CODER_ACCESS_URL).")
313-
root.Flags().StringVarP(&address, "address", "a", defaultAddress, "The address to serve the API and dashboard (uses $CODER_ADDRESS).")
314-
defaultDev, _ := strconv.ParseBool(os.Getenv("CODER_DEV_MODE"))
315-
root.Flags().BoolVarP(&dev, "dev", "", defaultDev, "Serve Coder in dev mode for tinkering (uses $CODER_DEV_MODE).")
316-
root.Flags().StringVarP(&postgresURL, "postgres-url", "", "",
317-
"URL of a PostgreSQL database to connect to (defaults to $CODER_PG_CONNECTION_URL).")
318-
root.Flags().Uint8VarP(&provisionerDaemonCount, "provisioner-daemons", "", 1, "The amount of provisioner daemons to create on start.")
319-
defaultTLSEnable, _ := strconv.ParseBool(os.Getenv("CODER_TLS_ENABLE"))
320-
root.Flags().BoolVarP(&tlsEnable, "tls-enable", "", defaultTLSEnable, "Specifies if TLS will be enabled (uses $CODER_TLS_ENABLE).")
321-
root.Flags().StringVarP(&tlsCertFile, "tls-cert-file", "", os.Getenv("CODER_TLS_CERT_FILE"),
305+
306+
cliflag.StringVarP(root.Flags(), &accessURL, "access-url", "", "CODER_ACCESS_URL", "", "Specifies the external URL to access Coder")
307+
cliflag.StringVarP(root.Flags(), &address, "address", "a", "CODER_ADDRESS", "127.0.0.1:3000", "The address to serve the API and dashboard")
308+
cliflag.BoolVarP(root.Flags(), &dev, "dev", "", "CODER_DEV_MODE", false, "Serve Coder in dev mode for tinkering")
309+
cliflag.StringVarP(root.Flags(), &postgresURL, "postgres-url", "", "CODER_PG_CONNECTION_URL", "", "URL of a PostgreSQL database to connect to")
310+
cliflag.Uint8VarP(root.Flags(), &provisionerDaemonCount, "provisioner-daemons", "", "CODER_PROVISIONER_DAEMONS", 1, "The amount of provisioner daemons to create on start.")
311+
cliflag.BoolVarP(root.Flags(), &tlsEnable, "tls-enable", "", "CODER_TLS_ENABLE", false, "Specifies if TLS will be enabled")
312+
cliflag.StringVarP(root.Flags(), &tlsCertFile, "tls-cert-file", "", "CODER_TLS_CERT_FILE", "",
322313
"Specifies the path to the certificate for TLS. It requires a PEM-encoded file. "+
323314
"To configure the listener to use a CA certificate, concatenate the primary certificate "+
324-
"and the CA certificate together. The primary certificate should appear first in the combined file (uses $CODER_TLS_CERT_FILE).")
325-
root.Flags().StringVarP(&tlsClientCAFile, "tls-client-ca-file", "", os.Getenv("CODER_TLS_CLIENT_CA_FILE"),
326-
"PEM-encoded Certificate Authority file used for checking the authenticity of client (uses $CODER_TLS_CLIENT_CA_FILE).")
327-
defaultTLSClientAuth := os.Getenv("CODER_TLS_CLIENT_AUTH")
328-
if defaultTLSClientAuth == "" {
329-
defaultTLSClientAuth = "request"
330-
}
331-
root.Flags().StringVarP(&tlsClientAuth, "tls-client-auth", "", defaultTLSClientAuth,
315+
"and the CA certificate together. The primary certificate should appear first in the combined file")
316+
cliflag.StringVarP(root.Flags(), &tlsClientCAFile, "tls-client-ca-file", "", "CODER_TLS_CLIENT_CA_FILE", "",
317+
"PEM-encoded Certificate Authority file used for checking the authenticity of client")
318+
cliflag.StringVarP(root.Flags(), &tlsClientAuth, "tls-client-auth", "", "CODER_TLS_CLIENT_AUTH", "request",
332319
`Specifies the policy the server will follow for TLS Client Authentication. `+
333-
`Accepted values are "none", "request", "require-any", "verify-if-given", or "require-and-verify" (uses $CODER_TLS_CLIENT_AUTH).`)
334-
root.Flags().StringVarP(&tlsKeyFile, "tls-key-file", "", os.Getenv("CODER_TLS_KEY_FILE"),
335-
"Specifies the path to the private key for the certificate. It requires a PEM-encoded file (uses $CODER_TLS_KEY_FILE).")
336-
defaultTLSMinVersion := os.Getenv("CODER_TLS_MIN_VERSION")
337-
if defaultTLSMinVersion == "" {
338-
defaultTLSMinVersion = "tls12"
339-
}
340-
root.Flags().StringVarP(&tlsMinVersion, "tls-min-version", "", defaultTLSMinVersion,
341-
`Specifies the minimum supported version of TLS. Accepted values are "tls10", "tls11", "tls12" or "tls13" (uses $CODER_TLS_MIN_VERSION).`)
342-
defaultTunnelRaw := os.Getenv("CODER_DEV_TUNNEL")
343-
if defaultTunnelRaw == "" {
344-
defaultTunnelRaw = "true"
345-
}
346-
defaultTunnel, _ := strconv.ParseBool(defaultTunnelRaw)
347-
root.Flags().BoolVarP(&useTunnel, "tunnel", "", defaultTunnel, "Serve dev mode through a Cloudflare Tunnel for easy setup (uses $CODER_DEV_TUNNEL).")
320+
`Accepted values are "none", "request", "require-any", "verify-if-given", or "require-and-verify"`)
321+
cliflag.StringVarP(root.Flags(), &tlsKeyFile, "tls-key-file", "", "CODER_TLS_KEY_FILE", "",
322+
"Specifies the path to the private key for the certificate. It requires a PEM-encoded file")
323+
cliflag.StringVarP(root.Flags(), &tlsMinVersion, "tls-min-version", "", "CODER_TLS_MIN_VERSION", "tls12",
324+
`Specifies the minimum supported version of TLS. Accepted values are "tls10", "tls11", "tls12" or "tls13"`)
325+
cliflag.BoolVarP(root.Flags(), &useTunnel, "tunnel", "", "CODER_DEV_TUNNEL", false, "Serve dev mode through a Cloudflare Tunnel for easy setup")
348326
_ = root.Flags().MarkHidden("tunnel")
349327

350328
return root

cli/workspaceagent.go

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ package cli
33
import (
44
"context"
55
"net/url"
6-
"os"
76
"time"
87

98
"cloud.google.com/go/compute/metadata"
@@ -14,6 +13,7 @@ import (
1413
"cdr.dev/slog/sloggers/sloghuman"
1514

1615
"github.com/coder/coder/agent"
16+
"github.com/coder/coder/cli/cliflag"
1717
"github.com/coder/coder/codersdk"
1818
"github.com/coder/coder/peer"
1919
"github.com/coder/retry"
@@ -23,6 +23,7 @@ func workspaceAgent() *cobra.Command {
2323
var (
2424
rawURL string
2525
auth string
26+
token string
2627
)
2728
cmd := &cobra.Command{
2829
Use: "agent",
@@ -40,11 +41,10 @@ func workspaceAgent() *cobra.Command {
4041
client := codersdk.New(coderURL)
4142
switch auth {
4243
case "token":
43-
sessionToken, exists := os.LookupEnv("CODER_TOKEN")
44-
if !exists {
44+
if token == "" {
4545
return xerrors.Errorf("CODER_TOKEN must be set for token auth")
4646
}
47-
client.SessionToken = sessionToken
47+
client.SessionToken = token
4848
case "google-instance-identity":
4949
// This is *only* done for testing to mock client authentication.
5050
// This will never be set in a production scenario.
@@ -83,12 +83,10 @@ func workspaceAgent() *cobra.Command {
8383
return closer.Close()
8484
},
8585
}
86-
defaultAuth := os.Getenv("CODER_AUTH")
87-
if defaultAuth == "" {
88-
defaultAuth = "token"
89-
}
90-
cmd.Flags().StringVarP(&auth, "auth", "", defaultAuth, "Specify the authentication type to use for the agent.")
91-
cmd.Flags().StringVarP(&rawURL, "url", "", os.Getenv("CODER_URL"), "Specify the URL to access Coder.")
86+
87+
cliflag.StringVarP(cmd.Flags(), &auth, "auth", "", "CODER_AUTH", "token", "Specify the authentication type to use for the agent")
88+
cliflag.StringVarP(cmd.Flags(), &rawURL, "url", "", "CODER_URL", "", "Specify the URL to access Coder")
89+
cliflag.StringVarP(cmd.Flags(), &auth, "token", "", "CODER_TOKEN", "", "Specifies the authentication token to access Coder")
9290

9391
return cmd
9492
}

0 commit comments

Comments
 (0)