From 305f696bb3127b6a30047af6aab16d052e19a73d Mon Sep 17 00:00:00 2001 From: Jon Ayers Date: Tue, 9 Aug 2022 19:47:53 +0000 Subject: [PATCH 01/32] fix: use unique ID for linked accounts --- coderd/database/generate.sh | 2 + .../migrations/000034_linked_user_id.down.sql | 6 + .../migrations/000034_linked_user_id.up.sql | 19 +++ coderd/database/models.go | 6 + coderd/database/querier.go | 3 + coderd/database/queries.sql.go | 56 +++++++++ coderd/database/queries/user_auth.sql | 24 ++++ coderd/userauth.go | 117 +++++++++++++++--- 8 files changed, 213 insertions(+), 20 deletions(-) create mode 100644 coderd/database/migrations/000034_linked_user_id.down.sql create mode 100644 coderd/database/migrations/000034_linked_user_id.up.sql create mode 100644 coderd/database/queries/user_auth.sql diff --git a/coderd/database/generate.sh b/coderd/database/generate.sh index e00b0ae73a425..60f9cd2e226fb 100755 --- a/coderd/database/generate.sh +++ b/coderd/database/generate.sh @@ -13,6 +13,8 @@ SCRIPT_DIR=$(dirname "${BASH_SOURCE[0]}") ( cd "$SCRIPT_DIR" + # Dump the updated schema. + go run dump/main.go # The logic below depends on the exact version being correct :( go run github.com/kyleconroy/sqlc/cmd/sqlc@v1.13.0 generate diff --git a/coderd/database/migrations/000034_linked_user_id.down.sql b/coderd/database/migrations/000034_linked_user_id.down.sql new file mode 100644 index 0000000000000..6e8d37f7e7cf7 --- /dev/null +++ b/coderd/database/migrations/000034_linked_user_id.down.sql @@ -0,0 +1,6 @@ +BEGIN; + +ALTER TABLE users DROP COLUMN linked_id; +ALTER TABLE users DROP COLUMN login_type; + +COMMIT; diff --git a/coderd/database/migrations/000034_linked_user_id.up.sql b/coderd/database/migrations/000034_linked_user_id.up.sql new file mode 100644 index 0000000000000..51b90a807de3f --- /dev/null +++ b/coderd/database/migrations/000034_linked_user_id.up.sql @@ -0,0 +1,19 @@ +BEGIN; + +ALTER TABLE users ADD COLUMN login_type login_type NOT NULL DEFAULT 'password'; +ALTER TABLE users ADD COLUMN linked_id text NOT NULL DEFAULT ''; + +UPDATE + users +SET + login_type = ( + SELECT + login_type + FROM + api_keys + WHERE + api_keys.user_id = users.id + ORDER BY updated_at DESC + LIMIT 1 + ) +COMMIT; diff --git a/coderd/database/models.go b/coderd/database/models.go index 6cf4c07761674..a4df558504761 100644 --- a/coderd/database/models.go +++ b/coderd/database/models.go @@ -493,6 +493,12 @@ type User struct { RBACRoles []string `db:"rbac_roles" json:"rbac_roles"` } +type UserAuth struct { + UserID uuid.UUID `db:"user_id" json:"user_id"` + LoginType LoginType `db:"login_type" json:"login_type"` + LinkedID string `db:"linked_id" json:"linked_id"` +} + type Workspace struct { ID uuid.UUID `db:"id" json:"id"` CreatedAt time.Time `db:"created_at" json:"created_at"` diff --git a/coderd/database/querier.go b/coderd/database/querier.go index 90e9a3a0a1385..f37424e58df78 100644 --- a/coderd/database/querier.go +++ b/coderd/database/querier.go @@ -61,6 +61,8 @@ type querier interface { GetTemplateVersionsCreatedAfter(ctx context.Context, createdAt time.Time) ([]TemplateVersion, error) GetTemplates(ctx context.Context) ([]Template, error) GetTemplatesWithFilter(ctx context.Context, arg GetTemplatesWithFilterParams) ([]Template, error) + GetUserAuthByLinkedID(ctx context.Context, linkedID string) (UserAuth, error) + GetUserAuthByUserID(ctx context.Context, userID uuid.UUID) (UserAuth, error) GetUserByEmailOrUsername(ctx context.Context, arg GetUserByEmailOrUsernameParams) (User, error) GetUserByID(ctx context.Context, id uuid.UUID) (User, error) GetUserCount(ctx context.Context) (int64, error) @@ -106,6 +108,7 @@ type querier interface { InsertTemplate(ctx context.Context, arg InsertTemplateParams) (Template, error) InsertTemplateVersion(ctx context.Context, arg InsertTemplateVersionParams) (TemplateVersion, error) InsertUser(ctx context.Context, arg InsertUserParams) (User, error) + InsertUserAuth(ctx context.Context, arg InsertUserAuthParams) (UserAuth, error) InsertWorkspace(ctx context.Context, arg InsertWorkspaceParams) (Workspace, error) InsertWorkspaceAgent(ctx context.Context, arg InsertWorkspaceAgentParams) (WorkspaceAgent, error) InsertWorkspaceApp(ctx context.Context, arg InsertWorkspaceAppParams) (WorkspaceApp, error) diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index d36262a121ee2..3fde41e01a8c0 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -2444,6 +2444,62 @@ func (q *sqlQuerier) UpdateTemplateVersionDescriptionByJobID(ctx context.Context return err } +const getUserAuthByLinkedID = `-- name: GetUserAuthByLinkedID :one +SELECT + user_id, login_type, linked_id +FROM + user_auth +WHERE + linked_id = $1 +` + +func (q *sqlQuerier) GetUserAuthByLinkedID(ctx context.Context, linkedID string) (UserAuth, error) { + row := q.db.QueryRowContext(ctx, getUserAuthByLinkedID, linkedID) + var i UserAuth + err := row.Scan(&i.UserID, &i.LoginType, &i.LinkedID) + return i, err +} + +const getUserAuthByUserID = `-- name: GetUserAuthByUserID :one +SELECT + user_id, login_type, linked_id +FROM + user_auth +WHERE + user_id = $1 +` + +func (q *sqlQuerier) GetUserAuthByUserID(ctx context.Context, userID uuid.UUID) (UserAuth, error) { + row := q.db.QueryRowContext(ctx, getUserAuthByUserID, userID) + var i UserAuth + err := row.Scan(&i.UserID, &i.LoginType, &i.LinkedID) + return i, err +} + +const insertUserAuth = `-- name: InsertUserAuth :one +INSERT INTO + user_auth ( + user_id, + login_type, + linked_id + ) +VALUES + ( $1, $2, $3) RETURNING user_id, login_type, linked_id +` + +type InsertUserAuthParams struct { + UserID uuid.UUID `db:"user_id" json:"user_id"` + LoginType LoginType `db:"login_type" json:"login_type"` + LinkedID string `db:"linked_id" json:"linked_id"` +} + +func (q *sqlQuerier) InsertUserAuth(ctx context.Context, arg InsertUserAuthParams) (UserAuth, error) { + row := q.db.QueryRowContext(ctx, insertUserAuth, arg.UserID, arg.LoginType, arg.LinkedID) + var i UserAuth + err := row.Scan(&i.UserID, &i.LoginType, &i.LinkedID) + return i, err +} + const getAuthorizationUserRoles = `-- name: GetAuthorizationUserRoles :one SELECT -- username is returned just to help for logging purposes diff --git a/coderd/database/queries/user_auth.sql b/coderd/database/queries/user_auth.sql new file mode 100644 index 0000000000000..9228c96ff6116 --- /dev/null +++ b/coderd/database/queries/user_auth.sql @@ -0,0 +1,24 @@ +-- name: GetUserAuthByUserID :one +SELECT + * +FROM + user_auth +WHERE + user_id = $1; +-- name: InsertUserAuth :one +INSERT INTO + user_auth ( + user_id, + login_type, + linked_id + ) +VALUES + ( $1, $2, $3) RETURNING *; +-- name: GetUserAuthByLinkedID :one +SELECT + * +FROM + user_auth +WHERE + linked_id = $1; + diff --git a/coderd/userauth.go b/coderd/userauth.go index 0ddb3f34d7a21..6336b765b4923 100644 --- a/coderd/userauth.go +++ b/coderd/userauth.go @@ -12,6 +12,7 @@ import ( "github.com/google/go-github/v43/github" "github.com/google/uuid" "golang.org/x/oauth2" + "golang.org/x/xerrors" "github.com/coder/coder/coderd/database" "github.com/coder/coder/coderd/httpapi" @@ -219,7 +220,10 @@ type OIDCConfig struct { } func (api *API) userOIDC(rw http.ResponseWriter, r *http.Request) { - state := httpmw.OAuth2(r) + var ( + ctx = r.Context() + state = httpmw.OAuth2(r) + ) // See the example here: https://github.com/coreos/go-oidc rawIDToken, ok := state.Token.Extra("id_token").(string) @@ -230,7 +234,7 @@ func (api *API) userOIDC(rw http.ResponseWriter, r *http.Request) { return } - idToken, err := api.OIDCConfig.Verifier.Verify(r.Context(), rawIDToken) + idToken, err := api.OIDCConfig.Verifier.Verify(ctx, rawIDToken) if err != nil { httpapi.Write(rw, http.StatusBadRequest, codersdk.Response{ Message: "Failed to verify OIDC token.", @@ -285,26 +289,38 @@ func (api *API) userOIDC(rw http.ResponseWriter, r *http.Request) { } } - var user database.User - user, err = api.Database.GetUserByEmailOrUsername(r.Context(), database.GetUserByEmailOrUsernameParams{ - Email: claims.Email, - }) - if errors.Is(err, sql.ErrNoRows) { - if !api.OIDCConfig.AllowSignups { - httpapi.Write(rw, http.StatusForbidden, codersdk.Response{ - Message: "Signups are disabled for OIDC authentication!", - }) - return + api.Database.InTx( + func(store database.Store) error { + } + ) + + user, found, err := findLinkedUser(ctx, api.Database, database.LoginTypeOIDC, uniqueUserOIDC(idToken), claims.Email) + if err != nil { + httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{ + Message: "Failed to find user.", + Detail: err.Error(), + }) + return + } + + if !found && !api.OIDCConfig.AllowSignups { + httpapi.Write(rw, http.StatusForbidden, codersdk.Response{ + Message: "Signups are disabled for OIDC authentication!", + }) + return + } + + if !found { var organizationID uuid.UUID - organizations, _ := api.Database.GetOrganizations(r.Context()) + organizations, _ := api.Database.GetOrganizations(ctx) if len(organizations) > 0 { // Add the user to the first organization. Once multi-organization // support is added, we should enable a configuration map of user // email to organization. organizationID = organizations[0].ID } - user, _, err = api.createUser(r.Context(), codersdk.CreateUserRequest{ + user, _, err = api.createUser(ctx, codersdk.CreateUserRequest{ Email: claims.Email, Username: claims.Username, OrganizationID: organizationID, @@ -316,15 +332,22 @@ func (api *API) userOIDC(rw http.ResponseWriter, r *http.Request) { }) return } - } - if err != nil { - httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{ - Message: "Failed to get user by email.", - Detail: err.Error(), + _, err = api.Database.InsertUserAuth(ctx, database.InsertUserAuthParams{ + UserID: user.ID, + LoginType: database.LoginTypeOIDC, + LinkedID: uniqueUserOIDC(idToken), }) - return + if err != nil { + httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{ + Message: "Failed to insert user auth metadata.", + Detail: err.Error(), + }) + return + } } + if user.Email != claims.Email || user.Username != claims.Username { + } _, created := api.createAPIKey(rw, r, database.InsertAPIKeyParams{ UserID: user.ID, LoginType: database.LoginTypeOIDC, @@ -342,3 +365,57 @@ func (api *API) userOIDC(rw http.ResponseWriter, r *http.Request) { } http.Redirect(rw, r, redirect, http.StatusTemporaryRedirect) } + +func uniqueUserOIDC(tok *oidc.IDToken) string { + return strings.Join([]string{tok.Issuer, tok.Subject}, "||") +} + +func findLinkedUser(ctx context.Context, db database.Store, authType database.LoginType, linkedID string, email string) (database.User, bool, error) { + var user database.User + + uauth, err := db.GetUserAuthByLinkedID(ctx, linkedID) + if err != nil && !errors.Is(err, sql.ErrNoRows) { + return user, false, xerrors.Errorf("get user auth by linked ID: %w", err) + } + + if err == nil { + user, err := db.GetUserByID(ctx, uauth.UserID) + if err != nil { + return user, false, xerrors.Errorf("get user by ID: %w", err) + } + return user, true, nil + } + + user, err = db.GetUserByEmailOrUsername(ctx, database.GetUserByEmailOrUsernameParams{ + Email: email, + }) + if err != nil && !errors.Is(err, sql.ErrNoRows) { + return user, false, xerrors.Errorf("get user by email: %w", err) + } + if errors.Is(err, sql.ErrNoRows) { + return user, false, nil + } + + // Try getting the UAuth by user ID instead now. Maybe the user + // logged in using a different login type. + uauth, err = db.GetUserAuthByUserID(ctx, user.ID) + if err != nil && errors.Is(err, sql.ErrNoRows) { + return user, false, xerrors.Errorf("get user auth by user ID: %w", err) + } + if uauth.LoginType != authType { + return user, false, xerrors.Errorf("cannot login with %q with account is already linked with %q", authType, uauth.LoginType) + } + if err == nil { + return user, false, xerrors.Errorf("user auth already exists with different linked ID? Expecting %q but got %q", linkedID, uauth.LinkedID) + } + + _, err = db.InsertUserAuth(ctx, database.InsertUserAuthParams{ + UserID: user.ID, + LoginType: authType, + LinkedID: linkedID, + }) + if err != nil { + return user, false, xerrors.Errorf("insert user auth: %w", err) + } + return user, true, nil +} From b4ab3018b22886f6ba19ce5fd94afb5857141530 Mon Sep 17 00:00:00 2001 From: Jon Ayers Date: Tue, 9 Aug 2022 22:42:01 +0000 Subject: [PATCH 02/32] fixup a bunch of stuff --- coderd/database/databasefake/databasefake.go | 2 + coderd/database/db_test.go | 2 + coderd/database/dump.sql | 4 +- .../migrations/000034_linked_user_id.up.sql | 3 +- coderd/database/models.go | 8 +- coderd/database/querier.go | 5 +- coderd/database/queries.sql.go | 155 +++++++----- coderd/database/queries/user_auth.sql | 24 -- coderd/database/queries/users.sql | 22 +- coderd/httpmw/apikey_test.go | 2 + coderd/httpmw/authorize_test.go | 2 + coderd/httpmw/organizationparam_test.go | 2 + coderd/httpmw/templateparam_test.go | 2 + coderd/httpmw/templateversionparam_test.go | 2 + coderd/httpmw/userparam_test.go | 8 +- coderd/httpmw/workspaceagentparam_test.go | 2 + coderd/httpmw/workspacebuildparam_test.go | 2 + coderd/httpmw/workspaceparam_test.go | 2 + coderd/provisionerjobs_internal_test.go | 2 + coderd/telemetry/telemetry_test.go | 4 + coderd/userauth.go | 230 +++++++++++------- coderd/users.go | 28 ++- 22 files changed, 313 insertions(+), 200 deletions(-) delete mode 100644 coderd/database/queries/user_auth.sql diff --git a/coderd/database/databasefake/databasefake.go b/coderd/database/databasefake/databasefake.go index 6623d338af7c4..1013a79d92d4f 100644 --- a/coderd/database/databasefake/databasefake.go +++ b/coderd/database/databasefake/databasefake.go @@ -1743,6 +1743,8 @@ func (q *fakeQuerier) InsertUser(_ context.Context, arg database.InsertUserParam Username: arg.Username, Status: database.UserStatusActive, RBACRoles: arg.RBACRoles, + LoginType: arg.LoginType, + LinkedID: arg.LinkedID, } q.users = append(q.users, user) return user, nil diff --git a/coderd/database/db_test.go b/coderd/database/db_test.go index 324e048e9156c..1fbdc4f34c2da 100644 --- a/coderd/database/db_test.go +++ b/coderd/database/db_test.go @@ -37,6 +37,8 @@ func TestNestedInTx(t *testing.T) { CreatedAt: database.Now(), UpdatedAt: database.Now(), RBACRoles: []string{}, + LoginType: database.LoginTypePassword, + LinkedID: uuid.NewString(), }) return err }) diff --git a/coderd/database/dump.sql b/coderd/database/dump.sql index 87853f23fe16e..f9e640e2f840b 100644 --- a/coderd/database/dump.sql +++ b/coderd/database/dump.sql @@ -275,7 +275,9 @@ CREATE TABLE users ( created_at timestamp with time zone NOT NULL, updated_at timestamp with time zone NOT NULL, status user_status DEFAULT 'active'::public.user_status NOT NULL, - rbac_roles text[] DEFAULT '{}'::text[] NOT NULL + rbac_roles text[] DEFAULT '{}'::text[] NOT NULL, + login_type login_type DEFAULT 'password'::public.login_type NOT NULL, + linked_id text DEFAULT ''::text NOT NULL ); CREATE TABLE workspace_agents ( diff --git a/coderd/database/migrations/000034_linked_user_id.up.sql b/coderd/database/migrations/000034_linked_user_id.up.sql index 51b90a807de3f..003bb389f1393 100644 --- a/coderd/database/migrations/000034_linked_user_id.up.sql +++ b/coderd/database/migrations/000034_linked_user_id.up.sql @@ -15,5 +15,6 @@ SET api_keys.user_id = users.id ORDER BY updated_at DESC LIMIT 1 - ) + ); + COMMIT; diff --git a/coderd/database/models.go b/coderd/database/models.go index a4df558504761..38162027a2ec1 100644 --- a/coderd/database/models.go +++ b/coderd/database/models.go @@ -491,12 +491,8 @@ type User struct { UpdatedAt time.Time `db:"updated_at" json:"updated_at"` Status UserStatus `db:"status" json:"status"` RBACRoles []string `db:"rbac_roles" json:"rbac_roles"` -} - -type UserAuth struct { - UserID uuid.UUID `db:"user_id" json:"user_id"` - LoginType LoginType `db:"login_type" json:"login_type"` - LinkedID string `db:"linked_id" json:"linked_id"` + LoginType LoginType `db:"login_type" json:"login_type"` + LinkedID string `db:"linked_id" json:"linked_id"` } type Workspace struct { diff --git a/coderd/database/querier.go b/coderd/database/querier.go index f37424e58df78..b418b34dd3e3f 100644 --- a/coderd/database/querier.go +++ b/coderd/database/querier.go @@ -61,10 +61,9 @@ type querier interface { GetTemplateVersionsCreatedAfter(ctx context.Context, createdAt time.Time) ([]TemplateVersion, error) GetTemplates(ctx context.Context) ([]Template, error) GetTemplatesWithFilter(ctx context.Context, arg GetTemplatesWithFilterParams) ([]Template, error) - GetUserAuthByLinkedID(ctx context.Context, linkedID string) (UserAuth, error) - GetUserAuthByUserID(ctx context.Context, userID uuid.UUID) (UserAuth, error) GetUserByEmailOrUsername(ctx context.Context, arg GetUserByEmailOrUsernameParams) (User, error) GetUserByID(ctx context.Context, id uuid.UUID) (User, error) + GetUserByLinkedID(ctx context.Context, linkedID string) (User, error) GetUserCount(ctx context.Context) (int64, error) GetUsers(ctx context.Context, arg GetUsersParams) ([]User, error) GetUsersByIDs(ctx context.Context, ids []uuid.UUID) ([]User, error) @@ -108,7 +107,6 @@ type querier interface { InsertTemplate(ctx context.Context, arg InsertTemplateParams) (Template, error) InsertTemplateVersion(ctx context.Context, arg InsertTemplateVersionParams) (TemplateVersion, error) InsertUser(ctx context.Context, arg InsertUserParams) (User, error) - InsertUserAuth(ctx context.Context, arg InsertUserAuthParams) (UserAuth, error) InsertWorkspace(ctx context.Context, arg InsertWorkspaceParams) (Workspace, error) InsertWorkspaceAgent(ctx context.Context, arg InsertWorkspaceAgentParams) (WorkspaceAgent, error) InsertWorkspaceApp(ctx context.Context, arg InsertWorkspaceAppParams) (WorkspaceApp, error) @@ -130,6 +128,7 @@ type querier interface { UpdateTemplateVersionByID(ctx context.Context, arg UpdateTemplateVersionByIDParams) error UpdateTemplateVersionDescriptionByJobID(ctx context.Context, arg UpdateTemplateVersionDescriptionByJobIDParams) error UpdateUserHashedPassword(ctx context.Context, arg UpdateUserHashedPasswordParams) error + UpdateUserLinkedID(ctx context.Context, arg UpdateUserLinkedIDParams) (User, error) UpdateUserProfile(ctx context.Context, arg UpdateUserProfileParams) (User, error) UpdateUserRoles(ctx context.Context, arg UpdateUserRolesParams) (User, error) UpdateUserStatus(ctx context.Context, arg UpdateUserStatusParams) (User, error) diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index 3fde41e01a8c0..268dd631436b6 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -2444,62 +2444,6 @@ func (q *sqlQuerier) UpdateTemplateVersionDescriptionByJobID(ctx context.Context return err } -const getUserAuthByLinkedID = `-- name: GetUserAuthByLinkedID :one -SELECT - user_id, login_type, linked_id -FROM - user_auth -WHERE - linked_id = $1 -` - -func (q *sqlQuerier) GetUserAuthByLinkedID(ctx context.Context, linkedID string) (UserAuth, error) { - row := q.db.QueryRowContext(ctx, getUserAuthByLinkedID, linkedID) - var i UserAuth - err := row.Scan(&i.UserID, &i.LoginType, &i.LinkedID) - return i, err -} - -const getUserAuthByUserID = `-- name: GetUserAuthByUserID :one -SELECT - user_id, login_type, linked_id -FROM - user_auth -WHERE - user_id = $1 -` - -func (q *sqlQuerier) GetUserAuthByUserID(ctx context.Context, userID uuid.UUID) (UserAuth, error) { - row := q.db.QueryRowContext(ctx, getUserAuthByUserID, userID) - var i UserAuth - err := row.Scan(&i.UserID, &i.LoginType, &i.LinkedID) - return i, err -} - -const insertUserAuth = `-- name: InsertUserAuth :one -INSERT INTO - user_auth ( - user_id, - login_type, - linked_id - ) -VALUES - ( $1, $2, $3) RETURNING user_id, login_type, linked_id -` - -type InsertUserAuthParams struct { - UserID uuid.UUID `db:"user_id" json:"user_id"` - LoginType LoginType `db:"login_type" json:"login_type"` - LinkedID string `db:"linked_id" json:"linked_id"` -} - -func (q *sqlQuerier) InsertUserAuth(ctx context.Context, arg InsertUserAuthParams) (UserAuth, error) { - row := q.db.QueryRowContext(ctx, insertUserAuth, arg.UserID, arg.LoginType, arg.LinkedID) - var i UserAuth - err := row.Scan(&i.UserID, &i.LoginType, &i.LinkedID) - return i, err -} - const getAuthorizationUserRoles = `-- name: GetAuthorizationUserRoles :one SELECT -- username is returned just to help for logging purposes @@ -2543,7 +2487,7 @@ func (q *sqlQuerier) GetAuthorizationUserRoles(ctx context.Context, userID uuid. const getUserByEmailOrUsername = `-- name: GetUserByEmailOrUsername :one SELECT - id, email, username, hashed_password, created_at, updated_at, status, rbac_roles + id, email, username, hashed_password, created_at, updated_at, status, rbac_roles, login_type, linked_id FROM users WHERE @@ -2570,13 +2514,15 @@ func (q *sqlQuerier) GetUserByEmailOrUsername(ctx context.Context, arg GetUserBy &i.UpdatedAt, &i.Status, pq.Array(&i.RBACRoles), + &i.LoginType, + &i.LinkedID, ) return i, err } const getUserByID = `-- name: GetUserByID :one SELECT - id, email, username, hashed_password, created_at, updated_at, status, rbac_roles + id, email, username, hashed_password, created_at, updated_at, status, rbac_roles, login_type, linked_id FROM users WHERE @@ -2597,6 +2543,35 @@ func (q *sqlQuerier) GetUserByID(ctx context.Context, id uuid.UUID) (User, error &i.UpdatedAt, &i.Status, pq.Array(&i.RBACRoles), + &i.LoginType, + &i.LinkedID, + ) + return i, err +} + +const getUserByLinkedID = `-- name: GetUserByLinkedID :one +SELECT + id, email, username, hashed_password, created_at, updated_at, status, rbac_roles, login_type, linked_id +FROM + users +WHERE + linked_id = $1 +` + +func (q *sqlQuerier) GetUserByLinkedID(ctx context.Context, linkedID string) (User, error) { + row := q.db.QueryRowContext(ctx, getUserByLinkedID, linkedID) + var i User + err := row.Scan( + &i.ID, + &i.Email, + &i.Username, + &i.HashedPassword, + &i.CreatedAt, + &i.UpdatedAt, + &i.Status, + pq.Array(&i.RBACRoles), + &i.LoginType, + &i.LinkedID, ) return i, err } @@ -2617,7 +2592,7 @@ func (q *sqlQuerier) GetUserCount(ctx context.Context) (int64, error) { const getUsers = `-- name: GetUsers :many SELECT - id, email, username, hashed_password, created_at, updated_at, status, rbac_roles + id, email, username, hashed_password, created_at, updated_at, status, rbac_roles, login_type, linked_id FROM users WHERE @@ -2709,6 +2684,8 @@ func (q *sqlQuerier) GetUsers(ctx context.Context, arg GetUsersParams) ([]User, &i.UpdatedAt, &i.Status, pq.Array(&i.RBACRoles), + &i.LoginType, + &i.LinkedID, ); err != nil { return nil, err } @@ -2724,7 +2701,7 @@ func (q *sqlQuerier) GetUsers(ctx context.Context, arg GetUsersParams) ([]User, } const getUsersByIDs = `-- name: GetUsersByIDs :many -SELECT id, email, username, hashed_password, created_at, updated_at, status, rbac_roles FROM users WHERE id = ANY($1 :: uuid [ ]) +SELECT id, email, username, hashed_password, created_at, updated_at, status, rbac_roles, login_type, linked_id FROM users WHERE id = ANY($1 :: uuid [ ]) ` func (q *sqlQuerier) GetUsersByIDs(ctx context.Context, ids []uuid.UUID) ([]User, error) { @@ -2745,6 +2722,8 @@ func (q *sqlQuerier) GetUsersByIDs(ctx context.Context, ids []uuid.UUID) ([]User &i.UpdatedAt, &i.Status, pq.Array(&i.RBACRoles), + &i.LoginType, + &i.LinkedID, ); err != nil { return nil, err } @@ -2768,10 +2747,12 @@ INSERT INTO hashed_password, created_at, updated_at, - rbac_roles + rbac_roles, + login_type, + linked_id ) VALUES - ($1, $2, $3, $4, $5, $6, $7) RETURNING id, email, username, hashed_password, created_at, updated_at, status, rbac_roles + ($1, $2, $3, $4, $5, $6, $7, $8, $9) RETURNING id, email, username, hashed_password, created_at, updated_at, status, rbac_roles, login_type, linked_id ` type InsertUserParams struct { @@ -2782,6 +2763,8 @@ type InsertUserParams struct { CreatedAt time.Time `db:"created_at" json:"created_at"` UpdatedAt time.Time `db:"updated_at" json:"updated_at"` RBACRoles []string `db:"rbac_roles" json:"rbac_roles"` + LoginType LoginType `db:"login_type" json:"login_type"` + LinkedID string `db:"linked_id" json:"linked_id"` } func (q *sqlQuerier) InsertUser(ctx context.Context, arg InsertUserParams) (User, error) { @@ -2793,6 +2776,8 @@ func (q *sqlQuerier) InsertUser(ctx context.Context, arg InsertUserParams) (User arg.CreatedAt, arg.UpdatedAt, pq.Array(arg.RBACRoles), + arg.LoginType, + arg.LinkedID, ) var i User err := row.Scan( @@ -2804,6 +2789,8 @@ func (q *sqlQuerier) InsertUser(ctx context.Context, arg InsertUserParams) (User &i.UpdatedAt, &i.Status, pq.Array(&i.RBACRoles), + &i.LoginType, + &i.LinkedID, ) return i, err } @@ -2827,6 +2814,38 @@ func (q *sqlQuerier) UpdateUserHashedPassword(ctx context.Context, arg UpdateUse return err } +const updateUserLinkedID = `-- name: UpdateUserLinkedID :one +UPDATE + users +SET + linked_id = $2 +WHERE + id = $1 RETURNING id, email, username, hashed_password, created_at, updated_at, status, rbac_roles, login_type, linked_id +` + +type UpdateUserLinkedIDParams struct { + ID uuid.UUID `db:"id" json:"id"` + LinkedID string `db:"linked_id" json:"linked_id"` +} + +func (q *sqlQuerier) UpdateUserLinkedID(ctx context.Context, arg UpdateUserLinkedIDParams) (User, error) { + row := q.db.QueryRowContext(ctx, updateUserLinkedID, arg.ID, arg.LinkedID) + var i User + err := row.Scan( + &i.ID, + &i.Email, + &i.Username, + &i.HashedPassword, + &i.CreatedAt, + &i.UpdatedAt, + &i.Status, + pq.Array(&i.RBACRoles), + &i.LoginType, + &i.LinkedID, + ) + return i, err +} + const updateUserProfile = `-- name: UpdateUserProfile :one UPDATE users @@ -2835,7 +2854,7 @@ SET username = $3, updated_at = $4 WHERE - id = $1 RETURNING id, email, username, hashed_password, created_at, updated_at, status, rbac_roles + id = $1 RETURNING id, email, username, hashed_password, created_at, updated_at, status, rbac_roles, login_type, linked_id ` type UpdateUserProfileParams struct { @@ -2862,6 +2881,8 @@ func (q *sqlQuerier) UpdateUserProfile(ctx context.Context, arg UpdateUserProfil &i.UpdatedAt, &i.Status, pq.Array(&i.RBACRoles), + &i.LoginType, + &i.LinkedID, ) return i, err } @@ -2874,7 +2895,7 @@ SET rbac_roles = ARRAY(SELECT DISTINCT UNNEST($1 :: text[])) WHERE id = $2 -RETURNING id, email, username, hashed_password, created_at, updated_at, status, rbac_roles +RETURNING id, email, username, hashed_password, created_at, updated_at, status, rbac_roles, login_type, linked_id ` type UpdateUserRolesParams struct { @@ -2894,6 +2915,8 @@ func (q *sqlQuerier) UpdateUserRoles(ctx context.Context, arg UpdateUserRolesPar &i.UpdatedAt, &i.Status, pq.Array(&i.RBACRoles), + &i.LoginType, + &i.LinkedID, ) return i, err } @@ -2905,7 +2928,7 @@ SET status = $2, updated_at = $3 WHERE - id = $1 RETURNING id, email, username, hashed_password, created_at, updated_at, status, rbac_roles + id = $1 RETURNING id, email, username, hashed_password, created_at, updated_at, status, rbac_roles, login_type, linked_id ` type UpdateUserStatusParams struct { @@ -2926,6 +2949,8 @@ func (q *sqlQuerier) UpdateUserStatus(ctx context.Context, arg UpdateUserStatusP &i.UpdatedAt, &i.Status, pq.Array(&i.RBACRoles), + &i.LoginType, + &i.LinkedID, ) return i, err } diff --git a/coderd/database/queries/user_auth.sql b/coderd/database/queries/user_auth.sql deleted file mode 100644 index 9228c96ff6116..0000000000000 --- a/coderd/database/queries/user_auth.sql +++ /dev/null @@ -1,24 +0,0 @@ --- name: GetUserAuthByUserID :one -SELECT - * -FROM - user_auth -WHERE - user_id = $1; --- name: InsertUserAuth :one -INSERT INTO - user_auth ( - user_id, - login_type, - linked_id - ) -VALUES - ( $1, $2, $3) RETURNING *; --- name: GetUserAuthByLinkedID :one -SELECT - * -FROM - user_auth -WHERE - linked_id = $1; - diff --git a/coderd/database/queries/users.sql b/coderd/database/queries/users.sql index 19fe8a7701744..d7edfb8c00c47 100644 --- a/coderd/database/queries/users.sql +++ b/coderd/database/queries/users.sql @@ -37,10 +37,12 @@ INSERT INTO hashed_password, created_at, updated_at, - rbac_roles + rbac_roles, + login_type, + linked_id ) VALUES - ($1, $2, $3, $4, $5, $6, $7) RETURNING *; + ($1, $2, $3, $4, $5, $6, $7, $8, $9) RETURNING *; -- name: UpdateUserProfile :one UPDATE @@ -159,3 +161,19 @@ LEFT JOIN organization_members ON id = user_id WHERE id = @user_id; + +-- name: GetUserByLinkedID :one +SELECT + * +FROM + users +WHERE + linked_id = $1; + +-- name: UpdateUserLinkedID :one +UPDATE + users +SET + linked_id = $2 +WHERE + id = $1 RETURNING *; diff --git a/coderd/httpmw/apikey_test.go b/coderd/httpmw/apikey_test.go index 0d29c84653c6f..58456aafd59bf 100644 --- a/coderd/httpmw/apikey_test.go +++ b/coderd/httpmw/apikey_test.go @@ -466,6 +466,8 @@ func createUser(ctx context.Context, t *testing.T, db database.Store) database.U CreatedAt: time.Now(), UpdatedAt: time.Now(), RBACRoles: []string{}, + LinkedID: uuid.NewString(), + LoginType: database.LoginTypePassword, }) require.NoError(t, err, "create user") return user diff --git a/coderd/httpmw/authorize_test.go b/coderd/httpmw/authorize_test.go index 997ac44350340..8ce23ca4ed946 100644 --- a/coderd/httpmw/authorize_test.go +++ b/coderd/httpmw/authorize_test.go @@ -115,6 +115,8 @@ func addUser(t *testing.T, db database.Store, roles ...string) (database.User, s Email: "admin@email.com", Username: "admin", RBACRoles: roles, + LinkedID: uuid.NewString(), + LoginType: database.LoginTypePassword, }) require.NoError(t, err) diff --git a/coderd/httpmw/organizationparam_test.go b/coderd/httpmw/organizationparam_test.go index d17c441741914..7562eb1e0f896 100644 --- a/coderd/httpmw/organizationparam_test.go +++ b/coderd/httpmw/organizationparam_test.go @@ -45,6 +45,8 @@ func TestOrganizationParam(t *testing.T) { Username: username, CreatedAt: database.Now(), UpdatedAt: database.Now(), + LinkedID: uuid.NewString(), + LoginType: database.LoginTypePassword, }) require.NoError(t, err) _, err = db.InsertAPIKey(r.Context(), database.InsertAPIKeyParams{ diff --git a/coderd/httpmw/templateparam_test.go b/coderd/httpmw/templateparam_test.go index 94abfe82cf5fb..f3f574659e69a 100644 --- a/coderd/httpmw/templateparam_test.go +++ b/coderd/httpmw/templateparam_test.go @@ -44,6 +44,8 @@ func TestTemplateParam(t *testing.T) { Username: username, CreatedAt: database.Now(), UpdatedAt: database.Now(), + LinkedID: uuid.NewString(), + LoginType: database.LoginTypePassword, }) require.NoError(t, err) diff --git a/coderd/httpmw/templateversionparam_test.go b/coderd/httpmw/templateversionparam_test.go index 5b49f75010bf9..c638be5172b31 100644 --- a/coderd/httpmw/templateversionparam_test.go +++ b/coderd/httpmw/templateversionparam_test.go @@ -44,6 +44,8 @@ func TestTemplateVersionParam(t *testing.T) { Username: username, CreatedAt: database.Now(), UpdatedAt: database.Now(), + LinkedID: uuid.NewString(), + LoginType: database.LoginTypePassword, }) require.NoError(t, err) diff --git a/coderd/httpmw/userparam_test.go b/coderd/httpmw/userparam_test.go index 866df68ef1eec..3ebedd400a5ee 100644 --- a/coderd/httpmw/userparam_test.go +++ b/coderd/httpmw/userparam_test.go @@ -35,9 +35,11 @@ func TestUserParam(t *testing.T) { }) user, err := db.InsertUser(r.Context(), database.InsertUserParams{ - ID: uuid.New(), - Email: "admin@email.com", - Username: "admin", + ID: uuid.New(), + Email: "admin@email.com", + Username: "admin", + LinkedID: uuid.NewString(), + LoginType: database.LoginTypePassword, }) require.NoError(t, err) diff --git a/coderd/httpmw/workspaceagentparam_test.go b/coderd/httpmw/workspaceagentparam_test.go index a2afaee534c9f..1e19e67c46fc2 100644 --- a/coderd/httpmw/workspaceagentparam_test.go +++ b/coderd/httpmw/workspaceagentparam_test.go @@ -44,6 +44,8 @@ func TestWorkspaceAgentParam(t *testing.T) { Username: username, CreatedAt: database.Now(), UpdatedAt: database.Now(), + LinkedID: uuid.NewString(), + LoginType: database.LoginTypePassword, }) require.NoError(t, err) diff --git a/coderd/httpmw/workspacebuildparam_test.go b/coderd/httpmw/workspacebuildparam_test.go index 6d402f01fc62b..4fe606533bd0f 100644 --- a/coderd/httpmw/workspacebuildparam_test.go +++ b/coderd/httpmw/workspacebuildparam_test.go @@ -44,6 +44,8 @@ func TestWorkspaceBuildParam(t *testing.T) { Username: username, CreatedAt: database.Now(), UpdatedAt: database.Now(), + LinkedID: uuid.NewString(), + LoginType: database.LoginTypePassword, }) require.NoError(t, err) diff --git a/coderd/httpmw/workspaceparam_test.go b/coderd/httpmw/workspaceparam_test.go index eac847a584f3b..633275a645247 100644 --- a/coderd/httpmw/workspaceparam_test.go +++ b/coderd/httpmw/workspaceparam_test.go @@ -44,6 +44,8 @@ func TestWorkspaceParam(t *testing.T) { Username: username, CreatedAt: database.Now(), UpdatedAt: database.Now(), + LinkedID: uuid.NewString(), + LoginType: database.LoginTypePassword, }) require.NoError(t, err) diff --git a/coderd/provisionerjobs_internal_test.go b/coderd/provisionerjobs_internal_test.go index 4d215f6bb2a92..bf2d996aba2d0 100644 --- a/coderd/provisionerjobs_internal_test.go +++ b/coderd/provisionerjobs_internal_test.go @@ -78,6 +78,8 @@ func TestProvisionerJobLogs_Unit(t *testing.T) { _, err = fDB.InsertUser(ctx, database.InsertUserParams{ ID: userID, RBACRoles: []string{"admin"}, + LinkedID: uuid.NewString(), + LoginType: database.LoginTypePassword, }) require.NoError(t, err) _, err = fDB.InsertWorkspaceBuild(ctx, database.InsertWorkspaceBuildParams{ diff --git a/coderd/telemetry/telemetry_test.go b/coderd/telemetry/telemetry_test.go index 4e78b19ec8c54..25fe91d348a6b 100644 --- a/coderd/telemetry/telemetry_test.go +++ b/coderd/telemetry/telemetry_test.go @@ -59,6 +59,8 @@ func TestTelemetry(t *testing.T) { _, err = db.InsertUser(ctx, database.InsertUserParams{ ID: uuid.New(), CreatedAt: database.Now(), + LinkedID: uuid.NewString(), + LoginType: database.LoginTypePassword, }) require.NoError(t, err) _, err = db.InsertWorkspace(ctx, database.InsertWorkspaceParams{ @@ -105,6 +107,8 @@ func TestTelemetry(t *testing.T) { ID: uuid.New(), Email: "kyle@coder.com", CreatedAt: database.Now(), + LinkedID: uuid.NewString(), + LoginType: database.LoginTypePassword, }) require.NoError(t, err) snapshot := collectSnapshot(t, db) diff --git a/coderd/userauth.go b/coderd/userauth.go index 6336b765b4923..e88e759d91b1f 100644 --- a/coderd/userauth.go +++ b/coderd/userauth.go @@ -6,6 +6,7 @@ import ( "errors" "fmt" "net/http" + "strconv" "strings" "github.com/coreos/go-oidc/v3/oidc" @@ -48,10 +49,13 @@ func (api *API) userAuthMethods(rw http.ResponseWriter, _ *http.Request) { } func (api *API) userOAuth2Github(rw http.ResponseWriter, r *http.Request) { - state := httpmw.OAuth2(r) + var ( + ctx = r.Context() + state = httpmw.OAuth2(r) + ) - oauthClient := oauth2.NewClient(r.Context(), oauth2.StaticTokenSource(state.Token)) - memberships, err := api.GithubOAuth2Config.ListOrganizationMemberships(r.Context(), oauthClient) + oauthClient := oauth2.NewClient(ctx, oauth2.StaticTokenSource(state.Token)) + memberships, err := api.GithubOAuth2Config.ListOrganizationMemberships(ctx, oauthClient) if err != nil { httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{ Message: "Internal error fetching authenticated Github user organizations.", @@ -76,7 +80,7 @@ func (api *API) userOAuth2Github(rw http.ResponseWriter, r *http.Request) { return } - ghUser, err := api.GithubOAuth2Config.AuthenticatedUser(r.Context(), oauthClient) + ghUser, err := api.GithubOAuth2Config.AuthenticatedUser(ctx, oauthClient) if err != nil { httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{ Message: "Internal error fetching authenticated Github user.", @@ -95,7 +99,7 @@ func (api *API) userOAuth2Github(rw http.ResponseWriter, r *http.Request) { continue } - allowedTeam, err = api.GithubOAuth2Config.TeamMembership(r.Context(), oauthClient, allowTeam.Organization, allowTeam.Slug, *ghUser.Login) + allowedTeam, err = api.GithubOAuth2Config.TeamMembership(ctx, oauthClient, allowTeam.Organization, allowTeam.Slug, *ghUser.Login) // The calling user may not have permission to the requested team! if err != nil { continue @@ -109,7 +113,7 @@ func (api *API) userOAuth2Github(rw http.ResponseWriter, r *http.Request) { } } - emails, err := api.GithubOAuth2Config.ListEmails(r.Context(), oauthClient) + emails, err := api.GithubOAuth2Config.ListEmails(ctx, oauthClient) if err != nil { httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{ Message: "Internal error fetching personal Github user.", @@ -118,37 +122,38 @@ func (api *API) userOAuth2Github(rw http.ResponseWriter, r *http.Request) { return } - var user database.User - // Search for existing users with matching and verified emails. - // If a verified GitHub email matches a Coder user, we will return. + verifiedEmails := make([]string, 0, len(emails)) for _, email := range emails { if !email.GetVerified() { continue } - user, err = api.Database.GetUserByEmailOrUsername(r.Context(), database.GetUserByEmailOrUsernameParams{ - Email: *email.Email, + verifiedEmails = append(verifiedEmails, email.GetEmail()) + } + + if len(verifiedEmails) == 0 { + httpapi.Write(rw, http.StatusForbidden, codersdk.Response{ + Message: "Verify an email address on Github to authenticate!", }) - if errors.Is(err, sql.ErrNoRows) { - continue - } - if err != nil { - httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{ - Message: fmt.Sprintf("Internal error fetching user by email %q.", *email.Email), - Detail: err.Error(), - }) - return - } - if !*email.Verified { - httpapi.Write(rw, http.StatusForbidden, codersdk.Response{ - Message: fmt.Sprintf("Verify the %q email address on Github to authenticate!", *email.Email), - }) - return - } - break + return + } + + user, found, err := findLinkedUser(ctx, api.Database, githubLinkedID(ghUser), verifiedEmails...) + if err != nil { + httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{ + Message: "Failed to find user.", + }) + return + } + + if found && user.LoginType != database.LoginTypeGithub { + httpapi.Write(rw, http.StatusConflict, codersdk.Response{ + Message: fmt.Sprintf("Incorrect login type, attempting to use %q but user is of login type %q", database.LoginTypeOIDC, user.LoginType), + }) + return } // If the user doesn't exist, create a new one! - if user.ID == uuid.Nil { + if !found { if !api.GithubOAuth2Config.AllowSignups { httpapi.Write(rw, http.StatusForbidden, codersdk.Response{ Message: "Signups are disabled for Github authentication!", @@ -178,10 +183,14 @@ func (api *API) userOAuth2Github(rw http.ResponseWriter, r *http.Request) { }) return } - user, _, err = api.createUser(r.Context(), codersdk.CreateUserRequest{ - Email: *verifiedEmail.Email, - Username: *ghUser.Login, - OrganizationID: organizationID, + user, _, err = api.createUser(r.Context(), createUserRequest{ + CreateUserRequest: codersdk.CreateUserRequest{ + Email: *verifiedEmail.Email, + Username: *ghUser.Login, + OrganizationID: organizationID, + }, + LinkedID: githubLinkedID(ghUser), + LoginType: database.LoginTypeGithub, }) if err != nil { httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{ @@ -192,6 +201,23 @@ func (api *API) userOAuth2Github(rw http.ResponseWriter, r *http.Request) { } } + // LEGACY: Remove 10/2022. + // We started tracking linked IDs later so it's possible for a user to be a + // pre-existing Github user and not have a linked ID. + if user.LinkedID == "" { + user, err = api.Database.UpdateUserLinkedID(ctx, database.UpdateUserLinkedIDParams{ + ID: user.ID, + LinkedID: githubLinkedID(ghUser), + }) + if err != nil { + httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{ + Message: "Failed to set user linked ID.", + Detail: err.Error(), + }) + return + } + } + _, created := api.createAPIKey(rw, r, database.InsertAPIKeyParams{ UserID: user.ID, LoginType: database.LoginTypeGithub, @@ -289,13 +315,7 @@ func (api *API) userOIDC(rw http.ResponseWriter, r *http.Request) { } } - api.Database.InTx( - func(store database.Store) error { - - } - ) - - user, found, err := findLinkedUser(ctx, api.Database, database.LoginTypeOIDC, uniqueUserOIDC(idToken), claims.Email) + user, found, err := findLinkedUser(ctx, api.Database, oidcLinkedID(idToken), claims.Email) if err != nil { httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{ Message: "Failed to find user.", @@ -311,6 +331,13 @@ func (api *API) userOIDC(rw http.ResponseWriter, r *http.Request) { return } + if found && user.LoginType != database.LoginTypeOIDC { + httpapi.Write(rw, http.StatusConflict, codersdk.Response{ + Message: fmt.Sprintf("Incorrect login type, attempting to use %q but user is of login type %q", database.LoginTypeOIDC, user.LoginType), + }) + return + } + if !found { var organizationID uuid.UUID organizations, _ := api.Database.GetOrganizations(ctx) @@ -320,10 +347,14 @@ func (api *API) userOIDC(rw http.ResponseWriter, r *http.Request) { // email to organization. organizationID = organizations[0].ID } - user, _, err = api.createUser(ctx, codersdk.CreateUserRequest{ - Email: claims.Email, - Username: claims.Username, - OrganizationID: organizationID, + user, _, err = api.createUser(ctx, createUserRequest{ + CreateUserRequest: codersdk.CreateUserRequest{ + Email: claims.Email, + Username: claims.Username, + OrganizationID: organizationID, + }, + LinkedID: oidcLinkedID(idToken), + LoginType: database.LoginTypeOIDC, }) if err != nil { httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{ @@ -332,11 +363,6 @@ func (api *API) userOIDC(rw http.ResponseWriter, r *http.Request) { }) return } - _, err = api.Database.InsertUserAuth(ctx, database.InsertUserAuthParams{ - UserID: user.ID, - LoginType: database.LoginTypeOIDC, - LinkedID: uniqueUserOIDC(idToken), - }) if err != nil { httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{ Message: "Failed to insert user auth metadata.", @@ -345,9 +371,49 @@ func (api *API) userOIDC(rw http.ResponseWriter, r *http.Request) { return } } - if user.Email != claims.Email || user.Username != claims.Username { + // LEGACY: Remove 10/2022. + // We started tracking linked IDs later so it's possible for a user to be a + // pre-existing OIDC user and not have a linked ID. + if user.LinkedID == "" { + user, err = api.Database.UpdateUserLinkedID(ctx, database.UpdateUserLinkedIDParams{ + ID: user.ID, + LinkedID: oidcLinkedID(idToken), + }) + if err != nil { + httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{ + Message: "Failed to update user linked ID.", + Detail: err.Error(), + }) + return + } + } + + // If the upstream email or username has changed we should mirror + // that in Coder. Many enterprises use a user's email/username as + // security auditing fields so they need to stay synced. + if user.Email != claims.Email || user.Username != claims.Username { + // TODO(JonA): Since we're processing updates to a user's upstream + // email/username, it's possible for a different built-in user to + // have already claimed the username. + // In such cases in the current implementation this user can now no + // longer sign in until an administrator finds the offending built-in + // user and changes their username. + user, err = api.Database.UpdateUserProfile(ctx, database.UpdateUserProfileParams{ + ID: user.ID, + Email: claims.Email, + Username: claims.Username, + UpdatedAt: database.Now(), + }) + if err != nil { + httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{ + Message: "Failed to update user profile.", + Detail: err.Error(), + }) + return + } } + _, created := api.createAPIKey(rw, r, database.InsertAPIKeyParams{ UserID: user.ID, LoginType: database.LoginTypeOIDC, @@ -366,56 +432,42 @@ func (api *API) userOIDC(rw http.ResponseWriter, r *http.Request) { http.Redirect(rw, r, redirect, http.StatusTemporaryRedirect) } -func uniqueUserOIDC(tok *oidc.IDToken) string { - return strings.Join([]string{tok.Issuer, tok.Subject}, "||") +// githubLinkedID returns the unique ID for a GitHub user. +func githubLinkedID(u *github.User) string { + return strconv.FormatInt(u.GetID(), 10) } -func findLinkedUser(ctx context.Context, db database.Store, authType database.LoginType, linkedID string, email string) (database.User, bool, error) { - var user database.User +// oidcLinkedID returns the uniqued ID for an OIDC user. +// See https://openid.net/specs/openid-connect-core-1_0.html#ClaimStability. +func oidcLinkedID(tok *oidc.IDToken) string { + return strings.Join([]string{tok.Issuer, tok.Subject}, "||") +} - uauth, err := db.GetUserAuthByLinkedID(ctx, linkedID) +// findLinkedUser tries to find a user by their unique OAuth-linked ID. +// If it doesn't not find it, it returns the user by their email. +func findLinkedUser(ctx context.Context, db database.Store, linkedID string, emails ...string) (database.User, bool, error) { + user, err := db.GetUserByLinkedID(ctx, linkedID) if err != nil && !errors.Is(err, sql.ErrNoRows) { return user, false, xerrors.Errorf("get user auth by linked ID: %w", err) } if err == nil { - user, err := db.GetUserByID(ctx, uauth.UserID) - if err != nil { - return user, false, xerrors.Errorf("get user by ID: %w", err) - } return user, true, nil } - user, err = db.GetUserByEmailOrUsername(ctx, database.GetUserByEmailOrUsernameParams{ - Email: email, - }) - if err != nil && !errors.Is(err, sql.ErrNoRows) { - return user, false, xerrors.Errorf("get user by email: %w", err) - } - if errors.Is(err, sql.ErrNoRows) { - return user, false, nil - } - - // Try getting the UAuth by user ID instead now. Maybe the user - // logged in using a different login type. - uauth, err = db.GetUserAuthByUserID(ctx, user.ID) - if err != nil && errors.Is(err, sql.ErrNoRows) { - return user, false, xerrors.Errorf("get user auth by user ID: %w", err) - } - if uauth.LoginType != authType { - return user, false, xerrors.Errorf("cannot login with %q with account is already linked with %q", authType, uauth.LoginType) - } - if err == nil { - return user, false, xerrors.Errorf("user auth already exists with different linked ID? Expecting %q but got %q", linkedID, uauth.LinkedID) + for _, email := range emails { + user, err = db.GetUserByEmailOrUsername(ctx, database.GetUserByEmailOrUsernameParams{ + Email: email, + }) + if err != nil && !errors.Is(err, sql.ErrNoRows) { + return user, false, xerrors.Errorf("get user by email: %w", err) + } + if errors.Is(err, sql.ErrNoRows) { + continue + } + return user, true, nil } - _, err = db.InsertUserAuth(ctx, database.InsertUserAuthParams{ - UserID: user.ID, - LoginType: authType, - LinkedID: linkedID, - }) - if err != nil { - return user, false, xerrors.Errorf("insert user auth: %w", err) - } - return user, true, nil + // No user found. + return user, false, nil } diff --git a/coderd/users.go b/coderd/users.go index 51509b0cc800e..f5622f54e6c9f 100644 --- a/coderd/users.go +++ b/coderd/users.go @@ -77,10 +77,14 @@ func (api *API) postFirstUser(rw http.ResponseWriter, r *http.Request) { return } - user, organizationID, err := api.createUser(r.Context(), codersdk.CreateUserRequest{ - Email: createUser.Email, - Username: createUser.Username, - Password: createUser.Password, + user, organizationID, err := api.createUser(r.Context(), createUserRequest{ + CreateUserRequest: codersdk.CreateUserRequest{ + Email: createUser.Email, + Username: createUser.Username, + Password: createUser.Password, + }, + LoginType: database.LoginTypePassword, + LinkedID: createUser.Email, }) if err != nil { httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{ @@ -235,7 +239,11 @@ func (api *API) postUser(rw http.ResponseWriter, r *http.Request) { return } - user, _, err := api.createUser(r.Context(), createUser) + user, _, err := api.createUser(r.Context(), createUserRequest{ + CreateUserRequest: createUser, + LinkedID: createUser.Email, + LoginType: database.LoginTypePassword, + }) if err != nil { httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{ Message: "Internal error creating user.", @@ -869,7 +877,13 @@ func (api *API) createAPIKey(rw http.ResponseWriter, r *http.Request, params dat return sessionToken, true } -func (api *API) createUser(ctx context.Context, req codersdk.CreateUserRequest) (database.User, uuid.UUID, error) { +type createUserRequest struct { + codersdk.CreateUserRequest + LoginType database.LoginType + LinkedID string +} + +func (api *API) createUser(ctx context.Context, req createUserRequest) (database.User, uuid.UUID, error) { var user database.User return user, req.OrganizationID, api.Database.InTx(func(db database.Store) error { orgRoles := make([]string, 0) @@ -896,6 +910,8 @@ func (api *API) createUser(ctx context.Context, req codersdk.CreateUserRequest) UpdatedAt: database.Now(), // All new users are defaulted to members of the site. RBACRoles: []string{}, + LoginType: req.LoginType, + LinkedID: req.LinkedID, } // If a user signs up with OAuth, they can have no password! if req.Password != "" { From dd2df9c933ef5cc8a39223fc5d2919fc9eb6e7e1 Mon Sep 17 00:00:00 2001 From: Jon Ayers Date: Tue, 9 Aug 2022 22:43:44 +0000 Subject: [PATCH 03/32] gofmt --- coderd/database/generate.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/coderd/database/generate.sh b/coderd/database/generate.sh index 60f9cd2e226fb..326fa096b90d1 100755 --- a/coderd/database/generate.sh +++ b/coderd/database/generate.sh @@ -13,8 +13,8 @@ SCRIPT_DIR=$(dirname "${BASH_SOURCE[0]}") ( cd "$SCRIPT_DIR" - # Dump the updated schema. - go run dump/main.go + # Dump the updated schema. + go run dump/main.go # The logic below depends on the exact version being correct :( go run github.com/kyleconroy/sqlc/cmd/sqlc@v1.13.0 generate From 0356f469e0138ebc1cc68ed53edb1b964a49d01b Mon Sep 17 00:00:00 2001 From: Jon Ayers Date: Tue, 9 Aug 2022 22:54:04 +0000 Subject: [PATCH 04/32] make fake db happy --- coderd/database/databasefake/databasefake.go | 27 ++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/coderd/database/databasefake/databasefake.go b/coderd/database/databasefake/databasefake.go index 1013a79d92d4f..978f8ca4dc926 100644 --- a/coderd/database/databasefake/databasefake.go +++ b/coderd/database/databasefake/databasefake.go @@ -2261,3 +2261,30 @@ func (q *fakeQuerier) GetDeploymentID(_ context.Context) (string, error) { return q.deploymentID, nil } + +func (q *fakeQuerier) UpdateUserLinkedID(_ context.Context, arg database.UpdateUserLinkedIDParams) (database.User, error) { + q.mutex.Lock() + defer q.mutex.Unlock() + + for index, user := range q.users { + if user.ID != arg.ID { + continue + } + user.LinkedID = arg.LinkedID + q.users[index] = user + return user, nil + } + return database.User{}, sql.ErrNoRows +} + +func (q *fakeQuerier) GetUserByLinkedID(_ context.Context, linkedID string) (database.User, error) { + q.mutex.RLock() + defer q.mutex.RUnlock() + + for _, user := range q.users { + if user.LinkedID == linkedID { + return user, nil + } + } + return database.User{}, sql.ErrNoRows +} From 6b1b9007932bb81c9c3b7446fcad0b9951e43f77 Mon Sep 17 00:00:00 2001 From: Jon Ayers Date: Tue, 9 Aug 2022 23:05:16 +0000 Subject: [PATCH 05/32] make audit happy --- coderd/audit/diff_test.go | 4 ++++ coderd/audit/table.go | 2 ++ 2 files changed, 6 insertions(+) diff --git a/coderd/audit/diff_test.go b/coderd/audit/diff_test.go index fc9c41b7cb16d..472e7cfad8cd3 100644 --- a/coderd/audit/diff_test.go +++ b/coderd/audit/diff_test.go @@ -158,6 +158,8 @@ func TestDiff(t *testing.T) { UpdatedAt: time.Now(), Status: database.UserStatusActive, RBACRoles: []string{"omega admin"}, + LoginType: database.LoginTypePassword, + LinkedID: "foobar", }, exp: audit.Map{ "id": uuid.UUID{1}.String(), @@ -166,6 +168,8 @@ func TestDiff(t *testing.T) { "hashed_password": ([]byte)(nil), "status": database.UserStatusActive, "rbac_roles": []string{"omega admin"}, + "login_type": database.LoginTypePassword, + "linked_id": "foobar", }, }, }) diff --git a/coderd/audit/table.go b/coderd/audit/table.go index c842956e6cf24..8d3d47ca62060 100644 --- a/coderd/audit/table.go +++ b/coderd/audit/table.go @@ -94,6 +94,8 @@ var AuditableResources = auditMap(map[any]map[string]Action{ "updated_at": ActionIgnore, // Changes, but is implicit and not helpful in a diff. "status": ActionTrack, "rbac_roles": ActionTrack, + "login_type": ActionTrack, + "linked_id": ActionTrack, }, &database.Workspace{}: { "id": ActionTrack, From 8f63d5cb9de73bccc7eb8a82aad3788bd3166015 Mon Sep 17 00:00:00 2001 From: Jon Ayers Date: Wed, 10 Aug 2022 00:27:05 +0000 Subject: [PATCH 06/32] fix some tests --- coderd/userauth_test.go | 58 ++++++++++++++++++++--------------------- coderd/users.go | 2 ++ codersdk/users.go | 3 +++ 3 files changed, 34 insertions(+), 29 deletions(-) diff --git a/coderd/userauth_test.go b/coderd/userauth_test.go index 6d4c6af34bd30..3a0818a62b5ae 100644 --- a/coderd/userauth_test.go +++ b/coderd/userauth_test.go @@ -206,34 +206,6 @@ func TestUserOAuth2Github(t *testing.T) { resp := oauth2Callback(t, client) require.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode) }) - t.Run("Login", func(t *testing.T) { - t.Parallel() - client := coderdtest.New(t, &coderdtest.Options{ - GithubOAuth2Config: &coderd.GithubOAuth2Config{ - OAuth2Config: &oauth2Config{}, - AllowOrganizations: []string{"coder"}, - ListOrganizationMemberships: func(ctx context.Context, client *http.Client) ([]*github.Membership, error) { - return []*github.Membership{{ - Organization: &github.Organization{ - Login: github.String("coder"), - }, - }}, nil - }, - AuthenticatedUser: func(ctx context.Context, client *http.Client) (*github.User, error) { - return &github.User{}, nil - }, - ListEmails: func(ctx context.Context, client *http.Client) ([]*github.UserEmail, error) { - return []*github.UserEmail{{ - Email: github.String("testuser@coder.com"), - Verified: github.Bool(true), - }}, nil - }, - }, - }) - _ = coderdtest.CreateFirstUser(t, client) - resp := oauth2Callback(t, client) - require.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode) - }) t.Run("SignupAllowedTeam", func(t *testing.T) { t.Parallel() client := coderdtest.New(t, &coderdtest.Options{ @@ -361,6 +333,7 @@ func TestUserOIDC(t *testing.T) { user, err := client.User(ctx, "me") require.NoError(t, err) require.Equal(t, tc.Username, user.Username) + require.Equal(t, "https://coder.com||hello", user.LinkedID) } }) } @@ -404,6 +377,27 @@ func TestUserOIDC(t *testing.T) { resp := oidcCallback(t, client) require.Equal(t, http.StatusBadRequest, resp.StatusCode) }) + + // Test that we do not allow collisions with pre-existing accounts + // of differing login types. + t.Run("InvalidLoginType", func(t *testing.T) { + t.Parallel() + config := createOIDCConfig(t, jwt.MapClaims{ + "email": "kyle@kwc.io", + "email_verified": true, + "preferred_username": "kyle", + }) + + client := coderdtest.New(t, &coderdtest.Options{ + OIDCConfig: config, + }) + + config.AllowSignups = true + config.EmailDomain = "kwc.io" + + resp := oidcCallback(t, client) + assert.Equal(t, http.StatusConflict, resp.StatusCode) + }) } // createOIDCConfig generates a new OIDCConfig that returns a static token @@ -415,11 +409,13 @@ func createOIDCConfig(t *testing.T, claims jwt.MapClaims) *coderd.OIDCConfig { // https://datatracker.ietf.org/doc/html/rfc7519#section-4.1 claims["exp"] = time.Now().Add(time.Hour).UnixMilli() + claims["iss"] = "https://coder.com" + claims["sub"] = "hello" signed, err := jwt.NewWithClaims(jwt.SigningMethodRS256, claims).SignedString(key) require.NoError(t, err) - verifier := oidc.NewVerifier("", &oidc.StaticKeySet{ + verifier := oidc.NewVerifier("https://coder.com", &oidc.StaticKeySet{ PublicKeys: []crypto.PublicKey{key.Public()}, }, &oidc.Config{ SkipClientIDCheck: true, @@ -480,3 +476,7 @@ func oidcCallback(t *testing.T, client *codersdk.Client) *http.Response { t.Log(string(data)) return res } + +func i64ptr(i int64) *int64 { + return &i +} diff --git a/coderd/users.go b/coderd/users.go index f5622f54e6c9f..4c67838be01f2 100644 --- a/coderd/users.go +++ b/coderd/users.go @@ -966,6 +966,8 @@ func convertUser(user database.User, organizationIDs []uuid.UUID) codersdk.User Status: codersdk.UserStatus(user.Status), OrganizationIDs: organizationIDs, Roles: make([]codersdk.Role, 0), + LoginType: codersdk.LoginType(user.LoginType), + LinkedID: user.LinkedID, } for _, roleName := range user.RBACRoles { diff --git a/codersdk/users.go b/codersdk/users.go index 17252c20405c3..e0d9ead70381f 100644 --- a/codersdk/users.go +++ b/codersdk/users.go @@ -27,6 +27,7 @@ type LoginType string const ( LoginTypePassword LoginType = "password" LoginTypeGithub LoginType = "github" + LoginTypeOIDC LoginType = "oidc" ) type UsersRequest struct { @@ -49,6 +50,8 @@ type User struct { Status UserStatus `json:"status"` OrganizationIDs []uuid.UUID `json:"organization_ids"` Roles []Role `json:"roles"` + LoginType LoginType `json:"login_type"` + LinkedID string `json:"linked_id"` } type APIKey struct { From de7db33e7cee9b40ed080508254fff2387131b70 Mon Sep 17 00:00:00 2001 From: Jon Ayers Date: Wed, 10 Aug 2022 00:31:41 +0000 Subject: [PATCH 07/32] make gen --- coderd/database/queries/users.sql | 10 +++++----- site/src/api/typesGenerated.ts | 4 +++- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/coderd/database/queries/users.sql b/coderd/database/queries/users.sql index d7edfb8c00c47..7d98f1b1cdb7d 100644 --- a/coderd/database/queries/users.sql +++ b/coderd/database/queries/users.sql @@ -38,8 +38,8 @@ INSERT INTO created_at, updated_at, rbac_roles, - login_type, - linked_id + login_type, + linked_id ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) RETURNING *; @@ -164,11 +164,11 @@ WHERE -- name: GetUserByLinkedID :one SELECT - * + * FROM - users + users WHERE - linked_id = $1; + linked_id = $1; -- name: UpdateUserLinkedID :one UPDATE diff --git a/site/src/api/typesGenerated.ts b/site/src/api/typesGenerated.ts index def9cd07e894a..3adcefb0e45db 100644 --- a/site/src/api/typesGenerated.ts +++ b/site/src/api/typesGenerated.ts @@ -352,6 +352,8 @@ export interface User { readonly status: UserStatus readonly organization_ids: string[] readonly roles: Role[] + readonly login_type: LoginType + readonly linked_id: string } // From codersdk/users.go @@ -537,7 +539,7 @@ export type LogLevel = "debug" | "error" | "info" | "trace" | "warn" export type LogSource = "provisioner" | "provisioner_daemon" // From codersdk/users.go -export type LoginType = "github" | "password" +export type LoginType = "github" | "oidc" | "password" // From codersdk/parameters.go export type ParameterDestinationScheme = "environment_variable" | "none" | "provisioner_variable" From 5fdf899a9af8e15ab5f40f5869deb9f3118c1805 Mon Sep 17 00:00:00 2001 From: Jon Ayers Date: Wed, 10 Aug 2022 00:37:28 +0000 Subject: [PATCH 08/32] fix tests --- coderd/userauth_test.go | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/coderd/userauth_test.go b/coderd/userauth_test.go index 3a0818a62b5ae..f3def2084804a 100644 --- a/coderd/userauth_test.go +++ b/coderd/userauth_test.go @@ -192,6 +192,7 @@ func TestUserOAuth2Github(t *testing.T) { AuthenticatedUser: func(ctx context.Context, client *http.Client) (*github.User, error) { return &github.User{ Login: github.String("kyle"), + ID: i64ptr(1234), }, nil }, ListEmails: func(ctx context.Context, client *http.Client) ([]*github.UserEmail, error) { @@ -205,6 +206,13 @@ func TestUserOAuth2Github(t *testing.T) { }) resp := oauth2Callback(t, client) require.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode) + + client.SessionToken = resp.Cookies()[0].Value + user, err := client.User(context.Background(), "me") + require.NoError(t, err) + require.Equal(t, "1234", user.LinkedID) + require.Equal(t, "kyle@coder.com", user.Email) + require.Equal(t, "kyle", user.Username) }) t.Run("SignupAllowedTeam", func(t *testing.T) { t.Parallel() @@ -392,6 +400,14 @@ func TestUserOIDC(t *testing.T) { OIDCConfig: config, }) + _, err := client.CreateFirstUser(context.Background(), codersdk.CreateFirstUserRequest{ + Email: "kyle@kwc.io", + Username: "kyle", + Password: "yeah", + OrganizationName: "default", + }) + require.NoError(t, err) + config.AllowSignups = true config.EmailDomain = "kwc.io" From 3a4d049fd022798d0bfd99faeda86bd760db4ab2 Mon Sep 17 00:00:00 2001 From: Jon Ayers Date: Wed, 10 Aug 2022 23:45:40 +0000 Subject: [PATCH 09/32] fmt --- coderd/database/queries.sql.go | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index 268dd631436b6..f64619266087c 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -2551,11 +2551,11 @@ func (q *sqlQuerier) GetUserByID(ctx context.Context, id uuid.UUID) (User, error const getUserByLinkedID = `-- name: GetUserByLinkedID :one SELECT - id, email, username, hashed_password, created_at, updated_at, status, rbac_roles, login_type, linked_id + id, email, username, hashed_password, created_at, updated_at, status, rbac_roles, login_type, linked_id FROM - users + users WHERE - linked_id = $1 + linked_id = $1 ` func (q *sqlQuerier) GetUserByLinkedID(ctx context.Context, linkedID string) (User, error) { @@ -2748,8 +2748,8 @@ INSERT INTO created_at, updated_at, rbac_roles, - login_type, - linked_id + login_type, + linked_id ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) RETURNING id, email, username, hashed_password, created_at, updated_at, status, rbac_roles, login_type, linked_id From 4108ece22a96b0153db22565dccb7edeb01ce782 Mon Sep 17 00:00:00 2001 From: Jon Ayers Date: Thu, 11 Aug 2022 21:16:28 +0000 Subject: [PATCH 10/32] begin refactoring PR --- .../database/migrations/000034_linked_user_id.up.sql | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/coderd/database/migrations/000034_linked_user_id.up.sql b/coderd/database/migrations/000034_linked_user_id.up.sql index 003bb389f1393..89974dc347309 100644 --- a/coderd/database/migrations/000034_linked_user_id.up.sql +++ b/coderd/database/migrations/000034_linked_user_id.up.sql @@ -1,6 +1,15 @@ BEGIN; -ALTER TABLE users ADD COLUMN login_type login_type NOT NULL DEFAULT 'password'; +CREATE TABLE IF NOT EXISTS users ( + user_id uuid NOT NULL, + login_type login_type NOT NULL, + linked_id text NOT NULL DEFAULT ''::text NOT NULL, + oauth_access_token text DEFAULT ''::text NOT NULL, + oauth_refresh_token text DEFAULT ''::text NOT NULL, + oauth_id_token text DEFAULT ''::text NOT NULL, + oauth_expiry timestamp with time zone DEFAULT '0001-01-01 00:00:00+00'::timestamp with time zone NOT NULL, + UNIQUE(user_id, login_type), +) ALTER TABLE users ADD COLUMN linked_id text NOT NULL DEFAULT ''; UPDATE From 14b5382ce9406049015cc3db0b9800fca1088f11 Mon Sep 17 00:00:00 2001 From: Jon Ayers Date: Fri, 12 Aug 2022 00:17:01 +0000 Subject: [PATCH 11/32] finish migration --- .../migrations/000034_linked_user_id.down.sql | 3 - .../migrations/000034_linked_user_id.up.sql | 59 ++++++++++++------- 2 files changed, 37 insertions(+), 25 deletions(-) diff --git a/coderd/database/migrations/000034_linked_user_id.down.sql b/coderd/database/migrations/000034_linked_user_id.down.sql index 6e8d37f7e7cf7..c4c4836d78fcb 100644 --- a/coderd/database/migrations/000034_linked_user_id.down.sql +++ b/coderd/database/migrations/000034_linked_user_id.down.sql @@ -1,6 +1,3 @@ BEGIN; -ALTER TABLE users DROP COLUMN linked_id; -ALTER TABLE users DROP COLUMN login_type; - COMMIT; diff --git a/coderd/database/migrations/000034_linked_user_id.up.sql b/coderd/database/migrations/000034_linked_user_id.up.sql index 89974dc347309..3daa2e78cfd97 100644 --- a/coderd/database/migrations/000034_linked_user_id.up.sql +++ b/coderd/database/migrations/000034_linked_user_id.up.sql @@ -1,29 +1,44 @@ BEGIN; -CREATE TABLE IF NOT EXISTS users ( +ALTER TYPE login_type ADD VALUE 'oidc'; + +CREATE TABLE IF NOT EXISTS user_links ( user_id uuid NOT NULL, login_type login_type NOT NULL, - linked_id text NOT NULL DEFAULT ''::text NOT NULL, - oauth_access_token text DEFAULT ''::text NOT NULL, - oauth_refresh_token text DEFAULT ''::text NOT NULL, - oauth_id_token text DEFAULT ''::text NOT NULL, - oauth_expiry timestamp with time zone DEFAULT '0001-01-01 00:00:00+00'::timestamp with time zone NOT NULL, - UNIQUE(user_id, login_type), -) -ALTER TABLE users ADD COLUMN linked_id text NOT NULL DEFAULT ''; + linked_id text DEFAULT ''::text NOT NULL, + oauth_access_token text DEFAULT ''::text NOT NULL, + oauth_refresh_token text DEFAULT ''::text NOT NULL, + oauth_id_token text DEFAULT ''::text NOT NULL, + oauth_expiry timestamp with time zone DEFAULT '0001-01-01 00:00:00+00'::timestamp with time zone NOT NULL, + UNIQUE(user_id, login_type) +); + +INSERT INTO user_links + ( + user_id, + login_type, + linked_id, + oauth_access_token, + oauth_refresh_token, + oauth_id_token, + oauth_expiry + ) +SELECT + keys.user_id, + keys.login_type, + '', + keys.oauth_access_token, + keys.oauth_refresh_token, + keys.oauth_id_token, + keys.oauth_expiry +FROM + ( + SELECT + row_number() OVER (partition by user_id, login_type ORDER BY updated_at DESC) AS x, + api_keys.* FROM api_keys + ) as keys + WHERE x=1 AND keys.login_type != 'password'; -UPDATE - users -SET - login_type = ( - SELECT - login_type - FROM - api_keys - WHERE - api_keys.user_id = users.id - ORDER BY updated_at DESC - LIMIT 1 - ); +ALTER TABLE api_keys RENAME COLUMN login_type TO _login_type COMMIT; From 85535011dff23615eee5ee19c0d086a801b51342 Mon Sep 17 00:00:00 2001 From: Jon Ayers Date: Fri, 12 Aug 2022 00:17:51 +0000 Subject: [PATCH 12/32] use main sql.dump --- coderd/database/dump.sql | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/coderd/database/dump.sql b/coderd/database/dump.sql index f9e640e2f840b..87853f23fe16e 100644 --- a/coderd/database/dump.sql +++ b/coderd/database/dump.sql @@ -275,9 +275,7 @@ CREATE TABLE users ( created_at timestamp with time zone NOT NULL, updated_at timestamp with time zone NOT NULL, status user_status DEFAULT 'active'::public.user_status NOT NULL, - rbac_roles text[] DEFAULT '{}'::text[] NOT NULL, - login_type login_type DEFAULT 'password'::public.login_type NOT NULL, - linked_id text DEFAULT ''::text NOT NULL + rbac_roles text[] DEFAULT '{}'::text[] NOT NULL ); CREATE TABLE workspace_agents ( From f748d3d0ff2b7dea99a09643aabc8e4888c68771 Mon Sep 17 00:00:00 2001 From: Jon Ayers Date: Fri, 12 Aug 2022 00:23:18 +0000 Subject: [PATCH 13/32] lift error --- coderd/database/postgres/postgres.go | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/coderd/database/postgres/postgres.go b/coderd/database/postgres/postgres.go index d1ef7b3084197..2d992625885fa 100644 --- a/coderd/database/postgres/postgres.go +++ b/coderd/database/postgres/postgres.go @@ -122,27 +122,29 @@ func Open() (string, func(), error) { return "", nil, xerrors.Errorf("expire resource: %w", err) } - pool.MaxWait = 120 * time.Second + pool.MaxWait = 15 * time.Second + var retryErr error err = pool.Retry(func() error { - db, err := sql.Open("postgres", dbURL) - if err != nil { - return xerrors.Errorf("open postgres: %w", err) + var db *sql.DB + db, retryErr := sql.Open("postgres", dbURL) + if retryErr != nil { + return xerrors.Errorf("open postgres: %w", retryErr) } defer db.Close() - err = db.Ping() - if err != nil { - return xerrors.Errorf("ping postgres: %w", err) + retryErr = db.Ping() + if retryErr != nil { + return xerrors.Errorf("ping postgres: %w", retryErr) } - err = database.MigrateUp(db) - if err != nil { - return xerrors.Errorf("migrate db: %w", err) + retryErr = database.MigrateUp(db) + if retryErr != nil { + return xerrors.Errorf("migrate db: %w", retryErr) } return nil }) if err != nil { - return "", nil, err + return "", nil, retryErr } return dbURL, func() { _ = pool.Purge(resource) From c1b987166da666f47eef56b2346584ccb84b2fee Mon Sep 17 00:00:00 2001 From: Jon Ayers Date: Fri, 12 Aug 2022 01:26:05 +0000 Subject: [PATCH 14/32] new migration --- coderd/database/dump.sql | 17 +- .../migrations/000034_linked_user_id.up.sql | 8 +- coderd/database/models.go | 36 ++-- coderd/database/postgres/postgres.go | 1 + coderd/database/querier.go | 2 - coderd/database/queries.sql.go | 171 +++--------------- coderd/database/queries/apikeys.sql | 14 +- coderd/database/queries/users.sql | 22 +-- 8 files changed, 71 insertions(+), 200 deletions(-) diff --git a/coderd/database/dump.sql b/coderd/database/dump.sql index 87853f23fe16e..bf62f67a10481 100644 --- a/coderd/database/dump.sql +++ b/coderd/database/dump.sql @@ -96,10 +96,6 @@ CREATE TABLE api_keys ( created_at timestamp with time zone NOT NULL, updated_at timestamp with time zone NOT NULL, login_type login_type NOT NULL, - oauth_access_token text DEFAULT ''::text NOT NULL, - oauth_refresh_token text DEFAULT ''::text NOT NULL, - oauth_id_token text DEFAULT ''::text NOT NULL, - oauth_expiry timestamp with time zone DEFAULT '0001-01-01 00:00:00+00'::timestamp with time zone NOT NULL, lifetime_seconds bigint DEFAULT 86400 NOT NULL, ip_address inet DEFAULT '0.0.0.0'::inet NOT NULL ); @@ -267,6 +263,16 @@ CREATE TABLE templates ( created_by uuid NOT NULL ); +CREATE TABLE user_links ( + user_id uuid NOT NULL, + login_type login_type NOT NULL, + linked_id text DEFAULT ''::text NOT NULL, + oauth_access_token text DEFAULT ''::text NOT NULL, + oauth_refresh_token text DEFAULT ''::text NOT NULL, + oauth_id_token text DEFAULT ''::text NOT NULL, + oauth_expiry timestamp with time zone DEFAULT '0001-01-01 00:00:00+00'::timestamp with time zone NOT NULL +); + CREATE TABLE users ( id uuid NOT NULL, email text NOT NULL, @@ -416,6 +422,9 @@ ALTER TABLE ONLY template_versions ALTER TABLE ONLY templates ADD CONSTRAINT templates_pkey PRIMARY KEY (id); +ALTER TABLE ONLY user_links + ADD CONSTRAINT user_links_user_id_login_type_key UNIQUE (user_id, login_type); + ALTER TABLE ONLY users ADD CONSTRAINT users_pkey PRIMARY KEY (id); diff --git a/coderd/database/migrations/000034_linked_user_id.up.sql b/coderd/database/migrations/000034_linked_user_id.up.sql index 3daa2e78cfd97..8bcddeb692bbd 100644 --- a/coderd/database/migrations/000034_linked_user_id.up.sql +++ b/coderd/database/migrations/000034_linked_user_id.up.sql @@ -1,7 +1,5 @@ BEGIN; -ALTER TYPE login_type ADD VALUE 'oidc'; - CREATE TABLE IF NOT EXISTS user_links ( user_id uuid NOT NULL, login_type login_type NOT NULL, @@ -39,6 +37,10 @@ FROM ) as keys WHERE x=1 AND keys.login_type != 'password'; -ALTER TABLE api_keys RENAME COLUMN login_type TO _login_type +ALTER TABLE api_keys + DROP COLUMN oauth_access_token, + DROP COLUMN oauth_refresh_token, + DROP COLUMN oauth_id_token, + DROP COLUMN oauth_expiry; COMMIT; diff --git a/coderd/database/models.go b/coderd/database/models.go index 38162027a2ec1..bbb98e9856226 100644 --- a/coderd/database/models.go +++ b/coderd/database/models.go @@ -313,20 +313,16 @@ func (e *WorkspaceTransition) Scan(src interface{}) error { } type APIKey struct { - ID string `db:"id" json:"id"` - HashedSecret []byte `db:"hashed_secret" json:"hashed_secret"` - UserID uuid.UUID `db:"user_id" json:"user_id"` - LastUsed time.Time `db:"last_used" json:"last_used"` - ExpiresAt time.Time `db:"expires_at" json:"expires_at"` - CreatedAt time.Time `db:"created_at" json:"created_at"` - UpdatedAt time.Time `db:"updated_at" json:"updated_at"` - LoginType LoginType `db:"login_type" json:"login_type"` - OAuthAccessToken string `db:"oauth_access_token" json:"oauth_access_token"` - OAuthRefreshToken string `db:"oauth_refresh_token" json:"oauth_refresh_token"` - OAuthIDToken string `db:"oauth_id_token" json:"oauth_id_token"` - OAuthExpiry time.Time `db:"oauth_expiry" json:"oauth_expiry"` - LifetimeSeconds int64 `db:"lifetime_seconds" json:"lifetime_seconds"` - IPAddress pqtype.Inet `db:"ip_address" json:"ip_address"` + ID string `db:"id" json:"id"` + HashedSecret []byte `db:"hashed_secret" json:"hashed_secret"` + UserID uuid.UUID `db:"user_id" json:"user_id"` + LastUsed time.Time `db:"last_used" json:"last_used"` + ExpiresAt time.Time `db:"expires_at" json:"expires_at"` + CreatedAt time.Time `db:"created_at" json:"created_at"` + UpdatedAt time.Time `db:"updated_at" json:"updated_at"` + LoginType LoginType `db:"login_type" json:"login_type"` + LifetimeSeconds int64 `db:"lifetime_seconds" json:"lifetime_seconds"` + IPAddress pqtype.Inet `db:"ip_address" json:"ip_address"` } type AuditLog struct { @@ -491,8 +487,16 @@ type User struct { UpdatedAt time.Time `db:"updated_at" json:"updated_at"` Status UserStatus `db:"status" json:"status"` RBACRoles []string `db:"rbac_roles" json:"rbac_roles"` - LoginType LoginType `db:"login_type" json:"login_type"` - LinkedID string `db:"linked_id" json:"linked_id"` +} + +type UserLink struct { + UserID uuid.UUID `db:"user_id" json:"user_id"` + LoginType LoginType `db:"login_type" json:"login_type"` + LinkedID string `db:"linked_id" json:"linked_id"` + OAuthAccessToken string `db:"oauth_access_token" json:"oauth_access_token"` + OAuthRefreshToken string `db:"oauth_refresh_token" json:"oauth_refresh_token"` + OAuthIDToken string `db:"oauth_id_token" json:"oauth_id_token"` + OAuthExpiry time.Time `db:"oauth_expiry" json:"oauth_expiry"` } type Workspace struct { diff --git a/coderd/database/postgres/postgres.go b/coderd/database/postgres/postgres.go index 2d992625885fa..f9ee2cd84a2c5 100644 --- a/coderd/database/postgres/postgres.go +++ b/coderd/database/postgres/postgres.go @@ -138,6 +138,7 @@ func Open() (string, func(), error) { } retryErr = database.MigrateUp(db) if retryErr != nil { + fmt.Printf("err: %v\n", retryErr) return xerrors.Errorf("migrate db: %w", retryErr) } diff --git a/coderd/database/querier.go b/coderd/database/querier.go index b418b34dd3e3f..90e9a3a0a1385 100644 --- a/coderd/database/querier.go +++ b/coderd/database/querier.go @@ -63,7 +63,6 @@ type querier interface { GetTemplatesWithFilter(ctx context.Context, arg GetTemplatesWithFilterParams) ([]Template, error) GetUserByEmailOrUsername(ctx context.Context, arg GetUserByEmailOrUsernameParams) (User, error) GetUserByID(ctx context.Context, id uuid.UUID) (User, error) - GetUserByLinkedID(ctx context.Context, linkedID string) (User, error) GetUserCount(ctx context.Context) (int64, error) GetUsers(ctx context.Context, arg GetUsersParams) ([]User, error) GetUsersByIDs(ctx context.Context, ids []uuid.UUID) ([]User, error) @@ -128,7 +127,6 @@ type querier interface { UpdateTemplateVersionByID(ctx context.Context, arg UpdateTemplateVersionByIDParams) error UpdateTemplateVersionDescriptionByJobID(ctx context.Context, arg UpdateTemplateVersionDescriptionByJobIDParams) error UpdateUserHashedPassword(ctx context.Context, arg UpdateUserHashedPasswordParams) error - UpdateUserLinkedID(ctx context.Context, arg UpdateUserLinkedIDParams) (User, error) UpdateUserProfile(ctx context.Context, arg UpdateUserProfileParams) (User, error) UpdateUserRoles(ctx context.Context, arg UpdateUserRolesParams) (User, error) UpdateUserStatus(ctx context.Context, arg UpdateUserStatusParams) (User, error) diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index f64619266087c..d71fe079bc730 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -31,7 +31,7 @@ func (q *sqlQuerier) DeleteAPIKeyByID(ctx context.Context, id string) error { const getAPIKeyByID = `-- name: GetAPIKeyByID :one SELECT - id, hashed_secret, user_id, last_used, expires_at, created_at, updated_at, login_type, oauth_access_token, oauth_refresh_token, oauth_id_token, oauth_expiry, lifetime_seconds, ip_address + id, hashed_secret, user_id, last_used, expires_at, created_at, updated_at, login_type, lifetime_seconds, ip_address FROM api_keys WHERE @@ -52,10 +52,6 @@ func (q *sqlQuerier) GetAPIKeyByID(ctx context.Context, id string) (APIKey, erro &i.CreatedAt, &i.UpdatedAt, &i.LoginType, - &i.OAuthAccessToken, - &i.OAuthRefreshToken, - &i.OAuthIDToken, - &i.OAuthExpiry, &i.LifetimeSeconds, &i.IPAddress, ) @@ -63,7 +59,7 @@ func (q *sqlQuerier) GetAPIKeyByID(ctx context.Context, id string) (APIKey, erro } const getAPIKeysLastUsedAfter = `-- name: GetAPIKeysLastUsedAfter :many -SELECT id, hashed_secret, user_id, last_used, expires_at, created_at, updated_at, login_type, oauth_access_token, oauth_refresh_token, oauth_id_token, oauth_expiry, lifetime_seconds, ip_address FROM api_keys WHERE last_used > $1 +SELECT id, hashed_secret, user_id, last_used, expires_at, created_at, updated_at, login_type, lifetime_seconds, ip_address FROM api_keys WHERE last_used > $1 ` func (q *sqlQuerier) GetAPIKeysLastUsedAfter(ctx context.Context, lastUsed time.Time) ([]APIKey, error) { @@ -84,10 +80,6 @@ func (q *sqlQuerier) GetAPIKeysLastUsedAfter(ctx context.Context, lastUsed time. &i.CreatedAt, &i.UpdatedAt, &i.LoginType, - &i.OAuthAccessToken, - &i.OAuthRefreshToken, - &i.OAuthIDToken, - &i.OAuthExpiry, &i.LifetimeSeconds, &i.IPAddress, ); err != nil { @@ -115,12 +107,7 @@ INSERT INTO last_used, expires_at, created_at, - updated_at, - login_type, - oauth_access_token, - oauth_refresh_token, - oauth_id_token, - oauth_expiry + updated_at ) VALUES ($1, @@ -129,24 +116,19 @@ VALUES WHEN 0 THEN 86400 ELSE $2::bigint END - , $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14) RETURNING id, hashed_secret, user_id, last_used, expires_at, created_at, updated_at, login_type, oauth_access_token, oauth_refresh_token, oauth_id_token, oauth_expiry, lifetime_seconds, ip_address + , $3, $4, $5, $6, $7, $8, $9) RETURNING id, hashed_secret, user_id, last_used, expires_at, created_at, updated_at, login_type, lifetime_seconds, ip_address ` type InsertAPIKeyParams struct { - ID string `db:"id" json:"id"` - LifetimeSeconds int64 `db:"lifetime_seconds" json:"lifetime_seconds"` - HashedSecret []byte `db:"hashed_secret" json:"hashed_secret"` - IPAddress pqtype.Inet `db:"ip_address" json:"ip_address"` - UserID uuid.UUID `db:"user_id" json:"user_id"` - LastUsed time.Time `db:"last_used" json:"last_used"` - ExpiresAt time.Time `db:"expires_at" json:"expires_at"` - CreatedAt time.Time `db:"created_at" json:"created_at"` - UpdatedAt time.Time `db:"updated_at" json:"updated_at"` - LoginType LoginType `db:"login_type" json:"login_type"` - OAuthAccessToken string `db:"oauth_access_token" json:"oauth_access_token"` - OAuthRefreshToken string `db:"oauth_refresh_token" json:"oauth_refresh_token"` - OAuthIDToken string `db:"oauth_id_token" json:"oauth_id_token"` - OAuthExpiry time.Time `db:"oauth_expiry" json:"oauth_expiry"` + ID string `db:"id" json:"id"` + LifetimeSeconds int64 `db:"lifetime_seconds" json:"lifetime_seconds"` + HashedSecret []byte `db:"hashed_secret" json:"hashed_secret"` + IPAddress pqtype.Inet `db:"ip_address" json:"ip_address"` + UserID uuid.UUID `db:"user_id" json:"user_id"` + LastUsed time.Time `db:"last_used" json:"last_used"` + ExpiresAt time.Time `db:"expires_at" json:"expires_at"` + CreatedAt time.Time `db:"created_at" json:"created_at"` + UpdatedAt time.Time `db:"updated_at" json:"updated_at"` } func (q *sqlQuerier) InsertAPIKey(ctx context.Context, arg InsertAPIKeyParams) (APIKey, error) { @@ -160,11 +142,6 @@ func (q *sqlQuerier) InsertAPIKey(ctx context.Context, arg InsertAPIKeyParams) ( arg.ExpiresAt, arg.CreatedAt, arg.UpdatedAt, - arg.LoginType, - arg.OAuthAccessToken, - arg.OAuthRefreshToken, - arg.OAuthIDToken, - arg.OAuthExpiry, ) var i APIKey err := row.Scan( @@ -176,10 +153,6 @@ func (q *sqlQuerier) InsertAPIKey(ctx context.Context, arg InsertAPIKeyParams) ( &i.CreatedAt, &i.UpdatedAt, &i.LoginType, - &i.OAuthAccessToken, - &i.OAuthRefreshToken, - &i.OAuthIDToken, - &i.OAuthExpiry, &i.LifetimeSeconds, &i.IPAddress, ) @@ -192,22 +165,16 @@ UPDATE SET last_used = $2, expires_at = $3, - ip_address = $4, - oauth_access_token = $5, - oauth_refresh_token = $6, - oauth_expiry = $7 + ip_address = $4 WHERE id = $1 ` type UpdateAPIKeyByIDParams struct { - ID string `db:"id" json:"id"` - LastUsed time.Time `db:"last_used" json:"last_used"` - ExpiresAt time.Time `db:"expires_at" json:"expires_at"` - IPAddress pqtype.Inet `db:"ip_address" json:"ip_address"` - OAuthAccessToken string `db:"oauth_access_token" json:"oauth_access_token"` - OAuthRefreshToken string `db:"oauth_refresh_token" json:"oauth_refresh_token"` - OAuthExpiry time.Time `db:"oauth_expiry" json:"oauth_expiry"` + ID string `db:"id" json:"id"` + LastUsed time.Time `db:"last_used" json:"last_used"` + ExpiresAt time.Time `db:"expires_at" json:"expires_at"` + IPAddress pqtype.Inet `db:"ip_address" json:"ip_address"` } func (q *sqlQuerier) UpdateAPIKeyByID(ctx context.Context, arg UpdateAPIKeyByIDParams) error { @@ -216,9 +183,6 @@ func (q *sqlQuerier) UpdateAPIKeyByID(ctx context.Context, arg UpdateAPIKeyByIDP arg.LastUsed, arg.ExpiresAt, arg.IPAddress, - arg.OAuthAccessToken, - arg.OAuthRefreshToken, - arg.OAuthExpiry, ) return err } @@ -2487,7 +2451,7 @@ func (q *sqlQuerier) GetAuthorizationUserRoles(ctx context.Context, userID uuid. const getUserByEmailOrUsername = `-- name: GetUserByEmailOrUsername :one SELECT - id, email, username, hashed_password, created_at, updated_at, status, rbac_roles, login_type, linked_id + id, email, username, hashed_password, created_at, updated_at, status, rbac_roles FROM users WHERE @@ -2514,15 +2478,13 @@ func (q *sqlQuerier) GetUserByEmailOrUsername(ctx context.Context, arg GetUserBy &i.UpdatedAt, &i.Status, pq.Array(&i.RBACRoles), - &i.LoginType, - &i.LinkedID, ) return i, err } const getUserByID = `-- name: GetUserByID :one SELECT - id, email, username, hashed_password, created_at, updated_at, status, rbac_roles, login_type, linked_id + id, email, username, hashed_password, created_at, updated_at, status, rbac_roles FROM users WHERE @@ -2543,35 +2505,6 @@ func (q *sqlQuerier) GetUserByID(ctx context.Context, id uuid.UUID) (User, error &i.UpdatedAt, &i.Status, pq.Array(&i.RBACRoles), - &i.LoginType, - &i.LinkedID, - ) - return i, err -} - -const getUserByLinkedID = `-- name: GetUserByLinkedID :one -SELECT - id, email, username, hashed_password, created_at, updated_at, status, rbac_roles, login_type, linked_id -FROM - users -WHERE - linked_id = $1 -` - -func (q *sqlQuerier) GetUserByLinkedID(ctx context.Context, linkedID string) (User, error) { - row := q.db.QueryRowContext(ctx, getUserByLinkedID, linkedID) - var i User - err := row.Scan( - &i.ID, - &i.Email, - &i.Username, - &i.HashedPassword, - &i.CreatedAt, - &i.UpdatedAt, - &i.Status, - pq.Array(&i.RBACRoles), - &i.LoginType, - &i.LinkedID, ) return i, err } @@ -2592,7 +2525,7 @@ func (q *sqlQuerier) GetUserCount(ctx context.Context) (int64, error) { const getUsers = `-- name: GetUsers :many SELECT - id, email, username, hashed_password, created_at, updated_at, status, rbac_roles, login_type, linked_id + id, email, username, hashed_password, created_at, updated_at, status, rbac_roles FROM users WHERE @@ -2684,8 +2617,6 @@ func (q *sqlQuerier) GetUsers(ctx context.Context, arg GetUsersParams) ([]User, &i.UpdatedAt, &i.Status, pq.Array(&i.RBACRoles), - &i.LoginType, - &i.LinkedID, ); err != nil { return nil, err } @@ -2701,7 +2632,7 @@ func (q *sqlQuerier) GetUsers(ctx context.Context, arg GetUsersParams) ([]User, } const getUsersByIDs = `-- name: GetUsersByIDs :many -SELECT id, email, username, hashed_password, created_at, updated_at, status, rbac_roles, login_type, linked_id FROM users WHERE id = ANY($1 :: uuid [ ]) +SELECT id, email, username, hashed_password, created_at, updated_at, status, rbac_roles FROM users WHERE id = ANY($1 :: uuid [ ]) ` func (q *sqlQuerier) GetUsersByIDs(ctx context.Context, ids []uuid.UUID) ([]User, error) { @@ -2722,8 +2653,6 @@ func (q *sqlQuerier) GetUsersByIDs(ctx context.Context, ids []uuid.UUID) ([]User &i.UpdatedAt, &i.Status, pq.Array(&i.RBACRoles), - &i.LoginType, - &i.LinkedID, ); err != nil { return nil, err } @@ -2747,12 +2676,10 @@ INSERT INTO hashed_password, created_at, updated_at, - rbac_roles, - login_type, - linked_id + rbac_roles ) VALUES - ($1, $2, $3, $4, $5, $6, $7, $8, $9) RETURNING id, email, username, hashed_password, created_at, updated_at, status, rbac_roles, login_type, linked_id + ($1, $2, $3, $4, $5, $6, $7) RETURNING id, email, username, hashed_password, created_at, updated_at, status, rbac_roles ` type InsertUserParams struct { @@ -2763,8 +2690,6 @@ type InsertUserParams struct { CreatedAt time.Time `db:"created_at" json:"created_at"` UpdatedAt time.Time `db:"updated_at" json:"updated_at"` RBACRoles []string `db:"rbac_roles" json:"rbac_roles"` - LoginType LoginType `db:"login_type" json:"login_type"` - LinkedID string `db:"linked_id" json:"linked_id"` } func (q *sqlQuerier) InsertUser(ctx context.Context, arg InsertUserParams) (User, error) { @@ -2776,8 +2701,6 @@ func (q *sqlQuerier) InsertUser(ctx context.Context, arg InsertUserParams) (User arg.CreatedAt, arg.UpdatedAt, pq.Array(arg.RBACRoles), - arg.LoginType, - arg.LinkedID, ) var i User err := row.Scan( @@ -2789,8 +2712,6 @@ func (q *sqlQuerier) InsertUser(ctx context.Context, arg InsertUserParams) (User &i.UpdatedAt, &i.Status, pq.Array(&i.RBACRoles), - &i.LoginType, - &i.LinkedID, ) return i, err } @@ -2814,38 +2735,6 @@ func (q *sqlQuerier) UpdateUserHashedPassword(ctx context.Context, arg UpdateUse return err } -const updateUserLinkedID = `-- name: UpdateUserLinkedID :one -UPDATE - users -SET - linked_id = $2 -WHERE - id = $1 RETURNING id, email, username, hashed_password, created_at, updated_at, status, rbac_roles, login_type, linked_id -` - -type UpdateUserLinkedIDParams struct { - ID uuid.UUID `db:"id" json:"id"` - LinkedID string `db:"linked_id" json:"linked_id"` -} - -func (q *sqlQuerier) UpdateUserLinkedID(ctx context.Context, arg UpdateUserLinkedIDParams) (User, error) { - row := q.db.QueryRowContext(ctx, updateUserLinkedID, arg.ID, arg.LinkedID) - var i User - err := row.Scan( - &i.ID, - &i.Email, - &i.Username, - &i.HashedPassword, - &i.CreatedAt, - &i.UpdatedAt, - &i.Status, - pq.Array(&i.RBACRoles), - &i.LoginType, - &i.LinkedID, - ) - return i, err -} - const updateUserProfile = `-- name: UpdateUserProfile :one UPDATE users @@ -2854,7 +2743,7 @@ SET username = $3, updated_at = $4 WHERE - id = $1 RETURNING id, email, username, hashed_password, created_at, updated_at, status, rbac_roles, login_type, linked_id + id = $1 RETURNING id, email, username, hashed_password, created_at, updated_at, status, rbac_roles ` type UpdateUserProfileParams struct { @@ -2881,8 +2770,6 @@ func (q *sqlQuerier) UpdateUserProfile(ctx context.Context, arg UpdateUserProfil &i.UpdatedAt, &i.Status, pq.Array(&i.RBACRoles), - &i.LoginType, - &i.LinkedID, ) return i, err } @@ -2895,7 +2782,7 @@ SET rbac_roles = ARRAY(SELECT DISTINCT UNNEST($1 :: text[])) WHERE id = $2 -RETURNING id, email, username, hashed_password, created_at, updated_at, status, rbac_roles, login_type, linked_id +RETURNING id, email, username, hashed_password, created_at, updated_at, status, rbac_roles ` type UpdateUserRolesParams struct { @@ -2915,8 +2802,6 @@ func (q *sqlQuerier) UpdateUserRoles(ctx context.Context, arg UpdateUserRolesPar &i.UpdatedAt, &i.Status, pq.Array(&i.RBACRoles), - &i.LoginType, - &i.LinkedID, ) return i, err } @@ -2928,7 +2813,7 @@ SET status = $2, updated_at = $3 WHERE - id = $1 RETURNING id, email, username, hashed_password, created_at, updated_at, status, rbac_roles, login_type, linked_id + id = $1 RETURNING id, email, username, hashed_password, created_at, updated_at, status, rbac_roles ` type UpdateUserStatusParams struct { @@ -2949,8 +2834,6 @@ func (q *sqlQuerier) UpdateUserStatus(ctx context.Context, arg UpdateUserStatusP &i.UpdatedAt, &i.Status, pq.Array(&i.RBACRoles), - &i.LoginType, - &i.LinkedID, ) return i, err } diff --git a/coderd/database/queries/apikeys.sql b/coderd/database/queries/apikeys.sql index 692ac3e69c8a8..e5a10cb989c9e 100644 --- a/coderd/database/queries/apikeys.sql +++ b/coderd/database/queries/apikeys.sql @@ -22,12 +22,7 @@ INSERT INTO last_used, expires_at, created_at, - updated_at, - login_type, - oauth_access_token, - oauth_refresh_token, - oauth_id_token, - oauth_expiry + updated_at ) VALUES (@id, @@ -36,7 +31,7 @@ VALUES WHEN 0 THEN 86400 ELSE @lifetime_seconds::bigint END - , @hashed_secret, @ip_address, @user_id, @last_used, @expires_at, @created_at, @updated_at, @login_type, @oauth_access_token, @oauth_refresh_token, @oauth_id_token, @oauth_expiry) RETURNING *; + , @hashed_secret, @ip_address, @user_id, @last_used, @expires_at, @created_at, @updated_at) RETURNING *; -- name: UpdateAPIKeyByID :exec UPDATE @@ -44,10 +39,7 @@ UPDATE SET last_used = $2, expires_at = $3, - ip_address = $4, - oauth_access_token = $5, - oauth_refresh_token = $6, - oauth_expiry = $7 + ip_address = $4 WHERE id = $1; diff --git a/coderd/database/queries/users.sql b/coderd/database/queries/users.sql index 7d98f1b1cdb7d..19fe8a7701744 100644 --- a/coderd/database/queries/users.sql +++ b/coderd/database/queries/users.sql @@ -37,12 +37,10 @@ INSERT INTO hashed_password, created_at, updated_at, - rbac_roles, - login_type, - linked_id + rbac_roles ) VALUES - ($1, $2, $3, $4, $5, $6, $7, $8, $9) RETURNING *; + ($1, $2, $3, $4, $5, $6, $7) RETURNING *; -- name: UpdateUserProfile :one UPDATE @@ -161,19 +159,3 @@ LEFT JOIN organization_members ON id = user_id WHERE id = @user_id; - --- name: GetUserByLinkedID :one -SELECT - * -FROM - users -WHERE - linked_id = $1; - --- name: UpdateUserLinkedID :one -UPDATE - users -SET - linked_id = $2 -WHERE - id = $1 RETURNING *; From e41c1033dc9fa91d29afb958ad0c211b2643a79e Mon Sep 17 00:00:00 2001 From: Jon Ayers Date: Fri, 12 Aug 2022 03:06:12 +0000 Subject: [PATCH 15/32] more rewriting --- coderd/audit/diff_test.go | 4 - coderd/database/databasefake/databasefake.go | 56 +----- coderd/database/db_test.go | 2 - coderd/database/querier.go | 5 + coderd/database/queries.sql.go | 181 ++++++++++++++++++- coderd/database/queries/apikeys.sql | 5 +- coderd/database/queries/user_links.sql | 46 +++++ coderd/httpmw/apikey.go | 47 +++-- coderd/httpmw/apikey_test.go | 4 - coderd/userauth.go | 146 +++++++++------ coderd/users.go | 75 ++++---- 11 files changed, 397 insertions(+), 174 deletions(-) create mode 100644 coderd/database/queries/user_links.sql diff --git a/coderd/audit/diff_test.go b/coderd/audit/diff_test.go index 472e7cfad8cd3..fc9c41b7cb16d 100644 --- a/coderd/audit/diff_test.go +++ b/coderd/audit/diff_test.go @@ -158,8 +158,6 @@ func TestDiff(t *testing.T) { UpdatedAt: time.Now(), Status: database.UserStatusActive, RBACRoles: []string{"omega admin"}, - LoginType: database.LoginTypePassword, - LinkedID: "foobar", }, exp: audit.Map{ "id": uuid.UUID{1}.String(), @@ -168,8 +166,6 @@ func TestDiff(t *testing.T) { "hashed_password": ([]byte)(nil), "status": database.UserStatusActive, "rbac_roles": []string{"omega admin"}, - "login_type": database.LoginTypePassword, - "linked_id": "foobar", }, }, }) diff --git a/coderd/database/databasefake/databasefake.go b/coderd/database/databasefake/databasefake.go index 978f8ca4dc926..1e4f83001390d 100644 --- a/coderd/database/databasefake/databasefake.go +++ b/coderd/database/databasefake/databasefake.go @@ -1453,20 +1453,16 @@ func (q *fakeQuerier) InsertAPIKey(_ context.Context, arg database.InsertAPIKeyP //nolint:gosimple key := database.APIKey{ - ID: arg.ID, - LifetimeSeconds: arg.LifetimeSeconds, - HashedSecret: arg.HashedSecret, - IPAddress: arg.IPAddress, - UserID: arg.UserID, - ExpiresAt: arg.ExpiresAt, - CreatedAt: arg.CreatedAt, - UpdatedAt: arg.UpdatedAt, - LastUsed: arg.LastUsed, - LoginType: arg.LoginType, - OAuthAccessToken: arg.OAuthAccessToken, - OAuthRefreshToken: arg.OAuthRefreshToken, - OAuthIDToken: arg.OAuthIDToken, - OAuthExpiry: arg.OAuthExpiry, + ID: arg.ID, + LifetimeSeconds: arg.LifetimeSeconds, + HashedSecret: arg.HashedSecret, + IPAddress: arg.IPAddress, + UserID: arg.UserID, + ExpiresAt: arg.ExpiresAt, + CreatedAt: arg.CreatedAt, + UpdatedAt: arg.UpdatedAt, + LastUsed: arg.LastUsed, + LoginType: arg.LoginType, } q.apiKeys = append(q.apiKeys, key) return key, nil @@ -1743,8 +1739,6 @@ func (q *fakeQuerier) InsertUser(_ context.Context, arg database.InsertUserParam Username: arg.Username, Status: database.UserStatusActive, RBACRoles: arg.RBACRoles, - LoginType: arg.LoginType, - LinkedID: arg.LinkedID, } q.users = append(q.users, user) return user, nil @@ -1900,9 +1894,6 @@ func (q *fakeQuerier) UpdateAPIKeyByID(_ context.Context, arg database.UpdateAPI apiKey.LastUsed = arg.LastUsed apiKey.ExpiresAt = arg.ExpiresAt apiKey.IPAddress = arg.IPAddress - apiKey.OAuthAccessToken = arg.OAuthAccessToken - apiKey.OAuthRefreshToken = arg.OAuthRefreshToken - apiKey.OAuthExpiry = arg.OAuthExpiry q.apiKeys[index] = apiKey return nil } @@ -2261,30 +2252,3 @@ func (q *fakeQuerier) GetDeploymentID(_ context.Context) (string, error) { return q.deploymentID, nil } - -func (q *fakeQuerier) UpdateUserLinkedID(_ context.Context, arg database.UpdateUserLinkedIDParams) (database.User, error) { - q.mutex.Lock() - defer q.mutex.Unlock() - - for index, user := range q.users { - if user.ID != arg.ID { - continue - } - user.LinkedID = arg.LinkedID - q.users[index] = user - return user, nil - } - return database.User{}, sql.ErrNoRows -} - -func (q *fakeQuerier) GetUserByLinkedID(_ context.Context, linkedID string) (database.User, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - for _, user := range q.users { - if user.LinkedID == linkedID { - return user, nil - } - } - return database.User{}, sql.ErrNoRows -} diff --git a/coderd/database/db_test.go b/coderd/database/db_test.go index 1fbdc4f34c2da..324e048e9156c 100644 --- a/coderd/database/db_test.go +++ b/coderd/database/db_test.go @@ -37,8 +37,6 @@ func TestNestedInTx(t *testing.T) { CreatedAt: database.Now(), UpdatedAt: database.Now(), RBACRoles: []string{}, - LoginType: database.LoginTypePassword, - LinkedID: uuid.NewString(), }) return err }) diff --git a/coderd/database/querier.go b/coderd/database/querier.go index 90e9a3a0a1385..272efe40fd446 100644 --- a/coderd/database/querier.go +++ b/coderd/database/querier.go @@ -64,6 +64,8 @@ type querier interface { GetUserByEmailOrUsername(ctx context.Context, arg GetUserByEmailOrUsernameParams) (User, error) GetUserByID(ctx context.Context, id uuid.UUID) (User, error) GetUserCount(ctx context.Context) (int64, error) + GetUserLinkByLinkedID(ctx context.Context, linkedID string) (UserLink, error) + GetUserLinkByUserIDLoginType(ctx context.Context, arg GetUserLinkByUserIDLoginTypeParams) (UserLink, error) GetUsers(ctx context.Context, arg GetUsersParams) ([]User, error) GetUsersByIDs(ctx context.Context, ids []uuid.UUID) ([]User, error) GetWorkspaceAgentByAuthToken(ctx context.Context, authToken uuid.UUID) (WorkspaceAgent, error) @@ -106,6 +108,7 @@ type querier interface { InsertTemplate(ctx context.Context, arg InsertTemplateParams) (Template, error) InsertTemplateVersion(ctx context.Context, arg InsertTemplateVersionParams) (TemplateVersion, error) InsertUser(ctx context.Context, arg InsertUserParams) (User, error) + InsertUserLink(ctx context.Context, arg InsertUserLinkParams) (UserLink, error) InsertWorkspace(ctx context.Context, arg InsertWorkspaceParams) (Workspace, error) InsertWorkspaceAgent(ctx context.Context, arg InsertWorkspaceAgentParams) (WorkspaceAgent, error) InsertWorkspaceApp(ctx context.Context, arg InsertWorkspaceAppParams) (WorkspaceApp, error) @@ -127,6 +130,8 @@ type querier interface { UpdateTemplateVersionByID(ctx context.Context, arg UpdateTemplateVersionByIDParams) error UpdateTemplateVersionDescriptionByJobID(ctx context.Context, arg UpdateTemplateVersionDescriptionByJobIDParams) error UpdateUserHashedPassword(ctx context.Context, arg UpdateUserHashedPasswordParams) error + UpdateUserLink(ctx context.Context, arg UpdateUserLinkParams) (UserLink, error) + UpdateUserLinkedID(ctx context.Context, arg UpdateUserLinkedIDParams) (UserLink, error) UpdateUserProfile(ctx context.Context, arg UpdateUserProfileParams) (User, error) UpdateUserRoles(ctx context.Context, arg UpdateUserRolesParams) (User, error) UpdateUserStatus(ctx context.Context, arg UpdateUserStatusParams) (User, error) diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index d71fe079bc730..8a25b0740183f 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -107,7 +107,8 @@ INSERT INTO last_used, expires_at, created_at, - updated_at + updated_at, + login_type ) VALUES ($1, @@ -116,7 +117,7 @@ VALUES WHEN 0 THEN 86400 ELSE $2::bigint END - , $3, $4, $5, $6, $7, $8, $9) RETURNING id, hashed_secret, user_id, last_used, expires_at, created_at, updated_at, login_type, lifetime_seconds, ip_address + , $3, $4, $5, $6, $7, $8, $9, $10) RETURNING id, hashed_secret, user_id, last_used, expires_at, created_at, updated_at, login_type, lifetime_seconds, ip_address ` type InsertAPIKeyParams struct { @@ -129,6 +130,7 @@ type InsertAPIKeyParams struct { ExpiresAt time.Time `db:"expires_at" json:"expires_at"` CreatedAt time.Time `db:"created_at" json:"created_at"` UpdatedAt time.Time `db:"updated_at" json:"updated_at"` + LoginType LoginType `db:"login_type" json:"login_type"` } func (q *sqlQuerier) InsertAPIKey(ctx context.Context, arg InsertAPIKeyParams) (APIKey, error) { @@ -142,6 +144,7 @@ func (q *sqlQuerier) InsertAPIKey(ctx context.Context, arg InsertAPIKeyParams) ( arg.ExpiresAt, arg.CreatedAt, arg.UpdatedAt, + arg.LoginType, ) var i APIKey err := row.Scan( @@ -2408,6 +2411,180 @@ func (q *sqlQuerier) UpdateTemplateVersionDescriptionByJobID(ctx context.Context return err } +const getUserLinkByLinkedID = `-- name: GetUserLinkByLinkedID :one +SELECT + user_id, login_type, linked_id, oauth_access_token, oauth_refresh_token, oauth_id_token, oauth_expiry +FROM + user_links +WHERE + linked_id = $1 +` + +func (q *sqlQuerier) GetUserLinkByLinkedID(ctx context.Context, linkedID string) (UserLink, error) { + row := q.db.QueryRowContext(ctx, getUserLinkByLinkedID, linkedID) + var i UserLink + err := row.Scan( + &i.UserID, + &i.LoginType, + &i.LinkedID, + &i.OAuthAccessToken, + &i.OAuthRefreshToken, + &i.OAuthIDToken, + &i.OAuthExpiry, + ) + return i, err +} + +const getUserLinkByUserIDLoginType = `-- name: GetUserLinkByUserIDLoginType :one +SELECT + user_id, login_type, linked_id, oauth_access_token, oauth_refresh_token, oauth_id_token, oauth_expiry +FROM + user_links +WHERE + user_id = $1 AND login_type = $2 +` + +type GetUserLinkByUserIDLoginTypeParams struct { + UserID uuid.UUID `db:"user_id" json:"user_id"` + LoginType LoginType `db:"login_type" json:"login_type"` +} + +func (q *sqlQuerier) GetUserLinkByUserIDLoginType(ctx context.Context, arg GetUserLinkByUserIDLoginTypeParams) (UserLink, error) { + row := q.db.QueryRowContext(ctx, getUserLinkByUserIDLoginType, arg.UserID, arg.LoginType) + var i UserLink + err := row.Scan( + &i.UserID, + &i.LoginType, + &i.LinkedID, + &i.OAuthAccessToken, + &i.OAuthRefreshToken, + &i.OAuthIDToken, + &i.OAuthExpiry, + ) + return i, err +} + +const insertUserLink = `-- name: InsertUserLink :one +INSERT INTO + user_links ( + user_id, + login_type, + linked_id, + oauth_access_token, + oauth_refresh_token, + oauth_id_token, + oauth_expiry + ) +VALUES + ( $1, $2, $3, $4, $5, $6, $7 ) RETURNING user_id, login_type, linked_id, oauth_access_token, oauth_refresh_token, oauth_id_token, oauth_expiry +` + +type InsertUserLinkParams struct { + UserID uuid.UUID `db:"user_id" json:"user_id"` + LoginType LoginType `db:"login_type" json:"login_type"` + LinkedID string `db:"linked_id" json:"linked_id"` + OAuthAccessToken string `db:"oauth_access_token" json:"oauth_access_token"` + OAuthRefreshToken string `db:"oauth_refresh_token" json:"oauth_refresh_token"` + OAuthIDToken string `db:"oauth_id_token" json:"oauth_id_token"` + OAuthExpiry time.Time `db:"oauth_expiry" json:"oauth_expiry"` +} + +func (q *sqlQuerier) InsertUserLink(ctx context.Context, arg InsertUserLinkParams) (UserLink, error) { + row := q.db.QueryRowContext(ctx, insertUserLink, + arg.UserID, + arg.LoginType, + arg.LinkedID, + arg.OAuthAccessToken, + arg.OAuthRefreshToken, + arg.OAuthIDToken, + arg.OAuthExpiry, + ) + var i UserLink + err := row.Scan( + &i.UserID, + &i.LoginType, + &i.LinkedID, + &i.OAuthAccessToken, + &i.OAuthRefreshToken, + &i.OAuthIDToken, + &i.OAuthExpiry, + ) + return i, err +} + +const updateUserLink = `-- name: UpdateUserLink :one +UPDATE + user_links +SET + oauth_access_token = $1, + oauth_refresh_token = $2, + oauth_id_token = $3, + oauth_expiry = $4 +WHERE + user_id = $5 AND login_type = $6 RETURNING user_id, login_type, linked_id, oauth_access_token, oauth_refresh_token, oauth_id_token, oauth_expiry +` + +type UpdateUserLinkParams struct { + OAuthAccessToken string `db:"oauth_access_token" json:"oauth_access_token"` + OAuthRefreshToken string `db:"oauth_refresh_token" json:"oauth_refresh_token"` + OAuthIDToken string `db:"oauth_id_token" json:"oauth_id_token"` + OAuthExpiry time.Time `db:"oauth_expiry" json:"oauth_expiry"` + UserID uuid.UUID `db:"user_id" json:"user_id"` + LoginType LoginType `db:"login_type" json:"login_type"` +} + +func (q *sqlQuerier) UpdateUserLink(ctx context.Context, arg UpdateUserLinkParams) (UserLink, error) { + row := q.db.QueryRowContext(ctx, updateUserLink, + arg.OAuthAccessToken, + arg.OAuthRefreshToken, + arg.OAuthIDToken, + arg.OAuthExpiry, + arg.UserID, + arg.LoginType, + ) + var i UserLink + err := row.Scan( + &i.UserID, + &i.LoginType, + &i.LinkedID, + &i.OAuthAccessToken, + &i.OAuthRefreshToken, + &i.OAuthIDToken, + &i.OAuthExpiry, + ) + return i, err +} + +const updateUserLinkedID = `-- name: UpdateUserLinkedID :one +UPDATE + user_links +SET + linked_id = $1 +WHERE + user_id = $2 AND login_type = $3 RETURNING user_id, login_type, linked_id, oauth_access_token, oauth_refresh_token, oauth_id_token, oauth_expiry +` + +type UpdateUserLinkedIDParams struct { + LinkedID string `db:"linked_id" json:"linked_id"` + UserID uuid.UUID `db:"user_id" json:"user_id"` + LoginType LoginType `db:"login_type" json:"login_type"` +} + +func (q *sqlQuerier) UpdateUserLinkedID(ctx context.Context, arg UpdateUserLinkedIDParams) (UserLink, error) { + row := q.db.QueryRowContext(ctx, updateUserLinkedID, arg.LinkedID, arg.UserID, arg.LoginType) + var i UserLink + err := row.Scan( + &i.UserID, + &i.LoginType, + &i.LinkedID, + &i.OAuthAccessToken, + &i.OAuthRefreshToken, + &i.OAuthIDToken, + &i.OAuthExpiry, + ) + return i, err +} + const getAuthorizationUserRoles = `-- name: GetAuthorizationUserRoles :one SELECT -- username is returned just to help for logging purposes diff --git a/coderd/database/queries/apikeys.sql b/coderd/database/queries/apikeys.sql index e5a10cb989c9e..22ce2e6057f3e 100644 --- a/coderd/database/queries/apikeys.sql +++ b/coderd/database/queries/apikeys.sql @@ -22,7 +22,8 @@ INSERT INTO last_used, expires_at, created_at, - updated_at + updated_at, + login_type ) VALUES (@id, @@ -31,7 +32,7 @@ VALUES WHEN 0 THEN 86400 ELSE @lifetime_seconds::bigint END - , @hashed_secret, @ip_address, @user_id, @last_used, @expires_at, @created_at, @updated_at) RETURNING *; + , @hashed_secret, @ip_address, @user_id, @last_used, @expires_at, @created_at, @updated_at, @login_type) RETURNING *; -- name: UpdateAPIKeyByID :exec UPDATE diff --git a/coderd/database/queries/user_links.sql b/coderd/database/queries/user_links.sql new file mode 100644 index 0000000000000..94120a5793a3c --- /dev/null +++ b/coderd/database/queries/user_links.sql @@ -0,0 +1,46 @@ +-- name: GetUserLinkByLinkedID :one +SELECT + * +FROM + user_links +WHERE + linked_id = $1; + +-- name: GetUserLinkByUserIDLoginType :one +SELECT + * +FROM + user_links +WHERE + user_id = $1 AND login_type = $2; + +-- name: InsertUserLink :one +INSERT INTO + user_links ( + user_id, + login_type, + linked_id, + oauth_access_token, + oauth_refresh_token, + oauth_expiry + ) +VALUES + ( $1, $2, $3, $4, $5, $6 ) RETURNING *; + +-- name: UpdateUserLinkedID :one +UPDATE + user_links +SET + linked_id = $1 +WHERE + user_id = $2 AND login_type = $3 RETURNING *; + +-- name: UpdateUserLink :one +UPDATE + user_links +SET + oauth_access_token = $1, + oauth_refresh_token = $2, + oauth_expiry = $3 +WHERE + user_id = $4 AND login_type = $5 RETURNING *; diff --git a/coderd/httpmw/apikey.go b/coderd/httpmw/apikey.go index 80586bc976f49..d6c9288e92071 100644 --- a/coderd/httpmw/apikey.go +++ b/coderd/httpmw/apikey.go @@ -14,6 +14,7 @@ import ( "golang.org/x/oauth2" + "github.com/google/uuid" "github.com/tabbed/pqtype" "github.com/coder/coder/coderd/database" @@ -149,9 +150,21 @@ func ExtractAPIKey(db database.Store, oauth *OAuth2Configs, redirectToLogin bool // Tracks if the API key has properties updated! changed := false + var link database.UserLink if key.LoginType != database.LoginTypePassword { + link, err = db.GetUserLinkByUserIDLoginType(r.Context(), database.GetUserLinkByUserIDLoginTypeParams{ + UserID: key.UserID, + LoginType: key.LoginType, + }) + if err != nil { + write(http.StatusInternalServerError, codersdk.Response{ + Message: "A database error occurred", + Detail: err.Error(), + }) + return + } // Check if the OAuth token is expired! - if key.OAuthExpiry.Before(now) && !key.OAuthExpiry.IsZero() { + if link.OAuthExpiry.Before(now) && !link.OAuthExpiry.IsZero() { var oauthConfig OAuth2Config switch key.LoginType { case database.LoginTypeGithub: @@ -167,9 +180,9 @@ func ExtractAPIKey(db database.Store, oauth *OAuth2Configs, redirectToLogin bool } // If it is, let's refresh it from the provided config! token, err := oauthConfig.TokenSource(r.Context(), &oauth2.Token{ - AccessToken: key.OAuthAccessToken, - RefreshToken: key.OAuthRefreshToken, - Expiry: key.OAuthExpiry, + AccessToken: link.OAuthAccessToken, + RefreshToken: link.OAuthRefreshToken, + Expiry: link.OAuthExpiry, }).Token() if err != nil { write(http.StatusUnauthorized, codersdk.Response{ @@ -178,9 +191,9 @@ func ExtractAPIKey(db database.Store, oauth *OAuth2Configs, redirectToLogin bool }) return } - key.OAuthAccessToken = token.AccessToken - key.OAuthRefreshToken = token.RefreshToken - key.OAuthExpiry = token.Expiry + link.OAuthAccessToken = token.AccessToken + link.OAuthRefreshToken = token.RefreshToken + link.OAuthExpiry = token.Expiry key.ExpiresAt = token.Expiry changed = true } @@ -222,13 +235,10 @@ func ExtractAPIKey(db database.Store, oauth *OAuth2Configs, redirectToLogin bool } if changed { err := db.UpdateAPIKeyByID(r.Context(), database.UpdateAPIKeyByIDParams{ - ID: key.ID, - LastUsed: key.LastUsed, - ExpiresAt: key.ExpiresAt, - IPAddress: key.IPAddress, - OAuthAccessToken: key.OAuthAccessToken, - OAuthRefreshToken: key.OAuthRefreshToken, - OAuthExpiry: key.OAuthExpiry, + ID: key.ID, + LastUsed: key.LastUsed, + ExpiresAt: key.ExpiresAt, + IPAddress: key.IPAddress, }) if err != nil { write(http.StatusInternalServerError, codersdk.Response{ @@ -237,6 +247,15 @@ func ExtractAPIKey(db database.Store, oauth *OAuth2Configs, redirectToLogin bool }) return } + if link.UserID != uuid.Nil { + link, err = db.UpdateUserLink(r.Context(), database.UpdateUserLinkParams{ + UserID: link.UserID, + OAuthAccessToken: link.OAuthAccessToken, + OAuthRefreshToken: link.OAuthRefreshToken, + OAuthExpiry: link.OAuthExpiry, + }) + + } } // If the key is valid, we also fetch the user roles and status. diff --git a/coderd/httpmw/apikey_test.go b/coderd/httpmw/apikey_test.go index 58456aafd59bf..28343b7f9bded 100644 --- a/coderd/httpmw/apikey_test.go +++ b/coderd/httpmw/apikey_test.go @@ -393,7 +393,6 @@ func TestAPIKey(t *testing.T) { HashedSecret: hashed[:], LoginType: database.LoginTypeGithub, LastUsed: database.Now(), - OAuthExpiry: database.Now().AddDate(0, 0, -1), UserID: user.ID, }) require.NoError(t, err) @@ -418,7 +417,6 @@ func TestAPIKey(t *testing.T) { require.Equal(t, sentAPIKey.LastUsed, gotAPIKey.LastUsed) require.Equal(t, token.Expiry, gotAPIKey.ExpiresAt) - require.Equal(t, token.AccessToken, gotAPIKey.OAuthAccessToken) }) t.Run("RemoteIPUpdates", func(t *testing.T) { @@ -466,8 +464,6 @@ func createUser(ctx context.Context, t *testing.T, db database.Store) database.U CreatedAt: time.Now(), UpdatedAt: time.Now(), RBACRoles: []string{}, - LinkedID: uuid.NewString(), - LoginType: database.LoginTypePassword, }) require.NoError(t, err, "create user") return user diff --git a/coderd/userauth.go b/coderd/userauth.go index e88e759d91b1f..c402a42ed7944 100644 --- a/coderd/userauth.go +++ b/coderd/userauth.go @@ -137,7 +137,7 @@ func (api *API) userOAuth2Github(rw http.ResponseWriter, r *http.Request) { return } - user, found, err := findLinkedUser(ctx, api.Database, githubLinkedID(ghUser), verifiedEmails...) + user, link, err := findLinkedUser(ctx, api.Database, githubLinkedID(ghUser), verifiedEmails...) if err != nil { httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{ Message: "Failed to find user.", @@ -145,15 +145,15 @@ func (api *API) userOAuth2Github(rw http.ResponseWriter, r *http.Request) { return } - if found && user.LoginType != database.LoginTypeGithub { + if link.UserID != uuid.Nil && link.LoginType != database.LoginTypeGithub { httpapi.Write(rw, http.StatusConflict, codersdk.Response{ - Message: fmt.Sprintf("Incorrect login type, attempting to use %q but user is of login type %q", database.LoginTypeOIDC, user.LoginType), + Message: fmt.Sprintf("Incorrect login type, attempting to use %q but user is of login type %q", database.LoginTypeOIDC, link.LoginType), }) return } // If the user doesn't exist, create a new one! - if !found { + if user.ID == uuid.Nil { if !api.GithubOAuth2Config.AllowSignups { httpapi.Write(rw, http.StatusForbidden, codersdk.Response{ Message: "Signups are disabled for Github authentication!", @@ -183,14 +183,10 @@ func (api *API) userOAuth2Github(rw http.ResponseWriter, r *http.Request) { }) return } - user, _, err = api.createUser(r.Context(), createUserRequest{ - CreateUserRequest: codersdk.CreateUserRequest{ - Email: *verifiedEmail.Email, - Username: *ghUser.Login, - OrganizationID: organizationID, - }, - LinkedID: githubLinkedID(ghUser), - LoginType: database.LoginTypeGithub, + user, _, err = api.createUser(r.Context(), codersdk.CreateUserRequest{ + Email: *verifiedEmail.Email, + Username: *ghUser.Login, + OrganizationID: organizationID, }) if err != nil { httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{ @@ -199,31 +195,50 @@ func (api *API) userOAuth2Github(rw http.ResponseWriter, r *http.Request) { }) return } + + } + + // This can happen if a user is a built-in user but is signing in + // with Github for the first time. + if link.UserID == uuid.Nil { + link, err = api.Database.InsertUserLink(ctx, database.InsertUserLinkParams{ + UserID: user.ID, + LoginType: database.LoginTypeGithub, + LinkedID: githubLinkedID(ghUser), + OAuthAccessToken: state.Token.AccessToken, + OAuthRefreshToken: state.Token.RefreshToken, + OAuthExpiry: state.Token.Expiry, + }) + if err != nil { + httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{ + Message: "A database error occurred.", + Detail: xerrors.Errorf("insert user link: %w", err.Error).Error(), + }) + return + } } // LEGACY: Remove 10/2022. // We started tracking linked IDs later so it's possible for a user to be a // pre-existing Github user and not have a linked ID. - if user.LinkedID == "" { - user, err = api.Database.UpdateUserLinkedID(ctx, database.UpdateUserLinkedIDParams{ - ID: user.ID, - LinkedID: githubLinkedID(ghUser), + if link.LinkedID == "" { + link, err = api.Database.UpdateUserLinkedID(ctx, database.UpdateUserLinkedIDParams{ + UserID: user.ID, + LinkedID: githubLinkedID(ghUser), + LoginType: database.LoginTypeGithub, }) if err != nil { httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{ - Message: "Failed to set user linked ID.", - Detail: err.Error(), + Message: "A database error occurred.", + Detail: xerrors.Errorf("update user link: %w", err.Error).Error(), }) return } } - _, created := api.createAPIKey(rw, r, database.InsertAPIKeyParams{ - UserID: user.ID, - LoginType: database.LoginTypeGithub, - OAuthAccessToken: state.Token.AccessToken, - OAuthRefreshToken: state.Token.RefreshToken, - OAuthExpiry: state.Token.Expiry, + _, created := api.createAPIKey(rw, r, createAPIKeyParams{ + UserID: user.ID, + LoginType: database.LoginTypeGithub, }) if !created { return @@ -315,7 +330,7 @@ func (api *API) userOIDC(rw http.ResponseWriter, r *http.Request) { } } - user, found, err := findLinkedUser(ctx, api.Database, oidcLinkedID(idToken), claims.Email) + user, link, err := findLinkedUser(ctx, api.Database, oidcLinkedID(idToken), claims.Email) if err != nil { httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{ Message: "Failed to find user.", @@ -324,21 +339,21 @@ func (api *API) userOIDC(rw http.ResponseWriter, r *http.Request) { return } - if !found && !api.OIDCConfig.AllowSignups { + if user.ID == uuid.Nil && !api.OIDCConfig.AllowSignups { httpapi.Write(rw, http.StatusForbidden, codersdk.Response{ Message: "Signups are disabled for OIDC authentication!", }) return } - if found && user.LoginType != database.LoginTypeOIDC { + if link.UserID != uuid.Nil && link.LoginType != database.LoginTypeOIDC { httpapi.Write(rw, http.StatusConflict, codersdk.Response{ - Message: fmt.Sprintf("Incorrect login type, attempting to use %q but user is of login type %q", database.LoginTypeOIDC, user.LoginType), + Message: fmt.Sprintf("Incorrect login type, attempting to use %q but user is of login type %q", database.LoginTypeOIDC, link.LoginType), }) return } - if !found { + if user.ID == uuid.Nil { var organizationID uuid.UUID organizations, _ := api.Database.GetOrganizations(ctx) if len(organizations) > 0 { @@ -347,14 +362,11 @@ func (api *API) userOIDC(rw http.ResponseWriter, r *http.Request) { // email to organization. organizationID = organizations[0].ID } - user, _, err = api.createUser(ctx, createUserRequest{ - CreateUserRequest: codersdk.CreateUserRequest{ - Email: claims.Email, - Username: claims.Username, - OrganizationID: organizationID, - }, - LinkedID: oidcLinkedID(idToken), - LoginType: database.LoginTypeOIDC, + + user, _, err = api.createUser(ctx, codersdk.CreateUserRequest{ + Email: claims.Email, + Username: claims.Username, + OrganizationID: organizationID, }) if err != nil { httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{ @@ -372,18 +384,38 @@ func (api *API) userOIDC(rw http.ResponseWriter, r *http.Request) { } } + if link.UserID == uuid.Nil { + link, err = api.Database.InsertUserLink(ctx, database.InsertUserLinkParams{ + UserID: user.ID, + LoginType: database.LoginTypeGithub, + LinkedID: oidcLinkedID(idToken), + OAuthAccessToken: state.Token.AccessToken, + OAuthRefreshToken: state.Token.RefreshToken, + OAuthExpiry: state.Token.Expiry, + }) + if err != nil { + httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{ + Message: "A database error occurred.", + Detail: xerrors.Errorf("insert user link: %w", err.Error).Error(), + }) + return + } + + } + // LEGACY: Remove 10/2022. // We started tracking linked IDs later so it's possible for a user to be a // pre-existing OIDC user and not have a linked ID. - if user.LinkedID == "" { - user, err = api.Database.UpdateUserLinkedID(ctx, database.UpdateUserLinkedIDParams{ - ID: user.ID, - LinkedID: oidcLinkedID(idToken), + if link.LinkedID == "" { + link, err = api.Database.UpdateUserLinkedID(ctx, database.UpdateUserLinkedIDParams{ + UserID: user.ID, + LinkedID: oidcLinkedID(idToken), + LoginType: database.LoginTypeGithub, }) if err != nil { httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{ - Message: "Failed to update user linked ID.", - Detail: err.Error(), + Message: "A database error occurred.", + Detail: xerrors.Errorf("update user link: %w", err.Error).Error(), }) return } @@ -414,12 +446,9 @@ func (api *API) userOIDC(rw http.ResponseWriter, r *http.Request) { } } - _, created := api.createAPIKey(rw, r, database.InsertAPIKeyParams{ - UserID: user.ID, - LoginType: database.LoginTypeOIDC, - OAuthAccessToken: state.Token.AccessToken, - OAuthRefreshToken: state.Token.RefreshToken, - OAuthExpiry: state.Token.Expiry, + _, created := api.createAPIKey(rw, r, createAPIKeyParams{ + UserID: user.ID, + LoginType: database.LoginTypeOIDC, }) if !created { return @@ -445,14 +474,19 @@ func oidcLinkedID(tok *oidc.IDToken) string { // findLinkedUser tries to find a user by their unique OAuth-linked ID. // If it doesn't not find it, it returns the user by their email. -func findLinkedUser(ctx context.Context, db database.Store, linkedID string, emails ...string) (database.User, bool, error) { - user, err := db.GetUserByLinkedID(ctx, linkedID) +func findLinkedUser(ctx context.Context, db database.Store, linkedID string, emails ...string) (database.User, database.UserLink, error) { + var ( + user database.User + link database.UserLink + ) + link, err := db.GetUserLinkByLinkedID(ctx, linkedID) if err != nil && !errors.Is(err, sql.ErrNoRows) { - return user, false, xerrors.Errorf("get user auth by linked ID: %w", err) + return user, link, xerrors.Errorf("get user auth by linked ID: %w", err) } if err == nil { - return user, true, nil + user, err = db.GetUserByID(ctx, link.UserID) + return user, link, nil } for _, email := range emails { @@ -460,14 +494,14 @@ func findLinkedUser(ctx context.Context, db database.Store, linkedID string, ema Email: email, }) if err != nil && !errors.Is(err, sql.ErrNoRows) { - return user, false, xerrors.Errorf("get user by email: %w", err) + return user, link, xerrors.Errorf("get user by email: %w", err) } if errors.Is(err, sql.ErrNoRows) { continue } - return user, true, nil + return user, link, nil } // No user found. - return user, false, nil + return user, link, nil } diff --git a/coderd/users.go b/coderd/users.go index 4c67838be01f2..ccf3cd8856f1f 100644 --- a/coderd/users.go +++ b/coderd/users.go @@ -77,14 +77,10 @@ func (api *API) postFirstUser(rw http.ResponseWriter, r *http.Request) { return } - user, organizationID, err := api.createUser(r.Context(), createUserRequest{ - CreateUserRequest: codersdk.CreateUserRequest{ - Email: createUser.Email, - Username: createUser.Username, - Password: createUser.Password, - }, - LoginType: database.LoginTypePassword, - LinkedID: createUser.Email, + user, organizationID, err := api.createUser(r.Context(), codersdk.CreateUserRequest{ + Email: createUser.Email, + Username: createUser.Username, + Password: createUser.Password, }) if err != nil { httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{ @@ -192,14 +188,14 @@ func (api *API) postUser(rw http.ResponseWriter, r *http.Request) { return } - var createUser codersdk.CreateUserRequest - if !httpapi.Read(rw, r, &createUser) { + var req codersdk.CreateUserRequest + if !httpapi.Read(rw, r, &req) { return } // Create the organization member in the org. if !api.Authorize(r, rbac.ActionCreate, - rbac.ResourceOrganizationMember.InOrg(createUser.OrganizationID)) { + rbac.ResourceOrganizationMember.InOrg(req.OrganizationID)) { httpapi.ResourceNotFound(rw) return } @@ -207,8 +203,8 @@ func (api *API) postUser(rw http.ResponseWriter, r *http.Request) { // TODO: @emyrk Authorize the organization create if the createUser will do that. _, err := api.Database.GetUserByEmailOrUsername(r.Context(), database.GetUserByEmailOrUsernameParams{ - Username: createUser.Username, - Email: createUser.Email, + Username: req.Username, + Email: req.Email, }) if err == nil { httpapi.Write(rw, http.StatusConflict, codersdk.Response{ @@ -224,10 +220,10 @@ func (api *API) postUser(rw http.ResponseWriter, r *http.Request) { return } - _, err = api.Database.GetOrganizationByID(r.Context(), createUser.OrganizationID) + _, err = api.Database.GetOrganizationByID(r.Context(), req.OrganizationID) if errors.Is(err, sql.ErrNoRows) { httpapi.Write(rw, http.StatusNotFound, codersdk.Response{ - Message: fmt.Sprintf("Organization does not exist with the provided id %q.", createUser.OrganizationID), + Message: fmt.Sprintf("Organization does not exist with the provided id %q.", req.OrganizationID), }) return } @@ -239,11 +235,7 @@ func (api *API) postUser(rw http.ResponseWriter, r *http.Request) { return } - user, _, err := api.createUser(r.Context(), createUserRequest{ - CreateUserRequest: createUser, - LinkedID: createUser.Email, - LoginType: database.LoginTypePassword, - }) + user, _, err := api.createUser(r.Context(), req) if err != nil { httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{ Message: "Internal error creating user.", @@ -257,7 +249,7 @@ func (api *API) postUser(rw http.ResponseWriter, r *http.Request) { Users: []telemetry.User{telemetry.ConvertUser(user)}, }) - httpapi.Write(rw, http.StatusCreated, convertUser(user, []uuid.UUID{createUser.OrganizationID})) + httpapi.Write(rw, http.StatusCreated, convertUser(user, []uuid.UUID{req.OrganizationID})) } // Returns the parameterized user requested. All validation @@ -695,7 +687,7 @@ func (api *API) postLogin(rw http.ResponseWriter, r *http.Request) { return } - sessionToken, created := api.createAPIKey(rw, r, database.InsertAPIKeyParams{ + sessionToken, created := api.createAPIKey(rw, r, createAPIKeyParams{ UserID: user.ID, LoginType: database.LoginTypePassword, }) @@ -718,7 +710,7 @@ func (api *API) postAPIKey(rw http.ResponseWriter, r *http.Request) { } lifeTime := time.Hour * 24 * 7 - sessionToken, created := api.createAPIKey(rw, r, database.InsertAPIKeyParams{ + sessionToken, created := api.createAPIKey(rw, r, createAPIKeyParams{ UserID: user.ID, LoginType: database.LoginTypePassword, // All api generated keys will last 1 week. Browser login tokens have @@ -804,7 +796,16 @@ func generateAPIKeyIDSecret() (id string, secret string, err error) { return id, secret, nil } -func (api *API) createAPIKey(rw http.ResponseWriter, r *http.Request, params database.InsertAPIKeyParams) (string, bool) { +type createAPIKeyParams struct { + UserID uuid.UUID + LoginType database.LoginType + + // Optional. + ExpiresAt time.Time + LifetimeSeconds int64 +} + +func (api *API) createAPIKey(rw http.ResponseWriter, r *http.Request, params createAPIKeyParams) (string, bool) { keyID, keySecret, err := generateAPIKeyIDSecret() if err != nil { httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{ @@ -842,15 +843,11 @@ func (api *API) createAPIKey(rw http.ResponseWriter, r *http.Request, params dat Valid: true, }, // Make sure in UTC time for common time zone - ExpiresAt: params.ExpiresAt.UTC(), - CreatedAt: database.Now(), - UpdatedAt: database.Now(), - HashedSecret: hashed[:], - LoginType: params.LoginType, - OAuthAccessToken: params.OAuthAccessToken, - OAuthRefreshToken: params.OAuthRefreshToken, - OAuthIDToken: params.OAuthIDToken, - OAuthExpiry: params.OAuthExpiry, + ExpiresAt: params.ExpiresAt.UTC(), + CreatedAt: database.Now(), + UpdatedAt: database.Now(), + HashedSecret: hashed[:], + LoginType: params.LoginType, }) if err != nil { httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{ @@ -877,13 +874,7 @@ func (api *API) createAPIKey(rw http.ResponseWriter, r *http.Request, params dat return sessionToken, true } -type createUserRequest struct { - codersdk.CreateUserRequest - LoginType database.LoginType - LinkedID string -} - -func (api *API) createUser(ctx context.Context, req createUserRequest) (database.User, uuid.UUID, error) { +func (api *API) createUser(ctx context.Context, req codersdk.CreateUserRequest) (database.User, uuid.UUID, error) { var user database.User return user, req.OrganizationID, api.Database.InTx(func(db database.Store) error { orgRoles := make([]string, 0) @@ -910,8 +901,6 @@ func (api *API) createUser(ctx context.Context, req createUserRequest) (database UpdatedAt: database.Now(), // All new users are defaulted to members of the site. RBACRoles: []string{}, - LoginType: req.LoginType, - LinkedID: req.LinkedID, } // If a user signs up with OAuth, they can have no password! if req.Password != "" { @@ -966,8 +955,6 @@ func convertUser(user database.User, organizationIDs []uuid.UUID) codersdk.User Status: codersdk.UserStatus(user.Status), OrganizationIDs: organizationIDs, Roles: make([]codersdk.Role, 0), - LoginType: codersdk.LoginType(user.LoginType), - LinkedID: user.LinkedID, } for _, roleName := range user.RBACRoles { From bb9b77760543a6f4bb0fda0f773f3364e53ce9e1 Mon Sep 17 00:00:00 2001 From: Jon Ayers Date: Fri, 12 Aug 2022 03:23:58 +0000 Subject: [PATCH 16/32] even more rewriting --- coderd/database/databasefake/databasefake.go | 77 +++++++++++++++++++ coderd/database/dump.sql | 1 - .../migrations/000034_linked_user_id.up.sql | 3 - coderd/database/models.go | 1 - coderd/database/queries.sql.go | 23 ++---- coderd/httpmw/apikey.go | 8 +- coderd/httpmw/authorize_test.go | 2 - coderd/httpmw/organizationparam_test.go | 2 - coderd/httpmw/templateparam_test.go | 2 - coderd/httpmw/templateversionparam_test.go | 2 - coderd/httpmw/userparam_test.go | 8 +- coderd/httpmw/workspaceagentparam_test.go | 2 - coderd/httpmw/workspacebuildparam_test.go | 2 - coderd/httpmw/workspaceparam_test.go | 2 - coderd/provisionerjobs_internal_test.go | 2 - coderd/telemetry/telemetry_test.go | 4 - 16 files changed, 93 insertions(+), 48 deletions(-) diff --git a/coderd/database/databasefake/databasefake.go b/coderd/database/databasefake/databasefake.go index 1e4f83001390d..d978702d237c6 100644 --- a/coderd/database/databasefake/databasefake.go +++ b/coderd/database/databasefake/databasefake.go @@ -73,6 +73,7 @@ type data struct { organizations []database.Organization organizationMembers []database.OrganizationMember users []database.User + userLinks []database.UserLink // New tables auditLogs []database.AuditLog @@ -2252,3 +2253,79 @@ func (q *fakeQuerier) GetDeploymentID(_ context.Context) (string, error) { return q.deploymentID, nil } + +func (q *fakeQuerier) GetUserLinkByLinkedID(_ context.Context, id string) (database.UserLink, error) { + q.mutex.RLock() + defer q.mutex.RUnlock() + + for _, link := range q.userLinks { + if link.LinkedID == id { + return link, nil + } + } + return database.UserLink{}, sql.ErrNoRows +} + +func (q *fakeQuerier) GetUserLinkByUserIDLoginType(_ context.Context, params database.GetUserLinkByUserIDLoginTypeParams) (database.UserLink, error) { + q.mutex.RLock() + defer q.mutex.RUnlock() + + for _, link := range q.userLinks { + if link.UserID == params.UserID && link.LoginType == params.LoginType { + return link, nil + } + } + return database.UserLink{}, sql.ErrNoRows +} + +func (q *fakeQuerier) InsertUserLink(_ context.Context, params database.InsertUserLinkParams) (database.UserLink, error) { + q.mutex.RLock() + defer q.mutex.RUnlock() + + link := database.UserLink{ + UserID: params.UserID, + LoginType: params.LoginType, + LinkedID: params.LinkedID, + OAuthAccessToken: params.OAuthAccessToken, + OAuthRefreshToken: params.OAuthRefreshToken, + OAuthExpiry: params.OAuthExpiry, + } + + q.userLinks = append(q.userLinks, link) + + return link, nil +} + +func (q *fakeQuerier) UpdateUserLinkedID(_ context.Context, params database.UpdateUserLinkedIDParams) (database.UserLink, error) { + q.mutex.RLock() + defer q.mutex.RUnlock() + + for i, link := range q.userLinks { + if link.UserID == params.UserID && link.LoginType == params.LoginType { + link.LinkedID = params.LinkedID + + q.userLinks[i] = link + return link, nil + } + } + + return database.UserLink{}, sql.ErrNoRows +} + +func (q *fakeQuerier) UpdateUserLink(_ context.Context, params database.UpdateUserLinkParams) (database.UserLink, error) { + q.mutex.RLock() + defer q.mutex.RUnlock() + + for i, link := range q.userLinks { + if link.UserID == params.UserID && link.LoginType == params.LoginType { + link.OAuthAccessToken = params.OAuthAccessToken + link.OAuthRefreshToken = params.OAuthRefreshToken + link.OAuthExpiry = params.OAuthExpiry + + q.userLinks[i] = link + return link, nil + } + } + + return database.UserLink{}, sql.ErrNoRows +} diff --git a/coderd/database/dump.sql b/coderd/database/dump.sql index bf62f67a10481..0f095446661fe 100644 --- a/coderd/database/dump.sql +++ b/coderd/database/dump.sql @@ -269,7 +269,6 @@ CREATE TABLE user_links ( linked_id text DEFAULT ''::text NOT NULL, oauth_access_token text DEFAULT ''::text NOT NULL, oauth_refresh_token text DEFAULT ''::text NOT NULL, - oauth_id_token text DEFAULT ''::text NOT NULL, oauth_expiry timestamp with time zone DEFAULT '0001-01-01 00:00:00+00'::timestamp with time zone NOT NULL ); diff --git a/coderd/database/migrations/000034_linked_user_id.up.sql b/coderd/database/migrations/000034_linked_user_id.up.sql index 8bcddeb692bbd..e9e1a79a5f12a 100644 --- a/coderd/database/migrations/000034_linked_user_id.up.sql +++ b/coderd/database/migrations/000034_linked_user_id.up.sql @@ -6,7 +6,6 @@ CREATE TABLE IF NOT EXISTS user_links ( linked_id text DEFAULT ''::text NOT NULL, oauth_access_token text DEFAULT ''::text NOT NULL, oauth_refresh_token text DEFAULT ''::text NOT NULL, - oauth_id_token text DEFAULT ''::text NOT NULL, oauth_expiry timestamp with time zone DEFAULT '0001-01-01 00:00:00+00'::timestamp with time zone NOT NULL, UNIQUE(user_id, login_type) ); @@ -18,7 +17,6 @@ INSERT INTO user_links linked_id, oauth_access_token, oauth_refresh_token, - oauth_id_token, oauth_expiry ) SELECT @@ -27,7 +25,6 @@ SELECT '', keys.oauth_access_token, keys.oauth_refresh_token, - keys.oauth_id_token, keys.oauth_expiry FROM ( diff --git a/coderd/database/models.go b/coderd/database/models.go index bbb98e9856226..fcf96ab197ca7 100644 --- a/coderd/database/models.go +++ b/coderd/database/models.go @@ -495,7 +495,6 @@ type UserLink struct { LinkedID string `db:"linked_id" json:"linked_id"` OAuthAccessToken string `db:"oauth_access_token" json:"oauth_access_token"` OAuthRefreshToken string `db:"oauth_refresh_token" json:"oauth_refresh_token"` - OAuthIDToken string `db:"oauth_id_token" json:"oauth_id_token"` OAuthExpiry time.Time `db:"oauth_expiry" json:"oauth_expiry"` } diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index 8a25b0740183f..e13cc923f2a96 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -2413,7 +2413,7 @@ func (q *sqlQuerier) UpdateTemplateVersionDescriptionByJobID(ctx context.Context const getUserLinkByLinkedID = `-- name: GetUserLinkByLinkedID :one SELECT - user_id, login_type, linked_id, oauth_access_token, oauth_refresh_token, oauth_id_token, oauth_expiry + user_id, login_type, linked_id, oauth_access_token, oauth_refresh_token, oauth_expiry FROM user_links WHERE @@ -2429,7 +2429,6 @@ func (q *sqlQuerier) GetUserLinkByLinkedID(ctx context.Context, linkedID string) &i.LinkedID, &i.OAuthAccessToken, &i.OAuthRefreshToken, - &i.OAuthIDToken, &i.OAuthExpiry, ) return i, err @@ -2437,7 +2436,7 @@ func (q *sqlQuerier) GetUserLinkByLinkedID(ctx context.Context, linkedID string) const getUserLinkByUserIDLoginType = `-- name: GetUserLinkByUserIDLoginType :one SELECT - user_id, login_type, linked_id, oauth_access_token, oauth_refresh_token, oauth_id_token, oauth_expiry + user_id, login_type, linked_id, oauth_access_token, oauth_refresh_token, oauth_expiry FROM user_links WHERE @@ -2458,7 +2457,6 @@ func (q *sqlQuerier) GetUserLinkByUserIDLoginType(ctx context.Context, arg GetUs &i.LinkedID, &i.OAuthAccessToken, &i.OAuthRefreshToken, - &i.OAuthIDToken, &i.OAuthExpiry, ) return i, err @@ -2472,11 +2470,10 @@ INSERT INTO linked_id, oauth_access_token, oauth_refresh_token, - oauth_id_token, oauth_expiry ) VALUES - ( $1, $2, $3, $4, $5, $6, $7 ) RETURNING user_id, login_type, linked_id, oauth_access_token, oauth_refresh_token, oauth_id_token, oauth_expiry + ( $1, $2, $3, $4, $5, $6 ) RETURNING user_id, login_type, linked_id, oauth_access_token, oauth_refresh_token, oauth_expiry ` type InsertUserLinkParams struct { @@ -2485,7 +2482,6 @@ type InsertUserLinkParams struct { LinkedID string `db:"linked_id" json:"linked_id"` OAuthAccessToken string `db:"oauth_access_token" json:"oauth_access_token"` OAuthRefreshToken string `db:"oauth_refresh_token" json:"oauth_refresh_token"` - OAuthIDToken string `db:"oauth_id_token" json:"oauth_id_token"` OAuthExpiry time.Time `db:"oauth_expiry" json:"oauth_expiry"` } @@ -2496,7 +2492,6 @@ func (q *sqlQuerier) InsertUserLink(ctx context.Context, arg InsertUserLinkParam arg.LinkedID, arg.OAuthAccessToken, arg.OAuthRefreshToken, - arg.OAuthIDToken, arg.OAuthExpiry, ) var i UserLink @@ -2506,7 +2501,6 @@ func (q *sqlQuerier) InsertUserLink(ctx context.Context, arg InsertUserLinkParam &i.LinkedID, &i.OAuthAccessToken, &i.OAuthRefreshToken, - &i.OAuthIDToken, &i.OAuthExpiry, ) return i, err @@ -2518,16 +2512,14 @@ UPDATE SET oauth_access_token = $1, oauth_refresh_token = $2, - oauth_id_token = $3, - oauth_expiry = $4 + oauth_expiry = $3 WHERE - user_id = $5 AND login_type = $6 RETURNING user_id, login_type, linked_id, oauth_access_token, oauth_refresh_token, oauth_id_token, oauth_expiry + user_id = $4 AND login_type = $5 RETURNING user_id, login_type, linked_id, oauth_access_token, oauth_refresh_token, oauth_expiry ` type UpdateUserLinkParams struct { OAuthAccessToken string `db:"oauth_access_token" json:"oauth_access_token"` OAuthRefreshToken string `db:"oauth_refresh_token" json:"oauth_refresh_token"` - OAuthIDToken string `db:"oauth_id_token" json:"oauth_id_token"` OAuthExpiry time.Time `db:"oauth_expiry" json:"oauth_expiry"` UserID uuid.UUID `db:"user_id" json:"user_id"` LoginType LoginType `db:"login_type" json:"login_type"` @@ -2537,7 +2529,6 @@ func (q *sqlQuerier) UpdateUserLink(ctx context.Context, arg UpdateUserLinkParam row := q.db.QueryRowContext(ctx, updateUserLink, arg.OAuthAccessToken, arg.OAuthRefreshToken, - arg.OAuthIDToken, arg.OAuthExpiry, arg.UserID, arg.LoginType, @@ -2549,7 +2540,6 @@ func (q *sqlQuerier) UpdateUserLink(ctx context.Context, arg UpdateUserLinkParam &i.LinkedID, &i.OAuthAccessToken, &i.OAuthRefreshToken, - &i.OAuthIDToken, &i.OAuthExpiry, ) return i, err @@ -2561,7 +2551,7 @@ UPDATE SET linked_id = $1 WHERE - user_id = $2 AND login_type = $3 RETURNING user_id, login_type, linked_id, oauth_access_token, oauth_refresh_token, oauth_id_token, oauth_expiry + user_id = $2 AND login_type = $3 RETURNING user_id, login_type, linked_id, oauth_access_token, oauth_refresh_token, oauth_expiry ` type UpdateUserLinkedIDParams struct { @@ -2579,7 +2569,6 @@ func (q *sqlQuerier) UpdateUserLinkedID(ctx context.Context, arg UpdateUserLinke &i.LinkedID, &i.OAuthAccessToken, &i.OAuthRefreshToken, - &i.OAuthIDToken, &i.OAuthExpiry, ) return i, err diff --git a/coderd/httpmw/apikey.go b/coderd/httpmw/apikey.go index d6c9288e92071..ad2d8c359f338 100644 --- a/coderd/httpmw/apikey.go +++ b/coderd/httpmw/apikey.go @@ -254,7 +254,13 @@ func ExtractAPIKey(db database.Store, oauth *OAuth2Configs, redirectToLogin bool OAuthRefreshToken: link.OAuthRefreshToken, OAuthExpiry: link.OAuthExpiry, }) - + if err != nil { + write(http.StatusInternalServerError, codersdk.Response{ + Message: internalErrorMessage, + Detail: fmt.Sprintf("update user_link: %s.", err.Error()), + }) + return + } } } diff --git a/coderd/httpmw/authorize_test.go b/coderd/httpmw/authorize_test.go index 8ce23ca4ed946..997ac44350340 100644 --- a/coderd/httpmw/authorize_test.go +++ b/coderd/httpmw/authorize_test.go @@ -115,8 +115,6 @@ func addUser(t *testing.T, db database.Store, roles ...string) (database.User, s Email: "admin@email.com", Username: "admin", RBACRoles: roles, - LinkedID: uuid.NewString(), - LoginType: database.LoginTypePassword, }) require.NoError(t, err) diff --git a/coderd/httpmw/organizationparam_test.go b/coderd/httpmw/organizationparam_test.go index 7562eb1e0f896..d17c441741914 100644 --- a/coderd/httpmw/organizationparam_test.go +++ b/coderd/httpmw/organizationparam_test.go @@ -45,8 +45,6 @@ func TestOrganizationParam(t *testing.T) { Username: username, CreatedAt: database.Now(), UpdatedAt: database.Now(), - LinkedID: uuid.NewString(), - LoginType: database.LoginTypePassword, }) require.NoError(t, err) _, err = db.InsertAPIKey(r.Context(), database.InsertAPIKeyParams{ diff --git a/coderd/httpmw/templateparam_test.go b/coderd/httpmw/templateparam_test.go index f3f574659e69a..94abfe82cf5fb 100644 --- a/coderd/httpmw/templateparam_test.go +++ b/coderd/httpmw/templateparam_test.go @@ -44,8 +44,6 @@ func TestTemplateParam(t *testing.T) { Username: username, CreatedAt: database.Now(), UpdatedAt: database.Now(), - LinkedID: uuid.NewString(), - LoginType: database.LoginTypePassword, }) require.NoError(t, err) diff --git a/coderd/httpmw/templateversionparam_test.go b/coderd/httpmw/templateversionparam_test.go index c638be5172b31..5b49f75010bf9 100644 --- a/coderd/httpmw/templateversionparam_test.go +++ b/coderd/httpmw/templateversionparam_test.go @@ -44,8 +44,6 @@ func TestTemplateVersionParam(t *testing.T) { Username: username, CreatedAt: database.Now(), UpdatedAt: database.Now(), - LinkedID: uuid.NewString(), - LoginType: database.LoginTypePassword, }) require.NoError(t, err) diff --git a/coderd/httpmw/userparam_test.go b/coderd/httpmw/userparam_test.go index 3ebedd400a5ee..866df68ef1eec 100644 --- a/coderd/httpmw/userparam_test.go +++ b/coderd/httpmw/userparam_test.go @@ -35,11 +35,9 @@ func TestUserParam(t *testing.T) { }) user, err := db.InsertUser(r.Context(), database.InsertUserParams{ - ID: uuid.New(), - Email: "admin@email.com", - Username: "admin", - LinkedID: uuid.NewString(), - LoginType: database.LoginTypePassword, + ID: uuid.New(), + Email: "admin@email.com", + Username: "admin", }) require.NoError(t, err) diff --git a/coderd/httpmw/workspaceagentparam_test.go b/coderd/httpmw/workspaceagentparam_test.go index 1e19e67c46fc2..a2afaee534c9f 100644 --- a/coderd/httpmw/workspaceagentparam_test.go +++ b/coderd/httpmw/workspaceagentparam_test.go @@ -44,8 +44,6 @@ func TestWorkspaceAgentParam(t *testing.T) { Username: username, CreatedAt: database.Now(), UpdatedAt: database.Now(), - LinkedID: uuid.NewString(), - LoginType: database.LoginTypePassword, }) require.NoError(t, err) diff --git a/coderd/httpmw/workspacebuildparam_test.go b/coderd/httpmw/workspacebuildparam_test.go index 4fe606533bd0f..6d402f01fc62b 100644 --- a/coderd/httpmw/workspacebuildparam_test.go +++ b/coderd/httpmw/workspacebuildparam_test.go @@ -44,8 +44,6 @@ func TestWorkspaceBuildParam(t *testing.T) { Username: username, CreatedAt: database.Now(), UpdatedAt: database.Now(), - LinkedID: uuid.NewString(), - LoginType: database.LoginTypePassword, }) require.NoError(t, err) diff --git a/coderd/httpmw/workspaceparam_test.go b/coderd/httpmw/workspaceparam_test.go index 633275a645247..eac847a584f3b 100644 --- a/coderd/httpmw/workspaceparam_test.go +++ b/coderd/httpmw/workspaceparam_test.go @@ -44,8 +44,6 @@ func TestWorkspaceParam(t *testing.T) { Username: username, CreatedAt: database.Now(), UpdatedAt: database.Now(), - LinkedID: uuid.NewString(), - LoginType: database.LoginTypePassword, }) require.NoError(t, err) diff --git a/coderd/provisionerjobs_internal_test.go b/coderd/provisionerjobs_internal_test.go index bf2d996aba2d0..4d215f6bb2a92 100644 --- a/coderd/provisionerjobs_internal_test.go +++ b/coderd/provisionerjobs_internal_test.go @@ -78,8 +78,6 @@ func TestProvisionerJobLogs_Unit(t *testing.T) { _, err = fDB.InsertUser(ctx, database.InsertUserParams{ ID: userID, RBACRoles: []string{"admin"}, - LinkedID: uuid.NewString(), - LoginType: database.LoginTypePassword, }) require.NoError(t, err) _, err = fDB.InsertWorkspaceBuild(ctx, database.InsertWorkspaceBuildParams{ diff --git a/coderd/telemetry/telemetry_test.go b/coderd/telemetry/telemetry_test.go index 25fe91d348a6b..4e78b19ec8c54 100644 --- a/coderd/telemetry/telemetry_test.go +++ b/coderd/telemetry/telemetry_test.go @@ -59,8 +59,6 @@ func TestTelemetry(t *testing.T) { _, err = db.InsertUser(ctx, database.InsertUserParams{ ID: uuid.New(), CreatedAt: database.Now(), - LinkedID: uuid.NewString(), - LoginType: database.LoginTypePassword, }) require.NoError(t, err) _, err = db.InsertWorkspace(ctx, database.InsertWorkspaceParams{ @@ -107,8 +105,6 @@ func TestTelemetry(t *testing.T) { ID: uuid.New(), Email: "kyle@coder.com", CreatedAt: database.Now(), - LinkedID: uuid.NewString(), - LoginType: database.LoginTypePassword, }) require.NoError(t, err) snapshot := collectSnapshot(t, db) From d940daedb37e4b26fc08c176cfe421329a07d031 Mon Sep 17 00:00:00 2001 From: Jon Ayers Date: Fri, 12 Aug 2022 04:16:43 +0000 Subject: [PATCH 17/32] finish up some test fixing --- coderd/audit/table.go | 2 - coderd/database/databasefake/databasefake.go | 15 ++--- .../migrations/000034_linked_user_id.up.sql | 10 ++++ coderd/database/postgres/postgres.go | 25 ++++---- coderd/httpmw/apikey.go | 3 +- coderd/userauth.go | 18 ++++-- coderd/userauth_test.go | 59 +++++++++---------- codersdk/users.go | 2 - site/src/api/typesGenerated.ts | 2 - 9 files changed, 71 insertions(+), 65 deletions(-) diff --git a/coderd/audit/table.go b/coderd/audit/table.go index 8d3d47ca62060..c842956e6cf24 100644 --- a/coderd/audit/table.go +++ b/coderd/audit/table.go @@ -94,8 +94,6 @@ var AuditableResources = auditMap(map[any]map[string]Action{ "updated_at": ActionIgnore, // Changes, but is implicit and not helpful in a diff. "status": ActionTrack, "rbac_roles": ActionTrack, - "login_type": ActionTrack, - "linked_id": ActionTrack, }, &database.Workspace{}: { "id": ActionTrack, diff --git a/coderd/database/databasefake/databasefake.go b/coderd/database/databasefake/databasefake.go index d978702d237c6..233fb63c59fa9 100644 --- a/coderd/database/databasefake/databasefake.go +++ b/coderd/database/databasefake/databasefake.go @@ -2278,17 +2278,18 @@ func (q *fakeQuerier) GetUserLinkByUserIDLoginType(_ context.Context, params dat return database.UserLink{}, sql.ErrNoRows } -func (q *fakeQuerier) InsertUserLink(_ context.Context, params database.InsertUserLinkParams) (database.UserLink, error) { +func (q *fakeQuerier) InsertUserLink(_ context.Context, args database.InsertUserLinkParams) (database.UserLink, error) { q.mutex.RLock() defer q.mutex.RUnlock() + //nolint:gosimple link := database.UserLink{ - UserID: params.UserID, - LoginType: params.LoginType, - LinkedID: params.LinkedID, - OAuthAccessToken: params.OAuthAccessToken, - OAuthRefreshToken: params.OAuthRefreshToken, - OAuthExpiry: params.OAuthExpiry, + UserID: args.UserID, + LoginType: args.LoginType, + LinkedID: args.LinkedID, + OAuthAccessToken: args.OAuthAccessToken, + OAuthRefreshToken: args.OAuthRefreshToken, + OAuthExpiry: args.OAuthExpiry, } q.userLinks = append(q.userLinks, link) diff --git a/coderd/database/migrations/000034_linked_user_id.up.sql b/coderd/database/migrations/000034_linked_user_id.up.sql index e9e1a79a5f12a..a64f2d0c84258 100644 --- a/coderd/database/migrations/000034_linked_user_id.up.sql +++ b/coderd/database/migrations/000034_linked_user_id.up.sql @@ -10,6 +10,13 @@ CREATE TABLE IF NOT EXISTS user_links ( UNIQUE(user_id, login_type) ); +-- This migrates columns on api_keys to the new user_links table. +-- It does this by finding all the API keys for each user, choosing +-- the most recently updated for each one and then assigning its relevant +-- values to the user_links table. +-- A user should at most have a row for an OIDC account and a Github account. +-- 'password' login types are ignored. + INSERT INTO user_links ( user_id, @@ -34,6 +41,9 @@ FROM ) as keys WHERE x=1 AND keys.login_type != 'password'; +-- Drop columns that have been migrated to user_links. +-- It appears the 'oauth_id_token' was unused and so it has +-- been dropped here as well to avoid future confusion. ALTER TABLE api_keys DROP COLUMN oauth_access_token, DROP COLUMN oauth_refresh_token, diff --git a/coderd/database/postgres/postgres.go b/coderd/database/postgres/postgres.go index f9ee2cd84a2c5..d1ef7b3084197 100644 --- a/coderd/database/postgres/postgres.go +++ b/coderd/database/postgres/postgres.go @@ -122,30 +122,27 @@ func Open() (string, func(), error) { return "", nil, xerrors.Errorf("expire resource: %w", err) } - pool.MaxWait = 15 * time.Second - var retryErr error + pool.MaxWait = 120 * time.Second err = pool.Retry(func() error { - var db *sql.DB - db, retryErr := sql.Open("postgres", dbURL) - if retryErr != nil { - return xerrors.Errorf("open postgres: %w", retryErr) + db, err := sql.Open("postgres", dbURL) + if err != nil { + return xerrors.Errorf("open postgres: %w", err) } defer db.Close() - retryErr = db.Ping() - if retryErr != nil { - return xerrors.Errorf("ping postgres: %w", retryErr) + err = db.Ping() + if err != nil { + return xerrors.Errorf("ping postgres: %w", err) } - retryErr = database.MigrateUp(db) - if retryErr != nil { - fmt.Printf("err: %v\n", retryErr) - return xerrors.Errorf("migrate db: %w", retryErr) + err = database.MigrateUp(db) + if err != nil { + return xerrors.Errorf("migrate db: %w", err) } return nil }) if err != nil { - return "", nil, retryErr + return "", nil, err } return dbURL, func() { _ = pool.Purge(resource) diff --git a/coderd/httpmw/apikey.go b/coderd/httpmw/apikey.go index ad2d8c359f338..498e48f05c1eb 100644 --- a/coderd/httpmw/apikey.go +++ b/coderd/httpmw/apikey.go @@ -159,7 +159,7 @@ func ExtractAPIKey(db database.Store, oauth *OAuth2Configs, redirectToLogin bool if err != nil { write(http.StatusInternalServerError, codersdk.Response{ Message: "A database error occurred", - Detail: err.Error(), + Detail: fmt.Sprintf("get user link by user ID and login type: %s", err.Error()), }) return } @@ -250,6 +250,7 @@ func ExtractAPIKey(db database.Store, oauth *OAuth2Configs, redirectToLogin bool if link.UserID != uuid.Nil { link, err = db.UpdateUserLink(r.Context(), database.UpdateUserLinkParams{ UserID: link.UserID, + LoginType: link.LoginType, OAuthAccessToken: link.OAuthAccessToken, OAuthRefreshToken: link.OAuthRefreshToken, OAuthExpiry: link.OAuthExpiry, diff --git a/coderd/userauth.go b/coderd/userauth.go index c402a42ed7944..93fe54b23aec4 100644 --- a/coderd/userauth.go +++ b/coderd/userauth.go @@ -140,7 +140,8 @@ func (api *API) userOAuth2Github(rw http.ResponseWriter, r *http.Request) { user, link, err := findLinkedUser(ctx, api.Database, githubLinkedID(ghUser), verifiedEmails...) if err != nil { httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{ - Message: "Failed to find user.", + Message: "An internal error occurred.", + Detail: err.Error(), }) return } @@ -195,7 +196,6 @@ func (api *API) userOAuth2Github(rw http.ResponseWriter, r *http.Request) { }) return } - } // This can happen if a user is a built-in user but is signing in @@ -220,7 +220,9 @@ func (api *API) userOAuth2Github(rw http.ResponseWriter, r *http.Request) { // LEGACY: Remove 10/2022. // We started tracking linked IDs later so it's possible for a user to be a - // pre-existing Github user and not have a linked ID. + // pre-existing Github user and not have a linked ID. The migration + // to user_links did not populate this field as it requires calling out + // to Github to query the user's ID. if link.LinkedID == "" { link, err = api.Database.UpdateUserLinkedID(ctx, database.UpdateUserLinkedIDParams{ UserID: user.ID, @@ -387,7 +389,7 @@ func (api *API) userOIDC(rw http.ResponseWriter, r *http.Request) { if link.UserID == uuid.Nil { link, err = api.Database.InsertUserLink(ctx, database.InsertUserLinkParams{ UserID: user.ID, - LoginType: database.LoginTypeGithub, + LoginType: database.LoginTypeOIDC, LinkedID: oidcLinkedID(idToken), OAuthAccessToken: state.Token.AccessToken, OAuthRefreshToken: state.Token.RefreshToken, @@ -400,12 +402,13 @@ func (api *API) userOIDC(rw http.ResponseWriter, r *http.Request) { }) return } - } // LEGACY: Remove 10/2022. // We started tracking linked IDs later so it's possible for a user to be a // pre-existing OIDC user and not have a linked ID. + // The migration that added the user_links table could not populate + // the 'linked_id' field since it requires fields off the access token. if link.LinkedID == "" { link, err = api.Database.UpdateUserLinkedID(ctx, database.UpdateUserLinkedIDParams{ UserID: user.ID, @@ -440,7 +443,7 @@ func (api *API) userOIDC(rw http.ResponseWriter, r *http.Request) { if err != nil { httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{ Message: "Failed to update user profile.", - Detail: err.Error(), + Detail: fmt.Sprintf("udpate user profile: %s", err.Error()), }) return } @@ -486,6 +489,9 @@ func findLinkedUser(ctx context.Context, db database.Store, linkedID string, ema if err == nil { user, err = db.GetUserByID(ctx, link.UserID) + if err != nil { + return database.User{}, database.UserLink{}, xerrors.Errorf("get user by id: %w", err) + } return user, link, nil } diff --git a/coderd/userauth_test.go b/coderd/userauth_test.go index f3def2084804a..9db68b4138e18 100644 --- a/coderd/userauth_test.go +++ b/coderd/userauth_test.go @@ -175,6 +175,34 @@ func TestUserOAuth2Github(t *testing.T) { resp := oauth2Callback(t, client) require.Equal(t, http.StatusForbidden, resp.StatusCode) }) + t.Run("Login", func(t *testing.T) { + t.Parallel() + client := coderdtest.New(t, &coderdtest.Options{ + GithubOAuth2Config: &coderd.GithubOAuth2Config{ + OAuth2Config: &oauth2Config{}, + AllowOrganizations: []string{"coder"}, + ListOrganizationMemberships: func(ctx context.Context, client *http.Client) ([]*github.Membership, error) { + return []*github.Membership{{ + Organization: &github.Organization{ + Login: github.String("coder"), + }, + }}, nil + }, + AuthenticatedUser: func(ctx context.Context, client *http.Client) (*github.User, error) { + return &github.User{}, nil + }, + ListEmails: func(ctx context.Context, client *http.Client) ([]*github.UserEmail, error) { + return []*github.UserEmail{{ + Email: github.String("testuser@coder.com"), + Verified: github.Bool(true), + }}, nil + }, + }, + }) + _ = coderdtest.CreateFirstUser(t, client) + resp := oauth2Callback(t, client) + require.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode) + }) t.Run("Signup", func(t *testing.T) { t.Parallel() client := coderdtest.New(t, &coderdtest.Options{ @@ -210,7 +238,6 @@ func TestUserOAuth2Github(t *testing.T) { client.SessionToken = resp.Cookies()[0].Value user, err := client.User(context.Background(), "me") require.NoError(t, err) - require.Equal(t, "1234", user.LinkedID) require.Equal(t, "kyle@coder.com", user.Email) require.Equal(t, "kyle", user.Username) }) @@ -341,7 +368,6 @@ func TestUserOIDC(t *testing.T) { user, err := client.User(ctx, "me") require.NoError(t, err) require.Equal(t, tc.Username, user.Username) - require.Equal(t, "https://coder.com||hello", user.LinkedID) } }) } @@ -385,35 +411,6 @@ func TestUserOIDC(t *testing.T) { resp := oidcCallback(t, client) require.Equal(t, http.StatusBadRequest, resp.StatusCode) }) - - // Test that we do not allow collisions with pre-existing accounts - // of differing login types. - t.Run("InvalidLoginType", func(t *testing.T) { - t.Parallel() - config := createOIDCConfig(t, jwt.MapClaims{ - "email": "kyle@kwc.io", - "email_verified": true, - "preferred_username": "kyle", - }) - - client := coderdtest.New(t, &coderdtest.Options{ - OIDCConfig: config, - }) - - _, err := client.CreateFirstUser(context.Background(), codersdk.CreateFirstUserRequest{ - Email: "kyle@kwc.io", - Username: "kyle", - Password: "yeah", - OrganizationName: "default", - }) - require.NoError(t, err) - - config.AllowSignups = true - config.EmailDomain = "kwc.io" - - resp := oidcCallback(t, client) - assert.Equal(t, http.StatusConflict, resp.StatusCode) - }) } // createOIDCConfig generates a new OIDCConfig that returns a static token diff --git a/codersdk/users.go b/codersdk/users.go index e0d9ead70381f..72b51306a5bb8 100644 --- a/codersdk/users.go +++ b/codersdk/users.go @@ -50,8 +50,6 @@ type User struct { Status UserStatus `json:"status"` OrganizationIDs []uuid.UUID `json:"organization_ids"` Roles []Role `json:"roles"` - LoginType LoginType `json:"login_type"` - LinkedID string `json:"linked_id"` } type APIKey struct { diff --git a/site/src/api/typesGenerated.ts b/site/src/api/typesGenerated.ts index 3adcefb0e45db..5ece107be5af8 100644 --- a/site/src/api/typesGenerated.ts +++ b/site/src/api/typesGenerated.ts @@ -352,8 +352,6 @@ export interface User { readonly status: UserStatus readonly organization_ids: string[] readonly roles: Role[] - readonly login_type: LoginType - readonly linked_id: string } // From codersdk/users.go From c97d57206adc4b4733c47ecf5a0f08343dac7c13 Mon Sep 17 00:00:00 2001 From: Jon Ayers Date: Fri, 12 Aug 2022 04:19:58 +0000 Subject: [PATCH 18/32] typos --- coderd/userauth.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/coderd/userauth.go b/coderd/userauth.go index 93fe54b23aec4..3cb3d94776385 100644 --- a/coderd/userauth.go +++ b/coderd/userauth.go @@ -443,7 +443,7 @@ func (api *API) userOIDC(rw http.ResponseWriter, r *http.Request) { if err != nil { httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{ Message: "Failed to update user profile.", - Detail: fmt.Sprintf("udpate user profile: %s", err.Error()), + Detail: fmt.Sprintf("update user profile: %s", err.Error()), }) return } From 28a37f1604d33783e964e371a38ec6fcc84bf8a5 Mon Sep 17 00:00:00 2001 From: Jon Ayers Date: Fri, 12 Aug 2022 04:38:41 +0000 Subject: [PATCH 19/32] fix some remaining tests --- coderd/httpmw/apikey.go | 4 +++- coderd/httpmw/apikey_test.go | 14 ++++++++++++++ coderd/httpmw/userparam_test.go | 1 + coderd/httpmw/workspacebuildparam_test.go | 1 + coderd/provisionerjobs_internal_test.go | 1 + 5 files changed, 20 insertions(+), 1 deletion(-) diff --git a/coderd/httpmw/apikey.go b/coderd/httpmw/apikey.go index 498e48f05c1eb..7e2c700dbf8fc 100644 --- a/coderd/httpmw/apikey.go +++ b/coderd/httpmw/apikey.go @@ -151,7 +151,9 @@ func ExtractAPIKey(db database.Store, oauth *OAuth2Configs, redirectToLogin bool changed := false var link database.UserLink - if key.LoginType != database.LoginTypePassword { + // The login_type should never be empty but sometimes it is + // for tests. + if key.LoginType != "" && key.LoginType != database.LoginTypePassword { link, err = db.GetUserLinkByUserIDLoginType(r.Context(), database.GetUserLinkByUserIDLoginTypeParams{ UserID: key.UserID, LoginType: key.LoginType, diff --git a/coderd/httpmw/apikey_test.go b/coderd/httpmw/apikey_test.go index 28343b7f9bded..64d0d5e198e88 100644 --- a/coderd/httpmw/apikey_test.go +++ b/coderd/httpmw/apikey_test.go @@ -361,6 +361,13 @@ func TestAPIKey(t *testing.T) { UserID: user.ID, }) require.NoError(t, err) + + _, err = db.InsertUserLink(r.Context(), database.InsertUserLinkParams{ + UserID: user.ID, + LoginType: database.LoginTypeGithub, + }) + require.NoError(t, err) + httpmw.ExtractAPIKey(db, nil, false)(successHandler).ServeHTTP(rw, r) res := rw.Result() defer res.Body.Close() @@ -396,6 +403,13 @@ func TestAPIKey(t *testing.T) { UserID: user.ID, }) require.NoError(t, err) + _, err = db.InsertUserLink(r.Context(), database.InsertUserLinkParams{ + UserID: user.ID, + LoginType: database.LoginTypeGithub, + OAuthExpiry: database.Now().AddDate(0, 0, -1), + }) + require.NoError(t, err) + token := &oauth2.Token{ AccessToken: "wow", RefreshToken: "moo", diff --git a/coderd/httpmw/userparam_test.go b/coderd/httpmw/userparam_test.go index 866df68ef1eec..d21f2a583d58f 100644 --- a/coderd/httpmw/userparam_test.go +++ b/coderd/httpmw/userparam_test.go @@ -47,6 +47,7 @@ func TestUserParam(t *testing.T) { HashedSecret: hashed[:], LastUsed: database.Now(), ExpiresAt: database.Now().Add(time.Minute), + LoginType: database.LoginTypePassword, }) require.NoError(t, err) diff --git a/coderd/httpmw/workspacebuildparam_test.go b/coderd/httpmw/workspacebuildparam_test.go index 6d402f01fc62b..c993a0cb66fb6 100644 --- a/coderd/httpmw/workspacebuildparam_test.go +++ b/coderd/httpmw/workspacebuildparam_test.go @@ -53,6 +53,7 @@ func TestWorkspaceBuildParam(t *testing.T) { HashedSecret: hashed[:], LastUsed: database.Now(), ExpiresAt: database.Now().Add(time.Minute), + LoginType: database.LoginTypePassword, }) require.NoError(t, err) diff --git a/coderd/provisionerjobs_internal_test.go b/coderd/provisionerjobs_internal_test.go index 4d215f6bb2a92..4d7bcb42a45da 100644 --- a/coderd/provisionerjobs_internal_test.go +++ b/coderd/provisionerjobs_internal_test.go @@ -73,6 +73,7 @@ func TestProvisionerJobLogs_Unit(t *testing.T) { HashedSecret: hashed[:], UserID: userID, ExpiresAt: time.Now().Add(5 * time.Hour), + LoginType: database.LoginTypePassword, }) require.NoError(t, err) _, err = fDB.InsertUser(ctx, database.InsertUserParams{ From c889bf0046f4b6b593ee5124439b1402fba86957 Mon Sep 17 00:00:00 2001 From: Jon Ayers Date: Fri, 12 Aug 2022 04:47:50 +0000 Subject: [PATCH 20/32] fix a gnarly bug --- coderd/userauth.go | 25 ++++++++++++++++++++----- 1 file changed, 20 insertions(+), 5 deletions(-) diff --git a/coderd/userauth.go b/coderd/userauth.go index 3cb3d94776385..e017a5ce38386 100644 --- a/coderd/userauth.go +++ b/coderd/userauth.go @@ -137,7 +137,7 @@ func (api *API) userOAuth2Github(rw http.ResponseWriter, r *http.Request) { return } - user, link, err := findLinkedUser(ctx, api.Database, githubLinkedID(ghUser), verifiedEmails...) + user, link, err := findLinkedUser(ctx, api.Database, database.LoginTypeGithub, githubLinkedID(ghUser), verifiedEmails...) if err != nil { httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{ Message: "An internal error occurred.", @@ -332,7 +332,7 @@ func (api *API) userOIDC(rw http.ResponseWriter, r *http.Request) { } } - user, link, err := findLinkedUser(ctx, api.Database, oidcLinkedID(idToken), claims.Email) + user, link, err := findLinkedUser(ctx, api.Database, database.LoginTypeOIDC, oidcLinkedID(idToken), claims.Email) if err != nil { httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{ Message: "Failed to find user.", @@ -477,7 +477,7 @@ func oidcLinkedID(tok *oidc.IDToken) string { // findLinkedUser tries to find a user by their unique OAuth-linked ID. // If it doesn't not find it, it returns the user by their email. -func findLinkedUser(ctx context.Context, db database.Store, linkedID string, emails ...string) (database.User, database.UserLink, error) { +func findLinkedUser(ctx context.Context, db database.Store, loginType database.LoginType, linkedID string, emails ...string) (database.User, database.UserLink, error) { var ( user database.User link database.UserLink @@ -505,9 +505,24 @@ func findLinkedUser(ctx context.Context, db database.Store, linkedID string, ema if errors.Is(err, sql.ErrNoRows) { continue } - return user, link, nil + break + } + + if user.ID == uuid.Nil { + // No user found. + return database.User{}, database.UserLink{}, nil + } + + // LEGACY: This is annoying but we have to search for the user_link + // again except this time we search by user_id and login_type. It's + // possible that a user_link exists without a populated 'linked_id'. + link, err = db.GetUserLinkByUserIDLoginType(ctx, database.GetUserLinkByUserIDLoginTypeParams{ + UserID: user.ID, + LoginType: loginType, + }) + if err != nil && !errors.Is(err, sql.ErrNoRows) { + return database.User{}, database.UserLink{}, xerrors.Errorf("get user link by user id and login type: %w", err) } - // No user found. return user, link, nil } From 0196a496881f81cbd14416a64e08cc5f250ee3b4 Mon Sep 17 00:00:00 2001 From: Jon Ayers Date: Fri, 12 Aug 2022 04:56:04 +0000 Subject: [PATCH 21/32] add a down migration --- .../migrations/000034_linked_user_id.down.sql | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/coderd/database/migrations/000034_linked_user_id.down.sql b/coderd/database/migrations/000034_linked_user_id.down.sql index c4c4836d78fcb..a1f4e9112e4fb 100644 --- a/coderd/database/migrations/000034_linked_user_id.down.sql +++ b/coderd/database/migrations/000034_linked_user_id.down.sql @@ -1,3 +1,22 @@ +-- This migration makes no attempt to try to populate +-- the oauth_access_token, oauth_refresh_token, and oauth_expiry +-- columns of api_key rows with the values from the dropped user_links +-- table. BEGIN; +DROP TABLE IF EXISTS user_links; + +ALTER TABLE + api_keys +ADD COLUMN oauth_access_token text DEFAULT ''::text NOT NULL; + +ALTER TABLE + api_keys +ADD COLUMN oauth_refresh_token text DEFAULT ''::text NOT NULL; + +ALTER TABLE + api_keys +ADD COLUMN oauth_expiry timestamp with time zone DEFAULT '0001-01-01 00:00:00+00'::timestamp with time zone NOT NULL; + + COMMIT; From b5dc95b57cc5e5812c5a9e08bd26c64086c76aa1 Mon Sep 17 00:00:00 2001 From: Jon Ayers Date: Fri, 12 Aug 2022 19:49:38 +0000 Subject: [PATCH 22/32] add fkey on user_links, fix tests, add comments --- coderd/database/dump.sql | 3 +++ coderd/database/migrations/000034_linked_user_id.up.sql | 3 ++- coderd/database/postgres/postgres.go | 1 + coderd/httpmw/apikey.go | 6 +++--- coderd/httpmw/apikey_test.go | 6 ++++++ coderd/httpmw/authorize_test.go | 1 + coderd/httpmw/organizationparam_test.go | 1 + coderd/httpmw/templateparam_test.go | 1 + coderd/httpmw/templateversionparam_test.go | 1 + coderd/httpmw/workspaceagentparam_test.go | 1 + coderd/httpmw/workspaceparam_test.go | 1 + 11 files changed, 21 insertions(+), 4 deletions(-) diff --git a/coderd/database/dump.sql b/coderd/database/dump.sql index 0f095446661fe..cc003937fa139 100644 --- a/coderd/database/dump.sql +++ b/coderd/database/dump.sql @@ -521,6 +521,9 @@ ALTER TABLE ONLY templates ALTER TABLE ONLY templates ADD CONSTRAINT templates_organization_id_fkey FOREIGN KEY (organization_id) REFERENCES organizations(id) ON DELETE CASCADE; +ALTER TABLE ONLY user_links + ADD CONSTRAINT user_links_user_id_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE; + ALTER TABLE ONLY workspace_agents ADD CONSTRAINT workspace_agents_resource_id_fkey FOREIGN KEY (resource_id) REFERENCES workspace_resources(id) ON DELETE CASCADE; diff --git a/coderd/database/migrations/000034_linked_user_id.up.sql b/coderd/database/migrations/000034_linked_user_id.up.sql index a64f2d0c84258..5aa110a306b70 100644 --- a/coderd/database/migrations/000034_linked_user_id.up.sql +++ b/coderd/database/migrations/000034_linked_user_id.up.sql @@ -7,7 +7,8 @@ CREATE TABLE IF NOT EXISTS user_links ( oauth_access_token text DEFAULT ''::text NOT NULL, oauth_refresh_token text DEFAULT ''::text NOT NULL, oauth_expiry timestamp with time zone DEFAULT '0001-01-01 00:00:00+00'::timestamp with time zone NOT NULL, - UNIQUE(user_id, login_type) + UNIQUE(user_id, login_type), + FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE ); -- This migrates columns on api_keys to the new user_links table. diff --git a/coderd/database/postgres/postgres.go b/coderd/database/postgres/postgres.go index d1ef7b3084197..d401c2ecc6341 100644 --- a/coderd/database/postgres/postgres.go +++ b/coderd/database/postgres/postgres.go @@ -136,6 +136,7 @@ func Open() (string, func(), error) { } err = database.MigrateUp(db) if err != nil { + fmt.Printf("err: %v\n", err) return xerrors.Errorf("migrate db: %w", err) } diff --git a/coderd/httpmw/apikey.go b/coderd/httpmw/apikey.go index 7e2c700dbf8fc..ed2a1e617f567 100644 --- a/coderd/httpmw/apikey.go +++ b/coderd/httpmw/apikey.go @@ -151,9 +151,7 @@ func ExtractAPIKey(db database.Store, oauth *OAuth2Configs, redirectToLogin bool changed := false var link database.UserLink - // The login_type should never be empty but sometimes it is - // for tests. - if key.LoginType != "" && key.LoginType != database.LoginTypePassword { + if key.LoginType != database.LoginTypePassword { link, err = db.GetUserLinkByUserIDLoginType(r.Context(), database.GetUserLinkByUserIDLoginTypeParams{ UserID: key.UserID, LoginType: key.LoginType, @@ -249,6 +247,8 @@ func ExtractAPIKey(db database.Store, oauth *OAuth2Configs, redirectToLogin bool }) return } + // If the API Key is associated with a user_link (e.g. Github/OIDC) + // then we want to update the relevant oauth fields. if link.UserID != uuid.Nil { link, err = db.UpdateUserLink(r.Context(), database.UpdateUserLinkParams{ UserID: link.UserID, diff --git a/coderd/httpmw/apikey_test.go b/coderd/httpmw/apikey_test.go index 64d0d5e198e88..adc13b2f176fa 100644 --- a/coderd/httpmw/apikey_test.go +++ b/coderd/httpmw/apikey_test.go @@ -187,6 +187,7 @@ func TestAPIKey(t *testing.T) { ID: id, HashedSecret: hashed[:], UserID: user.ID, + LoginType: database.LoginTypePassword, }) require.NoError(t, err) httpmw.ExtractAPIKey(db, nil, false)(successHandler).ServeHTTP(rw, r) @@ -215,6 +216,7 @@ func TestAPIKey(t *testing.T) { HashedSecret: hashed[:], ExpiresAt: database.Now().AddDate(0, 0, 1), UserID: user.ID, + LoginType: database.LoginTypePassword, }) require.NoError(t, err) httpmw.ExtractAPIKey(db, nil, false)(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { @@ -253,6 +255,7 @@ func TestAPIKey(t *testing.T) { HashedSecret: hashed[:], ExpiresAt: database.Now().AddDate(0, 0, 1), UserID: user.ID, + LoginType: database.LoginTypePassword, }) require.NoError(t, err) httpmw.ExtractAPIKey(db, nil, false)(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { @@ -288,6 +291,7 @@ func TestAPIKey(t *testing.T) { LastUsed: database.Now().AddDate(0, 0, -1), ExpiresAt: database.Now().AddDate(0, 0, 1), UserID: user.ID, + LoginType: database.LoginTypePassword, }) require.NoError(t, err) httpmw.ExtractAPIKey(db, nil, false)(successHandler).ServeHTTP(rw, r) @@ -323,6 +327,7 @@ func TestAPIKey(t *testing.T) { LastUsed: database.Now(), ExpiresAt: database.Now().Add(time.Minute), UserID: user.ID, + LoginType: database.LoginTypePassword, }) require.NoError(t, err) httpmw.ExtractAPIKey(db, nil, false)(successHandler).ServeHTTP(rw, r) @@ -455,6 +460,7 @@ func TestAPIKey(t *testing.T) { LastUsed: database.Now().AddDate(0, 0, -1), ExpiresAt: database.Now().AddDate(0, 0, 1), UserID: user.ID, + LoginType: database.LoginTypePassword, }) require.NoError(t, err) httpmw.ExtractAPIKey(db, nil, false)(successHandler).ServeHTTP(rw, r) diff --git a/coderd/httpmw/authorize_test.go b/coderd/httpmw/authorize_test.go index 997ac44350340..8dc9c252634e9 100644 --- a/coderd/httpmw/authorize_test.go +++ b/coderd/httpmw/authorize_test.go @@ -124,6 +124,7 @@ func addUser(t *testing.T, db database.Store, roles ...string) (database.User, s HashedSecret: hashed[:], LastUsed: database.Now(), ExpiresAt: database.Now().Add(time.Minute), + LoginType: database.LoginTypePassword, }) require.NoError(t, err) diff --git a/coderd/httpmw/organizationparam_test.go b/coderd/httpmw/organizationparam_test.go index d17c441741914..bdf442391091c 100644 --- a/coderd/httpmw/organizationparam_test.go +++ b/coderd/httpmw/organizationparam_test.go @@ -53,6 +53,7 @@ func TestOrganizationParam(t *testing.T) { HashedSecret: hashed[:], LastUsed: database.Now(), ExpiresAt: database.Now().Add(time.Minute), + LoginType: database.LoginTypePassword, }) require.NoError(t, err) r = r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, chi.NewRouteContext())) diff --git a/coderd/httpmw/templateparam_test.go b/coderd/httpmw/templateparam_test.go index 94abfe82cf5fb..7da35889aa536 100644 --- a/coderd/httpmw/templateparam_test.go +++ b/coderd/httpmw/templateparam_test.go @@ -53,6 +53,7 @@ func TestTemplateParam(t *testing.T) { HashedSecret: hashed[:], LastUsed: database.Now(), ExpiresAt: database.Now().Add(time.Minute), + LoginType: database.LoginTypePassword, }) require.NoError(t, err) diff --git a/coderd/httpmw/templateversionparam_test.go b/coderd/httpmw/templateversionparam_test.go index 5b49f75010bf9..fe4ebba9dfcc3 100644 --- a/coderd/httpmw/templateversionparam_test.go +++ b/coderd/httpmw/templateversionparam_test.go @@ -53,6 +53,7 @@ func TestTemplateVersionParam(t *testing.T) { HashedSecret: hashed[:], LastUsed: database.Now(), ExpiresAt: database.Now().Add(time.Minute), + LoginType: database.LoginTypePassword, }) require.NoError(t, err) diff --git a/coderd/httpmw/workspaceagentparam_test.go b/coderd/httpmw/workspaceagentparam_test.go index a2afaee534c9f..c2d047fdda983 100644 --- a/coderd/httpmw/workspaceagentparam_test.go +++ b/coderd/httpmw/workspaceagentparam_test.go @@ -53,6 +53,7 @@ func TestWorkspaceAgentParam(t *testing.T) { HashedSecret: hashed[:], LastUsed: database.Now(), ExpiresAt: database.Now().Add(time.Minute), + LoginType: database.LoginTypePassword, }) require.NoError(t, err) diff --git a/coderd/httpmw/workspaceparam_test.go b/coderd/httpmw/workspaceparam_test.go index eac847a584f3b..ecbebf13c5a4c 100644 --- a/coderd/httpmw/workspaceparam_test.go +++ b/coderd/httpmw/workspaceparam_test.go @@ -53,6 +53,7 @@ func TestWorkspaceParam(t *testing.T) { HashedSecret: hashed[:], LastUsed: database.Now(), ExpiresAt: database.Now().Add(time.Minute), + LoginType: database.LoginTypePassword, }) require.NoError(t, err) From f2f76e97c30f9e8c4943097fa5818c37215f9fe8 Mon Sep 17 00:00:00 2001 From: Jon Ayers Date: Fri, 12 Aug 2022 23:47:19 +0000 Subject: [PATCH 23/32] add login_type to users table --- coderd/database/dump.sql | 3 +- .../migrations/000034_linked_user_id.up.sql | 15 +++++++ coderd/database/models.go | 1 + coderd/database/postgres/postgres.go | 1 - coderd/database/queries.sql.go | 29 +++++++++---- coderd/database/queries/users.sql | 5 ++- coderd/userauth.go | 42 +++++++++++-------- coderd/userauth_test.go | 3 +- coderd/users.go | 31 +++++++++++--- 9 files changed, 92 insertions(+), 38 deletions(-) diff --git a/coderd/database/dump.sql b/coderd/database/dump.sql index cc003937fa139..7782646c94d07 100644 --- a/coderd/database/dump.sql +++ b/coderd/database/dump.sql @@ -280,7 +280,8 @@ CREATE TABLE users ( created_at timestamp with time zone NOT NULL, updated_at timestamp with time zone NOT NULL, status user_status DEFAULT 'active'::public.user_status NOT NULL, - rbac_roles text[] DEFAULT '{}'::text[] NOT NULL + rbac_roles text[] DEFAULT '{}'::text[] NOT NULL, + login_type login_type DEFAULT 'password'::public.login_type NOT NULL ); CREATE TABLE workspace_agents ( diff --git a/coderd/database/migrations/000034_linked_user_id.up.sql b/coderd/database/migrations/000034_linked_user_id.up.sql index 5aa110a306b70..fb1034da9f43d 100644 --- a/coderd/database/migrations/000034_linked_user_id.up.sql +++ b/coderd/database/migrations/000034_linked_user_id.up.sql @@ -51,4 +51,19 @@ ALTER TABLE api_keys DROP COLUMN oauth_id_token, DROP COLUMN oauth_expiry; +ALTER TABLE users ADD COLUMN login_type login_type NOT NULL DEFAULT 'password'; + +UPDATE + users +SET + login_type = ( + SELECT + login_type + FROM + user_links + WHERE + user_links.user_id = users.id + LIMIT 1 + ); + COMMIT; diff --git a/coderd/database/models.go b/coderd/database/models.go index fcf96ab197ca7..25ed8eedf3190 100644 --- a/coderd/database/models.go +++ b/coderd/database/models.go @@ -487,6 +487,7 @@ type User struct { UpdatedAt time.Time `db:"updated_at" json:"updated_at"` Status UserStatus `db:"status" json:"status"` RBACRoles []string `db:"rbac_roles" json:"rbac_roles"` + LoginType LoginType `db:"login_type" json:"login_type"` } type UserLink struct { diff --git a/coderd/database/postgres/postgres.go b/coderd/database/postgres/postgres.go index d401c2ecc6341..d1ef7b3084197 100644 --- a/coderd/database/postgres/postgres.go +++ b/coderd/database/postgres/postgres.go @@ -136,7 +136,6 @@ func Open() (string, func(), error) { } err = database.MigrateUp(db) if err != nil { - fmt.Printf("err: %v\n", err) return xerrors.Errorf("migrate db: %w", err) } diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index e13cc923f2a96..bdb8b9f05d4e8 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -2617,7 +2617,7 @@ func (q *sqlQuerier) GetAuthorizationUserRoles(ctx context.Context, userID uuid. const getUserByEmailOrUsername = `-- name: GetUserByEmailOrUsername :one SELECT - id, email, username, hashed_password, created_at, updated_at, status, rbac_roles + id, email, username, hashed_password, created_at, updated_at, status, rbac_roles, login_type FROM users WHERE @@ -2644,13 +2644,14 @@ func (q *sqlQuerier) GetUserByEmailOrUsername(ctx context.Context, arg GetUserBy &i.UpdatedAt, &i.Status, pq.Array(&i.RBACRoles), + &i.LoginType, ) return i, err } const getUserByID = `-- name: GetUserByID :one SELECT - id, email, username, hashed_password, created_at, updated_at, status, rbac_roles + id, email, username, hashed_password, created_at, updated_at, status, rbac_roles, login_type FROM users WHERE @@ -2671,6 +2672,7 @@ func (q *sqlQuerier) GetUserByID(ctx context.Context, id uuid.UUID) (User, error &i.UpdatedAt, &i.Status, pq.Array(&i.RBACRoles), + &i.LoginType, ) return i, err } @@ -2691,7 +2693,7 @@ func (q *sqlQuerier) GetUserCount(ctx context.Context) (int64, error) { const getUsers = `-- name: GetUsers :many SELECT - id, email, username, hashed_password, created_at, updated_at, status, rbac_roles + id, email, username, hashed_password, created_at, updated_at, status, rbac_roles, login_type FROM users WHERE @@ -2783,6 +2785,7 @@ func (q *sqlQuerier) GetUsers(ctx context.Context, arg GetUsersParams) ([]User, &i.UpdatedAt, &i.Status, pq.Array(&i.RBACRoles), + &i.LoginType, ); err != nil { return nil, err } @@ -2798,7 +2801,7 @@ func (q *sqlQuerier) GetUsers(ctx context.Context, arg GetUsersParams) ([]User, } const getUsersByIDs = `-- name: GetUsersByIDs :many -SELECT id, email, username, hashed_password, created_at, updated_at, status, rbac_roles FROM users WHERE id = ANY($1 :: uuid [ ]) +SELECT id, email, username, hashed_password, created_at, updated_at, status, rbac_roles, login_type FROM users WHERE id = ANY($1 :: uuid [ ]) ` func (q *sqlQuerier) GetUsersByIDs(ctx context.Context, ids []uuid.UUID) ([]User, error) { @@ -2819,6 +2822,7 @@ func (q *sqlQuerier) GetUsersByIDs(ctx context.Context, ids []uuid.UUID) ([]User &i.UpdatedAt, &i.Status, pq.Array(&i.RBACRoles), + &i.LoginType, ); err != nil { return nil, err } @@ -2842,10 +2846,11 @@ INSERT INTO hashed_password, created_at, updated_at, - rbac_roles + rbac_roles, + login_type ) VALUES - ($1, $2, $3, $4, $5, $6, $7) RETURNING id, email, username, hashed_password, created_at, updated_at, status, rbac_roles + ($1, $2, $3, $4, $5, $6, $7, $8) RETURNING id, email, username, hashed_password, created_at, updated_at, status, rbac_roles, login_type ` type InsertUserParams struct { @@ -2856,6 +2861,7 @@ type InsertUserParams struct { CreatedAt time.Time `db:"created_at" json:"created_at"` UpdatedAt time.Time `db:"updated_at" json:"updated_at"` RBACRoles []string `db:"rbac_roles" json:"rbac_roles"` + LoginType LoginType `db:"login_type" json:"login_type"` } func (q *sqlQuerier) InsertUser(ctx context.Context, arg InsertUserParams) (User, error) { @@ -2867,6 +2873,7 @@ func (q *sqlQuerier) InsertUser(ctx context.Context, arg InsertUserParams) (User arg.CreatedAt, arg.UpdatedAt, pq.Array(arg.RBACRoles), + arg.LoginType, ) var i User err := row.Scan( @@ -2878,6 +2885,7 @@ func (q *sqlQuerier) InsertUser(ctx context.Context, arg InsertUserParams) (User &i.UpdatedAt, &i.Status, pq.Array(&i.RBACRoles), + &i.LoginType, ) return i, err } @@ -2909,7 +2917,7 @@ SET username = $3, updated_at = $4 WHERE - id = $1 RETURNING id, email, username, hashed_password, created_at, updated_at, status, rbac_roles + id = $1 RETURNING id, email, username, hashed_password, created_at, updated_at, status, rbac_roles, login_type ` type UpdateUserProfileParams struct { @@ -2936,6 +2944,7 @@ func (q *sqlQuerier) UpdateUserProfile(ctx context.Context, arg UpdateUserProfil &i.UpdatedAt, &i.Status, pq.Array(&i.RBACRoles), + &i.LoginType, ) return i, err } @@ -2948,7 +2957,7 @@ SET rbac_roles = ARRAY(SELECT DISTINCT UNNEST($1 :: text[])) WHERE id = $2 -RETURNING id, email, username, hashed_password, created_at, updated_at, status, rbac_roles +RETURNING id, email, username, hashed_password, created_at, updated_at, status, rbac_roles, login_type ` type UpdateUserRolesParams struct { @@ -2968,6 +2977,7 @@ func (q *sqlQuerier) UpdateUserRoles(ctx context.Context, arg UpdateUserRolesPar &i.UpdatedAt, &i.Status, pq.Array(&i.RBACRoles), + &i.LoginType, ) return i, err } @@ -2979,7 +2989,7 @@ SET status = $2, updated_at = $3 WHERE - id = $1 RETURNING id, email, username, hashed_password, created_at, updated_at, status, rbac_roles + id = $1 RETURNING id, email, username, hashed_password, created_at, updated_at, status, rbac_roles, login_type ` type UpdateUserStatusParams struct { @@ -3000,6 +3010,7 @@ func (q *sqlQuerier) UpdateUserStatus(ctx context.Context, arg UpdateUserStatusP &i.UpdatedAt, &i.Status, pq.Array(&i.RBACRoles), + &i.LoginType, ) return i, err } diff --git a/coderd/database/queries/users.sql b/coderd/database/queries/users.sql index 19fe8a7701744..e5e6908a9ceb2 100644 --- a/coderd/database/queries/users.sql +++ b/coderd/database/queries/users.sql @@ -37,10 +37,11 @@ INSERT INTO hashed_password, created_at, updated_at, - rbac_roles + rbac_roles, + login_type ) VALUES - ($1, $2, $3, $4, $5, $6, $7) RETURNING *; + ($1, $2, $3, $4, $5, $6, $7, $8) RETURNING *; -- name: UpdateUserProfile :one UPDATE diff --git a/coderd/userauth.go b/coderd/userauth.go index e017a5ce38386..1fe739d6bfb05 100644 --- a/coderd/userauth.go +++ b/coderd/userauth.go @@ -137,7 +137,7 @@ func (api *API) userOAuth2Github(rw http.ResponseWriter, r *http.Request) { return } - user, link, err := findLinkedUser(ctx, api.Database, database.LoginTypeGithub, githubLinkedID(ghUser), verifiedEmails...) + user, link, err := findLinkedUser(ctx, api.Database, githubLinkedID(ghUser), verifiedEmails...) if err != nil { httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{ Message: "An internal error occurred.", @@ -146,9 +146,9 @@ func (api *API) userOAuth2Github(rw http.ResponseWriter, r *http.Request) { return } - if link.UserID != uuid.Nil && link.LoginType != database.LoginTypeGithub { - httpapi.Write(rw, http.StatusConflict, codersdk.Response{ - Message: fmt.Sprintf("Incorrect login type, attempting to use %q but user is of login type %q", database.LoginTypeOIDC, link.LoginType), + if user.ID != uuid.Nil && user.LoginType != database.LoginTypeGithub { + httpapi.Write(rw, http.StatusForbidden, codersdk.Response{ + Message: fmt.Sprintf("Incorrect login type, attempting to use %q but user is of login type %q", database.LoginTypeOIDC, user.LoginType), }) return } @@ -184,10 +184,13 @@ func (api *API) userOAuth2Github(rw http.ResponseWriter, r *http.Request) { }) return } - user, _, err = api.createUser(r.Context(), codersdk.CreateUserRequest{ - Email: *verifiedEmail.Email, - Username: *ghUser.Login, - OrganizationID: organizationID, + user, _, err = api.createUser(ctx, createUserRequest{ + CreateUserRequest: codersdk.CreateUserRequest{ + Email: *verifiedEmail.Email, + Username: *ghUser.Login, + OrganizationID: organizationID, + }, + LoginType: database.LoginTypeGithub, }) if err != nil { httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{ @@ -332,7 +335,7 @@ func (api *API) userOIDC(rw http.ResponseWriter, r *http.Request) { } } - user, link, err := findLinkedUser(ctx, api.Database, database.LoginTypeOIDC, oidcLinkedID(idToken), claims.Email) + user, link, err := findLinkedUser(ctx, api.Database, oidcLinkedID(idToken), claims.Email) if err != nil { httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{ Message: "Failed to find user.", @@ -348,9 +351,9 @@ func (api *API) userOIDC(rw http.ResponseWriter, r *http.Request) { return } - if link.UserID != uuid.Nil && link.LoginType != database.LoginTypeOIDC { - httpapi.Write(rw, http.StatusConflict, codersdk.Response{ - Message: fmt.Sprintf("Incorrect login type, attempting to use %q but user is of login type %q", database.LoginTypeOIDC, link.LoginType), + if user.ID != uuid.Nil && user.LoginType != database.LoginTypeOIDC { + httpapi.Write(rw, http.StatusForbidden, codersdk.Response{ + Message: fmt.Sprintf("Incorrect login type, attempting to use %q but user is of login type %q", database.LoginTypeOIDC, user.LoginType), }) return } @@ -365,10 +368,13 @@ func (api *API) userOIDC(rw http.ResponseWriter, r *http.Request) { organizationID = organizations[0].ID } - user, _, err = api.createUser(ctx, codersdk.CreateUserRequest{ - Email: claims.Email, - Username: claims.Username, - OrganizationID: organizationID, + user, _, err = api.createUser(ctx, createUserRequest{ + CreateUserRequest: codersdk.CreateUserRequest{ + Email: claims.Email, + Username: claims.Username, + OrganizationID: organizationID, + }, + LoginType: database.LoginTypeOIDC, }) if err != nil { httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{ @@ -477,7 +483,7 @@ func oidcLinkedID(tok *oidc.IDToken) string { // findLinkedUser tries to find a user by their unique OAuth-linked ID. // If it doesn't not find it, it returns the user by their email. -func findLinkedUser(ctx context.Context, db database.Store, loginType database.LoginType, linkedID string, emails ...string) (database.User, database.UserLink, error) { +func findLinkedUser(ctx context.Context, db database.Store, linkedID string, emails ...string) (database.User, database.UserLink, error) { var ( user database.User link database.UserLink @@ -518,7 +524,7 @@ func findLinkedUser(ctx context.Context, db database.Store, loginType database.L // possible that a user_link exists without a populated 'linked_id'. link, err = db.GetUserLinkByUserIDLoginType(ctx, database.GetUserLinkByUserIDLoginTypeParams{ UserID: user.ID, - LoginType: loginType, + LoginType: user.LoginType, }) if err != nil && !errors.Is(err, sql.ErrNoRows) { return database.User{}, database.UserLink{}, xerrors.Errorf("get user link by user id and login type: %w", err) diff --git a/coderd/userauth_test.go b/coderd/userauth_test.go index 9db68b4138e18..643c9fbb75094 100644 --- a/coderd/userauth_test.go +++ b/coderd/userauth_test.go @@ -175,7 +175,7 @@ func TestUserOAuth2Github(t *testing.T) { resp := oauth2Callback(t, client) require.Equal(t, http.StatusForbidden, resp.StatusCode) }) - t.Run("Login", func(t *testing.T) { + t.Run("MultiLoginNotAllowed", func(t *testing.T) { t.Parallel() client := coderdtest.New(t, &coderdtest.Options{ GithubOAuth2Config: &coderd.GithubOAuth2Config{ @@ -199,6 +199,7 @@ func TestUserOAuth2Github(t *testing.T) { }, }, }) + // Creates the first user with login_type 'password'. _ = coderdtest.CreateFirstUser(t, client) resp := oauth2Callback(t, client) require.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode) diff --git a/coderd/users.go b/coderd/users.go index 487ac448c11f0..fd4c0cde24175 100644 --- a/coderd/users.go +++ b/coderd/users.go @@ -77,10 +77,13 @@ func (api *API) postFirstUser(rw http.ResponseWriter, r *http.Request) { return } - user, organizationID, err := api.createUser(r.Context(), codersdk.CreateUserRequest{ - Email: createUser.Email, - Username: createUser.Username, - Password: createUser.Password, + user, organizationID, err := api.createUser(r.Context(), createUserRequest{ + CreateUserRequest: codersdk.CreateUserRequest{ + Email: createUser.Email, + Username: createUser.Username, + Password: createUser.Password, + }, + LoginType: database.LoginTypePassword, }) if err != nil { httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{ @@ -243,7 +246,10 @@ func (api *API) postUser(rw http.ResponseWriter, r *http.Request) { return } - user, _, err := api.createUser(r.Context(), req) + user, _, err := api.createUser(r.Context(), createUserRequest{ + CreateUserRequest: req, + LoginType: database.LoginTypePassword, + }) if err != nil { httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{ Message: "Internal error creating user.", @@ -684,6 +690,13 @@ func (api *API) postLogin(rw http.ResponseWriter, r *http.Request) { return } + if user.LoginType != database.LoginTypePassword { + httpapi.Write(rw, http.StatusForbidden, codersdk.Response{ + Message: fmt.Sprintf("Incorrect login type, attempting to use %q but user is of login type %q", database.LoginTypeOIDC, user.LoginType), + }) + return + } + // If the user doesn't exist, it will be a default struct. equal, err := userpassword.Compare(string(user.HashedPassword), loginWithPassword.Password) if err != nil { @@ -896,7 +909,12 @@ func (api *API) createAPIKey(rw http.ResponseWriter, r *http.Request, params cre return sessionToken, true } -func (api *API) createUser(ctx context.Context, req codersdk.CreateUserRequest) (database.User, uuid.UUID, error) { +type createUserRequest struct { + codersdk.CreateUserRequest + LoginType database.LoginType +} + +func (api *API) createUser(ctx context.Context, req createUserRequest) (database.User, uuid.UUID, error) { var user database.User return user, req.OrganizationID, api.Database.InTx(func(db database.Store) error { orgRoles := make([]string, 0) @@ -923,6 +941,7 @@ func (api *API) createUser(ctx context.Context, req codersdk.CreateUserRequest) UpdatedAt: database.Now(), // All new users are defaulted to members of the site. RBACRoles: []string{}, + LoginType: req.LoginType, } // If a user signs up with OAuth, they can have no password! if req.Password != "" { From eb266dba1ec399505eaaa1a29d200755fbaf994d Mon Sep 17 00:00:00 2001 From: Jon Ayers Date: Sat, 13 Aug 2022 00:40:54 +0000 Subject: [PATCH 24/32] fix login_type query --- coderd/database/migrations/000034_linked_user_id.up.sql | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/coderd/database/migrations/000034_linked_user_id.up.sql b/coderd/database/migrations/000034_linked_user_id.up.sql index fb1034da9f43d..732a6eb6ec4f8 100644 --- a/coderd/database/migrations/000034_linked_user_id.up.sql +++ b/coderd/database/migrations/000034_linked_user_id.up.sql @@ -37,7 +37,7 @@ SELECT FROM ( SELECT - row_number() OVER (partition by user_id, login_type ORDER BY updated_at DESC) AS x, + row_number() OVER (partition by user_id, login_type ORDER BY last_used DESC) AS x, api_keys.* FROM api_keys ) as keys WHERE x=1 AND keys.login_type != 'password'; @@ -63,7 +63,12 @@ SET user_links WHERE user_links.user_id = users.id + ORDER BY oauth_expiry DESC LIMIT 1 - ); + ) +FROM + user_links +WHERE + user_links.user_id = users.id; COMMIT; From 4671bf66e3a420633fc474cd83dbbc908e9c1945 Mon Sep 17 00:00:00 2001 From: Jon Ayers Date: Sat, 13 Aug 2022 02:02:37 +0000 Subject: [PATCH 25/32] fix tests --- coderd/database/databasefake/databasefake.go | 1 + coderd/userauth_test.go | 4 +++- coderd/users.go | 14 +++++++------- 3 files changed, 11 insertions(+), 8 deletions(-) diff --git a/coderd/database/databasefake/databasefake.go b/coderd/database/databasefake/databasefake.go index 233fb63c59fa9..2edc3f35397ad 100644 --- a/coderd/database/databasefake/databasefake.go +++ b/coderd/database/databasefake/databasefake.go @@ -1740,6 +1740,7 @@ func (q *fakeQuerier) InsertUser(_ context.Context, arg database.InsertUserParam Username: arg.Username, Status: database.UserStatusActive, RBACRoles: arg.RBACRoles, + LoginType: arg.LoginType, } q.users = append(q.users, user) return user, nil diff --git a/coderd/userauth_test.go b/coderd/userauth_test.go index 643c9fbb75094..25fb476673c71 100644 --- a/coderd/userauth_test.go +++ b/coderd/userauth_test.go @@ -201,8 +201,10 @@ func TestUserOAuth2Github(t *testing.T) { }) // Creates the first user with login_type 'password'. _ = coderdtest.CreateFirstUser(t, client) + // Attempting to login should give us a 403 since the user + // already has a login_type of 'password'. resp := oauth2Callback(t, client) - require.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode) + require.Equal(t, http.StatusForbidden, resp.StatusCode) }) t.Run("Signup", func(t *testing.T) { t.Parallel() diff --git a/coderd/users.go b/coderd/users.go index fd4c0cde24175..9e30edce21321 100644 --- a/coderd/users.go +++ b/coderd/users.go @@ -690,13 +690,6 @@ func (api *API) postLogin(rw http.ResponseWriter, r *http.Request) { return } - if user.LoginType != database.LoginTypePassword { - httpapi.Write(rw, http.StatusForbidden, codersdk.Response{ - Message: fmt.Sprintf("Incorrect login type, attempting to use %q but user is of login type %q", database.LoginTypeOIDC, user.LoginType), - }) - return - } - // If the user doesn't exist, it will be a default struct. equal, err := userpassword.Compare(string(user.HashedPassword), loginWithPassword.Password) if err != nil { @@ -714,6 +707,13 @@ func (api *API) postLogin(rw http.ResponseWriter, r *http.Request) { return } + if user.LoginType != database.LoginTypePassword { + httpapi.Write(rw, http.StatusForbidden, codersdk.Response{ + Message: fmt.Sprintf("Incorrect login type, attempting to use %q but user is of login type %q", database.LoginTypePassword, user.LoginType), + }) + return + } + // If the user logged into a suspended account, reject the login request. if user.Status != database.UserStatusActive { httpapi.Write(rw, http.StatusUnauthorized, codersdk.Response{ From c41f4e65bd45e27845e0eb8a7beab2dbfec3dc5b Mon Sep 17 00:00:00 2001 From: Jon Ayers Date: Sat, 13 Aug 2022 02:08:32 +0000 Subject: [PATCH 26/32] fix audit --- coderd/audit/table.go | 1 + 1 file changed, 1 insertion(+) diff --git a/coderd/audit/table.go b/coderd/audit/table.go index c842956e6cf24..6a44bdc88b653 100644 --- a/coderd/audit/table.go +++ b/coderd/audit/table.go @@ -94,6 +94,7 @@ var AuditableResources = auditMap(map[any]map[string]Action{ "updated_at": ActionIgnore, // Changes, but is implicit and not helpful in a diff. "status": ActionTrack, "rbac_roles": ActionTrack, + "login_type": ActionIgnore, }, &database.Workspace{}: { "id": ActionTrack, From f3d839219319de3e968e3eb22784ddfb29eae267 Mon Sep 17 00:00:00 2001 From: Jon Ayers Date: Sat, 13 Aug 2022 02:22:10 +0000 Subject: [PATCH 27/32] fix down --- coderd/database/migrations/000034_linked_user_id.down.sql | 1 + 1 file changed, 1 insertion(+) diff --git a/coderd/database/migrations/000034_linked_user_id.down.sql b/coderd/database/migrations/000034_linked_user_id.down.sql index a1f4e9112e4fb..4b75aad6abd7f 100644 --- a/coderd/database/migrations/000034_linked_user_id.down.sql +++ b/coderd/database/migrations/000034_linked_user_id.down.sql @@ -18,5 +18,6 @@ ALTER TABLE api_keys ADD COLUMN oauth_expiry timestamp with time zone DEFAULT '0001-01-01 00:00:00+00'::timestamp with time zone NOT NULL; +ALTER TABLE users DROP COLUMN login_type; COMMIT; From cc8400bb868b514e0bfdae0d56c23602dc2c0c7a Mon Sep 17 00:00:00 2001 From: Jon Ayers Date: Sat, 13 Aug 2022 02:42:27 +0000 Subject: [PATCH 28/32] fix one more test --- coderd/database/db_test.go | 1 + 1 file changed, 1 insertion(+) diff --git a/coderd/database/db_test.go b/coderd/database/db_test.go index 324e048e9156c..3ac9ac0519f3a 100644 --- a/coderd/database/db_test.go +++ b/coderd/database/db_test.go @@ -37,6 +37,7 @@ func TestNestedInTx(t *testing.T) { CreatedAt: database.Now(), UpdatedAt: database.Now(), RBACRoles: []string{}, + LoginType: database.LoginTypeGithub, }) return err }) From 083d256c65640d68ae848f2201390f05c9263b0e Mon Sep 17 00:00:00 2001 From: Jon Ayers Date: Wed, 17 Aug 2022 01:48:04 +0000 Subject: [PATCH 29/32] pr comments --- .../migrations/000034_linked_user_id.up.sql | 48 +++++++++---------- coderd/database/queries/user_links.sql | 46 +++++++++--------- coderd/database/queries/users.sql | 14 +++--- coderd/userauth.go | 8 ++-- 4 files changed, 59 insertions(+), 57 deletions(-) diff --git a/coderd/database/migrations/000034_linked_user_id.up.sql b/coderd/database/migrations/000034_linked_user_id.up.sql index 732a6eb6ec4f8..d86d5771165e6 100644 --- a/coderd/database/migrations/000034_linked_user_id.up.sql +++ b/coderd/database/migrations/000034_linked_user_id.up.sql @@ -7,8 +7,8 @@ CREATE TABLE IF NOT EXISTS user_links ( oauth_access_token text DEFAULT ''::text NOT NULL, oauth_refresh_token text DEFAULT ''::text NOT NULL, oauth_expiry timestamp with time zone DEFAULT '0001-01-01 00:00:00+00'::timestamp with time zone NOT NULL, - UNIQUE(user_id, login_type), - FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE + PRIMARY KEY(user_id, login_type), + FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE ); -- This migrates columns on api_keys to the new user_links table. @@ -18,8 +18,8 @@ CREATE TABLE IF NOT EXISTS user_links ( -- A user should at most have a row for an OIDC account and a Github account. -- 'password' login types are ignored. -INSERT INTO user_links - ( +INSERT INTO user_links + ( user_id, login_type, linked_id, @@ -27,25 +27,25 @@ INSERT INTO user_links oauth_refresh_token, oauth_expiry ) -SELECT - keys.user_id, +SELECT + keys.user_id, keys.login_type, '', keys.oauth_access_token, keys.oauth_refresh_token, - keys.oauth_expiry -FROM - ( - SELECT - row_number() OVER (partition by user_id, login_type ORDER BY last_used DESC) AS x, + keys.oauth_expiry +FROM + ( + SELECT + row_number() OVER (partition by user_id, login_type ORDER BY last_used DESC) AS x, api_keys.* FROM api_keys ) as keys - WHERE x=1 AND keys.login_type != 'password'; +WHERE x=1 AND keys.login_type != 'password'; -- Drop columns that have been migrated to user_links. -- It appears the 'oauth_id_token' was unused and so it has -- been dropped here as well to avoid future confusion. -ALTER TABLE api_keys +ALTER TABLE api_keys DROP COLUMN oauth_access_token, DROP COLUMN oauth_refresh_token, DROP COLUMN oauth_id_token, @@ -54,18 +54,18 @@ ALTER TABLE api_keys ALTER TABLE users ADD COLUMN login_type login_type NOT NULL DEFAULT 'password'; UPDATE - users + users SET - login_type = ( - SELECT - login_type - FROM - user_links - WHERE - user_links.user_id = users.id - ORDER BY oauth_expiry DESC - LIMIT 1 - ) + login_type = ( + SELECT + login_type + FROM + user_links + WHERE + user_links.user_id = users.id + ORDER BY oauth_expiry DESC + LIMIT 1 + ) FROM user_links WHERE diff --git a/coderd/database/queries/user_links.sql b/coderd/database/queries/user_links.sql index 94120a5793a3c..2390cb9782b30 100644 --- a/coderd/database/queries/user_links.sql +++ b/coderd/database/queries/user_links.sql @@ -1,46 +1,46 @@ -- name: GetUserLinkByLinkedID :one SELECT - * + * FROM - user_links + user_links WHERE - linked_id = $1; + linked_id = $1; -- name: GetUserLinkByUserIDLoginType :one SELECT - * + * FROM - user_links + user_links WHERE - user_id = $1 AND login_type = $2; + user_id = $1 AND login_type = $2; -- name: InsertUserLink :one INSERT INTO - user_links ( - user_id, - login_type, - linked_id, - oauth_access_token, - oauth_refresh_token, - oauth_expiry - ) + user_links ( + user_id, + login_type, + linked_id, + oauth_access_token, + oauth_refresh_token, + oauth_expiry + ) VALUES - ( $1, $2, $3, $4, $5, $6 ) RETURNING *; + ( $1, $2, $3, $4, $5, $6 ) RETURNING *; -- name: UpdateUserLinkedID :one UPDATE - user_links + user_links SET - linked_id = $1 + linked_id = $1 WHERE - user_id = $2 AND login_type = $3 RETURNING *; + user_id = $2 AND login_type = $3 RETURNING *; -- name: UpdateUserLink :one UPDATE - user_links + user_links SET - oauth_access_token = $1, - oauth_refresh_token = $2, - oauth_expiry = $3 + oauth_access_token = $1, + oauth_refresh_token = $2, + oauth_expiry = $3 WHERE - user_id = $4 AND login_type = $5 RETURNING *; + user_id = $4 AND login_type = $5 RETURNING *; diff --git a/coderd/database/queries/users.sql b/coderd/database/queries/users.sql index e5e6908a9ceb2..1d9caa758625e 100644 --- a/coderd/database/queries/users.sql +++ b/coderd/database/queries/users.sql @@ -38,7 +38,7 @@ INSERT INTO created_at, updated_at, rbac_roles, - login_type + login_type ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8) RETURNING *; @@ -55,12 +55,12 @@ WHERE -- name: UpdateUserRoles :one UPDATE - users + users SET -- Remove all duplicates from the roles. rbac_roles = ARRAY(SELECT DISTINCT UNNEST(@granted_roles :: text[])) WHERE - id = @id + id = @id RETURNING *; -- name: UpdateUserHashedPassword :exec @@ -123,8 +123,8 @@ WHERE END -- End of filters ORDER BY - -- Deterministic and consistent ordering of all users, even if they share - -- a timestamp. This is to ensure consistent pagination. + -- Deterministic and consistent ordering of all users, even if they share + -- a timestamp. This is to ensure consistent pagination. (created_at, id) ASC OFFSET @offset_opt LIMIT -- A null limit means "no limit", so 0 means return all @@ -153,10 +153,10 @@ SELECT array_append(users.rbac_roles, 'member'), -- All org_members get the org-member role for their orgs array_append(organization_members.roles, 'organization-member:'||organization_members.organization_id::text)) :: text[] - AS roles + AS roles FROM users LEFT JOIN organization_members ON id = user_id WHERE - id = @user_id; + id = @user_id; diff --git a/coderd/userauth.go b/coderd/userauth.go index 1fe739d6bfb05..0e99aa9f47a6d 100644 --- a/coderd/userauth.go +++ b/coderd/userauth.go @@ -148,7 +148,7 @@ func (api *API) userOAuth2Github(rw http.ResponseWriter, r *http.Request) { if user.ID != uuid.Nil && user.LoginType != database.LoginTypeGithub { httpapi.Write(rw, http.StatusForbidden, codersdk.Response{ - Message: fmt.Sprintf("Incorrect login type, attempting to use %q but user is of login type %q", database.LoginTypeOIDC, user.LoginType), + Message: fmt.Sprintf("Incorrect login type, attempting to use %q but user is of login type %q", database.LoginTypeGithub, user.LoginType), }) return } @@ -215,7 +215,7 @@ func (api *API) userOAuth2Github(rw http.ResponseWriter, r *http.Request) { if err != nil { httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{ Message: "A database error occurred.", - Detail: xerrors.Errorf("insert user link: %w", err.Error).Error(), + Detail: fmt.Sprintf("insert user link: %s", err.Error()), }) return } @@ -358,6 +358,8 @@ func (api *API) userOIDC(rw http.ResponseWriter, r *http.Request) { return } + // This can happen if a user is a built-in user but is signing in + // with OIDC for the first time. if user.ID == uuid.Nil { var organizationID uuid.UUID organizations, _ := api.Database.GetOrganizations(ctx) @@ -404,7 +406,7 @@ func (api *API) userOIDC(rw http.ResponseWriter, r *http.Request) { if err != nil { httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{ Message: "A database error occurred.", - Detail: xerrors.Errorf("insert user link: %w", err.Error).Error(), + Detail: fmt.Sprintf("insert user link: %s", err.Error()), }) return } From 92c185db1a260ee1d9e98105fc50b39e315e5cb9 Mon Sep 17 00:00:00 2001 From: Jon Ayers Date: Wed, 17 Aug 2022 01:49:58 +0000 Subject: [PATCH 30/32] fix conflicting migration file --- ...034_linked_user_id.down.sql => 000035_linked_user_id.down.sql} | 0 ...{000034_linked_user_id.up.sql => 000035_linked_user_id.up.sql} | 0 2 files changed, 0 insertions(+), 0 deletions(-) rename coderd/database/migrations/{000034_linked_user_id.down.sql => 000035_linked_user_id.down.sql} (100%) rename coderd/database/migrations/{000034_linked_user_id.up.sql => 000035_linked_user_id.up.sql} (100%) diff --git a/coderd/database/migrations/000034_linked_user_id.down.sql b/coderd/database/migrations/000035_linked_user_id.down.sql similarity index 100% rename from coderd/database/migrations/000034_linked_user_id.down.sql rename to coderd/database/migrations/000035_linked_user_id.down.sql diff --git a/coderd/database/migrations/000034_linked_user_id.up.sql b/coderd/database/migrations/000035_linked_user_id.up.sql similarity index 100% rename from coderd/database/migrations/000034_linked_user_id.up.sql rename to coderd/database/migrations/000035_linked_user_id.up.sql From 05595d8d360f2c878ea49b94c2a84722d9ae2bdf Mon Sep 17 00:00:00 2001 From: Jon Ayers Date: Wed, 17 Aug 2022 01:50:51 +0000 Subject: [PATCH 31/32] generate.sh --- coderd/database/dump.sql | 2 +- coderd/database/queries.sql.go | 60 +++++++++++++++++----------------- 2 files changed, 31 insertions(+), 31 deletions(-) diff --git a/coderd/database/dump.sql b/coderd/database/dump.sql index 7782646c94d07..1db9442d93543 100644 --- a/coderd/database/dump.sql +++ b/coderd/database/dump.sql @@ -423,7 +423,7 @@ ALTER TABLE ONLY templates ADD CONSTRAINT templates_pkey PRIMARY KEY (id); ALTER TABLE ONLY user_links - ADD CONSTRAINT user_links_user_id_login_type_key UNIQUE (user_id, login_type); + ADD CONSTRAINT user_links_pkey PRIMARY KEY (user_id, login_type); ALTER TABLE ONLY users ADD CONSTRAINT users_pkey PRIMARY KEY (id); diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index bdb8b9f05d4e8..d8d022039191a 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -2413,11 +2413,11 @@ func (q *sqlQuerier) UpdateTemplateVersionDescriptionByJobID(ctx context.Context const getUserLinkByLinkedID = `-- name: GetUserLinkByLinkedID :one SELECT - user_id, login_type, linked_id, oauth_access_token, oauth_refresh_token, oauth_expiry + user_id, login_type, linked_id, oauth_access_token, oauth_refresh_token, oauth_expiry FROM - user_links + user_links WHERE - linked_id = $1 + linked_id = $1 ` func (q *sqlQuerier) GetUserLinkByLinkedID(ctx context.Context, linkedID string) (UserLink, error) { @@ -2436,11 +2436,11 @@ func (q *sqlQuerier) GetUserLinkByLinkedID(ctx context.Context, linkedID string) const getUserLinkByUserIDLoginType = `-- name: GetUserLinkByUserIDLoginType :one SELECT - user_id, login_type, linked_id, oauth_access_token, oauth_refresh_token, oauth_expiry + user_id, login_type, linked_id, oauth_access_token, oauth_refresh_token, oauth_expiry FROM - user_links + user_links WHERE - user_id = $1 AND login_type = $2 + user_id = $1 AND login_type = $2 ` type GetUserLinkByUserIDLoginTypeParams struct { @@ -2464,16 +2464,16 @@ func (q *sqlQuerier) GetUserLinkByUserIDLoginType(ctx context.Context, arg GetUs const insertUserLink = `-- name: InsertUserLink :one INSERT INTO - user_links ( - user_id, - login_type, - linked_id, - oauth_access_token, - oauth_refresh_token, - oauth_expiry - ) + user_links ( + user_id, + login_type, + linked_id, + oauth_access_token, + oauth_refresh_token, + oauth_expiry + ) VALUES - ( $1, $2, $3, $4, $5, $6 ) RETURNING user_id, login_type, linked_id, oauth_access_token, oauth_refresh_token, oauth_expiry + ( $1, $2, $3, $4, $5, $6 ) RETURNING user_id, login_type, linked_id, oauth_access_token, oauth_refresh_token, oauth_expiry ` type InsertUserLinkParams struct { @@ -2508,13 +2508,13 @@ func (q *sqlQuerier) InsertUserLink(ctx context.Context, arg InsertUserLinkParam const updateUserLink = `-- name: UpdateUserLink :one UPDATE - user_links + user_links SET - oauth_access_token = $1, - oauth_refresh_token = $2, - oauth_expiry = $3 + oauth_access_token = $1, + oauth_refresh_token = $2, + oauth_expiry = $3 WHERE - user_id = $4 AND login_type = $5 RETURNING user_id, login_type, linked_id, oauth_access_token, oauth_refresh_token, oauth_expiry + user_id = $4 AND login_type = $5 RETURNING user_id, login_type, linked_id, oauth_access_token, oauth_refresh_token, oauth_expiry ` type UpdateUserLinkParams struct { @@ -2547,11 +2547,11 @@ func (q *sqlQuerier) UpdateUserLink(ctx context.Context, arg UpdateUserLinkParam const updateUserLinkedID = `-- name: UpdateUserLinkedID :one UPDATE - user_links + user_links SET - linked_id = $1 + linked_id = $1 WHERE - user_id = $2 AND login_type = $3 RETURNING user_id, login_type, linked_id, oauth_access_token, oauth_refresh_token, oauth_expiry + user_id = $2 AND login_type = $3 RETURNING user_id, login_type, linked_id, oauth_access_token, oauth_refresh_token, oauth_expiry ` type UpdateUserLinkedIDParams struct { @@ -2585,13 +2585,13 @@ SELECT array_append(users.rbac_roles, 'member'), -- All org_members get the org-member role for their orgs array_append(organization_members.roles, 'organization-member:'||organization_members.organization_id::text)) :: text[] - AS roles + AS roles FROM users LEFT JOIN organization_members ON id = user_id WHERE - id = $1 + id = $1 ` type GetAuthorizationUserRolesRow struct { @@ -2743,8 +2743,8 @@ WHERE END -- End of filters ORDER BY - -- Deterministic and consistent ordering of all users, even if they share - -- a timestamp. This is to ensure consistent pagination. + -- Deterministic and consistent ordering of all users, even if they share + -- a timestamp. This is to ensure consistent pagination. (created_at, id) ASC OFFSET $5 LIMIT -- A null limit means "no limit", so 0 means return all @@ -2847,7 +2847,7 @@ INSERT INTO created_at, updated_at, rbac_roles, - login_type + login_type ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8) RETURNING id, email, username, hashed_password, created_at, updated_at, status, rbac_roles, login_type @@ -2951,12 +2951,12 @@ func (q *sqlQuerier) UpdateUserProfile(ctx context.Context, arg UpdateUserProfil const updateUserRoles = `-- name: UpdateUserRoles :one UPDATE - users + users SET -- Remove all duplicates from the roles. rbac_roles = ARRAY(SELECT DISTINCT UNNEST($1 :: text[])) WHERE - id = $2 + id = $2 RETURNING id, email, username, hashed_password, created_at, updated_at, status, rbac_roles, login_type ` From aa9014876b4f34c677b24053d00d10bf0485d5cc Mon Sep 17 00:00:00 2001 From: Jon Ayers Date: Wed, 17 Aug 2022 01:51:28 +0000 Subject: [PATCH 32/32] butcher the english language to appease colin --- coderd/userauth.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/coderd/userauth.go b/coderd/userauth.go index 0e99aa9f47a6d..5e7654cd6a08f 100644 --- a/coderd/userauth.go +++ b/coderd/userauth.go @@ -478,7 +478,7 @@ func githubLinkedID(u *github.User) string { } // oidcLinkedID returns the uniqued ID for an OIDC user. -// See https://openid.net/specs/openid-connect-core-1_0.html#ClaimStability. +// See https://openid.net/specs/openid-connect-core-1_0.html#ClaimStability . func oidcLinkedID(tok *oidc.IDToken) string { return strings.Join([]string{tok.Issuer, tok.Subject}, "||") }