Skip to content

Commit 8cb07ba

Browse files
committed
decrypt fields when inserting and updating!
1 parent fe21f26 commit 8cb07ba

File tree

3 files changed

+44
-7
lines changed

3 files changed

+44
-7
lines changed

enterprise/coderd/coderdenttest/coderdenttest.go

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,12 @@ import (
2121
"github.com/coder/coder/v2/codersdk"
2222
"github.com/coder/coder/v2/enterprise/coderd"
2323
"github.com/coder/coder/v2/enterprise/coderd/license"
24+
"github.com/coder/coder/v2/enterprise/dbcrypt"
2425
)
2526

2627
const (
27-
testKeyID = "enterprise-test"
28+
testKeyID = "enterprise-test"
29+
testEncryptionKey = "coder-coder-coder-coder-coder-1!" // nolint:gosec
2830
)
2931

3032
var (
@@ -56,6 +58,7 @@ type Options struct {
5658
DontAddLicense bool
5759
DontAddFirstUser bool
5860
ReplicaSyncUpdateInterval time.Duration
61+
ExternalTokenEncryption *dbcrypt.Ciphers
5962
ProvisionerDaemonPSK string
6063
}
6164

@@ -82,6 +85,11 @@ func NewWithAPI(t *testing.T, options *Options) (
8285
err := oop.DeploymentValues.UserQuietHoursSchedule.DefaultSchedule.Set("0 0 * * *")
8386
require.NoError(t, err)
8487
}
88+
if options.ExternalTokenEncryption == nil {
89+
c, err := dbcrypt.CipherAES256([]byte(testEncryptionKey))
90+
require.NoError(t, err)
91+
options.ExternalTokenEncryption = dbcrypt.NewCiphers(c)
92+
}
8593
coderAPI, err := coderd.New(context.Background(), &coderd.Options{
8694
RBAC: true,
8795
AuditLogging: options.AuditLogging,
@@ -96,6 +104,7 @@ func NewWithAPI(t *testing.T, options *Options) (
96104
ProxyHealthInterval: options.ProxyHealthInterval,
97105
DefaultQuietHoursSchedule: oop.DeploymentValues.UserQuietHoursSchedule.DefaultSchedule.Value(),
98106
ProvisionerDaemonPSK: options.ProvisionerDaemonPSK,
107+
ExternalTokenEncryption: options.ExternalTokenEncryption,
99108
})
100109
require.NoError(t, err)
101110
setHandler(coderAPI.AGPL.RootHandler)

enterprise/dbcrypt/dbcrypt.go

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -145,23 +145,35 @@ func (db *dbCrypt) InsertUserLink(ctx context.Context, params database.InsertUse
145145
if err != nil {
146146
return database.UserLink{}, err
147147
}
148-
return db.Store.InsertUserLink(ctx, params)
148+
link, err := db.Store.InsertUserLink(ctx, params)
149+
if err != nil {
150+
return database.UserLink{}, err
151+
}
152+
return link, db.decryptFields(&link.OAuthAccessToken, &link.OAuthRefreshToken)
149153
}
150154

151155
func (db *dbCrypt) UpdateUserLink(ctx context.Context, params database.UpdateUserLinkParams) (database.UserLink, error) {
152156
err := db.encryptFields(&params.OAuthAccessToken, &params.OAuthRefreshToken)
153157
if err != nil {
154158
return database.UserLink{}, err
155159
}
156-
return db.Store.UpdateUserLink(ctx, params)
160+
updated, err := db.Store.UpdateUserLink(ctx, params)
161+
if err != nil {
162+
return database.UserLink{}, err
163+
}
164+
return updated, db.decryptFields(&updated.OAuthAccessToken, &updated.OAuthRefreshToken)
157165
}
158166

159167
func (db *dbCrypt) InsertGitAuthLink(ctx context.Context, params database.InsertGitAuthLinkParams) (database.GitAuthLink, error) {
160168
err := db.encryptFields(&params.OAuthAccessToken, &params.OAuthRefreshToken)
161169
if err != nil {
162170
return database.GitAuthLink{}, err
163171
}
164-
return db.Store.InsertGitAuthLink(ctx, params)
172+
link, err := db.Store.InsertGitAuthLink(ctx, params)
173+
if err != nil {
174+
return database.GitAuthLink{}, err
175+
}
176+
return link, db.decryptFields(&link.OAuthAccessToken, &link.OAuthRefreshToken)
165177
}
166178

167179
func (db *dbCrypt) GetGitAuthLink(ctx context.Context, params database.GetGitAuthLinkParams) (database.GitAuthLink, error) {
@@ -190,7 +202,11 @@ func (db *dbCrypt) UpdateGitAuthLink(ctx context.Context, params database.Update
190202
if err != nil {
191203
return database.GitAuthLink{}, err
192204
}
193-
return db.Store.UpdateGitAuthLink(ctx, params)
205+
updated, err := db.Store.UpdateGitAuthLink(ctx, params)
206+
if err != nil {
207+
return database.GitAuthLink{}, err
208+
}
209+
return updated, db.decryptFields(&updated.OAuthAccessToken, &updated.OAuthRefreshToken)
194210
}
195211

196212
func (db *dbCrypt) SetDBCryptSentinelValue(ctx context.Context, value string) error {

enterprise/dbcrypt/dbcrypt_test.go

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@ func TestUserLinks(t *testing.T) {
2929
OAuthAccessToken: "access",
3030
OAuthRefreshToken: "refresh",
3131
})
32+
require.Equal(t, link.OAuthAccessToken, "access")
33+
require.Equal(t, link.OAuthRefreshToken, "refresh")
34+
3235
link, err := db.GetUserLinkByLinkedID(ctx, link.LinkedID)
3336
require.NoError(t, err)
3437
requireEncryptedEquals(t, cipher, link.OAuthAccessToken, "access")
@@ -42,13 +45,16 @@ func TestUserLinks(t *testing.T) {
4245
link := dbgen.UserLink(t, crypt, database.UserLink{
4346
UserID: user.ID,
4447
})
45-
_, err := crypt.UpdateUserLink(ctx, database.UpdateUserLinkParams{
48+
updated, err := crypt.UpdateUserLink(ctx, database.UpdateUserLinkParams{
4649
OAuthAccessToken: "access",
4750
OAuthRefreshToken: "refresh",
4851
UserID: link.UserID,
4952
LoginType: link.LoginType,
5053
})
5154
require.NoError(t, err)
55+
require.Equal(t, updated.OAuthAccessToken, "access")
56+
require.Equal(t, updated.OAuthRefreshToken, "refresh")
57+
5258
link, err = db.GetUserLinkByLinkedID(ctx, link.LinkedID)
5359
require.NoError(t, err)
5460
requireEncryptedEquals(t, cipher, link.OAuthAccessToken, "access")
@@ -100,6 +106,9 @@ func TestGitAuthLinks(t *testing.T) {
100106
OAuthAccessToken: "access",
101107
OAuthRefreshToken: "refresh",
102108
})
109+
require.Equal(t, link.OAuthAccessToken, "access")
110+
require.Equal(t, link.OAuthRefreshToken, "refresh")
111+
103112
link, err := db.GetGitAuthLink(ctx, database.GetGitAuthLinkParams{
104113
ProviderID: link.ProviderID,
105114
UserID: link.UserID,
@@ -113,13 +122,16 @@ func TestGitAuthLinks(t *testing.T) {
113122
t.Parallel()
114123
db, crypt, cipher := setup(t)
115124
link := dbgen.GitAuthLink(t, crypt, database.GitAuthLink{})
116-
_, err := crypt.UpdateGitAuthLink(ctx, database.UpdateGitAuthLinkParams{
125+
updated, err := crypt.UpdateGitAuthLink(ctx, database.UpdateGitAuthLinkParams{
117126
ProviderID: link.ProviderID,
118127
UserID: link.UserID,
119128
OAuthAccessToken: "access",
120129
OAuthRefreshToken: "refresh",
121130
})
122131
require.NoError(t, err)
132+
require.Equal(t, updated.OAuthAccessToken, "access")
133+
require.Equal(t, updated.OAuthRefreshToken, "refresh")
134+
123135
link, err = db.GetGitAuthLink(ctx, database.GetGitAuthLinkParams{
124136
ProviderID: link.ProviderID,
125137
UserID: link.UserID,

0 commit comments

Comments
 (0)