Skip to content

Commit 03f1345

Browse files
chore: update dbmem
1 parent 61a9f58 commit 03f1345

File tree

2 files changed

+101
-85
lines changed

2 files changed

+101
-85
lines changed

coderd/database/dbmem/dbmem.go

Lines changed: 101 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -1147,94 +1147,116 @@ func getOwnerFromTags(tags map[string]string) string {
11471147
return ""
11481148
}
11491149

1150-
func (q *FakeQuerier) getProvisionerJobsByIDsWithQueuePositionLocked(_ context.Context, ids []uuid.UUID) ([]database.GetProvisionerJobsByIDsWithQueuePositionRow, error) {
1151-
// WITH pending_jobs AS (
1152-
// SELECT
1153-
// id, created_at
1154-
// FROM
1155-
// provisioner_jobs
1156-
// WHERE
1157-
// started_at IS NULL
1158-
// AND
1159-
// canceled_at IS NULL
1160-
// AND
1161-
// completed_at IS NULL
1162-
// AND
1163-
// error IS NULL
1164-
// ),
1165-
type pendingJobRow struct {
1166-
ID uuid.UUID
1167-
CreatedAt time.Time
1168-
}
1169-
pendingJobs := make([]pendingJobRow, 0)
1170-
for _, job := range q.provisionerJobs {
1171-
if job.StartedAt.Valid ||
1172-
job.CanceledAt.Valid ||
1173-
job.CompletedAt.Valid ||
1174-
job.Error.Valid {
1175-
continue
1150+
// provisionerTagsetContains checks if daemonTags contain all key-value pairs from jobTags
1151+
func provisionerTagsetContains(daemonTags, jobTags map[string]string) bool {
1152+
for jobKey, jobValue := range jobTags {
1153+
if daemonValue, exists := daemonTags[jobKey]; !exists || daemonValue != jobValue {
1154+
return false
11761155
}
1177-
pendingJobs = append(pendingJobs, pendingJobRow{
1178-
ID: job.ID,
1179-
CreatedAt: job.CreatedAt,
1180-
})
11811156
}
1157+
return true
1158+
}
11821159

1183-
// queue_position AS (
1184-
// SELECT
1185-
// id,
1186-
// ROW_NUMBER() OVER (ORDER BY created_at ASC) AS queue_position
1187-
// FROM
1188-
// pending_jobs
1189-
// ),
1190-
slices.SortFunc(pendingJobs, func(a, b pendingJobRow) int {
1191-
c := a.CreatedAt.Compare(b.CreatedAt)
1192-
return c
1193-
})
1160+
// GetProvisionerJobsByIDsWithQueuePosition mimics the SQL logic in pure Go
1161+
func (q *FakeQuerier) getProvisionerJobsByIDsWithQueuePositionLocked(_ context.Context, jobIDs []uuid.UUID) ([]database.GetProvisionerJobsByIDsWithQueuePositionRow, error) {
1162+
// Step 1: Filter provisionerJobs based on jobIDs
1163+
filteredJobs := make(map[uuid.UUID]database.ProvisionerJob)
1164+
for _, job := range q.provisionerJobs {
1165+
for _, id := range jobIDs {
1166+
if job.ID == id {
1167+
filteredJobs[job.ID] = job
1168+
}
1169+
}
1170+
}
11941171

1195-
queuePosition := make(map[uuid.UUID]int64)
1196-
for idx, pj := range pendingJobs {
1197-
queuePosition[pj.ID] = int64(idx + 1)
1198-
}
1199-
1200-
// queue_size AS (
1201-
// SELECT COUNT(*) AS count FROM pending_jobs
1202-
// ),
1203-
queueSize := len(pendingJobs)
1204-
1205-
// SELECT
1206-
// sqlc.embed(pj),
1207-
// COALESCE(qp.queue_position, 0) AS queue_position,
1208-
// COALESCE(qs.count, 0) AS queue_size
1209-
// FROM
1210-
// provisioner_jobs pj
1211-
// LEFT JOIN
1212-
// queue_position qp ON pj.id = qp.id
1213-
// LEFT JOIN
1214-
// queue_size qs ON TRUE
1215-
// WHERE
1216-
// pj.id IN (...)
1217-
jobs := make([]database.GetProvisionerJobsByIDsWithQueuePositionRow, 0)
1172+
// Step 2: Identify pending jobs
1173+
pendingJobs := make(map[uuid.UUID]database.ProvisionerJob)
12181174
for _, job := range q.provisionerJobs {
1219-
if ids != nil && !slices.Contains(ids, job.ID) {
1220-
continue
1175+
if job.JobStatus == "pending" {
1176+
pendingJobs[job.ID] = job
12211177
}
1222-
// clone the Tags before appending, since maps are reference types and
1223-
// we don't want the caller to be able to mutate the map we have inside
1224-
// dbmem!
1225-
job.Tags = maps.Clone(job.Tags)
1226-
job := database.GetProvisionerJobsByIDsWithQueuePositionRow{
1227-
// sqlc.embed(pj),
1228-
ProvisionerJob: job,
1229-
// COALESCE(qp.queue_position, 0) AS queue_position,
1230-
QueuePosition: queuePosition[job.ID],
1231-
// COALESCE(qs.count, 0) AS queue_size
1232-
QueueSize: int64(queueSize),
1178+
}
1179+
1180+
// Step 3: Identify pending jobs that have a matching provisioner
1181+
matchedJobs := make(map[uuid.UUID]struct{})
1182+
for _, job := range pendingJobs {
1183+
for _, daemon := range q.provisionerDaemons {
1184+
if provisionerTagsetContains(daemon.Tags, job.Tags) {
1185+
matchedJobs[job.ID] = struct{}{}
1186+
break
1187+
}
12331188
}
1234-
jobs = append(jobs, job)
12351189
}
12361190

1237-
return jobs, nil
1191+
// Step 4: Rank pending jobs per provisioner
1192+
jobRanks := make(map[uuid.UUID][]database.ProvisionerJob)
1193+
for _, job := range pendingJobs {
1194+
for _, daemon := range q.provisionerDaemons {
1195+
if provisionerTagsetContains(daemon.Tags, job.Tags) {
1196+
jobRanks[daemon.ID] = append(jobRanks[daemon.ID], job)
1197+
}
1198+
}
1199+
}
1200+
1201+
// Sort jobs per provisioner by CreatedAt
1202+
for daemonID := range jobRanks {
1203+
sort.Slice(jobRanks[daemonID], func(i, j int) bool {
1204+
return jobRanks[daemonID][i].CreatedAt.Before(jobRanks[daemonID][j].CreatedAt)
1205+
})
1206+
}
1207+
1208+
// Step 5: Compute queue position & max queue size across all provisioners
1209+
jobQueueStats := make(map[uuid.UUID]database.GetProvisionerJobsByIDsWithQueuePositionRow)
1210+
for _, jobs := range jobRanks {
1211+
queueSize := int64(len(jobs)) // Queue size per provisioner
1212+
for i, job := range jobs {
1213+
queuePosition := int64(i + 1)
1214+
1215+
// If the job already exists, update only if this queuePosition is better
1216+
if existing, exists := jobQueueStats[job.ID]; exists {
1217+
jobQueueStats[job.ID] = database.GetProvisionerJobsByIDsWithQueuePositionRow{
1218+
ID: job.ID,
1219+
CreatedAt: job.CreatedAt,
1220+
ProvisionerJob: job,
1221+
QueuePosition: min(existing.QueuePosition, queuePosition),
1222+
QueueSize: max(existing.QueueSize, queueSize), // Take the maximum queue size across provisioners
1223+
}
1224+
} else {
1225+
jobQueueStats[job.ID] = database.GetProvisionerJobsByIDsWithQueuePositionRow{
1226+
ID: job.ID,
1227+
CreatedAt: job.CreatedAt,
1228+
ProvisionerJob: job,
1229+
QueuePosition: queuePosition,
1230+
QueueSize: queueSize,
1231+
}
1232+
}
1233+
}
1234+
}
1235+
1236+
// Step 6: Compute the final results with minimal checks
1237+
var results []database.GetProvisionerJobsByIDsWithQueuePositionRow
1238+
for _, job := range filteredJobs {
1239+
// If the job has a computed rank, use it
1240+
if rank, found := jobQueueStats[job.ID]; found {
1241+
results = append(results, rank)
1242+
} else {
1243+
// Otherwise, return (0,0) for non-pending jobs and unranked pending jobs
1244+
results = append(results, database.GetProvisionerJobsByIDsWithQueuePositionRow{
1245+
ID: job.ID,
1246+
CreatedAt: job.CreatedAt,
1247+
ProvisionerJob: job,
1248+
QueuePosition: 0,
1249+
QueueSize: 0,
1250+
})
1251+
}
1252+
}
1253+
1254+
// Step 7: Sort results by CreatedAt
1255+
sort.Slice(results, func(i, j int) bool {
1256+
return results[i].CreatedAt.Before(results[j].CreatedAt)
1257+
})
1258+
1259+
return results, nil
12381260
}
12391261

12401262
func (*FakeQuerier) AcquireLock(_ context.Context, _ int64) error {

coderd/database/querier_test.go

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2168,9 +2168,6 @@ func TestExpectOne(t *testing.T) {
21682168

21692169
func TestGetProvisionerJobsByIDsWithQueuePosition(t *testing.T) {
21702170
t.Parallel()
2171-
if !dbtestutil.WillUsePostgres() {
2172-
t.SkipNow()
2173-
}
21742171

21752172
now := dbtime.Now()
21762173
ctx := testutil.Context(t, testutil.WaitShort)
@@ -2613,9 +2610,6 @@ func TestGetProvisionerJobsByIDsWithQueuePosition_MixedStatuses(t *testing.T) {
26132610

26142611
func TestGetProvisionerJobsByIDsWithQueuePosition_OrderValidation(t *testing.T) {
26152612
t.Parallel()
2616-
if !dbtestutil.WillUsePostgres() {
2617-
t.SkipNow()
2618-
}
26192613

26202614
db, _ := dbtestutil.NewDB(t)
26212615
now := dbtime.Now()

0 commit comments

Comments
 (0)