From 359f978a3b7efcfbb3d2cb37f502d8e38c2cb09f Mon Sep 17 00:00:00 2001 From: Kyle Carberry Date: Sat, 22 Jan 2022 20:27:38 +0000 Subject: [PATCH 1/7] feat: Add organizations endpoint for users This moves the /user endpoint to /users/me instead. This will reduce code duplication. This adds /users//organizations to list organizations a user has access to. It doesn't contain the permissions a user has over the organizations, but that will come in a future contribution. --- codecov.yml | 2 + coderd/coderd.go | 19 ++-- coderd/coderdtest/coderdtest.go | 51 ++++++---- coderd/coderdtest/coderdtest_test.go | 3 +- coderd/organizations.go | 25 +++++ coderd/users.go | 98 ++++++++++++------- coderd/users_test.go | 38 ++++++-- codersdk/users.go | 55 +++++++---- codersdk/users_test.go | 9 +- database/databasefake/databasefake.go | 63 ++++++++++++- database/querier.go | 4 + database/query.sql | 14 +++ database/query.sql.go | 129 ++++++++++++++++++++++++++ httpmw/user.go | 51 ---------- httpmw/user_test.go | 89 ------------------ httpmw/userparam.go | 53 +++++++++++ httpmw/userparam_test.go | 95 +++++++++++++++++++ peer/conn_test.go | 2 +- site/components/SignIn/SignInForm.tsx | 2 +- site/contexts/UserContext.tsx | 2 +- 20 files changed, 569 insertions(+), 235 deletions(-) create mode 100644 coderd/organizations.go delete mode 100644 httpmw/user.go delete mode 100644 httpmw/user_test.go create mode 100644 httpmw/userparam.go create mode 100644 httpmw/userparam_test.go diff --git a/codecov.yml b/codecov.yml index 472f75ef41099..faa4e2a91ec30 100644 --- a/codecov.yml +++ b/codecov.yml @@ -23,5 +23,7 @@ ignore: # This is generated code. - database/models.go - database/query.sql.go + # All coderd tests fail if this doesn't work. + - database/databasefake - peerbroker/proto - provisionersdk/proto diff --git a/coderd/coderd.go b/coderd/coderd.go index 16a69a918d683..a9de5880d23b0 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -31,15 +31,18 @@ func New(options *Options) http.Handler { Message: "👋", }) }) - r.Post("/user", users.createInitialUser) r.Post("/login", users.loginWithPassword) - // Require an API key and authenticated user for this group. - r.Group(func(r chi.Router) { - r.Use( - httpmw.ExtractAPIKey(options.Database, nil), - httpmw.ExtractUser(options.Database), - ) - r.Get("/user", users.authenticatedUser) + r.Route("/users", func(r chi.Router) { + r.Post("/", users.createInitialUser) + + r.Group(func(r chi.Router) { + r.Use( + httpmw.ExtractAPIKey(options.Database, nil), + httpmw.ExtractUserParam(options.Database), + ) + r.Get("/{user}", users.getUser) + r.Get("/{user}/organizations", users.getUserOrganizations) + }) }) }) r.NotFound(site.Handler().ServeHTTP) diff --git a/coderd/coderdtest/coderdtest.go b/coderd/coderdtest/coderdtest.go index d3e880133e2c6..969bb5dfbb3f4 100644 --- a/coderd/coderdtest/coderdtest.go +++ b/coderd/coderdtest/coderdtest.go @@ -11,6 +11,7 @@ import ( "cdr.dev/slog/sloggers/slogtest" "github.com/coder/coder/coderd" "github.com/coder/coder/codersdk" + "github.com/coder/coder/cryptorand" "github.com/coder/coder/database/databasefake" ) @@ -23,7 +24,37 @@ type Server struct { URL *url.URL } -// New constructs a new coderd test instance. +// RandomInitialUser generates a random initial user and authenticates +// it with the client on the Server struct. +func (s *Server) RandomInitialUser(t *testing.T) coderd.CreateInitialUserRequest { + username, err := cryptorand.String(12) + require.NoError(t, err) + password, err := cryptorand.String(12) + require.NoError(t, err) + organization, err := cryptorand.String(12) + require.NoError(t, err) + + req := coderd.CreateInitialUserRequest{ + Email: "testuser@coder.com", + Username: username, + Password: password, + Organization: organization, + } + _, err = s.Client.CreateInitialUser(context.Background(), req) + require.NoError(t, err) + + login, err := s.Client.LoginWithPassword(context.Background(), coderd.LoginWithPasswordRequest{ + Email: "testuser@coder.com", + Password: password, + }) + require.NoError(t, err) + err = s.Client.SetSessionToken(login.SessionToken) + require.NoError(t, err) + return req +} + +// New constructs a new coderd test instance. This returned Server +// should contain no side-effects. func New(t *testing.T) Server { // This can be hotswapped for a live database instance. db := databasefake.New() @@ -36,24 +67,8 @@ func New(t *testing.T) Server { require.NoError(t, err) t.Cleanup(srv.Close) - client := codersdk.New(serverURL) - _, err = client.CreateInitialUser(context.Background(), coderd.CreateUserRequest{ - Email: "testuser@coder.com", - Username: "testuser", - Password: "testpassword", - }) - require.NoError(t, err) - - login, err := client.LoginWithPassword(context.Background(), coderd.LoginWithPasswordRequest{ - Email: "testuser@coder.com", - Password: "testpassword", - }) - require.NoError(t, err) - err = client.SetSessionToken(login.SessionToken) - require.NoError(t, err) - return Server{ - Client: client, + Client: codersdk.New(serverURL), URL: serverURL, } } diff --git a/coderd/coderdtest/coderdtest_test.go b/coderd/coderdtest/coderdtest_test.go index 1e0a4ae9c4e72..127b47994189e 100644 --- a/coderd/coderdtest/coderdtest_test.go +++ b/coderd/coderdtest/coderdtest_test.go @@ -13,5 +13,6 @@ func TestMain(m *testing.M) { } func TestNew(t *testing.T) { - _ = coderdtest.New(t) + server := coderdtest.New(t) + _ = server.RandomInitialUser(t) } diff --git a/coderd/organizations.go b/coderd/organizations.go new file mode 100644 index 0000000000000..984598baf0cae --- /dev/null +++ b/coderd/organizations.go @@ -0,0 +1,25 @@ +package coderd + +import ( + "time" + + "github.com/coder/coder/database" +) + +// Organization is the JSON representation of a Coder organization. +type Organization struct { + ID string `json:"id" validate:"required"` + Username string `json:"username" validate:"required"` + CreatedAt time.Time `json:"created_at" validate:"required"` + UpdatedAt time.Time `json:"updated_at" validate:"required"` +} + +// convertOrganization consumes the database representation and outputs API friendly. +func convertOrganization(organization database.Organization) Organization { + return Organization{ + ID: organization.ID, + Username: organization.Name, + CreatedAt: organization.CreatedAt, + UpdatedAt: organization.UpdatedAt, + } +} diff --git a/coderd/users.go b/coderd/users.go index b5f4004d70f1e..14252e6fab72a 100644 --- a/coderd/users.go +++ b/coderd/users.go @@ -1,7 +1,6 @@ package coderd import ( - "context" "crypto/sha256" "database/sql" "errors" @@ -11,6 +10,7 @@ import ( "github.com/go-chi/render" "github.com/google/uuid" + "golang.org/x/xerrors" "github.com/coder/coder/coderd/userpassword" "github.com/coder/coder/cryptorand" @@ -27,11 +27,12 @@ type User struct { Username string `json:"username" validate:"required"` } -// CreateUserRequest enables callers to create a new user. -type CreateUserRequest struct { - Email string `json:"email" validate:"required,email"` - Username string `json:"username" validate:"required,username"` - Password string `json:"password" validate:"required"` +// CreateInitialUserRequest enables callers to create a new user. +type CreateInitialUserRequest struct { + Email string `json:"email" validate:"required,email"` + Username string `json:"username" validate:"required,username"` + Password string `json:"password" validate:"required"` + Organization string `json:"organization" validate:"required,username"` } // LoginWithPasswordRequest enables callers to authenticate with email and password. @@ -51,7 +52,7 @@ type users struct { // Creates the initial user for a Coder deployment. func (users *users) createInitialUser(rw http.ResponseWriter, r *http.Request) { - var createUser CreateUserRequest + var createUser CreateInitialUserRequest if !httpapi.Read(rw, r, &createUser) { return } @@ -70,19 +71,6 @@ func (users *users) createInitialUser(rw http.ResponseWriter, r *http.Request) { }) return } - _, err = users.Database.GetUserByEmailOrUsername(r.Context(), database.GetUserByEmailOrUsernameParams{ - Email: createUser.Email, - Username: createUser.Username, - }) - if errors.Is(err, sql.ErrNoRows) { - err = nil - } - if err != nil { - httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ - Message: fmt.Sprintf("get user: %s", err.Error()), - }) - return - } hashedPassword, err := userpassword.Hash(createUser.Password) if err != nil { httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ @@ -91,28 +79,53 @@ func (users *users) createInitialUser(rw http.ResponseWriter, r *http.Request) { return } - user, err := users.Database.InsertUser(context.Background(), database.InsertUserParams{ - ID: uuid.NewString(), - Email: createUser.Email, - HashedPassword: []byte(hashedPassword), - Username: createUser.Username, - LoginType: database.LoginTypeBuiltIn, - CreatedAt: database.Now(), - UpdatedAt: database.Now(), + // Create the user, organization, and membership to the user. + var user database.User + err = users.Database.InTx(func(s database.Store) error { + user, err = users.Database.InsertUser(r.Context(), database.InsertUserParams{ + ID: uuid.NewString(), + Email: createUser.Email, + HashedPassword: []byte(hashedPassword), + Username: createUser.Username, + LoginType: database.LoginTypeBuiltIn, + CreatedAt: database.Now(), + UpdatedAt: database.Now(), + }) + if err != nil { + return xerrors.Errorf("create user: %w", err) + } + organization, err := users.Database.InsertOrganization(r.Context(), database.InsertOrganizationParams{ + ID: uuid.NewString(), + Name: createUser.Organization, + CreatedAt: database.Now(), + UpdatedAt: database.Now(), + }) + if err != nil { + return xerrors.Errorf("create organization: %w", err) + } + _, err = users.Database.InsertOrganizationMember(r.Context(), database.InsertOrganizationMemberParams{ + OrganizationID: organization.ID, + UserID: user.ID, + CreatedAt: database.Now(), + UpdatedAt: database.Now(), + Roles: []string{"organization-admin"}, + }) + return nil }) if err != nil { httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ - Message: fmt.Sprintf("create user: %s", err.Error()), + Message: err.Error(), }) return } + render.Status(r, http.StatusCreated) render.JSON(rw, r, user) } // Returns the currently authenticated user. -func (*users) authenticatedUser(rw http.ResponseWriter, r *http.Request) { - user := httpmw.User(r) +func (*users) getUser(rw http.ResponseWriter, r *http.Request) { + user := httpmw.UserParam(r) render.JSON(rw, r, User{ ID: user.ID, @@ -122,6 +135,29 @@ func (*users) authenticatedUser(rw http.ResponseWriter, r *http.Request) { }) } +func (u *users) getUserOrganizations(rw http.ResponseWriter, r *http.Request) { + user := httpmw.UserParam(r) + + organizations, err := u.Database.GetOrganizationsByUserID(r.Context(), user.ID) + if errors.Is(err, sql.ErrNoRows) { + err = nil + } + if err != nil { + httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ + Message: fmt.Sprintf("get organizations: %s", err.Error()), + }) + return + } + + publicOrganizations := make([]Organization, 0, len(organizations)) + for _, organization := range organizations { + publicOrganizations = append(publicOrganizations, convertOrganization(organization)) + } + + render.Status(r, http.StatusOK) + render.JSON(rw, r, publicOrganizations) +} + // Authenticates the user with an email and password. func (users *users) loginWithPassword(rw http.ResponseWriter, r *http.Request) { var loginWithPassword LoginWithPasswordRequest diff --git a/coderd/users_test.go b/coderd/users_test.go index 0aa8e4f023d56..cd6deda103011 100644 --- a/coderd/users_test.go +++ b/coderd/users_test.go @@ -16,6 +16,7 @@ func TestUsers(t *testing.T) { t.Run("Authenticated", func(t *testing.T) { t.Parallel() server := coderdtest.New(t) + _ = server.RandomInitialUser(t) _, err := server.Client.User(context.Background(), "") require.NoError(t, err) }) @@ -23,15 +24,27 @@ func TestUsers(t *testing.T) { t.Run("CreateMultipleInitial", func(t *testing.T) { t.Parallel() server := coderdtest.New(t) - _, err := server.Client.CreateInitialUser(context.Background(), coderd.CreateUserRequest{ - Email: "dummy@coder.com", - Username: "fake", - Password: "password", + _ = server.RandomInitialUser(t) + _, err := server.Client.CreateInitialUser(context.Background(), coderd.CreateInitialUserRequest{ + Email: "dummy@coder.com", + Organization: "bananas", + Username: "fake", + Password: "password", }) require.Error(t, err) }) - t.Run("LoginNoEmail", func(t *testing.T) { + t.Run("Login", func(t *testing.T) { + server := coderdtest.New(t) + user := server.RandomInitialUser(t) + _, err := server.Client.LoginWithPassword(context.Background(), coderd.LoginWithPasswordRequest{ + Email: user.Email, + Password: user.Password, + }) + require.NoError(t, err) + }) + + t.Run("LoginInvalidUser", func(t *testing.T) { t.Parallel() server := coderdtest.New(t) _, err := server.Client.LoginWithPassword(context.Background(), coderd.LoginWithPasswordRequest{ @@ -44,13 +57,20 @@ func TestUsers(t *testing.T) { t.Run("LoginBadPassword", func(t *testing.T) { t.Parallel() server := coderdtest.New(t) - user, err := server.Client.User(context.Background(), "") - require.NoError(t, err) - - _, err = server.Client.LoginWithPassword(context.Background(), coderd.LoginWithPasswordRequest{ + user := server.RandomInitialUser(t) + _, err := server.Client.LoginWithPassword(context.Background(), coderd.LoginWithPasswordRequest{ Email: user.Email, Password: "bananas", }) require.Error(t, err) }) + + t.Run("ListOrganizations", func(t *testing.T) { + t.Parallel() + server := coderdtest.New(t) + _ = server.RandomInitialUser(t) + orgs, err := server.Client.UserOrganizations(context.Background(), "") + require.NoError(t, err) + require.Len(t, orgs, 1) + }) } diff --git a/codersdk/users.go b/codersdk/users.go index e7caa36b1bf15..abe62107b90f6 100644 --- a/codersdk/users.go +++ b/codersdk/users.go @@ -3,6 +3,7 @@ package codersdk import ( "context" "encoding/json" + "fmt" "net/http" "github.com/coder/coder/coderd" @@ -11,8 +12,8 @@ import ( // CreateInitialUser attempts to create the first user on a Coder deployment. // This initial user has superadmin privileges. If >0 users exist, this request // will fail. -func (c *Client) CreateInitialUser(ctx context.Context, req coderd.CreateUserRequest) (coderd.User, error) { - res, err := c.request(ctx, http.MethodPost, "/api/v2/user", req) +func (c *Client) CreateInitialUser(ctx context.Context, req coderd.CreateInitialUserRequest) (coderd.User, error) { + res, err := c.request(ctx, http.MethodPost, "/api/v2/users", req) if err != nil { return coderd.User{}, err } @@ -24,21 +25,6 @@ func (c *Client) CreateInitialUser(ctx context.Context, req coderd.CreateUserReq return user, json.NewDecoder(res.Body).Decode(&user) } -// User returns a user for the ID provided. -// If the ID string is empty, the current user will be returned. -func (c *Client) User(ctx context.Context, _ string) (coderd.User, error) { - res, err := c.request(ctx, http.MethodGet, "/api/v2/user", nil) - if err != nil { - return coderd.User{}, err - } - defer res.Body.Close() - if res.StatusCode > http.StatusOK { - return coderd.User{}, readBodyAsError(res) - } - var user coderd.User - return user, json.NewDecoder(res.Body).Decode(&user) -} - // LoginWithPassword creates a session token authenticating with an email and password. // Call `SetSessionToken()` to apply the newly acquired token to the client. func (c *Client) LoginWithPassword(ctx context.Context, req coderd.LoginWithPasswordRequest) (coderd.LoginWithPasswordResponse, error) { @@ -57,3 +43,38 @@ func (c *Client) LoginWithPassword(ctx context.Context, req coderd.LoginWithPass } return resp, nil } + +// User returns a user for the ID provided. +// If the ID string is empty, the current user will be returned. +func (c *Client) User(ctx context.Context, id string) (coderd.User, error) { + if id == "" { + id = "me" + } + res, err := c.request(ctx, http.MethodGet, fmt.Sprintf("/api/v2/users/%s", id), nil) + if err != nil { + return coderd.User{}, err + } + defer res.Body.Close() + if res.StatusCode > http.StatusOK { + return coderd.User{}, readBodyAsError(res) + } + var user coderd.User + return user, json.NewDecoder(res.Body).Decode(&user) +} + +// UserOrganizations fetches organizations a user is part of. +func (c *Client) UserOrganizations(ctx context.Context, id string) ([]coderd.Organization, error) { + if id == "" { + id = "me" + } + res, err := c.request(ctx, http.MethodGet, fmt.Sprintf("/api/v2/users/%s/organizations", id), nil) + if err != nil { + return nil, err + } + defer res.Body.Close() + if res.StatusCode != http.StatusOK { + return nil, readBodyAsError(res) + } + var orgs []coderd.Organization + return orgs, json.NewDecoder(res.Body).Decode(&orgs) +} diff --git a/codersdk/users_test.go b/codersdk/users_test.go index a6740ff6a220a..bf03c63405de8 100644 --- a/codersdk/users_test.go +++ b/codersdk/users_test.go @@ -15,10 +15,11 @@ import ( func TestUsers(t *testing.T) { t.Run("MultipleInitial", func(t *testing.T) { server := coderdtest.New(t) - _, err := server.Client.CreateInitialUser(context.Background(), coderd.CreateUserRequest{ - Email: "wowie@coder.com", - Username: "tester", - Password: "moo", + _, err := server.Client.CreateInitialUser(context.Background(), coderd.CreateInitialUserRequest{ + Email: "wowie@coder.com", + Organization: "somethin", + Username: "tester", + Password: "moo", }) var cerr *codersdk.Error require.ErrorAs(t, err, &cerr) diff --git a/database/databasefake/databasefake.go b/database/databasefake/databasefake.go index 96b39fc99ca17..fdf292757b2bb 100644 --- a/database/databasefake/databasefake.go +++ b/database/databasefake/databasefake.go @@ -10,15 +10,19 @@ import ( // New returns an in-memory fake of the database. func New() database.Store { return &fakeQuerier{ - apiKeys: make([]database.APIKey, 0), - users: make([]database.User, 0), + apiKeys: make([]database.APIKey, 0), + organizations: make([]database.Organization, 0), + organizationMembers: make([]database.OrganizationMember, 0), + users: make([]database.User, 0), } } // fakeQuerier replicates database functionality to enable quick testing. type fakeQuerier struct { - apiKeys []database.APIKey - users []database.User + apiKeys []database.APIKey + organizations []database.Organization + organizationMembers []database.OrganizationMember + users []database.User } // InTx doesn't rollback data properly for in-memory yet. @@ -57,6 +61,34 @@ func (q *fakeQuerier) GetUserCount(_ context.Context) (int64, error) { return int64(len(q.users)), nil } +func (q *fakeQuerier) GetOrganizationByName(ctx context.Context, name string) (database.Organization, error) { + for _, organization := range q.organizations { + if organization.Name == name { + return organization, nil + } + } + return database.Organization{}, sql.ErrNoRows +} + +func (q *fakeQuerier) GetOrganizationsByUserID(ctx context.Context, userID string) ([]database.Organization, error) { + organizations := make([]database.Organization, 0) + for _, organizationMember := range q.organizationMembers { + if organizationMember.UserID != userID { + continue + } + for _, organization := range q.organizations { + if organization.ID != organizationMember.OrganizationID { + continue + } + organizations = append(organizations, organization) + } + } + if len(organizations) == 0 { + return nil, sql.ErrNoRows + } + return organizations, nil +} + func (q *fakeQuerier) InsertAPIKey(_ context.Context, arg database.InsertAPIKeyParams) (database.APIKey, error) { //nolint:gosimple key := database.APIKey{ @@ -80,6 +112,29 @@ func (q *fakeQuerier) InsertAPIKey(_ context.Context, arg database.InsertAPIKeyP return key, nil } +func (q *fakeQuerier) InsertOrganization(ctx context.Context, arg database.InsertOrganizationParams) (database.Organization, error) { + organization := database.Organization{ + ID: arg.ID, + Name: arg.Name, + CreatedAt: arg.CreatedAt, + UpdatedAt: arg.UpdatedAt, + } + q.organizations = append(q.organizations, organization) + return organization, nil +} + +func (q *fakeQuerier) InsertOrganizationMember(ctx context.Context, arg database.InsertOrganizationMemberParams) (database.OrganizationMember, error) { + organizationMember := database.OrganizationMember{ + OrganizationID: arg.OrganizationID, + UserID: arg.UserID, + CreatedAt: arg.CreatedAt, + UpdatedAt: arg.UpdatedAt, + Roles: arg.Roles, + } + q.organizationMembers = append(q.organizationMembers, organizationMember) + return organizationMember, nil +} + func (q *fakeQuerier) InsertUser(_ context.Context, arg database.InsertUserParams) (database.User, error) { user := database.User{ ID: arg.ID, diff --git a/database/querier.go b/database/querier.go index cd15099990afa..ce7658c10e718 100644 --- a/database/querier.go +++ b/database/querier.go @@ -8,10 +8,14 @@ import ( type querier interface { GetAPIKeyByID(ctx context.Context, id string) (APIKey, error) + GetOrganizationByName(ctx context.Context, name string) (Organization, error) + GetOrganizationsByUserID(ctx context.Context, userID string) ([]Organization, error) GetUserByEmailOrUsername(ctx context.Context, arg GetUserByEmailOrUsernameParams) (User, error) GetUserByID(ctx context.Context, id string) (User, error) GetUserCount(ctx context.Context) (int64, error) InsertAPIKey(ctx context.Context, arg InsertAPIKeyParams) (APIKey, error) + InsertOrganization(ctx context.Context, arg InsertOrganizationParams) (Organization, error) + InsertOrganizationMember(ctx context.Context, arg InsertOrganizationMemberParams) (OrganizationMember, error) InsertUser(ctx context.Context, arg InsertUserParams) (User, error) UpdateAPIKeyByID(ctx context.Context, arg UpdateAPIKeyByIDParams) error } diff --git a/database/query.sql b/database/query.sql index 3382f86c5e801..5d541b07579f9 100644 --- a/database/query.sql +++ b/database/query.sql @@ -41,6 +41,14 @@ SELECT FROM users; +-- name: GetOrganizationByName :one +SELECT * FROM organizations WHERE name = $1 LIMIT 1; + +-- name: GetOrganizationsByUserID :many +SELECT * FROM organizations WHERE id = ( + SELECT organization_id FROM organization_members WHERE user_id = $1 +); + -- name: InsertAPIKey :one INSERT INTO api_keys ( @@ -79,6 +87,12 @@ VALUES $15 ) RETURNING *; +-- name: InsertOrganization :one +INSERT INTO organizations (id, name, description, created_at, updated_at) VALUES ($1, $2, $3, $4, $5) RETURNING *; + +-- name: InsertOrganizationMember :one +INSERT INTO organization_members (organization_id, user_id, created_at, updated_at, roles) VALUES ($1, $2, $3, $4, $5) RETURNING *; + -- name: InsertUser :one INSERT INTO users ( diff --git a/database/query.sql.go b/database/query.sql.go index 7d1003b8b2094..3382e8e75b69a 100644 --- a/database/query.sql.go +++ b/database/query.sql.go @@ -44,6 +44,68 @@ func (q *sqlQuerier) GetAPIKeyByID(ctx context.Context, id string) (APIKey, erro return i, err } +const getOrganizationByName = `-- name: GetOrganizationByName :one +SELECT id, name, description, created_at, updated_at, "default", auto_off_threshold, cpu_provisioning_rate, memory_provisioning_rate, workspace_auto_off FROM organizations WHERE name = $1 LIMIT 1 +` + +func (q *sqlQuerier) GetOrganizationByName(ctx context.Context, name string) (Organization, error) { + row := q.db.QueryRowContext(ctx, getOrganizationByName, name) + var i Organization + err := row.Scan( + &i.ID, + &i.Name, + &i.Description, + &i.CreatedAt, + &i.UpdatedAt, + &i.Default, + &i.AutoOffThreshold, + &i.CpuProvisioningRate, + &i.MemoryProvisioningRate, + &i.WorkspaceAutoOff, + ) + return i, err +} + +const getOrganizationsByUserID = `-- name: GetOrganizationsByUserID :many +SELECT id, name, description, created_at, updated_at, "default", auto_off_threshold, cpu_provisioning_rate, memory_provisioning_rate, workspace_auto_off FROM organizations WHERE id = ( + SELECT organization_id FROM organization_members WHERE user_id = $1 +) +` + +func (q *sqlQuerier) GetOrganizationsByUserID(ctx context.Context, userID string) ([]Organization, error) { + rows, err := q.db.QueryContext(ctx, getOrganizationsByUserID, userID) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Organization + for rows.Next() { + var i Organization + if err := rows.Scan( + &i.ID, + &i.Name, + &i.Description, + &i.CreatedAt, + &i.UpdatedAt, + &i.Default, + &i.AutoOffThreshold, + &i.CpuProvisioningRate, + &i.MemoryProvisioningRate, + &i.WorkspaceAutoOff, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const getUserByEmailOrUsername = `-- name: GetUserByEmailOrUsername :one SELECT id, email, name, revoked, login_type, hashed_password, created_at, updated_at, temporary_password, avatar_hash, ssh_key_regenerated_at, username, dotfiles_git_uri, roles, status, relatime, gpg_key_regenerated_at, _decomissioned, shell @@ -236,6 +298,73 @@ func (q *sqlQuerier) InsertAPIKey(ctx context.Context, arg InsertAPIKeyParams) ( return i, err } +const insertOrganization = `-- name: InsertOrganization :one +INSERT INTO organizations (id, name, description, created_at, updated_at) VALUES ($1, $2, $3, $4, $5) RETURNING id, name, description, created_at, updated_at, "default", auto_off_threshold, cpu_provisioning_rate, memory_provisioning_rate, workspace_auto_off +` + +type InsertOrganizationParams struct { + ID string `db:"id" json:"id"` + Name string `db:"name" json:"name"` + Description string `db:"description" json:"description"` + CreatedAt time.Time `db:"created_at" json:"created_at"` + UpdatedAt time.Time `db:"updated_at" json:"updated_at"` +} + +func (q *sqlQuerier) InsertOrganization(ctx context.Context, arg InsertOrganizationParams) (Organization, error) { + row := q.db.QueryRowContext(ctx, insertOrganization, + arg.ID, + arg.Name, + arg.Description, + arg.CreatedAt, + arg.UpdatedAt, + ) + var i Organization + err := row.Scan( + &i.ID, + &i.Name, + &i.Description, + &i.CreatedAt, + &i.UpdatedAt, + &i.Default, + &i.AutoOffThreshold, + &i.CpuProvisioningRate, + &i.MemoryProvisioningRate, + &i.WorkspaceAutoOff, + ) + return i, err +} + +const insertOrganizationMember = `-- name: InsertOrganizationMember :one +INSERT INTO organization_members (organization_id, user_id, created_at, updated_at, roles) VALUES ($1, $2, $3, $4, $5) RETURNING organization_id, user_id, created_at, updated_at, roles +` + +type InsertOrganizationMemberParams struct { + OrganizationID string `db:"organization_id" json:"organization_id"` + UserID string `db:"user_id" json:"user_id"` + CreatedAt time.Time `db:"created_at" json:"created_at"` + UpdatedAt time.Time `db:"updated_at" json:"updated_at"` + Roles []string `db:"roles" json:"roles"` +} + +func (q *sqlQuerier) InsertOrganizationMember(ctx context.Context, arg InsertOrganizationMemberParams) (OrganizationMember, error) { + row := q.db.QueryRowContext(ctx, insertOrganizationMember, + arg.OrganizationID, + arg.UserID, + arg.CreatedAt, + arg.UpdatedAt, + pq.Array(arg.Roles), + ) + var i OrganizationMember + err := row.Scan( + &i.OrganizationID, + &i.UserID, + &i.CreatedAt, + &i.UpdatedAt, + pq.Array(&i.Roles), + ) + return i, err +} + const insertUser = `-- name: InsertUser :one INSERT INTO users ( diff --git a/httpmw/user.go b/httpmw/user.go deleted file mode 100644 index 21e8442cdbdaa..0000000000000 --- a/httpmw/user.go +++ /dev/null @@ -1,51 +0,0 @@ -package httpmw - -import ( - "context" - "database/sql" - "errors" - "fmt" - "net/http" - - "github.com/coder/coder/database" - "github.com/coder/coder/httpapi" -) - -type userContextKey struct{} - -// User returns the user from the ExtractUser handler. -func User(r *http.Request) database.User { - user, ok := r.Context().Value(userContextKey{}).(database.User) - if !ok { - panic("developer error: user middleware not provided") - } - return user -} - -// ExtractUser consumes an API key and queries the user attached to it. -// It attaches the user to the request context. -func ExtractUser(db database.Store) func(http.Handler) http.Handler { - return func(next http.Handler) http.Handler { - return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { - // The user handler depends on API Key to get the authenticated user. - apiKey := APIKey(r) - - user, err := db.GetUserByID(r.Context(), apiKey.UserID) - if err != nil { - if errors.Is(err, sql.ErrNoRows) { - httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ - Message: "user not found for api key", - }) - return - } - httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ - Message: fmt.Sprintf("couldn't fetch user for api key: %s", err.Error()), - }) - return - } - - ctx := context.WithValue(r.Context(), userContextKey{}, user) - next.ServeHTTP(rw, r.WithContext(ctx)) - }) - } -} diff --git a/httpmw/user_test.go b/httpmw/user_test.go deleted file mode 100644 index b060e90cc87d4..0000000000000 --- a/httpmw/user_test.go +++ /dev/null @@ -1,89 +0,0 @@ -package httpmw_test - -import ( - "crypto/sha256" - "fmt" - "net/http" - "net/http/httptest" - "testing" - "time" - - "github.com/stretchr/testify/require" - - "github.com/coder/coder/database" - "github.com/coder/coder/database/databasefake" - "github.com/coder/coder/httpmw" -) - -func TestUser(t *testing.T) { - t.Run("NoUser", func(t *testing.T) { - var ( - db = databasefake.New() - id, secret = randomAPIKeyParts() - hashed = sha256.Sum256([]byte(secret)) - r = httptest.NewRequest("GET", "/", nil) - rw = httptest.NewRecorder() - ) - r.AddCookie(&http.Cookie{ - Name: httpmw.AuthCookie, - Value: fmt.Sprintf("%s-%s", id, secret), - }) - - _, err := db.InsertAPIKey(r.Context(), database.InsertAPIKeyParams{ - ID: id, - UserID: "bananas", - HashedSecret: hashed[:], - LastUsed: database.Now(), - ExpiresAt: database.Now().Add(time.Minute), - }) - require.NoError(t, err) - - httpmw.ExtractAPIKey(db, nil)(http.HandlerFunc(func(rw http.ResponseWriter, returnedRequest *http.Request) { - r = returnedRequest - })).ServeHTTP(rw, r) - - httpmw.ExtractUser(db)(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { - rw.WriteHeader(http.StatusOK) - })).ServeHTTP(rw, r) - }) - - t.Run("User", func(t *testing.T) { - var ( - db = databasefake.New() - id, secret = randomAPIKeyParts() - hashed = sha256.Sum256([]byte(secret)) - r = httptest.NewRequest("GET", "/", nil) - rw = httptest.NewRecorder() - ) - r.AddCookie(&http.Cookie{ - Name: httpmw.AuthCookie, - Value: fmt.Sprintf("%s-%s", id, secret), - }) - - user, err := db.InsertUser(r.Context(), database.InsertUserParams{ - ID: "testing", - CreatedAt: database.Now(), - UpdatedAt: database.Now(), - }) - require.NoError(t, err) - - _, err = db.InsertAPIKey(r.Context(), database.InsertAPIKeyParams{ - ID: id, - UserID: user.ID, - HashedSecret: hashed[:], - LastUsed: database.Now(), - ExpiresAt: database.Now().Add(time.Minute), - }) - require.NoError(t, err) - - httpmw.ExtractAPIKey(db, nil)(http.HandlerFunc(func(rw http.ResponseWriter, returnedRequest *http.Request) { - r = returnedRequest - })).ServeHTTP(rw, r) - - httpmw.ExtractUser(db)(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { - // Makes sure the context properly adds the User! - _ = httpmw.User(r) - rw.WriteHeader(http.StatusOK) - })).ServeHTTP(rw, r) - }) -} diff --git a/httpmw/userparam.go b/httpmw/userparam.go new file mode 100644 index 0000000000000..45aec11b8eb0b --- /dev/null +++ b/httpmw/userparam.go @@ -0,0 +1,53 @@ +package httpmw + +import ( + "context" + "fmt" + "net/http" + + "github.com/coder/coder/database" + "github.com/coder/coder/httpapi" + "github.com/go-chi/chi" +) + +type userParamContextKey struct{} + +// UserParam returns the user from the ExtractUserParam handler. +func UserParam(r *http.Request) database.User { + user, ok := r.Context().Value(userParamContextKey{}).(database.User) + if !ok { + panic("developer error: user parameter middleware not provided") + } + return user +} + +// ExtractUserParam extracts a user from an ID/username in the {user} URL parameter. +func ExtractUserParam(db database.Store) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + userID := chi.URLParam(r, "user") + if userID == "" { + httpapi.Write(rw, http.StatusBadRequest, httpapi.Response{ + Message: "user id or name must be provided", + }) + return + } + if userID != "me" { + httpapi.Write(rw, http.StatusBadRequest, httpapi.Response{ + Message: "getting non-personal users isn't supported yet", + }) + return + } + apiKey := APIKey(r) + user, err := db.GetUserByID(r.Context(), apiKey.UserID) + if err != nil { + httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ + Message: fmt.Sprintf("get user: %s", err.Error()), + }) + } + + ctx := context.WithValue(r.Context(), userParamContextKey{}, user) + next.ServeHTTP(rw, r.WithContext(ctx)) + }) + } +} diff --git a/httpmw/userparam_test.go b/httpmw/userparam_test.go new file mode 100644 index 0000000000000..48f0df6e09026 --- /dev/null +++ b/httpmw/userparam_test.go @@ -0,0 +1,95 @@ +package httpmw_test + +import ( + "context" + "crypto/sha256" + "fmt" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/coder/coder/database" + "github.com/coder/coder/database/databasefake" + "github.com/coder/coder/httpmw" + "github.com/go-chi/chi" + "github.com/stretchr/testify/require" +) + +func TestUserParam(t *testing.T) { + setup := func(t *testing.T) (database.Store, *httptest.ResponseRecorder, *http.Request) { + var ( + db = databasefake.New() + id, secret = randomAPIKeyParts() + hashed = sha256.Sum256([]byte(secret)) + r = httptest.NewRequest("GET", "/", nil) + rw = httptest.NewRecorder() + ) + r.AddCookie(&http.Cookie{ + Name: httpmw.AuthCookie, + Value: fmt.Sprintf("%s-%s", id, secret), + }) + + _, err := db.InsertUser(r.Context(), database.InsertUserParams{ + ID: "bananas", + }) + require.NoError(t, err) + _, err = db.InsertAPIKey(r.Context(), database.InsertAPIKeyParams{ + ID: id, + UserID: "bananas", + HashedSecret: hashed[:], + LastUsed: database.Now(), + ExpiresAt: database.Now().Add(time.Minute), + }) + require.NoError(t, err) + return db, rw, r + } + + t.Run("None", func(t *testing.T) { + db, rw, r := setup(t) + + httpmw.ExtractAPIKey(db, nil)(http.HandlerFunc(func(rw http.ResponseWriter, returnedRequest *http.Request) { + r = returnedRequest + })).ServeHTTP(rw, r) + + httpmw.ExtractUserParam(db)(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + rw.WriteHeader(http.StatusOK) + })).ServeHTTP(rw, r) + + require.Equal(t, http.StatusBadRequest, rw.Result().StatusCode) + }) + + t.Run("NotMe", func(t *testing.T) { + db, rw, r := setup(t) + + httpmw.ExtractAPIKey(db, nil)(http.HandlerFunc(func(rw http.ResponseWriter, returnedRequest *http.Request) { + r = returnedRequest + })).ServeHTTP(rw, r) + + routeContext := chi.NewRouteContext() + routeContext.URLParams.Add("user", "ben") + r = r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, routeContext)) + httpmw.ExtractUserParam(db)(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + rw.WriteHeader(http.StatusOK) + })).ServeHTTP(rw, r) + + require.Equal(t, http.StatusBadRequest, rw.Result().StatusCode) + }) + + t.Run("Me", func(t *testing.T) { + db, rw, r := setup(t) + + httpmw.ExtractAPIKey(db, nil)(http.HandlerFunc(func(rw http.ResponseWriter, returnedRequest *http.Request) { + r = returnedRequest + })).ServeHTTP(rw, r) + + routeContext := chi.NewRouteContext() + routeContext.URLParams.Add("user", "me") + r = r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, routeContext)) + httpmw.ExtractUserParam(db)(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + rw.WriteHeader(http.StatusOK) + })).ServeHTTP(rw, r) + + require.Equal(t, http.StatusOK, rw.Result().StatusCode) + }) +} diff --git a/peer/conn_test.go b/peer/conn_test.go index 7e64bf7747ea2..03e3eeca05363 100644 --- a/peer/conn_test.go +++ b/peer/conn_test.go @@ -23,7 +23,7 @@ import ( ) const ( - disconnectedTimeout = time.Second + disconnectedTimeout = 20 * time.Millisecond failedTimeout = disconnectedTimeout * 5 keepAliveInterval = time.Millisecond * 2 ) diff --git a/site/components/SignIn/SignInForm.tsx b/site/components/SignIn/SignInForm.tsx index b66735c5dafae..add23d2f8139e 100644 --- a/site/components/SignIn/SignInForm.tsx +++ b/site/components/SignIn/SignInForm.tsx @@ -87,7 +87,7 @@ export const SignInForm: React.FC = ({ try { await loginHandler(email, password) // Tell SWR to invalidate the cache for the user endpoint - mutate("/api/v2/user") + mutate("/api/v2/users/me") router.push("/") } catch (err) { helpers.setFieldError("password", "The username or password is incorrect.") diff --git a/site/contexts/UserContext.tsx b/site/contexts/UserContext.tsx index 544d5dd835701..3635425f45589 100644 --- a/site/contexts/UserContext.tsx +++ b/site/contexts/UserContext.tsx @@ -36,7 +36,7 @@ export const useUser = (redirectOnError = false): UserContext => { } export const UserProvider: React.FC = (props) => { - const { data, error } = useSWR("/api/v2/user") + const { data, error } = useSWR("/api/v2/users/me") return ( Date: Sat, 22 Jan 2022 20:58:12 +0000 Subject: [PATCH 2/7] Fix requested changes --- coderd/coderd.go | 4 ++-- coderd/organizations.go | 6 +++--- coderd/users.go | 13 +++++++++---- database/databasefake/databasefake.go | 9 +++++---- httpmw/userparam.go | 3 ++- httpmw/userparam_test.go | 20 ++++++++++++-------- 6 files changed, 33 insertions(+), 22 deletions(-) diff --git a/coderd/coderd.go b/coderd/coderd.go index a9de5880d23b0..4ad9463c73dcd 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -40,8 +40,8 @@ func New(options *Options) http.Handler { httpmw.ExtractAPIKey(options.Database, nil), httpmw.ExtractUserParam(options.Database), ) - r.Get("/{user}", users.getUser) - r.Get("/{user}/organizations", users.getUserOrganizations) + r.Get("/{user}", users.user) + r.Get("/{user}/organizations", users.userOrganizations) }) }) }) diff --git a/coderd/organizations.go b/coderd/organizations.go index 984598baf0cae..0f438274598b3 100644 --- a/coderd/organizations.go +++ b/coderd/organizations.go @@ -9,16 +9,16 @@ import ( // Organization is the JSON representation of a Coder organization. type Organization struct { ID string `json:"id" validate:"required"` - Username string `json:"username" validate:"required"` + Name string `json:"name" validate:"required"` CreatedAt time.Time `json:"created_at" validate:"required"` UpdatedAt time.Time `json:"updated_at" validate:"required"` } -// convertOrganization consumes the database representation and outputs API friendly. +// convertOrganization consumes the database representation and outputs an API friendly representation. func convertOrganization(organization database.Organization) Organization { return Organization{ ID: organization.ID, - Username: organization.Name, + Name: organization.Name, CreatedAt: organization.CreatedAt, UpdatedAt: organization.UpdatedAt, } diff --git a/coderd/users.go b/coderd/users.go index 14252e6fab72a..37309e0f3dcf8 100644 --- a/coderd/users.go +++ b/coderd/users.go @@ -110,6 +110,9 @@ func (users *users) createInitialUser(rw http.ResponseWriter, r *http.Request) { UpdatedAt: database.Now(), Roles: []string{"organization-admin"}, }) + if err != nil { + return xerrors.Errorf("create organization member: %w", err) + } return nil }) if err != nil { @@ -123,8 +126,9 @@ func (users *users) createInitialUser(rw http.ResponseWriter, r *http.Request) { render.JSON(rw, r, user) } -// Returns the currently authenticated user. -func (*users) getUser(rw http.ResponseWriter, r *http.Request) { +// Returns the parameterized user requested. All validation +// is completed in the middleware for this route. +func (*users) user(rw http.ResponseWriter, r *http.Request) { user := httpmw.UserParam(r) render.JSON(rw, r, User{ @@ -135,10 +139,11 @@ func (*users) getUser(rw http.ResponseWriter, r *http.Request) { }) } -func (u *users) getUserOrganizations(rw http.ResponseWriter, r *http.Request) { +// Returns organizations the parameterized user has access to. +func (users *users) userOrganizations(rw http.ResponseWriter, r *http.Request) { user := httpmw.UserParam(r) - organizations, err := u.Database.GetOrganizationsByUserID(r.Context(), user.ID) + organizations, err := users.Database.GetOrganizationsByUserID(r.Context(), user.ID) if errors.Is(err, sql.ErrNoRows) { err = nil } diff --git a/database/databasefake/databasefake.go b/database/databasefake/databasefake.go index fdf292757b2bb..07c7e11ac7dca 100644 --- a/database/databasefake/databasefake.go +++ b/database/databasefake/databasefake.go @@ -61,7 +61,7 @@ func (q *fakeQuerier) GetUserCount(_ context.Context) (int64, error) { return int64(len(q.users)), nil } -func (q *fakeQuerier) GetOrganizationByName(ctx context.Context, name string) (database.Organization, error) { +func (q *fakeQuerier) GetOrganizationByName(_ context.Context, name string) (database.Organization, error) { for _, organization := range q.organizations { if organization.Name == name { return organization, nil @@ -70,7 +70,7 @@ func (q *fakeQuerier) GetOrganizationByName(ctx context.Context, name string) (d return database.Organization{}, sql.ErrNoRows } -func (q *fakeQuerier) GetOrganizationsByUserID(ctx context.Context, userID string) ([]database.Organization, error) { +func (q *fakeQuerier) GetOrganizationsByUserID(_ context.Context, userID string) ([]database.Organization, error) { organizations := make([]database.Organization, 0) for _, organizationMember := range q.organizationMembers { if organizationMember.UserID != userID { @@ -112,7 +112,7 @@ func (q *fakeQuerier) InsertAPIKey(_ context.Context, arg database.InsertAPIKeyP return key, nil } -func (q *fakeQuerier) InsertOrganization(ctx context.Context, arg database.InsertOrganizationParams) (database.Organization, error) { +func (q *fakeQuerier) InsertOrganization(_ context.Context, arg database.InsertOrganizationParams) (database.Organization, error) { organization := database.Organization{ ID: arg.ID, Name: arg.Name, @@ -123,7 +123,8 @@ func (q *fakeQuerier) InsertOrganization(ctx context.Context, arg database.Inser return organization, nil } -func (q *fakeQuerier) InsertOrganizationMember(ctx context.Context, arg database.InsertOrganizationMemberParams) (database.OrganizationMember, error) { +func (q *fakeQuerier) InsertOrganizationMember(_ context.Context, arg database.InsertOrganizationMemberParams) (database.OrganizationMember, error) { + //nolint:gosimple organizationMember := database.OrganizationMember{ OrganizationID: arg.OrganizationID, UserID: arg.UserID, diff --git a/httpmw/userparam.go b/httpmw/userparam.go index 45aec11b8eb0b..4112f5843a222 100644 --- a/httpmw/userparam.go +++ b/httpmw/userparam.go @@ -5,9 +5,10 @@ import ( "fmt" "net/http" + "github.com/go-chi/chi" + "github.com/coder/coder/database" "github.com/coder/coder/httpapi" - "github.com/go-chi/chi" ) type userParamContextKey struct{} diff --git a/httpmw/userparam_test.go b/httpmw/userparam_test.go index 48f0df6e09026..fc091f4b005f5 100644 --- a/httpmw/userparam_test.go +++ b/httpmw/userparam_test.go @@ -9,11 +9,12 @@ import ( "testing" "time" + "github.com/go-chi/chi" + "github.com/stretchr/testify/require" + "github.com/coder/coder/database" "github.com/coder/coder/database/databasefake" "github.com/coder/coder/httpmw" - "github.com/go-chi/chi" - "github.com/stretchr/testify/require" ) func TestUserParam(t *testing.T) { @@ -55,8 +56,9 @@ func TestUserParam(t *testing.T) { httpmw.ExtractUserParam(db)(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { rw.WriteHeader(http.StatusOK) })).ServeHTTP(rw, r) - - require.Equal(t, http.StatusBadRequest, rw.Result().StatusCode) + res := rw.Result() + defer res.Body.Close() + require.Equal(t, http.StatusBadRequest, res.StatusCode) }) t.Run("NotMe", func(t *testing.T) { @@ -72,8 +74,9 @@ func TestUserParam(t *testing.T) { httpmw.ExtractUserParam(db)(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { rw.WriteHeader(http.StatusOK) })).ServeHTTP(rw, r) - - require.Equal(t, http.StatusBadRequest, rw.Result().StatusCode) + res := rw.Result() + defer res.Body.Close() + require.Equal(t, http.StatusBadRequest, res.StatusCode) }) t.Run("Me", func(t *testing.T) { @@ -89,7 +92,8 @@ func TestUserParam(t *testing.T) { httpmw.ExtractUserParam(db)(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { rw.WriteHeader(http.StatusOK) })).ServeHTTP(rw, r) - - require.Equal(t, http.StatusOK, rw.Result().StatusCode) + res := rw.Result() + defer res.Body.Close() + require.Equal(t, http.StatusOK, res.StatusCode) }) } From 47f70a99ee237f0bbbac97672b4f30989a35e695 Mon Sep 17 00:00:00 2001 From: Kyle Carberry Date: Sat, 22 Jan 2022 21:59:44 +0000 Subject: [PATCH 3/7] Fix tests --- codersdk/users_test.go | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/codersdk/users_test.go b/codersdk/users_test.go index bf03c63405de8..a2c5a0876c0c3 100644 --- a/codersdk/users_test.go +++ b/codersdk/users_test.go @@ -2,18 +2,16 @@ package codersdk_test import ( "context" - "net/http" "testing" "github.com/stretchr/testify/require" "github.com/coder/coder/coderd" "github.com/coder/coder/coderd/coderdtest" - "github.com/coder/coder/codersdk" ) func TestUsers(t *testing.T) { - t.Run("MultipleInitial", func(t *testing.T) { + t.Run("CreateInitial", func(t *testing.T) { server := coderdtest.New(t) _, err := server.Client.CreateInitialUser(context.Background(), coderd.CreateInitialUserRequest{ Email: "wowie@coder.com", @@ -21,14 +19,12 @@ func TestUsers(t *testing.T) { Username: "tester", Password: "moo", }) - var cerr *codersdk.Error - require.ErrorAs(t, err, &cerr) - require.Equal(t, cerr.StatusCode(), http.StatusConflict) - require.Greater(t, len(cerr.Error()), 0) + require.NoError(t, err) }) t.Run("Get", func(t *testing.T) { server := coderdtest.New(t) + _ = server.RandomInitialUser(t) _, err := server.Client.User(context.Background(), "") require.NoError(t, err) }) From f419971276672ff3333635af64a7360ac6eee51f Mon Sep 17 00:00:00 2001 From: Kyle Carberry Date: Sun, 23 Jan 2022 00:22:10 +0000 Subject: [PATCH 4/7] Fix timeout --- peer/conn_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/peer/conn_test.go b/peer/conn_test.go index 03e3eeca05363..ab6cd9044cf6a 100644 --- a/peer/conn_test.go +++ b/peer/conn_test.go @@ -23,7 +23,7 @@ import ( ) const ( - disconnectedTimeout = 20 * time.Millisecond + disconnectedTimeout = 5 * time.Second failedTimeout = disconnectedTimeout * 5 keepAliveInterval = time.Millisecond * 2 ) From 5302128e7592d4351ba692854c3184df2140a9d7 Mon Sep 17 00:00:00 2001 From: Kyle Carberry Date: Sun, 23 Jan 2022 00:42:46 +0000 Subject: [PATCH 5/7] Add test for UserOrgs --- codersdk/users_test.go | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/codersdk/users_test.go b/codersdk/users_test.go index a2c5a0876c0c3..8752807ec9862 100644 --- a/codersdk/users_test.go +++ b/codersdk/users_test.go @@ -22,10 +22,18 @@ func TestUsers(t *testing.T) { require.NoError(t, err) }) - t.Run("Get", func(t *testing.T) { + t.Run("User", func(t *testing.T) { server := coderdtest.New(t) _ = server.RandomInitialUser(t) _, err := server.Client.User(context.Background(), "") require.NoError(t, err) }) + + t.Run("UserOrganizations", func(t *testing.T) { + server := coderdtest.New(t) + _ = server.RandomInitialUser(t) + orgs, err := server.Client.UserOrganizations(context.Background(), "") + require.NoError(t, err) + require.Len(t, orgs, 1) + }) } From fac16a54519c7c2ec75c8ed03468eeabd0cefcf3 Mon Sep 17 00:00:00 2001 From: Kyle Carberry Date: Sun, 23 Jan 2022 05:11:01 +0000 Subject: [PATCH 6/7] Add test for userparam getting --- coderd/users.go | 3 --- httpmw/userparam_test.go | 1 + 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/coderd/users.go b/coderd/users.go index 37309e0f3dcf8..bac130a53f801 100644 --- a/coderd/users.go +++ b/coderd/users.go @@ -144,9 +144,6 @@ func (users *users) userOrganizations(rw http.ResponseWriter, r *http.Request) { user := httpmw.UserParam(r) organizations, err := users.Database.GetOrganizationsByUserID(r.Context(), user.ID) - if errors.Is(err, sql.ErrNoRows) { - err = nil - } if err != nil { httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ Message: fmt.Sprintf("get organizations: %s", err.Error()), diff --git a/httpmw/userparam_test.go b/httpmw/userparam_test.go index fc091f4b005f5..58204d899fdbb 100644 --- a/httpmw/userparam_test.go +++ b/httpmw/userparam_test.go @@ -90,6 +90,7 @@ func TestUserParam(t *testing.T) { routeContext.URLParams.Add("user", "me") r = r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, routeContext)) httpmw.ExtractUserParam(db)(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + _ = httpmw.UserParam(r) rw.WriteHeader(http.StatusOK) })).ServeHTTP(rw, r) res := rw.Result() From 41ec8ba545a3db1cbbc856b8a99cdf8ee10a2e07 Mon Sep 17 00:00:00 2001 From: Kyle Carberry Date: Sun, 23 Jan 2022 05:43:49 +0000 Subject: [PATCH 7/7] Add test for NoUser --- codersdk/users_test.go | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/codersdk/users_test.go b/codersdk/users_test.go index 8752807ec9862..2304bc9398359 100644 --- a/codersdk/users_test.go +++ b/codersdk/users_test.go @@ -22,6 +22,12 @@ func TestUsers(t *testing.T) { require.NoError(t, err) }) + t.Run("NoUser", func(t *testing.T) { + server := coderdtest.New(t) + _, err := server.Client.User(context.Background(), "") + require.Error(t, err) + }) + t.Run("User", func(t *testing.T) { server := coderdtest.New(t) _ = server.RandomInitialUser(t)