Skip to content

Commit 0f472eb

Browse files
committed
use query for group names
1 parent da9163e commit 0f472eb

File tree

8 files changed

+133
-15
lines changed

8 files changed

+133
-15
lines changed

coderd/database/dbauthz/dbauthz.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1833,6 +1833,13 @@ func (q *querier) GetUserCount(ctx context.Context) (int64, error) {
18331833
return q.db.GetUserCount(ctx)
18341834
}
18351835

1836+
func (q *querier) GetUserGroupNames(ctx context.Context, arg database.GetUserGroupNamesParams) ([]string, error) {
1837+
if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceGroup); err != nil {
1838+
return nil, err
1839+
}
1840+
return q.db.GetUserGroupNames(ctx, arg)
1841+
}
1842+
18361843
func (q *querier) GetUserLatencyInsights(ctx context.Context, arg database.GetUserLatencyInsightsParams) ([]database.GetUserLatencyInsightsRow, error) {
18371844
// Used by insights endpoints. Need to check both for auditors and for regular users with template acl perms.
18381845
if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceTemplateInsights); err != nil {

coderd/database/dbmem/dbmem.go

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4334,6 +4334,30 @@ func (q *FakeQuerier) GetUserCount(_ context.Context) (int64, error) {
43344334
return existing, nil
43354335
}
43364336

4337+
func (q *FakeQuerier) GetUserGroupNames(ctx context.Context, arg database.GetUserGroupNamesParams) ([]string, error) {
4338+
err := validateDatabaseType(arg)
4339+
if err != nil {
4340+
return nil, err
4341+
}
4342+
4343+
q.mutex.RLock()
4344+
defer q.mutex.RUnlock()
4345+
var groupIds []uuid.UUID
4346+
for _, member := range q.groupMembers {
4347+
if member.UserID == arg.UserID {
4348+
groupIds = append(groupIds, member.GroupID)
4349+
}
4350+
}
4351+
groupNames := []string{}
4352+
for _, group := range q.groups {
4353+
if slices.Contains(groupIds, group.ID) {
4354+
groupNames = append(groupNames, group.Name)
4355+
}
4356+
}
4357+
4358+
return groupNames, nil
4359+
}
4360+
43374361
func (q *FakeQuerier) GetUserLatencyInsights(_ context.Context, arg database.GetUserLatencyInsightsParams) ([]database.GetUserLatencyInsightsRow, error) {
43384362
err := validateDatabaseType(arg)
43394363
if err != nil {

coderd/database/dbmetrics/dbmetrics.go

Lines changed: 7 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

coderd/database/dbmock/dbmock.go

Lines changed: 15 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

coderd/database/querier.go

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

coderd/database/queries.sql.go

Lines changed: 51 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

coderd/database/queries/groups.sql

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,29 @@ FROM
2828
WHERE
2929
organization_id = $1;
3030

31+
-- name: GetUserGroupNames :many
32+
SELECT
33+
groups.name
34+
FROM
35+
groups
36+
LEFT JOIN
37+
group_members
38+
ON
39+
group_members.group_id = groups.id AND
40+
group_members.user_id = @user_id
41+
LEFT JOIN
42+
organization_members
43+
ON
44+
organization_members.organization_id = groups.id AND
45+
organization_members.user_id = @user_id
46+
WHERE
47+
-- In either case, the group_id will only match an org or a group.
48+
(group_members.user_id = @user_id OR organization_members.user_id = @user_id)
49+
AND
50+
-- Ensure the group or organization is the specified organization.
51+
groups.organization_id = @organization_id;
52+
53+
3154
-- name: InsertGroup :one
3255
INSERT INTO groups (
3356
id,

coderd/provisionerdserver/provisionerdserver.go

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -467,22 +467,12 @@ func (s *server) acquireProtoJob(ctx context.Context, job database.ProvisionerJo
467467
if err != nil {
468468
return nil, failJob(fmt.Sprintf("get owner: %s", err))
469469
}
470-
orgGroups, err := s.Database.GetGroupsByOrganizationID(ctx, s.OrganizationID)
470+
ownerGroupNames, err := s.Database.GetUserGroupNames(ctx, database.GetUserGroupNamesParams{
471+
UserID: owner.ID,
472+
OrganizationID: s.OrganizationID,
473+
})
471474
if err != nil {
472-
return nil, failJob(fmt.Sprintf("get owner groups: %s", err))
473-
}
474-
ownerGroupNames := []string{}
475-
for _, group := range orgGroups {
476-
members, err := s.Database.GetGroupMembers(ctx, group.ID)
477-
if err != nil {
478-
return nil, failJob(fmt.Sprintf("get group members: %s", err))
479-
}
480-
for _, member := range members {
481-
if member.ID == owner.ID {
482-
ownerGroupNames = append(ownerGroupNames, group.Name)
483-
break
484-
}
485-
}
475+
return nil, failJob(fmt.Sprintf("get owner group names: %s", err))
486476
}
487477
err = s.Pubsub.Publish(codersdk.WorkspaceNotifyChannel(workspace.ID), []byte{})
488478
if err != nil {

0 commit comments

Comments
 (0)