Skip to content

Commit d599753

Browse files
committed
rbac: add IsUnauthorizedError, return 404 if UnauthorizedError in organizationByUserAndName
1 parent fc992cd commit d599753

File tree

3 files changed

+50
-2
lines changed

3 files changed

+50
-2
lines changed

coderd/rbac/error.go

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
package rbac
22

3-
import "github.com/open-policy-agent/opa/rego"
3+
import (
4+
"errors"
5+
6+
"github.com/open-policy-agent/opa/rego"
7+
)
48

59
const (
610
// errUnauthorized is the error message that should be returned to
@@ -18,6 +22,12 @@ type UnauthorizedError struct {
1822
output rego.ResultSet
1923
}
2024

25+
// IsUnauthorizedError is a convenience function to check if err is UnauthorizedError.
26+
// It is equivalent to errors.As(err, &UnauthorizedError{}).
27+
func IsUnauthorizedError(err error) bool {
28+
return errors.As(err, &UnauthorizedError{})
29+
}
30+
2131
// ForbiddenWithInternal creates a new error that will return a simple
2232
// "forbidden" to the client, logging internally the more detailed message
2333
// provided.
@@ -50,3 +60,11 @@ func (e *UnauthorizedError) Input() map[string]interface{} {
5060
func (e *UnauthorizedError) Output() rego.ResultSet {
5161
return e.output
5262
}
63+
64+
// As implements the errors.As interface.
65+
func (*UnauthorizedError) As(target interface{}) bool {
66+
if _, ok := target.(*UnauthorizedError); ok {
67+
return true
68+
}
69+
return false
70+
}

coderd/rbac/error_test.go

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
package rbac
2+
3+
import (
4+
"testing"
5+
6+
"github.com/stretchr/testify/require"
7+
"golang.org/x/xerrors"
8+
)
9+
10+
func TestIsUnauthorizedError(t *testing.T) {
11+
t.Parallel()
12+
t.Run("NotWrapped", func(t *testing.T) {
13+
t.Parallel()
14+
errFunc := func() error {
15+
return UnauthorizedError{}
16+
}
17+
18+
err := errFunc()
19+
require.True(t, IsUnauthorizedError(err))
20+
})
21+
22+
t.Run("Wrapped", func(t *testing.T) {
23+
t.Parallel()
24+
errFunc := func() error {
25+
return xerrors.Errorf("test error: %w", UnauthorizedError{})
26+
}
27+
err := errFunc()
28+
require.True(t, IsUnauthorizedError(err))
29+
})
30+
}

coderd/users.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -966,7 +966,7 @@ func (api *API) organizationByUserAndName(rw http.ResponseWriter, r *http.Reques
966966
ctx := r.Context()
967967
organizationName := chi.URLParam(r, "organizationname")
968968
organization, err := api.Database.GetOrganizationByName(ctx, organizationName)
969-
if errors.Is(err, sql.ErrNoRows) {
969+
if errors.Is(err, sql.ErrNoRows) || rbac.IsUnauthorizedError(err) {
970970
httpapi.ResourceNotFound(rw)
971971
return
972972
}

0 commit comments

Comments
 (0)