From a63d27b72a3c54401d2de08701f6159aae92e409 Mon Sep 17 00:00:00 2001 From: Kyle Carberry Date: Fri, 15 Apr 2022 14:40:17 +0000 Subject: [PATCH 1/9] Initial oauth --- coderd/coderd.go | 7 + coderd/coderdtest/coderdtest.go | 2 + coderd/database/databasefake/databasefake.go | 10 + coderd/database/querier.go | 1 + coderd/database/queries.sql.go | 36 +++ coderd/database/queries/organizations.sql | 6 + coderd/httpmw/oauth.go | 124 +++++++++ coderd/httpmw/oauth_test.go | 97 +++++++ coderd/userauth.go | 132 ++++++++++ coderd/userauth_test.go | 147 +++++++++++ coderd/users.go | 252 ++++++++----------- go.mod | 2 + go.sum | 4 + peerbroker/proto/peerbroker.pb.go | 2 +- provisionerd/proto/provisionerd.pb.go | 2 +- provisionersdk/proto/provisioner.pb.go | 2 +- 16 files changed, 672 insertions(+), 154 deletions(-) create mode 100644 coderd/httpmw/oauth.go create mode 100644 coderd/httpmw/oauth_test.go create mode 100644 coderd/userauth.go create mode 100644 coderd/userauth_test.go diff --git a/coderd/coderd.go b/coderd/coderd.go index dde99869c5464..153092bf07283 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -34,6 +34,7 @@ type Options struct { AWSCertificates awsidentity.Certificates GoogleTokenValidator *idtoken.Validator + GithubOAuth2Provider GithubOAuth2Provider SecureAuthCookie bool SSHKeygenAlgorithm gitsshkey.Algorithm @@ -142,6 +143,12 @@ func New(options *Options) (http.Handler, func()) { r.Post("/first", api.postFirstUser) r.Post("/login", api.postLogin) r.Post("/logout", api.postLogout) + r.Route("/auth", func(r chi.Router) { + r.Route("/callback/github", func(r chi.Router) { + r.Use(httpmw.ExtractOAuth2(options.GithubOAuth2Provider)) + r.Get("/", api.userAuthGithub) + }) + }) r.Group(func(r chi.Router) { r.Use(httpmw.ExtractAPIKey(options.Database, nil)) r.Post("/", api.postUsers) diff --git a/coderd/coderdtest/coderdtest.go b/coderd/coderdtest/coderdtest.go index 7f30e992d23d6..c2679238a2df4 100644 --- a/coderd/coderdtest/coderdtest.go +++ b/coderd/coderdtest/coderdtest.go @@ -49,6 +49,7 @@ import ( type Options struct { AWSInstanceIdentity awsidentity.Certificates + GithubOAuth2Provider coderd.GithubOAuth2Provider GoogleInstanceIdentity *idtoken.Validator SSHKeygenAlgorithm gitsshkey.Algorithm } @@ -115,6 +116,7 @@ func New(t *testing.T, options *Options) *codersdk.Client { Pubsub: pubsub, AWSCertificates: options.AWSInstanceIdentity, + GithubOAuth2Provider: options.GithubOAuth2Provider, GoogleTokenValidator: options.GoogleInstanceIdentity, SSHKeygenAlgorithm: options.SSHKeygenAlgorithm, }) diff --git a/coderd/database/databasefake/databasefake.go b/coderd/database/databasefake/databasefake.go index 8e582d85c5ddc..556ba5d94c1f3 100644 --- a/coderd/database/databasefake/databasefake.go +++ b/coderd/database/databasefake/databasefake.go @@ -365,6 +365,16 @@ func (q *fakeQuerier) GetWorkspacesByUserID(_ context.Context, req database.GetW return workspaces, nil } +func (q *fakeQuerier) GetOrganizations(_ context.Context) ([]database.Organization, error) { + q.mutex.RLock() + defer q.mutex.RUnlock() + + if len(q.organizations) == 0 { + return nil, sql.ErrNoRows + } + return q.organizations, nil +} + func (q *fakeQuerier) GetOrganizationByID(_ context.Context, id uuid.UUID) (database.Organization, error) { q.mutex.RLock() defer q.mutex.RUnlock() diff --git a/coderd/database/querier.go b/coderd/database/querier.go index 073113a2451df..51c9e7cf30e51 100644 --- a/coderd/database/querier.go +++ b/coderd/database/querier.go @@ -18,6 +18,7 @@ type querier interface { GetOrganizationByID(ctx context.Context, id uuid.UUID) (Organization, error) GetOrganizationByName(ctx context.Context, name string) (Organization, error) GetOrganizationMemberByUserID(ctx context.Context, arg GetOrganizationMemberByUserIDParams) (OrganizationMember, error) + GetOrganizations(ctx context.Context) ([]Organization, error) GetOrganizationsByUserID(ctx context.Context, userID uuid.UUID) ([]Organization, error) GetParameterSchemasByJobID(ctx context.Context, jobID uuid.UUID) ([]ParameterSchema, error) GetParameterValueByScopeAndName(ctx context.Context, arg GetParameterValueByScopeAndNameParams) (ParameterValue, error) diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index b8550e38a1180..c43a84959ddd6 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -453,6 +453,42 @@ func (q *sqlQuerier) GetOrganizationByName(ctx context.Context, name string) (Or return i, err } +const getOrganizations = `-- name: GetOrganizations :many +SELECT + id, name, description, created_at, updated_at +FROM + organizations +` + +func (q *sqlQuerier) GetOrganizations(ctx context.Context) ([]Organization, error) { + rows, err := q.db.QueryContext(ctx, getOrganizations) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Organization + for rows.Next() { + var i Organization + if err := rows.Scan( + &i.ID, + &i.Name, + &i.Description, + &i.CreatedAt, + &i.UpdatedAt, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const getOrganizationsByUserID = `-- name: GetOrganizationsByUserID :many SELECT id, name, description, created_at, updated_at diff --git a/coderd/database/queries/organizations.sql b/coderd/database/queries/organizations.sql index 1682c04a8fd95..87c403049efd2 100644 --- a/coderd/database/queries/organizations.sql +++ b/coderd/database/queries/organizations.sql @@ -1,3 +1,9 @@ +-- name: GetOrganizations :many +SELECT + * +FROM + organizations; + -- name: GetOrganizationByID :one SELECT * diff --git a/coderd/httpmw/oauth.go b/coderd/httpmw/oauth.go new file mode 100644 index 0000000000000..f96b112fbaa2d --- /dev/null +++ b/coderd/httpmw/oauth.go @@ -0,0 +1,124 @@ +package httpmw + +import ( + "context" + "fmt" + "net/http" + + "golang.org/x/oauth2" + + "github.com/coder/coder/coderd/httpapi" + "github.com/coder/coder/cryptorand" +) + +const ( + oauth2StateCookieName = "oauth_state" + oauth2RedirectCookieName = "oauth_redirect" +) + +type oauth2StateKey struct{} + +type OAuth2State struct { + Token *oauth2.Token + Redirect string +} + +// OAuth2Provider exposes a subset of *oauth2.Config functions for easier testing. +// *oauth2.Config should be used instead of implementing this in production. +type OAuth2Provider interface { + AuthCodeURL(state string, opts ...oauth2.AuthCodeOption) string + Exchange(ctx context.Context, code string, opts ...oauth2.AuthCodeOption) (*oauth2.Token, error) +} + +// OAuth2 returns the state from an oauth request. +func OAuth2(r *http.Request) OAuth2State { + oauth, ok := r.Context().Value(oauth2StateKey{}).(OAuth2State) + if !ok { + panic("developer error: oauth middleware not provided") + } + return oauth +} + +// ExtractOAuth2 adds a middleware for handling OAuth2 callbacks. +// Any route that does not have a "code" URL parameter will be redirected +// to the handler configuration provided. +func ExtractOAuth2(provider OAuth2Provider) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + code := r.URL.Query().Get("code") + state := r.URL.Query().Get("state") + + if code == "" { + // If the code isn't provided, we'll redirect! + state, err := cryptorand.String(32) + if err != nil { + httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ + Message: fmt.Sprintf("generate state string: %s", err), + }) + return + } + + http.SetCookie(rw, &http.Cookie{ + Name: oauth2StateCookieName, + Value: state, + Path: "/", + HttpOnly: true, + SameSite: http.SameSiteStrictMode, + }) + // Redirect must always be specified, otherwise + // an old redirect could apply! + http.SetCookie(rw, &http.Cookie{ + Name: oauth2RedirectCookieName, + Value: r.URL.Query().Get("redirect"), + Path: "/", + HttpOnly: true, + SameSite: http.SameSiteStrictMode, + }) + + http.Redirect(rw, r, provider.AuthCodeURL(state, oauth2.AccessTypeOffline), http.StatusTemporaryRedirect) + return + } + + if state == "" { + httpapi.Write(rw, http.StatusBadRequest, httpapi.Response{ + Message: "state must be provided", + }) + return + } + + stateCookie, err := r.Cookie(oauth2StateCookieName) + if err != nil { + httpapi.Write(rw, http.StatusUnauthorized, httpapi.Response{ + Message: fmt.Sprintf("%q cookie must be provided", oauth2StateCookieName), + }) + return + } + if stateCookie.Value != state { + httpapi.Write(rw, http.StatusUnauthorized, httpapi.Response{ + Message: "state mismatched", + }) + return + } + + var redirect string + stateRedirect, err := r.Cookie(oauth2RedirectCookieName) + if err == nil { + redirect = stateRedirect.Value + } + + oauthToken, err := provider.Exchange(r.Context(), code) + if err != nil { + httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ + Message: fmt.Sprintf("exchange oauth code: %s", err), + }) + return + } + + ctx := context.WithValue(r.Context(), oauth2StateKey{}, OAuth2State{ + Token: oauthToken, + Redirect: redirect, + }) + next.ServeHTTP(rw, r.WithContext(ctx)) + }) + } +} diff --git a/coderd/httpmw/oauth_test.go b/coderd/httpmw/oauth_test.go new file mode 100644 index 0000000000000..6142b15853243 --- /dev/null +++ b/coderd/httpmw/oauth_test.go @@ -0,0 +1,97 @@ +package httpmw_test + +import ( + "context" + "net/http" + "net/http/httptest" + "net/url" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/oauth2" + + "github.com/coder/coder/coderd/httpmw" +) + +type testOAuth2Provider struct { +} + +func (*testOAuth2Provider) AuthCodeURL(state string, opts ...oauth2.AuthCodeOption) string { + return "?state=" + url.QueryEscape(state) +} + +func (*testOAuth2Provider) Exchange(ctx context.Context, code string, opts ...oauth2.AuthCodeOption) (*oauth2.Token, error) { + return &oauth2.Token{ + AccessToken: "hello", + }, nil +} + +func TestOAuth2(t *testing.T) { + t.Parallel() + t.Run("RedirectWithoutCode", func(t *testing.T) { + t.Parallel() + req := httptest.NewRequest("GET", "/?redirect="+url.QueryEscape("/dashboard"), nil) + res := httptest.NewRecorder() + httpmw.ExtractOAuth2(&testOAuth2Provider{})(nil).ServeHTTP(res, req) + location := res.Header().Get("Location") + if !assert.NotEmpty(t, location) { + return + } + require.Len(t, res.Result().Cookies(), 2) + cookie := res.Result().Cookies()[1] + require.Equal(t, "/dashboard", cookie.Value) + }) + t.Run("NoState", func(t *testing.T) { + t.Parallel() + req := httptest.NewRequest("GET", "/?code=something", nil) + res := httptest.NewRecorder() + httpmw.ExtractOAuth2(&testOAuth2Provider{})(nil).ServeHTTP(res, req) + require.Equal(t, http.StatusBadRequest, res.Result().StatusCode) + }) + t.Run("NoStateCookie", func(t *testing.T) { + t.Parallel() + req := httptest.NewRequest("GET", "/?code=something&state=test", nil) + res := httptest.NewRecorder() + httpmw.ExtractOAuth2(&testOAuth2Provider{})(nil).ServeHTTP(res, req) + require.Equal(t, http.StatusUnauthorized, res.Result().StatusCode) + }) + t.Run("MismatchedState", func(t *testing.T) { + t.Parallel() + req := httptest.NewRequest("GET", "/?code=something&state=test", nil) + req.AddCookie(&http.Cookie{ + Name: "oauth_state", + Value: "mismatch", + }) + res := httptest.NewRecorder() + httpmw.ExtractOAuth2(&testOAuth2Provider{})(nil).ServeHTTP(res, req) + require.Equal(t, http.StatusUnauthorized, res.Result().StatusCode) + }) + t.Run("ExchangeCodeAndState", func(t *testing.T) { + t.Parallel() + req := httptest.NewRequest("GET", "/?code=test&state=something", nil) + req.AddCookie(&http.Cookie{ + Name: "oauth_state", + Value: "something", + }) + req.AddCookie(&http.Cookie{ + Name: "oauth_redirect", + Value: "/dashboard", + }) + res := httptest.NewRecorder() + httpmw.ExtractOAuth2(&testOAuth2Provider{})(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + state := httpmw.OAuth2(r) + require.Equal(t, "/dashboard", state.Redirect) + })).ServeHTTP(res, req) + }) + + // t.Run("ExchangeCodeAndState", func(t *testing.T) { + // t.Parallel() + // req := httptest.NewRequest("GET", "/?code=test&state="+url.QueryEscape(state), nil) + // res := httptest.NewRecorder() + // ExtractOAuth(log, cipher, &testOAuthProvider{})(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + // rw.WriteHeader(http.StatusOK) + // })).ServeHTTP(res, req) + // assert.Equal(t, res.Result().StatusCode, http.StatusOK) + // }) +} diff --git a/coderd/userauth.go b/coderd/userauth.go new file mode 100644 index 0000000000000..080045ec1e4fe --- /dev/null +++ b/coderd/userauth.go @@ -0,0 +1,132 @@ +package coderd + +import ( + "context" + "database/sql" + "errors" + "fmt" + "net/http" + + "github.com/google/go-github/v43/github" + "github.com/google/uuid" + "golang.org/x/oauth2" + + "github.com/coder/coder/coderd/database" + "github.com/coder/coder/coderd/httpapi" + "github.com/coder/coder/coderd/httpmw" + "github.com/coder/coder/codersdk" +) + +// GithubOAuth2Provider exposes required functions for the Github authentication flow. +type GithubOAuth2Provider interface { + httpmw.OAuth2Provider + PersonalUser(ctx context.Context, client *github.Client) (*github.User, error) + ListEmails(ctx context.Context, client *github.Client) ([]*github.UserEmail, error) +} + +func (api *api) userAuthGithub(rw http.ResponseWriter, r *http.Request) { + state := httpmw.OAuth2(r) + + ghClient := github.NewClient(oauth2.NewClient(r.Context(), oauth2.StaticTokenSource(state.Token))) + ghUser, err := api.GithubOAuth2Provider.PersonalUser(r.Context(), ghClient) + if err != nil { + httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ + Message: fmt.Sprintf("get personal github user: %s", err), + }) + return + } + emails, err := api.GithubOAuth2Provider.ListEmails(r.Context(), ghClient) + if err != nil { + httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ + Message: fmt.Sprintf("get personal github user: %s", err), + }) + return + } + var user database.User + // Search for existing users with matching and verified emails. + // If a verified GitHub email matches a Coder user, we will + // return. + for _, email := range emails { + if email.Verified == nil { + continue + } + if !*email.Verified { + continue + } + user, err = api.Database.GetUserByEmailOrUsername(r.Context(), database.GetUserByEmailOrUsernameParams{ + Username: *ghUser.Name, + Email: *email.Email, + }) + if errors.Is(err, sql.ErrNoRows) { + continue + } + if err != nil { + httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ + Message: fmt.Sprintf("get user by email: %s", err), + }) + return + } + break + } + // If the user doesn't exist, create a new one! + if user.ID == uuid.Nil { + userCount, err := api.Database.GetUserCount(r.Context()) + if err != nil { + httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ + Message: fmt.Sprintf("get user count: %s", err.Error()), + }) + return + } + var organization database.Organization + // If there aren't any users yet, create one! + if userCount == 0 { + organization, err = api.Database.InsertOrganization(r.Context(), database.InsertOrganizationParams{ + ID: uuid.New(), + Name: *ghUser.Name, + CreatedAt: database.Now(), + UpdatedAt: database.Now(), + }) + if err != nil { + httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ + Message: fmt.Sprintf("create organization: %s", err), + }) + return + } + } else { + organizations, err := api.Database.GetOrganizations(r.Context()) + if err != nil { + httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ + Message: fmt.Sprintf("get organizations: %s", err), + }) + return + } + // Add the user to the first organization. Once multi-organization + // support is added, we should enable a configuration map of user + // email to organization. + organization = organizations[0] + } + + user, err = api.createUser(r.Context(), api.Database, codersdk.CreateUserRequest{ + Email: *ghUser.Email, + Username: *ghUser.Name, + OrganizationID: organization.ID, + }) + if err != nil { + httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ + Message: fmt.Sprintf("create user: %s", err), + }) + return + } + } + + _, created := api.createAPIKey(rw, r, user.ID) + if !created { + return + } + + redirect := state.Redirect + if redirect == "" { + redirect = "/" + } + http.Redirect(rw, r, redirect, http.StatusTemporaryRedirect) +} diff --git a/coderd/userauth_test.go b/coderd/userauth_test.go new file mode 100644 index 0000000000000..3db48546bd406 --- /dev/null +++ b/coderd/userauth_test.go @@ -0,0 +1,147 @@ +package coderd_test + +import ( + "context" + "net/http" + "net/url" + "testing" + + "github.com/coder/coder/coderd/coderdtest" + "github.com/coder/coder/codersdk" + "github.com/google/go-github/v43/github" + "github.com/stretchr/testify/require" + "golang.org/x/oauth2" +) + +type githubOAuthProvider struct{} + +func (g *githubOAuthProvider) AuthCodeURL(state string, opts ...oauth2.AuthCodeOption) string { + return "/?state=" + url.QueryEscape(state) +} + +func (g *githubOAuthProvider) Exchange(ctx context.Context, code string, opts ...oauth2.AuthCodeOption) (*oauth2.Token, error) { + return &oauth2.Token{ + AccessToken: "token", + }, nil +} + +func (g *githubOAuthProvider) PersonalUser(ctx context.Context, client *github.Client) (*github.User, error) { + return &github.User{ + ID: github.Int64(1), + Login: github.String("testuser"), + Name: github.String("some user"), + Email: github.String("wow@test.io"), + AvatarURL: github.String("https://coder.com/avatar.png"), + }, nil +} + +func (g *githubOAuthProvider) ListEmails(ctx context.Context, client *github.Client) ([]*github.UserEmail, error) { + return []*github.UserEmail{{ + Email: github.String("someone@io.io"), + Primary: github.Bool(true), + Verified: github.Bool(true), + }, { + Email: github.String("ok@io.io"), + Primary: github.Bool(false), + Verified: github.Bool(false), + }}, nil +} + +func TestUserAuthGithub(t *testing.T) { + t.Parallel() + t.Run("FirstUser", func(t *testing.T) { + t.Parallel() + client := coderdtest.New(t, &coderdtest.Options{ + GithubOAuth2Provider: &githubOAuthProvider{}, + }) + client.HTTPClient.CheckRedirect = func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + } + + state := "somestate" + oauthURL, err := client.URL.Parse("/api/v2/users/auth/callback/github?code=asd&state=" + state) + require.NoError(t, err) + req, err := http.NewRequest("GET", oauthURL.String(), nil) + require.NoError(t, err) + req.AddCookie(&http.Cookie{ + Name: "oauth_state", + Value: state, + }) + res, err := client.HTTPClient.Do(req) + require.NoError(t, err) + defer res.Body.Close() + require.Equal(t, http.StatusTemporaryRedirect, res.StatusCode) + }) + t.Run("NewUser", func(t *testing.T) { + t.Parallel() + client := coderdtest.New(t, &coderdtest.Options{ + GithubOAuth2Provider: &githubOAuthProvider{}, + }) + client.HTTPClient.CheckRedirect = func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + } + _, err := client.CreateFirstUser(context.Background(), codersdk.CreateFirstUserRequest{ + Email: "someone@io.io", + Username: "someone", + Password: "testing", + OrganizationName: "acme-corp", + }) + require.NoError(t, err) + token, err := client.LoginWithPassword(context.Background(), codersdk.LoginWithPasswordRequest{ + Email: "someone@io.io", + Password: "testing", + }) + require.NoError(t, err) + client.SessionToken = token.SessionToken + + state := "somestate" + oauthURL, err := client.URL.Parse("/api/v2/users/auth/callback/github?code=asd&state=" + state) + require.NoError(t, err) + req, err := http.NewRequest("GET", oauthURL.String(), nil) + require.NoError(t, err) + req.AddCookie(&http.Cookie{ + Name: "oauth_state", + Value: state, + }) + res, err := client.HTTPClient.Do(req) + require.NoError(t, err) + defer res.Body.Close() + require.Equal(t, http.StatusTemporaryRedirect, res.StatusCode) + }) + t.Run("ExistingUser", func(t *testing.T) { + t.Parallel() + client := coderdtest.New(t, &coderdtest.Options{ + GithubOAuth2Provider: &githubOAuthProvider{}, + }) + client.HTTPClient.CheckRedirect = func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + } + _, err := client.CreateFirstUser(context.Background(), codersdk.CreateFirstUserRequest{ + Email: "someone@io.io", + Username: "someone", + Password: "testing", + OrganizationName: "acme-corp", + }) + require.NoError(t, err) + token, err := client.LoginWithPassword(context.Background(), codersdk.LoginWithPasswordRequest{ + Email: "someone@io.io", + Password: "testing", + }) + require.NoError(t, err) + client.SessionToken = token.SessionToken + + state := "somestate" + oauthURL, err := client.URL.Parse("/api/v2/users/auth/callback/github?code=asd&state=" + state) + require.NoError(t, err) + req, err := http.NewRequest("GET", oauthURL.String(), nil) + require.NoError(t, err) + req.AddCookie(&http.Cookie{ + Name: "oauth_state", + Value: state, + }) + res, err := client.HTTPClient.Do(req) + require.NoError(t, err) + defer res.Body.Close() + require.Equal(t, http.StatusTemporaryRedirect, res.StatusCode) + }) +} diff --git a/coderd/users.go b/coderd/users.go index 6fcab0814033f..030c22976aedb 100644 --- a/coderd/users.go +++ b/coderd/users.go @@ -1,6 +1,7 @@ package coderd import ( + "context" "crypto/sha256" "database/sql" "encoding/json" @@ -69,46 +70,10 @@ func (api *api) postFirstUser(rw http.ResponseWriter, r *http.Request) { return } - hashedPassword, err := userpassword.Hash(createUser.Password) - if err != nil { - httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ - Message: fmt.Sprintf("hash password: %s", err.Error()), - }) - return - } - // Create the user, organization, and membership to the user. var user database.User var organization database.Organization err = api.Database.InTx(func(db database.Store) error { - user, err = api.Database.InsertUser(r.Context(), database.InsertUserParams{ - ID: uuid.New(), - Email: createUser.Email, - HashedPassword: []byte(hashedPassword), - Username: createUser.Username, - LoginType: database.LoginTypeBuiltIn, - CreatedAt: database.Now(), - UpdatedAt: database.Now(), - }) - if err != nil { - return xerrors.Errorf("create user: %w", err) - } - - privateKey, publicKey, err := gitsshkey.Generate(api.SSHKeygenAlgorithm) - if err != nil { - return xerrors.Errorf("generate user gitsshkey: %w", err) - } - _, err = db.InsertGitSSHKey(r.Context(), database.InsertGitSSHKeyParams{ - UserID: user.ID, - CreatedAt: database.Now(), - UpdatedAt: database.Now(), - PrivateKey: privateKey, - PublicKey: publicKey, - }) - if err != nil { - return xerrors.Errorf("insert user gitsshkey: %w", err) - } - organization, err = api.Database.InsertOrganization(r.Context(), database.InsertOrganizationParams{ ID: uuid.New(), Name: createUser.OrganizationName, @@ -118,15 +83,14 @@ func (api *api) postFirstUser(rw http.ResponseWriter, r *http.Request) { if err != nil { return xerrors.Errorf("create organization: %w", err) } - _, err = api.Database.InsertOrganizationMember(r.Context(), database.InsertOrganizationMemberParams{ + user, err = api.createUser(r.Context(), db, codersdk.CreateUserRequest{ + Email: createUser.Email, + Username: createUser.Username, + Password: createUser.Password, OrganizationID: organization.ID, - UserID: user.ID, - CreatedAt: database.Now(), - UpdatedAt: database.Now(), - Roles: []string{"organization-admin"}, }) if err != nil { - return xerrors.Errorf("create organization member: %w", err) + return xerrors.Errorf("create user: %w", err) } return nil }) @@ -199,56 +163,7 @@ func (api *api) postUsers(rw http.ResponseWriter, r *http.Request) { return } - hashedPassword, err := userpassword.Hash(createUser.Password) - if err != nil { - httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ - Message: fmt.Sprintf("hash password: %s", err.Error()), - }) - return - } - - var user database.User - err = api.Database.InTx(func(db database.Store) error { - user, err = db.InsertUser(r.Context(), database.InsertUserParams{ - ID: uuid.New(), - Email: createUser.Email, - HashedPassword: []byte(hashedPassword), - Username: createUser.Username, - LoginType: database.LoginTypeBuiltIn, - CreatedAt: database.Now(), - UpdatedAt: database.Now(), - }) - if err != nil { - return xerrors.Errorf("create user: %w", err) - } - - privateKey, publicKey, err := gitsshkey.Generate(api.SSHKeygenAlgorithm) - if err != nil { - return xerrors.Errorf("generate user gitsshkey: %w", err) - } - _, err = db.InsertGitSSHKey(r.Context(), database.InsertGitSSHKeyParams{ - UserID: user.ID, - CreatedAt: database.Now(), - UpdatedAt: database.Now(), - PrivateKey: privateKey, - PublicKey: publicKey, - }) - if err != nil { - return xerrors.Errorf("insert user gitsshkey: %w", err) - } - - _, err = db.InsertOrganizationMember(r.Context(), database.InsertOrganizationMemberParams{ - OrganizationID: organization.ID, - UserID: user.ID, - CreatedAt: database.Now(), - UpdatedAt: database.Now(), - Roles: []string{}, - }) - if err != nil { - return xerrors.Errorf("create organization member: %w", err) - } - return nil - }) + user, err := api.createUser(r.Context(), api.Database, createUser) if err != nil { httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ Message: err.Error(), @@ -479,41 +394,10 @@ func (api *api) postLogin(rw http.ResponseWriter, r *http.Request) { return } - keyID, keySecret, err := generateAPIKeyIDSecret() - if err != nil { - httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ - Message: fmt.Sprintf("generate api key parts: %s", err.Error()), - }) + sessionToken, created := api.createAPIKey(rw, r, user.ID) + if !created { return } - hashed := sha256.Sum256([]byte(keySecret)) - - _, err = api.Database.InsertAPIKey(r.Context(), database.InsertAPIKeyParams{ - ID: keyID, - UserID: user.ID, - ExpiresAt: database.Now().Add(24 * time.Hour), - CreatedAt: database.Now(), - UpdatedAt: database.Now(), - HashedSecret: hashed[:], - LoginType: database.LoginTypeBuiltIn, - }) - if err != nil { - httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ - Message: fmt.Sprintf("insert api key: %s", err.Error()), - }) - return - } - - // This format is consumed by the APIKey middleware. - sessionToken := fmt.Sprintf("%s-%s", keyID, keySecret) - http.SetCookie(rw, &http.Cookie{ - Name: httpmw.AuthCookie, - Value: sessionToken, - Path: "/", - HttpOnly: true, - SameSite: http.SameSiteLaxMode, - Secure: api.SecureAuthCookie, - }) httpapi.Write(rw, http.StatusCreated, codersdk.LoginWithPasswordResponse{ SessionToken: sessionToken, @@ -532,35 +416,12 @@ func (api *api) postAPIKey(rw http.ResponseWriter, r *http.Request) { return } - keyID, keySecret, err := generateAPIKeyIDSecret() - if err != nil { - httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ - Message: fmt.Sprintf("generate api key parts: %s", err.Error()), - }) - return - } - hashed := sha256.Sum256([]byte(keySecret)) - - _, err = api.Database.InsertAPIKey(r.Context(), database.InsertAPIKeyParams{ - ID: keyID, - UserID: apiKey.UserID, - ExpiresAt: database.Now().AddDate(1, 0, 0), // Expire after 1 year (same as v1) - CreatedAt: database.Now(), - UpdatedAt: database.Now(), - HashedSecret: hashed[:], - LoginType: database.LoginTypeBuiltIn, - }) - if err != nil { - httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ - Message: fmt.Sprintf("insert api key: %s", err.Error()), - }) + sessionToken, created := api.createAPIKey(rw, r, user.ID) + if !created { return } - // This format is consumed by the APIKey middleware. - generatedAPIKey := fmt.Sprintf("%s-%s", keyID, keySecret) - - httpapi.Write(rw, http.StatusCreated, codersdk.GenerateAPIKeyResponse{Key: generatedAPIKey}) + httpapi.Write(rw, http.StatusCreated, codersdk.GenerateAPIKeyResponse{Key: sessionToken}) } // Clear the user's session cookie @@ -930,3 +791,92 @@ func convertUser(user database.User) codersdk.User { Name: user.Name, } } + +func (api *api) createAPIKey(rw http.ResponseWriter, r *http.Request, userID uuid.UUID) (string, bool) { + keyID, keySecret, err := generateAPIKeyIDSecret() + if err != nil { + httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ + Message: fmt.Sprintf("generate api key parts: %s", err.Error()), + }) + return "", false + } + hashed := sha256.Sum256([]byte(keySecret)) + + _, err = api.Database.InsertAPIKey(r.Context(), database.InsertAPIKeyParams{ + ID: keyID, + UserID: userID, + ExpiresAt: database.Now().Add(24 * time.Hour), + CreatedAt: database.Now(), + UpdatedAt: database.Now(), + HashedSecret: hashed[:], + LoginType: database.LoginTypeBuiltIn, + }) + if err != nil { + httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ + Message: fmt.Sprintf("insert api key: %s", err.Error()), + }) + return "", false + } + + // This format is consumed by the APIKey middleware. + sessionToken := fmt.Sprintf("%s-%s", keyID, keySecret) + http.SetCookie(rw, &http.Cookie{ + Name: httpmw.AuthCookie, + Value: sessionToken, + Path: "/", + HttpOnly: true, + SameSite: http.SameSiteLaxMode, + Secure: api.SecureAuthCookie, + }) + return sessionToken, true +} + +func (api *api) createUser(ctx context.Context, db database.Store, req codersdk.CreateUserRequest) (database.User, error) { + params := database.InsertUserParams{ + ID: uuid.New(), + Email: req.Email, + Username: req.Username, + LoginType: database.LoginTypeBuiltIn, + CreatedAt: database.Now(), + UpdatedAt: database.Now(), + } + // If a user signs up with OAuth, they can have no password! + if req.Password != "" { + hashedPassword, err := userpassword.Hash(req.Password) + if err != nil { + return database.User{}, xerrors.Errorf("hash password: %w", err) + } + params.HashedPassword = []byte(hashedPassword) + } + + user, err := db.InsertUser(ctx, params) + if err != nil { + return database.User{}, xerrors.Errorf("create user: %w", err) + } + + privateKey, publicKey, err := gitsshkey.Generate(api.SSHKeygenAlgorithm) + if err != nil { + return database.User{}, xerrors.Errorf("generate user gitsshkey: %w", err) + } + _, err = db.InsertGitSSHKey(ctx, database.InsertGitSSHKeyParams{ + UserID: user.ID, + CreatedAt: database.Now(), + UpdatedAt: database.Now(), + PrivateKey: privateKey, + PublicKey: publicKey, + }) + if err != nil { + return database.User{}, xerrors.Errorf("insert user gitsshkey: %w", err) + } + _, err = db.InsertOrganizationMember(ctx, database.InsertOrganizationMemberParams{ + OrganizationID: req.OrganizationID, + UserID: user.ID, + CreatedAt: database.Now(), + UpdatedAt: database.Now(), + Roles: []string{}, + }) + if err != nil { + return database.User{}, xerrors.Errorf("create organization member: %w", err) + } + return user, nil +} diff --git a/go.mod b/go.mod index 435047b296f38..af27fc2bdbc82 100644 --- a/go.mod +++ b/go.mod @@ -60,6 +60,7 @@ require ( github.com/gohugoio/hugo v0.96.0 github.com/golang-jwt/jwt v3.2.2+incompatible github.com/golang-migrate/migrate/v4 v4.15.1 + github.com/google/go-github/v43 v43.0.1-0.20220414155304-00e42332e405 github.com/google/uuid v1.3.0 github.com/hashicorp/go-version v1.4.0 github.com/hashicorp/hc-install v0.3.1 @@ -151,6 +152,7 @@ require ( github.com/golang/protobuf v1.5.2 // indirect github.com/golang/snappy v0.0.4 // indirect github.com/google/go-cmp v0.5.7 // indirect + github.com/google/go-querystring v1.1.0 // indirect github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 // indirect github.com/gorilla/mux v1.8.0 // indirect github.com/hashicorp/errwrap v1.1.0 // indirect diff --git a/go.sum b/go.sum index 9466a75731c4a..3d1b6e6803c60 100644 --- a/go.sum +++ b/go.sum @@ -782,7 +782,11 @@ github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/ github.com/google/go-cmp v0.5.7 h1:81/ik6ipDQS2aGcBfIN5dHDB36BwrStyeAQquSYCV4o= github.com/google/go-cmp v0.5.7/go.mod h1:n+brtR0CgQNWTVd5ZUFpTBC8YFBDLK/h/bpaJ8/DtOE= github.com/google/go-github/v35 v35.2.0/go.mod h1:s0515YVTI+IMrDoy9Y4pHt9ShGpzHvHO8rZ7L7acgvs= +github.com/google/go-github/v43 v43.0.1-0.20220414155304-00e42332e405 h1:DdHws/YnnPrSywrjNYu2lEHqYHWp/LnEx56w59esd54= +github.com/google/go-github/v43 v43.0.1-0.20220414155304-00e42332e405/go.mod h1:4RgUDSnsxP19d65zJWqvqJ/poJxBCvmna50eXmIvoR8= github.com/google/go-querystring v1.0.0/go.mod h1:odCYkC5MyYFN7vkCjXpyrEuKhc/BUO6wN/zVPAxq5ck= +github.com/google/go-querystring v1.1.0 h1:AnCroh3fv4ZBgVIf1Iwtovgjaw/GiKJo8M8yD/fhyJ8= +github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17icRSOU623lUBU= github.com/google/gofuzz v0.0.0-20161122191042-44d81051d367/go.mod h1:HP5RmnzzSNb993RKQDq4+1A4ia9nllfqcQFTQJedwGI= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/gofuzz v1.1.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= diff --git a/peerbroker/proto/peerbroker.pb.go b/peerbroker/proto/peerbroker.pb.go index b1a880bf8ce36..8a443e6e42192 100644 --- a/peerbroker/proto/peerbroker.pb.go +++ b/peerbroker/proto/peerbroker.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.26.0 -// protoc v3.19.4 +// protoc v3.6.1 // source: peerbroker/proto/peerbroker.proto package proto diff --git a/provisionerd/proto/provisionerd.pb.go b/provisionerd/proto/provisionerd.pb.go index d7d835df2275d..27f3301b083cf 100644 --- a/provisionerd/proto/provisionerd.pb.go +++ b/provisionerd/proto/provisionerd.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.26.0 -// protoc v3.19.4 +// protoc v3.6.1 // source: provisionerd/proto/provisionerd.proto package proto diff --git a/provisionersdk/proto/provisioner.pb.go b/provisionersdk/proto/provisioner.pb.go index 72d37a0083f94..cdb0f4ee271a0 100644 --- a/provisionersdk/proto/provisioner.pb.go +++ b/provisionersdk/proto/provisioner.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.26.0 -// protoc v3.19.4 +// protoc v3.6.1 // source: provisionersdk/proto/provisioner.proto package proto From 05b6a370d2e72746ad2fae6426524bfe5bd93ceb Mon Sep 17 00:00:00 2001 From: Kyle Carberry Date: Sun, 17 Apr 2022 19:33:59 +0000 Subject: [PATCH 2/9] Add Github authentication --- .vscode/settings.json | 2 + cli/start.go | 42 +++++ coderd/coderd.go | 33 ++-- coderd/coderdtest/coderdtest.go | 4 +- coderd/database/databasefake/databasefake.go | 33 ++-- coderd/database/dump.sql | 15 +- coderd/database/migrations/000001_base.up.sql | 23 +-- coderd/database/models.go | 33 ++-- coderd/database/queries.sql.go | 113 +++++------- coderd/database/queries/apikeys.sql | 35 +--- coderd/database/sqlc.yaml | 8 +- coderd/httpmw/apikey.go | 57 +++--- coderd/httpmw/apikey_test.go | 40 ++-- coderd/httpmw/{oauth.go => oauth2.go} | 18 +- .../httpmw/{oauth_test.go => oauth2_test.go} | 25 +-- coderd/httpmw/organizationparam_test.go | 2 +- coderd/httpmw/templateparam_test.go | 2 +- coderd/httpmw/templateversionparam_test.go | 2 +- coderd/httpmw/workspaceagentparam_test.go | 2 +- coderd/httpmw/workspacebuildparam_test.go | 2 +- coderd/httpmw/workspaceparam_test.go | 2 +- coderd/userauth.go | 132 ------------- coderd/userauth_test.go | 147 --------------- coderd/useroauth2.go | 145 +++++++++++++++ coderd/useroauth2_test.go | 173 ++++++++++++++++++ coderd/users.go | 138 +++++++------- coderd/users_test.go | 5 +- 27 files changed, 636 insertions(+), 597 deletions(-) rename coderd/httpmw/{oauth.go => oauth2.go} (83%) rename coderd/httpmw/{oauth_test.go => oauth2_test.go} (78%) delete mode 100644 coderd/userauth.go delete mode 100644 coderd/userauth_test.go create mode 100644 coderd/useroauth2.go create mode 100644 coderd/useroauth2_test.go diff --git a/.vscode/settings.json b/.vscode/settings.json index 2988daa9f0d75..d5215c265a0ed 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -35,6 +35,7 @@ "nolint", "nosec", "ntqry", + "OIDC", "oneof", "parameterscopeid", "pqtype", @@ -46,6 +47,7 @@ "ptytest", "retrier", "sdkproto", + "Signup", "stretchr", "TCGETS", "tcpip", diff --git a/cli/start.go b/cli/start.go index 3d328cab1300a..ce3613ef5e609 100644 --- a/cli/start.go +++ b/cli/start.go @@ -18,7 +18,10 @@ import ( "github.com/briandowns/spinner" "github.com/coreos/go-systemd/daemon" + "github.com/google/go-github/v43/github" "github.com/spf13/cobra" + "golang.org/x/oauth2" + xgithub "golang.org/x/oauth2/github" "golang.org/x/xerrors" "google.golang.org/api/idtoken" "google.golang.org/api/option" @@ -153,6 +156,11 @@ func start() *cobra.Command { return xerrors.Errorf("parse ssh keygen algorithm %s: %w", sshKeygenAlgorithmRaw, err) } + githubOAuth2Config, err := configureGithubOAuth2(accessURLParsed, "", "") + if err != nil { + return xerrors.Errorf("configure github oauth2: %w", err) + } + logger := slog.Make(sloghuman.Sink(os.Stderr)) options := &coderd.Options{ AccessURL: accessURLParsed, @@ -160,6 +168,7 @@ func start() *cobra.Command { Database: databasefake.New(), Pubsub: database.NewPubsubInMemory(), GoogleTokenValidator: validator, + GithubOAuth2Config: githubOAuth2Config, SecureAuthCookie: secureAuthCookie, SSHKeygenAlgorithm: sshKeygenAlgorithm, } @@ -534,3 +543,36 @@ func configureTLS(listener net.Listener, tlsMinVersion, tlsClientAuth, tlsCertFi return tls.NewListener(listener, tlsConfig), nil } + +func configureGithubOAuth2(accessURL *url.URL, clientID, clientSecret string) (*coderd.GithubOAuth2Config, error) { + redirectURL, err := accessURL.Parse("/api/v2/users/oauth2/github/callback") + if err != nil { + return nil, xerrors.Errorf("parse github oauth callback url: %w", err) + } + return &coderd.GithubOAuth2Config{ + OAuth2Config: &oauth2.Config{ + ClientID: clientID, + ClientSecret: clientSecret, + Endpoint: xgithub.Endpoint, + RedirectURL: redirectURL.String(), + Scopes: []string{ + "read:user", + "user:email", + }, + }, + AllowSignups: true, + AllowOrganizations: []string{"coder"}, + AuthenticatedUser: func(ctx context.Context, client *http.Client) (*github.User, error) { + user, _, err := github.NewClient(client).Users.Get(ctx, "") + return user, err + }, + ListEmails: func(ctx context.Context, client *http.Client) ([]*github.UserEmail, error) { + emails, _, err := github.NewClient(client).Users.ListEmails(ctx, &github.ListOptions{}) + return emails, err + }, + ListOrganizations: func(ctx context.Context, client *http.Client) ([]*github.Organization, error) { + orgs, _, err := github.NewClient(client).Organizations.List(ctx, "", &github.ListOptions{}) + return orgs, err + }, + }, nil +} diff --git a/coderd/coderd.go b/coderd/coderd.go index 153092bf07283..fe1f5bf8b4fd7 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -34,7 +34,7 @@ type Options struct { AWSCertificates awsidentity.Certificates GoogleTokenValidator *idtoken.Validator - GithubOAuth2Provider GithubOAuth2Provider + GithubOAuth2Config *GithubOAuth2Config SecureAuthCookie bool SSHKeygenAlgorithm gitsshkey.Algorithm @@ -51,6 +51,9 @@ func New(options *Options) (http.Handler, func()) { api := &api{ Options: options, } + apiKeyMiddleware := httpmw.ExtractAPIKey(options.Database, &httpmw.OAuth2Configs{ + Github: options.GithubOAuth2Config, + }) r := chi.NewRouter() r.Route("/api/v2", func(r chi.Router) { @@ -75,7 +78,7 @@ func New(options *Options) (http.Handler, func()) { }) r.Route("/files", func(r chi.Router) { r.Use( - httpmw.ExtractAPIKey(options.Database, nil), + apiKeyMiddleware, // This number is arbitrary, but reading/writing // file content is expensive so it should be small. httpmw.RateLimitPerMinute(12), @@ -85,7 +88,7 @@ func New(options *Options) (http.Handler, func()) { }) r.Route("/organizations/{organization}", func(r chi.Router) { r.Use( - httpmw.ExtractAPIKey(options.Database, nil), + apiKeyMiddleware, httpmw.ExtractOrganizationParam(options.Database), ) r.Get("/", api.organization) @@ -98,7 +101,7 @@ func New(options *Options) (http.Handler, func()) { }) }) r.Route("/parameters/{scope}/{id}", func(r chi.Router) { - r.Use(httpmw.ExtractAPIKey(options.Database, nil)) + r.Use(apiKeyMiddleware) r.Post("/", api.postParameter) r.Get("/", api.parameters) r.Route("/{name}", func(r chi.Router) { @@ -107,7 +110,7 @@ func New(options *Options) (http.Handler, func()) { }) r.Route("/templates/{template}", func(r chi.Router) { r.Use( - httpmw.ExtractAPIKey(options.Database, nil), + apiKeyMiddleware, httpmw.ExtractTemplateParam(options.Database), httpmw.ExtractOrganizationParam(options.Database), ) @@ -121,7 +124,7 @@ func New(options *Options) (http.Handler, func()) { }) r.Route("/templateversions/{templateversion}", func(r chi.Router) { r.Use( - httpmw.ExtractAPIKey(options.Database, nil), + apiKeyMiddleware, httpmw.ExtractTemplateVersionParam(options.Database), httpmw.ExtractOrganizationParam(options.Database), ) @@ -143,14 +146,14 @@ func New(options *Options) (http.Handler, func()) { r.Post("/first", api.postFirstUser) r.Post("/login", api.postLogin) r.Post("/logout", api.postLogout) - r.Route("/auth", func(r chi.Router) { - r.Route("/callback/github", func(r chi.Router) { - r.Use(httpmw.ExtractOAuth2(options.GithubOAuth2Provider)) - r.Get("/", api.userAuthGithub) + r.Route("/oauth2", func(r chi.Router) { + r.Route("/github", func(r chi.Router) { + r.Use(httpmw.ExtractOAuth2(options.GithubOAuth2Config)) + r.Get("/callback", api.userOAuth2Github) }) }) r.Group(func(r chi.Router) { - r.Use(httpmw.ExtractAPIKey(options.Database, nil)) + r.Use(apiKeyMiddleware) r.Post("/", api.postUsers) r.Route("/{user}", func(r chi.Router) { r.Use(httpmw.ExtractUserParam(options.Database)) @@ -184,7 +187,7 @@ func New(options *Options) (http.Handler, func()) { }) r.Route("/{workspaceagent}", func(r chi.Router) { r.Use( - httpmw.ExtractAPIKey(options.Database, nil), + apiKeyMiddleware, httpmw.ExtractWorkspaceAgentParam(options.Database), ) r.Get("/", api.workspaceAgent) @@ -193,7 +196,7 @@ func New(options *Options) (http.Handler, func()) { }) r.Route("/workspaceresources/{workspaceresource}", func(r chi.Router) { r.Use( - httpmw.ExtractAPIKey(options.Database, nil), + apiKeyMiddleware, httpmw.ExtractWorkspaceResourceParam(options.Database), httpmw.ExtractWorkspaceParam(options.Database), ) @@ -201,7 +204,7 @@ func New(options *Options) (http.Handler, func()) { }) r.Route("/workspaces/{workspace}", func(r chi.Router) { r.Use( - httpmw.ExtractAPIKey(options.Database, nil), + apiKeyMiddleware, httpmw.ExtractWorkspaceParam(options.Database), ) r.Get("/", api.workspace) @@ -219,7 +222,7 @@ func New(options *Options) (http.Handler, func()) { }) r.Route("/workspacebuilds/{workspacebuild}", func(r chi.Router) { r.Use( - httpmw.ExtractAPIKey(options.Database, nil), + apiKeyMiddleware, httpmw.ExtractWorkspaceBuildParam(options.Database), httpmw.ExtractWorkspaceParam(options.Database), ) diff --git a/coderd/coderdtest/coderdtest.go b/coderd/coderdtest/coderdtest.go index c2679238a2df4..9bc8b4c699f4b 100644 --- a/coderd/coderdtest/coderdtest.go +++ b/coderd/coderdtest/coderdtest.go @@ -49,7 +49,7 @@ import ( type Options struct { AWSInstanceIdentity awsidentity.Certificates - GithubOAuth2Provider coderd.GithubOAuth2Provider + GithubOAuth2Config *coderd.GithubOAuth2Config GoogleInstanceIdentity *idtoken.Validator SSHKeygenAlgorithm gitsshkey.Algorithm } @@ -116,7 +116,7 @@ func New(t *testing.T, options *Options) *codersdk.Client { Pubsub: pubsub, AWSCertificates: options.AWSInstanceIdentity, - GithubOAuth2Provider: options.GithubOAuth2Provider, + GithubOAuth2Config: options.GithubOAuth2Config, GoogleTokenValidator: options.GoogleInstanceIdentity, SSHKeygenAlgorithm: options.SSHKeygenAlgorithm, }) diff --git a/coderd/database/databasefake/databasefake.go b/coderd/database/databasefake/databasefake.go index 556ba5d94c1f3..75771fbb13f2a 100644 --- a/coderd/database/databasefake/databasefake.go +++ b/coderd/database/databasefake/databasefake.go @@ -797,21 +797,18 @@ func (q *fakeQuerier) InsertAPIKey(_ context.Context, arg database.InsertAPIKeyP //nolint:gosimple key := database.APIKey{ - ID: arg.ID, - HashedSecret: arg.HashedSecret, - UserID: arg.UserID, - Application: arg.Application, - Name: arg.Name, - LastUsed: arg.LastUsed, - ExpiresAt: arg.ExpiresAt, - CreatedAt: arg.CreatedAt, - UpdatedAt: arg.UpdatedAt, - LoginType: arg.LoginType, - OIDCAccessToken: arg.OIDCAccessToken, - OIDCRefreshToken: arg.OIDCRefreshToken, - OIDCIDToken: arg.OIDCIDToken, - OIDCExpiry: arg.OIDCExpiry, - DevurlToken: arg.DevurlToken, + ID: arg.ID, + HashedSecret: arg.HashedSecret, + UserID: arg.UserID, + ExpiresAt: arg.ExpiresAt, + CreatedAt: arg.CreatedAt, + UpdatedAt: arg.UpdatedAt, + LastUsed: arg.LastUsed, + LoginType: arg.LoginType, + OAuthAccessToken: arg.OAuthAccessToken, + OAuthRefreshToken: arg.OAuthRefreshToken, + OAuthIDToken: arg.OAuthIDToken, + OAuthExpiry: arg.OAuthExpiry, } q.apiKeys = append(q.apiKeys, key) return key, nil @@ -1126,9 +1123,9 @@ func (q *fakeQuerier) UpdateAPIKeyByID(_ context.Context, arg database.UpdateAPI } apiKey.LastUsed = arg.LastUsed apiKey.ExpiresAt = arg.ExpiresAt - apiKey.OIDCAccessToken = arg.OIDCAccessToken - apiKey.OIDCRefreshToken = arg.OIDCRefreshToken - apiKey.OIDCExpiry = arg.OIDCExpiry + apiKey.OAuthAccessToken = arg.OAuthAccessToken + apiKey.OAuthRefreshToken = arg.OAuthRefreshToken + apiKey.OAuthExpiry = arg.OAuthExpiry q.apiKeys[index] = apiKey return nil } diff --git a/coderd/database/dump.sql b/coderd/database/dump.sql index fb8621e2f2f3d..406f07c95e130 100644 --- a/coderd/database/dump.sql +++ b/coderd/database/dump.sql @@ -14,9 +14,8 @@ CREATE TYPE log_source AS ENUM ( ); CREATE TYPE login_type AS ENUM ( - 'built-in', - 'saml', - 'oidc' + 'basic', + 'github' ); CREATE TYPE parameter_destination_scheme AS ENUM ( @@ -67,18 +66,16 @@ CREATE TABLE api_keys ( id text NOT NULL, hashed_secret bytea NOT NULL, user_id uuid NOT NULL, - application boolean NOT NULL, name text NOT NULL, last_used timestamp with time zone NOT NULL, expires_at timestamp with time zone NOT NULL, created_at timestamp with time zone NOT NULL, updated_at timestamp with time zone NOT NULL, login_type login_type NOT NULL, - oidc_access_token text DEFAULT ''::text NOT NULL, - oidc_refresh_token text DEFAULT ''::text NOT NULL, - oidc_id_token text DEFAULT ''::text NOT NULL, - oidc_expiry timestamp with time zone DEFAULT '0001-01-01 00:00:00+00'::timestamp with time zone NOT NULL, - devurl_token boolean DEFAULT false NOT NULL + oauth_access_token text DEFAULT ''::text NOT NULL, + oauth_refresh_token text DEFAULT ''::text NOT NULL, + oauth_id_token text DEFAULT ''::text NOT NULL, + oauth_expiry timestamp with time zone DEFAULT '0001-01-01 00:00:00+00'::timestamp with time zone NOT NULL ); CREATE TABLE files ( diff --git a/coderd/database/migrations/000001_base.up.sql b/coderd/database/migrations/000001_base.up.sql index 65fbbf8fd4805..92d73d61d95c9 100644 --- a/coderd/database/migrations/000001_base.up.sql +++ b/coderd/database/migrations/000001_base.up.sql @@ -4,14 +4,9 @@ -- All tables and types are stolen from: -- https://github.com/coder/m/blob/47b6fc383347b9f9fab424d829c482defd3e1fe2/product/coder/pkg/database/dump.sql --- --- Name: users; Type: TABLE; Schema: public; Owner: coder --- - CREATE TYPE login_type AS ENUM ( - 'built-in', - 'saml', - 'oidc' + 'basic', + 'github' ); CREATE TABLE IF NOT EXISTS users ( @@ -31,10 +26,6 @@ CREATE UNIQUE INDEX IF NOT EXISTS idx_users_email ON users USING btree (email); CREATE UNIQUE INDEX IF NOT EXISTS idx_users_username ON users USING btree (username); CREATE UNIQUE INDEX IF NOT EXISTS users_username_lower_idx ON users USING btree (lower(username)); --- --- Name: organizations; Type: TABLE; Schema: Owner: coder --- - CREATE TABLE IF NOT EXISTS organizations ( id uuid NOT NULL, name text NOT NULL, @@ -68,18 +59,16 @@ CREATE TABLE IF NOT EXISTS api_keys ( id text NOT NULL, hashed_secret bytea NOT NULL, user_id uuid NOT NULL, - application boolean NOT NULL, name text NOT NULL, last_used timestamp with time zone NOT NULL, expires_at timestamp with time zone NOT NULL, created_at timestamp with time zone NOT NULL, updated_at timestamp with time zone NOT NULL, login_type login_type NOT NULL, - oidc_access_token text DEFAULT ''::text NOT NULL, - oidc_refresh_token text DEFAULT ''::text NOT NULL, - oidc_id_token text DEFAULT ''::text NOT NULL, - oidc_expiry timestamp with time zone DEFAULT '0001-01-01 00:00:00+00'::timestamp with time zone NOT NULL, - devurl_token boolean DEFAULT false NOT NULL, + oauth_access_token text DEFAULT ''::text NOT NULL, + oauth_refresh_token text DEFAULT ''::text NOT NULL, + oauth_id_token text DEFAULT ''::text NOT NULL, + oauth_expiry timestamp with time zone DEFAULT '0001-01-01 00:00:00+00'::timestamp with time zone NOT NULL, PRIMARY KEY (id) ); diff --git a/coderd/database/models.go b/coderd/database/models.go index a8d311194139e..4de655d09b22a 100644 --- a/coderd/database/models.go +++ b/coderd/database/models.go @@ -56,9 +56,8 @@ func (e *LogSource) Scan(src interface{}) error { type LoginType string const ( - LoginTypeBuiltIn LoginType = "built-in" - LoginTypeSaml LoginType = "saml" - LoginTypeOIDC LoginType = "oidc" + LoginTypeBasic LoginType = "basic" + LoginTypeGithub LoginType = "github" ) func (e *LoginType) Scan(src interface{}) error { @@ -230,21 +229,19 @@ func (e *WorkspaceTransition) Scan(src interface{}) error { } type APIKey struct { - ID string `db:"id" json:"id"` - HashedSecret []byte `db:"hashed_secret" json:"hashed_secret"` - UserID uuid.UUID `db:"user_id" json:"user_id"` - Application bool `db:"application" json:"application"` - Name string `db:"name" json:"name"` - LastUsed time.Time `db:"last_used" json:"last_used"` - ExpiresAt time.Time `db:"expires_at" json:"expires_at"` - CreatedAt time.Time `db:"created_at" json:"created_at"` - UpdatedAt time.Time `db:"updated_at" json:"updated_at"` - LoginType LoginType `db:"login_type" json:"login_type"` - OIDCAccessToken string `db:"oidc_access_token" json:"oidc_access_token"` - OIDCRefreshToken string `db:"oidc_refresh_token" json:"oidc_refresh_token"` - OIDCIDToken string `db:"oidc_id_token" json:"oidc_id_token"` - OIDCExpiry time.Time `db:"oidc_expiry" json:"oidc_expiry"` - DevurlToken bool `db:"devurl_token" json:"devurl_token"` + ID string `db:"id" json:"id"` + HashedSecret []byte `db:"hashed_secret" json:"hashed_secret"` + UserID uuid.UUID `db:"user_id" json:"user_id"` + Name string `db:"name" json:"name"` + LastUsed time.Time `db:"last_used" json:"last_used"` + ExpiresAt time.Time `db:"expires_at" json:"expires_at"` + CreatedAt time.Time `db:"created_at" json:"created_at"` + UpdatedAt time.Time `db:"updated_at" json:"updated_at"` + LoginType LoginType `db:"login_type" json:"login_type"` + OAuthAccessToken string `db:"oauth_access_token" json:"oauth_access_token"` + OAuthRefreshToken string `db:"oauth_refresh_token" json:"oauth_refresh_token"` + OAuthIDToken string `db:"oauth_id_token" json:"oauth_id_token"` + OAuthExpiry time.Time `db:"oauth_expiry" json:"oauth_expiry"` } type File struct { diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index c43a84959ddd6..eef9f5e57f4cf 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -15,7 +15,7 @@ import ( const getAPIKeyByID = `-- name: GetAPIKeyByID :one SELECT - id, hashed_secret, user_id, application, name, last_used, expires_at, created_at, updated_at, login_type, oidc_access_token, oidc_refresh_token, oidc_id_token, oidc_expiry, devurl_token + id, hashed_secret, user_id, name, last_used, expires_at, created_at, updated_at, login_type, oauth_access_token, oauth_refresh_token, oauth_id_token, oauth_expiry FROM api_keys WHERE @@ -31,18 +31,16 @@ func (q *sqlQuerier) GetAPIKeyByID(ctx context.Context, id string) (APIKey, erro &i.ID, &i.HashedSecret, &i.UserID, - &i.Application, &i.Name, &i.LastUsed, &i.ExpiresAt, &i.CreatedAt, &i.UpdatedAt, &i.LoginType, - &i.OIDCAccessToken, - &i.OIDCRefreshToken, - &i.OIDCIDToken, - &i.OIDCExpiry, - &i.DevurlToken, + &i.OAuthAccessToken, + &i.OAuthRefreshToken, + &i.OAuthIDToken, + &i.OAuthExpiry, ) return i, err } @@ -53,55 +51,33 @@ INSERT INTO id, hashed_secret, user_id, - application, - "name", last_used, expires_at, created_at, updated_at, login_type, - oidc_access_token, - oidc_refresh_token, - oidc_id_token, - oidc_expiry, - devurl_token + oauth_access_token, + oauth_refresh_token, + oauth_id_token, + oauth_expiry ) VALUES - ( - $1, - $2, - $3, - $4, - $5, - $6, - $7, - $8, - $9, - $10, - $11, - $12, - $13, - $14, - $15 - ) RETURNING id, hashed_secret, user_id, application, name, last_used, expires_at, created_at, updated_at, login_type, oidc_access_token, oidc_refresh_token, oidc_id_token, oidc_expiry, devurl_token + ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12) RETURNING id, hashed_secret, user_id, name, last_used, expires_at, created_at, updated_at, login_type, oauth_access_token, oauth_refresh_token, oauth_id_token, oauth_expiry ` type InsertAPIKeyParams struct { - ID string `db:"id" json:"id"` - HashedSecret []byte `db:"hashed_secret" json:"hashed_secret"` - UserID uuid.UUID `db:"user_id" json:"user_id"` - Application bool `db:"application" json:"application"` - Name string `db:"name" json:"name"` - LastUsed time.Time `db:"last_used" json:"last_used"` - ExpiresAt time.Time `db:"expires_at" json:"expires_at"` - CreatedAt time.Time `db:"created_at" json:"created_at"` - UpdatedAt time.Time `db:"updated_at" json:"updated_at"` - LoginType LoginType `db:"login_type" json:"login_type"` - OIDCAccessToken string `db:"oidc_access_token" json:"oidc_access_token"` - OIDCRefreshToken string `db:"oidc_refresh_token" json:"oidc_refresh_token"` - OIDCIDToken string `db:"oidc_id_token" json:"oidc_id_token"` - OIDCExpiry time.Time `db:"oidc_expiry" json:"oidc_expiry"` - DevurlToken bool `db:"devurl_token" json:"devurl_token"` + ID string `db:"id" json:"id"` + HashedSecret []byte `db:"hashed_secret" json:"hashed_secret"` + UserID uuid.UUID `db:"user_id" json:"user_id"` + LastUsed time.Time `db:"last_used" json:"last_used"` + ExpiresAt time.Time `db:"expires_at" json:"expires_at"` + CreatedAt time.Time `db:"created_at" json:"created_at"` + UpdatedAt time.Time `db:"updated_at" json:"updated_at"` + LoginType LoginType `db:"login_type" json:"login_type"` + OAuthAccessToken string `db:"oauth_access_token" json:"oauth_access_token"` + OAuthRefreshToken string `db:"oauth_refresh_token" json:"oauth_refresh_token"` + OAuthIDToken string `db:"oauth_id_token" json:"oauth_id_token"` + OAuthExpiry time.Time `db:"oauth_expiry" json:"oauth_expiry"` } func (q *sqlQuerier) InsertAPIKey(ctx context.Context, arg InsertAPIKeyParams) (APIKey, error) { @@ -109,36 +85,31 @@ func (q *sqlQuerier) InsertAPIKey(ctx context.Context, arg InsertAPIKeyParams) ( arg.ID, arg.HashedSecret, arg.UserID, - arg.Application, - arg.Name, arg.LastUsed, arg.ExpiresAt, arg.CreatedAt, arg.UpdatedAt, arg.LoginType, - arg.OIDCAccessToken, - arg.OIDCRefreshToken, - arg.OIDCIDToken, - arg.OIDCExpiry, - arg.DevurlToken, + arg.OAuthAccessToken, + arg.OAuthRefreshToken, + arg.OAuthIDToken, + arg.OAuthExpiry, ) var i APIKey err := row.Scan( &i.ID, &i.HashedSecret, &i.UserID, - &i.Application, &i.Name, &i.LastUsed, &i.ExpiresAt, &i.CreatedAt, &i.UpdatedAt, &i.LoginType, - &i.OIDCAccessToken, - &i.OIDCRefreshToken, - &i.OIDCIDToken, - &i.OIDCExpiry, - &i.DevurlToken, + &i.OAuthAccessToken, + &i.OAuthRefreshToken, + &i.OAuthIDToken, + &i.OAuthExpiry, ) return i, err } @@ -149,20 +120,20 @@ UPDATE SET last_used = $2, expires_at = $3, - oidc_access_token = $4, - oidc_refresh_token = $5, - oidc_expiry = $6 + oauth_access_token = $4, + oauth_refresh_token = $5, + oauth_expiry = $6 WHERE id = $1 ` type UpdateAPIKeyByIDParams struct { - ID string `db:"id" json:"id"` - LastUsed time.Time `db:"last_used" json:"last_used"` - ExpiresAt time.Time `db:"expires_at" json:"expires_at"` - OIDCAccessToken string `db:"oidc_access_token" json:"oidc_access_token"` - OIDCRefreshToken string `db:"oidc_refresh_token" json:"oidc_refresh_token"` - OIDCExpiry time.Time `db:"oidc_expiry" json:"oidc_expiry"` + ID string `db:"id" json:"id"` + LastUsed time.Time `db:"last_used" json:"last_used"` + ExpiresAt time.Time `db:"expires_at" json:"expires_at"` + OAuthAccessToken string `db:"oauth_access_token" json:"oauth_access_token"` + OAuthRefreshToken string `db:"oauth_refresh_token" json:"oauth_refresh_token"` + OAuthExpiry time.Time `db:"oauth_expiry" json:"oauth_expiry"` } func (q *sqlQuerier) UpdateAPIKeyByID(ctx context.Context, arg UpdateAPIKeyByIDParams) error { @@ -170,9 +141,9 @@ func (q *sqlQuerier) UpdateAPIKeyByID(ctx context.Context, arg UpdateAPIKeyByIDP arg.ID, arg.LastUsed, arg.ExpiresAt, - arg.OIDCAccessToken, - arg.OIDCRefreshToken, - arg.OIDCExpiry, + arg.OAuthAccessToken, + arg.OAuthRefreshToken, + arg.OAuthExpiry, ) return err } diff --git a/coderd/database/queries/apikeys.sql b/coderd/database/queries/apikeys.sql index 62dc38ed2ca59..1af2016f491bf 100644 --- a/coderd/database/queries/apikeys.sql +++ b/coderd/database/queries/apikeys.sql @@ -14,37 +14,18 @@ INSERT INTO id, hashed_secret, user_id, - application, - "name", last_used, expires_at, created_at, updated_at, login_type, - oidc_access_token, - oidc_refresh_token, - oidc_id_token, - oidc_expiry, - devurl_token + oauth_access_token, + oauth_refresh_token, + oauth_id_token, + oauth_expiry ) VALUES - ( - $1, - $2, - $3, - $4, - $5, - $6, - $7, - $8, - $9, - $10, - $11, - $12, - $13, - $14, - $15 - ) RETURNING *; + ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12) RETURNING *; -- name: UpdateAPIKeyByID :exec UPDATE @@ -52,8 +33,8 @@ UPDATE SET last_used = $2, expires_at = $3, - oidc_access_token = $4, - oidc_refresh_token = $5, - oidc_expiry = $6 + oauth_access_token = $4, + oauth_refresh_token = $5, + oauth_expiry = $6 WHERE id = $1; diff --git a/coderd/database/sqlc.yaml b/coderd/database/sqlc.yaml index a009644cdf520..abde7029c3c79 100644 --- a/coderd/database/sqlc.yaml +++ b/coderd/database/sqlc.yaml @@ -21,10 +21,10 @@ overrides: rename: api_key: APIKey login_type_oidc: LoginTypeOIDC - oidc_access_token: OIDCAccessToken - oidc_expiry: OIDCExpiry - oidc_id_token: OIDCIDToken - oidc_refresh_token: OIDCRefreshToken + oauth_access_token: OAuthAccessToken + oauth_expiry: OAuthExpiry + oauth_id_token: OAuthIDToken + oauth_refresh_token: OAuthRefreshToken parameter_type_system_hcl: ParameterTypeSystemHCL userstatus: UserStatus gitsshkey: GitSSHKey diff --git a/coderd/httpmw/apikey.go b/coderd/httpmw/apikey.go index 1b18bc56bcde6..e6defe8b2395b 100644 --- a/coderd/httpmw/apikey.go +++ b/coderd/httpmw/apikey.go @@ -20,12 +20,6 @@ import ( // AuthCookie represents the name of the cookie the API key is stored in. const AuthCookie = "session_token" -// OAuth2Config contains a subset of functions exposed from oauth2.Config. -// It is abstracted for simple testing. -type OAuth2Config interface { - TokenSource(context.Context, *oauth2.Token) oauth2.TokenSource -} - type apiKeyContextKey struct{} // APIKey returns the API key from the ExtractAPIKey handler. @@ -37,10 +31,16 @@ func APIKey(r *http.Request) database.APIKey { return apiKey } +// OAuth2Configs is a collection of configurations for OAuth-based authentication. +// This should be extended to support other authentication types in the future. +type OAuth2Configs struct { + Github OAuth2Config +} + // ExtractAPIKey requires authentication using a valid API key. // It handles extending an API key if it comes close to expiry, // updating the last used time in the database. -func ExtractAPIKey(db database.Store, oauthConfig OAuth2Config) func(http.Handler) http.Handler { +func ExtractAPIKey(db database.Store, oauth *OAuth2Configs) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { cookie, err := r.Cookie(AuthCookie) @@ -99,14 +99,24 @@ func ExtractAPIKey(db database.Store, oauthConfig OAuth2Config) func(http.Handle // Tracks if the API key has properties updated! changed := false - if key.LoginType == database.LoginTypeOIDC { - // Check if the OIDC token is expired! - if key.OIDCExpiry.Before(now) && !key.OIDCExpiry.IsZero() { + if key.LoginType != database.LoginTypeBasic { + // Check if the OAuth token is expired! + if key.OAuthExpiry.Before(now) && !key.OAuthExpiry.IsZero() { + var oauthConfig OAuth2Config + switch key.LoginType { + case database.LoginTypeGithub: + oauthConfig = oauth.Github + default: + httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ + Message: fmt.Sprintf("unexpected authentication type %q", key.LoginType), + }) + return + } // If it is, let's refresh it from the provided config! token, err := oauthConfig.TokenSource(r.Context(), &oauth2.Token{ - AccessToken: key.OIDCAccessToken, - RefreshToken: key.OIDCRefreshToken, - Expiry: key.OIDCExpiry, + AccessToken: key.OAuthAccessToken, + RefreshToken: key.OAuthRefreshToken, + Expiry: key.OAuthExpiry, }).Token() if err != nil { httpapi.Write(rw, http.StatusUnauthorized, httpapi.Response{ @@ -114,9 +124,9 @@ func ExtractAPIKey(db database.Store, oauthConfig OAuth2Config) func(http.Handle }) return } - key.OIDCAccessToken = token.AccessToken - key.OIDCRefreshToken = token.RefreshToken - key.OIDCExpiry = token.Expiry + key.OAuthAccessToken = token.AccessToken + key.OAuthRefreshToken = token.RefreshToken + key.OAuthExpiry = token.Expiry key.ExpiresAt = token.Expiry changed = true } @@ -136,21 +146,20 @@ func ExtractAPIKey(db database.Store, oauthConfig OAuth2Config) func(http.Handle changed = true } // Only update the ExpiresAt once an hour to prevent database spam. - // We extend the ExpiresAt to reduce reauthentication. + // We extend the ExpiresAt to reduce re-authentication. apiKeyLifetime := 24 * time.Hour if key.ExpiresAt.Sub(now) <= apiKeyLifetime-time.Hour { key.ExpiresAt = now.Add(apiKeyLifetime) changed = true } - if changed { err := db.UpdateAPIKeyByID(r.Context(), database.UpdateAPIKeyByIDParams{ - ID: key.ID, - ExpiresAt: key.ExpiresAt, - LastUsed: key.LastUsed, - OIDCAccessToken: key.OIDCAccessToken, - OIDCRefreshToken: key.OIDCRefreshToken, - OIDCExpiry: key.OIDCExpiry, + ID: key.ID, + LastUsed: key.LastUsed, + ExpiresAt: key.ExpiresAt, + OAuthAccessToken: key.OAuthAccessToken, + OAuthRefreshToken: key.OAuthRefreshToken, + OAuthExpiry: key.OAuthExpiry, }) if err != nil { httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ diff --git a/coderd/httpmw/apikey_test.go b/coderd/httpmw/apikey_test.go index 2d4e7c3a6be67..0c8d8d396e55b 100644 --- a/coderd/httpmw/apikey_test.go +++ b/coderd/httpmw/apikey_test.go @@ -189,7 +189,6 @@ func TestAPIKey(t *testing.T) { sentAPIKey, err := db.InsertAPIKey(r.Context(), database.InsertAPIKeyParams{ ID: id, HashedSecret: hashed[:], - LastUsed: database.Now(), ExpiresAt: database.Now().AddDate(0, 0, 1), }) require.NoError(t, err) @@ -207,7 +206,6 @@ func TestAPIKey(t *testing.T) { gotAPIKey, err := db.GetAPIKeyByID(r.Context(), id) require.NoError(t, err) - require.Equal(t, sentAPIKey.LastUsed, gotAPIKey.LastUsed) require.Equal(t, sentAPIKey.ExpiresAt, gotAPIKey.ExpiresAt) }) @@ -277,7 +275,7 @@ func TestAPIKey(t *testing.T) { require.NotEqual(t, sentAPIKey.ExpiresAt, gotAPIKey.ExpiresAt) }) - t.Run("OIDCNotExpired", func(t *testing.T) { + t.Run("OAuthNotExpired", func(t *testing.T) { t.Parallel() var ( db = databasefake.New() @@ -294,7 +292,7 @@ func TestAPIKey(t *testing.T) { sentAPIKey, err := db.InsertAPIKey(r.Context(), database.InsertAPIKeyParams{ ID: id, HashedSecret: hashed[:], - LoginType: database.LoginTypeOIDC, + LoginType: database.LoginTypeGithub, LastUsed: database.Now(), ExpiresAt: database.Now().AddDate(0, 0, 1), }) @@ -311,7 +309,7 @@ func TestAPIKey(t *testing.T) { require.Equal(t, sentAPIKey.ExpiresAt, gotAPIKey.ExpiresAt) }) - t.Run("OIDCRefresh", func(t *testing.T) { + t.Run("OAuthRefresh", func(t *testing.T) { t.Parallel() var ( db = databasefake.New() @@ -328,9 +326,9 @@ func TestAPIKey(t *testing.T) { sentAPIKey, err := db.InsertAPIKey(r.Context(), database.InsertAPIKeyParams{ ID: id, HashedSecret: hashed[:], - LoginType: database.LoginTypeOIDC, + LoginType: database.LoginTypeGithub, LastUsed: database.Now(), - OIDCExpiry: database.Now().AddDate(0, 0, -1), + OAuthExpiry: database.Now().AddDate(0, 0, -1), }) require.NoError(t, err) token := &oauth2.Token{ @@ -338,11 +336,11 @@ func TestAPIKey(t *testing.T) { RefreshToken: "moo", Expiry: database.Now().AddDate(0, 0, 1), } - httpmw.ExtractAPIKey(db, &oauth2Config{ - tokenSource: &oauth2TokenSource{ - token: func() (*oauth2.Token, error) { + httpmw.ExtractAPIKey(db, &httpmw.OAuth2Configs{ + Github: &oauth2Config{ + tokenSource: oauth2TokenSource(func() (*oauth2.Token, error) { return token, nil - }, + }), }, })(successHandler).ServeHTTP(rw, r) res := rw.Result() @@ -354,22 +352,28 @@ func TestAPIKey(t *testing.T) { require.Equal(t, sentAPIKey.LastUsed, gotAPIKey.LastUsed) require.Equal(t, token.Expiry, gotAPIKey.ExpiresAt) - require.Equal(t, token.AccessToken, gotAPIKey.OIDCAccessToken) + require.Equal(t, token.AccessToken, gotAPIKey.OAuthAccessToken) }) } type oauth2Config struct { - tokenSource *oauth2TokenSource + tokenSource oauth2TokenSource } -func (o *oauth2Config) TokenSource(_ context.Context, _ *oauth2.Token) oauth2.TokenSource { +func (o *oauth2Config) TokenSource(context.Context, *oauth2.Token) oauth2.TokenSource { return o.tokenSource } -type oauth2TokenSource struct { - token func() (*oauth2.Token, error) +func (*oauth2Config) AuthCodeURL(string, ...oauth2.AuthCodeOption) string { + return "" } -func (o *oauth2TokenSource) Token() (*oauth2.Token, error) { - return o.token() +func (*oauth2Config) Exchange(context.Context, string, ...oauth2.AuthCodeOption) (*oauth2.Token, error) { + return &oauth2.Token{}, nil +} + +type oauth2TokenSource func() (*oauth2.Token, error) + +func (o oauth2TokenSource) Token() (*oauth2.Token, error) { + return o() } diff --git a/coderd/httpmw/oauth.go b/coderd/httpmw/oauth2.go similarity index 83% rename from coderd/httpmw/oauth.go rename to coderd/httpmw/oauth2.go index f96b112fbaa2d..899a994475798 100644 --- a/coderd/httpmw/oauth.go +++ b/coderd/httpmw/oauth2.go @@ -23,11 +23,12 @@ type OAuth2State struct { Redirect string } -// OAuth2Provider exposes a subset of *oauth2.Config functions for easier testing. +// OAuth2Config exposes a subset of *oauth2.Config functions for easier testing. // *oauth2.Config should be used instead of implementing this in production. -type OAuth2Provider interface { +type OAuth2Config interface { AuthCodeURL(state string, opts ...oauth2.AuthCodeOption) string Exchange(ctx context.Context, code string, opts ...oauth2.AuthCodeOption) (*oauth2.Token, error) + TokenSource(context.Context, *oauth2.Token) oauth2.TokenSource } // OAuth2 returns the state from an oauth request. @@ -42,9 +43,16 @@ func OAuth2(r *http.Request) OAuth2State { // ExtractOAuth2 adds a middleware for handling OAuth2 callbacks. // Any route that does not have a "code" URL parameter will be redirected // to the handler configuration provided. -func ExtractOAuth2(provider OAuth2Provider) func(http.Handler) http.Handler { +func ExtractOAuth2(config OAuth2Config) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + if config == nil { + httpapi.Write(rw, http.StatusPreconditionRequired, httpapi.Response{ + Message: fmt.Sprintf("The oauth2 method requested is not configured!"), + }) + return + } + code := r.URL.Query().Get("code") state := r.URL.Query().Get("state") @@ -75,7 +83,7 @@ func ExtractOAuth2(provider OAuth2Provider) func(http.Handler) http.Handler { SameSite: http.SameSiteStrictMode, }) - http.Redirect(rw, r, provider.AuthCodeURL(state, oauth2.AccessTypeOffline), http.StatusTemporaryRedirect) + http.Redirect(rw, r, config.AuthCodeURL(state, oauth2.AccessTypeOffline), http.StatusTemporaryRedirect) return } @@ -106,7 +114,7 @@ func ExtractOAuth2(provider OAuth2Provider) func(http.Handler) http.Handler { redirect = stateRedirect.Value } - oauthToken, err := provider.Exchange(r.Context(), code) + oauthToken, err := config.Exchange(r.Context(), code) if err != nil { httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ Message: fmt.Sprintf("exchange oauth code: %s", err), diff --git a/coderd/httpmw/oauth_test.go b/coderd/httpmw/oauth2_test.go similarity index 78% rename from coderd/httpmw/oauth_test.go rename to coderd/httpmw/oauth2_test.go index 6142b15853243..31803b7351487 100644 --- a/coderd/httpmw/oauth_test.go +++ b/coderd/httpmw/oauth2_test.go @@ -17,18 +17,29 @@ import ( type testOAuth2Provider struct { } -func (*testOAuth2Provider) AuthCodeURL(state string, opts ...oauth2.AuthCodeOption) string { +func (*testOAuth2Provider) AuthCodeURL(state string, _ ...oauth2.AuthCodeOption) string { return "?state=" + url.QueryEscape(state) } -func (*testOAuth2Provider) Exchange(ctx context.Context, code string, opts ...oauth2.AuthCodeOption) (*oauth2.Token, error) { +func (*testOAuth2Provider) Exchange(_ context.Context, _ string, _ ...oauth2.AuthCodeOption) (*oauth2.Token, error) { return &oauth2.Token{ AccessToken: "hello", }, nil } +func (*testOAuth2Provider) TokenSource(_ context.Context, _ *oauth2.Token) oauth2.TokenSource { + return nil +} + func TestOAuth2(t *testing.T) { t.Parallel() + t.Run("NotSetup", func(t *testing.T) { + t.Parallel() + req := httptest.NewRequest("GET", "/", nil) + res := httptest.NewRecorder() + httpmw.ExtractOAuth2(nil)(nil).ServeHTTP(res, req) + require.Equal(t, http.StatusPreconditionRequired, res.Result().StatusCode) + }) t.Run("RedirectWithoutCode", func(t *testing.T) { t.Parallel() req := httptest.NewRequest("GET", "/?redirect="+url.QueryEscape("/dashboard"), nil) @@ -84,14 +95,4 @@ func TestOAuth2(t *testing.T) { require.Equal(t, "/dashboard", state.Redirect) })).ServeHTTP(res, req) }) - - // t.Run("ExchangeCodeAndState", func(t *testing.T) { - // t.Parallel() - // req := httptest.NewRequest("GET", "/?code=test&state="+url.QueryEscape(state), nil) - // res := httptest.NewRecorder() - // ExtractOAuth(log, cipher, &testOAuthProvider{})(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { - // rw.WriteHeader(http.StatusOK) - // })).ServeHTTP(res, req) - // assert.Equal(t, res.Result().StatusCode, http.StatusOK) - // }) } diff --git a/coderd/httpmw/organizationparam_test.go b/coderd/httpmw/organizationparam_test.go index 02887260feea0..33fb7ee311c3d 100644 --- a/coderd/httpmw/organizationparam_test.go +++ b/coderd/httpmw/organizationparam_test.go @@ -41,7 +41,7 @@ func TestOrganizationParam(t *testing.T) { ID: userID, Email: "testaccount@coder.com", Name: "example", - LoginType: database.LoginTypeBuiltIn, + LoginType: database.LoginTypeBasic, HashedPassword: hashed[:], Username: username, CreatedAt: database.Now(), diff --git a/coderd/httpmw/templateparam_test.go b/coderd/httpmw/templateparam_test.go index 47089713d612f..69269ff7bfa00 100644 --- a/coderd/httpmw/templateparam_test.go +++ b/coderd/httpmw/templateparam_test.go @@ -40,7 +40,7 @@ func TestTemplateParam(t *testing.T) { ID: userID, Email: "testaccount@coder.com", Name: "example", - LoginType: database.LoginTypeBuiltIn, + LoginType: database.LoginTypeBasic, HashedPassword: hashed[:], Username: username, CreatedAt: database.Now(), diff --git a/coderd/httpmw/templateversionparam_test.go b/coderd/httpmw/templateversionparam_test.go index 025b646f2ae58..1254b3c79c2b1 100644 --- a/coderd/httpmw/templateversionparam_test.go +++ b/coderd/httpmw/templateversionparam_test.go @@ -40,7 +40,7 @@ func TestTemplateVersionParam(t *testing.T) { ID: userID, Email: "testaccount@coder.com", Name: "example", - LoginType: database.LoginTypeBuiltIn, + LoginType: database.LoginTypeBasic, HashedPassword: hashed[:], Username: username, CreatedAt: database.Now(), diff --git a/coderd/httpmw/workspaceagentparam_test.go b/coderd/httpmw/workspaceagentparam_test.go index f014a8bd55b55..3c8bf884a799d 100644 --- a/coderd/httpmw/workspaceagentparam_test.go +++ b/coderd/httpmw/workspaceagentparam_test.go @@ -40,7 +40,7 @@ func TestWorkspaceAgentParam(t *testing.T) { ID: userID, Email: "testaccount@coder.com", Name: "example", - LoginType: database.LoginTypeBuiltIn, + LoginType: database.LoginTypeBasic, HashedPassword: hashed[:], Username: username, CreatedAt: database.Now(), diff --git a/coderd/httpmw/workspacebuildparam_test.go b/coderd/httpmw/workspacebuildparam_test.go index 62eb6f975765c..ff1338ed31e93 100644 --- a/coderd/httpmw/workspacebuildparam_test.go +++ b/coderd/httpmw/workspacebuildparam_test.go @@ -40,7 +40,7 @@ func TestWorkspaceBuildParam(t *testing.T) { ID: userID, Email: "testaccount@coder.com", Name: "example", - LoginType: database.LoginTypeBuiltIn, + LoginType: database.LoginTypeBasic, HashedPassword: hashed[:], Username: username, CreatedAt: database.Now(), diff --git a/coderd/httpmw/workspaceparam_test.go b/coderd/httpmw/workspaceparam_test.go index 5c169a0d10218..72107530b7d60 100644 --- a/coderd/httpmw/workspaceparam_test.go +++ b/coderd/httpmw/workspaceparam_test.go @@ -40,7 +40,7 @@ func TestWorkspaceParam(t *testing.T) { ID: userID, Email: "testaccount@coder.com", Name: "example", - LoginType: database.LoginTypeBuiltIn, + LoginType: database.LoginTypeBasic, HashedPassword: hashed[:], Username: username, CreatedAt: database.Now(), diff --git a/coderd/userauth.go b/coderd/userauth.go deleted file mode 100644 index 080045ec1e4fe..0000000000000 --- a/coderd/userauth.go +++ /dev/null @@ -1,132 +0,0 @@ -package coderd - -import ( - "context" - "database/sql" - "errors" - "fmt" - "net/http" - - "github.com/google/go-github/v43/github" - "github.com/google/uuid" - "golang.org/x/oauth2" - - "github.com/coder/coder/coderd/database" - "github.com/coder/coder/coderd/httpapi" - "github.com/coder/coder/coderd/httpmw" - "github.com/coder/coder/codersdk" -) - -// GithubOAuth2Provider exposes required functions for the Github authentication flow. -type GithubOAuth2Provider interface { - httpmw.OAuth2Provider - PersonalUser(ctx context.Context, client *github.Client) (*github.User, error) - ListEmails(ctx context.Context, client *github.Client) ([]*github.UserEmail, error) -} - -func (api *api) userAuthGithub(rw http.ResponseWriter, r *http.Request) { - state := httpmw.OAuth2(r) - - ghClient := github.NewClient(oauth2.NewClient(r.Context(), oauth2.StaticTokenSource(state.Token))) - ghUser, err := api.GithubOAuth2Provider.PersonalUser(r.Context(), ghClient) - if err != nil { - httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ - Message: fmt.Sprintf("get personal github user: %s", err), - }) - return - } - emails, err := api.GithubOAuth2Provider.ListEmails(r.Context(), ghClient) - if err != nil { - httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ - Message: fmt.Sprintf("get personal github user: %s", err), - }) - return - } - var user database.User - // Search for existing users with matching and verified emails. - // If a verified GitHub email matches a Coder user, we will - // return. - for _, email := range emails { - if email.Verified == nil { - continue - } - if !*email.Verified { - continue - } - user, err = api.Database.GetUserByEmailOrUsername(r.Context(), database.GetUserByEmailOrUsernameParams{ - Username: *ghUser.Name, - Email: *email.Email, - }) - if errors.Is(err, sql.ErrNoRows) { - continue - } - if err != nil { - httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ - Message: fmt.Sprintf("get user by email: %s", err), - }) - return - } - break - } - // If the user doesn't exist, create a new one! - if user.ID == uuid.Nil { - userCount, err := api.Database.GetUserCount(r.Context()) - if err != nil { - httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ - Message: fmt.Sprintf("get user count: %s", err.Error()), - }) - return - } - var organization database.Organization - // If there aren't any users yet, create one! - if userCount == 0 { - organization, err = api.Database.InsertOrganization(r.Context(), database.InsertOrganizationParams{ - ID: uuid.New(), - Name: *ghUser.Name, - CreatedAt: database.Now(), - UpdatedAt: database.Now(), - }) - if err != nil { - httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ - Message: fmt.Sprintf("create organization: %s", err), - }) - return - } - } else { - organizations, err := api.Database.GetOrganizations(r.Context()) - if err != nil { - httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ - Message: fmt.Sprintf("get organizations: %s", err), - }) - return - } - // Add the user to the first organization. Once multi-organization - // support is added, we should enable a configuration map of user - // email to organization. - organization = organizations[0] - } - - user, err = api.createUser(r.Context(), api.Database, codersdk.CreateUserRequest{ - Email: *ghUser.Email, - Username: *ghUser.Name, - OrganizationID: organization.ID, - }) - if err != nil { - httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ - Message: fmt.Sprintf("create user: %s", err), - }) - return - } - } - - _, created := api.createAPIKey(rw, r, user.ID) - if !created { - return - } - - redirect := state.Redirect - if redirect == "" { - redirect = "/" - } - http.Redirect(rw, r, redirect, http.StatusTemporaryRedirect) -} diff --git a/coderd/userauth_test.go b/coderd/userauth_test.go deleted file mode 100644 index 3db48546bd406..0000000000000 --- a/coderd/userauth_test.go +++ /dev/null @@ -1,147 +0,0 @@ -package coderd_test - -import ( - "context" - "net/http" - "net/url" - "testing" - - "github.com/coder/coder/coderd/coderdtest" - "github.com/coder/coder/codersdk" - "github.com/google/go-github/v43/github" - "github.com/stretchr/testify/require" - "golang.org/x/oauth2" -) - -type githubOAuthProvider struct{} - -func (g *githubOAuthProvider) AuthCodeURL(state string, opts ...oauth2.AuthCodeOption) string { - return "/?state=" + url.QueryEscape(state) -} - -func (g *githubOAuthProvider) Exchange(ctx context.Context, code string, opts ...oauth2.AuthCodeOption) (*oauth2.Token, error) { - return &oauth2.Token{ - AccessToken: "token", - }, nil -} - -func (g *githubOAuthProvider) PersonalUser(ctx context.Context, client *github.Client) (*github.User, error) { - return &github.User{ - ID: github.Int64(1), - Login: github.String("testuser"), - Name: github.String("some user"), - Email: github.String("wow@test.io"), - AvatarURL: github.String("https://coder.com/avatar.png"), - }, nil -} - -func (g *githubOAuthProvider) ListEmails(ctx context.Context, client *github.Client) ([]*github.UserEmail, error) { - return []*github.UserEmail{{ - Email: github.String("someone@io.io"), - Primary: github.Bool(true), - Verified: github.Bool(true), - }, { - Email: github.String("ok@io.io"), - Primary: github.Bool(false), - Verified: github.Bool(false), - }}, nil -} - -func TestUserAuthGithub(t *testing.T) { - t.Parallel() - t.Run("FirstUser", func(t *testing.T) { - t.Parallel() - client := coderdtest.New(t, &coderdtest.Options{ - GithubOAuth2Provider: &githubOAuthProvider{}, - }) - client.HTTPClient.CheckRedirect = func(req *http.Request, via []*http.Request) error { - return http.ErrUseLastResponse - } - - state := "somestate" - oauthURL, err := client.URL.Parse("/api/v2/users/auth/callback/github?code=asd&state=" + state) - require.NoError(t, err) - req, err := http.NewRequest("GET", oauthURL.String(), nil) - require.NoError(t, err) - req.AddCookie(&http.Cookie{ - Name: "oauth_state", - Value: state, - }) - res, err := client.HTTPClient.Do(req) - require.NoError(t, err) - defer res.Body.Close() - require.Equal(t, http.StatusTemporaryRedirect, res.StatusCode) - }) - t.Run("NewUser", func(t *testing.T) { - t.Parallel() - client := coderdtest.New(t, &coderdtest.Options{ - GithubOAuth2Provider: &githubOAuthProvider{}, - }) - client.HTTPClient.CheckRedirect = func(req *http.Request, via []*http.Request) error { - return http.ErrUseLastResponse - } - _, err := client.CreateFirstUser(context.Background(), codersdk.CreateFirstUserRequest{ - Email: "someone@io.io", - Username: "someone", - Password: "testing", - OrganizationName: "acme-corp", - }) - require.NoError(t, err) - token, err := client.LoginWithPassword(context.Background(), codersdk.LoginWithPasswordRequest{ - Email: "someone@io.io", - Password: "testing", - }) - require.NoError(t, err) - client.SessionToken = token.SessionToken - - state := "somestate" - oauthURL, err := client.URL.Parse("/api/v2/users/auth/callback/github?code=asd&state=" + state) - require.NoError(t, err) - req, err := http.NewRequest("GET", oauthURL.String(), nil) - require.NoError(t, err) - req.AddCookie(&http.Cookie{ - Name: "oauth_state", - Value: state, - }) - res, err := client.HTTPClient.Do(req) - require.NoError(t, err) - defer res.Body.Close() - require.Equal(t, http.StatusTemporaryRedirect, res.StatusCode) - }) - t.Run("ExistingUser", func(t *testing.T) { - t.Parallel() - client := coderdtest.New(t, &coderdtest.Options{ - GithubOAuth2Provider: &githubOAuthProvider{}, - }) - client.HTTPClient.CheckRedirect = func(req *http.Request, via []*http.Request) error { - return http.ErrUseLastResponse - } - _, err := client.CreateFirstUser(context.Background(), codersdk.CreateFirstUserRequest{ - Email: "someone@io.io", - Username: "someone", - Password: "testing", - OrganizationName: "acme-corp", - }) - require.NoError(t, err) - token, err := client.LoginWithPassword(context.Background(), codersdk.LoginWithPasswordRequest{ - Email: "someone@io.io", - Password: "testing", - }) - require.NoError(t, err) - client.SessionToken = token.SessionToken - - state := "somestate" - oauthURL, err := client.URL.Parse("/api/v2/users/auth/callback/github?code=asd&state=" + state) - require.NoError(t, err) - req, err := http.NewRequest("GET", oauthURL.String(), nil) - require.NoError(t, err) - req.AddCookie(&http.Cookie{ - Name: "oauth_state", - Value: state, - }) - res, err := client.HTTPClient.Do(req) - require.NoError(t, err) - defer res.Body.Close() - require.Equal(t, http.StatusTemporaryRedirect, res.StatusCode) - }) -} diff --git a/coderd/useroauth2.go b/coderd/useroauth2.go new file mode 100644 index 0000000000000..427eea1de5e41 --- /dev/null +++ b/coderd/useroauth2.go @@ -0,0 +1,145 @@ +package coderd + +import ( + "context" + "database/sql" + "errors" + "fmt" + "net/http" + + "github.com/google/go-github/v43/github" + "github.com/google/uuid" + "golang.org/x/oauth2" + + "github.com/coder/coder/coderd/database" + "github.com/coder/coder/coderd/httpapi" + "github.com/coder/coder/coderd/httpmw" + "github.com/coder/coder/codersdk" +) + +// GithubOAuth2Provider exposes required functions for the Github authentication flow. +type GithubOAuth2Config struct { + httpmw.OAuth2Config + AuthenticatedUser func(ctx context.Context, client *http.Client) (*github.User, error) + ListEmails func(ctx context.Context, client *http.Client) ([]*github.UserEmail, error) + ListOrganizations func(ctx context.Context, client *http.Client) ([]*github.Organization, error) + + AllowSignups bool + AllowOrganizations []string +} + +func (api *api) userOAuth2Github(rw http.ResponseWriter, r *http.Request) { + state := httpmw.OAuth2(r) + + oauthClient := oauth2.NewClient(r.Context(), oauth2.StaticTokenSource(state.Token)) + organizations, err := api.GithubOAuth2Config.ListOrganizations(r.Context(), oauthClient) + if err != nil { + httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ + Message: fmt.Sprintf("get authenticated github user organizations: %s", err), + }) + return + } + var selectedOrganization *github.Organization + for _, organization := range organizations { + if organization.Login == nil { + continue + } + for _, allowed := range api.GithubOAuth2Config.AllowOrganizations { + if *organization.Login != allowed { + continue + } + selectedOrganization = organization + break + } + } + if selectedOrganization == nil { + httpapi.Write(rw, http.StatusUnauthorized, httpapi.Response{ + Message: fmt.Sprintf("You aren't a member of the authorized Github organizations!"), + }) + return + } + + emails, err := api.GithubOAuth2Config.ListEmails(r.Context(), oauthClient) + if err != nil { + httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ + Message: fmt.Sprintf("get personal github user: %s", err), + }) + return + } + + var user database.User + // Search for existing users with matching and verified emails. + // If a verified GitHub email matches a Coder user, we will return. + for _, email := range emails { + if email.Verified == nil { + continue + } + user, err = api.Database.GetUserByEmailOrUsername(r.Context(), database.GetUserByEmailOrUsernameParams{ + Email: *email.Email, + }) + if errors.Is(err, sql.ErrNoRows) { + continue + } + if err != nil { + httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ + Message: fmt.Sprintf("get user by email: %s", err), + }) + return + } + if !*email.Verified { + httpapi.Write(rw, http.StatusForbidden, httpapi.Response{ + Message: fmt.Sprintf("Verify the %q email address on Github to authenticate!", *email.Email), + }) + return + } + break + } + + // If the user doesn't exist, create a new one! + if user.ID == uuid.Nil { + if !api.GithubOAuth2Config.AllowSignups { + httpapi.Write(rw, http.StatusForbidden, httpapi.Response{ + Message: "Signups are disabled for Github authentication!", + }) + return + } + + var organizationID uuid.UUID + organizations, err := api.Database.GetOrganizations(r.Context()) + if err == nil { + // Add the user to the first organization. Once multi-organization + // support is added, we should enable a configuration map of user + // email to organization. + organizationID = organizations[0].ID + } + ghUser, err := api.GithubOAuth2Config.AuthenticatedUser(r.Context(), oauthClient) + if err != nil { + httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ + Message: fmt.Sprintf("get authenticated github user: %s", err), + }) + return + } + user, _, err = api.createUser(r.Context(), codersdk.CreateUserRequest{ + Email: *ghUser.Email, + Username: *ghUser.Login, + OrganizationID: organizationID, + }) + if err != nil { + httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ + Message: fmt.Sprintf("create user: %s", err), + }) + return + } + } + + _, created := api.createAPIKey(rw, r, user.ID) + if !created { + return + } + + redirect := state.Redirect + if redirect == "" { + redirect = "/" + } + http.Redirect(rw, r, redirect, http.StatusTemporaryRedirect) +} diff --git a/coderd/useroauth2_test.go b/coderd/useroauth2_test.go new file mode 100644 index 0000000000000..7f649418a46d0 --- /dev/null +++ b/coderd/useroauth2_test.go @@ -0,0 +1,173 @@ +package coderd_test + +import ( + "context" + "net/http" + "net/url" + "testing" + + "github.com/google/go-github/v43/github" + "github.com/stretchr/testify/require" + "golang.org/x/oauth2" + + "github.com/coder/coder/coderd" + "github.com/coder/coder/coderd/coderdtest" + "github.com/coder/coder/codersdk" +) + +type oauth2Config struct{} + +func (*oauth2Config) AuthCodeURL(state string, _ ...oauth2.AuthCodeOption) string { + return "/?state=" + url.QueryEscape(state) +} + +func (*oauth2Config) Exchange(context.Context, string, ...oauth2.AuthCodeOption) (*oauth2.Token, error) { + return &oauth2.Token{ + AccessToken: "token", + }, nil +} + +func (*oauth2Config) TokenSource(context.Context, *oauth2.Token) oauth2.TokenSource { + return nil +} + +func TestUserOAuth2Github(t *testing.T) { + t.Parallel() + t.Run("NotInAllowedOrganization", func(t *testing.T) { + t.Parallel() + client := coderdtest.New(t, &coderdtest.Options{ + GithubOAuth2Config: &coderd.GithubOAuth2Config{ + OAuth2Config: &oauth2Config{}, + ListOrganizations: func(ctx context.Context, client *http.Client) ([]*github.Organization, error) { + return []*github.Organization{{ + Login: github.String("kyle"), + }}, nil + }, + }, + }) + + resp := oauth2Callback(t, client) + require.Equal(t, http.StatusUnauthorized, resp.StatusCode) + }) + t.Run("UnverifiedEmail", func(t *testing.T) { + t.Parallel() + client := coderdtest.New(t, &coderdtest.Options{ + GithubOAuth2Config: &coderd.GithubOAuth2Config{ + OAuth2Config: &oauth2Config{}, + AllowOrganizations: []string{"coder"}, + ListOrganizations: func(ctx context.Context, client *http.Client) ([]*github.Organization, error) { + return []*github.Organization{{ + Login: github.String("coder"), + }}, nil + }, + AuthenticatedUser: func(ctx context.Context, client *http.Client) (*github.User, error) { + return &github.User{}, nil + }, + ListEmails: func(ctx context.Context, client *http.Client) ([]*github.UserEmail, error) { + return []*github.UserEmail{{ + Email: github.String("testuser@coder.com"), + Verified: github.Bool(false), + }}, nil + }, + }, + }) + _ = coderdtest.CreateFirstUser(t, client) + resp := oauth2Callback(t, client) + require.Equal(t, http.StatusForbidden, resp.StatusCode) + }) + t.Run("BlockSignups", func(t *testing.T) { + t.Parallel() + client := coderdtest.New(t, &coderdtest.Options{ + GithubOAuth2Config: &coderd.GithubOAuth2Config{ + OAuth2Config: &oauth2Config{}, + AllowOrganizations: []string{"coder"}, + ListOrganizations: func(ctx context.Context, client *http.Client) ([]*github.Organization, error) { + return []*github.Organization{{ + Login: github.String("coder"), + }}, nil + }, + AuthenticatedUser: func(ctx context.Context, client *http.Client) (*github.User, error) { + return &github.User{}, nil + }, + ListEmails: func(ctx context.Context, client *http.Client) ([]*github.UserEmail, error) { + return []*github.UserEmail{}, nil + }, + }, + }) + resp := oauth2Callback(t, client) + require.Equal(t, http.StatusForbidden, resp.StatusCode) + }) + t.Run("Signup", func(t *testing.T) { + t.Parallel() + client := coderdtest.New(t, &coderdtest.Options{ + GithubOAuth2Config: &coderd.GithubOAuth2Config{ + OAuth2Config: &oauth2Config{}, + AllowOrganizations: []string{"coder"}, + AllowSignups: true, + ListOrganizations: func(ctx context.Context, client *http.Client) ([]*github.Organization, error) { + return []*github.Organization{{ + Login: github.String("coder"), + }}, nil + }, + AuthenticatedUser: func(ctx context.Context, client *http.Client) (*github.User, error) { + return &github.User{ + Login: github.String("kyle"), + Email: github.String("kyle@coder.com"), + }, nil + }, + ListEmails: func(ctx context.Context, client *http.Client) ([]*github.UserEmail, error) { + return []*github.UserEmail{}, nil + }, + }, + }) + resp := oauth2Callback(t, client) + require.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode) + }) + t.Run("Login", func(t *testing.T) { + t.Parallel() + client := coderdtest.New(t, &coderdtest.Options{ + GithubOAuth2Config: &coderd.GithubOAuth2Config{ + OAuth2Config: &oauth2Config{}, + AllowOrganizations: []string{"coder"}, + ListOrganizations: func(ctx context.Context, client *http.Client) ([]*github.Organization, error) { + return []*github.Organization{{ + Login: github.String("coder"), + }}, nil + }, + AuthenticatedUser: func(ctx context.Context, client *http.Client) (*github.User, error) { + return &github.User{}, nil + }, + ListEmails: func(ctx context.Context, client *http.Client) ([]*github.UserEmail, error) { + return []*github.UserEmail{{ + Email: github.String("testuser@coder.com"), + Verified: github.Bool(true), + }}, nil + }, + }, + }) + _ = coderdtest.CreateFirstUser(t, client) + resp := oauth2Callback(t, client) + require.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode) + }) +} + +func oauth2Callback(t *testing.T, client *codersdk.Client) *http.Response { + client.HTTPClient.CheckRedirect = func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + } + state := "somestate" + oauthURL, err := client.URL.Parse("/api/v2/users/oauth2/github/callback?code=asd&state=" + state) + require.NoError(t, err) + req, err := http.NewRequest("GET", oauthURL.String(), nil) + require.NoError(t, err) + req.AddCookie(&http.Cookie{ + Name: "oauth_state", + Value: state, + }) + res, err := client.HTTPClient.Do(req) + require.NoError(t, err) + t.Cleanup(func() { + _ = res.Body.Close() + }) + return res +} diff --git a/coderd/users.go b/coderd/users.go index 030c22976aedb..ee9fd97f5496c 100644 --- a/coderd/users.go +++ b/coderd/users.go @@ -70,29 +70,10 @@ func (api *api) postFirstUser(rw http.ResponseWriter, r *http.Request) { return } - // Create the user, organization, and membership to the user. - var user database.User - var organization database.Organization - err = api.Database.InTx(func(db database.Store) error { - organization, err = api.Database.InsertOrganization(r.Context(), database.InsertOrganizationParams{ - ID: uuid.New(), - Name: createUser.OrganizationName, - CreatedAt: database.Now(), - UpdatedAt: database.Now(), - }) - if err != nil { - return xerrors.Errorf("create organization: %w", err) - } - user, err = api.createUser(r.Context(), db, codersdk.CreateUserRequest{ - Email: createUser.Email, - Username: createUser.Username, - Password: createUser.Password, - OrganizationID: organization.ID, - }) - if err != nil { - return xerrors.Errorf("create user: %w", err) - } - return nil + user, organizationID, err := api.createUser(r.Context(), codersdk.CreateUserRequest{ + Email: createUser.Email, + Username: createUser.Username, + Password: createUser.Password, }) if err != nil { httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ @@ -103,7 +84,7 @@ func (api *api) postFirstUser(rw http.ResponseWriter, r *http.Request) { httpapi.Write(rw, http.StatusCreated, codersdk.CreateFirstUserResponse{ UserID: user.ID, - OrganizationID: organization.ID, + OrganizationID: organizationID, }) } @@ -163,7 +144,7 @@ func (api *api) postUsers(rw http.ResponseWriter, r *http.Request) { return } - user, err := api.createUser(r.Context(), api.Database, createUser) + user, _, err := api.createUser(r.Context(), createUser) if err != nil { httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ Message: err.Error(), @@ -809,7 +790,6 @@ func (api *api) createAPIKey(rw http.ResponseWriter, r *http.Request, userID uui CreatedAt: database.Now(), UpdatedAt: database.Now(), HashedSecret: hashed[:], - LoginType: database.LoginTypeBuiltIn, }) if err != nil { httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ @@ -831,52 +811,70 @@ func (api *api) createAPIKey(rw http.ResponseWriter, r *http.Request, userID uui return sessionToken, true } -func (api *api) createUser(ctx context.Context, db database.Store, req codersdk.CreateUserRequest) (database.User, error) { - params := database.InsertUserParams{ - ID: uuid.New(), - Email: req.Email, - Username: req.Username, - LoginType: database.LoginTypeBuiltIn, - CreatedAt: database.Now(), - UpdatedAt: database.Now(), - } - // If a user signs up with OAuth, they can have no password! - if req.Password != "" { - hashedPassword, err := userpassword.Hash(req.Password) - if err != nil { - return database.User{}, xerrors.Errorf("hash password: %w", err) +func (api *api) createUser(ctx context.Context, req codersdk.CreateUserRequest) (database.User, uuid.UUID, error) { + var user database.User + return user, req.OrganizationID, api.Database.InTx(func(db database.Store) error { + // If no organization is provided, create a new one for the user. + if req.OrganizationID == uuid.Nil { + organization, err := db.InsertOrganization(ctx, database.InsertOrganizationParams{ + ID: uuid.New(), + Name: req.Username, + CreatedAt: database.Now(), + UpdatedAt: database.Now(), + }) + if err != nil { + return xerrors.Errorf("create organization: %w", err) + } + req.OrganizationID = organization.ID } - params.HashedPassword = []byte(hashedPassword) - } - user, err := db.InsertUser(ctx, params) - if err != nil { - return database.User{}, xerrors.Errorf("create user: %w", err) - } + params := database.InsertUserParams{ + ID: uuid.New(), + Email: req.Email, + Username: req.Username, + LoginType: database.LoginTypeBasic, + CreatedAt: database.Now(), + UpdatedAt: database.Now(), + } + // If a user signs up with OAuth, they can have no password! + if req.Password != "" { + hashedPassword, err := userpassword.Hash(req.Password) + if err != nil { + return xerrors.Errorf("hash password: %w", err) + } + params.HashedPassword = []byte(hashedPassword) + } - privateKey, publicKey, err := gitsshkey.Generate(api.SSHKeygenAlgorithm) - if err != nil { - return database.User{}, xerrors.Errorf("generate user gitsshkey: %w", err) - } - _, err = db.InsertGitSSHKey(ctx, database.InsertGitSSHKeyParams{ - UserID: user.ID, - CreatedAt: database.Now(), - UpdatedAt: database.Now(), - PrivateKey: privateKey, - PublicKey: publicKey, - }) - if err != nil { - return database.User{}, xerrors.Errorf("insert user gitsshkey: %w", err) - } - _, err = db.InsertOrganizationMember(ctx, database.InsertOrganizationMemberParams{ - OrganizationID: req.OrganizationID, - UserID: user.ID, - CreatedAt: database.Now(), - UpdatedAt: database.Now(), - Roles: []string{}, + var err error + user, err = db.InsertUser(ctx, params) + if err != nil { + return xerrors.Errorf("create user: %w", err) + } + + privateKey, publicKey, err := gitsshkey.Generate(api.SSHKeygenAlgorithm) + if err != nil { + return xerrors.Errorf("generate user gitsshkey: %w", err) + } + _, err = db.InsertGitSSHKey(ctx, database.InsertGitSSHKeyParams{ + UserID: user.ID, + CreatedAt: database.Now(), + UpdatedAt: database.Now(), + PrivateKey: privateKey, + PublicKey: publicKey, + }) + if err != nil { + return xerrors.Errorf("insert user gitsshkey: %w", err) + } + _, err = db.InsertOrganizationMember(ctx, database.InsertOrganizationMemberParams{ + OrganizationID: req.OrganizationID, + UserID: user.ID, + CreatedAt: database.Now(), + UpdatedAt: database.Now(), + Roles: []string{}, + }) + if err != nil { + return xerrors.Errorf("create organization member: %w", err) + } + return nil }) - if err != nil { - return database.User{}, xerrors.Errorf("create organization member: %w", err) - } - return user, nil } diff --git a/coderd/users_test.go b/coderd/users_test.go index d733f022ae560..2caf8167325d2 100644 --- a/coderd/users_test.go +++ b/coderd/users_test.go @@ -240,13 +240,14 @@ func TestUpdateUserProfile(t *testing.T) { t.Parallel() client := coderdtest.New(t, nil) user := coderdtest.CreateFirstUser(t, client) - existentUser, _ := client.CreateUser(context.Background(), codersdk.CreateUserRequest{ + existentUser, err := client.CreateUser(context.Background(), codersdk.CreateUserRequest{ Email: "bruno@coder.com", Username: "bruno", Password: "password", OrganizationID: user.OrganizationID, }) - _, err := client.UpdateUserProfile(context.Background(), codersdk.Me, codersdk.UpdateUserProfileRequest{ + require.NoError(t, err) + _, err = client.UpdateUserProfile(context.Background(), codersdk.Me, codersdk.UpdateUserProfileRequest{ Username: existentUser.Username, Email: "newemail@coder.com", }) From 53c2bf444a28366e77b081bba3a2835c5fc8db5f Mon Sep 17 00:00:00 2001 From: Kyle Carberry Date: Sun, 17 Apr 2022 22:19:29 +0000 Subject: [PATCH 3/9] Add AuthMethods endpoint --- cli/start.go | 9 ++- coderd/coderd.go | 1 + coderd/{useroauth2.go => userauth.go} | 36 +++++++---- .../{useroauth2_test.go => userauth_test.go} | 62 ++++++++++++++----- coderd/users.go | 29 ++++++--- codersdk/users.go | 22 +++++++ 6 files changed, 119 insertions(+), 40 deletions(-) rename coderd/{useroauth2.go => userauth.go} (76%) rename coderd/{useroauth2_test.go => userauth_test.go} (72%) diff --git a/cli/start.go b/cli/start.go index ce3613ef5e609..c9b215b1c4d41 100644 --- a/cli/start.go +++ b/cli/start.go @@ -557,6 +557,7 @@ func configureGithubOAuth2(accessURL *url.URL, clientID, clientSecret string) (* RedirectURL: redirectURL.String(), Scopes: []string{ "read:user", + "read:org", "user:email", }, }, @@ -570,9 +571,11 @@ func configureGithubOAuth2(accessURL *url.URL, clientID, clientSecret string) (* emails, _, err := github.NewClient(client).Users.ListEmails(ctx, &github.ListOptions{}) return emails, err }, - ListOrganizations: func(ctx context.Context, client *http.Client) ([]*github.Organization, error) { - orgs, _, err := github.NewClient(client).Organizations.List(ctx, "", &github.ListOptions{}) - return orgs, err + ListOrganizationMemberships: func(ctx context.Context, client *http.Client) ([]*github.Membership, error) { + memberships, _, err := github.NewClient(client).Organizations.ListOrgMemberships(ctx, &github.ListOrgMembershipsOptions{ + State: "active", + }) + return memberships, err }, }, nil } diff --git a/coderd/coderd.go b/coderd/coderd.go index fe1f5bf8b4fd7..33ea38a60584e 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -146,6 +146,7 @@ func New(options *Options) (http.Handler, func()) { r.Post("/first", api.postFirstUser) r.Post("/login", api.postLogin) r.Post("/logout", api.postLogout) + r.Get("/authmethods", api.userAuthMethods) r.Route("/oauth2", func(r chi.Router) { r.Route("/github", func(r chi.Router) { r.Use(httpmw.ExtractOAuth2(options.GithubOAuth2Config)) diff --git a/coderd/useroauth2.go b/coderd/userauth.go similarity index 76% rename from coderd/useroauth2.go rename to coderd/userauth.go index 427eea1de5e41..57f2ba9bdab82 100644 --- a/coderd/useroauth2.go +++ b/coderd/userauth.go @@ -20,39 +20,43 @@ import ( // GithubOAuth2Provider exposes required functions for the Github authentication flow. type GithubOAuth2Config struct { httpmw.OAuth2Config - AuthenticatedUser func(ctx context.Context, client *http.Client) (*github.User, error) - ListEmails func(ctx context.Context, client *http.Client) ([]*github.UserEmail, error) - ListOrganizations func(ctx context.Context, client *http.Client) ([]*github.Organization, error) + AuthenticatedUser func(ctx context.Context, client *http.Client) (*github.User, error) + ListEmails func(ctx context.Context, client *http.Client) ([]*github.UserEmail, error) + ListOrganizationMemberships func(ctx context.Context, client *http.Client) ([]*github.Membership, error) AllowSignups bool AllowOrganizations []string } +func (api *api) userAuthMethods(rw http.ResponseWriter, _ *http.Request) { + httpapi.Write(rw, http.StatusOK, codersdk.AuthMethods{ + Password: true, + Github: api.GithubOAuth2Config != nil, + }) +} + func (api *api) userOAuth2Github(rw http.ResponseWriter, r *http.Request) { state := httpmw.OAuth2(r) oauthClient := oauth2.NewClient(r.Context(), oauth2.StaticTokenSource(state.Token)) - organizations, err := api.GithubOAuth2Config.ListOrganizations(r.Context(), oauthClient) + memberships, err := api.GithubOAuth2Config.ListOrganizationMemberships(r.Context(), oauthClient) if err != nil { httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ Message: fmt.Sprintf("get authenticated github user organizations: %s", err), }) return } - var selectedOrganization *github.Organization - for _, organization := range organizations { - if organization.Login == nil { - continue - } + var selectedMembership *github.Membership + for _, membership := range memberships { for _, allowed := range api.GithubOAuth2Config.AllowOrganizations { - if *organization.Login != allowed { + if *membership.Organization.Login != allowed { continue } - selectedOrganization = organization + selectedMembership = membership break } } - if selectedOrganization == nil { + if selectedMembership == nil { httpapi.Write(rw, http.StatusUnauthorized, httpapi.Response{ Message: fmt.Sprintf("You aren't a member of the authorized Github organizations!"), }) @@ -132,7 +136,13 @@ func (api *api) userOAuth2Github(rw http.ResponseWriter, r *http.Request) { } } - _, created := api.createAPIKey(rw, r, user.ID) + _, created := api.createAPIKey(rw, r, database.InsertAPIKeyParams{ + UserID: user.ID, + LoginType: database.LoginTypeGithub, + OAuthAccessToken: state.Token.AccessToken, + OAuthRefreshToken: state.Token.RefreshToken, + OAuthExpiry: state.Token.Expiry, + }) if !created { return } diff --git a/coderd/useroauth2_test.go b/coderd/userauth_test.go similarity index 72% rename from coderd/useroauth2_test.go rename to coderd/userauth_test.go index 7f649418a46d0..b16af4f1250e7 100644 --- a/coderd/useroauth2_test.go +++ b/coderd/userauth_test.go @@ -31,6 +31,28 @@ func (*oauth2Config) TokenSource(context.Context, *oauth2.Token) oauth2.TokenSou return nil } +func TestUserAuthMethods(t *testing.T) { + t.Parallel() + t.Run("Basic", func(t *testing.T) { + t.Parallel() + client := coderdtest.New(t, nil) + methods, err := client.AuthMethods(context.Background()) + require.NoError(t, err) + require.True(t, methods.Password) + require.False(t, methods.Github) + }) + t.Run("Github", func(t *testing.T) { + t.Parallel() + client := coderdtest.New(t, &coderdtest.Options{ + GithubOAuth2Config: &coderd.GithubOAuth2Config{}, + }) + methods, err := client.AuthMethods(context.Background()) + require.NoError(t, err) + require.True(t, methods.Password) + require.True(t, methods.Github) + }) +} + func TestUserOAuth2Github(t *testing.T) { t.Parallel() t.Run("NotInAllowedOrganization", func(t *testing.T) { @@ -38,9 +60,11 @@ func TestUserOAuth2Github(t *testing.T) { client := coderdtest.New(t, &coderdtest.Options{ GithubOAuth2Config: &coderd.GithubOAuth2Config{ OAuth2Config: &oauth2Config{}, - ListOrganizations: func(ctx context.Context, client *http.Client) ([]*github.Organization, error) { - return []*github.Organization{{ - Login: github.String("kyle"), + ListOrganizationMemberships: func(ctx context.Context, client *http.Client) ([]*github.Membership, error) { + return []*github.Membership{{ + Organization: &github.Organization{ + Login: github.String("kyle"), + }, }}, nil }, }, @@ -55,9 +79,11 @@ func TestUserOAuth2Github(t *testing.T) { GithubOAuth2Config: &coderd.GithubOAuth2Config{ OAuth2Config: &oauth2Config{}, AllowOrganizations: []string{"coder"}, - ListOrganizations: func(ctx context.Context, client *http.Client) ([]*github.Organization, error) { - return []*github.Organization{{ - Login: github.String("coder"), + ListOrganizationMemberships: func(ctx context.Context, client *http.Client) ([]*github.Membership, error) { + return []*github.Membership{{ + Organization: &github.Organization{ + Login: github.String("coder"), + }, }}, nil }, AuthenticatedUser: func(ctx context.Context, client *http.Client) (*github.User, error) { @@ -81,9 +107,11 @@ func TestUserOAuth2Github(t *testing.T) { GithubOAuth2Config: &coderd.GithubOAuth2Config{ OAuth2Config: &oauth2Config{}, AllowOrganizations: []string{"coder"}, - ListOrganizations: func(ctx context.Context, client *http.Client) ([]*github.Organization, error) { - return []*github.Organization{{ - Login: github.String("coder"), + ListOrganizationMemberships: func(ctx context.Context, client *http.Client) ([]*github.Membership, error) { + return []*github.Membership{{ + Organization: &github.Organization{ + Login: github.String("coder"), + }, }}, nil }, AuthenticatedUser: func(ctx context.Context, client *http.Client) (*github.User, error) { @@ -104,9 +132,11 @@ func TestUserOAuth2Github(t *testing.T) { OAuth2Config: &oauth2Config{}, AllowOrganizations: []string{"coder"}, AllowSignups: true, - ListOrganizations: func(ctx context.Context, client *http.Client) ([]*github.Organization, error) { - return []*github.Organization{{ - Login: github.String("coder"), + ListOrganizationMemberships: func(ctx context.Context, client *http.Client) ([]*github.Membership, error) { + return []*github.Membership{{ + Organization: &github.Organization{ + Login: github.String("coder"), + }, }}, nil }, AuthenticatedUser: func(ctx context.Context, client *http.Client) (*github.User, error) { @@ -129,9 +159,11 @@ func TestUserOAuth2Github(t *testing.T) { GithubOAuth2Config: &coderd.GithubOAuth2Config{ OAuth2Config: &oauth2Config{}, AllowOrganizations: []string{"coder"}, - ListOrganizations: func(ctx context.Context, client *http.Client) ([]*github.Organization, error) { - return []*github.Organization{{ - Login: github.String("coder"), + ListOrganizationMemberships: func(ctx context.Context, client *http.Client) ([]*github.Membership, error) { + return []*github.Membership{{ + Organization: &github.Organization{ + Login: github.String("coder"), + }, }}, nil }, AuthenticatedUser: func(ctx context.Context, client *http.Client) (*github.User, error) { diff --git a/coderd/users.go b/coderd/users.go index ee9fd97f5496c..406940b2b81e9 100644 --- a/coderd/users.go +++ b/coderd/users.go @@ -375,7 +375,10 @@ func (api *api) postLogin(rw http.ResponseWriter, r *http.Request) { return } - sessionToken, created := api.createAPIKey(rw, r, user.ID) + sessionToken, created := api.createAPIKey(rw, r, database.InsertAPIKeyParams{ + UserID: user.ID, + LoginType: database.LoginTypeBasic, + }) if !created { return } @@ -397,7 +400,10 @@ func (api *api) postAPIKey(rw http.ResponseWriter, r *http.Request) { return } - sessionToken, created := api.createAPIKey(rw, r, user.ID) + sessionToken, created := api.createAPIKey(rw, r, database.InsertAPIKeyParams{ + UserID: user.ID, + LoginType: database.LoginTypeBasic, + }) if !created { return } @@ -773,7 +779,7 @@ func convertUser(user database.User) codersdk.User { } } -func (api *api) createAPIKey(rw http.ResponseWriter, r *http.Request, userID uuid.UUID) (string, bool) { +func (api *api) createAPIKey(rw http.ResponseWriter, r *http.Request, params database.InsertAPIKeyParams) (string, bool) { keyID, keySecret, err := generateAPIKeyIDSecret() if err != nil { httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ @@ -784,12 +790,17 @@ func (api *api) createAPIKey(rw http.ResponseWriter, r *http.Request, userID uui hashed := sha256.Sum256([]byte(keySecret)) _, err = api.Database.InsertAPIKey(r.Context(), database.InsertAPIKeyParams{ - ID: keyID, - UserID: userID, - ExpiresAt: database.Now().Add(24 * time.Hour), - CreatedAt: database.Now(), - UpdatedAt: database.Now(), - HashedSecret: hashed[:], + ID: keyID, + UserID: params.UserID, + ExpiresAt: database.Now().Add(24 * time.Hour), + CreatedAt: database.Now(), + UpdatedAt: database.Now(), + HashedSecret: hashed[:], + LoginType: params.LoginType, + OAuthAccessToken: params.OAuthAccessToken, + OAuthRefreshToken: params.OAuthRefreshToken, + OAuthIDToken: params.OAuthIDToken, + OAuthExpiry: params.OAuthExpiry, }) if err != nil { httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ diff --git a/codersdk/users.go b/codersdk/users.go index d6a920a4c4bdc..363e83d4ad3ea 100644 --- a/codersdk/users.go +++ b/codersdk/users.go @@ -77,6 +77,12 @@ type CreateWorkspaceRequest struct { ParameterValues []CreateParameterRequest `json:"parameter_values"` } +// AuthMethods contains whether authentication types are enabled or not. +type AuthMethods struct { + Password bool `json:"password"` + Github bool `json:"github"` +} + // HasFirstUser returns whether the first user has been created. func (c *Client) HasFirstUser(ctx context.Context) (bool, error) { res, err := c.request(ctx, http.MethodGet, "/api/v2/users/first", nil) @@ -287,6 +293,22 @@ func (c *Client) WorkspaceByName(ctx context.Context, userID uuid.UUID, name str return workspace, json.NewDecoder(res.Body).Decode(&workspace) } +// AuthMethods returns types of authentication available to the user. +func (c *Client) AuthMethods(ctx context.Context) (AuthMethods, error) { + res, err := c.request(ctx, http.MethodGet, "/api/v2/users/authmethods", nil) + if err != nil { + return AuthMethods{}, err + } + defer res.Body.Close() + + if res.StatusCode != http.StatusOK { + return AuthMethods{}, readBodyAsError(res) + } + + var userAuth AuthMethods + return userAuth, json.NewDecoder(res.Body).Decode(&userAuth) +} + // uuidOrMe returns the provided uuid as a string if it's valid, ortherwise // `me`. func uuidOrMe(id uuid.UUID) string { From 259515685c56fce2364bb35341c90a29fae885f9 Mon Sep 17 00:00:00 2001 From: Kyle Carberry Date: Sun, 17 Apr 2022 22:25:15 +0000 Subject: [PATCH 4/9] Add frontend --- site/src/api/index.ts | 5 +++ site/src/api/types.ts | 5 +++ site/src/components/SignIn/SignInForm.tsx | 23 +++++++++-- site/src/pages/login.test.tsx | 31 +++++++++++++- site/src/pages/login.tsx | 2 +- site/src/xServices/auth/authXService.ts | 49 +++++++++++++++++++++-- 6 files changed, 105 insertions(+), 10 deletions(-) diff --git a/site/src/api/index.ts b/site/src/api/index.ts index 7d2b664f357fa..d79914bef74ec 100644 --- a/site/src/api/index.ts +++ b/site/src/api/index.ts @@ -65,6 +65,11 @@ export const getUser = async (): Promise => { return response.data } +export const getAuthMethods = async (): Promise => { + const response = await axios.get("/api/v2/users/authmethods") + return response.data +} + export const getApiKey = async (): Promise => { const response = await axios.post("/api/v2/users/me/keys") return response.data diff --git a/site/src/api/types.ts b/site/src/api/types.ts index 7b95a64743174..ec90a2c5acd9f 100644 --- a/site/src/api/types.ts +++ b/site/src/api/types.ts @@ -18,6 +18,11 @@ export interface UserResponse { readonly name: string } +export interface AuthMethods { + readonly password: boolean + readonly github: boolean +} + /** * `Organization` must be kept in sync with the go struct in organizations.go */ diff --git a/site/src/components/SignIn/SignInForm.tsx b/site/src/components/SignIn/SignInForm.tsx index 917211082a906..ddee53fe20baf 100644 --- a/site/src/components/SignIn/SignInForm.tsx +++ b/site/src/components/SignIn/SignInForm.tsx @@ -1,11 +1,14 @@ +import Button from "@material-ui/core/Button" import FormHelperText from "@material-ui/core/FormHelperText" +import Link from "@material-ui/core/Link" import { makeStyles } from "@material-ui/core/styles" import TextField from "@material-ui/core/TextField" import { FormikContextType, useFormik } from "formik" import React from "react" import * as Yup from "yup" +import { AuthMethods } from "../../api/types" +import { LoadingButton } from "../Button" import { getFormHelpers, onChangeTrimmed } from "../Form" -import { LoadingButton } from "./../Button" import { Welcome } from "./Welcome" /** @@ -24,7 +27,8 @@ export const Language = { emailInvalid: "Please enter a valid email address.", emailRequired: "Please enter an email address.", authErrorMessage: "Incorrect email or password.", - signIn: "Sign In", + basicSignIn: "Sign In", + githubSignIn: "GitHub", } const validationSchema = Yup.object({ @@ -49,10 +53,11 @@ const useStyles = makeStyles((theme) => ({ export interface SignInFormProps { isLoading: boolean authErrorMessage?: string + authMethods?: AuthMethods onSubmit: ({ email, password }: { email: string; password: string }) => Promise } -export const SignInForm: React.FC = ({ isLoading, authErrorMessage, onSubmit }) => { +export const SignInForm: React.FC = ({ authMethods, isLoading, authErrorMessage, onSubmit }) => { const styles = useStyles() const form: FormikContextType = useFormik({ @@ -76,6 +81,7 @@ export const SignInForm: React.FC = ({ isLoading, authErrorMess className={styles.loginTextField} fullWidth label={Language.emailLabel} + type="email" variant="outlined" /> = ({ isLoading, authErrorMess {authErrorMessage && {Language.authErrorMessage}}
- {isLoading ? "" : Language.signIn} + {isLoading ? "" : Language.basicSignIn}
+ {authMethods?.github && ( +
+ + + +
+ )} ) } diff --git a/site/src/pages/login.test.tsx b/site/src/pages/login.test.tsx index f54379f7490fb..4d3d643b31292 100644 --- a/site/src/pages/login.test.tsx +++ b/site/src/pages/login.test.tsx @@ -16,6 +16,15 @@ describe("SignInPage", () => { return res(ctx.status(401), ctx.json({ message: "no user here" })) }), ) + // only leave password auth enabled by default + server.use( + rest.get("/api/v2/users/auth", (req, res, ctx) => { + return res(ctx.status(200), ctx.json({ + password: true, + github: false, + })) + }) + ) }) it("renders the sign-in form", async () => { @@ -23,7 +32,7 @@ describe("SignInPage", () => { render() // Then - await screen.findByText(Language.signIn) + await screen.findByText(Language.basicSignIn) }) it("shows an error message if SignIn fails", async () => { @@ -42,7 +51,7 @@ describe("SignInPage", () => { await userEvent.type(email, "test@coder.com") await userEvent.type(password, "password") // Click sign-in - const signInButton = await screen.findByText(Language.signIn) + const signInButton = await screen.findByText(Language.basicSignIn) act(() => signInButton.click()) // Then @@ -50,4 +59,22 @@ describe("SignInPage", () => { expect(errorMessage).toBeDefined() expect(history.location.pathname).toEqual("/login") }) + + it("shows github authentication when enabled", async () => { + // Given + server.use( + rest.get("/api/v2/users/auth", async (req, res, ctx) => { + return res(ctx.status(200), ctx.json({ + password: true, + github: true, + })) + }), + ) + + // When + render() + + // Then + await screen.findByText(Language.githubSignIn) + }) }) diff --git a/site/src/pages/login.tsx b/site/src/pages/login.tsx index fdde9fcc61f33..93782e48d2e73 100644 --- a/site/src/pages/login.tsx +++ b/site/src/pages/login.tsx @@ -47,7 +47,7 @@ export const SignInPage: React.FC = () => {
- +