diff --git a/cli/server_createadminuser.go b/cli/server_createadminuser.go index 01eb56a83b7e8..0619688468554 100644 --- a/cli/server_createadminuser.go +++ b/cli/server_createadminuser.go @@ -176,7 +176,7 @@ func (r *RootCmd) newCreateAdminUserCommand() *serpent.Command { // Create the user. var newUser database.User err = db.InTx(func(tx database.Store) error { - orgs, err := tx.GetOrganizations(ctx) + orgs, err := tx.GetOrganizations(ctx, database.GetOrganizationsParams{}) if err != nil { return xerrors.Errorf("get organizations: %w", err) } diff --git a/cli/server_createadminuser_test.go b/cli/server_createadminuser_test.go index 6e3939ea298d6..17c02b6548c09 100644 --- a/cli/server_createadminuser_test.go +++ b/cli/server_createadminuser_test.go @@ -60,7 +60,7 @@ func TestServerCreateAdminUser(t *testing.T) { require.EqualValues(t, []string{codersdk.RoleOwner}, user.RBACRoles, "user does not have owner role") // Check that user is admin in every org. - orgs, err := db.GetOrganizations(ctx) + orgs, err := db.GetOrganizations(ctx, database.GetOrganizationsParams{}) require.NoError(t, err) orgIDs := make(map[uuid.UUID]struct{}, len(orgs)) for _, org := range orgs { diff --git a/coderd/database/dbauthz/dbauthz.go b/coderd/database/dbauthz/dbauthz.go index c040df06196ec..8dd903a2a9137 100644 --- a/coderd/database/dbauthz/dbauthz.go +++ b/coderd/database/dbauthz/dbauthz.go @@ -1700,9 +1700,9 @@ func (q *querier) GetOrganizationIDsByMemberIDs(ctx context.Context, ids []uuid. return fetchWithPostFilter(q.auth, policy.ActionRead, q.db.GetOrganizationIDsByMemberIDs)(ctx, ids) } -func (q *querier) GetOrganizations(ctx context.Context) ([]database.Organization, error) { +func (q *querier) GetOrganizations(ctx context.Context, args database.GetOrganizationsParams) ([]database.Organization, error) { fetch := func(ctx context.Context, _ interface{}) ([]database.Organization, error) { - return q.db.GetOrganizations(ctx) + return q.db.GetOrganizations(ctx, args) } return fetchWithPostFilter(q.auth, policy.ActionRead, fetch)(ctx, nil) } diff --git a/coderd/database/dbauthz/dbauthz_test.go b/coderd/database/dbauthz/dbauthz_test.go index a270e4d2cb0d1..96cfca85f19d6 100644 --- a/coderd/database/dbauthz/dbauthz_test.go +++ b/coderd/database/dbauthz/dbauthz_test.go @@ -635,7 +635,7 @@ func (s *MethodTestSuite) TestOrganization() { def, _ := db.GetDefaultOrganization(context.Background()) a := dbgen.Organization(s.T(), db, database.Organization{}) b := dbgen.Organization(s.T(), db, database.Organization{}) - check.Args().Asserts(def, policy.ActionRead, a, policy.ActionRead, b, policy.ActionRead).Returns(slice.New(def, a, b)) + check.Args(database.GetOrganizationsParams{}).Asserts(def, policy.ActionRead, a, policy.ActionRead, b, policy.ActionRead).Returns(slice.New(def, a, b)) })) s.Run("GetOrganizationsByUserID", s.Subtest(func(db database.Store, check *expects) { u := dbgen.User(s.T(), db, database.User{}) diff --git a/coderd/database/dbmem/dbmem.go b/coderd/database/dbmem/dbmem.go index 42fdd2b93f63e..161a7d64b5aaa 100644 --- a/coderd/database/dbmem/dbmem.go +++ b/coderd/database/dbmem/dbmem.go @@ -3034,14 +3034,24 @@ func (q *FakeQuerier) GetOrganizationIDsByMemberIDs(_ context.Context, ids []uui return getOrganizationIDsByMemberIDRows, nil } -func (q *FakeQuerier) GetOrganizations(_ context.Context) ([]database.Organization, error) { +func (q *FakeQuerier) GetOrganizations(_ context.Context, args database.GetOrganizationsParams) ([]database.Organization, error) { q.mutex.RLock() defer q.mutex.RUnlock() - if len(q.organizations) == 0 { - return nil, sql.ErrNoRows + tmp := make([]database.Organization, 0) + for _, org := range q.organizations { + if len(args.IDs) > 0 { + if !slices.Contains(args.IDs, org.ID) { + continue + } + } + if args.Name != "" && !strings.EqualFold(org.Name, args.Name) { + continue + } + tmp = append(tmp, org) } - return q.organizations, nil + + return tmp, nil } func (q *FakeQuerier) GetOrganizationsByUserID(_ context.Context, userID uuid.UUID) ([]database.Organization, error) { @@ -3060,9 +3070,7 @@ func (q *FakeQuerier) GetOrganizationsByUserID(_ context.Context, userID uuid.UU organizations = append(organizations, organization) } } - if len(organizations) == 0 { - return nil, sql.ErrNoRows - } + return organizations, nil } diff --git a/coderd/database/dbmem/dbmem_test.go b/coderd/database/dbmem/dbmem_test.go index e7d7bd76bd132..11d30e61a895d 100644 --- a/coderd/database/dbmem/dbmem_test.go +++ b/coderd/database/dbmem/dbmem_test.go @@ -46,7 +46,7 @@ func TestInTx(t *testing.T) { go func() { <-inTx for i := 0; i < 20; i++ { - orgs, err := uut.GetOrganizations(context.Background()) + orgs, err := uut.GetOrganizations(context.Background(), database.GetOrganizationsParams{}) if err != nil { assert.ErrorIs(t, err, sql.ErrNoRows) } diff --git a/coderd/database/dbmetrics/dbmetrics.go b/coderd/database/dbmetrics/dbmetrics.go index 2215a45a6fc4b..ee619e3e7a813 100644 --- a/coderd/database/dbmetrics/dbmetrics.go +++ b/coderd/database/dbmetrics/dbmetrics.go @@ -865,9 +865,9 @@ func (m metricsStore) GetOrganizationIDsByMemberIDs(ctx context.Context, ids []u return organizations, err } -func (m metricsStore) GetOrganizations(ctx context.Context) ([]database.Organization, error) { +func (m metricsStore) GetOrganizations(ctx context.Context, args database.GetOrganizationsParams) ([]database.Organization, error) { start := time.Now() - organizations, err := m.s.GetOrganizations(ctx) + organizations, err := m.s.GetOrganizations(ctx, args) m.queryLatencies.WithLabelValues("GetOrganizations").Observe(time.Since(start).Seconds()) return organizations, err } diff --git a/coderd/database/dbmock/dbmock.go b/coderd/database/dbmock/dbmock.go index 66f15a5e8c99d..51832321ef742 100644 --- a/coderd/database/dbmock/dbmock.go +++ b/coderd/database/dbmock/dbmock.go @@ -1750,18 +1750,18 @@ func (mr *MockStoreMockRecorder) GetOrganizationIDsByMemberIDs(arg0, arg1 any) * } // GetOrganizations mocks base method. -func (m *MockStore) GetOrganizations(arg0 context.Context) ([]database.Organization, error) { +func (m *MockStore) GetOrganizations(arg0 context.Context, arg1 database.GetOrganizationsParams) ([]database.Organization, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetOrganizations", arg0) + ret := m.ctrl.Call(m, "GetOrganizations", arg0, arg1) ret0, _ := ret[0].([]database.Organization) ret1, _ := ret[1].(error) return ret0, ret1 } // GetOrganizations indicates an expected call of GetOrganizations. -func (mr *MockStoreMockRecorder) GetOrganizations(arg0 any) *gomock.Call { +func (mr *MockStoreMockRecorder) GetOrganizations(arg0, arg1 any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOrganizations", reflect.TypeOf((*MockStore)(nil).GetOrganizations), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOrganizations", reflect.TypeOf((*MockStore)(nil).GetOrganizations), arg0, arg1) } // GetOrganizationsByUserID mocks base method. diff --git a/coderd/database/querier.go b/coderd/database/querier.go index 5e662815a7780..a159735cbbf69 100644 --- a/coderd/database/querier.go +++ b/coderd/database/querier.go @@ -180,7 +180,7 @@ type sqlcQuerier interface { GetOrganizationByID(ctx context.Context, id uuid.UUID) (Organization, error) GetOrganizationByName(ctx context.Context, name string) (Organization, error) GetOrganizationIDsByMemberIDs(ctx context.Context, ids []uuid.UUID) ([]GetOrganizationIDsByMemberIDsRow, error) - GetOrganizations(ctx context.Context) ([]Organization, error) + GetOrganizations(ctx context.Context, arg GetOrganizationsParams) ([]Organization, error) GetOrganizationsByUserID(ctx context.Context, userID uuid.UUID) ([]Organization, error) GetParameterSchemasByJobID(ctx context.Context, jobID uuid.UUID) ([]ParameterSchema, error) GetPreviousTemplateVersion(ctx context.Context, arg GetPreviousTemplateVersionParams) (TemplateVersion, error) diff --git a/coderd/database/querier_test.go b/coderd/database/querier_test.go index 3ae9c89444aa0..7b7fd8b0a2823 100644 --- a/coderd/database/querier_test.go +++ b/coderd/database/querier_test.go @@ -516,7 +516,7 @@ func TestDefaultOrg(t *testing.T) { ctx := context.Background() // Should start with the default org - all, err := db.GetOrganizations(ctx) + all, err := db.GetOrganizations(ctx, database.GetOrganizationsParams{}) require.NoError(t, err) require.Len(t, all, 1) require.True(t, all[0].IsDefault, "first org should always be default") @@ -1211,7 +1211,7 @@ func TestExpectOne(t *testing.T) { dbgen.Organization(t, db, database.Organization{}) // Organizations is an easy table without foreign key dependencies - _, err = database.ExpectOne(db.GetOrganizations(ctx)) + _, err = database.ExpectOne(db.GetOrganizations(ctx, database.GetOrganizationsParams{})) require.ErrorContains(t, err, "too many rows returned") }) } diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index 7a25b7f82533b..a8049c36a89f7 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -4596,10 +4596,28 @@ SELECT id, name, description, created_at, updated_at, is_default, display_name, icon FROM organizations +WHERE + true + -- Filter by ids + AND CASE + WHEN array_length($1 :: uuid[], 1) > 0 THEN + id = ANY($1) + ELSE true + END + AND CASE + WHEN $2::text != '' THEN + LOWER("name") = LOWER($2) + ELSE true + END ` -func (q *sqlQuerier) GetOrganizations(ctx context.Context) ([]Organization, error) { - rows, err := q.db.QueryContext(ctx, getOrganizations) +type GetOrganizationsParams struct { + IDs []uuid.UUID `db:"ids" json:"ids"` + Name string `db:"name" json:"name"` +} + +func (q *sqlQuerier) GetOrganizations(ctx context.Context, arg GetOrganizationsParams) ([]Organization, error) { + rows, err := q.db.QueryContext(ctx, getOrganizations, pq.Array(arg.IDs), arg.Name) if err != nil { return nil, err } diff --git a/coderd/database/queries/organizations.sql b/coderd/database/queries/organizations.sql index 787985c3bdbbc..3a74170a913e1 100644 --- a/coderd/database/queries/organizations.sql +++ b/coderd/database/queries/organizations.sql @@ -12,7 +12,21 @@ LIMIT SELECT * FROM - organizations; + organizations +WHERE + true + -- Filter by ids + AND CASE + WHEN array_length(@ids :: uuid[], 1) > 0 THEN + id = ANY(@ids) + ELSE true + END + AND CASE + WHEN @name::text != '' THEN + LOWER("name") = LOWER(@name) + ELSE true + END +; -- name: GetOrganizationByID :one SELECT diff --git a/coderd/organizations.go b/coderd/organizations.go index 2acd3fe401a89..5f05099507b7c 100644 --- a/coderd/organizations.go +++ b/coderd/organizations.go @@ -3,6 +3,7 @@ package coderd import ( "net/http" + "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/db2sdk" "github.com/coder/coder/v2/coderd/httpapi" "github.com/coder/coder/v2/coderd/httpmw" @@ -18,7 +19,7 @@ import ( // @Router /organizations [get] func (api *API) organizations(rw http.ResponseWriter, r *http.Request) { ctx := r.Context() - organizations, err := api.Database.GetOrganizations(ctx) + organizations, err := api.Database.GetOrganizations(ctx, database.GetOrganizationsParams{}) if httpapi.Is404Error(err) { httpapi.ResourceNotFound(rw) return