Skip to content

Commit 9a52b4b

Browse files
committed
fix(coderd/rbac): do not cache context cancellation errors
1 parent 52c08a9 commit 9a52b4b

File tree

2 files changed

+54
-6
lines changed

2 files changed

+54
-6
lines changed

coderd/rbac/authz.go

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"crypto/sha256"
66
_ "embed"
77
"encoding/json"
8+
"errors"
89
"strings"
910
"sync"
1011
"time"
@@ -653,10 +654,10 @@ type authCache struct {
653654
authz Authorizer
654655
}
655656

656-
// Cacher returns an Authorizer that can use a cache stored on a context
657-
// to short circuit duplicate calls to the Authorizer. This is useful when
658-
// multiple calls are made to the Authorizer for the same subject, action, and
659-
// object. The cache is on each `ctx` and is not shared between requests.
657+
// Cacher returns an Authorizer that can use a cache to short circuit duplicate
658+
// calls to the Authorizer. This is useful when multiple calls are made to the
659+
// Authorizer for the same subject, action, and object.
660+
// This is a GLOBAL cache shared between all requests.
660661
// If no cache is found on the context, the Authorizer is called as normal.
661662
//
662663
// Cacher is safe for multiple actors.
@@ -676,8 +677,12 @@ func (c *authCache) Authorize(ctx context.Context, subject Subject, action Actio
676677
err, _, ok := c.cache.Get(authorizeCacheKey)
677678
if !ok {
678679
err = c.authz.Authorize(ctx, subject, action, object)
679-
// In case there is a caching bug, bound the TTL to 1 minute.
680-
c.cache.Set(authorizeCacheKey, err, time.Minute)
680+
// If there is a transient error such as a context cancellation, do not
681+
// cache it.
682+
if !errors.Is(err, context.Canceled) {
683+
// In case there is a caching bug, bound the TTL to 1 minute.
684+
c.cache.Set(authorizeCacheKey, err, time.Minute)
685+
}
681686
}
682687

683688
return err

coderd/rbac/authz_test.go

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,12 @@ import (
77

88
"github.com/google/uuid"
99
"github.com/prometheus/client_golang/prometheus"
10+
"github.com/stretchr/testify/assert"
1011
"github.com/stretchr/testify/require"
1112

1213
"github.com/coder/coder/v2/coderd/coderdtest"
1314
"github.com/coder/coder/v2/coderd/rbac"
15+
"github.com/coder/coder/v2/testutil"
1416
)
1517

1618
type benchmarkCase struct {
@@ -351,6 +353,47 @@ func TestCacher(t *testing.T) {
351353
require.NoError(t, rec.AllAsserted(), "all assertions should have been made")
352354
})
353355

356+
t.Run("DontCacheTransientErrors", func(t *testing.T) {
357+
t.Parallel()
358+
359+
var (
360+
ctx = testutil.Context(t, testutil.WaitShort)
361+
authOut = make(chan error, 1) // buffered to not block
362+
authorizeFunc = func(ctx context.Context, subject rbac.Subject, action rbac.Action, object rbac.Object) error {
363+
// Just return what you're told.
364+
return testutil.RequireRecvCtx(ctx, t, authOut)
365+
}
366+
ma = &rbac.MockAuthorizer{AuthorizeFunc: authorizeFunc}
367+
rec = &coderdtest.RecordingAuthorizer{Wrapped: ma}
368+
authz = rbac.Cacher(rec)
369+
subj, obj, action = coderdtest.RandomRBACSubject(), coderdtest.RandomRBACObject(), coderdtest.RandomRBACAction()
370+
)
371+
372+
// First call will result in a transient error. This should not be cached.
373+
testutil.RequireSendCtx(ctx, t, authOut, context.Canceled)
374+
err := authz.Authorize(ctx, subj, action, obj)
375+
assert.ErrorIs(t, err, context.Canceled)
376+
377+
// A subsequent call should still hit the authorizer.
378+
testutil.RequireSendCtx(ctx, t, authOut, nil)
379+
err = authz.Authorize(ctx, subj, action, obj)
380+
assert.NoError(t, err)
381+
// This should be cached and not hit the wrapped authorizer again.
382+
err = authz.Authorize(ctx, subj, action, obj)
383+
assert.NoError(t, err)
384+
385+
// Let's change the subject.
386+
subj, obj, action = coderdtest.RandomRBACSubject(), coderdtest.RandomRBACObject(), coderdtest.RandomRBACAction()
387+
388+
// A third will be a legit error
389+
testutil.RequireSendCtx(ctx, t, authOut, assert.AnError)
390+
err = authz.Authorize(ctx, subj, action, obj)
391+
assert.EqualError(t, err, assert.AnError.Error())
392+
// This should be cached and not hit the wrapped authorizer again.
393+
err = authz.Authorize(ctx, subj, action, obj)
394+
assert.EqualError(t, err, assert.AnError.Error())
395+
})
396+
354397
t.Run("MultipleSubjects", func(t *testing.T) {
355398
t.Parallel()
356399

0 commit comments

Comments
 (0)