From 634f4ca14881f9c16a66b98209f4a6912da408b4 Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Thu, 13 Jul 2023 10:04:49 -0400 Subject: [PATCH 1/7] feat: push GetUsers filter to SQL --- coderd/database/dbauthz/dbauthz.go | 26 +++++------ coderd/database/dbfake/dbfake.go | 36 ++++++++++++++ coderd/database/dbmetrics/dbmetrics.go | 7 +++ coderd/database/dbmock/dbmock.go | 15 ++++++ coderd/database/modelqueries.go | 65 +++++++++++++++++++++++++- coderd/database/queries.sql.go | 3 ++ coderd/database/queries/users.sql | 3 ++ coderd/rbac/regosql/compile_test.go | 20 ++++++++ coderd/rbac/regosql/configs.go | 17 +++++++ 9 files changed, 178 insertions(+), 14 deletions(-) diff --git a/coderd/database/dbauthz/dbauthz.go b/coderd/database/dbauthz/dbauthz.go index 41fa20392fadf..00919179c0c26 100644 --- a/coderd/database/dbauthz/dbauthz.go +++ b/coderd/database/dbauthz/dbauthz.go @@ -620,7 +620,6 @@ func (q *querier) GetAuthorizedUserCount(ctx context.Context, arg database.GetFi } func (q *querier) GetUsersWithCount(ctx context.Context, arg database.GetUsersParams) ([]database.User, int64, error) { - // TODO Implement this with a SQL filter. The count is incorrect without it. rowUsers, err := q.db.GetUsers(ctx, arg) if err != nil { return nil, -1, err @@ -630,18 +629,8 @@ func (q *querier) GetUsersWithCount(ctx context.Context, arg database.GetUsersPa return []database.User{}, 0, nil } - act, ok := ActorFromContext(ctx) - if !ok { - return nil, -1, NoActorError - } - // TODO: Is this correct? Should we return a restricted user? users := database.ConvertUserRows(rowUsers) - users, err = rbac.Filter(ctx, q.auth, act, rbac.ActionRead, users) - if err != nil { - return nil, -1, err - } - return users, rowUsers[0].Count, nil } @@ -699,6 +688,13 @@ func authorizedTemplateVersionFromJob(ctx context.Context, q *querier, job datab } } +// GetAuthorizedUsers is not required for dbauthz since GetUsers is already +// authenticated. +func (q *querier) GetAuthorizedUsers(ctx context.Context, arg database.GetUsersParams, _ rbac.PreparedAuthorized) ([]database.GetUsersRow, error) { + // GetUsers is authenticated. + return q.GetUsers(ctx, arg) +} + func (q *querier) AcquireLock(ctx context.Context, id int64) error { return q.db.AcquireLock(ctx, id) } @@ -1427,8 +1423,12 @@ func (q *querier) GetUserLinkByUserIDLoginType(ctx context.Context, arg database } func (q *querier) GetUsers(ctx context.Context, arg database.GetUsersParams) ([]database.GetUsersRow, error) { - // TODO: We should use GetUsersWithCount with a better method signature. - return fetchWithPostFilter(q.auth, q.db.GetUsers)(ctx, arg) + // This does the filtering in SQL. + prep, err := prepareSQLFilter(ctx, q.auth, rbac.ActionRead, rbac.ResourceUser.Type) + if err != nil { + return nil, xerrors.Errorf("(dev error) prepare sql filter: %w", err) + } + return q.db.GetAuthorizedUsers(ctx, arg, prep) } // GetUsersByIDs is only used for usernames on workspace return data. diff --git a/coderd/database/dbfake/dbfake.go b/coderd/database/dbfake/dbfake.go index bdcfd9366dabb..6c66b5ee6a614 100644 --- a/coderd/database/dbfake/dbfake.go +++ b/coderd/database/dbfake/dbfake.go @@ -23,6 +23,7 @@ import ( "github.com/coder/coder/coderd/database/db2sdk" "github.com/coder/coder/coderd/httpapi" "github.com/coder/coder/coderd/rbac" + "github.com/coder/coder/coderd/rbac/regosql" "github.com/coder/coder/coderd/util/slice" "github.com/coder/coder/codersdk" ) @@ -5437,3 +5438,38 @@ func (*FakeQuerier) UpsertTailnetClient(context.Context, database.UpsertTailnetC func (*FakeQuerier) UpsertTailnetCoordinator(context.Context, uuid.UUID) (database.TailnetCoordinator, error) { return database.TailnetCoordinator{}, ErrUnimplemented } + +func (q *FakeQuerier) GetAuthorizedUsers(ctx context.Context, arg database.GetUsersParams, prepared rbac.PreparedAuthorized) ([]database.GetUsersRow, error) { + if err := validateDatabaseType(arg); err != nil { + return nil, err + } + + // Call this to match the same function calls as the SQL implementation. + if prepared != nil { + _, err := prepared.CompileToSQL(ctx, regosql.ConvertConfig{ + VariableConverter: regosql.UserConverter(), + }) + if err != nil { + return nil, err + } + } + + users, err := q.GetUsers(ctx, arg) + if err != nil { + return nil, err + } + + q.mutex.RLock() + defer q.mutex.RUnlock() + + filteredUsers := make([]database.GetUsersRow, 0, len(users)) + for _, user := range users { + // If the filter exists, ensure the object is authorized. + if prepared != nil && prepared.Authorize(ctx, user.RBACObject()) != nil { + continue + } + + filteredUsers = append(filteredUsers, user) + } + return filteredUsers, nil +} diff --git a/coderd/database/dbmetrics/dbmetrics.go b/coderd/database/dbmetrics/dbmetrics.go index ec28fd428a102..17da69ebfcbda 100644 --- a/coderd/database/dbmetrics/dbmetrics.go +++ b/coderd/database/dbmetrics/dbmetrics.go @@ -1639,3 +1639,10 @@ func (m metricsStore) UpsertTailnetCoordinator(ctx context.Context, id uuid.UUID defer m.queryLatencies.WithLabelValues("UpsertTailnetCoordinator").Observe(time.Since(start).Seconds()) return m.s.UpsertTailnetCoordinator(ctx, id) } + +func (m metricsStore) GetAuthorizedUsers(ctx context.Context, arg database.GetUsersParams, prepared rbac.PreparedAuthorized) ([]database.GetUsersRow, error) { + start := time.Now() + r0, r1 := m.s.GetAuthorizedUsers(ctx, arg, prepared) + m.queryLatencies.WithLabelValues("GetAuthorizedUsers").Observe(time.Since(start).Seconds()) + return r0, r1 +} diff --git a/coderd/database/dbmock/dbmock.go b/coderd/database/dbmock/dbmock.go index deab31927154f..9d635bd77e0e4 100644 --- a/coderd/database/dbmock/dbmock.go +++ b/coderd/database/dbmock/dbmock.go @@ -446,6 +446,21 @@ func (mr *MockStoreMockRecorder) GetAuthorizedUserCount(arg0, arg1, arg2 interfa return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAuthorizedUserCount", reflect.TypeOf((*MockStore)(nil).GetAuthorizedUserCount), arg0, arg1, arg2) } +// GetAuthorizedUsers mocks base method. +func (m *MockStore) GetAuthorizedUsers(arg0 context.Context, arg1 database.GetUsersParams, arg2 rbac.PreparedAuthorized) ([]database.GetUsersRow, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetAuthorizedUsers", arg0, arg1, arg2) + ret0, _ := ret[0].([]database.GetUsersRow) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetAuthorizedUsers indicates an expected call of GetAuthorizedUsers. +func (mr *MockStoreMockRecorder) GetAuthorizedUsers(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAuthorizedUsers", reflect.TypeOf((*MockStore)(nil).GetAuthorizedUsers), arg0, arg1, arg2) +} + // GetAuthorizedWorkspaces mocks base method. func (m *MockStore) GetAuthorizedWorkspaces(arg0 context.Context, arg1 database.GetWorkspacesParams, arg2 rbac.PreparedAuthorized) ([]database.GetWorkspacesRow, error) { m.ctrl.T.Helper() diff --git a/coderd/database/modelqueries.go b/coderd/database/modelqueries.go index 28a56b825f34e..8f5a392a23f79 100644 --- a/coderd/database/modelqueries.go +++ b/coderd/database/modelqueries.go @@ -256,10 +256,73 @@ func (q *sqlQuerier) GetAuthorizedWorkspaces(ctx context.Context, arg GetWorkspa type userQuerier interface { GetAuthorizedUserCount(ctx context.Context, arg GetFilteredUserCountParams, prepared rbac.PreparedAuthorized) (int64, error) + GetAuthorizedUsers(ctx context.Context, arg GetUsersParams, prepared rbac.PreparedAuthorized) ([]GetUsersRow, error) +} + +func (q *sqlQuerier) GetAuthorizedUsers(ctx context.Context, arg GetUsersParams, prepared rbac.PreparedAuthorized) ([]GetUsersRow, error) { + authorizedFilter, err := prepared.CompileToSQL(ctx, regosql.ConvertConfig{ + VariableConverter: regosql.UserConverter(), + }) + + if err != nil { + return nil, xerrors.Errorf("compile authorized filter: %w", err) + } + + filtered, err := insertAuthorizedFilter(getUsers, fmt.Sprintf(" AND %s", authorizedFilter)) + if err != nil { + return nil, xerrors.Errorf("insert authorized filter: %w", err) + } + + query := fmt.Sprintf("-- name: GetAuthorizedUsers :many\n%s", filtered) + rows, err := q.db.QueryContext(ctx, query, + arg.AfterID, + arg.Search, + pq.Array(arg.Status), + pq.Array(arg.RbacRole), + arg.LastSeenBefore, + arg.LastSeenAfter, + arg.OffsetOpt, + arg.LimitOpt, + ) + if err != nil { + return nil, err + } + defer rows.Close() + var items []GetUsersRow + for rows.Next() { + var i GetUsersRow + if err := rows.Scan( + &i.ID, + &i.Email, + &i.Username, + &i.HashedPassword, + &i.CreatedAt, + &i.UpdatedAt, + &i.Status, + &i.RBACRoles, + &i.LoginType, + &i.AvatarURL, + &i.Deleted, + &i.LastSeenAt, + &i.Count, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil } func (q *sqlQuerier) GetAuthorizedUserCount(ctx context.Context, arg GetFilteredUserCountParams, prepared rbac.PreparedAuthorized) (int64, error) { - authorizedFilter, err := prepared.CompileToSQL(ctx, rbac.ConfigWithoutACL()) + authorizedFilter, err := prepared.CompileToSQL(ctx, regosql.ConvertConfig{ + VariableConverter: regosql.UserConverter(), + }) if err != nil { return -1, xerrors.Errorf("compile authorized filter: %w", err) } diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index 78537c65ff5df..e51891e7168f0 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -5301,6 +5301,9 @@ WHERE ELSE true END -- End of filters + + -- Authorize Filter clause will be injected below in GetAuthorizedUsers + -- @authorize_filter ORDER BY -- Deterministic and consistent ordering of all users. This is to ensure consistent pagination. LOWER(username) ASC OFFSET $7 diff --git a/coderd/database/queries/users.sql b/coderd/database/queries/users.sql index 75cc85cdf90de..2115b2eda332e 100644 --- a/coderd/database/queries/users.sql +++ b/coderd/database/queries/users.sql @@ -208,6 +208,9 @@ WHERE ELSE true END -- End of filters + + -- Authorize Filter clause will be injected below in GetAuthorizedUsers + -- @authorize_filter ORDER BY -- Deterministic and consistent ordering of all users. This is to ensure consistent pagination. LOWER(username) ASC OFFSET @offset_opt diff --git a/coderd/rbac/regosql/compile_test.go b/coderd/rbac/regosql/compile_test.go index 6c350b7834639..1997b279d6808 100644 --- a/coderd/rbac/regosql/compile_test.go +++ b/coderd/rbac/regosql/compile_test.go @@ -242,6 +242,26 @@ neq(input.object.owner, ""); p("false")), VariableConverter: regosql.TemplateConverter(), }, + { + Name: "UserNoOrgOwner", + Queries: []string{ + `input.object.org_owner != ""`, + }, + ExpectedSQL: p("'' != ''"), + VariableConverter: regosql.UserConverter(), + }, + { + Name: "UserOwnsSelf", + Queries: []string{ + `"10d03e62-7703-4df5-a358-4f76577d4e2f" = input.object.owner; + input.object.owner != ""; + input.object.org_owner = ""`, + }, + VariableConverter: regosql.UserConverter(), + ExpectedSQL: p( + p("'10d03e62-7703-4df5-a358-4f76577d4e2f' = ''") + " AND " + p("'' != ''") + " AND " + p("'' = ''"), + ), + }, } for _, tc := range testCases { diff --git a/coderd/rbac/regosql/configs.go b/coderd/rbac/regosql/configs.go index 475d317cd53ab..9f27809bf017c 100644 --- a/coderd/rbac/regosql/configs.go +++ b/coderd/rbac/regosql/configs.go @@ -36,6 +36,23 @@ func TemplateConverter() *sqltypes.VariableConverter { return matcher } +func UserConverter() *sqltypes.VariableConverter { + matcher := sqltypes.NewVariableConverter().RegisterMatcher( + resourceIDMatcher(), + // Users are never owned by an organization, so always return the empty string + // for the org owner. + sqltypes.StringVarMatcher("''", []string{"input", "object", "org_owner"}), + // Users never have an owner, and are only owned site wide. + sqltypes.StringVarMatcher("''", []string{"input", "object", "owner"}), + ) + matcher.RegisterMatcher( + // No ACLs on the user type + sqltypes.AlwaysFalse(groupACLMatcher(matcher)), + sqltypes.AlwaysFalse(userACLMatcher(matcher)), + ) + return matcher +} + // NoACLConverter should be used when the target SQL table does not contain // group or user ACL columns. func NoACLConverter() *sqltypes.VariableConverter { From 49e893b589738bf747da8e6de643a877ccc21472 Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Thu, 13 Jul 2023 10:38:26 -0400 Subject: [PATCH 2/7] Remove GetAuthorizedUserFilter --- coderd/database/dbauthz/dbauthz.go | 15 +-- coderd/database/dbauthz/dbauthz_test.go | 14 ++- coderd/database/dbfake/dbfake.go | 128 +++++++++--------------- coderd/database/dbmetrics/dbmetrics.go | 7 -- coderd/database/dbmock/dbmock.go | 15 --- coderd/database/modelqueries.go | 25 ----- coderd/database/queries.sql.go | 2 - 7 files changed, 61 insertions(+), 145 deletions(-) diff --git a/coderd/database/dbauthz/dbauthz.go b/coderd/database/dbauthz/dbauthz.go index 00919179c0c26..afd6e7ee0babd 100644 --- a/coderd/database/dbauthz/dbauthz.go +++ b/coderd/database/dbauthz/dbauthz.go @@ -615,12 +615,9 @@ func (q *querier) GetTemplateUserRoles(ctx context.Context, id uuid.UUID) ([]dat return q.db.GetTemplateUserRoles(ctx, id) } -func (q *querier) GetAuthorizedUserCount(ctx context.Context, arg database.GetFilteredUserCountParams, prepared rbac.PreparedAuthorized) (int64, error) { - return q.db.GetAuthorizedUserCount(ctx, arg, prepared) -} - func (q *querier) GetUsersWithCount(ctx context.Context, arg database.GetUsersParams) ([]database.User, int64, error) { - rowUsers, err := q.db.GetUsers(ctx, arg) + // q.GetUsers only returns authorized users + rowUsers, err := q.GetUsers(ctx, arg) if err != nil { return nil, -1, err } @@ -939,12 +936,10 @@ func (q *querier) GetFileTemplates(ctx context.Context, fileID uuid.UUID) ([]dat } func (q *querier) GetFilteredUserCount(ctx context.Context, arg database.GetFilteredUserCountParams) (int64, error) { - prep, err := prepareSQLFilter(ctx, q.auth, rbac.ActionRead, rbac.ResourceUser.Type) - if err != nil { - return -1, xerrors.Errorf("(dev error) prepare sql filter: %w", err) + if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceSystem); err != nil { + return -1, err } - // TODO: This should be the only implementation. - return q.GetAuthorizedUserCount(ctx, arg, prep) + return q.db.GetFilteredUserCount(ctx, arg) } func (q *querier) GetGitAuthLink(ctx context.Context, arg database.GetGitAuthLinkParams) (database.GitAuthLink, error) { diff --git a/coderd/database/dbauthz/dbauthz_test.go b/coderd/database/dbauthz/dbauthz_test.go index bde4a1dfd5ef4..2083ab65a371c 100644 --- a/coderd/database/dbauthz/dbauthz_test.go +++ b/coderd/database/dbauthz/dbauthz_test.go @@ -869,19 +869,17 @@ func (s *MethodTestSuite) TestUser() { Asserts(a, rbac.ActionRead, b, rbac.ActionRead). Returns(slice.New(a, b)) })) - s.Run("GetAuthorizedUserCount", s.Subtest(func(db database.Store, check *expects) { - _ = dbgen.User(s.T(), db, database.User{}) - check.Args(database.GetFilteredUserCountParams{}, emptyPreparedAuthorized{}).Asserts().Returns(int64(1)) - })) s.Run("GetFilteredUserCount", s.Subtest(func(db database.Store, check *expects) { _ = dbgen.User(s.T(), db, database.User{}) - check.Args(database.GetFilteredUserCountParams{}).Asserts().Returns(int64(1)) + check.Args(database.GetFilteredUserCountParams{}).Asserts( + rbac.ResourceSystem, rbac.ActionRead).Returns(int64(1)) })) s.Run("GetUsers", s.Subtest(func(db database.Store, check *expects) { - a := dbgen.User(s.T(), db, database.User{Username: "GetUsers-a-user"}) - b := dbgen.User(s.T(), db, database.User{Username: "GetUsers-b-user"}) + dbgen.User(s.T(), db, database.User{Username: "GetUsers-a-user"}) + dbgen.User(s.T(), db, database.User{Username: "GetUsers-b-user"}) check.Args(database.GetUsersParams{}). - Asserts(a, rbac.ActionRead, b, rbac.ActionRead) + // Asserts are done in a SQL filter + Asserts() })) s.Run("GetUsersWithCount", s.Subtest(func(db database.Store, check *expects) { a := dbgen.User(s.T(), db, database.User{Username: "GetUsersWithCount-a-user"}) diff --git a/coderd/database/dbfake/dbfake.go b/coderd/database/dbfake/dbfake.go index 6c66b5ee6a614..03969f7b7a909 100644 --- a/coderd/database/dbfake/dbfake.go +++ b/coderd/database/dbfake/dbfake.go @@ -267,80 +267,6 @@ func (q *FakeQuerier) getUserByIDNoLock(id uuid.UUID) (database.User, error) { return database.User{}, sql.ErrNoRows } -func (q *FakeQuerier) GetAuthorizedUserCount(ctx context.Context, params database.GetFilteredUserCountParams, prepared rbac.PreparedAuthorized) (int64, error) { - if err := validateDatabaseType(params); err != nil { - return 0, err - } - - q.mutex.RLock() - defer q.mutex.RUnlock() - - // Call this to match the same function calls as the SQL implementation. - if prepared != nil { - _, err := prepared.CompileToSQL(ctx, rbac.ConfigWithoutACL()) - if err != nil { - return -1, err - } - } - - users := make([]database.User, 0, len(q.users)) - - for _, user := range q.users { - // If the filter exists, ensure the object is authorized. - if prepared != nil && prepared.Authorize(ctx, user.RBACObject()) != nil { - continue - } - - users = append(users, user) - } - - // Filter out deleted since they should never be returned.. - tmp := make([]database.User, 0, len(users)) - for _, user := range users { - if !user.Deleted { - tmp = append(tmp, user) - } - } - users = tmp - - if params.Search != "" { - tmp := make([]database.User, 0, len(users)) - for i, user := range users { - if strings.Contains(strings.ToLower(user.Email), strings.ToLower(params.Search)) { - tmp = append(tmp, users[i]) - } else if strings.Contains(strings.ToLower(user.Username), strings.ToLower(params.Search)) { - tmp = append(tmp, users[i]) - } - } - users = tmp - } - - if len(params.Status) > 0 { - usersFilteredByStatus := make([]database.User, 0, len(users)) - for i, user := range users { - if slice.ContainsCompare(params.Status, user.Status, func(a, b database.UserStatus) bool { - return strings.EqualFold(string(a), string(b)) - }) { - usersFilteredByStatus = append(usersFilteredByStatus, users[i]) - } - } - users = usersFilteredByStatus - } - - if len(params.RbacRole) > 0 && !slice.Contains(params.RbacRole, rbac.RoleMember()) { - usersFilteredByRole := make([]database.User, 0, len(users)) - for i, user := range users { - if slice.OverlapCompare(params.RbacRole, user.RBACRoles, strings.EqualFold) { - usersFilteredByRole = append(usersFilteredByRole, users[i]) - } - } - - users = usersFilteredByRole - } - - return int64(len(users)), nil -} - func convertUsers(users []database.User, count int64) []database.GetUsersRow { rows := make([]database.GetUsersRow, len(users)) for i, u := range users { @@ -1673,12 +1599,58 @@ func (q *FakeQuerier) GetFileTemplates(_ context.Context, id uuid.UUID) ([]datab return rows, nil } -func (q *FakeQuerier) GetFilteredUserCount(ctx context.Context, arg database.GetFilteredUserCountParams) (int64, error) { - if err := validateDatabaseType(arg); err != nil { +func (q *FakeQuerier) GetFilteredUserCount(ctx context.Context, params database.GetFilteredUserCountParams) (int64, error) { + if err := validateDatabaseType(params); err != nil { return 0, err } - count, err := q.GetAuthorizedUserCount(ctx, arg, nil) - return count, err + + q.mutex.RLock() + defer q.mutex.RUnlock() + + // Filter out deleted since they should never be returned.. + users := make([]database.User, 0, len(q.users)) + for _, user := range q.users { + if !user.Deleted { + users = append(users, user) + } + } + + if params.Search != "" { + tmp := make([]database.User, 0, len(users)) + for i, user := range users { + if strings.Contains(strings.ToLower(user.Email), strings.ToLower(params.Search)) { + tmp = append(tmp, users[i]) + } else if strings.Contains(strings.ToLower(user.Username), strings.ToLower(params.Search)) { + tmp = append(tmp, users[i]) + } + } + users = tmp + } + + if len(params.Status) > 0 { + usersFilteredByStatus := make([]database.User, 0, len(users)) + for i, user := range users { + if slice.ContainsCompare(params.Status, user.Status, func(a, b database.UserStatus) bool { + return strings.EqualFold(string(a), string(b)) + }) { + usersFilteredByStatus = append(usersFilteredByStatus, users[i]) + } + } + users = usersFilteredByStatus + } + + if len(params.RbacRole) > 0 && !slice.Contains(params.RbacRole, rbac.RoleMember()) { + usersFilteredByRole := make([]database.User, 0, len(users)) + for i, user := range users { + if slice.OverlapCompare(params.RbacRole, user.RBACRoles, strings.EqualFold) { + usersFilteredByRole = append(usersFilteredByRole, users[i]) + } + } + + users = usersFilteredByRole + } + + return int64(len(users)), nil } func (q *FakeQuerier) GetGitAuthLink(_ context.Context, arg database.GetGitAuthLinkParams) (database.GitAuthLink, error) { diff --git a/coderd/database/dbmetrics/dbmetrics.go b/coderd/database/dbmetrics/dbmetrics.go index 17da69ebfcbda..76dcab05d91f9 100644 --- a/coderd/database/dbmetrics/dbmetrics.go +++ b/coderd/database/dbmetrics/dbmetrics.go @@ -101,13 +101,6 @@ func (m metricsStore) GetAuthorizedWorkspaces(ctx context.Context, arg database. return workspaces, err } -func (m metricsStore) GetAuthorizedUserCount(ctx context.Context, arg database.GetFilteredUserCountParams, prepared rbac.PreparedAuthorized) (int64, error) { - start := time.Now() - count, err := m.s.GetAuthorizedUserCount(ctx, arg, prepared) - m.queryLatencies.WithLabelValues("GetAuthorizedUserCount").Observe(time.Since(start).Seconds()) - return count, err -} - func (m metricsStore) AcquireLock(ctx context.Context, pgAdvisoryXactLock int64) error { start := time.Now() err := m.s.AcquireLock(ctx, pgAdvisoryXactLock) diff --git a/coderd/database/dbmock/dbmock.go b/coderd/database/dbmock/dbmock.go index 9d635bd77e0e4..f4d9c2f296c68 100644 --- a/coderd/database/dbmock/dbmock.go +++ b/coderd/database/dbmock/dbmock.go @@ -431,21 +431,6 @@ func (mr *MockStoreMockRecorder) GetAuthorizedTemplates(arg0, arg1, arg2 interfa return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAuthorizedTemplates", reflect.TypeOf((*MockStore)(nil).GetAuthorizedTemplates), arg0, arg1, arg2) } -// GetAuthorizedUserCount mocks base method. -func (m *MockStore) GetAuthorizedUserCount(arg0 context.Context, arg1 database.GetFilteredUserCountParams, arg2 rbac.PreparedAuthorized) (int64, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetAuthorizedUserCount", arg0, arg1, arg2) - ret0, _ := ret[0].(int64) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetAuthorizedUserCount indicates an expected call of GetAuthorizedUserCount. -func (mr *MockStoreMockRecorder) GetAuthorizedUserCount(arg0, arg1, arg2 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAuthorizedUserCount", reflect.TypeOf((*MockStore)(nil).GetAuthorizedUserCount), arg0, arg1, arg2) -} - // GetAuthorizedUsers mocks base method. func (m *MockStore) GetAuthorizedUsers(arg0 context.Context, arg1 database.GetUsersParams, arg2 rbac.PreparedAuthorized) ([]database.GetUsersRow, error) { m.ctrl.T.Helper() diff --git a/coderd/database/modelqueries.go b/coderd/database/modelqueries.go index 8f5a392a23f79..1eea90f56319b 100644 --- a/coderd/database/modelqueries.go +++ b/coderd/database/modelqueries.go @@ -255,7 +255,6 @@ func (q *sqlQuerier) GetAuthorizedWorkspaces(ctx context.Context, arg GetWorkspa } type userQuerier interface { - GetAuthorizedUserCount(ctx context.Context, arg GetFilteredUserCountParams, prepared rbac.PreparedAuthorized) (int64, error) GetAuthorizedUsers(ctx context.Context, arg GetUsersParams, prepared rbac.PreparedAuthorized) ([]GetUsersRow, error) } @@ -319,30 +318,6 @@ func (q *sqlQuerier) GetAuthorizedUsers(ctx context.Context, arg GetUsersParams, return items, nil } -func (q *sqlQuerier) GetAuthorizedUserCount(ctx context.Context, arg GetFilteredUserCountParams, prepared rbac.PreparedAuthorized) (int64, error) { - authorizedFilter, err := prepared.CompileToSQL(ctx, regosql.ConvertConfig{ - VariableConverter: regosql.UserConverter(), - }) - if err != nil { - return -1, xerrors.Errorf("compile authorized filter: %w", err) - } - - filtered, err := insertAuthorizedFilter(getFilteredUserCount, fmt.Sprintf(" AND %s", authorizedFilter)) - if err != nil { - return -1, xerrors.Errorf("insert authorized filter: %w", err) - } - - query := fmt.Sprintf("-- name: GetAuthorizedUserCount :one\n%s", filtered) - row := q.db.QueryRowContext(ctx, query, - arg.Search, - pq.Array(arg.Status), - pq.Array(arg.RbacRole), - ) - var count int64 - err = row.Scan(&count) - return count, err -} - func insertAuthorizedFilter(query string, replaceWith string) (string, error) { if !strings.Contains(query, authorizedQueryPlaceholder) { return "", xerrors.Errorf("query does not contain authorized replace string, this is not an authorized query") diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index e51891e7168f0..99068242b6487 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -5136,8 +5136,6 @@ WHERE THEN rbac_roles && $3 :: text[] ELSE true END - -- Authorize Filter clause will be injected below in GetAuthorizedUserCount - -- @authorize_filter ` type GetFilteredUserCountParams struct { From 60c70197b8d6543850cadfbb6d2b5eab52e47e9e Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Thu, 13 Jul 2023 10:47:32 -0400 Subject: [PATCH 3/7] Linting --- coderd/database/dbfake/dbfake.go | 2 +- coderd/database/modelqueries.go | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/coderd/database/dbfake/dbfake.go b/coderd/database/dbfake/dbfake.go index 03969f7b7a909..f6373a3a26abd 100644 --- a/coderd/database/dbfake/dbfake.go +++ b/coderd/database/dbfake/dbfake.go @@ -1599,7 +1599,7 @@ func (q *FakeQuerier) GetFileTemplates(_ context.Context, id uuid.UUID) ([]datab return rows, nil } -func (q *FakeQuerier) GetFilteredUserCount(ctx context.Context, params database.GetFilteredUserCountParams) (int64, error) { +func (q *FakeQuerier) GetFilteredUserCount(_ context.Context, params database.GetFilteredUserCountParams) (int64, error) { if err := validateDatabaseType(params); err != nil { return 0, err } diff --git a/coderd/database/modelqueries.go b/coderd/database/modelqueries.go index 1eea90f56319b..a7f186b668b0a 100644 --- a/coderd/database/modelqueries.go +++ b/coderd/database/modelqueries.go @@ -262,7 +262,6 @@ func (q *sqlQuerier) GetAuthorizedUsers(ctx context.Context, arg GetUsersParams, authorizedFilter, err := prepared.CompileToSQL(ctx, regosql.ConvertConfig{ VariableConverter: regosql.UserConverter(), }) - if err != nil { return nil, xerrors.Errorf("compile authorized filter: %w", err) } From 400916f5be1431fdc8f728d0ca9a3c890da46f67 Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Thu, 13 Jul 2023 10:48:48 -0400 Subject: [PATCH 4/7] dbauthz fix assert in unit test --- coderd/database/dbauthz/dbauthz_test.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/coderd/database/dbauthz/dbauthz_test.go b/coderd/database/dbauthz/dbauthz_test.go index 2083ab65a371c..99bb7659996dc 100644 --- a/coderd/database/dbauthz/dbauthz_test.go +++ b/coderd/database/dbauthz/dbauthz_test.go @@ -882,9 +882,9 @@ func (s *MethodTestSuite) TestUser() { Asserts() })) s.Run("GetUsersWithCount", s.Subtest(func(db database.Store, check *expects) { - a := dbgen.User(s.T(), db, database.User{Username: "GetUsersWithCount-a-user"}) - b := dbgen.User(s.T(), db, database.User{Username: "GetUsersWithCount-b-user"}) - check.Args(database.GetUsersParams{}).Asserts(a, rbac.ActionRead, b, rbac.ActionRead) + _ = dbgen.User(s.T(), db, database.User{Username: "GetUsersWithCount-a-user"}) + _ = dbgen.User(s.T(), db, database.User{Username: "GetUsersWithCount-b-user"}) + check.Args(database.GetUsersParams{}).Asserts() })) s.Run("InsertUser", s.Subtest(func(db database.Store, check *expects) { check.Args(database.InsertUserParams{ From 946ae676681b283668a60be319f17b1a7779cfc6 Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Thu, 13 Jul 2023 10:54:19 -0400 Subject: [PATCH 5/7] Make gen --- coderd/database/dbauthz/dbauthz.go | 1 - coderd/database/dbfake/dbfake.go | 70 +++++++++++++------------- coderd/database/dbmetrics/dbmetrics.go | 14 +++--- coderd/database/queries.sql.go | 2 + 4 files changed, 44 insertions(+), 43 deletions(-) diff --git a/coderd/database/dbauthz/dbauthz.go b/coderd/database/dbauthz/dbauthz.go index afd6e7ee0babd..b06f75eb05401 100644 --- a/coderd/database/dbauthz/dbauthz.go +++ b/coderd/database/dbauthz/dbauthz.go @@ -626,7 +626,6 @@ func (q *querier) GetUsersWithCount(ctx context.Context, arg database.GetUsersPa return []database.User{}, 0, nil } - // TODO: Is this correct? Should we return a restricted user? users := database.ConvertUserRows(rowUsers) return users, rowUsers[0].Count, nil } diff --git a/coderd/database/dbfake/dbfake.go b/coderd/database/dbfake/dbfake.go index f6373a3a26abd..1d448855a1083 100644 --- a/coderd/database/dbfake/dbfake.go +++ b/coderd/database/dbfake/dbfake.go @@ -926,6 +926,41 @@ func isNotNull(v interface{}) bool { // these methods remain unimplemented in the FakeQuerier. var ErrUnimplemented = xerrors.New("unimplemented") +func (q *FakeQuerier) GetAuthorizedUsers(ctx context.Context, arg database.GetUsersParams, prepared rbac.PreparedAuthorized) ([]database.GetUsersRow, error) { + if err := validateDatabaseType(arg); err != nil { + return nil, err + } + + // Call this to match the same function calls as the SQL implementation. + if prepared != nil { + _, err := prepared.CompileToSQL(ctx, regosql.ConvertConfig{ + VariableConverter: regosql.UserConverter(), + }) + if err != nil { + return nil, err + } + } + + users, err := q.GetUsers(ctx, arg) + if err != nil { + return nil, err + } + + q.mutex.RLock() + defer q.mutex.RUnlock() + + filteredUsers := make([]database.GetUsersRow, 0, len(users)) + for _, user := range users { + // If the filter exists, ensure the object is authorized. + if prepared != nil && prepared.Authorize(ctx, user.RBACObject()) != nil { + continue + } + + filteredUsers = append(filteredUsers, user) + } + return filteredUsers, nil +} + func (*FakeQuerier) AcquireLock(_ context.Context, _ int64) error { return xerrors.New("AcquireLock must only be called within a transaction") } @@ -5410,38 +5445,3 @@ func (*FakeQuerier) UpsertTailnetClient(context.Context, database.UpsertTailnetC func (*FakeQuerier) UpsertTailnetCoordinator(context.Context, uuid.UUID) (database.TailnetCoordinator, error) { return database.TailnetCoordinator{}, ErrUnimplemented } - -func (q *FakeQuerier) GetAuthorizedUsers(ctx context.Context, arg database.GetUsersParams, prepared rbac.PreparedAuthorized) ([]database.GetUsersRow, error) { - if err := validateDatabaseType(arg); err != nil { - return nil, err - } - - // Call this to match the same function calls as the SQL implementation. - if prepared != nil { - _, err := prepared.CompileToSQL(ctx, regosql.ConvertConfig{ - VariableConverter: regosql.UserConverter(), - }) - if err != nil { - return nil, err - } - } - - users, err := q.GetUsers(ctx, arg) - if err != nil { - return nil, err - } - - q.mutex.RLock() - defer q.mutex.RUnlock() - - filteredUsers := make([]database.GetUsersRow, 0, len(users)) - for _, user := range users { - // If the filter exists, ensure the object is authorized. - if prepared != nil && prepared.Authorize(ctx, user.RBACObject()) != nil { - continue - } - - filteredUsers = append(filteredUsers, user) - } - return filteredUsers, nil -} diff --git a/coderd/database/dbmetrics/dbmetrics.go b/coderd/database/dbmetrics/dbmetrics.go index 76dcab05d91f9..7ea35bf0b8d20 100644 --- a/coderd/database/dbmetrics/dbmetrics.go +++ b/coderd/database/dbmetrics/dbmetrics.go @@ -101,6 +101,13 @@ func (m metricsStore) GetAuthorizedWorkspaces(ctx context.Context, arg database. return workspaces, err } +func (m metricsStore) GetAuthorizedUsers(ctx context.Context, arg database.GetUsersParams, prepared rbac.PreparedAuthorized) ([]database.GetUsersRow, error) { + start := time.Now() + r0, r1 := m.s.GetAuthorizedUsers(ctx, arg, prepared) + m.queryLatencies.WithLabelValues("GetAuthorizedUsers").Observe(time.Since(start).Seconds()) + return r0, r1 +} + func (m metricsStore) AcquireLock(ctx context.Context, pgAdvisoryXactLock int64) error { start := time.Now() err := m.s.AcquireLock(ctx, pgAdvisoryXactLock) @@ -1632,10 +1639,3 @@ func (m metricsStore) UpsertTailnetCoordinator(ctx context.Context, id uuid.UUID defer m.queryLatencies.WithLabelValues("UpsertTailnetCoordinator").Observe(time.Since(start).Seconds()) return m.s.UpsertTailnetCoordinator(ctx, id) } - -func (m metricsStore) GetAuthorizedUsers(ctx context.Context, arg database.GetUsersParams, prepared rbac.PreparedAuthorized) ([]database.GetUsersRow, error) { - start := time.Now() - r0, r1 := m.s.GetAuthorizedUsers(ctx, arg, prepared) - m.queryLatencies.WithLabelValues("GetAuthorizedUsers").Observe(time.Since(start).Seconds()) - return r0, r1 -} diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index 99068242b6487..e51891e7168f0 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -5136,6 +5136,8 @@ WHERE THEN rbac_roles && $3 :: text[] ELSE true END + -- Authorize Filter clause will be injected below in GetAuthorizedUserCount + -- @authorize_filter ` type GetFilteredUserCountParams struct { From 3b974b3ea7bab4a86412fe8e90bdcb3ed750699d Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Thu, 13 Jul 2023 12:55:17 -0400 Subject: [PATCH 6/7] Remove GetFilteredUserCount --- coderd/database/dbauthz/dbauthz.go | 7 ---- coderd/database/dbauthz/dbauthz_test.go | 5 --- coderd/database/dbfake/dbfake.go | 54 ------------------------- coderd/database/dbmetrics/dbmetrics.go | 7 ---- coderd/database/dbmock/dbmock.go | 15 ------- coderd/database/querier.go | 2 - coderd/database/queries.sql.go | 49 ---------------------- coderd/database/queries/users.sql | 36 ----------------- 8 files changed, 175 deletions(-) diff --git a/coderd/database/dbauthz/dbauthz.go b/coderd/database/dbauthz/dbauthz.go index b06f75eb05401..e8da8919692f5 100644 --- a/coderd/database/dbauthz/dbauthz.go +++ b/coderd/database/dbauthz/dbauthz.go @@ -934,13 +934,6 @@ func (q *querier) GetFileTemplates(ctx context.Context, fileID uuid.UUID) ([]dat return q.db.GetFileTemplates(ctx, fileID) } -func (q *querier) GetFilteredUserCount(ctx context.Context, arg database.GetFilteredUserCountParams) (int64, error) { - if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceSystem); err != nil { - return -1, err - } - return q.db.GetFilteredUserCount(ctx, arg) -} - func (q *querier) GetGitAuthLink(ctx context.Context, arg database.GetGitAuthLinkParams) (database.GitAuthLink, error) { return fetch(q.log, q.auth, q.db.GetGitAuthLink)(ctx, arg) } diff --git a/coderd/database/dbauthz/dbauthz_test.go b/coderd/database/dbauthz/dbauthz_test.go index 99bb7659996dc..c38a168093da5 100644 --- a/coderd/database/dbauthz/dbauthz_test.go +++ b/coderd/database/dbauthz/dbauthz_test.go @@ -869,11 +869,6 @@ func (s *MethodTestSuite) TestUser() { Asserts(a, rbac.ActionRead, b, rbac.ActionRead). Returns(slice.New(a, b)) })) - s.Run("GetFilteredUserCount", s.Subtest(func(db database.Store, check *expects) { - _ = dbgen.User(s.T(), db, database.User{}) - check.Args(database.GetFilteredUserCountParams{}).Asserts( - rbac.ResourceSystem, rbac.ActionRead).Returns(int64(1)) - })) s.Run("GetUsers", s.Subtest(func(db database.Store, check *expects) { dbgen.User(s.T(), db, database.User{Username: "GetUsers-a-user"}) dbgen.User(s.T(), db, database.User{Username: "GetUsers-b-user"}) diff --git a/coderd/database/dbfake/dbfake.go b/coderd/database/dbfake/dbfake.go index 1d448855a1083..184768e83bd6a 100644 --- a/coderd/database/dbfake/dbfake.go +++ b/coderd/database/dbfake/dbfake.go @@ -1634,60 +1634,6 @@ func (q *FakeQuerier) GetFileTemplates(_ context.Context, id uuid.UUID) ([]datab return rows, nil } -func (q *FakeQuerier) GetFilteredUserCount(_ context.Context, params database.GetFilteredUserCountParams) (int64, error) { - if err := validateDatabaseType(params); err != nil { - return 0, err - } - - q.mutex.RLock() - defer q.mutex.RUnlock() - - // Filter out deleted since they should never be returned.. - users := make([]database.User, 0, len(q.users)) - for _, user := range q.users { - if !user.Deleted { - users = append(users, user) - } - } - - if params.Search != "" { - tmp := make([]database.User, 0, len(users)) - for i, user := range users { - if strings.Contains(strings.ToLower(user.Email), strings.ToLower(params.Search)) { - tmp = append(tmp, users[i]) - } else if strings.Contains(strings.ToLower(user.Username), strings.ToLower(params.Search)) { - tmp = append(tmp, users[i]) - } - } - users = tmp - } - - if len(params.Status) > 0 { - usersFilteredByStatus := make([]database.User, 0, len(users)) - for i, user := range users { - if slice.ContainsCompare(params.Status, user.Status, func(a, b database.UserStatus) bool { - return strings.EqualFold(string(a), string(b)) - }) { - usersFilteredByStatus = append(usersFilteredByStatus, users[i]) - } - } - users = usersFilteredByStatus - } - - if len(params.RbacRole) > 0 && !slice.Contains(params.RbacRole, rbac.RoleMember()) { - usersFilteredByRole := make([]database.User, 0, len(users)) - for i, user := range users { - if slice.OverlapCompare(params.RbacRole, user.RBACRoles, strings.EqualFold) { - usersFilteredByRole = append(usersFilteredByRole, users[i]) - } - } - - users = usersFilteredByRole - } - - return int64(len(users)), nil -} - func (q *FakeQuerier) GetGitAuthLink(_ context.Context, arg database.GetGitAuthLinkParams) (database.GitAuthLink, error) { if err := validateDatabaseType(arg); err != nil { return database.GitAuthLink{}, err diff --git a/coderd/database/dbmetrics/dbmetrics.go b/coderd/database/dbmetrics/dbmetrics.go index 7ea35bf0b8d20..32e37ea557110 100644 --- a/coderd/database/dbmetrics/dbmetrics.go +++ b/coderd/database/dbmetrics/dbmetrics.go @@ -350,13 +350,6 @@ func (m metricsStore) GetFileTemplates(ctx context.Context, fileID uuid.UUID) ([ return rows, err } -func (m metricsStore) GetFilteredUserCount(ctx context.Context, arg database.GetFilteredUserCountParams) (int64, error) { - start := time.Now() - count, err := m.s.GetFilteredUserCount(ctx, arg) - m.queryLatencies.WithLabelValues("GetFilteredUserCount").Observe(time.Since(start).Seconds()) - return count, err -} - func (m metricsStore) GetGitAuthLink(ctx context.Context, arg database.GetGitAuthLinkParams) (database.GitAuthLink, error) { start := time.Now() link, err := m.s.GetGitAuthLink(ctx, arg) diff --git a/coderd/database/dbmock/dbmock.go b/coderd/database/dbmock/dbmock.go index f4d9c2f296c68..f672a5e5dfc61 100644 --- a/coderd/database/dbmock/dbmock.go +++ b/coderd/database/dbmock/dbmock.go @@ -596,21 +596,6 @@ func (mr *MockStoreMockRecorder) GetFileTemplates(arg0, arg1 interface{}) *gomoc return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetFileTemplates", reflect.TypeOf((*MockStore)(nil).GetFileTemplates), arg0, arg1) } -// GetFilteredUserCount mocks base method. -func (m *MockStore) GetFilteredUserCount(arg0 context.Context, arg1 database.GetFilteredUserCountParams) (int64, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetFilteredUserCount", arg0, arg1) - ret0, _ := ret[0].(int64) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetFilteredUserCount indicates an expected call of GetFilteredUserCount. -func (mr *MockStoreMockRecorder) GetFilteredUserCount(arg0, arg1 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetFilteredUserCount", reflect.TypeOf((*MockStore)(nil).GetFilteredUserCount), arg0, arg1) -} - // GetGitAuthLink mocks base method. func (m *MockStore) GetGitAuthLink(arg0 context.Context, arg1 database.GetGitAuthLinkParams) (database.GitAuthLink, error) { m.ctrl.T.Helper() diff --git a/coderd/database/querier.go b/coderd/database/querier.go index 6959823a01945..e0824b2b6396a 100644 --- a/coderd/database/querier.go +++ b/coderd/database/querier.go @@ -65,8 +65,6 @@ type sqlcQuerier interface { GetFileByID(ctx context.Context, id uuid.UUID) (File, error) // Get all templates that use a file. GetFileTemplates(ctx context.Context, fileID uuid.UUID) ([]GetFileTemplatesRow, error) - // This will never count deleted users. - GetFilteredUserCount(ctx context.Context, arg GetFilteredUserCountParams) (int64, error) GetGitAuthLink(ctx context.Context, arg GetGitAuthLinkParams) (GitAuthLink, error) GetGitSSHKey(ctx context.Context, userID uuid.UUID) (GitSSHKey, error) GetGroupByID(ctx context.Context, id uuid.UUID) (Group, error) diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index e51891e7168f0..f7e89cab7c43f 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -5105,55 +5105,6 @@ func (q *sqlQuerier) GetAuthorizationUserRoles(ctx context.Context, userID uuid. return i, err } -const getFilteredUserCount = `-- name: GetFilteredUserCount :one -SELECT - COUNT(*) -FROM - users -WHERE - users.deleted = false - -- Start filters - -- Filter by name, email or username - AND CASE - WHEN $1 :: text != '' THEN ( - email ILIKE concat('%', $1, '%') - OR username ILIKE concat('%', $1, '%') - ) - ELSE true - END - -- Filter by status - AND CASE - -- @status needs to be a text because it can be empty, If it was - -- user_status enum, it would not. - WHEN cardinality($2 :: user_status[]) > 0 THEN - status = ANY($2 :: user_status[]) - ELSE true - END - -- Filter by rbac_roles - AND CASE - -- @rbac_role allows filtering by rbac roles. If 'member' is included, show everyone, as everyone is a member. - WHEN cardinality($3 :: text[]) > 0 AND 'member' != ANY($3 :: text[]) - THEN rbac_roles && $3 :: text[] - ELSE true - END - -- Authorize Filter clause will be injected below in GetAuthorizedUserCount - -- @authorize_filter -` - -type GetFilteredUserCountParams struct { - Search string `db:"search" json:"search"` - Status []UserStatus `db:"status" json:"status"` - RbacRole []string `db:"rbac_role" json:"rbac_role"` -} - -// This will never count deleted users. -func (q *sqlQuerier) GetFilteredUserCount(ctx context.Context, arg GetFilteredUserCountParams) (int64, error) { - row := q.db.QueryRowContext(ctx, getFilteredUserCount, arg.Search, pq.Array(arg.Status), pq.Array(arg.RbacRole)) - var count int64 - err := row.Scan(&count) - return count, err -} - const getUserByEmailOrUsername = `-- name: GetUserByEmailOrUsername :one SELECT id, email, username, hashed_password, created_at, updated_at, status, rbac_roles, login_type, avatar_url, deleted, last_seen_at diff --git a/coderd/database/queries/users.sql b/coderd/database/queries/users.sql index 2115b2eda332e..cd5426e576ca5 100644 --- a/coderd/database/queries/users.sql +++ b/coderd/database/queries/users.sql @@ -56,42 +56,6 @@ FROM WHERE status = 'active'::user_status AND deleted = false; --- name: GetFilteredUserCount :one --- This will never count deleted users. -SELECT - COUNT(*) -FROM - users -WHERE - users.deleted = false - -- Start filters - -- Filter by name, email or username - AND CASE - WHEN @search :: text != '' THEN ( - email ILIKE concat('%', @search, '%') - OR username ILIKE concat('%', @search, '%') - ) - ELSE true - END - -- Filter by status - AND CASE - -- @status needs to be a text because it can be empty, If it was - -- user_status enum, it would not. - WHEN cardinality(@status :: user_status[]) > 0 THEN - status = ANY(@status :: user_status[]) - ELSE true - END - -- Filter by rbac_roles - AND CASE - -- @rbac_role allows filtering by rbac roles. If 'member' is included, show everyone, as everyone is a member. - WHEN cardinality(@rbac_role :: text[]) > 0 AND 'member' != ANY(@rbac_role :: text[]) - THEN rbac_roles && @rbac_role :: text[] - ELSE true - END - -- Authorize Filter clause will be injected below in GetAuthorizedUserCount - -- @authorize_filter -; - -- name: InsertUser :one INSERT INTO users ( From 37330977224f2685d6791eac12fe4b0e1f65c1ea Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Fri, 14 Jul 2023 09:47:42 -0400 Subject: [PATCH 7/7] remove GetUsersWithCount --- coderd/database/dbauthz/dbauthz.go | 15 --------------- coderd/database/dbauthz/dbauthz_test.go | 5 ----- 2 files changed, 20 deletions(-) diff --git a/coderd/database/dbauthz/dbauthz.go b/coderd/database/dbauthz/dbauthz.go index 98fda94398a58..ef2b7b9d18b91 100644 --- a/coderd/database/dbauthz/dbauthz.go +++ b/coderd/database/dbauthz/dbauthz.go @@ -586,21 +586,6 @@ func (q *querier) SoftDeleteTemplateByID(ctx context.Context, id uuid.UUID) erro return deleteQ(q.log, q.auth, q.db.GetTemplateByID, deleteF)(ctx, id) } -func (q *querier) GetUsersWithCount(ctx context.Context, arg database.GetUsersParams) ([]database.User, int64, error) { - // q.GetUsers only returns authorized users - rowUsers, err := q.GetUsers(ctx, arg) - if err != nil { - return nil, -1, err - } - - if len(rowUsers) == 0 { - return []database.User{}, 0, nil - } - - users := database.ConvertUserRows(rowUsers) - return users, rowUsers[0].Count, nil -} - func (q *querier) SoftDeleteUserByID(ctx context.Context, id uuid.UUID) error { deleteF := func(ctx context.Context, id uuid.UUID) error { return q.db.UpdateUserDeletedByID(ctx, database.UpdateUserDeletedByIDParams{ diff --git a/coderd/database/dbauthz/dbauthz_test.go b/coderd/database/dbauthz/dbauthz_test.go index c38a168093da5..b9e42f632bd6c 100644 --- a/coderd/database/dbauthz/dbauthz_test.go +++ b/coderd/database/dbauthz/dbauthz_test.go @@ -876,11 +876,6 @@ func (s *MethodTestSuite) TestUser() { // Asserts are done in a SQL filter Asserts() })) - s.Run("GetUsersWithCount", s.Subtest(func(db database.Store, check *expects) { - _ = dbgen.User(s.T(), db, database.User{Username: "GetUsersWithCount-a-user"}) - _ = dbgen.User(s.T(), db, database.User{Username: "GetUsersWithCount-b-user"}) - check.Args(database.GetUsersParams{}).Asserts() - })) s.Run("InsertUser", s.Subtest(func(db database.Store, check *expects) { check.Args(database.InsertUserParams{ ID: uuid.New(),