diff --git a/coderd/database/databasefake/databasefake.go b/coderd/database/databasefake/databasefake.go index fcc214a5745b7..228d89c2f4444 100644 --- a/coderd/database/databasefake/databasefake.go +++ b/coderd/database/databasefake/databasefake.go @@ -321,41 +321,47 @@ func (q *fakeQuerier) GetWorkspacesWithFilter(_ context.Context, arg database.Ge workspaces := make([]database.Workspace, 0) for _, workspace := range q.workspaces { - if arg.OrganizationID != uuid.Nil && workspace.OrganizationID != arg.OrganizationID { - continue - } if arg.OwnerID != uuid.Nil && workspace.OwnerID != arg.OwnerID { continue } + if arg.OwnerUsername != "" { + owner, err := q.GetUserByID(context.Background(), workspace.OwnerID) + if err == nil && arg.OwnerUsername != owner.Username { + continue + } + } + if arg.TemplateName != "" { + templates, err := q.GetTemplatesWithFilter(context.Background(), database.GetTemplatesWithFilterParams{ + ExactName: arg.TemplateName, + }) + // Add to later param + if err == nil { + for _, t := range templates { + arg.TemplateIds = append(arg.TemplateIds, t.ID) + } + } + } if !arg.Deleted && workspace.Deleted { continue } if arg.Name != "" && !strings.Contains(workspace.Name, arg.Name) { continue } - workspaces = append(workspaces, workspace) - } - - return workspaces, nil -} - -func (q *fakeQuerier) GetWorkspacesByTemplateID(_ context.Context, arg database.GetWorkspacesByTemplateIDParams) ([]database.Workspace, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - workspaces := make([]database.Workspace, 0) - for _, workspace := range q.workspaces { - if workspace.TemplateID.String() != arg.TemplateID.String() { - continue - } - if workspace.Deleted != arg.Deleted { - continue + if len(arg.TemplateIds) > 0 { + match := false + for _, id := range arg.TemplateIds { + if workspace.TemplateID == id { + match = true + break + } + } + if !match { + continue + } } workspaces = append(workspaces, workspace) } - if len(workspaces) == 0 { - return nil, sql.ErrNoRows - } + return workspaces, nil } @@ -641,25 +647,6 @@ func (q *fakeQuerier) GetWorkspaceBuildByWorkspaceIDAndBuildNumber(_ context.Con return database.WorkspaceBuild{}, sql.ErrNoRows } -func (q *fakeQuerier) GetWorkspacesByOrganizationIDs(_ context.Context, req database.GetWorkspacesByOrganizationIDsParams) ([]database.Workspace, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - workspaces := make([]database.Workspace, 0) - for _, workspace := range q.workspaces { - for _, id := range req.Ids { - if workspace.OrganizationID != id { - continue - } - if workspace.Deleted != req.Deleted { - continue - } - workspaces = append(workspaces, workspace) - } - } - return workspaces, nil -} - func (q *fakeQuerier) GetOrganizations(_ context.Context) ([]database.Organization, error) { q.mutex.RLock() defer q.mutex.RUnlock() @@ -786,6 +773,44 @@ func (q *fakeQuerier) UpdateTemplateMetaByID(_ context.Context, arg database.Upd return sql.ErrNoRows } +func (q *fakeQuerier) GetTemplatesWithFilter(_ context.Context, arg database.GetTemplatesWithFilterParams) ([]database.Template, error) { + q.mutex.RLock() + defer q.mutex.RUnlock() + + var templates []database.Template + for _, template := range q.templates { + if template.Deleted != arg.Deleted { + continue + } + if arg.OrganizationID != uuid.Nil && template.OrganizationID != arg.OrganizationID { + continue + } + + if arg.ExactName != "" && !strings.EqualFold(template.Name, arg.ExactName) { + continue + } + + if len(arg.Ids) > 0 { + match := false + for _, id := range arg.Ids { + if template.ID == id { + match = true + break + } + } + if !match { + continue + } + } + templates = append(templates, template) + } + if len(templates) > 0 { + return templates, nil + } + + return nil, sql.ErrNoRows +} + func (q *fakeQuerier) GetTemplateVersionsByTemplateID(_ context.Context, arg database.GetTemplateVersionsByTemplateIDParams) (version []database.TemplateVersion, err error) { q.mutex.RLock() defer q.mutex.RUnlock() @@ -923,45 +948,6 @@ func (q *fakeQuerier) GetParameterValueByScopeAndName(_ context.Context, arg dat return database.ParameterValue{}, sql.ErrNoRows } -func (q *fakeQuerier) GetTemplatesByOrganization(_ context.Context, arg database.GetTemplatesByOrganizationParams) ([]database.Template, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - templates := make([]database.Template, 0) - for _, template := range q.templates { - if template.Deleted != arg.Deleted { - continue - } - if template.OrganizationID != arg.OrganizationID { - continue - } - templates = append(templates, template) - } - if len(templates) == 0 { - return nil, sql.ErrNoRows - } - return templates, nil -} - -func (q *fakeQuerier) GetTemplatesByIDs(_ context.Context, ids []uuid.UUID) ([]database.Template, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - templates := make([]database.Template, 0) - for _, template := range q.templates { - for _, id := range ids { - if template.ID.String() != id.String() { - continue - } - templates = append(templates, template) - } - } - if len(templates) == 0 { - return nil, sql.ErrNoRows - } - return templates, nil -} - func (q *fakeQuerier) GetOrganizationMemberByUserID(_ context.Context, arg database.GetOrganizationMemberByUserIDParams) (database.OrganizationMember, error) { q.mutex.RLock() defer q.mutex.RUnlock() diff --git a/coderd/database/databasefake/databasefake_test.go b/coderd/database/databasefake/databasefake_test.go new file mode 100644 index 0000000000000..6a3416b59c7ac --- /dev/null +++ b/coderd/database/databasefake/databasefake_test.go @@ -0,0 +1,61 @@ +package databasefake_test + +import ( + "fmt" + "reflect" + "testing" + + "github.com/coder/coder/coderd/database" + + "github.com/coder/coder/coderd/database/databasefake" +) + +// TestExactMethods will ensure the fake database does not hold onto excessive +// functions. The fake database is a manual implementation, so it is possible +// we forget to delete functions that we remove. This unit test just ensures +// we remove the extra methods. +func TestExactMethods(t *testing.T) { + t.Parallel() + + // extraFakeMethods contains the extra allowed methods that are not a part + // of the database.Store interface. + extraFakeMethods := map[string]string{ + // Example + // "SortFakeLists": "Helper function used", + } + + fake := reflect.TypeOf(databasefake.New()) + fakeMethods := methods(fake) + + store := reflect.TypeOf((*database.Store)(nil)).Elem() + storeMethods := methods(store) + + // Store should be a subset + for k := range storeMethods { + _, ok := fakeMethods[k] + if !ok { + panic(fmt.Sprintf("This should never happen. FakeDB missing method %s, so doesn't fit the interface", k)) + } + delete(storeMethods, k) + delete(fakeMethods, k) + } + + for k := range fakeMethods { + _, ok := extraFakeMethods[k] + if ok { + continue + } + // If you are seeing this error, you have an extra function not required + // for the database.Store. If you still want to keep it, add it to + // 'extraFakeMethods' to allow it. + t.Errorf("Fake method '%s()' is excessive and not needed to fit interface, delete it", k) + } +} + +func methods(rt reflect.Type) map[string]bool { + methods := make(map[string]bool) + for i := 0; i < rt.NumMethod(); i++ { + methods[rt.Method(i).Name] = true + } + return methods +} diff --git a/coderd/database/querier.go b/coderd/database/querier.go index c3f57d3a9f795..ffac6902a13e1 100644 --- a/coderd/database/querier.go +++ b/coderd/database/querier.go @@ -53,8 +53,7 @@ type querier interface { GetTemplateVersionByJobID(ctx context.Context, jobID uuid.UUID) (TemplateVersion, error) GetTemplateVersionByTemplateIDAndName(ctx context.Context, arg GetTemplateVersionByTemplateIDAndNameParams) (TemplateVersion, error) GetTemplateVersionsByTemplateID(ctx context.Context, arg GetTemplateVersionsByTemplateIDParams) ([]TemplateVersion, error) - GetTemplatesByIDs(ctx context.Context, ids []uuid.UUID) ([]Template, error) - GetTemplatesByOrganization(ctx context.Context, arg GetTemplatesByOrganizationParams) ([]Template, error) + GetTemplatesWithFilter(ctx context.Context, arg GetTemplatesWithFilterParams) ([]Template, error) GetUserByEmailOrUsername(ctx context.Context, arg GetUserByEmailOrUsernameParams) (User, error) GetUserByID(ctx context.Context, id uuid.UUID) (User, error) GetUserCount(ctx context.Context) (int64, error) @@ -78,8 +77,6 @@ type querier interface { GetWorkspaceResourceByID(ctx context.Context, id uuid.UUID) (WorkspaceResource, error) GetWorkspaceResourcesByJobID(ctx context.Context, jobID uuid.UUID) ([]WorkspaceResource, error) GetWorkspacesAutostart(ctx context.Context) ([]Workspace, error) - GetWorkspacesByOrganizationIDs(ctx context.Context, arg GetWorkspacesByOrganizationIDsParams) ([]Workspace, error) - GetWorkspacesByTemplateID(ctx context.Context, arg GetWorkspacesByTemplateIDParams) ([]Workspace, error) GetWorkspacesWithFilter(ctx context.Context, arg GetWorkspacesWithFilterParams) ([]Workspace, error) InsertAPIKey(ctx context.Context, arg InsertAPIKeyParams) (APIKey, error) InsertAuditLog(ctx context.Context, arg InsertAuditLogParams) (AuditLog, error) diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index edb2fc774589c..519abbf6f25b5 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -1671,68 +1671,48 @@ func (q *sqlQuerier) GetTemplateByOrganizationAndName(ctx context.Context, arg G return i, err } -const getTemplatesByIDs = `-- name: GetTemplatesByIDs :many +const getTemplatesWithFilter = `-- name: GetTemplatesWithFilter :many SELECT id, created_at, updated_at, organization_id, deleted, name, provisioner, active_version_id, description, max_ttl, min_autostart_interval, created_by FROM templates WHERE - id = ANY($1 :: uuid [ ]) -` - -func (q *sqlQuerier) GetTemplatesByIDs(ctx context.Context, ids []uuid.UUID) ([]Template, error) { - rows, err := q.db.QueryContext(ctx, getTemplatesByIDs, pq.Array(ids)) - if err != nil { - return nil, err - } - defer rows.Close() - var items []Template - for rows.Next() { - var i Template - if err := rows.Scan( - &i.ID, - &i.CreatedAt, - &i.UpdatedAt, - &i.OrganizationID, - &i.Deleted, - &i.Name, - &i.Provisioner, - &i.ActiveVersionID, - &i.Description, - &i.MaxTtl, - &i.MinAutostartInterval, - &i.CreatedBy, - ); err != nil { - return nil, err - } - items = append(items, i) - } - if err := rows.Close(); err != nil { - return nil, err - } - if err := rows.Err(); err != nil { - return nil, err - } - return items, nil -} - -const getTemplatesByOrganization = `-- name: GetTemplatesByOrganization :many -SELECT - id, created_at, updated_at, organization_id, deleted, name, provisioner, active_version_id, description, max_ttl, min_autostart_interval, created_by -FROM - templates -WHERE - organization_id = $1 - AND deleted = $2 + -- Optionally include deleted templates + templates.deleted = $1 + -- Filter by organization_id + AND CASE + WHEN $2 :: uuid != '00000000-00000000-00000000-00000000' THEN + organization_id = $2 + ELSE true + END + -- Filter by exact name + AND CASE + WHEN $3 :: text != '' THEN + LOWER("name") = LOWER($3) + ELSE true + END + -- Filter by ids + AND CASE + WHEN array_length($4 :: uuid[], 1) > 0 THEN + id = ANY($4) + ELSE true + END ` -type GetTemplatesByOrganizationParams struct { - OrganizationID uuid.UUID `db:"organization_id" json:"organization_id"` - Deleted bool `db:"deleted" json:"deleted"` +type GetTemplatesWithFilterParams struct { + Deleted bool `db:"deleted" json:"deleted"` + OrganizationID uuid.UUID `db:"organization_id" json:"organization_id"` + ExactName string `db:"exact_name" json:"exact_name"` + Ids []uuid.UUID `db:"ids" json:"ids"` } -func (q *sqlQuerier) GetTemplatesByOrganization(ctx context.Context, arg GetTemplatesByOrganizationParams) ([]Template, error) { - rows, err := q.db.QueryContext(ctx, getTemplatesByOrganization, arg.OrganizationID, arg.Deleted) +func (q *sqlQuerier) GetTemplatesWithFilter(ctx context.Context, arg GetTemplatesWithFilterParams) ([]Template, error) { + rows, err := q.db.QueryContext(ctx, getTemplatesWithFilter, + arg.Deleted, + arg.OrganizationID, + arg.ExactName, + pq.Array(arg.Ids), + ) if err != nil { return nil, err } @@ -3639,98 +3619,6 @@ func (q *sqlQuerier) GetWorkspacesAutostart(ctx context.Context) ([]Workspace, e return items, nil } -const getWorkspacesByOrganizationIDs = `-- name: GetWorkspacesByOrganizationIDs :many -SELECT id, created_at, updated_at, owner_id, organization_id, template_id, deleted, name, autostart_schedule, ttl FROM workspaces WHERE organization_id = ANY($1 :: uuid [ ]) AND deleted = $2 -` - -type GetWorkspacesByOrganizationIDsParams struct { - Ids []uuid.UUID `db:"ids" json:"ids"` - Deleted bool `db:"deleted" json:"deleted"` -} - -func (q *sqlQuerier) GetWorkspacesByOrganizationIDs(ctx context.Context, arg GetWorkspacesByOrganizationIDsParams) ([]Workspace, error) { - rows, err := q.db.QueryContext(ctx, getWorkspacesByOrganizationIDs, pq.Array(arg.Ids), arg.Deleted) - if err != nil { - return nil, err - } - defer rows.Close() - var items []Workspace - for rows.Next() { - var i Workspace - if err := rows.Scan( - &i.ID, - &i.CreatedAt, - &i.UpdatedAt, - &i.OwnerID, - &i.OrganizationID, - &i.TemplateID, - &i.Deleted, - &i.Name, - &i.AutostartSchedule, - &i.Ttl, - ); err != nil { - return nil, err - } - items = append(items, i) - } - if err := rows.Close(); err != nil { - return nil, err - } - if err := rows.Err(); err != nil { - return nil, err - } - return items, nil -} - -const getWorkspacesByTemplateID = `-- name: GetWorkspacesByTemplateID :many -SELECT - id, created_at, updated_at, owner_id, organization_id, template_id, deleted, name, autostart_schedule, ttl -FROM - workspaces -WHERE - template_id = $1 - AND deleted = $2 -` - -type GetWorkspacesByTemplateIDParams struct { - TemplateID uuid.UUID `db:"template_id" json:"template_id"` - Deleted bool `db:"deleted" json:"deleted"` -} - -func (q *sqlQuerier) GetWorkspacesByTemplateID(ctx context.Context, arg GetWorkspacesByTemplateIDParams) ([]Workspace, error) { - rows, err := q.db.QueryContext(ctx, getWorkspacesByTemplateID, arg.TemplateID, arg.Deleted) - if err != nil { - return nil, err - } - defer rows.Close() - var items []Workspace - for rows.Next() { - var i Workspace - if err := rows.Scan( - &i.ID, - &i.CreatedAt, - &i.UpdatedAt, - &i.OwnerID, - &i.OrganizationID, - &i.TemplateID, - &i.Deleted, - &i.Name, - &i.AutostartSchedule, - &i.Ttl, - ); err != nil { - return nil, err - } - items = append(items, i) - } - if err := rows.Close(); err != nil { - return nil, err - } - if err := rows.Err(); err != nil { - return nil, err - } - return items, nil -} - const getWorkspacesWithFilter = `-- name: GetWorkspacesWithFilter :many SELECT id, created_at, updated_at, owner_id, organization_id, template_id, deleted, name, autostart_schedule, ttl @@ -3738,39 +3626,57 @@ FROM workspaces WHERE -- Optionally include deleted workspaces - deleted = $1 - -- Filter by organization_id + workspaces.deleted = $1 + -- Filter by owner_id AND CASE WHEN $2 :: uuid != '00000000-00000000-00000000-00000000' THEN - organization_id = $2 + owner_id = $2 ELSE true END - -- Filter by owner_id + -- Filter by owner_name AND CASE - WHEN $3 :: uuid != '00000000-00000000-00000000-00000000' THEN - owner_id = $3 - ELSE true + WHEN $3 :: text != '' THEN + owner_id = (SELECT id FROM users WHERE username = $3) + ELSE true + END + -- Filter by template_name + -- There can be more than 1 template with the same name across organizations. + -- Use the organization filter to restrict to 1 org if needed. + AND CASE + WHEN $4 :: text != '' THEN + template_id = ANY(SELECT id FROM templates WHERE name = $4) + ELSE true + END + -- Filter by template_ids + AND CASE + WHEN array_length($5 :: uuid[], 1) > 0 THEN + template_id = ANY($5) + ELSE true END -- Filter by name, matching on substring AND CASE - WHEN $4 :: text != '' THEN - LOWER(name) LIKE '%' || LOWER($4) || '%' - ELSE true + WHEN $6 :: text != '' THEN + LOWER(name) LIKE '%' || LOWER($6) || '%' + ELSE true END ` type GetWorkspacesWithFilterParams struct { - Deleted bool `db:"deleted" json:"deleted"` - OrganizationID uuid.UUID `db:"organization_id" json:"organization_id"` - OwnerID uuid.UUID `db:"owner_id" json:"owner_id"` - Name string `db:"name" json:"name"` + Deleted bool `db:"deleted" json:"deleted"` + OwnerID uuid.UUID `db:"owner_id" json:"owner_id"` + OwnerUsername string `db:"owner_username" json:"owner_username"` + TemplateName string `db:"template_name" json:"template_name"` + TemplateIds []uuid.UUID `db:"template_ids" json:"template_ids"` + Name string `db:"name" json:"name"` } func (q *sqlQuerier) GetWorkspacesWithFilter(ctx context.Context, arg GetWorkspacesWithFilterParams) ([]Workspace, error) { rows, err := q.db.QueryContext(ctx, getWorkspacesWithFilter, arg.Deleted, - arg.OrganizationID, arg.OwnerID, + arg.OwnerUsername, + arg.TemplateName, + pq.Array(arg.TemplateIds), arg.Name, ) if err != nil { diff --git a/coderd/database/queries/templates.sql b/coderd/database/queries/templates.sql index c3b3753083351..f0c4f802ac388 100644 --- a/coderd/database/queries/templates.sql +++ b/coderd/database/queries/templates.sql @@ -8,13 +8,33 @@ WHERE LIMIT 1; --- name: GetTemplatesByIDs :many +-- name: GetTemplatesWithFilter :many SELECT * FROM templates WHERE - id = ANY(@ids :: uuid [ ]); + -- Optionally include deleted templates + templates.deleted = @deleted + -- Filter by organization_id + AND CASE + WHEN @organization_id :: uuid != '00000000-00000000-00000000-00000000' THEN + organization_id = @organization_id + ELSE true + END + -- Filter by exact name + AND CASE + WHEN @exact_name :: text != '' THEN + LOWER("name") = LOWER(@exact_name) + ELSE true + END + -- Filter by ids + AND CASE + WHEN array_length(@ids :: uuid[], 1) > 0 THEN + id = ANY(@ids) + ELSE true + END +; -- name: GetTemplateByOrganizationAndName :one SELECT @@ -28,15 +48,6 @@ WHERE LIMIT 1; --- name: GetTemplatesByOrganization :many -SELECT - * -FROM - templates -WHERE - organization_id = $1 - AND deleted = $2; - -- name: InsertTemplate :one INSERT INTO templates ( diff --git a/coderd/database/queries/workspaces.sql b/coderd/database/queries/workspaces.sql index 1b9f6a88f6256..dc68186fdc794 100644 --- a/coderd/database/queries/workspaces.sql +++ b/coderd/database/queries/workspaces.sql @@ -15,30 +15,41 @@ FROM workspaces WHERE -- Optionally include deleted workspaces - deleted = @deleted - -- Filter by organization_id + workspaces.deleted = @deleted + -- Filter by owner_id AND CASE - WHEN @organization_id :: uuid != '00000000-00000000-00000000-00000000' THEN - organization_id = @organization_id + WHEN @owner_id :: uuid != '00000000-00000000-00000000-00000000' THEN + owner_id = @owner_id ELSE true END - -- Filter by owner_id + -- Filter by owner_name + AND CASE + WHEN @owner_username :: text != '' THEN + owner_id = (SELECT id FROM users WHERE username = @owner_username) + ELSE true + END + -- Filter by template_name + -- There can be more than 1 template with the same name across organizations. + -- Use the organization filter to restrict to 1 org if needed. AND CASE - WHEN @owner_id :: uuid != '00000000-00000000-00000000-00000000' THEN - owner_id = @owner_id - ELSE true + WHEN @template_name :: text != '' THEN + template_id = ANY(SELECT id FROM templates WHERE name = @template_name) + ELSE true + END + -- Filter by template_ids + AND CASE + WHEN array_length(@template_ids :: uuid[], 1) > 0 THEN + template_id = ANY(@template_ids) + ELSE true END -- Filter by name, matching on substring AND CASE - WHEN @name :: text != '' THEN - LOWER(name) LIKE '%' || LOWER(@name) || '%' - ELSE true + WHEN @name :: text != '' THEN + LOWER(name) LIKE '%' || LOWER(@name) || '%' + ELSE true END ; --- name: GetWorkspacesByOrganizationIDs :many -SELECT * FROM workspaces WHERE organization_id = ANY(@ids :: uuid [ ]) AND deleted = @deleted; - -- name: GetWorkspacesAutostart :many SELECT * @@ -53,15 +64,6 @@ AND (ttl IS NOT NULL AND ttl > 0) ); --- name: GetWorkspacesByTemplateID :many -SELECT - * -FROM - workspaces -WHERE - template_id = $1 - AND deleted = $2; - -- name: GetWorkspaceByOwnerIDAndName :one SELECT * diff --git a/coderd/httpapi/queryparams.go b/coderd/httpapi/queryparams.go new file mode 100644 index 0000000000000..ea30480e16b4e --- /dev/null +++ b/coderd/httpapi/queryparams.go @@ -0,0 +1,105 @@ +package httpapi + +import ( + "fmt" + "net/url" + "strconv" + "strings" + + "github.com/google/uuid" + + "golang.org/x/xerrors" +) + +// QueryParamParser is a helper for parsing all query params and gathering all +// errors in 1 sweep. This means all invalid fields are returned at once, +// rather than only returning the first error +type QueryParamParser struct { + // Errors is the set of errors to return via the API. If the length + // of this set is 0, there are no errors!. + Errors []Error +} + +func NewQueryParamParser() *QueryParamParser { + return &QueryParamParser{ + Errors: []Error{}, + } +} + +func (p *QueryParamParser) Int(vals url.Values, def int, queryParam string) int { + v, err := parseQueryParam(vals, strconv.Atoi, def, queryParam) + if err != nil { + p.Errors = append(p.Errors, Error{ + Field: queryParam, + Detail: fmt.Sprintf("Query param %q must be a valid integer (%s)", queryParam, err.Error()), + }) + } + return v +} + +func (p *QueryParamParser) UUIDorMe(vals url.Values, def uuid.UUID, me uuid.UUID, queryParam string) uuid.UUID { + if vals.Get(queryParam) == "me" { + return me + } + return p.UUID(vals, def, queryParam) +} + +func (p *QueryParamParser) UUID(vals url.Values, def uuid.UUID, queryParam string) uuid.UUID { + v, err := parseQueryParam(vals, uuid.Parse, def, queryParam) + if err != nil { + p.Errors = append(p.Errors, Error{ + Field: queryParam, + Detail: fmt.Sprintf("Query param %q must be a valid uuid", queryParam), + }) + } + return v +} + +func (p *QueryParamParser) UUIDs(vals url.Values, def []uuid.UUID, queryParam string) []uuid.UUID { + v, err := parseQueryParam(vals, func(v string) ([]uuid.UUID, error) { + var badValues []string + strs := strings.Split(v, ",") + ids := make([]uuid.UUID, 0, len(strs)) + for _, s := range strs { + id, err := uuid.Parse(strings.TrimSpace(s)) + if err != nil { + badValues = append(badValues, v) + continue + } + ids = append(ids, id) + } + + if len(badValues) > 0 { + return []uuid.UUID{}, xerrors.Errorf("%s", strings.Join(badValues, ",")) + } + return ids, nil + }, def, queryParam) + if err != nil { + p.Errors = append(p.Errors, Error{ + Field: queryParam, + Detail: fmt.Sprintf("Query param %q has invalid uuids: %q", queryParam, err.Error()), + }) + } + return v +} + +func (p *QueryParamParser) String(vals url.Values, def string, queryParam string) string { + v, err := parseQueryParam(vals, func(v string) (string, error) { + return v, nil + }, def, queryParam) + if err != nil { + p.Errors = append(p.Errors, Error{ + Field: queryParam, + Detail: fmt.Sprintf("Query param %q must be a valid string", queryParam), + }) + } + return v +} + +func parseQueryParam[T any](vals url.Values, parse func(v string) (T, error), def T, queryParam string) (T, error) { + if !vals.Has(queryParam) || vals.Get(queryParam) == "" { + return def, nil + } + str := vals.Get(queryParam) + return parse(str) +} diff --git a/coderd/httpapi/queryparams_test.go b/coderd/httpapi/queryparams_test.go new file mode 100644 index 0000000000000..f4ff580b4dc22 --- /dev/null +++ b/coderd/httpapi/queryparams_test.go @@ -0,0 +1,201 @@ +package httpapi_test + +import ( + "fmt" + "net/http" + "net/url" + "testing" + + "github.com/google/uuid" + "github.com/stretchr/testify/require" + + "github.com/coder/coder/coderd/httpapi" +) + +type queryParamTestCase[T any] struct { + QueryParam string + // No set does not set the query param, rather than setting the empty value + NoSet bool + Value string + Default T + Expected T + ExpectedErrorContains string + Parse func(r *http.Request, def T, queryParam string) T +} + +func TestParseQueryParams(t *testing.T) { + t.Parallel() + + t.Run("UUID", func(t *testing.T) { + t.Parallel() + me := uuid.New() + expParams := []queryParamTestCase[uuid.UUID]{ + { + QueryParam: "valid_id", + Value: "afe39fbf-0f52-4a62-b0cc-58670145d773", + Expected: uuid.MustParse("afe39fbf-0f52-4a62-b0cc-58670145d773"), + }, + { + QueryParam: "me", + Value: "me", + Expected: me, + }, + { + QueryParam: "invalid_id", + Value: "bogus", + ExpectedErrorContains: "must be a valid uuid", + }, + { + QueryParam: "long_id", + Value: "afe39fbf-0f52-4a62-b0cc-58670145d773-123", + ExpectedErrorContains: "must be a valid uuid", + }, + { + QueryParam: "no_value", + NoSet: true, + Default: uuid.MustParse("11111111-1111-1111-1111-111111111111"), + Expected: uuid.MustParse("11111111-1111-1111-1111-111111111111"), + }, + { + QueryParam: "empty", + Value: "", + Expected: uuid.Nil, + }, + } + + parser := httpapi.NewQueryParamParser() + testQueryParams(t, expParams, parser, func(vals url.Values, def uuid.UUID, queryParam string) uuid.UUID { + return parser.UUIDorMe(vals, def, me, queryParam) + }) + }) + + t.Run("String", func(t *testing.T) { + t.Parallel() + expParams := []queryParamTestCase[string]{ + { + QueryParam: "valid_string", + Value: "random", + Expected: "random", + }, + { + QueryParam: "empty", + Value: "", + Expected: "", + }, + { + QueryParam: "no_value", + NoSet: true, + Default: "default", + Expected: "default", + }, + } + + parser := httpapi.NewQueryParamParser() + testQueryParams(t, expParams, parser, parser.String) + }) + + t.Run("Int", func(t *testing.T) { + t.Parallel() + expParams := []queryParamTestCase[int]{ + { + QueryParam: "valid_integer", + Value: "100", + Expected: 100, + }, + { + QueryParam: "empty", + Value: "", + Expected: 0, + }, + { + QueryParam: "no_value", + NoSet: true, + Default: 5, + Expected: 5, + }, + { + QueryParam: "negative", + Value: "-10", + Expected: -10, + Default: 5, + }, + { + QueryParam: "invalid_integer", + Value: "bogus", + Expected: 0, + ExpectedErrorContains: "must be a valid integer", + }, + } + + parser := httpapi.NewQueryParamParser() + testQueryParams(t, expParams, parser, parser.Int) + }) + + t.Run("UUIDs", func(t *testing.T) { + t.Parallel() + expParams := []queryParamTestCase[[]uuid.UUID]{ + { + QueryParam: "valid_ids_with_spaces", + Value: "6c8ef17d-5dd8-4b92-bac9-41944f90f237, 65fb05f3-12c8-4a0a-801f-40439cf9e681 , 01b94888-1eab-4bbf-aed0-dc7a8010da97", + Expected: []uuid.UUID{ + uuid.MustParse("6c8ef17d-5dd8-4b92-bac9-41944f90f237"), + uuid.MustParse("65fb05f3-12c8-4a0a-801f-40439cf9e681"), + uuid.MustParse("01b94888-1eab-4bbf-aed0-dc7a8010da97"), + }, + }, + { + QueryParam: "empty", + Value: "", + Default: []uuid.UUID{}, + Expected: []uuid.UUID{}, + }, + { + QueryParam: "no_value", + NoSet: true, + Default: []uuid.UUID{}, + Expected: []uuid.UUID{}, + }, + { + QueryParam: "default", + NoSet: true, + Default: []uuid.UUID{uuid.Nil}, + Expected: []uuid.UUID{uuid.Nil}, + }, + { + QueryParam: "invalid_id_in_set", + Value: "6c8ef17d-5dd8-4b92-bac9-41944f90f237,bogus", + Expected: []uuid.UUID{}, + Default: []uuid.UUID{}, + ExpectedErrorContains: "bogus", + }, + } + + parser := httpapi.NewQueryParamParser() + testQueryParams(t, expParams, parser, parser.UUIDs) + }) +} + +func testQueryParams[T any](t *testing.T, testCases []queryParamTestCase[T], parser *httpapi.QueryParamParser, parse func(vals url.Values, def T, queryParam string) T) { + v := url.Values{} + for _, c := range testCases { + if c.NoSet { + continue + } + v.Set(c.QueryParam, c.Value) + } + + for _, c := range testCases { + // !! Do not run these in parallel !! + t.Run(c.QueryParam, func(t *testing.T) { + v := parse(v, c.Default, c.QueryParam) + require.Equal(t, c.Expected, v, fmt.Sprintf("param=%q value=%q", c.QueryParam, c.Value)) + if c.ExpectedErrorContains != "" { + errors := parser.Errors + require.True(t, len(errors) > 0, "error exist") + last := errors[len(errors)-1] + require.True(t, last.Field == c.QueryParam, fmt.Sprintf("query param %q did not fail", c.QueryParam)) + require.Contains(t, last.Detail, c.ExpectedErrorContains, "correct error") + } + }) + } +} diff --git a/coderd/pagination.go b/coderd/pagination.go index 1dc1a28886221..07f7b0fe743db 100644 --- a/coderd/pagination.go +++ b/coderd/pagination.go @@ -2,7 +2,6 @@ package coderd import ( "net/http" - "strconv" "github.com/google/uuid" @@ -13,53 +12,21 @@ import ( // parsePagination extracts pagination query params from the http request. // If an error is encountered, the error is written to w and ok is set to false. func parsePagination(w http.ResponseWriter, r *http.Request) (p codersdk.Pagination, ok bool) { - var ( - afterID = uuid.Nil - limit = -1 // Default to no limit and return all results. - offset = 0 - ) - - var err error - if s := r.URL.Query().Get("after_id"); s != "" { - afterID, err = uuid.Parse(r.URL.Query().Get("after_id")) - if err != nil { - httpapi.Write(w, http.StatusBadRequest, httpapi.Response{ - Message: "Query param 'after_id' must be a valid UUID.", - Validations: []httpapi.Error{ - {Field: "after_id", Detail: err.Error()}, - }, - }) - return p, false - } - } - if s := r.URL.Query().Get("limit"); s != "" { - limit, err = strconv.Atoi(s) - if err != nil { - httpapi.Write(w, http.StatusBadRequest, httpapi.Response{ - Message: "Query param 'limit' must be a valid integer.", - Validations: []httpapi.Error{ - {Field: "limit", Detail: err.Error()}, - }, - }) - return p, false - } + queryParams := r.URL.Query() + parser := httpapi.NewQueryParamParser() + params := codersdk.Pagination{ + AfterID: parser.UUID(queryParams, uuid.Nil, "after_id"), + // Limit default to "-1" which returns all results + Limit: parser.Int(queryParams, -1, "limit"), + Offset: parser.Int(queryParams, 0, "offset"), } - if s := r.URL.Query().Get("offset"); s != "" { - offset, err = strconv.Atoi(s) - if err != nil { - httpapi.Write(w, http.StatusBadRequest, httpapi.Response{ - Message: "Query param 'offset' must be a valid integer.", - Validations: []httpapi.Error{ - {Field: "offset", Detail: err.Error()}, - }, - }) - return p, false - } + if len(parser.Errors) > 0 { + httpapi.Write(w, http.StatusBadRequest, httpapi.Response{ + Message: "Query parameters have invalid values.", + Validations: parser.Errors, + }) + return params, false } - return codersdk.Pagination{ - AfterID: afterID, - Limit: limit, - Offset: offset, - }, true + return params, true } diff --git a/coderd/pagination_internal_test.go b/coderd/pagination_internal_test.go index 97eba7275ea6f..5504ef267e165 100644 --- a/coderd/pagination_internal_test.go +++ b/coderd/pagination_internal_test.go @@ -14,6 +14,7 @@ import ( func TestPagination(t *testing.T) { t.Parallel() + const invalidValues = "Query parameters have invalid values" testCases := []struct { Name string @@ -27,27 +28,27 @@ func TestPagination(t *testing.T) { { Name: "BadAfterID", AfterID: "bogus", - ExpectedError: "Query param 'after_id' must be a valid UUID", + ExpectedError: invalidValues, }, { Name: "ShortAfterID", AfterID: "ff22a7b-bb6f-43d8-83e1-eefe0a1f5197", - ExpectedError: "Query param 'after_id' must be a valid UUID", + ExpectedError: invalidValues, }, { Name: "LongAfterID", AfterID: "cff22a7b-bb6f-43d8-83e1-eefe0a1f51972", - ExpectedError: "Query param 'after_id' must be a valid UUID", + ExpectedError: invalidValues, }, { Name: "BadLimit", Limit: "bogus", - ExpectedError: "Query param 'limit' must be a valid integer", + ExpectedError: invalidValues, }, { Name: "BadOffset", Offset: "bogus", - ExpectedError: "Query param 'offset' must be a valid integer", + ExpectedError: invalidValues, }, // Valid values diff --git a/coderd/templates.go b/coderd/templates.go index d79bd19f70fa2..6f0fa929d9923 100644 --- a/coderd/templates.go +++ b/coderd/templates.go @@ -68,8 +68,8 @@ func (api *API) deleteTemplate(rw http.ResponseWriter, r *http.Request) { return } - workspaces, err := api.Database.GetWorkspacesByTemplateID(r.Context(), database.GetWorkspacesByTemplateIDParams{ - TemplateID: template.ID, + workspaces, err := api.Database.GetWorkspacesWithFilter(r.Context(), database.GetWorkspacesWithFilterParams{ + TemplateIds: []uuid.UUID{template.ID}, }) if errors.Is(err, sql.ErrNoRows) { err = nil @@ -244,7 +244,7 @@ func (api *API) postTemplateByOrganization(rw http.ResponseWriter, r *http.Reque func (api *API) templatesByOrganization(rw http.ResponseWriter, r *http.Request) { organization := httpmw.OrganizationParam(r) - templates, err := api.Database.GetTemplatesByOrganization(r.Context(), database.GetTemplatesByOrganizationParams{ + templates, err := api.Database.GetTemplatesWithFilter(r.Context(), database.GetTemplatesWithFilterParams{ OrganizationID: organization.ID, }) if errors.Is(err, sql.ErrNoRows) { diff --git a/coderd/workspaces.go b/coderd/workspaces.go index b9e09f50513b4..920adb85295da 100644 --- a/coderd/workspaces.go +++ b/coderd/workspaces.go @@ -7,7 +7,9 @@ import ( "errors" "fmt" "net/http" + "net/url" "strconv" + "strings" "time" "github.com/go-chi/chi/v5" @@ -103,38 +105,19 @@ func (api *API) workspace(rw http.ResponseWriter, r *http.Request) { func (api *API) workspaces(rw http.ResponseWriter, r *http.Request) { apiKey := httpmw.APIKey(r) - // Empty strings mean no filter - orgFilter := r.URL.Query().Get("organization_id") - ownerFilter := r.URL.Query().Get("owner") - nameFilter := r.URL.Query().Get("name") - - filter := database.GetWorkspacesWithFilterParams{Deleted: false} - if orgFilter != "" { - orgID, err := uuid.Parse(orgFilter) - if err == nil { - filter.OrganizationID = orgID - } + queryStr := r.URL.Query().Get("q") + filter, errs := workspaceSearchQuery(queryStr) + if len(errs) > 0 { + httpapi.Write(rw, http.StatusBadRequest, httpapi.Response{ + Message: "Invalid workspace search query.", + Validations: errs, + }) + return } - if ownerFilter == "me" { + + if filter.OwnerUsername == "me" { filter.OwnerID = apiKey.UserID - } else if ownerFilter != "" { - userID, err := uuid.Parse(ownerFilter) - if err != nil { - // Maybe it's a username - user, err := api.Database.GetUserByEmailOrUsername(r.Context(), database.GetUserByEmailOrUsernameParams{ - // Why not just accept 1 arg and use it for both in the sql? - Username: ownerFilter, - Email: ownerFilter, - }) - if err == nil { - filter.OwnerID = user.ID - } - } else { - filter.OwnerID = userID - } - } - if nameFilter != "" { - filter.Name = nameFilter + filter.OwnerUsername = "" } workspaces, err := api.Database.GetWorkspacesWithFilter(r.Context(), filter) @@ -276,26 +259,13 @@ func (api *API) postWorkspacesByOrganization(rw http.ResponseWriter, r *http.Req return } - if organization.ID != template.OrganizationID { - httpapi.Write(rw, http.StatusUnauthorized, httpapi.Response{ - Message: fmt.Sprintf("Template is not in organization %q.", organization.Name), - }) + if !api.Authorize(rw, r, rbac.ActionRead, template) { return } - _, err = api.Database.GetOrganizationMemberByUserID(r.Context(), database.GetOrganizationMemberByUserIDParams{ - OrganizationID: template.OrganizationID, - UserID: apiKey.UserID, - }) - if errors.Is(err, sql.ErrNoRows) { + + if organization.ID != template.OrganizationID { httpapi.Write(rw, http.StatusUnauthorized, httpapi.Response{ - Message: "You aren't allowed to access templates in that organization.", - }) - return - } - if err != nil { - httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ - Message: "Internal error fetching organization member.", - Detail: err.Error(), + Message: fmt.Sprintf("Template is not in organization %q.", organization.Name), }) return } @@ -791,7 +761,9 @@ func convertWorkspaces(ctx context.Context, db database.Store, workspaces []data if err != nil { return nil, xerrors.Errorf("get workspace builds: %w", err) } - templates, err := db.GetTemplatesByIDs(ctx, templateIDs) + templates, err := db.GetTemplatesWithFilter(ctx, database.GetTemplatesWithFilterParams{ + Ids: templateIDs, + }) if errors.Is(err, sql.ErrNoRows) { err = nil } @@ -974,3 +946,81 @@ func validWorkspaceSchedule(s *string, min time.Duration) (sql.NullString, error String: *s, }, nil } + +// workspaceSearchQuery takes a query string and returns the workspace filter. +// It also can return the list of validation errors to return to the api. +func workspaceSearchQuery(query string) (database.GetWorkspacesWithFilterParams, []httpapi.Error) { + searchParams := make(url.Values) + if query == "" { + // No filter + return database.GetWorkspacesWithFilterParams{}, nil + } + // Because we do this in 2 passes, we want to maintain quotes on the first + // pass.Further splitting occurs on the second pass and quotes will be + // dropped. + elements := splitQueryParameterByDelimiter(query, ' ', true) + for _, element := range elements { + parts := splitQueryParameterByDelimiter(element, ':', false) + switch len(parts) { + case 1: + // No key:value pair. It is a workspace name, and maybe includes an owner + parts = splitQueryParameterByDelimiter(element, '/', false) + switch len(parts) { + case 1: + searchParams.Set("name", parts[0]) + case 2: + searchParams.Set("owner", parts[0]) + searchParams.Set("name", parts[1]) + default: + return database.GetWorkspacesWithFilterParams{}, []httpapi.Error{ + {Field: "q", Detail: fmt.Sprintf("Query element %q can only contain 1 '/'", element)}, + } + } + case 2: + searchParams.Set(parts[0], parts[1]) + default: + return database.GetWorkspacesWithFilterParams{}, []httpapi.Error{ + {Field: "q", Detail: fmt.Sprintf("Query element %q can only contain 1 ':'", element)}, + } + } + } + + // Using the query param parser here just returns consistent errors with + // other parsing. + parser := httpapi.NewQueryParamParser() + filter := database.GetWorkspacesWithFilterParams{ + Deleted: false, + OwnerUsername: parser.String(searchParams, "", "owner"), + TemplateName: parser.String(searchParams, "", "template"), + Name: parser.String(searchParams, "", "name"), + } + + return filter, parser.Errors +} + +// splitQueryParameterByDelimiter takes a query string and splits it into the individual elements +// of the query. Each element is separated by a delimiter. All quoted strings are +// kept as a single element. +// +// Although all our names cannot have spaces, that is a validation error. +// We should still parse the quoted string as a single value so that validation +// can properly fail on the space. If we do not, a value of `template:"my name"` +// will search `template:"my name:name"`, which produces an empty list instead of +// an error. +// nolint:revive +func splitQueryParameterByDelimiter(query string, delimiter rune, maintainQuotes bool) []string { + quoted := false + parts := strings.FieldsFunc(query, func(r rune) bool { + if r == '"' { + quoted = !quoted + } + return !quoted && r == delimiter + }) + if !maintainQuotes { + for i, part := range parts { + parts[i] = strings.Trim(part, "\"") + } + } + + return parts +} diff --git a/coderd/workspaces_internal_test.go b/coderd/workspaces_internal_test.go new file mode 100644 index 0000000000000..dc783b417ebfc --- /dev/null +++ b/coderd/workspaces_internal_test.go @@ -0,0 +1,144 @@ +package coderd + +import ( + "fmt" + "strings" + "testing" + + "github.com/coder/coder/coderd/database" + + "github.com/stretchr/testify/require" +) + +func TestSearchWorkspace(t *testing.T) { + t.Parallel() + testCases := []struct { + Name string + Query string + Expected database.GetWorkspacesWithFilterParams + ExpectedErrorContains string + }{ + { + Name: "Empty", + Query: "", + Expected: database.GetWorkspacesWithFilterParams{}, + }, + { + Name: "Owner/Name", + Query: "Foo/Bar", + Expected: database.GetWorkspacesWithFilterParams{ + OwnerUsername: "Foo", + Name: "Bar", + }, + }, + { + Name: "Name", + Query: "workspace-name", + Expected: database.GetWorkspacesWithFilterParams{ + Name: "workspace-name", + }, + }, + { + Name: "Name+Param", + Query: "workspace-name template:docker", + Expected: database.GetWorkspacesWithFilterParams{ + Name: "workspace-name", + TemplateName: "docker", + }, + }, + { + Name: "OnlyParams", + Query: "name:workspace-name template:docker owner:alice", + Expected: database.GetWorkspacesWithFilterParams{ + Name: "workspace-name", + TemplateName: "docker", + OwnerUsername: "alice", + }, + }, + { + Name: "QuotedParam", + Query: `name:workspace-name template:"docker template" owner:alice`, + Expected: database.GetWorkspacesWithFilterParams{ + Name: "workspace-name", + TemplateName: "docker template", + OwnerUsername: "alice", + }, + }, + { + Name: "QuotedKey", + Query: `"name":baz "template":foo "owner":bar`, + Expected: database.GetWorkspacesWithFilterParams{ + Name: "baz", + TemplateName: "foo", + OwnerUsername: "bar", + }, + }, + { + // This will not return an error + Name: "ExtraKeys", + Query: `foo:bar`, + Expected: database.GetWorkspacesWithFilterParams{}, + }, + { + // Quotes keep elements together + Name: "QuotedSpecial", + Query: `name:"workspace:name"`, + Expected: database.GetWorkspacesWithFilterParams{ + Name: "workspace:name", + }, + }, + { + Name: "QuotedMadness", + Query: `"name":"foo:bar:baz/baz/zoo:zonk"`, + Expected: database.GetWorkspacesWithFilterParams{ + Name: "foo:bar:baz/baz/zoo:zonk", + }, + }, + { + Name: "QuotedName", + Query: `"foo/bar"`, + Expected: database.GetWorkspacesWithFilterParams{ + Name: "foo/bar", + }, + }, + { + Name: "QuotedOwner/Name", + Query: `"foo"/"bar"`, + Expected: database.GetWorkspacesWithFilterParams{ + Name: "bar", + OwnerUsername: "foo", + }, + }, + + // Failures + { + Name: "ExtraSlashes", + Query: `foo/bar/baz`, + ExpectedErrorContains: "can only contain 1 '/'", + }, + { + Name: "ExtraColon", + Query: `owner:name:extra`, + ExpectedErrorContains: "can only contain 1 ':'", + }, + } + + for _, c := range testCases { + c := c + t.Run(c.Name, func(t *testing.T) { + t.Parallel() + values, errs := workspaceSearchQuery(c.Query) + if c.ExpectedErrorContains != "" { + require.True(t, len(errs) > 0, "expect some errors") + var s strings.Builder + for _, err := range errs { + _, _ = s.WriteString(fmt.Sprintf("%s: %s\n", err.Field, err.Detail)) + } + require.Contains(t, s.String(), c.ExpectedErrorContains) + } else { + require.Len(t, errs, 0, "expected no error") + require.Equal(t, c.Expected, values, "expected values") + } + }) + } +} diff --git a/coderd/workspaces_test.go b/coderd/workspaces_test.go index cda6688f4d7af..9e207362fe18d 100644 --- a/coderd/workspaces_test.go +++ b/coderd/workspaces_test.go @@ -4,18 +4,19 @@ import ( "context" "fmt" "net/http" + "strings" "testing" "time" - "github.com/coder/coder/coderd/rbac" - "github.com/coder/coder/coderd/util/ptr" - "github.com/google/uuid" "github.com/stretchr/testify/require" "github.com/coder/coder/coderd/autobuild/schedule" "github.com/coder/coder/coderd/coderdtest" + "github.com/coder/coder/coderd/rbac" + "github.com/coder/coder/coderd/util/ptr" "github.com/coder/coder/codersdk" + "github.com/coder/coder/cryptorand" "github.com/coder/coder/provisioner/echo" "github.com/coder/coder/provisionersdk/proto" ) @@ -336,8 +337,164 @@ func TestWorkspaceByOwnerAndName(t *testing.T) { }) } +// TestWorkspaceFilter creates a set of workspaces, users, and organizations +// to run various filters against for testing. func TestWorkspaceFilter(t *testing.T) { t.Parallel() + type coderUser struct { + *codersdk.Client + User codersdk.User + Org codersdk.Organization + } + + client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerD: true}) + first := coderdtest.CreateFirstUser(t, client) + + users := make([]coderUser, 0) + for i := 0; i < 10; i++ { + userClient := coderdtest.CreateAnotherUser(t, client, first.OrganizationID, rbac.RoleAdmin()) + user, err := userClient.User(context.Background(), codersdk.Me) + require.NoError(t, err, "fetch me") + + org, err := userClient.CreateOrganization(context.Background(), codersdk.CreateOrganizationRequest{ + Name: user.Username + "-org", + }) + require.NoError(t, err, "create org") + + users = append(users, coderUser{ + Client: userClient, + User: user, + Org: org, + }) + } + + type madeWorkspace struct { + Owner codersdk.User + Workspace codersdk.Workspace + Template codersdk.Template + } + + availTemplates := make([]codersdk.Template, 0) + allWorkspaces := make([]madeWorkspace, 0) + + // Create some random workspaces + for _, user := range users { + version := coderdtest.CreateTemplateVersion(t, client, user.Org.ID, nil) + + // Create a template & workspace in the user's org + coderdtest.AwaitTemplateVersionJob(t, client, version.ID) + template := coderdtest.CreateTemplate(t, client, user.Org.ID, version.ID) + availTemplates = append(availTemplates, template) + workspace := coderdtest.CreateWorkspace(t, user.Client, template.OrganizationID, template.ID) + allWorkspaces = append(allWorkspaces, madeWorkspace{ + Workspace: workspace, + Template: template, + Owner: user.User, + }) + + // Make a workspace with a random template + idx, _ := cryptorand.Intn(len(availTemplates)) + randTemplate := availTemplates[idx] + randWorkspace := coderdtest.CreateWorkspace(t, user.Client, randTemplate.OrganizationID, randTemplate.ID) + allWorkspaces = append(allWorkspaces, madeWorkspace{ + Workspace: randWorkspace, + Template: randTemplate, + Owner: user.User, + }) + } + + // Make sure all workspaces are done. Do it after all are made + for i, w := range allWorkspaces { + latest := coderdtest.AwaitWorkspaceBuildJob(t, client, w.Workspace.LatestBuild.ID) + allWorkspaces[i].Workspace.LatestBuild = latest + } + + // --- Setup done --- + testCases := []struct { + Name string + Filter codersdk.WorkspaceFilter + // If FilterF is true, we include it in the expected results + FilterF func(f codersdk.WorkspaceFilter, workspace madeWorkspace) bool + }{ + { + Name: "All", + Filter: codersdk.WorkspaceFilter{}, + FilterF: func(_ codersdk.WorkspaceFilter, _ madeWorkspace) bool { + return true + }, + }, + { + Name: "Owner", + Filter: codersdk.WorkspaceFilter{ + Owner: users[2].User.Username, + }, + FilterF: func(f codersdk.WorkspaceFilter, workspace madeWorkspace) bool { + return workspace.Owner.Username == f.Owner + }, + }, + { + Name: "TemplateName", + Filter: codersdk.WorkspaceFilter{ + Template: allWorkspaces[5].Template.Name, + }, + FilterF: func(f codersdk.WorkspaceFilter, workspace madeWorkspace) bool { + return workspace.Template.Name == f.Template + }, + }, + { + Name: "Name", + Filter: codersdk.WorkspaceFilter{ + // Use a common letter... one has to have this letter in it + Name: "a", + }, + FilterF: func(f codersdk.WorkspaceFilter, workspace madeWorkspace) bool { + return strings.Contains(workspace.Workspace.Name, f.Name) + }, + }, + { + Name: "Q-Owner/Name", + Filter: codersdk.WorkspaceFilter{ + FilterQuery: allWorkspaces[5].Owner.Username + "/" + allWorkspaces[5].Workspace.Name, + }, + FilterF: func(_ codersdk.WorkspaceFilter, workspace madeWorkspace) bool { + return workspace.Workspace.ID == allWorkspaces[5].Workspace.ID + }, + }, + { + Name: "Many filters", + Filter: codersdk.WorkspaceFilter{ + Owner: allWorkspaces[3].Owner.Username, + Template: allWorkspaces[3].Template.Name, + Name: allWorkspaces[3].Workspace.Name, + }, + FilterF: func(f codersdk.WorkspaceFilter, workspace madeWorkspace) bool { + return workspace.Workspace.ID == allWorkspaces[3].Workspace.ID + }, + }, + } + + for _, c := range testCases { + c := c + t.Run(c.Name, func(t *testing.T) { + t.Parallel() + workspaces, err := client.Workspaces(context.Background(), c.Filter) + require.NoError(t, err, "fetch workspaces") + + exp := make([]codersdk.Workspace, 0) + for _, made := range allWorkspaces { + if c.FilterF(c.Filter, made) { + exp = append(exp, made.Workspace) + } + } + require.ElementsMatch(t, exp, workspaces, "expected workspaces returned") + }) + } +} + +// TestWorkspaceFilterManual runs some specific setups with basic checks. +func TestWorkspaceFilterManual(t *testing.T) { + t.Parallel() + t.Run("Name", func(t *testing.T) { t.Parallel() client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerD: true}) @@ -370,6 +527,49 @@ func TestWorkspaceFilter(t *testing.T) { require.NoError(t, err) require.Len(t, ws, 0) }) + t.Run("Template", func(t *testing.T) { + t.Parallel() + client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerD: true}) + user := coderdtest.CreateFirstUser(t, client) + version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, nil) + coderdtest.AwaitTemplateVersionJob(t, client, version.ID) + template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID) + template2 := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID) + workspace := coderdtest.CreateWorkspace(t, client, user.OrganizationID, template.ID) + _ = coderdtest.CreateWorkspace(t, client, user.OrganizationID, template2.ID) + + // empty + ws, err := client.Workspaces(context.Background(), codersdk.WorkspaceFilter{}) + require.NoError(t, err) + require.Len(t, ws, 2) + + // single template + ws, err = client.Workspaces(context.Background(), codersdk.WorkspaceFilter{ + Template: template.Name, + }) + require.NoError(t, err) + require.Len(t, ws, 1) + require.Equal(t, workspace.ID, ws[0].ID) + }) + t.Run("FilterQuery", func(t *testing.T) { + t.Parallel() + client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerD: true}) + user := coderdtest.CreateFirstUser(t, client) + version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, nil) + coderdtest.AwaitTemplateVersionJob(t, client, version.ID) + template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID) + template2 := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID) + workspace := coderdtest.CreateWorkspace(t, client, user.OrganizationID, template.ID) + _ = coderdtest.CreateWorkspace(t, client, user.OrganizationID, template2.ID) + + // single workspace + ws, err := client.Workspaces(context.Background(), codersdk.WorkspaceFilter{ + FilterQuery: fmt.Sprintf("template:%s %s/%s", template.Name, workspace.OwnerName, workspace.Name), + }) + require.NoError(t, err) + require.Len(t, ws, 1) + require.Equal(t, workspace.ID, ws[0].ID) + }) } func TestPostWorkspaceBuild(t *testing.T) { diff --git a/codersdk/workspaces.go b/codersdk/workspaces.go index fbc1be91ab0e5..f5139c8e5200a 100644 --- a/codersdk/workspaces.go +++ b/codersdk/workspaces.go @@ -5,6 +5,7 @@ import ( "encoding/json" "fmt" "net/http" + "strings" "time" "github.com/google/uuid" @@ -217,26 +218,39 @@ func (c *Client) PutExtendWorkspace(ctx context.Context, id uuid.UUID, req PutEx } type WorkspaceFilter struct { - OrganizationID uuid.UUID `json:"organization_id,omitempty"` - // Owner can be a user_id (uuid), "me", or a username - Owner string `json:"owner,omitempty"` - Name string `json:"name,omitempty"` + // Owner can be "me" or a username + Owner string `json:"owner,omitempty" typescript:"-"` + // Template is a template name + Template string `json:"template,omitempty" typescript:"-"` + // Name will return partial matches + Name string `json:"name,omitempty" typescript:"-"` + // FilterQuery supports a raw filter query string + FilterQuery string `json:"q,omitempty"` } // asRequestOption returns a function that can be used in (*Client).Request. // It modifies the request query parameters. func (f WorkspaceFilter) asRequestOption() requestOption { return func(r *http.Request) { - q := r.URL.Query() - if f.OrganizationID != uuid.Nil { - q.Set("organization_id", f.OrganizationID.String()) - } + var params []string + // Make sure all user input is quoted to ensure it's parsed as a single + // string. if f.Owner != "" { - q.Set("owner", f.Owner) + params = append(params, fmt.Sprintf("owner:%q", f.Owner)) } if f.Name != "" { - q.Set("name", f.Name) + params = append(params, fmt.Sprintf("name:%q", f.Name)) + } + if f.Template != "" { + params = append(params, fmt.Sprintf("template:%q", f.Template)) } + if f.FilterQuery != "" { + // If custom stuff is added, just add it on here. + params = append(params, f.FilterQuery) + } + + q := r.URL.Query() + q.Set("q", strings.Join(params, " ")) r.URL.RawQuery = q.Encode() } } diff --git a/site/src/api/api.test.ts b/site/src/api/api.test.ts index 083eb177fb6ac..f7455a9910b32 100644 --- a/site/src/api/api.test.ts +++ b/site/src/api/api.test.ts @@ -118,10 +118,10 @@ describe("api.ts", () => { it.each<[TypesGen.WorkspaceFilter | undefined, string]>([ [undefined, "/api/v2/workspaces"], - [{ organization_id: "1", owner: "" }, "/api/v2/workspaces?organization_id=1"], - [{ organization_id: "", owner: "1" }, "/api/v2/workspaces?owner=1"], + [{ q: "" }, "/api/v2/workspaces"], + [{ q: "owner:1" }, "/api/v2/workspaces?q=owner%3A1"], - [{ organization_id: "1", owner: "me" }, "/api/v2/workspaces?organization_id=1&owner=me"], + [{ q: "owner:me" }, "/api/v2/workspaces?q=owner%3Ame"], ])(`getWorkspacesURL(%p) returns %p`, (filter, expected) => { expect(getWorkspacesURL(filter)).toBe(expected) }) diff --git a/site/src/api/api.ts b/site/src/api/api.ts index 09c14645fbffb..e46dc2708748a 100644 --- a/site/src/api/api.ts +++ b/site/src/api/api.ts @@ -120,14 +120,8 @@ export const getWorkspacesURL = (filter?: TypesGen.WorkspaceFilter): string => { const basePath = "/api/v2/workspaces" const searchParams = new URLSearchParams() - if (filter?.organization_id) { - searchParams.append("organization_id", filter.organization_id) - } - if (filter?.owner) { - searchParams.append("owner", filter.owner) - } - if (filter?.name) { - searchParams.append("name", filter.name) + if (filter?.q && filter.q !== "") { + searchParams.append("q", filter.q) } const searchString = searchParams.toString() diff --git a/site/src/api/typesGenerated.ts b/site/src/api/typesGenerated.ts index 79423268ac6a2..d63b9b46c3505 100644 --- a/site/src/api/typesGenerated.ts +++ b/site/src/api/typesGenerated.ts @@ -90,7 +90,7 @@ export interface CreateUserRequest { readonly organization_id: string } -// From codersdk/workspaces.go:34:6 +// From codersdk/workspaces.go:35:6 export interface CreateWorkspaceBuildRequest { readonly template_version_id?: string readonly transition: WorkspaceTransition @@ -223,7 +223,7 @@ export interface ProvisionerJobLog { readonly output: string } -// From codersdk/workspaces.go:201:6 +// From codersdk/workspaces.go:202:6 export interface PutExtendWorkspaceRequest { readonly deadline: string } @@ -311,12 +311,12 @@ export interface UpdateUserProfileRequest { readonly username: string } -// From codersdk/workspaces.go:160:6 +// From codersdk/workspaces.go:161:6 export interface UpdateWorkspaceAutostartRequest { readonly schedule?: string } -// From codersdk/workspaces.go:180:6 +// From codersdk/workspaces.go:181:6 export interface UpdateWorkspaceTTLRequest { readonly ttl_ms?: number } @@ -371,7 +371,7 @@ export interface UsersRequest extends Pagination { readonly status?: string } -// From codersdk/workspaces.go:18:6 +// From codersdk/workspaces.go:19:6 export interface Workspace { readonly id: string readonly created_at: string @@ -461,19 +461,17 @@ export interface WorkspaceBuild { readonly deadline: string } -// From codersdk/workspaces.go:83:6 +// From codersdk/workspaces.go:84:6 export interface WorkspaceBuildsRequest extends Pagination { readonly WorkspaceID: string } -// From codersdk/workspaces.go:219:6 +// From codersdk/workspaces.go:220:6 export interface WorkspaceFilter { - readonly organization_id?: string - readonly owner?: string - readonly name?: string + readonly q?: string } -// From codersdk/workspaces.go:41:6 +// From codersdk/workspaces.go:42:6 export interface WorkspaceOptions { readonly include_deleted?: boolean } diff --git a/site/src/util/workspace.test.ts b/site/src/util/workspace.test.ts index a7789dfbf8285..cdfd833fe17a8 100644 --- a/site/src/util/workspace.test.ts +++ b/site/src/util/workspace.test.ts @@ -104,13 +104,13 @@ describe("util > workspace", () => { describe("workspaceQueryToFilter", () => { it.each<[string | undefined, TypesGen.WorkspaceFilter]>([ [undefined, {}], - ["", {}], - ["asdkfvjn", { name: "asdkfvjn" }], - ["owner:me", { owner: "me" }], - ["owner:me owner:me2", { owner: "me" }], - ["me/dev", { owner: "me", name: "dev" }], - ["me/", { owner: "me" }], - [" key:val owner:me ", { owner: "me" }], + ["", { q: "" }], + ["asdkfvjn", { q: "asdkfvjn" }], + ["owner:me", { q: "owner:me" }], + ["owner:me owner:me2", { q: "owner:me owner:me2" }], + ["me/dev", { q: "me/dev" }], + ["me/", { q: "me/" }], + [" key:val owner:me ", { q: "key:val owner:me" }], ])(`query=%p, filter=%p`, (query, filter) => { expect(workspaceQueryToFilter(query)).toEqual(filter) }) diff --git a/site/src/util/workspace.ts b/site/src/util/workspace.ts index 94758aebd00d7..42c997f103f69 100644 --- a/site/src/util/workspace.ts +++ b/site/src/util/workspace.ts @@ -263,42 +263,9 @@ export const defaultWorkspaceExtension = (__startDate?: dayjs.Dayjs): TypesGen.P } export const workspaceQueryToFilter = (query?: string): TypesGen.WorkspaceFilter => { - const defaultFilter: TypesGen.WorkspaceFilter = {} const preparedQuery = query?.trim().replace(/ +/g, " ") - - if (!preparedQuery) { - return defaultFilter - } else { - const parts = preparedQuery.split(" ") - - for (const part of parts) { - if (part.includes(":")) { - const [key, val] = part.split(":") - if (key && val) { - if (key === "owner") { - return { - owner: val, - } - } - // skip invalid key pairs - continue - } - } - - if (part.includes("/")) { - const [username, name] = part.split("/") - return { - owner: username, - name: name === "" ? undefined : name, - } - } - - return { - name: part, - } - } - - return defaultFilter + return { + q: preparedQuery, } }