From 89dff937d9974afcb4eb421657cb6e19bff12228 Mon Sep 17 00:00:00 2001 From: Cian Johnston Date: Thu, 22 Feb 2024 13:32:49 +0000 Subject: [PATCH 1/4] fix(coderd): only allow untagged provisioners to pick up untagged jobs --- coderd/database/dbmem/dbmem.go | 39 +++++--- coderd/database/queries.sql.go | 9 +- coderd/database/queries/provisionerjobs.sql | 9 +- coderd/provisionerdserver/acquirer_test.go | 103 ++++++++++++++++++++ 4 files changed, 144 insertions(+), 16 deletions(-) diff --git a/coderd/database/dbmem/dbmem.go b/coderd/database/dbmem/dbmem.go index 638fbef175636..d4190b94b5150 100644 --- a/coderd/database/dbmem/dbmem.go +++ b/coderd/database/dbmem/dbmem.go @@ -752,6 +752,25 @@ func (*FakeQuerier) AcquireLock(_ context.Context, _ int64) error { return xerrors.New("AcquireLock must only be called within a transaction") } +// m1 and m2 are equal iff |m1| = |m2| ^ m2 ⊆ m1 +func tagsEqual(m1, m2 map[string]string) bool { + return len(m1) == len(m2) && tagsSubset(m1, m2) +} + +// m2 is a subset of m1 if each key in m1 exists in m2 +// with the same value +func tagsSubset(m1, m2 map[string]string) bool { + for k, v1 := range m1 { + if v2, found := m2[k]; !found || v1 != v2 { + return false + } + } + return true +} + +// default tags when no tag is specified for a provisioner or job +var tagsUntagged = provisionersdk.MutateTags(uuid.Nil, nil) + func (q *FakeQuerier) AcquireProvisionerJob(_ context.Context, arg database.AcquireProvisionerJobParams) (database.ProvisionerJob, error) { if err := validateDatabaseType(arg); err != nil { return database.ProvisionerJob{}, err @@ -783,19 +802,15 @@ func (q *FakeQuerier) AcquireProvisionerJob(_ context.Context, arg database.Acqu } } - missing := false - for key, value := range provisionerJob.Tags { - provided, found := tags[key] - if !found { - missing = true - break - } - if provided != value { - missing = true - break - } + // Special case for untagged provisioners: only match untagged jobs. + // Ref: coderd/database/queries/provisionerjobs.sql:24-30 + // CASE WHEN nested.tags :: jsonb = '{"scope": "organization", "owner": ""}' :: jsonb + // THEN nested.tags :: jsonb = @tags :: jsonb + if tagsEqual(provisionerJob.Tags, tagsUntagged) && !tagsEqual(provisionerJob.Tags, tags) { + continue } - if missing { + // ELSE nested.tags :: jsonb <@ @tags :: jsonb + if !tagsSubset(provisionerJob.Tags, tags) { continue } provisionerJob.StartedAt = arg.StartedAt diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index a4ce07d09b900..687390ec6111d 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -3936,8 +3936,13 @@ WHERE nested.started_at IS NULL -- Ensure the caller has the correct provisioner. AND nested.provisioner = ANY($3 :: provisioner_type [ ]) - -- Ensure the caller satisfies all job tags. - AND nested.tags <@ $4 :: jsonb + AND CASE + -- Special case for untagged provisioners: only match untagged jobs. + WHEN nested.tags :: jsonb = '{"scope": "organization", "owner": ""}' :: jsonb + THEN nested.tags :: jsonb = $4 :: jsonb + -- Ensure the caller satisfies all job tags. + ELSE nested.tags :: jsonb <@ $4 :: jsonb + END ORDER BY nested.created_at FOR UPDATE diff --git a/coderd/database/queries/provisionerjobs.sql b/coderd/database/queries/provisionerjobs.sql index b4c113c888dd4..1746cd1c59c5d 100644 --- a/coderd/database/queries/provisionerjobs.sql +++ b/coderd/database/queries/provisionerjobs.sql @@ -21,8 +21,13 @@ WHERE nested.started_at IS NULL -- Ensure the caller has the correct provisioner. AND nested.provisioner = ANY(@types :: provisioner_type [ ]) - -- Ensure the caller satisfies all job tags. - AND nested.tags <@ @tags :: jsonb + AND CASE + -- Special case for untagged provisioners: only match untagged jobs. + WHEN nested.tags :: jsonb = '{"scope": "organization", "owner": ""}' :: jsonb + THEN nested.tags :: jsonb = @tags :: jsonb + -- Ensure the caller satisfies all job tags. + ELSE nested.tags :: jsonb <@ @tags :: jsonb + END ORDER BY nested.created_at FOR UPDATE diff --git a/coderd/provisionerdserver/acquirer_test.go b/coderd/provisionerdserver/acquirer_test.go index bed8eccb68aca..72793ab89dccc 100644 --- a/coderd/provisionerdserver/acquirer_test.go +++ b/coderd/provisionerdserver/acquirer_test.go @@ -9,6 +9,7 @@ import ( "time" "github.com/google/uuid" + "github.com/sqlc-dev/pqtype" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.uber.org/goleak" @@ -18,6 +19,8 @@ import ( "cdr.dev/slog/sloggers/slogtest" "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/dbmem" + "github.com/coder/coder/v2/coderd/database/dbtestutil" + "github.com/coder/coder/v2/coderd/database/dbtime" "github.com/coder/coder/v2/coderd/database/provisionerjobs" "github.com/coder/coder/v2/coderd/database/pubsub" "github.com/coder/coder/v2/coderd/provisionerdserver" @@ -315,6 +318,106 @@ func TestAcquirer_UnblockOnCancel(t *testing.T) { require.Equal(t, jobID, job.ID) } +func TestAcquirer_MatchTags(t *testing.T) { + t.Parallel() + if testing.Short() { + t.Skip("skipping this test due to -short") + } + + for _, tt := range []struct { + name string + provisionerJobTags map[string]string + acquireJobTags map[string]string + expectAcquire bool + }{ + { + name: "untagged provisioner and untagged job", + provisionerJobTags: map[string]string{"scope": "organization", "owner": ""}, + acquireJobTags: map[string]string{"scope": "organization", "owner": ""}, + expectAcquire: true, + }, + { + name: "untagged provisioner and tagged job", + provisionerJobTags: map[string]string{"scope": "organization", "owner": "", "foo": "bar"}, + acquireJobTags: map[string]string{"scope": "organization", "owner": ""}, + expectAcquire: false, + }, + { + name: "tagged provisioner and untagged job", + provisionerJobTags: map[string]string{"scope": "organization", "owner": ""}, + acquireJobTags: map[string]string{"scope": "organization", "owner": "", "foo": "bar"}, + expectAcquire: false, + }, + { + name: "tagged provisioner and tagged job", + provisionerJobTags: map[string]string{"scope": "organization", "owner": "", "foo": "bar"}, + acquireJobTags: map[string]string{"scope": "organization", "owner": "", "foo": "bar"}, + expectAcquire: true, + }, + { + name: "tagged provisioner and double-tagged job", + provisionerJobTags: map[string]string{"scope": "organization", "owner": "", "foo": "bar", "baz": "zap"}, + acquireJobTags: map[string]string{"scope": "organization", "owner": "", "foo": "bar"}, + expectAcquire: false, + }, + { + name: "double-tagged provisioner and tagged job", + provisionerJobTags: map[string]string{"scope": "organization", "owner": "", "foo": "bar"}, + acquireJobTags: map[string]string{"scope": "organization", "owner": "", "foo": "bar", "baz": "zap"}, + expectAcquire: true, + }, + { + name: "double-tagged provisioner and double-tagged job", + provisionerJobTags: map[string]string{"scope": "organization", "owner": "", "foo": "bar", "baz": "zap"}, + acquireJobTags: map[string]string{"scope": "organization", "owner": "", "foo": "bar", "baz": "zap"}, + expectAcquire: true, + }, + } { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitShort/2) + // NOTE: explicitly not using fake store for this test. + db, ps := dbtestutil.NewDB(t) + log := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + org, err := db.InsertOrganization(ctx, database.InsertOrganizationParams{ + ID: uuid.New(), + Name: "test org", + Description: "the organization of testing", + CreatedAt: dbtime.Now(), + UpdatedAt: dbtime.Now(), + }) + require.NoError(t, err) + pj, err := db.InsertProvisionerJob(ctx, database.InsertProvisionerJobParams{ + ID: uuid.New(), + CreatedAt: dbtime.Now(), + UpdatedAt: dbtime.Now(), + OrganizationID: org.ID, + InitiatorID: uuid.New(), + Provisioner: database.ProvisionerTypeEcho, + StorageMethod: database.ProvisionerStorageMethodFile, + FileID: uuid.New(), + Type: database.ProvisionerJobTypeWorkspaceBuild, + Input: []byte("{}"), + Tags: tt.provisionerJobTags, + TraceMetadata: pqtype.NullRawMessage{}, + }) + require.NoError(t, err) + ptypes := []database.ProvisionerType{database.ProvisionerTypeEcho} + acq := provisionerdserver.NewAcquirer(ctx, log, db, ps) + aj, err := acq.AcquireJob(ctx, uuid.New(), ptypes, tt.acquireJobTags) + if tt.expectAcquire { + assert.NoError(t, err) + assert.Equal(t, pj.ID, aj.ID) + } else { + assert.Empty(t, aj, "should not have acquired job") + assert.ErrorIs(t, err, context.DeadlineExceeded, "should have timed out") + } + }) + } +} + func postJob(t *testing.T, ps pubsub.Pubsub, pt database.ProvisionerType, tags provisionerdserver.Tags) { t.Helper() msg, err := json.Marshal(provisionerjobs.JobPosting{ From f8f7ce9a1478b7a5e8aa9272061c8eb83569e933 Mon Sep 17 00:00:00 2001 From: Cian Johnston Date: Thu, 22 Feb 2024 14:02:00 +0000 Subject: [PATCH 2/4] add test cases for owner-scoped jobs --- coderd/provisionerdserver/acquirer_test.go | 27 ++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/coderd/provisionerdserver/acquirer_test.go b/coderd/provisionerdserver/acquirer_test.go index 72793ab89dccc..088fb68a5acaf 100644 --- a/coderd/provisionerdserver/acquirer_test.go +++ b/coderd/provisionerdserver/acquirer_test.go @@ -324,6 +324,9 @@ func TestAcquirer_MatchTags(t *testing.T) { t.Skip("skipping this test due to -short") } + someID := uuid.NewString() + someOtherID := uuid.NewString() + for _, tt := range []struct { name string provisionerJobTags map[string]string @@ -372,6 +375,30 @@ func TestAcquirer_MatchTags(t *testing.T) { acquireJobTags: map[string]string{"scope": "organization", "owner": "", "foo": "bar", "baz": "zap"}, expectAcquire: true, }, + { + name: "owner-scoped provisioner and untagged job", + provisionerJobTags: map[string]string{"scope": "organization", "owner": ""}, + acquireJobTags: map[string]string{"scope": "owner", "owner": someID}, + expectAcquire: false, + }, + { + name: "owner-scoped provisioner and owner-scoped job", + provisionerJobTags: map[string]string{"scope": "owner", "owner": someID}, + acquireJobTags: map[string]string{"scope": "owner", "owner": someID}, + expectAcquire: true, + }, + { + name: "owner-scoped provisioner and different owner-scoped job", + provisionerJobTags: map[string]string{"scope": "owner", "owner": someOtherID}, + acquireJobTags: map[string]string{"scope": "owner", "owner": someID}, + expectAcquire: false, + }, + { + name: "org-scoped provisioner and owner-scoped job", + provisionerJobTags: map[string]string{"scope": "owner", "owner": someID}, + acquireJobTags: map[string]string{"scope": "organization", "owner": ""}, + expectAcquire: false, + }, } { tt := tt t.Run(tt.name, func(t *testing.T) { From 039f2d21870ab8684a376f094a8303e6bfae41b7 Mon Sep 17 00:00:00 2001 From: Cian Johnston Date: Thu, 22 Feb 2024 14:04:16 +0000 Subject: [PATCH 3/4] make gen --- coderd/database/dbmem/dbmem.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/coderd/database/dbmem/dbmem.go b/coderd/database/dbmem/dbmem.go index d4190b94b5150..fad8441c689df 100644 --- a/coderd/database/dbmem/dbmem.go +++ b/coderd/database/dbmem/dbmem.go @@ -748,10 +748,6 @@ var deletedUserLinkError = &pq.Error{ Routine: "exec_stmt_raise", } -func (*FakeQuerier) AcquireLock(_ context.Context, _ int64) error { - return xerrors.New("AcquireLock must only be called within a transaction") -} - // m1 and m2 are equal iff |m1| = |m2| ^ m2 ⊆ m1 func tagsEqual(m1, m2 map[string]string) bool { return len(m1) == len(m2) && tagsSubset(m1, m2) @@ -771,6 +767,10 @@ func tagsSubset(m1, m2 map[string]string) bool { // default tags when no tag is specified for a provisioner or job var tagsUntagged = provisionersdk.MutateTags(uuid.Nil, nil) +func (*FakeQuerier) AcquireLock(_ context.Context, _ int64) error { + return xerrors.New("AcquireLock must only be called within a transaction") +} + func (q *FakeQuerier) AcquireProvisionerJob(_ context.Context, arg database.AcquireProvisionerJobParams) (database.ProvisionerJob, error) { if err := validateDatabaseType(arg); err != nil { return database.ProvisionerJob{}, err From 71d83b0b28940c348cd6a313a3ebd48e49c73803 Mon Sep 17 00:00:00 2001 From: Cian Johnston Date: Thu, 22 Feb 2024 14:26:13 +0000 Subject: [PATCH 4/4] update message in provisionerd cli --- enterprise/cli/provisionerdaemons.go | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/enterprise/cli/provisionerdaemons.go b/enterprise/cli/provisionerdaemons.go index b41ec75197aa9..f48b82e239c0f 100644 --- a/enterprise/cli/provisionerdaemons.go +++ b/enterprise/cli/provisionerdaemons.go @@ -117,9 +117,8 @@ func (r *RootCmd) provisionerDaemonStart() *clibase.Cmd { defer closeLogger() } - if len(tags) != 0 { - logger.Info(ctx, "note: tagged provisioners can currently pick up jobs from untagged templates") - logger.Info(ctx, "see https://github.com/coder/coder/issues/6442 for details") + if len(tags) == 0 { + logger.Info(ctx, "note: untagged provisioners can only pick up jobs from untagged templates") } // When authorizing with a PSK, we automatically scope the provisionerd