diff --git a/coderd/cryptokeys/cache.go b/coderd/cryptokeys/cache.go new file mode 100644 index 0000000000000..74fb025d416fd --- /dev/null +++ b/coderd/cryptokeys/cache.go @@ -0,0 +1,369 @@ +package cryptokeys + +import ( + "context" + "encoding/hex" + "io" + "strconv" + "sync" + "time" + + "golang.org/x/xerrors" + + "cdr.dev/slog" + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/db2sdk" + "github.com/coder/coder/v2/codersdk" + "github.com/coder/quartz" +) + +var ( + ErrKeyNotFound = xerrors.New("key not found") + ErrKeyInvalid = xerrors.New("key is invalid for use") + ErrClosed = xerrors.New("closed") + ErrInvalidFeature = xerrors.New("invalid feature for this operation") +) + +type Fetcher interface { + Fetch(ctx context.Context) ([]codersdk.CryptoKey, error) +} + +type EncryptionKeycache interface { + // EncryptingKey returns the latest valid key for encrypting payloads. A valid + // key is one that is both past its start time and before its deletion time. + EncryptingKey(ctx context.Context) (id string, key interface{}, err error) + // DecryptingKey returns the key with the provided id which maps to its sequence + // number. The key is valid for decryption as long as it is not deleted or past + // its deletion date. We must allow for keys prior to their start time to + // account for clock skew between peers (one key may be past its start time on + // one machine while another is not). + DecryptingKey(ctx context.Context, id string) (key interface{}, err error) + io.Closer +} + +type SigningKeycache interface { + // SigningKey returns the latest valid key for signing. A valid key is one + // that is both past its start time and before its deletion time. + SigningKey(ctx context.Context) (id string, key interface{}, err error) + // VerifyingKey returns the key with the provided id which should map to its + // sequence number. The key is valid for verifying as long as it is not deleted + // or past its deletion date. We must allow for keys prior to their start time + // to account for clock skew between peers (one key may be past its start time + // on one machine while another is not). + VerifyingKey(ctx context.Context, id string) (key interface{}, err error) + io.Closer +} + +const ( + // latestSequence is a special sequence number that represents the latest key. + latestSequence = -1 + // refreshInterval is the interval at which the key cache will refresh. + refreshInterval = time.Minute * 10 +) + +type DBFetcher struct { + DB database.Store + Feature database.CryptoKeyFeature +} + +func (d *DBFetcher) Fetch(ctx context.Context) ([]codersdk.CryptoKey, error) { + keys, err := d.DB.GetCryptoKeysByFeature(ctx, d.Feature) + if err != nil { + return nil, xerrors.Errorf("get crypto keys by feature: %w", err) + } + + return db2sdk.CryptoKeys(keys), nil +} + +// cache implements the caching functionality for both signing and encryption keys. +type cache struct { + clock quartz.Clock + refreshCtx context.Context + refreshCancel context.CancelFunc + fetcher Fetcher + logger slog.Logger + feature codersdk.CryptoKeyFeature + + mu sync.Mutex + keys map[int32]codersdk.CryptoKey + lastFetch time.Time + refresher *quartz.Timer + fetching bool + closed bool + cond *sync.Cond +} + +type CacheOption func(*cache) + +func WithCacheClock(clock quartz.Clock) CacheOption { + return func(d *cache) { + d.clock = clock + } +} + +// NewSigningCache instantiates a cache. Close should be called to release resources +// associated with its internal timer. +func NewSigningCache(ctx context.Context, logger slog.Logger, fetcher Fetcher, + feature codersdk.CryptoKeyFeature, opts ...func(*cache), +) (SigningKeycache, error) { + if !isSigningKeyFeature(feature) { + return nil, xerrors.Errorf("invalid feature: %s", feature) + } + return newCache(ctx, logger, fetcher, feature, opts...) +} + +func NewEncryptionCache(ctx context.Context, logger slog.Logger, fetcher Fetcher, + feature codersdk.CryptoKeyFeature, opts ...func(*cache), +) (EncryptionKeycache, error) { + if !isEncryptionKeyFeature(feature) { + return nil, xerrors.Errorf("invalid feature: %s", feature) + } + return newCache(ctx, logger, fetcher, feature, opts...) +} + +func newCache(ctx context.Context, logger slog.Logger, fetcher Fetcher, feature codersdk.CryptoKeyFeature, opts ...func(*cache)) (*cache, error) { + cache := &cache{ + clock: quartz.NewReal(), + logger: logger, + fetcher: fetcher, + feature: feature, + } + + for _, opt := range opts { + opt(cache) + } + + cache.cond = sync.NewCond(&cache.mu) + cache.refreshCtx, cache.refreshCancel = context.WithCancel(ctx) + cache.refresher = cache.clock.AfterFunc(refreshInterval, cache.refresh) + + keys, err := cache.cryptoKeys(ctx) + if err != nil { + cache.refreshCancel() + return nil, xerrors.Errorf("initial fetch: %w", err) + } + cache.keys = keys + return cache, nil +} + +func (c *cache) EncryptingKey(ctx context.Context) (string, interface{}, error) { + if !isEncryptionKeyFeature(c.feature) { + return "", nil, ErrInvalidFeature + } + + return c.cryptoKey(ctx, latestSequence) +} + +func (c *cache) DecryptingKey(ctx context.Context, id string) (interface{}, error) { + if !isEncryptionKeyFeature(c.feature) { + return nil, ErrInvalidFeature + } + + seq, err := strconv.ParseInt(id, 10, 64) + if err != nil { + return nil, xerrors.Errorf("parse id: %w", err) + } + + _, secret, err := c.cryptoKey(ctx, int32(seq)) + if err != nil { + return nil, xerrors.Errorf("crypto key: %w", err) + } + return secret, nil +} + +func (c *cache) SigningKey(ctx context.Context) (string, interface{}, error) { + if !isSigningKeyFeature(c.feature) { + return "", nil, ErrInvalidFeature + } + + return c.cryptoKey(ctx, latestSequence) +} + +func (c *cache) VerifyingKey(ctx context.Context, id string) (interface{}, error) { + if !isSigningKeyFeature(c.feature) { + return nil, ErrInvalidFeature + } + + seq, err := strconv.ParseInt(id, 10, 64) + if err != nil { + return nil, xerrors.Errorf("parse id: %w", err) + } + + _, secret, err := c.cryptoKey(ctx, int32(seq)) + if err != nil { + return nil, xerrors.Errorf("crypto key: %w", err) + } + + return secret, nil +} + +func isEncryptionKeyFeature(feature codersdk.CryptoKeyFeature) bool { + return feature == codersdk.CryptoKeyFeatureWorkspaceApp +} + +func isSigningKeyFeature(feature codersdk.CryptoKeyFeature) bool { + switch feature { + case codersdk.CryptoKeyFeatureTailnetResume, codersdk.CryptoKeyFeatureOIDCConvert: + return true + default: + return false + } +} + +func idSecret(k codersdk.CryptoKey) (string, []byte, error) { + key, err := hex.DecodeString(k.Secret) + if err != nil { + return "", nil, xerrors.Errorf("decode key: %w", err) + } + + return strconv.FormatInt(int64(k.Sequence), 10), key, nil +} + +func (c *cache) cryptoKey(ctx context.Context, sequence int32) (string, []byte, error) { + c.mu.Lock() + defer c.mu.Unlock() + + if c.closed { + return "", nil, ErrClosed + } + + var key codersdk.CryptoKey + var ok bool + for key, ok = c.key(sequence); !ok && c.fetching && !c.closed; { + c.cond.Wait() + } + + if c.closed { + return "", nil, ErrClosed + } + + if ok { + return checkKey(key, sequence, c.clock.Now()) + } + + c.fetching = true + c.mu.Unlock() + + keys, err := c.cryptoKeys(ctx) + if err != nil { + return "", nil, xerrors.Errorf("get keys: %w", err) + } + + c.mu.Lock() + c.lastFetch = c.clock.Now() + c.refresher.Reset(refreshInterval) + c.keys = keys + c.fetching = false + c.cond.Broadcast() + + key, ok = c.key(sequence) + if !ok { + return "", nil, ErrKeyNotFound + } + + return checkKey(key, sequence, c.clock.Now()) +} + +func (c *cache) key(sequence int32) (codersdk.CryptoKey, bool) { + if sequence == latestSequence { + return c.keys[latestSequence], c.keys[latestSequence].CanSign(c.clock.Now()) + } + + key, ok := c.keys[sequence] + return key, ok +} + +func checkKey(key codersdk.CryptoKey, sequence int32, now time.Time) (string, []byte, error) { + if sequence == latestSequence { + if !key.CanSign(now) { + return "", nil, ErrKeyInvalid + } + return idSecret(key) + } + + if !key.CanVerify(now) { + return "", nil, ErrKeyInvalid + } + + return idSecret(key) +} + +// refresh fetches the keys and updates the cache. +func (c *cache) refresh() { + now := c.clock.Now("CryptoKeyCache", "refresh") + c.mu.Lock() + defer c.mu.Unlock() + + if c.closed { + return + } + + // If something's already fetching, we don't need to do anything. + if c.fetching { + return + } + + // There's a window we must account for where the timer fires while a fetch + // is ongoing but prior to the timer getting reset. In this case we want to + // avoid double fetching. + if now.Sub(c.lastFetch) < refreshInterval { + return + } + + c.fetching = true + + c.mu.Unlock() + keys, err := c.cryptoKeys(c.refreshCtx) + if err != nil { + c.logger.Error(c.refreshCtx, "fetch crypto keys", slog.Error(err)) + return + } + + // We don't defer an unlock here due to the deferred unlock at the top of the function. + c.mu.Lock() + + c.lastFetch = c.clock.Now() + c.refresher.Reset(refreshInterval) + c.keys = keys + c.fetching = false + c.cond.Broadcast() +} + +// cryptoKeys queries the control plane for the crypto keys. +// Outside of initialization, this should only be called by fetch. +func (c *cache) cryptoKeys(ctx context.Context) (map[int32]codersdk.CryptoKey, error) { + keys, err := c.fetcher.Fetch(ctx) + if err != nil { + return nil, xerrors.Errorf("crypto keys: %w", err) + } + cache := toKeyMap(keys, c.clock.Now()) + return cache, nil +} + +func toKeyMap(keys []codersdk.CryptoKey, now time.Time) map[int32]codersdk.CryptoKey { + m := make(map[int32]codersdk.CryptoKey) + var latest codersdk.CryptoKey + for _, key := range keys { + m[key.Sequence] = key + if key.Sequence > latest.Sequence && key.CanSign(now) { + m[latestSequence] = key + } + } + return m +} + +func (c *cache) Close() error { + c.mu.Lock() + defer c.mu.Unlock() + + if c.closed { + return nil + } + + c.closed = true + c.refreshCancel() + c.refresher.Stop() + c.cond.Broadcast() + + return nil +} diff --git a/enterprise/wsproxy/keycache_test.go b/coderd/cryptokeys/cache_test.go similarity index 60% rename from enterprise/wsproxy/keycache_test.go rename to coderd/cryptokeys/cache_test.go index 210e04f9edf76..92fc4527ae7b3 100644 --- a/enterprise/wsproxy/keycache_test.go +++ b/coderd/cryptokeys/cache_test.go @@ -1,21 +1,28 @@ -package wsproxy_test +package cryptokeys_test import ( "context" + "crypto/rand" + "encoding/hex" + "strconv" "testing" "time" "github.com/stretchr/testify/require" + "go.uber.org/goleak" "cdr.dev/slog/sloggers/slogtest" "github.com/coder/coder/v2/coderd/cryptokeys" "github.com/coder/coder/v2/codersdk" - "github.com/coder/coder/v2/enterprise/wsproxy" "github.com/coder/coder/v2/testutil" "github.com/coder/quartz" ) +func TestMain(m *testing.M) { + goleak.VerifyTestMain(m) +} + func TestCryptoKeyCache(t *testing.T) { t.Parallel() @@ -32,8 +39,8 @@ func TestCryptoKeyCache(t *testing.T) { now := clock.Now().UTC() expected := codersdk.CryptoKey{ - Feature: codersdk.CryptoKeyFeatureWorkspaceApp, - Secret: "key2", + Feature: codersdk.CryptoKeyFeatureTailnetResume, + Secret: generateKey(t, 64), Sequence: 2, StartsAt: now, } @@ -42,12 +49,13 @@ func TestCryptoKeyCache(t *testing.T) { keys: []codersdk.CryptoKey{expected}, } - cache, err := wsproxy.NewCryptoKeyCache(ctx, logger, ff, withClock(clock)) + cache, err := cryptokeys.NewSigningCache(ctx, logger, ff, codersdk.CryptoKeyFeatureTailnetResume, cryptokeys.WithCacheClock(clock)) require.NoError(t, err) - got, err := cache.Signing(ctx) + id, got, err := cache.SigningKey(ctx) require.NoError(t, err) - require.Equal(t, expected, got) + require.Equal(t, keyID(expected), id) + require.Equal(t, decodedSecret(t, expected), got) require.Equal(t, 1, ff.called) }) @@ -63,27 +71,29 @@ func TestCryptoKeyCache(t *testing.T) { keys: []codersdk.CryptoKey{}, } - cache, err := wsproxy.NewCryptoKeyCache(ctx, logger, ff, withClock(clock)) + cache, err := cryptokeys.NewSigningCache(ctx, logger, ff, codersdk.CryptoKeyFeatureTailnetResume, cryptokeys.WithCacheClock(clock)) require.NoError(t, err) expected := codersdk.CryptoKey{ - Feature: codersdk.CryptoKeyFeatureWorkspaceApp, - Secret: "key1", + Feature: codersdk.CryptoKeyFeatureTailnetResume, + Secret: generateKey(t, 64), Sequence: 12, StartsAt: clock.Now().UTC(), } ff.keys = []codersdk.CryptoKey{expected} - got, err := cache.Signing(ctx) + id, got, err := cache.SigningKey(ctx) require.NoError(t, err) - require.Equal(t, expected, got) + require.Equal(t, decodedSecret(t, expected), got) + require.Equal(t, keyID(expected), id) // 1 on startup + missing cache. require.Equal(t, 2, ff.called) // Ensure the cache gets hit this time. - got, err = cache.Signing(ctx) + id, got, err = cache.SigningKey(ctx) require.NoError(t, err) - require.Equal(t, expected, got) + require.Equal(t, decodedSecret(t, expected), got) + require.Equal(t, keyID(expected), id) // 1 on startup + missing cache. require.Equal(t, 2, ff.called) }) @@ -97,9 +107,10 @@ func TestCryptoKeyCache(t *testing.T) { clock = quartz.NewMock(t) ) now := clock.Now().UTC() + expected := codersdk.CryptoKey{ - Feature: codersdk.CryptoKeyFeatureWorkspaceApp, - Secret: "key1", + Feature: codersdk.CryptoKeyFeatureTailnetResume, + Secret: generateKey(t, 64), Sequence: 1, StartsAt: clock.Now().UTC(), } @@ -108,8 +119,8 @@ func TestCryptoKeyCache(t *testing.T) { keys: []codersdk.CryptoKey{ expected, { - Feature: codersdk.CryptoKeyFeatureWorkspaceApp, - Secret: "key2", + Feature: codersdk.CryptoKeyFeatureTailnetResume, + Secret: generateKey(t, 64), Sequence: 2, StartsAt: now.Add(-time.Second), DeletesAt: now, @@ -117,12 +128,13 @@ func TestCryptoKeyCache(t *testing.T) { }, } - cache, err := wsproxy.NewCryptoKeyCache(ctx, logger, ff, withClock(clock)) + cache, err := cryptokeys.NewSigningCache(ctx, logger, ff, codersdk.CryptoKeyFeatureTailnetResume, cryptokeys.WithCacheClock(clock)) require.NoError(t, err) - got, err := cache.Signing(ctx) + id, got, err := cache.SigningKey(ctx) require.NoError(t, err) - require.Equal(t, expected, got) + require.Equal(t, decodedSecret(t, expected), got) + require.Equal(t, keyID(expected), id) require.Equal(t, 1, ff.called) }) @@ -132,17 +144,16 @@ func TestCryptoKeyCache(t *testing.T) { var ( ctx = testutil.Context(t, testutil.WaitShort) logger = slogtest.Make(t, nil) - clock = quartz.NewMock(t) ) ff := &fakeFetcher{ keys: []codersdk.CryptoKey{}, } - cache, err := wsproxy.NewCryptoKeyCache(ctx, logger, ff, withClock(clock)) + cache, err := cryptokeys.NewSigningCache(ctx, logger, ff, codersdk.CryptoKeyFeatureTailnetResume) require.NoError(t, err) - _, err = cache.Signing(ctx) + _, _, err = cache.SigningKey(ctx) require.ErrorIs(t, err, cryptokeys.ErrKeyNotFound) }) }) @@ -161,8 +172,8 @@ func TestCryptoKeyCache(t *testing.T) { now := clock.Now().UTC() expected := codersdk.CryptoKey{ - Feature: codersdk.CryptoKeyFeatureWorkspaceApp, - Secret: "key1", + Feature: codersdk.CryptoKeyFeatureTailnetResume, + Secret: generateKey(t, 64), Sequence: 12, StartsAt: now, } @@ -170,20 +181,20 @@ func TestCryptoKeyCache(t *testing.T) { keys: []codersdk.CryptoKey{ expected, { - Feature: codersdk.CryptoKeyFeatureWorkspaceApp, - Secret: "key2", + Feature: codersdk.CryptoKeyFeatureTailnetResume, + Secret: generateKey(t, 64), Sequence: 13, StartsAt: now, }, }, } - cache, err := wsproxy.NewCryptoKeyCache(ctx, logger, ff, withClock(clock)) + cache, err := cryptokeys.NewSigningCache(ctx, logger, ff, codersdk.CryptoKeyFeatureTailnetResume, cryptokeys.WithCacheClock(clock)) require.NoError(t, err) - got, err := cache.Verifying(ctx, expected.Sequence) + got, err := cache.VerifyingKey(ctx, keyID(expected)) require.NoError(t, err) - require.Equal(t, expected, got) + require.Equal(t, decodedSecret(t, expected), got) require.Equal(t, 1, ff.called) }) @@ -199,26 +210,26 @@ func TestCryptoKeyCache(t *testing.T) { keys: []codersdk.CryptoKey{}, } - cache, err := wsproxy.NewCryptoKeyCache(ctx, logger, ff, withClock(clock)) + cache, err := cryptokeys.NewSigningCache(ctx, logger, ff, codersdk.CryptoKeyFeatureTailnetResume, cryptokeys.WithCacheClock(clock)) require.NoError(t, err) expected := codersdk.CryptoKey{ - Feature: codersdk.CryptoKeyFeatureWorkspaceApp, - Secret: "key1", + Feature: codersdk.CryptoKeyFeatureTailnetResume, + Secret: generateKey(t, 64), Sequence: 12, StartsAt: clock.Now().UTC(), } ff.keys = []codersdk.CryptoKey{expected} - got, err := cache.Verifying(ctx, expected.Sequence) + got, err := cache.VerifyingKey(ctx, keyID(expected)) require.NoError(t, err) - require.Equal(t, expected, got) + require.Equal(t, decodedSecret(t, expected), got) require.Equal(t, 2, ff.called) // Ensure the cache gets hit this time. - got, err = cache.Verifying(ctx, expected.Sequence) + got, err = cache.VerifyingKey(ctx, keyID(expected)) require.NoError(t, err) - require.Equal(t, expected, got) + require.Equal(t, decodedSecret(t, expected), got) require.Equal(t, 2, ff.called) }) @@ -233,8 +244,8 @@ func TestCryptoKeyCache(t *testing.T) { now := clock.Now().UTC() expected := codersdk.CryptoKey{ - Feature: codersdk.CryptoKeyFeatureWorkspaceApp, - Secret: "key1", + Feature: codersdk.CryptoKeyFeatureTailnetResume, + Secret: generateKey(t, 64), Sequence: 12, StartsAt: now.Add(-time.Second), } @@ -245,16 +256,16 @@ func TestCryptoKeyCache(t *testing.T) { }, } - cache, err := wsproxy.NewCryptoKeyCache(ctx, logger, ff, withClock(clock)) + cache, err := cryptokeys.NewSigningCache(ctx, logger, ff, codersdk.CryptoKeyFeatureTailnetResume, cryptokeys.WithCacheClock(clock)) require.NoError(t, err) - got, err := cache.Verifying(ctx, expected.Sequence) + got, err := cache.VerifyingKey(ctx, keyID(expected)) require.NoError(t, err) - require.Equal(t, expected, got) + require.Equal(t, decodedSecret(t, expected), got) require.Equal(t, 1, ff.called) }) - t.Run("KeyInvalid", func(t *testing.T) { + t.Run("KeyPastDeletesAt", func(t *testing.T) { t.Parallel() var ( @@ -265,8 +276,8 @@ func TestCryptoKeyCache(t *testing.T) { now := clock.Now().UTC() expected := codersdk.CryptoKey{ - Feature: codersdk.CryptoKeyFeatureWorkspaceApp, - Secret: "key1", + Feature: codersdk.CryptoKeyFeatureTailnetResume, + Secret: generateKey(t, 64), Sequence: 12, StartsAt: now.Add(-time.Second), DeletesAt: now, @@ -278,10 +289,10 @@ func TestCryptoKeyCache(t *testing.T) { }, } - cache, err := wsproxy.NewCryptoKeyCache(ctx, logger, ff, withClock(clock)) + cache, err := cryptokeys.NewSigningCache(ctx, logger, ff, codersdk.CryptoKeyFeatureTailnetResume, cryptokeys.WithCacheClock(clock)) require.NoError(t, err) - _, err = cache.Verifying(ctx, expected.Sequence) + _, err = cache.VerifyingKey(ctx, keyID(expected)) require.ErrorIs(t, err, cryptokeys.ErrKeyInvalid) require.Equal(t, 1, ff.called) }) @@ -299,10 +310,10 @@ func TestCryptoKeyCache(t *testing.T) { keys: []codersdk.CryptoKey{}, } - cache, err := wsproxy.NewCryptoKeyCache(ctx, logger, ff, withClock(clock)) + cache, err := cryptokeys.NewSigningCache(ctx, logger, ff, codersdk.CryptoKeyFeatureTailnetResume, cryptokeys.WithCacheClock(clock)) require.NoError(t, err) - _, err = cache.Verifying(ctx, 1) + _, err = cache.VerifyingKey(ctx, "1") require.ErrorIs(t, err, cryptokeys.ErrKeyNotFound) }) }) @@ -318,8 +329,8 @@ func TestCryptoKeyCache(t *testing.T) { now := clock.Now().UTC() expected := codersdk.CryptoKey{ - Feature: codersdk.CryptoKeyFeatureWorkspaceApp, - Secret: "key1", + Feature: codersdk.CryptoKeyFeatureTailnetResume, + Secret: generateKey(t, 64), Sequence: 12, StartsAt: now, DeletesAt: now.Add(time.Minute * 10), @@ -330,17 +341,18 @@ func TestCryptoKeyCache(t *testing.T) { }, } - cache, err := wsproxy.NewCryptoKeyCache(ctx, logger, ff, withClock(clock)) + cache, err := cryptokeys.NewSigningCache(ctx, logger, ff, codersdk.CryptoKeyFeatureTailnetResume, cryptokeys.WithCacheClock(clock)) require.NoError(t, err) - got, err := cache.Signing(ctx) + id, got, err := cache.SigningKey(ctx) require.NoError(t, err) - require.Equal(t, expected, got) + require.Equal(t, decodedSecret(t, expected), got) + require.Equal(t, keyID(expected), id) require.Equal(t, 1, ff.called) newKey := codersdk.CryptoKey{ - Feature: codersdk.CryptoKeyFeatureWorkspaceApp, - Secret: "key2", + Feature: codersdk.CryptoKeyFeatureTailnetResume, + Secret: generateKey(t, 64), Sequence: 13, StartsAt: now, } @@ -353,9 +365,10 @@ func TestCryptoKeyCache(t *testing.T) { require.Equal(t, time.Minute*10, dur) // Assert hits cache. - got, err = cache.Signing(ctx) + id, got, err = cache.SigningKey(ctx) require.NoError(t, err) - require.Equal(t, newKey, got) + require.Equal(t, keyID(newKey), id) + require.Equal(t, decodedSecret(t, newKey), got) require.Equal(t, 2, ff.called) // We check again to ensure the timer has been reset. @@ -379,8 +392,8 @@ func TestCryptoKeyCache(t *testing.T) { now := clock.Now().UTC() expected := codersdk.CryptoKey{ - Feature: codersdk.CryptoKeyFeatureWorkspaceApp, - Secret: "key1", + Feature: codersdk.CryptoKeyFeatureTailnetResume, + Secret: generateKey(t, 64), Sequence: 12, StartsAt: now, DeletesAt: now.Add(time.Minute * 10), @@ -393,23 +406,24 @@ func TestCryptoKeyCache(t *testing.T) { // Create a trap that blocks when the refresh timer fires. trap := clock.Trap().Now("refresh") - cache, err := wsproxy.NewCryptoKeyCache(ctx, logger, ff, withClock(clock)) + cache, err := cryptokeys.NewSigningCache(ctx, logger, ff, codersdk.CryptoKeyFeatureTailnetResume, cryptokeys.WithCacheClock(clock)) require.NoError(t, err) _, wait := clock.AdvanceNext() trapped := trap.MustWait(ctx) newKey := codersdk.CryptoKey{ - Feature: codersdk.CryptoKeyFeatureWorkspaceApp, - Secret: "key2", + Feature: codersdk.CryptoKeyFeatureTailnetResume, + Secret: generateKey(t, 64), Sequence: 13, StartsAt: now, } ff.keys = []codersdk.CryptoKey{newKey} - _, err = cache.Verifying(ctx, newKey.Sequence) + key, err := cache.VerifyingKey(ctx, keyID(newKey)) require.NoError(t, err) require.Equal(t, 2, ff.called) + require.Equal(t, decodedSecret(t, newKey), key) trapped.Release() wait.MustWait(ctx) @@ -434,8 +448,8 @@ func TestCryptoKeyCache(t *testing.T) { now := clock.Now() expected := codersdk.CryptoKey{ - Feature: codersdk.CryptoKeyFeatureWorkspaceApp, - Secret: "key1", + Feature: codersdk.CryptoKeyFeatureTailnetResume, + Secret: generateKey(t, 64), Sequence: 12, StartsAt: now, } @@ -445,25 +459,26 @@ func TestCryptoKeyCache(t *testing.T) { }, } - cache, err := wsproxy.NewCryptoKeyCache(ctx, logger, ff, withClock(clock)) + cache, err := cryptokeys.NewSigningCache(ctx, logger, ff, codersdk.CryptoKeyFeatureTailnetResume, cryptokeys.WithCacheClock(clock)) require.NoError(t, err) - got, err := cache.Signing(ctx) + id, got, err := cache.SigningKey(ctx) require.NoError(t, err) - require.Equal(t, expected, got) + require.Equal(t, keyID(expected), id) + require.Equal(t, decodedSecret(t, expected), got) require.Equal(t, 1, ff.called) - got, err = cache.Verifying(ctx, expected.Sequence) + key, err := cache.VerifyingKey(ctx, keyID(expected)) require.NoError(t, err) - require.Equal(t, expected, got) + require.Equal(t, decodedSecret(t, expected), key) require.Equal(t, 1, ff.called) cache.Close() - _, err = cache.Signing(ctx) + _, _, err = cache.SigningKey(ctx) require.ErrorIs(t, err, cryptokeys.ErrClosed) - _, err = cache.Verifying(ctx, expected.Sequence) + _, err = cache.VerifyingKey(ctx, keyID(expected)) require.ErrorIs(t, err, cryptokeys.ErrClosed) }) } @@ -478,8 +493,25 @@ func (f *fakeFetcher) Fetch(_ context.Context) ([]codersdk.CryptoKey, error) { return f.keys, nil } -func withClock(clock quartz.Clock) func(*wsproxy.CryptoKeyCache) { - return func(cache *wsproxy.CryptoKeyCache) { - cache.Clock = clock - } +func keyID(key codersdk.CryptoKey) string { + return strconv.FormatInt(int64(key.Sequence), 10) +} + +func decodedSecret(t *testing.T, key codersdk.CryptoKey) []byte { + t.Helper() + + secret, err := hex.DecodeString(key.Secret) + require.NoError(t, err) + + return secret +} + +func generateKey(t *testing.T, size int) string { + t.Helper() + + key := make([]byte, size) + _, err := rand.Read(key) + require.NoError(t, err) + + return hex.EncodeToString(key) } diff --git a/coderd/cryptokeys/dbkeycache.go b/coderd/cryptokeys/dbkeycache.go deleted file mode 100644 index aa0a2444b35f2..0000000000000 --- a/coderd/cryptokeys/dbkeycache.go +++ /dev/null @@ -1,286 +0,0 @@ -package cryptokeys - -import ( - "context" - "strconv" - "sync" - "time" - - "golang.org/x/xerrors" - - "cdr.dev/slog" - "github.com/coder/coder/v2/coderd/database" - "github.com/coder/quartz" -) - -// never represents the maximum value for a time.Duration. -const never = 1<<63 - 1 - -// dbCache implements Keycache for callers with access to the database. -type dbCache struct { - db database.Store - feature database.CryptoKeyFeature - logger slog.Logger - clock quartz.Clock - - // The following are initialized by NewDBCache. - keysMu sync.RWMutex - keys map[int32]database.CryptoKey - latestKey database.CryptoKey - timer *quartz.Timer - // invalidateAt is the time at which the keys cache should be invalidated. - invalidateAt time.Time - closed bool -} - -type DBCacheOption func(*dbCache) - -func WithDBCacheClock(clock quartz.Clock) DBCacheOption { - return func(d *dbCache) { - d.clock = clock - } -} - -// NewSigningCache creates a new DBCache. Close should be called to -// release resources associated with its internal timer. -func NewSigningCache(logger slog.Logger, db database.Store, feature database.CryptoKeyFeature, opts ...func(*dbCache)) (SigningKeycache, error) { - if !isSigningKeyFeature(feature) { - return nil, ErrInvalidFeature - } - - return newDBCache(logger, db, feature, opts...), nil -} - -func NewEncryptionCache(logger slog.Logger, db database.Store, feature database.CryptoKeyFeature, opts ...func(*dbCache)) (EncryptionKeycache, error) { - if !isEncryptionKeyFeature(feature) { - return nil, ErrInvalidFeature - } - - return newDBCache(logger, db, feature, opts...), nil -} - -func newDBCache(logger slog.Logger, db database.Store, feature database.CryptoKeyFeature, opts ...func(*dbCache)) *dbCache { - d := &dbCache{ - db: db, - feature: feature, - clock: quartz.NewReal(), - logger: logger, - } - - for _, opt := range opts { - opt(d) - } - - // Initialize the timer. This will get properly initialized the first time we fetch. - d.timer = d.clock.AfterFunc(never, d.clear) - - return d -} - -func (d *dbCache) EncryptingKey(ctx context.Context) (id string, key interface{}, err error) { - if !isEncryptionKeyFeature(d.feature) { - return "", nil, ErrInvalidFeature - } - - return d.latest(ctx) -} - -func (d *dbCache) DecryptingKey(ctx context.Context, id string) (key interface{}, err error) { - if !isEncryptionKeyFeature(d.feature) { - return nil, ErrInvalidFeature - } - - return d.sequence(ctx, id) -} - -func (d *dbCache) SigningKey(ctx context.Context) (id string, key interface{}, err error) { - if !isSigningKeyFeature(d.feature) { - return "", nil, ErrInvalidFeature - } - - return d.latest(ctx) -} - -func (d *dbCache) VerifyingKey(ctx context.Context, id string) (key interface{}, err error) { - if !isSigningKeyFeature(d.feature) { - return nil, ErrInvalidFeature - } - - return d.sequence(ctx, id) -} - -// sequence returns the CryptoKey with the given sequence number, provided that -// it is neither deleted nor has breached its deletion date. It should only be -// used for verifying or decrypting payloads. To sign/encrypt call Signing. -func (d *dbCache) sequence(ctx context.Context, id string) (interface{}, error) { - sequence, err := strconv.ParseInt(id, 10, 32) - if err != nil { - return nil, xerrors.Errorf("expecting sequence number got %q: %w", id, err) - } - - d.keysMu.RLock() - if d.closed { - d.keysMu.RUnlock() - return nil, ErrClosed - } - - now := d.clock.Now() - key, ok := d.keys[int32(sequence)] - d.keysMu.RUnlock() - if ok { - return checkKey(key, now) - } - - d.keysMu.Lock() - defer d.keysMu.Unlock() - - if d.closed { - return nil, ErrClosed - } - - key, ok = d.keys[int32(sequence)] - if ok { - return checkKey(key, now) - } - - err = d.fetch(ctx) - if err != nil { - return nil, xerrors.Errorf("fetch: %w", err) - } - - key, ok = d.keys[int32(sequence)] - if !ok { - return nil, ErrKeyNotFound - } - - return checkKey(key, now) -} - -// latest returns the latest valid key for signing. A valid key is one that is -// both past its start time and before its deletion time. -func (d *dbCache) latest(ctx context.Context) (string, interface{}, error) { - d.keysMu.RLock() - - if d.closed { - d.keysMu.RUnlock() - return "", nil, ErrClosed - } - - latest := d.latestKey - d.keysMu.RUnlock() - - now := d.clock.Now() - if latest.CanSign(now) { - return idSecret(latest) - } - - d.keysMu.Lock() - defer d.keysMu.Unlock() - - if d.closed { - return "", nil, ErrClosed - } - - if d.latestKey.CanSign(now) { - return idSecret(d.latestKey) - } - - // Refetch all keys for this feature so we can find the latest valid key. - err := d.fetch(ctx) - if err != nil { - return "", nil, xerrors.Errorf("fetch: %w", err) - } - - return idSecret(d.latestKey) -} - -// clear invalidates the cache. This forces the subsequent call to fetch fresh keys. -func (d *dbCache) clear() { - now := d.clock.Now("DBCache", "clear") - d.keysMu.Lock() - defer d.keysMu.Unlock() - // Check if we raced with a fetch. It's possible that the timer fired and we - // lost the race to the mutex. We want to avoid invalidating - // a cache that was just refetched. - if now.Before(d.invalidateAt) { - return - } - d.keys = nil - d.latestKey = database.CryptoKey{} -} - -// fetch fetches all keys for the given feature and determines the latest key. -// It must be called while holding the keysMu lock. -func (d *dbCache) fetch(ctx context.Context) error { - keys, err := d.db.GetCryptoKeysByFeature(ctx, d.feature) - if err != nil { - return xerrors.Errorf("get crypto keys by feature: %w", err) - } - - now := d.clock.Now() - _ = d.timer.Reset(time.Minute * 10) - d.invalidateAt = now.Add(time.Minute * 10) - - cache := make(map[int32]database.CryptoKey) - var latest database.CryptoKey - for _, key := range keys { - cache[key.Sequence] = key - if key.CanSign(now) && key.Sequence > latest.Sequence { - latest = key - } - } - - if len(cache) == 0 { - return ErrKeyNotFound - } - - if !latest.CanSign(now) { - return ErrKeyInvalid - } - - d.keys, d.latestKey = cache, latest - return nil -} - -func checkKey(key database.CryptoKey, now time.Time) (interface{}, error) { - if !key.CanVerify(now) { - return nil, ErrKeyInvalid - } - - return key.DecodeString() -} - -func (d *dbCache) Close() error { - d.keysMu.Lock() - defer d.keysMu.Unlock() - - if d.closed { - return nil - } - - d.timer.Stop() - d.closed = true - return nil -} - -func isEncryptionKeyFeature(feature database.CryptoKeyFeature) bool { - return feature == database.CryptoKeyFeatureWorkspaceApps -} - -func isSigningKeyFeature(feature database.CryptoKeyFeature) bool { - switch feature { - case database.CryptoKeyFeatureTailnetResume, database.CryptoKeyFeatureOidcConvert: - return true - default: - return false - } -} - -func idSecret(k database.CryptoKey) (string, interface{}, error) { - key, err := k.DecodeString() - if err != nil { - return "", nil, xerrors.Errorf("decode key: %w", err) - } - - return strconv.FormatInt(int64(k.Sequence), 10), key, nil -} diff --git a/coderd/cryptokeys/dbkeycache_internal_test.go b/coderd/cryptokeys/dbkeycache_internal_test.go deleted file mode 100644 index c27bc5b8468ad..0000000000000 --- a/coderd/cryptokeys/dbkeycache_internal_test.go +++ /dev/null @@ -1,490 +0,0 @@ -package cryptokeys - -import ( - "database/sql" - "strconv" - "testing" - "time" - - "github.com/stretchr/testify/require" - "go.uber.org/mock/gomock" - - "cdr.dev/slog/sloggers/slogtest" - - "github.com/coder/coder/v2/coderd/database" - "github.com/coder/coder/v2/coderd/database/dbmock" - "github.com/coder/coder/v2/testutil" - "github.com/coder/quartz" -) - -func Test_version(t *testing.T) { - t.Parallel() - - t.Run("HitsCache", func(t *testing.T) { - t.Parallel() - - var ( - ctrl = gomock.NewController(t) - mockDB = dbmock.NewMockStore(ctrl) - clock = quartz.NewMock(t) - logger = slogtest.Make(t, nil) - ctx = testutil.Context(t, testutil.WaitShort) - ) - - expectedKey := database.CryptoKey{ - Feature: database.CryptoKeyFeatureWorkspaceApps, - Sequence: 32, - Secret: sql.NullString{ - String: mustGenerateKey(t), - Valid: true, - }, - } - - cache := map[int32]database.CryptoKey{ - 32: expectedKey, - } - - k := newDBCache(logger, mockDB, database.CryptoKeyFeatureWorkspaceApps, WithDBCacheClock(clock)) - defer k.Close() - k.keys = cache - - secret, err := k.sequence(ctx, keyID(expectedKey)) - require.NoError(t, err) - require.Equal(t, decodedSecret(t, expectedKey), secret) - }) - - t.Run("MissesCache", func(t *testing.T) { - t.Parallel() - - var ( - ctrl = gomock.NewController(t) - mockDB = dbmock.NewMockStore(ctrl) - clock = quartz.NewMock(t) - ctx = testutil.Context(t, testutil.WaitShort) - logger = slogtest.Make(t, nil) - ) - - expectedKey := database.CryptoKey{ - Feature: database.CryptoKeyFeatureWorkspaceApps, - Sequence: 33, - StartsAt: clock.Now(), - Secret: sql.NullString{ - String: mustGenerateKey(t), - Valid: true, - }, - } - - mockDB.EXPECT().GetCryptoKeysByFeature(ctx, database.CryptoKeyFeatureWorkspaceApps).Return([]database.CryptoKey{expectedKey}, nil) - - k := newDBCache(logger, mockDB, database.CryptoKeyFeatureWorkspaceApps, WithDBCacheClock(clock)) - defer k.Close() - - got, err := k.sequence(ctx, keyID(expectedKey)) - require.NoError(t, err) - require.Equal(t, decodedSecret(t, expectedKey), got) - }) - - t.Run("InvalidCachedKey", func(t *testing.T) { - t.Parallel() - - var ( - ctrl = gomock.NewController(t) - mockDB = dbmock.NewMockStore(ctrl) - clock = quartz.NewMock(t) - ctx = testutil.Context(t, testutil.WaitShort) - logger = slogtest.Make(t, nil) - ) - - cache := map[int32]database.CryptoKey{ - 32: { - Feature: database.CryptoKeyFeatureWorkspaceApps, - Sequence: 32, - Secret: sql.NullString{ - String: mustGenerateKey(t), - Valid: true, - }, - DeletesAt: sql.NullTime{ - Time: clock.Now(), - Valid: true, - }, - }, - } - - k := newDBCache(logger, mockDB, database.CryptoKeyFeatureWorkspaceApps, WithDBCacheClock(clock)) - defer k.Close() - k.keys = cache - - _, err := k.sequence(ctx, "32") - require.ErrorIs(t, err, ErrKeyInvalid) - }) - - t.Run("InvalidDBKey", func(t *testing.T) { - t.Parallel() - - var ( - ctrl = gomock.NewController(t) - mockDB = dbmock.NewMockStore(ctrl) - clock = quartz.NewMock(t) - ctx = testutil.Context(t, testutil.WaitShort) - logger = slogtest.Make(t, nil) - ) - - invalidKey := database.CryptoKey{ - Feature: database.CryptoKeyFeatureWorkspaceApps, - Sequence: 32, - Secret: sql.NullString{ - String: mustGenerateKey(t), - Valid: true, - }, - DeletesAt: sql.NullTime{ - Time: clock.Now(), - Valid: true, - }, - } - mockDB.EXPECT().GetCryptoKeysByFeature(ctx, database.CryptoKeyFeatureWorkspaceApps).Return([]database.CryptoKey{invalidKey}, nil) - - k := newDBCache(logger, mockDB, database.CryptoKeyFeatureWorkspaceApps, WithDBCacheClock(clock)) - defer k.Close() - - _, err := k.sequence(ctx, keyID(invalidKey)) - require.ErrorIs(t, err, ErrKeyInvalid) - }) -} - -func Test_latest(t *testing.T) { - t.Parallel() - - t.Run("HitsCache", func(t *testing.T) { - t.Parallel() - - var ( - ctrl = gomock.NewController(t) - mockDB = dbmock.NewMockStore(ctrl) - clock = quartz.NewMock(t) - ctx = testutil.Context(t, testutil.WaitShort) - logger = slogtest.Make(t, nil) - ) - - latestKey := database.CryptoKey{ - Feature: database.CryptoKeyFeatureWorkspaceApps, - Sequence: 32, - Secret: sql.NullString{ - String: mustGenerateKey(t), - Valid: true, - }, - StartsAt: clock.Now(), - } - k := newDBCache(logger, mockDB, database.CryptoKeyFeatureWorkspaceApps, WithDBCacheClock(clock)) - defer k.Close() - - k.latestKey = latestKey - - id, secret, err := k.latest(ctx) - require.NoError(t, err) - require.Equal(t, keyID(latestKey), id) - require.Equal(t, decodedSecret(t, latestKey), secret) - }) - - t.Run("InvalidCachedKey", func(t *testing.T) { - t.Parallel() - - var ( - ctrl = gomock.NewController(t) - mockDB = dbmock.NewMockStore(ctrl) - clock = quartz.NewMock(t) - ctx = testutil.Context(t, testutil.WaitShort) - logger = slogtest.Make(t, nil) - ) - - latestKey := database.CryptoKey{ - Feature: database.CryptoKeyFeatureWorkspaceApps, - Sequence: 33, - Secret: sql.NullString{ - String: mustGenerateKey(t), - Valid: true, - }, - StartsAt: clock.Now(), - } - - invalidKey := database.CryptoKey{ - Feature: database.CryptoKeyFeatureWorkspaceApps, - Sequence: 32, - Secret: sql.NullString{ - String: mustGenerateKey(t), - Valid: true, - }, - StartsAt: clock.Now().Add(-time.Hour), - DeletesAt: sql.NullTime{ - Time: clock.Now(), - Valid: true, - }, - } - - mockDB.EXPECT().GetCryptoKeysByFeature(ctx, database.CryptoKeyFeatureWorkspaceApps).Return([]database.CryptoKey{latestKey}, nil) - - k := newDBCache(logger, mockDB, database.CryptoKeyFeatureWorkspaceApps, WithDBCacheClock(clock)) - defer k.Close() - k.latestKey = invalidKey - - id, secret, err := k.latest(ctx) - require.NoError(t, err) - require.Equal(t, keyID(latestKey), id) - require.Equal(t, decodedSecret(t, latestKey), secret) - }) - - t.Run("UsesActiveKey", func(t *testing.T) { - t.Parallel() - - var ( - ctrl = gomock.NewController(t) - mockDB = dbmock.NewMockStore(ctrl) - clock = quartz.NewMock(t) - ctx = testutil.Context(t, testutil.WaitShort) - logger = slogtest.Make(t, nil) - ) - - inactiveKey := database.CryptoKey{ - Feature: database.CryptoKeyFeatureWorkspaceApps, - Sequence: 32, - Secret: sql.NullString{ - String: mustGenerateKey(t), - Valid: true, - }, - StartsAt: clock.Now().Add(time.Hour), - } - - activeKey := database.CryptoKey{ - Feature: database.CryptoKeyFeatureWorkspaceApps, - Sequence: 33, - Secret: sql.NullString{ - String: mustGenerateKey(t), - Valid: true, - }, - StartsAt: clock.Now(), - } - - mockDB.EXPECT().GetCryptoKeysByFeature(ctx, database.CryptoKeyFeatureWorkspaceApps).Return([]database.CryptoKey{inactiveKey, activeKey}, nil) - - k := newDBCache(logger, mockDB, database.CryptoKeyFeatureWorkspaceApps, WithDBCacheClock(clock)) - defer k.Close() - - id, secret, err := k.latest(ctx) - require.NoError(t, err) - require.Equal(t, keyID(activeKey), id) - require.Equal(t, decodedSecret(t, activeKey), secret) - }) - - t.Run("NoValidKeys", func(t *testing.T) { - t.Parallel() - - var ( - ctrl = gomock.NewController(t) - mockDB = dbmock.NewMockStore(ctrl) - clock = quartz.NewMock(t) - ctx = testutil.Context(t, testutil.WaitShort) - logger = slogtest.Make(t, nil) - ) - - inactiveKey := database.CryptoKey{ - Feature: database.CryptoKeyFeatureWorkspaceApps, - Sequence: 32, - Secret: sql.NullString{ - String: mustGenerateKey(t), - Valid: true, - }, - StartsAt: clock.Now().Add(time.Hour), - } - - invalidKey := database.CryptoKey{ - Feature: database.CryptoKeyFeatureWorkspaceApps, - Sequence: 33, - Secret: sql.NullString{ - String: mustGenerateKey(t), - Valid: true, - }, - StartsAt: clock.Now().Add(-time.Hour), - DeletesAt: sql.NullTime{ - Time: clock.Now(), - Valid: true, - }, - } - - mockDB.EXPECT().GetCryptoKeysByFeature(ctx, database.CryptoKeyFeatureWorkspaceApps).Return([]database.CryptoKey{inactiveKey, invalidKey}, nil) - - k := newDBCache(logger, mockDB, database.CryptoKeyFeatureWorkspaceApps, WithDBCacheClock(clock)) - defer k.Close() - - _, _, err := k.latest(ctx) - require.ErrorIs(t, err, ErrKeyInvalid) - }) -} - -func Test_clear(t *testing.T) { - t.Parallel() - - t.Run("InvalidatesCache", func(t *testing.T) { - t.Parallel() - - var ( - ctrl = gomock.NewController(t) - mockDB = dbmock.NewMockStore(ctrl) - clock = quartz.NewMock(t) - ctx = testutil.Context(t, testutil.WaitShort) - logger = slogtest.Make(t, nil) - ) - - k := newDBCache(logger, mockDB, database.CryptoKeyFeatureWorkspaceApps, WithDBCacheClock(clock)) - defer k.Close() - - activeKey := database.CryptoKey{ - Feature: database.CryptoKeyFeatureWorkspaceApps, - Sequence: 33, - Secret: sql.NullString{ - String: mustGenerateKey(t), - Valid: true, - }, - StartsAt: clock.Now(), - } - - mockDB.EXPECT().GetCryptoKeysByFeature(ctx, database.CryptoKeyFeatureWorkspaceApps).Return([]database.CryptoKey{activeKey}, nil) - - _, _, err := k.latest(ctx) - require.NoError(t, err) - - dur, wait := clock.AdvanceNext() - wait.MustWait(ctx) - require.Equal(t, time.Minute*10, dur) - require.Len(t, k.keys, 0) - require.Equal(t, database.CryptoKey{}, k.latestKey) - }) - - t.Run("ResetsTimer", func(t *testing.T) { - t.Parallel() - - var ( - ctrl = gomock.NewController(t) - mockDB = dbmock.NewMockStore(ctrl) - clock = quartz.NewMock(t) - ctx = testutil.Context(t, testutil.WaitShort) - logger = slogtest.Make(t, nil) - ) - - k := newDBCache(logger, mockDB, database.CryptoKeyFeatureWorkspaceApps, WithDBCacheClock(clock)) - defer k.Close() - - key := database.CryptoKey{ - Feature: database.CryptoKeyFeatureWorkspaceApps, - Sequence: 32, - Secret: sql.NullString{ - String: mustGenerateKey(t), - Valid: true, - }, - StartsAt: clock.Now(), - } - - mockDB.EXPECT().GetCryptoKeysByFeature(ctx, database.CryptoKeyFeatureWorkspaceApps).Return([]database.CryptoKey{key}, nil) - - // Advance it five minutes so that we can test that the - // timer is reset and doesn't fire after another five minute. - clock.Advance(time.Minute * 5) - - id, secret, err := k.latest(ctx) - require.NoError(t, err) - require.Equal(t, keyID(key), id) - require.Equal(t, decodedSecret(t, key), secret) - - // Advancing the clock now should require 10 minutes - // before the timer fires again. - dur, wait := clock.AdvanceNext() - wait.MustWait(ctx) - require.Equal(t, time.Minute*10, dur) - require.Len(t, k.keys, 0) - require.Equal(t, database.CryptoKey{}, k.latestKey) - }) - - // InvalidateAt tests that we have accounted for the race condition where a - // timer fires to invalidate the cache at the same time we are fetching new - // keys. In such cases we want to skip invalidation. - t.Run("InvalidateAt", func(t *testing.T) { - t.Parallel() - - var ( - ctrl = gomock.NewController(t) - mockDB = dbmock.NewMockStore(ctrl) - clock = quartz.NewMock(t) - ctx = testutil.Context(t, testutil.WaitShort) - logger = slogtest.Make(t, nil) - ) - - trap := clock.Trap().Now("clear") - - k := newDBCache(logger, mockDB, database.CryptoKeyFeatureWorkspaceApps, WithDBCacheClock(clock)) - defer k.Close() - - key := database.CryptoKey{ - Feature: database.CryptoKeyFeatureWorkspaceApps, - Sequence: 32, - Secret: sql.NullString{ - String: mustGenerateKey(t), - Valid: true, - }, - StartsAt: clock.Now(), - } - - mockDB.EXPECT().GetCryptoKeysByFeature(ctx, database.CryptoKeyFeatureWorkspaceApps).Return([]database.CryptoKey{key}, nil).Times(2) - - // Move us past the initial timer. - id, secret, err := k.latest(ctx) - require.NoError(t, err) - require.Equal(t, keyID(key), id) - require.Equal(t, decodedSecret(t, key), secret) - // Null these out so that we refetch. - k.keys = nil - k.latestKey = database.CryptoKey{} - - // Initiate firing the timer. - dur, wait := clock.AdvanceNext() - require.Equal(t, time.Minute*10, dur) - // Trap the function just before acquiring the mutex. - call := trap.MustWait(ctx) - - // Refetch keys. - id, secret, err = k.latest(ctx) - require.NoError(t, err) - require.Equal(t, keyID(key), id) - require.Equal(t, decodedSecret(t, key), secret) - - // Let the rest of the timer function run. - // It should see that we have refetched keys and - // not invalidate. - call.Release() - wait.MustWait(ctx) - require.Len(t, k.keys, 1) - require.Equal(t, key, k.latestKey) - trap.Close() - - // Refetching the keys should've instantiated a new timer. This one should invalidate keys. - _, wait = clock.AdvanceNext() - wait.MustWait(ctx) - require.Len(t, k.keys, 0) - require.Equal(t, database.CryptoKey{}, k.latestKey) - }) -} - -func mustGenerateKey(t *testing.T) string { - t.Helper() - key, err := generateKey(64) - require.NoError(t, err) - return key -} - -func keyID(key database.CryptoKey) string { - return strconv.FormatInt(int64(key.Sequence), 10) -} - -func decodedSecret(t *testing.T, key database.CryptoKey) []byte { - t.Helper() - decoded, err := key.DecodeString() - require.NoError(t, err) - return decoded -} diff --git a/coderd/cryptokeys/dbkeycache_test.go b/coderd/cryptokeys/dbkeycache_test.go deleted file mode 100644 index e24ef16660db1..0000000000000 --- a/coderd/cryptokeys/dbkeycache_test.go +++ /dev/null @@ -1,216 +0,0 @@ -package cryptokeys_test - -import ( - "strconv" - "testing" - - "github.com/stretchr/testify/require" - "go.uber.org/goleak" - - "cdr.dev/slog/sloggers/slogtest" - - "github.com/coder/coder/v2/coderd/cryptokeys" - "github.com/coder/coder/v2/coderd/database" - "github.com/coder/coder/v2/coderd/database/dbgen" - "github.com/coder/coder/v2/coderd/database/dbtestutil" - "github.com/coder/coder/v2/testutil" - "github.com/coder/quartz" -) - -func TestMain(m *testing.M) { - goleak.VerifyTestMain(m) -} - -func TestDBKeyCache(t *testing.T) { - t.Parallel() - - t.Run("VerifyingKey", func(t *testing.T) { - t.Parallel() - - t.Run("HitsCache", func(t *testing.T) { - t.Parallel() - - var ( - db, _ = dbtestutil.NewDB(t) - clock = quartz.NewMock(t) - ctx = testutil.Context(t, testutil.WaitShort) - logger = slogtest.Make(t, nil) - ) - - key := dbgen.CryptoKey(t, db, database.CryptoKey{ - Feature: database.CryptoKeyFeatureOidcConvert, - Sequence: 1, - StartsAt: clock.Now().UTC(), - }) - - k, err := cryptokeys.NewSigningCache(logger, db, database.CryptoKeyFeatureOidcConvert, cryptokeys.WithDBCacheClock(clock)) - require.NoError(t, err) - defer k.Close() - - got, err := k.VerifyingKey(ctx, keyID(key)) - require.NoError(t, err) - require.Equal(t, decodedSecret(t, key), got) - }) - - t.Run("NotFound", func(t *testing.T) { - t.Parallel() - - var ( - db, _ = dbtestutil.NewDB(t) - clock = quartz.NewMock(t) - ctx = testutil.Context(t, testutil.WaitShort) - logger = slogtest.Make(t, nil) - ) - - k, err := cryptokeys.NewSigningCache(logger, db, database.CryptoKeyFeatureOidcConvert, cryptokeys.WithDBCacheClock(clock)) - require.NoError(t, err) - defer k.Close() - - _, err = k.VerifyingKey(ctx, "123") - require.ErrorIs(t, err, cryptokeys.ErrKeyNotFound) - }) - }) - - t.Run("Signing", func(t *testing.T) { - t.Parallel() - - var ( - db, _ = dbtestutil.NewDB(t) - clock = quartz.NewMock(t) - ctx = testutil.Context(t, testutil.WaitShort) - logger = slogtest.Make(t, nil) - ) - - _ = dbgen.CryptoKey(t, db, database.CryptoKey{ - Feature: database.CryptoKeyFeatureOidcConvert, - Sequence: 10, - StartsAt: clock.Now().UTC(), - }) - - expectedKey := dbgen.CryptoKey(t, db, database.CryptoKey{ - Feature: database.CryptoKeyFeatureOidcConvert, - Sequence: 12, - StartsAt: clock.Now().UTC(), - }) - - _ = dbgen.CryptoKey(t, db, database.CryptoKey{ - Feature: database.CryptoKeyFeatureOidcConvert, - Sequence: 2, - StartsAt: clock.Now().UTC(), - }) - - k, err := cryptokeys.NewSigningCache(logger, db, database.CryptoKeyFeatureOidcConvert, cryptokeys.WithDBCacheClock(clock)) - require.NoError(t, err) - defer k.Close() - - id, key, err := k.SigningKey(ctx) - require.NoError(t, err) - require.Equal(t, keyID(expectedKey), id) - require.Equal(t, decodedSecret(t, expectedKey), key) - }) - - t.Run("Closed", func(t *testing.T) { - t.Parallel() - - var ( - db, _ = dbtestutil.NewDB(t) - clock = quartz.NewMock(t) - ctx = testutil.Context(t, testutil.WaitShort) - logger = slogtest.Make(t, nil) - ) - - expectedKey := dbgen.CryptoKey(t, db, database.CryptoKey{ - Feature: database.CryptoKeyFeatureOidcConvert, - Sequence: 10, - StartsAt: clock.Now(), - }) - - k, err := cryptokeys.NewSigningCache(logger, db, database.CryptoKeyFeatureOidcConvert, cryptokeys.WithDBCacheClock(clock)) - require.NoError(t, err) - defer k.Close() - - id, key, err := k.SigningKey(ctx) - require.NoError(t, err) - require.Equal(t, keyID(expectedKey), id) - require.Equal(t, decodedSecret(t, expectedKey), key) - - key, err = k.VerifyingKey(ctx, keyID(expectedKey)) - require.NoError(t, err) - require.Equal(t, decodedSecret(t, expectedKey), key) - - k.Close() - - _, _, err = k.SigningKey(ctx) - require.ErrorIs(t, err, cryptokeys.ErrClosed) - - _, err = k.VerifyingKey(ctx, keyID(expectedKey)) - require.ErrorIs(t, err, cryptokeys.ErrClosed) - }) - - t.Run("InvalidSigningFeature", func(t *testing.T) { - t.Parallel() - - var ( - db, _ = dbtestutil.NewDB(t) - clock = quartz.NewMock(t) - logger = slogtest.Make(t, nil) - ctx = testutil.Context(t, testutil.WaitShort) - ) - - _, err := cryptokeys.NewSigningCache(logger, db, database.CryptoKeyFeatureWorkspaceApps, cryptokeys.WithDBCacheClock(clock)) - require.ErrorIs(t, err, cryptokeys.ErrInvalidFeature) - - // Instantiate a signing cache and try to use it as an encryption cache. - sc, err := cryptokeys.NewSigningCache(logger, db, database.CryptoKeyFeatureOidcConvert, cryptokeys.WithDBCacheClock(clock)) - require.NoError(t, err) - defer sc.Close() - - ec, ok := sc.(cryptokeys.EncryptionKeycache) - require.True(t, ok) - _, _, err = ec.EncryptingKey(ctx) - require.ErrorIs(t, err, cryptokeys.ErrInvalidFeature) - - _, err = ec.DecryptingKey(ctx, "123") - require.ErrorIs(t, err, cryptokeys.ErrInvalidFeature) - }) - - t.Run("InvalidEncryptionFeature", func(t *testing.T) { - t.Parallel() - - var ( - db, _ = dbtestutil.NewDB(t) - clock = quartz.NewMock(t) - logger = slogtest.Make(t, nil) - ctx = testutil.Context(t, testutil.WaitShort) - ) - - _, err := cryptokeys.NewEncryptionCache(logger, db, database.CryptoKeyFeatureOidcConvert, cryptokeys.WithDBCacheClock(clock)) - require.ErrorIs(t, err, cryptokeys.ErrInvalidFeature) - - // Instantiate an encryption cache and try to use it as a signing cache. - ec, err := cryptokeys.NewEncryptionCache(logger, db, database.CryptoKeyFeatureWorkspaceApps, cryptokeys.WithDBCacheClock(clock)) - require.NoError(t, err) - defer ec.Close() - - sc, ok := ec.(cryptokeys.SigningKeycache) - require.True(t, ok) - _, _, err = sc.SigningKey(ctx) - require.ErrorIs(t, err, cryptokeys.ErrInvalidFeature) - - _, err = sc.VerifyingKey(ctx, "123") - require.ErrorIs(t, err, cryptokeys.ErrInvalidFeature) - }) -} - -func keyID(key database.CryptoKey) string { - return strconv.FormatInt(int64(key.Sequence), 10) -} - -func decodedSecret(t *testing.T, key database.CryptoKey) []byte { - t.Helper() - - secret, err := key.DecodeString() - require.NoError(t, err) - - return secret -} diff --git a/coderd/cryptokeys/keycache.go b/coderd/cryptokeys/keycache.go deleted file mode 100644 index 05c80a15b2378..0000000000000 --- a/coderd/cryptokeys/keycache.go +++ /dev/null @@ -1,41 +0,0 @@ -package cryptokeys - -import ( - "context" - "io" - - "golang.org/x/xerrors" -) - -var ( - ErrKeyNotFound = xerrors.New("key not found") - ErrKeyInvalid = xerrors.New("key is invalid for use") - ErrClosed = xerrors.New("closed") - ErrInvalidFeature = xerrors.New("invalid feature for this operation") -) - -type EncryptionKeycache interface { - // EncryptingKey returns the latest valid key for encrypting payloads. A valid - // key is one that is both past its start time and before its deletion time. - EncryptingKey(ctx context.Context) (id string, key interface{}, err error) - // DecryptingKey returns the key with the provided id which maps to its sequence - // number. The key is valid for decryption as long as it is not deleted or past - // its deletion date. We must allow for keys prior to their start time to - // account for clock skew between peers (one key may be past its start time on - // one machine while another is not). - DecryptingKey(ctx context.Context, id string) (key interface{}, err error) - io.Closer -} - -type SigningKeycache interface { - // SigningKey returns the latest valid key for signing. A valid key is one - // that is both past its start time and before its deletion time. - SigningKey(ctx context.Context) (id string, key interface{}, err error) - // VerifyingKey returns the key with the provided id which should map to its - // sequence number. The key is valid for verifying as long as it is not deleted - // or past its deletion date. We must allow for keys prior to their start time - // to account for clock skew between peers (one key may be past its start time - // on one machine while another is not). - VerifyingKey(ctx context.Context, id string) (key interface{}, err error) - io.Closer -} diff --git a/coderd/jwtutils/jwe.go b/coderd/jwtutils/jwe.go index f50cacb62de7c..d03816a477a26 100644 --- a/coderd/jwtutils/jwe.go +++ b/coderd/jwtutils/jwe.go @@ -27,7 +27,7 @@ type DecryptKeyProvider interface { func Encrypt(ctx context.Context, e EncryptKeyProvider, claims Claims) (string, error) { id, key, err := e.EncryptingKey(ctx) if err != nil { - return "", xerrors.Errorf("get signing key: %w", err) + return "", xerrors.Errorf("encrypting key: %w", err) } encrypter, err := jose.NewEncrypter( diff --git a/coderd/jwtutils/jwt_test.go b/coderd/jwtutils/jwt_test.go index ff30f7716b310..697e5d210d858 100644 --- a/coderd/jwtutils/jwt_test.go +++ b/coderd/jwtutils/jwt_test.go @@ -18,6 +18,7 @@ import ( "github.com/coder/coder/v2/coderd/database/dbgen" "github.com/coder/coder/v2/coderd/database/dbtestutil" "github.com/coder/coder/v2/coderd/jwtutils" + "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/testutil" ) @@ -238,10 +239,11 @@ func TestJWS(t *testing.T) { Feature: database.CryptoKeyFeatureOidcConvert, StartsAt: time.Now(), }) - log = slogtest.Make(t, nil) + log = slogtest.Make(t, nil) + fetcher = &cryptokeys.DBFetcher{DB: db, Feature: database.CryptoKeyFeatureOidcConvert} ) - cache, err := cryptokeys.NewSigningCache(log, db, database.CryptoKeyFeatureOidcConvert) + cache, err := cryptokeys.NewSigningCache(ctx, log, fetcher, codersdk.CryptoKeyFeatureOIDCConvert) require.NoError(t, err) claims := testClaims{ @@ -328,9 +330,11 @@ func TestJWE(t *testing.T) { StartsAt: time.Now(), }) log = slogtest.Make(t, nil) + + fetcher = &cryptokeys.DBFetcher{DB: db, Feature: database.CryptoKeyFeatureWorkspaceApps} ) - cache, err := cryptokeys.NewEncryptionCache(log, db, database.CryptoKeyFeatureWorkspaceApps) + cache, err := cryptokeys.NewEncryptionCache(ctx, log, fetcher, codersdk.CryptoKeyFeatureWorkspaceApp) require.NoError(t, err) claims := testClaims{ diff --git a/enterprise/wsproxy/keycache.go b/enterprise/wsproxy/keycache.go deleted file mode 100644 index a877b9757d250..0000000000000 --- a/enterprise/wsproxy/keycache.go +++ /dev/null @@ -1,224 +0,0 @@ -package wsproxy - -import ( - "context" - "sync" - "time" - - "golang.org/x/xerrors" - - "cdr.dev/slog" - - "github.com/coder/coder/v2/coderd/cryptokeys" - "github.com/coder/coder/v2/codersdk" - "github.com/coder/quartz" -) - -const ( - // latestSequence is a special sequence number that represents the latest key. - latestSequence = -1 - // refreshInterval is the interval at which the key cache will refresh. - refreshInterval = time.Minute * 10 -) - -type Fetcher interface { - Fetch(ctx context.Context) ([]codersdk.CryptoKey, error) -} - -type CryptoKeyCache struct { - Clock quartz.Clock - refreshCtx context.Context - refreshCancel context.CancelFunc - fetcher Fetcher - logger slog.Logger - - mu sync.Mutex - keys map[int32]codersdk.CryptoKey - lastFetch time.Time - refresher *quartz.Timer - fetching bool - closed bool - cond *sync.Cond -} - -func NewCryptoKeyCache(ctx context.Context, log slog.Logger, client Fetcher, opts ...func(*CryptoKeyCache)) (*CryptoKeyCache, error) { - cache := &CryptoKeyCache{ - Clock: quartz.NewReal(), - logger: log, - fetcher: client, - } - - for _, opt := range opts { - opt(cache) - } - - cache.cond = sync.NewCond(&cache.mu) - cache.refreshCtx, cache.refreshCancel = context.WithCancel(ctx) - cache.refresher = cache.Clock.AfterFunc(refreshInterval, cache.refresh) - - keys, err := cache.cryptoKeys(ctx) - if err != nil { - cache.refreshCancel() - return nil, xerrors.Errorf("initial fetch: %w", err) - } - cache.keys = keys - - return cache, nil -} - -func (k *CryptoKeyCache) Signing(ctx context.Context) (codersdk.CryptoKey, error) { - return k.cryptoKey(ctx, latestSequence) -} - -func (k *CryptoKeyCache) Verifying(ctx context.Context, sequence int32) (codersdk.CryptoKey, error) { - return k.cryptoKey(ctx, sequence) -} - -func (k *CryptoKeyCache) cryptoKey(ctx context.Context, sequence int32) (codersdk.CryptoKey, error) { - k.mu.Lock() - defer k.mu.Unlock() - - if k.closed { - return codersdk.CryptoKey{}, cryptokeys.ErrClosed - } - - var key codersdk.CryptoKey - var ok bool - for key, ok = k.key(sequence); !ok && k.fetching && !k.closed; { - k.cond.Wait() - } - - if k.closed { - return codersdk.CryptoKey{}, cryptokeys.ErrClosed - } - - if ok { - return checkKey(key, sequence, k.Clock.Now()) - } - - k.fetching = true - k.mu.Unlock() - - keys, err := k.cryptoKeys(ctx) - if err != nil { - return codersdk.CryptoKey{}, xerrors.Errorf("get keys: %w", err) - } - - k.mu.Lock() - k.lastFetch = k.Clock.Now() - k.refresher.Reset(refreshInterval) - k.keys = keys - k.fetching = false - k.cond.Broadcast() - - key, ok = k.key(sequence) - if !ok { - return codersdk.CryptoKey{}, cryptokeys.ErrKeyNotFound - } - - return checkKey(key, sequence, k.Clock.Now()) -} - -func (k *CryptoKeyCache) key(sequence int32) (codersdk.CryptoKey, bool) { - if sequence == latestSequence { - return k.keys[latestSequence], k.keys[latestSequence].CanSign(k.Clock.Now()) - } - - key, ok := k.keys[sequence] - return key, ok -} - -func checkKey(key codersdk.CryptoKey, sequence int32, now time.Time) (codersdk.CryptoKey, error) { - if sequence == latestSequence { - if !key.CanSign(now) { - return codersdk.CryptoKey{}, cryptokeys.ErrKeyInvalid - } - return key, nil - } - - if !key.CanVerify(now) { - return codersdk.CryptoKey{}, cryptokeys.ErrKeyInvalid - } - - return key, nil -} - -// refresh fetches the keys from the control plane and updates the cache. -func (k *CryptoKeyCache) refresh() { - now := k.Clock.Now("CryptoKeyCache", "refresh") - k.mu.Lock() - - if k.closed { - k.mu.Unlock() - return - } - - // If something's already fetching, we don't need to do anything. - if k.fetching { - k.mu.Unlock() - return - } - - // There's a window we must account for where the timer fires while a fetch - // is ongoing but prior to the timer getting reset. In this case we want to - // avoid double fetching. - if now.Sub(k.lastFetch) < refreshInterval { - k.mu.Unlock() - return - } - - k.fetching = true - - k.mu.Unlock() - keys, err := k.cryptoKeys(k.refreshCtx) - if err != nil { - k.logger.Error(k.refreshCtx, "fetch crypto keys", slog.Error(err)) - return - } - - k.mu.Lock() - defer k.mu.Unlock() - - k.lastFetch = k.Clock.Now() - k.refresher.Reset(refreshInterval) - k.keys = keys - k.fetching = false - k.cond.Broadcast() -} - -// cryptoKeys queries the control plane for the crypto keys. -// Outside of initialization, this should only be called by fetch. -func (k *CryptoKeyCache) cryptoKeys(ctx context.Context) (map[int32]codersdk.CryptoKey, error) { - keys, err := k.fetcher.Fetch(ctx) - if err != nil { - return nil, xerrors.Errorf("crypto keys: %w", err) - } - cache := toKeyMap(keys, k.Clock.Now()) - return cache, nil -} - -func toKeyMap(keys []codersdk.CryptoKey, now time.Time) map[int32]codersdk.CryptoKey { - m := make(map[int32]codersdk.CryptoKey) - var latest codersdk.CryptoKey - for _, key := range keys { - m[key.Sequence] = key - if key.Sequence > latest.Sequence && key.CanSign(now) { - m[latestSequence] = key - } - } - return m -} - -func (k *CryptoKeyCache) Close() { - k.mu.Lock() - defer k.mu.Unlock() - - if k.closed { - return - } - - k.closed = true - k.refreshCancel() - k.refresher.Stop() - k.cond.Broadcast() -} diff --git a/enterprise/wsproxy/keyfetcher.go b/enterprise/wsproxy/keyfetcher.go new file mode 100644 index 0000000000000..81b71301b610f --- /dev/null +++ b/enterprise/wsproxy/keyfetcher.go @@ -0,0 +1,25 @@ +package wsproxy + +import ( + "context" + + "github.com/coder/coder/v2/coderd/cryptokeys" + "github.com/coder/coder/v2/codersdk" + "github.com/coder/coder/v2/enterprise/wsproxy/wsproxysdk" + "golang.org/x/xerrors" +) + +var _ cryptokeys.Fetcher = &ProxyFetcher{} + +type ProxyFetcher struct { + Client *wsproxysdk.Client + Feature codersdk.CryptoKeyFeature +} + +func (p *ProxyFetcher) Fetch(ctx context.Context) ([]codersdk.CryptoKey, error) { + keys, err := p.Client.CryptoKeys(ctx) + if err != nil { + return nil, xerrors.Errorf("crypto keys: %w", err) + } + return keys.CryptoKeys, nil +}