diff --git a/coderd/database/dbauthz/dbauthz.go b/coderd/database/dbauthz/dbauthz.go index a0da90eb52f23..a6c6b34f2dafa 100644 --- a/coderd/database/dbauthz/dbauthz.go +++ b/coderd/database/dbauthz/dbauthz.go @@ -1016,6 +1016,12 @@ func (q *querier) GetDERPMeshKey(ctx context.Context) (string, error) { return q.db.GetDERPMeshKey(ctx) } +func (q *querier) GetDefaultOrganization(ctx context.Context) (database.Organization, error) { + return fetch(q.log, q.auth, func(ctx context.Context, _ any) (database.Organization, error) { + return q.db.GetDefaultOrganization(ctx) + })(ctx, nil) +} + func (q *querier) GetDefaultProxyConfig(ctx context.Context) (database.GetDefaultProxyConfigRow, error) { // No authz checks return q.db.GetDefaultProxyConfig(ctx) diff --git a/coderd/database/dbauthz/dbauthz_test.go b/coderd/database/dbauthz/dbauthz_test.go index 56b6012ba2193..c55b55a3d164d 100644 --- a/coderd/database/dbauthz/dbauthz_test.go +++ b/coderd/database/dbauthz/dbauthz_test.go @@ -570,6 +570,10 @@ func (s *MethodTestSuite) TestOrganization() { o := dbgen.Organization(s.T(), db, database.Organization{}) check.Args(o.ID).Asserts(o, rbac.ActionRead).Returns(o) })) + s.Run("GetDefaultOrganization", s.Subtest(func(db database.Store, check *expects) { + o := dbgen.Organization(s.T(), db, database.Organization{}) + check.Args().Asserts(o, rbac.ActionRead).Returns(o) + })) s.Run("GetOrganizationByName", s.Subtest(func(db database.Store, check *expects) { o := dbgen.Organization(s.T(), db, database.Organization{}) check.Args(o.Name).Asserts(o, rbac.ActionRead).Returns(o) diff --git a/coderd/database/dbmem/dbmem.go b/coderd/database/dbmem/dbmem.go index 6f7dba0999b92..ae0a0d7e48d33 100644 --- a/coderd/database/dbmem/dbmem.go +++ b/coderd/database/dbmem/dbmem.go @@ -1657,6 +1657,18 @@ func (q *FakeQuerier) GetDERPMeshKey(_ context.Context) (string, error) { return q.derpMeshKey, nil } +func (q *FakeQuerier) GetDefaultOrganization(_ context.Context) (database.Organization, error) { + q.mutex.RLock() + defer q.mutex.RUnlock() + + for _, org := range q.organizations { + if org.IsDefault { + return org, nil + } + } + return database.Organization{}, sql.ErrNoRows +} + func (q *FakeQuerier) GetDefaultProxyConfig(_ context.Context) (database.GetDefaultProxyConfigRow, error) { return database.GetDefaultProxyConfigRow{ DisplayName: q.defaultProxyDisplayName, diff --git a/coderd/database/dbmetrics/dbmetrics.go b/coderd/database/dbmetrics/dbmetrics.go index 948de3c763277..b07b7b0305d9c 100644 --- a/coderd/database/dbmetrics/dbmetrics.go +++ b/coderd/database/dbmetrics/dbmetrics.go @@ -433,6 +433,13 @@ func (m metricsStore) GetDERPMeshKey(ctx context.Context) (string, error) { return key, err } +func (m metricsStore) GetDefaultOrganization(ctx context.Context) (database.Organization, error) { + start := time.Now() + r0, r1 := m.s.GetDefaultOrganization(ctx) + m.queryLatencies.WithLabelValues("GetDefaultOrganization").Observe(time.Since(start).Seconds()) + return r0, r1 +} + func (m metricsStore) GetDefaultProxyConfig(ctx context.Context) (database.GetDefaultProxyConfigRow, error) { start := time.Now() resp, err := m.s.GetDefaultProxyConfig(ctx) diff --git a/coderd/database/dbmock/dbmock.go b/coderd/database/dbmock/dbmock.go index d767fd7cf5bd7..cbe91468c2a6d 100644 --- a/coderd/database/dbmock/dbmock.go +++ b/coderd/database/dbmock/dbmock.go @@ -828,6 +828,21 @@ func (mr *MockStoreMockRecorder) GetDERPMeshKey(arg0 any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetDERPMeshKey", reflect.TypeOf((*MockStore)(nil).GetDERPMeshKey), arg0) } +// GetDefaultOrganization mocks base method. +func (m *MockStore) GetDefaultOrganization(arg0 context.Context) (database.Organization, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetDefaultOrganization", arg0) + ret0, _ := ret[0].(database.Organization) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetDefaultOrganization indicates an expected call of GetDefaultOrganization. +func (mr *MockStoreMockRecorder) GetDefaultOrganization(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetDefaultOrganization", reflect.TypeOf((*MockStore)(nil).GetDefaultOrganization), arg0) +} + // GetDefaultProxyConfig mocks base method. func (m *MockStore) GetDefaultProxyConfig(arg0 context.Context) (database.GetDefaultProxyConfigRow, error) { m.ctrl.T.Helper() diff --git a/coderd/database/querier.go b/coderd/database/querier.go index cbeb5b1caf746..4b459e3141216 100644 --- a/coderd/database/querier.go +++ b/coderd/database/querier.go @@ -102,6 +102,7 @@ type sqlcQuerier interface { GetAuthorizationUserRoles(ctx context.Context, userID uuid.UUID) (GetAuthorizationUserRolesRow, error) GetDBCryptKeys(ctx context.Context) ([]DBCryptKey, error) GetDERPMeshKey(ctx context.Context) (string, error) + GetDefaultOrganization(ctx context.Context) (Organization, error) GetDefaultProxyConfig(ctx context.Context) (GetDefaultProxyConfigRow, error) GetDeploymentDAUs(ctx context.Context, tzOffset int32) ([]GetDeploymentDAUsRow, error) GetDeploymentID(ctx context.Context) (string, error) diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index 28d0a34f5ea80..002da316cbccf 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -3142,6 +3142,31 @@ func (q *sqlQuerier) UpdateMemberRoles(ctx context.Context, arg UpdateMemberRole return i, err } +const getDefaultOrganization = `-- name: GetDefaultOrganization :one +SELECT + id, name, description, created_at, updated_at, is_default +FROM + organizations +WHERE + is_default = true +LIMIT + 1 +` + +func (q *sqlQuerier) GetDefaultOrganization(ctx context.Context) (Organization, error) { + row := q.db.QueryRowContext(ctx, getDefaultOrganization) + var i Organization + err := row.Scan( + &i.ID, + &i.Name, + &i.Description, + &i.CreatedAt, + &i.UpdatedAt, + &i.IsDefault, + ) + return i, err +} + const getOrganizationByID = `-- name: GetOrganizationByID :one SELECT id, name, description, created_at, updated_at, is_default diff --git a/coderd/database/queries/organizations.sql b/coderd/database/queries/organizations.sql index 05185e9c90dec..7f901b003bf44 100644 --- a/coderd/database/queries/organizations.sql +++ b/coderd/database/queries/organizations.sql @@ -1,3 +1,13 @@ +-- name: GetDefaultOrganization :one +SELECT + * +FROM + organizations +WHERE + is_default = true +LIMIT + 1; + -- name: GetOrganizations :many SELECT * diff --git a/coderd/users.go b/coderd/users.go index ca757ed80436f..ba6587ecac5fc 100644 --- a/coderd/users.go +++ b/coderd/users.go @@ -401,10 +401,18 @@ func (api *API) postUser(rw http.ResponseWriter, r *http.Request) { return } } else { - // If no organization is provided, add the user to the first - // organization. - organizations, err := api.Database.GetOrganizations(ctx) + // If no organization is provided, add the user to the default + defaultOrg, err := api.Database.GetDefaultOrganization(ctx) if err != nil { + if httpapi.Is404Error(err) { + httpapi.Write(ctx, rw, http.StatusNotFound, + codersdk.Response{ + Message: "Resource not found or you do not have access to this resource", + Detail: "Organization not found", + }, + ) + return + } httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ Message: "Internal error fetching orgs.", Detail: err.Error(), @@ -412,12 +420,7 @@ func (api *API) postUser(rw http.ResponseWriter, r *http.Request) { return } - if len(organizations) > 0 { - // Add the user to the first organization. Once multi-organization - // support is added, we should enable a configuration map of user - // email to organization. - req.OrganizationID = organizations[0].ID - } + req.OrganizationID = defaultOrg.ID } var loginType database.LoginType diff --git a/coderd/users_test.go b/coderd/users_test.go index a9b4fa8de72b3..0d7f5a7bb21b9 100644 --- a/coderd/users_test.go +++ b/coderd/users_test.go @@ -493,21 +493,26 @@ func TestPostUsers(t *testing.T) { t.Parallel() auditor := audit.NewMock() client := coderdtest.New(t, &coderdtest.Options{Auditor: auditor}) - numLogs := len(auditor.AuditLogs()) - firstUser := coderdtest.CreateFirstUser(t, client) - numLogs++ // add an audit log for user create - numLogs++ // add an audit log for login ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) defer cancel() + // Add an extra org to try and confuse user creation + _, err := client.CreateOrganization(ctx, codersdk.CreateOrganizationRequest{ + Name: "foobar", + }) + require.NoError(t, err) + + numLogs := len(auditor.AuditLogs()) + user, err := client.CreateUser(ctx, codersdk.CreateUserRequest{ Email: "another@user.org", Username: "someone-else", Password: "SomeSecurePassword!", }) require.NoError(t, err) + numLogs++ // add an audit log for user create require.Len(t, auditor.AuditLogs(), numLogs) require.Equal(t, database.AuditActionCreate, auditor.AuditLogs()[numLogs-1].Action)