Skip to content

Commit 69035a6

Browse files
committed
feat: block disabling auto off if template has max ttl
1 parent 4e1d948 commit 69035a6

File tree

4 files changed

+154
-10
lines changed

4 files changed

+154
-10
lines changed

coderd/provisionerdserver/provisionerdserver.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,7 @@ func (server *Server) includeLastVariableValues(ctx context.Context, templateVer
308308

309309
templateVersion, err := server.Database.GetTemplateVersionByID(ctx, templateVersionID)
310310
if err != nil {
311-
return nil, fmt.Errorf("get template version: %w", err)
311+
return nil, xerrors.Errorf("get template version: %w", err)
312312
}
313313

314314
if templateVersion.TemplateID.UUID == uuid.Nil {
@@ -317,7 +317,7 @@ func (server *Server) includeLastVariableValues(ctx context.Context, templateVer
317317

318318
template, err := server.Database.GetTemplateByID(ctx, templateVersion.TemplateID.UUID)
319319
if err != nil {
320-
return nil, fmt.Errorf("get template: %w", err)
320+
return nil, xerrors.Errorf("get template: %w", err)
321321
}
322322

323323
if template.ActiveVersionID == uuid.Nil {
@@ -326,7 +326,7 @@ func (server *Server) includeLastVariableValues(ctx context.Context, templateVer
326326

327327
templateVariables, err := server.Database.GetTemplateVersionVariables(ctx, template.ActiveVersionID)
328328
if err != nil && !xerrors.Is(err, sql.ErrNoRows) {
329-
return nil, fmt.Errorf("get template version variables: %w", err)
329+
return nil, xerrors.Errorf("get template version variables: %w", err)
330330
}
331331

332332
for _, templateVariable := range templateVariables {

coderd/workspaces.go

Lines changed: 36 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -362,7 +362,16 @@ func (api *API) postWorkspacesByOrganization(rw http.ResponseWriter, r *http.Req
362362
return
363363
}
364364

365-
dbTTL, err := validWorkspaceTTLMillis(createWorkspace.TTLMillis, template.DefaultTTL)
365+
templateSchedule, err := (*api.TemplateScheduleStore.Load()).GetTemplateScheduleOptions(ctx, api.Database, template.ID)
366+
if err != nil {
367+
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
368+
Message: "Internal error fetching template schedule.",
369+
Detail: err.Error(),
370+
})
371+
return
372+
}
373+
374+
dbTTL, err := validWorkspaceTTLMillis(createWorkspace.TTLMillis, templateSchedule.DefaultTTL, templateSchedule.MaxTTL)
366375
if err != nil {
367376
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
368377
Message: "Invalid Workspace Time to Shutdown.",
@@ -799,9 +808,15 @@ func (api *API) putWorkspaceTTL(rw http.ResponseWriter, r *http.Request) {
799808
var dbTTL sql.NullInt64
800809

801810
err := api.Database.InTx(func(s database.Store) error {
811+
templateSchedule, err := (*api.TemplateScheduleStore.Load()).GetTemplateScheduleOptions(ctx, s, workspace.TemplateID)
812+
if err != nil {
813+
return xerrors.Errorf("get template schedule: %w", err)
814+
}
815+
816+
// don't override 0 ttl with template default here because it indicates
817+
// disabled auto-stop
802818
var validityErr error
803-
// don't override 0 ttl with template default here because it indicates disabled auto-stop
804-
dbTTL, validityErr = validWorkspaceTTLMillis(req.TTLMillis, 0)
819+
dbTTL, validityErr = validWorkspaceTTLMillis(req.TTLMillis, 0, templateSchedule.MaxTTL)
805820
if validityErr != nil {
806821
return codersdk.ValidationError{Field: "ttl_ms", Detail: validityErr.Error()}
807822
}
@@ -1187,14 +1202,25 @@ func convertWorkspaceTTLMillis(i sql.NullInt64) *int64 {
11871202
return &millis
11881203
}
11891204

1190-
func validWorkspaceTTLMillis(millis *int64, def int64) (sql.NullInt64, error) {
1205+
func validWorkspaceTTLMillis(millis *int64, templateDefault, templateMax time.Duration) (sql.NullInt64, error) {
1206+
if templateDefault == 0 && templateMax != 0 || (templateMax > 0 && templateDefault > templateMax) {
1207+
templateDefault = templateMax
1208+
}
1209+
11911210
if ptr.NilOrZero(millis) {
1192-
if def == 0 {
1211+
if templateDefault == 0 {
1212+
if templateMax > 0 {
1213+
return sql.NullInt64{
1214+
Int64: int64(templateMax),
1215+
Valid: true,
1216+
}, nil
1217+
}
1218+
11931219
return sql.NullInt64{}, nil
11941220
}
11951221

11961222
return sql.NullInt64{
1197-
Int64: def,
1223+
Int64: int64(templateDefault),
11981224
Valid: true,
11991225
}, nil
12001226
}
@@ -1209,6 +1235,10 @@ func validWorkspaceTTLMillis(millis *int64, def int64) (sql.NullInt64, error) {
12091235
return sql.NullInt64{}, errTTLMax
12101236
}
12111237

1238+
if templateMax > 0 && truncated > templateMax {
1239+
return sql.NullInt64{}, xerrors.Errorf("time until shutdown must be less than or equal to the template's maximum TTL %q", templateMax.String())
1240+
}
1241+
12121242
return sql.NullInt64{
12131243
Valid: true,
12141244
Int64: int64(truncated),

coderd/workspaces_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -331,7 +331,7 @@ func TestPostWorkspacesByOrganization(t *testing.T) {
331331
})
332332
// TTL should be set by the template
333333
require.Equal(t, template.DefaultTTLMillis, templateTTL)
334-
require.Equal(t, template.DefaultTTLMillis, template.DefaultTTLMillis, workspace.TTLMillis)
334+
require.Equal(t, template.DefaultTTLMillis, *workspace.TTLMillis)
335335
})
336336

337337
t.Run("InvalidTTL", func(t *testing.T) {

enterprise/coderd/templates_test.go

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,120 @@ func TestTemplates(t *testing.T) {
5858
require.NoError(t, err)
5959
require.Equal(t, 2*time.Hour, time.Duration(template.MaxTTLMillis)*time.Millisecond)
6060
})
61+
62+
t.Run("CreateUpdateWorkspaceMaxTTL", func(t *testing.T) {
63+
t.Parallel()
64+
client := coderdenttest.New(t, &coderdenttest.Options{
65+
Options: &coderdtest.Options{
66+
IncludeProvisionerDaemon: true,
67+
},
68+
})
69+
user := coderdtest.CreateFirstUser(t, client)
70+
_ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{
71+
Features: license.Features{
72+
codersdk.FeatureAdvancedTemplateScheduling: 1,
73+
},
74+
})
75+
76+
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, nil)
77+
exp := 24 * time.Hour.Milliseconds()
78+
template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID, func(ctr *codersdk.CreateTemplateRequest) {
79+
ctr.DefaultTTLMillis = &exp
80+
ctr.MaxTTLMillis = &exp
81+
})
82+
coderdtest.AwaitTemplateVersionJob(t, client, version.ID)
83+
84+
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
85+
defer cancel()
86+
87+
// No TTL provided should use template default
88+
req := codersdk.CreateWorkspaceRequest{
89+
TemplateID: template.ID,
90+
Name: "testing",
91+
}
92+
ws, err := client.CreateWorkspace(ctx, template.OrganizationID, codersdk.Me, req)
93+
require.NoError(t, err)
94+
require.EqualValues(t, exp, *ws.TTLMillis)
95+
96+
// Editing a workspace to have a higher TTL than the template's max
97+
// should error
98+
exp = exp + time.Minute.Milliseconds()
99+
err = client.UpdateWorkspaceTTL(ctx, ws.ID, codersdk.UpdateWorkspaceTTLRequest{
100+
TTLMillis: &exp,
101+
})
102+
require.Error(t, err)
103+
var apiErr *codersdk.Error
104+
require.ErrorAs(t, err, &apiErr)
105+
require.Equal(t, http.StatusBadRequest, apiErr.StatusCode())
106+
require.Len(t, apiErr.Validations, 1)
107+
require.Equal(t, apiErr.Validations[0].Field, "ttl_ms")
108+
require.Contains(t, apiErr.Validations[0].Detail, "time until shutdown must be less than or equal to the template's maximum TTL")
109+
110+
// Creating workspace with TTL higher than max should error
111+
req.Name = "testing2"
112+
req.TTLMillis = &exp
113+
ws, err = client.CreateWorkspace(ctx, template.OrganizationID, codersdk.Me, req)
114+
require.Error(t, err)
115+
apiErr = nil
116+
require.ErrorAs(t, err, &apiErr)
117+
require.Equal(t, http.StatusBadRequest, apiErr.StatusCode())
118+
require.Len(t, apiErr.Validations, 1)
119+
require.Equal(t, apiErr.Validations[0].Field, "ttl_ms")
120+
require.Contains(t, apiErr.Validations[0].Detail, "time until shutdown must be less than or equal to the template's maximum TTL")
121+
})
122+
123+
t.Run("BlockDisablingAutoOffWithMaxTTL", func(t *testing.T) {
124+
t.Parallel()
125+
client := coderdenttest.New(t, &coderdenttest.Options{
126+
Options: &coderdtest.Options{
127+
IncludeProvisionerDaemon: true,
128+
},
129+
})
130+
user := coderdtest.CreateFirstUser(t, client)
131+
_ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{
132+
Features: license.Features{
133+
codersdk.FeatureAdvancedTemplateScheduling: 1,
134+
},
135+
})
136+
137+
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, nil)
138+
exp := 24 * time.Hour.Milliseconds()
139+
template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID, func(ctr *codersdk.CreateTemplateRequest) {
140+
ctr.MaxTTLMillis = &exp
141+
})
142+
coderdtest.AwaitTemplateVersionJob(t, client, version.ID)
143+
144+
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
145+
defer cancel()
146+
147+
// No TTL provided should use template default
148+
req := codersdk.CreateWorkspaceRequest{
149+
TemplateID: template.ID,
150+
Name: "testing",
151+
}
152+
ws, err := client.CreateWorkspace(ctx, template.OrganizationID, codersdk.Me, req)
153+
require.NoError(t, err)
154+
require.EqualValues(t, exp, *ws.TTLMillis)
155+
156+
// Editing a workspace to disable the TTL should do nothing
157+
err = client.UpdateWorkspaceTTL(ctx, ws.ID, codersdk.UpdateWorkspaceTTLRequest{
158+
TTLMillis: nil,
159+
})
160+
require.NoError(t, err)
161+
ws, err = client.Workspace(ctx, ws.ID)
162+
require.NoError(t, err)
163+
require.EqualValues(t, exp, *ws.TTLMillis)
164+
165+
// Editing a workspace to have a TTL of 0 should do nothing
166+
zero := int64(0)
167+
err = client.UpdateWorkspaceTTL(ctx, ws.ID, codersdk.UpdateWorkspaceTTLRequest{
168+
TTLMillis: &zero,
169+
})
170+
require.NoError(t, err)
171+
ws, err = client.Workspace(ctx, ws.ID)
172+
require.NoError(t, err)
173+
require.EqualValues(t, exp, *ws.TTLMillis)
174+
})
61175
}
62176

63177
func TestTemplateACL(t *testing.T) {

0 commit comments

Comments
 (0)