Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Refactor CryptoKey handling to use codersdk package
Refactored the CryptoKey handling logic to leverage the codersdk
package. This enhances consistency and modularity by centralizing
CryptoKey management and related operations. The refactoring also
includes renaming functions to better reflect their functionality
and using the db2sdk package for database to SDK conversions.
  • Loading branch information
sreya committed Sep 26, 2024
commit 46503b66b3e1a5e27bf1958cd393cfdb25363712
71 changes: 38 additions & 33 deletions coderd/cryptokeys/dbkeycache.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ import (

"cdr.dev/slog"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/db2sdk"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/quartz"
)

Expand Down Expand Up @@ -59,84 +61,88 @@ func NewDBCache(ctx context.Context, logger slog.Logger, db database.Store, feat
}

// Version returns the CryptoKey with the given sequence number, provided that
// it is not deleted or has breached its deletion date.
func (d *DBCache) Version(ctx context.Context, sequence int32) (database.CryptoKey, error) {
// it is neither deleted nor has breached its deletion date.
func (d *DBCache) Version(ctx context.Context, sequence int32) (codersdk.CryptoKey, error) {
now := d.clock.Now().UTC()
d.cacheMu.RLock()
key, ok := d.cache[sequence]
d.cacheMu.RUnlock()
if ok {
if key.IsInvalid(now) {
return database.CryptoKey{}, ErrKeyNotFound
if !key.CanVerify(now) {
return codersdk.CryptoKey{}, ErrKeyInvalid
}
return key, nil
return db2sdk.CryptoKey(key), nil
}

d.cacheMu.Lock()
defer d.cacheMu.Unlock()

key, ok = d.cache[sequence]
if ok {
return key, nil
return db2sdk.CryptoKey(key), nil
}

key, err := d.db.GetCryptoKeyByFeatureAndSequence(ctx, database.GetCryptoKeyByFeatureAndSequenceParams{
Feature: d.feature,
Sequence: sequence,
})
if xerrors.Is(err, sql.ErrNoRows) {
return database.CryptoKey{}, ErrKeyNotFound
return codersdk.CryptoKey{}, ErrKeyNotFound
}
if err != nil {
return database.CryptoKey{}, err
return codersdk.CryptoKey{}, err
}

if key.IsInvalid(now) {
return database.CryptoKey{}, ErrKeyInvalid
if !key.CanVerify(now) {
return codersdk.CryptoKey{}, ErrKeyInvalid
}

if key.IsActive(now) && key.Sequence > d.latestKey.Sequence {
// If this key is valid for signing then mark it as the latest key.
if key.CanSign(now) && key.Sequence > d.latestKey.Sequence {
d.latestKey = key
}

d.cache[sequence] = key

return key, nil
return db2sdk.CryptoKey(key), nil
}

func (d *DBCache) Latest(ctx context.Context) (database.CryptoKey, error) {
// Latest returns the latest valid key for signing. A valid key is one that is
// both past its start time and before its deletion time.
func (d *DBCache) Latest(ctx context.Context) (codersdk.CryptoKey, error) {
d.cacheMu.RLock()
latest := d.latestKey
d.cacheMu.RUnlock()

now := d.clock.Now().UTC()
if latest.IsActive(now) {
return latest, nil
if latest.CanSign(now) {
return checkKey(latest, now)
}

d.cacheMu.Lock()
defer d.cacheMu.Unlock()

if latest.IsActive(now) {
return latest, nil
if latest.CanSign(now) {
return checkKey(latest, now)
}

// Refetch all keys for this feature so we can find the latest valid key.
cache, latest, err := d.newCache(ctx)
if err != nil {
return database.CryptoKey{}, xerrors.Errorf("new cache: %w", err)
return codersdk.CryptoKey{}, xerrors.Errorf("new cache: %w", err)
}

if len(cache) == 0 {
return database.CryptoKey{}, ErrKeyNotFound
return codersdk.CryptoKey{}, ErrKeyNotFound
}

if !latest.IsActive(now) {
return database.CryptoKey{}, ErrKeyInvalid
if !latest.CanSign(now) {
return codersdk.CryptoKey{}, ErrKeyInvalid
}

d.cache, d.latestKey = cache, latest

return d.latestKey, nil
return checkKey(latest, now)
}

func (d *DBCache) refresh(ctx context.Context) {
Expand All @@ -154,30 +160,29 @@ func (d *DBCache) refresh(ctx context.Context) {
})
}

// newCache fetches all keys for the given feature and determines the latest key.
func (d *DBCache) newCache(ctx context.Context) (map[int32]database.CryptoKey, database.CryptoKey, error) {
now := d.clock.Now().UTC()
keys, err := d.db.GetCryptoKeysByFeature(ctx, d.feature)
if err != nil {
return nil, database.CryptoKey{}, xerrors.Errorf("get crypto keys by feature: %w", err)
}
cache := toMap(keys)
cache := make(map[int32]database.CryptoKey)
var latest database.CryptoKey
// Keys are returned in order from highest sequence to lowest.
for _, key := range keys {
if !key.IsActive(now) {
continue
cache[key.Sequence] = key
if key.CanSign(now) && key.Sequence > latest.Sequence {
latest = key
}
latest = key
break
}

return cache, latest, nil
}

func toMap(keys []database.CryptoKey) map[int32]database.CryptoKey {
m := make(map[int32]database.CryptoKey)
for _, key := range keys {
m[key.Sequence] = key
func checkKey(key database.CryptoKey, now time.Time) (codersdk.CryptoKey, error) {
if !key.CanVerify(now) {
return codersdk.CryptoKey{}, ErrKeyInvalid
}
return m

return db2sdk.CryptoKey(key), nil
}
15 changes: 8 additions & 7 deletions coderd/cryptokeys/dbkeycache_internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"go.uber.org/mock/gomock"

"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/db2sdk"
"github.com/coder/coder/v2/coderd/database/dbmock"
"github.com/coder/coder/v2/testutil"
"github.com/coder/quartz"
Expand Down Expand Up @@ -49,7 +50,7 @@ func Test_Version(t *testing.T) {

got, err := k.Version(ctx, 32)
require.NoError(t, err)
require.Equal(t, expectedKey, got)
require.Equal(t, db2sdk.CryptoKey(expectedKey), got)
})

t.Run("MissesCache", func(t *testing.T) {
Expand Down Expand Up @@ -86,8 +87,8 @@ func Test_Version(t *testing.T) {

got, err := k.Version(ctx, 33)
require.NoError(t, err)
require.Equal(t, expectedKey, got)
require.Equal(t, expectedKey, k.latestKey)
require.Equal(t, db2sdk.CryptoKey(expectedKey), got)
require.Equal(t, db2sdk.CryptoKey(expectedKey), db2sdk.CryptoKey(k.latestKey))
})

t.Run("InvalidCachedKey", func(t *testing.T) {
Expand Down Expand Up @@ -123,7 +124,7 @@ func Test_Version(t *testing.T) {
}

_, err := k.Version(ctx, 32)
require.ErrorIs(t, err, ErrKeyNotFound)
require.ErrorIs(t, err, ErrKeyInvalid)
})

t.Run("InvalidDBKey", func(t *testing.T) {
Expand Down Expand Up @@ -196,7 +197,7 @@ func Test_Latest(t *testing.T) {

got, err := k.Latest(ctx)
require.NoError(t, err)
require.Equal(t, latestKey, got)
require.Equal(t, db2sdk.CryptoKey(latestKey), got)
})

t.Run("InvalidCachedKey", func(t *testing.T) {
Expand Down Expand Up @@ -242,7 +243,7 @@ func Test_Latest(t *testing.T) {

got, err := k.Latest(ctx)
require.NoError(t, err)
require.Equal(t, latestKey, got)
require.Equal(t, db2sdk.CryptoKey(latestKey), got)
})

t.Run("UsesActiveKey", func(t *testing.T) {
Expand Down Expand Up @@ -286,7 +287,7 @@ func Test_Latest(t *testing.T) {

got, err := k.Latest(ctx)
require.NoError(t, err)
require.Equal(t, activeKey, got)
require.Equal(t, db2sdk.CryptoKey(activeKey), got)
})

t.Run("NoValidKeys", func(t *testing.T) {
Expand Down
30 changes: 22 additions & 8 deletions coderd/cryptokeys/dbkeycache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (

"github.com/coder/coder/v2/coderd/cryptokeys"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/db2sdk"
"github.com/coder/coder/v2/coderd/database/dbgen"
"github.com/coder/coder/v2/coderd/database/dbtestutil"
"github.com/coder/coder/v2/testutil"
Expand Down Expand Up @@ -62,7 +63,7 @@ func TestDBKeyCache(t *testing.T) {

got, err := k.Version(ctx, key.Sequence)
require.NoError(t, err)
require.Equal(t, key, got)
require.Equal(t, db2sdk.CryptoKey(key), got)
})

t.Run("MissesCache", func(t *testing.T) {
Expand Down Expand Up @@ -100,7 +101,7 @@ func TestDBKeyCache(t *testing.T) {

got, err := k.Version(ctx, key.Sequence)
require.NoError(t, err)
require.Equal(t, key, got)
require.Equal(t, db2sdk.CryptoKey(key), got)
})
})

Expand Down Expand Up @@ -137,7 +138,7 @@ func TestDBKeyCache(t *testing.T) {

got, err := k.Latest(ctx)
require.NoError(t, err)
require.Equal(t, expectedKey, got)
require.Equal(t, db2sdk.CryptoKey(expectedKey), got)
})

t.Run("CacheRefreshes", func(t *testing.T) {
Expand Down Expand Up @@ -168,14 +169,24 @@ func TestDBKeyCache(t *testing.T) {
Valid: true,
},
})

wrongFeature := dbgen.CryptoKey(t, db, database.CryptoKey{
Feature: database.CryptoKeyFeatureOidcConvert,
Sequence: 30,
StartsAt: clock.Now().UTC(),
})

trap := clock.Trap().TickerFunc()
k, err := cryptokeys.NewDBCache(ctx, logger, db, database.CryptoKeyFeatureWorkspaceApps, cryptokeys.WithDBCacheClock(clock))
require.NoError(t, err)

// Should be able to fetch the expiring key since it's still valid.
got, err := k.Version(ctx, expiringKey.Sequence)
require.NoError(t, err)
require.Equal(t, expiringKey, got)
require.Equal(t, db2sdk.CryptoKey(expiringKey), got)

_, err = k.Version(ctx, wrongFeature.Sequence)
require.ErrorIs(t, err, cryptokeys.ErrKeyNotFound)

newLatest := dbgen.CryptoKey(t, db, database.CryptoKey{
Feature: database.CryptoKeyFeatureWorkspaceApps,
Expand All @@ -190,7 +201,7 @@ func TestDBKeyCache(t *testing.T) {
// The latest key should not be the one we just generated.
got, err = k.Latest(ctx)
require.NoError(t, err)
require.Equal(t, latest, got)
require.Equal(t, db2sdk.CryptoKey(latest), got)

// Wait for the ticker to fire and the cache to refresh.
trap.MustWait(ctx).Release()
Expand All @@ -200,11 +211,14 @@ func TestDBKeyCache(t *testing.T) {
// The latest key should be the one we just generated.
got, err = k.Latest(ctx)
require.NoError(t, err)
require.Equal(t, newLatest, got)

// The expiring key should be gone.
require.Equal(t, db2sdk.CryptoKey(newLatest), got)

// The expiring key should be invalid.
_, err = k.Version(ctx, expiringKey.Sequence)
require.ErrorIs(t, err, cryptokeys.ErrKeyInvalid)

// Sanity check that the wrong feature is still not found.
_, err = k.Version(ctx, wrongFeature.Sequence)
require.ErrorIs(t, err, cryptokeys.ErrKeyNotFound)
})
}
6 changes: 3 additions & 3 deletions coderd/cryptokeys/keycache.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import (

"golang.org/x/xerrors"

"github.com/coder/coder/v2/enterprise/wsproxy/wsproxysdk"
"github.com/coder/coder/v2/codersdk"
)

var ErrKeyNotFound = xerrors.New("key not found")
Expand All @@ -14,6 +14,6 @@ var ErrKeyInvalid = xerrors.New("key is invalid for use")

// Keycache provides an abstraction for fetching signing keys.
type Keycache interface {
Latest(ctx context.Context) (wsproxysdk.CryptoKey, error)
Version(ctx context.Context, sequence int32) (wsproxysdk.CryptoKey, error)
Latest(ctx context.Context) (codersdk.CryptoKey, error)
Version(ctx context.Context, sequence int32) (codersdk.CryptoKey, error)
}
14 changes: 14 additions & 0 deletions coderd/database/db2sdk/db2sdk.go
Original file line number Diff line number Diff line change
Expand Up @@ -659,3 +659,17 @@ func Organization(organization database.Organization) codersdk.Organization {
IsDefault: organization.IsDefault,
}
}

func CryptoKeys(keys []database.CryptoKey) []codersdk.CryptoKey {
return List(keys, CryptoKey)
}

func CryptoKey(key database.CryptoKey) codersdk.CryptoKey {
return codersdk.CryptoKey{
Feature: codersdk.CryptoKeyFeature(key.Feature),
Sequence: key.Sequence,
StartsAt: key.StartsAt.UTC(),
DeletesAt: key.DeletesAt.Time.UTC(),
Secret: key.Secret.String,
}
}
12 changes: 6 additions & 6 deletions coderd/database/modelmethods.go
Original file line number Diff line number Diff line change
Expand Up @@ -457,15 +457,15 @@ func (k CryptoKey) DecodeString() ([]byte, error) {
return hex.DecodeString(k.Secret.String)
}

func (k CryptoKey) IsActive(now time.Time) bool {
func (k CryptoKey) CanSign(now time.Time) bool {
now = now.UTC()
isAfterStart := !k.StartsAt.IsZero() && !now.Before(k.StartsAt.UTC())
return isAfterStart && !k.IsInvalid(now)
return isAfterStart && k.CanVerify(now)
}

func (k CryptoKey) IsInvalid(now time.Time) bool {
func (k CryptoKey) CanVerify(now time.Time) bool {
now = now.UTC()
isDeleted := !k.Secret.Valid
isPastDeletion := k.DeletesAt.Valid && !now.Before(k.DeletesAt.Time.UTC())
return isDeleted || isPastDeletion
hasSecret := k.Secret.Valid
isBeforeDeletion := !k.DeletesAt.Valid || now.Before(k.DeletesAt.Time.UTC())
return hasSecret && isBeforeDeletion
}
Loading
Loading