diff --git a/coderd/database/dbauthz/dbauthz.go b/coderd/database/dbauthz/dbauthz.go index 6156329cf7ddd..55a7eb8acefab 100644 --- a/coderd/database/dbauthz/dbauthz.go +++ b/coderd/database/dbauthz/dbauthz.go @@ -664,6 +664,15 @@ func (q *querier) ActivityBumpWorkspace(ctx context.Context, arg uuid.UUID) erro return update(q.log, q.auth, fetch, q.db.ActivityBumpWorkspace)(ctx, arg) } +func (q *querier) AllUserIDs(ctx context.Context) ([]uuid.UUID, error) { + // Although this technically only reads users, only system-related functions should be + // allowed to call this. + if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceSystem); err != nil { + return nil, err + } + return q.db.AllUserIDs(ctx) +} + func (q *querier) CleanTailnetCoordinators(ctx context.Context) error { if err := q.authorizeContext(ctx, rbac.ActionDelete, rbac.ResourceTailnetCoordinator); err != nil { return err diff --git a/coderd/database/dbfake/dbfake.go b/coderd/database/dbfake/dbfake.go index ab7363b275a2f..9ac8ed640af2f 100644 --- a/coderd/database/dbfake/dbfake.go +++ b/coderd/database/dbfake/dbfake.go @@ -812,6 +812,16 @@ func (q *FakeQuerier) ActivityBumpWorkspace(ctx context.Context, workspaceID uui return sql.ErrNoRows } +func (q *FakeQuerier) AllUserIDs(_ context.Context) ([]uuid.UUID, error) { + q.mutex.RLock() + defer q.mutex.RUnlock() + userIDs := make([]uuid.UUID, 0, len(q.users)) + for idx := range q.users { + userIDs[idx] = q.users[idx].ID + } + return userIDs, nil +} + func (*FakeQuerier) CleanTailnetCoordinators(_ context.Context) error { return ErrUnimplemented } diff --git a/coderd/database/dbgen/dbgen.go b/coderd/database/dbgen/dbgen.go index f68a8cebadcc8..56f85d10ef476 100644 --- a/coderd/database/dbgen/dbgen.go +++ b/coderd/database/dbgen/dbgen.go @@ -227,7 +227,7 @@ func User(t testing.TB, db database.Store, orig database.User) database.User { user, err = db.UpdateUserStatus(genCtx, database.UpdateUserStatusParams{ ID: user.ID, - Status: database.UserStatusActive, + Status: takeFirst(orig.Status, database.UserStatusActive), UpdatedAt: dbtime.Now(), }) require.NoError(t, err, "insert user") @@ -240,6 +240,14 @@ func User(t testing.TB, db database.Store, orig database.User) database.User { }) require.NoError(t, err, "user last seen") } + + if orig.Deleted { + err = db.UpdateUserDeletedByID(genCtx, database.UpdateUserDeletedByIDParams{ + ID: user.ID, + Deleted: orig.Deleted, + }) + require.NoError(t, err, "set user as deleted") + } return user } diff --git a/coderd/database/dbmetrics/dbmetrics.go b/coderd/database/dbmetrics/dbmetrics.go index 768c1d4adbcca..82e4d5e68ec70 100644 --- a/coderd/database/dbmetrics/dbmetrics.go +++ b/coderd/database/dbmetrics/dbmetrics.go @@ -100,6 +100,13 @@ func (m metricsStore) ActivityBumpWorkspace(ctx context.Context, arg uuid.UUID) return r0 } +func (m metricsStore) AllUserIDs(ctx context.Context) ([]uuid.UUID, error) { + start := time.Now() + r0, r1 := m.s.AllUserIDs(ctx) + m.queryLatencies.WithLabelValues("AllUserIDs").Observe(time.Since(start).Seconds()) + return r0, r1 +} + func (m metricsStore) CleanTailnetCoordinators(ctx context.Context) error { start := time.Now() err := m.s.CleanTailnetCoordinators(ctx) diff --git a/coderd/database/dbmock/dbmock.go b/coderd/database/dbmock/dbmock.go index 641dd7315b936..331e279b0a925 100644 --- a/coderd/database/dbmock/dbmock.go +++ b/coderd/database/dbmock/dbmock.go @@ -82,6 +82,21 @@ func (mr *MockStoreMockRecorder) ActivityBumpWorkspace(arg0, arg1 interface{}) * return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ActivityBumpWorkspace", reflect.TypeOf((*MockStore)(nil).ActivityBumpWorkspace), arg0, arg1) } +// AllUserIDs mocks base method. +func (m *MockStore) AllUserIDs(arg0 context.Context) ([]uuid.UUID, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "AllUserIDs", arg0) + ret0, _ := ret[0].([]uuid.UUID) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// AllUserIDs indicates an expected call of AllUserIDs. +func (mr *MockStoreMockRecorder) AllUserIDs(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AllUserIDs", reflect.TypeOf((*MockStore)(nil).AllUserIDs), arg0) +} + // CleanTailnetCoordinators mocks base method. func (m *MockStore) CleanTailnetCoordinators(arg0 context.Context) error { m.ctrl.T.Helper() diff --git a/coderd/database/querier.go b/coderd/database/querier.go index 63c1f7321dc15..b9dc7f5fc4e64 100644 --- a/coderd/database/querier.go +++ b/coderd/database/querier.go @@ -31,6 +31,8 @@ type sqlcQuerier interface { // We only bump if workspace shutdown is manual. // We only bump when 5% of the deadline has elapsed. ActivityBumpWorkspace(ctx context.Context, workspaceID uuid.UUID) error + // AllUserIDs returns all UserIDs regardless of user status or deletion. + AllUserIDs(ctx context.Context) ([]uuid.UUID, error) CleanTailnetCoordinators(ctx context.Context) error DeleteAPIKeyByID(ctx context.Context, id string) error DeleteAPIKeysByUserID(ctx context.Context, userID uuid.UUID) error diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index 67b7b782b9e29..1f7df65d50f86 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -5840,6 +5840,34 @@ func (q *sqlQuerier) UpdateUserLinkedID(ctx context.Context, arg UpdateUserLinke return i, err } +const allUserIDs = `-- name: AllUserIDs :many +SELECT DISTINCT id FROM USERS +` + +// AllUserIDs returns all UserIDs regardless of user status or deletion. +func (q *sqlQuerier) AllUserIDs(ctx context.Context) ([]uuid.UUID, error) { + rows, err := q.db.QueryContext(ctx, allUserIDs) + if err != nil { + return nil, err + } + defer rows.Close() + var items []uuid.UUID + for rows.Next() { + var id uuid.UUID + if err := rows.Scan(&id); err != nil { + return nil, err + } + items = append(items, id) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const getActiveUserCount = `-- name: GetActiveUserCount :one SELECT COUNT(*) diff --git a/coderd/database/queries/dbcrypt.sql b/coderd/database/queries/dbcrypt.sql index ef1021609d5a7..41dc051c8331d 100644 --- a/coderd/database/queries/dbcrypt.sql +++ b/coderd/database/queries/dbcrypt.sql @@ -16,3 +16,4 @@ AND INSERT INTO dbcrypt_keys (number, active_key_digest, created_at, test) VALUES (@number::int, @active_key_digest::text, CURRENT_TIMESTAMP, @test::text); + diff --git a/coderd/database/queries/users.sql b/coderd/database/queries/users.sql index 8560bf0abf696..8caa74a92e588 100644 --- a/coderd/database/queries/users.sql +++ b/coderd/database/queries/users.sql @@ -262,3 +262,8 @@ WHERE last_seen_at < @last_seen_after :: timestamp AND status = 'active'::user_status RETURNING id, email, last_seen_at; + +-- AllUserIDs returns all UserIDs regardless of user status or deletion. +-- name: AllUserIDs :many +SELECT DISTINCT id FROM USERS; + diff --git a/enterprise/cli/server_dbcrypt_test.go b/enterprise/cli/server_dbcrypt_test.go index cebf014a7ce58..22ec5e1a99b5f 100644 --- a/enterprise/cli/server_dbcrypt_test.go +++ b/enterprise/cli/server_dbcrypt_test.go @@ -20,6 +20,9 @@ import ( "github.com/coder/coder/v2/pty/ptytest" ) +// TestServerDBCrypt tests end-to-end encryption, decryption, and deletion +// of encrypted user data. +// // nolint: paralleltest // use of t.Setenv func TestServerDBCrypt(t *testing.T) { if !dbtestutil.WillUsePostgres() { @@ -41,15 +44,38 @@ func TestServerDBCrypt(t *testing.T) { }) db := database.New(sqlDB) + t.Cleanup(func() { + if t.Failed() { + t.Logf("Dumping data due to failed test. I hope you find what you're looking for!") + dumpUsers(t, sqlDB) + } + }) + // Populate the database with some unencrypted data. - users := genData(t, db, 10) + t.Logf("Generating unencrypted data") + users := genData(t, db) - // Setup an initial cipher + // Setup an initial cipher A keyA := mustString(t, 32) cipherA, err := dbcrypt.NewCiphers([]byte(keyA)) require.NoError(t, err) + // Create an encrypted database + cryptdb, err := dbcrypt.New(ctx, db, cipherA...) + require.NoError(t, err) + + // Populate the database with some encrypted data using cipher A. + t.Logf("Generating data encrypted with cipher A") + newUsers := genData(t, cryptdb) + + // Validate that newly created users were encrypted with cipher A + for _, usr := range newUsers { + requireEncryptedWithCipher(ctx, t, db, cipherA[0], usr.ID) + } + users = append(users, newUsers...) + // Encrypt all the data with the initial cipher. + t.Logf("Encrypting all data with cipher A") inv, _ := newCLI(t, "server", "dbcrypt", "rotate", "--postgres-url", connectionURL, "--new-key", base64.StdEncoding.EncodeToString([]byte(keyA)), @@ -65,18 +91,12 @@ func TestServerDBCrypt(t *testing.T) { requireEncryptedWithCipher(ctx, t, db, cipherA[0], usr.ID) } - // Create an encrypted database - cryptdb, err := dbcrypt.New(ctx, db, cipherA...) - require.NoError(t, err) - - // Populate the database with some encrypted data using cipher A. - users = append(users, genData(t, cryptdb, 10)...) - // Re-encrypt all existing data with a new cipher. keyB := mustString(t, 32) cipherBA, err := dbcrypt.NewCiphers([]byte(keyB), []byte(keyA)) require.NoError(t, err) + t.Logf("Enrypting all data with cipher B") inv, _ = newCLI(t, "server", "dbcrypt", "rotate", "--postgres-url", connectionURL, "--new-key", base64.StdEncoding.EncodeToString([]byte(keyB)), @@ -94,6 +114,7 @@ func TestServerDBCrypt(t *testing.T) { } // Assert that we can revoke the old key. + t.Logf("Revoking cipher A") err = db.RevokeDBCryptKey(ctx, cipherA[0].HexDigest()) require.NoError(t, err, "failed to revoke old key") @@ -109,6 +130,7 @@ func TestServerDBCrypt(t *testing.T) { require.Empty(t, oldKey.ActiveKeyDigest.String, "expected the old key to not be active") // Revoking the new key should fail. + t.Logf("Attempting to revoke cipher B should fail as it is still in use") err = db.RevokeDBCryptKey(ctx, cipherBA[0].HexDigest()) require.Error(t, err, "expected to fail to revoke the new key") var pgErr *pq.Error @@ -116,6 +138,7 @@ func TestServerDBCrypt(t *testing.T) { require.EqualValues(t, "23503", pgErr.Code, "expected a foreign key constraint violation error") // Decrypt the data using only cipher B. This should result in the key being revoked. + t.Logf("Decrypting with cipher B") inv, _ = newCLI(t, "server", "dbcrypt", "decrypt", "--postgres-url", connectionURL, "--keys", base64.StdEncoding.EncodeToString([]byte(keyB)), @@ -144,6 +167,7 @@ func TestServerDBCrypt(t *testing.T) { cipherC, err := dbcrypt.NewCiphers([]byte(keyC)) require.NoError(t, err) + t.Logf("Re-encrypting with cipher C") inv, _ = newCLI(t, "server", "dbcrypt", "rotate", "--postgres-url", connectionURL, "--new-key", base64.StdEncoding.EncodeToString([]byte(keyC)), @@ -161,6 +185,7 @@ func TestServerDBCrypt(t *testing.T) { } // Now delete all the encrypted data. + t.Logf("Deleting all encrypted data") inv, _ = newCLI(t, "server", "dbcrypt", "delete", "--postgres-url", connectionURL, "--external-token-encryption-keys", base64.StdEncoding.EncodeToString([]byte(keyC)), @@ -191,30 +216,84 @@ func TestServerDBCrypt(t *testing.T) { } } -func genData(t *testing.T, db database.Store, n int) []database.User { +func genData(t *testing.T, db database.Store) []database.User { t.Helper() var users []database.User - for i := 0; i < n; i++ { - usr := dbgen.User(t, db, database.User{ - LoginType: database.LoginTypeOIDC, - }) - _ = dbgen.UserLink(t, db, database.UserLink{ - UserID: usr.ID, - LoginType: usr.LoginType, - OAuthAccessToken: "access-" + usr.ID.String(), - OAuthRefreshToken: "refresh-" + usr.ID.String(), - }) - _ = dbgen.GitAuthLink(t, db, database.GitAuthLink{ - UserID: usr.ID, - ProviderID: "fake", - OAuthAccessToken: "access-" + usr.ID.String(), - OAuthRefreshToken: "refresh-" + usr.ID.String(), - }) - users = append(users, usr) + // Make some users + for _, status := range database.AllUserStatusValues() { + for _, loginType := range database.AllLoginTypeValues() { + for _, deleted := range []bool{false, true} { + usr := dbgen.User(t, db, database.User{ + LoginType: loginType, + Status: status, + Deleted: deleted, + }) + _ = dbgen.GitAuthLink(t, db, database.GitAuthLink{ + UserID: usr.ID, + ProviderID: "fake", + OAuthAccessToken: "access-" + usr.ID.String(), + OAuthRefreshToken: "refresh-" + usr.ID.String(), + }) + // Fun fact: our schema allows _all_ login types to have + // a user_link. Even though I'm not sure how it could occur + // in practice, making sure to test all combinations here. + _ = dbgen.UserLink(t, db, database.UserLink{ + UserID: usr.ID, + LoginType: usr.LoginType, + OAuthAccessToken: "access-" + usr.ID.String(), + OAuthRefreshToken: "refresh-" + usr.ID.String(), + }) + users = append(users, usr) + } + } } return users } +func dumpUsers(t *testing.T, db *sql.DB) { + t.Helper() + rows, err := db.QueryContext(context.Background(), `SELECT + u.id, + u.login_type, + u.status, + u.deleted, + ul.oauth_access_token_key_id AS uloatkid, + ul.oauth_refresh_token_key_id AS ulortkid, + gal.oauth_access_token_key_id AS galoatkid, + gal.oauth_refresh_token_key_id AS galortkid +FROM users u +LEFT OUTER JOIN user_links ul ON u.id = ul.user_id +LEFT OUTER JOIN git_auth_links gal ON u.id = gal.user_id +ORDER BY u.created_at ASC;`) + require.NoError(t, err) + defer rows.Close() + for rows.Next() { + var ( + id string + loginType string + status string + deleted bool + UlOatKid sql.NullString + UlOrtKid sql.NullString + GalOatKid sql.NullString + GalOrtKid sql.NullString + ) + require.NoError(t, rows.Scan( + &id, + &loginType, + &status, + &deleted, + &UlOatKid, + &UlOrtKid, + &GalOatKid, + &GalOrtKid, + )) + t.Logf("user: id:%s login_type:%-8s status:%-9s deleted:%-5t ul_kids{at:%-7s rt:%-7s} gal_kids{at:%-7s rt:%-7s}", + id, loginType, status, deleted, UlOatKid.String, UlOrtKid.String, GalOatKid.String, GalOrtKid.String, + ) + } +} + func mustString(t *testing.T, n int) string { t.Helper() s, err := cryptorand.String(n) diff --git a/enterprise/dbcrypt/cliutil.go b/enterprise/dbcrypt/cliutil.go index 7f68e284afe77..3601d0c539c2e 100644 --- a/enterprise/dbcrypt/cliutil.go +++ b/enterprise/dbcrypt/cliutil.go @@ -19,45 +19,45 @@ func Rotate(ctx context.Context, log slog.Logger, sqlDB *sql.DB, ciphers []Ciphe return xerrors.Errorf("create cryptdb: %w", err) } - users, err := cryptDB.GetUsers(ctx, database.GetUsersParams{}) + userIDs, err := db.AllUserIDs(ctx) if err != nil { return xerrors.Errorf("get users: %w", err) } - log.Info(ctx, "encrypting user tokens", slog.F("user_count", len(users))) - for idx, usr := range users { + log.Info(ctx, "encrypting user tokens", slog.F("user_count", len(userIDs))) + for idx, uid := range userIDs { err := cryptDB.InTx(func(tx database.Store) error { - userLinks, err := tx.GetUserLinksByUserID(ctx, usr.ID) + userLinks, err := tx.GetUserLinksByUserID(ctx, uid) if err != nil { return xerrors.Errorf("get user links for user: %w", err) } for _, userLink := range userLinks { if userLink.OAuthAccessTokenKeyID.String == ciphers[0].HexDigest() && userLink.OAuthRefreshTokenKeyID.String == ciphers[0].HexDigest() { - log.Debug(ctx, "skipping user link", slog.F("user_id", usr.ID), slog.F("current", idx+1), slog.F("cipher", ciphers[0].HexDigest())) + log.Debug(ctx, "skipping user link", slog.F("user_id", uid), slog.F("current", idx+1), slog.F("cipher", ciphers[0].HexDigest())) continue } if _, err := tx.UpdateUserLink(ctx, database.UpdateUserLinkParams{ OAuthAccessToken: userLink.OAuthAccessToken, OAuthRefreshToken: userLink.OAuthRefreshToken, OAuthExpiry: userLink.OAuthExpiry, - UserID: usr.ID, - LoginType: usr.LoginType, + UserID: uid, + LoginType: userLink.LoginType, }); err != nil { return xerrors.Errorf("update user link user_id=%s linked_id=%s: %w", userLink.UserID, userLink.LinkedID, err) } } - gitAuthLinks, err := tx.GetGitAuthLinksByUserID(ctx, usr.ID) + gitAuthLinks, err := tx.GetGitAuthLinksByUserID(ctx, uid) if err != nil { return xerrors.Errorf("get git auth links for user: %w", err) } for _, gitAuthLink := range gitAuthLinks { if gitAuthLink.OAuthAccessTokenKeyID.String == ciphers[0].HexDigest() && gitAuthLink.OAuthRefreshTokenKeyID.String == ciphers[0].HexDigest() { - log.Debug(ctx, "skipping git auth link", slog.F("user_id", usr.ID), slog.F("current", idx+1), slog.F("cipher", ciphers[0].HexDigest())) + log.Debug(ctx, "skipping git auth link", slog.F("user_id", uid), slog.F("current", idx+1), slog.F("cipher", ciphers[0].HexDigest())) continue } if _, err := tx.UpdateGitAuthLink(ctx, database.UpdateGitAuthLinkParams{ ProviderID: gitAuthLink.ProviderID, - UserID: usr.ID, + UserID: uid, UpdatedAt: gitAuthLink.UpdatedAt, OAuthAccessToken: gitAuthLink.OAuthAccessToken, OAuthRefreshToken: gitAuthLink.OAuthRefreshToken, @@ -73,7 +73,7 @@ func Rotate(ctx context.Context, log slog.Logger, sqlDB *sql.DB, ciphers []Ciphe if err != nil { return xerrors.Errorf("update user links: %w", err) } - log.Debug(ctx, "encrypted user tokens", slog.F("user_id", usr.ID), slog.F("current", idx+1), slog.F("cipher", ciphers[0].HexDigest())) + log.Debug(ctx, "encrypted user tokens", slog.F("user_id", uid), slog.F("current", idx+1), slog.F("cipher", ciphers[0].HexDigest())) } // Revoke old keys @@ -103,45 +103,45 @@ func Decrypt(ctx context.Context, log slog.Logger, sqlDB *sql.DB, ciphers []Ciph } cryptDB.primaryCipherDigest = "" - users, err := cryptDB.GetUsers(ctx, database.GetUsersParams{}) + userIDs, err := db.AllUserIDs(ctx) if err != nil { return xerrors.Errorf("get users: %w", err) } - log.Info(ctx, "decrypting user tokens", slog.F("user_count", len(users))) - for idx, usr := range users { + log.Info(ctx, "decrypting user tokens", slog.F("user_count", len(userIDs))) + for idx, uid := range userIDs { err := cryptDB.InTx(func(tx database.Store) error { - userLinks, err := tx.GetUserLinksByUserID(ctx, usr.ID) + userLinks, err := tx.GetUserLinksByUserID(ctx, uid) if err != nil { return xerrors.Errorf("get user links for user: %w", err) } for _, userLink := range userLinks { if !userLink.OAuthAccessTokenKeyID.Valid && !userLink.OAuthRefreshTokenKeyID.Valid { - log.Debug(ctx, "skipping user link", slog.F("user_id", usr.ID), slog.F("current", idx+1)) + log.Debug(ctx, "skipping user link", slog.F("user_id", uid), slog.F("current", idx+1)) continue } if _, err := tx.UpdateUserLink(ctx, database.UpdateUserLinkParams{ OAuthAccessToken: userLink.OAuthAccessToken, OAuthRefreshToken: userLink.OAuthRefreshToken, OAuthExpiry: userLink.OAuthExpiry, - UserID: usr.ID, - LoginType: usr.LoginType, + UserID: uid, + LoginType: userLink.LoginType, }); err != nil { return xerrors.Errorf("update user link user_id=%s linked_id=%s: %w", userLink.UserID, userLink.LinkedID, err) } } - gitAuthLinks, err := tx.GetGitAuthLinksByUserID(ctx, usr.ID) + gitAuthLinks, err := tx.GetGitAuthLinksByUserID(ctx, uid) if err != nil { return xerrors.Errorf("get git auth links for user: %w", err) } for _, gitAuthLink := range gitAuthLinks { if !gitAuthLink.OAuthAccessTokenKeyID.Valid && !gitAuthLink.OAuthRefreshTokenKeyID.Valid { - log.Debug(ctx, "skipping git auth link", slog.F("user_id", usr.ID), slog.F("current", idx+1)) + log.Debug(ctx, "skipping git auth link", slog.F("user_id", uid), slog.F("current", idx+1)) continue } if _, err := tx.UpdateGitAuthLink(ctx, database.UpdateGitAuthLinkParams{ ProviderID: gitAuthLink.ProviderID, - UserID: usr.ID, + UserID: uid, UpdatedAt: gitAuthLink.UpdatedAt, OAuthAccessToken: gitAuthLink.OAuthAccessToken, OAuthRefreshToken: gitAuthLink.OAuthRefreshToken, @@ -157,7 +157,7 @@ func Decrypt(ctx context.Context, log slog.Logger, sqlDB *sql.DB, ciphers []Ciph if err != nil { return xerrors.Errorf("update user links: %w", err) } - log.Debug(ctx, "decrypted user tokens", slog.F("user_id", usr.ID), slog.F("current", idx+1), slog.F("cipher", ciphers[0].HexDigest())) + log.Debug(ctx, "decrypted user tokens", slog.F("user_id", uid), slog.F("current", idx+1), slog.F("cipher", ciphers[0].HexDigest())) } // Revoke _all_ keys