Skip to content

Commit 13f1c9f

Browse files
committed
remove queryWithRelated
1 parent 4fe26e9 commit 13f1c9f

File tree

3 files changed

+62
-117
lines changed

3 files changed

+62
-117
lines changed

coderd/authzquery/authz.go

Lines changed: 0 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -245,48 +245,6 @@ func fetchWithPostFilter[ArgumentType any, ObjectType rbac.Objecter,
245245
}
246246
}
247247

248-
// queryWithRelated performs the same function as authorizedQuery, except that
249-
// RBAC checks are performed on the result of relatedFunc() instead of the result of fetch().
250-
// This is useful for cases where ObjectType does not implement RBACObjecter.
251-
// For example, a TemplateVersion object does not implement RBACObjecter, but it is
252-
// related to a Template object, which does. Thus, any operations on a TemplateVersion
253-
// are predicated on the RBAC permissions of the related Template object.
254-
func queryWithRelated[ObjectType any, ArgumentType any, Related rbac.Objecter](
255-
// Arguments
256-
logger slog.Logger,
257-
authorizer rbac.Authorizer,
258-
action rbac.Action,
259-
relatedFunc func(ObjectType, ArgumentType) (Related, error),
260-
fetch func(ctx context.Context, arg ArgumentType) (ObjectType, error)) func(ctx context.Context, arg ArgumentType) (ObjectType, error) {
261-
return func(ctx context.Context, arg ArgumentType) (empty ObjectType, err error) {
262-
// Fetch the rbac subject
263-
act, ok := ActorFromContext(ctx)
264-
if !ok {
265-
return empty, NoActorError
266-
}
267-
268-
// Fetch the rbac object
269-
obj, err := fetch(ctx, arg)
270-
if err != nil {
271-
return empty, xerrors.Errorf("fetch object: %w", err)
272-
}
273-
274-
// Fetch the related object on which we actually do RBAC
275-
rel, err := relatedFunc(obj, arg)
276-
if err != nil {
277-
return empty, xerrors.Errorf("fetch related object: %w", err)
278-
}
279-
280-
// Authorize the action
281-
err = authorizer.Authorize(ctx, act, action, rel.RBACObject())
282-
if err != nil {
283-
return empty, LogNotAuthorizedError(ctx, logger, err)
284-
}
285-
286-
return obj, nil
287-
}
288-
}
289-
290248
// prepareSQLFilter is a helper function that prepares a SQL filter using the
291249
// given authorization context.
292250
func prepareSQLFilter(ctx context.Context, authorizer rbac.Authorizer, action rbac.Action, resourceType string) (rbac.PreparedAuthorized, error) {

coderd/authzquery/group.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,10 @@ func (q *AuthzQuerier) GetGroupByOrgAndName(ctx context.Context, arg database.Ge
5050
}
5151

5252
func (q *AuthzQuerier) GetGroupMembers(ctx context.Context, groupID uuid.UUID) ([]database.User, error) {
53-
relatedFunc := func(_ []database.User, groupID uuid.UUID) (database.Group, error) {
54-
return q.db.GetGroupByID(ctx, groupID)
53+
if _, err := q.GetGroupByID(ctx, groupID); err != nil { // AuthZ check
54+
return nil, err
5555
}
56-
return queryWithRelated(q.log, q.auth, rbac.ActionRead, relatedFunc, q.db.GetGroupMembers)(ctx, groupID)
56+
return q.db.GetGroupMembers(ctx, groupID)
5757
}
5858

5959
func (q *AuthzQuerier) InsertAllUsersGroup(ctx context.Context, organizationID uuid.UUID) (database.Group, error) {

coderd/authzquery/template.go

Lines changed: 59 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -17,27 +17,26 @@ import (
1717

1818
func (q *AuthzQuerier) GetPreviousTemplateVersion(ctx context.Context, arg database.GetPreviousTemplateVersionParams) (database.TemplateVersion, error) {
1919
// An actor can read the previous template version if they can read the related template.
20-
fetchRelated := func(_ database.TemplateVersion, _ database.GetPreviousTemplateVersionParams) (rbac.Objecter, error) {
21-
if !arg.TemplateID.Valid {
22-
// If no linked template exists, check if the actor can read the template in the organization.
23-
return rbac.ResourceTemplate.InOrg(arg.OrganizationID), nil
20+
// If no linked template exists, we check if the actor can read *a* template.
21+
if !arg.TemplateID.Valid {
22+
if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceTemplate.InOrg(arg.OrganizationID)); err != nil {
23+
return database.TemplateVersion{}, err
2424
}
25-
return q.db.GetTemplateByID(ctx, arg.TemplateID.UUID)
2625
}
27-
return queryWithRelated(q.log, q.auth, rbac.ActionRead, fetchRelated, q.db.GetPreviousTemplateVersion)(ctx, arg)
26+
if _, err := q.GetTemplateByID(ctx, arg.TemplateID.UUID); err != nil {
27+
return database.TemplateVersion{}, err
28+
}
29+
return q.db.GetPreviousTemplateVersion(ctx, arg)
2830
}
2931

3032
func (q *AuthzQuerier) GetTemplateAverageBuildTime(ctx context.Context, arg database.GetTemplateAverageBuildTimeParams) (database.GetTemplateAverageBuildTimeRow, error) {
3133
// An actor can read the average build time if they can read the related template.
32-
fetchRelated := func(database.GetTemplateAverageBuildTimeRow, database.GetTemplateAverageBuildTimeParams) (rbac.Objecter, error) {
33-
if !arg.TemplateID.Valid {
34-
// If no linked template exists, check if the actor can read *a* template.
35-
// We don't know the organization ID.
36-
return rbac.ResourceTemplate, nil
37-
}
38-
return q.db.GetTemplateByID(ctx, arg.TemplateID.UUID)
34+
// It doesn't make any sense to get the average build time for a template that doesn't
35+
// exist, so omitting this check here.
36+
if _, err := q.GetTemplateByID(ctx, arg.TemplateID.UUID); err != nil {
37+
return database.GetTemplateAverageBuildTimeRow{}, err
3938
}
40-
return queryWithRelated(q.log, q.auth, rbac.ActionRead, fetchRelated, q.db.GetTemplateAverageBuildTime)(ctx, arg)
39+
return q.db.GetTemplateAverageBuildTime(ctx, arg)
4140
}
4241

4342
func (q *AuthzQuerier) GetTemplateByID(ctx context.Context, id uuid.UUID) (database.Template, error) {
@@ -50,68 +49,62 @@ func (q *AuthzQuerier) GetTemplateByOrganizationAndName(ctx context.Context, arg
5049

5150
func (q *AuthzQuerier) GetTemplateDAUs(ctx context.Context, templateID uuid.UUID) ([]database.GetTemplateDAUsRow, error) {
5251
// An actor can read the DAUs if they can read the related template.
53-
fetchRelated := func(_ []database.GetTemplateDAUsRow, _ uuid.UUID) (rbac.Objecter, error) {
54-
return q.db.GetTemplateByID(ctx, templateID)
52+
// Again, it doesn't make sense to get DAUs for a template that doesn't exist.
53+
if _, err := q.GetTemplateByID(ctx, templateID); err != nil {
54+
return nil, err
5555
}
56-
return queryWithRelated(q.log, q.auth, rbac.ActionRead, fetchRelated, q.db.GetTemplateDAUs)(ctx, templateID)
56+
return q.db.GetTemplateDAUs(ctx, templateID)
5757
}
5858

5959
func (q *AuthzQuerier) GetTemplateVersionByID(ctx context.Context, tvid uuid.UUID) (database.TemplateVersion, error) {
60-
// An actor can read the template version if they can read the related template.
61-
fetchRelated := func(tv database.TemplateVersion, _ uuid.UUID) (rbac.Objecter, error) {
62-
if !tv.TemplateID.Valid {
63-
// If no linked template exists, check if the actor can read a template
64-
// in the organization.
65-
return rbac.ResourceTemplate.InOrg(tv.OrganizationID), nil
60+
tv, err := q.db.GetTemplateVersionByID(ctx, tvid)
61+
if err != nil {
62+
return database.TemplateVersion{}, err
63+
}
64+
if !tv.TemplateID.Valid {
65+
// If no linked template exists, check if the actor can read a template in the organization.
66+
if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceTemplate.InOrg(tv.OrganizationID)); err != nil {
67+
return database.TemplateVersion{}, err
6668
}
67-
return q.db.GetTemplateByID(ctx, tv.TemplateID.UUID)
68-
}
69-
return queryWithRelated(
70-
q.log,
71-
q.auth,
72-
rbac.ActionRead,
73-
fetchRelated,
74-
q.db.GetTemplateVersionByID,
75-
)(ctx, tvid)
69+
} else if _, err := q.GetTemplateByID(ctx, tv.TemplateID.UUID); err != nil {
70+
// An actor can read the template version if they can read the related template.
71+
return database.TemplateVersion{}, err
72+
}
73+
return tv, nil
7674
}
7775

7876
func (q *AuthzQuerier) GetTemplateVersionByJobID(ctx context.Context, jobID uuid.UUID) (database.TemplateVersion, error) {
79-
// An actor can read the template version if they can read the related template.
80-
fetchRelated := func(tv database.TemplateVersion, _ uuid.UUID) (rbac.Objecter, error) {
81-
if !tv.TemplateID.Valid {
82-
// If no linked template exists, check if the actor can read a
83-
// template in the organization.
84-
return rbac.ResourceTemplate.InOrg(tv.OrganizationID), nil
77+
tv, err := q.db.GetTemplateVersionByJobID(ctx, jobID)
78+
if err != nil {
79+
return database.TemplateVersion{}, err
80+
}
81+
if !tv.TemplateID.Valid {
82+
// If no linked template exists, check if the actor can read a template in the organization.
83+
if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceTemplate.InOrg(tv.OrganizationID)); err != nil {
84+
return database.TemplateVersion{}, err
8585
}
86-
return q.db.GetTemplateByID(ctx, tv.TemplateID.UUID)
87-
}
88-
return queryWithRelated(
89-
q.log,
90-
q.auth,
91-
rbac.ActionRead,
92-
fetchRelated,
93-
q.db.GetTemplateVersionByJobID,
94-
)(ctx, jobID)
86+
} else if _, err := q.GetTemplateByID(ctx, tv.TemplateID.UUID); err != nil {
87+
// An actor can read the template version if they can read the related template.
88+
return database.TemplateVersion{}, err
89+
}
90+
return tv, nil
9591
}
9692

9793
func (q *AuthzQuerier) GetTemplateVersionByTemplateIDAndName(ctx context.Context, arg database.GetTemplateVersionByTemplateIDAndNameParams) (database.TemplateVersion, error) {
98-
// An actor can read the template version if they can read the related template.
99-
fetchRelated := func(tv database.TemplateVersion, p database.GetTemplateVersionByTemplateIDAndNameParams) (rbac.Objecter, error) {
100-
if !tv.TemplateID.Valid {
101-
// If no linked template exists, check if the actor can read *a* template.
102-
// We don't know the organization ID.
103-
return rbac.ResourceTemplate, nil
94+
tv, err := q.db.GetTemplateVersionByTemplateIDAndName(ctx, arg)
95+
if err != nil {
96+
return database.TemplateVersion{}, err
97+
}
98+
if !tv.TemplateID.Valid {
99+
// If no linked template exists, check if the actor can read a template in the organization.
100+
if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceTemplate.InOrg(tv.OrganizationID)); err != nil {
101+
return database.TemplateVersion{}, err
104102
}
105-
return q.db.GetTemplateByID(ctx, tv.TemplateID.UUID)
103+
} else if _, err := q.GetTemplateByID(ctx, tv.TemplateID.UUID); err != nil {
104+
// An actor can read the template version if they can read the related template.
105+
return database.TemplateVersion{}, err
106106
}
107-
108-
return queryWithRelated(
109-
q.log,
110-
q.auth,
111-
rbac.ActionRead,
112-
fetchRelated,
113-
q.db.GetTemplateVersionByTemplateIDAndName,
114-
)(ctx, arg)
107+
return tv, nil
115108
}
116109

117110
func (q *AuthzQuerier) GetTemplateVersionParameters(ctx context.Context, templateVersionID uuid.UUID) ([]database.TemplateVersionParameter, error) {
@@ -183,16 +176,10 @@ func (q *AuthzQuerier) GetTemplateVersionsByTemplateID(ctx context.Context, arg
183176

184177
func (q *AuthzQuerier) GetTemplateVersionsCreatedAfter(ctx context.Context, createdAt time.Time) ([]database.TemplateVersion, error) {
185178
// An actor can read execute this query if they can read all templates.
186-
fetchRelated := func(tvs []database.TemplateVersion, _ time.Time) (rbac.Objecter, error) {
187-
return rbac.ResourceTemplate.All(), nil
188-
}
189-
return queryWithRelated(
190-
q.log,
191-
q.auth,
192-
rbac.ActionRead,
193-
fetchRelated,
194-
q.db.GetTemplateVersionsCreatedAfter,
195-
)(ctx, createdAt)
179+
if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceTemplate); err != nil {
180+
return nil, err
181+
}
182+
return q.db.GetTemplateVersionsCreatedAfter(ctx, createdAt)
196183
}
197184

198185
func (q *AuthzQuerier) GetAuthorizedTemplates(ctx context.Context, arg database.GetTemplatesWithFilterParams, _ rbac.PreparedAuthorized) ([]database.Template, error) {

0 commit comments

Comments
 (0)