Skip to content

Commit 04de868

Browse files
committed
add some more dbcrypt tests
1 parent 03262de commit 04de868

File tree

4 files changed

+99
-3
lines changed

4 files changed

+99
-3
lines changed

coderd/database/queries.sql.go

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

coderd/database/queries/crypto_keys.sql

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ WHERE feature = $1
2020

2121
-- name: DeleteCryptoKey :one
2222
UPDATE crypto_keys
23-
SET secret = NULL
23+
SET secret = NULL, secret_key_id = NULL
2424
WHERE feature = $1 AND sequence = $2 RETURNING *;
2525

2626
-- name: InsertCryptoKey :one

enterprise/dbcrypt/dbcrypt.go

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,30 @@ func (db *dbCrypt) UpdateExternalAuthLink(ctx context.Context, params database.U
261261
return link, nil
262262
}
263263

264+
func (db *dbCrypt) GetCryptoKeys(ctx context.Context) ([]database.CryptoKey, error) {
265+
keys, err := db.Store.GetCryptoKeys(ctx)
266+
if err != nil {
267+
return nil, err
268+
}
269+
for i := range keys {
270+
if err := db.decryptField(&keys[i].Secret.String, keys[i].SecretKeyID); err != nil {
271+
return nil, err
272+
}
273+
}
274+
return keys, nil
275+
}
276+
277+
func (db *dbCrypt) GetLatestCryptoKeyByFeature(ctx context.Context, feature database.CryptoKeyFeature) (database.CryptoKey, error) {
278+
key, err := db.Store.GetLatestCryptoKeyByFeature(ctx, feature)
279+
if err != nil {
280+
return database.CryptoKey{}, err
281+
}
282+
if err := db.decryptField(&key.Secret.String, key.SecretKeyID); err != nil {
283+
return database.CryptoKey{}, err
284+
}
285+
return key, nil
286+
}
287+
264288
func (db *dbCrypt) GetCryptoKeyByFeatureAndSequence(ctx context.Context, params database.GetCryptoKeyByFeatureAndSequenceParams) (database.CryptoKey, error) {
265289
key, err := db.Store.GetCryptoKeyByFeatureAndSequence(ctx, params)
266290
if err != nil {
@@ -286,6 +310,17 @@ func (db *dbCrypt) InsertCryptoKey(ctx context.Context, params database.InsertCr
286310
return key, nil
287311
}
288312

313+
func (db *dbCrypt) UpdateCryptoKeyDeletesAt(ctx context.Context, arg database.UpdateCryptoKeyDeletesAtParams) (database.CryptoKey, error) {
314+
key, err := db.Store.UpdateCryptoKeyDeletesAt(ctx, arg)
315+
if err != nil {
316+
return database.CryptoKey{}, err
317+
}
318+
if err := db.decryptField(&key.Secret.String, key.SecretKeyID); err != nil {
319+
return database.CryptoKey{}, err
320+
}
321+
return key, nil
322+
}
323+
289324
func (db *dbCrypt) encryptField(field *string, digest *sql.NullString) error {
290325
// If no cipher is loaded, then we can't encrypt anything!
291326
if db.ciphers == nil || db.primaryCipherDigest == "" {

enterprise/dbcrypt/dbcrypt_internal_test.go

Lines changed: 62 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -352,12 +352,13 @@ func TestExternalAuthLinks(t *testing.T) {
352352
func TestCryptoKeys(t *testing.T) {
353353
t.Parallel()
354354
ctx := context.Background()
355-
db, crypt, ciphers := setup(t)
356355

357356
// We don't write a GetCryptoKeyByFeatureAndSequence test
358357
// because it's basically the same as InsertCryptoKey.
359358
t.Run("InsertCryptoKey", func(t *testing.T) {
360359
t.Parallel()
360+
361+
db, crypt, ciphers := setup(t)
361362
key := dbgen.CryptoKey(t, crypt, database.CryptoKey{
362363
Secret: sql.NullString{String: "test", Valid: true},
363364
})
@@ -371,6 +372,66 @@ func TestCryptoKeys(t *testing.T) {
371372
require.Equal(t, ciphers[0].HexDigest(), key.SecretKeyID.String)
372373
requireEncryptedEquals(t, ciphers[0], key.Secret.String, "test")
373374
})
375+
376+
t.Run("GetCryptoKeys", func(t *testing.T) {
377+
t.Parallel()
378+
db, crypt, ciphers := setup(t)
379+
_ = dbgen.CryptoKey(t, crypt, database.CryptoKey{
380+
Secret: sql.NullString{String: "test", Valid: true},
381+
})
382+
keys, err := crypt.GetCryptoKeys(ctx)
383+
require.NoError(t, err)
384+
require.Len(t, keys, 1)
385+
require.Equal(t, "test", keys[0].Secret.String)
386+
require.Equal(t, ciphers[0].HexDigest(), keys[0].SecretKeyID.String)
387+
388+
keys, err = db.GetCryptoKeys(ctx)
389+
require.NoError(t, err)
390+
require.Len(t, keys, 1)
391+
requireEncryptedEquals(t, ciphers[0], keys[0].Secret.String, "test")
392+
require.Equal(t, ciphers[0].HexDigest(), keys[0].SecretKeyID.String)
393+
})
394+
395+
t.Run("GetLatestCryptoKeyByFeature", func(t *testing.T) {
396+
t.Parallel()
397+
db, crypt, ciphers := setup(t)
398+
_ = dbgen.CryptoKey(t, crypt, database.CryptoKey{
399+
Secret: sql.NullString{String: "test", Valid: true},
400+
})
401+
key, err := crypt.GetLatestCryptoKeyByFeature(ctx, database.CryptoKeyFeatureWorkspaceApps)
402+
require.NoError(t, err)
403+
require.Equal(t, "test", key.Secret.String)
404+
require.Equal(t, ciphers[0].HexDigest(), key.SecretKeyID.String)
405+
406+
key, err = db.GetLatestCryptoKeyByFeature(ctx, database.CryptoKeyFeatureWorkspaceApps)
407+
require.NoError(t, err)
408+
requireEncryptedEquals(t, ciphers[0], key.Secret.String, "test")
409+
require.Equal(t, ciphers[0].HexDigest(), key.SecretKeyID.String)
410+
})
411+
412+
t.Run("GetCryptoKeyByFeatureAndSequence", func(t *testing.T) {
413+
t.Parallel()
414+
db, crypt, ciphers := setup(t)
415+
key := dbgen.CryptoKey(t, crypt, database.CryptoKey{
416+
Secret: sql.NullString{String: "test", Valid: true},
417+
})
418+
key, err := crypt.GetCryptoKeyByFeatureAndSequence(ctx, database.GetCryptoKeyByFeatureAndSequenceParams{
419+
Feature: database.CryptoKeyFeatureWorkspaceApps,
420+
Sequence: key.Sequence,
421+
})
422+
require.NoError(t, err)
423+
require.Equal(t, "test", key.Secret.String)
424+
require.Equal(t, ciphers[0].HexDigest(), key.SecretKeyID.String)
425+
426+
key, err = db.GetCryptoKeyByFeatureAndSequence(ctx, database.GetCryptoKeyByFeatureAndSequenceParams{
427+
Feature: database.CryptoKeyFeatureWorkspaceApps,
428+
Sequence: key.Sequence,
429+
})
430+
require.NoError(t, err)
431+
requireEncryptedEquals(t, ciphers[0], key.Secret.String, "test")
432+
require.Equal(t, ciphers[0].HexDigest(), key.SecretKeyID.String)
433+
})
434+
374435
t.Run("DecryptErr", func(t *testing.T) {
375436
t.Parallel()
376437
db, crypt, ciphers := setup(t)

0 commit comments

Comments
 (0)