Skip to content

Commit cbd776f

Browse files
committed
dbcrypt.New now marks database as encrypted
1 parent 6556269 commit cbd776f

File tree

3 files changed

+141
-47
lines changed

3 files changed

+141
-47
lines changed

enterprise/coderd/coderd.go

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,14 @@ func New(ctx context.Context, options *Options) (_ *API, err error) {
6565
ctx, cancelFunc := context.WithCancel(ctx)
6666

6767
externalTokenCipher := &atomic.Pointer[dbcrypt.Cipher]{}
68-
options.Database = dbcrypt.New(options.Database, &dbcrypt.Options{
68+
cryptDB, err := dbcrypt.New(ctx, options.Database, &dbcrypt.Options{
6969
ExternalTokenCipher: externalTokenCipher,
7070
})
71+
if err != nil {
72+
cancelFunc()
73+
return nil, xerrors.Errorf("init dbcrypt: %w", err)
74+
}
75+
options.Database = cryptDB
7176

7277
api := &API{
7378
ctx: ctx,

enterprise/dbcrypt/dbcrypt.go

Lines changed: 36 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import (
1313
"cdr.dev/slog"
1414

1515
"github.com/coder/coder/v2/coderd/database"
16+
"github.com/coder/coder/v2/coderd/database/dbauthz"
1617
)
1718

1819
// MagicPrefix is prepended to all encrypted values in the database.
@@ -25,7 +26,7 @@ const MagicPrefix = "dbcrypt-"
2526
// Otherwise, the value is encrypted.
2627
const sentinelValue = "coder"
2728

28-
var ErrNotEncrypted = xerrors.New("database is not encrypted")
29+
var ErrNotEnabled = xerrors.New("encryption is not enabled")
2930

3031
type Options struct {
3132
// ExternalTokenCipher is an optional cipher that is used
@@ -37,11 +38,15 @@ type Options struct {
3738

3839
// New creates a database.Store wrapper that encrypts/decrypts values
3940
// stored at rest in the database.
40-
func New(db database.Store, options *Options) database.Store {
41-
return &dbCrypt{
41+
func New(ctx context.Context, db database.Store, options *Options) (database.Store, error) {
42+
dbc := &dbCrypt{
4243
Options: options,
4344
Store: db,
4445
}
46+
if err := ensureEncrypted(dbauthz.AsSystemRestricted(ctx), dbc); err != nil {
47+
return nil, xerrors.Errorf("ensure encrypted database fields: %w", err)
48+
}
49+
return dbc, nil
4550
}
4651

4752
type dbCrypt struct {
@@ -61,14 +66,8 @@ func (db *dbCrypt) InTx(function func(database.Store) error, txOpts *sql.TxOptio
6166
func (db *dbCrypt) GetDBCryptSentinelValue(ctx context.Context) (string, error) {
6267
rawValue, err := db.Store.GetDBCryptSentinelValue(ctx)
6368
if err != nil {
64-
if errors.Is(err, sql.ErrNoRows) {
65-
return "", ErrNotEncrypted
66-
}
6769
return "", err
6870
}
69-
if rawValue == sentinelValue {
70-
return "", ErrNotEncrypted
71-
}
7271
return rawValue, db.decryptFields(&rawValue)
7372
}
7473

@@ -171,7 +170,7 @@ func (db *dbCrypt) decryptFields(fields ...*string) error {
171170
if strings.HasPrefix(*field, MagicPrefix) {
172171
// If we have a magic prefix but encryption is disabled,
173172
// complain loudly.
174-
return xerrors.Errorf("failed to decrypt field %q: encryption is disabled", *field)
173+
return xerrors.Errorf("failed to decrypt field %q: %w", *field, ErrNotEnabled)
175174
}
176175
}
177176
return nil
@@ -183,7 +182,7 @@ func (db *dbCrypt) decryptFields(fields ...*string) error {
183182
continue
184183
}
185184
if len(*field) < len(MagicPrefix) || !strings.HasPrefix(*field, MagicPrefix) {
186-
// We do not force encryption of unencrypted rows. This could be damaging
185+
// We do not force decryption of unencrypted rows. This could be damaging
187186
// to the deployment, and admins can always manually purge data.
188187
continue
189188
}
@@ -201,3 +200,29 @@ func (db *dbCrypt) decryptFields(fields ...*string) error {
201200
}
202201
return nil
203202
}
203+
204+
func ensureEncrypted(ctx context.Context, dbc *dbCrypt) error {
205+
return dbc.InTx(func(s database.Store) error {
206+
val, err := s.GetDBCryptSentinelValue(ctx)
207+
if err != nil {
208+
if !errors.Is(err, sql.ErrNoRows) {
209+
return err
210+
}
211+
}
212+
213+
if val != "" && val != sentinelValue {
214+
// TODO: Handle key rotation.
215+
return xerrors.Errorf("database is already encrypted with a different key and key rotation is not implemented yet")
216+
}
217+
218+
if val == sentinelValue {
219+
return nil // nothing to do!
220+
}
221+
222+
if err := s.SetDBCryptSentinelValue(ctx, sentinelValue); err != nil {
223+
return xerrors.Errorf("mark database as encrypted: %w", err)
224+
}
225+
226+
return nil
227+
}, nil)
228+
}

enterprise/dbcrypt/dbcrypt_test.go

Lines changed: 99 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ func TestUserLinks(t *testing.T) {
2727
t.Run("InsertUserLink", func(t *testing.T) {
2828
t.Parallel()
2929
db, crypt, cipher := setup(t)
30-
initCipher(t, cipher)
3130
user := dbgen.User(t, crypt, database.User{})
3231
link := dbgen.UserLink(t, crypt, database.UserLink{
3332
UserID: user.ID,
@@ -43,7 +42,6 @@ func TestUserLinks(t *testing.T) {
4342
t.Run("UpdateUserLink", func(t *testing.T) {
4443
t.Parallel()
4544
db, crypt, cipher := setup(t)
46-
initCipher(t, cipher)
4745
user := dbgen.User(t, crypt, database.User{})
4846
link := dbgen.UserLink(t, crypt, database.UserLink{
4947
UserID: user.ID,
@@ -64,7 +62,6 @@ func TestUserLinks(t *testing.T) {
6462
t.Run("GetUserLinkByLinkedID", func(t *testing.T) {
6563
t.Parallel()
6664
db, crypt, cipher := setup(t)
67-
initCipher(t, cipher)
6865
user := dbgen.User(t, crypt, database.User{})
6966
link := dbgen.UserLink(t, crypt, database.UserLink{
7067
UserID: user.ID,
@@ -86,7 +83,6 @@ func TestUserLinks(t *testing.T) {
8683
t.Run("GetUserLinkByUserIDLoginType", func(t *testing.T) {
8784
t.Parallel()
8885
db, crypt, cipher := setup(t)
89-
initCipher(t, cipher)
9086
user := dbgen.User(t, crypt, database.User{})
9187
link := dbgen.UserLink(t, crypt, database.UserLink{
9288
UserID: user.ID,
@@ -119,7 +115,6 @@ func TestGitAuthLinks(t *testing.T) {
119115
t.Run("InsertGitAuthLink", func(t *testing.T) {
120116
t.Parallel()
121117
db, crypt, cipher := setup(t)
122-
initCipher(t, cipher)
123118
link := dbgen.GitAuthLink(t, crypt, database.GitAuthLink{
124119
OAuthAccessToken: "access",
125120
OAuthRefreshToken: "refresh",
@@ -136,7 +131,6 @@ func TestGitAuthLinks(t *testing.T) {
136131
t.Run("UpdateGitAuthLink", func(t *testing.T) {
137132
t.Parallel()
138133
db, crypt, cipher := setup(t)
139-
initCipher(t, cipher)
140134
link := dbgen.GitAuthLink(t, crypt, database.GitAuthLink{})
141135
_, err := crypt.UpdateGitAuthLink(ctx, database.UpdateGitAuthLinkParams{
142136
ProviderID: link.ProviderID,
@@ -157,7 +151,6 @@ func TestGitAuthLinks(t *testing.T) {
157151
t.Run("GetGitAuthLink", func(t *testing.T) {
158152
t.Parallel()
159153
db, crypt, cipher := setup(t)
160-
initCipher(t, cipher)
161154
link := dbgen.GitAuthLink(t, crypt, database.GitAuthLink{
162155
OAuthAccessToken: "access",
163156
OAuthRefreshToken: "refresh",
@@ -181,37 +174,90 @@ func TestGitAuthLinks(t *testing.T) {
181174
})
182175
}
183176

184-
func TestDBCryptSentinelValue(t *testing.T) {
177+
func TestNew(t *testing.T) {
185178
t.Parallel()
186-
ctx := context.Background()
187-
db, crypt, cipher := setup(t)
188-
// Initially, the database will not be encrypted.
189-
_, err := db.GetDBCryptSentinelValue(ctx)
190-
require.ErrorIs(t, err, sql.ErrNoRows)
191-
_, err = crypt.GetDBCryptSentinelValue(ctx)
192-
require.EqualError(t, err, dbcrypt.ErrNotEncrypted.Error())
193179

194-
// Now, we'll encrypt the value.
195-
initCipher(t, cipher)
196-
err = crypt.SetDBCryptSentinelValue(ctx, "coder")
197-
require.NoError(t, err)
180+
t.Run("OK", func(t *testing.T) {
181+
// Given: a cipher is loaded
182+
cipher := &atomic.Pointer[dbcrypt.Cipher]{}
183+
initCipher(t, cipher)
184+
ctx, cancel := context.WithCancel(context.Background())
185+
t.Cleanup(cancel)
186+
rawDB, _ := dbtestutil.NewDB(t)
198187

199-
// The value should be encrypted in the database.
200-
crypted, err := db.GetDBCryptSentinelValue(ctx)
201-
require.NoError(t, err)
202-
require.NotEqual(t, "coder", crypted)
203-
decrypted, err := crypt.GetDBCryptSentinelValue(ctx)
204-
require.NoError(t, err)
205-
require.Equal(t, "coder", decrypted)
206-
requireEncryptedEquals(t, cipher, crypted, "coder")
188+
// When: we init the crypt db
189+
cryptDB, err := dbcrypt.New(ctx, rawDB, &dbcrypt.Options{
190+
ExternalTokenCipher: cipher,
191+
Logger: slogtest.Make(t, nil).Leveled(slog.LevelDebug),
192+
})
193+
require.NoError(t, err)
207194

208-
// Reset the key and empty values should be returned!
209-
initCipher(t, cipher)
195+
// Then: the sentinel value is encrypted
196+
cryptVal, err := cryptDB.GetDBCryptSentinelValue(ctx)
197+
require.NoError(t, err)
198+
require.Equal(t, "coder", cryptVal)
210199

211-
_, err = db.GetDBCryptSentinelValue(ctx) // We can still read the raw value
212-
require.NoError(t, err)
213-
_, err = crypt.GetDBCryptSentinelValue(ctx) // Decryption should fail
214-
require.ErrorIs(t, err, sql.ErrNoRows)
200+
rawVal, err := rawDB.GetDBCryptSentinelValue(ctx)
201+
require.NoError(t, err)
202+
require.Contains(t, rawVal, dbcrypt.MagicPrefix)
203+
})
204+
205+
t.Run("NoCipher", func(t *testing.T) {
206+
// Given: no cipher is loaded
207+
cipher := &atomic.Pointer[dbcrypt.Cipher]{}
208+
// initCipher(t, cipher)
209+
ctx, cancel := context.WithCancel(context.Background())
210+
t.Cleanup(cancel)
211+
rawDB, _ := dbtestutil.NewDB(t)
212+
213+
// When: we init the crypt db
214+
cryptDB, err := dbcrypt.New(ctx, rawDB, &dbcrypt.Options{
215+
ExternalTokenCipher: cipher,
216+
Logger: slogtest.Make(t, nil).Leveled(slog.LevelDebug),
217+
})
218+
require.NoError(t, err)
219+
220+
// Then: the sentinel value is not encrypted
221+
cryptVal, err := cryptDB.GetDBCryptSentinelValue(ctx)
222+
require.NoError(t, err)
223+
require.Equal(t, "coder", cryptVal)
224+
225+
rawVal, err := rawDB.GetDBCryptSentinelValue(ctx)
226+
require.NoError(t, err)
227+
require.Equal(t, "coder", rawVal)
228+
})
229+
230+
t.Run("CipherChanged", func(t *testing.T) {
231+
// Given: no cipher is loaded
232+
cipher := &atomic.Pointer[dbcrypt.Cipher]{}
233+
initCipher(t, cipher)
234+
ctx, cancel := context.WithCancel(context.Background())
235+
t.Cleanup(cancel)
236+
rawDB, _ := dbtestutil.NewDB(t)
237+
238+
// And: the sentinel value is encrypted with a different cipher
239+
cipher2 := &atomic.Pointer[dbcrypt.Cipher]{}
240+
initCipher(t, cipher2)
241+
field := "coder"
242+
encrypted, err := (*cipher2.Load()).Encrypt([]byte(field))
243+
require.NoError(t, err)
244+
b64encrypted := base64.StdEncoding.EncodeToString(encrypted)
245+
require.NoError(t, rawDB.SetDBCryptSentinelValue(ctx, b64encrypted))
246+
247+
// When: we init the crypt db
248+
_, err = dbcrypt.New(ctx, rawDB, &dbcrypt.Options{
249+
ExternalTokenCipher: cipher,
250+
Logger: slogtest.Make(t, nil).Leveled(slog.LevelDebug),
251+
})
252+
// Then: an error is returned
253+
// TODO: when we implement key rotation, this should not fail.
254+
require.ErrorContains(t, err, "database is already encrypted with a different key")
255+
256+
// And the sentinel value should remain unchanged. For now.
257+
rawVal, err := rawDB.GetDBCryptSentinelValue(ctx)
258+
require.NoError(t, err)
259+
require.Equal(t, b64encrypted, rawVal)
260+
})
215261
}
216262

217263
func requireEncryptedEquals(t *testing.T, cipher *atomic.Pointer[dbcrypt.Cipher], value, expected string) {
@@ -238,10 +284,28 @@ func initCipher(t *testing.T, cipher *atomic.Pointer[dbcrypt.Cipher]) {
238284

239285
func setup(t *testing.T) (db, cryptodb database.Store, cipher *atomic.Pointer[dbcrypt.Cipher]) {
240286
t.Helper()
287+
ctx, cancel := context.WithCancel(context.Background())
288+
t.Cleanup(cancel)
241289
rawDB, _ := dbtestutil.NewDB(t)
290+
291+
_, err := rawDB.GetDBCryptSentinelValue(ctx)
292+
require.ErrorIs(t, err, sql.ErrNoRows)
293+
242294
cipher = &atomic.Pointer[dbcrypt.Cipher]{}
243-
return rawDB, dbcrypt.New(rawDB, &dbcrypt.Options{
295+
initCipher(t, cipher)
296+
cryptDB, err := dbcrypt.New(ctx, rawDB, &dbcrypt.Options{
244297
ExternalTokenCipher: cipher,
245298
Logger: slogtest.Make(t, nil).Leveled(slog.LevelDebug),
246-
}), cipher
299+
})
300+
require.NoError(t, err)
301+
302+
rawVal, err := rawDB.GetDBCryptSentinelValue(ctx)
303+
require.NoError(t, err)
304+
require.Contains(t, rawVal, dbcrypt.MagicPrefix)
305+
306+
cryptVal, err := cryptDB.GetDBCryptSentinelValue(ctx)
307+
require.NoError(t, err)
308+
require.Equal(t, "coder", cryptVal)
309+
310+
return rawDB, cryptDB, cipher
247311
}

0 commit comments

Comments
 (0)