From c57b41b7f9777beca0dfb1a55e75ab64c1c0c391 Mon Sep 17 00:00:00 2001 From: Colin Adler Date: Thu, 27 Apr 2023 21:13:43 +0000 Subject: [PATCH 1/3] fix(audit): ensure template creation errors are audited --- coderd/templates.go | 35 +++++++++++++++++++++-------------- coderd/util/ptr/ptr.go | 15 ++++++++++++--- coderd/util/ptr/ptr_test.go | 22 ++++++++++++++++++++++ 3 files changed, 55 insertions(+), 17 deletions(-) diff --git a/coderd/templates.go b/coderd/templates.go index 44def2d4b00b4..43038b8ab3e26 100644 --- a/coderd/templates.go +++ b/coderd/templates.go @@ -21,6 +21,7 @@ import ( "github.com/coder/coder/coderd/rbac" "github.com/coder/coder/coderd/schedule" "github.com/coder/coder/coderd/telemetry" + "github.com/coder/coder/coderd/util/ptr" "github.com/coder/coder/codersdk" "github.com/coder/coder/examples" ) @@ -149,6 +150,19 @@ func (api *API) postTemplateByOrganization(rw http.ResponseWriter, r *http.Reque if !httpapi.Read(ctx, rw, r, &createTemplate) { return } + + // Make a temporary struct to represent the template. This is used for + // auditing if any of the following checks fails. It will be overwritten + // when the template is inserted into the db. + templateAudit.New = database.Template{ + OrganizationID: organization.ID, + Name: createTemplate.Name, + Description: createTemplate.Description, + CreatedBy: apiKey.UserID, + Icon: createTemplate.Icon, + DisplayName: createTemplate.DisplayName, + } + _, err := api.Database.GetTemplateByOrganizationAndName(ctx, database.GetTemplateByOrganizationAndNameParams{ OrganizationID: organization.ID, Name: createTemplate.Name, @@ -170,6 +184,7 @@ func (api *API) postTemplateByOrganization(rw http.ResponseWriter, r *http.Reque }) return } + templateVersion, err := api.Database.GetTemplateVersionByID(ctx, createTemplate.VersionID) if errors.Is(err, sql.ErrNoRows) { httpapi.Write(ctx, rw, http.StatusNotFound, codersdk.Response{ @@ -228,22 +243,14 @@ func (api *API) postTemplateByOrganization(rw http.ResponseWriter, r *http.Reque } var ( - allowUserCancelWorkspaceJobs bool - allowUserAutostart = true - allowUserAutostop = true + dbTemplate database.Template + template codersdk.Template + + allowUserCancelWorkspaceJobs = ptr.NilToDefault(createTemplate.AllowUserCancelWorkspaceJobs, false) + allowUserAutostart = ptr.NilToDefault(createTemplate.AllowUserAutostart, true) + allowUserAutostop = ptr.NilToDefault(createTemplate.AllowUserAutostop, true) ) - if createTemplate.AllowUserCancelWorkspaceJobs != nil { - allowUserCancelWorkspaceJobs = *createTemplate.AllowUserCancelWorkspaceJobs - } - if createTemplate.AllowUserAutostart != nil { - allowUserAutostart = *createTemplate.AllowUserAutostart - } - if createTemplate.AllowUserAutostop != nil { - allowUserAutostop = *createTemplate.AllowUserAutostop - } - var dbTemplate database.Template - var template codersdk.Template err = api.Database.InTx(func(tx database.Store) error { now := database.Now() dbTemplate, err = tx.InsertTemplate(ctx, database.InsertTemplateParams{ diff --git a/coderd/util/ptr/ptr.go b/coderd/util/ptr/ptr.go index eef582b95d16d..3500805c6fed0 100644 --- a/coderd/util/ptr/ptr.go +++ b/coderd/util/ptr/ptr.go @@ -17,10 +17,19 @@ func NilOrEmpty(s *string) bool { return s == nil || *s == "" } -// NilToEmpty coalesces a nil str to the empty string. -func NilToEmpty(s *string) string { +// NilToEmpty coalesces a nil value to the empty value. +func NilToEmpty[T any](s *T) T { + var def T if s == nil { - return "" + return def + } + return *s +} + +// NilToDefault coalesces a nil value to the provided default value. +func NilToDefault[T any](s *T, def T) T { + if s == nil { + return def } return *s } diff --git a/coderd/util/ptr/ptr_test.go b/coderd/util/ptr/ptr_test.go index d43e9ccd1122f..1157f22294732 100644 --- a/coderd/util/ptr/ptr_test.go +++ b/coderd/util/ptr/ptr_test.go @@ -52,6 +52,28 @@ func Test_NilOrEmpty(t *testing.T) { assert.False(t, ptr.NilOrEmpty(&nonEmptyString)) } +func Test_NilToEmpty(t *testing.T) { + t.Parallel() + + assert.False(t, ptr.NilToEmpty((*bool)(nil))) + assert.Empty(t, ptr.NilToEmpty((*int64)(nil))) + assert.Empty(t, ptr.NilToEmpty((*string)(nil))) + assert.Equal(t, true, ptr.NilToEmpty(ptr.Ref(true))) +} + +func Test_NilToDefault(t *testing.T) { + t.Parallel() + + assert.True(t, ptr.NilToDefault(ptr.Ref(true), false)) + assert.True(t, ptr.NilToDefault((*bool)(nil), true)) + + assert.Equal(t, 4, ptr.NilToDefault(ptr.Ref(4), 5)) + assert.Equal(t, 5, ptr.NilToDefault((*int64)(nil), 5)) + + assert.Equal(t, "hi", ptr.NilToDefault((*string)(nil), "hi")) + assert.Equal(t, "hello", ptr.NilToDefault(ptr.Ref("hello"), "hi")) +} + func Test_NilOrZero(t *testing.T) { t.Parallel() From dd2d73521ee15ea1da644f07ee49c54d90218737 Mon Sep 17 00:00:00 2001 From: Colin Adler Date: Thu, 27 Apr 2023 21:16:19 +0000 Subject: [PATCH 2/3] fixup! fix(audit): ensure template creation errors are audited --- coderd/templates.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/coderd/templates.go b/coderd/templates.go index 43038b8ab3e26..2250d52698c78 100644 --- a/coderd/templates.go +++ b/coderd/templates.go @@ -152,8 +152,8 @@ func (api *API) postTemplateByOrganization(rw http.ResponseWriter, r *http.Reque } // Make a temporary struct to represent the template. This is used for - // auditing if any of the following checks fails. It will be overwritten - // when the template is inserted into the db. + // auditing if any of the following checks fail. It will be overwritten when + // the template is inserted into the db. templateAudit.New = database.Template{ OrganizationID: organization.ID, Name: createTemplate.Name, From 3c7701e849467a861ac206ecaf97dcba83357bff Mon Sep 17 00:00:00 2001 From: Colin Adler Date: Thu, 27 Apr 2023 21:44:23 +0000 Subject: [PATCH 3/3] fixup! fix(audit): ensure template creation errors are audited --- coderd/util/ptr/ptr_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/coderd/util/ptr/ptr_test.go b/coderd/util/ptr/ptr_test.go index 1157f22294732..2dee346c8f5e4 100644 --- a/coderd/util/ptr/ptr_test.go +++ b/coderd/util/ptr/ptr_test.go @@ -67,8 +67,8 @@ func Test_NilToDefault(t *testing.T) { assert.True(t, ptr.NilToDefault(ptr.Ref(true), false)) assert.True(t, ptr.NilToDefault((*bool)(nil), true)) - assert.Equal(t, 4, ptr.NilToDefault(ptr.Ref(4), 5)) - assert.Equal(t, 5, ptr.NilToDefault((*int64)(nil), 5)) + assert.Equal(t, int64(4), ptr.NilToDefault(ptr.Ref[int64](4), 5)) + assert.Equal(t, int64(5), ptr.NilToDefault((*int64)(nil), 5)) assert.Equal(t, "hi", ptr.NilToDefault((*string)(nil), "hi")) assert.Equal(t, "hello", ptr.NilToDefault(ptr.Ref("hello"), "hi"))