Skip to content

Commit 6651fe1

Browse files
committed
feat: encrypt oidc and git auth tokens in the database
1 parent 71c52ea commit 6651fe1

File tree

11 files changed

+497
-15
lines changed

11 files changed

+497
-15
lines changed

coderd/authorize.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ func (h *HTTPAuthorizer) Authorize(r *http.Request, action rbac.Action, object r
7575
}
7676
// Log information for debugging. This will be very helpful
7777
// in the early days
78-
logger.Warn(r.Context(), "unauthorized",
78+
logger.Debug(r.Context(), "unauthorized",
7979
slog.F("roles", roles.Actor.SafeRoleNames()),
8080
slog.F("actor_id", roles.Actor.ID),
8181
slog.F("actor_name", roles.ActorName),

coderd/database/dbcrypt/dbcrypt.go

Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
package dbcrypt
2+
3+
import (
4+
"context"
5+
"database/sql"
6+
"strings"
7+
"sync/atomic"
8+
9+
"golang.org/x/xerrors"
10+
11+
"github.com/coder/coder/coderd/database"
12+
"github.com/coder/coder/cryptorand"
13+
)
14+
15+
// MagicPrefix is prepended to all encrypted values in the database.
16+
// This is used to determine if a value is encrypted or not.
17+
// If it is encrypted but a key is not provided, an error is returned.
18+
const MagicPrefix = "dbcrypt-"
19+
20+
// ErrInvalidCipher is returned when an invalid cipher is provided
21+
// for the encrypted data.
22+
var ErrInvalidCipher = xerrors.New("an invalid encryption cipher was provided for the encrypted data")
23+
24+
type Options struct {
25+
// ExternalTokenCipher is an optional cipher that is used
26+
// to encrypt/decrypt user link and git auth link tokens. If this is nil,
27+
// then no encryption/decryption will be performed.
28+
ExternalTokenCipher *atomic.Pointer[cryptorand.Cipher]
29+
}
30+
31+
// New creates a database.Store wrapper that encrypts/decrypts values
32+
// stored at rest in the database.
33+
func New(db database.Store, options *Options) database.Store {
34+
return &dbCrypt{
35+
Options: options,
36+
Store: db,
37+
}
38+
}
39+
40+
type dbCrypt struct {
41+
*Options
42+
database.Store
43+
}
44+
45+
func (db *dbCrypt) InTx(function func(database.Store) error, txOpts *sql.TxOptions) error {
46+
return db.Store.InTx(func(s database.Store) error {
47+
return function(&dbCrypt{
48+
Options: db.Options,
49+
Store: s,
50+
})
51+
}, txOpts)
52+
}
53+
54+
func (db *dbCrypt) GetUserLinkByLinkedID(ctx context.Context, linkedID string) (database.UserLink, error) {
55+
link, err := db.Store.GetUserLinkByLinkedID(ctx, linkedID)
56+
if err != nil {
57+
return database.UserLink{}, err
58+
}
59+
return link, db.decryptFields(&link.OAuthAccessToken, &link.OAuthRefreshToken)
60+
}
61+
62+
func (db *dbCrypt) GetUserLinkByUserIDLoginType(ctx context.Context, params database.GetUserLinkByUserIDLoginTypeParams) (database.UserLink, error) {
63+
link, err := db.Store.GetUserLinkByUserIDLoginType(ctx, params)
64+
if err != nil {
65+
return database.UserLink{}, err
66+
}
67+
return link, db.decryptFields(&link.OAuthAccessToken, &link.OAuthRefreshToken)
68+
}
69+
70+
func (db *dbCrypt) InsertUserLink(ctx context.Context, params database.InsertUserLinkParams) (database.UserLink, error) {
71+
err := db.encryptFields(&params.OAuthAccessToken, &params.OAuthRefreshToken)
72+
if err != nil {
73+
return database.UserLink{}, err
74+
}
75+
return db.Store.InsertUserLink(ctx, params)
76+
}
77+
78+
func (db *dbCrypt) UpdateUserLink(ctx context.Context, params database.UpdateUserLinkParams) (database.UserLink, error) {
79+
err := db.encryptFields(&params.OAuthAccessToken, &params.OAuthRefreshToken)
80+
if err != nil {
81+
return database.UserLink{}, err
82+
}
83+
return db.Store.UpdateUserLink(ctx, params)
84+
}
85+
86+
func (db *dbCrypt) InsertGitAuthLink(ctx context.Context, params database.InsertGitAuthLinkParams) (database.GitAuthLink, error) {
87+
err := db.encryptFields(&params.OAuthAccessToken, &params.OAuthRefreshToken)
88+
if err != nil {
89+
return database.GitAuthLink{}, err
90+
}
91+
return db.Store.InsertGitAuthLink(ctx, params)
92+
}
93+
94+
func (db *dbCrypt) GetGitAuthLink(ctx context.Context, params database.GetGitAuthLinkParams) (database.GitAuthLink, error) {
95+
link, err := db.Store.GetGitAuthLink(ctx, params)
96+
if err != nil {
97+
return database.GitAuthLink{}, err
98+
}
99+
return link, db.decryptFields(&link.OAuthAccessToken, &link.OAuthRefreshToken)
100+
}
101+
102+
func (db *dbCrypt) UpdateGitAuthLink(ctx context.Context, params database.UpdateGitAuthLinkParams) (database.GitAuthLink, error) {
103+
err := db.encryptFields(&params.OAuthAccessToken, &params.OAuthRefreshToken)
104+
if err != nil {
105+
return database.GitAuthLink{}, err
106+
}
107+
return db.Store.UpdateGitAuthLink(ctx, params)
108+
}
109+
110+
func (db *dbCrypt) encryptFields(fields ...*string) error {
111+
cipherPtr := db.ExternalTokenCipher.Load()
112+
// If no cipher is loaded, then we don't need to encrypt or decrypt anything!
113+
if cipherPtr == nil {
114+
return nil
115+
}
116+
cipher := *cipherPtr
117+
for _, field := range fields {
118+
if field == nil {
119+
continue
120+
}
121+
122+
encrypted, err := cipher.Encrypt([]byte(*field))
123+
if err != nil {
124+
return err
125+
}
126+
*field = MagicPrefix + string(encrypted)
127+
}
128+
return nil
129+
}
130+
131+
// decryptFields decrypts the given fields in place.
132+
// If the value fails to decrypt, sql.ErrNoRows will be returned.
133+
func (db *dbCrypt) decryptFields(fields ...*string) error {
134+
cipherPtr := db.ExternalTokenCipher.Load()
135+
// If no cipher is loaded, then we don't need to encrypt or decrypt anything!
136+
if cipherPtr == nil {
137+
for _, field := range fields {
138+
if field == nil {
139+
continue
140+
}
141+
if strings.HasPrefix(*field, MagicPrefix) {
142+
return ErrInvalidCipher
143+
}
144+
}
145+
return nil
146+
}
147+
148+
cipher := *cipherPtr
149+
for _, field := range fields {
150+
if field == nil {
151+
continue
152+
}
153+
if len(*field) < len(MagicPrefix) || !strings.HasPrefix(*field, MagicPrefix) {
154+
continue
155+
}
156+
157+
decrypted, err := cipher.Decrypt([]byte((*field)[len(MagicPrefix):]))
158+
if err != nil {
159+
return xerrors.Errorf("%w: %s", ErrInvalidCipher, err)
160+
}
161+
*field = string(decrypted)
162+
}
163+
return nil
164+
}
Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
package dbcrypt_test
2+
3+
import (
4+
"context"
5+
"crypto/rand"
6+
"io"
7+
"sync/atomic"
8+
"testing"
9+
10+
"github.com/stretchr/testify/require"
11+
12+
"github.com/coder/coder/coderd/database"
13+
"github.com/coder/coder/coderd/database/dbcrypt"
14+
"github.com/coder/coder/coderd/database/dbfake"
15+
"github.com/coder/coder/coderd/database/dbgen"
16+
"github.com/coder/coder/cryptorand"
17+
)
18+
19+
func TestUserLinks(t *testing.T) {
20+
t.Parallel()
21+
ctx := context.Background()
22+
23+
t.Run("InsertUserLink", func(t *testing.T) {
24+
t.Parallel()
25+
db, crypt, cipher := setup(t)
26+
initCipher(t, cipher)
27+
link := dbgen.UserLink(t, crypt, database.UserLink{
28+
OAuthAccessToken: "access",
29+
OAuthRefreshToken: "refresh",
30+
})
31+
link, err := db.GetUserLinkByLinkedID(ctx, link.LinkedID)
32+
require.NoError(t, err)
33+
requireEncryptedEquals(t, cipher, link.OAuthAccessToken, "access")
34+
requireEncryptedEquals(t, cipher, link.OAuthRefreshToken, "refresh")
35+
})
36+
37+
t.Run("UpdateUserLink", func(t *testing.T) {
38+
t.Parallel()
39+
db, crypt, cipher := setup(t)
40+
initCipher(t, cipher)
41+
link := dbgen.UserLink(t, crypt, database.UserLink{})
42+
_, err := crypt.UpdateUserLink(ctx, database.UpdateUserLinkParams{
43+
OAuthAccessToken: "access",
44+
OAuthRefreshToken: "refresh",
45+
UserID: link.UserID,
46+
LoginType: link.LoginType,
47+
})
48+
require.NoError(t, err)
49+
link, err = db.GetUserLinkByLinkedID(ctx, link.LinkedID)
50+
require.NoError(t, err)
51+
requireEncryptedEquals(t, cipher, link.OAuthAccessToken, "access")
52+
requireEncryptedEquals(t, cipher, link.OAuthRefreshToken, "refresh")
53+
})
54+
55+
t.Run("GetUserLinkByLinkedID", func(t *testing.T) {
56+
t.Parallel()
57+
db, crypt, cipher := setup(t)
58+
initCipher(t, cipher)
59+
link := dbgen.UserLink(t, crypt, database.UserLink{
60+
OAuthAccessToken: "access",
61+
OAuthRefreshToken: "refresh",
62+
})
63+
link, err := db.GetUserLinkByLinkedID(ctx, link.LinkedID)
64+
require.NoError(t, err)
65+
requireEncryptedEquals(t, cipher, link.OAuthAccessToken, "access")
66+
requireEncryptedEquals(t, cipher, link.OAuthRefreshToken, "refresh")
67+
68+
// Reset the key and empty values should be returned!
69+
initCipher(t, cipher)
70+
71+
link, err = crypt.GetUserLinkByLinkedID(ctx, link.LinkedID)
72+
require.ErrorIs(t, err, dbcrypt.ErrInvalidCipher)
73+
})
74+
75+
t.Run("GetUserLinkByUserIDLoginType", func(t *testing.T) {
76+
t.Parallel()
77+
db, crypt, cipher := setup(t)
78+
initCipher(t, cipher)
79+
link := dbgen.UserLink(t, crypt, database.UserLink{
80+
OAuthAccessToken: "access",
81+
OAuthRefreshToken: "refresh",
82+
})
83+
link, err := db.GetUserLinkByUserIDLoginType(ctx, database.GetUserLinkByUserIDLoginTypeParams{
84+
UserID: link.UserID,
85+
LoginType: link.LoginType,
86+
})
87+
require.NoError(t, err)
88+
requireEncryptedEquals(t, cipher, link.OAuthAccessToken, "access")
89+
requireEncryptedEquals(t, cipher, link.OAuthRefreshToken, "refresh")
90+
91+
// Reset the key and empty values should be returned!
92+
initCipher(t, cipher)
93+
94+
link, err = crypt.GetUserLinkByUserIDLoginType(ctx, database.GetUserLinkByUserIDLoginTypeParams{
95+
UserID: link.UserID,
96+
LoginType: link.LoginType,
97+
})
98+
require.ErrorIs(t, err, dbcrypt.ErrInvalidCipher)
99+
})
100+
}
101+
102+
func TestGitAuthLinks(t *testing.T) {
103+
t.Parallel()
104+
ctx := context.Background()
105+
106+
t.Run("InsertGitAuthLink", func(t *testing.T) {
107+
t.Parallel()
108+
db, crypt, cipher := setup(t)
109+
initCipher(t, cipher)
110+
link := dbgen.GitAuthLink(t, crypt, database.GitAuthLink{
111+
OAuthAccessToken: "access",
112+
OAuthRefreshToken: "refresh",
113+
})
114+
link, err := db.GetGitAuthLink(ctx, database.GetGitAuthLinkParams{
115+
ProviderID: link.ProviderID,
116+
UserID: link.UserID,
117+
})
118+
require.NoError(t, err)
119+
requireEncryptedEquals(t, cipher, link.OAuthAccessToken, "access")
120+
requireEncryptedEquals(t, cipher, link.OAuthRefreshToken, "refresh")
121+
})
122+
123+
t.Run("UpdateGitAuthLink", func(t *testing.T) {
124+
t.Parallel()
125+
db, crypt, cipher := setup(t)
126+
initCipher(t, cipher)
127+
link := dbgen.GitAuthLink(t, crypt, database.GitAuthLink{})
128+
_, err := crypt.UpdateGitAuthLink(ctx, database.UpdateGitAuthLinkParams{
129+
ProviderID: link.ProviderID,
130+
UserID: link.UserID,
131+
OAuthAccessToken: "access",
132+
OAuthRefreshToken: "refresh",
133+
})
134+
require.NoError(t, err)
135+
link, err = db.GetGitAuthLink(ctx, database.GetGitAuthLinkParams{
136+
ProviderID: link.ProviderID,
137+
UserID: link.UserID,
138+
})
139+
require.NoError(t, err)
140+
requireEncryptedEquals(t, cipher, link.OAuthAccessToken, "access")
141+
requireEncryptedEquals(t, cipher, link.OAuthRefreshToken, "refresh")
142+
})
143+
144+
t.Run("GetGitAuthLink", func(t *testing.T) {
145+
t.Parallel()
146+
db, crypt, cipher := setup(t)
147+
initCipher(t, cipher)
148+
link := dbgen.GitAuthLink(t, crypt, database.GitAuthLink{
149+
OAuthAccessToken: "access",
150+
OAuthRefreshToken: "refresh",
151+
})
152+
link, err := db.GetGitAuthLink(ctx, database.GetGitAuthLinkParams{
153+
UserID: link.UserID,
154+
ProviderID: link.ProviderID,
155+
})
156+
require.NoError(t, err)
157+
requireEncryptedEquals(t, cipher, link.OAuthAccessToken, "access")
158+
requireEncryptedEquals(t, cipher, link.OAuthRefreshToken, "refresh")
159+
160+
// Reset the key and empty values should be returned!
161+
initCipher(t, cipher)
162+
163+
link, err = crypt.GetGitAuthLink(ctx, database.GetGitAuthLinkParams{
164+
UserID: link.UserID,
165+
ProviderID: link.ProviderID,
166+
})
167+
require.ErrorIs(t, err, dbcrypt.ErrInvalidCipher)
168+
})
169+
}
170+
171+
func requireEncryptedEquals(t *testing.T, cipher *atomic.Pointer[cryptorand.Cipher], value, expected string) {
172+
t.Helper()
173+
c := (*cipher.Load())
174+
got, err := c.Decrypt([]byte(value[len(dbcrypt.MagicPrefix):]))
175+
require.NoError(t, err)
176+
require.Equal(t, expected, string(got))
177+
}
178+
179+
func initCipher(t *testing.T, cipher *atomic.Pointer[cryptorand.Cipher]) {
180+
t.Helper()
181+
key := make([]byte, 32) // AES-256 key size is 32 bytes
182+
_, err := io.ReadFull(rand.Reader, key)
183+
require.NoError(t, err)
184+
c, err := cryptorand.CipherAES256(key)
185+
require.NoError(t, err)
186+
cipher.Store(&c)
187+
}
188+
189+
func setup(t *testing.T) (db, cryptodb database.Store, cipher *atomic.Pointer[cryptorand.Cipher]) {
190+
t.Helper()
191+
rawDB := dbfake.New()
192+
cipher = &atomic.Pointer[cryptorand.Cipher]{}
193+
return rawDB, dbcrypt.New(rawDB, &dbcrypt.Options{
194+
ExternalTokenCipher: cipher,
195+
}), cipher
196+
}

coderd/database/dbgen/generator.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -385,7 +385,7 @@ func UserLink(t testing.TB, db database.Store, orig database.UserLink) database.
385385
LoginType: takeFirst(orig.LoginType, database.LoginTypeGithub),
386386
LinkedID: takeFirst(orig.LinkedID),
387387
OAuthAccessToken: takeFirst(orig.OAuthAccessToken, uuid.NewString()),
388-
OAuthRefreshToken: takeFirst(orig.OAuthAccessToken, uuid.NewString()),
388+
OAuthRefreshToken: takeFirst(orig.OAuthRefreshToken, uuid.NewString()),
389389
OAuthExpiry: takeFirst(orig.OAuthExpiry, database.Now().Add(time.Hour*24)),
390390
})
391391

@@ -398,7 +398,7 @@ func GitAuthLink(t testing.TB, db database.Store, orig database.GitAuthLink) dat
398398
ProviderID: takeFirst(orig.ProviderID, uuid.New().String()),
399399
UserID: takeFirst(orig.UserID, uuid.New()),
400400
OAuthAccessToken: takeFirst(orig.OAuthAccessToken, uuid.NewString()),
401-
OAuthRefreshToken: takeFirst(orig.OAuthAccessToken, uuid.NewString()),
401+
OAuthRefreshToken: takeFirst(orig.OAuthRefreshToken, uuid.NewString()),
402402
OAuthExpiry: takeFirst(orig.OAuthExpiry, database.Now().Add(time.Hour*24)),
403403
CreatedAt: takeFirst(orig.CreatedAt, database.Now()),
404404
UpdatedAt: takeFirst(orig.UpdatedAt, database.Now()),

0 commit comments

Comments
 (0)