Skip to content

Commit 3634330

Browse files
committed
Merge branch 'main' into userscmd
2 parents bb02b53 + 7496c3d commit 3634330

38 files changed

+1254
-426
lines changed

.vscode/settings.json

+2
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
"nolint",
3636
"nosec",
3737
"ntqry",
38+
"OIDC",
3839
"oneof",
3940
"parameterscopeid",
4041
"pqtype",
@@ -46,6 +47,7 @@
4647
"ptytest",
4748
"retrier",
4849
"sdkproto",
50+
"Signup",
4951
"stretchr",
5052
"TCGETS",
5153
"tcpip",

cli/cliflag/cliflag.go

+9
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import (
1414
"fmt"
1515
"os"
1616
"strconv"
17+
"strings"
1718

1819
"github.com/spf13/pflag"
1920
)
@@ -27,6 +28,14 @@ func StringVarP(flagset *pflag.FlagSet, p *string, name string, shorthand string
2728
flagset.StringVarP(p, name, shorthand, v, fmtUsage(usage, env))
2829
}
2930

31+
func StringArrayVarP(flagset *pflag.FlagSet, ptr *[]string, name string, shorthand string, env string, def []string, usage string) {
32+
val, ok := os.LookupEnv(env)
33+
if ok {
34+
def = strings.Split(val, ",")
35+
}
36+
flagset.StringArrayVarP(ptr, name, shorthand, def, usage)
37+
}
38+
3039
// Uint8VarP sets a uint8 flag on the given flag set.
3140
func Uint8VarP(flagset *pflag.FlagSet, ptr *uint8, name string, shorthand string, env string, def uint8, usage string) {
3241
val, ok := os.LookupEnv(env)

cli/cliflag/cliflag_test.go

+20
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,26 @@ func TestCliflag(t *testing.T) {
5454
require.NotContains(t, flagset.FlagUsages(), " - consumes")
5555
})
5656

57+
t.Run("StringArrayDefault", func(t *testing.T) {
58+
var ptr []string
59+
flagset, name, shorthand, env, usage := randomFlag()
60+
def := []string{"hello"}
61+
cliflag.StringArrayVarP(flagset, &ptr, name, shorthand, env, def, usage)
62+
got, err := flagset.GetStringArray(name)
63+
require.NoError(t, err)
64+
require.Equal(t, def, got)
65+
})
66+
67+
t.Run("StringArrayEnvVar", func(t *testing.T) {
68+
var ptr []string
69+
flagset, name, shorthand, env, usage := randomFlag()
70+
t.Setenv(env, "wow,test")
71+
cliflag.StringArrayVarP(flagset, &ptr, name, shorthand, env, nil, usage)
72+
got, err := flagset.GetStringArray(name)
73+
require.NoError(t, err)
74+
require.Equal(t, []string{"wow", "test"}, got)
75+
})
76+
5777
t.Run("IntDefault", func(t *testing.T) {
5878
var ptr uint8
5979
flagset, name, shorthand, env, usage := randomFlag()

cli/server.go

+71-13
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,11 @@ import (
1818

1919
"github.com/briandowns/spinner"
2020
"github.com/coreos/go-systemd/daemon"
21+
"github.com/google/go-github/v43/github"
2122
"github.com/pion/turn/v2"
2223
"github.com/spf13/cobra"
24+
"golang.org/x/oauth2"
25+
xgithub "golang.org/x/oauth2/github"
2326
"golang.org/x/xerrors"
2427
"google.golang.org/api/idtoken"
2528
"google.golang.org/api/option"
@@ -51,19 +54,23 @@ func server() *cobra.Command {
5154
dev bool
5255
postgresURL string
5356
// provisionerDaemonCount is a uint8 to ensure a number > 0.
54-
provisionerDaemonCount uint8
55-
tlsCertFile string
56-
tlsClientCAFile string
57-
tlsClientAuth string
58-
tlsEnable bool
59-
tlsKeyFile string
60-
tlsMinVersion string
61-
turnRelayAddress string
62-
skipTunnel bool
63-
traceDatadog bool
64-
secureAuthCookie bool
65-
sshKeygenAlgorithmRaw string
66-
spooky bool
57+
provisionerDaemonCount uint8
58+
oauth2GithubClientID string
59+
oauth2GithubClientSecret string
60+
oauth2GithubAllowedOrganizations []string
61+
oauth2GithubAllowSignups bool
62+
tlsCertFile string
63+
tlsClientCAFile string
64+
tlsClientAuth string
65+
tlsEnable bool
66+
tlsKeyFile string
67+
tlsMinVersion string
68+
turnRelayAddress string
69+
skipTunnel bool
70+
traceDatadog bool
71+
secureAuthCookie bool
72+
sshKeygenAlgorithmRaw string
73+
spooky bool
6774
)
6875

6976
root := &cobra.Command{
@@ -180,6 +187,13 @@ func server() *cobra.Command {
180187
TURNServer: turnServer,
181188
}
182189

190+
if oauth2GithubClientSecret != "" {
191+
options.GithubOAuth2Config, err = configureGithubOAuth2(accessURLParsed, oauth2GithubClientID, oauth2GithubClientSecret, oauth2GithubAllowSignups, oauth2GithubAllowedOrganizations)
192+
if err != nil {
193+
return xerrors.Errorf("configure github oauth2: %w", err)
194+
}
195+
}
196+
183197
_, _ = fmt.Fprintf(cmd.ErrOrStderr(), "access-url: %s\n", accessURL)
184198
_, _ = fmt.Fprintf(cmd.ErrOrStderr(), "provisioner-daemons: %d\n", provisionerDaemonCount)
185199
_, _ = fmt.Fprintln(cmd.ErrOrStderr())
@@ -373,6 +387,14 @@ func server() *cobra.Command {
373387
cliflag.BoolVarP(root.Flags(), &dev, "dev", "", "CODER_DEV_MODE", false, "Serve Coder in dev mode for tinkering")
374388
cliflag.StringVarP(root.Flags(), &postgresURL, "postgres-url", "", "CODER_PG_CONNECTION_URL", "", "URL of a PostgreSQL database to connect to")
375389
cliflag.Uint8VarP(root.Flags(), &provisionerDaemonCount, "provisioner-daemons", "", "CODER_PROVISIONER_DAEMONS", 1, "The amount of provisioner daemons to create on start.")
390+
cliflag.StringVarP(root.Flags(), &oauth2GithubClientID, "oauth2-github-client-id", "", "CODER_OAUTH2_GITHUB_CLIENT_ID", "",
391+
"Specifies a client ID to use for oauth2 with GitHub.")
392+
cliflag.StringVarP(root.Flags(), &oauth2GithubClientSecret, "oauth2-github-client-secret", "", "CODER_OAUTH2_GITHUB_CLIENT_SECRET", "",
393+
"Specifies a client secret to use for oauth2 with GitHub.")
394+
cliflag.StringArrayVarP(root.Flags(), &oauth2GithubAllowedOrganizations, "oauth2-github-allowed-orgs", "", "CODER_OAUTH2_GITHUB_ALLOWED_ORGS", nil,
395+
"Specifies organizations the user must be a member of to authenticate with GitHub.")
396+
cliflag.BoolVarP(root.Flags(), &oauth2GithubAllowSignups, "oauth2-github-allow-signups", "", "CODER_OAUTH2_GITHUB_ALLOW_SIGNUPS", false,
397+
"Specifies whether new users can sign up with GitHub.")
376398
cliflag.BoolVarP(root.Flags(), &tlsEnable, "tls-enable", "", "CODER_TLS_ENABLE", false, "Specifies if TLS will be enabled")
377399
cliflag.StringVarP(root.Flags(), &tlsCertFile, "tls-cert-file", "", "CODER_TLS_CERT_FILE", "",
378400
"Specifies the path to the certificate for TLS. It requires a PEM-encoded file. "+
@@ -572,6 +594,42 @@ func configureTLS(listener net.Listener, tlsMinVersion, tlsClientAuth, tlsCertFi
572594
return tls.NewListener(listener, tlsConfig), nil
573595
}
574596

597+
func configureGithubOAuth2(accessURL *url.URL, clientID, clientSecret string, allowSignups bool, allowOrgs []string) (*coderd.GithubOAuth2Config, error) {
598+
redirectURL, err := accessURL.Parse("/api/v2/users/oauth2/github/callback")
599+
if err != nil {
600+
return nil, xerrors.Errorf("parse github oauth callback url: %w", err)
601+
}
602+
return &coderd.GithubOAuth2Config{
603+
OAuth2Config: &oauth2.Config{
604+
ClientID: clientID,
605+
ClientSecret: clientSecret,
606+
Endpoint: xgithub.Endpoint,
607+
RedirectURL: redirectURL.String(),
608+
Scopes: []string{
609+
"read:user",
610+
"read:org",
611+
"user:email",
612+
},
613+
},
614+
AllowSignups: allowSignups,
615+
AllowOrganizations: allowOrgs,
616+
AuthenticatedUser: func(ctx context.Context, client *http.Client) (*github.User, error) {
617+
user, _, err := github.NewClient(client).Users.Get(ctx, "")
618+
return user, err
619+
},
620+
ListEmails: func(ctx context.Context, client *http.Client) ([]*github.UserEmail, error) {
621+
emails, _, err := github.NewClient(client).Users.ListEmails(ctx, &github.ListOptions{})
622+
return emails, err
623+
},
624+
ListOrganizationMemberships: func(ctx context.Context, client *http.Client) ([]*github.Membership, error) {
625+
memberships, _, err := github.NewClient(client).Organizations.ListOrgMemberships(ctx, &github.ListOrgMembershipsOptions{
626+
State: "active",
627+
})
628+
return memberships, err
629+
},
630+
}, nil
631+
}
632+
575633
type datadogLogger struct {
576634
logger slog.Logger
577635
}

coderd/coderd.go

+21-10
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ type Options struct {
4242
AWSCertificates awsidentity.Certificates
4343
AzureCertificates x509.VerifyOptions
4444
GoogleTokenValidator *idtoken.Validator
45+
GithubOAuth2Config *GithubOAuth2Config
4546
ICEServers []webrtc.ICEServer
4647
SecureAuthCookie bool
4748
SSHKeygenAlgorithm gitsshkey.Algorithm
@@ -62,6 +63,9 @@ func New(options *Options) (http.Handler, func()) {
6263
api := &api{
6364
Options: options,
6465
}
66+
apiKeyMiddleware := httpmw.ExtractAPIKey(options.Database, &httpmw.OAuth2Configs{
67+
Github: options.GithubOAuth2Config,
68+
})
6569

6670
r := chi.NewRouter()
6771
r.Route("/api/v2", func(r chi.Router) {
@@ -86,7 +90,7 @@ func New(options *Options) (http.Handler, func()) {
8690
})
8791
r.Route("/files", func(r chi.Router) {
8892
r.Use(
89-
httpmw.ExtractAPIKey(options.Database, nil),
93+
apiKeyMiddleware,
9094
// This number is arbitrary, but reading/writing
9195
// file content is expensive so it should be small.
9296
httpmw.RateLimitPerMinute(12),
@@ -96,7 +100,7 @@ func New(options *Options) (http.Handler, func()) {
96100
})
97101
r.Route("/organizations/{organization}", func(r chi.Router) {
98102
r.Use(
99-
httpmw.ExtractAPIKey(options.Database, nil),
103+
apiKeyMiddleware,
100104
httpmw.ExtractOrganizationParam(options.Database),
101105
)
102106
r.Get("/", api.organization)
@@ -109,7 +113,7 @@ func New(options *Options) (http.Handler, func()) {
109113
})
110114
})
111115
r.Route("/parameters/{scope}/{id}", func(r chi.Router) {
112-
r.Use(httpmw.ExtractAPIKey(options.Database, nil))
116+
r.Use(apiKeyMiddleware)
113117
r.Post("/", api.postParameter)
114118
r.Get("/", api.parameters)
115119
r.Route("/{name}", func(r chi.Router) {
@@ -118,7 +122,7 @@ func New(options *Options) (http.Handler, func()) {
118122
})
119123
r.Route("/templates/{template}", func(r chi.Router) {
120124
r.Use(
121-
httpmw.ExtractAPIKey(options.Database, nil),
125+
apiKeyMiddleware,
122126
httpmw.ExtractTemplateParam(options.Database),
123127
httpmw.ExtractOrganizationParam(options.Database),
124128
)
@@ -132,7 +136,7 @@ func New(options *Options) (http.Handler, func()) {
132136
})
133137
r.Route("/templateversions/{templateversion}", func(r chi.Router) {
134138
r.Use(
135-
httpmw.ExtractAPIKey(options.Database, nil),
139+
apiKeyMiddleware,
136140
httpmw.ExtractTemplateVersionParam(options.Database),
137141
httpmw.ExtractOrganizationParam(options.Database),
138142
)
@@ -154,8 +158,15 @@ func New(options *Options) (http.Handler, func()) {
154158
r.Post("/first", api.postFirstUser)
155159
r.Post("/login", api.postLogin)
156160
r.Post("/logout", api.postLogout)
161+
r.Get("/authmethods", api.userAuthMethods)
162+
r.Route("/oauth2", func(r chi.Router) {
163+
r.Route("/github", func(r chi.Router) {
164+
r.Use(httpmw.ExtractOAuth2(options.GithubOAuth2Config))
165+
r.Get("/callback", api.userOAuth2Github)
166+
})
167+
})
157168
r.Group(func(r chi.Router) {
158-
r.Use(httpmw.ExtractAPIKey(options.Database, nil))
169+
r.Use(apiKeyMiddleware)
159170
r.Post("/", api.postUsers)
160171
r.Get("/", api.users)
161172
r.Route("/{user}", func(r chi.Router) {
@@ -193,7 +204,7 @@ func New(options *Options) (http.Handler, func()) {
193204
})
194205
r.Route("/{workspaceagent}", func(r chi.Router) {
195206
r.Use(
196-
httpmw.ExtractAPIKey(options.Database, nil),
207+
apiKeyMiddleware,
197208
httpmw.ExtractWorkspaceAgentParam(options.Database),
198209
)
199210
r.Get("/", api.workspaceAgent)
@@ -204,15 +215,15 @@ func New(options *Options) (http.Handler, func()) {
204215
})
205216
r.Route("/workspaceresources/{workspaceresource}", func(r chi.Router) {
206217
r.Use(
207-
httpmw.ExtractAPIKey(options.Database, nil),
218+
apiKeyMiddleware,
208219
httpmw.ExtractWorkspaceResourceParam(options.Database),
209220
httpmw.ExtractWorkspaceParam(options.Database),
210221
)
211222
r.Get("/", api.workspaceResource)
212223
})
213224
r.Route("/workspaces/{workspace}", func(r chi.Router) {
214225
r.Use(
215-
httpmw.ExtractAPIKey(options.Database, nil),
226+
apiKeyMiddleware,
216227
httpmw.ExtractWorkspaceParam(options.Database),
217228
)
218229
r.Get("/", api.workspace)
@@ -230,7 +241,7 @@ func New(options *Options) (http.Handler, func()) {
230241
})
231242
r.Route("/workspacebuilds/{workspacebuild}", func(r chi.Router) {
232243
r.Use(
233-
httpmw.ExtractAPIKey(options.Database, nil),
244+
apiKeyMiddleware,
234245
httpmw.ExtractWorkspaceBuildParam(options.Database),
235246
httpmw.ExtractWorkspaceParam(options.Database),
236247
)

coderd/coderdtest/coderdtest.go

+2
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ import (
5353
type Options struct {
5454
AWSCertificates awsidentity.Certificates
5555
AzureCertificates x509.VerifyOptions
56+
GithubOAuth2Config *coderd.GithubOAuth2Config
5657
GoogleTokenValidator *idtoken.Validator
5758
SSHKeygenAlgorithm gitsshkey.Algorithm
5859
APIRateLimit int
@@ -123,6 +124,7 @@ func New(t *testing.T, options *Options) *codersdk.Client {
123124

124125
AWSCertificates: options.AWSCertificates,
125126
AzureCertificates: options.AzureCertificates,
127+
GithubOAuth2Config: options.GithubOAuth2Config,
126128
GoogleTokenValidator: options.GoogleTokenValidator,
127129
SSHKeygenAlgorithm: options.SSHKeygenAlgorithm,
128130
TURNServer: turnServer,

coderd/database/databasefake/databasefake.go

+25-18
Original file line numberDiff line numberDiff line change
@@ -432,6 +432,16 @@ func (q *fakeQuerier) GetWorkspacesByUserID(_ context.Context, req database.GetW
432432
return workspaces, nil
433433
}
434434

435+
func (q *fakeQuerier) GetOrganizations(_ context.Context) ([]database.Organization, error) {
436+
q.mutex.RLock()
437+
defer q.mutex.RUnlock()
438+
439+
if len(q.organizations) == 0 {
440+
return nil, sql.ErrNoRows
441+
}
442+
return q.organizations, nil
443+
}
444+
435445
func (q *fakeQuerier) GetOrganizationByID(_ context.Context, id uuid.UUID) (database.Organization, error) {
436446
q.mutex.RLock()
437447
defer q.mutex.RUnlock()
@@ -854,21 +864,18 @@ func (q *fakeQuerier) InsertAPIKey(_ context.Context, arg database.InsertAPIKeyP
854864

855865
//nolint:gosimple
856866
key := database.APIKey{
857-
ID: arg.ID,
858-
HashedSecret: arg.HashedSecret,
859-
UserID: arg.UserID,
860-
Application: arg.Application,
861-
Name: arg.Name,
862-
LastUsed: arg.LastUsed,
863-
ExpiresAt: arg.ExpiresAt,
864-
CreatedAt: arg.CreatedAt,
865-
UpdatedAt: arg.UpdatedAt,
866-
LoginType: arg.LoginType,
867-
OIDCAccessToken: arg.OIDCAccessToken,
868-
OIDCRefreshToken: arg.OIDCRefreshToken,
869-
OIDCIDToken: arg.OIDCIDToken,
870-
OIDCExpiry: arg.OIDCExpiry,
871-
DevurlToken: arg.DevurlToken,
867+
ID: arg.ID,
868+
HashedSecret: arg.HashedSecret,
869+
UserID: arg.UserID,
870+
ExpiresAt: arg.ExpiresAt,
871+
CreatedAt: arg.CreatedAt,
872+
UpdatedAt: arg.UpdatedAt,
873+
LastUsed: arg.LastUsed,
874+
LoginType: arg.LoginType,
875+
OAuthAccessToken: arg.OAuthAccessToken,
876+
OAuthRefreshToken: arg.OAuthRefreshToken,
877+
OAuthIDToken: arg.OAuthIDToken,
878+
OAuthExpiry: arg.OAuthExpiry,
872879
}
873880
q.apiKeys = append(q.apiKeys, key)
874881
return key, nil
@@ -1180,9 +1187,9 @@ func (q *fakeQuerier) UpdateAPIKeyByID(_ context.Context, arg database.UpdateAPI
11801187
}
11811188
apiKey.LastUsed = arg.LastUsed
11821189
apiKey.ExpiresAt = arg.ExpiresAt
1183-
apiKey.OIDCAccessToken = arg.OIDCAccessToken
1184-
apiKey.OIDCRefreshToken = arg.OIDCRefreshToken
1185-
apiKey.OIDCExpiry = arg.OIDCExpiry
1190+
apiKey.OAuthAccessToken = arg.OAuthAccessToken
1191+
apiKey.OAuthRefreshToken = arg.OAuthRefreshToken
1192+
apiKey.OAuthExpiry = arg.OAuthExpiry
11861193
q.apiKeys[index] = apiKey
11871194
return nil
11881195
}

0 commit comments

Comments
 (0)