diff --git a/coderd/database/dbauthz/dbauthz.go b/coderd/database/dbauthz/dbauthz.go index a210599d17cc4..027f5b78d7a77 100644 --- a/coderd/database/dbauthz/dbauthz.go +++ b/coderd/database/dbauthz/dbauthz.go @@ -2341,7 +2341,7 @@ func (q *querier) GetProvisionerJobsByIDs(ctx context.Context, ids []uuid.UUID) return provisionerJobs, nil } -func (q *querier) GetProvisionerJobsByIDsWithQueuePosition(ctx context.Context, ids []uuid.UUID) ([]database.GetProvisionerJobsByIDsWithQueuePositionRow, error) { +func (q *querier) GetProvisionerJobsByIDsWithQueuePosition(ctx context.Context, ids database.GetProvisionerJobsByIDsWithQueuePositionParams) ([]database.GetProvisionerJobsByIDsWithQueuePositionRow, error) { // TODO: Remove this once we have a proper rbac check for provisioner jobs. // Details in https://github.com/coder/coder/issues/16160 return q.db.GetProvisionerJobsByIDsWithQueuePosition(ctx, ids) diff --git a/coderd/database/dbauthz/dbauthz_test.go b/coderd/database/dbauthz/dbauthz_test.go index 703e51d739c47..3876ccac55dc6 100644 --- a/coderd/database/dbauthz/dbauthz_test.go +++ b/coderd/database/dbauthz/dbauthz_test.go @@ -4345,7 +4345,7 @@ func (s *MethodTestSuite) TestSystemFunctions() { check.Args([]uuid.UUID{uuid.New()}).Asserts(rbac.ResourceSystem, policy.ActionRead) })) s.Run("GetProvisionerJobsByIDsWithQueuePosition", s.Subtest(func(db database.Store, check *expects) { - check.Args([]uuid.UUID{}).Asserts() + check.Args(database.GetProvisionerJobsByIDsWithQueuePositionParams{}).Asserts() })) s.Run("GetReplicaByID", s.Subtest(func(db database.Store, check *expects) { check.Args(uuid.New()).Asserts(rbac.ResourceSystem, policy.ActionRead).Errors(sql.ErrNoRows) diff --git a/coderd/database/dbmem/dbmem.go b/coderd/database/dbmem/dbmem.go index a13dd33466bf6..7ba2a5a649a0e 100644 --- a/coderd/database/dbmem/dbmem.go +++ b/coderd/database/dbmem/dbmem.go @@ -4684,14 +4684,14 @@ func (q *FakeQuerier) GetProvisionerJobsByIDs(_ context.Context, ids []uuid.UUID return jobs, nil } -func (q *FakeQuerier) GetProvisionerJobsByIDsWithQueuePosition(ctx context.Context, ids []uuid.UUID) ([]database.GetProvisionerJobsByIDsWithQueuePositionRow, error) { +func (q *FakeQuerier) GetProvisionerJobsByIDsWithQueuePosition(ctx context.Context, arg database.GetProvisionerJobsByIDsWithQueuePositionParams) ([]database.GetProvisionerJobsByIDsWithQueuePositionRow, error) { q.mutex.RLock() defer q.mutex.RUnlock() - if ids == nil { - ids = []uuid.UUID{} + if arg.IDs == nil { + arg.IDs = []uuid.UUID{} } - return q.getProvisionerJobsByIDsWithQueuePositionLockedTagBasedQueue(ctx, ids) + return q.getProvisionerJobsByIDsWithQueuePositionLockedTagBasedQueue(ctx, arg.IDs) } func (q *FakeQuerier) GetProvisionerJobsByOrganizationAndStatusWithQueuePositionAndProvisioner(ctx context.Context, arg database.GetProvisionerJobsByOrganizationAndStatusWithQueuePositionAndProvisionerParams) ([]database.GetProvisionerJobsByOrganizationAndStatusWithQueuePositionAndProvisionerRow, error) { diff --git a/coderd/database/dbmetrics/querymetrics.go b/coderd/database/dbmetrics/querymetrics.go index e35ec11b02453..6229c74af2abf 100644 --- a/coderd/database/dbmetrics/querymetrics.go +++ b/coderd/database/dbmetrics/querymetrics.go @@ -1215,7 +1215,7 @@ func (m queryMetricsStore) GetProvisionerJobsByIDs(ctx context.Context, ids []uu return jobs, err } -func (m queryMetricsStore) GetProvisionerJobsByIDsWithQueuePosition(ctx context.Context, ids []uuid.UUID) ([]database.GetProvisionerJobsByIDsWithQueuePositionRow, error) { +func (m queryMetricsStore) GetProvisionerJobsByIDsWithQueuePosition(ctx context.Context, ids database.GetProvisionerJobsByIDsWithQueuePositionParams) ([]database.GetProvisionerJobsByIDsWithQueuePositionRow, error) { start := time.Now() r0, r1 := m.s.GetProvisionerJobsByIDsWithQueuePosition(ctx, ids) m.queryLatencies.WithLabelValues("GetProvisionerJobsByIDsWithQueuePosition").Observe(time.Since(start).Seconds()) diff --git a/coderd/database/dbmock/dbmock.go b/coderd/database/dbmock/dbmock.go index 7a1fc0c4b2a6f..bf64897208963 100644 --- a/coderd/database/dbmock/dbmock.go +++ b/coderd/database/dbmock/dbmock.go @@ -2494,18 +2494,18 @@ func (mr *MockStoreMockRecorder) GetProvisionerJobsByIDs(ctx, ids any) *gomock.C } // GetProvisionerJobsByIDsWithQueuePosition mocks base method. -func (m *MockStore) GetProvisionerJobsByIDsWithQueuePosition(ctx context.Context, ids []uuid.UUID) ([]database.GetProvisionerJobsByIDsWithQueuePositionRow, error) { +func (m *MockStore) GetProvisionerJobsByIDsWithQueuePosition(ctx context.Context, arg database.GetProvisionerJobsByIDsWithQueuePositionParams) ([]database.GetProvisionerJobsByIDsWithQueuePositionRow, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetProvisionerJobsByIDsWithQueuePosition", ctx, ids) + ret := m.ctrl.Call(m, "GetProvisionerJobsByIDsWithQueuePosition", ctx, arg) ret0, _ := ret[0].([]database.GetProvisionerJobsByIDsWithQueuePositionRow) ret1, _ := ret[1].(error) return ret0, ret1 } // GetProvisionerJobsByIDsWithQueuePosition indicates an expected call of GetProvisionerJobsByIDsWithQueuePosition. -func (mr *MockStoreMockRecorder) GetProvisionerJobsByIDsWithQueuePosition(ctx, ids any) *gomock.Call { +func (mr *MockStoreMockRecorder) GetProvisionerJobsByIDsWithQueuePosition(ctx, arg any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetProvisionerJobsByIDsWithQueuePosition", reflect.TypeOf((*MockStore)(nil).GetProvisionerJobsByIDsWithQueuePosition), ctx, ids) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetProvisionerJobsByIDsWithQueuePosition", reflect.TypeOf((*MockStore)(nil).GetProvisionerJobsByIDsWithQueuePosition), ctx, arg) } // GetProvisionerJobsByOrganizationAndStatusWithQueuePositionAndProvisioner mocks base method. diff --git a/coderd/database/querier.go b/coderd/database/querier.go index ac7497b641a05..c24016867b31a 100644 --- a/coderd/database/querier.go +++ b/coderd/database/querier.go @@ -278,7 +278,7 @@ type sqlcQuerier interface { GetProvisionerJobByIDForUpdate(ctx context.Context, id uuid.UUID) (ProvisionerJob, error) GetProvisionerJobTimingsByJobID(ctx context.Context, jobID uuid.UUID) ([]ProvisionerJobTiming, error) GetProvisionerJobsByIDs(ctx context.Context, ids []uuid.UUID) ([]ProvisionerJob, error) - GetProvisionerJobsByIDsWithQueuePosition(ctx context.Context, ids []uuid.UUID) ([]GetProvisionerJobsByIDsWithQueuePositionRow, error) + GetProvisionerJobsByIDsWithQueuePosition(ctx context.Context, arg GetProvisionerJobsByIDsWithQueuePositionParams) ([]GetProvisionerJobsByIDsWithQueuePositionRow, error) GetProvisionerJobsByOrganizationAndStatusWithQueuePositionAndProvisioner(ctx context.Context, arg GetProvisionerJobsByOrganizationAndStatusWithQueuePositionAndProvisionerParams) ([]GetProvisionerJobsByOrganizationAndStatusWithQueuePositionAndProvisionerRow, error) GetProvisionerJobsCreatedAfter(ctx context.Context, createdAt time.Time) ([]ProvisionerJob, error) // To avoid repeatedly attempting to reap the same jobs, we randomly order and limit to @max_jobs. diff --git a/coderd/database/querier_test.go b/coderd/database/querier_test.go index 5bafa58796b7a..b5ed8b019c1cb 100644 --- a/coderd/database/querier_test.go +++ b/coderd/database/querier_test.go @@ -15,7 +15,6 @@ import ( "github.com/stretchr/testify/require" "cdr.dev/slog/sloggers/slogtest" - "github.com/coder/coder/v2/coderd/coderdtest" "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/db2sdk" @@ -27,6 +26,7 @@ import ( "github.com/coder/coder/v2/coderd/database/migrations" "github.com/coder/coder/v2/coderd/httpmw" "github.com/coder/coder/v2/coderd/prebuilds" + "github.com/coder/coder/v2/coderd/provisionerdserver" "github.com/coder/coder/v2/coderd/rbac" "github.com/coder/coder/v2/coderd/rbac/policy" "github.com/coder/coder/v2/provisionersdk" @@ -1268,7 +1268,10 @@ func TestQueuePosition(t *testing.T) { Tags: database.StringMap{}, }) - queued, err := db.GetProvisionerJobsByIDsWithQueuePosition(ctx, jobIDs) + queued, err := db.GetProvisionerJobsByIDsWithQueuePosition(ctx, database.GetProvisionerJobsByIDsWithQueuePositionParams{ + IDs: jobIDs, + StaleIntervalMS: provisionerdserver.StaleInterval.Milliseconds(), + }) require.NoError(t, err) require.Len(t, queued, jobCount) sort.Slice(queued, func(i, j int) bool { @@ -1296,7 +1299,10 @@ func TestQueuePosition(t *testing.T) { require.NoError(t, err) require.Equal(t, jobs[0].ID, job.ID) - queued, err = db.GetProvisionerJobsByIDsWithQueuePosition(ctx, jobIDs) + queued, err = db.GetProvisionerJobsByIDsWithQueuePosition(ctx, database.GetProvisionerJobsByIDsWithQueuePositionParams{ + IDs: jobIDs, + StaleIntervalMS: provisionerdserver.StaleInterval.Milliseconds(), + }) require.NoError(t, err) require.Len(t, queued, jobCount) sort.Slice(queued, func(i, j int) bool { @@ -2550,7 +2556,10 @@ func TestGetProvisionerJobsByIDsWithQueuePosition(t *testing.T) { } // When: we fetch the jobs by their IDs - actualJobs, err := db.GetProvisionerJobsByIDsWithQueuePosition(ctx, filteredJobIDs) + actualJobs, err := db.GetProvisionerJobsByIDsWithQueuePosition(ctx, database.GetProvisionerJobsByIDsWithQueuePositionParams{ + IDs: filteredJobIDs, + StaleIntervalMS: provisionerdserver.StaleInterval.Milliseconds(), + }) require.NoError(t, err) require.Len(t, actualJobs, len(filteredJobs), "should return all unskipped jobs") @@ -2693,7 +2702,10 @@ func TestGetProvisionerJobsByIDsWithQueuePosition_MixedStatuses(t *testing.T) { } // When: we fetch the jobs by their IDs - actualJobs, err := db.GetProvisionerJobsByIDsWithQueuePosition(ctx, jobIDs) + actualJobs, err := db.GetProvisionerJobsByIDsWithQueuePosition(ctx, database.GetProvisionerJobsByIDsWithQueuePositionParams{ + IDs: jobIDs, + StaleIntervalMS: provisionerdserver.StaleInterval.Milliseconds(), + }) require.NoError(t, err) require.Len(t, actualJobs, len(allJobs), "should return all jobs") @@ -2788,7 +2800,10 @@ func TestGetProvisionerJobsByIDsWithQueuePosition_OrderValidation(t *testing.T) } // When: we fetch the jobs by their IDs - actualJobs, err := db.GetProvisionerJobsByIDsWithQueuePosition(ctx, jobIDs) + actualJobs, err := db.GetProvisionerJobsByIDsWithQueuePosition(ctx, database.GetProvisionerJobsByIDsWithQueuePositionParams{ + IDs: jobIDs, + StaleIntervalMS: provisionerdserver.StaleInterval.Milliseconds(), + }) require.NoError(t, err) require.Len(t, actualJobs, len(allJobs), "should return all jobs") diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index c166dd5fed89a..0d620706d011c 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -7663,17 +7663,21 @@ pending_jobs AS ( WHERE job_status = 'pending' ), +online_provisioner_daemons AS ( + SELECT id, tags FROM provisioner_daemons pd + WHERE pd.last_seen_at IS NOT NULL AND pd.last_seen_at >= (NOW() - ($2::bigint || ' ms')::interval) +), ranked_jobs AS ( -- Step 3: Rank only pending jobs based on provisioner availability SELECT pj.id, pj.created_at, - ROW_NUMBER() OVER (PARTITION BY pd.id ORDER BY pj.created_at ASC) AS queue_position, - COUNT(*) OVER (PARTITION BY pd.id) AS queue_size + ROW_NUMBER() OVER (PARTITION BY opd.id ORDER BY pj.created_at ASC) AS queue_position, + COUNT(*) OVER (PARTITION BY opd.id) AS queue_size FROM pending_jobs pj - INNER JOIN provisioner_daemons pd - ON provisioner_tagset_contains(pd.tags, pj.tags) -- Join only on the small pending set + INNER JOIN online_provisioner_daemons opd + ON provisioner_tagset_contains(opd.tags, pj.tags) -- Join only on the small pending set ), final_jobs AS ( -- Step 4: Compute best queue position and max queue size per job @@ -7705,6 +7709,11 @@ ORDER BY fj.created_at ` +type GetProvisionerJobsByIDsWithQueuePositionParams struct { + IDs []uuid.UUID `db:"ids" json:"ids"` + StaleIntervalMS int64 `db:"stale_interval_ms" json:"stale_interval_ms"` +} + type GetProvisionerJobsByIDsWithQueuePositionRow struct { ID uuid.UUID `db:"id" json:"id"` CreatedAt time.Time `db:"created_at" json:"created_at"` @@ -7713,8 +7722,8 @@ type GetProvisionerJobsByIDsWithQueuePositionRow struct { QueueSize int64 `db:"queue_size" json:"queue_size"` } -func (q *sqlQuerier) GetProvisionerJobsByIDsWithQueuePosition(ctx context.Context, ids []uuid.UUID) ([]GetProvisionerJobsByIDsWithQueuePositionRow, error) { - rows, err := q.db.QueryContext(ctx, getProvisionerJobsByIDsWithQueuePosition, pq.Array(ids)) +func (q *sqlQuerier) GetProvisionerJobsByIDsWithQueuePosition(ctx context.Context, arg GetProvisionerJobsByIDsWithQueuePositionParams) ([]GetProvisionerJobsByIDsWithQueuePositionRow, error) { + rows, err := q.db.QueryContext(ctx, getProvisionerJobsByIDsWithQueuePosition, pq.Array(arg.IDs), arg.StaleIntervalMS) if err != nil { return nil, err } diff --git a/coderd/database/queries/provisionerjobs.sql b/coderd/database/queries/provisionerjobs.sql index 88bacc705601c..f3902ba2ddd38 100644 --- a/coderd/database/queries/provisionerjobs.sql +++ b/coderd/database/queries/provisionerjobs.sql @@ -80,17 +80,21 @@ pending_jobs AS ( WHERE job_status = 'pending' ), +online_provisioner_daemons AS ( + SELECT id, tags FROM provisioner_daemons pd + WHERE pd.last_seen_at IS NOT NULL AND pd.last_seen_at >= (NOW() - (@stale_interval_ms::bigint || ' ms')::interval) +), ranked_jobs AS ( -- Step 3: Rank only pending jobs based on provisioner availability SELECT pj.id, pj.created_at, - ROW_NUMBER() OVER (PARTITION BY pd.id ORDER BY pj.created_at ASC) AS queue_position, - COUNT(*) OVER (PARTITION BY pd.id) AS queue_size + ROW_NUMBER() OVER (PARTITION BY opd.id ORDER BY pj.created_at ASC) AS queue_position, + COUNT(*) OVER (PARTITION BY opd.id) AS queue_size FROM pending_jobs pj - INNER JOIN provisioner_daemons pd - ON provisioner_tagset_contains(pd.tags, pj.tags) -- Join only on the small pending set + INNER JOIN online_provisioner_daemons opd + ON provisioner_tagset_contains(opd.tags, pj.tags) -- Join only on the small pending set ), final_jobs AS ( -- Step 4: Compute best queue position and max queue size per job diff --git a/coderd/templateversions.go b/coderd/templateversions.go index 7b682eac14ea0..8dd523374d69f 100644 --- a/coderd/templateversions.go +++ b/coderd/templateversions.go @@ -53,7 +53,10 @@ func (api *API) templateVersion(rw http.ResponseWriter, r *http.Request) { ctx := r.Context() templateVersion := httpmw.TemplateVersionParam(r) - jobs, err := api.Database.GetProvisionerJobsByIDsWithQueuePosition(ctx, []uuid.UUID{templateVersion.JobID}) + jobs, err := api.Database.GetProvisionerJobsByIDsWithQueuePosition(ctx, database.GetProvisionerJobsByIDsWithQueuePositionParams{ + IDs: []uuid.UUID{templateVersion.JobID}, + StaleIntervalMS: provisionerdserver.StaleInterval.Milliseconds(), + }) if err != nil || len(jobs) == 0 { httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ Message: "Internal error fetching provisioner job.", @@ -182,7 +185,10 @@ func (api *API) patchTemplateVersion(rw http.ResponseWriter, r *http.Request) { return } - jobs, err := api.Database.GetProvisionerJobsByIDsWithQueuePosition(ctx, []uuid.UUID{templateVersion.JobID}) + jobs, err := api.Database.GetProvisionerJobsByIDsWithQueuePosition(ctx, database.GetProvisionerJobsByIDsWithQueuePositionParams{ + IDs: []uuid.UUID{templateVersion.JobID}, + StaleIntervalMS: provisionerdserver.StaleInterval.Milliseconds(), + }) if err != nil || len(jobs) == 0 { httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ Message: "Internal error fetching provisioner job.", @@ -733,7 +739,10 @@ func (api *API) fetchTemplateVersionDryRunJob(rw http.ResponseWriter, r *http.Re return database.GetProvisionerJobsByIDsWithQueuePositionRow{}, false } - jobs, err := api.Database.GetProvisionerJobsByIDsWithQueuePosition(ctx, []uuid.UUID{jobUUID}) + jobs, err := api.Database.GetProvisionerJobsByIDsWithQueuePosition(ctx, database.GetProvisionerJobsByIDsWithQueuePositionParams{ + IDs: []uuid.UUID{jobUUID}, + StaleIntervalMS: provisionerdserver.StaleInterval.Milliseconds(), + }) if httpapi.Is404Error(err) { httpapi.Write(ctx, rw, http.StatusNotFound, codersdk.Response{ Message: fmt.Sprintf("Provisioner job %q not found.", jobUUID), @@ -865,7 +874,10 @@ func (api *API) templateVersionsByTemplate(rw http.ResponseWriter, r *http.Reque for _, version := range versions { jobIDs = append(jobIDs, version.JobID) } - jobs, err := store.GetProvisionerJobsByIDsWithQueuePosition(ctx, jobIDs) + jobs, err := store.GetProvisionerJobsByIDsWithQueuePosition(ctx, database.GetProvisionerJobsByIDsWithQueuePositionParams{ + IDs: jobIDs, + StaleIntervalMS: provisionerdserver.StaleInterval.Milliseconds(), + }) if err != nil { httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ Message: "Internal error fetching provisioner job.", @@ -933,7 +945,10 @@ func (api *API) templateVersionByName(rw http.ResponseWriter, r *http.Request) { }) return } - jobs, err := api.Database.GetProvisionerJobsByIDsWithQueuePosition(ctx, []uuid.UUID{templateVersion.JobID}) + jobs, err := api.Database.GetProvisionerJobsByIDsWithQueuePosition(ctx, database.GetProvisionerJobsByIDsWithQueuePositionParams{ + IDs: []uuid.UUID{templateVersion.JobID}, + StaleIntervalMS: provisionerdserver.StaleInterval.Milliseconds(), + }) if err != nil || len(jobs) == 0 { httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ Message: "Internal error fetching provisioner job.", @@ -1013,7 +1028,10 @@ func (api *API) templateVersionByOrganizationTemplateAndName(rw http.ResponseWri }) return } - jobs, err := api.Database.GetProvisionerJobsByIDsWithQueuePosition(ctx, []uuid.UUID{templateVersion.JobID}) + jobs, err := api.Database.GetProvisionerJobsByIDsWithQueuePosition(ctx, database.GetProvisionerJobsByIDsWithQueuePositionParams{ + IDs: []uuid.UUID{templateVersion.JobID}, + StaleIntervalMS: provisionerdserver.StaleInterval.Milliseconds(), + }) if err != nil || len(jobs) == 0 { httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ Message: "Internal error fetching provisioner job.", @@ -1115,7 +1133,10 @@ func (api *API) previousTemplateVersionByOrganizationTemplateAndName(rw http.Res return } - jobs, err := api.Database.GetProvisionerJobsByIDsWithQueuePosition(ctx, []uuid.UUID{previousTemplateVersion.JobID}) + jobs, err := api.Database.GetProvisionerJobsByIDsWithQueuePosition(ctx, database.GetProvisionerJobsByIDsWithQueuePositionParams{ + IDs: []uuid.UUID{previousTemplateVersion.JobID}, + StaleIntervalMS: provisionerdserver.StaleInterval.Milliseconds(), + }) if err != nil || len(jobs) == 0 { httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ Message: "Internal error fetching provisioner job.", diff --git a/coderd/workspacebuilds.go b/coderd/workspacebuilds.go index 1fd0c95ff3a77..1d14c4518602c 100644 --- a/coderd/workspacebuilds.go +++ b/coderd/workspacebuilds.go @@ -797,7 +797,10 @@ func (api *API) workspaceBuildsData(ctx context.Context, workspaceBuilds []datab for _, build := range workspaceBuilds { jobIDs = append(jobIDs, build.JobID) } - jobs, err := api.Database.GetProvisionerJobsByIDsWithQueuePosition(ctx, jobIDs) + jobs, err := api.Database.GetProvisionerJobsByIDsWithQueuePosition(ctx, database.GetProvisionerJobsByIDsWithQueuePositionParams{ + IDs: jobIDs, + StaleIntervalMS: provisionerdserver.StaleInterval.Milliseconds(), + }) if err != nil && !errors.Is(err, sql.ErrNoRows) { return workspaceBuildsData{}, xerrors.Errorf("get provisioner jobs: %w", err) }