Skip to content

Commit e744cde

Browse files
authored
fix(coderd): ensure that clearing invalid oauth refresh tokens works with dbcrypt (coder#15721)
coder#15608 introduced a buggy behaviour with dbcrypt enabled. When clearing an oauth refresh token, we had been setting the value to the empty string. The database encryption package considers decrypting an empty string to be an error, as an empty encrypted string value will still have a nonce associated with it and thus not actually be empty when stored at rest. Instead of 'deleting' the refresh token, 'update' it to be the empty string. This plays nicely with dbcrypt. It also adds a 'utility test' in the dbcrypt package to help encrypt a value. This was useful when manually fixing users affected by this bug on our dogfood instance.
1 parent ebfc133 commit e744cde

File tree

13 files changed

+184
-94
lines changed

13 files changed

+184
-94
lines changed

coderd/database/dbauthz/dbauthz.go

+7-7
Original file line numberDiff line numberDiff line change
@@ -3367,13 +3367,6 @@ func (q *querier) RegisterWorkspaceProxy(ctx context.Context, arg database.Regis
33673367
return updateWithReturn(q.log, q.auth, fetch, q.db.RegisterWorkspaceProxy)(ctx, arg)
33683368
}
33693369

3370-
func (q *querier) RemoveRefreshToken(ctx context.Context, arg database.RemoveRefreshTokenParams) error {
3371-
fetch := func(ctx context.Context, arg database.RemoveRefreshTokenParams) (database.ExternalAuthLink, error) {
3372-
return q.db.GetExternalAuthLink(ctx, database.GetExternalAuthLinkParams{UserID: arg.UserID, ProviderID: arg.ProviderID})
3373-
}
3374-
return fetchAndExec(q.log, q.auth, policy.ActionUpdatePersonal, fetch, q.db.RemoveRefreshToken)(ctx, arg)
3375-
}
3376-
33773370
func (q *querier) RemoveUserFromAllGroups(ctx context.Context, userID uuid.UUID) error {
33783371
// This is a system function to clear user groups in group sync.
33793372
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceSystem); err != nil {
@@ -3472,6 +3465,13 @@ func (q *querier) UpdateExternalAuthLink(ctx context.Context, arg database.Updat
34723465
return fetchAndQuery(q.log, q.auth, policy.ActionUpdatePersonal, fetch, q.db.UpdateExternalAuthLink)(ctx, arg)
34733466
}
34743467

3468+
func (q *querier) UpdateExternalAuthLinkRefreshToken(ctx context.Context, arg database.UpdateExternalAuthLinkRefreshTokenParams) error {
3469+
fetch := func(ctx context.Context, arg database.UpdateExternalAuthLinkRefreshTokenParams) (database.ExternalAuthLink, error) {
3470+
return q.db.GetExternalAuthLink(ctx, database.GetExternalAuthLinkParams{UserID: arg.UserID, ProviderID: arg.ProviderID})
3471+
}
3472+
return fetchAndExec(q.log, q.auth, policy.ActionUpdatePersonal, fetch, q.db.UpdateExternalAuthLinkRefreshToken)(ctx, arg)
3473+
}
3474+
34753475
func (q *querier) UpdateGitSSHKey(ctx context.Context, arg database.UpdateGitSSHKeyParams) (database.GitSSHKey, error) {
34763476
fetch := func(ctx context.Context, arg database.UpdateGitSSHKeyParams) (database.GitSSHKey, error) {
34773477
return q.db.GetGitSSHKey(ctx, arg.UserID)

coderd/database/dbauthz/dbauthz_test.go

+7-5
Original file line numberDiff line numberDiff line change
@@ -1282,12 +1282,14 @@ func (s *MethodTestSuite) TestUser() {
12821282
UserID: u.ID,
12831283
}).Asserts(u, policy.ActionUpdatePersonal)
12841284
}))
1285-
s.Run("RemoveRefreshToken", s.Subtest(func(db database.Store, check *expects) {
1285+
s.Run("UpdateExternalAuthLinkRefreshToken", s.Subtest(func(db database.Store, check *expects) {
12861286
link := dbgen.ExternalAuthLink(s.T(), db, database.ExternalAuthLink{})
1287-
check.Args(database.RemoveRefreshTokenParams{
1288-
ProviderID: link.ProviderID,
1289-
UserID: link.UserID,
1290-
UpdatedAt: link.UpdatedAt,
1287+
check.Args(database.UpdateExternalAuthLinkRefreshTokenParams{
1288+
OAuthRefreshToken: "",
1289+
OAuthRefreshTokenKeyID: "",
1290+
ProviderID: link.ProviderID,
1291+
UserID: link.UserID,
1292+
UpdatedAt: link.UpdatedAt,
12911293
}).Asserts(rbac.ResourceUserObject(link.UserID), policy.ActionUpdatePersonal)
12921294
}))
12931295
s.Run("UpdateExternalAuthLink", s.Subtest(func(db database.Store, check *expects) {

coderd/database/dbmem/dbmem.go

+23-23
Original file line numberDiff line numberDiff line change
@@ -8607,29 +8607,6 @@ func (q *FakeQuerier) RegisterWorkspaceProxy(_ context.Context, arg database.Reg
86078607
return database.WorkspaceProxy{}, sql.ErrNoRows
86088608
}
86098609

8610-
func (q *FakeQuerier) RemoveRefreshToken(_ context.Context, arg database.RemoveRefreshTokenParams) error {
8611-
if err := validateDatabaseType(arg); err != nil {
8612-
return err
8613-
}
8614-
8615-
q.mutex.Lock()
8616-
defer q.mutex.Unlock()
8617-
for index, gitAuthLink := range q.externalAuthLinks {
8618-
if gitAuthLink.ProviderID != arg.ProviderID {
8619-
continue
8620-
}
8621-
if gitAuthLink.UserID != arg.UserID {
8622-
continue
8623-
}
8624-
gitAuthLink.UpdatedAt = arg.UpdatedAt
8625-
gitAuthLink.OAuthRefreshToken = ""
8626-
q.externalAuthLinks[index] = gitAuthLink
8627-
8628-
return nil
8629-
}
8630-
return sql.ErrNoRows
8631-
}
8632-
86338610
func (q *FakeQuerier) RemoveUserFromAllGroups(_ context.Context, userID uuid.UUID) error {
86348611
q.mutex.Lock()
86358612
defer q.mutex.Unlock()
@@ -8849,6 +8826,29 @@ func (q *FakeQuerier) UpdateExternalAuthLink(_ context.Context, arg database.Upd
88498826
return database.ExternalAuthLink{}, sql.ErrNoRows
88508827
}
88518828

8829+
func (q *FakeQuerier) UpdateExternalAuthLinkRefreshToken(_ context.Context, arg database.UpdateExternalAuthLinkRefreshTokenParams) error {
8830+
if err := validateDatabaseType(arg); err != nil {
8831+
return err
8832+
}
8833+
8834+
q.mutex.Lock()
8835+
defer q.mutex.Unlock()
8836+
for index, gitAuthLink := range q.externalAuthLinks {
8837+
if gitAuthLink.ProviderID != arg.ProviderID {
8838+
continue
8839+
}
8840+
if gitAuthLink.UserID != arg.UserID {
8841+
continue
8842+
}
8843+
gitAuthLink.UpdatedAt = arg.UpdatedAt
8844+
gitAuthLink.OAuthRefreshToken = arg.OAuthRefreshToken
8845+
q.externalAuthLinks[index] = gitAuthLink
8846+
8847+
return nil
8848+
}
8849+
return sql.ErrNoRows
8850+
}
8851+
88528852
func (q *FakeQuerier) UpdateGitSSHKey(_ context.Context, arg database.UpdateGitSSHKeyParams) (database.GitSSHKey, error) {
88538853
if err := validateDatabaseType(arg); err != nil {
88548854
return database.GitSSHKey{}, err

coderd/database/dbmetrics/querymetrics.go

+7-7
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

coderd/database/dbmock/dbmock.go

+14-14
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

coderd/database/querier.go

+1-4
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

coderd/database/queries.sql.go

+34-23
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

coderd/database/queries/externalauth.sql

+9-6
Original file line numberDiff line numberDiff line change
@@ -43,13 +43,16 @@ UPDATE external_auth_links SET
4343
oauth_extra = $9
4444
WHERE provider_id = $1 AND user_id = $2 RETURNING *;
4545

46-
-- name: RemoveRefreshToken :exec
47-
-- Removing the refresh token disables the refresh behavior for a given
48-
-- auth token. If a refresh token is marked invalid, it is better to remove it
49-
-- then continually attempt to refresh the token.
46+
-- name: UpdateExternalAuthLinkRefreshToken :exec
5047
UPDATE
5148
external_auth_links
5249
SET
53-
oauth_refresh_token = '',
50+
oauth_refresh_token = @oauth_refresh_token,
5451
updated_at = @updated_at
55-
WHERE provider_id = @provider_id AND user_id = @user_id;
52+
WHERE
53+
provider_id = @provider_id
54+
AND
55+
user_id = @user_id
56+
AND
57+
-- Required for sqlc to generate a parameter for the oauth_refresh_token_key_id
58+
@oauth_refresh_token_key_id :: text = @oauth_refresh_token_key_id :: text;

coderd/externalauth/externalauth.go

+6-4
Original file line numberDiff line numberDiff line change
@@ -143,10 +143,12 @@ func (c *Config) RefreshToken(ctx context.Context, db database.Store, externalAu
143143
// get rid of it. Keeping it around will cause additional refresh
144144
// attempts that will fail and cost us api rate limits.
145145
if isFailedRefresh(existingToken, err) {
146-
dbExecErr := db.RemoveRefreshToken(ctx, database.RemoveRefreshTokenParams{
147-
UpdatedAt: dbtime.Now(),
148-
ProviderID: externalAuthLink.ProviderID,
149-
UserID: externalAuthLink.UserID,
146+
dbExecErr := db.UpdateExternalAuthLinkRefreshToken(ctx, database.UpdateExternalAuthLinkRefreshTokenParams{
147+
OAuthRefreshToken: "", // It is better to clear the refresh token than to keep retrying.
148+
OAuthRefreshTokenKeyID: externalAuthLink.OAuthRefreshTokenKeyID.String,
149+
UpdatedAt: dbtime.Now(),
150+
ProviderID: externalAuthLink.ProviderID,
151+
UserID: externalAuthLink.UserID,
150152
})
151153
if dbExecErr != nil {
152154
// This error should be rare.

coderd/externalauth/externalauth_test.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ func TestRefreshToken(t *testing.T) {
190190

191191
// Try again with a bad refresh token error
192192
// Expect DB call to remove the refresh token
193-
mDB.EXPECT().RemoveRefreshToken(gomock.Any(), gomock.Any()).Return(nil).Times(1)
193+
mDB.EXPECT().UpdateExternalAuthLinkRefreshToken(gomock.Any(), gomock.Any()).Return(nil).Times(1)
194194
refreshErr = &oauth2.RetrieveError{ // github error
195195
Response: &http.Response{
196196
StatusCode: http.StatusOK,

enterprise/dbcrypt/cipher_internal_test.go

+34
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ package dbcrypt
33
import (
44
"bytes"
55
"encoding/base64"
6+
"os"
7+
"strings"
68
"testing"
79

810
"github.com/stretchr/testify/require"
@@ -89,3 +91,35 @@ func TestCiphersBackwardCompatibility(t *testing.T) {
8991
require.NoError(t, err, "decryption should succeed")
9092
require.Equal(t, msg, string(decrypted), "decrypted message should match original message")
9193
}
94+
95+
// If you're looking here, you're probably in trouble.
96+
// Here's what you need to do:
97+
// 1. Get the current CODER_EXTERNAL_TOKEN_ENCRYPTION_KEYS environment variable.
98+
// 2. Run the following command:
99+
// ENCRYPT_ME="<value to encrypt>" CODER_EXTERNAL_TOKEN_ENCRYPTION_KEYS="<secret keys here>" go test -v -count=1 ./enterprise/dbcrypt -test.run='^TestHelpMeEncryptSomeValue$'
100+
// 3. Copy the value from the test output and do what you need with it.
101+
func TestHelpMeEncryptSomeValue(t *testing.T) {
102+
t.Parallel()
103+
t.Skip("this only exists if you need to encrypt a value with dbcrypt, it does not actually test anything")
104+
105+
valueToEncrypt := os.Getenv("ENCRYPT_ME")
106+
t.Logf("valueToEncrypt: %q", valueToEncrypt)
107+
keys := os.Getenv("CODER_EXTERNAL_TOKEN_ENCRYPTION_KEYS")
108+
require.NotEmpty(t, keys, "Set the CODER_EXTERNAL_TOKEN_ENCRYPTION_KEYS environment variable to use this")
109+
110+
base64Keys := strings.Split(keys, ",")
111+
activeKey := base64Keys[0]
112+
113+
decodedKey, err := base64.StdEncoding.DecodeString(activeKey)
114+
require.NoError(t, err, "the active key should be valid base64")
115+
116+
cipher, err := cipherAES256(decodedKey)
117+
require.NoError(t, err)
118+
119+
t.Logf("cipher digest: %+v", cipher.HexDigest())
120+
121+
encryptedEmptyString, err := cipher.Encrypt([]byte(valueToEncrypt))
122+
require.NoError(t, err)
123+
124+
t.Logf("encrypted and base64-encoded: %q", base64.StdEncoding.EncodeToString(encryptedEmptyString))
125+
}

enterprise/dbcrypt/dbcrypt.go

+15
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,21 @@ func (db *dbCrypt) UpdateExternalAuthLink(ctx context.Context, params database.U
261261
return link, nil
262262
}
263263

264+
func (db *dbCrypt) UpdateExternalAuthLinkRefreshToken(ctx context.Context, params database.UpdateExternalAuthLinkRefreshTokenParams) error {
265+
// We would normally use a sql.NullString here, but sqlc does not want to make
266+
// a params struct with a nullable string.
267+
var digest sql.NullString
268+
if params.OAuthRefreshTokenKeyID != "" {
269+
digest.String = params.OAuthRefreshTokenKeyID
270+
digest.Valid = true
271+
}
272+
if err := db.encryptField(&params.OAuthRefreshToken, &digest); err != nil {
273+
return err
274+
}
275+
276+
return db.Store.UpdateExternalAuthLinkRefreshToken(ctx, params)
277+
}
278+
264279
func (db *dbCrypt) GetCryptoKeys(ctx context.Context) ([]database.CryptoKey, error) {
265280
keys, err := db.Store.GetCryptoKeys(ctx)
266281
if err != nil {

0 commit comments

Comments
 (0)