Skip to content

Commit 196399c

Browse files
committed
chore: add dbauthz to unhanger tests
1 parent 3514ca3 commit 196399c

File tree

1 file changed

+68
-25
lines changed

1 file changed

+68
-25
lines changed

coderd/unhanger/detector_test.go

+68-25
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,24 @@ import (
55
"database/sql"
66
"encoding/json"
77
"fmt"
8+
"sync/atomic"
89
"testing"
910
"time"
1011

1112
"github.com/google/uuid"
13+
"github.com/prometheus/client_golang/prometheus"
1214
"github.com/stretchr/testify/assert"
1315
"github.com/stretchr/testify/require"
1416
"go.uber.org/goleak"
1517

18+
"cdr.dev/slog"
1619
"cdr.dev/slog/sloggers/slogtest"
1720
"github.com/coder/coder/v2/coderd/database"
21+
"github.com/coder/coder/v2/coderd/database/dbauthz"
1822
"github.com/coder/coder/v2/coderd/database/dbgen"
1923
"github.com/coder/coder/v2/coderd/database/dbtestutil"
24+
"github.com/coder/coder/v2/coderd/provisionerdserver"
25+
"github.com/coder/coder/v2/coderd/rbac"
2026
"github.com/coder/coder/v2/coderd/unhanger"
2127
"github.com/coder/coder/v2/provisionersdk"
2228
"github.com/coder/coder/v2/testutil"
@@ -37,7 +43,7 @@ func TestDetectorNoJobs(t *testing.T) {
3743
statsCh = make(chan unhanger.Stats)
3844
)
3945

40-
detector := unhanger.New(ctx, db, pubsub, log, tickCh).WithStatsChannel(statsCh)
46+
detector := unhanger.New(ctx, wrapDBAuthz(db, log), pubsub, log, tickCh).WithStatsChannel(statsCh)
4147
detector.Start()
4248
tickCh <- time.Now()
4349

@@ -84,7 +90,7 @@ func TestDetectorNoHungJobs(t *testing.T) {
8490
})
8591
}
8692

87-
detector := unhanger.New(ctx, db, pubsub, log, tickCh).WithStatsChannel(statsCh)
93+
detector := unhanger.New(ctx, wrapDBAuthz(db, log), pubsub, log, tickCh).WithStatsChannel(statsCh)
8894
detector.Start()
8995
tickCh <- now
9096

@@ -190,7 +196,7 @@ func TestDetectorHungWorkspaceBuild(t *testing.T) {
190196
t.Log("previous job ID: ", previousWorkspaceBuildJob.ID)
191197
t.Log("current job ID: ", currentWorkspaceBuildJob.ID)
192198

193-
detector := unhanger.New(ctx, db, pubsub, log, tickCh).WithStatsChannel(statsCh)
199+
detector := unhanger.New(ctx, wrapDBAuthz(db, log), pubsub, log, tickCh).WithStatsChannel(statsCh)
194200
detector.Start()
195201
tickCh <- now
196202

@@ -313,7 +319,7 @@ func TestDetectorHungWorkspaceBuildNoOverrideState(t *testing.T) {
313319
t.Log("previous job ID: ", previousWorkspaceBuildJob.ID)
314320
t.Log("current job ID: ", currentWorkspaceBuildJob.ID)
315321

316-
detector := unhanger.New(ctx, db, pubsub, log, tickCh).WithStatsChannel(statsCh)
322+
detector := unhanger.New(ctx, wrapDBAuthz(db, log), pubsub, log, tickCh).WithStatsChannel(statsCh)
317323
detector.Start()
318324
tickCh <- now
319325

@@ -406,7 +412,7 @@ func TestDetectorHungWorkspaceBuildNoOverrideStateIfNoExistingBuild(t *testing.T
406412

407413
t.Log("current job ID: ", currentWorkspaceBuildJob.ID)
408414

409-
detector := unhanger.New(ctx, db, pubsub, log, tickCh).WithStatsChannel(statsCh)
415+
detector := unhanger.New(ctx, wrapDBAuthz(db, log), pubsub, log, tickCh).WithStatsChannel(statsCh)
410416
detector.Start()
411417
tickCh <- now
412418

@@ -469,29 +475,40 @@ func TestDetectorHungOtherJobTypes(t *testing.T) {
469475
Type: database.ProvisionerJobTypeTemplateVersionImport,
470476
Input: []byte("{}"),
471477
})
472-
473-
// Template dry-run job.
474-
templateDryRunJob = dbgen.ProvisionerJob(t, db, pubsub, database.ProvisionerJob{
475-
CreatedAt: tenMinAgo,
476-
UpdatedAt: sixMinAgo,
477-
StartedAt: sql.NullTime{
478-
Time: tenMinAgo,
479-
Valid: true,
480-
},
478+
_ = dbgen.TemplateVersion(t, db, database.TemplateVersion{
481479
OrganizationID: org.ID,
482-
InitiatorID: user.ID,
483-
Provisioner: database.ProvisionerTypeEcho,
484-
StorageMethod: database.ProvisionerStorageMethodFile,
485-
FileID: file.ID,
486-
Type: database.ProvisionerJobTypeTemplateVersionDryRun,
487-
Input: []byte("{}"),
480+
JobID: templateImportJob.ID,
488481
})
489482
)
490483

484+
// Template dry-run job.
485+
dryRunVersion := dbgen.TemplateVersion(t, db, database.TemplateVersion{
486+
OrganizationID: org.ID,
487+
})
488+
input, err := json.Marshal(provisionerdserver.TemplateVersionDryRunJob{
489+
TemplateVersionID: dryRunVersion.ID,
490+
})
491+
require.NoError(t, err)
492+
templateDryRunJob := dbgen.ProvisionerJob(t, db, pubsub, database.ProvisionerJob{
493+
CreatedAt: tenMinAgo,
494+
UpdatedAt: sixMinAgo,
495+
StartedAt: sql.NullTime{
496+
Time: tenMinAgo,
497+
Valid: true,
498+
},
499+
OrganizationID: org.ID,
500+
InitiatorID: user.ID,
501+
Provisioner: database.ProvisionerTypeEcho,
502+
StorageMethod: database.ProvisionerStorageMethodFile,
503+
FileID: file.ID,
504+
Type: database.ProvisionerJobTypeTemplateVersionDryRun,
505+
Input: input,
506+
})
507+
491508
t.Log("template import job ID: ", templateImportJob.ID)
492509
t.Log("template dry-run job ID: ", templateDryRunJob.ID)
493510

494-
detector := unhanger.New(ctx, db, pubsub, log, tickCh).WithStatsChannel(statsCh)
511+
detector := unhanger.New(ctx, wrapDBAuthz(db, log), pubsub, log, tickCh).WithStatsChannel(statsCh)
495512
detector.Start()
496513
tickCh <- now
497514

@@ -564,11 +581,15 @@ func TestDetectorHungCanceledJob(t *testing.T) {
564581
Type: database.ProvisionerJobTypeTemplateVersionImport,
565582
Input: []byte("{}"),
566583
})
584+
_ = dbgen.TemplateVersion(t, db, database.TemplateVersion{
585+
OrganizationID: org.ID,
586+
JobID: templateImportJob.ID,
587+
})
567588
)
568589

569590
t.Log("template import job ID: ", templateImportJob.ID)
570591

571-
detector := unhanger.New(ctx, db, pubsub, log, tickCh).WithStatsChannel(statsCh)
592+
detector := unhanger.New(ctx, wrapDBAuthz(db, log), pubsub, log, tickCh).WithStatsChannel(statsCh)
572593
detector.Start()
573594
tickCh <- now
574595

@@ -657,6 +678,10 @@ func TestDetectorPushesLogs(t *testing.T) {
657678
Type: database.ProvisionerJobTypeTemplateVersionImport,
658679
Input: []byte("{}"),
659680
})
681+
_ = dbgen.TemplateVersion(t, db, database.TemplateVersion{
682+
OrganizationID: org.ID,
683+
JobID: templateImportJob.ID,
684+
})
660685
)
661686

662687
t.Log("template import job ID: ", templateImportJob.ID)
@@ -678,7 +703,7 @@ func TestDetectorPushesLogs(t *testing.T) {
678703
require.Len(t, logs, 10)
679704
}
680705

681-
detector := unhanger.New(ctx, db, pubsub, log, tickCh).WithStatsChannel(statsCh)
706+
detector := unhanger.New(ctx, wrapDBAuthz(db, log), pubsub, log, tickCh).WithStatsChannel(statsCh)
682707
detector.Start()
683708

684709
// Create pubsub subscription to listen for new log events.
@@ -752,7 +777,7 @@ func TestDetectorMaxJobsPerRun(t *testing.T) {
752777
// Create unhanger.MaxJobsPerRun + 1 hung jobs.
753778
now := time.Now()
754779
for i := 0; i < unhanger.MaxJobsPerRun+1; i++ {
755-
dbgen.ProvisionerJob(t, db, pubsub, database.ProvisionerJob{
780+
pj := dbgen.ProvisionerJob(t, db, pubsub, database.ProvisionerJob{
756781
CreatedAt: now.Add(-time.Hour),
757782
UpdatedAt: now.Add(-time.Hour),
758783
StartedAt: sql.NullTime{
@@ -767,9 +792,13 @@ func TestDetectorMaxJobsPerRun(t *testing.T) {
767792
Type: database.ProvisionerJobTypeTemplateVersionImport,
768793
Input: []byte("{}"),
769794
})
795+
_ = dbgen.TemplateVersion(t, db, database.TemplateVersion{
796+
OrganizationID: org.ID,
797+
JobID: pj.ID,
798+
})
770799
}
771800

772-
detector := unhanger.New(ctx, db, pubsub, log, tickCh).WithStatsChannel(statsCh)
801+
detector := unhanger.New(ctx, wrapDBAuthz(db, log), pubsub, log, tickCh).WithStatsChannel(statsCh)
773802
detector.Start()
774803
tickCh <- now
775804

@@ -788,3 +817,17 @@ func TestDetectorMaxJobsPerRun(t *testing.T) {
788817
detector.Close()
789818
detector.Wait()
790819
}
820+
821+
// wrapDBAuthz adds our Authorization/RBAC around the given database store, to
822+
// ensure the unhanger has the right permissions to do its work.
823+
func wrapDBAuthz(db database.Store, logger slog.Logger) database.Store {
824+
accessControlStore := &atomic.Pointer[dbauthz.AccessControlStore]{}
825+
var acs dbauthz.AccessControlStore = dbauthz.AGPLTemplateAccessControlStore{}
826+
accessControlStore.Store(&acs)
827+
return dbauthz.New(
828+
db,
829+
rbac.NewStrictCachingAuthorizer(prometheus.NewRegistry()),
830+
logger,
831+
accessControlStore,
832+
)
833+
}

0 commit comments

Comments
 (0)