diff --git a/coderd/coderd.go b/coderd/coderd.go index 83bb6ec78e4ca..94d2ebdbbea20 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -38,6 +38,7 @@ type Options struct { SecureAuthCookie bool SSHKeygenAlgorithm gitsshkey.Algorithm + APIRateLimit int } // New constructs the Coder API into an HTTP handler. @@ -48,6 +49,9 @@ func New(options *Options) (http.Handler, func()) { if options.AgentConnectionUpdateFrequency == 0 { options.AgentConnectionUpdateFrequency = 3 * time.Second } + if options.APIRateLimit == 0 { + options.APIRateLimit = 512 + } api := &api{ Options: options, } @@ -57,7 +61,7 @@ func New(options *Options) (http.Handler, func()) { r.Use( chitrace.Middleware(), // Specific routes can specify smaller limits. - httpmw.RateLimitPerMinute(512), + httpmw.RateLimitPerMinute(options.APIRateLimit), debugLogRequest(api.Logger), ) r.Get("/", func(w http.ResponseWriter, r *http.Request) { @@ -140,6 +144,7 @@ func New(options *Options) (http.Handler, func()) { }) }) r.Route("/users", func(r chi.Router) { + r.Get("/", api.getPaginatedUsers) r.Get("/first", api.firstUser) r.Post("/first", api.postFirstUser) r.Post("/login", api.postLogin) diff --git a/coderd/coderdtest/coderdtest.go b/coderd/coderdtest/coderdtest.go index 7f30e992d23d6..f6d405bca1342 100644 --- a/coderd/coderdtest/coderdtest.go +++ b/coderd/coderdtest/coderdtest.go @@ -51,6 +51,7 @@ type Options struct { AWSInstanceIdentity awsidentity.Certificates GoogleInstanceIdentity *idtoken.Validator SSHKeygenAlgorithm gitsshkey.Algorithm + APIRateLimit int } // New constructs an in-memory coderd instance and returns @@ -117,6 +118,7 @@ func New(t *testing.T, options *Options) *codersdk.Client { AWSCertificates: options.AWSInstanceIdentity, GoogleTokenValidator: options.GoogleInstanceIdentity, SSHKeygenAlgorithm: options.SSHKeygenAlgorithm, + APIRateLimit: options.APIRateLimit, }) t.Cleanup(func() { srv.Close() diff --git a/coderd/database/databasefake/databasefake.go b/coderd/database/databasefake/databasefake.go index 36b1d822dbcda..5754286a1cf41 100644 --- a/coderd/database/databasefake/databasefake.go +++ b/coderd/database/databasefake/databasefake.go @@ -3,6 +3,7 @@ package databasefake import ( "context" "database/sql" + "golang.org/x/xerrors" "strings" "sync" @@ -1373,3 +1374,11 @@ func (q *fakeQuerier) DeleteGitSSHKey(_ context.Context, userID uuid.UUID) error } return sql.ErrNoRows } + +func (q *fakeQuerier) PaginatedUsersAfter(ctx context.Context, arg database.PaginatedUsersAfterParams) ([]database.User, error) { + return nil, xerrors.Errorf("not implemented") +} + +func (q *fakeQuerier) PaginatedUsersBefore(ctx context.Context, arg database.PaginatedUsersBeforeParams) ([]database.User, error) { + return nil, xerrors.Errorf("not implemented") +} diff --git a/coderd/database/querier.go b/coderd/database/querier.go index 073113a2451df..0970ec7de776e 100644 --- a/coderd/database/querier.go +++ b/coderd/database/querier.go @@ -72,6 +72,8 @@ type querier interface { InsertWorkspaceAgent(ctx context.Context, arg InsertWorkspaceAgentParams) (WorkspaceAgent, error) InsertWorkspaceBuild(ctx context.Context, arg InsertWorkspaceBuildParams) (WorkspaceBuild, error) InsertWorkspaceResource(ctx context.Context, arg InsertWorkspaceResourceParams) (WorkspaceResource, error) + PaginatedUsersAfter(ctx context.Context, arg PaginatedUsersAfterParams) ([]User, error) + PaginatedUsersBefore(ctx context.Context, arg PaginatedUsersBeforeParams) ([]User, error) UpdateAPIKeyByID(ctx context.Context, arg UpdateAPIKeyByIDParams) error UpdateGitSSHKey(ctx context.Context, arg UpdateGitSSHKeyParams) error UpdateProvisionerDaemonByID(ctx context.Context, arg UpdateProvisionerDaemonByIDParams) error diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index e816b90c57a17..c2b298a37c286 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -1905,6 +1905,138 @@ func (q *sqlQuerier) InsertUser(ctx context.Context, arg InsertUserParams) (User return i, err } +const paginatedUsersAfter = `-- name: PaginatedUsersAfter :many +SELECT + id, email, name, revoked, login_type, hashed_password, created_at, updated_at, username +FROM + users +WHERE + CASE + WHEN $1::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN + created_at > (SELECT created_at FROM users WHERE id = $1) + -- If the after field is not provided, just return the first page + ELSE true + END + AND + CASE + WHEN $2::text != '' THEN + email LIKE '%' || $2 || '%' + ELSE true + END +ORDER BY + created_at ASC +LIMIT + $3 +` + +type PaginatedUsersAfterParams struct { + After uuid.UUID `db:"after" json:"after"` + Email string `db:"email" json:"email"` + LimitOpt int32 `db:"limit_opt" json:"limit_opt"` +} + +func (q *sqlQuerier) PaginatedUsersAfter(ctx context.Context, arg PaginatedUsersAfterParams) ([]User, error) { + rows, err := q.db.QueryContext(ctx, paginatedUsersAfter, arg.After, arg.Email, arg.LimitOpt) + if err != nil { + return nil, err + } + defer rows.Close() + var items []User + for rows.Next() { + var i User + if err := rows.Scan( + &i.ID, + &i.Email, + &i.Name, + &i.Revoked, + &i.LoginType, + &i.HashedPassword, + &i.CreatedAt, + &i.UpdatedAt, + &i.Username, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const paginatedUsersBefore = `-- name: PaginatedUsersBefore :many +SELECT users_before.id, users_before.email, users_before.name, users_before.revoked, users_before.login_type, users_before.hashed_password, users_before.created_at, users_before.updated_at, users_before.username FROM + (SELECT + id, email, name, revoked, login_type, hashed_password, created_at, updated_at, username + FROM + users + WHERE + CASE + WHEN $1::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN + created_at < (SELECT created_at FROM users WHERE id = $1) + -- If the 'before' field is not provided, this will return the last page. + -- Kinda odd, it's just a consequence of spliting the pagination queries into 2 + -- functions. + ELSE true + END + AND + CASE + WHEN $2::text != '' THEN + email LIKE '%' || $2 || '%' + ELSE true + END + ORDER BY + created_at DESC + LIMIT + $3) AS users_before +ORDER BY users_before.created_at ASC +` + +type PaginatedUsersBeforeParams struct { + Before uuid.UUID `db:"before" json:"before"` + Email string `db:"email" json:"email"` + LimitOpt int32 `db:"limit_opt" json:"limit_opt"` +} + +// Maintain the original ordering of the rows so the pages are the same order +// as PaginatedUsersAfter. +func (q *sqlQuerier) PaginatedUsersBefore(ctx context.Context, arg PaginatedUsersBeforeParams) ([]User, error) { + rows, err := q.db.QueryContext(ctx, paginatedUsersBefore, arg.Before, arg.Email, arg.LimitOpt) + if err != nil { + return nil, err + } + defer rows.Close() + var items []User + for rows.Next() { + var i User + if err := rows.Scan( + &i.ID, + &i.Email, + &i.Name, + &i.Revoked, + &i.LoginType, + &i.HashedPassword, + &i.CreatedAt, + &i.UpdatedAt, + &i.Username, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const updateUserProfile = `-- name: UpdateUserProfile :one UPDATE users diff --git a/coderd/database/queries/users.sql b/coderd/database/queries/users.sql index f41a96877d4d6..7cd11ccf12623 100644 --- a/coderd/database/queries/users.sql +++ b/coderd/database/queries/users.sql @@ -51,3 +51,56 @@ SET updated_at = $5 WHERE id = $1 RETURNING *; + + +-- name: PaginatedUsersAfter :many +SELECT + * +FROM + users +WHERE + CASE + WHEN @after::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN + created_at > (SELECT created_at FROM users WHERE id = @after) + -- If the after field is not provided, just return the first page + ELSE true + END + AND + CASE + WHEN @email::text != '' THEN + email LIKE '%' || @email || '%' + ELSE true + END +ORDER BY + created_at ASC +LIMIT + @limit_opt; + +-- name: PaginatedUsersBefore :many +SELECT users_before.* FROM + (SELECT + * + FROM + users + WHERE + CASE + WHEN @before::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN + created_at < (SELECT created_at FROM users WHERE id = @before) + -- If the 'before' field is not provided, this will return the last page. + -- Kinda odd, it's just a consequence of spliting the pagination queries into 2 + -- functions. + ELSE true + END + AND + CASE + WHEN @email::text != '' THEN + email LIKE '%' || @email || '%' + ELSE true + END + ORDER BY + created_at DESC + LIMIT + @limit_opt) AS users_before +-- Maintain the original ordering of the rows so the pages are the same order +-- as PaginatedUsersAfter. +ORDER BY users_before.created_at ASC; diff --git a/coderd/httpmw/ratelimit.go b/coderd/httpmw/ratelimit.go index 293d71efae3e1..41a8c030f19ee 100644 --- a/coderd/httpmw/ratelimit.go +++ b/coderd/httpmw/ratelimit.go @@ -14,6 +14,12 @@ import ( // RateLimitPerMinute returns a handler that limits requests per-minute based // on IP, endpoint, and user ID (if available). func RateLimitPerMinute(count int) func(http.Handler) http.Handler { + // -1 is no rate limit + if count == -1 { + return func(handler http.Handler) http.Handler { + return handler + } + } return httprate.Limit( count, 1*time.Minute, diff --git a/coderd/users.go b/coderd/users.go index e9a93d0f2a50e..53866e31058d3 100644 --- a/coderd/users.go +++ b/coderd/users.go @@ -7,6 +7,8 @@ import ( "errors" "fmt" "net/http" + "net/url" + "strconv" "time" "github.com/go-chi/chi/v5" @@ -145,6 +147,116 @@ func (api *api) postFirstUser(rw http.ResponseWriter, r *http.Request) { }) } +func (api *api) getPaginatedUsers(rw http.ResponseWriter, r *http.Request) { + var ( + beforeArg = r.URL.Query().Get("before") + afterArg = r.URL.Query().Get("after") + limitArg = r.URL.Query().Get("limit") + searchEmail = r.URL.Query().Get("email") + ) + + limit, err := strconv.Atoi(limitArg) + if limitArg != "" && err != nil { + httpapi.Write(rw, http.StatusBadRequest, httpapi.Response{ + Message: fmt.Sprintf("limit must be an integer: %s", err.Error()), + }) + return + } + if limit <= 0 { + // Default + limit = 10 + } + + if beforeArg != "" && afterArg != "" { + httpapi.Write(rw, http.StatusBadRequest, httpapi.Response{ + Message: fmt.Sprintf("cannot provide both 'after' and 'before'"), + }) + return + } + + pagerFields := codersdk.PagerFields{ + Limit: limit, + } + + var useBefore bool + var cursor uuid.UUID + var cursorErr error + if beforeArg != "" && afterArg != "" { + httpapi.Write(rw, http.StatusBadRequest, httpapi.Response{ + Message: fmt.Sprintf("cannot provide both 'after' and 'before'"), + }) + return + } else if beforeArg != "" { + // Last is a special word to indicate the last page + if beforeArg != "last" { + cursor, cursorErr = uuid.Parse(beforeArg) + } + useBefore = true + } else if afterArg != "" { + cursor, cursorErr = uuid.Parse(afterArg) + // Special keyword just incase this is the last page + pagerFields.EndingBefore = "last" + } + if cursorErr != nil { + httpapi.Write(rw, http.StatusBadRequest, httpapi.Response{ + Message: fmt.Sprintf("cursor must be a uuid: %s", cursorErr.Error()), + }) + return + } + + var users []database.User + if useBefore { + users, err = api.Database.PaginatedUsersBefore(r.Context(), database.PaginatedUsersBeforeParams{ + Before: cursor, + Email: searchEmail, + LimitOpt: int32(limit), + }) + } else { + users, err = api.Database.PaginatedUsersAfter(r.Context(), database.PaginatedUsersAfterParams{ + After: cursor, + Email: searchEmail, + LimitOpt: int32(limit), + }) + } + if err != nil { + httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ + Message: err.Error(), + }) + return + } + + // If we have users, we can build the pager fields + if len(users) != 0 { + first := users[0].ID + last := users[len(users)-1].ID + vals := r.URL.Query() + vals.Del("before") + vals.Del("after") + + prev := make(url.Values) + for k, v := range vals { + prev[k] = v + } + next := vals + prev.Set("before", first.String()) + next.Set("after", last.String()) + + pagerFields = codersdk.PagerFields{ + EndingBefore: first.String(), + StartingAfter: last.String(), + NextURI: r.URL.Path + "?" + next.Encode(), + PreviousURI: r.URL.Path + "?" + prev.Encode(), + Limit: limit, + } + } + + render.Status(r, http.StatusOK) + render.JSON(rw, r, codersdk.PaginatedUsers{ + Pager: pagerFields, + Page: convertUsers(users), + }) +} + // Creates a new user. func (api *api) postUsers(rw http.ResponseWriter, r *http.Request) { apiKey := httpmw.APIKey(r) @@ -939,3 +1051,17 @@ func convertUser(user database.User) codersdk.User { Name: user.Name, } } + +func convertUsers(users []database.User) []codersdk.User { + converted := make([]codersdk.User, 0, len(users)) + for _, u := range users { + converted = append(converted, codersdk.User{ + ID: u.ID, + Email: u.Email, + CreatedAt: u.CreatedAt, + Username: u.Username, + Name: u.Name, + }) + } + return converted +} diff --git a/coderd/users_test.go b/coderd/users_test.go index d733f022ae560..464edaf11e692 100644 --- a/coderd/users_test.go +++ b/coderd/users_test.go @@ -2,6 +2,7 @@ package coderd_test import ( "context" + "fmt" "net/http" "testing" @@ -531,3 +532,105 @@ func TestWorkspaceByUserAndName(t *testing.T) { require.NoError(t, err) }) } + +func TestPaginatedUsers(t *testing.T) { + t.Parallel() + ctx := context.Background() + client := coderdtest.New(t, &coderdtest.Options{APIRateLimit: -1}) + coderdtest.CreateFirstUser(t, client) + me, err := client.User(context.Background(), codersdk.Me) + require.NoError(t, err) + + allUsers := make([]codersdk.User, 0) + allUsers = append(allUsers, me) + gmailUsers := make([]codersdk.User, 0) + + org, err := client.CreateOrganization(ctx, me.ID, codersdk.CreateOrganizationRequest{ + Name: "default", + }) + require.NoError(t, err) + + // When 100 users exist + total := 100 + // Create users + for i := 0; i < total; i++ { + email := fmt.Sprintf("%d@coder.com", i) + if i%2 == 0 { + email = fmt.Sprintf("%d@gmail.com", i) + } + newUser, err := client.CreateUser(context.Background(), codersdk.CreateUserRequest{ + Email: email, + Username: fmt.Sprintf("user%d", i), + Password: "password", + OrganizationID: org.ID, + }) + require.NoError(t, err) + allUsers = append(allUsers, newUser) + if i%2 == 0 { + gmailUsers = append(gmailUsers, newUser) + } + } + + assertPagination(t, ctx, client, 10, allUsers, nil) + assertPagination(t, ctx, client, 5, allUsers, nil) + assertPagination(t, ctx, client, 3, allUsers, nil) + assertPagination(t, ctx, client, 1, allUsers, nil) + + // Try a search + gmailSearch := func(request codersdk.PaginatedUsersRequest) codersdk.PaginatedUsersRequest { + request.SearchEmail = "gmail" + return request + } + assertPagination(t, ctx, client, 3, gmailUsers, gmailSearch) + assertPagination(t, ctx, client, 7, gmailUsers, gmailSearch) + assertPagination(t, ctx, client, 1, gmailUsers, gmailSearch) +} + +func assertPagination(t *testing.T, ctx context.Context, client *codersdk.Client, limit int, allUsers []codersdk.User, + opt func(request codersdk.PaginatedUsersRequest) codersdk.PaginatedUsersRequest) { + var count int + if opt == nil { + opt = func(request codersdk.PaginatedUsersRequest) codersdk.PaginatedUsersRequest { + return request + } + } + + // Check the first page + page, err := client.PaginatedUsers(ctx, opt(codersdk.PaginatedUsersRequest{ + Limit: limit, + })) + require.NoError(t, err, "first page") + require.Equal(t, page.Page, allUsers[:limit]) + require.Equal(t, page.Pager.Limit, limit, "expected limit") + count += len(page.Page) + + for { + if page.Pager.StartingAfter == "" { + break + } + // Assert each page is the next expected page + page, err = client.PaginatedUsers(ctx, opt(codersdk.PaginatedUsersRequest{ + Limit: limit, + After: page.Pager.StartingAfter, + })) + require.NoError(t, err, "next page") + + var expected []codersdk.User + if count+limit > len(allUsers) { + expected = allUsers[count:] + } else { + expected = allUsers[count : count+limit] + } + require.Equal(t, page.Page, expected, "next users") + + // Also check the before + prevPage, err := client.PaginatedUsers(ctx, opt(codersdk.PaginatedUsersRequest{ + After: "", + Before: page.Pager.EndingBefore, + Limit: limit, + })) + require.NoError(t, err, "prev page") + require.Equal(t, allUsers[count-limit:count], prevPage.Page, "prev users") + count += len(page.Page) + } +} diff --git a/codersdk/users.go b/codersdk/users.go index d6a920a4c4bdc..8f42992a8a740 100644 --- a/codersdk/users.go +++ b/codersdk/users.go @@ -5,6 +5,7 @@ import ( "encoding/json" "fmt" "net/http" + "strconv" "time" "github.com/google/uuid" @@ -13,6 +14,26 @@ import ( // Me is used as a replacement for your own ID. var Me = uuid.Nil +type PaginatedUsersRequest struct { + After string `json:"after"` + Before string `json:"before"` + Limit int `json:"limit"` + SearchEmail string `json:"search_email"` +} + +type PagerFields struct { + EndingBefore string `json:"ending_before"` + StartingAfter string `json:"starting_after"` + NextURI string `json:"next_uri"` + PreviousURI string `json:"previous_uri"` + Limit int `json:"limit"` +} + +type PaginatedUsers struct { + Pager PagerFields `json:"pager"` + Page []User `json:"page"` +} + // User represents a user in Coder. type User struct { ID uuid.UUID `json:"id" validate:"required"` @@ -197,6 +218,32 @@ func (c *Client) User(ctx context.Context, id uuid.UUID) (User, error) { return user, json.NewDecoder(res.Body).Decode(&user) } +func (c *Client) PaginatedUsers(ctx context.Context, req PaginatedUsersRequest) (PaginatedUsers, error) { + res, err := c.request(ctx, http.MethodGet, fmt.Sprintf("/api/v2/users"), nil, func(r *http.Request) { + q := r.URL.Query() + if req.Before != "" { + q.Set("before", req.Before) + } + if req.After != "" { + q.Set("after", req.After) + } + q.Set("email", req.SearchEmail) + q.Set("limit", strconv.Itoa(req.Limit)) + r.URL.RawQuery = q.Encode() + }) + if err != nil { + return PaginatedUsers{}, err + } + defer res.Body.Close() + + if res.StatusCode != http.StatusOK { + return PaginatedUsers{}, readBodyAsError(res) + } + + var users PaginatedUsers + return users, json.NewDecoder(res.Body).Decode(&users) +} + // OrganizationsByUser returns all organizations the user is a member of. func (c *Client) OrganizationsByUser(ctx context.Context, userID uuid.UUID) ([]Organization, error) { res, err := c.request(ctx, http.MethodGet, fmt.Sprintf("/api/v2/users/%s/organizations", uuidOrMe(userID)), nil)