Skip to content

Commit a30b3c6

Browse files
committed
Refactor wsproxy to use cryptokeys interface
- Implement `cryptokeys.Keycache` interface in `CryptoKeyCache`. - Introduce context management for graceful shutdowns. - Simplify function signatures and improve concurrency handling. - Ensure functions return errors when cache is closed.
1 parent 87cb577 commit a30b3c6

File tree

2 files changed

+102
-19
lines changed

2 files changed

+102
-19
lines changed

enterprise/wsproxy/keycache.go

Lines changed: 46 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,19 +9,25 @@ import (
99

1010
"cdr.dev/slog"
1111

12+
"github.com/coder/coder/v2/coderd/cryptokeys"
1213
"github.com/coder/coder/v2/codersdk"
1314
"github.com/coder/coder/v2/enterprise/wsproxy/wsproxysdk"
1415
"github.com/coder/quartz"
1516
)
1617

18+
var _ cryptokeys.Keycache = &CryptoKeyCache{}
19+
1720
type CryptoKeyCache struct {
21+
ctx context.Context
22+
cancel context.CancelFunc
1823
client *wsproxysdk.Client
1924
logger slog.Logger
2025
Clock quartz.Clock
2126

2227
keysMu sync.RWMutex
2328
keys map[int32]codersdk.CryptoKey
2429
latest codersdk.CryptoKey
30+
closed bool
2531
}
2632

2733
func NewCryptoKeyCache(ctx context.Context, log slog.Logger, client *wsproxysdk.Client, opts ...func(*CryptoKeyCache)) (*CryptoKeyCache, error) {
@@ -40,14 +46,21 @@ func NewCryptoKeyCache(ctx context.Context, log slog.Logger, client *wsproxysdk.
4046
return nil, xerrors.Errorf("initial fetch: %w", err)
4147
}
4248
cache.keys, cache.latest = m, latest
49+
cache.ctx, cache.cancel = context.WithCancel(ctx)
4350

44-
go cache.refresh(ctx)
51+
go cache.refresh()
4552

4653
return cache, nil
4754
}
4855

49-
func (k *CryptoKeyCache) Latest(ctx context.Context) (codersdk.CryptoKey, error) {
56+
func (k *CryptoKeyCache) Signing(ctx context.Context) (codersdk.CryptoKey, error) {
5057
k.keysMu.RLock()
58+
59+
if k.closed {
60+
k.keysMu.RUnlock()
61+
return codersdk.CryptoKey{}, cryptokeys.ErrClosed
62+
}
63+
5164
latest := k.latest
5265
k.keysMu.RUnlock()
5366

@@ -59,6 +72,10 @@ func (k *CryptoKeyCache) Latest(ctx context.Context) (codersdk.CryptoKey, error)
5972
k.keysMu.Lock()
6073
defer k.keysMu.Unlock()
6174

75+
if k.closed {
76+
return codersdk.CryptoKey{}, cryptokeys.ErrClosed
77+
}
78+
6279
if k.latest.CanSign(now) {
6380
return k.latest, nil
6481
}
@@ -76,9 +93,14 @@ func (k *CryptoKeyCache) Latest(ctx context.Context) (codersdk.CryptoKey, error)
7693
return k.latest, nil
7794
}
7895

79-
func (k *CryptoKeyCache) Version(ctx context.Context, sequence int32) (codersdk.CryptoKey, error) {
96+
func (k *CryptoKeyCache) Verifying(ctx context.Context, sequence int32) (codersdk.CryptoKey, error) {
8097
now := k.Clock.Now().UTC()
8198
k.keysMu.RLock()
99+
if k.closed {
100+
k.keysMu.RUnlock()
101+
return codersdk.CryptoKey{}, cryptokeys.ErrClosed
102+
}
103+
82104
key, ok := k.keys[sequence]
83105
k.keysMu.RUnlock()
84106
if ok {
@@ -87,6 +109,11 @@ func (k *CryptoKeyCache) Version(ctx context.Context, sequence int32) (codersdk.
87109

88110
k.keysMu.Lock()
89111
defer k.keysMu.Unlock()
112+
113+
if k.closed {
114+
return codersdk.CryptoKey{}, cryptokeys.ErrClosed
115+
}
116+
90117
key, ok = k.keys[sequence]
91118
if ok {
92119
return validKey(key, now)
@@ -106,11 +133,11 @@ func (k *CryptoKeyCache) Version(ctx context.Context, sequence int32) (codersdk.
106133
return validKey(key, now)
107134
}
108135

109-
func (k *CryptoKeyCache) refresh(ctx context.Context) {
110-
k.Clock.TickerFunc(ctx, time.Minute*10, func() error {
111-
kmap, latest, err := k.fetch(ctx)
136+
func (k *CryptoKeyCache) refresh() {
137+
k.Clock.TickerFunc(k.ctx, time.Minute*10, func() error {
138+
kmap, latest, err := k.fetch(k.ctx)
112139
if err != nil {
113-
k.logger.Error(ctx, "failed to fetch crypto keys", slog.Error(err))
140+
k.logger.Error(k.ctx, "failed to fetch crypto keys", slog.Error(err))
114141
return nil
115142
}
116143

@@ -151,3 +178,15 @@ func validKey(key codersdk.CryptoKey, now time.Time) (codersdk.CryptoKey, error)
151178

152179
return key, nil
153180
}
181+
182+
func (k *CryptoKeyCache) Close() {
183+
k.keysMu.Lock()
184+
defer k.keysMu.Unlock()
185+
186+
if k.closed {
187+
return
188+
}
189+
190+
k.cancel()
191+
k.closed = true
192+
}

enterprise/wsproxy/keycache_test.go

Lines changed: 56 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import (
1212

1313
"cdr.dev/slog/sloggers/slogtest"
1414

15+
"github.com/coder/coder/v2/coderd/cryptokeys"
1516
"github.com/coder/coder/v2/codersdk"
1617
"github.com/coder/coder/v2/enterprise/wsproxy"
1718
"github.com/coder/coder/v2/enterprise/wsproxy/wsproxysdk"
@@ -61,7 +62,7 @@ func TestCryptoKeyCache(t *testing.T) {
6162
cache, err := wsproxy.NewCryptoKeyCache(ctx, logger, wsproxysdk.New(fc.url), withClock(clock))
6263
require.NoError(t, err)
6364

64-
got, err := cache.Latest(ctx)
65+
got, err := cache.Signing(ctx)
6566
require.NoError(t, err)
6667
require.Equal(t, expected, got)
6768
require.Equal(t, 1, fc.called)
@@ -88,14 +89,14 @@ func TestCryptoKeyCache(t *testing.T) {
8889
}
8990
fc.keys = []codersdk.CryptoKey{expected}
9091

91-
got, err := cache.Latest(ctx)
92+
got, err := cache.Signing(ctx)
9293
require.NoError(t, err)
9394
require.Equal(t, expected, got)
9495
// 1 on startup + missing cache.
9596
require.Equal(t, 2, fc.called)
9697

9798
// Ensure the cache gets hit this time.
98-
got, err = cache.Latest(ctx)
99+
got, err = cache.Signing(ctx)
99100
require.NoError(t, err)
100101
require.Equal(t, expected, got)
101102
// 1 on startup + missing cache.
@@ -132,7 +133,7 @@ func TestCryptoKeyCache(t *testing.T) {
132133
cache, err := wsproxy.NewCryptoKeyCache(ctx, logger, wsproxysdk.New(fc.url), withClock(clock))
133134
require.NoError(t, err)
134135

135-
got, err := cache.Latest(ctx)
136+
got, err := cache.Signing(ctx)
136137
require.NoError(t, err)
137138
require.Equal(t, expected, got)
138139
require.Equal(t, 1, fc.called)
@@ -171,7 +172,7 @@ func TestCryptoKeyCache(t *testing.T) {
171172
cache, err := wsproxy.NewCryptoKeyCache(ctx, logger, wsproxysdk.New(fc.url), withClock(clock))
172173
require.NoError(t, err)
173174

174-
got, err := cache.Version(ctx, expected.Sequence)
175+
got, err := cache.Verifying(ctx, expected.Sequence)
175176
require.NoError(t, err)
176177
require.Equal(t, expected, got)
177178
require.Equal(t, 1, fc.called)
@@ -198,13 +199,13 @@ func TestCryptoKeyCache(t *testing.T) {
198199
}
199200
fc.keys = []codersdk.CryptoKey{expected}
200201

201-
got, err := cache.Version(ctx, expected.Sequence)
202+
got, err := cache.Verifying(ctx, expected.Sequence)
202203
require.NoError(t, err)
203204
require.Equal(t, expected, got)
204205
require.Equal(t, 2, fc.called)
205206

206207
// Ensure the cache gets hit this time.
207-
got, err = cache.Version(ctx, expected.Sequence)
208+
got, err = cache.Verifying(ctx, expected.Sequence)
208209
require.NoError(t, err)
209210
require.Equal(t, expected, got)
210211
require.Equal(t, 2, fc.called)
@@ -234,7 +235,7 @@ func TestCryptoKeyCache(t *testing.T) {
234235
cache, err := wsproxy.NewCryptoKeyCache(ctx, logger, wsproxysdk.New(fc.url), withClock(clock))
235236
require.NoError(t, err)
236237

237-
got, err := cache.Version(ctx, expected.Sequence)
238+
got, err := cache.Verifying(ctx, expected.Sequence)
238239
require.NoError(t, err)
239240
require.Equal(t, expected, got)
240241
require.Equal(t, 1, fc.called)
@@ -265,7 +266,7 @@ func TestCryptoKeyCache(t *testing.T) {
265266
cache, err := wsproxy.NewCryptoKeyCache(ctx, logger, wsproxysdk.New(fc.url), withClock(clock))
266267
require.NoError(t, err)
267268

268-
_, err = cache.Version(ctx, expected.Sequence)
269+
_, err = cache.Verifying(ctx, expected.Sequence)
269270
require.Error(t, err)
270271
require.Equal(t, 1, fc.called)
271272
})
@@ -297,7 +298,7 @@ func TestCryptoKeyCache(t *testing.T) {
297298
cache, err := wsproxy.NewCryptoKeyCache(ctx, logger, wsproxysdk.New(fc.url), withClock(clock))
298299
require.NoError(t, err)
299300

300-
got, err := cache.Latest(ctx)
301+
got, err := cache.Signing(ctx)
301302
require.NoError(t, err)
302303
require.Equal(t, expected, got)
303304
require.Equal(t, 1, fc.called)
@@ -320,15 +321,58 @@ func TestCryptoKeyCache(t *testing.T) {
320321
require.Equal(t, 2, fc.called)
321322

322323
// Assert hits cache.
323-
got, err = cache.Latest(ctx)
324+
got, err = cache.Signing(ctx)
324325
require.NoError(t, err)
325326
require.Equal(t, newKey, got)
326327
require.Equal(t, 2, fc.called)
327328

328329
// Assert we do not have the old key.
329-
_, err = cache.Version(ctx, expected.Sequence)
330+
_, err = cache.Verifying(ctx, expected.Sequence)
330331
require.Error(t, err)
331332
})
333+
334+
t.Run("Closed", func(t *testing.T) {
335+
t.Parallel()
336+
337+
var (
338+
ctx = testutil.Context(t, testutil.WaitShort)
339+
logger = slogtest.Make(t, nil)
340+
clock = quartz.NewMock(t)
341+
)
342+
343+
now := clock.Now()
344+
expected := codersdk.CryptoKey{
345+
Feature: codersdk.CryptoKeyFeatureWorkspaceApp,
346+
Secret: "key1",
347+
Sequence: 12,
348+
StartsAt: now,
349+
DeletesAt: now.Add(time.Minute * 10),
350+
}
351+
fc := newFakeCoderd(t, []codersdk.CryptoKey{
352+
expected,
353+
})
354+
355+
cache, err := wsproxy.NewCryptoKeyCache(ctx, logger, wsproxysdk.New(fc.url), withClock(clock))
356+
require.NoError(t, err)
357+
358+
got, err := cache.Signing(ctx)
359+
require.NoError(t, err)
360+
require.Equal(t, expected, got)
361+
require.Equal(t, 1, fc.called)
362+
363+
got, err = cache.Verifying(ctx, expected.Sequence)
364+
require.NoError(t, err)
365+
require.Equal(t, expected, got)
366+
require.Equal(t, 1, fc.called)
367+
368+
cache.Close()
369+
370+
_, err = cache.Signing(ctx)
371+
require.ErrorIs(t, err, cryptokeys.ErrClosed)
372+
373+
_, err = cache.Verifying(ctx, expected.Sequence)
374+
require.ErrorIs(t, err, cryptokeys.ErrClosed)
375+
})
332376
}
333377

334378
type fakeCoderd struct {

0 commit comments

Comments
 (0)