Skip to content
Prev Previous commit
Next Next commit
add some more dbcrypt tests
  • Loading branch information
sreya committed Sep 13, 2024
commit 04de868f99c41593c3723c76fa3f62012bbc4c50
2 changes: 1 addition & 1 deletion coderd/database/queries.sql.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion coderd/database/queries/crypto_keys.sql
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ WHERE feature = $1

-- name: DeleteCryptoKey :one
UPDATE crypto_keys
SET secret = NULL
SET secret = NULL, secret_key_id = NULL
WHERE feature = $1 AND sequence = $2 RETURNING *;

-- name: InsertCryptoKey :one
Expand Down
35 changes: 35 additions & 0 deletions enterprise/dbcrypt/dbcrypt.go
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,30 @@ func (db *dbCrypt) UpdateExternalAuthLink(ctx context.Context, params database.U
return link, nil
}

func (db *dbCrypt) GetCryptoKeys(ctx context.Context) ([]database.CryptoKey, error) {
keys, err := db.Store.GetCryptoKeys(ctx)
if err != nil {
return nil, err
}
for i := range keys {
if err := db.decryptField(&keys[i].Secret.String, keys[i].SecretKeyID); err != nil {
return nil, err
}
}
return keys, nil
}

func (db *dbCrypt) GetLatestCryptoKeyByFeature(ctx context.Context, feature database.CryptoKeyFeature) (database.CryptoKey, error) {
key, err := db.Store.GetLatestCryptoKeyByFeature(ctx, feature)
if err != nil {
return database.CryptoKey{}, err
}
if err := db.decryptField(&key.Secret.String, key.SecretKeyID); err != nil {
return database.CryptoKey{}, err
}
return key, nil
}

func (db *dbCrypt) GetCryptoKeyByFeatureAndSequence(ctx context.Context, params database.GetCryptoKeyByFeatureAndSequenceParams) (database.CryptoKey, error) {
key, err := db.Store.GetCryptoKeyByFeatureAndSequence(ctx, params)
if err != nil {
Expand All @@ -286,6 +310,17 @@ func (db *dbCrypt) InsertCryptoKey(ctx context.Context, params database.InsertCr
return key, nil
}

func (db *dbCrypt) UpdateCryptoKeyDeletesAt(ctx context.Context, arg database.UpdateCryptoKeyDeletesAtParams) (database.CryptoKey, error) {
key, err := db.Store.UpdateCryptoKeyDeletesAt(ctx, arg)
if err != nil {
return database.CryptoKey{}, err
}
if err := db.decryptField(&key.Secret.String, key.SecretKeyID); err != nil {
return database.CryptoKey{}, err
}
return key, nil
}

func (db *dbCrypt) encryptField(field *string, digest *sql.NullString) error {
// If no cipher is loaded, then we can't encrypt anything!
if db.ciphers == nil || db.primaryCipherDigest == "" {
Expand Down
63 changes: 62 additions & 1 deletion enterprise/dbcrypt/dbcrypt_internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -352,12 +352,13 @@ func TestExternalAuthLinks(t *testing.T) {
func TestCryptoKeys(t *testing.T) {
t.Parallel()
ctx := context.Background()
db, crypt, ciphers := setup(t)

// We don't write a GetCryptoKeyByFeatureAndSequence test
// because it's basically the same as InsertCryptoKey.
t.Run("InsertCryptoKey", func(t *testing.T) {
t.Parallel()

db, crypt, ciphers := setup(t)
key := dbgen.CryptoKey(t, crypt, database.CryptoKey{
Secret: sql.NullString{String: "test", Valid: true},
})
Expand All @@ -371,6 +372,66 @@ func TestCryptoKeys(t *testing.T) {
require.Equal(t, ciphers[0].HexDigest(), key.SecretKeyID.String)
requireEncryptedEquals(t, ciphers[0], key.Secret.String, "test")
})

t.Run("GetCryptoKeys", func(t *testing.T) {
t.Parallel()
db, crypt, ciphers := setup(t)
_ = dbgen.CryptoKey(t, crypt, database.CryptoKey{
Secret: sql.NullString{String: "test", Valid: true},
})
keys, err := crypt.GetCryptoKeys(ctx)
require.NoError(t, err)
require.Len(t, keys, 1)
require.Equal(t, "test", keys[0].Secret.String)
require.Equal(t, ciphers[0].HexDigest(), keys[0].SecretKeyID.String)

keys, err = db.GetCryptoKeys(ctx)
require.NoError(t, err)
require.Len(t, keys, 1)
requireEncryptedEquals(t, ciphers[0], keys[0].Secret.String, "test")
require.Equal(t, ciphers[0].HexDigest(), keys[0].SecretKeyID.String)
})

t.Run("GetLatestCryptoKeyByFeature", func(t *testing.T) {
t.Parallel()
db, crypt, ciphers := setup(t)
_ = dbgen.CryptoKey(t, crypt, database.CryptoKey{
Secret: sql.NullString{String: "test", Valid: true},
})
key, err := crypt.GetLatestCryptoKeyByFeature(ctx, database.CryptoKeyFeatureWorkspaceApps)
require.NoError(t, err)
require.Equal(t, "test", key.Secret.String)
require.Equal(t, ciphers[0].HexDigest(), key.SecretKeyID.String)

key, err = db.GetLatestCryptoKeyByFeature(ctx, database.CryptoKeyFeatureWorkspaceApps)
require.NoError(t, err)
requireEncryptedEquals(t, ciphers[0], key.Secret.String, "test")
require.Equal(t, ciphers[0].HexDigest(), key.SecretKeyID.String)
})

t.Run("GetCryptoKeyByFeatureAndSequence", func(t *testing.T) {
t.Parallel()
db, crypt, ciphers := setup(t)
key := dbgen.CryptoKey(t, crypt, database.CryptoKey{
Secret: sql.NullString{String: "test", Valid: true},
})
key, err := crypt.GetCryptoKeyByFeatureAndSequence(ctx, database.GetCryptoKeyByFeatureAndSequenceParams{
Feature: database.CryptoKeyFeatureWorkspaceApps,
Sequence: key.Sequence,
})
require.NoError(t, err)
require.Equal(t, "test", key.Secret.String)
require.Equal(t, ciphers[0].HexDigest(), key.SecretKeyID.String)

key, err = db.GetCryptoKeyByFeatureAndSequence(ctx, database.GetCryptoKeyByFeatureAndSequenceParams{
Feature: database.CryptoKeyFeatureWorkspaceApps,
Sequence: key.Sequence,
})
require.NoError(t, err)
requireEncryptedEquals(t, ciphers[0], key.Secret.String, "test")
require.Equal(t, ciphers[0].HexDigest(), key.SecretKeyID.String)
})

t.Run("DecryptErr", func(t *testing.T) {
t.Parallel()
db, crypt, ciphers := setup(t)
Expand Down
Loading