diff --git a/enterprise/coderd/coderd.go b/enterprise/coderd/coderd.go index fc74ae281fca8..ddba75da35269 100644 --- a/enterprise/coderd/coderd.go +++ b/enterprise/coderd/coderd.go @@ -498,7 +498,10 @@ func (api *API) updateEntitlements(ctx context.Context) error { if initial, changed, enabled := featureChanged(codersdk.FeatureTemplateRBAC); shouldUpdate(initial, changed, enabled) { if enabled { - committer := committer{Database: api.Database} + committer := committer{ + Log: api.Logger.Named("quota_committer"), + Database: api.Database, + } ptr := proto.QuotaCommitter(&committer) api.AGPL.QuotaCommitter.Store(&ptr) } else { diff --git a/enterprise/coderd/workspacequota.go b/enterprise/coderd/workspacequota.go index bb25771f6775e..44ea3f302ff37 100644 --- a/enterprise/coderd/workspacequota.go +++ b/enterprise/coderd/workspacequota.go @@ -3,10 +3,12 @@ package coderd import ( "context" "database/sql" + "errors" "net/http" "github.com/google/uuid" - "golang.org/x/xerrors" + + "cdr.dev/slog" "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/httpapi" @@ -17,6 +19,7 @@ import ( ) type committer struct { + Log slog.Logger Database database.Store } @@ -28,12 +31,12 @@ func (c *committer) CommitQuota( return nil, err } - build, err := c.Database.GetWorkspaceBuildByJobID(ctx, jobID) + nextBuild, err := c.Database.GetWorkspaceBuildByJobID(ctx, jobID) if err != nil { return nil, err } - workspace, err := c.Database.GetWorkspaceByID(ctx, build.WorkspaceID) + workspace, err := c.Database.GetWorkspaceByID(ctx, nextBuild.WorkspaceID) if err != nil { return nil, err } @@ -58,25 +61,35 @@ func (c *committer) CommitQuota( // If the new build will reduce overall quota consumption, then we // allow it even if the user is over quota. netIncrease := true - previousBuild, err := s.GetWorkspaceBuildByWorkspaceIDAndBuildNumber(ctx, database.GetWorkspaceBuildByWorkspaceIDAndBuildNumberParams{ + prevBuild, err := s.GetWorkspaceBuildByWorkspaceIDAndBuildNumber(ctx, database.GetWorkspaceBuildByWorkspaceIDAndBuildNumberParams{ WorkspaceID: workspace.ID, - BuildNumber: build.BuildNumber - 1, + BuildNumber: nextBuild.BuildNumber - 1, }) if err == nil { - if build.DailyCost < previousBuild.DailyCost { - netIncrease = false - } - } else if !xerrors.Is(err, sql.ErrNoRows) { + netIncrease = request.DailyCost >= prevBuild.DailyCost + c.Log.Debug( + ctx, "previous build cost", + slog.F("prev_cost", prevBuild.DailyCost), + slog.F("next_cost", request.DailyCost), + slog.F("net_increase", netIncrease), + ) + } else if !errors.Is(err, sql.ErrNoRows) { return err } newConsumed := int64(request.DailyCost) + consumed if newConsumed > budget && netIncrease { + c.Log.Debug( + ctx, "over quota, rejecting", + slog.F("prev_consumed", consumed), + slog.F("next_consumed", newConsumed), + slog.F("budget", budget), + ) return nil } err = s.UpdateWorkspaceBuildCostByID(ctx, database.UpdateWorkspaceBuildCostByIDParams{ - ID: build.ID, + ID: nextBuild.ID, DailyCost: request.DailyCost, }) if err != nil { diff --git a/enterprise/coderd/workspacequota_test.go b/enterprise/coderd/workspacequota_test.go index a1e80da7c8f75..3119168696a36 100644 --- a/enterprise/coderd/workspacequota_test.go +++ b/enterprise/coderd/workspacequota_test.go @@ -10,6 +10,7 @@ import ( "github.com/stretchr/testify/require" "github.com/coder/coder/v2/coderd/coderdtest" + "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/util/ptr" "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/enterprise/coderd/coderdenttest" @@ -31,12 +32,13 @@ func verifyQuota(ctx context.Context, t *testing.T, client *codersdk.Client, con } func TestWorkspaceQuota(t *testing.T) { - // TODO: refactor for new impl - t.Parallel() - t.Run("BlocksBuild", func(t *testing.T) { + // This first test verifies the behavior of creating and deleting workspaces. + // It also tests multi-group quota stacking and the everyone group. + t.Run("CreateDelete", func(t *testing.T) { t.Parallel() + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) defer cancel() max := 1 @@ -49,8 +51,6 @@ func TestWorkspaceQuota(t *testing.T) { }, }) coderdtest.NewProvisionerDaemon(t, api.AGPL) - coderdtest.NewProvisionerDaemon(t, api.AGPL) - coderdtest.NewProvisionerDaemon(t, api.AGPL) verifyQuota(ctx, t, client, 0, 0) @@ -157,4 +157,104 @@ func TestWorkspaceQuota(t *testing.T) { verifyQuota(ctx, t, client, 4, 4) require.Equal(t, codersdk.WorkspaceStatusRunning, build.Status) }) + + t.Run("StartStop", func(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) + defer cancel() + max := 1 + client, _, api, user := coderdenttest.NewWithAPI(t, &coderdenttest.Options{ + UserWorkspaceQuota: max, + LicenseOptions: &coderdenttest.LicenseOptions{ + Features: license.Features{ + codersdk.FeatureTemplateRBAC: 1, + }, + }, + }) + coderdtest.NewProvisionerDaemon(t, api.AGPL) + + verifyQuota(ctx, t, client, 0, 0) + + // Patch the 'Everyone' group to verify its quota allowance is being accounted for. + _, err := client.PatchGroup(ctx, user.OrganizationID, codersdk.PatchGroupRequest{ + QuotaAllowance: ptr.Ref(4), + }) + require.NoError(t, err) + verifyQuota(ctx, t, client, 0, 4) + + stopResp := []*proto.Provision_Response{{ + Type: &proto.Provision_Response_Complete{ + Complete: &proto.Provision_Complete{ + Resources: []*proto.Resource{{ + Name: "example", + Type: "aws_instance", + DailyCost: 1, + }}, + }, + }, + }} + + startResp := []*proto.Provision_Response{{ + Type: &proto.Provision_Response_Complete{ + Complete: &proto.Provision_Complete{ + Resources: []*proto.Resource{{ + Name: "example", + Type: "aws_instance", + DailyCost: 2, + }}, + }, + }, + }} + + version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{ + Parse: echo.ParseComplete, + ProvisionPlanMap: map[proto.WorkspaceTransition][]*proto.Provision_Response{ + proto.WorkspaceTransition_START: startResp, + proto.WorkspaceTransition_STOP: stopResp, + }, + ProvisionApplyMap: map[proto.WorkspaceTransition][]*proto.Provision_Response{ + proto.WorkspaceTransition_START: startResp, + proto.WorkspaceTransition_STOP: stopResp, + }, + }) + + coderdtest.AwaitTemplateVersionJob(t, client, version.ID) + template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID) + + // Spin up two workspaces. + var wg sync.WaitGroup + var workspaces []codersdk.Workspace + for i := 0; i < 2; i++ { + workspace := coderdtest.CreateWorkspace(t, client, user.OrganizationID, template.ID) + workspaces = append(workspaces, workspace) + build := coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID) + assert.Equal(t, codersdk.WorkspaceStatusRunning, build.Status) + } + wg.Wait() + verifyQuota(ctx, t, client, 4, 4) + + // Next one must fail + workspace := coderdtest.CreateWorkspace(t, client, user.OrganizationID, template.ID) + build := coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID) + require.Contains(t, build.Job.Error, "quota") + + // Consumed shouldn't bump + verifyQuota(ctx, t, client, 4, 4) + require.Equal(t, codersdk.WorkspaceStatusFailed, build.Status) + + build = coderdtest.CreateWorkspaceBuild(t, client, workspaces[0], database.WorkspaceTransitionStop) + build = coderdtest.AwaitWorkspaceBuildJob(t, client, build.ID) + + // Quota goes down one + verifyQuota(ctx, t, client, 3, 4) + require.Equal(t, codersdk.WorkspaceStatusStopped, build.Status) + + build = coderdtest.CreateWorkspaceBuild(t, client, workspaces[0], database.WorkspaceTransitionStart) + build = coderdtest.AwaitWorkspaceBuildJob(t, client, build.ID) + + // Quota goes back up + verifyQuota(ctx, t, client, 4, 4) + require.Equal(t, codersdk.WorkspaceStatusRunning, build.Status) + }) } diff --git a/enterprise/tailnet/workspaceproxy.go b/enterprise/tailnet/workspaceproxy.go index 3011f100e6a5d..3150890c13fa9 100644 --- a/enterprise/tailnet/workspaceproxy.go +++ b/enterprise/tailnet/workspaceproxy.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "encoding/json" + "errors" "net" "time" @@ -26,6 +27,9 @@ func ServeWorkspaceProxy(ctx context.Context, conn net.Conn, ma agpl.MultiAgentC var msg wsproxysdk.CoordinateMessage err := decoder.Decode(&msg) if err != nil { + if errors.Is(err, net.ErrClosed) { + return nil + } return xerrors.Errorf("read json: %w", err) } diff --git a/provisioner/echo/serve.go b/provisioner/echo/serve.go index 50f6cc60b2257..c057254704f3d 100644 --- a/provisioner/echo/serve.go +++ b/provisioner/echo/serve.go @@ -127,19 +127,35 @@ func (e *echo) Provision(stream proto.DRPCProvisioner_ProvisionStream) error { return nil } - for index := 0; ; index++ { +outer: + for i := 0; ; i++ { var extension string if msg.GetPlan() != nil { extension = ".plan.protobuf" } else { extension = ".apply.protobuf" } - path := filepath.Join(config.Directory, fmt.Sprintf("%d.provision"+extension, index)) - _, err := e.filesystem.Stat(path) - if err != nil { - if index == 0 { - // Error if nothing is around to enable failed states. - return xerrors.New("no state") + var ( + path string + pathIndex int + ) + // Try more specific path first, then fallback to generic. + paths := []string{ + filepath.Join(config.Directory, fmt.Sprintf("%d.%s.provision"+extension, i, strings.ToLower(config.GetMetadata().GetWorkspaceTransition().String()))), + filepath.Join(config.Directory, fmt.Sprintf("%d.provision"+extension, i)), + } + for pathIndex, path = range paths { + _, err := e.filesystem.Stat(path) + if err != nil && pathIndex == len(paths)-1 { + // If there are zero messages, something is wrong. + if i == 0 { + // Error if nothing is around to enable failed states. + return xerrors.New("no state") + } + // Otherwise, we're done with the entire provision. + break outer + } else if err != nil { + continue } break } @@ -170,16 +186,28 @@ func (*echo) Shutdown(_ context.Context, _ *proto.Empty) (*proto.Empty, error) { return &proto.Empty{}, nil } +// Responses is a collection of mocked responses to Provision operations. type Responses struct { - Parse []*proto.Parse_Response + Parse []*proto.Parse_Response + + // ProvisionApply and ProvisionPlan are used to mock ALL responses of + // Apply and Plan, regardless of transition. ProvisionApply []*proto.Provision_Response ProvisionPlan []*proto.Provision_Response + + // ProvisionApplyMap and ProvisionPlanMap are used to mock specific + // transition responses. They are prioritized over the generic responses. + ProvisionApplyMap map[proto.WorkspaceTransition][]*proto.Provision_Response + ProvisionPlanMap map[proto.WorkspaceTransition][]*proto.Provision_Response } // Tar returns a tar archive of responses to provisioner operations. func Tar(responses *Responses) ([]byte, error) { if responses == nil { - responses = &Responses{ParseComplete, ProvisionComplete, ProvisionComplete} + responses = &Responses{ + ParseComplete, ProvisionComplete, ProvisionComplete, + nil, nil, + } } if responses.ProvisionPlan == nil { responses.ProvisionPlan = responses.ProvisionApply @@ -187,58 +215,61 @@ func Tar(responses *Responses) ([]byte, error) { var buffer bytes.Buffer writer := tar.NewWriter(&buffer) - for index, response := range responses.Parse { - data, err := protobuf.Marshal(response) + + writeProto := func(name string, message protobuf.Message) error { + data, err := protobuf.Marshal(message) if err != nil { - return nil, err + return err } + err = writer.WriteHeader(&tar.Header{ - Name: fmt.Sprintf("%d.parse.protobuf", index), + Name: name, Size: int64(len(data)), Mode: 0o644, }) if err != nil { - return nil, err + return err } + _, err = writer.Write(data) if err != nil { - return nil, err + return err } + + return nil } - for index, response := range responses.ProvisionApply { - data, err := protobuf.Marshal(response) - if err != nil { - return nil, err - } - err = writer.WriteHeader(&tar.Header{ - Name: fmt.Sprintf("%d.provision.apply.protobuf", index), - Size: int64(len(data)), - Mode: 0o644, - }) + for index, response := range responses.Parse { + err := writeProto(fmt.Sprintf("%d.parse.protobuf", index), response) if err != nil { return nil, err } - _, err = writer.Write(data) + } + for index, response := range responses.ProvisionApply { + err := writeProto(fmt.Sprintf("%d.provision.apply.protobuf", index), response) if err != nil { return nil, err } } for index, response := range responses.ProvisionPlan { - data, err := protobuf.Marshal(response) + err := writeProto(fmt.Sprintf("%d.provision.plan.protobuf", index), response) if err != nil { return nil, err } - err = writer.WriteHeader(&tar.Header{ - Name: fmt.Sprintf("%d.provision.plan.protobuf", index), - Size: int64(len(data)), - Mode: 0o644, - }) - if err != nil { - return nil, err + } + for trans, m := range responses.ProvisionApplyMap { + for i, rs := range m { + err := writeProto(fmt.Sprintf("%d.%s.provision.apply.protobuf", i, strings.ToLower(trans.String())), rs) + if err != nil { + return nil, err + } } - _, err = writer.Write(data) - if err != nil { - return nil, err + } + for trans, m := range responses.ProvisionPlanMap { + for i, rs := range m { + err := writeProto(fmt.Sprintf("%d.%s.provision.plan.protobuf", i, strings.ToLower(trans.String())), rs) + if err != nil { + return nil, err + } } } err := writer.Flush() diff --git a/provisioner/echo/serve_test.go b/provisioner/echo/serve_test.go index 539fab0c57536..01b283f8a55f5 100644 --- a/provisioner/echo/serve_test.go +++ b/provisioner/echo/serve_test.go @@ -112,6 +112,92 @@ func TestEcho(t *testing.T) { complete.GetComplete().Resources[0].Name) }) + t.Run("ProvisionStop", func(t *testing.T) { + t.Parallel() + + // Stop responses should be returned when the workspace is being stopped. + + defaultResponses := []*proto.Provision_Response{{ + Type: &proto.Provision_Response_Complete{ + Complete: &proto.Provision_Complete{ + Resources: []*proto.Resource{{ + Name: "DEFAULT", + }}, + }, + }, + }} + stopResponses := []*proto.Provision_Response{{ + Type: &proto.Provision_Response_Complete{ + Complete: &proto.Provision_Complete{ + Resources: []*proto.Resource{{ + Name: "STOP", + }}, + }, + }, + }} + data, err := echo.Tar(&echo.Responses{ + ProvisionApply: defaultResponses, + ProvisionPlan: defaultResponses, + ProvisionPlanMap: map[proto.WorkspaceTransition][]*proto.Provision_Response{ + proto.WorkspaceTransition_STOP: stopResponses, + }, + ProvisionApplyMap: map[proto.WorkspaceTransition][]*proto.Provision_Response{ + proto.WorkspaceTransition_STOP: stopResponses, + }, + }) + require.NoError(t, err) + + client, err := api.Provision(ctx) + require.NoError(t, err) + + // Do stop. + err = client.Send(&proto.Provision_Request{ + Type: &proto.Provision_Request_Plan{ + Plan: &proto.Provision_Plan{ + Config: &proto.Provision_Config{ + Directory: unpackTar(t, fs, data), + Metadata: &proto.Provision_Metadata{ + WorkspaceTransition: proto.WorkspaceTransition_STOP, + }, + }, + }, + }, + }) + require.NoError(t, err) + + complete, err := client.Recv() + require.NoError(t, err) + require.Equal(t, + stopResponses[0].GetComplete().Resources[0].Name, + complete.GetComplete().Resources[0].Name, + ) + + // Do start. + client, err = api.Provision(ctx) + require.NoError(t, err) + + err = client.Send(&proto.Provision_Request{ + Type: &proto.Provision_Request_Plan{ + Plan: &proto.Provision_Plan{ + Config: &proto.Provision_Config{ + Directory: unpackTar(t, fs, data), + Metadata: &proto.Provision_Metadata{ + WorkspaceTransition: proto.WorkspaceTransition_START, + }, + }, + }, + }, + }) + require.NoError(t, err) + + complete, err = client.Recv() + require.NoError(t, err) + require.Equal(t, + defaultResponses[0].GetComplete().Resources[0].Name, + complete.GetComplete().Resources[0].Name, + ) + }) + t.Run("ProvisionWithLogLevel", func(t *testing.T) { t.Parallel() diff --git a/provisionerd/runner/runner.go b/provisionerd/runner/runner.go index a42ce5c2da375..5911004f98e2e 100644 --- a/provisionerd/runner/runner.go +++ b/provisionerd/runner/runner.go @@ -964,10 +964,11 @@ func (r *Runner) buildWorkspace(ctx context.Context, stage string, req *sdkproto } func (r *Runner) commitQuota(ctx context.Context, resources []*sdkproto.Resource) *proto.FailedJob { + cost := sumDailyCost(resources) r.logger.Debug(ctx, "committing quota", slog.F("resources", resources), + slog.F("cost", cost), ) - cost := sumDailyCost(resources) if cost == 0 { return nil }