diff --git a/coderd/coderd.go b/coderd/coderd.go index 532eba43bf711..4545f9531c153 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -187,7 +187,7 @@ func New(options *Options) *API { options.PrometheusRegistry = prometheus.NewRegistry() } if options.Authorizer == nil { - options.Authorizer = rbac.NewAuthorizer(options.PrometheusRegistry) + options.Authorizer = rbac.NewCachingAuthorizer(options.PrometheusRegistry) } if options.TailnetCoordinator == nil { options.TailnetCoordinator = tailnet.NewCoordinator() @@ -289,6 +289,7 @@ func New(options *Options) *API { tracing.StatusWriterMiddleware, tracing.Middleware(api.TracerProvider), httpmw.AttachRequestID, + httpmw.AttachAuthzCache, httpmw.ExtractRealIP(api.RealIPConfig), httpmw.Logger(api.Logger), httpmw.Prometheus(options.PrometheusRegistry), diff --git a/coderd/coderdtest/coderdtest.go b/coderd/coderdtest/coderdtest.go index de14693313e41..3f444780aaa22 100644 --- a/coderd/coderdtest/coderdtest.go +++ b/coderd/coderdtest/coderdtest.go @@ -184,7 +184,7 @@ func NewOptions(t *testing.T, options *Options) (func(http.Handler), context.Can if strings.Contains(os.Getenv("CODER_EXPERIMENTS_TEST"), string(codersdk.ExperimentAuthzQuerier)) { if options.Authorizer == nil { options.Authorizer = &RecordingAuthorizer{ - Wrapped: rbac.NewAuthorizer(prometheus.NewRegistry()), + Wrapped: rbac.NewCachingAuthorizer(prometheus.NewRegistry()), } } options.Database = dbauthz.New(options.Database, options.Authorizer, slogtest.Make(t, nil).Leveled(slog.LevelDebug)) diff --git a/coderd/httpmw/authz.go b/coderd/httpmw/authz.go index 5bfe69d47c956..9decc3eb31649 100644 --- a/coderd/httpmw/authz.go +++ b/coderd/httpmw/authz.go @@ -3,9 +3,10 @@ package httpmw import ( "net/http" - "github.com/coder/coder/coderd/database/dbauthz" - "github.com/go-chi/chi/v5" + + "github.com/coder/coder/coderd/database/dbauthz" + "github.com/coder/coder/coderd/rbac" ) // AsAuthzSystem is a chained handler that temporarily sets the dbauthz context @@ -35,3 +36,16 @@ func AsAuthzSystem(mws ...func(http.Handler) http.Handler) func(http.Handler) ht }) } } + +// AttachAuthzCache enables the authz cache for the authorizer. All rbac checks will +// run against the cache, meaning duplicate checks will not be performed. +// +// Note the cache is safe for multiple actors. So mixing user and system checks +// is ok. +func AttachAuthzCache(next http.Handler) http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + ctx := rbac.WithCacheCtx(r.Context()) + + next.ServeHTTP(rw, r.WithContext(ctx)) + }) +} diff --git a/coderd/rbac/authz.go b/coderd/rbac/authz.go index 385feae294ff1..3f2b40bacebad 100644 --- a/coderd/rbac/authz.go +++ b/coderd/rbac/authz.go @@ -158,6 +158,13 @@ var ( partialQuery rego.PreparedPartialQuery ) +// NewCachingAuthorizer returns a new RegoAuthorizer that supports context based +// caching. To utilize the caching, the context passed to Authorize() must be +// created with 'WithCacheCtx(ctx)'. +func NewCachingAuthorizer(registry prometheus.Registerer) Authorizer { + return Cacher(NewAuthorizer(registry)) +} + func NewAuthorizer(registry prometheus.Registerer) *RegoAuthorizer { queryOnce.Do(func() { var err error diff --git a/coderd/rbac/authz_test.go b/coderd/rbac/authz_test.go index bcd560a05f3fa..8d024b55088f5 100644 --- a/coderd/rbac/authz_test.go +++ b/coderd/rbac/authz_test.go @@ -109,7 +109,7 @@ func BenchmarkRBACAuthorize(b *testing.B) { uuid.MustParse("0632b012-49e0-4d70-a5b3-f4398f1dcd52"), uuid.MustParse("70dbaa7a-ea9c-4f68-a781-97b08af8461d"), ) - authorizer := rbac.NewAuthorizer(prometheus.NewRegistry()) + authorizer := rbac.NewCachingAuthorizer(prometheus.NewRegistry()) // This benchmarks all the simple cases using just user permissions. Groups // are added as noise, but do not do anything. for _, c := range benchCases { @@ -136,7 +136,7 @@ func BenchmarkRBACAuthorizeGroups(b *testing.B) { uuid.MustParse("0632b012-49e0-4d70-a5b3-f4398f1dcd52"), uuid.MustParse("70dbaa7a-ea9c-4f68-a781-97b08af8461d"), ) - authorizer := rbac.NewAuthorizer(prometheus.NewRegistry()) + authorizer := rbac.NewCachingAuthorizer(prometheus.NewRegistry()) // Same benchmark cases, but this time groups will be used to match. // Some '*' permissions will still match, but using a fake action reduces @@ -188,7 +188,7 @@ func BenchmarkRBACFilter(b *testing.B) { uuid.MustParse("70dbaa7a-ea9c-4f68-a781-97b08af8461d"), ) - authorizer := rbac.NewAuthorizer(prometheus.NewRegistry()) + authorizer := rbac.NewCachingAuthorizer(prometheus.NewRegistry()) for _, c := range benchCases { b.Run("PrepareOnly-"+c.Name, func(b *testing.B) { diff --git a/coderd/rbac/builtin_test.go b/coderd/rbac/builtin_test.go index 6e5b67b6474a8..0a83f987f7244 100644 --- a/coderd/rbac/builtin_test.go +++ b/coderd/rbac/builtin_test.go @@ -23,7 +23,7 @@ type authSubject struct { func TestRolePermissions(t *testing.T) { t.Parallel() - auth := rbac.NewAuthorizer(prometheus.NewRegistry()) + auth := rbac.NewCachingAuthorizer(prometheus.NewRegistry()) // currentUser is anything that references "me", "mine", or "my". currentUser := uuid.New() diff --git a/coderd/rbac/cache.go b/coderd/rbac/cache.go index 7ee71bb6e32ab..1bd2bb53258ed 100644 --- a/coderd/rbac/cache.go +++ b/coderd/rbac/cache.go @@ -20,6 +20,8 @@ type cachedCalls struct { // multiple calls are made to the Authorizer for the same subject, action, and // object. The cache is on each `ctx` and is not shared between requests. // If no cache is found on the context, the Authorizer is called as normal. +// +// Cacher is safe for multiple actors. func Cacher(authz Authorizer) Authorizer { return &cachedCalls{authz: authz} } diff --git a/coderd/rbac/cache_test.go b/coderd/rbac/cache_test.go index 4670da307b756..03f7068f415a7 100644 --- a/coderd/rbac/cache_test.go +++ b/coderd/rbac/cache_test.go @@ -2,6 +2,7 @@ package rbac_test import ( "context" + "fmt" "testing" "github.com/stretchr/testify/require" @@ -10,6 +11,33 @@ import ( "github.com/coder/coder/coderd/rbac" ) +// BenchmarkCacher benchmarks the performance of the cacher with a given +// cache size. The expected cache size in prod will usually be 1-2. In Filter +// cases it can get as high as 10. +func BenchmarkCacher(b *testing.B) { + b.ResetTimer() + // Size of the cache. + sizes := []int{1, 10, 100, 1000} + for _, size := range sizes { + b.Run(fmt.Sprintf("Size%d", size), func(b *testing.B) { + ctx := rbac.WithCacheCtx(context.Background()) + authz := rbac.Cacher(&coderdtest.FakeAuthorizer{AlwaysReturn: nil}) + for i := 0; i < size; i++ { + // Preload the cache of a given size + subj, obj, action := coderdtest.RandomRBACSubject(), coderdtest.RandomRBACObject(), coderdtest.RandomRBACAction() + _ = authz.Authorize(ctx, subj, action, obj) + } + + // Cache is loaded as a slice, so this cache hit is always the last element. + subj, obj, action := coderdtest.RandomRBACSubject(), coderdtest.RandomRBACObject(), coderdtest.RandomRBACAction() + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = authz.Authorize(ctx, subj, action, obj) + } + }) + } +} + func TestCacher(t *testing.T) { t.Parallel() diff --git a/enterprise/coderd/coderd.go b/enterprise/coderd/coderd.go index 20d984a3b946c..9e195735767ef 100644 --- a/enterprise/coderd/coderd.go +++ b/enterprise/coderd/coderd.go @@ -47,7 +47,7 @@ func New(ctx context.Context, options *Options) (*API, error) { options.PrometheusRegistry = prometheus.NewRegistry() } if options.Options.Authorizer == nil { - options.Options.Authorizer = rbac.NewAuthorizer(options.PrometheusRegistry) + options.Options.Authorizer = rbac.NewCachingAuthorizer(options.PrometheusRegistry) } ctx, cancelFunc := context.WithCancel(ctx) api := &API{