Skip to content

feat: add jwt pkg #14928

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 25 commits into from
Oct 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
b5d939e
feat: add jwt pkg
sreya Oct 1, 2024
6025c7b
update make gen
sreya Oct 2, 2024
8b235be
Refactor JWT package to modularize key functions
sreya Oct 2, 2024
843de38
Remove unused JWT test file from repository
sreya Oct 2, 2024
099544f
Refactor JWT key functions and add tests
sreya Oct 2, 2024
acc4db3
Rename VerifyFn to ParseFn in JWT tests
sreya Oct 2, 2024
b4973a8
Remove unused JWE test file
sreya Oct 2, 2024
f7d7c95
Refactor JWT test structs to use public field names
sreya Oct 2, 2024
3ba8ad3
Refactor JWT to use new crypto key management system
sreya Oct 3, 2024
73c902c
Refactor JWT package for improved modularity and clarity
sreya Oct 3, 2024
e348a7a
mv dir
sreya Oct 3, 2024
c7489b4
update references
sreya Oct 3, 2024
d890ea2
refactor interfaces
sreya Oct 3, 2024
67ccd5c
refactor dbkeycache
sreya Oct 3, 2024
1a81c7a
Refactor JWT utility options for flexibility
sreya Oct 3, 2024
e529c4a
Enhance key generation and JWT error messages
sreya Oct 3, 2024
437e587
Update cryptographic key length requirements
sreya Oct 3, 2024
54214e2
Refactor key provider interfaces in JWT utilities
sreya Oct 3, 2024
93603a2
Refactor dbCache to remove feature validation
sreya Oct 3, 2024
e654a65
Refactor cryptokeys and jwtutils interfaces and logic
sreya Oct 3, 2024
0efabfd
Remove unused test code and mock cleanup
sreya Oct 3, 2024
e065356
Remove cryptokeys keycachemock from Makefile
sreya Oct 3, 2024
938bdda
Add feature validation to dbCache key methods
sreya Oct 3, 2024
48b1b3b
fmt
sreya Oct 3, 2024
1dd2205
Add initialization comment for db key cache timer
sreya Oct 4, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -537,7 +537,8 @@ gen/mark-fresh:
tailnet/tailnettest/coordinatormock.go \
tailnet/tailnettest/coordinateemock.go \
tailnet/tailnettest/multiagentmock.go \
"
"

for file in $$files; do
echo "$$file"
if [ ! -f "$$file" ]; then
Expand Down
144 changes: 110 additions & 34 deletions coderd/cryptokeys/dbkeycache.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,22 @@ package cryptokeys

import (
"context"
"strconv"
"sync"
"time"

"golang.org/x/xerrors"

"cdr.dev/slog"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/db2sdk"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/quartz"
)

// never represents the maximum value for a time.Duration.
const never = 1<<63 - 1

// DBCache implements Keycache for callers with access to the database.
type DBCache struct {
// dbCache implements Keycache for callers with access to the database.
type dbCache struct {
db database.Store
feature database.CryptoKeyFeature
logger slog.Logger
Expand All @@ -34,18 +33,34 @@ type DBCache struct {
closed bool
}

type DBCacheOption func(*DBCache)
type DBCacheOption func(*dbCache)

func WithDBCacheClock(clock quartz.Clock) DBCacheOption {
return func(d *DBCache) {
return func(d *dbCache) {
d.clock = clock
}
}

// NewDBCache creates a new DBCache. Close should be called to
// NewSigningCache creates a new DBCache. Close should be called to
// release resources associated with its internal timer.
func NewDBCache(logger slog.Logger, db database.Store, feature database.CryptoKeyFeature, opts ...func(*DBCache)) *DBCache {
d := &DBCache{
func NewSigningCache(logger slog.Logger, db database.Store, feature database.CryptoKeyFeature, opts ...func(*dbCache)) (SigningKeycache, error) {
if !isSigningKeyFeature(feature) {
return nil, ErrInvalidFeature
}

return newDBCache(logger, db, feature, opts...), nil
}

func NewEncryptionCache(logger slog.Logger, db database.Store, feature database.CryptoKeyFeature, opts ...func(*dbCache)) (EncryptionKeycache, error) {
if !isEncryptionKeyFeature(feature) {
return nil, ErrInvalidFeature
}

return newDBCache(logger, db, feature, opts...), nil
}

func newDBCache(logger slog.Logger, db database.Store, feature database.CryptoKeyFeature, opts ...func(*dbCache)) *dbCache {
d := &dbCache{
db: db,
feature: feature,
clock: quartz.NewReal(),
Expand All @@ -56,23 +71,61 @@ func NewDBCache(logger slog.Logger, db database.Store, feature database.CryptoKe
opt(d)
}

// Initialize the timer. This will get properly initialized the first time we fetch.
d.timer = d.clock.AfterFunc(never, d.clear)

return d
}

// Verifying returns the CryptoKey with the given sequence number, provided that
func (d *dbCache) EncryptingKey(ctx context.Context) (id string, key interface{}, err error) {
if !isEncryptionKeyFeature(d.feature) {
return "", nil, ErrInvalidFeature
}

return d.latest(ctx)
}

func (d *dbCache) DecryptingKey(ctx context.Context, id string) (key interface{}, err error) {
if !isEncryptionKeyFeature(d.feature) {
return nil, ErrInvalidFeature
}

return d.sequence(ctx, id)
}

func (d *dbCache) SigningKey(ctx context.Context) (id string, key interface{}, err error) {
if !isSigningKeyFeature(d.feature) {
return "", nil, ErrInvalidFeature
}

return d.latest(ctx)
}

func (d *dbCache) VerifyingKey(ctx context.Context, id string) (key interface{}, err error) {
if !isSigningKeyFeature(d.feature) {
return nil, ErrInvalidFeature
}

return d.sequence(ctx, id)
}

// sequence returns the CryptoKey with the given sequence number, provided that
// it is neither deleted nor has breached its deletion date. It should only be
// used for verifying or decrypting payloads. To sign/encrypt call Signing.
func (d *DBCache) Verifying(ctx context.Context, sequence int32) (codersdk.CryptoKey, error) {
func (d *dbCache) sequence(ctx context.Context, id string) (interface{}, error) {
sequence, err := strconv.ParseInt(id, 10, 32)
if err != nil {
return nil, xerrors.Errorf("expecting sequence number got %q: %w", id, err)
}

d.keysMu.RLock()
if d.closed {
d.keysMu.RUnlock()
return codersdk.CryptoKey{}, ErrClosed
return nil, ErrClosed
}

now := d.clock.Now()
key, ok := d.keys[sequence]
key, ok := d.keys[int32(sequence)]
d.keysMu.RUnlock()
if ok {
return checkKey(key, now)
Expand All @@ -82,67 +135,67 @@ func (d *DBCache) Verifying(ctx context.Context, sequence int32) (codersdk.Crypt
defer d.keysMu.Unlock()

if d.closed {
return codersdk.CryptoKey{}, ErrClosed
return nil, ErrClosed
}

key, ok = d.keys[sequence]
key, ok = d.keys[int32(sequence)]
if ok {
return checkKey(key, now)
}

err := d.fetch(ctx)
err = d.fetch(ctx)
if err != nil {
return codersdk.CryptoKey{}, xerrors.Errorf("fetch: %w", err)
return nil, xerrors.Errorf("fetch: %w", err)
}

key, ok = d.keys[sequence]
key, ok = d.keys[int32(sequence)]
if !ok {
return codersdk.CryptoKey{}, ErrKeyNotFound
return nil, ErrKeyNotFound
}

return checkKey(key, now)
}

// Signing returns the latest valid key for signing. A valid key is one that is
// latest returns the latest valid key for signing. A valid key is one that is
// both past its start time and before its deletion time.
func (d *DBCache) Signing(ctx context.Context) (codersdk.CryptoKey, error) {
func (d *dbCache) latest(ctx context.Context) (string, interface{}, error) {
d.keysMu.RLock()

if d.closed {
d.keysMu.RUnlock()
return codersdk.CryptoKey{}, ErrClosed
return "", nil, ErrClosed
}

latest := d.latestKey
d.keysMu.RUnlock()

now := d.clock.Now()
if latest.CanSign(now) {
return db2sdk.CryptoKey(latest), nil
return idSecret(latest)
}

d.keysMu.Lock()
defer d.keysMu.Unlock()

if d.closed {
return codersdk.CryptoKey{}, ErrClosed
return "", nil, ErrClosed
}

if d.latestKey.CanSign(now) {
return db2sdk.CryptoKey(d.latestKey), nil
return idSecret(d.latestKey)
}

// Refetch all keys for this feature so we can find the latest valid key.
err := d.fetch(ctx)
if err != nil {
return codersdk.CryptoKey{}, xerrors.Errorf("fetch: %w", err)
return "", nil, xerrors.Errorf("fetch: %w", err)
}

return db2sdk.CryptoKey(d.latestKey), nil
return idSecret(d.latestKey)
}

// clear invalidates the cache. This forces the subsequent call to fetch fresh keys.
func (d *DBCache) clear() {
func (d *dbCache) clear() {
now := d.clock.Now("DBCache", "clear")
d.keysMu.Lock()
defer d.keysMu.Unlock()
Expand All @@ -158,7 +211,7 @@ func (d *DBCache) clear() {

// fetch fetches all keys for the given feature and determines the latest key.
// It must be called while holding the keysMu lock.
func (d *DBCache) fetch(ctx context.Context) error {
func (d *dbCache) fetch(ctx context.Context) error {
keys, err := d.db.GetCryptoKeysByFeature(ctx, d.feature)
if err != nil {
return xerrors.Errorf("get crypto keys by feature: %w", err)
Expand Down Expand Up @@ -189,22 +242,45 @@ func (d *DBCache) fetch(ctx context.Context) error {
return nil
}

func checkKey(key database.CryptoKey, now time.Time) (codersdk.CryptoKey, error) {
func checkKey(key database.CryptoKey, now time.Time) (interface{}, error) {
if !key.CanVerify(now) {
return codersdk.CryptoKey{}, ErrKeyInvalid
return nil, ErrKeyInvalid
}

return db2sdk.CryptoKey(key), nil
return key.DecodeString()
}

func (d *DBCache) Close() {
func (d *dbCache) Close() error {
d.keysMu.Lock()
defer d.keysMu.Unlock()

if d.closed {
return
return nil
}

d.timer.Stop()
d.closed = true
return nil
}

func isEncryptionKeyFeature(feature database.CryptoKeyFeature) bool {
return feature == database.CryptoKeyFeatureWorkspaceApps
}

func isSigningKeyFeature(feature database.CryptoKeyFeature) bool {
switch feature {
case database.CryptoKeyFeatureTailnetResume, database.CryptoKeyFeatureOidcConvert:
return true
default:
return false
}
}

func idSecret(k database.CryptoKey) (string, interface{}, error) {
key, err := k.DecodeString()
if err != nil {
return "", nil, xerrors.Errorf("decode key: %w", err)
}

return strconv.FormatInt(int64(k.Sequence), 10), key, nil
}
Loading
Loading