Skip to content

Commit 53e8f9c

Browse files
authored
fix(coderd): only allow untagged provisioners to pick up untagged jobs (#12269)
Alternative solution to #6442 Modifies the behaviour of AcquireProvisionerJob and adds a special case for 'un-tagged' jobs such that they can only be picked up by 'un-tagged' provisioners. Also adds comprehensive test coverage for AcquireJob given various combinations of tags.
1 parent aa7a12a commit 53e8f9c

File tree

5 files changed

+173
-19
lines changed

5 files changed

+173
-19
lines changed

coderd/database/dbmem/dbmem.go

+27-12
Original file line numberDiff line numberDiff line change
@@ -748,6 +748,25 @@ var deletedUserLinkError = &pq.Error{
748748
Routine: "exec_stmt_raise",
749749
}
750750

751+
// m1 and m2 are equal iff |m1| = |m2| ^ m2 ⊆ m1
752+
func tagsEqual(m1, m2 map[string]string) bool {
753+
return len(m1) == len(m2) && tagsSubset(m1, m2)
754+
}
755+
756+
// m2 is a subset of m1 if each key in m1 exists in m2
757+
// with the same value
758+
func tagsSubset(m1, m2 map[string]string) bool {
759+
for k, v1 := range m1 {
760+
if v2, found := m2[k]; !found || v1 != v2 {
761+
return false
762+
}
763+
}
764+
return true
765+
}
766+
767+
// default tags when no tag is specified for a provisioner or job
768+
var tagsUntagged = provisionersdk.MutateTags(uuid.Nil, nil)
769+
751770
func (*FakeQuerier) AcquireLock(_ context.Context, _ int64) error {
752771
return xerrors.New("AcquireLock must only be called within a transaction")
753772
}
@@ -783,19 +802,15 @@ func (q *FakeQuerier) AcquireProvisionerJob(_ context.Context, arg database.Acqu
783802
}
784803
}
785804

786-
missing := false
787-
for key, value := range provisionerJob.Tags {
788-
provided, found := tags[key]
789-
if !found {
790-
missing = true
791-
break
792-
}
793-
if provided != value {
794-
missing = true
795-
break
796-
}
805+
// Special case for untagged provisioners: only match untagged jobs.
806+
// Ref: coderd/database/queries/provisionerjobs.sql:24-30
807+
// CASE WHEN nested.tags :: jsonb = '{"scope": "organization", "owner": ""}' :: jsonb
808+
// THEN nested.tags :: jsonb = @tags :: jsonb
809+
if tagsEqual(provisionerJob.Tags, tagsUntagged) && !tagsEqual(provisionerJob.Tags, tags) {
810+
continue
797811
}
798-
if missing {
812+
// ELSE nested.tags :: jsonb <@ @tags :: jsonb
813+
if !tagsSubset(provisionerJob.Tags, tags) {
799814
continue
800815
}
801816
provisionerJob.StartedAt = arg.StartedAt

coderd/database/queries.sql.go

+7-2
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

coderd/database/queries/provisionerjobs.sql

+7-2
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,13 @@ WHERE
2121
nested.started_at IS NULL
2222
-- Ensure the caller has the correct provisioner.
2323
AND nested.provisioner = ANY(@types :: provisioner_type [ ])
24-
-- Ensure the caller satisfies all job tags.
25-
AND nested.tags <@ @tags :: jsonb
24+
AND CASE
25+
-- Special case for untagged provisioners: only match untagged jobs.
26+
WHEN nested.tags :: jsonb = '{"scope": "organization", "owner": ""}' :: jsonb
27+
THEN nested.tags :: jsonb = @tags :: jsonb
28+
-- Ensure the caller satisfies all job tags.
29+
ELSE nested.tags :: jsonb <@ @tags :: jsonb
30+
END
2631
ORDER BY
2732
nested.created_at
2833
FOR UPDATE

coderd/provisionerdserver/acquirer_test.go

+130
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"time"
1010

1111
"github.com/google/uuid"
12+
"github.com/sqlc-dev/pqtype"
1213
"github.com/stretchr/testify/assert"
1314
"github.com/stretchr/testify/require"
1415
"go.uber.org/goleak"
@@ -18,6 +19,8 @@ import (
1819
"cdr.dev/slog/sloggers/slogtest"
1920
"github.com/coder/coder/v2/coderd/database"
2021
"github.com/coder/coder/v2/coderd/database/dbmem"
22+
"github.com/coder/coder/v2/coderd/database/dbtestutil"
23+
"github.com/coder/coder/v2/coderd/database/dbtime"
2124
"github.com/coder/coder/v2/coderd/database/provisionerjobs"
2225
"github.com/coder/coder/v2/coderd/database/pubsub"
2326
"github.com/coder/coder/v2/coderd/provisionerdserver"
@@ -315,6 +318,133 @@ func TestAcquirer_UnblockOnCancel(t *testing.T) {
315318
require.Equal(t, jobID, job.ID)
316319
}
317320

321+
func TestAcquirer_MatchTags(t *testing.T) {
322+
t.Parallel()
323+
if testing.Short() {
324+
t.Skip("skipping this test due to -short")
325+
}
326+
327+
someID := uuid.NewString()
328+
someOtherID := uuid.NewString()
329+
330+
for _, tt := range []struct {
331+
name string
332+
provisionerJobTags map[string]string
333+
acquireJobTags map[string]string
334+
expectAcquire bool
335+
}{
336+
{
337+
name: "untagged provisioner and untagged job",
338+
provisionerJobTags: map[string]string{"scope": "organization", "owner": ""},
339+
acquireJobTags: map[string]string{"scope": "organization", "owner": ""},
340+
expectAcquire: true,
341+
},
342+
{
343+
name: "untagged provisioner and tagged job",
344+
provisionerJobTags: map[string]string{"scope": "organization", "owner": "", "foo": "bar"},
345+
acquireJobTags: map[string]string{"scope": "organization", "owner": ""},
346+
expectAcquire: false,
347+
},
348+
{
349+
name: "tagged provisioner and untagged job",
350+
provisionerJobTags: map[string]string{"scope": "organization", "owner": ""},
351+
acquireJobTags: map[string]string{"scope": "organization", "owner": "", "foo": "bar"},
352+
expectAcquire: false,
353+
},
354+
{
355+
name: "tagged provisioner and tagged job",
356+
provisionerJobTags: map[string]string{"scope": "organization", "owner": "", "foo": "bar"},
357+
acquireJobTags: map[string]string{"scope": "organization", "owner": "", "foo": "bar"},
358+
expectAcquire: true,
359+
},
360+
{
361+
name: "tagged provisioner and double-tagged job",
362+
provisionerJobTags: map[string]string{"scope": "organization", "owner": "", "foo": "bar", "baz": "zap"},
363+
acquireJobTags: map[string]string{"scope": "organization", "owner": "", "foo": "bar"},
364+
expectAcquire: false,
365+
},
366+
{
367+
name: "double-tagged provisioner and tagged job",
368+
provisionerJobTags: map[string]string{"scope": "organization", "owner": "", "foo": "bar"},
369+
acquireJobTags: map[string]string{"scope": "organization", "owner": "", "foo": "bar", "baz": "zap"},
370+
expectAcquire: true,
371+
},
372+
{
373+
name: "double-tagged provisioner and double-tagged job",
374+
provisionerJobTags: map[string]string{"scope": "organization", "owner": "", "foo": "bar", "baz": "zap"},
375+
acquireJobTags: map[string]string{"scope": "organization", "owner": "", "foo": "bar", "baz": "zap"},
376+
expectAcquire: true,
377+
},
378+
{
379+
name: "owner-scoped provisioner and untagged job",
380+
provisionerJobTags: map[string]string{"scope": "organization", "owner": ""},
381+
acquireJobTags: map[string]string{"scope": "owner", "owner": someID},
382+
expectAcquire: false,
383+
},
384+
{
385+
name: "owner-scoped provisioner and owner-scoped job",
386+
provisionerJobTags: map[string]string{"scope": "owner", "owner": someID},
387+
acquireJobTags: map[string]string{"scope": "owner", "owner": someID},
388+
expectAcquire: true,
389+
},
390+
{
391+
name: "owner-scoped provisioner and different owner-scoped job",
392+
provisionerJobTags: map[string]string{"scope": "owner", "owner": someOtherID},
393+
acquireJobTags: map[string]string{"scope": "owner", "owner": someID},
394+
expectAcquire: false,
395+
},
396+
{
397+
name: "org-scoped provisioner and owner-scoped job",
398+
provisionerJobTags: map[string]string{"scope": "owner", "owner": someID},
399+
acquireJobTags: map[string]string{"scope": "organization", "owner": ""},
400+
expectAcquire: false,
401+
},
402+
} {
403+
tt := tt
404+
t.Run(tt.name, func(t *testing.T) {
405+
t.Parallel()
406+
407+
ctx := testutil.Context(t, testutil.WaitShort/2)
408+
// NOTE: explicitly not using fake store for this test.
409+
db, ps := dbtestutil.NewDB(t)
410+
log := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
411+
org, err := db.InsertOrganization(ctx, database.InsertOrganizationParams{
412+
ID: uuid.New(),
413+
Name: "test org",
414+
Description: "the organization of testing",
415+
CreatedAt: dbtime.Now(),
416+
UpdatedAt: dbtime.Now(),
417+
})
418+
require.NoError(t, err)
419+
pj, err := db.InsertProvisionerJob(ctx, database.InsertProvisionerJobParams{
420+
ID: uuid.New(),
421+
CreatedAt: dbtime.Now(),
422+
UpdatedAt: dbtime.Now(),
423+
OrganizationID: org.ID,
424+
InitiatorID: uuid.New(),
425+
Provisioner: database.ProvisionerTypeEcho,
426+
StorageMethod: database.ProvisionerStorageMethodFile,
427+
FileID: uuid.New(),
428+
Type: database.ProvisionerJobTypeWorkspaceBuild,
429+
Input: []byte("{}"),
430+
Tags: tt.provisionerJobTags,
431+
TraceMetadata: pqtype.NullRawMessage{},
432+
})
433+
require.NoError(t, err)
434+
ptypes := []database.ProvisionerType{database.ProvisionerTypeEcho}
435+
acq := provisionerdserver.NewAcquirer(ctx, log, db, ps)
436+
aj, err := acq.AcquireJob(ctx, uuid.New(), ptypes, tt.acquireJobTags)
437+
if tt.expectAcquire {
438+
assert.NoError(t, err)
439+
assert.Equal(t, pj.ID, aj.ID)
440+
} else {
441+
assert.Empty(t, aj, "should not have acquired job")
442+
assert.ErrorIs(t, err, context.DeadlineExceeded, "should have timed out")
443+
}
444+
})
445+
}
446+
}
447+
318448
func postJob(t *testing.T, ps pubsub.Pubsub, pt database.ProvisionerType, tags provisionerdserver.Tags) {
319449
t.Helper()
320450
msg, err := json.Marshal(provisionerjobs.JobPosting{

enterprise/cli/provisionerdaemons.go

+2-3
Original file line numberDiff line numberDiff line change
@@ -117,9 +117,8 @@ func (r *RootCmd) provisionerDaemonStart() *clibase.Cmd {
117117
defer closeLogger()
118118
}
119119

120-
if len(tags) != 0 {
121-
logger.Info(ctx, "note: tagged provisioners can currently pick up jobs from untagged templates")
122-
logger.Info(ctx, "see https://github.com/coder/coder/issues/6442 for details")
120+
if len(tags) == 0 {
121+
logger.Info(ctx, "note: untagged provisioners can only pick up jobs from untagged templates")
123122
}
124123

125124
// When authorizing with a PSK, we automatically scope the provisionerd

0 commit comments

Comments
 (0)