Skip to content

Commit 834a148

Browse files
committed
Add flags for configuring GitHub auth
1 parent 02ff066 commit 834a148

File tree

3 files changed

+62
-20
lines changed

3 files changed

+62
-20
lines changed

cli/cliflag/cliflag.go

+9
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import (
1414
"fmt"
1515
"os"
1616
"strconv"
17+
"strings"
1718

1819
"github.com/spf13/pflag"
1920
)
@@ -27,6 +28,14 @@ func StringVarP(flagset *pflag.FlagSet, p *string, name string, shorthand string
2728
flagset.StringVarP(p, name, shorthand, v, fmtUsage(usage, env))
2829
}
2930

31+
func StringArrayVarP(flagset *pflag.FlagSet, ptr *[]string, name string, shorthand string, env string, def []string, usage string) {
32+
val, ok := os.LookupEnv(env)
33+
if ok {
34+
def = strings.Split(val, ",")
35+
}
36+
flagset.StringArrayVarP(ptr, name, shorthand, def, usage)
37+
}
38+
3039
// Uint8VarP sets a uint8 flag on the given flag set.
3140
func Uint8VarP(flagset *pflag.FlagSet, ptr *uint8, name string, shorthand string, env string, def uint8, usage string) {
3241
val, ok := os.LookupEnv(env)

cli/cliflag/cliflag_test.go

+20
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,26 @@ func TestCliflag(t *testing.T) {
5454
require.NotContains(t, flagset.FlagUsages(), " - consumes")
5555
})
5656

57+
t.Run("StringArrayDefault", func(t *testing.T) {
58+
var ptr []string
59+
flagset, name, shorthand, env, usage := randomFlag()
60+
def := []string{"hello"}
61+
cliflag.StringArrayVarP(flagset, &ptr, name, shorthand, env, def, usage)
62+
got, err := flagset.GetStringArray(name)
63+
require.NoError(t, err)
64+
require.Equal(t, def, got)
65+
})
66+
67+
t.Run("StringArrayEnvVar", func(t *testing.T) {
68+
var ptr []string
69+
flagset, name, shorthand, env, usage := randomFlag()
70+
t.Setenv(env, "wow,test")
71+
cliflag.StringArrayVarP(flagset, &ptr, name, shorthand, env, nil, usage)
72+
got, err := flagset.GetStringArray(name)
73+
require.NoError(t, err)
74+
require.Equal(t, []string{"wow", "test"}, got)
75+
})
76+
5777
t.Run("IntDefault", func(t *testing.T) {
5878
var ptr uint8
5979
flagset, name, shorthand, env, usage := randomFlag()

cli/start.go

+33-20
Original file line numberDiff line numberDiff line change
@@ -52,17 +52,21 @@ func start() *cobra.Command {
5252
dev bool
5353
postgresURL string
5454
// provisionerDaemonCount is a uint8 to ensure a number > 0.
55-
provisionerDaemonCount uint8
56-
tlsCertFile string
57-
tlsClientCAFile string
58-
tlsClientAuth string
59-
tlsEnable bool
60-
tlsKeyFile string
61-
tlsMinVersion string
62-
skipTunnel bool
63-
traceDatadog bool
64-
secureAuthCookie bool
65-
sshKeygenAlgorithmRaw string
55+
provisionerDaemonCount uint8
56+
oauth2GithubClientID string
57+
oauth2GithubClientSecret string
58+
oauth2GithubAllowedOrganizations []string
59+
oauth2GithubAllowSignups bool
60+
tlsCertFile string
61+
tlsClientCAFile string
62+
tlsClientAuth string
63+
tlsEnable bool
64+
tlsKeyFile string
65+
tlsMinVersion string
66+
skipTunnel bool
67+
traceDatadog bool
68+
secureAuthCookie bool
69+
sshKeygenAlgorithmRaw string
6670
)
6771
root := &cobra.Command{
6872
Use: "start",
@@ -156,23 +160,24 @@ func start() *cobra.Command {
156160
return xerrors.Errorf("parse ssh keygen algorithm %s: %w", sshKeygenAlgorithmRaw, err)
157161
}
158162

159-
githubOAuth2Config, err := configureGithubOAuth2(accessURLParsed, "", "")
160-
if err != nil {
161-
return xerrors.Errorf("configure github oauth2: %w", err)
162-
}
163-
164163
logger := slog.Make(sloghuman.Sink(os.Stderr))
165164
options := &coderd.Options{
166165
AccessURL: accessURLParsed,
167166
Logger: logger.Named("coderd"),
168167
Database: databasefake.New(),
169168
Pubsub: database.NewPubsubInMemory(),
170169
GoogleTokenValidator: validator,
171-
GithubOAuth2Config: githubOAuth2Config,
172170
SecureAuthCookie: secureAuthCookie,
173171
SSHKeygenAlgorithm: sshKeygenAlgorithm,
174172
}
175173

174+
if oauth2GithubClientSecret != "" {
175+
options.GithubOAuth2Config, err = configureGithubOAuth2(accessURLParsed, oauth2GithubClientID, oauth2GithubClientSecret, oauth2GithubAllowSignups, oauth2GithubAllowedOrganizations)
176+
if err != nil {
177+
return xerrors.Errorf("configure github oauth2: %w", err)
178+
}
179+
}
180+
176181
_, _ = fmt.Fprintf(cmd.ErrOrStderr(), "access-url: %s\n", accessURL)
177182
_, _ = fmt.Fprintf(cmd.ErrOrStderr(), "provisioner-daemons: %d\n", provisionerDaemonCount)
178183
_, _ = fmt.Fprintln(cmd.ErrOrStderr())
@@ -366,6 +371,14 @@ func start() *cobra.Command {
366371
cliflag.BoolVarP(root.Flags(), &dev, "dev", "", "CODER_DEV_MODE", false, "Serve Coder in dev mode for tinkering")
367372
cliflag.StringVarP(root.Flags(), &postgresURL, "postgres-url", "", "CODER_PG_CONNECTION_URL", "", "URL of a PostgreSQL database to connect to")
368373
cliflag.Uint8VarP(root.Flags(), &provisionerDaemonCount, "provisioner-daemons", "", "CODER_PROVISIONER_DAEMONS", 1, "The amount of provisioner daemons to create on start.")
374+
cliflag.StringVarP(root.Flags(), &oauth2GithubClientID, "oauth2-github-client-id", "", "CODER_OAUTH2_GITHUB_CLIENT_ID", "",
375+
"Specifies a client ID to use for oauth2 with GitHub.")
376+
cliflag.StringVarP(root.Flags(), &oauth2GithubClientSecret, "oauth2-github-client-secret", "", "CODER_OAUTH2_GITHUB_CLIENT_SECRET", "",
377+
"Specifies a client secret to use for oauth2 with GitHub.")
378+
cliflag.StringArrayVarP(root.Flags(), &oauth2GithubAllowedOrganizations, "oauth2-github-allowed-orgs", "", "CODER_OAUTH2_GITHUB_ALLOWED_ORGS", nil,
379+
"Specifies organizations the user must be a member of to authenticate with GitHub.")
380+
cliflag.BoolVarP(root.Flags(), &oauth2GithubAllowSignups, "oauth2-github-allow-signups", "", "CODER_OAUTH2_GITHUB_ALLOW_SIGNUPS", false,
381+
"Specifies whether new users can sign up with GitHub.")
369382
cliflag.BoolVarP(root.Flags(), &tlsEnable, "tls-enable", "", "CODER_TLS_ENABLE", false, "Specifies if TLS will be enabled")
370383
cliflag.StringVarP(root.Flags(), &tlsCertFile, "tls-cert-file", "", "CODER_TLS_CERT_FILE", "",
371384
"Specifies the path to the certificate for TLS. It requires a PEM-encoded file. "+
@@ -544,7 +557,7 @@ func configureTLS(listener net.Listener, tlsMinVersion, tlsClientAuth, tlsCertFi
544557
return tls.NewListener(listener, tlsConfig), nil
545558
}
546559

547-
func configureGithubOAuth2(accessURL *url.URL, clientID, clientSecret string) (*coderd.GithubOAuth2Config, error) {
560+
func configureGithubOAuth2(accessURL *url.URL, clientID, clientSecret string, allowSignups bool, allowOrgs []string) (*coderd.GithubOAuth2Config, error) {
548561
redirectURL, err := accessURL.Parse("/api/v2/users/oauth2/github/callback")
549562
if err != nil {
550563
return nil, xerrors.Errorf("parse github oauth callback url: %w", err)
@@ -561,8 +574,8 @@ func configureGithubOAuth2(accessURL *url.URL, clientID, clientSecret string) (*
561574
"user:email",
562575
},
563576
},
564-
AllowSignups: true,
565-
AllowOrganizations: []string{"coder"},
577+
AllowSignups: allowSignups,
578+
AllowOrganizations: allowOrgs,
566579
AuthenticatedUser: func(ctx context.Context, client *http.Client) (*github.User, error) {
567580
user, _, err := github.NewClient(client).Users.Get(ctx, "")
568581
return user, err

0 commit comments

Comments
 (0)