diff --git a/cli/gitssh_test.go b/cli/gitssh_test.go index 78ebcb7d0ff52..db3224d03476f 100644 --- a/cli/gitssh_test.go +++ b/cli/gitssh_test.go @@ -20,8 +20,6 @@ import ( "github.com/stretchr/testify/require" gossh "golang.org/x/crypto/ssh" - "cdr.dev/slog" - "github.com/coder/coder/cli/clitest" "github.com/coder/coder/coderd/coderdtest" "github.com/coder/coder/codersdk" @@ -83,18 +81,7 @@ func prepareTestGitSSH(ctx context.Context, t *testing.T) (*codersdk.Client, str errC <- cmd.ExecuteContext(ctx) }() t.Cleanup(func() { require.NoError(t, <-errC) }) - coderdtest.AwaitWorkspaceAgents(t, client, workspace.LatestBuild.ID) - resources, err := client.WorkspaceResourcesByBuild(ctx, workspace.LatestBuild.ID) - require.NoError(t, err) - dialer, err := client.DialWorkspaceAgentTailnet(ctx, slog.Logger{}, resources[0].Agents[0].ID) - require.NoError(t, err) - defer dialer.Close() - require.Eventually(t, func() bool { - _, err = dialer.Ping() - return err == nil - }, testutil.WaitMedium, testutil.IntervalFast) - return agentClient, agentToken, pubkey } diff --git a/cli/root.go b/cli/root.go index c3e8c6d244384..fc460265af345 100644 --- a/cli/root.go +++ b/cli/root.go @@ -91,12 +91,13 @@ func Core() []*cobra.Command { users(), versionCmd(), workspaceAgent(), - features(), } } func AGPL() []*cobra.Command { - all := append(Core(), Server(coderd.New)) + all := append(Core(), Server(func(_ context.Context, o *coderd.Options) (*coderd.API, error) { + return coderd.New(o), nil + })) return all } @@ -548,13 +549,11 @@ func checkWarnings(cmd *cobra.Command, client *codersdk.Client) error { defer cancel() entitlements, err := client.Entitlements(ctx) - if err != nil { - return xerrors.Errorf("get entitlements to show warnings: %w", err) - } - for _, w := range entitlements.Warnings { - _, _ = fmt.Fprintln(cmd.ErrOrStderr(), cliui.Styles.Warn.Render(w)) + if err == nil { + for _, w := range entitlements.Warnings { + _, _ = fmt.Fprintln(cmd.ErrOrStderr(), cliui.Styles.Warn.Render(w)) + } } - return nil } diff --git a/cli/server.go b/cli/server.go index 57d0880eb3755..d3486a5c0e6c0 100644 --- a/cli/server.go +++ b/cli/server.go @@ -68,7 +68,7 @@ import ( ) // nolint:gocyclo -func Server(newAPI func(*coderd.Options) *coderd.API) *cobra.Command { +func Server(newAPI func(context.Context, *coderd.Options) (*coderd.API, error)) *cobra.Command { var ( accessURL string address string @@ -489,7 +489,10 @@ func Server(newAPI func(*coderd.Options) *coderd.API) *cobra.Command { ), promAddress, "prometheus")() } - coderAPI := newAPI(options) + coderAPI, err := newAPI(ctx, options) + if err != nil { + return err + } defer coderAPI.Close() client := codersdk.New(localURL) @@ -536,7 +539,7 @@ func Server(newAPI func(*coderd.Options) *coderd.API) *cobra.Command { // These errors are typically noise like "TLS: EOF". Vault does similar: // https://github.com/hashicorp/vault/blob/e2490059d0711635e529a4efcbaa1b26998d6e1c/command/server.go#L2714 ErrorLog: log.New(io.Discard, "", 0), - Handler: coderAPI.Handler, + Handler: coderAPI.RootHandler, BaseContext: func(_ net.Listener) context.Context { return shutdownConnsCtx }, diff --git a/coderd/audit/request.go b/coderd/audit/request.go index c8e86bf3ccaee..eed6ae5dc5afb 100644 --- a/coderd/audit/request.go +++ b/coderd/audit/request.go @@ -12,14 +12,13 @@ import ( "cdr.dev/slog" "github.com/coder/coder/coderd/database" - "github.com/coder/coder/coderd/features" "github.com/coder/coder/coderd/httpmw" "github.com/coder/coder/coderd/tracing" ) type RequestParams struct { - Features features.Service - Log slog.Logger + Audit Auditor + Log slog.Logger Request *http.Request Action database.AuditAction @@ -102,15 +101,6 @@ func InitRequest[T Auditable](w http.ResponseWriter, p *RequestParams) (*Request params: p, } - feats := struct { - Audit Auditor - }{} - err := p.Features.Get(&feats) - if err != nil { - p.Log.Error(p.Request.Context(), "unable to get auditor interface", slog.Error(err)) - return req, func() {} - } - return req, func() { ctx := context.Background() logCtx := p.Request.Context() @@ -120,7 +110,7 @@ func InitRequest[T Auditable](w http.ResponseWriter, p *RequestParams) (*Request return } - diff := Diff(feats.Audit, req.Old, req.New) + diff := Diff(p.Audit, req.Old, req.New) diffRaw, _ := json.Marshal(diff) ip, err := parseIP(p.Request.RemoteAddr) @@ -128,7 +118,7 @@ func InitRequest[T Auditable](w http.ResponseWriter, p *RequestParams) (*Request p.Log.Warn(logCtx, "parse ip", slog.Error(err)) } - err = feats.Audit.Export(ctx, database.AuditLog{ + err = p.Audit.Export(ctx, database.AuditLog{ ID: uuid.New(), Time: database.Now(), UserID: httpmw.APIKey(p.Request).UserID, diff --git a/coderd/authorize.go b/coderd/authorize.go index b21f6a19fcffe..6183092c18e8f 100644 --- a/coderd/authorize.go +++ b/coderd/authorize.go @@ -43,7 +43,7 @@ type HTTPAuthorizer struct { // return // } func (api *API) Authorize(r *http.Request, action rbac.Action, object rbac.Objecter) bool { - return api.httpAuth.Authorize(r, action, object) + return api.HTTPAuth.Authorize(r, action, object) } // Authorize will return false if the user is not authorized to do the action. diff --git a/coderd/coderd.go b/coderd/coderd.go index 607f15ff7931b..050d16b86911b 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -7,6 +7,7 @@ import ( "net/url" "path/filepath" "sync" + "sync/atomic" "time" "github.com/andybalholm/brotli" @@ -24,9 +25,9 @@ import ( "cdr.dev/slog" "github.com/coder/coder/buildinfo" + "github.com/coder/coder/coderd/audit" "github.com/coder/coder/coderd/awsidentity" "github.com/coder/coder/coderd/database" - "github.com/coder/coder/coderd/features" "github.com/coder/coder/coderd/gitsshkey" "github.com/coder/coder/coderd/httpapi" "github.com/coder/coder/coderd/httpmw" @@ -50,6 +51,7 @@ type Options struct { // CacheDir is used for caching files served by the API. CacheDir string + Auditor audit.Auditor AgentConnectionUpdateFrequency time.Duration AgentInactiveDisconnectTimeout time.Duration // APIRateLimit is the minutely throughput rate limit per user or ip. @@ -68,8 +70,6 @@ type Options struct { Telemetry telemetry.Reporter TracerProvider trace.TracerProvider AutoImportTemplates []AutoImportTemplate - LicenseHandler http.Handler - FeaturesService features.Service TailnetCoordinator *tailnet.Coordinator DERPMap *tailcfg.DERPMap @@ -80,6 +80,9 @@ type Options struct { // New constructs a Coder API handler. func New(options *Options) *API { + if options == nil { + options = &Options{} + } if options.AgentConnectionUpdateFrequency == 0 { options.AgentConnectionUpdateFrequency = 3 * time.Second } @@ -117,11 +120,8 @@ func New(options *Options) *API { if options.TailnetCoordinator == nil { options.TailnetCoordinator = tailnet.NewCoordinator() } - if options.LicenseHandler == nil { - options.LicenseHandler = licenses() - } - if options.FeaturesService == nil { - options.FeaturesService = &featuresService{} + if options.Auditor == nil { + options.Auditor = audit.NewNop() } siteCacheDir := options.CacheDir @@ -142,14 +142,16 @@ func New(options *Options) *API { r := chi.NewRouter() api := &API{ Options: options, - Handler: r, + RootHandler: r, siteHandler: site.Handler(site.FS(), binFS), - httpAuth: &HTTPAuthorizer{ + HTTPAuth: &HTTPAuthorizer{ Authorizer: options.Authorizer, Logger: options.Logger, }, metricsCache: metricsCache, + Auditor: atomic.Pointer[audit.Auditor]{}, } + api.Auditor.Store(&options.Auditor) api.workspaceAgentCache = wsconncache.New(api.dialWorkspaceAgentTailnet, 0) api.derpServer = derp.NewServer(key.NewNode(), tailnet.Logger(options.Logger)) oauthConfigs := &httpmw.OAuth2Configs{ @@ -218,6 +220,8 @@ func New(options *Options) *API { }) r.Route("/api/v2", func(r chi.Router) { + api.APIHandler = r + r.NotFound(func(rw http.ResponseWriter, r *http.Request) { httpapi.Write(rw, http.StatusNotFound, codersdk.Response{ Message: "Route not found.", @@ -473,14 +477,6 @@ func New(options *Options) *API { r.Get("/resources", api.workspaceBuildResources) r.Get("/state", api.workspaceBuildState) }) - r.Route("/entitlements", func(r chi.Router) { - r.Use(apiKeyMiddleware) - r.Get("/", api.FeaturesService.EntitlementsAPI) - }) - r.Route("/licenses", func(r chi.Router) { - r.Use(apiKeyMiddleware) - r.Mount("/", options.LicenseHandler) - }) }) r.NotFound(compressHandler(http.HandlerFunc(api.siteHandler.ServeHTTP)).ServeHTTP) @@ -489,17 +485,20 @@ func New(options *Options) *API { type API struct { *Options + Auditor atomic.Pointer[audit.Auditor] + HTTPAuth *HTTPAuthorizer - derpServer *derp.Server + // APIHandler serves "/api/v2" + APIHandler chi.Router + // RootHandler serves "/" + RootHandler chi.Router - Handler chi.Router + derpServer *derp.Server + metricsCache *metricscache.Cache siteHandler http.Handler websocketWaitMutex sync.Mutex websocketWaitGroup sync.WaitGroup workspaceAgentCache *wsconncache.Cache - httpAuth *HTTPAuthorizer - - metricsCache *metricscache.Cache } // Close waits for all WebSocket connections to drain before returning. diff --git a/coderd/coderd_test.go b/coderd/coderd_test.go index 87206873b1073..9fc459fc9e18e 100644 --- a/coderd/coderd_test.go +++ b/coderd/coderd_test.go @@ -38,16 +38,6 @@ func TestBuildInfo(t *testing.T) { require.Equal(t, buildinfo.Version(), buildInfo.Version, "version") } -// TestAuthorizeAllEndpoints will check `authorize` is called on every endpoint registered. -func TestAuthorizeAllEndpoints(t *testing.T) { - t.Parallel() - ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) - defer cancel() - a := coderdtest.NewAuthTester(ctx, t, nil) - skipRoutes, assertRoute := coderdtest.AGPLRoutes(a) - a.Test(ctx, assertRoute, skipRoutes) -} - func TestDERP(t *testing.T) { t.Parallel() client := coderdtest.New(t, nil) diff --git a/coderd/coderdtest/authtest.go b/coderd/coderdtest/authorize.go similarity index 96% rename from coderd/coderdtest/authtest.go rename to coderd/coderdtest/authorize.go index 564bfbacf91cd..8b31b8e8eb2b5 100644 --- a/coderd/coderdtest/authtest.go +++ b/coderd/coderdtest/authorize.go @@ -22,143 +22,6 @@ import ( "github.com/coder/coder/testutil" ) -type RouteCheck struct { - NoAuthorize bool - AssertAction rbac.Action - AssertObject rbac.Object - StatusCode int -} - -type AuthTester struct { - t *testing.T - api *coderd.API - authorizer *recordingAuthorizer - - Client *codersdk.Client - Workspace codersdk.Workspace - Organization codersdk.Organization - Admin codersdk.CreateFirstUserResponse - Template codersdk.Template - Version codersdk.TemplateVersion - WorkspaceResource codersdk.WorkspaceResource - File codersdk.UploadResponse - TemplateVersionDryRun codersdk.ProvisionerJob - TemplateParam codersdk.Parameter - URLParams map[string]string -} - -func NewAuthTester(ctx context.Context, t *testing.T, options *Options) *AuthTester { - authorizer := &recordingAuthorizer{} - if options == nil { - options = &Options{} - } - if options.Authorizer != nil { - t.Error("NewAuthTester cannot be called with custom Authorizer") - } - options.Authorizer = authorizer - options.IncludeProvisionerDaemon = true - - client, _, api := newWithAPI(t, options) - admin := CreateFirstUser(t, client) - // The provisioner will call to coderd and register itself. This is async, - // so we wait for it to occur. - require.Eventually(t, func() bool { - provisionerds, err := client.ProvisionerDaemons(ctx) - return assert.NoError(t, err) && len(provisionerds) > 0 - }, testutil.WaitLong, testutil.IntervalSlow) - - provisionerds, err := client.ProvisionerDaemons(ctx) - require.NoError(t, err, "fetch provisioners") - require.Len(t, provisionerds, 1) - - organization, err := client.Organization(ctx, admin.OrganizationID) - require.NoError(t, err, "fetch org") - - // Setup some data in the database. - version := CreateTemplateVersion(t, client, admin.OrganizationID, &echo.Responses{ - Parse: echo.ParseComplete, - Provision: []*proto.Provision_Response{{ - Type: &proto.Provision_Response_Complete{ - Complete: &proto.Provision_Complete{ - // Return a workspace resource - Resources: []*proto.Resource{{ - Name: "some", - Type: "example", - Agents: []*proto.Agent{{ - Name: "agent", - Id: "something", - Auth: &proto.Agent_Token{}, - Apps: []*proto.App{{ - Name: "testapp", - Url: "http://localhost:3000", - }}, - }}, - }}, - }, - }, - }}, - }) - AwaitTemplateVersionJob(t, client, version.ID) - template := CreateTemplate(t, client, admin.OrganizationID, version.ID) - workspace := CreateWorkspace(t, client, admin.OrganizationID, template.ID) - AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID) - file, err := client.Upload(ctx, codersdk.ContentTypeTar, make([]byte, 1024)) - require.NoError(t, err, "upload file") - workspaceResources, err := client.WorkspaceResourcesByBuild(ctx, workspace.LatestBuild.ID) - require.NoError(t, err, "workspace resources") - templateVersionDryRun, err := client.CreateTemplateVersionDryRun(ctx, version.ID, codersdk.CreateTemplateVersionDryRunRequest{ - ParameterValues: []codersdk.CreateParameterRequest{}, - }) - require.NoError(t, err, "template version dry-run") - - templateParam, err := client.CreateParameter(ctx, codersdk.ParameterTemplate, template.ID, codersdk.CreateParameterRequest{ - Name: "test-param", - SourceValue: "hello world", - SourceScheme: codersdk.ParameterSourceSchemeData, - DestinationScheme: codersdk.ParameterDestinationSchemeProvisionerVariable, - }) - require.NoError(t, err, "create template param") - - urlParameters := map[string]string{ - "{organization}": admin.OrganizationID.String(), - "{user}": admin.UserID.String(), - "{organizationname}": organization.Name, - "{workspace}": workspace.ID.String(), - "{workspacebuild}": workspace.LatestBuild.ID.String(), - "{workspacename}": workspace.Name, - "{workspaceagent}": workspaceResources[0].Agents[0].ID.String(), - "{buildnumber}": strconv.FormatInt(int64(workspace.LatestBuild.BuildNumber), 10), - "{template}": template.ID.String(), - "{hash}": file.Hash, - "{workspaceresource}": workspaceResources[0].ID.String(), - "{workspaceapp}": workspaceResources[0].Agents[0].Apps[0].Name, - "{templateversion}": version.ID.String(), - "{jobID}": templateVersionDryRun.ID.String(), - "{templatename}": template.Name, - "{workspace_and_agent}": workspace.Name + "." + workspaceResources[0].Agents[0].Name, - // Only checking template scoped params here - "parameters/{scope}/{id}": fmt.Sprintf("parameters/%s/%s", - string(templateParam.Scope), templateParam.ScopeID.String()), - } - - return &AuthTester{ - t: t, - api: api, - authorizer: authorizer, - Client: client, - Workspace: workspace, - Organization: organization, - Admin: admin, - Template: template, - Version: version, - WorkspaceResource: workspaceResources[0], - File: file, - TemplateVersionDryRun: templateVersionDryRun, - TemplateParam: templateParam, - URLParams: urlParameters, - } -} - func AGPLRoutes(a *AuthTester) (map[string]string, map[string]RouteCheck) { // Some quick reused objects workspaceRBACObj := rbac.ResourceWorkspace.InOrg(a.Organization.ID).WithOwner(a.Workspace.OwnerID.String()) @@ -181,7 +44,6 @@ func AGPLRoutes(a *AuthTester) (map[string]string, map[string]RouteCheck) { "POST:/api/v2/users/login": {NoAuthorize: true}, "GET:/api/v2/users/authmethods": {NoAuthorize: true}, "POST:/api/v2/csp/reports": {NoAuthorize: true}, - "GET:/api/v2/entitlements": {NoAuthorize: true}, // Has it's own auth "GET:/api/v2/users/oauth2/github/callback": {NoAuthorize: true}, @@ -408,6 +270,134 @@ func AGPLRoutes(a *AuthTester) (map[string]string, map[string]RouteCheck) { return skipRoutes, assertRoute } +type RouteCheck struct { + NoAuthorize bool + AssertAction rbac.Action + AssertObject rbac.Object + StatusCode int +} + +type AuthTester struct { + t *testing.T + api *coderd.API + authorizer *RecordingAuthorizer + + Client *codersdk.Client + Workspace codersdk.Workspace + Organization codersdk.Organization + Admin codersdk.CreateFirstUserResponse + Template codersdk.Template + Version codersdk.TemplateVersion + WorkspaceResource codersdk.WorkspaceResource + File codersdk.UploadResponse + TemplateVersionDryRun codersdk.ProvisionerJob + TemplateParam codersdk.Parameter + URLParams map[string]string +} + +func NewAuthTester(ctx context.Context, t *testing.T, client *codersdk.Client, api *coderd.API, admin codersdk.CreateFirstUserResponse) *AuthTester { + authorizer, ok := api.Authorizer.(*RecordingAuthorizer) + if !ok { + t.Fail() + } + // The provisioner will call to coderd and register itself. This is async, + // so we wait for it to occur. + require.Eventually(t, func() bool { + provisionerds, err := client.ProvisionerDaemons(ctx) + return assert.NoError(t, err) && len(provisionerds) > 0 + }, testutil.WaitLong, testutil.IntervalSlow) + + provisionerds, err := client.ProvisionerDaemons(ctx) + require.NoError(t, err, "fetch provisioners") + require.Len(t, provisionerds, 1) + + organization, err := client.Organization(ctx, admin.OrganizationID) + require.NoError(t, err, "fetch org") + + // Setup some data in the database. + version := CreateTemplateVersion(t, client, admin.OrganizationID, &echo.Responses{ + Parse: echo.ParseComplete, + Provision: []*proto.Provision_Response{{ + Type: &proto.Provision_Response_Complete{ + Complete: &proto.Provision_Complete{ + // Return a workspace resource + Resources: []*proto.Resource{{ + Name: "some", + Type: "example", + Agents: []*proto.Agent{{ + Name: "agent", + Id: "something", + Auth: &proto.Agent_Token{}, + Apps: []*proto.App{{ + Name: "testapp", + Url: "http://localhost:3000", + }}, + }}, + }}, + }, + }, + }}, + }) + AwaitTemplateVersionJob(t, client, version.ID) + template := CreateTemplate(t, client, admin.OrganizationID, version.ID) + workspace := CreateWorkspace(t, client, admin.OrganizationID, template.ID) + AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID) + file, err := client.Upload(ctx, codersdk.ContentTypeTar, make([]byte, 1024)) + require.NoError(t, err, "upload file") + workspaceResources, err := client.WorkspaceResourcesByBuild(ctx, workspace.LatestBuild.ID) + require.NoError(t, err, "workspace resources") + templateVersionDryRun, err := client.CreateTemplateVersionDryRun(ctx, version.ID, codersdk.CreateTemplateVersionDryRunRequest{ + ParameterValues: []codersdk.CreateParameterRequest{}, + }) + require.NoError(t, err, "template version dry-run") + + templateParam, err := client.CreateParameter(ctx, codersdk.ParameterTemplate, template.ID, codersdk.CreateParameterRequest{ + Name: "test-param", + SourceValue: "hello world", + SourceScheme: codersdk.ParameterSourceSchemeData, + DestinationScheme: codersdk.ParameterDestinationSchemeProvisionerVariable, + }) + require.NoError(t, err, "create template param") + urlParameters := map[string]string{ + "{organization}": admin.OrganizationID.String(), + "{user}": admin.UserID.String(), + "{organizationname}": organization.Name, + "{workspace}": workspace.ID.String(), + "{workspacebuild}": workspace.LatestBuild.ID.String(), + "{workspacename}": workspace.Name, + "{workspaceagent}": workspaceResources[0].Agents[0].ID.String(), + "{buildnumber}": strconv.FormatInt(int64(workspace.LatestBuild.BuildNumber), 10), + "{template}": template.ID.String(), + "{hash}": file.Hash, + "{workspaceresource}": workspaceResources[0].ID.String(), + "{workspaceapp}": workspaceResources[0].Agents[0].Apps[0].Name, + "{templateversion}": version.ID.String(), + "{jobID}": templateVersionDryRun.ID.String(), + "{templatename}": template.Name, + "{workspace_and_agent}": workspace.Name + "." + workspaceResources[0].Agents[0].Name, + // Only checking template scoped params here + "parameters/{scope}/{id}": fmt.Sprintf("parameters/%s/%s", + string(templateParam.Scope), templateParam.ScopeID.String()), + } + + return &AuthTester{ + t: t, + api: api, + authorizer: authorizer, + Client: client, + Workspace: workspace, + Organization: organization, + Admin: admin, + Template: template, + Version: version, + WorkspaceResource: workspaceResources[0], + File: file, + TemplateVersionDryRun: templateVersionDryRun, + TemplateParam: templateParam, + URLParams: urlParameters, + } +} + func (a *AuthTester) Test(ctx context.Context, assertRoute map[string]RouteCheck, skipRoutes map[string]string) { // Always fail auth from this point forward a.authorizer.AlwaysReturn = rbac.ForbiddenWithInternal(xerrors.New("fake implementation"), nil, nil) @@ -433,7 +423,7 @@ func (a *AuthTester) Test(ctx context.Context, assertRoute map[string]RouteCheck } err := chi.Walk( - a.api.Handler, + a.api.RootHandler, func( method string, route string, @@ -513,14 +503,14 @@ type authCall struct { Object rbac.Object } -type recordingAuthorizer struct { +type RecordingAuthorizer struct { Called *authCall AlwaysReturn error } -var _ rbac.Authorizer = (*recordingAuthorizer)(nil) +var _ rbac.Authorizer = (*RecordingAuthorizer)(nil) -func (r *recordingAuthorizer) ByRoleName(_ context.Context, subjectID string, roleNames []string, scope rbac.Scope, action rbac.Action, object rbac.Object) error { +func (r *RecordingAuthorizer) ByRoleName(_ context.Context, subjectID string, roleNames []string, scope rbac.Scope, action rbac.Action, object rbac.Object) error { r.Called = &authCall{ SubjectID: subjectID, Roles: roleNames, @@ -531,7 +521,7 @@ func (r *recordingAuthorizer) ByRoleName(_ context.Context, subjectID string, ro return r.AlwaysReturn } -func (r *recordingAuthorizer) PrepareByRoleName(_ context.Context, subjectID string, roles []string, scope rbac.Scope, action rbac.Action, _ string) (rbac.PreparedAuthorized, error) { +func (r *RecordingAuthorizer) PrepareByRoleName(_ context.Context, subjectID string, roles []string, scope rbac.Scope, action rbac.Action, _ string) (rbac.PreparedAuthorized, error) { return &fakePreparedAuthorizer{ Original: r, SubjectID: subjectID, @@ -541,12 +531,12 @@ func (r *recordingAuthorizer) PrepareByRoleName(_ context.Context, subjectID str }, nil } -func (r *recordingAuthorizer) reset() { +func (r *RecordingAuthorizer) reset() { r.Called = nil } type fakePreparedAuthorizer struct { - Original *recordingAuthorizer + Original *RecordingAuthorizer SubjectID string Roles []string Scope rbac.Scope diff --git a/coderd/coderdtest/authorize_test.go b/coderd/coderdtest/authorize_test.go new file mode 100644 index 0000000000000..c8ef64065a290 --- /dev/null +++ b/coderd/coderdtest/authorize_test.go @@ -0,0 +1,20 @@ +package coderdtest_test + +import ( + "context" + "testing" + + "github.com/coder/coder/coderd/coderdtest" +) + +func TestAuthorizeAllEndpoints(t *testing.T) { + t.Parallel() + client, _, api := coderdtest.NewWithAPI(t, &coderdtest.Options{ + Authorizer: &coderdtest.RecordingAuthorizer{}, + IncludeProvisionerDaemon: true, + }) + admin := coderdtest.CreateFirstUser(t, client) + a := coderdtest.NewAuthTester(context.Background(), t, client, api, admin) + skipRoute, assertRoute := coderdtest.AGPLRoutes(a) + a.Test(context.Background(), assertRoute, skipRoute) +} diff --git a/coderd/coderdtest/coderdtest.go b/coderd/coderdtest/coderdtest.go index 709078b154171..fd31cd55230b9 100644 --- a/coderd/coderdtest/coderdtest.go +++ b/coderd/coderdtest/coderdtest.go @@ -80,7 +80,6 @@ type Options struct { // IncludeProvisionerDaemon when true means to start an in-memory provisionerD IncludeProvisionerDaemon bool - APIBuilder func(*coderd.Options) *coderd.API MetricsCacheRefreshInterval time.Duration AgentStatsRefreshInterval time.Duration } @@ -112,14 +111,11 @@ func NewWithProvisionerCloser(t *testing.T, options *Options) (*codersdk.Client, // and is a temporary measure while the API to register provisioners is ironed // out. func newWithCloser(t *testing.T, options *Options) (*codersdk.Client, io.Closer) { - client, closer, _ := newWithAPI(t, options) + client, closer, _ := NewWithAPI(t, options) return client, closer } -// newWithAPI constructs an in-memory API instance and returns a client to talk to it. -// Most tests never need a reference to the API, but AuthorizationTest in this module uses it. -// Do not expose the API or wrath shall descend upon thee. -func newWithAPI(t *testing.T, options *Options) (*codersdk.Client, io.Closer, *coderd.API) { +func NewOptions(t *testing.T, options *Options) (*httptest.Server, context.CancelFunc, *coderd.Options) { if options == nil { options = &Options{} } @@ -140,9 +136,6 @@ func newWithAPI(t *testing.T, options *Options) (*codersdk.Client, io.Closer, *c close(options.AutobuildStats) }) } - if options.APIBuilder == nil { - options.APIBuilder = coderd.New - } // This can be hotswapped for a live database instance. db := databasefake.New() @@ -166,8 +159,6 @@ func newWithAPI(t *testing.T, options *Options) (*codersdk.Client, io.Closer, *c } ctx, cancelFunc := context.WithCancel(context.Background()) - defer t.Cleanup(cancelFunc) // Defer to ensure cancelFunc is executed first. - lifecycleExecutor := executor.New( ctx, db, @@ -201,13 +192,7 @@ func newWithAPI(t *testing.T, options *Options) (*codersdk.Client, io.Closer, *c options.SSHKeygenAlgorithm = gitsshkey.AlgorithmEd25519 } - features := coderd.DisabledImplementations - if options.Auditor != nil { - features.Auditor = options.Auditor - } - - // We set the handler after server creation for the access URL. - coderAPI := options.APIBuilder(&coderd.Options{ + return srv, cancelFunc, &coderd.Options{ AgentConnectionUpdateFrequency: 150 * time.Millisecond, // Force a long disconnection timeout to ensure // agents are not marked as disconnected during slow tests. @@ -218,6 +203,7 @@ func newWithAPI(t *testing.T, options *Options) (*codersdk.Client, io.Closer, *c Database: db, Pubsub: pubsub, + Auditor: options.Auditor, AWSCertificates: options.AWSCertificates, AzureCertificates: options.AzureCertificates, GithubOAuth2Config: options.GithubOAuth2Config, @@ -248,22 +234,30 @@ func newWithAPI(t *testing.T, options *Options) (*codersdk.Client, io.Closer, *c AutoImportTemplates: options.AutoImportTemplates, MetricsCacheRefreshInterval: options.MetricsCacheRefreshInterval, AgentStatsRefreshInterval: options.AgentStatsRefreshInterval, - FeaturesService: coderd.NewMockFeaturesService(features), - }) - t.Cleanup(func() { - _ = coderAPI.Close() - }) - srv.Config.Handler = coderAPI.Handler + } +} +// NewWithAPI constructs an in-memory API instance and returns a client to talk to it. +// Most tests never need a reference to the API, but AuthorizationTest in this module uses it. +// Do not expose the API or wrath shall descend upon thee. +func NewWithAPI(t *testing.T, options *Options) (*codersdk.Client, io.Closer, *coderd.API) { + if options == nil { + options = &Options{} + } + srv, cancelFunc, newOptions := NewOptions(t, options) + // We set the handler after server creation for the access URL. + coderAPI := coderd.New(newOptions) + srv.Config.Handler = coderAPI.RootHandler var provisionerCloser io.Closer = nopcloser{} if options.IncludeProvisionerDaemon { provisionerCloser = NewProvisionerDaemon(t, coderAPI) } t.Cleanup(func() { + cancelFunc() _ = provisionerCloser.Close() + _ = coderAPI.Close() }) - - return codersdk.New(serverURL), provisionerCloser, coderAPI + return codersdk.New(coderAPI.AccessURL), provisionerCloser, coderAPI } // NewProvisionerDaemon launches a provisionerd instance configured to work diff --git a/coderd/features.go b/coderd/features.go deleted file mode 100644 index 594fad2e38423..0000000000000 --- a/coderd/features.go +++ /dev/null @@ -1,97 +0,0 @@ -package coderd - -import ( - "net/http" - "reflect" - - "golang.org/x/xerrors" - - "github.com/coder/coder/coderd/audit" - "github.com/coder/coder/coderd/features" - "github.com/coder/coder/coderd/httpapi" - "github.com/coder/coder/codersdk" -) - -func NewMockFeaturesService(feats FeatureInterfaces) features.Service { - return &featuresService{ - feats: &feats, - } -} - -type featuresService struct { - feats *FeatureInterfaces -} - -func (*featuresService) EntitlementsAPI(rw http.ResponseWriter, _ *http.Request) { - feats := make(map[string]codersdk.Feature) - for _, f := range codersdk.FeatureNames { - feats[f] = codersdk.Feature{ - Entitlement: codersdk.EntitlementNotEntitled, - Enabled: false, - } - } - httpapi.Write(rw, http.StatusOK, codersdk.Entitlements{ - Features: feats, - Warnings: []string{}, - HasLicense: false, - }) -} - -// Get returns the implementations for feature interfaces. Parameter `s` must be a pointer to a -// struct type containing feature interfaces as fields. The AGPL featureService always returns the -// "disabled" version of the feature interface because it doesn't include any enterprise features -// by definition. -func (f *featuresService) Get(ps any) error { - if reflect.TypeOf(ps).Kind() != reflect.Pointer { - return xerrors.New("input must be pointer to struct") - } - vs := reflect.ValueOf(ps).Elem() - if vs.Kind() != reflect.Struct { - return xerrors.New("input must be pointer to struct") - } - for i := 0; i < vs.NumField(); i++ { - vf := vs.Field(i) - tf := vf.Type() - if tf.Kind() != reflect.Interface { - return xerrors.Errorf("fields of input struct must be interfaces: %s", tf.String()) - } - err := f.setImplementation(vf, tf) - if err != nil { - return err - } - } - return nil -} - -// setImplementation finds the correct implementation for the field's type, and sets it on the -// struct. It returns an error if unsuccessful -func (f *featuresService) setImplementation(vf reflect.Value, tf reflect.Type) error { - feats := f.feats - if feats == nil { - feats = &DisabledImplementations - } - - // when we get more than a few features it might make sense to have a data structure for finding - // the correct implementation that's faster than just a linear search, but for now just spin - // through the implementations we have. - vd := reflect.ValueOf(*feats) - for j := 0; j < vd.NumField(); j++ { - vdf := vd.Field(j) - if vdf.Type() == tf { - vf.Set(vdf) - return nil - } - } - return xerrors.Errorf("unable to find implementation of interface %s", tf.String()) -} - -// FeatureInterfaces contains a field for each interface controlled by an enterprise feature. -type FeatureInterfaces struct { - Auditor audit.Auditor -} - -// DisabledImplementations includes all the implementations of turned-off features. There are no -// turned-on implementations in AGPL code. -var DisabledImplementations = FeatureInterfaces{ - Auditor: audit.NewNop(), -} diff --git a/coderd/features/features.go b/coderd/features/features.go deleted file mode 100644 index d44bd5f2e40d1..0000000000000 --- a/coderd/features/features.go +++ /dev/null @@ -1,13 +0,0 @@ -package features - -import "net/http" - -// Service is the interface for interacting with enterprise features. -type Service interface { - EntitlementsAPI(w http.ResponseWriter, r *http.Request) - - // Get returns the implementations for feature interfaces. Parameter `s` must be a pointer to a - // struct type containing feature interfaces as fields. The FeatureService sets all fields to - // the correct implementations depending on whether the features are turned on. - Get(s any) error -} diff --git a/coderd/features_internal_test.go b/coderd/features_internal_test.go deleted file mode 100644 index cba3f3da89e50..0000000000000 --- a/coderd/features_internal_test.go +++ /dev/null @@ -1,100 +0,0 @@ -package coderd - -import ( - "encoding/json" - "net/http" - "net/http/httptest" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - "github.com/coder/coder/coderd/audit" - "github.com/coder/coder/codersdk" -) - -func TestEntitlements(t *testing.T) { - t.Parallel() - t.Run("GET", func(t *testing.T) { - t.Parallel() - r := httptest.NewRequest("GET", "https://example.com/api/v2/entitlements", nil) - rw := httptest.NewRecorder() - (&featuresService{}).EntitlementsAPI(rw, r) - resp := rw.Result() - defer resp.Body.Close() - assert.Equal(t, http.StatusOK, resp.StatusCode) - dec := json.NewDecoder(resp.Body) - var result codersdk.Entitlements - err := dec.Decode(&result) - require.NoError(t, err) - assert.False(t, result.HasLicense) - assert.Empty(t, result.Warnings) - for _, f := range codersdk.FeatureNames { - require.Contains(t, result.Features, f) - fe := result.Features[f] - assert.False(t, fe.Enabled) - assert.Equal(t, codersdk.EntitlementNotEntitled, fe.Entitlement) - } - }) -} - -func TestFeaturesServiceGet(t *testing.T) { - t.Parallel() - t.Run("Auditor", func(t *testing.T) { - t.Parallel() - uut := featuresService{} - target := struct { - Auditor audit.Auditor - }{} - err := uut.Get(&target) - require.NoError(t, err) - assert.NotNil(t, target.Auditor) - }) - - t.Run("NotPointer", func(t *testing.T) { - t.Parallel() - uut := featuresService{} - target := struct { - Auditor audit.Auditor - }{} - err := uut.Get(target) - require.Error(t, err) - assert.Nil(t, target.Auditor) - }) - - t.Run("UnknownInterface", func(t *testing.T) { - t.Parallel() - uut := featuresService{} - target := struct { - test testInterface - }{} - err := uut.Get(&target) - require.Error(t, err) - assert.Nil(t, target.test) - }) - - t.Run("PointerToNonStruct", func(t *testing.T) { - t.Parallel() - uut := featuresService{} - var target audit.Auditor - err := uut.Get(&target) - require.Error(t, err) - assert.Nil(t, target) - }) - - t.Run("StructWithNonInterfaces", func(t *testing.T) { - t.Parallel() - uut := featuresService{} - target := struct { - N int64 - Auditor audit.Auditor - }{} - err := uut.Get(&target) - require.Error(t, err) - assert.Nil(t, target.Auditor) - }) -} - -type testInterface interface { - Test() error -} diff --git a/coderd/licenses.go b/coderd/licenses.go deleted file mode 100644 index 28a0b1d418043..0000000000000 --- a/coderd/licenses.go +++ /dev/null @@ -1,24 +0,0 @@ -package coderd - -import ( - "net/http" - - "github.com/go-chi/chi/v5" - - "github.com/coder/coder/coderd/httpapi" - "github.com/coder/coder/codersdk" -) - -func licenses() http.Handler { - r := chi.NewRouter() - r.NotFound(unsupported) - return r -} - -func unsupported(rw http.ResponseWriter, _ *http.Request) { - httpapi.Write(rw, http.StatusNotFound, codersdk.Response{ - Message: "Unsupported", - Detail: "These endpoints are not supported in AGPL-licensed Coder", - Validations: nil, - }) -} diff --git a/coderd/provisionerdaemons.go b/coderd/provisionerdaemons.go index b3d68eca79f31..371ec0d649fd5 100644 --- a/coderd/provisionerdaemons.go +++ b/coderd/provisionerdaemons.go @@ -48,7 +48,7 @@ func (api *API) provisionerDaemons(rw http.ResponseWriter, r *http.Request) { if daemons == nil { daemons = []database.ProvisionerDaemon{} } - daemons, err = AuthorizeFilter(api.httpAuth, r, rbac.ActionRead, daemons) + daemons, err = AuthorizeFilter(api.HTTPAuth, r, rbac.ActionRead, daemons) if err != nil { httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{ Message: "Internal error fetching provisioner daemons.", diff --git a/coderd/provisionerjobs_internal_test.go b/coderd/provisionerjobs_internal_test.go index 9a58a27193d4d..34bfce841e4e9 100644 --- a/coderd/provisionerjobs_internal_test.go +++ b/coderd/provisionerjobs_internal_test.go @@ -41,7 +41,7 @@ func TestProvisionerJobLogs_Unit(t *testing.T) { api := New(&opts) defer api.Close() - server := httptest.NewServer(api.Handler) + server := httptest.NewServer(api.RootHandler) defer server.Close() userID := uuid.New() keyID, keySecret, err := generateAPIKeyIDSecret() diff --git a/coderd/templates.go b/coderd/templates.go index c48531a25c226..96bf39dd268de 100644 --- a/coderd/templates.go +++ b/coderd/templates.go @@ -85,11 +85,12 @@ func (api *API) template(rw http.ResponseWriter, r *http.Request) { func (api *API) deleteTemplate(rw http.ResponseWriter, r *http.Request) { var ( template = httpmw.TemplateParam(r) + auditor = *api.Auditor.Load() aReq, commitAudit = audit.InitRequest[database.Template](rw, &audit.RequestParams{ - Features: api.FeaturesService, - Log: api.Logger, - Request: r, - Action: database.AuditActionDelete, + Audit: auditor, + Log: api.Logger, + Request: r, + Action: database.AuditActionDelete, }) ) defer commitAudit() @@ -139,17 +140,18 @@ func (api *API) postTemplateByOrganization(rw http.ResponseWriter, r *http.Reque createTemplate codersdk.CreateTemplateRequest organization = httpmw.OrganizationParam(r) apiKey = httpmw.APIKey(r) + auditor = *api.Auditor.Load() templateAudit, commitTemplateAudit = audit.InitRequest[database.Template](rw, &audit.RequestParams{ - Features: api.FeaturesService, - Log: api.Logger, - Request: r, - Action: database.AuditActionCreate, + Audit: auditor, + Log: api.Logger, + Request: r, + Action: database.AuditActionCreate, }) templateVersionAudit, commitTemplateVersionAudit = audit.InitRequest[database.TemplateVersion](rw, &audit.RequestParams{ - Features: api.FeaturesService, - Log: api.Logger, - Request: r, - Action: database.AuditActionWrite, + Audit: auditor, + Log: api.Logger, + Request: r, + Action: database.AuditActionWrite, }) ) defer commitTemplateAudit() @@ -340,7 +342,7 @@ func (api *API) templatesByOrganization(rw http.ResponseWriter, r *http.Request) } // Filter templates based on rbac permissions - templates, err = AuthorizeFilter(api.httpAuth, r, rbac.ActionRead, templates) + templates, err = AuthorizeFilter(api.HTTPAuth, r, rbac.ActionRead, templates) if err != nil { httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{ Message: "Internal error fetching templates.", @@ -435,11 +437,12 @@ func (api *API) templateByOrganizationAndName(rw http.ResponseWriter, r *http.Re func (api *API) patchTemplateMeta(rw http.ResponseWriter, r *http.Request) { var ( template = httpmw.TemplateParam(r) + auditor = *api.Auditor.Load() aReq, commitAudit = audit.InitRequest[database.Template](rw, &audit.RequestParams{ - Features: api.FeaturesService, - Log: api.Logger, - Request: r, - Action: database.AuditActionWrite, + Audit: auditor, + Log: api.Logger, + Request: r, + Action: database.AuditActionWrite, }) ) defer commitAudit() diff --git a/coderd/templateversions.go b/coderd/templateversions.go index ef8a3ff85c94e..69a98e03b352c 100644 --- a/coderd/templateversions.go +++ b/coderd/templateversions.go @@ -559,11 +559,12 @@ func (api *API) templateVersionByName(rw http.ResponseWriter, r *http.Request) { func (api *API) patchActiveTemplateVersion(rw http.ResponseWriter, r *http.Request) { var ( template = httpmw.TemplateParam(r) + auditor = *api.Auditor.Load() aReq, commitAudit = audit.InitRequest[database.Template](rw, &audit.RequestParams{ - Features: api.FeaturesService, - Log: api.Logger, - Request: r, - Action: database.AuditActionWrite, + Audit: auditor, + Log: api.Logger, + Request: r, + Action: database.AuditActionWrite, }) ) defer commitAudit() @@ -631,11 +632,12 @@ func (api *API) postTemplateVersionsByOrganization(rw http.ResponseWriter, r *ht var ( apiKey = httpmw.APIKey(r) organization = httpmw.OrganizationParam(r) + auditor = *api.Auditor.Load() aReq, commitAudit = audit.InitRequest[database.TemplateVersion](rw, &audit.RequestParams{ - Features: api.FeaturesService, - Log: api.Logger, - Request: r, - Action: database.AuditActionCreate, + Audit: auditor, + Log: api.Logger, + Request: r, + Action: database.AuditActionCreate, }) req codersdk.CreateTemplateVersionRequest diff --git a/coderd/users.go b/coderd/users.go index 53e67937b87c3..631e660eb5770 100644 --- a/coderd/users.go +++ b/coderd/users.go @@ -220,7 +220,7 @@ func (api *API) users(rw http.ResponseWriter, r *http.Request) { return } - users, err = AuthorizeFilter(api.httpAuth, r, rbac.ActionRead, users) + users, err = AuthorizeFilter(api.HTTPAuth, r, rbac.ActionRead, users) if err != nil { httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{ Message: "Internal error fetching users.", @@ -255,11 +255,12 @@ func (api *API) users(rw http.ResponseWriter, r *http.Request) { // Creates a new user. func (api *API) postUser(rw http.ResponseWriter, r *http.Request) { + auditor := *api.Auditor.Load() aReq, commitAudit := audit.InitRequest[database.User](rw, &audit.RequestParams{ - Features: api.FeaturesService, - Log: api.Logger, - Request: r, - Action: database.AuditActionCreate, + Audit: auditor, + Log: api.Logger, + Request: r, + Action: database.AuditActionCreate, }) defer commitAudit() @@ -339,12 +340,13 @@ func (api *API) postUser(rw http.ResponseWriter, r *http.Request) { } func (api *API) deleteUser(rw http.ResponseWriter, r *http.Request) { + auditor := *api.Auditor.Load() user := httpmw.UserParam(r) aReq, commitAudit := audit.InitRequest[database.User](rw, &audit.RequestParams{ - Features: api.FeaturesService, - Log: api.Logger, - Request: r, - Action: database.AuditActionDelete, + Audit: auditor, + Log: api.Logger, + Request: r, + Action: database.AuditActionDelete, }) aReq.Old = user defer commitAudit() @@ -414,11 +416,12 @@ func (api *API) userByName(rw http.ResponseWriter, r *http.Request) { func (api *API) putUserProfile(rw http.ResponseWriter, r *http.Request) { var ( user = httpmw.UserParam(r) + auditor = *api.Auditor.Load() aReq, commitAudit = audit.InitRequest[database.User](rw, &audit.RequestParams{ - Features: api.FeaturesService, - Log: api.Logger, - Request: r, - Action: database.AuditActionWrite, + Audit: auditor, + Log: api.Logger, + Request: r, + Action: database.AuditActionWrite, }) ) defer commitAudit() @@ -494,11 +497,12 @@ func (api *API) putUserStatus(status database.UserStatus) func(rw http.ResponseW var ( user = httpmw.UserParam(r) apiKey = httpmw.APIKey(r) + auditor = *api.Auditor.Load() aReq, commitAudit = audit.InitRequest[database.User](rw, &audit.RequestParams{ - Features: api.FeaturesService, - Log: api.Logger, - Request: r, - Action: database.AuditActionWrite, + Audit: auditor, + Log: api.Logger, + Request: r, + Action: database.AuditActionWrite, }) ) defer commitAudit() @@ -560,11 +564,12 @@ func (api *API) putUserPassword(rw http.ResponseWriter, r *http.Request) { var ( user = httpmw.UserParam(r) params codersdk.UpdateUserPasswordRequest + auditor = *api.Auditor.Load() aReq, commitAudit = audit.InitRequest[database.User](rw, &audit.RequestParams{ - Features: api.FeaturesService, - Log: api.Logger, - Request: r, - Action: database.AuditActionWrite, + Audit: auditor, + Log: api.Logger, + Request: r, + Action: database.AuditActionWrite, }) ) defer commitAudit() @@ -673,7 +678,7 @@ func (api *API) userRoles(rw http.ResponseWriter, r *http.Request) { } // Only include ones we can read from RBAC. - memberships, err = AuthorizeFilter(api.httpAuth, r, rbac.ActionRead, memberships) + memberships, err = AuthorizeFilter(api.HTTPAuth, r, rbac.ActionRead, memberships) if err != nil { httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{ Message: "Internal error fetching memberships.", @@ -698,11 +703,12 @@ func (api *API) putUserRoles(rw http.ResponseWriter, r *http.Request) { user = httpmw.UserParam(r) actorRoles = httpmw.UserAuthorization(r) apiKey = httpmw.APIKey(r) + auditor = *api.Auditor.Load() aReq, commitAudit = audit.InitRequest[database.User](rw, &audit.RequestParams{ - Features: api.FeaturesService, - Log: api.Logger, - Request: r, - Action: database.AuditActionWrite, + Audit: auditor, + Log: api.Logger, + Request: r, + Action: database.AuditActionWrite, }) ) defer commitAudit() @@ -812,7 +818,7 @@ func (api *API) organizationsByUser(rw http.ResponseWriter, r *http.Request) { } // Only return orgs the user can read. - organizations, err = AuthorizeFilter(api.httpAuth, r, rbac.ActionRead, organizations) + organizations, err = AuthorizeFilter(api.HTTPAuth, r, rbac.ActionRead, organizations) if err != nil { httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{ Message: "Internal error fetching organizations.", @@ -1176,9 +1182,9 @@ func (api *API) createUser(ctx context.Context, store database.Store, req create func (api *API) setAuthCookie(rw http.ResponseWriter, cookie *http.Cookie) { http.SetCookie(rw, cookie) - devurlCookie := api.applicationCookie(cookie) - if devurlCookie != nil { - http.SetCookie(rw, devurlCookie) + appCookie := api.applicationCookie(cookie) + if appCookie != nil { + http.SetCookie(rw, appCookie) } } diff --git a/coderd/users_test.go b/coderd/users_test.go index 2378adc7d07e1..e23bbc78f381f 100644 --- a/coderd/users_test.go +++ b/coderd/users_test.go @@ -32,7 +32,11 @@ func TestFirstUser(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) defer cancel() - _, err := client.CreateFirstUser(ctx, codersdk.CreateFirstUserRequest{}) + has, err := client.HasFirstUser(context.Background()) + require.NoError(t, err) + require.False(t, has) + + _, err = client.CreateFirstUser(ctx, codersdk.CreateFirstUserRequest{}) require.Error(t, err) }) diff --git a/coderd/workspaces.go b/coderd/workspaces.go index 9631499c29089..f3375c6c75e6a 100644 --- a/coderd/workspaces.go +++ b/coderd/workspaces.go @@ -119,7 +119,7 @@ func (api *API) workspaces(rw http.ResponseWriter, r *http.Request) { } // Only return workspaces the user can read - workspaces, err = AuthorizeFilter(api.httpAuth, r, rbac.ActionRead, workspaces) + workspaces, err = AuthorizeFilter(api.HTTPAuth, r, rbac.ActionRead, workspaces) if err != nil { httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{ Message: "Internal error fetching workspaces.", @@ -217,11 +217,12 @@ func (api *API) postWorkspacesByOrganization(rw http.ResponseWriter, r *http.Req var ( organization = httpmw.OrganizationParam(r) apiKey = httpmw.APIKey(r) + auditor = api.Auditor.Load() aReq, commitAudit = audit.InitRequest[database.Workspace](rw, &audit.RequestParams{ - Features: api.FeaturesService, - Log: api.Logger, - Request: r, - Action: database.AuditActionCreate, + Audit: *auditor, + Log: api.Logger, + Request: r, + Action: database.AuditActionCreate, }) ) defer commitAudit() @@ -480,11 +481,12 @@ func (api *API) postWorkspacesByOrganization(rw http.ResponseWriter, r *http.Req func (api *API) patchWorkspace(rw http.ResponseWriter, r *http.Request) { var ( workspace = httpmw.WorkspaceParam(r) + auditor = api.Auditor.Load() aReq, commitAudit = audit.InitRequest[database.Workspace](rw, &audit.RequestParams{ - Features: api.FeaturesService, - Log: api.Logger, - Request: r, - Action: database.AuditActionWrite, + Audit: *auditor, + Log: api.Logger, + Request: r, + Action: database.AuditActionWrite, }) ) defer commitAudit() @@ -556,11 +558,12 @@ func (api *API) patchWorkspace(rw http.ResponseWriter, r *http.Request) { func (api *API) putWorkspaceAutostart(rw http.ResponseWriter, r *http.Request) { var ( workspace = httpmw.WorkspaceParam(r) + auditor = api.Auditor.Load() aReq, commitAudit = audit.InitRequest[database.Workspace](rw, &audit.RequestParams{ - Features: api.FeaturesService, - Log: api.Logger, - Request: r, - Action: database.AuditActionWrite, + Audit: *auditor, + Log: api.Logger, + Request: r, + Action: database.AuditActionWrite, }) ) defer commitAudit() @@ -616,11 +619,12 @@ func (api *API) putWorkspaceAutostart(rw http.ResponseWriter, r *http.Request) { func (api *API) putWorkspaceTTL(rw http.ResponseWriter, r *http.Request) { var ( workspace = httpmw.WorkspaceParam(r) + auditor = api.Auditor.Load() aReq, commitAudit = audit.InitRequest[database.Workspace](rw, &audit.RequestParams{ - Features: api.FeaturesService, - Log: api.Logger, - Request: r, - Action: database.AuditActionWrite, + Audit: *auditor, + Log: api.Logger, + Request: r, + Action: database.AuditActionWrite, }) ) defer commitAudit() diff --git a/cli/features.go b/enterprise/cli/features.go similarity index 88% rename from cli/features.go rename to enterprise/cli/features.go index 8b5d8ba30d680..7a56c6ba2e538 100644 --- a/cli/features.go +++ b/enterprise/cli/features.go @@ -3,12 +3,15 @@ package cli import ( "bytes" "encoding/json" + "errors" "fmt" + "net/http" "strings" "github.com/spf13/cobra" "golang.org/x/xerrors" + agpl "github.com/coder/coder/cli" "github.com/coder/coder/cli/cliui" "github.com/coder/coder/codersdk" ) @@ -36,11 +39,15 @@ func featuresList() *cobra.Command { Use: "list", Aliases: []string{"ls"}, RunE: func(cmd *cobra.Command, args []string) error { - client, err := CreateClient(cmd) + client, err := agpl.CreateClient(cmd) if err != nil { return err } entitlements, err := client.Entitlements(cmd.Context()) + var apiError *codersdk.Error + if errors.As(err, &apiError) && apiError.StatusCode() == http.StatusNotFound { + return xerrors.New("You are on the AGPL licensed version of Coder that does not have Enterprise functionality!") + } if err != nil { return err } diff --git a/cli/features_test.go b/enterprise/cli/features_test.go similarity index 80% rename from cli/features_test.go rename to enterprise/cli/features_test.go index 6c39fec81011a..7f7d13a5180d6 100644 --- a/cli/features_test.go +++ b/enterprise/cli/features_test.go @@ -11,6 +11,8 @@ import ( "github.com/coder/coder/cli/clitest" "github.com/coder/coder/coderd/coderdtest" "github.com/coder/coder/codersdk" + "github.com/coder/coder/enterprise/cli" + "github.com/coder/coder/enterprise/coderd/coderdenttest" "github.com/coder/coder/pty/ptytest" ) @@ -18,9 +20,9 @@ func TestFeaturesList(t *testing.T) { t.Parallel() t.Run("Table", func(t *testing.T) { t.Parallel() - client := coderdtest.New(t, nil) + client := coderdenttest.New(t, nil) coderdtest.CreateFirstUser(t, client) - cmd, root := clitest.New(t, "features", "list") + cmd, root := clitest.NewWithSubcommands(t, cli.EnterpriseSubcommands(), "features", "list") clitest.SetupConfig(t, client, root) pty := ptytest.New(t) cmd.SetIn(pty.Input()) @@ -36,9 +38,9 @@ func TestFeaturesList(t *testing.T) { t.Run("JSON", func(t *testing.T) { t.Parallel() - client := coderdtest.New(t, nil) + client := coderdenttest.New(t, nil) coderdtest.CreateFirstUser(t, client) - cmd, root := clitest.New(t, "features", "list", "-o", "json") + cmd, root := clitest.NewWithSubcommands(t, cli.EnterpriseSubcommands(), "features", "list", "-o", "json") clitest.SetupConfig(t, client, root) doneChan := make(chan struct{}) diff --git a/enterprise/cli/licenses_test.go b/enterprise/cli/licenses_test.go index 8a7f2076d56e6..a56e4a73277b6 100644 --- a/enterprise/cli/licenses_test.go +++ b/enterprise/cli/licenses_test.go @@ -23,7 +23,7 @@ import ( "github.com/coder/coder/coderd/httpapi" "github.com/coder/coder/codersdk" "github.com/coder/coder/enterprise/cli" - "github.com/coder/coder/enterprise/coderd" + "github.com/coder/coder/enterprise/coderd/coderdenttest" "github.com/coder/coder/pty/ptytest" "github.com/coder/coder/testutil" ) @@ -124,7 +124,7 @@ func TestLicensesAddReal(t *testing.T) { t.Parallel() t.Run("Fails", func(t *testing.T) { t.Parallel() - client := coderdtest.New(t, &coderdtest.Options{APIBuilder: coderd.NewEnterprise}) + client := coderdenttest.New(t, nil) coderdtest.CreateFirstUser(t, client) cmd, root := clitest.NewWithSubcommands(t, cli.EnterpriseSubcommands(), "licenses", "add", "-l", fakeLicenseJWT) @@ -175,7 +175,7 @@ func TestLicensesListReal(t *testing.T) { t.Parallel() t.Run("Empty", func(t *testing.T) { t.Parallel() - client := coderdtest.New(t, &coderdtest.Options{APIBuilder: coderd.NewEnterprise}) + client := coderdenttest.New(t, nil) coderdtest.CreateFirstUser(t, client) cmd, root := clitest.NewWithSubcommands(t, cli.EnterpriseSubcommands(), "licenses", "list") @@ -219,7 +219,7 @@ func TestLicensesDeleteReal(t *testing.T) { t.Parallel() t.Run("Empty", func(t *testing.T) { t.Parallel() - client := coderdtest.New(t, &coderdtest.Options{APIBuilder: coderd.NewEnterprise}) + client := coderdenttest.New(t, nil) coderdtest.CreateFirstUser(t, client) cmd, root := clitest.NewWithSubcommands(t, cli.EnterpriseSubcommands(), "licenses", "delete", "1") diff --git a/enterprise/cli/root.go b/enterprise/cli/root.go index 31546b5d679d0..52decb3266226 100644 --- a/enterprise/cli/root.go +++ b/enterprise/cli/root.go @@ -4,12 +4,12 @@ import ( "github.com/spf13/cobra" agpl "github.com/coder/coder/cli" - "github.com/coder/coder/enterprise/coderd" ) func enterpriseOnly() []*cobra.Command { return []*cobra.Command{ - agpl.Server(coderd.NewEnterprise), + server(), + features(), licenses(), } } diff --git a/enterprise/cli/server.go b/enterprise/cli/server.go new file mode 100644 index 0000000000000..8fd9e542af4cc --- /dev/null +++ b/enterprise/cli/server.go @@ -0,0 +1,33 @@ +package cli + +import ( + "context" + + "github.com/spf13/cobra" + + "github.com/coder/coder/cli/cliflag" + "github.com/coder/coder/enterprise/coderd" + + agpl "github.com/coder/coder/cli" + agplcoderd "github.com/coder/coder/coderd" +) + +func server() *cobra.Command { + var ( + auditLogging bool + ) + cmd := agpl.Server(func(ctx context.Context, options *agplcoderd.Options) (*agplcoderd.API, error) { + api, err := coderd.New(ctx, &coderd.Options{ + AuditLogging: auditLogging, + Options: options, + }) + if err != nil { + return nil, err + } + return api.AGPL, nil + }) + cliflag.BoolVarP(cmd.Flags(), &auditLogging, "audit-logging", "", "CODER_AUDIT_LOGGING", true, + "Specifies whether audit logging is enabled.") + + return cmd +} diff --git a/enterprise/coderd/auth_internal_test.go b/enterprise/coderd/auth_internal_test.go deleted file mode 100644 index 853b6f44c4eda..0000000000000 --- a/enterprise/coderd/auth_internal_test.go +++ /dev/null @@ -1,80 +0,0 @@ -package coderd - -import ( - "context" - "crypto/ed25519" - "crypto/rand" - "fmt" - "net/http" - "testing" - "time" - - "github.com/golang-jwt/jwt/v4" - "github.com/stretchr/testify/require" - - "github.com/coder/coder/coderd/coderdtest" - "github.com/coder/coder/coderd/rbac" - "github.com/coder/coder/codersdk" - "github.com/coder/coder/testutil" -) - -// TestAuthorizeAllEndpoints will check `authorize` is called on every endpoint registered. -// these tests patch the map of license keys, so cannot be run in parallel -// nolint:paralleltest -func TestAuthorizeAllEndpoints(t *testing.T) { - pubKey, privKey, err := ed25519.GenerateKey(rand.Reader) - require.NoError(t, err) - keyID := "testing" - oldKeys := keys - defer func() { - t.Log("restoring keys") - keys = oldKeys - }() - keys = map[string]ed25519.PublicKey{keyID: pubKey} - - ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) - defer cancel() - a := coderdtest.NewAuthTester(ctx, t, &coderdtest.Options{APIBuilder: NewEnterprise}) - - // We need a license in the DB, so that when we call GET api/v2/licenses there is one in the - // list to check authz on. - claims := &Claims{ - RegisteredClaims: jwt.RegisteredClaims{ - Issuer: "test@coder.test", - IssuedAt: jwt.NewNumericDate(time.Now()), - NotBefore: jwt.NewNumericDate(time.Now()), - ExpiresAt: jwt.NewNumericDate(time.Now().Add(2 * time.Hour)), - }, - LicenseExpires: jwt.NewNumericDate(time.Now().Add(time.Hour)), - AccountType: AccountTypeSalesforce, - AccountID: "testing", - Version: CurrentVersion, - Features: Features{ - UserLimit: 0, - AuditLog: 1, - }, - } - lic, err := makeLicense(claims, privKey, keyID) - require.NoError(t, err) - license, err := a.Client.AddLicense(ctx, codersdk.AddLicenseRequest{ - License: lic, - }) - require.NoError(t, err) - a.URLParams["licenses/{id}"] = fmt.Sprintf("licenses/%d", license.ID) - - skipRoutes, assertRoute := coderdtest.AGPLRoutes(a) - assertRoute["POST:/api/v2/licenses"] = coderdtest.RouteCheck{ - AssertAction: rbac.ActionCreate, - AssertObject: rbac.ResourceLicense, - } - assertRoute["GET:/api/v2/licenses"] = coderdtest.RouteCheck{ - StatusCode: http.StatusOK, - AssertAction: rbac.ActionRead, - AssertObject: rbac.ResourceLicense, - } - assertRoute["DELETE:/api/v2/licenses/{id}"] = coderdtest.RouteCheck{ - AssertAction: rbac.ActionDelete, - AssertObject: rbac.ResourceLicense, - } - a.Test(ctx, assertRoute, skipRoutes) -} diff --git a/enterprise/coderd/coderd.go b/enterprise/coderd/coderd.go index 598c32f11b367..20140f0e80d83 100644 --- a/enterprise/coderd/coderd.go +++ b/enterprise/coderd/coderd.go @@ -2,48 +2,282 @@ package coderd import ( "context" - "os" - "strings" + "crypto/ed25519" + "fmt" + "net/http" + "sync" + "time" "golang.org/x/xerrors" + "github.com/cenkalti/backoff/v4" + "github.com/go-chi/chi/v5" + + "cdr.dev/slog" "github.com/coder/coder/coderd" - "github.com/coder/coder/coderd/rbac" + agplaudit "github.com/coder/coder/coderd/audit" + "github.com/coder/coder/coderd/httpapi" + "github.com/coder/coder/coderd/httpmw" + "github.com/coder/coder/codersdk" + "github.com/coder/coder/enterprise/audit" + "github.com/coder/coder/enterprise/audit/backends" ) -const EnvAuditLogEnable = "CODER_AUDIT_LOG_ENABLE" +// New constructs an Enterprise coderd API instance. +// This handler is designed to wrap the AGPL Coder code and +// layer Enterprise functionality on top as much as possible. +func New(ctx context.Context, options *Options) (*API, error) { + if options.EntitlementsUpdateInterval == 0 { + options.EntitlementsUpdateInterval = 10 * time.Minute + } + if options.Keys == nil { + options.Keys = Keys + } + ctx, cancelFunc := context.WithCancel(ctx) + api := &API{ + AGPL: coderd.New(options.Options), + Options: options, + + entitlements: entitlements{ + activeUsers: codersdk.Feature{ + Entitlement: codersdk.EntitlementNotEntitled, + Enabled: false, + }, + auditLogs: codersdk.EntitlementNotEntitled, + }, + cancelEntitlementsLoop: cancelFunc, + } + oauthConfigs := &httpmw.OAuth2Configs{ + Github: options.GithubOAuth2Config, + OIDC: options.OIDCConfig, + } + apiKeyMiddleware := httpmw.ExtractAPIKey(options.Database, oauthConfigs, false) + + api.AGPL.APIHandler.Group(func(r chi.Router) { + r.Get("/entitlements", api.serveEntitlements) + r.Route("/licenses", func(r chi.Router) { + r.Use(apiKeyMiddleware) + r.Post("/", api.postLicense) + r.Get("/", api.licenses) + r.Delete("/{id}", api.deleteLicense) + }) + }) + + err := api.updateEntitlements(ctx) + if err != nil { + return nil, xerrors.Errorf("update entitlements: %w", err) + } + go api.runEntitlementsLoop(ctx) + + return api, nil +} + +type Options struct { + *coderd.Options + + AuditLogging bool + EntitlementsUpdateInterval time.Duration + Keys map[string]ed25519.PublicKey +} + +type API struct { + AGPL *coderd.API + *Options + + cancelEntitlementsLoop func() + entitlementsMu sync.RWMutex + entitlements entitlements +} + +type entitlements struct { + hasLicense bool + activeUsers codersdk.Feature + auditLogs codersdk.Entitlement +} + +func (api *API) Close() error { + api.cancelEntitlementsLoop() + return api.AGPL.Close() +} + +func (api *API) updateEntitlements(ctx context.Context) error { + licenses, err := api.Database.GetUnexpiredLicenses(ctx) + if err != nil { + return err + } + api.entitlementsMu.Lock() + defer api.entitlementsMu.Unlock() + now := time.Now() + + // Default all entitlements to be disabled. + entitlements := entitlements{ + hasLicense: false, + activeUsers: codersdk.Feature{ + Enabled: false, + Entitlement: codersdk.EntitlementNotEntitled, + }, + auditLogs: codersdk.EntitlementNotEntitled, + } + + // Here we loop through licenses to detect enabled features. + for _, l := range licenses { + claims, err := validateDBLicense(l, api.Keys) + if err != nil { + api.Logger.Debug(ctx, "skipping invalid license", + slog.F("id", l.ID), slog.Error(err)) + continue + } + entitlements.hasLicense = true + entitlement := codersdk.EntitlementEntitled + if now.After(claims.LicenseExpires.Time) { + // if the grace period were over, the validation fails, so if we are after + // LicenseExpires we must be in grace period. + entitlement = codersdk.EntitlementGracePeriod + } + if claims.Features.UserLimit > 0 { + entitlements.activeUsers = codersdk.Feature{ + Enabled: true, + Entitlement: entitlement, + } + currentLimit := int64(0) + if entitlements.activeUsers.Limit != nil { + currentLimit = *entitlements.activeUsers.Limit + } + limit := max(currentLimit, claims.Features.UserLimit) + entitlements.activeUsers.Limit = &limit + } + if claims.Features.AuditLog > 0 { + entitlements.auditLogs = entitlement + } + } + + if entitlements.auditLogs != api.entitlements.auditLogs { + auditor := agplaudit.NewNop() + // A flag could be added to the options that would allow disabling + // enhanced audit logging here! + if entitlements.auditLogs == codersdk.EntitlementEntitled && api.AuditLogging { + auditor = audit.NewAuditor( + audit.DefaultFilter, + backends.NewPostgres(api.Database, true), + backends.NewSlog(api.Logger), + ) + } + api.AGPL.Auditor.Store(&auditor) + } + + api.entitlements = entitlements + + return nil +} + +func (api *API) serveEntitlements(rw http.ResponseWriter, r *http.Request) { + api.entitlementsMu.RLock() + entitlements := api.entitlements + api.entitlementsMu.RUnlock() -func NewEnterprise(options *coderd.Options) *coderd.API { - var eOpts = *options - if eOpts.Authorizer == nil { - var err error - eOpts.Authorizer, err = rbac.NewAuthorizer() + resp := codersdk.Entitlements{ + Features: make(map[string]codersdk.Feature), + Warnings: make([]string, 0), + HasLicense: entitlements.hasLicense, + } + + if entitlements.activeUsers.Limit != nil { + activeUserCount, err := api.Database.GetActiveUserCount(r.Context()) if err != nil { - // This should never happen, as the unit tests would fail if the - // default built in authorizer failed. - panic(xerrors.Errorf("rego authorize panic: %w", err)) - } - } - eOpts.LicenseHandler = newLicenseAPI( - eOpts.Logger, - eOpts.Database, - eOpts.Pubsub, - &coderd.HTTPAuthorizer{ - Authorizer: eOpts.Authorizer, - Logger: eOpts.Logger, - }).handler() - en := Enablements{AuditLogs: true} - auditLog := os.Getenv(EnvAuditLogEnable) - auditLog = strings.ToLower(auditLog) - if auditLog == "disable" || auditLog == "false" || auditLog == "0" || auditLog == "no" { - en.AuditLogs = false - } - eOpts.FeaturesService = newFeaturesService( - context.Background(), - eOpts.Logger, - eOpts.Database, - eOpts.Pubsub, - en, - ) - return coderd.New(&eOpts) + httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{ + Message: "Unable to query database", + Detail: err.Error(), + }) + return + } + entitlements.activeUsers.Actual = &activeUserCount + if activeUserCount > *entitlements.activeUsers.Limit { + resp.Warnings = append(resp.Warnings, + fmt.Sprintf( + "Your deployment has %d active users but is only licensed for %d.", + activeUserCount, *entitlements.activeUsers.Limit)) + } + } + resp.Features[codersdk.FeatureUserLimit] = entitlements.activeUsers + + // Audit logs + resp.Features[codersdk.FeatureAuditLog] = codersdk.Feature{ + Entitlement: entitlements.auditLogs, + Enabled: api.AuditLogging, + } + if entitlements.auditLogs == codersdk.EntitlementGracePeriod && api.AuditLogging { + resp.Warnings = append(resp.Warnings, + "Audit logging is enabled but your license for this feature is expired.") + } + + httpapi.Write(rw, http.StatusOK, resp) +} + +func (api *API) runEntitlementsLoop(ctx context.Context) { + eb := backoff.NewExponentialBackOff() + eb.MaxElapsedTime = 0 // retry indefinitely + b := backoff.WithContext(eb, ctx) + updates := make(chan struct{}, 1) + subscribed := false + + for { + select { + case <-ctx.Done(): + return + default: + // pass + } + if !subscribed { + cancel, err := api.Pubsub.Subscribe(PubsubEventLicenses, func(_ context.Context, _ []byte) { + // don't block. If the channel is full, drop the event, as there is a resync + // scheduled already. + select { + case updates <- struct{}{}: + // pass + default: + // pass + } + }) + if err != nil { + api.Logger.Warn(ctx, "failed to subscribe to license updates", slog.Error(err)) + select { + case <-ctx.Done(): + return + case <-time.After(b.NextBackOff()): + } + continue + } + // nolint: revive + defer cancel() + subscribed = true + api.Logger.Debug(ctx, "successfully subscribed to pubsub") + } + + api.Logger.Info(ctx, "syncing licensed entitlements") + err := api.updateEntitlements(ctx) + if err != nil { + api.Logger.Warn(ctx, "failed to get feature entitlements", slog.Error(err)) + time.Sleep(b.NextBackOff()) + continue + } + b.Reset() + api.Logger.Debug(ctx, "synced licensed entitlements") + + select { + case <-ctx.Done(): + return + case <-time.After(api.EntitlementsUpdateInterval): + continue + case <-updates: + api.Logger.Debug(ctx, "got pubsub update") + continue + } + } +} + +func max(a, b int64) int64 { + if a > b { + return a + } + return b } diff --git a/enterprise/coderd/coderd_test.go b/enterprise/coderd/coderd_test.go new file mode 100644 index 0000000000000..aedd79417be41 --- /dev/null +++ b/enterprise/coderd/coderd_test.go @@ -0,0 +1,204 @@ +package coderd_test + +import ( + "context" + "reflect" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/goleak" + + agplaudit "github.com/coder/coder/coderd/audit" + "github.com/coder/coder/coderd/coderdtest" + "github.com/coder/coder/coderd/database" + "github.com/coder/coder/codersdk" + "github.com/coder/coder/enterprise/audit" + "github.com/coder/coder/enterprise/coderd" + "github.com/coder/coder/enterprise/coderd/coderdenttest" + "github.com/coder/coder/testutil" +) + +func TestMain(m *testing.M) { + goleak.VerifyTestMain(m) +} + +func TestEntitlements(t *testing.T) { + t.Parallel() + t.Run("NoLicense", func(t *testing.T) { + t.Parallel() + client := coderdenttest.New(t, nil) + res, err := client.Entitlements(context.Background()) + require.NoError(t, err) + require.False(t, res.HasLicense) + require.Empty(t, res.Warnings) + }) + t.Run("FullLicense", func(t *testing.T) { + t.Parallel() + client := coderdenttest.New(t, nil) + _ = coderdtest.CreateFirstUser(t, client) + coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{ + UserLimit: 100, + AuditLog: true, + }) + res, err := client.Entitlements(context.Background()) + require.NoError(t, err) + assert.True(t, res.HasLicense) + ul := res.Features[codersdk.FeatureUserLimit] + assert.Equal(t, codersdk.EntitlementEntitled, ul.Entitlement) + assert.Equal(t, int64(100), *ul.Limit) + assert.Equal(t, int64(1), *ul.Actual) + assert.True(t, ul.Enabled) + al := res.Features[codersdk.FeatureAuditLog] + assert.Equal(t, codersdk.EntitlementEntitled, al.Entitlement) + assert.True(t, al.Enabled) + assert.Nil(t, al.Limit) + assert.Nil(t, al.Actual) + assert.Empty(t, res.Warnings) + }) + t.Run("FullLicenseToNone", func(t *testing.T) { + t.Parallel() + client := coderdenttest.New(t, nil) + _ = coderdtest.CreateFirstUser(t, client) + license := coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{ + UserLimit: 100, + AuditLog: true, + }) + res, err := client.Entitlements(context.Background()) + require.NoError(t, err) + assert.True(t, res.HasLicense) + al := res.Features[codersdk.FeatureAuditLog] + assert.Equal(t, codersdk.EntitlementEntitled, al.Entitlement) + assert.True(t, al.Enabled) + + err = client.DeleteLicense(context.Background(), license.ID) + require.NoError(t, err) + + res, err = client.Entitlements(context.Background()) + require.NoError(t, err) + assert.False(t, res.HasLicense) + al = res.Features[codersdk.FeatureAuditLog] + assert.Equal(t, codersdk.EntitlementNotEntitled, al.Entitlement) + assert.True(t, al.Enabled) + }) + t.Run("Warnings", func(t *testing.T) { + t.Parallel() + client := coderdenttest.New(t, nil) + first := coderdtest.CreateFirstUser(t, client) + for i := 0; i < 4; i++ { + coderdtest.CreateAnotherUser(t, client, first.OrganizationID) + } + coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{ + UserLimit: 4, + AuditLog: true, + GraceAt: time.Now().Add(-time.Second), + }) + res, err := client.Entitlements(context.Background()) + require.NoError(t, err) + assert.True(t, res.HasLicense) + ul := res.Features[codersdk.FeatureUserLimit] + assert.Equal(t, codersdk.EntitlementGracePeriod, ul.Entitlement) + assert.Equal(t, int64(4), *ul.Limit) + assert.Equal(t, int64(5), *ul.Actual) + assert.True(t, ul.Enabled) + al := res.Features[codersdk.FeatureAuditLog] + assert.Equal(t, codersdk.EntitlementGracePeriod, al.Entitlement) + assert.True(t, al.Enabled) + assert.Nil(t, al.Limit) + assert.Nil(t, al.Actual) + assert.Len(t, res.Warnings, 2) + assert.Contains(t, res.Warnings, + "Your deployment has 5 active users but is only licensed for 4.") + assert.Contains(t, res.Warnings, + "Audit logging is enabled but your license for this feature is expired.") + }) + t.Run("Pubsub", func(t *testing.T) { + t.Parallel() + client, _, api := coderdenttest.NewWithAPI(t, nil) + entitlements, err := client.Entitlements(context.Background()) + require.NoError(t, err) + require.False(t, entitlements.HasLicense) + coderdtest.CreateFirstUser(t, client) + _, err = api.Database.InsertLicense(context.Background(), database.InsertLicenseParams{ + UploadedAt: database.Now(), + Exp: database.Now().AddDate(1, 0, 0), + JWT: coderdenttest.GenerateLicense(t, coderdenttest.LicenseOptions{ + AuditLog: true, + }), + }) + require.NoError(t, err) + err = api.Pubsub.Publish(coderd.PubsubEventLicenses, []byte{}) + require.NoError(t, err) + require.Eventually(t, func() bool { + entitlements, err := client.Entitlements(context.Background()) + assert.NoError(t, err) + return entitlements.HasLicense + }, testutil.WaitShort, testutil.IntervalFast) + }) + t.Run("Resync", func(t *testing.T) { + t.Parallel() + client, _, api := coderdenttest.NewWithAPI(t, &coderdenttest.Options{ + EntitlementsUpdateInterval: 25 * time.Millisecond, + }) + entitlements, err := client.Entitlements(context.Background()) + require.NoError(t, err) + require.False(t, entitlements.HasLicense) + coderdtest.CreateFirstUser(t, client) + // Valid + _, err = api.Database.InsertLicense(context.Background(), database.InsertLicenseParams{ + UploadedAt: database.Now(), + Exp: database.Now().AddDate(1, 0, 0), + JWT: coderdenttest.GenerateLicense(t, coderdenttest.LicenseOptions{ + AuditLog: true, + }), + }) + require.NoError(t, err) + // Expired + _, err = api.Database.InsertLicense(context.Background(), database.InsertLicenseParams{ + UploadedAt: database.Now(), + Exp: database.Now().AddDate(-1, 0, 0), + JWT: coderdenttest.GenerateLicense(t, coderdenttest.LicenseOptions{ + ExpiresAt: database.Now().AddDate(-1, 0, 0), + }), + }) + require.NoError(t, err) + // Invalid + _, err = api.Database.InsertLicense(context.Background(), database.InsertLicenseParams{ + UploadedAt: database.Now(), + Exp: database.Now().AddDate(1, 0, 0), + JWT: "invalid", + }) + require.NoError(t, err) + require.Eventually(t, func() bool { + entitlements, err := client.Entitlements(context.Background()) + assert.NoError(t, err) + return entitlements.HasLicense + }, testutil.WaitShort, testutil.IntervalFast) + }) +} + +func TestAuditLogging(t *testing.T) { + t.Parallel() + t.Run("Enabled", func(t *testing.T) { + t.Parallel() + client, _, api := coderdenttest.NewWithAPI(t, nil) + coderdtest.CreateFirstUser(t, client) + coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{ + AuditLog: true, + }) + auditor := *api.AGPL.Auditor.Load() + ea := audit.NewAuditor(audit.DefaultFilter) + t.Logf("%T = %T", auditor, ea) + assert.Equal(t, reflect.ValueOf(ea).Type(), reflect.ValueOf(auditor).Type()) + }) + t.Run("Disabled", func(t *testing.T) { + t.Parallel() + client, _, api := coderdenttest.NewWithAPI(t, nil) + coderdtest.CreateFirstUser(t, client) + auditor := *api.AGPL.Auditor.Load() + ea := agplaudit.NewNop() + t.Logf("%T = %T", auditor, ea) + assert.Equal(t, reflect.ValueOf(ea).Type(), reflect.ValueOf(auditor).Type()) + }) +} diff --git a/enterprise/coderd/coderdenttest/coderdenttest.go b/enterprise/coderd/coderdenttest/coderdenttest.go new file mode 100644 index 0000000000000..572b858bea31f --- /dev/null +++ b/enterprise/coderd/coderdenttest/coderdenttest.go @@ -0,0 +1,133 @@ +package coderdenttest + +import ( + "context" + "crypto/ed25519" + "crypto/rand" + "io" + "testing" + "time" + + "github.com/golang-jwt/jwt/v4" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/coder/coder/coderd/coderdtest" + "github.com/coder/coder/codersdk" + "github.com/coder/coder/enterprise/coderd" +) + +const ( + testKeyID = "enterprise-test" +) + +var ( + testPrivateKey ed25519.PrivateKey + testPublicKey ed25519.PublicKey +) + +func init() { + var err error + testPublicKey, testPrivateKey, err = ed25519.GenerateKey(rand.Reader) + if err != nil { + panic(err) + } +} + +type Options struct { + *coderdtest.Options + EntitlementsUpdateInterval time.Duration +} + +// New constructs a codersdk client connected to an in-memory Enterprise API instance. +func New(t *testing.T, options *Options) *codersdk.Client { + client, _, _ := NewWithAPI(t, options) + return client +} + +func NewWithAPI(t *testing.T, options *Options) (*codersdk.Client, io.Closer, *coderd.API) { + if options == nil { + options = &Options{} + } + if options.Options == nil { + options.Options = &coderdtest.Options{} + } + srv, cancelFunc, oop := coderdtest.NewOptions(t, options.Options) + coderAPI, err := coderd.New(context.Background(), &coderd.Options{ + AuditLogging: true, + Options: oop, + EntitlementsUpdateInterval: options.EntitlementsUpdateInterval, + Keys: map[string]ed25519.PublicKey{ + testKeyID: testPublicKey, + }, + }) + assert.NoError(t, err) + srv.Config.Handler = coderAPI.AGPL.RootHandler + var provisionerCloser io.Closer = nopcloser{} + if options.IncludeProvisionerDaemon { + provisionerCloser = coderdtest.NewProvisionerDaemon(t, coderAPI.AGPL) + } + t.Cleanup(func() { + cancelFunc() + _ = provisionerCloser.Close() + _ = coderAPI.Close() + }) + return codersdk.New(coderAPI.AccessURL), provisionerCloser, coderAPI +} + +type LicenseOptions struct { + AccountType string + AccountID string + GraceAt time.Time + ExpiresAt time.Time + UserLimit int64 + AuditLog bool +} + +// AddLicense generates a new license with the options provided and inserts it. +func AddLicense(t *testing.T, client *codersdk.Client, options LicenseOptions) codersdk.License { + license, err := client.AddLicense(context.Background(), codersdk.AddLicenseRequest{ + License: GenerateLicense(t, options), + }) + require.NoError(t, err) + return license +} + +// GenerateLicense returns a signed JWT using the test key. +func GenerateLicense(t *testing.T, options LicenseOptions) string { + if options.ExpiresAt.IsZero() { + options.ExpiresAt = time.Now().Add(time.Hour) + } + if options.GraceAt.IsZero() { + options.GraceAt = time.Now().Add(time.Hour) + } + auditLog := int64(0) + if options.AuditLog { + auditLog = 1 + } + c := &coderd.Claims{ + RegisteredClaims: jwt.RegisteredClaims{ + Issuer: "test@testing.test", + ExpiresAt: jwt.NewNumericDate(options.ExpiresAt), + NotBefore: jwt.NewNumericDate(time.Now().Add(-time.Minute)), + IssuedAt: jwt.NewNumericDate(time.Now().Add(-time.Minute)), + }, + LicenseExpires: jwt.NewNumericDate(options.GraceAt), + AccountType: options.AccountType, + AccountID: options.AccountID, + Version: coderd.CurrentVersion, + Features: coderd.Features{ + UserLimit: options.UserLimit, + AuditLog: auditLog, + }, + } + tok := jwt.NewWithClaims(jwt.SigningMethodEdDSA, c) + tok.Header[coderd.HeaderKeyID] = testKeyID + signedTok, err := tok.SignedString(testPrivateKey) + require.NoError(t, err) + return signedTok +} + +type nopcloser struct{} + +func (nopcloser) Close() error { return nil } diff --git a/enterprise/coderd/coderdenttest/coderdenttest_test.go b/enterprise/coderd/coderdenttest/coderdenttest_test.go new file mode 100644 index 0000000000000..ccea80cf9b968 --- /dev/null +++ b/enterprise/coderd/coderdenttest/coderdenttest_test.go @@ -0,0 +1,51 @@ +package coderdenttest_test + +import ( + "context" + "fmt" + "net/http" + "testing" + + "github.com/coder/coder/coderd/coderdtest" + "github.com/coder/coder/coderd/rbac" + "github.com/coder/coder/enterprise/coderd/coderdenttest" +) + +func TestNew(t *testing.T) { + t.Parallel() + _ = coderdenttest.New(t, nil) +} + +func TestAuthorizeAllEndpoints(t *testing.T) { + t.Parallel() + client, _, api := coderdenttest.NewWithAPI(t, &coderdenttest.Options{ + Options: &coderdtest.Options{ + Authorizer: &coderdtest.RecordingAuthorizer{}, + IncludeProvisionerDaemon: true, + }, + }) + admin := coderdtest.CreateFirstUser(t, client) + license := coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{}) + a := coderdtest.NewAuthTester(context.Background(), t, client, api.AGPL, admin) + a.URLParams["licenses/{id}"] = fmt.Sprintf("licenses/%d", license.ID) + + skipRoutes, assertRoute := coderdtest.AGPLRoutes(a) + assertRoute["GET:/api/v2/entitlements"] = coderdtest.RouteCheck{ + NoAuthorize: true, + } + assertRoute["POST:/api/v2/licenses"] = coderdtest.RouteCheck{ + AssertAction: rbac.ActionCreate, + AssertObject: rbac.ResourceLicense, + } + assertRoute["GET:/api/v2/licenses"] = coderdtest.RouteCheck{ + StatusCode: http.StatusOK, + AssertAction: rbac.ActionRead, + AssertObject: rbac.ResourceLicense, + } + assertRoute["DELETE:/api/v2/licenses/{id}"] = coderdtest.RouteCheck{ + AssertAction: rbac.ActionDelete, + AssertObject: rbac.ResourceLicense, + } + + a.Test(context.Background(), assertRoute, skipRoutes) +} diff --git a/enterprise/coderd/features.go b/enterprise/coderd/features.go deleted file mode 100644 index bc9977ff18441..0000000000000 --- a/enterprise/coderd/features.go +++ /dev/null @@ -1,327 +0,0 @@ -package coderd - -import ( - "context" - "crypto/ed25519" - "fmt" - "net/http" - "reflect" - "sync" - "time" - - "github.com/coder/coder/enterprise/audit/backends" - - "github.com/cenkalti/backoff/v4" - "golang.org/x/xerrors" - - "cdr.dev/slog" - - agpl "github.com/coder/coder/coderd" - agplAudit "github.com/coder/coder/coderd/audit" - "github.com/coder/coder/coderd/database" - "github.com/coder/coder/coderd/features" - "github.com/coder/coder/coderd/httpapi" - "github.com/coder/coder/codersdk" - "github.com/coder/coder/enterprise/audit" -) - -type Enablements struct { - AuditLogs bool -} - -type featuresService struct { - logger slog.Logger - database database.Store - pubsub database.Pubsub - keys map[string]ed25519.PublicKey - enablements Enablements - resyncInterval time.Duration - // enabledImplementations includes an "enabled" implementation of every feature. This is - // initialized at start of day and remains static. The consequence of this is that these things - // are hanging around using memory even if not licensed or in use, but it greatly simplifies the - // logic because we don't have to bother creating and destroying them as entitlements change. - // If we have a particularly memory-hungry feature in future, we might wish to reconsider this - // choice. - enabledImplementations agpl.FeatureInterfaces - - mu sync.RWMutex - entitlements entitlements -} - -// newFeaturesService creates a FeaturesService and starts it. It will continue running for the -// duration of the passed ctx. -func newFeaturesService( - ctx context.Context, - logger slog.Logger, - db database.Store, - pubsub database.Pubsub, - enablements Enablements, -) features.Service { - fs := &featuresService{ - logger: logger, - database: db, - pubsub: pubsub, - keys: keys, - enablements: enablements, - enabledImplementations: agpl.FeatureInterfaces{ - Auditor: audit.NewAuditor( - audit.DefaultFilter, - backends.NewPostgres(db, true), - backends.NewSlog(logger), - ), - }, - resyncInterval: 10 * time.Minute, - entitlements: entitlements{ - activeUsers: numericalEntitlement{ - entitlementLimit: entitlementLimit{ - unlimited: true, - }, - }, - }, - } - go fs.syncEntitlements(ctx) - return fs -} - -func (s *featuresService) EntitlementsAPI(rw http.ResponseWriter, r *http.Request) { - s.mu.RLock() - e := s.entitlements - s.mu.RUnlock() - - resp := codersdk.Entitlements{ - Features: make(map[string]codersdk.Feature), - Warnings: make([]string, 0), - HasLicense: e.hasLicense, - } - - // User limit - uf := codersdk.Feature{ - Entitlement: e.activeUsers.state.toSDK(), - Enabled: true, - } - if !e.activeUsers.unlimited { - n, err := s.database.GetActiveUserCount(r.Context()) - if err != nil { - httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{ - Message: "Unable to query database", - Detail: err.Error(), - }) - return - } - uf.Actual = &n - uf.Limit = &e.activeUsers.limit - if n > e.activeUsers.limit { - resp.Warnings = append(resp.Warnings, - fmt.Sprintf( - "Your deployment has %d active users but is only licensed for %d.", - n, e.activeUsers.limit)) - } - } - resp.Features[codersdk.FeatureUserLimit] = uf - - // Audit logs - resp.Features[codersdk.FeatureAuditLog] = codersdk.Feature{ - Entitlement: e.auditLogs.state.toSDK(), - Enabled: s.enablements.AuditLogs, - } - if e.auditLogs.state == gracePeriod && s.enablements.AuditLogs { - resp.Warnings = append(resp.Warnings, - "Audit logging is enabled but your license for this feature is expired.") - } - - httpapi.Write(rw, http.StatusOK, resp) -} - -type entitlementState int - -const ( - notEntitled entitlementState = iota - gracePeriod - entitled -) - -type entitlementLimit struct { - unlimited bool - limit int64 -} - -type entitlement struct { - state entitlementState -} - -func (s entitlementState) toSDK() codersdk.Entitlement { - switch s { - case notEntitled: - return codersdk.EntitlementNotEntitled - case gracePeriod: - return codersdk.EntitlementGracePeriod - case entitled: - return codersdk.EntitlementEntitled - default: - panic("unknown entitlementState") - } -} - -type numericalEntitlement struct { - entitlement - entitlementLimit -} - -type entitlements struct { - hasLicense bool - activeUsers numericalEntitlement - auditLogs entitlement -} - -func (s *featuresService) getEntitlements(ctx context.Context) (entitlements, error) { - licenses, err := s.database.GetUnexpiredLicenses(ctx) - if err != nil { - return entitlements{}, err - } - now := time.Now() - e := entitlements{ - activeUsers: numericalEntitlement{ - entitlementLimit: entitlementLimit{ - unlimited: true, - }, - }, - } - for _, l := range licenses { - claims, err := validateDBLicense(l, s.keys) - if err != nil { - s.logger.Debug(ctx, "skipping invalid license", - slog.F("id", l.ID), slog.Error(err)) - continue - } - e.hasLicense = true - thisEntitlement := entitled - if now.After(claims.LicenseExpires.Time) { - // if the grace period were over, the validation fails, so if we are after - // LicenseExpires we must be in grace period. - thisEntitlement = gracePeriod - } - if claims.Features.UserLimit > 0 { - e.activeUsers.state = thisEntitlement - e.activeUsers.unlimited = false - e.activeUsers.limit = max(e.activeUsers.limit, claims.Features.UserLimit) - } - if claims.Features.AuditLog > 0 { - e.auditLogs.state = thisEntitlement - } - } - return e, nil -} - -func (s *featuresService) syncEntitlements(ctx context.Context) { - eb := backoff.NewExponentialBackOff() - eb.MaxElapsedTime = 0 // retry indefinitely - b := backoff.WithContext(eb, ctx) - updates := make(chan struct{}, 1) - subscribed := false - - for { - select { - case <-ctx.Done(): - return - default: - // pass - } - if !subscribed { - cancel, err := s.pubsub.Subscribe(PubSubEventLicenses, func(_ context.Context, _ []byte) { - // don't block. If the channel is full, drop the event, as there is a resync - // scheduled already. - select { - case updates <- struct{}{}: - // pass - default: - // pass - } - }) - if err != nil { - s.logger.Warn(ctx, "failed to subscribe to license updates", slog.Error(err)) - time.Sleep(b.NextBackOff()) - continue - } - // nolint: revive - defer cancel() - subscribed = true - s.logger.Debug(ctx, "successfully subscribed to pubsub") - } - - s.logger.Info(ctx, "syncing licensed entitlements") - ents, err := s.getEntitlements(ctx) - if err != nil { - s.logger.Warn(ctx, "failed to get feature entitlements", slog.Error(err)) - time.Sleep(b.NextBackOff()) - continue - } - b.Reset() - - s.mu.Lock() - s.entitlements = ents - s.mu.Unlock() - s.logger.Debug(ctx, "synced licensed entitlements") - - select { - case <-ctx.Done(): - return - case <-time.After(s.resyncInterval): - continue - case <-updates: - s.logger.Debug(ctx, "got pubsub update") - continue - } - } -} - -func max(a, b int64) int64 { - if a > b { - return a - } - return b -} - -func (s *featuresService) Get(ps any) error { - if reflect.TypeOf(ps).Kind() != reflect.Pointer { - return xerrors.New("input must be pointer to struct") - } - vs := reflect.ValueOf(ps).Elem() - if vs.Kind() != reflect.Struct { - return xerrors.New("input must be pointer to struct") - } - // grab a local copy of entitlements so that we have a consistent set, but aren't keeping it - // locked from updates while we process. - s.mu.RLock() - ent := s.entitlements - s.mu.RUnlock() - - for i := 0; i < vs.NumField(); i++ { - vf := vs.Field(i) - tf := vf.Type() - if tf.Kind() != reflect.Interface { - return xerrors.Errorf("fields of input struct must be interfaces: %s", tf.String()) - } - - err := s.setImplementation(ent, vf, tf) - if err != nil { - return err - } - } - return nil -} - -func (s *featuresService) setImplementation(ent entitlements, vf reflect.Value, tf reflect.Type) error { - // c.f. https://stackoverflow.com/questions/7132848/how-to-get-the-reflect-type-of-an-interface - switch tf { - case reflect.TypeOf((*agplAudit.Auditor)(nil)).Elem(): - // Audit logging - if !s.enablements.AuditLogs || ent.auditLogs.state == notEntitled { - vf.Set(reflect.ValueOf(agpl.DisabledImplementations.Auditor)) - return nil - } - vf.Set(reflect.ValueOf(s.enabledImplementations.Auditor)) - return nil - default: - return xerrors.Errorf("unable to find implementation of interface %s", tf.String()) - } -} diff --git a/enterprise/coderd/features_internal_test.go b/enterprise/coderd/features_internal_test.go deleted file mode 100644 index a195c2ffe784b..0000000000000 --- a/enterprise/coderd/features_internal_test.go +++ /dev/null @@ -1,545 +0,0 @@ -package coderd - -import ( - "context" - "crypto/ed25519" - "crypto/rand" - "encoding/json" - "net/http" - "net/http/httptest" - "reflect" - "testing" - "time" - - "github.com/golang-jwt/jwt/v4" - "github.com/google/uuid" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - "cdr.dev/slog/sloggers/slogtest" - - agplCoderd "github.com/coder/coder/coderd" - agplAudit "github.com/coder/coder/coderd/audit" - "github.com/coder/coder/coderd/database" - "github.com/coder/coder/coderd/database/databasefake" - "github.com/coder/coder/coderd/features" - "github.com/coder/coder/codersdk" - "github.com/coder/coder/enterprise/audit" - "github.com/coder/coder/enterprise/audit/backends" - "github.com/coder/coder/testutil" -) - -func TestFeaturesService_EntitlementsAPI(t *testing.T) { - t.Parallel() - logger := slogtest.Make(t, nil) - - // Note that these are not actually used because we don't run the syncEntitlements - // routine in this test. - pubsub := database.NewPubsubInMemory() - pub, _, err := ed25519.GenerateKey(rand.Reader) - require.NoError(t, err) - keyID := "testing" - - t.Run("NoLicense", func(t *testing.T) { - t.Parallel() - db := databasefake.New() - uut := &featuresService{ - logger: logger, - database: db, - pubsub: pubsub, - keys: map[string]ed25519.PublicKey{keyID: pub}, - enablements: Enablements{AuditLogs: true}, - entitlements: entitlements{ - hasLicense: false, - activeUsers: numericalEntitlement{ - entitlement{notEntitled}, - entitlementLimit{ - unlimited: true, - }, - }, - auditLogs: entitlement{notEntitled}, - }, - } - result := requestEntitlements(t, uut) - assert.False(t, result.HasLicense) - assert.Empty(t, result.Warnings) - assert.Equal(t, codersdk.EntitlementNotEntitled, result.Features[codersdk.FeatureUserLimit].Entitlement) - assert.Equal(t, codersdk.EntitlementNotEntitled, result.Features[codersdk.FeatureAuditLog].Entitlement) - }) - - t.Run("FullLicense", func(t *testing.T) { - t.Parallel() - ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) - defer cancel() - db := databasefake.New() - uut := &featuresService{ - logger: logger, - database: db, - pubsub: pubsub, - keys: map[string]ed25519.PublicKey{keyID: pub}, - enablements: Enablements{AuditLogs: true}, - entitlements: entitlements{ - hasLicense: true, - activeUsers: numericalEntitlement{ - entitlement{entitled}, - entitlementLimit{ - unlimited: false, - limit: 100, - }, - }, - auditLogs: entitlement{entitled}, - }, - } - _, err := db.InsertUser(ctx, database.InsertUserParams{ - ID: uuid.UUID{}, - Email: "", - Username: "", - HashedPassword: nil, - CreatedAt: time.Time{}, - UpdatedAt: time.Time{}, - RBACRoles: nil, - LoginType: "", - }) - require.NoError(t, err) - result := requestEntitlements(t, uut) - assert.True(t, result.HasLicense) - ul := result.Features[codersdk.FeatureUserLimit] - assert.Equal(t, codersdk.EntitlementEntitled, ul.Entitlement) - assert.Equal(t, int64(100), *ul.Limit) - assert.Equal(t, int64(1), *ul.Actual) - assert.True(t, ul.Enabled) - al := result.Features[codersdk.FeatureAuditLog] - assert.Equal(t, codersdk.EntitlementEntitled, al.Entitlement) - assert.True(t, al.Enabled) - assert.Nil(t, al.Limit) - assert.Nil(t, al.Actual) - assert.Empty(t, result.Warnings) - }) - - t.Run("Warnings", func(t *testing.T) { - t.Parallel() - ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) - defer cancel() - db := databasefake.New() - uut := &featuresService{ - logger: logger, - database: db, - pubsub: pubsub, - keys: map[string]ed25519.PublicKey{keyID: pub}, - enablements: Enablements{AuditLogs: true}, - entitlements: entitlements{ - hasLicense: true, - activeUsers: numericalEntitlement{ - entitlement{gracePeriod}, - entitlementLimit{ - unlimited: false, - limit: 4, - }, - }, - auditLogs: entitlement{gracePeriod}, - }, - } - for i := byte(0); i < 5; i++ { - _, err := db.InsertUser(ctx, database.InsertUserParams{ - ID: uuid.UUID{i}, - Email: "", - Username: "", - HashedPassword: nil, - CreatedAt: time.Time{}, - UpdatedAt: time.Time{}, - RBACRoles: nil, - LoginType: "", - }) - require.NoError(t, err) - } - result := requestEntitlements(t, uut) - assert.True(t, result.HasLicense) - ul := result.Features[codersdk.FeatureUserLimit] - assert.Equal(t, codersdk.EntitlementGracePeriod, ul.Entitlement) - assert.Equal(t, int64(4), *ul.Limit) - assert.Equal(t, int64(5), *ul.Actual) - assert.True(t, ul.Enabled) - al := result.Features[codersdk.FeatureAuditLog] - assert.Equal(t, codersdk.EntitlementGracePeriod, al.Entitlement) - assert.True(t, al.Enabled) - assert.Nil(t, al.Limit) - assert.Nil(t, al.Actual) - assert.Len(t, result.Warnings, 2) - assert.Contains(t, result.Warnings, - "Your deployment has 5 active users but is only licensed for 4.") - assert.Contains(t, result.Warnings, - "Audit logging is enabled but your license for this feature is expired.") - }) -} - -func TestFeaturesServiceSyncEntitlements(t *testing.T) { - t.Parallel() - pub, priv, err := ed25519.GenerateKey(rand.Reader) - require.NoError(t, err) - keyID := "testing" - - // This tests that pubsub updates work by setting the resync interval very long - t.Run("PubSub", func(t *testing.T) { - t.Parallel() - ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) - defer cancel() - logger := slogtest.Make(t, nil) - pubsub := database.NewPubsubInMemory() - db := databasefake.New() - uut := &featuresService{ - logger: logger, - database: db, - pubsub: pubsub, - keys: map[string]ed25519.PublicKey{keyID: pub}, - enablements: Enablements{AuditLogs: true}, - resyncInterval: time.Hour, // no resyncs during test - entitlements: entitlements{}, - } - - _, invalidKey, err := ed25519.GenerateKey(rand.Reader) - require.NoError(t, err) - - // Start of day, 3 licenses, one expired, one invalid - _ = putLicense(ctx, t, db, priv, keyID, 1000, -2*time.Hour, -1*time.Hour) - _ = putLicense(ctx, t, db, invalidKey, "invalid", 900, time.Hour, 2*time.Hour) - l0 := putLicense(ctx, t, db, priv, keyID, 300, time.Hour, 2*time.Hour) - - go uut.syncEntitlements(ctx) - - testutil.Eventually(ctx, t, userLimitIs(uut, 300), testutil.IntervalFast) - - // New license - l1 := putLicense(ctx, t, db, priv, keyID, 305, time.Hour, 2*time.Hour) - err = pubsub.Publish(PubSubEventLicenses, []byte("add")) - require.NoError(t, err) - - // User limit goes up, because 305 > 300 - testutil.Eventually(ctx, t, userLimitIs(uut, 305), testutil.IntervalFast) - - // New license with lower limit - _ = putLicense(ctx, t, db, priv, keyID, 295, time.Hour, 2*time.Hour) - err = pubsub.Publish(PubSubEventLicenses, []byte("add")) - require.NoError(t, err) - - // Need to delete the others before the limit lowers - _, err = db.DeleteLicense(ctx, l1.ID) - require.NoError(t, err) - err = pubsub.Publish(PubSubEventLicenses, []byte("delete")) - require.NoError(t, err) - testutil.Eventually(ctx, t, userLimitIs(uut, 300), testutil.IntervalFast) - - _, err = db.DeleteLicense(ctx, l0.ID) - require.NoError(t, err) - err = pubsub.Publish(PubSubEventLicenses, []byte("delete")) - require.NoError(t, err) - testutil.Eventually(ctx, t, userLimitIs(uut, 295), testutil.IntervalFast) - }) - - // This tests that periodic resyncs work by setting the resync interval very fast and - // not sending any pubsub updates. - t.Run("Resyncs", func(t *testing.T) { - t.Parallel() - ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) - defer cancel() - logger := slogtest.Make(t, nil) - pubsub := database.NewPubsubInMemory() - db := databasefake.New() - uut := &featuresService{ - logger: logger, - database: db, - pubsub: pubsub, - keys: map[string]ed25519.PublicKey{keyID: pub}, - enablements: Enablements{AuditLogs: true}, - resyncInterval: 10 * time.Millisecond, - entitlements: entitlements{}, - } - - _, invalidKey, err := ed25519.GenerateKey(rand.Reader) - require.NoError(t, err) - - // Start of day, 3 licenses, one expired, one invalid - _ = putLicense(ctx, t, db, priv, keyID, 1000, -2*time.Hour, -1*time.Hour) - _ = putLicense(ctx, t, db, invalidKey, "invalid", 900, time.Hour, 2*time.Hour) - l0 := putLicense(ctx, t, db, priv, keyID, 300, time.Hour, 2*time.Hour) - - go uut.syncEntitlements(ctx) - - testutil.Eventually(ctx, t, userLimitIs(uut, 300), testutil.IntervalFast) - - // New license - l1 := putLicense(ctx, t, db, priv, keyID, 305, time.Hour, 2*time.Hour) - - // User limit goes up, because 305 > 300 - testutil.Eventually(ctx, t, userLimitIs(uut, 305), testutil.IntervalFast) - - // New license with lower limit - _ = putLicense(ctx, t, db, priv, keyID, 295, time.Hour, 2*time.Hour) - - // Need to delete the others before the limit lowers - _, err = db.DeleteLicense(ctx, l1.ID) - require.NoError(t, err) - testutil.Eventually(ctx, t, userLimitIs(uut, 300), testutil.IntervalFast) - - _, err = db.DeleteLicense(ctx, l0.ID) - require.NoError(t, err) - testutil.Eventually(ctx, t, userLimitIs(uut, 295), testutil.IntervalFast) - }) -} - -func requestEntitlements(t *testing.T, uut features.Service) codersdk.Entitlements { - t.Helper() - r := httptest.NewRequest("GET", "https://example.com/api/v2/entitlements", nil) - rw := httptest.NewRecorder() - uut.EntitlementsAPI(rw, r) - resp := rw.Result() - defer resp.Body.Close() - assert.Equal(t, http.StatusOK, resp.StatusCode) - dec := json.NewDecoder(resp.Body) - var result codersdk.Entitlements - err := dec.Decode(&result) - require.NoError(t, err) - return result -} - -func putLicense( - ctx context.Context, t *testing.T, db database.Store, - k ed25519.PrivateKey, keyID string, userLimit int64, - timeToGrace, timeToExpire time.Duration, -) database.License { - t.Helper() - c := &Claims{ - RegisteredClaims: jwt.RegisteredClaims{ - Issuer: "test@testing.test", - ExpiresAt: jwt.NewNumericDate(time.Now().Add(timeToExpire)), - NotBefore: jwt.NewNumericDate(time.Now().Add(-time.Minute)), - IssuedAt: jwt.NewNumericDate(time.Now().Add(-time.Minute)), - }, - LicenseExpires: jwt.NewNumericDate(time.Now().Add(timeToGrace)), - Version: CurrentVersion, - Features: Features{ - UserLimit: userLimit, - AuditLog: 1, - }, - } - j, err := makeLicense(c, k, keyID) - require.NoError(t, err) - l, err := db.InsertLicense(ctx, database.InsertLicenseParams{ - UploadedAt: c.IssuedAt.Time, - JWT: j, - Exp: c.ExpiresAt.Time, - }) - require.NoError(t, err) - return l -} - -func userLimitIs(fs *featuresService, limit int64) func(context.Context) bool { - return func(_ context.Context) bool { - fs.mu.RLock() - defer fs.mu.RUnlock() - return fs.entitlements.activeUsers.limit == limit - } -} - -func TestFeaturesServiceGet(t *testing.T) { - t.Parallel() - logger := slogtest.Make(t, nil) - - // Note that these are not actually used because we don't run the syncEntitlements - // routine in this test. - pubsub := database.NewPubsubInMemory() - pub, _, err := ed25519.GenerateKey(rand.Reader) - require.NoError(t, err) - keyID := "testing" - db := databasefake.New() - - t.Run("AuditorOff", func(t *testing.T) { - t.Parallel() - uut := &featuresService{ - logger: logger, - database: db, - pubsub: pubsub, - keys: map[string]ed25519.PublicKey{keyID: pub}, - enablements: Enablements{AuditLogs: true}, - enabledImplementations: agplCoderd.FeatureInterfaces{ - Auditor: audit.NewAuditor(audit.DefaultFilter), - }, - entitlements: entitlements{ - hasLicense: false, - activeUsers: numericalEntitlement{ - entitlement{notEntitled}, - entitlementLimit{ - unlimited: true, - }, - }, - auditLogs: entitlement{notEntitled}, - }, - } - target := struct { - Auditor agplAudit.Auditor - }{} - err := uut.Get(&target) - require.NoError(t, err) - assert.NotNil(t, target.Auditor) - nop := agplAudit.NewNop() - assert.Equal(t, reflect.ValueOf(nop).Type(), reflect.ValueOf(target.Auditor).Type()) - }) - - t.Run("AuditorOn", func(t *testing.T) { - t.Parallel() - uut := &featuresService{ - logger: logger, - database: db, - pubsub: pubsub, - keys: map[string]ed25519.PublicKey{keyID: pub}, - enablements: Enablements{AuditLogs: true}, - enabledImplementations: agplCoderd.FeatureInterfaces{ - Auditor: audit.NewAuditor(audit.DefaultFilter), - }, - entitlements: entitlements{ - hasLicense: false, - activeUsers: numericalEntitlement{ - entitlement{notEntitled}, - entitlementLimit{ - unlimited: true, - }, - }, - auditLogs: entitlement{entitled}, - }, - } - target := struct { - Auditor agplAudit.Auditor - }{} - err := uut.Get(&target) - require.NoError(t, err) - assert.NotNil(t, target.Auditor) - ea := audit.NewAuditor( - audit.DefaultFilter, - backends.NewPostgres(db, true), - backends.NewSlog(logger), - ) - assert.Equal(t, reflect.ValueOf(ea).Type(), reflect.ValueOf(target.Auditor).Type()) - }) - - t.Run("NotPointer", func(t *testing.T) { - t.Parallel() - uut := &featuresService{ - logger: logger, - database: db, - pubsub: pubsub, - keys: map[string]ed25519.PublicKey{keyID: pub}, - enablements: Enablements{AuditLogs: true}, - enabledImplementations: agplCoderd.FeatureInterfaces{ - Auditor: audit.NewAuditor(audit.DefaultFilter), - }, - entitlements: entitlements{ - hasLicense: false, - activeUsers: numericalEntitlement{ - entitlement{notEntitled}, - entitlementLimit{ - unlimited: true, - }, - }, - auditLogs: entitlement{notEntitled}, - }, - } - target := struct { - Auditor agplAudit.Auditor - }{} - err := uut.Get(target) - require.Error(t, err) - assert.Nil(t, target.Auditor) - }) - - t.Run("UnknownInterface", func(t *testing.T) { - t.Parallel() - uut := &featuresService{ - logger: logger, - database: db, - pubsub: pubsub, - keys: map[string]ed25519.PublicKey{keyID: pub}, - enablements: Enablements{AuditLogs: true}, - enabledImplementations: agplCoderd.FeatureInterfaces{ - Auditor: audit.NewAuditor(audit.DefaultFilter), - }, - entitlements: entitlements{ - hasLicense: false, - activeUsers: numericalEntitlement{ - entitlement{notEntitled}, - entitlementLimit{ - unlimited: true, - }, - }, - auditLogs: entitlement{notEntitled}, - }, - } - target := struct { - test testInterface - }{} - err := uut.Get(&target) - require.Error(t, err) - assert.Nil(t, target.test) - }) - - t.Run("PointerToNonStruct", func(t *testing.T) { - t.Parallel() - uut := &featuresService{ - logger: logger, - database: db, - pubsub: pubsub, - keys: map[string]ed25519.PublicKey{keyID: pub}, - enablements: Enablements{AuditLogs: true}, - enabledImplementations: agplCoderd.FeatureInterfaces{ - Auditor: audit.NewAuditor(audit.DefaultFilter), - }, - entitlements: entitlements{ - hasLicense: false, - activeUsers: numericalEntitlement{ - entitlement{notEntitled}, - entitlementLimit{ - unlimited: true, - }, - }, - auditLogs: entitlement{notEntitled}, - }, - } - var target agplAudit.Auditor - err := uut.Get(&target) - require.Error(t, err) - assert.Nil(t, target) - }) - - t.Run("StructWithNonInterfaces", func(t *testing.T) { - t.Parallel() - uut := &featuresService{ - logger: logger, - database: db, - pubsub: pubsub, - keys: map[string]ed25519.PublicKey{keyID: pub}, - enablements: Enablements{AuditLogs: true}, - enabledImplementations: agplCoderd.FeatureInterfaces{ - Auditor: audit.NewAuditor(audit.DefaultFilter), - }, - entitlements: entitlements{ - hasLicense: false, - activeUsers: numericalEntitlement{ - entitlement{notEntitled}, - entitlementLimit{ - unlimited: true, - }, - }, - auditLogs: entitlement{notEntitled}, - }, - } - target := struct { - N int64 - Auditor agplAudit.Auditor - }{} - err := uut.Get(&target) - require.Error(t, err) - assert.Nil(t, target.Auditor) - }) -} - -type testInterface interface { - Test() error -} diff --git a/enterprise/coderd/licenses.go b/enterprise/coderd/licenses.go index 9f75796da19bd..5b8273f2ffe60 100644 --- a/enterprise/coderd/licenses.go +++ b/enterprise/coderd/licenses.go @@ -30,7 +30,8 @@ const ( HeaderKeyID = "kid" AccountTypeSalesforce = "salesforce" VersionClaim = "version" - PubSubEventLicenses = "licenses" + + PubsubEventLicenses = "licenses" ) var ValidMethods = []string{"EdDSA"} @@ -41,7 +42,7 @@ var ValidMethods = []string{"EdDSA"} //go:embed keys/2022-08-12 var key20220812 []byte -var keys = map[string]ed25519.PublicKey{"2022-08-12": ed25519.PublicKey(key20220812)} +var Keys = map[string]ed25519.PublicKey{"2022-08-12": ed25519.PublicKey(key20220812)} type Features struct { UserLimit int64 `json:"user_limit"` @@ -68,96 +69,6 @@ var ( ErrMissingLicenseExpires = xerrors.New("license missing license_expires") ) -// parseLicense parses the license and returns the claims. If the license's signature is invalid or -// is not parsable, an error is returned. -func parseLicense(l string, keys map[string]ed25519.PublicKey) (jwt.MapClaims, error) { - tok, err := jwt.Parse( - l, - keyFunc(keys), - jwt.WithValidMethods(ValidMethods), - ) - if err != nil { - return nil, err - } - if claims, ok := tok.Claims.(jwt.MapClaims); ok && tok.Valid { - version, ok := claims[VersionClaim].(float64) - if !ok { - return nil, ErrInvalidVersion - } - if int64(version) != CurrentVersion { - return nil, ErrInvalidVersion - } - return claims, nil - } - return nil, xerrors.New("unable to parse Claims") -} - -// validateDBLicense validates a database.License record, and if valid, returns the claims. If -// unparsable or invalid, it returns an error -func validateDBLicense(l database.License, keys map[string]ed25519.PublicKey) (*Claims, error) { - tok, err := jwt.ParseWithClaims( - l.JWT, - &Claims{}, - keyFunc(keys), - jwt.WithValidMethods(ValidMethods), - ) - if err != nil { - return nil, err - } - if claims, ok := tok.Claims.(*Claims); ok && tok.Valid { - if claims.Version != uint64(CurrentVersion) { - return nil, ErrInvalidVersion - } - if claims.LicenseExpires == nil { - return nil, ErrMissingLicenseExpires - } - return claims, nil - } - return nil, xerrors.New("unable to parse Claims") -} - -func keyFunc(keys map[string]ed25519.PublicKey) func(*jwt.Token) (interface{}, error) { - return func(j *jwt.Token) (interface{}, error) { - keyID, ok := j.Header[HeaderKeyID].(string) - if !ok { - return nil, ErrMissingKeyID - } - k, ok := keys[keyID] - if !ok { - return nil, xerrors.Errorf("no key with ID %s", keyID) - } - return k, nil - } -} - -// licenseAPI handles enterprise licenses, and attaches to the main coderd.API via the -// LicenseHandler option, so that it serves all routes under /api/v2/licenses -type licenseAPI struct { - router chi.Router - logger slog.Logger - database database.Store - pubsub database.Pubsub - auth *coderd.HTTPAuthorizer -} - -func newLicenseAPI( - l slog.Logger, - db database.Store, - ps database.Pubsub, - auth *coderd.HTTPAuthorizer, -) *licenseAPI { - r := chi.NewRouter() - a := &licenseAPI{router: r, logger: l, database: db, pubsub: ps, auth: auth} - r.Post("/", a.postLicense) - r.Get("/", a.licenses) - r.Delete("/{id}", a.delete) - return a -} - -func (a *licenseAPI) handler() http.Handler { - return a.router -} - // postLicense adds a new Enterprise license to the cluster. We allow multiple different licenses // in the cluster at one time for several reasons: // @@ -167,8 +78,8 @@ func (a *licenseAPI) handler() http.Handler { // we generally don't want the old features to immediately break without warning. With a grace // period on the license, features will continue to work from the old license until its grace // period, then the users will get a warning allowing them to gracefully stop using the feature. -func (a *licenseAPI) postLicense(rw http.ResponseWriter, r *http.Request) { - if !a.auth.Authorize(r, rbac.ActionCreate, rbac.ResourceLicense) { +func (api *API) postLicense(rw http.ResponseWriter, r *http.Request) { + if !api.AGPL.Authorize(r, rbac.ActionCreate, rbac.ResourceLicense) { httpapi.Forbidden(rw) return } @@ -178,7 +89,7 @@ func (a *licenseAPI) postLicense(rw http.ResponseWriter, r *http.Request) { return } - claims, err := parseLicense(addLicense.License, keys) + claims, err := parseLicense(addLicense.License, api.Keys) if err != nil { httpapi.Write(rw, http.StatusBadRequest, codersdk.Response{ Message: "Invalid license", @@ -196,7 +107,7 @@ func (a *licenseAPI) postLicense(rw http.ResponseWriter, r *http.Request) { } expTime := time.Unix(int64(exp), 0) - dl, err := a.database.InsertLicense(r.Context(), database.InsertLicenseParams{ + dl, err := api.Database.InsertLicense(r.Context(), database.InsertLicenseParams{ UploadedAt: database.Now(), JWT: addLicense.License, Exp: expTime, @@ -208,25 +119,25 @@ func (a *licenseAPI) postLicense(rw http.ResponseWriter, r *http.Request) { }) return } - err = a.pubsub.Publish(PubSubEventLicenses, []byte("add")) + err = api.updateEntitlements(r.Context()) if err != nil { - a.logger.Error(context.Background(), "failed to publish license add", slog.Error(err)) + httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{ + Message: "Failed to update entitlements", + Detail: err.Error(), + }) + return + } + err = api.Pubsub.Publish(PubsubEventLicenses, []byte("add")) + if err != nil { + api.Logger.Error(context.Background(), "failed to publish license add", slog.Error(err)) // don't fail the HTTP request, since we did write it successfully to the database } httpapi.Write(rw, http.StatusCreated, convertLicense(dl, claims)) } -func convertLicense(dl database.License, c jwt.MapClaims) codersdk.License { - return codersdk.License{ - ID: dl.ID, - UploadedAt: dl.UploadedAt, - Claims: c, - } -} - -func (a *licenseAPI) licenses(rw http.ResponseWriter, r *http.Request) { - licenses, err := a.database.GetLicenses(r.Context()) +func (api *API) licenses(rw http.ResponseWriter, r *http.Request) { + licenses, err := api.Database.GetLicenses(r.Context()) if xerrors.Is(err, sql.ErrNoRows) { httpapi.Write(rw, http.StatusOK, []codersdk.License{}) return @@ -239,7 +150,7 @@ func (a *licenseAPI) licenses(rw http.ResponseWriter, r *http.Request) { return } - licenses, err = coderd.AuthorizeFilter(a.auth, r, rbac.ActionRead, licenses) + licenses, err = coderd.AuthorizeFilter(api.AGPL.HTTPAuth, r, rbac.ActionRead, licenses) if err != nil { httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{ Message: "Internal error fetching licenses.", @@ -258,6 +169,59 @@ func (a *licenseAPI) licenses(rw http.ResponseWriter, r *http.Request) { httpapi.Write(rw, http.StatusOK, sdkLicenses) } +func (api *API) deleteLicense(rw http.ResponseWriter, r *http.Request) { + if !api.AGPL.Authorize(r, rbac.ActionDelete, rbac.ResourceLicense) { + httpapi.Forbidden(rw) + return + } + + idStr := chi.URLParam(r, "id") + id, err := strconv.ParseInt(idStr, 10, 32) + if err != nil { + httpapi.Write(rw, http.StatusNotFound, codersdk.Response{ + Message: "License ID must be an integer", + }) + return + } + + _, err = api.Database.DeleteLicense(r.Context(), int32(id)) + if xerrors.Is(err, sql.ErrNoRows) { + httpapi.Write(rw, http.StatusNotFound, codersdk.Response{ + Message: "Unknown license ID", + }) + return + } + if err != nil { + httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{ + Message: "Internal error deleting license", + Detail: err.Error(), + }) + return + } + err = api.updateEntitlements(r.Context()) + if err != nil { + httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{ + Message: "Failed to update entitlements", + Detail: err.Error(), + }) + return + } + err = api.Pubsub.Publish(PubsubEventLicenses, []byte("delete")) + if err != nil { + api.Logger.Error(context.Background(), "failed to publish license delete", slog.Error(err)) + // don't fail the HTTP request, since we did write it successfully to the database + } + rw.WriteHeader(http.StatusOK) +} + +func convertLicense(dl database.License, c jwt.MapClaims) codersdk.License { + return codersdk.License{ + ID: dl.ID, + UploadedAt: dl.UploadedAt, + Claims: c, + } +} + func convertLicenses(licenses []database.License) ([]codersdk.License, error) { var out []codersdk.License for _, l := range licenses { @@ -292,40 +256,64 @@ func decodeClaims(l database.License) (jwt.MapClaims, error) { return c, err } -func (a *licenseAPI) delete(rw http.ResponseWriter, r *http.Request) { - if !a.auth.Authorize(r, rbac.ActionDelete, rbac.ResourceLicense) { - httpapi.Forbidden(rw) - return - } - - idStr := chi.URLParam(r, "id") - id, err := strconv.ParseInt(idStr, 10, 32) +// parseLicense parses the license and returns the claims. If the license's signature is invalid or +// is not parsable, an error is returned. +func parseLicense(l string, keys map[string]ed25519.PublicKey) (jwt.MapClaims, error) { + tok, err := jwt.Parse( + l, + keyFunc(keys), + jwt.WithValidMethods(ValidMethods), + ) if err != nil { - httpapi.Write(rw, http.StatusNotFound, codersdk.Response{ - Message: "License ID must be an integer", - }) - return + return nil, err } - - _, err = a.database.DeleteLicense(r.Context(), int32(id)) - if xerrors.Is(err, sql.ErrNoRows) { - httpapi.Write(rw, http.StatusNotFound, codersdk.Response{ - Message: "Unknown license ID", - }) - return + if claims, ok := tok.Claims.(jwt.MapClaims); ok && tok.Valid { + version, ok := claims[VersionClaim].(float64) + if !ok { + return nil, ErrInvalidVersion + } + if int64(version) != CurrentVersion { + return nil, ErrInvalidVersion + } + return claims, nil } + return nil, xerrors.New("unable to parse Claims") +} + +// validateDBLicense validates a database.License record, and if valid, returns the claims. If +// unparsable or invalid, it returns an error +func validateDBLicense(l database.License, keys map[string]ed25519.PublicKey) (*Claims, error) { + tok, err := jwt.ParseWithClaims( + l.JWT, + &Claims{}, + keyFunc(keys), + jwt.WithValidMethods(ValidMethods), + ) if err != nil { - httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{ - Message: "Internal error deleting license", - Detail: err.Error(), - }) - return + return nil, err } + if claims, ok := tok.Claims.(*Claims); ok && tok.Valid { + if claims.Version != uint64(CurrentVersion) { + return nil, ErrInvalidVersion + } + if claims.LicenseExpires == nil { + return nil, ErrMissingLicenseExpires + } + return claims, nil + } + return nil, xerrors.New("unable to parse Claims") +} - err = a.pubsub.Publish(PubSubEventLicenses, []byte("delete")) - if err != nil { - a.logger.Error(context.Background(), "failed to publish license delete", slog.Error(err)) - // don't fail the HTTP request, since we did write it successfully to the database +func keyFunc(keys map[string]ed25519.PublicKey) func(*jwt.Token) (interface{}, error) { + return func(j *jwt.Token) (interface{}, error) { + keyID, ok := j.Header[HeaderKeyID].(string) + if !ok { + return nil, ErrMissingKeyID + } + k, ok := keys[keyID] + if !ok { + return nil, xerrors.Errorf("no key with ID %s", keyID) + } + return k, nil } - rw.WriteHeader(http.StatusOK) } diff --git a/enterprise/coderd/licenses_internal_test.go b/enterprise/coderd/licenses_internal_test.go deleted file mode 100644 index 5695ca0df5233..0000000000000 --- a/enterprise/coderd/licenses_internal_test.go +++ /dev/null @@ -1,316 +0,0 @@ -package coderd - -import ( - "context" - "crypto/ed25519" - "crypto/rand" - "encoding/json" - "net/http" - "testing" - "time" - - "golang.org/x/xerrors" - - "github.com/stretchr/testify/assert" - - "github.com/golang-jwt/jwt/v4" - "github.com/stretchr/testify/require" - - "github.com/coder/coder/coderd/coderdtest" - "github.com/coder/coder/codersdk" - "github.com/coder/coder/testutil" -) - -// these tests patch the map of license keys, so cannot be run in parallel -// nolint:paralleltest -func TestPostLicense(t *testing.T) { - pubKey, privKey, err := ed25519.GenerateKey(rand.Reader) - require.NoError(t, err) - keyID := "testing" - oldKeys := keys - defer func() { - t.Log("restoring keys") - keys = oldKeys - }() - keys = map[string]ed25519.PublicKey{keyID: pubKey} - - t.Run("POST", func(t *testing.T) { - client := coderdtest.New(t, &coderdtest.Options{APIBuilder: NewEnterprise}) - _ = coderdtest.CreateFirstUser(t, client) - ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) - defer cancel() - - claims := &Claims{ - RegisteredClaims: jwt.RegisteredClaims{ - Issuer: "test@coder.test", - IssuedAt: jwt.NewNumericDate(time.Now()), - NotBefore: jwt.NewNumericDate(time.Now()), - ExpiresAt: jwt.NewNumericDate(time.Now().Add(2 * time.Hour)), - }, - LicenseExpires: jwt.NewNumericDate(time.Now().Add(time.Hour)), - AccountType: AccountTypeSalesforce, - AccountID: "testing", - Version: CurrentVersion, - Features: Features{ - UserLimit: 0, - AuditLog: 1, - }, - } - lic, err := makeLicense(claims, privKey, keyID) - require.NoError(t, err) - - respLic, err := client.AddLicense(ctx, codersdk.AddLicenseRequest{ - License: lic, - }) - require.NoError(t, err) - assert.GreaterOrEqual(t, respLic.ID, int32(0)) - // just a couple spot checks for sanity - assert.Equal(t, claims.AccountID, respLic.Claims["account_id"]) - features, ok := respLic.Claims["features"].(map[string]interface{}) - require.True(t, ok) - assert.Equal(t, json.Number("1"), features[codersdk.FeatureAuditLog]) - }) - - t.Run("POST_unathorized", func(t *testing.T) { - client := coderdtest.New(t, &coderdtest.Options{APIBuilder: NewEnterprise}) - ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) - defer cancel() - - claims := &Claims{ - RegisteredClaims: jwt.RegisteredClaims{ - Issuer: "test@coder.test", - IssuedAt: jwt.NewNumericDate(time.Now()), - NotBefore: jwt.NewNumericDate(time.Now()), - ExpiresAt: jwt.NewNumericDate(time.Now().Add(2 * time.Hour)), - }, - LicenseExpires: jwt.NewNumericDate(time.Now().Add(time.Hour)), - AccountType: AccountTypeSalesforce, - AccountID: "testing", - Version: CurrentVersion, - Features: Features{ - UserLimit: 0, - AuditLog: 1, - }, - } - lic, err := makeLicense(claims, privKey, keyID) - require.NoError(t, err) - - _, err = client.AddLicense(ctx, codersdk.AddLicenseRequest{ - License: lic, - }) - errResp := &codersdk.Error{} - if xerrors.As(err, &errResp) { - assert.Equal(t, 401, errResp.StatusCode()) - } else { - t.Error("expected to get error status 401") - } - }) - - t.Run("POST_corrupted", func(t *testing.T) { - client := coderdtest.New(t, &coderdtest.Options{APIBuilder: NewEnterprise}) - _ = coderdtest.CreateFirstUser(t, client) - ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) - defer cancel() - - claims := &Claims{ - RegisteredClaims: jwt.RegisteredClaims{ - Issuer: "test@coder.test", - IssuedAt: jwt.NewNumericDate(time.Now()), - NotBefore: jwt.NewNumericDate(time.Now()), - ExpiresAt: jwt.NewNumericDate(time.Now().Add(2 * time.Hour)), - }, - LicenseExpires: jwt.NewNumericDate(time.Now().Add(time.Hour)), - AccountType: AccountTypeSalesforce, - AccountID: "testing", - Version: CurrentVersion, - Features: Features{ - UserLimit: 0, - AuditLog: 1, - }, - } - lic, err := makeLicense(claims, privKey, keyID) - require.NoError(t, err) - - _, err = client.AddLicense(ctx, codersdk.AddLicenseRequest{ - License: "h" + lic, - }) - errResp := &codersdk.Error{} - if xerrors.As(err, &errResp) { - assert.Equal(t, 400, errResp.StatusCode()) - } else { - t.Error("expected to get error status 400") - } - }) -} - -// these tests patch the map of license keys, so cannot be run in parallel -// nolint:paralleltest -func TestGetLicense(t *testing.T) { - pubKey, privKey, err := ed25519.GenerateKey(rand.Reader) - require.NoError(t, err) - keyID := "testing" - oldKeys := keys - defer func() { - t.Log("restoring keys") - keys = oldKeys - }() - keys = map[string]ed25519.PublicKey{keyID: pubKey} - - t.Run("GET", func(t *testing.T) { - client := coderdtest.New(t, &coderdtest.Options{APIBuilder: NewEnterprise}) - _ = coderdtest.CreateFirstUser(t, client) - ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) - defer cancel() - - claims := &Claims{ - RegisteredClaims: jwt.RegisteredClaims{ - Issuer: "test@coder.test", - IssuedAt: jwt.NewNumericDate(time.Now()), - NotBefore: jwt.NewNumericDate(time.Now()), - ExpiresAt: jwt.NewNumericDate(time.Now().Add(2 * time.Hour)), - }, - LicenseExpires: jwt.NewNumericDate(time.Now().Add(time.Hour)), - AccountType: AccountTypeSalesforce, - AccountID: "testing", - Version: CurrentVersion, - Features: Features{ - UserLimit: 0, - AuditLog: 1, - }, - } - lic, err := makeLicense(claims, privKey, keyID) - require.NoError(t, err) - _, err = client.AddLicense(ctx, codersdk.AddLicenseRequest{ - License: lic, - }) - require.NoError(t, err) - - // 2nd license - claims.AccountID = "testing2" - claims.Features.UserLimit = 200 - lic2, err := makeLicense(claims, privKey, keyID) - require.NoError(t, err) - _, err = client.AddLicense(ctx, codersdk.AddLicenseRequest{ - License: lic2, - }) - require.NoError(t, err) - - licenses, err := client.Licenses(ctx) - require.NoError(t, err) - require.Len(t, licenses, 2) - assert.Equal(t, int32(1), licenses[0].ID) - assert.Equal(t, "testing", licenses[0].Claims["account_id"]) - assert.Equal(t, map[string]interface{}{ - codersdk.FeatureUserLimit: json.Number("0"), - codersdk.FeatureAuditLog: json.Number("1"), - }, licenses[0].Claims["features"]) - assert.Equal(t, int32(2), licenses[1].ID) - assert.Equal(t, "testing2", licenses[1].Claims["account_id"]) - assert.Equal(t, map[string]interface{}{ - codersdk.FeatureUserLimit: json.Number("200"), - codersdk.FeatureAuditLog: json.Number("1"), - }, licenses[1].Claims["features"]) - }) -} - -// these tests patch the map of license keys, so cannot be run in parallel -// nolint:paralleltest -func TestDeleteLicense(t *testing.T) { - pubKey, privKey, err := ed25519.GenerateKey(rand.Reader) - require.NoError(t, err) - keyID := "testing" - oldKeys := keys - defer func() { - t.Log("restoring keys") - keys = oldKeys - }() - keys = map[string]ed25519.PublicKey{keyID: pubKey} - - t.Run("DELETE_empty", func(t *testing.T) { - client := coderdtest.New(t, &coderdtest.Options{APIBuilder: NewEnterprise}) - _ = coderdtest.CreateFirstUser(t, client) - ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) - defer cancel() - - err := client.DeleteLicense(ctx, 1) - errResp := &codersdk.Error{} - if xerrors.As(err, &errResp) { - assert.Equal(t, 404, errResp.StatusCode()) - } else { - t.Error("expected to get error status 404") - } - }) - - t.Run("DELETE_bad_id", func(t *testing.T) { - client := coderdtest.New(t, &coderdtest.Options{APIBuilder: NewEnterprise}) - _ = coderdtest.CreateFirstUser(t, client) - ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) - defer cancel() - - resp, err := client.Request(ctx, http.MethodDelete, "/api/v2/licenses/drivers", nil) - require.NoError(t, err) - assert.Equal(t, http.StatusNotFound, resp.StatusCode) - require.NoError(t, resp.Body.Close()) - }) - - t.Run("DELETE", func(t *testing.T) { - client := coderdtest.New(t, &coderdtest.Options{APIBuilder: NewEnterprise}) - _ = coderdtest.CreateFirstUser(t, client) - ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) - defer cancel() - - claims := &Claims{ - RegisteredClaims: jwt.RegisteredClaims{ - Issuer: "test@coder.test", - IssuedAt: jwt.NewNumericDate(time.Now()), - NotBefore: jwt.NewNumericDate(time.Now()), - ExpiresAt: jwt.NewNumericDate(time.Now().Add(2 * time.Hour)), - }, - LicenseExpires: jwt.NewNumericDate(time.Now().Add(time.Hour)), - AccountType: AccountTypeSalesforce, - AccountID: "testing", - Version: CurrentVersion, - Features: Features{ - UserLimit: 0, - AuditLog: 1, - }, - } - lic, err := makeLicense(claims, privKey, keyID) - require.NoError(t, err) - _, err = client.AddLicense(ctx, codersdk.AddLicenseRequest{ - License: lic, - }) - require.NoError(t, err) - - // 2nd license - claims.AccountID = "testing2" - claims.Features.UserLimit = 200 - lic2, err := makeLicense(claims, privKey, keyID) - require.NoError(t, err) - _, err = client.AddLicense(ctx, codersdk.AddLicenseRequest{ - License: lic2, - }) - require.NoError(t, err) - - licenses, err := client.Licenses(ctx) - require.NoError(t, err) - assert.Len(t, licenses, 2) - for _, l := range licenses { - err = client.DeleteLicense(ctx, l.ID) - require.NoError(t, err) - } - licenses, err = client.Licenses(ctx) - require.NoError(t, err) - assert.Len(t, licenses, 0) - }) -} - -func makeLicense(c *Claims, privateKey ed25519.PrivateKey, keyID string) (string, error) { - tok := jwt.NewWithClaims(jwt.SigningMethodEdDSA, c) - tok.Header[HeaderKeyID] = keyID - signedTok, err := tok.SignedString(privateKey) - if err != nil { - return "", xerrors.Errorf("sign license: %w", err) - } - return signedTok, nil -} diff --git a/enterprise/coderd/licenses_test.go b/enterprise/coderd/licenses_test.go new file mode 100644 index 0000000000000..243898a43ca73 --- /dev/null +++ b/enterprise/coderd/licenses_test.go @@ -0,0 +1,168 @@ +package coderd_test + +import ( + "context" + "encoding/json" + "net/http" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/xerrors" + + "github.com/coder/coder/coderd/coderdtest" + "github.com/coder/coder/codersdk" + "github.com/coder/coder/enterprise/coderd" + "github.com/coder/coder/enterprise/coderd/coderdenttest" + "github.com/coder/coder/testutil" +) + +func TestPostLicense(t *testing.T) { + t.Parallel() + + t.Run("Success", func(t *testing.T) { + t.Parallel() + client := coderdenttest.New(t, nil) + _ = coderdtest.CreateFirstUser(t, client) + respLic := coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{ + AccountType: coderd.AccountTypeSalesforce, + AccountID: "testing", + AuditLog: true, + }) + assert.GreaterOrEqual(t, respLic.ID, int32(0)) + // just a couple spot checks for sanity + assert.Equal(t, "testing", respLic.Claims["account_id"]) + features, ok := respLic.Claims["features"].(map[string]interface{}) + require.True(t, ok) + assert.Equal(t, json.Number("1"), features[codersdk.FeatureAuditLog]) + }) + + t.Run("Unauthorized", func(t *testing.T) { + t.Parallel() + client := coderdenttest.New(t, nil) + _, err := client.AddLicense(context.Background(), codersdk.AddLicenseRequest{ + License: "content", + }) + errResp := &codersdk.Error{} + if xerrors.As(err, &errResp) { + assert.Equal(t, 401, errResp.StatusCode()) + } else { + t.Error("expected to get error status 401") + } + }) + + t.Run("Corrupted", func(t *testing.T) { + t.Parallel() + client := coderdenttest.New(t, nil) + _ = coderdtest.CreateFirstUser(t, client) + coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{}) + _, err := client.AddLicense(context.Background(), codersdk.AddLicenseRequest{ + License: "invalid", + }) + errResp := &codersdk.Error{} + if xerrors.As(err, &errResp) { + assert.Equal(t, 400, errResp.StatusCode()) + } else { + t.Error("expected to get error status 400") + } + }) +} + +func TestGetLicense(t *testing.T) { + t.Parallel() + t.Run("Success", func(t *testing.T) { + t.Parallel() + client := coderdenttest.New(t, nil) + _ = coderdtest.CreateFirstUser(t, client) + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) + defer cancel() + + coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{ + AccountID: "testing", + AuditLog: true, + }) + + coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{ + AccountID: "testing2", + AuditLog: true, + UserLimit: 200, + }) + + licenses, err := client.Licenses(ctx) + require.NoError(t, err) + require.Len(t, licenses, 2) + assert.Equal(t, int32(1), licenses[0].ID) + assert.Equal(t, "testing", licenses[0].Claims["account_id"]) + assert.Equal(t, map[string]interface{}{ + codersdk.FeatureUserLimit: json.Number("0"), + codersdk.FeatureAuditLog: json.Number("1"), + }, licenses[0].Claims["features"]) + assert.Equal(t, int32(2), licenses[1].ID) + assert.Equal(t, "testing2", licenses[1].Claims["account_id"]) + assert.Equal(t, map[string]interface{}{ + codersdk.FeatureUserLimit: json.Number("200"), + codersdk.FeatureAuditLog: json.Number("1"), + }, licenses[1].Claims["features"]) + }) +} + +func TestDeleteLicense(t *testing.T) { + t.Parallel() + t.Run("Empty", func(t *testing.T) { + t.Parallel() + client := coderdenttest.New(t, nil) + _ = coderdtest.CreateFirstUser(t, client) + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) + defer cancel() + + err := client.DeleteLicense(ctx, 1) + errResp := &codersdk.Error{} + if xerrors.As(err, &errResp) { + assert.Equal(t, 404, errResp.StatusCode()) + } else { + t.Error("expected to get error status 404") + } + }) + + t.Run("BadID", func(t *testing.T) { + t.Parallel() + client := coderdenttest.New(t, nil) + _ = coderdtest.CreateFirstUser(t, client) + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) + defer cancel() + + resp, err := client.Request(ctx, http.MethodDelete, "/api/v2/licenses/drivers", nil) + require.NoError(t, err) + assert.Equal(t, http.StatusNotFound, resp.StatusCode) + require.NoError(t, resp.Body.Close()) + }) + + t.Run("Success", func(t *testing.T) { + t.Parallel() + client := coderdenttest.New(t, nil) + _ = coderdtest.CreateFirstUser(t, client) + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) + defer cancel() + + coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{ + AccountID: "testing", + AuditLog: true, + }) + coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{ + AccountID: "testing2", + AuditLog: true, + UserLimit: 200, + }) + + licenses, err := client.Licenses(ctx) + require.NoError(t, err) + assert.Len(t, licenses, 2) + for _, l := range licenses { + err = client.DeleteLicense(ctx, l.ID) + require.NoError(t, err) + } + licenses, err = client.Licenses(ctx) + require.NoError(t, err) + assert.Len(t, licenses, 0) + }) +} diff --git a/flake.nix b/flake.nix index a5e4816b19e63..dfc44b91df36f 100644 --- a/flake.nix +++ b/flake.nix @@ -16,6 +16,7 @@ formatter = pkgs.nixpkgs-fmt; devShells.default = pkgs.mkShell { buildInputs = with pkgs; [ + bash bat drpc.defaultPackage.${system} exa diff --git a/site/src/api/api.ts b/site/src/api/api.ts index ea16315c3896e..d3dae5a17c765 100644 --- a/site/src/api/api.ts +++ b/site/src/api/api.ts @@ -16,6 +16,22 @@ export const hardCodedCSRFCookie = (): string => { return csrfToken } +// defaultEntitlements has a default set of disabled functionality. +export const defaultEntitlements = (): TypesGen.Entitlements => { + const features: TypesGen.Entitlements["features"] = {} + for (const feature in Types.FeatureNames) { + features[feature] = { + enabled: false, + entitlement: "not_entitled", + } + } + return { + features: features, + has_license: false, + warnings: [], + } +} + // Always attach CSRF token to all requests. // In puppeteer the document is undefined. In those cases, just // do nothing. @@ -424,8 +440,15 @@ export const putWorkspaceExtension = async ( } export const getEntitlements = async (): Promise => { - const response = await axios.get("/api/v2/entitlements") - return response.data + try { + const response = await axios.get("/api/v2/entitlements") + return response.data + } catch (error) { + if (axios.isAxiosError(error) && error.response?.status === 404) { + return defaultEntitlements() + } + throw error + } } export const getAuditLogs = async ( diff --git a/site/src/xServices/entitlements/entitlementsXService.ts b/site/src/xServices/entitlements/entitlementsXService.ts index 3eee8a5e43ac6..eb3792bd650e9 100644 --- a/site/src/xServices/entitlements/entitlementsXService.ts +++ b/site/src/xServices/entitlements/entitlementsXService.ts @@ -84,7 +84,7 @@ export const entitlementsMachine = createMachine( }), }, services: { - getEntitlements: () => API.getEntitlements(), + getEntitlements: API.getEntitlements, }, }, )