Skip to content

Commit 288df75

Browse files
authored
fix: serialize updateEntitlements() (#14974)
fixes #14961 Adding the license and updating entitlements is flaky, especially at the start of our `coderdent` testing because, while the actual modifications to the `entitlements.Set` were threadsafe, we could have multiple goroutines reading from the database and writing to the set, so we could end up writing stale data. This enforces serialization on updates, so that if you modify the database and kick off an update, you know the state of the `Set` is at least as fresh as your database update.
1 parent ea3b13c commit 288df75

File tree

7 files changed

+304
-208
lines changed

7 files changed

+304
-208
lines changed

coderd/entitlements/entitlements.go

+55-9
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,31 @@
11
package entitlements
22

33
import (
4+
"context"
45
"encoding/json"
56
"net/http"
67
"sync"
78
"time"
89

10+
"golang.org/x/exp/slices"
11+
"golang.org/x/xerrors"
12+
913
"github.com/coder/coder/v2/codersdk"
1014
)
1115

1216
type Set struct {
1317
entitlementsMu sync.RWMutex
1418
entitlements codersdk.Entitlements
19+
// right2Update works like a semaphore. Reading from the chan gives the right to update the set,
20+
// and you send on the chan when you are done. We only allow one simultaneous update, so this
21+
// serve to serialize them. You MUST NOT attempt to read from this channel while holding the
22+
// entitlementsMu lock. It is permissible to acquire the entitlementsMu lock while holding the
23+
// right2Update token.
24+
right2Update chan struct{}
1525
}
1626

1727
func New() *Set {
18-
return &Set{
28+
s := &Set{
1929
// Some defaults for an unlicensed instance.
2030
// These will be updated when coderd is initialized.
2131
entitlements: codersdk.Entitlements{
@@ -27,7 +37,44 @@ func New() *Set {
2737
RequireTelemetry: false,
2838
RefreshedAt: time.Time{},
2939
},
40+
right2Update: make(chan struct{}, 1),
3041
}
42+
s.right2Update <- struct{}{} // one token, serialized updates
43+
return s
44+
}
45+
46+
// ErrLicenseRequiresTelemetry is an error returned by a fetch passed to Update to indicate that the
47+
// fetched license cannot be used because it requires telemetry.
48+
var ErrLicenseRequiresTelemetry = xerrors.New("License requires telemetry but telemetry is disabled")
49+
50+
func (l *Set) Update(ctx context.Context, fetch func(context.Context) (codersdk.Entitlements, error)) error {
51+
select {
52+
case <-ctx.Done():
53+
return ctx.Err()
54+
case <-l.right2Update:
55+
defer func() {
56+
l.right2Update <- struct{}{}
57+
}()
58+
}
59+
ents, err := fetch(ctx)
60+
if xerrors.Is(err, ErrLicenseRequiresTelemetry) {
61+
// We can't fail because then the user couldn't remove the offending
62+
// license w/o a restart.
63+
//
64+
// We don't simply append to entitlement.Errors since we don't want any
65+
// enterprise features enabled.
66+
l.Modify(func(entitlements *codersdk.Entitlements) {
67+
entitlements.Errors = []string{err.Error()}
68+
})
69+
return nil
70+
}
71+
if err != nil {
72+
return err
73+
}
74+
l.entitlementsMu.Lock()
75+
defer l.entitlementsMu.Unlock()
76+
l.entitlements = ents
77+
return nil
3178
}
3279

3380
// AllowRefresh returns whether the entitlements are allowed to be refreshed.
@@ -74,14 +121,7 @@ func (l *Set) AsJSON() json.RawMessage {
74121
return b
75122
}
76123

77-
func (l *Set) Replace(entitlements codersdk.Entitlements) {
78-
l.entitlementsMu.Lock()
79-
defer l.entitlementsMu.Unlock()
80-
81-
l.entitlements = entitlements
82-
}
83-
84-
func (l *Set) Update(do func(entitlements *codersdk.Entitlements)) {
124+
func (l *Set) Modify(do func(entitlements *codersdk.Entitlements)) {
85125
l.entitlementsMu.Lock()
86126
defer l.entitlementsMu.Unlock()
87127

@@ -107,3 +147,9 @@ func (l *Set) WriteEntitlementWarningHeaders(header http.Header) {
107147
header.Add(codersdk.EntitlementsWarningHeader, warning)
108148
}
109149
}
150+
151+
func (l *Set) Errors() []string {
152+
l.entitlementsMu.RLock()
153+
defer l.entitlementsMu.RUnlock()
154+
return slices.Clone(l.entitlements.Errors)
155+
}
+73-12
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,24 @@
11
package entitlements_test
22

33
import (
4+
"context"
45
"testing"
56
"time"
67

78
"github.com/stretchr/testify/require"
89

910
"github.com/coder/coder/v2/coderd/entitlements"
1011
"github.com/coder/coder/v2/codersdk"
12+
"github.com/coder/coder/v2/testutil"
1113
)
1214

13-
func TestUpdate(t *testing.T) {
15+
func TestModify(t *testing.T) {
1416
t.Parallel()
1517

1618
set := entitlements.New()
1719
require.False(t, set.Enabled(codersdk.FeatureMultipleOrganizations))
1820

19-
set.Update(func(entitlements *codersdk.Entitlements) {
21+
set.Modify(func(entitlements *codersdk.Entitlements) {
2022
entitlements.Features[codersdk.FeatureMultipleOrganizations] = codersdk.Feature{
2123
Enabled: true,
2224
Entitlement: codersdk.EntitlementEntitled,
@@ -30,15 +32,15 @@ func TestAllowRefresh(t *testing.T) {
3032

3133
now := time.Now()
3234
set := entitlements.New()
33-
set.Update(func(entitlements *codersdk.Entitlements) {
35+
set.Modify(func(entitlements *codersdk.Entitlements) {
3436
entitlements.RefreshedAt = now
3537
})
3638

3739
ok, wait := set.AllowRefresh(now)
3840
require.False(t, ok)
3941
require.InDelta(t, time.Minute.Seconds(), wait.Seconds(), 5)
4042

41-
set.Update(func(entitlements *codersdk.Entitlements) {
43+
set.Modify(func(entitlements *codersdk.Entitlements) {
4244
entitlements.RefreshedAt = now.Add(time.Minute * -2)
4345
})
4446

@@ -47,17 +49,76 @@ func TestAllowRefresh(t *testing.T) {
4749
require.Equal(t, time.Duration(0), wait)
4850
}
4951

50-
func TestReplace(t *testing.T) {
52+
func TestUpdate(t *testing.T) {
5153
t.Parallel()
54+
ctx := testutil.Context(t, testutil.WaitShort)
5255

5356
set := entitlements.New()
5457
require.False(t, set.Enabled(codersdk.FeatureMultipleOrganizations))
55-
set.Replace(codersdk.Entitlements{
56-
Features: map[codersdk.FeatureName]codersdk.Feature{
57-
codersdk.FeatureMultipleOrganizations: {
58-
Enabled: true,
59-
},
60-
},
61-
})
58+
fetchStarted := make(chan struct{})
59+
firstDone := make(chan struct{})
60+
errCh := make(chan error, 2)
61+
go func() {
62+
err := set.Update(ctx, func(_ context.Context) (codersdk.Entitlements, error) {
63+
close(fetchStarted)
64+
select {
65+
case <-firstDone:
66+
// OK!
67+
case <-ctx.Done():
68+
t.Error("timeout")
69+
return codersdk.Entitlements{}, ctx.Err()
70+
}
71+
return codersdk.Entitlements{
72+
Features: map[codersdk.FeatureName]codersdk.Feature{
73+
codersdk.FeatureMultipleOrganizations: {
74+
Enabled: true,
75+
},
76+
},
77+
}, nil
78+
})
79+
errCh <- err
80+
}()
81+
testutil.RequireRecvCtx(ctx, t, fetchStarted)
82+
require.False(t, set.Enabled(codersdk.FeatureMultipleOrganizations))
83+
// start a second update while the first one is in progress
84+
go func() {
85+
err := set.Update(ctx, func(_ context.Context) (codersdk.Entitlements, error) {
86+
return codersdk.Entitlements{
87+
Features: map[codersdk.FeatureName]codersdk.Feature{
88+
codersdk.FeatureMultipleOrganizations: {
89+
Enabled: true,
90+
},
91+
codersdk.FeatureAppearance: {
92+
Enabled: true,
93+
},
94+
},
95+
}, nil
96+
})
97+
errCh <- err
98+
}()
99+
close(firstDone)
100+
err := testutil.RequireRecvCtx(ctx, t, errCh)
101+
require.NoError(t, err)
102+
err = testutil.RequireRecvCtx(ctx, t, errCh)
103+
require.NoError(t, err)
62104
require.True(t, set.Enabled(codersdk.FeatureMultipleOrganizations))
105+
require.True(t, set.Enabled(codersdk.FeatureAppearance))
106+
}
107+
108+
func TestUpdate_LicenseRequiresTelemetry(t *testing.T) {
109+
t.Parallel()
110+
ctx := testutil.Context(t, testutil.WaitShort)
111+
set := entitlements.New()
112+
set.Modify(func(entitlements *codersdk.Entitlements) {
113+
entitlements.Errors = []string{"some error"}
114+
entitlements.Features[codersdk.FeatureAppearance] = codersdk.Feature{
115+
Enabled: true,
116+
}
117+
})
118+
err := set.Update(ctx, func(_ context.Context) (codersdk.Entitlements, error) {
119+
return codersdk.Entitlements{}, entitlements.ErrLicenseRequiresTelemetry
120+
})
121+
require.NoError(t, err)
122+
require.True(t, set.Enabled(codersdk.FeatureAppearance))
123+
require.Equal(t, []string{entitlements.ErrLicenseRequiresTelemetry.Error()}, set.Errors())
63124
}

0 commit comments

Comments
 (0)