diff --git a/coderd/database/dbauthz/dbauthz.go b/coderd/database/dbauthz/dbauthz.go index f32e176754fa1..b0b4adcf844c9 100644 --- a/coderd/database/dbauthz/dbauthz.go +++ b/coderd/database/dbauthz/dbauthz.go @@ -1321,11 +1321,25 @@ func (q *querier) GetGroupByOrgAndName(ctx context.Context, arg database.GetGrou return fetch(q.log, q.auth, q.db.GetGroupByOrgAndName)(ctx, arg) } -func (q *querier) GetGroupMembers(ctx context.Context, id uuid.UUID) ([]database.User, error) { +func (q *querier) GetGroupMembers(ctx context.Context) ([]database.GroupMember, error) { + if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceSystem); err != nil { + return nil, err + } + return q.db.GetGroupMembers(ctx) +} + +func (q *querier) GetGroupMembersByGroupID(ctx context.Context, id uuid.UUID) ([]database.User, error) { if _, err := q.GetGroupByID(ctx, id); err != nil { // AuthZ check return nil, err } - return q.db.GetGroupMembers(ctx, id) + return q.db.GetGroupMembersByGroupID(ctx, id) +} + +func (q *querier) GetGroups(ctx context.Context) ([]database.Group, error) { + if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceSystem); err != nil { + return nil, err + } + return q.db.GetGroups(ctx) } func (q *querier) GetGroupsByOrganizationAndUserID(ctx context.Context, arg database.GetGroupsByOrganizationAndUserIDParams) ([]database.Group, error) { diff --git a/coderd/database/dbauthz/dbauthz_test.go b/coderd/database/dbauthz/dbauthz_test.go index 17c0d76c4ef31..3ccafb33262cc 100644 --- a/coderd/database/dbauthz/dbauthz_test.go +++ b/coderd/database/dbauthz/dbauthz_test.go @@ -314,11 +314,19 @@ func (s *MethodTestSuite) TestGroup() { Name: g.Name, }).Asserts(g, policy.ActionRead).Returns(g) })) - s.Run("GetGroupMembers", s.Subtest(func(db database.Store, check *expects) { + s.Run("GetGroupMembersByGroupID", s.Subtest(func(db database.Store, check *expects) { g := dbgen.Group(s.T(), db, database.Group{}) _ = dbgen.GroupMember(s.T(), db, database.GroupMember{}) check.Args(g.ID).Asserts(g, policy.ActionRead) })) + s.Run("GetGroupMembers", s.Subtest(func(db database.Store, check *expects) { + _ = dbgen.GroupMember(s.T(), db, database.GroupMember{}) + check.Asserts(rbac.ResourceSystem, policy.ActionRead) + })) + s.Run("GetGroups", s.Subtest(func(db database.Store, check *expects) { + _ = dbgen.Group(s.T(), db, database.Group{}) + check.Asserts(rbac.ResourceSystem, policy.ActionRead) + })) s.Run("GetGroupsByOrganizationAndUserID", s.Subtest(func(db database.Store, check *expects) { g := dbgen.Group(s.T(), db, database.Group{}) gm := dbgen.GroupMember(s.T(), db, database.GroupMember{GroupID: g.ID}) diff --git a/coderd/database/dbgen/dbgen_test.go b/coderd/database/dbgen/dbgen_test.go index 2681f6eb1fece..0690808eb590d 100644 --- a/coderd/database/dbgen/dbgen_test.go +++ b/coderd/database/dbgen/dbgen_test.go @@ -105,7 +105,7 @@ func TestGenerator(t *testing.T) { exp := []database.User{u} dbgen.GroupMember(t, db, database.GroupMember{GroupID: g.ID, UserID: u.ID}) - require.Equal(t, exp, must(db.GetGroupMembers(context.Background(), g.ID))) + require.Equal(t, exp, must(db.GetGroupMembersByGroupID(context.Background(), g.ID))) }) t.Run("Organization", func(t *testing.T) { diff --git a/coderd/database/dbmem/dbmem.go b/coderd/database/dbmem/dbmem.go index 04eecd5d86355..573279db3ad39 100644 --- a/coderd/database/dbmem/dbmem.go +++ b/coderd/database/dbmem/dbmem.go @@ -2370,7 +2370,16 @@ func (q *FakeQuerier) GetGroupByOrgAndName(_ context.Context, arg database.GetGr return database.Group{}, sql.ErrNoRows } -func (q *FakeQuerier) GetGroupMembers(_ context.Context, id uuid.UUID) ([]database.User, error) { +func (q *FakeQuerier) GetGroupMembers(_ context.Context) ([]database.GroupMember, error) { + q.mutex.RLock() + defer q.mutex.RUnlock() + + out := make([]database.GroupMember, len(q.groupMembers)) + copy(out, q.groupMembers) + return out, nil +} + +func (q *FakeQuerier) GetGroupMembersByGroupID(_ context.Context, id uuid.UUID) ([]database.User, error) { q.mutex.RLock() defer q.mutex.RUnlock() @@ -2399,6 +2408,15 @@ func (q *FakeQuerier) GetGroupMembers(_ context.Context, id uuid.UUID) ([]databa return users, nil } +func (q *FakeQuerier) GetGroups(_ context.Context) ([]database.Group, error) { + q.mutex.RLock() + defer q.mutex.RUnlock() + + out := make([]database.Group, len(q.groups)) + copy(out, q.groups) + return out, nil +} + func (q *FakeQuerier) GetGroupsByOrganizationAndUserID(_ context.Context, arg database.GetGroupsByOrganizationAndUserIDParams) ([]database.Group, error) { err := validateDatabaseType(arg) if err != nil { diff --git a/coderd/database/dbmetrics/dbmetrics.go b/coderd/database/dbmetrics/dbmetrics.go index dc4c52a31faaf..e45072cd71cdb 100644 --- a/coderd/database/dbmetrics/dbmetrics.go +++ b/coderd/database/dbmetrics/dbmetrics.go @@ -585,13 +585,27 @@ func (m metricsStore) GetGroupByOrgAndName(ctx context.Context, arg database.Get return group, err } -func (m metricsStore) GetGroupMembers(ctx context.Context, groupID uuid.UUID) ([]database.User, error) { +func (m metricsStore) GetGroupMembers(ctx context.Context) ([]database.GroupMember, error) { start := time.Now() - users, err := m.s.GetGroupMembers(ctx, groupID) + r0, r1 := m.s.GetGroupMembers(ctx) m.queryLatencies.WithLabelValues("GetGroupMembers").Observe(time.Since(start).Seconds()) + return r0, r1 +} + +func (m metricsStore) GetGroupMembersByGroupID(ctx context.Context, groupID uuid.UUID) ([]database.User, error) { + start := time.Now() + users, err := m.s.GetGroupMembersByGroupID(ctx, groupID) + m.queryLatencies.WithLabelValues("GetGroupMembersByGroupID").Observe(time.Since(start).Seconds()) return users, err } +func (m metricsStore) GetGroups(ctx context.Context) ([]database.Group, error) { + start := time.Now() + r0, r1 := m.s.GetGroups(ctx) + m.queryLatencies.WithLabelValues("GetGroups").Observe(time.Since(start).Seconds()) + return r0, r1 +} + func (m metricsStore) GetGroupsByOrganizationAndUserID(ctx context.Context, arg database.GetGroupsByOrganizationAndUserIDParams) ([]database.Group, error) { start := time.Now() r0, r1 := m.s.GetGroupsByOrganizationAndUserID(ctx, arg) diff --git a/coderd/database/dbmock/dbmock.go b/coderd/database/dbmock/dbmock.go index 35f0312b7b20f..4b952461c44a6 100644 --- a/coderd/database/dbmock/dbmock.go +++ b/coderd/database/dbmock/dbmock.go @@ -1139,18 +1139,48 @@ func (mr *MockStoreMockRecorder) GetGroupByOrgAndName(arg0, arg1 any) *gomock.Ca } // GetGroupMembers mocks base method. -func (m *MockStore) GetGroupMembers(arg0 context.Context, arg1 uuid.UUID) ([]database.User, error) { +func (m *MockStore) GetGroupMembers(arg0 context.Context) ([]database.GroupMember, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetGroupMembers", arg0, arg1) - ret0, _ := ret[0].([]database.User) + ret := m.ctrl.Call(m, "GetGroupMembers", arg0) + ret0, _ := ret[0].([]database.GroupMember) ret1, _ := ret[1].(error) return ret0, ret1 } // GetGroupMembers indicates an expected call of GetGroupMembers. -func (mr *MockStoreMockRecorder) GetGroupMembers(arg0, arg1 any) *gomock.Call { +func (mr *MockStoreMockRecorder) GetGroupMembers(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetGroupMembers", reflect.TypeOf((*MockStore)(nil).GetGroupMembers), arg0) +} + +// GetGroupMembersByGroupID mocks base method. +func (m *MockStore) GetGroupMembersByGroupID(arg0 context.Context, arg1 uuid.UUID) ([]database.User, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetGroupMembersByGroupID", arg0, arg1) + ret0, _ := ret[0].([]database.User) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetGroupMembersByGroupID indicates an expected call of GetGroupMembersByGroupID. +func (mr *MockStoreMockRecorder) GetGroupMembersByGroupID(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetGroupMembersByGroupID", reflect.TypeOf((*MockStore)(nil).GetGroupMembersByGroupID), arg0, arg1) +} + +// GetGroups mocks base method. +func (m *MockStore) GetGroups(arg0 context.Context) ([]database.Group, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetGroups", arg0) + ret0, _ := ret[0].([]database.Group) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetGroups indicates an expected call of GetGroups. +func (mr *MockStoreMockRecorder) GetGroups(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetGroupMembers", reflect.TypeOf((*MockStore)(nil).GetGroupMembers), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetGroups", reflect.TypeOf((*MockStore)(nil).GetGroups), arg0) } // GetGroupsByOrganizationAndUserID mocks base method. diff --git a/coderd/database/querier.go b/coderd/database/querier.go index 50645ab1e5eb5..f244f52026eb0 100644 --- a/coderd/database/querier.go +++ b/coderd/database/querier.go @@ -124,9 +124,11 @@ type sqlcQuerier interface { GetGitSSHKey(ctx context.Context, userID uuid.UUID) (GitSSHKey, error) GetGroupByID(ctx context.Context, id uuid.UUID) (Group, error) GetGroupByOrgAndName(ctx context.Context, arg GetGroupByOrgAndNameParams) (Group, error) + GetGroupMembers(ctx context.Context) ([]GroupMember, error) // If the group is a user made group, then we need to check the group_members table. // If it is the "Everyone" group, then we need to check the organization_members table. - GetGroupMembers(ctx context.Context, groupID uuid.UUID) ([]User, error) + GetGroupMembersByGroupID(ctx context.Context, groupID uuid.UUID) ([]User, error) + GetGroups(ctx context.Context) ([]Group, error) GetGroupsByOrganizationAndUserID(ctx context.Context, arg GetGroupsByOrganizationAndUserIDParams) ([]Group, error) GetGroupsByOrganizationID(ctx context.Context, organizationID uuid.UUID) ([]Group, error) GetHealthSettings(ctx context.Context) (string, error) diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index 4f113323a024f..b25d99eb9e03b 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -1312,6 +1312,33 @@ func (q *sqlQuerier) DeleteGroupMemberFromGroup(ctx context.Context, arg DeleteG } const getGroupMembers = `-- name: GetGroupMembers :many +SELECT user_id, group_id FROM group_members +` + +func (q *sqlQuerier) GetGroupMembers(ctx context.Context) ([]GroupMember, error) { + rows, err := q.db.QueryContext(ctx, getGroupMembers) + if err != nil { + return nil, err + } + defer rows.Close() + var items []GroupMember + for rows.Next() { + var i GroupMember + if err := rows.Scan(&i.UserID, &i.GroupID); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const getGroupMembersByGroupID = `-- name: GetGroupMembersByGroupID :many SELECT users.id, users.email, users.username, users.hashed_password, users.created_at, users.updated_at, users.status, users.rbac_roles, users.login_type, users.avatar_url, users.deleted, users.last_seen_at, users.quiet_hours_schedule, users.theme_preference, users.name FROM @@ -1337,8 +1364,8 @@ AND // If the group is a user made group, then we need to check the group_members table. // If it is the "Everyone" group, then we need to check the organization_members table. -func (q *sqlQuerier) GetGroupMembers(ctx context.Context, groupID uuid.UUID) ([]User, error) { - rows, err := q.db.QueryContext(ctx, getGroupMembers, groupID) +func (q *sqlQuerier) GetGroupMembersByGroupID(ctx context.Context, groupID uuid.UUID) ([]User, error) { + rows, err := q.db.QueryContext(ctx, getGroupMembersByGroupID, groupID) if err != nil { return nil, err } @@ -1507,6 +1534,41 @@ func (q *sqlQuerier) GetGroupByOrgAndName(ctx context.Context, arg GetGroupByOrg return i, err } +const getGroups = `-- name: GetGroups :many +SELECT id, name, organization_id, avatar_url, quota_allowance, display_name, source FROM groups +` + +func (q *sqlQuerier) GetGroups(ctx context.Context) ([]Group, error) { + rows, err := q.db.QueryContext(ctx, getGroups) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Group + for rows.Next() { + var i Group + if err := rows.Scan( + &i.ID, + &i.Name, + &i.OrganizationID, + &i.AvatarURL, + &i.QuotaAllowance, + &i.DisplayName, + &i.Source, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const getGroupsByOrganizationAndUserID = `-- name: GetGroupsByOrganizationAndUserID :many SELECT groups.id, groups.name, groups.organization_id, groups.avatar_url, groups.quota_allowance, groups.display_name, groups.source diff --git a/coderd/database/queries/groupmembers.sql b/coderd/database/queries/groupmembers.sql index d755212132383..8f4770eff112e 100644 --- a/coderd/database/queries/groupmembers.sql +++ b/coderd/database/queries/groupmembers.sql @@ -1,4 +1,7 @@ -- name: GetGroupMembers :many +SELECT * FROM group_members; + +-- name: GetGroupMembersByGroupID :many SELECT users.* FROM diff --git a/coderd/database/queries/groups.sql b/coderd/database/queries/groups.sql index 53d0b25874987..9dea20f0fa6e6 100644 --- a/coderd/database/queries/groups.sql +++ b/coderd/database/queries/groups.sql @@ -1,3 +1,6 @@ +-- name: GetGroups :many +SELECT * FROM groups; + -- name: GetGroupByID :one SELECT * diff --git a/coderd/telemetry/telemetry.go b/coderd/telemetry/telemetry.go index 9d16ba7922098..91251053663f5 100644 --- a/coderd/telemetry/telemetry.go +++ b/coderd/telemetry/telemetry.go @@ -344,9 +344,6 @@ func (r *remoteReporter) createSnapshot() (*Snapshot, error) { users := database.ConvertUserRows(userRows) var firstUser database.User for _, dbUser := range users { - if dbUser.Status != database.UserStatusActive { - continue - } if firstUser.CreatedAt.IsZero() { firstUser = dbUser } @@ -366,6 +363,28 @@ func (r *remoteReporter) createSnapshot() (*Snapshot, error) { } return nil }) + eg.Go(func() error { + groups, err := r.options.Database.GetGroups(ctx) + if err != nil { + return xerrors.Errorf("get groups: %w", err) + } + snapshot.Groups = make([]Group, 0, len(groups)) + for _, group := range groups { + snapshot.Groups = append(snapshot.Groups, ConvertGroup(group)) + } + return nil + }) + eg.Go(func() error { + groupMembers, err := r.options.Database.GetGroupMembers(ctx) + if err != nil { + return xerrors.Errorf("get groups: %w", err) + } + snapshot.GroupMembers = make([]GroupMember, 0, len(groupMembers)) + for _, member := range groupMembers { + snapshot.GroupMembers = append(snapshot.GroupMembers, ConvertGroupMember(member)) + } + return nil + }) eg.Go(func() error { workspaceRows, err := r.options.Database.GetWorkspaces(ctx, database.GetWorkspacesParams{}) if err != nil { @@ -642,6 +661,26 @@ func ConvertUser(dbUser database.User) User { EmailHashed: emailHashed, RBACRoles: dbUser.RBACRoles, CreatedAt: dbUser.CreatedAt, + Status: dbUser.Status, + } +} + +func ConvertGroup(group database.Group) Group { + return Group{ + ID: group.ID, + Name: group.Name, + OrganizationID: group.OrganizationID, + AvatarURL: group.AvatarURL, + QuotaAllowance: group.QuotaAllowance, + DisplayName: group.DisplayName, + Source: group.Source, + } +} + +func ConvertGroupMember(member database.GroupMember) GroupMember { + return GroupMember{ + GroupID: member.GroupID, + UserID: member.UserID, } } @@ -746,6 +785,8 @@ type Snapshot struct { TemplateVersions []TemplateVersion `json:"template_versions"` Templates []Template `json:"templates"` Users []User `json:"users"` + Groups []Group `json:"groups"` + GroupMembers []GroupMember `json:"group_members"` WorkspaceAgentStats []WorkspaceAgentStat `json:"workspace_agent_stats"` WorkspaceAgents []WorkspaceAgent `json:"workspace_agents"` WorkspaceApps []WorkspaceApp `json:"workspace_apps"` @@ -797,6 +838,21 @@ type User struct { Status database.UserStatus `json:"status"` } +type Group struct { + ID uuid.UUID `json:"id"` + Name string `json:"name"` + OrganizationID uuid.UUID `json:"organization_id"` + AvatarURL string `json:"avatar_url"` + QuotaAllowance int32 `json:"quota_allowance"` + DisplayName string `json:"display_name"` + Source database.GroupSource `json:"source"` +} + +type GroupMember struct { + UserID uuid.UUID `json:"user_id"` + GroupID uuid.UUID `json:"group_id"` +} + type WorkspaceResource struct { ID uuid.UUID `json:"id"` CreatedAt time.Time `json:"created_at"` diff --git a/coderd/telemetry/telemetry_test.go b/coderd/telemetry/telemetry_test.go index fa8650de3f3d5..2eff919ddc63d 100644 --- a/coderd/telemetry/telemetry_test.go +++ b/coderd/telemetry/telemetry_test.go @@ -55,6 +55,8 @@ func TestTelemetry(t *testing.T) { SharingLevel: database.AppSharingLevelOwner, Health: database.WorkspaceAppHealthDisabled, }) + _ = dbgen.Group(t, db, database.Group{}) + _ = dbgen.GroupMember(t, db, database.GroupMember{}) wsagent := dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{}) // Update the workspace agent to have a valid subsystem. err = db.UpdateWorkspaceAgentStartupByID(ctx, database.UpdateWorkspaceAgentStartupByIDParams{ @@ -91,6 +93,8 @@ func TestTelemetry(t *testing.T) { require.Len(t, snapshot.Templates, 1) require.Len(t, snapshot.TemplateVersions, 1) require.Len(t, snapshot.Users, 1) + require.Len(t, snapshot.Groups, 2) + require.Len(t, snapshot.GroupMembers, 1) require.Len(t, snapshot.Workspaces, 1) require.Len(t, snapshot.WorkspaceApps, 1) require.Len(t, snapshot.WorkspaceAgents, 1) diff --git a/enterprise/coderd/groups.go b/enterprise/coderd/groups.go index 65220e5cbabf7..dffbc5200c767 100644 --- a/enterprise/coderd/groups.go +++ b/enterprise/coderd/groups.go @@ -150,7 +150,7 @@ func (api *API) patchGroup(rw http.ResponseWriter, r *http.Request) { return } - currentMembers, err := api.Database.GetGroupMembers(ctx, group.ID) + currentMembers, err := api.Database.GetGroupMembersByGroupID(ctx, group.ID) if err != nil { httpapi.InternalServerError(rw, err) return @@ -276,7 +276,7 @@ func (api *API) patchGroup(rw http.ResponseWriter, r *http.Request) { return } - patchedMembers, err := api.Database.GetGroupMembers(ctx, group.ID) + patchedMembers, err := api.Database.GetGroupMembersByGroupID(ctx, group.ID) if err != nil { httpapi.InternalServerError(rw, err) return @@ -317,7 +317,7 @@ func (api *API) deleteGroup(rw http.ResponseWriter, r *http.Request) { return } - groupMembers, getMembersErr := api.Database.GetGroupMembers(ctx, group.ID) + groupMembers, getMembersErr := api.Database.GetGroupMembersByGroupID(ctx, group.ID) if getMembersErr != nil { httpapi.InternalServerError(rw, getMembersErr) return @@ -363,7 +363,7 @@ func (api *API) group(rw http.ResponseWriter, r *http.Request) { group = httpmw.GroupParam(r) ) - users, err := api.Database.GetGroupMembers(ctx, group.ID) + users, err := api.Database.GetGroupMembersByGroupID(ctx, group.ID) if err != nil && !xerrors.Is(err, sql.ErrNoRows) { httpapi.InternalServerError(rw, err) return @@ -408,7 +408,7 @@ func (api *API) groups(rw http.ResponseWriter, r *http.Request) { resp := make([]codersdk.Group, 0, len(groups)) for _, group := range groups { - members, err := api.Database.GetGroupMembers(ctx, group.ID) + members, err := api.Database.GetGroupMembersByGroupID(ctx, group.ID) if err != nil { httpapi.InternalServerError(rw, err) return diff --git a/enterprise/coderd/templates.go b/enterprise/coderd/templates.go index 6caf882192816..d9d5f245fcb41 100644 --- a/enterprise/coderd/templates.go +++ b/enterprise/coderd/templates.go @@ -59,7 +59,7 @@ func (api *API) templateAvailablePermissions(rw http.ResponseWriter, r *http.Req sdkGroups := make([]codersdk.Group, 0, len(groups)) for _, group := range groups { // nolint:gocritic - members, err := api.Database.GetGroupMembers(dbauthz.AsSystemRestricted(ctx), group.ID) + members, err := api.Database.GetGroupMembersByGroupID(dbauthz.AsSystemRestricted(ctx), group.ID) if err != nil { httpapi.InternalServerError(rw, err) return @@ -128,7 +128,7 @@ func (api *API) templateACL(rw http.ResponseWriter, r *http.Request) { // them read the group members. // We should probably at least return more truncated user data here. // nolint:gocritic - members, err = api.Database.GetGroupMembers(dbauthz.AsSystemRestricted(ctx), group.ID) + members, err = api.Database.GetGroupMembersByGroupID(dbauthz.AsSystemRestricted(ctx), group.ID) if err != nil { httpapi.InternalServerError(rw, err) return