diff --git a/coderd/apidoc/docs.go b/coderd/apidoc/docs.go index 6c770c18232ac..d03c8b9be989d 100644 --- a/coderd/apidoc/docs.go +++ b/coderd/apidoc/docs.go @@ -2941,6 +2941,12 @@ const docTemplate = `{ "name": "organization", "in": "path", "required": true + }, + { + "type": "object", + "description": "Provisioner tags to filter by (JSON of the form {'tag1':'value1','tag2':'value2'})", + "name": "tags", + "in": "query" } ], "responses": { diff --git a/coderd/apidoc/swagger.json b/coderd/apidoc/swagger.json index 4f5ca444f703e..b740ea82e8abb 100644 --- a/coderd/apidoc/swagger.json +++ b/coderd/apidoc/swagger.json @@ -2579,6 +2579,12 @@ "name": "organization", "in": "path", "required": true + }, + { + "type": "object", + "description": "Provisioner tags to filter by (JSON of the form {'tag1':'value1','tag2':'value2'})", + "name": "tags", + "in": "query" } ], "responses": { diff --git a/coderd/database/dbauthz/dbauthz.go b/coderd/database/dbauthz/dbauthz.go index c855d5a1984df..35a052586fbaf 100644 --- a/coderd/database/dbauthz/dbauthz.go +++ b/coderd/database/dbauthz/dbauthz.go @@ -1888,7 +1888,7 @@ func (q *querier) GetProvisionerDaemons(ctx context.Context) ([]database.Provisi return fetchWithPostFilter(q.auth, policy.ActionRead, fetch)(ctx, nil) } -func (q *querier) GetProvisionerDaemonsByOrganization(ctx context.Context, organizationID uuid.UUID) ([]database.ProvisionerDaemon, error) { +func (q *querier) GetProvisionerDaemonsByOrganization(ctx context.Context, organizationID database.GetProvisionerDaemonsByOrganizationParams) ([]database.ProvisionerDaemon, error) { return fetchWithPostFilter(q.auth, policy.ActionRead, q.db.GetProvisionerDaemonsByOrganization)(ctx, organizationID) } diff --git a/coderd/database/dbauthz/dbauthz_test.go b/coderd/database/dbauthz/dbauthz_test.go index b610efe0349f5..413d79a84a400 100644 --- a/coderd/database/dbauthz/dbauthz_test.go +++ b/coderd/database/dbauthz/dbauthz_test.go @@ -2066,9 +2066,9 @@ func (s *MethodTestSuite) TestExtraMethods() { }), }) s.NoError(err, "insert provisioner daemon") - ds, err := db.GetProvisionerDaemonsByOrganization(context.Background(), org.ID) + ds, err := db.GetProvisionerDaemonsByOrganization(context.Background(), database.GetProvisionerDaemonsByOrganizationParams{OrganizationID: org.ID}) s.NoError(err, "get provisioner daemon by org") - check.Args(org.ID).Asserts(d, policy.ActionRead).Returns(ds) + check.Args(database.GetProvisionerDaemonsByOrganizationParams{OrganizationID: org.ID}).Asserts(d, policy.ActionRead).Returns(ds) })) s.Run("DeleteOldProvisionerDaemons", s.Subtest(func(db database.Store, check *expects) { _, err := db.UpsertProvisionerDaemon(context.Background(), database.UpsertProvisionerDaemonParams{ @@ -2560,7 +2560,7 @@ func (s *MethodTestSuite) TestSystemFunctions() { j := dbgen.ProvisionerJob(s.T(), db, nil, database.ProvisionerJob{ StartedAt: sql.NullTime{Valid: false}, }) - check.Args(database.AcquireProvisionerJobParams{OrganizationID: j.OrganizationID, Types: []database.ProvisionerType{j.Provisioner}, Tags: must(json.Marshal(j.Tags))}). + check.Args(database.AcquireProvisionerJobParams{OrganizationID: j.OrganizationID, Types: []database.ProvisionerType{j.Provisioner}, ProvisionerTags: must(json.Marshal(j.Tags))}). Asserts( /*rbac.ResourceSystem, policy.ActionUpdate*/ ) })) s.Run("UpdateProvisionerJobWithCompleteByID", s.Subtest(func(db database.Store, check *expects) { diff --git a/coderd/database/dbfake/dbfake.go b/coderd/database/dbfake/dbfake.go index ca514479cab6a..9c5a09f40ff65 100644 --- a/coderd/database/dbfake/dbfake.go +++ b/coderd/database/dbfake/dbfake.go @@ -194,8 +194,8 @@ func (b WorkspaceBuildBuilder) Do() WorkspaceResponse { UUID: uuid.New(), Valid: true, }, - Types: []database.ProvisionerType{database.ProvisionerTypeEcho}, - Tags: []byte(`{"scope": "organization"}`), + Types: []database.ProvisionerType{database.ProvisionerTypeEcho}, + ProvisionerTags: []byte(`{"scope": "organization"}`), }) require.NoError(b.t, err, "acquire starting job") if j.ID == job.ID { diff --git a/coderd/database/dbgen/dbgen.go b/coderd/database/dbgen/dbgen.go index 5e83125a93b84..793795af7762a 100644 --- a/coderd/database/dbgen/dbgen.go +++ b/coderd/database/dbgen/dbgen.go @@ -531,11 +531,11 @@ func ProvisionerJob(t testing.TB, db database.Store, ps pubsub.Pubsub, orig data } if !orig.StartedAt.Time.IsZero() { job, err = db.AcquireProvisionerJob(genCtx, database.AcquireProvisionerJobParams{ - StartedAt: orig.StartedAt, - OrganizationID: job.OrganizationID, - Types: []database.ProvisionerType{database.ProvisionerTypeEcho}, - Tags: must(json.Marshal(orig.Tags)), - WorkerID: uuid.NullUUID{}, + StartedAt: orig.StartedAt, + OrganizationID: job.OrganizationID, + Types: []database.ProvisionerType{database.ProvisionerTypeEcho}, + ProvisionerTags: must(json.Marshal(orig.Tags)), + WorkerID: uuid.NullUUID{}, }) require.NoError(t, err) // There is no easy way to make sure we acquire the correct job. diff --git a/coderd/database/dbmem/dbmem.go b/coderd/database/dbmem/dbmem.go index 9a306db09785e..002170da508c8 100644 --- a/coderd/database/dbmem/dbmem.go +++ b/coderd/database/dbmem/dbmem.go @@ -1194,8 +1194,8 @@ func (q *FakeQuerier) AcquireProvisionerJob(_ context.Context, arg database.Acqu continue } tags := map[string]string{} - if arg.Tags != nil { - err := json.Unmarshal(arg.Tags, &tags) + if arg.ProvisionerTags != nil { + err := json.Unmarshal(arg.ProvisionerTags, &tags) if err != nil { return provisionerJob, xerrors.Errorf("unmarshal: %w", err) } @@ -3625,16 +3625,28 @@ func (q *FakeQuerier) GetProvisionerDaemons(_ context.Context) ([]database.Provi return out, nil } -func (q *FakeQuerier) GetProvisionerDaemonsByOrganization(_ context.Context, organizationID uuid.UUID) ([]database.ProvisionerDaemon, error) { +func (q *FakeQuerier) GetProvisionerDaemonsByOrganization(_ context.Context, arg database.GetProvisionerDaemonsByOrganizationParams) ([]database.ProvisionerDaemon, error) { q.mutex.RLock() defer q.mutex.RUnlock() daemons := make([]database.ProvisionerDaemon, 0) for _, daemon := range q.provisionerDaemons { - if daemon.OrganizationID == organizationID { - daemon.Tags = maps.Clone(daemon.Tags) - daemons = append(daemons, daemon) + if daemon.OrganizationID != arg.OrganizationID { + continue + } + // Special case for untagged provisioners: only match untagged jobs. + // Ref: coderd/database/queries/provisionerjobs.sql:24-30 + // CASE WHEN nested.tags :: jsonb = '{"scope": "organization", "owner": ""}' :: jsonb + // THEN nested.tags :: jsonb = @tags :: jsonb + if tagsEqual(arg.WantTags, tagsUntagged) && !tagsEqual(arg.WantTags, daemon.Tags) { + continue + } + // ELSE nested.tags :: jsonb <@ @tags :: jsonb + if !tagsSubset(arg.WantTags, daemon.Tags) { + continue } + daemon.Tags = maps.Clone(daemon.Tags) + daemons = append(daemons, daemon) } return daemons, nil diff --git a/coderd/database/dbmetrics/querymetrics.go b/coderd/database/dbmetrics/querymetrics.go index cee25e482bbaa..4592100992352 100644 --- a/coderd/database/dbmetrics/querymetrics.go +++ b/coderd/database/dbmetrics/querymetrics.go @@ -959,9 +959,9 @@ func (m queryMetricsStore) GetProvisionerDaemons(ctx context.Context) ([]databas return daemons, err } -func (m queryMetricsStore) GetProvisionerDaemonsByOrganization(ctx context.Context, organizationID uuid.UUID) ([]database.ProvisionerDaemon, error) { +func (m queryMetricsStore) GetProvisionerDaemonsByOrganization(ctx context.Context, arg database.GetProvisionerDaemonsByOrganizationParams) ([]database.ProvisionerDaemon, error) { start := time.Now() - r0, r1 := m.s.GetProvisionerDaemonsByOrganization(ctx, organizationID) + r0, r1 := m.s.GetProvisionerDaemonsByOrganization(ctx, arg) m.queryLatencies.WithLabelValues("GetProvisionerDaemonsByOrganization").Observe(time.Since(start).Seconds()) return r0, r1 } diff --git a/coderd/database/dbmock/dbmock.go b/coderd/database/dbmock/dbmock.go index d8721f56d3f4e..1b13c8c25a1e1 100644 --- a/coderd/database/dbmock/dbmock.go +++ b/coderd/database/dbmock/dbmock.go @@ -1973,7 +1973,7 @@ func (mr *MockStoreMockRecorder) GetProvisionerDaemons(arg0 any) *gomock.Call { } // GetProvisionerDaemonsByOrganization mocks base method. -func (m *MockStore) GetProvisionerDaemonsByOrganization(arg0 context.Context, arg1 uuid.UUID) ([]database.ProvisionerDaemon, error) { +func (m *MockStore) GetProvisionerDaemonsByOrganization(arg0 context.Context, arg1 database.GetProvisionerDaemonsByOrganizationParams) ([]database.ProvisionerDaemon, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetProvisionerDaemonsByOrganization", arg0, arg1) ret0, _ := ret[0].([]database.ProvisionerDaemon) diff --git a/coderd/database/dump.sql b/coderd/database/dump.sql index 557b5c2dd9325..fa0c4b6845230 100644 --- a/coderd/database/dump.sql +++ b/coderd/database/dump.sql @@ -198,6 +198,10 @@ CREATE TYPE startup_script_behavior AS ENUM ( 'non-blocking' ); +CREATE DOMAIN tagset AS jsonb; + +COMMENT ON DOMAIN tagset IS 'A set of tags that match provisioner daemons to provisioner jobs, which can originate from workspaces or templates. tagset is a narrowed type over jsonb. It is expected to be the JSON representation of map[string]string. That is, {"key1": "value1", "key2": "value2"}. We need the narrowed type instead of just using jsonb so that we can give sqlc a type hint, otherwise it defaults to json.RawMessage. json.RawMessage is a suboptimal type to use in the context that we need tagset for.'; + CREATE TYPE tailnet_status AS ENUM ( 'ok', 'lost' @@ -376,6 +380,21 @@ BEGIN END; $$; +CREATE FUNCTION provisioner_tagset_contains(provisioner_tags tagset, job_tags tagset) RETURNS boolean + LANGUAGE plpgsql + AS $$ +BEGIN + RETURN CASE + -- Special case for untagged provisioners, where only an exact match should count + WHEN job_tags::jsonb = '{"scope": "organization", "owner": ""}'::jsonb THEN job_tags::jsonb = provisioner_tags::jsonb + -- General case + ELSE job_tags::jsonb <@ provisioner_tags::jsonb + END; +END; +$$; + +COMMENT ON FUNCTION provisioner_tagset_contains(provisioner_tags tagset, job_tags tagset) IS 'Returns true if the provisioner_tags contains the job_tags, or if the job_tags represents an untagged provisioner and the superset is exactly equal to the subset.'; + CREATE FUNCTION remove_organization_member_role() RETURNS trigger LANGUAGE plpgsql AS $$ diff --git a/coderd/database/migrations/000274_check_tags.down.sql b/coderd/database/migrations/000274_check_tags.down.sql new file mode 100644 index 0000000000000..623a3e9dac6e5 --- /dev/null +++ b/coderd/database/migrations/000274_check_tags.down.sql @@ -0,0 +1,3 @@ +DROP FUNCTION IF EXISTS provisioner_tagset_contains(tagset, tagset); + +DROP DOMAIN IF EXISTS tagset; diff --git a/coderd/database/migrations/000274_check_tags.up.sql b/coderd/database/migrations/000274_check_tags.up.sql new file mode 100644 index 0000000000000..b897e5e8ea124 --- /dev/null +++ b/coderd/database/migrations/000274_check_tags.up.sql @@ -0,0 +1,17 @@ +CREATE DOMAIN tagset AS jsonb; + +COMMENT ON DOMAIN tagset IS 'A set of tags that match provisioner daemons to provisioner jobs, which can originate from workspaces or templates. tagset is a narrowed type over jsonb. It is expected to be the JSON representation of map[string]string. That is, {"key1": "value1", "key2": "value2"}. We need the narrowed type instead of just using jsonb so that we can give sqlc a type hint, otherwise it defaults to json.RawMessage. json.RawMessage is a suboptimal type to use in the context that we need tagset for.'; + +CREATE OR REPLACE FUNCTION provisioner_tagset_contains(provisioner_tags tagset, job_tags tagset) +RETURNS boolean AS $$ +BEGIN + RETURN CASE + -- Special case for untagged provisioners, where only an exact match should count + WHEN job_tags::jsonb = '{"scope": "organization", "owner": ""}'::jsonb THEN job_tags::jsonb = provisioner_tags::jsonb + -- General case + ELSE job_tags::jsonb <@ provisioner_tags::jsonb + END; +END; +$$ LANGUAGE plpgsql; + +COMMENT ON FUNCTION provisioner_tagset_contains(tagset, tagset) IS 'Returns true if the provisioner_tags contains the job_tags, or if the job_tags represents an untagged provisioner and the superset is exactly equal to the subset.'; diff --git a/coderd/database/querier.go b/coderd/database/querier.go index 46d1b1ae5b322..c7bf9e837c7b8 100644 --- a/coderd/database/querier.go +++ b/coderd/database/querier.go @@ -196,7 +196,7 @@ type sqlcQuerier interface { GetParameterSchemasByJobID(ctx context.Context, jobID uuid.UUID) ([]ParameterSchema, error) GetPreviousTemplateVersion(ctx context.Context, arg GetPreviousTemplateVersionParams) (TemplateVersion, error) GetProvisionerDaemons(ctx context.Context) ([]ProvisionerDaemon, error) - GetProvisionerDaemonsByOrganization(ctx context.Context, organizationID uuid.UUID) ([]ProvisionerDaemon, error) + GetProvisionerDaemonsByOrganization(ctx context.Context, arg GetProvisionerDaemonsByOrganizationParams) ([]ProvisionerDaemon, error) GetProvisionerJobByID(ctx context.Context, id uuid.UUID) (ProvisionerJob, error) GetProvisionerJobTimingsByJobID(ctx context.Context, jobID uuid.UUID) ([]ProvisionerJobTiming, error) GetProvisionerJobsByIDs(ctx context.Context, ids []uuid.UUID) ([]ProvisionerJob, error) diff --git a/coderd/database/querier_test.go b/coderd/database/querier_test.go index b3fde5a558e6b..619e9868b612f 100644 --- a/coderd/database/querier_test.go +++ b/coderd/database/querier_test.go @@ -1020,7 +1020,7 @@ func TestQueuePosition(t *testing.T) { UUID: uuid.New(), Valid: true, }, - Tags: json.RawMessage("{}"), + ProvisionerTags: json.RawMessage("{}"), }) require.NoError(t, err) require.Equal(t, jobs[0].ID, job.ID) diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index 87d3c17f5400f..564246e32638c 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -5269,11 +5269,20 @@ SELECT FROM provisioner_daemons WHERE - organization_id = $1 + -- This is the original search criteria: + organization_id = $1 :: uuid + AND + -- adding support for searching by tags: + ($2 :: tagset = 'null' :: tagset OR provisioner_tagset_contains(provisioner_daemons.tags::tagset, $2::tagset)) ` -func (q *sqlQuerier) GetProvisionerDaemonsByOrganization(ctx context.Context, organizationID uuid.UUID) ([]ProvisionerDaemon, error) { - rows, err := q.db.QueryContext(ctx, getProvisionerDaemonsByOrganization, organizationID) +type GetProvisionerDaemonsByOrganizationParams struct { + OrganizationID uuid.UUID `db:"organization_id" json:"organization_id"` + WantTags StringMap `db:"want_tags" json:"want_tags"` +} + +func (q *sqlQuerier) GetProvisionerDaemonsByOrganization(ctx context.Context, arg GetProvisionerDaemonsByOrganizationParams) ([]ProvisionerDaemon, error) { + rows, err := q.db.QueryContext(ctx, getProvisionerDaemonsByOrganization, arg.OrganizationID, arg.WantTags) if err != nil { return nil, err } @@ -5523,21 +5532,17 @@ WHERE SELECT id FROM - provisioner_jobs AS nested + provisioner_jobs AS potential_job WHERE - nested.started_at IS NULL - AND nested.organization_id = $3 + potential_job.started_at IS NULL + AND potential_job.organization_id = $3 -- Ensure the caller has the correct provisioner. - AND nested.provisioner = ANY($4 :: provisioner_type [ ]) - AND CASE - -- Special case for untagged provisioners: only match untagged jobs. - WHEN nested.tags :: jsonb = '{"scope": "organization", "owner": ""}' :: jsonb - THEN nested.tags :: jsonb = $5 :: jsonb - -- Ensure the caller satisfies all job tags. - ELSE nested.tags :: jsonb <@ $5 :: jsonb - END + AND potential_job.provisioner = ANY($4 :: provisioner_type [ ]) + -- elsewhere, we use the tagset type, but here we use jsonb for backward compatibility + -- they are aliases and the code that calls this query already relies on a different type + AND provisioner_tagset_contains($5 :: jsonb, potential_job.tags :: jsonb) ORDER BY - nested.created_at + potential_job.created_at FOR UPDATE SKIP LOCKED LIMIT @@ -5546,11 +5551,11 @@ WHERE ` type AcquireProvisionerJobParams struct { - StartedAt sql.NullTime `db:"started_at" json:"started_at"` - WorkerID uuid.NullUUID `db:"worker_id" json:"worker_id"` - OrganizationID uuid.UUID `db:"organization_id" json:"organization_id"` - Types []ProvisionerType `db:"types" json:"types"` - Tags json.RawMessage `db:"tags" json:"tags"` + StartedAt sql.NullTime `db:"started_at" json:"started_at"` + WorkerID uuid.NullUUID `db:"worker_id" json:"worker_id"` + OrganizationID uuid.UUID `db:"organization_id" json:"organization_id"` + Types []ProvisionerType `db:"types" json:"types"` + ProvisionerTags json.RawMessage `db:"provisioner_tags" json:"provisioner_tags"` } // Acquires the lock for a single job that isn't started, completed, @@ -5565,7 +5570,7 @@ func (q *sqlQuerier) AcquireProvisionerJob(ctx context.Context, arg AcquireProvi arg.WorkerID, arg.OrganizationID, pq.Array(arg.Types), - arg.Tags, + arg.ProvisionerTags, ) var i ProvisionerJob err := row.Scan( diff --git a/coderd/database/queries/provisionerdaemons.sql b/coderd/database/queries/provisionerdaemons.sql index bee1c6e92ff4b..a6633c91158a9 100644 --- a/coderd/database/queries/provisionerdaemons.sql +++ b/coderd/database/queries/provisionerdaemons.sql @@ -10,7 +10,11 @@ SELECT FROM provisioner_daemons WHERE - organization_id = @organization_id; + -- This is the original search criteria: + organization_id = @organization_id :: uuid + AND + -- adding support for searching by tags: + (@want_tags :: tagset = 'null' :: tagset OR provisioner_tagset_contains(provisioner_daemons.tags::tagset, @want_tags::tagset)); -- name: DeleteOldProvisionerDaemons :exec -- Delete provisioner daemons that have been created at least a week ago diff --git a/coderd/database/queries/provisionerjobs.sql b/coderd/database/queries/provisionerjobs.sql index 95a84fcd3c824..95e8a88b84e6d 100644 --- a/coderd/database/queries/provisionerjobs.sql +++ b/coderd/database/queries/provisionerjobs.sql @@ -16,21 +16,17 @@ WHERE SELECT id FROM - provisioner_jobs AS nested + provisioner_jobs AS potential_job WHERE - nested.started_at IS NULL - AND nested.organization_id = @organization_id + potential_job.started_at IS NULL + AND potential_job.organization_id = @organization_id -- Ensure the caller has the correct provisioner. - AND nested.provisioner = ANY(@types :: provisioner_type [ ]) - AND CASE - -- Special case for untagged provisioners: only match untagged jobs. - WHEN nested.tags :: jsonb = '{"scope": "organization", "owner": ""}' :: jsonb - THEN nested.tags :: jsonb = @tags :: jsonb - -- Ensure the caller satisfies all job tags. - ELSE nested.tags :: jsonb <@ @tags :: jsonb - END + AND potential_job.provisioner = ANY(@types :: provisioner_type [ ]) + -- elsewhere, we use the tagset type, but here we use jsonb for backward compatibility + -- they are aliases and the code that calls this query already relies on a different type + AND provisioner_tagset_contains(@provisioner_tags :: jsonb, potential_job.tags :: jsonb) ORDER BY - nested.created_at + potential_job.created_at FOR UPDATE SKIP LOCKED LIMIT @@ -160,4 +156,4 @@ RETURNING *; -- name: GetProvisionerJobTimingsByJobID :many SELECT * FROM provisioner_job_timings WHERE job_id = $1 -ORDER BY started_at ASC; \ No newline at end of file +ORDER BY started_at ASC; diff --git a/coderd/database/sqlc.yaml b/coderd/database/sqlc.yaml index 2161feb47e1c3..6a9a66ee45a9b 100644 --- a/coderd/database/sqlc.yaml +++ b/coderd/database/sqlc.yaml @@ -35,6 +35,9 @@ sql: - db_type: "name_organization_pair" go_type: type: "NameOrganizationPair" + - db_type: "tagset" + go_type: + type: "StringMap" - column: "custom_roles.site_permissions" go_type: type: "CustomRolePermissions" diff --git a/coderd/provisionerdserver/acquirer.go b/coderd/provisionerdserver/acquirer.go index 36e0d51df44f8..4c2fe6b1d49a9 100644 --- a/coderd/provisionerdserver/acquirer.go +++ b/coderd/provisionerdserver/acquirer.go @@ -130,8 +130,8 @@ func (a *Acquirer) AcquireJob( UUID: worker, Valid: true, }, - Types: pt, - Tags: dbTags, + Types: pt, + ProvisionerTags: dbTags, }) if xerrors.Is(err, sql.ErrNoRows) { logger.Debug(ctx, "no job available") diff --git a/coderd/provisionerdserver/acquirer_test.go b/coderd/provisionerdserver/acquirer_test.go index a916cb68fba1f..6f0face1ebb4c 100644 --- a/coderd/provisionerdserver/acquirer_test.go +++ b/coderd/provisionerdserver/acquirer_test.go @@ -649,7 +649,7 @@ func (s *fakeTaggedStore) AcquireProvisionerJob( ) { defer func() { s.params <- params }() var tags provisionerdserver.Tags - err := json.Unmarshal(params.Tags, &tags) + err := json.Unmarshal(params.ProvisionerTags, &tags) if !assert.NoError(s.t, err) { return database.ProvisionerJob{}, err } diff --git a/codersdk/organizations.go b/codersdk/organizations.go index 77e24a2be3e10..4966b7a41809c 100644 --- a/codersdk/organizations.go +++ b/codersdk/organizations.go @@ -5,6 +5,7 @@ import ( "encoding/json" "fmt" "net/http" + "net/url" "strings" "time" @@ -314,11 +315,21 @@ func (c *Client) ProvisionerDaemons(ctx context.Context) ([]ProvisionerDaemon, e return daemons, json.NewDecoder(res.Body).Decode(&daemons) } -func (c *Client) OrganizationProvisionerDaemons(ctx context.Context, organizationID uuid.UUID) ([]ProvisionerDaemon, error) { - res, err := c.Request(ctx, http.MethodGet, - fmt.Sprintf("/api/v2/organizations/%s/provisionerdaemons", organizationID.String()), - nil, - ) +func (c *Client) OrganizationProvisionerDaemons(ctx context.Context, organizationID uuid.UUID, tags map[string]string) ([]ProvisionerDaemon, error) { + baseURL := fmt.Sprintf("/api/v2/organizations/%s/provisionerdaemons", organizationID.String()) + + queryParams := url.Values{} + tagsJSON, err := json.Marshal(tags) + if err != nil { + return nil, xerrors.Errorf("marshal tags: %w", err) + } + + queryParams.Add("tags", string(tagsJSON)) + if len(queryParams) > 0 { + baseURL = fmt.Sprintf("%s?%s", baseURL, queryParams.Encode()) + } + + res, err := c.Request(ctx, http.MethodGet, baseURL, nil) if err != nil { return nil, xerrors.Errorf("execute request: %w", err) } diff --git a/docs/reference/api/enterprise.md b/docs/reference/api/enterprise.md index 57ffa5260edde..b8764fd89b449 100644 --- a/docs/reference/api/enterprise.md +++ b/docs/reference/api/enterprise.md @@ -1480,9 +1480,10 @@ curl -X GET http://coder-server:8080/api/v2/organizations/{organization}/provisi ### Parameters -| Name | In | Type | Required | Description | -| -------------- | ---- | ------------ | -------- | --------------- | -| `organization` | path | string(uuid) | true | Organization ID | +| Name | In | Type | Required | Description | +| -------------- | ----- | ------------ | -------- | ---------------------------------------------------------------------------------- | +| `organization` | path | string(uuid) | true | Organization ID | +| `tags` | query | object | false | Provisioner tags to filter by (JSON of the form {'tag1':'value1','tag2':'value2'}) | ### Example responses diff --git a/enterprise/cli/provisionerdaemonstart_test.go b/enterprise/cli/provisionerdaemonstart_test.go index 3132e80a4c68e..763ac49b92996 100644 --- a/enterprise/cli/provisionerdaemonstart_test.go +++ b/enterprise/cli/provisionerdaemonstart_test.go @@ -236,7 +236,7 @@ func TestProvisionerDaemon_SessionToken(t *testing.T) { var daemons []codersdk.ProvisionerDaemon var err error require.Eventually(t, func() bool { - daemons, err = client.OrganizationProvisionerDaemons(ctx, anotherOrg.ID) + daemons, err = client.OrganizationProvisionerDaemons(ctx, anotherOrg.ID, nil) if err != nil { return false } @@ -282,7 +282,7 @@ func TestProvisionerDaemon_ProvisionerKey(t *testing.T) { var daemons []codersdk.ProvisionerDaemon require.Eventually(t, func() bool { - daemons, err = client.OrganizationProvisionerDaemons(ctx, user.OrganizationID) + daemons, err = client.OrganizationProvisionerDaemons(ctx, user.OrganizationID, nil) if err != nil { return false } @@ -376,7 +376,7 @@ func TestProvisionerDaemon_ProvisionerKey(t *testing.T) { var daemons []codersdk.ProvisionerDaemon require.Eventually(t, func() bool { - daemons, err = client.OrganizationProvisionerDaemons(ctx, anotherOrg.ID) + daemons, err = client.OrganizationProvisionerDaemons(ctx, anotherOrg.ID, nil) if err != nil { return false } diff --git a/enterprise/coderd/provisionerdaemons.go b/enterprise/coderd/provisionerdaemons.go index 6f8cd1a3ec0b6..0eb3a51db57dd 100644 --- a/enterprise/coderd/provisionerdaemons.go +++ b/enterprise/coderd/provisionerdaemons.go @@ -56,13 +56,33 @@ func (api *API) provisionerDaemonsEnabledMW(next http.Handler) http.Handler { // @Produce json // @Tags Enterprise // @Param organization path string true "Organization ID" format(uuid) +// @Param tags query object false "Provisioner tags to filter by (JSON of the form {'tag1':'value1','tag2':'value2'})" // @Success 200 {array} codersdk.ProvisionerDaemon // @Router /organizations/{organization}/provisionerdaemons [get] func (api *API) provisionerDaemons(rw http.ResponseWriter, r *http.Request) { - ctx := r.Context() - org := httpmw.OrganizationParam(r) + var ( + ctx = r.Context() + org = httpmw.OrganizationParam(r) + tagParam = r.URL.Query().Get("tags") + tags = database.StringMap{} + err = tags.Scan([]byte(tagParam)) + ) + + if tagParam != "" && err != nil { + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ + Message: "Invalid tags query parameter", + Detail: err.Error(), + }) + return + } - daemons, err := api.Database.GetProvisionerDaemonsByOrganization(ctx, org.ID) + daemons, err := api.Database.GetProvisionerDaemonsByOrganization( + ctx, + database.GetProvisionerDaemonsByOrganizationParams{ + OrganizationID: org.ID, + WantTags: tags, + }, + ) if err != nil { httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ Message: "Internal error fetching provisioner daemons.", diff --git a/enterprise/coderd/provisionerdaemons_test.go b/enterprise/coderd/provisionerdaemons_test.go index 09a4de5806e88..d8d770097c156 100644 --- a/enterprise/coderd/provisionerdaemons_test.go +++ b/enterprise/coderd/provisionerdaemons_test.go @@ -3,10 +3,12 @@ package coderd_test import ( "bytes" "context" + "database/sql" "fmt" "io" "net/http" "testing" + "time" "github.com/google/uuid" "github.com/stretchr/testify/assert" @@ -18,6 +20,7 @@ import ( "github.com/coder/coder/v2/buildinfo" "github.com/coder/coder/v2/coderd/coderdtest" "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/db2sdk" "github.com/coder/coder/v2/coderd/database/dbauthz" "github.com/coder/coder/v2/coderd/provisionerkey" "github.com/coder/coder/v2/coderd/rbac" @@ -768,7 +771,7 @@ func TestGetProvisionerDaemons(t *testing.T) { require.NoError(t, err) srv.DRPCConn().Close() - daemons, err := orgAdmin.OrganizationProvisionerDaemons(ctx, org.ID) + daemons, err := orgAdmin.OrganizationProvisionerDaemons(ctx, org.ID, nil) require.NoError(t, err) require.Len(t, daemons, 1) @@ -794,4 +797,207 @@ func TestGetProvisionerDaemons(t *testing.T) { _, err = outsideOrg.ListProvisionerKeyDaemons(ctx, org.ID) require.Error(t, err) }) + + t.Run("filtered by tags", func(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + tagsToFilterBy map[string]string + provisionerDaemonTags map[string]string + expectToGetDaemon bool + }{ + { + name: "only an empty tagset finds an untagged provisioner", + tagsToFilterBy: map[string]string{"scope": "organization", "owner": ""}, + provisionerDaemonTags: map[string]string{"scope": "organization", "owner": ""}, + expectToGetDaemon: true, + }, + { + name: "an exact match with a single optional tag finds a provisioner daemon", + tagsToFilterBy: map[string]string{"scope": "organization", "owner": "", "environment": "on-prem"}, + provisionerDaemonTags: map[string]string{"scope": "organization", "owner": "", "environment": "on-prem"}, + expectToGetDaemon: true, + }, + { + name: "a subset of filter tags finds a daemon with a superset of tags", + tagsToFilterBy: map[string]string{"scope": "organization", "owner": "", "environment": "on-prem"}, + provisionerDaemonTags: map[string]string{"scope": "organization", "owner": "", "environment": "on-prem", "datacenter": "chicago"}, + expectToGetDaemon: true, + }, + { + name: "an exact match with two additional tags finds a provisioner daemon", + tagsToFilterBy: map[string]string{"scope": "organization", "owner": "", "environment": "on-prem", "datacenter": "chicago"}, + provisionerDaemonTags: map[string]string{"scope": "organization", "owner": "", "environment": "on-prem", "datacenter": "chicago"}, + expectToGetDaemon: true, + }, + { + name: "a user scoped filter tag set finds a user scoped provisioner daemon", + tagsToFilterBy: map[string]string{"scope": "user", "owner": "aaa"}, + provisionerDaemonTags: map[string]string{"scope": "user", "owner": "aaa"}, + expectToGetDaemon: true, + }, + { + name: "a user scoped filter tag set finds a user scoped provisioner daemon with an additional tag", + tagsToFilterBy: map[string]string{"scope": "user", "owner": "aaa"}, + provisionerDaemonTags: map[string]string{"scope": "user", "owner": "aaa", "environment": "on-prem"}, + expectToGetDaemon: true, + }, + { + name: "user-scoped provisioner with tags and user-scoped filter with tags", + tagsToFilterBy: map[string]string{"scope": "user", "owner": "aaa", "environment": "on-prem"}, + provisionerDaemonTags: map[string]string{"scope": "user", "owner": "aaa", "environment": "on-prem"}, + expectToGetDaemon: true, + }, + { + name: "user-scoped provisioner with multiple tags and user-scoped filter with a subset of tags", + tagsToFilterBy: map[string]string{"scope": "user", "owner": "aaa", "environment": "on-prem"}, + provisionerDaemonTags: map[string]string{"scope": "user", "owner": "aaa", "environment": "on-prem", "datacenter": "chicago"}, + expectToGetDaemon: true, + }, + { + name: "user-scoped provisioner with multiple tags and user-scoped filter with multiple tags", + tagsToFilterBy: map[string]string{"scope": "user", "owner": "aaa", "environment": "on-prem", "datacenter": "chicago"}, + provisionerDaemonTags: map[string]string{"scope": "user", "owner": "aaa", "environment": "on-prem", "datacenter": "chicago"}, + expectToGetDaemon: true, + }, + { + name: "untagged provisioner and tagged filter", + tagsToFilterBy: map[string]string{"scope": "organization", "owner": "", "environment": "on-prem"}, + provisionerDaemonTags: map[string]string{"scope": "organization", "owner": ""}, + expectToGetDaemon: false, + }, + { + name: "tagged provisioner and untagged filter", + tagsToFilterBy: map[string]string{"scope": "organization", "owner": ""}, + provisionerDaemonTags: map[string]string{"scope": "organization", "owner": "", "environment": "on-prem"}, + expectToGetDaemon: false, + }, + { + name: "tagged provisioner and double-tagged filter", + tagsToFilterBy: map[string]string{"scope": "organization", "owner": "", "environment": "on-prem", "datacenter": "chicago"}, + provisionerDaemonTags: map[string]string{"scope": "organization", "owner": "", "environment": "on-prem"}, + expectToGetDaemon: false, + }, + { + name: "double-tagged provisioner and double-tagged filter with differing tags", + tagsToFilterBy: map[string]string{"scope": "organization", "owner": "", "environment": "on-prem", "datacenter": "chicago"}, + provisionerDaemonTags: map[string]string{"scope": "organization", "owner": "", "environment": "on-prem", "datacenter": "new_york"}, + expectToGetDaemon: false, + }, + { + name: "user-scoped provisioner and untagged filter", + tagsToFilterBy: map[string]string{"scope": "organization", "owner": ""}, + provisionerDaemonTags: map[string]string{"scope": "user", "owner": "aaa"}, + expectToGetDaemon: false, + }, + { + name: "user-scoped provisioner and different user-scoped filter", + tagsToFilterBy: map[string]string{"scope": "user", "owner": "bbb"}, + provisionerDaemonTags: map[string]string{"scope": "user", "owner": "aaa"}, + expectToGetDaemon: false, + }, + { + name: "org-scoped provisioner and user-scoped filter", + tagsToFilterBy: map[string]string{"scope": "user", "owner": "aaa"}, + provisionerDaemonTags: map[string]string{"scope": "organization", "owner": ""}, + expectToGetDaemon: false, + }, + { + name: "user-scoped provisioner and org-scoped filter with tags", + tagsToFilterBy: map[string]string{"scope": "user", "owner": "aaa", "environment": "on-prem"}, + provisionerDaemonTags: map[string]string{"scope": "organization", "owner": ""}, + expectToGetDaemon: false, + }, + { + name: "user-scoped provisioner and user-scoped filter with tags", + tagsToFilterBy: map[string]string{"scope": "user", "owner": "aaa", "environment": "on-prem"}, + provisionerDaemonTags: map[string]string{"scope": "user", "owner": "aaa"}, + expectToGetDaemon: false, + }, + { + name: "user-scoped provisioner with tags and user-scoped filter with multiple tags", + tagsToFilterBy: map[string]string{"scope": "user", "owner": "aaa", "environment": "on-prem", "datacenter": "chicago"}, + provisionerDaemonTags: map[string]string{"scope": "user", "owner": "aaa", "environment": "on-prem"}, + expectToGetDaemon: false, + }, + { + name: "user-scoped provisioner with tags and user-scoped filter with differing tags", + tagsToFilterBy: map[string]string{"scope": "user", "owner": "aaa", "environment": "on-prem", "datacenter": "new_york"}, + provisionerDaemonTags: map[string]string{"scope": "user", "owner": "aaa", "environment": "on-prem", "datacenter": "chicago"}, + expectToGetDaemon: false, + }, + } + for _, tt := range testCases { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + dv := coderdtest.DeploymentValues(t) + client, db, _ := coderdenttest.NewWithDatabase(t, &coderdenttest.Options{ + Options: &coderdtest.Options{ + DeploymentValues: dv, + }, + ProvisionerDaemonPSK: "provisionersftw", + LicenseOptions: &coderdenttest.LicenseOptions{ + Features: license.Features{ + codersdk.FeatureExternalProvisionerDaemons: 1, + codersdk.FeatureMultipleOrganizations: 1, + }, + }, + }) + ctx := testutil.Context(t, testutil.WaitShort) + + org := coderdenttest.CreateOrganization(t, client, coderdenttest.CreateOrganizationOptions{ + IncludeProvisionerDaemon: false, + }) + orgAdmin, _ := coderdtest.CreateAnotherUser(t, client, org.ID, rbac.ScopedRoleOrgMember(org.ID)) + + daemonCreatedAt := time.Now() + + //nolint:gocritic // We're not testing auth on the following in this test + provisionerKey, err := db.InsertProvisionerKey(dbauthz.AsSystemRestricted(ctx), database.InsertProvisionerKeyParams{ + Name: "Test Provisioner Key", + ID: uuid.New(), + CreatedAt: daemonCreatedAt, + OrganizationID: org.ID, + HashedSecret: []byte{}, + Tags: tt.provisionerDaemonTags, + }) + require.NoError(t, err, "should be able to create a provisioner key") + + //nolint:gocritic // We're not testing auth on the following in this test + pd, err := db.UpsertProvisionerDaemon(dbauthz.AsSystemRestricted(ctx), database.UpsertProvisionerDaemonParams{ + CreatedAt: daemonCreatedAt, + Name: "Test Provisioner Daemon", + Provisioners: []database.ProvisionerType{}, + Tags: tt.provisionerDaemonTags, + LastSeenAt: sql.NullTime{ + Time: daemonCreatedAt, + Valid: true, + }, + Version: "", + OrganizationID: org.ID, + APIVersion: "", + KeyID: provisionerKey.ID, + }) + require.NoError(t, err, "should be able to create provisioner daemon") + daemonAsCreated := db2sdk.ProvisionerDaemon(pd) + + allDaemons, err := orgAdmin.OrganizationProvisionerDaemons(ctx, org.ID, nil) + require.NoError(t, err) + require.Len(t, allDaemons, 1) + + daemonsAsFound, err := orgAdmin.OrganizationProvisionerDaemons(ctx, org.ID, tt.tagsToFilterBy) + if tt.expectToGetDaemon { + require.NoError(t, err) + require.Len(t, daemonsAsFound, 1) + require.Equal(t, daemonAsCreated.Tags, daemonsAsFound[0].Tags, "found daemon should have the same tags as created daemon") + require.Equal(t, daemonsAsFound[0].KeyID, provisionerKey.ID) + } else { + require.NoError(t, err) + assert.Empty(t, daemonsAsFound, "should not have found daemon") + } + }) + } + }) } diff --git a/enterprise/coderd/provisionerkeys.go b/enterprise/coderd/provisionerkeys.go index 0d153ffef1791..0d715c707b779 100644 --- a/enterprise/coderd/provisionerkeys.go +++ b/enterprise/coderd/provisionerkeys.go @@ -137,7 +137,7 @@ func (api *API) provisionerKeyDaemons(rw http.ResponseWriter, r *http.Request) { } sdkKeys := convertProvisionerKeys(pks) - daemons, err := api.Database.GetProvisionerDaemonsByOrganization(ctx, organization.ID) + daemons, err := api.Database.GetProvisionerDaemonsByOrganization(ctx, database.GetProvisionerDaemonsByOrganizationParams{OrganizationID: organization.ID}) if err != nil { httpapi.InternalServerError(rw, err) return diff --git a/enterprise/coderd/schedule/template_test.go b/enterprise/coderd/schedule/template_test.go index c85c2c6ea1b0e..f2871928c0ac0 100644 --- a/enterprise/coderd/schedule/template_test.go +++ b/enterprise/coderd/schedule/template_test.go @@ -248,8 +248,8 @@ func TestTemplateUpdateBuildDeadlines(t *testing.T) { UUID: uuid.New(), Valid: true, }, - Types: []database.ProvisionerType{database.ProvisionerTypeEcho}, - Tags: json.RawMessage(fmt.Sprintf(`{%q: "yeah"}`, c.name)), + Types: []database.ProvisionerType{database.ProvisionerTypeEcho}, + ProvisionerTags: json.RawMessage(fmt.Sprintf(`{%q: "yeah"}`, c.name)), }) require.NoError(t, err) require.Equal(t, job.ID, acquiredJob.ID) @@ -532,8 +532,8 @@ func TestTemplateUpdateBuildDeadlinesSkip(t *testing.T) { UUID: uuid.New(), Valid: true, }, - Types: []database.ProvisionerType{database.ProvisionerTypeEcho}, - Tags: json.RawMessage(fmt.Sprintf(`{%q: "yeah"}`, wsID)), + Types: []database.ProvisionerType{database.ProvisionerTypeEcho}, + ProvisionerTags: json.RawMessage(fmt.Sprintf(`{%q: "yeah"}`, wsID)), }) require.NoError(t, err) require.Equal(t, job.ID, acquiredJob.ID)