diff --git a/.vscode/settings.json b/.vscode/settings.json index 160a8cd98a01e..771981cf7a686 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -35,6 +35,7 @@ "nolint", "nosec", "ntqry", + "OIDC", "oneof", "parameterscopeid", "pqtype", @@ -46,6 +47,7 @@ "ptytest", "retrier", "sdkproto", + "Signup", "stretchr", "TCGETS", "tcpip", diff --git a/cli/cliflag/cliflag.go b/cli/cliflag/cliflag.go index e846d5fc391ae..be2117b4d44fa 100644 --- a/cli/cliflag/cliflag.go +++ b/cli/cliflag/cliflag.go @@ -14,6 +14,7 @@ import ( "fmt" "os" "strconv" + "strings" "github.com/spf13/pflag" ) @@ -27,6 +28,14 @@ func StringVarP(flagset *pflag.FlagSet, p *string, name string, shorthand string flagset.StringVarP(p, name, shorthand, v, fmtUsage(usage, env)) } +func StringArrayVarP(flagset *pflag.FlagSet, ptr *[]string, name string, shorthand string, env string, def []string, usage string) { + val, ok := os.LookupEnv(env) + if ok { + def = strings.Split(val, ",") + } + flagset.StringArrayVarP(ptr, name, shorthand, def, usage) +} + // Uint8VarP sets a uint8 flag on the given flag set. func Uint8VarP(flagset *pflag.FlagSet, ptr *uint8, name string, shorthand string, env string, def uint8, usage string) { val, ok := os.LookupEnv(env) diff --git a/cli/cliflag/cliflag_test.go b/cli/cliflag/cliflag_test.go index 2228b7e10bbc9..b0684fedb1d98 100644 --- a/cli/cliflag/cliflag_test.go +++ b/cli/cliflag/cliflag_test.go @@ -54,6 +54,26 @@ func TestCliflag(t *testing.T) { require.NotContains(t, flagset.FlagUsages(), " - consumes") }) + t.Run("StringArrayDefault", func(t *testing.T) { + var ptr []string + flagset, name, shorthand, env, usage := randomFlag() + def := []string{"hello"} + cliflag.StringArrayVarP(flagset, &ptr, name, shorthand, env, def, usage) + got, err := flagset.GetStringArray(name) + require.NoError(t, err) + require.Equal(t, def, got) + }) + + t.Run("StringArrayEnvVar", func(t *testing.T) { + var ptr []string + flagset, name, shorthand, env, usage := randomFlag() + t.Setenv(env, "wow,test") + cliflag.StringArrayVarP(flagset, &ptr, name, shorthand, env, nil, usage) + got, err := flagset.GetStringArray(name) + require.NoError(t, err) + require.Equal(t, []string{"wow", "test"}, got) + }) + t.Run("IntDefault", func(t *testing.T) { var ptr uint8 flagset, name, shorthand, env, usage := randomFlag() diff --git a/cli/server.go b/cli/server.go index f442fda0b9992..a2e858b9f463e 100644 --- a/cli/server.go +++ b/cli/server.go @@ -18,8 +18,11 @@ import ( "github.com/briandowns/spinner" "github.com/coreos/go-systemd/daemon" + "github.com/google/go-github/v43/github" "github.com/pion/turn/v2" "github.com/spf13/cobra" + "golang.org/x/oauth2" + xgithub "golang.org/x/oauth2/github" "golang.org/x/xerrors" "google.golang.org/api/idtoken" "google.golang.org/api/option" @@ -51,19 +54,23 @@ func server() *cobra.Command { dev bool postgresURL string // provisionerDaemonCount is a uint8 to ensure a number > 0. - provisionerDaemonCount uint8 - tlsCertFile string - tlsClientCAFile string - tlsClientAuth string - tlsEnable bool - tlsKeyFile string - tlsMinVersion string - turnRelayAddress string - skipTunnel bool - traceDatadog bool - secureAuthCookie bool - sshKeygenAlgorithmRaw string - spooky bool + provisionerDaemonCount uint8 + oauth2GithubClientID string + oauth2GithubClientSecret string + oauth2GithubAllowedOrganizations []string + oauth2GithubAllowSignups bool + tlsCertFile string + tlsClientCAFile string + tlsClientAuth string + tlsEnable bool + tlsKeyFile string + tlsMinVersion string + turnRelayAddress string + skipTunnel bool + traceDatadog bool + secureAuthCookie bool + sshKeygenAlgorithmRaw string + spooky bool ) root := &cobra.Command{ @@ -180,6 +187,13 @@ func server() *cobra.Command { TURNServer: turnServer, } + if oauth2GithubClientSecret != "" { + options.GithubOAuth2Config, err = configureGithubOAuth2(accessURLParsed, oauth2GithubClientID, oauth2GithubClientSecret, oauth2GithubAllowSignups, oauth2GithubAllowedOrganizations) + if err != nil { + return xerrors.Errorf("configure github oauth2: %w", err) + } + } + _, _ = fmt.Fprintf(cmd.ErrOrStderr(), "access-url: %s\n", accessURL) _, _ = fmt.Fprintf(cmd.ErrOrStderr(), "provisioner-daemons: %d\n", provisionerDaemonCount) _, _ = fmt.Fprintln(cmd.ErrOrStderr()) @@ -373,6 +387,14 @@ func server() *cobra.Command { cliflag.BoolVarP(root.Flags(), &dev, "dev", "", "CODER_DEV_MODE", false, "Serve Coder in dev mode for tinkering") cliflag.StringVarP(root.Flags(), &postgresURL, "postgres-url", "", "CODER_PG_CONNECTION_URL", "", "URL of a PostgreSQL database to connect to") cliflag.Uint8VarP(root.Flags(), &provisionerDaemonCount, "provisioner-daemons", "", "CODER_PROVISIONER_DAEMONS", 1, "The amount of provisioner daemons to create on start.") + cliflag.StringVarP(root.Flags(), &oauth2GithubClientID, "oauth2-github-client-id", "", "CODER_OAUTH2_GITHUB_CLIENT_ID", "", + "Specifies a client ID to use for oauth2 with GitHub.") + cliflag.StringVarP(root.Flags(), &oauth2GithubClientSecret, "oauth2-github-client-secret", "", "CODER_OAUTH2_GITHUB_CLIENT_SECRET", "", + "Specifies a client secret to use for oauth2 with GitHub.") + cliflag.StringArrayVarP(root.Flags(), &oauth2GithubAllowedOrganizations, "oauth2-github-allowed-orgs", "", "CODER_OAUTH2_GITHUB_ALLOWED_ORGS", nil, + "Specifies organizations the user must be a member of to authenticate with GitHub.") + cliflag.BoolVarP(root.Flags(), &oauth2GithubAllowSignups, "oauth2-github-allow-signups", "", "CODER_OAUTH2_GITHUB_ALLOW_SIGNUPS", false, + "Specifies whether new users can sign up with GitHub.") cliflag.BoolVarP(root.Flags(), &tlsEnable, "tls-enable", "", "CODER_TLS_ENABLE", false, "Specifies if TLS will be enabled") cliflag.StringVarP(root.Flags(), &tlsCertFile, "tls-cert-file", "", "CODER_TLS_CERT_FILE", "", "Specifies the path to the certificate for TLS. It requires a PEM-encoded file. "+ @@ -572,6 +594,42 @@ func configureTLS(listener net.Listener, tlsMinVersion, tlsClientAuth, tlsCertFi return tls.NewListener(listener, tlsConfig), nil } +func configureGithubOAuth2(accessURL *url.URL, clientID, clientSecret string, allowSignups bool, allowOrgs []string) (*coderd.GithubOAuth2Config, error) { + redirectURL, err := accessURL.Parse("/api/v2/users/oauth2/github/callback") + if err != nil { + return nil, xerrors.Errorf("parse github oauth callback url: %w", err) + } + return &coderd.GithubOAuth2Config{ + OAuth2Config: &oauth2.Config{ + ClientID: clientID, + ClientSecret: clientSecret, + Endpoint: xgithub.Endpoint, + RedirectURL: redirectURL.String(), + Scopes: []string{ + "read:user", + "read:org", + "user:email", + }, + }, + AllowSignups: allowSignups, + AllowOrganizations: allowOrgs, + AuthenticatedUser: func(ctx context.Context, client *http.Client) (*github.User, error) { + user, _, err := github.NewClient(client).Users.Get(ctx, "") + return user, err + }, + ListEmails: func(ctx context.Context, client *http.Client) ([]*github.UserEmail, error) { + emails, _, err := github.NewClient(client).Users.ListEmails(ctx, &github.ListOptions{}) + return emails, err + }, + ListOrganizationMemberships: func(ctx context.Context, client *http.Client) ([]*github.Membership, error) { + memberships, _, err := github.NewClient(client).Organizations.ListOrgMemberships(ctx, &github.ListOrgMembershipsOptions{ + State: "active", + }) + return memberships, err + }, + }, nil +} + type datadogLogger struct { logger slog.Logger } diff --git a/coderd/coderd.go b/coderd/coderd.go index f69ad623d79b0..2f8bde36e4638 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -42,6 +42,7 @@ type Options struct { AWSCertificates awsidentity.Certificates AzureCertificates x509.VerifyOptions GoogleTokenValidator *idtoken.Validator + GithubOAuth2Config *GithubOAuth2Config ICEServers []webrtc.ICEServer SecureAuthCookie bool SSHKeygenAlgorithm gitsshkey.Algorithm @@ -62,6 +63,9 @@ func New(options *Options) (http.Handler, func()) { api := &api{ Options: options, } + apiKeyMiddleware := httpmw.ExtractAPIKey(options.Database, &httpmw.OAuth2Configs{ + Github: options.GithubOAuth2Config, + }) r := chi.NewRouter() r.Route("/api/v2", func(r chi.Router) { @@ -86,7 +90,7 @@ func New(options *Options) (http.Handler, func()) { }) r.Route("/files", func(r chi.Router) { r.Use( - httpmw.ExtractAPIKey(options.Database, nil), + apiKeyMiddleware, // This number is arbitrary, but reading/writing // file content is expensive so it should be small. httpmw.RateLimitPerMinute(12), @@ -96,7 +100,7 @@ func New(options *Options) (http.Handler, func()) { }) r.Route("/organizations/{organization}", func(r chi.Router) { r.Use( - httpmw.ExtractAPIKey(options.Database, nil), + apiKeyMiddleware, httpmw.ExtractOrganizationParam(options.Database), ) r.Get("/", api.organization) @@ -109,7 +113,7 @@ func New(options *Options) (http.Handler, func()) { }) }) r.Route("/parameters/{scope}/{id}", func(r chi.Router) { - r.Use(httpmw.ExtractAPIKey(options.Database, nil)) + r.Use(apiKeyMiddleware) r.Post("/", api.postParameter) r.Get("/", api.parameters) r.Route("/{name}", func(r chi.Router) { @@ -118,7 +122,7 @@ func New(options *Options) (http.Handler, func()) { }) r.Route("/templates/{template}", func(r chi.Router) { r.Use( - httpmw.ExtractAPIKey(options.Database, nil), + apiKeyMiddleware, httpmw.ExtractTemplateParam(options.Database), httpmw.ExtractOrganizationParam(options.Database), ) @@ -132,7 +136,7 @@ func New(options *Options) (http.Handler, func()) { }) r.Route("/templateversions/{templateversion}", func(r chi.Router) { r.Use( - httpmw.ExtractAPIKey(options.Database, nil), + apiKeyMiddleware, httpmw.ExtractTemplateVersionParam(options.Database), httpmw.ExtractOrganizationParam(options.Database), ) @@ -154,8 +158,15 @@ func New(options *Options) (http.Handler, func()) { r.Post("/first", api.postFirstUser) r.Post("/login", api.postLogin) r.Post("/logout", api.postLogout) + r.Get("/authmethods", api.userAuthMethods) + r.Route("/oauth2", func(r chi.Router) { + r.Route("/github", func(r chi.Router) { + r.Use(httpmw.ExtractOAuth2(options.GithubOAuth2Config)) + r.Get("/callback", api.userOAuth2Github) + }) + }) r.Group(func(r chi.Router) { - r.Use(httpmw.ExtractAPIKey(options.Database, nil)) + r.Use(apiKeyMiddleware) r.Post("/", api.postUsers) r.Get("/", api.users) r.Route("/{user}", func(r chi.Router) { @@ -193,7 +204,7 @@ func New(options *Options) (http.Handler, func()) { }) r.Route("/{workspaceagent}", func(r chi.Router) { r.Use( - httpmw.ExtractAPIKey(options.Database, nil), + apiKeyMiddleware, httpmw.ExtractWorkspaceAgentParam(options.Database), ) r.Get("/", api.workspaceAgent) @@ -204,7 +215,7 @@ func New(options *Options) (http.Handler, func()) { }) r.Route("/workspaceresources/{workspaceresource}", func(r chi.Router) { r.Use( - httpmw.ExtractAPIKey(options.Database, nil), + apiKeyMiddleware, httpmw.ExtractWorkspaceResourceParam(options.Database), httpmw.ExtractWorkspaceParam(options.Database), ) @@ -212,7 +223,7 @@ func New(options *Options) (http.Handler, func()) { }) r.Route("/workspaces/{workspace}", func(r chi.Router) { r.Use( - httpmw.ExtractAPIKey(options.Database, nil), + apiKeyMiddleware, httpmw.ExtractWorkspaceParam(options.Database), ) r.Get("/", api.workspace) @@ -230,7 +241,7 @@ func New(options *Options) (http.Handler, func()) { }) r.Route("/workspacebuilds/{workspacebuild}", func(r chi.Router) { r.Use( - httpmw.ExtractAPIKey(options.Database, nil), + apiKeyMiddleware, httpmw.ExtractWorkspaceBuildParam(options.Database), httpmw.ExtractWorkspaceParam(options.Database), ) diff --git a/coderd/coderdtest/coderdtest.go b/coderd/coderdtest/coderdtest.go index c31b7585f56df..ab9db83d81871 100644 --- a/coderd/coderdtest/coderdtest.go +++ b/coderd/coderdtest/coderdtest.go @@ -53,6 +53,7 @@ import ( type Options struct { AWSCertificates awsidentity.Certificates AzureCertificates x509.VerifyOptions + GithubOAuth2Config *coderd.GithubOAuth2Config GoogleTokenValidator *idtoken.Validator SSHKeygenAlgorithm gitsshkey.Algorithm APIRateLimit int @@ -123,6 +124,7 @@ func New(t *testing.T, options *Options) *codersdk.Client { AWSCertificates: options.AWSCertificates, AzureCertificates: options.AzureCertificates, + GithubOAuth2Config: options.GithubOAuth2Config, GoogleTokenValidator: options.GoogleTokenValidator, SSHKeygenAlgorithm: options.SSHKeygenAlgorithm, TURNServer: turnServer, diff --git a/coderd/database/databasefake/databasefake.go b/coderd/database/databasefake/databasefake.go index da9c750bf3609..0f521117439b4 100644 --- a/coderd/database/databasefake/databasefake.go +++ b/coderd/database/databasefake/databasefake.go @@ -434,6 +434,16 @@ func (q *fakeQuerier) GetWorkspacesByUserID(_ context.Context, req database.GetW return workspaces, nil } +func (q *fakeQuerier) GetOrganizations(_ context.Context) ([]database.Organization, error) { + q.mutex.RLock() + defer q.mutex.RUnlock() + + if len(q.organizations) == 0 { + return nil, sql.ErrNoRows + } + return q.organizations, nil +} + func (q *fakeQuerier) GetOrganizationByID(_ context.Context, id uuid.UUID) (database.Organization, error) { q.mutex.RLock() defer q.mutex.RUnlock() @@ -856,21 +866,18 @@ func (q *fakeQuerier) InsertAPIKey(_ context.Context, arg database.InsertAPIKeyP //nolint:gosimple key := database.APIKey{ - ID: arg.ID, - HashedSecret: arg.HashedSecret, - UserID: arg.UserID, - Application: arg.Application, - Name: arg.Name, - LastUsed: arg.LastUsed, - ExpiresAt: arg.ExpiresAt, - CreatedAt: arg.CreatedAt, - UpdatedAt: arg.UpdatedAt, - LoginType: arg.LoginType, - OIDCAccessToken: arg.OIDCAccessToken, - OIDCRefreshToken: arg.OIDCRefreshToken, - OIDCIDToken: arg.OIDCIDToken, - OIDCExpiry: arg.OIDCExpiry, - DevurlToken: arg.DevurlToken, + ID: arg.ID, + HashedSecret: arg.HashedSecret, + UserID: arg.UserID, + ExpiresAt: arg.ExpiresAt, + CreatedAt: arg.CreatedAt, + UpdatedAt: arg.UpdatedAt, + LastUsed: arg.LastUsed, + LoginType: arg.LoginType, + OAuthAccessToken: arg.OAuthAccessToken, + OAuthRefreshToken: arg.OAuthRefreshToken, + OAuthIDToken: arg.OAuthIDToken, + OAuthExpiry: arg.OAuthExpiry, } q.apiKeys = append(q.apiKeys, key) return key, nil @@ -1185,9 +1192,9 @@ func (q *fakeQuerier) UpdateAPIKeyByID(_ context.Context, arg database.UpdateAPI } apiKey.LastUsed = arg.LastUsed apiKey.ExpiresAt = arg.ExpiresAt - apiKey.OIDCAccessToken = arg.OIDCAccessToken - apiKey.OIDCRefreshToken = arg.OIDCRefreshToken - apiKey.OIDCExpiry = arg.OIDCExpiry + apiKey.OAuthAccessToken = arg.OAuthAccessToken + apiKey.OAuthRefreshToken = arg.OAuthRefreshToken + apiKey.OAuthExpiry = arg.OAuthExpiry q.apiKeys[index] = apiKey return nil } diff --git a/coderd/database/dump.sql b/coderd/database/dump.sql index fb8621e2f2f3d..0319082ec4ab2 100644 --- a/coderd/database/dump.sql +++ b/coderd/database/dump.sql @@ -14,9 +14,8 @@ CREATE TYPE log_source AS ENUM ( ); CREATE TYPE login_type AS ENUM ( - 'built-in', - 'saml', - 'oidc' + 'password', + 'github' ); CREATE TYPE parameter_destination_scheme AS ENUM ( @@ -67,18 +66,15 @@ CREATE TABLE api_keys ( id text NOT NULL, hashed_secret bytea NOT NULL, user_id uuid NOT NULL, - application boolean NOT NULL, - name text NOT NULL, last_used timestamp with time zone NOT NULL, expires_at timestamp with time zone NOT NULL, created_at timestamp with time zone NOT NULL, updated_at timestamp with time zone NOT NULL, login_type login_type NOT NULL, - oidc_access_token text DEFAULT ''::text NOT NULL, - oidc_refresh_token text DEFAULT ''::text NOT NULL, - oidc_id_token text DEFAULT ''::text NOT NULL, - oidc_expiry timestamp with time zone DEFAULT '0001-01-01 00:00:00+00'::timestamp with time zone NOT NULL, - devurl_token boolean DEFAULT false NOT NULL + oauth_access_token text DEFAULT ''::text NOT NULL, + oauth_refresh_token text DEFAULT ''::text NOT NULL, + oauth_id_token text DEFAULT ''::text NOT NULL, + oauth_expiry timestamp with time zone DEFAULT '0001-01-01 00:00:00+00'::timestamp with time zone NOT NULL ); CREATE TABLE files ( diff --git a/coderd/database/migrations/000001_base.up.sql b/coderd/database/migrations/000001_base.up.sql index 65fbbf8fd4805..81cd3a4f75b3e 100644 --- a/coderd/database/migrations/000001_base.up.sql +++ b/coderd/database/migrations/000001_base.up.sql @@ -4,14 +4,9 @@ -- All tables and types are stolen from: -- https://github.com/coder/m/blob/47b6fc383347b9f9fab424d829c482defd3e1fe2/product/coder/pkg/database/dump.sql --- --- Name: users; Type: TABLE; Schema: public; Owner: coder --- - CREATE TYPE login_type AS ENUM ( - 'built-in', - 'saml', - 'oidc' + 'password', + 'github' ); CREATE TABLE IF NOT EXISTS users ( @@ -31,10 +26,6 @@ CREATE UNIQUE INDEX IF NOT EXISTS idx_users_email ON users USING btree (email); CREATE UNIQUE INDEX IF NOT EXISTS idx_users_username ON users USING btree (username); CREATE UNIQUE INDEX IF NOT EXISTS users_username_lower_idx ON users USING btree (lower(username)); --- --- Name: organizations; Type: TABLE; Schema: Owner: coder --- - CREATE TABLE IF NOT EXISTS organizations ( id uuid NOT NULL, name text NOT NULL, @@ -68,18 +59,15 @@ CREATE TABLE IF NOT EXISTS api_keys ( id text NOT NULL, hashed_secret bytea NOT NULL, user_id uuid NOT NULL, - application boolean NOT NULL, - name text NOT NULL, last_used timestamp with time zone NOT NULL, expires_at timestamp with time zone NOT NULL, created_at timestamp with time zone NOT NULL, updated_at timestamp with time zone NOT NULL, login_type login_type NOT NULL, - oidc_access_token text DEFAULT ''::text NOT NULL, - oidc_refresh_token text DEFAULT ''::text NOT NULL, - oidc_id_token text DEFAULT ''::text NOT NULL, - oidc_expiry timestamp with time zone DEFAULT '0001-01-01 00:00:00+00'::timestamp with time zone NOT NULL, - devurl_token boolean DEFAULT false NOT NULL, + oauth_access_token text DEFAULT ''::text NOT NULL, + oauth_refresh_token text DEFAULT ''::text NOT NULL, + oauth_id_token text DEFAULT ''::text NOT NULL, + oauth_expiry timestamp with time zone DEFAULT '0001-01-01 00:00:00+00'::timestamp with time zone NOT NULL, PRIMARY KEY (id) ); diff --git a/coderd/database/models.go b/coderd/database/models.go index a8d311194139e..2857bf391e4ea 100644 --- a/coderd/database/models.go +++ b/coderd/database/models.go @@ -56,9 +56,8 @@ func (e *LogSource) Scan(src interface{}) error { type LoginType string const ( - LoginTypeBuiltIn LoginType = "built-in" - LoginTypeSaml LoginType = "saml" - LoginTypeOIDC LoginType = "oidc" + LoginTypePassword LoginType = "password" + LoginTypeGithub LoginType = "github" ) func (e *LoginType) Scan(src interface{}) error { @@ -230,21 +229,18 @@ func (e *WorkspaceTransition) Scan(src interface{}) error { } type APIKey struct { - ID string `db:"id" json:"id"` - HashedSecret []byte `db:"hashed_secret" json:"hashed_secret"` - UserID uuid.UUID `db:"user_id" json:"user_id"` - Application bool `db:"application" json:"application"` - Name string `db:"name" json:"name"` - LastUsed time.Time `db:"last_used" json:"last_used"` - ExpiresAt time.Time `db:"expires_at" json:"expires_at"` - CreatedAt time.Time `db:"created_at" json:"created_at"` - UpdatedAt time.Time `db:"updated_at" json:"updated_at"` - LoginType LoginType `db:"login_type" json:"login_type"` - OIDCAccessToken string `db:"oidc_access_token" json:"oidc_access_token"` - OIDCRefreshToken string `db:"oidc_refresh_token" json:"oidc_refresh_token"` - OIDCIDToken string `db:"oidc_id_token" json:"oidc_id_token"` - OIDCExpiry time.Time `db:"oidc_expiry" json:"oidc_expiry"` - DevurlToken bool `db:"devurl_token" json:"devurl_token"` + ID string `db:"id" json:"id"` + HashedSecret []byte `db:"hashed_secret" json:"hashed_secret"` + UserID uuid.UUID `db:"user_id" json:"user_id"` + LastUsed time.Time `db:"last_used" json:"last_used"` + ExpiresAt time.Time `db:"expires_at" json:"expires_at"` + CreatedAt time.Time `db:"created_at" json:"created_at"` + UpdatedAt time.Time `db:"updated_at" json:"updated_at"` + LoginType LoginType `db:"login_type" json:"login_type"` + OAuthAccessToken string `db:"oauth_access_token" json:"oauth_access_token"` + OAuthRefreshToken string `db:"oauth_refresh_token" json:"oauth_refresh_token"` + OAuthIDToken string `db:"oauth_id_token" json:"oauth_id_token"` + OAuthExpiry time.Time `db:"oauth_expiry" json:"oauth_expiry"` } type File struct { diff --git a/coderd/database/querier.go b/coderd/database/querier.go index 4d0de28747589..3b8f317b620b3 100644 --- a/coderd/database/querier.go +++ b/coderd/database/querier.go @@ -18,6 +18,7 @@ type querier interface { GetOrganizationByID(ctx context.Context, id uuid.UUID) (Organization, error) GetOrganizationByName(ctx context.Context, name string) (Organization, error) GetOrganizationMemberByUserID(ctx context.Context, arg GetOrganizationMemberByUserIDParams) (OrganizationMember, error) + GetOrganizations(ctx context.Context) ([]Organization, error) GetOrganizationsByUserID(ctx context.Context, userID uuid.UUID) ([]Organization, error) GetParameterSchemasByJobID(ctx context.Context, jobID uuid.UUID) ([]ParameterSchema, error) GetParameterValueByScopeAndName(ctx context.Context, arg GetParameterValueByScopeAndNameParams) (ParameterValue, error) diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index 3fe7595e441d4..c1e7025b34936 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -15,7 +15,7 @@ import ( const getAPIKeyByID = `-- name: GetAPIKeyByID :one SELECT - id, hashed_secret, user_id, application, name, last_used, expires_at, created_at, updated_at, login_type, oidc_access_token, oidc_refresh_token, oidc_id_token, oidc_expiry, devurl_token + id, hashed_secret, user_id, last_used, expires_at, created_at, updated_at, login_type, oauth_access_token, oauth_refresh_token, oauth_id_token, oauth_expiry FROM api_keys WHERE @@ -31,18 +31,15 @@ func (q *sqlQuerier) GetAPIKeyByID(ctx context.Context, id string) (APIKey, erro &i.ID, &i.HashedSecret, &i.UserID, - &i.Application, - &i.Name, &i.LastUsed, &i.ExpiresAt, &i.CreatedAt, &i.UpdatedAt, &i.LoginType, - &i.OIDCAccessToken, - &i.OIDCRefreshToken, - &i.OIDCIDToken, - &i.OIDCExpiry, - &i.DevurlToken, + &i.OAuthAccessToken, + &i.OAuthRefreshToken, + &i.OAuthIDToken, + &i.OAuthExpiry, ) return i, err } @@ -53,55 +50,33 @@ INSERT INTO id, hashed_secret, user_id, - application, - "name", last_used, expires_at, created_at, updated_at, login_type, - oidc_access_token, - oidc_refresh_token, - oidc_id_token, - oidc_expiry, - devurl_token + oauth_access_token, + oauth_refresh_token, + oauth_id_token, + oauth_expiry ) VALUES - ( - $1, - $2, - $3, - $4, - $5, - $6, - $7, - $8, - $9, - $10, - $11, - $12, - $13, - $14, - $15 - ) RETURNING id, hashed_secret, user_id, application, name, last_used, expires_at, created_at, updated_at, login_type, oidc_access_token, oidc_refresh_token, oidc_id_token, oidc_expiry, devurl_token + ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12) RETURNING id, hashed_secret, user_id, last_used, expires_at, created_at, updated_at, login_type, oauth_access_token, oauth_refresh_token, oauth_id_token, oauth_expiry ` type InsertAPIKeyParams struct { - ID string `db:"id" json:"id"` - HashedSecret []byte `db:"hashed_secret" json:"hashed_secret"` - UserID uuid.UUID `db:"user_id" json:"user_id"` - Application bool `db:"application" json:"application"` - Name string `db:"name" json:"name"` - LastUsed time.Time `db:"last_used" json:"last_used"` - ExpiresAt time.Time `db:"expires_at" json:"expires_at"` - CreatedAt time.Time `db:"created_at" json:"created_at"` - UpdatedAt time.Time `db:"updated_at" json:"updated_at"` - LoginType LoginType `db:"login_type" json:"login_type"` - OIDCAccessToken string `db:"oidc_access_token" json:"oidc_access_token"` - OIDCRefreshToken string `db:"oidc_refresh_token" json:"oidc_refresh_token"` - OIDCIDToken string `db:"oidc_id_token" json:"oidc_id_token"` - OIDCExpiry time.Time `db:"oidc_expiry" json:"oidc_expiry"` - DevurlToken bool `db:"devurl_token" json:"devurl_token"` + ID string `db:"id" json:"id"` + HashedSecret []byte `db:"hashed_secret" json:"hashed_secret"` + UserID uuid.UUID `db:"user_id" json:"user_id"` + LastUsed time.Time `db:"last_used" json:"last_used"` + ExpiresAt time.Time `db:"expires_at" json:"expires_at"` + CreatedAt time.Time `db:"created_at" json:"created_at"` + UpdatedAt time.Time `db:"updated_at" json:"updated_at"` + LoginType LoginType `db:"login_type" json:"login_type"` + OAuthAccessToken string `db:"oauth_access_token" json:"oauth_access_token"` + OAuthRefreshToken string `db:"oauth_refresh_token" json:"oauth_refresh_token"` + OAuthIDToken string `db:"oauth_id_token" json:"oauth_id_token"` + OAuthExpiry time.Time `db:"oauth_expiry" json:"oauth_expiry"` } func (q *sqlQuerier) InsertAPIKey(ctx context.Context, arg InsertAPIKeyParams) (APIKey, error) { @@ -109,36 +84,30 @@ func (q *sqlQuerier) InsertAPIKey(ctx context.Context, arg InsertAPIKeyParams) ( arg.ID, arg.HashedSecret, arg.UserID, - arg.Application, - arg.Name, arg.LastUsed, arg.ExpiresAt, arg.CreatedAt, arg.UpdatedAt, arg.LoginType, - arg.OIDCAccessToken, - arg.OIDCRefreshToken, - arg.OIDCIDToken, - arg.OIDCExpiry, - arg.DevurlToken, + arg.OAuthAccessToken, + arg.OAuthRefreshToken, + arg.OAuthIDToken, + arg.OAuthExpiry, ) var i APIKey err := row.Scan( &i.ID, &i.HashedSecret, &i.UserID, - &i.Application, - &i.Name, &i.LastUsed, &i.ExpiresAt, &i.CreatedAt, &i.UpdatedAt, &i.LoginType, - &i.OIDCAccessToken, - &i.OIDCRefreshToken, - &i.OIDCIDToken, - &i.OIDCExpiry, - &i.DevurlToken, + &i.OAuthAccessToken, + &i.OAuthRefreshToken, + &i.OAuthIDToken, + &i.OAuthExpiry, ) return i, err } @@ -149,20 +118,20 @@ UPDATE SET last_used = $2, expires_at = $3, - oidc_access_token = $4, - oidc_refresh_token = $5, - oidc_expiry = $6 + oauth_access_token = $4, + oauth_refresh_token = $5, + oauth_expiry = $6 WHERE id = $1 ` type UpdateAPIKeyByIDParams struct { - ID string `db:"id" json:"id"` - LastUsed time.Time `db:"last_used" json:"last_used"` - ExpiresAt time.Time `db:"expires_at" json:"expires_at"` - OIDCAccessToken string `db:"oidc_access_token" json:"oidc_access_token"` - OIDCRefreshToken string `db:"oidc_refresh_token" json:"oidc_refresh_token"` - OIDCExpiry time.Time `db:"oidc_expiry" json:"oidc_expiry"` + ID string `db:"id" json:"id"` + LastUsed time.Time `db:"last_used" json:"last_used"` + ExpiresAt time.Time `db:"expires_at" json:"expires_at"` + OAuthAccessToken string `db:"oauth_access_token" json:"oauth_access_token"` + OAuthRefreshToken string `db:"oauth_refresh_token" json:"oauth_refresh_token"` + OAuthExpiry time.Time `db:"oauth_expiry" json:"oauth_expiry"` } func (q *sqlQuerier) UpdateAPIKeyByID(ctx context.Context, arg UpdateAPIKeyByIDParams) error { @@ -170,9 +139,9 @@ func (q *sqlQuerier) UpdateAPIKeyByID(ctx context.Context, arg UpdateAPIKeyByIDP arg.ID, arg.LastUsed, arg.ExpiresAt, - arg.OIDCAccessToken, - arg.OIDCRefreshToken, - arg.OIDCExpiry, + arg.OAuthAccessToken, + arg.OAuthRefreshToken, + arg.OAuthExpiry, ) return err } @@ -453,6 +422,42 @@ func (q *sqlQuerier) GetOrganizationByName(ctx context.Context, name string) (Or return i, err } +const getOrganizations = `-- name: GetOrganizations :many +SELECT + id, name, description, created_at, updated_at +FROM + organizations +` + +func (q *sqlQuerier) GetOrganizations(ctx context.Context) ([]Organization, error) { + rows, err := q.db.QueryContext(ctx, getOrganizations) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Organization + for rows.Next() { + var i Organization + if err := rows.Scan( + &i.ID, + &i.Name, + &i.Description, + &i.CreatedAt, + &i.UpdatedAt, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const getOrganizationsByUserID = `-- name: GetOrganizationsByUserID :many SELECT id, name, description, created_at, updated_at diff --git a/coderd/database/queries/apikeys.sql b/coderd/database/queries/apikeys.sql index 62dc38ed2ca59..1af2016f491bf 100644 --- a/coderd/database/queries/apikeys.sql +++ b/coderd/database/queries/apikeys.sql @@ -14,37 +14,18 @@ INSERT INTO id, hashed_secret, user_id, - application, - "name", last_used, expires_at, created_at, updated_at, login_type, - oidc_access_token, - oidc_refresh_token, - oidc_id_token, - oidc_expiry, - devurl_token + oauth_access_token, + oauth_refresh_token, + oauth_id_token, + oauth_expiry ) VALUES - ( - $1, - $2, - $3, - $4, - $5, - $6, - $7, - $8, - $9, - $10, - $11, - $12, - $13, - $14, - $15 - ) RETURNING *; + ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12) RETURNING *; -- name: UpdateAPIKeyByID :exec UPDATE @@ -52,8 +33,8 @@ UPDATE SET last_used = $2, expires_at = $3, - oidc_access_token = $4, - oidc_refresh_token = $5, - oidc_expiry = $6 + oauth_access_token = $4, + oauth_refresh_token = $5, + oauth_expiry = $6 WHERE id = $1; diff --git a/coderd/database/queries/organizations.sql b/coderd/database/queries/organizations.sql index 1682c04a8fd95..87c403049efd2 100644 --- a/coderd/database/queries/organizations.sql +++ b/coderd/database/queries/organizations.sql @@ -1,3 +1,9 @@ +-- name: GetOrganizations :many +SELECT + * +FROM + organizations; + -- name: GetOrganizationByID :one SELECT * diff --git a/coderd/database/sqlc.yaml b/coderd/database/sqlc.yaml index a009644cdf520..abde7029c3c79 100644 --- a/coderd/database/sqlc.yaml +++ b/coderd/database/sqlc.yaml @@ -21,10 +21,10 @@ overrides: rename: api_key: APIKey login_type_oidc: LoginTypeOIDC - oidc_access_token: OIDCAccessToken - oidc_expiry: OIDCExpiry - oidc_id_token: OIDCIDToken - oidc_refresh_token: OIDCRefreshToken + oauth_access_token: OAuthAccessToken + oauth_expiry: OAuthExpiry + oauth_id_token: OAuthIDToken + oauth_refresh_token: OAuthRefreshToken parameter_type_system_hcl: ParameterTypeSystemHCL userstatus: UserStatus gitsshkey: GitSSHKey diff --git a/coderd/httpmw/apikey.go b/coderd/httpmw/apikey.go index 1b18bc56bcde6..c3038ace73b6a 100644 --- a/coderd/httpmw/apikey.go +++ b/coderd/httpmw/apikey.go @@ -20,12 +20,6 @@ import ( // AuthCookie represents the name of the cookie the API key is stored in. const AuthCookie = "session_token" -// OAuth2Config contains a subset of functions exposed from oauth2.Config. -// It is abstracted for simple testing. -type OAuth2Config interface { - TokenSource(context.Context, *oauth2.Token) oauth2.TokenSource -} - type apiKeyContextKey struct{} // APIKey returns the API key from the ExtractAPIKey handler. @@ -37,10 +31,16 @@ func APIKey(r *http.Request) database.APIKey { return apiKey } +// OAuth2Configs is a collection of configurations for OAuth-based authentication. +// This should be extended to support other authentication types in the future. +type OAuth2Configs struct { + Github OAuth2Config +} + // ExtractAPIKey requires authentication using a valid API key. // It handles extending an API key if it comes close to expiry, // updating the last used time in the database. -func ExtractAPIKey(db database.Store, oauthConfig OAuth2Config) func(http.Handler) http.Handler { +func ExtractAPIKey(db database.Store, oauth *OAuth2Configs) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { cookie, err := r.Cookie(AuthCookie) @@ -99,14 +99,24 @@ func ExtractAPIKey(db database.Store, oauthConfig OAuth2Config) func(http.Handle // Tracks if the API key has properties updated! changed := false - if key.LoginType == database.LoginTypeOIDC { - // Check if the OIDC token is expired! - if key.OIDCExpiry.Before(now) && !key.OIDCExpiry.IsZero() { + if key.LoginType != database.LoginTypePassword { + // Check if the OAuth token is expired! + if key.OAuthExpiry.Before(now) && !key.OAuthExpiry.IsZero() { + var oauthConfig OAuth2Config + switch key.LoginType { + case database.LoginTypeGithub: + oauthConfig = oauth.Github + default: + httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ + Message: fmt.Sprintf("unexpected authentication type %q", key.LoginType), + }) + return + } // If it is, let's refresh it from the provided config! token, err := oauthConfig.TokenSource(r.Context(), &oauth2.Token{ - AccessToken: key.OIDCAccessToken, - RefreshToken: key.OIDCRefreshToken, - Expiry: key.OIDCExpiry, + AccessToken: key.OAuthAccessToken, + RefreshToken: key.OAuthRefreshToken, + Expiry: key.OAuthExpiry, }).Token() if err != nil { httpapi.Write(rw, http.StatusUnauthorized, httpapi.Response{ @@ -114,9 +124,9 @@ func ExtractAPIKey(db database.Store, oauthConfig OAuth2Config) func(http.Handle }) return } - key.OIDCAccessToken = token.AccessToken - key.OIDCRefreshToken = token.RefreshToken - key.OIDCExpiry = token.Expiry + key.OAuthAccessToken = token.AccessToken + key.OAuthRefreshToken = token.RefreshToken + key.OAuthExpiry = token.Expiry key.ExpiresAt = token.Expiry changed = true } @@ -136,21 +146,20 @@ func ExtractAPIKey(db database.Store, oauthConfig OAuth2Config) func(http.Handle changed = true } // Only update the ExpiresAt once an hour to prevent database spam. - // We extend the ExpiresAt to reduce reauthentication. + // We extend the ExpiresAt to reduce re-authentication. apiKeyLifetime := 24 * time.Hour if key.ExpiresAt.Sub(now) <= apiKeyLifetime-time.Hour { key.ExpiresAt = now.Add(apiKeyLifetime) changed = true } - if changed { err := db.UpdateAPIKeyByID(r.Context(), database.UpdateAPIKeyByIDParams{ - ID: key.ID, - ExpiresAt: key.ExpiresAt, - LastUsed: key.LastUsed, - OIDCAccessToken: key.OIDCAccessToken, - OIDCRefreshToken: key.OIDCRefreshToken, - OIDCExpiry: key.OIDCExpiry, + ID: key.ID, + LastUsed: key.LastUsed, + ExpiresAt: key.ExpiresAt, + OAuthAccessToken: key.OAuthAccessToken, + OAuthRefreshToken: key.OAuthRefreshToken, + OAuthExpiry: key.OAuthExpiry, }) if err != nil { httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ diff --git a/coderd/httpmw/apikey_test.go b/coderd/httpmw/apikey_test.go index 2d4e7c3a6be67..0c8d8d396e55b 100644 --- a/coderd/httpmw/apikey_test.go +++ b/coderd/httpmw/apikey_test.go @@ -189,7 +189,6 @@ func TestAPIKey(t *testing.T) { sentAPIKey, err := db.InsertAPIKey(r.Context(), database.InsertAPIKeyParams{ ID: id, HashedSecret: hashed[:], - LastUsed: database.Now(), ExpiresAt: database.Now().AddDate(0, 0, 1), }) require.NoError(t, err) @@ -207,7 +206,6 @@ func TestAPIKey(t *testing.T) { gotAPIKey, err := db.GetAPIKeyByID(r.Context(), id) require.NoError(t, err) - require.Equal(t, sentAPIKey.LastUsed, gotAPIKey.LastUsed) require.Equal(t, sentAPIKey.ExpiresAt, gotAPIKey.ExpiresAt) }) @@ -277,7 +275,7 @@ func TestAPIKey(t *testing.T) { require.NotEqual(t, sentAPIKey.ExpiresAt, gotAPIKey.ExpiresAt) }) - t.Run("OIDCNotExpired", func(t *testing.T) { + t.Run("OAuthNotExpired", func(t *testing.T) { t.Parallel() var ( db = databasefake.New() @@ -294,7 +292,7 @@ func TestAPIKey(t *testing.T) { sentAPIKey, err := db.InsertAPIKey(r.Context(), database.InsertAPIKeyParams{ ID: id, HashedSecret: hashed[:], - LoginType: database.LoginTypeOIDC, + LoginType: database.LoginTypeGithub, LastUsed: database.Now(), ExpiresAt: database.Now().AddDate(0, 0, 1), }) @@ -311,7 +309,7 @@ func TestAPIKey(t *testing.T) { require.Equal(t, sentAPIKey.ExpiresAt, gotAPIKey.ExpiresAt) }) - t.Run("OIDCRefresh", func(t *testing.T) { + t.Run("OAuthRefresh", func(t *testing.T) { t.Parallel() var ( db = databasefake.New() @@ -328,9 +326,9 @@ func TestAPIKey(t *testing.T) { sentAPIKey, err := db.InsertAPIKey(r.Context(), database.InsertAPIKeyParams{ ID: id, HashedSecret: hashed[:], - LoginType: database.LoginTypeOIDC, + LoginType: database.LoginTypeGithub, LastUsed: database.Now(), - OIDCExpiry: database.Now().AddDate(0, 0, -1), + OAuthExpiry: database.Now().AddDate(0, 0, -1), }) require.NoError(t, err) token := &oauth2.Token{ @@ -338,11 +336,11 @@ func TestAPIKey(t *testing.T) { RefreshToken: "moo", Expiry: database.Now().AddDate(0, 0, 1), } - httpmw.ExtractAPIKey(db, &oauth2Config{ - tokenSource: &oauth2TokenSource{ - token: func() (*oauth2.Token, error) { + httpmw.ExtractAPIKey(db, &httpmw.OAuth2Configs{ + Github: &oauth2Config{ + tokenSource: oauth2TokenSource(func() (*oauth2.Token, error) { return token, nil - }, + }), }, })(successHandler).ServeHTTP(rw, r) res := rw.Result() @@ -354,22 +352,28 @@ func TestAPIKey(t *testing.T) { require.Equal(t, sentAPIKey.LastUsed, gotAPIKey.LastUsed) require.Equal(t, token.Expiry, gotAPIKey.ExpiresAt) - require.Equal(t, token.AccessToken, gotAPIKey.OIDCAccessToken) + require.Equal(t, token.AccessToken, gotAPIKey.OAuthAccessToken) }) } type oauth2Config struct { - tokenSource *oauth2TokenSource + tokenSource oauth2TokenSource } -func (o *oauth2Config) TokenSource(_ context.Context, _ *oauth2.Token) oauth2.TokenSource { +func (o *oauth2Config) TokenSource(context.Context, *oauth2.Token) oauth2.TokenSource { return o.tokenSource } -type oauth2TokenSource struct { - token func() (*oauth2.Token, error) +func (*oauth2Config) AuthCodeURL(string, ...oauth2.AuthCodeOption) string { + return "" } -func (o *oauth2TokenSource) Token() (*oauth2.Token, error) { - return o.token() +func (*oauth2Config) Exchange(context.Context, string, ...oauth2.AuthCodeOption) (*oauth2.Token, error) { + return &oauth2.Token{}, nil +} + +type oauth2TokenSource func() (*oauth2.Token, error) + +func (o oauth2TokenSource) Token() (*oauth2.Token, error) { + return o() } diff --git a/coderd/httpmw/oauth2.go b/coderd/httpmw/oauth2.go new file mode 100644 index 0000000000000..c3e2e0f00519f --- /dev/null +++ b/coderd/httpmw/oauth2.go @@ -0,0 +1,132 @@ +package httpmw + +import ( + "context" + "fmt" + "net/http" + + "golang.org/x/oauth2" + + "github.com/coder/coder/coderd/httpapi" + "github.com/coder/coder/cryptorand" +) + +const ( + oauth2StateCookieName = "oauth_state" + oauth2RedirectCookieName = "oauth_redirect" +) + +type oauth2StateKey struct{} + +type OAuth2State struct { + Token *oauth2.Token + Redirect string +} + +// OAuth2Config exposes a subset of *oauth2.Config functions for easier testing. +// *oauth2.Config should be used instead of implementing this in production. +type OAuth2Config interface { + AuthCodeURL(state string, opts ...oauth2.AuthCodeOption) string + Exchange(ctx context.Context, code string, opts ...oauth2.AuthCodeOption) (*oauth2.Token, error) + TokenSource(context.Context, *oauth2.Token) oauth2.TokenSource +} + +// OAuth2 returns the state from an oauth request. +func OAuth2(r *http.Request) OAuth2State { + oauth, ok := r.Context().Value(oauth2StateKey{}).(OAuth2State) + if !ok { + panic("developer error: oauth middleware not provided") + } + return oauth +} + +// ExtractOAuth2 is a middleware for automatically redirecting to OAuth +// URLs, and handling the exchange inbound. Any route that does not have +// a "code" URL parameter will be redirected. +func ExtractOAuth2(config OAuth2Config) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + if config == nil { + httpapi.Write(rw, http.StatusPreconditionRequired, httpapi.Response{ + Message: fmt.Sprintf("The oauth2 method requested is not configured!"), + }) + return + } + + code := r.URL.Query().Get("code") + state := r.URL.Query().Get("state") + + if code == "" { + // If the code isn't provided, we'll redirect! + state, err := cryptorand.String(32) + if err != nil { + httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ + Message: fmt.Sprintf("generate state string: %s", err), + }) + return + } + + http.SetCookie(rw, &http.Cookie{ + Name: oauth2StateCookieName, + Value: state, + Path: "/", + HttpOnly: true, + SameSite: http.SameSiteStrictMode, + }) + // Redirect must always be specified, otherwise + // an old redirect could apply! + http.SetCookie(rw, &http.Cookie{ + Name: oauth2RedirectCookieName, + Value: r.URL.Query().Get("redirect"), + Path: "/", + HttpOnly: true, + SameSite: http.SameSiteStrictMode, + }) + + http.Redirect(rw, r, config.AuthCodeURL(state, oauth2.AccessTypeOffline), http.StatusTemporaryRedirect) + return + } + + if state == "" { + httpapi.Write(rw, http.StatusBadRequest, httpapi.Response{ + Message: "state must be provided", + }) + return + } + + stateCookie, err := r.Cookie(oauth2StateCookieName) + if err != nil { + httpapi.Write(rw, http.StatusUnauthorized, httpapi.Response{ + Message: fmt.Sprintf("%q cookie must be provided", oauth2StateCookieName), + }) + return + } + if stateCookie.Value != state { + httpapi.Write(rw, http.StatusUnauthorized, httpapi.Response{ + Message: "state mismatched", + }) + return + } + + var redirect string + stateRedirect, err := r.Cookie(oauth2RedirectCookieName) + if err == nil { + redirect = stateRedirect.Value + } + + oauthToken, err := config.Exchange(r.Context(), code) + if err != nil { + httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ + Message: fmt.Sprintf("exchange oauth code: %s", err), + }) + return + } + + ctx := context.WithValue(r.Context(), oauth2StateKey{}, OAuth2State{ + Token: oauthToken, + Redirect: redirect, + }) + next.ServeHTTP(rw, r.WithContext(ctx)) + }) + } +} diff --git a/coderd/httpmw/oauth2_test.go b/coderd/httpmw/oauth2_test.go new file mode 100644 index 0000000000000..31803b7351487 --- /dev/null +++ b/coderd/httpmw/oauth2_test.go @@ -0,0 +1,98 @@ +package httpmw_test + +import ( + "context" + "net/http" + "net/http/httptest" + "net/url" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/oauth2" + + "github.com/coder/coder/coderd/httpmw" +) + +type testOAuth2Provider struct { +} + +func (*testOAuth2Provider) AuthCodeURL(state string, _ ...oauth2.AuthCodeOption) string { + return "?state=" + url.QueryEscape(state) +} + +func (*testOAuth2Provider) Exchange(_ context.Context, _ string, _ ...oauth2.AuthCodeOption) (*oauth2.Token, error) { + return &oauth2.Token{ + AccessToken: "hello", + }, nil +} + +func (*testOAuth2Provider) TokenSource(_ context.Context, _ *oauth2.Token) oauth2.TokenSource { + return nil +} + +func TestOAuth2(t *testing.T) { + t.Parallel() + t.Run("NotSetup", func(t *testing.T) { + t.Parallel() + req := httptest.NewRequest("GET", "/", nil) + res := httptest.NewRecorder() + httpmw.ExtractOAuth2(nil)(nil).ServeHTTP(res, req) + require.Equal(t, http.StatusPreconditionRequired, res.Result().StatusCode) + }) + t.Run("RedirectWithoutCode", func(t *testing.T) { + t.Parallel() + req := httptest.NewRequest("GET", "/?redirect="+url.QueryEscape("/dashboard"), nil) + res := httptest.NewRecorder() + httpmw.ExtractOAuth2(&testOAuth2Provider{})(nil).ServeHTTP(res, req) + location := res.Header().Get("Location") + if !assert.NotEmpty(t, location) { + return + } + require.Len(t, res.Result().Cookies(), 2) + cookie := res.Result().Cookies()[1] + require.Equal(t, "/dashboard", cookie.Value) + }) + t.Run("NoState", func(t *testing.T) { + t.Parallel() + req := httptest.NewRequest("GET", "/?code=something", nil) + res := httptest.NewRecorder() + httpmw.ExtractOAuth2(&testOAuth2Provider{})(nil).ServeHTTP(res, req) + require.Equal(t, http.StatusBadRequest, res.Result().StatusCode) + }) + t.Run("NoStateCookie", func(t *testing.T) { + t.Parallel() + req := httptest.NewRequest("GET", "/?code=something&state=test", nil) + res := httptest.NewRecorder() + httpmw.ExtractOAuth2(&testOAuth2Provider{})(nil).ServeHTTP(res, req) + require.Equal(t, http.StatusUnauthorized, res.Result().StatusCode) + }) + t.Run("MismatchedState", func(t *testing.T) { + t.Parallel() + req := httptest.NewRequest("GET", "/?code=something&state=test", nil) + req.AddCookie(&http.Cookie{ + Name: "oauth_state", + Value: "mismatch", + }) + res := httptest.NewRecorder() + httpmw.ExtractOAuth2(&testOAuth2Provider{})(nil).ServeHTTP(res, req) + require.Equal(t, http.StatusUnauthorized, res.Result().StatusCode) + }) + t.Run("ExchangeCodeAndState", func(t *testing.T) { + t.Parallel() + req := httptest.NewRequest("GET", "/?code=test&state=something", nil) + req.AddCookie(&http.Cookie{ + Name: "oauth_state", + Value: "something", + }) + req.AddCookie(&http.Cookie{ + Name: "oauth_redirect", + Value: "/dashboard", + }) + res := httptest.NewRecorder() + httpmw.ExtractOAuth2(&testOAuth2Provider{})(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + state := httpmw.OAuth2(r) + require.Equal(t, "/dashboard", state.Redirect) + })).ServeHTTP(res, req) + }) +} diff --git a/coderd/httpmw/organizationparam_test.go b/coderd/httpmw/organizationparam_test.go index 02887260feea0..a5bd256a66bbd 100644 --- a/coderd/httpmw/organizationparam_test.go +++ b/coderd/httpmw/organizationparam_test.go @@ -41,7 +41,7 @@ func TestOrganizationParam(t *testing.T) { ID: userID, Email: "testaccount@coder.com", Name: "example", - LoginType: database.LoginTypeBuiltIn, + LoginType: database.LoginTypePassword, HashedPassword: hashed[:], Username: username, CreatedAt: database.Now(), diff --git a/coderd/httpmw/templateparam_test.go b/coderd/httpmw/templateparam_test.go index 47089713d612f..fb3d2324d0490 100644 --- a/coderd/httpmw/templateparam_test.go +++ b/coderd/httpmw/templateparam_test.go @@ -40,7 +40,7 @@ func TestTemplateParam(t *testing.T) { ID: userID, Email: "testaccount@coder.com", Name: "example", - LoginType: database.LoginTypeBuiltIn, + LoginType: database.LoginTypePassword, HashedPassword: hashed[:], Username: username, CreatedAt: database.Now(), diff --git a/coderd/httpmw/templateversionparam_test.go b/coderd/httpmw/templateversionparam_test.go index 025b646f2ae58..7207a14d7bf92 100644 --- a/coderd/httpmw/templateversionparam_test.go +++ b/coderd/httpmw/templateversionparam_test.go @@ -40,7 +40,7 @@ func TestTemplateVersionParam(t *testing.T) { ID: userID, Email: "testaccount@coder.com", Name: "example", - LoginType: database.LoginTypeBuiltIn, + LoginType: database.LoginTypePassword, HashedPassword: hashed[:], Username: username, CreatedAt: database.Now(), diff --git a/coderd/httpmw/workspaceagentparam_test.go b/coderd/httpmw/workspaceagentparam_test.go index f014a8bd55b55..575a144c6efde 100644 --- a/coderd/httpmw/workspaceagentparam_test.go +++ b/coderd/httpmw/workspaceagentparam_test.go @@ -40,7 +40,7 @@ func TestWorkspaceAgentParam(t *testing.T) { ID: userID, Email: "testaccount@coder.com", Name: "example", - LoginType: database.LoginTypeBuiltIn, + LoginType: database.LoginTypePassword, HashedPassword: hashed[:], Username: username, CreatedAt: database.Now(), diff --git a/coderd/httpmw/workspacebuildparam_test.go b/coderd/httpmw/workspacebuildparam_test.go index 62eb6f975765c..39722ea644944 100644 --- a/coderd/httpmw/workspacebuildparam_test.go +++ b/coderd/httpmw/workspacebuildparam_test.go @@ -40,7 +40,7 @@ func TestWorkspaceBuildParam(t *testing.T) { ID: userID, Email: "testaccount@coder.com", Name: "example", - LoginType: database.LoginTypeBuiltIn, + LoginType: database.LoginTypePassword, HashedPassword: hashed[:], Username: username, CreatedAt: database.Now(), diff --git a/coderd/httpmw/workspaceparam_test.go b/coderd/httpmw/workspaceparam_test.go index 5c169a0d10218..0f10c0e129ada 100644 --- a/coderd/httpmw/workspaceparam_test.go +++ b/coderd/httpmw/workspaceparam_test.go @@ -40,7 +40,7 @@ func TestWorkspaceParam(t *testing.T) { ID: userID, Email: "testaccount@coder.com", Name: "example", - LoginType: database.LoginTypeBuiltIn, + LoginType: database.LoginTypePassword, HashedPassword: hashed[:], Username: username, CreatedAt: database.Now(), diff --git a/coderd/userauth.go b/coderd/userauth.go new file mode 100644 index 0000000000000..087a9adb78115 --- /dev/null +++ b/coderd/userauth.go @@ -0,0 +1,155 @@ +package coderd + +import ( + "context" + "database/sql" + "errors" + "fmt" + "net/http" + + "github.com/google/go-github/v43/github" + "github.com/google/uuid" + "golang.org/x/oauth2" + + "github.com/coder/coder/coderd/database" + "github.com/coder/coder/coderd/httpapi" + "github.com/coder/coder/coderd/httpmw" + "github.com/coder/coder/codersdk" +) + +// GithubOAuth2Provider exposes required functions for the Github authentication flow. +type GithubOAuth2Config struct { + httpmw.OAuth2Config + AuthenticatedUser func(ctx context.Context, client *http.Client) (*github.User, error) + ListEmails func(ctx context.Context, client *http.Client) ([]*github.UserEmail, error) + ListOrganizationMemberships func(ctx context.Context, client *http.Client) ([]*github.Membership, error) + + AllowSignups bool + AllowOrganizations []string +} + +func (api *api) userAuthMethods(rw http.ResponseWriter, _ *http.Request) { + httpapi.Write(rw, http.StatusOK, codersdk.AuthMethods{ + Password: true, + Github: api.GithubOAuth2Config != nil, + }) +} + +func (api *api) userOAuth2Github(rw http.ResponseWriter, r *http.Request) { + state := httpmw.OAuth2(r) + + oauthClient := oauth2.NewClient(r.Context(), oauth2.StaticTokenSource(state.Token)) + memberships, err := api.GithubOAuth2Config.ListOrganizationMemberships(r.Context(), oauthClient) + if err != nil { + httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ + Message: fmt.Sprintf("get authenticated github user organizations: %s", err), + }) + return + } + var selectedMembership *github.Membership + for _, membership := range memberships { + for _, allowed := range api.GithubOAuth2Config.AllowOrganizations { + if *membership.Organization.Login != allowed { + continue + } + selectedMembership = membership + break + } + } + if selectedMembership == nil { + httpapi.Write(rw, http.StatusUnauthorized, httpapi.Response{ + Message: fmt.Sprintf("You aren't a member of the authorized Github organizations!"), + }) + return + } + + emails, err := api.GithubOAuth2Config.ListEmails(r.Context(), oauthClient) + if err != nil { + httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ + Message: fmt.Sprintf("get personal github user: %s", err), + }) + return + } + + var user database.User + // Search for existing users with matching and verified emails. + // If a verified GitHub email matches a Coder user, we will return. + for _, email := range emails { + if email.Verified == nil { + continue + } + user, err = api.Database.GetUserByEmailOrUsername(r.Context(), database.GetUserByEmailOrUsernameParams{ + Email: *email.Email, + }) + if errors.Is(err, sql.ErrNoRows) { + continue + } + if err != nil { + httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ + Message: fmt.Sprintf("get user by email: %s", err), + }) + return + } + if !*email.Verified { + httpapi.Write(rw, http.StatusForbidden, httpapi.Response{ + Message: fmt.Sprintf("Verify the %q email address on Github to authenticate!", *email.Email), + }) + return + } + break + } + + // If the user doesn't exist, create a new one! + if user.ID == uuid.Nil { + if !api.GithubOAuth2Config.AllowSignups { + httpapi.Write(rw, http.StatusForbidden, httpapi.Response{ + Message: "Signups are disabled for Github authentication!", + }) + return + } + + var organizationID uuid.UUID + organizations, _ := api.Database.GetOrganizations(r.Context()) + if len(organizations) > 0 { + // Add the user to the first organization. Once multi-organization + // support is added, we should enable a configuration map of user + // email to organization. + organizationID = organizations[0].ID + } + ghUser, err := api.GithubOAuth2Config.AuthenticatedUser(r.Context(), oauthClient) + if err != nil { + httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ + Message: fmt.Sprintf("get authenticated github user: %s", err), + }) + return + } + user, _, err = api.createUser(r.Context(), codersdk.CreateUserRequest{ + Email: *ghUser.Email, + Username: *ghUser.Login, + OrganizationID: organizationID, + }) + if err != nil { + httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ + Message: fmt.Sprintf("create user: %s", err), + }) + return + } + } + + _, created := api.createAPIKey(rw, r, database.InsertAPIKeyParams{ + UserID: user.ID, + LoginType: database.LoginTypeGithub, + OAuthAccessToken: state.Token.AccessToken, + OAuthRefreshToken: state.Token.RefreshToken, + OAuthExpiry: state.Token.Expiry, + }) + if !created { + return + } + + redirect := state.Redirect + if redirect == "" { + redirect = "/" + } + http.Redirect(rw, r, redirect, http.StatusTemporaryRedirect) +} diff --git a/coderd/userauth_test.go b/coderd/userauth_test.go new file mode 100644 index 0000000000000..b5103b9d2da83 --- /dev/null +++ b/coderd/userauth_test.go @@ -0,0 +1,205 @@ +package coderd_test + +import ( + "context" + "net/http" + "net/url" + "testing" + + "github.com/google/go-github/v43/github" + "github.com/stretchr/testify/require" + "golang.org/x/oauth2" + + "github.com/coder/coder/coderd" + "github.com/coder/coder/coderd/coderdtest" + "github.com/coder/coder/codersdk" +) + +type oauth2Config struct{} + +func (*oauth2Config) AuthCodeURL(state string, _ ...oauth2.AuthCodeOption) string { + return "/?state=" + url.QueryEscape(state) +} + +func (*oauth2Config) Exchange(context.Context, string, ...oauth2.AuthCodeOption) (*oauth2.Token, error) { + return &oauth2.Token{ + AccessToken: "token", + }, nil +} + +func (*oauth2Config) TokenSource(context.Context, *oauth2.Token) oauth2.TokenSource { + return nil +} + +func TestUserAuthMethods(t *testing.T) { + t.Parallel() + t.Run("Password", func(t *testing.T) { + t.Parallel() + client := coderdtest.New(t, nil) + methods, err := client.AuthMethods(context.Background()) + require.NoError(t, err) + require.True(t, methods.Password) + require.False(t, methods.Github) + }) + t.Run("Github", func(t *testing.T) { + t.Parallel() + client := coderdtest.New(t, &coderdtest.Options{ + GithubOAuth2Config: &coderd.GithubOAuth2Config{}, + }) + methods, err := client.AuthMethods(context.Background()) + require.NoError(t, err) + require.True(t, methods.Password) + require.True(t, methods.Github) + }) +} + +func TestUserOAuth2Github(t *testing.T) { + t.Parallel() + t.Run("NotInAllowedOrganization", func(t *testing.T) { + t.Parallel() + client := coderdtest.New(t, &coderdtest.Options{ + GithubOAuth2Config: &coderd.GithubOAuth2Config{ + OAuth2Config: &oauth2Config{}, + ListOrganizationMemberships: func(ctx context.Context, client *http.Client) ([]*github.Membership, error) { + return []*github.Membership{{ + Organization: &github.Organization{ + Login: github.String("kyle"), + }, + }}, nil + }, + }, + }) + + resp := oauth2Callback(t, client) + require.Equal(t, http.StatusUnauthorized, resp.StatusCode) + }) + t.Run("UnverifiedEmail", func(t *testing.T) { + t.Parallel() + client := coderdtest.New(t, &coderdtest.Options{ + GithubOAuth2Config: &coderd.GithubOAuth2Config{ + OAuth2Config: &oauth2Config{}, + AllowOrganizations: []string{"coder"}, + ListOrganizationMemberships: func(ctx context.Context, client *http.Client) ([]*github.Membership, error) { + return []*github.Membership{{ + Organization: &github.Organization{ + Login: github.String("coder"), + }, + }}, nil + }, + AuthenticatedUser: func(ctx context.Context, client *http.Client) (*github.User, error) { + return &github.User{}, nil + }, + ListEmails: func(ctx context.Context, client *http.Client) ([]*github.UserEmail, error) { + return []*github.UserEmail{{ + Email: github.String("testuser@coder.com"), + Verified: github.Bool(false), + }}, nil + }, + }, + }) + _ = coderdtest.CreateFirstUser(t, client) + resp := oauth2Callback(t, client) + require.Equal(t, http.StatusForbidden, resp.StatusCode) + }) + t.Run("BlockSignups", func(t *testing.T) { + t.Parallel() + client := coderdtest.New(t, &coderdtest.Options{ + GithubOAuth2Config: &coderd.GithubOAuth2Config{ + OAuth2Config: &oauth2Config{}, + AllowOrganizations: []string{"coder"}, + ListOrganizationMemberships: func(ctx context.Context, client *http.Client) ([]*github.Membership, error) { + return []*github.Membership{{ + Organization: &github.Organization{ + Login: github.String("coder"), + }, + }}, nil + }, + AuthenticatedUser: func(ctx context.Context, client *http.Client) (*github.User, error) { + return &github.User{}, nil + }, + ListEmails: func(ctx context.Context, client *http.Client) ([]*github.UserEmail, error) { + return []*github.UserEmail{}, nil + }, + }, + }) + resp := oauth2Callback(t, client) + require.Equal(t, http.StatusForbidden, resp.StatusCode) + }) + t.Run("Signup", func(t *testing.T) { + t.Parallel() + client := coderdtest.New(t, &coderdtest.Options{ + GithubOAuth2Config: &coderd.GithubOAuth2Config{ + OAuth2Config: &oauth2Config{}, + AllowOrganizations: []string{"coder"}, + AllowSignups: true, + ListOrganizationMemberships: func(ctx context.Context, client *http.Client) ([]*github.Membership, error) { + return []*github.Membership{{ + Organization: &github.Organization{ + Login: github.String("coder"), + }, + }}, nil + }, + AuthenticatedUser: func(ctx context.Context, client *http.Client) (*github.User, error) { + return &github.User{ + Login: github.String("kyle"), + Email: github.String("kyle@coder.com"), + }, nil + }, + ListEmails: func(ctx context.Context, client *http.Client) ([]*github.UserEmail, error) { + return []*github.UserEmail{}, nil + }, + }, + }) + resp := oauth2Callback(t, client) + require.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode) + }) + t.Run("Login", func(t *testing.T) { + t.Parallel() + client := coderdtest.New(t, &coderdtest.Options{ + GithubOAuth2Config: &coderd.GithubOAuth2Config{ + OAuth2Config: &oauth2Config{}, + AllowOrganizations: []string{"coder"}, + ListOrganizationMemberships: func(ctx context.Context, client *http.Client) ([]*github.Membership, error) { + return []*github.Membership{{ + Organization: &github.Organization{ + Login: github.String("coder"), + }, + }}, nil + }, + AuthenticatedUser: func(ctx context.Context, client *http.Client) (*github.User, error) { + return &github.User{}, nil + }, + ListEmails: func(ctx context.Context, client *http.Client) ([]*github.UserEmail, error) { + return []*github.UserEmail{{ + Email: github.String("testuser@coder.com"), + Verified: github.Bool(true), + }}, nil + }, + }, + }) + _ = coderdtest.CreateFirstUser(t, client) + resp := oauth2Callback(t, client) + require.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode) + }) +} + +func oauth2Callback(t *testing.T, client *codersdk.Client) *http.Response { + client.HTTPClient.CheckRedirect = func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + } + state := "somestate" + oauthURL, err := client.URL.Parse("/api/v2/users/oauth2/github/callback?code=asd&state=" + state) + require.NoError(t, err) + req, err := http.NewRequest("GET", oauthURL.String(), nil) + require.NoError(t, err) + req.AddCookie(&http.Cookie{ + Name: "oauth_state", + Value: state, + }) + res, err := client.HTTPClient.Do(req) + require.NoError(t, err) + t.Cleanup(func() { + _ = res.Body.Close() + }) + return res +} diff --git a/coderd/users.go b/coderd/users.go index 41f1b28d76ca0..046a26ef74c30 100644 --- a/coderd/users.go +++ b/coderd/users.go @@ -1,6 +1,7 @@ package coderd import ( + "context" "crypto/sha256" "database/sql" "encoding/json" @@ -71,66 +72,10 @@ func (api *api) postFirstUser(rw http.ResponseWriter, r *http.Request) { return } - hashedPassword, err := userpassword.Hash(createUser.Password) - if err != nil { - httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ - Message: fmt.Sprintf("hash password: %s", err.Error()), - }) - return - } - - // Create the user, organization, and membership to the user. - var user database.User - var organization database.Organization - err = api.Database.InTx(func(db database.Store) error { - user, err = api.Database.InsertUser(r.Context(), database.InsertUserParams{ - ID: uuid.New(), - Email: createUser.Email, - HashedPassword: []byte(hashedPassword), - Username: createUser.Username, - LoginType: database.LoginTypeBuiltIn, - CreatedAt: database.Now(), - UpdatedAt: database.Now(), - }) - if err != nil { - return xerrors.Errorf("create user: %w", err) - } - - privateKey, publicKey, err := gitsshkey.Generate(api.SSHKeygenAlgorithm) - if err != nil { - return xerrors.Errorf("generate user gitsshkey: %w", err) - } - _, err = db.InsertGitSSHKey(r.Context(), database.InsertGitSSHKeyParams{ - UserID: user.ID, - CreatedAt: database.Now(), - UpdatedAt: database.Now(), - PrivateKey: privateKey, - PublicKey: publicKey, - }) - if err != nil { - return xerrors.Errorf("insert user gitsshkey: %w", err) - } - - organization, err = api.Database.InsertOrganization(r.Context(), database.InsertOrganizationParams{ - ID: uuid.New(), - Name: createUser.OrganizationName, - CreatedAt: database.Now(), - UpdatedAt: database.Now(), - }) - if err != nil { - return xerrors.Errorf("create organization: %w", err) - } - _, err = api.Database.InsertOrganizationMember(r.Context(), database.InsertOrganizationMemberParams{ - OrganizationID: organization.ID, - UserID: user.ID, - CreatedAt: database.Now(), - UpdatedAt: database.Now(), - Roles: []string{"organization-admin"}, - }) - if err != nil { - return xerrors.Errorf("create organization member: %w", err) - } - return nil + user, organizationID, err := api.createUser(r.Context(), codersdk.CreateUserRequest{ + Email: createUser.Email, + Username: createUser.Username, + Password: createUser.Password, }) if err != nil { httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ @@ -141,7 +86,7 @@ func (api *api) postFirstUser(rw http.ResponseWriter, r *http.Request) { httpapi.Write(rw, http.StatusCreated, codersdk.CreateFirstUserResponse{ UserID: user.ID, - OrganizationID: organization.ID, + OrganizationID: organizationID, }) } @@ -262,56 +207,7 @@ func (api *api) postUsers(rw http.ResponseWriter, r *http.Request) { return } - hashedPassword, err := userpassword.Hash(createUser.Password) - if err != nil { - httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ - Message: fmt.Sprintf("hash password: %s", err.Error()), - }) - return - } - - var user database.User - err = api.Database.InTx(func(db database.Store) error { - user, err = db.InsertUser(r.Context(), database.InsertUserParams{ - ID: uuid.New(), - Email: createUser.Email, - HashedPassword: []byte(hashedPassword), - Username: createUser.Username, - LoginType: database.LoginTypeBuiltIn, - CreatedAt: database.Now(), - UpdatedAt: database.Now(), - }) - if err != nil { - return xerrors.Errorf("create user: %w", err) - } - - privateKey, publicKey, err := gitsshkey.Generate(api.SSHKeygenAlgorithm) - if err != nil { - return xerrors.Errorf("generate user gitsshkey: %w", err) - } - _, err = db.InsertGitSSHKey(r.Context(), database.InsertGitSSHKeyParams{ - UserID: user.ID, - CreatedAt: database.Now(), - UpdatedAt: database.Now(), - PrivateKey: privateKey, - PublicKey: publicKey, - }) - if err != nil { - return xerrors.Errorf("insert user gitsshkey: %w", err) - } - - _, err = db.InsertOrganizationMember(r.Context(), database.InsertOrganizationMemberParams{ - OrganizationID: organization.ID, - UserID: user.ID, - CreatedAt: database.Now(), - UpdatedAt: database.Now(), - Roles: []string{}, - }) - if err != nil { - return xerrors.Errorf("create organization member: %w", err) - } - return nil - }) + user, _, err := api.createUser(r.Context(), createUser) if err != nil { httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ Message: err.Error(), @@ -542,42 +438,14 @@ func (api *api) postLogin(rw http.ResponseWriter, r *http.Request) { return } - keyID, keySecret, err := generateAPIKeyIDSecret() - if err != nil { - httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ - Message: fmt.Sprintf("generate api key parts: %s", err.Error()), - }) - return - } - hashed := sha256.Sum256([]byte(keySecret)) - - _, err = api.Database.InsertAPIKey(r.Context(), database.InsertAPIKeyParams{ - ID: keyID, - UserID: user.ID, - ExpiresAt: database.Now().Add(24 * time.Hour), - CreatedAt: database.Now(), - UpdatedAt: database.Now(), - HashedSecret: hashed[:], - LoginType: database.LoginTypeBuiltIn, + sessionToken, created := api.createAPIKey(rw, r, database.InsertAPIKeyParams{ + UserID: user.ID, + LoginType: database.LoginTypePassword, }) - if err != nil { - httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ - Message: fmt.Sprintf("insert api key: %s", err.Error()), - }) + if !created { return } - // This format is consumed by the APIKey middleware. - sessionToken := fmt.Sprintf("%s-%s", keyID, keySecret) - http.SetCookie(rw, &http.Cookie{ - Name: httpmw.AuthCookie, - Value: sessionToken, - Path: "/", - HttpOnly: true, - SameSite: http.SameSiteLaxMode, - Secure: api.SecureAuthCookie, - }) - httpapi.Write(rw, http.StatusCreated, codersdk.LoginWithPasswordResponse{ SessionToken: sessionToken, }) @@ -595,35 +463,15 @@ func (api *api) postAPIKey(rw http.ResponseWriter, r *http.Request) { return } - keyID, keySecret, err := generateAPIKeyIDSecret() - if err != nil { - httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ - Message: fmt.Sprintf("generate api key parts: %s", err.Error()), - }) - return - } - hashed := sha256.Sum256([]byte(keySecret)) - - _, err = api.Database.InsertAPIKey(r.Context(), database.InsertAPIKeyParams{ - ID: keyID, - UserID: apiKey.UserID, - ExpiresAt: database.Now().AddDate(1, 0, 0), // Expire after 1 year (same as v1) - CreatedAt: database.Now(), - UpdatedAt: database.Now(), - HashedSecret: hashed[:], - LoginType: database.LoginTypeBuiltIn, + sessionToken, created := api.createAPIKey(rw, r, database.InsertAPIKeyParams{ + UserID: user.ID, + LoginType: database.LoginTypePassword, }) - if err != nil { - httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ - Message: fmt.Sprintf("insert api key: %s", err.Error()), - }) + if !created { return } - // This format is consumed by the APIKey middleware. - generatedAPIKey := fmt.Sprintf("%s-%s", keyID, keySecret) - - httpapi.Write(rw, http.StatusCreated, codersdk.GenerateAPIKeyResponse{Key: generatedAPIKey}) + httpapi.Write(rw, http.StatusCreated, codersdk.GenerateAPIKeyResponse{Key: sessionToken}) } // Clear the user's session cookie @@ -984,6 +832,117 @@ func generateAPIKeyIDSecret() (id string, secret string, err error) { return id, secret, nil } +func (api *api) createAPIKey(rw http.ResponseWriter, r *http.Request, params database.InsertAPIKeyParams) (string, bool) { + keyID, keySecret, err := generateAPIKeyIDSecret() + if err != nil { + httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ + Message: fmt.Sprintf("generate api key parts: %s", err.Error()), + }) + return "", false + } + hashed := sha256.Sum256([]byte(keySecret)) + + _, err = api.Database.InsertAPIKey(r.Context(), database.InsertAPIKeyParams{ + ID: keyID, + UserID: params.UserID, + ExpiresAt: database.Now().Add(24 * time.Hour), + CreatedAt: database.Now(), + UpdatedAt: database.Now(), + HashedSecret: hashed[:], + LoginType: params.LoginType, + OAuthAccessToken: params.OAuthAccessToken, + OAuthRefreshToken: params.OAuthRefreshToken, + OAuthIDToken: params.OAuthIDToken, + OAuthExpiry: params.OAuthExpiry, + }) + if err != nil { + httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ + Message: fmt.Sprintf("insert api key: %s", err.Error()), + }) + return "", false + } + + // This format is consumed by the APIKey middleware. + sessionToken := fmt.Sprintf("%s-%s", keyID, keySecret) + http.SetCookie(rw, &http.Cookie{ + Name: httpmw.AuthCookie, + Value: sessionToken, + Path: "/", + HttpOnly: true, + SameSite: http.SameSiteLaxMode, + Secure: api.SecureAuthCookie, + }) + return sessionToken, true +} + +func (api *api) createUser(ctx context.Context, req codersdk.CreateUserRequest) (database.User, uuid.UUID, error) { + var user database.User + return user, req.OrganizationID, api.Database.InTx(func(db database.Store) error { + // If no organization is provided, create a new one for the user. + if req.OrganizationID == uuid.Nil { + organization, err := db.InsertOrganization(ctx, database.InsertOrganizationParams{ + ID: uuid.New(), + Name: req.Username, + CreatedAt: database.Now(), + UpdatedAt: database.Now(), + }) + if err != nil { + return xerrors.Errorf("create organization: %w", err) + } + req.OrganizationID = organization.ID + } + + params := database.InsertUserParams{ + ID: uuid.New(), + Email: req.Email, + Username: req.Username, + LoginType: database.LoginTypePassword, + CreatedAt: database.Now(), + UpdatedAt: database.Now(), + } + // If a user signs up with OAuth, they can have no password! + if req.Password != "" { + hashedPassword, err := userpassword.Hash(req.Password) + if err != nil { + return xerrors.Errorf("hash password: %w", err) + } + params.HashedPassword = []byte(hashedPassword) + } + + var err error + user, err = db.InsertUser(ctx, params) + if err != nil { + return xerrors.Errorf("create user: %w", err) + } + + privateKey, publicKey, err := gitsshkey.Generate(api.SSHKeygenAlgorithm) + if err != nil { + return xerrors.Errorf("generate user gitsshkey: %w", err) + } + _, err = db.InsertGitSSHKey(ctx, database.InsertGitSSHKeyParams{ + UserID: user.ID, + CreatedAt: database.Now(), + UpdatedAt: database.Now(), + PrivateKey: privateKey, + PublicKey: publicKey, + }) + if err != nil { + return xerrors.Errorf("insert user gitsshkey: %w", err) + } + _, err = db.InsertOrganizationMember(ctx, database.InsertOrganizationMemberParams{ + OrganizationID: req.OrganizationID, + UserID: user.ID, + CreatedAt: database.Now(), + UpdatedAt: database.Now(), + Roles: []string{}, + }) + if err != nil { + return xerrors.Errorf("create organization member: %w", err) + } + return nil + }) +} + func convertUser(user database.User) codersdk.User { return codersdk.User{ ID: user.ID, diff --git a/coderd/users_test.go b/coderd/users_test.go index 0e08462f7537b..6c677da34d115 100644 --- a/coderd/users_test.go +++ b/coderd/users_test.go @@ -241,13 +241,14 @@ func TestUpdateUserProfile(t *testing.T) { t.Parallel() client := coderdtest.New(t, nil) user := coderdtest.CreateFirstUser(t, client) - existentUser, _ := client.CreateUser(context.Background(), codersdk.CreateUserRequest{ + existentUser, err := client.CreateUser(context.Background(), codersdk.CreateUserRequest{ Email: "bruno@coder.com", Username: "bruno", Password: "password", OrganizationID: user.OrganizationID, }) - _, err := client.UpdateUserProfile(context.Background(), codersdk.Me, codersdk.UpdateUserProfileRequest{ + require.NoError(t, err) + _, err = client.UpdateUserProfile(context.Background(), codersdk.Me, codersdk.UpdateUserProfileRequest{ Username: existentUser.Username, Email: "newemail@coder.com", }) diff --git a/codersdk/users.go b/codersdk/users.go index a4e3f7d0e12c2..283db8c93eb69 100644 --- a/codersdk/users.go +++ b/codersdk/users.go @@ -92,6 +92,12 @@ type CreateWorkspaceRequest struct { ParameterValues []CreateParameterRequest `json:"parameter_values"` } +// AuthMethods contains whether authentication types are enabled or not. +type AuthMethods struct { + Password bool `json:"password"` + Github bool `json:"github"` +} + // HasFirstUser returns whether the first user has been created. func (c *Client) HasFirstUser(ctx context.Context) (bool, error) { res, err := c.request(ctx, http.MethodGet, "/api/v2/users/first", nil) @@ -330,6 +336,22 @@ func (c *Client) WorkspaceByName(ctx context.Context, userID uuid.UUID, name str return workspace, json.NewDecoder(res.Body).Decode(&workspace) } +// AuthMethods returns types of authentication available to the user. +func (c *Client) AuthMethods(ctx context.Context) (AuthMethods, error) { + res, err := c.request(ctx, http.MethodGet, "/api/v2/users/authmethods", nil) + if err != nil { + return AuthMethods{}, err + } + defer res.Body.Close() + + if res.StatusCode != http.StatusOK { + return AuthMethods{}, readBodyAsError(res) + } + + var userAuth AuthMethods + return userAuth, json.NewDecoder(res.Body).Decode(&userAuth) +} + // uuidOrMe returns the provided uuid as a string if it's valid, ortherwise // `me`. func uuidOrMe(id uuid.UUID) string { diff --git a/go.mod b/go.mod index 3af02b66e1d28..56615fd79dc7a 100644 --- a/go.mod +++ b/go.mod @@ -61,6 +61,7 @@ require ( github.com/gohugoio/hugo v0.97.2 github.com/golang-jwt/jwt v3.2.2+incompatible github.com/golang-migrate/migrate/v4 v4.15.1 + github.com/google/go-github/v43 v43.0.1-0.20220414155304-00e42332e405 github.com/google/uuid v1.3.0 github.com/hashicorp/go-version v1.4.0 github.com/hashicorp/hc-install v0.3.1 @@ -157,6 +158,7 @@ require ( github.com/golang/protobuf v1.5.2 // indirect github.com/golang/snappy v0.0.4 // indirect github.com/google/go-cmp v0.5.7 // indirect + github.com/google/go-querystring v1.1.0 // indirect github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 // indirect github.com/gorilla/mux v1.8.0 // indirect github.com/hashicorp/errwrap v1.1.0 // indirect diff --git a/go.sum b/go.sum index c1911803d9df2..3a237b6c71cc8 100644 --- a/go.sum +++ b/go.sum @@ -784,7 +784,11 @@ github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/ github.com/google/go-cmp v0.5.7 h1:81/ik6ipDQS2aGcBfIN5dHDB36BwrStyeAQquSYCV4o= github.com/google/go-cmp v0.5.7/go.mod h1:n+brtR0CgQNWTVd5ZUFpTBC8YFBDLK/h/bpaJ8/DtOE= github.com/google/go-github/v35 v35.2.0/go.mod h1:s0515YVTI+IMrDoy9Y4pHt9ShGpzHvHO8rZ7L7acgvs= +github.com/google/go-github/v43 v43.0.1-0.20220414155304-00e42332e405 h1:DdHws/YnnPrSywrjNYu2lEHqYHWp/LnEx56w59esd54= +github.com/google/go-github/v43 v43.0.1-0.20220414155304-00e42332e405/go.mod h1:4RgUDSnsxP19d65zJWqvqJ/poJxBCvmna50eXmIvoR8= github.com/google/go-querystring v1.0.0/go.mod h1:odCYkC5MyYFN7vkCjXpyrEuKhc/BUO6wN/zVPAxq5ck= +github.com/google/go-querystring v1.1.0 h1:AnCroh3fv4ZBgVIf1Iwtovgjaw/GiKJo8M8yD/fhyJ8= +github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17icRSOU623lUBU= github.com/google/gofuzz v0.0.0-20161122191042-44d81051d367/go.mod h1:HP5RmnzzSNb993RKQDq4+1A4ia9nllfqcQFTQJedwGI= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/gofuzz v1.1.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= diff --git a/site/src/api/index.ts b/site/src/api/index.ts index 232a188a324ac..cd6a387f6d99c 100644 --- a/site/src/api/index.ts +++ b/site/src/api/index.ts @@ -2,6 +2,7 @@ import axios, { AxiosRequestHeaders } from "axios" import { mutate } from "swr" import { MockPager, MockUser, MockUser2 } from "../testHelpers/entities" import * as Types from "./types" +import * as TypesGen from "./typesGenerated" const CONTENT_TYPE_JSON: AxiosRequestHeaders = { "Content-Type": "application/json", @@ -65,6 +66,11 @@ export const getUser = async (): Promise => { return response.data } +export const getAuthMethods = async (): Promise => { + const response = await axios.get("/api/v2/users/authmethods") + return response.data +} + export const getApiKey = async (): Promise => { const response = await axios.post("/api/v2/users/me/keys") return response.data diff --git a/site/src/api/typesGenerated.ts b/site/src/api/typesGenerated.ts index 8bdbf3759af11..e0c38e4225ba5 100644 --- a/site/src/api/typesGenerated.ts +++ b/site/src/api/typesGenerated.ts @@ -132,6 +132,12 @@ export interface CreateWorkspaceRequest { readonly name: string } +// From codersdk/users.go:96:6. +export interface AuthMethods { + readonly password: boolean + readonly github: boolean +} + // From codersdk/workspaceagents.go:31:6. export interface GoogleInstanceIdentityToken { readonly json_web_token: string diff --git a/site/src/components/SignInForm/SignInForm.stories.tsx b/site/src/components/SignInForm/SignInForm.stories.tsx index bc6a80840c2d3..90bdf0a39a73f 100644 --- a/site/src/components/SignInForm/SignInForm.stories.tsx +++ b/site/src/components/SignInForm/SignInForm.stories.tsx @@ -24,7 +24,26 @@ SignedOut.args = { } export const Loading = Template.bind({}) -Loading.args = { ...SignedOut.args, isLoading: true } +Loading.args = { + ...SignedOut.args, + isLoading: true, + authMethods: { + github: true, + password: true, + }, +} + +export const WithLoginError = Template.bind({}) +WithLoginError.args = { ...SignedOut.args, authErrorMessage: "Email or password was invalid" } -export const WithError = Template.bind({}) -WithError.args = { ...SignedOut.args, authErrorMessage: "Email or password was invalid" } +export const WithAuthMethodsError = Template.bind({}) +WithAuthMethodsError.args = { ...SignedOut.args, methodsErrorMessage: "Failed to fetch auth methods" } + +export const WithGithub = Template.bind({}) +WithGithub.args = { + ...SignedOut.args, + authMethods: { + password: true, + github: true, + }, +} diff --git a/site/src/components/SignInForm/SignInForm.tsx b/site/src/components/SignInForm/SignInForm.tsx index 75e9ba43defdd..f1c11659112e8 100644 --- a/site/src/components/SignInForm/SignInForm.tsx +++ b/site/src/components/SignInForm/SignInForm.tsx @@ -1,9 +1,12 @@ +import Button from "@material-ui/core/Button" import FormHelperText from "@material-ui/core/FormHelperText" +import Link from "@material-ui/core/Link" import { makeStyles } from "@material-ui/core/styles" import TextField from "@material-ui/core/TextField" import { FormikContextType, useFormik } from "formik" import React from "react" import * as Yup from "yup" +import { AuthMethods } from "../../api/typesGenerated" import { getFormHelpers, onChangeTrimmed } from "../../util/formUtils" import { Welcome } from "../Welcome/Welcome" import { LoadingButton } from "./../LoadingButton/LoadingButton" @@ -24,7 +27,9 @@ export const Language = { emailInvalid: "Please enter a valid email address.", emailRequired: "Please enter an email address.", authErrorMessage: "Incorrect email or password.", - signIn: "Sign In", + methodsErrorMessage: "Unable to fetch auth methods.", + passwordSignIn: "Sign In", + githubSignIn: "GitHub", } const validationSchema = Yup.object({ @@ -49,10 +54,18 @@ const useStyles = makeStyles((theme) => ({ export interface SignInFormProps { isLoading: boolean authErrorMessage?: string + methodsErrorMessage?: string + authMethods?: AuthMethods onSubmit: ({ email, password }: { email: string; password: string }) => Promise } -export const SignInForm: React.FC = ({ isLoading, authErrorMessage, onSubmit }) => { +export const SignInForm: React.FC = ({ + authMethods, + isLoading, + authErrorMessage, + methodsErrorMessage, + onSubmit, +}) => { const styles = useStyles() const form: FormikContextType = useFormik({ @@ -76,6 +89,7 @@ export const SignInForm: React.FC = ({ isLoading, authErrorMess className={styles.loginTextField} fullWidth label={Language.emailLabel} + type="email" variant="outlined" /> = ({ isLoading, authErrorMess variant="outlined" /> {authErrorMessage && {Language.authErrorMessage}} + {methodsErrorMessage && {Language.methodsErrorMessage}}
- {isLoading ? "" : Language.signIn} + {isLoading ? "" : Language.passwordSignIn}
+ {authMethods?.github && ( +
+ + + +
+ )} ) } diff --git a/site/src/pages/LoginPage/LoginPage.test.tsx b/site/src/pages/LoginPage/LoginPage.test.tsx index f9c4fb8ebee5f..1d5e8c2abf771 100644 --- a/site/src/pages/LoginPage/LoginPage.test.tsx +++ b/site/src/pages/LoginPage/LoginPage.test.tsx @@ -23,7 +23,7 @@ describe("LoginPage", () => { render() // Then - await screen.findByText(Language.signIn) + await screen.findByText(Language.passwordSignIn) }) it("shows an error message if SignIn fails", async () => { @@ -42,7 +42,7 @@ describe("LoginPage", () => { await userEvent.type(email, "test@coder.com") await userEvent.type(password, "password") // Click sign-in - const signInButton = await screen.findByText(Language.signIn) + const signInButton = await screen.findByText(Language.passwordSignIn) act(() => signInButton.click()) // Then @@ -50,4 +50,43 @@ describe("LoginPage", () => { expect(errorMessage).toBeDefined() expect(history.location.pathname).toEqual("/login") }) + + it("shows an error if fetching auth methods fails", async () => { + // Given + server.use( + // Make login fail + rest.get("/api/v2/users/authmethods", async (req, res, ctx) => { + return res(ctx.status(500), ctx.json({ message: "nope" })) + }), + ) + + // When + render() + + // Then + const errorMessage = await screen.findByText(Language.methodsErrorMessage) + expect(errorMessage).toBeDefined() + }) + + it("shows github authentication when enabled", async () => { + // Given + server.use( + rest.get("/api/v2/users/authmethods", async (req, res, ctx) => { + return res( + ctx.status(200), + ctx.json({ + password: true, + github: true, + }), + ) + }), + ) + + // When + render() + + // Then + await screen.findByText(Language.passwordSignIn) + await screen.findByText(Language.githubSignIn) + }) }) diff --git a/site/src/pages/LoginPage/LoginPage.tsx b/site/src/pages/LoginPage/LoginPage.tsx index 46b0ab0859abf..75f9f6d8a0928 100644 --- a/site/src/pages/LoginPage/LoginPage.tsx +++ b/site/src/pages/LoginPage/LoginPage.tsx @@ -35,6 +35,9 @@ export const LoginPage: React.FC = () => { const isLoading = authState.hasTag("loading") const redirectTo = retrieveRedirect(location.search) const authErrorMessage = authState.context.authError ? (authState.context.authError as Error).message : undefined + const getMethodsError = authState.context.getMethodsError + ? (authState.context.getMethodsError as Error).message + : undefined const onSubmit = async ({ email, password }: { email: string; password: string }) => { authSend({ type: "SIGN_IN", email, password }) @@ -47,7 +50,13 @@ export const LoginPage: React.FC = () => {
- +