From 6c18525e83ddc4a46992dff29e02bb65c5334370 Mon Sep 17 00:00:00 2001 From: Cian Johnston Date: Wed, 21 Feb 2024 11:57:49 +0000 Subject: [PATCH 1/3] feat(coderd/database): update AcquireProvisionerJob query to allow specifying exact tag match behaviour --- coderd/database/dbfake/dbfake.go | 5 +- coderd/database/dbgen/dbgen.go | 9 +- coderd/database/dbmem/dbmem.go | 33 ++++--- coderd/database/queries.sql.go | 18 ++-- coderd/database/queries/provisionerjobs.sql | 8 +- coderd/provisionerdserver/acquirer.go | 17 +++- coderd/provisionerdserver/acquirer_test.go | 88 +++++++++++++++++++ .../provisionerdserver_test.go | 3 +- enterprise/coderd/schedule/template_test.go | 5 +- 9 files changed, 154 insertions(+), 32 deletions(-) diff --git a/coderd/database/dbfake/dbfake.go b/coderd/database/dbfake/dbfake.go index ea49c78065657..766c162f69aa0 100644 --- a/coderd/database/dbfake/dbfake.go +++ b/coderd/database/dbfake/dbfake.go @@ -192,8 +192,9 @@ func (b WorkspaceBuildBuilder) Do() WorkspaceResponse { UUID: uuid.New(), Valid: true, }, - Types: []database.ProvisionerType{database.ProvisionerTypeEcho}, - Tags: []byte(`{"scope": "organization"}`), + Types: []database.ProvisionerType{database.ProvisionerTypeEcho}, + Tags: []byte(`{"scope": "organization"}`), + ExactTagMatch: false, }) require.NoError(b.t, err, "acquire starting job") if j.ID == job.ID { diff --git a/coderd/database/dbgen/dbgen.go b/coderd/database/dbgen/dbgen.go index c24f4cb826700..6b07e0c44de19 100644 --- a/coderd/database/dbgen/dbgen.go +++ b/coderd/database/dbgen/dbgen.go @@ -417,10 +417,11 @@ func ProvisionerJob(t testing.TB, db database.Store, ps pubsub.Pubsub, orig data } if !orig.StartedAt.Time.IsZero() { job, err = db.AcquireProvisionerJob(genCtx, database.AcquireProvisionerJobParams{ - StartedAt: orig.StartedAt, - Types: []database.ProvisionerType{database.ProvisionerTypeEcho}, - Tags: must(json.Marshal(orig.Tags)), - WorkerID: uuid.NullUUID{}, + StartedAt: orig.StartedAt, + Types: []database.ProvisionerType{database.ProvisionerTypeEcho}, + Tags: must(json.Marshal(orig.Tags)), + WorkerID: uuid.NullUUID{}, + ExactTagMatch: false, }) require.NoError(t, err) // There is no easy way to make sure we acquire the correct job. diff --git a/coderd/database/dbmem/dbmem.go b/coderd/database/dbmem/dbmem.go index 638fbef175636..30641f5ef113f 100644 --- a/coderd/database/dbmem/dbmem.go +++ b/coderd/database/dbmem/dbmem.go @@ -748,6 +748,23 @@ var deletedUserLinkError = &pq.Error{ Routine: "exec_stmt_raise", } +// m1 and m2 are equal if m1 is a subset of m2 +// and m2 is a subset of m1. +func tagsEqual(m1, m2 map[string]string) bool { + return tagsSubset(m1, m2) && tagsSubset(m2, m1) +} + +// m2 is a subset of m1 if each key in m1 exists in m2 +// with the same value +func tagsSubset(m1, m2 map[string]string) bool { + for k, v1 := range m1 { + if v2, found := m2[k]; !found || v1 != v2 { + return false + } + } + return true +} + func (*FakeQuerier) AcquireLock(_ context.Context, _ int64) error { return xerrors.New("AcquireLock must only be called within a transaction") } @@ -783,19 +800,11 @@ func (q *FakeQuerier) AcquireProvisionerJob(_ context.Context, arg database.Acqu } } - missing := false - for key, value := range provisionerJob.Tags { - provided, found := tags[key] - if !found { - missing = true - break - } - if provided != value { - missing = true - break - } + matchFunc := tagsSubset + if arg.ExactTagMatch { + matchFunc = tagsEqual } - if missing { + if !matchFunc(provisionerJob.Tags, tags) { continue } provisionerJob.StartedAt = arg.StartedAt diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index a23c9f97697de..9cbb086cdc391 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -3936,8 +3936,12 @@ WHERE nested.started_at IS NULL -- Ensure the caller has the correct provisioner. AND nested.provisioner = ANY($3 :: provisioner_type [ ]) - -- Ensure the caller satisfies all job tags. - AND nested.tags <@ $4 :: jsonb + -- Ensure the job matches satisfies all requested tags. + AND CASE + WHEN $4 :: boolean THEN nested.tags = $5 :: jsonb + ELSE + nested.tags <@ $5 :: jsonb + END ORDER BY nested.created_at FOR UPDATE @@ -3948,10 +3952,11 @@ WHERE ` type AcquireProvisionerJobParams struct { - StartedAt sql.NullTime `db:"started_at" json:"started_at"` - WorkerID uuid.NullUUID `db:"worker_id" json:"worker_id"` - Types []ProvisionerType `db:"types" json:"types"` - Tags json.RawMessage `db:"tags" json:"tags"` + StartedAt sql.NullTime `db:"started_at" json:"started_at"` + WorkerID uuid.NullUUID `db:"worker_id" json:"worker_id"` + Types []ProvisionerType `db:"types" json:"types"` + ExactTagMatch bool `db:"exact_tag_match" json:"exact_tag_match"` + Tags json.RawMessage `db:"tags" json:"tags"` } // Acquires the lock for a single job that isn't started, completed, @@ -3965,6 +3970,7 @@ func (q *sqlQuerier) AcquireProvisionerJob(ctx context.Context, arg AcquireProvi arg.StartedAt, arg.WorkerID, pq.Array(arg.Types), + arg.ExactTagMatch, arg.Tags, ) var i ProvisionerJob diff --git a/coderd/database/queries/provisionerjobs.sql b/coderd/database/queries/provisionerjobs.sql index b4c113c888dd4..9f6d59bc187c0 100644 --- a/coderd/database/queries/provisionerjobs.sql +++ b/coderd/database/queries/provisionerjobs.sql @@ -21,8 +21,12 @@ WHERE nested.started_at IS NULL -- Ensure the caller has the correct provisioner. AND nested.provisioner = ANY(@types :: provisioner_type [ ]) - -- Ensure the caller satisfies all job tags. - AND nested.tags <@ @tags :: jsonb + -- Ensure the job matches satisfies all requested tags. + AND CASE + WHEN @exact_tag_match :: boolean THEN nested.tags = @tags :: jsonb + ELSE + nested.tags <@ @tags :: jsonb + END ORDER BY nested.created_at FOR UPDATE diff --git a/coderd/provisionerdserver/acquirer.go b/coderd/provisionerdserver/acquirer.go index c9a43d660b671..b805937513aae 100644 --- a/coderd/provisionerdserver/acquirer.go +++ b/coderd/provisionerdserver/acquirer.go @@ -49,6 +49,8 @@ type Acquirer struct { mu sync.Mutex q map[dKey]domain + exactTagMatch bool + // testing only backupPollDuration time.Duration } @@ -61,6 +63,12 @@ func TestingBackupPollDuration(dur time.Duration) AcquirerOption { } } +func WithExactTagMatch() AcquirerOption { + return func(a *Acquirer) { + a.exactTagMatch = true + } +} + // AcquirerStore is the subset of database.Store that the Acquirer needs type AcquirerStore interface { AcquireProvisionerJob(context.Context, database.AcquireProvisionerJobParams) (database.ProvisionerJob, error) @@ -76,6 +84,7 @@ func NewAcquirer(ctx context.Context, logger slog.Logger, store AcquirerStore, p ps: ps, q: make(map[dKey]domain), backupPollDuration: backupPollDuration, + exactTagMatch: false, } for _, opt := range opts { opt(a) @@ -96,7 +105,8 @@ func (a *Acquirer) AcquireJob( logger := a.logger.With( slog.F("worker_id", worker), slog.F("provisioner_types", pt), - slog.F("tags", tags)) + slog.F("tags", tags), + slog.F("exact_tag_match", a.exactTagMatch)) logger.Debug(ctx, "acquiring job") dk := domainKey(pt, tags) dbTags, err := tags.ToJSON() @@ -128,8 +138,9 @@ func (a *Acquirer) AcquireJob( UUID: worker, Valid: true, }, - Types: pt, - Tags: dbTags, + Types: pt, + Tags: dbTags, + ExactTagMatch: a.exactTagMatch, }) if xerrors.Is(err, sql.ErrNoRows) { logger.Debug(ctx, "no job available") diff --git a/coderd/provisionerdserver/acquirer_test.go b/coderd/provisionerdserver/acquirer_test.go index bed8eccb68aca..1a4c418a2a7fd 100644 --- a/coderd/provisionerdserver/acquirer_test.go +++ b/coderd/provisionerdserver/acquirer_test.go @@ -9,6 +9,7 @@ import ( "time" "github.com/google/uuid" + "github.com/sqlc-dev/pqtype" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.uber.org/goleak" @@ -18,6 +19,8 @@ import ( "cdr.dev/slog/sloggers/slogtest" "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/dbmem" + "github.com/coder/coder/v2/coderd/database/dbtestutil" + "github.com/coder/coder/v2/coderd/database/dbtime" "github.com/coder/coder/v2/coderd/database/provisionerjobs" "github.com/coder/coder/v2/coderd/database/pubsub" "github.com/coder/coder/v2/coderd/provisionerdserver" @@ -315,6 +318,91 @@ func TestAcquirer_UnblockOnCancel(t *testing.T) { require.Equal(t, jobID, job.ID) } +func TestAcquirer_ExactTagMatch(t *testing.T) { + t.Parallel() + if testing.Short() { + t.Skip("skipping this test due to -short") + } + + for _, tt := range []struct { + name string + provisionerJobTags map[string]string + acquireJobTags map[string]string + expectAcquire bool + }{ + { + name: "match", + provisionerJobTags: map[string]string{"scope": "organization", "owner": "", "foo": "bar"}, + acquireJobTags: map[string]string{"scope": "organization", "owner": "", "foo": "bar"}, + expectAcquire: true, + }, + { + name: "subset", + provisionerJobTags: map[string]string{"scope": "organization", "owner": "", "foo": "bar"}, + acquireJobTags: map[string]string{"scope": "organization", "owner": ""}, + expectAcquire: false, + }, + { + name: "key mismatch", + provisionerJobTags: map[string]string{"scope": "organization", "owner": "", "fop": "bar"}, + acquireJobTags: map[string]string{"scope": "organization", "owner": "", "foo": "bar"}, + expectAcquire: false, + }, + { + name: "value mismatch", + provisionerJobTags: map[string]string{"scope": "organization", "owner": "", "foo": "baz"}, + acquireJobTags: map[string]string{"scope": "organization", "owner": "", "foo": "bar"}, + expectAcquire: false, + }, + } { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitShort) + // NOTE: explicitly not using fake store for this test. + db, ps := dbtestutil.NewDB(t) + log := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + org, err := db.InsertOrganization(ctx, database.InsertOrganizationParams{ + ID: uuid.New(), + Name: "test org", + Description: "the organization of testing", + CreatedAt: dbtime.Now(), + UpdatedAt: dbtime.Now(), + }) + require.NoError(t, err) + pj, err := db.InsertProvisionerJob(ctx, database.InsertProvisionerJobParams{ + ID: uuid.New(), + CreatedAt: dbtime.Now(), + UpdatedAt: dbtime.Now(), + OrganizationID: org.ID, + InitiatorID: uuid.New(), + Provisioner: database.ProvisionerTypeEcho, + StorageMethod: database.ProvisionerStorageMethodFile, + FileID: uuid.New(), + Type: database.ProvisionerJobTypeWorkspaceBuild, + Input: []byte("{}"), + Tags: tt.provisionerJobTags, + TraceMetadata: pqtype.NullRawMessage{}, + }) + require.NoError(t, err) + ptypes := []database.ProvisionerType{database.ProvisionerTypeEcho} + opts := []provisionerdserver.AcquirerOption{ + provisionerdserver.WithExactTagMatch(), + } + acq := provisionerdserver.NewAcquirer(ctx, log, db, ps, opts...) + aj, err := acq.AcquireJob(ctx, uuid.New(), ptypes, tt.acquireJobTags) + if tt.expectAcquire { + require.NoError(t, err) + require.Equal(t, pj.ID, aj.ID) + } else { + require.ErrorIs(t, err, context.DeadlineExceeded, "should have timed out") + require.Empty(t, aj, "should not have acquired job") + } + }) + } +} + func postJob(t *testing.T, ps pubsub.Pubsub, pt database.ProvisionerType, tags provisionerdserver.Tags) { t.Helper() msg, err := json.Marshal(provisionerjobs.JobPosting{ diff --git a/coderd/provisionerdserver/provisionerdserver_test.go b/coderd/provisionerdserver/provisionerdserver_test.go index 679717657ccec..bb4f44c8c2dae 100644 --- a/coderd/provisionerdserver/provisionerdserver_test.go +++ b/coderd/provisionerdserver/provisionerdserver_test.go @@ -573,7 +573,8 @@ func TestUpdateJob(t *testing.T) { UUID: srvID, Valid: true, }, - Types: []database.ProvisionerType{database.ProvisionerTypeEcho}, + Types: []database.ProvisionerType{database.ProvisionerTypeEcho}, + ExactTagMatch: false, }) require.NoError(t, err) return job.ID diff --git a/enterprise/coderd/schedule/template_test.go b/enterprise/coderd/schedule/template_test.go index ac158a795b3a3..d0e9bfe114ff3 100644 --- a/enterprise/coderd/schedule/template_test.go +++ b/enterprise/coderd/schedule/template_test.go @@ -181,8 +181,9 @@ func TestTemplateUpdateBuildDeadlines(t *testing.T) { UUID: uuid.New(), Valid: true, }, - Types: []database.ProvisionerType{database.ProvisionerTypeEcho}, - Tags: json.RawMessage(fmt.Sprintf(`{%q: "yeah"}`, c.name)), + Types: []database.ProvisionerType{database.ProvisionerTypeEcho}, + Tags: json.RawMessage(fmt.Sprintf(`{%q: "yeah"}`, c.name)), + ExactTagMatch: false, }) require.NoError(t, err) require.Equal(t, job.ID, acquiredJob.ID) From 53ba28fd11a9e3ac49af3843d24025fd28d99cb7 Mon Sep 17 00:00:00 2001 From: Cian Johnston Date: Wed, 21 Feb 2024 15:15:42 +0000 Subject: [PATCH 2/3] address PR comments --- coderd/database/dbmem/dbmem.go | 5 ++--- coderd/database/queries.sql.go | 3 ++- coderd/database/queries/provisionerjobs.sql | 3 ++- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/coderd/database/dbmem/dbmem.go b/coderd/database/dbmem/dbmem.go index 30641f5ef113f..f457116c32556 100644 --- a/coderd/database/dbmem/dbmem.go +++ b/coderd/database/dbmem/dbmem.go @@ -748,10 +748,9 @@ var deletedUserLinkError = &pq.Error{ Routine: "exec_stmt_raise", } -// m1 and m2 are equal if m1 is a subset of m2 -// and m2 is a subset of m1. +// m1 and m2 are equal iff |m1| = |m2| ^ m1 ⊆ m2 func tagsEqual(m1, m2 map[string]string) bool { - return tagsSubset(m1, m2) && tagsSubset(m2, m1) + return len(m1) == len(m2) && tagsSubset(m1, m2) } // m2 is a subset of m1 if each key in m1 exists in m2 diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index 9cbb086cdc391..638513fe3867d 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -3936,9 +3936,10 @@ WHERE nested.started_at IS NULL -- Ensure the caller has the correct provisioner. AND nested.provisioner = ANY($3 :: provisioner_type [ ]) - -- Ensure the job matches satisfies all requested tags. + -- Ensure the caller satisfies all job tags if requested, AND CASE WHEN $4 :: boolean THEN nested.tags = $5 :: jsonb + -- Otherwise, ensure caller satisfies a subset of tags. ELSE nested.tags <@ $5 :: jsonb END diff --git a/coderd/database/queries/provisionerjobs.sql b/coderd/database/queries/provisionerjobs.sql index 9f6d59bc187c0..bfdcca657db87 100644 --- a/coderd/database/queries/provisionerjobs.sql +++ b/coderd/database/queries/provisionerjobs.sql @@ -21,9 +21,10 @@ WHERE nested.started_at IS NULL -- Ensure the caller has the correct provisioner. AND nested.provisioner = ANY(@types :: provisioner_type [ ]) - -- Ensure the job matches satisfies all requested tags. + -- Ensure the caller satisfies all job tags if requested, AND CASE WHEN @exact_tag_match :: boolean THEN nested.tags = @tags :: jsonb + -- Otherwise, ensure caller satisfies a subset of tags. ELSE nested.tags <@ @tags :: jsonb END From 14fd776f46677a42d3896bcf26d02938074c8f63 Mon Sep 17 00:00:00 2001 From: Cian Johnston Date: Wed, 21 Feb 2024 19:22:17 +0000 Subject: [PATCH 3/3] fixup! address PR comments --- coderd/database/dbmem/dbmem.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/coderd/database/dbmem/dbmem.go b/coderd/database/dbmem/dbmem.go index f457116c32556..9811b78bfa9ea 100644 --- a/coderd/database/dbmem/dbmem.go +++ b/coderd/database/dbmem/dbmem.go @@ -748,7 +748,7 @@ var deletedUserLinkError = &pq.Error{ Routine: "exec_stmt_raise", } -// m1 and m2 are equal iff |m1| = |m2| ^ m1 ⊆ m2 +// m1 and m2 are equal iff |m1| = |m2| ^ m2 ⊆ m1 func tagsEqual(m1, m2 map[string]string) bool { return len(m1) == len(m2) && tagsSubset(m1, m2) }