Skip to content

Commit 26458cd

Browse files
authored
refactor: consolidate template and workspace acl validation (#19192)
1 parent 02de067 commit 26458cd

File tree

14 files changed

+535
-103
lines changed

14 files changed

+535
-103
lines changed

coderd/coderd.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1415,7 +1415,8 @@ func New(options *Options) *API {
14151415
r.Get("/timings", api.workspaceTimings)
14161416
r.Route("/acl", func(r chi.Router) {
14171417
r.Use(
1418-
httpmw.RequireExperiment(api.Experiments, codersdk.ExperimentWorkspaceSharing))
1418+
httpmw.RequireExperiment(api.Experiments, codersdk.ExperimentWorkspaceSharing),
1419+
)
14191420

14201421
r.Patch("/", api.patchWorkspaceACL)
14211422
})

coderd/database/dbauthz/dbauthz.go

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5376,6 +5376,26 @@ func (q *querier) UpsertWorkspaceAppAuditSession(ctx context.Context, arg databa
53765376
return q.db.UpsertWorkspaceAppAuditSession(ctx, arg)
53775377
}
53785378

5379+
func (q *querier) ValidateGroupIDs(ctx context.Context, groupIDs []uuid.UUID) (database.ValidateGroupIDsRow, error) {
5380+
// This check is probably overly restrictive, but the "correct" check isn't
5381+
// necessarily obvious. It's only used as a verification check for ACLs right
5382+
// now, which are performed as system.
5383+
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceSystem); err != nil {
5384+
return database.ValidateGroupIDsRow{}, err
5385+
}
5386+
return q.db.ValidateGroupIDs(ctx, groupIDs)
5387+
}
5388+
5389+
func (q *querier) ValidateUserIDs(ctx context.Context, userIDs []uuid.UUID) (database.ValidateUserIDsRow, error) {
5390+
// This check is probably overly restrictive, but the "correct" check isn't
5391+
// necessarily obvious. It's only used as a verification check for ACLs right
5392+
// now, which are performed as system.
5393+
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceSystem); err != nil {
5394+
return database.ValidateUserIDsRow{}, err
5395+
}
5396+
return q.db.ValidateUserIDs(ctx, userIDs)
5397+
}
5398+
53795399
func (q *querier) GetAuthorizedTemplates(ctx context.Context, arg database.GetTemplatesWithFilterParams, _ rbac.PreparedAuthorized) ([]database.Template, error) {
53805400
// TODO Delete this function, all GetTemplates should be authorized. For now just call getTemplates on the authz querier.
53815401
return q.GetTemplatesWithFilter(ctx, arg)

coderd/database/dbauthz/dbauthz_test.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -623,6 +623,11 @@ func (s *MethodTestSuite) TestGroup() {
623623
ID: g.ID,
624624
}).Asserts(g, policy.ActionUpdate)
625625
}))
626+
s.Run("ValidateGroupIDs", s.Subtest(func(db database.Store, check *expects) {
627+
o := dbgen.Organization(s.T(), db, database.Organization{})
628+
g := dbgen.Group(s.T(), db, database.Group{OrganizationID: o.ID})
629+
check.Args([]uuid.UUID{g.ID}).Asserts(rbac.ResourceSystem, policy.ActionRead)
630+
}))
626631
}
627632

628633
func (s *MethodTestSuite) TestProvisionerJob() {
@@ -2077,6 +2082,10 @@ func (s *MethodTestSuite) TestUser() {
20772082
Interval: int32((time.Hour * 24).Seconds()),
20782083
}).Asserts(rbac.ResourceUser, policy.ActionRead)
20792084
}))
2085+
s.Run("ValidateUserIDs", s.Subtest(func(db database.Store, check *expects) {
2086+
u := dbgen.User(s.T(), db, database.User{})
2087+
check.Args([]uuid.UUID{u.ID}).Asserts(rbac.ResourceSystem, policy.ActionRead)
2088+
}))
20802089
}
20812090

20822091
func (s *MethodTestSuite) TestWorkspace() {

coderd/database/dbmetrics/querymetrics.go

Lines changed: 14 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

coderd/database/dbmock/dbmock.go

Lines changed: 30 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

coderd/database/querier.go

Lines changed: 2 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

coderd/database/queries.sql.go

Lines changed: 64 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

coderd/database/queries/groups.sql

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,24 @@ WHERE
88
LIMIT
99
1;
1010

11+
-- name: ValidateGroupIDs :one
12+
WITH input AS (
13+
SELECT
14+
unnest(@group_ids::uuid[]) AS id
15+
)
16+
SELECT
17+
array_agg(input.id)::uuid[] as invalid_group_ids,
18+
COUNT(*) = 0 as ok
19+
FROM
20+
-- Preserve rows where there is not a matching left (groups) row for each
21+
-- right (input) row...
22+
groups
23+
RIGHT JOIN input ON groups.id = input.id
24+
WHERE
25+
-- ...so that we can retain exactly those rows where an input ID does not
26+
-- match an existing group.
27+
groups.id IS NULL;
28+
1129
-- name: GetGroupByOrgAndName :one
1230
SELECT
1331
*

coderd/database/queries/users.sql

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,26 @@ WHERE
2525
LIMIT
2626
1;
2727

28+
-- name: ValidateUserIDs :one
29+
WITH input AS (
30+
SELECT
31+
unnest(@user_ids::uuid[]) AS id
32+
)
33+
SELECT
34+
array_agg(input.id)::uuid[] as invalid_user_ids,
35+
COUNT(*) = 0 as ok
36+
FROM
37+
-- Preserve rows where there is not a matching left (users) row for each
38+
-- right (input) row...
39+
users
40+
RIGHT JOIN input ON users.id = input.id
41+
WHERE
42+
-- ...so that we can retain exactly those rows where an input ID does not
43+
-- match an existing user...
44+
users.id IS NULL OR
45+
-- ...or that only matches a user that was deleted.
46+
users.deleted = true;
47+
2848
-- name: GetUsersByIDs :many
2949
-- This shouldn't check for deleted, because it's frequently used
3050
-- to look up references to actions. eg. a user could build a workspace

coderd/rbac/acl/updatevalidator.go

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
package acl
2+
3+
import (
4+
"context"
5+
"fmt"
6+
7+
"github.com/google/uuid"
8+
9+
"github.com/coder/coder/v2/coderd/database"
10+
"github.com/coder/coder/v2/coderd/database/dbauthz"
11+
"github.com/coder/coder/v2/codersdk"
12+
)
13+
14+
type UpdateValidator[Role codersdk.WorkspaceRole | codersdk.TemplateRole] interface {
15+
// Users should return a map from user UUIDs (as strings) to the role they
16+
// are being assigned. Additionally, it should return a string that will be
17+
// used as the field name for the ValidationErrors returned from Validate.
18+
Users() (map[string]Role, string)
19+
// Groups should return a map from group UUIDs (as strings) to the role they
20+
// are being assigned. Additionally, it should return a string that will be
21+
// used as the field name for the ValidationErrors returned from Validate.
22+
Groups() (map[string]Role, string)
23+
// ValidateRole should return an error that will be used in the
24+
// ValidationError if the role is invalid for the corresponding resource type.
25+
ValidateRole(role Role) error
26+
}
27+
28+
func Validate[Role codersdk.WorkspaceRole | codersdk.TemplateRole](
29+
ctx context.Context,
30+
db database.Store,
31+
v UpdateValidator[Role],
32+
) []codersdk.ValidationError {
33+
// nolint:gocritic // Validate requires full read access to users and groups
34+
ctx = dbauthz.AsSystemRestricted(ctx)
35+
var validErrs []codersdk.ValidationError
36+
37+
groupRoles, groupsField := v.Groups()
38+
groupIDs := make([]uuid.UUID, 0, len(groupRoles))
39+
for idStr, role := range groupRoles {
40+
// Validate the provided role names
41+
if err := v.ValidateRole(role); err != nil {
42+
validErrs = append(validErrs, codersdk.ValidationError{
43+
Field: groupsField,
44+
Detail: err.Error(),
45+
})
46+
}
47+
// Validate that the IDs are UUIDs
48+
id, err := uuid.Parse(idStr)
49+
if err != nil {
50+
validErrs = append(validErrs, codersdk.ValidationError{
51+
Field: groupsField,
52+
Detail: fmt.Sprintf("%v is not a valid UUID.", idStr),
53+
})
54+
continue
55+
}
56+
// Don't check if the ID exists when setting the role to
57+
// WorkspaceRoleDeleted or TemplateRoleDeleted. They might've existing at
58+
// some point and got deleted. If we report that as an error here then they
59+
// can't be removed.
60+
if string(role) == "" {
61+
continue
62+
}
63+
groupIDs = append(groupIDs, id)
64+
}
65+
66+
// Validate that the groups exist
67+
groupValidation, err := db.ValidateGroupIDs(ctx, groupIDs)
68+
if err != nil {
69+
validErrs = append(validErrs, codersdk.ValidationError{
70+
Field: groupsField,
71+
Detail: fmt.Sprintf("failed to validate group IDs: %v", err.Error()),
72+
})
73+
}
74+
if !groupValidation.Ok {
75+
for _, id := range groupValidation.InvalidGroupIds {
76+
validErrs = append(validErrs, codersdk.ValidationError{
77+
Field: groupsField,
78+
Detail: fmt.Sprintf("group with ID %v does not exist", id),
79+
})
80+
}
81+
}
82+
83+
userRoles, usersField := v.Users()
84+
userIDs := make([]uuid.UUID, 0, len(userRoles))
85+
for idStr, role := range userRoles {
86+
// Validate the provided role names
87+
if err := v.ValidateRole(role); err != nil {
88+
validErrs = append(validErrs, codersdk.ValidationError{
89+
Field: usersField,
90+
Detail: err.Error(),
91+
})
92+
}
93+
// Validate that the IDs are UUIDs
94+
id, err := uuid.Parse(idStr)
95+
if err != nil {
96+
validErrs = append(validErrs, codersdk.ValidationError{
97+
Field: usersField,
98+
Detail: fmt.Sprintf("%v is not a valid UUID.", idStr),
99+
})
100+
continue
101+
}
102+
// Don't check if the ID exists when setting the role to
103+
// WorkspaceRoleDeleted or TemplateRoleDeleted. They might've existing at
104+
// some point and got deleted. If we report that as an error here then they
105+
// can't be removed.
106+
if string(role) == "" {
107+
continue
108+
}
109+
userIDs = append(userIDs, id)
110+
}
111+
112+
// Validate that the groups exist
113+
userValidation, err := db.ValidateUserIDs(ctx, userIDs)
114+
if err != nil {
115+
validErrs = append(validErrs, codersdk.ValidationError{
116+
Field: usersField,
117+
Detail: fmt.Sprintf("failed to validate user IDs: %v", err.Error()),
118+
})
119+
}
120+
if !userValidation.Ok {
121+
for _, id := range userValidation.InvalidUserIds {
122+
validErrs = append(validErrs, codersdk.ValidationError{
123+
Field: usersField,
124+
Detail: fmt.Sprintf("user with ID %v does not exist", id),
125+
})
126+
}
127+
}
128+
129+
return validErrs
130+
}

0 commit comments

Comments
 (0)