Skip to content

fix: serialize updateEntitlements() #14974

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 55 additions & 9 deletions coderd/entitlements/entitlements.go
Original file line number Diff line number Diff line change
@@ -1,21 +1,31 @@
package entitlements

import (
"context"
"encoding/json"
"net/http"
"sync"
"time"

"golang.org/x/exp/slices"
"golang.org/x/xerrors"

"github.com/coder/coder/v2/codersdk"
)

type Set struct {
entitlementsMu sync.RWMutex
entitlements codersdk.Entitlements
// right2Update works like a semaphore. Reading from the chan gives the right to update the set,
// and you send on the chan when you are done. We only allow one simultaneous update, so this
// serve to serialize them. You MUST NOT attempt to read from this channel while holding the
// entitlementsMu lock. It is permissible to acquire the entitlementsMu lock while holding the
// right2Update token.
right2Update chan struct{}
}

func New() *Set {
return &Set{
s := &Set{
// Some defaults for an unlicensed instance.
// These will be updated when coderd is initialized.
entitlements: codersdk.Entitlements{
Expand All @@ -27,7 +37,44 @@ func New() *Set {
RequireTelemetry: false,
RefreshedAt: time.Time{},
},
right2Update: make(chan struct{}, 1),
}
s.right2Update <- struct{}{} // one token, serialized updates
return s
}

// ErrLicenseRequiresTelemetry is an error returned by a fetch passed to Update to indicate that the
// fetched license cannot be used because it requires telemetry.
var ErrLicenseRequiresTelemetry = xerrors.New("License requires telemetry but telemetry is disabled")

func (l *Set) Update(ctx context.Context, fetch func(context.Context) (codersdk.Entitlements, error)) error {
select {
case <-ctx.Done():
return ctx.Err()
case <-l.right2Update:
defer func() {
l.right2Update <- struct{}{}
}()
}
ents, err := fetch(ctx)
if xerrors.Is(err, ErrLicenseRequiresTelemetry) {
// We can't fail because then the user couldn't remove the offending
// license w/o a restart.
//
// We don't simply append to entitlement.Errors since we don't want any
// enterprise features enabled.
l.Modify(func(entitlements *codersdk.Entitlements) {
entitlements.Errors = []string{err.Error()}
})
return nil
}
if err != nil {
return err
}
l.entitlementsMu.Lock()
defer l.entitlementsMu.Unlock()
l.entitlements = ents
return nil
}

// AllowRefresh returns whether the entitlements are allowed to be refreshed.
Expand Down Expand Up @@ -74,14 +121,7 @@ func (l *Set) AsJSON() json.RawMessage {
return b
}

func (l *Set) Replace(entitlements codersdk.Entitlements) {
l.entitlementsMu.Lock()
defer l.entitlementsMu.Unlock()

l.entitlements = entitlements
}

func (l *Set) Update(do func(entitlements *codersdk.Entitlements)) {
func (l *Set) Modify(do func(entitlements *codersdk.Entitlements)) {
l.entitlementsMu.Lock()
defer l.entitlementsMu.Unlock()

Expand All @@ -107,3 +147,9 @@ func (l *Set) WriteEntitlementWarningHeaders(header http.Header) {
header.Add(codersdk.EntitlementsWarningHeader, warning)
}
}

func (l *Set) Errors() []string {
l.entitlementsMu.RLock()
defer l.entitlementsMu.RUnlock()
return slices.Clone(l.entitlements.Errors)
}
85 changes: 73 additions & 12 deletions coderd/entitlements/entitlements_test.go
Original file line number Diff line number Diff line change
@@ -1,22 +1,24 @@
package entitlements_test

import (
"context"
"testing"
"time"

"github.com/stretchr/testify/require"

"github.com/coder/coder/v2/coderd/entitlements"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/testutil"
)

func TestUpdate(t *testing.T) {
func TestModify(t *testing.T) {
t.Parallel()

set := entitlements.New()
require.False(t, set.Enabled(codersdk.FeatureMultipleOrganizations))

set.Update(func(entitlements *codersdk.Entitlements) {
set.Modify(func(entitlements *codersdk.Entitlements) {
entitlements.Features[codersdk.FeatureMultipleOrganizations] = codersdk.Feature{
Enabled: true,
Entitlement: codersdk.EntitlementEntitled,
Expand All @@ -30,15 +32,15 @@ func TestAllowRefresh(t *testing.T) {

now := time.Now()
set := entitlements.New()
set.Update(func(entitlements *codersdk.Entitlements) {
set.Modify(func(entitlements *codersdk.Entitlements) {
entitlements.RefreshedAt = now
})

ok, wait := set.AllowRefresh(now)
require.False(t, ok)
require.InDelta(t, time.Minute.Seconds(), wait.Seconds(), 5)

set.Update(func(entitlements *codersdk.Entitlements) {
set.Modify(func(entitlements *codersdk.Entitlements) {
entitlements.RefreshedAt = now.Add(time.Minute * -2)
})

Expand All @@ -47,17 +49,76 @@ func TestAllowRefresh(t *testing.T) {
require.Equal(t, time.Duration(0), wait)
}

func TestReplace(t *testing.T) {
func TestUpdate(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)

set := entitlements.New()
require.False(t, set.Enabled(codersdk.FeatureMultipleOrganizations))
set.Replace(codersdk.Entitlements{
Features: map[codersdk.FeatureName]codersdk.Feature{
codersdk.FeatureMultipleOrganizations: {
Enabled: true,
},
},
})
fetchStarted := make(chan struct{})
firstDone := make(chan struct{})
errCh := make(chan error, 2)
go func() {
err := set.Update(ctx, func(_ context.Context) (codersdk.Entitlements, error) {
close(fetchStarted)
select {
case <-firstDone:
// OK!
case <-ctx.Done():
t.Error("timeout")
return codersdk.Entitlements{}, ctx.Err()
}
return codersdk.Entitlements{
Features: map[codersdk.FeatureName]codersdk.Feature{
codersdk.FeatureMultipleOrganizations: {
Enabled: true,
},
},
}, nil
})
errCh <- err
}()
testutil.RequireRecvCtx(ctx, t, fetchStarted)
require.False(t, set.Enabled(codersdk.FeatureMultipleOrganizations))
// start a second update while the first one is in progress
go func() {
err := set.Update(ctx, func(_ context.Context) (codersdk.Entitlements, error) {
return codersdk.Entitlements{
Features: map[codersdk.FeatureName]codersdk.Feature{
codersdk.FeatureMultipleOrganizations: {
Enabled: true,
},
codersdk.FeatureAppearance: {
Enabled: true,
},
},
}, nil
})
errCh <- err
}()
close(firstDone)
err := testutil.RequireRecvCtx(ctx, t, errCh)
require.NoError(t, err)
err = testutil.RequireRecvCtx(ctx, t, errCh)
require.NoError(t, err)
require.True(t, set.Enabled(codersdk.FeatureMultipleOrganizations))
require.True(t, set.Enabled(codersdk.FeatureAppearance))
}

func TestUpdate_LicenseRequiresTelemetry(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
set := entitlements.New()
set.Modify(func(entitlements *codersdk.Entitlements) {
entitlements.Errors = []string{"some error"}
entitlements.Features[codersdk.FeatureAppearance] = codersdk.Feature{
Enabled: true,
}
})
err := set.Update(ctx, func(_ context.Context) (codersdk.Entitlements, error) {
return codersdk.Entitlements{}, entitlements.ErrLicenseRequiresTelemetry
})
require.NoError(t, err)
require.True(t, set.Enabled(codersdk.FeatureAppearance))
require.Equal(t, []string{entitlements.ErrLicenseRequiresTelemetry.Error()}, set.Errors())
}
Loading
Loading