From fd05de9e0d3fe959662f877b8245a545eea1f55c Mon Sep 17 00:00:00 2001 From: Kyle Carberry Date: Tue, 13 Sep 2022 02:42:13 +0000 Subject: [PATCH 01/19] chore: Refactor Enterprise code to layer on top of AGPL This is an experiment to invert the import order of the Enterprise code to layer on top of AGPL. --- cli/root.go | 4 +- cli/server.go | 9 +- coderd/audit/request.go | 18 +- coderd/authorize.go | 2 +- coderd/coderd.go | 61 +- coderd/coderd_test.go | 30 +- coderd/coderdtest/coderdtest.go | 39 +- coderd/features.go | 97 ---- coderd/features/features.go | 4 - coderd/features_internal_test.go | 100 ---- coderd/licenses.go | 24 - coderd/provisionerdaemons.go | 2 +- coderd/provisionerjobs_internal_test.go | 2 +- coderd/templates.go | 34 +- coderd/templateversions.go | 16 +- coderd/users.go | 46 +- coderd/workspaces.go | 34 +- enterprise/cli/licenses_test.go | 8 +- enterprise/cli/root.go | 13 +- enterprise/coderd/auth_internal_test.go | 80 --- enterprise/coderd/coderd.go | 240 ++++++-- enterprise/coderd/coderd_test.go | 93 +++ .../coderd/coderdenttest/coderdenttest.go | 124 ++++ .../coderdenttest/coderdenttest_test.go | 314 +++++----- enterprise/coderd/features.go | 327 ----------- enterprise/coderd/features_internal_test.go | 545 ------------------ enterprise/coderd/licenses.go | 249 ++++---- enterprise/coderd/licenses_internal_test.go | 316 ---------- enterprise/coderd/licenses_test.go | 168 ++++++ 29 files changed, 1048 insertions(+), 1951 deletions(-) delete mode 100644 coderd/features.go delete mode 100644 coderd/features_internal_test.go delete mode 100644 coderd/licenses.go delete mode 100644 enterprise/coderd/auth_internal_test.go create mode 100644 enterprise/coderd/coderd_test.go create mode 100644 enterprise/coderd/coderdenttest/coderdenttest.go rename coderd/coderdtest/authtest.go => enterprise/coderd/coderdenttest/coderdenttest_test.go (90%) delete mode 100644 enterprise/coderd/features.go delete mode 100644 enterprise/coderd/features_internal_test.go delete mode 100644 enterprise/coderd/licenses_internal_test.go create mode 100644 enterprise/coderd/licenses_test.go diff --git a/cli/root.go b/cli/root.go index 779c47c07c1c3..430201d049506 100644 --- a/cli/root.go +++ b/cli/root.go @@ -96,7 +96,9 @@ func Core() []*cobra.Command { } func AGPL() []*cobra.Command { - all := append(Core(), Server(coderd.New)) + all := append(Core(), Server(func(_ context.Context, o *coderd.Options) (*coderd.API, error) { + return coderd.New(o), nil + })) return all } diff --git a/cli/server.go b/cli/server.go index 925a1a619a840..f64f49a24ce7b 100644 --- a/cli/server.go +++ b/cli/server.go @@ -70,7 +70,7 @@ import ( ) // nolint:gocyclo -func Server(newAPI func(*coderd.Options) *coderd.API) *cobra.Command { +func Server(newAPI func(context.Context, *coderd.Options) (*coderd.API, error)) *cobra.Command { var ( accessURL string address string @@ -506,7 +506,10 @@ func Server(newAPI func(*coderd.Options) *coderd.API) *cobra.Command { ), promAddress, "prometheus")() } - coderAPI := newAPI(options) + coderAPI, err := newAPI(ctx, options) + if err != nil { + return err + } defer coderAPI.Close() client := codersdk.New(localURL) @@ -553,7 +556,7 @@ func Server(newAPI func(*coderd.Options) *coderd.API) *cobra.Command { // These errors are typically noise like "TLS: EOF". Vault does similar: // https://github.com/hashicorp/vault/blob/e2490059d0711635e529a4efcbaa1b26998d6e1c/command/server.go#L2714 ErrorLog: log.New(io.Discard, "", 0), - Handler: coderAPI.Handler, + Handler: coderAPI.RootHandler, BaseContext: func(_ net.Listener) context.Context { return shutdownConnsCtx }, diff --git a/coderd/audit/request.go b/coderd/audit/request.go index 77fb1580de3ad..1be65bd084c59 100644 --- a/coderd/audit/request.go +++ b/coderd/audit/request.go @@ -12,14 +12,13 @@ import ( "cdr.dev/slog" "github.com/coder/coder/coderd/database" - "github.com/coder/coder/coderd/features" "github.com/coder/coder/coderd/httpapi" "github.com/coder/coder/coderd/httpmw" ) type RequestParams struct { - Features features.Service - Log slog.Logger + Audit Auditor + Log slog.Logger Request *http.Request Action database.AuditAction @@ -102,15 +101,6 @@ func InitRequest[T Auditable](w http.ResponseWriter, p *RequestParams) (*Request params: p, } - feats := struct { - Audit Auditor - }{} - err := p.Features.Get(&feats) - if err != nil { - p.Log.Error(p.Request.Context(), "unable to get auditor interface", slog.Error(err)) - return req, func() {} - } - return req, func() { ctx := context.Background() logCtx := p.Request.Context() @@ -120,7 +110,7 @@ func InitRequest[T Auditable](w http.ResponseWriter, p *RequestParams) (*Request return } - diff := Diff(feats.Audit, req.Old, req.New) + diff := Diff(p.Audit, req.Old, req.New) diffRaw, _ := json.Marshal(diff) ip, err := parseIP(p.Request.RemoteAddr) @@ -128,7 +118,7 @@ func InitRequest[T Auditable](w http.ResponseWriter, p *RequestParams) (*Request p.Log.Warn(logCtx, "parse ip", slog.Error(err)) } - err = feats.Audit.Export(ctx, database.AuditLog{ + err = p.Audit.Export(ctx, database.AuditLog{ ID: uuid.New(), Time: database.Now(), UserID: httpmw.APIKey(p.Request).UserID, diff --git a/coderd/authorize.go b/coderd/authorize.go index 55310cee78755..21099fcffc982 100644 --- a/coderd/authorize.go +++ b/coderd/authorize.go @@ -42,7 +42,7 @@ type HTTPAuthorizer struct { // return // } func (api *API) Authorize(r *http.Request, action rbac.Action, object rbac.Objecter) bool { - return api.httpAuth.Authorize(r, action, object) + return api.HTTPAuth.Authorize(r, action, object) } // Authorize will return false if the user is not authorized to do the action. diff --git a/coderd/coderd.go b/coderd/coderd.go index 14a18e7986aae..7b16e4856e54d 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -7,6 +7,7 @@ import ( "net/url" "path/filepath" "sync" + "sync/atomic" "time" "github.com/andybalholm/brotli" @@ -25,9 +26,9 @@ import ( "cdr.dev/slog" "github.com/coder/coder/buildinfo" + "github.com/coder/coder/coderd/audit" "github.com/coder/coder/coderd/awsidentity" "github.com/coder/coder/coderd/database" - "github.com/coder/coder/coderd/features" "github.com/coder/coder/coderd/gitsshkey" "github.com/coder/coder/coderd/httpapi" "github.com/coder/coder/coderd/httpmw" @@ -52,6 +53,7 @@ type Options struct { // CacheDir is used for caching files served by the API. CacheDir string + Auditor audit.Auditor AgentConnectionUpdateFrequency time.Duration AgentInactiveDisconnectTimeout time.Duration // APIRateLimit is the minutely throughput rate limit per user or ip. @@ -72,8 +74,6 @@ type Options struct { TURNServer *turnconn.Server TracerProvider *sdktrace.TracerProvider AutoImportTemplates []AutoImportTemplate - LicenseHandler http.Handler - FeaturesService features.Service TailscaleEnable bool TailnetCoordinator *tailnet.Coordinator @@ -85,6 +85,9 @@ type Options struct { // New constructs a Coder API handler. func New(options *Options) *API { + if options == nil { + options = &Options{} + } if options.AgentConnectionUpdateFrequency == 0 { options.AgentConnectionUpdateFrequency = 3 * time.Second } @@ -110,11 +113,8 @@ func New(options *Options) *API { if options.TailnetCoordinator == nil { options.TailnetCoordinator = tailnet.NewCoordinator() } - if options.LicenseHandler == nil { - options.LicenseHandler = licenses() - } - if options.FeaturesService == nil { - options.FeaturesService = &featuresService{} + if options.Auditor == nil { + options.Auditor = audit.NewNop() } siteCacheDir := options.CacheDir @@ -135,14 +135,17 @@ func New(options *Options) *API { r := chi.NewRouter() api := &API{ Options: options, - Handler: r, + RootHandler: r, siteHandler: site.Handler(site.FS(), binFS), - httpAuth: &HTTPAuthorizer{ + HTTPAuth: &HTTPAuthorizer{ Authorizer: options.Authorizer, Logger: options.Logger, }, metricsCache: metricsCache, + Auditor: atomic.Pointer[audit.Auditor]{}, } + api.Auditor.Store(&options.Auditor) + if options.TailscaleEnable { api.workspaceAgentCache = wsconncache.New(api.dialWorkspaceAgentTailnet, 0) } else { @@ -194,6 +197,8 @@ func New(options *Options) *API { }) r.Route("/api/v2", func(r chi.Router) { + api.APIHandler = r + r.NotFound(func(rw http.ResponseWriter, r *http.Request) { httpapi.Write(rw, http.StatusNotFound, codersdk.Response{ Message: "Route not found.", @@ -460,12 +465,9 @@ func New(options *Options) *API { }) r.Route("/entitlements", func(r chi.Router) { r.Use(apiKeyMiddleware) - r.Get("/", api.FeaturesService.EntitlementsAPI) - }) - r.Route("/licenses", func(r chi.Router) { - r.Use(apiKeyMiddleware) - r.Mount("/", options.LicenseHandler) + r.Get("/", entitlements) }) + r.HandleFunc("/licenses", unsupported) }) r.NotFound(compressHandler(http.HandlerFunc(api.siteHandler.ServeHTTP)).ServeHTTP) @@ -477,12 +479,14 @@ type API struct { derpServer *derp.Server - Handler chi.Router + Auditor atomic.Pointer[audit.Auditor] + RootHandler chi.Router + APIHandler chi.Router siteHandler http.Handler websocketWaitMutex sync.Mutex websocketWaitGroup sync.WaitGroup workspaceAgentCache *wsconncache.Cache - httpAuth *HTTPAuthorizer + HTTPAuth *HTTPAuthorizer metricsCache *metricscache.Cache } @@ -517,3 +521,26 @@ func compressHandler(h http.Handler) http.Handler { return cmp.Handler(h) } + +func entitlements(rw http.ResponseWriter, _ *http.Request) { + feats := make(map[string]codersdk.Feature) + for _, f := range codersdk.FeatureNames { + feats[f] = codersdk.Feature{ + Entitlement: codersdk.EntitlementNotEntitled, + Enabled: false, + } + } + httpapi.Write(rw, http.StatusOK, codersdk.Entitlements{ + Features: feats, + Warnings: []string{}, + HasLicense: false, + }) +} + +func unsupported(rw http.ResponseWriter, _ *http.Request) { + httpapi.Write(rw, http.StatusNotFound, codersdk.Response{ + Message: "Unsupported", + Detail: "These endpoints are not supported in AGPL-licensed Coder", + Validations: nil, + }) +} diff --git a/coderd/coderd_test.go b/coderd/coderd_test.go index 87206873b1073..ea4fafaae533e 100644 --- a/coderd/coderd_test.go +++ b/coderd/coderd_test.go @@ -17,6 +17,7 @@ import ( "github.com/coder/coder/buildinfo" "github.com/coder/coder/coderd/coderdtest" + "github.com/coder/coder/codersdk" "github.com/coder/coder/tailnet" "github.com/coder/coder/testutil" ) @@ -38,16 +39,6 @@ func TestBuildInfo(t *testing.T) { require.Equal(t, buildinfo.Version(), buildInfo.Version, "version") } -// TestAuthorizeAllEndpoints will check `authorize` is called on every endpoint registered. -func TestAuthorizeAllEndpoints(t *testing.T) { - t.Parallel() - ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) - defer cancel() - a := coderdtest.NewAuthTester(ctx, t, nil) - skipRoutes, assertRoute := coderdtest.AGPLRoutes(a) - a.Test(ctx, assertRoute, skipRoutes) -} - func TestDERP(t *testing.T) { t.Parallel() client := coderdtest.New(t, nil) @@ -124,3 +115,22 @@ func TestDERPLatencyCheck(t *testing.T) { defer res.Body.Close() require.Equal(t, http.StatusOK, res.StatusCode) } + +func TestEntitlements(t *testing.T) { + t.Parallel() + t.Run("GET", func(t *testing.T) { + t.Parallel() + client := coderdtest.New(t, nil) + _ = coderdtest.CreateFirstUser(t, client) + result, err := client.Entitlements(context.Background()) + require.NoError(t, err) + assert.False(t, result.HasLicense) + assert.Empty(t, result.Warnings) + for _, f := range codersdk.FeatureNames { + require.Contains(t, result.Features, f) + fe := result.Features[f] + assert.False(t, fe.Enabled) + assert.Equal(t, codersdk.EntitlementNotEntitled, fe.Entitlement) + } + }) +} diff --git a/coderd/coderdtest/coderdtest.go b/coderd/coderdtest/coderdtest.go index 490dce5a125a6..e9e371e8dc958 100644 --- a/coderd/coderdtest/coderdtest.go +++ b/coderd/coderdtest/coderdtest.go @@ -79,7 +79,6 @@ type Options struct { // IncludeProvisionerDaemon when true means to start an in-memory provisionerD IncludeProvisionerDaemon bool - APIBuilder func(*coderd.Options) *coderd.API MetricsCacheRefreshInterval time.Duration AgentStatsRefreshInterval time.Duration } @@ -115,10 +114,7 @@ func newWithCloser(t *testing.T, options *Options) (*codersdk.Client, io.Closer) return client, closer } -// newWithAPI constructs an in-memory API instance and returns a client to talk to it. -// Most tests never need a reference to the API, but AuthorizationTest in this module uses it. -// Do not expose the API or wrath shall descend upon thee. -func newWithAPI(t *testing.T, options *Options) (*codersdk.Client, io.Closer, *coderd.API) { +func NewOptions(t *testing.T, options *Options) (*httptest.Server, *coderd.Options) { if options == nil { options = &Options{} } @@ -139,9 +135,6 @@ func newWithAPI(t *testing.T, options *Options) (*codersdk.Client, io.Closer, *c close(options.AutobuildStats) }) } - if options.APIBuilder == nil { - options.APIBuilder = coderd.New - } // This can be hotswapped for a live database instance. db := databasefake.New() @@ -199,13 +192,7 @@ func newWithAPI(t *testing.T, options *Options) (*codersdk.Client, io.Closer, *c _ = turnServer.Close() }) - features := coderd.DisabledImplementations - if options.Auditor != nil { - features.Auditor = options.Auditor - } - - // We set the handler after server creation for the access URL. - coderAPI := options.APIBuilder(&coderd.Options{ + return srv, &coderd.Options{ AgentConnectionUpdateFrequency: 150 * time.Millisecond, // Force a long disconnection timeout to ensure // agents are not marked as disconnected during slow tests. @@ -216,6 +203,7 @@ func newWithAPI(t *testing.T, options *Options) (*codersdk.Client, io.Closer, *c Database: db, Pubsub: pubsub, + Auditor: options.Auditor, AWSCertificates: options.AWSCertificates, AzureCertificates: options.AzureCertificates, GithubOAuth2Config: options.GithubOAuth2Config, @@ -247,13 +235,23 @@ func newWithAPI(t *testing.T, options *Options) (*codersdk.Client, io.Closer, *c AutoImportTemplates: options.AutoImportTemplates, MetricsCacheRefreshInterval: options.MetricsCacheRefreshInterval, AgentStatsRefreshInterval: options.AgentStatsRefreshInterval, - FeaturesService: coderd.NewMockFeaturesService(features), - }) + } +} + +// newWithAPI constructs an in-memory API instance and returns a client to talk to it. +// Most tests never need a reference to the API, but AuthorizationTest in this module uses it. +// Do not expose the API or wrath shall descend upon thee. +func newWithAPI(t *testing.T, options *Options) (*codersdk.Client, io.Closer, *coderd.API) { + if options == nil { + options = &Options{} + } + srv, newOptions := NewOptions(t, options) + // We set the handler after server creation for the access URL. + coderAPI := coderd.New(newOptions) t.Cleanup(func() { _ = coderAPI.Close() }) - srv.Config.Handler = coderAPI.Handler - + srv.Config.Handler = coderAPI.RootHandler var provisionerCloser io.Closer = nopcloser{} if options.IncludeProvisionerDaemon { provisionerCloser = NewProvisionerDaemon(t, coderAPI) @@ -261,8 +259,7 @@ func newWithAPI(t *testing.T, options *Options) (*codersdk.Client, io.Closer, *c t.Cleanup(func() { _ = provisionerCloser.Close() }) - - return codersdk.New(serverURL), provisionerCloser, coderAPI + return codersdk.New(coderAPI.AccessURL), provisionerCloser, coderAPI } // NewProvisionerDaemon launches a provisionerd instance configured to work diff --git a/coderd/features.go b/coderd/features.go deleted file mode 100644 index 594fad2e38423..0000000000000 --- a/coderd/features.go +++ /dev/null @@ -1,97 +0,0 @@ -package coderd - -import ( - "net/http" - "reflect" - - "golang.org/x/xerrors" - - "github.com/coder/coder/coderd/audit" - "github.com/coder/coder/coderd/features" - "github.com/coder/coder/coderd/httpapi" - "github.com/coder/coder/codersdk" -) - -func NewMockFeaturesService(feats FeatureInterfaces) features.Service { - return &featuresService{ - feats: &feats, - } -} - -type featuresService struct { - feats *FeatureInterfaces -} - -func (*featuresService) EntitlementsAPI(rw http.ResponseWriter, _ *http.Request) { - feats := make(map[string]codersdk.Feature) - for _, f := range codersdk.FeatureNames { - feats[f] = codersdk.Feature{ - Entitlement: codersdk.EntitlementNotEntitled, - Enabled: false, - } - } - httpapi.Write(rw, http.StatusOK, codersdk.Entitlements{ - Features: feats, - Warnings: []string{}, - HasLicense: false, - }) -} - -// Get returns the implementations for feature interfaces. Parameter `s` must be a pointer to a -// struct type containing feature interfaces as fields. The AGPL featureService always returns the -// "disabled" version of the feature interface because it doesn't include any enterprise features -// by definition. -func (f *featuresService) Get(ps any) error { - if reflect.TypeOf(ps).Kind() != reflect.Pointer { - return xerrors.New("input must be pointer to struct") - } - vs := reflect.ValueOf(ps).Elem() - if vs.Kind() != reflect.Struct { - return xerrors.New("input must be pointer to struct") - } - for i := 0; i < vs.NumField(); i++ { - vf := vs.Field(i) - tf := vf.Type() - if tf.Kind() != reflect.Interface { - return xerrors.Errorf("fields of input struct must be interfaces: %s", tf.String()) - } - err := f.setImplementation(vf, tf) - if err != nil { - return err - } - } - return nil -} - -// setImplementation finds the correct implementation for the field's type, and sets it on the -// struct. It returns an error if unsuccessful -func (f *featuresService) setImplementation(vf reflect.Value, tf reflect.Type) error { - feats := f.feats - if feats == nil { - feats = &DisabledImplementations - } - - // when we get more than a few features it might make sense to have a data structure for finding - // the correct implementation that's faster than just a linear search, but for now just spin - // through the implementations we have. - vd := reflect.ValueOf(*feats) - for j := 0; j < vd.NumField(); j++ { - vdf := vd.Field(j) - if vdf.Type() == tf { - vf.Set(vdf) - return nil - } - } - return xerrors.Errorf("unable to find implementation of interface %s", tf.String()) -} - -// FeatureInterfaces contains a field for each interface controlled by an enterprise feature. -type FeatureInterfaces struct { - Auditor audit.Auditor -} - -// DisabledImplementations includes all the implementations of turned-off features. There are no -// turned-on implementations in AGPL code. -var DisabledImplementations = FeatureInterfaces{ - Auditor: audit.NewNop(), -} diff --git a/coderd/features/features.go b/coderd/features/features.go index d44bd5f2e40d1..f086931fa8003 100644 --- a/coderd/features/features.go +++ b/coderd/features/features.go @@ -1,11 +1,7 @@ package features -import "net/http" - // Service is the interface for interacting with enterprise features. type Service interface { - EntitlementsAPI(w http.ResponseWriter, r *http.Request) - // Get returns the implementations for feature interfaces. Parameter `s` must be a pointer to a // struct type containing feature interfaces as fields. The FeatureService sets all fields to // the correct implementations depending on whether the features are turned on. diff --git a/coderd/features_internal_test.go b/coderd/features_internal_test.go deleted file mode 100644 index cba3f3da89e50..0000000000000 --- a/coderd/features_internal_test.go +++ /dev/null @@ -1,100 +0,0 @@ -package coderd - -import ( - "encoding/json" - "net/http" - "net/http/httptest" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - "github.com/coder/coder/coderd/audit" - "github.com/coder/coder/codersdk" -) - -func TestEntitlements(t *testing.T) { - t.Parallel() - t.Run("GET", func(t *testing.T) { - t.Parallel() - r := httptest.NewRequest("GET", "https://example.com/api/v2/entitlements", nil) - rw := httptest.NewRecorder() - (&featuresService{}).EntitlementsAPI(rw, r) - resp := rw.Result() - defer resp.Body.Close() - assert.Equal(t, http.StatusOK, resp.StatusCode) - dec := json.NewDecoder(resp.Body) - var result codersdk.Entitlements - err := dec.Decode(&result) - require.NoError(t, err) - assert.False(t, result.HasLicense) - assert.Empty(t, result.Warnings) - for _, f := range codersdk.FeatureNames { - require.Contains(t, result.Features, f) - fe := result.Features[f] - assert.False(t, fe.Enabled) - assert.Equal(t, codersdk.EntitlementNotEntitled, fe.Entitlement) - } - }) -} - -func TestFeaturesServiceGet(t *testing.T) { - t.Parallel() - t.Run("Auditor", func(t *testing.T) { - t.Parallel() - uut := featuresService{} - target := struct { - Auditor audit.Auditor - }{} - err := uut.Get(&target) - require.NoError(t, err) - assert.NotNil(t, target.Auditor) - }) - - t.Run("NotPointer", func(t *testing.T) { - t.Parallel() - uut := featuresService{} - target := struct { - Auditor audit.Auditor - }{} - err := uut.Get(target) - require.Error(t, err) - assert.Nil(t, target.Auditor) - }) - - t.Run("UnknownInterface", func(t *testing.T) { - t.Parallel() - uut := featuresService{} - target := struct { - test testInterface - }{} - err := uut.Get(&target) - require.Error(t, err) - assert.Nil(t, target.test) - }) - - t.Run("PointerToNonStruct", func(t *testing.T) { - t.Parallel() - uut := featuresService{} - var target audit.Auditor - err := uut.Get(&target) - require.Error(t, err) - assert.Nil(t, target) - }) - - t.Run("StructWithNonInterfaces", func(t *testing.T) { - t.Parallel() - uut := featuresService{} - target := struct { - N int64 - Auditor audit.Auditor - }{} - err := uut.Get(&target) - require.Error(t, err) - assert.Nil(t, target.Auditor) - }) -} - -type testInterface interface { - Test() error -} diff --git a/coderd/licenses.go b/coderd/licenses.go deleted file mode 100644 index 28a0b1d418043..0000000000000 --- a/coderd/licenses.go +++ /dev/null @@ -1,24 +0,0 @@ -package coderd - -import ( - "net/http" - - "github.com/go-chi/chi/v5" - - "github.com/coder/coder/coderd/httpapi" - "github.com/coder/coder/codersdk" -) - -func licenses() http.Handler { - r := chi.NewRouter() - r.NotFound(unsupported) - return r -} - -func unsupported(rw http.ResponseWriter, _ *http.Request) { - httpapi.Write(rw, http.StatusNotFound, codersdk.Response{ - Message: "Unsupported", - Detail: "These endpoints are not supported in AGPL-licensed Coder", - Validations: nil, - }) -} diff --git a/coderd/provisionerdaemons.go b/coderd/provisionerdaemons.go index df4197a70ae18..247a31d015380 100644 --- a/coderd/provisionerdaemons.go +++ b/coderd/provisionerdaemons.go @@ -48,7 +48,7 @@ func (api *API) provisionerDaemons(rw http.ResponseWriter, r *http.Request) { if daemons == nil { daemons = []database.ProvisionerDaemon{} } - daemons, err = AuthorizeFilter(api.httpAuth, r, rbac.ActionRead, daemons) + daemons, err = AuthorizeFilter(api.HTTPAuth, r, rbac.ActionRead, daemons) if err != nil { httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{ Message: "Internal error fetching provisioner daemons.", diff --git a/coderd/provisionerjobs_internal_test.go b/coderd/provisionerjobs_internal_test.go index 67004661d9583..3dd7b527a53fc 100644 --- a/coderd/provisionerjobs_internal_test.go +++ b/coderd/provisionerjobs_internal_test.go @@ -41,7 +41,7 @@ func TestProvisionerJobLogs_Unit(t *testing.T) { api := New(&opts) defer api.Close() - server := httptest.NewServer(api.Handler) + server := httptest.NewServer(api.RootHandler) defer server.Close() userID := uuid.New() keyID, keySecret, err := generateAPIKeyIDSecret() diff --git a/coderd/templates.go b/coderd/templates.go index c48531a25c226..06aea92fd7d5b 100644 --- a/coderd/templates.go +++ b/coderd/templates.go @@ -86,10 +86,10 @@ func (api *API) deleteTemplate(rw http.ResponseWriter, r *http.Request) { var ( template = httpmw.TemplateParam(r) aReq, commitAudit = audit.InitRequest[database.Template](rw, &audit.RequestParams{ - Features: api.FeaturesService, - Log: api.Logger, - Request: r, - Action: database.AuditActionDelete, + Audit: *api.Auditor.Load(), + Log: api.Logger, + Request: r, + Action: database.AuditActionDelete, }) ) defer commitAudit() @@ -140,16 +140,16 @@ func (api *API) postTemplateByOrganization(rw http.ResponseWriter, r *http.Reque organization = httpmw.OrganizationParam(r) apiKey = httpmw.APIKey(r) templateAudit, commitTemplateAudit = audit.InitRequest[database.Template](rw, &audit.RequestParams{ - Features: api.FeaturesService, - Log: api.Logger, - Request: r, - Action: database.AuditActionCreate, + Audit: *api.Auditor.Load(), + Log: api.Logger, + Request: r, + Action: database.AuditActionCreate, }) templateVersionAudit, commitTemplateVersionAudit = audit.InitRequest[database.TemplateVersion](rw, &audit.RequestParams{ - Features: api.FeaturesService, - Log: api.Logger, - Request: r, - Action: database.AuditActionWrite, + Audit: *api.Auditor.Load(), + Log: api.Logger, + Request: r, + Action: database.AuditActionWrite, }) ) defer commitTemplateAudit() @@ -340,7 +340,7 @@ func (api *API) templatesByOrganization(rw http.ResponseWriter, r *http.Request) } // Filter templates based on rbac permissions - templates, err = AuthorizeFilter(api.httpAuth, r, rbac.ActionRead, templates) + templates, err = AuthorizeFilter(api.HTTPAuth, r, rbac.ActionRead, templates) if err != nil { httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{ Message: "Internal error fetching templates.", @@ -436,10 +436,10 @@ func (api *API) patchTemplateMeta(rw http.ResponseWriter, r *http.Request) { var ( template = httpmw.TemplateParam(r) aReq, commitAudit = audit.InitRequest[database.Template](rw, &audit.RequestParams{ - Features: api.FeaturesService, - Log: api.Logger, - Request: r, - Action: database.AuditActionWrite, + Audit: *api.Auditor.Load(), + Log: api.Logger, + Request: r, + Action: database.AuditActionWrite, }) ) defer commitAudit() diff --git a/coderd/templateversions.go b/coderd/templateversions.go index c14e3ab9ca07c..2bb691d53e3b2 100644 --- a/coderd/templateversions.go +++ b/coderd/templateversions.go @@ -560,10 +560,10 @@ func (api *API) patchActiveTemplateVersion(rw http.ResponseWriter, r *http.Reque var ( template = httpmw.TemplateParam(r) aReq, commitAudit = audit.InitRequest[database.Template](rw, &audit.RequestParams{ - Features: api.FeaturesService, - Log: api.Logger, - Request: r, - Action: database.AuditActionWrite, + Audit: *api.Auditor.Load(), + Log: api.Logger, + Request: r, + Action: database.AuditActionWrite, }) ) defer commitAudit() @@ -632,10 +632,10 @@ func (api *API) postTemplateVersionsByOrganization(rw http.ResponseWriter, r *ht apiKey = httpmw.APIKey(r) organization = httpmw.OrganizationParam(r) aReq, commitAudit = audit.InitRequest[database.TemplateVersion](rw, &audit.RequestParams{ - Features: api.FeaturesService, - Log: api.Logger, - Request: r, - Action: database.AuditActionCreate, + Audit: *api.Auditor.Load(), + Log: api.Logger, + Request: r, + Action: database.AuditActionCreate, }) req codersdk.CreateTemplateVersionRequest diff --git a/coderd/users.go b/coderd/users.go index 6c8046d94665d..58e0055218520 100644 --- a/coderd/users.go +++ b/coderd/users.go @@ -220,7 +220,7 @@ func (api *API) users(rw http.ResponseWriter, r *http.Request) { return } - users, err = AuthorizeFilter(api.httpAuth, r, rbac.ActionRead, users) + users, err = AuthorizeFilter(api.HTTPAuth, r, rbac.ActionRead, users) if err != nil { httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{ Message: "Internal error fetching users.", @@ -256,10 +256,10 @@ func (api *API) users(rw http.ResponseWriter, r *http.Request) { // Creates a new user. func (api *API) postUser(rw http.ResponseWriter, r *http.Request) { aReq, commitAudit := audit.InitRequest[database.User](rw, &audit.RequestParams{ - Features: api.FeaturesService, - Log: api.Logger, - Request: r, - Action: database.AuditActionCreate, + Audit: *api.Auditor.Load(), + Log: api.Logger, + Request: r, + Action: database.AuditActionCreate, }) defer commitAudit() @@ -364,10 +364,10 @@ func (api *API) putUserProfile(rw http.ResponseWriter, r *http.Request) { var ( user = httpmw.UserParam(r) aReq, commitAudit = audit.InitRequest[database.User](rw, &audit.RequestParams{ - Features: api.FeaturesService, - Log: api.Logger, - Request: r, - Action: database.AuditActionWrite, + Audit: *api.Auditor.Load(), + Log: api.Logger, + Request: r, + Action: database.AuditActionWrite, }) ) defer commitAudit() @@ -444,10 +444,10 @@ func (api *API) putUserStatus(status database.UserStatus) func(rw http.ResponseW user = httpmw.UserParam(r) apiKey = httpmw.APIKey(r) aReq, commitAudit = audit.InitRequest[database.User](rw, &audit.RequestParams{ - Features: api.FeaturesService, - Log: api.Logger, - Request: r, - Action: database.AuditActionWrite, + Audit: *api.Auditor.Load(), + Log: api.Logger, + Request: r, + Action: database.AuditActionWrite, }) ) defer commitAudit() @@ -510,10 +510,10 @@ func (api *API) putUserPassword(rw http.ResponseWriter, r *http.Request) { user = httpmw.UserParam(r) params codersdk.UpdateUserPasswordRequest aReq, commitAudit = audit.InitRequest[database.User](rw, &audit.RequestParams{ - Features: api.FeaturesService, - Log: api.Logger, - Request: r, - Action: database.AuditActionWrite, + Audit: *api.Auditor.Load(), + Log: api.Logger, + Request: r, + Action: database.AuditActionWrite, }) ) defer commitAudit() @@ -622,7 +622,7 @@ func (api *API) userRoles(rw http.ResponseWriter, r *http.Request) { } // Only include ones we can read from RBAC. - memberships, err = AuthorizeFilter(api.httpAuth, r, rbac.ActionRead, memberships) + memberships, err = AuthorizeFilter(api.HTTPAuth, r, rbac.ActionRead, memberships) if err != nil { httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{ Message: "Internal error fetching memberships.", @@ -648,10 +648,10 @@ func (api *API) putUserRoles(rw http.ResponseWriter, r *http.Request) { actorRoles = httpmw.AuthorizationUserRoles(r) apiKey = httpmw.APIKey(r) aReq, commitAudit = audit.InitRequest[database.User](rw, &audit.RequestParams{ - Features: api.FeaturesService, - Log: api.Logger, - Request: r, - Action: database.AuditActionWrite, + Audit: *api.Auditor.Load(), + Log: api.Logger, + Request: r, + Action: database.AuditActionWrite, }) ) defer commitAudit() @@ -761,7 +761,7 @@ func (api *API) organizationsByUser(rw http.ResponseWriter, r *http.Request) { } // Only return orgs the user can read. - organizations, err = AuthorizeFilter(api.httpAuth, r, rbac.ActionRead, organizations) + organizations, err = AuthorizeFilter(api.HTTPAuth, r, rbac.ActionRead, organizations) if err != nil { httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{ Message: "Internal error fetching organizations.", diff --git a/coderd/workspaces.go b/coderd/workspaces.go index c021f9a033133..3c6ad3314db8c 100644 --- a/coderd/workspaces.go +++ b/coderd/workspaces.go @@ -144,7 +144,7 @@ func (api *API) workspaces(rw http.ResponseWriter, r *http.Request) { } // Only return workspaces the user can read - workspaces, err = AuthorizeFilter(api.httpAuth, r, rbac.ActionRead, workspaces) + workspaces, err = AuthorizeFilter(api.HTTPAuth, r, rbac.ActionRead, workspaces) if err != nil { httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{ Message: "Internal error fetching workspaces.", @@ -253,10 +253,10 @@ func (api *API) postWorkspacesByOrganization(rw http.ResponseWriter, r *http.Req organization = httpmw.OrganizationParam(r) apiKey = httpmw.APIKey(r) aReq, commitAudit = audit.InitRequest[database.Workspace](rw, &audit.RequestParams{ - Features: api.FeaturesService, - Log: api.Logger, - Request: r, - Action: database.AuditActionCreate, + Audit: *api.Auditor.Load(), + Log: api.Logger, + Request: r, + Action: database.AuditActionCreate, }) ) defer commitAudit() @@ -492,10 +492,10 @@ func (api *API) patchWorkspace(rw http.ResponseWriter, r *http.Request) { var ( workspace = httpmw.WorkspaceParam(r) aReq, commitAudit = audit.InitRequest[database.Workspace](rw, &audit.RequestParams{ - Features: api.FeaturesService, - Log: api.Logger, - Request: r, - Action: database.AuditActionWrite, + Audit: *api.Auditor.Load(), + Log: api.Logger, + Request: r, + Action: database.AuditActionWrite, }) ) defer commitAudit() @@ -568,10 +568,10 @@ func (api *API) putWorkspaceAutostart(rw http.ResponseWriter, r *http.Request) { var ( workspace = httpmw.WorkspaceParam(r) aReq, commitAudit = audit.InitRequest[database.Workspace](rw, &audit.RequestParams{ - Features: api.FeaturesService, - Log: api.Logger, - Request: r, - Action: database.AuditActionWrite, + Audit: *api.Auditor.Load(), + Log: api.Logger, + Request: r, + Action: database.AuditActionWrite, }) ) defer commitAudit() @@ -628,10 +628,10 @@ func (api *API) putWorkspaceTTL(rw http.ResponseWriter, r *http.Request) { var ( workspace = httpmw.WorkspaceParam(r) aReq, commitAudit = audit.InitRequest[database.Workspace](rw, &audit.RequestParams{ - Features: api.FeaturesService, - Log: api.Logger, - Request: r, - Action: database.AuditActionWrite, + Audit: *api.Auditor.Load(), + Log: api.Logger, + Request: r, + Action: database.AuditActionWrite, }) ) defer commitAudit() diff --git a/enterprise/cli/licenses_test.go b/enterprise/cli/licenses_test.go index 8a7f2076d56e6..a56e4a73277b6 100644 --- a/enterprise/cli/licenses_test.go +++ b/enterprise/cli/licenses_test.go @@ -23,7 +23,7 @@ import ( "github.com/coder/coder/coderd/httpapi" "github.com/coder/coder/codersdk" "github.com/coder/coder/enterprise/cli" - "github.com/coder/coder/enterprise/coderd" + "github.com/coder/coder/enterprise/coderd/coderdenttest" "github.com/coder/coder/pty/ptytest" "github.com/coder/coder/testutil" ) @@ -124,7 +124,7 @@ func TestLicensesAddReal(t *testing.T) { t.Parallel() t.Run("Fails", func(t *testing.T) { t.Parallel() - client := coderdtest.New(t, &coderdtest.Options{APIBuilder: coderd.NewEnterprise}) + client := coderdenttest.New(t, nil) coderdtest.CreateFirstUser(t, client) cmd, root := clitest.NewWithSubcommands(t, cli.EnterpriseSubcommands(), "licenses", "add", "-l", fakeLicenseJWT) @@ -175,7 +175,7 @@ func TestLicensesListReal(t *testing.T) { t.Parallel() t.Run("Empty", func(t *testing.T) { t.Parallel() - client := coderdtest.New(t, &coderdtest.Options{APIBuilder: coderd.NewEnterprise}) + client := coderdenttest.New(t, nil) coderdtest.CreateFirstUser(t, client) cmd, root := clitest.NewWithSubcommands(t, cli.EnterpriseSubcommands(), "licenses", "list") @@ -219,7 +219,7 @@ func TestLicensesDeleteReal(t *testing.T) { t.Parallel() t.Run("Empty", func(t *testing.T) { t.Parallel() - client := coderdtest.New(t, &coderdtest.Options{APIBuilder: coderd.NewEnterprise}) + client := coderdenttest.New(t, nil) coderdtest.CreateFirstUser(t, client) cmd, root := clitest.NewWithSubcommands(t, cli.EnterpriseSubcommands(), "licenses", "delete", "1") diff --git a/enterprise/cli/root.go b/enterprise/cli/root.go index 31546b5d679d0..111dd9885f075 100644 --- a/enterprise/cli/root.go +++ b/enterprise/cli/root.go @@ -1,15 +1,26 @@ package cli import ( + "context" + "github.com/spf13/cobra" agpl "github.com/coder/coder/cli" + agplcoderd "github.com/coder/coder/coderd" "github.com/coder/coder/enterprise/coderd" ) func enterpriseOnly() []*cobra.Command { return []*cobra.Command{ - agpl.Server(coderd.NewEnterprise), + agpl.Server(func(ctx context.Context, options *agplcoderd.Options) (*agplcoderd.API, error) { + api, err := coderd.New(ctx, &coderd.Options{ + Options: options, + }) + if err != nil { + return nil, err + } + return api.AGPL, nil + }), licenses(), } } diff --git a/enterprise/coderd/auth_internal_test.go b/enterprise/coderd/auth_internal_test.go deleted file mode 100644 index 853b6f44c4eda..0000000000000 --- a/enterprise/coderd/auth_internal_test.go +++ /dev/null @@ -1,80 +0,0 @@ -package coderd - -import ( - "context" - "crypto/ed25519" - "crypto/rand" - "fmt" - "net/http" - "testing" - "time" - - "github.com/golang-jwt/jwt/v4" - "github.com/stretchr/testify/require" - - "github.com/coder/coder/coderd/coderdtest" - "github.com/coder/coder/coderd/rbac" - "github.com/coder/coder/codersdk" - "github.com/coder/coder/testutil" -) - -// TestAuthorizeAllEndpoints will check `authorize` is called on every endpoint registered. -// these tests patch the map of license keys, so cannot be run in parallel -// nolint:paralleltest -func TestAuthorizeAllEndpoints(t *testing.T) { - pubKey, privKey, err := ed25519.GenerateKey(rand.Reader) - require.NoError(t, err) - keyID := "testing" - oldKeys := keys - defer func() { - t.Log("restoring keys") - keys = oldKeys - }() - keys = map[string]ed25519.PublicKey{keyID: pubKey} - - ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) - defer cancel() - a := coderdtest.NewAuthTester(ctx, t, &coderdtest.Options{APIBuilder: NewEnterprise}) - - // We need a license in the DB, so that when we call GET api/v2/licenses there is one in the - // list to check authz on. - claims := &Claims{ - RegisteredClaims: jwt.RegisteredClaims{ - Issuer: "test@coder.test", - IssuedAt: jwt.NewNumericDate(time.Now()), - NotBefore: jwt.NewNumericDate(time.Now()), - ExpiresAt: jwt.NewNumericDate(time.Now().Add(2 * time.Hour)), - }, - LicenseExpires: jwt.NewNumericDate(time.Now().Add(time.Hour)), - AccountType: AccountTypeSalesforce, - AccountID: "testing", - Version: CurrentVersion, - Features: Features{ - UserLimit: 0, - AuditLog: 1, - }, - } - lic, err := makeLicense(claims, privKey, keyID) - require.NoError(t, err) - license, err := a.Client.AddLicense(ctx, codersdk.AddLicenseRequest{ - License: lic, - }) - require.NoError(t, err) - a.URLParams["licenses/{id}"] = fmt.Sprintf("licenses/%d", license.ID) - - skipRoutes, assertRoute := coderdtest.AGPLRoutes(a) - assertRoute["POST:/api/v2/licenses"] = coderdtest.RouteCheck{ - AssertAction: rbac.ActionCreate, - AssertObject: rbac.ResourceLicense, - } - assertRoute["GET:/api/v2/licenses"] = coderdtest.RouteCheck{ - StatusCode: http.StatusOK, - AssertAction: rbac.ActionRead, - AssertObject: rbac.ResourceLicense, - } - assertRoute["DELETE:/api/v2/licenses/{id}"] = coderdtest.RouteCheck{ - AssertAction: rbac.ActionDelete, - AssertObject: rbac.ResourceLicense, - } - a.Test(ctx, assertRoute, skipRoutes) -} diff --git a/enterprise/coderd/coderd.go b/enterprise/coderd/coderd.go index 598c32f11b367..0058e830271ea 100644 --- a/enterprise/coderd/coderd.go +++ b/enterprise/coderd/coderd.go @@ -2,48 +2,216 @@ package coderd import ( "context" - "os" - "strings" + "crypto/ed25519" + "fmt" + "net/http" + "sync" + "time" "golang.org/x/xerrors" + "github.com/go-chi/chi/v5" + + "cdr.dev/slog" "github.com/coder/coder/coderd" - "github.com/coder/coder/coderd/rbac" + agplaudit "github.com/coder/coder/coderd/audit" + "github.com/coder/coder/coderd/httpapi" + "github.com/coder/coder/coderd/httpmw" + "github.com/coder/coder/codersdk" + "github.com/coder/coder/enterprise/audit" + "github.com/coder/coder/enterprise/audit/backends" ) -const EnvAuditLogEnable = "CODER_AUDIT_LOG_ENABLE" +// New constructs an Enterprise coderd API instance. +// This handler is designed to wrap the AGPL Coder code and +// layer Enterprise functionality on top as much as possible. +func New(ctx context.Context, options *Options) (*API, error) { + if options.EntitlementsUpdateInterval == 0 { + options.EntitlementsUpdateInterval = 10 * time.Minute + } + if options.Keys == nil { + options.Keys = Keys + } + ctx, cancelFunc := context.WithCancel(ctx) + api := &API{ + AGPL: coderd.New(options.Options), + Options: options, + + auditLogs: codersdk.EntitlementNotEntitled, + cancelEntitlementsLoop: cancelFunc, + } + oauthConfigs := &httpmw.OAuth2Configs{ + Github: options.GithubOAuth2Config, + OIDC: options.OIDCConfig, + } + apiKeyMiddleware := httpmw.ExtractAPIKey(options.Database, oauthConfigs, false) + + api.AGPL.APIHandler.Group(func(r chi.Router) { + r.Get("/entitlements", api.entitlements) + r.Route("/licenses", func(r chi.Router) { + r.Use(apiKeyMiddleware) + r.Post("/", api.postLicense) + r.Get("/", api.licenses) + r.Delete("/{id}", api.deleteLicense) + }) + }) + + err := api.updateEntitlements(ctx) + if err != nil { + return nil, xerrors.Errorf("update entitlements: %w", err) + } + api.closeLicenseSubscribe, err = api.Pubsub.Subscribe(pubSubEventLicenses, func(ctx context.Context, message []byte) { + _ = api.updateEntitlements(ctx) + }) + if err != nil { + return nil, xerrors.Errorf("subscribe to license updates: %w", err) + } + go api.runEntitlementsLoop(ctx) + + return api, nil +} + +type Options struct { + *coderd.Options + + EntitlementsUpdateInterval time.Duration + Keys map[string]ed25519.PublicKey +} + +type API struct { + AGPL *coderd.API + *Options + + closeLicenseSubscribe func() + cancelEntitlementsLoop func() + mutex sync.RWMutex + hasLicense bool + activeUsers codersdk.Feature + auditLogs codersdk.Entitlement +} + +func (api *API) Close() error { + api.closeLicenseSubscribe() + api.cancelEntitlementsLoop() + return api.AGPL.Close() +} + +func (api *API) runEntitlementsLoop(ctx context.Context) { + ticker := time.NewTicker(api.EntitlementsUpdateInterval) + defer ticker.Stop() + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + } + err := api.updateEntitlements(ctx) + if err != nil { + api.Logger.Warn(ctx, "failed to get feature entitlements", slog.Error(err)) + continue + } + } +} + +func (api *API) updateEntitlements(ctx context.Context) error { + licenses, err := api.Database.GetUnexpiredLicenses(ctx) + if err != nil { + return err + } + api.mutex.Lock() + defer api.mutex.Unlock() + now := time.Now() + auditLogs := api.auditLogs + for _, l := range licenses { + claims, err := validateDBLicense(l, api.Keys) + if err != nil { + api.Logger.Debug(ctx, "skipping invalid license", + slog.F("id", l.ID), slog.Error(err)) + continue + } + api.hasLicense = true + entitlement := codersdk.EntitlementEntitled + if now.After(claims.LicenseExpires.Time) { + // if the grace period were over, the validation fails, so if we are after + // LicenseExpires we must be in grace period. + entitlement = codersdk.EntitlementGracePeriod + } + if claims.Features.UserLimit > 0 { + api.activeUsers.Enabled = true + api.activeUsers.Entitlement = entitlement + currentLimit := int64(0) + if api.activeUsers.Limit != nil { + currentLimit = *api.activeUsers.Limit + } + limit := max(currentLimit, claims.Features.UserLimit) + api.activeUsers.Limit = &limit + } + if claims.Features.AuditLog > 0 { + api.auditLogs = entitlement + } + } + if auditLogs != api.auditLogs { + auditor := agplaudit.NewNop() + if api.auditLogs == codersdk.EntitlementEntitled { + auditor = audit.NewAuditor( + audit.DefaultFilter, + backends.NewPostgres(api.Database, true), + backends.NewSlog(api.Logger), + ) + } + api.AGPL.Auditor.Store(&auditor) + } + return nil +} + +func (api *API) entitlements(rw http.ResponseWriter, r *http.Request) { + api.mutex.RLock() + hasLicense := api.hasLicense + activeUsers := api.activeUsers + auditLogs := api.auditLogs + api.mutex.RUnlock() -func NewEnterprise(options *coderd.Options) *coderd.API { - var eOpts = *options - if eOpts.Authorizer == nil { - var err error - eOpts.Authorizer, err = rbac.NewAuthorizer() + resp := codersdk.Entitlements{ + Features: make(map[string]codersdk.Feature), + Warnings: make([]string, 0), + HasLicense: hasLicense, + } + + if activeUsers.Limit != nil { + activeUserCount, err := api.Database.GetActiveUserCount(r.Context()) if err != nil { - // This should never happen, as the unit tests would fail if the - // default built in authorizer failed. - panic(xerrors.Errorf("rego authorize panic: %w", err)) - } - } - eOpts.LicenseHandler = newLicenseAPI( - eOpts.Logger, - eOpts.Database, - eOpts.Pubsub, - &coderd.HTTPAuthorizer{ - Authorizer: eOpts.Authorizer, - Logger: eOpts.Logger, - }).handler() - en := Enablements{AuditLogs: true} - auditLog := os.Getenv(EnvAuditLogEnable) - auditLog = strings.ToLower(auditLog) - if auditLog == "disable" || auditLog == "false" || auditLog == "0" || auditLog == "no" { - en.AuditLogs = false - } - eOpts.FeaturesService = newFeaturesService( - context.Background(), - eOpts.Logger, - eOpts.Database, - eOpts.Pubsub, - en, - ) - return coderd.New(&eOpts) + httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{ + Message: "Unable to query database", + Detail: err.Error(), + }) + return + } + activeUsers.Actual = &activeUserCount + if activeUserCount > *activeUsers.Limit { + resp.Warnings = append(resp.Warnings, + fmt.Sprintf( + "Your deployment has %d active users but is only licensed for %d.", + activeUserCount, *activeUsers.Limit)) + } + } + resp.Features[codersdk.FeatureUserLimit] = activeUsers + + // Audit logs + resp.Features[codersdk.FeatureAuditLog] = codersdk.Feature{ + Entitlement: auditLogs, + Enabled: true, + } + if auditLogs == codersdk.EntitlementGracePeriod { + resp.Warnings = append(resp.Warnings, + "Audit logging is enabled but your license for this feature is expired.") + } + + httpapi.Write(rw, http.StatusOK, resp) +} + +func max(a, b int64) int64 { + if a > b { + return a + } + return b } diff --git a/enterprise/coderd/coderd_test.go b/enterprise/coderd/coderd_test.go new file mode 100644 index 0000000000000..114a02d2ff472 --- /dev/null +++ b/enterprise/coderd/coderd_test.go @@ -0,0 +1,93 @@ +package coderd_test + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/goleak" + + "github.com/coder/coder/coderd/coderdtest" + "github.com/coder/coder/codersdk" + "github.com/coder/coder/enterprise/coderd/coderdenttest" +) + +func TestMain(m *testing.M) { + goleak.VerifyTestMain(m) +} + +func TestEntitlements(t *testing.T) { + t.Parallel() + t.Run("NoLicense", func(t *testing.T) { + t.Parallel() + client := coderdenttest.New(t, nil) + res, err := client.Entitlements(context.Background()) + require.NoError(t, err) + require.False(t, res.HasLicense) + require.Empty(t, res.Warnings) + }) + t.Run("NoLicense", func(t *testing.T) { + t.Parallel() + client := coderdenttest.New(t, nil) + res, err := client.Entitlements(context.Background()) + require.NoError(t, err) + require.False(t, res.HasLicense) + require.Empty(t, res.Warnings) + }) + t.Run("FullLicense", func(t *testing.T) { + t.Parallel() + client := coderdenttest.New(t, nil) + _ = coderdtest.CreateFirstUser(t, client) + coderdenttest.AddLicense(t, client, coderdenttest.AddLicenseOptions{ + UserLimit: 100, + AuditLog: true, + }) + res, err := client.Entitlements(context.Background()) + require.NoError(t, err) + assert.True(t, res.HasLicense) + ul := res.Features[codersdk.FeatureUserLimit] + assert.Equal(t, codersdk.EntitlementEntitled, ul.Entitlement) + assert.Equal(t, int64(100), *ul.Limit) + assert.Equal(t, int64(1), *ul.Actual) + assert.True(t, ul.Enabled) + al := res.Features[codersdk.FeatureAuditLog] + assert.Equal(t, codersdk.EntitlementEntitled, al.Entitlement) + assert.True(t, al.Enabled) + assert.Nil(t, al.Limit) + assert.Nil(t, al.Actual) + assert.Empty(t, res.Warnings) + }) + t.Run("Warnings", func(t *testing.T) { + t.Parallel() + client := coderdenttest.New(t, nil) + first := coderdtest.CreateFirstUser(t, client) + for i := 0; i < 4; i++ { + coderdtest.CreateAnotherUser(t, client, first.OrganizationID) + } + coderdenttest.AddLicense(t, client, coderdenttest.AddLicenseOptions{ + UserLimit: 4, + AuditLog: true, + GraceAt: time.Now().Add(-time.Second), + }) + res, err := client.Entitlements(context.Background()) + require.NoError(t, err) + assert.True(t, res.HasLicense) + ul := res.Features[codersdk.FeatureUserLimit] + assert.Equal(t, codersdk.EntitlementGracePeriod, ul.Entitlement) + assert.Equal(t, int64(4), *ul.Limit) + assert.Equal(t, int64(5), *ul.Actual) + assert.True(t, ul.Enabled) + al := res.Features[codersdk.FeatureAuditLog] + assert.Equal(t, codersdk.EntitlementGracePeriod, al.Entitlement) + assert.True(t, al.Enabled) + assert.Nil(t, al.Limit) + assert.Nil(t, al.Actual) + assert.Len(t, res.Warnings, 2) + assert.Contains(t, res.Warnings, + "Your deployment has 5 active users but is only licensed for 4.") + assert.Contains(t, res.Warnings, + "Audit logging is enabled but your license for this feature is expired.") + }) +} diff --git a/enterprise/coderd/coderdenttest/coderdenttest.go b/enterprise/coderd/coderdenttest/coderdenttest.go new file mode 100644 index 0000000000000..01c6bcbeeca9a --- /dev/null +++ b/enterprise/coderd/coderdenttest/coderdenttest.go @@ -0,0 +1,124 @@ +package coderdenttest + +import ( + "context" + "crypto/ed25519" + "crypto/rand" + "io" + "testing" + "time" + + "github.com/golang-jwt/jwt/v4" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/coder/coder/coderd/coderdtest" + "github.com/coder/coder/codersdk" + "github.com/coder/coder/enterprise/coderd" +) + +const ( + testKeyID = "enterprise-test" +) + +var ( + testPrivateKey ed25519.PrivateKey + testPublicKey ed25519.PublicKey +) + +func init() { + var err error + testPublicKey, testPrivateKey, err = ed25519.GenerateKey(rand.Reader) + if err != nil { + panic(err) + } +} + +type Options struct { + *coderdtest.Options +} + +// New constructs a codersdk client connected to an in-memory Enterprise API instance. +func New(t *testing.T, options *Options) *codersdk.Client { + client, _, _ := NewWithAPI(t, options) + return client +} + +func NewWithAPI(t *testing.T, options *Options) (*codersdk.Client, io.Closer, *coderd.API) { + if options == nil { + options = &Options{} + } + if options.Options == nil { + options.Options = &coderdtest.Options{} + } + srv, oop := coderdtest.NewOptions(t, options.Options) + coderAPI, err := coderd.New(context.Background(), &coderd.Options{ + Options: oop, + Keys: map[string]ed25519.PublicKey{ + testKeyID: testPublicKey, + }, + }) + assert.NoError(t, err) + srv.Config.Handler = coderAPI.AGPL.RootHandler + var provisionerCloser io.Closer = nopcloser{} + if options.IncludeProvisionerDaemon { + provisionerCloser = coderdtest.NewProvisionerDaemon(t, coderAPI.AGPL) + } + t.Cleanup(func() { + _ = provisionerCloser.Close() + _ = coderAPI.Close() + }) + return codersdk.New(coderAPI.AccessURL), provisionerCloser, coderAPI +} + +type AddLicenseOptions struct { + AccountType string + AccountID string + GraceAt time.Time + ExpiresAt time.Time + UserLimit int64 + AuditLog bool +} + +// AddLicense generates a new license with the options provided and inserts it. +func AddLicense(t *testing.T, client *codersdk.Client, options AddLicenseOptions) codersdk.License { + if options.ExpiresAt.IsZero() { + options.ExpiresAt = time.Now().Add(time.Hour) + } + if options.GraceAt.IsZero() { + options.GraceAt = time.Now().Add(time.Hour) + } + auditLog := int64(0) + if options.AuditLog { + auditLog = 1 + } + c := &coderd.Claims{ + RegisteredClaims: jwt.RegisteredClaims{ + Issuer: "test@testing.test", + ExpiresAt: jwt.NewNumericDate(options.ExpiresAt), + NotBefore: jwt.NewNumericDate(time.Now().Add(-time.Minute)), + IssuedAt: jwt.NewNumericDate(time.Now().Add(-time.Minute)), + }, + LicenseExpires: jwt.NewNumericDate(options.GraceAt), + AccountType: options.AccountType, + AccountID: options.AccountID, + Version: coderd.CurrentVersion, + Features: coderd.Features{ + UserLimit: options.UserLimit, + AuditLog: auditLog, + }, + } + tok := jwt.NewWithClaims(jwt.SigningMethodEdDSA, c) + tok.Header[coderd.HeaderKeyID] = testKeyID + signedTok, err := tok.SignedString(testPrivateKey) + require.NoError(t, err) + license, err := client.AddLicense(context.Background(), codersdk.AddLicenseRequest{ + License: signedTok, + }) + require.NoError(t, err) + return license +} + +type nopcloser struct{} + +func (nopcloser) Close() error { return nil } diff --git a/coderd/coderdtest/authtest.go b/enterprise/coderd/coderdenttest/coderdenttest_test.go similarity index 90% rename from coderd/coderdtest/authtest.go rename to enterprise/coderd/coderdenttest/coderdenttest_test.go index 6eb3df8ac6bc5..41b426f8db8b6 100644 --- a/coderd/coderdtest/authtest.go +++ b/enterprise/coderd/coderdenttest/coderdenttest_test.go @@ -1,4 +1,4 @@ -package coderdtest +package coderdenttest_test import ( "context" @@ -14,152 +14,25 @@ import ( "github.com/stretchr/testify/require" "golang.org/x/xerrors" - "github.com/coder/coder/coderd" + "github.com/coder/coder/coderd/coderdtest" "github.com/coder/coder/coderd/rbac" "github.com/coder/coder/codersdk" + "github.com/coder/coder/enterprise/coderd" + "github.com/coder/coder/enterprise/coderd/coderdenttest" "github.com/coder/coder/provisioner/echo" "github.com/coder/coder/provisionersdk/proto" "github.com/coder/coder/testutil" ) -type RouteCheck struct { - NoAuthorize bool - AssertAction rbac.Action - AssertObject rbac.Object - StatusCode int +func TestNew(t *testing.T) { + t.Parallel() + _ = coderdenttest.New(t, nil) } -type AuthTester struct { - t *testing.T - api *coderd.API - authorizer *recordingAuthorizer - - Client *codersdk.Client - Workspace codersdk.Workspace - Organization codersdk.Organization - Admin codersdk.CreateFirstUserResponse - Template codersdk.Template - Version codersdk.TemplateVersion - WorkspaceResource codersdk.WorkspaceResource - File codersdk.UploadResponse - TemplateVersionDryRun codersdk.ProvisionerJob - TemplateParam codersdk.Parameter - URLParams map[string]string -} - -func NewAuthTester(ctx context.Context, t *testing.T, options *Options) *AuthTester { - authorizer := &recordingAuthorizer{} - if options == nil { - options = &Options{} - } - if options.Authorizer != nil { - t.Error("NewAuthTester cannot be called with custom Authorizer") - } - options.Authorizer = authorizer - options.IncludeProvisionerDaemon = true - - client, _, api := newWithAPI(t, options) - admin := CreateFirstUser(t, client) - // The provisioner will call to coderd and register itself. This is async, - // so we wait for it to occur. - require.Eventually(t, func() bool { - provisionerds, err := client.ProvisionerDaemons(ctx) - return assert.NoError(t, err) && len(provisionerds) > 0 - }, testutil.WaitLong, testutil.IntervalSlow) - - provisionerds, err := client.ProvisionerDaemons(ctx) - require.NoError(t, err, "fetch provisioners") - require.Len(t, provisionerds, 1) - - organization, err := client.Organization(ctx, admin.OrganizationID) - require.NoError(t, err, "fetch org") - - // Setup some data in the database. - version := CreateTemplateVersion(t, client, admin.OrganizationID, &echo.Responses{ - Parse: echo.ParseComplete, - Provision: []*proto.Provision_Response{{ - Type: &proto.Provision_Response_Complete{ - Complete: &proto.Provision_Complete{ - // Return a workspace resource - Resources: []*proto.Resource{{ - Name: "some", - Type: "example", - Agents: []*proto.Agent{{ - Name: "agent", - Id: "something", - Auth: &proto.Agent_Token{}, - Apps: []*proto.App{{ - Name: "testapp", - Url: "http://localhost:3000", - }}, - }}, - }}, - }, - }, - }}, - }) - AwaitTemplateVersionJob(t, client, version.ID) - template := CreateTemplate(t, client, admin.OrganizationID, version.ID) - workspace := CreateWorkspace(t, client, admin.OrganizationID, template.ID) - AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID) - file, err := client.Upload(ctx, codersdk.ContentTypeTar, make([]byte, 1024)) - require.NoError(t, err, "upload file") - workspaceResources, err := client.WorkspaceResourcesByBuild(ctx, workspace.LatestBuild.ID) - require.NoError(t, err, "workspace resources") - templateVersionDryRun, err := client.CreateTemplateVersionDryRun(ctx, version.ID, codersdk.CreateTemplateVersionDryRunRequest{ - ParameterValues: []codersdk.CreateParameterRequest{}, - }) - require.NoError(t, err, "template version dry-run") - - templateParam, err := client.CreateParameter(ctx, codersdk.ParameterTemplate, template.ID, codersdk.CreateParameterRequest{ - Name: "test-param", - SourceValue: "hello world", - SourceScheme: codersdk.ParameterSourceSchemeData, - DestinationScheme: codersdk.ParameterDestinationSchemeProvisionerVariable, - }) - require.NoError(t, err, "create template param") +func TestAuthorizeAllEndpoints(t *testing.T) { + t.Parallel() + a := newAuthTester(context.Background(), t) - urlParameters := map[string]string{ - "{organization}": admin.OrganizationID.String(), - "{user}": admin.UserID.String(), - "{organizationname}": organization.Name, - "{workspace}": workspace.ID.String(), - "{workspacebuild}": workspace.LatestBuild.ID.String(), - "{workspacename}": workspace.Name, - "{workspaceagent}": workspaceResources[0].Agents[0].ID.String(), - "{buildnumber}": strconv.FormatInt(int64(workspace.LatestBuild.BuildNumber), 10), - "{template}": template.ID.String(), - "{hash}": file.Hash, - "{workspaceresource}": workspaceResources[0].ID.String(), - "{workspaceapp}": workspaceResources[0].Agents[0].Apps[0].Name, - "{templateversion}": version.ID.String(), - "{jobID}": templateVersionDryRun.ID.String(), - "{templatename}": template.Name, - "{workspace_and_agent}": workspace.Name + "." + workspaceResources[0].Agents[0].Name, - // Only checking template scoped params here - "parameters/{scope}/{id}": fmt.Sprintf("parameters/%s/%s", - string(templateParam.Scope), templateParam.ScopeID.String()), - } - - return &AuthTester{ - t: t, - api: api, - authorizer: authorizer, - Client: client, - Workspace: workspace, - Organization: organization, - Admin: admin, - Template: template, - Version: version, - WorkspaceResource: workspaceResources[0], - File: file, - TemplateVersionDryRun: templateVersionDryRun, - TemplateParam: templateParam, - URLParams: urlParameters, - } -} - -func AGPLRoutes(a *AuthTester) (map[string]string, map[string]RouteCheck) { // Some quick reused objects workspaceRBACObj := rbac.ResourceWorkspace.InOrg(a.Organization.ID).WithOwner(a.Workspace.OwnerID.String()) workspaceExecObj := rbac.ResourceWorkspaceExecution.InOrg(a.Organization.ID).WithOwner(a.Workspace.OwnerID.String()) @@ -170,7 +43,7 @@ func AGPLRoutes(a *AuthTester) (map[string]string, map[string]RouteCheck) { "GET:/derp/latency-check": "This always returns a 200!", } - assertRoute := map[string]RouteCheck{ + assertRoute := map[string]routeCheck{ // These endpoints do not require auth "GET:/api/v2": {NoAuthorize: true}, "GET:/api/v2/buildinfo": {NoAuthorize: true}, @@ -391,11 +264,26 @@ func AGPLRoutes(a *AuthTester) (map[string]string, map[string]RouteCheck) { "PUT:/api/v2/organizations/{organization}/members/{user}/roles": {NoAuthorize: true}, "POST:/api/v2/workspaces/{workspace}/builds": {StatusCode: http.StatusBadRequest, NoAuthorize: true}, "POST:/api/v2/organizations/{organization}/templateversions": {StatusCode: http.StatusBadRequest, NoAuthorize: true}, + + // Enterprise only endpoints + "POST:/api/v2/licenses": { + AssertAction: rbac.ActionCreate, + AssertObject: rbac.ResourceLicense, + }, + "GET:/api/v2/licenses": { + StatusCode: http.StatusOK, + AssertAction: rbac.ActionRead, + AssertObject: rbac.ResourceLicense, + }, + "DELETE:/api/v2/licenses/{id}": { + AssertAction: rbac.ActionDelete, + AssertObject: rbac.ResourceLicense, + }, } // Routes like proxy routes support all HTTP methods. A helper func to expand // 1 url to all http methods. - assertAllHTTPMethods := func(url string, check RouteCheck) { + assertAllHTTPMethods := func(url string, check routeCheck) { methods := []string{http.MethodGet, http.MethodHead, http.MethodPost, http.MethodPut, http.MethodPatch, http.MethodDelete, http.MethodConnect, http.MethodOptions, http.MethodTrace} @@ -406,19 +294,155 @@ func AGPLRoutes(a *AuthTester) (map[string]string, map[string]RouteCheck) { } } - assertAllHTTPMethods("/%40{user}/{workspace_and_agent}/apps/{workspaceapp}/*", RouteCheck{ + assertAllHTTPMethods("/%40{user}/{workspace_and_agent}/apps/{workspaceapp}/*", routeCheck{ AssertAction: rbac.ActionCreate, AssertObject: workspaceExecObj, }) - assertAllHTTPMethods("/@{user}/{workspace_and_agent}/apps/{workspaceapp}/*", RouteCheck{ + assertAllHTTPMethods("/@{user}/{workspace_and_agent}/apps/{workspaceapp}/*", routeCheck{ AssertAction: rbac.ActionCreate, AssertObject: workspaceExecObj, }) - return skipRoutes, assertRoute + a.Test(context.Background(), assertRoute, skipRoutes) +} + +type routeCheck struct { + NoAuthorize bool + AssertAction rbac.Action + AssertObject rbac.Object + StatusCode int +} + +type authTester struct { + t *testing.T + api *coderd.API + authorizer *recordingAuthorizer + + Client *codersdk.Client + Workspace codersdk.Workspace + Organization codersdk.Organization + Admin codersdk.CreateFirstUserResponse + Template codersdk.Template + Version codersdk.TemplateVersion + WorkspaceResource codersdk.WorkspaceResource + File codersdk.UploadResponse + TemplateVersionDryRun codersdk.ProvisionerJob + TemplateParam codersdk.Parameter + URLParams map[string]string +} + +func newAuthTester(ctx context.Context, t *testing.T) *authTester { + authorizer := &recordingAuthorizer{} + options := &coderdenttest.Options{ + Options: &coderdtest.Options{ + Authorizer: authorizer, + IncludeProvisionerDaemon: true, + }, + } + + client, _, api := coderdenttest.NewWithAPI(t, options) + admin := coderdtest.CreateFirstUser(t, client) + // The provisioner will call to coderd and register itself. This is async, + // so we wait for it to occur. + require.Eventually(t, func() bool { + provisionerds, err := client.ProvisionerDaemons(ctx) + return assert.NoError(t, err) && len(provisionerds) > 0 + }, testutil.WaitLong, testutil.IntervalSlow) + + provisionerds, err := client.ProvisionerDaemons(ctx) + require.NoError(t, err, "fetch provisioners") + require.Len(t, provisionerds, 1) + + organization, err := client.Organization(ctx, admin.OrganizationID) + require.NoError(t, err, "fetch org") + + // Setup some data in the database. + version := coderdtest.CreateTemplateVersion(t, client, admin.OrganizationID, &echo.Responses{ + Parse: echo.ParseComplete, + Provision: []*proto.Provision_Response{{ + Type: &proto.Provision_Response_Complete{ + Complete: &proto.Provision_Complete{ + // Return a workspace resource + Resources: []*proto.Resource{{ + Name: "some", + Type: "example", + Agents: []*proto.Agent{{ + Name: "agent", + Id: "something", + Auth: &proto.Agent_Token{}, + Apps: []*proto.App{{ + Name: "testapp", + Url: "http://localhost:3000", + }}, + }}, + }}, + }, + }, + }}, + }) + coderdtest.AwaitTemplateVersionJob(t, client, version.ID) + template := coderdtest.CreateTemplate(t, client, admin.OrganizationID, version.ID) + workspace := coderdtest.CreateWorkspace(t, client, admin.OrganizationID, template.ID) + coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID) + file, err := client.Upload(ctx, codersdk.ContentTypeTar, make([]byte, 1024)) + require.NoError(t, err, "upload file") + workspaceResources, err := client.WorkspaceResourcesByBuild(ctx, workspace.LatestBuild.ID) + require.NoError(t, err, "workspace resources") + templateVersionDryRun, err := client.CreateTemplateVersionDryRun(ctx, version.ID, codersdk.CreateTemplateVersionDryRunRequest{ + ParameterValues: []codersdk.CreateParameterRequest{}, + }) + require.NoError(t, err, "template version dry-run") + + templateParam, err := client.CreateParameter(ctx, codersdk.ParameterTemplate, template.ID, codersdk.CreateParameterRequest{ + Name: "test-param", + SourceValue: "hello world", + SourceScheme: codersdk.ParameterSourceSchemeData, + DestinationScheme: codersdk.ParameterDestinationSchemeProvisionerVariable, + }) + require.NoError(t, err, "create template param") + license := coderdenttest.AddLicense(t, client, coderdenttest.AddLicenseOptions{}) + urlParameters := map[string]string{ + "{organization}": admin.OrganizationID.String(), + "{user}": admin.UserID.String(), + "{organizationname}": organization.Name, + "{workspace}": workspace.ID.String(), + "{workspacebuild}": workspace.LatestBuild.ID.String(), + "{workspacename}": workspace.Name, + "{workspaceagent}": workspaceResources[0].Agents[0].ID.String(), + "{buildnumber}": strconv.FormatInt(int64(workspace.LatestBuild.BuildNumber), 10), + "{template}": template.ID.String(), + "{hash}": file.Hash, + "{workspaceresource}": workspaceResources[0].ID.String(), + "{workspaceapp}": workspaceResources[0].Agents[0].Apps[0].Name, + "{templateversion}": version.ID.String(), + "{jobID}": templateVersionDryRun.ID.String(), + "{templatename}": template.Name, + "{workspace_and_agent}": workspace.Name + "." + workspaceResources[0].Agents[0].Name, + // Only checking template scoped params here + "parameters/{scope}/{id}": fmt.Sprintf("parameters/%s/%s", + string(templateParam.Scope), templateParam.ScopeID.String()), + "licenses/{id}": fmt.Sprintf("licenses/%d", license.ID), + } + + return &authTester{ + t: t, + api: api, + authorizer: authorizer, + Client: client, + Workspace: workspace, + Organization: organization, + Admin: admin, + Template: template, + Version: version, + WorkspaceResource: workspaceResources[0], + File: file, + TemplateVersionDryRun: templateVersionDryRun, + TemplateParam: templateParam, + URLParams: urlParameters, + } } -func (a *AuthTester) Test(ctx context.Context, assertRoute map[string]RouteCheck, skipRoutes map[string]string) { +func (a *authTester) Test(ctx context.Context, assertRoute map[string]routeCheck, skipRoutes map[string]string) { // Always fail auth from this point forward a.authorizer.AlwaysReturn = rbac.ForbiddenWithInternal(xerrors.New("fake implementation"), nil, nil) @@ -443,7 +467,7 @@ func (a *AuthTester) Test(ctx context.Context, assertRoute map[string]RouteCheck } err := chi.Walk( - a.api.Handler, + a.api.AGPL.RootHandler, func( method string, route string, @@ -466,7 +490,7 @@ func (a *AuthTester) Test(ctx context.Context, assertRoute map[string]RouteCheck routeAssertions, ok := assertRoute[routeKey] if !ok { // By default, all omitted routes check for just "authorize" called - routeAssertions = RouteCheck{} + routeAssertions = routeCheck{} } delete(routeMissing, routeKey) diff --git a/enterprise/coderd/features.go b/enterprise/coderd/features.go deleted file mode 100644 index bc9977ff18441..0000000000000 --- a/enterprise/coderd/features.go +++ /dev/null @@ -1,327 +0,0 @@ -package coderd - -import ( - "context" - "crypto/ed25519" - "fmt" - "net/http" - "reflect" - "sync" - "time" - - "github.com/coder/coder/enterprise/audit/backends" - - "github.com/cenkalti/backoff/v4" - "golang.org/x/xerrors" - - "cdr.dev/slog" - - agpl "github.com/coder/coder/coderd" - agplAudit "github.com/coder/coder/coderd/audit" - "github.com/coder/coder/coderd/database" - "github.com/coder/coder/coderd/features" - "github.com/coder/coder/coderd/httpapi" - "github.com/coder/coder/codersdk" - "github.com/coder/coder/enterprise/audit" -) - -type Enablements struct { - AuditLogs bool -} - -type featuresService struct { - logger slog.Logger - database database.Store - pubsub database.Pubsub - keys map[string]ed25519.PublicKey - enablements Enablements - resyncInterval time.Duration - // enabledImplementations includes an "enabled" implementation of every feature. This is - // initialized at start of day and remains static. The consequence of this is that these things - // are hanging around using memory even if not licensed or in use, but it greatly simplifies the - // logic because we don't have to bother creating and destroying them as entitlements change. - // If we have a particularly memory-hungry feature in future, we might wish to reconsider this - // choice. - enabledImplementations agpl.FeatureInterfaces - - mu sync.RWMutex - entitlements entitlements -} - -// newFeaturesService creates a FeaturesService and starts it. It will continue running for the -// duration of the passed ctx. -func newFeaturesService( - ctx context.Context, - logger slog.Logger, - db database.Store, - pubsub database.Pubsub, - enablements Enablements, -) features.Service { - fs := &featuresService{ - logger: logger, - database: db, - pubsub: pubsub, - keys: keys, - enablements: enablements, - enabledImplementations: agpl.FeatureInterfaces{ - Auditor: audit.NewAuditor( - audit.DefaultFilter, - backends.NewPostgres(db, true), - backends.NewSlog(logger), - ), - }, - resyncInterval: 10 * time.Minute, - entitlements: entitlements{ - activeUsers: numericalEntitlement{ - entitlementLimit: entitlementLimit{ - unlimited: true, - }, - }, - }, - } - go fs.syncEntitlements(ctx) - return fs -} - -func (s *featuresService) EntitlementsAPI(rw http.ResponseWriter, r *http.Request) { - s.mu.RLock() - e := s.entitlements - s.mu.RUnlock() - - resp := codersdk.Entitlements{ - Features: make(map[string]codersdk.Feature), - Warnings: make([]string, 0), - HasLicense: e.hasLicense, - } - - // User limit - uf := codersdk.Feature{ - Entitlement: e.activeUsers.state.toSDK(), - Enabled: true, - } - if !e.activeUsers.unlimited { - n, err := s.database.GetActiveUserCount(r.Context()) - if err != nil { - httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{ - Message: "Unable to query database", - Detail: err.Error(), - }) - return - } - uf.Actual = &n - uf.Limit = &e.activeUsers.limit - if n > e.activeUsers.limit { - resp.Warnings = append(resp.Warnings, - fmt.Sprintf( - "Your deployment has %d active users but is only licensed for %d.", - n, e.activeUsers.limit)) - } - } - resp.Features[codersdk.FeatureUserLimit] = uf - - // Audit logs - resp.Features[codersdk.FeatureAuditLog] = codersdk.Feature{ - Entitlement: e.auditLogs.state.toSDK(), - Enabled: s.enablements.AuditLogs, - } - if e.auditLogs.state == gracePeriod && s.enablements.AuditLogs { - resp.Warnings = append(resp.Warnings, - "Audit logging is enabled but your license for this feature is expired.") - } - - httpapi.Write(rw, http.StatusOK, resp) -} - -type entitlementState int - -const ( - notEntitled entitlementState = iota - gracePeriod - entitled -) - -type entitlementLimit struct { - unlimited bool - limit int64 -} - -type entitlement struct { - state entitlementState -} - -func (s entitlementState) toSDK() codersdk.Entitlement { - switch s { - case notEntitled: - return codersdk.EntitlementNotEntitled - case gracePeriod: - return codersdk.EntitlementGracePeriod - case entitled: - return codersdk.EntitlementEntitled - default: - panic("unknown entitlementState") - } -} - -type numericalEntitlement struct { - entitlement - entitlementLimit -} - -type entitlements struct { - hasLicense bool - activeUsers numericalEntitlement - auditLogs entitlement -} - -func (s *featuresService) getEntitlements(ctx context.Context) (entitlements, error) { - licenses, err := s.database.GetUnexpiredLicenses(ctx) - if err != nil { - return entitlements{}, err - } - now := time.Now() - e := entitlements{ - activeUsers: numericalEntitlement{ - entitlementLimit: entitlementLimit{ - unlimited: true, - }, - }, - } - for _, l := range licenses { - claims, err := validateDBLicense(l, s.keys) - if err != nil { - s.logger.Debug(ctx, "skipping invalid license", - slog.F("id", l.ID), slog.Error(err)) - continue - } - e.hasLicense = true - thisEntitlement := entitled - if now.After(claims.LicenseExpires.Time) { - // if the grace period were over, the validation fails, so if we are after - // LicenseExpires we must be in grace period. - thisEntitlement = gracePeriod - } - if claims.Features.UserLimit > 0 { - e.activeUsers.state = thisEntitlement - e.activeUsers.unlimited = false - e.activeUsers.limit = max(e.activeUsers.limit, claims.Features.UserLimit) - } - if claims.Features.AuditLog > 0 { - e.auditLogs.state = thisEntitlement - } - } - return e, nil -} - -func (s *featuresService) syncEntitlements(ctx context.Context) { - eb := backoff.NewExponentialBackOff() - eb.MaxElapsedTime = 0 // retry indefinitely - b := backoff.WithContext(eb, ctx) - updates := make(chan struct{}, 1) - subscribed := false - - for { - select { - case <-ctx.Done(): - return - default: - // pass - } - if !subscribed { - cancel, err := s.pubsub.Subscribe(PubSubEventLicenses, func(_ context.Context, _ []byte) { - // don't block. If the channel is full, drop the event, as there is a resync - // scheduled already. - select { - case updates <- struct{}{}: - // pass - default: - // pass - } - }) - if err != nil { - s.logger.Warn(ctx, "failed to subscribe to license updates", slog.Error(err)) - time.Sleep(b.NextBackOff()) - continue - } - // nolint: revive - defer cancel() - subscribed = true - s.logger.Debug(ctx, "successfully subscribed to pubsub") - } - - s.logger.Info(ctx, "syncing licensed entitlements") - ents, err := s.getEntitlements(ctx) - if err != nil { - s.logger.Warn(ctx, "failed to get feature entitlements", slog.Error(err)) - time.Sleep(b.NextBackOff()) - continue - } - b.Reset() - - s.mu.Lock() - s.entitlements = ents - s.mu.Unlock() - s.logger.Debug(ctx, "synced licensed entitlements") - - select { - case <-ctx.Done(): - return - case <-time.After(s.resyncInterval): - continue - case <-updates: - s.logger.Debug(ctx, "got pubsub update") - continue - } - } -} - -func max(a, b int64) int64 { - if a > b { - return a - } - return b -} - -func (s *featuresService) Get(ps any) error { - if reflect.TypeOf(ps).Kind() != reflect.Pointer { - return xerrors.New("input must be pointer to struct") - } - vs := reflect.ValueOf(ps).Elem() - if vs.Kind() != reflect.Struct { - return xerrors.New("input must be pointer to struct") - } - // grab a local copy of entitlements so that we have a consistent set, but aren't keeping it - // locked from updates while we process. - s.mu.RLock() - ent := s.entitlements - s.mu.RUnlock() - - for i := 0; i < vs.NumField(); i++ { - vf := vs.Field(i) - tf := vf.Type() - if tf.Kind() != reflect.Interface { - return xerrors.Errorf("fields of input struct must be interfaces: %s", tf.String()) - } - - err := s.setImplementation(ent, vf, tf) - if err != nil { - return err - } - } - return nil -} - -func (s *featuresService) setImplementation(ent entitlements, vf reflect.Value, tf reflect.Type) error { - // c.f. https://stackoverflow.com/questions/7132848/how-to-get-the-reflect-type-of-an-interface - switch tf { - case reflect.TypeOf((*agplAudit.Auditor)(nil)).Elem(): - // Audit logging - if !s.enablements.AuditLogs || ent.auditLogs.state == notEntitled { - vf.Set(reflect.ValueOf(agpl.DisabledImplementations.Auditor)) - return nil - } - vf.Set(reflect.ValueOf(s.enabledImplementations.Auditor)) - return nil - default: - return xerrors.Errorf("unable to find implementation of interface %s", tf.String()) - } -} diff --git a/enterprise/coderd/features_internal_test.go b/enterprise/coderd/features_internal_test.go deleted file mode 100644 index a195c2ffe784b..0000000000000 --- a/enterprise/coderd/features_internal_test.go +++ /dev/null @@ -1,545 +0,0 @@ -package coderd - -import ( - "context" - "crypto/ed25519" - "crypto/rand" - "encoding/json" - "net/http" - "net/http/httptest" - "reflect" - "testing" - "time" - - "github.com/golang-jwt/jwt/v4" - "github.com/google/uuid" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - "cdr.dev/slog/sloggers/slogtest" - - agplCoderd "github.com/coder/coder/coderd" - agplAudit "github.com/coder/coder/coderd/audit" - "github.com/coder/coder/coderd/database" - "github.com/coder/coder/coderd/database/databasefake" - "github.com/coder/coder/coderd/features" - "github.com/coder/coder/codersdk" - "github.com/coder/coder/enterprise/audit" - "github.com/coder/coder/enterprise/audit/backends" - "github.com/coder/coder/testutil" -) - -func TestFeaturesService_EntitlementsAPI(t *testing.T) { - t.Parallel() - logger := slogtest.Make(t, nil) - - // Note that these are not actually used because we don't run the syncEntitlements - // routine in this test. - pubsub := database.NewPubsubInMemory() - pub, _, err := ed25519.GenerateKey(rand.Reader) - require.NoError(t, err) - keyID := "testing" - - t.Run("NoLicense", func(t *testing.T) { - t.Parallel() - db := databasefake.New() - uut := &featuresService{ - logger: logger, - database: db, - pubsub: pubsub, - keys: map[string]ed25519.PublicKey{keyID: pub}, - enablements: Enablements{AuditLogs: true}, - entitlements: entitlements{ - hasLicense: false, - activeUsers: numericalEntitlement{ - entitlement{notEntitled}, - entitlementLimit{ - unlimited: true, - }, - }, - auditLogs: entitlement{notEntitled}, - }, - } - result := requestEntitlements(t, uut) - assert.False(t, result.HasLicense) - assert.Empty(t, result.Warnings) - assert.Equal(t, codersdk.EntitlementNotEntitled, result.Features[codersdk.FeatureUserLimit].Entitlement) - assert.Equal(t, codersdk.EntitlementNotEntitled, result.Features[codersdk.FeatureAuditLog].Entitlement) - }) - - t.Run("FullLicense", func(t *testing.T) { - t.Parallel() - ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) - defer cancel() - db := databasefake.New() - uut := &featuresService{ - logger: logger, - database: db, - pubsub: pubsub, - keys: map[string]ed25519.PublicKey{keyID: pub}, - enablements: Enablements{AuditLogs: true}, - entitlements: entitlements{ - hasLicense: true, - activeUsers: numericalEntitlement{ - entitlement{entitled}, - entitlementLimit{ - unlimited: false, - limit: 100, - }, - }, - auditLogs: entitlement{entitled}, - }, - } - _, err := db.InsertUser(ctx, database.InsertUserParams{ - ID: uuid.UUID{}, - Email: "", - Username: "", - HashedPassword: nil, - CreatedAt: time.Time{}, - UpdatedAt: time.Time{}, - RBACRoles: nil, - LoginType: "", - }) - require.NoError(t, err) - result := requestEntitlements(t, uut) - assert.True(t, result.HasLicense) - ul := result.Features[codersdk.FeatureUserLimit] - assert.Equal(t, codersdk.EntitlementEntitled, ul.Entitlement) - assert.Equal(t, int64(100), *ul.Limit) - assert.Equal(t, int64(1), *ul.Actual) - assert.True(t, ul.Enabled) - al := result.Features[codersdk.FeatureAuditLog] - assert.Equal(t, codersdk.EntitlementEntitled, al.Entitlement) - assert.True(t, al.Enabled) - assert.Nil(t, al.Limit) - assert.Nil(t, al.Actual) - assert.Empty(t, result.Warnings) - }) - - t.Run("Warnings", func(t *testing.T) { - t.Parallel() - ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) - defer cancel() - db := databasefake.New() - uut := &featuresService{ - logger: logger, - database: db, - pubsub: pubsub, - keys: map[string]ed25519.PublicKey{keyID: pub}, - enablements: Enablements{AuditLogs: true}, - entitlements: entitlements{ - hasLicense: true, - activeUsers: numericalEntitlement{ - entitlement{gracePeriod}, - entitlementLimit{ - unlimited: false, - limit: 4, - }, - }, - auditLogs: entitlement{gracePeriod}, - }, - } - for i := byte(0); i < 5; i++ { - _, err := db.InsertUser(ctx, database.InsertUserParams{ - ID: uuid.UUID{i}, - Email: "", - Username: "", - HashedPassword: nil, - CreatedAt: time.Time{}, - UpdatedAt: time.Time{}, - RBACRoles: nil, - LoginType: "", - }) - require.NoError(t, err) - } - result := requestEntitlements(t, uut) - assert.True(t, result.HasLicense) - ul := result.Features[codersdk.FeatureUserLimit] - assert.Equal(t, codersdk.EntitlementGracePeriod, ul.Entitlement) - assert.Equal(t, int64(4), *ul.Limit) - assert.Equal(t, int64(5), *ul.Actual) - assert.True(t, ul.Enabled) - al := result.Features[codersdk.FeatureAuditLog] - assert.Equal(t, codersdk.EntitlementGracePeriod, al.Entitlement) - assert.True(t, al.Enabled) - assert.Nil(t, al.Limit) - assert.Nil(t, al.Actual) - assert.Len(t, result.Warnings, 2) - assert.Contains(t, result.Warnings, - "Your deployment has 5 active users but is only licensed for 4.") - assert.Contains(t, result.Warnings, - "Audit logging is enabled but your license for this feature is expired.") - }) -} - -func TestFeaturesServiceSyncEntitlements(t *testing.T) { - t.Parallel() - pub, priv, err := ed25519.GenerateKey(rand.Reader) - require.NoError(t, err) - keyID := "testing" - - // This tests that pubsub updates work by setting the resync interval very long - t.Run("PubSub", func(t *testing.T) { - t.Parallel() - ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) - defer cancel() - logger := slogtest.Make(t, nil) - pubsub := database.NewPubsubInMemory() - db := databasefake.New() - uut := &featuresService{ - logger: logger, - database: db, - pubsub: pubsub, - keys: map[string]ed25519.PublicKey{keyID: pub}, - enablements: Enablements{AuditLogs: true}, - resyncInterval: time.Hour, // no resyncs during test - entitlements: entitlements{}, - } - - _, invalidKey, err := ed25519.GenerateKey(rand.Reader) - require.NoError(t, err) - - // Start of day, 3 licenses, one expired, one invalid - _ = putLicense(ctx, t, db, priv, keyID, 1000, -2*time.Hour, -1*time.Hour) - _ = putLicense(ctx, t, db, invalidKey, "invalid", 900, time.Hour, 2*time.Hour) - l0 := putLicense(ctx, t, db, priv, keyID, 300, time.Hour, 2*time.Hour) - - go uut.syncEntitlements(ctx) - - testutil.Eventually(ctx, t, userLimitIs(uut, 300), testutil.IntervalFast) - - // New license - l1 := putLicense(ctx, t, db, priv, keyID, 305, time.Hour, 2*time.Hour) - err = pubsub.Publish(PubSubEventLicenses, []byte("add")) - require.NoError(t, err) - - // User limit goes up, because 305 > 300 - testutil.Eventually(ctx, t, userLimitIs(uut, 305), testutil.IntervalFast) - - // New license with lower limit - _ = putLicense(ctx, t, db, priv, keyID, 295, time.Hour, 2*time.Hour) - err = pubsub.Publish(PubSubEventLicenses, []byte("add")) - require.NoError(t, err) - - // Need to delete the others before the limit lowers - _, err = db.DeleteLicense(ctx, l1.ID) - require.NoError(t, err) - err = pubsub.Publish(PubSubEventLicenses, []byte("delete")) - require.NoError(t, err) - testutil.Eventually(ctx, t, userLimitIs(uut, 300), testutil.IntervalFast) - - _, err = db.DeleteLicense(ctx, l0.ID) - require.NoError(t, err) - err = pubsub.Publish(PubSubEventLicenses, []byte("delete")) - require.NoError(t, err) - testutil.Eventually(ctx, t, userLimitIs(uut, 295), testutil.IntervalFast) - }) - - // This tests that periodic resyncs work by setting the resync interval very fast and - // not sending any pubsub updates. - t.Run("Resyncs", func(t *testing.T) { - t.Parallel() - ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) - defer cancel() - logger := slogtest.Make(t, nil) - pubsub := database.NewPubsubInMemory() - db := databasefake.New() - uut := &featuresService{ - logger: logger, - database: db, - pubsub: pubsub, - keys: map[string]ed25519.PublicKey{keyID: pub}, - enablements: Enablements{AuditLogs: true}, - resyncInterval: 10 * time.Millisecond, - entitlements: entitlements{}, - } - - _, invalidKey, err := ed25519.GenerateKey(rand.Reader) - require.NoError(t, err) - - // Start of day, 3 licenses, one expired, one invalid - _ = putLicense(ctx, t, db, priv, keyID, 1000, -2*time.Hour, -1*time.Hour) - _ = putLicense(ctx, t, db, invalidKey, "invalid", 900, time.Hour, 2*time.Hour) - l0 := putLicense(ctx, t, db, priv, keyID, 300, time.Hour, 2*time.Hour) - - go uut.syncEntitlements(ctx) - - testutil.Eventually(ctx, t, userLimitIs(uut, 300), testutil.IntervalFast) - - // New license - l1 := putLicense(ctx, t, db, priv, keyID, 305, time.Hour, 2*time.Hour) - - // User limit goes up, because 305 > 300 - testutil.Eventually(ctx, t, userLimitIs(uut, 305), testutil.IntervalFast) - - // New license with lower limit - _ = putLicense(ctx, t, db, priv, keyID, 295, time.Hour, 2*time.Hour) - - // Need to delete the others before the limit lowers - _, err = db.DeleteLicense(ctx, l1.ID) - require.NoError(t, err) - testutil.Eventually(ctx, t, userLimitIs(uut, 300), testutil.IntervalFast) - - _, err = db.DeleteLicense(ctx, l0.ID) - require.NoError(t, err) - testutil.Eventually(ctx, t, userLimitIs(uut, 295), testutil.IntervalFast) - }) -} - -func requestEntitlements(t *testing.T, uut features.Service) codersdk.Entitlements { - t.Helper() - r := httptest.NewRequest("GET", "https://example.com/api/v2/entitlements", nil) - rw := httptest.NewRecorder() - uut.EntitlementsAPI(rw, r) - resp := rw.Result() - defer resp.Body.Close() - assert.Equal(t, http.StatusOK, resp.StatusCode) - dec := json.NewDecoder(resp.Body) - var result codersdk.Entitlements - err := dec.Decode(&result) - require.NoError(t, err) - return result -} - -func putLicense( - ctx context.Context, t *testing.T, db database.Store, - k ed25519.PrivateKey, keyID string, userLimit int64, - timeToGrace, timeToExpire time.Duration, -) database.License { - t.Helper() - c := &Claims{ - RegisteredClaims: jwt.RegisteredClaims{ - Issuer: "test@testing.test", - ExpiresAt: jwt.NewNumericDate(time.Now().Add(timeToExpire)), - NotBefore: jwt.NewNumericDate(time.Now().Add(-time.Minute)), - IssuedAt: jwt.NewNumericDate(time.Now().Add(-time.Minute)), - }, - LicenseExpires: jwt.NewNumericDate(time.Now().Add(timeToGrace)), - Version: CurrentVersion, - Features: Features{ - UserLimit: userLimit, - AuditLog: 1, - }, - } - j, err := makeLicense(c, k, keyID) - require.NoError(t, err) - l, err := db.InsertLicense(ctx, database.InsertLicenseParams{ - UploadedAt: c.IssuedAt.Time, - JWT: j, - Exp: c.ExpiresAt.Time, - }) - require.NoError(t, err) - return l -} - -func userLimitIs(fs *featuresService, limit int64) func(context.Context) bool { - return func(_ context.Context) bool { - fs.mu.RLock() - defer fs.mu.RUnlock() - return fs.entitlements.activeUsers.limit == limit - } -} - -func TestFeaturesServiceGet(t *testing.T) { - t.Parallel() - logger := slogtest.Make(t, nil) - - // Note that these are not actually used because we don't run the syncEntitlements - // routine in this test. - pubsub := database.NewPubsubInMemory() - pub, _, err := ed25519.GenerateKey(rand.Reader) - require.NoError(t, err) - keyID := "testing" - db := databasefake.New() - - t.Run("AuditorOff", func(t *testing.T) { - t.Parallel() - uut := &featuresService{ - logger: logger, - database: db, - pubsub: pubsub, - keys: map[string]ed25519.PublicKey{keyID: pub}, - enablements: Enablements{AuditLogs: true}, - enabledImplementations: agplCoderd.FeatureInterfaces{ - Auditor: audit.NewAuditor(audit.DefaultFilter), - }, - entitlements: entitlements{ - hasLicense: false, - activeUsers: numericalEntitlement{ - entitlement{notEntitled}, - entitlementLimit{ - unlimited: true, - }, - }, - auditLogs: entitlement{notEntitled}, - }, - } - target := struct { - Auditor agplAudit.Auditor - }{} - err := uut.Get(&target) - require.NoError(t, err) - assert.NotNil(t, target.Auditor) - nop := agplAudit.NewNop() - assert.Equal(t, reflect.ValueOf(nop).Type(), reflect.ValueOf(target.Auditor).Type()) - }) - - t.Run("AuditorOn", func(t *testing.T) { - t.Parallel() - uut := &featuresService{ - logger: logger, - database: db, - pubsub: pubsub, - keys: map[string]ed25519.PublicKey{keyID: pub}, - enablements: Enablements{AuditLogs: true}, - enabledImplementations: agplCoderd.FeatureInterfaces{ - Auditor: audit.NewAuditor(audit.DefaultFilter), - }, - entitlements: entitlements{ - hasLicense: false, - activeUsers: numericalEntitlement{ - entitlement{notEntitled}, - entitlementLimit{ - unlimited: true, - }, - }, - auditLogs: entitlement{entitled}, - }, - } - target := struct { - Auditor agplAudit.Auditor - }{} - err := uut.Get(&target) - require.NoError(t, err) - assert.NotNil(t, target.Auditor) - ea := audit.NewAuditor( - audit.DefaultFilter, - backends.NewPostgres(db, true), - backends.NewSlog(logger), - ) - assert.Equal(t, reflect.ValueOf(ea).Type(), reflect.ValueOf(target.Auditor).Type()) - }) - - t.Run("NotPointer", func(t *testing.T) { - t.Parallel() - uut := &featuresService{ - logger: logger, - database: db, - pubsub: pubsub, - keys: map[string]ed25519.PublicKey{keyID: pub}, - enablements: Enablements{AuditLogs: true}, - enabledImplementations: agplCoderd.FeatureInterfaces{ - Auditor: audit.NewAuditor(audit.DefaultFilter), - }, - entitlements: entitlements{ - hasLicense: false, - activeUsers: numericalEntitlement{ - entitlement{notEntitled}, - entitlementLimit{ - unlimited: true, - }, - }, - auditLogs: entitlement{notEntitled}, - }, - } - target := struct { - Auditor agplAudit.Auditor - }{} - err := uut.Get(target) - require.Error(t, err) - assert.Nil(t, target.Auditor) - }) - - t.Run("UnknownInterface", func(t *testing.T) { - t.Parallel() - uut := &featuresService{ - logger: logger, - database: db, - pubsub: pubsub, - keys: map[string]ed25519.PublicKey{keyID: pub}, - enablements: Enablements{AuditLogs: true}, - enabledImplementations: agplCoderd.FeatureInterfaces{ - Auditor: audit.NewAuditor(audit.DefaultFilter), - }, - entitlements: entitlements{ - hasLicense: false, - activeUsers: numericalEntitlement{ - entitlement{notEntitled}, - entitlementLimit{ - unlimited: true, - }, - }, - auditLogs: entitlement{notEntitled}, - }, - } - target := struct { - test testInterface - }{} - err := uut.Get(&target) - require.Error(t, err) - assert.Nil(t, target.test) - }) - - t.Run("PointerToNonStruct", func(t *testing.T) { - t.Parallel() - uut := &featuresService{ - logger: logger, - database: db, - pubsub: pubsub, - keys: map[string]ed25519.PublicKey{keyID: pub}, - enablements: Enablements{AuditLogs: true}, - enabledImplementations: agplCoderd.FeatureInterfaces{ - Auditor: audit.NewAuditor(audit.DefaultFilter), - }, - entitlements: entitlements{ - hasLicense: false, - activeUsers: numericalEntitlement{ - entitlement{notEntitled}, - entitlementLimit{ - unlimited: true, - }, - }, - auditLogs: entitlement{notEntitled}, - }, - } - var target agplAudit.Auditor - err := uut.Get(&target) - require.Error(t, err) - assert.Nil(t, target) - }) - - t.Run("StructWithNonInterfaces", func(t *testing.T) { - t.Parallel() - uut := &featuresService{ - logger: logger, - database: db, - pubsub: pubsub, - keys: map[string]ed25519.PublicKey{keyID: pub}, - enablements: Enablements{AuditLogs: true}, - enabledImplementations: agplCoderd.FeatureInterfaces{ - Auditor: audit.NewAuditor(audit.DefaultFilter), - }, - entitlements: entitlements{ - hasLicense: false, - activeUsers: numericalEntitlement{ - entitlement{notEntitled}, - entitlementLimit{ - unlimited: true, - }, - }, - auditLogs: entitlement{notEntitled}, - }, - } - target := struct { - N int64 - Auditor agplAudit.Auditor - }{} - err := uut.Get(&target) - require.Error(t, err) - assert.Nil(t, target.Auditor) - }) -} - -type testInterface interface { - Test() error -} diff --git a/enterprise/coderd/licenses.go b/enterprise/coderd/licenses.go index 9f75796da19bd..2a544cd8b05c0 100644 --- a/enterprise/coderd/licenses.go +++ b/enterprise/coderd/licenses.go @@ -30,7 +30,8 @@ const ( HeaderKeyID = "kid" AccountTypeSalesforce = "salesforce" VersionClaim = "version" - PubSubEventLicenses = "licenses" + + pubSubEventLicenses = "licenses" ) var ValidMethods = []string{"EdDSA"} @@ -41,7 +42,7 @@ var ValidMethods = []string{"EdDSA"} //go:embed keys/2022-08-12 var key20220812 []byte -var keys = map[string]ed25519.PublicKey{"2022-08-12": ed25519.PublicKey(key20220812)} +var Keys = map[string]ed25519.PublicKey{"2022-08-12": ed25519.PublicKey(key20220812)} type Features struct { UserLimit int64 `json:"user_limit"` @@ -68,96 +69,6 @@ var ( ErrMissingLicenseExpires = xerrors.New("license missing license_expires") ) -// parseLicense parses the license and returns the claims. If the license's signature is invalid or -// is not parsable, an error is returned. -func parseLicense(l string, keys map[string]ed25519.PublicKey) (jwt.MapClaims, error) { - tok, err := jwt.Parse( - l, - keyFunc(keys), - jwt.WithValidMethods(ValidMethods), - ) - if err != nil { - return nil, err - } - if claims, ok := tok.Claims.(jwt.MapClaims); ok && tok.Valid { - version, ok := claims[VersionClaim].(float64) - if !ok { - return nil, ErrInvalidVersion - } - if int64(version) != CurrentVersion { - return nil, ErrInvalidVersion - } - return claims, nil - } - return nil, xerrors.New("unable to parse Claims") -} - -// validateDBLicense validates a database.License record, and if valid, returns the claims. If -// unparsable or invalid, it returns an error -func validateDBLicense(l database.License, keys map[string]ed25519.PublicKey) (*Claims, error) { - tok, err := jwt.ParseWithClaims( - l.JWT, - &Claims{}, - keyFunc(keys), - jwt.WithValidMethods(ValidMethods), - ) - if err != nil { - return nil, err - } - if claims, ok := tok.Claims.(*Claims); ok && tok.Valid { - if claims.Version != uint64(CurrentVersion) { - return nil, ErrInvalidVersion - } - if claims.LicenseExpires == nil { - return nil, ErrMissingLicenseExpires - } - return claims, nil - } - return nil, xerrors.New("unable to parse Claims") -} - -func keyFunc(keys map[string]ed25519.PublicKey) func(*jwt.Token) (interface{}, error) { - return func(j *jwt.Token) (interface{}, error) { - keyID, ok := j.Header[HeaderKeyID].(string) - if !ok { - return nil, ErrMissingKeyID - } - k, ok := keys[keyID] - if !ok { - return nil, xerrors.Errorf("no key with ID %s", keyID) - } - return k, nil - } -} - -// licenseAPI handles enterprise licenses, and attaches to the main coderd.API via the -// LicenseHandler option, so that it serves all routes under /api/v2/licenses -type licenseAPI struct { - router chi.Router - logger slog.Logger - database database.Store - pubsub database.Pubsub - auth *coderd.HTTPAuthorizer -} - -func newLicenseAPI( - l slog.Logger, - db database.Store, - ps database.Pubsub, - auth *coderd.HTTPAuthorizer, -) *licenseAPI { - r := chi.NewRouter() - a := &licenseAPI{router: r, logger: l, database: db, pubsub: ps, auth: auth} - r.Post("/", a.postLicense) - r.Get("/", a.licenses) - r.Delete("/{id}", a.delete) - return a -} - -func (a *licenseAPI) handler() http.Handler { - return a.router -} - // postLicense adds a new Enterprise license to the cluster. We allow multiple different licenses // in the cluster at one time for several reasons: // @@ -167,8 +78,8 @@ func (a *licenseAPI) handler() http.Handler { // we generally don't want the old features to immediately break without warning. With a grace // period on the license, features will continue to work from the old license until its grace // period, then the users will get a warning allowing them to gracefully stop using the feature. -func (a *licenseAPI) postLicense(rw http.ResponseWriter, r *http.Request) { - if !a.auth.Authorize(r, rbac.ActionCreate, rbac.ResourceLicense) { +func (api *API) postLicense(rw http.ResponseWriter, r *http.Request) { + if !api.AGPL.Authorize(r, rbac.ActionCreate, rbac.ResourceLicense) { httpapi.Forbidden(rw) return } @@ -178,7 +89,7 @@ func (a *licenseAPI) postLicense(rw http.ResponseWriter, r *http.Request) { return } - claims, err := parseLicense(addLicense.License, keys) + claims, err := parseLicense(addLicense.License, api.Keys) if err != nil { httpapi.Write(rw, http.StatusBadRequest, codersdk.Response{ Message: "Invalid license", @@ -196,7 +107,7 @@ func (a *licenseAPI) postLicense(rw http.ResponseWriter, r *http.Request) { } expTime := time.Unix(int64(exp), 0) - dl, err := a.database.InsertLicense(r.Context(), database.InsertLicenseParams{ + dl, err := api.Database.InsertLicense(r.Context(), database.InsertLicenseParams{ UploadedAt: database.Now(), JWT: addLicense.License, Exp: expTime, @@ -208,25 +119,17 @@ func (a *licenseAPI) postLicense(rw http.ResponseWriter, r *http.Request) { }) return } - err = a.pubsub.Publish(PubSubEventLicenses, []byte("add")) + err = api.Pubsub.Publish(pubSubEventLicenses, []byte("add")) if err != nil { - a.logger.Error(context.Background(), "failed to publish license add", slog.Error(err)) + api.Logger.Error(context.Background(), "failed to publish license add", slog.Error(err)) // don't fail the HTTP request, since we did write it successfully to the database } httpapi.Write(rw, http.StatusCreated, convertLicense(dl, claims)) } -func convertLicense(dl database.License, c jwt.MapClaims) codersdk.License { - return codersdk.License{ - ID: dl.ID, - UploadedAt: dl.UploadedAt, - Claims: c, - } -} - -func (a *licenseAPI) licenses(rw http.ResponseWriter, r *http.Request) { - licenses, err := a.database.GetLicenses(r.Context()) +func (api *API) licenses(rw http.ResponseWriter, r *http.Request) { + licenses, err := api.Database.GetLicenses(r.Context()) if xerrors.Is(err, sql.ErrNoRows) { httpapi.Write(rw, http.StatusOK, []codersdk.License{}) return @@ -239,7 +142,7 @@ func (a *licenseAPI) licenses(rw http.ResponseWriter, r *http.Request) { return } - licenses, err = coderd.AuthorizeFilter(a.auth, r, rbac.ActionRead, licenses) + licenses, err = coderd.AuthorizeFilter(api.AGPL.HTTPAuth, r, rbac.ActionRead, licenses) if err != nil { httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{ Message: "Internal error fetching licenses.", @@ -258,6 +161,52 @@ func (a *licenseAPI) licenses(rw http.ResponseWriter, r *http.Request) { httpapi.Write(rw, http.StatusOK, sdkLicenses) } +func (api *API) deleteLicense(rw http.ResponseWriter, r *http.Request) { + if !api.AGPL.Authorize(r, rbac.ActionDelete, rbac.ResourceLicense) { + httpapi.Forbidden(rw) + return + } + + idStr := chi.URLParam(r, "id") + id, err := strconv.ParseInt(idStr, 10, 32) + if err != nil { + httpapi.Write(rw, http.StatusNotFound, codersdk.Response{ + Message: "License ID must be an integer", + }) + return + } + + _, err = api.Database.DeleteLicense(r.Context(), int32(id)) + if xerrors.Is(err, sql.ErrNoRows) { + httpapi.Write(rw, http.StatusNotFound, codersdk.Response{ + Message: "Unknown license ID", + }) + return + } + if err != nil { + httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{ + Message: "Internal error deleting license", + Detail: err.Error(), + }) + return + } + + err = api.Pubsub.Publish(pubSubEventLicenses, []byte("delete")) + if err != nil { + api.Logger.Error(context.Background(), "failed to publish license delete", slog.Error(err)) + // don't fail the HTTP request, since we did write it successfully to the database + } + rw.WriteHeader(http.StatusOK) +} + +func convertLicense(dl database.License, c jwt.MapClaims) codersdk.License { + return codersdk.License{ + ID: dl.ID, + UploadedAt: dl.UploadedAt, + Claims: c, + } +} + func convertLicenses(licenses []database.License) ([]codersdk.License, error) { var out []codersdk.License for _, l := range licenses { @@ -292,40 +241,64 @@ func decodeClaims(l database.License) (jwt.MapClaims, error) { return c, err } -func (a *licenseAPI) delete(rw http.ResponseWriter, r *http.Request) { - if !a.auth.Authorize(r, rbac.ActionDelete, rbac.ResourceLicense) { - httpapi.Forbidden(rw) - return - } - - idStr := chi.URLParam(r, "id") - id, err := strconv.ParseInt(idStr, 10, 32) +// parseLicense parses the license and returns the claims. If the license's signature is invalid or +// is not parsable, an error is returned. +func parseLicense(l string, keys map[string]ed25519.PublicKey) (jwt.MapClaims, error) { + tok, err := jwt.Parse( + l, + keyFunc(keys), + jwt.WithValidMethods(ValidMethods), + ) if err != nil { - httpapi.Write(rw, http.StatusNotFound, codersdk.Response{ - Message: "License ID must be an integer", - }) - return + return nil, err } - - _, err = a.database.DeleteLicense(r.Context(), int32(id)) - if xerrors.Is(err, sql.ErrNoRows) { - httpapi.Write(rw, http.StatusNotFound, codersdk.Response{ - Message: "Unknown license ID", - }) - return + if claims, ok := tok.Claims.(jwt.MapClaims); ok && tok.Valid { + version, ok := claims[VersionClaim].(float64) + if !ok { + return nil, ErrInvalidVersion + } + if int64(version) != CurrentVersion { + return nil, ErrInvalidVersion + } + return claims, nil } + return nil, xerrors.New("unable to parse Claims") +} + +// validateDBLicense validates a database.License record, and if valid, returns the claims. If +// unparsable or invalid, it returns an error +func validateDBLicense(l database.License, keys map[string]ed25519.PublicKey) (*Claims, error) { + tok, err := jwt.ParseWithClaims( + l.JWT, + &Claims{}, + keyFunc(keys), + jwt.WithValidMethods(ValidMethods), + ) if err != nil { - httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{ - Message: "Internal error deleting license", - Detail: err.Error(), - }) - return + return nil, err + } + if claims, ok := tok.Claims.(*Claims); ok && tok.Valid { + if claims.Version != uint64(CurrentVersion) { + return nil, ErrInvalidVersion + } + if claims.LicenseExpires == nil { + return nil, ErrMissingLicenseExpires + } + return claims, nil } + return nil, xerrors.New("unable to parse Claims") +} - err = a.pubsub.Publish(PubSubEventLicenses, []byte("delete")) - if err != nil { - a.logger.Error(context.Background(), "failed to publish license delete", slog.Error(err)) - // don't fail the HTTP request, since we did write it successfully to the database +func keyFunc(keys map[string]ed25519.PublicKey) func(*jwt.Token) (interface{}, error) { + return func(j *jwt.Token) (interface{}, error) { + keyID, ok := j.Header[HeaderKeyID].(string) + if !ok { + return nil, ErrMissingKeyID + } + k, ok := keys[keyID] + if !ok { + return nil, xerrors.Errorf("no key with ID %s", keyID) + } + return k, nil } - rw.WriteHeader(http.StatusOK) } diff --git a/enterprise/coderd/licenses_internal_test.go b/enterprise/coderd/licenses_internal_test.go deleted file mode 100644 index 5695ca0df5233..0000000000000 --- a/enterprise/coderd/licenses_internal_test.go +++ /dev/null @@ -1,316 +0,0 @@ -package coderd - -import ( - "context" - "crypto/ed25519" - "crypto/rand" - "encoding/json" - "net/http" - "testing" - "time" - - "golang.org/x/xerrors" - - "github.com/stretchr/testify/assert" - - "github.com/golang-jwt/jwt/v4" - "github.com/stretchr/testify/require" - - "github.com/coder/coder/coderd/coderdtest" - "github.com/coder/coder/codersdk" - "github.com/coder/coder/testutil" -) - -// these tests patch the map of license keys, so cannot be run in parallel -// nolint:paralleltest -func TestPostLicense(t *testing.T) { - pubKey, privKey, err := ed25519.GenerateKey(rand.Reader) - require.NoError(t, err) - keyID := "testing" - oldKeys := keys - defer func() { - t.Log("restoring keys") - keys = oldKeys - }() - keys = map[string]ed25519.PublicKey{keyID: pubKey} - - t.Run("POST", func(t *testing.T) { - client := coderdtest.New(t, &coderdtest.Options{APIBuilder: NewEnterprise}) - _ = coderdtest.CreateFirstUser(t, client) - ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) - defer cancel() - - claims := &Claims{ - RegisteredClaims: jwt.RegisteredClaims{ - Issuer: "test@coder.test", - IssuedAt: jwt.NewNumericDate(time.Now()), - NotBefore: jwt.NewNumericDate(time.Now()), - ExpiresAt: jwt.NewNumericDate(time.Now().Add(2 * time.Hour)), - }, - LicenseExpires: jwt.NewNumericDate(time.Now().Add(time.Hour)), - AccountType: AccountTypeSalesforce, - AccountID: "testing", - Version: CurrentVersion, - Features: Features{ - UserLimit: 0, - AuditLog: 1, - }, - } - lic, err := makeLicense(claims, privKey, keyID) - require.NoError(t, err) - - respLic, err := client.AddLicense(ctx, codersdk.AddLicenseRequest{ - License: lic, - }) - require.NoError(t, err) - assert.GreaterOrEqual(t, respLic.ID, int32(0)) - // just a couple spot checks for sanity - assert.Equal(t, claims.AccountID, respLic.Claims["account_id"]) - features, ok := respLic.Claims["features"].(map[string]interface{}) - require.True(t, ok) - assert.Equal(t, json.Number("1"), features[codersdk.FeatureAuditLog]) - }) - - t.Run("POST_unathorized", func(t *testing.T) { - client := coderdtest.New(t, &coderdtest.Options{APIBuilder: NewEnterprise}) - ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) - defer cancel() - - claims := &Claims{ - RegisteredClaims: jwt.RegisteredClaims{ - Issuer: "test@coder.test", - IssuedAt: jwt.NewNumericDate(time.Now()), - NotBefore: jwt.NewNumericDate(time.Now()), - ExpiresAt: jwt.NewNumericDate(time.Now().Add(2 * time.Hour)), - }, - LicenseExpires: jwt.NewNumericDate(time.Now().Add(time.Hour)), - AccountType: AccountTypeSalesforce, - AccountID: "testing", - Version: CurrentVersion, - Features: Features{ - UserLimit: 0, - AuditLog: 1, - }, - } - lic, err := makeLicense(claims, privKey, keyID) - require.NoError(t, err) - - _, err = client.AddLicense(ctx, codersdk.AddLicenseRequest{ - License: lic, - }) - errResp := &codersdk.Error{} - if xerrors.As(err, &errResp) { - assert.Equal(t, 401, errResp.StatusCode()) - } else { - t.Error("expected to get error status 401") - } - }) - - t.Run("POST_corrupted", func(t *testing.T) { - client := coderdtest.New(t, &coderdtest.Options{APIBuilder: NewEnterprise}) - _ = coderdtest.CreateFirstUser(t, client) - ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) - defer cancel() - - claims := &Claims{ - RegisteredClaims: jwt.RegisteredClaims{ - Issuer: "test@coder.test", - IssuedAt: jwt.NewNumericDate(time.Now()), - NotBefore: jwt.NewNumericDate(time.Now()), - ExpiresAt: jwt.NewNumericDate(time.Now().Add(2 * time.Hour)), - }, - LicenseExpires: jwt.NewNumericDate(time.Now().Add(time.Hour)), - AccountType: AccountTypeSalesforce, - AccountID: "testing", - Version: CurrentVersion, - Features: Features{ - UserLimit: 0, - AuditLog: 1, - }, - } - lic, err := makeLicense(claims, privKey, keyID) - require.NoError(t, err) - - _, err = client.AddLicense(ctx, codersdk.AddLicenseRequest{ - License: "h" + lic, - }) - errResp := &codersdk.Error{} - if xerrors.As(err, &errResp) { - assert.Equal(t, 400, errResp.StatusCode()) - } else { - t.Error("expected to get error status 400") - } - }) -} - -// these tests patch the map of license keys, so cannot be run in parallel -// nolint:paralleltest -func TestGetLicense(t *testing.T) { - pubKey, privKey, err := ed25519.GenerateKey(rand.Reader) - require.NoError(t, err) - keyID := "testing" - oldKeys := keys - defer func() { - t.Log("restoring keys") - keys = oldKeys - }() - keys = map[string]ed25519.PublicKey{keyID: pubKey} - - t.Run("GET", func(t *testing.T) { - client := coderdtest.New(t, &coderdtest.Options{APIBuilder: NewEnterprise}) - _ = coderdtest.CreateFirstUser(t, client) - ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) - defer cancel() - - claims := &Claims{ - RegisteredClaims: jwt.RegisteredClaims{ - Issuer: "test@coder.test", - IssuedAt: jwt.NewNumericDate(time.Now()), - NotBefore: jwt.NewNumericDate(time.Now()), - ExpiresAt: jwt.NewNumericDate(time.Now().Add(2 * time.Hour)), - }, - LicenseExpires: jwt.NewNumericDate(time.Now().Add(time.Hour)), - AccountType: AccountTypeSalesforce, - AccountID: "testing", - Version: CurrentVersion, - Features: Features{ - UserLimit: 0, - AuditLog: 1, - }, - } - lic, err := makeLicense(claims, privKey, keyID) - require.NoError(t, err) - _, err = client.AddLicense(ctx, codersdk.AddLicenseRequest{ - License: lic, - }) - require.NoError(t, err) - - // 2nd license - claims.AccountID = "testing2" - claims.Features.UserLimit = 200 - lic2, err := makeLicense(claims, privKey, keyID) - require.NoError(t, err) - _, err = client.AddLicense(ctx, codersdk.AddLicenseRequest{ - License: lic2, - }) - require.NoError(t, err) - - licenses, err := client.Licenses(ctx) - require.NoError(t, err) - require.Len(t, licenses, 2) - assert.Equal(t, int32(1), licenses[0].ID) - assert.Equal(t, "testing", licenses[0].Claims["account_id"]) - assert.Equal(t, map[string]interface{}{ - codersdk.FeatureUserLimit: json.Number("0"), - codersdk.FeatureAuditLog: json.Number("1"), - }, licenses[0].Claims["features"]) - assert.Equal(t, int32(2), licenses[1].ID) - assert.Equal(t, "testing2", licenses[1].Claims["account_id"]) - assert.Equal(t, map[string]interface{}{ - codersdk.FeatureUserLimit: json.Number("200"), - codersdk.FeatureAuditLog: json.Number("1"), - }, licenses[1].Claims["features"]) - }) -} - -// these tests patch the map of license keys, so cannot be run in parallel -// nolint:paralleltest -func TestDeleteLicense(t *testing.T) { - pubKey, privKey, err := ed25519.GenerateKey(rand.Reader) - require.NoError(t, err) - keyID := "testing" - oldKeys := keys - defer func() { - t.Log("restoring keys") - keys = oldKeys - }() - keys = map[string]ed25519.PublicKey{keyID: pubKey} - - t.Run("DELETE_empty", func(t *testing.T) { - client := coderdtest.New(t, &coderdtest.Options{APIBuilder: NewEnterprise}) - _ = coderdtest.CreateFirstUser(t, client) - ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) - defer cancel() - - err := client.DeleteLicense(ctx, 1) - errResp := &codersdk.Error{} - if xerrors.As(err, &errResp) { - assert.Equal(t, 404, errResp.StatusCode()) - } else { - t.Error("expected to get error status 404") - } - }) - - t.Run("DELETE_bad_id", func(t *testing.T) { - client := coderdtest.New(t, &coderdtest.Options{APIBuilder: NewEnterprise}) - _ = coderdtest.CreateFirstUser(t, client) - ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) - defer cancel() - - resp, err := client.Request(ctx, http.MethodDelete, "/api/v2/licenses/drivers", nil) - require.NoError(t, err) - assert.Equal(t, http.StatusNotFound, resp.StatusCode) - require.NoError(t, resp.Body.Close()) - }) - - t.Run("DELETE", func(t *testing.T) { - client := coderdtest.New(t, &coderdtest.Options{APIBuilder: NewEnterprise}) - _ = coderdtest.CreateFirstUser(t, client) - ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) - defer cancel() - - claims := &Claims{ - RegisteredClaims: jwt.RegisteredClaims{ - Issuer: "test@coder.test", - IssuedAt: jwt.NewNumericDate(time.Now()), - NotBefore: jwt.NewNumericDate(time.Now()), - ExpiresAt: jwt.NewNumericDate(time.Now().Add(2 * time.Hour)), - }, - LicenseExpires: jwt.NewNumericDate(time.Now().Add(time.Hour)), - AccountType: AccountTypeSalesforce, - AccountID: "testing", - Version: CurrentVersion, - Features: Features{ - UserLimit: 0, - AuditLog: 1, - }, - } - lic, err := makeLicense(claims, privKey, keyID) - require.NoError(t, err) - _, err = client.AddLicense(ctx, codersdk.AddLicenseRequest{ - License: lic, - }) - require.NoError(t, err) - - // 2nd license - claims.AccountID = "testing2" - claims.Features.UserLimit = 200 - lic2, err := makeLicense(claims, privKey, keyID) - require.NoError(t, err) - _, err = client.AddLicense(ctx, codersdk.AddLicenseRequest{ - License: lic2, - }) - require.NoError(t, err) - - licenses, err := client.Licenses(ctx) - require.NoError(t, err) - assert.Len(t, licenses, 2) - for _, l := range licenses { - err = client.DeleteLicense(ctx, l.ID) - require.NoError(t, err) - } - licenses, err = client.Licenses(ctx) - require.NoError(t, err) - assert.Len(t, licenses, 0) - }) -} - -func makeLicense(c *Claims, privateKey ed25519.PrivateKey, keyID string) (string, error) { - tok := jwt.NewWithClaims(jwt.SigningMethodEdDSA, c) - tok.Header[HeaderKeyID] = keyID - signedTok, err := tok.SignedString(privateKey) - if err != nil { - return "", xerrors.Errorf("sign license: %w", err) - } - return signedTok, nil -} diff --git a/enterprise/coderd/licenses_test.go b/enterprise/coderd/licenses_test.go new file mode 100644 index 0000000000000..b518a684d6c05 --- /dev/null +++ b/enterprise/coderd/licenses_test.go @@ -0,0 +1,168 @@ +package coderd_test + +import ( + "context" + "encoding/json" + "net/http" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/xerrors" + + "github.com/coder/coder/coderd/coderdtest" + "github.com/coder/coder/codersdk" + "github.com/coder/coder/enterprise/coderd" + "github.com/coder/coder/enterprise/coderd/coderdenttest" + "github.com/coder/coder/testutil" +) + +func TestPostLicense(t *testing.T) { + t.Parallel() + + t.Run("POST", func(t *testing.T) { + t.Parallel() + client := coderdenttest.New(t, nil) + _ = coderdtest.CreateFirstUser(t, client) + respLic := coderdenttest.AddLicense(t, client, coderdenttest.AddLicenseOptions{ + AccountType: coderd.AccountTypeSalesforce, + AccountID: "testing", + AuditLog: true, + }) + assert.GreaterOrEqual(t, respLic.ID, int32(0)) + // just a couple spot checks for sanity + assert.Equal(t, "testing", respLic.Claims["account_id"]) + features, ok := respLic.Claims["features"].(map[string]interface{}) + require.True(t, ok) + assert.Equal(t, json.Number("1"), features[codersdk.FeatureAuditLog]) + }) + + t.Run("POST_unauthorized", func(t *testing.T) { + t.Parallel() + client := coderdenttest.New(t, nil) + _, err := client.AddLicense(context.Background(), codersdk.AddLicenseRequest{ + License: "content", + }) + errResp := &codersdk.Error{} + if xerrors.As(err, &errResp) { + assert.Equal(t, 401, errResp.StatusCode()) + } else { + t.Error("expected to get error status 401") + } + }) + + t.Run("POST_corrupted", func(t *testing.T) { + t.Parallel() + client := coderdenttest.New(t, nil) + _ = coderdtest.CreateFirstUser(t, client) + coderdenttest.AddLicense(t, client, coderdenttest.AddLicenseOptions{}) + _, err := client.AddLicense(context.Background(), codersdk.AddLicenseRequest{ + License: "invalid", + }) + errResp := &codersdk.Error{} + if xerrors.As(err, &errResp) { + assert.Equal(t, 400, errResp.StatusCode()) + } else { + t.Error("expected to get error status 400") + } + }) +} + +func TestGetLicense(t *testing.T) { + t.Parallel() + t.Run("GET", func(t *testing.T) { + t.Parallel() + client := coderdenttest.New(t, nil) + _ = coderdtest.CreateFirstUser(t, client) + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) + defer cancel() + + coderdenttest.AddLicense(t, client, coderdenttest.AddLicenseOptions{ + AccountID: "testing", + AuditLog: true, + }) + + coderdenttest.AddLicense(t, client, coderdenttest.AddLicenseOptions{ + AccountID: "testing2", + AuditLog: true, + UserLimit: 200, + }) + + licenses, err := client.Licenses(ctx) + require.NoError(t, err) + require.Len(t, licenses, 2) + assert.Equal(t, int32(1), licenses[0].ID) + assert.Equal(t, "testing", licenses[0].Claims["account_id"]) + assert.Equal(t, map[string]interface{}{ + codersdk.FeatureUserLimit: json.Number("0"), + codersdk.FeatureAuditLog: json.Number("1"), + }, licenses[0].Claims["features"]) + assert.Equal(t, int32(2), licenses[1].ID) + assert.Equal(t, "testing2", licenses[1].Claims["account_id"]) + assert.Equal(t, map[string]interface{}{ + codersdk.FeatureUserLimit: json.Number("200"), + codersdk.FeatureAuditLog: json.Number("1"), + }, licenses[1].Claims["features"]) + }) +} + +func TestDeleteLicense(t *testing.T) { + t.Parallel() + t.Run("DELETE_empty", func(t *testing.T) { + t.Parallel() + client := coderdenttest.New(t, nil) + _ = coderdtest.CreateFirstUser(t, client) + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) + defer cancel() + + err := client.DeleteLicense(ctx, 1) + errResp := &codersdk.Error{} + if xerrors.As(err, &errResp) { + assert.Equal(t, 404, errResp.StatusCode()) + } else { + t.Error("expected to get error status 404") + } + }) + + t.Run("DELETE_bad_id", func(t *testing.T) { + t.Parallel() + client := coderdenttest.New(t, nil) + _ = coderdtest.CreateFirstUser(t, client) + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) + defer cancel() + + resp, err := client.Request(ctx, http.MethodDelete, "/api/v2/licenses/drivers", nil) + require.NoError(t, err) + assert.Equal(t, http.StatusNotFound, resp.StatusCode) + require.NoError(t, resp.Body.Close()) + }) + + t.Run("DELETE", func(t *testing.T) { + t.Parallel() + client := coderdenttest.New(t, nil) + _ = coderdtest.CreateFirstUser(t, client) + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) + defer cancel() + + coderdenttest.AddLicense(t, client, coderdenttest.AddLicenseOptions{ + AccountID: "testing", + AuditLog: true, + }) + coderdenttest.AddLicense(t, client, coderdenttest.AddLicenseOptions{ + AccountID: "testing2", + AuditLog: true, + UserLimit: 200, + }) + + licenses, err := client.Licenses(ctx) + require.NoError(t, err) + assert.Len(t, licenses, 2) + for _, l := range licenses { + err = client.DeleteLicense(ctx, l.ID) + require.NoError(t, err) + } + licenses, err = client.Licenses(ctx) + require.NoError(t, err) + assert.Len(t, licenses, 0) + }) +} From aad7be440c5e9dad2733a09b235260967cf07e71 Mon Sep 17 00:00:00 2001 From: Kyle Carberry Date: Tue, 13 Sep 2022 17:37:18 +0000 Subject: [PATCH 02/19] Fix Garrett's comments --- coderd/coderd.go | 4 ++-- enterprise/coderd/coderd.go | 2 ++ 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/coderd/coderd.go b/coderd/coderd.go index fa76c10950f9a..d399fb331ae48 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -466,7 +466,7 @@ func New(options *Options) *API { }) r.Route("/entitlements", func(r chi.Router) { r.Use(apiKeyMiddleware) - r.Get("/", entitlements) + r.Get("/", nopEntitlements) }) r.HandleFunc("/licenses", unsupported) }) @@ -523,7 +523,7 @@ func compressHandler(h http.Handler) http.Handler { return cmp.Handler(h) } -func entitlements(rw http.ResponseWriter, _ *http.Request) { +func nopEntitlements(rw http.ResponseWriter, _ *http.Request) { feats := make(map[string]codersdk.Feature) for _, f := range codersdk.FeatureNames { feats[f] = codersdk.Feature{ diff --git a/enterprise/coderd/coderd.go b/enterprise/coderd/coderd.go index 0058e830271ea..a4e1bce59ebf8 100644 --- a/enterprise/coderd/coderd.go +++ b/enterprise/coderd/coderd.go @@ -152,6 +152,8 @@ func (api *API) updateEntitlements(ctx context.Context) error { } if auditLogs != api.auditLogs { auditor := agplaudit.NewNop() + // A flag could be added to the options that would allow disabling + // enhanced audit logging here! if api.auditLogs == codersdk.EntitlementEntitled { auditor = audit.NewAuditor( audit.DefaultFilter, From 27f53aa77def69fb15c043782784958b58ca4969 Mon Sep 17 00:00:00 2001 From: Kyle Carberry Date: Thu, 15 Sep 2022 03:24:16 +0000 Subject: [PATCH 03/19] Add pointer.Handle to atomically obtain references This uses a context to ensure the same value persists through multiple executions to `Load()`. --- coderd/coderd.go | 28 +++++++++---------------- coderd/pointer/pointer.go | 37 ++++++++++++++++++++++++++++++++++ coderd/pointer/pointer_test.go | 34 +++++++++++++++++++++++++++++++ coderd/templates.go | 11 ++++++---- coderd/templateversions.go | 6 ++++-- coderd/users.go | 18 +++++++++++------ coderd/workspaces.go | 12 +++++++---- enterprise/coderd/coderd.go | 2 +- 8 files changed, 112 insertions(+), 36 deletions(-) create mode 100644 coderd/pointer/pointer.go create mode 100644 coderd/pointer/pointer_test.go diff --git a/coderd/coderd.go b/coderd/coderd.go index 11e4a4cd3a8e1..ea31815f10c09 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -7,7 +7,6 @@ import ( "net/url" "path/filepath" "sync" - "sync/atomic" "time" "github.com/andybalholm/brotli" @@ -33,6 +32,7 @@ import ( "github.com/coder/coder/coderd/httpapi" "github.com/coder/coder/coderd/httpmw" "github.com/coder/coder/coderd/metricscache" + "github.com/coder/coder/coderd/pointer" "github.com/coder/coder/coderd/rbac" "github.com/coder/coder/coderd/telemetry" "github.com/coder/coder/coderd/tracing" @@ -148,9 +148,8 @@ func New(options *Options) *API { Logger: options.Logger, }, metricsCache: metricsCache, - Auditor: atomic.Pointer[audit.Auditor]{}, + Auditor: pointer.New(options.Auditor), } - api.Auditor.Store(&options.Auditor) if options.TailscaleEnable { api.workspaceAgentCache = wsconncache.New(api.dialWorkspaceAgentTailnet, 0) @@ -495,7 +494,6 @@ func New(options *Options) *API { r.Use(apiKeyMiddleware) r.Get("/", nopEntitlements) }) - r.HandleFunc("/licenses", unsupported) }) r.NotFound(compressHandler(http.HandlerFunc(api.siteHandler.ServeHTTP)).ServeHTTP) @@ -504,19 +502,19 @@ func New(options *Options) *API { type API struct { *Options + Auditor *pointer.Handle[audit.Auditor] + HTTPAuth *HTTPAuthorizer - derpServer *derp.Server + // APIHandler serves "/api/v2" and all children routes. + APIHandler chi.Router + RootHandler chi.Router - Auditor atomic.Pointer[audit.Auditor] - RootHandler chi.Router - APIHandler chi.Router + derpServer *derp.Server + metricsCache *metricscache.Cache siteHandler http.Handler websocketWaitMutex sync.Mutex websocketWaitGroup sync.WaitGroup workspaceAgentCache *wsconncache.Cache - HTTPAuth *HTTPAuthorizer - - metricsCache *metricscache.Cache } // Close waits for all WebSocket connections to drain before returning. @@ -564,11 +562,3 @@ func nopEntitlements(rw http.ResponseWriter, _ *http.Request) { HasLicense: false, }) } - -func unsupported(rw http.ResponseWriter, _ *http.Request) { - httpapi.Write(rw, http.StatusNotFound, codersdk.Response{ - Message: "Unsupported", - Detail: "These endpoints are not supported in AGPL-licensed Coder", - Validations: nil, - }) -} diff --git a/coderd/pointer/pointer.go b/coderd/pointer/pointer.go new file mode 100644 index 0000000000000..ce0c0abc12973 --- /dev/null +++ b/coderd/pointer/pointer.go @@ -0,0 +1,37 @@ +package pointer + +import ( + "context" + + "go.uber.org/atomic" +) + +func New[T any](value T) *Handle[T] { + h := &Handle[T]{ + key: struct{}{}, + ptr: atomic.Pointer[T]{}, + } + h.Store(value) + return h +} + +// Handle loads the stored value into a context, and returns +// a context with the attached value. It's intention is to +// hold a single handle for the lifecycle of a request. +type Handle[T any] struct { + key struct{} + ptr atomic.Pointer[T] +} + +func (p *Handle[T]) Load(ctx context.Context) (context.Context, T) { + value, ok := ctx.Value(&p.key).(T) + if !ok { + ctx = context.WithValue(ctx, &p.key, *p.ptr.Load()) + return p.Load(ctx) + } + return ctx, value +} + +func (p *Handle[T]) Store(t T) { + p.ptr.Store(&t) +} diff --git a/coderd/pointer/pointer_test.go b/coderd/pointer/pointer_test.go new file mode 100644 index 0000000000000..f41ef8cd36bd5 --- /dev/null +++ b/coderd/pointer/pointer_test.go @@ -0,0 +1,34 @@ +package pointer_test + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/coder/coder/coderd/pointer" +) + +func TestHandle(t *testing.T) { + t.Parallel() + t.Run("Single", func(t *testing.T) { + t.Parallel() + ptr := pointer.New("hello") + ctx := context.Background() + ctx, value := ptr.Load(ctx) + require.Equal(t, "hello", value) + ptr.Store("world") + _, value = ptr.Load(ctx) + require.Equal(t, "hello", value) + }) + t.Run("Multiple", func(t *testing.T) { + t.Parallel() + ptr1 := pointer.New("1") + ptr2 := pointer.New("2") + ctx := context.Background() + ctx, v1 := ptr1.Load(ctx) + require.Equal(t, "1", v1) + _, v2 := ptr2.Load(ctx) + require.Equal(t, "2", v2) + }) +} diff --git a/coderd/templates.go b/coderd/templates.go index 06aea92fd7d5b..50dd788200930 100644 --- a/coderd/templates.go +++ b/coderd/templates.go @@ -85,8 +85,9 @@ func (api *API) template(rw http.ResponseWriter, r *http.Request) { func (api *API) deleteTemplate(rw http.ResponseWriter, r *http.Request) { var ( template = httpmw.TemplateParam(r) + _, auditor = api.Auditor.Load(r.Context()) aReq, commitAudit = audit.InitRequest[database.Template](rw, &audit.RequestParams{ - Audit: *api.Auditor.Load(), + Audit: auditor, Log: api.Logger, Request: r, Action: database.AuditActionDelete, @@ -139,14 +140,15 @@ func (api *API) postTemplateByOrganization(rw http.ResponseWriter, r *http.Reque createTemplate codersdk.CreateTemplateRequest organization = httpmw.OrganizationParam(r) apiKey = httpmw.APIKey(r) + _, auditor = api.Auditor.Load(r.Context()) templateAudit, commitTemplateAudit = audit.InitRequest[database.Template](rw, &audit.RequestParams{ - Audit: *api.Auditor.Load(), + Audit: auditor, Log: api.Logger, Request: r, Action: database.AuditActionCreate, }) templateVersionAudit, commitTemplateVersionAudit = audit.InitRequest[database.TemplateVersion](rw, &audit.RequestParams{ - Audit: *api.Auditor.Load(), + Audit: auditor, Log: api.Logger, Request: r, Action: database.AuditActionWrite, @@ -435,8 +437,9 @@ func (api *API) templateByOrganizationAndName(rw http.ResponseWriter, r *http.Re func (api *API) patchTemplateMeta(rw http.ResponseWriter, r *http.Request) { var ( template = httpmw.TemplateParam(r) + _, auditor = api.Auditor.Load(r.Context()) aReq, commitAudit = audit.InitRequest[database.Template](rw, &audit.RequestParams{ - Audit: *api.Auditor.Load(), + Audit: auditor, Log: api.Logger, Request: r, Action: database.AuditActionWrite, diff --git a/coderd/templateversions.go b/coderd/templateversions.go index 843a2b8d015f0..37517073aa254 100644 --- a/coderd/templateversions.go +++ b/coderd/templateversions.go @@ -559,8 +559,9 @@ func (api *API) templateVersionByName(rw http.ResponseWriter, r *http.Request) { func (api *API) patchActiveTemplateVersion(rw http.ResponseWriter, r *http.Request) { var ( template = httpmw.TemplateParam(r) + _, auditor = api.Auditor.Load(r.Context()) aReq, commitAudit = audit.InitRequest[database.Template](rw, &audit.RequestParams{ - Audit: *api.Auditor.Load(), + Audit: auditor, Log: api.Logger, Request: r, Action: database.AuditActionWrite, @@ -631,8 +632,9 @@ func (api *API) postTemplateVersionsByOrganization(rw http.ResponseWriter, r *ht var ( apiKey = httpmw.APIKey(r) organization = httpmw.OrganizationParam(r) + _, auditor = api.Auditor.Load(r.Context()) aReq, commitAudit = audit.InitRequest[database.TemplateVersion](rw, &audit.RequestParams{ - Audit: *api.Auditor.Load(), + Audit: auditor, Log: api.Logger, Request: r, Action: database.AuditActionCreate, diff --git a/coderd/users.go b/coderd/users.go index 67194ecf9a9ca..683710791b2c8 100644 --- a/coderd/users.go +++ b/coderd/users.go @@ -255,8 +255,9 @@ func (api *API) users(rw http.ResponseWriter, r *http.Request) { // Creates a new user. func (api *API) postUser(rw http.ResponseWriter, r *http.Request) { + _, auditor := api.Auditor.Load(r.Context()) aReq, commitAudit := audit.InitRequest[database.User](rw, &audit.RequestParams{ - Audit: *api.Auditor.Load(), + Audit: auditor, Log: api.Logger, Request: r, Action: database.AuditActionCreate, @@ -339,9 +340,10 @@ func (api *API) postUser(rw http.ResponseWriter, r *http.Request) { } func (api *API) deleteUser(rw http.ResponseWriter, r *http.Request) { + _, auditor := api.Auditor.Load(r.Context()) user := httpmw.UserParam(r) aReq, commitAudit := audit.InitRequest[database.User](rw, &audit.RequestParams{ - Audit: *api.Auditor.Load(), + Audit: auditor, Log: api.Logger, Request: r, Action: database.AuditActionDelete, @@ -414,8 +416,9 @@ func (api *API) userByName(rw http.ResponseWriter, r *http.Request) { func (api *API) putUserProfile(rw http.ResponseWriter, r *http.Request) { var ( user = httpmw.UserParam(r) + _, auditor = api.Auditor.Load(r.Context()) aReq, commitAudit = audit.InitRequest[database.User](rw, &audit.RequestParams{ - Audit: *api.Auditor.Load(), + Audit: auditor, Log: api.Logger, Request: r, Action: database.AuditActionWrite, @@ -494,8 +497,9 @@ func (api *API) putUserStatus(status database.UserStatus) func(rw http.ResponseW var ( user = httpmw.UserParam(r) apiKey = httpmw.APIKey(r) + _, auditor = api.Auditor.Load(r.Context()) aReq, commitAudit = audit.InitRequest[database.User](rw, &audit.RequestParams{ - Audit: *api.Auditor.Load(), + Audit: auditor, Log: api.Logger, Request: r, Action: database.AuditActionWrite, @@ -560,8 +564,9 @@ func (api *API) putUserPassword(rw http.ResponseWriter, r *http.Request) { var ( user = httpmw.UserParam(r) params codersdk.UpdateUserPasswordRequest + _, auditor = api.Auditor.Load(r.Context()) aReq, commitAudit = audit.InitRequest[database.User](rw, &audit.RequestParams{ - Audit: *api.Auditor.Load(), + Audit: auditor, Log: api.Logger, Request: r, Action: database.AuditActionWrite, @@ -698,8 +703,9 @@ func (api *API) putUserRoles(rw http.ResponseWriter, r *http.Request) { user = httpmw.UserParam(r) actorRoles = httpmw.AuthorizationUserRoles(r) apiKey = httpmw.APIKey(r) + _, auditor = api.Auditor.Load(r.Context()) aReq, commitAudit = audit.InitRequest[database.User](rw, &audit.RequestParams{ - Audit: *api.Auditor.Load(), + Audit: auditor, Log: api.Logger, Request: r, Action: database.AuditActionWrite, diff --git a/coderd/workspaces.go b/coderd/workspaces.go index 26d4c7b280204..1a6dbd1ec0cc0 100644 --- a/coderd/workspaces.go +++ b/coderd/workspaces.go @@ -254,8 +254,9 @@ func (api *API) postWorkspacesByOrganization(rw http.ResponseWriter, r *http.Req var ( organization = httpmw.OrganizationParam(r) apiKey = httpmw.APIKey(r) + _, auditor = api.Auditor.Load(r.Context()) aReq, commitAudit = audit.InitRequest[database.Workspace](rw, &audit.RequestParams{ - Audit: *api.Auditor.Load(), + Audit: auditor, Log: api.Logger, Request: r, Action: database.AuditActionCreate, @@ -495,8 +496,9 @@ func (api *API) postWorkspacesByOrganization(rw http.ResponseWriter, r *http.Req func (api *API) patchWorkspace(rw http.ResponseWriter, r *http.Request) { var ( workspace = httpmw.WorkspaceParam(r) + _, auditor = api.Auditor.Load(r.Context()) aReq, commitAudit = audit.InitRequest[database.Workspace](rw, &audit.RequestParams{ - Audit: *api.Auditor.Load(), + Audit: auditor, Log: api.Logger, Request: r, Action: database.AuditActionWrite, @@ -571,8 +573,9 @@ func (api *API) patchWorkspace(rw http.ResponseWriter, r *http.Request) { func (api *API) putWorkspaceAutostart(rw http.ResponseWriter, r *http.Request) { var ( workspace = httpmw.WorkspaceParam(r) + _, auditor = api.Auditor.Load(r.Context()) aReq, commitAudit = audit.InitRequest[database.Workspace](rw, &audit.RequestParams{ - Audit: *api.Auditor.Load(), + Audit: auditor, Log: api.Logger, Request: r, Action: database.AuditActionWrite, @@ -631,8 +634,9 @@ func (api *API) putWorkspaceAutostart(rw http.ResponseWriter, r *http.Request) { func (api *API) putWorkspaceTTL(rw http.ResponseWriter, r *http.Request) { var ( workspace = httpmw.WorkspaceParam(r) + _, auditor = api.Auditor.Load(r.Context()) aReq, commitAudit = audit.InitRequest[database.Workspace](rw, &audit.RequestParams{ - Audit: *api.Auditor.Load(), + Audit: auditor, Log: api.Logger, Request: r, Action: database.AuditActionWrite, diff --git a/enterprise/coderd/coderd.go b/enterprise/coderd/coderd.go index a4e1bce59ebf8..ff1c29fbf7675 100644 --- a/enterprise/coderd/coderd.go +++ b/enterprise/coderd/coderd.go @@ -161,7 +161,7 @@ func (api *API) updateEntitlements(ctx context.Context) error { backends.NewSlog(api.Logger), ) } - api.AGPL.Auditor.Store(&auditor) + api.AGPL.Auditor.Store(auditor) } return nil } From ed5b96df8b7b49fbbcdddcde32de35b6afc64a05 Mon Sep 17 00:00:00 2001 From: Kyle Carberry Date: Thu, 15 Sep 2022 04:03:51 +0000 Subject: [PATCH 04/19] Remove entitlements API from AGPL coderd --- cli/root.go | 1 - coderd/coderd.go | 24 +++-------------- coderd/pointer/pointer.go | 1 + {cli => enterprise/cli}/features.go | 9 ++++++- {cli => enterprise/cli}/features_test.go | 10 ++++--- enterprise/cli/root.go | 1 + enterprise/coderd/coderd.go | 4 +++ site/src/api/api.ts | 27 +++++++++++++++++-- .../entitlements/entitlementsXService.ts | 2 +- 9 files changed, 49 insertions(+), 30 deletions(-) rename {cli => enterprise/cli}/features.go (88%) rename {cli => enterprise/cli}/features_test.go (80%) diff --git a/cli/root.go b/cli/root.go index 430201d049506..adb4713ccb2b4 100644 --- a/cli/root.go +++ b/cli/root.go @@ -91,7 +91,6 @@ func Core() []*cobra.Command { users(), versionCmd(), workspaceAgent(), - features(), } } diff --git a/coderd/coderd.go b/coderd/coderd.go index ea31815f10c09..666a0501f4de3 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -490,10 +490,6 @@ func New(options *Options) *API { r.Get("/resources", api.workspaceBuildResources) r.Get("/state", api.workspaceBuildState) }) - r.Route("/entitlements", func(r chi.Router) { - r.Use(apiKeyMiddleware) - r.Get("/", nopEntitlements) - }) }) r.NotFound(compressHandler(http.HandlerFunc(api.siteHandler.ServeHTTP)).ServeHTTP) @@ -505,8 +501,9 @@ type API struct { Auditor *pointer.Handle[audit.Auditor] HTTPAuth *HTTPAuthorizer - // APIHandler serves "/api/v2" and all children routes. - APIHandler chi.Router + // APIHandler serves "/api/v2" + APIHandler chi.Router + // RootHandler serves "/" RootHandler chi.Router derpServer *derp.Server @@ -547,18 +544,3 @@ func compressHandler(h http.Handler) http.Handler { return cmp.Handler(h) } - -func nopEntitlements(rw http.ResponseWriter, _ *http.Request) { - feats := make(map[string]codersdk.Feature) - for _, f := range codersdk.FeatureNames { - feats[f] = codersdk.Feature{ - Entitlement: codersdk.EntitlementNotEntitled, - Enabled: false, - } - } - httpapi.Write(rw, http.StatusOK, codersdk.Entitlements{ - Features: feats, - Warnings: []string{}, - HasLicense: false, - }) -} diff --git a/coderd/pointer/pointer.go b/coderd/pointer/pointer.go index ce0c0abc12973..733b022d2bc35 100644 --- a/coderd/pointer/pointer.go +++ b/coderd/pointer/pointer.go @@ -6,6 +6,7 @@ import ( "go.uber.org/atomic" ) +// New constructs a Handle with an initialized value. func New[T any](value T) *Handle[T] { h := &Handle[T]{ key: struct{}{}, diff --git a/cli/features.go b/enterprise/cli/features.go similarity index 88% rename from cli/features.go rename to enterprise/cli/features.go index f430534330816..25db628dea87d 100644 --- a/cli/features.go +++ b/enterprise/cli/features.go @@ -3,12 +3,15 @@ package cli import ( "bytes" "encoding/json" + "errors" "fmt" + "net/http" "strings" "github.com/spf13/cobra" "golang.org/x/xerrors" + agpl "github.com/coder/coder/cli" "github.com/coder/coder/cli/cliui" "github.com/coder/coder/codersdk" ) @@ -36,11 +39,15 @@ func featuresList() *cobra.Command { Use: "list", Aliases: []string{"ls"}, RunE: func(cmd *cobra.Command, args []string) error { - client, err := CreateClient(cmd) + client, err := agpl.CreateClient(cmd) if err != nil { return err } entitlements, err := client.Entitlements(cmd.Context()) + var apiError *codersdk.Error + if errors.As(err, &apiError) && apiError.StatusCode() == http.StatusNotFound { + return xerrors.New("You are on the AGPL licensed version of Coder that does not have Enterprise functionality!") + } if err != nil { return err } diff --git a/cli/features_test.go b/enterprise/cli/features_test.go similarity index 80% rename from cli/features_test.go rename to enterprise/cli/features_test.go index 6c39fec81011a..7f7d13a5180d6 100644 --- a/cli/features_test.go +++ b/enterprise/cli/features_test.go @@ -11,6 +11,8 @@ import ( "github.com/coder/coder/cli/clitest" "github.com/coder/coder/coderd/coderdtest" "github.com/coder/coder/codersdk" + "github.com/coder/coder/enterprise/cli" + "github.com/coder/coder/enterprise/coderd/coderdenttest" "github.com/coder/coder/pty/ptytest" ) @@ -18,9 +20,9 @@ func TestFeaturesList(t *testing.T) { t.Parallel() t.Run("Table", func(t *testing.T) { t.Parallel() - client := coderdtest.New(t, nil) + client := coderdenttest.New(t, nil) coderdtest.CreateFirstUser(t, client) - cmd, root := clitest.New(t, "features", "list") + cmd, root := clitest.NewWithSubcommands(t, cli.EnterpriseSubcommands(), "features", "list") clitest.SetupConfig(t, client, root) pty := ptytest.New(t) cmd.SetIn(pty.Input()) @@ -36,9 +38,9 @@ func TestFeaturesList(t *testing.T) { t.Run("JSON", func(t *testing.T) { t.Parallel() - client := coderdtest.New(t, nil) + client := coderdenttest.New(t, nil) coderdtest.CreateFirstUser(t, client) - cmd, root := clitest.New(t, "features", "list", "-o", "json") + cmd, root := clitest.NewWithSubcommands(t, cli.EnterpriseSubcommands(), "features", "list", "-o", "json") clitest.SetupConfig(t, client, root) doneChan := make(chan struct{}) diff --git a/enterprise/cli/root.go b/enterprise/cli/root.go index 111dd9885f075..c250b9d9f9ccd 100644 --- a/enterprise/cli/root.go +++ b/enterprise/cli/root.go @@ -21,6 +21,7 @@ func enterpriseOnly() []*cobra.Command { } return api.AGPL, nil }), + features(), licenses(), } } diff --git a/enterprise/coderd/coderd.go b/enterprise/coderd/coderd.go index ff1c29fbf7675..ecae2ab0801f4 100644 --- a/enterprise/coderd/coderd.go +++ b/enterprise/coderd/coderd.go @@ -37,6 +37,10 @@ func New(ctx context.Context, options *Options) (*API, error) { AGPL: coderd.New(options.Options), Options: options, + activeUsers: codersdk.Feature{ + Entitlement: codersdk.EntitlementNotEntitled, + Enabled: false, + }, auditLogs: codersdk.EntitlementNotEntitled, cancelEntitlementsLoop: cancelFunc, } diff --git a/site/src/api/api.ts b/site/src/api/api.ts index 3e2ea86c16d8b..86f2d10bb2f27 100644 --- a/site/src/api/api.ts +++ b/site/src/api/api.ts @@ -16,6 +16,22 @@ export const hardCodedCSRFCookie = (): string => { return csrfToken } +// defaultEntitlements has a default set of disabled functionality. +export const defaultEntitlements = (): TypesGen.Entitlements => { + const features: TypesGen.Entitlements["features"] = {} + for (const feature in Types.FeatureNames) { + features[feature] = { + enabled: false, + entitlement: "not_entitled", + } + } + return { + features: features, + has_license: false, + warnings: [], + } +} + // Always attach CSRF token to all requests. // In puppeteer the document is undefined. In those cases, just // do nothing. @@ -424,8 +440,15 @@ export const putWorkspaceExtension = async ( } export const getEntitlements = async (): Promise => { - const response = await axios.get("/api/v2/entitlements") - return response.data + try { + const response = await axios.get("/api/v2/entitlements") + return response.data + } catch (error) { + if (axios.isAxiosError(error) && error.response?.status === 404) { + return defaultEntitlements() + } + throw error + } } interface GetAuditLogsOptions { diff --git a/site/src/xServices/entitlements/entitlementsXService.ts b/site/src/xServices/entitlements/entitlementsXService.ts index 3eee8a5e43ac6..eb3792bd650e9 100644 --- a/site/src/xServices/entitlements/entitlementsXService.ts +++ b/site/src/xServices/entitlements/entitlementsXService.ts @@ -84,7 +84,7 @@ export const entitlementsMachine = createMachine( }), }, services: { - getEntitlements: () => API.getEntitlements(), + getEntitlements: API.getEntitlements, }, }, ) From be96fff2684c6b9704c5368dcffef54d0ac05842 Mon Sep 17 00:00:00 2001 From: Kyle Carberry Date: Thu, 15 Sep 2022 04:20:06 +0000 Subject: [PATCH 05/19] Remove AGPL Coder entitlements endpoint test --- coderd/coderd_test.go | 20 -------------------- coderd/users.go | 6 +++--- coderd/users_test.go | 6 +++++- 3 files changed, 8 insertions(+), 24 deletions(-) diff --git a/coderd/coderd_test.go b/coderd/coderd_test.go index ea4fafaae533e..9fc459fc9e18e 100644 --- a/coderd/coderd_test.go +++ b/coderd/coderd_test.go @@ -17,7 +17,6 @@ import ( "github.com/coder/coder/buildinfo" "github.com/coder/coder/coderd/coderdtest" - "github.com/coder/coder/codersdk" "github.com/coder/coder/tailnet" "github.com/coder/coder/testutil" ) @@ -115,22 +114,3 @@ func TestDERPLatencyCheck(t *testing.T) { defer res.Body.Close() require.Equal(t, http.StatusOK, res.StatusCode) } - -func TestEntitlements(t *testing.T) { - t.Parallel() - t.Run("GET", func(t *testing.T) { - t.Parallel() - client := coderdtest.New(t, nil) - _ = coderdtest.CreateFirstUser(t, client) - result, err := client.Entitlements(context.Background()) - require.NoError(t, err) - assert.False(t, result.HasLicense) - assert.Empty(t, result.Warnings) - for _, f := range codersdk.FeatureNames { - require.Contains(t, result.Features, f) - fe := result.Features[f] - assert.False(t, fe.Enabled) - assert.Equal(t, codersdk.EntitlementNotEntitled, fe.Entitlement) - } - }) -} diff --git a/coderd/users.go b/coderd/users.go index 683710791b2c8..e64e9e9609c8e 100644 --- a/coderd/users.go +++ b/coderd/users.go @@ -1181,9 +1181,9 @@ func (api *API) createUser(ctx context.Context, store database.Store, req create func (api *API) setAuthCookie(rw http.ResponseWriter, cookie *http.Cookie) { http.SetCookie(rw, cookie) - devurlCookie := api.applicationCookie(cookie) - if devurlCookie != nil { - http.SetCookie(rw, devurlCookie) + appCookie := api.applicationCookie(cookie) + if appCookie != nil { + http.SetCookie(rw, appCookie) } } diff --git a/coderd/users_test.go b/coderd/users_test.go index 2378adc7d07e1..e23bbc78f381f 100644 --- a/coderd/users_test.go +++ b/coderd/users_test.go @@ -32,7 +32,11 @@ func TestFirstUser(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) defer cancel() - _, err := client.CreateFirstUser(ctx, codersdk.CreateFirstUserRequest{}) + has, err := client.HasFirstUser(context.Background()) + require.NoError(t, err) + require.False(t, has) + + _, err = client.CreateFirstUser(ctx, codersdk.CreateFirstUserRequest{}) require.Error(t, err) }) From b4cbd6ca7fca7ed40d1b2747ed9ff7fc900c6e23 Mon Sep 17 00:00:00 2001 From: Kyle Carberry Date: Thu, 15 Sep 2022 05:00:21 +0000 Subject: [PATCH 06/19] Fix warnings output --- cli/root.go | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/cli/root.go b/cli/root.go index adb4713ccb2b4..955cbf9051618 100644 --- a/cli/root.go +++ b/cli/root.go @@ -549,13 +549,11 @@ func checkWarnings(cmd *cobra.Command, client *codersdk.Client) error { defer cancel() entitlements, err := client.Entitlements(ctx) - if err != nil { - return xerrors.Errorf("get entitlements to show warnings: %w", err) - } - for _, w := range entitlements.Warnings { - _, _ = fmt.Fprintln(cmd.ErrOrStderr(), cliui.Styles.Warn.Render(w)) + if err == nil { + for _, w := range entitlements.Warnings { + _, _ = fmt.Fprintln(cmd.ErrOrStderr(), cliui.Styles.Warn.Render(w)) + } } - return nil } From 1964a64f3d0dc4b1362cd2f5551fdd9356a9ffe0 Mon Sep 17 00:00:00 2001 From: Kyle Carberry Date: Thu, 15 Sep 2022 05:26:50 +0000 Subject: [PATCH 07/19] Add command-line flag to toggle audit logging --- enterprise/cli/root.go | 14 +------------- enterprise/cli/server.go | 33 ++++++++++++++++++++++++++++++++ enterprise/coderd/coderd.go | 33 ++++++++++++++++++++++---------- enterprise/coderd/coderd_test.go | 25 ++++++++++++++++++++++++ 4 files changed, 82 insertions(+), 23 deletions(-) create mode 100644 enterprise/cli/server.go diff --git a/enterprise/cli/root.go b/enterprise/cli/root.go index c250b9d9f9ccd..52decb3266226 100644 --- a/enterprise/cli/root.go +++ b/enterprise/cli/root.go @@ -1,26 +1,14 @@ package cli import ( - "context" - "github.com/spf13/cobra" agpl "github.com/coder/coder/cli" - agplcoderd "github.com/coder/coder/coderd" - "github.com/coder/coder/enterprise/coderd" ) func enterpriseOnly() []*cobra.Command { return []*cobra.Command{ - agpl.Server(func(ctx context.Context, options *agplcoderd.Options) (*agplcoderd.API, error) { - api, err := coderd.New(ctx, &coderd.Options{ - Options: options, - }) - if err != nil { - return nil, err - } - return api.AGPL, nil - }), + server(), features(), licenses(), } diff --git a/enterprise/cli/server.go b/enterprise/cli/server.go new file mode 100644 index 0000000000000..8fd9e542af4cc --- /dev/null +++ b/enterprise/cli/server.go @@ -0,0 +1,33 @@ +package cli + +import ( + "context" + + "github.com/spf13/cobra" + + "github.com/coder/coder/cli/cliflag" + "github.com/coder/coder/enterprise/coderd" + + agpl "github.com/coder/coder/cli" + agplcoderd "github.com/coder/coder/coderd" +) + +func server() *cobra.Command { + var ( + auditLogging bool + ) + cmd := agpl.Server(func(ctx context.Context, options *agplcoderd.Options) (*agplcoderd.API, error) { + api, err := coderd.New(ctx, &coderd.Options{ + AuditLogging: auditLogging, + Options: options, + }) + if err != nil { + return nil, err + } + return api.AGPL, nil + }) + cliflag.BoolVarP(cmd.Flags(), &auditLogging, "audit-logging", "", "CODER_AUDIT_LOGGING", true, + "Specifies whether audit logging is enabled.") + + return cmd +} diff --git a/enterprise/coderd/coderd.go b/enterprise/coderd/coderd.go index ecae2ab0801f4..67382ad6a166b 100644 --- a/enterprise/coderd/coderd.go +++ b/enterprise/coderd/coderd.go @@ -78,6 +78,7 @@ func New(ctx context.Context, options *Options) (*API, error) { type Options struct { *coderd.Options + AuditLogging bool EntitlementsUpdateInterval time.Duration Keys map[string]ed25519.PublicKey } @@ -125,7 +126,14 @@ func (api *API) updateEntitlements(ctx context.Context) error { api.mutex.Lock() defer api.mutex.Unlock() now := time.Now() - auditLogs := api.auditLogs + + // Default all entitlements to be disabled. + activeUsers := codersdk.Feature{ + Enabled: false, + Entitlement: codersdk.EntitlementNotEntitled, + } + auditLogs := codersdk.EntitlementNotEntitled + for _, l := range licenses { claims, err := validateDBLicense(l, api.Keys) if err != nil { @@ -141,24 +149,25 @@ func (api *API) updateEntitlements(ctx context.Context) error { entitlement = codersdk.EntitlementGracePeriod } if claims.Features.UserLimit > 0 { - api.activeUsers.Enabled = true - api.activeUsers.Entitlement = entitlement + activeUsers.Enabled = true + activeUsers.Entitlement = entitlement currentLimit := int64(0) - if api.activeUsers.Limit != nil { - currentLimit = *api.activeUsers.Limit + if activeUsers.Limit != nil { + currentLimit = *activeUsers.Limit } limit := max(currentLimit, claims.Features.UserLimit) - api.activeUsers.Limit = &limit + activeUsers.Limit = &limit } if claims.Features.AuditLog > 0 { - api.auditLogs = entitlement + auditLogs = entitlement } } + if auditLogs != api.auditLogs { auditor := agplaudit.NewNop() // A flag could be added to the options that would allow disabling // enhanced audit logging here! - if api.auditLogs == codersdk.EntitlementEntitled { + if api.auditLogs == codersdk.EntitlementEntitled && api.AuditLogging { auditor = audit.NewAuditor( audit.DefaultFilter, backends.NewPostgres(api.Database, true), @@ -167,6 +176,10 @@ func (api *API) updateEntitlements(ctx context.Context) error { } api.AGPL.Auditor.Store(auditor) } + + api.activeUsers = activeUsers + api.auditLogs = auditLogs + return nil } @@ -205,9 +218,9 @@ func (api *API) entitlements(rw http.ResponseWriter, r *http.Request) { // Audit logs resp.Features[codersdk.FeatureAuditLog] = codersdk.Feature{ Entitlement: auditLogs, - Enabled: true, + Enabled: api.AuditLogging, } - if auditLogs == codersdk.EntitlementGracePeriod { + if auditLogs == codersdk.EntitlementGracePeriod && api.AuditLogging { resp.Warnings = append(resp.Warnings, "Audit logging is enabled but your license for this feature is expired.") } diff --git a/enterprise/coderd/coderd_test.go b/enterprise/coderd/coderd_test.go index 114a02d2ff472..84f1656a1bcc8 100644 --- a/enterprise/coderd/coderd_test.go +++ b/enterprise/coderd/coderd_test.go @@ -59,6 +59,31 @@ func TestEntitlements(t *testing.T) { assert.Nil(t, al.Actual) assert.Empty(t, res.Warnings) }) + t.Run("FullLicenseToNone", func(t *testing.T) { + t.Parallel() + client := coderdenttest.New(t, nil) + _ = coderdtest.CreateFirstUser(t, client) + license := coderdenttest.AddLicense(t, client, coderdenttest.AddLicenseOptions{ + UserLimit: 100, + AuditLog: true, + }) + res, err := client.Entitlements(context.Background()) + require.NoError(t, err) + assert.True(t, res.HasLicense) + al := res.Features[codersdk.FeatureAuditLog] + assert.Equal(t, codersdk.EntitlementEntitled, al.Entitlement) + assert.True(t, al.Enabled) + + err = client.DeleteLicense(context.Background(), license.ID) + require.NoError(t, err) + + res, err = client.Entitlements(context.Background()) + require.NoError(t, err) + assert.True(t, res.HasLicense) + al = res.Features[codersdk.FeatureAuditLog] + assert.Equal(t, codersdk.EntitlementNotEntitled, al.Entitlement) + assert.True(t, al.Enabled) + }) t.Run("Warnings", func(t *testing.T) { t.Parallel() client := coderdenttest.New(t, nil) From 0e88f51e43d7287409cc05d240bd580f588731d2 Mon Sep 17 00:00:00 2001 From: Kyle Carberry Date: Thu, 15 Sep 2022 05:27:52 +0000 Subject: [PATCH 08/19] Fix hasLicense being set --- enterprise/coderd/coderd.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/enterprise/coderd/coderd.go b/enterprise/coderd/coderd.go index 67382ad6a166b..71e1db15ec685 100644 --- a/enterprise/coderd/coderd.go +++ b/enterprise/coderd/coderd.go @@ -128,6 +128,7 @@ func (api *API) updateEntitlements(ctx context.Context) error { now := time.Now() // Default all entitlements to be disabled. + hasLicense := false activeUsers := codersdk.Feature{ Enabled: false, Entitlement: codersdk.EntitlementNotEntitled, @@ -141,7 +142,7 @@ func (api *API) updateEntitlements(ctx context.Context) error { slog.F("id", l.ID), slog.Error(err)) continue } - api.hasLicense = true + hasLicense = true entitlement := codersdk.EntitlementEntitled if now.After(claims.LicenseExpires.Time) { // if the grace period were over, the validation fails, so if we are after @@ -177,6 +178,7 @@ func (api *API) updateEntitlements(ctx context.Context) error { api.AGPL.Auditor.Store(auditor) } + api.hasLicense = hasLicense api.activeUsers = activeUsers api.auditLogs = auditLogs From 586b3e8d82236f0313fe55a788336b4a43162b88 Mon Sep 17 00:00:00 2001 From: Kyle Carberry Date: Thu, 15 Sep 2022 05:28:34 +0000 Subject: [PATCH 09/19] Remove features interface --- coderd/features/features.go | 9 --------- 1 file changed, 9 deletions(-) delete mode 100644 coderd/features/features.go diff --git a/coderd/features/features.go b/coderd/features/features.go deleted file mode 100644 index f086931fa8003..0000000000000 --- a/coderd/features/features.go +++ /dev/null @@ -1,9 +0,0 @@ -package features - -// Service is the interface for interacting with enterprise features. -type Service interface { - // Get returns the implementations for feature interfaces. Parameter `s` must be a pointer to a - // struct type containing feature interfaces as fields. The FeatureService sets all fields to - // the correct implementations depending on whether the features are turned on. - Get(s any) error -} From 1bb6d0f9fd3b1a6a2d5642b3a707b4ff9a071e5f Mon Sep 17 00:00:00 2001 From: Kyle Carberry Date: Thu, 15 Sep 2022 05:34:03 +0000 Subject: [PATCH 10/19] Fix audit logging default --- enterprise/coderd/coderdenttest/coderdenttest.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/enterprise/coderd/coderdenttest/coderdenttest.go b/enterprise/coderd/coderdenttest/coderdenttest.go index 01c6bcbeeca9a..f17f2455de6ef 100644 --- a/enterprise/coderd/coderdenttest/coderdenttest.go +++ b/enterprise/coderd/coderdenttest/coderdenttest.go @@ -53,7 +53,8 @@ func NewWithAPI(t *testing.T, options *Options) (*codersdk.Client, io.Closer, *c } srv, oop := coderdtest.NewOptions(t, options.Options) coderAPI, err := coderd.New(context.Background(), &coderd.Options{ - Options: oop, + AuditLogging: true, + Options: oop, Keys: map[string]ed25519.PublicKey{ testKeyID: testPublicKey, }, From 8c603f606724d054280f15f642732e1eecdf4c6f Mon Sep 17 00:00:00 2001 From: Kyle Carberry Date: Fri, 16 Sep 2022 14:12:10 -0500 Subject: [PATCH 11/19] Add bash as a dependency --- flake.nix | 1 + 1 file changed, 1 insertion(+) diff --git a/flake.nix b/flake.nix index a5e4816b19e63..dfc44b91df36f 100644 --- a/flake.nix +++ b/flake.nix @@ -16,6 +16,7 @@ formatter = pkgs.nixpkgs-fmt; devShells.default = pkgs.mkShell { buildInputs = with pkgs; [ + bash bat drpc.defaultPackage.${system} exa From 2dcf1f884d1335cdc81d6cde4759efd53aa5fd81 Mon Sep 17 00:00:00 2001 From: Kyle Carberry Date: Fri, 16 Sep 2022 14:52:23 -0500 Subject: [PATCH 12/19] Add comment --- enterprise/coderd/coderd.go | 1 + enterprise/coderd/licenses_test.go | 14 +++++++------- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/enterprise/coderd/coderd.go b/enterprise/coderd/coderd.go index 71e1db15ec685..a5cd7456b8824 100644 --- a/enterprise/coderd/coderd.go +++ b/enterprise/coderd/coderd.go @@ -135,6 +135,7 @@ func (api *API) updateEntitlements(ctx context.Context) error { } auditLogs := codersdk.EntitlementNotEntitled + // Here we loop through licenses to detect enabled features. for _, l := range licenses { claims, err := validateDBLicense(l, api.Keys) if err != nil { diff --git a/enterprise/coderd/licenses_test.go b/enterprise/coderd/licenses_test.go index b518a684d6c05..ba380932b7fb8 100644 --- a/enterprise/coderd/licenses_test.go +++ b/enterprise/coderd/licenses_test.go @@ -20,7 +20,7 @@ import ( func TestPostLicense(t *testing.T) { t.Parallel() - t.Run("POST", func(t *testing.T) { + t.Run("Success", func(t *testing.T) { t.Parallel() client := coderdenttest.New(t, nil) _ = coderdtest.CreateFirstUser(t, client) @@ -37,7 +37,7 @@ func TestPostLicense(t *testing.T) { assert.Equal(t, json.Number("1"), features[codersdk.FeatureAuditLog]) }) - t.Run("POST_unauthorized", func(t *testing.T) { + t.Run("Unauthorized", func(t *testing.T) { t.Parallel() client := coderdenttest.New(t, nil) _, err := client.AddLicense(context.Background(), codersdk.AddLicenseRequest{ @@ -51,7 +51,7 @@ func TestPostLicense(t *testing.T) { } }) - t.Run("POST_corrupted", func(t *testing.T) { + t.Run("Corrupted", func(t *testing.T) { t.Parallel() client := coderdenttest.New(t, nil) _ = coderdtest.CreateFirstUser(t, client) @@ -70,7 +70,7 @@ func TestPostLicense(t *testing.T) { func TestGetLicense(t *testing.T) { t.Parallel() - t.Run("GET", func(t *testing.T) { + t.Run("Success", func(t *testing.T) { t.Parallel() client := coderdenttest.New(t, nil) _ = coderdtest.CreateFirstUser(t, client) @@ -108,7 +108,7 @@ func TestGetLicense(t *testing.T) { func TestDeleteLicense(t *testing.T) { t.Parallel() - t.Run("DELETE_empty", func(t *testing.T) { + t.Run("Empty", func(t *testing.T) { t.Parallel() client := coderdenttest.New(t, nil) _ = coderdtest.CreateFirstUser(t, client) @@ -124,7 +124,7 @@ func TestDeleteLicense(t *testing.T) { } }) - t.Run("DELETE_bad_id", func(t *testing.T) { + t.Run("BadID", func(t *testing.T) { t.Parallel() client := coderdenttest.New(t, nil) _ = coderdtest.CreateFirstUser(t, client) @@ -137,7 +137,7 @@ func TestDeleteLicense(t *testing.T) { require.NoError(t, resp.Body.Close()) }) - t.Run("DELETE", func(t *testing.T) { + t.Run("Success", func(t *testing.T) { t.Parallel() client := coderdenttest.New(t, nil) _ = coderdtest.CreateFirstUser(t, client) From 1955ae3aa02ef6b152c4a496412452dc04cff6d2 Mon Sep 17 00:00:00 2001 From: Kyle Carberry Date: Mon, 19 Sep 2022 19:04:58 +0000 Subject: [PATCH 13/19] Add tests for resync and pubsub, and add back previous exp backoff retry --- enterprise/coderd/coderd.go | 80 +++++++++++++----- enterprise/coderd/coderd_test.go | 83 ++++++++++++++++++- .../coderd/coderdenttest/coderdenttest.go | 25 ++++-- .../coderdenttest/coderdenttest_test.go | 2 +- enterprise/coderd/licenses.go | 6 +- enterprise/coderd/licenses_test.go | 12 +-- 6 files changed, 166 insertions(+), 42 deletions(-) diff --git a/enterprise/coderd/coderd.go b/enterprise/coderd/coderd.go index a5cd7456b8824..d8bea430694c6 100644 --- a/enterprise/coderd/coderd.go +++ b/enterprise/coderd/coderd.go @@ -10,6 +10,7 @@ import ( "golang.org/x/xerrors" + "github.com/cenkalti/backoff/v4" "github.com/go-chi/chi/v5" "cdr.dev/slog" @@ -64,7 +65,7 @@ func New(ctx context.Context, options *Options) (*API, error) { if err != nil { return nil, xerrors.Errorf("update entitlements: %w", err) } - api.closeLicenseSubscribe, err = api.Pubsub.Subscribe(pubSubEventLicenses, func(ctx context.Context, message []byte) { + api.closeLicenseSubscribe, err = api.Pubsub.Subscribe(PubsubEventLicenses, func(ctx context.Context, message []byte) { _ = api.updateEntitlements(ctx) }) if err != nil { @@ -101,23 +102,6 @@ func (api *API) Close() error { return api.AGPL.Close() } -func (api *API) runEntitlementsLoop(ctx context.Context) { - ticker := time.NewTicker(api.EntitlementsUpdateInterval) - defer ticker.Stop() - for { - select { - case <-ctx.Done(): - return - case <-ticker.C: - } - err := api.updateEntitlements(ctx) - if err != nil { - api.Logger.Warn(ctx, "failed to get feature entitlements", slog.Error(err)) - continue - } - } -} - func (api *API) updateEntitlements(ctx context.Context) error { licenses, err := api.Database.GetUnexpiredLicenses(ctx) if err != nil { @@ -169,7 +153,7 @@ func (api *API) updateEntitlements(ctx context.Context) error { auditor := agplaudit.NewNop() // A flag could be added to the options that would allow disabling // enhanced audit logging here! - if api.auditLogs == codersdk.EntitlementEntitled && api.AuditLogging { + if auditLogs == codersdk.EntitlementEntitled && api.AuditLogging { auditor = audit.NewAuditor( audit.DefaultFilter, backends.NewPostgres(api.Database, true), @@ -231,6 +215,64 @@ func (api *API) entitlements(rw http.ResponseWriter, r *http.Request) { httpapi.Write(rw, http.StatusOK, resp) } +func (api *API) runEntitlementsLoop(ctx context.Context) { + eb := backoff.NewExponentialBackOff() + eb.MaxElapsedTime = 0 // retry indefinitely + b := backoff.WithContext(eb, ctx) + updates := make(chan struct{}, 1) + subscribed := false + + for { + select { + case <-ctx.Done(): + return + default: + // pass + } + if !subscribed { + cancel, err := api.Pubsub.Subscribe(PubsubEventLicenses, func(_ context.Context, _ []byte) { + // don't block. If the channel is full, drop the event, as there is a resync + // scheduled already. + select { + case updates <- struct{}{}: + // pass + default: + // pass + } + }) + if err != nil { + api.Logger.Warn(ctx, "failed to subscribe to license updates", slog.Error(err)) + time.Sleep(b.NextBackOff()) + continue + } + // nolint: revive + defer cancel() + subscribed = true + api.Logger.Debug(ctx, "successfully subscribed to pubsub") + } + + api.Logger.Info(ctx, "syncing licensed entitlements") + err := api.updateEntitlements(ctx) + if err != nil { + api.Logger.Warn(ctx, "failed to get feature entitlements", slog.Error(err)) + time.Sleep(b.NextBackOff()) + continue + } + b.Reset() + api.Logger.Debug(ctx, "synced licensed entitlements") + + select { + case <-ctx.Done(): + return + case <-time.After(api.EntitlementsUpdateInterval): + continue + case <-updates: + api.Logger.Debug(ctx, "got pubsub update") + continue + } + } +} + func max(a, b int64) int64 { if a > b { return a diff --git a/enterprise/coderd/coderd_test.go b/enterprise/coderd/coderd_test.go index 84f1656a1bcc8..72c4befa5c9ce 100644 --- a/enterprise/coderd/coderd_test.go +++ b/enterprise/coderd/coderd_test.go @@ -2,6 +2,7 @@ package coderd_test import ( "context" + "reflect" "testing" "time" @@ -9,9 +10,14 @@ import ( "github.com/stretchr/testify/require" "go.uber.org/goleak" + agplaudit "github.com/coder/coder/coderd/audit" "github.com/coder/coder/coderd/coderdtest" + "github.com/coder/coder/coderd/database" "github.com/coder/coder/codersdk" + "github.com/coder/coder/enterprise/audit" + "github.com/coder/coder/enterprise/coderd" "github.com/coder/coder/enterprise/coderd/coderdenttest" + "github.com/coder/coder/testutil" ) func TestMain(m *testing.M) { @@ -40,7 +46,7 @@ func TestEntitlements(t *testing.T) { t.Parallel() client := coderdenttest.New(t, nil) _ = coderdtest.CreateFirstUser(t, client) - coderdenttest.AddLicense(t, client, coderdenttest.AddLicenseOptions{ + coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{ UserLimit: 100, AuditLog: true, }) @@ -63,7 +69,7 @@ func TestEntitlements(t *testing.T) { t.Parallel() client := coderdenttest.New(t, nil) _ = coderdtest.CreateFirstUser(t, client) - license := coderdenttest.AddLicense(t, client, coderdenttest.AddLicenseOptions{ + license := coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{ UserLimit: 100, AuditLog: true, }) @@ -79,7 +85,7 @@ func TestEntitlements(t *testing.T) { res, err = client.Entitlements(context.Background()) require.NoError(t, err) - assert.True(t, res.HasLicense) + assert.False(t, res.HasLicense) al = res.Features[codersdk.FeatureAuditLog] assert.Equal(t, codersdk.EntitlementNotEntitled, al.Entitlement) assert.True(t, al.Enabled) @@ -91,7 +97,7 @@ func TestEntitlements(t *testing.T) { for i := 0; i < 4; i++ { coderdtest.CreateAnotherUser(t, client, first.OrganizationID) } - coderdenttest.AddLicense(t, client, coderdenttest.AddLicenseOptions{ + coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{ UserLimit: 4, AuditLog: true, GraceAt: time.Now().Add(-time.Second), @@ -115,4 +121,73 @@ func TestEntitlements(t *testing.T) { assert.Contains(t, res.Warnings, "Audit logging is enabled but your license for this feature is expired.") }) + t.Run("Pubsub", func(t *testing.T) { + t.Parallel() + client, _, api := coderdenttest.NewWithAPI(t, nil) + entitlements, err := client.Entitlements(context.Background()) + require.NoError(t, err) + require.False(t, entitlements.HasLicense) + coderdtest.CreateFirstUser(t, client) + _, err = api.Database.InsertLicense(context.Background(), database.InsertLicenseParams{ + UploadedAt: database.Now(), + Exp: database.Now().AddDate(1, 0, 0), + JWT: coderdenttest.GenerateLicense(t, coderdenttest.LicenseOptions{ + AuditLog: true, + }), + }) + require.NoError(t, err) + err = api.Pubsub.Publish(coderd.PubsubEventLicenses, []byte{}) + require.NoError(t, err) + require.Eventually(t, func() bool { + entitlements, err := client.Entitlements(context.Background()) + assert.NoError(t, err) + return entitlements.HasLicense + }, testutil.WaitShort, testutil.IntervalFast) + }) + t.Run("Resync", func(t *testing.T) { + t.Parallel() + client, _, api := coderdenttest.NewWithAPI(t, &coderdenttest.Options{ + EntitlementsUpdateInterval: 25 * time.Millisecond, + }) + entitlements, err := client.Entitlements(context.Background()) + require.NoError(t, err) + require.False(t, entitlements.HasLicense) + coderdtest.CreateFirstUser(t, client) + _, err = api.Database.InsertLicense(context.Background(), database.InsertLicenseParams{ + UploadedAt: database.Now(), + Exp: database.Now().AddDate(1, 0, 0), + JWT: coderdenttest.GenerateLicense(t, coderdenttest.LicenseOptions{ + AuditLog: true, + }), + }) + require.NoError(t, err) + require.Eventually(t, func() bool { + entitlements, err := client.Entitlements(context.Background()) + assert.NoError(t, err) + return entitlements.HasLicense + }, testutil.WaitShort, testutil.IntervalFast) + }) +} + +func TestAuditLogging(t *testing.T) { + t.Parallel() + t.Run("Enabled", func(t *testing.T) { + t.Parallel() + client, _, api := coderdenttest.NewWithAPI(t, nil) + coderdtest.CreateFirstUser(t, client) + coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{ + AuditLog: true, + }) + _, auditor := api.AGPL.Auditor.Load(context.Background()) + ea := audit.NewAuditor(audit.DefaultFilter) + assert.Equal(t, reflect.ValueOf(ea).Type(), reflect.ValueOf(auditor).Type()) + }) + t.Run("Disabled", func(t *testing.T) { + t.Parallel() + client, _, api := coderdenttest.NewWithAPI(t, nil) + coderdtest.CreateFirstUser(t, client) + _, auditor := api.AGPL.Auditor.Load(context.Background()) + ea := agplaudit.NewNop() + assert.Equal(t, reflect.ValueOf(ea).Type(), reflect.ValueOf(auditor).Type()) + }) } diff --git a/enterprise/coderd/coderdenttest/coderdenttest.go b/enterprise/coderd/coderdenttest/coderdenttest.go index f17f2455de6ef..813fddf473bae 100644 --- a/enterprise/coderd/coderdenttest/coderdenttest.go +++ b/enterprise/coderd/coderdenttest/coderdenttest.go @@ -36,6 +36,7 @@ func init() { type Options struct { *coderdtest.Options + EntitlementsUpdateInterval time.Duration } // New constructs a codersdk client connected to an in-memory Enterprise API instance. @@ -53,8 +54,9 @@ func NewWithAPI(t *testing.T, options *Options) (*codersdk.Client, io.Closer, *c } srv, oop := coderdtest.NewOptions(t, options.Options) coderAPI, err := coderd.New(context.Background(), &coderd.Options{ - AuditLogging: true, - Options: oop, + AuditLogging: true, + Options: oop, + EntitlementsUpdateInterval: options.EntitlementsUpdateInterval, Keys: map[string]ed25519.PublicKey{ testKeyID: testPublicKey, }, @@ -72,7 +74,7 @@ func NewWithAPI(t *testing.T, options *Options) (*codersdk.Client, io.Closer, *c return codersdk.New(coderAPI.AccessURL), provisionerCloser, coderAPI } -type AddLicenseOptions struct { +type LicenseOptions struct { AccountType string AccountID string GraceAt time.Time @@ -82,7 +84,16 @@ type AddLicenseOptions struct { } // AddLicense generates a new license with the options provided and inserts it. -func AddLicense(t *testing.T, client *codersdk.Client, options AddLicenseOptions) codersdk.License { +func AddLicense(t *testing.T, client *codersdk.Client, options LicenseOptions) codersdk.License { + license, err := client.AddLicense(context.Background(), codersdk.AddLicenseRequest{ + License: GenerateLicense(t, options), + }) + require.NoError(t, err) + return license +} + +// GenerateLicense returns a signed JWT using the test key. +func GenerateLicense(t *testing.T, options LicenseOptions) string { if options.ExpiresAt.IsZero() { options.ExpiresAt = time.Now().Add(time.Hour) } @@ -113,11 +124,7 @@ func AddLicense(t *testing.T, client *codersdk.Client, options AddLicenseOptions tok.Header[coderd.HeaderKeyID] = testKeyID signedTok, err := tok.SignedString(testPrivateKey) require.NoError(t, err) - license, err := client.AddLicense(context.Background(), codersdk.AddLicenseRequest{ - License: signedTok, - }) - require.NoError(t, err) - return license + return signedTok } type nopcloser struct{} diff --git a/enterprise/coderd/coderdenttest/coderdenttest_test.go b/enterprise/coderd/coderdenttest/coderdenttest_test.go index e18aabb2bffbc..929a8fc383280 100644 --- a/enterprise/coderd/coderdenttest/coderdenttest_test.go +++ b/enterprise/coderd/coderdenttest/coderdenttest_test.go @@ -402,7 +402,7 @@ func newAuthTester(ctx context.Context, t *testing.T) *authTester { DestinationScheme: codersdk.ParameterDestinationSchemeProvisionerVariable, }) require.NoError(t, err, "create template param") - license := coderdenttest.AddLicense(t, client, coderdenttest.AddLicenseOptions{}) + license := coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{}) urlParameters := map[string]string{ "{organization}": admin.OrganizationID.String(), "{user}": admin.UserID.String(), diff --git a/enterprise/coderd/licenses.go b/enterprise/coderd/licenses.go index 4ffaf42cf585e..5b8273f2ffe60 100644 --- a/enterprise/coderd/licenses.go +++ b/enterprise/coderd/licenses.go @@ -31,7 +31,7 @@ const ( AccountTypeSalesforce = "salesforce" VersionClaim = "version" - pubSubEventLicenses = "licenses" + PubsubEventLicenses = "licenses" ) var ValidMethods = []string{"EdDSA"} @@ -127,7 +127,7 @@ func (api *API) postLicense(rw http.ResponseWriter, r *http.Request) { }) return } - err = api.Pubsub.Publish(pubSubEventLicenses, []byte("add")) + err = api.Pubsub.Publish(PubsubEventLicenses, []byte("add")) if err != nil { api.Logger.Error(context.Background(), "failed to publish license add", slog.Error(err)) // don't fail the HTTP request, since we did write it successfully to the database @@ -206,7 +206,7 @@ func (api *API) deleteLicense(rw http.ResponseWriter, r *http.Request) { }) return } - err = api.Pubsub.Publish(pubSubEventLicenses, []byte("delete")) + err = api.Pubsub.Publish(PubsubEventLicenses, []byte("delete")) if err != nil { api.Logger.Error(context.Background(), "failed to publish license delete", slog.Error(err)) // don't fail the HTTP request, since we did write it successfully to the database diff --git a/enterprise/coderd/licenses_test.go b/enterprise/coderd/licenses_test.go index ba380932b7fb8..243898a43ca73 100644 --- a/enterprise/coderd/licenses_test.go +++ b/enterprise/coderd/licenses_test.go @@ -24,7 +24,7 @@ func TestPostLicense(t *testing.T) { t.Parallel() client := coderdenttest.New(t, nil) _ = coderdtest.CreateFirstUser(t, client) - respLic := coderdenttest.AddLicense(t, client, coderdenttest.AddLicenseOptions{ + respLic := coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{ AccountType: coderd.AccountTypeSalesforce, AccountID: "testing", AuditLog: true, @@ -55,7 +55,7 @@ func TestPostLicense(t *testing.T) { t.Parallel() client := coderdenttest.New(t, nil) _ = coderdtest.CreateFirstUser(t, client) - coderdenttest.AddLicense(t, client, coderdenttest.AddLicenseOptions{}) + coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{}) _, err := client.AddLicense(context.Background(), codersdk.AddLicenseRequest{ License: "invalid", }) @@ -77,12 +77,12 @@ func TestGetLicense(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) defer cancel() - coderdenttest.AddLicense(t, client, coderdenttest.AddLicenseOptions{ + coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{ AccountID: "testing", AuditLog: true, }) - coderdenttest.AddLicense(t, client, coderdenttest.AddLicenseOptions{ + coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{ AccountID: "testing2", AuditLog: true, UserLimit: 200, @@ -144,11 +144,11 @@ func TestDeleteLicense(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) defer cancel() - coderdenttest.AddLicense(t, client, coderdenttest.AddLicenseOptions{ + coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{ AccountID: "testing", AuditLog: true, }) - coderdenttest.AddLicense(t, client, coderdenttest.AddLicenseOptions{ + coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{ AccountID: "testing2", AuditLog: true, UserLimit: 200, From e01a6ceef1501d51c12ebefaf7a24dff90295f3d Mon Sep 17 00:00:00 2001 From: Kyle Carberry Date: Mon, 19 Sep 2022 19:25:15 +0000 Subject: [PATCH 14/19] Separate authz code again --- coderd/coderdtest/authorize.go | 560 +++++++++++++++++ coderd/coderdtest/authorize_test.go | 20 + coderd/coderdtest/coderdtest.go | 6 +- .../coderdenttest/coderdenttest_test.go | 581 +----------------- 4 files changed, 602 insertions(+), 565 deletions(-) create mode 100644 coderd/coderdtest/authorize.go create mode 100644 coderd/coderdtest/authorize_test.go diff --git a/coderd/coderdtest/authorize.go b/coderd/coderdtest/authorize.go new file mode 100644 index 0000000000000..8dd3ff7257555 --- /dev/null +++ b/coderd/coderdtest/authorize.go @@ -0,0 +1,560 @@ +package coderdtest + +import ( + "context" + "fmt" + "io" + "net/http" + "strconv" + "strings" + "testing" + + "github.com/go-chi/chi/v5" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/xerrors" + + "github.com/coder/coder/coderd" + "github.com/coder/coder/coderd/rbac" + "github.com/coder/coder/codersdk" + "github.com/coder/coder/provisioner/echo" + "github.com/coder/coder/provisionersdk/proto" + "github.com/coder/coder/testutil" +) + +func AGPLRoutes(a *AuthTester) (map[string]string, map[string]RouteCheck) { + // Some quick reused objects + workspaceRBACObj := rbac.ResourceWorkspace.InOrg(a.Organization.ID).WithOwner(a.Workspace.OwnerID.String()) + workspaceExecObj := rbac.ResourceWorkspaceExecution.InOrg(a.Organization.ID).WithOwner(a.Workspace.OwnerID.String()) + applicationConnectObj := rbac.ResourceWorkspaceApplicationConnect.InOrg(a.Organization.ID).WithOwner(a.Workspace.OwnerID.String()) + + // skipRoutes allows skipping routes from being checked. + skipRoutes := map[string]string{ + "POST:/api/v2/users/logout": "Logging out deletes the API Key for other routes", + "GET:/derp": "This requires a WebSocket upgrade!", + "GET:/derp/latency-check": "This always returns a 200!", + } + + assertRoute := map[string]RouteCheck{ + // These endpoints do not require auth + "GET:/api/v2": {NoAuthorize: true}, + "GET:/api/v2/buildinfo": {NoAuthorize: true}, + "GET:/api/v2/users/first": {NoAuthorize: true}, + "POST:/api/v2/users/first": {NoAuthorize: true}, + "POST:/api/v2/users/login": {NoAuthorize: true}, + "GET:/api/v2/users/authmethods": {NoAuthorize: true}, + "POST:/api/v2/csp/reports": {NoAuthorize: true}, + + // Has it's own auth + "GET:/api/v2/users/oauth2/github/callback": {NoAuthorize: true}, + "GET:/api/v2/users/oidc/callback": {NoAuthorize: true}, + + // All workspaceagents endpoints do not use rbac + "POST:/api/v2/workspaceagents/aws-instance-identity": {NoAuthorize: true}, + "POST:/api/v2/workspaceagents/azure-instance-identity": {NoAuthorize: true}, + "POST:/api/v2/workspaceagents/google-instance-identity": {NoAuthorize: true}, + "GET:/api/v2/workspaceagents/me/gitsshkey": {NoAuthorize: true}, + "GET:/api/v2/workspaceagents/me/iceservers": {NoAuthorize: true}, + "GET:/api/v2/workspaceagents/me/listen": {NoAuthorize: true}, + "GET:/api/v2/workspaceagents/me/metadata": {NoAuthorize: true}, + "GET:/api/v2/workspaceagents/me/turn": {NoAuthorize: true}, + "GET:/api/v2/workspaceagents/me/coordinate": {NoAuthorize: true}, + "POST:/api/v2/workspaceagents/me/version": {NoAuthorize: true}, + "GET:/api/v2/workspaceagents/me/report-stats": {NoAuthorize: true}, + "GET:/api/v2/workspaceagents/{workspaceagent}/iceservers": {NoAuthorize: true}, + + // These endpoints have more assertions. This is good, add more endpoints to assert if you can! + "GET:/api/v2/organizations/{organization}": {AssertObject: rbac.ResourceOrganization.InOrg(a.Admin.OrganizationID)}, + "GET:/api/v2/users/{user}/organizations": {StatusCode: http.StatusOK, AssertObject: rbac.ResourceOrganization}, + "GET:/api/v2/users/{user}/workspace/{workspacename}": { + AssertObject: rbac.ResourceWorkspace, + AssertAction: rbac.ActionRead, + }, + "GET:/api/v2/users/{user}/workspace/{workspacename}/builds/{buildnumber}": { + AssertObject: rbac.ResourceWorkspace, + AssertAction: rbac.ActionRead, + }, + "GET:/api/v2/workspacebuilds/{workspacebuild}": { + AssertAction: rbac.ActionRead, + AssertObject: workspaceRBACObj, + }, + "GET:/api/v2/workspacebuilds/{workspacebuild}/logs": { + AssertAction: rbac.ActionRead, + AssertObject: workspaceRBACObj, + }, + "GET:/api/v2/workspaces/{workspace}/builds": { + AssertAction: rbac.ActionRead, + AssertObject: workspaceRBACObj, + }, + "GET:/api/v2/workspaces/{workspace}": { + AssertAction: rbac.ActionRead, + AssertObject: workspaceRBACObj, + }, + "PUT:/api/v2/workspaces/{workspace}/autostart": { + AssertAction: rbac.ActionUpdate, + AssertObject: workspaceRBACObj, + }, + "PUT:/api/v2/workspaces/{workspace}/ttl": { + AssertAction: rbac.ActionUpdate, + AssertObject: workspaceRBACObj, + }, + "GET:/api/v2/workspaceresources/{workspaceresource}": { + AssertAction: rbac.ActionRead, + AssertObject: workspaceRBACObj, + }, + "PATCH:/api/v2/workspacebuilds/{workspacebuild}/cancel": { + AssertAction: rbac.ActionUpdate, + AssertObject: workspaceRBACObj, + }, + "GET:/api/v2/workspacebuilds/{workspacebuild}/resources": { + AssertAction: rbac.ActionRead, + AssertObject: workspaceRBACObj, + }, + "GET:/api/v2/workspacebuilds/{workspacebuild}/state": { + AssertAction: rbac.ActionRead, + AssertObject: workspaceRBACObj, + }, + "GET:/api/v2/workspaceagents/{workspaceagent}": { + AssertAction: rbac.ActionRead, + AssertObject: workspaceRBACObj, + }, + "GET:/api/v2/workspaceagents/{workspaceagent}/dial": { + AssertAction: rbac.ActionCreate, + AssertObject: workspaceExecObj, + }, + "GET:/api/v2/workspaceagents/{workspaceagent}/turn": { + AssertAction: rbac.ActionCreate, + AssertObject: workspaceExecObj, + }, + "GET:/api/v2/workspaceagents/{workspaceagent}/pty": { + AssertAction: rbac.ActionCreate, + AssertObject: workspaceExecObj, + }, + "GET:/api/v2/workspaceagents/{workspaceagent}/coordinate": { + AssertAction: rbac.ActionCreate, + AssertObject: workspaceExecObj, + }, + "GET:/api/v2/workspaces/": { + StatusCode: http.StatusOK, + AssertAction: rbac.ActionRead, + AssertObject: workspaceRBACObj, + }, + "GET:/api/v2/organizations/{organization}/templates": { + StatusCode: http.StatusOK, + AssertAction: rbac.ActionRead, + AssertObject: rbac.ResourceTemplate.InOrg(a.Template.OrganizationID), + }, + "POST:/api/v2/organizations/{organization}/templates": { + AssertAction: rbac.ActionCreate, + AssertObject: rbac.ResourceTemplate.InOrg(a.Organization.ID), + }, + "DELETE:/api/v2/templates/{template}": { + AssertAction: rbac.ActionDelete, + AssertObject: rbac.ResourceTemplate.InOrg(a.Template.OrganizationID), + }, + "GET:/api/v2/templates/{template}": { + AssertAction: rbac.ActionRead, + AssertObject: rbac.ResourceTemplate.InOrg(a.Template.OrganizationID), + }, + "POST:/api/v2/files": {AssertAction: rbac.ActionCreate, AssertObject: rbac.ResourceFile}, + "GET:/api/v2/files/{hash}": { + AssertAction: rbac.ActionRead, + AssertObject: rbac.ResourceFile.WithOwner(a.Admin.UserID.String()), + }, + "GET:/api/v2/templates/{template}/versions": { + AssertAction: rbac.ActionRead, + AssertObject: rbac.ResourceTemplate.InOrg(a.Template.OrganizationID), + }, + "PATCH:/api/v2/templates/{template}/versions": { + AssertAction: rbac.ActionUpdate, + AssertObject: rbac.ResourceTemplate.InOrg(a.Template.OrganizationID), + }, + "GET:/api/v2/templates/{template}/versions/{templateversionname}": { + AssertAction: rbac.ActionRead, + AssertObject: rbac.ResourceTemplate.InOrg(a.Template.OrganizationID), + }, + "GET:/api/v2/templateversions/{templateversion}": { + AssertAction: rbac.ActionRead, + AssertObject: rbac.ResourceTemplate.InOrg(a.Template.OrganizationID), + }, + "PATCH:/api/v2/templateversions/{templateversion}/cancel": { + AssertAction: rbac.ActionUpdate, + AssertObject: rbac.ResourceTemplate.InOrg(a.Template.OrganizationID), + }, + "GET:/api/v2/templateversions/{templateversion}/logs": { + AssertAction: rbac.ActionRead, + AssertObject: rbac.ResourceTemplate.InOrg(a.Template.OrganizationID), + }, + "GET:/api/v2/templateversions/{templateversion}/parameters": { + AssertAction: rbac.ActionRead, + AssertObject: rbac.ResourceTemplate.InOrg(a.Template.OrganizationID), + }, + "GET:/api/v2/templateversions/{templateversion}/resources": { + AssertAction: rbac.ActionRead, + AssertObject: rbac.ResourceTemplate.InOrg(a.Template.OrganizationID), + }, + "GET:/api/v2/templateversions/{templateversion}/schema": { + AssertAction: rbac.ActionRead, + AssertObject: rbac.ResourceTemplate.InOrg(a.Template.OrganizationID), + }, + "POST:/api/v2/templateversions/{templateversion}/dry-run": { + // The first check is to read the template + AssertAction: rbac.ActionRead, + AssertObject: rbac.ResourceTemplate.InOrg(a.Version.OrganizationID), + }, + "GET:/api/v2/templateversions/{templateversion}/dry-run/{jobID}": { + AssertAction: rbac.ActionRead, + AssertObject: rbac.ResourceTemplate.InOrg(a.Version.OrganizationID), + }, + "GET:/api/v2/templateversions/{templateversion}/dry-run/{jobID}/resources": { + AssertAction: rbac.ActionRead, + AssertObject: rbac.ResourceTemplate.InOrg(a.Version.OrganizationID), + }, + "GET:/api/v2/templateversions/{templateversion}/dry-run/{jobID}/logs": { + AssertAction: rbac.ActionRead, + AssertObject: rbac.ResourceTemplate.InOrg(a.Version.OrganizationID), + }, + "PATCH:/api/v2/templateversions/{templateversion}/dry-run/{jobID}/cancel": { + AssertAction: rbac.ActionRead, + AssertObject: rbac.ResourceTemplate.InOrg(a.Version.OrganizationID), + }, + "GET:/api/v2/provisionerdaemons": { + StatusCode: http.StatusOK, + AssertObject: rbac.ResourceProvisionerDaemon, + }, + + "POST:/api/v2/parameters/{scope}/{id}": { + AssertAction: rbac.ActionUpdate, + AssertObject: rbac.ResourceTemplate, + }, + "GET:/api/v2/parameters/{scope}/{id}": { + AssertAction: rbac.ActionRead, + AssertObject: rbac.ResourceTemplate, + }, + "DELETE:/api/v2/parameters/{scope}/{id}/{name}": { + AssertAction: rbac.ActionUpdate, + AssertObject: rbac.ResourceTemplate, + }, + "GET:/api/v2/organizations/{organization}/templates/{templatename}": { + AssertAction: rbac.ActionRead, + AssertObject: rbac.ResourceTemplate.InOrg(a.Template.OrganizationID), + }, + "POST:/api/v2/organizations/{organization}/workspaces": { + AssertAction: rbac.ActionCreate, + // No ID when creating + AssertObject: workspaceRBACObj, + }, + "GET:/api/v2/workspaces/{workspace}/watch": { + AssertAction: rbac.ActionRead, + AssertObject: workspaceRBACObj, + }, + "GET:/api/v2/users": {StatusCode: http.StatusOK, AssertObject: rbac.ResourceUser}, + + // These endpoints need payloads to get to the auth part. Payloads will be required + "PUT:/api/v2/users/{user}/roles": {StatusCode: http.StatusBadRequest, NoAuthorize: true}, + "PUT:/api/v2/organizations/{organization}/members/{user}/roles": {NoAuthorize: true}, + "POST:/api/v2/workspaces/{workspace}/builds": {StatusCode: http.StatusBadRequest, NoAuthorize: true}, + "POST:/api/v2/organizations/{organization}/templateversions": {StatusCode: http.StatusBadRequest, NoAuthorize: true}, + } + + // Routes like proxy routes support all HTTP methods. A helper func to expand + // 1 url to all http methods. + assertAllHTTPMethods := func(url string, check RouteCheck) { + methods := []string{http.MethodGet, http.MethodHead, http.MethodPost, + http.MethodPut, http.MethodPatch, http.MethodDelete, + http.MethodConnect, http.MethodOptions, http.MethodTrace} + + for _, method := range methods { + route := method + ":" + url + assertRoute[route] = check + } + } + + assertAllHTTPMethods("/%40{user}/{workspace_and_agent}/apps/{workspaceapp}/*", RouteCheck{ + AssertAction: rbac.ActionCreate, + AssertObject: applicationConnectObj, + }) + assertAllHTTPMethods("/@{user}/{workspace_and_agent}/apps/{workspaceapp}/*", RouteCheck{ + AssertAction: rbac.ActionCreate, + AssertObject: applicationConnectObj, + }) + + return skipRoutes, assertRoute +} + +type RouteCheck struct { + NoAuthorize bool + AssertAction rbac.Action + AssertObject rbac.Object + StatusCode int +} + +type AuthTester struct { + t *testing.T + api *coderd.API + authorizer *RecordingAuthorizer + + Client *codersdk.Client + Workspace codersdk.Workspace + Organization codersdk.Organization + Admin codersdk.CreateFirstUserResponse + Template codersdk.Template + Version codersdk.TemplateVersion + WorkspaceResource codersdk.WorkspaceResource + File codersdk.UploadResponse + TemplateVersionDryRun codersdk.ProvisionerJob + TemplateParam codersdk.Parameter + URLParams map[string]string +} + +func NewAuthTester(ctx context.Context, t *testing.T, client *codersdk.Client, api *coderd.API, admin codersdk.CreateFirstUserResponse) *AuthTester { + authorizer, ok := api.Authorizer.(*RecordingAuthorizer) + if !ok { + t.Fail() + } + // The provisioner will call to coderd and register itself. This is async, + // so we wait for it to occur. + require.Eventually(t, func() bool { + provisionerds, err := client.ProvisionerDaemons(ctx) + return assert.NoError(t, err) && len(provisionerds) > 0 + }, testutil.WaitLong, testutil.IntervalSlow) + + provisionerds, err := client.ProvisionerDaemons(ctx) + require.NoError(t, err, "fetch provisioners") + require.Len(t, provisionerds, 1) + + organization, err := client.Organization(ctx, admin.OrganizationID) + require.NoError(t, err, "fetch org") + + // Setup some data in the database. + version := CreateTemplateVersion(t, client, admin.OrganizationID, &echo.Responses{ + Parse: echo.ParseComplete, + Provision: []*proto.Provision_Response{{ + Type: &proto.Provision_Response_Complete{ + Complete: &proto.Provision_Complete{ + // Return a workspace resource + Resources: []*proto.Resource{{ + Name: "some", + Type: "example", + Agents: []*proto.Agent{{ + Name: "agent", + Id: "something", + Auth: &proto.Agent_Token{}, + Apps: []*proto.App{{ + Name: "testapp", + Url: "http://localhost:3000", + }}, + }}, + }}, + }, + }, + }}, + }) + AwaitTemplateVersionJob(t, client, version.ID) + template := CreateTemplate(t, client, admin.OrganizationID, version.ID) + workspace := CreateWorkspace(t, client, admin.OrganizationID, template.ID) + AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID) + file, err := client.Upload(ctx, codersdk.ContentTypeTar, make([]byte, 1024)) + require.NoError(t, err, "upload file") + workspaceResources, err := client.WorkspaceResourcesByBuild(ctx, workspace.LatestBuild.ID) + require.NoError(t, err, "workspace resources") + templateVersionDryRun, err := client.CreateTemplateVersionDryRun(ctx, version.ID, codersdk.CreateTemplateVersionDryRunRequest{ + ParameterValues: []codersdk.CreateParameterRequest{}, + }) + require.NoError(t, err, "template version dry-run") + + templateParam, err := client.CreateParameter(ctx, codersdk.ParameterTemplate, template.ID, codersdk.CreateParameterRequest{ + Name: "test-param", + SourceValue: "hello world", + SourceScheme: codersdk.ParameterSourceSchemeData, + DestinationScheme: codersdk.ParameterDestinationSchemeProvisionerVariable, + }) + require.NoError(t, err, "create template param") + urlParameters := map[string]string{ + "{organization}": admin.OrganizationID.String(), + "{user}": admin.UserID.String(), + "{organizationname}": organization.Name, + "{workspace}": workspace.ID.String(), + "{workspacebuild}": workspace.LatestBuild.ID.String(), + "{workspacename}": workspace.Name, + "{workspaceagent}": workspaceResources[0].Agents[0].ID.String(), + "{buildnumber}": strconv.FormatInt(int64(workspace.LatestBuild.BuildNumber), 10), + "{template}": template.ID.String(), + "{hash}": file.Hash, + "{workspaceresource}": workspaceResources[0].ID.String(), + "{workspaceapp}": workspaceResources[0].Agents[0].Apps[0].Name, + "{templateversion}": version.ID.String(), + "{jobID}": templateVersionDryRun.ID.String(), + "{templatename}": template.Name, + "{workspace_and_agent}": workspace.Name + "." + workspaceResources[0].Agents[0].Name, + // Only checking template scoped params here + "parameters/{scope}/{id}": fmt.Sprintf("parameters/%s/%s", + string(templateParam.Scope), templateParam.ScopeID.String()), + } + + return &AuthTester{ + t: t, + api: api, + authorizer: authorizer, + Client: client, + Workspace: workspace, + Organization: organization, + Admin: admin, + Template: template, + Version: version, + WorkspaceResource: workspaceResources[0], + File: file, + TemplateVersionDryRun: templateVersionDryRun, + TemplateParam: templateParam, + URLParams: urlParameters, + } +} + +func (a *AuthTester) Test(ctx context.Context, assertRoute map[string]RouteCheck, skipRoutes map[string]string) { + // Always fail auth from this point forward + a.authorizer.AlwaysReturn = rbac.ForbiddenWithInternal(xerrors.New("fake implementation"), nil, nil) + + routeMissing := make(map[string]bool) + for k, v := range assertRoute { + noTrailSlash := strings.TrimRight(k, "/") + if _, ok := assertRoute[noTrailSlash]; ok && noTrailSlash != k { + a.t.Errorf("route %q & %q is declared twice", noTrailSlash, k) + a.t.FailNow() + } + assertRoute[noTrailSlash] = v + routeMissing[noTrailSlash] = true + } + + for k, v := range skipRoutes { + noTrailSlash := strings.TrimRight(k, "/") + if _, ok := skipRoutes[noTrailSlash]; ok && noTrailSlash != k { + a.t.Errorf("route %q & %q is declared twice", noTrailSlash, k) + a.t.FailNow() + } + skipRoutes[noTrailSlash] = v + } + + err := chi.Walk( + a.api.RootHandler, + func( + method string, + route string, + handler http.Handler, + middlewares ...func(http.Handler) http.Handler, + ) error { + // work around chi's bugged handling of /*/*/ which can occur if we + // r.Mount("/", someHandler()) in our tree + for strings.Contains(route, "/*/") { + route = strings.Replace(route, "/*/", "/", -1) + } + name := method + ":" + route + if _, ok := skipRoutes[strings.TrimRight(name, "/")]; ok { + return nil + } + a.t.Run(name, func(t *testing.T) { + a.authorizer.reset() + routeKey := strings.TrimRight(name, "/") + + routeAssertions, ok := assertRoute[routeKey] + if !ok { + // By default, all omitted routes check for just "authorize" called + routeAssertions = RouteCheck{} + } + delete(routeMissing, routeKey) + + // Replace all url params with known values + for k, v := range a.URLParams { + route = strings.ReplaceAll(route, k, v) + } + + resp, err := a.Client.Request(ctx, method, route, nil) + require.NoError(t, err, "do req") + body, _ := io.ReadAll(resp.Body) + t.Logf("Response Body: %q", string(body)) + _ = resp.Body.Close() + + if !routeAssertions.NoAuthorize { + assert.NotNil(t, a.authorizer.Called, "authorizer expected") + if routeAssertions.StatusCode != 0 { + assert.Equal(t, routeAssertions.StatusCode, resp.StatusCode, "expect unauthorized") + } else { + // It's either a 404 or 403. + if resp.StatusCode != http.StatusNotFound { + assert.Equal(t, http.StatusForbidden, resp.StatusCode, "expect unauthorized") + } + } + if a.authorizer.Called != nil { + if routeAssertions.AssertAction != "" { + assert.Equal(t, routeAssertions.AssertAction, a.authorizer.Called.Action, "resource action") + } + if routeAssertions.AssertObject.Type != "" { + assert.Equal(t, routeAssertions.AssertObject.Type, a.authorizer.Called.Object.Type, "resource type") + } + if routeAssertions.AssertObject.Owner != "" { + assert.Equal(t, routeAssertions.AssertObject.Owner, a.authorizer.Called.Object.Owner, "resource owner") + } + if routeAssertions.AssertObject.OrgID != "" { + assert.Equal(t, routeAssertions.AssertObject.OrgID, a.authorizer.Called.Object.OrgID, "resource org") + } + } + } else { + assert.Nil(t, a.authorizer.Called, "authorize not expected") + } + }) + return nil + }) + require.NoError(a.t, err) + require.Len(a.t, routeMissing, 0, "didn't walk some asserted routes: %v", routeMissing) +} + +type authCall struct { + SubjectID string + Roles []string + Scope rbac.Scope + Action rbac.Action + Object rbac.Object +} + +type RecordingAuthorizer struct { + Called *authCall + AlwaysReturn error +} + +var _ rbac.Authorizer = (*RecordingAuthorizer)(nil) + +func (r *RecordingAuthorizer) ByRoleName(_ context.Context, subjectID string, roleNames []string, scope rbac.Scope, action rbac.Action, object rbac.Object) error { + r.Called = &authCall{ + SubjectID: subjectID, + Roles: roleNames, + Scope: scope, + Action: action, + Object: object, + } + return r.AlwaysReturn +} + +func (r *RecordingAuthorizer) PrepareByRoleName(_ context.Context, subjectID string, roles []string, scope rbac.Scope, action rbac.Action, _ string) (rbac.PreparedAuthorized, error) { + return &fakePreparedAuthorizer{ + Original: r, + SubjectID: subjectID, + Roles: roles, + Scope: scope, + Action: action, + }, nil +} + +func (r *RecordingAuthorizer) reset() { + r.Called = nil +} + +type fakePreparedAuthorizer struct { + Original *RecordingAuthorizer + SubjectID string + Roles []string + Scope rbac.Scope + Action rbac.Action +} + +func (f *fakePreparedAuthorizer) Authorize(ctx context.Context, object rbac.Object) error { + return f.Original.ByRoleName(ctx, f.SubjectID, f.Roles, f.Scope, f.Action, object) +} diff --git a/coderd/coderdtest/authorize_test.go b/coderd/coderdtest/authorize_test.go new file mode 100644 index 0000000000000..c8ef64065a290 --- /dev/null +++ b/coderd/coderdtest/authorize_test.go @@ -0,0 +1,20 @@ +package coderdtest_test + +import ( + "context" + "testing" + + "github.com/coder/coder/coderd/coderdtest" +) + +func TestAuthorizeAllEndpoints(t *testing.T) { + t.Parallel() + client, _, api := coderdtest.NewWithAPI(t, &coderdtest.Options{ + Authorizer: &coderdtest.RecordingAuthorizer{}, + IncludeProvisionerDaemon: true, + }) + admin := coderdtest.CreateFirstUser(t, client) + a := coderdtest.NewAuthTester(context.Background(), t, client, api, admin) + skipRoute, assertRoute := coderdtest.AGPLRoutes(a) + a.Test(context.Background(), assertRoute, skipRoute) +} diff --git a/coderd/coderdtest/coderdtest.go b/coderd/coderdtest/coderdtest.go index 8023a3bce8d4b..55ed03f486960 100644 --- a/coderd/coderdtest/coderdtest.go +++ b/coderd/coderdtest/coderdtest.go @@ -112,7 +112,7 @@ func NewWithProvisionerCloser(t *testing.T, options *Options) (*codersdk.Client, // and is a temporary measure while the API to register provisioners is ironed // out. func newWithCloser(t *testing.T, options *Options) (*codersdk.Client, io.Closer) { - client, closer, _ := newWithAPI(t, options) + client, closer, _ := NewWithAPI(t, options) return client, closer } @@ -247,10 +247,10 @@ func NewOptions(t *testing.T, options *Options) (*httptest.Server, *coderd.Optio } } -// newWithAPI constructs an in-memory API instance and returns a client to talk to it. +// NewWithAPI constructs an in-memory API instance and returns a client to talk to it. // Most tests never need a reference to the API, but AuthorizationTest in this module uses it. // Do not expose the API or wrath shall descend upon thee. -func newWithAPI(t *testing.T, options *Options) (*codersdk.Client, io.Closer, *coderd.API) { +func NewWithAPI(t *testing.T, options *Options) (*codersdk.Client, io.Closer, *coderd.API) { if options == nil { options = &Options{} } diff --git a/enterprise/coderd/coderdenttest/coderdenttest_test.go b/enterprise/coderd/coderdenttest/coderdenttest_test.go index 929a8fc383280..ccea80cf9b968 100644 --- a/enterprise/coderd/coderdenttest/coderdenttest_test.go +++ b/enterprise/coderd/coderdenttest/coderdenttest_test.go @@ -3,25 +3,12 @@ package coderdenttest_test import ( "context" "fmt" - "io" "net/http" - "strconv" - "strings" "testing" - "github.com/go-chi/chi/v5" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "golang.org/x/xerrors" - "github.com/coder/coder/coderd/coderdtest" "github.com/coder/coder/coderd/rbac" - "github.com/coder/coder/codersdk" - "github.com/coder/coder/enterprise/coderd" "github.com/coder/coder/enterprise/coderd/coderdenttest" - "github.com/coder/coder/provisioner/echo" - "github.com/coder/coder/provisionersdk/proto" - "github.com/coder/coder/testutil" ) func TestNew(t *testing.T) { @@ -31,564 +18,34 @@ func TestNew(t *testing.T) { func TestAuthorizeAllEndpoints(t *testing.T) { t.Parallel() - a := newAuthTester(context.Background(), t) - - // Some quick reused objects - workspaceRBACObj := rbac.ResourceWorkspace.InOrg(a.Organization.ID).WithOwner(a.Workspace.OwnerID.String()) - workspaceExecObj := rbac.ResourceWorkspaceExecution.InOrg(a.Organization.ID).WithOwner(a.Workspace.OwnerID.String()) - applicationConnectObj := rbac.ResourceWorkspaceApplicationConnect.InOrg(a.Organization.ID).WithOwner(a.Workspace.OwnerID.String()) - - // skipRoutes allows skipping routes from being checked. - skipRoutes := map[string]string{ - "POST:/api/v2/users/logout": "Logging out deletes the API Key for other routes", - "GET:/derp": "This requires a WebSocket upgrade!", - "GET:/derp/latency-check": "This always returns a 200!", - } - - assertRoute := map[string]routeCheck{ - // These endpoints do not require auth - "GET:/api/v2": {NoAuthorize: true}, - "GET:/api/v2/buildinfo": {NoAuthorize: true}, - "GET:/api/v2/users/first": {NoAuthorize: true}, - "POST:/api/v2/users/first": {NoAuthorize: true}, - "POST:/api/v2/users/login": {NoAuthorize: true}, - "GET:/api/v2/users/authmethods": {NoAuthorize: true}, - "POST:/api/v2/csp/reports": {NoAuthorize: true}, - "GET:/api/v2/entitlements": {NoAuthorize: true}, - - // Has it's own auth - "GET:/api/v2/users/oauth2/github/callback": {NoAuthorize: true}, - "GET:/api/v2/users/oidc/callback": {NoAuthorize: true}, - - // All workspaceagents endpoints do not use rbac - "POST:/api/v2/workspaceagents/aws-instance-identity": {NoAuthorize: true}, - "POST:/api/v2/workspaceagents/azure-instance-identity": {NoAuthorize: true}, - "POST:/api/v2/workspaceagents/google-instance-identity": {NoAuthorize: true}, - "GET:/api/v2/workspaceagents/me/gitsshkey": {NoAuthorize: true}, - "GET:/api/v2/workspaceagents/me/iceservers": {NoAuthorize: true}, - "GET:/api/v2/workspaceagents/me/listen": {NoAuthorize: true}, - "GET:/api/v2/workspaceagents/me/metadata": {NoAuthorize: true}, - "GET:/api/v2/workspaceagents/me/turn": {NoAuthorize: true}, - "GET:/api/v2/workspaceagents/me/coordinate": {NoAuthorize: true}, - "POST:/api/v2/workspaceagents/me/version": {NoAuthorize: true}, - "GET:/api/v2/workspaceagents/me/report-stats": {NoAuthorize: true}, - "GET:/api/v2/workspaceagents/{workspaceagent}/iceservers": {NoAuthorize: true}, - - // These endpoints have more assertions. This is good, add more endpoints to assert if you can! - "GET:/api/v2/organizations/{organization}": {AssertObject: rbac.ResourceOrganization.InOrg(a.Admin.OrganizationID)}, - "GET:/api/v2/users/{user}/organizations": {StatusCode: http.StatusOK, AssertObject: rbac.ResourceOrganization}, - "GET:/api/v2/users/{user}/workspace/{workspacename}": { - AssertObject: rbac.ResourceWorkspace, - AssertAction: rbac.ActionRead, - }, - "GET:/api/v2/users/{user}/workspace/{workspacename}/builds/{buildnumber}": { - AssertObject: rbac.ResourceWorkspace, - AssertAction: rbac.ActionRead, - }, - "GET:/api/v2/workspacebuilds/{workspacebuild}": { - AssertAction: rbac.ActionRead, - AssertObject: workspaceRBACObj, - }, - "GET:/api/v2/workspacebuilds/{workspacebuild}/logs": { - AssertAction: rbac.ActionRead, - AssertObject: workspaceRBACObj, - }, - "GET:/api/v2/workspaces/{workspace}/builds": { - AssertAction: rbac.ActionRead, - AssertObject: workspaceRBACObj, - }, - "GET:/api/v2/workspaces/{workspace}": { - AssertAction: rbac.ActionRead, - AssertObject: workspaceRBACObj, - }, - "PUT:/api/v2/workspaces/{workspace}/autostart": { - AssertAction: rbac.ActionUpdate, - AssertObject: workspaceRBACObj, - }, - "PUT:/api/v2/workspaces/{workspace}/ttl": { - AssertAction: rbac.ActionUpdate, - AssertObject: workspaceRBACObj, - }, - "GET:/api/v2/workspaceresources/{workspaceresource}": { - AssertAction: rbac.ActionRead, - AssertObject: workspaceRBACObj, - }, - "PATCH:/api/v2/workspacebuilds/{workspacebuild}/cancel": { - AssertAction: rbac.ActionUpdate, - AssertObject: workspaceRBACObj, - }, - "GET:/api/v2/workspacebuilds/{workspacebuild}/resources": { - AssertAction: rbac.ActionRead, - AssertObject: workspaceRBACObj, - }, - "GET:/api/v2/workspacebuilds/{workspacebuild}/state": { - AssertAction: rbac.ActionRead, - AssertObject: workspaceRBACObj, - }, - "GET:/api/v2/workspaceagents/{workspaceagent}": { - AssertAction: rbac.ActionRead, - AssertObject: workspaceRBACObj, - }, - "GET:/api/v2/workspaceagents/{workspaceagent}/dial": { - AssertAction: rbac.ActionCreate, - AssertObject: workspaceExecObj, - }, - "GET:/api/v2/workspaceagents/{workspaceagent}/turn": { - AssertAction: rbac.ActionCreate, - AssertObject: workspaceExecObj, - }, - "GET:/api/v2/workspaceagents/{workspaceagent}/pty": { - AssertAction: rbac.ActionCreate, - AssertObject: workspaceExecObj, - }, - "GET:/api/v2/workspaceagents/{workspaceagent}/coordinate": { - AssertAction: rbac.ActionCreate, - AssertObject: workspaceExecObj, - }, - "GET:/api/v2/workspaces/": { - StatusCode: http.StatusOK, - AssertAction: rbac.ActionRead, - AssertObject: workspaceRBACObj, - }, - "GET:/api/v2/organizations/{organization}/templates": { - StatusCode: http.StatusOK, - AssertAction: rbac.ActionRead, - AssertObject: rbac.ResourceTemplate.InOrg(a.Template.OrganizationID), - }, - "POST:/api/v2/organizations/{organization}/templates": { - AssertAction: rbac.ActionCreate, - AssertObject: rbac.ResourceTemplate.InOrg(a.Organization.ID), - }, - "DELETE:/api/v2/templates/{template}": { - AssertAction: rbac.ActionDelete, - AssertObject: rbac.ResourceTemplate.InOrg(a.Template.OrganizationID), - }, - "GET:/api/v2/templates/{template}": { - AssertAction: rbac.ActionRead, - AssertObject: rbac.ResourceTemplate.InOrg(a.Template.OrganizationID), - }, - "POST:/api/v2/files": {AssertAction: rbac.ActionCreate, AssertObject: rbac.ResourceFile}, - "GET:/api/v2/files/{hash}": { - AssertAction: rbac.ActionRead, - AssertObject: rbac.ResourceFile.WithOwner(a.Admin.UserID.String()), - }, - "GET:/api/v2/templates/{template}/versions": { - AssertAction: rbac.ActionRead, - AssertObject: rbac.ResourceTemplate.InOrg(a.Template.OrganizationID), - }, - "PATCH:/api/v2/templates/{template}/versions": { - AssertAction: rbac.ActionUpdate, - AssertObject: rbac.ResourceTemplate.InOrg(a.Template.OrganizationID), - }, - "GET:/api/v2/templates/{template}/versions/{templateversionname}": { - AssertAction: rbac.ActionRead, - AssertObject: rbac.ResourceTemplate.InOrg(a.Template.OrganizationID), - }, - "GET:/api/v2/templateversions/{templateversion}": { - AssertAction: rbac.ActionRead, - AssertObject: rbac.ResourceTemplate.InOrg(a.Template.OrganizationID), - }, - "PATCH:/api/v2/templateversions/{templateversion}/cancel": { - AssertAction: rbac.ActionUpdate, - AssertObject: rbac.ResourceTemplate.InOrg(a.Template.OrganizationID), - }, - "GET:/api/v2/templateversions/{templateversion}/logs": { - AssertAction: rbac.ActionRead, - AssertObject: rbac.ResourceTemplate.InOrg(a.Template.OrganizationID), - }, - "GET:/api/v2/templateversions/{templateversion}/parameters": { - AssertAction: rbac.ActionRead, - AssertObject: rbac.ResourceTemplate.InOrg(a.Template.OrganizationID), - }, - "GET:/api/v2/templateversions/{templateversion}/resources": { - AssertAction: rbac.ActionRead, - AssertObject: rbac.ResourceTemplate.InOrg(a.Template.OrganizationID), - }, - "GET:/api/v2/templateversions/{templateversion}/schema": { - AssertAction: rbac.ActionRead, - AssertObject: rbac.ResourceTemplate.InOrg(a.Template.OrganizationID), - }, - "POST:/api/v2/templateversions/{templateversion}/dry-run": { - // The first check is to read the template - AssertAction: rbac.ActionRead, - AssertObject: rbac.ResourceTemplate.InOrg(a.Version.OrganizationID), - }, - "GET:/api/v2/templateversions/{templateversion}/dry-run/{jobID}": { - AssertAction: rbac.ActionRead, - AssertObject: rbac.ResourceTemplate.InOrg(a.Version.OrganizationID), - }, - "GET:/api/v2/templateversions/{templateversion}/dry-run/{jobID}/resources": { - AssertAction: rbac.ActionRead, - AssertObject: rbac.ResourceTemplate.InOrg(a.Version.OrganizationID), - }, - "GET:/api/v2/templateversions/{templateversion}/dry-run/{jobID}/logs": { - AssertAction: rbac.ActionRead, - AssertObject: rbac.ResourceTemplate.InOrg(a.Version.OrganizationID), - }, - "PATCH:/api/v2/templateversions/{templateversion}/dry-run/{jobID}/cancel": { - AssertAction: rbac.ActionRead, - AssertObject: rbac.ResourceTemplate.InOrg(a.Version.OrganizationID), - }, - "GET:/api/v2/provisionerdaemons": { - StatusCode: http.StatusOK, - AssertObject: rbac.ResourceProvisionerDaemon, - }, - - "POST:/api/v2/parameters/{scope}/{id}": { - AssertAction: rbac.ActionUpdate, - AssertObject: rbac.ResourceTemplate, - }, - "GET:/api/v2/parameters/{scope}/{id}": { - AssertAction: rbac.ActionRead, - AssertObject: rbac.ResourceTemplate, - }, - "DELETE:/api/v2/parameters/{scope}/{id}/{name}": { - AssertAction: rbac.ActionUpdate, - AssertObject: rbac.ResourceTemplate, - }, - "GET:/api/v2/organizations/{organization}/templates/{templatename}": { - AssertAction: rbac.ActionRead, - AssertObject: rbac.ResourceTemplate.InOrg(a.Template.OrganizationID), - }, - "POST:/api/v2/organizations/{organization}/workspaces": { - AssertAction: rbac.ActionCreate, - // No ID when creating - AssertObject: workspaceRBACObj, - }, - "GET:/api/v2/workspaces/{workspace}/watch": { - AssertAction: rbac.ActionRead, - AssertObject: workspaceRBACObj, - }, - "GET:/api/v2/users": {StatusCode: http.StatusOK, AssertObject: rbac.ResourceUser}, - - // These endpoints need payloads to get to the auth part. Payloads will be required - "PUT:/api/v2/users/{user}/roles": {StatusCode: http.StatusBadRequest, NoAuthorize: true}, - "PUT:/api/v2/organizations/{organization}/members/{user}/roles": {NoAuthorize: true}, - "POST:/api/v2/workspaces/{workspace}/builds": {StatusCode: http.StatusBadRequest, NoAuthorize: true}, - "POST:/api/v2/organizations/{organization}/templateversions": {StatusCode: http.StatusBadRequest, NoAuthorize: true}, - - // Enterprise only endpoints - "POST:/api/v2/licenses": { - AssertAction: rbac.ActionCreate, - AssertObject: rbac.ResourceLicense, - }, - "GET:/api/v2/licenses": { - StatusCode: http.StatusOK, - AssertAction: rbac.ActionRead, - AssertObject: rbac.ResourceLicense, - }, - "DELETE:/api/v2/licenses/{id}": { - AssertAction: rbac.ActionDelete, - AssertObject: rbac.ResourceLicense, - }, - } - - // Routes like proxy routes support all HTTP methods. A helper func to expand - // 1 url to all http methods. - assertAllHTTPMethods := func(url string, check routeCheck) { - methods := []string{http.MethodGet, http.MethodHead, http.MethodPost, - http.MethodPut, http.MethodPatch, http.MethodDelete, - http.MethodConnect, http.MethodOptions, http.MethodTrace} - - for _, method := range methods { - route := method + ":" + url - assertRoute[route] = check - } - } - - assertAllHTTPMethods("/%40{user}/{workspace_and_agent}/apps/{workspaceapp}/*", routeCheck{ - AssertAction: rbac.ActionCreate, - AssertObject: applicationConnectObj, - }) - assertAllHTTPMethods("/@{user}/{workspace_and_agent}/apps/{workspaceapp}/*", routeCheck{ - AssertAction: rbac.ActionCreate, - AssertObject: applicationConnectObj, - }) - - a.Test(context.Background(), assertRoute, skipRoutes) -} - -type routeCheck struct { - NoAuthorize bool - AssertAction rbac.Action - AssertObject rbac.Object - StatusCode int -} - -type authTester struct { - t *testing.T - api *coderd.API - authorizer *recordingAuthorizer - - Client *codersdk.Client - Workspace codersdk.Workspace - Organization codersdk.Organization - Admin codersdk.CreateFirstUserResponse - Template codersdk.Template - Version codersdk.TemplateVersion - WorkspaceResource codersdk.WorkspaceResource - File codersdk.UploadResponse - TemplateVersionDryRun codersdk.ProvisionerJob - TemplateParam codersdk.Parameter - URLParams map[string]string -} - -func newAuthTester(ctx context.Context, t *testing.T) *authTester { - authorizer := &recordingAuthorizer{} - options := &coderdenttest.Options{ + client, _, api := coderdenttest.NewWithAPI(t, &coderdenttest.Options{ Options: &coderdtest.Options{ - Authorizer: authorizer, + Authorizer: &coderdtest.RecordingAuthorizer{}, IncludeProvisionerDaemon: true, }, - } - - client, _, api := coderdenttest.NewWithAPI(t, options) - admin := coderdtest.CreateFirstUser(t, client) - // The provisioner will call to coderd and register itself. This is async, - // so we wait for it to occur. - require.Eventually(t, func() bool { - provisionerds, err := client.ProvisionerDaemons(ctx) - return assert.NoError(t, err) && len(provisionerds) > 0 - }, testutil.WaitLong, testutil.IntervalSlow) - - provisionerds, err := client.ProvisionerDaemons(ctx) - require.NoError(t, err, "fetch provisioners") - require.Len(t, provisionerds, 1) - - organization, err := client.Organization(ctx, admin.OrganizationID) - require.NoError(t, err, "fetch org") - - // Setup some data in the database. - version := coderdtest.CreateTemplateVersion(t, client, admin.OrganizationID, &echo.Responses{ - Parse: echo.ParseComplete, - Provision: []*proto.Provision_Response{{ - Type: &proto.Provision_Response_Complete{ - Complete: &proto.Provision_Complete{ - // Return a workspace resource - Resources: []*proto.Resource{{ - Name: "some", - Type: "example", - Agents: []*proto.Agent{{ - Name: "agent", - Id: "something", - Auth: &proto.Agent_Token{}, - Apps: []*proto.App{{ - Name: "testapp", - Url: "http://localhost:3000", - }}, - }}, - }}, - }, - }, - }}, - }) - coderdtest.AwaitTemplateVersionJob(t, client, version.ID) - template := coderdtest.CreateTemplate(t, client, admin.OrganizationID, version.ID) - workspace := coderdtest.CreateWorkspace(t, client, admin.OrganizationID, template.ID) - coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID) - file, err := client.Upload(ctx, codersdk.ContentTypeTar, make([]byte, 1024)) - require.NoError(t, err, "upload file") - workspaceResources, err := client.WorkspaceResourcesByBuild(ctx, workspace.LatestBuild.ID) - require.NoError(t, err, "workspace resources") - templateVersionDryRun, err := client.CreateTemplateVersionDryRun(ctx, version.ID, codersdk.CreateTemplateVersionDryRunRequest{ - ParameterValues: []codersdk.CreateParameterRequest{}, - }) - require.NoError(t, err, "template version dry-run") - - templateParam, err := client.CreateParameter(ctx, codersdk.ParameterTemplate, template.ID, codersdk.CreateParameterRequest{ - Name: "test-param", - SourceValue: "hello world", - SourceScheme: codersdk.ParameterSourceSchemeData, - DestinationScheme: codersdk.ParameterDestinationSchemeProvisionerVariable, }) - require.NoError(t, err, "create template param") + admin := coderdtest.CreateFirstUser(t, client) license := coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{}) - urlParameters := map[string]string{ - "{organization}": admin.OrganizationID.String(), - "{user}": admin.UserID.String(), - "{organizationname}": organization.Name, - "{workspace}": workspace.ID.String(), - "{workspacebuild}": workspace.LatestBuild.ID.String(), - "{workspacename}": workspace.Name, - "{workspaceagent}": workspaceResources[0].Agents[0].ID.String(), - "{buildnumber}": strconv.FormatInt(int64(workspace.LatestBuild.BuildNumber), 10), - "{template}": template.ID.String(), - "{hash}": file.Hash, - "{workspaceresource}": workspaceResources[0].ID.String(), - "{workspaceapp}": workspaceResources[0].Agents[0].Apps[0].Name, - "{templateversion}": version.ID.String(), - "{jobID}": templateVersionDryRun.ID.String(), - "{templatename}": template.Name, - "{workspace_and_agent}": workspace.Name + "." + workspaceResources[0].Agents[0].Name, - // Only checking template scoped params here - "parameters/{scope}/{id}": fmt.Sprintf("parameters/%s/%s", - string(templateParam.Scope), templateParam.ScopeID.String()), - "licenses/{id}": fmt.Sprintf("licenses/%d", license.ID), - } + a := coderdtest.NewAuthTester(context.Background(), t, client, api.AGPL, admin) + a.URLParams["licenses/{id}"] = fmt.Sprintf("licenses/%d", license.ID) - return &authTester{ - t: t, - api: api, - authorizer: authorizer, - Client: client, - Workspace: workspace, - Organization: organization, - Admin: admin, - Template: template, - Version: version, - WorkspaceResource: workspaceResources[0], - File: file, - TemplateVersionDryRun: templateVersionDryRun, - TemplateParam: templateParam, - URLParams: urlParameters, + skipRoutes, assertRoute := coderdtest.AGPLRoutes(a) + assertRoute["GET:/api/v2/entitlements"] = coderdtest.RouteCheck{ + NoAuthorize: true, } -} - -func (a *authTester) Test(ctx context.Context, assertRoute map[string]routeCheck, skipRoutes map[string]string) { - // Always fail auth from this point forward - a.authorizer.AlwaysReturn = rbac.ForbiddenWithInternal(xerrors.New("fake implementation"), nil, nil) - - routeMissing := make(map[string]bool) - for k, v := range assertRoute { - noTrailSlash := strings.TrimRight(k, "/") - if _, ok := assertRoute[noTrailSlash]; ok && noTrailSlash != k { - a.t.Errorf("route %q & %q is declared twice", noTrailSlash, k) - a.t.FailNow() - } - assertRoute[noTrailSlash] = v - routeMissing[noTrailSlash] = true + assertRoute["POST:/api/v2/licenses"] = coderdtest.RouteCheck{ + AssertAction: rbac.ActionCreate, + AssertObject: rbac.ResourceLicense, } - - for k, v := range skipRoutes { - noTrailSlash := strings.TrimRight(k, "/") - if _, ok := skipRoutes[noTrailSlash]; ok && noTrailSlash != k { - a.t.Errorf("route %q & %q is declared twice", noTrailSlash, k) - a.t.FailNow() - } - skipRoutes[noTrailSlash] = v + assertRoute["GET:/api/v2/licenses"] = coderdtest.RouteCheck{ + StatusCode: http.StatusOK, + AssertAction: rbac.ActionRead, + AssertObject: rbac.ResourceLicense, } - - err := chi.Walk( - a.api.AGPL.RootHandler, - func( - method string, - route string, - handler http.Handler, - middlewares ...func(http.Handler) http.Handler, - ) error { - // work around chi's bugged handling of /*/*/ which can occur if we - // r.Mount("/", someHandler()) in our tree - for strings.Contains(route, "/*/") { - route = strings.Replace(route, "/*/", "/", -1) - } - name := method + ":" + route - if _, ok := skipRoutes[strings.TrimRight(name, "/")]; ok { - return nil - } - a.t.Run(name, func(t *testing.T) { - a.authorizer.reset() - routeKey := strings.TrimRight(name, "/") - - routeAssertions, ok := assertRoute[routeKey] - if !ok { - // By default, all omitted routes check for just "authorize" called - routeAssertions = routeCheck{} - } - delete(routeMissing, routeKey) - - // Replace all url params with known values - for k, v := range a.URLParams { - route = strings.ReplaceAll(route, k, v) - } - - resp, err := a.Client.Request(ctx, method, route, nil) - require.NoError(t, err, "do req") - body, _ := io.ReadAll(resp.Body) - t.Logf("Response Body: %q", string(body)) - _ = resp.Body.Close() - - if !routeAssertions.NoAuthorize { - assert.NotNil(t, a.authorizer.Called, "authorizer expected") - if routeAssertions.StatusCode != 0 { - assert.Equal(t, routeAssertions.StatusCode, resp.StatusCode, "expect unauthorized") - } else { - // It's either a 404 or 403. - if resp.StatusCode != http.StatusNotFound { - assert.Equal(t, http.StatusForbidden, resp.StatusCode, "expect unauthorized") - } - } - if a.authorizer.Called != nil { - if routeAssertions.AssertAction != "" { - assert.Equal(t, routeAssertions.AssertAction, a.authorizer.Called.Action, "resource action") - } - if routeAssertions.AssertObject.Type != "" { - assert.Equal(t, routeAssertions.AssertObject.Type, a.authorizer.Called.Object.Type, "resource type") - } - if routeAssertions.AssertObject.Owner != "" { - assert.Equal(t, routeAssertions.AssertObject.Owner, a.authorizer.Called.Object.Owner, "resource owner") - } - if routeAssertions.AssertObject.OrgID != "" { - assert.Equal(t, routeAssertions.AssertObject.OrgID, a.authorizer.Called.Object.OrgID, "resource org") - } - } - } else { - assert.Nil(t, a.authorizer.Called, "authorize not expected") - } - }) - return nil - }) - require.NoError(a.t, err) - require.Len(a.t, routeMissing, 0, "didn't walk some asserted routes: %v", routeMissing) -} - -type authCall struct { - SubjectID string - Roles []string - Scope rbac.Scope - Action rbac.Action - Object rbac.Object -} - -type recordingAuthorizer struct { - Called *authCall - AlwaysReturn error -} - -var _ rbac.Authorizer = (*recordingAuthorizer)(nil) - -func (r *recordingAuthorizer) ByRoleName(_ context.Context, subjectID string, roleNames []string, scope rbac.Scope, action rbac.Action, object rbac.Object) error { - r.Called = &authCall{ - SubjectID: subjectID, - Roles: roleNames, - Scope: scope, - Action: action, - Object: object, + assertRoute["DELETE:/api/v2/licenses/{id}"] = coderdtest.RouteCheck{ + AssertAction: rbac.ActionDelete, + AssertObject: rbac.ResourceLicense, } - return r.AlwaysReturn -} -func (r *recordingAuthorizer) PrepareByRoleName(_ context.Context, subjectID string, roles []string, scope rbac.Scope, action rbac.Action, _ string) (rbac.PreparedAuthorized, error) { - return &fakePreparedAuthorizer{ - Original: r, - SubjectID: subjectID, - Roles: roles, - Scope: scope, - Action: action, - }, nil -} - -func (r *recordingAuthorizer) reset() { - r.Called = nil -} - -type fakePreparedAuthorizer struct { - Original *recordingAuthorizer - SubjectID string - Roles []string - Scope rbac.Scope - Action rbac.Action -} - -func (f *fakePreparedAuthorizer) Authorize(ctx context.Context, object rbac.Object) error { - return f.Original.ByRoleName(ctx, f.SubjectID, f.Roles, f.Scope, f.Action, object) + a.Test(context.Background(), assertRoute, skipRoutes) } From b0a3c95603858e32660108e441d6b93ff181c3d5 Mon Sep 17 00:00:00 2001 From: Kyle Carberry Date: Mon, 19 Sep 2022 19:26:10 +0000 Subject: [PATCH 15/19] Add pointer loading example from comment --- coderd/pointer/pointer_test.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/coderd/pointer/pointer_test.go b/coderd/pointer/pointer_test.go index f41ef8cd36bd5..e4c7f0f3f9773 100644 --- a/coderd/pointer/pointer_test.go +++ b/coderd/pointer/pointer_test.go @@ -18,6 +18,8 @@ func TestHandle(t *testing.T) { ctx, value := ptr.Load(ctx) require.Equal(t, "hello", value) ptr.Store("world") + ctx, value = ptr.Load(ctx) + require.Equal(t, "hello", value) _, value = ptr.Load(ctx) require.Equal(t, "hello", value) }) From 2b87ae1ca3e7d4618e185a9003e21b59823eab20 Mon Sep 17 00:00:00 2001 From: Kyle Carberry Date: Mon, 19 Sep 2022 20:37:32 +0000 Subject: [PATCH 16/19] Fix duplicate test, remove pointer.Handle --- coderd/coderd.go | 7 +++--- coderd/pointer/pointer.go | 38 -------------------------------- coderd/pointer/pointer_test.go | 36 ------------------------------ coderd/templates.go | 6 ++--- coderd/templateversions.go | 4 ++-- coderd/users.go | 12 +++++----- coderd/workspaces.go | 16 +++++++------- enterprise/coderd/coderd.go | 16 +++++--------- enterprise/coderd/coderd_test.go | 31 +++++++++++++++++--------- 9 files changed, 50 insertions(+), 116 deletions(-) delete mode 100644 coderd/pointer/pointer.go delete mode 100644 coderd/pointer/pointer_test.go diff --git a/coderd/coderd.go b/coderd/coderd.go index fce094d1f0d67..e9e9de8edd771 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -7,6 +7,7 @@ import ( "net/url" "path/filepath" "sync" + "sync/atomic" "time" "github.com/andybalholm/brotli" @@ -32,7 +33,6 @@ import ( "github.com/coder/coder/coderd/httpapi" "github.com/coder/coder/coderd/httpmw" "github.com/coder/coder/coderd/metricscache" - "github.com/coder/coder/coderd/pointer" "github.com/coder/coder/coderd/rbac" "github.com/coder/coder/coderd/telemetry" "github.com/coder/coder/coderd/tracing" @@ -148,8 +148,9 @@ func New(options *Options) *API { Logger: options.Logger, }, metricsCache: metricsCache, - Auditor: pointer.New(options.Auditor), + Auditor: atomic.Pointer[audit.Auditor]{}, } + api.Auditor.Store(&options.Auditor) if options.TailscaleEnable { api.workspaceAgentCache = wsconncache.New(api.dialWorkspaceAgentTailnet, 0) @@ -498,7 +499,7 @@ func New(options *Options) *API { type API struct { *Options - Auditor *pointer.Handle[audit.Auditor] + Auditor atomic.Pointer[audit.Auditor] HTTPAuth *HTTPAuthorizer // APIHandler serves "/api/v2" diff --git a/coderd/pointer/pointer.go b/coderd/pointer/pointer.go deleted file mode 100644 index 733b022d2bc35..0000000000000 --- a/coderd/pointer/pointer.go +++ /dev/null @@ -1,38 +0,0 @@ -package pointer - -import ( - "context" - - "go.uber.org/atomic" -) - -// New constructs a Handle with an initialized value. -func New[T any](value T) *Handle[T] { - h := &Handle[T]{ - key: struct{}{}, - ptr: atomic.Pointer[T]{}, - } - h.Store(value) - return h -} - -// Handle loads the stored value into a context, and returns -// a context with the attached value. It's intention is to -// hold a single handle for the lifecycle of a request. -type Handle[T any] struct { - key struct{} - ptr atomic.Pointer[T] -} - -func (p *Handle[T]) Load(ctx context.Context) (context.Context, T) { - value, ok := ctx.Value(&p.key).(T) - if !ok { - ctx = context.WithValue(ctx, &p.key, *p.ptr.Load()) - return p.Load(ctx) - } - return ctx, value -} - -func (p *Handle[T]) Store(t T) { - p.ptr.Store(&t) -} diff --git a/coderd/pointer/pointer_test.go b/coderd/pointer/pointer_test.go deleted file mode 100644 index e4c7f0f3f9773..0000000000000 --- a/coderd/pointer/pointer_test.go +++ /dev/null @@ -1,36 +0,0 @@ -package pointer_test - -import ( - "context" - "testing" - - "github.com/stretchr/testify/require" - - "github.com/coder/coder/coderd/pointer" -) - -func TestHandle(t *testing.T) { - t.Parallel() - t.Run("Single", func(t *testing.T) { - t.Parallel() - ptr := pointer.New("hello") - ctx := context.Background() - ctx, value := ptr.Load(ctx) - require.Equal(t, "hello", value) - ptr.Store("world") - ctx, value = ptr.Load(ctx) - require.Equal(t, "hello", value) - _, value = ptr.Load(ctx) - require.Equal(t, "hello", value) - }) - t.Run("Multiple", func(t *testing.T) { - t.Parallel() - ptr1 := pointer.New("1") - ptr2 := pointer.New("2") - ctx := context.Background() - ctx, v1 := ptr1.Load(ctx) - require.Equal(t, "1", v1) - _, v2 := ptr2.Load(ctx) - require.Equal(t, "2", v2) - }) -} diff --git a/coderd/templates.go b/coderd/templates.go index 50dd788200930..96bf39dd268de 100644 --- a/coderd/templates.go +++ b/coderd/templates.go @@ -85,7 +85,7 @@ func (api *API) template(rw http.ResponseWriter, r *http.Request) { func (api *API) deleteTemplate(rw http.ResponseWriter, r *http.Request) { var ( template = httpmw.TemplateParam(r) - _, auditor = api.Auditor.Load(r.Context()) + auditor = *api.Auditor.Load() aReq, commitAudit = audit.InitRequest[database.Template](rw, &audit.RequestParams{ Audit: auditor, Log: api.Logger, @@ -140,7 +140,7 @@ func (api *API) postTemplateByOrganization(rw http.ResponseWriter, r *http.Reque createTemplate codersdk.CreateTemplateRequest organization = httpmw.OrganizationParam(r) apiKey = httpmw.APIKey(r) - _, auditor = api.Auditor.Load(r.Context()) + auditor = *api.Auditor.Load() templateAudit, commitTemplateAudit = audit.InitRequest[database.Template](rw, &audit.RequestParams{ Audit: auditor, Log: api.Logger, @@ -437,7 +437,7 @@ func (api *API) templateByOrganizationAndName(rw http.ResponseWriter, r *http.Re func (api *API) patchTemplateMeta(rw http.ResponseWriter, r *http.Request) { var ( template = httpmw.TemplateParam(r) - _, auditor = api.Auditor.Load(r.Context()) + auditor = *api.Auditor.Load() aReq, commitAudit = audit.InitRequest[database.Template](rw, &audit.RequestParams{ Audit: auditor, Log: api.Logger, diff --git a/coderd/templateversions.go b/coderd/templateversions.go index 37517073aa254..69a98e03b352c 100644 --- a/coderd/templateversions.go +++ b/coderd/templateversions.go @@ -559,7 +559,7 @@ func (api *API) templateVersionByName(rw http.ResponseWriter, r *http.Request) { func (api *API) patchActiveTemplateVersion(rw http.ResponseWriter, r *http.Request) { var ( template = httpmw.TemplateParam(r) - _, auditor = api.Auditor.Load(r.Context()) + auditor = *api.Auditor.Load() aReq, commitAudit = audit.InitRequest[database.Template](rw, &audit.RequestParams{ Audit: auditor, Log: api.Logger, @@ -632,7 +632,7 @@ func (api *API) postTemplateVersionsByOrganization(rw http.ResponseWriter, r *ht var ( apiKey = httpmw.APIKey(r) organization = httpmw.OrganizationParam(r) - _, auditor = api.Auditor.Load(r.Context()) + auditor = *api.Auditor.Load() aReq, commitAudit = audit.InitRequest[database.TemplateVersion](rw, &audit.RequestParams{ Audit: auditor, Log: api.Logger, diff --git a/coderd/users.go b/coderd/users.go index 604a28859ab39..631e660eb5770 100644 --- a/coderd/users.go +++ b/coderd/users.go @@ -255,7 +255,7 @@ func (api *API) users(rw http.ResponseWriter, r *http.Request) { // Creates a new user. func (api *API) postUser(rw http.ResponseWriter, r *http.Request) { - _, auditor := api.Auditor.Load(r.Context()) + auditor := *api.Auditor.Load() aReq, commitAudit := audit.InitRequest[database.User](rw, &audit.RequestParams{ Audit: auditor, Log: api.Logger, @@ -340,7 +340,7 @@ func (api *API) postUser(rw http.ResponseWriter, r *http.Request) { } func (api *API) deleteUser(rw http.ResponseWriter, r *http.Request) { - _, auditor := api.Auditor.Load(r.Context()) + auditor := *api.Auditor.Load() user := httpmw.UserParam(r) aReq, commitAudit := audit.InitRequest[database.User](rw, &audit.RequestParams{ Audit: auditor, @@ -416,7 +416,7 @@ func (api *API) userByName(rw http.ResponseWriter, r *http.Request) { func (api *API) putUserProfile(rw http.ResponseWriter, r *http.Request) { var ( user = httpmw.UserParam(r) - _, auditor = api.Auditor.Load(r.Context()) + auditor = *api.Auditor.Load() aReq, commitAudit = audit.InitRequest[database.User](rw, &audit.RequestParams{ Audit: auditor, Log: api.Logger, @@ -497,7 +497,7 @@ func (api *API) putUserStatus(status database.UserStatus) func(rw http.ResponseW var ( user = httpmw.UserParam(r) apiKey = httpmw.APIKey(r) - _, auditor = api.Auditor.Load(r.Context()) + auditor = *api.Auditor.Load() aReq, commitAudit = audit.InitRequest[database.User](rw, &audit.RequestParams{ Audit: auditor, Log: api.Logger, @@ -564,7 +564,7 @@ func (api *API) putUserPassword(rw http.ResponseWriter, r *http.Request) { var ( user = httpmw.UserParam(r) params codersdk.UpdateUserPasswordRequest - _, auditor = api.Auditor.Load(r.Context()) + auditor = *api.Auditor.Load() aReq, commitAudit = audit.InitRequest[database.User](rw, &audit.RequestParams{ Audit: auditor, Log: api.Logger, @@ -703,7 +703,7 @@ func (api *API) putUserRoles(rw http.ResponseWriter, r *http.Request) { user = httpmw.UserParam(r) actorRoles = httpmw.UserAuthorization(r) apiKey = httpmw.APIKey(r) - _, auditor = api.Auditor.Load(r.Context()) + auditor = *api.Auditor.Load() aReq, commitAudit = audit.InitRequest[database.User](rw, &audit.RequestParams{ Audit: auditor, Log: api.Logger, diff --git a/coderd/workspaces.go b/coderd/workspaces.go index f81a10f42ca92..f3375c6c75e6a 100644 --- a/coderd/workspaces.go +++ b/coderd/workspaces.go @@ -217,9 +217,9 @@ func (api *API) postWorkspacesByOrganization(rw http.ResponseWriter, r *http.Req var ( organization = httpmw.OrganizationParam(r) apiKey = httpmw.APIKey(r) - _, auditor = api.Auditor.Load(r.Context()) + auditor = api.Auditor.Load() aReq, commitAudit = audit.InitRequest[database.Workspace](rw, &audit.RequestParams{ - Audit: auditor, + Audit: *auditor, Log: api.Logger, Request: r, Action: database.AuditActionCreate, @@ -481,9 +481,9 @@ func (api *API) postWorkspacesByOrganization(rw http.ResponseWriter, r *http.Req func (api *API) patchWorkspace(rw http.ResponseWriter, r *http.Request) { var ( workspace = httpmw.WorkspaceParam(r) - _, auditor = api.Auditor.Load(r.Context()) + auditor = api.Auditor.Load() aReq, commitAudit = audit.InitRequest[database.Workspace](rw, &audit.RequestParams{ - Audit: auditor, + Audit: *auditor, Log: api.Logger, Request: r, Action: database.AuditActionWrite, @@ -558,9 +558,9 @@ func (api *API) patchWorkspace(rw http.ResponseWriter, r *http.Request) { func (api *API) putWorkspaceAutostart(rw http.ResponseWriter, r *http.Request) { var ( workspace = httpmw.WorkspaceParam(r) - _, auditor = api.Auditor.Load(r.Context()) + auditor = api.Auditor.Load() aReq, commitAudit = audit.InitRequest[database.Workspace](rw, &audit.RequestParams{ - Audit: auditor, + Audit: *auditor, Log: api.Logger, Request: r, Action: database.AuditActionWrite, @@ -619,9 +619,9 @@ func (api *API) putWorkspaceAutostart(rw http.ResponseWriter, r *http.Request) { func (api *API) putWorkspaceTTL(rw http.ResponseWriter, r *http.Request) { var ( workspace = httpmw.WorkspaceParam(r) - _, auditor = api.Auditor.Load(r.Context()) + auditor = api.Auditor.Load() aReq, commitAudit = audit.InitRequest[database.Workspace](rw, &audit.RequestParams{ - Audit: auditor, + Audit: *auditor, Log: api.Logger, Request: r, Action: database.AuditActionWrite, diff --git a/enterprise/coderd/coderd.go b/enterprise/coderd/coderd.go index d8bea430694c6..8f41725c73e23 100644 --- a/enterprise/coderd/coderd.go +++ b/enterprise/coderd/coderd.go @@ -65,12 +65,6 @@ func New(ctx context.Context, options *Options) (*API, error) { if err != nil { return nil, xerrors.Errorf("update entitlements: %w", err) } - api.closeLicenseSubscribe, err = api.Pubsub.Subscribe(PubsubEventLicenses, func(ctx context.Context, message []byte) { - _ = api.updateEntitlements(ctx) - }) - if err != nil { - return nil, xerrors.Errorf("subscribe to license updates: %w", err) - } go api.runEntitlementsLoop(ctx) return api, nil @@ -88,7 +82,6 @@ type API struct { AGPL *coderd.API *Options - closeLicenseSubscribe func() cancelEntitlementsLoop func() mutex sync.RWMutex hasLicense bool @@ -97,7 +90,6 @@ type API struct { } func (api *API) Close() error { - api.closeLicenseSubscribe() api.cancelEntitlementsLoop() return api.AGPL.Close() } @@ -160,7 +152,7 @@ func (api *API) updateEntitlements(ctx context.Context) error { backends.NewSlog(api.Logger), ) } - api.AGPL.Auditor.Store(auditor) + api.AGPL.Auditor.Store(&auditor) } api.hasLicense = hasLicense @@ -242,7 +234,11 @@ func (api *API) runEntitlementsLoop(ctx context.Context) { }) if err != nil { api.Logger.Warn(ctx, "failed to subscribe to license updates", slog.Error(err)) - time.Sleep(b.NextBackOff()) + select { + case <-ctx.Done(): + return + case <-time.After(b.NextBackOff()): + } continue } // nolint: revive diff --git a/enterprise/coderd/coderd_test.go b/enterprise/coderd/coderd_test.go index 72c4befa5c9ce..a1a20c8db843c 100644 --- a/enterprise/coderd/coderd_test.go +++ b/enterprise/coderd/coderd_test.go @@ -34,14 +34,6 @@ func TestEntitlements(t *testing.T) { require.False(t, res.HasLicense) require.Empty(t, res.Warnings) }) - t.Run("NoLicense", func(t *testing.T) { - t.Parallel() - client := coderdenttest.New(t, nil) - res, err := client.Entitlements(context.Background()) - require.NoError(t, err) - require.False(t, res.HasLicense) - require.Empty(t, res.Warnings) - }) t.Run("FullLicense", func(t *testing.T) { t.Parallel() client := coderdenttest.New(t, nil) @@ -153,6 +145,7 @@ func TestEntitlements(t *testing.T) { require.NoError(t, err) require.False(t, entitlements.HasLicense) coderdtest.CreateFirstUser(t, client) + // Valid _, err = api.Database.InsertLicense(context.Background(), database.InsertLicenseParams{ UploadedAt: database.Now(), Exp: database.Now().AddDate(1, 0, 0), @@ -161,6 +154,22 @@ func TestEntitlements(t *testing.T) { }), }) require.NoError(t, err) + // Expired + _, err = api.Database.InsertLicense(context.Background(), database.InsertLicenseParams{ + UploadedAt: database.Now(), + Exp: database.Now().AddDate(-1, 0, 0), + JWT: coderdenttest.GenerateLicense(t, coderdenttest.LicenseOptions{ + AuditLog: true, + }), + }) + require.NoError(t, err) + // Invalid + _, err = api.Database.InsertLicense(context.Background(), database.InsertLicenseParams{ + UploadedAt: database.Now(), + Exp: database.Now().AddDate(1, 0, 0), + JWT: "invalid", + }) + require.NoError(t, err) require.Eventually(t, func() bool { entitlements, err := client.Entitlements(context.Background()) assert.NoError(t, err) @@ -178,16 +187,18 @@ func TestAuditLogging(t *testing.T) { coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{ AuditLog: true, }) - _, auditor := api.AGPL.Auditor.Load(context.Background()) + auditor := *api.AGPL.Auditor.Load() ea := audit.NewAuditor(audit.DefaultFilter) + t.Logf("%T = %T", auditor, ea) assert.Equal(t, reflect.ValueOf(ea).Type(), reflect.ValueOf(auditor).Type()) }) t.Run("Disabled", func(t *testing.T) { t.Parallel() client, _, api := coderdenttest.NewWithAPI(t, nil) coderdtest.CreateFirstUser(t, client) - _, auditor := api.AGPL.Auditor.Load(context.Background()) + auditor := *api.AGPL.Auditor.Load() ea := agplaudit.NewNop() + t.Logf("%T = %T", auditor, ea) assert.Equal(t, reflect.ValueOf(ea).Type(), reflect.ValueOf(auditor).Type()) }) } From 7fd09038b7540eeed3aea9f7ef79138a23dc6d34 Mon Sep 17 00:00:00 2001 From: Kyle Carberry Date: Mon, 19 Sep 2022 21:08:45 +0000 Subject: [PATCH 17/19] Fix expired license --- enterprise/coderd/coderd_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/enterprise/coderd/coderd_test.go b/enterprise/coderd/coderd_test.go index a1a20c8db843c..aedd79417be41 100644 --- a/enterprise/coderd/coderd_test.go +++ b/enterprise/coderd/coderd_test.go @@ -159,7 +159,7 @@ func TestEntitlements(t *testing.T) { UploadedAt: database.Now(), Exp: database.Now().AddDate(-1, 0, 0), JWT: coderdenttest.GenerateLicense(t, coderdenttest.LicenseOptions{ - AuditLog: true, + ExpiresAt: database.Now().AddDate(-1, 0, 0), }), }) require.NoError(t, err) From c67605b639a0afed21a3e1fdd18d583969598c25 Mon Sep 17 00:00:00 2001 From: Kyle Carberry Date: Mon, 19 Sep 2022 23:28:08 +0000 Subject: [PATCH 18/19] Add entitlements struct --- enterprise/coderd/coderd.go | 90 ++++++++++++++++++++----------------- 1 file changed, 48 insertions(+), 42 deletions(-) diff --git a/enterprise/coderd/coderd.go b/enterprise/coderd/coderd.go index 8f41725c73e23..20140f0e80d83 100644 --- a/enterprise/coderd/coderd.go +++ b/enterprise/coderd/coderd.go @@ -38,11 +38,13 @@ func New(ctx context.Context, options *Options) (*API, error) { AGPL: coderd.New(options.Options), Options: options, - activeUsers: codersdk.Feature{ - Entitlement: codersdk.EntitlementNotEntitled, - Enabled: false, + entitlements: entitlements{ + activeUsers: codersdk.Feature{ + Entitlement: codersdk.EntitlementNotEntitled, + Enabled: false, + }, + auditLogs: codersdk.EntitlementNotEntitled, }, - auditLogs: codersdk.EntitlementNotEntitled, cancelEntitlementsLoop: cancelFunc, } oauthConfigs := &httpmw.OAuth2Configs{ @@ -52,7 +54,7 @@ func New(ctx context.Context, options *Options) (*API, error) { apiKeyMiddleware := httpmw.ExtractAPIKey(options.Database, oauthConfigs, false) api.AGPL.APIHandler.Group(func(r chi.Router) { - r.Get("/entitlements", api.entitlements) + r.Get("/entitlements", api.serveEntitlements) r.Route("/licenses", func(r chi.Router) { r.Use(apiKeyMiddleware) r.Post("/", api.postLicense) @@ -83,10 +85,14 @@ type API struct { *Options cancelEntitlementsLoop func() - mutex sync.RWMutex - hasLicense bool - activeUsers codersdk.Feature - auditLogs codersdk.Entitlement + entitlementsMu sync.RWMutex + entitlements entitlements +} + +type entitlements struct { + hasLicense bool + activeUsers codersdk.Feature + auditLogs codersdk.Entitlement } func (api *API) Close() error { @@ -99,17 +105,19 @@ func (api *API) updateEntitlements(ctx context.Context) error { if err != nil { return err } - api.mutex.Lock() - defer api.mutex.Unlock() + api.entitlementsMu.Lock() + defer api.entitlementsMu.Unlock() now := time.Now() // Default all entitlements to be disabled. - hasLicense := false - activeUsers := codersdk.Feature{ - Enabled: false, - Entitlement: codersdk.EntitlementNotEntitled, + entitlements := entitlements{ + hasLicense: false, + activeUsers: codersdk.Feature{ + Enabled: false, + Entitlement: codersdk.EntitlementNotEntitled, + }, + auditLogs: codersdk.EntitlementNotEntitled, } - auditLogs := codersdk.EntitlementNotEntitled // Here we loop through licenses to detect enabled features. for _, l := range licenses { @@ -119,7 +127,7 @@ func (api *API) updateEntitlements(ctx context.Context) error { slog.F("id", l.ID), slog.Error(err)) continue } - hasLicense = true + entitlements.hasLicense = true entitlement := codersdk.EntitlementEntitled if now.After(claims.LicenseExpires.Time) { // if the grace period were over, the validation fails, so if we are after @@ -127,25 +135,27 @@ func (api *API) updateEntitlements(ctx context.Context) error { entitlement = codersdk.EntitlementGracePeriod } if claims.Features.UserLimit > 0 { - activeUsers.Enabled = true - activeUsers.Entitlement = entitlement + entitlements.activeUsers = codersdk.Feature{ + Enabled: true, + Entitlement: entitlement, + } currentLimit := int64(0) - if activeUsers.Limit != nil { - currentLimit = *activeUsers.Limit + if entitlements.activeUsers.Limit != nil { + currentLimit = *entitlements.activeUsers.Limit } limit := max(currentLimit, claims.Features.UserLimit) - activeUsers.Limit = &limit + entitlements.activeUsers.Limit = &limit } if claims.Features.AuditLog > 0 { - auditLogs = entitlement + entitlements.auditLogs = entitlement } } - if auditLogs != api.auditLogs { + if entitlements.auditLogs != api.entitlements.auditLogs { auditor := agplaudit.NewNop() // A flag could be added to the options that would allow disabling // enhanced audit logging here! - if auditLogs == codersdk.EntitlementEntitled && api.AuditLogging { + if entitlements.auditLogs == codersdk.EntitlementEntitled && api.AuditLogging { auditor = audit.NewAuditor( audit.DefaultFilter, backends.NewPostgres(api.Database, true), @@ -155,27 +165,23 @@ func (api *API) updateEntitlements(ctx context.Context) error { api.AGPL.Auditor.Store(&auditor) } - api.hasLicense = hasLicense - api.activeUsers = activeUsers - api.auditLogs = auditLogs + api.entitlements = entitlements return nil } -func (api *API) entitlements(rw http.ResponseWriter, r *http.Request) { - api.mutex.RLock() - hasLicense := api.hasLicense - activeUsers := api.activeUsers - auditLogs := api.auditLogs - api.mutex.RUnlock() +func (api *API) serveEntitlements(rw http.ResponseWriter, r *http.Request) { + api.entitlementsMu.RLock() + entitlements := api.entitlements + api.entitlementsMu.RUnlock() resp := codersdk.Entitlements{ Features: make(map[string]codersdk.Feature), Warnings: make([]string, 0), - HasLicense: hasLicense, + HasLicense: entitlements.hasLicense, } - if activeUsers.Limit != nil { + if entitlements.activeUsers.Limit != nil { activeUserCount, err := api.Database.GetActiveUserCount(r.Context()) if err != nil { httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{ @@ -184,22 +190,22 @@ func (api *API) entitlements(rw http.ResponseWriter, r *http.Request) { }) return } - activeUsers.Actual = &activeUserCount - if activeUserCount > *activeUsers.Limit { + entitlements.activeUsers.Actual = &activeUserCount + if activeUserCount > *entitlements.activeUsers.Limit { resp.Warnings = append(resp.Warnings, fmt.Sprintf( "Your deployment has %d active users but is only licensed for %d.", - activeUserCount, *activeUsers.Limit)) + activeUserCount, *entitlements.activeUsers.Limit)) } } - resp.Features[codersdk.FeatureUserLimit] = activeUsers + resp.Features[codersdk.FeatureUserLimit] = entitlements.activeUsers // Audit logs resp.Features[codersdk.FeatureAuditLog] = codersdk.Feature{ - Entitlement: auditLogs, + Entitlement: entitlements.auditLogs, Enabled: api.AuditLogging, } - if auditLogs == codersdk.EntitlementGracePeriod && api.AuditLogging { + if entitlements.auditLogs == codersdk.EntitlementGracePeriod && api.AuditLogging { resp.Warnings = append(resp.Warnings, "Audit logging is enabled but your license for this feature is expired.") } From 7f9ed3991ed7eabcf8a922487a573d1b7a74722f Mon Sep 17 00:00:00 2001 From: Kyle Carberry Date: Tue, 20 Sep 2022 02:09:35 +0000 Subject: [PATCH 19/19] Fix context passing --- coderd/coderdtest/coderdtest.go | 13 +++++-------- enterprise/coderd/coderdenttest/coderdenttest.go | 3 ++- 2 files changed, 7 insertions(+), 9 deletions(-) diff --git a/coderd/coderdtest/coderdtest.go b/coderd/coderdtest/coderdtest.go index 58319d1a89fd3..fd31cd55230b9 100644 --- a/coderd/coderdtest/coderdtest.go +++ b/coderd/coderdtest/coderdtest.go @@ -115,7 +115,7 @@ func newWithCloser(t *testing.T, options *Options) (*codersdk.Client, io.Closer) return client, closer } -func NewOptions(t *testing.T, options *Options) (*httptest.Server, *coderd.Options) { +func NewOptions(t *testing.T, options *Options) (*httptest.Server, context.CancelFunc, *coderd.Options) { if options == nil { options = &Options{} } @@ -159,8 +159,6 @@ func NewOptions(t *testing.T, options *Options) (*httptest.Server, *coderd.Optio } ctx, cancelFunc := context.WithCancel(context.Background()) - defer t.Cleanup(cancelFunc) // Defer to ensure cancelFunc is executed first. - lifecycleExecutor := executor.New( ctx, db, @@ -194,7 +192,7 @@ func NewOptions(t *testing.T, options *Options) (*httptest.Server, *coderd.Optio options.SSHKeygenAlgorithm = gitsshkey.AlgorithmEd25519 } - return srv, &coderd.Options{ + return srv, cancelFunc, &coderd.Options{ AgentConnectionUpdateFrequency: 150 * time.Millisecond, // Force a long disconnection timeout to ensure // agents are not marked as disconnected during slow tests. @@ -246,19 +244,18 @@ func NewWithAPI(t *testing.T, options *Options) (*codersdk.Client, io.Closer, *c if options == nil { options = &Options{} } - srv, newOptions := NewOptions(t, options) + srv, cancelFunc, newOptions := NewOptions(t, options) // We set the handler after server creation for the access URL. coderAPI := coderd.New(newOptions) - t.Cleanup(func() { - _ = coderAPI.Close() - }) srv.Config.Handler = coderAPI.RootHandler var provisionerCloser io.Closer = nopcloser{} if options.IncludeProvisionerDaemon { provisionerCloser = NewProvisionerDaemon(t, coderAPI) } t.Cleanup(func() { + cancelFunc() _ = provisionerCloser.Close() + _ = coderAPI.Close() }) return codersdk.New(coderAPI.AccessURL), provisionerCloser, coderAPI } diff --git a/enterprise/coderd/coderdenttest/coderdenttest.go b/enterprise/coderd/coderdenttest/coderdenttest.go index 813fddf473bae..572b858bea31f 100644 --- a/enterprise/coderd/coderdenttest/coderdenttest.go +++ b/enterprise/coderd/coderdenttest/coderdenttest.go @@ -52,7 +52,7 @@ func NewWithAPI(t *testing.T, options *Options) (*codersdk.Client, io.Closer, *c if options.Options == nil { options.Options = &coderdtest.Options{} } - srv, oop := coderdtest.NewOptions(t, options.Options) + srv, cancelFunc, oop := coderdtest.NewOptions(t, options.Options) coderAPI, err := coderd.New(context.Background(), &coderd.Options{ AuditLogging: true, Options: oop, @@ -68,6 +68,7 @@ func NewWithAPI(t *testing.T, options *Options) (*codersdk.Client, io.Closer, *c provisionerCloser = coderdtest.NewProvisionerDaemon(t, coderAPI.AGPL) } t.Cleanup(func() { + cancelFunc() _ = provisionerCloser.Close() _ = coderAPI.Close() })