Skip to content

feat: Backend api for filtering users using filter query string #2553

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Jun 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion coderd/coderd_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -338,10 +338,11 @@ func TestAuthorizeAllEndpoints(t *testing.T) {
AssertAction: rbac.ActionRead,
AssertObject: workspaceRBACObj,
},
"POST:/api/v2/users/{user}/organizations/": {
"POST:/api/v2/users/{user}/organizations": {
AssertAction: rbac.ActionCreate,
AssertObject: rbac.ResourceOrganization,
},
"GET:/api/v2/users": {StatusCode: http.StatusOK, AssertObject: rbac.ResourceUser},

// These endpoints need payloads to get to the auth part. Payloads will be required
"PUT:/api/v2/users/{user}/roles": {StatusCode: http.StatusBadRequest, NoAuthorize: true},
Expand Down
22 changes: 14 additions & 8 deletions coderd/database/databasefake/databasefake.go
Original file line number Diff line number Diff line change
Expand Up @@ -285,19 +285,25 @@ func (q *fakeQuerier) GetUsers(_ context.Context, params database.GetUsersParams
users = tmp
}

if len(params.Status) == 0 {
params.Status = []database.UserStatus{database.UserStatusActive}
if len(params.Status) > 0 {
usersFilteredByStatus := make([]database.User, 0, len(users))
for i, user := range users {
if slice.Contains(params.Status, user.Status) {
usersFilteredByStatus = append(usersFilteredByStatus, users[i])
}
}
users = usersFilteredByStatus
}

usersFilteredByStatus := make([]database.User, 0, len(users))
for i, user := range users {
for _, status := range params.Status {
if user.Status == status {
usersFilteredByStatus = append(usersFilteredByStatus, users[i])
if len(params.RbacRole) > 0 {
usersFilteredByRole := make([]database.User, 0, len(users))
for i, user := range users {
if slice.Overlap(params.RbacRole, user.RBACRoles) {
usersFilteredByRole = append(usersFilteredByRole, users[i])
}
}
users = usersFilteredByRole
}
users = usersFilteredByStatus

if params.OffsetOpt > 0 {
if int(params.OffsetOpt) > len(users)-1 {
Expand Down
7 changes: 7 additions & 0 deletions coderd/database/modelmethods.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,10 @@ func (d ProvisionerDaemon) RBACObject() rbac.Object {
func (f File) RBACObject() rbac.Object {
return rbac.ResourceFile.WithID(f.Hash).WithOwner(f.CreatedBy.String())
}

// RBACObject returns the RBAC object for the site wide user resource.
// If you are trying to get the RBAC object for the UserData, use
// rbac.ResourceUserData
func (u User) RBACObject() rbac.Object {
return rbac.ResourceUser.WithID(u.ID.String())
}
21 changes: 14 additions & 7 deletions coderd/database/queries.sql.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

15 changes: 10 additions & 5 deletions coderd/database/queries/users.sql
Original file line number Diff line number Diff line change
Expand Up @@ -108,12 +108,17 @@ WHERE
AND CASE
-- @status needs to be a text because it can be empty, If it was
-- user_status enum, it would not.
WHEN cardinality(@status :: user_status[]) > 0 THEN (
WHEN cardinality(@status :: user_status[]) > 0 THEN
status = ANY(@status :: user_status[])
)
ELSE
-- Only show active by default
status = 'active'
ELSE true
END
-- Filter by rbac_roles
AND CASE
-- @rbac_role allows filtering by rbac roles. If 'member' is included, show everyone, as
-- everyone is a member.
WHEN cardinality(@rbac_role :: text[]) > 0 AND 'member' != ANY(@rbac_role :: text[]) THEN
rbac_roles && @rbac_role :: text[]
ELSE true
END
-- End of filters
ORDER BY
Expand Down
25 changes: 21 additions & 4 deletions coderd/httpapi/queryparams.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,14 +83,31 @@ func (p *QueryParamParser) UUIDs(vals url.Values, def []uuid.UUID, queryParam st
return v
}

func (p *QueryParamParser) String(vals url.Values, def string, queryParam string) string {
v, err := parseQueryParam(vals, func(v string) (string, error) {
func (*QueryParamParser) String(vals url.Values, def string, queryParam string) string {
v, _ := parseQueryParam(vals, func(v string) (string, error) {
return v, nil
}, def, queryParam)
return v
}

func (*QueryParamParser) Strings(vals url.Values, def []string, queryParam string) []string {
v, _ := parseQueryParam(vals, func(v string) ([]string, error) {
if v == "" {
return []string{}, nil
}
return strings.Split(v, ","), nil
}, def, queryParam)
return v
}

// ParseCustom has to be a function, not a method on QueryParamParser because generics
// cannot be used on struct methods.
func ParseCustom[T any](parser *QueryParamParser, vals url.Values, def T, queryParam string, parseFunc func(v string) (T, error)) T {
v, err := parseQueryParam(vals, parseFunc, def, queryParam)
if err != nil {
p.Errors = append(p.Errors, Error{
parser.Errors = append(parser.Errors, Error{
Field: queryParam,
Detail: fmt.Sprintf("Query param %q must be a valid string", queryParam),
Detail: fmt.Sprintf("Query param %q has invalid uuids: %q", queryParam, err.Error()),
})
}
return v
Expand Down
96 changes: 65 additions & 31 deletions coderd/users.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"fmt"
"net"
"net/http"
"net/url"
"strings"
"time"

Expand Down Expand Up @@ -119,35 +120,13 @@ func (api *API) postFirstUser(rw http.ResponseWriter, r *http.Request) {
}

func (api *API) users(rw http.ResponseWriter, r *http.Request) {
var (
searchName = r.URL.Query().Get("search")
statusFilters = r.URL.Query().Get("status")
)

statuses := make([]database.UserStatus, 0)

if statusFilters != "" {
// Split on commas if present to account for it being a list
for _, filter := range strings.Split(statusFilters, ",") {
switch database.UserStatus(filter) {
case database.UserStatusSuspended, database.UserStatusActive:
statuses = append(statuses, database.UserStatus(filter))
default:
httpapi.Write(rw, http.StatusBadRequest, httpapi.Response{
Message: fmt.Sprintf("%q is not a valid user status.", filter),
Validations: []httpapi.Error{
{Field: "status", Detail: "invalid status"},
},
})
return
}
}
}

// Reading all users across the site.
if !api.Authorize(r, rbac.ActionRead, rbac.ResourceUser) {
httpapi.Forbidden(rw)
return
query := r.URL.Query().Get("q")
params, errs := userSearchQuery(query)
if len(errs) > 0 {
httpapi.Write(rw, http.StatusBadRequest, httpapi.Response{
Message: "Invalid user search query.",
Validations: errs,
})
}

paginationParams, ok := parsePagination(rw, r)
Expand All @@ -159,8 +138,9 @@ func (api *API) users(rw http.ResponseWriter, r *http.Request) {
AfterID: paginationParams.AfterID,
OffsetOpt: int32(paginationParams.Offset),
LimitOpt: int32(paginationParams.Limit),
Search: searchName,
Status: statuses,
Search: params.Search,
Status: params.Status,
RbacRole: params.RbacRole,
})
if errors.Is(err, sql.ErrNoRows) {
httpapi.Write(rw, http.StatusOK, []codersdk.User{})
Expand All @@ -174,6 +154,7 @@ func (api *API) users(rw http.ResponseWriter, r *http.Request) {
return
}

users = AuthorizeFilter(api, r, rbac.ActionRead, users)
userIDs := make([]uuid.UUID, 0, len(users))
for _, user := range users {
userIDs = append(userIDs, user.ID)
Expand Down Expand Up @@ -971,3 +952,56 @@ func findUser(id uuid.UUID, users []database.User) *database.User {
}
return nil
}

func userSearchQuery(query string) (database.GetUsersParams, []httpapi.Error) {
searchParams := make(url.Values)
if query == "" {
// No filter
return database.GetUsersParams{}, nil
}
// Because we do this in 2 passes, we want to maintain quotes on the first
// pass.Further splitting occurs on the second pass and quotes will be
// dropped.
elements := splitQueryParameterByDelimiter(query, ' ', true)
for _, element := range elements {
parts := splitQueryParameterByDelimiter(element, ':', false)
switch len(parts) {
case 1:
// No key:value pair.
searchParams.Set("search", parts[0])
case 2:
searchParams.Set(parts[0], parts[1])
default:
return database.GetUsersParams{}, []httpapi.Error{
{Field: "q", Detail: fmt.Sprintf("Query element %q can only contain 1 ':'", element)},
}
}
}

parser := httpapi.NewQueryParamParser()
filter := database.GetUsersParams{
Search: parser.String(searchParams, "", "search"),
Status: httpapi.ParseCustom(parser, searchParams, []database.UserStatus{}, "status", parseUserStatus),
RbacRole: parser.Strings(searchParams, []string{}, "role"),
}

return filter, parser.Errors
}

// parseUserStatus ensures proper enums are used for user statuses
func parseUserStatus(v string) ([]database.UserStatus, error) {
var statuses []database.UserStatus
if v == "" {
return statuses, nil
}
parts := strings.Split(v, ",")
for _, part := range parts {
switch database.UserStatus(part) {
case database.UserStatusActive, database.UserStatusSuspended:
statuses = append(statuses, database.UserStatus(part))
default:
return []database.UserStatus{}, xerrors.Errorf("%q is not a valid user status", part)
}
}
return statuses, nil
}
Loading