Skip to content

Commit 9de1e3a

Browse files
committed
fix: correctly reject quota-violating builds
- Modify the Echo provisioner to make it easier to test the difference between stopped and started states. - Expand Quotas tests
1 parent 4ee6fd5 commit 9de1e3a

File tree

6 files changed

+284
-53
lines changed

6 files changed

+284
-53
lines changed

enterprise/coderd/coderd.go

+4-1
Original file line numberDiff line numberDiff line change
@@ -498,7 +498,10 @@ func (api *API) updateEntitlements(ctx context.Context) error {
498498

499499
if initial, changed, enabled := featureChanged(codersdk.FeatureTemplateRBAC); shouldUpdate(initial, changed, enabled) {
500500
if enabled {
501-
committer := committer{Database: api.Database}
501+
committer := committer{
502+
Log: api.Logger.Named("quota_committer"),
503+
Database: api.Database,
504+
}
502505
ptr := proto.QuotaCommitter(&committer)
503506
api.AGPL.QuotaCommitter.Store(&ptr)
504507
} else {

enterprise/coderd/workspacequota.go

+23-10
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,12 @@ package coderd
33
import (
44
"context"
55
"database/sql"
6+
"errors"
67
"net/http"
78

89
"github.com/google/uuid"
9-
"golang.org/x/xerrors"
10+
11+
"cdr.dev/slog"
1012

1113
"github.com/coder/coder/v2/coderd/database"
1214
"github.com/coder/coder/v2/coderd/httpapi"
@@ -17,6 +19,7 @@ import (
1719
)
1820

1921
type committer struct {
22+
Log slog.Logger
2023
Database database.Store
2124
}
2225

@@ -28,12 +31,12 @@ func (c *committer) CommitQuota(
2831
return nil, err
2932
}
3033

31-
build, err := c.Database.GetWorkspaceBuildByJobID(ctx, jobID)
34+
nextBuild, err := c.Database.GetWorkspaceBuildByJobID(ctx, jobID)
3235
if err != nil {
3336
return nil, err
3437
}
3538

36-
workspace, err := c.Database.GetWorkspaceByID(ctx, build.WorkspaceID)
39+
workspace, err := c.Database.GetWorkspaceByID(ctx, nextBuild.WorkspaceID)
3740
if err != nil {
3841
return nil, err
3942
}
@@ -58,25 +61,35 @@ func (c *committer) CommitQuota(
5861
// If the new build will reduce overall quota consumption, then we
5962
// allow it even if the user is over quota.
6063
netIncrease := true
61-
previousBuild, err := s.GetWorkspaceBuildByWorkspaceIDAndBuildNumber(ctx, database.GetWorkspaceBuildByWorkspaceIDAndBuildNumberParams{
64+
prevBuild, err := s.GetWorkspaceBuildByWorkspaceIDAndBuildNumber(ctx, database.GetWorkspaceBuildByWorkspaceIDAndBuildNumberParams{
6265
WorkspaceID: workspace.ID,
63-
BuildNumber: build.BuildNumber - 1,
66+
BuildNumber: nextBuild.BuildNumber - 1,
6467
})
6568
if err == nil {
66-
if build.DailyCost < previousBuild.DailyCost {
67-
netIncrease = false
68-
}
69-
} else if !xerrors.Is(err, sql.ErrNoRows) {
69+
netIncrease = request.DailyCost >= prevBuild.DailyCost
70+
c.Log.Debug(
71+
ctx, "previous build cost",
72+
slog.F("prev_cost", prevBuild.DailyCost),
73+
slog.F("next_cost", request.DailyCost),
74+
slog.F("net_increase", netIncrease),
75+
)
76+
} else if !errors.Is(err, sql.ErrNoRows) {
7077
return err
7178
}
7279

7380
newConsumed := int64(request.DailyCost) + consumed
7481
if newConsumed > budget && netIncrease {
82+
c.Log.Debug(
83+
ctx, "over quota, rejecting",
84+
slog.F("prev_consumed", consumed),
85+
slog.F("next_consumed", newConsumed),
86+
slog.F("budget", budget),
87+
)
7588
return nil
7689
}
7790

7891
err = s.UpdateWorkspaceBuildCostByID(ctx, database.UpdateWorkspaceBuildCostByIDParams{
79-
ID: build.ID,
92+
ID: nextBuild.ID,
8093
DailyCost: request.DailyCost,
8194
})
8295
if err != nil {

enterprise/coderd/workspacequota_test.go

+105-5
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010
"github.com/stretchr/testify/require"
1111

1212
"github.com/coder/coder/v2/coderd/coderdtest"
13+
"github.com/coder/coder/v2/coderd/database"
1314
"github.com/coder/coder/v2/coderd/util/ptr"
1415
"github.com/coder/coder/v2/codersdk"
1516
"github.com/coder/coder/v2/enterprise/coderd/coderdenttest"
@@ -31,12 +32,13 @@ func verifyQuota(ctx context.Context, t *testing.T, client *codersdk.Client, con
3132
}
3233

3334
func TestWorkspaceQuota(t *testing.T) {
34-
// TODO: refactor for new impl
35-
3635
t.Parallel()
3736

38-
t.Run("BlocksBuild", func(t *testing.T) {
37+
// This first test verifies the behavior of creating and deleting workspaces.
38+
// It also tests multi-group quota stacking and the everyone group.
39+
t.Run("CreateDelete", func(t *testing.T) {
3940
t.Parallel()
41+
4042
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
4143
defer cancel()
4244
max := 1
@@ -49,8 +51,6 @@ func TestWorkspaceQuota(t *testing.T) {
4951
},
5052
})
5153
coderdtest.NewProvisionerDaemon(t, api.AGPL)
52-
coderdtest.NewProvisionerDaemon(t, api.AGPL)
53-
coderdtest.NewProvisionerDaemon(t, api.AGPL)
5454

5555
verifyQuota(ctx, t, client, 0, 0)
5656

@@ -157,4 +157,104 @@ func TestWorkspaceQuota(t *testing.T) {
157157
verifyQuota(ctx, t, client, 4, 4)
158158
require.Equal(t, codersdk.WorkspaceStatusRunning, build.Status)
159159
})
160+
161+
t.Run("StartStop", func(t *testing.T) {
162+
t.Parallel()
163+
164+
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
165+
defer cancel()
166+
max := 1
167+
client, _, api, user := coderdenttest.NewWithAPI(t, &coderdenttest.Options{
168+
UserWorkspaceQuota: max,
169+
LicenseOptions: &coderdenttest.LicenseOptions{
170+
Features: license.Features{
171+
codersdk.FeatureTemplateRBAC: 1,
172+
},
173+
},
174+
})
175+
coderdtest.NewProvisionerDaemon(t, api.AGPL)
176+
177+
verifyQuota(ctx, t, client, 0, 0)
178+
179+
// Patch the 'Everyone' group to verify its quota allowance is being accounted for.
180+
_, err := client.PatchGroup(ctx, user.OrganizationID, codersdk.PatchGroupRequest{
181+
QuotaAllowance: ptr.Ref(4),
182+
})
183+
require.NoError(t, err)
184+
verifyQuota(ctx, t, client, 0, 4)
185+
186+
stopResp := []*proto.Provision_Response{{
187+
Type: &proto.Provision_Response_Complete{
188+
Complete: &proto.Provision_Complete{
189+
Resources: []*proto.Resource{{
190+
Name: "example",
191+
Type: "aws_instance",
192+
DailyCost: 1,
193+
}},
194+
},
195+
},
196+
}}
197+
198+
startResp := []*proto.Provision_Response{{
199+
Type: &proto.Provision_Response_Complete{
200+
Complete: &proto.Provision_Complete{
201+
Resources: []*proto.Resource{{
202+
Name: "example",
203+
Type: "aws_instance",
204+
DailyCost: 2,
205+
}},
206+
},
207+
},
208+
}}
209+
210+
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{
211+
Parse: echo.ParseComplete,
212+
ProvisionPlanMap: map[proto.WorkspaceTransition][]*proto.Provision_Response{
213+
proto.WorkspaceTransition_START: startResp,
214+
proto.WorkspaceTransition_STOP: stopResp,
215+
},
216+
ProvisionApplyMap: map[proto.WorkspaceTransition][]*proto.Provision_Response{
217+
proto.WorkspaceTransition_START: startResp,
218+
proto.WorkspaceTransition_STOP: stopResp,
219+
},
220+
})
221+
222+
coderdtest.AwaitTemplateVersionJob(t, client, version.ID)
223+
template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
224+
225+
// Spin up two workspaces.
226+
var wg sync.WaitGroup
227+
var workspaces []codersdk.Workspace
228+
for i := 0; i < 2; i++ {
229+
workspace := coderdtest.CreateWorkspace(t, client, user.OrganizationID, template.ID)
230+
workspaces = append(workspaces, workspace)
231+
build := coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID)
232+
assert.Equal(t, codersdk.WorkspaceStatusRunning, build.Status)
233+
}
234+
wg.Wait()
235+
verifyQuota(ctx, t, client, 4, 4)
236+
237+
// Next one must fail
238+
workspace := coderdtest.CreateWorkspace(t, client, user.OrganizationID, template.ID)
239+
build := coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID)
240+
require.Contains(t, build.Job.Error, "quota")
241+
242+
// Consumed shouldn't bump
243+
verifyQuota(ctx, t, client, 4, 4)
244+
require.Equal(t, codersdk.WorkspaceStatusFailed, build.Status)
245+
246+
build = coderdtest.CreateWorkspaceBuild(t, client, workspaces[0], database.WorkspaceTransitionStop)
247+
build = coderdtest.AwaitWorkspaceBuildJob(t, client, build.ID)
248+
249+
// Quota goes down one
250+
verifyQuota(ctx, t, client, 3, 4)
251+
require.Equal(t, codersdk.WorkspaceStatusStopped, build.Status)
252+
253+
build = coderdtest.CreateWorkspaceBuild(t, client, workspaces[0], database.WorkspaceTransitionStart)
254+
build = coderdtest.AwaitWorkspaceBuildJob(t, client, build.ID)
255+
256+
// Quota goes back up
257+
verifyQuota(ctx, t, client, 4, 4)
258+
require.Equal(t, codersdk.WorkspaceStatusRunning, build.Status)
259+
})
160260
}

provisioner/echo/serve.go

+64-36
Original file line numberDiff line numberDiff line change
@@ -127,19 +127,35 @@ func (e *echo) Provision(stream proto.DRPCProvisioner_ProvisionStream) error {
127127
return nil
128128
}
129129

130-
for index := 0; ; index++ {
130+
outer:
131+
for i := 0; ; i++ {
131132
var extension string
132133
if msg.GetPlan() != nil {
133134
extension = ".plan.protobuf"
134135
} else {
135136
extension = ".apply.protobuf"
136137
}
137-
path := filepath.Join(config.Directory, fmt.Sprintf("%d.provision"+extension, index))
138-
_, err := e.filesystem.Stat(path)
139-
if err != nil {
140-
if index == 0 {
141-
// Error if nothing is around to enable failed states.
142-
return xerrors.New("no state")
138+
var (
139+
path string
140+
pathIndex int
141+
)
142+
// Try more specific path first, then fallback to generic.
143+
paths := []string{
144+
filepath.Join(config.Directory, fmt.Sprintf("%d.%s.provision"+extension, i, strings.ToLower(config.GetMetadata().GetWorkspaceTransition().String()))),
145+
filepath.Join(config.Directory, fmt.Sprintf("%d.provision"+extension, i)),
146+
}
147+
for pathIndex, path = range paths {
148+
_, err := e.filesystem.Stat(path)
149+
if err != nil && pathIndex == len(paths)-1 {
150+
// If there are zero messages, something is wrong.
151+
if i == 0 {
152+
// Error if nothing is around to enable failed states.
153+
return xerrors.New("no state")
154+
}
155+
// Otherwise, we're done with the entire provision.
156+
break outer
157+
} else if err != nil {
158+
continue
143159
}
144160
break
145161
}
@@ -170,75 +186,87 @@ func (*echo) Shutdown(_ context.Context, _ *proto.Empty) (*proto.Empty, error) {
170186
return &proto.Empty{}, nil
171187
}
172188

189+
// Responses is a collection of mocked responses to Provision operations.
173190
type Responses struct {
174191
Parse []*proto.Parse_Response
175192
ProvisionApply []*proto.Provision_Response
176193
ProvisionPlan []*proto.Provision_Response
194+
195+
// ProvisionApplyMap and ProvisionPlanMap are used to mock specific
196+
// transition responses.
197+
ProvisionApplyMap map[proto.WorkspaceTransition][]*proto.Provision_Response
198+
ProvisionPlanMap map[proto.WorkspaceTransition][]*proto.Provision_Response
177199
}
178200

179201
// Tar returns a tar archive of responses to provisioner operations.
180202
func Tar(responses *Responses) ([]byte, error) {
181203
if responses == nil {
182-
responses = &Responses{ParseComplete, ProvisionComplete, ProvisionComplete}
204+
responses = &Responses{
205+
ParseComplete, ProvisionComplete, ProvisionComplete,
206+
nil, nil,
207+
}
183208
}
184209
if responses.ProvisionPlan == nil {
185210
responses.ProvisionPlan = responses.ProvisionApply
186211
}
187212

188213
var buffer bytes.Buffer
189214
writer := tar.NewWriter(&buffer)
190-
for index, response := range responses.Parse {
191-
data, err := protobuf.Marshal(response)
215+
216+
writeProto := func(name string, message protobuf.Message) error {
217+
data, err := protobuf.Marshal(message)
192218
if err != nil {
193-
return nil, err
219+
return err
194220
}
221+
195222
err = writer.WriteHeader(&tar.Header{
196-
Name: fmt.Sprintf("%d.parse.protobuf", index),
223+
Name: name,
197224
Size: int64(len(data)),
198225
Mode: 0o644,
199226
})
200227
if err != nil {
201-
return nil, err
228+
return err
202229
}
230+
203231
_, err = writer.Write(data)
204232
if err != nil {
205-
return nil, err
233+
return err
206234
}
235+
236+
return nil
207237
}
208-
for index, response := range responses.ProvisionApply {
209-
data, err := protobuf.Marshal(response)
210-
if err != nil {
211-
return nil, err
212-
}
213-
err = writer.WriteHeader(&tar.Header{
214-
Name: fmt.Sprintf("%d.provision.apply.protobuf", index),
215-
Size: int64(len(data)),
216-
Mode: 0o644,
217-
})
238+
for index, response := range responses.Parse {
239+
err := writeProto(fmt.Sprintf("%d.parse.protobuf", index), response)
218240
if err != nil {
219241
return nil, err
220242
}
221-
_, err = writer.Write(data)
243+
}
244+
for index, response := range responses.ProvisionApply {
245+
err := writeProto(fmt.Sprintf("%d.provision.apply.protobuf", index), response)
222246
if err != nil {
223247
return nil, err
224248
}
225249
}
226250
for index, response := range responses.ProvisionPlan {
227-
data, err := protobuf.Marshal(response)
251+
err := writeProto(fmt.Sprintf("%d.provision.plan.protobuf", index), response)
228252
if err != nil {
229253
return nil, err
230254
}
231-
err = writer.WriteHeader(&tar.Header{
232-
Name: fmt.Sprintf("%d.provision.plan.protobuf", index),
233-
Size: int64(len(data)),
234-
Mode: 0o644,
235-
})
236-
if err != nil {
237-
return nil, err
255+
}
256+
for trans, m := range responses.ProvisionApplyMap {
257+
for i, rs := range m {
258+
err := writeProto(fmt.Sprintf("%d.%s.provision.apply.protobuf", i, strings.ToLower(trans.String())), rs)
259+
if err != nil {
260+
return nil, err
261+
}
238262
}
239-
_, err = writer.Write(data)
240-
if err != nil {
241-
return nil, err
263+
}
264+
for trans, m := range responses.ProvisionPlanMap {
265+
for i, rs := range m {
266+
err := writeProto(fmt.Sprintf("%d.%s.provision.plan.protobuf", i, strings.ToLower(trans.String())), rs)
267+
if err != nil {
268+
return nil, err
269+
}
242270
}
243271
}
244272
err := writer.Flush()

0 commit comments

Comments
 (0)