diff --git a/cli/logout.go b/cli/logout.go index 15d57b37c0f5b..90b7f8f3d653c 100644 --- a/cli/logout.go +++ b/cli/logout.go @@ -3,6 +3,7 @@ package cli import ( "fmt" "os" + "strings" "github.com/spf13/cobra" "golang.org/x/xerrors" @@ -15,11 +16,16 @@ func logout() *cobra.Command { Use: "logout", Short: "Remove the local authenticated session", RunE: func(cmd *cobra.Command, args []string) error { - var isLoggedOut bool + client, err := createClient(cmd) + if err != nil { + return err + } + + var errors []error config := createConfig(cmd) - _, err := cliui.Prompt(cmd, cliui.PromptOptions{ + _, err = cliui.Prompt(cmd, cliui.PromptOptions{ Text: "Are you sure you want to logout?", IsConfirm: true, Default: "yes", @@ -28,38 +34,40 @@ func logout() *cobra.Command { return err } - err = config.URL().Delete() + err = client.Logout(cmd.Context()) if err != nil { - // Only throw error if the URL configuration file is present, - // otherwise the user is already logged out, and we proceed - if !os.IsNotExist(err) { - return xerrors.Errorf("remove URL file: %w", err) - } - isLoggedOut = true + errors = append(errors, xerrors.Errorf("logout api: %w", err)) + } + + err = config.URL().Delete() + // Only throw error if the URL configuration file is present, + // otherwise the user is already logged out, and we proceed + if err != nil && !os.IsNotExist(err) { + errors = append(errors, xerrors.Errorf("remove URL file: %w", err)) } err = config.Session().Delete() - if err != nil { - // Only throw error if the session configuration file is present, - // otherwise the user is already logged out, and we proceed - if !os.IsNotExist(err) { - return xerrors.Errorf("remove session file: %w", err) - } - isLoggedOut = true + // Only throw error if the session configuration file is present, + // otherwise the user is already logged out, and we proceed + if err != nil && !os.IsNotExist(err) { + errors = append(errors, xerrors.Errorf("remove session file: %w", err)) } err = config.Organization().Delete() // If the organization configuration file is absent, we still proceed if err != nil && !os.IsNotExist(err) { - return xerrors.Errorf("remove organization file: %w", err) + errors = append(errors, xerrors.Errorf("remove organization file: %w", err)) } - // If the user was already logged out, we show them a different message - if isLoggedOut { - _, _ = fmt.Fprintf(cmd.OutOrStdout(), notLoggedInMessage+"\n") - } else { - _, _ = fmt.Fprintf(cmd.OutOrStdout(), caret+"Successfully logged out.\n") + if len(errors) > 0 { + var errorStringBuilder strings.Builder + for _, err := range errors { + _, _ = fmt.Fprint(&errorStringBuilder, "\t"+err.Error()+"\n") + } + errorString := strings.TrimRight(errorStringBuilder.String(), "\n") + return xerrors.New("Failed to log out.\n" + errorString) } + _, _ = fmt.Fprintf(cmd.OutOrStdout(), caret+"You are no longer logged in. You can log in using 'coder login '.\n") return nil }, } diff --git a/cli/logout_test.go b/cli/logout_test.go index bae15ef204ca4..753dcba234bfe 100644 --- a/cli/logout_test.go +++ b/cli/logout_test.go @@ -1,7 +1,10 @@ package cli_test import ( + "fmt" "os" + "regexp" + "runtime" "testing" "github.com/stretchr/testify/assert" @@ -21,7 +24,7 @@ func TestLogout(t *testing.T) { pty := ptytest.New(t) config := login(t, pty) - // ensure session files exist + // Ensure session files exist. require.FileExists(t, string(config.URL())) require.FileExists(t, string(config.Session())) @@ -40,7 +43,7 @@ func TestLogout(t *testing.T) { pty.ExpectMatch("Are you sure you want to logout?") pty.WriteLine("yes") - pty.ExpectMatch("Successfully logged out") + pty.ExpectMatch("You are no longer logged in. You can log in using 'coder login '.") <-logoutChan }) t.Run("SkipPrompt", func(t *testing.T) { @@ -49,7 +52,7 @@ func TestLogout(t *testing.T) { pty := ptytest.New(t) config := login(t, pty) - // ensure session files exist + // Ensure session files exist. require.FileExists(t, string(config.URL())) require.FileExists(t, string(config.Session())) @@ -66,7 +69,7 @@ func TestLogout(t *testing.T) { assert.NoFileExists(t, string(config.Session())) }() - pty.ExpectMatch("Successfully logged out") + pty.ExpectMatch("You are no longer logged in. You can log in using 'coder login '.") <-logoutChan }) t.Run("NoURLFile", func(t *testing.T) { @@ -75,7 +78,7 @@ func TestLogout(t *testing.T) { pty := ptytest.New(t) config := login(t, pty) - // ensure session files exist + // Ensure session files exist. require.FileExists(t, string(config.URL())) require.FileExists(t, string(config.Session())) @@ -91,14 +94,9 @@ func TestLogout(t *testing.T) { go func() { defer close(logoutChan) err := logout.Execute() - assert.NoError(t, err) - assert.NoFileExists(t, string(config.URL())) - assert.NoFileExists(t, string(config.Session())) + assert.EqualError(t, err, "You are not logged in. Try logging in using 'coder login '.") }() - pty.ExpectMatch("Are you sure you want to logout?") - pty.WriteLine("yes") - pty.ExpectMatch("You are not logged in. Try logging in using 'coder login '.") <-logoutChan }) t.Run("NoSessionFile", func(t *testing.T) { @@ -107,7 +105,7 @@ func TestLogout(t *testing.T) { pty := ptytest.New(t) config := login(t, pty) - // ensure session files exist + // Ensure session files exist. require.FileExists(t, string(config.URL())) require.FileExists(t, string(config.Session())) @@ -123,14 +121,73 @@ func TestLogout(t *testing.T) { go func() { defer close(logoutChan) err = logout.Execute() - assert.NoError(t, err) - assert.NoFileExists(t, string(config.URL())) - assert.NoFileExists(t, string(config.Session())) + assert.EqualError(t, err, "You are not logged in. Try logging in using 'coder login '.") + }() + + <-logoutChan + }) + t.Run("CannotDeleteFiles", func(t *testing.T) { + t.Parallel() + + pty := ptytest.New(t) + config := login(t, pty) + + // Ensure session files exist. + require.FileExists(t, string(config.URL())) + require.FileExists(t, string(config.Session())) + + var ( + err error + urlFile *os.File + sessionFile *os.File + ) + if runtime.GOOS == "windows" { + // Opening the files so Windows does not allow deleting them. + urlFile, err = os.Open(string(config.URL())) + require.NoError(t, err) + sessionFile, err = os.Open(string(config.Session())) + require.NoError(t, err) + } else { + // Changing the permissions to throw error during deletion. + err = os.Chmod(string(config), 0500) + require.NoError(t, err) + } + t.Cleanup(func() { + if runtime.GOOS == "windows" { + // Closing the opened files for cleanup. + err = urlFile.Close() + require.NoError(t, err) + err = sessionFile.Close() + require.NoError(t, err) + } else { + // Setting the permissions back for cleanup. + err = os.Chmod(string(config), 0700) + require.NoError(t, err) + } + }) + + logoutChan := make(chan struct{}) + logout, _ := clitest.New(t, "logout", "--global-config", string(config)) + + logout.SetIn(pty.Input()) + logout.SetOut(pty.Output()) + + go func() { + defer close(logoutChan) + err := logout.Execute() + assert.NotNil(t, err) + var errorMessage string + if runtime.GOOS == "windows" { + errorMessage = "The process cannot access the file because it is being used by another process." + } else { + errorMessage = "permission denied" + } + errRegex := regexp.MustCompile(fmt.Sprintf("Failed to log out.\n\tremove URL file: .+: %s\n\tremove session file: .+: %s", errorMessage, errorMessage)) + assert.Regexp(t, errRegex, err.Error()) }() pty.ExpectMatch("Are you sure you want to logout?") pty.WriteLine("yes") - pty.ExpectMatch("You are not logged in. Try logging in using 'coder login '.") <-logoutChan }) } diff --git a/coderd/coderd.go b/coderd/coderd.go index c8c292a8a35b1..4513b2c86360a 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -219,7 +219,6 @@ func New(options *Options) *API { r.Get("/first", api.firstUser) r.Post("/first", api.postFirstUser) r.Post("/login", api.postLogin) - r.Post("/logout", api.postLogout) r.Get("/authmethods", api.userAuthMethods) r.Route("/oauth2", func(r chi.Router) { r.Route("/github", func(r chi.Router) { @@ -234,6 +233,7 @@ func New(options *Options) *API { ) r.Post("/", api.postUser) r.Get("/", api.users) + r.Post("/logout", api.postLogout) // These routes query information about site wide roles. r.Route("/roles", func(r chi.Router) { r.Get("/", api.assignableSiteRoles) diff --git a/coderd/coderd_test.go b/coderd/coderd_test.go index 4ca374ef7ed6a..d0aabb7a74758 100644 --- a/coderd/coderd_test.go +++ b/coderd/coderd_test.go @@ -104,6 +104,10 @@ func TestAuthorizeAllEndpoints(t *testing.T) { workspaceRBACObj := rbac.ResourceWorkspace.InOrg(organization.ID).WithID(workspace.ID.String()).WithOwner(workspace.OwnerID.String()) // skipRoutes allows skipping routes from being checked. + skipRoutes := map[string]string{ + "POST:/api/v2/users/logout": "Logging out deletes the API Key for other routes", + } + type routeCheck struct { NoAuthorize bool AssertAction rbac.Action @@ -117,7 +121,6 @@ func TestAuthorizeAllEndpoints(t *testing.T) { "GET:/api/v2/users/first": {NoAuthorize: true}, "POST:/api/v2/users/first": {NoAuthorize: true}, "POST:/api/v2/users/login": {NoAuthorize: true}, - "POST:/api/v2/users/logout": {NoAuthorize: true}, "GET:/api/v2/users/authmethods": {NoAuthorize: true}, "POST:/api/v2/csp/reports": {NoAuthorize: true}, @@ -310,8 +313,20 @@ func TestAuthorizeAllEndpoints(t *testing.T) { assertRoute[noTrailSlash] = v } + for k, v := range skipRoutes { + noTrailSlash := strings.TrimRight(k, "/") + if _, ok := skipRoutes[noTrailSlash]; ok && noTrailSlash != k { + t.Errorf("route %q & %q is declared twice", noTrailSlash, k) + t.FailNow() + } + skipRoutes[noTrailSlash] = v + } + err = chi.Walk(api.Handler, func(method string, route string, handler http.Handler, middlewares ...func(http.Handler) http.Handler) error { name := method + ":" + route + if _, ok := skipRoutes[strings.TrimRight(name, "/")]; ok { + return nil + } t.Run(name, func(t *testing.T) { authorizer.reset() routeAssertions, ok := assertRoute[strings.TrimRight(name, "/")] diff --git a/coderd/database/databasefake/databasefake.go b/coderd/database/databasefake/databasefake.go index 6acc5d156bdac..8d9edc6f02523 100644 --- a/coderd/database/databasefake/databasefake.go +++ b/coderd/database/databasefake/databasefake.go @@ -126,6 +126,21 @@ func (q *fakeQuerier) GetAPIKeyByID(_ context.Context, id string) (database.APIK return database.APIKey{}, sql.ErrNoRows } +func (q *fakeQuerier) DeleteAPIKeyByID(_ context.Context, id string) error { + q.mutex.Lock() + defer q.mutex.Unlock() + + for index, apiKey := range q.apiKeys { + if apiKey.ID != id { + continue + } + q.apiKeys[index] = q.apiKeys[len(q.apiKeys)-1] + q.apiKeys = q.apiKeys[:len(q.apiKeys)-1] + return nil + } + return sql.ErrNoRows +} + func (q *fakeQuerier) GetFileByHash(_ context.Context, hash string) (database.File, error) { q.mutex.RLock() defer q.mutex.RUnlock() diff --git a/coderd/database/querier.go b/coderd/database/querier.go index db061cfcb36a1..7d4c361cea05b 100644 --- a/coderd/database/querier.go +++ b/coderd/database/querier.go @@ -18,6 +18,7 @@ type querier interface { // multiple provisioners from acquiring the same jobs. See: // https://www.postgresql.org/docs/9.5/sql-select.html#SQL-FOR-UPDATE-SHARE AcquireProvisionerJob(ctx context.Context, arg AcquireProvisionerJobParams) (ProvisionerJob, error) + DeleteAPIKeyByID(ctx context.Context, id string) error DeleteGitSSHKey(ctx context.Context, userID uuid.UUID) error DeleteParameterValueByID(ctx context.Context, id uuid.UUID) error GetAPIKeyByID(ctx context.Context, id string) (APIKey, error) diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index 137fe36b4456b..8724fbbc8e474 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -15,6 +15,19 @@ import ( "github.com/tabbed/pqtype" ) +const deleteAPIKeyByID = `-- name: DeleteAPIKeyByID :exec +DELETE +FROM + api_keys +WHERE + id = $1 +` + +func (q *sqlQuerier) DeleteAPIKeyByID(ctx context.Context, id string) error { + _, err := q.db.ExecContext(ctx, deleteAPIKeyByID, id) + return err +} + const getAPIKeyByID = `-- name: GetAPIKeyByID :one SELECT id, hashed_secret, user_id, last_used, expires_at, created_at, updated_at, login_type, oauth_access_token, oauth_refresh_token, oauth_id_token, oauth_expiry diff --git a/coderd/database/queries/apikeys.sql b/coderd/database/queries/apikeys.sql index 1af2016f491bf..38ac145ce465e 100644 --- a/coderd/database/queries/apikeys.sql +++ b/coderd/database/queries/apikeys.sql @@ -38,3 +38,10 @@ SET oauth_expiry = $6 WHERE id = $1; + +-- name: DeleteAPIKeyByID :exec +DELETE +FROM + api_keys +WHERE + id = $1; diff --git a/coderd/users.go b/coderd/users.go index 02d46655d4f54..f4b87e6a98ebc 100644 --- a/coderd/users.go +++ b/coderd/users.go @@ -85,7 +85,7 @@ func (api *API) postFirstUser(rw http.ResponseWriter, r *http.Request) { // TODO: @emyrk this currently happens outside the database tx used to create // the user. Maybe I add this ability to grant roles in the createUser api // and add some rbac bypass when calling api functions this way?? - // Add the admin role to this first user + // Add the admin role to this first user. _, err = api.Database.UpdateUserRoles(r.Context(), database.UpdateUserRolesParams{ GrantedRoles: []string{rbac.RoleAdmin(), rbac.RoleMember()}, ID: user.ID, @@ -109,7 +109,7 @@ func (api *API) users(rw http.ResponseWriter, r *http.Request) { statusFilter = r.URL.Query().Get("status") ) - // Reading all users across the site + // Reading all users across the site. if !api.Authorize(rw, r, rbac.ActionRead, rbac.ResourceUser) { return } @@ -162,7 +162,7 @@ func (api *API) users(rw http.ResponseWriter, r *http.Request) { // Creates a new user. func (api *API) postUser(rw http.ResponseWriter, r *http.Request) { - // Create the user on the site + // Create the user on the site. if !api.Authorize(rw, r, rbac.ActionCreate, rbac.ResourceUser) { return } @@ -408,11 +408,11 @@ func (api *API) userRoles(rw http.ResponseWriter, r *http.Request) { return } - // Only include ones we can read from RBAC + // Only include ones we can read from RBAC. memberships = AuthorizeFilter(api, r, rbac.ActionRead, memberships) for _, mem := range memberships { - // If we can read the org member, include the roles + // If we can read the org member, include the roles. if err == nil { resp.OrganizationRoles[mem.OrganizationID] = mem.Roles } @@ -422,7 +422,7 @@ func (api *API) userRoles(rw http.ResponseWriter, r *http.Request) { } func (api *API) putUserRoles(rw http.ResponseWriter, r *http.Request) { - // User is the user to modify + // User is the user to modify. user := httpmw.UserParam(r) roles := httpmw.UserRoles(r) @@ -470,7 +470,7 @@ func (api *API) putUserRoles(rw http.ResponseWriter, r *http.Request) { // updateSiteUserRoles will ensure only site wide roles are passed in as arguments. // If an organization role is included, an error is returned. func (api *API) updateSiteUserRoles(ctx context.Context, args database.UpdateUserRolesParams) (database.User, error) { - // Enforce only site wide roles + // Enforce only site wide roles. for _, r := range args.GrantedRoles { if _, ok := rbac.IsOrgRole(r); ok { return database.User{}, xerrors.Errorf("must only update site wide roles") @@ -504,7 +504,7 @@ func (api *API) organizationsByUser(rw http.ResponseWriter, r *http.Request) { return } - // Only return orgs the user can read + // Only return orgs the user can read. organizations = AuthorizeFilter(api, r, rbac.ActionRead, organizations) publicOrganizations := make([]codersdk.Organization, 0, len(organizations)) @@ -585,7 +585,7 @@ func (api *API) postLogin(rw http.ResponseWriter, r *http.Request) { }) } -// Creates a new session key, used for logging in via the CLI +// Creates a new session key, used for logging in via the CLI. func (api *API) postAPIKey(rw http.ResponseWriter, r *http.Request) { user := httpmw.UserParam(r) @@ -604,17 +604,28 @@ func (api *API) postAPIKey(rw http.ResponseWriter, r *http.Request) { httpapi.Write(rw, http.StatusCreated, codersdk.GenerateAPIKeyResponse{Key: sessionToken}) } -// Clear the user's session cookie -func (*API) postLogout(rw http.ResponseWriter, _ *http.Request) { - // Get a blank token cookie +// Clear the user's session cookie. +func (api *API) postLogout(rw http.ResponseWriter, r *http.Request) { + // Get a blank token cookie. cookie := &http.Cookie{ - // MaxAge < 0 means to delete the cookie now + // MaxAge < 0 means to delete the cookie now. MaxAge: -1, Name: httpmw.SessionTokenKey, Path: "/", } http.SetCookie(rw, cookie) + + // Delete the session token from database. + apiKey := httpmw.APIKey(r) + err := api.Database.DeleteAPIKeyByID(r.Context(), apiKey.ID) + if err != nil { + httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ + Message: fmt.Sprintf("delete api key: %s", err.Error()), + }) + return + } + httpapi.Write(rw, http.StatusOK, httpapi.Response{ Message: "Logged out!", }) @@ -696,7 +707,7 @@ func (api *API) createUser(ctx context.Context, req codersdk.CreateUserRequest) req.OrganizationID = organization.ID orgRoles = append(orgRoles, rbac.RoleOrgAdmin(req.OrganizationID)) } - // Always also be a member + // Always also be a member. orgRoles = append(orgRoles, rbac.RoleOrgMember(req.OrganizationID)) params := database.InsertUserParams{ @@ -742,7 +753,7 @@ func (api *API) createUser(ctx context.Context, req codersdk.CreateUserRequest) UserID: user.ID, CreatedAt: database.Now(), UpdatedAt: database.Now(), - // By default give them membership to the organization + // By default give them membership to the organization. Roles: orgRoles, }) if err != nil { diff --git a/coderd/users_test.go b/coderd/users_test.go index d8d05542df092..0c31d5c9259aa 100644 --- a/coderd/users_test.go +++ b/coderd/users_test.go @@ -2,15 +2,18 @@ package coderd_test import ( "context" + "database/sql" "fmt" "net/http" "sort" + "strings" "testing" "github.com/google/uuid" "github.com/stretchr/testify/require" "github.com/coder/coder/coderd/coderdtest" + "github.com/coder/coder/coderd/database/databasefake" "github.com/coder/coder/coderd/httpmw" "github.com/coder/coder/coderd/rbac" "github.com/coder/coder/codersdk" @@ -103,27 +106,71 @@ func TestPostLogin(t *testing.T) { func TestPostLogout(t *testing.T) { t.Parallel() - t.Run("ClearCookie", func(t *testing.T) { + // Checks that the cookie is cleared and the API Key is deleted from the database. + t.Run("Logout", func(t *testing.T) { t.Parallel() - client := coderdtest.New(t, nil) + ctx := context.Background() + client, api := coderdtest.NewWithAPI(t, nil) + _ = coderdtest.CreateFirstUser(t, client) + keyID := strings.Split(client.SessionToken, "-")[0] + + apiKey, err := api.Database.GetAPIKeyByID(ctx, keyID) + require.NoError(t, err) + require.Equal(t, keyID, apiKey.ID, "API key should exist in the database") + fullURL, err := client.URL.Parse("/api/v2/users/logout") require.NoError(t, err, "Server URL should parse successfully") - req, err := http.NewRequestWithContext(context.Background(), http.MethodPost, fullURL.String(), nil) - require.NoError(t, err, "/logout request construction should succeed") + res, err := client.Request(ctx, http.MethodPost, fullURL.String(), nil) + require.NoError(t, err, "/logout request should succeed") + res.Body.Close() + require.Equal(t, http.StatusOK, res.StatusCode) + + cookies := res.Cookies() + require.Len(t, cookies, 1, "Exactly one cookie should be returned") - httpClient := &http.Client{} + require.Equal(t, httpmw.SessionTokenKey, cookies[0].Name, "Cookie should be the auth cookie") + require.Equal(t, -1, cookies[0].MaxAge, "Cookie should be set to delete") - response, err := httpClient.Do(req) + apiKey, err = api.Database.GetAPIKeyByID(ctx, keyID) + require.ErrorIs(t, err, sql.ErrNoRows, "API key should not exist in the database") + }) + + t.Run("LogoutWithoutKey", func(t *testing.T) { + t.Parallel() + + ctx := context.Background() + client, api := coderdtest.NewWithAPI(t, nil) + _ = coderdtest.CreateFirstUser(t, client) + keyID := strings.Split(client.SessionToken, "-")[0] + + apiKey, err := api.Database.GetAPIKeyByID(ctx, keyID) + require.NoError(t, err) + require.Equal(t, keyID, apiKey.ID, "API key should exist in the database") + + // Setting a fake database without the API Key to be used by the API. + // The middleware that extracts the API key is already set to read + // from the original database. + dbWithoutKey := databasefake.New() + api.Database = dbWithoutKey + + fullURL, err := client.URL.Parse("/api/v2/users/logout") + require.NoError(t, err, "Server URL should parse successfully") + + res, err := client.Request(ctx, http.MethodPost, fullURL.String(), nil) require.NoError(t, err, "/logout request should succeed") - response.Body.Close() + res.Body.Close() + require.Equal(t, http.StatusInternalServerError, res.StatusCode) - cookies := response.Cookies() + cookies := res.Cookies() require.Len(t, cookies, 1, "Exactly one cookie should be returned") - require.Equal(t, cookies[0].Name, httpmw.SessionTokenKey, "Cookie should be the auth cookie") - require.Equal(t, cookies[0].MaxAge, -1, "Cookie should be set to delete") + require.Equal(t, httpmw.SessionTokenKey, cookies[0].Name, "Cookie should be the auth cookie") + require.Equal(t, -1, cookies[0].MaxAge, "Cookie should be set to delete") + + apiKey, err = api.Database.GetAPIKeyByID(ctx, keyID) + require.ErrorIs(t, err, sql.ErrNoRows, "API key should not exist in the database") }) }