Skip to content

Commit 7a64a4e

Browse files
committed
fixup! refactor dbcrypt: add Ciphers to wrap multiple AES256
1 parent 4142fb2 commit 7a64a4e

File tree

3 files changed

+24
-43
lines changed

3 files changed

+24
-43
lines changed

enterprise/cli/dbcrypt_rotate.go

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,6 @@ import (
77
"cdr.dev/slog"
88
"context"
99
"encoding/base64"
10-
"sync/atomic"
11-
1210
"github.com/coder/coder/v2/cli"
1311
"github.com/coder/coder/v2/cli/clibase"
1412
"github.com/coder/coder/v2/coderd/database"
@@ -59,8 +57,6 @@ func (r *RootCmd) dbcryptRotate() *clibase.Cmd {
5957
return xerrors.Errorf("old and new keys must be different")
6058
}
6159

62-
primaryCipherPtr := &atomic.Pointer[dbcrypt.Cipher]{}
63-
secondaryCipherPtr := &atomic.Pointer[dbcrypt.Cipher]{}
6460
primaryCipher, err := dbcrypt.CipherAES256(newKey)
6561
if err != nil {
6662
return xerrors.Errorf("create primary cipher: %w", err)
@@ -69,8 +65,7 @@ func (r *RootCmd) dbcryptRotate() *clibase.Cmd {
6965
if err != nil {
7066
return xerrors.Errorf("create secondary cipher: %w", err)
7167
}
72-
primaryCipherPtr.Store(&primaryCipher)
73-
secondaryCipherPtr.Store(&secondaryCipher)
68+
ciphers := dbcrypt.NewCiphers(primaryCipher, secondaryCipher)
7469

7570
sqlDB, err := cli.ConnectToPostgres(inv.Context(), logger, "postgres", vals.PostgresURL.Value())
7671
if err != nil {
@@ -83,11 +78,7 @@ func (r *RootCmd) dbcryptRotate() *clibase.Cmd {
8378

8479
db := database.New(sqlDB)
8580

86-
cryptDB, err := dbcrypt.New(ctx, db, &dbcrypt.Options{
87-
PrimaryCipher: primaryCipherPtr,
88-
SecondaryCipher: secondaryCipherPtr,
89-
Logger: logger.Named("cryptdb"),
90-
})
81+
cryptDB, err := dbcrypt.New(ctx, db, ciphers)
9182
if err != nil {
9283
return xerrors.Errorf("create cryptdb: %w", err)
9384
}

enterprise/cli/dbcrypt_rotate_test.go

Lines changed: 18 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,8 @@ import (
66
"encoding/base64"
77
"fmt"
88

9-
"sync/atomic"
109
"testing"
1110

12-
"cdr.dev/slog/sloggers/slogtest"
13-
1411
"github.com/coder/coder/v2/coderd/database"
1512
"github.com/coder/coder/v2/coderd/database/dbgen"
1613
"github.com/coder/coder/v2/coderd/database/dbtestutil"
@@ -47,17 +44,10 @@ func TestDBCryptRotate(t *testing.T) {
4744
keyA := mustString(t, 32)
4845
cA, err := dbcrypt.CipherAES256([]byte(keyA))
4946
require.NoError(t, err)
50-
cipherA := &atomic.Pointer[dbcrypt.Cipher]{}
51-
cipherB := &atomic.Pointer[dbcrypt.Cipher]{}
52-
cipherA.Store(&cA)
47+
ciphers := dbcrypt.NewCiphers(cA)
5348

5449
// Create an encrypted database
55-
log := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
56-
cryptdb, err := dbcrypt.New(ctx, db, &dbcrypt.Options{
57-
PrimaryCipher: cipherA,
58-
SecondaryCipher: cipherB,
59-
Logger: log,
60-
})
50+
cryptdb, err := dbcrypt.New(ctx, db, ciphers)
6151
require.NoError(t, err)
6252

6353
// Populate the database with some data encrypted with cipher A.
@@ -102,31 +92,37 @@ func TestDBCryptRotate(t *testing.T) {
10292
require.NoError(t, err)
10393

10494
// Validate that all data has been updated with the checksum of the new cipher.
105-
expectedPrefixA := fmt.Sprintf("dbcrypt-%s-", cA.HexDigest()[:7])
106-
expectedPrefixB := fmt.Sprintf("dbcrypt-%s-", cB.HexDigest()[:7])
10795
for _, usr := range users {
10896
ul, err := db.GetUserLinkByUserIDLoginType(ctx, database.GetUserLinkByUserIDLoginTypeParams{
10997
UserID: usr.ID,
11098
LoginType: usr.LoginType,
11199
})
112100
require.NoError(t, err, "failed to get user link for user %s", usr.ID)
113-
require.NotContains(t, ul.OAuthAccessToken, expectedPrefixA, "user_link.oauth_access_token should not contain the old cipher checksum")
114-
require.NotContains(t, ul.OAuthRefreshToken, expectedPrefixA, "user_link.oauth_refresh_token should not contain the old cipher checksum")
115-
require.Contains(t, ul.OAuthAccessToken, expectedPrefixB, "user_link.oauth_access_token should contain the new cipher checksum")
116-
require.Contains(t, ul.OAuthRefreshToken, expectedPrefixB, "user_link.oauth_refresh_token should contain the new cipher checksum")
101+
requireEncrypted(t, cB, ul.OAuthAccessToken)
102+
requireEncrypted(t, cB, ul.OAuthRefreshToken)
117103

118104
gal, err := db.GetGitAuthLink(ctx, database.GetGitAuthLinkParams{
119105
UserID: usr.ID,
120106
ProviderID: "fake",
121107
})
122108
require.NoError(t, err, "failed to get git auth link for user %s", usr.ID)
123-
require.NotContains(t, gal.OAuthAccessToken, expectedPrefixA, "git_auth_link.oauth_access_token should not contain the old cipher checksum")
124-
require.NotContains(t, gal.OAuthRefreshToken, expectedPrefixA, "git_auth_link.oauth_refresh_token should not contain the old cipher checksum")
125-
require.Contains(t, gal.OAuthAccessToken, expectedPrefixB, "git_auth_link.oauth_access_token should contain the new cipher checksum")
126-
require.Contains(t, gal.OAuthRefreshToken, expectedPrefixB, "git_auth_link.oauth_refresh_token should contain the new cipher checksum")
109+
requireEncrypted(t, cB, gal.OAuthAccessToken)
110+
requireEncrypted(t, cB, gal.OAuthRefreshToken)
127111
}
128112
}
129113

114+
func requireEncrypted(t *testing.T, c dbcrypt.Cipher, s string) {
115+
t.Helper()
116+
require.Greater(t, len(s), 8, "encrypted string is too short")
117+
require.Equal(t, dbcrypt.MagicPrefix, s[:8], "missing magic prefix")
118+
decodedVal, err := base64.StdEncoding.DecodeString(s[8:])
119+
require.NoError(t, err, "failed to decode base64 string")
120+
require.Greater(t, len(decodedVal), 8, "base64-decoded value is too short")
121+
require.Equal(t, c.HexDigest(), string(decodedVal[:7]), "cipher digest does not match")
122+
_, err = c.Decrypt(decodedVal[8:])
123+
require.NoError(t, err, "failed to decrypt value")
124+
}
125+
130126
func mustString(t *testing.T, n int) string {
131127
t.Helper()
132128
s, err := cryptorand.String(n)

enterprise/coderd/coderd.go

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ import (
1212
"strconv"
1313
"strings"
1414
"sync"
15-
"sync/atomic"
1615
"time"
1716

1817
"golang.org/x/xerrors"
@@ -65,17 +64,12 @@ func New(ctx context.Context, options *Options) (_ *API, err error) {
6564
ctx, cancelFunc := context.WithCancel(ctx)
6665

6766
if options.PrimaryExternalTokenEncryption != nil {
68-
primaryExternalTokenCipher := atomic.Pointer[dbcrypt.Cipher]{}
69-
primaryExternalTokenCipher.Store(&options.PrimaryExternalTokenEncryption)
70-
secondaryExternalTokenCipher := atomic.Pointer[dbcrypt.Cipher]{}
67+
cs := make([]dbcrypt.Cipher, 0)
68+
cs = append(cs, options.PrimaryExternalTokenEncryption)
7169
if options.SecondaryExternalTokenEncryption != nil {
72-
secondaryExternalTokenCipher.Store(&options.SecondaryExternalTokenEncryption)
70+
cs = append(cs, options.SecondaryExternalTokenEncryption)
7371
}
74-
cryptDB, err := dbcrypt.New(ctx, options.Database, &dbcrypt.Options{
75-
PrimaryCipher: &primaryExternalTokenCipher,
76-
SecondaryCipher: &secondaryExternalTokenCipher,
77-
})
78-
72+
cryptDB, err := dbcrypt.New(ctx, options.Database, dbcrypt.NewCiphers(cs...))
7973
if err != nil {
8074
cancelFunc()
8175
return nil, xerrors.Errorf("init dbcrypt: %w", err)

0 commit comments

Comments
 (0)