@@ -5,12 +5,16 @@ import (
5
5
"fmt"
6
6
"strings"
7
7
8
+ "github.com/google/uuid"
8
9
"github.com/lib/pq"
10
+ "golang.org/x/xerrors"
9
11
10
12
"github.com/coder/coder/coderd/rbac"
13
+ "github.com/coder/coder/coderd/rbac/regosql"
14
+ )
11
15
12
- "github.com/google/uuid"
13
- "golang.org/x/xerrors "
16
+ const (
17
+ authorizedQueryPlaceholder = "-- @authorize_filter "
14
18
)
15
19
16
20
// customQuerier encompasses all non-generated queries.
@@ -23,10 +27,70 @@ type customQuerier interface {
23
27
}
24
28
25
29
type templateQuerier interface {
30
+ GetAuthorizedTemplates (ctx context.Context , arg GetTemplatesWithFilterParams , prepared rbac.PreparedAuthorized ) ([]Template , error )
26
31
GetTemplateGroupRoles (ctx context.Context , id uuid.UUID ) ([]TemplateGroup , error )
27
32
GetTemplateUserRoles (ctx context.Context , id uuid.UUID ) ([]TemplateUser , error )
28
33
}
29
34
35
+ func (q * sqlQuerier ) GetAuthorizedTemplates (ctx context.Context , arg GetTemplatesWithFilterParams , prepared rbac.PreparedAuthorized ) ([]Template , error ) {
36
+ authorizedFilter , err := prepared .CompileToSQL (regosql.ConvertConfig {
37
+ VariableConverter : regosql .TemplateConverter (),
38
+ })
39
+ if err != nil {
40
+ return nil , xerrors .Errorf ("compile authorized filter: %w" , err )
41
+ }
42
+
43
+ filtered , err := insertAuthorizedFilter (getTemplatesWithFilter , fmt .Sprintf (" AND %s" , authorizedFilter ))
44
+ if err != nil {
45
+ return nil , xerrors .Errorf ("insert authorized filter: %w" , err )
46
+ }
47
+
48
+ // The name comment is for metric tracking
49
+ query := fmt .Sprintf ("-- name: GetAuthorizedTemplates :many\n %s" , filtered )
50
+ rows , err := q .db .QueryContext (ctx , query ,
51
+ arg .Deleted ,
52
+ arg .OrganizationID ,
53
+ arg .ExactName ,
54
+ pq .Array (arg .IDs ),
55
+ )
56
+ if err != nil {
57
+ return nil , err
58
+ }
59
+ defer rows .Close ()
60
+ var items []Template
61
+ for rows .Next () {
62
+ var i Template
63
+ if err := rows .Scan (
64
+ & i .ID ,
65
+ & i .CreatedAt ,
66
+ & i .UpdatedAt ,
67
+ & i .OrganizationID ,
68
+ & i .Deleted ,
69
+ & i .Name ,
70
+ & i .Provisioner ,
71
+ & i .ActiveVersionID ,
72
+ & i .Description ,
73
+ & i .DefaultTTL ,
74
+ & i .CreatedBy ,
75
+ & i .Icon ,
76
+ & i .UserACL ,
77
+ & i .GroupACL ,
78
+ & i .DisplayName ,
79
+ & i .AllowUserCancelWorkspaceJobs ,
80
+ ); err != nil {
81
+ return nil , err
82
+ }
83
+ items = append (items , i )
84
+ }
85
+ if err := rows .Close (); err != nil {
86
+ return nil , err
87
+ }
88
+ if err := rows .Err (); err != nil {
89
+ return nil , err
90
+ }
91
+ return items , nil
92
+ }
93
+
30
94
type TemplateUser struct {
31
95
User
32
96
Actions Actions `db:"actions"`
@@ -112,18 +176,27 @@ func (q *sqlQuerier) GetTemplateGroupRoles(ctx context.Context, id uuid.UUID) ([
112
176
}
113
177
114
178
type workspaceQuerier interface {
115
- GetAuthorizedWorkspaces (ctx context.Context , arg GetWorkspacesParams , authorizedFilter rbac.AuthorizeFilter ) ([]GetWorkspacesRow , error )
179
+ GetAuthorizedWorkspaces (ctx context.Context , arg GetWorkspacesParams , prepared rbac.PreparedAuthorized ) ([]GetWorkspacesRow , error )
116
180
}
117
181
118
182
// GetAuthorizedWorkspaces returns all workspaces that the user is authorized to access.
119
183
// This code is copied from `GetWorkspaces` and adds the authorized filter WHERE
120
184
// clause.
121
- func (q * sqlQuerier ) GetAuthorizedWorkspaces (ctx context.Context , arg GetWorkspacesParams , authorizedFilter rbac.AuthorizeFilter ) ([]GetWorkspacesRow , error ) {
185
+ func (q * sqlQuerier ) GetAuthorizedWorkspaces (ctx context.Context , arg GetWorkspacesParams , prepared rbac.PreparedAuthorized ) ([]GetWorkspacesRow , error ) {
186
+ authorizedFilter , err := prepared .CompileToSQL (rbac .ConfigWithoutACL ())
187
+ if err != nil {
188
+ return nil , xerrors .Errorf ("compile authorized filter: %w" , err )
189
+ }
190
+
122
191
// In order to properly use ORDER BY, OFFSET, and LIMIT, we need to inject the
123
192
// authorizedFilter between the end of the where clause and those statements.
124
- filter := strings .Replace (getWorkspaces , "-- @authorize_filter" , fmt .Sprintf (" AND %s" , authorizedFilter .SQLString (rbac .NoACLConfig ())), 1 )
193
+ filtered , err := insertAuthorizedFilter (getWorkspaces , fmt .Sprintf (" AND %s" , authorizedFilter ))
194
+ if err != nil {
195
+ return nil , xerrors .Errorf ("insert authorized filter: %w" , err )
196
+ }
197
+
125
198
// The name comment is for metric tracking
126
- query := fmt .Sprintf ("-- name: GetAuthorizedWorkspaces :many\n %s" , filter )
199
+ query := fmt .Sprintf ("-- name: GetAuthorizedWorkspaces :many\n %s" , filtered )
127
200
rows , err := q .db .QueryContext (ctx , query ,
128
201
arg .Deleted ,
129
202
arg .Status ,
@@ -172,19 +245,36 @@ func (q *sqlQuerier) GetAuthorizedWorkspaces(ctx context.Context, arg GetWorkspa
172
245
}
173
246
174
247
type userQuerier interface {
175
- GetAuthorizedUserCount (ctx context.Context , arg GetFilteredUserCountParams , authorizedFilter rbac.AuthorizeFilter ) (int64 , error )
248
+ GetAuthorizedUserCount (ctx context.Context , arg GetFilteredUserCountParams , prepared rbac.PreparedAuthorized ) (int64 , error )
176
249
}
177
250
178
- func (q * sqlQuerier ) GetAuthorizedUserCount (ctx context.Context , arg GetFilteredUserCountParams , authorizedFilter rbac.AuthorizeFilter ) (int64 , error ) {
179
- filter := strings .Replace (getFilteredUserCount , "-- @authorize_filter" , fmt .Sprintf (" AND %s" , authorizedFilter .SQLString (rbac .NoACLConfig ())), 1 )
180
- query := fmt .Sprintf ("-- name: GetAuthorizedUserCount :one\n %s" , filter )
251
+ func (q * sqlQuerier ) GetAuthorizedUserCount (ctx context.Context , arg GetFilteredUserCountParams , prepared rbac.PreparedAuthorized ) (int64 , error ) {
252
+ authorizedFilter , err := prepared .CompileToSQL (rbac .ConfigWithoutACL ())
253
+ if err != nil {
254
+ return - 1 , xerrors .Errorf ("compile authorized filter: %w" , err )
255
+ }
256
+
257
+ filtered , err := insertAuthorizedFilter (getFilteredUserCount , fmt .Sprintf (" AND %s" , authorizedFilter ))
258
+ if err != nil {
259
+ return - 1 , xerrors .Errorf ("insert authorized filter: %w" , err )
260
+ }
261
+
262
+ query := fmt .Sprintf ("-- name: GetAuthorizedUserCount :one\n %s" , filtered )
181
263
row := q .db .QueryRowContext (ctx , query ,
182
264
arg .Deleted ,
183
265
arg .Search ,
184
266
pq .Array (arg .Status ),
185
267
pq .Array (arg .RbacRole ),
186
268
)
187
269
var count int64
188
- err : = row .Scan (& count )
270
+ err = row .Scan (& count )
189
271
return count , err
190
272
}
273
+
274
+ func insertAuthorizedFilter (query string , replaceWith string ) (string , error ) {
275
+ if ! strings .Contains (query , authorizedQueryPlaceholder ) {
276
+ return "" , xerrors .Errorf ("query does not contain authorized replace string, this is not an authorized query" )
277
+ }
278
+ filtered := strings .Replace (query , authorizedQueryPlaceholder , replaceWith , 1 )
279
+ return filtered , nil
280
+ }
0 commit comments