Skip to content

chore: add prebuilds system user #16916

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 28 commits into from
Mar 25, 2025
Merged
Show file tree
Hide file tree
Changes from 26 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
300e80f
add prebuilds system user database changes and associated changes
SasSwart Mar 12, 2025
b788237
optionally prevent system users from counting to user count
dannykopping Mar 13, 2025
8122595
appease the linter
dannykopping Mar 13, 2025
bfb7c28
add unit test for system user behaviour
dannykopping Mar 13, 2025
6639167
reverting RBAC changes; not relevant here
dannykopping Mar 13, 2025
769ae1d
removing unnecessary changes
dannykopping Mar 13, 2025
e7e9c27
exclude system user db tests from non-linux OSs
dannykopping Mar 13, 2025
3936047
Rename prebuild system user reference
SasSwart Mar 17, 2025
8bdcafb
ensure that users.IsSystem is not nullable
SasSwart Mar 17, 2025
324fde2
Fixes
dannykopping Mar 17, 2025
81d9dfa
Merge remote-tracking branch 'origin/main' into prebuilds-system-user
SasSwart Mar 18, 2025
896c881
renumber migrations
SasSwart Mar 18, 2025
de4fb8a
ensure that system users are filtered and returned consistently
SasSwart Mar 19, 2025
2751d5b
make -B lint
SasSwart Mar 19, 2025
1042c39
rewrite prebuilds system user tests in our usual style
SasSwart Mar 19, 2025
f9e9d11
add support for prebuilds user to dbmem
SasSwart Mar 19, 2025
7492965
appease the linter
SasSwart Mar 19, 2025
29e2020
add support for the prebuilds system user to dbmem
SasSwart Mar 19, 2025
8c51585
linter
SasSwart Mar 19, 2025
cdc5c71
fix dbmem tests
SasSwart Mar 19, 2025
0d4813a
remove restriction on modifying system users for now
SasSwart Mar 19, 2025
95d70a3
remove system user index
SasSwart Mar 20, 2025
8f1d71c
Merge remote-tracking branch 'origin/main' into prebuilds-system-user
SasSwart Mar 24, 2025
7e009e5
invert tests that check for system user update protection
SasSwart Mar 24, 2025
addd7c6
lint
SasSwart Mar 24, 2025
7a4ef24
Allow TestUpdateSystemUser to run against dbmem
SasSwart Mar 24, 2025
f30ce72
Merge remote-tracking branch 'origin/main' into prebuilds-system-user
SasSwart Mar 25, 2025
5f0ae5e
Renumber migrations
SasSwart Mar 25, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cli/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -1894,7 +1894,7 @@ func getGithubOAuth2ConfigParams(ctx context.Context, db database.Store, vals *c

if defaultEligibleNotSet {
// nolint:gocritic // User count requires system privileges
userCount, err := db.GetUserCount(dbauthz.AsSystemRestricted(ctx))
userCount, err := db.GetUserCount(dbauthz.AsSystemRestricted(ctx), false)
if err != nil {
return nil, xerrors.Errorf("get user count: %w", err)
}
Expand Down
33 changes: 19 additions & 14 deletions coderd/database/dbauthz/dbauthz.go
Original file line number Diff line number Diff line change
Expand Up @@ -1057,13 +1057,13 @@ func (q *querier) ActivityBumpWorkspace(ctx context.Context, arg database.Activi
return update(q.log, q.auth, fetch, q.db.ActivityBumpWorkspace)(ctx, arg)
}

func (q *querier) AllUserIDs(ctx context.Context) ([]uuid.UUID, error) {
func (q *querier) AllUserIDs(ctx context.Context, includeSystem bool) ([]uuid.UUID, error) {
// Although this technically only reads users, only system-related functions should be
// allowed to call this.
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceSystem); err != nil {
return nil, err
}
return q.db.AllUserIDs(ctx)
return q.db.AllUserIDs(ctx, includeSystem)
}

func (q *querier) ArchiveUnusedTemplateVersions(ctx context.Context, arg database.ArchiveUnusedTemplateVersionsParams) ([]uuid.UUID, error) {
Expand Down Expand Up @@ -1316,7 +1316,11 @@ func (q *querier) DeleteOldWorkspaceAgentStats(ctx context.Context) error {

func (q *querier) DeleteOrganizationMember(ctx context.Context, arg database.DeleteOrganizationMemberParams) error {
return deleteQ[database.OrganizationMember](q.log, q.auth, func(ctx context.Context, arg database.DeleteOrganizationMemberParams) (database.OrganizationMember, error) {
member, err := database.ExpectOne(q.OrganizationMembers(ctx, database.OrganizationMembersParams(arg)))
member, err := database.ExpectOne(q.OrganizationMembers(ctx, database.OrganizationMembersParams{
OrganizationID: arg.OrganizationID,
UserID: arg.UserID,
IncludeSystem: false,
}))
if err != nil {
return database.OrganizationMember{}, err
}
Expand Down Expand Up @@ -1502,11 +1506,11 @@ func (q *querier) GetAPIKeysLastUsedAfter(ctx context.Context, lastUsed time.Tim
return fetchWithPostFilter(q.auth, policy.ActionRead, q.db.GetAPIKeysLastUsedAfter)(ctx, lastUsed)
}

func (q *querier) GetActiveUserCount(ctx context.Context) (int64, error) {
func (q *querier) GetActiveUserCount(ctx context.Context, includeSystem bool) (int64, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceSystem); err != nil {
return 0, err
}
return q.db.GetActiveUserCount(ctx)
return q.db.GetActiveUserCount(ctx, includeSystem)
}

func (q *querier) GetActiveWorkspaceBuildsByTemplateID(ctx context.Context, templateID uuid.UUID) ([]database.WorkspaceBuild, error) {
Expand Down Expand Up @@ -1737,22 +1741,22 @@ func (q *querier) GetGroupByOrgAndName(ctx context.Context, arg database.GetGrou
return fetch(q.log, q.auth, q.db.GetGroupByOrgAndName)(ctx, arg)
}

func (q *querier) GetGroupMembers(ctx context.Context) ([]database.GroupMember, error) {
func (q *querier) GetGroupMembers(ctx context.Context, includeSystem bool) ([]database.GroupMember, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceSystem); err != nil {
return nil, err
}
return q.db.GetGroupMembers(ctx)
return q.db.GetGroupMembers(ctx, includeSystem)
}

func (q *querier) GetGroupMembersByGroupID(ctx context.Context, id uuid.UUID) ([]database.GroupMember, error) {
return fetchWithPostFilter(q.auth, policy.ActionRead, q.db.GetGroupMembersByGroupID)(ctx, id)
func (q *querier) GetGroupMembersByGroupID(ctx context.Context, arg database.GetGroupMembersByGroupIDParams) ([]database.GroupMember, error) {
return fetchWithPostFilter(q.auth, policy.ActionRead, q.db.GetGroupMembersByGroupID)(ctx, arg)
}

func (q *querier) GetGroupMembersCountByGroupID(ctx context.Context, groupID uuid.UUID) (int64, error) {
if _, err := q.GetGroupByID(ctx, groupID); err != nil { // AuthZ check
func (q *querier) GetGroupMembersCountByGroupID(ctx context.Context, arg database.GetGroupMembersCountByGroupIDParams) (int64, error) {
if _, err := q.GetGroupByID(ctx, arg.GroupID); err != nil { // AuthZ check
return 0, err
}
memberCount, err := q.db.GetGroupMembersCountByGroupID(ctx, groupID)
memberCount, err := q.db.GetGroupMembersCountByGroupID(ctx, arg)
if err != nil {
return 0, err
}
Expand Down Expand Up @@ -2530,11 +2534,11 @@ func (q *querier) GetUserByID(ctx context.Context, id uuid.UUID) (database.User,
return fetch(q.log, q.auth, q.db.GetUserByID)(ctx, id)
}

func (q *querier) GetUserCount(ctx context.Context) (int64, error) {
func (q *querier) GetUserCount(ctx context.Context, includeSystem bool) (int64, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceSystem); err != nil {
return 0, err
}
return q.db.GetUserCount(ctx)
return q.db.GetUserCount(ctx, includeSystem)
}

func (q *querier) GetUserLatencyInsights(ctx context.Context, arg database.GetUserLatencyInsightsParams) ([]database.GetUserLatencyInsightsRow, error) {
Expand Down Expand Up @@ -3771,6 +3775,7 @@ func (q *querier) UpdateMemberRoles(ctx context.Context, arg database.UpdateMemb
member, err := database.ExpectOne(q.OrganizationMembers(ctx, database.OrganizationMembersParams{
OrganizationID: arg.OrgID,
UserID: arg.UserID,
IncludeSystem: false,
}))
if err != nil {
return database.OrganizationMember{}, err
Expand Down
18 changes: 12 additions & 6 deletions coderd/database/dbauthz/dbauthz_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -387,19 +387,25 @@ func (s *MethodTestSuite) TestGroup() {
g := dbgen.Group(s.T(), db, database.Group{})
u := dbgen.User(s.T(), db, database.User{})
gm := dbgen.GroupMember(s.T(), db, database.GroupMemberTable{GroupID: g.ID, UserID: u.ID})
check.Args(g.ID).Asserts(gm, policy.ActionRead)
check.Args(database.GetGroupMembersByGroupIDParams{
GroupID: g.ID,
IncludeSystem: false,
}).Asserts(gm, policy.ActionRead)
}))
s.Run("GetGroupMembersCountByGroupID", s.Subtest(func(db database.Store, check *expects) {
dbtestutil.DisableForeignKeysAndTriggers(s.T(), db)
g := dbgen.Group(s.T(), db, database.Group{})
check.Args(g.ID).Asserts(g, policy.ActionRead)
check.Args(database.GetGroupMembersCountByGroupIDParams{
GroupID: g.ID,
IncludeSystem: false,
}).Asserts(g, policy.ActionRead)
}))
s.Run("GetGroupMembers", s.Subtest(func(db database.Store, check *expects) {
dbtestutil.DisableForeignKeysAndTriggers(s.T(), db)
g := dbgen.Group(s.T(), db, database.Group{})
u := dbgen.User(s.T(), db, database.User{})
dbgen.GroupMember(s.T(), db, database.GroupMemberTable{GroupID: g.ID, UserID: u.ID})
check.Asserts(rbac.ResourceSystem, policy.ActionRead)
check.Args(false).Asserts(rbac.ResourceSystem, policy.ActionRead)
}))
s.Run("System/GetGroups", s.Subtest(func(db database.Store, check *expects) {
dbtestutil.DisableForeignKeysAndTriggers(s.T(), db)
Expand Down Expand Up @@ -1664,7 +1670,7 @@ func (s *MethodTestSuite) TestUser() {
s.Run("AllUserIDs", s.Subtest(func(db database.Store, check *expects) {
a := dbgen.User(s.T(), db, database.User{})
b := dbgen.User(s.T(), db, database.User{})
check.Args().Asserts(rbac.ResourceSystem, policy.ActionRead).Returns(slice.New(a.ID, b.ID))
check.Args(false).Asserts(rbac.ResourceSystem, policy.ActionRead).Returns(slice.New(a.ID, b.ID))
}))
s.Run("CustomRoles", s.Subtest(func(db database.Store, check *expects) {
check.Args(database.CustomRolesParams{}).Asserts(rbac.ResourceAssignRole, policy.ActionRead).Returns([]database.CustomRole{})
Expand Down Expand Up @@ -3679,7 +3685,7 @@ func (s *MethodTestSuite) TestSystemFunctions() {
check.Args().Asserts(rbac.ResourceSystem, policy.ActionRead)
}))
s.Run("GetActiveUserCount", s.Subtest(func(db database.Store, check *expects) {
check.Args().Asserts(rbac.ResourceSystem, policy.ActionRead).Returns(int64(0))
check.Args(false).Asserts(rbac.ResourceSystem, policy.ActionRead).Returns(int64(0))
}))
s.Run("GetUnexpiredLicenses", s.Subtest(func(db database.Store, check *expects) {
check.Args().Asserts(rbac.ResourceSystem, policy.ActionRead)
Expand Down Expand Up @@ -3722,7 +3728,7 @@ func (s *MethodTestSuite) TestSystemFunctions() {
check.Args(time.Now().Add(time.Hour*-1)).Asserts(rbac.ResourceSystem, policy.ActionRead)
}))
s.Run("GetUserCount", s.Subtest(func(db database.Store, check *expects) {
check.Args().Asserts(rbac.ResourceSystem, policy.ActionRead).Returns(int64(0))
check.Args(false).Asserts(rbac.ResourceSystem, policy.ActionRead).Returns(int64(0))
}))
s.Run("GetTemplates", s.Subtest(func(db database.Store, check *expects) {
dbtestutil.DisableForeignKeysAndTriggers(s.T(), db)
Expand Down
5 changes: 4 additions & 1 deletion coderd/database/dbauthz/groupsauth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,10 @@ func TestGroupsAuth(t *testing.T) {
require.Error(t, err, "group read")
}

members, err := db.GetGroupMembersByGroupID(actorCtx, group.ID)
members, err := db.GetGroupMembersByGroupID(actorCtx, database.GetGroupMembersByGroupIDParams{
GroupID: group.ID,
IncludeSystem: false,
})
if tc.ReadMembers {
require.NoError(t, err, "member read")
require.Len(t, members, tc.MembersExpected, "member count found does not match")
Expand Down
5 changes: 4 additions & 1 deletion coderd/database/dbgen/dbgen_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,10 @@ func TestGenerator(t *testing.T) {
gm := dbgen.GroupMember(t, db, database.GroupMemberTable{GroupID: g.ID, UserID: u.ID})
exp := []database.GroupMember{gm}

require.Equal(t, exp, must(db.GetGroupMembersByGroupID(context.Background(), g.ID)))
require.Equal(t, exp, must(db.GetGroupMembersByGroupID(context.Background(), database.GetGroupMembersByGroupIDParams{
GroupID: g.ID,
IncludeSystem: false,
})))
})

t.Run("Organization", func(t *testing.T) {
Expand Down
65 changes: 54 additions & 11 deletions coderd/database/dbmem/dbmem.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"golang.org/x/xerrors"

"github.com/coder/coder/v2/coderd/notifications/types"
"github.com/coder/coder/v2/coderd/prebuilds"

"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbtime"
Expand Down Expand Up @@ -153,6 +154,22 @@ func New() database.Store {
panic(xerrors.Errorf("failed to create psk provisioner key: %w", err))
}

q.mutex.Lock()
// We can't insert this user using the interface, because it's a system user.
q.data.users = append(q.data.users, database.User{
ID: prebuilds.SystemUserID,
Email: "prebuilds@coder.com",
Username: "prebuilds",
CreatedAt: dbtime.Now(),
UpdatedAt: dbtime.Now(),
Status: "active",
LoginType: "none",
HashedPassword: []byte{},
IsSystem: true,
Deleted: false,
})
q.mutex.Unlock()

return q
}

Expand Down Expand Up @@ -440,6 +457,7 @@ func convertUsers(users []database.User, count int64) []database.GetUsersRow {
Deleted: u.Deleted,
LastSeenAt: u.LastSeenAt,
Count: count,
IsSystem: u.IsSystem,
}
}

Expand Down Expand Up @@ -1552,11 +1570,16 @@ func (q *FakeQuerier) ActivityBumpWorkspace(ctx context.Context, arg database.Ac
return sql.ErrNoRows
}

func (q *FakeQuerier) AllUserIDs(_ context.Context) ([]uuid.UUID, error) {
// nolint:revive // It's not a control flag, it's a filter.
func (q *FakeQuerier) AllUserIDs(_ context.Context, includeSystem bool) ([]uuid.UUID, error) {
q.mutex.RLock()
defer q.mutex.RUnlock()
userIDs := make([]uuid.UUID, 0, len(q.users))
for idx := range q.users {
if !includeSystem && q.users[idx].IsSystem {
continue
}

userIDs = append(userIDs, q.users[idx].ID)
}
return userIDs, nil
Expand Down Expand Up @@ -2647,12 +2670,17 @@ func (q *FakeQuerier) GetAPIKeysLastUsedAfter(_ context.Context, after time.Time
return apiKeys, nil
}

func (q *FakeQuerier) GetActiveUserCount(_ context.Context) (int64, error) {
// nolint:revive // It's not a control flag, it's a filter.
func (q *FakeQuerier) GetActiveUserCount(_ context.Context, includeSystem bool) (int64, error) {
q.mutex.RLock()
defer q.mutex.RUnlock()

active := int64(0)
for _, u := range q.users {
if !includeSystem && u.IsSystem {
continue
}

if u.Status == database.UserStatusActive && !u.Deleted {
active++
}
Expand Down Expand Up @@ -3388,14 +3416,18 @@ func (q *FakeQuerier) GetGroupByOrgAndName(_ context.Context, arg database.GetGr
return database.Group{}, sql.ErrNoRows
}

func (q *FakeQuerier) GetGroupMembers(ctx context.Context) ([]database.GroupMember, error) {
//nolint:revive // It's not a control flag, its a filter
func (q *FakeQuerier) GetGroupMembers(ctx context.Context, includeSystem bool) ([]database.GroupMember, error) {
q.mutex.RLock()
defer q.mutex.RUnlock()

members := make([]database.GroupMemberTable, 0, len(q.groupMembers))
members = append(members, q.groupMembers...)
for _, org := range q.organizations {
for _, user := range q.users {
if !includeSystem && user.IsSystem {
continue
}
members = append(members, database.GroupMemberTable{
UserID: user.ID,
GroupID: org.ID,
Expand All @@ -3418,17 +3450,17 @@ func (q *FakeQuerier) GetGroupMembers(ctx context.Context) ([]database.GroupMemb
return groupMembers, nil
}

func (q *FakeQuerier) GetGroupMembersByGroupID(ctx context.Context, id uuid.UUID) ([]database.GroupMember, error) {
func (q *FakeQuerier) GetGroupMembersByGroupID(ctx context.Context, arg database.GetGroupMembersByGroupIDParams) ([]database.GroupMember, error) {
q.mutex.RLock()
defer q.mutex.RUnlock()

if q.isEveryoneGroup(id) {
return q.getEveryoneGroupMembersNoLock(ctx, id), nil
if q.isEveryoneGroup(arg.GroupID) {
return q.getEveryoneGroupMembersNoLock(ctx, arg.GroupID), nil
}

var groupMembers []database.GroupMember
for _, member := range q.groupMembers {
if member.GroupID == id {
if member.GroupID == arg.GroupID {
groupMember, err := q.getGroupMemberNoLock(ctx, member.UserID, member.GroupID)
if errors.Is(err, errUserDeleted) {
continue
Expand All @@ -3443,8 +3475,8 @@ func (q *FakeQuerier) GetGroupMembersByGroupID(ctx context.Context, id uuid.UUID
return groupMembers, nil
}

func (q *FakeQuerier) GetGroupMembersCountByGroupID(ctx context.Context, groupID uuid.UUID) (int64, error) {
users, err := q.GetGroupMembersByGroupID(ctx, groupID)
func (q *FakeQuerier) GetGroupMembersCountByGroupID(ctx context.Context, arg database.GetGroupMembersCountByGroupIDParams) (int64, error) {
users, err := q.GetGroupMembersByGroupID(ctx, database.GetGroupMembersByGroupIDParams(arg))
if err != nil {
return 0, err
}
Expand Down Expand Up @@ -6221,12 +6253,16 @@ func (q *FakeQuerier) GetUserByID(_ context.Context, id uuid.UUID) (database.Use
return q.getUserByIDNoLock(id)
}

func (q *FakeQuerier) GetUserCount(_ context.Context) (int64, error) {
// nolint:revive // It's not a control flag, it's a filter.
func (q *FakeQuerier) GetUserCount(_ context.Context, includeSystem bool) (int64, error) {
q.mutex.RLock()
defer q.mutex.RUnlock()

existing := int64(0)
for _, u := range q.users {
if !includeSystem && u.IsSystem {
continue
}
if !u.Deleted {
existing++
}
Expand Down Expand Up @@ -6578,6 +6614,12 @@ func (q *FakeQuerier) GetUsers(_ context.Context, params database.GetUsersParams
users = usersFilteredByLastSeen
}

if !params.IncludeSystem {
users = slices.DeleteFunc(users, func(u database.User) bool {
return u.IsSystem
})
}

if params.GithubComUserID != 0 {
usersFilteredByGithubComUserID := make([]database.User, 0, len(users))
for i, user := range users {
Expand Down Expand Up @@ -8900,6 +8942,7 @@ func (q *FakeQuerier) InsertUser(_ context.Context, arg database.InsertUserParam
Status: status,
RBACRoles: arg.RBACRoles,
LoginType: arg.LoginType,
IsSystem: false,
}
q.users = append(q.users, user)
sort.Slice(q.users, func(i, j int) bool {
Expand Down Expand Up @@ -10058,7 +10101,7 @@ func (q *FakeQuerier) UpdateInactiveUsersToDormant(_ context.Context, params dat

var updated []database.UpdateInactiveUsersToDormantRow
for index, user := range q.users {
if user.Status == database.UserStatusActive && user.LastSeenAt.Before(params.LastSeenAfter) {
if user.Status == database.UserStatusActive && user.LastSeenAt.Before(params.LastSeenAfter) && !user.IsSystem {
q.users[index].Status = database.UserStatusDormant
q.users[index].UpdatedAt = params.UpdatedAt
updated = append(updated, database.UpdateInactiveUsersToDormantRow{
Expand Down
Loading
Loading