Skip to content

chore: add dbauthz to unhanger tests #14394

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 22, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 70 additions & 25 deletions coderd/unhanger/detector_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,20 @@ import (
"time"

"github.com/google/uuid"
"github.com/prometheus/client_golang/prometheus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/goleak"

"cdr.dev/slog"
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/v2/coderd/coderdtest"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbauthz"
"github.com/coder/coder/v2/coderd/database/dbgen"
"github.com/coder/coder/v2/coderd/database/dbtestutil"
"github.com/coder/coder/v2/coderd/provisionerdserver"
"github.com/coder/coder/v2/coderd/rbac"
"github.com/coder/coder/v2/coderd/unhanger"
"github.com/coder/coder/v2/provisionersdk"
"github.com/coder/coder/v2/testutil"
Expand All @@ -37,7 +43,7 @@ func TestDetectorNoJobs(t *testing.T) {
statsCh = make(chan unhanger.Stats)
)

detector := unhanger.New(ctx, db, pubsub, log, tickCh).WithStatsChannel(statsCh)
detector := unhanger.New(ctx, wrapDBAuthz(db, log), pubsub, log, tickCh).WithStatsChannel(statsCh)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Feels odd to have a raw db laying around outside this function call scope. I'd probably do this to not have a db laying around. Up to you 🤷‍♂️

	var (
		ctx         = testutil.Context(t, testutil.WaitLong)
		rdb, pubsub = dbtestutil.NewDB(t)
		db          = wrapDBAuthz(rdb, slogtest.Make(t, nil))
		log         = slogtest.Make(t, nil)
		tickCh      = make(chan time.Time)
		statsCh     = make(chan unhanger.Stats)
	)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The annoying thing about that is that I'd have to set up a dbauthz principal for the calls that write the job, user, template, etc. Authz for those operations is irrelevant to this test.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, yea it is annoying. dbgen gets around that with a genCtx:

var genCtx = dbauthz.As(context.Background(), rbac.Subject{

If you use dbgen to handle all the initial state, you can ignore that requirement.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Almost, but not quite all state is dbgen. It make sense to me to keep this way, since it really is just the unhanger under test here, so only it need the dbauthz.

detector.Start()
tickCh <- time.Now()

Expand Down Expand Up @@ -84,7 +90,7 @@ func TestDetectorNoHungJobs(t *testing.T) {
})
}

detector := unhanger.New(ctx, db, pubsub, log, tickCh).WithStatsChannel(statsCh)
detector := unhanger.New(ctx, wrapDBAuthz(db, log), pubsub, log, tickCh).WithStatsChannel(statsCh)
detector.Start()
tickCh <- now

Expand Down Expand Up @@ -190,7 +196,7 @@ func TestDetectorHungWorkspaceBuild(t *testing.T) {
t.Log("previous job ID: ", previousWorkspaceBuildJob.ID)
t.Log("current job ID: ", currentWorkspaceBuildJob.ID)

detector := unhanger.New(ctx, db, pubsub, log, tickCh).WithStatsChannel(statsCh)
detector := unhanger.New(ctx, wrapDBAuthz(db, log), pubsub, log, tickCh).WithStatsChannel(statsCh)
detector.Start()
tickCh <- now

Expand Down Expand Up @@ -313,7 +319,7 @@ func TestDetectorHungWorkspaceBuildNoOverrideState(t *testing.T) {
t.Log("previous job ID: ", previousWorkspaceBuildJob.ID)
t.Log("current job ID: ", currentWorkspaceBuildJob.ID)

detector := unhanger.New(ctx, db, pubsub, log, tickCh).WithStatsChannel(statsCh)
detector := unhanger.New(ctx, wrapDBAuthz(db, log), pubsub, log, tickCh).WithStatsChannel(statsCh)
detector.Start()
tickCh <- now

Expand Down Expand Up @@ -406,7 +412,7 @@ func TestDetectorHungWorkspaceBuildNoOverrideStateIfNoExistingBuild(t *testing.T

t.Log("current job ID: ", currentWorkspaceBuildJob.ID)

detector := unhanger.New(ctx, db, pubsub, log, tickCh).WithStatsChannel(statsCh)
detector := unhanger.New(ctx, wrapDBAuthz(db, log), pubsub, log, tickCh).WithStatsChannel(statsCh)
detector.Start()
tickCh <- now

Expand Down Expand Up @@ -469,29 +475,42 @@ func TestDetectorHungOtherJobTypes(t *testing.T) {
Type: database.ProvisionerJobTypeTemplateVersionImport,
Input: []byte("{}"),
})

// Template dry-run job.
templateDryRunJob = dbgen.ProvisionerJob(t, db, pubsub, database.ProvisionerJob{
CreatedAt: tenMinAgo,
UpdatedAt: sixMinAgo,
StartedAt: sql.NullTime{
Time: tenMinAgo,
Valid: true,
},
_ = dbgen.TemplateVersion(t, db, database.TemplateVersion{
OrganizationID: org.ID,
InitiatorID: user.ID,
Provisioner: database.ProvisionerTypeEcho,
StorageMethod: database.ProvisionerStorageMethodFile,
FileID: file.ID,
Type: database.ProvisionerJobTypeTemplateVersionDryRun,
Input: []byte("{}"),
JobID: templateImportJob.ID,
CreatedBy: user.ID,
})
)

// Template dry-run job.
dryRunVersion := dbgen.TemplateVersion(t, db, database.TemplateVersion{
OrganizationID: org.ID,
CreatedBy: user.ID,
})
input, err := json.Marshal(provisionerdserver.TemplateVersionDryRunJob{
TemplateVersionID: dryRunVersion.ID,
})
require.NoError(t, err)
templateDryRunJob := dbgen.ProvisionerJob(t, db, pubsub, database.ProvisionerJob{
CreatedAt: tenMinAgo,
UpdatedAt: sixMinAgo,
StartedAt: sql.NullTime{
Time: tenMinAgo,
Valid: true,
},
OrganizationID: org.ID,
InitiatorID: user.ID,
Provisioner: database.ProvisionerTypeEcho,
StorageMethod: database.ProvisionerStorageMethodFile,
FileID: file.ID,
Type: database.ProvisionerJobTypeTemplateVersionDryRun,
Input: input,
})

t.Log("template import job ID: ", templateImportJob.ID)
t.Log("template dry-run job ID: ", templateDryRunJob.ID)

detector := unhanger.New(ctx, db, pubsub, log, tickCh).WithStatsChannel(statsCh)
detector := unhanger.New(ctx, wrapDBAuthz(db, log), pubsub, log, tickCh).WithStatsChannel(statsCh)
detector.Start()
tickCh <- now

Expand Down Expand Up @@ -564,11 +583,16 @@ func TestDetectorHungCanceledJob(t *testing.T) {
Type: database.ProvisionerJobTypeTemplateVersionImport,
Input: []byte("{}"),
})
_ = dbgen.TemplateVersion(t, db, database.TemplateVersion{
OrganizationID: org.ID,
JobID: templateImportJob.ID,
CreatedBy: user.ID,
})
)

t.Log("template import job ID: ", templateImportJob.ID)

detector := unhanger.New(ctx, db, pubsub, log, tickCh).WithStatsChannel(statsCh)
detector := unhanger.New(ctx, wrapDBAuthz(db, log), pubsub, log, tickCh).WithStatsChannel(statsCh)
detector.Start()
tickCh <- now

Expand Down Expand Up @@ -657,6 +681,11 @@ func TestDetectorPushesLogs(t *testing.T) {
Type: database.ProvisionerJobTypeTemplateVersionImport,
Input: []byte("{}"),
})
_ = dbgen.TemplateVersion(t, db, database.TemplateVersion{
OrganizationID: org.ID,
JobID: templateImportJob.ID,
CreatedBy: user.ID,
})
)

t.Log("template import job ID: ", templateImportJob.ID)
Expand All @@ -678,7 +707,7 @@ func TestDetectorPushesLogs(t *testing.T) {
require.Len(t, logs, 10)
}

detector := unhanger.New(ctx, db, pubsub, log, tickCh).WithStatsChannel(statsCh)
detector := unhanger.New(ctx, wrapDBAuthz(db, log), pubsub, log, tickCh).WithStatsChannel(statsCh)
detector.Start()

// Create pubsub subscription to listen for new log events.
Expand Down Expand Up @@ -752,7 +781,7 @@ func TestDetectorMaxJobsPerRun(t *testing.T) {
// Create unhanger.MaxJobsPerRun + 1 hung jobs.
now := time.Now()
for i := 0; i < unhanger.MaxJobsPerRun+1; i++ {
dbgen.ProvisionerJob(t, db, pubsub, database.ProvisionerJob{
pj := dbgen.ProvisionerJob(t, db, pubsub, database.ProvisionerJob{
CreatedAt: now.Add(-time.Hour),
UpdatedAt: now.Add(-time.Hour),
StartedAt: sql.NullTime{
Expand All @@ -767,9 +796,14 @@ func TestDetectorMaxJobsPerRun(t *testing.T) {
Type: database.ProvisionerJobTypeTemplateVersionImport,
Input: []byte("{}"),
})
_ = dbgen.TemplateVersion(t, db, database.TemplateVersion{
OrganizationID: org.ID,
JobID: pj.ID,
CreatedBy: user.ID,
})
}

detector := unhanger.New(ctx, db, pubsub, log, tickCh).WithStatsChannel(statsCh)
detector := unhanger.New(ctx, wrapDBAuthz(db, log), pubsub, log, tickCh).WithStatsChannel(statsCh)
detector.Start()
tickCh <- now

Expand All @@ -788,3 +822,14 @@ func TestDetectorMaxJobsPerRun(t *testing.T) {
detector.Close()
detector.Wait()
}

// wrapDBAuthz adds our Authorization/RBAC around the given database store, to
// ensure the unhanger has the right permissions to do its work.
func wrapDBAuthz(db database.Store, logger slog.Logger) database.Store {
return dbauthz.New(
db,
rbac.NewStrictCachingAuthorizer(prometheus.NewRegistry()),
logger,
coderdtest.AccessControlStorePointer(),
)
}
Loading