Skip to content

Commit 28f54a5

Browse files
presleypf0ssel
authored andcommitted
feat: add count endpoint for users, enabling better pagination (#4848)
* Start on backend * Hook up frontend * Add to frontend test * Add go test, wip * Fix some test bugs * Fix test * Format * Add to authorize.go * copy user array into local variable * Authorize route * Log count error * Authorize better * Tweaks to authorization * More authorization tweaks * Make gen * Fix test Co-authored-by: Garrett <garrett@coder.com>
1 parent 366aa25 commit 28f54a5

File tree

19 files changed

+612
-208
lines changed

19 files changed

+612
-208
lines changed

coderd/coderd.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -437,6 +437,7 @@ func New(options *Options) *API {
437437
)
438438
r.Post("/", api.postUser)
439439
r.Get("/", api.users)
440+
r.Get("/count", api.userCount)
440441
r.Post("/logout", api.postLogout)
441442
// These routes query information about site wide roles.
442443
r.Route("/roles", func(r chi.Router) {

coderd/coderdtest/authorize.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,7 @@ func AGPLRoutes(a *AuthTester) (map[string]string, map[string]RouteCheck) {
246246
// Endpoints that use the SQLQuery filter.
247247
"GET:/api/v2/workspaces/": {StatusCode: http.StatusOK, NoAuthorize: true},
248248
"GET:/api/v2/workspaces/count": {StatusCode: http.StatusOK, NoAuthorize: true},
249+
"GET:/api/v2/users/count": {StatusCode: http.StatusOK, NoAuthorize: true},
249250
}
250251

251252
// Routes like proxy routes support all HTTP methods. A helper func to expand

coderd/database/databasefake/databasefake.go

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -457,6 +457,72 @@ func (q *fakeQuerier) GetActiveUserCount(_ context.Context) (int64, error) {
457457
return active, nil
458458
}
459459

460+
func (q *fakeQuerier) GetFilteredUserCount(ctx context.Context, arg database.GetFilteredUserCountParams) (int64, error) {
461+
count, err := q.GetAuthorizedUserCount(ctx, arg, nil)
462+
return count, err
463+
}
464+
465+
func (q *fakeQuerier) GetAuthorizedUserCount(_ context.Context, params database.GetFilteredUserCountParams, authorizedFilter rbac.AuthorizeFilter) (int64, error) {
466+
q.mutex.RLock()
467+
defer q.mutex.RUnlock()
468+
469+
users := append([]database.User{}, q.users...)
470+
471+
if params.Deleted {
472+
tmp := make([]database.User, 0, len(users))
473+
for _, user := range users {
474+
if user.Deleted {
475+
tmp = append(tmp, user)
476+
}
477+
}
478+
users = tmp
479+
}
480+
481+
if params.Search != "" {
482+
tmp := make([]database.User, 0, len(users))
483+
for i, user := range users {
484+
if strings.Contains(strings.ToLower(user.Email), strings.ToLower(params.Search)) {
485+
tmp = append(tmp, users[i])
486+
} else if strings.Contains(strings.ToLower(user.Username), strings.ToLower(params.Search)) {
487+
tmp = append(tmp, users[i])
488+
}
489+
}
490+
users = tmp
491+
}
492+
493+
if len(params.Status) > 0 {
494+
usersFilteredByStatus := make([]database.User, 0, len(users))
495+
for i, user := range users {
496+
if slice.ContainsCompare(params.Status, user.Status, func(a, b database.UserStatus) bool {
497+
return strings.EqualFold(string(a), string(b))
498+
}) {
499+
usersFilteredByStatus = append(usersFilteredByStatus, users[i])
500+
}
501+
}
502+
users = usersFilteredByStatus
503+
}
504+
505+
if len(params.RbacRole) > 0 && !slice.Contains(params.RbacRole, rbac.RoleMember()) {
506+
usersFilteredByRole := make([]database.User, 0, len(users))
507+
for i, user := range users {
508+
if slice.OverlapCompare(params.RbacRole, user.RBACRoles, strings.EqualFold) {
509+
usersFilteredByRole = append(usersFilteredByRole, users[i])
510+
}
511+
}
512+
513+
users = usersFilteredByRole
514+
}
515+
516+
for _, user := range q.workspaces {
517+
// If the filter exists, ensure the object is authorized.
518+
if authorizedFilter != nil && !authorizedFilter.Eval(user.RBACObject()) {
519+
continue
520+
}
521+
}
522+
523+
return int64(len(users)), nil
524+
}
525+
460526
func (q *fakeQuerier) UpdateUserDeletedByID(_ context.Context, params database.UpdateUserDeletedByIDParams) error {
461527
q.mutex.Lock()
462528
defer q.mutex.Unlock()

coderd/database/modelqueries.go

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ import (
1919
type customQuerier interface {
2020
templateQuerier
2121
workspaceQuerier
22+
userQuerier
2223
}
2324

2425
type templateQuerier interface {
@@ -169,8 +170,6 @@ func (q *sqlQuerier) GetAuthorizedWorkspaces(ctx context.Context, arg GetWorkspa
169170
}
170171

171172
func (q *sqlQuerier) GetAuthorizedWorkspaceCount(ctx context.Context, arg GetWorkspaceCountParams, authorizedFilter rbac.AuthorizeFilter) (int64, error) {
172-
// In order to properly use ORDER BY, OFFSET, and LIMIT, we need to inject the
173-
// authorizedFilter between the end of the where clause and those statements.
174173
filter := strings.Replace(getWorkspaceCount, "-- @authorize_filter", fmt.Sprintf(" AND %s", authorizedFilter.SQLString(rbac.NoACLConfig())), 1)
175174
// The name comment is for metric tracking
176175
query := fmt.Sprintf("-- name: GetAuthorizedWorkspaceCount :one\n%s", filter)
@@ -187,3 +186,21 @@ func (q *sqlQuerier) GetAuthorizedWorkspaceCount(ctx context.Context, arg GetWor
187186
err := row.Scan(&count)
188187
return count, err
189188
}
189+
190+
type userQuerier interface {
191+
GetAuthorizedUserCount(ctx context.Context, arg GetFilteredUserCountParams, authorizedFilter rbac.AuthorizeFilter) (int64, error)
192+
}
193+
194+
func (q *sqlQuerier) GetAuthorizedUserCount(ctx context.Context, arg GetFilteredUserCountParams, authorizedFilter rbac.AuthorizeFilter) (int64, error) {
195+
filter := strings.Replace(getFilteredUserCount, "-- @authorize_filter", fmt.Sprintf(" AND %s", authorizedFilter.SQLString(rbac.NoACLConfig())), 1)
196+
query := fmt.Sprintf("-- name: GetAuthorizedUserCount :one\n%s", filter)
197+
row := q.db.QueryRowContext(ctx, query,
198+
arg.Deleted,
199+
arg.Search,
200+
pq.Array(arg.Status),
201+
pq.Array(arg.RbacRole),
202+
)
203+
var count int64
204+
err := row.Scan(&count)
205+
return count, err
206+
}

coderd/database/querier.go

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

coderd/database/queries.sql.go

Lines changed: 54 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

coderd/database/queries/users.sql

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,41 @@ FROM
3939
WHERE
4040
status = 'active'::user_status AND deleted = false;
4141

42+
-- name: GetFilteredUserCount :one
43+
SELECT
44+
COUNT(*)
45+
FROM
46+
users
47+
WHERE
48+
users.deleted = @deleted
49+
-- Start filters
50+
-- Filter by name, email or username
51+
AND CASE
52+
WHEN @search :: text != '' THEN (
53+
email ILIKE concat('%', @search, '%')
54+
OR username ILIKE concat('%', @search, '%')
55+
)
56+
ELSE true
57+
END
58+
-- Filter by status
59+
AND CASE
60+
-- @status needs to be a text because it can be empty, If it was
61+
-- user_status enum, it would not.
62+
WHEN cardinality(@status :: user_status[]) > 0 THEN
63+
status = ANY(@status :: user_status[])
64+
ELSE true
65+
END
66+
-- Filter by rbac_roles
67+
AND CASE
68+
-- @rbac_role allows filtering by rbac roles. If 'member' is included, show everyone, as everyone is a member.
69+
WHEN cardinality(@rbac_role :: text[]) > 0 AND 'member' != ANY(@rbac_role :: text[])
70+
THEN rbac_roles && @rbac_role :: text[]
71+
ELSE true
72+
END
73+
-- Authorize Filter clause will be injected below in GetAuthorizedUserCount
74+
-- @authorize_filter
75+
;
76+
4277
-- name: InsertUser :one
4378
INSERT INTO
4479
users (

coderd/users.go

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,42 @@ func (api *API) users(rw http.ResponseWriter, r *http.Request) {
251251
render.JSON(rw, r, convertUsers(users, organizationIDsByUserID))
252252
}
253253

254+
func (api *API) userCount(rw http.ResponseWriter, r *http.Request) {
255+
ctx := r.Context()
256+
query := r.URL.Query().Get("q")
257+
params, errs := userSearchQuery(query)
258+
if len(errs) > 0 {
259+
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
260+
Message: "Invalid user search query.",
261+
Validations: errs,
262+
})
263+
return
264+
}
265+
266+
sqlFilter, err := api.HTTPAuth.AuthorizeSQLFilter(r, rbac.ActionRead, rbac.ResourceUser.Type)
267+
if err != nil {
268+
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
269+
Message: "Internal error preparing sql filter.",
270+
Detail: err.Error(),
271+
})
272+
return
273+
}
274+
275+
count, err := api.Database.GetAuthorizedUserCount(ctx, database.GetFilteredUserCountParams{
276+
Search: params.Search,
277+
Status: params.Status,
278+
RbacRole: params.RbacRole,
279+
}, sqlFilter)
280+
if err != nil {
281+
httpapi.InternalServerError(rw, err)
282+
return
283+
}
284+
285+
httpapi.Write(ctx, rw, http.StatusOK, codersdk.UserCountResponse{
286+
Count: count,
287+
})
288+
}
289+
254290
// Creates a new user.
255291
func (api *API) postUser(rw http.ResponseWriter, r *http.Request) {
256292
ctx := r.Context()

coderd/users_test.go

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1255,6 +1255,58 @@ func TestGetUsers(t *testing.T) {
12551255
})
12561256
}
12571257

1258+
func TestGetFilteredUserCount(t *testing.T) {
1259+
t.Parallel()
1260+
t.Run("AllUsers", func(t *testing.T) {
1261+
t.Parallel()
1262+
client := coderdtest.New(t, nil)
1263+
user := coderdtest.CreateFirstUser(t, client)
1264+
1265+
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
1266+
defer cancel()
1267+
1268+
client.CreateUser(ctx, codersdk.CreateUserRequest{
1269+
Email: "alice@email.com",
1270+
Username: "alice",
1271+
Password: "password",
1272+
OrganizationID: user.OrganizationID,
1273+
})
1274+
// No params is all users
1275+
response, err := client.UserCount(ctx, codersdk.UserCountRequest{})
1276+
require.NoError(t, err)
1277+
require.Equal(t, 2, int(response.Count))
1278+
})
1279+
t.Run("ActiveUsers", func(t *testing.T) {
1280+
t.Parallel()
1281+
client := coderdtest.New(t, nil)
1282+
first := coderdtest.CreateFirstUser(t, client)
1283+
1284+
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
1285+
defer cancel()
1286+
1287+
_, err := client.User(ctx, first.UserID.String())
1288+
require.NoError(t, err, "")
1289+
1290+
// Alice will be suspended
1291+
alice, err := client.CreateUser(ctx, codersdk.CreateUserRequest{
1292+
Email: "alice@email.com",
1293+
Username: "alice",
1294+
Password: "password",
1295+
OrganizationID: first.OrganizationID,
1296+
})
1297+
require.NoError(t, err)
1298+
1299+
_, err = client.UpdateUserStatus(ctx, alice.Username, codersdk.UserStatusSuspended)
1300+
require.NoError(t, err)
1301+
1302+
response, err := client.UserCount(ctx, codersdk.UserCountRequest{
1303+
Status: codersdk.UserStatusActive,
1304+
})
1305+
require.NoError(t, err)
1306+
require.Equal(t, 1, int(response.Count))
1307+
})
1308+
}
1309+
12581310
func TestPostTokens(t *testing.T) {
12591311
t.Parallel()
12601312
client := coderdtest.New(t, nil)

coderd/workspaces.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ func (api *API) workspaceCount(rw http.ResponseWriter, r *http.Request) {
166166
filter, errs := workspaceSearchQuery(queryStr, codersdk.Pagination{})
167167
if len(errs) > 0 {
168168
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
169-
Message: "Invalid audit search query.",
169+
Message: "Invalid workspace search query.",
170170
Validations: errs,
171171
})
172172
return

0 commit comments

Comments
 (0)