@@ -27,7 +27,6 @@ func TestUserLinks(t *testing.T) {
27
27
t .Run ("InsertUserLink" , func (t * testing.T ) {
28
28
t .Parallel ()
29
29
db , crypt , cipher := setup (t )
30
- initCipher (t , cipher )
31
30
user := dbgen .User (t , crypt , database.User {})
32
31
link := dbgen .UserLink (t , crypt , database.UserLink {
33
32
UserID : user .ID ,
@@ -43,7 +42,6 @@ func TestUserLinks(t *testing.T) {
43
42
t .Run ("UpdateUserLink" , func (t * testing.T ) {
44
43
t .Parallel ()
45
44
db , crypt , cipher := setup (t )
46
- initCipher (t , cipher )
47
45
user := dbgen .User (t , crypt , database.User {})
48
46
link := dbgen .UserLink (t , crypt , database.UserLink {
49
47
UserID : user .ID ,
@@ -64,7 +62,6 @@ func TestUserLinks(t *testing.T) {
64
62
t .Run ("GetUserLinkByLinkedID" , func (t * testing.T ) {
65
63
t .Parallel ()
66
64
db , crypt , cipher := setup (t )
67
- initCipher (t , cipher )
68
65
user := dbgen .User (t , crypt , database.User {})
69
66
link := dbgen .UserLink (t , crypt , database.UserLink {
70
67
UserID : user .ID ,
@@ -86,7 +83,6 @@ func TestUserLinks(t *testing.T) {
86
83
t .Run ("GetUserLinkByUserIDLoginType" , func (t * testing.T ) {
87
84
t .Parallel ()
88
85
db , crypt , cipher := setup (t )
89
- initCipher (t , cipher )
90
86
user := dbgen .User (t , crypt , database.User {})
91
87
link := dbgen .UserLink (t , crypt , database.UserLink {
92
88
UserID : user .ID ,
@@ -119,7 +115,6 @@ func TestGitAuthLinks(t *testing.T) {
119
115
t .Run ("InsertGitAuthLink" , func (t * testing.T ) {
120
116
t .Parallel ()
121
117
db , crypt , cipher := setup (t )
122
- initCipher (t , cipher )
123
118
link := dbgen .GitAuthLink (t , crypt , database.GitAuthLink {
124
119
OAuthAccessToken : "access" ,
125
120
OAuthRefreshToken : "refresh" ,
@@ -136,7 +131,6 @@ func TestGitAuthLinks(t *testing.T) {
136
131
t .Run ("UpdateGitAuthLink" , func (t * testing.T ) {
137
132
t .Parallel ()
138
133
db , crypt , cipher := setup (t )
139
- initCipher (t , cipher )
140
134
link := dbgen .GitAuthLink (t , crypt , database.GitAuthLink {})
141
135
_ , err := crypt .UpdateGitAuthLink (ctx , database.UpdateGitAuthLinkParams {
142
136
ProviderID : link .ProviderID ,
@@ -157,7 +151,6 @@ func TestGitAuthLinks(t *testing.T) {
157
151
t .Run ("GetGitAuthLink" , func (t * testing.T ) {
158
152
t .Parallel ()
159
153
db , crypt , cipher := setup (t )
160
- initCipher (t , cipher )
161
154
link := dbgen .GitAuthLink (t , crypt , database.GitAuthLink {
162
155
OAuthAccessToken : "access" ,
163
156
OAuthRefreshToken : "refresh" ,
@@ -181,37 +174,90 @@ func TestGitAuthLinks(t *testing.T) {
181
174
})
182
175
}
183
176
184
- func TestDBCryptSentinelValue (t * testing.T ) {
177
+ func TestNew (t * testing.T ) {
185
178
t .Parallel ()
186
- ctx := context .Background ()
187
- db , crypt , cipher := setup (t )
188
- // Initially, the database will not be encrypted.
189
- _ , err := db .GetDBCryptSentinelValue (ctx )
190
- require .ErrorIs (t , err , sql .ErrNoRows )
191
- _ , err = crypt .GetDBCryptSentinelValue (ctx )
192
- require .EqualError (t , err , dbcrypt .ErrNotEncrypted .Error ())
193
179
194
- // Now, we'll encrypt the value.
195
- initCipher (t , cipher )
196
- err = crypt .SetDBCryptSentinelValue (ctx , "coder" )
197
- require .NoError (t , err )
180
+ t .Run ("OK" , func (t * testing.T ) {
181
+ // Given: a cipher is loaded
182
+ cipher := & atomic.Pointer [dbcrypt.Cipher ]{}
183
+ initCipher (t , cipher )
184
+ ctx , cancel := context .WithCancel (context .Background ())
185
+ t .Cleanup (cancel )
186
+ rawDB , _ := dbtestutil .NewDB (t )
198
187
199
- // The value should be encrypted in the database.
200
- crypted , err := db .GetDBCryptSentinelValue (ctx )
201
- require .NoError (t , err )
202
- require .NotEqual (t , "coder" , crypted )
203
- decrypted , err := crypt .GetDBCryptSentinelValue (ctx )
204
- require .NoError (t , err )
205
- require .Equal (t , "coder" , decrypted )
206
- requireEncryptedEquals (t , cipher , crypted , "coder" )
188
+ // When: we init the crypt db
189
+ cryptDB , err := dbcrypt .New (ctx , rawDB , & dbcrypt.Options {
190
+ ExternalTokenCipher : cipher ,
191
+ Logger : slogtest .Make (t , nil ).Leveled (slog .LevelDebug ),
192
+ })
193
+ require .NoError (t , err )
207
194
208
- // Reset the key and empty values should be returned!
209
- initCipher (t , cipher )
195
+ // Then: the sentinel value is encrypted
196
+ cryptVal , err := cryptDB .GetDBCryptSentinelValue (ctx )
197
+ require .NoError (t , err )
198
+ require .Equal (t , "coder" , cryptVal )
210
199
211
- _ , err = db .GetDBCryptSentinelValue (ctx ) // We can still read the raw value
212
- require .NoError (t , err )
213
- _ , err = crypt .GetDBCryptSentinelValue (ctx ) // Decryption should fail
214
- require .ErrorIs (t , err , sql .ErrNoRows )
200
+ rawVal , err := rawDB .GetDBCryptSentinelValue (ctx )
201
+ require .NoError (t , err )
202
+ require .Contains (t , rawVal , dbcrypt .MagicPrefix )
203
+ })
204
+
205
+ t .Run ("NoCipher" , func (t * testing.T ) {
206
+ // Given: no cipher is loaded
207
+ cipher := & atomic.Pointer [dbcrypt.Cipher ]{}
208
+ // initCipher(t, cipher)
209
+ ctx , cancel := context .WithCancel (context .Background ())
210
+ t .Cleanup (cancel )
211
+ rawDB , _ := dbtestutil .NewDB (t )
212
+
213
+ // When: we init the crypt db
214
+ cryptDB , err := dbcrypt .New (ctx , rawDB , & dbcrypt.Options {
215
+ ExternalTokenCipher : cipher ,
216
+ Logger : slogtest .Make (t , nil ).Leveled (slog .LevelDebug ),
217
+ })
218
+ require .NoError (t , err )
219
+
220
+ // Then: the sentinel value is not encrypted
221
+ cryptVal , err := cryptDB .GetDBCryptSentinelValue (ctx )
222
+ require .NoError (t , err )
223
+ require .Equal (t , "coder" , cryptVal )
224
+
225
+ rawVal , err := rawDB .GetDBCryptSentinelValue (ctx )
226
+ require .NoError (t , err )
227
+ require .Equal (t , "coder" , rawVal )
228
+ })
229
+
230
+ t .Run ("CipherChanged" , func (t * testing.T ) {
231
+ // Given: no cipher is loaded
232
+ cipher := & atomic.Pointer [dbcrypt.Cipher ]{}
233
+ initCipher (t , cipher )
234
+ ctx , cancel := context .WithCancel (context .Background ())
235
+ t .Cleanup (cancel )
236
+ rawDB , _ := dbtestutil .NewDB (t )
237
+
238
+ // And: the sentinel value is encrypted with a different cipher
239
+ cipher2 := & atomic.Pointer [dbcrypt.Cipher ]{}
240
+ initCipher (t , cipher2 )
241
+ field := "coder"
242
+ encrypted , err := (* cipher2 .Load ()).Encrypt ([]byte (field ))
243
+ require .NoError (t , err )
244
+ b64encrypted := base64 .StdEncoding .EncodeToString (encrypted )
245
+ require .NoError (t , rawDB .SetDBCryptSentinelValue (ctx , b64encrypted ))
246
+
247
+ // When: we init the crypt db
248
+ _ , err = dbcrypt .New (ctx , rawDB , & dbcrypt.Options {
249
+ ExternalTokenCipher : cipher ,
250
+ Logger : slogtest .Make (t , nil ).Leveled (slog .LevelDebug ),
251
+ })
252
+ // Then: an error is returned
253
+ // TODO: when we implement key rotation, this should not fail.
254
+ require .ErrorContains (t , err , "database is already encrypted with a different key" )
255
+
256
+ // And the sentinel value should remain unchanged. For now.
257
+ rawVal , err := rawDB .GetDBCryptSentinelValue (ctx )
258
+ require .NoError (t , err )
259
+ require .Equal (t , b64encrypted , rawVal )
260
+ })
215
261
}
216
262
217
263
func requireEncryptedEquals (t * testing.T , cipher * atomic.Pointer [dbcrypt.Cipher ], value , expected string ) {
@@ -238,10 +284,28 @@ func initCipher(t *testing.T, cipher *atomic.Pointer[dbcrypt.Cipher]) {
238
284
239
285
func setup (t * testing.T ) (db , cryptodb database.Store , cipher * atomic.Pointer [dbcrypt.Cipher ]) {
240
286
t .Helper ()
287
+ ctx , cancel := context .WithCancel (context .Background ())
288
+ t .Cleanup (cancel )
241
289
rawDB , _ := dbtestutil .NewDB (t )
290
+
291
+ _ , err := rawDB .GetDBCryptSentinelValue (ctx )
292
+ require .ErrorIs (t , err , sql .ErrNoRows )
293
+
242
294
cipher = & atomic.Pointer [dbcrypt.Cipher ]{}
243
- return rawDB , dbcrypt .New (rawDB , & dbcrypt.Options {
295
+ initCipher (t , cipher )
296
+ cryptDB , err := dbcrypt .New (ctx , rawDB , & dbcrypt.Options {
244
297
ExternalTokenCipher : cipher ,
245
298
Logger : slogtest .Make (t , nil ).Leveled (slog .LevelDebug ),
246
- }), cipher
299
+ })
300
+ require .NoError (t , err )
301
+
302
+ rawVal , err := rawDB .GetDBCryptSentinelValue (ctx )
303
+ require .NoError (t , err )
304
+ require .Contains (t , rawVal , dbcrypt .MagicPrefix )
305
+
306
+ cryptVal , err := cryptDB .GetDBCryptSentinelValue (ctx )
307
+ require .NoError (t , err )
308
+ require .Equal (t , "coder" , cryptVal )
309
+
310
+ return rawDB , cryptDB , cipher
247
311
}
0 commit comments