@@ -17,27 +17,26 @@ import (
17
17
18
18
func (q * AuthzQuerier ) GetPreviousTemplateVersion (ctx context.Context , arg database.GetPreviousTemplateVersionParams ) (database.TemplateVersion , error ) {
19
19
// An actor can read the previous template version if they can read the related template.
20
- fetchRelated := func ( _ database. TemplateVersion , _ database. GetPreviousTemplateVersionParams ) (rbac. Objecter , error ) {
21
- if ! arg .TemplateID .Valid {
22
- // If no linked template exists, check if the actor can read the template in the organization.
23
- return rbac . ResourceTemplate . InOrg ( arg . OrganizationID ), nil
20
+ // If no linked template exists, we check if the actor can read *a* template.
21
+ if ! arg .TemplateID .Valid {
22
+ if err := q . authorizeContext ( ctx , rbac . ActionRead , rbac . ResourceTemplate . InOrg ( arg . OrganizationID )); err != nil {
23
+ return database. TemplateVersion {}, err
24
24
}
25
- return q .db .GetTemplateByID (ctx , arg .TemplateID .UUID )
26
25
}
27
- return queryWithRelated (q .log , q .auth , rbac .ActionRead , fetchRelated , q .db .GetPreviousTemplateVersion )(ctx , arg )
26
+ if _ , err := q .GetTemplateByID (ctx , arg .TemplateID .UUID ); err != nil {
27
+ return database.TemplateVersion {}, err
28
+ }
29
+ return q .db .GetPreviousTemplateVersion (ctx , arg )
28
30
}
29
31
30
32
func (q * AuthzQuerier ) GetTemplateAverageBuildTime (ctx context.Context , arg database.GetTemplateAverageBuildTimeParams ) (database.GetTemplateAverageBuildTimeRow , error ) {
31
33
// An actor can read the average build time if they can read the related template.
32
- fetchRelated := func (database.GetTemplateAverageBuildTimeRow , database.GetTemplateAverageBuildTimeParams ) (rbac.Objecter , error ) {
33
- if ! arg .TemplateID .Valid {
34
- // If no linked template exists, check if the actor can read *a* template.
35
- // We don't know the organization ID.
36
- return rbac .ResourceTemplate , nil
37
- }
38
- return q .db .GetTemplateByID (ctx , arg .TemplateID .UUID )
34
+ // It doesn't make any sense to get the average build time for a template that doesn't
35
+ // exist, so omitting this check here.
36
+ if _ , err := q .GetTemplateByID (ctx , arg .TemplateID .UUID ); err != nil {
37
+ return database.GetTemplateAverageBuildTimeRow {}, err
39
38
}
40
- return queryWithRelated ( q . log , q . auth , rbac . ActionRead , fetchRelated , q . db .GetTemplateAverageBuildTime ) (ctx , arg )
39
+ return q . db .GetTemplateAverageBuildTime (ctx , arg )
41
40
}
42
41
43
42
func (q * AuthzQuerier ) GetTemplateByID (ctx context.Context , id uuid.UUID ) (database.Template , error ) {
@@ -50,68 +49,62 @@ func (q *AuthzQuerier) GetTemplateByOrganizationAndName(ctx context.Context, arg
50
49
51
50
func (q * AuthzQuerier ) GetTemplateDAUs (ctx context.Context , templateID uuid.UUID ) ([]database.GetTemplateDAUsRow , error ) {
52
51
// An actor can read the DAUs if they can read the related template.
53
- fetchRelated := func (_ []database.GetTemplateDAUsRow , _ uuid.UUID ) (rbac.Objecter , error ) {
54
- return q .db .GetTemplateByID (ctx , templateID )
52
+ // Again, it doesn't make sense to get DAUs for a template that doesn't exist.
53
+ if _ , err := q .GetTemplateByID (ctx , templateID ); err != nil {
54
+ return nil , err
55
55
}
56
- return queryWithRelated ( q . log , q . auth , rbac . ActionRead , fetchRelated , q . db .GetTemplateDAUs ) (ctx , templateID )
56
+ return q . db .GetTemplateDAUs (ctx , templateID )
57
57
}
58
58
59
59
func (q * AuthzQuerier ) GetTemplateVersionByID (ctx context.Context , tvid uuid.UUID ) (database.TemplateVersion , error ) {
60
- // An actor can read the template version if they can read the related template.
61
- fetchRelated := func (tv database.TemplateVersion , _ uuid.UUID ) (rbac.Objecter , error ) {
62
- if ! tv .TemplateID .Valid {
63
- // If no linked template exists, check if the actor can read a template
64
- // in the organization.
65
- return rbac .ResourceTemplate .InOrg (tv .OrganizationID ), nil
60
+ tv , err := q .db .GetTemplateVersionByID (ctx , tvid )
61
+ if err != nil {
62
+ return database.TemplateVersion {}, err
63
+ }
64
+ if ! tv .TemplateID .Valid {
65
+ // If no linked template exists, check if the actor can read a template in the organization.
66
+ if err := q .authorizeContext (ctx , rbac .ActionRead , rbac .ResourceTemplate .InOrg (tv .OrganizationID )); err != nil {
67
+ return database.TemplateVersion {}, err
66
68
}
67
- return q .db .GetTemplateByID (ctx , tv .TemplateID .UUID )
68
- }
69
- return queryWithRelated (
70
- q .log ,
71
- q .auth ,
72
- rbac .ActionRead ,
73
- fetchRelated ,
74
- q .db .GetTemplateVersionByID ,
75
- )(ctx , tvid )
69
+ } else if _ , err := q .GetTemplateByID (ctx , tv .TemplateID .UUID ); err != nil {
70
+ // An actor can read the template version if they can read the related template.
71
+ return database.TemplateVersion {}, err
72
+ }
73
+ return tv , nil
76
74
}
77
75
78
76
func (q * AuthzQuerier ) GetTemplateVersionByJobID (ctx context.Context , jobID uuid.UUID ) (database.TemplateVersion , error ) {
79
- // An actor can read the template version if they can read the related template.
80
- fetchRelated := func (tv database.TemplateVersion , _ uuid.UUID ) (rbac.Objecter , error ) {
81
- if ! tv .TemplateID .Valid {
82
- // If no linked template exists, check if the actor can read a
83
- // template in the organization.
84
- return rbac .ResourceTemplate .InOrg (tv .OrganizationID ), nil
77
+ tv , err := q .db .GetTemplateVersionByJobID (ctx , jobID )
78
+ if err != nil {
79
+ return database.TemplateVersion {}, err
80
+ }
81
+ if ! tv .TemplateID .Valid {
82
+ // If no linked template exists, check if the actor can read a template in the organization.
83
+ if err := q .authorizeContext (ctx , rbac .ActionRead , rbac .ResourceTemplate .InOrg (tv .OrganizationID )); err != nil {
84
+ return database.TemplateVersion {}, err
85
85
}
86
- return q .db .GetTemplateByID (ctx , tv .TemplateID .UUID )
87
- }
88
- return queryWithRelated (
89
- q .log ,
90
- q .auth ,
91
- rbac .ActionRead ,
92
- fetchRelated ,
93
- q .db .GetTemplateVersionByJobID ,
94
- )(ctx , jobID )
86
+ } else if _ , err := q .GetTemplateByID (ctx , tv .TemplateID .UUID ); err != nil {
87
+ // An actor can read the template version if they can read the related template.
88
+ return database.TemplateVersion {}, err
89
+ }
90
+ return tv , nil
95
91
}
96
92
97
93
func (q * AuthzQuerier ) GetTemplateVersionByTemplateIDAndName (ctx context.Context , arg database.GetTemplateVersionByTemplateIDAndNameParams ) (database.TemplateVersion , error ) {
98
- // An actor can read the template version if they can read the related template.
99
- fetchRelated := func (tv database.TemplateVersion , p database.GetTemplateVersionByTemplateIDAndNameParams ) (rbac.Objecter , error ) {
100
- if ! tv .TemplateID .Valid {
101
- // If no linked template exists, check if the actor can read *a* template.
102
- // We don't know the organization ID.
103
- return rbac .ResourceTemplate , nil
94
+ tv , err := q .db .GetTemplateVersionByTemplateIDAndName (ctx , arg )
95
+ if err != nil {
96
+ return database.TemplateVersion {}, err
97
+ }
98
+ if ! tv .TemplateID .Valid {
99
+ // If no linked template exists, check if the actor can read a template in the organization.
100
+ if err := q .authorizeContext (ctx , rbac .ActionRead , rbac .ResourceTemplate .InOrg (tv .OrganizationID )); err != nil {
101
+ return database.TemplateVersion {}, err
104
102
}
105
- return q .db .GetTemplateByID (ctx , tv .TemplateID .UUID )
103
+ } else if _ , err := q .GetTemplateByID (ctx , tv .TemplateID .UUID ); err != nil {
104
+ // An actor can read the template version if they can read the related template.
105
+ return database.TemplateVersion {}, err
106
106
}
107
-
108
- return queryWithRelated (
109
- q .log ,
110
- q .auth ,
111
- rbac .ActionRead ,
112
- fetchRelated ,
113
- q .db .GetTemplateVersionByTemplateIDAndName ,
114
- )(ctx , arg )
107
+ return tv , nil
115
108
}
116
109
117
110
func (q * AuthzQuerier ) GetTemplateVersionParameters (ctx context.Context , templateVersionID uuid.UUID ) ([]database.TemplateVersionParameter , error ) {
@@ -183,16 +176,10 @@ func (q *AuthzQuerier) GetTemplateVersionsByTemplateID(ctx context.Context, arg
183
176
184
177
func (q * AuthzQuerier ) GetTemplateVersionsCreatedAfter (ctx context.Context , createdAt time.Time ) ([]database.TemplateVersion , error ) {
185
178
// An actor can read execute this query if they can read all templates.
186
- fetchRelated := func (tvs []database.TemplateVersion , _ time.Time ) (rbac.Objecter , error ) {
187
- return rbac .ResourceTemplate .All (), nil
188
- }
189
- return queryWithRelated (
190
- q .log ,
191
- q .auth ,
192
- rbac .ActionRead ,
193
- fetchRelated ,
194
- q .db .GetTemplateVersionsCreatedAfter ,
195
- )(ctx , createdAt )
179
+ if err := q .authorizeContext (ctx , rbac .ActionRead , rbac .ResourceTemplate ); err != nil {
180
+ return nil , err
181
+ }
182
+ return q .db .GetTemplateVersionsCreatedAfter (ctx , createdAt )
196
183
}
197
184
198
185
func (q * AuthzQuerier ) GetAuthorizedTemplates (ctx context.Context , arg database.GetTemplatesWithFilterParams , _ rbac.PreparedAuthorized ) ([]database.Template , error ) {
0 commit comments