Skip to content

Commit bfe6835

Browse files
committed
test: add full org sync tests
1 parent d0f31d1 commit bfe6835

File tree

11 files changed

+304
-29
lines changed

11 files changed

+304
-29
lines changed

coderd/database/dbauthz/dbauthz.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,7 @@ var (
243243
rbac.ResourceAssignOrgRole.Type: rbac.ResourceAssignOrgRole.AvailableActions(),
244244
rbac.ResourceSystem.Type: {policy.WildcardSymbol},
245245
rbac.ResourceOrganization.Type: {policy.ActionCreate, policy.ActionRead},
246-
rbac.ResourceOrganizationMember.Type: {policy.ActionCreate},
246+
rbac.ResourceOrganizationMember.Type: {policy.ActionCreate, policy.ActionDelete, policy.ActionRead},
247247
rbac.ResourceProvisionerDaemon.Type: {policy.ActionCreate, policy.ActionUpdate},
248248
rbac.ResourceProvisionerKeys.Type: {policy.ActionCreate, policy.ActionRead, policy.ActionDelete},
249249
rbac.ResourceUser.Type: rbac.ResourceUser.AvailableActions(),

coderd/idpsync/idpsync.go

+16
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import (
2222
// claims to the internal representation of a user in Coder.
2323
// TODO: Move group + role sync into this interface.
2424
type IDPSync interface {
25+
OrganizationSyncEnabled() bool
2526
// ParseOrganizationClaims takes claims from an OIDC provider, and returns the
2627
// organization sync params for assigning users into organizations.
2728
ParseOrganizationClaims(ctx context.Context, _ jwt.MapClaims) (OrganizationParams, *HTTPError)
@@ -105,6 +106,21 @@ func ParseStringSliceClaim(claim interface{}) ([]string, error) {
105106
return nil, xerrors.Errorf("invalid claim type. Expected an array of strings, got: %T", claim)
106107
}
107108

109+
// IsHTTPError handles us being inconsistent with returning errors as values or
110+
// pointers.
111+
func IsHTTPError(err error) *HTTPError {
112+
var httpErr HTTPError
113+
if xerrors.As(err, &httpErr) {
114+
return &httpErr
115+
}
116+
117+
var httpErrPtr *HTTPError
118+
if xerrors.As(err, &httpErrPtr) {
119+
return httpErrPtr
120+
}
121+
return nil
122+
}
123+
108124
// HTTPError is a helper struct for returning errors from the IDP sync process.
109125
// A regular error is not sufficient because many of these errors are surfaced
110126
// to a user logging in, and the errors should be descriptive.

coderd/idpsync/idpsync_test.go

+8
Original file line numberDiff line numberDiff line change
@@ -135,3 +135,11 @@ func TestParseStringSliceClaim(t *testing.T) {
135135
})
136136
}
137137
}
138+
139+
func TestIsHTTPError(t *testing.T) {
140+
herr := idpsync.HTTPError{}
141+
require.NotNil(t, idpsync.IsHTTPError(herr))
142+
require.NotNil(t, idpsync.IsHTTPError(&herr))
143+
144+
require.Nil(t, error(nil))
145+
}

coderd/idpsync/organization.go

+6-1
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,15 @@ import (
1616
"github.com/coder/coder/v2/coderd/util/slice"
1717
)
1818

19+
func (s AGPLIDPSync) OrganizationSyncEnabled() bool {
20+
// AGPL does not support syncing organizations.
21+
return false
22+
}
23+
1924
func (s AGPLIDPSync) ParseOrganizationClaims(_ context.Context, _ jwt.MapClaims) (OrganizationParams, *HTTPError) {
2025
// For AGPL we only sync the default organization.
2126
return OrganizationParams{
22-
SyncEnabled: false,
27+
SyncEnabled: s.OrganizationSyncEnabled(),
2328
IncludeDefault: s.OrganizationAssignDefault,
2429
Organizations: []uuid.UUID{},
2530
}, nil

coderd/members.go

+9
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package coderd
22

33
import (
44
"context"
5+
"fmt"
56
"net/http"
67

78
"github.com/google/uuid"
@@ -43,6 +44,14 @@ func (api *API) postOrganizationMember(rw http.ResponseWriter, r *http.Request)
4344
aReq.Old = database.AuditableOrganizationMember{}
4445
defer commitAudit()
4546

47+
if user.LoginType == database.LoginTypeOIDC && api.IDPSync.OrganizationSyncEnabled() {
48+
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
49+
Message: "Organization sync is enabled for OIDC users, meaning manual organization assignment is not allowed for this user.",
50+
Detail: fmt.Sprintf("User %s is an OIDC user and organization sync is enabled. Ask an administrator to resolve this in your external IDP.", user.ID),
51+
})
52+
return
53+
}
54+
4655
member, err := api.Database.InsertOrganizationMember(ctx, database.InsertOrganizationMemberParams{
4756
OrganizationID: organization.ID,
4857
UserID: user.ID,

coderd/userauth.go

+8-10
Original file line numberDiff line numberDiff line change
@@ -669,12 +669,11 @@ func (api *API) userOAuth2Github(rw http.ResponseWriter, r *http.Request) {
669669
})
670670
cookies, user, key, err := api.oauthLogin(r, params)
671671
defer params.CommitAuditLogs()
672-
var httpErr idpsync.HTTPError
673-
if xerrors.As(err, &httpErr) {
674-
httpErr.Write(rw, r)
675-
return
676-
}
677672
if err != nil {
673+
if httpErr := idpsync.IsHTTPError(err); httpErr != nil {
674+
httpErr.Write(rw, r)
675+
return
676+
}
678677
logger.Error(ctx, "oauth2: login failed", slog.F("user", user.Username), slog.Error(err))
679678
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
680679
Message: "Failed to process OAuth login.",
@@ -1066,12 +1065,11 @@ func (api *API) userOIDC(rw http.ResponseWriter, r *http.Request) {
10661065
})
10671066
cookies, user, key, err := api.oauthLogin(r, params)
10681067
defer params.CommitAuditLogs()
1069-
var httpErr idpsync.HTTPError
1070-
if xerrors.As(err, &httpErr) {
1071-
httpErr.Write(rw, r)
1072-
return
1073-
}
10741068
if err != nil {
1069+
if hErr := idpsync.IsHTTPError(err); hErr != nil {
1070+
hErr.Write(rw, r)
1071+
return
1072+
}
10751073
logger.Error(ctx, "oauth2: login failed", slog.F("user", user.Username), slog.Error(err))
10761074
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
10771075
Message: "Failed to process OAuth login.",

coderd/userauth_test.go

+6
Original file line numberDiff line numberDiff line change
@@ -366,6 +366,7 @@ func TestUserOAuth2Github(t *testing.T) {
366366
require.Equal(t, "kyle", user.Username)
367367
require.Equal(t, "Kylium Carbonate", user.Name)
368368
require.Equal(t, "/hello-world", user.AvatarURL)
369+
require.Equal(t, 1, len(user.OrganizationIDs), "in the default org")
369370

370371
require.Len(t, auditor.AuditLogs(), numLogs)
371372
require.NotEqual(t, auditor.AuditLogs()[numLogs-1].UserID, uuid.Nil)
@@ -419,6 +420,7 @@ func TestUserOAuth2Github(t *testing.T) {
419420
require.Equal(t, "kyle", user.Username)
420421
require.Equal(t, strings.Repeat("a", 128), user.Name)
421422
require.Equal(t, "/hello-world", user.AvatarURL)
423+
require.Equal(t, 1, len(user.OrganizationIDs), "in the default org")
422424

423425
require.Len(t, auditor.AuditLogs(), numLogs)
424426
require.NotEqual(t, auditor.AuditLogs()[numLogs-1].UserID, uuid.Nil)
@@ -474,6 +476,7 @@ func TestUserOAuth2Github(t *testing.T) {
474476
require.Equal(t, "kyle", user.Username)
475477
require.Equal(t, "Kylium Carbonate", user.Name)
476478
require.Equal(t, "/hello-world", user.AvatarURL)
479+
require.Equal(t, 1, len(user.OrganizationIDs), "in the default org")
477480

478481
require.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)
479482
require.Len(t, auditor.AuditLogs(), numLogs)
@@ -536,6 +539,7 @@ func TestUserOAuth2Github(t *testing.T) {
536539
require.Equal(t, "mathias@coder.com", user.Email)
537540
require.Equal(t, "mathias", user.Username)
538541
require.Equal(t, "Mathias Mathias", user.Name)
542+
require.Equal(t, 1, len(user.OrganizationIDs), "in the default org")
539543

540544
require.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)
541545
require.Len(t, auditor.AuditLogs(), numLogs)
@@ -598,6 +602,7 @@ func TestUserOAuth2Github(t *testing.T) {
598602
require.Equal(t, "mathias@coder.com", user.Email)
599603
require.Equal(t, "mathias", user.Username)
600604
require.Equal(t, "Mathias Mathias", user.Name)
605+
require.Equal(t, 1, len(user.OrganizationIDs), "in the default org")
601606

602607
require.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)
603608
require.Len(t, auditor.AuditLogs(), numLogs)
@@ -1270,6 +1275,7 @@ func TestUserOIDC(t *testing.T) {
12701275
require.Len(t, auditor.AuditLogs(), numLogs)
12711276
require.NotEqual(t, uuid.Nil, auditor.AuditLogs()[numLogs-1].UserID)
12721277
require.Equal(t, database.AuditActionRegister, auditor.AuditLogs()[numLogs-1].Action)
1278+
require.Equal(t, 1, len(user.OrganizationIDs), "in the default org")
12731279
}
12741280
})
12751281
}

coderd/users.go

-4
Original file line numberDiff line numberDiff line change
@@ -1294,10 +1294,6 @@ func (api *API) CreateUser(ctx context.Context, store database.Store, req Create
12941294
var user database.User
12951295
err := store.InTx(func(tx database.Store) error {
12961296
orgRoles := make([]string, 0)
1297-
// Organization is required to know where to allocate the user.
1298-
if len(req.OrganizationIDs) == 0 {
1299-
return xerrors.Errorf("organization ID must be provided")
1300-
}
13011297

13021298
params := database.InsertUserParams{
13031299
ID: uuid.New(),

enterprise/coderd/enidpsync/enidpsync.go

+1-7
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,6 @@ import (
77
"github.com/coder/coder/v2/coderd/idpsync"
88
)
99

10-
func init() {
11-
idpsync.NewSync = func(logger slog.Logger, entitlements *entitlements.Set, settings idpsync.SyncSettings) idpsync.IDPSync {
12-
return NewSync(logger, entitlements, settings)
13-
}
14-
}
15-
1610
type EnterpriseIDPSync struct {
1711
entitlements *entitlements.Set
1812
*idpsync.AGPLIDPSync
@@ -21,6 +15,6 @@ type EnterpriseIDPSync struct {
2115
func NewSync(logger slog.Logger, set *entitlements.Set, settings idpsync.SyncSettings) *EnterpriseIDPSync {
2216
return &EnterpriseIDPSync{
2317
entitlements: set,
24-
AGPLIDPSync: idpsync.NewAGPLSync(logger.With(slog.F("enterprise_capable", "true")), set, settings),
18+
AGPLIDPSync: idpsync.NewAGPLSync(logger.With(slog.F("enterprise_capable", "true")), settings),
2519
}
2620
}

enterprise/coderd/enidpsync/organizations.go

+5-1
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,12 @@ import (
1414
"github.com/coder/coder/v2/codersdk"
1515
)
1616

17+
func (e EnterpriseIDPSync) OrganizationSyncEnabled() bool {
18+
return e.entitlements.Enabled(codersdk.FeatureMultipleOrganizations) && e.OrganizationField != ""
19+
}
20+
1721
func (e EnterpriseIDPSync) ParseOrganizationClaims(ctx context.Context, mergedClaims jwt.MapClaims) (idpsync.OrganizationParams, *idpsync.HTTPError) {
18-
if !e.entitlements.Enabled(codersdk.FeatureMultipleOrganizations) {
22+
if !e.OrganizationSyncEnabled() {
1923
// Default to agpl if multi-org is not enabled
2024
return e.AGPLIDPSync.ParseOrganizationClaims(ctx, mergedClaims)
2125
}

0 commit comments

Comments
 (0)