diff --git a/cli/cliflag/cliflag.go b/cli/cliflag/cliflag.go new file mode 100644 index 0000000000000..e846d5fc391ae --- /dev/null +++ b/cli/cliflag/cliflag.go @@ -0,0 +1,70 @@ +// Package cliflag extends flagset with environment variable defaults. +// +// Usage: +// +// cliflag.String(root.Flags(), &address, "address", "a", "CODER_ADDRESS", "127.0.0.1:3000", "The address to serve the API and dashboard") +// +// Will produce the following usage docs: +// +// -a, --address string The address to serve the API and dashboard (uses $CODER_ADDRESS). (default "127.0.0.1:3000") +// +package cliflag + +import ( + "fmt" + "os" + "strconv" + + "github.com/spf13/pflag" +) + +// StringVarP sets a string flag on the given flag set. +func StringVarP(flagset *pflag.FlagSet, p *string, name string, shorthand string, env string, def string, usage string) { + v, ok := os.LookupEnv(env) + if !ok || v == "" { + v = def + } + flagset.StringVarP(p, name, shorthand, v, fmtUsage(usage, env)) +} + +// Uint8VarP sets a uint8 flag on the given flag set. +func Uint8VarP(flagset *pflag.FlagSet, ptr *uint8, name string, shorthand string, env string, def uint8, usage string) { + val, ok := os.LookupEnv(env) + if !ok || val == "" { + flagset.Uint8VarP(ptr, name, shorthand, def, fmtUsage(usage, env)) + return + } + + vi64, err := strconv.ParseUint(val, 10, 8) + if err != nil { + flagset.Uint8VarP(ptr, name, shorthand, def, fmtUsage(usage, env)) + return + } + + flagset.Uint8VarP(ptr, name, shorthand, uint8(vi64), fmtUsage(usage, env)) +} + +// BoolVarP sets a bool flag on the given flag set. +func BoolVarP(flagset *pflag.FlagSet, ptr *bool, name string, shorthand string, env string, def bool, usage string) { + val, ok := os.LookupEnv(env) + if !ok || val == "" { + flagset.BoolVarP(ptr, name, shorthand, def, fmtUsage(usage, env)) + return + } + + valb, err := strconv.ParseBool(val) + if err != nil { + flagset.BoolVarP(ptr, name, shorthand, def, fmtUsage(usage, env)) + return + } + + flagset.BoolVarP(ptr, name, shorthand, valb, fmtUsage(usage, env)) +} + +func fmtUsage(u string, env string) string { + if env == "" { + return fmt.Sprintf("%s.", u) + } + + return fmt.Sprintf("%s - consumes $%s.", u, env) +} diff --git a/cli/cliflag/cliflag_test.go b/cli/cliflag/cliflag_test.go new file mode 100644 index 0000000000000..542bb04abfd9d --- /dev/null +++ b/cli/cliflag/cliflag_test.go @@ -0,0 +1,145 @@ +package cliflag_test + +import ( + "fmt" + "strconv" + "testing" + + "github.com/spf13/pflag" + "github.com/stretchr/testify/require" + + "github.com/coder/coder/cli/cliflag" + "github.com/coder/coder/cryptorand" +) + +// Testcliflag cannot run in parallel because it uses t.Setenv. +//nolint:paralleltest +func TestCliflag(t *testing.T) { + t.Run("StringDefault", func(t *testing.T) { + var p string + flagset, name, shorthand, env, usage := randomFlag() + def, _ := cryptorand.String(10) + + cliflag.StringVarP(flagset, &p, name, shorthand, env, def, usage) + got, err := flagset.GetString(name) + require.NoError(t, err) + require.Equal(t, def, got) + require.Contains(t, flagset.FlagUsages(), usage) + require.Contains(t, flagset.FlagUsages(), fmt.Sprintf(" - consumes $%s", env)) + }) + + t.Run("StringEnvVar", func(t *testing.T) { + var p string + flagset, name, shorthand, env, usage := randomFlag() + envValue, _ := cryptorand.String(10) + t.Setenv(env, envValue) + def, _ := cryptorand.String(10) + + cliflag.StringVarP(flagset, &p, name, shorthand, env, def, usage) + got, err := flagset.GetString(name) + require.NoError(t, err) + require.Equal(t, envValue, got) + }) + + t.Run("EmptyEnvVar", func(t *testing.T) { + var p string + flagset, name, shorthand, _, usage := randomFlag() + def, _ := cryptorand.String(10) + + cliflag.StringVarP(flagset, &p, name, shorthand, "", def, usage) + got, err := flagset.GetString(name) + require.NoError(t, err) + require.Equal(t, def, got) + require.Contains(t, flagset.FlagUsages(), usage) + require.NotContains(t, flagset.FlagUsages(), " - consumes") + }) + + t.Run("IntDefault", func(t *testing.T) { + var p uint8 + flagset, name, shorthand, env, usage := randomFlag() + def, _ := cryptorand.Int63n(10) + + cliflag.Uint8VarP(flagset, &p, name, shorthand, env, uint8(def), usage) + got, err := flagset.GetUint8(name) + require.NoError(t, err) + require.Equal(t, uint8(def), got) + require.Contains(t, flagset.FlagUsages(), usage) + require.Contains(t, flagset.FlagUsages(), fmt.Sprintf(" - consumes $%s", env)) + }) + + t.Run("IntEnvVar", func(t *testing.T) { + var p uint8 + flagset, name, shorthand, env, usage := randomFlag() + envValue, _ := cryptorand.Int63n(10) + t.Setenv(env, strconv.FormatUint(uint64(envValue), 10)) + def, _ := cryptorand.Int() + + cliflag.Uint8VarP(flagset, &p, name, shorthand, env, uint8(def), usage) + got, err := flagset.GetUint8(name) + require.NoError(t, err) + require.Equal(t, uint8(envValue), got) + }) + + t.Run("IntFailParse", func(t *testing.T) { + var p uint8 + flagset, name, shorthand, env, usage := randomFlag() + envValue, _ := cryptorand.String(10) + t.Setenv(env, envValue) + def, _ := cryptorand.Int63n(10) + + cliflag.Uint8VarP(flagset, &p, name, shorthand, env, uint8(def), usage) + got, err := flagset.GetUint8(name) + require.NoError(t, err) + require.Equal(t, uint8(def), got) + }) + + t.Run("BoolDefault", func(t *testing.T) { + var p bool + flagset, name, shorthand, env, usage := randomFlag() + def, _ := cryptorand.Bool() + + cliflag.BoolVarP(flagset, &p, name, shorthand, env, def, usage) + got, err := flagset.GetBool(name) + require.NoError(t, err) + require.Equal(t, def, got) + require.Contains(t, flagset.FlagUsages(), usage) + require.Contains(t, flagset.FlagUsages(), fmt.Sprintf(" - consumes $%s", env)) + }) + + t.Run("BoolEnvVar", func(t *testing.T) { + var p bool + flagset, name, shorthand, env, usage := randomFlag() + envValue, _ := cryptorand.Bool() + t.Setenv(env, strconv.FormatBool(envValue)) + def, _ := cryptorand.Bool() + + cliflag.BoolVarP(flagset, &p, name, shorthand, env, def, usage) + got, err := flagset.GetBool(name) + require.NoError(t, err) + require.Equal(t, envValue, got) + }) + + t.Run("BoolFailParse", func(t *testing.T) { + var p bool + flagset, name, shorthand, env, usage := randomFlag() + envValue, _ := cryptorand.String(10) + t.Setenv(env, envValue) + def, _ := cryptorand.Bool() + + cliflag.BoolVarP(flagset, &p, name, shorthand, env, def, usage) + got, err := flagset.GetBool(name) + require.NoError(t, err) + require.Equal(t, def, got) + }) +} + +func randomFlag() (*pflag.FlagSet, string, string, string, string) { + fsname, _ := cryptorand.String(10) + flagset := pflag.NewFlagSet(fsname, pflag.PanicOnError) + name, _ := cryptorand.String(10) + shorthand, _ := cryptorand.String(1) + env, _ := cryptorand.String(10) + usage, _ := cryptorand.String(10) + + return flagset, name, shorthand, env, usage +} diff --git a/cli/start.go b/cli/start.go index 99260ed78565d..7f7aa6d4c11dc 100644 --- a/cli/start.go +++ b/cli/start.go @@ -13,7 +13,6 @@ import ( "net/url" "os" "os/signal" - "strconv" "time" "github.com/briandowns/spinner" @@ -25,6 +24,7 @@ import ( "cdr.dev/slog" "cdr.dev/slog/sloggers/sloghuman" + "github.com/coder/coder/cli/cliflag" "github.com/coder/coder/cli/cliui" "github.com/coder/coder/cli/config" "github.com/coder/coder/coderd" @@ -40,10 +40,11 @@ import ( func start() *cobra.Command { var ( - accessURL string - address string - dev bool - postgresURL string + accessURL string + address string + dev bool + postgresURL string + // provisionerDaemonCount is a uint8 to ensure a number > 0. provisionerDaemonCount uint8 tlsCertFile string tlsClientCAFile string @@ -57,10 +58,6 @@ func start() *cobra.Command { Use: "start", RunE: func(cmd *cobra.Command, args []string) error { printLogo(cmd) - if postgresURL == "" { - // Default to the environment variable! - postgresURL = os.Getenv("CODER_PG_CONNECTION_URL") - } listener, err := net.Listen("tcp", address) if err != nil { @@ -163,7 +160,7 @@ func start() *cobra.Command { } provisionerDaemons := make([]*provisionerd.Server, 0) - for i := uint8(0); i < provisionerDaemonCount; i++ { + for i := 0; uint8(i) < provisionerDaemonCount; i++ { daemonClose, err := newProvisionerDaemon(cmd.Context(), client, logger) if err != nil { return xerrors.Errorf("create provisioner daemon: %w", err) @@ -305,46 +302,27 @@ func start() *cobra.Command { return nil }, } - defaultAddress := os.Getenv("CODER_ADDRESS") - if defaultAddress == "" { - defaultAddress = "127.0.0.1:3000" - } - root.Flags().StringVarP(&accessURL, "access-url", "", os.Getenv("CODER_ACCESS_URL"), "Specifies the external URL to access Coder (uses $CODER_ACCESS_URL).") - root.Flags().StringVarP(&address, "address", "a", defaultAddress, "The address to serve the API and dashboard (uses $CODER_ADDRESS).") - defaultDev, _ := strconv.ParseBool(os.Getenv("CODER_DEV_MODE")) - root.Flags().BoolVarP(&dev, "dev", "", defaultDev, "Serve Coder in dev mode for tinkering (uses $CODER_DEV_MODE).") - root.Flags().StringVarP(&postgresURL, "postgres-url", "", "", - "URL of a PostgreSQL database to connect to (defaults to $CODER_PG_CONNECTION_URL).") - root.Flags().Uint8VarP(&provisionerDaemonCount, "provisioner-daemons", "", 1, "The amount of provisioner daemons to create on start.") - defaultTLSEnable, _ := strconv.ParseBool(os.Getenv("CODER_TLS_ENABLE")) - root.Flags().BoolVarP(&tlsEnable, "tls-enable", "", defaultTLSEnable, "Specifies if TLS will be enabled (uses $CODER_TLS_ENABLE).") - root.Flags().StringVarP(&tlsCertFile, "tls-cert-file", "", os.Getenv("CODER_TLS_CERT_FILE"), + + cliflag.StringVarP(root.Flags(), &accessURL, "access-url", "", "CODER_ACCESS_URL", "", "Specifies the external URL to access Coder") + cliflag.StringVarP(root.Flags(), &address, "address", "a", "CODER_ADDRESS", "127.0.0.1:3000", "The address to serve the API and dashboard") + cliflag.BoolVarP(root.Flags(), &dev, "dev", "", "CODER_DEV_MODE", false, "Serve Coder in dev mode for tinkering") + cliflag.StringVarP(root.Flags(), &postgresURL, "postgres-url", "", "CODER_PG_CONNECTION_URL", "", "URL of a PostgreSQL database to connect to") + cliflag.Uint8VarP(root.Flags(), &provisionerDaemonCount, "provisioner-daemons", "", "CODER_PROVISIONER_DAEMONS", 1, "The amount of provisioner daemons to create on start.") + cliflag.BoolVarP(root.Flags(), &tlsEnable, "tls-enable", "", "CODER_TLS_ENABLE", false, "Specifies if TLS will be enabled") + cliflag.StringVarP(root.Flags(), &tlsCertFile, "tls-cert-file", "", "CODER_TLS_CERT_FILE", "", "Specifies the path to the certificate for TLS. It requires a PEM-encoded file. "+ "To configure the listener to use a CA certificate, concatenate the primary certificate "+ - "and the CA certificate together. The primary certificate should appear first in the combined file (uses $CODER_TLS_CERT_FILE).") - root.Flags().StringVarP(&tlsClientCAFile, "tls-client-ca-file", "", os.Getenv("CODER_TLS_CLIENT_CA_FILE"), - "PEM-encoded Certificate Authority file used for checking the authenticity of client (uses $CODER_TLS_CLIENT_CA_FILE).") - defaultTLSClientAuth := os.Getenv("CODER_TLS_CLIENT_AUTH") - if defaultTLSClientAuth == "" { - defaultTLSClientAuth = "request" - } - root.Flags().StringVarP(&tlsClientAuth, "tls-client-auth", "", defaultTLSClientAuth, + "and the CA certificate together. The primary certificate should appear first in the combined file") + cliflag.StringVarP(root.Flags(), &tlsClientCAFile, "tls-client-ca-file", "", "CODER_TLS_CLIENT_CA_FILE", "", + "PEM-encoded Certificate Authority file used for checking the authenticity of client") + cliflag.StringVarP(root.Flags(), &tlsClientAuth, "tls-client-auth", "", "CODER_TLS_CLIENT_AUTH", "request", `Specifies the policy the server will follow for TLS Client Authentication. `+ - `Accepted values are "none", "request", "require-any", "verify-if-given", or "require-and-verify" (uses $CODER_TLS_CLIENT_AUTH).`) - root.Flags().StringVarP(&tlsKeyFile, "tls-key-file", "", os.Getenv("CODER_TLS_KEY_FILE"), - "Specifies the path to the private key for the certificate. It requires a PEM-encoded file (uses $CODER_TLS_KEY_FILE).") - defaultTLSMinVersion := os.Getenv("CODER_TLS_MIN_VERSION") - if defaultTLSMinVersion == "" { - defaultTLSMinVersion = "tls12" - } - root.Flags().StringVarP(&tlsMinVersion, "tls-min-version", "", defaultTLSMinVersion, - `Specifies the minimum supported version of TLS. Accepted values are "tls10", "tls11", "tls12" or "tls13" (uses $CODER_TLS_MIN_VERSION).`) - defaultTunnelRaw := os.Getenv("CODER_DEV_TUNNEL") - if defaultTunnelRaw == "" { - defaultTunnelRaw = "true" - } - defaultTunnel, _ := strconv.ParseBool(defaultTunnelRaw) - root.Flags().BoolVarP(&useTunnel, "tunnel", "", defaultTunnel, "Serve dev mode through a Cloudflare Tunnel for easy setup (uses $CODER_DEV_TUNNEL).") + `Accepted values are "none", "request", "require-any", "verify-if-given", or "require-and-verify"`) + cliflag.StringVarP(root.Flags(), &tlsKeyFile, "tls-key-file", "", "CODER_TLS_KEY_FILE", "", + "Specifies the path to the private key for the certificate. It requires a PEM-encoded file") + cliflag.StringVarP(root.Flags(), &tlsMinVersion, "tls-min-version", "", "CODER_TLS_MIN_VERSION", "tls12", + `Specifies the minimum supported version of TLS. Accepted values are "tls10", "tls11", "tls12" or "tls13"`) + cliflag.BoolVarP(root.Flags(), &useTunnel, "tunnel", "", "CODER_DEV_TUNNEL", false, "Serve dev mode through a Cloudflare Tunnel for easy setup") _ = root.Flags().MarkHidden("tunnel") return root diff --git a/cli/workspaceagent.go b/cli/workspaceagent.go index 369fe8010445b..ec1929ddf4659 100644 --- a/cli/workspaceagent.go +++ b/cli/workspaceagent.go @@ -3,7 +3,6 @@ package cli import ( "context" "net/url" - "os" "time" "cloud.google.com/go/compute/metadata" @@ -14,6 +13,7 @@ import ( "cdr.dev/slog/sloggers/sloghuman" "github.com/coder/coder/agent" + "github.com/coder/coder/cli/cliflag" "github.com/coder/coder/codersdk" "github.com/coder/coder/peer" "github.com/coder/retry" @@ -23,6 +23,7 @@ func workspaceAgent() *cobra.Command { var ( rawURL string auth string + token string ) cmd := &cobra.Command{ Use: "agent", @@ -40,11 +41,10 @@ func workspaceAgent() *cobra.Command { client := codersdk.New(coderURL) switch auth { case "token": - sessionToken, exists := os.LookupEnv("CODER_TOKEN") - if !exists { + if token == "" { return xerrors.Errorf("CODER_TOKEN must be set for token auth") } - client.SessionToken = sessionToken + client.SessionToken = token case "google-instance-identity": // This is *only* done for testing to mock client authentication. // This will never be set in a production scenario. @@ -83,12 +83,10 @@ func workspaceAgent() *cobra.Command { return closer.Close() }, } - defaultAuth := os.Getenv("CODER_AUTH") - if defaultAuth == "" { - defaultAuth = "token" - } - cmd.Flags().StringVarP(&auth, "auth", "", defaultAuth, "Specify the authentication type to use for the agent.") - cmd.Flags().StringVarP(&rawURL, "url", "", os.Getenv("CODER_URL"), "Specify the URL to access Coder.") + + cliflag.StringVarP(cmd.Flags(), &auth, "auth", "", "CODER_AUTH", "token", "Specify the authentication type to use for the agent") + cliflag.StringVarP(cmd.Flags(), &rawURL, "url", "", "CODER_URL", "", "Specify the URL to access Coder") + cliflag.StringVarP(cmd.Flags(), &auth, "token", "", "CODER_TOKEN", "", "Specifies the authentication token to access Coder") return cmd }