Skip to content

Commit 4142fb2

Browse files
committed
refactor dbcrypt: add Ciphers to wrap multiple AES256
1 parent 5a0161c commit 4142fb2

File tree

4 files changed

+87
-177
lines changed

4 files changed

+87
-177
lines changed

enterprise/dbcrypt/cipher.go

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ type Cipher interface {
1818
}
1919

2020
// CipherAES256 returns a new AES-256 cipher.
21-
func CipherAES256(key []byte) (Cipher, error) {
21+
func CipherAES256(key []byte) (*AES256, error) {
2222
if len(key) != 32 {
2323
return nil, xerrors.Errorf("key must be 32 bytes")
2424
}
@@ -31,16 +31,16 @@ func CipherAES256(key []byte) (Cipher, error) {
3131
return nil, err
3232
}
3333
digest := fmt.Sprintf("%x", sha256.Sum256(key))[:7]
34-
return &aes256{aead: aead, digest: digest}, nil
34+
return &AES256{aead: aead, digest: digest}, nil
3535
}
3636

37-
type aes256 struct {
37+
type AES256 struct {
3838
aead cipher.AEAD
3939
// digest is the first 7 bytes of the hex-encoded SHA-256 digest of aead.
4040
digest string
4141
}
4242

43-
func (a *aes256) Encrypt(plaintext []byte) ([]byte, error) {
43+
func (a *AES256) Encrypt(plaintext []byte) ([]byte, error) {
4444
nonce := make([]byte, a.aead.NonceSize())
4545
_, err := io.ReadFull(rand.Reader, nonce)
4646
if err != nil {
@@ -49,7 +49,7 @@ func (a *aes256) Encrypt(plaintext []byte) ([]byte, error) {
4949
return a.aead.Seal(nonce, nonce, plaintext, nil), nil
5050
}
5151

52-
func (a *aes256) Decrypt(ciphertext []byte) ([]byte, error) {
52+
func (a *AES256) Decrypt(ciphertext []byte) ([]byte, error) {
5353
if len(ciphertext) < a.aead.NonceSize() {
5454
return nil, xerrors.Errorf("ciphertext too short")
5555
}
@@ -60,7 +60,7 @@ func (a *aes256) Decrypt(ciphertext []byte) ([]byte, error) {
6060
return decrypted, nil
6161
}
6262

63-
func (a *aes256) HexDigest() string {
63+
func (a *AES256) HexDigest() string {
6464
return a.digest
6565
}
6666

@@ -70,19 +70,22 @@ type Ciphers struct {
7070
m map[string]Cipher
7171
}
7272

73-
// CiphersAES256 returns a new Ciphers instance with the given ciphers.
73+
// NewCiphers returns a new Ciphers instance with the given ciphers.
7474
// The first cipher in the list is the primary cipher. Any ciphers after the
7575
// first are considered secondary ciphers and are only used for decryption.
76-
func CiphersAES256(cs ...Cipher) Ciphers {
76+
func NewCiphers(cs ...Cipher) *Ciphers {
7777
var primary string
7878
m := make(map[string]Cipher)
7979
for idx, c := range cs {
80+
if _, ok := c.(*Ciphers); ok {
81+
panic("developer error: do not nest Ciphers")
82+
}
8083
m[c.HexDigest()] = c
8184
if idx == 0 {
8285
primary = c.HexDigest()
8386
}
8487
}
85-
return Ciphers{primary: primary, m: m}
88+
return &Ciphers{primary: primary, m: m}
8689
}
8790

8891
// Encrypt encrypts the given plaintext using the primary cipher and returns the
@@ -112,3 +115,8 @@ func (cs Ciphers) Decrypt(ciphertext []byte) ([]byte, error) {
112115
}
113116
return c.Decrypt(ciphertext[8:])
114117
}
118+
119+
// HexDigest returns the digest of the primary cipher.
120+
func (cs Ciphers) HexDigest() string {
121+
return cs.primary
122+
}

enterprise/dbcrypt/cipher_test.go

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ func TestCipherAES256(t *testing.T) {
4444
})
4545
}
4646

47-
func TestCiphersAES256(t *testing.T) {
47+
func TestCiphers(t *testing.T) {
4848
t.Parallel()
4949

5050
// Given: two ciphers
@@ -55,10 +55,7 @@ func TestCiphersAES256(t *testing.T) {
5555
cipher2, err := dbcrypt.CipherAES256(key2)
5656
require.NoError(t, err)
5757

58-
ciphers := dbcrypt.CiphersAES256(
59-
cipher1,
60-
cipher2,
61-
)
58+
ciphers := dbcrypt.NewCiphers(cipher1, cipher2)
6259

6360
// Then: it should encrypt with the cipher1
6461
output, err := ciphers.Encrypt([]byte("hello world"))
@@ -82,4 +79,16 @@ func TestCiphersAES256(t *testing.T) {
8279
decrypted2, err := ciphers.Decrypt(bytes.Join([][]byte{[]byte(cipher2.HexDigest()), output2}, []byte{'-'}))
8380
require.NoError(t, err)
8481
require.Equal(t, "hello world", string(decrypted2))
82+
83+
// Decryption of data encrypted with cipher1 should succeed
84+
output1, err := cipher1.Encrypt([]byte("hello world"))
85+
require.NoError(t, err)
86+
decrypted1, err := ciphers.Decrypt(bytes.Join([][]byte{[]byte(cipher1.HexDigest()), output1}, []byte{'-'}))
87+
require.NoError(t, err)
88+
require.Equal(t, "hello world", string(decrypted1))
89+
90+
// Wrapping a Ciphers with itself should panic.
91+
require.PanicsWithValue(t, "developer error: do not nest Ciphers", func() {
92+
_ = dbcrypt.NewCiphers(ciphers)
93+
})
8594
}

enterprise/dbcrypt/dbcrypt.go

Lines changed: 21 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,14 @@
1212
// - database.DBCryptSentinelValue
1313
//
1414
// Encrypted fields are stored in the following format:
15-
// "dbcrypt-<first 7 characters of cipher's SHA256 digest>-<base64-encoded encrypted value>"
15+
// "dbcrypt-${b64encode(<first 7 digits of cipher's SHA256 digest>-<encrypted value>)}"
1616
//
1717
// The first 7 characters of the cipher's SHA256 digest are used to identify the cipher
1818
// used to encrypt the value.
1919
//
20-
// Two ciphers can be provided to support key rotation. The primary cipher is used to encrypt
21-
// and decrypt all values. We only use the secondary cipher to decrypt values if decryption
22-
// with the primary cipher fails.
20+
// Multiple ciphers can be provided to support key rotation. The primary cipher is used
21+
// to encrypt and decrypt all data. Secondary ciphers are only used for decryption.
22+
// We currently only use a single secondary cipher.
2323
package dbcrypt
2424

2525
import (
@@ -28,16 +28,12 @@ import (
2828
"encoding/base64"
2929
"errors"
3030
"strings"
31-
"sync/atomic"
32-
33-
"github.com/google/uuid"
34-
"github.com/hashicorp/go-multierror"
35-
"golang.org/x/xerrors"
36-
37-
"cdr.dev/slog"
3831

3932
"github.com/coder/coder/v2/coderd/database"
4033
"github.com/coder/coder/v2/coderd/database/dbauthz"
34+
35+
"github.com/google/uuid"
36+
"golang.org/x/xerrors"
4137
)
4238

4339
// MagicPrefix is prepended to all encrypted values in the database.
@@ -48,10 +44,6 @@ import (
4844
// encrypted value.
4945
const MagicPrefix = "dbcrypt-"
5046

51-
// MagicPrefixLength is the length of the entire prefix used to identify
52-
// encrypted values.
53-
const MagicPrefixLength = len(MagicPrefix) + 8
54-
5547
// sentinelValue is the value that is stored in the database to indicate
5648
// whether encryption is enabled. If not enabled, the value either not
5749
// present, or is the raw string "coder".
@@ -79,31 +71,14 @@ func (*DecryptFailedError) Unwrap() error {
7971
return sql.ErrNoRows
8072
}
8173

82-
func IsDecryptFailedError(err error) bool {
83-
var e *DecryptFailedError
84-
return errors.As(err, &e)
85-
}
86-
87-
type Options struct {
88-
// PrimaryCipher is an optional cipher that is used
89-
// to encrypt/decrypt user link and git auth link tokens. If this is nil,
90-
// then no encryption/decryption will be performed.
91-
PrimaryCipher *atomic.Pointer[Cipher]
92-
// SecondaryCipher is an optional cipher that is only used
93-
// to decrypt user link and git auth link tokens.
94-
// This should only be used when rotating the primary cipher.
95-
SecondaryCipher *atomic.Pointer[Cipher]
96-
Logger slog.Logger
97-
}
98-
9974
// New creates a database.Store wrapper that encrypts/decrypts values
10075
// stored at rest in the database.
101-
func New(ctx context.Context, db database.Store, options *Options) (database.Store, error) {
102-
if options.PrimaryCipher.Load() == nil {
103-
return nil, xerrors.Errorf("at least one cipher is required")
76+
func New(ctx context.Context, db database.Store, cs *Ciphers) (database.Store, error) {
77+
if cs == nil {
78+
return nil, xerrors.Errorf("no ciphers configured")
10479
}
10580
dbc := &dbCrypt{
106-
Options: options,
81+
ciphers: cs,
10782
Store: db,
10883
}
10984
if err := ensureEncrypted(dbauthz.AsSystemRestricted(ctx), dbc); err != nil {
@@ -113,14 +88,14 @@ func New(ctx context.Context, db database.Store, options *Options) (database.Sto
11388
}
11489

11590
type dbCrypt struct {
116-
*Options
91+
ciphers *Ciphers
11792
database.Store
11893
}
11994

12095
func (db *dbCrypt) InTx(function func(database.Store) error, txOpts *sql.TxOptions) error {
12196
return db.Store.InTx(func(s database.Store) error {
12297
return function(&dbCrypt{
123-
Options: db.Options,
98+
ciphers: db.ciphers,
12499
Store: s,
125100
})
126101
}, txOpts)
@@ -225,83 +200,52 @@ func (db *dbCrypt) SetDBCryptSentinelValue(ctx context.Context, value string) er
225200
}
226201

227202
func (db *dbCrypt) encryptFields(fields ...*string) error {
228-
// Encryption ALWAYS happens with the primary cipher.
229-
cipherPtr := db.PrimaryCipher.Load()
230203
// If no cipher is loaded, then we can't encrypt anything!
231-
if cipherPtr == nil {
204+
if db.ciphers == nil {
232205
return ErrNotEnabled
233206
}
234-
cipher := *cipherPtr
207+
235208
for _, field := range fields {
236209
if field == nil {
237210
continue
238211
}
239212

240-
encrypted, err := cipher.Encrypt([]byte(*field))
213+
encrypted, err := db.ciphers.Encrypt([]byte(*field))
241214
if err != nil {
242215
return err
243216
}
244217
// Base64 is used to support UTF-8 encoding in PostgreSQL.
245-
*field = MagicPrefix + cipher.HexDigest()[:7] + "-" + b64encode(encrypted)
218+
*field = MagicPrefix + b64encode(encrypted)
246219
}
247220
return nil
248221
}
249222

250223
// decryptFields decrypts the given fields in place.
251224
// If the value fails to decrypt, sql.ErrNoRows will be returned.
252225
func (db *dbCrypt) decryptFields(fields ...*string) error {
253-
var merr *multierror.Error
254-
255-
// We try to decrypt with both the primary and secondary cipher.
256-
primaryCipherPtr := db.PrimaryCipher.Load()
257-
if err := decryptWithCipher(primaryCipherPtr, fields...); err == nil {
258-
return nil
259-
} else {
260-
merr = multierror.Append(merr, err)
261-
}
262-
secondaryCipherPtr := db.SecondaryCipher.Load()
263-
if err := decryptWithCipher(secondaryCipherPtr, fields...); err == nil {
264-
return nil
265-
} else {
266-
merr = multierror.Append(merr, err)
267-
}
268-
return merr
269-
}
270-
271-
func decryptWithCipher(cipherPtr *Cipher, fields ...*string) error {
272-
// If no cipher is loaded, then we can't decrypt anything!
273-
if cipherPtr == nil {
226+
if db.ciphers == nil {
274227
return ErrNotEnabled
275228
}
276229

277-
cipher := *cipherPtr
278230
for _, field := range fields {
279231
if field == nil {
280232
continue
281233
}
282234

283-
if len(*field) < 16 || !strings.HasPrefix(*field, MagicPrefix) {
235+
if len(*field) < 8 || !strings.HasPrefix(*field, MagicPrefix) {
284236
// We do not force decryption of unencrypted rows. This could be damaging
285237
// to the deployment, and admins can always manually purge data.
286238
continue
287239
}
288240

289-
// The first 7 characters of the digest are used to identify the cipher.
290-
// If the cipher changes, we should complain loudly.
291-
encPrefix := cipher.HexDigest()[:7]
292-
if !strings.HasPrefix((*field)[8:15], encPrefix) {
293-
return &DecryptFailedError{
294-
Inner: xerrors.Errorf("cipher mismatch: expected %q, got %q", encPrefix, (*field)[8:15]),
295-
}
296-
}
297-
data, err := b64decode((*field)[16:])
241+
data, err := b64decode((*field)[8:])
298242
if err != nil {
299243
// If it's not base64 with the prefix, we should complain loudly.
300244
return &DecryptFailedError{
301245
Inner: xerrors.Errorf("malformed encrypted field %q: %w", *field, err),
302246
}
303247
}
304-
decrypted, err := cipher.Decrypt(data)
248+
decrypted, err := db.ciphers.Decrypt(data)
305249
if err != nil {
306250
// If the encryption key changed, return our special error that unwraps to sql.ErrNoRows.
307251
return &DecryptFailedError{Inner: err}

0 commit comments

Comments
 (0)