diff --git a/cli/server.go b/cli/server.go index 5adb44c3c0a7d..2c1f8fab10c1d 100644 --- a/cli/server.go +++ b/cli/server.go @@ -10,7 +10,6 @@ import ( "crypto/tls" "crypto/x509" "database/sql" - "encoding/hex" "errors" "flag" "fmt" @@ -62,6 +61,7 @@ import ( "github.com/coder/serpent" "github.com/coder/wgtunnel/tunnelsdk" + "github.com/coder/coder/v2/coderd/cryptokeys" "github.com/coder/coder/v2/coderd/entitlements" "github.com/coder/coder/v2/coderd/notifications/reports" "github.com/coder/coder/v2/coderd/runtimeconfig" @@ -97,7 +97,6 @@ import ( "github.com/coder/coder/v2/coderd/updatecheck" "github.com/coder/coder/v2/coderd/util/slice" stringutil "github.com/coder/coder/v2/coderd/util/strings" - "github.com/coder/coder/v2/coderd/workspaceapps" "github.com/coder/coder/v2/coderd/workspaceapps/appurl" "github.com/coder/coder/v2/coderd/workspacestats" "github.com/coder/coder/v2/codersdk" @@ -741,90 +740,31 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd. return xerrors.Errorf("set deployment id: %w", err) } } - - // Read the app signing key from the DB. We store it hex encoded - // since the config table uses strings for the value and we - // don't want to deal with automatic encoding issues. - appSecurityKeyStr, err := tx.GetAppSecurityKey(ctx) - if err != nil && !xerrors.Is(err, sql.ErrNoRows) { - return xerrors.Errorf("get app signing key: %w", err) - } - // If the string in the DB is an invalid hex string or the - // length is not equal to the current key length, generate a new - // one. - // - // If the key is regenerated, old signed tokens and encrypted - // strings will become invalid. New signed app tokens will be - // generated automatically on failure. Any workspace app token - // smuggling operations in progress may fail, although with a - // helpful error. - if decoded, err := hex.DecodeString(appSecurityKeyStr); err != nil || len(decoded) != len(workspaceapps.SecurityKey{}) { - b := make([]byte, len(workspaceapps.SecurityKey{})) - _, err := rand.Read(b) - if err != nil { - return xerrors.Errorf("generate fresh app signing key: %w", err) - } - - appSecurityKeyStr = hex.EncodeToString(b) - err = tx.UpsertAppSecurityKey(ctx, appSecurityKeyStr) - if err != nil { - return xerrors.Errorf("insert freshly generated app signing key to database: %w", err) - } - } - - appSecurityKey, err := workspaceapps.KeyFromString(appSecurityKeyStr) - if err != nil { - return xerrors.Errorf("decode app signing key from database: %w", err) - } - - options.AppSecurityKey = appSecurityKey - - // Read the oauth signing key from the database. Like the app security, generate a new one - // if it is invalid for any reason. - oauthSigningKeyStr, err := tx.GetOAuthSigningKey(ctx) - if err != nil && !xerrors.Is(err, sql.ErrNoRows) { - return xerrors.Errorf("get app oauth signing key: %w", err) - } - if decoded, err := hex.DecodeString(oauthSigningKeyStr); err != nil || len(decoded) != len(options.OAuthSigningKey) { - b := make([]byte, len(options.OAuthSigningKey)) - _, err := rand.Read(b) - if err != nil { - return xerrors.Errorf("generate fresh oauth signing key: %w", err) - } - - oauthSigningKeyStr = hex.EncodeToString(b) - err = tx.UpsertOAuthSigningKey(ctx, oauthSigningKeyStr) - if err != nil { - return xerrors.Errorf("insert freshly generated oauth signing key to database: %w", err) - } - } - - oauthKeyBytes, err := hex.DecodeString(oauthSigningKeyStr) - if err != nil { - return xerrors.Errorf("decode oauth signing key from database: %w", err) - } - if len(oauthKeyBytes) != len(options.OAuthSigningKey) { - return xerrors.Errorf("oauth signing key in database is not the correct length, expect %d got %d", len(options.OAuthSigningKey), len(oauthKeyBytes)) - } - copy(options.OAuthSigningKey[:], oauthKeyBytes) - if options.OAuthSigningKey == [32]byte{} { - return xerrors.Errorf("oauth signing key in database is empty") - } - - // Read the coordinator resume token signing key from the - // database. - resumeTokenKey, err := tailnet.ResumeTokenSigningKeyFromDatabase(ctx, tx) - if err != nil { - return xerrors.Errorf("get coordinator resume token key from database: %w", err) - } - options.CoordinatorResumeTokenProvider = tailnet.NewResumeTokenKeyProvider(resumeTokenKey, quartz.NewReal(), tailnet.DefaultResumeTokenExpiry) - return nil }, nil) if err != nil { - return err + return xerrors.Errorf("set deployment id: %w", err) + } + + fetcher := &cryptokeys.DBFetcher{ + DB: options.Database, + } + + resumeKeycache, err := cryptokeys.NewSigningCache(ctx, + logger, + fetcher, + codersdk.CryptoKeyFeatureTailnetResume, + ) + if err != nil { + logger.Critical(ctx, "failed to properly instantiate tailnet resume signing cache", slog.Error(err)) } + options.CoordinatorResumeTokenProvider = tailnet.NewResumeTokenKeyProvider( + resumeKeycache, + quartz.NewReal(), + tailnet.DefaultResumeTokenExpiry, + ) + options.RuntimeConfig = runtimeconfig.NewManager() // This should be output before the logs start streaming. diff --git a/coderd/apidoc/docs.go b/coderd/apidoc/docs.go index 76084b1ff54dd..09f070046066a 100644 --- a/coderd/apidoc/docs.go +++ b/coderd/apidoc/docs.go @@ -7646,6 +7646,15 @@ const docTemplate = `{ ], "summary": "Get workspace proxy crypto keys", "operationId": "get-workspace-proxy-crypto-keys", + "parameters": [ + { + "type": "string", + "description": "Feature key", + "name": "feature", + "in": "query", + "required": true + } + ], "responses": { "200": { "description": "OK", @@ -10011,12 +10020,14 @@ const docTemplate = `{ "codersdk.CryptoKeyFeature": { "type": "string", "enum": [ - "workspace_apps", + "workspace_apps_api_key", + "workspace_apps_token", "oidc_convert", "tailnet_resume" ], "x-enum-varnames": [ - "CryptoKeyFeatureWorkspaceApp", + "CryptoKeyFeatureWorkspaceAppsAPIKey", + "CryptoKeyFeatureWorkspaceAppsToken", "CryptoKeyFeatureOIDCConvert", "CryptoKeyFeatureTailnetResume" ] @@ -16244,9 +16255,6 @@ const docTemplate = `{ "wsproxysdk.RegisterWorkspaceProxyResponse": { "type": "object", "properties": { - "app_security_key": { - "type": "string" - }, "derp_force_websockets": { "type": "boolean" }, diff --git a/coderd/apidoc/swagger.json b/coderd/apidoc/swagger.json index beff69ca22373..42b34d576509a 100644 --- a/coderd/apidoc/swagger.json +++ b/coderd/apidoc/swagger.json @@ -6758,6 +6758,15 @@ "tags": ["Enterprise"], "summary": "Get workspace proxy crypto keys", "operationId": "get-workspace-proxy-crypto-keys", + "parameters": [ + { + "type": "string", + "description": "Feature key", + "name": "feature", + "in": "query", + "required": true + } + ], "responses": { "200": { "description": "OK", @@ -8914,9 +8923,15 @@ }, "codersdk.CryptoKeyFeature": { "type": "string", - "enum": ["workspace_apps", "oidc_convert", "tailnet_resume"], + "enum": [ + "workspace_apps_api_key", + "workspace_apps_token", + "oidc_convert", + "tailnet_resume" + ], "x-enum-varnames": [ - "CryptoKeyFeatureWorkspaceApp", + "CryptoKeyFeatureWorkspaceAppsAPIKey", + "CryptoKeyFeatureWorkspaceAppsToken", "CryptoKeyFeatureOIDCConvert", "CryptoKeyFeatureTailnetResume" ] @@ -14853,9 +14868,6 @@ "wsproxysdk.RegisterWorkspaceProxyResponse": { "type": "object", "properties": { - "app_security_key": { - "type": "string" - }, "derp_force_websockets": { "type": "boolean" }, diff --git a/coderd/coderd.go b/coderd/coderd.go index cb0884808ef27..3011c2d58d39c 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -40,6 +40,7 @@ import ( "github.com/coder/quartz" "github.com/coder/serpent" + "github.com/coder/coder/v2/coderd/cryptokeys" "github.com/coder/coder/v2/coderd/entitlements" "github.com/coder/coder/v2/coderd/idpsync" "github.com/coder/coder/v2/coderd/runtimeconfig" @@ -185,9 +186,6 @@ type Options struct { TemplateScheduleStore *atomic.Pointer[schedule.TemplateScheduleStore] UserQuietHoursScheduleStore *atomic.Pointer[schedule.UserQuietHoursScheduleStore] AccessControlStore *atomic.Pointer[dbauthz.AccessControlStore] - // AppSecurityKey is the crypto key used to sign and encrypt tokens related to - // workspace applications. It consists of both a signing and encryption key. - AppSecurityKey workspaceapps.SecurityKey // CoordinatorResumeTokenProvider is used to provide and validate resume // tokens issued by and passed to the coordinator DRPC API. CoordinatorResumeTokenProvider tailnet.ResumeTokenProvider @@ -251,6 +249,12 @@ type Options struct { // OneTimePasscodeValidityPeriod specifies how long a one time passcode should be valid for. OneTimePasscodeValidityPeriod time.Duration + + // Keycaches + AppSigningKeyCache cryptokeys.SigningKeycache + AppEncryptionKeyCache cryptokeys.EncryptionKeycache + OIDCConvertKeyCache cryptokeys.SigningKeycache + Clock quartz.Clock } // @title Coder API @@ -352,6 +356,9 @@ func New(options *Options) *API { if options.PrometheusRegistry == nil { options.PrometheusRegistry = prometheus.NewRegistry() } + if options.Clock == nil { + options.Clock = quartz.NewReal() + } if options.DERPServer == nil && options.DeploymentValues.DERP.Server.Enable { options.DERPServer = derp.NewServer(key.NewNode(), tailnet.Logger(options.Logger.Named("derp"))) } @@ -444,6 +451,49 @@ func New(options *Options) *API { if err != nil { panic(xerrors.Errorf("get deployment ID: %w", err)) } + + fetcher := &cryptokeys.DBFetcher{ + DB: options.Database, + } + + if options.OIDCConvertKeyCache == nil { + options.OIDCConvertKeyCache, err = cryptokeys.NewSigningCache(ctx, + options.Logger.Named("oidc_convert_keycache"), + fetcher, + codersdk.CryptoKeyFeatureOIDCConvert, + ) + if err != nil { + options.Logger.Critical(ctx, "failed to properly instantiate oidc convert signing cache", slog.Error(err)) + } + } + + if options.AppSigningKeyCache == nil { + options.AppSigningKeyCache, err = cryptokeys.NewSigningCache(ctx, + options.Logger.Named("app_signing_keycache"), + fetcher, + codersdk.CryptoKeyFeatureWorkspaceAppsToken, + ) + if err != nil { + options.Logger.Critical(ctx, "failed to properly instantiate app signing key cache", slog.Error(err)) + } + } + + if options.AppEncryptionKeyCache == nil { + options.AppEncryptionKeyCache, err = cryptokeys.NewEncryptionCache(ctx, + options.Logger, + fetcher, + codersdk.CryptoKeyFeatureWorkspaceAppsAPIKey, + ) + if err != nil { + options.Logger.Critical(ctx, "failed to properly instantiate app encryption key cache", slog.Error(err)) + } + } + + // Start a background process that rotates keys. We intentionally start this after the caches + // are created to force initial requests for a key to populate the caches. This helps catch + // bugs that may only occur when a key isn't precached in tests and the latency cost is minimal. + cryptokeys.StartRotator(ctx, options.Logger, options.Database) + api := &API{ ctx: ctx, cancel: cancel, @@ -464,7 +514,7 @@ func New(options *Options) *API { options.DeploymentValues, oauthConfigs, options.AgentInactiveDisconnectTimeout, - options.AppSecurityKey, + options.AppSigningKeyCache, ), metricsCache: metricsCache, Auditor: atomic.Pointer[audit.Auditor]{}, @@ -606,7 +656,7 @@ func New(options *Options) *API { ResumeTokenProvider: api.Options.CoordinatorResumeTokenProvider, }) if err != nil { - api.Logger.Fatal(api.ctx, "failed to initialize tailnet client service", slog.Error(err)) + api.Logger.Fatal(context.Background(), "failed to initialize tailnet client service", slog.Error(err)) } api.statsReporter = workspacestats.NewReporter(workspacestats.ReporterOptions{ @@ -628,9 +678,6 @@ func New(options *Options) *API { options.WorkspaceAppsStatsCollectorOptions.Reporter = api.statsReporter } - if options.AppSecurityKey.IsZero() { - api.Logger.Fatal(api.ctx, "app security key cannot be zero") - } api.workspaceAppServer = &workspaceapps.Server{ Logger: workspaceAppsLogger, @@ -642,11 +689,11 @@ func New(options *Options) *API { SignedTokenProvider: api.WorkspaceAppsProvider, AgentProvider: api.agentProvider, - AppSecurityKey: options.AppSecurityKey, StatsCollector: workspaceapps.NewStatsCollector(options.WorkspaceAppsStatsCollectorOptions), - DisablePathApps: options.DeploymentValues.DisablePathApps.Value(), - SecureAuthCookie: options.DeploymentValues.SecureAuthCookie.Value(), + DisablePathApps: options.DeploymentValues.DisablePathApps.Value(), + SecureAuthCookie: options.DeploymentValues.SecureAuthCookie.Value(), + APIKeyEncryptionKeycache: options.AppEncryptionKeyCache, } apiKeyMiddleware := httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{ @@ -1434,6 +1481,9 @@ func (api *API) Close() error { _ = api.agentProvider.Close() _ = api.statsReporter.Close() _ = api.NetworkTelemetryBatcher.Close() + _ = api.OIDCConvertKeyCache.Close() + _ = api.AppSigningKeyCache.Close() + _ = api.AppEncryptionKeyCache.Close() return nil } diff --git a/coderd/coderdtest/coderdtest.go b/coderd/coderdtest/coderdtest.go index 05c31f35bd20a..d94a6fbe93c4e 100644 --- a/coderd/coderdtest/coderdtest.go +++ b/coderd/coderdtest/coderdtest.go @@ -55,6 +55,7 @@ import ( "github.com/coder/coder/v2/coderd/audit" "github.com/coder/coder/v2/coderd/autobuild" "github.com/coder/coder/v2/coderd/awsidentity" + "github.com/coder/coder/v2/coderd/cryptokeys" "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/db2sdk" "github.com/coder/coder/v2/coderd/database/dbauthz" @@ -88,12 +89,9 @@ import ( sdkproto "github.com/coder/coder/v2/provisionersdk/proto" "github.com/coder/coder/v2/tailnet" "github.com/coder/coder/v2/testutil" + "github.com/coder/quartz" ) -// AppSecurityKey is a 96-byte key used to sign JWTs and encrypt JWEs for -// workspace app tokens in tests. -var AppSecurityKey = must(workspaceapps.KeyFromString("6465616e207761732068657265206465616e207761732068657265206465616e207761732068657265206465616e207761732068657265206465616e207761732068657265206465616e207761732068657265206465616e2077617320686572")) - type Options struct { // AccessURL denotes a custom access URL. By default we use the httptest // server's URL. Setting this may result in unexpected behavior (especially @@ -161,8 +159,10 @@ type Options struct { DatabaseRolluper *dbrollup.Rolluper WorkspaceUsageTrackerFlush chan int WorkspaceUsageTrackerTick chan time.Time - - NotificationsEnqueuer notifications.Enqueuer + NotificationsEnqueuer notifications.Enqueuer + APIKeyEncryptionCache cryptokeys.EncryptionKeycache + OIDCConvertKeyCache cryptokeys.SigningKeycache + Clock quartz.Clock } // New constructs a codersdk client connected to an in-memory API instance. @@ -525,7 +525,6 @@ func NewOptions(t testing.TB, options *Options) (func(http.Handler), context.Can DeploymentOptions: codersdk.DeploymentOptionsWithoutSecrets(options.DeploymentValues.Options()), UpdateCheckOptions: options.UpdateCheckOptions, SwaggerEndpoint: options.SwaggerEndpoint, - AppSecurityKey: AppSecurityKey, SSHConfig: options.ConfigSSH, HealthcheckFunc: options.HealthcheckFunc, HealthcheckTimeout: options.HealthcheckTimeout, @@ -538,6 +537,9 @@ func NewOptions(t testing.TB, options *Options) (func(http.Handler), context.Can WorkspaceUsageTracker: wuTracker, NotificationsEnqueuer: options.NotificationsEnqueuer, OneTimePasscodeValidityPeriod: options.OneTimePasscodeValidityPeriod, + Clock: options.Clock, + AppEncryptionKeyCache: options.APIKeyEncryptionCache, + OIDCConvertKeyCache: options.OIDCConvertKeyCache, } } diff --git a/coderd/cryptokeys/cache.go b/coderd/cryptokeys/cache.go index 74fb025d416fd..7777d5f75b942 100644 --- a/coderd/cryptokeys/cache.go +++ b/coderd/cryptokeys/cache.go @@ -3,6 +3,7 @@ package cryptokeys import ( "context" "encoding/hex" + "fmt" "io" "strconv" "sync" @@ -12,7 +13,7 @@ import ( "cdr.dev/slog" "github.com/coder/coder/v2/coderd/database" - "github.com/coder/coder/v2/coderd/database/db2sdk" + "github.com/coder/coder/v2/coderd/database/dbauthz" "github.com/coder/coder/v2/codersdk" "github.com/coder/quartz" ) @@ -25,7 +26,7 @@ var ( ) type Fetcher interface { - Fetch(ctx context.Context) ([]codersdk.CryptoKey, error) + Fetch(ctx context.Context, feature codersdk.CryptoKeyFeature) ([]codersdk.CryptoKey, error) } type EncryptionKeycache interface { @@ -62,27 +63,26 @@ const ( ) type DBFetcher struct { - DB database.Store - Feature database.CryptoKeyFeature + DB database.Store } -func (d *DBFetcher) Fetch(ctx context.Context) ([]codersdk.CryptoKey, error) { - keys, err := d.DB.GetCryptoKeysByFeature(ctx, d.Feature) +func (d *DBFetcher) Fetch(ctx context.Context, feature codersdk.CryptoKeyFeature) ([]codersdk.CryptoKey, error) { + keys, err := d.DB.GetCryptoKeysByFeature(ctx, database.CryptoKeyFeature(feature)) if err != nil { return nil, xerrors.Errorf("get crypto keys by feature: %w", err) } - return db2sdk.CryptoKeys(keys), nil + return toSDKKeys(keys), nil } // cache implements the caching functionality for both signing and encryption keys. type cache struct { - clock quartz.Clock - refreshCtx context.Context - refreshCancel context.CancelFunc - fetcher Fetcher - logger slog.Logger - feature codersdk.CryptoKeyFeature + ctx context.Context + cancel context.CancelFunc + clock quartz.Clock + fetcher Fetcher + logger slog.Logger + feature codersdk.CryptoKeyFeature mu sync.Mutex keys map[int32]codersdk.CryptoKey @@ -109,7 +109,8 @@ func NewSigningCache(ctx context.Context, logger slog.Logger, fetcher Fetcher, if !isSigningKeyFeature(feature) { return nil, xerrors.Errorf("invalid feature: %s", feature) } - return newCache(ctx, logger, fetcher, feature, opts...) + logger = logger.Named(fmt.Sprintf("%s_signing_keycache", feature)) + return newCache(ctx, logger, fetcher, feature, opts...), nil } func NewEncryptionCache(ctx context.Context, logger slog.Logger, fetcher Fetcher, @@ -118,10 +119,11 @@ func NewEncryptionCache(ctx context.Context, logger slog.Logger, fetcher Fetcher if !isEncryptionKeyFeature(feature) { return nil, xerrors.Errorf("invalid feature: %s", feature) } - return newCache(ctx, logger, fetcher, feature, opts...) + logger = logger.Named(fmt.Sprintf("%s_encryption_keycache", feature)) + return newCache(ctx, logger, fetcher, feature, opts...), nil } -func newCache(ctx context.Context, logger slog.Logger, fetcher Fetcher, feature codersdk.CryptoKeyFeature, opts ...func(*cache)) (*cache, error) { +func newCache(ctx context.Context, logger slog.Logger, fetcher Fetcher, feature codersdk.CryptoKeyFeature, opts ...func(*cache)) *cache { cache := &cache{ clock: quartz.NewReal(), logger: logger, @@ -134,16 +136,16 @@ func newCache(ctx context.Context, logger slog.Logger, fetcher Fetcher, feature } cache.cond = sync.NewCond(&cache.mu) - cache.refreshCtx, cache.refreshCancel = context.WithCancel(ctx) + //nolint:gocritic // We need to be able to read the keys in order to cache them. + cache.ctx, cache.cancel = context.WithCancel(dbauthz.AsKeyReader(ctx)) cache.refresher = cache.clock.AfterFunc(refreshInterval, cache.refresh) - keys, err := cache.cryptoKeys(ctx) + keys, err := cache.cryptoKeys(cache.ctx) if err != nil { - cache.refreshCancel() - return nil, xerrors.Errorf("initial fetch: %w", err) + cache.logger.Critical(cache.ctx, "failed initial fetch", slog.Error(err)) } cache.keys = keys - return cache, nil + return cache } func (c *cache) EncryptingKey(ctx context.Context) (string, interface{}, error) { @@ -151,6 +153,8 @@ func (c *cache) EncryptingKey(ctx context.Context) (string, interface{}, error) return "", nil, ErrInvalidFeature } + //nolint:gocritic // cache can only read crypto keys. + ctx = dbauthz.AsKeyReader(ctx) return c.cryptoKey(ctx, latestSequence) } @@ -164,6 +168,8 @@ func (c *cache) DecryptingKey(ctx context.Context, id string) (interface{}, erro return nil, xerrors.Errorf("parse id: %w", err) } + //nolint:gocritic // cache can only read crypto keys. + ctx = dbauthz.AsKeyReader(ctx) _, secret, err := c.cryptoKey(ctx, int32(seq)) if err != nil { return nil, xerrors.Errorf("crypto key: %w", err) @@ -176,6 +182,8 @@ func (c *cache) SigningKey(ctx context.Context) (string, interface{}, error) { return "", nil, ErrInvalidFeature } + //nolint:gocritic // cache can only read crypto keys. + ctx = dbauthz.AsKeyReader(ctx) return c.cryptoKey(ctx, latestSequence) } @@ -188,7 +196,8 @@ func (c *cache) VerifyingKey(ctx context.Context, id string) (interface{}, error if err != nil { return nil, xerrors.Errorf("parse id: %w", err) } - + //nolint:gocritic // cache can only read crypto keys. + ctx = dbauthz.AsKeyReader(ctx) _, secret, err := c.cryptoKey(ctx, int32(seq)) if err != nil { return nil, xerrors.Errorf("crypto key: %w", err) @@ -198,12 +207,12 @@ func (c *cache) VerifyingKey(ctx context.Context, id string) (interface{}, error } func isEncryptionKeyFeature(feature codersdk.CryptoKeyFeature) bool { - return feature == codersdk.CryptoKeyFeatureWorkspaceApp + return feature == codersdk.CryptoKeyFeatureWorkspaceAppsAPIKey } func isSigningKeyFeature(feature codersdk.CryptoKeyFeature) bool { switch feature { - case codersdk.CryptoKeyFeatureTailnetResume, codersdk.CryptoKeyFeatureOIDCConvert: + case codersdk.CryptoKeyFeatureTailnetResume, codersdk.CryptoKeyFeatureOIDCConvert, codersdk.CryptoKeyFeatureWorkspaceAppsToken: return true default: return false @@ -292,14 +301,15 @@ func checkKey(key codersdk.CryptoKey, sequence int32, now time.Time) (string, [] func (c *cache) refresh() { now := c.clock.Now("CryptoKeyCache", "refresh") c.mu.Lock() - defer c.mu.Unlock() if c.closed { + c.mu.Unlock() return } // If something's already fetching, we don't need to do anything. if c.fetching { + c.mu.Unlock() return } @@ -307,20 +317,21 @@ func (c *cache) refresh() { // is ongoing but prior to the timer getting reset. In this case we want to // avoid double fetching. if now.Sub(c.lastFetch) < refreshInterval { + c.mu.Unlock() return } c.fetching = true c.mu.Unlock() - keys, err := c.cryptoKeys(c.refreshCtx) + keys, err := c.cryptoKeys(c.ctx) if err != nil { - c.logger.Error(c.refreshCtx, "fetch crypto keys", slog.Error(err)) + c.logger.Error(c.ctx, "fetch crypto keys", slog.Error(err)) return } - // We don't defer an unlock here due to the deferred unlock at the top of the function. c.mu.Lock() + defer c.mu.Unlock() c.lastFetch = c.clock.Now() c.refresher.Reset(refreshInterval) @@ -332,9 +343,9 @@ func (c *cache) refresh() { // cryptoKeys queries the control plane for the crypto keys. // Outside of initialization, this should only be called by fetch. func (c *cache) cryptoKeys(ctx context.Context) (map[int32]codersdk.CryptoKey, error) { - keys, err := c.fetcher.Fetch(ctx) + keys, err := c.fetcher.Fetch(ctx, c.feature) if err != nil { - return nil, xerrors.Errorf("crypto keys: %w", err) + return nil, xerrors.Errorf("fetch: %w", err) } cache := toKeyMap(keys, c.clock.Now()) return cache, nil @@ -361,9 +372,28 @@ func (c *cache) Close() error { } c.closed = true - c.refreshCancel() + c.cancel() c.refresher.Stop() c.cond.Broadcast() return nil } + +// We have to do this to avoid a circular dependency on db2sdk (cryptokeys -> db2sdk -> tailnet -> cryptokeys) +func toSDKKeys(keys []database.CryptoKey) []codersdk.CryptoKey { + into := make([]codersdk.CryptoKey, 0, len(keys)) + for _, key := range keys { + into = append(into, toSDK(key)) + } + return into +} + +func toSDK(key database.CryptoKey) codersdk.CryptoKey { + return codersdk.CryptoKey{ + Feature: codersdk.CryptoKeyFeature(key.Feature), + Sequence: key.Sequence, + StartsAt: key.StartsAt, + DeletesAt: key.DeletesAt.Time, + Secret: key.Secret.String, + } +} diff --git a/coderd/cryptokeys/cache_test.go b/coderd/cryptokeys/cache_test.go index 92fc4527ae7b3..cda87315605a4 100644 --- a/coderd/cryptokeys/cache_test.go +++ b/coderd/cryptokeys/cache_test.go @@ -488,7 +488,7 @@ type fakeFetcher struct { called int } -func (f *fakeFetcher) Fetch(_ context.Context) ([]codersdk.CryptoKey, error) { +func (f *fakeFetcher) Fetch(_ context.Context, _ codersdk.CryptoKeyFeature) ([]codersdk.CryptoKey, error) { f.called++ return f.keys, nil } diff --git a/coderd/cryptokeys/rotate.go b/coderd/cryptokeys/rotate.go index 14a623e2156db..5d7d7b33b9dec 100644 --- a/coderd/cryptokeys/rotate.go +++ b/coderd/cryptokeys/rotate.go @@ -11,6 +11,7 @@ import ( "cdr.dev/slog" "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/dbauthz" "github.com/coder/coder/v2/coderd/database/dbtime" "github.com/coder/quartz" ) @@ -53,10 +54,12 @@ func WithKeyDuration(keyDuration time.Duration) RotatorOption { // StartRotator starts a background process that rotates keys in the database. // It ensures there's at least one valid key per feature prior to returning. // Canceling the provided context will stop the background process. -func StartRotator(ctx context.Context, logger slog.Logger, db database.Store, opts ...RotatorOption) error { +func StartRotator(ctx context.Context, logger slog.Logger, db database.Store, opts ...RotatorOption) { + //nolint:gocritic // KeyRotator can only rotate crypto keys. + ctx = dbauthz.AsKeyRotator(ctx) kr := &rotator{ db: db, - logger: logger, + logger: logger.Named("keyrotator"), clock: quartz.NewReal(), keyDuration: DefaultKeyDuration, features: database.AllCryptoKeyFeatureValues(), @@ -68,12 +71,10 @@ func StartRotator(ctx context.Context, logger slog.Logger, db database.Store, op err := kr.rotateKeys(ctx) if err != nil { - return xerrors.Errorf("rotate keys: %w", err) + kr.logger.Critical(ctx, "failed to rotate keys", slog.Error(err)) } go kr.start(ctx) - - return nil } // start begins the process of rotating keys. @@ -226,9 +227,11 @@ func (k *rotator) rotateKey(ctx context.Context, tx database.Store, key database func generateNewSecret(feature database.CryptoKeyFeature) (string, error) { switch feature { - case database.CryptoKeyFeatureWorkspaceApps: + case database.CryptoKeyFeatureWorkspaceAppsAPIKey: return generateKey(32) - case database.CryptoKeyFeatureOidcConvert: + case database.CryptoKeyFeatureWorkspaceAppsToken: + return generateKey(64) + case database.CryptoKeyFeatureOIDCConvert: return generateKey(64) case database.CryptoKeyFeatureTailnetResume: return generateKey(64) @@ -247,9 +250,11 @@ func generateKey(length int) (string, error) { func tokenDuration(feature database.CryptoKeyFeature) time.Duration { switch feature { - case database.CryptoKeyFeatureWorkspaceApps: + case database.CryptoKeyFeatureWorkspaceAppsAPIKey: + return WorkspaceAppsTokenDuration + case database.CryptoKeyFeatureWorkspaceAppsToken: return WorkspaceAppsTokenDuration - case database.CryptoKeyFeatureOidcConvert: + case database.CryptoKeyFeatureOIDCConvert: return OIDCConvertTokenDuration case database.CryptoKeyFeatureTailnetResume: return TailnetResumeTokenDuration diff --git a/coderd/cryptokeys/rotate_internal_test.go b/coderd/cryptokeys/rotate_internal_test.go index 43754c1d8750f..e427a3c6216ac 100644 --- a/coderd/cryptokeys/rotate_internal_test.go +++ b/coderd/cryptokeys/rotate_internal_test.go @@ -38,7 +38,7 @@ func Test_rotateKeys(t *testing.T) { clock: clock, logger: logger, features: []database.CryptoKeyFeature{ - database.CryptoKeyFeatureWorkspaceApps, + database.CryptoKeyFeatureWorkspaceAppsAPIKey, }, } @@ -46,7 +46,7 @@ func Test_rotateKeys(t *testing.T) { // Seed the database with an existing key. oldKey := dbgen.CryptoKey(t, db, database.CryptoKey{ - Feature: database.CryptoKeyFeatureWorkspaceApps, + Feature: database.CryptoKeyFeatureWorkspaceAppsAPIKey, StartsAt: now, Sequence: 15, }) @@ -69,11 +69,11 @@ func Test_rotateKeys(t *testing.T) { // The new key should be created and have a starts_at of the old key's expires_at. newKey, err := db.GetCryptoKeyByFeatureAndSequence(ctx, database.GetCryptoKeyByFeatureAndSequenceParams{ - Feature: database.CryptoKeyFeatureWorkspaceApps, + Feature: database.CryptoKeyFeatureWorkspaceAppsAPIKey, Sequence: oldKey.Sequence + 1, }) require.NoError(t, err) - requireKey(t, newKey, database.CryptoKeyFeatureWorkspaceApps, oldKey.ExpiresAt(keyDuration), nullTime, oldKey.Sequence+1) + requireKey(t, newKey, database.CryptoKeyFeatureWorkspaceAppsAPIKey, oldKey.ExpiresAt(keyDuration), nullTime, oldKey.Sequence+1) // Advance the clock just before the keys delete time. clock.Advance(oldKey.DeletesAt.Time.UTC().Sub(now) - time.Second) @@ -123,7 +123,7 @@ func Test_rotateKeys(t *testing.T) { clock: clock, logger: logger, features: []database.CryptoKeyFeature{ - database.CryptoKeyFeatureWorkspaceApps, + database.CryptoKeyFeatureWorkspaceAppsAPIKey, }, } @@ -131,7 +131,7 @@ func Test_rotateKeys(t *testing.T) { // Seed the database with an existing key existingKey := dbgen.CryptoKey(t, db, database.CryptoKey{ - Feature: database.CryptoKeyFeatureWorkspaceApps, + Feature: database.CryptoKeyFeatureWorkspaceAppsAPIKey, StartsAt: now, Sequence: 123, }) @@ -179,7 +179,7 @@ func Test_rotateKeys(t *testing.T) { clock: clock, logger: logger, features: []database.CryptoKeyFeature{ - database.CryptoKeyFeatureWorkspaceApps, + database.CryptoKeyFeatureWorkspaceAppsAPIKey, }, } @@ -187,7 +187,7 @@ func Test_rotateKeys(t *testing.T) { // Seed the database with an existing key deletingKey := dbgen.CryptoKey(t, db, database.CryptoKey{ - Feature: database.CryptoKeyFeatureWorkspaceApps, + Feature: database.CryptoKeyFeatureWorkspaceAppsAPIKey, StartsAt: now.Add(-keyDuration), Sequence: 789, DeletesAt: sql.NullTime{ @@ -232,7 +232,7 @@ func Test_rotateKeys(t *testing.T) { clock: clock, logger: logger, features: []database.CryptoKeyFeature{ - database.CryptoKeyFeatureWorkspaceApps, + database.CryptoKeyFeatureWorkspaceAppsAPIKey, }, } @@ -240,7 +240,7 @@ func Test_rotateKeys(t *testing.T) { // Seed the database with an existing key deletingKey := dbgen.CryptoKey(t, db, database.CryptoKey{ - Feature: database.CryptoKeyFeatureWorkspaceApps, + Feature: database.CryptoKeyFeatureWorkspaceAppsAPIKey, StartsAt: now, Sequence: 456, DeletesAt: sql.NullTime{ @@ -281,7 +281,7 @@ func Test_rotateKeys(t *testing.T) { clock: clock, logger: logger, features: []database.CryptoKeyFeature{ - database.CryptoKeyFeatureWorkspaceApps, + database.CryptoKeyFeatureWorkspaceAppsAPIKey, }, } @@ -291,7 +291,7 @@ func Test_rotateKeys(t *testing.T) { keys, err := db.GetCryptoKeys(ctx) require.NoError(t, err) require.Len(t, keys, 1) - requireKey(t, keys[0], database.CryptoKeyFeatureWorkspaceApps, clock.Now().UTC(), nullTime, 1) + requireKey(t, keys[0], database.CryptoKeyFeatureWorkspaceAppsAPIKey, clock.Now().UTC(), nullTime, 1) }) // Assert we insert a new key when the only key was manually deleted. @@ -312,14 +312,14 @@ func Test_rotateKeys(t *testing.T) { clock: clock, logger: logger, features: []database.CryptoKeyFeature{ - database.CryptoKeyFeatureWorkspaceApps, + database.CryptoKeyFeatureWorkspaceAppsAPIKey, }, } now := dbnow(clock) deletedkey := dbgen.CryptoKey(t, db, database.CryptoKey{ - Feature: database.CryptoKeyFeatureWorkspaceApps, + Feature: database.CryptoKeyFeatureWorkspaceAppsAPIKey, StartsAt: now, Sequence: 19, DeletesAt: sql.NullTime{ @@ -338,7 +338,7 @@ func Test_rotateKeys(t *testing.T) { keys, err := db.GetCryptoKeys(ctx) require.NoError(t, err) require.Len(t, keys, 1) - requireKey(t, keys[0], database.CryptoKeyFeatureWorkspaceApps, now, nullTime, deletedkey.Sequence+1) + requireKey(t, keys[0], database.CryptoKeyFeatureWorkspaceAppsAPIKey, now, nullTime, deletedkey.Sequence+1) }) // This tests ensures that rotation works with multiple @@ -365,9 +365,11 @@ func Test_rotateKeys(t *testing.T) { now := dbnow(clock) - // We'll test a scenario where one feature has no valid keys. - // Another has a key that should be rotate. And one that - // has a valid key that shouldn't trigger an action. + // We'll test a scenario where: + // - One feature has no valid keys. + // - One has a key that should be rotated. + // - One has a valid key that shouldn't trigger an action. + // - One has no keys at all. _ = dbgen.CryptoKey(t, db, database.CryptoKey{ Feature: database.CryptoKeyFeatureTailnetResume, StartsAt: now.Add(-keyDuration), @@ -377,6 +379,7 @@ func Test_rotateKeys(t *testing.T) { Valid: false, }, }) + // Generate another deleted key to ensure we insert after the latest sequence. deletedKey := dbgen.CryptoKey(t, db, database.CryptoKey{ Feature: database.CryptoKeyFeatureTailnetResume, StartsAt: now.Add(-keyDuration), @@ -389,14 +392,14 @@ func Test_rotateKeys(t *testing.T) { // Insert a key that should be rotated. rotatedKey := dbgen.CryptoKey(t, db, database.CryptoKey{ - Feature: database.CryptoKeyFeatureWorkspaceApps, + Feature: database.CryptoKeyFeatureWorkspaceAppsAPIKey, StartsAt: now.Add(-keyDuration + time.Hour), Sequence: 42, }) // Insert a key that should not trigger an action. validKey := dbgen.CryptoKey(t, db, database.CryptoKey{ - Feature: database.CryptoKeyFeatureOidcConvert, + Feature: database.CryptoKeyFeatureOIDCConvert, StartsAt: now, Sequence: 17, }) @@ -406,26 +409,28 @@ func Test_rotateKeys(t *testing.T) { keys, err := db.GetCryptoKeys(ctx) require.NoError(t, err) - require.Len(t, keys, 4) + require.Len(t, keys, 5) kbf, err := keysByFeature(keys, database.AllCryptoKeyFeatureValues()) require.NoError(t, err) // No actions on OIDC convert. - require.Len(t, kbf[database.CryptoKeyFeatureOidcConvert], 1) + require.Len(t, kbf[database.CryptoKeyFeatureOIDCConvert], 1) // Workspace apps should have been rotated. - require.Len(t, kbf[database.CryptoKeyFeatureWorkspaceApps], 2) + require.Len(t, kbf[database.CryptoKeyFeatureWorkspaceAppsAPIKey], 2) // No existing key for tailnet resume should've // caused a key to be inserted. require.Len(t, kbf[database.CryptoKeyFeatureTailnetResume], 1) + require.Len(t, kbf[database.CryptoKeyFeatureWorkspaceAppsToken], 1) - oidcKey := kbf[database.CryptoKeyFeatureOidcConvert][0] + oidcKey := kbf[database.CryptoKeyFeatureOIDCConvert][0] tailnetKey := kbf[database.CryptoKeyFeatureTailnetResume][0] - requireKey(t, oidcKey, database.CryptoKeyFeatureOidcConvert, now, nullTime, validKey.Sequence) + appTokenKey := kbf[database.CryptoKeyFeatureWorkspaceAppsToken][0] + requireKey(t, oidcKey, database.CryptoKeyFeatureOIDCConvert, now, nullTime, validKey.Sequence) requireKey(t, tailnetKey, database.CryptoKeyFeatureTailnetResume, now, nullTime, deletedKey.Sequence+1) - - newKey := kbf[database.CryptoKeyFeatureWorkspaceApps][0] - oldKey := kbf[database.CryptoKeyFeatureWorkspaceApps][1] + requireKey(t, appTokenKey, database.CryptoKeyFeatureWorkspaceAppsToken, now, nullTime, 1) + newKey := kbf[database.CryptoKeyFeatureWorkspaceAppsAPIKey][0] + oldKey := kbf[database.CryptoKeyFeatureWorkspaceAppsAPIKey][1] if newKey.Sequence == rotatedKey.Sequence { oldKey, newKey = newKey, oldKey } @@ -433,8 +438,8 @@ func Test_rotateKeys(t *testing.T) { Time: rotatedKey.ExpiresAt(keyDuration).Add(WorkspaceAppsTokenDuration + time.Hour), Valid: true, } - requireKey(t, oldKey, database.CryptoKeyFeatureWorkspaceApps, rotatedKey.StartsAt.UTC(), deletesAt, rotatedKey.Sequence) - requireKey(t, newKey, database.CryptoKeyFeatureWorkspaceApps, rotatedKey.ExpiresAt(keyDuration), nullTime, rotatedKey.Sequence+1) + requireKey(t, oldKey, database.CryptoKeyFeatureWorkspaceAppsAPIKey, rotatedKey.StartsAt.UTC(), deletesAt, rotatedKey.Sequence) + requireKey(t, newKey, database.CryptoKeyFeatureWorkspaceAppsAPIKey, rotatedKey.ExpiresAt(keyDuration), nullTime, rotatedKey.Sequence+1) }) t.Run("UnknownFeature", func(t *testing.T) { @@ -478,11 +483,11 @@ func Test_rotateKeys(t *testing.T) { keyDuration: keyDuration, clock: clock, logger: logger, - features: []database.CryptoKeyFeature{database.CryptoKeyFeatureWorkspaceApps}, + features: []database.CryptoKeyFeature{database.CryptoKeyFeatureWorkspaceAppsAPIKey}, } expiringKey := dbgen.CryptoKey(t, db, database.CryptoKey{ - Feature: database.CryptoKeyFeatureWorkspaceApps, + Feature: database.CryptoKeyFeatureWorkspaceAppsAPIKey, StartsAt: now.Add(-keyDuration), Sequence: 345, }) @@ -522,19 +527,19 @@ func Test_rotateKeys(t *testing.T) { keyDuration: keyDuration, clock: clock, logger: logger, - features: []database.CryptoKeyFeature{database.CryptoKeyFeatureWorkspaceApps}, + features: []database.CryptoKeyFeature{database.CryptoKeyFeatureWorkspaceAppsAPIKey}, } now := dbnow(clock) expiredKey := dbgen.CryptoKey(t, db, database.CryptoKey{ - Feature: database.CryptoKeyFeatureWorkspaceApps, + Feature: database.CryptoKeyFeatureWorkspaceAppsAPIKey, StartsAt: now.Add(-keyDuration - 2*time.Hour), Sequence: 19, }) deletedKey := dbgen.CryptoKey(t, db, database.CryptoKey{ - Feature: database.CryptoKeyFeatureWorkspaceApps, + Feature: database.CryptoKeyFeatureWorkspaceAppsAPIKey, StartsAt: now, Sequence: 20, Secret: sql.NullString{ @@ -587,9 +592,11 @@ func requireKey(t *testing.T, key database.CryptoKey, feature database.CryptoKey require.NoError(t, err) switch key.Feature { - case database.CryptoKeyFeatureOidcConvert: + case database.CryptoKeyFeatureOIDCConvert: + require.Len(t, secret, 64) + case database.CryptoKeyFeatureWorkspaceAppsToken: require.Len(t, secret, 64) - case database.CryptoKeyFeatureWorkspaceApps: + case database.CryptoKeyFeatureWorkspaceAppsAPIKey: require.Len(t, secret, 32) case database.CryptoKeyFeatureTailnetResume: require.Len(t, secret, 64) diff --git a/coderd/cryptokeys/rotate_test.go b/coderd/cryptokeys/rotate_test.go index 190ad213b1153..9e147c8f921f0 100644 --- a/coderd/cryptokeys/rotate_test.go +++ b/coderd/cryptokeys/rotate_test.go @@ -34,8 +34,7 @@ func TestRotator(t *testing.T) { require.NoError(t, err) require.Len(t, dbkeys, 0) - err = cryptokeys.StartRotator(ctx, logger, db, cryptokeys.WithClock(clock)) - require.NoError(t, err) + cryptokeys.StartRotator(ctx, logger, db, cryptokeys.WithClock(clock)) // Fetch the keys from the database and ensure they // are as expected. @@ -58,7 +57,7 @@ func TestRotator(t *testing.T) { now := clock.Now().UTC() rotatingKey := dbgen.CryptoKey(t, db, database.CryptoKey{ - Feature: database.CryptoKeyFeatureWorkspaceApps, + Feature: database.CryptoKeyFeatureWorkspaceAppsAPIKey, StartsAt: now.Add(-cryptokeys.DefaultKeyDuration + time.Hour + time.Minute), Sequence: 12345, }) @@ -66,8 +65,7 @@ func TestRotator(t *testing.T) { trap := clock.Trap().TickerFunc() t.Cleanup(trap.Close) - err := cryptokeys.StartRotator(ctx, logger, db, cryptokeys.WithClock(clock)) - require.NoError(t, err) + cryptokeys.StartRotator(ctx, logger, db, cryptokeys.WithClock(clock)) initialKeyLen := len(database.AllCryptoKeyFeatureValues()) // Fetch the keys from the database and ensure they @@ -85,7 +83,7 @@ func TestRotator(t *testing.T) { require.NoError(t, err) require.Len(t, keys, initialKeyLen+1) - newKey, err := db.GetLatestCryptoKeyByFeature(ctx, database.CryptoKeyFeatureWorkspaceApps) + newKey, err := db.GetLatestCryptoKeyByFeature(ctx, database.CryptoKeyFeatureWorkspaceAppsAPIKey) require.NoError(t, err) require.Equal(t, rotatingKey.Sequence+1, newKey.Sequence) require.Equal(t, rotatingKey.ExpiresAt(cryptokeys.DefaultKeyDuration), newKey.StartsAt.UTC()) diff --git a/coderd/database/dbauthz/dbauthz.go b/coderd/database/dbauthz/dbauthz.go index 052f25450e6a5..35e4f09250ff8 100644 --- a/coderd/database/dbauthz/dbauthz.go +++ b/coderd/database/dbauthz/dbauthz.go @@ -228,6 +228,42 @@ var ( Scope: rbac.ScopeAll, }.WithCachedASTValue() + // See cryptokeys package. + subjectCryptoKeyRotator = rbac.Subject{ + FriendlyName: "Crypto Key Rotator", + ID: uuid.Nil.String(), + Roles: rbac.Roles([]rbac.Role{ + { + Identifier: rbac.RoleIdentifier{Name: "keyrotator"}, + DisplayName: "Key Rotator", + Site: rbac.Permissions(map[string][]policy.Action{ + rbac.ResourceCryptoKey.Type: {policy.WildcardSymbol}, + }), + Org: map[string][]rbac.Permission{}, + User: []rbac.Permission{}, + }, + }), + Scope: rbac.ScopeAll, + }.WithCachedASTValue() + + // See cryptokeys package. + subjectCryptoKeyReader = rbac.Subject{ + FriendlyName: "Crypto Key Reader", + ID: uuid.Nil.String(), + Roles: rbac.Roles([]rbac.Role{ + { + Identifier: rbac.RoleIdentifier{Name: "keyrotator"}, + DisplayName: "Key Rotator", + Site: rbac.Permissions(map[string][]policy.Action{ + rbac.ResourceCryptoKey.Type: {policy.WildcardSymbol}, + }), + Org: map[string][]rbac.Permission{}, + User: []rbac.Permission{}, + }, + }), + Scope: rbac.ScopeAll, + }.WithCachedASTValue() + subjectSystemRestricted = rbac.Subject{ FriendlyName: "System", ID: uuid.Nil.String(), @@ -281,6 +317,16 @@ func AsHangDetector(ctx context.Context) context.Context { return context.WithValue(ctx, authContextKey{}, subjectHangDetector) } +// AsKeyRotator returns a context with an actor that has permissions required for rotating crypto keys. +func AsKeyRotator(ctx context.Context) context.Context { + return context.WithValue(ctx, authContextKey{}, subjectCryptoKeyRotator) +} + +// AsKeyReader returns a context with an actor that has permissions required for reading crypto keys. +func AsKeyReader(ctx context.Context) context.Context { + return context.WithValue(ctx, authContextKey{}, subjectCryptoKeyReader) +} + // AsSystemRestricted returns a context with an actor that has permissions // required for various system operations (login, logout, metrics cache). func AsSystemRestricted(ctx context.Context) context.Context { diff --git a/coderd/database/dbauthz/dbauthz_test.go b/coderd/database/dbauthz/dbauthz_test.go index 6a34e88104ce1..439cf1bdaec19 100644 --- a/coderd/database/dbauthz/dbauthz_test.go +++ b/coderd/database/dbauthz/dbauthz_test.go @@ -2243,13 +2243,13 @@ func (s *MethodTestSuite) TestCryptoKeys() { })) s.Run("InsertCryptoKey", s.Subtest(func(db database.Store, check *expects) { check.Args(database.InsertCryptoKeyParams{ - Feature: database.CryptoKeyFeatureWorkspaceApps, + Feature: database.CryptoKeyFeatureWorkspaceAppsAPIKey, }). Asserts(rbac.ResourceCryptoKey, policy.ActionCreate) })) s.Run("DeleteCryptoKey", s.Subtest(func(db database.Store, check *expects) { key := dbgen.CryptoKey(s.T(), db, database.CryptoKey{ - Feature: database.CryptoKeyFeatureWorkspaceApps, + Feature: database.CryptoKeyFeatureWorkspaceAppsAPIKey, Sequence: 4, }) check.Args(database.DeleteCryptoKeyParams{ @@ -2259,7 +2259,7 @@ func (s *MethodTestSuite) TestCryptoKeys() { })) s.Run("GetCryptoKeyByFeatureAndSequence", s.Subtest(func(db database.Store, check *expects) { key := dbgen.CryptoKey(s.T(), db, database.CryptoKey{ - Feature: database.CryptoKeyFeatureWorkspaceApps, + Feature: database.CryptoKeyFeatureWorkspaceAppsAPIKey, Sequence: 4, }) check.Args(database.GetCryptoKeyByFeatureAndSequenceParams{ @@ -2269,14 +2269,14 @@ func (s *MethodTestSuite) TestCryptoKeys() { })) s.Run("GetLatestCryptoKeyByFeature", s.Subtest(func(db database.Store, check *expects) { dbgen.CryptoKey(s.T(), db, database.CryptoKey{ - Feature: database.CryptoKeyFeatureWorkspaceApps, + Feature: database.CryptoKeyFeatureWorkspaceAppsAPIKey, Sequence: 4, }) - check.Args(database.CryptoKeyFeatureWorkspaceApps).Asserts(rbac.ResourceCryptoKey, policy.ActionRead) + check.Args(database.CryptoKeyFeatureWorkspaceAppsAPIKey).Asserts(rbac.ResourceCryptoKey, policy.ActionRead) })) s.Run("UpdateCryptoKeyDeletesAt", s.Subtest(func(db database.Store, check *expects) { key := dbgen.CryptoKey(s.T(), db, database.CryptoKey{ - Feature: database.CryptoKeyFeatureWorkspaceApps, + Feature: database.CryptoKeyFeatureWorkspaceAppsAPIKey, Sequence: 4, }) check.Args(database.UpdateCryptoKeyDeletesAtParams{ @@ -2286,7 +2286,7 @@ func (s *MethodTestSuite) TestCryptoKeys() { }).Asserts(rbac.ResourceCryptoKey, policy.ActionUpdate) })) s.Run("GetCryptoKeysByFeature", s.Subtest(func(db database.Store, check *expects) { - check.Args(database.CryptoKeyFeatureWorkspaceApps). + check.Args(database.CryptoKeyFeatureWorkspaceAppsAPIKey). Asserts(rbac.ResourceCryptoKey, policy.ActionRead) })) } diff --git a/coderd/database/dbgen/dbgen.go b/coderd/database/dbgen/dbgen.go index 255c62f82aef2..69419b98c79b1 100644 --- a/coderd/database/dbgen/dbgen.go +++ b/coderd/database/dbgen/dbgen.go @@ -943,7 +943,7 @@ func CustomRole(t testing.TB, db database.Store, seed database.CustomRole) datab func CryptoKey(t testing.TB, db database.Store, seed database.CryptoKey) database.CryptoKey { t.Helper() - seed.Feature = takeFirst(seed.Feature, database.CryptoKeyFeatureWorkspaceApps) + seed.Feature = takeFirst(seed.Feature, database.CryptoKeyFeatureWorkspaceAppsAPIKey) // An empty string for the secret is interpreted as // a caller wanting a new secret to be generated. @@ -1048,9 +1048,11 @@ func takeFirst[Value comparable](values ...Value) Value { func newCryptoKeySecret(feature database.CryptoKeyFeature) (string, error) { switch feature { - case database.CryptoKeyFeatureWorkspaceApps: + case database.CryptoKeyFeatureWorkspaceAppsAPIKey: return generateCryptoKey(32) - case database.CryptoKeyFeatureOidcConvert: + case database.CryptoKeyFeatureWorkspaceAppsToken: + return generateCryptoKey(64) + case database.CryptoKeyFeatureOIDCConvert: return generateCryptoKey(64) case database.CryptoKeyFeatureTailnetResume: return generateCryptoKey(64) diff --git a/coderd/database/dump.sql b/coderd/database/dump.sql index 3a9a5a7a2d8f6..fc7819e38f218 100644 --- a/coderd/database/dump.sql +++ b/coderd/database/dump.sql @@ -38,7 +38,8 @@ CREATE TYPE build_reason AS ENUM ( ); CREATE TYPE crypto_key_feature AS ENUM ( - 'workspace_apps', + 'workspace_apps_token', + 'workspace_apps_api_key', 'oidc_convert', 'tailnet_resume' ); diff --git a/coderd/database/migrations/000271_cryptokey_features.down.sql b/coderd/database/migrations/000271_cryptokey_features.down.sql new file mode 100644 index 0000000000000..7cdd00d222da8 --- /dev/null +++ b/coderd/database/migrations/000271_cryptokey_features.down.sql @@ -0,0 +1,18 @@ +-- Step 1: Remove the new entries from crypto_keys table +DELETE FROM crypto_keys +WHERE feature IN ('workspace_apps_token', 'workspace_apps_api_key'); + +CREATE TYPE old_crypto_key_feature AS ENUM ( + 'workspace_apps', + 'oidc_convert', + 'tailnet_resume' +); + +ALTER TABLE crypto_keys + ALTER COLUMN feature TYPE old_crypto_key_feature + USING (feature::text::old_crypto_key_feature); + +DROP TYPE crypto_key_feature; + +ALTER TYPE old_crypto_key_feature RENAME TO crypto_key_feature; + diff --git a/coderd/database/migrations/000271_cryptokey_features.up.sql b/coderd/database/migrations/000271_cryptokey_features.up.sql new file mode 100644 index 0000000000000..bca75d220d0c7 --- /dev/null +++ b/coderd/database/migrations/000271_cryptokey_features.up.sql @@ -0,0 +1,18 @@ +-- Create a new enum type with the desired values +CREATE TYPE new_crypto_key_feature AS ENUM ( + 'workspace_apps_token', + 'workspace_apps_api_key', + 'oidc_convert', + 'tailnet_resume' +); + +DELETE FROM crypto_keys WHERE feature = 'workspace_apps'; + +-- Drop the old type and rename the new one +ALTER TABLE crypto_keys + ALTER COLUMN feature TYPE new_crypto_key_feature + USING (feature::text::new_crypto_key_feature); + +DROP TYPE crypto_key_feature; + +ALTER TYPE new_crypto_key_feature RENAME TO crypto_key_feature; diff --git a/coderd/database/migrations/testdata/fixtures/000271_cryptokey_features.up.sql b/coderd/database/migrations/testdata/fixtures/000271_cryptokey_features.up.sql new file mode 100644 index 0000000000000..5cb2cd4c95509 --- /dev/null +++ b/coderd/database/migrations/testdata/fixtures/000271_cryptokey_features.up.sql @@ -0,0 +1,40 @@ +INSERT INTO crypto_keys (feature, sequence, secret, secret_key_id, starts_at, deletes_at) +VALUES ( + 'workspace_apps_token', + 1, + 'abc', + NULL, + '1970-01-01 00:00:00 UTC'::timestamptz, + '2100-01-01 00:00:00 UTC'::timestamptz +); + +INSERT INTO crypto_keys (feature, sequence, secret, secret_key_id, starts_at, deletes_at) +VALUES ( + 'workspace_apps_api_key', + 1, + 'def', + NULL, + '1970-01-01 00:00:00 UTC'::timestamptz, + '2100-01-01 00:00:00 UTC'::timestamptz +); + +INSERT INTO crypto_keys (feature, sequence, secret, secret_key_id, starts_at, deletes_at) +VALUES ( + 'oidc_convert', + 2, + 'ghi', + NULL, + '1970-01-01 00:00:00 UTC'::timestamptz, + '2100-01-01 00:00:00 UTC'::timestamptz +); + +INSERT INTO crypto_keys (feature, sequence, secret, secret_key_id, starts_at, deletes_at) +VALUES ( + 'tailnet_resume', + 2, + 'jkl', + NULL, + '1970-01-01 00:00:00 UTC'::timestamptz, + '2100-01-01 00:00:00 UTC'::timestamptz +); + diff --git a/coderd/database/models.go b/coderd/database/models.go index 1207587d46529..e7d90acf5ea94 100644 --- a/coderd/database/models.go +++ b/coderd/database/models.go @@ -345,9 +345,10 @@ func AllBuildReasonValues() []BuildReason { type CryptoKeyFeature string const ( - CryptoKeyFeatureWorkspaceApps CryptoKeyFeature = "workspace_apps" - CryptoKeyFeatureOidcConvert CryptoKeyFeature = "oidc_convert" - CryptoKeyFeatureTailnetResume CryptoKeyFeature = "tailnet_resume" + CryptoKeyFeatureWorkspaceAppsToken CryptoKeyFeature = "workspace_apps_token" + CryptoKeyFeatureWorkspaceAppsAPIKey CryptoKeyFeature = "workspace_apps_api_key" + CryptoKeyFeatureOIDCConvert CryptoKeyFeature = "oidc_convert" + CryptoKeyFeatureTailnetResume CryptoKeyFeature = "tailnet_resume" ) func (e *CryptoKeyFeature) Scan(src interface{}) error { @@ -387,8 +388,9 @@ func (ns NullCryptoKeyFeature) Value() (driver.Value, error) { func (e CryptoKeyFeature) Valid() bool { switch e { - case CryptoKeyFeatureWorkspaceApps, - CryptoKeyFeatureOidcConvert, + case CryptoKeyFeatureWorkspaceAppsToken, + CryptoKeyFeatureWorkspaceAppsAPIKey, + CryptoKeyFeatureOIDCConvert, CryptoKeyFeatureTailnetResume: return true } @@ -397,8 +399,9 @@ func (e CryptoKeyFeature) Valid() bool { func AllCryptoKeyFeatureValues() []CryptoKeyFeature { return []CryptoKeyFeature{ - CryptoKeyFeatureWorkspaceApps, - CryptoKeyFeatureOidcConvert, + CryptoKeyFeatureWorkspaceAppsToken, + CryptoKeyFeatureWorkspaceAppsAPIKey, + CryptoKeyFeatureOIDCConvert, CryptoKeyFeatureTailnetResume, } } diff --git a/coderd/database/sqlc.yaml b/coderd/database/sqlc.yaml index a70e45a522989..257c95ddb2d7a 100644 --- a/coderd/database/sqlc.yaml +++ b/coderd/database/sqlc.yaml @@ -135,6 +135,8 @@ sql: api_key_id: APIKeyID callback_url: CallbackURL login_type_oauth2_provider_app: LoginTypeOAuth2ProviderApp + crypto_key_feature_workspace_apps_api_key: CryptoKeyFeatureWorkspaceAppsAPIKey + crypto_key_feature_oidc_convert: CryptoKeyFeatureOIDCConvert rules: - name: do-not-use-public-schema-in-queries message: "do not use public schema in queries" diff --git a/coderd/jwtutils/jwe.go b/coderd/jwtutils/jwe.go index d03816a477a26..bc9d0ddd2a9c8 100644 --- a/coderd/jwtutils/jwe.go +++ b/coderd/jwtutils/jwe.go @@ -65,6 +65,12 @@ func Encrypt(ctx context.Context, e EncryptKeyProvider, claims Claims) (string, return compact, nil } +func WithDecryptExpected(expected jwt.Expected) func(*DecryptOptions) { + return func(opts *DecryptOptions) { + opts.RegisteredClaims = expected + } +} + // DecryptOptions are options for decrypting a JWE. type DecryptOptions struct { RegisteredClaims jwt.Expected @@ -100,7 +106,7 @@ func Decrypt(ctx context.Context, d DecryptKeyProvider, token string, claims Cla kid := object.Header.KeyID if kid == "" { - return xerrors.Errorf("expected %q header to be a string", keyIDHeaderKey) + return ErrMissingKeyID } key, err := d.DecryptingKey(ctx, kid) diff --git a/coderd/jwtutils/jws.go b/coderd/jwtutils/jws.go index 73f35e672492d..0c8ca9aa30f39 100644 --- a/coderd/jwtutils/jws.go +++ b/coderd/jwtutils/jws.go @@ -10,10 +10,27 @@ import ( "golang.org/x/xerrors" ) +var ErrMissingKeyID = xerrors.New("missing key ID") + const ( keyIDHeaderKey = "kid" ) +// RegisteredClaims is a convenience type for embedding jwt.Claims. It should be +// preferred over embedding jwt.Claims directly since it will ensure that certain fields are set. +type RegisteredClaims jwt.Claims + +func (r RegisteredClaims) Validate(e jwt.Expected) error { + if r.Expiry == nil { + return xerrors.Errorf("expiry is required") + } + if e.Time.IsZero() { + return xerrors.Errorf("expected time is required") + } + + return (jwt.Claims(r)).Validate(e) +} + // Claims defines the payload for a JWT. Most callers // should embed jwt.Claims type Claims interface { @@ -24,6 +41,11 @@ const ( signingAlgo = jose.HS512 ) +type SigningKeyManager interface { + SigningKeyProvider + VerifyKeyProvider +} + type SigningKeyProvider interface { SigningKey(ctx context.Context) (id string, key interface{}, err error) } @@ -75,6 +97,12 @@ type VerifyOptions struct { SignatureAlgorithm jose.SignatureAlgorithm } +func WithVerifyExpected(expected jwt.Expected) func(*VerifyOptions) { + return func(opts *VerifyOptions) { + opts.RegisteredClaims = expected + } +} + // Verify verifies that a token was signed by the provided key. It unmarshals into the provided claims. func Verify(ctx context.Context, v VerifyKeyProvider, token string, claims Claims, opts ...func(*VerifyOptions)) error { options := VerifyOptions{ @@ -105,7 +133,7 @@ func Verify(ctx context.Context, v VerifyKeyProvider, token string, claims Claim kid := signature.Header.KeyID if kid == "" { - return xerrors.Errorf("expected %q header to be a string", keyIDHeaderKey) + return ErrMissingKeyID } key, err := v.VerifyingKey(ctx, kid) @@ -125,3 +153,35 @@ func Verify(ctx context.Context, v VerifyKeyProvider, token string, claims Claim return claims.Validate(options.RegisteredClaims) } + +// StaticKey fulfills the SigningKeycache and EncryptionKeycache interfaces. Useful for testing. +type StaticKey struct { + ID string + Key interface{} +} + +func (s StaticKey) SigningKey(_ context.Context) (string, interface{}, error) { + return s.ID, s.Key, nil +} + +func (s StaticKey) VerifyingKey(_ context.Context, id string) (interface{}, error) { + if id != s.ID { + return nil, xerrors.Errorf("invalid id %q", id) + } + return s.Key, nil +} + +func (s StaticKey) EncryptingKey(_ context.Context) (string, interface{}, error) { + return s.ID, s.Key, nil +} + +func (s StaticKey) DecryptingKey(_ context.Context, id string) (interface{}, error) { + if id != s.ID { + return nil, xerrors.Errorf("invalid id %q", id) + } + return s.Key, nil +} + +func (StaticKey) Close() error { + return nil +} diff --git a/coderd/jwtutils/jwt_test.go b/coderd/jwtutils/jwt_test.go index 697e5d210d858..5d1f4d48bdb4a 100644 --- a/coderd/jwtutils/jwt_test.go +++ b/coderd/jwtutils/jwt_test.go @@ -236,11 +236,11 @@ func TestJWS(t *testing.T) { ctx = testutil.Context(t, testutil.WaitShort) db, _ = dbtestutil.NewDB(t) _ = dbgen.CryptoKey(t, db, database.CryptoKey{ - Feature: database.CryptoKeyFeatureOidcConvert, + Feature: database.CryptoKeyFeatureOIDCConvert, StartsAt: time.Now(), }) log = slogtest.Make(t, nil) - fetcher = &cryptokeys.DBFetcher{DB: db, Feature: database.CryptoKeyFeatureOidcConvert} + fetcher = &cryptokeys.DBFetcher{DB: db} ) cache, err := cryptokeys.NewSigningCache(ctx, log, fetcher, codersdk.CryptoKeyFeatureOIDCConvert) @@ -326,15 +326,15 @@ func TestJWE(t *testing.T) { ctx = testutil.Context(t, testutil.WaitShort) db, _ = dbtestutil.NewDB(t) _ = dbgen.CryptoKey(t, db, database.CryptoKey{ - Feature: database.CryptoKeyFeatureWorkspaceApps, + Feature: database.CryptoKeyFeatureWorkspaceAppsAPIKey, StartsAt: time.Now(), }) log = slogtest.Make(t, nil) - fetcher = &cryptokeys.DBFetcher{DB: db, Feature: database.CryptoKeyFeatureWorkspaceApps} + fetcher = &cryptokeys.DBFetcher{DB: db} ) - cache, err := cryptokeys.NewEncryptionCache(ctx, log, fetcher, codersdk.CryptoKeyFeatureWorkspaceApp) + cache, err := cryptokeys.NewEncryptionCache(ctx, log, fetcher, codersdk.CryptoKeyFeatureWorkspaceAppsAPIKey) require.NoError(t, err) claims := testClaims{ diff --git a/coderd/userauth.go b/coderd/userauth.go index 85ab0d77e6cc1..f1a19d77d23d0 100644 --- a/coderd/userauth.go +++ b/coderd/userauth.go @@ -15,7 +15,8 @@ import ( "time" "github.com/coreos/go-oidc/v3/oidc" - "github.com/golang-jwt/jwt/v4" + "github.com/go-jose/go-jose/v4" + "github.com/go-jose/go-jose/v4/jwt" "github.com/google/go-github/v43/github" "github.com/google/uuid" "github.com/moby/moby/pkg/namesgenerator" @@ -23,6 +24,9 @@ import ( "golang.org/x/xerrors" "cdr.dev/slog" + "github.com/coder/coder/v2/coderd/cryptokeys" + "github.com/coder/coder/v2/coderd/idpsync" + "github.com/coder/coder/v2/coderd/jwtutils" "github.com/coder/coder/v2/coderd/apikey" "github.com/coder/coder/v2/coderd/audit" @@ -32,7 +36,6 @@ import ( "github.com/coder/coder/v2/coderd/externalauth" "github.com/coder/coder/v2/coderd/httpapi" "github.com/coder/coder/v2/coderd/httpmw" - "github.com/coder/coder/v2/coderd/idpsync" "github.com/coder/coder/v2/coderd/notifications" "github.com/coder/coder/v2/coderd/promoauth" "github.com/coder/coder/v2/coderd/rbac" @@ -49,7 +52,7 @@ const ( ) type OAuthConvertStateClaims struct { - jwt.RegisteredClaims + jwtutils.RegisteredClaims UserID uuid.UUID `json:"user_id"` State string `json:"state"` @@ -57,6 +60,10 @@ type OAuthConvertStateClaims struct { ToLoginType codersdk.LoginType `json:"to_login_type"` } +func (o *OAuthConvertStateClaims) Validate(e jwt.Expected) error { + return o.RegisteredClaims.Validate(e) +} + // postConvertLoginType replies with an oauth state token capable of converting // the user to an oauth user. // @@ -149,11 +156,11 @@ func (api *API) postConvertLoginType(rw http.ResponseWriter, r *http.Request) { // Eg: Developers with more than 1 deployment. now := time.Now() claims := &OAuthConvertStateClaims{ - RegisteredClaims: jwt.RegisteredClaims{ + RegisteredClaims: jwtutils.RegisteredClaims{ Issuer: api.DeploymentID, Subject: stateString, Audience: []string{user.ID.String()}, - ExpiresAt: jwt.NewNumericDate(now.Add(time.Minute * 5)), + Expiry: jwt.NewNumericDate(now.Add(time.Minute * 5)), NotBefore: jwt.NewNumericDate(now.Add(time.Second * -1)), IssuedAt: jwt.NewNumericDate(now), ID: uuid.NewString(), @@ -164,9 +171,7 @@ func (api *API) postConvertLoginType(rw http.ResponseWriter, r *http.Request) { ToLoginType: req.ToType, } - token := jwt.NewWithClaims(jwt.SigningMethodHS512, claims) - // Key must be a byte slice, not an array. So make sure to include the [:] - tokenString, err := token.SignedString(api.OAuthSigningKey[:]) + token, err := jwtutils.Sign(ctx, api.OIDCConvertKeyCache, claims) if err != nil { httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ Message: "Internal error signing state jwt.", @@ -176,8 +181,8 @@ func (api *API) postConvertLoginType(rw http.ResponseWriter, r *http.Request) { } aReq.New = database.AuditOAuthConvertState{ - CreatedAt: claims.IssuedAt.Time, - ExpiresAt: claims.ExpiresAt.Time, + CreatedAt: claims.IssuedAt.Time(), + ExpiresAt: claims.Expiry.Time(), FromLoginType: database.LoginType(claims.FromLoginType), ToLoginType: database.LoginType(claims.ToLoginType), UserID: claims.UserID, @@ -186,8 +191,8 @@ func (api *API) postConvertLoginType(rw http.ResponseWriter, r *http.Request) { http.SetCookie(rw, &http.Cookie{ Name: OAuthConvertCookieValue, Path: "/", - Value: tokenString, - Expires: claims.ExpiresAt.Time, + Value: token, + Expires: claims.Expiry.Time(), Secure: api.SecureAuthCookie, HttpOnly: true, // Must be SameSite to work on the redirected auth flow from the @@ -196,7 +201,7 @@ func (api *API) postConvertLoginType(rw http.ResponseWriter, r *http.Request) { }) httpapi.Write(ctx, rw, http.StatusCreated, codersdk.OAuthConversionResponse{ StateString: stateString, - ExpiresAt: claims.ExpiresAt.Time, + ExpiresAt: claims.Expiry.Time(), ToType: claims.ToLoginType, UserID: claims.UserID, }) @@ -1677,10 +1682,9 @@ func (api *API) convertUserToOauth(ctx context.Context, r *http.Request, db data } } var claims OAuthConvertStateClaims - token, err := jwt.ParseWithClaims(jwtCookie.Value, &claims, func(_ *jwt.Token) (interface{}, error) { - return api.OAuthSigningKey[:], nil - }) - if xerrors.Is(err, jwt.ErrSignatureInvalid) || !token.Valid { + + err = jwtutils.Verify(ctx, api.OIDCConvertKeyCache, jwtCookie.Value, &claims) + if xerrors.Is(err, cryptokeys.ErrKeyNotFound) || xerrors.Is(err, cryptokeys.ErrKeyInvalid) || xerrors.Is(err, jose.ErrCryptoFailure) || xerrors.Is(err, jwtutils.ErrMissingKeyID) { // These errors are probably because the user is mixing 2 coder deployments. return database.User{}, idpsync.HTTPError{ Code: http.StatusBadRequest, @@ -1709,7 +1713,7 @@ func (api *API) convertUserToOauth(ctx context.Context, r *http.Request, db data oauthConvertAudit.UserID = claims.UserID oauthConvertAudit.Old = user - if claims.RegisteredClaims.Issuer != api.DeploymentID { + if claims.Issuer != api.DeploymentID { return database.User{}, idpsync.HTTPError{ Code: http.StatusForbidden, Msg: "Request to convert login type failed. Issuer mismatch. Found a cookie from another coder deployment, please try again.", diff --git a/coderd/userauth_test.go b/coderd/userauth_test.go index 20dfe7f723899..6386be7eb8be4 100644 --- a/coderd/userauth_test.go +++ b/coderd/userauth_test.go @@ -3,6 +3,8 @@ package coderd_test import ( "context" "crypto" + "crypto/rand" + "encoding/json" "fmt" "io" "net/http" @@ -13,6 +15,7 @@ import ( "time" "github.com/coreos/go-oidc/v3/oidc" + "github.com/go-jose/go-jose/v4" "github.com/golang-jwt/jwt/v4" "github.com/google/go-github/v43/github" "github.com/google/uuid" @@ -27,10 +30,12 @@ import ( "github.com/coder/coder/v2/coderd/audit" "github.com/coder/coder/v2/coderd/coderdtest" "github.com/coder/coder/v2/coderd/coderdtest/oidctest" + "github.com/coder/coder/v2/coderd/cryptokeys" "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/dbauthz" "github.com/coder/coder/v2/coderd/database/dbgen" "github.com/coder/coder/v2/coderd/database/dbtestutil" + "github.com/coder/coder/v2/coderd/jwtutils" "github.com/coder/coder/v2/coderd/notifications" "github.com/coder/coder/v2/coderd/promoauth" "github.com/coder/coder/v2/codersdk" @@ -1316,6 +1321,7 @@ func TestUserOIDC(t *testing.T) { owner := coderdtest.CreateFirstUser(t, client) user, userData := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID) + require.Equal(t, codersdk.LoginTypePassword, userData.LoginType) claims := jwt.MapClaims{ "email": userData.Email, @@ -1323,15 +1329,17 @@ func TestUserOIDC(t *testing.T) { var err error user.HTTPClient.Jar, err = cookiejar.New(nil) require.NoError(t, err) + user.HTTPClient.Transport = http.DefaultTransport.(*http.Transport).Clone() ctx := testutil.Context(t, testutil.WaitShort) + convertResponse, err := user.ConvertLoginType(ctx, codersdk.ConvertLoginRequest{ ToType: codersdk.LoginTypeOIDC, Password: "SomeSecurePassword!", }) require.NoError(t, err) - fake.LoginWithClient(t, user, claims, func(r *http.Request) { + _, _ = fake.LoginWithClient(t, user, claims, func(r *http.Request) { r.URL.RawQuery = url.Values{ "oidc_merge_state": {convertResponse.StateString}, }.Encode() @@ -1341,6 +1349,99 @@ func TestUserOIDC(t *testing.T) { r.AddCookie(cookie) } }) + + info, err := client.User(ctx, userData.ID.String()) + require.NoError(t, err) + require.Equal(t, codersdk.LoginTypeOIDC, info.LoginType) + }) + + t.Run("BadJWT", func(t *testing.T) { + t.Parallel() + + var ( + ctx = testutil.Context(t, testutil.WaitMedium) + logger = slogtest.Make(t, nil) + ) + + auditor := audit.NewMock() + fake := oidctest.NewFakeIDP(t, + oidctest.WithRefresh(func(_ string) error { + return xerrors.New("refreshing token should never occur") + }), + oidctest.WithServing(), + ) + cfg := fake.OIDCConfig(t, nil, func(cfg *coderd.OIDCConfig) { + cfg.AllowSignups = true + }) + + db, ps := dbtestutil.NewDB(t) + fetcher := &cryptokeys.DBFetcher{ + DB: db, + } + + kc, err := cryptokeys.NewSigningCache(ctx, logger, fetcher, codersdk.CryptoKeyFeatureOIDCConvert) + require.NoError(t, err) + + client := coderdtest.New(t, &coderdtest.Options{ + Auditor: auditor, + OIDCConfig: cfg, + Database: db, + Pubsub: ps, + OIDCConvertKeyCache: kc, + }) + + owner := coderdtest.CreateFirstUser(t, client) + user, userData := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID) + + claims := jwt.MapClaims{ + "email": userData.Email, + } + user.HTTPClient.Jar, err = cookiejar.New(nil) + require.NoError(t, err) + user.HTTPClient.Transport = http.DefaultTransport.(*http.Transport).Clone() + + convertResponse, err := user.ConvertLoginType(ctx, codersdk.ConvertLoginRequest{ + ToType: codersdk.LoginTypeOIDC, + Password: "SomeSecurePassword!", + }) + require.NoError(t, err) + + // Update the cookie to use a bad signing key. We're asserting the behavior of the scenario + // where a JWT gets minted on an old version of Coder but gets verified on a new version. + _, resp := fake.AttemptLogin(t, user, claims, func(r *http.Request) { + r.URL.RawQuery = url.Values{ + "oidc_merge_state": {convertResponse.StateString}, + }.Encode() + r.Header.Set(codersdk.SessionTokenHeader, user.SessionToken()) + + cookies := user.HTTPClient.Jar.Cookies(user.URL) + for i, cookie := range cookies { + if cookie.Name != coderd.OAuthConvertCookieValue { + continue + } + + jwt := cookie.Value + var claims coderd.OAuthConvertStateClaims + err := jwtutils.Verify(ctx, kc, jwt, &claims) + require.NoError(t, err) + badJWT := generateBadJWT(t, claims) + cookie.Value = badJWT + cookies[i] = cookie + } + + user.HTTPClient.Jar.SetCookies(user.URL, cookies) + + for _, cookie := range cookies { + fmt.Printf("cookie: %+v\n", cookie) + r.AddCookie(cookie) + } + }) + defer resp.Body.Close() + require.Equal(t, http.StatusBadRequest, resp.StatusCode) + var respErr codersdk.Response + err = json.NewDecoder(resp.Body).Decode(&respErr) + require.NoError(t, err) + require.Contains(t, respErr.Message, "Using an invalid jwt to authorize this action.") }) t.Run("AlternateUsername", func(t *testing.T) { @@ -2022,3 +2123,24 @@ func inflateClaims(t testing.TB, seed jwt.MapClaims, size int) jwt.MapClaims { seed["random_data"] = junk return seed } + +// generateBadJWT generates a JWT with a random key. It's intended to emulate the old-style JWT's we generated. +func generateBadJWT(t *testing.T, claims interface{}) string { + t.Helper() + + var buf [64]byte + _, err := rand.Read(buf[:]) + require.NoError(t, err) + signer, err := jose.NewSigner(jose.SigningKey{ + Algorithm: jose.HS512, + Key: buf[:], + }, nil) + require.NoError(t, err) + payload, err := json.Marshal(claims) + require.NoError(t, err) + signed, err := signer.Sign(payload) + require.NoError(t, err) + compact, err := signed.CompactSerialize() + require.NoError(t, err) + return compact +} diff --git a/coderd/workspaceagents.go b/coderd/workspaceagents.go index 6ea631f2e7d0c..a181697f27279 100644 --- a/coderd/workspaceagents.go +++ b/coderd/workspaceagents.go @@ -32,6 +32,7 @@ import ( "github.com/coder/coder/v2/coderd/externalauth" "github.com/coder/coder/v2/coderd/httpapi" "github.com/coder/coder/v2/coderd/httpmw" + "github.com/coder/coder/v2/coderd/jwtutils" "github.com/coder/coder/v2/coderd/rbac/policy" "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/codersdk/agentsdk" @@ -852,8 +853,12 @@ func (api *API) workspaceAgentClientCoordinate(rw http.ResponseWriter, r *http.R ) if resumeToken != "" { var err error - peerID, err = api.Options.CoordinatorResumeTokenProvider.VerifyResumeToken(resumeToken) - if err != nil { + peerID, err = api.Options.CoordinatorResumeTokenProvider.VerifyResumeToken(ctx, resumeToken) + // If the token is missing the key ID, it's probably an old token in which + // case we just want to generate a new peer ID. + if xerrors.Is(err, jwtutils.ErrMissingKeyID) { + peerID = uuid.New() + } else if err != nil { httpapi.Write(ctx, rw, http.StatusUnauthorized, codersdk.Response{ Message: workspacesdk.CoordinateAPIInvalidResumeToken, Detail: err.Error(), @@ -862,9 +867,10 @@ func (api *API) workspaceAgentClientCoordinate(rw http.ResponseWriter, r *http.R }, }) return + } else { + api.Logger.Debug(ctx, "accepted coordinate resume token for peer", + slog.F("peer_id", peerID.String())) } - api.Logger.Debug(ctx, "accepted coordinate resume token for peer", - slog.F("peer_id", peerID.String())) } api.WebsocketWaitMutex.Lock() diff --git a/coderd/workspaceagents_test.go b/coderd/workspaceagents_test.go index 8c0801a914d61..ba677975471d6 100644 --- a/coderd/workspaceagents_test.go +++ b/coderd/workspaceagents_test.go @@ -13,6 +13,7 @@ import ( "testing" "time" + "github.com/go-jose/go-jose/v4/jwt" "github.com/google/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -36,6 +37,7 @@ import ( "github.com/coder/coder/v2/coderd/database/dbtime" "github.com/coder/coder/v2/coderd/database/pubsub" "github.com/coder/coder/v2/coderd/externalauth" + "github.com/coder/coder/v2/coderd/jwtutils" "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/codersdk/agentsdk" "github.com/coder/coder/v2/codersdk/workspacesdk" @@ -531,20 +533,20 @@ func newResumeTokenRecordingProvider(t testing.TB, underlying tailnet.ResumeToke } } -func (r *resumeTokenRecordingProvider) GenerateResumeToken(peerID uuid.UUID) (*tailnetproto.RefreshResumeTokenResponse, error) { +func (r *resumeTokenRecordingProvider) GenerateResumeToken(ctx context.Context, peerID uuid.UUID) (*tailnetproto.RefreshResumeTokenResponse, error) { select { case r.generateCalls <- peerID: - return r.ResumeTokenProvider.GenerateResumeToken(peerID) + return r.ResumeTokenProvider.GenerateResumeToken(ctx, peerID) default: r.t.Error("generateCalls full") return nil, xerrors.New("generateCalls full") } } -func (r *resumeTokenRecordingProvider) VerifyResumeToken(token string) (uuid.UUID, error) { +func (r *resumeTokenRecordingProvider) VerifyResumeToken(ctx context.Context, token string) (uuid.UUID, error) { select { case r.verifyCalls <- token: - return r.ResumeTokenProvider.VerifyResumeToken(token) + return r.ResumeTokenProvider.VerifyResumeToken(ctx, token) default: r.t.Error("verifyCalls full") return uuid.Nil, xerrors.New("verifyCalls full") @@ -554,69 +556,136 @@ func (r *resumeTokenRecordingProvider) VerifyResumeToken(token string) (uuid.UUI func TestWorkspaceAgentClientCoordinate_ResumeToken(t *testing.T) { t.Parallel() - logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) - clock := quartz.NewMock(t) - resumeTokenSigningKey, err := tailnet.GenerateResumeTokenSigningKey() - require.NoError(t, err) - resumeTokenProvider := newResumeTokenRecordingProvider( - t, - tailnet.NewResumeTokenKeyProvider(resumeTokenSigningKey, clock, time.Hour), - ) - client, closer, api := coderdtest.NewWithAPI(t, &coderdtest.Options{ - Coordinator: tailnet.NewCoordinator(logger), - CoordinatorResumeTokenProvider: resumeTokenProvider, + t.Run("OK", func(t *testing.T) { + t.Parallel() + + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + clock := quartz.NewMock(t) + resumeTokenSigningKey, err := tailnet.GenerateResumeTokenSigningKey() + mgr := jwtutils.StaticKey{ + ID: uuid.New().String(), + Key: resumeTokenSigningKey[:], + } + require.NoError(t, err) + resumeTokenProvider := newResumeTokenRecordingProvider( + t, + tailnet.NewResumeTokenKeyProvider(mgr, clock, time.Hour), + ) + client, closer, api := coderdtest.NewWithAPI(t, &coderdtest.Options{ + Coordinator: tailnet.NewCoordinator(logger), + CoordinatorResumeTokenProvider: resumeTokenProvider, + }) + defer closer.Close() + user := coderdtest.CreateFirstUser(t, client) + + // Create a workspace with an agent. No need to connect it since clients can + // still connect to the coordinator while the agent isn't connected. + r := dbfake.WorkspaceBuild(t, api.Database, database.WorkspaceTable{ + OrganizationID: user.OrganizationID, + OwnerID: user.UserID, + }).WithAgent().Do() + agentTokenUUID, err := uuid.Parse(r.AgentToken) + require.NoError(t, err) + ctx := testutil.Context(t, testutil.WaitLong) + agentAndBuild, err := api.Database.GetWorkspaceAgentAndLatestBuildByAuthToken(dbauthz.AsSystemRestricted(ctx), agentTokenUUID) //nolint + require.NoError(t, err) + + // Connect with no resume token, and ensure that the peer ID is set to a + // random value. + originalResumeToken, err := connectToCoordinatorAndFetchResumeToken(ctx, logger, client, agentAndBuild.WorkspaceAgent.ID, "") + require.NoError(t, err) + originalPeerID := testutil.RequireRecvCtx(ctx, t, resumeTokenProvider.generateCalls) + require.NotEqual(t, originalPeerID, uuid.Nil) + + // Connect with a valid resume token, and ensure that the peer ID is set to + // the stored value. + clock.Advance(time.Second) + newResumeToken, err := connectToCoordinatorAndFetchResumeToken(ctx, logger, client, agentAndBuild.WorkspaceAgent.ID, originalResumeToken) + require.NoError(t, err) + verifiedToken := testutil.RequireRecvCtx(ctx, t, resumeTokenProvider.verifyCalls) + require.Equal(t, originalResumeToken, verifiedToken) + newPeerID := testutil.RequireRecvCtx(ctx, t, resumeTokenProvider.generateCalls) + require.Equal(t, originalPeerID, newPeerID) + require.NotEqual(t, originalResumeToken, newResumeToken) + + // Connect with an invalid resume token, and ensure that the request is + // rejected. + clock.Advance(time.Second) + _, err = connectToCoordinatorAndFetchResumeToken(ctx, logger, client, agentAndBuild.WorkspaceAgent.ID, "invalid") + require.Error(t, err) + var sdkErr *codersdk.Error + require.ErrorAs(t, err, &sdkErr) + require.Equal(t, http.StatusUnauthorized, sdkErr.StatusCode()) + require.Len(t, sdkErr.Validations, 1) + require.Equal(t, "resume_token", sdkErr.Validations[0].Field) + verifiedToken = testutil.RequireRecvCtx(ctx, t, resumeTokenProvider.verifyCalls) + require.Equal(t, "invalid", verifiedToken) + + select { + case <-resumeTokenProvider.generateCalls: + t.Fatal("unexpected peer ID in channel") + default: + } }) - defer closer.Close() - user := coderdtest.CreateFirstUser(t, client) - // Create a workspace with an agent. No need to connect it since clients can - // still connect to the coordinator while the agent isn't connected. - r := dbfake.WorkspaceBuild(t, api.Database, database.WorkspaceTable{ - OrganizationID: user.OrganizationID, - OwnerID: user.UserID, - }).WithAgent().Do() - agentTokenUUID, err := uuid.Parse(r.AgentToken) - require.NoError(t, err) - ctx := testutil.Context(t, testutil.WaitLong) - agentAndBuild, err := api.Database.GetWorkspaceAgentAndLatestBuildByAuthToken(dbauthz.AsSystemRestricted(ctx), agentTokenUUID) //nolint - require.NoError(t, err) + t.Run("BadJWT", func(t *testing.T) { + t.Parallel() - // Connect with no resume token, and ensure that the peer ID is set to a - // random value. - originalResumeToken, err := connectToCoordinatorAndFetchResumeToken(ctx, logger, client, agentAndBuild.WorkspaceAgent.ID, "") - require.NoError(t, err) - originalPeerID := testutil.RequireRecvCtx(ctx, t, resumeTokenProvider.generateCalls) - require.NotEqual(t, originalPeerID, uuid.Nil) + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + clock := quartz.NewMock(t) + resumeTokenSigningKey, err := tailnet.GenerateResumeTokenSigningKey() + mgr := jwtutils.StaticKey{ + ID: uuid.New().String(), + Key: resumeTokenSigningKey[:], + } + require.NoError(t, err) + resumeTokenProvider := newResumeTokenRecordingProvider( + t, + tailnet.NewResumeTokenKeyProvider(mgr, clock, time.Hour), + ) + client, closer, api := coderdtest.NewWithAPI(t, &coderdtest.Options{ + Coordinator: tailnet.NewCoordinator(logger), + CoordinatorResumeTokenProvider: resumeTokenProvider, + }) + defer closer.Close() + user := coderdtest.CreateFirstUser(t, client) - // Connect with a valid resume token, and ensure that the peer ID is set to - // the stored value. - clock.Advance(time.Second) - newResumeToken, err := connectToCoordinatorAndFetchResumeToken(ctx, logger, client, agentAndBuild.WorkspaceAgent.ID, originalResumeToken) - require.NoError(t, err) - verifiedToken := testutil.RequireRecvCtx(ctx, t, resumeTokenProvider.verifyCalls) - require.Equal(t, originalResumeToken, verifiedToken) - newPeerID := testutil.RequireRecvCtx(ctx, t, resumeTokenProvider.generateCalls) - require.Equal(t, originalPeerID, newPeerID) - require.NotEqual(t, originalResumeToken, newResumeToken) - - // Connect with an invalid resume token, and ensure that the request is - // rejected. - clock.Advance(time.Second) - _, err = connectToCoordinatorAndFetchResumeToken(ctx, logger, client, agentAndBuild.WorkspaceAgent.ID, "invalid") - require.Error(t, err) - var sdkErr *codersdk.Error - require.ErrorAs(t, err, &sdkErr) - require.Equal(t, http.StatusUnauthorized, sdkErr.StatusCode()) - require.Len(t, sdkErr.Validations, 1) - require.Equal(t, "resume_token", sdkErr.Validations[0].Field) - verifiedToken = testutil.RequireRecvCtx(ctx, t, resumeTokenProvider.verifyCalls) - require.Equal(t, "invalid", verifiedToken) + // Create a workspace with an agent. No need to connect it since clients can + // still connect to the coordinator while the agent isn't connected. + r := dbfake.WorkspaceBuild(t, api.Database, database.WorkspaceTable{ + OrganizationID: user.OrganizationID, + OwnerID: user.UserID, + }).WithAgent().Do() + agentTokenUUID, err := uuid.Parse(r.AgentToken) + require.NoError(t, err) + ctx := testutil.Context(t, testutil.WaitLong) + agentAndBuild, err := api.Database.GetWorkspaceAgentAndLatestBuildByAuthToken(dbauthz.AsSystemRestricted(ctx), agentTokenUUID) //nolint + require.NoError(t, err) - select { - case <-resumeTokenProvider.generateCalls: - t.Fatal("unexpected peer ID in channel") - default: - } + // Connect with no resume token, and ensure that the peer ID is set to a + // random value. + originalResumeToken, err := connectToCoordinatorAndFetchResumeToken(ctx, logger, client, agentAndBuild.WorkspaceAgent.ID, "") + require.NoError(t, err) + originalPeerID := testutil.RequireRecvCtx(ctx, t, resumeTokenProvider.generateCalls) + require.NotEqual(t, originalPeerID, uuid.Nil) + + // Connect with an outdated token, and ensure that the peer ID is set to a + // random value. We don't want to fail requests just because + // a user got unlucky during a deployment upgrade. + outdatedToken := generateBadJWT(t, jwtutils.RegisteredClaims{ + Subject: originalPeerID.String(), + Expiry: jwt.NewNumericDate(clock.Now().Add(time.Minute)), + }) + + clock.Advance(time.Second) + newResumeToken, err := connectToCoordinatorAndFetchResumeToken(ctx, logger, client, agentAndBuild.WorkspaceAgent.ID, outdatedToken) + require.NoError(t, err) + verifiedToken := testutil.RequireRecvCtx(ctx, t, resumeTokenProvider.verifyCalls) + require.Equal(t, outdatedToken, verifiedToken) + newPeerID := testutil.RequireRecvCtx(ctx, t, resumeTokenProvider.generateCalls) + require.NotEqual(t, originalPeerID, newPeerID) + require.NotEqual(t, originalResumeToken, newResumeToken) + }) } // connectToCoordinatorAndFetchResumeToken connects to the tailnet coordinator diff --git a/coderd/workspaceapps.go b/coderd/workspaceapps.go index d2fa11b9ea2ea..e264dbd80b58d 100644 --- a/coderd/workspaceapps.go +++ b/coderd/workspaceapps.go @@ -16,6 +16,7 @@ import ( "github.com/coder/coder/v2/coderd/database/dbtime" "github.com/coder/coder/v2/coderd/httpapi" "github.com/coder/coder/v2/coderd/httpmw" + "github.com/coder/coder/v2/coderd/jwtutils" "github.com/coder/coder/v2/coderd/rbac/policy" "github.com/coder/coder/v2/coderd/workspaceapps" "github.com/coder/coder/v2/coderd/workspaceapps/appurl" @@ -122,10 +123,11 @@ func (api *API) workspaceApplicationAuth(rw http.ResponseWriter, r *http.Request return } - // Encrypt the API key. - encryptedAPIKey, err := api.AppSecurityKey.EncryptAPIKey(workspaceapps.EncryptedAPIKeyPayload{ + payload := workspaceapps.EncryptedAPIKeyPayload{ APIKey: cookie.Value, - }) + } + payload.Fill(api.Clock.Now()) + encryptedAPIKey, err := jwtutils.Encrypt(ctx, api.AppEncryptionKeyCache, payload) if err != nil { httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ Message: "Failed to encrypt API key.", diff --git a/coderd/workspaceapps/apptest/apptest.go b/coderd/workspaceapps/apptest/apptest.go index 14adf2d61d362..c6e251806230d 100644 --- a/coderd/workspaceapps/apptest/apptest.go +++ b/coderd/workspaceapps/apptest/apptest.go @@ -3,6 +3,7 @@ package apptest import ( "bufio" "context" + "crypto/rand" "encoding/json" "fmt" "io" @@ -408,6 +409,67 @@ func Run(t *testing.T, appHostIsPrimary bool, factory DeploymentFactory) { require.Equal(t, http.StatusInternalServerError, resp.StatusCode) assertWorkspaceLastUsedAtNotUpdated(t, appDetails) }) + + t.Run("BadJWT", func(t *testing.T) { + t.Parallel() + + appDetails := setupProxyTest(t, nil) + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) + defer cancel() + + u := appDetails.PathAppURL(appDetails.Apps.Owner) + resp, err := requestWithRetries(ctx, t, appDetails.AppClient(t), http.MethodGet, u.String(), nil) + require.NoError(t, err) + defer resp.Body.Close() + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + require.Equal(t, proxyTestAppBody, string(body)) + require.Equal(t, http.StatusOK, resp.StatusCode) + + appTokenCookie := findCookie(resp.Cookies(), codersdk.SignedAppTokenCookie) + require.NotNil(t, appTokenCookie, "no signed app token cookie in response") + require.Equal(t, appTokenCookie.Path, u.Path, "incorrect path on app token cookie") + + object, err := jose.ParseSigned(appTokenCookie.Value) + require.NoError(t, err) + require.Len(t, object.Signatures, 1) + + // Parse the payload. + var tok workspaceapps.SignedToken + //nolint:gosec + err = json.Unmarshal(object.UnsafePayloadWithoutVerification(), &tok) + require.NoError(t, err) + + appTokenClient := appDetails.AppClient(t) + apiKey := appTokenClient.SessionToken() + appTokenClient.SetSessionToken("") + appTokenClient.HTTPClient.Jar, err = cookiejar.New(nil) + require.NoError(t, err) + // Sign the token with an old-style key. + appTokenCookie.Value = generateBadJWT(t, tok) + appTokenClient.HTTPClient.Jar.SetCookies(u, + []*http.Cookie{ + appTokenCookie, + { + Name: codersdk.PathAppSessionTokenCookie, + Value: apiKey, + }, + }, + ) + + resp, err = requestWithRetries(ctx, t, appTokenClient, http.MethodGet, u.String(), nil) + require.NoError(t, err) + defer resp.Body.Close() + body, err = io.ReadAll(resp.Body) + require.NoError(t, err) + require.Equal(t, proxyTestAppBody, string(body)) + require.Equal(t, http.StatusOK, resp.StatusCode) + assertWorkspaceLastUsedAtUpdated(t, appDetails) + + // Since the old token is invalid, the signed app token cookie should have a new value. + newTokenCookie := findCookie(resp.Cookies(), codersdk.SignedAppTokenCookie) + require.NotEqual(t, appTokenCookie.Value, newTokenCookie.Value) + }) }) t.Run("WorkspaceApplicationAuth", func(t *testing.T) { @@ -463,7 +525,7 @@ func Run(t *testing.T, appHostIsPrimary bool, factory DeploymentFactory) { appClient.SetSessionToken("") // Try to load the application without authentication. - u := c.appURL + u := *c.appURL u.Path = path.Join(u.Path, "/test") req, err := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil) require.NoError(t, err) @@ -500,7 +562,7 @@ func Run(t *testing.T, appHostIsPrimary bool, factory DeploymentFactory) { // Copy the query parameters and then check equality. u.RawQuery = gotLocation.RawQuery - require.Equal(t, u, gotLocation) + require.Equal(t, u, *gotLocation) // Verify the API key is set. encryptedAPIKey := gotLocation.Query().Get(workspaceapps.SubdomainProxyAPIKeyParam) @@ -580,6 +642,38 @@ func Run(t *testing.T, appHostIsPrimary bool, factory DeploymentFactory) { resp.Body.Close() require.Equal(t, http.StatusOK, resp.StatusCode) }) + + t.Run("BadJWE", func(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) + defer cancel() + + currentKeyStr := appDetails.SDKClient.SessionToken() + appClient := appDetails.AppClient(t) + appClient.SetSessionToken("") + u := *c.appURL + u.Path = path.Join(u.Path, "/test") + badToken := generateBadJWE(t, workspaceapps.EncryptedAPIKeyPayload{ + APIKey: currentKeyStr, + }) + + u.RawQuery = (url.Values{ + workspaceapps.SubdomainProxyAPIKeyParam: {badToken}, + }).Encode() + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil) + require.NoError(t, err) + + var resp *http.Response + resp, err = doWithRetries(t, appClient, req) + require.NoError(t, err) + defer resp.Body.Close() + require.Equal(t, http.StatusBadRequest, resp.StatusCode) + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + require.Contains(t, string(body), "Could not decrypt API key. Please remove the query parameter and try again.") + }) } }) }) @@ -1077,6 +1171,68 @@ func Run(t *testing.T, appHostIsPrimary bool, factory DeploymentFactory) { assertWorkspaceLastUsedAtNotUpdated(t, appDetails) }) }) + + t.Run("BadJWT", func(t *testing.T) { + t.Parallel() + + appDetails := setupProxyTest(t, nil) + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) + defer cancel() + + u := appDetails.SubdomainAppURL(appDetails.Apps.Owner) + resp, err := requestWithRetries(ctx, t, appDetails.AppClient(t), http.MethodGet, u.String(), nil) + require.NoError(t, err) + defer resp.Body.Close() + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + require.Equal(t, proxyTestAppBody, string(body)) + require.Equal(t, http.StatusOK, resp.StatusCode) + + appTokenCookie := findCookie(resp.Cookies(), codersdk.SignedAppTokenCookie) + require.NotNil(t, appTokenCookie, "no signed token cookie in response") + require.Equal(t, appTokenCookie.Path, "/", "incorrect path on signed token cookie") + + object, err := jose.ParseSigned(appTokenCookie.Value) + require.NoError(t, err) + require.Len(t, object.Signatures, 1) + + // Parse the payload. + var tok workspaceapps.SignedToken + //nolint:gosec + err = json.Unmarshal(object.UnsafePayloadWithoutVerification(), &tok) + require.NoError(t, err) + + appTokenClient := appDetails.AppClient(t) + apiKey := appTokenClient.SessionToken() + appTokenClient.SetSessionToken("") + appTokenClient.HTTPClient.Jar, err = cookiejar.New(nil) + require.NoError(t, err) + // Sign the token with an old-style key. + appTokenCookie.Value = generateBadJWT(t, tok) + appTokenClient.HTTPClient.Jar.SetCookies(u, + []*http.Cookie{ + appTokenCookie, + { + Name: codersdk.SubdomainAppSessionTokenCookie, + Value: apiKey, + }, + }, + ) + + // We should still be able to successfully proxy. + resp, err = requestWithRetries(ctx, t, appTokenClient, http.MethodGet, u.String(), nil) + require.NoError(t, err) + defer resp.Body.Close() + body, err = io.ReadAll(resp.Body) + require.NoError(t, err) + require.Equal(t, proxyTestAppBody, string(body)) + require.Equal(t, http.StatusOK, resp.StatusCode) + assertWorkspaceLastUsedAtUpdated(t, appDetails) + + // Since the old token is invalid, the signed app token cookie should have a new value. + newTokenCookie := findCookie(resp.Cookies(), codersdk.SignedAppTokenCookie) + require.NotEqual(t, appTokenCookie.Value, newTokenCookie.Value) + }) }) t.Run("PortSharing", func(t *testing.T) { @@ -1789,3 +1945,57 @@ func assertWorkspaceLastUsedAtNotUpdated(t testing.TB, details *Details) { require.NoError(t, err) require.Equal(t, before.LastUsedAt, after.LastUsedAt, "workspace LastUsedAt updated when it should not have been") } + +func generateBadJWE(t *testing.T, claims interface{}) string { + t.Helper() + var buf [32]byte + _, err := rand.Read(buf[:]) + require.NoError(t, err) + encrypt, err := jose.NewEncrypter( + jose.A256GCM, + jose.Recipient{ + Algorithm: jose.A256GCMKW, + Key: buf[:], + }, &jose.EncrypterOptions{ + Compression: jose.DEFLATE, + }, + ) + require.NoError(t, err) + payload, err := json.Marshal(claims) + require.NoError(t, err) + signed, err := encrypt.Encrypt(payload) + require.NoError(t, err) + compact, err := signed.CompactSerialize() + require.NoError(t, err) + return compact +} + +// generateBadJWT generates a JWT with a random key. It's intended to emulate the old-style JWT's we generated. +func generateBadJWT(t *testing.T, claims interface{}) string { + t.Helper() + + var buf [64]byte + _, err := rand.Read(buf[:]) + require.NoError(t, err) + signer, err := jose.NewSigner(jose.SigningKey{ + Algorithm: jose.HS512, + Key: buf[:], + }, nil) + require.NoError(t, err) + payload, err := json.Marshal(claims) + require.NoError(t, err) + signed, err := signer.Sign(payload) + require.NoError(t, err) + compact, err := signed.CompactSerialize() + require.NoError(t, err) + return compact +} + +func findCookie(cookies []*http.Cookie, name string) *http.Cookie { + for _, cookie := range cookies { + if cookie.Name == name { + return cookie + } + } + return nil +} diff --git a/coderd/workspaceapps/db.go b/coderd/workspaceapps/db.go index 1b369cf6d6ef4..1aa4dfe91bdd0 100644 --- a/coderd/workspaceapps/db.go +++ b/coderd/workspaceapps/db.go @@ -13,11 +13,15 @@ import ( "golang.org/x/exp/slices" "golang.org/x/xerrors" + "github.com/go-jose/go-jose/v4/jwt" + "cdr.dev/slog" + "github.com/coder/coder/v2/coderd/cryptokeys" "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/dbauthz" "github.com/coder/coder/v2/coderd/httpapi" "github.com/coder/coder/v2/coderd/httpmw" + "github.com/coder/coder/v2/coderd/jwtutils" "github.com/coder/coder/v2/coderd/rbac" "github.com/coder/coder/v2/coderd/rbac/policy" "github.com/coder/coder/v2/codersdk" @@ -35,12 +39,20 @@ type DBTokenProvider struct { DeploymentValues *codersdk.DeploymentValues OAuth2Configs *httpmw.OAuth2Configs WorkspaceAgentInactiveTimeout time.Duration - SigningKey SecurityKey + Keycache cryptokeys.SigningKeycache } var _ SignedTokenProvider = &DBTokenProvider{} -func NewDBTokenProvider(log slog.Logger, accessURL *url.URL, authz rbac.Authorizer, db database.Store, cfg *codersdk.DeploymentValues, oauth2Cfgs *httpmw.OAuth2Configs, workspaceAgentInactiveTimeout time.Duration, signingKey SecurityKey) SignedTokenProvider { +func NewDBTokenProvider(log slog.Logger, + accessURL *url.URL, + authz rbac.Authorizer, + db database.Store, + cfg *codersdk.DeploymentValues, + oauth2Cfgs *httpmw.OAuth2Configs, + workspaceAgentInactiveTimeout time.Duration, + signer cryptokeys.SigningKeycache, +) SignedTokenProvider { if workspaceAgentInactiveTimeout == 0 { workspaceAgentInactiveTimeout = 1 * time.Minute } @@ -53,12 +65,12 @@ func NewDBTokenProvider(log slog.Logger, accessURL *url.URL, authz rbac.Authoriz DeploymentValues: cfg, OAuth2Configs: oauth2Cfgs, WorkspaceAgentInactiveTimeout: workspaceAgentInactiveTimeout, - SigningKey: signingKey, + Keycache: signer, } } func (p *DBTokenProvider) FromRequest(r *http.Request) (*SignedToken, bool) { - return FromRequest(r, p.SigningKey) + return FromRequest(r, p.Keycache) } func (p *DBTokenProvider) Issue(ctx context.Context, rw http.ResponseWriter, r *http.Request, issueReq IssueTokenRequest) (*SignedToken, string, bool) { @@ -70,7 +82,7 @@ func (p *DBTokenProvider) Issue(ctx context.Context, rw http.ResponseWriter, r * dangerousSystemCtx := dbauthz.AsSystemRestricted(ctx) appReq := issueReq.AppRequest.Normalize() - err := appReq.Validate() + err := appReq.Check() if err != nil { WriteWorkspaceApp500(p.Logger, p.DashboardURL, rw, r, &appReq, err, "invalid app request") return nil, "", false @@ -210,9 +222,11 @@ func (p *DBTokenProvider) Issue(ctx context.Context, rw http.ResponseWriter, r * return nil, "", false } + token.RegisteredClaims = jwtutils.RegisteredClaims{ + Expiry: jwt.NewNumericDate(time.Now().Add(DefaultTokenExpiry)), + } // Sign the token. - token.Expiry = time.Now().Add(DefaultTokenExpiry) - tokenStr, err := p.SigningKey.SignToken(token) + tokenStr, err := jwtutils.Sign(ctx, p.Keycache, token) if err != nil { WriteWorkspaceApp500(p.Logger, p.DashboardURL, rw, r, &appReq, err, "generate token") return nil, "", false diff --git a/coderd/workspaceapps/db_test.go b/coderd/workspaceapps/db_test.go index 6c5a0212aff2b..bf364f1ce62b3 100644 --- a/coderd/workspaceapps/db_test.go +++ b/coderd/workspaceapps/db_test.go @@ -13,6 +13,7 @@ import ( "testing" "time" + "github.com/go-jose/go-jose/v4/jwt" "github.com/google/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -20,6 +21,7 @@ import ( "github.com/coder/coder/v2/agent/agenttest" "github.com/coder/coder/v2/coderd/coderdtest" "github.com/coder/coder/v2/coderd/httpmw" + "github.com/coder/coder/v2/coderd/jwtutils" "github.com/coder/coder/v2/coderd/workspaceapps" "github.com/coder/coder/v2/coderd/workspaceapps/appurl" "github.com/coder/coder/v2/codersdk" @@ -94,8 +96,7 @@ func Test_ResolveRequest(t *testing.T) { _ = closer.Close() }) - ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium) - t.Cleanup(cancel) + ctx := testutil.Context(t, testutil.WaitMedium) firstUser := coderdtest.CreateFirstUser(t, client) me, err := client.User(ctx, codersdk.Me) @@ -276,15 +277,17 @@ func Test_ResolveRequest(t *testing.T) { _ = w.Body.Close() require.Equal(t, &workspaceapps.SignedToken{ + RegisteredClaims: jwtutils.RegisteredClaims{ + Expiry: jwt.NewNumericDate(token.Expiry.Time()), + }, Request: req, - Expiry: token.Expiry, // ignored to avoid flakiness UserID: me.ID, WorkspaceID: workspace.ID, AgentID: agentID, AppURL: appURL, }, token) require.NotZero(t, token.Expiry) - require.WithinDuration(t, time.Now().Add(workspaceapps.DefaultTokenExpiry), token.Expiry, time.Minute) + require.WithinDuration(t, time.Now().Add(workspaceapps.DefaultTokenExpiry), token.Expiry.Time(), time.Minute) // Check that the token was set in the response and is valid. require.Len(t, w.Cookies(), 1) @@ -292,10 +295,11 @@ func Test_ResolveRequest(t *testing.T) { require.Equal(t, codersdk.SignedAppTokenCookie, cookie.Name) require.Equal(t, req.BasePath, cookie.Path) - parsedToken, err := api.AppSecurityKey.VerifySignedToken(cookie.Value) + var parsedToken workspaceapps.SignedToken + err := jwtutils.Verify(ctx, api.AppSigningKeyCache, cookie.Value, &parsedToken) require.NoError(t, err) // normalize expiry - require.WithinDuration(t, token.Expiry, parsedToken.Expiry, 2*time.Second) + require.WithinDuration(t, token.Expiry.Time(), parsedToken.Expiry.Time(), 2*time.Second) parsedToken.Expiry = token.Expiry require.Equal(t, token, &parsedToken) @@ -314,7 +318,7 @@ func Test_ResolveRequest(t *testing.T) { }) require.True(t, ok) // normalize expiry - require.WithinDuration(t, token.Expiry, secondToken.Expiry, 2*time.Second) + require.WithinDuration(t, token.Expiry.Time(), secondToken.Expiry.Time(), 2*time.Second) secondToken.Expiry = token.Expiry require.Equal(t, token, secondToken) } @@ -540,13 +544,16 @@ func Test_ResolveRequest(t *testing.T) { // App name differs AppSlugOrPort: appNamePublic, }).Normalize(), - Expiry: time.Now().Add(time.Minute), + RegisteredClaims: jwtutils.RegisteredClaims{ + Expiry: jwt.NewNumericDate(time.Now().Add(time.Minute)), + }, UserID: me.ID, WorkspaceID: workspace.ID, AgentID: agentID, AppURL: appURL, } - badTokenStr, err := api.AppSecurityKey.SignToken(badToken) + + badTokenStr, err := jwtutils.Sign(ctx, api.AppSigningKeyCache, badToken) require.NoError(t, err) req := (workspaceapps.Request{ @@ -589,7 +596,8 @@ func Test_ResolveRequest(t *testing.T) { require.Len(t, cookies, 1) require.Equal(t, cookies[0].Name, codersdk.SignedAppTokenCookie) require.NotEqual(t, cookies[0].Value, badTokenStr) - parsedToken, err := api.AppSecurityKey.VerifySignedToken(cookies[0].Value) + var parsedToken workspaceapps.SignedToken + err = jwtutils.Verify(ctx, api.AppSigningKeyCache, cookies[0].Value, &parsedToken) require.NoError(t, err) require.Equal(t, appNameOwner, parsedToken.AppSlugOrPort) }) diff --git a/coderd/workspaceapps/provider.go b/coderd/workspaceapps/provider.go index 8d4b7fd149800..1887036e35cbf 100644 --- a/coderd/workspaceapps/provider.go +++ b/coderd/workspaceapps/provider.go @@ -38,7 +38,7 @@ type ResolveRequestOptions struct { func ResolveRequest(rw http.ResponseWriter, r *http.Request, opts ResolveRequestOptions) (*SignedToken, bool) { appReq := opts.AppRequest.Normalize() - err := appReq.Validate() + err := appReq.Check() if err != nil { // This is a 500 since it's a coder server or proxy that's making this // request struct based on details from the request. The values should @@ -79,7 +79,7 @@ func ResolveRequest(rw http.ResponseWriter, r *http.Request, opts ResolveRequest Name: codersdk.SignedAppTokenCookie, Value: tokenStr, Path: appReq.BasePath, - Expires: token.Expiry, + Expires: token.Expiry.Time(), }) return token, true diff --git a/coderd/workspaceapps/proxy.go b/coderd/workspaceapps/proxy.go index 69f1aadca49b2..84cea4fa86678 100644 --- a/coderd/workspaceapps/proxy.go +++ b/coderd/workspaceapps/proxy.go @@ -11,17 +11,21 @@ import ( "strconv" "strings" "sync" + "time" "github.com/go-chi/chi/v5" + "github.com/go-jose/go-jose/v4/jwt" "github.com/google/uuid" "go.opentelemetry.io/otel/trace" "nhooyr.io/websocket" "cdr.dev/slog" "github.com/coder/coder/v2/agent/agentssh" + "github.com/coder/coder/v2/coderd/cryptokeys" "github.com/coder/coder/v2/coderd/database/dbtime" "github.com/coder/coder/v2/coderd/httpapi" "github.com/coder/coder/v2/coderd/httpmw" + "github.com/coder/coder/v2/coderd/jwtutils" "github.com/coder/coder/v2/coderd/tracing" "github.com/coder/coder/v2/coderd/util/slice" "github.com/coder/coder/v2/coderd/workspaceapps/appurl" @@ -97,8 +101,8 @@ type Server struct { HostnameRegex *regexp.Regexp RealIPConfig *httpmw.RealIPConfig - SignedTokenProvider SignedTokenProvider - AppSecurityKey SecurityKey + SignedTokenProvider SignedTokenProvider + APIKeyEncryptionKeycache cryptokeys.EncryptionKeycache // DisablePathApps disables path-based apps. This is a security feature as path // based apps share the same cookie as the dashboard, and are susceptible to XSS @@ -176,7 +180,10 @@ func (s *Server) handleAPIKeySmuggling(rw http.ResponseWriter, r *http.Request, } // Exchange the encoded API key for a real one. - token, err := s.AppSecurityKey.DecryptAPIKey(encryptedAPIKey) + var payload EncryptedAPIKeyPayload + err := jwtutils.Decrypt(ctx, s.APIKeyEncryptionKeycache, encryptedAPIKey, &payload, jwtutils.WithDecryptExpected(jwt.Expected{ + Time: time.Now(), + })) if err != nil { s.Logger.Debug(ctx, "could not decrypt smuggled workspace app API key", slog.Error(err)) site.RenderStaticErrorPage(rw, r, site.ErrorPageData{ @@ -225,7 +232,7 @@ func (s *Server) handleAPIKeySmuggling(rw http.ResponseWriter, r *http.Request, // server using the wrong value. http.SetCookie(rw, &http.Cookie{ Name: AppConnectSessionTokenCookieName(accessMethod), - Value: token, + Value: payload.APIKey, Domain: domain, Path: "/", MaxAge: 0, diff --git a/coderd/workspaceapps/request.go b/coderd/workspaceapps/request.go index 4f6a6f3a64e65..0833ab731fe67 100644 --- a/coderd/workspaceapps/request.go +++ b/coderd/workspaceapps/request.go @@ -124,9 +124,9 @@ func (r Request) Normalize() Request { return req } -// Validate ensures the request is correct and contains the necessary +// Check ensures the request is correct and contains the necessary // parameters. -func (r Request) Validate() error { +func (r Request) Check() error { switch r.AccessMethod { case AccessMethodPath, AccessMethodSubdomain, AccessMethodTerminal: default: diff --git a/coderd/workspaceapps/request_test.go b/coderd/workspaceapps/request_test.go index b6e4bb7a2e65f..fbabc840745e9 100644 --- a/coderd/workspaceapps/request_test.go +++ b/coderd/workspaceapps/request_test.go @@ -279,7 +279,7 @@ func Test_RequestValidate(t *testing.T) { if !c.noNormalize { req = c.req.Normalize() } - err := req.Validate() + err := req.Check() if c.errContains == "" { require.NoError(t, err) } else { diff --git a/coderd/workspaceapps/token.go b/coderd/workspaceapps/token.go index 33428b0e25f13..dcd8c5a0e5c34 100644 --- a/coderd/workspaceapps/token.go +++ b/coderd/workspaceapps/token.go @@ -1,35 +1,27 @@ package workspaceapps import ( - "encoding/base64" - "encoding/hex" - "encoding/json" "net/http" "strings" "time" - "github.com/go-jose/go-jose/v3" + "github.com/go-jose/go-jose/v4/jwt" "github.com/google/uuid" "golang.org/x/xerrors" - "github.com/coder/coder/v2/coderd/database/dbtime" + "github.com/coder/coder/v2/coderd/cryptokeys" + "github.com/coder/coder/v2/coderd/jwtutils" "github.com/coder/coder/v2/codersdk" ) -const ( - tokenSigningAlgorithm = jose.HS512 - apiKeyEncryptionAlgorithm = jose.A256GCMKW -) - // SignedToken is the struct data contained inside a workspace app JWE. It // contains the details of the workspace app that the token is valid for to // avoid database queries. type SignedToken struct { + jwtutils.RegisteredClaims // Request details. Request `json:"request"` - // Trusted resolved details. - Expiry time.Time `json:"expiry"` // set by GenerateToken if unset UserID uuid.UUID `json:"user_id"` WorkspaceID uuid.UUID `json:"workspace_id"` AgentID uuid.UUID `json:"agent_id"` @@ -57,191 +49,32 @@ func (t SignedToken) MatchesRequest(req Request) bool { t.AppSlugOrPort == req.AppSlugOrPort } -// SecurityKey is used for signing and encrypting app tokens and API keys. -// -// The first 64 bytes of the key are used for signing tokens with HMAC-SHA256, -// and the last 32 bytes are used for encrypting API keys with AES-256-GCM. -// We use a single key for both operations to avoid having to store and manage -// two keys. -type SecurityKey [96]byte - -func (k SecurityKey) IsZero() bool { - return k == SecurityKey{} -} - -func (k SecurityKey) String() string { - return hex.EncodeToString(k[:]) -} - -func (k SecurityKey) signingKey() []byte { - return k[:64] -} - -func (k SecurityKey) encryptionKey() []byte { - return k[64:] -} - -func KeyFromString(str string) (SecurityKey, error) { - var key SecurityKey - decoded, err := hex.DecodeString(str) - if err != nil { - return key, xerrors.Errorf("decode key: %w", err) - } - if len(decoded) != len(key) { - return key, xerrors.Errorf("expected key to be %d bytes, got %d", len(key), len(decoded)) - } - copy(key[:], decoded) - - return key, nil -} - -// SignToken generates a signed workspace app token with the given payload. If -// the payload doesn't have an expiry, it will be set to the current time plus -// the default expiry. -func (k SecurityKey) SignToken(payload SignedToken) (string, error) { - if payload.Expiry.IsZero() { - payload.Expiry = time.Now().Add(DefaultTokenExpiry) - } - payloadBytes, err := json.Marshal(payload) - if err != nil { - return "", xerrors.Errorf("marshal payload to JSON: %w", err) - } - - signer, err := jose.NewSigner(jose.SigningKey{ - Algorithm: tokenSigningAlgorithm, - Key: k.signingKey(), - }, nil) - if err != nil { - return "", xerrors.Errorf("create signer: %w", err) - } - - signedObject, err := signer.Sign(payloadBytes) - if err != nil { - return "", xerrors.Errorf("sign payload: %w", err) - } - - serialized, err := signedObject.CompactSerialize() - if err != nil { - return "", xerrors.Errorf("serialize JWS: %w", err) - } - - return serialized, nil -} - -// VerifySignedToken parses a signed workspace app token with the given key and -// returns the payload. If the token is invalid or expired, an error is -// returned. -func (k SecurityKey) VerifySignedToken(str string) (SignedToken, error) { - object, err := jose.ParseSigned(str) - if err != nil { - return SignedToken{}, xerrors.Errorf("parse JWS: %w", err) - } - if len(object.Signatures) != 1 { - return SignedToken{}, xerrors.New("expected 1 signature") - } - if object.Signatures[0].Header.Algorithm != string(tokenSigningAlgorithm) { - return SignedToken{}, xerrors.Errorf("expected token signing algorithm to be %q, got %q", tokenSigningAlgorithm, object.Signatures[0].Header.Algorithm) - } - - output, err := object.Verify(k.signingKey()) - if err != nil { - return SignedToken{}, xerrors.Errorf("verify JWS: %w", err) - } - - var tok SignedToken - err = json.Unmarshal(output, &tok) - if err != nil { - return SignedToken{}, xerrors.Errorf("unmarshal payload: %w", err) - } - if tok.Expiry.Before(time.Now()) { - return SignedToken{}, xerrors.New("signed app token expired") - } - - return tok, nil -} - type EncryptedAPIKeyPayload struct { - APIKey string `json:"api_key"` - ExpiresAt time.Time `json:"expires_at"` + jwtutils.RegisteredClaims + APIKey string `json:"api_key"` } -// EncryptAPIKey encrypts an API key for subdomain token smuggling. -func (k SecurityKey) EncryptAPIKey(payload EncryptedAPIKeyPayload) (string, error) { - if payload.APIKey == "" { - return "", xerrors.New("API key is empty") - } - if payload.ExpiresAt.IsZero() { - // Very short expiry as these keys are only used once as part of an - // automatic redirection flow. - payload.ExpiresAt = dbtime.Now().Add(time.Minute) - } - - payloadBytes, err := json.Marshal(payload) - if err != nil { - return "", xerrors.Errorf("marshal payload: %w", err) - } - - // JWEs seem to apply a nonce themselves. - encrypter, err := jose.NewEncrypter( - jose.A256GCM, - jose.Recipient{ - Algorithm: apiKeyEncryptionAlgorithm, - Key: k.encryptionKey(), - }, - &jose.EncrypterOptions{ - Compression: jose.DEFLATE, - }, - ) - if err != nil { - return "", xerrors.Errorf("initializer jose encrypter: %w", err) - } - encryptedObject, err := encrypter.Encrypt(payloadBytes) - if err != nil { - return "", xerrors.Errorf("encrypt jwe: %w", err) - } - - encrypted := encryptedObject.FullSerialize() - return base64.RawURLEncoding.EncodeToString([]byte(encrypted)), nil +func (e *EncryptedAPIKeyPayload) Fill(now time.Time) { + e.Issuer = "coderd" + e.Audience = jwt.Audience{"wsproxy"} + e.Expiry = jwt.NewNumericDate(now.Add(time.Minute)) + e.NotBefore = jwt.NewNumericDate(now.Add(-time.Minute)) } -// DecryptAPIKey undoes EncryptAPIKey and is used in the subdomain app handler. -func (k SecurityKey) DecryptAPIKey(encryptedAPIKey string) (string, error) { - encrypted, err := base64.RawURLEncoding.DecodeString(encryptedAPIKey) - if err != nil { - return "", xerrors.Errorf("base64 decode encrypted API key: %w", err) +func (e EncryptedAPIKeyPayload) Validate(ex jwt.Expected) error { + if e.NotBefore == nil { + return xerrors.Errorf("not before is required") } - object, err := jose.ParseEncrypted(string(encrypted)) - if err != nil { - return "", xerrors.Errorf("parse encrypted API key: %w", err) - } - if object.Header.Algorithm != string(apiKeyEncryptionAlgorithm) { - return "", xerrors.Errorf("expected API key encryption algorithm to be %q, got %q", apiKeyEncryptionAlgorithm, object.Header.Algorithm) - } - - // Decrypt using the hashed secret. - decrypted, err := object.Decrypt(k.encryptionKey()) - if err != nil { - return "", xerrors.Errorf("decrypt API key: %w", err) - } - - // Unmarshal the payload. - var payload EncryptedAPIKeyPayload - if err := json.Unmarshal(decrypted, &payload); err != nil { - return "", xerrors.Errorf("unmarshal decrypted payload: %w", err) - } - - // Validate expiry. - if payload.ExpiresAt.Before(dbtime.Now()) { - return "", xerrors.New("encrypted API key expired") - } + ex.Issuer = "coderd" + ex.AnyAudience = jwt.Audience{"wsproxy"} - return payload.APIKey, nil + return e.RegisteredClaims.Validate(ex) } // FromRequest returns the signed token from the request, if it exists and is // valid. The caller must check that the token matches the request. -func FromRequest(r *http.Request, key SecurityKey) (*SignedToken, bool) { +func FromRequest(r *http.Request, mgr cryptokeys.SigningKeycache) (*SignedToken, bool) { // Get all signed app tokens from the request. This includes the query // parameter and all matching cookies sent with the request. If there are // somehow multiple signed app token cookies, we want to try all of them @@ -270,8 +103,12 @@ func FromRequest(r *http.Request, key SecurityKey) (*SignedToken, bool) { tokens = tokens[:4] } + ctx := r.Context() for _, tokenStr := range tokens { - token, err := key.VerifySignedToken(tokenStr) + var token SignedToken + err := jwtutils.Verify(ctx, mgr, tokenStr, &token, jwtutils.WithVerifyExpected(jwt.Expected{ + Time: time.Now(), + })) if err == nil { req := token.Request.Normalize() if hasQueryParam && req.AccessMethod != AccessMethodTerminal { @@ -280,7 +117,7 @@ func FromRequest(r *http.Request, key SecurityKey) (*SignedToken, bool) { return nil, false } - err := req.Validate() + err := req.Check() if err == nil { // The request has a valid signed app token, which is a valid // token signed by us. The caller must check that it matches diff --git a/coderd/workspaceapps/token_test.go b/coderd/workspaceapps/token_test.go index c656ae2ab77b8..db070268fa196 100644 --- a/coderd/workspaceapps/token_test.go +++ b/coderd/workspaceapps/token_test.go @@ -1,22 +1,22 @@ package workspaceapps_test import ( - "fmt" + "crypto/rand" "net/http" "net/http/httptest" "testing" "time" + "github.com/go-jose/go-jose/v4/jwt" + "github.com/coder/coder/v2/codersdk" + "github.com/coder/coder/v2/testutil" - "github.com/go-jose/go-jose/v3" "github.com/google/uuid" "github.com/stretchr/testify/require" - "github.com/coder/coder/v2/coderd/coderdtest" - "github.com/coder/coder/v2/coderd/database/dbtime" + "github.com/coder/coder/v2/coderd/jwtutils" "github.com/coder/coder/v2/coderd/workspaceapps" - "github.com/coder/coder/v2/cryptorand" ) func Test_TokenMatchesRequest(t *testing.T) { @@ -283,129 +283,6 @@ func Test_TokenMatchesRequest(t *testing.T) { } } -func Test_GenerateToken(t *testing.T) { - t.Parallel() - - t.Run("SetExpiry", func(t *testing.T) { - t.Parallel() - - tokenStr, err := coderdtest.AppSecurityKey.SignToken(workspaceapps.SignedToken{ - Request: workspaceapps.Request{ - AccessMethod: workspaceapps.AccessMethodPath, - BasePath: "/app", - UsernameOrID: "foo", - WorkspaceNameOrID: "bar", - AgentNameOrID: "baz", - AppSlugOrPort: "qux", - }, - - Expiry: time.Time{}, - UserID: uuid.MustParse("b1530ba9-76f3-415e-b597-4ddd7cd466a4"), - WorkspaceID: uuid.MustParse("1e6802d3-963e-45ac-9d8c-bf997016ffed"), - AgentID: uuid.MustParse("9ec18681-d2c9-4c9e-9186-f136efb4edbe"), - AppURL: "http://127.0.0.1:8080", - }) - require.NoError(t, err) - - token, err := coderdtest.AppSecurityKey.VerifySignedToken(tokenStr) - require.NoError(t, err) - - require.WithinDuration(t, time.Now().Add(time.Minute), token.Expiry, 15*time.Second) - }) - - future := time.Now().Add(time.Hour) - cases := []struct { - name string - token workspaceapps.SignedToken - parseErrContains string - }{ - { - name: "OK1", - token: workspaceapps.SignedToken{ - Request: workspaceapps.Request{ - AccessMethod: workspaceapps.AccessMethodPath, - BasePath: "/app", - UsernameOrID: "foo", - WorkspaceNameOrID: "bar", - AgentNameOrID: "baz", - AppSlugOrPort: "qux", - }, - - Expiry: future, - UserID: uuid.MustParse("b1530ba9-76f3-415e-b597-4ddd7cd466a4"), - WorkspaceID: uuid.MustParse("1e6802d3-963e-45ac-9d8c-bf997016ffed"), - AgentID: uuid.MustParse("9ec18681-d2c9-4c9e-9186-f136efb4edbe"), - AppURL: "http://127.0.0.1:8080", - }, - }, - { - name: "OK2", - token: workspaceapps.SignedToken{ - Request: workspaceapps.Request{ - AccessMethod: workspaceapps.AccessMethodSubdomain, - BasePath: "/", - UsernameOrID: "oof", - WorkspaceNameOrID: "rab", - AgentNameOrID: "zab", - AppSlugOrPort: "xuq", - }, - - Expiry: future, - UserID: uuid.MustParse("6fa684a3-11aa-49fd-8512-ab527bd9b900"), - WorkspaceID: uuid.MustParse("b2d816cc-505c-441d-afdf-dae01781bc0b"), - AgentID: uuid.MustParse("6c4396e1-af88-4a8a-91a3-13ea54fc29fb"), - AppURL: "http://localhost:9090", - }, - }, - { - name: "Expired", - token: workspaceapps.SignedToken{ - Request: workspaceapps.Request{ - AccessMethod: workspaceapps.AccessMethodSubdomain, - BasePath: "/", - UsernameOrID: "foo", - WorkspaceNameOrID: "bar", - AgentNameOrID: "baz", - AppSlugOrPort: "qux", - }, - - Expiry: time.Now().Add(-time.Hour), - UserID: uuid.MustParse("b1530ba9-76f3-415e-b597-4ddd7cd466a4"), - WorkspaceID: uuid.MustParse("1e6802d3-963e-45ac-9d8c-bf997016ffed"), - AgentID: uuid.MustParse("9ec18681-d2c9-4c9e-9186-f136efb4edbe"), - AppURL: "http://127.0.0.1:8080", - }, - parseErrContains: "token expired", - }, - } - - for _, c := range cases { - c := c - - t.Run(c.name, func(t *testing.T) { - t.Parallel() - - str, err := coderdtest.AppSecurityKey.SignToken(c.token) - require.NoError(t, err) - - // Tokens aren't deterministic as they have a random nonce, so we - // can't compare them directly. - - token, err := coderdtest.AppSecurityKey.VerifySignedToken(str) - if c.parseErrContains != "" { - require.Error(t, err) - require.ErrorContains(t, err, c.parseErrContains) - } else { - require.NoError(t, err) - // normalize the expiry - require.WithinDuration(t, c.token.Expiry, token.Expiry, 10*time.Second) - c.token.Expiry = token.Expiry - require.Equal(t, c.token, token) - } - }) - } -} - func Test_FromRequest(t *testing.T) { t.Parallel() @@ -419,7 +296,13 @@ func Test_FromRequest(t *testing.T) { Value: "invalid", }) + ctx := testutil.Context(t, testutil.WaitShort) + signer := newSigner(t) + token := workspaceapps.SignedToken{ + RegisteredClaims: jwtutils.RegisteredClaims{ + Expiry: jwt.NewNumericDate(time.Now().Add(time.Hour)), + }, Request: workspaceapps.Request{ AccessMethod: workspaceapps.AccessMethodSubdomain, BasePath: "/", @@ -429,7 +312,6 @@ func Test_FromRequest(t *testing.T) { AgentNameOrID: "agent", AppSlugOrPort: "app", }, - Expiry: time.Now().Add(time.Hour), UserID: uuid.New(), WorkspaceID: uuid.New(), AgentID: uuid.New(), @@ -438,16 +320,15 @@ func Test_FromRequest(t *testing.T) { // Add an expired cookie expired := token - expired.Expiry = time.Now().Add(time.Hour * -1) - expiredStr, err := coderdtest.AppSecurityKey.SignToken(token) + expired.RegisteredClaims.Expiry = jwt.NewNumericDate(time.Now().Add(time.Hour * -1)) + expiredStr, err := jwtutils.Sign(ctx, signer, expired) require.NoError(t, err) r.AddCookie(&http.Cookie{ Name: codersdk.SignedAppTokenCookie, Value: expiredStr, }) - // Add a valid token - validStr, err := coderdtest.AppSecurityKey.SignToken(token) + validStr, err := jwtutils.Sign(ctx, signer, token) require.NoError(t, err) r.AddCookie(&http.Cookie{ @@ -455,147 +336,27 @@ func Test_FromRequest(t *testing.T) { Value: validStr, }) - signed, ok := workspaceapps.FromRequest(r, coderdtest.AppSecurityKey) + signed, ok := workspaceapps.FromRequest(r, signer) require.True(t, ok, "expected a token to be found") // Confirm it is the correct token. require.Equal(t, signed.UserID, token.UserID) }) } -// The ParseToken fn is tested quite thoroughly in the GenerateToken test as -// well. -func Test_ParseToken(t *testing.T) { - t.Parallel() - - t.Run("InvalidJWS", func(t *testing.T) { - t.Parallel() - - token, err := coderdtest.AppSecurityKey.VerifySignedToken("invalid") - require.Error(t, err) - require.ErrorContains(t, err, "parse JWS") - require.Equal(t, workspaceapps.SignedToken{}, token) - }) - - t.Run("VerifySignature", func(t *testing.T) { - t.Parallel() +func newSigner(t *testing.T) jwtutils.StaticKey { + t.Helper() - // Create a valid token using a different key. - var otherKey workspaceapps.SecurityKey - copy(otherKey[:], coderdtest.AppSecurityKey[:]) - for i := range otherKey { - otherKey[i] ^= 0xff - } - require.NotEqual(t, coderdtest.AppSecurityKey, otherKey) - - tokenStr, err := otherKey.SignToken(workspaceapps.SignedToken{ - Request: workspaceapps.Request{ - AccessMethod: workspaceapps.AccessMethodPath, - BasePath: "/app", - UsernameOrID: "foo", - WorkspaceNameOrID: "bar", - AgentNameOrID: "baz", - AppSlugOrPort: "qux", - }, - - Expiry: time.Now().Add(time.Hour), - UserID: uuid.MustParse("b1530ba9-76f3-415e-b597-4ddd7cd466a4"), - WorkspaceID: uuid.MustParse("1e6802d3-963e-45ac-9d8c-bf997016ffed"), - AgentID: uuid.MustParse("9ec18681-d2c9-4c9e-9186-f136efb4edbe"), - AppURL: "http://127.0.0.1:8080", - }) - require.NoError(t, err) - - // Verify the token is invalid. - token, err := coderdtest.AppSecurityKey.VerifySignedToken(tokenStr) - require.Error(t, err) - require.ErrorContains(t, err, "verify JWS") - require.Equal(t, workspaceapps.SignedToken{}, token) - }) - - t.Run("InvalidBody", func(t *testing.T) { - t.Parallel() - - // Create a signature for an invalid body. - signer, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.HS512, Key: coderdtest.AppSecurityKey[:64]}, nil) - require.NoError(t, err) - signedObject, err := signer.Sign([]byte("hi")) - require.NoError(t, err) - serialized, err := signedObject.CompactSerialize() - require.NoError(t, err) - - token, err := coderdtest.AppSecurityKey.VerifySignedToken(serialized) - require.Error(t, err) - require.ErrorContains(t, err, "unmarshal payload") - require.Equal(t, workspaceapps.SignedToken{}, token) - }) -} - -func TestAPIKeyEncryption(t *testing.T) { - t.Parallel() - - genAPIKey := func(t *testing.T) string { - id, _ := cryptorand.String(10) - secret, _ := cryptorand.String(22) - - return fmt.Sprintf("%s-%s", id, secret) + return jwtutils.StaticKey{ + ID: "test", + Key: generateSecret(t, 64), } +} - t.Run("OK", func(t *testing.T) { - t.Parallel() - - key := genAPIKey(t) - encrypted, err := coderdtest.AppSecurityKey.EncryptAPIKey(workspaceapps.EncryptedAPIKeyPayload{ - APIKey: key, - }) - require.NoError(t, err) - - decryptedKey, err := coderdtest.AppSecurityKey.DecryptAPIKey(encrypted) - require.NoError(t, err) - require.Equal(t, key, decryptedKey) - }) - - t.Run("Verifies", func(t *testing.T) { - t.Parallel() - - t.Run("Expiry", func(t *testing.T) { - t.Parallel() - - key := genAPIKey(t) - encrypted, err := coderdtest.AppSecurityKey.EncryptAPIKey(workspaceapps.EncryptedAPIKeyPayload{ - APIKey: key, - ExpiresAt: dbtime.Now().Add(-1 * time.Hour), - }) - require.NoError(t, err) - - decryptedKey, err := coderdtest.AppSecurityKey.DecryptAPIKey(encrypted) - require.Error(t, err) - require.ErrorContains(t, err, "expired") - require.Empty(t, decryptedKey) - }) - - t.Run("EncryptionKey", func(t *testing.T) { - t.Parallel() - - // Create a valid token using a different key. - var otherKey workspaceapps.SecurityKey - copy(otherKey[:], coderdtest.AppSecurityKey[:]) - for i := range otherKey { - otherKey[i] ^= 0xff - } - require.NotEqual(t, coderdtest.AppSecurityKey, otherKey) - - // Encrypt with the other key. - key := genAPIKey(t) - encrypted, err := otherKey.EncryptAPIKey(workspaceapps.EncryptedAPIKeyPayload{ - APIKey: key, - }) - require.NoError(t, err) +func generateSecret(t *testing.T, size int) []byte { + t.Helper() - // Decrypt with the original key. - decryptedKey, err := coderdtest.AppSecurityKey.DecryptAPIKey(encrypted) - require.Error(t, err) - require.ErrorContains(t, err, "decrypt API key") - require.Empty(t, decryptedKey) - }) - }) + secret := make([]byte, size) + _, err := rand.Read(secret) + require.NoError(t, err) + return secret } diff --git a/coderd/workspaceapps_test.go b/coderd/workspaceapps_test.go index 1d00b7daa7bd9..52b3e18b4e6ad 100644 --- a/coderd/workspaceapps_test.go +++ b/coderd/workspaceapps_test.go @@ -5,16 +5,23 @@ import ( "net/http" "net/url" "testing" + "time" + "github.com/go-jose/go-jose/v4/jwt" "github.com/stretchr/testify/require" + "cdr.dev/slog/sloggers/slogtest" + "github.com/coder/coder/v2/coderd/coderdtest" + "github.com/coder/coder/v2/coderd/cryptokeys" "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/dbgen" "github.com/coder/coder/v2/coderd/database/dbtestutil" + "github.com/coder/coder/v2/coderd/jwtutils" "github.com/coder/coder/v2/coderd/workspaceapps" "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/testutil" + "github.com/coder/quartz" ) func TestGetAppHost(t *testing.T) { @@ -181,16 +188,28 @@ func TestWorkspaceApplicationAuth(t *testing.T) { t.Run(c.name, func(t *testing.T) { t.Parallel() - db, pubsub := dbtestutil.NewDB(t) - + ctx := testutil.Context(t, testutil.WaitMedium) + logger := slogtest.Make(t, nil) accessURL, err := url.Parse(c.accessURL) require.NoError(t, err) + db, ps := dbtestutil.NewDB(t) + fetcher := &cryptokeys.DBFetcher{ + DB: db, + } + + kc, err := cryptokeys.NewEncryptionCache(ctx, logger, fetcher, codersdk.CryptoKeyFeatureWorkspaceAppsAPIKey) + require.NoError(t, err) + + clock := quartz.NewMock(t) + client := coderdtest.New(t, &coderdtest.Options{ - Database: db, - Pubsub: pubsub, - AccessURL: accessURL, - AppHostname: c.appHostname, + AccessURL: accessURL, + AppHostname: c.appHostname, + Database: db, + Pubsub: ps, + APIKeyEncryptionCache: kc, + Clock: clock, }) _ = coderdtest.CreateFirstUser(t, client) @@ -240,7 +259,15 @@ func TestWorkspaceApplicationAuth(t *testing.T) { loc.RawQuery = q.Encode() require.Equal(t, c.expectRedirect, loc.String()) - // The decrypted key is verified in the apptest test suite. + var token workspaceapps.EncryptedAPIKeyPayload + err = jwtutils.Decrypt(ctx, kc, encryptedAPIKey, &token, jwtutils.WithDecryptExpected(jwt.Expected{ + Time: clock.Now(), + AnyAudience: jwt.Audience{"wsproxy"}, + Issuer: "coderd", + })) + require.NoError(t, err) + require.Equal(t, jwt.NewNumericDate(clock.Now().Add(time.Minute)), token.Expiry) + require.Equal(t, jwt.NewNumericDate(clock.Now().Add(-time.Minute)), token.NotBefore) }) } } diff --git a/codersdk/deployment.go b/codersdk/deployment.go index d6840df504b85..391d0039f0369 100644 --- a/codersdk/deployment.go +++ b/codersdk/deployment.go @@ -3109,9 +3109,11 @@ func (c *Client) SSHConfiguration(ctx context.Context) (SSHConfigResponse, error type CryptoKeyFeature string const ( - CryptoKeyFeatureWorkspaceApp CryptoKeyFeature = "workspace_apps" - CryptoKeyFeatureOIDCConvert CryptoKeyFeature = "oidc_convert" - CryptoKeyFeatureTailnetResume CryptoKeyFeature = "tailnet_resume" + CryptoKeyFeatureWorkspaceAppsAPIKey CryptoKeyFeature = "workspace_apps_api_key" + //nolint:gosec // This denotes a type of key, not a literal. + CryptoKeyFeatureWorkspaceAppsToken CryptoKeyFeature = "workspace_apps_token" + CryptoKeyFeatureOIDCConvert CryptoKeyFeature = "oidc_convert" + CryptoKeyFeatureTailnetResume CryptoKeyFeature = "tailnet_resume" ) type CryptoKey struct { diff --git a/codersdk/workspacesdk/connector_internal_test.go b/codersdk/workspacesdk/connector_internal_test.go index 7a339a0079ba2..19f1930c89bc5 100644 --- a/codersdk/workspacesdk/connector_internal_test.go +++ b/codersdk/workspacesdk/connector_internal_test.go @@ -25,6 +25,7 @@ import ( "cdr.dev/slog/sloggers/slogtest" "github.com/coder/coder/v2/apiversion" "github.com/coder/coder/v2/coderd/httpapi" + "github.com/coder/coder/v2/coderd/jwtutils" "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/tailnet" "github.com/coder/coder/v2/tailnet/proto" @@ -61,7 +62,7 @@ func TestTailnetAPIConnector_Disconnects(t *testing.T) { CoordPtr: &coordPtr, DERPMapUpdateFrequency: time.Millisecond, DERPMapFn: func() *tailcfg.DERPMap { return <-derpMapCh }, - NetworkTelemetryHandler: func(batch []*proto.TelemetryEvent) {}, + NetworkTelemetryHandler: func([]*proto.TelemetryEvent) {}, ResumeTokenProvider: tailnet.NewInsecureTestResumeTokenProvider(), }) require.NoError(t, err) @@ -165,13 +166,17 @@ func TestTailnetAPIConnector_ResumeToken(t *testing.T) { clock := quartz.NewMock(t) resumeTokenSigningKey, err := tailnet.GenerateResumeTokenSigningKey() require.NoError(t, err) - resumeTokenProvider := tailnet.NewResumeTokenKeyProvider(resumeTokenSigningKey, clock, time.Hour) + mgr := jwtutils.StaticKey{ + ID: "123", + Key: resumeTokenSigningKey[:], + } + resumeTokenProvider := tailnet.NewResumeTokenKeyProvider(mgr, clock, time.Hour) svc, err := tailnet.NewClientService(tailnet.ClientServiceOptions{ Logger: logger, CoordPtr: &coordPtr, DERPMapUpdateFrequency: time.Millisecond, DERPMapFn: func() *tailcfg.DERPMap { return <-derpMapCh }, - NetworkTelemetryHandler: func(batch []*proto.TelemetryEvent) {}, + NetworkTelemetryHandler: func([]*proto.TelemetryEvent) {}, ResumeTokenProvider: resumeTokenProvider, }) require.NoError(t, err) @@ -190,7 +195,7 @@ func TestTailnetAPIConnector_ResumeToken(t *testing.T) { t.Logf("received resume token: %s", resumeToken) assert.Equal(t, expectResumeToken, resumeToken) if resumeToken != "" { - peerID, err = resumeTokenProvider.VerifyResumeToken(resumeToken) + peerID, err = resumeTokenProvider.VerifyResumeToken(ctx, resumeToken) assert.NoError(t, err, "failed to parse resume token") if err != nil { httpapi.Write(ctx, w, http.StatusUnauthorized, codersdk.Response{ @@ -280,13 +285,17 @@ func TestTailnetAPIConnector_ResumeTokenFailure(t *testing.T) { clock := quartz.NewMock(t) resumeTokenSigningKey, err := tailnet.GenerateResumeTokenSigningKey() require.NoError(t, err) - resumeTokenProvider := tailnet.NewResumeTokenKeyProvider(resumeTokenSigningKey, clock, time.Hour) + mgr := jwtutils.StaticKey{ + ID: uuid.New().String(), + Key: resumeTokenSigningKey[:], + } + resumeTokenProvider := tailnet.NewResumeTokenKeyProvider(mgr, clock, time.Hour) svc, err := tailnet.NewClientService(tailnet.ClientServiceOptions{ Logger: logger, CoordPtr: &coordPtr, DERPMapUpdateFrequency: time.Millisecond, DERPMapFn: func() *tailcfg.DERPMap { return <-derpMapCh }, - NetworkTelemetryHandler: func(batch []*proto.TelemetryEvent) {}, + NetworkTelemetryHandler: func(_ []*proto.TelemetryEvent) {}, ResumeTokenProvider: resumeTokenProvider, }) require.NoError(t, err) diff --git a/docs/reference/api/schemas.md b/docs/reference/api/schemas.md index ed3800b3a27cd..f4e683305029b 100644 --- a/docs/reference/api/schemas.md +++ b/docs/reference/api/schemas.md @@ -1454,7 +1454,7 @@ CreateWorkspaceRequest provides options for creating a new workspace. Only one o ```json { "deletes_at": "2019-08-24T14:15:22Z", - "feature": "workspace_apps", + "feature": "workspace_apps_api_key", "secret": "string", "sequence": 0, "starts_at": "2019-08-24T14:15:22Z" @@ -1474,18 +1474,19 @@ CreateWorkspaceRequest provides options for creating a new workspace. Only one o ## codersdk.CryptoKeyFeature ```json -"workspace_apps" +"workspace_apps_api_key" ``` ### Properties #### Enumerated Values -| Value | -| ---------------- | -| `workspace_apps` | -| `oidc_convert` | -| `tailnet_resume` | +| Value | +| ------------------------ | +| `workspace_apps_api_key` | +| `workspace_apps_token` | +| `oidc_convert` | +| `tailnet_resume` | ## codersdk.CustomRoleRequest @@ -9893,7 +9894,7 @@ _None_ "crypto_keys": [ { "deletes_at": "2019-08-24T14:15:22Z", - "feature": "workspace_apps", + "feature": "workspace_apps_api_key", "secret": "string", "sequence": 0, "starts_at": "2019-08-24T14:15:22Z" @@ -9971,7 +9972,6 @@ _None_ ```json { - "app_security_key": "string", "derp_force_websockets": true, "derp_map": { "homeParams": { @@ -10052,7 +10052,6 @@ _None_ | Name | Type | Required | Restrictions | Description | | ----------------------- | --------------------------------------------- | -------- | ------------ | -------------------------------------------------------------------------------------- | -| `app_security_key` | string | false | | | | `derp_force_websockets` | boolean | false | | | | `derp_map` | [tailcfg.DERPMap](#tailcfgderpmap) | false | | | | `derp_mesh_key` | string | false | | | diff --git a/enterprise/coderd/coderdenttest/proxytest.go b/enterprise/coderd/coderdenttest/proxytest.go index 6e5a822bdf251..a6f2c7384b16f 100644 --- a/enterprise/coderd/coderdenttest/proxytest.go +++ b/enterprise/coderd/coderdenttest/proxytest.go @@ -65,6 +65,8 @@ type WorkspaceProxy struct { // owner client. If a token is provided, the proxy will become a replica of the // existing proxy region. func NewWorkspaceProxyReplica(t *testing.T, coderdAPI *coderd.API, owner *codersdk.Client, options *ProxyOptions) WorkspaceProxy { + t.Helper() + ctx, cancelFunc := context.WithCancel(context.Background()) t.Cleanup(cancelFunc) @@ -142,8 +144,10 @@ func NewWorkspaceProxyReplica(t *testing.T, coderdAPI *coderd.API, owner *coders statsCollectorOptions.Flush = options.FlushStats } + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug).With(slog.F("server_url", serverURL.String())) + wssrv, err := wsproxy.New(ctx, &wsproxy.Options{ - Logger: slogtest.Make(t, nil).Leveled(slog.LevelDebug).With(slog.F("server_url", serverURL.String())), + Logger: logger, Experiments: options.Experiments, DashboardURL: coderdAPI.AccessURL, AccessURL: accessURL, diff --git a/enterprise/coderd/workspaceproxy.go b/enterprise/coderd/workspaceproxy.go index 47bdf53493489..4008de69e4faa 100644 --- a/enterprise/coderd/workspaceproxy.go +++ b/enterprise/coderd/workspaceproxy.go @@ -7,6 +7,7 @@ import ( "fmt" "net/http" "net/url" + "slices" "strings" "time" @@ -33,6 +34,13 @@ import ( "github.com/coder/coder/v2/enterprise/wsproxy/wsproxysdk" ) +// whitelistedCryptoKeyFeatures is a list of crypto key features that are +// allowed to be queried with workspace proxies. +var whitelistedCryptoKeyFeatures = []database.CryptoKeyFeature{ + database.CryptoKeyFeatureWorkspaceAppsToken, + database.CryptoKeyFeatureWorkspaceAppsAPIKey, +} + // forceWorkspaceProxyHealthUpdate forces an update of the proxy health. // This is useful when a proxy is created or deleted. Errors will be logged. func (api *API) forceWorkspaceProxyHealthUpdate(ctx context.Context) { @@ -700,7 +708,6 @@ func (api *API) workspaceProxyRegister(rw http.ResponseWriter, r *http.Request) } httpapi.Write(ctx, rw, http.StatusCreated, wsproxysdk.RegisterWorkspaceProxyResponse{ - AppSecurityKey: api.AppSecurityKey.String(), DERPMeshKey: api.DERPServer.MeshKey(), DERPRegionID: regionID, DERPMap: api.AGPL.DERPMap(), @@ -721,13 +728,29 @@ func (api *API) workspaceProxyRegister(rw http.ResponseWriter, r *http.Request) // @Security CoderSessionToken // @Produce json // @Tags Enterprise +// @Param feature query string true "Feature key" // @Success 200 {object} wsproxysdk.CryptoKeysResponse // @Router /workspaceproxies/me/crypto-keys [get] // @x-apidocgen {"skip": true} func (api *API) workspaceProxyCryptoKeys(rw http.ResponseWriter, r *http.Request) { ctx := r.Context() - keys, err := api.Database.GetCryptoKeysByFeature(ctx, database.CryptoKeyFeatureWorkspaceApps) + feature := database.CryptoKeyFeature(r.URL.Query().Get("feature")) + if feature == "" { + httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{ + Message: "Missing feature query parameter.", + }) + return + } + + if !slices.Contains(whitelistedCryptoKeyFeatures, feature) { + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ + Message: fmt.Sprintf("Invalid feature: %q", feature), + }) + return + } + + keys, err := api.Database.GetCryptoKeysByFeature(ctx, feature) if err != nil { httpapi.InternalServerError(rw, err) return diff --git a/enterprise/coderd/workspaceproxy_test.go b/enterprise/coderd/workspaceproxy_test.go index 5231a0b0c4241..0be112b532b7a 100644 --- a/enterprise/coderd/workspaceproxy_test.go +++ b/enterprise/coderd/workspaceproxy_test.go @@ -320,7 +320,6 @@ func TestProxyRegisterDeregister(t *testing.T) { } registerRes1, err := proxyClient.RegisterWorkspaceProxy(ctx, req) require.NoError(t, err) - require.NotEmpty(t, registerRes1.AppSecurityKey) require.NotEmpty(t, registerRes1.DERPMeshKey) require.EqualValues(t, 10001, registerRes1.DERPRegionID) require.Empty(t, registerRes1.SiblingReplicas) @@ -609,11 +608,8 @@ func TestProxyRegisterDeregister(t *testing.T) { func TestIssueSignedAppToken(t *testing.T) { t.Parallel() - db, pubsub := dbtestutil.NewDB(t) client, user := coderdenttest.New(t, &coderdenttest.Options{ Options: &coderdtest.Options{ - Database: db, - Pubsub: pubsub, IncludeProvisionerDaemon: true, }, LicenseOptions: &coderdenttest.LicenseOptions{ @@ -716,6 +712,10 @@ func TestReconnectingPTYSignedToken(t *testing.T) { closer.Close() }) + _ = dbgen.CryptoKey(t, db, database.CryptoKey{ + Feature: database.CryptoKeyFeatureWorkspaceAppsToken, + }) + // Create a workspace + apps authToken := uuid.NewString() version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{ @@ -915,51 +915,86 @@ func TestGetCryptoKeys(t *testing.T) { now := time.Now() expectedKey1 := dbgen.CryptoKey(t, db, database.CryptoKey{ - Feature: database.CryptoKeyFeatureWorkspaceApps, + Feature: database.CryptoKeyFeatureWorkspaceAppsAPIKey, StartsAt: now.Add(-time.Hour), Sequence: 2, }) - key1 := db2sdk.CryptoKey(expectedKey1) + encryptionKey := db2sdk.CryptoKey(expectedKey1) expectedKey2 := dbgen.CryptoKey(t, db, database.CryptoKey{ - Feature: database.CryptoKeyFeatureWorkspaceApps, + Feature: database.CryptoKeyFeatureWorkspaceAppsToken, StartsAt: now, Sequence: 3, }) - key2 := db2sdk.CryptoKey(expectedKey2) + signingKey := db2sdk.CryptoKey(expectedKey2) // Create a deleted key. _ = dbgen.CryptoKey(t, db, database.CryptoKey{ - Feature: database.CryptoKeyFeatureWorkspaceApps, + Feature: database.CryptoKeyFeatureWorkspaceAppsAPIKey, StartsAt: now.Add(-time.Hour), Secret: sql.NullString{ String: "secret1", Valid: false, }, - Sequence: 1, - }) - - // Create a key with different features. - _ = dbgen.CryptoKey(t, db, database.CryptoKey{ - Feature: database.CryptoKeyFeatureTailnetResume, - StartsAt: now.Add(-time.Hour), - Sequence: 1, - }) - _ = dbgen.CryptoKey(t, db, database.CryptoKey{ - Feature: database.CryptoKeyFeatureOidcConvert, - StartsAt: now.Add(-time.Hour), - Sequence: 1, + Sequence: 4, }) proxy := coderdenttest.NewWorkspaceProxyReplica(t, api, cclient, &coderdenttest.ProxyOptions{ Name: testutil.GetRandomName(t), }) - keys, err := proxy.SDKClient.CryptoKeys(ctx) + keys, err := proxy.SDKClient.CryptoKeys(ctx, codersdk.CryptoKeyFeatureWorkspaceAppsAPIKey) require.NoError(t, err) require.NotEmpty(t, keys) + // 1 key is generated on startup, the other we manually generated. require.Equal(t, 2, len(keys.CryptoKeys)) - requireContainsKeys(t, keys.CryptoKeys, key1, key2) + requireContainsKeys(t, keys.CryptoKeys, encryptionKey) + requireNotContainsKeys(t, keys.CryptoKeys, signingKey) + + keys, err = proxy.SDKClient.CryptoKeys(ctx, codersdk.CryptoKeyFeatureWorkspaceAppsToken) + require.NoError(t, err) + require.NotEmpty(t, keys) + // 1 key is generated on startup, the other we manually generated. + require.Equal(t, 2, len(keys.CryptoKeys)) + requireContainsKeys(t, keys.CryptoKeys, signingKey) + requireNotContainsKeys(t, keys.CryptoKeys, encryptionKey) + }) + + t.Run("InvalidFeature", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitMedium) + db, pubsub := dbtestutil.NewDB(t) + cclient, _, api, _ := coderdenttest.NewWithAPI(t, &coderdenttest.Options{ + Options: &coderdtest.Options{ + Database: db, + Pubsub: pubsub, + IncludeProvisionerDaemon: true, + }, + LicenseOptions: &coderdenttest.LicenseOptions{ + Features: license.Features{ + codersdk.FeatureWorkspaceProxy: 1, + }, + }, + }) + + proxy := coderdenttest.NewWorkspaceProxyReplica(t, api, cclient, &coderdenttest.ProxyOptions{ + Name: testutil.GetRandomName(t), + }) + + _, err := proxy.SDKClient.CryptoKeys(ctx, codersdk.CryptoKeyFeatureOIDCConvert) + require.Error(t, err) + var sdkErr *codersdk.Error + require.ErrorAs(t, err, &sdkErr) + require.Equal(t, http.StatusBadRequest, sdkErr.StatusCode()) + _, err = proxy.SDKClient.CryptoKeys(ctx, codersdk.CryptoKeyFeatureTailnetResume) + require.Error(t, err) + require.ErrorAs(t, err, &sdkErr) + require.Equal(t, http.StatusBadRequest, sdkErr.StatusCode()) + _, err = proxy.SDKClient.CryptoKeys(ctx, "invalid") + require.Error(t, err) + require.ErrorAs(t, err, &sdkErr) + require.Equal(t, http.StatusBadRequest, sdkErr.StatusCode()) }) t.Run("Unauthorized", func(t *testing.T) { @@ -987,7 +1022,7 @@ func TestGetCryptoKeys(t *testing.T) { client := wsproxysdk.New(cclient.URL) client.SetSessionToken(cclient.SessionToken()) - _, err := client.CryptoKeys(ctx) + _, err := client.CryptoKeys(ctx, codersdk.CryptoKeyFeatureWorkspaceAppsAPIKey) require.Error(t, err) var sdkErr *codersdk.Error require.ErrorAs(t, err, &sdkErr) @@ -995,6 +1030,18 @@ func TestGetCryptoKeys(t *testing.T) { }) } +func requireNotContainsKeys(t *testing.T, keys []codersdk.CryptoKey, unexpected ...codersdk.CryptoKey) { + t.Helper() + + for _, unexpectedKey := range unexpected { + for _, key := range keys { + if key.Feature == unexpectedKey.Feature && key.Sequence == unexpectedKey.Sequence { + t.Fatalf("unexpected key %+v found", unexpectedKey) + } + } + } +} + func requireContainsKeys(t *testing.T, keys []codersdk.CryptoKey, expected ...codersdk.CryptoKey) { t.Helper() diff --git a/enterprise/dbcrypt/dbcrypt_internal_test.go b/enterprise/dbcrypt/dbcrypt_internal_test.go index 432dc90061677..a96c32aaa8aae 100644 --- a/enterprise/dbcrypt/dbcrypt_internal_test.go +++ b/enterprise/dbcrypt/dbcrypt_internal_test.go @@ -397,12 +397,12 @@ func TestCryptoKeys(t *testing.T) { _ = dbgen.CryptoKey(t, crypt, database.CryptoKey{ Secret: sql.NullString{String: "test", Valid: true}, }) - key, err := crypt.GetLatestCryptoKeyByFeature(ctx, database.CryptoKeyFeatureWorkspaceApps) + key, err := crypt.GetLatestCryptoKeyByFeature(ctx, database.CryptoKeyFeatureWorkspaceAppsAPIKey) require.NoError(t, err) require.Equal(t, "test", key.Secret.String) require.Equal(t, ciphers[0].HexDigest(), key.SecretKeyID.String) - key, err = db.GetLatestCryptoKeyByFeature(ctx, database.CryptoKeyFeatureWorkspaceApps) + key, err = db.GetLatestCryptoKeyByFeature(ctx, database.CryptoKeyFeatureWorkspaceAppsAPIKey) require.NoError(t, err) requireEncryptedEquals(t, ciphers[0], key.Secret.String, "test") require.Equal(t, ciphers[0].HexDigest(), key.SecretKeyID.String) @@ -415,7 +415,7 @@ func TestCryptoKeys(t *testing.T) { Secret: sql.NullString{String: "test", Valid: true}, }) key, err := crypt.GetCryptoKeyByFeatureAndSequence(ctx, database.GetCryptoKeyByFeatureAndSequenceParams{ - Feature: database.CryptoKeyFeatureWorkspaceApps, + Feature: database.CryptoKeyFeatureWorkspaceAppsAPIKey, Sequence: key.Sequence, }) require.NoError(t, err) @@ -423,7 +423,7 @@ func TestCryptoKeys(t *testing.T) { require.Equal(t, ciphers[0].HexDigest(), key.SecretKeyID.String) key, err = db.GetCryptoKeyByFeatureAndSequence(ctx, database.GetCryptoKeyByFeatureAndSequenceParams{ - Feature: database.CryptoKeyFeatureWorkspaceApps, + Feature: database.CryptoKeyFeatureWorkspaceAppsAPIKey, Sequence: key.Sequence, }) require.NoError(t, err) @@ -459,7 +459,7 @@ func TestCryptoKeys(t *testing.T) { Secret: sql.NullString{String: "test", Valid: true}, }) _ = dbgen.CryptoKey(t, crypt, database.CryptoKey{ - Feature: database.CryptoKeyFeatureWorkspaceApps, + Feature: database.CryptoKeyFeatureWorkspaceAppsAPIKey, Sequence: 43, }) keys, err := crypt.GetCryptoKeysByFeature(ctx, database.CryptoKeyFeatureTailnetResume) diff --git a/enterprise/workspaceapps_test.go b/enterprise/workspaceapps_test.go index f4ba577f13e33..51d0314c45767 100644 --- a/enterprise/workspaceapps_test.go +++ b/enterprise/workspaceapps_test.go @@ -5,6 +5,7 @@ import ( "testing" "github.com/coder/coder/v2/coderd/coderdtest" + "github.com/coder/coder/v2/coderd/database/dbtestutil" "github.com/coder/coder/v2/coderd/httpmw" "github.com/coder/coder/v2/coderd/workspaceapps/apptest" "github.com/coder/coder/v2/codersdk" @@ -36,6 +37,9 @@ func TestWorkspaceApps(t *testing.T) { flushStatsCollectorCh <- flushStatsCollectorDone <-flushStatsCollectorDone } + + db, pubsub := dbtestutil.NewDB(t) + client, _, _, user := coderdenttest.NewWithAPI(t, &coderdenttest.Options{ Options: &coderdtest.Options{ DeploymentValues: deploymentValues, @@ -51,6 +55,8 @@ func TestWorkspaceApps(t *testing.T) { }, }, WorkspaceAppsStatsCollectorOptions: opts.StatsCollectorOptions, + Database: db, + Pubsub: pubsub, }, LicenseOptions: &coderdenttest.LicenseOptions{ Features: license.Features{ diff --git a/enterprise/wsproxy/keyfetcher.go b/enterprise/wsproxy/keyfetcher.go index f30fffb2cd093..1a1745d6ccd2d 100644 --- a/enterprise/wsproxy/keyfetcher.go +++ b/enterprise/wsproxy/keyfetcher.go @@ -13,12 +13,11 @@ import ( var _ cryptokeys.Fetcher = &ProxyFetcher{} type ProxyFetcher struct { - Client *wsproxysdk.Client - Feature codersdk.CryptoKeyFeature + Client *wsproxysdk.Client } -func (p *ProxyFetcher) Fetch(ctx context.Context) ([]codersdk.CryptoKey, error) { - keys, err := p.Client.CryptoKeys(ctx) +func (p *ProxyFetcher) Fetch(ctx context.Context, feature codersdk.CryptoKeyFeature) ([]codersdk.CryptoKey, error) { + keys, err := p.Client.CryptoKeys(ctx, feature) if err != nil { return nil, xerrors.Errorf("crypto keys: %w", err) } diff --git a/enterprise/wsproxy/tokenprovider.go b/enterprise/wsproxy/tokenprovider.go index 38822a4e7a22d..5093c6015725e 100644 --- a/enterprise/wsproxy/tokenprovider.go +++ b/enterprise/wsproxy/tokenprovider.go @@ -7,6 +7,8 @@ import ( "cdr.dev/slog" + "github.com/coder/coder/v2/coderd/cryptokeys" + "github.com/coder/coder/v2/coderd/jwtutils" "github.com/coder/coder/v2/coderd/workspaceapps" "github.com/coder/coder/v2/enterprise/wsproxy/wsproxysdk" ) @@ -18,18 +20,19 @@ type TokenProvider struct { AccessURL *url.URL AppHostname string - Client *wsproxysdk.Client - SecurityKey workspaceapps.SecurityKey - Logger slog.Logger + Client *wsproxysdk.Client + TokenSigningKeycache cryptokeys.SigningKeycache + APIKeyEncryptionKeycache cryptokeys.EncryptionKeycache + Logger slog.Logger } func (p *TokenProvider) FromRequest(r *http.Request) (*workspaceapps.SignedToken, bool) { - return workspaceapps.FromRequest(r, p.SecurityKey) + return workspaceapps.FromRequest(r, p.TokenSigningKeycache) } func (p *TokenProvider) Issue(ctx context.Context, rw http.ResponseWriter, r *http.Request, issueReq workspaceapps.IssueTokenRequest) (*workspaceapps.SignedToken, string, bool) { appReq := issueReq.AppRequest.Normalize() - err := appReq.Validate() + err := appReq.Check() if err != nil { workspaceapps.WriteWorkspaceApp500(p.Logger, p.DashboardURL, rw, r, &appReq, err, "invalid app request") return nil, "", false @@ -42,7 +45,8 @@ func (p *TokenProvider) Issue(ctx context.Context, rw http.ResponseWriter, r *ht } // Check that it verifies properly and matches the string. - token, err := p.SecurityKey.VerifySignedToken(resp.SignedTokenStr) + var token workspaceapps.SignedToken + err = jwtutils.Verify(ctx, p.TokenSigningKeycache, resp.SignedTokenStr, &token) if err != nil { workspaceapps.WriteWorkspaceApp500(p.Logger, p.DashboardURL, rw, r, &appReq, err, "failed to verify newly generated signed token") return nil, "", false diff --git a/enterprise/wsproxy/wsproxy.go b/enterprise/wsproxy/wsproxy.go index 2a7e9e81e0cda..fe900fa433530 100644 --- a/enterprise/wsproxy/wsproxy.go +++ b/enterprise/wsproxy/wsproxy.go @@ -31,6 +31,7 @@ import ( "github.com/coder/coder/v2/buildinfo" "github.com/coder/coder/v2/cli/cliutil" "github.com/coder/coder/v2/coderd" + "github.com/coder/coder/v2/coderd/cryptokeys" "github.com/coder/coder/v2/coderd/httpapi" "github.com/coder/coder/v2/coderd/httpmw" "github.com/coder/coder/v2/coderd/tracing" @@ -130,6 +131,13 @@ type Server struct { // the moon's token. SDKClient *wsproxysdk.Client + // apiKeyEncryptionKeycache manages the encryption keys for smuggling API + // tokens to the alternate domain when using workspace apps. + apiKeyEncryptionKeycache cryptokeys.EncryptionKeycache + // appTokenSigningKeycache manages the signing keys for signing the app + // tokens we use for workspace apps. + appTokenSigningKeycache cryptokeys.SigningKeycache + // DERP derpMesh *derpmesh.Mesh derpMeshTLSConfig *tls.Config @@ -195,19 +203,42 @@ func New(ctx context.Context, opts *Options) (*Server, error) { derpServer := derp.NewServer(key.NewNode(), tailnet.Logger(opts.Logger.Named("net.derp"))) ctx, cancel := context.WithCancel(context.Background()) + + encryptionCache, err := cryptokeys.NewEncryptionCache(ctx, + opts.Logger, + &ProxyFetcher{Client: client}, + codersdk.CryptoKeyFeatureWorkspaceAppsAPIKey, + ) + if err != nil { + cancel() + return nil, xerrors.Errorf("create api key encryption cache: %w", err) + } + signingCache, err := cryptokeys.NewSigningCache(ctx, + opts.Logger, + &ProxyFetcher{Client: client}, + codersdk.CryptoKeyFeatureWorkspaceAppsToken, + ) + if err != nil { + cancel() + return nil, xerrors.Errorf("create api token signing cache: %w", err) + } + r := chi.NewRouter() s := &Server{ - Options: opts, - Handler: r, - DashboardURL: opts.DashboardURL, - Logger: opts.Logger.Named("net.workspace-proxy"), - TracerProvider: opts.Tracing, - PrometheusRegistry: opts.PrometheusRegistry, - SDKClient: client, - derpMesh: derpmesh.New(opts.Logger.Named("net.derpmesh"), derpServer, meshTLSConfig), - derpMeshTLSConfig: meshTLSConfig, - ctx: ctx, - cancel: cancel, + ctx: ctx, + cancel: cancel, + + Options: opts, + Handler: r, + DashboardURL: opts.DashboardURL, + Logger: opts.Logger.Named("net.workspace-proxy"), + TracerProvider: opts.Tracing, + PrometheusRegistry: opts.PrometheusRegistry, + SDKClient: client, + derpMesh: derpmesh.New(opts.Logger.Named("net.derpmesh"), derpServer, meshTLSConfig), + derpMeshTLSConfig: meshTLSConfig, + apiKeyEncryptionKeycache: encryptionCache, + appTokenSigningKeycache: signingCache, } // Register the workspace proxy with the primary coderd instance and start a @@ -240,11 +271,6 @@ func New(ctx context.Context, opts *Options) (*Server, error) { return nil, xerrors.Errorf("handle register: %w", err) } - secKey, err := workspaceapps.KeyFromString(regResp.AppSecurityKey) - if err != nil { - return nil, xerrors.Errorf("parse app security key: %w", err) - } - agentProvider, err := coderd.NewServerTailnet(ctx, s.Logger, nil, @@ -277,20 +303,21 @@ func New(ctx context.Context, opts *Options) (*Server, error) { HostnameRegex: opts.AppHostnameRegex, RealIPConfig: opts.RealIPConfig, SignedTokenProvider: &TokenProvider{ - DashboardURL: opts.DashboardURL, - AccessURL: opts.AccessURL, - AppHostname: opts.AppHostname, - Client: client, - SecurityKey: secKey, - Logger: s.Logger.Named("proxy_token_provider"), + DashboardURL: opts.DashboardURL, + AccessURL: opts.AccessURL, + AppHostname: opts.AppHostname, + Client: client, + TokenSigningKeycache: signingCache, + APIKeyEncryptionKeycache: encryptionCache, + Logger: s.Logger.Named("proxy_token_provider"), }, - AppSecurityKey: secKey, DisablePathApps: opts.DisablePathApps, SecureAuthCookie: opts.SecureAuthCookie, - AgentProvider: agentProvider, - StatsCollector: workspaceapps.NewStatsCollector(opts.StatsCollectorOptions), + AgentProvider: agentProvider, + StatsCollector: workspaceapps.NewStatsCollector(opts.StatsCollectorOptions), + APIKeyEncryptionKeycache: encryptionCache, } derpHandler := derphttp.Handler(derpServer) @@ -435,6 +462,8 @@ func (s *Server) Close() error { err = multierror.Append(err, agentProviderErr) } s.SDKClient.SDKClient.HTTPClient.CloseIdleConnections() + _ = s.appTokenSigningKeycache.Close() + _ = s.apiKeyEncryptionKeycache.Close() return err } diff --git a/enterprise/wsproxy/wsproxy_test.go b/enterprise/wsproxy/wsproxy_test.go index 3d3926c5afae7..4add46af9bc0a 100644 --- a/enterprise/wsproxy/wsproxy_test.go +++ b/enterprise/wsproxy/wsproxy_test.go @@ -25,6 +25,9 @@ import ( "github.com/coder/coder/v2/agent/agenttest" "github.com/coder/coder/v2/buildinfo" "github.com/coder/coder/v2/coderd/coderdtest" + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/dbgen" + "github.com/coder/coder/v2/coderd/database/dbtestutil" "github.com/coder/coder/v2/coderd/healthcheck/derphealth" "github.com/coder/coder/v2/coderd/httpmw" "github.com/coder/coder/v2/coderd/workspaceapps/apptest" @@ -932,6 +935,9 @@ func TestWorkspaceProxyWorkspaceApps(t *testing.T) { if opts.PrimaryAppHost == "" { opts.PrimaryAppHost = "*.primary.test.coder.com" } + + db, pubsub := dbtestutil.NewDB(t) + client, closer, api, user := coderdenttest.NewWithAPI(t, &coderdenttest.Options{ Options: &coderdtest.Options{ DeploymentValues: deploymentValues, @@ -947,6 +953,8 @@ func TestWorkspaceProxyWorkspaceApps(t *testing.T) { }, }, WorkspaceAppsStatsCollectorOptions: opts.StatsCollectorOptions, + Database: db, + Pubsub: pubsub, }, LicenseOptions: &coderdenttest.LicenseOptions{ Features: license.Features{ @@ -959,6 +967,13 @@ func TestWorkspaceProxyWorkspaceApps(t *testing.T) { _ = closer.Close() }) + _ = dbgen.CryptoKey(t, db, database.CryptoKey{ + Feature: database.CryptoKeyFeatureWorkspaceAppsToken, + }) + _ = dbgen.CryptoKey(t, db, database.CryptoKey{ + Feature: database.CryptoKeyFeatureWorkspaceAppsAPIKey, + }) + // Create the external proxy if opts.DisableSubdomainApps { opts.AppHost = "" @@ -1002,6 +1017,8 @@ func TestWorkspaceProxyWorkspaceApps_BlockDirect(t *testing.T) { if opts.PrimaryAppHost == "" { opts.PrimaryAppHost = "*.primary.test.coder.com" } + + db, pubsub := dbtestutil.NewDB(t) client, closer, api, user := coderdenttest.NewWithAPI(t, &coderdenttest.Options{ Options: &coderdtest.Options{ DeploymentValues: deploymentValues, @@ -1017,6 +1034,8 @@ func TestWorkspaceProxyWorkspaceApps_BlockDirect(t *testing.T) { }, }, WorkspaceAppsStatsCollectorOptions: opts.StatsCollectorOptions, + Database: db, + Pubsub: pubsub, }, LicenseOptions: &coderdenttest.LicenseOptions{ Features: license.Features{ @@ -1029,6 +1048,13 @@ func TestWorkspaceProxyWorkspaceApps_BlockDirect(t *testing.T) { _ = closer.Close() }) + _ = dbgen.CryptoKey(t, db, database.CryptoKey{ + Feature: database.CryptoKeyFeatureWorkspaceAppsToken, + }) + _ = dbgen.CryptoKey(t, db, database.CryptoKey{ + Feature: database.CryptoKeyFeatureWorkspaceAppsAPIKey, + }) + // Create the external proxy if opts.DisableSubdomainApps { opts.AppHost = "" diff --git a/enterprise/wsproxy/wsproxysdk/wsproxysdk.go b/enterprise/wsproxy/wsproxysdk/wsproxysdk.go index 77d36561c6de8..a8f22c2b93063 100644 --- a/enterprise/wsproxy/wsproxysdk/wsproxysdk.go +++ b/enterprise/wsproxy/wsproxysdk/wsproxysdk.go @@ -205,7 +205,6 @@ type RegisterWorkspaceProxyRequest struct { } type RegisterWorkspaceProxyResponse struct { - AppSecurityKey string `json:"app_security_key"` DERPMeshKey string `json:"derp_mesh_key"` DERPRegionID int32 `json:"derp_region_id"` DERPMap *tailcfg.DERPMap `json:"derp_map"` @@ -372,12 +371,6 @@ func (l *RegisterWorkspaceProxyLoop) Start(ctx context.Context) (RegisterWorkspa } failedAttempts = 0 - // Check for consistency. - if originalRes.AppSecurityKey != resp.AppSecurityKey { - l.failureFn(xerrors.New("app security key has changed, proxy must be restarted")) - return - } - if originalRes.DERPMeshKey != resp.DERPMeshKey { l.failureFn(xerrors.New("DERP mesh key has changed, proxy must be restarted")) return @@ -586,10 +579,10 @@ type CryptoKeysResponse struct { CryptoKeys []codersdk.CryptoKey `json:"crypto_keys"` } -func (c *Client) CryptoKeys(ctx context.Context) (CryptoKeysResponse, error) { +func (c *Client) CryptoKeys(ctx context.Context, feature codersdk.CryptoKeyFeature) (CryptoKeysResponse, error) { res, err := c.Request(ctx, http.MethodGet, - "/api/v2/workspaceproxies/me/crypto-keys", - nil, + "/api/v2/workspaceproxies/me/crypto-keys", nil, + codersdk.WithQueryParam("feature", string(feature)), ) if err != nil { return CryptoKeysResponse{}, xerrors.Errorf("make request: %w", err) diff --git a/site/src/api/typesGenerated.ts b/site/src/api/typesGenerated.ts index e55167ef03f88..d687fb68ec61f 100644 --- a/site/src/api/typesGenerated.ts +++ b/site/src/api/typesGenerated.ts @@ -2110,8 +2110,8 @@ export type BuildReason = "autostart" | "autostop" | "initiator" export const BuildReasons: BuildReason[] = ["autostart", "autostop", "initiator"] // From codersdk/deployment.go -export type CryptoKeyFeature = "oidc_convert" | "tailnet_resume" | "workspace_apps" -export const CryptoKeyFeatures: CryptoKeyFeature[] = ["oidc_convert", "tailnet_resume", "workspace_apps"] +export type CryptoKeyFeature = "oidc_convert" | "tailnet_resume" | "workspace_apps_api_key" | "workspace_apps_token" +export const CryptoKeyFeatures: CryptoKeyFeature[] = ["oidc_convert", "tailnet_resume", "workspace_apps_api_key", "workspace_apps_token"] // From codersdk/workspaceagents.go export type DisplayApp = "port_forwarding_helper" | "ssh_helper" | "vscode" | "vscode_insiders" | "web_terminal" diff --git a/tailnet/resume.go b/tailnet/resume.go index b9443064a37f9..2975fa35f1674 100644 --- a/tailnet/resume.go +++ b/tailnet/resume.go @@ -3,32 +3,23 @@ package tailnet import ( "context" "crypto/rand" - "database/sql" - "encoding/hex" - "encoding/json" "time" - "github.com/go-jose/go-jose/v3" + "github.com/go-jose/go-jose/v4/jwt" "github.com/google/uuid" "golang.org/x/xerrors" "google.golang.org/protobuf/types/known/durationpb" "google.golang.org/protobuf/types/known/timestamppb" + "github.com/coder/coder/v2/coderd/jwtutils" "github.com/coder/coder/v2/tailnet/proto" "github.com/coder/quartz" ) const ( DefaultResumeTokenExpiry = 24 * time.Hour - - resumeTokenSigningAlgorithm = jose.HS512 ) -// resumeTokenSigningKeyID is a fixed key ID for the resume token signing key. -// If/when we add support for multiple keys (e.g. key rotation), this will move -// to the database instead. -var resumeTokenSigningKeyID = uuid.MustParse("97166747-9309-4d7f-9071-a230e257c2a4") - // NewInsecureTestResumeTokenProvider returns a ResumeTokenProvider that uses a // random key with short expiry for testing purposes. If any errors occur while // generating the key, the function panics. @@ -37,12 +28,15 @@ func NewInsecureTestResumeTokenProvider() ResumeTokenProvider { if err != nil { panic(err) } - return NewResumeTokenKeyProvider(key, quartz.NewReal(), time.Hour) + return NewResumeTokenKeyProvider(jwtutils.StaticKey{ + ID: uuid.New().String(), + Key: key[:], + }, quartz.NewReal(), time.Hour) } type ResumeTokenProvider interface { - GenerateResumeToken(peerID uuid.UUID) (*proto.RefreshResumeTokenResponse, error) - VerifyResumeToken(token string) (uuid.UUID, error) + GenerateResumeToken(ctx context.Context, peerID uuid.UUID) (*proto.RefreshResumeTokenResponse, error) + VerifyResumeToken(ctx context.Context, token string) (uuid.UUID, error) } type ResumeTokenSigningKey [64]byte @@ -56,104 +50,37 @@ func GenerateResumeTokenSigningKey() (ResumeTokenSigningKey, error) { return key, nil } -type ResumeTokenSigningKeyDatabaseStore interface { - GetCoordinatorResumeTokenSigningKey(ctx context.Context) (string, error) - UpsertCoordinatorResumeTokenSigningKey(ctx context.Context, key string) error -} - -// ResumeTokenSigningKeyFromDatabase retrieves the coordinator resume token -// signing key from the database. If the key is not found, a new key is -// generated and inserted into the database. -func ResumeTokenSigningKeyFromDatabase(ctx context.Context, db ResumeTokenSigningKeyDatabaseStore) (ResumeTokenSigningKey, error) { - var resumeTokenKey ResumeTokenSigningKey - resumeTokenKeyStr, err := db.GetCoordinatorResumeTokenSigningKey(ctx) - if err != nil && !xerrors.Is(err, sql.ErrNoRows) { - return resumeTokenKey, xerrors.Errorf("get coordinator resume token key: %w", err) - } - if decoded, err := hex.DecodeString(resumeTokenKeyStr); err != nil || len(decoded) != len(resumeTokenKey) { - newKey, err := GenerateResumeTokenSigningKey() - if err != nil { - return resumeTokenKey, xerrors.Errorf("generate fresh coordinator resume token key: %w", err) - } - - resumeTokenKeyStr = hex.EncodeToString(newKey[:]) - err = db.UpsertCoordinatorResumeTokenSigningKey(ctx, resumeTokenKeyStr) - if err != nil { - return resumeTokenKey, xerrors.Errorf("insert freshly generated coordinator resume token key to database: %w", err) - } - } - - resumeTokenKeyBytes, err := hex.DecodeString(resumeTokenKeyStr) - if err != nil { - return resumeTokenKey, xerrors.Errorf("decode coordinator resume token key from database: %w", err) - } - if len(resumeTokenKeyBytes) != len(resumeTokenKey) { - return resumeTokenKey, xerrors.Errorf("coordinator resume token key in database is not the correct length, expect %d got %d", len(resumeTokenKey), len(resumeTokenKeyBytes)) - } - copy(resumeTokenKey[:], resumeTokenKeyBytes) - if resumeTokenKey == [64]byte{} { - return resumeTokenKey, xerrors.Errorf("coordinator resume token key in database is empty") - } - return resumeTokenKey, nil -} - type ResumeTokenKeyProvider struct { - key ResumeTokenSigningKey + key jwtutils.SigningKeyManager clock quartz.Clock expiry time.Duration } -func NewResumeTokenKeyProvider(key ResumeTokenSigningKey, clock quartz.Clock, expiry time.Duration) ResumeTokenProvider { +func NewResumeTokenKeyProvider(key jwtutils.SigningKeyManager, clock quartz.Clock, expiry time.Duration) ResumeTokenProvider { if expiry <= 0 { expiry = DefaultResumeTokenExpiry } return ResumeTokenKeyProvider{ key: key, clock: clock, - expiry: DefaultResumeTokenExpiry, + expiry: expiry, } } -type resumeTokenPayload struct { - PeerID uuid.UUID `json:"sub"` - Expiry int64 `json:"exp"` -} - -func (p ResumeTokenKeyProvider) GenerateResumeToken(peerID uuid.UUID) (*proto.RefreshResumeTokenResponse, error) { +func (p ResumeTokenKeyProvider) GenerateResumeToken(ctx context.Context, peerID uuid.UUID) (*proto.RefreshResumeTokenResponse, error) { exp := p.clock.Now().Add(p.expiry) - payload := resumeTokenPayload{ - PeerID: peerID, - Expiry: exp.Unix(), - } - payloadBytes, err := json.Marshal(payload) - if err != nil { - return nil, xerrors.Errorf("marshal payload to JSON: %w", err) - } - - signer, err := jose.NewSigner(jose.SigningKey{ - Algorithm: resumeTokenSigningAlgorithm, - Key: p.key[:], - }, &jose.SignerOptions{ - ExtraHeaders: map[jose.HeaderKey]interface{}{ - "kid": resumeTokenSigningKeyID.String(), - }, - }) - if err != nil { - return nil, xerrors.Errorf("create signer: %w", err) + payload := jwtutils.RegisteredClaims{ + Subject: peerID.String(), + Expiry: jwt.NewNumericDate(exp), } - signedObject, err := signer.Sign(payloadBytes) + token, err := jwtutils.Sign(ctx, p.key, payload) if err != nil { return nil, xerrors.Errorf("sign payload: %w", err) } - serialized, err := signedObject.CompactSerialize() - if err != nil { - return nil, xerrors.Errorf("serialize JWS: %w", err) - } - return &proto.RefreshResumeTokenResponse{ - Token: serialized, + Token: token, RefreshIn: durationpb.New(p.expiry / 2), ExpiresAt: timestamppb.New(exp), }, nil @@ -162,35 +89,17 @@ func (p ResumeTokenKeyProvider) GenerateResumeToken(peerID uuid.UUID) (*proto.Re // VerifyResumeToken parses a signed tailnet resume token with the given key and // returns the payload. If the token is invalid or expired, an error is // returned. -func (p ResumeTokenKeyProvider) VerifyResumeToken(str string) (uuid.UUID, error) { - object, err := jose.ParseSigned(str) - if err != nil { - return uuid.Nil, xerrors.Errorf("parse JWS: %w", err) - } - if len(object.Signatures) != 1 { - return uuid.Nil, xerrors.New("expected 1 signature") - } - if object.Signatures[0].Header.Algorithm != string(resumeTokenSigningAlgorithm) { - return uuid.Nil, xerrors.Errorf("expected token signing algorithm to be %q, got %q", resumeTokenSigningAlgorithm, object.Signatures[0].Header.Algorithm) - } - if object.Signatures[0].Header.KeyID != resumeTokenSigningKeyID.String() { - return uuid.Nil, xerrors.Errorf("expected token key ID to be %q, got %q", resumeTokenSigningKeyID, object.Signatures[0].Header.KeyID) - } - - output, err := object.Verify(p.key[:]) +func (p ResumeTokenKeyProvider) VerifyResumeToken(ctx context.Context, str string) (uuid.UUID, error) { + var tok jwt.Claims + err := jwtutils.Verify(ctx, p.key, str, &tok, jwtutils.WithVerifyExpected(jwt.Expected{ + Time: p.clock.Now(), + })) if err != nil { - return uuid.Nil, xerrors.Errorf("verify JWS: %w", err) + return uuid.Nil, xerrors.Errorf("verify payload: %w", err) } - - var tok resumeTokenPayload - err = json.Unmarshal(output, &tok) + parsed, err := uuid.Parse(tok.Subject) if err != nil { - return uuid.Nil, xerrors.Errorf("unmarshal payload: %w", err) + return uuid.Nil, xerrors.Errorf("parse peerID from token: %w", err) } - exp := time.Unix(tok.Expiry, 0) - if exp.Before(p.clock.Now()) { - return uuid.Nil, xerrors.New("signed resume token expired") - } - - return tok.PeerID, nil + return parsed, nil } diff --git a/tailnet/resume_test.go b/tailnet/resume_test.go index 3f63887cbfef3..6f32fba4c511e 100644 --- a/tailnet/resume_test.go +++ b/tailnet/resume_test.go @@ -1,117 +1,20 @@ package tailnet_test import ( - "context" - "encoding/hex" "testing" "time" + "github.com/go-jose/go-jose/v4" + "github.com/go-jose/go-jose/v4/jwt" "github.com/google/uuid" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "go.uber.org/mock/gomock" - "github.com/coder/coder/v2/coderd/database/dbmock" - "github.com/coder/coder/v2/coderd/database/dbtestutil" + "github.com/coder/coder/v2/coderd/jwtutils" "github.com/coder/coder/v2/tailnet" "github.com/coder/coder/v2/testutil" "github.com/coder/quartz" ) -func TestResumeTokenSigningKeyFromDatabase(t *testing.T) { - t.Parallel() - - assertRandomKey := func(t *testing.T, key tailnet.ResumeTokenSigningKey) { - t.Helper() - assert.NotEqual(t, tailnet.ResumeTokenSigningKey{}, key, "key should not be empty") - assert.NotEqualValues(t, [64]byte{1}, key, "key should not be all 1s") - } - - t.Run("GenerateRetrieve", func(t *testing.T) { - t.Parallel() - - db, _ := dbtestutil.NewDB(t) - ctx := testutil.Context(t, testutil.WaitShort) - key1, err := tailnet.ResumeTokenSigningKeyFromDatabase(ctx, db) - require.NoError(t, err) - assertRandomKey(t, key1) - - key2, err := tailnet.ResumeTokenSigningKeyFromDatabase(ctx, db) - require.NoError(t, err) - require.Equal(t, key1, key2, "keys should not be different") - }) - - t.Run("GetError", func(t *testing.T) { - t.Parallel() - - db := dbmock.NewMockStore(gomock.NewController(t)) - db.EXPECT().GetCoordinatorResumeTokenSigningKey(gomock.Any()).Return("", assert.AnError) - - ctx := testutil.Context(t, testutil.WaitShort) - _, err := tailnet.ResumeTokenSigningKeyFromDatabase(ctx, db) - require.ErrorIs(t, err, assert.AnError) - }) - - t.Run("UpsertError", func(t *testing.T) { - t.Parallel() - - db := dbmock.NewMockStore(gomock.NewController(t)) - db.EXPECT().GetCoordinatorResumeTokenSigningKey(gomock.Any()).Return("", nil) - db.EXPECT().UpsertCoordinatorResumeTokenSigningKey(gomock.Any(), gomock.Any()).Return(assert.AnError) - - ctx := testutil.Context(t, testutil.WaitShort) - _, err := tailnet.ResumeTokenSigningKeyFromDatabase(ctx, db) - require.ErrorIs(t, err, assert.AnError) - }) - - t.Run("DecodeErrorShouldRegenerate", func(t *testing.T) { - t.Parallel() - - db := dbmock.NewMockStore(gomock.NewController(t)) - db.EXPECT().GetCoordinatorResumeTokenSigningKey(gomock.Any()).Return("invalid", nil) - - var storedKey tailnet.ResumeTokenSigningKey - db.EXPECT().UpsertCoordinatorResumeTokenSigningKey(gomock.Any(), gomock.Any()).Do(func(_ context.Context, value string) error { - keyBytes, err := hex.DecodeString(value) - require.NoError(t, err) - require.Len(t, keyBytes, len(storedKey)) - copy(storedKey[:], keyBytes) - return nil - }) - - ctx := testutil.Context(t, testutil.WaitShort) - key, err := tailnet.ResumeTokenSigningKeyFromDatabase(ctx, db) - require.NoError(t, err) - assertRandomKey(t, key) - require.Equal(t, storedKey, key, "key should match stored value") - }) - - t.Run("LengthErrorShouldRegenerate", func(t *testing.T) { - t.Parallel() - - db := dbmock.NewMockStore(gomock.NewController(t)) - db.EXPECT().GetCoordinatorResumeTokenSigningKey(gomock.Any()).Return("deadbeef", nil) - db.EXPECT().UpsertCoordinatorResumeTokenSigningKey(gomock.Any(), gomock.Any()).Return(nil) - - ctx := testutil.Context(t, testutil.WaitShort) - key, err := tailnet.ResumeTokenSigningKeyFromDatabase(ctx, db) - require.NoError(t, err) - assertRandomKey(t, key) - }) - - t.Run("EmptyError", func(t *testing.T) { - t.Parallel() - - db := dbmock.NewMockStore(gomock.NewController(t)) - emptyKey := hex.EncodeToString(make([]byte, 64)) - db.EXPECT().GetCoordinatorResumeTokenSigningKey(gomock.Any()).Return(emptyKey, nil) - - ctx := testutil.Context(t, testutil.WaitShort) - _, err := tailnet.ResumeTokenSigningKeyFromDatabase(ctx, db) - require.ErrorContains(t, err, "is empty") - }) -} - func TestResumeTokenKeyProvider(t *testing.T) { t.Parallel() @@ -121,17 +24,18 @@ func TestResumeTokenKeyProvider(t *testing.T) { t.Run("OK", func(t *testing.T) { t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) id := uuid.New() clock := quartz.NewMock(t) - provider := tailnet.NewResumeTokenKeyProvider(key, clock, tailnet.DefaultResumeTokenExpiry) - token, err := provider.GenerateResumeToken(id) + provider := tailnet.NewResumeTokenKeyProvider(newKeySigner(key), clock, tailnet.DefaultResumeTokenExpiry) + token, err := provider.GenerateResumeToken(ctx, id) require.NoError(t, err) require.NotNil(t, token) require.NotEmpty(t, token.Token) require.Equal(t, tailnet.DefaultResumeTokenExpiry/2, token.RefreshIn.AsDuration()) require.WithinDuration(t, clock.Now().Add(tailnet.DefaultResumeTokenExpiry), token.ExpiresAt.AsTime(), time.Second) - gotID, err := provider.VerifyResumeToken(token.Token) + gotID, err := provider.VerifyResumeToken(ctx, token.Token) require.NoError(t, err) require.Equal(t, id, gotID) }) @@ -139,43 +43,57 @@ func TestResumeTokenKeyProvider(t *testing.T) { t.Run("Expired", func(t *testing.T) { t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) id := uuid.New() clock := quartz.NewMock(t) - provider := tailnet.NewResumeTokenKeyProvider(key, clock, tailnet.DefaultResumeTokenExpiry) - token, err := provider.GenerateResumeToken(id) + provider := tailnet.NewResumeTokenKeyProvider(newKeySigner(key), clock, tailnet.DefaultResumeTokenExpiry) + token, err := provider.GenerateResumeToken(ctx, id) require.NoError(t, err) require.NotNil(t, token) require.NotEmpty(t, token.Token) require.Equal(t, tailnet.DefaultResumeTokenExpiry/2, token.RefreshIn.AsDuration()) require.WithinDuration(t, clock.Now().Add(tailnet.DefaultResumeTokenExpiry), token.ExpiresAt.AsTime(), time.Second) - // Advance time past expiry - _ = clock.Advance(tailnet.DefaultResumeTokenExpiry + time.Second) + // Advance time past expiry. Account for leeway. + _ = clock.Advance(tailnet.DefaultResumeTokenExpiry + time.Second*61) - _, err = provider.VerifyResumeToken(token.Token) - require.ErrorContains(t, err, "expired") + _, err = provider.VerifyResumeToken(ctx, token.Token) + require.Error(t, err) + require.ErrorIs(t, err, jwt.ErrExpired) }) t.Run("InvalidToken", func(t *testing.T) { t.Parallel() - provider := tailnet.NewResumeTokenKeyProvider(key, quartz.NewMock(t), tailnet.DefaultResumeTokenExpiry) - _, err := provider.VerifyResumeToken("invalid") + ctx := testutil.Context(t, testutil.WaitShort) + provider := tailnet.NewResumeTokenKeyProvider(newKeySigner(key), quartz.NewMock(t), tailnet.DefaultResumeTokenExpiry) + _, err := provider.VerifyResumeToken(ctx, "invalid") require.ErrorContains(t, err, "parse JWS") }) t.Run("VerifyError", func(t *testing.T) { t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) // Generate a resume token with a different key otherKey, err := tailnet.GenerateResumeTokenSigningKey() require.NoError(t, err) - otherProvider := tailnet.NewResumeTokenKeyProvider(otherKey, quartz.NewMock(t), tailnet.DefaultResumeTokenExpiry) - token, err := otherProvider.GenerateResumeToken(uuid.New()) + otherSigner := newKeySigner(otherKey) + otherProvider := tailnet.NewResumeTokenKeyProvider(otherSigner, quartz.NewMock(t), tailnet.DefaultResumeTokenExpiry) + token, err := otherProvider.GenerateResumeToken(ctx, uuid.New()) require.NoError(t, err) - provider := tailnet.NewResumeTokenKeyProvider(key, quartz.NewMock(t), tailnet.DefaultResumeTokenExpiry) - _, err = provider.VerifyResumeToken(token.Token) - require.ErrorContains(t, err, "verify JWS") + signer := newKeySigner(key) + signer.ID = otherSigner.ID + provider := tailnet.NewResumeTokenKeyProvider(signer, quartz.NewMock(t), tailnet.DefaultResumeTokenExpiry) + _, err = provider.VerifyResumeToken(ctx, token.Token) + require.ErrorIs(t, err, jose.ErrCryptoFailure) }) } + +func newKeySigner(key tailnet.ResumeTokenSigningKey) jwtutils.StaticKey { + return jwtutils.StaticKey{ + ID: "123", + Key: key[:], + } +} diff --git a/tailnet/service.go b/tailnet/service.go index 28a054dd8d671..7f38f63a589b3 100644 --- a/tailnet/service.go +++ b/tailnet/service.go @@ -177,7 +177,7 @@ func (s *DRPCService) RefreshResumeToken(ctx context.Context, _ *proto.RefreshRe return nil, xerrors.New("no Stream ID") } - res, err := s.ResumeTokenProvider.GenerateResumeToken(streamID.ID) + res, err := s.ResumeTokenProvider.GenerateResumeToken(ctx, streamID.ID) if err != nil { return nil, xerrors.Errorf("generate resume token: %w", err) }