Skip to content

Commit 5a0161c

Browse files
committed
refactor: add Ciphers to abstract over multiple ciphers
1 parent 67ee610 commit 5a0161c

File tree

2 files changed

+95
-5
lines changed

2 files changed

+95
-5
lines changed

enterprise/dbcrypt/cipher.go

Lines changed: 55 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,14 @@ func CipherAES256(key []byte) (Cipher, error) {
3030
if err != nil {
3131
return nil, err
3232
}
33-
digest := sha256.Sum256(key)
34-
return &aes256{aead: aead, digest: digest[:]}, nil
33+
digest := fmt.Sprintf("%x", sha256.Sum256(key))[:7]
34+
return &aes256{aead: aead, digest: digest}, nil
3535
}
3636

3737
type aes256 struct {
38-
aead cipher.AEAD
39-
digest []byte
38+
aead cipher.AEAD
39+
// digest is the first 7 bytes of the hex-encoded SHA-256 digest of aead.
40+
digest string
4041
}
4142

4243
func (a *aes256) Encrypt(plaintext []byte) ([]byte, error) {
@@ -60,5 +61,54 @@ func (a *aes256) Decrypt(ciphertext []byte) ([]byte, error) {
6061
}
6162

6263
func (a *aes256) HexDigest() string {
63-
return fmt.Sprintf("%x", a.digest)
64+
return a.digest
65+
}
66+
67+
type CipherDigest string
68+
type Ciphers struct {
69+
primary string
70+
m map[string]Cipher
71+
}
72+
73+
// CiphersAES256 returns a new Ciphers instance with the given ciphers.
74+
// The first cipher in the list is the primary cipher. Any ciphers after the
75+
// first are considered secondary ciphers and are only used for decryption.
76+
func CiphersAES256(cs ...Cipher) Ciphers {
77+
var primary string
78+
m := make(map[string]Cipher)
79+
for idx, c := range cs {
80+
m[c.HexDigest()] = c
81+
if idx == 0 {
82+
primary = c.HexDigest()
83+
}
84+
}
85+
return Ciphers{primary: primary, m: m}
86+
}
87+
88+
// Encrypt encrypts the given plaintext using the primary cipher and returns the
89+
// ciphertext. The ciphertext is prefixed with the primary cipher's digest.
90+
func (cs Ciphers) Encrypt(plaintext []byte) ([]byte, error) {
91+
c, ok := cs.m[cs.primary]
92+
if !ok {
93+
return nil, xerrors.Errorf("no ciphers configured")
94+
}
95+
prefix := []byte(c.HexDigest() + "-")
96+
crypted, err := c.Encrypt(plaintext)
97+
if err != nil {
98+
return nil, err
99+
}
100+
return append(prefix, crypted...), nil
101+
}
102+
103+
// Decrypt decrypts the given ciphertext using the cipher indicated by the
104+
// ciphertext's prefix. The prefix is the first 7 bytes of the hex-encoded
105+
// SHA-256 digest of the cipher's key. Decryption will fail if the prefix
106+
// does not match any of the configured ciphers.
107+
func (cs Ciphers) Decrypt(ciphertext []byte) ([]byte, error) {
108+
requiredPrefix := string(ciphertext[:7])
109+
c, ok := cs.m[requiredPrefix]
110+
if !ok {
111+
return nil, xerrors.Errorf("missing required decryption cipher %s", requiredPrefix)
112+
}
113+
return c.Decrypt(ciphertext[8:])
64114
}

enterprise/dbcrypt/cipher_test.go

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,3 +43,43 @@ func TestCipherAES256(t *testing.T) {
4343
require.ErrorContains(t, err, "key must be 32 bytes")
4444
})
4545
}
46+
47+
func TestCiphersAES256(t *testing.T) {
48+
t.Parallel()
49+
50+
// Given: two ciphers
51+
key1 := bytes.Repeat([]byte{'a'}, 32)
52+
key2 := bytes.Repeat([]byte{'b'}, 32)
53+
cipher1, err := dbcrypt.CipherAES256(key1)
54+
require.NoError(t, err)
55+
cipher2, err := dbcrypt.CipherAES256(key2)
56+
require.NoError(t, err)
57+
58+
ciphers := dbcrypt.CiphersAES256(
59+
cipher1,
60+
cipher2,
61+
)
62+
63+
// Then: it should encrypt with the cipher1
64+
output, err := ciphers.Encrypt([]byte("hello world"))
65+
require.NoError(t, err)
66+
// The first 7 bytes of the output should be the hex digest of cipher1
67+
require.Equal(t, cipher1.HexDigest(), string(output[:7]))
68+
69+
// And: it should decrypt successfully
70+
decrypted, err := ciphers.Decrypt(output)
71+
require.NoError(t, err)
72+
require.Equal(t, "hello world", string(decrypted))
73+
74+
// Decryption of the above should fail with cipher2
75+
_, err = cipher2.Decrypt(output)
76+
var decryptErr *dbcrypt.DecryptFailedError
77+
require.ErrorAs(t, err, &decryptErr)
78+
79+
// Decryption of data encrypted with cipher2 should succeed
80+
output2, err := cipher2.Encrypt([]byte("hello world"))
81+
require.NoError(t, err)
82+
decrypted2, err := ciphers.Decrypt(bytes.Join([][]byte{[]byte(cipher2.HexDigest()), output2}, []byte{'-'}))
83+
require.NoError(t, err)
84+
require.Equal(t, "hello world", string(decrypted2))
85+
}

0 commit comments

Comments
 (0)