Skip to content

Commit 77954e1

Browse files
committed
fix(coderd): ensure that clearing invalid oauth refresh tokens works with dbcrypt
1 parent 7e1ac2e commit 77954e1

File tree

13 files changed

+177
-94
lines changed

13 files changed

+177
-94
lines changed

coderd/database/dbauthz/dbauthz.go

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3330,13 +3330,6 @@ func (q *querier) RegisterWorkspaceProxy(ctx context.Context, arg database.Regis
33303330
return updateWithReturn(q.log, q.auth, fetch, q.db.RegisterWorkspaceProxy)(ctx, arg)
33313331
}
33323332

3333-
func (q *querier) RemoveRefreshToken(ctx context.Context, arg database.RemoveRefreshTokenParams) error {
3334-
fetch := func(ctx context.Context, arg database.RemoveRefreshTokenParams) (database.ExternalAuthLink, error) {
3335-
return q.db.GetExternalAuthLink(ctx, database.GetExternalAuthLinkParams{UserID: arg.UserID, ProviderID: arg.ProviderID})
3336-
}
3337-
return fetchAndExec(q.log, q.auth, policy.ActionUpdatePersonal, fetch, q.db.RemoveRefreshToken)(ctx, arg)
3338-
}
3339-
33403333
func (q *querier) RemoveUserFromAllGroups(ctx context.Context, userID uuid.UUID) error {
33413334
// This is a system function to clear user groups in group sync.
33423335
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceSystem); err != nil {
@@ -3435,6 +3428,13 @@ func (q *querier) UpdateExternalAuthLink(ctx context.Context, arg database.Updat
34353428
return fetchAndQuery(q.log, q.auth, policy.ActionUpdatePersonal, fetch, q.db.UpdateExternalAuthLink)(ctx, arg)
34363429
}
34373430

3431+
func (q *querier) UpdateExternalAuthLinkRefreshToken(ctx context.Context, arg database.UpdateExternalAuthLinkRefreshTokenParams) error {
3432+
fetch := func(ctx context.Context, arg database.UpdateExternalAuthLinkRefreshTokenParams) (database.ExternalAuthLink, error) {
3433+
return q.db.GetExternalAuthLink(ctx, database.GetExternalAuthLinkParams{UserID: arg.UserID, ProviderID: arg.ProviderID})
3434+
}
3435+
return fetchAndExec(q.log, q.auth, policy.ActionUpdatePersonal, fetch, q.db.UpdateExternalAuthLinkRefreshToken)(ctx, arg)
3436+
}
3437+
34383438
func (q *querier) UpdateGitSSHKey(ctx context.Context, arg database.UpdateGitSSHKeyParams) (database.GitSSHKey, error) {
34393439
fetch := func(ctx context.Context, arg database.UpdateGitSSHKeyParams) (database.GitSSHKey, error) {
34403440
return q.db.GetGitSSHKey(ctx, arg.UserID)

coderd/database/dbauthz/dbauthz_test.go

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1282,12 +1282,14 @@ func (s *MethodTestSuite) TestUser() {
12821282
UserID: u.ID,
12831283
}).Asserts(u, policy.ActionUpdatePersonal)
12841284
}))
1285-
s.Run("RemoveRefreshToken", s.Subtest(func(db database.Store, check *expects) {
1285+
s.Run("UpdateExternalAuthLinkRefreshToken", s.Subtest(func(db database.Store, check *expects) {
12861286
link := dbgen.ExternalAuthLink(s.T(), db, database.ExternalAuthLink{})
1287-
check.Args(database.RemoveRefreshTokenParams{
1288-
ProviderID: link.ProviderID,
1289-
UserID: link.UserID,
1290-
UpdatedAt: link.UpdatedAt,
1287+
check.Args(database.UpdateExternalAuthLinkRefreshTokenParams{
1288+
OAuthRefreshToken: "",
1289+
OAuthRefreshTokenKeyID: "",
1290+
ProviderID: link.ProviderID,
1291+
UserID: link.UserID,
1292+
UpdatedAt: link.UpdatedAt,
12911293
}).Asserts(rbac.ResourceUserObject(link.UserID), policy.ActionUpdatePersonal)
12921294
}))
12931295
s.Run("UpdateExternalAuthLink", s.Subtest(func(db database.Store, check *expects) {

coderd/database/dbmem/dbmem.go

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -8556,29 +8556,6 @@ func (q *FakeQuerier) RegisterWorkspaceProxy(_ context.Context, arg database.Reg
85568556
return database.WorkspaceProxy{}, sql.ErrNoRows
85578557
}
85588558

8559-
func (q *FakeQuerier) RemoveRefreshToken(_ context.Context, arg database.RemoveRefreshTokenParams) error {
8560-
if err := validateDatabaseType(arg); err != nil {
8561-
return err
8562-
}
8563-
8564-
q.mutex.Lock()
8565-
defer q.mutex.Unlock()
8566-
for index, gitAuthLink := range q.externalAuthLinks {
8567-
if gitAuthLink.ProviderID != arg.ProviderID {
8568-
continue
8569-
}
8570-
if gitAuthLink.UserID != arg.UserID {
8571-
continue
8572-
}
8573-
gitAuthLink.UpdatedAt = arg.UpdatedAt
8574-
gitAuthLink.OAuthRefreshToken = ""
8575-
q.externalAuthLinks[index] = gitAuthLink
8576-
8577-
return nil
8578-
}
8579-
return sql.ErrNoRows
8580-
}
8581-
85828559
func (q *FakeQuerier) RemoveUserFromAllGroups(_ context.Context, userID uuid.UUID) error {
85838560
q.mutex.Lock()
85848561
defer q.mutex.Unlock()
@@ -8798,6 +8775,29 @@ func (q *FakeQuerier) UpdateExternalAuthLink(_ context.Context, arg database.Upd
87988775
return database.ExternalAuthLink{}, sql.ErrNoRows
87998776
}
88008777

8778+
func (q *FakeQuerier) UpdateExternalAuthLinkRefreshToken(ctx context.Context, arg database.UpdateExternalAuthLinkRefreshTokenParams) error {
8779+
if err := validateDatabaseType(arg); err != nil {
8780+
return err
8781+
}
8782+
8783+
q.mutex.Lock()
8784+
defer q.mutex.Unlock()
8785+
for index, gitAuthLink := range q.externalAuthLinks {
8786+
if gitAuthLink.ProviderID != arg.ProviderID {
8787+
continue
8788+
}
8789+
if gitAuthLink.UserID != arg.UserID {
8790+
continue
8791+
}
8792+
gitAuthLink.UpdatedAt = arg.UpdatedAt
8793+
gitAuthLink.OAuthRefreshToken = ""
8794+
q.externalAuthLinks[index] = gitAuthLink
8795+
8796+
return nil
8797+
}
8798+
return sql.ErrNoRows
8799+
}
8800+
88018801
func (q *FakeQuerier) UpdateGitSSHKey(_ context.Context, arg database.UpdateGitSSHKeyParams) (database.GitSSHKey, error) {
88028802
if err := validateDatabaseType(arg); err != nil {
88038803
return database.GitSSHKey{}, err

coderd/database/dbmetrics/querymetrics.go

Lines changed: 7 additions & 7 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

coderd/database/dbmock/dbmock.go

Lines changed: 14 additions & 14 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

coderd/database/querier.go

Lines changed: 1 addition & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

coderd/database/queries.sql.go

Lines changed: 34 additions & 23 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

coderd/database/queries/externalauth.sql

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -43,13 +43,16 @@ UPDATE external_auth_links SET
4343
oauth_extra = $9
4444
WHERE provider_id = $1 AND user_id = $2 RETURNING *;
4545

46-
-- name: RemoveRefreshToken :exec
47-
-- Removing the refresh token disables the refresh behavior for a given
48-
-- auth token. If a refresh token is marked invalid, it is better to remove it
49-
-- then continually attempt to refresh the token.
46+
-- name: UpdateExternalAuthLinkRefreshToken :exec
5047
UPDATE
5148
external_auth_links
5249
SET
53-
oauth_refresh_token = '',
50+
oauth_refresh_token = @oauth_refresh_token,
5451
updated_at = @updated_at
55-
WHERE provider_id = @provider_id AND user_id = @user_id;
52+
WHERE
53+
provider_id = @provider_id
54+
AND
55+
user_id = @user_id
56+
AND
57+
-- Required for sqlc to generate a parameter for the oauth_refresh_token_key_id
58+
@oauth_refresh_token_key_id :: text = @oauth_refresh_token_key_id :: text;

coderd/externalauth/externalauth.go

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -143,10 +143,12 @@ func (c *Config) RefreshToken(ctx context.Context, db database.Store, externalAu
143143
// get rid of it. Keeping it around will cause additional refresh
144144
// attempts that will fail and cost us api rate limits.
145145
if isFailedRefresh(existingToken, err) {
146-
dbExecErr := db.RemoveRefreshToken(ctx, database.RemoveRefreshTokenParams{
147-
UpdatedAt: dbtime.Now(),
148-
ProviderID: externalAuthLink.ProviderID,
149-
UserID: externalAuthLink.UserID,
146+
dbExecErr := db.UpdateExternalAuthLinkRefreshToken(ctx, database.UpdateExternalAuthLinkRefreshTokenParams{
147+
OAuthRefreshToken: "", // It is better to clear the refresh token than to keep retrying.
148+
OAuthRefreshTokenKeyID: externalAuthLink.OAuthRefreshTokenKeyID.String,
149+
UpdatedAt: dbtime.Now(),
150+
ProviderID: externalAuthLink.ProviderID,
151+
UserID: externalAuthLink.UserID,
150152
})
151153
if dbExecErr != nil {
152154
// This error should be rare.

coderd/externalauth/externalauth_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ func TestRefreshToken(t *testing.T) {
190190

191191
// Try again with a bad refresh token error
192192
// Expect DB call to remove the refresh token
193-
mDB.EXPECT().RemoveRefreshToken(gomock.Any(), gomock.Any()).Return(nil).Times(1)
193+
mDB.EXPECT().UpdateExternalAuthLinkRefreshToken(gomock.Any(), gomock.Any()).Return(nil).Times(1)
194194
refreshErr = &oauth2.RetrieveError{ // github error
195195
Response: &http.Response{
196196
StatusCode: http.StatusOK,

enterprise/dbcrypt/cipher_internal_test.go

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ package dbcrypt
33
import (
44
"bytes"
55
"encoding/base64"
6+
"os"
7+
"strings"
68
"testing"
79

810
"github.com/stretchr/testify/require"
@@ -89,3 +91,35 @@ func TestCiphersBackwardCompatibility(t *testing.T) {
8991
require.NoError(t, err, "decryption should succeed")
9092
require.Equal(t, msg, string(decrypted), "decrypted message should match original message")
9193
}
94+
95+
// If you're looking here, you're probably in trouble.
96+
// Here's what you need to do:
97+
// 1. Get the current CODER_EXTERNAL_TOKEN_ENCRYPTION_KEYS environment variable.
98+
// 2. Run the following command:
99+
// ENCRYPT_ME="<value to encrypt>" CODER_EXTERNAL_TOKEN_ENCRYPTION_KEYS="<secret keys here>" go test -v -count=1 ./enterprise/dbcrypt -test.run='^TestHelpMeEncryptSomeValue$'
100+
// 3. Copy the value from the test output and do what you need with it.
101+
func TestHelpMeEncryptSomeValue(t *testing.T) {
102+
t.Parallel()
103+
t.Skip("this only exists if you need to encrypt a value with dbcrypt, it does not actually test anything")
104+
105+
valueToEncrypt := os.Getenv("ENCRYPT_ME")
106+
t.Logf("valueToEncrypt: %q", valueToEncrypt)
107+
keys := os.Getenv("CODER_EXTERNAL_TOKEN_ENCRYPTION_KEYS")
108+
require.NotEmpty(t, keys, "Set the CODER_EXTERNAL_TOKEN_ENCRYPTION_KEYS environment variable to use this")
109+
110+
base64Keys := strings.Split(keys, ",")
111+
activeKey := base64Keys[0]
112+
113+
decodedKey, err := base64.StdEncoding.DecodeString(activeKey)
114+
require.NoError(t, err, "the active key should be valid base64")
115+
116+
cipher, err := cipherAES256(decodedKey)
117+
require.NoError(t, err)
118+
119+
t.Logf("cipher digest: %+v", cipher.HexDigest())
120+
121+
encryptedEmptyString, err := cipher.Encrypt([]byte(valueToEncrypt))
122+
require.NoError(t, err)
123+
124+
t.Logf("encrypted and base64-encoded: %q", base64.StdEncoding.EncodeToString(encryptedEmptyString))
125+
}

enterprise/dbcrypt/dbcrypt.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,14 @@ func (db *dbCrypt) UpdateExternalAuthLink(ctx context.Context, params database.U
261261
return link, nil
262262
}
263263

264+
func (db *dbCrypt) UpdateExternalAuthLinkRefreshToken(ctx context.Context, params database.UpdateExternalAuthLinkRefreshTokenParams) error {
265+
if err := db.encryptField(&params.OAuthRefreshToken, &sql.NullString{String: params.OAuthRefreshTokenKeyID, Valid: true}); err != nil {
266+
return err
267+
}
268+
269+
return db.Store.UpdateExternalAuthLinkRefreshToken(ctx, params)
270+
}
271+
264272
func (db *dbCrypt) GetCryptoKeys(ctx context.Context) ([]database.CryptoKey, error) {
265273
keys, err := db.Store.GetCryptoKeys(ctx)
266274
if err != nil {

0 commit comments

Comments
 (0)