Skip to content

Commit aa39fcc

Browse files
committed
refactor: move rotate logic into dbcrypt
1 parent d51ec66 commit aa39fcc

File tree

2 files changed

+90
-64
lines changed

2 files changed

+90
-64
lines changed

enterprise/cli/server_dbcrypt_rotate.go

Lines changed: 2 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,12 @@ package cli
55
import (
66
"bytes"
77
"context"
8-
"database/sql"
98
"encoding/base64"
109

1110
"cdr.dev/slog"
1211
"cdr.dev/slog/sloggers/sloghuman"
1312
"github.com/coder/coder/v2/cli"
1413
"github.com/coder/coder/v2/cli/clibase"
15-
"github.com/coder/coder/v2/coderd/database"
1614
"github.com/coder/coder/v2/codersdk"
1715
"github.com/coder/coder/v2/enterprise/dbcrypt"
1816

@@ -95,70 +93,10 @@ func (*RootCmd) dbcryptRotateCmd() *clibase.Cmd {
9593
_ = sqlDB.Close()
9694
}()
9795
logger.Info(ctx, "connected to postgres")
98-
99-
db := database.New(sqlDB)
100-
101-
cryptDB, err := dbcrypt.New(ctx, db, ciphers...)
102-
if err != nil {
103-
return xerrors.Errorf("create cryptdb: %w", err)
104-
}
105-
106-
users, err := cryptDB.GetUsers(ctx, database.GetUsersParams{})
107-
if err != nil {
108-
return xerrors.Errorf("get users: %w", err)
109-
}
110-
logger.Info(ctx, "encrypting user tokens", slog.F("user_count", len(users)))
111-
for idx, usr := range users {
112-
err := cryptDB.InTx(func(tx database.Store) error {
113-
userLinks, err := tx.GetUserLinksByUserID(ctx, usr.ID)
114-
if err != nil {
115-
return xerrors.Errorf("get user links for user: %w", err)
116-
}
117-
for _, userLink := range userLinks {
118-
if _, err := tx.UpdateUserLink(ctx, database.UpdateUserLinkParams{
119-
OAuthAccessToken: userLink.OAuthAccessToken,
120-
OAuthRefreshToken: userLink.OAuthRefreshToken,
121-
OAuthExpiry: userLink.OAuthExpiry,
122-
UserID: usr.ID,
123-
LoginType: usr.LoginType,
124-
}); err != nil {
125-
return xerrors.Errorf("update user link user_id=%s linked_id=%s: %w", userLink.UserID, userLink.LinkedID, err)
126-
}
127-
}
128-
gitAuthLinks, err := tx.GetGitAuthLinksByUserID(ctx, usr.ID)
129-
if err != nil {
130-
return xerrors.Errorf("get git auth links for user: %w", err)
131-
}
132-
for _, gitAuthLink := range gitAuthLinks {
133-
if _, err := tx.UpdateGitAuthLink(ctx, database.UpdateGitAuthLinkParams{
134-
ProviderID: gitAuthLink.ProviderID,
135-
UserID: usr.ID,
136-
UpdatedAt: gitAuthLink.UpdatedAt,
137-
OAuthAccessToken: gitAuthLink.OAuthAccessToken,
138-
OAuthRefreshToken: gitAuthLink.OAuthRefreshToken,
139-
OAuthExpiry: gitAuthLink.OAuthExpiry,
140-
}); err != nil {
141-
return xerrors.Errorf("update git auth link user_id=%s provider_id=%s: %w", gitAuthLink.UserID, gitAuthLink.ProviderID, err)
142-
}
143-
}
144-
return nil
145-
}, &sql.TxOptions{
146-
Isolation: sql.LevelRepeatableRead,
147-
})
148-
if err != nil {
149-
return xerrors.Errorf("update user links: %w", err)
150-
}
151-
logger.Debug(ctx, "encrypted user tokens", slog.F("user_id", usr.ID), slog.F("current", idx+1), slog.F("cipher", ciphers[0].HexDigest()))
96+
if err := dbcrypt.Rotate(ctx, logger, sqlDB, ciphers); err != nil {
97+
return xerrors.Errorf("rotate ciphers: %w", err)
15298
}
15399
logger.Info(ctx, "operation completed successfully")
154-
155-
// Revoke old keys
156-
for _, c := range ciphers[1:] {
157-
if err := db.RevokeDBCryptKey(ctx, c.HexDigest()); err != nil {
158-
return xerrors.Errorf("revoke key: %w", err)
159-
}
160-
logger.Info(ctx, "revoked unused key", slog.F("digest", c.HexDigest()))
161-
}
162100
return nil
163101
},
164102
}

enterprise/dbcrypt/rotate.go

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
package dbcrypt
2+
3+
import (
4+
"context"
5+
"database/sql"
6+
7+
"golang.org/x/xerrors"
8+
9+
"cdr.dev/slog"
10+
"github.com/coder/coder/v2/coderd/database"
11+
)
12+
13+
// Rotate rotates the database encryption keys by re-encrypting all user tokens
14+
// with the first cipher and revoking all other ciphers.
15+
func Rotate(ctx context.Context, log slog.Logger, sqlDB *sql.DB, ciphers []Cipher) error {
16+
db := database.New(sqlDB)
17+
cryptDB, err := New(ctx, db, ciphers...)
18+
if err != nil {
19+
return xerrors.Errorf("create cryptdb: %w", err)
20+
}
21+
22+
users, err := cryptDB.GetUsers(ctx, database.GetUsersParams{})
23+
if err != nil {
24+
return xerrors.Errorf("get users: %w", err)
25+
}
26+
log.Info(ctx, "encrypting user tokens", slog.F("user_count", len(users)))
27+
for idx, usr := range users {
28+
err := cryptDB.InTx(func(tx database.Store) error {
29+
userLinks, err := tx.GetUserLinksByUserID(ctx, usr.ID)
30+
if err != nil {
31+
return xerrors.Errorf("get user links for user: %w", err)
32+
}
33+
for _, userLink := range userLinks {
34+
if userLink.OAuthAccessTokenKeyID.String == ciphers[0].HexDigest() && userLink.OAuthRefreshTokenKeyID.String == ciphers[0].HexDigest() {
35+
log.Debug(ctx, "skipping user link", slog.F("user_id", usr.ID), slog.F("current", idx+1), slog.F("cipher", ciphers[0].HexDigest()))
36+
continue
37+
}
38+
if _, err := tx.UpdateUserLink(ctx, database.UpdateUserLinkParams{
39+
OAuthAccessToken: userLink.OAuthAccessToken,
40+
OAuthRefreshToken: userLink.OAuthRefreshToken,
41+
OAuthExpiry: userLink.OAuthExpiry,
42+
UserID: usr.ID,
43+
LoginType: usr.LoginType,
44+
}); err != nil {
45+
return xerrors.Errorf("update user link user_id=%s linked_id=%s: %w", userLink.UserID, userLink.LinkedID, err)
46+
}
47+
}
48+
49+
gitAuthLinks, err := tx.GetGitAuthLinksByUserID(ctx, usr.ID)
50+
if err != nil {
51+
return xerrors.Errorf("get git auth links for user: %w", err)
52+
}
53+
for _, gitAuthLink := range gitAuthLinks {
54+
if gitAuthLink.OAuthAccessTokenKeyID.String == ciphers[0].HexDigest() && gitAuthLink.OAuthRefreshTokenKeyID.String == ciphers[0].HexDigest() {
55+
log.Debug(ctx, "skipping git auth link", slog.F("user_id", usr.ID), slog.F("current", idx+1), slog.F("cipher", ciphers[0].HexDigest()))
56+
continue
57+
}
58+
if _, err := tx.UpdateGitAuthLink(ctx, database.UpdateGitAuthLinkParams{
59+
ProviderID: gitAuthLink.ProviderID,
60+
UserID: usr.ID,
61+
UpdatedAt: gitAuthLink.UpdatedAt,
62+
OAuthAccessToken: gitAuthLink.OAuthAccessToken,
63+
OAuthRefreshToken: gitAuthLink.OAuthRefreshToken,
64+
OAuthExpiry: gitAuthLink.OAuthExpiry,
65+
}); err != nil {
66+
return xerrors.Errorf("update git auth link user_id=%s provider_id=%s: %w", gitAuthLink.UserID, gitAuthLink.ProviderID, err)
67+
}
68+
}
69+
return nil
70+
}, &sql.TxOptions{
71+
Isolation: sql.LevelRepeatableRead,
72+
})
73+
if err != nil {
74+
return xerrors.Errorf("update user links: %w", err)
75+
}
76+
log.Debug(ctx, "encrypted user tokens", slog.F("user_id", usr.ID), slog.F("current", idx+1), slog.F("cipher", ciphers[0].HexDigest()))
77+
}
78+
79+
// Revoke old keys
80+
for _, c := range ciphers[1:] {
81+
if err := db.RevokeDBCryptKey(ctx, c.HexDigest()); err != nil {
82+
return xerrors.Errorf("revoke key: %w", err)
83+
}
84+
log.Info(ctx, "revoked unused key", slog.F("digest", c.HexDigest()))
85+
}
86+
87+
return nil
88+
}

0 commit comments

Comments
 (0)