Skip to content

Commit 7837f71

Browse files
committed
move cipher to dbcrypt package
1 parent 2b99db9 commit 7837f71

File tree

5 files changed

+83
-46
lines changed

5 files changed

+83
-46
lines changed

cryptorand/cipher.go renamed to coderd/database/dbcrypt/cipher.go

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1-
package cryptorand
1+
package dbcrypt
22

33
import (
44
"crypto/aes"
55
"crypto/cipher"
66
"crypto/rand"
7+
"errors"
78
"io"
89

910
"golang.org/x/xerrors"
@@ -14,6 +15,23 @@ type Cipher interface {
1415
Decrypt([]byte) ([]byte, error)
1516
}
1617

18+
type DecryptFailedError struct {
19+
Inner error
20+
}
21+
22+
func (e *DecryptFailedError) Error() string {
23+
return xerrors.Errorf("decrypt failed: %w", e.Inner).Error()
24+
}
25+
26+
func (e *DecryptFailedError) Unwrap() error {
27+
return e.Inner
28+
}
29+
30+
func IsDecryptFailedError(err error) bool {
31+
var e *DecryptFailedError
32+
return errors.As(err, &e)
33+
}
34+
1735
// CipherAES256 returns a new AES-256 cipher.
1836
func CipherAES256(key []byte) (Cipher, error) {
1937
block, err := aes.NewCipher(key)
@@ -44,5 +62,9 @@ func (a *aes256) Decrypt(ciphertext []byte) ([]byte, error) {
4462
if len(ciphertext) < a.aead.NonceSize() {
4563
return nil, xerrors.Errorf("ciphertext too short")
4664
}
47-
return a.aead.Open(nil, ciphertext[:a.aead.NonceSize()], ciphertext[a.aead.NonceSize():], nil)
65+
decrypted, err := a.aead.Open(nil, ciphertext[:a.aead.NonceSize()], ciphertext[a.aead.NonceSize():], nil)
66+
if err != nil {
67+
return nil, &DecryptFailedError{Inner: err}
68+
}
69+
return decrypted, nil
4870
}
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
package dbcrypt_test
2+
3+
import (
4+
"bytes"
5+
"testing"
6+
7+
"github.com/stretchr/testify/require"
8+
9+
"github.com/coder/coder/v2/coderd/database/dbcrypt"
10+
)
11+
12+
func TestCipherAES256(t *testing.T) {
13+
t.Parallel()
14+
15+
t.Run("ValidInput", func(t *testing.T) {
16+
t.Parallel()
17+
key := bytes.Repeat([]byte{'a'}, 32)
18+
cipher, err := dbcrypt.CipherAES256(key)
19+
require.NoError(t, err)
20+
21+
output, err := cipher.Encrypt([]byte("hello world"))
22+
require.NoError(t, err)
23+
24+
response, err := cipher.Decrypt(output)
25+
require.NoError(t, err)
26+
require.Equal(t, "hello world", string(response))
27+
})
28+
29+
t.Run("InvalidInput", func(t *testing.T) {
30+
t.Parallel()
31+
key := bytes.Repeat([]byte{'a'}, 32)
32+
cipher, err := dbcrypt.CipherAES256(key)
33+
require.NoError(t, err)
34+
_, err = cipher.Decrypt(bytes.Repeat([]byte{'a'}, 100))
35+
var decryptErr *dbcrypt.DecryptFailedError
36+
require.ErrorAs(t, err, &decryptErr)
37+
})
38+
39+
t.Run("InvalidKeySize", func(t *testing.T) {
40+
t.Parallel()
41+
42+
_, err := dbcrypt.CipherAES256(bytes.Repeat([]byte{'a'}, 31))
43+
require.ErrorContains(t, err, "invalid key size")
44+
})
45+
}

coderd/database/dbcrypt/dbcrypt.go

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,11 @@ import (
88
"strings"
99
"sync/atomic"
1010

11-
"cdr.dev/slog"
1211
"golang.org/x/xerrors"
1312

13+
"cdr.dev/slog"
14+
1415
"github.com/coder/coder/v2/coderd/database"
15-
"github.com/coder/coder/v2/cryptorand"
1616
)
1717

1818
// MagicPrefix is prepended to all encrypted values in the database.
@@ -24,7 +24,7 @@ type Options struct {
2424
// ExternalTokenCipher is an optional cipher that is used
2525
// to encrypt/decrypt user link and git auth link tokens. If this is nil,
2626
// then no encryption/decryption will be performed.
27-
ExternalTokenCipher *atomic.Pointer[cryptorand.Cipher]
27+
ExternalTokenCipher *atomic.Pointer[Cipher]
2828
Logger slog.Logger
2929
}
3030

@@ -141,7 +141,7 @@ func (db *dbCrypt) encryptFields(fields ...*string) error {
141141
// decryptFields decrypts the given fields in place.
142142
// If the value fails to decrypt, sql.ErrNoRows will be returned.
143143
func (db *dbCrypt) decryptFields(deleteFn func() error, fields ...*string) error {
144-
delete := func(reason string) error {
144+
doDelete := func(reason string) error {
145145
err := deleteFn()
146146
if err != nil {
147147
return xerrors.Errorf("delete encrypted row: %w", err)
@@ -164,7 +164,7 @@ func (db *dbCrypt) decryptFields(deleteFn func() error, fields ...*string) error
164164
if strings.HasPrefix(*field, MagicPrefix) {
165165
// If we have a magic prefix but encryption is disabled,
166166
// we should delete the row.
167-
return delete("encryption disabled")
167+
return doDelete("encryption disabled")
168168
}
169169
}
170170
return nil
@@ -183,12 +183,12 @@ func (db *dbCrypt) decryptFields(deleteFn func() error, fields ...*string) error
183183
data, err := base64.StdEncoding.DecodeString((*field)[len(MagicPrefix):])
184184
if err != nil {
185185
// If it's not base64 with the prefix, we should delete the row.
186-
return delete("stored value was not base64 encoded")
186+
return doDelete("stored value was not base64 encoded")
187187
}
188188
decrypted, err := cipher.Decrypt(data)
189189
if err != nil {
190190
// If the encryption key changed, we should delete the row.
191-
return delete("encryption key changed")
191+
return doDelete("encryption key changed")
192192
}
193193
*field = string(decrypted)
194194
}

coderd/database/dbcrypt/dbcrypt_test.go

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,15 @@ import (
99
"sync/atomic"
1010
"testing"
1111

12+
"github.com/stretchr/testify/require"
13+
1214
"cdr.dev/slog"
1315
"cdr.dev/slog/sloggers/slogtest"
14-
"github.com/stretchr/testify/require"
1516

1617
"github.com/coder/coder/v2/coderd/database"
1718
"github.com/coder/coder/v2/coderd/database/dbcrypt"
1819
"github.com/coder/coder/v2/coderd/database/dbfake"
1920
"github.com/coder/coder/v2/coderd/database/dbgen"
20-
"github.com/coder/coder/v2/cryptorand"
2121
)
2222

2323
func TestUserLinks(t *testing.T) {
@@ -172,7 +172,7 @@ func TestGitAuthLinks(t *testing.T) {
172172
})
173173
}
174174

175-
func requireEncryptedEquals(t *testing.T, cipher *atomic.Pointer[cryptorand.Cipher], value, expected string) {
175+
func requireEncryptedEquals(t *testing.T, cipher *atomic.Pointer[dbcrypt.Cipher], value, expected string) {
176176
t.Helper()
177177
c := (*cipher.Load())
178178
data, err := base64.StdEncoding.DecodeString(value[len(dbcrypt.MagicPrefix):])
@@ -182,20 +182,20 @@ func requireEncryptedEquals(t *testing.T, cipher *atomic.Pointer[cryptorand.Ciph
182182
require.Equal(t, expected, string(got))
183183
}
184184

185-
func initCipher(t *testing.T, cipher *atomic.Pointer[cryptorand.Cipher]) {
185+
func initCipher(t *testing.T, cipher *atomic.Pointer[dbcrypt.Cipher]) {
186186
t.Helper()
187187
key := make([]byte, 32) // AES-256 key size is 32 bytes
188188
_, err := io.ReadFull(rand.Reader, key)
189189
require.NoError(t, err)
190-
c, err := cryptorand.CipherAES256(key)
190+
c, err := dbcrypt.CipherAES256(key)
191191
require.NoError(t, err)
192192
cipher.Store(&c)
193193
}
194194

195-
func setup(t *testing.T) (db, cryptodb database.Store, cipher *atomic.Pointer[cryptorand.Cipher]) {
195+
func setup(t *testing.T) (db, cryptodb database.Store, cipher *atomic.Pointer[dbcrypt.Cipher]) {
196196
t.Helper()
197197
rawDB := dbfake.New()
198-
cipher = &atomic.Pointer[cryptorand.Cipher]{}
198+
cipher = &atomic.Pointer[dbcrypt.Cipher]{}
199199
return rawDB, dbcrypt.New(rawDB, &dbcrypt.Options{
200200
ExternalTokenCipher: cipher,
201201
Logger: slogtest.Make(t, nil).Leveled(slog.LevelDebug),

cryptorand/cipher_test.go

Lines changed: 0 additions & 30 deletions
This file was deleted.

0 commit comments

Comments
 (0)