From 4e1d469ae7ed8785d6beae8c79797b5adbf261f8 Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Tue, 11 Jun 2024 09:47:26 -0500 Subject: [PATCH 1/4] chore: merge organization member db queries Merge into 1 that also joins in the user table for username. Required to list organization members on UI/cli --- cli/server_createadminuser_test.go | 6 +- coderd/database/dbauthz/dbauthz.go | 18 +-- coderd/database/dbauthz/dbauthz_test.go | 36 +++-- coderd/database/dbauthz/setup_test.go | 40 +++++- coderd/database/dbgen/dbgen_test.go | 4 +- coderd/database/dbmem/dbmem.go | 63 ++++----- coderd/database/dbmetrics/dbmetrics.go | 21 +-- coderd/database/dbmock/dbmock.go | 45 ++---- coderd/database/modelmethods.go | 4 + coderd/database/modelqueries.go | 24 ++++ coderd/database/models.go | 2 +- coderd/database/querier.go | 5 +- coderd/database/querier_test.go | 36 +++++ coderd/database/queries.sql.go | 130 +++++++++--------- .../database/queries/organizationmembers.sql | 32 +++-- coderd/httpmw/organizationparam.go | 6 +- coderd/userauth.go | 8 +- coderd/users.go | 7 +- enterprise/coderd/groups.go | 4 +- 19 files changed, 282 insertions(+), 209 deletions(-) diff --git a/cli/server_createadminuser_test.go b/cli/server_createadminuser_test.go index 9bc6add2ecbd2..6e3939ea298d6 100644 --- a/cli/server_createadminuser_test.go +++ b/cli/server_createadminuser_test.go @@ -67,12 +67,12 @@ func TestServerCreateAdminUser(t *testing.T) { orgIDs[org.ID] = struct{}{} } - orgMemberships, err := db.GetOrganizationMembershipsByUserID(ctx, user.ID) + orgMemberships, err := db.OrganizationMembers(ctx, database.OrganizationMembersParams{UserID: user.ID}) require.NoError(t, err) orgIDs2 := make(map[uuid.UUID]struct{}, len(orgMemberships)) for _, membership := range orgMemberships { - orgIDs2[membership.OrganizationID] = struct{}{} - assert.Equal(t, []string{rbac.RoleOrgAdmin()}, membership.Roles, "user is not org admin") + orgIDs2[membership.OrganizationMember.OrganizationID] = struct{}{} + assert.Equal(t, []string{rbac.RoleOrgAdmin()}, membership.OrganizationMember.Roles, "user is not org admin") } require.Equal(t, orgIDs, orgIDs2, "user is not in all orgs") diff --git a/coderd/database/dbauthz/dbauthz.go b/coderd/database/dbauthz/dbauthz.go index bc8bf19763c73..85659751a9107 100644 --- a/coderd/database/dbauthz/dbauthz.go +++ b/coderd/database/dbauthz/dbauthz.go @@ -1476,14 +1476,6 @@ func (q *querier) GetOrganizationIDsByMemberIDs(ctx context.Context, ids []uuid. return fetchWithPostFilter(q.auth, policy.ActionRead, q.db.GetOrganizationIDsByMemberIDs)(ctx, ids) } -func (q *querier) GetOrganizationMemberByUserID(ctx context.Context, arg database.GetOrganizationMemberByUserIDParams) (database.OrganizationMember, error) { - return fetch(q.log, q.auth, q.db.GetOrganizationMemberByUserID)(ctx, arg) -} - -func (q *querier) GetOrganizationMembershipsByUserID(ctx context.Context, userID uuid.UUID) ([]database.OrganizationMember, error) { - return fetchWithPostFilter(q.auth, policy.ActionRead, q.db.GetOrganizationMembershipsByUserID)(ctx, userID) -} - func (q *querier) GetOrganizations(ctx context.Context) ([]database.Organization, error) { fetch := func(ctx context.Context, _ interface{}) ([]database.Organization, error) { return q.db.GetOrganizations(ctx) @@ -2771,6 +2763,10 @@ func (q *querier) ListWorkspaceAgentPortShares(ctx context.Context, workspaceID return q.db.ListWorkspaceAgentPortShares(ctx, workspaceID) } +func (q *querier) OrganizationMembers(ctx context.Context, arg database.OrganizationMembersParams) ([]database.OrganizationMembersRow, error) { + return fetchWithPostFilter(q.auth, policy.ActionRead, q.db.OrganizationMembers)(ctx, arg) +} + func (q *querier) ReduceWorkspaceAgentShareLevelToAuthenticatedByTemplate(ctx context.Context, templateID uuid.UUID) error { template, err := q.db.GetTemplateByID(ctx, templateID) if err != nil { @@ -2870,15 +2866,15 @@ func (q *querier) UpdateInactiveUsersToDormant(ctx context.Context, lastSeenAfte func (q *querier) UpdateMemberRoles(ctx context.Context, arg database.UpdateMemberRolesParams) (database.OrganizationMember, error) { // Authorized fetch will check that the actor has read access to the org member since the org member is returned. - member, err := q.GetOrganizationMemberByUserID(ctx, database.GetOrganizationMemberByUserIDParams{ + member, err := database.ExpectOne(q.OrganizationMembers(ctx, database.OrganizationMembersParams{ OrganizationID: arg.OrgID, UserID: arg.UserID, - }) + })) if err != nil { return database.OrganizationMember{}, err } - originalRoles, err := q.convertToOrganizationRoles(member.OrganizationID, member.Roles) + originalRoles, err := q.convertToOrganizationRoles(member.OrganizationMember.OrganizationID, member.OrganizationMember.Roles) if err != nil { return database.OrganizationMember{}, xerrors.Errorf("convert original roles: %w", err) } diff --git a/coderd/database/dbauthz/dbauthz_test.go b/coderd/database/dbauthz/dbauthz_test.go index 9d90a4d44114a..ce04f78ac16f8 100644 --- a/coderd/database/dbauthz/dbauthz_test.go +++ b/coderd/database/dbauthz/dbauthz_test.go @@ -596,13 +596,6 @@ func (s *MethodTestSuite) TestOrganization() { check.Args([]uuid.UUID{ma.UserID, mb.UserID}). Asserts(rbac.ResourceUserObject(ma.UserID), policy.ActionRead, rbac.ResourceUserObject(mb.UserID), policy.ActionRead) })) - s.Run("GetOrganizationMemberByUserID", s.Subtest(func(db database.Store, check *expects) { - mem := dbgen.OrganizationMember(s.T(), db, database.OrganizationMember{}) - check.Args(database.GetOrganizationMemberByUserIDParams{ - OrganizationID: mem.OrganizationID, - UserID: mem.UserID, - }).Asserts(mem, policy.ActionRead).Returns(mem) - })) s.Run("GetOrganizationMembershipsByUserID", s.Subtest(func(db database.Store, check *expects) { u := dbgen.User(s.T(), db, database.User{}) a := dbgen.OrganizationMember(s.T(), db, database.OrganizationMember{UserID: u.ID}) @@ -658,6 +651,22 @@ func (s *MethodTestSuite) TestOrganization() { o.ID, ).Asserts(o, policy.ActionDelete) })) + s.Run("OrganizationMembers", s.Subtest(func(db database.Store, check *expects) { + o := dbgen.Organization(s.T(), db, database.Organization{}) + u := dbgen.User(s.T(), db, database.User{}) + mem := dbgen.OrganizationMember(s.T(), db, database.OrganizationMember{ + OrganizationID: o.ID, + UserID: u.ID, + Roles: []string{rbac.RoleOrgAdmin()}, + }) + + check.Args(database.OrganizationMembersParams{ + OrganizationID: uuid.UUID{}, + UserID: uuid.UUID{}, + }).Asserts( + mem, policy.ActionRead, + ) + })) s.Run("UpdateMemberRoles", s.Subtest(func(db database.Store, check *expects) { o := dbgen.Organization(s.T(), db, database.Organization{}) u := dbgen.User(s.T(), db, database.User{}) @@ -673,11 +682,14 @@ func (s *MethodTestSuite) TestOrganization() { GrantedRoles: []string{}, UserID: u.ID, OrgID: o.ID, - }).Asserts( - mem, policy.ActionRead, - rbac.ResourceAssignRole.InOrg(o.ID), policy.ActionAssign, // org-mem - rbac.ResourceAssignRole.InOrg(o.ID), policy.ActionDelete, // org-admin - ).Returns(out) + }). + WithNotAuthorized(sql.ErrNoRows.Error()). + WithCancelled(sql.ErrNoRows.Error()). + Asserts( + mem, policy.ActionRead, + rbac.ResourceAssignRole.InOrg(o.ID), policy.ActionAssign, // org-mem + rbac.ResourceAssignRole.InOrg(o.ID), policy.ActionDelete, // org-admin + ).Returns(out) })) } diff --git a/coderd/database/dbauthz/setup_test.go b/coderd/database/dbauthz/setup_test.go index e391b9e2ef3c6..4df38a3ca4b98 100644 --- a/coderd/database/dbauthz/setup_test.go +++ b/coderd/database/dbauthz/setup_test.go @@ -157,7 +157,7 @@ func (s *MethodTestSuite) Subtest(testCaseF func(db database.Store, check *expec if len(testCase.assertions) > 0 { // Only run these tests if we know the underlying call makes // rbac assertions. - s.NotAuthorizedErrorTest(ctx, fakeAuthorizer, callMethod) + s.NotAuthorizedErrorTest(ctx, fakeAuthorizer, testCase, callMethod) } if len(testCase.assertions) > 0 || @@ -230,7 +230,7 @@ func (s *MethodTestSuite) NoActorErrorTest(callMethod func(ctx context.Context) // NotAuthorizedErrorTest runs the given method with an authorizer that will fail authz. // Asserts that the error returned is a NotAuthorizedError. -func (s *MethodTestSuite) NotAuthorizedErrorTest(ctx context.Context, az *coderdtest.FakeAuthorizer, callMethod func(ctx context.Context) ([]reflect.Value, error)) { +func (s *MethodTestSuite) NotAuthorizedErrorTest(ctx context.Context, az *coderdtest.FakeAuthorizer, testCase expects, callMethod func(ctx context.Context) ([]reflect.Value, error)) { s.Run("NotAuthorized", func() { az.AlwaysReturn = rbac.ForbiddenWithInternal(xerrors.New("Always fail authz"), rbac.Subject{}, "", rbac.Object{}, nil) @@ -242,9 +242,14 @@ func (s *MethodTestSuite) NotAuthorizedErrorTest(ctx context.Context, az *coderd // This is unfortunate, but if we are using `Filter` the error returned will be nil. So filter out // any case where the error is nil and the response is an empty slice. if err != nil || !hasEmptySliceResponse(resp) { - s.ErrorContainsf(err, "unauthorized", "error string should have a good message") - s.Errorf(err, "method should an error with disallow authz") - s.ErrorAs(err, &dbauthz.NotAuthorizedError{}, "error should be NotAuthorizedError") + // Expect the default error + if testCase.notAuthorizedExpect == "" { + s.ErrorContainsf(err, "unauthorized", "error string should have a good message") + s.Errorf(err, "method should an error with disallow authz") + s.ErrorAs(err, &dbauthz.NotAuthorizedError{}, "error should be NotAuthorizedError") + } else { + s.ErrorContains(err, testCase.notAuthorizedExpect) + } } }) @@ -263,8 +268,12 @@ func (s *MethodTestSuite) NotAuthorizedErrorTest(ctx context.Context, az *coderd // This is unfortunate, but if we are using `Filter` the error returned will be nil. So filter out // any case where the error is nil and the response is an empty slice. if err != nil || !hasEmptySliceResponse(resp) { - s.Errorf(err, "method should an error with cancellation") - s.ErrorIsf(err, context.Canceled, "error should match context.Canceled") + if testCase.cancelledCtxExpect == "" { + s.Errorf(err, "method should an error with cancellation") + s.ErrorIsf(err, context.Canceled, "error should match context.Canceled") + } else { + s.ErrorContains(err, testCase.cancelledCtxExpect) + } } }) } @@ -308,6 +317,13 @@ type expects struct { // outputs is optional. Can assert non-error return values. outputs []reflect.Value err error + + // Optional override of the default error checks. + // By default, we search for the expected error strings. + // If these strings are present, these strings will be searched + // instead. + notAuthorizedExpect string + cancelledCtxExpect string } // Asserts is required. Asserts the RBAC authorize calls that should be made. @@ -338,6 +354,16 @@ func (m *expects) Errors(err error) *expects { return m } +func (m *expects) WithNotAuthorized(contains string) *expects { + m.notAuthorizedExpect = contains + return m +} + +func (m *expects) WithCancelled(contains string) *expects { + m.cancelledCtxExpect = contains + return m +} + // AssertRBAC contains the object and actions to be asserted. type AssertRBAC struct { Object rbac.Object diff --git a/coderd/database/dbgen/dbgen_test.go b/coderd/database/dbgen/dbgen_test.go index eaf5a0e764482..2681f6eb1fece 100644 --- a/coderd/database/dbgen/dbgen_test.go +++ b/coderd/database/dbgen/dbgen_test.go @@ -119,10 +119,10 @@ func TestGenerator(t *testing.T) { t.Parallel() db := dbmem.New() exp := dbgen.OrganizationMember(t, db, database.OrganizationMember{}) - require.Equal(t, exp, must(db.GetOrganizationMemberByUserID(context.Background(), database.GetOrganizationMemberByUserIDParams{ + require.Equal(t, exp, must(database.ExpectOne(db.OrganizationMembers(context.Background(), database.OrganizationMembersParams{ OrganizationID: exp.OrganizationID, UserID: exp.UserID, - }))) + }))).OrganizationMember) }) t.Run("Workspace", func(t *testing.T) { diff --git a/coderd/database/dbmem/dbmem.go b/coderd/database/dbmem/dbmem.go index 55251f71227ca..f0d23fc9db829 100644 --- a/coderd/database/dbmem/dbmem.go +++ b/coderd/database/dbmem/dbmem.go @@ -2759,41 +2759,6 @@ func (q *FakeQuerier) GetOrganizationIDsByMemberIDs(_ context.Context, ids []uui return getOrganizationIDsByMemberIDRows, nil } -func (q *FakeQuerier) GetOrganizationMemberByUserID(_ context.Context, arg database.GetOrganizationMemberByUserIDParams) (database.OrganizationMember, error) { - if err := validateDatabaseType(arg); err != nil { - return database.OrganizationMember{}, err - } - - q.mutex.RLock() - defer q.mutex.RUnlock() - - for _, organizationMember := range q.organizationMembers { - if organizationMember.OrganizationID != arg.OrganizationID { - continue - } - if organizationMember.UserID != arg.UserID { - continue - } - return organizationMember, nil - } - return database.OrganizationMember{}, sql.ErrNoRows -} - -func (q *FakeQuerier) GetOrganizationMembershipsByUserID(_ context.Context, userID uuid.UUID) ([]database.OrganizationMember, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - var memberships []database.OrganizationMember - for _, organizationMember := range q.organizationMembers { - mem := organizationMember - if mem.UserID != userID { - continue - } - memberships = append(memberships, mem) - } - return memberships, nil -} - func (q *FakeQuerier) GetOrganizations(_ context.Context) ([]database.Organization, error) { q.mutex.RLock() defer q.mutex.RUnlock() @@ -6963,6 +6928,34 @@ func (q *FakeQuerier) ListWorkspaceAgentPortShares(_ context.Context, workspaceI return shares, nil } +func (q *FakeQuerier) OrganizationMembers(ctx context.Context, arg database.OrganizationMembersParams) ([]database.OrganizationMembersRow, error) { + if err := validateDatabaseType(arg); err != nil { + return []database.OrganizationMembersRow{}, err + } + + q.mutex.RLock() + defer q.mutex.RUnlock() + + tmp := make([]database.OrganizationMembersRow, 0) + for _, organizationMember := range q.organizationMembers { + if arg.OrganizationID != uuid.Nil && organizationMember.OrganizationID != arg.OrganizationID { + continue + } + + if arg.UserID != uuid.Nil && organizationMember.UserID != arg.UserID { + continue + } + + organizationMember := organizationMember + user, _ := q.getUserByIDNoLock(organizationMember.UserID) + tmp = append(tmp, database.OrganizationMembersRow{ + OrganizationMember: organizationMember, + Username: user.Username, + }) + } + return tmp, nil +} + func (q *FakeQuerier) ReduceWorkspaceAgentShareLevelToAuthenticatedByTemplate(_ context.Context, templateID uuid.UUID) error { err := validateDatabaseType(templateID) if err != nil { diff --git a/coderd/database/dbmetrics/dbmetrics.go b/coderd/database/dbmetrics/dbmetrics.go index aff562fcdb89f..1891fe6f999e9 100644 --- a/coderd/database/dbmetrics/dbmetrics.go +++ b/coderd/database/dbmetrics/dbmetrics.go @@ -760,20 +760,6 @@ func (m metricsStore) GetOrganizationIDsByMemberIDs(ctx context.Context, ids []u return organizations, err } -func (m metricsStore) GetOrganizationMemberByUserID(ctx context.Context, arg database.GetOrganizationMemberByUserIDParams) (database.OrganizationMember, error) { - start := time.Now() - member, err := m.s.GetOrganizationMemberByUserID(ctx, arg) - m.queryLatencies.WithLabelValues("GetOrganizationMemberByUserID").Observe(time.Since(start).Seconds()) - return member, err -} - -func (m metricsStore) GetOrganizationMembershipsByUserID(ctx context.Context, userID uuid.UUID) ([]database.OrganizationMember, error) { - start := time.Now() - memberships, err := m.s.GetOrganizationMembershipsByUserID(ctx, userID) - m.queryLatencies.WithLabelValues("GetOrganizationMembershipsByUserID").Observe(time.Since(start).Seconds()) - return memberships, err -} - func (m metricsStore) GetOrganizations(ctx context.Context) ([]database.Organization, error) { start := time.Now() organizations, err := m.s.GetOrganizations(ctx) @@ -1747,6 +1733,13 @@ func (m metricsStore) ListWorkspaceAgentPortShares(ctx context.Context, workspac return r0, r1 } +func (m metricsStore) OrganizationMembers(ctx context.Context, arg database.OrganizationMembersParams) ([]database.OrganizationMembersRow, error) { + start := time.Now() + r0, r1 := m.s.OrganizationMembers(ctx, arg) + m.queryLatencies.WithLabelValues("OrganizationMembers").Observe(time.Since(start).Seconds()) + return r0, r1 +} + func (m metricsStore) ReduceWorkspaceAgentShareLevelToAuthenticatedByTemplate(ctx context.Context, templateID uuid.UUID) error { start := time.Now() r0 := m.s.ReduceWorkspaceAgentShareLevelToAuthenticatedByTemplate(ctx, templateID) diff --git a/coderd/database/dbmock/dbmock.go b/coderd/database/dbmock/dbmock.go index 3ef96d13f8b33..b49d3e7f06c76 100644 --- a/coderd/database/dbmock/dbmock.go +++ b/coderd/database/dbmock/dbmock.go @@ -1514,36 +1514,6 @@ func (mr *MockStoreMockRecorder) GetOrganizationIDsByMemberIDs(arg0, arg1 any) * return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOrganizationIDsByMemberIDs", reflect.TypeOf((*MockStore)(nil).GetOrganizationIDsByMemberIDs), arg0, arg1) } -// GetOrganizationMemberByUserID mocks base method. -func (m *MockStore) GetOrganizationMemberByUserID(arg0 context.Context, arg1 database.GetOrganizationMemberByUserIDParams) (database.OrganizationMember, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetOrganizationMemberByUserID", arg0, arg1) - ret0, _ := ret[0].(database.OrganizationMember) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetOrganizationMemberByUserID indicates an expected call of GetOrganizationMemberByUserID. -func (mr *MockStoreMockRecorder) GetOrganizationMemberByUserID(arg0, arg1 any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOrganizationMemberByUserID", reflect.TypeOf((*MockStore)(nil).GetOrganizationMemberByUserID), arg0, arg1) -} - -// GetOrganizationMembershipsByUserID mocks base method. -func (m *MockStore) GetOrganizationMembershipsByUserID(arg0 context.Context, arg1 uuid.UUID) ([]database.OrganizationMember, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetOrganizationMembershipsByUserID", arg0, arg1) - ret0, _ := ret[0].([]database.OrganizationMember) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetOrganizationMembershipsByUserID indicates an expected call of GetOrganizationMembershipsByUserID. -func (mr *MockStoreMockRecorder) GetOrganizationMembershipsByUserID(arg0, arg1 any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOrganizationMembershipsByUserID", reflect.TypeOf((*MockStore)(nil).GetOrganizationMembershipsByUserID), arg0, arg1) -} - // GetOrganizations mocks base method. func (m *MockStore) GetOrganizations(arg0 context.Context) ([]database.Organization, error) { m.ctrl.T.Helper() @@ -3661,6 +3631,21 @@ func (mr *MockStoreMockRecorder) ListWorkspaceAgentPortShares(arg0, arg1 any) *g return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListWorkspaceAgentPortShares", reflect.TypeOf((*MockStore)(nil).ListWorkspaceAgentPortShares), arg0, arg1) } +// OrganizationMembers mocks base method. +func (m *MockStore) OrganizationMembers(arg0 context.Context, arg1 database.OrganizationMembersParams) ([]database.OrganizationMembersRow, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "OrganizationMembers", arg0, arg1) + ret0, _ := ret[0].([]database.OrganizationMembersRow) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// OrganizationMembers indicates an expected call of OrganizationMembers. +func (mr *MockStoreMockRecorder) OrganizationMembers(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OrganizationMembers", reflect.TypeOf((*MockStore)(nil).OrganizationMembers), arg0, arg1) +} + // Ping mocks base method. func (m *MockStore) Ping(arg0 context.Context) (time.Duration, error) { m.ctrl.T.Helper() diff --git a/coderd/database/modelmethods.go b/coderd/database/modelmethods.go index e5fd1db60337f..ee22ae1ad42ba 100644 --- a/coderd/database/modelmethods.go +++ b/coderd/database/modelmethods.go @@ -179,6 +179,10 @@ func (m OrganizationMember) RBACObject() rbac.Object { WithOwner(m.UserID.String()) } +func (m OrganizationMembersRow) RBACObject() rbac.Object { + return m.OrganizationMember.RBACObject() +} + func (m GetOrganizationIDsByMemberIDsRow) RBACObject() rbac.Object { // TODO: This feels incorrect as we are really returning a list of orgmembers. // This return type should be refactored to return a list of orgmembers, not this diff --git a/coderd/database/modelqueries.go b/coderd/database/modelqueries.go index ca38505b28ef0..9cc5d7792101c 100644 --- a/coderd/database/modelqueries.go +++ b/coderd/database/modelqueries.go @@ -2,6 +2,7 @@ package database import ( "context" + "database/sql" "fmt" "strings" @@ -17,6 +18,29 @@ const ( authorizedQueryPlaceholder = "-- @authorize_filter" ) +// ExpectOne can be used to convert a ':many:' query into a ':one' +// query. To reduce the quantity of SQL queries, a :many with a filter is used. +// These filters sometimes are expected to return just 1 row. +// +// A :many query will never return a sql.ErrNoRows, but a :one does. +// This function will correct the error for the empty set. +func ExpectOne[T any](ret []T, err error) (T, error) { + var empty T + if err != nil { + return empty, err + } + + if len(ret) == 0 { + return empty, sql.ErrNoRows + } + + if len(ret) > 1 { + return empty, xerrors.Errorf("too many rows returned, expected 1") + } + + return ret[0], nil +} + // customQuerier encompasses all non-generated queries. // It provides a flexible way to write queries for cases // where sqlc proves inadequate. diff --git a/coderd/database/models.go b/coderd/database/models.go index 8a558f5beeb0b..3cdb3b1a63c01 100644 --- a/coderd/database/models.go +++ b/coderd/database/models.go @@ -1,6 +1,6 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.25.0 +// sqlc v1.26.0 package database diff --git a/coderd/database/querier.go b/coderd/database/querier.go index 6e2b1ff60cfdf..a25f2dcf9b006 100644 --- a/coderd/database/querier.go +++ b/coderd/database/querier.go @@ -1,6 +1,6 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.25.0 +// sqlc v1.26.0 package database @@ -151,8 +151,6 @@ type sqlcQuerier interface { GetOrganizationByID(ctx context.Context, id uuid.UUID) (Organization, error) GetOrganizationByName(ctx context.Context, name string) (Organization, error) GetOrganizationIDsByMemberIDs(ctx context.Context, ids []uuid.UUID) ([]GetOrganizationIDsByMemberIDsRow, error) - GetOrganizationMemberByUserID(ctx context.Context, arg GetOrganizationMemberByUserIDParams) (OrganizationMember, error) - GetOrganizationMembershipsByUserID(ctx context.Context, userID uuid.UUID) ([]OrganizationMember, error) GetOrganizations(ctx context.Context) ([]Organization, error) GetOrganizationsByUserID(ctx context.Context, userID uuid.UUID) ([]Organization, error) GetParameterSchemasByJobID(ctx context.Context, jobID uuid.UUID) ([]ParameterSchema, error) @@ -349,6 +347,7 @@ type sqlcQuerier interface { InsertWorkspaceResource(ctx context.Context, arg InsertWorkspaceResourceParams) (WorkspaceResource, error) InsertWorkspaceResourceMetadata(ctx context.Context, arg InsertWorkspaceResourceMetadataParams) ([]WorkspaceResourceMetadatum, error) ListWorkspaceAgentPortShares(ctx context.Context, workspaceID uuid.UUID) ([]WorkspaceAgentPortShare, error) + OrganizationMembers(ctx context.Context, arg OrganizationMembersParams) ([]OrganizationMembersRow, error) ReduceWorkspaceAgentShareLevelToAuthenticatedByTemplate(ctx context.Context, templateID uuid.UUID) error RegisterWorkspaceProxy(ctx context.Context, arg RegisterWorkspaceProxyParams) (WorkspaceProxy, error) RemoveUserFromAllGroups(ctx context.Context, userID uuid.UUID) error diff --git a/coderd/database/querier_test.go b/coderd/database/querier_test.go index 0d523c25290e2..22004e6fab71c 100644 --- a/coderd/database/querier_test.go +++ b/coderd/database/querier_test.go @@ -903,6 +903,42 @@ func TestArchiveVersions(t *testing.T) { }) } +func TestExpectOne(t *testing.T) { + t.Parallel() + if testing.Short() { + t.SkipNow() + } + + t.Run("ErrNoRows", func(t *testing.T) { + t.Parallel() + sqlDB := testSQLDB(t) + err := migrations.Up(sqlDB) + require.NoError(t, err) + db := database.New(sqlDB) + ctx := context.Background() + + _, err = database.ExpectOne(db.GetUsers(ctx, database.GetUsersParams{})) + require.ErrorIs(t, err, sql.ErrNoRows) + }) + + t.Run("TooMany", func(t *testing.T) { + t.Parallel() + sqlDB := testSQLDB(t) + err := migrations.Up(sqlDB) + require.NoError(t, err) + db := database.New(sqlDB) + ctx := context.Background() + + // Create 2 organizations so the query returns >1 + dbgen.Organization(t, db, database.Organization{}) + dbgen.Organization(t, db, database.Organization{}) + + // Organizations is an easy table without foreign key dependencies + _, err = database.ExpectOne(db.GetOrganizations(ctx)) + require.ErrorContains(t, err, "too many rows returned") + }) +} + func requireUsersMatch(t testing.TB, expected []database.User, found []database.GetUsersRow, msg string) { t.Helper() require.ElementsMatch(t, expected, database.ConvertUserRows(found), msg) diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index 823cf2cc45796..773f6ab613113 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -1,6 +1,6 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.25.0 +// sqlc v1.26.0 package database @@ -3795,25 +3795,35 @@ func (q *sqlQuerier) GetOrganizationIDsByMemberIDs(ctx context.Context, ids []uu return items, nil } -const getOrganizationMemberByUserID = `-- name: GetOrganizationMemberByUserID :one -SELECT - user_id, organization_id, created_at, updated_at, roles -FROM - organization_members -WHERE - organization_id = $1 - AND user_id = $2 -LIMIT - 1 +const insertOrganizationMember = `-- name: InsertOrganizationMember :one +INSERT INTO + organization_members ( + organization_id, + user_id, + created_at, + updated_at, + roles + ) +VALUES + ($1, $2, $3, $4, $5) RETURNING user_id, organization_id, created_at, updated_at, roles ` -type GetOrganizationMemberByUserIDParams struct { +type InsertOrganizationMemberParams struct { OrganizationID uuid.UUID `db:"organization_id" json:"organization_id"` UserID uuid.UUID `db:"user_id" json:"user_id"` + CreatedAt time.Time `db:"created_at" json:"created_at"` + UpdatedAt time.Time `db:"updated_at" json:"updated_at"` + Roles []string `db:"roles" json:"roles"` } -func (q *sqlQuerier) GetOrganizationMemberByUserID(ctx context.Context, arg GetOrganizationMemberByUserIDParams) (OrganizationMember, error) { - row := q.db.QueryRowContext(ctx, getOrganizationMemberByUserID, arg.OrganizationID, arg.UserID) +func (q *sqlQuerier) InsertOrganizationMember(ctx context.Context, arg InsertOrganizationMemberParams) (OrganizationMember, error) { + row := q.db.QueryRowContext(ctx, insertOrganizationMember, + arg.OrganizationID, + arg.UserID, + arg.CreatedAt, + arg.UpdatedAt, + pq.Array(arg.Roles), + ) var i OrganizationMember err := row.Scan( &i.UserID, @@ -3825,30 +3835,56 @@ func (q *sqlQuerier) GetOrganizationMemberByUserID(ctx context.Context, arg GetO return i, err } -const getOrganizationMembershipsByUserID = `-- name: GetOrganizationMembershipsByUserID :many +const organizationMembers = `-- name: OrganizationMembers :many SELECT - user_id, organization_id, created_at, updated_at, roles + organization_members.user_id, organization_members.organization_id, organization_members.created_at, organization_members.updated_at, organization_members.roles, + users.username FROM organization_members -WHERE - user_id = $1 + INNER JOIN + users ON organization_members.user_id = users.id +WHERE + true + -- Filter by organization id + AND CASE + WHEN $1 :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN + organization_id = $1 + ELSE true + END + -- Filter by user id + AND CASE + WHEN $2 :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN + user_id = $2 + ELSE true + END ` -func (q *sqlQuerier) GetOrganizationMembershipsByUserID(ctx context.Context, userID uuid.UUID) ([]OrganizationMember, error) { - rows, err := q.db.QueryContext(ctx, getOrganizationMembershipsByUserID, userID) +type OrganizationMembersParams struct { + OrganizationID uuid.UUID `db:"organization_id" json:"organization_id"` + UserID uuid.UUID `db:"user_id" json:"user_id"` +} + +type OrganizationMembersRow struct { + OrganizationMember OrganizationMember `db:"organization_member" json:"organization_member"` + Username string `db:"username" json:"username"` +} + +func (q *sqlQuerier) OrganizationMembers(ctx context.Context, arg OrganizationMembersParams) ([]OrganizationMembersRow, error) { + rows, err := q.db.QueryContext(ctx, organizationMembers, arg.OrganizationID, arg.UserID) if err != nil { return nil, err } defer rows.Close() - var items []OrganizationMember + var items []OrganizationMembersRow for rows.Next() { - var i OrganizationMember + var i OrganizationMembersRow if err := rows.Scan( - &i.UserID, - &i.OrganizationID, - &i.CreatedAt, - &i.UpdatedAt, - pq.Array(&i.Roles), + &i.OrganizationMember.UserID, + &i.OrganizationMember.OrganizationID, + &i.OrganizationMember.CreatedAt, + &i.OrganizationMember.UpdatedAt, + pq.Array(&i.OrganizationMember.Roles), + &i.Username, ); err != nil { return nil, err } @@ -3863,46 +3899,6 @@ func (q *sqlQuerier) GetOrganizationMembershipsByUserID(ctx context.Context, use return items, nil } -const insertOrganizationMember = `-- name: InsertOrganizationMember :one -INSERT INTO - organization_members ( - organization_id, - user_id, - created_at, - updated_at, - roles - ) -VALUES - ($1, $2, $3, $4, $5) RETURNING user_id, organization_id, created_at, updated_at, roles -` - -type InsertOrganizationMemberParams struct { - OrganizationID uuid.UUID `db:"organization_id" json:"organization_id"` - UserID uuid.UUID `db:"user_id" json:"user_id"` - CreatedAt time.Time `db:"created_at" json:"created_at"` - UpdatedAt time.Time `db:"updated_at" json:"updated_at"` - Roles []string `db:"roles" json:"roles"` -} - -func (q *sqlQuerier) InsertOrganizationMember(ctx context.Context, arg InsertOrganizationMemberParams) (OrganizationMember, error) { - row := q.db.QueryRowContext(ctx, insertOrganizationMember, - arg.OrganizationID, - arg.UserID, - arg.CreatedAt, - arg.UpdatedAt, - pq.Array(arg.Roles), - ) - var i OrganizationMember - err := row.Scan( - &i.UserID, - &i.OrganizationID, - &i.CreatedAt, - &i.UpdatedAt, - pq.Array(&i.Roles), - ) - return i, err -} - const updateMemberRoles = `-- name: UpdateMemberRoles :one UPDATE organization_members diff --git a/coderd/database/queries/organizationmembers.sql b/coderd/database/queries/organizationmembers.sql index 10a45d25eb2c5..d2a269cf4d8d5 100644 --- a/coderd/database/queries/organizationmembers.sql +++ b/coderd/database/queries/organizationmembers.sql @@ -1,13 +1,25 @@ --- name: GetOrganizationMemberByUserID :one +-- name: OrganizationMembers :many SELECT - * + sqlc.embed(organization_members), + users.username FROM organization_members + INNER JOIN + users ON organization_members.user_id = users.id WHERE - organization_id = $1 - AND user_id = $2 -LIMIT - 1; + true + -- Filter by organization id + AND CASE + WHEN @organization_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN + organization_id = @organization_id + ELSE true + END + -- Filter by user id + AND CASE + WHEN @user_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN + user_id = @user_id + ELSE true + END; -- name: InsertOrganizationMember :one INSERT INTO @@ -22,14 +34,6 @@ VALUES ($1, $2, $3, $4, $5) RETURNING *; --- name: GetOrganizationMembershipsByUserID :many -SELECT - * -FROM - organization_members -WHERE - user_id = $1; - -- name: GetOrganizationIDsByMemberIDs :many SELECT user_id, array_agg(organization_id) :: uuid [ ] AS "organization_IDs" diff --git a/coderd/httpmw/organizationparam.go b/coderd/httpmw/organizationparam.go index 0c8ccae96c519..9ec0af6a460cf 100644 --- a/coderd/httpmw/organizationparam.go +++ b/coderd/httpmw/organizationparam.go @@ -124,10 +124,10 @@ func ExtractOrganizationMemberParam(db database.Store) func(http.Handler) http.H } organization := OrganizationParam(r) - organizationMember, err := db.GetOrganizationMemberByUserID(ctx, database.GetOrganizationMemberByUserIDParams{ + organizationMember, err := database.ExpectOne(db.OrganizationMembers(ctx, database.OrganizationMembersParams{ OrganizationID: organization.ID, UserID: user.ID, - }) + })) if httpapi.Is404Error(err) { httpapi.ResourceNotFound(rw) return @@ -141,7 +141,7 @@ func ExtractOrganizationMemberParam(db database.Store) func(http.Handler) http.H } ctx = context.WithValue(ctx, organizationMemberParamContextKey{}, OrganizationMember{ - OrganizationMember: organizationMember, + OrganizationMember: organizationMember.OrganizationMember, // Here we're making two exceptions to the rule about not leaking data about the user // to the API handler, which is to include the username and avatar URL. // If the caller has permission to read the OrganizationMember, then we're explicitly diff --git a/coderd/userauth.go b/coderd/userauth.go index 306982b29c9ab..7bf243f3b708b 100644 --- a/coderd/userauth.go +++ b/coderd/userauth.go @@ -1518,15 +1518,17 @@ func (api *API) oauthLogin(r *http.Request, params *oauthLoginParams) ([]*http.C } //nolint:gocritic // No user present in the context. - memberships, err := tx.GetOrganizationMembershipsByUserID(dbauthz.AsSystemRestricted(ctx), user.ID) + memberships, err := tx.OrganizationMembers(dbauthz.AsSystemRestricted(ctx), database.OrganizationMembersParams{ + UserID: user.ID, + }) if err != nil { return xerrors.Errorf("get organization memberships: %w", err) } // If the user is not in the default organization, then we can't assign groups. // A user cannot be in groups to an org they are not a member of. - if !slices.ContainsFunc(memberships, func(member database.OrganizationMember) bool { - return member.OrganizationID == defaultOrganization.ID + if !slices.ContainsFunc(memberships, func(member database.OrganizationMembersRow) bool { + return member.OrganizationMember.OrganizationID == defaultOrganization.ID }) { return xerrors.Errorf("user %s is not a member of the default organization, cannot assign to groups in the org", user.ID) } diff --git a/coderd/users.go b/coderd/users.go index 1e375232b48e7..7d3b6181e171a 100644 --- a/coderd/users.go +++ b/coderd/users.go @@ -1027,12 +1027,15 @@ func (api *API) userRoles(rw http.ResponseWriter, r *http.Request) { return } + // TODO: Replace this with "GetAuthorizationUserRoles" resp := codersdk.UserRoles{ Roles: user.RBACRoles, OrganizationRoles: make(map[uuid.UUID][]string), } - memberships, err := api.Database.GetOrganizationMembershipsByUserID(ctx, user.ID) + memberships, err := api.Database.OrganizationMembers(ctx, database.OrganizationMembersParams{ + UserID: user.ID, + }) if err != nil { httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ Message: "Internal error fetching user's organization memberships.", @@ -1042,7 +1045,7 @@ func (api *API) userRoles(rw http.ResponseWriter, r *http.Request) { } for _, mem := range memberships { - resp.OrganizationRoles[mem.OrganizationID] = mem.Roles + resp.OrganizationRoles[mem.OrganizationMember.OrganizationID] = mem.OrganizationMember.Roles } httpapi.Write(ctx, rw, http.StatusOK, resp) diff --git a/enterprise/coderd/groups.go b/enterprise/coderd/groups.go index dea135f683fb8..65220e5cbabf7 100644 --- a/enterprise/coderd/groups.go +++ b/enterprise/coderd/groups.go @@ -166,10 +166,10 @@ func (api *API) patchGroup(rw http.ResponseWriter, r *http.Request) { } // TODO: It would be nice to enforce this at the schema level // but unfortunately our org_members table does not have an ID. - _, err := api.Database.GetOrganizationMemberByUserID(ctx, database.GetOrganizationMemberByUserIDParams{ + _, err := database.ExpectOne(api.Database.OrganizationMembers(ctx, database.OrganizationMembersParams{ OrganizationID: group.OrganizationID, UserID: uuid.MustParse(id), - }) + })) if xerrors.Is(err, sql.ErrNoRows) { httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ Message: fmt.Sprintf("User %q must be a member of organization %q", id, group.ID), From 43634925420b93daa8da8b2d279fd144b593055e Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Tue, 11 Jun 2024 10:27:25 -0500 Subject: [PATCH 2/4] fixup dbauthz test --- coderd/database/dbauthz/dbauthz_test.go | 6 ------ 1 file changed, 6 deletions(-) diff --git a/coderd/database/dbauthz/dbauthz_test.go b/coderd/database/dbauthz/dbauthz_test.go index ce04f78ac16f8..44d45118ce1ea 100644 --- a/coderd/database/dbauthz/dbauthz_test.go +++ b/coderd/database/dbauthz/dbauthz_test.go @@ -596,12 +596,6 @@ func (s *MethodTestSuite) TestOrganization() { check.Args([]uuid.UUID{ma.UserID, mb.UserID}). Asserts(rbac.ResourceUserObject(ma.UserID), policy.ActionRead, rbac.ResourceUserObject(mb.UserID), policy.ActionRead) })) - s.Run("GetOrganizationMembershipsByUserID", s.Subtest(func(db database.Store, check *expects) { - u := dbgen.User(s.T(), db, database.User{}) - a := dbgen.OrganizationMember(s.T(), db, database.OrganizationMember{UserID: u.ID}) - b := dbgen.OrganizationMember(s.T(), db, database.OrganizationMember{UserID: u.ID}) - check.Args(u.ID).Asserts(a, policy.ActionRead, b, policy.ActionRead).Returns(slice.New(a, b)) - })) s.Run("GetOrganizations", s.Subtest(func(db database.Store, check *expects) { def, _ := db.GetDefaultOrganization(context.Background()) a := dbgen.Organization(s.T(), db, database.Organization{}) From 8b1bc7fe53c013114531e4a8aa887661dcf27093 Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Tue, 11 Jun 2024 10:32:17 -0500 Subject: [PATCH 3/4] linting --- coderd/database/dbmem/dbmem.go | 2 +- coderd/database/models.go | 2 +- coderd/database/querier.go | 2 +- coderd/database/queries.sql.go | 2 +- coderd/userauth.go | 3 ++- coderd/users.go | 3 ++- 6 files changed, 8 insertions(+), 6 deletions(-) diff --git a/coderd/database/dbmem/dbmem.go b/coderd/database/dbmem/dbmem.go index f0d23fc9db829..3e8a3012a177e 100644 --- a/coderd/database/dbmem/dbmem.go +++ b/coderd/database/dbmem/dbmem.go @@ -6928,7 +6928,7 @@ func (q *FakeQuerier) ListWorkspaceAgentPortShares(_ context.Context, workspaceI return shares, nil } -func (q *FakeQuerier) OrganizationMembers(ctx context.Context, arg database.OrganizationMembersParams) ([]database.OrganizationMembersRow, error) { +func (q *FakeQuerier) OrganizationMembers(_ context.Context, arg database.OrganizationMembersParams) ([]database.OrganizationMembersRow, error) { if err := validateDatabaseType(arg); err != nil { return []database.OrganizationMembersRow{}, err } diff --git a/coderd/database/models.go b/coderd/database/models.go index 3cdb3b1a63c01..8a558f5beeb0b 100644 --- a/coderd/database/models.go +++ b/coderd/database/models.go @@ -1,6 +1,6 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.26.0 +// sqlc v1.25.0 package database diff --git a/coderd/database/querier.go b/coderd/database/querier.go index a25f2dcf9b006..c4388236134ae 100644 --- a/coderd/database/querier.go +++ b/coderd/database/querier.go @@ -1,6 +1,6 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.26.0 +// sqlc v1.25.0 package database diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index 773f6ab613113..675c8199c441b 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -1,6 +1,6 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.26.0 +// sqlc v1.25.0 package database diff --git a/coderd/userauth.go b/coderd/userauth.go index 7bf243f3b708b..b9d163a6afdac 100644 --- a/coderd/userauth.go +++ b/coderd/userauth.go @@ -1519,7 +1519,8 @@ func (api *API) oauthLogin(r *http.Request, params *oauthLoginParams) ([]*http.C //nolint:gocritic // No user present in the context. memberships, err := tx.OrganizationMembers(dbauthz.AsSystemRestricted(ctx), database.OrganizationMembersParams{ - UserID: user.ID, + UserID: user.ID, + OrganizationID: uuid.Nil, }) if err != nil { return xerrors.Errorf("get organization memberships: %w", err) diff --git a/coderd/users.go b/coderd/users.go index 7d3b6181e171a..b8a3306b12121 100644 --- a/coderd/users.go +++ b/coderd/users.go @@ -1034,7 +1034,8 @@ func (api *API) userRoles(rw http.ResponseWriter, r *http.Request) { } memberships, err := api.Database.OrganizationMembers(ctx, database.OrganizationMembersParams{ - UserID: user.ID, + UserID: user.ID, + OrganizationID: uuid.Nil, }) if err != nil { httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ From 2542dcdad468af692cc19f37835ed5f111d05238 Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Wed, 12 Jun 2024 09:25:09 -0500 Subject: [PATCH 4/4] add comment --- coderd/database/querier.go | 4 +++ coderd/database/queries.sql.go | 25 +++++++++++-------- .../database/queries/organizationmembers.sql | 25 +++++++++++-------- 3 files changed, 32 insertions(+), 22 deletions(-) diff --git a/coderd/database/querier.go b/coderd/database/querier.go index c4388236134ae..f87e6015b517e 100644 --- a/coderd/database/querier.go +++ b/coderd/database/querier.go @@ -347,6 +347,10 @@ type sqlcQuerier interface { InsertWorkspaceResource(ctx context.Context, arg InsertWorkspaceResourceParams) (WorkspaceResource, error) InsertWorkspaceResourceMetadata(ctx context.Context, arg InsertWorkspaceResourceMetadataParams) ([]WorkspaceResourceMetadatum, error) ListWorkspaceAgentPortShares(ctx context.Context, workspaceID uuid.UUID) ([]WorkspaceAgentPortShare, error) + // Arguments are optional with uuid.Nil to ignore. + // - Use just 'organization_id' to get all members of an org + // - Use just 'user_id' to get all orgs a user is a member of + // - Use both to get a specific org member row OrganizationMembers(ctx context.Context, arg OrganizationMembersParams) ([]OrganizationMembersRow, error) ReduceWorkspaceAgentShareLevelToAuthenticatedByTemplate(ctx context.Context, templateID uuid.UUID) error RegisterWorkspaceProxy(ctx context.Context, arg RegisterWorkspaceProxyParams) (WorkspaceProxy, error) diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index 675c8199c441b..d9d13e4598c04 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -3844,18 +3844,17 @@ FROM INNER JOIN users ON organization_members.user_id = users.id WHERE - true - -- Filter by organization id - AND CASE - WHEN $1 :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN - organization_id = $1 - ELSE true + -- Filter by organization id + CASE + WHEN $1 :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN + organization_id = $1 + ELSE true END - -- Filter by user id - AND CASE - WHEN $2 :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN - user_id = $2 - ELSE true + -- Filter by user id + AND CASE + WHEN $2 :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN + user_id = $2 + ELSE true END ` @@ -3869,6 +3868,10 @@ type OrganizationMembersRow struct { Username string `db:"username" json:"username"` } +// Arguments are optional with uuid.Nil to ignore. +// - Use just 'organization_id' to get all members of an org +// - Use just 'user_id' to get all orgs a user is a member of +// - Use both to get a specific org member row func (q *sqlQuerier) OrganizationMembers(ctx context.Context, arg OrganizationMembersParams) ([]OrganizationMembersRow, error) { rows, err := q.db.QueryContext(ctx, organizationMembers, arg.OrganizationID, arg.UserID) if err != nil { diff --git a/coderd/database/queries/organizationmembers.sql b/coderd/database/queries/organizationmembers.sql index d2a269cf4d8d5..d32d9a8e8abc8 100644 --- a/coderd/database/queries/organizationmembers.sql +++ b/coderd/database/queries/organizationmembers.sql @@ -1,4 +1,8 @@ -- name: OrganizationMembers :many +-- Arguments are optional with uuid.Nil to ignore. +-- - Use just 'organization_id' to get all members of an org +-- - Use just 'user_id' to get all orgs a user is a member of +-- - Use both to get a specific org member row SELECT sqlc.embed(organization_members), users.username @@ -7,18 +11,17 @@ FROM INNER JOIN users ON organization_members.user_id = users.id WHERE - true - -- Filter by organization id - AND CASE - WHEN @organization_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN - organization_id = @organization_id - ELSE true + -- Filter by organization id + CASE + WHEN @organization_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN + organization_id = @organization_id + ELSE true END - -- Filter by user id - AND CASE - WHEN @user_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN - user_id = @user_id - ELSE true + -- Filter by user id + AND CASE + WHEN @user_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN + user_id = @user_id + ELSE true END; -- name: InsertOrganizationMember :one