Skip to content

Commit e3d8557

Browse files
committed
chore: add tests for updating max_ttl on template
1 parent 25f7d2c commit e3d8557

File tree

2 files changed

+210
-2
lines changed

2 files changed

+210
-2
lines changed

coderd/activitybump_test.go

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,24 @@ import (
1818

1919
type mockTemplateScheduleStore struct {
2020
getFn func(ctx context.Context, db database.Store, templateID uuid.UUID) (provisionerdserver.TemplateScheduleOptions, error)
21+
setFn func(ctx context.Context, db database.Store, template database.Template, options provisionerdserver.TemplateScheduleOptions) (database.Template, error)
2122
}
2223

2324
var _ provisionerdserver.TemplateScheduleStore = mockTemplateScheduleStore{}
2425

2526
func (m mockTemplateScheduleStore) GetTemplateScheduleOptions(ctx context.Context, db database.Store, templateID uuid.UUID) (provisionerdserver.TemplateScheduleOptions, error) {
26-
return m.getFn(ctx, db, templateID)
27+
if m.getFn != nil {
28+
return m.getFn(ctx, db, templateID)
29+
}
30+
31+
return provisionerdserver.NewAGPLTemplateScheduleStore().GetTemplateScheduleOptions(ctx, db, templateID)
2732
}
2833

29-
func (mockTemplateScheduleStore) SetTemplateScheduleOptions(ctx context.Context, db database.Store, template database.Template, options provisionerdserver.TemplateScheduleOptions) (database.Template, error) {
34+
func (m mockTemplateScheduleStore) SetTemplateScheduleOptions(ctx context.Context, db database.Store, template database.Template, options provisionerdserver.TemplateScheduleOptions) (database.Template, error) {
35+
if m.setFn != nil {
36+
return m.setFn(ctx, db, template, options)
37+
}
38+
3039
return provisionerdserver.NewAGPLTemplateScheduleStore().SetTemplateScheduleOptions(ctx, db, template, options)
3140
}
3241

coderd/templates_test.go

Lines changed: 199 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package coderd_test
33
import (
44
"context"
55
"net/http"
6+
"sync/atomic"
67
"testing"
78
"time"
89

@@ -15,6 +16,7 @@ import (
1516
"github.com/coder/coder/coderd/audit"
1617
"github.com/coder/coder/coderd/coderdtest"
1718
"github.com/coder/coder/coderd/database"
19+
"github.com/coder/coder/coderd/provisionerdserver"
1820
"github.com/coder/coder/coderd/util/ptr"
1921
"github.com/coder/coder/codersdk"
2022
"github.com/coder/coder/codersdk/agentsdk"
@@ -143,6 +145,95 @@ func TestPostTemplateByOrganization(t *testing.T) {
143145
require.Contains(t, err.Error(), "Try logging in using 'coder login <url>'.")
144146
})
145147

148+
t.Run("MaxTTL", func(t *testing.T) {
149+
t.Parallel()
150+
151+
const (
152+
defaultTTL = 1 * time.Hour
153+
maxTTL = 24 * time.Hour
154+
)
155+
156+
t.Run("OK", func(t *testing.T) {
157+
t.Parallel()
158+
159+
var setCalled int64
160+
client := coderdtest.New(t, &coderdtest.Options{
161+
TemplateScheduleStore: mockTemplateScheduleStore{
162+
setFn: func(ctx context.Context, db database.Store, template database.Template, options provisionerdserver.TemplateScheduleOptions) (database.Template, error) {
163+
atomic.AddInt64(&setCalled, 1)
164+
require.Equal(t, maxTTL, options.MaxTTL)
165+
template.DefaultTTL = int64(options.DefaultTTL)
166+
template.MaxTTL = int64(options.MaxTTL)
167+
return template, nil
168+
},
169+
},
170+
})
171+
user := coderdtest.CreateFirstUser(t, client)
172+
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, nil)
173+
174+
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
175+
defer cancel()
176+
177+
got, err := client.CreateTemplate(ctx, user.OrganizationID, codersdk.CreateTemplateRequest{
178+
Name: "testing",
179+
VersionID: version.ID,
180+
DefaultTTLMillis: ptr.Ref(int64(0)),
181+
MaxTTLMillis: ptr.Ref(maxTTL.Milliseconds()),
182+
})
183+
require.NoError(t, err)
184+
185+
require.EqualValues(t, 1, atomic.LoadInt64(&setCalled))
186+
require.EqualValues(t, 0, got.DefaultTTLMillis)
187+
require.Equal(t, maxTTL.Milliseconds(), got.MaxTTLMillis)
188+
})
189+
190+
t.Run("DefaultTTLBigger", func(t *testing.T) {
191+
t.Parallel()
192+
193+
client := coderdtest.New(t, nil)
194+
user := coderdtest.CreateFirstUser(t, client)
195+
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, nil)
196+
197+
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
198+
defer cancel()
199+
200+
_, err := client.CreateTemplate(ctx, user.OrganizationID, codersdk.CreateTemplateRequest{
201+
Name: "testing",
202+
VersionID: version.ID,
203+
DefaultTTLMillis: ptr.Ref((maxTTL * 2).Milliseconds()),
204+
MaxTTLMillis: ptr.Ref(maxTTL.Milliseconds()),
205+
})
206+
require.Error(t, err)
207+
var sdkErr *codersdk.Error
208+
require.ErrorAs(t, err, &sdkErr)
209+
require.Equal(t, http.StatusBadRequest, sdkErr.StatusCode())
210+
require.Len(t, sdkErr.Validations, 1)
211+
require.Equal(t, "default_ttl_ms", sdkErr.Validations[0].Field)
212+
require.Contains(t, sdkErr.Validations[0].Detail, "Must be less than or equal to max_ttl_ms")
213+
})
214+
215+
t.Run("IgnoredUnlicensed", func(t *testing.T) {
216+
t.Parallel()
217+
218+
client := coderdtest.New(t, nil)
219+
user := coderdtest.CreateFirstUser(t, client)
220+
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, nil)
221+
222+
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
223+
defer cancel()
224+
225+
got, err := client.CreateTemplate(ctx, user.OrganizationID, codersdk.CreateTemplateRequest{
226+
Name: "testing",
227+
VersionID: version.ID,
228+
DefaultTTLMillis: ptr.Ref(defaultTTL.Milliseconds()),
229+
MaxTTLMillis: ptr.Ref(maxTTL.Milliseconds()),
230+
})
231+
require.NoError(t, err)
232+
require.Equal(t, defaultTTL.Milliseconds(), got.DefaultTTLMillis)
233+
require.Zero(t, got.MaxTTLMillis)
234+
})
235+
})
236+
146237
t.Run("NoVersion", func(t *testing.T) {
147238
t.Parallel()
148239
client := coderdtest.New(t, nil)
@@ -345,6 +436,114 @@ func TestPatchTemplateMeta(t *testing.T) {
345436
assert.Equal(t, updated.DefaultTTLMillis, template.DefaultTTLMillis)
346437
})
347438

439+
t.Run("MaxTTL", func(t *testing.T) {
440+
t.Parallel()
441+
442+
const (
443+
defaultTTL = 1 * time.Hour
444+
maxTTL = 24 * time.Hour
445+
)
446+
447+
t.Run("OK", func(t *testing.T) {
448+
t.Parallel()
449+
450+
var setCalled int64
451+
client := coderdtest.New(t, &coderdtest.Options{
452+
TemplateScheduleStore: mockTemplateScheduleStore{
453+
setFn: func(ctx context.Context, db database.Store, template database.Template, options provisionerdserver.TemplateScheduleOptions) (database.Template, error) {
454+
if atomic.AddInt64(&setCalled, 1) == 2 {
455+
require.Equal(t, maxTTL, options.MaxTTL)
456+
}
457+
template.DefaultTTL = int64(options.DefaultTTL)
458+
template.MaxTTL = int64(options.MaxTTL)
459+
return template, nil
460+
},
461+
},
462+
})
463+
user := coderdtest.CreateFirstUser(t, client)
464+
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, nil)
465+
template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID, func(ctr *codersdk.CreateTemplateRequest) {
466+
ctr.DefaultTTLMillis = ptr.Ref(24 * time.Hour.Milliseconds())
467+
})
468+
469+
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
470+
defer cancel()
471+
472+
got, err := client.UpdateTemplateMeta(ctx, template.ID, codersdk.UpdateTemplateMeta{
473+
Name: template.Name,
474+
DisplayName: template.DisplayName,
475+
Description: template.Description,
476+
Icon: template.Icon,
477+
DefaultTTLMillis: 0,
478+
MaxTTLMillis: maxTTL.Milliseconds(),
479+
AllowUserCancelWorkspaceJobs: template.AllowUserCancelWorkspaceJobs,
480+
})
481+
require.NoError(t, err)
482+
483+
require.EqualValues(t, 2, atomic.LoadInt64(&setCalled))
484+
require.EqualValues(t, 0, got.DefaultTTLMillis)
485+
require.Equal(t, maxTTL.Milliseconds(), got.MaxTTLMillis)
486+
})
487+
488+
t.Run("DefaultTTLBigger", func(t *testing.T) {
489+
t.Parallel()
490+
491+
client := coderdtest.New(t, nil)
492+
user := coderdtest.CreateFirstUser(t, client)
493+
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, nil)
494+
template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID, func(ctr *codersdk.CreateTemplateRequest) {
495+
ctr.DefaultTTLMillis = ptr.Ref(24 * time.Hour.Milliseconds())
496+
})
497+
498+
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
499+
defer cancel()
500+
501+
_, err := client.UpdateTemplateMeta(ctx, template.ID, codersdk.UpdateTemplateMeta{
502+
Name: template.Name,
503+
DisplayName: template.DisplayName,
504+
Description: template.Description,
505+
Icon: template.Icon,
506+
DefaultTTLMillis: (maxTTL * 2).Milliseconds(),
507+
MaxTTLMillis: maxTTL.Milliseconds(),
508+
AllowUserCancelWorkspaceJobs: template.AllowUserCancelWorkspaceJobs,
509+
})
510+
require.Error(t, err)
511+
var sdkErr *codersdk.Error
512+
require.ErrorAs(t, err, &sdkErr)
513+
require.Equal(t, http.StatusBadRequest, sdkErr.StatusCode())
514+
require.Len(t, sdkErr.Validations, 1)
515+
require.Equal(t, "default_ttl_ms", sdkErr.Validations[0].Field)
516+
require.Contains(t, sdkErr.Validations[0].Detail, "Must be less than or equal to max_ttl_ms")
517+
})
518+
519+
t.Run("IgnoredUnlicensed", func(t *testing.T) {
520+
t.Parallel()
521+
522+
client := coderdtest.New(t, nil)
523+
user := coderdtest.CreateFirstUser(t, client)
524+
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, nil)
525+
template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID, func(ctr *codersdk.CreateTemplateRequest) {
526+
ctr.DefaultTTLMillis = ptr.Ref(24 * time.Hour.Milliseconds())
527+
})
528+
529+
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
530+
defer cancel()
531+
532+
got, err := client.UpdateTemplateMeta(ctx, template.ID, codersdk.UpdateTemplateMeta{
533+
Name: template.Name,
534+
DisplayName: template.DisplayName,
535+
Description: template.Description,
536+
Icon: template.Icon,
537+
DefaultTTLMillis: defaultTTL.Milliseconds(),
538+
MaxTTLMillis: maxTTL.Milliseconds(),
539+
AllowUserCancelWorkspaceJobs: template.AllowUserCancelWorkspaceJobs,
540+
})
541+
require.NoError(t, err)
542+
require.Equal(t, defaultTTL.Milliseconds(), got.DefaultTTLMillis)
543+
require.Zero(t, got.MaxTTLMillis)
544+
})
545+
})
546+
348547
t.Run("NotModified", func(t *testing.T) {
349548
t.Parallel()
350549

0 commit comments

Comments
 (0)