diff --git a/coderd/database/dbmem/dbmem.go b/coderd/database/dbmem/dbmem.go index 638fbef175636..fad8441c689df 100644 --- a/coderd/database/dbmem/dbmem.go +++ b/coderd/database/dbmem/dbmem.go @@ -748,6 +748,25 @@ var deletedUserLinkError = &pq.Error{ Routine: "exec_stmt_raise", } +// m1 and m2 are equal iff |m1| = |m2| ^ m2 ⊆ m1 +func tagsEqual(m1, m2 map[string]string) bool { + return len(m1) == len(m2) && tagsSubset(m1, m2) +} + +// m2 is a subset of m1 if each key in m1 exists in m2 +// with the same value +func tagsSubset(m1, m2 map[string]string) bool { + for k, v1 := range m1 { + if v2, found := m2[k]; !found || v1 != v2 { + return false + } + } + return true +} + +// default tags when no tag is specified for a provisioner or job +var tagsUntagged = provisionersdk.MutateTags(uuid.Nil, nil) + func (*FakeQuerier) AcquireLock(_ context.Context, _ int64) error { return xerrors.New("AcquireLock must only be called within a transaction") } @@ -783,19 +802,15 @@ func (q *FakeQuerier) AcquireProvisionerJob(_ context.Context, arg database.Acqu } } - missing := false - for key, value := range provisionerJob.Tags { - provided, found := tags[key] - if !found { - missing = true - break - } - if provided != value { - missing = true - break - } + // Special case for untagged provisioners: only match untagged jobs. + // Ref: coderd/database/queries/provisionerjobs.sql:24-30 + // CASE WHEN nested.tags :: jsonb = '{"scope": "organization", "owner": ""}' :: jsonb + // THEN nested.tags :: jsonb = @tags :: jsonb + if tagsEqual(provisionerJob.Tags, tagsUntagged) && !tagsEqual(provisionerJob.Tags, tags) { + continue } - if missing { + // ELSE nested.tags :: jsonb <@ @tags :: jsonb + if !tagsSubset(provisionerJob.Tags, tags) { continue } provisionerJob.StartedAt = arg.StartedAt diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index a4ce07d09b900..687390ec6111d 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -3936,8 +3936,13 @@ WHERE nested.started_at IS NULL -- Ensure the caller has the correct provisioner. AND nested.provisioner = ANY($3 :: provisioner_type [ ]) - -- Ensure the caller satisfies all job tags. - AND nested.tags <@ $4 :: jsonb + AND CASE + -- Special case for untagged provisioners: only match untagged jobs. + WHEN nested.tags :: jsonb = '{"scope": "organization", "owner": ""}' :: jsonb + THEN nested.tags :: jsonb = $4 :: jsonb + -- Ensure the caller satisfies all job tags. + ELSE nested.tags :: jsonb <@ $4 :: jsonb + END ORDER BY nested.created_at FOR UPDATE diff --git a/coderd/database/queries/provisionerjobs.sql b/coderd/database/queries/provisionerjobs.sql index b4c113c888dd4..1746cd1c59c5d 100644 --- a/coderd/database/queries/provisionerjobs.sql +++ b/coderd/database/queries/provisionerjobs.sql @@ -21,8 +21,13 @@ WHERE nested.started_at IS NULL -- Ensure the caller has the correct provisioner. AND nested.provisioner = ANY(@types :: provisioner_type [ ]) - -- Ensure the caller satisfies all job tags. - AND nested.tags <@ @tags :: jsonb + AND CASE + -- Special case for untagged provisioners: only match untagged jobs. + WHEN nested.tags :: jsonb = '{"scope": "organization", "owner": ""}' :: jsonb + THEN nested.tags :: jsonb = @tags :: jsonb + -- Ensure the caller satisfies all job tags. + ELSE nested.tags :: jsonb <@ @tags :: jsonb + END ORDER BY nested.created_at FOR UPDATE diff --git a/coderd/provisionerdserver/acquirer_test.go b/coderd/provisionerdserver/acquirer_test.go index bed8eccb68aca..088fb68a5acaf 100644 --- a/coderd/provisionerdserver/acquirer_test.go +++ b/coderd/provisionerdserver/acquirer_test.go @@ -9,6 +9,7 @@ import ( "time" "github.com/google/uuid" + "github.com/sqlc-dev/pqtype" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.uber.org/goleak" @@ -18,6 +19,8 @@ import ( "cdr.dev/slog/sloggers/slogtest" "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/dbmem" + "github.com/coder/coder/v2/coderd/database/dbtestutil" + "github.com/coder/coder/v2/coderd/database/dbtime" "github.com/coder/coder/v2/coderd/database/provisionerjobs" "github.com/coder/coder/v2/coderd/database/pubsub" "github.com/coder/coder/v2/coderd/provisionerdserver" @@ -315,6 +318,133 @@ func TestAcquirer_UnblockOnCancel(t *testing.T) { require.Equal(t, jobID, job.ID) } +func TestAcquirer_MatchTags(t *testing.T) { + t.Parallel() + if testing.Short() { + t.Skip("skipping this test due to -short") + } + + someID := uuid.NewString() + someOtherID := uuid.NewString() + + for _, tt := range []struct { + name string + provisionerJobTags map[string]string + acquireJobTags map[string]string + expectAcquire bool + }{ + { + name: "untagged provisioner and untagged job", + provisionerJobTags: map[string]string{"scope": "organization", "owner": ""}, + acquireJobTags: map[string]string{"scope": "organization", "owner": ""}, + expectAcquire: true, + }, + { + name: "untagged provisioner and tagged job", + provisionerJobTags: map[string]string{"scope": "organization", "owner": "", "foo": "bar"}, + acquireJobTags: map[string]string{"scope": "organization", "owner": ""}, + expectAcquire: false, + }, + { + name: "tagged provisioner and untagged job", + provisionerJobTags: map[string]string{"scope": "organization", "owner": ""}, + acquireJobTags: map[string]string{"scope": "organization", "owner": "", "foo": "bar"}, + expectAcquire: false, + }, + { + name: "tagged provisioner and tagged job", + provisionerJobTags: map[string]string{"scope": "organization", "owner": "", "foo": "bar"}, + acquireJobTags: map[string]string{"scope": "organization", "owner": "", "foo": "bar"}, + expectAcquire: true, + }, + { + name: "tagged provisioner and double-tagged job", + provisionerJobTags: map[string]string{"scope": "organization", "owner": "", "foo": "bar", "baz": "zap"}, + acquireJobTags: map[string]string{"scope": "organization", "owner": "", "foo": "bar"}, + expectAcquire: false, + }, + { + name: "double-tagged provisioner and tagged job", + provisionerJobTags: map[string]string{"scope": "organization", "owner": "", "foo": "bar"}, + acquireJobTags: map[string]string{"scope": "organization", "owner": "", "foo": "bar", "baz": "zap"}, + expectAcquire: true, + }, + { + name: "double-tagged provisioner and double-tagged job", + provisionerJobTags: map[string]string{"scope": "organization", "owner": "", "foo": "bar", "baz": "zap"}, + acquireJobTags: map[string]string{"scope": "organization", "owner": "", "foo": "bar", "baz": "zap"}, + expectAcquire: true, + }, + { + name: "owner-scoped provisioner and untagged job", + provisionerJobTags: map[string]string{"scope": "organization", "owner": ""}, + acquireJobTags: map[string]string{"scope": "owner", "owner": someID}, + expectAcquire: false, + }, + { + name: "owner-scoped provisioner and owner-scoped job", + provisionerJobTags: map[string]string{"scope": "owner", "owner": someID}, + acquireJobTags: map[string]string{"scope": "owner", "owner": someID}, + expectAcquire: true, + }, + { + name: "owner-scoped provisioner and different owner-scoped job", + provisionerJobTags: map[string]string{"scope": "owner", "owner": someOtherID}, + acquireJobTags: map[string]string{"scope": "owner", "owner": someID}, + expectAcquire: false, + }, + { + name: "org-scoped provisioner and owner-scoped job", + provisionerJobTags: map[string]string{"scope": "owner", "owner": someID}, + acquireJobTags: map[string]string{"scope": "organization", "owner": ""}, + expectAcquire: false, + }, + } { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitShort/2) + // NOTE: explicitly not using fake store for this test. + db, ps := dbtestutil.NewDB(t) + log := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + org, err := db.InsertOrganization(ctx, database.InsertOrganizationParams{ + ID: uuid.New(), + Name: "test org", + Description: "the organization of testing", + CreatedAt: dbtime.Now(), + UpdatedAt: dbtime.Now(), + }) + require.NoError(t, err) + pj, err := db.InsertProvisionerJob(ctx, database.InsertProvisionerJobParams{ + ID: uuid.New(), + CreatedAt: dbtime.Now(), + UpdatedAt: dbtime.Now(), + OrganizationID: org.ID, + InitiatorID: uuid.New(), + Provisioner: database.ProvisionerTypeEcho, + StorageMethod: database.ProvisionerStorageMethodFile, + FileID: uuid.New(), + Type: database.ProvisionerJobTypeWorkspaceBuild, + Input: []byte("{}"), + Tags: tt.provisionerJobTags, + TraceMetadata: pqtype.NullRawMessage{}, + }) + require.NoError(t, err) + ptypes := []database.ProvisionerType{database.ProvisionerTypeEcho} + acq := provisionerdserver.NewAcquirer(ctx, log, db, ps) + aj, err := acq.AcquireJob(ctx, uuid.New(), ptypes, tt.acquireJobTags) + if tt.expectAcquire { + assert.NoError(t, err) + assert.Equal(t, pj.ID, aj.ID) + } else { + assert.Empty(t, aj, "should not have acquired job") + assert.ErrorIs(t, err, context.DeadlineExceeded, "should have timed out") + } + }) + } +} + func postJob(t *testing.T, ps pubsub.Pubsub, pt database.ProvisionerType, tags provisionerdserver.Tags) { t.Helper() msg, err := json.Marshal(provisionerjobs.JobPosting{ diff --git a/enterprise/cli/provisionerdaemons.go b/enterprise/cli/provisionerdaemons.go index b41ec75197aa9..f48b82e239c0f 100644 --- a/enterprise/cli/provisionerdaemons.go +++ b/enterprise/cli/provisionerdaemons.go @@ -117,9 +117,8 @@ func (r *RootCmd) provisionerDaemonStart() *clibase.Cmd { defer closeLogger() } - if len(tags) != 0 { - logger.Info(ctx, "note: tagged provisioners can currently pick up jobs from untagged templates") - logger.Info(ctx, "see https://github.com/coder/coder/issues/6442 for details") + if len(tags) == 0 { + logger.Info(ctx, "note: untagged provisioners can only pick up jobs from untagged templates") } // When authorizing with a PSK, we automatically scope the provisionerd