diff --git a/coderd/database/dbauthz/dbauthz.go b/coderd/database/dbauthz/dbauthz.go index 8bfede13e8c72..ef2b7b9d18b91 100644 --- a/coderd/database/dbauthz/dbauthz.go +++ b/coderd/database/dbauthz/dbauthz.go @@ -586,32 +586,6 @@ func (q *querier) SoftDeleteTemplateByID(ctx context.Context, id uuid.UUID) erro return deleteQ(q.log, q.auth, q.db.GetTemplateByID, deleteF)(ctx, id) } -func (q *querier) GetUsersWithCount(ctx context.Context, arg database.GetUsersParams) ([]database.User, int64, error) { - // TODO Implement this with a SQL filter. The count is incorrect without it. - rowUsers, err := q.db.GetUsers(ctx, arg) - if err != nil { - return nil, -1, err - } - - if len(rowUsers) == 0 { - return []database.User{}, 0, nil - } - - act, ok := ActorFromContext(ctx) - if !ok { - return nil, -1, NoActorError - } - - // TODO: Is this correct? Should we return a restricted user? - users := database.ConvertUserRows(rowUsers) - users, err = rbac.Filter(ctx, q.auth, act, rbac.ActionRead, users) - if err != nil { - return nil, -1, err - } - - return users, rowUsers[0].Count, nil -} - func (q *querier) SoftDeleteUserByID(ctx context.Context, id uuid.UUID) error { deleteF := func(ctx context.Context, id uuid.UUID) error { return q.db.UpdateUserDeletedByID(ctx, database.UpdateUserDeletedByIDParams{ @@ -904,15 +878,6 @@ func (q *querier) GetFileTemplates(ctx context.Context, fileID uuid.UUID) ([]dat return q.db.GetFileTemplates(ctx, fileID) } -func (q *querier) GetFilteredUserCount(ctx context.Context, arg database.GetFilteredUserCountParams) (int64, error) { - prep, err := prepareSQLFilter(ctx, q.auth, rbac.ActionRead, rbac.ResourceUser.Type) - if err != nil { - return -1, xerrors.Errorf("(dev error) prepare sql filter: %w", err) - } - // TODO: This should be the only implementation. - return q.GetAuthorizedUserCount(ctx, arg, prep) -} - func (q *querier) GetGitAuthLink(ctx context.Context, arg database.GetGitAuthLinkParams) (database.GitAuthLink, error) { return fetch(q.log, q.auth, q.db.GetGitAuthLink)(ctx, arg) } @@ -1389,8 +1354,12 @@ func (q *querier) GetUserLinkByUserIDLoginType(ctx context.Context, arg database } func (q *querier) GetUsers(ctx context.Context, arg database.GetUsersParams) ([]database.GetUsersRow, error) { - // TODO: We should use GetUsersWithCount with a better method signature. - return fetchWithPostFilter(q.auth, q.db.GetUsers)(ctx, arg) + // This does the filtering in SQL. + prep, err := prepareSQLFilter(ctx, q.auth, rbac.ActionRead, rbac.ResourceUser.Type) + if err != nil { + return nil, xerrors.Errorf("(dev error) prepare sql filter: %w", err) + } + return q.db.GetAuthorizedUsers(ctx, arg, prep) } // GetUsersByIDs is only used for usernames on workspace return data. @@ -2639,6 +2608,9 @@ func (q *querier) GetAuthorizedWorkspaces(ctx context.Context, arg database.GetW return q.GetWorkspaces(ctx, arg) } -func (q *querier) GetAuthorizedUserCount(ctx context.Context, arg database.GetFilteredUserCountParams, prepared rbac.PreparedAuthorized) (int64, error) { - return q.db.GetAuthorizedUserCount(ctx, arg, prepared) +// GetAuthorizedUsers is not required for dbauthz since GetUsers is already +// authenticated. +func (q *querier) GetAuthorizedUsers(ctx context.Context, arg database.GetUsersParams, _ rbac.PreparedAuthorized) ([]database.GetUsersRow, error) { + // GetUsers is authenticated. + return q.GetUsers(ctx, arg) } diff --git a/coderd/database/dbauthz/dbauthz_test.go b/coderd/database/dbauthz/dbauthz_test.go index bde4a1dfd5ef4..b9e42f632bd6c 100644 --- a/coderd/database/dbauthz/dbauthz_test.go +++ b/coderd/database/dbauthz/dbauthz_test.go @@ -869,24 +869,12 @@ func (s *MethodTestSuite) TestUser() { Asserts(a, rbac.ActionRead, b, rbac.ActionRead). Returns(slice.New(a, b)) })) - s.Run("GetAuthorizedUserCount", s.Subtest(func(db database.Store, check *expects) { - _ = dbgen.User(s.T(), db, database.User{}) - check.Args(database.GetFilteredUserCountParams{}, emptyPreparedAuthorized{}).Asserts().Returns(int64(1)) - })) - s.Run("GetFilteredUserCount", s.Subtest(func(db database.Store, check *expects) { - _ = dbgen.User(s.T(), db, database.User{}) - check.Args(database.GetFilteredUserCountParams{}).Asserts().Returns(int64(1)) - })) s.Run("GetUsers", s.Subtest(func(db database.Store, check *expects) { - a := dbgen.User(s.T(), db, database.User{Username: "GetUsers-a-user"}) - b := dbgen.User(s.T(), db, database.User{Username: "GetUsers-b-user"}) + dbgen.User(s.T(), db, database.User{Username: "GetUsers-a-user"}) + dbgen.User(s.T(), db, database.User{Username: "GetUsers-b-user"}) check.Args(database.GetUsersParams{}). - Asserts(a, rbac.ActionRead, b, rbac.ActionRead) - })) - s.Run("GetUsersWithCount", s.Subtest(func(db database.Store, check *expects) { - a := dbgen.User(s.T(), db, database.User{Username: "GetUsersWithCount-a-user"}) - b := dbgen.User(s.T(), db, database.User{Username: "GetUsersWithCount-b-user"}) - check.Args(database.GetUsersParams{}).Asserts(a, rbac.ActionRead, b, rbac.ActionRead) + // Asserts are done in a SQL filter + Asserts() })) s.Run("InsertUser", s.Subtest(func(db database.Store, check *expects) { check.Args(database.InsertUserParams{ diff --git a/coderd/database/dbfake/dbfake.go b/coderd/database/dbfake/dbfake.go index 02e70639cd161..d68d0d08af199 100644 --- a/coderd/database/dbfake/dbfake.go +++ b/coderd/database/dbfake/dbfake.go @@ -23,6 +23,7 @@ import ( "github.com/coder/coder/coderd/database/db2sdk" "github.com/coder/coder/coderd/httpapi" "github.com/coder/coder/coderd/rbac" + "github.com/coder/coder/coderd/rbac/regosql" "github.com/coder/coder/coderd/util/slice" "github.com/coder/coder/codersdk" ) @@ -1207,14 +1208,6 @@ func (q *FakeQuerier) GetFileTemplates(_ context.Context, id uuid.UUID) ([]datab return rows, nil } -func (q *FakeQuerier) GetFilteredUserCount(ctx context.Context, arg database.GetFilteredUserCountParams) (int64, error) { - if err := validateDatabaseType(arg); err != nil { - return 0, err - } - count, err := q.GetAuthorizedUserCount(ctx, arg, nil) - return count, err -} - func (q *FakeQuerier) GetGitAuthLink(_ context.Context, arg database.GetGitAuthLinkParams) (database.GitAuthLink, error) { if err := validateDatabaseType(arg); err != nil { return database.GitAuthLink{}, err @@ -5365,76 +5358,37 @@ func (q *FakeQuerier) GetAuthorizedWorkspaces(ctx context.Context, arg database. return q.convertToWorkspaceRowsNoLock(ctx, workspaces, int64(beforePageCount)), nil } -func (q *FakeQuerier) GetAuthorizedUserCount(ctx context.Context, params database.GetFilteredUserCountParams, prepared rbac.PreparedAuthorized) (int64, error) { - if err := validateDatabaseType(params); err != nil { - return 0, err +func (q *FakeQuerier) GetAuthorizedUsers(ctx context.Context, arg database.GetUsersParams, prepared rbac.PreparedAuthorized) ([]database.GetUsersRow, error) { + if err := validateDatabaseType(arg); err != nil { + return nil, err } - q.mutex.RLock() - defer q.mutex.RUnlock() - // Call this to match the same function calls as the SQL implementation. if prepared != nil { - _, err := prepared.CompileToSQL(ctx, rbac.ConfigWithoutACL()) + _, err := prepared.CompileToSQL(ctx, regosql.ConvertConfig{ + VariableConverter: regosql.UserConverter(), + }) if err != nil { - return -1, err + return nil, err } } - users := make([]database.User, 0, len(q.users)) + users, err := q.GetUsers(ctx, arg) + if err != nil { + return nil, err + } - for _, user := range q.users { + q.mutex.RLock() + defer q.mutex.RUnlock() + + filteredUsers := make([]database.GetUsersRow, 0, len(users)) + for _, user := range users { // If the filter exists, ensure the object is authorized. if prepared != nil && prepared.Authorize(ctx, user.RBACObject()) != nil { continue } - users = append(users, user) - } - - // Filter out deleted since they should never be returned.. - tmp := make([]database.User, 0, len(users)) - for _, user := range users { - if !user.Deleted { - tmp = append(tmp, user) - } - } - users = tmp - - if params.Search != "" { - tmp := make([]database.User, 0, len(users)) - for i, user := range users { - if strings.Contains(strings.ToLower(user.Email), strings.ToLower(params.Search)) { - tmp = append(tmp, users[i]) - } else if strings.Contains(strings.ToLower(user.Username), strings.ToLower(params.Search)) { - tmp = append(tmp, users[i]) - } - } - users = tmp - } - - if len(params.Status) > 0 { - usersFilteredByStatus := make([]database.User, 0, len(users)) - for i, user := range users { - if slice.ContainsCompare(params.Status, user.Status, func(a, b database.UserStatus) bool { - return strings.EqualFold(string(a), string(b)) - }) { - usersFilteredByStatus = append(usersFilteredByStatus, users[i]) - } - } - users = usersFilteredByStatus - } - - if len(params.RbacRole) > 0 && !slice.Contains(params.RbacRole, rbac.RoleMember()) { - usersFilteredByRole := make([]database.User, 0, len(users)) - for i, user := range users { - if slice.OverlapCompare(params.RbacRole, user.RBACRoles, strings.EqualFold) { - usersFilteredByRole = append(usersFilteredByRole, users[i]) - } - } - - users = usersFilteredByRole + filteredUsers = append(filteredUsers, user) } - - return int64(len(users)), nil + return filteredUsers, nil } diff --git a/coderd/database/dbmetrics/dbmetrics.go b/coderd/database/dbmetrics/dbmetrics.go index 11d857a30262b..12f03ea8c75fd 100644 --- a/coderd/database/dbmetrics/dbmetrics.go +++ b/coderd/database/dbmetrics/dbmetrics.go @@ -321,13 +321,6 @@ func (m metricsStore) GetFileTemplates(ctx context.Context, fileID uuid.UUID) ([ return rows, err } -func (m metricsStore) GetFilteredUserCount(ctx context.Context, arg database.GetFilteredUserCountParams) (int64, error) { - start := time.Now() - count, err := m.s.GetFilteredUserCount(ctx, arg) - m.queryLatencies.WithLabelValues("GetFilteredUserCount").Observe(time.Since(start).Seconds()) - return count, err -} - func (m metricsStore) GetGitAuthLink(ctx context.Context, arg database.GetGitAuthLinkParams) (database.GitAuthLink, error) { start := time.Now() link, err := m.s.GetGitAuthLink(ctx, arg) @@ -1639,9 +1632,9 @@ func (m metricsStore) GetAuthorizedWorkspaces(ctx context.Context, arg database. return workspaces, err } -func (m metricsStore) GetAuthorizedUserCount(ctx context.Context, arg database.GetFilteredUserCountParams, prepared rbac.PreparedAuthorized) (int64, error) { +func (m metricsStore) GetAuthorizedUsers(ctx context.Context, arg database.GetUsersParams, prepared rbac.PreparedAuthorized) ([]database.GetUsersRow, error) { start := time.Now() - count, err := m.s.GetAuthorizedUserCount(ctx, arg, prepared) - m.queryLatencies.WithLabelValues("GetAuthorizedUserCount").Observe(time.Since(start).Seconds()) - return count, err + r0, r1 := m.s.GetAuthorizedUsers(ctx, arg, prepared) + m.queryLatencies.WithLabelValues("GetAuthorizedUsers").Observe(time.Since(start).Seconds()) + return r0, r1 } diff --git a/coderd/database/dbmock/dbmock.go b/coderd/database/dbmock/dbmock.go index deab31927154f..f672a5e5dfc61 100644 --- a/coderd/database/dbmock/dbmock.go +++ b/coderd/database/dbmock/dbmock.go @@ -431,19 +431,19 @@ func (mr *MockStoreMockRecorder) GetAuthorizedTemplates(arg0, arg1, arg2 interfa return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAuthorizedTemplates", reflect.TypeOf((*MockStore)(nil).GetAuthorizedTemplates), arg0, arg1, arg2) } -// GetAuthorizedUserCount mocks base method. -func (m *MockStore) GetAuthorizedUserCount(arg0 context.Context, arg1 database.GetFilteredUserCountParams, arg2 rbac.PreparedAuthorized) (int64, error) { +// GetAuthorizedUsers mocks base method. +func (m *MockStore) GetAuthorizedUsers(arg0 context.Context, arg1 database.GetUsersParams, arg2 rbac.PreparedAuthorized) ([]database.GetUsersRow, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetAuthorizedUserCount", arg0, arg1, arg2) - ret0, _ := ret[0].(int64) + ret := m.ctrl.Call(m, "GetAuthorizedUsers", arg0, arg1, arg2) + ret0, _ := ret[0].([]database.GetUsersRow) ret1, _ := ret[1].(error) return ret0, ret1 } -// GetAuthorizedUserCount indicates an expected call of GetAuthorizedUserCount. -func (mr *MockStoreMockRecorder) GetAuthorizedUserCount(arg0, arg1, arg2 interface{}) *gomock.Call { +// GetAuthorizedUsers indicates an expected call of GetAuthorizedUsers. +func (mr *MockStoreMockRecorder) GetAuthorizedUsers(arg0, arg1, arg2 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAuthorizedUserCount", reflect.TypeOf((*MockStore)(nil).GetAuthorizedUserCount), arg0, arg1, arg2) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAuthorizedUsers", reflect.TypeOf((*MockStore)(nil).GetAuthorizedUsers), arg0, arg1, arg2) } // GetAuthorizedWorkspaces mocks base method. @@ -596,21 +596,6 @@ func (mr *MockStoreMockRecorder) GetFileTemplates(arg0, arg1 interface{}) *gomoc return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetFileTemplates", reflect.TypeOf((*MockStore)(nil).GetFileTemplates), arg0, arg1) } -// GetFilteredUserCount mocks base method. -func (m *MockStore) GetFilteredUserCount(arg0 context.Context, arg1 database.GetFilteredUserCountParams) (int64, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetFilteredUserCount", arg0, arg1) - ret0, _ := ret[0].(int64) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetFilteredUserCount indicates an expected call of GetFilteredUserCount. -func (mr *MockStoreMockRecorder) GetFilteredUserCount(arg0, arg1 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetFilteredUserCount", reflect.TypeOf((*MockStore)(nil).GetFilteredUserCount), arg0, arg1) -} - // GetGitAuthLink mocks base method. func (m *MockStore) GetGitAuthLink(arg0 context.Context, arg1 database.GetGitAuthLinkParams) (database.GitAuthLink, error) { m.ctrl.T.Helper() diff --git a/coderd/database/modelqueries.go b/coderd/database/modelqueries.go index 28a56b825f34e..a7f186b668b0a 100644 --- a/coderd/database/modelqueries.go +++ b/coderd/database/modelqueries.go @@ -255,29 +255,66 @@ func (q *sqlQuerier) GetAuthorizedWorkspaces(ctx context.Context, arg GetWorkspa } type userQuerier interface { - GetAuthorizedUserCount(ctx context.Context, arg GetFilteredUserCountParams, prepared rbac.PreparedAuthorized) (int64, error) + GetAuthorizedUsers(ctx context.Context, arg GetUsersParams, prepared rbac.PreparedAuthorized) ([]GetUsersRow, error) } -func (q *sqlQuerier) GetAuthorizedUserCount(ctx context.Context, arg GetFilteredUserCountParams, prepared rbac.PreparedAuthorized) (int64, error) { - authorizedFilter, err := prepared.CompileToSQL(ctx, rbac.ConfigWithoutACL()) +func (q *sqlQuerier) GetAuthorizedUsers(ctx context.Context, arg GetUsersParams, prepared rbac.PreparedAuthorized) ([]GetUsersRow, error) { + authorizedFilter, err := prepared.CompileToSQL(ctx, regosql.ConvertConfig{ + VariableConverter: regosql.UserConverter(), + }) if err != nil { - return -1, xerrors.Errorf("compile authorized filter: %w", err) + return nil, xerrors.Errorf("compile authorized filter: %w", err) } - filtered, err := insertAuthorizedFilter(getFilteredUserCount, fmt.Sprintf(" AND %s", authorizedFilter)) + filtered, err := insertAuthorizedFilter(getUsers, fmt.Sprintf(" AND %s", authorizedFilter)) if err != nil { - return -1, xerrors.Errorf("insert authorized filter: %w", err) + return nil, xerrors.Errorf("insert authorized filter: %w", err) } - query := fmt.Sprintf("-- name: GetAuthorizedUserCount :one\n%s", filtered) - row := q.db.QueryRowContext(ctx, query, + query := fmt.Sprintf("-- name: GetAuthorizedUsers :many\n%s", filtered) + rows, err := q.db.QueryContext(ctx, query, + arg.AfterID, arg.Search, pq.Array(arg.Status), pq.Array(arg.RbacRole), + arg.LastSeenBefore, + arg.LastSeenAfter, + arg.OffsetOpt, + arg.LimitOpt, ) - var count int64 - err = row.Scan(&count) - return count, err + if err != nil { + return nil, err + } + defer rows.Close() + var items []GetUsersRow + for rows.Next() { + var i GetUsersRow + if err := rows.Scan( + &i.ID, + &i.Email, + &i.Username, + &i.HashedPassword, + &i.CreatedAt, + &i.UpdatedAt, + &i.Status, + &i.RBACRoles, + &i.LoginType, + &i.AvatarURL, + &i.Deleted, + &i.LastSeenAt, + &i.Count, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil } func insertAuthorizedFilter(query string, replaceWith string) (string, error) { diff --git a/coderd/database/querier.go b/coderd/database/querier.go index 6959823a01945..e0824b2b6396a 100644 --- a/coderd/database/querier.go +++ b/coderd/database/querier.go @@ -65,8 +65,6 @@ type sqlcQuerier interface { GetFileByID(ctx context.Context, id uuid.UUID) (File, error) // Get all templates that use a file. GetFileTemplates(ctx context.Context, fileID uuid.UUID) ([]GetFileTemplatesRow, error) - // This will never count deleted users. - GetFilteredUserCount(ctx context.Context, arg GetFilteredUserCountParams) (int64, error) GetGitAuthLink(ctx context.Context, arg GetGitAuthLinkParams) (GitAuthLink, error) GetGitSSHKey(ctx context.Context, userID uuid.UUID) (GitSSHKey, error) GetGroupByID(ctx context.Context, id uuid.UUID) (Group, error) diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index 1cb61174b2913..b369372635792 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -5108,55 +5108,6 @@ func (q *sqlQuerier) GetAuthorizationUserRoles(ctx context.Context, userID uuid. return i, err } -const getFilteredUserCount = `-- name: GetFilteredUserCount :one -SELECT - COUNT(*) -FROM - users -WHERE - users.deleted = false - -- Start filters - -- Filter by name, email or username - AND CASE - WHEN $1 :: text != '' THEN ( - email ILIKE concat('%', $1, '%') - OR username ILIKE concat('%', $1, '%') - ) - ELSE true - END - -- Filter by status - AND CASE - -- @status needs to be a text because it can be empty, If it was - -- user_status enum, it would not. - WHEN cardinality($2 :: user_status[]) > 0 THEN - status = ANY($2 :: user_status[]) - ELSE true - END - -- Filter by rbac_roles - AND CASE - -- @rbac_role allows filtering by rbac roles. If 'member' is included, show everyone, as everyone is a member. - WHEN cardinality($3 :: text[]) > 0 AND 'member' != ANY($3 :: text[]) - THEN rbac_roles && $3 :: text[] - ELSE true - END - -- Authorize Filter clause will be injected below in GetAuthorizedUserCount - -- @authorize_filter -` - -type GetFilteredUserCountParams struct { - Search string `db:"search" json:"search"` - Status []UserStatus `db:"status" json:"status"` - RbacRole []string `db:"rbac_role" json:"rbac_role"` -} - -// This will never count deleted users. -func (q *sqlQuerier) GetFilteredUserCount(ctx context.Context, arg GetFilteredUserCountParams) (int64, error) { - row := q.db.QueryRowContext(ctx, getFilteredUserCount, arg.Search, pq.Array(arg.Status), pq.Array(arg.RbacRole)) - var count int64 - err := row.Scan(&count) - return count, err -} - const getUserByEmailOrUsername = `-- name: GetUserByEmailOrUsername :one SELECT id, email, username, hashed_password, created_at, updated_at, status, rbac_roles, login_type, avatar_url, deleted, last_seen_at @@ -5304,6 +5255,9 @@ WHERE ELSE true END -- End of filters + + -- Authorize Filter clause will be injected below in GetAuthorizedUsers + -- @authorize_filter ORDER BY -- Deterministic and consistent ordering of all users. This is to ensure consistent pagination. LOWER(username) ASC OFFSET $7 diff --git a/coderd/database/queries/users.sql b/coderd/database/queries/users.sql index 75cc85cdf90de..cd5426e576ca5 100644 --- a/coderd/database/queries/users.sql +++ b/coderd/database/queries/users.sql @@ -56,42 +56,6 @@ FROM WHERE status = 'active'::user_status AND deleted = false; --- name: GetFilteredUserCount :one --- This will never count deleted users. -SELECT - COUNT(*) -FROM - users -WHERE - users.deleted = false - -- Start filters - -- Filter by name, email or username - AND CASE - WHEN @search :: text != '' THEN ( - email ILIKE concat('%', @search, '%') - OR username ILIKE concat('%', @search, '%') - ) - ELSE true - END - -- Filter by status - AND CASE - -- @status needs to be a text because it can be empty, If it was - -- user_status enum, it would not. - WHEN cardinality(@status :: user_status[]) > 0 THEN - status = ANY(@status :: user_status[]) - ELSE true - END - -- Filter by rbac_roles - AND CASE - -- @rbac_role allows filtering by rbac roles. If 'member' is included, show everyone, as everyone is a member. - WHEN cardinality(@rbac_role :: text[]) > 0 AND 'member' != ANY(@rbac_role :: text[]) - THEN rbac_roles && @rbac_role :: text[] - ELSE true - END - -- Authorize Filter clause will be injected below in GetAuthorizedUserCount - -- @authorize_filter -; - -- name: InsertUser :one INSERT INTO users ( @@ -208,6 +172,9 @@ WHERE ELSE true END -- End of filters + + -- Authorize Filter clause will be injected below in GetAuthorizedUsers + -- @authorize_filter ORDER BY -- Deterministic and consistent ordering of all users. This is to ensure consistent pagination. LOWER(username) ASC OFFSET @offset_opt diff --git a/coderd/rbac/regosql/compile_test.go b/coderd/rbac/regosql/compile_test.go index 6c350b7834639..1997b279d6808 100644 --- a/coderd/rbac/regosql/compile_test.go +++ b/coderd/rbac/regosql/compile_test.go @@ -242,6 +242,26 @@ neq(input.object.owner, ""); p("false")), VariableConverter: regosql.TemplateConverter(), }, + { + Name: "UserNoOrgOwner", + Queries: []string{ + `input.object.org_owner != ""`, + }, + ExpectedSQL: p("'' != ''"), + VariableConverter: regosql.UserConverter(), + }, + { + Name: "UserOwnsSelf", + Queries: []string{ + `"10d03e62-7703-4df5-a358-4f76577d4e2f" = input.object.owner; + input.object.owner != ""; + input.object.org_owner = ""`, + }, + VariableConverter: regosql.UserConverter(), + ExpectedSQL: p( + p("'10d03e62-7703-4df5-a358-4f76577d4e2f' = ''") + " AND " + p("'' != ''") + " AND " + p("'' = ''"), + ), + }, } for _, tc := range testCases { diff --git a/coderd/rbac/regosql/configs.go b/coderd/rbac/regosql/configs.go index 475d317cd53ab..9f27809bf017c 100644 --- a/coderd/rbac/regosql/configs.go +++ b/coderd/rbac/regosql/configs.go @@ -36,6 +36,23 @@ func TemplateConverter() *sqltypes.VariableConverter { return matcher } +func UserConverter() *sqltypes.VariableConverter { + matcher := sqltypes.NewVariableConverter().RegisterMatcher( + resourceIDMatcher(), + // Users are never owned by an organization, so always return the empty string + // for the org owner. + sqltypes.StringVarMatcher("''", []string{"input", "object", "org_owner"}), + // Users never have an owner, and are only owned site wide. + sqltypes.StringVarMatcher("''", []string{"input", "object", "owner"}), + ) + matcher.RegisterMatcher( + // No ACLs on the user type + sqltypes.AlwaysFalse(groupACLMatcher(matcher)), + sqltypes.AlwaysFalse(userACLMatcher(matcher)), + ) + return matcher +} + // NoACLConverter should be used when the target SQL table does not contain // group or user ACL columns. func NoACLConverter() *sqltypes.VariableConverter {