Skip to content

Commit d21ab21

Browse files
authored
feat: Backend api for filtering users using filter query string (#2553)
* User search query string
1 parent 981fb27 commit d21ab21

File tree

13 files changed

+343
-84
lines changed

13 files changed

+343
-84
lines changed

coderd/coderd_test.go

+2-1
Original file line numberDiff line numberDiff line change
@@ -338,10 +338,11 @@ func TestAuthorizeAllEndpoints(t *testing.T) {
338338
AssertAction: rbac.ActionRead,
339339
AssertObject: workspaceRBACObj,
340340
},
341-
"POST:/api/v2/users/{user}/organizations/": {
341+
"POST:/api/v2/users/{user}/organizations": {
342342
AssertAction: rbac.ActionCreate,
343343
AssertObject: rbac.ResourceOrganization,
344344
},
345+
"GET:/api/v2/users": {StatusCode: http.StatusOK, AssertObject: rbac.ResourceUser},
345346

346347
// These endpoints need payloads to get to the auth part. Payloads will be required
347348
"PUT:/api/v2/users/{user}/roles": {StatusCode: http.StatusBadRequest, NoAuthorize: true},

coderd/database/databasefake/databasefake.go

+14-8
Original file line numberDiff line numberDiff line change
@@ -285,19 +285,25 @@ func (q *fakeQuerier) GetUsers(_ context.Context, params database.GetUsersParams
285285
users = tmp
286286
}
287287

288-
if len(params.Status) == 0 {
289-
params.Status = []database.UserStatus{database.UserStatusActive}
288+
if len(params.Status) > 0 {
289+
usersFilteredByStatus := make([]database.User, 0, len(users))
290+
for i, user := range users {
291+
if slice.Contains(params.Status, user.Status) {
292+
usersFilteredByStatus = append(usersFilteredByStatus, users[i])
293+
}
294+
}
295+
users = usersFilteredByStatus
290296
}
291297

292-
usersFilteredByStatus := make([]database.User, 0, len(users))
293-
for i, user := range users {
294-
for _, status := range params.Status {
295-
if user.Status == status {
296-
usersFilteredByStatus = append(usersFilteredByStatus, users[i])
298+
if len(params.RbacRole) > 0 {
299+
usersFilteredByRole := make([]database.User, 0, len(users))
300+
for i, user := range users {
301+
if slice.Overlap(params.RbacRole, user.RBACRoles) {
302+
usersFilteredByRole = append(usersFilteredByRole, users[i])
297303
}
298304
}
305+
users = usersFilteredByRole
299306
}
300-
users = usersFilteredByStatus
301307

302308
if params.OffsetOpt > 0 {
303309
if int(params.OffsetOpt) > len(users)-1 {

coderd/database/modelmethods.go

+7
Original file line numberDiff line numberDiff line change
@@ -30,3 +30,10 @@ func (d ProvisionerDaemon) RBACObject() rbac.Object {
3030
func (f File) RBACObject() rbac.Object {
3131
return rbac.ResourceFile.WithID(f.Hash).WithOwner(f.CreatedBy.String())
3232
}
33+
34+
// RBACObject returns the RBAC object for the site wide user resource.
35+
// If you are trying to get the RBAC object for the UserData, use
36+
// rbac.ResourceUserData
37+
func (u User) RBACObject() rbac.Object {
38+
return rbac.ResourceUser.WithID(u.ID.String())
39+
}

coderd/database/queries.sql.go

+14-7
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

coderd/database/queries/users.sql

+10-5
Original file line numberDiff line numberDiff line change
@@ -108,12 +108,17 @@ WHERE
108108
AND CASE
109109
-- @status needs to be a text because it can be empty, If it was
110110
-- user_status enum, it would not.
111-
WHEN cardinality(@status :: user_status[]) > 0 THEN (
111+
WHEN cardinality(@status :: user_status[]) > 0 THEN
112112
status = ANY(@status :: user_status[])
113-
)
114-
ELSE
115-
-- Only show active by default
116-
status = 'active'
113+
ELSE true
114+
END
115+
-- Filter by rbac_roles
116+
AND CASE
117+
-- @rbac_role allows filtering by rbac roles. If 'member' is included, show everyone, as
118+
-- everyone is a member.
119+
WHEN cardinality(@rbac_role :: text[]) > 0 AND 'member' != ANY(@rbac_role :: text[]) THEN
120+
rbac_roles && @rbac_role :: text[]
121+
ELSE true
117122
END
118123
-- End of filters
119124
ORDER BY

coderd/httpapi/queryparams.go

+21-4
Original file line numberDiff line numberDiff line change
@@ -83,14 +83,31 @@ func (p *QueryParamParser) UUIDs(vals url.Values, def []uuid.UUID, queryParam st
8383
return v
8484
}
8585

86-
func (p *QueryParamParser) String(vals url.Values, def string, queryParam string) string {
87-
v, err := parseQueryParam(vals, func(v string) (string, error) {
86+
func (*QueryParamParser) String(vals url.Values, def string, queryParam string) string {
87+
v, _ := parseQueryParam(vals, func(v string) (string, error) {
8888
return v, nil
8989
}, def, queryParam)
90+
return v
91+
}
92+
93+
func (*QueryParamParser) Strings(vals url.Values, def []string, queryParam string) []string {
94+
v, _ := parseQueryParam(vals, func(v string) ([]string, error) {
95+
if v == "" {
96+
return []string{}, nil
97+
}
98+
return strings.Split(v, ","), nil
99+
}, def, queryParam)
100+
return v
101+
}
102+
103+
// ParseCustom has to be a function, not a method on QueryParamParser because generics
104+
// cannot be used on struct methods.
105+
func ParseCustom[T any](parser *QueryParamParser, vals url.Values, def T, queryParam string, parseFunc func(v string) (T, error)) T {
106+
v, err := parseQueryParam(vals, parseFunc, def, queryParam)
90107
if err != nil {
91-
p.Errors = append(p.Errors, Error{
108+
parser.Errors = append(parser.Errors, Error{
92109
Field: queryParam,
93-
Detail: fmt.Sprintf("Query param %q must be a valid string", queryParam),
110+
Detail: fmt.Sprintf("Query param %q has invalid uuids: %q", queryParam, err.Error()),
94111
})
95112
}
96113
return v

coderd/users.go

+65-31
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"fmt"
99
"net"
1010
"net/http"
11+
"net/url"
1112
"strings"
1213
"time"
1314

@@ -119,35 +120,13 @@ func (api *API) postFirstUser(rw http.ResponseWriter, r *http.Request) {
119120
}
120121

121122
func (api *API) users(rw http.ResponseWriter, r *http.Request) {
122-
var (
123-
searchName = r.URL.Query().Get("search")
124-
statusFilters = r.URL.Query().Get("status")
125-
)
126-
127-
statuses := make([]database.UserStatus, 0)
128-
129-
if statusFilters != "" {
130-
// Split on commas if present to account for it being a list
131-
for _, filter := range strings.Split(statusFilters, ",") {
132-
switch database.UserStatus(filter) {
133-
case database.UserStatusSuspended, database.UserStatusActive:
134-
statuses = append(statuses, database.UserStatus(filter))
135-
default:
136-
httpapi.Write(rw, http.StatusBadRequest, httpapi.Response{
137-
Message: fmt.Sprintf("%q is not a valid user status.", filter),
138-
Validations: []httpapi.Error{
139-
{Field: "status", Detail: "invalid status"},
140-
},
141-
})
142-
return
143-
}
144-
}
145-
}
146-
147-
// Reading all users across the site.
148-
if !api.Authorize(r, rbac.ActionRead, rbac.ResourceUser) {
149-
httpapi.Forbidden(rw)
150-
return
123+
query := r.URL.Query().Get("q")
124+
params, errs := userSearchQuery(query)
125+
if len(errs) > 0 {
126+
httpapi.Write(rw, http.StatusBadRequest, httpapi.Response{
127+
Message: "Invalid user search query.",
128+
Validations: errs,
129+
})
151130
}
152131

153132
paginationParams, ok := parsePagination(rw, r)
@@ -159,8 +138,9 @@ func (api *API) users(rw http.ResponseWriter, r *http.Request) {
159138
AfterID: paginationParams.AfterID,
160139
OffsetOpt: int32(paginationParams.Offset),
161140
LimitOpt: int32(paginationParams.Limit),
162-
Search: searchName,
163-
Status: statuses,
141+
Search: params.Search,
142+
Status: params.Status,
143+
RbacRole: params.RbacRole,
164144
})
165145
if errors.Is(err, sql.ErrNoRows) {
166146
httpapi.Write(rw, http.StatusOK, []codersdk.User{})
@@ -174,6 +154,7 @@ func (api *API) users(rw http.ResponseWriter, r *http.Request) {
174154
return
175155
}
176156

157+
users = AuthorizeFilter(api, r, rbac.ActionRead, users)
177158
userIDs := make([]uuid.UUID, 0, len(users))
178159
for _, user := range users {
179160
userIDs = append(userIDs, user.ID)
@@ -971,3 +952,56 @@ func findUser(id uuid.UUID, users []database.User) *database.User {
971952
}
972953
return nil
973954
}
955+
956+
func userSearchQuery(query string) (database.GetUsersParams, []httpapi.Error) {
957+
searchParams := make(url.Values)
958+
if query == "" {
959+
// No filter
960+
return database.GetUsersParams{}, nil
961+
}
962+
// Because we do this in 2 passes, we want to maintain quotes on the first
963+
// pass.Further splitting occurs on the second pass and quotes will be
964+
// dropped.
965+
elements := splitQueryParameterByDelimiter(query, ' ', true)
966+
for _, element := range elements {
967+
parts := splitQueryParameterByDelimiter(element, ':', false)
968+
switch len(parts) {
969+
case 1:
970+
// No key:value pair.
971+
searchParams.Set("search", parts[0])
972+
case 2:
973+
searchParams.Set(parts[0], parts[1])
974+
default:
975+
return database.GetUsersParams{}, []httpapi.Error{
976+
{Field: "q", Detail: fmt.Sprintf("Query element %q can only contain 1 ':'", element)},
977+
}
978+
}
979+
}
980+
981+
parser := httpapi.NewQueryParamParser()
982+
filter := database.GetUsersParams{
983+
Search: parser.String(searchParams, "", "search"),
984+
Status: httpapi.ParseCustom(parser, searchParams, []database.UserStatus{}, "status", parseUserStatus),
985+
RbacRole: parser.Strings(searchParams, []string{}, "role"),
986+
}
987+
988+
return filter, parser.Errors
989+
}
990+
991+
// parseUserStatus ensures proper enums are used for user statuses
992+
func parseUserStatus(v string) ([]database.UserStatus, error) {
993+
var statuses []database.UserStatus
994+
if v == "" {
995+
return statuses, nil
996+
}
997+
parts := strings.Split(v, ",")
998+
for _, part := range parts {
999+
switch database.UserStatus(part) {
1000+
case database.UserStatusActive, database.UserStatusSuspended:
1001+
statuses = append(statuses, database.UserStatus(part))
1002+
default:
1003+
return []database.UserStatus{}, xerrors.Errorf("%q is not a valid user status", part)
1004+
}
1005+
}
1006+
return statuses, nil
1007+
}

0 commit comments

Comments
 (0)