From 4ce4adaecfedd033377260f8c1e7cbda4ec1ed5b Mon Sep 17 00:00:00 2001 From: Spike Curtis Date: Thu, 22 Aug 2024 14:36:13 +0400 Subject: [PATCH] chore: add dbauthz to unhanger tests --- coderd/unhanger/detector_test.go | 95 +++++++++++++++++++++++--------- 1 file changed, 70 insertions(+), 25 deletions(-) diff --git a/coderd/unhanger/detector_test.go b/coderd/unhanger/detector_test.go index 99705fb159211..28bb2575b9ee7 100644 --- a/coderd/unhanger/detector_test.go +++ b/coderd/unhanger/detector_test.go @@ -9,14 +9,20 @@ import ( "time" "github.com/google/uuid" + "github.com/prometheus/client_golang/prometheus" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.uber.org/goleak" + "cdr.dev/slog" "cdr.dev/slog/sloggers/slogtest" + "github.com/coder/coder/v2/coderd/coderdtest" "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/dbauthz" "github.com/coder/coder/v2/coderd/database/dbgen" "github.com/coder/coder/v2/coderd/database/dbtestutil" + "github.com/coder/coder/v2/coderd/provisionerdserver" + "github.com/coder/coder/v2/coderd/rbac" "github.com/coder/coder/v2/coderd/unhanger" "github.com/coder/coder/v2/provisionersdk" "github.com/coder/coder/v2/testutil" @@ -37,7 +43,7 @@ func TestDetectorNoJobs(t *testing.T) { statsCh = make(chan unhanger.Stats) ) - detector := unhanger.New(ctx, db, pubsub, log, tickCh).WithStatsChannel(statsCh) + detector := unhanger.New(ctx, wrapDBAuthz(db, log), pubsub, log, tickCh).WithStatsChannel(statsCh) detector.Start() tickCh <- time.Now() @@ -84,7 +90,7 @@ func TestDetectorNoHungJobs(t *testing.T) { }) } - detector := unhanger.New(ctx, db, pubsub, log, tickCh).WithStatsChannel(statsCh) + detector := unhanger.New(ctx, wrapDBAuthz(db, log), pubsub, log, tickCh).WithStatsChannel(statsCh) detector.Start() tickCh <- now @@ -190,7 +196,7 @@ func TestDetectorHungWorkspaceBuild(t *testing.T) { t.Log("previous job ID: ", previousWorkspaceBuildJob.ID) t.Log("current job ID: ", currentWorkspaceBuildJob.ID) - detector := unhanger.New(ctx, db, pubsub, log, tickCh).WithStatsChannel(statsCh) + detector := unhanger.New(ctx, wrapDBAuthz(db, log), pubsub, log, tickCh).WithStatsChannel(statsCh) detector.Start() tickCh <- now @@ -313,7 +319,7 @@ func TestDetectorHungWorkspaceBuildNoOverrideState(t *testing.T) { t.Log("previous job ID: ", previousWorkspaceBuildJob.ID) t.Log("current job ID: ", currentWorkspaceBuildJob.ID) - detector := unhanger.New(ctx, db, pubsub, log, tickCh).WithStatsChannel(statsCh) + detector := unhanger.New(ctx, wrapDBAuthz(db, log), pubsub, log, tickCh).WithStatsChannel(statsCh) detector.Start() tickCh <- now @@ -406,7 +412,7 @@ func TestDetectorHungWorkspaceBuildNoOverrideStateIfNoExistingBuild(t *testing.T t.Log("current job ID: ", currentWorkspaceBuildJob.ID) - detector := unhanger.New(ctx, db, pubsub, log, tickCh).WithStatsChannel(statsCh) + detector := unhanger.New(ctx, wrapDBAuthz(db, log), pubsub, log, tickCh).WithStatsChannel(statsCh) detector.Start() tickCh <- now @@ -469,29 +475,42 @@ func TestDetectorHungOtherJobTypes(t *testing.T) { Type: database.ProvisionerJobTypeTemplateVersionImport, Input: []byte("{}"), }) - - // Template dry-run job. - templateDryRunJob = dbgen.ProvisionerJob(t, db, pubsub, database.ProvisionerJob{ - CreatedAt: tenMinAgo, - UpdatedAt: sixMinAgo, - StartedAt: sql.NullTime{ - Time: tenMinAgo, - Valid: true, - }, + _ = dbgen.TemplateVersion(t, db, database.TemplateVersion{ OrganizationID: org.ID, - InitiatorID: user.ID, - Provisioner: database.ProvisionerTypeEcho, - StorageMethod: database.ProvisionerStorageMethodFile, - FileID: file.ID, - Type: database.ProvisionerJobTypeTemplateVersionDryRun, - Input: []byte("{}"), + JobID: templateImportJob.ID, + CreatedBy: user.ID, }) ) + // Template dry-run job. + dryRunVersion := dbgen.TemplateVersion(t, db, database.TemplateVersion{ + OrganizationID: org.ID, + CreatedBy: user.ID, + }) + input, err := json.Marshal(provisionerdserver.TemplateVersionDryRunJob{ + TemplateVersionID: dryRunVersion.ID, + }) + require.NoError(t, err) + templateDryRunJob := dbgen.ProvisionerJob(t, db, pubsub, database.ProvisionerJob{ + CreatedAt: tenMinAgo, + UpdatedAt: sixMinAgo, + StartedAt: sql.NullTime{ + Time: tenMinAgo, + Valid: true, + }, + OrganizationID: org.ID, + InitiatorID: user.ID, + Provisioner: database.ProvisionerTypeEcho, + StorageMethod: database.ProvisionerStorageMethodFile, + FileID: file.ID, + Type: database.ProvisionerJobTypeTemplateVersionDryRun, + Input: input, + }) + t.Log("template import job ID: ", templateImportJob.ID) t.Log("template dry-run job ID: ", templateDryRunJob.ID) - detector := unhanger.New(ctx, db, pubsub, log, tickCh).WithStatsChannel(statsCh) + detector := unhanger.New(ctx, wrapDBAuthz(db, log), pubsub, log, tickCh).WithStatsChannel(statsCh) detector.Start() tickCh <- now @@ -564,11 +583,16 @@ func TestDetectorHungCanceledJob(t *testing.T) { Type: database.ProvisionerJobTypeTemplateVersionImport, Input: []byte("{}"), }) + _ = dbgen.TemplateVersion(t, db, database.TemplateVersion{ + OrganizationID: org.ID, + JobID: templateImportJob.ID, + CreatedBy: user.ID, + }) ) t.Log("template import job ID: ", templateImportJob.ID) - detector := unhanger.New(ctx, db, pubsub, log, tickCh).WithStatsChannel(statsCh) + detector := unhanger.New(ctx, wrapDBAuthz(db, log), pubsub, log, tickCh).WithStatsChannel(statsCh) detector.Start() tickCh <- now @@ -657,6 +681,11 @@ func TestDetectorPushesLogs(t *testing.T) { Type: database.ProvisionerJobTypeTemplateVersionImport, Input: []byte("{}"), }) + _ = dbgen.TemplateVersion(t, db, database.TemplateVersion{ + OrganizationID: org.ID, + JobID: templateImportJob.ID, + CreatedBy: user.ID, + }) ) t.Log("template import job ID: ", templateImportJob.ID) @@ -678,7 +707,7 @@ func TestDetectorPushesLogs(t *testing.T) { require.Len(t, logs, 10) } - detector := unhanger.New(ctx, db, pubsub, log, tickCh).WithStatsChannel(statsCh) + detector := unhanger.New(ctx, wrapDBAuthz(db, log), pubsub, log, tickCh).WithStatsChannel(statsCh) detector.Start() // Create pubsub subscription to listen for new log events. @@ -752,7 +781,7 @@ func TestDetectorMaxJobsPerRun(t *testing.T) { // Create unhanger.MaxJobsPerRun + 1 hung jobs. now := time.Now() for i := 0; i < unhanger.MaxJobsPerRun+1; i++ { - dbgen.ProvisionerJob(t, db, pubsub, database.ProvisionerJob{ + pj := dbgen.ProvisionerJob(t, db, pubsub, database.ProvisionerJob{ CreatedAt: now.Add(-time.Hour), UpdatedAt: now.Add(-time.Hour), StartedAt: sql.NullTime{ @@ -767,9 +796,14 @@ func TestDetectorMaxJobsPerRun(t *testing.T) { Type: database.ProvisionerJobTypeTemplateVersionImport, Input: []byte("{}"), }) + _ = dbgen.TemplateVersion(t, db, database.TemplateVersion{ + OrganizationID: org.ID, + JobID: pj.ID, + CreatedBy: user.ID, + }) } - detector := unhanger.New(ctx, db, pubsub, log, tickCh).WithStatsChannel(statsCh) + detector := unhanger.New(ctx, wrapDBAuthz(db, log), pubsub, log, tickCh).WithStatsChannel(statsCh) detector.Start() tickCh <- now @@ -788,3 +822,14 @@ func TestDetectorMaxJobsPerRun(t *testing.T) { detector.Close() detector.Wait() } + +// wrapDBAuthz adds our Authorization/RBAC around the given database store, to +// ensure the unhanger has the right permissions to do its work. +func wrapDBAuthz(db database.Store, logger slog.Logger) database.Store { + return dbauthz.New( + db, + rbac.NewStrictCachingAuthorizer(prometheus.NewRegistry()), + logger, + coderdtest.AccessControlStorePointer(), + ) +}