Skip to content

Commit 760c718

Browse files
committed
fix: correctly reject quota-violating builds
Due to a logical error in CommitQuota, all workspace Stop->Start operations were being accepted, regardless of the Quota limit. This issue only appeared after #9201, so this was a minor regression in main for about 3 days. This PR adds a test to make sure this kind of bug doesn't recur. To make the new test possible, we give the echo provisioner the ability to simulate responses to specific transitions.
1 parent 4a9c773 commit 760c718

File tree

6 files changed

+288
-54
lines changed

6 files changed

+288
-54
lines changed

enterprise/coderd/coderd.go

+4-1
Original file line numberDiff line numberDiff line change
@@ -498,7 +498,10 @@ func (api *API) updateEntitlements(ctx context.Context) error {
498498

499499
if initial, changed, enabled := featureChanged(codersdk.FeatureTemplateRBAC); shouldUpdate(initial, changed, enabled) {
500500
if enabled {
501-
committer := committer{Database: api.Database}
501+
committer := committer{
502+
Log: api.Logger.Named("quota_committer"),
503+
Database: api.Database,
504+
}
502505
ptr := proto.QuotaCommitter(&committer)
503506
api.AGPL.QuotaCommitter.Store(&ptr)
504507
} else {

enterprise/coderd/workspacequota.go

+23-10
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,12 @@ package coderd
33
import (
44
"context"
55
"database/sql"
6+
"errors"
67
"net/http"
78

89
"github.com/google/uuid"
9-
"golang.org/x/xerrors"
10+
11+
"cdr.dev/slog"
1012

1113
"github.com/coder/coder/v2/coderd/database"
1214
"github.com/coder/coder/v2/coderd/httpapi"
@@ -17,6 +19,7 @@ import (
1719
)
1820

1921
type committer struct {
22+
Log slog.Logger
2023
Database database.Store
2124
}
2225

@@ -28,12 +31,12 @@ func (c *committer) CommitQuota(
2831
return nil, err
2932
}
3033

31-
build, err := c.Database.GetWorkspaceBuildByJobID(ctx, jobID)
34+
nextBuild, err := c.Database.GetWorkspaceBuildByJobID(ctx, jobID)
3235
if err != nil {
3336
return nil, err
3437
}
3538

36-
workspace, err := c.Database.GetWorkspaceByID(ctx, build.WorkspaceID)
39+
workspace, err := c.Database.GetWorkspaceByID(ctx, nextBuild.WorkspaceID)
3740
if err != nil {
3841
return nil, err
3942
}
@@ -58,25 +61,35 @@ func (c *committer) CommitQuota(
5861
// If the new build will reduce overall quota consumption, then we
5962
// allow it even if the user is over quota.
6063
netIncrease := true
61-
previousBuild, err := s.GetWorkspaceBuildByWorkspaceIDAndBuildNumber(ctx, database.GetWorkspaceBuildByWorkspaceIDAndBuildNumberParams{
64+
prevBuild, err := s.GetWorkspaceBuildByWorkspaceIDAndBuildNumber(ctx, database.GetWorkspaceBuildByWorkspaceIDAndBuildNumberParams{
6265
WorkspaceID: workspace.ID,
63-
BuildNumber: build.BuildNumber - 1,
66+
BuildNumber: nextBuild.BuildNumber - 1,
6467
})
6568
if err == nil {
66-
if build.DailyCost < previousBuild.DailyCost {
67-
netIncrease = false
68-
}
69-
} else if !xerrors.Is(err, sql.ErrNoRows) {
69+
netIncrease = request.DailyCost >= prevBuild.DailyCost
70+
c.Log.Debug(
71+
ctx, "previous build cost",
72+
slog.F("prev_cost", prevBuild.DailyCost),
73+
slog.F("next_cost", request.DailyCost),
74+
slog.F("net_increase", netIncrease),
75+
)
76+
} else if !errors.Is(err, sql.ErrNoRows) {
7077
return err
7178
}
7279

7380
newConsumed := int64(request.DailyCost) + consumed
7481
if newConsumed > budget && netIncrease {
82+
c.Log.Debug(
83+
ctx, "over quota, rejecting",
84+
slog.F("prev_consumed", consumed),
85+
slog.F("next_consumed", newConsumed),
86+
slog.F("budget", budget),
87+
)
7588
return nil
7689
}
7790

7891
err = s.UpdateWorkspaceBuildCostByID(ctx, database.UpdateWorkspaceBuildCostByIDParams{
79-
ID: build.ID,
92+
ID: nextBuild.ID,
8093
DailyCost: request.DailyCost,
8194
})
8295
if err != nil {

enterprise/coderd/workspacequota_test.go

+105-5
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010
"github.com/stretchr/testify/require"
1111

1212
"github.com/coder/coder/v2/coderd/coderdtest"
13+
"github.com/coder/coder/v2/coderd/database"
1314
"github.com/coder/coder/v2/coderd/util/ptr"
1415
"github.com/coder/coder/v2/codersdk"
1516
"github.com/coder/coder/v2/enterprise/coderd/coderdenttest"
@@ -31,12 +32,13 @@ func verifyQuota(ctx context.Context, t *testing.T, client *codersdk.Client, con
3132
}
3233

3334
func TestWorkspaceQuota(t *testing.T) {
34-
// TODO: refactor for new impl
35-
3635
t.Parallel()
3736

38-
t.Run("BlocksBuild", func(t *testing.T) {
37+
// This first test verifies the behavior of creating and deleting workspaces.
38+
// It also tests multi-group quota stacking and the everyone group.
39+
t.Run("CreateDelete", func(t *testing.T) {
3940
t.Parallel()
41+
4042
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
4143
defer cancel()
4244
max := 1
@@ -49,8 +51,6 @@ func TestWorkspaceQuota(t *testing.T) {
4951
},
5052
})
5153
coderdtest.NewProvisionerDaemon(t, api.AGPL)
52-
coderdtest.NewProvisionerDaemon(t, api.AGPL)
53-
coderdtest.NewProvisionerDaemon(t, api.AGPL)
5454

5555
verifyQuota(ctx, t, client, 0, 0)
5656

@@ -157,4 +157,104 @@ func TestWorkspaceQuota(t *testing.T) {
157157
verifyQuota(ctx, t, client, 4, 4)
158158
require.Equal(t, codersdk.WorkspaceStatusRunning, build.Status)
159159
})
160+
161+
t.Run("StartStop", func(t *testing.T) {
162+
t.Parallel()
163+
164+
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
165+
defer cancel()
166+
max := 1
167+
client, _, api, user := coderdenttest.NewWithAPI(t, &coderdenttest.Options{
168+
UserWorkspaceQuota: max,
169+
LicenseOptions: &coderdenttest.LicenseOptions{
170+
Features: license.Features{
171+
codersdk.FeatureTemplateRBAC: 1,
172+
},
173+
},
174+
})
175+
coderdtest.NewProvisionerDaemon(t, api.AGPL)
176+
177+
verifyQuota(ctx, t, client, 0, 0)
178+
179+
// Patch the 'Everyone' group to verify its quota allowance is being accounted for.
180+
_, err := client.PatchGroup(ctx, user.OrganizationID, codersdk.PatchGroupRequest{
181+
QuotaAllowance: ptr.Ref(4),
182+
})
183+
require.NoError(t, err)
184+
verifyQuota(ctx, t, client, 0, 4)
185+
186+
stopResp := []*proto.Provision_Response{{
187+
Type: &proto.Provision_Response_Complete{
188+
Complete: &proto.Provision_Complete{
189+
Resources: []*proto.Resource{{
190+
Name: "example",
191+
Type: "aws_instance",
192+
DailyCost: 1,
193+
}},
194+
},
195+
},
196+
}}
197+
198+
startResp := []*proto.Provision_Response{{
199+
Type: &proto.Provision_Response_Complete{
200+
Complete: &proto.Provision_Complete{
201+
Resources: []*proto.Resource{{
202+
Name: "example",
203+
Type: "aws_instance",
204+
DailyCost: 2,
205+
}},
206+
},
207+
},
208+
}}
209+
210+
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{
211+
Parse: echo.ParseComplete,
212+
ProvisionPlanMap: map[proto.WorkspaceTransition][]*proto.Provision_Response{
213+
proto.WorkspaceTransition_START: startResp,
214+
proto.WorkspaceTransition_STOP: stopResp,
215+
},
216+
ProvisionApplyMap: map[proto.WorkspaceTransition][]*proto.Provision_Response{
217+
proto.WorkspaceTransition_START: startResp,
218+
proto.WorkspaceTransition_STOP: stopResp,
219+
},
220+
})
221+
222+
coderdtest.AwaitTemplateVersionJob(t, client, version.ID)
223+
template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
224+
225+
// Spin up two workspaces.
226+
var wg sync.WaitGroup
227+
var workspaces []codersdk.Workspace
228+
for i := 0; i < 2; i++ {
229+
workspace := coderdtest.CreateWorkspace(t, client, user.OrganizationID, template.ID)
230+
workspaces = append(workspaces, workspace)
231+
build := coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID)
232+
assert.Equal(t, codersdk.WorkspaceStatusRunning, build.Status)
233+
}
234+
wg.Wait()
235+
verifyQuota(ctx, t, client, 4, 4)
236+
237+
// Next one must fail
238+
workspace := coderdtest.CreateWorkspace(t, client, user.OrganizationID, template.ID)
239+
build := coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID)
240+
require.Contains(t, build.Job.Error, "quota")
241+
242+
// Consumed shouldn't bump
243+
verifyQuota(ctx, t, client, 4, 4)
244+
require.Equal(t, codersdk.WorkspaceStatusFailed, build.Status)
245+
246+
build = coderdtest.CreateWorkspaceBuild(t, client, workspaces[0], database.WorkspaceTransitionStop)
247+
build = coderdtest.AwaitWorkspaceBuildJob(t, client, build.ID)
248+
249+
// Quota goes down one
250+
verifyQuota(ctx, t, client, 3, 4)
251+
require.Equal(t, codersdk.WorkspaceStatusStopped, build.Status)
252+
253+
build = coderdtest.CreateWorkspaceBuild(t, client, workspaces[0], database.WorkspaceTransitionStart)
254+
build = coderdtest.AwaitWorkspaceBuildJob(t, client, build.ID)
255+
256+
// Quota goes back up
257+
verifyQuota(ctx, t, client, 4, 4)
258+
require.Equal(t, codersdk.WorkspaceStatusRunning, build.Status)
259+
})
160260
}

provisioner/echo/serve.go

+68-37
Original file line numberDiff line numberDiff line change
@@ -127,19 +127,35 @@ func (e *echo) Provision(stream proto.DRPCProvisioner_ProvisionStream) error {
127127
return nil
128128
}
129129

130-
for index := 0; ; index++ {
130+
outer:
131+
for i := 0; ; i++ {
131132
var extension string
132133
if msg.GetPlan() != nil {
133134
extension = ".plan.protobuf"
134135
} else {
135136
extension = ".apply.protobuf"
136137
}
137-
path := filepath.Join(config.Directory, fmt.Sprintf("%d.provision"+extension, index))
138-
_, err := e.filesystem.Stat(path)
139-
if err != nil {
140-
if index == 0 {
141-
// Error if nothing is around to enable failed states.
142-
return xerrors.New("no state")
138+
var (
139+
path string
140+
pathIndex int
141+
)
142+
// Try more specific path first, then fallback to generic.
143+
paths := []string{
144+
filepath.Join(config.Directory, fmt.Sprintf("%d.%s.provision"+extension, i, strings.ToLower(config.GetMetadata().GetWorkspaceTransition().String()))),
145+
filepath.Join(config.Directory, fmt.Sprintf("%d.provision"+extension, i)),
146+
}
147+
for pathIndex, path = range paths {
148+
_, err := e.filesystem.Stat(path)
149+
if err != nil && pathIndex == len(paths)-1 {
150+
// If there are zero messages, something is wrong.
151+
if i == 0 {
152+
// Error if nothing is around to enable failed states.
153+
return xerrors.New("no state")
154+
}
155+
// Otherwise, we're done with the entire provision.
156+
break outer
157+
} else if err != nil {
158+
continue
143159
}
144160
break
145161
}
@@ -170,75 +186,90 @@ func (*echo) Shutdown(_ context.Context, _ *proto.Empty) (*proto.Empty, error) {
170186
return &proto.Empty{}, nil
171187
}
172188

189+
// Responses is a collection of mocked responses to Provision operations.
173190
type Responses struct {
174-
Parse []*proto.Parse_Response
191+
Parse []*proto.Parse_Response
192+
193+
// ProvisionApply and ProvisionPlan are used to mock ALL responses of
194+
// Apply and Plan, regardless of transition.
175195
ProvisionApply []*proto.Provision_Response
176196
ProvisionPlan []*proto.Provision_Response
197+
198+
// ProvisionApplyMap and ProvisionPlanMap are used to mock specific
199+
// transition responses. They are prioritized over the generic responses.
200+
ProvisionApplyMap map[proto.WorkspaceTransition][]*proto.Provision_Response
201+
ProvisionPlanMap map[proto.WorkspaceTransition][]*proto.Provision_Response
177202
}
178203

179204
// Tar returns a tar archive of responses to provisioner operations.
180205
func Tar(responses *Responses) ([]byte, error) {
181206
if responses == nil {
182-
responses = &Responses{ParseComplete, ProvisionComplete, ProvisionComplete}
207+
responses = &Responses{
208+
ParseComplete, ProvisionComplete, ProvisionComplete,
209+
nil, nil,
210+
}
183211
}
184212
if responses.ProvisionPlan == nil {
185213
responses.ProvisionPlan = responses.ProvisionApply
186214
}
187215

188216
var buffer bytes.Buffer
189217
writer := tar.NewWriter(&buffer)
190-
for index, response := range responses.Parse {
191-
data, err := protobuf.Marshal(response)
218+
219+
writeProto := func(name string, message protobuf.Message) error {
220+
data, err := protobuf.Marshal(message)
192221
if err != nil {
193-
return nil, err
222+
return err
194223
}
224+
195225
err = writer.WriteHeader(&tar.Header{
196-
Name: fmt.Sprintf("%d.parse.protobuf", index),
226+
Name: name,
197227
Size: int64(len(data)),
198228
Mode: 0o644,
199229
})
200230
if err != nil {
201-
return nil, err
231+
return err
202232
}
233+
203234
_, err = writer.Write(data)
204235
if err != nil {
205-
return nil, err
236+
return err
206237
}
238+
239+
return nil
207240
}
208-
for index, response := range responses.ProvisionApply {
209-
data, err := protobuf.Marshal(response)
210-
if err != nil {
211-
return nil, err
212-
}
213-
err = writer.WriteHeader(&tar.Header{
214-
Name: fmt.Sprintf("%d.provision.apply.protobuf", index),
215-
Size: int64(len(data)),
216-
Mode: 0o644,
217-
})
241+
for index, response := range responses.Parse {
242+
err := writeProto(fmt.Sprintf("%d.parse.protobuf", index), response)
218243
if err != nil {
219244
return nil, err
220245
}
221-
_, err = writer.Write(data)
246+
}
247+
for index, response := range responses.ProvisionApply {
248+
err := writeProto(fmt.Sprintf("%d.provision.apply.protobuf", index), response)
222249
if err != nil {
223250
return nil, err
224251
}
225252
}
226253
for index, response := range responses.ProvisionPlan {
227-
data, err := protobuf.Marshal(response)
254+
err := writeProto(fmt.Sprintf("%d.provision.plan.protobuf", index), response)
228255
if err != nil {
229256
return nil, err
230257
}
231-
err = writer.WriteHeader(&tar.Header{
232-
Name: fmt.Sprintf("%d.provision.plan.protobuf", index),
233-
Size: int64(len(data)),
234-
Mode: 0o644,
235-
})
236-
if err != nil {
237-
return nil, err
258+
}
259+
for trans, m := range responses.ProvisionApplyMap {
260+
for i, rs := range m {
261+
err := writeProto(fmt.Sprintf("%d.%s.provision.apply.protobuf", i, strings.ToLower(trans.String())), rs)
262+
if err != nil {
263+
return nil, err
264+
}
238265
}
239-
_, err = writer.Write(data)
240-
if err != nil {
241-
return nil, err
266+
}
267+
for trans, m := range responses.ProvisionPlanMap {
268+
for i, rs := range m {
269+
err := writeProto(fmt.Sprintf("%d.%s.provision.plan.protobuf", i, strings.ToLower(trans.String())), rs)
270+
if err != nil {
271+
return nil, err
272+
}
242273
}
243274
}
244275
err := writer.Flush()

0 commit comments

Comments
 (0)