From 7f1d7c4240d927f15247bb6bb1c379f2ec3f6161 Mon Sep 17 00:00:00 2001 From: Ammar Bandukwala Date: Tue, 22 Aug 2023 02:45:00 +0000 Subject: [PATCH] fix: correctly reject quota-violating builds Due to a logical error in CommitQuota, all workspace Stop->Start operations were being accepted, regardless of the Quota limit. This issue only appeared after #9201, so this was a minor regression in main for about 3 days. This PR adds a test to make sure this kind of bug doesn't recur. To make the new test possible, we give the echo provisioner the ability to simulate responses to specific transitions. Side note, I changed workspaceproxy.go to return a nil value in a case to pass staticcheck. --- enterprise/coderd/coderd.go | 5 +- enterprise/coderd/workspacequota.go | 33 ++++--- enterprise/coderd/workspacequota_test.go | 110 +++++++++++++++++++++-- enterprise/tailnet/workspaceproxy.go | 4 + provisioner/echo/serve.go | 105 ++++++++++++++-------- provisioner/echo/serve_test.go | 86 ++++++++++++++++++ provisionerd/runner/runner.go | 3 +- 7 files changed, 292 insertions(+), 54 deletions(-) diff --git a/enterprise/coderd/coderd.go b/enterprise/coderd/coderd.go index fc74ae281fca8..ddba75da35269 100644 --- a/enterprise/coderd/coderd.go +++ b/enterprise/coderd/coderd.go @@ -498,7 +498,10 @@ func (api *API) updateEntitlements(ctx context.Context) error { if initial, changed, enabled := featureChanged(codersdk.FeatureTemplateRBAC); shouldUpdate(initial, changed, enabled) { if enabled { - committer := committer{Database: api.Database} + committer := committer{ + Log: api.Logger.Named("quota_committer"), + Database: api.Database, + } ptr := proto.QuotaCommitter(&committer) api.AGPL.QuotaCommitter.Store(&ptr) } else { diff --git a/enterprise/coderd/workspacequota.go b/enterprise/coderd/workspacequota.go index bb25771f6775e..44ea3f302ff37 100644 --- a/enterprise/coderd/workspacequota.go +++ b/enterprise/coderd/workspacequota.go @@ -3,10 +3,12 @@ package coderd import ( "context" "database/sql" + "errors" "net/http" "github.com/google/uuid" - "golang.org/x/xerrors" + + "cdr.dev/slog" "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/httpapi" @@ -17,6 +19,7 @@ import ( ) type committer struct { + Log slog.Logger Database database.Store } @@ -28,12 +31,12 @@ func (c *committer) CommitQuota( return nil, err } - build, err := c.Database.GetWorkspaceBuildByJobID(ctx, jobID) + nextBuild, err := c.Database.GetWorkspaceBuildByJobID(ctx, jobID) if err != nil { return nil, err } - workspace, err := c.Database.GetWorkspaceByID(ctx, build.WorkspaceID) + workspace, err := c.Database.GetWorkspaceByID(ctx, nextBuild.WorkspaceID) if err != nil { return nil, err } @@ -58,25 +61,35 @@ func (c *committer) CommitQuota( // If the new build will reduce overall quota consumption, then we // allow it even if the user is over quota. netIncrease := true - previousBuild, err := s.GetWorkspaceBuildByWorkspaceIDAndBuildNumber(ctx, database.GetWorkspaceBuildByWorkspaceIDAndBuildNumberParams{ + prevBuild, err := s.GetWorkspaceBuildByWorkspaceIDAndBuildNumber(ctx, database.GetWorkspaceBuildByWorkspaceIDAndBuildNumberParams{ WorkspaceID: workspace.ID, - BuildNumber: build.BuildNumber - 1, + BuildNumber: nextBuild.BuildNumber - 1, }) if err == nil { - if build.DailyCost < previousBuild.DailyCost { - netIncrease = false - } - } else if !xerrors.Is(err, sql.ErrNoRows) { + netIncrease = request.DailyCost >= prevBuild.DailyCost + c.Log.Debug( + ctx, "previous build cost", + slog.F("prev_cost", prevBuild.DailyCost), + slog.F("next_cost", request.DailyCost), + slog.F("net_increase", netIncrease), + ) + } else if !errors.Is(err, sql.ErrNoRows) { return err } newConsumed := int64(request.DailyCost) + consumed if newConsumed > budget && netIncrease { + c.Log.Debug( + ctx, "over quota, rejecting", + slog.F("prev_consumed", consumed), + slog.F("next_consumed", newConsumed), + slog.F("budget", budget), + ) return nil } err = s.UpdateWorkspaceBuildCostByID(ctx, database.UpdateWorkspaceBuildCostByIDParams{ - ID: build.ID, + ID: nextBuild.ID, DailyCost: request.DailyCost, }) if err != nil { diff --git a/enterprise/coderd/workspacequota_test.go b/enterprise/coderd/workspacequota_test.go index a1e80da7c8f75..3119168696a36 100644 --- a/enterprise/coderd/workspacequota_test.go +++ b/enterprise/coderd/workspacequota_test.go @@ -10,6 +10,7 @@ import ( "github.com/stretchr/testify/require" "github.com/coder/coder/v2/coderd/coderdtest" + "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/util/ptr" "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/enterprise/coderd/coderdenttest" @@ -31,12 +32,13 @@ func verifyQuota(ctx context.Context, t *testing.T, client *codersdk.Client, con } func TestWorkspaceQuota(t *testing.T) { - // TODO: refactor for new impl - t.Parallel() - t.Run("BlocksBuild", func(t *testing.T) { + // This first test verifies the behavior of creating and deleting workspaces. + // It also tests multi-group quota stacking and the everyone group. + t.Run("CreateDelete", func(t *testing.T) { t.Parallel() + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) defer cancel() max := 1 @@ -49,8 +51,6 @@ func TestWorkspaceQuota(t *testing.T) { }, }) coderdtest.NewProvisionerDaemon(t, api.AGPL) - coderdtest.NewProvisionerDaemon(t, api.AGPL) - coderdtest.NewProvisionerDaemon(t, api.AGPL) verifyQuota(ctx, t, client, 0, 0) @@ -157,4 +157,104 @@ func TestWorkspaceQuota(t *testing.T) { verifyQuota(ctx, t, client, 4, 4) require.Equal(t, codersdk.WorkspaceStatusRunning, build.Status) }) + + t.Run("StartStop", func(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) + defer cancel() + max := 1 + client, _, api, user := coderdenttest.NewWithAPI(t, &coderdenttest.Options{ + UserWorkspaceQuota: max, + LicenseOptions: &coderdenttest.LicenseOptions{ + Features: license.Features{ + codersdk.FeatureTemplateRBAC: 1, + }, + }, + }) + coderdtest.NewProvisionerDaemon(t, api.AGPL) + + verifyQuota(ctx, t, client, 0, 0) + + // Patch the 'Everyone' group to verify its quota allowance is being accounted for. + _, err := client.PatchGroup(ctx, user.OrganizationID, codersdk.PatchGroupRequest{ + QuotaAllowance: ptr.Ref(4), + }) + require.NoError(t, err) + verifyQuota(ctx, t, client, 0, 4) + + stopResp := []*proto.Provision_Response{{ + Type: &proto.Provision_Response_Complete{ + Complete: &proto.Provision_Complete{ + Resources: []*proto.Resource{{ + Name: "example", + Type: "aws_instance", + DailyCost: 1, + }}, + }, + }, + }} + + startResp := []*proto.Provision_Response{{ + Type: &proto.Provision_Response_Complete{ + Complete: &proto.Provision_Complete{ + Resources: []*proto.Resource{{ + Name: "example", + Type: "aws_instance", + DailyCost: 2, + }}, + }, + }, + }} + + version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{ + Parse: echo.ParseComplete, + ProvisionPlanMap: map[proto.WorkspaceTransition][]*proto.Provision_Response{ + proto.WorkspaceTransition_START: startResp, + proto.WorkspaceTransition_STOP: stopResp, + }, + ProvisionApplyMap: map[proto.WorkspaceTransition][]*proto.Provision_Response{ + proto.WorkspaceTransition_START: startResp, + proto.WorkspaceTransition_STOP: stopResp, + }, + }) + + coderdtest.AwaitTemplateVersionJob(t, client, version.ID) + template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID) + + // Spin up two workspaces. + var wg sync.WaitGroup + var workspaces []codersdk.Workspace + for i := 0; i < 2; i++ { + workspace := coderdtest.CreateWorkspace(t, client, user.OrganizationID, template.ID) + workspaces = append(workspaces, workspace) + build := coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID) + assert.Equal(t, codersdk.WorkspaceStatusRunning, build.Status) + } + wg.Wait() + verifyQuota(ctx, t, client, 4, 4) + + // Next one must fail + workspace := coderdtest.CreateWorkspace(t, client, user.OrganizationID, template.ID) + build := coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID) + require.Contains(t, build.Job.Error, "quota") + + // Consumed shouldn't bump + verifyQuota(ctx, t, client, 4, 4) + require.Equal(t, codersdk.WorkspaceStatusFailed, build.Status) + + build = coderdtest.CreateWorkspaceBuild(t, client, workspaces[0], database.WorkspaceTransitionStop) + build = coderdtest.AwaitWorkspaceBuildJob(t, client, build.ID) + + // Quota goes down one + verifyQuota(ctx, t, client, 3, 4) + require.Equal(t, codersdk.WorkspaceStatusStopped, build.Status) + + build = coderdtest.CreateWorkspaceBuild(t, client, workspaces[0], database.WorkspaceTransitionStart) + build = coderdtest.AwaitWorkspaceBuildJob(t, client, build.ID) + + // Quota goes back up + verifyQuota(ctx, t, client, 4, 4) + require.Equal(t, codersdk.WorkspaceStatusRunning, build.Status) + }) } diff --git a/enterprise/tailnet/workspaceproxy.go b/enterprise/tailnet/workspaceproxy.go index 3011f100e6a5d..3150890c13fa9 100644 --- a/enterprise/tailnet/workspaceproxy.go +++ b/enterprise/tailnet/workspaceproxy.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "encoding/json" + "errors" "net" "time" @@ -26,6 +27,9 @@ func ServeWorkspaceProxy(ctx context.Context, conn net.Conn, ma agpl.MultiAgentC var msg wsproxysdk.CoordinateMessage err := decoder.Decode(&msg) if err != nil { + if errors.Is(err, net.ErrClosed) { + return nil + } return xerrors.Errorf("read json: %w", err) } diff --git a/provisioner/echo/serve.go b/provisioner/echo/serve.go index 50f6cc60b2257..c057254704f3d 100644 --- a/provisioner/echo/serve.go +++ b/provisioner/echo/serve.go @@ -127,19 +127,35 @@ func (e *echo) Provision(stream proto.DRPCProvisioner_ProvisionStream) error { return nil } - for index := 0; ; index++ { +outer: + for i := 0; ; i++ { var extension string if msg.GetPlan() != nil { extension = ".plan.protobuf" } else { extension = ".apply.protobuf" } - path := filepath.Join(config.Directory, fmt.Sprintf("%d.provision"+extension, index)) - _, err := e.filesystem.Stat(path) - if err != nil { - if index == 0 { - // Error if nothing is around to enable failed states. - return xerrors.New("no state") + var ( + path string + pathIndex int + ) + // Try more specific path first, then fallback to generic. + paths := []string{ + filepath.Join(config.Directory, fmt.Sprintf("%d.%s.provision"+extension, i, strings.ToLower(config.GetMetadata().GetWorkspaceTransition().String()))), + filepath.Join(config.Directory, fmt.Sprintf("%d.provision"+extension, i)), + } + for pathIndex, path = range paths { + _, err := e.filesystem.Stat(path) + if err != nil && pathIndex == len(paths)-1 { + // If there are zero messages, something is wrong. + if i == 0 { + // Error if nothing is around to enable failed states. + return xerrors.New("no state") + } + // Otherwise, we're done with the entire provision. + break outer + } else if err != nil { + continue } break } @@ -170,16 +186,28 @@ func (*echo) Shutdown(_ context.Context, _ *proto.Empty) (*proto.Empty, error) { return &proto.Empty{}, nil } +// Responses is a collection of mocked responses to Provision operations. type Responses struct { - Parse []*proto.Parse_Response + Parse []*proto.Parse_Response + + // ProvisionApply and ProvisionPlan are used to mock ALL responses of + // Apply and Plan, regardless of transition. ProvisionApply []*proto.Provision_Response ProvisionPlan []*proto.Provision_Response + + // ProvisionApplyMap and ProvisionPlanMap are used to mock specific + // transition responses. They are prioritized over the generic responses. + ProvisionApplyMap map[proto.WorkspaceTransition][]*proto.Provision_Response + ProvisionPlanMap map[proto.WorkspaceTransition][]*proto.Provision_Response } // Tar returns a tar archive of responses to provisioner operations. func Tar(responses *Responses) ([]byte, error) { if responses == nil { - responses = &Responses{ParseComplete, ProvisionComplete, ProvisionComplete} + responses = &Responses{ + ParseComplete, ProvisionComplete, ProvisionComplete, + nil, nil, + } } if responses.ProvisionPlan == nil { responses.ProvisionPlan = responses.ProvisionApply @@ -187,58 +215,61 @@ func Tar(responses *Responses) ([]byte, error) { var buffer bytes.Buffer writer := tar.NewWriter(&buffer) - for index, response := range responses.Parse { - data, err := protobuf.Marshal(response) + + writeProto := func(name string, message protobuf.Message) error { + data, err := protobuf.Marshal(message) if err != nil { - return nil, err + return err } + err = writer.WriteHeader(&tar.Header{ - Name: fmt.Sprintf("%d.parse.protobuf", index), + Name: name, Size: int64(len(data)), Mode: 0o644, }) if err != nil { - return nil, err + return err } + _, err = writer.Write(data) if err != nil { - return nil, err + return err } + + return nil } - for index, response := range responses.ProvisionApply { - data, err := protobuf.Marshal(response) - if err != nil { - return nil, err - } - err = writer.WriteHeader(&tar.Header{ - Name: fmt.Sprintf("%d.provision.apply.protobuf", index), - Size: int64(len(data)), - Mode: 0o644, - }) + for index, response := range responses.Parse { + err := writeProto(fmt.Sprintf("%d.parse.protobuf", index), response) if err != nil { return nil, err } - _, err = writer.Write(data) + } + for index, response := range responses.ProvisionApply { + err := writeProto(fmt.Sprintf("%d.provision.apply.protobuf", index), response) if err != nil { return nil, err } } for index, response := range responses.ProvisionPlan { - data, err := protobuf.Marshal(response) + err := writeProto(fmt.Sprintf("%d.provision.plan.protobuf", index), response) if err != nil { return nil, err } - err = writer.WriteHeader(&tar.Header{ - Name: fmt.Sprintf("%d.provision.plan.protobuf", index), - Size: int64(len(data)), - Mode: 0o644, - }) - if err != nil { - return nil, err + } + for trans, m := range responses.ProvisionApplyMap { + for i, rs := range m { + err := writeProto(fmt.Sprintf("%d.%s.provision.apply.protobuf", i, strings.ToLower(trans.String())), rs) + if err != nil { + return nil, err + } } - _, err = writer.Write(data) - if err != nil { - return nil, err + } + for trans, m := range responses.ProvisionPlanMap { + for i, rs := range m { + err := writeProto(fmt.Sprintf("%d.%s.provision.plan.protobuf", i, strings.ToLower(trans.String())), rs) + if err != nil { + return nil, err + } } } err := writer.Flush() diff --git a/provisioner/echo/serve_test.go b/provisioner/echo/serve_test.go index 539fab0c57536..01b283f8a55f5 100644 --- a/provisioner/echo/serve_test.go +++ b/provisioner/echo/serve_test.go @@ -112,6 +112,92 @@ func TestEcho(t *testing.T) { complete.GetComplete().Resources[0].Name) }) + t.Run("ProvisionStop", func(t *testing.T) { + t.Parallel() + + // Stop responses should be returned when the workspace is being stopped. + + defaultResponses := []*proto.Provision_Response{{ + Type: &proto.Provision_Response_Complete{ + Complete: &proto.Provision_Complete{ + Resources: []*proto.Resource{{ + Name: "DEFAULT", + }}, + }, + }, + }} + stopResponses := []*proto.Provision_Response{{ + Type: &proto.Provision_Response_Complete{ + Complete: &proto.Provision_Complete{ + Resources: []*proto.Resource{{ + Name: "STOP", + }}, + }, + }, + }} + data, err := echo.Tar(&echo.Responses{ + ProvisionApply: defaultResponses, + ProvisionPlan: defaultResponses, + ProvisionPlanMap: map[proto.WorkspaceTransition][]*proto.Provision_Response{ + proto.WorkspaceTransition_STOP: stopResponses, + }, + ProvisionApplyMap: map[proto.WorkspaceTransition][]*proto.Provision_Response{ + proto.WorkspaceTransition_STOP: stopResponses, + }, + }) + require.NoError(t, err) + + client, err := api.Provision(ctx) + require.NoError(t, err) + + // Do stop. + err = client.Send(&proto.Provision_Request{ + Type: &proto.Provision_Request_Plan{ + Plan: &proto.Provision_Plan{ + Config: &proto.Provision_Config{ + Directory: unpackTar(t, fs, data), + Metadata: &proto.Provision_Metadata{ + WorkspaceTransition: proto.WorkspaceTransition_STOP, + }, + }, + }, + }, + }) + require.NoError(t, err) + + complete, err := client.Recv() + require.NoError(t, err) + require.Equal(t, + stopResponses[0].GetComplete().Resources[0].Name, + complete.GetComplete().Resources[0].Name, + ) + + // Do start. + client, err = api.Provision(ctx) + require.NoError(t, err) + + err = client.Send(&proto.Provision_Request{ + Type: &proto.Provision_Request_Plan{ + Plan: &proto.Provision_Plan{ + Config: &proto.Provision_Config{ + Directory: unpackTar(t, fs, data), + Metadata: &proto.Provision_Metadata{ + WorkspaceTransition: proto.WorkspaceTransition_START, + }, + }, + }, + }, + }) + require.NoError(t, err) + + complete, err = client.Recv() + require.NoError(t, err) + require.Equal(t, + defaultResponses[0].GetComplete().Resources[0].Name, + complete.GetComplete().Resources[0].Name, + ) + }) + t.Run("ProvisionWithLogLevel", func(t *testing.T) { t.Parallel() diff --git a/provisionerd/runner/runner.go b/provisionerd/runner/runner.go index a42ce5c2da375..5911004f98e2e 100644 --- a/provisionerd/runner/runner.go +++ b/provisionerd/runner/runner.go @@ -964,10 +964,11 @@ func (r *Runner) buildWorkspace(ctx context.Context, stage string, req *sdkproto } func (r *Runner) commitQuota(ctx context.Context, resources []*sdkproto.Resource) *proto.FailedJob { + cost := sumDailyCost(resources) r.logger.Debug(ctx, "committing quota", slog.F("resources", resources), + slog.F("cost", cost), ) - cost := sumDailyCost(resources) if cost == 0 { return nil }