Skip to content

chore: merge organization member db queries #13542

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jun 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions cli/server_createadminuser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,12 @@ func TestServerCreateAdminUser(t *testing.T) {
orgIDs[org.ID] = struct{}{}
}

orgMemberships, err := db.GetOrganizationMembershipsByUserID(ctx, user.ID)
orgMemberships, err := db.OrganizationMembers(ctx, database.OrganizationMembersParams{UserID: user.ID})
require.NoError(t, err)
orgIDs2 := make(map[uuid.UUID]struct{}, len(orgMemberships))
for _, membership := range orgMemberships {
orgIDs2[membership.OrganizationID] = struct{}{}
assert.Equal(t, []string{rbac.RoleOrgAdmin()}, membership.Roles, "user is not org admin")
orgIDs2[membership.OrganizationMember.OrganizationID] = struct{}{}
assert.Equal(t, []string{rbac.RoleOrgAdmin()}, membership.OrganizationMember.Roles, "user is not org admin")
}

require.Equal(t, orgIDs, orgIDs2, "user is not in all orgs")
Expand Down
18 changes: 7 additions & 11 deletions coderd/database/dbauthz/dbauthz.go
Original file line number Diff line number Diff line change
Expand Up @@ -1476,14 +1476,6 @@ func (q *querier) GetOrganizationIDsByMemberIDs(ctx context.Context, ids []uuid.
return fetchWithPostFilter(q.auth, policy.ActionRead, q.db.GetOrganizationIDsByMemberIDs)(ctx, ids)
}

func (q *querier) GetOrganizationMemberByUserID(ctx context.Context, arg database.GetOrganizationMemberByUserIDParams) (database.OrganizationMember, error) {
return fetch(q.log, q.auth, q.db.GetOrganizationMemberByUserID)(ctx, arg)
}

func (q *querier) GetOrganizationMembershipsByUserID(ctx context.Context, userID uuid.UUID) ([]database.OrganizationMember, error) {
return fetchWithPostFilter(q.auth, policy.ActionRead, q.db.GetOrganizationMembershipsByUserID)(ctx, userID)
}
Comment on lines -1479 to -1485
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed these, replaced with 1 OrganizationMembers.


func (q *querier) GetOrganizations(ctx context.Context) ([]database.Organization, error) {
fetch := func(ctx context.Context, _ interface{}) ([]database.Organization, error) {
return q.db.GetOrganizations(ctx)
Expand Down Expand Up @@ -2771,6 +2763,10 @@ func (q *querier) ListWorkspaceAgentPortShares(ctx context.Context, workspaceID
return q.db.ListWorkspaceAgentPortShares(ctx, workspaceID)
}

func (q *querier) OrganizationMembers(ctx context.Context, arg database.OrganizationMembersParams) ([]database.OrganizationMembersRow, error) {
return fetchWithPostFilter(q.auth, policy.ActionRead, q.db.OrganizationMembers)(ctx, arg)
}

func (q *querier) ReduceWorkspaceAgentShareLevelToAuthenticatedByTemplate(ctx context.Context, templateID uuid.UUID) error {
template, err := q.db.GetTemplateByID(ctx, templateID)
if err != nil {
Expand Down Expand Up @@ -2870,15 +2866,15 @@ func (q *querier) UpdateInactiveUsersToDormant(ctx context.Context, lastSeenAfte

func (q *querier) UpdateMemberRoles(ctx context.Context, arg database.UpdateMemberRolesParams) (database.OrganizationMember, error) {
// Authorized fetch will check that the actor has read access to the org member since the org member is returned.
member, err := q.GetOrganizationMemberByUserID(ctx, database.GetOrganizationMemberByUserIDParams{
member, err := database.ExpectOne(q.OrganizationMembers(ctx, database.OrganizationMembersParams{
OrganizationID: arg.OrgID,
UserID: arg.UserID,
})
}))
if err != nil {
return database.OrganizationMember{}, err
}

originalRoles, err := q.convertToOrganizationRoles(member.OrganizationID, member.Roles)
originalRoles, err := q.convertToOrganizationRoles(member.OrganizationMember.OrganizationID, member.OrganizationMember.Roles)
if err != nil {
return database.OrganizationMember{}, xerrors.Errorf("convert original roles: %w", err)
}
Expand Down
42 changes: 24 additions & 18 deletions coderd/database/dbauthz/dbauthz_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -596,19 +596,6 @@ func (s *MethodTestSuite) TestOrganization() {
check.Args([]uuid.UUID{ma.UserID, mb.UserID}).
Asserts(rbac.ResourceUserObject(ma.UserID), policy.ActionRead, rbac.ResourceUserObject(mb.UserID), policy.ActionRead)
}))
s.Run("GetOrganizationMemberByUserID", s.Subtest(func(db database.Store, check *expects) {
mem := dbgen.OrganizationMember(s.T(), db, database.OrganizationMember{})
check.Args(database.GetOrganizationMemberByUserIDParams{
OrganizationID: mem.OrganizationID,
UserID: mem.UserID,
}).Asserts(mem, policy.ActionRead).Returns(mem)
}))
s.Run("GetOrganizationMembershipsByUserID", s.Subtest(func(db database.Store, check *expects) {
u := dbgen.User(s.T(), db, database.User{})
a := dbgen.OrganizationMember(s.T(), db, database.OrganizationMember{UserID: u.ID})
b := dbgen.OrganizationMember(s.T(), db, database.OrganizationMember{UserID: u.ID})
check.Args(u.ID).Asserts(a, policy.ActionRead, b, policy.ActionRead).Returns(slice.New(a, b))
}))
s.Run("GetOrganizations", s.Subtest(func(db database.Store, check *expects) {
def, _ := db.GetDefaultOrganization(context.Background())
a := dbgen.Organization(s.T(), db, database.Organization{})
Expand Down Expand Up @@ -658,6 +645,22 @@ func (s *MethodTestSuite) TestOrganization() {
o.ID,
).Asserts(o, policy.ActionDelete)
}))
s.Run("OrganizationMembers", s.Subtest(func(db database.Store, check *expects) {
o := dbgen.Organization(s.T(), db, database.Organization{})
u := dbgen.User(s.T(), db, database.User{})
mem := dbgen.OrganizationMember(s.T(), db, database.OrganizationMember{
OrganizationID: o.ID,
UserID: u.ID,
Roles: []string{rbac.RoleOrgAdmin()},
})

check.Args(database.OrganizationMembersParams{
OrganizationID: uuid.UUID{},
UserID: uuid.UUID{},
}).Asserts(
mem, policy.ActionRead,
)
}))
s.Run("UpdateMemberRoles", s.Subtest(func(db database.Store, check *expects) {
o := dbgen.Organization(s.T(), db, database.Organization{})
u := dbgen.User(s.T(), db, database.User{})
Expand All @@ -673,11 +676,14 @@ func (s *MethodTestSuite) TestOrganization() {
GrantedRoles: []string{},
UserID: u.ID,
OrgID: o.ID,
}).Asserts(
mem, policy.ActionRead,
rbac.ResourceAssignRole.InOrg(o.ID), policy.ActionAssign, // org-mem
rbac.ResourceAssignRole.InOrg(o.ID), policy.ActionDelete, // org-admin
).Returns(out)
}).
WithNotAuthorized(sql.ErrNoRows.Error()).
WithCancelled(sql.ErrNoRows.Error()).
Asserts(
mem, policy.ActionRead,
rbac.ResourceAssignRole.InOrg(o.ID), policy.ActionAssign, // org-mem
rbac.ResourceAssignRole.InOrg(o.ID), policy.ActionDelete, // org-admin
).Returns(out)
}))
}

Expand Down
40 changes: 33 additions & 7 deletions coderd/database/dbauthz/setup_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ func (s *MethodTestSuite) Subtest(testCaseF func(db database.Store, check *expec
if len(testCase.assertions) > 0 {
// Only run these tests if we know the underlying call makes
// rbac assertions.
s.NotAuthorizedErrorTest(ctx, fakeAuthorizer, callMethod)
s.NotAuthorizedErrorTest(ctx, fakeAuthorizer, testCase, callMethod)
}

if len(testCase.assertions) > 0 ||
Expand Down Expand Up @@ -230,7 +230,7 @@ func (s *MethodTestSuite) NoActorErrorTest(callMethod func(ctx context.Context)

// NotAuthorizedErrorTest runs the given method with an authorizer that will fail authz.
// Asserts that the error returned is a NotAuthorizedError.
func (s *MethodTestSuite) NotAuthorizedErrorTest(ctx context.Context, az *coderdtest.FakeAuthorizer, callMethod func(ctx context.Context) ([]reflect.Value, error)) {
func (s *MethodTestSuite) NotAuthorizedErrorTest(ctx context.Context, az *coderdtest.FakeAuthorizer, testCase expects, callMethod func(ctx context.Context) ([]reflect.Value, error)) {
s.Run("NotAuthorized", func() {
az.AlwaysReturn = rbac.ForbiddenWithInternal(xerrors.New("Always fail authz"), rbac.Subject{}, "", rbac.Object{}, nil)

Expand All @@ -242,9 +242,14 @@ func (s *MethodTestSuite) NotAuthorizedErrorTest(ctx context.Context, az *coderd
// This is unfortunate, but if we are using `Filter` the error returned will be nil. So filter out
// any case where the error is nil and the response is an empty slice.
if err != nil || !hasEmptySliceResponse(resp) {
s.ErrorContainsf(err, "unauthorized", "error string should have a good message")
s.Errorf(err, "method should an error with disallow authz")
s.ErrorAs(err, &dbauthz.NotAuthorizedError{}, "error should be NotAuthorizedError")
// Expect the default error
if testCase.notAuthorizedExpect == "" {
s.ErrorContainsf(err, "unauthorized", "error string should have a good message")
s.Errorf(err, "method should an error with disallow authz")
s.ErrorAs(err, &dbauthz.NotAuthorizedError{}, "error should be NotAuthorizedError")
} else {
s.ErrorContains(err, testCase.notAuthorizedExpect)
}
}
})

Expand All @@ -263,8 +268,12 @@ func (s *MethodTestSuite) NotAuthorizedErrorTest(ctx context.Context, az *coderd
// This is unfortunate, but if we are using `Filter` the error returned will be nil. So filter out
// any case where the error is nil and the response is an empty slice.
if err != nil || !hasEmptySliceResponse(resp) {
s.Errorf(err, "method should an error with cancellation")
s.ErrorIsf(err, context.Canceled, "error should match context.Canceled")
if testCase.cancelledCtxExpect == "" {
s.Errorf(err, "method should an error with cancellation")
s.ErrorIsf(err, context.Canceled, "error should match context.Canceled")
} else {
s.ErrorContains(err, testCase.cancelledCtxExpect)
}
}
})
}
Expand Down Expand Up @@ -308,6 +317,13 @@ type expects struct {
// outputs is optional. Can assert non-error return values.
outputs []reflect.Value
err error

// Optional override of the default error checks.
// By default, we search for the expected error strings.
// If these strings are present, these strings will be searched
// instead.
notAuthorizedExpect string
cancelledCtxExpect string
Comment on lines +320 to +326
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This used to be done with a static list:

slice.Contains([]string{
"GetAuthorizedWorkspaces",
"GetAuthorizedTemplates",
}, methodName) {

We should move to this new approach, which is much closer to the actual writing of the tests.

}

// Asserts is required. Asserts the RBAC authorize calls that should be made.
Expand Down Expand Up @@ -338,6 +354,16 @@ func (m *expects) Errors(err error) *expects {
return m
}

func (m *expects) WithNotAuthorized(contains string) *expects {
m.notAuthorizedExpect = contains
return m
}

func (m *expects) WithCancelled(contains string) *expects {
m.cancelledCtxExpect = contains
return m
}

// AssertRBAC contains the object and actions to be asserted.
type AssertRBAC struct {
Object rbac.Object
Expand Down
4 changes: 2 additions & 2 deletions coderd/database/dbgen/dbgen_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,10 +119,10 @@ func TestGenerator(t *testing.T) {
t.Parallel()
db := dbmem.New()
exp := dbgen.OrganizationMember(t, db, database.OrganizationMember{})
require.Equal(t, exp, must(db.GetOrganizationMemberByUserID(context.Background(), database.GetOrganizationMemberByUserIDParams{
require.Equal(t, exp, must(database.ExpectOne(db.OrganizationMembers(context.Background(), database.OrganizationMembersParams{
OrganizationID: exp.OrganizationID,
UserID: exp.UserID,
})))
}))).OrganizationMember)
})

t.Run("Workspace", func(t *testing.T) {
Expand Down
63 changes: 28 additions & 35 deletions coderd/database/dbmem/dbmem.go
Original file line number Diff line number Diff line change
Expand Up @@ -2759,41 +2759,6 @@ func (q *FakeQuerier) GetOrganizationIDsByMemberIDs(_ context.Context, ids []uui
return getOrganizationIDsByMemberIDRows, nil
}

func (q *FakeQuerier) GetOrganizationMemberByUserID(_ context.Context, arg database.GetOrganizationMemberByUserIDParams) (database.OrganizationMember, error) {
if err := validateDatabaseType(arg); err != nil {
return database.OrganizationMember{}, err
}

q.mutex.RLock()
defer q.mutex.RUnlock()

for _, organizationMember := range q.organizationMembers {
if organizationMember.OrganizationID != arg.OrganizationID {
continue
}
if organizationMember.UserID != arg.UserID {
continue
}
return organizationMember, nil
}
return database.OrganizationMember{}, sql.ErrNoRows
}

func (q *FakeQuerier) GetOrganizationMembershipsByUserID(_ context.Context, userID uuid.UUID) ([]database.OrganizationMember, error) {
q.mutex.RLock()
defer q.mutex.RUnlock()

var memberships []database.OrganizationMember
for _, organizationMember := range q.organizationMembers {
mem := organizationMember
if mem.UserID != userID {
continue
}
memberships = append(memberships, mem)
}
return memberships, nil
}

func (q *FakeQuerier) GetOrganizations(_ context.Context) ([]database.Organization, error) {
q.mutex.RLock()
defer q.mutex.RUnlock()
Expand Down Expand Up @@ -6963,6 +6928,34 @@ func (q *FakeQuerier) ListWorkspaceAgentPortShares(_ context.Context, workspaceI
return shares, nil
}

func (q *FakeQuerier) OrganizationMembers(_ context.Context, arg database.OrganizationMembersParams) ([]database.OrganizationMembersRow, error) {
if err := validateDatabaseType(arg); err != nil {
return []database.OrganizationMembersRow{}, err
}

q.mutex.RLock()
defer q.mutex.RUnlock()

tmp := make([]database.OrganizationMembersRow, 0)
for _, organizationMember := range q.organizationMembers {
if arg.OrganizationID != uuid.Nil && organizationMember.OrganizationID != arg.OrganizationID {
continue
}

if arg.UserID != uuid.Nil && organizationMember.UserID != arg.UserID {
continue
}

organizationMember := organizationMember
user, _ := q.getUserByIDNoLock(organizationMember.UserID)
tmp = append(tmp, database.OrganizationMembersRow{
OrganizationMember: organizationMember,
Username: user.Username,
})
}
return tmp, nil
}

func (q *FakeQuerier) ReduceWorkspaceAgentShareLevelToAuthenticatedByTemplate(_ context.Context, templateID uuid.UUID) error {
err := validateDatabaseType(templateID)
if err != nil {
Expand Down
21 changes: 7 additions & 14 deletions coderd/database/dbmetrics/dbmetrics.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

45 changes: 15 additions & 30 deletions coderd/database/dbmock/dbmock.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading
Loading