Skip to content

Commit f537193

Browse files
authored
chore: refactor keycache implementation to reduce duplication (coder#15100)
1 parent 8e254cb commit f537193

File tree

10 files changed

+512
-1339
lines changed

10 files changed

+512
-1339
lines changed

coderd/cryptokeys/cache.go

Lines changed: 369 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,369 @@
1+
package cryptokeys
2+
3+
import (
4+
"context"
5+
"encoding/hex"
6+
"io"
7+
"strconv"
8+
"sync"
9+
"time"
10+
11+
"golang.org/x/xerrors"
12+
13+
"cdr.dev/slog"
14+
"github.com/coder/coder/v2/coderd/database"
15+
"github.com/coder/coder/v2/coderd/database/db2sdk"
16+
"github.com/coder/coder/v2/codersdk"
17+
"github.com/coder/quartz"
18+
)
19+
20+
var (
21+
ErrKeyNotFound = xerrors.New("key not found")
22+
ErrKeyInvalid = xerrors.New("key is invalid for use")
23+
ErrClosed = xerrors.New("closed")
24+
ErrInvalidFeature = xerrors.New("invalid feature for this operation")
25+
)
26+
27+
type Fetcher interface {
28+
Fetch(ctx context.Context) ([]codersdk.CryptoKey, error)
29+
}
30+
31+
type EncryptionKeycache interface {
32+
// EncryptingKey returns the latest valid key for encrypting payloads. A valid
33+
// key is one that is both past its start time and before its deletion time.
34+
EncryptingKey(ctx context.Context) (id string, key interface{}, err error)
35+
// DecryptingKey returns the key with the provided id which maps to its sequence
36+
// number. The key is valid for decryption as long as it is not deleted or past
37+
// its deletion date. We must allow for keys prior to their start time to
38+
// account for clock skew between peers (one key may be past its start time on
39+
// one machine while another is not).
40+
DecryptingKey(ctx context.Context, id string) (key interface{}, err error)
41+
io.Closer
42+
}
43+
44+
type SigningKeycache interface {
45+
// SigningKey returns the latest valid key for signing. A valid key is one
46+
// that is both past its start time and before its deletion time.
47+
SigningKey(ctx context.Context) (id string, key interface{}, err error)
48+
// VerifyingKey returns the key with the provided id which should map to its
49+
// sequence number. The key is valid for verifying as long as it is not deleted
50+
// or past its deletion date. We must allow for keys prior to their start time
51+
// to account for clock skew between peers (one key may be past its start time
52+
// on one machine while another is not).
53+
VerifyingKey(ctx context.Context, id string) (key interface{}, err error)
54+
io.Closer
55+
}
56+
57+
const (
58+
// latestSequence is a special sequence number that represents the latest key.
59+
latestSequence = -1
60+
// refreshInterval is the interval at which the key cache will refresh.
61+
refreshInterval = time.Minute * 10
62+
)
63+
64+
type DBFetcher struct {
65+
DB database.Store
66+
Feature database.CryptoKeyFeature
67+
}
68+
69+
func (d *DBFetcher) Fetch(ctx context.Context) ([]codersdk.CryptoKey, error) {
70+
keys, err := d.DB.GetCryptoKeysByFeature(ctx, d.Feature)
71+
if err != nil {
72+
return nil, xerrors.Errorf("get crypto keys by feature: %w", err)
73+
}
74+
75+
return db2sdk.CryptoKeys(keys), nil
76+
}
77+
78+
// cache implements the caching functionality for both signing and encryption keys.
79+
type cache struct {
80+
clock quartz.Clock
81+
refreshCtx context.Context
82+
refreshCancel context.CancelFunc
83+
fetcher Fetcher
84+
logger slog.Logger
85+
feature codersdk.CryptoKeyFeature
86+
87+
mu sync.Mutex
88+
keys map[int32]codersdk.CryptoKey
89+
lastFetch time.Time
90+
refresher *quartz.Timer
91+
fetching bool
92+
closed bool
93+
cond *sync.Cond
94+
}
95+
96+
type CacheOption func(*cache)
97+
98+
func WithCacheClock(clock quartz.Clock) CacheOption {
99+
return func(d *cache) {
100+
d.clock = clock
101+
}
102+
}
103+
104+
// NewSigningCache instantiates a cache. Close should be called to release resources
105+
// associated with its internal timer.
106+
func NewSigningCache(ctx context.Context, logger slog.Logger, fetcher Fetcher,
107+
feature codersdk.CryptoKeyFeature, opts ...func(*cache),
108+
) (SigningKeycache, error) {
109+
if !isSigningKeyFeature(feature) {
110+
return nil, xerrors.Errorf("invalid feature: %s", feature)
111+
}
112+
return newCache(ctx, logger, fetcher, feature, opts...)
113+
}
114+
115+
func NewEncryptionCache(ctx context.Context, logger slog.Logger, fetcher Fetcher,
116+
feature codersdk.CryptoKeyFeature, opts ...func(*cache),
117+
) (EncryptionKeycache, error) {
118+
if !isEncryptionKeyFeature(feature) {
119+
return nil, xerrors.Errorf("invalid feature: %s", feature)
120+
}
121+
return newCache(ctx, logger, fetcher, feature, opts...)
122+
}
123+
124+
func newCache(ctx context.Context, logger slog.Logger, fetcher Fetcher, feature codersdk.CryptoKeyFeature, opts ...func(*cache)) (*cache, error) {
125+
cache := &cache{
126+
clock: quartz.NewReal(),
127+
logger: logger,
128+
fetcher: fetcher,
129+
feature: feature,
130+
}
131+
132+
for _, opt := range opts {
133+
opt(cache)
134+
}
135+
136+
cache.cond = sync.NewCond(&cache.mu)
137+
cache.refreshCtx, cache.refreshCancel = context.WithCancel(ctx)
138+
cache.refresher = cache.clock.AfterFunc(refreshInterval, cache.refresh)
139+
140+
keys, err := cache.cryptoKeys(ctx)
141+
if err != nil {
142+
cache.refreshCancel()
143+
return nil, xerrors.Errorf("initial fetch: %w", err)
144+
}
145+
cache.keys = keys
146+
return cache, nil
147+
}
148+
149+
func (c *cache) EncryptingKey(ctx context.Context) (string, interface{}, error) {
150+
if !isEncryptionKeyFeature(c.feature) {
151+
return "", nil, ErrInvalidFeature
152+
}
153+
154+
return c.cryptoKey(ctx, latestSequence)
155+
}
156+
157+
func (c *cache) DecryptingKey(ctx context.Context, id string) (interface{}, error) {
158+
if !isEncryptionKeyFeature(c.feature) {
159+
return nil, ErrInvalidFeature
160+
}
161+
162+
seq, err := strconv.ParseInt(id, 10, 64)
163+
if err != nil {
164+
return nil, xerrors.Errorf("parse id: %w", err)
165+
}
166+
167+
_, secret, err := c.cryptoKey(ctx, int32(seq))
168+
if err != nil {
169+
return nil, xerrors.Errorf("crypto key: %w", err)
170+
}
171+
return secret, nil
172+
}
173+
174+
func (c *cache) SigningKey(ctx context.Context) (string, interface{}, error) {
175+
if !isSigningKeyFeature(c.feature) {
176+
return "", nil, ErrInvalidFeature
177+
}
178+
179+
return c.cryptoKey(ctx, latestSequence)
180+
}
181+
182+
func (c *cache) VerifyingKey(ctx context.Context, id string) (interface{}, error) {
183+
if !isSigningKeyFeature(c.feature) {
184+
return nil, ErrInvalidFeature
185+
}
186+
187+
seq, err := strconv.ParseInt(id, 10, 64)
188+
if err != nil {
189+
return nil, xerrors.Errorf("parse id: %w", err)
190+
}
191+
192+
_, secret, err := c.cryptoKey(ctx, int32(seq))
193+
if err != nil {
194+
return nil, xerrors.Errorf("crypto key: %w", err)
195+
}
196+
197+
return secret, nil
198+
}
199+
200+
func isEncryptionKeyFeature(feature codersdk.CryptoKeyFeature) bool {
201+
return feature == codersdk.CryptoKeyFeatureWorkspaceApp
202+
}
203+
204+
func isSigningKeyFeature(feature codersdk.CryptoKeyFeature) bool {
205+
switch feature {
206+
case codersdk.CryptoKeyFeatureTailnetResume, codersdk.CryptoKeyFeatureOIDCConvert:
207+
return true
208+
default:
209+
return false
210+
}
211+
}
212+
213+
func idSecret(k codersdk.CryptoKey) (string, []byte, error) {
214+
key, err := hex.DecodeString(k.Secret)
215+
if err != nil {
216+
return "", nil, xerrors.Errorf("decode key: %w", err)
217+
}
218+
219+
return strconv.FormatInt(int64(k.Sequence), 10), key, nil
220+
}
221+
222+
func (c *cache) cryptoKey(ctx context.Context, sequence int32) (string, []byte, error) {
223+
c.mu.Lock()
224+
defer c.mu.Unlock()
225+
226+
if c.closed {
227+
return "", nil, ErrClosed
228+
}
229+
230+
var key codersdk.CryptoKey
231+
var ok bool
232+
for key, ok = c.key(sequence); !ok && c.fetching && !c.closed; {
233+
c.cond.Wait()
234+
}
235+
236+
if c.closed {
237+
return "", nil, ErrClosed
238+
}
239+
240+
if ok {
241+
return checkKey(key, sequence, c.clock.Now())
242+
}
243+
244+
c.fetching = true
245+
c.mu.Unlock()
246+
247+
keys, err := c.cryptoKeys(ctx)
248+
if err != nil {
249+
return "", nil, xerrors.Errorf("get keys: %w", err)
250+
}
251+
252+
c.mu.Lock()
253+
c.lastFetch = c.clock.Now()
254+
c.refresher.Reset(refreshInterval)
255+
c.keys = keys
256+
c.fetching = false
257+
c.cond.Broadcast()
258+
259+
key, ok = c.key(sequence)
260+
if !ok {
261+
return "", nil, ErrKeyNotFound
262+
}
263+
264+
return checkKey(key, sequence, c.clock.Now())
265+
}
266+
267+
func (c *cache) key(sequence int32) (codersdk.CryptoKey, bool) {
268+
if sequence == latestSequence {
269+
return c.keys[latestSequence], c.keys[latestSequence].CanSign(c.clock.Now())
270+
}
271+
272+
key, ok := c.keys[sequence]
273+
return key, ok
274+
}
275+
276+
func checkKey(key codersdk.CryptoKey, sequence int32, now time.Time) (string, []byte, error) {
277+
if sequence == latestSequence {
278+
if !key.CanSign(now) {
279+
return "", nil, ErrKeyInvalid
280+
}
281+
return idSecret(key)
282+
}
283+
284+
if !key.CanVerify(now) {
285+
return "", nil, ErrKeyInvalid
286+
}
287+
288+
return idSecret(key)
289+
}
290+
291+
// refresh fetches the keys and updates the cache.
292+
func (c *cache) refresh() {
293+
now := c.clock.Now("CryptoKeyCache", "refresh")
294+
c.mu.Lock()
295+
defer c.mu.Unlock()
296+
297+
if c.closed {
298+
return
299+
}
300+
301+
// If something's already fetching, we don't need to do anything.
302+
if c.fetching {
303+
return
304+
}
305+
306+
// There's a window we must account for where the timer fires while a fetch
307+
// is ongoing but prior to the timer getting reset. In this case we want to
308+
// avoid double fetching.
309+
if now.Sub(c.lastFetch) < refreshInterval {
310+
return
311+
}
312+
313+
c.fetching = true
314+
315+
c.mu.Unlock()
316+
keys, err := c.cryptoKeys(c.refreshCtx)
317+
if err != nil {
318+
c.logger.Error(c.refreshCtx, "fetch crypto keys", slog.Error(err))
319+
return
320+
}
321+
322+
// We don't defer an unlock here due to the deferred unlock at the top of the function.
323+
c.mu.Lock()
324+
325+
c.lastFetch = c.clock.Now()
326+
c.refresher.Reset(refreshInterval)
327+
c.keys = keys
328+
c.fetching = false
329+
c.cond.Broadcast()
330+
}
331+
332+
// cryptoKeys queries the control plane for the crypto keys.
333+
// Outside of initialization, this should only be called by fetch.
334+
func (c *cache) cryptoKeys(ctx context.Context) (map[int32]codersdk.CryptoKey, error) {
335+
keys, err := c.fetcher.Fetch(ctx)
336+
if err != nil {
337+
return nil, xerrors.Errorf("crypto keys: %w", err)
338+
}
339+
cache := toKeyMap(keys, c.clock.Now())
340+
return cache, nil
341+
}
342+
343+
func toKeyMap(keys []codersdk.CryptoKey, now time.Time) map[int32]codersdk.CryptoKey {
344+
m := make(map[int32]codersdk.CryptoKey)
345+
var latest codersdk.CryptoKey
346+
for _, key := range keys {
347+
m[key.Sequence] = key
348+
if key.Sequence > latest.Sequence && key.CanSign(now) {
349+
m[latestSequence] = key
350+
}
351+
}
352+
return m
353+
}
354+
355+
func (c *cache) Close() error {
356+
c.mu.Lock()
357+
defer c.mu.Unlock()
358+
359+
if c.closed {
360+
return nil
361+
}
362+
363+
c.closed = true
364+
c.refreshCancel()
365+
c.refresher.Stop()
366+
c.cond.Broadcast()
367+
368+
return nil
369+
}

0 commit comments

Comments
 (0)