Skip to content

Commit 2d5c068

Browse files
authored
feat: implement key rotation system (#14710)
1 parent dbe6b6c commit 2d5c068

File tree

5 files changed

+1029
-1
lines changed

5 files changed

+1029
-1
lines changed

coderd/database/dbgen/dbgen.go

+5-1
Original file line numberDiff line numberDiff line change
@@ -902,7 +902,11 @@ func CryptoKey(t testing.TB, db database.Store, seed database.CryptoKey) databas
902902

903903
seed.Feature = takeFirst(seed.Feature, database.CryptoKeyFeatureWorkspaceApps)
904904

905-
if !seed.Secret.Valid {
905+
// An empty string for the secret is interpreted as
906+
// a caller wanting a new secret to be generated.
907+
// To generate a key with a NULL secret set Valid=false
908+
// and String to a non-empty string.
909+
if seed.Secret.String == "" {
906910
secret, err := newCryptoKeySecret(seed.Feature)
907911
require.NoError(t, err, "generate secret")
908912
seed.Secret = sql.NullString{

coderd/database/lock.go

+1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ const (
1111
LockIDDBRollup
1212
LockIDDBPurge
1313
LockIDNotificationsReportGenerator
14+
LockIDCryptoKeyRotation
1415
)
1516

1617
// GenLockID generates a unique and consistent lock ID from a given string.

coderd/keyrotate/rotate.go

+298
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,298 @@
1+
package keyrotate
2+
3+
import (
4+
"context"
5+
"crypto/rand"
6+
"database/sql"
7+
"encoding/hex"
8+
"time"
9+
10+
"golang.org/x/xerrors"
11+
12+
"cdr.dev/slog"
13+
"github.com/coder/coder/v2/coderd/database"
14+
"github.com/coder/coder/v2/coderd/database/dbtime"
15+
"github.com/coder/quartz"
16+
)
17+
18+
const (
19+
WorkspaceAppsTokenDuration = time.Minute
20+
OIDCConvertTokenDuration = time.Minute * 5
21+
TailnetResumeTokenDuration = time.Hour * 24
22+
23+
// defaultRotationInterval is the default interval at which keys are checked for rotation.
24+
defaultRotationInterval = time.Minute * 10
25+
// DefaultKeyDuration is the default duration for which a key is valid. It applies to all features.
26+
DefaultKeyDuration = time.Hour * 24 * 30
27+
)
28+
29+
// rotator is responsible for rotating keys in the database.
30+
type rotator struct {
31+
db database.Store
32+
logger slog.Logger
33+
clock quartz.Clock
34+
keyDuration time.Duration
35+
36+
features []database.CryptoKeyFeature
37+
}
38+
39+
type Option func(*rotator)
40+
41+
func WithClock(clock quartz.Clock) Option {
42+
return func(r *rotator) {
43+
r.clock = clock
44+
}
45+
}
46+
47+
func WithKeyDuration(keyDuration time.Duration) Option {
48+
return func(r *rotator) {
49+
r.keyDuration = keyDuration
50+
}
51+
}
52+
53+
// StartRotator starts a background process that rotates keys in the database.
54+
// It ensures there's at least one valid key per feature prior to returning.
55+
// Canceling the provided context will stop the background process.
56+
func StartRotator(ctx context.Context, logger slog.Logger, db database.Store, opts ...Option) error {
57+
kr := &rotator{
58+
db: db,
59+
logger: logger,
60+
clock: quartz.NewReal(),
61+
keyDuration: DefaultKeyDuration,
62+
features: database.AllCryptoKeyFeatureValues(),
63+
}
64+
65+
for _, opt := range opts {
66+
opt(kr)
67+
}
68+
69+
err := kr.rotateKeys(ctx)
70+
if err != nil {
71+
return xerrors.Errorf("rotate keys: %w", err)
72+
}
73+
74+
go kr.start(ctx)
75+
76+
return nil
77+
}
78+
79+
// start begins the process of rotating keys.
80+
// Canceling the context will stop the rotation process.
81+
func (k *rotator) start(ctx context.Context) {
82+
k.clock.TickerFunc(ctx, defaultRotationInterval, func() error {
83+
err := k.rotateKeys(ctx)
84+
if err != nil {
85+
k.logger.Error(ctx, "failed to rotate keys", slog.Error(err))
86+
}
87+
return nil
88+
})
89+
k.logger.Debug(ctx, "ctx canceled, stopping key rotation")
90+
}
91+
92+
// rotateKeys checks for any keys needing rotation or deletion and
93+
// may insert a new key if it detects that a valid one does
94+
// not exist for a feature.
95+
func (k *rotator) rotateKeys(ctx context.Context) error {
96+
return k.db.InTx(
97+
func(tx database.Store) error {
98+
err := tx.AcquireLock(ctx, database.LockIDCryptoKeyRotation)
99+
if err != nil {
100+
return xerrors.Errorf("acquire lock: %w", err)
101+
}
102+
103+
cryptokeys, err := tx.GetCryptoKeys(ctx)
104+
if err != nil {
105+
return xerrors.Errorf("get keys: %w", err)
106+
}
107+
108+
featureKeys, err := keysByFeature(cryptokeys, k.features)
109+
if err != nil {
110+
return xerrors.Errorf("keys by feature: %w", err)
111+
}
112+
113+
now := dbtime.Time(k.clock.Now().UTC())
114+
for feature, keys := range featureKeys {
115+
// We'll use a counter to determine if we should insert a new key. We should always have at least one key for a feature.
116+
var validKeys int
117+
for _, key := range keys {
118+
switch {
119+
case shouldDeleteKey(key, now):
120+
_, err := tx.DeleteCryptoKey(ctx, database.DeleteCryptoKeyParams{
121+
Feature: key.Feature,
122+
Sequence: key.Sequence,
123+
})
124+
if err != nil {
125+
return xerrors.Errorf("delete key: %w", err)
126+
}
127+
k.logger.Debug(ctx, "deleted key",
128+
slog.F("key", key.Sequence),
129+
slog.F("feature", key.Feature),
130+
)
131+
case shouldRotateKey(key, k.keyDuration, now):
132+
_, err := k.rotateKey(ctx, tx, key, now)
133+
if err != nil {
134+
return xerrors.Errorf("rotate key: %w", err)
135+
}
136+
k.logger.Debug(ctx, "rotated key",
137+
slog.F("key", key.Sequence),
138+
slog.F("feature", key.Feature),
139+
)
140+
validKeys++
141+
default:
142+
// We only consider keys without a populated deletes_at field as valid.
143+
// This is because under normal circumstances the deletes_at field
144+
// is set during rotation (meaning a new key was generated)
145+
// but it's possible if the database was manually altered to
146+
// delete the new key we may be in a situation where there
147+
// isn't a key to replace the one scheduled for deletion.
148+
if !key.DeletesAt.Valid {
149+
validKeys++
150+
}
151+
}
152+
}
153+
if validKeys == 0 {
154+
k.logger.Info(ctx, "no valid keys detected, inserting new key",
155+
slog.F("feature", feature),
156+
)
157+
_, err := k.insertNewKey(ctx, tx, feature, now)
158+
if err != nil {
159+
return xerrors.Errorf("insert new key: %w", err)
160+
}
161+
}
162+
}
163+
return nil
164+
}, &sql.TxOptions{
165+
Isolation: sql.LevelRepeatableRead,
166+
})
167+
}
168+
169+
func (k *rotator) insertNewKey(ctx context.Context, tx database.Store, feature database.CryptoKeyFeature, startsAt time.Time) (database.CryptoKey, error) {
170+
secret, err := generateNewSecret(feature)
171+
if err != nil {
172+
return database.CryptoKey{}, xerrors.Errorf("generate new secret: %w", err)
173+
}
174+
175+
latestKey, err := tx.GetLatestCryptoKeyByFeature(ctx, feature)
176+
if err != nil && !xerrors.Is(err, sql.ErrNoRows) {
177+
return database.CryptoKey{}, xerrors.Errorf("get latest key: %w", err)
178+
}
179+
180+
newKey, err := tx.InsertCryptoKey(ctx, database.InsertCryptoKeyParams{
181+
Feature: feature,
182+
Sequence: latestKey.Sequence + 1,
183+
Secret: sql.NullString{
184+
String: secret,
185+
Valid: true,
186+
},
187+
// Set by dbcrypt if it's required.
188+
SecretKeyID: sql.NullString{},
189+
StartsAt: startsAt.UTC(),
190+
})
191+
if err != nil {
192+
return database.CryptoKey{}, xerrors.Errorf("inserting new key: %w", err)
193+
}
194+
195+
k.logger.Info(ctx, "inserted new key for feature", slog.F("feature", feature))
196+
return newKey, nil
197+
}
198+
199+
func (k *rotator) rotateKey(ctx context.Context, tx database.Store, key database.CryptoKey, now time.Time) ([]database.CryptoKey, error) {
200+
startsAt := minStartsAt(key, now, k.keyDuration)
201+
newKey, err := k.insertNewKey(ctx, tx, key.Feature, startsAt)
202+
if err != nil {
203+
return nil, xerrors.Errorf("insert new key: %w", err)
204+
}
205+
206+
// Set old key's deletes_at to an hour + however long the token
207+
// for this feature is expected to be valid for. This should
208+
// allow for sufficient time for the new key to propagate to
209+
// dependent services (i.e. Workspace Proxies).
210+
deletesAt := startsAt.Add(time.Hour).Add(tokenDuration(key.Feature))
211+
212+
updatedKey, err := tx.UpdateCryptoKeyDeletesAt(ctx, database.UpdateCryptoKeyDeletesAtParams{
213+
Feature: key.Feature,
214+
Sequence: key.Sequence,
215+
DeletesAt: sql.NullTime{
216+
Time: deletesAt.UTC(),
217+
Valid: true,
218+
},
219+
})
220+
if err != nil {
221+
return nil, xerrors.Errorf("update old key's deletes_at: %w", err)
222+
}
223+
224+
return []database.CryptoKey{updatedKey, newKey}, nil
225+
}
226+
227+
func generateNewSecret(feature database.CryptoKeyFeature) (string, error) {
228+
switch feature {
229+
case database.CryptoKeyFeatureWorkspaceApps:
230+
return generateKey(96)
231+
case database.CryptoKeyFeatureOidcConvert:
232+
return generateKey(32)
233+
case database.CryptoKeyFeatureTailnetResume:
234+
return generateKey(64)
235+
}
236+
return "", xerrors.Errorf("unknown feature: %s", feature)
237+
}
238+
239+
func generateKey(length int) (string, error) {
240+
b := make([]byte, length)
241+
_, err := rand.Read(b)
242+
if err != nil {
243+
return "", xerrors.Errorf("rand read: %w", err)
244+
}
245+
return hex.EncodeToString(b), nil
246+
}
247+
248+
func tokenDuration(feature database.CryptoKeyFeature) time.Duration {
249+
switch feature {
250+
case database.CryptoKeyFeatureWorkspaceApps:
251+
return WorkspaceAppsTokenDuration
252+
case database.CryptoKeyFeatureOidcConvert:
253+
return OIDCConvertTokenDuration
254+
case database.CryptoKeyFeatureTailnetResume:
255+
return TailnetResumeTokenDuration
256+
default:
257+
return 0
258+
}
259+
}
260+
261+
func shouldDeleteKey(key database.CryptoKey, now time.Time) bool {
262+
return key.DeletesAt.Valid && !now.Before(key.DeletesAt.Time.UTC())
263+
}
264+
265+
func shouldRotateKey(key database.CryptoKey, keyDuration time.Duration, now time.Time) bool {
266+
// If deletes_at is set, we've already inserted a key.
267+
if key.DeletesAt.Valid {
268+
return false
269+
}
270+
expirationTime := key.ExpiresAt(keyDuration)
271+
return !now.Add(time.Hour).UTC().Before(expirationTime)
272+
}
273+
274+
func keysByFeature(keys []database.CryptoKey, features []database.CryptoKeyFeature) (map[database.CryptoKeyFeature][]database.CryptoKey, error) {
275+
m := map[database.CryptoKeyFeature][]database.CryptoKey{}
276+
for _, feature := range features {
277+
m[feature] = []database.CryptoKey{}
278+
}
279+
for _, key := range keys {
280+
if _, ok := m[key.Feature]; !ok {
281+
return nil, xerrors.Errorf("unknown feature: %s", key.Feature)
282+
}
283+
284+
m[key.Feature] = append(m[key.Feature], key)
285+
}
286+
return m, nil
287+
}
288+
289+
// minStartsAt ensures the minimum starts_at time we use for a new
290+
// key is no less than 3*the default rotation interval.
291+
func minStartsAt(key database.CryptoKey, now time.Time, keyDuration time.Duration) time.Time {
292+
expiresAt := key.ExpiresAt(keyDuration)
293+
minStartsAt := now.Add(3 * defaultRotationInterval)
294+
if expiresAt.Before(minStartsAt) {
295+
return minStartsAt
296+
}
297+
return expiresAt
298+
}

0 commit comments

Comments
 (0)