Skip to content

Commit 0e933f0

Browse files
authored
chore: refactor user -> rbac.subject into a function (#13624)
* chore: refactor user subject logic to be in 1 place * test: implement test to assert deleted custom roles are omitted * add unit test for deleted role
1 parent 3ef12ac commit 0e933f0

File tree

7 files changed

+320
-86
lines changed

7 files changed

+320
-86
lines changed

coderd/httpmw/apikey.go

+27-25
Original file line numberDiff line numberDiff line change
@@ -406,16 +406,15 @@ func ExtractAPIKey(rw http.ResponseWriter, r *http.Request, cfg ExtractAPIKeyCon
406406
// If the key is valid, we also fetch the user roles and status.
407407
// The roles are used for RBAC authorize checks, and the status
408408
// is to block 'suspended' users from accessing the platform.
409-
//nolint:gocritic // system needs to update user roles
410-
roles, err := cfg.DB.GetAuthorizationUserRoles(dbauthz.AsSystemRestricted(ctx), key.UserID)
409+
actor, userStatus, err := UserRBACSubject(ctx, cfg.DB, key.UserID, rbac.ScopeName(key.Scope))
411410
if err != nil {
412411
return write(http.StatusUnauthorized, codersdk.Response{
413412
Message: internalErrorMessage,
414413
Detail: fmt.Sprintf("Internal error fetching user's roles. %s", err.Error()),
415414
})
416415
}
417416

418-
if roles.Status == database.UserStatusDormant {
417+
if userStatus == database.UserStatusDormant {
419418
// If coder confirms that the dormant user is valid, it can switch their account to active.
420419
// nolint:gocritic
421420
u, err := cfg.DB.UpdateUserStatus(dbauthz.AsSystemRestricted(ctx), database.UpdateUserStatusParams{
@@ -429,47 +428,50 @@ func ExtractAPIKey(rw http.ResponseWriter, r *http.Request, cfg ExtractAPIKeyCon
429428
Detail: fmt.Sprintf("can't activate a dormant user: %s", err.Error()),
430429
})
431430
}
432-
roles.Status = u.Status
431+
userStatus = u.Status
433432
}
434433

435-
if roles.Status != database.UserStatusActive {
434+
if userStatus != database.UserStatusActive {
436435
return write(http.StatusUnauthorized, codersdk.Response{
437-
Message: fmt.Sprintf("User is not active (status = %q). Contact an admin to reactivate your account.", roles.Status),
436+
Message: fmt.Sprintf("User is not active (status = %q). Contact an admin to reactivate your account.", userStatus),
438437
})
439438
}
440439

440+
if cfg.PostAuthAdditionalHeadersFunc != nil {
441+
cfg.PostAuthAdditionalHeadersFunc(actor, rw.Header())
442+
}
443+
444+
return key, &actor, true
445+
}
446+
447+
// UserRBACSubject fetches a user's rbac.Subject from the database. It pulls all roles from both
448+
// site and organization scopes. It also pulls the groups, and the user's status.
449+
func UserRBACSubject(ctx context.Context, db database.Store, userID uuid.UUID, scope rbac.ExpandableScope) (rbac.Subject, database.UserStatus, error) {
450+
//nolint:gocritic // system needs to update user roles
451+
roles, err := db.GetAuthorizationUserRoles(dbauthz.AsSystemRestricted(ctx), userID)
452+
if err != nil {
453+
return rbac.Subject{}, "", xerrors.Errorf("get authorization user roles: %w", err)
454+
}
455+
441456
roleNames, err := roles.RoleNames()
442457
if err != nil {
443-
return write(http.StatusInternalServerError, codersdk.Response{
444-
Message: "Internal Server Error",
445-
Detail: err.Error(),
446-
})
458+
return rbac.Subject{}, "", xerrors.Errorf("expand role names: %w", err)
447459
}
448460

449461
//nolint:gocritic // Permission to lookup custom roles the user has assigned.
450-
rbacRoles, err := rolestore.Expand(dbauthz.AsSystemRestricted(ctx), cfg.DB, roleNames)
462+
rbacRoles, err := rolestore.Expand(dbauthz.AsSystemRestricted(ctx), db, roleNames)
451463
if err != nil {
452-
return write(http.StatusInternalServerError, codersdk.Response{
453-
Message: "Failed to expand authenticated user roles",
454-
Detail: err.Error(),
455-
Validations: nil,
456-
})
464+
return rbac.Subject{}, "", xerrors.Errorf("expand role names: %w", err)
457465
}
458466

459-
// Actor is the user's authorization context.
460467
actor := rbac.Subject{
461468
FriendlyName: roles.Username,
462-
ID: key.UserID.String(),
469+
ID: userID.String(),
463470
Roles: rbacRoles,
464471
Groups: roles.Groups,
465-
Scope: rbac.ScopeName(key.Scope),
472+
Scope: scope,
466473
}.WithCachedASTValue()
467-
468-
if cfg.PostAuthAdditionalHeadersFunc != nil {
469-
cfg.PostAuthAdditionalHeadersFunc(actor, rw.Header())
470-
}
471-
472-
return key, &actor, true
474+
return actor, roles.Status, nil
473475
}
474476

475477
// APITokenFromRequest returns the api token from the request.

coderd/httpmw/apikey_test.go

+169-1
Original file line numberDiff line numberDiff line change
@@ -14,16 +14,20 @@ import (
1414
"testing"
1515
"time"
1616

17+
"github.com/google/uuid"
1718
"github.com/stretchr/testify/assert"
1819
"github.com/stretchr/testify/require"
20+
"golang.org/x/exp/slices"
1921
"golang.org/x/oauth2"
2022

2123
"github.com/coder/coder/v2/coderd/database"
24+
"github.com/coder/coder/v2/coderd/database/dbauthz"
2225
"github.com/coder/coder/v2/coderd/database/dbgen"
2326
"github.com/coder/coder/v2/coderd/database/dbmem"
2427
"github.com/coder/coder/v2/coderd/database/dbtime"
2528
"github.com/coder/coder/v2/coderd/httpapi"
2629
"github.com/coder/coder/v2/coderd/httpmw"
30+
"github.com/coder/coder/v2/coderd/rbac"
2731
"github.com/coder/coder/v2/codersdk"
2832
"github.com/coder/coder/v2/cryptorand"
2933
"github.com/coder/coder/v2/testutil"
@@ -38,6 +42,37 @@ func randomAPIKeyParts() (id string, secret string) {
3842
func TestAPIKey(t *testing.T) {
3943
t.Parallel()
4044

45+
// assertActorOk asserts all the properties of the user auth are ok.
46+
assertActorOk := func(t *testing.T, r *http.Request) {
47+
t.Helper()
48+
49+
actor, ok := dbauthz.ActorFromContext(r.Context())
50+
assert.True(t, ok, "dbauthz actor ok")
51+
if ok {
52+
_, err := actor.Roles.Expand()
53+
assert.NoError(t, err, "actor roles ok")
54+
55+
_, err = actor.Scope.Expand()
56+
assert.NoError(t, err, "actor scope ok")
57+
58+
err = actor.RegoValueOk()
59+
assert.NoError(t, err, "actor rego ok")
60+
}
61+
62+
auth, ok := httpmw.UserAuthorizationOptional(r)
63+
assert.True(t, ok, "httpmw auth ok")
64+
if ok {
65+
_, err := auth.Roles.Expand()
66+
assert.NoError(t, err, "auth roles ok")
67+
68+
_, err = auth.Scope.Expand()
69+
assert.NoError(t, err, "auth scope ok")
70+
71+
err = auth.RegoValueOk()
72+
assert.NoError(t, err, "auth rego ok")
73+
}
74+
}
75+
4176
successHandler := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
4277
// Only called if the API key passes through the handler.
4378
httpapi.Write(context.Background(), rw, http.StatusOK, codersdk.Response{
@@ -256,6 +291,7 @@ func TestAPIKey(t *testing.T) {
256291
})(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
257292
// Checks that it exists on the context!
258293
_ = httpmw.APIKey(r)
294+
assertActorOk(t, r)
259295
httpapi.Write(r.Context(), rw, http.StatusOK, codersdk.Response{
260296
Message: "It worked!",
261297
})
@@ -296,6 +332,7 @@ func TestAPIKey(t *testing.T) {
296332
// Checks that it exists on the context!
297333
apiKey := httpmw.APIKey(r)
298334
assert.Equal(t, database.APIKeyScopeApplicationConnect, apiKey.Scope)
335+
assertActorOk(t, r)
299336

300337
httpapi.Write(r.Context(), rw, http.StatusOK, codersdk.Response{
301338
Message: "it worked!",
@@ -330,6 +367,8 @@ func TestAPIKey(t *testing.T) {
330367
})(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
331368
// Checks that it exists on the context!
332369
_ = httpmw.APIKey(r)
370+
assertActorOk(t, r)
371+
333372
httpapi.Write(r.Context(), rw, http.StatusOK, codersdk.Response{
334373
Message: "It worked!",
335374
})
@@ -633,7 +672,7 @@ func TestAPIKey(t *testing.T) {
633672
require.Equal(t, sentAPIKey.LoginType, gotAPIKey.LoginType)
634673
})
635674

636-
t.Run("MissongConfig", func(t *testing.T) {
675+
t.Run("MissingConfig", func(t *testing.T) {
637676
t.Parallel()
638677
var (
639678
db = dbmem.New()
@@ -667,4 +706,133 @@ func TestAPIKey(t *testing.T) {
667706
out, _ := io.ReadAll(res.Body)
668707
require.Contains(t, string(out), "Unable to refresh")
669708
})
709+
710+
t.Run("CustomRoles", func(t *testing.T) {
711+
t.Parallel()
712+
var (
713+
db = dbmem.New()
714+
org = dbgen.Organization(t, db, database.Organization{})
715+
customRole = dbgen.CustomRole(t, db, database.CustomRole{
716+
Name: "custom-role",
717+
OrgPermissions: []database.CustomRolePermission{},
718+
OrganizationID: uuid.NullUUID{
719+
UUID: org.ID,
720+
Valid: true,
721+
},
722+
})
723+
user = dbgen.User(t, db, database.User{
724+
RBACRoles: []string{},
725+
})
726+
_ = dbgen.OrganizationMember(t, db, database.OrganizationMember{
727+
UserID: user.ID,
728+
OrganizationID: org.ID,
729+
CreatedAt: time.Time{},
730+
UpdatedAt: time.Time{},
731+
Roles: []string{
732+
rbac.RoleOrgAdmin(),
733+
customRole.Name,
734+
},
735+
})
736+
_, token = dbgen.APIKey(t, db, database.APIKey{
737+
UserID: user.ID,
738+
ExpiresAt: dbtime.Now().AddDate(0, 0, 1),
739+
})
740+
741+
r = httptest.NewRequest("GET", "/", nil)
742+
rw = httptest.NewRecorder()
743+
)
744+
r.Header.Set(codersdk.SessionTokenHeader, token)
745+
746+
httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{
747+
DB: db,
748+
RedirectToLogin: false,
749+
})(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
750+
assertActorOk(t, r)
751+
752+
auth := httpmw.UserAuthorization(r)
753+
754+
roles, err := auth.Roles.Expand()
755+
assert.NoError(t, err, "expand user roles")
756+
// Assert built in org role
757+
assert.True(t, slices.ContainsFunc(roles, func(role rbac.Role) bool {
758+
return role.Identifier.Name == rbac.RoleOrgAdmin() && role.Identifier.OrganizationID == org.ID
759+
}), "org admin role")
760+
// Assert custom role
761+
assert.True(t, slices.ContainsFunc(roles, func(role rbac.Role) bool {
762+
return role.Identifier.Name == customRole.Name && role.Identifier.OrganizationID == org.ID
763+
}), "custom org role")
764+
765+
httpapi.Write(r.Context(), rw, http.StatusOK, codersdk.Response{
766+
Message: "It worked!",
767+
})
768+
})).ServeHTTP(rw, r)
769+
res := rw.Result()
770+
defer res.Body.Close()
771+
require.Equal(t, http.StatusOK, res.StatusCode)
772+
})
773+
774+
// There is no sql foreign key constraint to require all assigned roles
775+
// still exist in the database. We need to handle deleted roles.
776+
t.Run("RoleNotExists", func(t *testing.T) {
777+
t.Parallel()
778+
var (
779+
roleNotExistsName = "role-not-exists"
780+
db = dbmem.New()
781+
org = dbgen.Organization(t, db, database.Organization{})
782+
user = dbgen.User(t, db, database.User{
783+
RBACRoles: []string{
784+
// Also provide an org not exists. In practice this makes no sense
785+
// to store org roles in the user table, but there is no org to
786+
// store it in. So just throw this here for even more unexpected
787+
// behavior handling!
788+
rbac.RoleIdentifier{Name: roleNotExistsName, OrganizationID: uuid.New()}.String(),
789+
},
790+
})
791+
_ = dbgen.OrganizationMember(t, db, database.OrganizationMember{
792+
UserID: user.ID,
793+
OrganizationID: org.ID,
794+
CreatedAt: time.Time{},
795+
UpdatedAt: time.Time{},
796+
Roles: []string{
797+
rbac.RoleOrgAdmin(),
798+
roleNotExistsName,
799+
},
800+
})
801+
_, token = dbgen.APIKey(t, db, database.APIKey{
802+
UserID: user.ID,
803+
ExpiresAt: dbtime.Now().AddDate(0, 0, 1),
804+
})
805+
806+
r = httptest.NewRequest("GET", "/", nil)
807+
rw = httptest.NewRecorder()
808+
)
809+
r.Header.Set(codersdk.SessionTokenHeader, token)
810+
811+
httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{
812+
DB: db,
813+
RedirectToLogin: false,
814+
})(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
815+
assertActorOk(t, r)
816+
auth := httpmw.UserAuthorization(r)
817+
818+
roles, err := auth.Roles.Expand()
819+
assert.NoError(t, err, "expand user roles")
820+
// Assert built in org role
821+
assert.True(t, slices.ContainsFunc(roles, func(role rbac.Role) bool {
822+
return role.Identifier.Name == rbac.RoleOrgAdmin() && role.Identifier.OrganizationID == org.ID
823+
}), "org admin role")
824+
825+
// Assert the role-not-exists is not returned
826+
assert.False(t, slices.ContainsFunc(roles, func(role rbac.Role) bool {
827+
return role.Identifier.Name == roleNotExistsName
828+
}), "role should not exist")
829+
830+
httpapi.Write(r.Context(), rw, http.StatusOK, codersdk.Response{
831+
Message: "It worked!",
832+
})
833+
})).ServeHTTP(rw, r)
834+
res := rw.Result()
835+
defer res.Body.Close()
836+
require.Equal(t, http.StatusOK, res.StatusCode)
837+
})
670838
}

coderd/identityprovider/tokens.go

+6-31
Original file line numberDiff line numberDiff line change
@@ -209,27 +209,14 @@ func authorizationCodeGrant(ctx context.Context, db database.Store, app database
209209
}
210210

211211
// Grab the user roles so we can perform the exchange as the user.
212-
//nolint:gocritic // In the token exchange, there is no user actor.
213-
roles, err := db.GetAuthorizationUserRoles(dbauthz.AsSystemRestricted(ctx), dbCode.UserID)
212+
actor, _, err := httpmw.UserRBACSubject(ctx, db, dbCode.UserID, rbac.ScopeAll)
214213
if err != nil {
215-
return oauth2.Token{}, err
216-
}
217-
218-
roleNames, err := roles.RoleNames()
219-
if err != nil {
220-
return oauth2.Token{}, xerrors.Errorf("role names: %w", err)
221-
}
222-
223-
userSubj := rbac.Subject{
224-
ID: dbCode.UserID.String(),
225-
Roles: rbac.RoleIdentifiers(roleNames),
226-
Groups: roles.Groups,
227-
Scope: rbac.ScopeAll,
214+
return oauth2.Token{}, xerrors.Errorf("fetch user actor: %w", err)
228215
}
229216

230217
// Do the actual token exchange in the database.
231218
err = db.InTx(func(tx database.Store) error {
232-
ctx := dbauthz.As(ctx, userSubj)
219+
ctx := dbauthz.As(ctx, actor)
233220
err = tx.DeleteOAuth2ProviderAppCodeByID(ctx, dbCode.ID)
234221
if err != nil {
235222
return xerrors.Errorf("delete oauth2 app code: %w", err)
@@ -311,22 +298,10 @@ func refreshTokenGrant(ctx context.Context, db database.Store, app database.OAut
311298
if err != nil {
312299
return oauth2.Token{}, err
313300
}
314-
//nolint:gocritic // There is no user yet so we must use the system.
315-
roles, err := db.GetAuthorizationUserRoles(dbauthz.AsSystemRestricted(ctx), prevKey.UserID)
316-
if err != nil {
317-
return oauth2.Token{}, err
318-
}
319301

320-
roleNames, err := roles.RoleNames()
302+
actor, _, err := httpmw.UserRBACSubject(ctx, db, prevKey.UserID, rbac.ScopeAll)
321303
if err != nil {
322-
return oauth2.Token{}, xerrors.Errorf("role names: %w", err)
323-
}
324-
325-
userSubj := rbac.Subject{
326-
ID: prevKey.UserID.String(),
327-
Roles: rbac.RoleIdentifiers(roleNames),
328-
Groups: roles.Groups,
329-
Scope: rbac.ScopeAll,
304+
return oauth2.Token{}, xerrors.Errorf("fetch user actor: %w", err)
330305
}
331306

332307
// Generate a new refresh token.
@@ -351,7 +326,7 @@ func refreshTokenGrant(ctx context.Context, db database.Store, app database.OAut
351326

352327
// Replace the token.
353328
err = db.InTx(func(tx database.Store) error {
354-
ctx := dbauthz.As(ctx, userSubj)
329+
ctx := dbauthz.As(ctx, actor)
355330
err = tx.DeleteAPIKeyByID(ctx, prevKey.ID) // This cascades to the token.
356331
if err != nil {
357332
return xerrors.Errorf("delete oauth2 app token: %w", err)

0 commit comments

Comments
 (0)