Skip to content

Commit 4e1d469

Browse files
committed
chore: merge organization member db queries
Merge into 1 that also joins in the user table for username. Required to list organization members on UI/cli
1 parent 5ccf508 commit 4e1d469

19 files changed

+282
-209
lines changed

cli/server_createadminuser_test.go

+3-3
Original file line numberDiff line numberDiff line change
@@ -67,12 +67,12 @@ func TestServerCreateAdminUser(t *testing.T) {
6767
orgIDs[org.ID] = struct{}{}
6868
}
6969

70-
orgMemberships, err := db.GetOrganizationMembershipsByUserID(ctx, user.ID)
70+
orgMemberships, err := db.OrganizationMembers(ctx, database.OrganizationMembersParams{UserID: user.ID})
7171
require.NoError(t, err)
7272
orgIDs2 := make(map[uuid.UUID]struct{}, len(orgMemberships))
7373
for _, membership := range orgMemberships {
74-
orgIDs2[membership.OrganizationID] = struct{}{}
75-
assert.Equal(t, []string{rbac.RoleOrgAdmin()}, membership.Roles, "user is not org admin")
74+
orgIDs2[membership.OrganizationMember.OrganizationID] = struct{}{}
75+
assert.Equal(t, []string{rbac.RoleOrgAdmin()}, membership.OrganizationMember.Roles, "user is not org admin")
7676
}
7777

7878
require.Equal(t, orgIDs, orgIDs2, "user is not in all orgs")

coderd/database/dbauthz/dbauthz.go

+7-11
Original file line numberDiff line numberDiff line change
@@ -1476,14 +1476,6 @@ func (q *querier) GetOrganizationIDsByMemberIDs(ctx context.Context, ids []uuid.
14761476
return fetchWithPostFilter(q.auth, policy.ActionRead, q.db.GetOrganizationIDsByMemberIDs)(ctx, ids)
14771477
}
14781478

1479-
func (q *querier) GetOrganizationMemberByUserID(ctx context.Context, arg database.GetOrganizationMemberByUserIDParams) (database.OrganizationMember, error) {
1480-
return fetch(q.log, q.auth, q.db.GetOrganizationMemberByUserID)(ctx, arg)
1481-
}
1482-
1483-
func (q *querier) GetOrganizationMembershipsByUserID(ctx context.Context, userID uuid.UUID) ([]database.OrganizationMember, error) {
1484-
return fetchWithPostFilter(q.auth, policy.ActionRead, q.db.GetOrganizationMembershipsByUserID)(ctx, userID)
1485-
}
1486-
14871479
func (q *querier) GetOrganizations(ctx context.Context) ([]database.Organization, error) {
14881480
fetch := func(ctx context.Context, _ interface{}) ([]database.Organization, error) {
14891481
return q.db.GetOrganizations(ctx)
@@ -2771,6 +2763,10 @@ func (q *querier) ListWorkspaceAgentPortShares(ctx context.Context, workspaceID
27712763
return q.db.ListWorkspaceAgentPortShares(ctx, workspaceID)
27722764
}
27732765

2766+
func (q *querier) OrganizationMembers(ctx context.Context, arg database.OrganizationMembersParams) ([]database.OrganizationMembersRow, error) {
2767+
return fetchWithPostFilter(q.auth, policy.ActionRead, q.db.OrganizationMembers)(ctx, arg)
2768+
}
2769+
27742770
func (q *querier) ReduceWorkspaceAgentShareLevelToAuthenticatedByTemplate(ctx context.Context, templateID uuid.UUID) error {
27752771
template, err := q.db.GetTemplateByID(ctx, templateID)
27762772
if err != nil {
@@ -2870,15 +2866,15 @@ func (q *querier) UpdateInactiveUsersToDormant(ctx context.Context, lastSeenAfte
28702866

28712867
func (q *querier) UpdateMemberRoles(ctx context.Context, arg database.UpdateMemberRolesParams) (database.OrganizationMember, error) {
28722868
// Authorized fetch will check that the actor has read access to the org member since the org member is returned.
2873-
member, err := q.GetOrganizationMemberByUserID(ctx, database.GetOrganizationMemberByUserIDParams{
2869+
member, err := database.ExpectOne(q.OrganizationMembers(ctx, database.OrganizationMembersParams{
28742870
OrganizationID: arg.OrgID,
28752871
UserID: arg.UserID,
2876-
})
2872+
}))
28772873
if err != nil {
28782874
return database.OrganizationMember{}, err
28792875
}
28802876

2881-
originalRoles, err := q.convertToOrganizationRoles(member.OrganizationID, member.Roles)
2877+
originalRoles, err := q.convertToOrganizationRoles(member.OrganizationMember.OrganizationID, member.OrganizationMember.Roles)
28822878
if err != nil {
28832879
return database.OrganizationMember{}, xerrors.Errorf("convert original roles: %w", err)
28842880
}

coderd/database/dbauthz/dbauthz_test.go

+24-12
Original file line numberDiff line numberDiff line change
@@ -596,13 +596,6 @@ func (s *MethodTestSuite) TestOrganization() {
596596
check.Args([]uuid.UUID{ma.UserID, mb.UserID}).
597597
Asserts(rbac.ResourceUserObject(ma.UserID), policy.ActionRead, rbac.ResourceUserObject(mb.UserID), policy.ActionRead)
598598
}))
599-
s.Run("GetOrganizationMemberByUserID", s.Subtest(func(db database.Store, check *expects) {
600-
mem := dbgen.OrganizationMember(s.T(), db, database.OrganizationMember{})
601-
check.Args(database.GetOrganizationMemberByUserIDParams{
602-
OrganizationID: mem.OrganizationID,
603-
UserID: mem.UserID,
604-
}).Asserts(mem, policy.ActionRead).Returns(mem)
605-
}))
606599
s.Run("GetOrganizationMembershipsByUserID", s.Subtest(func(db database.Store, check *expects) {
607600
u := dbgen.User(s.T(), db, database.User{})
608601
a := dbgen.OrganizationMember(s.T(), db, database.OrganizationMember{UserID: u.ID})
@@ -658,6 +651,22 @@ func (s *MethodTestSuite) TestOrganization() {
658651
o.ID,
659652
).Asserts(o, policy.ActionDelete)
660653
}))
654+
s.Run("OrganizationMembers", s.Subtest(func(db database.Store, check *expects) {
655+
o := dbgen.Organization(s.T(), db, database.Organization{})
656+
u := dbgen.User(s.T(), db, database.User{})
657+
mem := dbgen.OrganizationMember(s.T(), db, database.OrganizationMember{
658+
OrganizationID: o.ID,
659+
UserID: u.ID,
660+
Roles: []string{rbac.RoleOrgAdmin()},
661+
})
662+
663+
check.Args(database.OrganizationMembersParams{
664+
OrganizationID: uuid.UUID{},
665+
UserID: uuid.UUID{},
666+
}).Asserts(
667+
mem, policy.ActionRead,
668+
)
669+
}))
661670
s.Run("UpdateMemberRoles", s.Subtest(func(db database.Store, check *expects) {
662671
o := dbgen.Organization(s.T(), db, database.Organization{})
663672
u := dbgen.User(s.T(), db, database.User{})
@@ -673,11 +682,14 @@ func (s *MethodTestSuite) TestOrganization() {
673682
GrantedRoles: []string{},
674683
UserID: u.ID,
675684
OrgID: o.ID,
676-
}).Asserts(
677-
mem, policy.ActionRead,
678-
rbac.ResourceAssignRole.InOrg(o.ID), policy.ActionAssign, // org-mem
679-
rbac.ResourceAssignRole.InOrg(o.ID), policy.ActionDelete, // org-admin
680-
).Returns(out)
685+
}).
686+
WithNotAuthorized(sql.ErrNoRows.Error()).
687+
WithCancelled(sql.ErrNoRows.Error()).
688+
Asserts(
689+
mem, policy.ActionRead,
690+
rbac.ResourceAssignRole.InOrg(o.ID), policy.ActionAssign, // org-mem
691+
rbac.ResourceAssignRole.InOrg(o.ID), policy.ActionDelete, // org-admin
692+
).Returns(out)
681693
}))
682694
}
683695

coderd/database/dbauthz/setup_test.go

+33-7
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ func (s *MethodTestSuite) Subtest(testCaseF func(db database.Store, check *expec
157157
if len(testCase.assertions) > 0 {
158158
// Only run these tests if we know the underlying call makes
159159
// rbac assertions.
160-
s.NotAuthorizedErrorTest(ctx, fakeAuthorizer, callMethod)
160+
s.NotAuthorizedErrorTest(ctx, fakeAuthorizer, testCase, callMethod)
161161
}
162162

163163
if len(testCase.assertions) > 0 ||
@@ -230,7 +230,7 @@ func (s *MethodTestSuite) NoActorErrorTest(callMethod func(ctx context.Context)
230230

231231
// NotAuthorizedErrorTest runs the given method with an authorizer that will fail authz.
232232
// Asserts that the error returned is a NotAuthorizedError.
233-
func (s *MethodTestSuite) NotAuthorizedErrorTest(ctx context.Context, az *coderdtest.FakeAuthorizer, callMethod func(ctx context.Context) ([]reflect.Value, error)) {
233+
func (s *MethodTestSuite) NotAuthorizedErrorTest(ctx context.Context, az *coderdtest.FakeAuthorizer, testCase expects, callMethod func(ctx context.Context) ([]reflect.Value, error)) {
234234
s.Run("NotAuthorized", func() {
235235
az.AlwaysReturn = rbac.ForbiddenWithInternal(xerrors.New("Always fail authz"), rbac.Subject{}, "", rbac.Object{}, nil)
236236

@@ -242,9 +242,14 @@ func (s *MethodTestSuite) NotAuthorizedErrorTest(ctx context.Context, az *coderd
242242
// This is unfortunate, but if we are using `Filter` the error returned will be nil. So filter out
243243
// any case where the error is nil and the response is an empty slice.
244244
if err != nil || !hasEmptySliceResponse(resp) {
245-
s.ErrorContainsf(err, "unauthorized", "error string should have a good message")
246-
s.Errorf(err, "method should an error with disallow authz")
247-
s.ErrorAs(err, &dbauthz.NotAuthorizedError{}, "error should be NotAuthorizedError")
245+
// Expect the default error
246+
if testCase.notAuthorizedExpect == "" {
247+
s.ErrorContainsf(err, "unauthorized", "error string should have a good message")
248+
s.Errorf(err, "method should an error with disallow authz")
249+
s.ErrorAs(err, &dbauthz.NotAuthorizedError{}, "error should be NotAuthorizedError")
250+
} else {
251+
s.ErrorContains(err, testCase.notAuthorizedExpect)
252+
}
248253
}
249254
})
250255

@@ -263,8 +268,12 @@ func (s *MethodTestSuite) NotAuthorizedErrorTest(ctx context.Context, az *coderd
263268
// This is unfortunate, but if we are using `Filter` the error returned will be nil. So filter out
264269
// any case where the error is nil and the response is an empty slice.
265270
if err != nil || !hasEmptySliceResponse(resp) {
266-
s.Errorf(err, "method should an error with cancellation")
267-
s.ErrorIsf(err, context.Canceled, "error should match context.Canceled")
271+
if testCase.cancelledCtxExpect == "" {
272+
s.Errorf(err, "method should an error with cancellation")
273+
s.ErrorIsf(err, context.Canceled, "error should match context.Canceled")
274+
} else {
275+
s.ErrorContains(err, testCase.cancelledCtxExpect)
276+
}
268277
}
269278
})
270279
}
@@ -308,6 +317,13 @@ type expects struct {
308317
// outputs is optional. Can assert non-error return values.
309318
outputs []reflect.Value
310319
err error
320+
321+
// Optional override of the default error checks.
322+
// By default, we search for the expected error strings.
323+
// If these strings are present, these strings will be searched
324+
// instead.
325+
notAuthorizedExpect string
326+
cancelledCtxExpect string
311327
}
312328

313329
// Asserts is required. Asserts the RBAC authorize calls that should be made.
@@ -338,6 +354,16 @@ func (m *expects) Errors(err error) *expects {
338354
return m
339355
}
340356

357+
func (m *expects) WithNotAuthorized(contains string) *expects {
358+
m.notAuthorizedExpect = contains
359+
return m
360+
}
361+
362+
func (m *expects) WithCancelled(contains string) *expects {
363+
m.cancelledCtxExpect = contains
364+
return m
365+
}
366+
341367
// AssertRBAC contains the object and actions to be asserted.
342368
type AssertRBAC struct {
343369
Object rbac.Object

coderd/database/dbgen/dbgen_test.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -119,10 +119,10 @@ func TestGenerator(t *testing.T) {
119119
t.Parallel()
120120
db := dbmem.New()
121121
exp := dbgen.OrganizationMember(t, db, database.OrganizationMember{})
122-
require.Equal(t, exp, must(db.GetOrganizationMemberByUserID(context.Background(), database.GetOrganizationMemberByUserIDParams{
122+
require.Equal(t, exp, must(database.ExpectOne(db.OrganizationMembers(context.Background(), database.OrganizationMembersParams{
123123
OrganizationID: exp.OrganizationID,
124124
UserID: exp.UserID,
125-
})))
125+
}))).OrganizationMember)
126126
})
127127

128128
t.Run("Workspace", func(t *testing.T) {

coderd/database/dbmem/dbmem.go

+28-35
Original file line numberDiff line numberDiff line change
@@ -2759,41 +2759,6 @@ func (q *FakeQuerier) GetOrganizationIDsByMemberIDs(_ context.Context, ids []uui
27592759
return getOrganizationIDsByMemberIDRows, nil
27602760
}
27612761

2762-
func (q *FakeQuerier) GetOrganizationMemberByUserID(_ context.Context, arg database.GetOrganizationMemberByUserIDParams) (database.OrganizationMember, error) {
2763-
if err := validateDatabaseType(arg); err != nil {
2764-
return database.OrganizationMember{}, err
2765-
}
2766-
2767-
q.mutex.RLock()
2768-
defer q.mutex.RUnlock()
2769-
2770-
for _, organizationMember := range q.organizationMembers {
2771-
if organizationMember.OrganizationID != arg.OrganizationID {
2772-
continue
2773-
}
2774-
if organizationMember.UserID != arg.UserID {
2775-
continue
2776-
}
2777-
return organizationMember, nil
2778-
}
2779-
return database.OrganizationMember{}, sql.ErrNoRows
2780-
}
2781-
2782-
func (q *FakeQuerier) GetOrganizationMembershipsByUserID(_ context.Context, userID uuid.UUID) ([]database.OrganizationMember, error) {
2783-
q.mutex.RLock()
2784-
defer q.mutex.RUnlock()
2785-
2786-
var memberships []database.OrganizationMember
2787-
for _, organizationMember := range q.organizationMembers {
2788-
mem := organizationMember
2789-
if mem.UserID != userID {
2790-
continue
2791-
}
2792-
memberships = append(memberships, mem)
2793-
}
2794-
return memberships, nil
2795-
}
2796-
27972762
func (q *FakeQuerier) GetOrganizations(_ context.Context) ([]database.Organization, error) {
27982763
q.mutex.RLock()
27992764
defer q.mutex.RUnlock()
@@ -6963,6 +6928,34 @@ func (q *FakeQuerier) ListWorkspaceAgentPortShares(_ context.Context, workspaceI
69636928
return shares, nil
69646929
}
69656930

6931+
func (q *FakeQuerier) OrganizationMembers(ctx context.Context, arg database.OrganizationMembersParams) ([]database.OrganizationMembersRow, error) {
6932+
if err := validateDatabaseType(arg); err != nil {
6933+
return []database.OrganizationMembersRow{}, err
6934+
}
6935+
6936+
q.mutex.RLock()
6937+
defer q.mutex.RUnlock()
6938+
6939+
tmp := make([]database.OrganizationMembersRow, 0)
6940+
for _, organizationMember := range q.organizationMembers {
6941+
if arg.OrganizationID != uuid.Nil && organizationMember.OrganizationID != arg.OrganizationID {
6942+
continue
6943+
}
6944+
6945+
if arg.UserID != uuid.Nil && organizationMember.UserID != arg.UserID {
6946+
continue
6947+
}
6948+
6949+
organizationMember := organizationMember
6950+
user, _ := q.getUserByIDNoLock(organizationMember.UserID)
6951+
tmp = append(tmp, database.OrganizationMembersRow{
6952+
OrganizationMember: organizationMember,
6953+
Username: user.Username,
6954+
})
6955+
}
6956+
return tmp, nil
6957+
}
6958+
69666959
func (q *FakeQuerier) ReduceWorkspaceAgentShareLevelToAuthenticatedByTemplate(_ context.Context, templateID uuid.UUID) error {
69676960
err := validateDatabaseType(templateID)
69686961
if err != nil {

coderd/database/dbmetrics/dbmetrics.go

+7-14
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

coderd/database/dbmock/dbmock.go

+15-30
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

coderd/database/modelmethods.go

+4
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,10 @@ func (m OrganizationMember) RBACObject() rbac.Object {
179179
WithOwner(m.UserID.String())
180180
}
181181

182+
func (m OrganizationMembersRow) RBACObject() rbac.Object {
183+
return m.OrganizationMember.RBACObject()
184+
}
185+
182186
func (m GetOrganizationIDsByMemberIDsRow) RBACObject() rbac.Object {
183187
// TODO: This feels incorrect as we are really returning a list of orgmembers.
184188
// This return type should be refactored to return a list of orgmembers, not this

0 commit comments

Comments
 (0)