Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
chore: refactor entitlements to be a safe object to use
  • Loading branch information
Emyrk committed Aug 22, 2024
commit 33b819940eaac9de3bfe58d55b83869fae80f166
58 changes: 21 additions & 37 deletions enterprise/coderd/coderd.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
"github.com/coder/coder/v2/coderd/database"
agplportsharing "github.com/coder/coder/v2/coderd/portsharing"
"github.com/coder/coder/v2/coderd/rbac/policy"
"github.com/coder/coder/v2/enterprise/coderd/entitlements"
"github.com/coder/coder/v2/enterprise/coderd/portsharing"

"golang.org/x/xerrors"
Expand Down Expand Up @@ -553,9 +554,7 @@ type API struct {
// ProxyHealth checks the reachability of all workspace proxies.
ProxyHealth *proxyhealth.ProxyHealth

entitlementsUpdateMu sync.Mutex
entitlementsMu sync.RWMutex
entitlements codersdk.Entitlements
entitlements *entitlements.Set

provisionerDaemonAuth *provisionerDaemonAuth

Expand Down Expand Up @@ -588,11 +587,8 @@ func (api *API) writeEntitlementWarningsHeader(a rbac.Subject, header http.Heade
// has no roles. This is a normal user!
return
}
api.entitlementsMu.RLock()
defer api.entitlementsMu.RUnlock()
for _, warning := range api.entitlements.Warnings {
header.Add(codersdk.EntitlementsWarningHeader, warning)
}

api.entitlements.WriteEntitlementWarningHeaders(header)
}

func (api *API) Close() error {
Expand All @@ -614,9 +610,6 @@ func (api *API) Close() error {
}

func (api *API) updateEntitlements(ctx context.Context) error {
api.entitlementsUpdateMu.Lock()
defer api.entitlementsUpdateMu.Unlock()

replicas := api.replicaManager.AllPrimary()
agedReplicas := make([]database.Replica, 0, len(replicas))
for _, replica := range replicas {
Expand All @@ -632,7 +625,7 @@ func (api *API) updateEntitlements(ctx context.Context) error {
agedReplicas = append(agedReplicas, replica)
}

entitlements, err := license.Entitlements(
reloadedEntitlements, err := license.Entitlements(
ctx, api.Database,
len(agedReplicas), len(api.ExternalAuthConfigs), api.LicenseKeys, map[codersdk.FeatureName]bool{
codersdk.FeatureAuditLog: api.AuditLogging,
Expand All @@ -652,29 +645,24 @@ func (api *API) updateEntitlements(ctx context.Context) error {
return err
}

if entitlements.RequireTelemetry && !api.DeploymentValues.Telemetry.Enable.Value() {
if reloadedEntitlements.RequireTelemetry && !api.DeploymentValues.Telemetry.Enable.Value() {
// We can't fail because then the user couldn't remove the offending
// license w/o a restart.
//
// We don't simply append to entitlement.Errors since we don't want any
// enterprise features enabled.
api.entitlements.Errors = []string{
"License requires telemetry but telemetry is disabled",
}
api.entitlements.Update(func(entitlements *codersdk.Entitlements) {
entitlements.Errors = []string{
"License requires telemetry but telemetry is disabled",
}
})

api.Logger.Error(ctx, "license requires telemetry enabled")
return nil
}

featureChanged := func(featureName codersdk.FeatureName) (initial, changed, enabled bool) {
if api.entitlements.Features == nil {
return true, false, entitlements.Features[featureName].Enabled
}
oldFeature := api.entitlements.Features[featureName]
newFeature := entitlements.Features[featureName]
if oldFeature.Enabled != newFeature.Enabled {
return false, true, newFeature.Enabled
}
return false, false, newFeature.Enabled
return api.entitlements.FeatureChanged(featureName, reloadedEntitlements.Features[featureName])
}

shouldUpdate := func(initial, changed, enabled bool) bool {
Expand Down Expand Up @@ -831,20 +819,18 @@ func (api *API) updateEntitlements(ctx context.Context) error {
}

// External token encryption is soft-enforced
featureExternalTokenEncryption := entitlements.Features[codersdk.FeatureExternalTokenEncryption]
featureExternalTokenEncryption := reloadedEntitlements.Features[codersdk.FeatureExternalTokenEncryption]
featureExternalTokenEncryption.Enabled = len(api.ExternalTokenEncryption) > 0
if featureExternalTokenEncryption.Enabled && featureExternalTokenEncryption.Entitlement != codersdk.EntitlementEntitled {
msg := fmt.Sprintf("%s is enabled (due to setting external token encryption keys) but your license is not entitled to this feature.", codersdk.FeatureExternalTokenEncryption.Humanize())
api.Logger.Warn(ctx, msg)
entitlements.Warnings = append(entitlements.Warnings, msg)
reloadedEntitlements.Warnings = append(reloadedEntitlements.Warnings, msg)
}
entitlements.Features[codersdk.FeatureExternalTokenEncryption] = featureExternalTokenEncryption
reloadedEntitlements.Features[codersdk.FeatureExternalTokenEncryption] = featureExternalTokenEncryption

api.entitlementsMu.Lock()
defer api.entitlementsMu.Unlock()
api.entitlements = entitlements
api.licenseMetricsCollector.Entitlements.Store(&entitlements)
api.AGPL.SiteHandler.Entitlements.Store(&entitlements)
api.entitlements.Replace(reloadedEntitlements)
api.licenseMetricsCollector.Entitlements.Store(&reloadedEntitlements)
api.AGPL.SiteHandler.Entitlements.Store(&reloadedEntitlements)
return nil
}

Expand Down Expand Up @@ -1024,10 +1010,8 @@ func derpMapper(logger slog.Logger, proxyHealth *proxyhealth.ProxyHealth) func(*
// @Router /entitlements [get]
func (api *API) serveEntitlements(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
api.entitlementsMu.RLock()
entitlements := api.entitlements
api.entitlementsMu.RUnlock()
httpapi.Write(ctx, rw, http.StatusOK, entitlements)
// TODO: verify this works
httpapi.Write(ctx, rw, http.StatusOK, api.entitlements.AsJSON())
}

func (api *API) runEntitlementsLoop(ctx context.Context) {
Expand Down
83 changes: 83 additions & 0 deletions enterprise/coderd/entitlements/entitlements.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
package entitlements

import (
"encoding/json"
"net/http"
"sync"

"github.com/coder/coder/v2/codersdk"
)

type Set struct {
entitlementsMu sync.RWMutex
entitlements codersdk.Entitlements
}

func New(entitlements codersdk.Entitlements) *Set {
return &Set{
entitlements: entitlements,
}
}

func (l *Set) Replace(entitlements codersdk.Entitlements) {
l.entitlementsMu.Lock()
defer l.entitlementsMu.Unlock()

l.entitlements = entitlements
}

func (l *Set) Feature(name codersdk.FeatureName) (codersdk.Feature, bool) {
l.entitlementsMu.RLock()
defer l.entitlementsMu.RUnlock()

f, ok := l.entitlements.Features[name]
return f, ok
}

func (l *Set) Enabled(feature codersdk.FeatureName) bool {
l.entitlementsMu.Lock()
defer l.entitlementsMu.Unlock()

f, ok := l.entitlements.Features[feature]
if !ok {
return false
}
return f.Enabled
}

// AsJSON is used to return this to the api without exposing the entitlements for
// mutation.
func (l *Set) AsJSON() json.RawMessage {
l.entitlementsMu.Lock()
defer l.entitlementsMu.Unlock()

b, _ := json.Marshal(l.entitlements)
return b
}

func (l *Set) Update(do func(entitlements *codersdk.Entitlements)) {
l.entitlementsMu.Lock()
defer l.entitlementsMu.Unlock()

do(&l.entitlements)
}

func (l *Set) FeatureChanged(featureName codersdk.FeatureName, newFeature codersdk.Feature) (initial, changed, enabled bool) {
l.entitlementsMu.Lock()
defer l.entitlementsMu.Unlock()

oldFeature := l.entitlements.Features[featureName]
if oldFeature.Enabled != newFeature.Enabled {
return false, true, newFeature.Enabled
}
return false, false, newFeature.Enabled
}

func (l *Set) WriteEntitlementWarningHeaders(header http.Header) {
l.entitlementsMu.RLock()
defer l.entitlementsMu.RUnlock()

for _, warning := range l.entitlements.Warnings {
header.Add(codersdk.EntitlementsWarningHeader, warning)
}
}
12 changes: 3 additions & 9 deletions enterprise/coderd/license/metricscollector.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
package license

import (
"sync/atomic"

"github.com/prometheus/client_golang/prometheus"

"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/enterprise/coderd/entitlements"
)

var (
Expand All @@ -15,7 +14,7 @@ var (
)

type MetricsCollector struct {
Entitlements atomic.Pointer[codersdk.Entitlements]
Entitlements *entitlements.Set
}

var _ prometheus.Collector = new(MetricsCollector)
Expand All @@ -27,12 +26,7 @@ func (*MetricsCollector) Describe(descCh chan<- *prometheus.Desc) {
}

func (mc *MetricsCollector) Collect(metricsCh chan<- prometheus.Metric) {
entitlements := mc.Entitlements.Load()
if entitlements == nil || entitlements.Features == nil {
return
}

userLimitEntitlement, ok := entitlements.Features[codersdk.FeatureUserLimit]
userLimitEntitlement, ok := mc.Entitlements.Feature(codersdk.FeatureUserLimit)
if !ok {
return
}
Expand Down