Skip to content

feat: Add GitHub OAuth #1050

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Apr 23, 2022
2 changes: 2 additions & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
"nolint",
"nosec",
"ntqry",
"OIDC",
"oneof",
"parameterscopeid",
"pqtype",
Expand All @@ -46,6 +47,7 @@
"ptytest",
"retrier",
"sdkproto",
"Signup",
"stretchr",
"TCGETS",
"tcpip",
Expand Down
9 changes: 9 additions & 0 deletions cli/cliflag/cliflag.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"fmt"
"os"
"strconv"
"strings"

"github.com/spf13/pflag"
)
Expand All @@ -27,6 +28,14 @@ func StringVarP(flagset *pflag.FlagSet, p *string, name string, shorthand string
flagset.StringVarP(p, name, shorthand, v, fmtUsage(usage, env))
}

func StringArrayVarP(flagset *pflag.FlagSet, ptr *[]string, name string, shorthand string, env string, def []string, usage string) {
val, ok := os.LookupEnv(env)
if ok {
def = strings.Split(val, ",")
}
flagset.StringArrayVarP(ptr, name, shorthand, def, usage)
}

// Uint8VarP sets a uint8 flag on the given flag set.
func Uint8VarP(flagset *pflag.FlagSet, ptr *uint8, name string, shorthand string, env string, def uint8, usage string) {
val, ok := os.LookupEnv(env)
Expand Down
20 changes: 20 additions & 0 deletions cli/cliflag/cliflag_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,26 @@ func TestCliflag(t *testing.T) {
require.NotContains(t, flagset.FlagUsages(), " - consumes")
})

t.Run("StringArrayDefault", func(t *testing.T) {
var ptr []string
flagset, name, shorthand, env, usage := randomFlag()
def := []string{"hello"}
cliflag.StringArrayVarP(flagset, &ptr, name, shorthand, env, def, usage)
got, err := flagset.GetStringArray(name)
require.NoError(t, err)
require.Equal(t, def, got)
})

t.Run("StringArrayEnvVar", func(t *testing.T) {
var ptr []string
flagset, name, shorthand, env, usage := randomFlag()
t.Setenv(env, "wow,test")
cliflag.StringArrayVarP(flagset, &ptr, name, shorthand, env, nil, usage)
got, err := flagset.GetStringArray(name)
require.NoError(t, err)
require.Equal(t, []string{"wow", "test"}, got)
})

t.Run("IntDefault", func(t *testing.T) {
var ptr uint8
flagset, name, shorthand, env, usage := randomFlag()
Expand Down
80 changes: 69 additions & 11 deletions cli/start.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@ import (

"github.com/briandowns/spinner"
"github.com/coreos/go-systemd/daemon"
"github.com/google/go-github/v43/github"
"github.com/spf13/cobra"
"golang.org/x/oauth2"
xgithub "golang.org/x/oauth2/github"
"golang.org/x/xerrors"
"google.golang.org/api/idtoken"
"google.golang.org/api/option"
Expand Down Expand Up @@ -49,17 +52,21 @@ func start() *cobra.Command {
dev bool
postgresURL string
// provisionerDaemonCount is a uint8 to ensure a number > 0.
provisionerDaemonCount uint8
tlsCertFile string
tlsClientCAFile string
tlsClientAuth string
tlsEnable bool
tlsKeyFile string
tlsMinVersion string
skipTunnel bool
traceDatadog bool
secureAuthCookie bool
sshKeygenAlgorithmRaw string
provisionerDaemonCount uint8
oauth2GithubClientID string
oauth2GithubClientSecret string
oauth2GithubAllowedOrganizations []string
oauth2GithubAllowSignups bool
tlsCertFile string
tlsClientCAFile string
tlsClientAuth string
tlsEnable bool
tlsKeyFile string
tlsMinVersion string
skipTunnel bool
traceDatadog bool
secureAuthCookie bool
sshKeygenAlgorithmRaw string
)
root := &cobra.Command{
Use: "start",
Expand Down Expand Up @@ -164,6 +171,13 @@ func start() *cobra.Command {
SSHKeygenAlgorithm: sshKeygenAlgorithm,
}

if oauth2GithubClientSecret != "" {
options.GithubOAuth2Config, err = configureGithubOAuth2(accessURLParsed, oauth2GithubClientID, oauth2GithubClientSecret, oauth2GithubAllowSignups, oauth2GithubAllowedOrganizations)
if err != nil {
return xerrors.Errorf("configure github oauth2: %w", err)
}
}

_, _ = fmt.Fprintf(cmd.ErrOrStderr(), "access-url: %s\n", accessURL)
_, _ = fmt.Fprintf(cmd.ErrOrStderr(), "provisioner-daemons: %d\n", provisionerDaemonCount)
_, _ = fmt.Fprintln(cmd.ErrOrStderr())
Expand Down Expand Up @@ -357,6 +371,14 @@ func start() *cobra.Command {
cliflag.BoolVarP(root.Flags(), &dev, "dev", "", "CODER_DEV_MODE", false, "Serve Coder in dev mode for tinkering")
cliflag.StringVarP(root.Flags(), &postgresURL, "postgres-url", "", "CODER_PG_CONNECTION_URL", "", "URL of a PostgreSQL database to connect to")
cliflag.Uint8VarP(root.Flags(), &provisionerDaemonCount, "provisioner-daemons", "", "CODER_PROVISIONER_DAEMONS", 1, "The amount of provisioner daemons to create on start.")
cliflag.StringVarP(root.Flags(), &oauth2GithubClientID, "oauth2-github-client-id", "", "CODER_OAUTH2_GITHUB_CLIENT_ID", "",
"Specifies a client ID to use for oauth2 with GitHub.")
cliflag.StringVarP(root.Flags(), &oauth2GithubClientSecret, "oauth2-github-client-secret", "", "CODER_OAUTH2_GITHUB_CLIENT_SECRET", "",
"Specifies a client secret to use for oauth2 with GitHub.")
cliflag.StringArrayVarP(root.Flags(), &oauth2GithubAllowedOrganizations, "oauth2-github-allowed-orgs", "", "CODER_OAUTH2_GITHUB_ALLOWED_ORGS", nil,
"Specifies organizations the user must be a member of to authenticate with GitHub.")
cliflag.BoolVarP(root.Flags(), &oauth2GithubAllowSignups, "oauth2-github-allow-signups", "", "CODER_OAUTH2_GITHUB_ALLOW_SIGNUPS", false,
"Specifies whether new users can sign up with GitHub.")
cliflag.BoolVarP(root.Flags(), &tlsEnable, "tls-enable", "", "CODER_TLS_ENABLE", false, "Specifies if TLS will be enabled")
cliflag.StringVarP(root.Flags(), &tlsCertFile, "tls-cert-file", "", "CODER_TLS_CERT_FILE", "",
"Specifies the path to the certificate for TLS. It requires a PEM-encoded file. "+
Expand Down Expand Up @@ -534,3 +556,39 @@ func configureTLS(listener net.Listener, tlsMinVersion, tlsClientAuth, tlsCertFi

return tls.NewListener(listener, tlsConfig), nil
}

func configureGithubOAuth2(accessURL *url.URL, clientID, clientSecret string, allowSignups bool, allowOrgs []string) (*coderd.GithubOAuth2Config, error) {
redirectURL, err := accessURL.Parse("/api/v2/users/oauth2/github/callback")
if err != nil {
return nil, xerrors.Errorf("parse github oauth callback url: %w", err)
}
return &coderd.GithubOAuth2Config{
OAuth2Config: &oauth2.Config{
ClientID: clientID,
ClientSecret: clientSecret,
Endpoint: xgithub.Endpoint,
RedirectURL: redirectURL.String(),
Scopes: []string{
"read:user",
"read:org",
"user:email",
},
},
AllowSignups: allowSignups,
AllowOrganizations: allowOrgs,
AuthenticatedUser: func(ctx context.Context, client *http.Client) (*github.User, error) {
user, _, err := github.NewClient(client).Users.Get(ctx, "")
return user, err
},
ListEmails: func(ctx context.Context, client *http.Client) ([]*github.UserEmail, error) {
emails, _, err := github.NewClient(client).Users.ListEmails(ctx, &github.ListOptions{})
return emails, err
},
ListOrganizationMemberships: func(ctx context.Context, client *http.Client) ([]*github.Membership, error) {
memberships, _, err := github.NewClient(client).Organizations.ListOrgMemberships(ctx, &github.ListOrgMembershipsOptions{
State: "active",
})
return memberships, err
},
}, nil
}
31 changes: 21 additions & 10 deletions coderd/coderd.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ type Options struct {

AWSCertificates awsidentity.Certificates
GoogleTokenValidator *idtoken.Validator
GithubOAuth2Config *GithubOAuth2Config

SecureAuthCookie bool
SSHKeygenAlgorithm gitsshkey.Algorithm
Expand All @@ -50,6 +51,9 @@ func New(options *Options) (http.Handler, func()) {
api := &api{
Options: options,
}
apiKeyMiddleware := httpmw.ExtractAPIKey(options.Database, &httpmw.OAuth2Configs{
Github: options.GithubOAuth2Config,
})

r := chi.NewRouter()
r.Route("/api/v2", func(r chi.Router) {
Expand All @@ -74,7 +78,7 @@ func New(options *Options) (http.Handler, func()) {
})
r.Route("/files", func(r chi.Router) {
r.Use(
httpmw.ExtractAPIKey(options.Database, nil),
apiKeyMiddleware,
// This number is arbitrary, but reading/writing
// file content is expensive so it should be small.
httpmw.RateLimitPerMinute(12),
Expand All @@ -84,7 +88,7 @@ func New(options *Options) (http.Handler, func()) {
})
r.Route("/organizations/{organization}", func(r chi.Router) {
r.Use(
httpmw.ExtractAPIKey(options.Database, nil),
apiKeyMiddleware,
httpmw.ExtractOrganizationParam(options.Database),
)
r.Get("/", api.organization)
Expand All @@ -97,7 +101,7 @@ func New(options *Options) (http.Handler, func()) {
})
})
r.Route("/parameters/{scope}/{id}", func(r chi.Router) {
r.Use(httpmw.ExtractAPIKey(options.Database, nil))
r.Use(apiKeyMiddleware)
r.Post("/", api.postParameter)
r.Get("/", api.parameters)
r.Route("/{name}", func(r chi.Router) {
Expand All @@ -106,7 +110,7 @@ func New(options *Options) (http.Handler, func()) {
})
r.Route("/templates/{template}", func(r chi.Router) {
r.Use(
httpmw.ExtractAPIKey(options.Database, nil),
apiKeyMiddleware,
httpmw.ExtractTemplateParam(options.Database),
httpmw.ExtractOrganizationParam(options.Database),
)
Expand All @@ -120,7 +124,7 @@ func New(options *Options) (http.Handler, func()) {
})
r.Route("/templateversions/{templateversion}", func(r chi.Router) {
r.Use(
httpmw.ExtractAPIKey(options.Database, nil),
apiKeyMiddleware,
httpmw.ExtractTemplateVersionParam(options.Database),
httpmw.ExtractOrganizationParam(options.Database),
)
Expand All @@ -142,8 +146,15 @@ func New(options *Options) (http.Handler, func()) {
r.Post("/first", api.postFirstUser)
r.Post("/login", api.postLogin)
r.Post("/logout", api.postLogout)
r.Get("/authmethods", api.userAuthMethods)
r.Route("/oauth2", func(r chi.Router) {
r.Route("/github", func(r chi.Router) {
r.Use(httpmw.ExtractOAuth2(options.GithubOAuth2Config))
r.Get("/callback", api.userOAuth2Github)
})
})
r.Group(func(r chi.Router) {
r.Use(httpmw.ExtractAPIKey(options.Database, nil))
r.Use(apiKeyMiddleware)
r.Post("/", api.postUsers)
r.Route("/{user}", func(r chi.Router) {
r.Use(httpmw.ExtractUserParam(options.Database))
Expand Down Expand Up @@ -177,7 +188,7 @@ func New(options *Options) (http.Handler, func()) {
})
r.Route("/{workspaceagent}", func(r chi.Router) {
r.Use(
httpmw.ExtractAPIKey(options.Database, nil),
apiKeyMiddleware,
httpmw.ExtractWorkspaceAgentParam(options.Database),
)
r.Get("/", api.workspaceAgent)
Expand All @@ -186,15 +197,15 @@ func New(options *Options) (http.Handler, func()) {
})
r.Route("/workspaceresources/{workspaceresource}", func(r chi.Router) {
r.Use(
httpmw.ExtractAPIKey(options.Database, nil),
apiKeyMiddleware,
httpmw.ExtractWorkspaceResourceParam(options.Database),
httpmw.ExtractWorkspaceParam(options.Database),
)
r.Get("/", api.workspaceResource)
})
r.Route("/workspaces/{workspace}", func(r chi.Router) {
r.Use(
httpmw.ExtractAPIKey(options.Database, nil),
apiKeyMiddleware,
httpmw.ExtractWorkspaceParam(options.Database),
)
r.Get("/", api.workspace)
Expand All @@ -212,7 +223,7 @@ func New(options *Options) (http.Handler, func()) {
})
r.Route("/workspacebuilds/{workspacebuild}", func(r chi.Router) {
r.Use(
httpmw.ExtractAPIKey(options.Database, nil),
apiKeyMiddleware,
httpmw.ExtractWorkspaceBuildParam(options.Database),
httpmw.ExtractWorkspaceParam(options.Database),
)
Expand Down
2 changes: 2 additions & 0 deletions coderd/coderdtest/coderdtest.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ import (

type Options struct {
AWSInstanceIdentity awsidentity.Certificates
GithubOAuth2Config *coderd.GithubOAuth2Config
GoogleInstanceIdentity *idtoken.Validator
SSHKeygenAlgorithm gitsshkey.Algorithm
}
Expand Down Expand Up @@ -115,6 +116,7 @@ func New(t *testing.T, options *Options) *codersdk.Client {
Pubsub: pubsub,

AWSCertificates: options.AWSInstanceIdentity,
GithubOAuth2Config: options.GithubOAuth2Config,
GoogleTokenValidator: options.GoogleInstanceIdentity,
SSHKeygenAlgorithm: options.SSHKeygenAlgorithm,
})
Expand Down
43 changes: 25 additions & 18 deletions coderd/database/databasefake/databasefake.go
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,16 @@ func (q *fakeQuerier) GetWorkspacesByUserID(_ context.Context, req database.GetW
return workspaces, nil
}

func (q *fakeQuerier) GetOrganizations(_ context.Context) ([]database.Organization, error) {
q.mutex.RLock()
defer q.mutex.RUnlock()

if len(q.organizations) == 0 {
return nil, sql.ErrNoRows
}
return q.organizations, nil
}

func (q *fakeQuerier) GetOrganizationByID(_ context.Context, id uuid.UUID) (database.Organization, error) {
q.mutex.RLock()
defer q.mutex.RUnlock()
Expand Down Expand Up @@ -787,21 +797,18 @@ func (q *fakeQuerier) InsertAPIKey(_ context.Context, arg database.InsertAPIKeyP

//nolint:gosimple
key := database.APIKey{
ID: arg.ID,
HashedSecret: arg.HashedSecret,
UserID: arg.UserID,
Application: arg.Application,
Name: arg.Name,
LastUsed: arg.LastUsed,
ExpiresAt: arg.ExpiresAt,
CreatedAt: arg.CreatedAt,
UpdatedAt: arg.UpdatedAt,
LoginType: arg.LoginType,
OIDCAccessToken: arg.OIDCAccessToken,
OIDCRefreshToken: arg.OIDCRefreshToken,
OIDCIDToken: arg.OIDCIDToken,
OIDCExpiry: arg.OIDCExpiry,
DevurlToken: arg.DevurlToken,
ID: arg.ID,
HashedSecret: arg.HashedSecret,
UserID: arg.UserID,
ExpiresAt: arg.ExpiresAt,
CreatedAt: arg.CreatedAt,
UpdatedAt: arg.UpdatedAt,
LastUsed: arg.LastUsed,
LoginType: arg.LoginType,
OAuthAccessToken: arg.OAuthAccessToken,
OAuthRefreshToken: arg.OAuthRefreshToken,
OAuthIDToken: arg.OAuthIDToken,
OAuthExpiry: arg.OAuthExpiry,
}
q.apiKeys = append(q.apiKeys, key)
return key, nil
Expand Down Expand Up @@ -1116,9 +1123,9 @@ func (q *fakeQuerier) UpdateAPIKeyByID(_ context.Context, arg database.UpdateAPI
}
apiKey.LastUsed = arg.LastUsed
apiKey.ExpiresAt = arg.ExpiresAt
apiKey.OIDCAccessToken = arg.OIDCAccessToken
apiKey.OIDCRefreshToken = arg.OIDCRefreshToken
apiKey.OIDCExpiry = arg.OIDCExpiry
apiKey.OAuthAccessToken = arg.OAuthAccessToken
apiKey.OAuthRefreshToken = arg.OAuthRefreshToken
apiKey.OAuthExpiry = arg.OAuthExpiry
q.apiKeys[index] = apiKey
return nil
}
Expand Down
Loading