@@ -6,11 +6,8 @@ import (
6
6
"encoding/base64"
7
7
"fmt"
8
8
9
- "sync/atomic"
10
9
"testing"
11
10
12
- "cdr.dev/slog/sloggers/slogtest"
13
-
14
11
"github.com/coder/coder/v2/coderd/database"
15
12
"github.com/coder/coder/v2/coderd/database/dbgen"
16
13
"github.com/coder/coder/v2/coderd/database/dbtestutil"
@@ -47,17 +44,10 @@ func TestDBCryptRotate(t *testing.T) {
47
44
keyA := mustString (t , 32 )
48
45
cA , err := dbcrypt .CipherAES256 ([]byte (keyA ))
49
46
require .NoError (t , err )
50
- cipherA := & atomic.Pointer [dbcrypt.Cipher ]{}
51
- cipherB := & atomic.Pointer [dbcrypt.Cipher ]{}
52
- cipherA .Store (& cA )
47
+ ciphers := dbcrypt .NewCiphers (cA )
53
48
54
49
// Create an encrypted database
55
- log := slogtest .Make (t , & slogtest.Options {IgnoreErrors : true })
56
- cryptdb , err := dbcrypt .New (ctx , db , & dbcrypt.Options {
57
- PrimaryCipher : cipherA ,
58
- SecondaryCipher : cipherB ,
59
- Logger : log ,
60
- })
50
+ cryptdb , err := dbcrypt .New (ctx , db , ciphers )
61
51
require .NoError (t , err )
62
52
63
53
// Populate the database with some data encrypted with cipher A.
@@ -102,31 +92,37 @@ func TestDBCryptRotate(t *testing.T) {
102
92
require .NoError (t , err )
103
93
104
94
// Validate that all data has been updated with the checksum of the new cipher.
105
- expectedPrefixA := fmt .Sprintf ("dbcrypt-%s-" , cA .HexDigest ()[:7 ])
106
- expectedPrefixB := fmt .Sprintf ("dbcrypt-%s-" , cB .HexDigest ()[:7 ])
107
95
for _ , usr := range users {
108
96
ul , err := db .GetUserLinkByUserIDLoginType (ctx , database.GetUserLinkByUserIDLoginTypeParams {
109
97
UserID : usr .ID ,
110
98
LoginType : usr .LoginType ,
111
99
})
112
100
require .NoError (t , err , "failed to get user link for user %s" , usr .ID )
113
- require .NotContains (t , ul .OAuthAccessToken , expectedPrefixA , "user_link.oauth_access_token should not contain the old cipher checksum" )
114
- require .NotContains (t , ul .OAuthRefreshToken , expectedPrefixA , "user_link.oauth_refresh_token should not contain the old cipher checksum" )
115
- require .Contains (t , ul .OAuthAccessToken , expectedPrefixB , "user_link.oauth_access_token should contain the new cipher checksum" )
116
- require .Contains (t , ul .OAuthRefreshToken , expectedPrefixB , "user_link.oauth_refresh_token should contain the new cipher checksum" )
101
+ requireEncrypted (t , cB , ul .OAuthAccessToken )
102
+ requireEncrypted (t , cB , ul .OAuthRefreshToken )
117
103
118
104
gal , err := db .GetGitAuthLink (ctx , database.GetGitAuthLinkParams {
119
105
UserID : usr .ID ,
120
106
ProviderID : "fake" ,
121
107
})
122
108
require .NoError (t , err , "failed to get git auth link for user %s" , usr .ID )
123
- require .NotContains (t , gal .OAuthAccessToken , expectedPrefixA , "git_auth_link.oauth_access_token should not contain the old cipher checksum" )
124
- require .NotContains (t , gal .OAuthRefreshToken , expectedPrefixA , "git_auth_link.oauth_refresh_token should not contain the old cipher checksum" )
125
- require .Contains (t , gal .OAuthAccessToken , expectedPrefixB , "git_auth_link.oauth_access_token should contain the new cipher checksum" )
126
- require .Contains (t , gal .OAuthRefreshToken , expectedPrefixB , "git_auth_link.oauth_refresh_token should contain the new cipher checksum" )
109
+ requireEncrypted (t , cB , gal .OAuthAccessToken )
110
+ requireEncrypted (t , cB , gal .OAuthRefreshToken )
127
111
}
128
112
}
129
113
114
+ func requireEncrypted (t * testing.T , c dbcrypt.Cipher , s string ) {
115
+ t .Helper ()
116
+ require .Greater (t , len (s ), 8 , "encrypted string is too short" )
117
+ require .Equal (t , dbcrypt .MagicPrefix , s [:8 ], "missing magic prefix" )
118
+ decodedVal , err := base64 .StdEncoding .DecodeString (s [8 :])
119
+ require .NoError (t , err , "failed to decode base64 string" )
120
+ require .Greater (t , len (decodedVal ), 8 , "base64-decoded value is too short" )
121
+ require .Equal (t , c .HexDigest (), string (decodedVal [:7 ]), "cipher digest does not match" )
122
+ _ , err = c .Decrypt (decodedVal [8 :])
123
+ require .NoError (t , err , "failed to decrypt value" )
124
+ }
125
+
130
126
func mustString (t * testing.T , n int ) string {
131
127
t .Helper ()
132
128
s , err := cryptorand .String (n )
0 commit comments