Skip to content

Commit 59efa4a

Browse files
authored
fix(audit): ensure template creation errors are audited (#7315)
1 parent 77d9937 commit 59efa4a

File tree

3 files changed

+55
-17
lines changed

3 files changed

+55
-17
lines changed

coderd/templates.go

+21-14
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import (
2121
"github.com/coder/coder/coderd/rbac"
2222
"github.com/coder/coder/coderd/schedule"
2323
"github.com/coder/coder/coderd/telemetry"
24+
"github.com/coder/coder/coderd/util/ptr"
2425
"github.com/coder/coder/codersdk"
2526
"github.com/coder/coder/examples"
2627
)
@@ -149,6 +150,19 @@ func (api *API) postTemplateByOrganization(rw http.ResponseWriter, r *http.Reque
149150
if !httpapi.Read(ctx, rw, r, &createTemplate) {
150151
return
151152
}
153+
154+
// Make a temporary struct to represent the template. This is used for
155+
// auditing if any of the following checks fail. It will be overwritten when
156+
// the template is inserted into the db.
157+
templateAudit.New = database.Template{
158+
OrganizationID: organization.ID,
159+
Name: createTemplate.Name,
160+
Description: createTemplate.Description,
161+
CreatedBy: apiKey.UserID,
162+
Icon: createTemplate.Icon,
163+
DisplayName: createTemplate.DisplayName,
164+
}
165+
152166
_, err := api.Database.GetTemplateByOrganizationAndName(ctx, database.GetTemplateByOrganizationAndNameParams{
153167
OrganizationID: organization.ID,
154168
Name: createTemplate.Name,
@@ -170,6 +184,7 @@ func (api *API) postTemplateByOrganization(rw http.ResponseWriter, r *http.Reque
170184
})
171185
return
172186
}
187+
173188
templateVersion, err := api.Database.GetTemplateVersionByID(ctx, createTemplate.VersionID)
174189
if errors.Is(err, sql.ErrNoRows) {
175190
httpapi.Write(ctx, rw, http.StatusNotFound, codersdk.Response{
@@ -228,22 +243,14 @@ func (api *API) postTemplateByOrganization(rw http.ResponseWriter, r *http.Reque
228243
}
229244

230245
var (
231-
allowUserCancelWorkspaceJobs bool
232-
allowUserAutostart = true
233-
allowUserAutostop = true
246+
dbTemplate database.Template
247+
template codersdk.Template
248+
249+
allowUserCancelWorkspaceJobs = ptr.NilToDefault(createTemplate.AllowUserCancelWorkspaceJobs, false)
250+
allowUserAutostart = ptr.NilToDefault(createTemplate.AllowUserAutostart, true)
251+
allowUserAutostop = ptr.NilToDefault(createTemplate.AllowUserAutostop, true)
234252
)
235-
if createTemplate.AllowUserCancelWorkspaceJobs != nil {
236-
allowUserCancelWorkspaceJobs = *createTemplate.AllowUserCancelWorkspaceJobs
237-
}
238-
if createTemplate.AllowUserAutostart != nil {
239-
allowUserAutostart = *createTemplate.AllowUserAutostart
240-
}
241-
if createTemplate.AllowUserAutostop != nil {
242-
allowUserAutostop = *createTemplate.AllowUserAutostop
243-
}
244253

245-
var dbTemplate database.Template
246-
var template codersdk.Template
247254
err = api.Database.InTx(func(tx database.Store) error {
248255
now := database.Now()
249256
dbTemplate, err = tx.InsertTemplate(ctx, database.InsertTemplateParams{

coderd/util/ptr/ptr.go

+12-3
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,19 @@ func NilOrEmpty(s *string) bool {
1717
return s == nil || *s == ""
1818
}
1919

20-
// NilToEmpty coalesces a nil str to the empty string.
21-
func NilToEmpty(s *string) string {
20+
// NilToEmpty coalesces a nil value to the empty value.
21+
func NilToEmpty[T any](s *T) T {
22+
var def T
2223
if s == nil {
23-
return ""
24+
return def
25+
}
26+
return *s
27+
}
28+
29+
// NilToDefault coalesces a nil value to the provided default value.
30+
func NilToDefault[T any](s *T, def T) T {
31+
if s == nil {
32+
return def
2433
}
2534
return *s
2635
}

coderd/util/ptr/ptr_test.go

+22
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,28 @@ func Test_NilOrEmpty(t *testing.T) {
5252
assert.False(t, ptr.NilOrEmpty(&nonEmptyString))
5353
}
5454

55+
func Test_NilToEmpty(t *testing.T) {
56+
t.Parallel()
57+
58+
assert.False(t, ptr.NilToEmpty((*bool)(nil)))
59+
assert.Empty(t, ptr.NilToEmpty((*int64)(nil)))
60+
assert.Empty(t, ptr.NilToEmpty((*string)(nil)))
61+
assert.Equal(t, true, ptr.NilToEmpty(ptr.Ref(true)))
62+
}
63+
64+
func Test_NilToDefault(t *testing.T) {
65+
t.Parallel()
66+
67+
assert.True(t, ptr.NilToDefault(ptr.Ref(true), false))
68+
assert.True(t, ptr.NilToDefault((*bool)(nil), true))
69+
70+
assert.Equal(t, int64(4), ptr.NilToDefault(ptr.Ref[int64](4), 5))
71+
assert.Equal(t, int64(5), ptr.NilToDefault((*int64)(nil), 5))
72+
73+
assert.Equal(t, "hi", ptr.NilToDefault((*string)(nil), "hi"))
74+
assert.Equal(t, "hello", ptr.NilToDefault(ptr.Ref("hello"), "hi"))
75+
}
76+
5577
func Test_NilOrZero(t *testing.T) {
5678
t.Parallel()
5779

0 commit comments

Comments
 (0)