Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix tests
  • Loading branch information
sreya committed Oct 2, 2024
commit 0fef6b0569dc903bc1a0c4c3e6b8d57ef0705c8d
73 changes: 51 additions & 22 deletions enterprise/wsproxy/keycache.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,14 @@ func NewCryptoKeyCache(ctx context.Context, log slog.Logger, client *wsproxysdk.
opt(cache)
}

m, latest, err := cache.fetch(ctx)
cache.refreshCtx, cache.refreshCancel = context.WithCancel(ctx)
cache.refresher = cache.Clock.AfterFunc(time.Minute*10, cache.refresh)
m, latest, err := cache.fetchKeys(ctx)
if err != nil {
cache.refreshCancel()
return nil, xerrors.Errorf("initial fetch: %w", err)
}
cache.keys, cache.latest = m, latest
cache.refresher = cache.Clock.AfterFunc(time.Minute*10, cache.refresh)

return cache, nil
}
Expand All @@ -77,9 +79,12 @@ func (k *CryptoKeyCache) Signing(ctx context.Context) (codersdk.CryptoKey, error
}

k.keysMu.RLock()
if k.latest.CanSign(now) {
k.keysMu.RUnlock()
return k.latest, nil
latest = k.latest
k.keysMu.RUnlock()

now = k.Clock.Now()
if latest.CanSign(now) {
return latest, nil
}

_, latest, err := k.fetch(ctx)
Expand All @@ -91,27 +96,28 @@ func (k *CryptoKeyCache) Signing(ctx context.Context) (codersdk.CryptoKey, error
}

func (k *CryptoKeyCache) Verifying(ctx context.Context, sequence int32) (codersdk.CryptoKey, error) {
now := k.Clock.Now()
k.keysMu.RLock()
if k.isClosed() {
k.keysMu.RUnlock()
return codersdk.CryptoKey{}, cryptokeys.ErrClosed
}

now := k.Clock.Now()
k.keysMu.RLock()
key, ok := k.keys[sequence]
k.keysMu.RUnlock()
if ok {
return validKey(key, now)
}

k.keysMu.Lock()
defer k.keysMu.Unlock()
k.fetchLock.Lock()
defer k.fetchLock.Unlock()

if k.isClosed() {
return codersdk.CryptoKey{}, cryptokeys.ErrClosed
}

k.keysMu.RLock()
key, ok = k.keys[sequence]
k.keysMu.RUnlock()
if ok {
return validKey(key, now)
}
Expand All @@ -134,14 +140,23 @@ func (k *CryptoKeyCache) refresh() {
return
}

k.keysMu.RLock()
if k.Clock.Now().Sub(k.lastFetch) < time.Minute*10 {
k.keysMu.Unlock()
k.fetchLock.Lock()
defer k.fetchLock.Unlock()

if k.isClosed() {
return
}

k.fetchLock.Lock()
defer k.fetchLock.Unlock()
k.keysMu.RLock()
lastFetch := k.lastFetch
k.keysMu.RUnlock()

// There's a window we must account for where the timer fires while a fetch
// is ongoing but prior to the timer getting reset. In this case we want to
// avoid double fetching.
if k.Clock.Now().Sub(lastFetch) < time.Minute*10 {
return
}

_, _, err := k.fetch(k.refreshCtx)
if err != nil {
Expand All @@ -150,19 +165,28 @@ func (k *CryptoKeyCache) refresh() {
}
}

func (k *CryptoKeyCache) fetch(ctx context.Context) (map[int32]codersdk.CryptoKey, codersdk.CryptoKey, error) {

func (k *CryptoKeyCache) fetchKeys(ctx context.Context) (map[int32]codersdk.CryptoKey, codersdk.CryptoKey, error) {
keys, err := k.client.CryptoKeys(ctx)
if err != nil {
return nil, codersdk.CryptoKey{}, xerrors.Errorf("get security keys: %w", err)
return nil, codersdk.CryptoKey{}, xerrors.Errorf("crypto keys: %w", err)
}
cache, latest := toKeyMap(keys.CryptoKeys, k.Clock.Now())
return cache, latest, nil
}

if len(keys.CryptoKeys) == 0 {
// fetch fetches the keys from the control plane and updates the cache. The fetchMu
// must be held when calling this function to avoid multiple concurrent fetches.
func (k *CryptoKeyCache) fetch(ctx context.Context) (map[int32]codersdk.CryptoKey, codersdk.CryptoKey, error) {
keys, latest, err := k.fetchKeys(ctx)
if err != nil {
return nil, codersdk.CryptoKey{}, xerrors.Errorf("fetch keys: %w", err)
}

if len(keys) == 0 {
return nil, codersdk.CryptoKey{}, cryptokeys.ErrKeyNotFound
}

now := k.Clock.Now()
kmap, latest := toKeyMap(keys.CryptoKeys, now)
if !latest.CanSign(now) {
return nil, codersdk.CryptoKey{}, cryptokeys.ErrKeyInvalid
}
Expand All @@ -172,9 +196,9 @@ func (k *CryptoKeyCache) fetch(ctx context.Context) (map[int32]codersdk.CryptoKe

k.lastFetch = k.Clock.Now()
k.refresher.Reset(time.Minute * 10)
k.keys, k.latest = kmap, latest
k.keys, k.latest = keys, latest

return kmap, latest, nil
return keys, latest, nil
}

func toKeyMap(keys []codersdk.CryptoKey, now time.Time) (map[int32]codersdk.CryptoKey, codersdk.CryptoKey) {
Expand Down Expand Up @@ -202,6 +226,11 @@ func (k *CryptoKeyCache) isClosed() bool {
}

func (k *CryptoKeyCache) Close() {
// The fetch lock must always be held before holding the keys lock
// otherwise we risk a deadlock.
k.fetchLock.Lock()
defer k.fetchLock.Unlock()

k.keysMu.Lock()
defer k.keysMu.Unlock()

Expand Down
13 changes: 4 additions & 9 deletions enterprise/wsproxy/keycache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -317,8 +317,6 @@ func TestCryptoKeyCache(t *testing.T) {
clock = quartz.NewMock(t)
)

trap := clock.Trap().TickerFunc()

now := clock.Now().UTC()
expected := codersdk.CryptoKey{
Feature: codersdk.CryptoKeyFeatureWorkspaceApp,
Expand All @@ -339,8 +337,6 @@ func TestCryptoKeyCache(t *testing.T) {
require.Equal(t, expected, got)
require.Equal(t, 1, fc.called)

wait := trap.MustWait(ctx)

newKey := codersdk.CryptoKey{
Feature: codersdk.CryptoKeyFeatureWorkspaceApp,
Secret: "key2",
Expand All @@ -349,8 +345,6 @@ func TestCryptoKeyCache(t *testing.T) {
}
fc.keys = []codersdk.CryptoKey{newKey}

wait.Release()

// The ticker should fire and cause a request to coderd.
_, advance := clock.AdvanceNext()
advance.MustWait(ctx)
Expand All @@ -362,9 +356,10 @@ func TestCryptoKeyCache(t *testing.T) {
require.Equal(t, newKey, got)
require.Equal(t, 2, fc.called)

// Assert we do not have the old key.
_, err = cache.Verifying(ctx, expected.Sequence)
require.Error(t, err)
// The ticker should fire and cause a request to coderd.
_, advance = clock.AdvanceNext()
advance.MustWait(ctx)
require.Equal(t, 3, fc.called)
})

t.Run("Closed", func(t *testing.T) {
Expand Down