Skip to content

Commit 54bc054

Browse files
committed
feat: Implement basic authorize and unit test
1 parent 2161f84 commit 54bc054

File tree

9 files changed

+229
-83
lines changed

9 files changed

+229
-83
lines changed

coderd/coderd.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,8 @@ func New(options *Options) (http.Handler, func()) {
8181

8282
authRolesMiddleware := httpmw.ExtractUserRoles(options.Database)
8383

84-
authorize := func(f http.HandlerFunc, actions rbac.Action) func(http.Handler) http.Handler {
85-
return httpmw.Authorize(api.Logger, api.Authorizer, actions)
84+
authorize := func(f http.HandlerFunc, actions rbac.Action) http.HandlerFunc {
85+
return httpmw.Authorize(api.Logger, api.Authorizer, actions)(f).ServeHTTP
8686
}
8787

8888
r := chi.NewRouter()
@@ -141,7 +141,7 @@ func New(options *Options) (http.Handler, func()) {
141141
})
142142
r.Route("/members", func(r chi.Router) {
143143
r.Route("/roles", func(r chi.Router) {
144-
r.Use(httpmw.Object(rbac.ResourceUserRole))
144+
r.Use(httpmw.RBACObject(rbac.ResourceUserRole))
145145
r.Get("/", authorize(api.assignableOrgRoles, rbac.ActionCreate))
146146
})
147147
r.Route("/{user}", func(r chi.Router) {
@@ -214,8 +214,8 @@ func New(options *Options) (http.Handler, func()) {
214214
r.Get("/", api.users)
215215
// These routes query information about site wide roles.
216216
r.Route("/roles", func(r chi.Router) {
217-
// Can create/delete all roles to view this endpoint
218-
r.With(httpmw.Object(rbac.ResourceUserRole), authorize(rbac.ActionCreate, rbac.ActionDelete)).Get("/", api.assignableSiteRoles)
217+
r.Use(httpmw.RBACObject(rbac.ResourceUserRole))
218+
r.Get("/", authorize(api.assignableSiteRoles, rbac.ActionCreate))
219219
})
220220
r.Route("/{user}", func(r chi.Router) {
221221
r.Use(httpmw.ExtractUserParam(options.Database))

coderd/database/databasefake/databasefake.go

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,37 @@ func (q *fakeQuerier) GetUsers(_ context.Context, params database.GetUsersParams
242242
return tmp, nil
243243
}
244244

245+
func (q *fakeQuerier) GetAllUserRoles(ctx context.Context, userID uuid.UUID) (database.GetAllUserRolesRow, error) {
246+
q.mutex.RLock()
247+
defer q.mutex.RUnlock()
248+
249+
var user *database.User
250+
roles := make([]string, 0)
251+
for _, u := range q.users {
252+
if u.ID == userID {
253+
roles = append(roles, u.RBACRoles...)
254+
user = &u
255+
break
256+
}
257+
}
258+
259+
for _, mem := range q.organizationMembers {
260+
if mem.UserID == userID {
261+
roles = append(roles, mem.Roles...)
262+
}
263+
}
264+
265+
if user == nil {
266+
return database.GetAllUserRolesRow{}, sql.ErrNoRows
267+
}
268+
269+
return database.GetAllUserRolesRow{
270+
ID: userID,
271+
Username: user.Username,
272+
Roles: roles,
273+
}, nil
274+
}
275+
245276
func (q *fakeQuerier) GetWorkspacesByTemplateID(_ context.Context, arg database.GetWorkspacesByTemplateIDParams) ([]database.Workspace, error) {
246277
q.mutex.RLock()
247278
defer q.mutex.RUnlock()

coderd/database/querier.go

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

coderd/database/queries.sql.go

Lines changed: 7 additions & 22 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

coderd/database/queries/users.sql

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ WHERE
124124
id = $1 RETURNING *;
125125

126126

127-
-- name: GetAllUserRoles :many
127+
-- name: GetAllUserRoles :one
128128
SELECT
129129
-- username is returned just to help for logging purposes
130130
id, username, array_cat(users.rbac_roles, organization_members.roles) :: text[] AS roles

coderd/httpmw/authorize.go

Lines changed: 16 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@ import (
55
"fmt"
66
"net/http"
77

8-
"github.com/google/uuid"
9-
108
"golang.org/x/xerrors"
119

1210
"cdr.dev/slog"
@@ -15,48 +13,13 @@ import (
1513
"github.com/coder/coder/coderd/rbac"
1614
)
1715

16+
// AuthObject wraps the rbac object type for middleware to customize this value
17+
// before being passed to Authorize().
1818
type AuthObject struct {
19-
// WithUser sets the owner of the object to the value returned by the func
20-
WithUser func(r *http.Request) uuid.UUID
21-
22-
// InOrg sets the org owner of the object to the value returned by the func
23-
InOrg func(r *http.Request) uuid.UUID
24-
25-
// WithOwner sets the object id to the value returned by the func
26-
WithOwner func(r *http.Request) uuid.UUID
27-
2819
// Object is that base static object the above functions can modify.
2920
Object rbac.Object
3021
}
3122

32-
func RBACWithOwner(owner func(r *http.Request) database.User) func(http.Handler) http.Handler {
33-
return func(next http.Handler) http.Handler {
34-
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
35-
ao := GetAuthObject(r)
36-
ao.WithOwner = func(r *http.Request) uuid.UUID {
37-
return owner(r).ID
38-
}
39-
40-
ctx := context.WithValue(r.Context(), authObjectKey{}, ao)
41-
next.ServeHTTP(rw, r.WithContext(ctx))
42-
})
43-
}
44-
}
45-
46-
func RBACInOrg(org func(r *http.Request) database.Organization) func(http.Handler) http.Handler {
47-
return func(next http.Handler) http.Handler {
48-
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
49-
ao := GetAuthObject(r)
50-
ao.InOrg = func(r *http.Request) uuid.UUID {
51-
return org(r).ID
52-
}
53-
54-
ctx := context.WithValue(r.Context(), authObjectKey{}, ao)
55-
next.ServeHTTP(rw, r.WithContext(ctx))
56-
})
57-
}
58-
}
59-
6023
// Authorize allows for static object & action authorize checking. If the object is a static object, this is an easy way
6124
// to enforce the route.
6225
func Authorize(logger slog.Logger, auth *rbac.RegoAuthorizer, action rbac.Action) func(http.Handler) http.Handler {
@@ -66,25 +29,27 @@ func Authorize(logger slog.Logger, auth *rbac.RegoAuthorizer, action rbac.Action
6629
args := GetAuthObject(r)
6730

6831
object := args.Object
69-
organization, ok := r.Context().Value(organizationParamContextKey{}).(database.Organization)
70-
if ok {
32+
33+
unknownOrg := r.Context().Value(organizationParamContextKey{})
34+
if organization, castOK := unknownOrg.(database.Organization); unknownOrg != nil {
35+
if !castOK {
36+
panic("developer error: organization param middleware not provided for authorize")
37+
}
7138
object = object.InOrg(organization.ID)
7239
}
7340

74-
if args.InOrg != nil {
75-
object.InOrg(args.InOrg(r))
76-
}
77-
if args.WithUser != nil {
78-
object.WithOwner(args.InOrg(r).String())
79-
}
80-
if args.WithOwner != nil {
81-
object.WithID(args.InOrg(r).String())
41+
unknownOwner := r.Context().Value(userParamContextKey{})
42+
if owner, castOK := unknownOwner.(database.User); unknownOwner != nil {
43+
if !castOK {
44+
panic("developer error: user param middleware not provided for authorize")
45+
}
46+
object = object.WithOwner(owner.ID.String())
8247
}
8348

8449
// Error on the first action that fails
8550
err := auth.AuthorizeByRoleName(r.Context(), roles.ID.String(), roles.Roles, action, object)
8651
if err != nil {
87-
var internalError *rbac.UnauthorizedError
52+
internalError := new(rbac.UnauthorizedError)
8853
if xerrors.As(err, internalError) {
8954
logger = logger.With(slog.F("internal", internalError.Internal()))
9055
}
@@ -117,7 +82,7 @@ func GetAuthObject(r *http.Request) AuthObject {
11782
return obj
11883
}
11984

120-
func Object(object rbac.Object) func(http.Handler) http.Handler {
85+
func RBACObject(object rbac.Object) func(http.Handler) http.Handler {
12186
return func(next http.Handler) http.Handler {
12287
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
12388
ao := GetAuthObject(r)

coderd/httpmw/organizationparam.go

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -78,9 +78,6 @@ func ExtractOrganizationParam(db database.Store) func(http.Handler) http.Handler
7878
ctx := context.WithValue(r.Context(), organizationParamContextKey{}, organization)
7979
ctx = context.WithValue(ctx, organizationMemberParamContextKey{}, organizationMember)
8080

81-
next = RBACInOrg(func(r *http.Request) database.Organization {
82-
return organization
83-
})(next)
8481
next.ServeHTTP(rw, r.WithContext(ctx))
8582
})
8683
}

coderd/roles_test.go

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
package coderd_test
2+
3+
import (
4+
"context"
5+
"net/http"
6+
"testing"
7+
8+
"github.com/coder/coder/coderd/rbac"
9+
10+
"github.com/google/uuid"
11+
12+
"github.com/coder/coder/codersdk"
13+
14+
"github.com/stretchr/testify/require"
15+
16+
"github.com/coder/coder/coderd/coderdtest"
17+
)
18+
19+
func TestListRoles(t *testing.T) {
20+
t.Parallel()
21+
22+
requireUnauthorized := func(t *testing.T, err error) {
23+
var apiErr *codersdk.Error
24+
require.ErrorAs(t, err, &apiErr)
25+
require.Equal(t, http.StatusUnauthorized, apiErr.StatusCode())
26+
require.Contains(t, apiErr.Message, "unauthorized")
27+
}
28+
29+
ctx := context.Background()
30+
client := coderdtest.New(t, nil)
31+
// Create admin, member, and org admin
32+
admin := coderdtest.CreateFirstUser(t, client)
33+
member := coderdtest.CreateAnotherUser(t, client, admin.OrganizationID)
34+
35+
orgAdmin := coderdtest.CreateAnotherUser(t, client, admin.OrganizationID)
36+
orgAdminUser, err := orgAdmin.User(ctx, codersdk.Me)
37+
require.NoError(t, err)
38+
39+
// TODO: @emyrk switch this to the admin when getting non-personal users is
40+
// supported. `client.UpdateOrganizationMemberRoles(...)`
41+
_, err = orgAdmin.UpdateOrganizationMemberRoles(ctx, admin.OrganizationID, orgAdminUser.ID,
42+
codersdk.UpdateRoles{
43+
Roles: []string{rbac.RoleOrgMember(admin.OrganizationID), rbac.RoleOrgAdmin(admin.OrganizationID)},
44+
},
45+
)
46+
require.NoError(t, err)
47+
48+
testCases := []struct {
49+
Name string
50+
Client *codersdk.Client
51+
APICall func() ([]string, error)
52+
ExpectedRoles []string
53+
Authorized bool
54+
}{
55+
{
56+
Name: "MemberListSite",
57+
APICall: func() ([]string, error) {
58+
x, err := member.ListSiteRoles(ctx)
59+
return x, err
60+
},
61+
Authorized: false,
62+
},
63+
{
64+
Name: "OrgMemberListOrg",
65+
APICall: func() ([]string, error) {
66+
return member.ListOrgRoles(ctx, admin.OrganizationID)
67+
},
68+
Authorized: false,
69+
},
70+
{
71+
Name: "NonOrgMemberListOrg",
72+
APICall: func() ([]string, error) {
73+
return member.ListOrgRoles(ctx, uuid.New())
74+
},
75+
Authorized: false,
76+
},
77+
// Org admin
78+
{
79+
Name: "OrgAdminListSite",
80+
APICall: func() ([]string, error) {
81+
return orgAdmin.ListSiteRoles(ctx)
82+
},
83+
Authorized: false,
84+
},
85+
{
86+
Name: "OrgAdminListOrg",
87+
APICall: func() ([]string, error) {
88+
return orgAdmin.ListOrgRoles(ctx, admin.OrganizationID)
89+
},
90+
Authorized: true,
91+
ExpectedRoles: rbac.ListOrgRoles(admin.OrganizationID),
92+
},
93+
{
94+
Name: "OrgAdminListOtherOrg",
95+
APICall: func() ([]string, error) {
96+
return orgAdmin.ListOrgRoles(ctx, uuid.New())
97+
},
98+
Authorized: false,
99+
},
100+
// Admin
101+
{
102+
Name: "AdminListSite",
103+
APICall: func() ([]string, error) {
104+
return client.ListSiteRoles(ctx)
105+
},
106+
Authorized: true,
107+
ExpectedRoles: rbac.ListSiteRoles(),
108+
},
109+
{
110+
Name: "AdminListOrg",
111+
APICall: func() ([]string, error) {
112+
return client.ListOrgRoles(ctx, admin.OrganizationID)
113+
},
114+
Authorized: true,
115+
ExpectedRoles: rbac.ListOrgRoles(admin.OrganizationID),
116+
},
117+
}
118+
119+
for _, c := range testCases {
120+
c := c
121+
t.Run(c.Name, func(t *testing.T) {
122+
t.Parallel()
123+
roles, err := c.APICall()
124+
if !c.Authorized {
125+
requireUnauthorized(t, err)
126+
} else {
127+
require.NoError(t, err)
128+
require.Equal(t, c.ExpectedRoles, roles)
129+
}
130+
})
131+
}
132+
}

0 commit comments

Comments
 (0)