Skip to content

Commit 61c37d0

Browse files
committed
fix: serialize updateEntitlements()
1 parent 9acf6ac commit 61c37d0

File tree

7 files changed

+304
-208
lines changed

7 files changed

+304
-208
lines changed

coderd/entitlements/entitlements.go

+55-9
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,31 @@
11
package entitlements
22

33
import (
4+
"context"
45
"encoding/json"
56
"net/http"
67
"sync"
78
"time"
89

10+
"golang.org/x/exp/slices"
11+
"golang.org/x/xerrors"
12+
913
"github.com/coder/coder/v2/codersdk"
1014
)
1115

1216
type Set struct {
1317
entitlementsMu sync.RWMutex
1418
entitlements codersdk.Entitlements
19+
// right2Update works like a semaphore. Reading from the chan gives the right to update the set,
20+
// and you send on the chan when you are done. We only allow one simultaneous update, so this
21+
// serve to serialize them. You MUST NOT attempt to read from this channel while holding the
22+
// entitlementsMu lock. It is permissible to acquire the entitlementsMu lock while holding the
23+
// right2Update token.
24+
right2Update chan struct{}
1525
}
1626

1727
func New() *Set {
18-
return &Set{
28+
s := &Set{
1929
// Some defaults for an unlicensed instance.
2030
// These will be updated when coderd is initialized.
2131
entitlements: codersdk.Entitlements{
@@ -27,7 +37,44 @@ func New() *Set {
2737
RequireTelemetry: false,
2838
RefreshedAt: time.Time{},
2939
},
40+
right2Update: make(chan struct{}, 1),
3041
}
42+
s.right2Update <- struct{}{} // one token, serialized updates
43+
return s
44+
}
45+
46+
// ErrLicenseRequiresTelemetry is an error returned by a fetch passed to Update to indicate that the
47+
// fetched license cannot be used because it requires telemetry.
48+
var ErrLicenseRequiresTelemetry = xerrors.New("License requires telemetry but telemetry is disabled")
49+
50+
func (l *Set) Update(ctx context.Context, fetch func(context.Context) (codersdk.Entitlements, error)) error {
51+
select {
52+
case <-ctx.Done():
53+
return ctx.Err()
54+
case <-l.right2Update:
55+
defer func() {
56+
l.right2Update <- struct{}{}
57+
}()
58+
}
59+
ents, err := fetch(ctx)
60+
if xerrors.Is(err, ErrLicenseRequiresTelemetry) {
61+
// We can't fail because then the user couldn't remove the offending
62+
// license w/o a restart.
63+
//
64+
// We don't simply append to entitlement.Errors since we don't want any
65+
// enterprise features enabled.
66+
l.Modify(func(entitlements *codersdk.Entitlements) {
67+
entitlements.Errors = []string{err.Error()}
68+
})
69+
return nil
70+
}
71+
if err != nil {
72+
return err
73+
}
74+
l.entitlementsMu.Lock()
75+
defer l.entitlementsMu.Unlock()
76+
l.entitlements = ents
77+
return nil
3178
}
3279

3380
// AllowRefresh returns whether the entitlements are allowed to be refreshed.
@@ -74,14 +121,7 @@ func (l *Set) AsJSON() json.RawMessage {
74121
return b
75122
}
76123

77-
func (l *Set) Replace(entitlements codersdk.Entitlements) {
78-
l.entitlementsMu.Lock()
79-
defer l.entitlementsMu.Unlock()
80-
81-
l.entitlements = entitlements
82-
}
83-
84-
func (l *Set) Update(do func(entitlements *codersdk.Entitlements)) {
124+
func (l *Set) Modify(do func(entitlements *codersdk.Entitlements)) {
85125
l.entitlementsMu.Lock()
86126
defer l.entitlementsMu.Unlock()
87127

@@ -107,3 +147,9 @@ func (l *Set) WriteEntitlementWarningHeaders(header http.Header) {
107147
header.Add(codersdk.EntitlementsWarningHeader, warning)
108148
}
109149
}
150+
151+
func (l *Set) Errors() []string {
152+
l.entitlementsMu.RLock()
153+
defer l.entitlementsMu.RUnlock()
154+
return slices.Clone(l.entitlements.Errors)
155+
}
+73-12
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,24 @@
11
package entitlements_test
22

33
import (
4+
"context"
45
"testing"
56
"time"
67

78
"github.com/stretchr/testify/require"
89

910
"github.com/coder/coder/v2/coderd/entitlements"
1011
"github.com/coder/coder/v2/codersdk"
12+
"github.com/coder/coder/v2/testutil"
1113
)
1214

13-
func TestUpdate(t *testing.T) {
15+
func TestModify(t *testing.T) {
1416
t.Parallel()
1517

1618
set := entitlements.New()
1719
require.False(t, set.Enabled(codersdk.FeatureMultipleOrganizations))
1820

19-
set.Update(func(entitlements *codersdk.Entitlements) {
21+
set.Modify(func(entitlements *codersdk.Entitlements) {
2022
entitlements.Features[codersdk.FeatureMultipleOrganizations] = codersdk.Feature{
2123
Enabled: true,
2224
Entitlement: codersdk.EntitlementEntitled,
@@ -30,15 +32,15 @@ func TestAllowRefresh(t *testing.T) {
3032

3133
now := time.Now()
3234
set := entitlements.New()
33-
set.Update(func(entitlements *codersdk.Entitlements) {
35+
set.Modify(func(entitlements *codersdk.Entitlements) {
3436
entitlements.RefreshedAt = now
3537
})
3638

3739
ok, wait := set.AllowRefresh(now)
3840
require.False(t, ok)
3941
require.InDelta(t, time.Minute.Seconds(), wait.Seconds(), 5)
4042

41-
set.Update(func(entitlements *codersdk.Entitlements) {
43+
set.Modify(func(entitlements *codersdk.Entitlements) {
4244
entitlements.RefreshedAt = now.Add(time.Minute * -2)
4345
})
4446

@@ -47,17 +49,76 @@ func TestAllowRefresh(t *testing.T) {
4749
require.Equal(t, time.Duration(0), wait)
4850
}
4951

50-
func TestReplace(t *testing.T) {
52+
func TestUpdate(t *testing.T) {
5153
t.Parallel()
54+
ctx := testutil.Context(t, testutil.WaitShort)
5255

5356
set := entitlements.New()
5457
require.False(t, set.Enabled(codersdk.FeatureMultipleOrganizations))
55-
set.Replace(codersdk.Entitlements{
56-
Features: map[codersdk.FeatureName]codersdk.Feature{
57-
codersdk.FeatureMultipleOrganizations: {
58-
Enabled: true,
59-
},
60-
},
61-
})
58+
fetchStarted := make(chan struct{})
59+
firstDone := make(chan struct{})
60+
errCh := make(chan error, 2)
61+
go func() {
62+
err := set.Update(ctx, func(_ context.Context) (codersdk.Entitlements, error) {
63+
close(fetchStarted)
64+
select {
65+
case <-firstDone:
66+
// OK!
67+
case <-ctx.Done():
68+
t.Error("timeout")
69+
return codersdk.Entitlements{}, ctx.Err()
70+
}
71+
return codersdk.Entitlements{
72+
Features: map[codersdk.FeatureName]codersdk.Feature{
73+
codersdk.FeatureMultipleOrganizations: {
74+
Enabled: true,
75+
},
76+
},
77+
}, nil
78+
})
79+
errCh <- err
80+
}()
81+
testutil.RequireRecvCtx(ctx, t, fetchStarted)
82+
require.False(t, set.Enabled(codersdk.FeatureMultipleOrganizations))
83+
// start a second update while the first one is in progress
84+
go func() {
85+
err := set.Update(ctx, func(_ context.Context) (codersdk.Entitlements, error) {
86+
return codersdk.Entitlements{
87+
Features: map[codersdk.FeatureName]codersdk.Feature{
88+
codersdk.FeatureMultipleOrganizations: {
89+
Enabled: true,
90+
},
91+
codersdk.FeatureAppearance: {
92+
Enabled: true,
93+
},
94+
},
95+
}, nil
96+
})
97+
errCh <- err
98+
}()
99+
close(firstDone)
100+
err := testutil.RequireRecvCtx(ctx, t, errCh)
101+
require.NoError(t, err)
102+
err = testutil.RequireRecvCtx(ctx, t, errCh)
103+
require.NoError(t, err)
62104
require.True(t, set.Enabled(codersdk.FeatureMultipleOrganizations))
105+
require.True(t, set.Enabled(codersdk.FeatureAppearance))
106+
}
107+
108+
func TestUpdate_LicenseRequiresTelemetry(t *testing.T) {
109+
t.Parallel()
110+
ctx := testutil.Context(t, testutil.WaitShort)
111+
set := entitlements.New()
112+
set.Modify(func(entitlements *codersdk.Entitlements) {
113+
entitlements.Errors = []string{"some error"}
114+
entitlements.Features[codersdk.FeatureAppearance] = codersdk.Feature{
115+
Enabled: true,
116+
}
117+
})
118+
err := set.Update(ctx, func(_ context.Context) (codersdk.Entitlements, error) {
119+
return codersdk.Entitlements{}, entitlements.ErrLicenseRequiresTelemetry
120+
})
121+
require.NoError(t, err)
122+
require.True(t, set.Enabled(codersdk.FeatureAppearance))
123+
require.Equal(t, []string{entitlements.ErrLicenseRequiresTelemetry.Error()}, set.Errors())
63124
}

0 commit comments

Comments
 (0)