From cab9961092366c13a7bb44b8d8f3029ca39483ee Mon Sep 17 00:00:00 2001 From: Jon Ayers Date: Tue, 15 Oct 2024 01:48:50 +0000 Subject: [PATCH 1/7] chore: refactor keycache implemenation to reduce duplication --- coderd/cryptokeys/dbkeycache.go | 388 +++++---- coderd/cryptokeys/dbkeycache_internal_test.go | 736 +++++++++--------- coderd/cryptokeys/dbkeycache_test.go | 70 +- coderd/cryptokeys/keycache.go | 6 + coderd/jwtutils/jwt_test.go | 9 +- 5 files changed, 650 insertions(+), 559 deletions(-) diff --git a/coderd/cryptokeys/dbkeycache.go b/coderd/cryptokeys/dbkeycache.go index aa0a2444b35f2..c2fb3b55372a0 100644 --- a/coderd/cryptokeys/dbkeycache.go +++ b/coderd/cryptokeys/dbkeycache.go @@ -2,6 +2,7 @@ package cryptokeys import ( "context" + "encoding/hex" "strconv" "sync" "time" @@ -10,277 +11,342 @@ import ( "cdr.dev/slog" "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/db2sdk" + "github.com/coder/coder/v2/codersdk" "github.com/coder/quartz" ) -// never represents the maximum value for a time.Duration. -const never = 1<<63 - 1 - -// dbCache implements Keycache for callers with access to the database. -type dbCache struct { - db database.Store - feature database.CryptoKeyFeature - logger slog.Logger - clock quartz.Clock - - // The following are initialized by NewDBCache. - keysMu sync.RWMutex - keys map[int32]database.CryptoKey - latestKey database.CryptoKey - timer *quartz.Timer - // invalidateAt is the time at which the keys cache should be invalidated. - invalidateAt time.Time - closed bool +const ( + // latestSequence is a special sequence number that represents the latest key. + latestSequence = -1 + // refreshInterval is the interval at which the key cache will refresh. + refreshInterval = time.Minute * 10 +) + +type DBFetcher struct { + DB database.Store + Feature database.CryptoKeyFeature +} + +func (d *DBFetcher) Fetch(ctx context.Context) ([]codersdk.CryptoKey, error) { + keys, err := d.DB.GetCryptoKeysByFeature(ctx, d.Feature) + if err != nil { + return nil, xerrors.Errorf("get crypto keys by feature: %w", err) + } + + return db2sdk.CryptoKeys(keys), nil } -type DBCacheOption func(*dbCache) +// CryptoKeyCache implements Keycache for callers with access to the database. +type CryptoKeyCache struct { + clock quartz.Clock + refreshCtx context.Context + refreshCancel context.CancelFunc + fetcher Fetcher + logger slog.Logger + feature database.CryptoKeyFeature + + mu sync.Mutex + keys map[int32]codersdk.CryptoKey + lastFetch time.Time + refresher *quartz.Timer + fetching bool + closed bool + cond *sync.Cond +} + +type DBCacheOption func(*CryptoKeyCache) func WithDBCacheClock(clock quartz.Clock) DBCacheOption { - return func(d *dbCache) { + return func(d *CryptoKeyCache) { d.clock = clock } } // NewSigningCache creates a new DBCache. Close should be called to // release resources associated with its internal timer. -func NewSigningCache(logger slog.Logger, db database.Store, feature database.CryptoKeyFeature, opts ...func(*dbCache)) (SigningKeycache, error) { +func NewSigningCache(ctx context.Context, logger slog.Logger, fetcher Fetcher, feature database.CryptoKeyFeature, opts ...func(*CryptoKeyCache)) (SigningKeycache, error) { if !isSigningKeyFeature(feature) { return nil, ErrInvalidFeature } - return newDBCache(logger, db, feature, opts...), nil + return newDBCache(ctx, logger, fetcher, feature, opts...) } -func NewEncryptionCache(logger slog.Logger, db database.Store, feature database.CryptoKeyFeature, opts ...func(*dbCache)) (EncryptionKeycache, error) { +func NewEncryptionCache(ctx context.Context, logger slog.Logger, fetcher Fetcher, feature database.CryptoKeyFeature, opts ...func(*CryptoKeyCache)) (EncryptionKeycache, error) { if !isEncryptionKeyFeature(feature) { return nil, ErrInvalidFeature } - return newDBCache(logger, db, feature, opts...), nil + return newDBCache(ctx, logger, fetcher, feature, opts...) } -func newDBCache(logger slog.Logger, db database.Store, feature database.CryptoKeyFeature, opts ...func(*dbCache)) *dbCache { - d := &dbCache{ - db: db, - feature: feature, +func newDBCache(ctx context.Context, logger slog.Logger, fetcher Fetcher, feature database.CryptoKeyFeature, opts ...func(*CryptoKeyCache)) (*CryptoKeyCache, error) { + cache := &CryptoKeyCache{ clock: quartz.NewReal(), logger: logger, + fetcher: fetcher, + feature: feature, } for _, opt := range opts { - opt(d) + opt(cache) } - // Initialize the timer. This will get properly initialized the first time we fetch. - d.timer = d.clock.AfterFunc(never, d.clear) + cache.cond = sync.NewCond(&cache.mu) + cache.refreshCtx, cache.refreshCancel = context.WithCancel(ctx) + cache.refresher = cache.clock.AfterFunc(refreshInterval, cache.refresh) - return d + keys, err := cache.cryptoKeys(ctx) + if err != nil { + cache.refreshCancel() + return nil, xerrors.Errorf("initial fetch: %w", err) + } + cache.keys = keys + return cache, nil } -func (d *dbCache) EncryptingKey(ctx context.Context) (id string, key interface{}, err error) { +func (d *CryptoKeyCache) EncryptingKey(ctx context.Context) (string, interface{}, error) { if !isEncryptionKeyFeature(d.feature) { return "", nil, ErrInvalidFeature } - return d.latest(ctx) + key, err := d.cryptoKey(ctx, latestSequence) + if err != nil { + return "", nil, xerrors.Errorf("crypto key: %w", err) + } + + secret, err := hex.DecodeString(key.Secret) + if err != nil { + return "", nil, xerrors.Errorf("decode key: %w", err) + } + + return strconv.FormatInt(int64(key.Sequence), 10), secret, nil } -func (d *dbCache) DecryptingKey(ctx context.Context, id string) (key interface{}, err error) { +func (d *CryptoKeyCache) DecryptingKey(ctx context.Context, id string) (interface{}, error) { if !isEncryptionKeyFeature(d.feature) { return nil, ErrInvalidFeature } - return d.sequence(ctx, id) + i, err := strconv.ParseInt(id, 10, 64) + if err != nil { + return nil, xerrors.Errorf("parse id: %w", err) + } + + key, err := d.cryptoKey(ctx, int32(i)) + if err != nil { + return nil, xerrors.Errorf("crypto key: %w", err) + } + + secret, err := hex.DecodeString(key.Secret) + if err != nil { + return nil, xerrors.Errorf("decode key: %w", err) + } + + return secret, nil } -func (d *dbCache) SigningKey(ctx context.Context) (id string, key interface{}, err error) { +func (d *CryptoKeyCache) SigningKey(ctx context.Context) (string, interface{}, error) { if !isSigningKeyFeature(d.feature) { return "", nil, ErrInvalidFeature } - return d.latest(ctx) + key, err := d.cryptoKey(ctx, latestSequence) + if err != nil { + return "", nil, xerrors.Errorf("crypto key: %w", err) + } + + return strconv.FormatInt(int64(key.Sequence), 10), key.Secret, nil } -func (d *dbCache) VerifyingKey(ctx context.Context, id string) (key interface{}, err error) { +func (d *CryptoKeyCache) VerifyingKey(ctx context.Context, sequence string) (interface{}, error) { if !isSigningKeyFeature(d.feature) { return nil, ErrInvalidFeature } - return d.sequence(ctx, id) -} - -// sequence returns the CryptoKey with the given sequence number, provided that -// it is neither deleted nor has breached its deletion date. It should only be -// used for verifying or decrypting payloads. To sign/encrypt call Signing. -func (d *dbCache) sequence(ctx context.Context, id string) (interface{}, error) { - sequence, err := strconv.ParseInt(id, 10, 32) + i, err := strconv.ParseInt(sequence, 10, 64) if err != nil { - return nil, xerrors.Errorf("expecting sequence number got %q: %w", id, err) + return nil, xerrors.Errorf("parse id: %w", err) } - d.keysMu.RLock() - if d.closed { - d.keysMu.RUnlock() - return nil, ErrClosed - } - - now := d.clock.Now() - key, ok := d.keys[int32(sequence)] - d.keysMu.RUnlock() - if ok { - return checkKey(key, now) + key, err := d.cryptoKey(ctx, int32(i)) + if err != nil { + return nil, xerrors.Errorf("crypto key: %w", err) } - d.keysMu.Lock() - defer d.keysMu.Unlock() + return key.Secret, nil +} - if d.closed { - return nil, ErrClosed - } +func isEncryptionKeyFeature(feature database.CryptoKeyFeature) bool { + return feature == database.CryptoKeyFeatureWorkspaceApps +} - key, ok = d.keys[int32(sequence)] - if ok { - return checkKey(key, now) +func isSigningKeyFeature(feature database.CryptoKeyFeature) bool { + switch feature { + case database.CryptoKeyFeatureTailnetResume, database.CryptoKeyFeatureOidcConvert: + return true + default: + return false } +} - err = d.fetch(ctx) +func idSecret(k database.CryptoKey) (string, interface{}, error) { + key, err := k.DecodeString() if err != nil { - return nil, xerrors.Errorf("fetch: %w", err) - } - - key, ok = d.keys[int32(sequence)] - if !ok { - return nil, ErrKeyNotFound + return "", nil, xerrors.Errorf("decode key: %w", err) } - return checkKey(key, now) + return strconv.FormatInt(int64(k.Sequence), 10), key, nil } -// latest returns the latest valid key for signing. A valid key is one that is -// both past its start time and before its deletion time. -func (d *dbCache) latest(ctx context.Context) (string, interface{}, error) { - d.keysMu.RLock() +func (k *CryptoKeyCache) cryptoKey(ctx context.Context, sequence int32) (codersdk.CryptoKey, error) { + k.mu.Lock() + defer k.mu.Unlock() - if d.closed { - d.keysMu.RUnlock() - return "", nil, ErrClosed + if k.closed { + return codersdk.CryptoKey{}, ErrClosed } - latest := d.latestKey - d.keysMu.RUnlock() - - now := d.clock.Now() - if latest.CanSign(now) { - return idSecret(latest) + var key codersdk.CryptoKey + var ok bool + for key, ok = k.key(sequence); !ok && k.fetching && !k.closed; { + k.cond.Wait() } - d.keysMu.Lock() - defer d.keysMu.Unlock() - - if d.closed { - return "", nil, ErrClosed + if k.closed { + return codersdk.CryptoKey{}, ErrClosed } - if d.latestKey.CanSign(now) { - return idSecret(d.latestKey) + if ok { + return checkKey(key, sequence, k.clock.Now()) } - // Refetch all keys for this feature so we can find the latest valid key. - err := d.fetch(ctx) + k.fetching = true + k.mu.Unlock() + + keys, err := k.cryptoKeys(ctx) if err != nil { - return "", nil, xerrors.Errorf("fetch: %w", err) + return codersdk.CryptoKey{}, xerrors.Errorf("get keys: %w", err) } - return idSecret(d.latestKey) -} + k.mu.Lock() + k.lastFetch = k.clock.Now() + k.refresher.Reset(refreshInterval) + k.keys = keys + k.fetching = false + k.cond.Broadcast() -// clear invalidates the cache. This forces the subsequent call to fetch fresh keys. -func (d *dbCache) clear() { - now := d.clock.Now("DBCache", "clear") - d.keysMu.Lock() - defer d.keysMu.Unlock() - // Check if we raced with a fetch. It's possible that the timer fired and we - // lost the race to the mutex. We want to avoid invalidating - // a cache that was just refetched. - if now.Before(d.invalidateAt) { - return + key, ok = k.key(sequence) + if !ok { + return codersdk.CryptoKey{}, ErrKeyNotFound } - d.keys = nil - d.latestKey = database.CryptoKey{} + + return checkKey(key, sequence, k.clock.Now()) } -// fetch fetches all keys for the given feature and determines the latest key. -// It must be called while holding the keysMu lock. -func (d *dbCache) fetch(ctx context.Context) error { - keys, err := d.db.GetCryptoKeysByFeature(ctx, d.feature) - if err != nil { - return xerrors.Errorf("get crypto keys by feature: %w", err) +func (k *CryptoKeyCache) key(sequence int32) (codersdk.CryptoKey, bool) { + if sequence == latestSequence { + return k.keys[latestSequence], k.keys[latestSequence].CanSign(k.clock.Now()) } - now := d.clock.Now() - _ = d.timer.Reset(time.Minute * 10) - d.invalidateAt = now.Add(time.Minute * 10) + key, ok := k.keys[sequence] + return key, ok +} - cache := make(map[int32]database.CryptoKey) - var latest database.CryptoKey - for _, key := range keys { - cache[key.Sequence] = key - if key.CanSign(now) && key.Sequence > latest.Sequence { - latest = key +func checkKey(key codersdk.CryptoKey, sequence int32, now time.Time) (codersdk.CryptoKey, error) { + if sequence == latestSequence { + if !key.CanSign(now) { + return codersdk.CryptoKey{}, ErrKeyInvalid } + return key, nil } - if len(cache) == 0 { - return ErrKeyNotFound - } - - if !latest.CanSign(now) { - return ErrKeyInvalid + if !key.CanVerify(now) { + return codersdk.CryptoKey{}, ErrKeyInvalid } - d.keys, d.latestKey = cache, latest - return nil + return key, nil } -func checkKey(key database.CryptoKey, now time.Time) (interface{}, error) { - if !key.CanVerify(now) { - return nil, ErrKeyInvalid +// refresh fetches the keys and updates the cache. +func (k *CryptoKeyCache) refresh() { + now := k.clock.Now("CryptoKeyCache", "refresh") + k.mu.Lock() + + if k.closed { + k.mu.Unlock() + return } - return key.DecodeString() -} + // If something's already fetching, we don't need to do anything. + if k.fetching { + k.mu.Unlock() + return + } + + // There's a window we must account for where the timer fires while a fetch + // is ongoing but prior to the timer getting reset. In this case we want to + // avoid double fetching. + if now.Sub(k.lastFetch) < refreshInterval { + k.mu.Unlock() + return + } -func (d *dbCache) Close() error { - d.keysMu.Lock() - defer d.keysMu.Unlock() + k.fetching = true - if d.closed { - return nil + k.mu.Unlock() + keys, err := k.cryptoKeys(k.refreshCtx) + if err != nil { + k.logger.Error(k.refreshCtx, "fetch crypto keys", slog.Error(err)) + return } - d.timer.Stop() - d.closed = true - return nil + k.mu.Lock() + defer k.mu.Unlock() + + k.lastFetch = k.clock.Now() + k.refresher.Reset(refreshInterval) + k.keys = keys + k.fetching = false + k.cond.Broadcast() } -func isEncryptionKeyFeature(feature database.CryptoKeyFeature) bool { - return feature == database.CryptoKeyFeatureWorkspaceApps +// cryptoKeys queries the control plane for the crypto keys. +// Outside of initialization, this should only be called by fetch. +func (k *CryptoKeyCache) cryptoKeys(ctx context.Context) (map[int32]codersdk.CryptoKey, error) { + keys, err := k.fetcher.Fetch(ctx) + if err != nil { + return nil, xerrors.Errorf("crypto keys: %w", err) + } + cache := toKeyMap(keys, k.clock.Now()) + return cache, nil } -func isSigningKeyFeature(feature database.CryptoKeyFeature) bool { - switch feature { - case database.CryptoKeyFeatureTailnetResume, database.CryptoKeyFeatureOidcConvert: - return true - default: - return false +func toKeyMap(keys []codersdk.CryptoKey, now time.Time) map[int32]codersdk.CryptoKey { + m := make(map[int32]codersdk.CryptoKey) + var latest codersdk.CryptoKey + for _, key := range keys { + m[key.Sequence] = key + if key.Sequence > latest.Sequence && key.CanSign(now) { + m[latestSequence] = key + } } + return m } -func idSecret(k database.CryptoKey) (string, interface{}, error) { - key, err := k.DecodeString() - if err != nil { - return "", nil, xerrors.Errorf("decode key: %w", err) +func (k *CryptoKeyCache) Close() error { + k.mu.Lock() + defer k.mu.Unlock() + + if k.closed { + return nil } - return strconv.FormatInt(int64(k.Sequence), 10), key, nil + k.closed = true + k.refreshCancel() + k.refresher.Stop() + k.cond.Broadcast() + + return nil } diff --git a/coderd/cryptokeys/dbkeycache_internal_test.go b/coderd/cryptokeys/dbkeycache_internal_test.go index c27bc5b8468ad..4a01c0bc0d05e 100644 --- a/coderd/cryptokeys/dbkeycache_internal_test.go +++ b/coderd/cryptokeys/dbkeycache_internal_test.go @@ -2,9 +2,8 @@ package cryptokeys import ( "database/sql" - "strconv" + "encoding/hex" "testing" - "time" "github.com/stretchr/testify/require" "go.uber.org/mock/gomock" @@ -13,6 +12,7 @@ import ( "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/dbmock" + "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/testutil" "github.com/coder/quartz" ) @@ -24,31 +24,35 @@ func Test_version(t *testing.T) { t.Parallel() var ( - ctrl = gomock.NewController(t) - mockDB = dbmock.NewMockStore(ctrl) - clock = quartz.NewMock(t) - logger = slogtest.Make(t, nil) - ctx = testutil.Context(t, testutil.WaitShort) + ctrl = gomock.NewController(t) + mockDB = dbmock.NewMockStore(ctrl) + clock = quartz.NewMock(t) + logger = slogtest.Make(t, nil) + ctx = testutil.Context(t, testutil.WaitShort) + fetcher = &DBFetcher{DB: mockDB, Feature: database.CryptoKeyFeatureWorkspaceApps} ) - expectedKey := database.CryptoKey{ - Feature: database.CryptoKeyFeatureWorkspaceApps, + mockDB.EXPECT().GetCryptoKeysByFeature(ctx, database.CryptoKeyFeatureWorkspaceApps).Return([]database.CryptoKey{}, nil) + + expectedKey := codersdk.CryptoKey{ + Feature: codersdk.CryptoKeyFeatureWorkspaceApp, Sequence: 32, - Secret: sql.NullString{ - String: mustGenerateKey(t), - Valid: true, - }, + Secret: mustGenerateKey(t), } - cache := map[int32]database.CryptoKey{ - 32: expectedKey, + cache := map[int32]codersdk.CryptoKey{ + 32: { + Feature: codersdk.CryptoKeyFeatureWorkspaceApp, + Sequence: 32, + Secret: mustGenerateKey(t), + }, } - k := newDBCache(logger, mockDB, database.CryptoKeyFeatureWorkspaceApps, WithDBCacheClock(clock)) + k, err := newDBCache(ctx, logger, fetcher, database.CryptoKeyFeatureWorkspaceApps, WithDBCacheClock(clock)) defer k.Close() k.keys = cache - secret, err := k.sequence(ctx, keyID(expectedKey)) + secret, err := k.cryptoKey(ctx, keyID(expectedKey)) require.NoError(t, err) require.Equal(t, decodedSecret(t, expectedKey), secret) }) @@ -64,22 +68,20 @@ func Test_version(t *testing.T) { logger = slogtest.Make(t, nil) ) - expectedKey := database.CryptoKey{ - Feature: database.CryptoKeyFeatureWorkspaceApps, + expectedKey := codersdk.CryptoKey{ + Feature: codersdk.CryptoKeyFeatureWorkspaceApp, Sequence: 33, - StartsAt: clock.Now(), - Secret: sql.NullString{ - String: mustGenerateKey(t), - Valid: true, - }, + Secret: mustGenerateKey(t), } - mockDB.EXPECT().GetCryptoKeysByFeature(ctx, database.CryptoKeyFeatureWorkspaceApps).Return([]database.CryptoKey{expectedKey}, nil) + mockDB.EXPECT().GetCryptoKeysByFeature(ctx, database.CryptoKeyFeatureWorkspaceApps).Return([]database.CryptoKey{toDBKey(expectedKey)}, nil) - k := newDBCache(logger, mockDB, database.CryptoKeyFeatureWorkspaceApps, WithDBCacheClock(clock)) + fetcher := &DBFetcher{DB: mockDB, Feature: database.CryptoKeyFeatureWorkspaceApps} + + k, err := newDBCache(ctx, logger, fetcher, database.CryptoKeyFeatureWorkspaceApps, WithDBCacheClock(clock)) defer k.Close() - got, err := k.sequence(ctx, keyID(expectedKey)) + got, err := k.cryptoKey(ctx, keyID(expectedKey)) require.NoError(t, err) require.Equal(t, decodedSecret(t, expectedKey), got) }) @@ -95,26 +97,21 @@ func Test_version(t *testing.T) { logger = slogtest.Make(t, nil) ) - cache := map[int32]database.CryptoKey{ + cache := map[int32]codersdk.CryptoKey{ 32: { - Feature: database.CryptoKeyFeatureWorkspaceApps, + Feature: codersdk.CryptoKeyFeatureWorkspaceApp, Sequence: 32, - Secret: sql.NullString{ - String: mustGenerateKey(t), - Valid: true, - }, - DeletesAt: sql.NullTime{ - Time: clock.Now(), - Valid: true, - }, + Secret: mustGenerateKey(t), }, } - k := newDBCache(logger, mockDB, database.CryptoKeyFeatureWorkspaceApps, WithDBCacheClock(clock)) + fetcher := &DBFetcher{DB: mockDB, Feature: database.CryptoKeyFeatureWorkspaceApps} + + k, err := newDBCache(ctx, logger, fetcher, database.CryptoKeyFeatureWorkspaceApps, WithDBCacheClock(clock)) defer k.Close() k.keys = cache - _, err := k.sequence(ctx, "32") + _, err = k.cryptoKey(ctx, 32) require.ErrorIs(t, err, ErrKeyInvalid) }) @@ -143,333 +140,335 @@ func Test_version(t *testing.T) { } mockDB.EXPECT().GetCryptoKeysByFeature(ctx, database.CryptoKeyFeatureWorkspaceApps).Return([]database.CryptoKey{invalidKey}, nil) - k := newDBCache(logger, mockDB, database.CryptoKeyFeatureWorkspaceApps, WithDBCacheClock(clock)) - defer k.Close() - - _, err := k.sequence(ctx, keyID(invalidKey)) - require.ErrorIs(t, err, ErrKeyInvalid) - }) -} - -func Test_latest(t *testing.T) { - t.Parallel() - - t.Run("HitsCache", func(t *testing.T) { - t.Parallel() - - var ( - ctrl = gomock.NewController(t) - mockDB = dbmock.NewMockStore(ctrl) - clock = quartz.NewMock(t) - ctx = testutil.Context(t, testutil.WaitShort) - logger = slogtest.Make(t, nil) - ) - - latestKey := database.CryptoKey{ - Feature: database.CryptoKeyFeatureWorkspaceApps, - Sequence: 32, - Secret: sql.NullString{ - String: mustGenerateKey(t), - Valid: true, - }, - StartsAt: clock.Now(), - } - k := newDBCache(logger, mockDB, database.CryptoKeyFeatureWorkspaceApps, WithDBCacheClock(clock)) - defer k.Close() - - k.latestKey = latestKey - - id, secret, err := k.latest(ctx) - require.NoError(t, err) - require.Equal(t, keyID(latestKey), id) - require.Equal(t, decodedSecret(t, latestKey), secret) - }) - - t.Run("InvalidCachedKey", func(t *testing.T) { - t.Parallel() - - var ( - ctrl = gomock.NewController(t) - mockDB = dbmock.NewMockStore(ctrl) - clock = quartz.NewMock(t) - ctx = testutil.Context(t, testutil.WaitShort) - logger = slogtest.Make(t, nil) - ) - - latestKey := database.CryptoKey{ - Feature: database.CryptoKeyFeatureWorkspaceApps, - Sequence: 33, - Secret: sql.NullString{ - String: mustGenerateKey(t), - Valid: true, - }, - StartsAt: clock.Now(), - } - - invalidKey := database.CryptoKey{ - Feature: database.CryptoKeyFeatureWorkspaceApps, - Sequence: 32, - Secret: sql.NullString{ - String: mustGenerateKey(t), - Valid: true, - }, - StartsAt: clock.Now().Add(-time.Hour), - DeletesAt: sql.NullTime{ - Time: clock.Now(), - Valid: true, - }, - } - - mockDB.EXPECT().GetCryptoKeysByFeature(ctx, database.CryptoKeyFeatureWorkspaceApps).Return([]database.CryptoKey{latestKey}, nil) - - k := newDBCache(logger, mockDB, database.CryptoKeyFeatureWorkspaceApps, WithDBCacheClock(clock)) - defer k.Close() - k.latestKey = invalidKey - - id, secret, err := k.latest(ctx) - require.NoError(t, err) - require.Equal(t, keyID(latestKey), id) - require.Equal(t, decodedSecret(t, latestKey), secret) - }) - - t.Run("UsesActiveKey", func(t *testing.T) { - t.Parallel() - - var ( - ctrl = gomock.NewController(t) - mockDB = dbmock.NewMockStore(ctrl) - clock = quartz.NewMock(t) - ctx = testutil.Context(t, testutil.WaitShort) - logger = slogtest.Make(t, nil) - ) - - inactiveKey := database.CryptoKey{ - Feature: database.CryptoKeyFeatureWorkspaceApps, - Sequence: 32, - Secret: sql.NullString{ - String: mustGenerateKey(t), - Valid: true, - }, - StartsAt: clock.Now().Add(time.Hour), - } - - activeKey := database.CryptoKey{ - Feature: database.CryptoKeyFeatureWorkspaceApps, - Sequence: 33, - Secret: sql.NullString{ - String: mustGenerateKey(t), - Valid: true, - }, - StartsAt: clock.Now(), - } - - mockDB.EXPECT().GetCryptoKeysByFeature(ctx, database.CryptoKeyFeatureWorkspaceApps).Return([]database.CryptoKey{inactiveKey, activeKey}, nil) + fetcher := &DBFetcher{DB: mockDB, Feature: database.CryptoKeyFeatureWorkspaceApps} - k := newDBCache(logger, mockDB, database.CryptoKeyFeatureWorkspaceApps, WithDBCacheClock(clock)) + k, err := newDBCache(ctx, logger, fetcher, database.CryptoKeyFeatureWorkspaceApps, WithDBCacheClock(clock)) defer k.Close() - id, secret, err := k.latest(ctx) - require.NoError(t, err) - require.Equal(t, keyID(activeKey), id) - require.Equal(t, decodedSecret(t, activeKey), secret) - }) - - t.Run("NoValidKeys", func(t *testing.T) { - t.Parallel() - - var ( - ctrl = gomock.NewController(t) - mockDB = dbmock.NewMockStore(ctrl) - clock = quartz.NewMock(t) - ctx = testutil.Context(t, testutil.WaitShort) - logger = slogtest.Make(t, nil) - ) - - inactiveKey := database.CryptoKey{ - Feature: database.CryptoKeyFeatureWorkspaceApps, - Sequence: 32, - Secret: sql.NullString{ - String: mustGenerateKey(t), - Valid: true, - }, - StartsAt: clock.Now().Add(time.Hour), - } - - invalidKey := database.CryptoKey{ - Feature: database.CryptoKeyFeatureWorkspaceApps, - Sequence: 33, - Secret: sql.NullString{ - String: mustGenerateKey(t), - Valid: true, - }, - StartsAt: clock.Now().Add(-time.Hour), - DeletesAt: sql.NullTime{ - Time: clock.Now(), - Valid: true, - }, - } - - mockDB.EXPECT().GetCryptoKeysByFeature(ctx, database.CryptoKeyFeatureWorkspaceApps).Return([]database.CryptoKey{inactiveKey, invalidKey}, nil) - - k := newDBCache(logger, mockDB, database.CryptoKeyFeatureWorkspaceApps, WithDBCacheClock(clock)) - defer k.Close() - - _, _, err := k.latest(ctx) + _, err = k.cryptoKey(ctx, invalidKey.Sequence) require.ErrorIs(t, err, ErrKeyInvalid) }) } -func Test_clear(t *testing.T) { - t.Parallel() - - t.Run("InvalidatesCache", func(t *testing.T) { - t.Parallel() - - var ( - ctrl = gomock.NewController(t) - mockDB = dbmock.NewMockStore(ctrl) - clock = quartz.NewMock(t) - ctx = testutil.Context(t, testutil.WaitShort) - logger = slogtest.Make(t, nil) - ) - - k := newDBCache(logger, mockDB, database.CryptoKeyFeatureWorkspaceApps, WithDBCacheClock(clock)) - defer k.Close() - - activeKey := database.CryptoKey{ - Feature: database.CryptoKeyFeatureWorkspaceApps, - Sequence: 33, - Secret: sql.NullString{ - String: mustGenerateKey(t), - Valid: true, - }, - StartsAt: clock.Now(), - } - - mockDB.EXPECT().GetCryptoKeysByFeature(ctx, database.CryptoKeyFeatureWorkspaceApps).Return([]database.CryptoKey{activeKey}, nil) - - _, _, err := k.latest(ctx) - require.NoError(t, err) - - dur, wait := clock.AdvanceNext() - wait.MustWait(ctx) - require.Equal(t, time.Minute*10, dur) - require.Len(t, k.keys, 0) - require.Equal(t, database.CryptoKey{}, k.latestKey) - }) - - t.Run("ResetsTimer", func(t *testing.T) { - t.Parallel() - - var ( - ctrl = gomock.NewController(t) - mockDB = dbmock.NewMockStore(ctrl) - clock = quartz.NewMock(t) - ctx = testutil.Context(t, testutil.WaitShort) - logger = slogtest.Make(t, nil) - ) - - k := newDBCache(logger, mockDB, database.CryptoKeyFeatureWorkspaceApps, WithDBCacheClock(clock)) - defer k.Close() - - key := database.CryptoKey{ - Feature: database.CryptoKeyFeatureWorkspaceApps, - Sequence: 32, - Secret: sql.NullString{ - String: mustGenerateKey(t), - Valid: true, - }, - StartsAt: clock.Now(), - } - - mockDB.EXPECT().GetCryptoKeysByFeature(ctx, database.CryptoKeyFeatureWorkspaceApps).Return([]database.CryptoKey{key}, nil) - - // Advance it five minutes so that we can test that the - // timer is reset and doesn't fire after another five minute. - clock.Advance(time.Minute * 5) - - id, secret, err := k.latest(ctx) - require.NoError(t, err) - require.Equal(t, keyID(key), id) - require.Equal(t, decodedSecret(t, key), secret) - - // Advancing the clock now should require 10 minutes - // before the timer fires again. - dur, wait := clock.AdvanceNext() - wait.MustWait(ctx) - require.Equal(t, time.Minute*10, dur) - require.Len(t, k.keys, 0) - require.Equal(t, database.CryptoKey{}, k.latestKey) - }) - - // InvalidateAt tests that we have accounted for the race condition where a - // timer fires to invalidate the cache at the same time we are fetching new - // keys. In such cases we want to skip invalidation. - t.Run("InvalidateAt", func(t *testing.T) { - t.Parallel() - - var ( - ctrl = gomock.NewController(t) - mockDB = dbmock.NewMockStore(ctrl) - clock = quartz.NewMock(t) - ctx = testutil.Context(t, testutil.WaitShort) - logger = slogtest.Make(t, nil) - ) - - trap := clock.Trap().Now("clear") - - k := newDBCache(logger, mockDB, database.CryptoKeyFeatureWorkspaceApps, WithDBCacheClock(clock)) - defer k.Close() - - key := database.CryptoKey{ - Feature: database.CryptoKeyFeatureWorkspaceApps, - Sequence: 32, - Secret: sql.NullString{ - String: mustGenerateKey(t), - Valid: true, - }, - StartsAt: clock.Now(), - } - - mockDB.EXPECT().GetCryptoKeysByFeature(ctx, database.CryptoKeyFeatureWorkspaceApps).Return([]database.CryptoKey{key}, nil).Times(2) - - // Move us past the initial timer. - id, secret, err := k.latest(ctx) - require.NoError(t, err) - require.Equal(t, keyID(key), id) - require.Equal(t, decodedSecret(t, key), secret) - // Null these out so that we refetch. - k.keys = nil - k.latestKey = database.CryptoKey{} - - // Initiate firing the timer. - dur, wait := clock.AdvanceNext() - require.Equal(t, time.Minute*10, dur) - // Trap the function just before acquiring the mutex. - call := trap.MustWait(ctx) - - // Refetch keys. - id, secret, err = k.latest(ctx) - require.NoError(t, err) - require.Equal(t, keyID(key), id) - require.Equal(t, decodedSecret(t, key), secret) - - // Let the rest of the timer function run. - // It should see that we have refetched keys and - // not invalidate. - call.Release() - wait.MustWait(ctx) - require.Len(t, k.keys, 1) - require.Equal(t, key, k.latestKey) - trap.Close() - - // Refetching the keys should've instantiated a new timer. This one should invalidate keys. - _, wait = clock.AdvanceNext() - wait.MustWait(ctx) - require.Len(t, k.keys, 0) - require.Equal(t, database.CryptoKey{}, k.latestKey) - }) -} +// func Test_latest(t *testing.T) { +// t.Parallel() + +// t.Run("HitsCache", func(t *testing.T) { +// t.Parallel() + +// var ( +// ctrl = gomock.NewController(t) +// mockDB = dbmock.NewMockStore(ctrl) +// clock = quartz.NewMock(t) +// ctx = testutil.Context(t, testutil.WaitShort) +// logger = slogtest.Make(t, nil) +// ) + +// latestKey := database.CryptoKey{ +// Feature: database.CryptoKeyFeatureWorkspaceApps, +// Sequence: 32, +// Secret: sql.NullString{ +// String: mustGenerateKey(t), +// Valid: true, +// }, +// StartsAt: clock.Now(), +// } +// fetcher := &DBFetcher{DB: mockDB, Feature: database.CryptoKeyFeatureWorkspaceApps} + +// k, err := newDBCache(ctx, logger, fetcher, database.CryptoKeyFeatureWorkspaceApps, WithDBCacheClock(clock)) +// defer k.Close() + +// id, secret, err := k.latest(ctx) +// require.NoError(t, err) +// require.Equal(t, keyID(latestKey), id) +// require.Equal(t, decodedSecret(t, latestKey), secret) +// }) + +// t.Run("InvalidCachedKey", func(t *testing.T) { +// t.Parallel() + +// var ( +// ctrl = gomock.NewController(t) +// mockDB = dbmock.NewMockStore(ctrl) +// clock = quartz.NewMock(t) +// ctx = testutil.Context(t, testutil.WaitShort) +// logger = slogtest.Make(t, nil) +// ) + +// latestKey := database.CryptoKey{ +// Feature: database.CryptoKeyFeatureWorkspaceApps, +// Sequence: 33, +// Secret: sql.NullString{ +// String: mustGenerateKey(t), +// Valid: true, +// }, +// StartsAt: clock.Now(), +// } + +// invalidKey := database.CryptoKey{ +// Feature: database.CryptoKeyFeatureWorkspaceApps, +// Sequence: 32, +// Secret: sql.NullString{ +// String: mustGenerateKey(t), +// Valid: true, +// }, +// StartsAt: clock.Now().Add(-time.Hour), +// DeletesAt: sql.NullTime{ +// Time: clock.Now(), +// Valid: true, +// }, +// } + +// mockDB.EXPECT().GetCryptoKeysByFeature(ctx, database.CryptoKeyFeatureWorkspaceApps).Return([]database.CryptoKey{latestKey}, nil) + +// k := newDBCache(logger, mockDB, database.CryptoKeyFeatureWorkspaceApps, WithDBCacheClock(clock)) +// defer k.Close() +// k.latestKey = invalidKey + +// id, secret, err := k.latest(ctx) +// require.NoError(t, err) +// require.Equal(t, keyID(latestKey), id) +// require.Equal(t, decodedSecret(t, latestKey), secret) +// }) + +// t.Run("UsesActiveKey", func(t *testing.T) { +// t.Parallel() + +// var ( +// ctrl = gomock.NewController(t) +// mockDB = dbmock.NewMockStore(ctrl) +// clock = quartz.NewMock(t) +// ctx = testutil.Context(t, testutil.WaitShort) +// logger = slogtest.Make(t, nil) +// ) + +// inactiveKey := database.CryptoKey{ +// Feature: database.CryptoKeyFeatureWorkspaceApps, +// Sequence: 32, +// Secret: sql.NullString{ +// String: mustGenerateKey(t), +// Valid: true, +// }, +// StartsAt: clock.Now().Add(time.Hour), +// } + +// activeKey := database.CryptoKey{ +// Feature: database.CryptoKeyFeatureWorkspaceApps, +// Sequence: 33, +// Secret: sql.NullString{ +// String: mustGenerateKey(t), +// Valid: true, +// }, +// StartsAt: clock.Now(), +// } + +// mockDB.EXPECT().GetCryptoKeysByFeature(ctx, database.CryptoKeyFeatureWorkspaceApps).Return([]database.CryptoKey{inactiveKey, activeKey}, nil) + +// k := newDBCache(logger, mockDB, database.CryptoKeyFeatureWorkspaceApps, WithDBCacheClock(clock)) +// defer k.Close() + +// id, secret, err := k.latest(ctx) +// require.NoError(t, err) +// require.Equal(t, keyID(activeKey), id) +// require.Equal(t, decodedSecret(t, activeKey), secret) +// }) + +// t.Run("NoValidKeys", func(t *testing.T) { +// t.Parallel() + +// var ( +// ctrl = gomock.NewController(t) +// mockDB = dbmock.NewMockStore(ctrl) +// clock = quartz.NewMock(t) +// ctx = testutil.Context(t, testutil.WaitShort) +// logger = slogtest.Make(t, nil) +// ) + +// inactiveKey := database.CryptoKey{ +// Feature: database.CryptoKeyFeatureWorkspaceApps, +// Sequence: 32, +// Secret: sql.NullString{ +// String: mustGenerateKey(t), +// Valid: true, +// }, +// StartsAt: clock.Now().Add(time.Hour), +// } + +// invalidKey := database.CryptoKey{ +// Feature: database.CryptoKeyFeatureWorkspaceApps, +// Sequence: 33, +// Secret: sql.NullString{ +// String: mustGenerateKey(t), +// Valid: true, +// }, +// StartsAt: clock.Now().Add(-time.Hour), +// DeletesAt: sql.NullTime{ +// Time: clock.Now(), +// Valid: true, +// }, +// } + +// mockDB.EXPECT().GetCryptoKeysByFeature(ctx, database.CryptoKeyFeatureWorkspaceApps).Return([]database.CryptoKey{inactiveKey, invalidKey}, nil) + +// k := newDBCache(logger, mockDB, database.CryptoKeyFeatureWorkspaceApps, WithDBCacheClock(clock)) +// defer k.Close() + +// _, _, err := k.latest(ctx) +// require.ErrorIs(t, err, ErrKeyInvalid) +// }) +// } + +// func Test_clear(t *testing.T) { +// t.Parallel() + +// t.Run("InvalidatesCache", func(t *testing.T) { +// t.Parallel() + +// var ( +// ctrl = gomock.NewController(t) +// mockDB = dbmock.NewMockStore(ctrl) +// clock = quartz.NewMock(t) +// ctx = testutil.Context(t, testutil.WaitShort) +// logger = slogtest.Make(t, nil) +// ) + +// k := newDBCache(logger, mockDB, database.CryptoKeyFeatureWorkspaceApps, WithDBCacheClock(clock)) +// defer k.Close() + +// activeKey := database.CryptoKey{ +// Feature: database.CryptoKeyFeatureWorkspaceApps, +// Sequence: 33, +// Secret: sql.NullString{ +// String: mustGenerateKey(t), +// Valid: true, +// }, +// StartsAt: clock.Now(), +// } + +// mockDB.EXPECT().GetCryptoKeysByFeature(ctx, database.CryptoKeyFeatureWorkspaceApps).Return([]database.CryptoKey{activeKey}, nil) + +// _, _, err := k.latest(ctx) +// require.NoError(t, err) + +// dur, wait := clock.AdvanceNext() +// wait.MustWait(ctx) +// require.Equal(t, time.Minute*10, dur) +// require.Len(t, k.keys, 0) +// require.Equal(t, database.CryptoKey{}, k.latestKey) +// }) + +// t.Run("ResetsTimer", func(t *testing.T) { +// t.Parallel() + +// var ( +// ctrl = gomock.NewController(t) +// mockDB = dbmock.NewMockStore(ctrl) +// clock = quartz.NewMock(t) +// ctx = testutil.Context(t, testutil.WaitShort) +// logger = slogtest.Make(t, nil) +// ) + +// k := newDBCache(logger, mockDB, database.CryptoKeyFeatureWorkspaceApps, WithDBCacheClock(clock)) +// defer k.Close() + +// key := database.CryptoKey{ +// Feature: database.CryptoKeyFeatureWorkspaceApps, +// Sequence: 32, +// Secret: sql.NullString{ +// String: mustGenerateKey(t), +// Valid: true, +// }, +// StartsAt: clock.Now(), +// } + +// mockDB.EXPECT().GetCryptoKeysByFeature(ctx, database.CryptoKeyFeatureWorkspaceApps).Return([]database.CryptoKey{key}, nil) + +// // Advance it five minutes so that we can test that the +// // timer is reset and doesn't fire after another five minute. +// clock.Advance(time.Minute * 5) + +// id, secret, err := k.latest(ctx) +// require.NoError(t, err) +// require.Equal(t, keyID(key), id) +// require.Equal(t, decodedSecret(t, key), secret) + +// // Advancing the clock now should require 10 minutes +// // before the timer fires again. +// dur, wait := clock.AdvanceNext() +// wait.MustWait(ctx) +// require.Equal(t, time.Minute*10, dur) +// require.Len(t, k.keys, 0) +// require.Equal(t, database.CryptoKey{}, k.latestKey) +// }) + +// // InvalidateAt tests that we have accounted for the race condition where a +// // timer fires to invalidate the cache at the same time we are fetching new +// // keys. In such cases we want to skip invalidation. +// t.Run("InvalidateAt", func(t *testing.T) { +// t.Parallel() + +// var ( +// ctrl = gomock.NewController(t) +// mockDB = dbmock.NewMockStore(ctrl) +// clock = quartz.NewMock(t) +// ctx = testutil.Context(t, testutil.WaitShort) +// logger = slogtest.Make(t, nil) +// ) + +// trap := clock.Trap().Now("clear") + +// k := newDBCache(logger, mockDB, database.CryptoKeyFeatureWorkspaceApps, WithDBCacheClock(clock)) +// defer k.Close() + +// key := database.CryptoKey{ +// Feature: database.CryptoKeyFeatureWorkspaceApps, +// Sequence: 32, +// Secret: sql.NullString{ +// String: mustGenerateKey(t), +// Valid: true, +// }, +// StartsAt: clock.Now(), +// } + +// mockDB.EXPECT().GetCryptoKeysByFeature(ctx, database.CryptoKeyFeatureWorkspaceApps).Return([]database.CryptoKey{key}, nil).Times(2) + +// // Move us past the initial timer. +// id, secret, err := k.latest(ctx) +// require.NoError(t, err) +// require.Equal(t, keyID(key), id) +// require.Equal(t, decodedSecret(t, key), secret) +// // Null these out so that we refetch. +// k.keys = nil +// k.latestKey = database.CryptoKey{} + +// // Initiate firing the timer. +// dur, wait := clock.AdvanceNext() +// require.Equal(t, time.Minute*10, dur) +// // Trap the function just before acquiring the mutex. +// call := trap.MustWait(ctx) + +// // Refetch keys. +// id, secret, err = k.latest(ctx) +// require.NoError(t, err) +// require.Equal(t, keyID(key), id) +// require.Equal(t, decodedSecret(t, key), secret) + +// // Let the rest of the timer function run. +// // It should see that we have refetched keys and +// // not invalidate. +// call.Release() +// wait.MustWait(ctx) +// require.Len(t, k.keys, 1) +// require.Equal(t, key, k.latestKey) +// trap.Close() + +// // Refetching the keys should've instantiated a new timer. This one should invalidate keys. +// _, wait = clock.AdvanceNext() +// wait.MustWait(ctx) +// require.Len(t, k.keys, 0) +// require.Equal(t, database.CryptoKey{}, k.latestKey) +// }) +// } func mustGenerateKey(t *testing.T) string { t.Helper() @@ -478,13 +477,24 @@ func mustGenerateKey(t *testing.T) string { return key } -func keyID(key database.CryptoKey) string { - return strconv.FormatInt(int64(key.Sequence), 10) +func keyID(key codersdk.CryptoKey) int32 { + return key.Sequence } -func decodedSecret(t *testing.T, key database.CryptoKey) []byte { +func decodedSecret(t *testing.T, key codersdk.CryptoKey) []byte { t.Helper() - decoded, err := key.DecodeString() + decoded, err := hex.DecodeString(key.Secret) require.NoError(t, err) return decoded } + +func toDBKey(key codersdk.CryptoKey) database.CryptoKey { + return database.CryptoKey{ + Feature: database.CryptoKeyFeatureWorkspaceApps, + Sequence: key.Sequence, + Secret: sql.NullString{ + String: key.Secret, + Valid: key.Secret != "", + }, + } +} diff --git a/coderd/cryptokeys/dbkeycache_test.go b/coderd/cryptokeys/dbkeycache_test.go index e24ef16660db1..cd68a196c493f 100644 --- a/coderd/cryptokeys/dbkeycache_test.go +++ b/coderd/cryptokeys/dbkeycache_test.go @@ -31,10 +31,11 @@ func TestDBKeyCache(t *testing.T) { t.Parallel() var ( - db, _ = dbtestutil.NewDB(t) - clock = quartz.NewMock(t) - ctx = testutil.Context(t, testutil.WaitShort) - logger = slogtest.Make(t, nil) + db, _ = dbtestutil.NewDB(t) + clock = quartz.NewMock(t) + ctx = testutil.Context(t, testutil.WaitShort) + logger = slogtest.Make(t, nil) + fetcher = &cryptokeys.DBFetcher{DB: db, Feature: database.CryptoKeyFeatureOidcConvert} ) key := dbgen.CryptoKey(t, db, database.CryptoKey{ @@ -43,7 +44,7 @@ func TestDBKeyCache(t *testing.T) { StartsAt: clock.Now().UTC(), }) - k, err := cryptokeys.NewSigningCache(logger, db, database.CryptoKeyFeatureOidcConvert, cryptokeys.WithDBCacheClock(clock)) + k, err := cryptokeys.NewSigningCache(ctx, logger, fetcher, database.CryptoKeyFeatureOidcConvert, cryptokeys.WithDBCacheClock(clock)) require.NoError(t, err) defer k.Close() @@ -56,13 +57,14 @@ func TestDBKeyCache(t *testing.T) { t.Parallel() var ( - db, _ = dbtestutil.NewDB(t) - clock = quartz.NewMock(t) - ctx = testutil.Context(t, testutil.WaitShort) - logger = slogtest.Make(t, nil) + db, _ = dbtestutil.NewDB(t) + clock = quartz.NewMock(t) + ctx = testutil.Context(t, testutil.WaitShort) + logger = slogtest.Make(t, nil) + fetcher = &cryptokeys.DBFetcher{DB: db, Feature: database.CryptoKeyFeatureOidcConvert} ) - k, err := cryptokeys.NewSigningCache(logger, db, database.CryptoKeyFeatureOidcConvert, cryptokeys.WithDBCacheClock(clock)) + k, err := cryptokeys.NewSigningCache(ctx, logger, fetcher, database.CryptoKeyFeatureOidcConvert, cryptokeys.WithDBCacheClock(clock)) require.NoError(t, err) defer k.Close() @@ -75,10 +77,11 @@ func TestDBKeyCache(t *testing.T) { t.Parallel() var ( - db, _ = dbtestutil.NewDB(t) - clock = quartz.NewMock(t) - ctx = testutil.Context(t, testutil.WaitShort) - logger = slogtest.Make(t, nil) + db, _ = dbtestutil.NewDB(t) + clock = quartz.NewMock(t) + ctx = testutil.Context(t, testutil.WaitShort) + logger = slogtest.Make(t, nil) + fetcher = &cryptokeys.DBFetcher{DB: db, Feature: database.CryptoKeyFeatureOidcConvert} ) _ = dbgen.CryptoKey(t, db, database.CryptoKey{ @@ -99,7 +102,7 @@ func TestDBKeyCache(t *testing.T) { StartsAt: clock.Now().UTC(), }) - k, err := cryptokeys.NewSigningCache(logger, db, database.CryptoKeyFeatureOidcConvert, cryptokeys.WithDBCacheClock(clock)) + k, err := cryptokeys.NewSigningCache(ctx, logger, fetcher, database.CryptoKeyFeatureOidcConvert, cryptokeys.WithDBCacheClock(clock)) require.NoError(t, err) defer k.Close() @@ -113,10 +116,11 @@ func TestDBKeyCache(t *testing.T) { t.Parallel() var ( - db, _ = dbtestutil.NewDB(t) - clock = quartz.NewMock(t) - ctx = testutil.Context(t, testutil.WaitShort) - logger = slogtest.Make(t, nil) + db, _ = dbtestutil.NewDB(t) + clock = quartz.NewMock(t) + ctx = testutil.Context(t, testutil.WaitShort) + logger = slogtest.Make(t, nil) + fetcher = &cryptokeys.DBFetcher{DB: db, Feature: database.CryptoKeyFeatureOidcConvert} ) expectedKey := dbgen.CryptoKey(t, db, database.CryptoKey{ @@ -125,7 +129,7 @@ func TestDBKeyCache(t *testing.T) { StartsAt: clock.Now(), }) - k, err := cryptokeys.NewSigningCache(logger, db, database.CryptoKeyFeatureOidcConvert, cryptokeys.WithDBCacheClock(clock)) + k, err := cryptokeys.NewSigningCache(ctx, logger, fetcher, database.CryptoKeyFeatureOidcConvert, cryptokeys.WithDBCacheClock(clock)) require.NoError(t, err) defer k.Close() @@ -151,17 +155,18 @@ func TestDBKeyCache(t *testing.T) { t.Parallel() var ( - db, _ = dbtestutil.NewDB(t) - clock = quartz.NewMock(t) - logger = slogtest.Make(t, nil) - ctx = testutil.Context(t, testutil.WaitShort) + db, _ = dbtestutil.NewDB(t) + clock = quartz.NewMock(t) + logger = slogtest.Make(t, nil) + fetcher = &cryptokeys.DBFetcher{DB: db, Feature: database.CryptoKeyFeatureOidcConvert} + ctx = testutil.Context(t, testutil.WaitShort) ) - _, err := cryptokeys.NewSigningCache(logger, db, database.CryptoKeyFeatureWorkspaceApps, cryptokeys.WithDBCacheClock(clock)) + _, err := cryptokeys.NewSigningCache(ctx, logger, fetcher, database.CryptoKeyFeatureWorkspaceApps, cryptokeys.WithDBCacheClock(clock)) require.ErrorIs(t, err, cryptokeys.ErrInvalidFeature) // Instantiate a signing cache and try to use it as an encryption cache. - sc, err := cryptokeys.NewSigningCache(logger, db, database.CryptoKeyFeatureOidcConvert, cryptokeys.WithDBCacheClock(clock)) + sc, err := cryptokeys.NewSigningCache(ctx, logger, fetcher, database.CryptoKeyFeatureOidcConvert, cryptokeys.WithDBCacheClock(clock)) require.NoError(t, err) defer sc.Close() @@ -178,17 +183,18 @@ func TestDBKeyCache(t *testing.T) { t.Parallel() var ( - db, _ = dbtestutil.NewDB(t) - clock = quartz.NewMock(t) - logger = slogtest.Make(t, nil) - ctx = testutil.Context(t, testutil.WaitShort) + db, _ = dbtestutil.NewDB(t) + clock = quartz.NewMock(t) + logger = slogtest.Make(t, nil) + fetcher = &cryptokeys.DBFetcher{DB: db, Feature: database.CryptoKeyFeatureOidcConvert} + ctx = testutil.Context(t, testutil.WaitShort) ) - _, err := cryptokeys.NewEncryptionCache(logger, db, database.CryptoKeyFeatureOidcConvert, cryptokeys.WithDBCacheClock(clock)) + _, err := cryptokeys.NewEncryptionCache(ctx, logger, fetcher, database.CryptoKeyFeatureOidcConvert, cryptokeys.WithDBCacheClock(clock)) require.ErrorIs(t, err, cryptokeys.ErrInvalidFeature) // Instantiate an encryption cache and try to use it as a signing cache. - ec, err := cryptokeys.NewEncryptionCache(logger, db, database.CryptoKeyFeatureWorkspaceApps, cryptokeys.WithDBCacheClock(clock)) + ec, err := cryptokeys.NewEncryptionCache(ctx, logger, fetcher, database.CryptoKeyFeatureWorkspaceApps, cryptokeys.WithDBCacheClock(clock)) require.NoError(t, err) defer ec.Close() diff --git a/coderd/cryptokeys/keycache.go b/coderd/cryptokeys/keycache.go index 05c80a15b2378..076256448d659 100644 --- a/coderd/cryptokeys/keycache.go +++ b/coderd/cryptokeys/keycache.go @@ -5,6 +5,8 @@ import ( "io" "golang.org/x/xerrors" + + "github.com/coder/coder/v2/codersdk" ) var ( @@ -14,6 +16,10 @@ var ( ErrInvalidFeature = xerrors.New("invalid feature for this operation") ) +type Fetcher interface { + Fetch(ctx context.Context) ([]codersdk.CryptoKey, error) +} + type EncryptionKeycache interface { // EncryptingKey returns the latest valid key for encrypting payloads. A valid // key is one that is both past its start time and before its deletion time. diff --git a/coderd/jwtutils/jwt_test.go b/coderd/jwtutils/jwt_test.go index ff30f7716b310..46bb45a2cecd5 100644 --- a/coderd/jwtutils/jwt_test.go +++ b/coderd/jwtutils/jwt_test.go @@ -238,10 +238,11 @@ func TestJWS(t *testing.T) { Feature: database.CryptoKeyFeatureOidcConvert, StartsAt: time.Now(), }) - log = slogtest.Make(t, nil) + log = slogtest.Make(t, nil) + fetcher = &cryptokeys.DBFetcher{DB: db, Feature: database.CryptoKeyFeatureOidcConvert} ) - cache, err := cryptokeys.NewSigningCache(log, db, database.CryptoKeyFeatureOidcConvert) + cache, err := cryptokeys.NewSigningCache(ctx, log, fetcher, database.CryptoKeyFeatureOidcConvert) require.NoError(t, err) claims := testClaims{ @@ -328,9 +329,11 @@ func TestJWE(t *testing.T) { StartsAt: time.Now(), }) log = slogtest.Make(t, nil) + + fetcher = &cryptokeys.DBFetcher{DB: db, Feature: database.CryptoKeyFeatureOidcConvert} ) - cache, err := cryptokeys.NewEncryptionCache(log, db, database.CryptoKeyFeatureWorkspaceApps) + cache, err := cryptokeys.NewEncryptionCache(ctx, log, fetcher, database.CryptoKeyFeatureWorkspaceApps) require.NoError(t, err) claims := testClaims{ From 50bb6e85f0f63bac1925b6f3790351c8958620ad Mon Sep 17 00:00:00 2001 From: Jon Ayers Date: Tue, 15 Oct 2024 17:29:06 +0000 Subject: [PATCH 2/7] wip --- coderd/cryptokeys/dbkeycache.go | 16 +- coderd/cryptokeys/dbkeycache_internal_test.go | 837 +++++++++--------- enterprise/wsproxy/keycache_test.go | 485 ---------- 3 files changed, 420 insertions(+), 918 deletions(-) diff --git a/coderd/cryptokeys/dbkeycache.go b/coderd/cryptokeys/dbkeycache.go index c2fb3b55372a0..26941192619b2 100644 --- a/coderd/cryptokeys/dbkeycache.go +++ b/coderd/cryptokeys/dbkeycache.go @@ -44,7 +44,7 @@ type CryptoKeyCache struct { refreshCancel context.CancelFunc fetcher Fetcher logger slog.Logger - feature database.CryptoKeyFeature + feature codersdk.CryptoKeyFeature mu sync.Mutex keys map[int32]codersdk.CryptoKey @@ -65,7 +65,7 @@ func WithDBCacheClock(clock quartz.Clock) DBCacheOption { // NewSigningCache creates a new DBCache. Close should be called to // release resources associated with its internal timer. -func NewSigningCache(ctx context.Context, logger slog.Logger, fetcher Fetcher, feature database.CryptoKeyFeature, opts ...func(*CryptoKeyCache)) (SigningKeycache, error) { +func NewSigningCache(ctx context.Context, logger slog.Logger, fetcher Fetcher, feature codersdk.CryptoKeyFeature, opts ...func(*CryptoKeyCache)) (SigningKeycache, error) { if !isSigningKeyFeature(feature) { return nil, ErrInvalidFeature } @@ -73,7 +73,7 @@ func NewSigningCache(ctx context.Context, logger slog.Logger, fetcher Fetcher, f return newDBCache(ctx, logger, fetcher, feature, opts...) } -func NewEncryptionCache(ctx context.Context, logger slog.Logger, fetcher Fetcher, feature database.CryptoKeyFeature, opts ...func(*CryptoKeyCache)) (EncryptionKeycache, error) { +func NewEncryptionCache(ctx context.Context, logger slog.Logger, fetcher Fetcher, feature codersdk.CryptoKeyFeature, opts ...func(*CryptoKeyCache)) (EncryptionKeycache, error) { if !isEncryptionKeyFeature(feature) { return nil, ErrInvalidFeature } @@ -81,7 +81,7 @@ func NewEncryptionCache(ctx context.Context, logger slog.Logger, fetcher Fetcher return newDBCache(ctx, logger, fetcher, feature, opts...) } -func newDBCache(ctx context.Context, logger slog.Logger, fetcher Fetcher, feature database.CryptoKeyFeature, opts ...func(*CryptoKeyCache)) (*CryptoKeyCache, error) { +func newDBCache(ctx context.Context, logger slog.Logger, fetcher Fetcher, feature codersdk.CryptoKeyFeature, opts ...func(*CryptoKeyCache)) (*CryptoKeyCache, error) { cache := &CryptoKeyCache{ clock: quartz.NewReal(), logger: logger, @@ -178,13 +178,13 @@ func (d *CryptoKeyCache) VerifyingKey(ctx context.Context, sequence string) (int return key.Secret, nil } -func isEncryptionKeyFeature(feature database.CryptoKeyFeature) bool { - return feature == database.CryptoKeyFeatureWorkspaceApps +func isEncryptionKeyFeature(feature codersdk.CryptoKeyFeature) bool { + return feature == codersdk.CryptoKeyFeatureWorkspaceApp } -func isSigningKeyFeature(feature database.CryptoKeyFeature) bool { +func isSigningKeyFeature(feature codersdk.CryptoKeyFeature) bool { switch feature { - case database.CryptoKeyFeatureTailnetResume, database.CryptoKeyFeatureOidcConvert: + case codersdk.CryptoKeyFeatureTailnetResume, codersdk.CryptoKeyFeatureOIDCConvert: return true default: return false diff --git a/coderd/cryptokeys/dbkeycache_internal_test.go b/coderd/cryptokeys/dbkeycache_internal_test.go index 4a01c0bc0d05e..73b647c6e61bf 100644 --- a/coderd/cryptokeys/dbkeycache_internal_test.go +++ b/coderd/cryptokeys/dbkeycache_internal_test.go @@ -1,500 +1,487 @@ package cryptokeys import ( - "database/sql" - "encoding/hex" + "context" "testing" + "time" "github.com/stretchr/testify/require" - "go.uber.org/mock/gomock" "cdr.dev/slog/sloggers/slogtest" - "github.com/coder/coder/v2/coderd/database" - "github.com/coder/coder/v2/coderd/database/dbmock" "github.com/coder/coder/v2/codersdk" + "github.com/coder/coder/v2/enterprise/wsproxy" "github.com/coder/coder/v2/testutil" "github.com/coder/quartz" ) -func Test_version(t *testing.T) { +func TestCryptoKeyCache(t *testing.T) { t.Parallel() - t.Run("HitsCache", func(t *testing.T) { + t.Run("Signing", func(t *testing.T) { t.Parallel() - var ( - ctrl = gomock.NewController(t) - mockDB = dbmock.NewMockStore(ctrl) - clock = quartz.NewMock(t) - logger = slogtest.Make(t, nil) - ctx = testutil.Context(t, testutil.WaitShort) - fetcher = &DBFetcher{DB: mockDB, Feature: database.CryptoKeyFeatureWorkspaceApps} - ) + t.Run("HitsCache", func(t *testing.T) { + t.Parallel() + var ( + ctx = testutil.Context(t, testutil.WaitShort) + logger = slogtest.Make(t, nil) + clock = quartz.NewMock(t) + ) - mockDB.EXPECT().GetCryptoKeysByFeature(ctx, database.CryptoKeyFeatureWorkspaceApps).Return([]database.CryptoKey{}, nil) + now := clock.Now().UTC() + expected := codersdk.CryptoKey{ + Feature: codersdk.CryptoKeyFeatureWorkspaceApp, + Secret: "key2", + Sequence: 2, + StartsAt: now, + } + + ff := &fakeFetcher{ + keys: []codersdk.CryptoKey{expected}, + } + + cache, err := NewSigningCache(ctx, logger, ff, codersdk.CryptoKeyFeatureWorkspaceApp, WithDBCacheClock(clock)) + require.NoError(t, err) + + id, got, err := cache.SigningKey(ctx) + require.NoError(t, err) + require.Equal(t, expected, got) + require.Equal(t, 1, ff.called) + require.Equal(t, "2", id) + }) + + t.Run("MissesCache", func(t *testing.T) { + t.Parallel() + var ( + ctx = testutil.Context(t, testutil.WaitShort) + logger = slogtest.Make(t, nil) + clock = quartz.NewMock(t) + ) + + ff := &fakeFetcher{ + keys: []codersdk.CryptoKey{}, + } + + cache, err := NewSigningCache(ctx, logger, ff, codersdk.CryptoKeyFeatureWorkspaceApp, WithDBCacheClock(clock)) + require.NoError(t, err) + + expected := codersdk.CryptoKey{ + Feature: codersdk.CryptoKeyFeatureWorkspaceApp, + Secret: "key1", + Sequence: 12, + StartsAt: clock.Now().UTC(), + } + ff.keys = []codersdk.CryptoKey{expected} + + id, got, err := cache.SigningKey(ctx) + require.NoError(t, err) + require.Equal(t, expected, got) + require.Equal(t, "12", id) + // 1 on startup + missing cache. + require.Equal(t, 2, ff.called) + + // Ensure the cache gets hit this time. + id, got, err = cache.SigningKey(ctx) + require.NoError(t, err) + require.Equal(t, expected, got) + require.Equal(t, "12", id) + // 1 on startup + missing cache. + require.Equal(t, 2, ff.called) + }) + + t.Run("IgnoresInvalid", func(t *testing.T) { + t.Parallel() + + var ( + ctx = testutil.Context(t, testutil.WaitShort) + logger = slogtest.Make(t, nil) + clock = quartz.NewMock(t) + ) + now := clock.Now().UTC() + expected := codersdk.CryptoKey{ + Feature: codersdk.CryptoKeyFeatureWorkspaceApp, + Secret: "key1", + Sequence: 1, + StartsAt: clock.Now().UTC(), + } + + ff := &fakeFetcher{ + keys: []codersdk.CryptoKey{ + expected, + { + Feature: codersdk.CryptoKeyFeatureWorkspaceApp, + Secret: "key2", + Sequence: 2, + StartsAt: now.Add(-time.Second), + DeletesAt: now, + }, + }, + } + + cache, err := NewSigningCache(ctx, logger, ff, codersdk.CryptoKeyFeatureWorkspaceApp, WithDBCacheClock(clock)) + require.NoError(t, err) + + id, got, err := cache.SigningKey(ctx) + require.NoError(t, err) + require.Equal(t, expected, got) + require.Equal(t, "2", id) + require.Equal(t, 1, ff.called) + }) + + t.Run("KeyNotFound", func(t *testing.T) { + t.Parallel() + + var ( + ctx = testutil.Context(t, testutil.WaitShort) + logger = slogtest.Make(t, nil) + ) + + ff := &fakeFetcher{ + keys: []codersdk.CryptoKey{}, + } + + cache, err := NewSigningCache(ctx, logger, ff, codersdk.CryptoKeyFeatureWorkspaceApp) + require.NoError(t, err) + + _, _, err = cache.SigningKey(ctx) + require.ErrorIs(t, err, ErrKeyNotFound) + }) + }) - expectedKey := codersdk.CryptoKey{ - Feature: codersdk.CryptoKeyFeatureWorkspaceApp, - Sequence: 32, - Secret: mustGenerateKey(t), - } + t.Run("Verifying", func(t *testing.T) { + t.Parallel() - cache := map[int32]codersdk.CryptoKey{ - 32: { - Feature: codersdk.CryptoKeyFeatureWorkspaceApp, - Sequence: 32, - Secret: mustGenerateKey(t), - }, - } + t.Run("HitsCache", func(t *testing.T) { + t.Parallel() - k, err := newDBCache(ctx, logger, fetcher, database.CryptoKeyFeatureWorkspaceApps, WithDBCacheClock(clock)) - defer k.Close() - k.keys = cache + var ( + ctx = testutil.Context(t, testutil.WaitShort) + logger = slogtest.Make(t, nil) + clock = quartz.NewMock(t) + ) - secret, err := k.cryptoKey(ctx, keyID(expectedKey)) - require.NoError(t, err) - require.Equal(t, decodedSecret(t, expectedKey), secret) + now := clock.Now().UTC() + expected := codersdk.CryptoKey{ + Feature: codersdk.CryptoKeyFeatureWorkspaceApp, + Secret: "key1", + Sequence: 12, + StartsAt: now, + } + ff := &fakeFetcher{ + keys: []codersdk.CryptoKey{ + expected, + { + Feature: codersdk.CryptoKeyFeatureWorkspaceApp, + Secret: "key2", + Sequence: 13, + StartsAt: now, + }, + }, + } + + cache, err := wsproxy.NewCryptoKeyCache(ctx, logger, ff, withClock(clock)) + require.NoError(t, err) + + got, err := cache.Verifying(ctx, expected.Sequence) + require.NoError(t, err) + require.Equal(t, expected, got) + require.Equal(t, 1, ff.called) + }) + + t.Run("MissesCache", func(t *testing.T) { + t.Parallel() + var ( + ctx = testutil.Context(t, testutil.WaitShort) + logger = slogtest.Make(t, nil) + clock = quartz.NewMock(t) + ) + + ff := &fakeFetcher{ + keys: []codersdk.CryptoKey{}, + } + + cache, err := wsproxy.NewCryptoKeyCache(ctx, logger, ff, withClock(clock)) + require.NoError(t, err) + + expected := codersdk.CryptoKey{ + Feature: codersdk.CryptoKeyFeatureWorkspaceApp, + Secret: "key1", + Sequence: 12, + StartsAt: clock.Now().UTC(), + } + ff.keys = []codersdk.CryptoKey{expected} + + got, err := cache.Verifying(ctx, expected.Sequence) + require.NoError(t, err) + require.Equal(t, expected, got) + require.Equal(t, 2, ff.called) + + // Ensure the cache gets hit this time. + got, err = cache.Verifying(ctx, expected.Sequence) + require.NoError(t, err) + require.Equal(t, expected, got) + require.Equal(t, 2, ff.called) + }) + + t.Run("AllowsBeforeStartsAt", func(t *testing.T) { + t.Parallel() + + var ( + ctx = testutil.Context(t, testutil.WaitShort) + logger = slogtest.Make(t, nil) + clock = quartz.NewMock(t) + ) + + now := clock.Now().UTC() + expected := codersdk.CryptoKey{ + Feature: codersdk.CryptoKeyFeatureWorkspaceApp, + Secret: "key1", + Sequence: 12, + StartsAt: now.Add(-time.Second), + } + + ff := &fakeFetcher{ + keys: []codersdk.CryptoKey{ + expected, + }, + } + + cache, err := wsproxy.NewCryptoKeyCache(ctx, logger, ff, withClock(clock)) + require.NoError(t, err) + + got, err := cache.Verifying(ctx, expected.Sequence) + require.NoError(t, err) + require.Equal(t, expected, got) + require.Equal(t, 1, ff.called) + }) + + t.Run("KeyInvalid", func(t *testing.T) { + t.Parallel() + + var ( + ctx = testutil.Context(t, testutil.WaitShort) + logger = slogtest.Make(t, nil) + clock = quartz.NewMock(t) + ) + + now := clock.Now().UTC() + expected := codersdk.CryptoKey{ + Feature: codersdk.CryptoKeyFeatureWorkspaceApp, + Secret: "key1", + Sequence: 12, + StartsAt: now.Add(-time.Second), + DeletesAt: now, + } + + ff := &fakeFetcher{ + keys: []codersdk.CryptoKey{ + expected, + }, + } + + cache, err := wsproxy.NewCryptoKeyCache(ctx, logger, ff, withClock(clock)) + require.NoError(t, err) + + _, err = cache.Verifying(ctx, expected.Sequence) + require.ErrorIs(t, err, ErrKeyInvalid) + require.Equal(t, 1, ff.called) + }) + + t.Run("KeyNotFound", func(t *testing.T) { + t.Parallel() + + var ( + ctx = testutil.Context(t, testutil.WaitShort) + logger = slogtest.Make(t, nil) + clock = quartz.NewMock(t) + ) + + ff := &fakeFetcher{ + keys: []codersdk.CryptoKey{}, + } + + cache, err := wsproxy.NewCryptoKeyCache(ctx, logger, ff, withClock(clock)) + require.NoError(t, err) + + _, err = cache.Verifying(ctx, 1) + require.ErrorIs(t, err, ErrKeyNotFound) + }) }) - t.Run("MissesCache", func(t *testing.T) { + t.Run("CacheRefreshes", func(t *testing.T) { t.Parallel() var ( - ctrl = gomock.NewController(t) - mockDB = dbmock.NewMockStore(ctrl) - clock = quartz.NewMock(t) ctx = testutil.Context(t, testutil.WaitShort) logger = slogtest.Make(t, nil) + clock = quartz.NewMock(t) ) - expectedKey := codersdk.CryptoKey{ - Feature: codersdk.CryptoKeyFeatureWorkspaceApp, - Sequence: 33, - Secret: mustGenerateKey(t), + now := clock.Now().UTC() + expected := codersdk.CryptoKey{ + Feature: codersdk.CryptoKeyFeatureWorkspaceApp, + Secret: "key1", + Sequence: 12, + StartsAt: now, + DeletesAt: now.Add(time.Minute * 10), + } + ff := &fakeFetcher{ + keys: []codersdk.CryptoKey{ + expected, + }, } - mockDB.EXPECT().GetCryptoKeysByFeature(ctx, database.CryptoKeyFeatureWorkspaceApps).Return([]database.CryptoKey{toDBKey(expectedKey)}, nil) + cache, err := wsproxy.NewCryptoKeyCache(ctx, logger, ff, withClock(clock)) + require.NoError(t, err) - fetcher := &DBFetcher{DB: mockDB, Feature: database.CryptoKeyFeatureWorkspaceApps} + got, err := cache.Signing(ctx) + require.NoError(t, err) + require.Equal(t, expected, got) + require.Equal(t, 1, ff.called) - k, err := newDBCache(ctx, logger, fetcher, database.CryptoKeyFeatureWorkspaceApps, WithDBCacheClock(clock)) - defer k.Close() + newKey := codersdk.CryptoKey{ + Feature: codersdk.CryptoKeyFeatureWorkspaceApp, + Secret: "key2", + Sequence: 13, + StartsAt: now, + } + ff.keys = []codersdk.CryptoKey{newKey} - got, err := k.cryptoKey(ctx, keyID(expectedKey)) + // The ticker should fire and cause a request to coderd. + dur, advance := clock.AdvanceNext() + advance.MustWait(ctx) + require.Equal(t, 2, ff.called) + require.Equal(t, time.Minute*10, dur) + + // Assert hits cache. + got, err = cache.Signing(ctx) require.NoError(t, err) - require.Equal(t, decodedSecret(t, expectedKey), got) + require.Equal(t, newKey, got) + require.Equal(t, 2, ff.called) + + // We check again to ensure the timer has been reset. + _, advance = clock.AdvanceNext() + advance.MustWait(ctx) + require.Equal(t, 3, ff.called) + require.Equal(t, time.Minute*10, dur) }) - t.Run("InvalidCachedKey", func(t *testing.T) { + // This test ensures that if the refresh timer races with an inflight request + // and loses that it doesn't cause a redundant fetch. + + t.Run("RefreshNoDoubleFetch", func(t *testing.T) { t.Parallel() var ( - ctrl = gomock.NewController(t) - mockDB = dbmock.NewMockStore(ctrl) - clock = quartz.NewMock(t) ctx = testutil.Context(t, testutil.WaitShort) logger = slogtest.Make(t, nil) + clock = quartz.NewMock(t) ) - cache := map[int32]codersdk.CryptoKey{ - 32: { - Feature: codersdk.CryptoKeyFeatureWorkspaceApp, - Sequence: 32, - Secret: mustGenerateKey(t), + now := clock.Now().UTC() + expected := codersdk.CryptoKey{ + Feature: codersdk.CryptoKeyFeatureWorkspaceApp, + Secret: "key1", + Sequence: 12, + StartsAt: now, + DeletesAt: now.Add(time.Minute * 10), + } + ff := &fakeFetcher{ + keys: []codersdk.CryptoKey{ + expected, }, } - fetcher := &DBFetcher{DB: mockDB, Feature: database.CryptoKeyFeatureWorkspaceApps} + // Create a trap that blocks when the refresh timer fires. + trap := clock.Trap().Now("refresh") + cache, err := wsproxy.NewCryptoKeyCache(ctx, logger, ff, withClock(clock)) + require.NoError(t, err) + + _, wait := clock.AdvanceNext() + trapped := trap.MustWait(ctx) - k, err := newDBCache(ctx, logger, fetcher, database.CryptoKeyFeatureWorkspaceApps, WithDBCacheClock(clock)) - defer k.Close() - k.keys = cache + newKey := codersdk.CryptoKey{ + Feature: codersdk.CryptoKeyFeatureWorkspaceApp, + Secret: "key2", + Sequence: 13, + StartsAt: now, + } + ff.keys = []codersdk.CryptoKey{newKey} - _, err = k.cryptoKey(ctx, 32) - require.ErrorIs(t, err, ErrKeyInvalid) + _, err = cache.Verifying(ctx, newKey.Sequence) + require.NoError(t, err) + require.Equal(t, 2, ff.called) + + trapped.Release() + wait.MustWait(ctx) + require.Equal(t, 2, ff.called) + trap.Close() + + // The next timer should fire in 10 minutes. + dur, wait := clock.AdvanceNext() + wait.MustWait(ctx) + require.Equal(t, time.Minute*10, dur) + require.Equal(t, 3, ff.called) }) - t.Run("InvalidDBKey", func(t *testing.T) { + t.Run("Closed", func(t *testing.T) { t.Parallel() var ( - ctrl = gomock.NewController(t) - mockDB = dbmock.NewMockStore(ctrl) - clock = quartz.NewMock(t) ctx = testutil.Context(t, testutil.WaitShort) logger = slogtest.Make(t, nil) + clock = quartz.NewMock(t) ) - invalidKey := database.CryptoKey{ - Feature: database.CryptoKeyFeatureWorkspaceApps, - Sequence: 32, - Secret: sql.NullString{ - String: mustGenerateKey(t), - Valid: true, - }, - DeletesAt: sql.NullTime{ - Time: clock.Now(), - Valid: true, + now := clock.Now() + expected := codersdk.CryptoKey{ + Feature: codersdk.CryptoKeyFeatureWorkspaceApp, + Secret: "key1", + Sequence: 12, + StartsAt: now, + } + ff := &fakeFetcher{ + keys: []codersdk.CryptoKey{ + expected, }, } - mockDB.EXPECT().GetCryptoKeysByFeature(ctx, database.CryptoKeyFeatureWorkspaceApps).Return([]database.CryptoKey{invalidKey}, nil) - fetcher := &DBFetcher{DB: mockDB, Feature: database.CryptoKeyFeatureWorkspaceApps} + cache, err := wsproxy.NewCryptoKeyCache(ctx, logger, ff, withClock(clock)) + require.NoError(t, err) + + got, err := cache.Signing(ctx) + require.NoError(t, err) + require.Equal(t, expected, got) + require.Equal(t, 1, ff.called) - k, err := newDBCache(ctx, logger, fetcher, database.CryptoKeyFeatureWorkspaceApps, WithDBCacheClock(clock)) - defer k.Close() + got, err = cache.Verifying(ctx, expected.Sequence) + require.NoError(t, err) + require.Equal(t, expected, got) + require.Equal(t, 1, ff.called) - _, err = k.cryptoKey(ctx, invalidKey.Sequence) - require.ErrorIs(t, err, ErrKeyInvalid) - }) -} + cache.Close() + + _, err = cache.Signing(ctx) + require.ErrorIs(t, err, ErrClosed) -// func Test_latest(t *testing.T) { -// t.Parallel() - -// t.Run("HitsCache", func(t *testing.T) { -// t.Parallel() - -// var ( -// ctrl = gomock.NewController(t) -// mockDB = dbmock.NewMockStore(ctrl) -// clock = quartz.NewMock(t) -// ctx = testutil.Context(t, testutil.WaitShort) -// logger = slogtest.Make(t, nil) -// ) - -// latestKey := database.CryptoKey{ -// Feature: database.CryptoKeyFeatureWorkspaceApps, -// Sequence: 32, -// Secret: sql.NullString{ -// String: mustGenerateKey(t), -// Valid: true, -// }, -// StartsAt: clock.Now(), -// } -// fetcher := &DBFetcher{DB: mockDB, Feature: database.CryptoKeyFeatureWorkspaceApps} - -// k, err := newDBCache(ctx, logger, fetcher, database.CryptoKeyFeatureWorkspaceApps, WithDBCacheClock(clock)) -// defer k.Close() - -// id, secret, err := k.latest(ctx) -// require.NoError(t, err) -// require.Equal(t, keyID(latestKey), id) -// require.Equal(t, decodedSecret(t, latestKey), secret) -// }) - -// t.Run("InvalidCachedKey", func(t *testing.T) { -// t.Parallel() - -// var ( -// ctrl = gomock.NewController(t) -// mockDB = dbmock.NewMockStore(ctrl) -// clock = quartz.NewMock(t) -// ctx = testutil.Context(t, testutil.WaitShort) -// logger = slogtest.Make(t, nil) -// ) - -// latestKey := database.CryptoKey{ -// Feature: database.CryptoKeyFeatureWorkspaceApps, -// Sequence: 33, -// Secret: sql.NullString{ -// String: mustGenerateKey(t), -// Valid: true, -// }, -// StartsAt: clock.Now(), -// } - -// invalidKey := database.CryptoKey{ -// Feature: database.CryptoKeyFeatureWorkspaceApps, -// Sequence: 32, -// Secret: sql.NullString{ -// String: mustGenerateKey(t), -// Valid: true, -// }, -// StartsAt: clock.Now().Add(-time.Hour), -// DeletesAt: sql.NullTime{ -// Time: clock.Now(), -// Valid: true, -// }, -// } - -// mockDB.EXPECT().GetCryptoKeysByFeature(ctx, database.CryptoKeyFeatureWorkspaceApps).Return([]database.CryptoKey{latestKey}, nil) - -// k := newDBCache(logger, mockDB, database.CryptoKeyFeatureWorkspaceApps, WithDBCacheClock(clock)) -// defer k.Close() -// k.latestKey = invalidKey - -// id, secret, err := k.latest(ctx) -// require.NoError(t, err) -// require.Equal(t, keyID(latestKey), id) -// require.Equal(t, decodedSecret(t, latestKey), secret) -// }) - -// t.Run("UsesActiveKey", func(t *testing.T) { -// t.Parallel() - -// var ( -// ctrl = gomock.NewController(t) -// mockDB = dbmock.NewMockStore(ctrl) -// clock = quartz.NewMock(t) -// ctx = testutil.Context(t, testutil.WaitShort) -// logger = slogtest.Make(t, nil) -// ) - -// inactiveKey := database.CryptoKey{ -// Feature: database.CryptoKeyFeatureWorkspaceApps, -// Sequence: 32, -// Secret: sql.NullString{ -// String: mustGenerateKey(t), -// Valid: true, -// }, -// StartsAt: clock.Now().Add(time.Hour), -// } - -// activeKey := database.CryptoKey{ -// Feature: database.CryptoKeyFeatureWorkspaceApps, -// Sequence: 33, -// Secret: sql.NullString{ -// String: mustGenerateKey(t), -// Valid: true, -// }, -// StartsAt: clock.Now(), -// } - -// mockDB.EXPECT().GetCryptoKeysByFeature(ctx, database.CryptoKeyFeatureWorkspaceApps).Return([]database.CryptoKey{inactiveKey, activeKey}, nil) - -// k := newDBCache(logger, mockDB, database.CryptoKeyFeatureWorkspaceApps, WithDBCacheClock(clock)) -// defer k.Close() - -// id, secret, err := k.latest(ctx) -// require.NoError(t, err) -// require.Equal(t, keyID(activeKey), id) -// require.Equal(t, decodedSecret(t, activeKey), secret) -// }) - -// t.Run("NoValidKeys", func(t *testing.T) { -// t.Parallel() - -// var ( -// ctrl = gomock.NewController(t) -// mockDB = dbmock.NewMockStore(ctrl) -// clock = quartz.NewMock(t) -// ctx = testutil.Context(t, testutil.WaitShort) -// logger = slogtest.Make(t, nil) -// ) - -// inactiveKey := database.CryptoKey{ -// Feature: database.CryptoKeyFeatureWorkspaceApps, -// Sequence: 32, -// Secret: sql.NullString{ -// String: mustGenerateKey(t), -// Valid: true, -// }, -// StartsAt: clock.Now().Add(time.Hour), -// } - -// invalidKey := database.CryptoKey{ -// Feature: database.CryptoKeyFeatureWorkspaceApps, -// Sequence: 33, -// Secret: sql.NullString{ -// String: mustGenerateKey(t), -// Valid: true, -// }, -// StartsAt: clock.Now().Add(-time.Hour), -// DeletesAt: sql.NullTime{ -// Time: clock.Now(), -// Valid: true, -// }, -// } - -// mockDB.EXPECT().GetCryptoKeysByFeature(ctx, database.CryptoKeyFeatureWorkspaceApps).Return([]database.CryptoKey{inactiveKey, invalidKey}, nil) - -// k := newDBCache(logger, mockDB, database.CryptoKeyFeatureWorkspaceApps, WithDBCacheClock(clock)) -// defer k.Close() - -// _, _, err := k.latest(ctx) -// require.ErrorIs(t, err, ErrKeyInvalid) -// }) -// } - -// func Test_clear(t *testing.T) { -// t.Parallel() - -// t.Run("InvalidatesCache", func(t *testing.T) { -// t.Parallel() - -// var ( -// ctrl = gomock.NewController(t) -// mockDB = dbmock.NewMockStore(ctrl) -// clock = quartz.NewMock(t) -// ctx = testutil.Context(t, testutil.WaitShort) -// logger = slogtest.Make(t, nil) -// ) - -// k := newDBCache(logger, mockDB, database.CryptoKeyFeatureWorkspaceApps, WithDBCacheClock(clock)) -// defer k.Close() - -// activeKey := database.CryptoKey{ -// Feature: database.CryptoKeyFeatureWorkspaceApps, -// Sequence: 33, -// Secret: sql.NullString{ -// String: mustGenerateKey(t), -// Valid: true, -// }, -// StartsAt: clock.Now(), -// } - -// mockDB.EXPECT().GetCryptoKeysByFeature(ctx, database.CryptoKeyFeatureWorkspaceApps).Return([]database.CryptoKey{activeKey}, nil) - -// _, _, err := k.latest(ctx) -// require.NoError(t, err) - -// dur, wait := clock.AdvanceNext() -// wait.MustWait(ctx) -// require.Equal(t, time.Minute*10, dur) -// require.Len(t, k.keys, 0) -// require.Equal(t, database.CryptoKey{}, k.latestKey) -// }) - -// t.Run("ResetsTimer", func(t *testing.T) { -// t.Parallel() - -// var ( -// ctrl = gomock.NewController(t) -// mockDB = dbmock.NewMockStore(ctrl) -// clock = quartz.NewMock(t) -// ctx = testutil.Context(t, testutil.WaitShort) -// logger = slogtest.Make(t, nil) -// ) - -// k := newDBCache(logger, mockDB, database.CryptoKeyFeatureWorkspaceApps, WithDBCacheClock(clock)) -// defer k.Close() - -// key := database.CryptoKey{ -// Feature: database.CryptoKeyFeatureWorkspaceApps, -// Sequence: 32, -// Secret: sql.NullString{ -// String: mustGenerateKey(t), -// Valid: true, -// }, -// StartsAt: clock.Now(), -// } - -// mockDB.EXPECT().GetCryptoKeysByFeature(ctx, database.CryptoKeyFeatureWorkspaceApps).Return([]database.CryptoKey{key}, nil) - -// // Advance it five minutes so that we can test that the -// // timer is reset and doesn't fire after another five minute. -// clock.Advance(time.Minute * 5) - -// id, secret, err := k.latest(ctx) -// require.NoError(t, err) -// require.Equal(t, keyID(key), id) -// require.Equal(t, decodedSecret(t, key), secret) - -// // Advancing the clock now should require 10 minutes -// // before the timer fires again. -// dur, wait := clock.AdvanceNext() -// wait.MustWait(ctx) -// require.Equal(t, time.Minute*10, dur) -// require.Len(t, k.keys, 0) -// require.Equal(t, database.CryptoKey{}, k.latestKey) -// }) - -// // InvalidateAt tests that we have accounted for the race condition where a -// // timer fires to invalidate the cache at the same time we are fetching new -// // keys. In such cases we want to skip invalidation. -// t.Run("InvalidateAt", func(t *testing.T) { -// t.Parallel() - -// var ( -// ctrl = gomock.NewController(t) -// mockDB = dbmock.NewMockStore(ctrl) -// clock = quartz.NewMock(t) -// ctx = testutil.Context(t, testutil.WaitShort) -// logger = slogtest.Make(t, nil) -// ) - -// trap := clock.Trap().Now("clear") - -// k := newDBCache(logger, mockDB, database.CryptoKeyFeatureWorkspaceApps, WithDBCacheClock(clock)) -// defer k.Close() - -// key := database.CryptoKey{ -// Feature: database.CryptoKeyFeatureWorkspaceApps, -// Sequence: 32, -// Secret: sql.NullString{ -// String: mustGenerateKey(t), -// Valid: true, -// }, -// StartsAt: clock.Now(), -// } - -// mockDB.EXPECT().GetCryptoKeysByFeature(ctx, database.CryptoKeyFeatureWorkspaceApps).Return([]database.CryptoKey{key}, nil).Times(2) - -// // Move us past the initial timer. -// id, secret, err := k.latest(ctx) -// require.NoError(t, err) -// require.Equal(t, keyID(key), id) -// require.Equal(t, decodedSecret(t, key), secret) -// // Null these out so that we refetch. -// k.keys = nil -// k.latestKey = database.CryptoKey{} - -// // Initiate firing the timer. -// dur, wait := clock.AdvanceNext() -// require.Equal(t, time.Minute*10, dur) -// // Trap the function just before acquiring the mutex. -// call := trap.MustWait(ctx) - -// // Refetch keys. -// id, secret, err = k.latest(ctx) -// require.NoError(t, err) -// require.Equal(t, keyID(key), id) -// require.Equal(t, decodedSecret(t, key), secret) - -// // Let the rest of the timer function run. -// // It should see that we have refetched keys and -// // not invalidate. -// call.Release() -// wait.MustWait(ctx) -// require.Len(t, k.keys, 1) -// require.Equal(t, key, k.latestKey) -// trap.Close() - -// // Refetching the keys should've instantiated a new timer. This one should invalidate keys. -// _, wait = clock.AdvanceNext() -// wait.MustWait(ctx) -// require.Len(t, k.keys, 0) -// require.Equal(t, database.CryptoKey{}, k.latestKey) -// }) -// } - -func mustGenerateKey(t *testing.T) string { - t.Helper() - key, err := generateKey(64) - require.NoError(t, err) - return key + _, err = cache.Verifying(ctx, expected.Sequence) + require.ErrorIs(t, err, ErrClosed) + }) } -func keyID(key codersdk.CryptoKey) int32 { - return key.Sequence +type fakeFetcher struct { + keys []codersdk.CryptoKey + called int } -func decodedSecret(t *testing.T, key codersdk.CryptoKey) []byte { - t.Helper() - decoded, err := hex.DecodeString(key.Secret) - require.NoError(t, err) - return decoded +func (f *fakeFetcher) Fetch(_ context.Context) ([]codersdk.CryptoKey, error) { + f.called++ + return f.keys, nil } -func toDBKey(key codersdk.CryptoKey) database.CryptoKey { - return database.CryptoKey{ - Feature: database.CryptoKeyFeatureWorkspaceApps, - Sequence: key.Sequence, - Secret: sql.NullString{ - String: key.Secret, - Valid: key.Secret != "", - }, +func withClock(clock quartz.Clock) func(*wsproxy.CryptoKeyCache) { + return func(cache *wsproxy.CryptoKeyCache) { + cache.Clock = clock } } diff --git a/enterprise/wsproxy/keycache_test.go b/enterprise/wsproxy/keycache_test.go index 210e04f9edf76..e69de29bb2d1d 100644 --- a/enterprise/wsproxy/keycache_test.go +++ b/enterprise/wsproxy/keycache_test.go @@ -1,485 +0,0 @@ -package wsproxy_test - -import ( - "context" - "testing" - "time" - - "github.com/stretchr/testify/require" - - "cdr.dev/slog/sloggers/slogtest" - - "github.com/coder/coder/v2/coderd/cryptokeys" - "github.com/coder/coder/v2/codersdk" - "github.com/coder/coder/v2/enterprise/wsproxy" - "github.com/coder/coder/v2/testutil" - "github.com/coder/quartz" -) - -func TestCryptoKeyCache(t *testing.T) { - t.Parallel() - - t.Run("Signing", func(t *testing.T) { - t.Parallel() - - t.Run("HitsCache", func(t *testing.T) { - t.Parallel() - var ( - ctx = testutil.Context(t, testutil.WaitShort) - logger = slogtest.Make(t, nil) - clock = quartz.NewMock(t) - ) - - now := clock.Now().UTC() - expected := codersdk.CryptoKey{ - Feature: codersdk.CryptoKeyFeatureWorkspaceApp, - Secret: "key2", - Sequence: 2, - StartsAt: now, - } - - ff := &fakeFetcher{ - keys: []codersdk.CryptoKey{expected}, - } - - cache, err := wsproxy.NewCryptoKeyCache(ctx, logger, ff, withClock(clock)) - require.NoError(t, err) - - got, err := cache.Signing(ctx) - require.NoError(t, err) - require.Equal(t, expected, got) - require.Equal(t, 1, ff.called) - }) - - t.Run("MissesCache", func(t *testing.T) { - t.Parallel() - var ( - ctx = testutil.Context(t, testutil.WaitShort) - logger = slogtest.Make(t, nil) - clock = quartz.NewMock(t) - ) - - ff := &fakeFetcher{ - keys: []codersdk.CryptoKey{}, - } - - cache, err := wsproxy.NewCryptoKeyCache(ctx, logger, ff, withClock(clock)) - require.NoError(t, err) - - expected := codersdk.CryptoKey{ - Feature: codersdk.CryptoKeyFeatureWorkspaceApp, - Secret: "key1", - Sequence: 12, - StartsAt: clock.Now().UTC(), - } - ff.keys = []codersdk.CryptoKey{expected} - - got, err := cache.Signing(ctx) - require.NoError(t, err) - require.Equal(t, expected, got) - // 1 on startup + missing cache. - require.Equal(t, 2, ff.called) - - // Ensure the cache gets hit this time. - got, err = cache.Signing(ctx) - require.NoError(t, err) - require.Equal(t, expected, got) - // 1 on startup + missing cache. - require.Equal(t, 2, ff.called) - }) - - t.Run("IgnoresInvalid", func(t *testing.T) { - t.Parallel() - - var ( - ctx = testutil.Context(t, testutil.WaitShort) - logger = slogtest.Make(t, nil) - clock = quartz.NewMock(t) - ) - now := clock.Now().UTC() - expected := codersdk.CryptoKey{ - Feature: codersdk.CryptoKeyFeatureWorkspaceApp, - Secret: "key1", - Sequence: 1, - StartsAt: clock.Now().UTC(), - } - - ff := &fakeFetcher{ - keys: []codersdk.CryptoKey{ - expected, - { - Feature: codersdk.CryptoKeyFeatureWorkspaceApp, - Secret: "key2", - Sequence: 2, - StartsAt: now.Add(-time.Second), - DeletesAt: now, - }, - }, - } - - cache, err := wsproxy.NewCryptoKeyCache(ctx, logger, ff, withClock(clock)) - require.NoError(t, err) - - got, err := cache.Signing(ctx) - require.NoError(t, err) - require.Equal(t, expected, got) - require.Equal(t, 1, ff.called) - }) - - t.Run("KeyNotFound", func(t *testing.T) { - t.Parallel() - - var ( - ctx = testutil.Context(t, testutil.WaitShort) - logger = slogtest.Make(t, nil) - clock = quartz.NewMock(t) - ) - - ff := &fakeFetcher{ - keys: []codersdk.CryptoKey{}, - } - - cache, err := wsproxy.NewCryptoKeyCache(ctx, logger, ff, withClock(clock)) - require.NoError(t, err) - - _, err = cache.Signing(ctx) - require.ErrorIs(t, err, cryptokeys.ErrKeyNotFound) - }) - }) - - t.Run("Verifying", func(t *testing.T) { - t.Parallel() - - t.Run("HitsCache", func(t *testing.T) { - t.Parallel() - - var ( - ctx = testutil.Context(t, testutil.WaitShort) - logger = slogtest.Make(t, nil) - clock = quartz.NewMock(t) - ) - - now := clock.Now().UTC() - expected := codersdk.CryptoKey{ - Feature: codersdk.CryptoKeyFeatureWorkspaceApp, - Secret: "key1", - Sequence: 12, - StartsAt: now, - } - ff := &fakeFetcher{ - keys: []codersdk.CryptoKey{ - expected, - { - Feature: codersdk.CryptoKeyFeatureWorkspaceApp, - Secret: "key2", - Sequence: 13, - StartsAt: now, - }, - }, - } - - cache, err := wsproxy.NewCryptoKeyCache(ctx, logger, ff, withClock(clock)) - require.NoError(t, err) - - got, err := cache.Verifying(ctx, expected.Sequence) - require.NoError(t, err) - require.Equal(t, expected, got) - require.Equal(t, 1, ff.called) - }) - - t.Run("MissesCache", func(t *testing.T) { - t.Parallel() - var ( - ctx = testutil.Context(t, testutil.WaitShort) - logger = slogtest.Make(t, nil) - clock = quartz.NewMock(t) - ) - - ff := &fakeFetcher{ - keys: []codersdk.CryptoKey{}, - } - - cache, err := wsproxy.NewCryptoKeyCache(ctx, logger, ff, withClock(clock)) - require.NoError(t, err) - - expected := codersdk.CryptoKey{ - Feature: codersdk.CryptoKeyFeatureWorkspaceApp, - Secret: "key1", - Sequence: 12, - StartsAt: clock.Now().UTC(), - } - ff.keys = []codersdk.CryptoKey{expected} - - got, err := cache.Verifying(ctx, expected.Sequence) - require.NoError(t, err) - require.Equal(t, expected, got) - require.Equal(t, 2, ff.called) - - // Ensure the cache gets hit this time. - got, err = cache.Verifying(ctx, expected.Sequence) - require.NoError(t, err) - require.Equal(t, expected, got) - require.Equal(t, 2, ff.called) - }) - - t.Run("AllowsBeforeStartsAt", func(t *testing.T) { - t.Parallel() - - var ( - ctx = testutil.Context(t, testutil.WaitShort) - logger = slogtest.Make(t, nil) - clock = quartz.NewMock(t) - ) - - now := clock.Now().UTC() - expected := codersdk.CryptoKey{ - Feature: codersdk.CryptoKeyFeatureWorkspaceApp, - Secret: "key1", - Sequence: 12, - StartsAt: now.Add(-time.Second), - } - - ff := &fakeFetcher{ - keys: []codersdk.CryptoKey{ - expected, - }, - } - - cache, err := wsproxy.NewCryptoKeyCache(ctx, logger, ff, withClock(clock)) - require.NoError(t, err) - - got, err := cache.Verifying(ctx, expected.Sequence) - require.NoError(t, err) - require.Equal(t, expected, got) - require.Equal(t, 1, ff.called) - }) - - t.Run("KeyInvalid", func(t *testing.T) { - t.Parallel() - - var ( - ctx = testutil.Context(t, testutil.WaitShort) - logger = slogtest.Make(t, nil) - clock = quartz.NewMock(t) - ) - - now := clock.Now().UTC() - expected := codersdk.CryptoKey{ - Feature: codersdk.CryptoKeyFeatureWorkspaceApp, - Secret: "key1", - Sequence: 12, - StartsAt: now.Add(-time.Second), - DeletesAt: now, - } - - ff := &fakeFetcher{ - keys: []codersdk.CryptoKey{ - expected, - }, - } - - cache, err := wsproxy.NewCryptoKeyCache(ctx, logger, ff, withClock(clock)) - require.NoError(t, err) - - _, err = cache.Verifying(ctx, expected.Sequence) - require.ErrorIs(t, err, cryptokeys.ErrKeyInvalid) - require.Equal(t, 1, ff.called) - }) - - t.Run("KeyNotFound", func(t *testing.T) { - t.Parallel() - - var ( - ctx = testutil.Context(t, testutil.WaitShort) - logger = slogtest.Make(t, nil) - clock = quartz.NewMock(t) - ) - - ff := &fakeFetcher{ - keys: []codersdk.CryptoKey{}, - } - - cache, err := wsproxy.NewCryptoKeyCache(ctx, logger, ff, withClock(clock)) - require.NoError(t, err) - - _, err = cache.Verifying(ctx, 1) - require.ErrorIs(t, err, cryptokeys.ErrKeyNotFound) - }) - }) - - t.Run("CacheRefreshes", func(t *testing.T) { - t.Parallel() - - var ( - ctx = testutil.Context(t, testutil.WaitShort) - logger = slogtest.Make(t, nil) - clock = quartz.NewMock(t) - ) - - now := clock.Now().UTC() - expected := codersdk.CryptoKey{ - Feature: codersdk.CryptoKeyFeatureWorkspaceApp, - Secret: "key1", - Sequence: 12, - StartsAt: now, - DeletesAt: now.Add(time.Minute * 10), - } - ff := &fakeFetcher{ - keys: []codersdk.CryptoKey{ - expected, - }, - } - - cache, err := wsproxy.NewCryptoKeyCache(ctx, logger, ff, withClock(clock)) - require.NoError(t, err) - - got, err := cache.Signing(ctx) - require.NoError(t, err) - require.Equal(t, expected, got) - require.Equal(t, 1, ff.called) - - newKey := codersdk.CryptoKey{ - Feature: codersdk.CryptoKeyFeatureWorkspaceApp, - Secret: "key2", - Sequence: 13, - StartsAt: now, - } - ff.keys = []codersdk.CryptoKey{newKey} - - // The ticker should fire and cause a request to coderd. - dur, advance := clock.AdvanceNext() - advance.MustWait(ctx) - require.Equal(t, 2, ff.called) - require.Equal(t, time.Minute*10, dur) - - // Assert hits cache. - got, err = cache.Signing(ctx) - require.NoError(t, err) - require.Equal(t, newKey, got) - require.Equal(t, 2, ff.called) - - // We check again to ensure the timer has been reset. - _, advance = clock.AdvanceNext() - advance.MustWait(ctx) - require.Equal(t, 3, ff.called) - require.Equal(t, time.Minute*10, dur) - }) - - // This test ensures that if the refresh timer races with an inflight request - // and loses that it doesn't cause a redundant fetch. - - t.Run("RefreshNoDoubleFetch", func(t *testing.T) { - t.Parallel() - - var ( - ctx = testutil.Context(t, testutil.WaitShort) - logger = slogtest.Make(t, nil) - clock = quartz.NewMock(t) - ) - - now := clock.Now().UTC() - expected := codersdk.CryptoKey{ - Feature: codersdk.CryptoKeyFeatureWorkspaceApp, - Secret: "key1", - Sequence: 12, - StartsAt: now, - DeletesAt: now.Add(time.Minute * 10), - } - ff := &fakeFetcher{ - keys: []codersdk.CryptoKey{ - expected, - }, - } - - // Create a trap that blocks when the refresh timer fires. - trap := clock.Trap().Now("refresh") - cache, err := wsproxy.NewCryptoKeyCache(ctx, logger, ff, withClock(clock)) - require.NoError(t, err) - - _, wait := clock.AdvanceNext() - trapped := trap.MustWait(ctx) - - newKey := codersdk.CryptoKey{ - Feature: codersdk.CryptoKeyFeatureWorkspaceApp, - Secret: "key2", - Sequence: 13, - StartsAt: now, - } - ff.keys = []codersdk.CryptoKey{newKey} - - _, err = cache.Verifying(ctx, newKey.Sequence) - require.NoError(t, err) - require.Equal(t, 2, ff.called) - - trapped.Release() - wait.MustWait(ctx) - require.Equal(t, 2, ff.called) - trap.Close() - - // The next timer should fire in 10 minutes. - dur, wait := clock.AdvanceNext() - wait.MustWait(ctx) - require.Equal(t, time.Minute*10, dur) - require.Equal(t, 3, ff.called) - }) - - t.Run("Closed", func(t *testing.T) { - t.Parallel() - - var ( - ctx = testutil.Context(t, testutil.WaitShort) - logger = slogtest.Make(t, nil) - clock = quartz.NewMock(t) - ) - - now := clock.Now() - expected := codersdk.CryptoKey{ - Feature: codersdk.CryptoKeyFeatureWorkspaceApp, - Secret: "key1", - Sequence: 12, - StartsAt: now, - } - ff := &fakeFetcher{ - keys: []codersdk.CryptoKey{ - expected, - }, - } - - cache, err := wsproxy.NewCryptoKeyCache(ctx, logger, ff, withClock(clock)) - require.NoError(t, err) - - got, err := cache.Signing(ctx) - require.NoError(t, err) - require.Equal(t, expected, got) - require.Equal(t, 1, ff.called) - - got, err = cache.Verifying(ctx, expected.Sequence) - require.NoError(t, err) - require.Equal(t, expected, got) - require.Equal(t, 1, ff.called) - - cache.Close() - - _, err = cache.Signing(ctx) - require.ErrorIs(t, err, cryptokeys.ErrClosed) - - _, err = cache.Verifying(ctx, expected.Sequence) - require.ErrorIs(t, err, cryptokeys.ErrClosed) - }) -} - -type fakeFetcher struct { - keys []codersdk.CryptoKey - called int -} - -func (f *fakeFetcher) Fetch(_ context.Context) ([]codersdk.CryptoKey, error) { - f.called++ - return f.keys, nil -} - -func withClock(clock quartz.Clock) func(*wsproxy.CryptoKeyCache) { - return func(cache *wsproxy.CryptoKeyCache) { - cache.Clock = clock - } -} From a26df46a57b043025f64f5fb1fc4009f6b580b1d Mon Sep 17 00:00:00 2001 From: Jon Ayers Date: Wed, 16 Oct 2024 00:19:56 +0000 Subject: [PATCH 3/7] finish refactor --- coderd/cryptokeys/cache.go | 370 ++++++++++++++++++ coderd/cryptokeys/cache_internal_test.go | 1 + ...eycache_internal_test.go => cache_test.go} | 160 +++++--- coderd/cryptokeys/dbkeycache.go | 352 ----------------- coderd/cryptokeys/dbkeycache_test.go | 222 ----------- coderd/cryptokeys/keycache.go | 47 --- coderd/jwtutils/jwt_test.go | 5 +- enterprise/wsproxy/keycache.go | 224 ----------- enterprise/wsproxy/keycache_test.go | 0 enterprise/wsproxy/keyfetcher.go | 25 ++ 10 files changed, 494 insertions(+), 912 deletions(-) create mode 100644 coderd/cryptokeys/cache.go create mode 100644 coderd/cryptokeys/cache_internal_test.go rename coderd/cryptokeys/{dbkeycache_internal_test.go => cache_test.go} (65%) delete mode 100644 coderd/cryptokeys/dbkeycache.go delete mode 100644 coderd/cryptokeys/dbkeycache_test.go delete mode 100644 coderd/cryptokeys/keycache.go delete mode 100644 enterprise/wsproxy/keycache.go delete mode 100644 enterprise/wsproxy/keycache_test.go create mode 100644 enterprise/wsproxy/keyfetcher.go diff --git a/coderd/cryptokeys/cache.go b/coderd/cryptokeys/cache.go new file mode 100644 index 0000000000000..8436b7099ad1c --- /dev/null +++ b/coderd/cryptokeys/cache.go @@ -0,0 +1,370 @@ +package cryptokeys + +import ( + "context" + "encoding/hex" + "io" + "strconv" + "sync" + "time" + + "golang.org/x/xerrors" + + "cdr.dev/slog" + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/db2sdk" + "github.com/coder/coder/v2/coderd/jwtutils" + "github.com/coder/coder/v2/codersdk" + "github.com/coder/quartz" +) + +var ( + ErrKeyNotFound = xerrors.New("key not found") + ErrKeyInvalid = xerrors.New("key is invalid for use") + ErrClosed = xerrors.New("closed") + ErrInvalidFeature = xerrors.New("invalid feature for this operation") +) + +var _ jwtutils.SigningKeyProvider = &cache{} + +type Fetcher interface { + Fetch(ctx context.Context) ([]codersdk.CryptoKey, error) +} + +type EncryptionKeycache interface { + // EncryptingKey returns the latest valid key for encrypting payloads. A valid + // key is one that is both past its start time and before its deletion time. + EncryptingKey(ctx context.Context) (id string, key interface{}, err error) + // DecryptingKey returns the key with the provided id which maps to its sequence + // number. The key is valid for decryption as long as it is not deleted or past + // its deletion date. We must allow for keys prior to their start time to + // account for clock skew between peers (one key may be past its start time on + // one machine while another is not). + DecryptingKey(ctx context.Context, id string) (key interface{}, err error) + io.Closer +} + +type SigningKeycache interface { + // SigningKey returns the latest valid key for signing. A valid key is one + // that is both past its start time and before its deletion time. + SigningKey(ctx context.Context) (id string, key interface{}, err error) + // VerifyingKey returns the key with the provided id which should map to its + // sequence number. The key is valid for verifying as long as it is not deleted + // or past its deletion date. We must allow for keys prior to their start time + // to account for clock skew between peers (one key may be past its start time + // on one machine while another is not). + VerifyingKey(ctx context.Context, id string) (key interface{}, err error) + io.Closer +} + +const ( + // latestSequence is a special sequence number that represents the latest key. + latestSequence = -1 + // refreshInterval is the interval at which the key cache will refresh. + refreshInterval = time.Minute * 10 +) + +type DBFetcher struct { + DB database.Store + Feature database.CryptoKeyFeature +} + +func (d *DBFetcher) Fetch(ctx context.Context) ([]codersdk.CryptoKey, error) { + keys, err := d.DB.GetCryptoKeysByFeature(ctx, d.Feature) + if err != nil { + return nil, xerrors.Errorf("get crypto keys by feature: %w", err) + } + + return db2sdk.CryptoKeys(keys), nil +} + +// cache implements the caching functionality for both signing and encryption keys. +type cache struct { + clock quartz.Clock + refreshCtx context.Context + refreshCancel context.CancelFunc + fetcher Fetcher + logger slog.Logger + feature codersdk.CryptoKeyFeature + + mu sync.Mutex + keys map[int32]codersdk.CryptoKey + lastFetch time.Time + refresher *quartz.Timer + fetching bool + closed bool + cond *sync.Cond +} + +type CacheOption func(*cache) + +func WithCacheClock(clock quartz.Clock) CacheOption { + return func(d *cache) { + d.clock = clock + } +} + +// NewSigningCache instantiates a cache. Close should be called to +// release resources associated with its internal timer. +func NewSigningCache(ctx context.Context, logger slog.Logger, fetcher Fetcher, feature codersdk.CryptoKeyFeature, opts ...func(*cache)) (SigningKeycache, error) { + if !isSigningKeyFeature(feature) { + return nil, xerrors.Errorf("invalid feature: %s", feature) + } + return newCache(ctx, logger, fetcher, feature, opts...) +} + +func NewEncryptionCache(ctx context.Context, logger slog.Logger, fetcher Fetcher, feature codersdk.CryptoKeyFeature, opts ...func(*cache)) (EncryptionKeycache, error) { + if !isEncryptionKeyFeature(feature) { + return nil, xerrors.Errorf("invalid feature: %s", feature) + } + return newCache(ctx, logger, fetcher, feature, opts...) +} + +func newCache(ctx context.Context, logger slog.Logger, fetcher Fetcher, feature codersdk.CryptoKeyFeature, opts ...func(*cache)) (*cache, error) { + cache := &cache{ + clock: quartz.NewReal(), + logger: logger, + fetcher: fetcher, + feature: feature, + } + + for _, opt := range opts { + opt(cache) + } + + cache.cond = sync.NewCond(&cache.mu) + cache.refreshCtx, cache.refreshCancel = context.WithCancel(ctx) + cache.refresher = cache.clock.AfterFunc(refreshInterval, cache.refresh) + + keys, err := cache.cryptoKeys(ctx) + if err != nil { + cache.refreshCancel() + return nil, xerrors.Errorf("initial fetch: %w", err) + } + cache.keys = keys + return cache, nil +} + +func (c *cache) EncryptingKey(ctx context.Context) (string, interface{}, error) { + if !isEncryptionKeyFeature(c.feature) { + return "", nil, ErrInvalidFeature + } + + return c.cryptoKey(ctx, latestSequence) +} + +func (c *cache) DecryptingKey(ctx context.Context, id string) (interface{}, error) { + if !isEncryptionKeyFeature(c.feature) { + return nil, ErrInvalidFeature + } + + seq, err := strconv.ParseInt(id, 10, 64) + if err != nil { + return nil, xerrors.Errorf("parse id: %w", err) + } + + _, secret, err := c.cryptoKey(ctx, int32(seq)) + if err != nil { + return nil, xerrors.Errorf("crypto key: %w", err) + } + return secret, nil +} + +func (c *cache) SigningKey(ctx context.Context) (string, interface{}, error) { + if !isSigningKeyFeature(c.feature) { + return "", nil, ErrInvalidFeature + } + + return c.cryptoKey(ctx, latestSequence) +} + +func (c *cache) VerifyingKey(ctx context.Context, id string) (interface{}, error) { + if !isSigningKeyFeature(c.feature) { + return nil, ErrInvalidFeature + } + + seq, err := strconv.ParseInt(id, 10, 64) + if err != nil { + return nil, xerrors.Errorf("parse id: %w", err) + } + + _, secret, err := c.cryptoKey(ctx, int32(seq)) + if err != nil { + return nil, xerrors.Errorf("crypto key: %w", err) + } + + return secret, nil +} + +func isEncryptionKeyFeature(feature codersdk.CryptoKeyFeature) bool { + return feature == codersdk.CryptoKeyFeatureWorkspaceApp +} + +func isSigningKeyFeature(feature codersdk.CryptoKeyFeature) bool { + switch feature { + case codersdk.CryptoKeyFeatureTailnetResume, codersdk.CryptoKeyFeatureOIDCConvert: + return true + default: + return false + } +} + +func idSecret(k codersdk.CryptoKey) (string, []byte, error) { + key, err := hex.DecodeString(k.Secret) + if err != nil { + return "", nil, xerrors.Errorf("decode key: %w", err) + } + + return strconv.FormatInt(int64(k.Sequence), 10), key, nil +} + +func (c *cache) cryptoKey(ctx context.Context, sequence int32) (string, []byte, error) { + c.mu.Lock() + defer c.mu.Unlock() + + if c.closed { + return "", nil, ErrClosed + } + + var key codersdk.CryptoKey + var ok bool + for key, ok = c.key(sequence); !ok && c.fetching && !c.closed; { + c.cond.Wait() + } + + if c.closed { + return "", nil, ErrClosed + } + + if ok { + return checkKey(key, sequence, c.clock.Now()) + } + + c.fetching = true + c.mu.Unlock() + + keys, err := c.cryptoKeys(ctx) + if err != nil { + return "", nil, xerrors.Errorf("get keys: %w", err) + } + + c.mu.Lock() + c.lastFetch = c.clock.Now() + c.refresher.Reset(refreshInterval) + c.keys = keys + c.fetching = false + c.cond.Broadcast() + + key, ok = c.key(sequence) + if !ok { + return "", nil, ErrKeyNotFound + } + + return checkKey(key, sequence, c.clock.Now()) +} + +func (c *cache) key(sequence int32) (codersdk.CryptoKey, bool) { + if sequence == latestSequence { + return c.keys[latestSequence], c.keys[latestSequence].CanSign(c.clock.Now()) + } + + key, ok := c.keys[sequence] + return key, ok +} + +func checkKey(key codersdk.CryptoKey, sequence int32, now time.Time) (string, []byte, error) { + if sequence == latestSequence { + if !key.CanSign(now) { + return "", nil, ErrKeyInvalid + } + return idSecret(key) + } + + if !key.CanVerify(now) { + return "", nil, ErrKeyInvalid + } + + return idSecret(key) +} + +// refresh fetches the keys and updates the cache. +func (c *cache) refresh() { + now := c.clock.Now("CryptoKeyCache", "refresh") + c.mu.Lock() + + if c.closed { + c.mu.Unlock() + return + } + + // If something's already fetching, we don't need to do anything. + if c.fetching { + c.mu.Unlock() + return + } + + // There's a window we must account for where the timer fires while a fetch + // is ongoing but prior to the timer getting reset. In this case we want to + // avoid double fetching. + if now.Sub(c.lastFetch) < refreshInterval { + c.mu.Unlock() + return + } + + c.fetching = true + + c.mu.Unlock() + keys, err := c.cryptoKeys(c.refreshCtx) + if err != nil { + c.logger.Error(c.refreshCtx, "fetch crypto keys", slog.Error(err)) + return + } + + c.mu.Lock() + defer c.mu.Unlock() + + c.lastFetch = c.clock.Now() + c.refresher.Reset(refreshInterval) + c.keys = keys + c.fetching = false + c.cond.Broadcast() +} + +// cryptoKeys queries the control plane for the crypto keys. +// Outside of initialization, this should only be called by fetch. +func (c *cache) cryptoKeys(ctx context.Context) (map[int32]codersdk.CryptoKey, error) { + keys, err := c.fetcher.Fetch(ctx) + if err != nil { + return nil, xerrors.Errorf("crypto keys: %w", err) + } + cache := toKeyMap(keys, c.clock.Now()) + return cache, nil +} + +func toKeyMap(keys []codersdk.CryptoKey, now time.Time) map[int32]codersdk.CryptoKey { + m := make(map[int32]codersdk.CryptoKey) + var latest codersdk.CryptoKey + for _, key := range keys { + m[key.Sequence] = key + if key.Sequence > latest.Sequence && key.CanSign(now) { + m[latestSequence] = key + } + } + return m +} + +func (c *cache) Close() error { + c.mu.Lock() + defer c.mu.Unlock() + + if c.closed { + return nil + } + + c.closed = true + c.refreshCancel() + c.refresher.Stop() + c.cond.Broadcast() + + return nil +} diff --git a/coderd/cryptokeys/cache_internal_test.go b/coderd/cryptokeys/cache_internal_test.go new file mode 100644 index 0000000000000..4845ac9198d2c --- /dev/null +++ b/coderd/cryptokeys/cache_internal_test.go @@ -0,0 +1 @@ +package cryptokeys diff --git a/coderd/cryptokeys/dbkeycache_internal_test.go b/coderd/cryptokeys/cache_test.go similarity index 65% rename from coderd/cryptokeys/dbkeycache_internal_test.go rename to coderd/cryptokeys/cache_test.go index 73b647c6e61bf..58699b0b4b09f 100644 --- a/coderd/cryptokeys/dbkeycache_internal_test.go +++ b/coderd/cryptokeys/cache_test.go @@ -1,20 +1,28 @@ -package cryptokeys +package cryptokeys_test import ( "context" + "crypto/rand" + "encoding/hex" + "strconv" "testing" "time" "github.com/stretchr/testify/require" + "go.uber.org/goleak" "cdr.dev/slog/sloggers/slogtest" + "github.com/coder/coder/v2/coderd/cryptokeys" "github.com/coder/coder/v2/codersdk" - "github.com/coder/coder/v2/enterprise/wsproxy" "github.com/coder/coder/v2/testutil" "github.com/coder/quartz" ) +func TestMain(m *testing.M) { + goleak.VerifyTestMain(m) +} + func TestCryptoKeyCache(t *testing.T) { t.Parallel() @@ -32,7 +40,7 @@ func TestCryptoKeyCache(t *testing.T) { now := clock.Now().UTC() expected := codersdk.CryptoKey{ Feature: codersdk.CryptoKeyFeatureWorkspaceApp, - Secret: "key2", + Secret: generateKey(t, 64), Sequence: 2, StartsAt: now, } @@ -41,14 +49,14 @@ func TestCryptoKeyCache(t *testing.T) { keys: []codersdk.CryptoKey{expected}, } - cache, err := NewSigningCache(ctx, logger, ff, codersdk.CryptoKeyFeatureWorkspaceApp, WithDBCacheClock(clock)) + cache, err := cryptokeys.NewSigningCache(ctx, logger, ff, codersdk.CryptoKeyFeatureWorkspaceApp, cryptokeys.WithCacheClock(clock)) require.NoError(t, err) id, got, err := cache.SigningKey(ctx) require.NoError(t, err) - require.Equal(t, expected, got) + require.Equal(t, keyID(expected), id) + require.Equal(t, decodedSecret(t, expected), got) require.Equal(t, 1, ff.called) - require.Equal(t, "2", id) }) t.Run("MissesCache", func(t *testing.T) { @@ -63,12 +71,12 @@ func TestCryptoKeyCache(t *testing.T) { keys: []codersdk.CryptoKey{}, } - cache, err := NewSigningCache(ctx, logger, ff, codersdk.CryptoKeyFeatureWorkspaceApp, WithDBCacheClock(clock)) + cache, err := cryptokeys.NewSigningCache(ctx, logger, ff, codersdk.CryptoKeyFeatureWorkspaceApp, cryptokeys.WithCacheClock(clock)) require.NoError(t, err) expected := codersdk.CryptoKey{ Feature: codersdk.CryptoKeyFeatureWorkspaceApp, - Secret: "key1", + Secret: generateKey(t, 64), Sequence: 12, StartsAt: clock.Now().UTC(), } @@ -76,16 +84,16 @@ func TestCryptoKeyCache(t *testing.T) { id, got, err := cache.SigningKey(ctx) require.NoError(t, err) - require.Equal(t, expected, got) - require.Equal(t, "12", id) + require.Equal(t, decodedSecret(t, expected), got) + require.Equal(t, keyID(expected), id) // 1 on startup + missing cache. require.Equal(t, 2, ff.called) // Ensure the cache gets hit this time. id, got, err = cache.SigningKey(ctx) require.NoError(t, err) - require.Equal(t, expected, got) - require.Equal(t, "12", id) + require.Equal(t, decodedSecret(t, expected), got) + require.Equal(t, keyID(expected), id) // 1 on startup + missing cache. require.Equal(t, 2, ff.called) }) @@ -99,9 +107,10 @@ func TestCryptoKeyCache(t *testing.T) { clock = quartz.NewMock(t) ) now := clock.Now().UTC() + expected := codersdk.CryptoKey{ Feature: codersdk.CryptoKeyFeatureWorkspaceApp, - Secret: "key1", + Secret: generateKey(t, 64), Sequence: 1, StartsAt: clock.Now().UTC(), } @@ -111,7 +120,7 @@ func TestCryptoKeyCache(t *testing.T) { expected, { Feature: codersdk.CryptoKeyFeatureWorkspaceApp, - Secret: "key2", + Secret: generateKey(t, 64), Sequence: 2, StartsAt: now.Add(-time.Second), DeletesAt: now, @@ -119,13 +128,13 @@ func TestCryptoKeyCache(t *testing.T) { }, } - cache, err := NewSigningCache(ctx, logger, ff, codersdk.CryptoKeyFeatureWorkspaceApp, WithDBCacheClock(clock)) + cache, err := cryptokeys.NewSigningCache(ctx, logger, ff, codersdk.CryptoKeyFeatureWorkspaceApp, cryptokeys.WithCacheClock(clock)) require.NoError(t, err) id, got, err := cache.SigningKey(ctx) require.NoError(t, err) - require.Equal(t, expected, got) - require.Equal(t, "2", id) + require.Equal(t, decodedSecret(t, expected), got) + require.Equal(t, keyID(expected), id) require.Equal(t, 1, ff.called) }) @@ -141,11 +150,11 @@ func TestCryptoKeyCache(t *testing.T) { keys: []codersdk.CryptoKey{}, } - cache, err := NewSigningCache(ctx, logger, ff, codersdk.CryptoKeyFeatureWorkspaceApp) + cache, err := cryptokeys.NewSigningCache(ctx, logger, ff, codersdk.CryptoKeyFeatureWorkspaceApp) require.NoError(t, err) _, _, err = cache.SigningKey(ctx) - require.ErrorIs(t, err, ErrKeyNotFound) + require.ErrorIs(t, err, cryptokeys.ErrKeyNotFound) }) }) @@ -164,7 +173,7 @@ func TestCryptoKeyCache(t *testing.T) { now := clock.Now().UTC() expected := codersdk.CryptoKey{ Feature: codersdk.CryptoKeyFeatureWorkspaceApp, - Secret: "key1", + Secret: generateKey(t, 64), Sequence: 12, StartsAt: now, } @@ -173,19 +182,19 @@ func TestCryptoKeyCache(t *testing.T) { expected, { Feature: codersdk.CryptoKeyFeatureWorkspaceApp, - Secret: "key2", + Secret: generateKey(t, 64), Sequence: 13, StartsAt: now, }, }, } - cache, err := wsproxy.NewCryptoKeyCache(ctx, logger, ff, withClock(clock)) + cache, err := cryptokeys.NewSigningCache(ctx, logger, ff, codersdk.CryptoKeyFeatureWorkspaceApp, cryptokeys.WithCacheClock(clock)) require.NoError(t, err) - got, err := cache.Verifying(ctx, expected.Sequence) + got, err := cache.VerifyingKey(ctx, keyID(expected)) require.NoError(t, err) - require.Equal(t, expected, got) + require.Equal(t, decodedSecret(t, expected), got) require.Equal(t, 1, ff.called) }) @@ -201,26 +210,26 @@ func TestCryptoKeyCache(t *testing.T) { keys: []codersdk.CryptoKey{}, } - cache, err := wsproxy.NewCryptoKeyCache(ctx, logger, ff, withClock(clock)) + cache, err := cryptokeys.NewSigningCache(ctx, logger, ff, codersdk.CryptoKeyFeatureWorkspaceApp, cryptokeys.WithCacheClock(clock)) require.NoError(t, err) expected := codersdk.CryptoKey{ Feature: codersdk.CryptoKeyFeatureWorkspaceApp, - Secret: "key1", + Secret: generateKey(t, 64), Sequence: 12, StartsAt: clock.Now().UTC(), } ff.keys = []codersdk.CryptoKey{expected} - got, err := cache.Verifying(ctx, expected.Sequence) + got, err := cache.VerifyingKey(ctx, keyID(expected)) require.NoError(t, err) - require.Equal(t, expected, got) + require.Equal(t, decodedSecret(t, expected), got) require.Equal(t, 2, ff.called) // Ensure the cache gets hit this time. - got, err = cache.Verifying(ctx, expected.Sequence) + got, err = cache.VerifyingKey(ctx, keyID(expected)) require.NoError(t, err) - require.Equal(t, expected, got) + require.Equal(t, decodedSecret(t, expected), got) require.Equal(t, 2, ff.called) }) @@ -236,7 +245,7 @@ func TestCryptoKeyCache(t *testing.T) { now := clock.Now().UTC() expected := codersdk.CryptoKey{ Feature: codersdk.CryptoKeyFeatureWorkspaceApp, - Secret: "key1", + Secret: generateKey(t, 64), Sequence: 12, StartsAt: now.Add(-time.Second), } @@ -247,16 +256,16 @@ func TestCryptoKeyCache(t *testing.T) { }, } - cache, err := wsproxy.NewCryptoKeyCache(ctx, logger, ff, withClock(clock)) + cache, err := cryptokeys.NewSigningCache(ctx, logger, ff, codersdk.CryptoKeyFeatureWorkspaceApp, cryptokeys.WithCacheClock(clock)) require.NoError(t, err) - got, err := cache.Verifying(ctx, expected.Sequence) + got, err := cache.VerifyingKey(ctx, keyID(expected)) require.NoError(t, err) - require.Equal(t, expected, got) + require.Equal(t, decodedSecret(t, expected), got) require.Equal(t, 1, ff.called) }) - t.Run("KeyInvalid", func(t *testing.T) { + t.Run("KeyPastDeletesAt", func(t *testing.T) { t.Parallel() var ( @@ -268,7 +277,7 @@ func TestCryptoKeyCache(t *testing.T) { now := clock.Now().UTC() expected := codersdk.CryptoKey{ Feature: codersdk.CryptoKeyFeatureWorkspaceApp, - Secret: "key1", + Secret: generateKey(t, 64), Sequence: 12, StartsAt: now.Add(-time.Second), DeletesAt: now, @@ -280,11 +289,11 @@ func TestCryptoKeyCache(t *testing.T) { }, } - cache, err := wsproxy.NewCryptoKeyCache(ctx, logger, ff, withClock(clock)) + cache, err := cryptokeys.NewSigningCache(ctx, logger, ff, codersdk.CryptoKeyFeatureWorkspaceApp, cryptokeys.WithCacheClock(clock)) require.NoError(t, err) - _, err = cache.Verifying(ctx, expected.Sequence) - require.ErrorIs(t, err, ErrKeyInvalid) + _, err = cache.VerifyingKey(ctx, keyID(expected)) + require.ErrorIs(t, err, cryptokeys.ErrKeyInvalid) require.Equal(t, 1, ff.called) }) @@ -301,11 +310,11 @@ func TestCryptoKeyCache(t *testing.T) { keys: []codersdk.CryptoKey{}, } - cache, err := wsproxy.NewCryptoKeyCache(ctx, logger, ff, withClock(clock)) + cache, err := cryptokeys.NewSigningCache(ctx, logger, ff, codersdk.CryptoKeyFeatureWorkspaceApp, cryptokeys.WithCacheClock(clock)) require.NoError(t, err) - _, err = cache.Verifying(ctx, 1) - require.ErrorIs(t, err, ErrKeyNotFound) + _, err = cache.VerifyingKey(ctx, "1") + require.ErrorIs(t, err, cryptokeys.ErrKeyNotFound) }) }) @@ -321,7 +330,7 @@ func TestCryptoKeyCache(t *testing.T) { now := clock.Now().UTC() expected := codersdk.CryptoKey{ Feature: codersdk.CryptoKeyFeatureWorkspaceApp, - Secret: "key1", + Secret: generateKey(t, 64), Sequence: 12, StartsAt: now, DeletesAt: now.Add(time.Minute * 10), @@ -332,17 +341,18 @@ func TestCryptoKeyCache(t *testing.T) { }, } - cache, err := wsproxy.NewCryptoKeyCache(ctx, logger, ff, withClock(clock)) + cache, err := cryptokeys.NewSigningCache(ctx, logger, ff, codersdk.CryptoKeyFeatureWorkspaceApp, cryptokeys.WithCacheClock(clock)) require.NoError(t, err) - got, err := cache.Signing(ctx) + id, got, err := cache.SigningKey(ctx) require.NoError(t, err) - require.Equal(t, expected, got) + require.Equal(t, decodedSecret(t, expected), got) + require.Equal(t, keyID(expected), id) require.Equal(t, 1, ff.called) newKey := codersdk.CryptoKey{ Feature: codersdk.CryptoKeyFeatureWorkspaceApp, - Secret: "key2", + Secret: generateKey(t, 64), Sequence: 13, StartsAt: now, } @@ -355,9 +365,10 @@ func TestCryptoKeyCache(t *testing.T) { require.Equal(t, time.Minute*10, dur) // Assert hits cache. - got, err = cache.Signing(ctx) + id, got, err = cache.SigningKey(ctx) require.NoError(t, err) - require.Equal(t, newKey, got) + require.Equal(t, keyID(newKey), id) + require.Equal(t, decodedSecret(t, newKey), got) require.Equal(t, 2, ff.called) // We check again to ensure the timer has been reset. @@ -382,7 +393,7 @@ func TestCryptoKeyCache(t *testing.T) { now := clock.Now().UTC() expected := codersdk.CryptoKey{ Feature: codersdk.CryptoKeyFeatureWorkspaceApp, - Secret: "key1", + Secret: generateKey(t, 64), Sequence: 12, StartsAt: now, DeletesAt: now.Add(time.Minute * 10), @@ -395,7 +406,7 @@ func TestCryptoKeyCache(t *testing.T) { // Create a trap that blocks when the refresh timer fires. trap := clock.Trap().Now("refresh") - cache, err := wsproxy.NewCryptoKeyCache(ctx, logger, ff, withClock(clock)) + cache, err := cryptokeys.NewSigningCache(ctx, logger, ff, codersdk.CryptoKeyFeatureWorkspaceApp, cryptokeys.WithCacheClock(clock)) require.NoError(t, err) _, wait := clock.AdvanceNext() @@ -409,9 +420,10 @@ func TestCryptoKeyCache(t *testing.T) { } ff.keys = []codersdk.CryptoKey{newKey} - _, err = cache.Verifying(ctx, newKey.Sequence) + key, err := cache.VerifyingKey(ctx, keyID(newKey)) require.NoError(t, err) require.Equal(t, 2, ff.called) + require.Equal(t, decodedSecret(t, newKey), key) trapped.Release() wait.MustWait(ctx) @@ -447,26 +459,27 @@ func TestCryptoKeyCache(t *testing.T) { }, } - cache, err := wsproxy.NewCryptoKeyCache(ctx, logger, ff, withClock(clock)) + cache, err := cryptokeys.NewSigningCache(ctx, logger, ff, codersdk.CryptoKeyFeatureWorkspaceApp, cryptokeys.WithCacheClock(clock)) require.NoError(t, err) - got, err := cache.Signing(ctx) + id, got, err := cache.SigningKey(ctx) require.NoError(t, err) - require.Equal(t, expected, got) + require.Equal(t, keyID(expected), id) + require.Equal(t, decodedSecret(t, expected), got) require.Equal(t, 1, ff.called) - got, err = cache.Verifying(ctx, expected.Sequence) + key, err := cache.VerifyingKey(ctx, keyID(expected)) require.NoError(t, err) - require.Equal(t, expected, got) + require.Equal(t, decodedSecret(t, expected), key) require.Equal(t, 1, ff.called) cache.Close() - _, err = cache.Signing(ctx) - require.ErrorIs(t, err, ErrClosed) + _, _, err = cache.SigningKey(ctx) + require.ErrorIs(t, err, cryptokeys.ErrClosed) - _, err = cache.Verifying(ctx, expected.Sequence) - require.ErrorIs(t, err, ErrClosed) + _, err = cache.VerifyingKey(ctx, keyID(expected)) + require.ErrorIs(t, err, cryptokeys.ErrClosed) }) } @@ -480,8 +493,25 @@ func (f *fakeFetcher) Fetch(_ context.Context) ([]codersdk.CryptoKey, error) { return f.keys, nil } -func withClock(clock quartz.Clock) func(*wsproxy.CryptoKeyCache) { - return func(cache *wsproxy.CryptoKeyCache) { - cache.Clock = clock - } +func keyID(key codersdk.CryptoKey) string { + return strconv.FormatInt(int64(key.Sequence), 10) +} + +func decodedSecret(t *testing.T, key codersdk.CryptoKey) []byte { + t.Helper() + + secret, err := hex.DecodeString(key.Secret) + require.NoError(t, err) + + return secret +} + +func generateKey(t *testing.T, size int) string { + t.Helper() + + key := make([]byte, size) + _, err := rand.Read(key) + require.NoError(t, err) + + return hex.EncodeToString(key) } diff --git a/coderd/cryptokeys/dbkeycache.go b/coderd/cryptokeys/dbkeycache.go deleted file mode 100644 index 26941192619b2..0000000000000 --- a/coderd/cryptokeys/dbkeycache.go +++ /dev/null @@ -1,352 +0,0 @@ -package cryptokeys - -import ( - "context" - "encoding/hex" - "strconv" - "sync" - "time" - - "golang.org/x/xerrors" - - "cdr.dev/slog" - "github.com/coder/coder/v2/coderd/database" - "github.com/coder/coder/v2/coderd/database/db2sdk" - "github.com/coder/coder/v2/codersdk" - "github.com/coder/quartz" -) - -const ( - // latestSequence is a special sequence number that represents the latest key. - latestSequence = -1 - // refreshInterval is the interval at which the key cache will refresh. - refreshInterval = time.Minute * 10 -) - -type DBFetcher struct { - DB database.Store - Feature database.CryptoKeyFeature -} - -func (d *DBFetcher) Fetch(ctx context.Context) ([]codersdk.CryptoKey, error) { - keys, err := d.DB.GetCryptoKeysByFeature(ctx, d.Feature) - if err != nil { - return nil, xerrors.Errorf("get crypto keys by feature: %w", err) - } - - return db2sdk.CryptoKeys(keys), nil -} - -// CryptoKeyCache implements Keycache for callers with access to the database. -type CryptoKeyCache struct { - clock quartz.Clock - refreshCtx context.Context - refreshCancel context.CancelFunc - fetcher Fetcher - logger slog.Logger - feature codersdk.CryptoKeyFeature - - mu sync.Mutex - keys map[int32]codersdk.CryptoKey - lastFetch time.Time - refresher *quartz.Timer - fetching bool - closed bool - cond *sync.Cond -} - -type DBCacheOption func(*CryptoKeyCache) - -func WithDBCacheClock(clock quartz.Clock) DBCacheOption { - return func(d *CryptoKeyCache) { - d.clock = clock - } -} - -// NewSigningCache creates a new DBCache. Close should be called to -// release resources associated with its internal timer. -func NewSigningCache(ctx context.Context, logger slog.Logger, fetcher Fetcher, feature codersdk.CryptoKeyFeature, opts ...func(*CryptoKeyCache)) (SigningKeycache, error) { - if !isSigningKeyFeature(feature) { - return nil, ErrInvalidFeature - } - - return newDBCache(ctx, logger, fetcher, feature, opts...) -} - -func NewEncryptionCache(ctx context.Context, logger slog.Logger, fetcher Fetcher, feature codersdk.CryptoKeyFeature, opts ...func(*CryptoKeyCache)) (EncryptionKeycache, error) { - if !isEncryptionKeyFeature(feature) { - return nil, ErrInvalidFeature - } - - return newDBCache(ctx, logger, fetcher, feature, opts...) -} - -func newDBCache(ctx context.Context, logger slog.Logger, fetcher Fetcher, feature codersdk.CryptoKeyFeature, opts ...func(*CryptoKeyCache)) (*CryptoKeyCache, error) { - cache := &CryptoKeyCache{ - clock: quartz.NewReal(), - logger: logger, - fetcher: fetcher, - feature: feature, - } - - for _, opt := range opts { - opt(cache) - } - - cache.cond = sync.NewCond(&cache.mu) - cache.refreshCtx, cache.refreshCancel = context.WithCancel(ctx) - cache.refresher = cache.clock.AfterFunc(refreshInterval, cache.refresh) - - keys, err := cache.cryptoKeys(ctx) - if err != nil { - cache.refreshCancel() - return nil, xerrors.Errorf("initial fetch: %w", err) - } - cache.keys = keys - return cache, nil -} - -func (d *CryptoKeyCache) EncryptingKey(ctx context.Context) (string, interface{}, error) { - if !isEncryptionKeyFeature(d.feature) { - return "", nil, ErrInvalidFeature - } - - key, err := d.cryptoKey(ctx, latestSequence) - if err != nil { - return "", nil, xerrors.Errorf("crypto key: %w", err) - } - - secret, err := hex.DecodeString(key.Secret) - if err != nil { - return "", nil, xerrors.Errorf("decode key: %w", err) - } - - return strconv.FormatInt(int64(key.Sequence), 10), secret, nil -} - -func (d *CryptoKeyCache) DecryptingKey(ctx context.Context, id string) (interface{}, error) { - if !isEncryptionKeyFeature(d.feature) { - return nil, ErrInvalidFeature - } - - i, err := strconv.ParseInt(id, 10, 64) - if err != nil { - return nil, xerrors.Errorf("parse id: %w", err) - } - - key, err := d.cryptoKey(ctx, int32(i)) - if err != nil { - return nil, xerrors.Errorf("crypto key: %w", err) - } - - secret, err := hex.DecodeString(key.Secret) - if err != nil { - return nil, xerrors.Errorf("decode key: %w", err) - } - - return secret, nil -} - -func (d *CryptoKeyCache) SigningKey(ctx context.Context) (string, interface{}, error) { - if !isSigningKeyFeature(d.feature) { - return "", nil, ErrInvalidFeature - } - - key, err := d.cryptoKey(ctx, latestSequence) - if err != nil { - return "", nil, xerrors.Errorf("crypto key: %w", err) - } - - return strconv.FormatInt(int64(key.Sequence), 10), key.Secret, nil -} - -func (d *CryptoKeyCache) VerifyingKey(ctx context.Context, sequence string) (interface{}, error) { - if !isSigningKeyFeature(d.feature) { - return nil, ErrInvalidFeature - } - - i, err := strconv.ParseInt(sequence, 10, 64) - if err != nil { - return nil, xerrors.Errorf("parse id: %w", err) - } - - key, err := d.cryptoKey(ctx, int32(i)) - if err != nil { - return nil, xerrors.Errorf("crypto key: %w", err) - } - - return key.Secret, nil -} - -func isEncryptionKeyFeature(feature codersdk.CryptoKeyFeature) bool { - return feature == codersdk.CryptoKeyFeatureWorkspaceApp -} - -func isSigningKeyFeature(feature codersdk.CryptoKeyFeature) bool { - switch feature { - case codersdk.CryptoKeyFeatureTailnetResume, codersdk.CryptoKeyFeatureOIDCConvert: - return true - default: - return false - } -} - -func idSecret(k database.CryptoKey) (string, interface{}, error) { - key, err := k.DecodeString() - if err != nil { - return "", nil, xerrors.Errorf("decode key: %w", err) - } - - return strconv.FormatInt(int64(k.Sequence), 10), key, nil -} - -func (k *CryptoKeyCache) cryptoKey(ctx context.Context, sequence int32) (codersdk.CryptoKey, error) { - k.mu.Lock() - defer k.mu.Unlock() - - if k.closed { - return codersdk.CryptoKey{}, ErrClosed - } - - var key codersdk.CryptoKey - var ok bool - for key, ok = k.key(sequence); !ok && k.fetching && !k.closed; { - k.cond.Wait() - } - - if k.closed { - return codersdk.CryptoKey{}, ErrClosed - } - - if ok { - return checkKey(key, sequence, k.clock.Now()) - } - - k.fetching = true - k.mu.Unlock() - - keys, err := k.cryptoKeys(ctx) - if err != nil { - return codersdk.CryptoKey{}, xerrors.Errorf("get keys: %w", err) - } - - k.mu.Lock() - k.lastFetch = k.clock.Now() - k.refresher.Reset(refreshInterval) - k.keys = keys - k.fetching = false - k.cond.Broadcast() - - key, ok = k.key(sequence) - if !ok { - return codersdk.CryptoKey{}, ErrKeyNotFound - } - - return checkKey(key, sequence, k.clock.Now()) -} - -func (k *CryptoKeyCache) key(sequence int32) (codersdk.CryptoKey, bool) { - if sequence == latestSequence { - return k.keys[latestSequence], k.keys[latestSequence].CanSign(k.clock.Now()) - } - - key, ok := k.keys[sequence] - return key, ok -} - -func checkKey(key codersdk.CryptoKey, sequence int32, now time.Time) (codersdk.CryptoKey, error) { - if sequence == latestSequence { - if !key.CanSign(now) { - return codersdk.CryptoKey{}, ErrKeyInvalid - } - return key, nil - } - - if !key.CanVerify(now) { - return codersdk.CryptoKey{}, ErrKeyInvalid - } - - return key, nil -} - -// refresh fetches the keys and updates the cache. -func (k *CryptoKeyCache) refresh() { - now := k.clock.Now("CryptoKeyCache", "refresh") - k.mu.Lock() - - if k.closed { - k.mu.Unlock() - return - } - - // If something's already fetching, we don't need to do anything. - if k.fetching { - k.mu.Unlock() - return - } - - // There's a window we must account for where the timer fires while a fetch - // is ongoing but prior to the timer getting reset. In this case we want to - // avoid double fetching. - if now.Sub(k.lastFetch) < refreshInterval { - k.mu.Unlock() - return - } - - k.fetching = true - - k.mu.Unlock() - keys, err := k.cryptoKeys(k.refreshCtx) - if err != nil { - k.logger.Error(k.refreshCtx, "fetch crypto keys", slog.Error(err)) - return - } - - k.mu.Lock() - defer k.mu.Unlock() - - k.lastFetch = k.clock.Now() - k.refresher.Reset(refreshInterval) - k.keys = keys - k.fetching = false - k.cond.Broadcast() -} - -// cryptoKeys queries the control plane for the crypto keys. -// Outside of initialization, this should only be called by fetch. -func (k *CryptoKeyCache) cryptoKeys(ctx context.Context) (map[int32]codersdk.CryptoKey, error) { - keys, err := k.fetcher.Fetch(ctx) - if err != nil { - return nil, xerrors.Errorf("crypto keys: %w", err) - } - cache := toKeyMap(keys, k.clock.Now()) - return cache, nil -} - -func toKeyMap(keys []codersdk.CryptoKey, now time.Time) map[int32]codersdk.CryptoKey { - m := make(map[int32]codersdk.CryptoKey) - var latest codersdk.CryptoKey - for _, key := range keys { - m[key.Sequence] = key - if key.Sequence > latest.Sequence && key.CanSign(now) { - m[latestSequence] = key - } - } - return m -} - -func (k *CryptoKeyCache) Close() error { - k.mu.Lock() - defer k.mu.Unlock() - - if k.closed { - return nil - } - - k.closed = true - k.refreshCancel() - k.refresher.Stop() - k.cond.Broadcast() - - return nil -} diff --git a/coderd/cryptokeys/dbkeycache_test.go b/coderd/cryptokeys/dbkeycache_test.go deleted file mode 100644 index cd68a196c493f..0000000000000 --- a/coderd/cryptokeys/dbkeycache_test.go +++ /dev/null @@ -1,222 +0,0 @@ -package cryptokeys_test - -import ( - "strconv" - "testing" - - "github.com/stretchr/testify/require" - "go.uber.org/goleak" - - "cdr.dev/slog/sloggers/slogtest" - - "github.com/coder/coder/v2/coderd/cryptokeys" - "github.com/coder/coder/v2/coderd/database" - "github.com/coder/coder/v2/coderd/database/dbgen" - "github.com/coder/coder/v2/coderd/database/dbtestutil" - "github.com/coder/coder/v2/testutil" - "github.com/coder/quartz" -) - -func TestMain(m *testing.M) { - goleak.VerifyTestMain(m) -} - -func TestDBKeyCache(t *testing.T) { - t.Parallel() - - t.Run("VerifyingKey", func(t *testing.T) { - t.Parallel() - - t.Run("HitsCache", func(t *testing.T) { - t.Parallel() - - var ( - db, _ = dbtestutil.NewDB(t) - clock = quartz.NewMock(t) - ctx = testutil.Context(t, testutil.WaitShort) - logger = slogtest.Make(t, nil) - fetcher = &cryptokeys.DBFetcher{DB: db, Feature: database.CryptoKeyFeatureOidcConvert} - ) - - key := dbgen.CryptoKey(t, db, database.CryptoKey{ - Feature: database.CryptoKeyFeatureOidcConvert, - Sequence: 1, - StartsAt: clock.Now().UTC(), - }) - - k, err := cryptokeys.NewSigningCache(ctx, logger, fetcher, database.CryptoKeyFeatureOidcConvert, cryptokeys.WithDBCacheClock(clock)) - require.NoError(t, err) - defer k.Close() - - got, err := k.VerifyingKey(ctx, keyID(key)) - require.NoError(t, err) - require.Equal(t, decodedSecret(t, key), got) - }) - - t.Run("NotFound", func(t *testing.T) { - t.Parallel() - - var ( - db, _ = dbtestutil.NewDB(t) - clock = quartz.NewMock(t) - ctx = testutil.Context(t, testutil.WaitShort) - logger = slogtest.Make(t, nil) - fetcher = &cryptokeys.DBFetcher{DB: db, Feature: database.CryptoKeyFeatureOidcConvert} - ) - - k, err := cryptokeys.NewSigningCache(ctx, logger, fetcher, database.CryptoKeyFeatureOidcConvert, cryptokeys.WithDBCacheClock(clock)) - require.NoError(t, err) - defer k.Close() - - _, err = k.VerifyingKey(ctx, "123") - require.ErrorIs(t, err, cryptokeys.ErrKeyNotFound) - }) - }) - - t.Run("Signing", func(t *testing.T) { - t.Parallel() - - var ( - db, _ = dbtestutil.NewDB(t) - clock = quartz.NewMock(t) - ctx = testutil.Context(t, testutil.WaitShort) - logger = slogtest.Make(t, nil) - fetcher = &cryptokeys.DBFetcher{DB: db, Feature: database.CryptoKeyFeatureOidcConvert} - ) - - _ = dbgen.CryptoKey(t, db, database.CryptoKey{ - Feature: database.CryptoKeyFeatureOidcConvert, - Sequence: 10, - StartsAt: clock.Now().UTC(), - }) - - expectedKey := dbgen.CryptoKey(t, db, database.CryptoKey{ - Feature: database.CryptoKeyFeatureOidcConvert, - Sequence: 12, - StartsAt: clock.Now().UTC(), - }) - - _ = dbgen.CryptoKey(t, db, database.CryptoKey{ - Feature: database.CryptoKeyFeatureOidcConvert, - Sequence: 2, - StartsAt: clock.Now().UTC(), - }) - - k, err := cryptokeys.NewSigningCache(ctx, logger, fetcher, database.CryptoKeyFeatureOidcConvert, cryptokeys.WithDBCacheClock(clock)) - require.NoError(t, err) - defer k.Close() - - id, key, err := k.SigningKey(ctx) - require.NoError(t, err) - require.Equal(t, keyID(expectedKey), id) - require.Equal(t, decodedSecret(t, expectedKey), key) - }) - - t.Run("Closed", func(t *testing.T) { - t.Parallel() - - var ( - db, _ = dbtestutil.NewDB(t) - clock = quartz.NewMock(t) - ctx = testutil.Context(t, testutil.WaitShort) - logger = slogtest.Make(t, nil) - fetcher = &cryptokeys.DBFetcher{DB: db, Feature: database.CryptoKeyFeatureOidcConvert} - ) - - expectedKey := dbgen.CryptoKey(t, db, database.CryptoKey{ - Feature: database.CryptoKeyFeatureOidcConvert, - Sequence: 10, - StartsAt: clock.Now(), - }) - - k, err := cryptokeys.NewSigningCache(ctx, logger, fetcher, database.CryptoKeyFeatureOidcConvert, cryptokeys.WithDBCacheClock(clock)) - require.NoError(t, err) - defer k.Close() - - id, key, err := k.SigningKey(ctx) - require.NoError(t, err) - require.Equal(t, keyID(expectedKey), id) - require.Equal(t, decodedSecret(t, expectedKey), key) - - key, err = k.VerifyingKey(ctx, keyID(expectedKey)) - require.NoError(t, err) - require.Equal(t, decodedSecret(t, expectedKey), key) - - k.Close() - - _, _, err = k.SigningKey(ctx) - require.ErrorIs(t, err, cryptokeys.ErrClosed) - - _, err = k.VerifyingKey(ctx, keyID(expectedKey)) - require.ErrorIs(t, err, cryptokeys.ErrClosed) - }) - - t.Run("InvalidSigningFeature", func(t *testing.T) { - t.Parallel() - - var ( - db, _ = dbtestutil.NewDB(t) - clock = quartz.NewMock(t) - logger = slogtest.Make(t, nil) - fetcher = &cryptokeys.DBFetcher{DB: db, Feature: database.CryptoKeyFeatureOidcConvert} - ctx = testutil.Context(t, testutil.WaitShort) - ) - - _, err := cryptokeys.NewSigningCache(ctx, logger, fetcher, database.CryptoKeyFeatureWorkspaceApps, cryptokeys.WithDBCacheClock(clock)) - require.ErrorIs(t, err, cryptokeys.ErrInvalidFeature) - - // Instantiate a signing cache and try to use it as an encryption cache. - sc, err := cryptokeys.NewSigningCache(ctx, logger, fetcher, database.CryptoKeyFeatureOidcConvert, cryptokeys.WithDBCacheClock(clock)) - require.NoError(t, err) - defer sc.Close() - - ec, ok := sc.(cryptokeys.EncryptionKeycache) - require.True(t, ok) - _, _, err = ec.EncryptingKey(ctx) - require.ErrorIs(t, err, cryptokeys.ErrInvalidFeature) - - _, err = ec.DecryptingKey(ctx, "123") - require.ErrorIs(t, err, cryptokeys.ErrInvalidFeature) - }) - - t.Run("InvalidEncryptionFeature", func(t *testing.T) { - t.Parallel() - - var ( - db, _ = dbtestutil.NewDB(t) - clock = quartz.NewMock(t) - logger = slogtest.Make(t, nil) - fetcher = &cryptokeys.DBFetcher{DB: db, Feature: database.CryptoKeyFeatureOidcConvert} - ctx = testutil.Context(t, testutil.WaitShort) - ) - - _, err := cryptokeys.NewEncryptionCache(ctx, logger, fetcher, database.CryptoKeyFeatureOidcConvert, cryptokeys.WithDBCacheClock(clock)) - require.ErrorIs(t, err, cryptokeys.ErrInvalidFeature) - - // Instantiate an encryption cache and try to use it as a signing cache. - ec, err := cryptokeys.NewEncryptionCache(ctx, logger, fetcher, database.CryptoKeyFeatureWorkspaceApps, cryptokeys.WithDBCacheClock(clock)) - require.NoError(t, err) - defer ec.Close() - - sc, ok := ec.(cryptokeys.SigningKeycache) - require.True(t, ok) - _, _, err = sc.SigningKey(ctx) - require.ErrorIs(t, err, cryptokeys.ErrInvalidFeature) - - _, err = sc.VerifyingKey(ctx, "123") - require.ErrorIs(t, err, cryptokeys.ErrInvalidFeature) - }) -} - -func keyID(key database.CryptoKey) string { - return strconv.FormatInt(int64(key.Sequence), 10) -} - -func decodedSecret(t *testing.T, key database.CryptoKey) []byte { - t.Helper() - - secret, err := key.DecodeString() - require.NoError(t, err) - - return secret -} diff --git a/coderd/cryptokeys/keycache.go b/coderd/cryptokeys/keycache.go deleted file mode 100644 index 076256448d659..0000000000000 --- a/coderd/cryptokeys/keycache.go +++ /dev/null @@ -1,47 +0,0 @@ -package cryptokeys - -import ( - "context" - "io" - - "golang.org/x/xerrors" - - "github.com/coder/coder/v2/codersdk" -) - -var ( - ErrKeyNotFound = xerrors.New("key not found") - ErrKeyInvalid = xerrors.New("key is invalid for use") - ErrClosed = xerrors.New("closed") - ErrInvalidFeature = xerrors.New("invalid feature for this operation") -) - -type Fetcher interface { - Fetch(ctx context.Context) ([]codersdk.CryptoKey, error) -} - -type EncryptionKeycache interface { - // EncryptingKey returns the latest valid key for encrypting payloads. A valid - // key is one that is both past its start time and before its deletion time. - EncryptingKey(ctx context.Context) (id string, key interface{}, err error) - // DecryptingKey returns the key with the provided id which maps to its sequence - // number. The key is valid for decryption as long as it is not deleted or past - // its deletion date. We must allow for keys prior to their start time to - // account for clock skew between peers (one key may be past its start time on - // one machine while another is not). - DecryptingKey(ctx context.Context, id string) (key interface{}, err error) - io.Closer -} - -type SigningKeycache interface { - // SigningKey returns the latest valid key for signing. A valid key is one - // that is both past its start time and before its deletion time. - SigningKey(ctx context.Context) (id string, key interface{}, err error) - // VerifyingKey returns the key with the provided id which should map to its - // sequence number. The key is valid for verifying as long as it is not deleted - // or past its deletion date. We must allow for keys prior to their start time - // to account for clock skew between peers (one key may be past its start time - // on one machine while another is not). - VerifyingKey(ctx context.Context, id string) (key interface{}, err error) - io.Closer -} diff --git a/coderd/jwtutils/jwt_test.go b/coderd/jwtutils/jwt_test.go index 46bb45a2cecd5..f40ded21c9223 100644 --- a/coderd/jwtutils/jwt_test.go +++ b/coderd/jwtutils/jwt_test.go @@ -18,6 +18,7 @@ import ( "github.com/coder/coder/v2/coderd/database/dbgen" "github.com/coder/coder/v2/coderd/database/dbtestutil" "github.com/coder/coder/v2/coderd/jwtutils" + "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/testutil" ) @@ -242,7 +243,7 @@ func TestJWS(t *testing.T) { fetcher = &cryptokeys.DBFetcher{DB: db, Feature: database.CryptoKeyFeatureOidcConvert} ) - cache, err := cryptokeys.NewSigningCache(ctx, log, fetcher, database.CryptoKeyFeatureOidcConvert) + cache, err := cryptokeys.NewSigningCache(ctx, log, fetcher, codersdk.CryptoKeyFeatureOIDCConvert) require.NoError(t, err) claims := testClaims{ @@ -333,7 +334,7 @@ func TestJWE(t *testing.T) { fetcher = &cryptokeys.DBFetcher{DB: db, Feature: database.CryptoKeyFeatureOidcConvert} ) - cache, err := cryptokeys.NewEncryptionCache(ctx, log, fetcher, database.CryptoKeyFeatureWorkspaceApps) + cache, err := cryptokeys.NewEncryptionCache(ctx, log, fetcher, codersdk.CryptoKeyFeatureWorkspaceApp) require.NoError(t, err) claims := testClaims{ diff --git a/enterprise/wsproxy/keycache.go b/enterprise/wsproxy/keycache.go deleted file mode 100644 index a877b9757d250..0000000000000 --- a/enterprise/wsproxy/keycache.go +++ /dev/null @@ -1,224 +0,0 @@ -package wsproxy - -import ( - "context" - "sync" - "time" - - "golang.org/x/xerrors" - - "cdr.dev/slog" - - "github.com/coder/coder/v2/coderd/cryptokeys" - "github.com/coder/coder/v2/codersdk" - "github.com/coder/quartz" -) - -const ( - // latestSequence is a special sequence number that represents the latest key. - latestSequence = -1 - // refreshInterval is the interval at which the key cache will refresh. - refreshInterval = time.Minute * 10 -) - -type Fetcher interface { - Fetch(ctx context.Context) ([]codersdk.CryptoKey, error) -} - -type CryptoKeyCache struct { - Clock quartz.Clock - refreshCtx context.Context - refreshCancel context.CancelFunc - fetcher Fetcher - logger slog.Logger - - mu sync.Mutex - keys map[int32]codersdk.CryptoKey - lastFetch time.Time - refresher *quartz.Timer - fetching bool - closed bool - cond *sync.Cond -} - -func NewCryptoKeyCache(ctx context.Context, log slog.Logger, client Fetcher, opts ...func(*CryptoKeyCache)) (*CryptoKeyCache, error) { - cache := &CryptoKeyCache{ - Clock: quartz.NewReal(), - logger: log, - fetcher: client, - } - - for _, opt := range opts { - opt(cache) - } - - cache.cond = sync.NewCond(&cache.mu) - cache.refreshCtx, cache.refreshCancel = context.WithCancel(ctx) - cache.refresher = cache.Clock.AfterFunc(refreshInterval, cache.refresh) - - keys, err := cache.cryptoKeys(ctx) - if err != nil { - cache.refreshCancel() - return nil, xerrors.Errorf("initial fetch: %w", err) - } - cache.keys = keys - - return cache, nil -} - -func (k *CryptoKeyCache) Signing(ctx context.Context) (codersdk.CryptoKey, error) { - return k.cryptoKey(ctx, latestSequence) -} - -func (k *CryptoKeyCache) Verifying(ctx context.Context, sequence int32) (codersdk.CryptoKey, error) { - return k.cryptoKey(ctx, sequence) -} - -func (k *CryptoKeyCache) cryptoKey(ctx context.Context, sequence int32) (codersdk.CryptoKey, error) { - k.mu.Lock() - defer k.mu.Unlock() - - if k.closed { - return codersdk.CryptoKey{}, cryptokeys.ErrClosed - } - - var key codersdk.CryptoKey - var ok bool - for key, ok = k.key(sequence); !ok && k.fetching && !k.closed; { - k.cond.Wait() - } - - if k.closed { - return codersdk.CryptoKey{}, cryptokeys.ErrClosed - } - - if ok { - return checkKey(key, sequence, k.Clock.Now()) - } - - k.fetching = true - k.mu.Unlock() - - keys, err := k.cryptoKeys(ctx) - if err != nil { - return codersdk.CryptoKey{}, xerrors.Errorf("get keys: %w", err) - } - - k.mu.Lock() - k.lastFetch = k.Clock.Now() - k.refresher.Reset(refreshInterval) - k.keys = keys - k.fetching = false - k.cond.Broadcast() - - key, ok = k.key(sequence) - if !ok { - return codersdk.CryptoKey{}, cryptokeys.ErrKeyNotFound - } - - return checkKey(key, sequence, k.Clock.Now()) -} - -func (k *CryptoKeyCache) key(sequence int32) (codersdk.CryptoKey, bool) { - if sequence == latestSequence { - return k.keys[latestSequence], k.keys[latestSequence].CanSign(k.Clock.Now()) - } - - key, ok := k.keys[sequence] - return key, ok -} - -func checkKey(key codersdk.CryptoKey, sequence int32, now time.Time) (codersdk.CryptoKey, error) { - if sequence == latestSequence { - if !key.CanSign(now) { - return codersdk.CryptoKey{}, cryptokeys.ErrKeyInvalid - } - return key, nil - } - - if !key.CanVerify(now) { - return codersdk.CryptoKey{}, cryptokeys.ErrKeyInvalid - } - - return key, nil -} - -// refresh fetches the keys from the control plane and updates the cache. -func (k *CryptoKeyCache) refresh() { - now := k.Clock.Now("CryptoKeyCache", "refresh") - k.mu.Lock() - - if k.closed { - k.mu.Unlock() - return - } - - // If something's already fetching, we don't need to do anything. - if k.fetching { - k.mu.Unlock() - return - } - - // There's a window we must account for where the timer fires while a fetch - // is ongoing but prior to the timer getting reset. In this case we want to - // avoid double fetching. - if now.Sub(k.lastFetch) < refreshInterval { - k.mu.Unlock() - return - } - - k.fetching = true - - k.mu.Unlock() - keys, err := k.cryptoKeys(k.refreshCtx) - if err != nil { - k.logger.Error(k.refreshCtx, "fetch crypto keys", slog.Error(err)) - return - } - - k.mu.Lock() - defer k.mu.Unlock() - - k.lastFetch = k.Clock.Now() - k.refresher.Reset(refreshInterval) - k.keys = keys - k.fetching = false - k.cond.Broadcast() -} - -// cryptoKeys queries the control plane for the crypto keys. -// Outside of initialization, this should only be called by fetch. -func (k *CryptoKeyCache) cryptoKeys(ctx context.Context) (map[int32]codersdk.CryptoKey, error) { - keys, err := k.fetcher.Fetch(ctx) - if err != nil { - return nil, xerrors.Errorf("crypto keys: %w", err) - } - cache := toKeyMap(keys, k.Clock.Now()) - return cache, nil -} - -func toKeyMap(keys []codersdk.CryptoKey, now time.Time) map[int32]codersdk.CryptoKey { - m := make(map[int32]codersdk.CryptoKey) - var latest codersdk.CryptoKey - for _, key := range keys { - m[key.Sequence] = key - if key.Sequence > latest.Sequence && key.CanSign(now) { - m[latestSequence] = key - } - } - return m -} - -func (k *CryptoKeyCache) Close() { - k.mu.Lock() - defer k.mu.Unlock() - - if k.closed { - return - } - - k.closed = true - k.refreshCancel() - k.refresher.Stop() - k.cond.Broadcast() -} diff --git a/enterprise/wsproxy/keycache_test.go b/enterprise/wsproxy/keycache_test.go deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/enterprise/wsproxy/keyfetcher.go b/enterprise/wsproxy/keyfetcher.go new file mode 100644 index 0000000000000..81b71301b610f --- /dev/null +++ b/enterprise/wsproxy/keyfetcher.go @@ -0,0 +1,25 @@ +package wsproxy + +import ( + "context" + + "github.com/coder/coder/v2/coderd/cryptokeys" + "github.com/coder/coder/v2/codersdk" + "github.com/coder/coder/v2/enterprise/wsproxy/wsproxysdk" + "golang.org/x/xerrors" +) + +var _ cryptokeys.Fetcher = &ProxyFetcher{} + +type ProxyFetcher struct { + Client *wsproxysdk.Client + Feature codersdk.CryptoKeyFeature +} + +func (p *ProxyFetcher) Fetch(ctx context.Context) ([]codersdk.CryptoKey, error) { + keys, err := p.Client.CryptoKeys(ctx) + if err != nil { + return nil, xerrors.Errorf("crypto keys: %w", err) + } + return keys.CryptoKeys, nil +} From b991aac2cd7f72e364fc981821039d0af45b3cdc Mon Sep 17 00:00:00 2001 From: Jon Ayers Date: Wed, 16 Oct 2024 00:29:13 +0000 Subject: [PATCH 4/7] fix tests --- coderd/cryptokeys/cache.go | 3 -- coderd/cryptokeys/cache_internal_test.go | 1 - coderd/cryptokeys/cache_test.go | 56 ++++++++++++------------ 3 files changed, 28 insertions(+), 32 deletions(-) delete mode 100644 coderd/cryptokeys/cache_internal_test.go diff --git a/coderd/cryptokeys/cache.go b/coderd/cryptokeys/cache.go index 8436b7099ad1c..e074f5de4b83b 100644 --- a/coderd/cryptokeys/cache.go +++ b/coderd/cryptokeys/cache.go @@ -13,7 +13,6 @@ import ( "cdr.dev/slog" "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/db2sdk" - "github.com/coder/coder/v2/coderd/jwtutils" "github.com/coder/coder/v2/codersdk" "github.com/coder/quartz" ) @@ -25,8 +24,6 @@ var ( ErrInvalidFeature = xerrors.New("invalid feature for this operation") ) -var _ jwtutils.SigningKeyProvider = &cache{} - type Fetcher interface { Fetch(ctx context.Context) ([]codersdk.CryptoKey, error) } diff --git a/coderd/cryptokeys/cache_internal_test.go b/coderd/cryptokeys/cache_internal_test.go deleted file mode 100644 index 4845ac9198d2c..0000000000000 --- a/coderd/cryptokeys/cache_internal_test.go +++ /dev/null @@ -1 +0,0 @@ -package cryptokeys diff --git a/coderd/cryptokeys/cache_test.go b/coderd/cryptokeys/cache_test.go index 58699b0b4b09f..92fc4527ae7b3 100644 --- a/coderd/cryptokeys/cache_test.go +++ b/coderd/cryptokeys/cache_test.go @@ -39,7 +39,7 @@ func TestCryptoKeyCache(t *testing.T) { now := clock.Now().UTC() expected := codersdk.CryptoKey{ - Feature: codersdk.CryptoKeyFeatureWorkspaceApp, + Feature: codersdk.CryptoKeyFeatureTailnetResume, Secret: generateKey(t, 64), Sequence: 2, StartsAt: now, @@ -49,7 +49,7 @@ func TestCryptoKeyCache(t *testing.T) { keys: []codersdk.CryptoKey{expected}, } - cache, err := cryptokeys.NewSigningCache(ctx, logger, ff, codersdk.CryptoKeyFeatureWorkspaceApp, cryptokeys.WithCacheClock(clock)) + cache, err := cryptokeys.NewSigningCache(ctx, logger, ff, codersdk.CryptoKeyFeatureTailnetResume, cryptokeys.WithCacheClock(clock)) require.NoError(t, err) id, got, err := cache.SigningKey(ctx) @@ -71,11 +71,11 @@ func TestCryptoKeyCache(t *testing.T) { keys: []codersdk.CryptoKey{}, } - cache, err := cryptokeys.NewSigningCache(ctx, logger, ff, codersdk.CryptoKeyFeatureWorkspaceApp, cryptokeys.WithCacheClock(clock)) + cache, err := cryptokeys.NewSigningCache(ctx, logger, ff, codersdk.CryptoKeyFeatureTailnetResume, cryptokeys.WithCacheClock(clock)) require.NoError(t, err) expected := codersdk.CryptoKey{ - Feature: codersdk.CryptoKeyFeatureWorkspaceApp, + Feature: codersdk.CryptoKeyFeatureTailnetResume, Secret: generateKey(t, 64), Sequence: 12, StartsAt: clock.Now().UTC(), @@ -109,7 +109,7 @@ func TestCryptoKeyCache(t *testing.T) { now := clock.Now().UTC() expected := codersdk.CryptoKey{ - Feature: codersdk.CryptoKeyFeatureWorkspaceApp, + Feature: codersdk.CryptoKeyFeatureTailnetResume, Secret: generateKey(t, 64), Sequence: 1, StartsAt: clock.Now().UTC(), @@ -119,7 +119,7 @@ func TestCryptoKeyCache(t *testing.T) { keys: []codersdk.CryptoKey{ expected, { - Feature: codersdk.CryptoKeyFeatureWorkspaceApp, + Feature: codersdk.CryptoKeyFeatureTailnetResume, Secret: generateKey(t, 64), Sequence: 2, StartsAt: now.Add(-time.Second), @@ -128,7 +128,7 @@ func TestCryptoKeyCache(t *testing.T) { }, } - cache, err := cryptokeys.NewSigningCache(ctx, logger, ff, codersdk.CryptoKeyFeatureWorkspaceApp, cryptokeys.WithCacheClock(clock)) + cache, err := cryptokeys.NewSigningCache(ctx, logger, ff, codersdk.CryptoKeyFeatureTailnetResume, cryptokeys.WithCacheClock(clock)) require.NoError(t, err) id, got, err := cache.SigningKey(ctx) @@ -150,7 +150,7 @@ func TestCryptoKeyCache(t *testing.T) { keys: []codersdk.CryptoKey{}, } - cache, err := cryptokeys.NewSigningCache(ctx, logger, ff, codersdk.CryptoKeyFeatureWorkspaceApp) + cache, err := cryptokeys.NewSigningCache(ctx, logger, ff, codersdk.CryptoKeyFeatureTailnetResume) require.NoError(t, err) _, _, err = cache.SigningKey(ctx) @@ -172,7 +172,7 @@ func TestCryptoKeyCache(t *testing.T) { now := clock.Now().UTC() expected := codersdk.CryptoKey{ - Feature: codersdk.CryptoKeyFeatureWorkspaceApp, + Feature: codersdk.CryptoKeyFeatureTailnetResume, Secret: generateKey(t, 64), Sequence: 12, StartsAt: now, @@ -181,7 +181,7 @@ func TestCryptoKeyCache(t *testing.T) { keys: []codersdk.CryptoKey{ expected, { - Feature: codersdk.CryptoKeyFeatureWorkspaceApp, + Feature: codersdk.CryptoKeyFeatureTailnetResume, Secret: generateKey(t, 64), Sequence: 13, StartsAt: now, @@ -189,7 +189,7 @@ func TestCryptoKeyCache(t *testing.T) { }, } - cache, err := cryptokeys.NewSigningCache(ctx, logger, ff, codersdk.CryptoKeyFeatureWorkspaceApp, cryptokeys.WithCacheClock(clock)) + cache, err := cryptokeys.NewSigningCache(ctx, logger, ff, codersdk.CryptoKeyFeatureTailnetResume, cryptokeys.WithCacheClock(clock)) require.NoError(t, err) got, err := cache.VerifyingKey(ctx, keyID(expected)) @@ -210,11 +210,11 @@ func TestCryptoKeyCache(t *testing.T) { keys: []codersdk.CryptoKey{}, } - cache, err := cryptokeys.NewSigningCache(ctx, logger, ff, codersdk.CryptoKeyFeatureWorkspaceApp, cryptokeys.WithCacheClock(clock)) + cache, err := cryptokeys.NewSigningCache(ctx, logger, ff, codersdk.CryptoKeyFeatureTailnetResume, cryptokeys.WithCacheClock(clock)) require.NoError(t, err) expected := codersdk.CryptoKey{ - Feature: codersdk.CryptoKeyFeatureWorkspaceApp, + Feature: codersdk.CryptoKeyFeatureTailnetResume, Secret: generateKey(t, 64), Sequence: 12, StartsAt: clock.Now().UTC(), @@ -244,7 +244,7 @@ func TestCryptoKeyCache(t *testing.T) { now := clock.Now().UTC() expected := codersdk.CryptoKey{ - Feature: codersdk.CryptoKeyFeatureWorkspaceApp, + Feature: codersdk.CryptoKeyFeatureTailnetResume, Secret: generateKey(t, 64), Sequence: 12, StartsAt: now.Add(-time.Second), @@ -256,7 +256,7 @@ func TestCryptoKeyCache(t *testing.T) { }, } - cache, err := cryptokeys.NewSigningCache(ctx, logger, ff, codersdk.CryptoKeyFeatureWorkspaceApp, cryptokeys.WithCacheClock(clock)) + cache, err := cryptokeys.NewSigningCache(ctx, logger, ff, codersdk.CryptoKeyFeatureTailnetResume, cryptokeys.WithCacheClock(clock)) require.NoError(t, err) got, err := cache.VerifyingKey(ctx, keyID(expected)) @@ -276,7 +276,7 @@ func TestCryptoKeyCache(t *testing.T) { now := clock.Now().UTC() expected := codersdk.CryptoKey{ - Feature: codersdk.CryptoKeyFeatureWorkspaceApp, + Feature: codersdk.CryptoKeyFeatureTailnetResume, Secret: generateKey(t, 64), Sequence: 12, StartsAt: now.Add(-time.Second), @@ -289,7 +289,7 @@ func TestCryptoKeyCache(t *testing.T) { }, } - cache, err := cryptokeys.NewSigningCache(ctx, logger, ff, codersdk.CryptoKeyFeatureWorkspaceApp, cryptokeys.WithCacheClock(clock)) + cache, err := cryptokeys.NewSigningCache(ctx, logger, ff, codersdk.CryptoKeyFeatureTailnetResume, cryptokeys.WithCacheClock(clock)) require.NoError(t, err) _, err = cache.VerifyingKey(ctx, keyID(expected)) @@ -310,7 +310,7 @@ func TestCryptoKeyCache(t *testing.T) { keys: []codersdk.CryptoKey{}, } - cache, err := cryptokeys.NewSigningCache(ctx, logger, ff, codersdk.CryptoKeyFeatureWorkspaceApp, cryptokeys.WithCacheClock(clock)) + cache, err := cryptokeys.NewSigningCache(ctx, logger, ff, codersdk.CryptoKeyFeatureTailnetResume, cryptokeys.WithCacheClock(clock)) require.NoError(t, err) _, err = cache.VerifyingKey(ctx, "1") @@ -329,7 +329,7 @@ func TestCryptoKeyCache(t *testing.T) { now := clock.Now().UTC() expected := codersdk.CryptoKey{ - Feature: codersdk.CryptoKeyFeatureWorkspaceApp, + Feature: codersdk.CryptoKeyFeatureTailnetResume, Secret: generateKey(t, 64), Sequence: 12, StartsAt: now, @@ -341,7 +341,7 @@ func TestCryptoKeyCache(t *testing.T) { }, } - cache, err := cryptokeys.NewSigningCache(ctx, logger, ff, codersdk.CryptoKeyFeatureWorkspaceApp, cryptokeys.WithCacheClock(clock)) + cache, err := cryptokeys.NewSigningCache(ctx, logger, ff, codersdk.CryptoKeyFeatureTailnetResume, cryptokeys.WithCacheClock(clock)) require.NoError(t, err) id, got, err := cache.SigningKey(ctx) @@ -351,7 +351,7 @@ func TestCryptoKeyCache(t *testing.T) { require.Equal(t, 1, ff.called) newKey := codersdk.CryptoKey{ - Feature: codersdk.CryptoKeyFeatureWorkspaceApp, + Feature: codersdk.CryptoKeyFeatureTailnetResume, Secret: generateKey(t, 64), Sequence: 13, StartsAt: now, @@ -392,7 +392,7 @@ func TestCryptoKeyCache(t *testing.T) { now := clock.Now().UTC() expected := codersdk.CryptoKey{ - Feature: codersdk.CryptoKeyFeatureWorkspaceApp, + Feature: codersdk.CryptoKeyFeatureTailnetResume, Secret: generateKey(t, 64), Sequence: 12, StartsAt: now, @@ -406,15 +406,15 @@ func TestCryptoKeyCache(t *testing.T) { // Create a trap that blocks when the refresh timer fires. trap := clock.Trap().Now("refresh") - cache, err := cryptokeys.NewSigningCache(ctx, logger, ff, codersdk.CryptoKeyFeatureWorkspaceApp, cryptokeys.WithCacheClock(clock)) + cache, err := cryptokeys.NewSigningCache(ctx, logger, ff, codersdk.CryptoKeyFeatureTailnetResume, cryptokeys.WithCacheClock(clock)) require.NoError(t, err) _, wait := clock.AdvanceNext() trapped := trap.MustWait(ctx) newKey := codersdk.CryptoKey{ - Feature: codersdk.CryptoKeyFeatureWorkspaceApp, - Secret: "key2", + Feature: codersdk.CryptoKeyFeatureTailnetResume, + Secret: generateKey(t, 64), Sequence: 13, StartsAt: now, } @@ -448,8 +448,8 @@ func TestCryptoKeyCache(t *testing.T) { now := clock.Now() expected := codersdk.CryptoKey{ - Feature: codersdk.CryptoKeyFeatureWorkspaceApp, - Secret: "key1", + Feature: codersdk.CryptoKeyFeatureTailnetResume, + Secret: generateKey(t, 64), Sequence: 12, StartsAt: now, } @@ -459,7 +459,7 @@ func TestCryptoKeyCache(t *testing.T) { }, } - cache, err := cryptokeys.NewSigningCache(ctx, logger, ff, codersdk.CryptoKeyFeatureWorkspaceApp, cryptokeys.WithCacheClock(clock)) + cache, err := cryptokeys.NewSigningCache(ctx, logger, ff, codersdk.CryptoKeyFeatureTailnetResume, cryptokeys.WithCacheClock(clock)) require.NoError(t, err) id, got, err := cache.SigningKey(ctx) From f6e2c266ec1d47272aa807ade177f1de4bb1f399 Mon Sep 17 00:00:00 2001 From: Jon Ayers Date: Wed, 16 Oct 2024 00:37:22 +0000 Subject: [PATCH 5/7] fix tests --- coderd/jwtutils/jwe.go | 2 +- coderd/jwtutils/jwt_test.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/coderd/jwtutils/jwe.go b/coderd/jwtutils/jwe.go index f50cacb62de7c..d03816a477a26 100644 --- a/coderd/jwtutils/jwe.go +++ b/coderd/jwtutils/jwe.go @@ -27,7 +27,7 @@ type DecryptKeyProvider interface { func Encrypt(ctx context.Context, e EncryptKeyProvider, claims Claims) (string, error) { id, key, err := e.EncryptingKey(ctx) if err != nil { - return "", xerrors.Errorf("get signing key: %w", err) + return "", xerrors.Errorf("encrypting key: %w", err) } encrypter, err := jose.NewEncrypter( diff --git a/coderd/jwtutils/jwt_test.go b/coderd/jwtutils/jwt_test.go index f40ded21c9223..697e5d210d858 100644 --- a/coderd/jwtutils/jwt_test.go +++ b/coderd/jwtutils/jwt_test.go @@ -331,7 +331,7 @@ func TestJWE(t *testing.T) { }) log = slogtest.Make(t, nil) - fetcher = &cryptokeys.DBFetcher{DB: db, Feature: database.CryptoKeyFeatureOidcConvert} + fetcher = &cryptokeys.DBFetcher{DB: db, Feature: database.CryptoKeyFeatureWorkspaceApps} ) cache, err := cryptokeys.NewEncryptionCache(ctx, log, fetcher, codersdk.CryptoKeyFeatureWorkspaceApp) From e11e2d216fe264f98a10cc769fb20da8cfe0869b Mon Sep 17 00:00:00 2001 From: Jon Ayers Date: Wed, 16 Oct 2024 18:35:57 +0000 Subject: [PATCH 6/7] pr comments --- coderd/cryptokeys/cache.go | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/coderd/cryptokeys/cache.go b/coderd/cryptokeys/cache.go index e074f5de4b83b..a693c9b562853 100644 --- a/coderd/cryptokeys/cache.go +++ b/coderd/cryptokeys/cache.go @@ -101,16 +101,18 @@ func WithCacheClock(clock quartz.Clock) CacheOption { } } -// NewSigningCache instantiates a cache. Close should be called to -// release resources associated with its internal timer. -func NewSigningCache(ctx context.Context, logger slog.Logger, fetcher Fetcher, feature codersdk.CryptoKeyFeature, opts ...func(*cache)) (SigningKeycache, error) { +// NewSigningCache instantiates a cache. Close should be called to release resources +// associated with its internal timer. +func NewSigningCache(ctx context.Context, logger slog.Logger, fetcher Fetcher, + feature codersdk.CryptoKeyFeature, opts ...func(*cache)) (SigningKeycache, error) { if !isSigningKeyFeature(feature) { return nil, xerrors.Errorf("invalid feature: %s", feature) } return newCache(ctx, logger, fetcher, feature, opts...) } -func NewEncryptionCache(ctx context.Context, logger slog.Logger, fetcher Fetcher, feature codersdk.CryptoKeyFeature, opts ...func(*cache)) (EncryptionKeycache, error) { +func NewEncryptionCache(ctx context.Context, logger slog.Logger, fetcher Fetcher, + feature codersdk.CryptoKeyFeature, opts ...func(*cache)) (EncryptionKeycache, error) { if !isEncryptionKeyFeature(feature) { return nil, xerrors.Errorf("invalid feature: %s", feature) } @@ -288,15 +290,14 @@ func checkKey(key codersdk.CryptoKey, sequence int32, now time.Time) (string, [] func (c *cache) refresh() { now := c.clock.Now("CryptoKeyCache", "refresh") c.mu.Lock() + defer c.mu.Unlock() if c.closed { - c.mu.Unlock() return } // If something's already fetching, we don't need to do anything. if c.fetching { - c.mu.Unlock() return } @@ -304,7 +305,6 @@ func (c *cache) refresh() { // is ongoing but prior to the timer getting reset. In this case we want to // avoid double fetching. if now.Sub(c.lastFetch) < refreshInterval { - c.mu.Unlock() return } @@ -317,8 +317,8 @@ func (c *cache) refresh() { return } + // We don't defer an unlock here due to the deferred unlock at the top of the function. c.mu.Lock() - defer c.mu.Unlock() c.lastFetch = c.clock.Now() c.refresher.Reset(refreshInterval) From 37928610fb4fd3ac945690cc1ea3eeb6869d924c Mon Sep 17 00:00:00 2001 From: Jon Ayers Date: Wed, 16 Oct 2024 18:43:00 +0000 Subject: [PATCH 7/7] fmt --- coderd/cryptokeys/cache.go | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/coderd/cryptokeys/cache.go b/coderd/cryptokeys/cache.go index a693c9b562853..74fb025d416fd 100644 --- a/coderd/cryptokeys/cache.go +++ b/coderd/cryptokeys/cache.go @@ -104,7 +104,8 @@ func WithCacheClock(clock quartz.Clock) CacheOption { // NewSigningCache instantiates a cache. Close should be called to release resources // associated with its internal timer. func NewSigningCache(ctx context.Context, logger slog.Logger, fetcher Fetcher, - feature codersdk.CryptoKeyFeature, opts ...func(*cache)) (SigningKeycache, error) { + feature codersdk.CryptoKeyFeature, opts ...func(*cache), +) (SigningKeycache, error) { if !isSigningKeyFeature(feature) { return nil, xerrors.Errorf("invalid feature: %s", feature) } @@ -112,7 +113,8 @@ func NewSigningCache(ctx context.Context, logger slog.Logger, fetcher Fetcher, } func NewEncryptionCache(ctx context.Context, logger slog.Logger, fetcher Fetcher, - feature codersdk.CryptoKeyFeature, opts ...func(*cache)) (EncryptionKeycache, error) { + feature codersdk.CryptoKeyFeature, opts ...func(*cache), +) (EncryptionKeycache, error) { if !isEncryptionKeyFeature(feature) { return nil, xerrors.Errorf("invalid feature: %s", feature) }