From b5d939e41e3695667a6ba18dc538ddcbb46e509d Mon Sep 17 00:00:00 2001 From: Jon Ayers Date: Tue, 1 Oct 2024 23:34:02 +0000 Subject: [PATCH 01/25] feat: add jwt pkg --- coderd/cryptokeys/doc.go | 4 + coderd/cryptokeys/mock_keycache.go | 85 ++++++ coderd/jwt/jwe.go | 100 +++++++ coderd/jwt/jwe_test.go | 1 + coderd/jwt/jws.go | 163 ++++++++++++ coderd/jwt/jwt_test.go | 406 +++++++++++++++++++++++++++++ go.mod | 2 +- 7 files changed, 760 insertions(+), 1 deletion(-) create mode 100644 coderd/cryptokeys/doc.go create mode 100644 coderd/cryptokeys/mock_keycache.go create mode 100644 coderd/jwt/jwe.go create mode 100644 coderd/jwt/jwe_test.go create mode 100644 coderd/jwt/jws.go create mode 100644 coderd/jwt/jwt_test.go diff --git a/coderd/cryptokeys/doc.go b/coderd/cryptokeys/doc.go new file mode 100644 index 0000000000000..efe2968d9cac7 --- /dev/null +++ b/coderd/cryptokeys/doc.go @@ -0,0 +1,4 @@ +// Package cryptokeys provides an abstraction for fetching internally used cryptographic keys mainly for JWT signing and verification. +package cryptokeys + +//go:generate mockgen -destination mock_keycache.go -package cryptokeys . Keycache diff --git a/coderd/cryptokeys/mock_keycache.go b/coderd/cryptokeys/mock_keycache.go new file mode 100644 index 0000000000000..e365fa8bc803b --- /dev/null +++ b/coderd/cryptokeys/mock_keycache.go @@ -0,0 +1,85 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/coder/coder/v2/coderd/cryptokeys (interfaces: Keycache) +// +// Generated by this command: +// +// mockgen -destination mock_keycache.go -package cryptokeys . Keycache +// + +// Package cryptokeys is a generated GoMock package. +package cryptokeys + +import ( + context "context" + reflect "reflect" + + codersdk "github.com/coder/coder/v2/codersdk" + gomock "go.uber.org/mock/gomock" +) + +// MockKeycache is a mock of Keycache interface. +type MockKeycache struct { + ctrl *gomock.Controller + recorder *MockKeycacheMockRecorder +} + +// MockKeycacheMockRecorder is the mock recorder for MockKeycache. +type MockKeycacheMockRecorder struct { + mock *MockKeycache +} + +// NewMockKeycache creates a new mock instance. +func NewMockKeycache(ctrl *gomock.Controller) *MockKeycache { + mock := &MockKeycache{ctrl: ctrl} + mock.recorder = &MockKeycacheMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockKeycache) EXPECT() *MockKeycacheMockRecorder { + return m.recorder +} + +// Feature mocks base method. +func (m *MockKeycache) Feature() codersdk.CryptoKeyFeature { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Feature") + ret0, _ := ret[0].(codersdk.CryptoKeyFeature) + return ret0 +} + +// Feature indicates an expected call of Feature. +func (mr *MockKeycacheMockRecorder) Feature() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Feature", reflect.TypeOf((*MockKeycache)(nil).Feature)) +} + +// Signing mocks base method. +func (m *MockKeycache) Signing(arg0 context.Context) (codersdk.CryptoKey, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Signing", arg0) + ret0, _ := ret[0].(codersdk.CryptoKey) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Signing indicates an expected call of Signing. +func (mr *MockKeycacheMockRecorder) Signing(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Signing", reflect.TypeOf((*MockKeycache)(nil).Signing), arg0) +} + +// Verifying mocks base method. +func (m *MockKeycache) Verifying(arg0 context.Context, arg1 int32) (codersdk.CryptoKey, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Verifying", arg0, arg1) + ret0, _ := ret[0].(codersdk.CryptoKey) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Verifying indicates an expected call of Verifying. +func (mr *MockKeycacheMockRecorder) Verifying(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Verifying", reflect.TypeOf((*MockKeycache)(nil).Verifying), arg0, arg1) +} diff --git a/coderd/jwt/jwe.go b/coderd/jwt/jwe.go new file mode 100644 index 0000000000000..fc9c91c492cb8 --- /dev/null +++ b/coderd/jwt/jwe.go @@ -0,0 +1,100 @@ +package jwt + +import ( + "encoding/base64" + "encoding/json" + "time" + + "github.com/go-jose/go-jose/v4" + jjwt "github.com/go-jose/go-jose/v4/jwt" + "golang.org/x/xerrors" +) + +const ( + encryptKeyAlgo = jose.A256GCMKW + encryptContentAlgo = jose.A256GCM +) + +func Encrypt(claims Claims, keyFn SecuringKeyFn) (string, error) { + kid, key, err := keyFn() + if err != nil { + return "", xerrors.Errorf("get key: %w", err) + } + + encrypter, err := jose.NewEncrypter( + encryptContentAlgo, + jose.Recipient{ + Algorithm: encryptKeyAlgo, + Key: key, + }, + &jose.EncrypterOptions{ + Compression: jose.DEFLATE, + ExtraHeaders: map[jose.HeaderKey]interface{}{ + keyIDHeaderKey: kid, + }, + }, + ) + if err != nil { + return "", xerrors.Errorf("initialize encrypter: %w", err) + } + + payload, err := json.Marshal(claims) + if err != nil { + return "", xerrors.Errorf("marshal payload: %w", err) + } + + encrypted, err := encrypter.Encrypt(payload) + if err != nil { + return "", xerrors.Errorf("encrypt: %w", err) + } + + serialized := []byte(encrypted.FullSerialize()) + return base64.RawURLEncoding.EncodeToString(serialized), nil +} + +func Decrypt(token string, claims Claims, keyFn KeyFunc, opts ...func(*ParseOptions)) error { + options := ParseOptions{ + RegisteredClaims: jjwt.Expected{ + Time: time.Now(), + }, + KeyAlgorithm: encryptKeyAlgo, + ContentEncryptionAlgorithm: encryptContentAlgo, + } + + for _, opt := range opts { + opt(&options) + } + + encrypted, err := base64.RawURLEncoding.DecodeString(token) + if err != nil { + return xerrors.Errorf("decode: %w", err) + } + + object, err := jose.ParseEncrypted(string(encrypted), + []jose.KeyAlgorithm{options.KeyAlgorithm}, + []jose.ContentEncryption{options.ContentEncryptionAlgorithm}, + ) + if err != nil { + return xerrors.Errorf("parse encrypted API key: %w", err) + } + + if object.Header.Algorithm != string(encryptKeyAlgo) { + return xerrors.Errorf("expected API key encryption algorithm to be %q, got %q", encryptKeyAlgo, object.Header.Algorithm) + } + + key, err := keyFn(object.Header) + if err != nil { + return xerrors.Errorf("get key: %w", err) + } + + decrypted, err := object.Decrypt(key) + if err != nil { + return xerrors.Errorf("decrypt: %w", err) + } + + if err := json.Unmarshal(decrypted, &claims); err != nil { + return xerrors.Errorf("unmarshal: %w", err) + } + + return claims.Validate(options.RegisteredClaims) +} diff --git a/coderd/jwt/jwe_test.go b/coderd/jwt/jwe_test.go new file mode 100644 index 0000000000000..30619ec668540 --- /dev/null +++ b/coderd/jwt/jwe_test.go @@ -0,0 +1 @@ +package jwt_test diff --git a/coderd/jwt/jws.go b/coderd/jwt/jws.go new file mode 100644 index 0000000000000..2c39ae72586aa --- /dev/null +++ b/coderd/jwt/jws.go @@ -0,0 +1,163 @@ +package jwt + +import ( + "context" + "encoding/hex" + "encoding/json" + "strconv" + "time" + + "github.com/go-jose/go-jose/v4" + jjwt "github.com/go-jose/go-jose/v4/jwt" + "golang.org/x/xerrors" + + "github.com/coder/coder/v2/coderd/cryptokeys" +) + +type Claims interface { + Validate(jjwt.Expected) error +} + +const ( + defaultSigningAlgo = jose.HS512 + featureHeaderKey = "feat" + keyIDHeaderKey = "kid" +) + +type SecuringKeyFn func() (id string, key interface{}, err error) + +func KeycacheSecure(ctx context.Context, keys cryptokeys.Keycache) SecuringKeyFn { + return func() (id string, key interface{}, err error) { + signing, err := keys.Signing(ctx) + if err != nil { + return "", nil, xerrors.Errorf("get signing key: %w", err) + } + + decoded, err := hex.DecodeString(signing.Secret) + if err != nil { + return "", nil, xerrors.Errorf("decode signing key: %w", err) + } + + return strconv.FormatInt(int64(signing.Sequence), 10), decoded, nil + } +} + +func Sign(claims Claims, keyFn SecuringKeyFn) (string, error) { + kid, key, err := keyFn() + if err != nil { + return "", xerrors.Errorf("get key: %w", err) + } + + signer, err := jose.NewSigner(jose.SigningKey{ + Algorithm: defaultSigningAlgo, + Key: key, + }, &jose.SignerOptions{ + ExtraHeaders: map[jose.HeaderKey]interface{}{ + keyIDHeaderKey: kid, + }, + }) + if err != nil { + return "", xerrors.Errorf("new signer: %w", err) + } + + payload, err := json.Marshal(claims) + if err != nil { + return "", xerrors.Errorf("marshal claims: %w", err) + } + + signed, err := signer.Sign(payload) + if err != nil { + return "", xerrors.Errorf("sign payload: %w", err) + } + + compact, err := signed.CompactSerialize() + if err != nil { + return "", xerrors.Errorf("compact serialize: %w", err) + } + + return compact, nil +} + +type KeyFunc func(jose.Header) (interface{}, error) + +func KeycacheVerify(ctx context.Context, keys cryptokeys.Keycache) KeyFunc { + return func(header jose.Header) (interface{}, error) { + sequenceStr := header.KeyID + if sequenceStr == "" { + return nil, xerrors.Errorf("expected %q header to be a string", keyIDHeaderKey) + } + + sequence, err := strconv.ParseInt(sequenceStr, 10, 32) + if err != nil { + return nil, xerrors.Errorf("parse sequence: %w", err) + } + + key, err := keys.Verifying(ctx, int32(sequence)) + if err != nil { + return nil, xerrors.Errorf("version: %w", err) + } + + decoded, err := hex.DecodeString(key.Secret) + if err != nil { + return nil, xerrors.Errorf("decode key: %w", err) + } + + return decoded, nil + } +} + +type ParseOptions struct { + RegisteredClaims jjwt.Expected + + // The following are only used for JWSs. + SignatureAlgorithm jose.SignatureAlgorithm + + // The following should only be used for JWEs. + KeyAlgorithm jose.KeyAlgorithm + ContentEncryptionAlgorithm jose.ContentEncryption +} + +func Verify(token string, claims Claims, keyFn KeyFunc, opts ...func(*ParseOptions)) error { + options := ParseOptions{ + RegisteredClaims: jjwt.Expected{ + Time: time.Now(), + }, + SignatureAlgorithm: defaultSigningAlgo, + } + + for _, opt := range opts { + opt(&options) + } + + object, err := jose.ParseSigned(token, []jose.SignatureAlgorithm{options.SignatureAlgorithm}) + if err != nil { + return xerrors.Errorf("parse JWS: %w", err) + } + + if len(object.Signatures) != 1 { + return xerrors.New("expected 1 signature") + } + + signature := object.Signatures[0] + + if signature.Header.Algorithm != string(defaultSigningAlgo) { + return xerrors.Errorf("expected token signing algorithm to be %q, got %q", defaultSigningAlgo, object.Signatures[0].Header.Algorithm) + } + + key, err := keyFn(signature.Header) + if err != nil { + return xerrors.Errorf("get key: %w", err) + } + + payload, err := object.Verify(key) + if err != nil { + return xerrors.Errorf("verify payload: %w", err) + } + + err = json.Unmarshal(payload, &claims) + if err != nil { + return xerrors.Errorf("unmarshal payload: %w", err) + } + + return claims.Validate(options.RegisteredClaims) +} diff --git a/coderd/jwt/jwt_test.go b/coderd/jwt/jwt_test.go new file mode 100644 index 0000000000000..7fa464e815c0d --- /dev/null +++ b/coderd/jwt/jwt_test.go @@ -0,0 +1,406 @@ +package jwt_test + +import ( + "crypto/rand" + "encoding/hex" + "testing" + "time" + + "github.com/go-jose/go-jose/v4" + jjwt "github.com/go-jose/go-jose/v4/jwt" + "github.com/google/uuid" + "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" + "golang.org/x/xerrors" + + "github.com/coder/coder/v2/coderd/cryptokeys" + "github.com/coder/coder/v2/coderd/jwt" + "github.com/coder/coder/v2/codersdk" + "github.com/coder/coder/v2/testutil" +) + +func TestJWT(t *testing.T) { + t.Parallel() + + type tokenType struct { + name string + SecureFn func(jwt.Claims, jwt.SecuringKeyFn) (string, error) + VerifyFn func(string, jwt.Claims, jwt.KeyFunc, ...func(*jwt.ParseOptions)) error + KeySize int + } + + types := []tokenType{ + { + name: "JWE", + SecureFn: jwt.Encrypt, + VerifyFn: jwt.Decrypt, + KeySize: 32, + }, + { + name: "JWS", + SecureFn: jwt.Sign, + VerifyFn: jwt.Verify, + KeySize: 64, + }, + } + + for _, tt := range types { + tt := tt + + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + t.Run("Basic", func(t *testing.T) { + t.Parallel() + + id := uuid.New().String() + key := generateSecret(t, tt.KeySize) + + claims := jjwt.Claims{ + Issuer: "coder", + Subject: "user@coder.com", + Audience: jjwt.Audience{"coder"}, + Expiry: jjwt.NewNumericDate(time.Now().Add(time.Hour)), + IssuedAt: jjwt.NewNumericDate(time.Now()), + NotBefore: jjwt.NewNumericDate(time.Now()), + } + + token, err := tt.SecureFn(claims, securingKeyFn(id, key)) + require.NoError(t, err) + + var actual testClaims + err = tt.VerifyFn(token, &actual, verifyingKeyFn(id, key)) + require.NoError(t, err) + }) + + t.Run("Keycache", func(t *testing.T) { + t.Parallel() + + var ( + ctx = testutil.Context(t, testutil.WaitShort) + ctrl = gomock.NewController(t) + keycache = cryptokeys.NewMockKeycache(ctrl) + now = time.Now() + ) + + key := generateCryptoKey(t, 1234567890, now, tt.KeySize) + + keycache.EXPECT().Signing(gomock.Any()).Return(key, nil) + keycache.EXPECT().Verifying(gomock.Any(), key.Sequence).Return(key, nil) + + claims := jjwt.Claims{ + Issuer: "coder", + Subject: "user@coder.com", + Audience: jjwt.Audience{"coder"}, + Expiry: jjwt.NewNumericDate(time.Now().Add(time.Hour)), + IssuedAt: jjwt.NewNumericDate(time.Now()), + NotBefore: jjwt.NewNumericDate(time.Now()), + } + + token, err := tt.SecureFn(claims, jwt.KeycacheSecure(ctx, keycache)) + require.NoError(t, err) + + var actual testClaims + err = tt.VerifyFn(token, &actual, jwt.KeycacheVerify(ctx, keycache)) + require.NoError(t, err) + }) + + t.Run("WrongIssuer", func(t *testing.T) { + t.Parallel() + + id := uuid.New().String() + key := generateSecret(t, tt.KeySize) + + claims := jjwt.Claims{ + Issuer: "coder", + Subject: "user@coder.com", + Audience: jjwt.Audience{"coder"}, + Expiry: jjwt.NewNumericDate(time.Now().Add(time.Hour)), + IssuedAt: jjwt.NewNumericDate(time.Now()), + NotBefore: jjwt.NewNumericDate(time.Now()), + } + + token, err := tt.SecureFn(claims, securingKeyFn(id, key)) + require.NoError(t, err) + + var actual testClaims + err = tt.VerifyFn(token, &actual, verifyingKeyFn(id, key), withExpected(jjwt.Expected{ + Issuer: "coder2", + })) + require.ErrorIs(t, err, jjwt.ErrInvalidIssuer) + }) + + t.Run("WrongSubject", func(t *testing.T) { + t.Parallel() + + id := uuid.New().String() + key := generateSecret(t, tt.KeySize) + + claims := jjwt.Claims{ + Issuer: "coder", + Subject: "user@coder.com", + Audience: jjwt.Audience{"coder"}, + Expiry: jjwt.NewNumericDate(time.Now().Add(time.Hour)), + IssuedAt: jjwt.NewNumericDate(time.Now()), + NotBefore: jjwt.NewNumericDate(time.Now()), + } + + token, err := tt.SecureFn(claims, securingKeyFn(id, key)) + require.NoError(t, err) + + var actual testClaims + err = tt.VerifyFn(token, &actual, verifyingKeyFn(id, key), withExpected(jjwt.Expected{ + Subject: "user2@coder.com", + })) + require.ErrorIs(t, err, jjwt.ErrInvalidSubject) + }) + + t.Run("WrongAudience", func(t *testing.T) { + t.Parallel() + + id := uuid.New().String() + key := generateSecret(t, tt.KeySize) + + claims := jjwt.Claims{ + Issuer: "coder", + Subject: "user@coder.com", + Audience: jjwt.Audience{"coder"}, + Expiry: jjwt.NewNumericDate(time.Now().Add(time.Hour)), + IssuedAt: jjwt.NewNumericDate(time.Now()), + NotBefore: jjwt.NewNumericDate(time.Now()), + } + + token, err := tt.SecureFn(claims, securingKeyFn(id, key)) + require.NoError(t, err) + + var actual testClaims + err = tt.VerifyFn(token, &actual, verifyingKeyFn(id, key), withExpected(jjwt.Expected{ + AnyAudience: jjwt.Audience{"coder2"}, + })) + require.ErrorIs(t, err, jjwt.ErrInvalidAudience) + }) + + t.Run("Expired", func(t *testing.T) { + t.Parallel() + + id := uuid.New().String() + key := generateSecret(t, tt.KeySize) + + claims := jjwt.Claims{ + Issuer: "coder", + Subject: "user@coder.com", + Audience: jjwt.Audience{"coder"}, + Expiry: jjwt.NewNumericDate(time.Now().Add(time.Minute)), + IssuedAt: jjwt.NewNumericDate(time.Now()), + NotBefore: jjwt.NewNumericDate(time.Now()), + } + + token, err := tt.SecureFn(claims, securingKeyFn(id, key)) + require.NoError(t, err) + + var actual testClaims + err = tt.VerifyFn(token, &actual, verifyingKeyFn(id, key), withExpected(jjwt.Expected{ + Time: time.Now().Add(time.Minute * 3), + })) + require.ErrorIs(t, err, jjwt.ErrExpired) + }) + + t.Run("IssuedInFuture", func(t *testing.T) { + t.Parallel() + + id := uuid.New().String() + key := generateSecret(t, tt.KeySize) + + claims := jjwt.Claims{ + Issuer: "coder", + Subject: "user@coder.com", + Audience: jjwt.Audience{"coder"}, + Expiry: jjwt.NewNumericDate(time.Now().Add(time.Minute)), + IssuedAt: jjwt.NewNumericDate(time.Now()), + } + + token, err := tt.SecureFn(claims, securingKeyFn(id, key)) + require.NoError(t, err) + + var actual testClaims + err = tt.VerifyFn(token, &actual, verifyingKeyFn(id, key), withExpected(jjwt.Expected{ + Time: time.Now().Add(-time.Minute * 3), + })) + require.ErrorIs(t, err, jjwt.ErrIssuedInTheFuture) + }) + + t.Run("IsBefore", func(t *testing.T) { + t.Parallel() + + id := uuid.New().String() + key := generateSecret(t, tt.KeySize) + + claims := jjwt.Claims{ + Issuer: "coder", + Subject: "user@coder.com", + Audience: jjwt.Audience{"coder"}, + Expiry: jjwt.NewNumericDate(time.Now().Add(time.Hour)), + IssuedAt: jjwt.NewNumericDate(time.Now()), + NotBefore: jjwt.NewNumericDate(time.Now().Add(time.Minute * 5)), + } + + token, err := tt.SecureFn(claims, securingKeyFn(id, key)) + require.NoError(t, err) + + var actual testClaims + err = tt.VerifyFn(token, &actual, verifyingKeyFn(id, key), withExpected(jjwt.Expected{ + Time: time.Now().Add(time.Minute * 3), + })) + require.ErrorIs(t, err, jjwt.ErrNotValidYet) + }) + + t.Run("WrongSignatureAlgorithm", func(t *testing.T) { + t.Parallel() + + if tt.name == "JWE" { + t.Skip("JWE does not support this") + } + + id := uuid.New().String() + key := generateSecret(t, tt.KeySize) + + claims := jjwt.Claims{ + Issuer: "coder", + Subject: "user@coder.com", + Audience: jjwt.Audience{"coder"}, + Expiry: jjwt.NewNumericDate(time.Now().Add(time.Hour)), + IssuedAt: jjwt.NewNumericDate(time.Now()), + NotBefore: jjwt.NewNumericDate(time.Now().Add(time.Minute * 5)), + } + + token, err := tt.SecureFn(claims, securingKeyFn(id, key)) + require.NoError(t, err) + + var actual testClaims + err = tt.VerifyFn(token, &actual, verifyingKeyFn(id, key), withSignatureAlgorithm(jose.HS256)) + require.Error(t, err) + }) + + t.Run("WrongKeyAlgorithm", func(t *testing.T) { + t.Parallel() + + if tt.name == "JWS" { + t.Skip("JWS does not support this") + } + + id := uuid.New().String() + key := generateSecret(t, tt.KeySize) + + claims := jjwt.Claims{ + Issuer: "coder", + Subject: "user@coder.com", + Audience: jjwt.Audience{"coder"}, + Expiry: jjwt.NewNumericDate(time.Now().Add(time.Hour)), + IssuedAt: jjwt.NewNumericDate(time.Now()), + NotBefore: jjwt.NewNumericDate(time.Now().Add(time.Minute * 5)), + } + + token, err := tt.SecureFn(claims, securingKeyFn(id, key)) + require.NoError(t, err) + + var actual testClaims + err = tt.VerifyFn(token, &actual, verifyingKeyFn(id, key), withKeyAlgorithm(jose.A128GCMKW)) + require.Error(t, err) + }) + + t.Run("WrongContentyEncryption", func(t *testing.T) { + t.Parallel() + + if tt.name == "JWS" { + t.Skip("JWS does not support this") + } + + id := uuid.New().String() + key := generateSecret(t, tt.KeySize) + + claims := jjwt.Claims{ + Issuer: "coder", + Subject: "user@coder.com", + Audience: jjwt.Audience{"coder"}, + Expiry: jjwt.NewNumericDate(time.Now().Add(time.Hour)), + IssuedAt: jjwt.NewNumericDate(time.Now()), + NotBefore: jjwt.NewNumericDate(time.Now().Add(time.Minute * 5)), + } + + token, err := tt.SecureFn(claims, securingKeyFn(id, key)) + require.NoError(t, err) + + var actual testClaims + err = tt.VerifyFn(token, &actual, verifyingKeyFn(id, key), withContentEncryptionAlgorithm(jose.A128GCM)) + require.Error(t, err) + + }) + }) + } +} + +func generateCryptoKey(t *testing.T, seq int32, now time.Time, keySize int) codersdk.CryptoKey { + t.Helper() + + secret := generateSecret(t, keySize) + + return codersdk.CryptoKey{ + Feature: codersdk.CryptoKeyFeatureTailnetResume, + Secret: hex.EncodeToString(secret), + Sequence: seq, + StartsAt: now, + } +} + +func generateSecret(t *testing.T, keySize int) []byte { + t.Helper() + + b := make([]byte, keySize) + _, err := rand.Read(b) + require.NoError(t, err) + return b +} + +type testClaims struct { + MyClaim string `json:"my_claim"` + jjwt.Claims +} + +func securingKeyFn(id string, key []byte) jwt.SecuringKeyFn { + return func() (string, interface{}, error) { + return id, key, nil + } +} + +func verifyingKeyFn(id string, key []byte) jwt.KeyFunc { + return func(header jose.Header) (interface{}, error) { + if header.KeyID != id { + return nil, xerrors.Errorf("expected key ID %q, got %q", id, header.KeyID) + } + return key, nil + } +} + +func withExpected(e jjwt.Expected) func(*jwt.ParseOptions) { + return func(opts *jwt.ParseOptions) { + opts.RegisteredClaims = e + } +} + +func withSignatureAlgorithm(alg jose.SignatureAlgorithm) func(*jwt.ParseOptions) { + return func(opts *jwt.ParseOptions) { + opts.SignatureAlgorithm = alg + } +} + +func withKeyAlgorithm(alg jose.KeyAlgorithm) func(*jwt.ParseOptions) { + return func(opts *jwt.ParseOptions) { + opts.KeyAlgorithm = alg + } +} + +func withContentEncryptionAlgorithm(alg jose.ContentEncryption) func(*jwt.ParseOptions) { + return func(opts *jwt.ParseOptions) { + opts.ContentEncryptionAlgorithm = alg + } +} diff --git a/go.mod b/go.mod index d5de2c7dbf769..eea7c2b647a7d 100644 --- a/go.mod +++ b/go.mod @@ -207,6 +207,7 @@ require ( github.com/coder/serpent v0.8.0 github.com/emersion/go-sasl v0.0.0-20200509203442-7bfe0ed36a21 github.com/emersion/go-smtp v0.21.2 + github.com/go-jose/go-jose/v4 v4.0.2 github.com/gomarkdown/markdown v0.0.0-20231222211730-1d6d20845b47 github.com/google/go-github/v61 v61.0.0 github.com/mocktools/go-smtp-mock/v2 v2.3.0 @@ -224,7 +225,6 @@ require ( github.com/charmbracelet/x/ansi v0.2.3 // indirect github.com/charmbracelet/x/term v0.2.0 // indirect github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect - github.com/go-jose/go-jose/v4 v4.0.2 // indirect github.com/go-viper/mapstructure/v2 v2.0.0 // indirect github.com/hashicorp/go-plugin v1.6.1 // indirect github.com/hashicorp/go-retryablehttp v0.7.7 // indirect From 6025c7b48b112a65e80111bb1db62c0e04d52364 Mon Sep 17 00:00:00 2001 From: Jon Ayers Date: Wed, 2 Oct 2024 04:54:33 +0000 Subject: [PATCH 02/25] update make gen --- Makefile | 7 ++++++- coderd/cryptokeys/doc.go | 2 +- .../{mock_keycache.go => keycachemock.go} | 16 +--------------- 3 files changed, 8 insertions(+), 17 deletions(-) rename coderd/cryptokeys/{mock_keycache.go => keycachemock.go} (79%) diff --git a/Makefile b/Makefile index 0765346500975..4bfcc73a06966 100644 --- a/Makefile +++ b/Makefile @@ -507,7 +507,8 @@ gen: \ examples/examples.gen.json \ tailnet/tailnettest/coordinatormock.go \ tailnet/tailnettest/coordinateemock.go \ - tailnet/tailnettest/multiagentmock.go + tailnet/tailnettest/multiagentmock.go \ + coderd/cryptokeys/keycachemock.go .PHONY: gen # Mark all generated files as fresh so make thinks they're up-to-date. This is @@ -537,6 +538,7 @@ gen/mark-fresh: tailnet/tailnettest/coordinatormock.go \ tailnet/tailnettest/coordinateemock.go \ tailnet/tailnettest/multiagentmock.go \ + coderd/cryptokeys/keycachemock.go " for file in $$files; do echo "$$file" @@ -628,6 +630,9 @@ examples/examples.gen.json: scripts/examplegen/main.go examples/examples.go $(sh coderd/rbac/object_gen.go: scripts/rbacgen/rbacobject.gotmpl scripts/rbacgen/main.go coderd/rbac/object.go coderd/rbac/policy/policy.go go run scripts/rbacgen/main.go rbac > coderd/rbac/object_gen.go +coderd/cryptokeys/keycachemock.go: coderd/cryptokeys/keycache.go + go generate ./coderd/cryptokeys + codersdk/rbacresources_gen.go: scripts/rbacgen/codersdk.gotmpl scripts/rbacgen/main.go coderd/rbac/object.go coderd/rbac/policy/policy.go go run scripts/rbacgen/main.go codersdk > codersdk/rbacresources_gen.go diff --git a/coderd/cryptokeys/doc.go b/coderd/cryptokeys/doc.go index efe2968d9cac7..8cee81c28bd69 100644 --- a/coderd/cryptokeys/doc.go +++ b/coderd/cryptokeys/doc.go @@ -1,4 +1,4 @@ // Package cryptokeys provides an abstraction for fetching internally used cryptographic keys mainly for JWT signing and verification. package cryptokeys -//go:generate mockgen -destination mock_keycache.go -package cryptokeys . Keycache +//go:generate mockgen -destination keycachemock.go -package cryptokeys . Keycache diff --git a/coderd/cryptokeys/mock_keycache.go b/coderd/cryptokeys/keycachemock.go similarity index 79% rename from coderd/cryptokeys/mock_keycache.go rename to coderd/cryptokeys/keycachemock.go index e365fa8bc803b..7a7b2e5b0ca13 100644 --- a/coderd/cryptokeys/mock_keycache.go +++ b/coderd/cryptokeys/keycachemock.go @@ -3,7 +3,7 @@ // // Generated by this command: // -// mockgen -destination mock_keycache.go -package cryptokeys . Keycache +// mockgen -destination keycachemock.go -package cryptokeys . Keycache // // Package cryptokeys is a generated GoMock package. @@ -40,20 +40,6 @@ func (m *MockKeycache) EXPECT() *MockKeycacheMockRecorder { return m.recorder } -// Feature mocks base method. -func (m *MockKeycache) Feature() codersdk.CryptoKeyFeature { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Feature") - ret0, _ := ret[0].(codersdk.CryptoKeyFeature) - return ret0 -} - -// Feature indicates an expected call of Feature. -func (mr *MockKeycacheMockRecorder) Feature() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Feature", reflect.TypeOf((*MockKeycache)(nil).Feature)) -} - // Signing mocks base method. func (m *MockKeycache) Signing(arg0 context.Context) (codersdk.CryptoKey, error) { m.ctrl.T.Helper() From 8b235beb6184dc32a82d4d8fb1db77b161709623 Mon Sep 17 00:00:00 2001 From: Jon Ayers Date: Wed, 2 Oct 2024 05:07:19 +0000 Subject: [PATCH 03/25] Refactor JWT package to modularize key functions This refactor encapsulates key management for JWT signing and verification by: - Introducing `ParseKeyFunc` and `SecuringKeyFn` to streamline key operations. - Moving key-related functions to `jwt.go` to improve package modularity. - Simplifying `Encrypt` and `Decrypt` function documentation and usage. - Updating tests to align with the new function signatures and logic flow. --- coderd/jwt/jwe.go | 4 ++- coderd/jwt/jws.go | 70 ++----------------------------------- coderd/jwt/jwt.go | 79 ++++++++++++++++++++++++++++++++++++++++++ coderd/jwt/jwt_test.go | 6 ++-- 4 files changed, 88 insertions(+), 71 deletions(-) create mode 100644 coderd/jwt/jwt.go diff --git a/coderd/jwt/jwe.go b/coderd/jwt/jwe.go index fc9c91c492cb8..15174049f6181 100644 --- a/coderd/jwt/jwe.go +++ b/coderd/jwt/jwe.go @@ -15,6 +15,7 @@ const ( encryptContentAlgo = jose.A256GCM ) +// Encrypt encrypts a token and returns it as a string. func Encrypt(claims Claims, keyFn SecuringKeyFn) (string, error) { kid, key, err := keyFn() if err != nil { @@ -52,7 +53,8 @@ func Encrypt(claims Claims, keyFn SecuringKeyFn) (string, error) { return base64.RawURLEncoding.EncodeToString(serialized), nil } -func Decrypt(token string, claims Claims, keyFn KeyFunc, opts ...func(*ParseOptions)) error { +// Decrypt decrypts the token using the provided key. It unmarshals into the provided claims. +func Decrypt(token string, claims Claims, keyFn ParseKeyFunc, opts ...func(*ParseOptions)) error { options := ParseOptions{ RegisteredClaims: jjwt.Expected{ Time: time.Now(), diff --git a/coderd/jwt/jws.go b/coderd/jwt/jws.go index 2c39ae72586aa..2ed410cc20fa8 100644 --- a/coderd/jwt/jws.go +++ b/coderd/jwt/jws.go @@ -1,47 +1,21 @@ package jwt import ( - "context" - "encoding/hex" "encoding/json" - "strconv" "time" "github.com/go-jose/go-jose/v4" jjwt "github.com/go-jose/go-jose/v4/jwt" "golang.org/x/xerrors" - - "github.com/coder/coder/v2/coderd/cryptokeys" ) -type Claims interface { - Validate(jjwt.Expected) error -} - const ( defaultSigningAlgo = jose.HS512 featureHeaderKey = "feat" keyIDHeaderKey = "kid" ) -type SecuringKeyFn func() (id string, key interface{}, err error) - -func KeycacheSecure(ctx context.Context, keys cryptokeys.Keycache) SecuringKeyFn { - return func() (id string, key interface{}, err error) { - signing, err := keys.Signing(ctx) - if err != nil { - return "", nil, xerrors.Errorf("get signing key: %w", err) - } - - decoded, err := hex.DecodeString(signing.Secret) - if err != nil { - return "", nil, xerrors.Errorf("decode signing key: %w", err) - } - - return strconv.FormatInt(int64(signing.Sequence), 10), decoded, nil - } -} - +// Sign signs a token and returns it as a string. func Sign(claims Claims, keyFn SecuringKeyFn) (string, error) { kid, key, err := keyFn() if err != nil { @@ -78,46 +52,8 @@ func Sign(claims Claims, keyFn SecuringKeyFn) (string, error) { return compact, nil } -type KeyFunc func(jose.Header) (interface{}, error) - -func KeycacheVerify(ctx context.Context, keys cryptokeys.Keycache) KeyFunc { - return func(header jose.Header) (interface{}, error) { - sequenceStr := header.KeyID - if sequenceStr == "" { - return nil, xerrors.Errorf("expected %q header to be a string", keyIDHeaderKey) - } - - sequence, err := strconv.ParseInt(sequenceStr, 10, 32) - if err != nil { - return nil, xerrors.Errorf("parse sequence: %w", err) - } - - key, err := keys.Verifying(ctx, int32(sequence)) - if err != nil { - return nil, xerrors.Errorf("version: %w", err) - } - - decoded, err := hex.DecodeString(key.Secret) - if err != nil { - return nil, xerrors.Errorf("decode key: %w", err) - } - - return decoded, nil - } -} - -type ParseOptions struct { - RegisteredClaims jjwt.Expected - - // The following are only used for JWSs. - SignatureAlgorithm jose.SignatureAlgorithm - - // The following should only be used for JWEs. - KeyAlgorithm jose.KeyAlgorithm - ContentEncryptionAlgorithm jose.ContentEncryption -} - -func Verify(token string, claims Claims, keyFn KeyFunc, opts ...func(*ParseOptions)) error { +// Verify verifies that a token was signed by the provided key. It unmarshals into the provided claims. +func Verify(token string, claims Claims, keyFn ParseKeyFunc, opts ...func(*ParseOptions)) error { options := ParseOptions{ RegisteredClaims: jjwt.Expected{ Time: time.Now(), diff --git a/coderd/jwt/jwt.go b/coderd/jwt/jwt.go new file mode 100644 index 0000000000000..55b157b8f260a --- /dev/null +++ b/coderd/jwt/jwt.go @@ -0,0 +1,79 @@ +package jwt + +import ( + "context" + "encoding/hex" + "strconv" + + "github.com/go-jose/go-jose/v4" + jjwt "github.com/go-jose/go-jose/v4/jwt" + "golang.org/x/xerrors" + + "github.com/coder/coder/v2/coderd/cryptokeys" +) + +type Claims interface { + Validate(jjwt.Expected) error +} + +// SecuringKeyFn returns a key for signing or encrypting. +type SecuringKeyFn func() (id string, key interface{}, err error) + +// KeycacheSecure returns the appropriate key for signing or encrypting. +func KeycacheSecure(ctx context.Context, keys cryptokeys.Keycache) SecuringKeyFn { + return func() (id string, key interface{}, err error) { + signing, err := keys.Signing(ctx) + if err != nil { + return "", nil, xerrors.Errorf("get signing key: %w", err) + } + + decoded, err := hex.DecodeString(signing.Secret) + if err != nil { + return "", nil, xerrors.Errorf("decode signing key: %w", err) + } + + return strconv.FormatInt(int64(signing.Sequence), 10), decoded, nil + } +} + +// ParseKeyFunc returns a key for verifying or decrypting a token. +type ParseKeyFunc func(jose.Header) (interface{}, error) + +// KeycacheParse returns the appropriate key to decrypt or verify a token. +func KeycacheParse(ctx context.Context, keys cryptokeys.Keycache) ParseKeyFunc { + return func(header jose.Header) (interface{}, error) { + sequenceStr := header.KeyID + if sequenceStr == "" { + return nil, xerrors.Errorf("expected %q header to be a string", keyIDHeaderKey) + } + + sequence, err := strconv.ParseInt(sequenceStr, 10, 32) + if err != nil { + return nil, xerrors.Errorf("parse sequence: %w", err) + } + + key, err := keys.Verifying(ctx, int32(sequence)) + if err != nil { + return nil, xerrors.Errorf("version: %w", err) + } + + decoded, err := hex.DecodeString(key.Secret) + if err != nil { + return nil, xerrors.Errorf("decode key: %w", err) + } + + return decoded, nil + } +} + +// ParseOptions are options for parsing a JWT. +type ParseOptions struct { + RegisteredClaims jjwt.Expected + + // The following are only used for JWSs. + SignatureAlgorithm jose.SignatureAlgorithm + + // The following should only be used for JWEs. + KeyAlgorithm jose.KeyAlgorithm + ContentEncryptionAlgorithm jose.ContentEncryption +} diff --git a/coderd/jwt/jwt_test.go b/coderd/jwt/jwt_test.go index 7fa464e815c0d..ea96f4b3afde6 100644 --- a/coderd/jwt/jwt_test.go +++ b/coderd/jwt/jwt_test.go @@ -25,7 +25,7 @@ func TestJWT(t *testing.T) { type tokenType struct { name string SecureFn func(jwt.Claims, jwt.SecuringKeyFn) (string, error) - VerifyFn func(string, jwt.Claims, jwt.KeyFunc, ...func(*jwt.ParseOptions)) error + VerifyFn func(string, jwt.Claims, jwt.ParseKeyFunc, ...func(*jwt.ParseOptions)) error KeySize int } @@ -101,7 +101,7 @@ func TestJWT(t *testing.T) { require.NoError(t, err) var actual testClaims - err = tt.VerifyFn(token, &actual, jwt.KeycacheVerify(ctx, keycache)) + err = tt.VerifyFn(token, &actual, jwt.KeycacheParse(ctx, keycache)) require.NoError(t, err) }) @@ -372,7 +372,7 @@ func securingKeyFn(id string, key []byte) jwt.SecuringKeyFn { } } -func verifyingKeyFn(id string, key []byte) jwt.KeyFunc { +func verifyingKeyFn(id string, key []byte) jwt.ParseKeyFunc { return func(header jose.Header) (interface{}, error) { if header.KeyID != id { return nil, xerrors.Errorf("expected key ID %q, got %q", id, header.KeyID) From 843de3802396c684d3ae9f512c345b68a2405101 Mon Sep 17 00:00:00 2001 From: Jon Ayers Date: Wed, 2 Oct 2024 05:08:54 +0000 Subject: [PATCH 04/25] Remove unused JWT test file from repository The JWT functionality has been refactored, and its associated test values are no longer required. Removing the test file helps maintain code clarity and prevents outdated test logic from impacting the current codebase. --- coderd/jwt/jwt_test.go | 406 ----------------------------------------- 1 file changed, 406 deletions(-) delete mode 100644 coderd/jwt/jwt_test.go diff --git a/coderd/jwt/jwt_test.go b/coderd/jwt/jwt_test.go deleted file mode 100644 index ea96f4b3afde6..0000000000000 --- a/coderd/jwt/jwt_test.go +++ /dev/null @@ -1,406 +0,0 @@ -package jwt_test - -import ( - "crypto/rand" - "encoding/hex" - "testing" - "time" - - "github.com/go-jose/go-jose/v4" - jjwt "github.com/go-jose/go-jose/v4/jwt" - "github.com/google/uuid" - "github.com/stretchr/testify/require" - "go.uber.org/mock/gomock" - "golang.org/x/xerrors" - - "github.com/coder/coder/v2/coderd/cryptokeys" - "github.com/coder/coder/v2/coderd/jwt" - "github.com/coder/coder/v2/codersdk" - "github.com/coder/coder/v2/testutil" -) - -func TestJWT(t *testing.T) { - t.Parallel() - - type tokenType struct { - name string - SecureFn func(jwt.Claims, jwt.SecuringKeyFn) (string, error) - VerifyFn func(string, jwt.Claims, jwt.ParseKeyFunc, ...func(*jwt.ParseOptions)) error - KeySize int - } - - types := []tokenType{ - { - name: "JWE", - SecureFn: jwt.Encrypt, - VerifyFn: jwt.Decrypt, - KeySize: 32, - }, - { - name: "JWS", - SecureFn: jwt.Sign, - VerifyFn: jwt.Verify, - KeySize: 64, - }, - } - - for _, tt := range types { - tt := tt - - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - - t.Run("Basic", func(t *testing.T) { - t.Parallel() - - id := uuid.New().String() - key := generateSecret(t, tt.KeySize) - - claims := jjwt.Claims{ - Issuer: "coder", - Subject: "user@coder.com", - Audience: jjwt.Audience{"coder"}, - Expiry: jjwt.NewNumericDate(time.Now().Add(time.Hour)), - IssuedAt: jjwt.NewNumericDate(time.Now()), - NotBefore: jjwt.NewNumericDate(time.Now()), - } - - token, err := tt.SecureFn(claims, securingKeyFn(id, key)) - require.NoError(t, err) - - var actual testClaims - err = tt.VerifyFn(token, &actual, verifyingKeyFn(id, key)) - require.NoError(t, err) - }) - - t.Run("Keycache", func(t *testing.T) { - t.Parallel() - - var ( - ctx = testutil.Context(t, testutil.WaitShort) - ctrl = gomock.NewController(t) - keycache = cryptokeys.NewMockKeycache(ctrl) - now = time.Now() - ) - - key := generateCryptoKey(t, 1234567890, now, tt.KeySize) - - keycache.EXPECT().Signing(gomock.Any()).Return(key, nil) - keycache.EXPECT().Verifying(gomock.Any(), key.Sequence).Return(key, nil) - - claims := jjwt.Claims{ - Issuer: "coder", - Subject: "user@coder.com", - Audience: jjwt.Audience{"coder"}, - Expiry: jjwt.NewNumericDate(time.Now().Add(time.Hour)), - IssuedAt: jjwt.NewNumericDate(time.Now()), - NotBefore: jjwt.NewNumericDate(time.Now()), - } - - token, err := tt.SecureFn(claims, jwt.KeycacheSecure(ctx, keycache)) - require.NoError(t, err) - - var actual testClaims - err = tt.VerifyFn(token, &actual, jwt.KeycacheParse(ctx, keycache)) - require.NoError(t, err) - }) - - t.Run("WrongIssuer", func(t *testing.T) { - t.Parallel() - - id := uuid.New().String() - key := generateSecret(t, tt.KeySize) - - claims := jjwt.Claims{ - Issuer: "coder", - Subject: "user@coder.com", - Audience: jjwt.Audience{"coder"}, - Expiry: jjwt.NewNumericDate(time.Now().Add(time.Hour)), - IssuedAt: jjwt.NewNumericDate(time.Now()), - NotBefore: jjwt.NewNumericDate(time.Now()), - } - - token, err := tt.SecureFn(claims, securingKeyFn(id, key)) - require.NoError(t, err) - - var actual testClaims - err = tt.VerifyFn(token, &actual, verifyingKeyFn(id, key), withExpected(jjwt.Expected{ - Issuer: "coder2", - })) - require.ErrorIs(t, err, jjwt.ErrInvalidIssuer) - }) - - t.Run("WrongSubject", func(t *testing.T) { - t.Parallel() - - id := uuid.New().String() - key := generateSecret(t, tt.KeySize) - - claims := jjwt.Claims{ - Issuer: "coder", - Subject: "user@coder.com", - Audience: jjwt.Audience{"coder"}, - Expiry: jjwt.NewNumericDate(time.Now().Add(time.Hour)), - IssuedAt: jjwt.NewNumericDate(time.Now()), - NotBefore: jjwt.NewNumericDate(time.Now()), - } - - token, err := tt.SecureFn(claims, securingKeyFn(id, key)) - require.NoError(t, err) - - var actual testClaims - err = tt.VerifyFn(token, &actual, verifyingKeyFn(id, key), withExpected(jjwt.Expected{ - Subject: "user2@coder.com", - })) - require.ErrorIs(t, err, jjwt.ErrInvalidSubject) - }) - - t.Run("WrongAudience", func(t *testing.T) { - t.Parallel() - - id := uuid.New().String() - key := generateSecret(t, tt.KeySize) - - claims := jjwt.Claims{ - Issuer: "coder", - Subject: "user@coder.com", - Audience: jjwt.Audience{"coder"}, - Expiry: jjwt.NewNumericDate(time.Now().Add(time.Hour)), - IssuedAt: jjwt.NewNumericDate(time.Now()), - NotBefore: jjwt.NewNumericDate(time.Now()), - } - - token, err := tt.SecureFn(claims, securingKeyFn(id, key)) - require.NoError(t, err) - - var actual testClaims - err = tt.VerifyFn(token, &actual, verifyingKeyFn(id, key), withExpected(jjwt.Expected{ - AnyAudience: jjwt.Audience{"coder2"}, - })) - require.ErrorIs(t, err, jjwt.ErrInvalidAudience) - }) - - t.Run("Expired", func(t *testing.T) { - t.Parallel() - - id := uuid.New().String() - key := generateSecret(t, tt.KeySize) - - claims := jjwt.Claims{ - Issuer: "coder", - Subject: "user@coder.com", - Audience: jjwt.Audience{"coder"}, - Expiry: jjwt.NewNumericDate(time.Now().Add(time.Minute)), - IssuedAt: jjwt.NewNumericDate(time.Now()), - NotBefore: jjwt.NewNumericDate(time.Now()), - } - - token, err := tt.SecureFn(claims, securingKeyFn(id, key)) - require.NoError(t, err) - - var actual testClaims - err = tt.VerifyFn(token, &actual, verifyingKeyFn(id, key), withExpected(jjwt.Expected{ - Time: time.Now().Add(time.Minute * 3), - })) - require.ErrorIs(t, err, jjwt.ErrExpired) - }) - - t.Run("IssuedInFuture", func(t *testing.T) { - t.Parallel() - - id := uuid.New().String() - key := generateSecret(t, tt.KeySize) - - claims := jjwt.Claims{ - Issuer: "coder", - Subject: "user@coder.com", - Audience: jjwt.Audience{"coder"}, - Expiry: jjwt.NewNumericDate(time.Now().Add(time.Minute)), - IssuedAt: jjwt.NewNumericDate(time.Now()), - } - - token, err := tt.SecureFn(claims, securingKeyFn(id, key)) - require.NoError(t, err) - - var actual testClaims - err = tt.VerifyFn(token, &actual, verifyingKeyFn(id, key), withExpected(jjwt.Expected{ - Time: time.Now().Add(-time.Minute * 3), - })) - require.ErrorIs(t, err, jjwt.ErrIssuedInTheFuture) - }) - - t.Run("IsBefore", func(t *testing.T) { - t.Parallel() - - id := uuid.New().String() - key := generateSecret(t, tt.KeySize) - - claims := jjwt.Claims{ - Issuer: "coder", - Subject: "user@coder.com", - Audience: jjwt.Audience{"coder"}, - Expiry: jjwt.NewNumericDate(time.Now().Add(time.Hour)), - IssuedAt: jjwt.NewNumericDate(time.Now()), - NotBefore: jjwt.NewNumericDate(time.Now().Add(time.Minute * 5)), - } - - token, err := tt.SecureFn(claims, securingKeyFn(id, key)) - require.NoError(t, err) - - var actual testClaims - err = tt.VerifyFn(token, &actual, verifyingKeyFn(id, key), withExpected(jjwt.Expected{ - Time: time.Now().Add(time.Minute * 3), - })) - require.ErrorIs(t, err, jjwt.ErrNotValidYet) - }) - - t.Run("WrongSignatureAlgorithm", func(t *testing.T) { - t.Parallel() - - if tt.name == "JWE" { - t.Skip("JWE does not support this") - } - - id := uuid.New().String() - key := generateSecret(t, tt.KeySize) - - claims := jjwt.Claims{ - Issuer: "coder", - Subject: "user@coder.com", - Audience: jjwt.Audience{"coder"}, - Expiry: jjwt.NewNumericDate(time.Now().Add(time.Hour)), - IssuedAt: jjwt.NewNumericDate(time.Now()), - NotBefore: jjwt.NewNumericDate(time.Now().Add(time.Minute * 5)), - } - - token, err := tt.SecureFn(claims, securingKeyFn(id, key)) - require.NoError(t, err) - - var actual testClaims - err = tt.VerifyFn(token, &actual, verifyingKeyFn(id, key), withSignatureAlgorithm(jose.HS256)) - require.Error(t, err) - }) - - t.Run("WrongKeyAlgorithm", func(t *testing.T) { - t.Parallel() - - if tt.name == "JWS" { - t.Skip("JWS does not support this") - } - - id := uuid.New().String() - key := generateSecret(t, tt.KeySize) - - claims := jjwt.Claims{ - Issuer: "coder", - Subject: "user@coder.com", - Audience: jjwt.Audience{"coder"}, - Expiry: jjwt.NewNumericDate(time.Now().Add(time.Hour)), - IssuedAt: jjwt.NewNumericDate(time.Now()), - NotBefore: jjwt.NewNumericDate(time.Now().Add(time.Minute * 5)), - } - - token, err := tt.SecureFn(claims, securingKeyFn(id, key)) - require.NoError(t, err) - - var actual testClaims - err = tt.VerifyFn(token, &actual, verifyingKeyFn(id, key), withKeyAlgorithm(jose.A128GCMKW)) - require.Error(t, err) - }) - - t.Run("WrongContentyEncryption", func(t *testing.T) { - t.Parallel() - - if tt.name == "JWS" { - t.Skip("JWS does not support this") - } - - id := uuid.New().String() - key := generateSecret(t, tt.KeySize) - - claims := jjwt.Claims{ - Issuer: "coder", - Subject: "user@coder.com", - Audience: jjwt.Audience{"coder"}, - Expiry: jjwt.NewNumericDate(time.Now().Add(time.Hour)), - IssuedAt: jjwt.NewNumericDate(time.Now()), - NotBefore: jjwt.NewNumericDate(time.Now().Add(time.Minute * 5)), - } - - token, err := tt.SecureFn(claims, securingKeyFn(id, key)) - require.NoError(t, err) - - var actual testClaims - err = tt.VerifyFn(token, &actual, verifyingKeyFn(id, key), withContentEncryptionAlgorithm(jose.A128GCM)) - require.Error(t, err) - - }) - }) - } -} - -func generateCryptoKey(t *testing.T, seq int32, now time.Time, keySize int) codersdk.CryptoKey { - t.Helper() - - secret := generateSecret(t, keySize) - - return codersdk.CryptoKey{ - Feature: codersdk.CryptoKeyFeatureTailnetResume, - Secret: hex.EncodeToString(secret), - Sequence: seq, - StartsAt: now, - } -} - -func generateSecret(t *testing.T, keySize int) []byte { - t.Helper() - - b := make([]byte, keySize) - _, err := rand.Read(b) - require.NoError(t, err) - return b -} - -type testClaims struct { - MyClaim string `json:"my_claim"` - jjwt.Claims -} - -func securingKeyFn(id string, key []byte) jwt.SecuringKeyFn { - return func() (string, interface{}, error) { - return id, key, nil - } -} - -func verifyingKeyFn(id string, key []byte) jwt.ParseKeyFunc { - return func(header jose.Header) (interface{}, error) { - if header.KeyID != id { - return nil, xerrors.Errorf("expected key ID %q, got %q", id, header.KeyID) - } - return key, nil - } -} - -func withExpected(e jjwt.Expected) func(*jwt.ParseOptions) { - return func(opts *jwt.ParseOptions) { - opts.RegisteredClaims = e - } -} - -func withSignatureAlgorithm(alg jose.SignatureAlgorithm) func(*jwt.ParseOptions) { - return func(opts *jwt.ParseOptions) { - opts.SignatureAlgorithm = alg - } -} - -func withKeyAlgorithm(alg jose.KeyAlgorithm) func(*jwt.ParseOptions) { - return func(opts *jwt.ParseOptions) { - opts.KeyAlgorithm = alg - } -} - -func withContentEncryptionAlgorithm(alg jose.ContentEncryption) func(*jwt.ParseOptions) { - return func(opts *jwt.ParseOptions) { - opts.ContentEncryptionAlgorithm = alg - } -} From 099544f9222fbd7d1ea735b23e8edb059ed4628d Mon Sep 17 00:00:00 2001 From: Jon Ayers Date: Wed, 2 Oct 2024 05:12:38 +0000 Subject: [PATCH 05/25] Refactor JWT key functions and add tests - Renamed key functions to clarify their use. - Removed unused `featureHeaderKey`. - Added comprehensive tests for JWT signing and verifying. --- coderd/jwt/jws.go | 1 - coderd/jwt/jwt.go | 8 +- coderd/jwt/jwt_test.go | 405 +++++++++++++++++++++++++++++++++++++++++ 3 files changed, 409 insertions(+), 5 deletions(-) create mode 100644 coderd/jwt/jwt_test.go diff --git a/coderd/jwt/jws.go b/coderd/jwt/jws.go index 2ed410cc20fa8..c9a9c5ce8131e 100644 --- a/coderd/jwt/jws.go +++ b/coderd/jwt/jws.go @@ -11,7 +11,6 @@ import ( const ( defaultSigningAlgo = jose.HS512 - featureHeaderKey = "feat" keyIDHeaderKey = "kid" ) diff --git a/coderd/jwt/jwt.go b/coderd/jwt/jwt.go index 55b157b8f260a..fbcf51603accb 100644 --- a/coderd/jwt/jwt.go +++ b/coderd/jwt/jwt.go @@ -19,8 +19,8 @@ type Claims interface { // SecuringKeyFn returns a key for signing or encrypting. type SecuringKeyFn func() (id string, key interface{}, err error) -// KeycacheSecure returns the appropriate key for signing or encrypting. -func KeycacheSecure(ctx context.Context, keys cryptokeys.Keycache) SecuringKeyFn { +// SecuringKeyFromCache returns the appropriate key for signing or encrypting. +func SecuringKeyFromCache(ctx context.Context, keys cryptokeys.Keycache) SecuringKeyFn { return func() (id string, key interface{}, err error) { signing, err := keys.Signing(ctx) if err != nil { @@ -39,8 +39,8 @@ func KeycacheSecure(ctx context.Context, keys cryptokeys.Keycache) SecuringKeyFn // ParseKeyFunc returns a key for verifying or decrypting a token. type ParseKeyFunc func(jose.Header) (interface{}, error) -// KeycacheParse returns the appropriate key to decrypt or verify a token. -func KeycacheParse(ctx context.Context, keys cryptokeys.Keycache) ParseKeyFunc { +// ParseKeyFromCache returns the appropriate key to decrypt or verify a token. +func ParseKeyFromCache(ctx context.Context, keys cryptokeys.Keycache) ParseKeyFunc { return func(header jose.Header) (interface{}, error) { sequenceStr := header.KeyID if sequenceStr == "" { diff --git a/coderd/jwt/jwt_test.go b/coderd/jwt/jwt_test.go new file mode 100644 index 0000000000000..57fb77849eb0e --- /dev/null +++ b/coderd/jwt/jwt_test.go @@ -0,0 +1,405 @@ +package jwt_test + +import ( + "crypto/rand" + "encoding/hex" + "testing" + "time" + + "github.com/go-jose/go-jose/v4" + jjwt "github.com/go-jose/go-jose/v4/jwt" + "github.com/google/uuid" + "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" + "golang.org/x/xerrors" + + "github.com/coder/coder/v2/coderd/cryptokeys" + "github.com/coder/coder/v2/coderd/jwt" + "github.com/coder/coder/v2/codersdk" + "github.com/coder/coder/v2/testutil" +) + +func TestJWT(t *testing.T) { + t.Parallel() + + type tokenType struct { + name string + SecureFn func(jwt.Claims, jwt.SecuringKeyFn) (string, error) + VerifyFn func(string, jwt.Claims, jwt.ParseKeyFunc, ...func(*jwt.ParseOptions)) error + KeySize int + } + + types := []tokenType{ + { + name: "JWE", + SecureFn: jwt.Encrypt, + VerifyFn: jwt.Decrypt, + KeySize: 32, + }, + { + name: "JWS", + SecureFn: jwt.Sign, + VerifyFn: jwt.Verify, + KeySize: 64, + }, + } + + for _, tt := range types { + tt := tt + + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + t.Run("Basic", func(t *testing.T) { + t.Parallel() + + id := uuid.New().String() + key := generateSecret(t, tt.KeySize) + + claims := jjwt.Claims{ + Issuer: "coder", + Subject: "user@coder.com", + Audience: jjwt.Audience{"coder"}, + Expiry: jjwt.NewNumericDate(time.Now().Add(time.Hour)), + IssuedAt: jjwt.NewNumericDate(time.Now()), + NotBefore: jjwt.NewNumericDate(time.Now()), + } + + token, err := tt.SecureFn(claims, securingKeyFn(id, key)) + require.NoError(t, err) + + var actual testClaims + err = tt.VerifyFn(token, &actual, verifyingKeyFn(id, key)) + require.NoError(t, err) + }) + + t.Run("Keycache", func(t *testing.T) { + t.Parallel() + + var ( + ctx = testutil.Context(t, testutil.WaitShort) + ctrl = gomock.NewController(t) + keycache = cryptokeys.NewMockKeycache(ctrl) + now = time.Now() + ) + + key := generateCryptoKey(t, 1234567890, now, tt.KeySize) + + keycache.EXPECT().Signing(gomock.Any()).Return(key, nil) + keycache.EXPECT().Verifying(gomock.Any(), key.Sequence).Return(key, nil) + + claims := jjwt.Claims{ + Issuer: "coder", + Subject: "user@coder.com", + Audience: jjwt.Audience{"coder"}, + Expiry: jjwt.NewNumericDate(time.Now().Add(time.Hour)), + IssuedAt: jjwt.NewNumericDate(time.Now()), + NotBefore: jjwt.NewNumericDate(time.Now()), + } + + token, err := tt.SecureFn(claims, jwt.SecuringKeyFromCache(ctx, keycache)) + require.NoError(t, err) + + var actual testClaims + err = tt.VerifyFn(token, &actual, jwt.ParseKeyFromCache(ctx, keycache)) + require.NoError(t, err) + }) + + t.Run("WrongIssuer", func(t *testing.T) { + t.Parallel() + + id := uuid.New().String() + key := generateSecret(t, tt.KeySize) + + claims := jjwt.Claims{ + Issuer: "coder", + Subject: "user@coder.com", + Audience: jjwt.Audience{"coder"}, + Expiry: jjwt.NewNumericDate(time.Now().Add(time.Hour)), + IssuedAt: jjwt.NewNumericDate(time.Now()), + NotBefore: jjwt.NewNumericDate(time.Now()), + } + + token, err := tt.SecureFn(claims, securingKeyFn(id, key)) + require.NoError(t, err) + + var actual testClaims + err = tt.VerifyFn(token, &actual, verifyingKeyFn(id, key), withExpected(jjwt.Expected{ + Issuer: "coder2", + })) + require.ErrorIs(t, err, jjwt.ErrInvalidIssuer) + }) + + t.Run("WrongSubject", func(t *testing.T) { + t.Parallel() + + id := uuid.New().String() + key := generateSecret(t, tt.KeySize) + + claims := jjwt.Claims{ + Issuer: "coder", + Subject: "user@coder.com", + Audience: jjwt.Audience{"coder"}, + Expiry: jjwt.NewNumericDate(time.Now().Add(time.Hour)), + IssuedAt: jjwt.NewNumericDate(time.Now()), + NotBefore: jjwt.NewNumericDate(time.Now()), + } + + token, err := tt.SecureFn(claims, securingKeyFn(id, key)) + require.NoError(t, err) + + var actual testClaims + err = tt.VerifyFn(token, &actual, verifyingKeyFn(id, key), withExpected(jjwt.Expected{ + Subject: "user2@coder.com", + })) + require.ErrorIs(t, err, jjwt.ErrInvalidSubject) + }) + + t.Run("WrongAudience", func(t *testing.T) { + t.Parallel() + + id := uuid.New().String() + key := generateSecret(t, tt.KeySize) + + claims := jjwt.Claims{ + Issuer: "coder", + Subject: "user@coder.com", + Audience: jjwt.Audience{"coder"}, + Expiry: jjwt.NewNumericDate(time.Now().Add(time.Hour)), + IssuedAt: jjwt.NewNumericDate(time.Now()), + NotBefore: jjwt.NewNumericDate(time.Now()), + } + + token, err := tt.SecureFn(claims, securingKeyFn(id, key)) + require.NoError(t, err) + + var actual testClaims + err = tt.VerifyFn(token, &actual, verifyingKeyFn(id, key), withExpected(jjwt.Expected{ + AnyAudience: jjwt.Audience{"coder2"}, + })) + require.ErrorIs(t, err, jjwt.ErrInvalidAudience) + }) + + t.Run("Expired", func(t *testing.T) { + t.Parallel() + + id := uuid.New().String() + key := generateSecret(t, tt.KeySize) + + claims := jjwt.Claims{ + Issuer: "coder", + Subject: "user@coder.com", + Audience: jjwt.Audience{"coder"}, + Expiry: jjwt.NewNumericDate(time.Now().Add(time.Minute)), + IssuedAt: jjwt.NewNumericDate(time.Now()), + NotBefore: jjwt.NewNumericDate(time.Now()), + } + + token, err := tt.SecureFn(claims, securingKeyFn(id, key)) + require.NoError(t, err) + + var actual testClaims + err = tt.VerifyFn(token, &actual, verifyingKeyFn(id, key), withExpected(jjwt.Expected{ + Time: time.Now().Add(time.Minute * 3), + })) + require.ErrorIs(t, err, jjwt.ErrExpired) + }) + + t.Run("IssuedInFuture", func(t *testing.T) { + t.Parallel() + + id := uuid.New().String() + key := generateSecret(t, tt.KeySize) + + claims := jjwt.Claims{ + Issuer: "coder", + Subject: "user@coder.com", + Audience: jjwt.Audience{"coder"}, + Expiry: jjwt.NewNumericDate(time.Now().Add(time.Minute)), + IssuedAt: jjwt.NewNumericDate(time.Now()), + } + + token, err := tt.SecureFn(claims, securingKeyFn(id, key)) + require.NoError(t, err) + + var actual testClaims + err = tt.VerifyFn(token, &actual, verifyingKeyFn(id, key), withExpected(jjwt.Expected{ + Time: time.Now().Add(-time.Minute * 3), + })) + require.ErrorIs(t, err, jjwt.ErrIssuedInTheFuture) + }) + + t.Run("IsBefore", func(t *testing.T) { + t.Parallel() + + id := uuid.New().String() + key := generateSecret(t, tt.KeySize) + + claims := jjwt.Claims{ + Issuer: "coder", + Subject: "user@coder.com", + Audience: jjwt.Audience{"coder"}, + Expiry: jjwt.NewNumericDate(time.Now().Add(time.Hour)), + IssuedAt: jjwt.NewNumericDate(time.Now()), + NotBefore: jjwt.NewNumericDate(time.Now().Add(time.Minute * 5)), + } + + token, err := tt.SecureFn(claims, securingKeyFn(id, key)) + require.NoError(t, err) + + var actual testClaims + err = tt.VerifyFn(token, &actual, verifyingKeyFn(id, key), withExpected(jjwt.Expected{ + Time: time.Now().Add(time.Minute * 3), + })) + require.ErrorIs(t, err, jjwt.ErrNotValidYet) + }) + + t.Run("WrongSignatureAlgorithm", func(t *testing.T) { + t.Parallel() + + if tt.name == "JWE" { + t.Skip("JWE does not support this") + } + + id := uuid.New().String() + key := generateSecret(t, tt.KeySize) + + claims := jjwt.Claims{ + Issuer: "coder", + Subject: "user@coder.com", + Audience: jjwt.Audience{"coder"}, + Expiry: jjwt.NewNumericDate(time.Now().Add(time.Hour)), + IssuedAt: jjwt.NewNumericDate(time.Now()), + NotBefore: jjwt.NewNumericDate(time.Now().Add(time.Minute * 5)), + } + + token, err := tt.SecureFn(claims, securingKeyFn(id, key)) + require.NoError(t, err) + + var actual testClaims + err = tt.VerifyFn(token, &actual, verifyingKeyFn(id, key), withSignatureAlgorithm(jose.HS256)) + require.Error(t, err) + }) + + t.Run("WrongKeyAlgorithm", func(t *testing.T) { + t.Parallel() + + if tt.name == "JWS" { + t.Skip("JWS does not support this") + } + + id := uuid.New().String() + key := generateSecret(t, tt.KeySize) + + claims := jjwt.Claims{ + Issuer: "coder", + Subject: "user@coder.com", + Audience: jjwt.Audience{"coder"}, + Expiry: jjwt.NewNumericDate(time.Now().Add(time.Hour)), + IssuedAt: jjwt.NewNumericDate(time.Now()), + NotBefore: jjwt.NewNumericDate(time.Now().Add(time.Minute * 5)), + } + + token, err := tt.SecureFn(claims, securingKeyFn(id, key)) + require.NoError(t, err) + + var actual testClaims + err = tt.VerifyFn(token, &actual, verifyingKeyFn(id, key), withKeyAlgorithm(jose.A128GCMKW)) + require.Error(t, err) + }) + + t.Run("WrongContentyEncryption", func(t *testing.T) { + t.Parallel() + + if tt.name == "JWS" { + t.Skip("JWS does not support this") + } + + id := uuid.New().String() + key := generateSecret(t, tt.KeySize) + + claims := jjwt.Claims{ + Issuer: "coder", + Subject: "user@coder.com", + Audience: jjwt.Audience{"coder"}, + Expiry: jjwt.NewNumericDate(time.Now().Add(time.Hour)), + IssuedAt: jjwt.NewNumericDate(time.Now()), + NotBefore: jjwt.NewNumericDate(time.Now().Add(time.Minute * 5)), + } + + token, err := tt.SecureFn(claims, securingKeyFn(id, key)) + require.NoError(t, err) + + var actual testClaims + err = tt.VerifyFn(token, &actual, verifyingKeyFn(id, key), withContentEncryptionAlgorithm(jose.A128GCM)) + require.Error(t, err) + }) + }) + } +} + +func generateCryptoKey(t *testing.T, seq int32, now time.Time, keySize int) codersdk.CryptoKey { + t.Helper() + + secret := generateSecret(t, keySize) + + return codersdk.CryptoKey{ + Feature: codersdk.CryptoKeyFeatureTailnetResume, + Secret: hex.EncodeToString(secret), + Sequence: seq, + StartsAt: now, + } +} + +func generateSecret(t *testing.T, keySize int) []byte { + t.Helper() + + b := make([]byte, keySize) + _, err := rand.Read(b) + require.NoError(t, err) + return b +} + +type testClaims struct { + MyClaim string `json:"my_claim"` + jjwt.Claims +} + +func securingKeyFn(id string, key []byte) jwt.SecuringKeyFn { + return func() (string, interface{}, error) { + return id, key, nil + } +} + +func verifyingKeyFn(id string, key []byte) jwt.ParseKeyFunc { + return func(header jose.Header) (interface{}, error) { + if header.KeyID != id { + return nil, xerrors.Errorf("expected key ID %q, got %q", id, header.KeyID) + } + return key, nil + } +} + +func withExpected(e jjwt.Expected) func(*jwt.ParseOptions) { + return func(opts *jwt.ParseOptions) { + opts.RegisteredClaims = e + } +} + +func withSignatureAlgorithm(alg jose.SignatureAlgorithm) func(*jwt.ParseOptions) { + return func(opts *jwt.ParseOptions) { + opts.SignatureAlgorithm = alg + } +} + +func withKeyAlgorithm(alg jose.KeyAlgorithm) func(*jwt.ParseOptions) { + return func(opts *jwt.ParseOptions) { + opts.KeyAlgorithm = alg + } +} + +func withContentEncryptionAlgorithm(alg jose.ContentEncryption) func(*jwt.ParseOptions) { + return func(opts *jwt.ParseOptions) { + opts.ContentEncryptionAlgorithm = alg + } +} From acc4db384d9b6392d7ee65330d2243d06093aa7a Mon Sep 17 00:00:00 2001 From: Jon Ayers Date: Wed, 2 Oct 2024 05:14:19 +0000 Subject: [PATCH 06/25] Rename VerifyFn to ParseFn in JWT tests --- coderd/jwt/jwt_test.go | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/coderd/jwt/jwt_test.go b/coderd/jwt/jwt_test.go index 57fb77849eb0e..123719c7e06c3 100644 --- a/coderd/jwt/jwt_test.go +++ b/coderd/jwt/jwt_test.go @@ -25,7 +25,7 @@ func TestJWT(t *testing.T) { type tokenType struct { name string SecureFn func(jwt.Claims, jwt.SecuringKeyFn) (string, error) - VerifyFn func(string, jwt.Claims, jwt.ParseKeyFunc, ...func(*jwt.ParseOptions)) error + ParseFn func(string, jwt.Claims, jwt.ParseKeyFunc, ...func(*jwt.ParseOptions)) error KeySize int } @@ -33,13 +33,13 @@ func TestJWT(t *testing.T) { { name: "JWE", SecureFn: jwt.Encrypt, - VerifyFn: jwt.Decrypt, + ParseFn: jwt.Decrypt, KeySize: 32, }, { name: "JWS", SecureFn: jwt.Sign, - VerifyFn: jwt.Verify, + ParseFn: jwt.Verify, KeySize: 64, }, } @@ -69,7 +69,7 @@ func TestJWT(t *testing.T) { require.NoError(t, err) var actual testClaims - err = tt.VerifyFn(token, &actual, verifyingKeyFn(id, key)) + err = tt.ParseFn(token, &actual, verifyingKeyFn(id, key)) require.NoError(t, err) }) @@ -101,7 +101,7 @@ func TestJWT(t *testing.T) { require.NoError(t, err) var actual testClaims - err = tt.VerifyFn(token, &actual, jwt.ParseKeyFromCache(ctx, keycache)) + err = tt.ParseFn(token, &actual, jwt.ParseKeyFromCache(ctx, keycache)) require.NoError(t, err) }) @@ -124,7 +124,7 @@ func TestJWT(t *testing.T) { require.NoError(t, err) var actual testClaims - err = tt.VerifyFn(token, &actual, verifyingKeyFn(id, key), withExpected(jjwt.Expected{ + err = tt.ParseFn(token, &actual, verifyingKeyFn(id, key), withExpected(jjwt.Expected{ Issuer: "coder2", })) require.ErrorIs(t, err, jjwt.ErrInvalidIssuer) @@ -149,7 +149,7 @@ func TestJWT(t *testing.T) { require.NoError(t, err) var actual testClaims - err = tt.VerifyFn(token, &actual, verifyingKeyFn(id, key), withExpected(jjwt.Expected{ + err = tt.ParseFn(token, &actual, verifyingKeyFn(id, key), withExpected(jjwt.Expected{ Subject: "user2@coder.com", })) require.ErrorIs(t, err, jjwt.ErrInvalidSubject) @@ -174,7 +174,7 @@ func TestJWT(t *testing.T) { require.NoError(t, err) var actual testClaims - err = tt.VerifyFn(token, &actual, verifyingKeyFn(id, key), withExpected(jjwt.Expected{ + err = tt.ParseFn(token, &actual, verifyingKeyFn(id, key), withExpected(jjwt.Expected{ AnyAudience: jjwt.Audience{"coder2"}, })) require.ErrorIs(t, err, jjwt.ErrInvalidAudience) @@ -199,7 +199,7 @@ func TestJWT(t *testing.T) { require.NoError(t, err) var actual testClaims - err = tt.VerifyFn(token, &actual, verifyingKeyFn(id, key), withExpected(jjwt.Expected{ + err = tt.ParseFn(token, &actual, verifyingKeyFn(id, key), withExpected(jjwt.Expected{ Time: time.Now().Add(time.Minute * 3), })) require.ErrorIs(t, err, jjwt.ErrExpired) @@ -223,7 +223,7 @@ func TestJWT(t *testing.T) { require.NoError(t, err) var actual testClaims - err = tt.VerifyFn(token, &actual, verifyingKeyFn(id, key), withExpected(jjwt.Expected{ + err = tt.ParseFn(token, &actual, verifyingKeyFn(id, key), withExpected(jjwt.Expected{ Time: time.Now().Add(-time.Minute * 3), })) require.ErrorIs(t, err, jjwt.ErrIssuedInTheFuture) @@ -248,7 +248,7 @@ func TestJWT(t *testing.T) { require.NoError(t, err) var actual testClaims - err = tt.VerifyFn(token, &actual, verifyingKeyFn(id, key), withExpected(jjwt.Expected{ + err = tt.ParseFn(token, &actual, verifyingKeyFn(id, key), withExpected(jjwt.Expected{ Time: time.Now().Add(time.Minute * 3), })) require.ErrorIs(t, err, jjwt.ErrNotValidYet) @@ -277,7 +277,7 @@ func TestJWT(t *testing.T) { require.NoError(t, err) var actual testClaims - err = tt.VerifyFn(token, &actual, verifyingKeyFn(id, key), withSignatureAlgorithm(jose.HS256)) + err = tt.ParseFn(token, &actual, verifyingKeyFn(id, key), withSignatureAlgorithm(jose.HS256)) require.Error(t, err) }) @@ -304,7 +304,7 @@ func TestJWT(t *testing.T) { require.NoError(t, err) var actual testClaims - err = tt.VerifyFn(token, &actual, verifyingKeyFn(id, key), withKeyAlgorithm(jose.A128GCMKW)) + err = tt.ParseFn(token, &actual, verifyingKeyFn(id, key), withKeyAlgorithm(jose.A128GCMKW)) require.Error(t, err) }) @@ -331,7 +331,7 @@ func TestJWT(t *testing.T) { require.NoError(t, err) var actual testClaims - err = tt.VerifyFn(token, &actual, verifyingKeyFn(id, key), withContentEncryptionAlgorithm(jose.A128GCM)) + err = tt.ParseFn(token, &actual, verifyingKeyFn(id, key), withContentEncryptionAlgorithm(jose.A128GCM)) require.Error(t, err) }) }) From b4973a80264baa5e11f4c0d657ecf663d89197d2 Mon Sep 17 00:00:00 2001 From: Jon Ayers Date: Wed, 2 Oct 2024 05:17:31 +0000 Subject: [PATCH 07/25] Remove unused JWE test file --- coderd/jwt/jwe_test.go | 1 - 1 file changed, 1 deletion(-) delete mode 100644 coderd/jwt/jwe_test.go diff --git a/coderd/jwt/jwe_test.go b/coderd/jwt/jwe_test.go deleted file mode 100644 index 30619ec668540..0000000000000 --- a/coderd/jwt/jwe_test.go +++ /dev/null @@ -1 +0,0 @@ -package jwt_test From f7d7c95b694e0baf15a57ea7790152929ecdc3d2 Mon Sep 17 00:00:00 2001 From: Jon Ayers Date: Wed, 2 Oct 2024 05:18:21 +0000 Subject: [PATCH 08/25] Refactor JWT test structs to use public field names --- coderd/jwt/jwt_test.go | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/coderd/jwt/jwt_test.go b/coderd/jwt/jwt_test.go index 123719c7e06c3..b9d50d7503e30 100644 --- a/coderd/jwt/jwt_test.go +++ b/coderd/jwt/jwt_test.go @@ -23,7 +23,7 @@ func TestJWT(t *testing.T) { t.Parallel() type tokenType struct { - name string + Name string SecureFn func(jwt.Claims, jwt.SecuringKeyFn) (string, error) ParseFn func(string, jwt.Claims, jwt.ParseKeyFunc, ...func(*jwt.ParseOptions)) error KeySize int @@ -31,13 +31,13 @@ func TestJWT(t *testing.T) { types := []tokenType{ { - name: "JWE", + Name: "JWE", SecureFn: jwt.Encrypt, ParseFn: jwt.Decrypt, KeySize: 32, }, { - name: "JWS", + Name: "JWS", SecureFn: jwt.Sign, ParseFn: jwt.Verify, KeySize: 64, @@ -47,7 +47,7 @@ func TestJWT(t *testing.T) { for _, tt := range types { tt := tt - t.Run(tt.name, func(t *testing.T) { + t.Run(tt.Name, func(t *testing.T) { t.Parallel() t.Run("Basic", func(t *testing.T) { @@ -257,7 +257,7 @@ func TestJWT(t *testing.T) { t.Run("WrongSignatureAlgorithm", func(t *testing.T) { t.Parallel() - if tt.name == "JWE" { + if tt.Name == "JWE" { t.Skip("JWE does not support this") } @@ -284,7 +284,7 @@ func TestJWT(t *testing.T) { t.Run("WrongKeyAlgorithm", func(t *testing.T) { t.Parallel() - if tt.name == "JWS" { + if tt.Name == "JWS" { t.Skip("JWS does not support this") } @@ -311,7 +311,7 @@ func TestJWT(t *testing.T) { t.Run("WrongContentyEncryption", func(t *testing.T) { t.Parallel() - if tt.name == "JWS" { + if tt.Name == "JWS" { t.Skip("JWS does not support this") } From 3ba8ad33257b8ace4f35fdc878ded17f72efc450 Mon Sep 17 00:00:00 2001 From: Jon Ayers Date: Thu, 3 Oct 2024 01:19:16 +0000 Subject: [PATCH 09/25] Refactor JWT to use new crypto key management system - Streamlines key retrieval by integrating with the cryptokeys package. - Simplifies key handling by decoding directly using hex utilities. - Enhances validation by using context for managing key operations. - Improves test robustness with clearer setup and consistent mocking. --- coderd/jwt/jwe.go | 43 ++++- coderd/jwt/jws.go | 43 ++++- coderd/jwt/jwt.go | 59 +------ coderd/jwt/jwt_test.go | 382 ++++++++++++++++++++++++----------------- 4 files changed, 295 insertions(+), 232 deletions(-) diff --git a/coderd/jwt/jwe.go b/coderd/jwt/jwe.go index 15174049f6181..1529d75265916 100644 --- a/coderd/jwt/jwe.go +++ b/coderd/jwt/jwe.go @@ -1,13 +1,18 @@ package jwt import ( + "context" "encoding/base64" + "encoding/hex" "encoding/json" + "strconv" "time" "github.com/go-jose/go-jose/v4" jjwt "github.com/go-jose/go-jose/v4/jwt" "golang.org/x/xerrors" + + "github.com/coder/coder/v2/coderd/cryptokeys" ) const ( @@ -16,22 +21,27 @@ const ( ) // Encrypt encrypts a token and returns it as a string. -func Encrypt(claims Claims, keyFn SecuringKeyFn) (string, error) { - kid, key, err := keyFn() +func Encrypt(ctx context.Context, keys cryptokeys.Keycache, claims Claims) (string, error) { + signing, err := keys.Signing(ctx) + if err != nil { + return "", xerrors.Errorf("get signing key: %w", err) + } + + decoded, err := hex.DecodeString(signing.Secret) if err != nil { - return "", xerrors.Errorf("get key: %w", err) + return "", xerrors.Errorf("decode signing key: %w", err) } encrypter, err := jose.NewEncrypter( encryptContentAlgo, jose.Recipient{ Algorithm: encryptKeyAlgo, - Key: key, + Key: decoded, }, &jose.EncrypterOptions{ Compression: jose.DEFLATE, ExtraHeaders: map[jose.HeaderKey]interface{}{ - keyIDHeaderKey: kid, + keyIDHeaderKey: strconv.FormatInt(int64(signing.Sequence), 10), }, }, ) @@ -54,7 +64,7 @@ func Encrypt(claims Claims, keyFn SecuringKeyFn) (string, error) { } // Decrypt decrypts the token using the provided key. It unmarshals into the provided claims. -func Decrypt(token string, claims Claims, keyFn ParseKeyFunc, opts ...func(*ParseOptions)) error { +func Decrypt(ctx context.Context, keys cryptokeys.Keycache, token string, claims Claims, opts ...func(*ParseOptions)) error { options := ParseOptions{ RegisteredClaims: jjwt.Expected{ Time: time.Now(), @@ -84,12 +94,27 @@ func Decrypt(token string, claims Claims, keyFn ParseKeyFunc, opts ...func(*Pars return xerrors.Errorf("expected API key encryption algorithm to be %q, got %q", encryptKeyAlgo, object.Header.Algorithm) } - key, err := keyFn(object.Header) + sequenceStr := object.Header.KeyID + if sequenceStr == "" { + return xerrors.Errorf("expected %q header to be a string", keyIDHeaderKey) + } + + sequence, err := strconv.ParseInt(sequenceStr, 10, 32) + if err != nil { + return xerrors.Errorf("parse sequence: %w", err) + } + + key, err := keys.Verifying(ctx, int32(sequence)) + if err != nil { + return xerrors.Errorf("version: %w", err) + } + + decoded, err := hex.DecodeString(key.Secret) if err != nil { - return xerrors.Errorf("get key: %w", err) + return xerrors.Errorf("decode key: %w", err) } - decrypted, err := object.Decrypt(key) + decrypted, err := object.Decrypt(decoded) if err != nil { return xerrors.Errorf("decrypt: %w", err) } diff --git a/coderd/jwt/jws.go b/coderd/jwt/jws.go index c9a9c5ce8131e..e26d30b92da69 100644 --- a/coderd/jwt/jws.go +++ b/coderd/jwt/jws.go @@ -1,12 +1,17 @@ package jwt import ( + "context" + "encoding/hex" "encoding/json" + "strconv" "time" "github.com/go-jose/go-jose/v4" jjwt "github.com/go-jose/go-jose/v4/jwt" "golang.org/x/xerrors" + + "github.com/coder/coder/v2/coderd/cryptokeys" ) const ( @@ -15,18 +20,23 @@ const ( ) // Sign signs a token and returns it as a string. -func Sign(claims Claims, keyFn SecuringKeyFn) (string, error) { - kid, key, err := keyFn() +func Sign(ctx context.Context, keys cryptokeys.Keycache, claims Claims) (string, error) { + signing, err := keys.Signing(ctx) + if err != nil { + return "", xerrors.Errorf("get signing key: %w", err) + } + + decoded, err := hex.DecodeString(signing.Secret) if err != nil { - return "", xerrors.Errorf("get key: %w", err) + return "", xerrors.Errorf("decode signing key: %w", err) } signer, err := jose.NewSigner(jose.SigningKey{ Algorithm: defaultSigningAlgo, - Key: key, + Key: decoded, }, &jose.SignerOptions{ ExtraHeaders: map[jose.HeaderKey]interface{}{ - keyIDHeaderKey: kid, + keyIDHeaderKey: strconv.FormatInt(int64(signing.Sequence), 10), }, }) if err != nil { @@ -52,7 +62,7 @@ func Sign(claims Claims, keyFn SecuringKeyFn) (string, error) { } // Verify verifies that a token was signed by the provided key. It unmarshals into the provided claims. -func Verify(token string, claims Claims, keyFn ParseKeyFunc, opts ...func(*ParseOptions)) error { +func Verify(ctx context.Context, keys cryptokeys.Keycache, token string, claims Claims, opts ...func(*ParseOptions)) error { options := ParseOptions{ RegisteredClaims: jjwt.Expected{ Time: time.Now(), @@ -79,12 +89,27 @@ func Verify(token string, claims Claims, keyFn ParseKeyFunc, opts ...func(*Parse return xerrors.Errorf("expected token signing algorithm to be %q, got %q", defaultSigningAlgo, object.Signatures[0].Header.Algorithm) } - key, err := keyFn(signature.Header) + sequenceStr := signature.Header.KeyID + if sequenceStr == "" { + return xerrors.Errorf("expected %q header to be a string", keyIDHeaderKey) + } + + sequence, err := strconv.ParseInt(sequenceStr, 10, 32) + if err != nil { + return xerrors.Errorf("parse sequence: %w", err) + } + + key, err := keys.Verifying(ctx, int32(sequence)) + if err != nil { + return xerrors.Errorf("version: %w", err) + } + + decoded, err := hex.DecodeString(key.Secret) if err != nil { - return xerrors.Errorf("get key: %w", err) + return xerrors.Errorf("decode key: %w", err) } - payload, err := object.Verify(key) + payload, err := object.Verify(decoded) if err != nil { return xerrors.Errorf("verify payload: %w", err) } diff --git a/coderd/jwt/jwt.go b/coderd/jwt/jwt.go index fbcf51603accb..6bf84e4ee46c1 100644 --- a/coderd/jwt/jwt.go +++ b/coderd/jwt/jwt.go @@ -1,71 +1,16 @@ package jwt import ( - "context" - "encoding/hex" - "strconv" - "github.com/go-jose/go-jose/v4" jjwt "github.com/go-jose/go-jose/v4/jwt" - "golang.org/x/xerrors" - - "github.com/coder/coder/v2/coderd/cryptokeys" ) +// Claims defines the payload for a JWT. Most callers +// should ember go-jose/jwt.Claims type Claims interface { Validate(jjwt.Expected) error } -// SecuringKeyFn returns a key for signing or encrypting. -type SecuringKeyFn func() (id string, key interface{}, err error) - -// SecuringKeyFromCache returns the appropriate key for signing or encrypting. -func SecuringKeyFromCache(ctx context.Context, keys cryptokeys.Keycache) SecuringKeyFn { - return func() (id string, key interface{}, err error) { - signing, err := keys.Signing(ctx) - if err != nil { - return "", nil, xerrors.Errorf("get signing key: %w", err) - } - - decoded, err := hex.DecodeString(signing.Secret) - if err != nil { - return "", nil, xerrors.Errorf("decode signing key: %w", err) - } - - return strconv.FormatInt(int64(signing.Sequence), 10), decoded, nil - } -} - -// ParseKeyFunc returns a key for verifying or decrypting a token. -type ParseKeyFunc func(jose.Header) (interface{}, error) - -// ParseKeyFromCache returns the appropriate key to decrypt or verify a token. -func ParseKeyFromCache(ctx context.Context, keys cryptokeys.Keycache) ParseKeyFunc { - return func(header jose.Header) (interface{}, error) { - sequenceStr := header.KeyID - if sequenceStr == "" { - return nil, xerrors.Errorf("expected %q header to be a string", keyIDHeaderKey) - } - - sequence, err := strconv.ParseInt(sequenceStr, 10, 32) - if err != nil { - return nil, xerrors.Errorf("parse sequence: %w", err) - } - - key, err := keys.Verifying(ctx, int32(sequence)) - if err != nil { - return nil, xerrors.Errorf("version: %w", err) - } - - decoded, err := hex.DecodeString(key.Secret) - if err != nil { - return nil, xerrors.Errorf("decode key: %w", err) - } - - return decoded, nil - } -} - // ParseOptions are options for parsing a JWT. type ParseOptions struct { RegisteredClaims jjwt.Expected diff --git a/coderd/jwt/jwt_test.go b/coderd/jwt/jwt_test.go index b9d50d7503e30..ade68ecad2e26 100644 --- a/coderd/jwt/jwt_test.go +++ b/coderd/jwt/jwt_test.go @@ -1,6 +1,7 @@ package jwt_test import ( + "context" "crypto/rand" "encoding/hex" "testing" @@ -8,10 +9,8 @@ import ( "github.com/go-jose/go-jose/v4" jjwt "github.com/go-jose/go-jose/v4/jwt" - "github.com/google/uuid" "github.com/stretchr/testify/require" "go.uber.org/mock/gomock" - "golang.org/x/xerrors" "github.com/coder/coder/v2/coderd/cryptokeys" "github.com/coder/coder/v2/coderd/jwt" @@ -24,22 +23,22 @@ func TestJWT(t *testing.T) { type tokenType struct { Name string - SecureFn func(jwt.Claims, jwt.SecuringKeyFn) (string, error) - ParseFn func(string, jwt.Claims, jwt.ParseKeyFunc, ...func(*jwt.ParseOptions)) error + SignFn func(ctx context.Context, keys cryptokeys.Keycache, claims jwt.Claims) (string, error) + VerifyFn func(ctx context.Context, keys cryptokeys.Keycache, token string, claims jwt.Claims, opts ...func(*jwt.ParseOptions)) error KeySize int } types := []tokenType{ { Name: "JWE", - SecureFn: jwt.Encrypt, - ParseFn: jwt.Decrypt, + SignFn: jwt.Encrypt, + VerifyFn: jwt.Decrypt, KeySize: 32, }, { Name: "JWS", - SecureFn: jwt.Sign, - ParseFn: jwt.Verify, + SignFn: jwt.Sign, + VerifyFn: jwt.Verify, KeySize: 64, }, } @@ -50,30 +49,7 @@ func TestJWT(t *testing.T) { t.Run(tt.Name, func(t *testing.T) { t.Parallel() - t.Run("Basic", func(t *testing.T) { - t.Parallel() - - id := uuid.New().String() - key := generateSecret(t, tt.KeySize) - - claims := jjwt.Claims{ - Issuer: "coder", - Subject: "user@coder.com", - Audience: jjwt.Audience{"coder"}, - Expiry: jjwt.NewNumericDate(time.Now().Add(time.Hour)), - IssuedAt: jjwt.NewNumericDate(time.Now()), - NotBefore: jjwt.NewNumericDate(time.Now()), - } - - token, err := tt.SecureFn(claims, securingKeyFn(id, key)) - require.NoError(t, err) - - var actual testClaims - err = tt.ParseFn(token, &actual, verifyingKeyFn(id, key)) - require.NoError(t, err) - }) - - t.Run("Keycache", func(t *testing.T) { + t.Run("OK", func(t *testing.T) { t.Parallel() var ( @@ -85,46 +61,62 @@ func TestJWT(t *testing.T) { key := generateCryptoKey(t, 1234567890, now, tt.KeySize) - keycache.EXPECT().Signing(gomock.Any()).Return(key, nil) - keycache.EXPECT().Verifying(gomock.Any(), key.Sequence).Return(key, nil) - - claims := jjwt.Claims{ - Issuer: "coder", - Subject: "user@coder.com", - Audience: jjwt.Audience{"coder"}, - Expiry: jjwt.NewNumericDate(time.Now().Add(time.Hour)), - IssuedAt: jjwt.NewNumericDate(time.Now()), - NotBefore: jjwt.NewNumericDate(time.Now()), + keycache.EXPECT().Signing(ctx).Return(key, nil) + keycache.EXPECT().Verifying(ctx, key.Sequence).Return(key, nil) + + claims := testClaims{ + Claims: jjwt.Claims{ + Issuer: "coder", + Subject: "user@coder.com", + Audience: jjwt.Audience{"coder"}, + Expiry: jjwt.NewNumericDate(time.Now().Add(time.Hour)), + IssuedAt: jjwt.NewNumericDate(time.Now()), + NotBefore: jjwt.NewNumericDate(time.Now()), + }, + MyClaim: "my_value", } - token, err := tt.SecureFn(claims, jwt.SecuringKeyFromCache(ctx, keycache)) + token, err := tt.SignFn(ctx, keycache, claims) require.NoError(t, err) var actual testClaims - err = tt.ParseFn(token, &actual, jwt.ParseKeyFromCache(ctx, keycache)) + err = tt.VerifyFn(ctx, keycache, token, &actual) require.NoError(t, err) + require.Equal(t, claims, actual) }) t.Run("WrongIssuer", func(t *testing.T) { t.Parallel() - id := uuid.New().String() - key := generateSecret(t, tt.KeySize) + var ( + ctx = testutil.Context(t, testutil.WaitShort) + ctrl = gomock.NewController(t) + keycache = cryptokeys.NewMockKeycache(ctrl) + now = time.Now() + ) + + key := generateCryptoKey(t, 1234567890, now, tt.KeySize) - claims := jjwt.Claims{ - Issuer: "coder", - Subject: "user@coder.com", - Audience: jjwt.Audience{"coder"}, - Expiry: jjwt.NewNumericDate(time.Now().Add(time.Hour)), - IssuedAt: jjwt.NewNumericDate(time.Now()), - NotBefore: jjwt.NewNumericDate(time.Now()), + keycache.EXPECT().Signing(ctx).Return(key, nil) + keycache.EXPECT().Verifying(ctx, key.Sequence).Return(key, nil) + + claims := testClaims{ + Claims: jjwt.Claims{ + Issuer: "coder", + Subject: "user@coder.com", + Audience: jjwt.Audience{"coder"}, + Expiry: jjwt.NewNumericDate(time.Now().Add(time.Hour)), + IssuedAt: jjwt.NewNumericDate(time.Now()), + NotBefore: jjwt.NewNumericDate(time.Now()), + }, + MyClaim: "my_value", } - token, err := tt.SecureFn(claims, securingKeyFn(id, key)) + token, err := tt.SignFn(ctx, keycache, claims) require.NoError(t, err) var actual testClaims - err = tt.ParseFn(token, &actual, verifyingKeyFn(id, key), withExpected(jjwt.Expected{ + err = tt.VerifyFn(ctx, keycache, token, &actual, withExpected(jjwt.Expected{ Issuer: "coder2", })) require.ErrorIs(t, err, jjwt.ErrInvalidIssuer) @@ -133,23 +125,35 @@ func TestJWT(t *testing.T) { t.Run("WrongSubject", func(t *testing.T) { t.Parallel() - id := uuid.New().String() - key := generateSecret(t, tt.KeySize) + var ( + ctx = testutil.Context(t, testutil.WaitShort) + ctrl = gomock.NewController(t) + keycache = cryptokeys.NewMockKeycache(ctrl) + now = time.Now() + ) + + key := generateCryptoKey(t, 1234567890, now, tt.KeySize) - claims := jjwt.Claims{ - Issuer: "coder", - Subject: "user@coder.com", - Audience: jjwt.Audience{"coder"}, - Expiry: jjwt.NewNumericDate(time.Now().Add(time.Hour)), - IssuedAt: jjwt.NewNumericDate(time.Now()), - NotBefore: jjwt.NewNumericDate(time.Now()), + keycache.EXPECT().Signing(ctx).Return(key, nil) + keycache.EXPECT().Verifying(ctx, key.Sequence).Return(key, nil) + + claims := testClaims{ + Claims: jjwt.Claims{ + Issuer: "coder", + Subject: "user@coder.com", + Audience: jjwt.Audience{"coder"}, + Expiry: jjwt.NewNumericDate(time.Now().Add(time.Hour)), + IssuedAt: jjwt.NewNumericDate(time.Now()), + NotBefore: jjwt.NewNumericDate(time.Now()), + }, + MyClaim: "my_value", } - token, err := tt.SecureFn(claims, securingKeyFn(id, key)) + token, err := tt.SignFn(ctx, keycache, claims) require.NoError(t, err) var actual testClaims - err = tt.ParseFn(token, &actual, verifyingKeyFn(id, key), withExpected(jjwt.Expected{ + err = tt.VerifyFn(ctx, keycache, token, &actual, withExpected(jjwt.Expected{ Subject: "user2@coder.com", })) require.ErrorIs(t, err, jjwt.ErrInvalidSubject) @@ -158,23 +162,34 @@ func TestJWT(t *testing.T) { t.Run("WrongAudience", func(t *testing.T) { t.Parallel() - id := uuid.New().String() - key := generateSecret(t, tt.KeySize) + var ( + ctx = testutil.Context(t, testutil.WaitShort) + ctrl = gomock.NewController(t) + keycache = cryptokeys.NewMockKeycache(ctrl) + now = time.Now() + key = generateCryptoKey(t, 1234567890, now, tt.KeySize) + ) - claims := jjwt.Claims{ - Issuer: "coder", - Subject: "user@coder.com", - Audience: jjwt.Audience{"coder"}, - Expiry: jjwt.NewNumericDate(time.Now().Add(time.Hour)), - IssuedAt: jjwt.NewNumericDate(time.Now()), - NotBefore: jjwt.NewNumericDate(time.Now()), + keycache.EXPECT().Signing(ctx).Return(key, nil) + keycache.EXPECT().Verifying(ctx, key.Sequence).Return(key, nil) + + claims := testClaims{ + Claims: jjwt.Claims{ + Issuer: "coder", + Subject: "user@coder.com", + Audience: jjwt.Audience{"coder"}, + Expiry: jjwt.NewNumericDate(time.Now().Add(time.Hour)), + IssuedAt: jjwt.NewNumericDate(time.Now()), + NotBefore: jjwt.NewNumericDate(time.Now()), + }, + MyClaim: "my_value", } - token, err := tt.SecureFn(claims, securingKeyFn(id, key)) + token, err := tt.SignFn(ctx, keycache, claims) require.NoError(t, err) var actual testClaims - err = tt.ParseFn(token, &actual, verifyingKeyFn(id, key), withExpected(jjwt.Expected{ + err = tt.VerifyFn(ctx, keycache, token, &actual, withExpected(jjwt.Expected{ AnyAudience: jjwt.Audience{"coder2"}, })) require.ErrorIs(t, err, jjwt.ErrInvalidAudience) @@ -183,23 +198,34 @@ func TestJWT(t *testing.T) { t.Run("Expired", func(t *testing.T) { t.Parallel() - id := uuid.New().String() - key := generateSecret(t, tt.KeySize) + var ( + ctx = testutil.Context(t, testutil.WaitShort) + ctrl = gomock.NewController(t) + keycache = cryptokeys.NewMockKeycache(ctrl) + now = time.Now() + key = generateCryptoKey(t, 1234567890, now, tt.KeySize) + ) - claims := jjwt.Claims{ - Issuer: "coder", - Subject: "user@coder.com", - Audience: jjwt.Audience{"coder"}, - Expiry: jjwt.NewNumericDate(time.Now().Add(time.Minute)), - IssuedAt: jjwt.NewNumericDate(time.Now()), - NotBefore: jjwt.NewNumericDate(time.Now()), + keycache.EXPECT().Signing(ctx).Return(key, nil) + keycache.EXPECT().Verifying(ctx, key.Sequence).Return(key, nil) + + claims := testClaims{ + Claims: jjwt.Claims{ + Issuer: "coder", + Subject: "user@coder.com", + Audience: jjwt.Audience{"coder"}, + Expiry: jjwt.NewNumericDate(time.Now().Add(time.Minute)), + IssuedAt: jjwt.NewNumericDate(time.Now()), + NotBefore: jjwt.NewNumericDate(time.Now()), + }, + MyClaim: "my_value", } - token, err := tt.SecureFn(claims, securingKeyFn(id, key)) + token, err := tt.SignFn(ctx, keycache, claims) require.NoError(t, err) var actual testClaims - err = tt.ParseFn(token, &actual, verifyingKeyFn(id, key), withExpected(jjwt.Expected{ + err = tt.VerifyFn(ctx, keycache, token, &actual, withExpected(jjwt.Expected{ Time: time.Now().Add(time.Minute * 3), })) require.ErrorIs(t, err, jjwt.ErrExpired) @@ -208,22 +234,34 @@ func TestJWT(t *testing.T) { t.Run("IssuedInFuture", func(t *testing.T) { t.Parallel() - id := uuid.New().String() - key := generateSecret(t, tt.KeySize) + var ( + ctx = testutil.Context(t, testutil.WaitShort) + ctrl = gomock.NewController(t) + keycache = cryptokeys.NewMockKeycache(ctrl) + now = time.Now() + ) + + key := generateCryptoKey(t, 1234567890, now, tt.KeySize) - claims := jjwt.Claims{ - Issuer: "coder", - Subject: "user@coder.com", - Audience: jjwt.Audience{"coder"}, - Expiry: jjwt.NewNumericDate(time.Now().Add(time.Minute)), - IssuedAt: jjwt.NewNumericDate(time.Now()), + keycache.EXPECT().Signing(ctx).Return(key, nil) + keycache.EXPECT().Verifying(ctx, key.Sequence).Return(key, nil) + + claims := testClaims{ + Claims: jjwt.Claims{ + Issuer: "coder", + Subject: "user@coder.com", + Audience: jjwt.Audience{"coder"}, + Expiry: jjwt.NewNumericDate(time.Now().Add(time.Minute)), + IssuedAt: jjwt.NewNumericDate(time.Now()), + }, + MyClaim: "my_value", } - token, err := tt.SecureFn(claims, securingKeyFn(id, key)) + token, err := tt.SignFn(ctx, keycache, claims) require.NoError(t, err) var actual testClaims - err = tt.ParseFn(token, &actual, verifyingKeyFn(id, key), withExpected(jjwt.Expected{ + err = tt.VerifyFn(ctx, keycache, token, &actual, withExpected(jjwt.Expected{ Time: time.Now().Add(-time.Minute * 3), })) require.ErrorIs(t, err, jjwt.ErrIssuedInTheFuture) @@ -232,23 +270,35 @@ func TestJWT(t *testing.T) { t.Run("IsBefore", func(t *testing.T) { t.Parallel() - id := uuid.New().String() - key := generateSecret(t, tt.KeySize) + var ( + ctx = testutil.Context(t, testutil.WaitShort) + ctrl = gomock.NewController(t) + keycache = cryptokeys.NewMockKeycache(ctrl) + now = time.Now() + ) + + key := generateCryptoKey(t, 1234567890, now, tt.KeySize) - claims := jjwt.Claims{ - Issuer: "coder", - Subject: "user@coder.com", - Audience: jjwt.Audience{"coder"}, - Expiry: jjwt.NewNumericDate(time.Now().Add(time.Hour)), - IssuedAt: jjwt.NewNumericDate(time.Now()), - NotBefore: jjwt.NewNumericDate(time.Now().Add(time.Minute * 5)), + keycache.EXPECT().Signing(ctx).Return(key, nil) + keycache.EXPECT().Verifying(ctx, key.Sequence).Return(key, nil) + + claims := testClaims{ + Claims: jjwt.Claims{ + Issuer: "coder", + Subject: "user@coder.com", + Audience: jjwt.Audience{"coder"}, + Expiry: jjwt.NewNumericDate(time.Now().Add(time.Hour)), + IssuedAt: jjwt.NewNumericDate(time.Now()), + NotBefore: jjwt.NewNumericDate(time.Now().Add(time.Minute * 5)), + }, + MyClaim: "my_value", } - token, err := tt.SecureFn(claims, securingKeyFn(id, key)) + token, err := tt.SignFn(ctx, keycache, claims) require.NoError(t, err) var actual testClaims - err = tt.ParseFn(token, &actual, verifyingKeyFn(id, key), withExpected(jjwt.Expected{ + err = tt.VerifyFn(ctx, keycache, token, &actual, withExpected(jjwt.Expected{ Time: time.Now().Add(time.Minute * 3), })) require.ErrorIs(t, err, jjwt.ErrNotValidYet) @@ -261,23 +311,34 @@ func TestJWT(t *testing.T) { t.Skip("JWE does not support this") } - id := uuid.New().String() - key := generateSecret(t, tt.KeySize) + var ( + ctx = testutil.Context(t, testutil.WaitShort) + ctrl = gomock.NewController(t) + keycache = cryptokeys.NewMockKeycache(ctrl) + now = time.Now() + ) + + key := generateCryptoKey(t, 1234567890, now, tt.KeySize) - claims := jjwt.Claims{ - Issuer: "coder", - Subject: "user@coder.com", - Audience: jjwt.Audience{"coder"}, - Expiry: jjwt.NewNumericDate(time.Now().Add(time.Hour)), - IssuedAt: jjwt.NewNumericDate(time.Now()), - NotBefore: jjwt.NewNumericDate(time.Now().Add(time.Minute * 5)), + keycache.EXPECT().Signing(ctx).Return(key, nil) + + claims := testClaims{ + Claims: jjwt.Claims{ + Issuer: "coder", + Subject: "user@coder.com", + Audience: jjwt.Audience{"coder"}, + Expiry: jjwt.NewNumericDate(time.Now().Add(time.Hour)), + IssuedAt: jjwt.NewNumericDate(time.Now()), + NotBefore: jjwt.NewNumericDate(time.Now().Add(time.Minute * 5)), + }, + MyClaim: "my_value", } - token, err := tt.SecureFn(claims, securingKeyFn(id, key)) + token, err := tt.SignFn(ctx, keycache, claims) require.NoError(t, err) var actual testClaims - err = tt.ParseFn(token, &actual, verifyingKeyFn(id, key), withSignatureAlgorithm(jose.HS256)) + err = tt.VerifyFn(ctx, keycache, token, &actual, withSignatureAlgorithm(jose.HS256)) require.Error(t, err) }) @@ -288,23 +349,34 @@ func TestJWT(t *testing.T) { t.Skip("JWS does not support this") } - id := uuid.New().String() - key := generateSecret(t, tt.KeySize) + var ( + ctx = testutil.Context(t, testutil.WaitShort) + ctrl = gomock.NewController(t) + keycache = cryptokeys.NewMockKeycache(ctrl) + now = time.Now() + ) + + key := generateCryptoKey(t, 1234567890, now, tt.KeySize) - claims := jjwt.Claims{ - Issuer: "coder", - Subject: "user@coder.com", - Audience: jjwt.Audience{"coder"}, - Expiry: jjwt.NewNumericDate(time.Now().Add(time.Hour)), - IssuedAt: jjwt.NewNumericDate(time.Now()), - NotBefore: jjwt.NewNumericDate(time.Now().Add(time.Minute * 5)), + keycache.EXPECT().Signing(ctx).Return(key, nil) + + claims := testClaims{ + Claims: jjwt.Claims{ + Issuer: "coder", + Subject: "user@coder.com", + Audience: jjwt.Audience{"coder"}, + Expiry: jjwt.NewNumericDate(time.Now().Add(time.Hour)), + IssuedAt: jjwt.NewNumericDate(time.Now()), + NotBefore: jjwt.NewNumericDate(time.Now().Add(time.Minute * 5)), + }, + MyClaim: "my_value", } - token, err := tt.SecureFn(claims, securingKeyFn(id, key)) + token, err := tt.SignFn(ctx, keycache, claims) require.NoError(t, err) var actual testClaims - err = tt.ParseFn(token, &actual, verifyingKeyFn(id, key), withKeyAlgorithm(jose.A128GCMKW)) + err = tt.VerifyFn(ctx, keycache, token, &actual, withKeyAlgorithm(jose.A128GCMKW)) require.Error(t, err) }) @@ -315,23 +387,34 @@ func TestJWT(t *testing.T) { t.Skip("JWS does not support this") } - id := uuid.New().String() - key := generateSecret(t, tt.KeySize) + var ( + ctx = testutil.Context(t, testutil.WaitShort) + ctrl = gomock.NewController(t) + keycache = cryptokeys.NewMockKeycache(ctrl) + now = time.Now() + ) + + key := generateCryptoKey(t, 1234567890, now, tt.KeySize) + + keycache.EXPECT().Signing(gomock.Any()).Return(key, nil) - claims := jjwt.Claims{ - Issuer: "coder", - Subject: "user@coder.com", - Audience: jjwt.Audience{"coder"}, - Expiry: jjwt.NewNumericDate(time.Now().Add(time.Hour)), - IssuedAt: jjwt.NewNumericDate(time.Now()), - NotBefore: jjwt.NewNumericDate(time.Now().Add(time.Minute * 5)), + claims := testClaims{ + Claims: jjwt.Claims{ + Issuer: "coder", + Subject: "user@coder.com", + Audience: jjwt.Audience{"coder"}, + Expiry: jjwt.NewNumericDate(time.Now().Add(time.Hour)), + IssuedAt: jjwt.NewNumericDate(time.Now()), + NotBefore: jjwt.NewNumericDate(time.Now().Add(time.Minute * 5)), + }, + MyClaim: "my_value", } - token, err := tt.SecureFn(claims, securingKeyFn(id, key)) + token, err := tt.SignFn(ctx, keycache, claims) require.NoError(t, err) var actual testClaims - err = tt.ParseFn(token, &actual, verifyingKeyFn(id, key), withContentEncryptionAlgorithm(jose.A128GCM)) + err = tt.VerifyFn(ctx, keycache, token, &actual, withContentEncryptionAlgorithm(jose.A128GCM)) require.Error(t, err) }) }) @@ -365,21 +448,6 @@ type testClaims struct { jjwt.Claims } -func securingKeyFn(id string, key []byte) jwt.SecuringKeyFn { - return func() (string, interface{}, error) { - return id, key, nil - } -} - -func verifyingKeyFn(id string, key []byte) jwt.ParseKeyFunc { - return func(header jose.Header) (interface{}, error) { - if header.KeyID != id { - return nil, xerrors.Errorf("expected key ID %q, got %q", id, header.KeyID) - } - return key, nil - } -} - func withExpected(e jjwt.Expected) func(*jwt.ParseOptions) { return func(opts *jwt.ParseOptions) { opts.RegisteredClaims = e From 73c902ca24d47da88f66b211e0581629990dc9b4 Mon Sep 17 00:00:00 2001 From: Jon Ayers Date: Thu, 3 Oct 2024 01:26:37 +0000 Subject: [PATCH 10/25] Refactor JWT package for improved modularity and clarity - Rename package `jwt` to `jwtutils` for better context. - Consolidate constants and enhance error messages. - Simplify key management for JWT signing and verification. --- coderd/jwt/jwe.go | 10 +++++----- coderd/jwt/jws.go | 21 ++++++++++----------- coderd/jwt/jwt.go | 14 +++++++++----- coderd/jwt/jwt_test.go | 2 +- 4 files changed, 25 insertions(+), 22 deletions(-) diff --git a/coderd/jwt/jwe.go b/coderd/jwt/jwe.go index 1529d75265916..222b38e27a3d0 100644 --- a/coderd/jwt/jwe.go +++ b/coderd/jwt/jwe.go @@ -1,4 +1,4 @@ -package jwt +package jwtutils import ( "context" @@ -9,7 +9,7 @@ import ( "time" "github.com/go-jose/go-jose/v4" - jjwt "github.com/go-jose/go-jose/v4/jwt" + "github.com/go-jose/go-jose/v4/jwt" "golang.org/x/xerrors" "github.com/coder/coder/v2/coderd/cryptokeys" @@ -66,7 +66,7 @@ func Encrypt(ctx context.Context, keys cryptokeys.Keycache, claims Claims) (stri // Decrypt decrypts the token using the provided key. It unmarshals into the provided claims. func Decrypt(ctx context.Context, keys cryptokeys.Keycache, token string, claims Claims, opts ...func(*ParseOptions)) error { options := ParseOptions{ - RegisteredClaims: jjwt.Expected{ + RegisteredClaims: jwt.Expected{ Time: time.Now(), }, KeyAlgorithm: encryptKeyAlgo, @@ -87,11 +87,11 @@ func Decrypt(ctx context.Context, keys cryptokeys.Keycache, token string, claims []jose.ContentEncryption{options.ContentEncryptionAlgorithm}, ) if err != nil { - return xerrors.Errorf("parse encrypted API key: %w", err) + return xerrors.Errorf("parse jwe: %w", err) } if object.Header.Algorithm != string(encryptKeyAlgo) { - return xerrors.Errorf("expected API key encryption algorithm to be %q, got %q", encryptKeyAlgo, object.Header.Algorithm) + return xerrors.Errorf("expected JWE algorithm to be %q, got %q", encryptKeyAlgo, object.Header.Algorithm) } sequenceStr := object.Header.KeyID diff --git a/coderd/jwt/jws.go b/coderd/jwt/jws.go index e26d30b92da69..462846a1cb202 100644 --- a/coderd/jwt/jws.go +++ b/coderd/jwt/jws.go @@ -1,4 +1,4 @@ -package jwt +package jwtutils import ( "context" @@ -8,15 +8,14 @@ import ( "time" "github.com/go-jose/go-jose/v4" - jjwt "github.com/go-jose/go-jose/v4/jwt" + "github.com/go-jose/go-jose/v4/jwt" "golang.org/x/xerrors" "github.com/coder/coder/v2/coderd/cryptokeys" ) const ( - defaultSigningAlgo = jose.HS512 - keyIDHeaderKey = "kid" + signingAlgo = jose.HS512 ) // Sign signs a token and returns it as a string. @@ -32,7 +31,7 @@ func Sign(ctx context.Context, keys cryptokeys.Keycache, claims Claims) (string, } signer, err := jose.NewSigner(jose.SigningKey{ - Algorithm: defaultSigningAlgo, + Algorithm: signingAlgo, Key: decoded, }, &jose.SignerOptions{ ExtraHeaders: map[jose.HeaderKey]interface{}{ @@ -64,10 +63,10 @@ func Sign(ctx context.Context, keys cryptokeys.Keycache, claims Claims) (string, // Verify verifies that a token was signed by the provided key. It unmarshals into the provided claims. func Verify(ctx context.Context, keys cryptokeys.Keycache, token string, claims Claims, opts ...func(*ParseOptions)) error { options := ParseOptions{ - RegisteredClaims: jjwt.Expected{ + RegisteredClaims: jwt.Expected{ Time: time.Now(), }, - SignatureAlgorithm: defaultSigningAlgo, + SignatureAlgorithm: signingAlgo, } for _, opt := range opts { @@ -85,8 +84,8 @@ func Verify(ctx context.Context, keys cryptokeys.Keycache, token string, claims signature := object.Signatures[0] - if signature.Header.Algorithm != string(defaultSigningAlgo) { - return xerrors.Errorf("expected token signing algorithm to be %q, got %q", defaultSigningAlgo, object.Signatures[0].Header.Algorithm) + if signature.Header.Algorithm != string(signingAlgo) { + return xerrors.Errorf("expected JWS algorithm to be %q, got %q", signingAlgo, object.Signatures[0].Header.Algorithm) } sequenceStr := signature.Header.KeyID @@ -96,12 +95,12 @@ func Verify(ctx context.Context, keys cryptokeys.Keycache, token string, claims sequence, err := strconv.ParseInt(sequenceStr, 10, 32) if err != nil { - return xerrors.Errorf("parse sequence: %w", err) + return xerrors.Errorf("parse sequence %q: %w", sequenceStr, err) } key, err := keys.Verifying(ctx, int32(sequence)) if err != nil { - return xerrors.Errorf("version: %w", err) + return xerrors.Errorf("verifying key for seq %v: %w", sequence, err) } decoded, err := hex.DecodeString(key.Secret) diff --git a/coderd/jwt/jwt.go b/coderd/jwt/jwt.go index 6bf84e4ee46c1..eb8bbba915813 100644 --- a/coderd/jwt/jwt.go +++ b/coderd/jwt/jwt.go @@ -1,19 +1,23 @@ -package jwt +package jwtutils import ( "github.com/go-jose/go-jose/v4" - jjwt "github.com/go-jose/go-jose/v4/jwt" + "github.com/go-jose/go-jose/v4/jwt" +) + +const ( + keyIDHeaderKey = "kid" ) // Claims defines the payload for a JWT. Most callers -// should ember go-jose/jwt.Claims +// should embed jwt.Claims type Claims interface { - Validate(jjwt.Expected) error + Validate(jwt.Expected) error } // ParseOptions are options for parsing a JWT. type ParseOptions struct { - RegisteredClaims jjwt.Expected + RegisteredClaims jwt.Expected // The following are only used for JWSs. SignatureAlgorithm jose.SignatureAlgorithm diff --git a/coderd/jwt/jwt_test.go b/coderd/jwt/jwt_test.go index ade68ecad2e26..af1a9e5b8edb1 100644 --- a/coderd/jwt/jwt_test.go +++ b/coderd/jwt/jwt_test.go @@ -1,4 +1,4 @@ -package jwt_test +package jwtutils_test import ( "context" From e348a7afcc311d2436dae4a4a29157014ea7d78c Mon Sep 17 00:00:00 2001 From: Jon Ayers Date: Thu, 3 Oct 2024 02:15:30 +0000 Subject: [PATCH 11/25] mv dir --- coderd/{jwt => jwtutils}/jwe.go | 0 coderd/{jwt => jwtutils}/jws.go | 0 coderd/{jwt => jwtutils}/jwt.go | 0 coderd/{jwt => jwtutils}/jwt_test.go | 102 +++++++++++++-------------- 4 files changed, 51 insertions(+), 51 deletions(-) rename coderd/{jwt => jwtutils}/jwe.go (100%) rename coderd/{jwt => jwtutils}/jws.go (100%) rename coderd/{jwt => jwtutils}/jwt.go (100%) rename coderd/{jwt => jwtutils}/jwt_test.go (82%) diff --git a/coderd/jwt/jwe.go b/coderd/jwtutils/jwe.go similarity index 100% rename from coderd/jwt/jwe.go rename to coderd/jwtutils/jwe.go diff --git a/coderd/jwt/jws.go b/coderd/jwtutils/jws.go similarity index 100% rename from coderd/jwt/jws.go rename to coderd/jwtutils/jws.go diff --git a/coderd/jwt/jwt.go b/coderd/jwtutils/jwt.go similarity index 100% rename from coderd/jwt/jwt.go rename to coderd/jwtutils/jwt.go diff --git a/coderd/jwt/jwt_test.go b/coderd/jwtutils/jwt_test.go similarity index 82% rename from coderd/jwt/jwt_test.go rename to coderd/jwtutils/jwt_test.go index af1a9e5b8edb1..c068dc7e3d9b3 100644 --- a/coderd/jwt/jwt_test.go +++ b/coderd/jwtutils/jwt_test.go @@ -7,13 +7,13 @@ import ( "testing" "time" - "github.com/go-jose/go-jose/v4" - jjwt "github.com/go-jose/go-jose/v4/jwt" + "github.com/go-jose/go-jose/v4/jwt" "github.com/stretchr/testify/require" "go.uber.org/mock/gomock" + "gopkg.in/square/go-jose.v2" "github.com/coder/coder/v2/coderd/cryptokeys" - "github.com/coder/coder/v2/coderd/jwt" + "github.com/coder/coder/v2/coderd/jwtutils" "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/testutil" ) @@ -23,22 +23,22 @@ func TestJWT(t *testing.T) { type tokenType struct { Name string - SignFn func(ctx context.Context, keys cryptokeys.Keycache, claims jwt.Claims) (string, error) - VerifyFn func(ctx context.Context, keys cryptokeys.Keycache, token string, claims jwt.Claims, opts ...func(*jwt.ParseOptions)) error + SignFn func(ctx context.Context, keys cryptokeys.Keycache, claims jwtutils.Claims) (string, error) + VerifyFn func(ctx context.Context, keys cryptokeys.Keycache, token string, claims jwtutils.Claims, opts ...func(*jwtutils.ParseOptions)) error KeySize int } types := []tokenType{ { Name: "JWE", - SignFn: jwt.Encrypt, - VerifyFn: jwt.Decrypt, + SignFn: jwtutils.Encrypt, + VerifyFn: jwtutils.Decrypt, KeySize: 32, }, { Name: "JWS", - SignFn: jwt.Sign, - VerifyFn: jwt.Verify, + SignFn: jwtutils.Sign, + VerifyFn: jwtutils.Verify, KeySize: 64, }, } @@ -65,13 +65,13 @@ func TestJWT(t *testing.T) { keycache.EXPECT().Verifying(ctx, key.Sequence).Return(key, nil) claims := testClaims{ - Claims: jjwt.Claims{ + Claims: jwt.Claims{ Issuer: "coder", Subject: "user@coder.com", - Audience: jjwt.Audience{"coder"}, - Expiry: jjwt.NewNumericDate(time.Now().Add(time.Hour)), - IssuedAt: jjwt.NewNumericDate(time.Now()), - NotBefore: jjwt.NewNumericDate(time.Now()), + Audience: jwt.Audience{"coder"}, + Expiry: jwt.NewNumericDate(time.Now().Add(time.Hour)), + IssuedAt: jwt.NewNumericDate(time.Now()), + NotBefore: jwt.NewNumericDate(time.Now()), }, MyClaim: "my_value", } @@ -101,13 +101,13 @@ func TestJWT(t *testing.T) { keycache.EXPECT().Verifying(ctx, key.Sequence).Return(key, nil) claims := testClaims{ - Claims: jjwt.Claims{ + Claims: jwt.Claims{ Issuer: "coder", Subject: "user@coder.com", - Audience: jjwt.Audience{"coder"}, - Expiry: jjwt.NewNumericDate(time.Now().Add(time.Hour)), - IssuedAt: jjwt.NewNumericDate(time.Now()), - NotBefore: jjwt.NewNumericDate(time.Now()), + Audience: jwt.Audience{"coder"}, + Expiry: jwt.NewNumericDate(time.Now().Add(time.Hour)), + IssuedAt: jwt.NewNumericDate(time.Now()), + NotBefore: jwt.NewNumericDate(time.Now()), }, MyClaim: "my_value", } @@ -138,13 +138,13 @@ func TestJWT(t *testing.T) { keycache.EXPECT().Verifying(ctx, key.Sequence).Return(key, nil) claims := testClaims{ - Claims: jjwt.Claims{ + Claims: jwt.Claims{ Issuer: "coder", Subject: "user@coder.com", - Audience: jjwt.Audience{"coder"}, - Expiry: jjwt.NewNumericDate(time.Now().Add(time.Hour)), - IssuedAt: jjwt.NewNumericDate(time.Now()), - NotBefore: jjwt.NewNumericDate(time.Now()), + Audience: jwt.Audience{"coder"}, + Expiry: jwt.NewNumericDate(time.Now().Add(time.Hour)), + IssuedAt: jwt.NewNumericDate(time.Now()), + NotBefore: jwt.NewNumericDate(time.Now()), }, MyClaim: "my_value", } @@ -174,13 +174,13 @@ func TestJWT(t *testing.T) { keycache.EXPECT().Verifying(ctx, key.Sequence).Return(key, nil) claims := testClaims{ - Claims: jjwt.Claims{ + Claims: jwt.Claims{ Issuer: "coder", Subject: "user@coder.com", - Audience: jjwt.Audience{"coder"}, - Expiry: jjwt.NewNumericDate(time.Now().Add(time.Hour)), - IssuedAt: jjwt.NewNumericDate(time.Now()), - NotBefore: jjwt.NewNumericDate(time.Now()), + Audience: jwt.Audience{"coder"}, + Expiry: jwt.NewNumericDate(time.Now().Add(time.Hour)), + IssuedAt: jwt.NewNumericDate(time.Now()), + NotBefore: jwt.NewNumericDate(time.Now()), }, MyClaim: "my_value", } @@ -210,13 +210,13 @@ func TestJWT(t *testing.T) { keycache.EXPECT().Verifying(ctx, key.Sequence).Return(key, nil) claims := testClaims{ - Claims: jjwt.Claims{ + Claims: jwt.Claims{ Issuer: "coder", Subject: "user@coder.com", - Audience: jjwt.Audience{"coder"}, - Expiry: jjwt.NewNumericDate(time.Now().Add(time.Minute)), - IssuedAt: jjwt.NewNumericDate(time.Now()), - NotBefore: jjwt.NewNumericDate(time.Now()), + Audience: jwt.Audience{"coder"}, + Expiry: jwt.NewNumericDate(time.Now().Add(time.Minute)), + IssuedAt: jwt.NewNumericDate(time.Now()), + NotBefore: jwt.NewNumericDate(time.Now()), }, MyClaim: "my_value", } @@ -247,12 +247,12 @@ func TestJWT(t *testing.T) { keycache.EXPECT().Verifying(ctx, key.Sequence).Return(key, nil) claims := testClaims{ - Claims: jjwt.Claims{ + Claims: jwt.Claims{ Issuer: "coder", Subject: "user@coder.com", - Audience: jjwt.Audience{"coder"}, - Expiry: jjwt.NewNumericDate(time.Now().Add(time.Minute)), - IssuedAt: jjwt.NewNumericDate(time.Now()), + Audience: jwt.Audience{"coder"}, + Expiry: jwt.NewNumericDate(time.Now().Add(time.Minute)), + IssuedAt: jwt.NewNumericDate(time.Now()), }, MyClaim: "my_value", } @@ -283,13 +283,13 @@ func TestJWT(t *testing.T) { keycache.EXPECT().Verifying(ctx, key.Sequence).Return(key, nil) claims := testClaims{ - Claims: jjwt.Claims{ + Claims: jwt.Claims{ Issuer: "coder", Subject: "user@coder.com", - Audience: jjwt.Audience{"coder"}, - Expiry: jjwt.NewNumericDate(time.Now().Add(time.Hour)), - IssuedAt: jjwt.NewNumericDate(time.Now()), - NotBefore: jjwt.NewNumericDate(time.Now().Add(time.Minute * 5)), + Audience: jwt.Audience{"coder"}, + Expiry: jwt.NewNumericDate(time.Now().Add(time.Hour)), + IssuedAt: jwt.NewNumericDate(time.Now()), + NotBefore: jwt.NewNumericDate(time.Now().Add(time.Minute * 5)), }, MyClaim: "my_value", } @@ -327,9 +327,9 @@ func TestJWT(t *testing.T) { Issuer: "coder", Subject: "user@coder.com", Audience: jjwt.Audience{"coder"}, - Expiry: jjwt.NewNumericDate(time.Now().Add(time.Hour)), - IssuedAt: jjwt.NewNumericDate(time.Now()), - NotBefore: jjwt.NewNumericDate(time.Now().Add(time.Minute * 5)), + Expiry: jwt.NewNumericDate(time.Now().Add(time.Hour)), + IssuedAt: jwt.NewNumericDate(time.Now()), + NotBefore: jwt.NewNumericDate(time.Now().Add(time.Minute * 5)), }, MyClaim: "my_value", } @@ -399,13 +399,13 @@ func TestJWT(t *testing.T) { keycache.EXPECT().Signing(gomock.Any()).Return(key, nil) claims := testClaims{ - Claims: jjwt.Claims{ + Claims: jwt.Claims{ Issuer: "coder", Subject: "user@coder.com", - Audience: jjwt.Audience{"coder"}, - Expiry: jjwt.NewNumericDate(time.Now().Add(time.Hour)), - IssuedAt: jjwt.NewNumericDate(time.Now()), - NotBefore: jjwt.NewNumericDate(time.Now().Add(time.Minute * 5)), + Audience: jwt.Audience{"coder"}, + Expiry: jwt.NewNumericDate(time.Now().Add(time.Hour)), + IssuedAt: jwt.NewNumericDate(time.Now()), + NotBefore: jwt.NewNumericDate(time.Now().Add(time.Minute * 5)), }, MyClaim: "my_value", } From c7489b40879ec7692d5d82ddf96d7431bdde054d Mon Sep 17 00:00:00 2001 From: Jon Ayers Date: Thu, 3 Oct 2024 02:18:24 +0000 Subject: [PATCH 12/25] update references --- coderd/jwtutils/jwt_test.go | 60 ++++++++++++++++++------------------- 1 file changed, 30 insertions(+), 30 deletions(-) diff --git a/coderd/jwtutils/jwt_test.go b/coderd/jwtutils/jwt_test.go index c068dc7e3d9b3..a28314bfc91c5 100644 --- a/coderd/jwtutils/jwt_test.go +++ b/coderd/jwtutils/jwt_test.go @@ -7,10 +7,10 @@ import ( "testing" "time" + "github.com/go-jose/go-jose/v4" "github.com/go-jose/go-jose/v4/jwt" "github.com/stretchr/testify/require" "go.uber.org/mock/gomock" - "gopkg.in/square/go-jose.v2" "github.com/coder/coder/v2/coderd/cryptokeys" "github.com/coder/coder/v2/coderd/jwtutils" @@ -116,10 +116,10 @@ func TestJWT(t *testing.T) { require.NoError(t, err) var actual testClaims - err = tt.VerifyFn(ctx, keycache, token, &actual, withExpected(jjwt.Expected{ + err = tt.VerifyFn(ctx, keycache, token, &actual, withExpected(jwt.Expected{ Issuer: "coder2", })) - require.ErrorIs(t, err, jjwt.ErrInvalidIssuer) + require.ErrorIs(t, err, jwt.ErrInvalidIssuer) }) t.Run("WrongSubject", func(t *testing.T) { @@ -153,10 +153,10 @@ func TestJWT(t *testing.T) { require.NoError(t, err) var actual testClaims - err = tt.VerifyFn(ctx, keycache, token, &actual, withExpected(jjwt.Expected{ + err = tt.VerifyFn(ctx, keycache, token, &actual, withExpected(jwt.Expected{ Subject: "user2@coder.com", })) - require.ErrorIs(t, err, jjwt.ErrInvalidSubject) + require.ErrorIs(t, err, jwt.ErrInvalidSubject) }) t.Run("WrongAudience", func(t *testing.T) { @@ -189,10 +189,10 @@ func TestJWT(t *testing.T) { require.NoError(t, err) var actual testClaims - err = tt.VerifyFn(ctx, keycache, token, &actual, withExpected(jjwt.Expected{ - AnyAudience: jjwt.Audience{"coder2"}, + err = tt.VerifyFn(ctx, keycache, token, &actual, withExpected(jwt.Expected{ + AnyAudience: jwt.Audience{"coder2"}, })) - require.ErrorIs(t, err, jjwt.ErrInvalidAudience) + require.ErrorIs(t, err, jwt.ErrInvalidAudience) }) t.Run("Expired", func(t *testing.T) { @@ -225,10 +225,10 @@ func TestJWT(t *testing.T) { require.NoError(t, err) var actual testClaims - err = tt.VerifyFn(ctx, keycache, token, &actual, withExpected(jjwt.Expected{ + err = tt.VerifyFn(ctx, keycache, token, &actual, withExpected(jwt.Expected{ Time: time.Now().Add(time.Minute * 3), })) - require.ErrorIs(t, err, jjwt.ErrExpired) + require.ErrorIs(t, err, jwt.ErrExpired) }) t.Run("IssuedInFuture", func(t *testing.T) { @@ -261,10 +261,10 @@ func TestJWT(t *testing.T) { require.NoError(t, err) var actual testClaims - err = tt.VerifyFn(ctx, keycache, token, &actual, withExpected(jjwt.Expected{ + err = tt.VerifyFn(ctx, keycache, token, &actual, withExpected(jwt.Expected{ Time: time.Now().Add(-time.Minute * 3), })) - require.ErrorIs(t, err, jjwt.ErrIssuedInTheFuture) + require.ErrorIs(t, err, jwt.ErrIssuedInTheFuture) }) t.Run("IsBefore", func(t *testing.T) { @@ -298,10 +298,10 @@ func TestJWT(t *testing.T) { require.NoError(t, err) var actual testClaims - err = tt.VerifyFn(ctx, keycache, token, &actual, withExpected(jjwt.Expected{ + err = tt.VerifyFn(ctx, keycache, token, &actual, withExpected(jwt.Expected{ Time: time.Now().Add(time.Minute * 3), })) - require.ErrorIs(t, err, jjwt.ErrNotValidYet) + require.ErrorIs(t, err, jwt.ErrNotValidYet) }) t.Run("WrongSignatureAlgorithm", func(t *testing.T) { @@ -323,10 +323,10 @@ func TestJWT(t *testing.T) { keycache.EXPECT().Signing(ctx).Return(key, nil) claims := testClaims{ - Claims: jjwt.Claims{ + Claims: jwt.Claims{ Issuer: "coder", Subject: "user@coder.com", - Audience: jjwt.Audience{"coder"}, + Audience: jwt.Audience{"coder"}, Expiry: jwt.NewNumericDate(time.Now().Add(time.Hour)), IssuedAt: jwt.NewNumericDate(time.Now()), NotBefore: jwt.NewNumericDate(time.Now().Add(time.Minute * 5)), @@ -361,13 +361,13 @@ func TestJWT(t *testing.T) { keycache.EXPECT().Signing(ctx).Return(key, nil) claims := testClaims{ - Claims: jjwt.Claims{ + Claims: jwt.Claims{ Issuer: "coder", Subject: "user@coder.com", - Audience: jjwt.Audience{"coder"}, - Expiry: jjwt.NewNumericDate(time.Now().Add(time.Hour)), - IssuedAt: jjwt.NewNumericDate(time.Now()), - NotBefore: jjwt.NewNumericDate(time.Now().Add(time.Minute * 5)), + Audience: jwt.Audience{"coder"}, + Expiry: jwt.NewNumericDate(time.Now().Add(time.Hour)), + IssuedAt: jwt.NewNumericDate(time.Now()), + NotBefore: jwt.NewNumericDate(time.Now().Add(time.Minute * 5)), }, MyClaim: "my_value", } @@ -445,29 +445,29 @@ func generateSecret(t *testing.T, keySize int) []byte { type testClaims struct { MyClaim string `json:"my_claim"` - jjwt.Claims + jwt.Claims } -func withExpected(e jjwt.Expected) func(*jwt.ParseOptions) { - return func(opts *jwt.ParseOptions) { +func withExpected(e jwt.Expected) func(*jwtutils.ParseOptions) { + return func(opts *jwtutils.ParseOptions) { opts.RegisteredClaims = e } } -func withSignatureAlgorithm(alg jose.SignatureAlgorithm) func(*jwt.ParseOptions) { - return func(opts *jwt.ParseOptions) { +func withSignatureAlgorithm(alg jose.SignatureAlgorithm) func(*jwtutils.ParseOptions) { + return func(opts *jwtutils.ParseOptions) { opts.SignatureAlgorithm = alg } } -func withKeyAlgorithm(alg jose.KeyAlgorithm) func(*jwt.ParseOptions) { - return func(opts *jwt.ParseOptions) { +func withKeyAlgorithm(alg jose.KeyAlgorithm) func(*jwtutils.ParseOptions) { + return func(opts *jwtutils.ParseOptions) { opts.KeyAlgorithm = alg } } -func withContentEncryptionAlgorithm(alg jose.ContentEncryption) func(*jwt.ParseOptions) { - return func(opts *jwt.ParseOptions) { +func withContentEncryptionAlgorithm(alg jose.ContentEncryption) func(*jwtutils.ParseOptions) { + return func(opts *jwtutils.ParseOptions) { opts.ContentEncryptionAlgorithm = alg } } From d890ea2fdf05bf0cb0accad9bfb4e4fe493aa1c2 Mon Sep 17 00:00:00 2001 From: Jon Ayers Date: Thu, 3 Oct 2024 04:24:36 +0000 Subject: [PATCH 13/25] refactor interfaces --- coderd/cryptokeys/dbkeycache.go | 99 +++++++++++++++---- coderd/cryptokeys/dbkeycache_internal_test.go | 3 +- coderd/cryptokeys/keycache.go | 18 +++- coderd/jwtutils/jwe.go | 41 +++----- coderd/jwtutils/jws.go | 35 +++---- coderd/jwtutils/jwt_test.go | 57 ++++++++--- 6 files changed, 168 insertions(+), 85 deletions(-) diff --git a/coderd/cryptokeys/dbkeycache.go b/coderd/cryptokeys/dbkeycache.go index 4986f1669c4e5..f4c60e7fa95b9 100644 --- a/coderd/cryptokeys/dbkeycache.go +++ b/coderd/cryptokeys/dbkeycache.go @@ -2,6 +2,7 @@ package cryptokeys import ( "context" + "strconv" "sync" "time" @@ -9,8 +10,6 @@ import ( "cdr.dev/slog" "github.com/coder/coder/v2/coderd/database" - "github.com/coder/coder/v2/coderd/database/db2sdk" - "github.com/coder/coder/v2/codersdk" "github.com/coder/quartz" ) @@ -61,18 +60,54 @@ func NewDBCache(logger slog.Logger, db database.Store, feature database.CryptoKe return d } +func (d *DBCache) EncryptingKey(ctx context.Context) (id string, key interface{}, err error) { + if !isEncryptionKeyFeature(d.feature) { + return "", nil, xerrors.Errorf("invalid feature: %s", d.feature) + } + return d.Signing(ctx) +} + +func (d *DBCache) DecryptingKey(ctx context.Context, id string) (key interface{}, err error) { + if !isEncryptionKeyFeature(d.feature) { + return nil, xerrors.Errorf("invalid feature: %s", d.feature) + } + + return d.Verifying(ctx, id) +} + +func (d *DBCache) SigningKey(ctx context.Context) (id string, key interface{}, err error) { + if !isSigningKeyFeature(d.feature) { + return "", nil, xerrors.Errorf("invalid feature: %s", d.feature) + } + + return d.Signing(ctx) +} + +func (d *DBCache) VerifyingKey(ctx context.Context, id string) (key interface{}, err error) { + if !isSigningKeyFeature(d.feature) { + return nil, xerrors.Errorf("invalid feature: %s", d.feature) + } + + return d.Verifying(ctx, id) +} + // Verifying returns the CryptoKey with the given sequence number, provided that // it is neither deleted nor has breached its deletion date. It should only be // used for verifying or decrypting payloads. To sign/encrypt call Signing. -func (d *DBCache) Verifying(ctx context.Context, sequence int32) (codersdk.CryptoKey, error) { +func (d *DBCache) Verifying(ctx context.Context, id string) (interface{}, error) { + sequence, err := strconv.ParseInt(id, 10, 32) + if err != nil { + return nil, xerrors.Errorf("expecting sequence number got %q: %w", id, err) + } + d.keysMu.RLock() if d.closed { d.keysMu.RUnlock() - return codersdk.CryptoKey{}, ErrClosed + return nil, ErrClosed } now := d.clock.Now() - key, ok := d.keys[sequence] + key, ok := d.keys[int32(sequence)] d.keysMu.RUnlock() if ok { return checkKey(key, now) @@ -82,22 +117,22 @@ func (d *DBCache) Verifying(ctx context.Context, sequence int32) (codersdk.Crypt defer d.keysMu.Unlock() if d.closed { - return codersdk.CryptoKey{}, ErrClosed + return nil, ErrClosed } - key, ok = d.keys[sequence] + key, ok = d.keys[int32(sequence)] if ok { return checkKey(key, now) } - err := d.fetch(ctx) + err = d.fetch(ctx) if err != nil { - return codersdk.CryptoKey{}, xerrors.Errorf("fetch: %w", err) + return nil, xerrors.Errorf("fetch: %w", err) } - key, ok = d.keys[sequence] + key, ok = d.keys[int32(sequence)] if !ok { - return codersdk.CryptoKey{}, ErrKeyNotFound + return nil, ErrKeyNotFound } return checkKey(key, now) @@ -105,12 +140,12 @@ func (d *DBCache) Verifying(ctx context.Context, sequence int32) (codersdk.Crypt // Signing returns the latest valid key for signing. A valid key is one that is // both past its start time and before its deletion time. -func (d *DBCache) Signing(ctx context.Context) (codersdk.CryptoKey, error) { +func (d *DBCache) Signing(ctx context.Context) (string, interface{}, error) { d.keysMu.RLock() if d.closed { d.keysMu.RUnlock() - return codersdk.CryptoKey{}, ErrClosed + return "", nil, ErrClosed } latest := d.latestKey @@ -118,27 +153,27 @@ func (d *DBCache) Signing(ctx context.Context) (codersdk.CryptoKey, error) { now := d.clock.Now() if latest.CanSign(now) { - return db2sdk.CryptoKey(latest), nil + return idSecret(latest) } d.keysMu.Lock() defer d.keysMu.Unlock() if d.closed { - return codersdk.CryptoKey{}, ErrClosed + return "", nil, ErrClosed } if d.latestKey.CanSign(now) { - return db2sdk.CryptoKey(d.latestKey), nil + return idSecret(d.latestKey) } // Refetch all keys for this feature so we can find the latest valid key. err := d.fetch(ctx) if err != nil { - return codersdk.CryptoKey{}, xerrors.Errorf("fetch: %w", err) + return "", nil, xerrors.Errorf("fetch: %w", err) } - return db2sdk.CryptoKey(d.latestKey), nil + return idSecret(d.latestKey) } // clear invalidates the cache. This forces the subsequent call to fetch fresh keys. @@ -189,12 +224,12 @@ func (d *DBCache) fetch(ctx context.Context) error { return nil } -func checkKey(key database.CryptoKey, now time.Time) (codersdk.CryptoKey, error) { +func checkKey(key database.CryptoKey, now time.Time) (interface{}, error) { if !key.CanVerify(now) { - return codersdk.CryptoKey{}, ErrKeyInvalid + return nil, ErrKeyInvalid } - return db2sdk.CryptoKey(key), nil + return key.DecodeString() } func (d *DBCache) Close() { @@ -208,3 +243,25 @@ func (d *DBCache) Close() { d.timer.Stop() d.closed = true } + +func isEncryptionKeyFeature(feature database.CryptoKeyFeature) bool { + return feature == database.CryptoKeyFeatureWorkspaceApps +} + +func isSigningKeyFeature(feature database.CryptoKeyFeature) bool { + switch feature { + case database.CryptoKeyFeatureTailnetResume, database.CryptoKeyFeatureOidcConvert: + return true + default: + return false + } +} + +func idSecret(k database.CryptoKey) (string, interface{}, error) { + key, err := k.DecodeString() + if err != nil { + return "", nil, xerrors.Errorf("decode key: %w", err) + } + + return strconv.FormatInt(int64(k.Sequence), 10), key, nil +} diff --git a/coderd/cryptokeys/dbkeycache_internal_test.go b/coderd/cryptokeys/dbkeycache_internal_test.go index a3450f5f5e0d9..8611196749a4a 100644 --- a/coderd/cryptokeys/dbkeycache_internal_test.go +++ b/coderd/cryptokeys/dbkeycache_internal_test.go @@ -48,8 +48,9 @@ func Test_Verifying(t *testing.T) { defer k.Close() k.keys = cache - got, err := k.Verifying(ctx, 32) + id, secret, err := k.SigningKey(ctx) require.NoError(t, err) + require.Equal(t, "32", id) require.Equal(t, db2sdk.CryptoKey(expectedKey), got) }) diff --git a/coderd/cryptokeys/keycache.go b/coderd/cryptokeys/keycache.go index 8c4ebfa13f64e..8df4beab99077 100644 --- a/coderd/cryptokeys/keycache.go +++ b/coderd/cryptokeys/keycache.go @@ -4,8 +4,6 @@ import ( "context" "golang.org/x/xerrors" - - "github.com/coder/coder/v2/codersdk" ) var ( @@ -14,8 +12,18 @@ var ( ErrClosed = xerrors.New("closed") ) -// Keycache provides an abstraction for fetching signing keys. +// Keycache provides an abstraction for fetching cryptographic keys used for signing or encrypting payloads. type Keycache interface { - Signing(ctx context.Context) (codersdk.CryptoKey, error) - Verifying(ctx context.Context, sequence int32) (codersdk.CryptoKey, error) + SigningKeycache + EncryptionKeycache +} + +type EncryptionKeycache interface { + EncryptingKey(ctx context.Context) (id string, key interface{}, err error) + DecryptingKey(ctx context.Context, id string) (key interface{}, err error) +} + +type SigningKeycache interface { + SigningKey(ctx context.Context) (id string, key interface{}, err error) + VerifyingKey(ctx context.Context, id string) (key interface{}, err error) } diff --git a/coderd/jwtutils/jwe.go b/coderd/jwtutils/jwe.go index 222b38e27a3d0..31e524eded55a 100644 --- a/coderd/jwtutils/jwe.go +++ b/coderd/jwtutils/jwe.go @@ -3,16 +3,12 @@ package jwtutils import ( "context" "encoding/base64" - "encoding/hex" "encoding/json" - "strconv" "time" "github.com/go-jose/go-jose/v4" "github.com/go-jose/go-jose/v4/jwt" "golang.org/x/xerrors" - - "github.com/coder/coder/v2/coderd/cryptokeys" ) const ( @@ -20,28 +16,31 @@ const ( encryptContentAlgo = jose.A256GCM ) +type EncryptKeyer interface { + EncryptingKey(ctx context.Context) (id string, key interface{}, err error) +} + +type DecryptKeyer interface { + DecryptingKey(ctx context.Context, id string) (key interface{}, err error) +} + // Encrypt encrypts a token and returns it as a string. -func Encrypt(ctx context.Context, keys cryptokeys.Keycache, claims Claims) (string, error) { - signing, err := keys.Signing(ctx) +func Encrypt(ctx context.Context, e EncryptKeyer, claims Claims) (string, error) { + id, key, err := e.EncryptingKey(ctx) if err != nil { return "", xerrors.Errorf("get signing key: %w", err) } - decoded, err := hex.DecodeString(signing.Secret) - if err != nil { - return "", xerrors.Errorf("decode signing key: %w", err) - } - encrypter, err := jose.NewEncrypter( encryptContentAlgo, jose.Recipient{ Algorithm: encryptKeyAlgo, - Key: decoded, + Key: key, }, &jose.EncrypterOptions{ Compression: jose.DEFLATE, ExtraHeaders: map[jose.HeaderKey]interface{}{ - keyIDHeaderKey: strconv.FormatInt(int64(signing.Sequence), 10), + keyIDHeaderKey: id, }, }, ) @@ -64,7 +63,7 @@ func Encrypt(ctx context.Context, keys cryptokeys.Keycache, claims Claims) (stri } // Decrypt decrypts the token using the provided key. It unmarshals into the provided claims. -func Decrypt(ctx context.Context, keys cryptokeys.Keycache, token string, claims Claims, opts ...func(*ParseOptions)) error { +func Decrypt(ctx context.Context, d DecryptKeyer, token string, claims Claims, opts ...func(*ParseOptions)) error { options := ParseOptions{ RegisteredClaims: jwt.Expected{ Time: time.Now(), @@ -99,22 +98,12 @@ func Decrypt(ctx context.Context, keys cryptokeys.Keycache, token string, claims return xerrors.Errorf("expected %q header to be a string", keyIDHeaderKey) } - sequence, err := strconv.ParseInt(sequenceStr, 10, 32) - if err != nil { - return xerrors.Errorf("parse sequence: %w", err) - } - - key, err := keys.Verifying(ctx, int32(sequence)) + key, err := d.DecryptingKey(ctx, sequenceStr) if err != nil { return xerrors.Errorf("version: %w", err) } - decoded, err := hex.DecodeString(key.Secret) - if err != nil { - return xerrors.Errorf("decode key: %w", err) - } - - decrypted, err := object.Decrypt(decoded) + decrypted, err := object.Decrypt(key) if err != nil { return xerrors.Errorf("decrypt: %w", err) } diff --git a/coderd/jwtutils/jws.go b/coderd/jwtutils/jws.go index 462846a1cb202..b27744bccc7bf 100644 --- a/coderd/jwtutils/jws.go +++ b/coderd/jwtutils/jws.go @@ -2,7 +2,6 @@ package jwtutils import ( "context" - "encoding/hex" "encoding/json" "strconv" "time" @@ -10,32 +9,33 @@ import ( "github.com/go-jose/go-jose/v4" "github.com/go-jose/go-jose/v4/jwt" "golang.org/x/xerrors" - - "github.com/coder/coder/v2/coderd/cryptokeys" ) const ( signingAlgo = jose.HS512 ) +type SignKeyer interface { + SigningKey(ctx context.Context) (id string, key interface{}, err error) +} + +type VerifyKeyer interface { + VerifyingKey(ctx context.Context, id string) (key interface{}, err error) +} + // Sign signs a token and returns it as a string. -func Sign(ctx context.Context, keys cryptokeys.Keycache, claims Claims) (string, error) { - signing, err := keys.Signing(ctx) +func Sign(ctx context.Context, s SignKeyer, claims Claims) (string, error) { + id, key, err := s.SigningKey(ctx) if err != nil { return "", xerrors.Errorf("get signing key: %w", err) } - decoded, err := hex.DecodeString(signing.Secret) - if err != nil { - return "", xerrors.Errorf("decode signing key: %w", err) - } - signer, err := jose.NewSigner(jose.SigningKey{ Algorithm: signingAlgo, - Key: decoded, + Key: key, }, &jose.SignerOptions{ ExtraHeaders: map[jose.HeaderKey]interface{}{ - keyIDHeaderKey: strconv.FormatInt(int64(signing.Sequence), 10), + keyIDHeaderKey: id, }, }) if err != nil { @@ -61,7 +61,7 @@ func Sign(ctx context.Context, keys cryptokeys.Keycache, claims Claims) (string, } // Verify verifies that a token was signed by the provided key. It unmarshals into the provided claims. -func Verify(ctx context.Context, keys cryptokeys.Keycache, token string, claims Claims, opts ...func(*ParseOptions)) error { +func Verify(ctx context.Context, v VerifyKeyer, token string, claims Claims, opts ...func(*ParseOptions)) error { options := ParseOptions{ RegisteredClaims: jwt.Expected{ Time: time.Now(), @@ -98,17 +98,12 @@ func Verify(ctx context.Context, keys cryptokeys.Keycache, token string, claims return xerrors.Errorf("parse sequence %q: %w", sequenceStr, err) } - key, err := keys.Verifying(ctx, int32(sequence)) + key, err := v.VerifyingKey(ctx, sequenceStr) if err != nil { return xerrors.Errorf("verifying key for seq %v: %w", sequence, err) } - decoded, err := hex.DecodeString(key.Secret) - if err != nil { - return xerrors.Errorf("decode key: %w", err) - } - - payload, err := object.Verify(decoded) + payload, err := object.Verify(key) if err != nil { return xerrors.Errorf("verify payload: %w", err) } diff --git a/coderd/jwtutils/jwt_test.go b/coderd/jwtutils/jwt_test.go index a28314bfc91c5..6a99cebd998f3 100644 --- a/coderd/jwtutils/jwt_test.go +++ b/coderd/jwtutils/jwt_test.go @@ -22,24 +22,18 @@ func TestJWT(t *testing.T) { t.Parallel() type tokenType struct { - Name string - SignFn func(ctx context.Context, keys cryptokeys.Keycache, claims jwtutils.Claims) (string, error) - VerifyFn func(ctx context.Context, keys cryptokeys.Keycache, token string, claims jwtutils.Claims, opts ...func(*jwtutils.ParseOptions)) error - KeySize int + Name string + KeySize int } types := []tokenType{ { - Name: "JWE", - SignFn: jwtutils.Encrypt, - VerifyFn: jwtutils.Decrypt, - KeySize: 32, + Name: "JWE", + KeySize: 32, }, { - Name: "JWS", - SignFn: jwtutils.Sign, - VerifyFn: jwtutils.Verify, - KeySize: 64, + Name: "JWS", + KeySize: 64, }, } @@ -76,6 +70,17 @@ func TestJWT(t *testing.T) { MyClaim: "my_value", } + var token string + var err error + + if tt.Name == "JWE" { + token, err = jwtutils.Encrypt(ctx, keycache, claims) + require.NoError(t, err) + } else { + token, err = jwtutils.Sign(ctx, keycache, claims) + require.NoError(t, err) + } + token, err := tt.SignFn(ctx, keycache, claims) require.NoError(t, err) @@ -471,3 +476,31 @@ func withContentEncryptionAlgorithm(alg jose.ContentEncryption) func(*jwtutils.P opts.ContentEncryptionAlgorithm = alg } } + +type godkey interface { + jwtutils.SignKeyer + jwtutils.VerifyKeyer + jwtutils.EncryptKeyer + jwtutils.DecryptKeyer +} + +type key struct { + signFn func(context.Context) (string, interface{}, error) + verifyFn func(context.Context, string) (interface{}, error) +} + +func (k *key) SigningKey(ctx context.Context) (string, interface{}, error) { + return k.signFn(ctx) +} + +func (k *key) VerifyingKey(ctx context.Context, id string) (interface{}, error) { + return k.verifyFn(ctx, id) +} + +func (k *key) EncryptingKey(ctx context.Context) (string, interface{}, error) { + return k.signFn(ctx) +} + +func (k *key) DecryptingKey(ctx context.Context, id string) (interface{}, error) { + return k.verifyFn(ctx, id) +} From 67ccd5cf26750d02c2a822ac20d2cbb36f9a49a6 Mon Sep 17 00:00:00 2001 From: Jon Ayers Date: Thu, 3 Oct 2024 06:41:39 +0000 Subject: [PATCH 14/25] refactor dbkeycache --- coderd/cryptokeys/dbkeycache.go | 65 ++++++---- coderd/cryptokeys/dbkeycache_internal_test.go | 122 +++++++++++------- coderd/cryptokeys/dbkeycache_test.go | 91 +++++++++---- coderd/cryptokeys/keycache.go | 16 +-- 4 files changed, 189 insertions(+), 105 deletions(-) diff --git a/coderd/cryptokeys/dbkeycache.go b/coderd/cryptokeys/dbkeycache.go index f4c60e7fa95b9..e6b9952ade90d 100644 --- a/coderd/cryptokeys/dbkeycache.go +++ b/coderd/cryptokeys/dbkeycache.go @@ -16,8 +16,8 @@ import ( // never represents the maximum value for a time.Duration. const never = 1<<63 - 1 -// DBCache implements Keycache for callers with access to the database. -type DBCache struct { +// dbCache implements Keycache for callers with access to the database. +type dbCache struct { db database.Store feature database.CryptoKeyFeature logger slog.Logger @@ -33,18 +33,34 @@ type DBCache struct { closed bool } -type DBCacheOption func(*DBCache) +type DBCacheOption func(*dbCache) func WithDBCacheClock(clock quartz.Clock) DBCacheOption { - return func(d *DBCache) { + return func(d *dbCache) { d.clock = clock } } -// NewDBCache creates a new DBCache. Close should be called to +// NewSigningCache creates a new DBCache. Close should be called to // release resources associated with its internal timer. -func NewDBCache(logger slog.Logger, db database.Store, feature database.CryptoKeyFeature, opts ...func(*DBCache)) *DBCache { - d := &DBCache{ +func NewSigningCache(logger slog.Logger, db database.Store, feature database.CryptoKeyFeature, opts ...func(*dbCache)) (SigningKeycache, error) { + if !isSigningKeyFeature(feature) { + return nil, ErrInvalidFeature + } + + return newDBCache(logger, db, feature, opts...), nil +} + +func NewEncryptionCache(logger slog.Logger, db database.Store, feature database.CryptoKeyFeature, opts ...func(*dbCache)) (EncryptionKeycache, error) { + if !isEncryptionKeyFeature(feature) { + return nil, ErrInvalidFeature + } + + return newDBCache(logger, db, feature, opts...), nil +} + +func newDBCache(logger slog.Logger, db database.Store, feature database.CryptoKeyFeature, opts ...func(*dbCache)) *dbCache { + d := &dbCache{ db: db, feature: feature, clock: quartz.NewReal(), @@ -60,41 +76,41 @@ func NewDBCache(logger slog.Logger, db database.Store, feature database.CryptoKe return d } -func (d *DBCache) EncryptingKey(ctx context.Context) (id string, key interface{}, err error) { +func (d *dbCache) EncryptingKey(ctx context.Context) (id string, key interface{}, err error) { if !isEncryptionKeyFeature(d.feature) { return "", nil, xerrors.Errorf("invalid feature: %s", d.feature) } - return d.Signing(ctx) + return d.latest(ctx) } -func (d *DBCache) DecryptingKey(ctx context.Context, id string) (key interface{}, err error) { +func (d *dbCache) DecryptingKey(ctx context.Context, id string) (key interface{}, err error) { if !isEncryptionKeyFeature(d.feature) { return nil, xerrors.Errorf("invalid feature: %s", d.feature) } - return d.Verifying(ctx, id) + return d.sequence(ctx, id) } -func (d *DBCache) SigningKey(ctx context.Context) (id string, key interface{}, err error) { +func (d *dbCache) SigningKey(ctx context.Context) (id string, key interface{}, err error) { if !isSigningKeyFeature(d.feature) { return "", nil, xerrors.Errorf("invalid feature: %s", d.feature) } - return d.Signing(ctx) + return d.latest(ctx) } -func (d *DBCache) VerifyingKey(ctx context.Context, id string) (key interface{}, err error) { - if !isSigningKeyFeature(d.feature) { +func (d *dbCache) VerifyingKey(ctx context.Context, id string) (key interface{}, err error) { + if !isSigningKeyFeature(d.feature) { return nil, xerrors.Errorf("invalid feature: %s", d.feature) } - return d.Verifying(ctx, id) + return d.sequence(ctx, id) } -// Verifying returns the CryptoKey with the given sequence number, provided that +// sequence returns the CryptoKey with the given sequence number, provided that // it is neither deleted nor has breached its deletion date. It should only be // used for verifying or decrypting payloads. To sign/encrypt call Signing. -func (d *DBCache) Verifying(ctx context.Context, id string) (interface{}, error) { +func (d *dbCache) sequence(ctx context.Context, id string) (interface{}, error) { sequence, err := strconv.ParseInt(id, 10, 32) if err != nil { return nil, xerrors.Errorf("expecting sequence number got %q: %w", id, err) @@ -138,9 +154,9 @@ func (d *DBCache) Verifying(ctx context.Context, id string) (interface{}, error) return checkKey(key, now) } -// Signing returns the latest valid key for signing. A valid key is one that is +// latest returns the latest valid key for signing. A valid key is one that is // both past its start time and before its deletion time. -func (d *DBCache) Signing(ctx context.Context) (string, interface{}, error) { +func (d *dbCache) latest(ctx context.Context) (string, interface{}, error) { d.keysMu.RLock() if d.closed { @@ -177,7 +193,7 @@ func (d *DBCache) Signing(ctx context.Context) (string, interface{}, error) { } // clear invalidates the cache. This forces the subsequent call to fetch fresh keys. -func (d *DBCache) clear() { +func (d *dbCache) clear() { now := d.clock.Now("DBCache", "clear") d.keysMu.Lock() defer d.keysMu.Unlock() @@ -193,7 +209,7 @@ func (d *DBCache) clear() { // fetch fetches all keys for the given feature and determines the latest key. // It must be called while holding the keysMu lock. -func (d *DBCache) fetch(ctx context.Context) error { +func (d *dbCache) fetch(ctx context.Context) error { keys, err := d.db.GetCryptoKeysByFeature(ctx, d.feature) if err != nil { return xerrors.Errorf("get crypto keys by feature: %w", err) @@ -232,16 +248,17 @@ func checkKey(key database.CryptoKey, now time.Time) (interface{}, error) { return key.DecodeString() } -func (d *DBCache) Close() { +func (d *dbCache) Close() error { d.keysMu.Lock() defer d.keysMu.Unlock() if d.closed { - return + return nil } d.timer.Stop() d.closed = true + return nil } func isEncryptionKeyFeature(feature database.CryptoKeyFeature) bool { diff --git a/coderd/cryptokeys/dbkeycache_internal_test.go b/coderd/cryptokeys/dbkeycache_internal_test.go index 8611196749a4a..c27bc5b8468ad 100644 --- a/coderd/cryptokeys/dbkeycache_internal_test.go +++ b/coderd/cryptokeys/dbkeycache_internal_test.go @@ -2,6 +2,7 @@ package cryptokeys import ( "database/sql" + "strconv" "testing" "time" @@ -11,13 +12,12 @@ import ( "cdr.dev/slog/sloggers/slogtest" "github.com/coder/coder/v2/coderd/database" - "github.com/coder/coder/v2/coderd/database/db2sdk" "github.com/coder/coder/v2/coderd/database/dbmock" "github.com/coder/coder/v2/testutil" "github.com/coder/quartz" ) -func Test_Verifying(t *testing.T) { +func Test_version(t *testing.T) { t.Parallel() t.Run("HitsCache", func(t *testing.T) { @@ -35,7 +35,7 @@ func Test_Verifying(t *testing.T) { Feature: database.CryptoKeyFeatureWorkspaceApps, Sequence: 32, Secret: sql.NullString{ - String: "secret", + String: mustGenerateKey(t), Valid: true, }, } @@ -44,14 +44,13 @@ func Test_Verifying(t *testing.T) { 32: expectedKey, } - k := NewDBCache(logger, mockDB, database.CryptoKeyFeatureWorkspaceApps, WithDBCacheClock(clock)) + k := newDBCache(logger, mockDB, database.CryptoKeyFeatureWorkspaceApps, WithDBCacheClock(clock)) defer k.Close() k.keys = cache - id, secret, err := k.SigningKey(ctx) + secret, err := k.sequence(ctx, keyID(expectedKey)) require.NoError(t, err) - require.Equal(t, "32", id) - require.Equal(t, db2sdk.CryptoKey(expectedKey), got) + require.Equal(t, decodedSecret(t, expectedKey), secret) }) t.Run("MissesCache", func(t *testing.T) { @@ -70,20 +69,19 @@ func Test_Verifying(t *testing.T) { Sequence: 33, StartsAt: clock.Now(), Secret: sql.NullString{ - String: "secret", + String: mustGenerateKey(t), Valid: true, }, } mockDB.EXPECT().GetCryptoKeysByFeature(ctx, database.CryptoKeyFeatureWorkspaceApps).Return([]database.CryptoKey{expectedKey}, nil) - k := NewDBCache(logger, mockDB, database.CryptoKeyFeatureWorkspaceApps, WithDBCacheClock(clock)) + k := newDBCache(logger, mockDB, database.CryptoKeyFeatureWorkspaceApps, WithDBCacheClock(clock)) defer k.Close() - got, err := k.Verifying(ctx, 33) + got, err := k.sequence(ctx, keyID(expectedKey)) require.NoError(t, err) - require.Equal(t, db2sdk.CryptoKey(expectedKey), got) - require.Equal(t, db2sdk.CryptoKey(expectedKey), db2sdk.CryptoKey(k.latestKey)) + require.Equal(t, decodedSecret(t, expectedKey), got) }) t.Run("InvalidCachedKey", func(t *testing.T) { @@ -102,7 +100,7 @@ func Test_Verifying(t *testing.T) { Feature: database.CryptoKeyFeatureWorkspaceApps, Sequence: 32, Secret: sql.NullString{ - String: "secret", + String: mustGenerateKey(t), Valid: true, }, DeletesAt: sql.NullTime{ @@ -112,11 +110,11 @@ func Test_Verifying(t *testing.T) { }, } - k := NewDBCache(logger, mockDB, database.CryptoKeyFeatureWorkspaceApps, WithDBCacheClock(clock)) + k := newDBCache(logger, mockDB, database.CryptoKeyFeatureWorkspaceApps, WithDBCacheClock(clock)) defer k.Close() k.keys = cache - _, err := k.Verifying(ctx, 32) + _, err := k.sequence(ctx, "32") require.ErrorIs(t, err, ErrKeyInvalid) }) @@ -135,7 +133,7 @@ func Test_Verifying(t *testing.T) { Feature: database.CryptoKeyFeatureWorkspaceApps, Sequence: 32, Secret: sql.NullString{ - String: "secret", + String: mustGenerateKey(t), Valid: true, }, DeletesAt: sql.NullTime{ @@ -145,15 +143,15 @@ func Test_Verifying(t *testing.T) { } mockDB.EXPECT().GetCryptoKeysByFeature(ctx, database.CryptoKeyFeatureWorkspaceApps).Return([]database.CryptoKey{invalidKey}, nil) - k := NewDBCache(logger, mockDB, database.CryptoKeyFeatureWorkspaceApps, WithDBCacheClock(clock)) + k := newDBCache(logger, mockDB, database.CryptoKeyFeatureWorkspaceApps, WithDBCacheClock(clock)) defer k.Close() - _, err := k.Verifying(ctx, 32) + _, err := k.sequence(ctx, keyID(invalidKey)) require.ErrorIs(t, err, ErrKeyInvalid) }) } -func Test_Signing(t *testing.T) { +func Test_latest(t *testing.T) { t.Parallel() t.Run("HitsCache", func(t *testing.T) { @@ -171,19 +169,20 @@ func Test_Signing(t *testing.T) { Feature: database.CryptoKeyFeatureWorkspaceApps, Sequence: 32, Secret: sql.NullString{ - String: "secret", + String: mustGenerateKey(t), Valid: true, }, StartsAt: clock.Now(), } - k := NewDBCache(logger, mockDB, database.CryptoKeyFeatureWorkspaceApps, WithDBCacheClock(clock)) + k := newDBCache(logger, mockDB, database.CryptoKeyFeatureWorkspaceApps, WithDBCacheClock(clock)) defer k.Close() k.latestKey = latestKey - got, err := k.Signing(ctx) + id, secret, err := k.latest(ctx) require.NoError(t, err) - require.Equal(t, db2sdk.CryptoKey(latestKey), got) + require.Equal(t, keyID(latestKey), id) + require.Equal(t, decodedSecret(t, latestKey), secret) }) t.Run("InvalidCachedKey", func(t *testing.T) { @@ -201,7 +200,7 @@ func Test_Signing(t *testing.T) { Feature: database.CryptoKeyFeatureWorkspaceApps, Sequence: 33, Secret: sql.NullString{ - String: "secret", + String: mustGenerateKey(t), Valid: true, }, StartsAt: clock.Now(), @@ -211,7 +210,7 @@ func Test_Signing(t *testing.T) { Feature: database.CryptoKeyFeatureWorkspaceApps, Sequence: 32, Secret: sql.NullString{ - String: "secret", + String: mustGenerateKey(t), Valid: true, }, StartsAt: clock.Now().Add(-time.Hour), @@ -223,13 +222,14 @@ func Test_Signing(t *testing.T) { mockDB.EXPECT().GetCryptoKeysByFeature(ctx, database.CryptoKeyFeatureWorkspaceApps).Return([]database.CryptoKey{latestKey}, nil) - k := NewDBCache(logger, mockDB, database.CryptoKeyFeatureWorkspaceApps, WithDBCacheClock(clock)) + k := newDBCache(logger, mockDB, database.CryptoKeyFeatureWorkspaceApps, WithDBCacheClock(clock)) defer k.Close() k.latestKey = invalidKey - got, err := k.Signing(ctx) + id, secret, err := k.latest(ctx) require.NoError(t, err) - require.Equal(t, db2sdk.CryptoKey(latestKey), got) + require.Equal(t, keyID(latestKey), id) + require.Equal(t, decodedSecret(t, latestKey), secret) }) t.Run("UsesActiveKey", func(t *testing.T) { @@ -247,7 +247,7 @@ func Test_Signing(t *testing.T) { Feature: database.CryptoKeyFeatureWorkspaceApps, Sequence: 32, Secret: sql.NullString{ - String: "secret", + String: mustGenerateKey(t), Valid: true, }, StartsAt: clock.Now().Add(time.Hour), @@ -257,7 +257,7 @@ func Test_Signing(t *testing.T) { Feature: database.CryptoKeyFeatureWorkspaceApps, Sequence: 33, Secret: sql.NullString{ - String: "secret", + String: mustGenerateKey(t), Valid: true, }, StartsAt: clock.Now(), @@ -265,12 +265,13 @@ func Test_Signing(t *testing.T) { mockDB.EXPECT().GetCryptoKeysByFeature(ctx, database.CryptoKeyFeatureWorkspaceApps).Return([]database.CryptoKey{inactiveKey, activeKey}, nil) - k := NewDBCache(logger, mockDB, database.CryptoKeyFeatureWorkspaceApps, WithDBCacheClock(clock)) + k := newDBCache(logger, mockDB, database.CryptoKeyFeatureWorkspaceApps, WithDBCacheClock(clock)) defer k.Close() - got, err := k.Signing(ctx) + id, secret, err := k.latest(ctx) require.NoError(t, err) - require.Equal(t, db2sdk.CryptoKey(activeKey), got) + require.Equal(t, keyID(activeKey), id) + require.Equal(t, decodedSecret(t, activeKey), secret) }) t.Run("NoValidKeys", func(t *testing.T) { @@ -288,7 +289,7 @@ func Test_Signing(t *testing.T) { Feature: database.CryptoKeyFeatureWorkspaceApps, Sequence: 32, Secret: sql.NullString{ - String: "secret", + String: mustGenerateKey(t), Valid: true, }, StartsAt: clock.Now().Add(time.Hour), @@ -298,7 +299,7 @@ func Test_Signing(t *testing.T) { Feature: database.CryptoKeyFeatureWorkspaceApps, Sequence: 33, Secret: sql.NullString{ - String: "secret", + String: mustGenerateKey(t), Valid: true, }, StartsAt: clock.Now().Add(-time.Hour), @@ -310,10 +311,10 @@ func Test_Signing(t *testing.T) { mockDB.EXPECT().GetCryptoKeysByFeature(ctx, database.CryptoKeyFeatureWorkspaceApps).Return([]database.CryptoKey{inactiveKey, invalidKey}, nil) - k := NewDBCache(logger, mockDB, database.CryptoKeyFeatureWorkspaceApps, WithDBCacheClock(clock)) + k := newDBCache(logger, mockDB, database.CryptoKeyFeatureWorkspaceApps, WithDBCacheClock(clock)) defer k.Close() - _, err := k.Signing(ctx) + _, _, err := k.latest(ctx) require.ErrorIs(t, err, ErrKeyInvalid) }) } @@ -332,14 +333,14 @@ func Test_clear(t *testing.T) { logger = slogtest.Make(t, nil) ) - k := NewDBCache(logger, mockDB, database.CryptoKeyFeatureWorkspaceApps, WithDBCacheClock(clock)) + k := newDBCache(logger, mockDB, database.CryptoKeyFeatureWorkspaceApps, WithDBCacheClock(clock)) defer k.Close() activeKey := database.CryptoKey{ Feature: database.CryptoKeyFeatureWorkspaceApps, Sequence: 33, Secret: sql.NullString{ - String: "secret", + String: mustGenerateKey(t), Valid: true, }, StartsAt: clock.Now(), @@ -347,7 +348,7 @@ func Test_clear(t *testing.T) { mockDB.EXPECT().GetCryptoKeysByFeature(ctx, database.CryptoKeyFeatureWorkspaceApps).Return([]database.CryptoKey{activeKey}, nil) - _, err := k.Signing(ctx) + _, _, err := k.latest(ctx) require.NoError(t, err) dur, wait := clock.AdvanceNext() @@ -368,14 +369,14 @@ func Test_clear(t *testing.T) { logger = slogtest.Make(t, nil) ) - k := NewDBCache(logger, mockDB, database.CryptoKeyFeatureWorkspaceApps, WithDBCacheClock(clock)) + k := newDBCache(logger, mockDB, database.CryptoKeyFeatureWorkspaceApps, WithDBCacheClock(clock)) defer k.Close() key := database.CryptoKey{ Feature: database.CryptoKeyFeatureWorkspaceApps, Sequence: 32, Secret: sql.NullString{ - String: "secret", + String: mustGenerateKey(t), Valid: true, }, StartsAt: clock.Now(), @@ -387,9 +388,10 @@ func Test_clear(t *testing.T) { // timer is reset and doesn't fire after another five minute. clock.Advance(time.Minute * 5) - latest, err := k.Signing(ctx) + id, secret, err := k.latest(ctx) require.NoError(t, err) - require.Equal(t, db2sdk.CryptoKey(key), latest) + require.Equal(t, keyID(key), id) + require.Equal(t, decodedSecret(t, key), secret) // Advancing the clock now should require 10 minutes // before the timer fires again. @@ -416,14 +418,14 @@ func Test_clear(t *testing.T) { trap := clock.Trap().Now("clear") - k := NewDBCache(logger, mockDB, database.CryptoKeyFeatureWorkspaceApps, WithDBCacheClock(clock)) + k := newDBCache(logger, mockDB, database.CryptoKeyFeatureWorkspaceApps, WithDBCacheClock(clock)) defer k.Close() key := database.CryptoKey{ Feature: database.CryptoKeyFeatureWorkspaceApps, Sequence: 32, Secret: sql.NullString{ - String: "secret", + String: mustGenerateKey(t), Valid: true, }, StartsAt: clock.Now(), @@ -432,9 +434,10 @@ func Test_clear(t *testing.T) { mockDB.EXPECT().GetCryptoKeysByFeature(ctx, database.CryptoKeyFeatureWorkspaceApps).Return([]database.CryptoKey{key}, nil).Times(2) // Move us past the initial timer. - latest, err := k.Signing(ctx) + id, secret, err := k.latest(ctx) require.NoError(t, err) - require.Equal(t, db2sdk.CryptoKey(key), latest) + require.Equal(t, keyID(key), id) + require.Equal(t, decodedSecret(t, key), secret) // Null these out so that we refetch. k.keys = nil k.latestKey = database.CryptoKey{} @@ -446,9 +449,10 @@ func Test_clear(t *testing.T) { call := trap.MustWait(ctx) // Refetch keys. - latest, err = k.Signing(ctx) + id, secret, err = k.latest(ctx) require.NoError(t, err) - require.Equal(t, db2sdk.CryptoKey(key), latest) + require.Equal(t, keyID(key), id) + require.Equal(t, decodedSecret(t, key), secret) // Let the rest of the timer function run. // It should see that we have refetched keys and @@ -466,3 +470,21 @@ func Test_clear(t *testing.T) { require.Equal(t, database.CryptoKey{}, k.latestKey) }) } + +func mustGenerateKey(t *testing.T) string { + t.Helper() + key, err := generateKey(64) + require.NoError(t, err) + return key +} + +func keyID(key database.CryptoKey) string { + return strconv.FormatInt(int64(key.Sequence), 10) +} + +func decodedSecret(t *testing.T, key database.CryptoKey) []byte { + t.Helper() + decoded, err := key.DecodeString() + require.NoError(t, err) + return decoded +} diff --git a/coderd/cryptokeys/dbkeycache_test.go b/coderd/cryptokeys/dbkeycache_test.go index 8c92cf3a90aa6..c421eeaf4b86f 100644 --- a/coderd/cryptokeys/dbkeycache_test.go +++ b/coderd/cryptokeys/dbkeycache_test.go @@ -1,6 +1,7 @@ package cryptokeys_test import ( + "strconv" "testing" "github.com/stretchr/testify/require" @@ -10,7 +11,6 @@ import ( "github.com/coder/coder/v2/coderd/cryptokeys" "github.com/coder/coder/v2/coderd/database" - "github.com/coder/coder/v2/coderd/database/db2sdk" "github.com/coder/coder/v2/coderd/database/dbgen" "github.com/coder/coder/v2/coderd/database/dbtestutil" "github.com/coder/coder/v2/testutil" @@ -24,7 +24,7 @@ func TestMain(m *testing.M) { func TestDBKeyCache(t *testing.T) { t.Parallel() - t.Run("Verifying", func(t *testing.T) { + t.Run("VerifyingKey", func(t *testing.T) { t.Parallel() t.Run("HitsCache", func(t *testing.T) { @@ -38,17 +38,18 @@ func TestDBKeyCache(t *testing.T) { ) key := dbgen.CryptoKey(t, db, database.CryptoKey{ - Feature: database.CryptoKeyFeatureWorkspaceApps, + Feature: database.CryptoKeyFeatureOidcConvert, Sequence: 1, StartsAt: clock.Now().UTC(), }) - k := cryptokeys.NewDBCache(logger, db, database.CryptoKeyFeatureWorkspaceApps, cryptokeys.WithDBCacheClock(clock)) + k, err := cryptokeys.NewSigningCache(logger, db, database.CryptoKeyFeatureOidcConvert, cryptokeys.WithDBCacheClock(clock)) + require.NoError(t, err) defer k.Close() - got, err := k.Verifying(ctx, key.Sequence) + got, err := k.VerifyingKey(ctx, keyID(key)) require.NoError(t, err) - require.Equal(t, db2sdk.CryptoKey(key), got) + require.Equal(t, decodedSecret(t, key), got) }) t.Run("NotFound", func(t *testing.T) { @@ -61,12 +62,14 @@ func TestDBKeyCache(t *testing.T) { logger = slogtest.Make(t, nil) ) - k := cryptokeys.NewDBCache(logger, db, database.CryptoKeyFeatureWorkspaceApps, cryptokeys.WithDBCacheClock(clock)) + k, err := cryptokeys.NewSigningCache(logger, db, database.CryptoKeyFeatureOidcConvert, cryptokeys.WithDBCacheClock(clock)) + require.NoError(t, err) defer k.Close() - _, err := k.Verifying(ctx, 123) + _, err = k.VerifyingKey(ctx, "123") require.ErrorIs(t, err, cryptokeys.ErrKeyNotFound) }) + }) t.Run("Signing", func(t *testing.T) { @@ -80,29 +83,31 @@ func TestDBKeyCache(t *testing.T) { ) _ = dbgen.CryptoKey(t, db, database.CryptoKey{ - Feature: database.CryptoKeyFeatureWorkspaceApps, + Feature: database.CryptoKeyFeatureOidcConvert, Sequence: 10, StartsAt: clock.Now().UTC(), }) expectedKey := dbgen.CryptoKey(t, db, database.CryptoKey{ - Feature: database.CryptoKeyFeatureWorkspaceApps, + Feature: database.CryptoKeyFeatureOidcConvert, Sequence: 12, StartsAt: clock.Now().UTC(), }) _ = dbgen.CryptoKey(t, db, database.CryptoKey{ - Feature: database.CryptoKeyFeatureWorkspaceApps, + Feature: database.CryptoKeyFeatureOidcConvert, Sequence: 2, StartsAt: clock.Now().UTC(), }) - k := cryptokeys.NewDBCache(logger, db, database.CryptoKeyFeatureWorkspaceApps, cryptokeys.WithDBCacheClock(clock)) + k, err := cryptokeys.NewSigningCache(logger, db, database.CryptoKeyFeatureOidcConvert, cryptokeys.WithDBCacheClock(clock)) + require.NoError(t, err) defer k.Close() - got, err := k.Signing(ctx) + id, key, err := k.SigningKey(ctx) require.NoError(t, err) - require.Equal(t, db2sdk.CryptoKey(expectedKey), got) + require.Equal(t, keyID(expectedKey), id) + require.Equal(t, decodedSecret(t, expectedKey), key) }) t.Run("Closed", func(t *testing.T) { @@ -116,28 +121,70 @@ func TestDBKeyCache(t *testing.T) { ) expectedKey := dbgen.CryptoKey(t, db, database.CryptoKey{ - Feature: database.CryptoKeyFeatureWorkspaceApps, + Feature: database.CryptoKeyFeatureOidcConvert, Sequence: 10, StartsAt: clock.Now(), }) - k := cryptokeys.NewDBCache(logger, db, database.CryptoKeyFeatureWorkspaceApps, cryptokeys.WithDBCacheClock(clock)) + k, err := cryptokeys.NewSigningCache(logger, db, database.CryptoKeyFeatureOidcConvert, cryptokeys.WithDBCacheClock(clock)) + require.NoError(t, err) defer k.Close() - got, err := k.Signing(ctx) + id, key, err := k.SigningKey(ctx) require.NoError(t, err) - require.Equal(t, db2sdk.CryptoKey(expectedKey), got) + require.Equal(t, keyID(expectedKey), id) + require.Equal(t, decodedSecret(t, expectedKey), key) - got, err = k.Verifying(ctx, expectedKey.Sequence) + key, err = k.VerifyingKey(ctx, keyID(expectedKey)) require.NoError(t, err) - require.Equal(t, db2sdk.CryptoKey(expectedKey), got) + require.Equal(t, decodedSecret(t, expectedKey), key) k.Close() - _, err = k.Signing(ctx) + _, _, err = k.SigningKey(ctx) require.ErrorIs(t, err, cryptokeys.ErrClosed) - _, err = k.Verifying(ctx, expectedKey.Sequence) + _, err = k.VerifyingKey(ctx, keyID(expectedKey)) require.ErrorIs(t, err, cryptokeys.ErrClosed) }) + + t.Run("InvalidSigningFeature", func(t *testing.T) { + t.Parallel() + + var ( + db, _ = dbtestutil.NewDB(t) + clock = quartz.NewMock(t) + logger = slogtest.Make(t, nil) + ) + + _, err := cryptokeys.NewSigningCache(logger, db, database.CryptoKeyFeatureWorkspaceApps, cryptokeys.WithDBCacheClock(clock)) + require.ErrorIs(t, err, cryptokeys.ErrInvalidFeature) + }) + + t.Run("InvalidEncryptionFeature", func(t *testing.T) { + t.Parallel() + + var ( + db, _ = dbtestutil.NewDB(t) + clock = quartz.NewMock(t) + logger = slogtest.Make(t, nil) + ) + + _, err := cryptokeys.NewEncryptionCache(logger, db, database.CryptoKeyFeatureOidcConvert, cryptokeys.WithDBCacheClock(clock)) + require.ErrorIs(t, err, cryptokeys.ErrInvalidFeature) + }) + +} + +func keyID(key database.CryptoKey) string { + return strconv.FormatInt(int64(key.Sequence), 10) +} + +func decodedSecret(t *testing.T, key database.CryptoKey) []byte { + t.Helper() + + secret, err := key.DecodeString() + require.NoError(t, err) + + return secret } diff --git a/coderd/cryptokeys/keycache.go b/coderd/cryptokeys/keycache.go index 8df4beab99077..a5e5d087a5ee2 100644 --- a/coderd/cryptokeys/keycache.go +++ b/coderd/cryptokeys/keycache.go @@ -2,28 +2,26 @@ package cryptokeys import ( "context" + "io" "golang.org/x/xerrors" ) var ( - ErrKeyNotFound = xerrors.New("key not found") - ErrKeyInvalid = xerrors.New("key is invalid for use") - ErrClosed = xerrors.New("closed") + ErrKeyNotFound = xerrors.New("key not found") + ErrKeyInvalid = xerrors.New("key is invalid for use") + ErrClosed = xerrors.New("closed") + ErrInvalidFeature = xerrors.New("invalid feature for this operation") ) -// Keycache provides an abstraction for fetching cryptographic keys used for signing or encrypting payloads. -type Keycache interface { - SigningKeycache - EncryptionKeycache -} - type EncryptionKeycache interface { EncryptingKey(ctx context.Context) (id string, key interface{}, err error) DecryptingKey(ctx context.Context, id string) (key interface{}, err error) + io.Closer } type SigningKeycache interface { SigningKey(ctx context.Context) (id string, key interface{}, err error) VerifyingKey(ctx context.Context, id string) (key interface{}, err error) + io.Closer } From 1a81c7ad548182aee42af81539d11ed04d15aaa5 Mon Sep 17 00:00:00 2001 From: Jon Ayers Date: Thu, 3 Oct 2024 07:43:39 +0000 Subject: [PATCH 15/25] Refactor JWT utility options for flexibility Refactor JWT utility code to replace `ParseOptions` with more specific `VerifyOptions` and `DecryptOptions`. This change enhances clarity and flexibility, accommodating different needs for JWS and JWE operations. Update tests to align with the refactored structure, ensuring robust coverage and functionality verification. --- coderd/jwtutils/jwe.go | 13 +- coderd/jwtutils/jws.go | 22 +- coderd/jwtutils/jwt.go | 28 -- coderd/jwtutils/jwt_test.go | 634 +++++++++++++----------------------- 4 files changed, 260 insertions(+), 437 deletions(-) delete mode 100644 coderd/jwtutils/jwt.go diff --git a/coderd/jwtutils/jwe.go b/coderd/jwtutils/jwe.go index 31e524eded55a..8dd33933b5bc1 100644 --- a/coderd/jwtutils/jwe.go +++ b/coderd/jwtutils/jwe.go @@ -62,9 +62,18 @@ func Encrypt(ctx context.Context, e EncryptKeyer, claims Claims) (string, error) return base64.RawURLEncoding.EncodeToString(serialized), nil } +// DecryptOptions are options for decrypting a JWE. +type DecryptOptions struct { + RegisteredClaims jwt.Expected + + // The following should only be used for JWEs. + KeyAlgorithm jose.KeyAlgorithm + ContentEncryptionAlgorithm jose.ContentEncryption +} + // Decrypt decrypts the token using the provided key. It unmarshals into the provided claims. -func Decrypt(ctx context.Context, d DecryptKeyer, token string, claims Claims, opts ...func(*ParseOptions)) error { - options := ParseOptions{ +func Decrypt(ctx context.Context, d DecryptKeyer, token string, claims Claims, opts ...func(*DecryptOptions)) error { + options := DecryptOptions{ RegisteredClaims: jwt.Expected{ Time: time.Now(), }, diff --git a/coderd/jwtutils/jws.go b/coderd/jwtutils/jws.go index b27744bccc7bf..432e93a8a78ab 100644 --- a/coderd/jwtutils/jws.go +++ b/coderd/jwtutils/jws.go @@ -11,6 +11,16 @@ import ( "golang.org/x/xerrors" ) +const ( + keyIDHeaderKey = "kid" +) + +// Claims defines the payload for a JWT. Most callers +// should embed jwt.Claims +type Claims interface { + Validate(jwt.Expected) error +} + const ( signingAlgo = jose.HS512 ) @@ -60,9 +70,17 @@ func Sign(ctx context.Context, s SignKeyer, claims Claims) (string, error) { return compact, nil } +// VerifyOptions are options for verifying a JWT. +type VerifyOptions struct { + RegisteredClaims jwt.Expected + + // The following are only used for JWSs. + SignatureAlgorithm jose.SignatureAlgorithm +} + // Verify verifies that a token was signed by the provided key. It unmarshals into the provided claims. -func Verify(ctx context.Context, v VerifyKeyer, token string, claims Claims, opts ...func(*ParseOptions)) error { - options := ParseOptions{ +func Verify(ctx context.Context, v VerifyKeyer, token string, claims Claims, opts ...func(*VerifyOptions)) error { + options := VerifyOptions{ RegisteredClaims: jwt.Expected{ Time: time.Now(), }, diff --git a/coderd/jwtutils/jwt.go b/coderd/jwtutils/jwt.go deleted file mode 100644 index eb8bbba915813..0000000000000 --- a/coderd/jwtutils/jwt.go +++ /dev/null @@ -1,28 +0,0 @@ -package jwtutils - -import ( - "github.com/go-jose/go-jose/v4" - "github.com/go-jose/go-jose/v4/jwt" -) - -const ( - keyIDHeaderKey = "kid" -) - -// Claims defines the payload for a JWT. Most callers -// should embed jwt.Claims -type Claims interface { - Validate(jwt.Expected) error -} - -// ParseOptions are options for parsing a JWT. -type ParseOptions struct { - RegisteredClaims jwt.Expected - - // The following are only used for JWSs. - SignatureAlgorithm jose.SignatureAlgorithm - - // The following should only be used for JWEs. - KeyAlgorithm jose.KeyAlgorithm - ContentEncryptionAlgorithm jose.ContentEncryption -} diff --git a/coderd/jwtutils/jwt_test.go b/coderd/jwtutils/jwt_test.go index 6a99cebd998f3..0496155a547f0 100644 --- a/coderd/jwtutils/jwt_test.go +++ b/coderd/jwtutils/jwt_test.go @@ -9,423 +9,234 @@ import ( "github.com/go-jose/go-jose/v4" "github.com/go-jose/go-jose/v4/jwt" + "github.com/google/uuid" "github.com/stretchr/testify/require" - "go.uber.org/mock/gomock" - "github.com/coder/coder/v2/coderd/cryptokeys" "github.com/coder/coder/v2/coderd/jwtutils" "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/testutil" ) -func TestJWT(t *testing.T) { +func TestClaims(t *testing.T) { t.Parallel() type tokenType struct { Name string KeySize int + Sign bool } types := []tokenType{ { Name: "JWE", + Sign: false, KeySize: 32, }, { Name: "JWS", + Sign: true, KeySize: 64, }, } + type testcase struct { + name string + claims jwt.Claims + expectedClaims jwt.Expected + expectedErr error + } + + cases := []testcase{ + { + name: "OK", + claims: jwt.Claims{ + Issuer: "coder", + Subject: "user@coder.com", + Audience: jwt.Audience{"coder"}, + Expiry: jwt.NewNumericDate(time.Now().Add(time.Hour)), + IssuedAt: jwt.NewNumericDate(time.Now()), + NotBefore: jwt.NewNumericDate(time.Now()), + }, + }, + { + name: "WrongIssuer", + claims: jwt.Claims{ + Issuer: "coder", + Subject: "user@coder.com", + Audience: jwt.Audience{"coder"}, + Expiry: jwt.NewNumericDate(time.Now().Add(time.Hour)), + IssuedAt: jwt.NewNumericDate(time.Now()), + NotBefore: jwt.NewNumericDate(time.Now()), + }, + expectedClaims: jwt.Expected{ + Issuer: "coder2", + }, + expectedErr: jwt.ErrInvalidIssuer, + }, + { + name: "WrongSubject", + claims: jwt.Claims{ + Issuer: "coder", + Subject: "user@coder.com", + Audience: jwt.Audience{"coder"}, + Expiry: jwt.NewNumericDate(time.Now().Add(time.Hour)), + IssuedAt: jwt.NewNumericDate(time.Now()), + NotBefore: jwt.NewNumericDate(time.Now()), + }, + expectedClaims: jwt.Expected{ + Subject: "user2@coder.com", + }, + expectedErr: jwt.ErrInvalidSubject, + }, + { + name: "WrongAudience", + claims: jwt.Claims{ + Issuer: "coder", + Subject: "user@coder.com", + Audience: jwt.Audience{"coder"}, + Expiry: jwt.NewNumericDate(time.Now().Add(time.Hour)), + IssuedAt: jwt.NewNumericDate(time.Now()), + NotBefore: jwt.NewNumericDate(time.Now()), + }, + }, + { + name: "Expired", + claims: jwt.Claims{ + Issuer: "coder", + Subject: "user@coder.com", + Audience: jwt.Audience{"coder"}, + Expiry: jwt.NewNumericDate(time.Now().Add(time.Minute)), + IssuedAt: jwt.NewNumericDate(time.Now()), + NotBefore: jwt.NewNumericDate(time.Now()), + }, + expectedClaims: jwt.Expected{ + Time: time.Now().Add(time.Minute * 3), + }, + expectedErr: jwt.ErrExpired, + }, + { + name: "IssuedInFuture", + claims: jwt.Claims{ + Issuer: "coder", + Subject: "user@coder.com", + Audience: jwt.Audience{"coder"}, + Expiry: jwt.NewNumericDate(time.Now().Add(time.Minute)), + IssuedAt: jwt.NewNumericDate(time.Now()), + }, + expectedClaims: jwt.Expected{ + Time: time.Now().Add(-time.Minute * 3), + }, + expectedErr: jwt.ErrIssuedInTheFuture, + }, + { + name: "IsBefore", + claims: jwt.Claims{ + Issuer: "coder", + Subject: "user@coder.com", + Audience: jwt.Audience{"coder"}, + Expiry: jwt.NewNumericDate(time.Now().Add(time.Hour)), + IssuedAt: jwt.NewNumericDate(time.Now()), + NotBefore: jwt.NewNumericDate(time.Now().Add(time.Minute * 5)), + }, + expectedClaims: jwt.Expected{ + Time: time.Now().Add(time.Minute * 3), + }, + expectedErr: jwt.ErrNotValidYet, + }, + } + for _, tt := range types { tt := tt t.Run(tt.Name, func(t *testing.T) { t.Parallel() - - t.Run("OK", func(t *testing.T) { - t.Parallel() - - var ( - ctx = testutil.Context(t, testutil.WaitShort) - ctrl = gomock.NewController(t) - keycache = cryptokeys.NewMockKeycache(ctrl) - now = time.Now() - ) - - key := generateCryptoKey(t, 1234567890, now, tt.KeySize) - - keycache.EXPECT().Signing(ctx).Return(key, nil) - keycache.EXPECT().Verifying(ctx, key.Sequence).Return(key, nil) - - claims := testClaims{ - Claims: jwt.Claims{ - Issuer: "coder", - Subject: "user@coder.com", - Audience: jwt.Audience{"coder"}, - Expiry: jwt.NewNumericDate(time.Now().Add(time.Hour)), - IssuedAt: jwt.NewNumericDate(time.Now()), - NotBefore: jwt.NewNumericDate(time.Now()), - }, - MyClaim: "my_value", - } - - var token string - var err error - - if tt.Name == "JWE" { - token, err = jwtutils.Encrypt(ctx, keycache, claims) + for _, c := range cases { + c := c + t.Run(c.name, func(t *testing.T) { + t.Parallel() + + var ( + ctx = testutil.Context(t, testutil.WaitShort) + key = newKey(t, tt.KeySize) + token string + err error + ) + + if tt.Sign { + token, err = jwtutils.Sign(ctx, key, c.claims) + } else { + token, err = jwtutils.Encrypt(ctx, key, c.claims) + } require.NoError(t, err) - } else { - token, err = jwtutils.Sign(ctx, keycache, claims) + + var actual testClaims + if tt.Sign { + err = jwtutils.Verify(ctx, key, token, &actual) + } else { + err = jwtutils.Decrypt(ctx, key, token, &actual) + } require.NoError(t, err) - } - - token, err := tt.SignFn(ctx, keycache, claims) - require.NoError(t, err) - - var actual testClaims - err = tt.VerifyFn(ctx, keycache, token, &actual) - require.NoError(t, err) - require.Equal(t, claims, actual) - }) - - t.Run("WrongIssuer", func(t *testing.T) { - t.Parallel() - - var ( - ctx = testutil.Context(t, testutil.WaitShort) - ctrl = gomock.NewController(t) - keycache = cryptokeys.NewMockKeycache(ctrl) - now = time.Now() - ) - - key := generateCryptoKey(t, 1234567890, now, tt.KeySize) - - keycache.EXPECT().Signing(ctx).Return(key, nil) - keycache.EXPECT().Verifying(ctx, key.Sequence).Return(key, nil) - - claims := testClaims{ - Claims: jwt.Claims{ - Issuer: "coder", - Subject: "user@coder.com", - Audience: jwt.Audience{"coder"}, - Expiry: jwt.NewNumericDate(time.Now().Add(time.Hour)), - IssuedAt: jwt.NewNumericDate(time.Now()), - NotBefore: jwt.NewNumericDate(time.Now()), - }, - MyClaim: "my_value", - } - - token, err := tt.SignFn(ctx, keycache, claims) - require.NoError(t, err) - - var actual testClaims - err = tt.VerifyFn(ctx, keycache, token, &actual, withExpected(jwt.Expected{ - Issuer: "coder2", - })) - require.ErrorIs(t, err, jwt.ErrInvalidIssuer) - }) - - t.Run("WrongSubject", func(t *testing.T) { - t.Parallel() - - var ( - ctx = testutil.Context(t, testutil.WaitShort) - ctrl = gomock.NewController(t) - keycache = cryptokeys.NewMockKeycache(ctrl) - now = time.Now() - ) - - key := generateCryptoKey(t, 1234567890, now, tt.KeySize) - - keycache.EXPECT().Signing(ctx).Return(key, nil) - keycache.EXPECT().Verifying(ctx, key.Sequence).Return(key, nil) - - claims := testClaims{ - Claims: jwt.Claims{ - Issuer: "coder", - Subject: "user@coder.com", - Audience: jwt.Audience{"coder"}, - Expiry: jwt.NewNumericDate(time.Now().Add(time.Hour)), - IssuedAt: jwt.NewNumericDate(time.Now()), - NotBefore: jwt.NewNumericDate(time.Now()), - }, - MyClaim: "my_value", - } - - token, err := tt.SignFn(ctx, keycache, claims) - require.NoError(t, err) - - var actual testClaims - err = tt.VerifyFn(ctx, keycache, token, &actual, withExpected(jwt.Expected{ - Subject: "user2@coder.com", - })) - require.ErrorIs(t, err, jwt.ErrInvalidSubject) - }) - - t.Run("WrongAudience", func(t *testing.T) { - t.Parallel() - - var ( - ctx = testutil.Context(t, testutil.WaitShort) - ctrl = gomock.NewController(t) - keycache = cryptokeys.NewMockKeycache(ctrl) - now = time.Now() - key = generateCryptoKey(t, 1234567890, now, tt.KeySize) - ) - - keycache.EXPECT().Signing(ctx).Return(key, nil) - keycache.EXPECT().Verifying(ctx, key.Sequence).Return(key, nil) - - claims := testClaims{ - Claims: jwt.Claims{ - Issuer: "coder", - Subject: "user@coder.com", - Audience: jwt.Audience{"coder"}, - Expiry: jwt.NewNumericDate(time.Now().Add(time.Hour)), - IssuedAt: jwt.NewNumericDate(time.Now()), - NotBefore: jwt.NewNumericDate(time.Now()), - }, - MyClaim: "my_value", - } - - token, err := tt.SignFn(ctx, keycache, claims) - require.NoError(t, err) - - var actual testClaims - err = tt.VerifyFn(ctx, keycache, token, &actual, withExpected(jwt.Expected{ - AnyAudience: jwt.Audience{"coder2"}, - })) - require.ErrorIs(t, err, jwt.ErrInvalidAudience) - }) - - t.Run("Expired", func(t *testing.T) { - t.Parallel() - - var ( - ctx = testutil.Context(t, testutil.WaitShort) - ctrl = gomock.NewController(t) - keycache = cryptokeys.NewMockKeycache(ctrl) - now = time.Now() - key = generateCryptoKey(t, 1234567890, now, tt.KeySize) - ) - - keycache.EXPECT().Signing(ctx).Return(key, nil) - keycache.EXPECT().Verifying(ctx, key.Sequence).Return(key, nil) - - claims := testClaims{ - Claims: jwt.Claims{ - Issuer: "coder", - Subject: "user@coder.com", - Audience: jwt.Audience{"coder"}, - Expiry: jwt.NewNumericDate(time.Now().Add(time.Minute)), - IssuedAt: jwt.NewNumericDate(time.Now()), - NotBefore: jwt.NewNumericDate(time.Now()), - }, - MyClaim: "my_value", - } - - token, err := tt.SignFn(ctx, keycache, claims) - require.NoError(t, err) - - var actual testClaims - err = tt.VerifyFn(ctx, keycache, token, &actual, withExpected(jwt.Expected{ - Time: time.Now().Add(time.Minute * 3), - })) - require.ErrorIs(t, err, jwt.ErrExpired) - }) - - t.Run("IssuedInFuture", func(t *testing.T) { - t.Parallel() - - var ( - ctx = testutil.Context(t, testutil.WaitShort) - ctrl = gomock.NewController(t) - keycache = cryptokeys.NewMockKeycache(ctrl) - now = time.Now() - ) - - key := generateCryptoKey(t, 1234567890, now, tt.KeySize) - - keycache.EXPECT().Signing(ctx).Return(key, nil) - keycache.EXPECT().Verifying(ctx, key.Sequence).Return(key, nil) - - claims := testClaims{ - Claims: jwt.Claims{ - Issuer: "coder", - Subject: "user@coder.com", - Audience: jwt.Audience{"coder"}, - Expiry: jwt.NewNumericDate(time.Now().Add(time.Minute)), - IssuedAt: jwt.NewNumericDate(time.Now()), - }, - MyClaim: "my_value", - } - - token, err := tt.SignFn(ctx, keycache, claims) - require.NoError(t, err) - - var actual testClaims - err = tt.VerifyFn(ctx, keycache, token, &actual, withExpected(jwt.Expected{ - Time: time.Now().Add(-time.Minute * 3), - })) - require.ErrorIs(t, err, jwt.ErrIssuedInTheFuture) - }) - - t.Run("IsBefore", func(t *testing.T) { - t.Parallel() - - var ( - ctx = testutil.Context(t, testutil.WaitShort) - ctrl = gomock.NewController(t) - keycache = cryptokeys.NewMockKeycache(ctrl) - now = time.Now() - ) - - key := generateCryptoKey(t, 1234567890, now, tt.KeySize) - - keycache.EXPECT().Signing(ctx).Return(key, nil) - keycache.EXPECT().Verifying(ctx, key.Sequence).Return(key, nil) - - claims := testClaims{ - Claims: jwt.Claims{ - Issuer: "coder", - Subject: "user@coder.com", - Audience: jwt.Audience{"coder"}, - Expiry: jwt.NewNumericDate(time.Now().Add(time.Hour)), - IssuedAt: jwt.NewNumericDate(time.Now()), - NotBefore: jwt.NewNumericDate(time.Now().Add(time.Minute * 5)), - }, - MyClaim: "my_value", - } - - token, err := tt.SignFn(ctx, keycache, claims) - require.NoError(t, err) - - var actual testClaims - err = tt.VerifyFn(ctx, keycache, token, &actual, withExpected(jwt.Expected{ - Time: time.Now().Add(time.Minute * 3), - })) - require.ErrorIs(t, err, jwt.ErrNotValidYet) - }) - - t.Run("WrongSignatureAlgorithm", func(t *testing.T) { - t.Parallel() - - if tt.Name == "JWE" { - t.Skip("JWE does not support this") - } - - var ( - ctx = testutil.Context(t, testutil.WaitShort) - ctrl = gomock.NewController(t) - keycache = cryptokeys.NewMockKeycache(ctrl) - now = time.Now() - ) - - key := generateCryptoKey(t, 1234567890, now, tt.KeySize) - - keycache.EXPECT().Signing(ctx).Return(key, nil) - - claims := testClaims{ - Claims: jwt.Claims{ - Issuer: "coder", - Subject: "user@coder.com", - Audience: jwt.Audience{"coder"}, - Expiry: jwt.NewNumericDate(time.Now().Add(time.Hour)), - IssuedAt: jwt.NewNumericDate(time.Now()), - NotBefore: jwt.NewNumericDate(time.Now().Add(time.Minute * 5)), - }, - MyClaim: "my_value", - } - - token, err := tt.SignFn(ctx, keycache, claims) - require.NoError(t, err) - - var actual testClaims - err = tt.VerifyFn(ctx, keycache, token, &actual, withSignatureAlgorithm(jose.HS256)) - require.Error(t, err) - }) - - t.Run("WrongKeyAlgorithm", func(t *testing.T) { - t.Parallel() - - if tt.Name == "JWS" { - t.Skip("JWS does not support this") - } - - var ( - ctx = testutil.Context(t, testutil.WaitShort) - ctrl = gomock.NewController(t) - keycache = cryptokeys.NewMockKeycache(ctrl) - now = time.Now() - ) - - key := generateCryptoKey(t, 1234567890, now, tt.KeySize) - - keycache.EXPECT().Signing(ctx).Return(key, nil) - - claims := testClaims{ - Claims: jwt.Claims{ - Issuer: "coder", - Subject: "user@coder.com", - Audience: jwt.Audience{"coder"}, - Expiry: jwt.NewNumericDate(time.Now().Add(time.Hour)), - IssuedAt: jwt.NewNumericDate(time.Now()), - NotBefore: jwt.NewNumericDate(time.Now().Add(time.Minute * 5)), - }, - MyClaim: "my_value", - } - - token, err := tt.SignFn(ctx, keycache, claims) - require.NoError(t, err) - - var actual testClaims - err = tt.VerifyFn(ctx, keycache, token, &actual, withKeyAlgorithm(jose.A128GCMKW)) - require.Error(t, err) - }) - - t.Run("WrongContentyEncryption", func(t *testing.T) { - t.Parallel() - - if tt.Name == "JWS" { - t.Skip("JWS does not support this") - } - - var ( - ctx = testutil.Context(t, testutil.WaitShort) - ctrl = gomock.NewController(t) - keycache = cryptokeys.NewMockKeycache(ctrl) - now = time.Now() - ) - - key := generateCryptoKey(t, 1234567890, now, tt.KeySize) - - keycache.EXPECT().Signing(gomock.Any()).Return(key, nil) - - claims := testClaims{ - Claims: jwt.Claims{ - Issuer: "coder", - Subject: "user@coder.com", - Audience: jwt.Audience{"coder"}, - Expiry: jwt.NewNumericDate(time.Now().Add(time.Hour)), - IssuedAt: jwt.NewNumericDate(time.Now()), - NotBefore: jwt.NewNumericDate(time.Now().Add(time.Minute * 5)), - }, - MyClaim: "my_value", - } - - token, err := tt.SignFn(ctx, keycache, claims) - require.NoError(t, err) - - var actual testClaims - err = tt.VerifyFn(ctx, keycache, token, &actual, withContentEncryptionAlgorithm(jose.A128GCM)) - require.Error(t, err) - }) + require.Equal(t, c.claims, actual) + }) + } }) } } +func TestJWS(t *testing.T) { + t.Run("WrongSignatureAlgorithm", func(t *testing.T) { + t.Parallel() + + var ( + ctx = testutil.Context(t, testutil.WaitShort) + ) + + key := newKey(t, 64) + + token, err := jwtutils.Sign(ctx, key, jwt.Claims{}) + require.NoError(t, err) + + var actual testClaims + err = jwtutils.Verify(ctx, key, token, &actual, withSignatureAlgorithm(jose.HS256)) + require.Error(t, err) + + }) +} + +func TestJWE(t *testing.T) { + t.Run("WrongKeyAlgorithm", func(t *testing.T) { + t.Parallel() + + var ( + ctx = testutil.Context(t, testutil.WaitShort) + key = newKey(t, 32) + ) + + token, err := jwtutils.Encrypt(ctx, key, jwt.Claims{}) + require.NoError(t, err) + + var actual testClaims + err = jwtutils.Decrypt(ctx, key, token, &actual, withKeyAlgorithm(jose.A128GCMKW)) + require.Error(t, err) + }) + + t.Run("WrongContentyEncryption", func(t *testing.T) { + t.Parallel() + + var ( + ctx = testutil.Context(t, testutil.WaitShort) + key = newKey(t, 32) + ) + + token, err := jwtutils.Encrypt(ctx, key, jwt.Claims{}) + require.NoError(t, err) + + var actual testClaims + err = jwtutils.Decrypt(ctx, key, token, &actual, withContentEncryptionAlgorithm(jose.A128GCM)) + require.Error(t, err) + }) +} + func generateCryptoKey(t *testing.T, seq int32, now time.Time, keySize int) codersdk.CryptoKey { t.Helper() @@ -453,54 +264,67 @@ type testClaims struct { jwt.Claims } -func withExpected(e jwt.Expected) func(*jwtutils.ParseOptions) { - return func(opts *jwtutils.ParseOptions) { +func withExpected(e jwt.Expected) func(*jwtutils.VerifyOptions) { + return func(opts *jwtutils.VerifyOptions) { opts.RegisteredClaims = e } } -func withSignatureAlgorithm(alg jose.SignatureAlgorithm) func(*jwtutils.ParseOptions) { - return func(opts *jwtutils.ParseOptions) { +func withSignatureAlgorithm(alg jose.SignatureAlgorithm) func(*jwtutils.VerifyOptions) { + return func(opts *jwtutils.VerifyOptions) { opts.SignatureAlgorithm = alg } } -func withKeyAlgorithm(alg jose.KeyAlgorithm) func(*jwtutils.ParseOptions) { - return func(opts *jwtutils.ParseOptions) { +func withKeyAlgorithm(alg jose.KeyAlgorithm) func(*jwtutils.DecryptOptions) { + return func(opts *jwtutils.DecryptOptions) { opts.KeyAlgorithm = alg } } -func withContentEncryptionAlgorithm(alg jose.ContentEncryption) func(*jwtutils.ParseOptions) { - return func(opts *jwtutils.ParseOptions) { +func withContentEncryptionAlgorithm(alg jose.ContentEncryption) func(*jwtutils.DecryptOptions) { + return func(opts *jwtutils.DecryptOptions) { opts.ContentEncryptionAlgorithm = alg } } -type godkey interface { - jwtutils.SignKeyer - jwtutils.VerifyKeyer - jwtutils.EncryptKeyer - jwtutils.DecryptKeyer +type key struct { + t testing.TB + id string + secret []byte } -type key struct { - signFn func(context.Context) (string, interface{}, error) - verifyFn func(context.Context, string) (interface{}, error) +func newKey(t *testing.T, size int) *key { + t.Helper() + + id := uuid.New().String() + secret := generateSecret(t, size) + + return &key{ + t: t, + id: id, + secret: secret, + } } -func (k *key) SigningKey(ctx context.Context) (string, interface{}, error) { - return k.signFn(ctx) +func (k *key) SigningKey(ctx context.Context) (id string, key interface{}, err error) { + return k.id, k.secret, nil } -func (k *key) VerifyingKey(ctx context.Context, id string) (interface{}, error) { - return k.verifyFn(ctx, id) +func (k *key) VerifyingKey(ctx context.Context, id string) (key interface{}, err error) { + k.t.Helper() + + require.Equal(k.t, k.id, id) + return k.secret, nil } -func (k *key) EncryptingKey(ctx context.Context) (string, interface{}, error) { - return k.signFn(ctx) +func (k *key) EncryptingKey(ctx context.Context) (id string, key interface{}, err error) { + return k.id, k.secret, nil } -func (k *key) DecryptingKey(ctx context.Context, id string) (interface{}, error) { - return k.verifyFn(ctx, id) +func (k *key) DecryptingKey(ctx context.Context, id string) (key interface{}, err error) { + k.t.Helper() + + require.Equal(k.t, k.id, id) + return k.secret, nil } From e529c4abf10bacbc7bf46b26c8bdba2038e89dc8 Mon Sep 17 00:00:00 2001 From: Jon Ayers Date: Thu, 3 Oct 2024 08:06:10 +0000 Subject: [PATCH 16/25] Enhance key generation and JWT error messages Increase OIDC secret length for added security and clarity. Optimize JWT test coverage with custom claims and key cache tests. --- coderd/cryptokeys/rotate.go | 4 +- coderd/database/dbgen/dbgen.go | 4 +- coderd/jwtutils/jwe.go | 8 +- coderd/jwtutils/jws.go | 14 +--- coderd/jwtutils/jwt_test.go | 138 +++++++++++++++++++++++++++++++-- 5 files changed, 143 insertions(+), 25 deletions(-) diff --git a/coderd/cryptokeys/rotate.go b/coderd/cryptokeys/rotate.go index 224b9100d5bf8..14a623e2156db 100644 --- a/coderd/cryptokeys/rotate.go +++ b/coderd/cryptokeys/rotate.go @@ -227,9 +227,9 @@ func (k *rotator) rotateKey(ctx context.Context, tx database.Store, key database func generateNewSecret(feature database.CryptoKeyFeature) (string, error) { switch feature { case database.CryptoKeyFeatureWorkspaceApps: - return generateKey(96) - case database.CryptoKeyFeatureOidcConvert: return generateKey(32) + case database.CryptoKeyFeatureOidcConvert: + return generateKey(64) case database.CryptoKeyFeatureTailnetResume: return generateKey(64) } diff --git a/coderd/database/dbgen/dbgen.go b/coderd/database/dbgen/dbgen.go index 1a2f052a279b3..93439fd0f2b77 100644 --- a/coderd/database/dbgen/dbgen.go +++ b/coderd/database/dbgen/dbgen.go @@ -988,9 +988,9 @@ func takeFirst[Value comparable](values ...Value) Value { func newCryptoKeySecret(feature database.CryptoKeyFeature) (string, error) { switch feature { case database.CryptoKeyFeatureWorkspaceApps: - return generateCryptoKey(96) - case database.CryptoKeyFeatureOidcConvert: return generateCryptoKey(32) + case database.CryptoKeyFeatureOidcConvert: + return generateCryptoKey(64) case database.CryptoKeyFeatureTailnetResume: return generateCryptoKey(64) } diff --git a/coderd/jwtutils/jwe.go b/coderd/jwtutils/jwe.go index 8dd33933b5bc1..0621d7d95696f 100644 --- a/coderd/jwtutils/jwe.go +++ b/coderd/jwtutils/jwe.go @@ -102,14 +102,14 @@ func Decrypt(ctx context.Context, d DecryptKeyer, token string, claims Claims, o return xerrors.Errorf("expected JWE algorithm to be %q, got %q", encryptKeyAlgo, object.Header.Algorithm) } - sequenceStr := object.Header.KeyID - if sequenceStr == "" { + kid := object.Header.KeyID + if kid == "" { return xerrors.Errorf("expected %q header to be a string", keyIDHeaderKey) } - key, err := d.DecryptingKey(ctx, sequenceStr) + key, err := d.DecryptingKey(ctx, kid) if err != nil { - return xerrors.Errorf("version: %w", err) + return xerrors.Errorf("key with id %q: %w", kid, err) } decrypted, err := object.Decrypt(key) diff --git a/coderd/jwtutils/jws.go b/coderd/jwtutils/jws.go index 432e93a8a78ab..ac369eaaa379d 100644 --- a/coderd/jwtutils/jws.go +++ b/coderd/jwtutils/jws.go @@ -3,7 +3,6 @@ package jwtutils import ( "context" "encoding/json" - "strconv" "time" "github.com/go-jose/go-jose/v4" @@ -106,19 +105,14 @@ func Verify(ctx context.Context, v VerifyKeyer, token string, claims Claims, opt return xerrors.Errorf("expected JWS algorithm to be %q, got %q", signingAlgo, object.Signatures[0].Header.Algorithm) } - sequenceStr := signature.Header.KeyID - if sequenceStr == "" { + kid := signature.Header.KeyID + if kid == "" { return xerrors.Errorf("expected %q header to be a string", keyIDHeaderKey) } - sequence, err := strconv.ParseInt(sequenceStr, 10, 32) + key, err := v.VerifyingKey(ctx, kid) if err != nil { - return xerrors.Errorf("parse sequence %q: %w", sequenceStr, err) - } - - key, err := v.VerifyingKey(ctx, sequenceStr) - if err != nil { - return xerrors.Errorf("verifying key for seq %v: %w", sequence, err) + return xerrors.Errorf("key with id %q: %w", kid, err) } payload, err := object.Verify(key) diff --git a/coderd/jwtutils/jwt_test.go b/coderd/jwtutils/jwt_test.go index 0496155a547f0..59e23d4ff4064 100644 --- a/coderd/jwtutils/jwt_test.go +++ b/coderd/jwtutils/jwt_test.go @@ -12,6 +12,12 @@ import ( "github.com/google/uuid" "github.com/stretchr/testify/require" + "cdr.dev/slog/sloggers/slogtest" + + "github.com/coder/coder/v2/coderd/cryptokeys" + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/dbgen" + "github.com/coder/coder/v2/coderd/database/dbtestutil" "github.com/coder/coder/v2/coderd/jwtutils" "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/testutil" @@ -41,7 +47,7 @@ func TestClaims(t *testing.T) { type testcase struct { name string - claims jwt.Claims + claims jwtutils.Claims expectedClaims jwt.Expected expectedErr error } @@ -169,14 +175,18 @@ func TestClaims(t *testing.T) { } require.NoError(t, err) - var actual testClaims + var actual jwt.Claims if tt.Sign { - err = jwtutils.Verify(ctx, key, token, &actual) + err = jwtutils.Verify(ctx, key, token, &actual, withVerifyExpected(c.expectedClaims)) } else { - err = jwtutils.Decrypt(ctx, key, token, &actual) + err = jwtutils.Decrypt(ctx, key, token, &actual, withDecryptExpected(c.expectedClaims)) + } + if c.expectedErr != nil { + require.ErrorIs(t, err, c.expectedErr) + } else { + require.NoError(t, err) + require.Equal(t, c.claims, actual) } - require.NoError(t, err) - require.Equal(t, c.claims, actual) }) } }) @@ -184,6 +194,7 @@ func TestClaims(t *testing.T) { } func TestJWS(t *testing.T) { + t.Parallel() t.Run("WrongSignatureAlgorithm", func(t *testing.T) { t.Parallel() @@ -199,11 +210,65 @@ func TestJWS(t *testing.T) { var actual testClaims err = jwtutils.Verify(ctx, key, token, &actual, withSignatureAlgorithm(jose.HS256)) require.Error(t, err) + }) + + t.Run("CustomClaims", func(t *testing.T) { + t.Parallel() + + var ( + ctx = testutil.Context(t, testutil.WaitShort) + key = newKey(t, 64) + ) + + expected := testClaims{ + MyClaim: "my_value", + } + token, err := jwtutils.Sign(ctx, key, expected) + require.NoError(t, err) + + var actual testClaims + err = jwtutils.Verify(ctx, key, token, &actual, withVerifyExpected(jwt.Expected{})) + require.NoError(t, err) + require.Equal(t, expected, actual) + }) + + t.Run("WithKeycache", func(t *testing.T) { + t.Parallel() + var ( + ctx = testutil.Context(t, testutil.WaitShort) + db, _ = dbtestutil.NewDB(t) + _ = dbgen.CryptoKey(t, db, database.CryptoKey{ + Feature: database.CryptoKeyFeatureOidcConvert, + StartsAt: time.Now(), + }) + log = slogtest.Make(t, nil) + ) + + cache, err := cryptokeys.NewSigningCache(log, db, database.CryptoKeyFeatureOidcConvert) + require.NoError(t, err) + + claims := testClaims{ + MyClaim: "my_value", + Claims: jwt.Claims{ + Expiry: jwt.NewNumericDate(time.Now().Add(time.Hour)), + }, + } + + token, err := jwtutils.Sign(ctx, cache, claims) + require.NoError(t, err) + + var actual testClaims + err = jwtutils.Verify(ctx, cache, token, &actual) + require.NoError(t, err) + require.Equal(t, claims, actual) }) + } func TestJWE(t *testing.T) { + t.Parallel() + t.Run("WrongKeyAlgorithm", func(t *testing.T) { t.Parallel() @@ -235,6 +300,59 @@ func TestJWE(t *testing.T) { err = jwtutils.Decrypt(ctx, key, token, &actual, withContentEncryptionAlgorithm(jose.A128GCM)) require.Error(t, err) }) + + t.Run("CustomClaims", func(t *testing.T) { + t.Parallel() + + var ( + ctx = testutil.Context(t, testutil.WaitShort) + key = newKey(t, 32) + ) + + expected := testClaims{ + MyClaim: "my_value", + } + + token, err := jwtutils.Encrypt(ctx, key, expected) + require.NoError(t, err) + + var actual testClaims + err = jwtutils.Decrypt(ctx, key, token, &actual, withDecryptExpected(jwt.Expected{})) + require.NoError(t, err) + require.Equal(t, expected, actual) + }) + + t.Run("WithKeycache", func(t *testing.T) { + t.Parallel() + + var ( + ctx = testutil.Context(t, testutil.WaitShort) + db, _ = dbtestutil.NewDB(t) + _ = dbgen.CryptoKey(t, db, database.CryptoKey{ + Feature: database.CryptoKeyFeatureWorkspaceApps, + StartsAt: time.Now(), + }) + log = slogtest.Make(t, nil) + ) + + cache, err := cryptokeys.NewEncryptionCache(log, db, database.CryptoKeyFeatureWorkspaceApps) + require.NoError(t, err) + + claims := testClaims{ + MyClaim: "my_value", + Claims: jwt.Claims{ + Expiry: jwt.NewNumericDate(time.Now().Add(time.Hour)), + }, + } + + token, err := jwtutils.Encrypt(ctx, cache, claims) + require.NoError(t, err) + + var actual testClaims + err = jwtutils.Decrypt(ctx, cache, token, &actual) + require.NoError(t, err) + require.Equal(t, claims, actual) + }) } func generateCryptoKey(t *testing.T, seq int32, now time.Time, keySize int) codersdk.CryptoKey { @@ -264,7 +382,13 @@ type testClaims struct { jwt.Claims } -func withExpected(e jwt.Expected) func(*jwtutils.VerifyOptions) { +func withDecryptExpected(e jwt.Expected) func(*jwtutils.DecryptOptions) { + return func(opts *jwtutils.DecryptOptions) { + opts.RegisteredClaims = e + } +} + +func withVerifyExpected(e jwt.Expected) func(*jwtutils.VerifyOptions) { return func(opts *jwtutils.VerifyOptions) { opts.RegisteredClaims = e } From 437e587e36724a07c7e8afb91599f652e76bcefa Mon Sep 17 00:00:00 2001 From: Jon Ayers Date: Thu, 3 Oct 2024 08:07:00 +0000 Subject: [PATCH 17/25] Update cryptographic key length requirements --- coderd/cryptokeys/rotate_internal_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/coderd/cryptokeys/rotate_internal_test.go b/coderd/cryptokeys/rotate_internal_test.go index 36ecf4fa9d76d..43754c1d8750f 100644 --- a/coderd/cryptokeys/rotate_internal_test.go +++ b/coderd/cryptokeys/rotate_internal_test.go @@ -588,9 +588,9 @@ func requireKey(t *testing.T, key database.CryptoKey, feature database.CryptoKey switch key.Feature { case database.CryptoKeyFeatureOidcConvert: - require.Len(t, secret, 32) + require.Len(t, secret, 64) case database.CryptoKeyFeatureWorkspaceApps: - require.Len(t, secret, 96) + require.Len(t, secret, 32) case database.CryptoKeyFeatureTailnetResume: require.Len(t, secret, 64) default: From 54214e29aedc86d0fa968399fa47c34defba038d Mon Sep 17 00:00:00 2001 From: Jon Ayers Date: Thu, 3 Oct 2024 08:10:24 +0000 Subject: [PATCH 18/25] Refactor key provider interfaces in JWT utilities --- coderd/jwtutils/jwe.go | 8 ++++---- coderd/jwtutils/jws.go | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/coderd/jwtutils/jwe.go b/coderd/jwtutils/jwe.go index 0621d7d95696f..35c7aa2cdbd14 100644 --- a/coderd/jwtutils/jwe.go +++ b/coderd/jwtutils/jwe.go @@ -16,16 +16,16 @@ const ( encryptContentAlgo = jose.A256GCM ) -type EncryptKeyer interface { +type EncryptKeyProvider interface { EncryptingKey(ctx context.Context) (id string, key interface{}, err error) } -type DecryptKeyer interface { +type DecryptKeyProvider interface { DecryptingKey(ctx context.Context, id string) (key interface{}, err error) } // Encrypt encrypts a token and returns it as a string. -func Encrypt(ctx context.Context, e EncryptKeyer, claims Claims) (string, error) { +func Encrypt(ctx context.Context, e EncryptKeyProvider, claims Claims) (string, error) { id, key, err := e.EncryptingKey(ctx) if err != nil { return "", xerrors.Errorf("get signing key: %w", err) @@ -72,7 +72,7 @@ type DecryptOptions struct { } // Decrypt decrypts the token using the provided key. It unmarshals into the provided claims. -func Decrypt(ctx context.Context, d DecryptKeyer, token string, claims Claims, opts ...func(*DecryptOptions)) error { +func Decrypt(ctx context.Context, d DecryptKeyProvider, token string, claims Claims, opts ...func(*DecryptOptions)) error { options := DecryptOptions{ RegisteredClaims: jwt.Expected{ Time: time.Now(), diff --git a/coderd/jwtutils/jws.go b/coderd/jwtutils/jws.go index ac369eaaa379d..cb04a1273bb97 100644 --- a/coderd/jwtutils/jws.go +++ b/coderd/jwtutils/jws.go @@ -24,16 +24,16 @@ const ( signingAlgo = jose.HS512 ) -type SignKeyer interface { +type SigningKeyProvider interface { SigningKey(ctx context.Context) (id string, key interface{}, err error) } -type VerifyKeyer interface { +type VerifyKeyProvider interface { VerifyingKey(ctx context.Context, id string) (key interface{}, err error) } // Sign signs a token and returns it as a string. -func Sign(ctx context.Context, s SignKeyer, claims Claims) (string, error) { +func Sign(ctx context.Context, s SigningKeyProvider, claims Claims) (string, error) { id, key, err := s.SigningKey(ctx) if err != nil { return "", xerrors.Errorf("get signing key: %w", err) @@ -78,7 +78,7 @@ type VerifyOptions struct { } // Verify verifies that a token was signed by the provided key. It unmarshals into the provided claims. -func Verify(ctx context.Context, v VerifyKeyer, token string, claims Claims, opts ...func(*VerifyOptions)) error { +func Verify(ctx context.Context, v VerifyKeyProvider, token string, claims Claims, opts ...func(*VerifyOptions)) error { options := VerifyOptions{ RegisteredClaims: jwt.Expected{ Time: time.Now(), From 93603a2a393ae8e6c62e6f241ce42a810aa76ac2 Mon Sep 17 00:00:00 2001 From: Jon Ayers Date: Thu, 3 Oct 2024 08:12:17 +0000 Subject: [PATCH 19/25] Refactor dbCache to remove feature validation --- coderd/cryptokeys/dbkeycache.go | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/coderd/cryptokeys/dbkeycache.go b/coderd/cryptokeys/dbkeycache.go index e6b9952ade90d..14f16f682f5af 100644 --- a/coderd/cryptokeys/dbkeycache.go +++ b/coderd/cryptokeys/dbkeycache.go @@ -77,33 +77,18 @@ func newDBCache(logger slog.Logger, db database.Store, feature database.CryptoKe } func (d *dbCache) EncryptingKey(ctx context.Context) (id string, key interface{}, err error) { - if !isEncryptionKeyFeature(d.feature) { - return "", nil, xerrors.Errorf("invalid feature: %s", d.feature) - } return d.latest(ctx) } func (d *dbCache) DecryptingKey(ctx context.Context, id string) (key interface{}, err error) { - if !isEncryptionKeyFeature(d.feature) { - return nil, xerrors.Errorf("invalid feature: %s", d.feature) - } - return d.sequence(ctx, id) } func (d *dbCache) SigningKey(ctx context.Context) (id string, key interface{}, err error) { - if !isSigningKeyFeature(d.feature) { - return "", nil, xerrors.Errorf("invalid feature: %s", d.feature) - } - return d.latest(ctx) } func (d *dbCache) VerifyingKey(ctx context.Context, id string) (key interface{}, err error) { - if !isSigningKeyFeature(d.feature) { - return nil, xerrors.Errorf("invalid feature: %s", d.feature) - } - return d.sequence(ctx, id) } From e654a6552ca36914b689228eb1b4b07041f90a83 Mon Sep 17 00:00:00 2001 From: Jon Ayers Date: Thu, 3 Oct 2024 22:30:07 +0000 Subject: [PATCH 20/25] Refactor cryptokeys and jwtutils interfaces and logic - Enhance comments for key interfaces to clarify usage and considerations for time validity and clock skew. - Refactor JWE/JWS logic to simplify serialization and deserialization processes, ensuring more efficient and concise handling of JWTs. Implement compact serialization and remove unnecessary base64 encoding. --- coderd/cryptokeys/keycache.go | 14 ++++++++++++++ coderd/jwtutils/jwe.go | 20 ++++++++------------ coderd/jwtutils/jws.go | 4 +--- 3 files changed, 23 insertions(+), 15 deletions(-) diff --git a/coderd/cryptokeys/keycache.go b/coderd/cryptokeys/keycache.go index a5e5d087a5ee2..05c80a15b2378 100644 --- a/coderd/cryptokeys/keycache.go +++ b/coderd/cryptokeys/keycache.go @@ -15,13 +15,27 @@ var ( ) type EncryptionKeycache interface { + // EncryptingKey returns the latest valid key for encrypting payloads. A valid + // key is one that is both past its start time and before its deletion time. EncryptingKey(ctx context.Context) (id string, key interface{}, err error) + // DecryptingKey returns the key with the provided id which maps to its sequence + // number. The key is valid for decryption as long as it is not deleted or past + // its deletion date. We must allow for keys prior to their start time to + // account for clock skew between peers (one key may be past its start time on + // one machine while another is not). DecryptingKey(ctx context.Context, id string) (key interface{}, err error) io.Closer } type SigningKeycache interface { + // SigningKey returns the latest valid key for signing. A valid key is one + // that is both past its start time and before its deletion time. SigningKey(ctx context.Context) (id string, key interface{}, err error) + // VerifyingKey returns the key with the provided id which should map to its + // sequence number. The key is valid for verifying as long as it is not deleted + // or past its deletion date. We must allow for keys prior to their start time + // to account for clock skew between peers (one key may be past its start time + // on one machine while another is not). VerifyingKey(ctx context.Context, id string) (key interface{}, err error) io.Closer } diff --git a/coderd/jwtutils/jwe.go b/coderd/jwtutils/jwe.go index 35c7aa2cdbd14..f50cacb62de7c 100644 --- a/coderd/jwtutils/jwe.go +++ b/coderd/jwtutils/jwe.go @@ -2,7 +2,6 @@ package jwtutils import ( "context" - "encoding/base64" "encoding/json" "time" @@ -58,15 +57,17 @@ func Encrypt(ctx context.Context, e EncryptKeyProvider, claims Claims) (string, return "", xerrors.Errorf("encrypt: %w", err) } - serialized := []byte(encrypted.FullSerialize()) - return base64.RawURLEncoding.EncodeToString(serialized), nil + compact, err := encrypted.CompactSerialize() + if err != nil { + return "", xerrors.Errorf("compact serialize: %w", err) + } + + return compact, nil } // DecryptOptions are options for decrypting a JWE. type DecryptOptions struct { - RegisteredClaims jwt.Expected - - // The following should only be used for JWEs. + RegisteredClaims jwt.Expected KeyAlgorithm jose.KeyAlgorithm ContentEncryptionAlgorithm jose.ContentEncryption } @@ -85,12 +86,7 @@ func Decrypt(ctx context.Context, d DecryptKeyProvider, token string, claims Cla opt(&options) } - encrypted, err := base64.RawURLEncoding.DecodeString(token) - if err != nil { - return xerrors.Errorf("decode: %w", err) - } - - object, err := jose.ParseEncrypted(string(encrypted), + object, err := jose.ParseEncrypted(token, []jose.KeyAlgorithm{options.KeyAlgorithm}, []jose.ContentEncryption{options.ContentEncryptionAlgorithm}, ) diff --git a/coderd/jwtutils/jws.go b/coderd/jwtutils/jws.go index cb04a1273bb97..73f35e672492d 100644 --- a/coderd/jwtutils/jws.go +++ b/coderd/jwtutils/jws.go @@ -71,9 +71,7 @@ func Sign(ctx context.Context, s SigningKeyProvider, claims Claims) (string, err // VerifyOptions are options for verifying a JWT. type VerifyOptions struct { - RegisteredClaims jwt.Expected - - // The following are only used for JWSs. + RegisteredClaims jwt.Expected SignatureAlgorithm jose.SignatureAlgorithm } From 0efabfd277a6f5768162e718e2c1ffe036804216 Mon Sep 17 00:00:00 2001 From: Jon Ayers Date: Thu, 3 Oct 2024 22:36:26 +0000 Subject: [PATCH 21/25] Remove unused test code and mock cleanup Removing unused test and mock code simplifies maintenance and reduces clutter. --- coderd/cryptokeys/dbkeycache_test.go | 2 - coderd/cryptokeys/doc.go | 2 - coderd/cryptokeys/keycachemock.go | 71 ---------------------------- coderd/jwtutils/jwt_test.go | 23 ++------- 4 files changed, 4 insertions(+), 94 deletions(-) delete mode 100644 coderd/cryptokeys/keycachemock.go diff --git a/coderd/cryptokeys/dbkeycache_test.go b/coderd/cryptokeys/dbkeycache_test.go index c421eeaf4b86f..3e0a8fcc033c2 100644 --- a/coderd/cryptokeys/dbkeycache_test.go +++ b/coderd/cryptokeys/dbkeycache_test.go @@ -69,7 +69,6 @@ func TestDBKeyCache(t *testing.T) { _, err = k.VerifyingKey(ctx, "123") require.ErrorIs(t, err, cryptokeys.ErrKeyNotFound) }) - }) t.Run("Signing", func(t *testing.T) { @@ -173,7 +172,6 @@ func TestDBKeyCache(t *testing.T) { _, err := cryptokeys.NewEncryptionCache(logger, db, database.CryptoKeyFeatureOidcConvert, cryptokeys.WithDBCacheClock(clock)) require.ErrorIs(t, err, cryptokeys.ErrInvalidFeature) }) - } func keyID(key database.CryptoKey) string { diff --git a/coderd/cryptokeys/doc.go b/coderd/cryptokeys/doc.go index 8cee81c28bd69..b2494f9f0da8d 100644 --- a/coderd/cryptokeys/doc.go +++ b/coderd/cryptokeys/doc.go @@ -1,4 +1,2 @@ // Package cryptokeys provides an abstraction for fetching internally used cryptographic keys mainly for JWT signing and verification. package cryptokeys - -//go:generate mockgen -destination keycachemock.go -package cryptokeys . Keycache diff --git a/coderd/cryptokeys/keycachemock.go b/coderd/cryptokeys/keycachemock.go deleted file mode 100644 index 7a7b2e5b0ca13..0000000000000 --- a/coderd/cryptokeys/keycachemock.go +++ /dev/null @@ -1,71 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: github.com/coder/coder/v2/coderd/cryptokeys (interfaces: Keycache) -// -// Generated by this command: -// -// mockgen -destination keycachemock.go -package cryptokeys . Keycache -// - -// Package cryptokeys is a generated GoMock package. -package cryptokeys - -import ( - context "context" - reflect "reflect" - - codersdk "github.com/coder/coder/v2/codersdk" - gomock "go.uber.org/mock/gomock" -) - -// MockKeycache is a mock of Keycache interface. -type MockKeycache struct { - ctrl *gomock.Controller - recorder *MockKeycacheMockRecorder -} - -// MockKeycacheMockRecorder is the mock recorder for MockKeycache. -type MockKeycacheMockRecorder struct { - mock *MockKeycache -} - -// NewMockKeycache creates a new mock instance. -func NewMockKeycache(ctrl *gomock.Controller) *MockKeycache { - mock := &MockKeycache{ctrl: ctrl} - mock.recorder = &MockKeycacheMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockKeycache) EXPECT() *MockKeycacheMockRecorder { - return m.recorder -} - -// Signing mocks base method. -func (m *MockKeycache) Signing(arg0 context.Context) (codersdk.CryptoKey, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Signing", arg0) - ret0, _ := ret[0].(codersdk.CryptoKey) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// Signing indicates an expected call of Signing. -func (mr *MockKeycacheMockRecorder) Signing(arg0 any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Signing", reflect.TypeOf((*MockKeycache)(nil).Signing), arg0) -} - -// Verifying mocks base method. -func (m *MockKeycache) Verifying(arg0 context.Context, arg1 int32) (codersdk.CryptoKey, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Verifying", arg0, arg1) - ret0, _ := ret[0].(codersdk.CryptoKey) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// Verifying indicates an expected call of Verifying. -func (mr *MockKeycacheMockRecorder) Verifying(arg0, arg1 any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Verifying", reflect.TypeOf((*MockKeycache)(nil).Verifying), arg0, arg1) -} diff --git a/coderd/jwtutils/jwt_test.go b/coderd/jwtutils/jwt_test.go index 59e23d4ff4064..dc66bfbbdf013 100644 --- a/coderd/jwtutils/jwt_test.go +++ b/coderd/jwtutils/jwt_test.go @@ -3,7 +3,6 @@ package jwtutils_test import ( "context" "crypto/rand" - "encoding/hex" "testing" "time" @@ -19,7 +18,6 @@ import ( "github.com/coder/coder/v2/coderd/database/dbgen" "github.com/coder/coder/v2/coderd/database/dbtestutil" "github.com/coder/coder/v2/coderd/jwtutils" - "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/testutil" ) @@ -355,19 +353,6 @@ func TestJWE(t *testing.T) { }) } -func generateCryptoKey(t *testing.T, seq int32, now time.Time, keySize int) codersdk.CryptoKey { - t.Helper() - - secret := generateSecret(t, keySize) - - return codersdk.CryptoKey{ - Feature: codersdk.CryptoKeyFeatureTailnetResume, - Secret: hex.EncodeToString(secret), - Sequence: seq, - StartsAt: now, - } -} - func generateSecret(t *testing.T, keySize int) []byte { t.Helper() @@ -431,22 +416,22 @@ func newKey(t *testing.T, size int) *key { } } -func (k *key) SigningKey(ctx context.Context) (id string, key interface{}, err error) { +func (k *key) SigningKey(_ context.Context) (id string, key interface{}, err error) { return k.id, k.secret, nil } -func (k *key) VerifyingKey(ctx context.Context, id string) (key interface{}, err error) { +func (k *key) VerifyingKey(_ context.Context, id string) (key interface{}, err error) { k.t.Helper() require.Equal(k.t, k.id, id) return k.secret, nil } -func (k *key) EncryptingKey(ctx context.Context) (id string, key interface{}, err error) { +func (k *key) EncryptingKey(_ context.Context) (id string, key interface{}, err error) { return k.id, k.secret, nil } -func (k *key) DecryptingKey(ctx context.Context, id string) (key interface{}, err error) { +func (k *key) DecryptingKey(_ context.Context, id string) (key interface{}, err error) { k.t.Helper() require.Equal(k.t, k.id, id) From e065356a6a0caf1bcdc968ba2eae7ab3d762102c Mon Sep 17 00:00:00 2001 From: Jon Ayers Date: Thu, 3 Oct 2024 23:02:05 +0000 Subject: [PATCH 22/25] Remove cryptokeys keycachemock from Makefile Deleting unnecessary cryptokeys keycachemock ensures clarity and reduces maintenance. --- Makefile | 10 +++------- coderd/jwtutils/jwt_test.go | 1 - 2 files changed, 3 insertions(+), 8 deletions(-) diff --git a/Makefile b/Makefile index 4bfcc73a06966..be74b27013a23 100644 --- a/Makefile +++ b/Makefile @@ -507,8 +507,7 @@ gen: \ examples/examples.gen.json \ tailnet/tailnettest/coordinatormock.go \ tailnet/tailnettest/coordinateemock.go \ - tailnet/tailnettest/multiagentmock.go \ - coderd/cryptokeys/keycachemock.go + tailnet/tailnettest/multiagentmock.go .PHONY: gen # Mark all generated files as fresh so make thinks they're up-to-date. This is @@ -538,8 +537,8 @@ gen/mark-fresh: tailnet/tailnettest/coordinatormock.go \ tailnet/tailnettest/coordinateemock.go \ tailnet/tailnettest/multiagentmock.go \ - coderd/cryptokeys/keycachemock.go - " + " + for file in $$files; do echo "$$file" if [ ! -f "$$file" ]; then @@ -630,9 +629,6 @@ examples/examples.gen.json: scripts/examplegen/main.go examples/examples.go $(sh coderd/rbac/object_gen.go: scripts/rbacgen/rbacobject.gotmpl scripts/rbacgen/main.go coderd/rbac/object.go coderd/rbac/policy/policy.go go run scripts/rbacgen/main.go rbac > coderd/rbac/object_gen.go -coderd/cryptokeys/keycachemock.go: coderd/cryptokeys/keycache.go - go generate ./coderd/cryptokeys - codersdk/rbacresources_gen.go: scripts/rbacgen/codersdk.gotmpl scripts/rbacgen/main.go coderd/rbac/object.go coderd/rbac/policy/policy.go go run scripts/rbacgen/main.go codersdk > codersdk/rbacresources_gen.go diff --git a/coderd/jwtutils/jwt_test.go b/coderd/jwtutils/jwt_test.go index dc66bfbbdf013..eb849da88e4a0 100644 --- a/coderd/jwtutils/jwt_test.go +++ b/coderd/jwtutils/jwt_test.go @@ -261,7 +261,6 @@ func TestJWS(t *testing.T) { require.NoError(t, err) require.Equal(t, claims, actual) }) - } func TestJWE(t *testing.T) { From 938bdda0594ff373df62c991db6c9fc49bd4c6fa Mon Sep 17 00:00:00 2001 From: Jon Ayers Date: Thu, 3 Oct 2024 23:08:05 +0000 Subject: [PATCH 23/25] Add feature validation to dbCache key methods Ensure that the dbCache methods for encrypting, decrypting, signing, and verifying keys validate the feature flag before proceeding with operations. This validation step prevents using keys for unintended purposes, maintaining proper alignment with their intended cryptographic feature. --- coderd/cryptokeys/dbkeycache.go | 16 ++++++++++++++++ coderd/cryptokeys/dbkeycache_test.go | 28 ++++++++++++++++++++++++++++ 2 files changed, 44 insertions(+) diff --git a/coderd/cryptokeys/dbkeycache.go b/coderd/cryptokeys/dbkeycache.go index 14f16f682f5af..64fa5f8ce06a4 100644 --- a/coderd/cryptokeys/dbkeycache.go +++ b/coderd/cryptokeys/dbkeycache.go @@ -77,18 +77,34 @@ func newDBCache(logger slog.Logger, db database.Store, feature database.CryptoKe } func (d *dbCache) EncryptingKey(ctx context.Context) (id string, key interface{}, err error) { + if !isEncryptionKeyFeature(d.feature) { + return "", nil, ErrInvalidFeature + } + return d.latest(ctx) } func (d *dbCache) DecryptingKey(ctx context.Context, id string) (key interface{}, err error) { + if !isEncryptionKeyFeature(d.feature) { + return nil, ErrInvalidFeature + } + return d.sequence(ctx, id) } func (d *dbCache) SigningKey(ctx context.Context) (id string, key interface{}, err error) { + if !isSigningKeyFeature(d.feature) { + return "", nil, ErrInvalidFeature + } + return d.latest(ctx) } func (d *dbCache) VerifyingKey(ctx context.Context, id string) (key interface{}, err error) { + if !isSigningKeyFeature(d.feature) { + return nil, ErrInvalidFeature + } + return d.sequence(ctx, id) } diff --git a/coderd/cryptokeys/dbkeycache_test.go b/coderd/cryptokeys/dbkeycache_test.go index 3e0a8fcc033c2..e24ef16660db1 100644 --- a/coderd/cryptokeys/dbkeycache_test.go +++ b/coderd/cryptokeys/dbkeycache_test.go @@ -154,10 +154,24 @@ func TestDBKeyCache(t *testing.T) { db, _ = dbtestutil.NewDB(t) clock = quartz.NewMock(t) logger = slogtest.Make(t, nil) + ctx = testutil.Context(t, testutil.WaitShort) ) _, err := cryptokeys.NewSigningCache(logger, db, database.CryptoKeyFeatureWorkspaceApps, cryptokeys.WithDBCacheClock(clock)) require.ErrorIs(t, err, cryptokeys.ErrInvalidFeature) + + // Instantiate a signing cache and try to use it as an encryption cache. + sc, err := cryptokeys.NewSigningCache(logger, db, database.CryptoKeyFeatureOidcConvert, cryptokeys.WithDBCacheClock(clock)) + require.NoError(t, err) + defer sc.Close() + + ec, ok := sc.(cryptokeys.EncryptionKeycache) + require.True(t, ok) + _, _, err = ec.EncryptingKey(ctx) + require.ErrorIs(t, err, cryptokeys.ErrInvalidFeature) + + _, err = ec.DecryptingKey(ctx, "123") + require.ErrorIs(t, err, cryptokeys.ErrInvalidFeature) }) t.Run("InvalidEncryptionFeature", func(t *testing.T) { @@ -167,10 +181,24 @@ func TestDBKeyCache(t *testing.T) { db, _ = dbtestutil.NewDB(t) clock = quartz.NewMock(t) logger = slogtest.Make(t, nil) + ctx = testutil.Context(t, testutil.WaitShort) ) _, err := cryptokeys.NewEncryptionCache(logger, db, database.CryptoKeyFeatureOidcConvert, cryptokeys.WithDBCacheClock(clock)) require.ErrorIs(t, err, cryptokeys.ErrInvalidFeature) + + // Instantiate an encryption cache and try to use it as a signing cache. + ec, err := cryptokeys.NewEncryptionCache(logger, db, database.CryptoKeyFeatureWorkspaceApps, cryptokeys.WithDBCacheClock(clock)) + require.NoError(t, err) + defer ec.Close() + + sc, ok := ec.(cryptokeys.SigningKeycache) + require.True(t, ok) + _, _, err = sc.SigningKey(ctx) + require.ErrorIs(t, err, cryptokeys.ErrInvalidFeature) + + _, err = sc.VerifyingKey(ctx, "123") + require.ErrorIs(t, err, cryptokeys.ErrInvalidFeature) }) } From 48b1b3bf48857cc4abb944fd06ecf86288ec91fb Mon Sep 17 00:00:00 2001 From: Jon Ayers Date: Thu, 3 Oct 2024 23:09:57 +0000 Subject: [PATCH 24/25] fmt --- coderd/jwtutils/jwt_test.go | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/coderd/jwtutils/jwt_test.go b/coderd/jwtutils/jwt_test.go index eb849da88e4a0..ff30f7716b310 100644 --- a/coderd/jwtutils/jwt_test.go +++ b/coderd/jwtutils/jwt_test.go @@ -196,9 +196,7 @@ func TestJWS(t *testing.T) { t.Run("WrongSignatureAlgorithm", func(t *testing.T) { t.Parallel() - var ( - ctx = testutil.Context(t, testutil.WaitShort) - ) + ctx := testutil.Context(t, testutil.WaitShort) key := newKey(t, 64) From 1dd2205a209ae7e9610d100e8bfe7902b48e070e Mon Sep 17 00:00:00 2001 From: Jon Ayers Date: Fri, 4 Oct 2024 01:52:10 +0000 Subject: [PATCH 25/25] Add initialization comment for db key cache timer --- coderd/cryptokeys/dbkeycache.go | 1 + 1 file changed, 1 insertion(+) diff --git a/coderd/cryptokeys/dbkeycache.go b/coderd/cryptokeys/dbkeycache.go index 64fa5f8ce06a4..aa0a2444b35f2 100644 --- a/coderd/cryptokeys/dbkeycache.go +++ b/coderd/cryptokeys/dbkeycache.go @@ -71,6 +71,7 @@ func newDBCache(logger slog.Logger, db database.Store, feature database.CryptoKe opt(d) } + // Initialize the timer. This will get properly initialized the first time we fetch. d.timer = d.clock.AfterFunc(never, d.clear) return d