Skip to content

feat: add count to get users endpoint #5016

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Nov 14, 2022
Merged
4 changes: 2 additions & 2 deletions cli/list.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,12 +102,12 @@ func list() *cobra.Command {
_, _ = fmt.Fprintln(cmd.ErrOrStderr())
return nil
}
users, err := client.Users(cmd.Context(), codersdk.UsersRequest{})
userRes, err := client.Users(cmd.Context(), codersdk.UsersRequest{})
if err != nil {
return err
}
usersByID := map[uuid.UUID]codersdk.User{}
for _, user := range users {
for _, user := range userRes.Users {
usersByID[user.ID] = user
}

Expand Down
6 changes: 3 additions & 3 deletions cli/userlist.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,20 +30,20 @@ func userList() *cobra.Command {
if err != nil {
return err
}
users, err := client.Users(cmd.Context(), codersdk.UsersRequest{})
res, err := client.Users(cmd.Context(), codersdk.UsersRequest{})
if err != nil {
return err
}

out := ""
switch outputFormat {
case "table", "":
out, err = cliui.DisplayTable(users, "Username", columns)
out, err = cliui.DisplayTable(res.Users, "Username", columns)
if err != nil {
return xerrors.Errorf("render table: %w", err)
}
case "json":
outBytes, err := json.Marshal(users)
outBytes, err := json.Marshal(res.Users)
if err != nil {
return xerrors.Errorf("marshal users to JSON: %w", err)
}
Expand Down
1 change: 0 additions & 1 deletion coderd/coderd.go
Original file line number Diff line number Diff line change
Expand Up @@ -437,7 +437,6 @@ func New(options *Options) *API {
)
r.Post("/", api.postUser)
r.Get("/", api.users)
r.Get("/count", api.userCount)
r.Post("/logout", api.postLogout)
// These routes query information about site wide roles.
r.Route("/roles", func(r chi.Router) {
Expand Down
1 change: 0 additions & 1 deletion coderd/coderdtest/authorize.go
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,6 @@ func AGPLRoutes(a *AuthTester) (map[string]string, map[string]RouteCheck) {

// Endpoints that use the SQLQuery filter.
"GET:/api/v2/workspaces/": {StatusCode: http.StatusOK, NoAuthorize: true},
"GET:/api/v2/users/count": {StatusCode: http.StatusOK, NoAuthorize: true},
}

// Routes like proxy routes support all HTTP methods. A helper func to expand
Expand Down
33 changes: 29 additions & 4 deletions coderd/database/databasefake/databasefake.go
Original file line number Diff line number Diff line change
Expand Up @@ -538,7 +538,7 @@ func (q *fakeQuerier) UpdateUserDeletedByID(_ context.Context, params database.U
return sql.ErrNoRows
}

func (q *fakeQuerier) GetUsers(_ context.Context, params database.GetUsersParams) ([]database.User, error) {
func (q *fakeQuerier) GetUsers(_ context.Context, params database.GetUsersParams) ([]database.GetUsersRow, error) {
q.mutex.RLock()
defer q.mutex.RUnlock()

Expand Down Expand Up @@ -579,7 +579,7 @@ func (q *fakeQuerier) GetUsers(_ context.Context, params database.GetUsersParams

// If no users after the time, then we return an empty list.
if !found {
return nil, sql.ErrNoRows
return []database.GetUsersRow{}, nil
}
}

Expand Down Expand Up @@ -617,9 +617,11 @@ func (q *fakeQuerier) GetUsers(_ context.Context, params database.GetUsersParams
users = usersFilteredByRole
}

beforePageCount := len(users)

if params.OffsetOpt > 0 {
if int(params.OffsetOpt) > len(users)-1 {
return nil, sql.ErrNoRows
return []database.GetUsersRow{}, nil
}
users = users[params.OffsetOpt:]
}
Expand All @@ -631,7 +633,30 @@ func (q *fakeQuerier) GetUsers(_ context.Context, params database.GetUsersParams
users = users[:params.LimitOpt]
}

return users, nil
return convertUsers(users, int64(beforePageCount)), nil
}

func convertUsers(users []database.User, count int64) []database.GetUsersRow {
rows := make([]database.GetUsersRow, len(users))
for i, u := range users {
rows[i] = database.GetUsersRow{
ID: u.ID,
Email: u.Email,
Username: u.Username,
HashedPassword: u.HashedPassword,
CreatedAt: u.CreatedAt,
UpdatedAt: u.UpdatedAt,
Status: u.Status,
RBACRoles: u.RBACRoles,
LoginType: u.LoginType,
AvatarURL: u.AvatarURL,
Deleted: u.Deleted,
LastSeenAt: u.LastSeenAt,
Count: count,
}
}

return rows
}

func (q *fakeQuerier) GetUsersByIDs(_ context.Context, ids []uuid.UUID) ([]database.User, error) {
Expand Down
22 changes: 22 additions & 0 deletions coderd/database/modelmethods.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,3 +71,25 @@ func (User) RBACObject() rbac.Object {
func (License) RBACObject() rbac.Object {
return rbac.ResourceLicense
}

func ConvertUserRows(rows []GetUsersRow) []User {
users := make([]User, len(rows))
for i, r := range rows {
users[i] = User{
ID: r.ID,
Email: r.Email,
Username: r.Username,
HashedPassword: r.HashedPassword,
CreatedAt: r.CreatedAt,
UpdatedAt: r.UpdatedAt,
Status: r.Status,
RBACRoles: r.RBACRoles,
LoginType: r.LoginType,
AvatarURL: r.AvatarURL,
Deleted: r.Deleted,
LastSeenAt: r.LastSeenAt,
}
}

return users
}
2 changes: 1 addition & 1 deletion coderd/database/querier.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

25 changes: 21 additions & 4 deletions coderd/database/queries.sql.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion coderd/database/queries/users.sql
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ WHERE

-- name: GetUsers :many
SELECT
*
*, COUNT(*) OVER() AS count
FROM
users
WHERE
Expand Down
3 changes: 2 additions & 1 deletion coderd/telemetry/telemetry.go
Original file line number Diff line number Diff line change
Expand Up @@ -344,10 +344,11 @@ func (r *remoteReporter) createSnapshot() (*Snapshot, error) {
return nil
})
eg.Go(func() error {
users, err := r.options.Database.GetUsers(ctx, database.GetUsersParams{})
userRows, err := r.options.Database.GetUsers(ctx, database.GetUsersParams{})
if err != nil {
return xerrors.Errorf("get users: %w", err)
}
users := database.ConvertUserRows(userRows)
var firstUser database.User
for _, dbUser := range users {
if dbUser.Status != database.UserStatusActive {
Expand Down
56 changes: 14 additions & 42 deletions coderd/users.go
Original file line number Diff line number Diff line change
Expand Up @@ -198,27 +198,32 @@ func (api *API) users(rw http.ResponseWriter, r *http.Request) {
return
}

users, err := api.Database.GetUsers(ctx, database.GetUsersParams{
userRows, err := api.Database.GetUsers(ctx, database.GetUsersParams{
AfterID: paginationParams.AfterID,
OffsetOpt: int32(paginationParams.Offset),
LimitOpt: int32(paginationParams.Limit),
Search: params.Search,
Status: params.Status,
RbacRole: params.RbacRole,
})
if errors.Is(err, sql.ErrNoRows) {
httpapi.Write(ctx, rw, http.StatusOK, []codersdk.User{})
return
}
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error fetching users.",
Detail: err.Error(),
})
return
}
// GetUsers does not return ErrNoRows because it uses a window function to get the count.
// So we need to check if the userRows is empty and return an empty array if so.
if len(userRows) == 0 {
httpapi.Write(ctx, rw, http.StatusOK, codersdk.GetUsersResponse{
Users: []codersdk.User{},
Count: 0,
})
return
}
Comment on lines +216 to +224
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it ever could return that error anyways since it returns an array, and only :one queries will error. I think this simplifies the code below though, which is nice.


users, err = AuthorizeFilter(api.HTTPAuth, r, rbac.ActionRead, users)
users, err := AuthorizeFilter(api.HTTPAuth, r, rbac.ActionRead, database.ConvertUserRows(userRows))
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error fetching users.",
Expand Down Expand Up @@ -248,42 +253,9 @@ func (api *API) users(rw http.ResponseWriter, r *http.Request) {
}

render.Status(r, http.StatusOK)
render.JSON(rw, r, convertUsers(users, organizationIDsByUserID))
}

func (api *API) userCount(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
query := r.URL.Query().Get("q")
params, errs := userSearchQuery(query)
if len(errs) > 0 {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Invalid user search query.",
Validations: errs,
})
return
}

sqlFilter, err := api.HTTPAuth.AuthorizeSQLFilter(r, rbac.ActionRead, rbac.ResourceUser.Type)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error preparing sql filter.",
Detail: err.Error(),
})
return
}

count, err := api.Database.GetAuthorizedUserCount(ctx, database.GetFilteredUserCountParams{
Search: params.Search,
Status: params.Status,
RbacRole: params.RbacRole,
}, sqlFilter)
if err != nil {
httpapi.InternalServerError(rw, err)
return
}

httpapi.Write(ctx, rw, http.StatusOK, codersdk.UserCountResponse{
Count: count,
render.JSON(rw, r, codersdk.GetUsersResponse{
Users: convertUsers(users, organizationIDsByUserID),
Count: int(userRows[0].Count),
})
}

Expand Down
Loading