From d34aa7e362ef16d95c59d3a14ea163d5f788dafb Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Tue, 12 Apr 2022 17:22:41 -0500 Subject: [PATCH 1/3] feat: Basic user paginations - Currently "before" does not work - Filters not added yet --- coderd/coderd.go | 1 + coderd/database/databasefake/databasefake.go | 5 ++ coderd/database/querier.go | 1 + coderd/database/queries.sql.go | 60 ++++++++++++++++ coderd/database/queries/users.sql | 23 ++++++ coderd/users.go | 76 ++++++++++++++++++++ coderd/users_test.go | 42 +++++++++++ codersdk/users.go | 32 +++++++++ 8 files changed, 240 insertions(+) diff --git a/coderd/coderd.go b/coderd/coderd.go index 83bb6ec78e4ca..536f53dfd3757 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -140,6 +140,7 @@ func New(options *Options) (http.Handler, func()) { }) }) r.Route("/users", func(r chi.Router) { + r.Get("/", api.getPaginatedUsers) r.Get("/first", api.firstUser) r.Post("/first", api.postFirstUser) r.Post("/login", api.postLogin) diff --git a/coderd/database/databasefake/databasefake.go b/coderd/database/databasefake/databasefake.go index 36b1d822dbcda..ce84b59e8499b 100644 --- a/coderd/database/databasefake/databasefake.go +++ b/coderd/database/databasefake/databasefake.go @@ -3,6 +3,7 @@ package databasefake import ( "context" "database/sql" + "golang.org/x/xerrors" "strings" "sync" @@ -1373,3 +1374,7 @@ func (q *fakeQuerier) DeleteGitSSHKey(_ context.Context, userID uuid.UUID) error } return sql.ErrNoRows } + +func (q *fakeQuerier) PaginatedUsers(ctx context.Context, arg database.PaginatedUsersParams) ([]database.User, error) { + return nil, xerrors.Errorf("not implemented") +} diff --git a/coderd/database/querier.go b/coderd/database/querier.go index 073113a2451df..92fcc0a6969a1 100644 --- a/coderd/database/querier.go +++ b/coderd/database/querier.go @@ -72,6 +72,7 @@ type querier interface { InsertWorkspaceAgent(ctx context.Context, arg InsertWorkspaceAgentParams) (WorkspaceAgent, error) InsertWorkspaceBuild(ctx context.Context, arg InsertWorkspaceBuildParams) (WorkspaceBuild, error) InsertWorkspaceResource(ctx context.Context, arg InsertWorkspaceResourceParams) (WorkspaceResource, error) + PaginatedUsers(ctx context.Context, arg PaginatedUsersParams) ([]User, error) UpdateAPIKeyByID(ctx context.Context, arg UpdateAPIKeyByIDParams) error UpdateGitSSHKey(ctx context.Context, arg UpdateGitSSHKeyParams) error UpdateProvisionerDaemonByID(ctx context.Context, arg UpdateProvisionerDaemonByIDParams) error diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index e816b90c57a17..f48aae6daa88e 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -1905,6 +1905,66 @@ func (q *sqlQuerier) InsertUser(ctx context.Context, arg InsertUserParams) (User return i, err } +const paginatedUsers = `-- name: PaginatedUsers :many +SELECT + id, email, name, revoked, login_type, hashed_password, created_at, updated_at, username +FROM + users +WHERE + CASE + WHEN $1::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN + created_at > (SELECT created_at FROM users WHERE id = $1) + WHEN $2::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN + created_at < (SELECT created_at FROM users WHERE id = $2) + ELSE true + END +ORDER BY + -- TODO: When doing 'before', we need to flip this to DESC. + -- You cannot put 'ASC' or 'DESC' in a CASE statement. :'( + created_at ASC +LIMIT + $3 +` + +type PaginatedUsersParams struct { + After uuid.UUID `db:"after" json:"after"` + Before uuid.UUID `db:"before" json:"before"` + LimitOpt int32 `db:"limit_opt" json:"limit_opt"` +} + +func (q *sqlQuerier) PaginatedUsers(ctx context.Context, arg PaginatedUsersParams) ([]User, error) { + rows, err := q.db.QueryContext(ctx, paginatedUsers, arg.After, arg.Before, arg.LimitOpt) + if err != nil { + return nil, err + } + defer rows.Close() + var items []User + for rows.Next() { + var i User + if err := rows.Scan( + &i.ID, + &i.Email, + &i.Name, + &i.Revoked, + &i.LoginType, + &i.HashedPassword, + &i.CreatedAt, + &i.UpdatedAt, + &i.Username, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const updateUserProfile = `-- name: UpdateUserProfile :one UPDATE users diff --git a/coderd/database/queries/users.sql b/coderd/database/queries/users.sql index f41a96877d4d6..d3709445bb607 100644 --- a/coderd/database/queries/users.sql +++ b/coderd/database/queries/users.sql @@ -51,3 +51,26 @@ SET updated_at = $5 WHERE id = $1 RETURNING *; + + +-- name: PaginatedUsers :many +SELECT + * +FROM + users +WHERE + CASE + WHEN @after::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN + created_at > (SELECT created_at FROM users WHERE id = @after) + WHEN @before::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN + created_at < (SELECT created_at FROM users WHERE id = @before) + ELSE true + END +ORDER BY + -- TODO: When doing 'before', we need to flip this to DESC. + -- You cannot put 'ASC' or 'DESC' in a CASE statement. :'( + -- Until we figure this out, before is broken. + -- Another option is to do a subquery above + created_at ASC +LIMIT + @limit_opt; diff --git a/coderd/users.go b/coderd/users.go index e9a93d0f2a50e..e6cb2c011e3d7 100644 --- a/coderd/users.go +++ b/coderd/users.go @@ -7,6 +7,7 @@ import ( "errors" "fmt" "net/http" + "strconv" "time" "github.com/go-chi/chi/v5" @@ -145,6 +146,67 @@ func (api *api) postFirstUser(rw http.ResponseWriter, r *http.Request) { }) } +func (api *api) getPaginatedUsers(rw http.ResponseWriter, r *http.Request) { + var ( + beforeArg = r.URL.Query().Get("before") + afterArg = r.URL.Query().Get("after") + limitArg = r.URL.Query().Get("limit") + ) + + limit, err := strconv.Atoi(limitArg) + if limitArg != "" && err != nil { + httpapi.Write(rw, http.StatusBadRequest, httpapi.Response{ + Message: fmt.Sprintf("limit must be an integer: %s", err.Error()), + }) + return + } + if limit <= 0 { + // Default + limit = 10 + } + + var before uuid.UUID + var after uuid.UUID + if beforeArg != "" { + before, err = uuid.Parse(beforeArg) + if err != nil { + httpapi.Write(rw, http.StatusBadRequest, httpapi.Response{ + Message: fmt.Sprintf("before must be a uuid: %s", err.Error()), + }) + return + } + } + + if afterArg != "" { + after, err = uuid.Parse(afterArg) + if err != nil { + httpapi.Write(rw, http.StatusBadRequest, httpapi.Response{ + Message: fmt.Sprintf("after must be a uuid: %s", err.Error()), + }) + return + } + } + + var _, _ = before, after + users, err := api.Database.PaginatedUsers(r.Context(), database.PaginatedUsersParams{ + Before: before, + After: after, + LimitOpt: int32(limit), + }) + //users, err := api.Database.PaginatedUsers(r.Context(), int32(limit)) + if err != nil { + httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ + Message: err.Error(), + }) + return + } + + render.Status(r, http.StatusOK) + render.JSON(rw, r, codersdk.PaginatedUsers{ + Page: convertUsers(users), + }) +} + // Creates a new user. func (api *api) postUsers(rw http.ResponseWriter, r *http.Request) { apiKey := httpmw.APIKey(r) @@ -939,3 +1001,17 @@ func convertUser(user database.User) codersdk.User { Name: user.Name, } } + +func convertUsers(users []database.User) []codersdk.User { + converted := make([]codersdk.User, 0, len(users)) + for _, u := range users { + converted = append(converted, codersdk.User{ + ID: u.ID, + Email: u.Email, + CreatedAt: u.CreatedAt, + Username: u.Username, + Name: u.Name, + }) + } + return converted +} diff --git a/coderd/users_test.go b/coderd/users_test.go index d733f022ae560..f0abea934d912 100644 --- a/coderd/users_test.go +++ b/coderd/users_test.go @@ -2,6 +2,7 @@ package coderd_test import ( "context" + "fmt" "net/http" "testing" @@ -531,3 +532,44 @@ func TestWorkspaceByUserAndName(t *testing.T) { require.NoError(t, err) }) } + +func TestPaginatedUsers(t *testing.T) { + t.Parallel() + ctx := context.Background() + client := coderdtest.New(t, nil) + coderdtest.CreateFirstUser(t, client) + me, err := client.User(context.Background(), codersdk.Me) + require.NoError(t, err) + + allUsers := make([]codersdk.User, 0) + allUsers = append(allUsers, me) + + org, err := client.CreateOrganization(ctx, me.ID, codersdk.CreateOrganizationRequest{ + Name: "default", + }) + require.NoError(t, err) + + total := 100 + // Create users + for i := 0; i < total; i++ { + newUser, err := client.CreateUser(context.Background(), codersdk.CreateUserRequest{ + Email: fmt.Sprintf("%d@coder.com", i), + Username: fmt.Sprintf("user%d", i), + Password: "password", + OrganizationID: org.ID, + }) + require.NoError(t, err) + allUsers = append(allUsers, newUser) + } + + limit := 10 + users, err := client.PaginatedUsers(ctx, codersdk.PaginatedUsersRequest{ + Limit: limit, + }) + require.NoError(t, err) + require.Equal(t, users.Page, allUsers[:limit]) + + users, err = client.PaginatedUsers(ctx, codersdk.PaginatedUsersRequest{After: users.Page[len(users.Page)-1].ID}) + require.NoError(t, err) + require.Equal(t, users.Page, allUsers[limit:limit*2]) +} diff --git a/codersdk/users.go b/codersdk/users.go index d6a920a4c4bdc..87ac69c1ea319 100644 --- a/codersdk/users.go +++ b/codersdk/users.go @@ -5,6 +5,7 @@ import ( "encoding/json" "fmt" "net/http" + "strconv" "time" "github.com/google/uuid" @@ -13,6 +14,16 @@ import ( // Me is used as a replacement for your own ID. var Me = uuid.Nil +type PaginatedUsersRequest struct { + After uuid.UUID + Before uuid.UUID + Limit int +} + +type PaginatedUsers struct { + Page []User `json:"page"` +} + // User represents a user in Coder. type User struct { ID uuid.UUID `json:"id" validate:"required"` @@ -197,6 +208,27 @@ func (c *Client) User(ctx context.Context, id uuid.UUID) (User, error) { return user, json.NewDecoder(res.Body).Decode(&user) } +func (c *Client) PaginatedUsers(ctx context.Context, req PaginatedUsersRequest) (PaginatedUsers, error) { + res, err := c.request(ctx, http.MethodGet, fmt.Sprintf("/api/v2/users"), nil, func(r *http.Request) { + q := r.URL.Query() + q.Set("before", req.Before.String()) + q.Set("after", req.After.String()) + q.Set("limit", strconv.Itoa(req.Limit)) + r.URL.RawQuery = q.Encode() + }) + if err != nil { + return PaginatedUsers{}, err + } + defer res.Body.Close() + + if res.StatusCode != http.StatusOK { + return PaginatedUsers{}, readBodyAsError(res) + } + + var users PaginatedUsers + return users, json.NewDecoder(res.Body).Decode(&users) +} + // OrganizationsByUser returns all organizations the user is a member of. func (c *Client) OrganizationsByUser(ctx context.Context, userID uuid.UUID) ([]Organization, error) { res, err := c.request(ctx, http.MethodGet, fmt.Sprintf("/api/v2/users/%s/organizations", uuidOrMe(userID)), nil) From cbd7cebaa7baeaf912df8ef2d770abe2fef27164 Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Mon, 18 Apr 2022 10:46:19 -0500 Subject: [PATCH 2/3] User pagination unit testing --- coderd/database/databasefake/databasefake.go | 6 +- coderd/database/querier.go | 3 +- coderd/database/queries.sql.go | 90 ++++++++++++++++++-- coderd/database/queries/users.sql | 46 ++++++++-- 4 files changed, 126 insertions(+), 19 deletions(-) diff --git a/coderd/database/databasefake/databasefake.go b/coderd/database/databasefake/databasefake.go index ce84b59e8499b..5754286a1cf41 100644 --- a/coderd/database/databasefake/databasefake.go +++ b/coderd/database/databasefake/databasefake.go @@ -1375,6 +1375,10 @@ func (q *fakeQuerier) DeleteGitSSHKey(_ context.Context, userID uuid.UUID) error return sql.ErrNoRows } -func (q *fakeQuerier) PaginatedUsers(ctx context.Context, arg database.PaginatedUsersParams) ([]database.User, error) { +func (q *fakeQuerier) PaginatedUsersAfter(ctx context.Context, arg database.PaginatedUsersAfterParams) ([]database.User, error) { + return nil, xerrors.Errorf("not implemented") +} + +func (q *fakeQuerier) PaginatedUsersBefore(ctx context.Context, arg database.PaginatedUsersBeforeParams) ([]database.User, error) { return nil, xerrors.Errorf("not implemented") } diff --git a/coderd/database/querier.go b/coderd/database/querier.go index 92fcc0a6969a1..0970ec7de776e 100644 --- a/coderd/database/querier.go +++ b/coderd/database/querier.go @@ -72,7 +72,8 @@ type querier interface { InsertWorkspaceAgent(ctx context.Context, arg InsertWorkspaceAgentParams) (WorkspaceAgent, error) InsertWorkspaceBuild(ctx context.Context, arg InsertWorkspaceBuildParams) (WorkspaceBuild, error) InsertWorkspaceResource(ctx context.Context, arg InsertWorkspaceResourceParams) (WorkspaceResource, error) - PaginatedUsers(ctx context.Context, arg PaginatedUsersParams) ([]User, error) + PaginatedUsersAfter(ctx context.Context, arg PaginatedUsersAfterParams) ([]User, error) + PaginatedUsersBefore(ctx context.Context, arg PaginatedUsersBeforeParams) ([]User, error) UpdateAPIKeyByID(ctx context.Context, arg UpdateAPIKeyByIDParams) error UpdateGitSSHKey(ctx context.Context, arg UpdateGitSSHKeyParams) error UpdateProvisionerDaemonByID(ctx context.Context, arg UpdateProvisionerDaemonByIDParams) error diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index f48aae6daa88e..c2b298a37c286 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -1905,7 +1905,7 @@ func (q *sqlQuerier) InsertUser(ctx context.Context, arg InsertUserParams) (User return i, err } -const paginatedUsers = `-- name: PaginatedUsers :many +const paginatedUsersAfter = `-- name: PaginatedUsersAfter :many SELECT id, email, name, revoked, login_type, hashed_password, created_at, updated_at, username FROM @@ -1914,26 +1914,98 @@ WHERE CASE WHEN $1::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN created_at > (SELECT created_at FROM users WHERE id = $1) - WHEN $2::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN - created_at < (SELECT created_at FROM users WHERE id = $2) - ELSE true + -- If the after field is not provided, just return the first page + ELSE true + END + AND + CASE + WHEN $2::text != '' THEN + email LIKE '%' || $2 || '%' + ELSE true END ORDER BY - -- TODO: When doing 'before', we need to flip this to DESC. - -- You cannot put 'ASC' or 'DESC' in a CASE statement. :'( created_at ASC LIMIT $3 ` -type PaginatedUsersParams struct { +type PaginatedUsersAfterParams struct { After uuid.UUID `db:"after" json:"after"` + Email string `db:"email" json:"email"` + LimitOpt int32 `db:"limit_opt" json:"limit_opt"` +} + +func (q *sqlQuerier) PaginatedUsersAfter(ctx context.Context, arg PaginatedUsersAfterParams) ([]User, error) { + rows, err := q.db.QueryContext(ctx, paginatedUsersAfter, arg.After, arg.Email, arg.LimitOpt) + if err != nil { + return nil, err + } + defer rows.Close() + var items []User + for rows.Next() { + var i User + if err := rows.Scan( + &i.ID, + &i.Email, + &i.Name, + &i.Revoked, + &i.LoginType, + &i.HashedPassword, + &i.CreatedAt, + &i.UpdatedAt, + &i.Username, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const paginatedUsersBefore = `-- name: PaginatedUsersBefore :many +SELECT users_before.id, users_before.email, users_before.name, users_before.revoked, users_before.login_type, users_before.hashed_password, users_before.created_at, users_before.updated_at, users_before.username FROM + (SELECT + id, email, name, revoked, login_type, hashed_password, created_at, updated_at, username + FROM + users + WHERE + CASE + WHEN $1::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN + created_at < (SELECT created_at FROM users WHERE id = $1) + -- If the 'before' field is not provided, this will return the last page. + -- Kinda odd, it's just a consequence of spliting the pagination queries into 2 + -- functions. + ELSE true + END + AND + CASE + WHEN $2::text != '' THEN + email LIKE '%' || $2 || '%' + ELSE true + END + ORDER BY + created_at DESC + LIMIT + $3) AS users_before +ORDER BY users_before.created_at ASC +` + +type PaginatedUsersBeforeParams struct { Before uuid.UUID `db:"before" json:"before"` + Email string `db:"email" json:"email"` LimitOpt int32 `db:"limit_opt" json:"limit_opt"` } -func (q *sqlQuerier) PaginatedUsers(ctx context.Context, arg PaginatedUsersParams) ([]User, error) { - rows, err := q.db.QueryContext(ctx, paginatedUsers, arg.After, arg.Before, arg.LimitOpt) +// Maintain the original ordering of the rows so the pages are the same order +// as PaginatedUsersAfter. +func (q *sqlQuerier) PaginatedUsersBefore(ctx context.Context, arg PaginatedUsersBeforeParams) ([]User, error) { + rows, err := q.db.QueryContext(ctx, paginatedUsersBefore, arg.Before, arg.Email, arg.LimitOpt) if err != nil { return nil, err } diff --git a/coderd/database/queries/users.sql b/coderd/database/queries/users.sql index d3709445bb607..7cd11ccf12623 100644 --- a/coderd/database/queries/users.sql +++ b/coderd/database/queries/users.sql @@ -53,7 +53,7 @@ WHERE id = $1 RETURNING *; --- name: PaginatedUsers :many +-- name: PaginatedUsersAfter :many SELECT * FROM @@ -62,15 +62,45 @@ WHERE CASE WHEN @after::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN created_at > (SELECT created_at FROM users WHERE id = @after) - WHEN @before::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN - created_at < (SELECT created_at FROM users WHERE id = @before) - ELSE true + -- If the after field is not provided, just return the first page + ELSE true + END + AND + CASE + WHEN @email::text != '' THEN + email LIKE '%' || @email || '%' + ELSE true END ORDER BY - -- TODO: When doing 'before', we need to flip this to DESC. - -- You cannot put 'ASC' or 'DESC' in a CASE statement. :'( - -- Until we figure this out, before is broken. - -- Another option is to do a subquery above created_at ASC LIMIT @limit_opt; + +-- name: PaginatedUsersBefore :many +SELECT users_before.* FROM + (SELECT + * + FROM + users + WHERE + CASE + WHEN @before::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN + created_at < (SELECT created_at FROM users WHERE id = @before) + -- If the 'before' field is not provided, this will return the last page. + -- Kinda odd, it's just a consequence of spliting the pagination queries into 2 + -- functions. + ELSE true + END + AND + CASE + WHEN @email::text != '' THEN + email LIKE '%' || @email || '%' + ELSE true + END + ORDER BY + created_at DESC + LIMIT + @limit_opt) AS users_before +-- Maintain the original ordering of the rows so the pages are the same order +-- as PaginatedUsersAfter. +ORDER BY users_before.created_at ASC; From 60e0059bbeda078aab52e97016c89e54cde29fdd Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Mon, 18 Apr 2022 14:29:00 -0500 Subject: [PATCH 3/3] Fix rate limit --- coderd/coderd.go | 6 +- coderd/coderdtest/coderdtest.go | 2 + coderd/httpmw/ratelimit.go | 6 ++ coderd/users.go | 106 +++++++++++++++++++++++--------- coderd/users_test.go | 83 +++++++++++++++++++++---- codersdk/users.go | 27 ++++++-- 6 files changed, 184 insertions(+), 46 deletions(-) diff --git a/coderd/coderd.go b/coderd/coderd.go index 536f53dfd3757..94d2ebdbbea20 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -38,6 +38,7 @@ type Options struct { SecureAuthCookie bool SSHKeygenAlgorithm gitsshkey.Algorithm + APIRateLimit int } // New constructs the Coder API into an HTTP handler. @@ -48,6 +49,9 @@ func New(options *Options) (http.Handler, func()) { if options.AgentConnectionUpdateFrequency == 0 { options.AgentConnectionUpdateFrequency = 3 * time.Second } + if options.APIRateLimit == 0 { + options.APIRateLimit = 512 + } api := &api{ Options: options, } @@ -57,7 +61,7 @@ func New(options *Options) (http.Handler, func()) { r.Use( chitrace.Middleware(), // Specific routes can specify smaller limits. - httpmw.RateLimitPerMinute(512), + httpmw.RateLimitPerMinute(options.APIRateLimit), debugLogRequest(api.Logger), ) r.Get("/", func(w http.ResponseWriter, r *http.Request) { diff --git a/coderd/coderdtest/coderdtest.go b/coderd/coderdtest/coderdtest.go index 7f30e992d23d6..f6d405bca1342 100644 --- a/coderd/coderdtest/coderdtest.go +++ b/coderd/coderdtest/coderdtest.go @@ -51,6 +51,7 @@ type Options struct { AWSInstanceIdentity awsidentity.Certificates GoogleInstanceIdentity *idtoken.Validator SSHKeygenAlgorithm gitsshkey.Algorithm + APIRateLimit int } // New constructs an in-memory coderd instance and returns @@ -117,6 +118,7 @@ func New(t *testing.T, options *Options) *codersdk.Client { AWSCertificates: options.AWSInstanceIdentity, GoogleTokenValidator: options.GoogleInstanceIdentity, SSHKeygenAlgorithm: options.SSHKeygenAlgorithm, + APIRateLimit: options.APIRateLimit, }) t.Cleanup(func() { srv.Close() diff --git a/coderd/httpmw/ratelimit.go b/coderd/httpmw/ratelimit.go index 293d71efae3e1..41a8c030f19ee 100644 --- a/coderd/httpmw/ratelimit.go +++ b/coderd/httpmw/ratelimit.go @@ -14,6 +14,12 @@ import ( // RateLimitPerMinute returns a handler that limits requests per-minute based // on IP, endpoint, and user ID (if available). func RateLimitPerMinute(count int) func(http.Handler) http.Handler { + // -1 is no rate limit + if count == -1 { + return func(handler http.Handler) http.Handler { + return handler + } + } return httprate.Limit( count, 1*time.Minute, diff --git a/coderd/users.go b/coderd/users.go index e6cb2c011e3d7..53866e31058d3 100644 --- a/coderd/users.go +++ b/coderd/users.go @@ -7,6 +7,7 @@ import ( "errors" "fmt" "net/http" + "net/url" "strconv" "time" @@ -148,9 +149,10 @@ func (api *api) postFirstUser(rw http.ResponseWriter, r *http.Request) { func (api *api) getPaginatedUsers(rw http.ResponseWriter, r *http.Request) { var ( - beforeArg = r.URL.Query().Get("before") - afterArg = r.URL.Query().Get("after") - limitArg = r.URL.Query().Get("limit") + beforeArg = r.URL.Query().Get("before") + afterArg = r.URL.Query().Get("after") + limitArg = r.URL.Query().Get("limit") + searchEmail = r.URL.Query().Get("email") ) limit, err := strconv.Atoi(limitArg) @@ -165,35 +167,57 @@ func (api *api) getPaginatedUsers(rw http.ResponseWriter, r *http.Request) { limit = 10 } - var before uuid.UUID - var after uuid.UUID - if beforeArg != "" { - before, err = uuid.Parse(beforeArg) - if err != nil { - httpapi.Write(rw, http.StatusBadRequest, httpapi.Response{ - Message: fmt.Sprintf("before must be a uuid: %s", err.Error()), - }) - return - } + if beforeArg != "" && afterArg != "" { + httpapi.Write(rw, http.StatusBadRequest, httpapi.Response{ + Message: fmt.Sprintf("cannot provide both 'after' and 'before'"), + }) + return } - if afterArg != "" { - after, err = uuid.Parse(afterArg) - if err != nil { - httpapi.Write(rw, http.StatusBadRequest, httpapi.Response{ - Message: fmt.Sprintf("after must be a uuid: %s", err.Error()), - }) - return + pagerFields := codersdk.PagerFields{ + Limit: limit, + } + + var useBefore bool + var cursor uuid.UUID + var cursorErr error + if beforeArg != "" && afterArg != "" { + httpapi.Write(rw, http.StatusBadRequest, httpapi.Response{ + Message: fmt.Sprintf("cannot provide both 'after' and 'before'"), + }) + return + } else if beforeArg != "" { + // Last is a special word to indicate the last page + if beforeArg != "last" { + cursor, cursorErr = uuid.Parse(beforeArg) } + useBefore = true + } else if afterArg != "" { + cursor, cursorErr = uuid.Parse(afterArg) + // Special keyword just incase this is the last page + pagerFields.EndingBefore = "last" + } + if cursorErr != nil { + httpapi.Write(rw, http.StatusBadRequest, httpapi.Response{ + Message: fmt.Sprintf("cursor must be a uuid: %s", cursorErr.Error()), + }) + return } - var _, _ = before, after - users, err := api.Database.PaginatedUsers(r.Context(), database.PaginatedUsersParams{ - Before: before, - After: after, - LimitOpt: int32(limit), - }) - //users, err := api.Database.PaginatedUsers(r.Context(), int32(limit)) + var users []database.User + if useBefore { + users, err = api.Database.PaginatedUsersBefore(r.Context(), database.PaginatedUsersBeforeParams{ + Before: cursor, + Email: searchEmail, + LimitOpt: int32(limit), + }) + } else { + users, err = api.Database.PaginatedUsersAfter(r.Context(), database.PaginatedUsersAfterParams{ + After: cursor, + Email: searchEmail, + LimitOpt: int32(limit), + }) + } if err != nil { httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ Message: err.Error(), @@ -201,9 +225,35 @@ func (api *api) getPaginatedUsers(rw http.ResponseWriter, r *http.Request) { return } + // If we have users, we can build the pager fields + if len(users) != 0 { + first := users[0].ID + last := users[len(users)-1].ID + vals := r.URL.Query() + vals.Del("before") + vals.Del("after") + + prev := make(url.Values) + for k, v := range vals { + prev[k] = v + } + next := vals + prev.Set("before", first.String()) + next.Set("after", last.String()) + + pagerFields = codersdk.PagerFields{ + EndingBefore: first.String(), + StartingAfter: last.String(), + NextURI: r.URL.Path + "?" + next.Encode(), + PreviousURI: r.URL.Path + "?" + prev.Encode(), + Limit: limit, + } + } + render.Status(r, http.StatusOK) render.JSON(rw, r, codersdk.PaginatedUsers{ - Page: convertUsers(users), + Pager: pagerFields, + Page: convertUsers(users), }) } diff --git a/coderd/users_test.go b/coderd/users_test.go index f0abea934d912..464edaf11e692 100644 --- a/coderd/users_test.go +++ b/coderd/users_test.go @@ -536,40 +536,101 @@ func TestWorkspaceByUserAndName(t *testing.T) { func TestPaginatedUsers(t *testing.T) { t.Parallel() ctx := context.Background() - client := coderdtest.New(t, nil) + client := coderdtest.New(t, &coderdtest.Options{APIRateLimit: -1}) coderdtest.CreateFirstUser(t, client) me, err := client.User(context.Background(), codersdk.Me) require.NoError(t, err) allUsers := make([]codersdk.User, 0) allUsers = append(allUsers, me) + gmailUsers := make([]codersdk.User, 0) org, err := client.CreateOrganization(ctx, me.ID, codersdk.CreateOrganizationRequest{ Name: "default", }) require.NoError(t, err) + // When 100 users exist total := 100 // Create users for i := 0; i < total; i++ { + email := fmt.Sprintf("%d@coder.com", i) + if i%2 == 0 { + email = fmt.Sprintf("%d@gmail.com", i) + } newUser, err := client.CreateUser(context.Background(), codersdk.CreateUserRequest{ - Email: fmt.Sprintf("%d@coder.com", i), + Email: email, Username: fmt.Sprintf("user%d", i), Password: "password", OrganizationID: org.ID, }) require.NoError(t, err) allUsers = append(allUsers, newUser) + if i%2 == 0 { + gmailUsers = append(gmailUsers, newUser) + } } - limit := 10 - users, err := client.PaginatedUsers(ctx, codersdk.PaginatedUsersRequest{ - Limit: limit, - }) - require.NoError(t, err) - require.Equal(t, users.Page, allUsers[:limit]) + assertPagination(t, ctx, client, 10, allUsers, nil) + assertPagination(t, ctx, client, 5, allUsers, nil) + assertPagination(t, ctx, client, 3, allUsers, nil) + assertPagination(t, ctx, client, 1, allUsers, nil) - users, err = client.PaginatedUsers(ctx, codersdk.PaginatedUsersRequest{After: users.Page[len(users.Page)-1].ID}) - require.NoError(t, err) - require.Equal(t, users.Page, allUsers[limit:limit*2]) + // Try a search + gmailSearch := func(request codersdk.PaginatedUsersRequest) codersdk.PaginatedUsersRequest { + request.SearchEmail = "gmail" + return request + } + assertPagination(t, ctx, client, 3, gmailUsers, gmailSearch) + assertPagination(t, ctx, client, 7, gmailUsers, gmailSearch) + assertPagination(t, ctx, client, 1, gmailUsers, gmailSearch) +} + +func assertPagination(t *testing.T, ctx context.Context, client *codersdk.Client, limit int, allUsers []codersdk.User, + opt func(request codersdk.PaginatedUsersRequest) codersdk.PaginatedUsersRequest) { + var count int + if opt == nil { + opt = func(request codersdk.PaginatedUsersRequest) codersdk.PaginatedUsersRequest { + return request + } + } + + // Check the first page + page, err := client.PaginatedUsers(ctx, opt(codersdk.PaginatedUsersRequest{ + Limit: limit, + })) + require.NoError(t, err, "first page") + require.Equal(t, page.Page, allUsers[:limit]) + require.Equal(t, page.Pager.Limit, limit, "expected limit") + count += len(page.Page) + + for { + if page.Pager.StartingAfter == "" { + break + } + // Assert each page is the next expected page + page, err = client.PaginatedUsers(ctx, opt(codersdk.PaginatedUsersRequest{ + Limit: limit, + After: page.Pager.StartingAfter, + })) + require.NoError(t, err, "next page") + + var expected []codersdk.User + if count+limit > len(allUsers) { + expected = allUsers[count:] + } else { + expected = allUsers[count : count+limit] + } + require.Equal(t, page.Page, expected, "next users") + + // Also check the before + prevPage, err := client.PaginatedUsers(ctx, opt(codersdk.PaginatedUsersRequest{ + After: "", + Before: page.Pager.EndingBefore, + Limit: limit, + })) + require.NoError(t, err, "prev page") + require.Equal(t, allUsers[count-limit:count], prevPage.Page, "prev users") + count += len(page.Page) + } } diff --git a/codersdk/users.go b/codersdk/users.go index 87ac69c1ea319..8f42992a8a740 100644 --- a/codersdk/users.go +++ b/codersdk/users.go @@ -15,13 +15,23 @@ import ( var Me = uuid.Nil type PaginatedUsersRequest struct { - After uuid.UUID - Before uuid.UUID - Limit int + After string `json:"after"` + Before string `json:"before"` + Limit int `json:"limit"` + SearchEmail string `json:"search_email"` +} + +type PagerFields struct { + EndingBefore string `json:"ending_before"` + StartingAfter string `json:"starting_after"` + NextURI string `json:"next_uri"` + PreviousURI string `json:"previous_uri"` + Limit int `json:"limit"` } type PaginatedUsers struct { - Page []User `json:"page"` + Pager PagerFields `json:"pager"` + Page []User `json:"page"` } // User represents a user in Coder. @@ -211,8 +221,13 @@ func (c *Client) User(ctx context.Context, id uuid.UUID) (User, error) { func (c *Client) PaginatedUsers(ctx context.Context, req PaginatedUsersRequest) (PaginatedUsers, error) { res, err := c.request(ctx, http.MethodGet, fmt.Sprintf("/api/v2/users"), nil, func(r *http.Request) { q := r.URL.Query() - q.Set("before", req.Before.String()) - q.Set("after", req.After.String()) + if req.Before != "" { + q.Set("before", req.Before) + } + if req.After != "" { + q.Set("after", req.After) + } + q.Set("email", req.SearchEmail) q.Set("limit", strconv.Itoa(req.Limit)) r.URL.RawQuery = q.Encode() })