From 355b2c1059f4637f1b526008fe9abd7a69753396 Mon Sep 17 00:00:00 2001 From: Cian Johnston Date: Fri, 15 Sep 2023 10:25:25 +0100 Subject: [PATCH 01/11] add some more logging to test --- enterprise/cli/server_dbcrypt_test.go | 59 +++++++++++++++++++++++---- 1 file changed, 51 insertions(+), 8 deletions(-) diff --git a/enterprise/cli/server_dbcrypt_test.go b/enterprise/cli/server_dbcrypt_test.go index cebf014a7ce58..9a66288ec9781 100644 --- a/enterprise/cli/server_dbcrypt_test.go +++ b/enterprise/cli/server_dbcrypt_test.go @@ -4,6 +4,7 @@ import ( "context" "database/sql" "encoding/base64" + "strings" "testing" "github.com/google/uuid" @@ -43,12 +44,26 @@ func TestServerDBCrypt(t *testing.T) { // Populate the database with some unencrypted data. users := genData(t, db, 10) + dumpUsers(t, sqlDB, "NOT ENCRYPTED") - // Setup an initial cipher + // Setup an initial cipher A keyA := mustString(t, 32) cipherA, err := dbcrypt.NewCiphers([]byte(keyA)) require.NoError(t, err) + // Create an encrypted database + cryptdb, err := dbcrypt.New(ctx, db, cipherA...) + require.NoError(t, err) + + // Populate the database with some encrypted data using cipher A. + users = append(users, genData(t, cryptdb, 10)...) + dumpUsers(t, sqlDB, "PARTIALLY ENCRYPTED A") + + // Validate that newly created users were encrypted with cipher A + for _, usr := range users[10:] { + requireEncryptedWithCipher(ctx, t, db, cipherA[0], usr.ID) + } + // Encrypt all the data with the initial cipher. inv, _ := newCLI(t, "server", "dbcrypt", "rotate", "--postgres-url", connectionURL, @@ -60,18 +75,12 @@ func TestServerDBCrypt(t *testing.T) { err = inv.Run() require.NoError(t, err) + dumpUsers(t, sqlDB, "ENCRYPTED A") // Validate that all existing data has been encrypted with cipher A. for _, usr := range users { requireEncryptedWithCipher(ctx, t, db, cipherA[0], usr.ID) } - // Create an encrypted database - cryptdb, err := dbcrypt.New(ctx, db, cipherA...) - require.NoError(t, err) - - // Populate the database with some encrypted data using cipher A. - users = append(users, genData(t, cryptdb, 10)...) - // Re-encrypt all existing data with a new cipher. keyB := mustString(t, 32) cipherBA, err := dbcrypt.NewCiphers([]byte(keyB), []byte(keyA)) @@ -89,6 +98,7 @@ func TestServerDBCrypt(t *testing.T) { require.NoError(t, err) // Validate that all data has been re-encrypted with cipher B. + dumpUsers(t, sqlDB, "ENCRYPTED B") for _, usr := range users { requireEncryptedWithCipher(ctx, t, db, cipherBA[0], usr.ID) } @@ -135,6 +145,7 @@ func TestServerDBCrypt(t *testing.T) { } // Validate that all data has been decrypted. + dumpUsers(t, sqlDB, "DECRYPTED") for _, usr := range users { requireEncryptedWithCipher(ctx, t, db, &nullCipher{}, usr.ID) } @@ -156,6 +167,7 @@ func TestServerDBCrypt(t *testing.T) { require.NoError(t, err) // Validate that all data has been re-encrypted with cipher C. + dumpUsers(t, sqlDB, "ENCRYPTED C") for _, usr := range users { requireEncryptedWithCipher(ctx, t, db, cipherC[0], usr.ID) } @@ -172,6 +184,7 @@ func TestServerDBCrypt(t *testing.T) { require.NoError(t, err) // Assert that no user links remain. + dumpUsers(t, sqlDB, "DELETED") for _, usr := range users { userLinks, err := db.GetUserLinksByUserID(ctx, usr.ID) require.NoError(t, err, "failed to get user links for user %s", usr.ID) @@ -215,6 +228,36 @@ func genData(t *testing.T, db database.Store, n int) []database.User { return users } +func dumpUsers(t *testing.T, db *sql.DB, header string) { + t.Logf("%s %s %s", strings.Repeat("=", 20), header, strings.Repeat("=", 20)) + rows, err := db.QueryContext(context.Background(), `select u.id, u.status, u.deleted, ul.oauth_access_token_key_id as uloatkid, ul.oauth_refresh_token_key_id as ulortkid, gal.oauth_access_token_key_id as galoatkid, gal.oauth_refresh_token_key_id as galortkid from users u left outer join user_links ul on u.id = ul.user_id left outer join git_auth_links gal on u.id = gal.user_id;`) + require.NoError(t, err) + defer rows.Close() + for rows.Next() { + var ( + id string + status string + deleted bool + UlOatKid sql.NullString + UlOrtKid sql.NullString + GalOatKid sql.NullString + GalOrtKid sql.NullString + ) + require.NoError(t, rows.Scan( + &id, + &status, + &deleted, + &UlOatKid, + &UlOrtKid, + &GalOatKid, + &GalOrtKid, + )) + t.Logf("user: id:%s status:%-9s deleted:%-5t ul_kids{at:%-7s rt:%-7s} gal_kids{at:%-7s rt:%-7s}", + id, status, deleted, UlOatKid.String, UlOrtKid.String, GalOatKid.String, GalOrtKid.String, + ) + } +} + func mustString(t *testing.T, n int) string { t.Helper() s, err := cryptorand.String(n) From 39aa29e981411ecd332b8d3fc8df0f1760c333bb Mon Sep 17 00:00:00 2001 From: Cian Johnston Date: Fri, 15 Sep 2023 10:29:12 +0100 Subject: [PATCH 02/11] validate working with multiple user statuses --- coderd/database/dbgen/dbgen.go | 2 +- enterprise/cli/server_dbcrypt_test.go | 8 ++++++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/coderd/database/dbgen/dbgen.go b/coderd/database/dbgen/dbgen.go index f68a8cebadcc8..291fb5b996547 100644 --- a/coderd/database/dbgen/dbgen.go +++ b/coderd/database/dbgen/dbgen.go @@ -227,7 +227,7 @@ func User(t testing.TB, db database.Store, orig database.User) database.User { user, err = db.UpdateUserStatus(genCtx, database.UpdateUserStatusParams{ ID: user.ID, - Status: database.UserStatusActive, + Status: takeFirst(orig.Status, database.UserStatusActive), UpdatedAt: dbtime.Now(), }) require.NoError(t, err, "insert user") diff --git a/enterprise/cli/server_dbcrypt_test.go b/enterprise/cli/server_dbcrypt_test.go index 9a66288ec9781..417014a9c5a07 100644 --- a/enterprise/cli/server_dbcrypt_test.go +++ b/enterprise/cli/server_dbcrypt_test.go @@ -207,9 +207,17 @@ func TestServerDBCrypt(t *testing.T) { func genData(t *testing.T, db database.Store, n int) []database.User { t.Helper() var users []database.User + // Make some users for i := 0; i < n; i++ { + status := database.UserStatusActive + if i%2 == 0 { + status = database.UserStatusSuspended + } else if i%3 == 0 { + status = database.UserStatusDormant + } usr := dbgen.User(t, db, database.User{ LoginType: database.LoginTypeOIDC, + Status: status, }) _ = dbgen.UserLink(t, db, database.UserLink{ UserID: usr.ID, From 3950c097a47b09ec35d82109a565c701660aaf1a Mon Sep 17 00:00:00 2001 From: Cian Johnston Date: Fri, 15 Sep 2023 10:36:36 +0100 Subject: [PATCH 03/11] fmt query --- enterprise/cli/server_dbcrypt_test.go | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/enterprise/cli/server_dbcrypt_test.go b/enterprise/cli/server_dbcrypt_test.go index 417014a9c5a07..3076ac3bc97c9 100644 --- a/enterprise/cli/server_dbcrypt_test.go +++ b/enterprise/cli/server_dbcrypt_test.go @@ -238,7 +238,18 @@ func genData(t *testing.T, db database.Store, n int) []database.User { func dumpUsers(t *testing.T, db *sql.DB, header string) { t.Logf("%s %s %s", strings.Repeat("=", 20), header, strings.Repeat("=", 20)) - rows, err := db.QueryContext(context.Background(), `select u.id, u.status, u.deleted, ul.oauth_access_token_key_id as uloatkid, ul.oauth_refresh_token_key_id as ulortkid, gal.oauth_access_token_key_id as galoatkid, gal.oauth_refresh_token_key_id as galortkid from users u left outer join user_links ul on u.id = ul.user_id left outer join git_auth_links gal on u.id = gal.user_id;`) + rows, err := db.QueryContext(context.Background(), `SELECT + u.id, + u.status, + u.deleted, + ul.oauth_access_token_key_id AS uloatkid, + ul.oauth_refresh_token_key_id AS ulortkid, + gal.oauth_access_token_key_id AS galoatkid, + gal.oauth_refresh_token_key_id AS galortkid +FROM users u +LEFT OUTER JOIN user_links ul ON u.id = ul.user_id +LEFT OUTER JOIN git_auth_links gal ON u.id = gal.user_id +ORDER BY u.created_at ASC;`) require.NoError(t, err) defer rows.Close() for rows.Next() { From 49b9d90f006e787ab5f391b5f9084484f963f3b9 Mon Sep 17 00:00:00 2001 From: Cian Johnston Date: Fri, 15 Sep 2023 10:38:53 +0100 Subject: [PATCH 04/11] create deleted users --- coderd/database/dbgen/dbgen.go | 8 ++++++++ enterprise/cli/server_dbcrypt_test.go | 4 ++++ 2 files changed, 12 insertions(+) diff --git a/coderd/database/dbgen/dbgen.go b/coderd/database/dbgen/dbgen.go index 291fb5b996547..56f85d10ef476 100644 --- a/coderd/database/dbgen/dbgen.go +++ b/coderd/database/dbgen/dbgen.go @@ -240,6 +240,14 @@ func User(t testing.TB, db database.Store, orig database.User) database.User { }) require.NoError(t, err, "user last seen") } + + if orig.Deleted { + err = db.UpdateUserDeletedByID(genCtx, database.UpdateUserDeletedByIDParams{ + ID: user.ID, + Deleted: orig.Deleted, + }) + require.NoError(t, err, "set user as deleted") + } return user } diff --git a/enterprise/cli/server_dbcrypt_test.go b/enterprise/cli/server_dbcrypt_test.go index 3076ac3bc97c9..d7a8cb7a1af68 100644 --- a/enterprise/cli/server_dbcrypt_test.go +++ b/enterprise/cli/server_dbcrypt_test.go @@ -209,15 +209,19 @@ func genData(t *testing.T, db database.Store, n int) []database.User { var users []database.User // Make some users for i := 0; i < n; i++ { + var deleted bool status := database.UserStatusActive if i%2 == 0 { status = database.UserStatusSuspended } else if i%3 == 0 { status = database.UserStatusDormant + } else if i%5 == 0 { + deleted = true } usr := dbgen.User(t, db, database.User{ LoginType: database.LoginTypeOIDC, Status: status, + Deleted: deleted, }) _ = dbgen.UserLink(t, db, database.UserLink{ UserID: usr.ID, From 5c1b687a34c5d0da9cb258b79aaf7aa95335b979 Mon Sep 17 00:00:00 2001 From: Cian Johnston Date: Fri, 15 Sep 2023 11:11:25 +0100 Subject: [PATCH 05/11] robustificate data generation --- enterprise/cli/server_dbcrypt_test.go | 69 ++++++++++++++------------- 1 file changed, 36 insertions(+), 33 deletions(-) diff --git a/enterprise/cli/server_dbcrypt_test.go b/enterprise/cli/server_dbcrypt_test.go index d7a8cb7a1af68..b17f8d326b03e 100644 --- a/enterprise/cli/server_dbcrypt_test.go +++ b/enterprise/cli/server_dbcrypt_test.go @@ -43,7 +43,7 @@ func TestServerDBCrypt(t *testing.T) { db := database.New(sqlDB) // Populate the database with some unencrypted data. - users := genData(t, db, 10) + users := genData(t, db) dumpUsers(t, sqlDB, "NOT ENCRYPTED") // Setup an initial cipher A @@ -56,13 +56,14 @@ func TestServerDBCrypt(t *testing.T) { require.NoError(t, err) // Populate the database with some encrypted data using cipher A. - users = append(users, genData(t, cryptdb, 10)...) + newUsers := genData(t, cryptdb) dumpUsers(t, sqlDB, "PARTIALLY ENCRYPTED A") // Validate that newly created users were encrypted with cipher A - for _, usr := range users[10:] { + for _, usr := range newUsers { requireEncryptedWithCipher(ctx, t, db, cipherA[0], usr.ID) } + users = append(users, newUsers...) // Encrypt all the data with the initial cipher. inv, _ := newCLI(t, "server", "dbcrypt", "rotate", @@ -86,6 +87,10 @@ func TestServerDBCrypt(t *testing.T) { cipherBA, err := dbcrypt.NewCiphers([]byte(keyB), []byte(keyA)) require.NoError(t, err) + // Generate some more encrypted data using the new cipher + users = append(users, genData(t, db)...) + dumpUsers(t, sqlDB, "ENCRYPTED AB") + inv, _ = newCLI(t, "server", "dbcrypt", "rotate", "--postgres-url", connectionURL, "--new-key", base64.StdEncoding.EncodeToString([]byte(keyB)), @@ -204,38 +209,33 @@ func TestServerDBCrypt(t *testing.T) { } } -func genData(t *testing.T, db database.Store, n int) []database.User { +func genData(t *testing.T, db database.Store) []database.User { t.Helper() var users []database.User // Make some users - for i := 0; i < n; i++ { - var deleted bool - status := database.UserStatusActive - if i%2 == 0 { - status = database.UserStatusSuspended - } else if i%3 == 0 { - status = database.UserStatusDormant - } else if i%5 == 0 { - deleted = true + for _, status := range database.AllUserStatusValues() { + for _, loginType := range database.AllLoginTypeValues() { + for _, deleted := range []bool{false, true} { + usr := dbgen.User(t, db, database.User{ + LoginType: loginType, + Status: status, + Deleted: deleted, + }) + _ = dbgen.GitAuthLink(t, db, database.GitAuthLink{ + UserID: usr.ID, + ProviderID: "fake", + OAuthAccessToken: "access-" + usr.ID.String(), + OAuthRefreshToken: "refresh-" + usr.ID.String(), + }) + _ = dbgen.UserLink(t, db, database.UserLink{ + UserID: usr.ID, + LoginType: usr.LoginType, + OAuthAccessToken: "access-" + usr.ID.String(), + OAuthRefreshToken: "refresh-" + usr.ID.String(), + }) + users = append(users, usr) + } } - usr := dbgen.User(t, db, database.User{ - LoginType: database.LoginTypeOIDC, - Status: status, - Deleted: deleted, - }) - _ = dbgen.UserLink(t, db, database.UserLink{ - UserID: usr.ID, - LoginType: usr.LoginType, - OAuthAccessToken: "access-" + usr.ID.String(), - OAuthRefreshToken: "refresh-" + usr.ID.String(), - }) - _ = dbgen.GitAuthLink(t, db, database.GitAuthLink{ - UserID: usr.ID, - ProviderID: "fake", - OAuthAccessToken: "access-" + usr.ID.String(), - OAuthRefreshToken: "refresh-" + usr.ID.String(), - }) - users = append(users, usr) } return users } @@ -244,6 +244,7 @@ func dumpUsers(t *testing.T, db *sql.DB, header string) { t.Logf("%s %s %s", strings.Repeat("=", 20), header, strings.Repeat("=", 20)) rows, err := db.QueryContext(context.Background(), `SELECT u.id, + u.login_type, u.status, u.deleted, ul.oauth_access_token_key_id AS uloatkid, @@ -259,6 +260,7 @@ ORDER BY u.created_at ASC;`) for rows.Next() { var ( id string + loginType string status string deleted bool UlOatKid sql.NullString @@ -268,6 +270,7 @@ ORDER BY u.created_at ASC;`) ) require.NoError(t, rows.Scan( &id, + &loginType, &status, &deleted, &UlOatKid, @@ -275,8 +278,8 @@ ORDER BY u.created_at ASC;`) &GalOatKid, &GalOrtKid, )) - t.Logf("user: id:%s status:%-9s deleted:%-5t ul_kids{at:%-7s rt:%-7s} gal_kids{at:%-7s rt:%-7s}", - id, status, deleted, UlOatKid.String, UlOrtKid.String, GalOatKid.String, GalOrtKid.String, + t.Logf("user: id:%s login_type:%-8s status:%-9s deleted:%-5t ul_kids{at:%-7s rt:%-7s} gal_kids{at:%-7s rt:%-7s}", + id, loginType, status, deleted, UlOatKid.String, UlOrtKid.String, GalOatKid.String, GalOrtKid.String, ) } } From 9bb1a5d4ab57a73ba415235e094553e224f0acd5 Mon Sep 17 00:00:00 2001 From: Cian Johnston Date: Fri, 15 Sep 2023 11:16:12 +0100 Subject: [PATCH 06/11] only dump on test failure --- enterprise/cli/server_dbcrypt_test.go | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/enterprise/cli/server_dbcrypt_test.go b/enterprise/cli/server_dbcrypt_test.go index b17f8d326b03e..4f2860a1f31ad 100644 --- a/enterprise/cli/server_dbcrypt_test.go +++ b/enterprise/cli/server_dbcrypt_test.go @@ -4,7 +4,6 @@ import ( "context" "database/sql" "encoding/base64" - "strings" "testing" "github.com/google/uuid" @@ -42,9 +41,15 @@ func TestServerDBCrypt(t *testing.T) { }) db := database.New(sqlDB) + t.Cleanup(func() { + if t.Failed() { + t.Logf("Dumping data due to failed test. I hope you find what you're looking for!") + dumpUsers(t, sqlDB) + } + }) + // Populate the database with some unencrypted data. users := genData(t, db) - dumpUsers(t, sqlDB, "NOT ENCRYPTED") // Setup an initial cipher A keyA := mustString(t, 32) @@ -57,7 +62,6 @@ func TestServerDBCrypt(t *testing.T) { // Populate the database with some encrypted data using cipher A. newUsers := genData(t, cryptdb) - dumpUsers(t, sqlDB, "PARTIALLY ENCRYPTED A") // Validate that newly created users were encrypted with cipher A for _, usr := range newUsers { @@ -76,7 +80,6 @@ func TestServerDBCrypt(t *testing.T) { err = inv.Run() require.NoError(t, err) - dumpUsers(t, sqlDB, "ENCRYPTED A") // Validate that all existing data has been encrypted with cipher A. for _, usr := range users { requireEncryptedWithCipher(ctx, t, db, cipherA[0], usr.ID) @@ -89,7 +92,6 @@ func TestServerDBCrypt(t *testing.T) { // Generate some more encrypted data using the new cipher users = append(users, genData(t, db)...) - dumpUsers(t, sqlDB, "ENCRYPTED AB") inv, _ = newCLI(t, "server", "dbcrypt", "rotate", "--postgres-url", connectionURL, @@ -103,7 +105,6 @@ func TestServerDBCrypt(t *testing.T) { require.NoError(t, err) // Validate that all data has been re-encrypted with cipher B. - dumpUsers(t, sqlDB, "ENCRYPTED B") for _, usr := range users { requireEncryptedWithCipher(ctx, t, db, cipherBA[0], usr.ID) } @@ -150,7 +151,6 @@ func TestServerDBCrypt(t *testing.T) { } // Validate that all data has been decrypted. - dumpUsers(t, sqlDB, "DECRYPTED") for _, usr := range users { requireEncryptedWithCipher(ctx, t, db, &nullCipher{}, usr.ID) } @@ -172,7 +172,6 @@ func TestServerDBCrypt(t *testing.T) { require.NoError(t, err) // Validate that all data has been re-encrypted with cipher C. - dumpUsers(t, sqlDB, "ENCRYPTED C") for _, usr := range users { requireEncryptedWithCipher(ctx, t, db, cipherC[0], usr.ID) } @@ -189,7 +188,6 @@ func TestServerDBCrypt(t *testing.T) { require.NoError(t, err) // Assert that no user links remain. - dumpUsers(t, sqlDB, "DELETED") for _, usr := range users { userLinks, err := db.GetUserLinksByUserID(ctx, usr.ID) require.NoError(t, err, "failed to get user links for user %s", usr.ID) @@ -227,6 +225,9 @@ func genData(t *testing.T, db database.Store) []database.User { OAuthAccessToken: "access-" + usr.ID.String(), OAuthRefreshToken: "refresh-" + usr.ID.String(), }) + // Fun fact: our schema allows _all_ login types to have + // a user_link. Even though I'm not sure how it could occur + // in practice, making sure to test all combinations here. _ = dbgen.UserLink(t, db, database.UserLink{ UserID: usr.ID, LoginType: usr.LoginType, @@ -240,8 +241,8 @@ func genData(t *testing.T, db database.Store) []database.User { return users } -func dumpUsers(t *testing.T, db *sql.DB, header string) { - t.Logf("%s %s %s", strings.Repeat("=", 20), header, strings.Repeat("=", 20)) +func dumpUsers(t *testing.T, db *sql.DB) { + t.Helper() rows, err := db.QueryContext(context.Background(), `SELECT u.id, u.login_type, From 7da597249f4bbc2f7cc704d4ee449ffb4b325de6 Mon Sep 17 00:00:00 2001 From: Cian Johnston Date: Fri, 15 Sep 2023 11:35:33 +0100 Subject: [PATCH 07/11] more logging --- enterprise/cli/server_dbcrypt_test.go | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/enterprise/cli/server_dbcrypt_test.go b/enterprise/cli/server_dbcrypt_test.go index 4f2860a1f31ad..22ec5e1a99b5f 100644 --- a/enterprise/cli/server_dbcrypt_test.go +++ b/enterprise/cli/server_dbcrypt_test.go @@ -20,6 +20,9 @@ import ( "github.com/coder/coder/v2/pty/ptytest" ) +// TestServerDBCrypt tests end-to-end encryption, decryption, and deletion +// of encrypted user data. +// // nolint: paralleltest // use of t.Setenv func TestServerDBCrypt(t *testing.T) { if !dbtestutil.WillUsePostgres() { @@ -49,6 +52,7 @@ func TestServerDBCrypt(t *testing.T) { }) // Populate the database with some unencrypted data. + t.Logf("Generating unencrypted data") users := genData(t, db) // Setup an initial cipher A @@ -61,6 +65,7 @@ func TestServerDBCrypt(t *testing.T) { require.NoError(t, err) // Populate the database with some encrypted data using cipher A. + t.Logf("Generating data encrypted with cipher A") newUsers := genData(t, cryptdb) // Validate that newly created users were encrypted with cipher A @@ -70,6 +75,7 @@ func TestServerDBCrypt(t *testing.T) { users = append(users, newUsers...) // Encrypt all the data with the initial cipher. + t.Logf("Encrypting all data with cipher A") inv, _ := newCLI(t, "server", "dbcrypt", "rotate", "--postgres-url", connectionURL, "--new-key", base64.StdEncoding.EncodeToString([]byte(keyA)), @@ -90,9 +96,7 @@ func TestServerDBCrypt(t *testing.T) { cipherBA, err := dbcrypt.NewCiphers([]byte(keyB), []byte(keyA)) require.NoError(t, err) - // Generate some more encrypted data using the new cipher - users = append(users, genData(t, db)...) - + t.Logf("Enrypting all data with cipher B") inv, _ = newCLI(t, "server", "dbcrypt", "rotate", "--postgres-url", connectionURL, "--new-key", base64.StdEncoding.EncodeToString([]byte(keyB)), @@ -110,6 +114,7 @@ func TestServerDBCrypt(t *testing.T) { } // Assert that we can revoke the old key. + t.Logf("Revoking cipher A") err = db.RevokeDBCryptKey(ctx, cipherA[0].HexDigest()) require.NoError(t, err, "failed to revoke old key") @@ -125,6 +130,7 @@ func TestServerDBCrypt(t *testing.T) { require.Empty(t, oldKey.ActiveKeyDigest.String, "expected the old key to not be active") // Revoking the new key should fail. + t.Logf("Attempting to revoke cipher B should fail as it is still in use") err = db.RevokeDBCryptKey(ctx, cipherBA[0].HexDigest()) require.Error(t, err, "expected to fail to revoke the new key") var pgErr *pq.Error @@ -132,6 +138,7 @@ func TestServerDBCrypt(t *testing.T) { require.EqualValues(t, "23503", pgErr.Code, "expected a foreign key constraint violation error") // Decrypt the data using only cipher B. This should result in the key being revoked. + t.Logf("Decrypting with cipher B") inv, _ = newCLI(t, "server", "dbcrypt", "decrypt", "--postgres-url", connectionURL, "--keys", base64.StdEncoding.EncodeToString([]byte(keyB)), @@ -160,6 +167,7 @@ func TestServerDBCrypt(t *testing.T) { cipherC, err := dbcrypt.NewCiphers([]byte(keyC)) require.NoError(t, err) + t.Logf("Re-encrypting with cipher C") inv, _ = newCLI(t, "server", "dbcrypt", "rotate", "--postgres-url", connectionURL, "--new-key", base64.StdEncoding.EncodeToString([]byte(keyC)), @@ -177,6 +185,7 @@ func TestServerDBCrypt(t *testing.T) { } // Now delete all the encrypted data. + t.Logf("Deleting all encrypted data") inv, _ = newCLI(t, "server", "dbcrypt", "delete", "--postgres-url", connectionURL, "--external-token-encryption-keys", base64.StdEncoding.EncodeToString([]byte(keyC)), From 2f63e43dd9f1b30773558d99f7a3aea505f093f9 Mon Sep 17 00:00:00 2001 From: Cian Johnston Date: Fri, 15 Sep 2023 11:36:23 +0100 Subject: [PATCH 08/11] do not skip deleted users when encrypting or decrypting --- enterprise/dbcrypt/cliutil.go | 68 +++++++++++++++++++++++------------ 1 file changed, 46 insertions(+), 22 deletions(-) diff --git a/enterprise/dbcrypt/cliutil.go b/enterprise/dbcrypt/cliutil.go index 7f68e284afe77..f12213947dca4 100644 --- a/enterprise/dbcrypt/cliutil.go +++ b/enterprise/dbcrypt/cliutil.go @@ -6,6 +6,8 @@ import ( "golang.org/x/xerrors" + "github.com/google/uuid" + "cdr.dev/slog" "github.com/coder/coder/v2/coderd/database" ) @@ -19,45 +21,45 @@ func Rotate(ctx context.Context, log slog.Logger, sqlDB *sql.DB, ciphers []Ciphe return xerrors.Errorf("create cryptdb: %w", err) } - users, err := cryptDB.GetUsers(ctx, database.GetUsersParams{}) + userIDs, err := allUserIDs(ctx, sqlDB) if err != nil { return xerrors.Errorf("get users: %w", err) } - log.Info(ctx, "encrypting user tokens", slog.F("user_count", len(users))) - for idx, usr := range users { + log.Info(ctx, "encrypting user tokens", slog.F("user_count", len(userIDs))) + for idx, uid := range userIDs { err := cryptDB.InTx(func(tx database.Store) error { - userLinks, err := tx.GetUserLinksByUserID(ctx, usr.ID) + userLinks, err := tx.GetUserLinksByUserID(ctx, uid) if err != nil { return xerrors.Errorf("get user links for user: %w", err) } for _, userLink := range userLinks { if userLink.OAuthAccessTokenKeyID.String == ciphers[0].HexDigest() && userLink.OAuthRefreshTokenKeyID.String == ciphers[0].HexDigest() { - log.Debug(ctx, "skipping user link", slog.F("user_id", usr.ID), slog.F("current", idx+1), slog.F("cipher", ciphers[0].HexDigest())) + log.Debug(ctx, "skipping user link", slog.F("user_id", uid), slog.F("current", idx+1), slog.F("cipher", ciphers[0].HexDigest())) continue } if _, err := tx.UpdateUserLink(ctx, database.UpdateUserLinkParams{ OAuthAccessToken: userLink.OAuthAccessToken, OAuthRefreshToken: userLink.OAuthRefreshToken, OAuthExpiry: userLink.OAuthExpiry, - UserID: usr.ID, - LoginType: usr.LoginType, + UserID: uid, + LoginType: userLink.LoginType, }); err != nil { return xerrors.Errorf("update user link user_id=%s linked_id=%s: %w", userLink.UserID, userLink.LinkedID, err) } } - gitAuthLinks, err := tx.GetGitAuthLinksByUserID(ctx, usr.ID) + gitAuthLinks, err := tx.GetGitAuthLinksByUserID(ctx, uid) if err != nil { return xerrors.Errorf("get git auth links for user: %w", err) } for _, gitAuthLink := range gitAuthLinks { if gitAuthLink.OAuthAccessTokenKeyID.String == ciphers[0].HexDigest() && gitAuthLink.OAuthRefreshTokenKeyID.String == ciphers[0].HexDigest() { - log.Debug(ctx, "skipping git auth link", slog.F("user_id", usr.ID), slog.F("current", idx+1), slog.F("cipher", ciphers[0].HexDigest())) + log.Debug(ctx, "skipping git auth link", slog.F("user_id", uid), slog.F("current", idx+1), slog.F("cipher", ciphers[0].HexDigest())) continue } if _, err := tx.UpdateGitAuthLink(ctx, database.UpdateGitAuthLinkParams{ ProviderID: gitAuthLink.ProviderID, - UserID: usr.ID, + UserID: uid, UpdatedAt: gitAuthLink.UpdatedAt, OAuthAccessToken: gitAuthLink.OAuthAccessToken, OAuthRefreshToken: gitAuthLink.OAuthRefreshToken, @@ -73,7 +75,7 @@ func Rotate(ctx context.Context, log slog.Logger, sqlDB *sql.DB, ciphers []Ciphe if err != nil { return xerrors.Errorf("update user links: %w", err) } - log.Debug(ctx, "encrypted user tokens", slog.F("user_id", usr.ID), slog.F("current", idx+1), slog.F("cipher", ciphers[0].HexDigest())) + log.Debug(ctx, "encrypted user tokens", slog.F("user_id", uid), slog.F("current", idx+1), slog.F("cipher", ciphers[0].HexDigest())) } // Revoke old keys @@ -103,45 +105,45 @@ func Decrypt(ctx context.Context, log slog.Logger, sqlDB *sql.DB, ciphers []Ciph } cryptDB.primaryCipherDigest = "" - users, err := cryptDB.GetUsers(ctx, database.GetUsersParams{}) + userIDs, err := allUserIDs(ctx, sqlDB) if err != nil { return xerrors.Errorf("get users: %w", err) } - log.Info(ctx, "decrypting user tokens", slog.F("user_count", len(users))) - for idx, usr := range users { + log.Info(ctx, "decrypting user tokens", slog.F("user_count", len(userIDs))) + for idx, uid := range userIDs { err := cryptDB.InTx(func(tx database.Store) error { - userLinks, err := tx.GetUserLinksByUserID(ctx, usr.ID) + userLinks, err := tx.GetUserLinksByUserID(ctx, uid) if err != nil { return xerrors.Errorf("get user links for user: %w", err) } for _, userLink := range userLinks { if !userLink.OAuthAccessTokenKeyID.Valid && !userLink.OAuthRefreshTokenKeyID.Valid { - log.Debug(ctx, "skipping user link", slog.F("user_id", usr.ID), slog.F("current", idx+1)) + log.Debug(ctx, "skipping user link", slog.F("user_id", uid), slog.F("current", idx+1)) continue } if _, err := tx.UpdateUserLink(ctx, database.UpdateUserLinkParams{ OAuthAccessToken: userLink.OAuthAccessToken, OAuthRefreshToken: userLink.OAuthRefreshToken, OAuthExpiry: userLink.OAuthExpiry, - UserID: usr.ID, - LoginType: usr.LoginType, + UserID: uid, + LoginType: userLink.LoginType, }); err != nil { return xerrors.Errorf("update user link user_id=%s linked_id=%s: %w", userLink.UserID, userLink.LinkedID, err) } } - gitAuthLinks, err := tx.GetGitAuthLinksByUserID(ctx, usr.ID) + gitAuthLinks, err := tx.GetGitAuthLinksByUserID(ctx, uid) if err != nil { return xerrors.Errorf("get git auth links for user: %w", err) } for _, gitAuthLink := range gitAuthLinks { if !gitAuthLink.OAuthAccessTokenKeyID.Valid && !gitAuthLink.OAuthRefreshTokenKeyID.Valid { - log.Debug(ctx, "skipping git auth link", slog.F("user_id", usr.ID), slog.F("current", idx+1)) + log.Debug(ctx, "skipping git auth link", slog.F("user_id", uid), slog.F("current", idx+1)) continue } if _, err := tx.UpdateGitAuthLink(ctx, database.UpdateGitAuthLinkParams{ ProviderID: gitAuthLink.ProviderID, - UserID: usr.ID, + UserID: uid, UpdatedAt: gitAuthLink.UpdatedAt, OAuthAccessToken: gitAuthLink.OAuthAccessToken, OAuthRefreshToken: gitAuthLink.OAuthRefreshToken, @@ -157,7 +159,7 @@ func Decrypt(ctx context.Context, log slog.Logger, sqlDB *sql.DB, ciphers []Ciph if err != nil { return xerrors.Errorf("update user links: %w", err) } - log.Debug(ctx, "decrypted user tokens", slog.F("user_id", usr.ID), slog.F("current", idx+1), slog.F("cipher", ciphers[0].HexDigest())) + log.Debug(ctx, "decrypted user tokens", slog.F("user_id", uid), slog.F("current", idx+1), slog.F("cipher", ciphers[0].HexDigest())) } // Revoke _all_ keys @@ -212,3 +214,25 @@ func Delete(ctx context.Context, log slog.Logger, sqlDB *sql.DB) error { return nil } + +// allUserIDs returns _all_ user IDs we know about, regardless of status or deletion. +// We need to encrypt / decrypt tokens regardless of user status or deletion as they +// may still be valid. While we could check the expiry, we also don't know if the +// provider is lying about expiry. +// This function will likely only ever be used here, so keeping it here instead +// of exposing it in all of our database-related interfaces. +func allUserIDs(ctx context.Context, sqlDB *sql.DB) ([]uuid.UUID, error) { + var id uuid.UUID + userIDs := make([]uuid.UUID, 0) + rows, err := sqlDB.QueryContext(ctx, `SELECT DISTINCT id FROM users`) + if err != nil { + return nil, xerrors.Errorf("failed to query all user ids: %w", err) + } + for rows.Next() { + if err := rows.Scan(&id); err != nil { + return nil, xerrors.Errorf("failed to scan user_id: %w", err) + } + userIDs = append(userIDs, id) + } + return userIDs, nil +} From 11f85bff7436a1006a91bc47a5bcf447c727f159 Mon Sep 17 00:00:00 2001 From: Cian Johnston Date: Fri, 15 Sep 2023 14:31:49 +0100 Subject: [PATCH 09/11] make AllUserIDs a fully-fledged query citizen --- coderd/database/dbauthz/dbauthz.go | 9 +++++++++ coderd/database/dbfake/dbfake.go | 8 ++++++++ coderd/database/dbmetrics/dbmetrics.go | 7 +++++++ coderd/database/dbmock/dbmock.go | 15 ++++++++++++++ coderd/database/querier.go | 2 ++ coderd/database/queries.sql.go | 28 ++++++++++++++++++++++++++ coderd/database/queries/dbcrypt.sql | 1 + coderd/database/queries/users.sql | 5 +++++ enterprise/dbcrypt/cliutil.go | 28 ++------------------------ 9 files changed, 77 insertions(+), 26 deletions(-) diff --git a/coderd/database/dbauthz/dbauthz.go b/coderd/database/dbauthz/dbauthz.go index 6156329cf7ddd..55a7eb8acefab 100644 --- a/coderd/database/dbauthz/dbauthz.go +++ b/coderd/database/dbauthz/dbauthz.go @@ -664,6 +664,15 @@ func (q *querier) ActivityBumpWorkspace(ctx context.Context, arg uuid.UUID) erro return update(q.log, q.auth, fetch, q.db.ActivityBumpWorkspace)(ctx, arg) } +func (q *querier) AllUserIDs(ctx context.Context) ([]uuid.UUID, error) { + // Although this technically only reads users, only system-related functions should be + // allowed to call this. + if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceSystem); err != nil { + return nil, err + } + return q.db.AllUserIDs(ctx) +} + func (q *querier) CleanTailnetCoordinators(ctx context.Context) error { if err := q.authorizeContext(ctx, rbac.ActionDelete, rbac.ResourceTailnetCoordinator); err != nil { return err diff --git a/coderd/database/dbfake/dbfake.go b/coderd/database/dbfake/dbfake.go index ab7363b275a2f..b570a49b1dfcc 100644 --- a/coderd/database/dbfake/dbfake.go +++ b/coderd/database/dbfake/dbfake.go @@ -812,6 +812,14 @@ func (q *FakeQuerier) ActivityBumpWorkspace(ctx context.Context, workspaceID uui return sql.ErrNoRows } +func (q *FakeQuerier) AllUserIDs(ctx context.Context) ([]uuid.UUID, error) { + userIDs := make([]uuid.UUID, 0, len(q.users)) + for idx := range q.users { + userIDs[idx] = q.users[idx].ID + } + return userIDs, nil +} + func (*FakeQuerier) CleanTailnetCoordinators(_ context.Context) error { return ErrUnimplemented } diff --git a/coderd/database/dbmetrics/dbmetrics.go b/coderd/database/dbmetrics/dbmetrics.go index 768c1d4adbcca..82e4d5e68ec70 100644 --- a/coderd/database/dbmetrics/dbmetrics.go +++ b/coderd/database/dbmetrics/dbmetrics.go @@ -100,6 +100,13 @@ func (m metricsStore) ActivityBumpWorkspace(ctx context.Context, arg uuid.UUID) return r0 } +func (m metricsStore) AllUserIDs(ctx context.Context) ([]uuid.UUID, error) { + start := time.Now() + r0, r1 := m.s.AllUserIDs(ctx) + m.queryLatencies.WithLabelValues("AllUserIDs").Observe(time.Since(start).Seconds()) + return r0, r1 +} + func (m metricsStore) CleanTailnetCoordinators(ctx context.Context) error { start := time.Now() err := m.s.CleanTailnetCoordinators(ctx) diff --git a/coderd/database/dbmock/dbmock.go b/coderd/database/dbmock/dbmock.go index 641dd7315b936..331e279b0a925 100644 --- a/coderd/database/dbmock/dbmock.go +++ b/coderd/database/dbmock/dbmock.go @@ -82,6 +82,21 @@ func (mr *MockStoreMockRecorder) ActivityBumpWorkspace(arg0, arg1 interface{}) * return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ActivityBumpWorkspace", reflect.TypeOf((*MockStore)(nil).ActivityBumpWorkspace), arg0, arg1) } +// AllUserIDs mocks base method. +func (m *MockStore) AllUserIDs(arg0 context.Context) ([]uuid.UUID, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "AllUserIDs", arg0) + ret0, _ := ret[0].([]uuid.UUID) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// AllUserIDs indicates an expected call of AllUserIDs. +func (mr *MockStoreMockRecorder) AllUserIDs(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AllUserIDs", reflect.TypeOf((*MockStore)(nil).AllUserIDs), arg0) +} + // CleanTailnetCoordinators mocks base method. func (m *MockStore) CleanTailnetCoordinators(arg0 context.Context) error { m.ctrl.T.Helper() diff --git a/coderd/database/querier.go b/coderd/database/querier.go index 63c1f7321dc15..b9dc7f5fc4e64 100644 --- a/coderd/database/querier.go +++ b/coderd/database/querier.go @@ -31,6 +31,8 @@ type sqlcQuerier interface { // We only bump if workspace shutdown is manual. // We only bump when 5% of the deadline has elapsed. ActivityBumpWorkspace(ctx context.Context, workspaceID uuid.UUID) error + // AllUserIDs returns all UserIDs regardless of user status or deletion. + AllUserIDs(ctx context.Context) ([]uuid.UUID, error) CleanTailnetCoordinators(ctx context.Context) error DeleteAPIKeyByID(ctx context.Context, id string) error DeleteAPIKeysByUserID(ctx context.Context, userID uuid.UUID) error diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index 67b7b782b9e29..1f7df65d50f86 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -5840,6 +5840,34 @@ func (q *sqlQuerier) UpdateUserLinkedID(ctx context.Context, arg UpdateUserLinke return i, err } +const allUserIDs = `-- name: AllUserIDs :many +SELECT DISTINCT id FROM USERS +` + +// AllUserIDs returns all UserIDs regardless of user status or deletion. +func (q *sqlQuerier) AllUserIDs(ctx context.Context) ([]uuid.UUID, error) { + rows, err := q.db.QueryContext(ctx, allUserIDs) + if err != nil { + return nil, err + } + defer rows.Close() + var items []uuid.UUID + for rows.Next() { + var id uuid.UUID + if err := rows.Scan(&id); err != nil { + return nil, err + } + items = append(items, id) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const getActiveUserCount = `-- name: GetActiveUserCount :one SELECT COUNT(*) diff --git a/coderd/database/queries/dbcrypt.sql b/coderd/database/queries/dbcrypt.sql index ef1021609d5a7..41dc051c8331d 100644 --- a/coderd/database/queries/dbcrypt.sql +++ b/coderd/database/queries/dbcrypt.sql @@ -16,3 +16,4 @@ AND INSERT INTO dbcrypt_keys (number, active_key_digest, created_at, test) VALUES (@number::int, @active_key_digest::text, CURRENT_TIMESTAMP, @test::text); + diff --git a/coderd/database/queries/users.sql b/coderd/database/queries/users.sql index 8560bf0abf696..8caa74a92e588 100644 --- a/coderd/database/queries/users.sql +++ b/coderd/database/queries/users.sql @@ -262,3 +262,8 @@ WHERE last_seen_at < @last_seen_after :: timestamp AND status = 'active'::user_status RETURNING id, email, last_seen_at; + +-- AllUserIDs returns all UserIDs regardless of user status or deletion. +-- name: AllUserIDs :many +SELECT DISTINCT id FROM USERS; + diff --git a/enterprise/dbcrypt/cliutil.go b/enterprise/dbcrypt/cliutil.go index f12213947dca4..3601d0c539c2e 100644 --- a/enterprise/dbcrypt/cliutil.go +++ b/enterprise/dbcrypt/cliutil.go @@ -6,8 +6,6 @@ import ( "golang.org/x/xerrors" - "github.com/google/uuid" - "cdr.dev/slog" "github.com/coder/coder/v2/coderd/database" ) @@ -21,7 +19,7 @@ func Rotate(ctx context.Context, log slog.Logger, sqlDB *sql.DB, ciphers []Ciphe return xerrors.Errorf("create cryptdb: %w", err) } - userIDs, err := allUserIDs(ctx, sqlDB) + userIDs, err := db.AllUserIDs(ctx) if err != nil { return xerrors.Errorf("get users: %w", err) } @@ -105,7 +103,7 @@ func Decrypt(ctx context.Context, log slog.Logger, sqlDB *sql.DB, ciphers []Ciph } cryptDB.primaryCipherDigest = "" - userIDs, err := allUserIDs(ctx, sqlDB) + userIDs, err := db.AllUserIDs(ctx) if err != nil { return xerrors.Errorf("get users: %w", err) } @@ -214,25 +212,3 @@ func Delete(ctx context.Context, log slog.Logger, sqlDB *sql.DB) error { return nil } - -// allUserIDs returns _all_ user IDs we know about, regardless of status or deletion. -// We need to encrypt / decrypt tokens regardless of user status or deletion as they -// may still be valid. While we could check the expiry, we also don't know if the -// provider is lying about expiry. -// This function will likely only ever be used here, so keeping it here instead -// of exposing it in all of our database-related interfaces. -func allUserIDs(ctx context.Context, sqlDB *sql.DB) ([]uuid.UUID, error) { - var id uuid.UUID - userIDs := make([]uuid.UUID, 0) - rows, err := sqlDB.QueryContext(ctx, `SELECT DISTINCT id FROM users`) - if err != nil { - return nil, xerrors.Errorf("failed to query all user ids: %w", err) - } - for rows.Next() { - if err := rows.Scan(&id); err != nil { - return nil, xerrors.Errorf("failed to scan user_id: %w", err) - } - userIDs = append(userIDs, id) - } - return userIDs, nil -} From 0c11d6c52767b9c1d579b5988855e553cef8c53f Mon Sep 17 00:00:00 2001 From: Cian Johnston Date: Fri, 15 Sep 2023 14:48:07 +0100 Subject: [PATCH 10/11] fixup! make AllUserIDs a fully-fledged query citizen --- coderd/database/dbfake/dbfake.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/coderd/database/dbfake/dbfake.go b/coderd/database/dbfake/dbfake.go index b570a49b1dfcc..cccfe28a8922b 100644 --- a/coderd/database/dbfake/dbfake.go +++ b/coderd/database/dbfake/dbfake.go @@ -812,7 +812,7 @@ func (q *FakeQuerier) ActivityBumpWorkspace(ctx context.Context, workspaceID uui return sql.ErrNoRows } -func (q *FakeQuerier) AllUserIDs(ctx context.Context) ([]uuid.UUID, error) { +func (q *FakeQuerier) AllUserIDs(_ context.Context) ([]uuid.UUID, error) { userIDs := make([]uuid.UUID, 0, len(q.users)) for idx := range q.users { userIDs[idx] = q.users[idx].ID From 0f164a3e122ee438660385408e207d9c8a74f7f7 Mon Sep 17 00:00:00 2001 From: Cian Johnston Date: Fri, 15 Sep 2023 13:48:47 +0000 Subject: [PATCH 11/11] fixup! fixup! make AllUserIDs a fully-fledged query citizen --- coderd/database/dbfake/dbfake.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/coderd/database/dbfake/dbfake.go b/coderd/database/dbfake/dbfake.go index cccfe28a8922b..9ac8ed640af2f 100644 --- a/coderd/database/dbfake/dbfake.go +++ b/coderd/database/dbfake/dbfake.go @@ -813,6 +813,8 @@ func (q *FakeQuerier) ActivityBumpWorkspace(ctx context.Context, workspaceID uui } func (q *FakeQuerier) AllUserIDs(_ context.Context) ([]uuid.UUID, error) { + q.mutex.RLock() + defer q.mutex.RUnlock() userIDs := make([]uuid.UUID, 0, len(q.users)) for idx := range q.users { userIDs[idx] = q.users[idx].ID