diff --git a/coderd/audit/table.go b/coderd/audit/table.go index c842956e6cf24..6a44bdc88b653 100644 --- a/coderd/audit/table.go +++ b/coderd/audit/table.go @@ -94,6 +94,7 @@ var AuditableResources = auditMap(map[any]map[string]Action{ "updated_at": ActionIgnore, // Changes, but is implicit and not helpful in a diff. "status": ActionTrack, "rbac_roles": ActionTrack, + "login_type": ActionIgnore, }, &database.Workspace{}: { "id": ActionTrack, diff --git a/coderd/database/databasefake/databasefake.go b/coderd/database/databasefake/databasefake.go index 6623d338af7c4..2edc3f35397ad 100644 --- a/coderd/database/databasefake/databasefake.go +++ b/coderd/database/databasefake/databasefake.go @@ -73,6 +73,7 @@ type data struct { organizations []database.Organization organizationMembers []database.OrganizationMember users []database.User + userLinks []database.UserLink // New tables auditLogs []database.AuditLog @@ -1453,20 +1454,16 @@ func (q *fakeQuerier) InsertAPIKey(_ context.Context, arg database.InsertAPIKeyP //nolint:gosimple key := database.APIKey{ - ID: arg.ID, - LifetimeSeconds: arg.LifetimeSeconds, - HashedSecret: arg.HashedSecret, - IPAddress: arg.IPAddress, - UserID: arg.UserID, - ExpiresAt: arg.ExpiresAt, - CreatedAt: arg.CreatedAt, - UpdatedAt: arg.UpdatedAt, - LastUsed: arg.LastUsed, - LoginType: arg.LoginType, - OAuthAccessToken: arg.OAuthAccessToken, - OAuthRefreshToken: arg.OAuthRefreshToken, - OAuthIDToken: arg.OAuthIDToken, - OAuthExpiry: arg.OAuthExpiry, + ID: arg.ID, + LifetimeSeconds: arg.LifetimeSeconds, + HashedSecret: arg.HashedSecret, + IPAddress: arg.IPAddress, + UserID: arg.UserID, + ExpiresAt: arg.ExpiresAt, + CreatedAt: arg.CreatedAt, + UpdatedAt: arg.UpdatedAt, + LastUsed: arg.LastUsed, + LoginType: arg.LoginType, } q.apiKeys = append(q.apiKeys, key) return key, nil @@ -1743,6 +1740,7 @@ func (q *fakeQuerier) InsertUser(_ context.Context, arg database.InsertUserParam Username: arg.Username, Status: database.UserStatusActive, RBACRoles: arg.RBACRoles, + LoginType: arg.LoginType, } q.users = append(q.users, user) return user, nil @@ -1898,9 +1896,6 @@ func (q *fakeQuerier) UpdateAPIKeyByID(_ context.Context, arg database.UpdateAPI apiKey.LastUsed = arg.LastUsed apiKey.ExpiresAt = arg.ExpiresAt apiKey.IPAddress = arg.IPAddress - apiKey.OAuthAccessToken = arg.OAuthAccessToken - apiKey.OAuthRefreshToken = arg.OAuthRefreshToken - apiKey.OAuthExpiry = arg.OAuthExpiry q.apiKeys[index] = apiKey return nil } @@ -2259,3 +2254,80 @@ func (q *fakeQuerier) GetDeploymentID(_ context.Context) (string, error) { return q.deploymentID, nil } + +func (q *fakeQuerier) GetUserLinkByLinkedID(_ context.Context, id string) (database.UserLink, error) { + q.mutex.RLock() + defer q.mutex.RUnlock() + + for _, link := range q.userLinks { + if link.LinkedID == id { + return link, nil + } + } + return database.UserLink{}, sql.ErrNoRows +} + +func (q *fakeQuerier) GetUserLinkByUserIDLoginType(_ context.Context, params database.GetUserLinkByUserIDLoginTypeParams) (database.UserLink, error) { + q.mutex.RLock() + defer q.mutex.RUnlock() + + for _, link := range q.userLinks { + if link.UserID == params.UserID && link.LoginType == params.LoginType { + return link, nil + } + } + return database.UserLink{}, sql.ErrNoRows +} + +func (q *fakeQuerier) InsertUserLink(_ context.Context, args database.InsertUserLinkParams) (database.UserLink, error) { + q.mutex.RLock() + defer q.mutex.RUnlock() + + //nolint:gosimple + link := database.UserLink{ + UserID: args.UserID, + LoginType: args.LoginType, + LinkedID: args.LinkedID, + OAuthAccessToken: args.OAuthAccessToken, + OAuthRefreshToken: args.OAuthRefreshToken, + OAuthExpiry: args.OAuthExpiry, + } + + q.userLinks = append(q.userLinks, link) + + return link, nil +} + +func (q *fakeQuerier) UpdateUserLinkedID(_ context.Context, params database.UpdateUserLinkedIDParams) (database.UserLink, error) { + q.mutex.RLock() + defer q.mutex.RUnlock() + + for i, link := range q.userLinks { + if link.UserID == params.UserID && link.LoginType == params.LoginType { + link.LinkedID = params.LinkedID + + q.userLinks[i] = link + return link, nil + } + } + + return database.UserLink{}, sql.ErrNoRows +} + +func (q *fakeQuerier) UpdateUserLink(_ context.Context, params database.UpdateUserLinkParams) (database.UserLink, error) { + q.mutex.RLock() + defer q.mutex.RUnlock() + + for i, link := range q.userLinks { + if link.UserID == params.UserID && link.LoginType == params.LoginType { + link.OAuthAccessToken = params.OAuthAccessToken + link.OAuthRefreshToken = params.OAuthRefreshToken + link.OAuthExpiry = params.OAuthExpiry + + q.userLinks[i] = link + return link, nil + } + } + + return database.UserLink{}, sql.ErrNoRows +} diff --git a/coderd/database/db_test.go b/coderd/database/db_test.go index 324e048e9156c..3ac9ac0519f3a 100644 --- a/coderd/database/db_test.go +++ b/coderd/database/db_test.go @@ -37,6 +37,7 @@ func TestNestedInTx(t *testing.T) { CreatedAt: database.Now(), UpdatedAt: database.Now(), RBACRoles: []string{}, + LoginType: database.LoginTypeGithub, }) return err }) diff --git a/coderd/database/dump.sql b/coderd/database/dump.sql index 87853f23fe16e..1db9442d93543 100644 --- a/coderd/database/dump.sql +++ b/coderd/database/dump.sql @@ -96,10 +96,6 @@ CREATE TABLE api_keys ( created_at timestamp with time zone NOT NULL, updated_at timestamp with time zone NOT NULL, login_type login_type NOT NULL, - oauth_access_token text DEFAULT ''::text NOT NULL, - oauth_refresh_token text DEFAULT ''::text NOT NULL, - oauth_id_token text DEFAULT ''::text NOT NULL, - oauth_expiry timestamp with time zone DEFAULT '0001-01-01 00:00:00+00'::timestamp with time zone NOT NULL, lifetime_seconds bigint DEFAULT 86400 NOT NULL, ip_address inet DEFAULT '0.0.0.0'::inet NOT NULL ); @@ -267,6 +263,15 @@ CREATE TABLE templates ( created_by uuid NOT NULL ); +CREATE TABLE user_links ( + user_id uuid NOT NULL, + login_type login_type NOT NULL, + linked_id text DEFAULT ''::text NOT NULL, + oauth_access_token text DEFAULT ''::text NOT NULL, + oauth_refresh_token text DEFAULT ''::text NOT NULL, + oauth_expiry timestamp with time zone DEFAULT '0001-01-01 00:00:00+00'::timestamp with time zone NOT NULL +); + CREATE TABLE users ( id uuid NOT NULL, email text NOT NULL, @@ -275,7 +280,8 @@ CREATE TABLE users ( created_at timestamp with time zone NOT NULL, updated_at timestamp with time zone NOT NULL, status user_status DEFAULT 'active'::public.user_status NOT NULL, - rbac_roles text[] DEFAULT '{}'::text[] NOT NULL + rbac_roles text[] DEFAULT '{}'::text[] NOT NULL, + login_type login_type DEFAULT 'password'::public.login_type NOT NULL ); CREATE TABLE workspace_agents ( @@ -416,6 +422,9 @@ ALTER TABLE ONLY template_versions ALTER TABLE ONLY templates ADD CONSTRAINT templates_pkey PRIMARY KEY (id); +ALTER TABLE ONLY user_links + ADD CONSTRAINT user_links_pkey PRIMARY KEY (user_id, login_type); + ALTER TABLE ONLY users ADD CONSTRAINT users_pkey PRIMARY KEY (id); @@ -513,6 +522,9 @@ ALTER TABLE ONLY templates ALTER TABLE ONLY templates ADD CONSTRAINT templates_organization_id_fkey FOREIGN KEY (organization_id) REFERENCES organizations(id) ON DELETE CASCADE; +ALTER TABLE ONLY user_links + ADD CONSTRAINT user_links_user_id_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE; + ALTER TABLE ONLY workspace_agents ADD CONSTRAINT workspace_agents_resource_id_fkey FOREIGN KEY (resource_id) REFERENCES workspace_resources(id) ON DELETE CASCADE; diff --git a/coderd/database/generate.sh b/coderd/database/generate.sh index e00b0ae73a425..326fa096b90d1 100755 --- a/coderd/database/generate.sh +++ b/coderd/database/generate.sh @@ -13,6 +13,8 @@ SCRIPT_DIR=$(dirname "${BASH_SOURCE[0]}") ( cd "$SCRIPT_DIR" + # Dump the updated schema. + go run dump/main.go # The logic below depends on the exact version being correct :( go run github.com/kyleconroy/sqlc/cmd/sqlc@v1.13.0 generate diff --git a/coderd/database/migrations/000035_linked_user_id.down.sql b/coderd/database/migrations/000035_linked_user_id.down.sql new file mode 100644 index 0000000000000..4b75aad6abd7f --- /dev/null +++ b/coderd/database/migrations/000035_linked_user_id.down.sql @@ -0,0 +1,23 @@ +-- This migration makes no attempt to try to populate +-- the oauth_access_token, oauth_refresh_token, and oauth_expiry +-- columns of api_key rows with the values from the dropped user_links +-- table. +BEGIN; + +DROP TABLE IF EXISTS user_links; + +ALTER TABLE + api_keys +ADD COLUMN oauth_access_token text DEFAULT ''::text NOT NULL; + +ALTER TABLE + api_keys +ADD COLUMN oauth_refresh_token text DEFAULT ''::text NOT NULL; + +ALTER TABLE + api_keys +ADD COLUMN oauth_expiry timestamp with time zone DEFAULT '0001-01-01 00:00:00+00'::timestamp with time zone NOT NULL; + +ALTER TABLE users DROP COLUMN login_type; + +COMMIT; diff --git a/coderd/database/migrations/000035_linked_user_id.up.sql b/coderd/database/migrations/000035_linked_user_id.up.sql new file mode 100644 index 0000000000000..d86d5771165e6 --- /dev/null +++ b/coderd/database/migrations/000035_linked_user_id.up.sql @@ -0,0 +1,74 @@ +BEGIN; + +CREATE TABLE IF NOT EXISTS user_links ( + user_id uuid NOT NULL, + login_type login_type NOT NULL, + linked_id text DEFAULT ''::text NOT NULL, + oauth_access_token text DEFAULT ''::text NOT NULL, + oauth_refresh_token text DEFAULT ''::text NOT NULL, + oauth_expiry timestamp with time zone DEFAULT '0001-01-01 00:00:00+00'::timestamp with time zone NOT NULL, + PRIMARY KEY(user_id, login_type), + FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE +); + +-- This migrates columns on api_keys to the new user_links table. +-- It does this by finding all the API keys for each user, choosing +-- the most recently updated for each one and then assigning its relevant +-- values to the user_links table. +-- A user should at most have a row for an OIDC account and a Github account. +-- 'password' login types are ignored. + +INSERT INTO user_links + ( + user_id, + login_type, + linked_id, + oauth_access_token, + oauth_refresh_token, + oauth_expiry + ) +SELECT + keys.user_id, + keys.login_type, + '', + keys.oauth_access_token, + keys.oauth_refresh_token, + keys.oauth_expiry +FROM + ( + SELECT + row_number() OVER (partition by user_id, login_type ORDER BY last_used DESC) AS x, + api_keys.* FROM api_keys + ) as keys +WHERE x=1 AND keys.login_type != 'password'; + +-- Drop columns that have been migrated to user_links. +-- It appears the 'oauth_id_token' was unused and so it has +-- been dropped here as well to avoid future confusion. +ALTER TABLE api_keys + DROP COLUMN oauth_access_token, + DROP COLUMN oauth_refresh_token, + DROP COLUMN oauth_id_token, + DROP COLUMN oauth_expiry; + +ALTER TABLE users ADD COLUMN login_type login_type NOT NULL DEFAULT 'password'; + +UPDATE + users +SET + login_type = ( + SELECT + login_type + FROM + user_links + WHERE + user_links.user_id = users.id + ORDER BY oauth_expiry DESC + LIMIT 1 + ) +FROM + user_links +WHERE + user_links.user_id = users.id; + +COMMIT; diff --git a/coderd/database/models.go b/coderd/database/models.go index 6cf4c07761674..25ed8eedf3190 100644 --- a/coderd/database/models.go +++ b/coderd/database/models.go @@ -313,20 +313,16 @@ func (e *WorkspaceTransition) Scan(src interface{}) error { } type APIKey struct { - ID string `db:"id" json:"id"` - HashedSecret []byte `db:"hashed_secret" json:"hashed_secret"` - UserID uuid.UUID `db:"user_id" json:"user_id"` - LastUsed time.Time `db:"last_used" json:"last_used"` - ExpiresAt time.Time `db:"expires_at" json:"expires_at"` - CreatedAt time.Time `db:"created_at" json:"created_at"` - UpdatedAt time.Time `db:"updated_at" json:"updated_at"` - LoginType LoginType `db:"login_type" json:"login_type"` - OAuthAccessToken string `db:"oauth_access_token" json:"oauth_access_token"` - OAuthRefreshToken string `db:"oauth_refresh_token" json:"oauth_refresh_token"` - OAuthIDToken string `db:"oauth_id_token" json:"oauth_id_token"` - OAuthExpiry time.Time `db:"oauth_expiry" json:"oauth_expiry"` - LifetimeSeconds int64 `db:"lifetime_seconds" json:"lifetime_seconds"` - IPAddress pqtype.Inet `db:"ip_address" json:"ip_address"` + ID string `db:"id" json:"id"` + HashedSecret []byte `db:"hashed_secret" json:"hashed_secret"` + UserID uuid.UUID `db:"user_id" json:"user_id"` + LastUsed time.Time `db:"last_used" json:"last_used"` + ExpiresAt time.Time `db:"expires_at" json:"expires_at"` + CreatedAt time.Time `db:"created_at" json:"created_at"` + UpdatedAt time.Time `db:"updated_at" json:"updated_at"` + LoginType LoginType `db:"login_type" json:"login_type"` + LifetimeSeconds int64 `db:"lifetime_seconds" json:"lifetime_seconds"` + IPAddress pqtype.Inet `db:"ip_address" json:"ip_address"` } type AuditLog struct { @@ -491,6 +487,16 @@ type User struct { UpdatedAt time.Time `db:"updated_at" json:"updated_at"` Status UserStatus `db:"status" json:"status"` RBACRoles []string `db:"rbac_roles" json:"rbac_roles"` + LoginType LoginType `db:"login_type" json:"login_type"` +} + +type UserLink struct { + UserID uuid.UUID `db:"user_id" json:"user_id"` + LoginType LoginType `db:"login_type" json:"login_type"` + LinkedID string `db:"linked_id" json:"linked_id"` + OAuthAccessToken string `db:"oauth_access_token" json:"oauth_access_token"` + OAuthRefreshToken string `db:"oauth_refresh_token" json:"oauth_refresh_token"` + OAuthExpiry time.Time `db:"oauth_expiry" json:"oauth_expiry"` } type Workspace struct { diff --git a/coderd/database/querier.go b/coderd/database/querier.go index 90e9a3a0a1385..272efe40fd446 100644 --- a/coderd/database/querier.go +++ b/coderd/database/querier.go @@ -64,6 +64,8 @@ type querier interface { GetUserByEmailOrUsername(ctx context.Context, arg GetUserByEmailOrUsernameParams) (User, error) GetUserByID(ctx context.Context, id uuid.UUID) (User, error) GetUserCount(ctx context.Context) (int64, error) + GetUserLinkByLinkedID(ctx context.Context, linkedID string) (UserLink, error) + GetUserLinkByUserIDLoginType(ctx context.Context, arg GetUserLinkByUserIDLoginTypeParams) (UserLink, error) GetUsers(ctx context.Context, arg GetUsersParams) ([]User, error) GetUsersByIDs(ctx context.Context, ids []uuid.UUID) ([]User, error) GetWorkspaceAgentByAuthToken(ctx context.Context, authToken uuid.UUID) (WorkspaceAgent, error) @@ -106,6 +108,7 @@ type querier interface { InsertTemplate(ctx context.Context, arg InsertTemplateParams) (Template, error) InsertTemplateVersion(ctx context.Context, arg InsertTemplateVersionParams) (TemplateVersion, error) InsertUser(ctx context.Context, arg InsertUserParams) (User, error) + InsertUserLink(ctx context.Context, arg InsertUserLinkParams) (UserLink, error) InsertWorkspace(ctx context.Context, arg InsertWorkspaceParams) (Workspace, error) InsertWorkspaceAgent(ctx context.Context, arg InsertWorkspaceAgentParams) (WorkspaceAgent, error) InsertWorkspaceApp(ctx context.Context, arg InsertWorkspaceAppParams) (WorkspaceApp, error) @@ -127,6 +130,8 @@ type querier interface { UpdateTemplateVersionByID(ctx context.Context, arg UpdateTemplateVersionByIDParams) error UpdateTemplateVersionDescriptionByJobID(ctx context.Context, arg UpdateTemplateVersionDescriptionByJobIDParams) error UpdateUserHashedPassword(ctx context.Context, arg UpdateUserHashedPasswordParams) error + UpdateUserLink(ctx context.Context, arg UpdateUserLinkParams) (UserLink, error) + UpdateUserLinkedID(ctx context.Context, arg UpdateUserLinkedIDParams) (UserLink, error) UpdateUserProfile(ctx context.Context, arg UpdateUserProfileParams) (User, error) UpdateUserRoles(ctx context.Context, arg UpdateUserRolesParams) (User, error) UpdateUserStatus(ctx context.Context, arg UpdateUserStatusParams) (User, error) diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index d36262a121ee2..d8d022039191a 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -31,7 +31,7 @@ func (q *sqlQuerier) DeleteAPIKeyByID(ctx context.Context, id string) error { const getAPIKeyByID = `-- name: GetAPIKeyByID :one SELECT - id, hashed_secret, user_id, last_used, expires_at, created_at, updated_at, login_type, oauth_access_token, oauth_refresh_token, oauth_id_token, oauth_expiry, lifetime_seconds, ip_address + id, hashed_secret, user_id, last_used, expires_at, created_at, updated_at, login_type, lifetime_seconds, ip_address FROM api_keys WHERE @@ -52,10 +52,6 @@ func (q *sqlQuerier) GetAPIKeyByID(ctx context.Context, id string) (APIKey, erro &i.CreatedAt, &i.UpdatedAt, &i.LoginType, - &i.OAuthAccessToken, - &i.OAuthRefreshToken, - &i.OAuthIDToken, - &i.OAuthExpiry, &i.LifetimeSeconds, &i.IPAddress, ) @@ -63,7 +59,7 @@ func (q *sqlQuerier) GetAPIKeyByID(ctx context.Context, id string) (APIKey, erro } const getAPIKeysLastUsedAfter = `-- name: GetAPIKeysLastUsedAfter :many -SELECT id, hashed_secret, user_id, last_used, expires_at, created_at, updated_at, login_type, oauth_access_token, oauth_refresh_token, oauth_id_token, oauth_expiry, lifetime_seconds, ip_address FROM api_keys WHERE last_used > $1 +SELECT id, hashed_secret, user_id, last_used, expires_at, created_at, updated_at, login_type, lifetime_seconds, ip_address FROM api_keys WHERE last_used > $1 ` func (q *sqlQuerier) GetAPIKeysLastUsedAfter(ctx context.Context, lastUsed time.Time) ([]APIKey, error) { @@ -84,10 +80,6 @@ func (q *sqlQuerier) GetAPIKeysLastUsedAfter(ctx context.Context, lastUsed time. &i.CreatedAt, &i.UpdatedAt, &i.LoginType, - &i.OAuthAccessToken, - &i.OAuthRefreshToken, - &i.OAuthIDToken, - &i.OAuthExpiry, &i.LifetimeSeconds, &i.IPAddress, ); err != nil { @@ -116,11 +108,7 @@ INSERT INTO expires_at, created_at, updated_at, - login_type, - oauth_access_token, - oauth_refresh_token, - oauth_id_token, - oauth_expiry + login_type ) VALUES ($1, @@ -129,24 +117,20 @@ VALUES WHEN 0 THEN 86400 ELSE $2::bigint END - , $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14) RETURNING id, hashed_secret, user_id, last_used, expires_at, created_at, updated_at, login_type, oauth_access_token, oauth_refresh_token, oauth_id_token, oauth_expiry, lifetime_seconds, ip_address + , $3, $4, $5, $6, $7, $8, $9, $10) RETURNING id, hashed_secret, user_id, last_used, expires_at, created_at, updated_at, login_type, lifetime_seconds, ip_address ` type InsertAPIKeyParams struct { - ID string `db:"id" json:"id"` - LifetimeSeconds int64 `db:"lifetime_seconds" json:"lifetime_seconds"` - HashedSecret []byte `db:"hashed_secret" json:"hashed_secret"` - IPAddress pqtype.Inet `db:"ip_address" json:"ip_address"` - UserID uuid.UUID `db:"user_id" json:"user_id"` - LastUsed time.Time `db:"last_used" json:"last_used"` - ExpiresAt time.Time `db:"expires_at" json:"expires_at"` - CreatedAt time.Time `db:"created_at" json:"created_at"` - UpdatedAt time.Time `db:"updated_at" json:"updated_at"` - LoginType LoginType `db:"login_type" json:"login_type"` - OAuthAccessToken string `db:"oauth_access_token" json:"oauth_access_token"` - OAuthRefreshToken string `db:"oauth_refresh_token" json:"oauth_refresh_token"` - OAuthIDToken string `db:"oauth_id_token" json:"oauth_id_token"` - OAuthExpiry time.Time `db:"oauth_expiry" json:"oauth_expiry"` + ID string `db:"id" json:"id"` + LifetimeSeconds int64 `db:"lifetime_seconds" json:"lifetime_seconds"` + HashedSecret []byte `db:"hashed_secret" json:"hashed_secret"` + IPAddress pqtype.Inet `db:"ip_address" json:"ip_address"` + UserID uuid.UUID `db:"user_id" json:"user_id"` + LastUsed time.Time `db:"last_used" json:"last_used"` + ExpiresAt time.Time `db:"expires_at" json:"expires_at"` + CreatedAt time.Time `db:"created_at" json:"created_at"` + UpdatedAt time.Time `db:"updated_at" json:"updated_at"` + LoginType LoginType `db:"login_type" json:"login_type"` } func (q *sqlQuerier) InsertAPIKey(ctx context.Context, arg InsertAPIKeyParams) (APIKey, error) { @@ -161,10 +145,6 @@ func (q *sqlQuerier) InsertAPIKey(ctx context.Context, arg InsertAPIKeyParams) ( arg.CreatedAt, arg.UpdatedAt, arg.LoginType, - arg.OAuthAccessToken, - arg.OAuthRefreshToken, - arg.OAuthIDToken, - arg.OAuthExpiry, ) var i APIKey err := row.Scan( @@ -176,10 +156,6 @@ func (q *sqlQuerier) InsertAPIKey(ctx context.Context, arg InsertAPIKeyParams) ( &i.CreatedAt, &i.UpdatedAt, &i.LoginType, - &i.OAuthAccessToken, - &i.OAuthRefreshToken, - &i.OAuthIDToken, - &i.OAuthExpiry, &i.LifetimeSeconds, &i.IPAddress, ) @@ -192,22 +168,16 @@ UPDATE SET last_used = $2, expires_at = $3, - ip_address = $4, - oauth_access_token = $5, - oauth_refresh_token = $6, - oauth_expiry = $7 + ip_address = $4 WHERE id = $1 ` type UpdateAPIKeyByIDParams struct { - ID string `db:"id" json:"id"` - LastUsed time.Time `db:"last_used" json:"last_used"` - ExpiresAt time.Time `db:"expires_at" json:"expires_at"` - IPAddress pqtype.Inet `db:"ip_address" json:"ip_address"` - OAuthAccessToken string `db:"oauth_access_token" json:"oauth_access_token"` - OAuthRefreshToken string `db:"oauth_refresh_token" json:"oauth_refresh_token"` - OAuthExpiry time.Time `db:"oauth_expiry" json:"oauth_expiry"` + ID string `db:"id" json:"id"` + LastUsed time.Time `db:"last_used" json:"last_used"` + ExpiresAt time.Time `db:"expires_at" json:"expires_at"` + IPAddress pqtype.Inet `db:"ip_address" json:"ip_address"` } func (q *sqlQuerier) UpdateAPIKeyByID(ctx context.Context, arg UpdateAPIKeyByIDParams) error { @@ -216,9 +186,6 @@ func (q *sqlQuerier) UpdateAPIKeyByID(ctx context.Context, arg UpdateAPIKeyByIDP arg.LastUsed, arg.ExpiresAt, arg.IPAddress, - arg.OAuthAccessToken, - arg.OAuthRefreshToken, - arg.OAuthExpiry, ) return err } @@ -2444,6 +2411,169 @@ func (q *sqlQuerier) UpdateTemplateVersionDescriptionByJobID(ctx context.Context return err } +const getUserLinkByLinkedID = `-- name: GetUserLinkByLinkedID :one +SELECT + user_id, login_type, linked_id, oauth_access_token, oauth_refresh_token, oauth_expiry +FROM + user_links +WHERE + linked_id = $1 +` + +func (q *sqlQuerier) GetUserLinkByLinkedID(ctx context.Context, linkedID string) (UserLink, error) { + row := q.db.QueryRowContext(ctx, getUserLinkByLinkedID, linkedID) + var i UserLink + err := row.Scan( + &i.UserID, + &i.LoginType, + &i.LinkedID, + &i.OAuthAccessToken, + &i.OAuthRefreshToken, + &i.OAuthExpiry, + ) + return i, err +} + +const getUserLinkByUserIDLoginType = `-- name: GetUserLinkByUserIDLoginType :one +SELECT + user_id, login_type, linked_id, oauth_access_token, oauth_refresh_token, oauth_expiry +FROM + user_links +WHERE + user_id = $1 AND login_type = $2 +` + +type GetUserLinkByUserIDLoginTypeParams struct { + UserID uuid.UUID `db:"user_id" json:"user_id"` + LoginType LoginType `db:"login_type" json:"login_type"` +} + +func (q *sqlQuerier) GetUserLinkByUserIDLoginType(ctx context.Context, arg GetUserLinkByUserIDLoginTypeParams) (UserLink, error) { + row := q.db.QueryRowContext(ctx, getUserLinkByUserIDLoginType, arg.UserID, arg.LoginType) + var i UserLink + err := row.Scan( + &i.UserID, + &i.LoginType, + &i.LinkedID, + &i.OAuthAccessToken, + &i.OAuthRefreshToken, + &i.OAuthExpiry, + ) + return i, err +} + +const insertUserLink = `-- name: InsertUserLink :one +INSERT INTO + user_links ( + user_id, + login_type, + linked_id, + oauth_access_token, + oauth_refresh_token, + oauth_expiry + ) +VALUES + ( $1, $2, $3, $4, $5, $6 ) RETURNING user_id, login_type, linked_id, oauth_access_token, oauth_refresh_token, oauth_expiry +` + +type InsertUserLinkParams struct { + UserID uuid.UUID `db:"user_id" json:"user_id"` + LoginType LoginType `db:"login_type" json:"login_type"` + LinkedID string `db:"linked_id" json:"linked_id"` + OAuthAccessToken string `db:"oauth_access_token" json:"oauth_access_token"` + OAuthRefreshToken string `db:"oauth_refresh_token" json:"oauth_refresh_token"` + OAuthExpiry time.Time `db:"oauth_expiry" json:"oauth_expiry"` +} + +func (q *sqlQuerier) InsertUserLink(ctx context.Context, arg InsertUserLinkParams) (UserLink, error) { + row := q.db.QueryRowContext(ctx, insertUserLink, + arg.UserID, + arg.LoginType, + arg.LinkedID, + arg.OAuthAccessToken, + arg.OAuthRefreshToken, + arg.OAuthExpiry, + ) + var i UserLink + err := row.Scan( + &i.UserID, + &i.LoginType, + &i.LinkedID, + &i.OAuthAccessToken, + &i.OAuthRefreshToken, + &i.OAuthExpiry, + ) + return i, err +} + +const updateUserLink = `-- name: UpdateUserLink :one +UPDATE + user_links +SET + oauth_access_token = $1, + oauth_refresh_token = $2, + oauth_expiry = $3 +WHERE + user_id = $4 AND login_type = $5 RETURNING user_id, login_type, linked_id, oauth_access_token, oauth_refresh_token, oauth_expiry +` + +type UpdateUserLinkParams struct { + OAuthAccessToken string `db:"oauth_access_token" json:"oauth_access_token"` + OAuthRefreshToken string `db:"oauth_refresh_token" json:"oauth_refresh_token"` + OAuthExpiry time.Time `db:"oauth_expiry" json:"oauth_expiry"` + UserID uuid.UUID `db:"user_id" json:"user_id"` + LoginType LoginType `db:"login_type" json:"login_type"` +} + +func (q *sqlQuerier) UpdateUserLink(ctx context.Context, arg UpdateUserLinkParams) (UserLink, error) { + row := q.db.QueryRowContext(ctx, updateUserLink, + arg.OAuthAccessToken, + arg.OAuthRefreshToken, + arg.OAuthExpiry, + arg.UserID, + arg.LoginType, + ) + var i UserLink + err := row.Scan( + &i.UserID, + &i.LoginType, + &i.LinkedID, + &i.OAuthAccessToken, + &i.OAuthRefreshToken, + &i.OAuthExpiry, + ) + return i, err +} + +const updateUserLinkedID = `-- name: UpdateUserLinkedID :one +UPDATE + user_links +SET + linked_id = $1 +WHERE + user_id = $2 AND login_type = $3 RETURNING user_id, login_type, linked_id, oauth_access_token, oauth_refresh_token, oauth_expiry +` + +type UpdateUserLinkedIDParams struct { + LinkedID string `db:"linked_id" json:"linked_id"` + UserID uuid.UUID `db:"user_id" json:"user_id"` + LoginType LoginType `db:"login_type" json:"login_type"` +} + +func (q *sqlQuerier) UpdateUserLinkedID(ctx context.Context, arg UpdateUserLinkedIDParams) (UserLink, error) { + row := q.db.QueryRowContext(ctx, updateUserLinkedID, arg.LinkedID, arg.UserID, arg.LoginType) + var i UserLink + err := row.Scan( + &i.UserID, + &i.LoginType, + &i.LinkedID, + &i.OAuthAccessToken, + &i.OAuthRefreshToken, + &i.OAuthExpiry, + ) + return i, err +} + const getAuthorizationUserRoles = `-- name: GetAuthorizationUserRoles :one SELECT -- username is returned just to help for logging purposes @@ -2455,13 +2585,13 @@ SELECT array_append(users.rbac_roles, 'member'), -- All org_members get the org-member role for their orgs array_append(organization_members.roles, 'organization-member:'||organization_members.organization_id::text)) :: text[] - AS roles + AS roles FROM users LEFT JOIN organization_members ON id = user_id WHERE - id = $1 + id = $1 ` type GetAuthorizationUserRolesRow struct { @@ -2487,7 +2617,7 @@ func (q *sqlQuerier) GetAuthorizationUserRoles(ctx context.Context, userID uuid. const getUserByEmailOrUsername = `-- name: GetUserByEmailOrUsername :one SELECT - id, email, username, hashed_password, created_at, updated_at, status, rbac_roles + id, email, username, hashed_password, created_at, updated_at, status, rbac_roles, login_type FROM users WHERE @@ -2514,13 +2644,14 @@ func (q *sqlQuerier) GetUserByEmailOrUsername(ctx context.Context, arg GetUserBy &i.UpdatedAt, &i.Status, pq.Array(&i.RBACRoles), + &i.LoginType, ) return i, err } const getUserByID = `-- name: GetUserByID :one SELECT - id, email, username, hashed_password, created_at, updated_at, status, rbac_roles + id, email, username, hashed_password, created_at, updated_at, status, rbac_roles, login_type FROM users WHERE @@ -2541,6 +2672,7 @@ func (q *sqlQuerier) GetUserByID(ctx context.Context, id uuid.UUID) (User, error &i.UpdatedAt, &i.Status, pq.Array(&i.RBACRoles), + &i.LoginType, ) return i, err } @@ -2561,7 +2693,7 @@ func (q *sqlQuerier) GetUserCount(ctx context.Context) (int64, error) { const getUsers = `-- name: GetUsers :many SELECT - id, email, username, hashed_password, created_at, updated_at, status, rbac_roles + id, email, username, hashed_password, created_at, updated_at, status, rbac_roles, login_type FROM users WHERE @@ -2611,8 +2743,8 @@ WHERE END -- End of filters ORDER BY - -- Deterministic and consistent ordering of all users, even if they share - -- a timestamp. This is to ensure consistent pagination. + -- Deterministic and consistent ordering of all users, even if they share + -- a timestamp. This is to ensure consistent pagination. (created_at, id) ASC OFFSET $5 LIMIT -- A null limit means "no limit", so 0 means return all @@ -2653,6 +2785,7 @@ func (q *sqlQuerier) GetUsers(ctx context.Context, arg GetUsersParams) ([]User, &i.UpdatedAt, &i.Status, pq.Array(&i.RBACRoles), + &i.LoginType, ); err != nil { return nil, err } @@ -2668,7 +2801,7 @@ func (q *sqlQuerier) GetUsers(ctx context.Context, arg GetUsersParams) ([]User, } const getUsersByIDs = `-- name: GetUsersByIDs :many -SELECT id, email, username, hashed_password, created_at, updated_at, status, rbac_roles FROM users WHERE id = ANY($1 :: uuid [ ]) +SELECT id, email, username, hashed_password, created_at, updated_at, status, rbac_roles, login_type FROM users WHERE id = ANY($1 :: uuid [ ]) ` func (q *sqlQuerier) GetUsersByIDs(ctx context.Context, ids []uuid.UUID) ([]User, error) { @@ -2689,6 +2822,7 @@ func (q *sqlQuerier) GetUsersByIDs(ctx context.Context, ids []uuid.UUID) ([]User &i.UpdatedAt, &i.Status, pq.Array(&i.RBACRoles), + &i.LoginType, ); err != nil { return nil, err } @@ -2712,10 +2846,11 @@ INSERT INTO hashed_password, created_at, updated_at, - rbac_roles + rbac_roles, + login_type ) VALUES - ($1, $2, $3, $4, $5, $6, $7) RETURNING id, email, username, hashed_password, created_at, updated_at, status, rbac_roles + ($1, $2, $3, $4, $5, $6, $7, $8) RETURNING id, email, username, hashed_password, created_at, updated_at, status, rbac_roles, login_type ` type InsertUserParams struct { @@ -2726,6 +2861,7 @@ type InsertUserParams struct { CreatedAt time.Time `db:"created_at" json:"created_at"` UpdatedAt time.Time `db:"updated_at" json:"updated_at"` RBACRoles []string `db:"rbac_roles" json:"rbac_roles"` + LoginType LoginType `db:"login_type" json:"login_type"` } func (q *sqlQuerier) InsertUser(ctx context.Context, arg InsertUserParams) (User, error) { @@ -2737,6 +2873,7 @@ func (q *sqlQuerier) InsertUser(ctx context.Context, arg InsertUserParams) (User arg.CreatedAt, arg.UpdatedAt, pq.Array(arg.RBACRoles), + arg.LoginType, ) var i User err := row.Scan( @@ -2748,6 +2885,7 @@ func (q *sqlQuerier) InsertUser(ctx context.Context, arg InsertUserParams) (User &i.UpdatedAt, &i.Status, pq.Array(&i.RBACRoles), + &i.LoginType, ) return i, err } @@ -2779,7 +2917,7 @@ SET username = $3, updated_at = $4 WHERE - id = $1 RETURNING id, email, username, hashed_password, created_at, updated_at, status, rbac_roles + id = $1 RETURNING id, email, username, hashed_password, created_at, updated_at, status, rbac_roles, login_type ` type UpdateUserProfileParams struct { @@ -2806,19 +2944,20 @@ func (q *sqlQuerier) UpdateUserProfile(ctx context.Context, arg UpdateUserProfil &i.UpdatedAt, &i.Status, pq.Array(&i.RBACRoles), + &i.LoginType, ) return i, err } const updateUserRoles = `-- name: UpdateUserRoles :one UPDATE - users + users SET -- Remove all duplicates from the roles. rbac_roles = ARRAY(SELECT DISTINCT UNNEST($1 :: text[])) WHERE - id = $2 -RETURNING id, email, username, hashed_password, created_at, updated_at, status, rbac_roles + id = $2 +RETURNING id, email, username, hashed_password, created_at, updated_at, status, rbac_roles, login_type ` type UpdateUserRolesParams struct { @@ -2838,6 +2977,7 @@ func (q *sqlQuerier) UpdateUserRoles(ctx context.Context, arg UpdateUserRolesPar &i.UpdatedAt, &i.Status, pq.Array(&i.RBACRoles), + &i.LoginType, ) return i, err } @@ -2849,7 +2989,7 @@ SET status = $2, updated_at = $3 WHERE - id = $1 RETURNING id, email, username, hashed_password, created_at, updated_at, status, rbac_roles + id = $1 RETURNING id, email, username, hashed_password, created_at, updated_at, status, rbac_roles, login_type ` type UpdateUserStatusParams struct { @@ -2870,6 +3010,7 @@ func (q *sqlQuerier) UpdateUserStatus(ctx context.Context, arg UpdateUserStatusP &i.UpdatedAt, &i.Status, pq.Array(&i.RBACRoles), + &i.LoginType, ) return i, err } diff --git a/coderd/database/queries/apikeys.sql b/coderd/database/queries/apikeys.sql index 692ac3e69c8a8..22ce2e6057f3e 100644 --- a/coderd/database/queries/apikeys.sql +++ b/coderd/database/queries/apikeys.sql @@ -23,11 +23,7 @@ INSERT INTO expires_at, created_at, updated_at, - login_type, - oauth_access_token, - oauth_refresh_token, - oauth_id_token, - oauth_expiry + login_type ) VALUES (@id, @@ -36,7 +32,7 @@ VALUES WHEN 0 THEN 86400 ELSE @lifetime_seconds::bigint END - , @hashed_secret, @ip_address, @user_id, @last_used, @expires_at, @created_at, @updated_at, @login_type, @oauth_access_token, @oauth_refresh_token, @oauth_id_token, @oauth_expiry) RETURNING *; + , @hashed_secret, @ip_address, @user_id, @last_used, @expires_at, @created_at, @updated_at, @login_type) RETURNING *; -- name: UpdateAPIKeyByID :exec UPDATE @@ -44,10 +40,7 @@ UPDATE SET last_used = $2, expires_at = $3, - ip_address = $4, - oauth_access_token = $5, - oauth_refresh_token = $6, - oauth_expiry = $7 + ip_address = $4 WHERE id = $1; diff --git a/coderd/database/queries/user_links.sql b/coderd/database/queries/user_links.sql new file mode 100644 index 0000000000000..2390cb9782b30 --- /dev/null +++ b/coderd/database/queries/user_links.sql @@ -0,0 +1,46 @@ +-- name: GetUserLinkByLinkedID :one +SELECT + * +FROM + user_links +WHERE + linked_id = $1; + +-- name: GetUserLinkByUserIDLoginType :one +SELECT + * +FROM + user_links +WHERE + user_id = $1 AND login_type = $2; + +-- name: InsertUserLink :one +INSERT INTO + user_links ( + user_id, + login_type, + linked_id, + oauth_access_token, + oauth_refresh_token, + oauth_expiry + ) +VALUES + ( $1, $2, $3, $4, $5, $6 ) RETURNING *; + +-- name: UpdateUserLinkedID :one +UPDATE + user_links +SET + linked_id = $1 +WHERE + user_id = $2 AND login_type = $3 RETURNING *; + +-- name: UpdateUserLink :one +UPDATE + user_links +SET + oauth_access_token = $1, + oauth_refresh_token = $2, + oauth_expiry = $3 +WHERE + user_id = $4 AND login_type = $5 RETURNING *; diff --git a/coderd/database/queries/users.sql b/coderd/database/queries/users.sql index 19fe8a7701744..1d9caa758625e 100644 --- a/coderd/database/queries/users.sql +++ b/coderd/database/queries/users.sql @@ -37,10 +37,11 @@ INSERT INTO hashed_password, created_at, updated_at, - rbac_roles + rbac_roles, + login_type ) VALUES - ($1, $2, $3, $4, $5, $6, $7) RETURNING *; + ($1, $2, $3, $4, $5, $6, $7, $8) RETURNING *; -- name: UpdateUserProfile :one UPDATE @@ -54,12 +55,12 @@ WHERE -- name: UpdateUserRoles :one UPDATE - users + users SET -- Remove all duplicates from the roles. rbac_roles = ARRAY(SELECT DISTINCT UNNEST(@granted_roles :: text[])) WHERE - id = @id + id = @id RETURNING *; -- name: UpdateUserHashedPassword :exec @@ -122,8 +123,8 @@ WHERE END -- End of filters ORDER BY - -- Deterministic and consistent ordering of all users, even if they share - -- a timestamp. This is to ensure consistent pagination. + -- Deterministic and consistent ordering of all users, even if they share + -- a timestamp. This is to ensure consistent pagination. (created_at, id) ASC OFFSET @offset_opt LIMIT -- A null limit means "no limit", so 0 means return all @@ -152,10 +153,10 @@ SELECT array_append(users.rbac_roles, 'member'), -- All org_members get the org-member role for their orgs array_append(organization_members.roles, 'organization-member:'||organization_members.organization_id::text)) :: text[] - AS roles + AS roles FROM users LEFT JOIN organization_members ON id = user_id WHERE - id = @user_id; + id = @user_id; diff --git a/coderd/httpmw/apikey.go b/coderd/httpmw/apikey.go index 80586bc976f49..ed2a1e617f567 100644 --- a/coderd/httpmw/apikey.go +++ b/coderd/httpmw/apikey.go @@ -14,6 +14,7 @@ import ( "golang.org/x/oauth2" + "github.com/google/uuid" "github.com/tabbed/pqtype" "github.com/coder/coder/coderd/database" @@ -149,9 +150,21 @@ func ExtractAPIKey(db database.Store, oauth *OAuth2Configs, redirectToLogin bool // Tracks if the API key has properties updated! changed := false + var link database.UserLink if key.LoginType != database.LoginTypePassword { + link, err = db.GetUserLinkByUserIDLoginType(r.Context(), database.GetUserLinkByUserIDLoginTypeParams{ + UserID: key.UserID, + LoginType: key.LoginType, + }) + if err != nil { + write(http.StatusInternalServerError, codersdk.Response{ + Message: "A database error occurred", + Detail: fmt.Sprintf("get user link by user ID and login type: %s", err.Error()), + }) + return + } // Check if the OAuth token is expired! - if key.OAuthExpiry.Before(now) && !key.OAuthExpiry.IsZero() { + if link.OAuthExpiry.Before(now) && !link.OAuthExpiry.IsZero() { var oauthConfig OAuth2Config switch key.LoginType { case database.LoginTypeGithub: @@ -167,9 +180,9 @@ func ExtractAPIKey(db database.Store, oauth *OAuth2Configs, redirectToLogin bool } // If it is, let's refresh it from the provided config! token, err := oauthConfig.TokenSource(r.Context(), &oauth2.Token{ - AccessToken: key.OAuthAccessToken, - RefreshToken: key.OAuthRefreshToken, - Expiry: key.OAuthExpiry, + AccessToken: link.OAuthAccessToken, + RefreshToken: link.OAuthRefreshToken, + Expiry: link.OAuthExpiry, }).Token() if err != nil { write(http.StatusUnauthorized, codersdk.Response{ @@ -178,9 +191,9 @@ func ExtractAPIKey(db database.Store, oauth *OAuth2Configs, redirectToLogin bool }) return } - key.OAuthAccessToken = token.AccessToken - key.OAuthRefreshToken = token.RefreshToken - key.OAuthExpiry = token.Expiry + link.OAuthAccessToken = token.AccessToken + link.OAuthRefreshToken = token.RefreshToken + link.OAuthExpiry = token.Expiry key.ExpiresAt = token.Expiry changed = true } @@ -222,13 +235,10 @@ func ExtractAPIKey(db database.Store, oauth *OAuth2Configs, redirectToLogin bool } if changed { err := db.UpdateAPIKeyByID(r.Context(), database.UpdateAPIKeyByIDParams{ - ID: key.ID, - LastUsed: key.LastUsed, - ExpiresAt: key.ExpiresAt, - IPAddress: key.IPAddress, - OAuthAccessToken: key.OAuthAccessToken, - OAuthRefreshToken: key.OAuthRefreshToken, - OAuthExpiry: key.OAuthExpiry, + ID: key.ID, + LastUsed: key.LastUsed, + ExpiresAt: key.ExpiresAt, + IPAddress: key.IPAddress, }) if err != nil { write(http.StatusInternalServerError, codersdk.Response{ @@ -237,6 +247,24 @@ func ExtractAPIKey(db database.Store, oauth *OAuth2Configs, redirectToLogin bool }) return } + // If the API Key is associated with a user_link (e.g. Github/OIDC) + // then we want to update the relevant oauth fields. + if link.UserID != uuid.Nil { + link, err = db.UpdateUserLink(r.Context(), database.UpdateUserLinkParams{ + UserID: link.UserID, + LoginType: link.LoginType, + OAuthAccessToken: link.OAuthAccessToken, + OAuthRefreshToken: link.OAuthRefreshToken, + OAuthExpiry: link.OAuthExpiry, + }) + if err != nil { + write(http.StatusInternalServerError, codersdk.Response{ + Message: internalErrorMessage, + Detail: fmt.Sprintf("update user_link: %s.", err.Error()), + }) + return + } + } } // If the key is valid, we also fetch the user roles and status. diff --git a/coderd/httpmw/apikey_test.go b/coderd/httpmw/apikey_test.go index 0d29c84653c6f..adc13b2f176fa 100644 --- a/coderd/httpmw/apikey_test.go +++ b/coderd/httpmw/apikey_test.go @@ -187,6 +187,7 @@ func TestAPIKey(t *testing.T) { ID: id, HashedSecret: hashed[:], UserID: user.ID, + LoginType: database.LoginTypePassword, }) require.NoError(t, err) httpmw.ExtractAPIKey(db, nil, false)(successHandler).ServeHTTP(rw, r) @@ -215,6 +216,7 @@ func TestAPIKey(t *testing.T) { HashedSecret: hashed[:], ExpiresAt: database.Now().AddDate(0, 0, 1), UserID: user.ID, + LoginType: database.LoginTypePassword, }) require.NoError(t, err) httpmw.ExtractAPIKey(db, nil, false)(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { @@ -253,6 +255,7 @@ func TestAPIKey(t *testing.T) { HashedSecret: hashed[:], ExpiresAt: database.Now().AddDate(0, 0, 1), UserID: user.ID, + LoginType: database.LoginTypePassword, }) require.NoError(t, err) httpmw.ExtractAPIKey(db, nil, false)(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { @@ -288,6 +291,7 @@ func TestAPIKey(t *testing.T) { LastUsed: database.Now().AddDate(0, 0, -1), ExpiresAt: database.Now().AddDate(0, 0, 1), UserID: user.ID, + LoginType: database.LoginTypePassword, }) require.NoError(t, err) httpmw.ExtractAPIKey(db, nil, false)(successHandler).ServeHTTP(rw, r) @@ -323,6 +327,7 @@ func TestAPIKey(t *testing.T) { LastUsed: database.Now(), ExpiresAt: database.Now().Add(time.Minute), UserID: user.ID, + LoginType: database.LoginTypePassword, }) require.NoError(t, err) httpmw.ExtractAPIKey(db, nil, false)(successHandler).ServeHTTP(rw, r) @@ -361,6 +366,13 @@ func TestAPIKey(t *testing.T) { UserID: user.ID, }) require.NoError(t, err) + + _, err = db.InsertUserLink(r.Context(), database.InsertUserLinkParams{ + UserID: user.ID, + LoginType: database.LoginTypeGithub, + }) + require.NoError(t, err) + httpmw.ExtractAPIKey(db, nil, false)(successHandler).ServeHTTP(rw, r) res := rw.Result() defer res.Body.Close() @@ -393,10 +405,16 @@ func TestAPIKey(t *testing.T) { HashedSecret: hashed[:], LoginType: database.LoginTypeGithub, LastUsed: database.Now(), - OAuthExpiry: database.Now().AddDate(0, 0, -1), UserID: user.ID, }) require.NoError(t, err) + _, err = db.InsertUserLink(r.Context(), database.InsertUserLinkParams{ + UserID: user.ID, + LoginType: database.LoginTypeGithub, + OAuthExpiry: database.Now().AddDate(0, 0, -1), + }) + require.NoError(t, err) + token := &oauth2.Token{ AccessToken: "wow", RefreshToken: "moo", @@ -418,7 +436,6 @@ func TestAPIKey(t *testing.T) { require.Equal(t, sentAPIKey.LastUsed, gotAPIKey.LastUsed) require.Equal(t, token.Expiry, gotAPIKey.ExpiresAt) - require.Equal(t, token.AccessToken, gotAPIKey.OAuthAccessToken) }) t.Run("RemoteIPUpdates", func(t *testing.T) { @@ -443,6 +460,7 @@ func TestAPIKey(t *testing.T) { LastUsed: database.Now().AddDate(0, 0, -1), ExpiresAt: database.Now().AddDate(0, 0, 1), UserID: user.ID, + LoginType: database.LoginTypePassword, }) require.NoError(t, err) httpmw.ExtractAPIKey(db, nil, false)(successHandler).ServeHTTP(rw, r) diff --git a/coderd/httpmw/authorize_test.go b/coderd/httpmw/authorize_test.go index 2077926e6f989..14ab9611066d0 100644 --- a/coderd/httpmw/authorize_test.go +++ b/coderd/httpmw/authorize_test.go @@ -124,6 +124,7 @@ func addUser(t *testing.T, db database.Store, roles ...string) (database.User, s HashedSecret: hashed[:], LastUsed: database.Now(), ExpiresAt: database.Now().Add(time.Minute), + LoginType: database.LoginTypePassword, }) require.NoError(t, err) diff --git a/coderd/httpmw/organizationparam_test.go b/coderd/httpmw/organizationparam_test.go index d17c441741914..bdf442391091c 100644 --- a/coderd/httpmw/organizationparam_test.go +++ b/coderd/httpmw/organizationparam_test.go @@ -53,6 +53,7 @@ func TestOrganizationParam(t *testing.T) { HashedSecret: hashed[:], LastUsed: database.Now(), ExpiresAt: database.Now().Add(time.Minute), + LoginType: database.LoginTypePassword, }) require.NoError(t, err) r = r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, chi.NewRouteContext())) diff --git a/coderd/httpmw/templateparam_test.go b/coderd/httpmw/templateparam_test.go index 94abfe82cf5fb..7da35889aa536 100644 --- a/coderd/httpmw/templateparam_test.go +++ b/coderd/httpmw/templateparam_test.go @@ -53,6 +53,7 @@ func TestTemplateParam(t *testing.T) { HashedSecret: hashed[:], LastUsed: database.Now(), ExpiresAt: database.Now().Add(time.Minute), + LoginType: database.LoginTypePassword, }) require.NoError(t, err) diff --git a/coderd/httpmw/templateversionparam_test.go b/coderd/httpmw/templateversionparam_test.go index 5b49f75010bf9..fe4ebba9dfcc3 100644 --- a/coderd/httpmw/templateversionparam_test.go +++ b/coderd/httpmw/templateversionparam_test.go @@ -53,6 +53,7 @@ func TestTemplateVersionParam(t *testing.T) { HashedSecret: hashed[:], LastUsed: database.Now(), ExpiresAt: database.Now().Add(time.Minute), + LoginType: database.LoginTypePassword, }) require.NoError(t, err) diff --git a/coderd/httpmw/userparam_test.go b/coderd/httpmw/userparam_test.go index 866df68ef1eec..d21f2a583d58f 100644 --- a/coderd/httpmw/userparam_test.go +++ b/coderd/httpmw/userparam_test.go @@ -47,6 +47,7 @@ func TestUserParam(t *testing.T) { HashedSecret: hashed[:], LastUsed: database.Now(), ExpiresAt: database.Now().Add(time.Minute), + LoginType: database.LoginTypePassword, }) require.NoError(t, err) diff --git a/coderd/httpmw/workspaceagentparam_test.go b/coderd/httpmw/workspaceagentparam_test.go index a2afaee534c9f..c2d047fdda983 100644 --- a/coderd/httpmw/workspaceagentparam_test.go +++ b/coderd/httpmw/workspaceagentparam_test.go @@ -53,6 +53,7 @@ func TestWorkspaceAgentParam(t *testing.T) { HashedSecret: hashed[:], LastUsed: database.Now(), ExpiresAt: database.Now().Add(time.Minute), + LoginType: database.LoginTypePassword, }) require.NoError(t, err) diff --git a/coderd/httpmw/workspacebuildparam_test.go b/coderd/httpmw/workspacebuildparam_test.go index 6d402f01fc62b..c993a0cb66fb6 100644 --- a/coderd/httpmw/workspacebuildparam_test.go +++ b/coderd/httpmw/workspacebuildparam_test.go @@ -53,6 +53,7 @@ func TestWorkspaceBuildParam(t *testing.T) { HashedSecret: hashed[:], LastUsed: database.Now(), ExpiresAt: database.Now().Add(time.Minute), + LoginType: database.LoginTypePassword, }) require.NoError(t, err) diff --git a/coderd/httpmw/workspaceparam_test.go b/coderd/httpmw/workspaceparam_test.go index eac847a584f3b..ecbebf13c5a4c 100644 --- a/coderd/httpmw/workspaceparam_test.go +++ b/coderd/httpmw/workspaceparam_test.go @@ -53,6 +53,7 @@ func TestWorkspaceParam(t *testing.T) { HashedSecret: hashed[:], LastUsed: database.Now(), ExpiresAt: database.Now().Add(time.Minute), + LoginType: database.LoginTypePassword, }) require.NoError(t, err) diff --git a/coderd/provisionerjobs_internal_test.go b/coderd/provisionerjobs_internal_test.go index 44bdb6d7fd9e7..da593e68f91ce 100644 --- a/coderd/provisionerjobs_internal_test.go +++ b/coderd/provisionerjobs_internal_test.go @@ -73,6 +73,7 @@ func TestProvisionerJobLogs_Unit(t *testing.T) { HashedSecret: hashed[:], UserID: userID, ExpiresAt: time.Now().Add(5 * time.Hour), + LoginType: database.LoginTypePassword, }) require.NoError(t, err) _, err = fDB.InsertUser(ctx, database.InsertUserParams{ diff --git a/coderd/userauth.go b/coderd/userauth.go index 0ddb3f34d7a21..5e7654cd6a08f 100644 --- a/coderd/userauth.go +++ b/coderd/userauth.go @@ -6,12 +6,14 @@ import ( "errors" "fmt" "net/http" + "strconv" "strings" "github.com/coreos/go-oidc/v3/oidc" "github.com/google/go-github/v43/github" "github.com/google/uuid" "golang.org/x/oauth2" + "golang.org/x/xerrors" "github.com/coder/coder/coderd/database" "github.com/coder/coder/coderd/httpapi" @@ -47,10 +49,13 @@ func (api *API) userAuthMethods(rw http.ResponseWriter, _ *http.Request) { } func (api *API) userOAuth2Github(rw http.ResponseWriter, r *http.Request) { - state := httpmw.OAuth2(r) + var ( + ctx = r.Context() + state = httpmw.OAuth2(r) + ) - oauthClient := oauth2.NewClient(r.Context(), oauth2.StaticTokenSource(state.Token)) - memberships, err := api.GithubOAuth2Config.ListOrganizationMemberships(r.Context(), oauthClient) + oauthClient := oauth2.NewClient(ctx, oauth2.StaticTokenSource(state.Token)) + memberships, err := api.GithubOAuth2Config.ListOrganizationMemberships(ctx, oauthClient) if err != nil { httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{ Message: "Internal error fetching authenticated Github user organizations.", @@ -75,7 +80,7 @@ func (api *API) userOAuth2Github(rw http.ResponseWriter, r *http.Request) { return } - ghUser, err := api.GithubOAuth2Config.AuthenticatedUser(r.Context(), oauthClient) + ghUser, err := api.GithubOAuth2Config.AuthenticatedUser(ctx, oauthClient) if err != nil { httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{ Message: "Internal error fetching authenticated Github user.", @@ -94,7 +99,7 @@ func (api *API) userOAuth2Github(rw http.ResponseWriter, r *http.Request) { continue } - allowedTeam, err = api.GithubOAuth2Config.TeamMembership(r.Context(), oauthClient, allowTeam.Organization, allowTeam.Slug, *ghUser.Login) + allowedTeam, err = api.GithubOAuth2Config.TeamMembership(ctx, oauthClient, allowTeam.Organization, allowTeam.Slug, *ghUser.Login) // The calling user may not have permission to the requested team! if err != nil { continue @@ -108,7 +113,7 @@ func (api *API) userOAuth2Github(rw http.ResponseWriter, r *http.Request) { } } - emails, err := api.GithubOAuth2Config.ListEmails(r.Context(), oauthClient) + emails, err := api.GithubOAuth2Config.ListEmails(ctx, oauthClient) if err != nil { httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{ Message: "Internal error fetching personal Github user.", @@ -117,33 +122,35 @@ func (api *API) userOAuth2Github(rw http.ResponseWriter, r *http.Request) { return } - var user database.User - // Search for existing users with matching and verified emails. - // If a verified GitHub email matches a Coder user, we will return. + verifiedEmails := make([]string, 0, len(emails)) for _, email := range emails { if !email.GetVerified() { continue } - user, err = api.Database.GetUserByEmailOrUsername(r.Context(), database.GetUserByEmailOrUsernameParams{ - Email: *email.Email, + verifiedEmails = append(verifiedEmails, email.GetEmail()) + } + + if len(verifiedEmails) == 0 { + httpapi.Write(rw, http.StatusForbidden, codersdk.Response{ + Message: "Verify an email address on Github to authenticate!", }) - if errors.Is(err, sql.ErrNoRows) { - continue - } - if err != nil { - httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{ - Message: fmt.Sprintf("Internal error fetching user by email %q.", *email.Email), - Detail: err.Error(), - }) - return - } - if !*email.Verified { - httpapi.Write(rw, http.StatusForbidden, codersdk.Response{ - Message: fmt.Sprintf("Verify the %q email address on Github to authenticate!", *email.Email), - }) - return - } - break + return + } + + user, link, err := findLinkedUser(ctx, api.Database, githubLinkedID(ghUser), verifiedEmails...) + if err != nil { + httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{ + Message: "An internal error occurred.", + Detail: err.Error(), + }) + return + } + + if user.ID != uuid.Nil && user.LoginType != database.LoginTypeGithub { + httpapi.Write(rw, http.StatusForbidden, codersdk.Response{ + Message: fmt.Sprintf("Incorrect login type, attempting to use %q but user is of login type %q", database.LoginTypeGithub, user.LoginType), + }) + return } // If the user doesn't exist, create a new one! @@ -177,10 +184,13 @@ func (api *API) userOAuth2Github(rw http.ResponseWriter, r *http.Request) { }) return } - user, _, err = api.createUser(r.Context(), codersdk.CreateUserRequest{ - Email: *verifiedEmail.Email, - Username: *ghUser.Login, - OrganizationID: organizationID, + user, _, err = api.createUser(ctx, createUserRequest{ + CreateUserRequest: codersdk.CreateUserRequest{ + Email: *verifiedEmail.Email, + Username: *ghUser.Login, + OrganizationID: organizationID, + }, + LoginType: database.LoginTypeGithub, }) if err != nil { httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{ @@ -191,12 +201,49 @@ func (api *API) userOAuth2Github(rw http.ResponseWriter, r *http.Request) { } } - _, created := api.createAPIKey(rw, r, database.InsertAPIKeyParams{ - UserID: user.ID, - LoginType: database.LoginTypeGithub, - OAuthAccessToken: state.Token.AccessToken, - OAuthRefreshToken: state.Token.RefreshToken, - OAuthExpiry: state.Token.Expiry, + // This can happen if a user is a built-in user but is signing in + // with Github for the first time. + if link.UserID == uuid.Nil { + link, err = api.Database.InsertUserLink(ctx, database.InsertUserLinkParams{ + UserID: user.ID, + LoginType: database.LoginTypeGithub, + LinkedID: githubLinkedID(ghUser), + OAuthAccessToken: state.Token.AccessToken, + OAuthRefreshToken: state.Token.RefreshToken, + OAuthExpiry: state.Token.Expiry, + }) + if err != nil { + httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{ + Message: "A database error occurred.", + Detail: fmt.Sprintf("insert user link: %s", err.Error()), + }) + return + } + } + + // LEGACY: Remove 10/2022. + // We started tracking linked IDs later so it's possible for a user to be a + // pre-existing Github user and not have a linked ID. The migration + // to user_links did not populate this field as it requires calling out + // to Github to query the user's ID. + if link.LinkedID == "" { + link, err = api.Database.UpdateUserLinkedID(ctx, database.UpdateUserLinkedIDParams{ + UserID: user.ID, + LinkedID: githubLinkedID(ghUser), + LoginType: database.LoginTypeGithub, + }) + if err != nil { + httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{ + Message: "A database error occurred.", + Detail: xerrors.Errorf("update user link: %w", err.Error).Error(), + }) + return + } + } + + _, created := api.createAPIKey(rw, r, createAPIKeyParams{ + UserID: user.ID, + LoginType: database.LoginTypeGithub, }) if !created { return @@ -219,7 +266,10 @@ type OIDCConfig struct { } func (api *API) userOIDC(rw http.ResponseWriter, r *http.Request) { - state := httpmw.OAuth2(r) + var ( + ctx = r.Context() + state = httpmw.OAuth2(r) + ) // See the example here: https://github.com/coreos/go-oidc rawIDToken, ok := state.Token.Extra("id_token").(string) @@ -230,7 +280,7 @@ func (api *API) userOIDC(rw http.ResponseWriter, r *http.Request) { return } - idToken, err := api.OIDCConfig.Verifier.Verify(r.Context(), rawIDToken) + idToken, err := api.OIDCConfig.Verifier.Verify(ctx, rawIDToken) if err != nil { httpapi.Write(rw, http.StatusBadRequest, codersdk.Response{ Message: "Failed to verify OIDC token.", @@ -285,29 +335,48 @@ func (api *API) userOIDC(rw http.ResponseWriter, r *http.Request) { } } - var user database.User - user, err = api.Database.GetUserByEmailOrUsername(r.Context(), database.GetUserByEmailOrUsernameParams{ - Email: claims.Email, - }) - if errors.Is(err, sql.ErrNoRows) { - if !api.OIDCConfig.AllowSignups { - httpapi.Write(rw, http.StatusForbidden, codersdk.Response{ - Message: "Signups are disabled for OIDC authentication!", - }) - return - } + user, link, err := findLinkedUser(ctx, api.Database, oidcLinkedID(idToken), claims.Email) + if err != nil { + httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{ + Message: "Failed to find user.", + Detail: err.Error(), + }) + return + } + + if user.ID == uuid.Nil && !api.OIDCConfig.AllowSignups { + httpapi.Write(rw, http.StatusForbidden, codersdk.Response{ + Message: "Signups are disabled for OIDC authentication!", + }) + return + } + + if user.ID != uuid.Nil && user.LoginType != database.LoginTypeOIDC { + httpapi.Write(rw, http.StatusForbidden, codersdk.Response{ + Message: fmt.Sprintf("Incorrect login type, attempting to use %q but user is of login type %q", database.LoginTypeOIDC, user.LoginType), + }) + return + } + + // This can happen if a user is a built-in user but is signing in + // with OIDC for the first time. + if user.ID == uuid.Nil { var organizationID uuid.UUID - organizations, _ := api.Database.GetOrganizations(r.Context()) + organizations, _ := api.Database.GetOrganizations(ctx) if len(organizations) > 0 { // Add the user to the first organization. Once multi-organization // support is added, we should enable a configuration map of user // email to organization. organizationID = organizations[0].ID } - user, _, err = api.createUser(r.Context(), codersdk.CreateUserRequest{ - Email: claims.Email, - Username: claims.Username, - OrganizationID: organizationID, + + user, _, err = api.createUser(ctx, createUserRequest{ + CreateUserRequest: codersdk.CreateUserRequest{ + Email: claims.Email, + Username: claims.Username, + OrganizationID: organizationID, + }, + LoginType: database.LoginTypeOIDC, }) if err != nil { httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{ @@ -316,21 +385,81 @@ func (api *API) userOIDC(rw http.ResponseWriter, r *http.Request) { }) return } + if err != nil { + httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{ + Message: "Failed to insert user auth metadata.", + Detail: err.Error(), + }) + return + } } - if err != nil { - httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{ - Message: "Failed to get user by email.", - Detail: err.Error(), + + if link.UserID == uuid.Nil { + link, err = api.Database.InsertUserLink(ctx, database.InsertUserLinkParams{ + UserID: user.ID, + LoginType: database.LoginTypeOIDC, + LinkedID: oidcLinkedID(idToken), + OAuthAccessToken: state.Token.AccessToken, + OAuthRefreshToken: state.Token.RefreshToken, + OAuthExpiry: state.Token.Expiry, }) - return + if err != nil { + httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{ + Message: "A database error occurred.", + Detail: fmt.Sprintf("insert user link: %s", err.Error()), + }) + return + } + } + + // LEGACY: Remove 10/2022. + // We started tracking linked IDs later so it's possible for a user to be a + // pre-existing OIDC user and not have a linked ID. + // The migration that added the user_links table could not populate + // the 'linked_id' field since it requires fields off the access token. + if link.LinkedID == "" { + link, err = api.Database.UpdateUserLinkedID(ctx, database.UpdateUserLinkedIDParams{ + UserID: user.ID, + LinkedID: oidcLinkedID(idToken), + LoginType: database.LoginTypeGithub, + }) + if err != nil { + httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{ + Message: "A database error occurred.", + Detail: xerrors.Errorf("update user link: %w", err.Error).Error(), + }) + return + } } - _, created := api.createAPIKey(rw, r, database.InsertAPIKeyParams{ - UserID: user.ID, - LoginType: database.LoginTypeOIDC, - OAuthAccessToken: state.Token.AccessToken, - OAuthRefreshToken: state.Token.RefreshToken, - OAuthExpiry: state.Token.Expiry, + // If the upstream email or username has changed we should mirror + // that in Coder. Many enterprises use a user's email/username as + // security auditing fields so they need to stay synced. + if user.Email != claims.Email || user.Username != claims.Username { + // TODO(JonA): Since we're processing updates to a user's upstream + // email/username, it's possible for a different built-in user to + // have already claimed the username. + // In such cases in the current implementation this user can now no + // longer sign in until an administrator finds the offending built-in + // user and changes their username. + user, err = api.Database.UpdateUserProfile(ctx, database.UpdateUserProfileParams{ + ID: user.ID, + Email: claims.Email, + Username: claims.Username, + UpdatedAt: database.Now(), + }) + if err != nil { + httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{ + Message: "Failed to update user profile.", + Detail: fmt.Sprintf("update user profile: %s", err.Error()), + }) + return + } + } + + _, created := api.createAPIKey(rw, r, createAPIKeyParams{ + UserID: user.ID, + LoginType: database.LoginTypeOIDC, }) if !created { return @@ -342,3 +471,66 @@ func (api *API) userOIDC(rw http.ResponseWriter, r *http.Request) { } http.Redirect(rw, r, redirect, http.StatusTemporaryRedirect) } + +// githubLinkedID returns the unique ID for a GitHub user. +func githubLinkedID(u *github.User) string { + return strconv.FormatInt(u.GetID(), 10) +} + +// oidcLinkedID returns the uniqued ID for an OIDC user. +// See https://openid.net/specs/openid-connect-core-1_0.html#ClaimStability . +func oidcLinkedID(tok *oidc.IDToken) string { + return strings.Join([]string{tok.Issuer, tok.Subject}, "||") +} + +// findLinkedUser tries to find a user by their unique OAuth-linked ID. +// If it doesn't not find it, it returns the user by their email. +func findLinkedUser(ctx context.Context, db database.Store, linkedID string, emails ...string) (database.User, database.UserLink, error) { + var ( + user database.User + link database.UserLink + ) + link, err := db.GetUserLinkByLinkedID(ctx, linkedID) + if err != nil && !errors.Is(err, sql.ErrNoRows) { + return user, link, xerrors.Errorf("get user auth by linked ID: %w", err) + } + + if err == nil { + user, err = db.GetUserByID(ctx, link.UserID) + if err != nil { + return database.User{}, database.UserLink{}, xerrors.Errorf("get user by id: %w", err) + } + return user, link, nil + } + + for _, email := range emails { + user, err = db.GetUserByEmailOrUsername(ctx, database.GetUserByEmailOrUsernameParams{ + Email: email, + }) + if err != nil && !errors.Is(err, sql.ErrNoRows) { + return user, link, xerrors.Errorf("get user by email: %w", err) + } + if errors.Is(err, sql.ErrNoRows) { + continue + } + break + } + + if user.ID == uuid.Nil { + // No user found. + return database.User{}, database.UserLink{}, nil + } + + // LEGACY: This is annoying but we have to search for the user_link + // again except this time we search by user_id and login_type. It's + // possible that a user_link exists without a populated 'linked_id'. + link, err = db.GetUserLinkByUserIDLoginType(ctx, database.GetUserLinkByUserIDLoginTypeParams{ + UserID: user.ID, + LoginType: user.LoginType, + }) + if err != nil && !errors.Is(err, sql.ErrNoRows) { + return database.User{}, database.UserLink{}, xerrors.Errorf("get user link by user id and login type: %w", err) + } + + return user, link, nil +} diff --git a/coderd/userauth_test.go b/coderd/userauth_test.go index 6d4c6af34bd30..25fb476673c71 100644 --- a/coderd/userauth_test.go +++ b/coderd/userauth_test.go @@ -175,13 +175,12 @@ func TestUserOAuth2Github(t *testing.T) { resp := oauth2Callback(t, client) require.Equal(t, http.StatusForbidden, resp.StatusCode) }) - t.Run("Signup", func(t *testing.T) { + t.Run("MultiLoginNotAllowed", func(t *testing.T) { t.Parallel() client := coderdtest.New(t, &coderdtest.Options{ GithubOAuth2Config: &coderd.GithubOAuth2Config{ OAuth2Config: &oauth2Config{}, AllowOrganizations: []string{"coder"}, - AllowSignups: true, ListOrganizationMemberships: func(ctx context.Context, client *http.Client) ([]*github.Membership, error) { return []*github.Membership{{ Organization: &github.Organization{ @@ -190,28 +189,30 @@ func TestUserOAuth2Github(t *testing.T) { }}, nil }, AuthenticatedUser: func(ctx context.Context, client *http.Client) (*github.User, error) { - return &github.User{ - Login: github.String("kyle"), - }, nil + return &github.User{}, nil }, ListEmails: func(ctx context.Context, client *http.Client) ([]*github.UserEmail, error) { return []*github.UserEmail{{ - Email: github.String("kyle@coder.com"), + Email: github.String("testuser@coder.com"), Verified: github.Bool(true), - Primary: github.Bool(true), }}, nil }, }, }) + // Creates the first user with login_type 'password'. + _ = coderdtest.CreateFirstUser(t, client) + // Attempting to login should give us a 403 since the user + // already has a login_type of 'password'. resp := oauth2Callback(t, client) - require.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode) + require.Equal(t, http.StatusForbidden, resp.StatusCode) }) - t.Run("Login", func(t *testing.T) { + t.Run("Signup", func(t *testing.T) { t.Parallel() client := coderdtest.New(t, &coderdtest.Options{ GithubOAuth2Config: &coderd.GithubOAuth2Config{ OAuth2Config: &oauth2Config{}, AllowOrganizations: []string{"coder"}, + AllowSignups: true, ListOrganizationMemberships: func(ctx context.Context, client *http.Client) ([]*github.Membership, error) { return []*github.Membership{{ Organization: &github.Organization{ @@ -220,19 +221,28 @@ func TestUserOAuth2Github(t *testing.T) { }}, nil }, AuthenticatedUser: func(ctx context.Context, client *http.Client) (*github.User, error) { - return &github.User{}, nil + return &github.User{ + Login: github.String("kyle"), + ID: i64ptr(1234), + }, nil }, ListEmails: func(ctx context.Context, client *http.Client) ([]*github.UserEmail, error) { return []*github.UserEmail{{ - Email: github.String("testuser@coder.com"), + Email: github.String("kyle@coder.com"), Verified: github.Bool(true), + Primary: github.Bool(true), }}, nil }, }, }) - _ = coderdtest.CreateFirstUser(t, client) resp := oauth2Callback(t, client) require.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode) + + client.SessionToken = resp.Cookies()[0].Value + user, err := client.User(context.Background(), "me") + require.NoError(t, err) + require.Equal(t, "kyle@coder.com", user.Email) + require.Equal(t, "kyle", user.Username) }) t.Run("SignupAllowedTeam", func(t *testing.T) { t.Parallel() @@ -415,11 +425,13 @@ func createOIDCConfig(t *testing.T, claims jwt.MapClaims) *coderd.OIDCConfig { // https://datatracker.ietf.org/doc/html/rfc7519#section-4.1 claims["exp"] = time.Now().Add(time.Hour).UnixMilli() + claims["iss"] = "https://coder.com" + claims["sub"] = "hello" signed, err := jwt.NewWithClaims(jwt.SigningMethodRS256, claims).SignedString(key) require.NoError(t, err) - verifier := oidc.NewVerifier("", &oidc.StaticKeySet{ + verifier := oidc.NewVerifier("https://coder.com", &oidc.StaticKeySet{ PublicKeys: []crypto.PublicKey{key.Public()}, }, &oidc.Config{ SkipClientIDCheck: true, @@ -480,3 +492,7 @@ func oidcCallback(t *testing.T, client *codersdk.Client) *http.Response { t.Log(string(data)) return res } + +func i64ptr(i int64) *int64 { + return &i +} diff --git a/coderd/users.go b/coderd/users.go index 9d3297eee68b3..e77364eeba6ef 100644 --- a/coderd/users.go +++ b/coderd/users.go @@ -77,10 +77,13 @@ func (api *API) postFirstUser(rw http.ResponseWriter, r *http.Request) { return } - user, organizationID, err := api.createUser(r.Context(), codersdk.CreateUserRequest{ - Email: createUser.Email, - Username: createUser.Username, - Password: createUser.Password, + user, organizationID, err := api.createUser(r.Context(), createUserRequest{ + CreateUserRequest: codersdk.CreateUserRequest{ + Email: createUser.Email, + Username: createUser.Username, + Password: createUser.Password, + }, + LoginType: database.LoginTypePassword, }) if err != nil { httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{ @@ -196,14 +199,14 @@ func (api *API) postUser(rw http.ResponseWriter, r *http.Request) { return } - var createUser codersdk.CreateUserRequest - if !httpapi.Read(rw, r, &createUser) { + var req codersdk.CreateUserRequest + if !httpapi.Read(rw, r, &req) { return } // Create the organization member in the org. if !api.Authorize(r, rbac.ActionCreate, - rbac.ResourceOrganizationMember.InOrg(createUser.OrganizationID)) { + rbac.ResourceOrganizationMember.InOrg(req.OrganizationID)) { httpapi.ResourceNotFound(rw) return } @@ -211,8 +214,8 @@ func (api *API) postUser(rw http.ResponseWriter, r *http.Request) { // TODO: @emyrk Authorize the organization create if the createUser will do that. _, err := api.Database.GetUserByEmailOrUsername(r.Context(), database.GetUserByEmailOrUsernameParams{ - Username: createUser.Username, - Email: createUser.Email, + Username: req.Username, + Email: req.Email, }) if err == nil { httpapi.Write(rw, http.StatusConflict, codersdk.Response{ @@ -228,10 +231,10 @@ func (api *API) postUser(rw http.ResponseWriter, r *http.Request) { return } - _, err = api.Database.GetOrganizationByID(r.Context(), createUser.OrganizationID) + _, err = api.Database.GetOrganizationByID(r.Context(), req.OrganizationID) if errors.Is(err, sql.ErrNoRows) { httpapi.Write(rw, http.StatusNotFound, codersdk.Response{ - Message: fmt.Sprintf("Organization does not exist with the provided id %q.", createUser.OrganizationID), + Message: fmt.Sprintf("Organization does not exist with the provided id %q.", req.OrganizationID), }) return } @@ -243,7 +246,10 @@ func (api *API) postUser(rw http.ResponseWriter, r *http.Request) { return } - user, _, err := api.createUser(r.Context(), createUser) + user, _, err := api.createUser(r.Context(), createUserRequest{ + CreateUserRequest: req, + LoginType: database.LoginTypePassword, + }) if err != nil { httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{ Message: "Internal error creating user.", @@ -257,7 +263,7 @@ func (api *API) postUser(rw http.ResponseWriter, r *http.Request) { Users: []telemetry.User{telemetry.ConvertUser(user)}, }) - httpapi.Write(rw, http.StatusCreated, convertUser(user, []uuid.UUID{createUser.OrganizationID})) + httpapi.Write(rw, http.StatusCreated, convertUser(user, []uuid.UUID{req.OrganizationID})) } // Returns the parameterized user requested. All validation @@ -701,6 +707,13 @@ func (api *API) postLogin(rw http.ResponseWriter, r *http.Request) { return } + if user.LoginType != database.LoginTypePassword { + httpapi.Write(rw, http.StatusForbidden, codersdk.Response{ + Message: fmt.Sprintf("Incorrect login type, attempting to use %q but user is of login type %q", database.LoginTypePassword, user.LoginType), + }) + return + } + // If the user logged into a suspended account, reject the login request. if user.Status != database.UserStatusActive { httpapi.Write(rw, http.StatusUnauthorized, codersdk.Response{ @@ -709,7 +722,7 @@ func (api *API) postLogin(rw http.ResponseWriter, r *http.Request) { return } - sessionToken, created := api.createAPIKey(rw, r, database.InsertAPIKeyParams{ + sessionToken, created := api.createAPIKey(rw, r, createAPIKeyParams{ UserID: user.ID, LoginType: database.LoginTypePassword, }) @@ -732,7 +745,7 @@ func (api *API) postAPIKey(rw http.ResponseWriter, r *http.Request) { } lifeTime := time.Hour * 24 * 7 - sessionToken, created := api.createAPIKey(rw, r, database.InsertAPIKeyParams{ + sessionToken, created := api.createAPIKey(rw, r, createAPIKeyParams{ UserID: user.ID, LoginType: database.LoginTypePassword, // All api generated keys will last 1 week. Browser login tokens have @@ -818,7 +831,16 @@ func generateAPIKeyIDSecret() (id string, secret string, err error) { return id, secret, nil } -func (api *API) createAPIKey(rw http.ResponseWriter, r *http.Request, params database.InsertAPIKeyParams) (string, bool) { +type createAPIKeyParams struct { + UserID uuid.UUID + LoginType database.LoginType + + // Optional. + ExpiresAt time.Time + LifetimeSeconds int64 +} + +func (api *API) createAPIKey(rw http.ResponseWriter, r *http.Request, params createAPIKeyParams) (string, bool) { keyID, keySecret, err := generateAPIKeyIDSecret() if err != nil { httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{ @@ -856,15 +878,11 @@ func (api *API) createAPIKey(rw http.ResponseWriter, r *http.Request, params dat Valid: true, }, // Make sure in UTC time for common time zone - ExpiresAt: params.ExpiresAt.UTC(), - CreatedAt: database.Now(), - UpdatedAt: database.Now(), - HashedSecret: hashed[:], - LoginType: params.LoginType, - OAuthAccessToken: params.OAuthAccessToken, - OAuthRefreshToken: params.OAuthRefreshToken, - OAuthIDToken: params.OAuthIDToken, - OAuthExpiry: params.OAuthExpiry, + ExpiresAt: params.ExpiresAt.UTC(), + CreatedAt: database.Now(), + UpdatedAt: database.Now(), + HashedSecret: hashed[:], + LoginType: params.LoginType, }) if err != nil { httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{ @@ -891,7 +909,12 @@ func (api *API) createAPIKey(rw http.ResponseWriter, r *http.Request, params dat return sessionToken, true } -func (api *API) createUser(ctx context.Context, req codersdk.CreateUserRequest) (database.User, uuid.UUID, error) { +type createUserRequest struct { + codersdk.CreateUserRequest + LoginType database.LoginType +} + +func (api *API) createUser(ctx context.Context, req createUserRequest) (database.User, uuid.UUID, error) { var user database.User return user, req.OrganizationID, api.Database.InTx(func(db database.Store) error { orgRoles := make([]string, 0) @@ -918,6 +941,7 @@ func (api *API) createUser(ctx context.Context, req codersdk.CreateUserRequest) UpdatedAt: database.Now(), // All new users are defaulted to members of the site. RBACRoles: []string{}, + LoginType: req.LoginType, } // If a user signs up with OAuth, they can have no password! if req.Password != "" { diff --git a/codersdk/users.go b/codersdk/users.go index 17252c20405c3..72b51306a5bb8 100644 --- a/codersdk/users.go +++ b/codersdk/users.go @@ -27,6 +27,7 @@ type LoginType string const ( LoginTypePassword LoginType = "password" LoginTypeGithub LoginType = "github" + LoginTypeOIDC LoginType = "oidc" ) type UsersRequest struct { diff --git a/site/src/api/typesGenerated.ts b/site/src/api/typesGenerated.ts index e849b1a1f051f..b42497dec88f4 100644 --- a/site/src/api/typesGenerated.ts +++ b/site/src/api/typesGenerated.ts @@ -542,7 +542,7 @@ export type LogLevel = "debug" | "error" | "info" | "trace" | "warn" export type LogSource = "provisioner" | "provisioner_daemon" // From codersdk/users.go -export type LoginType = "github" | "password" +export type LoginType = "github" | "oidc" | "password" // From codersdk/parameters.go export type ParameterDestinationScheme = "environment_variable" | "none" | "provisioner_variable"