6
6
7
7
"golang.org/x/xerrors"
8
8
9
+ "github.com/google/uuid"
10
+
9
11
"cdr.dev/slog"
10
12
"github.com/coder/coder/v2/coderd/database"
11
13
)
@@ -19,45 +21,45 @@ func Rotate(ctx context.Context, log slog.Logger, sqlDB *sql.DB, ciphers []Ciphe
19
21
return xerrors .Errorf ("create cryptdb: %w" , err )
20
22
}
21
23
22
- users , err := cryptDB . GetUsers (ctx , database. GetUsersParams {} )
24
+ userIDs , err := allUserIDs (ctx , sqlDB )
23
25
if err != nil {
24
26
return xerrors .Errorf ("get users: %w" , err )
25
27
}
26
- log .Info (ctx , "encrypting user tokens" , slog .F ("user_count" , len (users )))
27
- for idx , usr := range users {
28
+ log .Info (ctx , "encrypting user tokens" , slog .F ("user_count" , len (userIDs )))
29
+ for idx , uid := range userIDs {
28
30
err := cryptDB .InTx (func (tx database.Store ) error {
29
- userLinks , err := tx .GetUserLinksByUserID (ctx , usr . ID )
31
+ userLinks , err := tx .GetUserLinksByUserID (ctx , uid )
30
32
if err != nil {
31
33
return xerrors .Errorf ("get user links for user: %w" , err )
32
34
}
33
35
for _ , userLink := range userLinks {
34
36
if userLink .OAuthAccessTokenKeyID .String == ciphers [0 ].HexDigest () && userLink .OAuthRefreshTokenKeyID .String == ciphers [0 ].HexDigest () {
35
- log .Debug (ctx , "skipping user link" , slog .F ("user_id" , usr . ID ), slog .F ("current" , idx + 1 ), slog .F ("cipher" , ciphers [0 ].HexDigest ()))
37
+ log .Debug (ctx , "skipping user link" , slog .F ("user_id" , uid ), slog .F ("current" , idx + 1 ), slog .F ("cipher" , ciphers [0 ].HexDigest ()))
36
38
continue
37
39
}
38
40
if _ , err := tx .UpdateUserLink (ctx , database.UpdateUserLinkParams {
39
41
OAuthAccessToken : userLink .OAuthAccessToken ,
40
42
OAuthRefreshToken : userLink .OAuthRefreshToken ,
41
43
OAuthExpiry : userLink .OAuthExpiry ,
42
- UserID : usr . ID ,
43
- LoginType : usr .LoginType ,
44
+ UserID : uid ,
45
+ LoginType : userLink .LoginType ,
44
46
}); err != nil {
45
47
return xerrors .Errorf ("update user link user_id=%s linked_id=%s: %w" , userLink .UserID , userLink .LinkedID , err )
46
48
}
47
49
}
48
50
49
- gitAuthLinks , err := tx .GetGitAuthLinksByUserID (ctx , usr . ID )
51
+ gitAuthLinks , err := tx .GetGitAuthLinksByUserID (ctx , uid )
50
52
if err != nil {
51
53
return xerrors .Errorf ("get git auth links for user: %w" , err )
52
54
}
53
55
for _ , gitAuthLink := range gitAuthLinks {
54
56
if gitAuthLink .OAuthAccessTokenKeyID .String == ciphers [0 ].HexDigest () && gitAuthLink .OAuthRefreshTokenKeyID .String == ciphers [0 ].HexDigest () {
55
- log .Debug (ctx , "skipping git auth link" , slog .F ("user_id" , usr . ID ), slog .F ("current" , idx + 1 ), slog .F ("cipher" , ciphers [0 ].HexDigest ()))
57
+ log .Debug (ctx , "skipping git auth link" , slog .F ("user_id" , uid ), slog .F ("current" , idx + 1 ), slog .F ("cipher" , ciphers [0 ].HexDigest ()))
56
58
continue
57
59
}
58
60
if _ , err := tx .UpdateGitAuthLink (ctx , database.UpdateGitAuthLinkParams {
59
61
ProviderID : gitAuthLink .ProviderID ,
60
- UserID : usr . ID ,
62
+ UserID : uid ,
61
63
UpdatedAt : gitAuthLink .UpdatedAt ,
62
64
OAuthAccessToken : gitAuthLink .OAuthAccessToken ,
63
65
OAuthRefreshToken : gitAuthLink .OAuthRefreshToken ,
@@ -73,7 +75,7 @@ func Rotate(ctx context.Context, log slog.Logger, sqlDB *sql.DB, ciphers []Ciphe
73
75
if err != nil {
74
76
return xerrors .Errorf ("update user links: %w" , err )
75
77
}
76
- log .Debug (ctx , "encrypted user tokens" , slog .F ("user_id" , usr . ID ), slog .F ("current" , idx + 1 ), slog .F ("cipher" , ciphers [0 ].HexDigest ()))
78
+ log .Debug (ctx , "encrypted user tokens" , slog .F ("user_id" , uid ), slog .F ("current" , idx + 1 ), slog .F ("cipher" , ciphers [0 ].HexDigest ()))
77
79
}
78
80
79
81
// Revoke old keys
@@ -103,45 +105,45 @@ func Decrypt(ctx context.Context, log slog.Logger, sqlDB *sql.DB, ciphers []Ciph
103
105
}
104
106
cryptDB .primaryCipherDigest = ""
105
107
106
- users , err := cryptDB . GetUsers (ctx , database. GetUsersParams {} )
108
+ userIDs , err := allUserIDs (ctx , sqlDB )
107
109
if err != nil {
108
110
return xerrors .Errorf ("get users: %w" , err )
109
111
}
110
- log .Info (ctx , "decrypting user tokens" , slog .F ("user_count" , len (users )))
111
- for idx , usr := range users {
112
+ log .Info (ctx , "decrypting user tokens" , slog .F ("user_count" , len (userIDs )))
113
+ for idx , uid := range userIDs {
112
114
err := cryptDB .InTx (func (tx database.Store ) error {
113
- userLinks , err := tx .GetUserLinksByUserID (ctx , usr . ID )
115
+ userLinks , err := tx .GetUserLinksByUserID (ctx , uid )
114
116
if err != nil {
115
117
return xerrors .Errorf ("get user links for user: %w" , err )
116
118
}
117
119
for _ , userLink := range userLinks {
118
120
if ! userLink .OAuthAccessTokenKeyID .Valid && ! userLink .OAuthRefreshTokenKeyID .Valid {
119
- log .Debug (ctx , "skipping user link" , slog .F ("user_id" , usr . ID ), slog .F ("current" , idx + 1 ))
121
+ log .Debug (ctx , "skipping user link" , slog .F ("user_id" , uid ), slog .F ("current" , idx + 1 ))
120
122
continue
121
123
}
122
124
if _ , err := tx .UpdateUserLink (ctx , database.UpdateUserLinkParams {
123
125
OAuthAccessToken : userLink .OAuthAccessToken ,
124
126
OAuthRefreshToken : userLink .OAuthRefreshToken ,
125
127
OAuthExpiry : userLink .OAuthExpiry ,
126
- UserID : usr . ID ,
127
- LoginType : usr .LoginType ,
128
+ UserID : uid ,
129
+ LoginType : userLink .LoginType ,
128
130
}); err != nil {
129
131
return xerrors .Errorf ("update user link user_id=%s linked_id=%s: %w" , userLink .UserID , userLink .LinkedID , err )
130
132
}
131
133
}
132
134
133
- gitAuthLinks , err := tx .GetGitAuthLinksByUserID (ctx , usr . ID )
135
+ gitAuthLinks , err := tx .GetGitAuthLinksByUserID (ctx , uid )
134
136
if err != nil {
135
137
return xerrors .Errorf ("get git auth links for user: %w" , err )
136
138
}
137
139
for _ , gitAuthLink := range gitAuthLinks {
138
140
if ! gitAuthLink .OAuthAccessTokenKeyID .Valid && ! gitAuthLink .OAuthRefreshTokenKeyID .Valid {
139
- log .Debug (ctx , "skipping git auth link" , slog .F ("user_id" , usr . ID ), slog .F ("current" , idx + 1 ))
141
+ log .Debug (ctx , "skipping git auth link" , slog .F ("user_id" , uid ), slog .F ("current" , idx + 1 ))
140
142
continue
141
143
}
142
144
if _ , err := tx .UpdateGitAuthLink (ctx , database.UpdateGitAuthLinkParams {
143
145
ProviderID : gitAuthLink .ProviderID ,
144
- UserID : usr . ID ,
146
+ UserID : uid ,
145
147
UpdatedAt : gitAuthLink .UpdatedAt ,
146
148
OAuthAccessToken : gitAuthLink .OAuthAccessToken ,
147
149
OAuthRefreshToken : gitAuthLink .OAuthRefreshToken ,
@@ -157,7 +159,7 @@ func Decrypt(ctx context.Context, log slog.Logger, sqlDB *sql.DB, ciphers []Ciph
157
159
if err != nil {
158
160
return xerrors .Errorf ("update user links: %w" , err )
159
161
}
160
- log .Debug (ctx , "decrypted user tokens" , slog .F ("user_id" , usr . ID ), slog .F ("current" , idx + 1 ), slog .F ("cipher" , ciphers [0 ].HexDigest ()))
162
+ log .Debug (ctx , "decrypted user tokens" , slog .F ("user_id" , uid ), slog .F ("current" , idx + 1 ), slog .F ("cipher" , ciphers [0 ].HexDigest ()))
161
163
}
162
164
163
165
// Revoke _all_ keys
@@ -212,3 +214,25 @@ func Delete(ctx context.Context, log slog.Logger, sqlDB *sql.DB) error {
212
214
213
215
return nil
214
216
}
217
+
218
+ // allUserIDs returns _all_ user IDs we know about, regardless of status or deletion.
219
+ // We need to encrypt / decrypt tokens regardless of user status or deletion as they
220
+ // may still be valid. While we could check the expiry, we also don't know if the
221
+ // provider is lying about expiry.
222
+ // This function will likely only ever be used here, so keeping it here instead
223
+ // of exposing it in all of our database-related interfaces.
224
+ func allUserIDs (ctx context.Context , sqlDB * sql.DB ) ([]uuid.UUID , error ) {
225
+ var id uuid.UUID
226
+ userIDs := make ([]uuid.UUID , 0 )
227
+ rows , err := sqlDB .QueryContext (ctx , `SELECT DISTINCT id FROM users` )
228
+ if err != nil {
229
+ return nil , xerrors .Errorf ("failed to query all user ids: %w" , err )
230
+ }
231
+ for rows .Next () {
232
+ if err := rows .Scan (& id ); err != nil {
233
+ return nil , xerrors .Errorf ("failed to scan user_id: %w" , err )
234
+ }
235
+ userIDs = append (userIDs , id )
236
+ }
237
+ return userIDs , nil
238
+ }
0 commit comments