Skip to content

feat: Return more 404s vs 403s #2194

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
Jun 14, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions cli/autostart_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ func TestAutostart(t *testing.T) {
clitest.SetupConfig(t, client, root)

err := cmd.Execute()
require.ErrorContains(t, err, "status code 403: Forbidden", "unexpected error")
require.ErrorContains(t, err, "status code 404", "unexpected error")
})

t.Run("Disable_NotFound", func(t *testing.T) {
Expand All @@ -126,7 +126,7 @@ func TestAutostart(t *testing.T) {
clitest.SetupConfig(t, client, root)

err := cmd.Execute()
require.ErrorContains(t, err, "status code 403: Forbidden", "unexpected error")
require.ErrorContains(t, err, "status code 404:", "unexpected error")
})

t.Run("Enable_DefaultSchedule", func(t *testing.T) {
Expand Down
4 changes: 2 additions & 2 deletions cli/ttl_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ func TestTTL(t *testing.T) {
clitest.SetupConfig(t, client, root)

err := cmd.Execute()
require.ErrorContains(t, err, "status code 403: Forbidden", "unexpected error")
require.ErrorContains(t, err, "status code 404:", "unexpected error")
})

t.Run("Unset_NotFound", func(t *testing.T) {
Expand All @@ -195,7 +195,7 @@ func TestTTL(t *testing.T) {
clitest.SetupConfig(t, client, root)

err := cmd.Execute()
require.ErrorContains(t, err, "status code 403: Forbidden", "unexpected error")
require.ErrorContains(t, err, "status code 404:", "unexpected error")
})

t.Run("TemplateMaxTTL", func(t *testing.T) {
Expand Down
13 changes: 9 additions & 4 deletions coderd/authorize.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (

"cdr.dev/slog"

"github.com/coder/coder/coderd/httpapi"
"github.com/coder/coder/coderd/httpmw"
"github.com/coder/coder/coderd/rbac"
)
Expand All @@ -17,12 +16,18 @@ func AuthorizeFilter[O rbac.Objecter](api *API, r *http.Request, action rbac.Act
return rbac.Filter(r.Context(), api.Authorizer, roles.ID.String(), roles.Roles, action, objects)
}

func (api *API) Authorize(rw http.ResponseWriter, r *http.Request, action rbac.Action, object rbac.Objecter) bool {
// Authorize will return false if the user is not authorized to do the action.
// This function will log appropriately, but the caller must return an
// error to the api client.
// Eg:
// if !api.Authorize(...) {
// httpapi.Forbidden(rw)
// return
// }
func (api *API) Authorize(r *http.Request, action rbac.Action, object rbac.Objecter) bool {
roles := httpmw.AuthorizationUserRoles(r)
err := api.Authorizer.ByRoleName(r.Context(), roles.ID.String(), roles.Roles, action, object.RBACObject())
if err != nil {
httpapi.Forbidden(rw)

// Log the errors for debugging
internalError := new(rbac.UnauthorizedError)
logger := api.Logger
Expand Down
12 changes: 8 additions & 4 deletions coderd/coderd_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -380,9 +380,6 @@ func TestAuthorizeAllEndpoints(t *testing.T) {
// By default, all omitted routes check for just "authorize" called
routeAssertions = routeCheck{}
}
if routeAssertions.StatusCode == 0 {
routeAssertions.StatusCode = http.StatusForbidden
}

// Replace all url params with known values
route = strings.ReplaceAll(route, "{organization}", admin.OrganizationID.String())
Expand Down Expand Up @@ -413,7 +410,14 @@ func TestAuthorizeAllEndpoints(t *testing.T) {

if !routeAssertions.NoAuthorize {
assert.NotNil(t, authorizer.Called, "authorizer expected")
assert.Equal(t, routeAssertions.StatusCode, resp.StatusCode, "expect unauthorized")
if routeAssertions.StatusCode != 0 {
assert.Equal(t, routeAssertions.StatusCode, resp.StatusCode, "expect unauthorized")
} else {
// It's either a 404 or 403.
if resp.StatusCode != http.StatusNotFound {
assert.Equal(t, http.StatusForbidden, resp.StatusCode, "expect unauthorized")
}
}
if authorizer.Called != nil {
if routeAssertions.AssertAction != "" {
assert.Equal(t, routeAssertions.AssertAction, authorizer.Called.Action, "resource action")
Expand Down
9 changes: 6 additions & 3 deletions coderd/files.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ func (api *API) postFile(rw http.ResponseWriter, r *http.Request) {
apiKey := httpmw.APIKey(r)
// This requires the site wide action to create files.
// Once created, a user can read their own files uploaded
if !api.Authorize(rw, r, rbac.ActionCreate, rbac.ResourceFile) {
if !api.Authorize(r, rbac.ActionCreate, rbac.ResourceFile) {
httpapi.Forbidden(rw)
return
}

Expand Down Expand Up @@ -86,7 +87,7 @@ func (api *API) fileByHash(rw http.ResponseWriter, r *http.Request) {
}
file, err := api.Database.GetFileByHash(r.Context(), hash)
if errors.Is(err, sql.ErrNoRows) {
httpapi.Forbidden(rw)
httpapi.ResourceNotFound(rw)
return
}
if err != nil {
Expand All @@ -97,8 +98,10 @@ func (api *API) fileByHash(rw http.ResponseWriter, r *http.Request) {
return
}

if !api.Authorize(rw, r, rbac.ActionRead,
if !api.Authorize(r, rbac.ActionRead,
rbac.ResourceFile.WithOwner(file.CreatedBy.String()).WithID(file.Hash)) {
// Return 404 to not leak the file exists
httpapi.ResourceNotFound(rw)
return
}

Expand Down
2 changes: 1 addition & 1 deletion coderd/files_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ func TestDownload(t *testing.T) {
_, _, err := client.Download(context.Background(), "something")
var apiErr *codersdk.Error
require.ErrorAs(t, err, &apiErr)
require.Equal(t, http.StatusForbidden, apiErr.StatusCode())
require.Equal(t, http.StatusNotFound, apiErr.StatusCode())
})

t.Run("Insert", func(t *testing.T) {
Expand Down
6 changes: 4 additions & 2 deletions coderd/gitsshkey.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ import (
func (api *API) regenerateGitSSHKey(rw http.ResponseWriter, r *http.Request) {
user := httpmw.UserParam(r)

if !api.Authorize(rw, r, rbac.ActionUpdate, rbac.ResourceUserData.WithOwner(user.ID.String())) {
if !api.Authorize(r, rbac.ActionUpdate, rbac.ResourceUserData.WithOwner(user.ID.String())) {
httpapi.ResourceNotFound(rw)
return
}

Expand Down Expand Up @@ -62,7 +63,8 @@ func (api *API) regenerateGitSSHKey(rw http.ResponseWriter, r *http.Request) {
func (api *API) gitSSHKey(rw http.ResponseWriter, r *http.Request) {
user := httpmw.UserParam(r)

if !api.Authorize(rw, r, rbac.ActionRead, rbac.ResourceUserData.WithOwner(user.ID.String())) {
if !api.Authorize(r, rbac.ActionRead, rbac.ResourceUserData.WithOwner(user.ID.String())) {
httpapi.ResourceNotFound(rw)
return
}

Expand Down
8 changes: 8 additions & 0 deletions coderd/httpapi/httpapi.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,14 @@ type Error struct {
Detail string `json:"detail" validate:"required"`
}

// ResourceNotFound is intentionally vague. All 404 responses should be identical
// to prevent leaking existence of resources.
func ResourceNotFound(rw http.ResponseWriter) {
Write(rw, http.StatusNotFound, Response{
Message: "Resource not found or you do not have access to this resource",
})
}

func Forbidden(rw http.ResponseWriter) {
Write(rw, http.StatusForbidden, Response{
Message: "Forbidden.",
Expand Down
9 changes: 2 additions & 7 deletions coderd/httpmw/organizationparam.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"context"
"database/sql"
"errors"
"fmt"
"net/http"

"github.com/coder/coder/coderd/database"
Expand Down Expand Up @@ -45,9 +44,7 @@ func ExtractOrganizationParam(db database.Store) func(http.Handler) http.Handler

organization, err := db.GetOrganizationByID(r.Context(), orgID)
if errors.Is(err, sql.ErrNoRows) {
httpapi.Write(rw, http.StatusNotFound, httpapi.Response{
Message: fmt.Sprintf("Organization %q does not exist.", orgID),
})
httpapi.ResourceNotFound(rw)
return
}
if err != nil {
Expand Down Expand Up @@ -76,9 +73,7 @@ func ExtractOrganizationMemberParam(db database.Store) func(http.Handler) http.H
UserID: user.ID,
})
if errors.Is(err, sql.ErrNoRows) {
httpapi.Write(rw, http.StatusForbidden, httpapi.Response{
Message: "Not a member of the organization.",
})
httpapi.ResourceNotFound(rw)
return
}
if err != nil {
Expand Down
2 changes: 1 addition & 1 deletion coderd/httpmw/organizationparam_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ func TestOrganizationParam(t *testing.T) {
rtr.ServeHTTP(rw, r)
res := rw.Result()
defer res.Body.Close()
require.Equal(t, http.StatusForbidden, res.StatusCode)
require.Equal(t, http.StatusNotFound, res.StatusCode)
})

t.Run("Success", func(t *testing.T) {
Expand Down
14 changes: 2 additions & 12 deletions coderd/httpmw/templateparam.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"context"
"database/sql"
"errors"
"fmt"
"net/http"

"github.com/go-chi/chi/v5"
Expand Down Expand Up @@ -33,10 +32,8 @@ func ExtractTemplateParam(db database.Store) func(http.Handler) http.Handler {
return
}
template, err := db.GetTemplateByID(r.Context(), templateID)
if errors.Is(err, sql.ErrNoRows) {
httpapi.Write(rw, http.StatusNotFound, httpapi.Response{
Message: fmt.Sprintf("Template %q does not exist.", templateID),
})
if errors.Is(err, sql.ErrNoRows) || (err == nil && template.Deleted) {
httpapi.ResourceNotFound(rw)
return
}
if err != nil {
Expand All @@ -47,13 +44,6 @@ func ExtractTemplateParam(db database.Store) func(http.Handler) http.Handler {
return
}

if template.Deleted {
httpapi.Write(rw, http.StatusNotFound, httpapi.Response{
Message: fmt.Sprintf("Template %q does not exist.", templateID),
})
return
}

ctx := context.WithValue(r.Context(), templateParamContextKey{}, template)
chi.RouteContext(ctx).URLParams.Add("organization", template.OrganizationID.String())
next.ServeHTTP(rw, r.WithContext(ctx))
Expand Down
5 changes: 1 addition & 4 deletions coderd/httpmw/templateversionparam.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"context"
"database/sql"
"errors"
"fmt"
"net/http"

"github.com/go-chi/chi/v5"
Expand Down Expand Up @@ -34,9 +33,7 @@ func ExtractTemplateVersionParam(db database.Store) func(http.Handler) http.Hand
}
templateVersion, err := db.GetTemplateVersionByID(r.Context(), templateVersionID)
if errors.Is(err, sql.ErrNoRows) {
httpapi.Write(rw, http.StatusNotFound, httpapi.Response{
Message: fmt.Sprintf("Template version %q does not exist.", templateVersionID),
})
httpapi.ResourceNotFound(rw)
return
}
if err != nil {
Expand Down
7 changes: 7 additions & 0 deletions coderd/httpmw/userparam.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@ package httpmw

import (
"context"
"database/sql"
"net/http"

"golang.org/x/xerrors"

"github.com/go-chi/chi/v5"
"github.com/google/uuid"

Expand Down Expand Up @@ -47,6 +50,10 @@ func ExtractUserParam(db database.Store) func(http.Handler) http.Handler {

if userQuery == "me" {
user, err = db.GetUserByID(r.Context(), APIKey(r).UserID)
if xerrors.Is(err, sql.ErrNoRows) {
httpapi.ResourceNotFound(rw)
return
}
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
Message: "Internal error fetching user.",
Expand Down
5 changes: 1 addition & 4 deletions coderd/httpmw/workspacebuildparam.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"context"
"database/sql"
"errors"
"fmt"
"net/http"

"github.com/go-chi/chi/v5"
Expand Down Expand Up @@ -34,9 +33,7 @@ func ExtractWorkspaceBuildParam(db database.Store) func(http.Handler) http.Handl
}
workspaceBuild, err := db.GetWorkspaceBuildByID(r.Context(), workspaceBuildID)
if errors.Is(err, sql.ErrNoRows) {
httpapi.Write(rw, http.StatusNotFound, httpapi.Response{
Message: fmt.Sprintf("Workspace build %q does not exist.", workspaceBuildID),
})
httpapi.ResourceNotFound(rw)
return
}
if err != nil {
Expand Down
5 changes: 1 addition & 4 deletions coderd/httpmw/workspaceparam.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"context"
"database/sql"
"errors"
"fmt"
"net/http"

"github.com/coder/coder/coderd/database"
Expand Down Expand Up @@ -32,9 +31,7 @@ func ExtractWorkspaceParam(db database.Store) func(http.Handler) http.Handler {
}
workspace, err := db.GetWorkspaceByID(r.Context(), workspaceID)
if errors.Is(err, sql.ErrNoRows) {
httpapi.Write(rw, http.StatusNotFound, httpapi.Response{
Message: fmt.Sprintf("Workspace %q does not exist.", workspaceID),
})
httpapi.ResourceNotFound(rw)
return
}
if err != nil {
Expand Down
6 changes: 4 additions & 2 deletions coderd/members.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,15 @@ func (api *API) putMemberRoles(rw http.ResponseWriter, r *http.Request) {
added, removed := rbac.ChangeRoleSet(member.Roles, impliedTypes)
for _, roleName := range added {
// Assigning a role requires the create permission.
if !api.Authorize(rw, r, rbac.ActionCreate, rbac.ResourceOrgRoleAssignment.WithID(roleName).InOrg(organization.ID)) {
if !api.Authorize(r, rbac.ActionCreate, rbac.ResourceOrgRoleAssignment.WithID(roleName).InOrg(organization.ID)) {
httpapi.Forbidden(rw)
return
}
}
for _, roleName := range removed {
// Removing a role requires the delete permission.
if !api.Authorize(rw, r, rbac.ActionDelete, rbac.ResourceOrgRoleAssignment.WithID(roleName).InOrg(organization.ID)) {
if !api.Authorize(r, rbac.ActionDelete, rbac.ResourceOrgRoleAssignment.WithID(roleName).InOrg(organization.ID)) {
httpapi.Forbidden(rw)
return
}
}
Expand Down
7 changes: 4 additions & 3 deletions coderd/organizations.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,10 @@ import (
func (api *API) organization(rw http.ResponseWriter, r *http.Request) {
organization := httpmw.OrganizationParam(r)

if !api.Authorize(rw, r, rbac.ActionRead, rbac.ResourceOrganization.
if !api.Authorize(r, rbac.ActionRead, rbac.ResourceOrganization.
InOrg(organization.ID).
WithID(organization.ID.String())) {
httpapi.ResourceNotFound(rw)
return
}

Expand All @@ -32,8 +33,8 @@ func (api *API) postOrganizations(rw http.ResponseWriter, r *http.Request) {
apiKey := httpmw.APIKey(r)
// Create organization uses the organization resource without an OrgID.
// This means you need the site wide permission to make a new organization.
if !api.Authorize(rw, r, rbac.ActionCreate,
rbac.ResourceOrganization) {
if !api.Authorize(r, rbac.ActionCreate, rbac.ResourceOrganization) {
httpapi.Forbidden(rw)
return
}

Expand Down
4 changes: 2 additions & 2 deletions coderd/organizations_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ func TestOrganizationByUserAndName(t *testing.T) {
_, err := client.OrganizationByName(context.Background(), codersdk.Me, "nothing")
var apiErr *codersdk.Error
require.ErrorAs(t, err, &apiErr)
require.Equal(t, http.StatusForbidden, apiErr.StatusCode())
require.Equal(t, http.StatusNotFound, apiErr.StatusCode())
})

t.Run("NoMember", func(t *testing.T) {
Expand All @@ -45,7 +45,7 @@ func TestOrganizationByUserAndName(t *testing.T) {
_, err = other.OrganizationByName(context.Background(), codersdk.Me, org.Name)
var apiErr *codersdk.Error
require.ErrorAs(t, err, &apiErr)
require.Equal(t, http.StatusForbidden, apiErr.StatusCode())
require.Equal(t, http.StatusNotFound, apiErr.StatusCode())
})

t.Run("Valid", func(t *testing.T) {
Expand Down
Loading