Skip to content

Commit af125c3

Browse files
authored
chore: refactor entitlements to be a safe object to use (coder#14406)
* chore: refactor entitlements to be passable as an argument Previously, all usage of entitlements requires mutex usage on the api struct directly. This prevents passing the entitlements to a sub package. It also creates the possibility for misuse.
1 parent cb6a472 commit af125c3

17 files changed

+247
-124
lines changed

coderd/coderd.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ import (
3737
"tailscale.com/util/singleflight"
3838

3939
"cdr.dev/slog"
40+
"github.com/coder/coder/v2/coderd/entitlements"
4041
"github.com/coder/quartz"
4142
"github.com/coder/serpent"
4243

@@ -157,6 +158,9 @@ type Options struct {
157158
TrialGenerator func(ctx context.Context, body codersdk.LicensorTrialRequest) error
158159
// RefreshEntitlements is used to set correct entitlements after creating first user and generating trial license.
159160
RefreshEntitlements func(ctx context.Context) error
161+
// Entitlements can come from the enterprise caller if enterprise code is
162+
// included.
163+
Entitlements *entitlements.Set
160164
// PostAuthAdditionalHeadersFunc is used to add additional headers to the response
161165
// after a successful authentication.
162166
// This is somewhat janky, but seemingly the only reasonable way to add a header
@@ -263,6 +267,9 @@ func New(options *Options) *API {
263267
if options == nil {
264268
options = &Options{}
265269
}
270+
if options.Entitlements == nil {
271+
options.Entitlements = entitlements.New()
272+
}
266273
if options.NewTicker == nil {
267274
options.NewTicker = func(duration time.Duration) (tick <-chan time.Time, done func()) {
268275
ticker := time.NewTicker(duration)
@@ -500,6 +507,7 @@ func New(options *Options) *API {
500507
DocsURL: options.DeploymentValues.DocsURL.String(),
501508
AppearanceFetcher: &api.AppearanceFetcher,
502509
BuildInfo: buildInfo,
510+
Entitlements: options.Entitlements,
503511
})
504512
api.SiteHandler.Experiments.Store(&experiments)
505513

coderd/entitlements/entitlements.go

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
package entitlements
2+
3+
import (
4+
"encoding/json"
5+
"net/http"
6+
"sync"
7+
"time"
8+
9+
"github.com/coder/coder/v2/codersdk"
10+
)
11+
12+
type Set struct {
13+
entitlementsMu sync.RWMutex
14+
entitlements codersdk.Entitlements
15+
}
16+
17+
func New() *Set {
18+
return &Set{
19+
// Some defaults for an unlicensed instance.
20+
// These will be updated when coderd is initialized.
21+
entitlements: codersdk.Entitlements{
22+
Features: map[codersdk.FeatureName]codersdk.Feature{},
23+
Warnings: nil,
24+
Errors: nil,
25+
HasLicense: false,
26+
Trial: false,
27+
RequireTelemetry: false,
28+
RefreshedAt: time.Time{},
29+
},
30+
}
31+
}
32+
33+
// AllowRefresh returns whether the entitlements are allowed to be refreshed.
34+
// If it returns false, that means it was recently refreshed and the caller should
35+
// wait the returned duration before trying again.
36+
func (l *Set) AllowRefresh(now time.Time) (bool, time.Duration) {
37+
l.entitlementsMu.RLock()
38+
defer l.entitlementsMu.RUnlock()
39+
40+
diff := now.Sub(l.entitlements.RefreshedAt)
41+
if diff < time.Minute {
42+
return false, time.Minute - diff
43+
}
44+
45+
return true, 0
46+
}
47+
48+
func (l *Set) Feature(name codersdk.FeatureName) (codersdk.Feature, bool) {
49+
l.entitlementsMu.RLock()
50+
defer l.entitlementsMu.RUnlock()
51+
52+
f, ok := l.entitlements.Features[name]
53+
return f, ok
54+
}
55+
56+
func (l *Set) Enabled(feature codersdk.FeatureName) bool {
57+
l.entitlementsMu.RLock()
58+
defer l.entitlementsMu.RUnlock()
59+
60+
f, ok := l.entitlements.Features[feature]
61+
if !ok {
62+
return false
63+
}
64+
return f.Enabled
65+
}
66+
67+
// AsJSON is used to return this to the api without exposing the entitlements for
68+
// mutation.
69+
func (l *Set) AsJSON() json.RawMessage {
70+
l.entitlementsMu.RLock()
71+
defer l.entitlementsMu.RUnlock()
72+
73+
b, _ := json.Marshal(l.entitlements)
74+
return b
75+
}
76+
77+
func (l *Set) Replace(entitlements codersdk.Entitlements) {
78+
l.entitlementsMu.Lock()
79+
defer l.entitlementsMu.Unlock()
80+
81+
l.entitlements = entitlements
82+
}
83+
84+
func (l *Set) Update(do func(entitlements *codersdk.Entitlements)) {
85+
l.entitlementsMu.Lock()
86+
defer l.entitlementsMu.Unlock()
87+
88+
do(&l.entitlements)
89+
}
90+
91+
func (l *Set) FeatureChanged(featureName codersdk.FeatureName, newFeature codersdk.Feature) (initial, changed, enabled bool) {
92+
l.entitlementsMu.RLock()
93+
defer l.entitlementsMu.RUnlock()
94+
95+
oldFeature := l.entitlements.Features[featureName]
96+
if oldFeature.Enabled != newFeature.Enabled {
97+
return false, true, newFeature.Enabled
98+
}
99+
return false, false, newFeature.Enabled
100+
}
101+
102+
func (l *Set) WriteEntitlementWarningHeaders(header http.Header) {
103+
l.entitlementsMu.RLock()
104+
defer l.entitlementsMu.RUnlock()
105+
106+
for _, warning := range l.entitlements.Warnings {
107+
header.Add(codersdk.EntitlementsWarningHeader, warning)
108+
}
109+
}
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
package entitlements_test
2+
3+
import (
4+
"testing"
5+
"time"
6+
7+
"github.com/stretchr/testify/require"
8+
9+
"github.com/coder/coder/v2/coderd/entitlements"
10+
"github.com/coder/coder/v2/codersdk"
11+
)
12+
13+
func TestUpdate(t *testing.T) {
14+
t.Parallel()
15+
16+
set := entitlements.New()
17+
require.False(t, set.Enabled(codersdk.FeatureMultipleOrganizations))
18+
19+
set.Update(func(entitlements *codersdk.Entitlements) {
20+
entitlements.Features[codersdk.FeatureMultipleOrganizations] = codersdk.Feature{
21+
Enabled: true,
22+
Entitlement: codersdk.EntitlementEntitled,
23+
}
24+
})
25+
require.True(t, set.Enabled(codersdk.FeatureMultipleOrganizations))
26+
}
27+
28+
func TestAllowRefresh(t *testing.T) {
29+
t.Parallel()
30+
31+
now := time.Now()
32+
set := entitlements.New()
33+
set.Update(func(entitlements *codersdk.Entitlements) {
34+
entitlements.RefreshedAt = now
35+
})
36+
37+
ok, wait := set.AllowRefresh(now)
38+
require.False(t, ok)
39+
require.InDelta(t, time.Minute.Seconds(), wait.Seconds(), 5)
40+
41+
set.Update(func(entitlements *codersdk.Entitlements) {
42+
entitlements.RefreshedAt = now.Add(time.Minute * -2)
43+
})
44+
45+
ok, wait = set.AllowRefresh(now)
46+
require.True(t, ok)
47+
require.Equal(t, time.Duration(0), wait)
48+
}
49+
50+
func TestReplace(t *testing.T) {
51+
t.Parallel()
52+
53+
set := entitlements.New()
54+
require.False(t, set.Enabled(codersdk.FeatureMultipleOrganizations))
55+
set.Replace(codersdk.Entitlements{
56+
Features: map[codersdk.FeatureName]codersdk.Feature{
57+
codersdk.FeatureMultipleOrganizations: {
58+
Enabled: true,
59+
},
60+
},
61+
})
62+
require.True(t, set.Enabled(codersdk.FeatureMultipleOrganizations))
63+
}

codersdk/deployment.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,12 @@ const (
3535
EntitlementNotEntitled Entitlement = "not_entitled"
3636
)
3737

38+
// Entitled returns if the entitlement can be used. So this is true if it
39+
// is entitled or still in it's grace period.
40+
func (e Entitlement) Entitled() bool {
41+
return e == EntitlementEntitled || e == EntitlementGracePeriod
42+
}
43+
3844
// Weight converts the enum types to a numerical value for easier
3945
// comparisons. Easier than sets of if statements.
4046
func (e Entitlement) Weight() int {

enterprise/coderd/coderd.go

Lines changed: 30 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ import (
1515
"github.com/coder/coder/v2/buildinfo"
1616
"github.com/coder/coder/v2/coderd/appearance"
1717
"github.com/coder/coder/v2/coderd/database"
18+
"github.com/coder/coder/v2/coderd/entitlements"
1819
agplportsharing "github.com/coder/coder/v2/coderd/portsharing"
1920
"github.com/coder/coder/v2/coderd/rbac/policy"
2021
"github.com/coder/coder/v2/enterprise/coderd/portsharing"
@@ -103,19 +104,26 @@ func New(ctx context.Context, options *Options) (_ *API, err error) {
103104
}
104105
return nil, xerrors.Errorf("init database encryption: %w", err)
105106
}
107+
108+
entitlementsSet := entitlements.New()
106109
options.Database = cryptDB
107110
api := &API{
108-
ctx: ctx,
109-
cancel: cancelFunc,
110-
Options: options,
111+
ctx: ctx,
112+
cancel: cancelFunc,
113+
Options: options,
114+
entitlements: entitlementsSet,
111115
provisionerDaemonAuth: &provisionerDaemonAuth{
112116
psk: options.ProvisionerDaemonPSK,
113117
authorizer: options.Authorizer,
114118
db: options.Database,
115119
},
120+
licenseMetricsCollector: &license.MetricsCollector{
121+
Entitlements: entitlementsSet,
122+
},
116123
}
117124
// This must happen before coderd initialization!
118125
options.PostAuthAdditionalHeadersFunc = api.writeEntitlementWarningsHeader
126+
options.Options.Entitlements = api.entitlements
119127
api.AGPL = coderd.New(options.Options)
120128
defer func() {
121129
if err != nil {
@@ -493,7 +501,7 @@ func New(ctx context.Context, options *Options) (_ *API, err error) {
493501
}
494502
api.AGPL.WorkspaceProxiesFetchUpdater.Store(&fetchUpdater)
495503

496-
err = api.PrometheusRegistry.Register(&api.licenseMetricsCollector)
504+
err = api.PrometheusRegistry.Register(api.licenseMetricsCollector)
497505
if err != nil {
498506
return nil, xerrors.Errorf("unable to register license metrics collector")
499507
}
@@ -553,13 +561,11 @@ type API struct {
553561
// ProxyHealth checks the reachability of all workspace proxies.
554562
ProxyHealth *proxyhealth.ProxyHealth
555563

556-
entitlementsUpdateMu sync.Mutex
557-
entitlementsMu sync.RWMutex
558-
entitlements codersdk.Entitlements
564+
entitlements *entitlements.Set
559565

560566
provisionerDaemonAuth *provisionerDaemonAuth
561567

562-
licenseMetricsCollector license.MetricsCollector
568+
licenseMetricsCollector *license.MetricsCollector
563569
tailnetService *tailnet.ClientService
564570
}
565571

@@ -588,11 +594,8 @@ func (api *API) writeEntitlementWarningsHeader(a rbac.Subject, header http.Heade
588594
// has no roles. This is a normal user!
589595
return
590596
}
591-
api.entitlementsMu.RLock()
592-
defer api.entitlementsMu.RUnlock()
593-
for _, warning := range api.entitlements.Warnings {
594-
header.Add(codersdk.EntitlementsWarningHeader, warning)
595-
}
597+
598+
api.entitlements.WriteEntitlementWarningHeaders(header)
596599
}
597600

598601
func (api *API) Close() error {
@@ -614,9 +617,6 @@ func (api *API) Close() error {
614617
}
615618

616619
func (api *API) updateEntitlements(ctx context.Context) error {
617-
api.entitlementsUpdateMu.Lock()
618-
defer api.entitlementsUpdateMu.Unlock()
619-
620620
replicas := api.replicaManager.AllPrimary()
621621
agedReplicas := make([]database.Replica, 0, len(replicas))
622622
for _, replica := range replicas {
@@ -632,7 +632,7 @@ func (api *API) updateEntitlements(ctx context.Context) error {
632632
agedReplicas = append(agedReplicas, replica)
633633
}
634634

635-
entitlements, err := license.Entitlements(
635+
reloadedEntitlements, err := license.Entitlements(
636636
ctx, api.Database,
637637
len(agedReplicas), len(api.ExternalAuthConfigs), api.LicenseKeys, map[codersdk.FeatureName]bool{
638638
codersdk.FeatureAuditLog: api.AuditLogging,
@@ -652,29 +652,24 @@ func (api *API) updateEntitlements(ctx context.Context) error {
652652
return err
653653
}
654654

655-
if entitlements.RequireTelemetry && !api.DeploymentValues.Telemetry.Enable.Value() {
655+
if reloadedEntitlements.RequireTelemetry && !api.DeploymentValues.Telemetry.Enable.Value() {
656656
// We can't fail because then the user couldn't remove the offending
657657
// license w/o a restart.
658658
//
659659
// We don't simply append to entitlement.Errors since we don't want any
660660
// enterprise features enabled.
661-
api.entitlements.Errors = []string{
662-
"License requires telemetry but telemetry is disabled",
663-
}
661+
api.entitlements.Update(func(entitlements *codersdk.Entitlements) {
662+
entitlements.Errors = []string{
663+
"License requires telemetry but telemetry is disabled",
664+
}
665+
})
666+
664667
api.Logger.Error(ctx, "license requires telemetry enabled")
665668
return nil
666669
}
667670

668671
featureChanged := func(featureName codersdk.FeatureName) (initial, changed, enabled bool) {
669-
if api.entitlements.Features == nil {
670-
return true, false, entitlements.Features[featureName].Enabled
671-
}
672-
oldFeature := api.entitlements.Features[featureName]
673-
newFeature := entitlements.Features[featureName]
674-
if oldFeature.Enabled != newFeature.Enabled {
675-
return false, true, newFeature.Enabled
676-
}
677-
return false, false, newFeature.Enabled
672+
return api.entitlements.FeatureChanged(featureName, reloadedEntitlements.Features[featureName])
678673
}
679674

680675
shouldUpdate := func(initial, changed, enabled bool) bool {
@@ -831,20 +826,16 @@ func (api *API) updateEntitlements(ctx context.Context) error {
831826
}
832827

833828
// External token encryption is soft-enforced
834-
featureExternalTokenEncryption := entitlements.Features[codersdk.FeatureExternalTokenEncryption]
829+
featureExternalTokenEncryption := reloadedEntitlements.Features[codersdk.FeatureExternalTokenEncryption]
835830
featureExternalTokenEncryption.Enabled = len(api.ExternalTokenEncryption) > 0
836831
if featureExternalTokenEncryption.Enabled && featureExternalTokenEncryption.Entitlement != codersdk.EntitlementEntitled {
837832
msg := fmt.Sprintf("%s is enabled (due to setting external token encryption keys) but your license is not entitled to this feature.", codersdk.FeatureExternalTokenEncryption.Humanize())
838833
api.Logger.Warn(ctx, msg)
839-
entitlements.Warnings = append(entitlements.Warnings, msg)
834+
reloadedEntitlements.Warnings = append(reloadedEntitlements.Warnings, msg)
840835
}
841-
entitlements.Features[codersdk.FeatureExternalTokenEncryption] = featureExternalTokenEncryption
836+
reloadedEntitlements.Features[codersdk.FeatureExternalTokenEncryption] = featureExternalTokenEncryption
842837

843-
api.entitlementsMu.Lock()
844-
defer api.entitlementsMu.Unlock()
845-
api.entitlements = entitlements
846-
api.licenseMetricsCollector.Entitlements.Store(&entitlements)
847-
api.AGPL.SiteHandler.Entitlements.Store(&entitlements)
838+
api.entitlements.Replace(reloadedEntitlements)
848839
return nil
849840
}
850841

@@ -1024,10 +1015,7 @@ func derpMapper(logger slog.Logger, proxyHealth *proxyhealth.ProxyHealth) func(*
10241015
// @Router /entitlements [get]
10251016
func (api *API) serveEntitlements(rw http.ResponseWriter, r *http.Request) {
10261017
ctx := r.Context()
1027-
api.entitlementsMu.RLock()
1028-
entitlements := api.entitlements
1029-
api.entitlementsMu.RUnlock()
1030-
httpapi.Write(ctx, rw, http.StatusOK, entitlements)
1018+
httpapi.Write(ctx, rw, http.StatusOK, api.entitlements.AsJSON())
10311019
}
10321020

10331021
func (api *API) runEntitlementsLoop(ctx context.Context) {

0 commit comments

Comments
 (0)