Skip to content

Commit 67494a3

Browse files
authored
chore: push GetUsers authorization filter to SQL (coder#8497)
* feat: push GetUsers filter to SQL * Remove GetAuthorizedUserFilter * Remove GetFilteredUserCount * remove GetUsersWithCount
1 parent dfac074 commit 67494a3

File tree

11 files changed

+136
-251
lines changed

11 files changed

+136
-251
lines changed

coderd/database/dbauthz/dbauthz.go

+11-39
Original file line numberDiff line numberDiff line change
@@ -586,32 +586,6 @@ func (q *querier) SoftDeleteTemplateByID(ctx context.Context, id uuid.UUID) erro
586586
return deleteQ(q.log, q.auth, q.db.GetTemplateByID, deleteF)(ctx, id)
587587
}
588588

589-
func (q *querier) GetUsersWithCount(ctx context.Context, arg database.GetUsersParams) ([]database.User, int64, error) {
590-
// TODO Implement this with a SQL filter. The count is incorrect without it.
591-
rowUsers, err := q.db.GetUsers(ctx, arg)
592-
if err != nil {
593-
return nil, -1, err
594-
}
595-
596-
if len(rowUsers) == 0 {
597-
return []database.User{}, 0, nil
598-
}
599-
600-
act, ok := ActorFromContext(ctx)
601-
if !ok {
602-
return nil, -1, NoActorError
603-
}
604-
605-
// TODO: Is this correct? Should we return a restricted user?
606-
users := database.ConvertUserRows(rowUsers)
607-
users, err = rbac.Filter(ctx, q.auth, act, rbac.ActionRead, users)
608-
if err != nil {
609-
return nil, -1, err
610-
}
611-
612-
return users, rowUsers[0].Count, nil
613-
}
614-
615589
func (q *querier) SoftDeleteUserByID(ctx context.Context, id uuid.UUID) error {
616590
deleteF := func(ctx context.Context, id uuid.UUID) error {
617591
return q.db.UpdateUserDeletedByID(ctx, database.UpdateUserDeletedByIDParams{
@@ -904,15 +878,6 @@ func (q *querier) GetFileTemplates(ctx context.Context, fileID uuid.UUID) ([]dat
904878
return q.db.GetFileTemplates(ctx, fileID)
905879
}
906880

907-
func (q *querier) GetFilteredUserCount(ctx context.Context, arg database.GetFilteredUserCountParams) (int64, error) {
908-
prep, err := prepareSQLFilter(ctx, q.auth, rbac.ActionRead, rbac.ResourceUser.Type)
909-
if err != nil {
910-
return -1, xerrors.Errorf("(dev error) prepare sql filter: %w", err)
911-
}
912-
// TODO: This should be the only implementation.
913-
return q.GetAuthorizedUserCount(ctx, arg, prep)
914-
}
915-
916881
func (q *querier) GetGitAuthLink(ctx context.Context, arg database.GetGitAuthLinkParams) (database.GitAuthLink, error) {
917882
return fetch(q.log, q.auth, q.db.GetGitAuthLink)(ctx, arg)
918883
}
@@ -1389,8 +1354,12 @@ func (q *querier) GetUserLinkByUserIDLoginType(ctx context.Context, arg database
13891354
}
13901355

13911356
func (q *querier) GetUsers(ctx context.Context, arg database.GetUsersParams) ([]database.GetUsersRow, error) {
1392-
// TODO: We should use GetUsersWithCount with a better method signature.
1393-
return fetchWithPostFilter(q.auth, q.db.GetUsers)(ctx, arg)
1357+
// This does the filtering in SQL.
1358+
prep, err := prepareSQLFilter(ctx, q.auth, rbac.ActionRead, rbac.ResourceUser.Type)
1359+
if err != nil {
1360+
return nil, xerrors.Errorf("(dev error) prepare sql filter: %w", err)
1361+
}
1362+
return q.db.GetAuthorizedUsers(ctx, arg, prep)
13941363
}
13951364

13961365
// GetUsersByIDs is only used for usernames on workspace return data.
@@ -2639,6 +2608,9 @@ func (q *querier) GetAuthorizedWorkspaces(ctx context.Context, arg database.GetW
26392608
return q.GetWorkspaces(ctx, arg)
26402609
}
26412610

2642-
func (q *querier) GetAuthorizedUserCount(ctx context.Context, arg database.GetFilteredUserCountParams, prepared rbac.PreparedAuthorized) (int64, error) {
2643-
return q.db.GetAuthorizedUserCount(ctx, arg, prepared)
2611+
// GetAuthorizedUsers is not required for dbauthz since GetUsers is already
2612+
// authenticated.
2613+
func (q *querier) GetAuthorizedUsers(ctx context.Context, arg database.GetUsersParams, _ rbac.PreparedAuthorized) ([]database.GetUsersRow, error) {
2614+
// GetUsers is authenticated.
2615+
return q.GetUsers(ctx, arg)
26442616
}

coderd/database/dbauthz/dbauthz_test.go

+4-16
Original file line numberDiff line numberDiff line change
@@ -869,24 +869,12 @@ func (s *MethodTestSuite) TestUser() {
869869
Asserts(a, rbac.ActionRead, b, rbac.ActionRead).
870870
Returns(slice.New(a, b))
871871
}))
872-
s.Run("GetAuthorizedUserCount", s.Subtest(func(db database.Store, check *expects) {
873-
_ = dbgen.User(s.T(), db, database.User{})
874-
check.Args(database.GetFilteredUserCountParams{}, emptyPreparedAuthorized{}).Asserts().Returns(int64(1))
875-
}))
876-
s.Run("GetFilteredUserCount", s.Subtest(func(db database.Store, check *expects) {
877-
_ = dbgen.User(s.T(), db, database.User{})
878-
check.Args(database.GetFilteredUserCountParams{}).Asserts().Returns(int64(1))
879-
}))
880872
s.Run("GetUsers", s.Subtest(func(db database.Store, check *expects) {
881-
a := dbgen.User(s.T(), db, database.User{Username: "GetUsers-a-user"})
882-
b := dbgen.User(s.T(), db, database.User{Username: "GetUsers-b-user"})
873+
dbgen.User(s.T(), db, database.User{Username: "GetUsers-a-user"})
874+
dbgen.User(s.T(), db, database.User{Username: "GetUsers-b-user"})
883875
check.Args(database.GetUsersParams{}).
884-
Asserts(a, rbac.ActionRead, b, rbac.ActionRead)
885-
}))
886-
s.Run("GetUsersWithCount", s.Subtest(func(db database.Store, check *expects) {
887-
a := dbgen.User(s.T(), db, database.User{Username: "GetUsersWithCount-a-user"})
888-
b := dbgen.User(s.T(), db, database.User{Username: "GetUsersWithCount-b-user"})
889-
check.Args(database.GetUsersParams{}).Asserts(a, rbac.ActionRead, b, rbac.ActionRead)
876+
// Asserts are done in a SQL filter
877+
Asserts()
890878
}))
891879
s.Run("InsertUser", s.Subtest(func(db database.Store, check *expects) {
892880
check.Args(database.InsertUserParams{

coderd/database/dbfake/dbfake.go

+19-65
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import (
2323
"github.com/coder/coder/coderd/database/db2sdk"
2424
"github.com/coder/coder/coderd/httpapi"
2525
"github.com/coder/coder/coderd/rbac"
26+
"github.com/coder/coder/coderd/rbac/regosql"
2627
"github.com/coder/coder/coderd/util/slice"
2728
"github.com/coder/coder/codersdk"
2829
)
@@ -1207,14 +1208,6 @@ func (q *FakeQuerier) GetFileTemplates(_ context.Context, id uuid.UUID) ([]datab
12071208
return rows, nil
12081209
}
12091210

1210-
func (q *FakeQuerier) GetFilteredUserCount(ctx context.Context, arg database.GetFilteredUserCountParams) (int64, error) {
1211-
if err := validateDatabaseType(arg); err != nil {
1212-
return 0, err
1213-
}
1214-
count, err := q.GetAuthorizedUserCount(ctx, arg, nil)
1215-
return count, err
1216-
}
1217-
12181211
func (q *FakeQuerier) GetGitAuthLink(_ context.Context, arg database.GetGitAuthLinkParams) (database.GitAuthLink, error) {
12191212
if err := validateDatabaseType(arg); err != nil {
12201213
return database.GitAuthLink{}, err
@@ -5365,76 +5358,37 @@ func (q *FakeQuerier) GetAuthorizedWorkspaces(ctx context.Context, arg database.
53655358
return q.convertToWorkspaceRowsNoLock(ctx, workspaces, int64(beforePageCount)), nil
53665359
}
53675360

5368-
func (q *FakeQuerier) GetAuthorizedUserCount(ctx context.Context, params database.GetFilteredUserCountParams, prepared rbac.PreparedAuthorized) (int64, error) {
5369-
if err := validateDatabaseType(params); err != nil {
5370-
return 0, err
5361+
func (q *FakeQuerier) GetAuthorizedUsers(ctx context.Context, arg database.GetUsersParams, prepared rbac.PreparedAuthorized) ([]database.GetUsersRow, error) {
5362+
if err := validateDatabaseType(arg); err != nil {
5363+
return nil, err
53715364
}
53725365

5373-
q.mutex.RLock()
5374-
defer q.mutex.RUnlock()
5375-
53765366
// Call this to match the same function calls as the SQL implementation.
53775367
if prepared != nil {
5378-
_, err := prepared.CompileToSQL(ctx, rbac.ConfigWithoutACL())
5368+
_, err := prepared.CompileToSQL(ctx, regosql.ConvertConfig{
5369+
VariableConverter: regosql.UserConverter(),
5370+
})
53795371
if err != nil {
5380-
return -1, err
5372+
return nil, err
53815373
}
53825374
}
53835375

5384-
users := make([]database.User, 0, len(q.users))
5376+
users, err := q.GetUsers(ctx, arg)
5377+
if err != nil {
5378+
return nil, err
5379+
}
53855380

5386-
for _, user := range q.users {
5381+
q.mutex.RLock()
5382+
defer q.mutex.RUnlock()
5383+
5384+
filteredUsers := make([]database.GetUsersRow, 0, len(users))
5385+
for _, user := range users {
53875386
// If the filter exists, ensure the object is authorized.
53885387
if prepared != nil && prepared.Authorize(ctx, user.RBACObject()) != nil {
53895388
continue
53905389
}
53915390

5392-
users = append(users, user)
5393-
}
5394-
5395-
// Filter out deleted since they should never be returned..
5396-
tmp := make([]database.User, 0, len(users))
5397-
for _, user := range users {
5398-
if !user.Deleted {
5399-
tmp = append(tmp, user)
5400-
}
5401-
}
5402-
users = tmp
5403-
5404-
if params.Search != "" {
5405-
tmp := make([]database.User, 0, len(users))
5406-
for i, user := range users {
5407-
if strings.Contains(strings.ToLower(user.Email), strings.ToLower(params.Search)) {
5408-
tmp = append(tmp, users[i])
5409-
} else if strings.Contains(strings.ToLower(user.Username), strings.ToLower(params.Search)) {
5410-
tmp = append(tmp, users[i])
5411-
}
5412-
}
5413-
users = tmp
5414-
}
5415-
5416-
if len(params.Status) > 0 {
5417-
usersFilteredByStatus := make([]database.User, 0, len(users))
5418-
for i, user := range users {
5419-
if slice.ContainsCompare(params.Status, user.Status, func(a, b database.UserStatus) bool {
5420-
return strings.EqualFold(string(a), string(b))
5421-
}) {
5422-
usersFilteredByStatus = append(usersFilteredByStatus, users[i])
5423-
}
5424-
}
5425-
users = usersFilteredByStatus
5426-
}
5427-
5428-
if len(params.RbacRole) > 0 && !slice.Contains(params.RbacRole, rbac.RoleMember()) {
5429-
usersFilteredByRole := make([]database.User, 0, len(users))
5430-
for i, user := range users {
5431-
if slice.OverlapCompare(params.RbacRole, user.RBACRoles, strings.EqualFold) {
5432-
usersFilteredByRole = append(usersFilteredByRole, users[i])
5433-
}
5434-
}
5435-
5436-
users = usersFilteredByRole
5391+
filteredUsers = append(filteredUsers, user)
54375392
}
5438-
5439-
return int64(len(users)), nil
5393+
return filteredUsers, nil
54405394
}

coderd/database/dbmetrics/dbmetrics.go

+4-11
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

coderd/database/dbmock/dbmock.go

+7-22
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

coderd/database/modelqueries.go

+48-11
Original file line numberDiff line numberDiff line change
@@ -255,29 +255,66 @@ func (q *sqlQuerier) GetAuthorizedWorkspaces(ctx context.Context, arg GetWorkspa
255255
}
256256

257257
type userQuerier interface {
258-
GetAuthorizedUserCount(ctx context.Context, arg GetFilteredUserCountParams, prepared rbac.PreparedAuthorized) (int64, error)
258+
GetAuthorizedUsers(ctx context.Context, arg GetUsersParams, prepared rbac.PreparedAuthorized) ([]GetUsersRow, error)
259259
}
260260

261-
func (q *sqlQuerier) GetAuthorizedUserCount(ctx context.Context, arg GetFilteredUserCountParams, prepared rbac.PreparedAuthorized) (int64, error) {
262-
authorizedFilter, err := prepared.CompileToSQL(ctx, rbac.ConfigWithoutACL())
261+
func (q *sqlQuerier) GetAuthorizedUsers(ctx context.Context, arg GetUsersParams, prepared rbac.PreparedAuthorized) ([]GetUsersRow, error) {
262+
authorizedFilter, err := prepared.CompileToSQL(ctx, regosql.ConvertConfig{
263+
VariableConverter: regosql.UserConverter(),
264+
})
263265
if err != nil {
264-
return -1, xerrors.Errorf("compile authorized filter: %w", err)
266+
return nil, xerrors.Errorf("compile authorized filter: %w", err)
265267
}
266268

267-
filtered, err := insertAuthorizedFilter(getFilteredUserCount, fmt.Sprintf(" AND %s", authorizedFilter))
269+
filtered, err := insertAuthorizedFilter(getUsers, fmt.Sprintf(" AND %s", authorizedFilter))
268270
if err != nil {
269-
return -1, xerrors.Errorf("insert authorized filter: %w", err)
271+
return nil, xerrors.Errorf("insert authorized filter: %w", err)
270272
}
271273

272-
query := fmt.Sprintf("-- name: GetAuthorizedUserCount :one\n%s", filtered)
273-
row := q.db.QueryRowContext(ctx, query,
274+
query := fmt.Sprintf("-- name: GetAuthorizedUsers :many\n%s", filtered)
275+
rows, err := q.db.QueryContext(ctx, query,
276+
arg.AfterID,
274277
arg.Search,
275278
pq.Array(arg.Status),
276279
pq.Array(arg.RbacRole),
280+
arg.LastSeenBefore,
281+
arg.LastSeenAfter,
282+
arg.OffsetOpt,
283+
arg.LimitOpt,
277284
)
278-
var count int64
279-
err = row.Scan(&count)
280-
return count, err
285+
if err != nil {
286+
return nil, err
287+
}
288+
defer rows.Close()
289+
var items []GetUsersRow
290+
for rows.Next() {
291+
var i GetUsersRow
292+
if err := rows.Scan(
293+
&i.ID,
294+
&i.Email,
295+
&i.Username,
296+
&i.HashedPassword,
297+
&i.CreatedAt,
298+
&i.UpdatedAt,
299+
&i.Status,
300+
&i.RBACRoles,
301+
&i.LoginType,
302+
&i.AvatarURL,
303+
&i.Deleted,
304+
&i.LastSeenAt,
305+
&i.Count,
306+
); err != nil {
307+
return nil, err
308+
}
309+
items = append(items, i)
310+
}
311+
if err := rows.Close(); err != nil {
312+
return nil, err
313+
}
314+
if err := rows.Err(); err != nil {
315+
return nil, err
316+
}
317+
return items, nil
281318
}
282319

283320
func insertAuthorizedFilter(query string, replaceWith string) (string, error) {

coderd/database/querier.go

-2
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)