diff --git a/Makefile b/Makefile index 0765346500975..be74b27013a23 100644 --- a/Makefile +++ b/Makefile @@ -537,7 +537,8 @@ gen/mark-fresh: tailnet/tailnettest/coordinatormock.go \ tailnet/tailnettest/coordinateemock.go \ tailnet/tailnettest/multiagentmock.go \ - " + " + for file in $$files; do echo "$$file" if [ ! -f "$$file" ]; then diff --git a/coderd/cryptokeys/dbkeycache.go b/coderd/cryptokeys/dbkeycache.go index 4986f1669c4e5..aa0a2444b35f2 100644 --- a/coderd/cryptokeys/dbkeycache.go +++ b/coderd/cryptokeys/dbkeycache.go @@ -2,6 +2,7 @@ package cryptokeys import ( "context" + "strconv" "sync" "time" @@ -9,16 +10,14 @@ import ( "cdr.dev/slog" "github.com/coder/coder/v2/coderd/database" - "github.com/coder/coder/v2/coderd/database/db2sdk" - "github.com/coder/coder/v2/codersdk" "github.com/coder/quartz" ) // never represents the maximum value for a time.Duration. const never = 1<<63 - 1 -// DBCache implements Keycache for callers with access to the database. -type DBCache struct { +// dbCache implements Keycache for callers with access to the database. +type dbCache struct { db database.Store feature database.CryptoKeyFeature logger slog.Logger @@ -34,18 +33,34 @@ type DBCache struct { closed bool } -type DBCacheOption func(*DBCache) +type DBCacheOption func(*dbCache) func WithDBCacheClock(clock quartz.Clock) DBCacheOption { - return func(d *DBCache) { + return func(d *dbCache) { d.clock = clock } } -// NewDBCache creates a new DBCache. Close should be called to +// NewSigningCache creates a new DBCache. Close should be called to // release resources associated with its internal timer. -func NewDBCache(logger slog.Logger, db database.Store, feature database.CryptoKeyFeature, opts ...func(*DBCache)) *DBCache { - d := &DBCache{ +func NewSigningCache(logger slog.Logger, db database.Store, feature database.CryptoKeyFeature, opts ...func(*dbCache)) (SigningKeycache, error) { + if !isSigningKeyFeature(feature) { + return nil, ErrInvalidFeature + } + + return newDBCache(logger, db, feature, opts...), nil +} + +func NewEncryptionCache(logger slog.Logger, db database.Store, feature database.CryptoKeyFeature, opts ...func(*dbCache)) (EncryptionKeycache, error) { + if !isEncryptionKeyFeature(feature) { + return nil, ErrInvalidFeature + } + + return newDBCache(logger, db, feature, opts...), nil +} + +func newDBCache(logger slog.Logger, db database.Store, feature database.CryptoKeyFeature, opts ...func(*dbCache)) *dbCache { + d := &dbCache{ db: db, feature: feature, clock: quartz.NewReal(), @@ -56,23 +71,61 @@ func NewDBCache(logger slog.Logger, db database.Store, feature database.CryptoKe opt(d) } + // Initialize the timer. This will get properly initialized the first time we fetch. d.timer = d.clock.AfterFunc(never, d.clear) return d } -// Verifying returns the CryptoKey with the given sequence number, provided that +func (d *dbCache) EncryptingKey(ctx context.Context) (id string, key interface{}, err error) { + if !isEncryptionKeyFeature(d.feature) { + return "", nil, ErrInvalidFeature + } + + return d.latest(ctx) +} + +func (d *dbCache) DecryptingKey(ctx context.Context, id string) (key interface{}, err error) { + if !isEncryptionKeyFeature(d.feature) { + return nil, ErrInvalidFeature + } + + return d.sequence(ctx, id) +} + +func (d *dbCache) SigningKey(ctx context.Context) (id string, key interface{}, err error) { + if !isSigningKeyFeature(d.feature) { + return "", nil, ErrInvalidFeature + } + + return d.latest(ctx) +} + +func (d *dbCache) VerifyingKey(ctx context.Context, id string) (key interface{}, err error) { + if !isSigningKeyFeature(d.feature) { + return nil, ErrInvalidFeature + } + + return d.sequence(ctx, id) +} + +// sequence returns the CryptoKey with the given sequence number, provided that // it is neither deleted nor has breached its deletion date. It should only be // used for verifying or decrypting payloads. To sign/encrypt call Signing. -func (d *DBCache) Verifying(ctx context.Context, sequence int32) (codersdk.CryptoKey, error) { +func (d *dbCache) sequence(ctx context.Context, id string) (interface{}, error) { + sequence, err := strconv.ParseInt(id, 10, 32) + if err != nil { + return nil, xerrors.Errorf("expecting sequence number got %q: %w", id, err) + } + d.keysMu.RLock() if d.closed { d.keysMu.RUnlock() - return codersdk.CryptoKey{}, ErrClosed + return nil, ErrClosed } now := d.clock.Now() - key, ok := d.keys[sequence] + key, ok := d.keys[int32(sequence)] d.keysMu.RUnlock() if ok { return checkKey(key, now) @@ -82,35 +135,35 @@ func (d *DBCache) Verifying(ctx context.Context, sequence int32) (codersdk.Crypt defer d.keysMu.Unlock() if d.closed { - return codersdk.CryptoKey{}, ErrClosed + return nil, ErrClosed } - key, ok = d.keys[sequence] + key, ok = d.keys[int32(sequence)] if ok { return checkKey(key, now) } - err := d.fetch(ctx) + err = d.fetch(ctx) if err != nil { - return codersdk.CryptoKey{}, xerrors.Errorf("fetch: %w", err) + return nil, xerrors.Errorf("fetch: %w", err) } - key, ok = d.keys[sequence] + key, ok = d.keys[int32(sequence)] if !ok { - return codersdk.CryptoKey{}, ErrKeyNotFound + return nil, ErrKeyNotFound } return checkKey(key, now) } -// Signing returns the latest valid key for signing. A valid key is one that is +// latest returns the latest valid key for signing. A valid key is one that is // both past its start time and before its deletion time. -func (d *DBCache) Signing(ctx context.Context) (codersdk.CryptoKey, error) { +func (d *dbCache) latest(ctx context.Context) (string, interface{}, error) { d.keysMu.RLock() if d.closed { d.keysMu.RUnlock() - return codersdk.CryptoKey{}, ErrClosed + return "", nil, ErrClosed } latest := d.latestKey @@ -118,31 +171,31 @@ func (d *DBCache) Signing(ctx context.Context) (codersdk.CryptoKey, error) { now := d.clock.Now() if latest.CanSign(now) { - return db2sdk.CryptoKey(latest), nil + return idSecret(latest) } d.keysMu.Lock() defer d.keysMu.Unlock() if d.closed { - return codersdk.CryptoKey{}, ErrClosed + return "", nil, ErrClosed } if d.latestKey.CanSign(now) { - return db2sdk.CryptoKey(d.latestKey), nil + return idSecret(d.latestKey) } // Refetch all keys for this feature so we can find the latest valid key. err := d.fetch(ctx) if err != nil { - return codersdk.CryptoKey{}, xerrors.Errorf("fetch: %w", err) + return "", nil, xerrors.Errorf("fetch: %w", err) } - return db2sdk.CryptoKey(d.latestKey), nil + return idSecret(d.latestKey) } // clear invalidates the cache. This forces the subsequent call to fetch fresh keys. -func (d *DBCache) clear() { +func (d *dbCache) clear() { now := d.clock.Now("DBCache", "clear") d.keysMu.Lock() defer d.keysMu.Unlock() @@ -158,7 +211,7 @@ func (d *DBCache) clear() { // fetch fetches all keys for the given feature and determines the latest key. // It must be called while holding the keysMu lock. -func (d *DBCache) fetch(ctx context.Context) error { +func (d *dbCache) fetch(ctx context.Context) error { keys, err := d.db.GetCryptoKeysByFeature(ctx, d.feature) if err != nil { return xerrors.Errorf("get crypto keys by feature: %w", err) @@ -189,22 +242,45 @@ func (d *DBCache) fetch(ctx context.Context) error { return nil } -func checkKey(key database.CryptoKey, now time.Time) (codersdk.CryptoKey, error) { +func checkKey(key database.CryptoKey, now time.Time) (interface{}, error) { if !key.CanVerify(now) { - return codersdk.CryptoKey{}, ErrKeyInvalid + return nil, ErrKeyInvalid } - return db2sdk.CryptoKey(key), nil + return key.DecodeString() } -func (d *DBCache) Close() { +func (d *dbCache) Close() error { d.keysMu.Lock() defer d.keysMu.Unlock() if d.closed { - return + return nil } d.timer.Stop() d.closed = true + return nil +} + +func isEncryptionKeyFeature(feature database.CryptoKeyFeature) bool { + return feature == database.CryptoKeyFeatureWorkspaceApps +} + +func isSigningKeyFeature(feature database.CryptoKeyFeature) bool { + switch feature { + case database.CryptoKeyFeatureTailnetResume, database.CryptoKeyFeatureOidcConvert: + return true + default: + return false + } +} + +func idSecret(k database.CryptoKey) (string, interface{}, error) { + key, err := k.DecodeString() + if err != nil { + return "", nil, xerrors.Errorf("decode key: %w", err) + } + + return strconv.FormatInt(int64(k.Sequence), 10), key, nil } diff --git a/coderd/cryptokeys/dbkeycache_internal_test.go b/coderd/cryptokeys/dbkeycache_internal_test.go index a3450f5f5e0d9..c27bc5b8468ad 100644 --- a/coderd/cryptokeys/dbkeycache_internal_test.go +++ b/coderd/cryptokeys/dbkeycache_internal_test.go @@ -2,6 +2,7 @@ package cryptokeys import ( "database/sql" + "strconv" "testing" "time" @@ -11,13 +12,12 @@ import ( "cdr.dev/slog/sloggers/slogtest" "github.com/coder/coder/v2/coderd/database" - "github.com/coder/coder/v2/coderd/database/db2sdk" "github.com/coder/coder/v2/coderd/database/dbmock" "github.com/coder/coder/v2/testutil" "github.com/coder/quartz" ) -func Test_Verifying(t *testing.T) { +func Test_version(t *testing.T) { t.Parallel() t.Run("HitsCache", func(t *testing.T) { @@ -35,7 +35,7 @@ func Test_Verifying(t *testing.T) { Feature: database.CryptoKeyFeatureWorkspaceApps, Sequence: 32, Secret: sql.NullString{ - String: "secret", + String: mustGenerateKey(t), Valid: true, }, } @@ -44,13 +44,13 @@ func Test_Verifying(t *testing.T) { 32: expectedKey, } - k := NewDBCache(logger, mockDB, database.CryptoKeyFeatureWorkspaceApps, WithDBCacheClock(clock)) + k := newDBCache(logger, mockDB, database.CryptoKeyFeatureWorkspaceApps, WithDBCacheClock(clock)) defer k.Close() k.keys = cache - got, err := k.Verifying(ctx, 32) + secret, err := k.sequence(ctx, keyID(expectedKey)) require.NoError(t, err) - require.Equal(t, db2sdk.CryptoKey(expectedKey), got) + require.Equal(t, decodedSecret(t, expectedKey), secret) }) t.Run("MissesCache", func(t *testing.T) { @@ -69,20 +69,19 @@ func Test_Verifying(t *testing.T) { Sequence: 33, StartsAt: clock.Now(), Secret: sql.NullString{ - String: "secret", + String: mustGenerateKey(t), Valid: true, }, } mockDB.EXPECT().GetCryptoKeysByFeature(ctx, database.CryptoKeyFeatureWorkspaceApps).Return([]database.CryptoKey{expectedKey}, nil) - k := NewDBCache(logger, mockDB, database.CryptoKeyFeatureWorkspaceApps, WithDBCacheClock(clock)) + k := newDBCache(logger, mockDB, database.CryptoKeyFeatureWorkspaceApps, WithDBCacheClock(clock)) defer k.Close() - got, err := k.Verifying(ctx, 33) + got, err := k.sequence(ctx, keyID(expectedKey)) require.NoError(t, err) - require.Equal(t, db2sdk.CryptoKey(expectedKey), got) - require.Equal(t, db2sdk.CryptoKey(expectedKey), db2sdk.CryptoKey(k.latestKey)) + require.Equal(t, decodedSecret(t, expectedKey), got) }) t.Run("InvalidCachedKey", func(t *testing.T) { @@ -101,7 +100,7 @@ func Test_Verifying(t *testing.T) { Feature: database.CryptoKeyFeatureWorkspaceApps, Sequence: 32, Secret: sql.NullString{ - String: "secret", + String: mustGenerateKey(t), Valid: true, }, DeletesAt: sql.NullTime{ @@ -111,11 +110,11 @@ func Test_Verifying(t *testing.T) { }, } - k := NewDBCache(logger, mockDB, database.CryptoKeyFeatureWorkspaceApps, WithDBCacheClock(clock)) + k := newDBCache(logger, mockDB, database.CryptoKeyFeatureWorkspaceApps, WithDBCacheClock(clock)) defer k.Close() k.keys = cache - _, err := k.Verifying(ctx, 32) + _, err := k.sequence(ctx, "32") require.ErrorIs(t, err, ErrKeyInvalid) }) @@ -134,7 +133,7 @@ func Test_Verifying(t *testing.T) { Feature: database.CryptoKeyFeatureWorkspaceApps, Sequence: 32, Secret: sql.NullString{ - String: "secret", + String: mustGenerateKey(t), Valid: true, }, DeletesAt: sql.NullTime{ @@ -144,15 +143,15 @@ func Test_Verifying(t *testing.T) { } mockDB.EXPECT().GetCryptoKeysByFeature(ctx, database.CryptoKeyFeatureWorkspaceApps).Return([]database.CryptoKey{invalidKey}, nil) - k := NewDBCache(logger, mockDB, database.CryptoKeyFeatureWorkspaceApps, WithDBCacheClock(clock)) + k := newDBCache(logger, mockDB, database.CryptoKeyFeatureWorkspaceApps, WithDBCacheClock(clock)) defer k.Close() - _, err := k.Verifying(ctx, 32) + _, err := k.sequence(ctx, keyID(invalidKey)) require.ErrorIs(t, err, ErrKeyInvalid) }) } -func Test_Signing(t *testing.T) { +func Test_latest(t *testing.T) { t.Parallel() t.Run("HitsCache", func(t *testing.T) { @@ -170,19 +169,20 @@ func Test_Signing(t *testing.T) { Feature: database.CryptoKeyFeatureWorkspaceApps, Sequence: 32, Secret: sql.NullString{ - String: "secret", + String: mustGenerateKey(t), Valid: true, }, StartsAt: clock.Now(), } - k := NewDBCache(logger, mockDB, database.CryptoKeyFeatureWorkspaceApps, WithDBCacheClock(clock)) + k := newDBCache(logger, mockDB, database.CryptoKeyFeatureWorkspaceApps, WithDBCacheClock(clock)) defer k.Close() k.latestKey = latestKey - got, err := k.Signing(ctx) + id, secret, err := k.latest(ctx) require.NoError(t, err) - require.Equal(t, db2sdk.CryptoKey(latestKey), got) + require.Equal(t, keyID(latestKey), id) + require.Equal(t, decodedSecret(t, latestKey), secret) }) t.Run("InvalidCachedKey", func(t *testing.T) { @@ -200,7 +200,7 @@ func Test_Signing(t *testing.T) { Feature: database.CryptoKeyFeatureWorkspaceApps, Sequence: 33, Secret: sql.NullString{ - String: "secret", + String: mustGenerateKey(t), Valid: true, }, StartsAt: clock.Now(), @@ -210,7 +210,7 @@ func Test_Signing(t *testing.T) { Feature: database.CryptoKeyFeatureWorkspaceApps, Sequence: 32, Secret: sql.NullString{ - String: "secret", + String: mustGenerateKey(t), Valid: true, }, StartsAt: clock.Now().Add(-time.Hour), @@ -222,13 +222,14 @@ func Test_Signing(t *testing.T) { mockDB.EXPECT().GetCryptoKeysByFeature(ctx, database.CryptoKeyFeatureWorkspaceApps).Return([]database.CryptoKey{latestKey}, nil) - k := NewDBCache(logger, mockDB, database.CryptoKeyFeatureWorkspaceApps, WithDBCacheClock(clock)) + k := newDBCache(logger, mockDB, database.CryptoKeyFeatureWorkspaceApps, WithDBCacheClock(clock)) defer k.Close() k.latestKey = invalidKey - got, err := k.Signing(ctx) + id, secret, err := k.latest(ctx) require.NoError(t, err) - require.Equal(t, db2sdk.CryptoKey(latestKey), got) + require.Equal(t, keyID(latestKey), id) + require.Equal(t, decodedSecret(t, latestKey), secret) }) t.Run("UsesActiveKey", func(t *testing.T) { @@ -246,7 +247,7 @@ func Test_Signing(t *testing.T) { Feature: database.CryptoKeyFeatureWorkspaceApps, Sequence: 32, Secret: sql.NullString{ - String: "secret", + String: mustGenerateKey(t), Valid: true, }, StartsAt: clock.Now().Add(time.Hour), @@ -256,7 +257,7 @@ func Test_Signing(t *testing.T) { Feature: database.CryptoKeyFeatureWorkspaceApps, Sequence: 33, Secret: sql.NullString{ - String: "secret", + String: mustGenerateKey(t), Valid: true, }, StartsAt: clock.Now(), @@ -264,12 +265,13 @@ func Test_Signing(t *testing.T) { mockDB.EXPECT().GetCryptoKeysByFeature(ctx, database.CryptoKeyFeatureWorkspaceApps).Return([]database.CryptoKey{inactiveKey, activeKey}, nil) - k := NewDBCache(logger, mockDB, database.CryptoKeyFeatureWorkspaceApps, WithDBCacheClock(clock)) + k := newDBCache(logger, mockDB, database.CryptoKeyFeatureWorkspaceApps, WithDBCacheClock(clock)) defer k.Close() - got, err := k.Signing(ctx) + id, secret, err := k.latest(ctx) require.NoError(t, err) - require.Equal(t, db2sdk.CryptoKey(activeKey), got) + require.Equal(t, keyID(activeKey), id) + require.Equal(t, decodedSecret(t, activeKey), secret) }) t.Run("NoValidKeys", func(t *testing.T) { @@ -287,7 +289,7 @@ func Test_Signing(t *testing.T) { Feature: database.CryptoKeyFeatureWorkspaceApps, Sequence: 32, Secret: sql.NullString{ - String: "secret", + String: mustGenerateKey(t), Valid: true, }, StartsAt: clock.Now().Add(time.Hour), @@ -297,7 +299,7 @@ func Test_Signing(t *testing.T) { Feature: database.CryptoKeyFeatureWorkspaceApps, Sequence: 33, Secret: sql.NullString{ - String: "secret", + String: mustGenerateKey(t), Valid: true, }, StartsAt: clock.Now().Add(-time.Hour), @@ -309,10 +311,10 @@ func Test_Signing(t *testing.T) { mockDB.EXPECT().GetCryptoKeysByFeature(ctx, database.CryptoKeyFeatureWorkspaceApps).Return([]database.CryptoKey{inactiveKey, invalidKey}, nil) - k := NewDBCache(logger, mockDB, database.CryptoKeyFeatureWorkspaceApps, WithDBCacheClock(clock)) + k := newDBCache(logger, mockDB, database.CryptoKeyFeatureWorkspaceApps, WithDBCacheClock(clock)) defer k.Close() - _, err := k.Signing(ctx) + _, _, err := k.latest(ctx) require.ErrorIs(t, err, ErrKeyInvalid) }) } @@ -331,14 +333,14 @@ func Test_clear(t *testing.T) { logger = slogtest.Make(t, nil) ) - k := NewDBCache(logger, mockDB, database.CryptoKeyFeatureWorkspaceApps, WithDBCacheClock(clock)) + k := newDBCache(logger, mockDB, database.CryptoKeyFeatureWorkspaceApps, WithDBCacheClock(clock)) defer k.Close() activeKey := database.CryptoKey{ Feature: database.CryptoKeyFeatureWorkspaceApps, Sequence: 33, Secret: sql.NullString{ - String: "secret", + String: mustGenerateKey(t), Valid: true, }, StartsAt: clock.Now(), @@ -346,7 +348,7 @@ func Test_clear(t *testing.T) { mockDB.EXPECT().GetCryptoKeysByFeature(ctx, database.CryptoKeyFeatureWorkspaceApps).Return([]database.CryptoKey{activeKey}, nil) - _, err := k.Signing(ctx) + _, _, err := k.latest(ctx) require.NoError(t, err) dur, wait := clock.AdvanceNext() @@ -367,14 +369,14 @@ func Test_clear(t *testing.T) { logger = slogtest.Make(t, nil) ) - k := NewDBCache(logger, mockDB, database.CryptoKeyFeatureWorkspaceApps, WithDBCacheClock(clock)) + k := newDBCache(logger, mockDB, database.CryptoKeyFeatureWorkspaceApps, WithDBCacheClock(clock)) defer k.Close() key := database.CryptoKey{ Feature: database.CryptoKeyFeatureWorkspaceApps, Sequence: 32, Secret: sql.NullString{ - String: "secret", + String: mustGenerateKey(t), Valid: true, }, StartsAt: clock.Now(), @@ -386,9 +388,10 @@ func Test_clear(t *testing.T) { // timer is reset and doesn't fire after another five minute. clock.Advance(time.Minute * 5) - latest, err := k.Signing(ctx) + id, secret, err := k.latest(ctx) require.NoError(t, err) - require.Equal(t, db2sdk.CryptoKey(key), latest) + require.Equal(t, keyID(key), id) + require.Equal(t, decodedSecret(t, key), secret) // Advancing the clock now should require 10 minutes // before the timer fires again. @@ -415,14 +418,14 @@ func Test_clear(t *testing.T) { trap := clock.Trap().Now("clear") - k := NewDBCache(logger, mockDB, database.CryptoKeyFeatureWorkspaceApps, WithDBCacheClock(clock)) + k := newDBCache(logger, mockDB, database.CryptoKeyFeatureWorkspaceApps, WithDBCacheClock(clock)) defer k.Close() key := database.CryptoKey{ Feature: database.CryptoKeyFeatureWorkspaceApps, Sequence: 32, Secret: sql.NullString{ - String: "secret", + String: mustGenerateKey(t), Valid: true, }, StartsAt: clock.Now(), @@ -431,9 +434,10 @@ func Test_clear(t *testing.T) { mockDB.EXPECT().GetCryptoKeysByFeature(ctx, database.CryptoKeyFeatureWorkspaceApps).Return([]database.CryptoKey{key}, nil).Times(2) // Move us past the initial timer. - latest, err := k.Signing(ctx) + id, secret, err := k.latest(ctx) require.NoError(t, err) - require.Equal(t, db2sdk.CryptoKey(key), latest) + require.Equal(t, keyID(key), id) + require.Equal(t, decodedSecret(t, key), secret) // Null these out so that we refetch. k.keys = nil k.latestKey = database.CryptoKey{} @@ -445,9 +449,10 @@ func Test_clear(t *testing.T) { call := trap.MustWait(ctx) // Refetch keys. - latest, err = k.Signing(ctx) + id, secret, err = k.latest(ctx) require.NoError(t, err) - require.Equal(t, db2sdk.CryptoKey(key), latest) + require.Equal(t, keyID(key), id) + require.Equal(t, decodedSecret(t, key), secret) // Let the rest of the timer function run. // It should see that we have refetched keys and @@ -465,3 +470,21 @@ func Test_clear(t *testing.T) { require.Equal(t, database.CryptoKey{}, k.latestKey) }) } + +func mustGenerateKey(t *testing.T) string { + t.Helper() + key, err := generateKey(64) + require.NoError(t, err) + return key +} + +func keyID(key database.CryptoKey) string { + return strconv.FormatInt(int64(key.Sequence), 10) +} + +func decodedSecret(t *testing.T, key database.CryptoKey) []byte { + t.Helper() + decoded, err := key.DecodeString() + require.NoError(t, err) + return decoded +} diff --git a/coderd/cryptokeys/dbkeycache_test.go b/coderd/cryptokeys/dbkeycache_test.go index 8c92cf3a90aa6..e24ef16660db1 100644 --- a/coderd/cryptokeys/dbkeycache_test.go +++ b/coderd/cryptokeys/dbkeycache_test.go @@ -1,6 +1,7 @@ package cryptokeys_test import ( + "strconv" "testing" "github.com/stretchr/testify/require" @@ -10,7 +11,6 @@ import ( "github.com/coder/coder/v2/coderd/cryptokeys" "github.com/coder/coder/v2/coderd/database" - "github.com/coder/coder/v2/coderd/database/db2sdk" "github.com/coder/coder/v2/coderd/database/dbgen" "github.com/coder/coder/v2/coderd/database/dbtestutil" "github.com/coder/coder/v2/testutil" @@ -24,7 +24,7 @@ func TestMain(m *testing.M) { func TestDBKeyCache(t *testing.T) { t.Parallel() - t.Run("Verifying", func(t *testing.T) { + t.Run("VerifyingKey", func(t *testing.T) { t.Parallel() t.Run("HitsCache", func(t *testing.T) { @@ -38,17 +38,18 @@ func TestDBKeyCache(t *testing.T) { ) key := dbgen.CryptoKey(t, db, database.CryptoKey{ - Feature: database.CryptoKeyFeatureWorkspaceApps, + Feature: database.CryptoKeyFeatureOidcConvert, Sequence: 1, StartsAt: clock.Now().UTC(), }) - k := cryptokeys.NewDBCache(logger, db, database.CryptoKeyFeatureWorkspaceApps, cryptokeys.WithDBCacheClock(clock)) + k, err := cryptokeys.NewSigningCache(logger, db, database.CryptoKeyFeatureOidcConvert, cryptokeys.WithDBCacheClock(clock)) + require.NoError(t, err) defer k.Close() - got, err := k.Verifying(ctx, key.Sequence) + got, err := k.VerifyingKey(ctx, keyID(key)) require.NoError(t, err) - require.Equal(t, db2sdk.CryptoKey(key), got) + require.Equal(t, decodedSecret(t, key), got) }) t.Run("NotFound", func(t *testing.T) { @@ -61,10 +62,11 @@ func TestDBKeyCache(t *testing.T) { logger = slogtest.Make(t, nil) ) - k := cryptokeys.NewDBCache(logger, db, database.CryptoKeyFeatureWorkspaceApps, cryptokeys.WithDBCacheClock(clock)) + k, err := cryptokeys.NewSigningCache(logger, db, database.CryptoKeyFeatureOidcConvert, cryptokeys.WithDBCacheClock(clock)) + require.NoError(t, err) defer k.Close() - _, err := k.Verifying(ctx, 123) + _, err = k.VerifyingKey(ctx, "123") require.ErrorIs(t, err, cryptokeys.ErrKeyNotFound) }) }) @@ -80,29 +82,31 @@ func TestDBKeyCache(t *testing.T) { ) _ = dbgen.CryptoKey(t, db, database.CryptoKey{ - Feature: database.CryptoKeyFeatureWorkspaceApps, + Feature: database.CryptoKeyFeatureOidcConvert, Sequence: 10, StartsAt: clock.Now().UTC(), }) expectedKey := dbgen.CryptoKey(t, db, database.CryptoKey{ - Feature: database.CryptoKeyFeatureWorkspaceApps, + Feature: database.CryptoKeyFeatureOidcConvert, Sequence: 12, StartsAt: clock.Now().UTC(), }) _ = dbgen.CryptoKey(t, db, database.CryptoKey{ - Feature: database.CryptoKeyFeatureWorkspaceApps, + Feature: database.CryptoKeyFeatureOidcConvert, Sequence: 2, StartsAt: clock.Now().UTC(), }) - k := cryptokeys.NewDBCache(logger, db, database.CryptoKeyFeatureWorkspaceApps, cryptokeys.WithDBCacheClock(clock)) + k, err := cryptokeys.NewSigningCache(logger, db, database.CryptoKeyFeatureOidcConvert, cryptokeys.WithDBCacheClock(clock)) + require.NoError(t, err) defer k.Close() - got, err := k.Signing(ctx) + id, key, err := k.SigningKey(ctx) require.NoError(t, err) - require.Equal(t, db2sdk.CryptoKey(expectedKey), got) + require.Equal(t, keyID(expectedKey), id) + require.Equal(t, decodedSecret(t, expectedKey), key) }) t.Run("Closed", func(t *testing.T) { @@ -116,28 +120,97 @@ func TestDBKeyCache(t *testing.T) { ) expectedKey := dbgen.CryptoKey(t, db, database.CryptoKey{ - Feature: database.CryptoKeyFeatureWorkspaceApps, + Feature: database.CryptoKeyFeatureOidcConvert, Sequence: 10, StartsAt: clock.Now(), }) - k := cryptokeys.NewDBCache(logger, db, database.CryptoKeyFeatureWorkspaceApps, cryptokeys.WithDBCacheClock(clock)) + k, err := cryptokeys.NewSigningCache(logger, db, database.CryptoKeyFeatureOidcConvert, cryptokeys.WithDBCacheClock(clock)) + require.NoError(t, err) defer k.Close() - got, err := k.Signing(ctx) + id, key, err := k.SigningKey(ctx) require.NoError(t, err) - require.Equal(t, db2sdk.CryptoKey(expectedKey), got) + require.Equal(t, keyID(expectedKey), id) + require.Equal(t, decodedSecret(t, expectedKey), key) - got, err = k.Verifying(ctx, expectedKey.Sequence) + key, err = k.VerifyingKey(ctx, keyID(expectedKey)) require.NoError(t, err) - require.Equal(t, db2sdk.CryptoKey(expectedKey), got) + require.Equal(t, decodedSecret(t, expectedKey), key) k.Close() - _, err = k.Signing(ctx) + _, _, err = k.SigningKey(ctx) require.ErrorIs(t, err, cryptokeys.ErrClosed) - _, err = k.Verifying(ctx, expectedKey.Sequence) + _, err = k.VerifyingKey(ctx, keyID(expectedKey)) require.ErrorIs(t, err, cryptokeys.ErrClosed) }) + + t.Run("InvalidSigningFeature", func(t *testing.T) { + t.Parallel() + + var ( + db, _ = dbtestutil.NewDB(t) + clock = quartz.NewMock(t) + logger = slogtest.Make(t, nil) + ctx = testutil.Context(t, testutil.WaitShort) + ) + + _, err := cryptokeys.NewSigningCache(logger, db, database.CryptoKeyFeatureWorkspaceApps, cryptokeys.WithDBCacheClock(clock)) + require.ErrorIs(t, err, cryptokeys.ErrInvalidFeature) + + // Instantiate a signing cache and try to use it as an encryption cache. + sc, err := cryptokeys.NewSigningCache(logger, db, database.CryptoKeyFeatureOidcConvert, cryptokeys.WithDBCacheClock(clock)) + require.NoError(t, err) + defer sc.Close() + + ec, ok := sc.(cryptokeys.EncryptionKeycache) + require.True(t, ok) + _, _, err = ec.EncryptingKey(ctx) + require.ErrorIs(t, err, cryptokeys.ErrInvalidFeature) + + _, err = ec.DecryptingKey(ctx, "123") + require.ErrorIs(t, err, cryptokeys.ErrInvalidFeature) + }) + + t.Run("InvalidEncryptionFeature", func(t *testing.T) { + t.Parallel() + + var ( + db, _ = dbtestutil.NewDB(t) + clock = quartz.NewMock(t) + logger = slogtest.Make(t, nil) + ctx = testutil.Context(t, testutil.WaitShort) + ) + + _, err := cryptokeys.NewEncryptionCache(logger, db, database.CryptoKeyFeatureOidcConvert, cryptokeys.WithDBCacheClock(clock)) + require.ErrorIs(t, err, cryptokeys.ErrInvalidFeature) + + // Instantiate an encryption cache and try to use it as a signing cache. + ec, err := cryptokeys.NewEncryptionCache(logger, db, database.CryptoKeyFeatureWorkspaceApps, cryptokeys.WithDBCacheClock(clock)) + require.NoError(t, err) + defer ec.Close() + + sc, ok := ec.(cryptokeys.SigningKeycache) + require.True(t, ok) + _, _, err = sc.SigningKey(ctx) + require.ErrorIs(t, err, cryptokeys.ErrInvalidFeature) + + _, err = sc.VerifyingKey(ctx, "123") + require.ErrorIs(t, err, cryptokeys.ErrInvalidFeature) + }) +} + +func keyID(key database.CryptoKey) string { + return strconv.FormatInt(int64(key.Sequence), 10) +} + +func decodedSecret(t *testing.T, key database.CryptoKey) []byte { + t.Helper() + + secret, err := key.DecodeString() + require.NoError(t, err) + + return secret } diff --git a/coderd/cryptokeys/doc.go b/coderd/cryptokeys/doc.go new file mode 100644 index 0000000000000..b2494f9f0da8d --- /dev/null +++ b/coderd/cryptokeys/doc.go @@ -0,0 +1,2 @@ +// Package cryptokeys provides an abstraction for fetching internally used cryptographic keys mainly for JWT signing and verification. +package cryptokeys diff --git a/coderd/cryptokeys/keycache.go b/coderd/cryptokeys/keycache.go index 8c4ebfa13f64e..05c80a15b2378 100644 --- a/coderd/cryptokeys/keycache.go +++ b/coderd/cryptokeys/keycache.go @@ -2,20 +2,40 @@ package cryptokeys import ( "context" + "io" "golang.org/x/xerrors" - - "github.com/coder/coder/v2/codersdk" ) var ( - ErrKeyNotFound = xerrors.New("key not found") - ErrKeyInvalid = xerrors.New("key is invalid for use") - ErrClosed = xerrors.New("closed") + ErrKeyNotFound = xerrors.New("key not found") + ErrKeyInvalid = xerrors.New("key is invalid for use") + ErrClosed = xerrors.New("closed") + ErrInvalidFeature = xerrors.New("invalid feature for this operation") ) -// Keycache provides an abstraction for fetching signing keys. -type Keycache interface { - Signing(ctx context.Context) (codersdk.CryptoKey, error) - Verifying(ctx context.Context, sequence int32) (codersdk.CryptoKey, error) +type EncryptionKeycache interface { + // EncryptingKey returns the latest valid key for encrypting payloads. A valid + // key is one that is both past its start time and before its deletion time. + EncryptingKey(ctx context.Context) (id string, key interface{}, err error) + // DecryptingKey returns the key with the provided id which maps to its sequence + // number. The key is valid for decryption as long as it is not deleted or past + // its deletion date. We must allow for keys prior to their start time to + // account for clock skew between peers (one key may be past its start time on + // one machine while another is not). + DecryptingKey(ctx context.Context, id string) (key interface{}, err error) + io.Closer +} + +type SigningKeycache interface { + // SigningKey returns the latest valid key for signing. A valid key is one + // that is both past its start time and before its deletion time. + SigningKey(ctx context.Context) (id string, key interface{}, err error) + // VerifyingKey returns the key with the provided id which should map to its + // sequence number. The key is valid for verifying as long as it is not deleted + // or past its deletion date. We must allow for keys prior to their start time + // to account for clock skew between peers (one key may be past its start time + // on one machine while another is not). + VerifyingKey(ctx context.Context, id string) (key interface{}, err error) + io.Closer } diff --git a/coderd/cryptokeys/rotate.go b/coderd/cryptokeys/rotate.go index 224b9100d5bf8..14a623e2156db 100644 --- a/coderd/cryptokeys/rotate.go +++ b/coderd/cryptokeys/rotate.go @@ -227,9 +227,9 @@ func (k *rotator) rotateKey(ctx context.Context, tx database.Store, key database func generateNewSecret(feature database.CryptoKeyFeature) (string, error) { switch feature { case database.CryptoKeyFeatureWorkspaceApps: - return generateKey(96) - case database.CryptoKeyFeatureOidcConvert: return generateKey(32) + case database.CryptoKeyFeatureOidcConvert: + return generateKey(64) case database.CryptoKeyFeatureTailnetResume: return generateKey(64) } diff --git a/coderd/cryptokeys/rotate_internal_test.go b/coderd/cryptokeys/rotate_internal_test.go index 36ecf4fa9d76d..43754c1d8750f 100644 --- a/coderd/cryptokeys/rotate_internal_test.go +++ b/coderd/cryptokeys/rotate_internal_test.go @@ -588,9 +588,9 @@ func requireKey(t *testing.T, key database.CryptoKey, feature database.CryptoKey switch key.Feature { case database.CryptoKeyFeatureOidcConvert: - require.Len(t, secret, 32) + require.Len(t, secret, 64) case database.CryptoKeyFeatureWorkspaceApps: - require.Len(t, secret, 96) + require.Len(t, secret, 32) case database.CryptoKeyFeatureTailnetResume: require.Len(t, secret, 64) default: diff --git a/coderd/database/dbgen/dbgen.go b/coderd/database/dbgen/dbgen.go index 1a2f052a279b3..93439fd0f2b77 100644 --- a/coderd/database/dbgen/dbgen.go +++ b/coderd/database/dbgen/dbgen.go @@ -988,9 +988,9 @@ func takeFirst[Value comparable](values ...Value) Value { func newCryptoKeySecret(feature database.CryptoKeyFeature) (string, error) { switch feature { case database.CryptoKeyFeatureWorkspaceApps: - return generateCryptoKey(96) - case database.CryptoKeyFeatureOidcConvert: return generateCryptoKey(32) + case database.CryptoKeyFeatureOidcConvert: + return generateCryptoKey(64) case database.CryptoKeyFeatureTailnetResume: return generateCryptoKey(64) } diff --git a/coderd/jwtutils/jwe.go b/coderd/jwtutils/jwe.go new file mode 100644 index 0000000000000..f50cacb62de7c --- /dev/null +++ b/coderd/jwtutils/jwe.go @@ -0,0 +1,121 @@ +package jwtutils + +import ( + "context" + "encoding/json" + "time" + + "github.com/go-jose/go-jose/v4" + "github.com/go-jose/go-jose/v4/jwt" + "golang.org/x/xerrors" +) + +const ( + encryptKeyAlgo = jose.A256GCMKW + encryptContentAlgo = jose.A256GCM +) + +type EncryptKeyProvider interface { + EncryptingKey(ctx context.Context) (id string, key interface{}, err error) +} + +type DecryptKeyProvider interface { + DecryptingKey(ctx context.Context, id string) (key interface{}, err error) +} + +// Encrypt encrypts a token and returns it as a string. +func Encrypt(ctx context.Context, e EncryptKeyProvider, claims Claims) (string, error) { + id, key, err := e.EncryptingKey(ctx) + if err != nil { + return "", xerrors.Errorf("get signing key: %w", err) + } + + encrypter, err := jose.NewEncrypter( + encryptContentAlgo, + jose.Recipient{ + Algorithm: encryptKeyAlgo, + Key: key, + }, + &jose.EncrypterOptions{ + Compression: jose.DEFLATE, + ExtraHeaders: map[jose.HeaderKey]interface{}{ + keyIDHeaderKey: id, + }, + }, + ) + if err != nil { + return "", xerrors.Errorf("initialize encrypter: %w", err) + } + + payload, err := json.Marshal(claims) + if err != nil { + return "", xerrors.Errorf("marshal payload: %w", err) + } + + encrypted, err := encrypter.Encrypt(payload) + if err != nil { + return "", xerrors.Errorf("encrypt: %w", err) + } + + compact, err := encrypted.CompactSerialize() + if err != nil { + return "", xerrors.Errorf("compact serialize: %w", err) + } + + return compact, nil +} + +// DecryptOptions are options for decrypting a JWE. +type DecryptOptions struct { + RegisteredClaims jwt.Expected + KeyAlgorithm jose.KeyAlgorithm + ContentEncryptionAlgorithm jose.ContentEncryption +} + +// Decrypt decrypts the token using the provided key. It unmarshals into the provided claims. +func Decrypt(ctx context.Context, d DecryptKeyProvider, token string, claims Claims, opts ...func(*DecryptOptions)) error { + options := DecryptOptions{ + RegisteredClaims: jwt.Expected{ + Time: time.Now(), + }, + KeyAlgorithm: encryptKeyAlgo, + ContentEncryptionAlgorithm: encryptContentAlgo, + } + + for _, opt := range opts { + opt(&options) + } + + object, err := jose.ParseEncrypted(token, + []jose.KeyAlgorithm{options.KeyAlgorithm}, + []jose.ContentEncryption{options.ContentEncryptionAlgorithm}, + ) + if err != nil { + return xerrors.Errorf("parse jwe: %w", err) + } + + if object.Header.Algorithm != string(encryptKeyAlgo) { + return xerrors.Errorf("expected JWE algorithm to be %q, got %q", encryptKeyAlgo, object.Header.Algorithm) + } + + kid := object.Header.KeyID + if kid == "" { + return xerrors.Errorf("expected %q header to be a string", keyIDHeaderKey) + } + + key, err := d.DecryptingKey(ctx, kid) + if err != nil { + return xerrors.Errorf("key with id %q: %w", kid, err) + } + + decrypted, err := object.Decrypt(key) + if err != nil { + return xerrors.Errorf("decrypt: %w", err) + } + + if err := json.Unmarshal(decrypted, &claims); err != nil { + return xerrors.Errorf("unmarshal: %w", err) + } + + return claims.Validate(options.RegisteredClaims) +} diff --git a/coderd/jwtutils/jws.go b/coderd/jwtutils/jws.go new file mode 100644 index 0000000000000..73f35e672492d --- /dev/null +++ b/coderd/jwtutils/jws.go @@ -0,0 +1,127 @@ +package jwtutils + +import ( + "context" + "encoding/json" + "time" + + "github.com/go-jose/go-jose/v4" + "github.com/go-jose/go-jose/v4/jwt" + "golang.org/x/xerrors" +) + +const ( + keyIDHeaderKey = "kid" +) + +// Claims defines the payload for a JWT. Most callers +// should embed jwt.Claims +type Claims interface { + Validate(jwt.Expected) error +} + +const ( + signingAlgo = jose.HS512 +) + +type SigningKeyProvider interface { + SigningKey(ctx context.Context) (id string, key interface{}, err error) +} + +type VerifyKeyProvider interface { + VerifyingKey(ctx context.Context, id string) (key interface{}, err error) +} + +// Sign signs a token and returns it as a string. +func Sign(ctx context.Context, s SigningKeyProvider, claims Claims) (string, error) { + id, key, err := s.SigningKey(ctx) + if err != nil { + return "", xerrors.Errorf("get signing key: %w", err) + } + + signer, err := jose.NewSigner(jose.SigningKey{ + Algorithm: signingAlgo, + Key: key, + }, &jose.SignerOptions{ + ExtraHeaders: map[jose.HeaderKey]interface{}{ + keyIDHeaderKey: id, + }, + }) + if err != nil { + return "", xerrors.Errorf("new signer: %w", err) + } + + payload, err := json.Marshal(claims) + if err != nil { + return "", xerrors.Errorf("marshal claims: %w", err) + } + + signed, err := signer.Sign(payload) + if err != nil { + return "", xerrors.Errorf("sign payload: %w", err) + } + + compact, err := signed.CompactSerialize() + if err != nil { + return "", xerrors.Errorf("compact serialize: %w", err) + } + + return compact, nil +} + +// VerifyOptions are options for verifying a JWT. +type VerifyOptions struct { + RegisteredClaims jwt.Expected + SignatureAlgorithm jose.SignatureAlgorithm +} + +// Verify verifies that a token was signed by the provided key. It unmarshals into the provided claims. +func Verify(ctx context.Context, v VerifyKeyProvider, token string, claims Claims, opts ...func(*VerifyOptions)) error { + options := VerifyOptions{ + RegisteredClaims: jwt.Expected{ + Time: time.Now(), + }, + SignatureAlgorithm: signingAlgo, + } + + for _, opt := range opts { + opt(&options) + } + + object, err := jose.ParseSigned(token, []jose.SignatureAlgorithm{options.SignatureAlgorithm}) + if err != nil { + return xerrors.Errorf("parse JWS: %w", err) + } + + if len(object.Signatures) != 1 { + return xerrors.New("expected 1 signature") + } + + signature := object.Signatures[0] + + if signature.Header.Algorithm != string(signingAlgo) { + return xerrors.Errorf("expected JWS algorithm to be %q, got %q", signingAlgo, object.Signatures[0].Header.Algorithm) + } + + kid := signature.Header.KeyID + if kid == "" { + return xerrors.Errorf("expected %q header to be a string", keyIDHeaderKey) + } + + key, err := v.VerifyingKey(ctx, kid) + if err != nil { + return xerrors.Errorf("key with id %q: %w", kid, err) + } + + payload, err := object.Verify(key) + if err != nil { + return xerrors.Errorf("verify payload: %w", err) + } + + err = json.Unmarshal(payload, &claims) + if err != nil { + return xerrors.Errorf("unmarshal payload: %w", err) + } + + return claims.Validate(options.RegisteredClaims) +} diff --git a/coderd/jwtutils/jwt_test.go b/coderd/jwtutils/jwt_test.go new file mode 100644 index 0000000000000..ff30f7716b310 --- /dev/null +++ b/coderd/jwtutils/jwt_test.go @@ -0,0 +1,436 @@ +package jwtutils_test + +import ( + "context" + "crypto/rand" + "testing" + "time" + + "github.com/go-jose/go-jose/v4" + "github.com/go-jose/go-jose/v4/jwt" + "github.com/google/uuid" + "github.com/stretchr/testify/require" + + "cdr.dev/slog/sloggers/slogtest" + + "github.com/coder/coder/v2/coderd/cryptokeys" + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/dbgen" + "github.com/coder/coder/v2/coderd/database/dbtestutil" + "github.com/coder/coder/v2/coderd/jwtutils" + "github.com/coder/coder/v2/testutil" +) + +func TestClaims(t *testing.T) { + t.Parallel() + + type tokenType struct { + Name string + KeySize int + Sign bool + } + + types := []tokenType{ + { + Name: "JWE", + Sign: false, + KeySize: 32, + }, + { + Name: "JWS", + Sign: true, + KeySize: 64, + }, + } + + type testcase struct { + name string + claims jwtutils.Claims + expectedClaims jwt.Expected + expectedErr error + } + + cases := []testcase{ + { + name: "OK", + claims: jwt.Claims{ + Issuer: "coder", + Subject: "user@coder.com", + Audience: jwt.Audience{"coder"}, + Expiry: jwt.NewNumericDate(time.Now().Add(time.Hour)), + IssuedAt: jwt.NewNumericDate(time.Now()), + NotBefore: jwt.NewNumericDate(time.Now()), + }, + }, + { + name: "WrongIssuer", + claims: jwt.Claims{ + Issuer: "coder", + Subject: "user@coder.com", + Audience: jwt.Audience{"coder"}, + Expiry: jwt.NewNumericDate(time.Now().Add(time.Hour)), + IssuedAt: jwt.NewNumericDate(time.Now()), + NotBefore: jwt.NewNumericDate(time.Now()), + }, + expectedClaims: jwt.Expected{ + Issuer: "coder2", + }, + expectedErr: jwt.ErrInvalidIssuer, + }, + { + name: "WrongSubject", + claims: jwt.Claims{ + Issuer: "coder", + Subject: "user@coder.com", + Audience: jwt.Audience{"coder"}, + Expiry: jwt.NewNumericDate(time.Now().Add(time.Hour)), + IssuedAt: jwt.NewNumericDate(time.Now()), + NotBefore: jwt.NewNumericDate(time.Now()), + }, + expectedClaims: jwt.Expected{ + Subject: "user2@coder.com", + }, + expectedErr: jwt.ErrInvalidSubject, + }, + { + name: "WrongAudience", + claims: jwt.Claims{ + Issuer: "coder", + Subject: "user@coder.com", + Audience: jwt.Audience{"coder"}, + Expiry: jwt.NewNumericDate(time.Now().Add(time.Hour)), + IssuedAt: jwt.NewNumericDate(time.Now()), + NotBefore: jwt.NewNumericDate(time.Now()), + }, + }, + { + name: "Expired", + claims: jwt.Claims{ + Issuer: "coder", + Subject: "user@coder.com", + Audience: jwt.Audience{"coder"}, + Expiry: jwt.NewNumericDate(time.Now().Add(time.Minute)), + IssuedAt: jwt.NewNumericDate(time.Now()), + NotBefore: jwt.NewNumericDate(time.Now()), + }, + expectedClaims: jwt.Expected{ + Time: time.Now().Add(time.Minute * 3), + }, + expectedErr: jwt.ErrExpired, + }, + { + name: "IssuedInFuture", + claims: jwt.Claims{ + Issuer: "coder", + Subject: "user@coder.com", + Audience: jwt.Audience{"coder"}, + Expiry: jwt.NewNumericDate(time.Now().Add(time.Minute)), + IssuedAt: jwt.NewNumericDate(time.Now()), + }, + expectedClaims: jwt.Expected{ + Time: time.Now().Add(-time.Minute * 3), + }, + expectedErr: jwt.ErrIssuedInTheFuture, + }, + { + name: "IsBefore", + claims: jwt.Claims{ + Issuer: "coder", + Subject: "user@coder.com", + Audience: jwt.Audience{"coder"}, + Expiry: jwt.NewNumericDate(time.Now().Add(time.Hour)), + IssuedAt: jwt.NewNumericDate(time.Now()), + NotBefore: jwt.NewNumericDate(time.Now().Add(time.Minute * 5)), + }, + expectedClaims: jwt.Expected{ + Time: time.Now().Add(time.Minute * 3), + }, + expectedErr: jwt.ErrNotValidYet, + }, + } + + for _, tt := range types { + tt := tt + + t.Run(tt.Name, func(t *testing.T) { + t.Parallel() + for _, c := range cases { + c := c + t.Run(c.name, func(t *testing.T) { + t.Parallel() + + var ( + ctx = testutil.Context(t, testutil.WaitShort) + key = newKey(t, tt.KeySize) + token string + err error + ) + + if tt.Sign { + token, err = jwtutils.Sign(ctx, key, c.claims) + } else { + token, err = jwtutils.Encrypt(ctx, key, c.claims) + } + require.NoError(t, err) + + var actual jwt.Claims + if tt.Sign { + err = jwtutils.Verify(ctx, key, token, &actual, withVerifyExpected(c.expectedClaims)) + } else { + err = jwtutils.Decrypt(ctx, key, token, &actual, withDecryptExpected(c.expectedClaims)) + } + if c.expectedErr != nil { + require.ErrorIs(t, err, c.expectedErr) + } else { + require.NoError(t, err) + require.Equal(t, c.claims, actual) + } + }) + } + }) + } +} + +func TestJWS(t *testing.T) { + t.Parallel() + t.Run("WrongSignatureAlgorithm", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitShort) + + key := newKey(t, 64) + + token, err := jwtutils.Sign(ctx, key, jwt.Claims{}) + require.NoError(t, err) + + var actual testClaims + err = jwtutils.Verify(ctx, key, token, &actual, withSignatureAlgorithm(jose.HS256)) + require.Error(t, err) + }) + + t.Run("CustomClaims", func(t *testing.T) { + t.Parallel() + + var ( + ctx = testutil.Context(t, testutil.WaitShort) + key = newKey(t, 64) + ) + + expected := testClaims{ + MyClaim: "my_value", + } + token, err := jwtutils.Sign(ctx, key, expected) + require.NoError(t, err) + + var actual testClaims + err = jwtutils.Verify(ctx, key, token, &actual, withVerifyExpected(jwt.Expected{})) + require.NoError(t, err) + require.Equal(t, expected, actual) + }) + + t.Run("WithKeycache", func(t *testing.T) { + t.Parallel() + + var ( + ctx = testutil.Context(t, testutil.WaitShort) + db, _ = dbtestutil.NewDB(t) + _ = dbgen.CryptoKey(t, db, database.CryptoKey{ + Feature: database.CryptoKeyFeatureOidcConvert, + StartsAt: time.Now(), + }) + log = slogtest.Make(t, nil) + ) + + cache, err := cryptokeys.NewSigningCache(log, db, database.CryptoKeyFeatureOidcConvert) + require.NoError(t, err) + + claims := testClaims{ + MyClaim: "my_value", + Claims: jwt.Claims{ + Expiry: jwt.NewNumericDate(time.Now().Add(time.Hour)), + }, + } + + token, err := jwtutils.Sign(ctx, cache, claims) + require.NoError(t, err) + + var actual testClaims + err = jwtutils.Verify(ctx, cache, token, &actual) + require.NoError(t, err) + require.Equal(t, claims, actual) + }) +} + +func TestJWE(t *testing.T) { + t.Parallel() + + t.Run("WrongKeyAlgorithm", func(t *testing.T) { + t.Parallel() + + var ( + ctx = testutil.Context(t, testutil.WaitShort) + key = newKey(t, 32) + ) + + token, err := jwtutils.Encrypt(ctx, key, jwt.Claims{}) + require.NoError(t, err) + + var actual testClaims + err = jwtutils.Decrypt(ctx, key, token, &actual, withKeyAlgorithm(jose.A128GCMKW)) + require.Error(t, err) + }) + + t.Run("WrongContentyEncryption", func(t *testing.T) { + t.Parallel() + + var ( + ctx = testutil.Context(t, testutil.WaitShort) + key = newKey(t, 32) + ) + + token, err := jwtutils.Encrypt(ctx, key, jwt.Claims{}) + require.NoError(t, err) + + var actual testClaims + err = jwtutils.Decrypt(ctx, key, token, &actual, withContentEncryptionAlgorithm(jose.A128GCM)) + require.Error(t, err) + }) + + t.Run("CustomClaims", func(t *testing.T) { + t.Parallel() + + var ( + ctx = testutil.Context(t, testutil.WaitShort) + key = newKey(t, 32) + ) + + expected := testClaims{ + MyClaim: "my_value", + } + + token, err := jwtutils.Encrypt(ctx, key, expected) + require.NoError(t, err) + + var actual testClaims + err = jwtutils.Decrypt(ctx, key, token, &actual, withDecryptExpected(jwt.Expected{})) + require.NoError(t, err) + require.Equal(t, expected, actual) + }) + + t.Run("WithKeycache", func(t *testing.T) { + t.Parallel() + + var ( + ctx = testutil.Context(t, testutil.WaitShort) + db, _ = dbtestutil.NewDB(t) + _ = dbgen.CryptoKey(t, db, database.CryptoKey{ + Feature: database.CryptoKeyFeatureWorkspaceApps, + StartsAt: time.Now(), + }) + log = slogtest.Make(t, nil) + ) + + cache, err := cryptokeys.NewEncryptionCache(log, db, database.CryptoKeyFeatureWorkspaceApps) + require.NoError(t, err) + + claims := testClaims{ + MyClaim: "my_value", + Claims: jwt.Claims{ + Expiry: jwt.NewNumericDate(time.Now().Add(time.Hour)), + }, + } + + token, err := jwtutils.Encrypt(ctx, cache, claims) + require.NoError(t, err) + + var actual testClaims + err = jwtutils.Decrypt(ctx, cache, token, &actual) + require.NoError(t, err) + require.Equal(t, claims, actual) + }) +} + +func generateSecret(t *testing.T, keySize int) []byte { + t.Helper() + + b := make([]byte, keySize) + _, err := rand.Read(b) + require.NoError(t, err) + return b +} + +type testClaims struct { + MyClaim string `json:"my_claim"` + jwt.Claims +} + +func withDecryptExpected(e jwt.Expected) func(*jwtutils.DecryptOptions) { + return func(opts *jwtutils.DecryptOptions) { + opts.RegisteredClaims = e + } +} + +func withVerifyExpected(e jwt.Expected) func(*jwtutils.VerifyOptions) { + return func(opts *jwtutils.VerifyOptions) { + opts.RegisteredClaims = e + } +} + +func withSignatureAlgorithm(alg jose.SignatureAlgorithm) func(*jwtutils.VerifyOptions) { + return func(opts *jwtutils.VerifyOptions) { + opts.SignatureAlgorithm = alg + } +} + +func withKeyAlgorithm(alg jose.KeyAlgorithm) func(*jwtutils.DecryptOptions) { + return func(opts *jwtutils.DecryptOptions) { + opts.KeyAlgorithm = alg + } +} + +func withContentEncryptionAlgorithm(alg jose.ContentEncryption) func(*jwtutils.DecryptOptions) { + return func(opts *jwtutils.DecryptOptions) { + opts.ContentEncryptionAlgorithm = alg + } +} + +type key struct { + t testing.TB + id string + secret []byte +} + +func newKey(t *testing.T, size int) *key { + t.Helper() + + id := uuid.New().String() + secret := generateSecret(t, size) + + return &key{ + t: t, + id: id, + secret: secret, + } +} + +func (k *key) SigningKey(_ context.Context) (id string, key interface{}, err error) { + return k.id, k.secret, nil +} + +func (k *key) VerifyingKey(_ context.Context, id string) (key interface{}, err error) { + k.t.Helper() + + require.Equal(k.t, k.id, id) + return k.secret, nil +} + +func (k *key) EncryptingKey(_ context.Context) (id string, key interface{}, err error) { + return k.id, k.secret, nil +} + +func (k *key) DecryptingKey(_ context.Context, id string) (key interface{}, err error) { + k.t.Helper() + + require.Equal(k.t, k.id, id) + return k.secret, nil +} diff --git a/go.mod b/go.mod index d5de2c7dbf769..eea7c2b647a7d 100644 --- a/go.mod +++ b/go.mod @@ -207,6 +207,7 @@ require ( github.com/coder/serpent v0.8.0 github.com/emersion/go-sasl v0.0.0-20200509203442-7bfe0ed36a21 github.com/emersion/go-smtp v0.21.2 + github.com/go-jose/go-jose/v4 v4.0.2 github.com/gomarkdown/markdown v0.0.0-20231222211730-1d6d20845b47 github.com/google/go-github/v61 v61.0.0 github.com/mocktools/go-smtp-mock/v2 v2.3.0 @@ -224,7 +225,6 @@ require ( github.com/charmbracelet/x/ansi v0.2.3 // indirect github.com/charmbracelet/x/term v0.2.0 // indirect github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect - github.com/go-jose/go-jose/v4 v4.0.2 // indirect github.com/go-viper/mapstructure/v2 v2.0.0 // indirect github.com/hashicorp/go-plugin v1.6.1 // indirect github.com/hashicorp/go-retryablehttp v0.7.7 // indirect