Skip to content

Commit 1d789ce

Browse files
johnstcnstirby
authored andcommitted
fix(coderd): ensure that clearing invalid oauth refresh tokens works with dbcrypt (#15721)
#15608 introduced a buggy behaviour with dbcrypt enabled. When clearing an oauth refresh token, we had been setting the value to the empty string. The database encryption package considers decrypting an empty string to be an error, as an empty encrypted string value will still have a nonce associated with it and thus not actually be empty when stored at rest. Instead of 'deleting' the refresh token, 'update' it to be the empty string. This plays nicely with dbcrypt. It also adds a 'utility test' in the dbcrypt package to help encrypt a value. This was useful when manually fixing users affected by this bug on our dogfood instance. (cherry picked from commit e744cde)
1 parent b359fb9 commit 1d789ce

File tree

13 files changed

+184
-94
lines changed

13 files changed

+184
-94
lines changed

coderd/database/dbauthz/dbauthz.go

+7-7
Original file line numberDiff line numberDiff line change
@@ -3330,13 +3330,6 @@ func (q *querier) RegisterWorkspaceProxy(ctx context.Context, arg database.Regis
33303330
return updateWithReturn(q.log, q.auth, fetch, q.db.RegisterWorkspaceProxy)(ctx, arg)
33313331
}
33323332

3333-
func (q *querier) RemoveRefreshToken(ctx context.Context, arg database.RemoveRefreshTokenParams) error {
3334-
fetch := func(ctx context.Context, arg database.RemoveRefreshTokenParams) (database.ExternalAuthLink, error) {
3335-
return q.db.GetExternalAuthLink(ctx, database.GetExternalAuthLinkParams{UserID: arg.UserID, ProviderID: arg.ProviderID})
3336-
}
3337-
return fetchAndExec(q.log, q.auth, policy.ActionUpdatePersonal, fetch, q.db.RemoveRefreshToken)(ctx, arg)
3338-
}
3339-
33403333
func (q *querier) RemoveUserFromAllGroups(ctx context.Context, userID uuid.UUID) error {
33413334
// This is a system function to clear user groups in group sync.
33423335
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceSystem); err != nil {
@@ -3435,6 +3428,13 @@ func (q *querier) UpdateExternalAuthLink(ctx context.Context, arg database.Updat
34353428
return fetchAndQuery(q.log, q.auth, policy.ActionUpdatePersonal, fetch, q.db.UpdateExternalAuthLink)(ctx, arg)
34363429
}
34373430

3431+
func (q *querier) UpdateExternalAuthLinkRefreshToken(ctx context.Context, arg database.UpdateExternalAuthLinkRefreshTokenParams) error {
3432+
fetch := func(ctx context.Context, arg database.UpdateExternalAuthLinkRefreshTokenParams) (database.ExternalAuthLink, error) {
3433+
return q.db.GetExternalAuthLink(ctx, database.GetExternalAuthLinkParams{UserID: arg.UserID, ProviderID: arg.ProviderID})
3434+
}
3435+
return fetchAndExec(q.log, q.auth, policy.ActionUpdatePersonal, fetch, q.db.UpdateExternalAuthLinkRefreshToken)(ctx, arg)
3436+
}
3437+
34383438
func (q *querier) UpdateGitSSHKey(ctx context.Context, arg database.UpdateGitSSHKeyParams) (database.GitSSHKey, error) {
34393439
fetch := func(ctx context.Context, arg database.UpdateGitSSHKeyParams) (database.GitSSHKey, error) {
34403440
return q.db.GetGitSSHKey(ctx, arg.UserID)

coderd/database/dbauthz/dbauthz_test.go

+7-5
Original file line numberDiff line numberDiff line change
@@ -1282,12 +1282,14 @@ func (s *MethodTestSuite) TestUser() {
12821282
UserID: u.ID,
12831283
}).Asserts(u, policy.ActionUpdatePersonal)
12841284
}))
1285-
s.Run("RemoveRefreshToken", s.Subtest(func(db database.Store, check *expects) {
1285+
s.Run("UpdateExternalAuthLinkRefreshToken", s.Subtest(func(db database.Store, check *expects) {
12861286
link := dbgen.ExternalAuthLink(s.T(), db, database.ExternalAuthLink{})
1287-
check.Args(database.RemoveRefreshTokenParams{
1288-
ProviderID: link.ProviderID,
1289-
UserID: link.UserID,
1290-
UpdatedAt: link.UpdatedAt,
1287+
check.Args(database.UpdateExternalAuthLinkRefreshTokenParams{
1288+
OAuthRefreshToken: "",
1289+
OAuthRefreshTokenKeyID: "",
1290+
ProviderID: link.ProviderID,
1291+
UserID: link.UserID,
1292+
UpdatedAt: link.UpdatedAt,
12911293
}).Asserts(rbac.ResourceUserObject(link.UserID), policy.ActionUpdatePersonal)
12921294
}))
12931295
s.Run("UpdateExternalAuthLink", s.Subtest(func(db database.Store, check *expects) {

coderd/database/dbmem/dbmem.go

+23-23
Original file line numberDiff line numberDiff line change
@@ -8556,29 +8556,6 @@ func (q *FakeQuerier) RegisterWorkspaceProxy(_ context.Context, arg database.Reg
85568556
return database.WorkspaceProxy{}, sql.ErrNoRows
85578557
}
85588558

8559-
func (q *FakeQuerier) RemoveRefreshToken(_ context.Context, arg database.RemoveRefreshTokenParams) error {
8560-
if err := validateDatabaseType(arg); err != nil {
8561-
return err
8562-
}
8563-
8564-
q.mutex.Lock()
8565-
defer q.mutex.Unlock()
8566-
for index, gitAuthLink := range q.externalAuthLinks {
8567-
if gitAuthLink.ProviderID != arg.ProviderID {
8568-
continue
8569-
}
8570-
if gitAuthLink.UserID != arg.UserID {
8571-
continue
8572-
}
8573-
gitAuthLink.UpdatedAt = arg.UpdatedAt
8574-
gitAuthLink.OAuthRefreshToken = ""
8575-
q.externalAuthLinks[index] = gitAuthLink
8576-
8577-
return nil
8578-
}
8579-
return sql.ErrNoRows
8580-
}
8581-
85828559
func (q *FakeQuerier) RemoveUserFromAllGroups(_ context.Context, userID uuid.UUID) error {
85838560
q.mutex.Lock()
85848561
defer q.mutex.Unlock()
@@ -8798,6 +8775,29 @@ func (q *FakeQuerier) UpdateExternalAuthLink(_ context.Context, arg database.Upd
87988775
return database.ExternalAuthLink{}, sql.ErrNoRows
87998776
}
88008777

8778+
func (q *FakeQuerier) UpdateExternalAuthLinkRefreshToken(_ context.Context, arg database.UpdateExternalAuthLinkRefreshTokenParams) error {
8779+
if err := validateDatabaseType(arg); err != nil {
8780+
return err
8781+
}
8782+
8783+
q.mutex.Lock()
8784+
defer q.mutex.Unlock()
8785+
for index, gitAuthLink := range q.externalAuthLinks {
8786+
if gitAuthLink.ProviderID != arg.ProviderID {
8787+
continue
8788+
}
8789+
if gitAuthLink.UserID != arg.UserID {
8790+
continue
8791+
}
8792+
gitAuthLink.UpdatedAt = arg.UpdatedAt
8793+
gitAuthLink.OAuthRefreshToken = arg.OAuthRefreshToken
8794+
q.externalAuthLinks[index] = gitAuthLink
8795+
8796+
return nil
8797+
}
8798+
return sql.ErrNoRows
8799+
}
8800+
88018801
func (q *FakeQuerier) UpdateGitSSHKey(_ context.Context, arg database.UpdateGitSSHKeyParams) (database.GitSSHKey, error) {
88028802
if err := validateDatabaseType(arg); err != nil {
88038803
return database.GitSSHKey{}, err

coderd/database/dbmetrics/querymetrics.go

+7-7
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

coderd/database/dbmock/dbmock.go

+14-14
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

coderd/database/querier.go

+1-4
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

coderd/database/queries.sql.go

+34-23
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

coderd/database/queries/externalauth.sql

+9-6
Original file line numberDiff line numberDiff line change
@@ -43,13 +43,16 @@ UPDATE external_auth_links SET
4343
oauth_extra = $9
4444
WHERE provider_id = $1 AND user_id = $2 RETURNING *;
4545

46-
-- name: RemoveRefreshToken :exec
47-
-- Removing the refresh token disables the refresh behavior for a given
48-
-- auth token. If a refresh token is marked invalid, it is better to remove it
49-
-- then continually attempt to refresh the token.
46+
-- name: UpdateExternalAuthLinkRefreshToken :exec
5047
UPDATE
5148
external_auth_links
5249
SET
53-
oauth_refresh_token = '',
50+
oauth_refresh_token = @oauth_refresh_token,
5451
updated_at = @updated_at
55-
WHERE provider_id = @provider_id AND user_id = @user_id;
52+
WHERE
53+
provider_id = @provider_id
54+
AND
55+
user_id = @user_id
56+
AND
57+
-- Required for sqlc to generate a parameter for the oauth_refresh_token_key_id
58+
@oauth_refresh_token_key_id :: text = @oauth_refresh_token_key_id :: text;

coderd/externalauth/externalauth.go

+6-4
Original file line numberDiff line numberDiff line change
@@ -143,10 +143,12 @@ func (c *Config) RefreshToken(ctx context.Context, db database.Store, externalAu
143143
// get rid of it. Keeping it around will cause additional refresh
144144
// attempts that will fail and cost us api rate limits.
145145
if isFailedRefresh(existingToken, err) {
146-
dbExecErr := db.RemoveRefreshToken(ctx, database.RemoveRefreshTokenParams{
147-
UpdatedAt: dbtime.Now(),
148-
ProviderID: externalAuthLink.ProviderID,
149-
UserID: externalAuthLink.UserID,
146+
dbExecErr := db.UpdateExternalAuthLinkRefreshToken(ctx, database.UpdateExternalAuthLinkRefreshTokenParams{
147+
OAuthRefreshToken: "", // It is better to clear the refresh token than to keep retrying.
148+
OAuthRefreshTokenKeyID: externalAuthLink.OAuthRefreshTokenKeyID.String,
149+
UpdatedAt: dbtime.Now(),
150+
ProviderID: externalAuthLink.ProviderID,
151+
UserID: externalAuthLink.UserID,
150152
})
151153
if dbExecErr != nil {
152154
// This error should be rare.

coderd/externalauth/externalauth_test.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ func TestRefreshToken(t *testing.T) {
190190

191191
// Try again with a bad refresh token error
192192
// Expect DB call to remove the refresh token
193-
mDB.EXPECT().RemoveRefreshToken(gomock.Any(), gomock.Any()).Return(nil).Times(1)
193+
mDB.EXPECT().UpdateExternalAuthLinkRefreshToken(gomock.Any(), gomock.Any()).Return(nil).Times(1)
194194
refreshErr = &oauth2.RetrieveError{ // github error
195195
Response: &http.Response{
196196
StatusCode: http.StatusOK,

enterprise/dbcrypt/cipher_internal_test.go

+34
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ package dbcrypt
33
import (
44
"bytes"
55
"encoding/base64"
6+
"os"
7+
"strings"
68
"testing"
79

810
"github.com/stretchr/testify/require"
@@ -89,3 +91,35 @@ func TestCiphersBackwardCompatibility(t *testing.T) {
8991
require.NoError(t, err, "decryption should succeed")
9092
require.Equal(t, msg, string(decrypted), "decrypted message should match original message")
9193
}
94+
95+
// If you're looking here, you're probably in trouble.
96+
// Here's what you need to do:
97+
// 1. Get the current CODER_EXTERNAL_TOKEN_ENCRYPTION_KEYS environment variable.
98+
// 2. Run the following command:
99+
// ENCRYPT_ME="<value to encrypt>" CODER_EXTERNAL_TOKEN_ENCRYPTION_KEYS="<secret keys here>" go test -v -count=1 ./enterprise/dbcrypt -test.run='^TestHelpMeEncryptSomeValue$'
100+
// 3. Copy the value from the test output and do what you need with it.
101+
func TestHelpMeEncryptSomeValue(t *testing.T) {
102+
t.Parallel()
103+
t.Skip("this only exists if you need to encrypt a value with dbcrypt, it does not actually test anything")
104+
105+
valueToEncrypt := os.Getenv("ENCRYPT_ME")
106+
t.Logf("valueToEncrypt: %q", valueToEncrypt)
107+
keys := os.Getenv("CODER_EXTERNAL_TOKEN_ENCRYPTION_KEYS")
108+
require.NotEmpty(t, keys, "Set the CODER_EXTERNAL_TOKEN_ENCRYPTION_KEYS environment variable to use this")
109+
110+
base64Keys := strings.Split(keys, ",")
111+
activeKey := base64Keys[0]
112+
113+
decodedKey, err := base64.StdEncoding.DecodeString(activeKey)
114+
require.NoError(t, err, "the active key should be valid base64")
115+
116+
cipher, err := cipherAES256(decodedKey)
117+
require.NoError(t, err)
118+
119+
t.Logf("cipher digest: %+v", cipher.HexDigest())
120+
121+
encryptedEmptyString, err := cipher.Encrypt([]byte(valueToEncrypt))
122+
require.NoError(t, err)
123+
124+
t.Logf("encrypted and base64-encoded: %q", base64.StdEncoding.EncodeToString(encryptedEmptyString))
125+
}

enterprise/dbcrypt/dbcrypt.go

+15
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,21 @@ func (db *dbCrypt) UpdateExternalAuthLink(ctx context.Context, params database.U
261261
return link, nil
262262
}
263263

264+
func (db *dbCrypt) UpdateExternalAuthLinkRefreshToken(ctx context.Context, params database.UpdateExternalAuthLinkRefreshTokenParams) error {
265+
// We would normally use a sql.NullString here, but sqlc does not want to make
266+
// a params struct with a nullable string.
267+
var digest sql.NullString
268+
if params.OAuthRefreshTokenKeyID != "" {
269+
digest.String = params.OAuthRefreshTokenKeyID
270+
digest.Valid = true
271+
}
272+
if err := db.encryptField(&params.OAuthRefreshToken, &digest); err != nil {
273+
return err
274+
}
275+
276+
return db.Store.UpdateExternalAuthLinkRefreshToken(ctx, params)
277+
}
278+
264279
func (db *dbCrypt) GetCryptoKeys(ctx context.Context) ([]database.CryptoKey, error) {
265280
keys, err := db.Store.GetCryptoKeys(ctx)
266281
if err != nil {

0 commit comments

Comments
 (0)