Skip to content

fix(coderd): only allow untagged provisioners to pick up untagged jobs #12269

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Feb 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 27 additions & 12 deletions coderd/database/dbmem/dbmem.go
Original file line number Diff line number Diff line change
Expand Up @@ -748,6 +748,25 @@ var deletedUserLinkError = &pq.Error{
Routine: "exec_stmt_raise",
}

// m1 and m2 are equal iff |m1| = |m2| ^ m2 ⊆ m1
func tagsEqual(m1, m2 map[string]string) bool {
return len(m1) == len(m2) && tagsSubset(m1, m2)
}

// m2 is a subset of m1 if each key in m1 exists in m2
// with the same value
func tagsSubset(m1, m2 map[string]string) bool {
for k, v1 := range m1 {
if v2, found := m2[k]; !found || v1 != v2 {
return false
}
}
return true
}

// default tags when no tag is specified for a provisioner or job
var tagsUntagged = provisionersdk.MutateTags(uuid.Nil, nil)

func (*FakeQuerier) AcquireLock(_ context.Context, _ int64) error {
return xerrors.New("AcquireLock must only be called within a transaction")
}
Expand Down Expand Up @@ -783,19 +802,15 @@ func (q *FakeQuerier) AcquireProvisionerJob(_ context.Context, arg database.Acqu
}
}

missing := false
for key, value := range provisionerJob.Tags {
provided, found := tags[key]
if !found {
missing = true
break
}
if provided != value {
missing = true
break
}
// Special case for untagged provisioners: only match untagged jobs.
// Ref: coderd/database/queries/provisionerjobs.sql:24-30
// CASE WHEN nested.tags :: jsonb = '{"scope": "organization", "owner": ""}' :: jsonb
// THEN nested.tags :: jsonb = @tags :: jsonb
if tagsEqual(provisionerJob.Tags, tagsUntagged) && !tagsEqual(provisionerJob.Tags, tags) {
continue
}
if missing {
// ELSE nested.tags :: jsonb <@ @tags :: jsonb
if !tagsSubset(provisionerJob.Tags, tags) {
continue
}
provisionerJob.StartedAt = arg.StartedAt
Expand Down
9 changes: 7 additions & 2 deletions coderd/database/queries.sql.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

9 changes: 7 additions & 2 deletions coderd/database/queries/provisionerjobs.sql
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,13 @@ WHERE
nested.started_at IS NULL
-- Ensure the caller has the correct provisioner.
AND nested.provisioner = ANY(@types :: provisioner_type [ ])
-- Ensure the caller satisfies all job tags.
AND nested.tags <@ @tags :: jsonb
AND CASE
-- Special case for untagged provisioners: only match untagged jobs.
WHEN nested.tags :: jsonb = '{"scope": "organization", "owner": ""}' :: jsonb
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When we assign @ tags, will it always contain scope = organization when untagged and used live?

Copy link
Member Author

@johnstcn johnstcn Feb 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We make sure to call provisionersdk.MutateTags before serializing tags, so it will always contain scope and owner keys.

THEN nested.tags :: jsonb = @tags :: jsonb
-- Ensure the caller satisfies all job tags.
ELSE nested.tags :: jsonb <@ @tags :: jsonb
END
ORDER BY
nested.created_at
FOR UPDATE
Expand Down
130 changes: 130 additions & 0 deletions coderd/provisionerdserver/acquirer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"time"

"github.com/google/uuid"
"github.com/sqlc-dev/pqtype"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/goleak"
Expand All @@ -18,6 +19,8 @@ import (
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbmem"
"github.com/coder/coder/v2/coderd/database/dbtestutil"
"github.com/coder/coder/v2/coderd/database/dbtime"
"github.com/coder/coder/v2/coderd/database/provisionerjobs"
"github.com/coder/coder/v2/coderd/database/pubsub"
"github.com/coder/coder/v2/coderd/provisionerdserver"
Expand Down Expand Up @@ -315,6 +318,133 @@ func TestAcquirer_UnblockOnCancel(t *testing.T) {
require.Equal(t, jobID, job.ID)
}

func TestAcquirer_MatchTags(t *testing.T) {
t.Parallel()
if testing.Short() {
t.Skip("skipping this test due to -short")
}

someID := uuid.NewString()
someOtherID := uuid.NewString()

for _, tt := range []struct {
name string
provisionerJobTags map[string]string
acquireJobTags map[string]string
expectAcquire bool
}{
{
name: "untagged provisioner and untagged job",
provisionerJobTags: map[string]string{"scope": "organization", "owner": ""},
acquireJobTags: map[string]string{"scope": "organization", "owner": ""},
expectAcquire: true,
},
{
name: "untagged provisioner and tagged job",
provisionerJobTags: map[string]string{"scope": "organization", "owner": "", "foo": "bar"},
acquireJobTags: map[string]string{"scope": "organization", "owner": ""},
expectAcquire: false,
},
{
name: "tagged provisioner and untagged job",
provisionerJobTags: map[string]string{"scope": "organization", "owner": ""},
acquireJobTags: map[string]string{"scope": "organization", "owner": "", "foo": "bar"},
expectAcquire: false,
},
Comment on lines +348 to +353
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

review: this is the behaviour that changes

{
name: "tagged provisioner and tagged job",
provisionerJobTags: map[string]string{"scope": "organization", "owner": "", "foo": "bar"},
acquireJobTags: map[string]string{"scope": "organization", "owner": "", "foo": "bar"},
expectAcquire: true,
},
{
name: "tagged provisioner and double-tagged job",
provisionerJobTags: map[string]string{"scope": "organization", "owner": "", "foo": "bar", "baz": "zap"},
acquireJobTags: map[string]string{"scope": "organization", "owner": "", "foo": "bar"},
expectAcquire: false,
},
{
name: "double-tagged provisioner and tagged job",
provisionerJobTags: map[string]string{"scope": "organization", "owner": "", "foo": "bar"},
acquireJobTags: map[string]string{"scope": "organization", "owner": "", "foo": "bar", "baz": "zap"},
expectAcquire: true,
},
{
name: "double-tagged provisioner and double-tagged job",
provisionerJobTags: map[string]string{"scope": "organization", "owner": "", "foo": "bar", "baz": "zap"},
acquireJobTags: map[string]string{"scope": "organization", "owner": "", "foo": "bar", "baz": "zap"},
expectAcquire: true,
},
{
name: "owner-scoped provisioner and untagged job",
provisionerJobTags: map[string]string{"scope": "organization", "owner": ""},
acquireJobTags: map[string]string{"scope": "owner", "owner": someID},
expectAcquire: false,
},
{
name: "owner-scoped provisioner and owner-scoped job",
provisionerJobTags: map[string]string{"scope": "owner", "owner": someID},
acquireJobTags: map[string]string{"scope": "owner", "owner": someID},
expectAcquire: true,
},
{
name: "owner-scoped provisioner and different owner-scoped job",
provisionerJobTags: map[string]string{"scope": "owner", "owner": someOtherID},
acquireJobTags: map[string]string{"scope": "owner", "owner": someID},
expectAcquire: false,
},
{
name: "org-scoped provisioner and owner-scoped job",
provisionerJobTags: map[string]string{"scope": "owner", "owner": someID},
acquireJobTags: map[string]string{"scope": "organization", "owner": ""},
expectAcquire: false,
},
} {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()

ctx := testutil.Context(t, testutil.WaitShort/2)
// NOTE: explicitly not using fake store for this test.
db, ps := dbtestutil.NewDB(t)
log := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
org, err := db.InsertOrganization(ctx, database.InsertOrganizationParams{
ID: uuid.New(),
Name: "test org",
Description: "the organization of testing",
CreatedAt: dbtime.Now(),
UpdatedAt: dbtime.Now(),
})
require.NoError(t, err)
pj, err := db.InsertProvisionerJob(ctx, database.InsertProvisionerJobParams{
ID: uuid.New(),
CreatedAt: dbtime.Now(),
UpdatedAt: dbtime.Now(),
OrganizationID: org.ID,
InitiatorID: uuid.New(),
Provisioner: database.ProvisionerTypeEcho,
StorageMethod: database.ProvisionerStorageMethodFile,
FileID: uuid.New(),
Type: database.ProvisionerJobTypeWorkspaceBuild,
Input: []byte("{}"),
Tags: tt.provisionerJobTags,
TraceMetadata: pqtype.NullRawMessage{},
})
require.NoError(t, err)
ptypes := []database.ProvisionerType{database.ProvisionerTypeEcho}
acq := provisionerdserver.NewAcquirer(ctx, log, db, ps)
aj, err := acq.AcquireJob(ctx, uuid.New(), ptypes, tt.acquireJobTags)
if tt.expectAcquire {
assert.NoError(t, err)
assert.Equal(t, pj.ID, aj.ID)
} else {
assert.Empty(t, aj, "should not have acquired job")
assert.ErrorIs(t, err, context.DeadlineExceeded, "should have timed out")
}
})
}
}

func postJob(t *testing.T, ps pubsub.Pubsub, pt database.ProvisionerType, tags provisionerdserver.Tags) {
t.Helper()
msg, err := json.Marshal(provisionerjobs.JobPosting{
Expand Down
5 changes: 2 additions & 3 deletions enterprise/cli/provisionerdaemons.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,8 @@ func (r *RootCmd) provisionerDaemonStart() *clibase.Cmd {
defer closeLogger()
}

if len(tags) != 0 {
logger.Info(ctx, "note: tagged provisioners can currently pick up jobs from untagged templates")
logger.Info(ctx, "see https://github.com/coder/coder/issues/6442 for details")
if len(tags) == 0 {
logger.Info(ctx, "note: untagged provisioners can only pick up jobs from untagged templates")
}

// When authorizing with a PSK, we automatically scope the provisionerd
Expand Down