diff --git a/cli/features.go b/cli/features.go index 5d631fc04977f..1995153275eaf 100644 --- a/cli/features.go +++ b/cli/features.go @@ -1,6 +1,7 @@ package cli import ( + "bytes" "encoding/json" "fmt" "strings" @@ -53,12 +54,14 @@ func featuresList() *cobra.Command { return xerrors.Errorf("render table: %w", err) } case "json": - outBytes, err := json.Marshal(entitlements) + buf := new(bytes.Buffer) + enc := json.NewEncoder(buf) + enc.SetIndent("", " ") + err = enc.Encode(entitlements) if err != nil { return xerrors.Errorf("marshal features to JSON: %w", err) } - - out = string(outBytes) + out = buf.String() default: return xerrors.Errorf(`unknown output format %q, only "table" and "json" are supported`, outputFormat) } diff --git a/coderd/coderd.go b/coderd/coderd.go index f7b8603367b5e..be089523ec503 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -68,6 +68,7 @@ type Options struct { TracerProvider *sdktrace.TracerProvider AutoImportTemplates []AutoImportTemplate LicenseHandler http.Handler + FeaturesService FeaturesService } // New constructs a Coder API handler. @@ -97,6 +98,9 @@ func New(options *Options) *API { if options.LicenseHandler == nil { options.LicenseHandler = licenses() } + if options.FeaturesService == nil { + options.FeaturesService = featuresService{} + } siteCacheDir := options.CacheDir if siteCacheDir != "" { @@ -406,7 +410,7 @@ func New(options *Options) *API { }) r.Route("/entitlements", func(r chi.Router) { r.Use(apiKeyMiddleware) - r.Get("/", entitlements) + r.Get("/", api.FeaturesService.EntitlementsAPI) }) r.Route("/licenses", func(r chi.Router) { r.Use(apiKeyMiddleware) diff --git a/coderd/database/databasefake/databasefake.go b/coderd/database/databasefake/databasefake.go index 590d3c5685bd4..2b42d0cca7ad2 100644 --- a/coderd/database/databasefake/databasefake.go +++ b/coderd/database/databasefake/databasefake.go @@ -246,6 +246,19 @@ func (q *fakeQuerier) GetUserCount(_ context.Context) (int64, error) { return int64(len(q.users)), nil } +func (q *fakeQuerier) GetActiveUserCount(_ context.Context) (int64, error) { + q.mutex.RLock() + defer q.mutex.RUnlock() + + active := int64(0) + for _, u := range q.users { + if u.Status == database.UserStatusActive { + active++ + } + } + return active, nil +} + func (q *fakeQuerier) GetUsers(_ context.Context, params database.GetUsersParams) ([]database.User, error) { q.mutex.RLock() defer q.mutex.RUnlock() @@ -2322,6 +2335,21 @@ func (q *fakeQuerier) GetLicenses(_ context.Context) ([]database.License, error) return results, nil } +func (q *fakeQuerier) GetUnexpiredLicenses(_ context.Context) ([]database.License, error) { + q.mutex.RLock() + defer q.mutex.RUnlock() + + now := time.Now() + var results []database.License + for _, l := range q.licenses { + if l.Exp.After(now) { + results = append(results, l) + } + } + sort.Slice(results, func(i, j int) bool { return results[i].ID < results[j].ID }) + return results, nil +} + func (q *fakeQuerier) DeleteLicense(_ context.Context, id int32) (int32, error) { q.mutex.Lock() defer q.mutex.Unlock() diff --git a/coderd/database/querier.go b/coderd/database/querier.go index 9603608ebbb05..389f15b385d6d 100644 --- a/coderd/database/querier.go +++ b/coderd/database/querier.go @@ -25,6 +25,7 @@ type querier interface { DeleteParameterValueByID(ctx context.Context, id uuid.UUID) error GetAPIKeyByID(ctx context.Context, id string) (APIKey, error) GetAPIKeysLastUsedAfter(ctx context.Context, lastUsed time.Time) ([]APIKey, error) + GetActiveUserCount(ctx context.Context) (int64, error) // GetAuditLogsBefore retrieves `limit` number of audit logs before the provided // ID. GetAuditLogsBefore(ctx context.Context, arg GetAuditLogsBeforeParams) ([]AuditLog, error) @@ -63,6 +64,7 @@ type querier interface { GetTemplateVersionsCreatedAfter(ctx context.Context, createdAt time.Time) ([]TemplateVersion, error) GetTemplates(ctx context.Context) ([]Template, error) GetTemplatesWithFilter(ctx context.Context, arg GetTemplatesWithFilterParams) ([]Template, error) + GetUnexpiredLicenses(ctx context.Context) ([]License, error) GetUserByEmailOrUsername(ctx context.Context, arg GetUserByEmailOrUsernameParams) (User, error) GetUserByID(ctx context.Context, id uuid.UUID) (User, error) GetUserCount(ctx context.Context) (int64, error) diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index 9a22d86d888a8..1e4a194fa740d 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -522,6 +522,41 @@ func (q *sqlQuerier) GetLicenses(ctx context.Context) ([]License, error) { return items, nil } +const getUnexpiredLicenses = `-- name: GetUnexpiredLicenses :many +SELECT id, uploaded_at, jwt, exp +FROM licenses +WHERE exp > NOW() +ORDER BY (id) +` + +func (q *sqlQuerier) GetUnexpiredLicenses(ctx context.Context) ([]License, error) { + rows, err := q.db.QueryContext(ctx, getUnexpiredLicenses) + if err != nil { + return nil, err + } + defer rows.Close() + var items []License + for rows.Next() { + var i License + if err := rows.Scan( + &i.ID, + &i.UploadedAt, + &i.JWT, + &i.Exp, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const insertLicense = `-- name: InsertLicense :one INSERT INTO licenses ( @@ -2664,6 +2699,22 @@ func (q *sqlQuerier) UpdateUserLinkedID(ctx context.Context, arg UpdateUserLinke return i, err } +const getActiveUserCount = `-- name: GetActiveUserCount :one +SELECT + COUNT(*) +FROM + users +WHERE + status = 'active'::public.user_status +` + +func (q *sqlQuerier) GetActiveUserCount(ctx context.Context) (int64, error) { + row := q.db.QueryRowContext(ctx, getActiveUserCount) + var count int64 + err := row.Scan(&count) + return count, err +} + const getAuthorizationUserRoles = `-- name: GetAuthorizationUserRoles :one SELECT -- username is returned just to help for logging purposes diff --git a/coderd/database/queries/licenses.sql b/coderd/database/queries/licenses.sql index e299589087119..39419c301761d 100644 --- a/coderd/database/queries/licenses.sql +++ b/coderd/database/queries/licenses.sql @@ -13,6 +13,12 @@ SELECT * FROM licenses ORDER BY (id); +-- name: GetUnexpiredLicenses :many +SELECT * +FROM licenses +WHERE exp > NOW() +ORDER BY (id); + -- name: DeleteLicense :one DELETE FROM licenses diff --git a/coderd/database/queries/users.sql b/coderd/database/queries/users.sql index 1d9caa758625e..12751fe064b47 100644 --- a/coderd/database/queries/users.sql +++ b/coderd/database/queries/users.sql @@ -28,6 +28,14 @@ SELECT FROM users; +-- name: GetActiveUserCount :one +SELECT + COUNT(*) +FROM + users +WHERE + status = 'active'::public.user_status; + -- name: InsertUser :one INSERT INTO users ( diff --git a/coderd/features.go b/coderd/features.go index a6eaeca9c545b..55ddd2af895f9 100644 --- a/coderd/features.go +++ b/coderd/features.go @@ -7,7 +7,20 @@ import ( "github.com/coder/coder/codersdk" ) -func entitlements(rw http.ResponseWriter, _ *http.Request) { +// FeaturesService is the interface for interacting with enterprise features. +type FeaturesService interface { + EntitlementsAPI(w http.ResponseWriter, r *http.Request) + + // TODO + // Get returns the implementations for feature interfaces. Parameter `s `must be a pointer to a + // struct type containing feature interfaces as fields. The FeatureService sets all fields to + // the correct implementations depending on whether the features are turned on. + // Get(s any) error +} + +type featuresService struct{} + +func (featuresService) EntitlementsAPI(rw http.ResponseWriter, _ *http.Request) { features := make(map[string]codersdk.Feature) for _, f := range codersdk.FeatureNames { features[f] = codersdk.Feature{ diff --git a/coderd/features_internal_test.go b/coderd/features_internal_test.go index 50c7e8f53e397..d06fc96e19626 100644 --- a/coderd/features_internal_test.go +++ b/coderd/features_internal_test.go @@ -18,7 +18,7 @@ func TestEntitlements(t *testing.T) { t.Parallel() r := httptest.NewRequest("GET", "https://example.com/api/v2/entitlements", nil) rw := httptest.NewRecorder() - entitlements(rw, r) + featuresService{}.EntitlementsAPI(rw, r) resp := rw.Result() defer resp.Body.Close() assert.Equal(t, http.StatusOK, resp.StatusCode) diff --git a/codersdk/features.go b/codersdk/features.go index 6bdfd5ff53bd4..37b0113c37dfb 100644 --- a/codersdk/features.go +++ b/codersdk/features.go @@ -24,8 +24,8 @@ var FeatureNames = []string{FeatureUserLimit, FeatureAuditLog} type Feature struct { Entitlement Entitlement `json:"entitlement"` Enabled bool `json:"enabled"` - Limit *int64 `json:"limit"` - Actual *int64 `json:"actual"` + Limit *int64 `json:"limit,omitempty"` + Actual *int64 `json:"actual,omitempty"` } type Entitlements struct { diff --git a/enterprise/coderd/coderd.go b/enterprise/coderd/coderd.go index 2be49052d3658..598c32f11b367 100644 --- a/enterprise/coderd/coderd.go +++ b/enterprise/coderd/coderd.go @@ -1,12 +1,18 @@ package coderd import ( + "context" + "os" + "strings" + "golang.org/x/xerrors" "github.com/coder/coder/coderd" "github.com/coder/coder/coderd/rbac" ) +const EnvAuditLogEnable = "CODER_AUDIT_LOG_ENABLE" + func NewEnterprise(options *coderd.Options) *coderd.API { var eOpts = *options if eOpts.Authorizer == nil { @@ -26,5 +32,18 @@ func NewEnterprise(options *coderd.Options) *coderd.API { Authorizer: eOpts.Authorizer, Logger: eOpts.Logger, }).handler() + en := Enablements{AuditLogs: true} + auditLog := os.Getenv(EnvAuditLogEnable) + auditLog = strings.ToLower(auditLog) + if auditLog == "disable" || auditLog == "false" || auditLog == "0" || auditLog == "no" { + en.AuditLogs = false + } + eOpts.FeaturesService = newFeaturesService( + context.Background(), + eOpts.Logger, + eOpts.Database, + eOpts.Pubsub, + en, + ) return coderd.New(&eOpts) } diff --git a/enterprise/coderd/features.go b/enterprise/coderd/features.go new file mode 100644 index 0000000000000..2102cdc0eb122 --- /dev/null +++ b/enterprise/coderd/features.go @@ -0,0 +1,261 @@ +package coderd + +import ( + "context" + "crypto/ed25519" + "fmt" + "net/http" + "sync" + "time" + + "github.com/cenkalti/backoff/v4" + + "cdr.dev/slog" + + agpl "github.com/coder/coder/coderd" + "github.com/coder/coder/coderd/database" + "github.com/coder/coder/coderd/httpapi" + "github.com/coder/coder/codersdk" +) + +type Enablements struct { + AuditLogs bool +} + +type featuresService struct { + logger slog.Logger + database database.Store + pubsub database.Pubsub + keys map[string]ed25519.PublicKey + enablements Enablements + resyncInterval time.Duration + + mu sync.RWMutex + entitlements entitlements +} + +// newFeaturesService creates a FeaturesService and starts it. It will continue running for the +// duration of the passed ctx. +func newFeaturesService( + ctx context.Context, + logger slog.Logger, + db database.Store, + pubsub database.Pubsub, + enablements Enablements, +) agpl.FeaturesService { + fs := &featuresService{ + logger: logger, + database: db, + pubsub: pubsub, + keys: keys, + enablements: enablements, + resyncInterval: 10 * time.Minute, + entitlements: entitlements{ + activeUsers: numericalEntitlement{ + entitlementLimit: entitlementLimit{ + unlimited: true, + }, + }, + }, + } + go fs.syncEntitlements(ctx) + return fs +} + +func (s *featuresService) EntitlementsAPI(rw http.ResponseWriter, r *http.Request) { + s.mu.RLock() + e := s.entitlements + s.mu.RUnlock() + + resp := codersdk.Entitlements{ + Features: make(map[string]codersdk.Feature), + Warnings: make([]string, 0), + HasLicense: e.hasLicense, + } + + // User limit + uf := codersdk.Feature{ + Entitlement: e.activeUsers.state.toSDK(), + Enabled: true, + } + if !e.activeUsers.unlimited { + n, err := s.database.GetActiveUserCount(r.Context()) + if err != nil { + httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{ + Message: "Unable to query database", + Detail: err.Error(), + }) + return + } + uf.Actual = &n + uf.Limit = &e.activeUsers.limit + if n > e.activeUsers.limit { + resp.Warnings = append(resp.Warnings, + fmt.Sprintf( + "Your deployment has %d active users but is only licensed for %d", + n, e.activeUsers.limit)) + } + } + resp.Features[codersdk.FeatureUserLimit] = uf + + // Audit logs + resp.Features[codersdk.FeatureAuditLog] = codersdk.Feature{ + Entitlement: e.auditLogs.state.toSDK(), + Enabled: s.enablements.AuditLogs, + } + if e.auditLogs.state == gracePeriod && s.enablements.AuditLogs { + resp.Warnings = append(resp.Warnings, + "Audit logging is enabled but your license for this feature is expired") + } + + httpapi.Write(rw, http.StatusOK, resp) +} + +type entitlementState int + +const ( + notEntitled entitlementState = iota + gracePeriod + entitled +) + +type entitlementLimit struct { + unlimited bool + limit int64 +} + +type entitlement struct { + state entitlementState +} + +func (s entitlementState) toSDK() codersdk.Entitlement { + switch s { + case notEntitled: + return codersdk.EntitlementNotEntitled + case gracePeriod: + return codersdk.EntitlementGracePeriod + case entitled: + return codersdk.EntitlementEntitled + default: + panic("unknown entitlementState") + } +} + +type numericalEntitlement struct { + entitlement + entitlementLimit +} + +type entitlements struct { + hasLicense bool + activeUsers numericalEntitlement + auditLogs entitlement +} + +func (s *featuresService) getEntitlements(ctx context.Context) (entitlements, error) { + licenses, err := s.database.GetUnexpiredLicenses(ctx) + if err != nil { + return entitlements{}, err + } + now := time.Now() + e := entitlements{ + activeUsers: numericalEntitlement{ + entitlementLimit: entitlementLimit{ + unlimited: true, + }, + }, + } + for _, l := range licenses { + claims, err := validateDBLicense(l, s.keys) + if err != nil { + s.logger.Debug(ctx, "skipping invalid license", + slog.F("id", l.ID), slog.Error(err)) + continue + } + e.hasLicense = true + thisEntitlement := entitled + if now.After(claims.LicenseExpires.Time) { + // if the grace period were over, the validation fails, so if we are after + // LicenseExpires we must be in grace period. + thisEntitlement = gracePeriod + } + if claims.Features.UserLimit > 0 { + e.activeUsers.state = thisEntitlement + e.activeUsers.unlimited = false + e.activeUsers.limit = max(e.activeUsers.limit, claims.Features.UserLimit) + } + if claims.Features.AuditLog > 0 { + e.auditLogs.state = thisEntitlement + } + } + return e, nil +} + +func (s *featuresService) syncEntitlements(ctx context.Context) { + eb := backoff.NewExponentialBackOff() + eb.MaxElapsedTime = 0 // retry indefinitely + b := backoff.WithContext(eb, ctx) + updates := make(chan struct{}, 1) + subscribed := false + + for { + select { + case <-ctx.Done(): + return + default: + // pass + } + if !subscribed { + cancel, err := s.pubsub.Subscribe(PubSubEventLicenses, func(_ context.Context, _ []byte) { + // don't block. If the channel is full, drop the event, as there is a resync + // scheduled already. + select { + case updates <- struct{}{}: + // pass + default: + // pass + } + }) + if err != nil { + s.logger.Warn(ctx, "failed to subscribe to license updates", slog.Error(err)) + time.Sleep(b.NextBackOff()) + continue + } + // nolint: revive + defer cancel() + subscribed = true + s.logger.Debug(ctx, "successfully subscribed to pubsub") + } + + s.logger.Info(ctx, "syncing licensed entitlements") + ents, err := s.getEntitlements(ctx) + if err != nil { + s.logger.Warn(ctx, "failed to get feature entitlements", slog.Error(err)) + time.Sleep(b.NextBackOff()) + continue + } + b.Reset() + + s.mu.Lock() + s.entitlements = ents + s.mu.Unlock() + s.logger.Debug(ctx, "synced licensed entitlements") + + select { + case <-ctx.Done(): + return + case <-time.After(s.resyncInterval): + continue + case <-updates: + s.logger.Debug(ctx, "got pubsub update") + continue + } + } +} + +func max(a, b int64) int64 { + if a > b { + return a + } + return b +} diff --git a/enterprise/coderd/features_internal_test.go b/enterprise/coderd/features_internal_test.go new file mode 100644 index 0000000000000..bb1b14a57606d --- /dev/null +++ b/enterprise/coderd/features_internal_test.go @@ -0,0 +1,337 @@ +package coderd + +import ( + "context" + "crypto/ed25519" + "crypto/rand" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/golang-jwt/jwt/v4" + + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "cdr.dev/slog/sloggers/slogtest" + + "github.com/coder/coder/coderd" + "github.com/coder/coder/coderd/database" + "github.com/coder/coder/coderd/database/databasefake" + "github.com/coder/coder/codersdk" + "github.com/coder/coder/testutil" +) + +func TestFeaturesService_EntitlementsAPI(t *testing.T) { + t.Parallel() + logger := slogtest.Make(t, nil) + + // Note that these are not actually used because we don't run the syncEntitlements + // routine in this test. + pubsub := database.NewPubsubInMemory() + pub, _, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + keyID := "testing" + + t.Run("NoLicense", func(t *testing.T) { + t.Parallel() + db := databasefake.New() + uut := &featuresService{ + logger: logger, + database: db, + pubsub: pubsub, + keys: map[string]ed25519.PublicKey{keyID: pub}, + enablements: Enablements{AuditLogs: true}, + entitlements: entitlements{ + hasLicense: false, + activeUsers: numericalEntitlement{ + entitlement{notEntitled}, + entitlementLimit{ + unlimited: true, + }, + }, + auditLogs: entitlement{notEntitled}, + }, + } + result := requestEntitlements(t, uut) + assert.False(t, result.HasLicense) + assert.Empty(t, result.Warnings) + assert.Equal(t, codersdk.EntitlementNotEntitled, result.Features[codersdk.FeatureUserLimit].Entitlement) + assert.Equal(t, codersdk.EntitlementNotEntitled, result.Features[codersdk.FeatureAuditLog].Entitlement) + }) + + t.Run("FullLicense", func(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) + defer cancel() + db := databasefake.New() + uut := &featuresService{ + logger: logger, + database: db, + pubsub: pubsub, + keys: map[string]ed25519.PublicKey{keyID: pub}, + enablements: Enablements{AuditLogs: true}, + entitlements: entitlements{ + hasLicense: true, + activeUsers: numericalEntitlement{ + entitlement{entitled}, + entitlementLimit{ + unlimited: false, + limit: 100, + }, + }, + auditLogs: entitlement{entitled}, + }, + } + _, err = db.InsertUser(ctx, database.InsertUserParams{ + ID: uuid.UUID{}, + Email: "", + Username: "", + HashedPassword: nil, + CreatedAt: time.Time{}, + UpdatedAt: time.Time{}, + RBACRoles: nil, + LoginType: "", + }) + require.NoError(t, err) + result := requestEntitlements(t, uut) + assert.True(t, result.HasLicense) + ul := result.Features[codersdk.FeatureUserLimit] + assert.Equal(t, codersdk.EntitlementEntitled, ul.Entitlement) + assert.Equal(t, int64(100), *ul.Limit) + assert.Equal(t, int64(1), *ul.Actual) + assert.True(t, ul.Enabled) + al := result.Features[codersdk.FeatureAuditLog] + assert.Equal(t, codersdk.EntitlementEntitled, al.Entitlement) + assert.True(t, al.Enabled) + assert.Nil(t, al.Limit) + assert.Nil(t, al.Actual) + assert.Empty(t, result.Warnings) + }) + + t.Run("Warnings", func(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) + defer cancel() + db := databasefake.New() + uut := &featuresService{ + logger: logger, + database: db, + pubsub: pubsub, + keys: map[string]ed25519.PublicKey{keyID: pub}, + enablements: Enablements{AuditLogs: true}, + entitlements: entitlements{ + hasLicense: true, + activeUsers: numericalEntitlement{ + entitlement{gracePeriod}, + entitlementLimit{ + unlimited: false, + limit: 4, + }, + }, + auditLogs: entitlement{gracePeriod}, + }, + } + for i := byte(0); i < 5; i++ { + _, err = db.InsertUser(ctx, database.InsertUserParams{ + ID: uuid.UUID{i}, + Email: "", + Username: "", + HashedPassword: nil, + CreatedAt: time.Time{}, + UpdatedAt: time.Time{}, + RBACRoles: nil, + LoginType: "", + }) + require.NoError(t, err) + } + result := requestEntitlements(t, uut) + assert.True(t, result.HasLicense) + ul := result.Features[codersdk.FeatureUserLimit] + assert.Equal(t, codersdk.EntitlementGracePeriod, ul.Entitlement) + assert.Equal(t, int64(4), *ul.Limit) + assert.Equal(t, int64(5), *ul.Actual) + assert.True(t, ul.Enabled) + al := result.Features[codersdk.FeatureAuditLog] + assert.Equal(t, codersdk.EntitlementGracePeriod, al.Entitlement) + assert.True(t, al.Enabled) + assert.Nil(t, al.Limit) + assert.Nil(t, al.Actual) + assert.Len(t, result.Warnings, 2) + assert.Contains(t, result.Warnings, + "Your deployment has 5 active users but is only licensed for 4") + assert.Contains(t, result.Warnings, + "Audit logging is enabled but your license for this feature is expired") + }) +} + +func TestFeaturesServiceSyncEntitlements(t *testing.T) { + t.Parallel() + pub, priv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + keyID := "testing" + + // This tests that pubsub updates work by setting the resync interval very long + t.Run("PubSub", func(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) + defer cancel() + logger := slogtest.Make(t, nil) + pubsub := database.NewPubsubInMemory() + db := databasefake.New() + uut := &featuresService{ + logger: logger, + database: db, + pubsub: pubsub, + keys: map[string]ed25519.PublicKey{keyID: pub}, + enablements: Enablements{AuditLogs: true}, + resyncInterval: time.Hour, // no resyncs during test + entitlements: entitlements{}, + } + + _, invalidKey, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + // Start of day, 3 licenses, one expired, one invalid + _ = putLicense(ctx, t, db, priv, keyID, 1000, -2*time.Hour, -1*time.Hour) + _ = putLicense(ctx, t, db, invalidKey, "invalid", 900, time.Hour, 2*time.Hour) + l0 := putLicense(ctx, t, db, priv, keyID, 300, time.Hour, 2*time.Hour) + + go uut.syncEntitlements(ctx) + + testutil.Eventually(ctx, t, userLimitIs(uut, 300), testutil.IntervalFast) + + // New license + l1 := putLicense(ctx, t, db, priv, keyID, 305, time.Hour, 2*time.Hour) + err = pubsub.Publish(PubSubEventLicenses, []byte("add")) + require.NoError(t, err) + + // User limit goes up, because 305 > 300 + testutil.Eventually(ctx, t, userLimitIs(uut, 305), testutil.IntervalFast) + + // New license with lower limit + _ = putLicense(ctx, t, db, priv, keyID, 295, time.Hour, 2*time.Hour) + err = pubsub.Publish(PubSubEventLicenses, []byte("add")) + require.NoError(t, err) + + // Need to delete the others before the limit lowers + _, err = db.DeleteLicense(ctx, l1.ID) + require.NoError(t, err) + err = pubsub.Publish(PubSubEventLicenses, []byte("delete")) + require.NoError(t, err) + testutil.Eventually(ctx, t, userLimitIs(uut, 300), testutil.IntervalFast) + + _, err = db.DeleteLicense(ctx, l0.ID) + require.NoError(t, err) + err = pubsub.Publish(PubSubEventLicenses, []byte("delete")) + require.NoError(t, err) + testutil.Eventually(ctx, t, userLimitIs(uut, 295), testutil.IntervalFast) + }) + + // This tests that periodic resyncs work by setting the resync interval very fast and + // not sending any pubsub updates. + t.Run("Resyncs", func(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) + defer cancel() + logger := slogtest.Make(t, nil) + pubsub := database.NewPubsubInMemory() + db := databasefake.New() + uut := &featuresService{ + logger: logger, + database: db, + pubsub: pubsub, + keys: map[string]ed25519.PublicKey{keyID: pub}, + enablements: Enablements{AuditLogs: true}, + resyncInterval: 10 * time.Millisecond, + entitlements: entitlements{}, + } + + _, invalidKey, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + // Start of day, 3 licenses, one expired, one invalid + _ = putLicense(ctx, t, db, priv, keyID, 1000, -2*time.Hour, -1*time.Hour) + _ = putLicense(ctx, t, db, invalidKey, "invalid", 900, time.Hour, 2*time.Hour) + l0 := putLicense(ctx, t, db, priv, keyID, 300, time.Hour, 2*time.Hour) + + go uut.syncEntitlements(ctx) + + testutil.Eventually(ctx, t, userLimitIs(uut, 300), testutil.IntervalFast) + + // New license + l1 := putLicense(ctx, t, db, priv, keyID, 305, time.Hour, 2*time.Hour) + + // User limit goes up, because 305 > 300 + testutil.Eventually(ctx, t, userLimitIs(uut, 305), testutil.IntervalFast) + + // New license with lower limit + _ = putLicense(ctx, t, db, priv, keyID, 295, time.Hour, 2*time.Hour) + + // Need to delete the others before the limit lowers + _, err = db.DeleteLicense(ctx, l1.ID) + require.NoError(t, err) + testutil.Eventually(ctx, t, userLimitIs(uut, 300), testutil.IntervalFast) + + _, err = db.DeleteLicense(ctx, l0.ID) + require.NoError(t, err) + testutil.Eventually(ctx, t, userLimitIs(uut, 295), testutil.IntervalFast) + }) +} + +func requestEntitlements(t *testing.T, uut coderd.FeaturesService) codersdk.Entitlements { + t.Helper() + r := httptest.NewRequest("GET", "https://example.com/api/v2/entitlements", nil) + rw := httptest.NewRecorder() + uut.EntitlementsAPI(rw, r) + resp := rw.Result() + defer resp.Body.Close() + assert.Equal(t, http.StatusOK, resp.StatusCode) + dec := json.NewDecoder(resp.Body) + var result codersdk.Entitlements + err := dec.Decode(&result) + require.NoError(t, err) + return result +} + +func putLicense( + ctx context.Context, t *testing.T, db database.Store, + k ed25519.PrivateKey, keyID string, userLimit int64, + timeToGrace, timeToExpire time.Duration, +) database.License { + t.Helper() + c := &Claims{ + RegisteredClaims: jwt.RegisteredClaims{ + Issuer: "test@testing.test", + ExpiresAt: jwt.NewNumericDate(time.Now().Add(timeToExpire)), + NotBefore: jwt.NewNumericDate(time.Now().Add(-time.Minute)), + IssuedAt: jwt.NewNumericDate(time.Now().Add(-time.Minute)), + }, + LicenseExpires: jwt.NewNumericDate(time.Now().Add(timeToGrace)), + Version: CurrentVersion, + Features: Features{ + UserLimit: userLimit, + AuditLog: 1, + }, + } + j, err := makeLicense(c, k, keyID) + require.NoError(t, err) + l, err := db.InsertLicense(ctx, database.InsertLicenseParams{ + UploadedAt: c.IssuedAt.Time, + JWT: j, + Exp: c.ExpiresAt.Time, + }) + require.NoError(t, err) + return l +} + +func userLimitIs(fs *featuresService, limit int64) func(context.Context) bool { + return func(_ context.Context) bool { + fs.mu.RLock() + defer fs.mu.RUnlock() + return fs.entitlements.activeUsers.limit == limit + } +} diff --git a/enterprise/coderd/licenses.go b/enterprise/coderd/licenses.go index 16592fcde2654..02bef91d52bcb 100644 --- a/enterprise/coderd/licenses.go +++ b/enterprise/coderd/licenses.go @@ -64,8 +64,9 @@ type Claims struct { } var ( - ErrInvalidVersion = xerrors.New("license must be version 3") - ErrMissingKeyID = xerrors.Errorf("JOSE header must contain %s", HeaderKeyID) + ErrInvalidVersion = xerrors.New("license must be version 3") + ErrMissingKeyID = xerrors.Errorf("JOSE header must contain %s", HeaderKeyID) + ErrMissingLicenseExpires = xerrors.New("license missing license_expires") ) // parseLicense parses the license and returns the claims. If the license's signature is invalid or @@ -92,6 +93,30 @@ func parseLicense(l string, keys map[string]ed25519.PublicKey) (jwt.MapClaims, e return nil, xerrors.New("unable to parse Claims") } +// validateDBLicense validates a database.License record, and if valid, returns the claims. If +// unparsable or invalid, it returns an error +func validateDBLicense(l database.License, keys map[string]ed25519.PublicKey) (*Claims, error) { + tok, err := jwt.ParseWithClaims( + l.JWT, + &Claims{}, + keyFunc(keys), + jwt.WithValidMethods(ValidMethods), + ) + if err != nil { + return nil, err + } + if claims, ok := tok.Claims.(*Claims); ok && tok.Valid { + if claims.Version != uint64(CurrentVersion) { + return nil, ErrInvalidVersion + } + if claims.LicenseExpires == nil { + return nil, ErrMissingLicenseExpires + } + return claims, nil + } + return nil, xerrors.New("unable to parse Claims") +} + func keyFunc(keys map[string]ed25519.PublicKey) func(*jwt.Token) (interface{}, error) { return func(j *jwt.Token) (interface{}, error) { keyID, ok := j.Header[HeaderKeyID].(string) @@ -297,5 +322,11 @@ func (a *licenseAPI) delete(rw http.ResponseWriter, r *http.Request) { }) return } + + err = a.pubsub.Publish(PubSubEventLicenses, []byte("delete")) + if err != nil { + a.logger.Error(context.Background(), "failed to publish license delete", slog.Error(err)) + // don't fail the HTTP request, since we did write it successfully to the database + } rw.WriteHeader(http.StatusOK) }