diff --git a/coderd/database/dbauthz/dbauthz.go b/coderd/database/dbauthz/dbauthz.go index 14ad39e114c5b..2486c8e715f3c 100644 --- a/coderd/database/dbauthz/dbauthz.go +++ b/coderd/database/dbauthz/dbauthz.go @@ -2610,6 +2610,18 @@ func (q *querier) GetProvisionerJobByIDForUpdate(ctx context.Context, id uuid.UU return job, nil } +func (q *querier) GetProvisionerJobByIDWithLock(ctx context.Context, id uuid.UUID) (database.ProvisionerJob, error) { + job, err := q.db.GetProvisionerJobByIDWithLock(ctx, id) + if err != nil { + return database.ProvisionerJob{}, err + } + + if err := q.authorizeProvisionerJob(ctx, job); err != nil { + return database.ProvisionerJob{}, err + } + return job, nil +} + func (q *querier) GetProvisionerJobTimingsByJobID(ctx context.Context, jobID uuid.UUID) ([]database.ProvisionerJobTiming, error) { _, err := q.GetProvisionerJobByID(ctx, jobID) if err != nil { diff --git a/coderd/database/dbauthz/dbauthz_test.go b/coderd/database/dbauthz/dbauthz_test.go index bc369eed92b6c..a0a3e991e6989 100644 --- a/coderd/database/dbauthz/dbauthz_test.go +++ b/coderd/database/dbauthz/dbauthz_test.go @@ -664,6 +664,24 @@ func (s *MethodTestSuite) TestProvisionerJob() { dbm.EXPECT().GetProvisionerLogsAfterID(gomock.Any(), arg).Return([]database.ProvisionerJobLog{}, nil).AnyTimes() check.Args(arg).Asserts(ws, policy.ActionRead).Returns([]database.ProvisionerJobLog{}) })) + s.Run("Build/GetProvisionerJobByIDWithLock", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + ws := testutil.Fake(s.T(), faker, database.Workspace{}) + j := testutil.Fake(s.T(), faker, database.ProvisionerJob{Type: database.ProvisionerJobTypeWorkspaceBuild}) + build := testutil.Fake(s.T(), faker, database.WorkspaceBuild{WorkspaceID: ws.ID, JobID: j.ID}) + dbm.EXPECT().GetProvisionerJobByIDWithLock(gomock.Any(), j.ID).Return(j, nil).AnyTimes() + dbm.EXPECT().GetWorkspaceBuildByJobID(gomock.Any(), j.ID).Return(build, nil).AnyTimes() + dbm.EXPECT().GetWorkspaceByID(gomock.Any(), build.WorkspaceID).Return(ws, nil).AnyTimes() + check.Args(j.ID).Asserts(ws, policy.ActionRead).Returns(j) + })) + s.Run("TemplateVersion/GetProvisionerJobByIDWithLock", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + tpl := testutil.Fake(s.T(), faker, database.Template{}) + j := testutil.Fake(s.T(), faker, database.ProvisionerJob{Type: database.ProvisionerJobTypeTemplateVersionImport}) + v := testutil.Fake(s.T(), faker, database.TemplateVersion{JobID: j.ID, TemplateID: uuid.NullUUID{UUID: tpl.ID, Valid: true}}) + dbm.EXPECT().GetProvisionerJobByIDWithLock(gomock.Any(), j.ID).Return(j, nil).AnyTimes() + dbm.EXPECT().GetTemplateVersionByJobID(gomock.Any(), j.ID).Return(v, nil).AnyTimes() + dbm.EXPECT().GetTemplateByID(gomock.Any(), tpl.ID).Return(tpl, nil).AnyTimes() + check.Args(j.ID).Asserts(v.RBACObject(tpl), policy.ActionRead).Returns(j) + })) } func (s *MethodTestSuite) TestLicense() { diff --git a/coderd/database/dbmetrics/querymetrics.go b/coderd/database/dbmetrics/querymetrics.go index 6fc79b10480a7..f89e68f02938d 100644 --- a/coderd/database/dbmetrics/querymetrics.go +++ b/coderd/database/dbmetrics/querymetrics.go @@ -1272,6 +1272,13 @@ func (m queryMetricsStore) GetProvisionerJobByIDForUpdate(ctx context.Context, i return r0, r1 } +func (m queryMetricsStore) GetProvisionerJobByIDWithLock(ctx context.Context, id uuid.UUID) (database.ProvisionerJob, error) { + start := time.Now() + r0, r1 := m.s.GetProvisionerJobByIDWithLock(ctx, id) + m.queryLatencies.WithLabelValues("GetProvisionerJobByIDWithLock").Observe(time.Since(start).Seconds()) + return r0, r1 +} + func (m queryMetricsStore) GetProvisionerJobTimingsByJobID(ctx context.Context, jobID uuid.UUID) ([]database.ProvisionerJobTiming, error) { start := time.Now() r0, r1 := m.s.GetProvisionerJobTimingsByJobID(ctx, jobID) diff --git a/coderd/database/dbmock/dbmock.go b/coderd/database/dbmock/dbmock.go index cdbcdc7678eb0..ce02050afb2f5 100644 --- a/coderd/database/dbmock/dbmock.go +++ b/coderd/database/dbmock/dbmock.go @@ -2670,6 +2670,21 @@ func (mr *MockStoreMockRecorder) GetProvisionerJobByIDForUpdate(ctx, id any) *go return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetProvisionerJobByIDForUpdate", reflect.TypeOf((*MockStore)(nil).GetProvisionerJobByIDForUpdate), ctx, id) } +// GetProvisionerJobByIDWithLock mocks base method. +func (m *MockStore) GetProvisionerJobByIDWithLock(ctx context.Context, id uuid.UUID) (database.ProvisionerJob, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetProvisionerJobByIDWithLock", ctx, id) + ret0, _ := ret[0].(database.ProvisionerJob) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetProvisionerJobByIDWithLock indicates an expected call of GetProvisionerJobByIDWithLock. +func (mr *MockStoreMockRecorder) GetProvisionerJobByIDWithLock(ctx, id any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetProvisionerJobByIDWithLock", reflect.TypeOf((*MockStore)(nil).GetProvisionerJobByIDWithLock), ctx, id) +} + // GetProvisionerJobTimingsByJobID mocks base method. func (m *MockStore) GetProvisionerJobTimingsByJobID(ctx context.Context, jobID uuid.UUID) ([]database.ProvisionerJobTiming, error) { m.ctrl.T.Helper() diff --git a/coderd/database/querier.go b/coderd/database/querier.go index 1d14130790114..b5ef14f6b86b5 100644 --- a/coderd/database/querier.go +++ b/coderd/database/querier.go @@ -296,6 +296,9 @@ type sqlcQuerier interface { // Gets a single provisioner job by ID for update. // This is used to securely reap jobs that have been hung/pending for a long time. GetProvisionerJobByIDForUpdate(ctx context.Context, id uuid.UUID) (ProvisionerJob, error) + // Gets a provisioner job by ID with exclusive lock. + // Blocks until the row is available for update. + GetProvisionerJobByIDWithLock(ctx context.Context, id uuid.UUID) (ProvisionerJob, error) GetProvisionerJobTimingsByJobID(ctx context.Context, jobID uuid.UUID) ([]ProvisionerJobTiming, error) GetProvisionerJobsByIDs(ctx context.Context, ids []uuid.UUID) ([]ProvisionerJob, error) GetProvisionerJobsByIDsWithQueuePosition(ctx context.Context, arg GetProvisionerJobsByIDsWithQueuePositionParams) ([]GetProvisionerJobsByIDsWithQueuePositionRow, error) diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index 7050ec11d31e5..af975247f6aa0 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -8875,6 +8875,47 @@ func (q *sqlQuerier) GetProvisionerJobByIDForUpdate(ctx context.Context, id uuid return i, err } +const getProvisionerJobByIDWithLock = `-- name: GetProvisionerJobByIDWithLock :one +SELECT + id, created_at, updated_at, started_at, canceled_at, completed_at, error, organization_id, initiator_id, provisioner, storage_method, type, input, worker_id, file_id, tags, error_code, trace_metadata, job_status, logs_length, logs_overflowed +FROM + provisioner_jobs +WHERE + id = $1 +FOR UPDATE +` + +// Gets a provisioner job by ID with exclusive lock. +// Blocks until the row is available for update. +func (q *sqlQuerier) GetProvisionerJobByIDWithLock(ctx context.Context, id uuid.UUID) (ProvisionerJob, error) { + row := q.db.QueryRowContext(ctx, getProvisionerJobByIDWithLock, id) + var i ProvisionerJob + err := row.Scan( + &i.ID, + &i.CreatedAt, + &i.UpdatedAt, + &i.StartedAt, + &i.CanceledAt, + &i.CompletedAt, + &i.Error, + &i.OrganizationID, + &i.InitiatorID, + &i.Provisioner, + &i.StorageMethod, + &i.Type, + &i.Input, + &i.WorkerID, + &i.FileID, + &i.Tags, + &i.ErrorCode, + &i.TraceMetadata, + &i.JobStatus, + &i.LogsLength, + &i.LogsOverflowed, + ) + return i, err +} + const getProvisionerJobTimingsByJobID = `-- name: GetProvisionerJobTimingsByJobID :many SELECT job_id, started_at, ended_at, stage, source, action, resource FROM provisioner_job_timings WHERE job_id = $1 diff --git a/coderd/database/queries/provisionerjobs.sql b/coderd/database/queries/provisionerjobs.sql index 3ba581646689e..dfc95a0bb4570 100644 --- a/coderd/database/queries/provisionerjobs.sql +++ b/coderd/database/queries/provisionerjobs.sql @@ -55,6 +55,17 @@ WHERE FOR UPDATE SKIP LOCKED; +-- name: GetProvisionerJobByIDWithLock :one +-- Gets a provisioner job by ID with exclusive lock. +-- Blocks until the row is available for update. +SELECT + * +FROM + provisioner_jobs +WHERE + id = $1 +FOR UPDATE; + -- name: GetProvisionerJobsByIDs :many SELECT * diff --git a/coderd/workspacebuilds.go b/coderd/workspacebuilds.go index 2fdb40a1e4661..b6409d8ed781d 100644 --- a/coderd/workspacebuilds.go +++ b/coderd/workspacebuilds.go @@ -663,7 +663,7 @@ func (api *API) patchCancelWorkspaceBuild(rw http.ResponseWriter, r *http.Reques return xerrors.New("user is not allowed to cancel workspace builds") } - job, err := db.GetProvisionerJobByIDForUpdate(ctx, workspaceBuild.JobID) + job, err := db.GetProvisionerJobByIDWithLock(ctx, workspaceBuild.JobID) if err != nil { code = http.StatusInternalServerError resp.Message = "Internal error fetching provisioner job." diff --git a/coderd/workspacebuilds_test.go b/coderd/workspacebuilds_test.go index 994411a8b3817..2c518a95e53a6 100644 --- a/coderd/workspacebuilds_test.go +++ b/coderd/workspacebuilds_test.go @@ -580,7 +580,7 @@ func TestPatchCancelWorkspaceBuild(t *testing.T) { require.Eventually(t, func() bool { err := client.CancelWorkspaceBuild(ctx, build.ID, codersdk.CancelWorkspaceBuildParams{}) - return assert.NoError(t, err) + return err == nil }, testutil.WaitShort, testutil.IntervalMedium) require.Eventually(t, func() bool {