Skip to content

Commit a4da4c6

Browse files
committed
make query return group over name
1 parent cf0fd5a commit a4da4c6

File tree

9 files changed

+104
-95
lines changed

9 files changed

+104
-95
lines changed

coderd/database/dbauthz/dbauthz.go

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1142,6 +1142,10 @@ func (q *querier) GetGroupMembers(ctx context.Context, id uuid.UUID) ([]database
11421142
return q.db.GetGroupMembers(ctx, id)
11431143
}
11441144

1145+
func (q *querier) GetGroupsByOrganizationAndUserID(ctx context.Context, arg database.GetGroupsByOrganizationAndUserIDParams) ([]database.Group, error) {
1146+
return fetchWithPostFilter(q.auth, q.db.GetGroupsByOrganizationAndUserID)(ctx, arg)
1147+
}
1148+
11451149
func (q *querier) GetGroupsByOrganizationID(ctx context.Context, organizationID uuid.UUID) ([]database.Group, error) {
11461150
return fetchWithPostFilter(q.auth, q.db.GetGroupsByOrganizationID)(ctx, organizationID)
11471151
}
@@ -1833,13 +1837,6 @@ func (q *querier) GetUserCount(ctx context.Context) (int64, error) {
18331837
return q.db.GetUserCount(ctx)
18341838
}
18351839

1836-
func (q *querier) GetUserGroupNames(ctx context.Context, arg database.GetUserGroupNamesParams) ([]string, error) {
1837-
if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceGroup.InOrg(arg.OrganizationID)); err != nil {
1838-
return nil, err
1839-
}
1840-
return q.db.GetUserGroupNames(ctx, arg)
1841-
}
1842-
18431840
func (q *querier) GetUserLatencyInsights(ctx context.Context, arg database.GetUserLatencyInsightsParams) ([]database.GetUserLatencyInsightsRow, error) {
18441841
// Used by insights endpoints. Need to check both for auditors and for regular users with template acl perms.
18451842
if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceTemplateInsights); err != nil {

coderd/database/dbauthz/dbauthz_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -314,13 +314,13 @@ func (s *MethodTestSuite) TestGroup() {
314314
_ = dbgen.GroupMember(s.T(), db, database.GroupMember{})
315315
check.Args(g.ID).Asserts(g, rbac.ActionRead)
316316
}))
317-
s.Run("GetUserGroupNames", s.Subtest(func(db database.Store, check *expects) {
317+
s.Run("GetGroupsByOrganizationAndUserID", s.Subtest(func(db database.Store, check *expects) {
318318
g := dbgen.Group(s.T(), db, database.Group{})
319319
gm := dbgen.GroupMember(s.T(), db, database.GroupMember{GroupID: g.ID})
320-
check.Args(database.GetUserGroupNamesParams{
320+
check.Args(database.GetGroupsByOrganizationAndUserIDParams{
321321
OrganizationID: g.OrganizationID,
322322
UserID: gm.UserID,
323-
}).Asserts(rbac.ResourceGroup.InOrg(g.OrganizationID), rbac.ActionRead)
323+
}).Asserts(g, rbac.ActionRead)
324324
}))
325325
s.Run("InsertAllUsersGroup", s.Subtest(func(db database.Store, check *expects) {
326326
o := dbgen.Organization(s.T(), db, database.Organization{})

coderd/database/dbmem/dbmem.go

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -2250,6 +2250,30 @@ func (q *FakeQuerier) GetGroupMembers(_ context.Context, id uuid.UUID) ([]databa
22502250
return users, nil
22512251
}
22522252

2253+
func (q *FakeQuerier) GetGroupsByOrganizationAndUserID(ctx context.Context, arg database.GetGroupsByOrganizationAndUserIDParams) ([]database.Group, error) {
2254+
err := validateDatabaseType(arg)
2255+
if err != nil {
2256+
return nil, err
2257+
}
2258+
2259+
q.mutex.RLock()
2260+
defer q.mutex.RUnlock()
2261+
var groupIds []uuid.UUID
2262+
for _, member := range q.groupMembers {
2263+
if member.UserID == arg.UserID {
2264+
groupIds = append(groupIds, member.GroupID)
2265+
}
2266+
}
2267+
groups := []database.Group{}
2268+
for _, group := range q.groups {
2269+
if slices.Contains(groupIds, group.ID) {
2270+
groups = append(groups, group)
2271+
}
2272+
}
2273+
2274+
return groups, nil
2275+
}
2276+
22532277
func (q *FakeQuerier) GetGroupsByOrganizationID(_ context.Context, id uuid.UUID) ([]database.Group, error) {
22542278
q.mutex.RLock()
22552279
defer q.mutex.RUnlock()
@@ -4334,30 +4358,6 @@ func (q *FakeQuerier) GetUserCount(_ context.Context) (int64, error) {
43344358
return existing, nil
43354359
}
43364360

4337-
func (q *FakeQuerier) GetUserGroupNames(_ context.Context, arg database.GetUserGroupNamesParams) ([]string, error) {
4338-
err := validateDatabaseType(arg)
4339-
if err != nil {
4340-
return nil, err
4341-
}
4342-
4343-
q.mutex.RLock()
4344-
defer q.mutex.RUnlock()
4345-
var groupIds []uuid.UUID
4346-
for _, member := range q.groupMembers {
4347-
if member.UserID == arg.UserID {
4348-
groupIds = append(groupIds, member.GroupID)
4349-
}
4350-
}
4351-
groupNames := []string{}
4352-
for _, group := range q.groups {
4353-
if slices.Contains(groupIds, group.ID) {
4354-
groupNames = append(groupNames, group.Name)
4355-
}
4356-
}
4357-
4358-
return groupNames, nil
4359-
}
4360-
43614361
func (q *FakeQuerier) GetUserLatencyInsights(_ context.Context, arg database.GetUserLatencyInsightsParams) ([]database.GetUserLatencyInsightsRow, error) {
43624362
err := validateDatabaseType(arg)
43634363
if err != nil {

coderd/database/dbmetrics/dbmetrics.go

Lines changed: 7 additions & 7 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

coderd/database/dbmock/dbmock.go

Lines changed: 15 additions & 15 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

coderd/database/querier.go

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

coderd/database/queries.sql.go

Lines changed: 43 additions & 35 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

coderd/database/queries/groups.sql

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,9 @@ FROM
2828
WHERE
2929
organization_id = $1;
3030

31-
-- name: GetUserGroupNames :many
31+
-- name: GetGroupsByOrganizationAndUserID :many
3232
SELECT
33-
groups.name
33+
groups.*
3434
FROM
3535
groups
3636
LEFT JOIN

coderd/provisionerdserver/provisionerdserver.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -467,13 +467,17 @@ func (s *server) acquireProtoJob(ctx context.Context, job database.ProvisionerJo
467467
if err != nil {
468468
return nil, failJob(fmt.Sprintf("get owner: %s", err))
469469
}
470-
ownerGroupNames, err := s.Database.GetUserGroupNames(ctx, database.GetUserGroupNamesParams{
470+
ownerGroups, err := s.Database.GetGroupsByOrganizationAndUserID(ctx, database.GetGroupsByOrganizationAndUserIDParams{
471471
UserID: owner.ID,
472472
OrganizationID: s.OrganizationID,
473473
})
474474
if err != nil {
475475
return nil, failJob(fmt.Sprintf("get owner group names: %s", err))
476476
}
477+
ownerGroupNames := []string{}
478+
for _, group := range ownerGroups {
479+
ownerGroupNames = append(ownerGroupNames, group.Name)
480+
}
477481
err = s.Pubsub.Publish(codersdk.WorkspaceNotifyChannel(workspace.ID), []byte{})
478482
if err != nil {
479483
return nil, failJob(fmt.Sprintf("publish workspace update: %s", err))

0 commit comments

Comments
 (0)