Skip to content

Commit 2f63e43

Browse files
committed
do not skip deleted users when encrypting or decrypting
1 parent 7da5972 commit 2f63e43

File tree

1 file changed

+46
-22
lines changed

1 file changed

+46
-22
lines changed

enterprise/dbcrypt/cliutil.go

Lines changed: 46 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ import (
66

77
"golang.org/x/xerrors"
88

9+
"github.com/google/uuid"
10+
911
"cdr.dev/slog"
1012
"github.com/coder/coder/v2/coderd/database"
1113
)
@@ -19,45 +21,45 @@ func Rotate(ctx context.Context, log slog.Logger, sqlDB *sql.DB, ciphers []Ciphe
1921
return xerrors.Errorf("create cryptdb: %w", err)
2022
}
2123

22-
users, err := cryptDB.GetUsers(ctx, database.GetUsersParams{})
24+
userIDs, err := allUserIDs(ctx, sqlDB)
2325
if err != nil {
2426
return xerrors.Errorf("get users: %w", err)
2527
}
26-
log.Info(ctx, "encrypting user tokens", slog.F("user_count", len(users)))
27-
for idx, usr := range users {
28+
log.Info(ctx, "encrypting user tokens", slog.F("user_count", len(userIDs)))
29+
for idx, uid := range userIDs {
2830
err := cryptDB.InTx(func(tx database.Store) error {
29-
userLinks, err := tx.GetUserLinksByUserID(ctx, usr.ID)
31+
userLinks, err := tx.GetUserLinksByUserID(ctx, uid)
3032
if err != nil {
3133
return xerrors.Errorf("get user links for user: %w", err)
3234
}
3335
for _, userLink := range userLinks {
3436
if userLink.OAuthAccessTokenKeyID.String == ciphers[0].HexDigest() && userLink.OAuthRefreshTokenKeyID.String == ciphers[0].HexDigest() {
35-
log.Debug(ctx, "skipping user link", slog.F("user_id", usr.ID), slog.F("current", idx+1), slog.F("cipher", ciphers[0].HexDigest()))
37+
log.Debug(ctx, "skipping user link", slog.F("user_id", uid), slog.F("current", idx+1), slog.F("cipher", ciphers[0].HexDigest()))
3638
continue
3739
}
3840
if _, err := tx.UpdateUserLink(ctx, database.UpdateUserLinkParams{
3941
OAuthAccessToken: userLink.OAuthAccessToken,
4042
OAuthRefreshToken: userLink.OAuthRefreshToken,
4143
OAuthExpiry: userLink.OAuthExpiry,
42-
UserID: usr.ID,
43-
LoginType: usr.LoginType,
44+
UserID: uid,
45+
LoginType: userLink.LoginType,
4446
}); err != nil {
4547
return xerrors.Errorf("update user link user_id=%s linked_id=%s: %w", userLink.UserID, userLink.LinkedID, err)
4648
}
4749
}
4850

49-
gitAuthLinks, err := tx.GetGitAuthLinksByUserID(ctx, usr.ID)
51+
gitAuthLinks, err := tx.GetGitAuthLinksByUserID(ctx, uid)
5052
if err != nil {
5153
return xerrors.Errorf("get git auth links for user: %w", err)
5254
}
5355
for _, gitAuthLink := range gitAuthLinks {
5456
if gitAuthLink.OAuthAccessTokenKeyID.String == ciphers[0].HexDigest() && gitAuthLink.OAuthRefreshTokenKeyID.String == ciphers[0].HexDigest() {
55-
log.Debug(ctx, "skipping git auth link", slog.F("user_id", usr.ID), slog.F("current", idx+1), slog.F("cipher", ciphers[0].HexDigest()))
57+
log.Debug(ctx, "skipping git auth link", slog.F("user_id", uid), slog.F("current", idx+1), slog.F("cipher", ciphers[0].HexDigest()))
5658
continue
5759
}
5860
if _, err := tx.UpdateGitAuthLink(ctx, database.UpdateGitAuthLinkParams{
5961
ProviderID: gitAuthLink.ProviderID,
60-
UserID: usr.ID,
62+
UserID: uid,
6163
UpdatedAt: gitAuthLink.UpdatedAt,
6264
OAuthAccessToken: gitAuthLink.OAuthAccessToken,
6365
OAuthRefreshToken: gitAuthLink.OAuthRefreshToken,
@@ -73,7 +75,7 @@ func Rotate(ctx context.Context, log slog.Logger, sqlDB *sql.DB, ciphers []Ciphe
7375
if err != nil {
7476
return xerrors.Errorf("update user links: %w", err)
7577
}
76-
log.Debug(ctx, "encrypted user tokens", slog.F("user_id", usr.ID), slog.F("current", idx+1), slog.F("cipher", ciphers[0].HexDigest()))
78+
log.Debug(ctx, "encrypted user tokens", slog.F("user_id", uid), slog.F("current", idx+1), slog.F("cipher", ciphers[0].HexDigest()))
7779
}
7880

7981
// Revoke old keys
@@ -103,45 +105,45 @@ func Decrypt(ctx context.Context, log slog.Logger, sqlDB *sql.DB, ciphers []Ciph
103105
}
104106
cryptDB.primaryCipherDigest = ""
105107

106-
users, err := cryptDB.GetUsers(ctx, database.GetUsersParams{})
108+
userIDs, err := allUserIDs(ctx, sqlDB)
107109
if err != nil {
108110
return xerrors.Errorf("get users: %w", err)
109111
}
110-
log.Info(ctx, "decrypting user tokens", slog.F("user_count", len(users)))
111-
for idx, usr := range users {
112+
log.Info(ctx, "decrypting user tokens", slog.F("user_count", len(userIDs)))
113+
for idx, uid := range userIDs {
112114
err := cryptDB.InTx(func(tx database.Store) error {
113-
userLinks, err := tx.GetUserLinksByUserID(ctx, usr.ID)
115+
userLinks, err := tx.GetUserLinksByUserID(ctx, uid)
114116
if err != nil {
115117
return xerrors.Errorf("get user links for user: %w", err)
116118
}
117119
for _, userLink := range userLinks {
118120
if !userLink.OAuthAccessTokenKeyID.Valid && !userLink.OAuthRefreshTokenKeyID.Valid {
119-
log.Debug(ctx, "skipping user link", slog.F("user_id", usr.ID), slog.F("current", idx+1))
121+
log.Debug(ctx, "skipping user link", slog.F("user_id", uid), slog.F("current", idx+1))
120122
continue
121123
}
122124
if _, err := tx.UpdateUserLink(ctx, database.UpdateUserLinkParams{
123125
OAuthAccessToken: userLink.OAuthAccessToken,
124126
OAuthRefreshToken: userLink.OAuthRefreshToken,
125127
OAuthExpiry: userLink.OAuthExpiry,
126-
UserID: usr.ID,
127-
LoginType: usr.LoginType,
128+
UserID: uid,
129+
LoginType: userLink.LoginType,
128130
}); err != nil {
129131
return xerrors.Errorf("update user link user_id=%s linked_id=%s: %w", userLink.UserID, userLink.LinkedID, err)
130132
}
131133
}
132134

133-
gitAuthLinks, err := tx.GetGitAuthLinksByUserID(ctx, usr.ID)
135+
gitAuthLinks, err := tx.GetGitAuthLinksByUserID(ctx, uid)
134136
if err != nil {
135137
return xerrors.Errorf("get git auth links for user: %w", err)
136138
}
137139
for _, gitAuthLink := range gitAuthLinks {
138140
if !gitAuthLink.OAuthAccessTokenKeyID.Valid && !gitAuthLink.OAuthRefreshTokenKeyID.Valid {
139-
log.Debug(ctx, "skipping git auth link", slog.F("user_id", usr.ID), slog.F("current", idx+1))
141+
log.Debug(ctx, "skipping git auth link", slog.F("user_id", uid), slog.F("current", idx+1))
140142
continue
141143
}
142144
if _, err := tx.UpdateGitAuthLink(ctx, database.UpdateGitAuthLinkParams{
143145
ProviderID: gitAuthLink.ProviderID,
144-
UserID: usr.ID,
146+
UserID: uid,
145147
UpdatedAt: gitAuthLink.UpdatedAt,
146148
OAuthAccessToken: gitAuthLink.OAuthAccessToken,
147149
OAuthRefreshToken: gitAuthLink.OAuthRefreshToken,
@@ -157,7 +159,7 @@ func Decrypt(ctx context.Context, log slog.Logger, sqlDB *sql.DB, ciphers []Ciph
157159
if err != nil {
158160
return xerrors.Errorf("update user links: %w", err)
159161
}
160-
log.Debug(ctx, "decrypted user tokens", slog.F("user_id", usr.ID), slog.F("current", idx+1), slog.F("cipher", ciphers[0].HexDigest()))
162+
log.Debug(ctx, "decrypted user tokens", slog.F("user_id", uid), slog.F("current", idx+1), slog.F("cipher", ciphers[0].HexDigest()))
161163
}
162164

163165
// Revoke _all_ keys
@@ -212,3 +214,25 @@ func Delete(ctx context.Context, log slog.Logger, sqlDB *sql.DB) error {
212214

213215
return nil
214216
}
217+
218+
// allUserIDs returns _all_ user IDs we know about, regardless of status or deletion.
219+
// We need to encrypt / decrypt tokens regardless of user status or deletion as they
220+
// may still be valid. While we could check the expiry, we also don't know if the
221+
// provider is lying about expiry.
222+
// This function will likely only ever be used here, so keeping it here instead
223+
// of exposing it in all of our database-related interfaces.
224+
func allUserIDs(ctx context.Context, sqlDB *sql.DB) ([]uuid.UUID, error) {
225+
var id uuid.UUID
226+
userIDs := make([]uuid.UUID, 0)
227+
rows, err := sqlDB.QueryContext(ctx, `SELECT DISTINCT id FROM users`)
228+
if err != nil {
229+
return nil, xerrors.Errorf("failed to query all user ids: %w", err)
230+
}
231+
for rows.Next() {
232+
if err := rows.Scan(&id); err != nil {
233+
return nil, xerrors.Errorf("failed to scan user_id: %w", err)
234+
}
235+
userIDs = append(userIDs, id)
236+
}
237+
return userIDs, nil
238+
}

0 commit comments

Comments
 (0)