Skip to content

chore: implement organization sync and create idpsync package #14432

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 23 commits into from
Aug 30, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
test: start implementing sync tests
  • Loading branch information
Emyrk committed Aug 29, 2024
commit 951a7240622b4791be013bfc86a6588d3138cb44
3 changes: 2 additions & 1 deletion coderd/idpsync/idpsync.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"net/http"
"strings"

"github.com/golang-jwt/jwt/v4"
"github.com/google/uuid"
"golang.org/x/xerrors"

Expand All @@ -29,7 +30,7 @@ var NewSync = func(logger slog.Logger, entitlements *entitlements.Set, settings
type IDPSync interface {
// ParseOrganizationClaims takes claims from an OIDC provider, and returns the
// organization sync params for assigning users into organizations.
ParseOrganizationClaims(ctx context.Context, _ map[string]interface{}) (OrganizationParams, *HttpError)
ParseOrganizationClaims(ctx context.Context, _ jwt.MapClaims) (OrganizationParams, *HttpError)
// SyncOrganizations assigns and removed users from organizations based on the
// provided params.
SyncOrganizations(ctx context.Context, tx database.Store, user database.User, params OrganizationParams) error
Expand Down
3 changes: 2 additions & 1 deletion coderd/idpsync/organization.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"database/sql"

"github.com/golang-jwt/jwt/v4"
"github.com/google/uuid"
"golang.org/x/xerrors"

Expand All @@ -15,7 +16,7 @@ import (
"github.com/coder/coder/v2/coderd/util/slice"
)

func (s AGPLIDPSync) ParseOrganizationClaims(ctx context.Context, _ map[string]interface{}) (OrganizationParams, *HttpError) {
func (s AGPLIDPSync) ParseOrganizationClaims(ctx context.Context, _ jwt.MapClaims) (OrganizationParams, *HttpError) {
// nolint:gocritic // all syncing is done as a system user
ctx = dbauthz.AsSystemRestricted(ctx)

Expand Down
58 changes: 58 additions & 0 deletions coderd/idpsync/organizations_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
package idpsync

import (
"testing"

"github.com/golang-jwt/jwt/v4"
"github.com/google/uuid"
"github.com/stretchr/testify/require"

"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/v2/coderd/entitlements"
"github.com/coder/coder/v2/testutil"
)

func TestParseOrganizationClaims(t *testing.T) {
t.Parallel()

t.Run("SingleOrgDeployment", func(t *testing.T) {
t.Parallel()

s := NewAGPLSync(slogtest.Make(t, &slogtest.Options{}), entitlements.New(), SyncSettings{
OrganizationField: "",
OrganizationMapping: nil,
OrganizationAssignDefault: true,
})

ctx := testutil.Context(t, testutil.WaitMedium)

params, err := s.ParseOrganizationClaims(ctx, jwt.MapClaims{})
require.Nil(t, err)

require.Empty(t, params.Organizations)
require.True(t, params.IncludeDefault)
require.False(t, params.SyncEnabled)
})

t.Run("AGPL", func(t *testing.T) {
t.Parallel()

// AGPL has limited behavior
s := NewAGPLSync(slogtest.Make(t, &slogtest.Options{}), entitlements.New(), SyncSettings{
OrganizationField: "orgs",
OrganizationMapping: map[string][]uuid.UUID{
"random": {uuid.New()},
},
OrganizationAssignDefault: false,
})

ctx := testutil.Context(t, testutil.WaitMedium)

params, err := s.ParseOrganizationClaims(ctx, jwt.MapClaims{})
require.Nil(t, err)

require.Empty(t, params.Organizations)
require.False(t, params.IncludeDefault)
require.False(t, params.SyncEnabled)
})
}
3 changes: 2 additions & 1 deletion enterprise/coderd/enidpsync/organizations.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"net/http"

"github.com/golang-jwt/jwt/v4"
"github.com/google/uuid"

"cdr.dev/slog"
Expand All @@ -12,7 +13,7 @@ import (
"github.com/coder/coder/v2/codersdk"
)

func (e EnterpriseIDPSync) ParseOrganizationClaims(ctx context.Context, mergedClaims map[string]interface{}) (idpsync.OrganizationParams, *idpsync.HttpError) {
func (e EnterpriseIDPSync) ParseOrganizationClaims(ctx context.Context, mergedClaims jwt.MapClaims) (idpsync.OrganizationParams, *idpsync.HttpError) {
if !e.entitlements.Enabled(codersdk.FeatureMultipleOrganizations) {
// Default to agpl if multi-org is not enabled
return e.AGPLIDPSync.ParseOrganizationClaims(ctx, mergedClaims)
Expand Down
183 changes: 183 additions & 0 deletions enterprise/coderd/enidpsync/organizations_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
package enidpsync

import (
"context"
"testing"

"github.com/golang-jwt/jwt/v4"
"github.com/google/uuid"
"github.com/prometheus/client_golang/prometheus"
"github.com/stretchr/testify/require"

"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/v2/coderd/coderdtest"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/db2sdk"
"github.com/coder/coder/v2/coderd/database/dbauthz"
"github.com/coder/coder/v2/coderd/database/dbgen"
"github.com/coder/coder/v2/coderd/database/dbtestutil"
"github.com/coder/coder/v2/coderd/entitlements"
"github.com/coder/coder/v2/coderd/idpsync"
"github.com/coder/coder/v2/coderd/rbac"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/testutil"
)

type ExpectedUser struct {
SyncError bool
Organizations []uuid.UUID
}

type Expectations struct {
Name string
Claims jwt.MapClaims
// Parse
ParseError func(t *testing.T, httpErr *idpsync.HttpError)
ExpectedParams idpsync.OrganizationParams
// Mutate allows mutating the user before syncing
Mutate func(t *testing.T, db database.Store, user database.User)
Sync ExpectedUser
}

type OrganizationSyncTestCase struct {
Settings idpsync.SyncSettings
Entitlements *entitlements.Set
Exps []Expectations
}

func TestOrganizationSync(t *testing.T) {
t.Parallel()

if dbtestutil.WillUsePostgres() {
t.Skip("Skipping test because it populates a lot of db entries, which is slow on postgres")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is OK since we aren't introducing new tables or columns and using existing (and well tested) data structures. 👍

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, I've done this a few times now when I need to populate a lot of data. Essentially, this test is not designed to test the database, but test the logic in the sync method. There is little upside to making this a full SQL db imo.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Especially since this is in the enidpsync package. I use a real DB in the coderd/userauth_test.go file.

}

requireUserOrgs := func(t *testing.T, db database.Store, user database.User, expected []uuid.UUID) {
t.Helper()

// nolint:gocritic // in testing
members, err := db.OrganizationMembers(dbauthz.AsSystemRestricted(context.Background()), database.OrganizationMembersParams{
UserID: user.ID,
})
require.NoError(t, err)

foundIDs := db2sdk.List(members, func(m database.OrganizationMembersRow) uuid.UUID {
return m.OrganizationMember.OrganizationID
})
require.ElementsMatch(t, expected, foundIDs, "match user organizations")
}

entitled := entitlements.New()
entitled.Update(func(entitlements *codersdk.Entitlements) {
entitlements.Features[codersdk.FeatureMultipleOrganizations] = codersdk.Feature{
Entitlement: codersdk.EntitlementEntitled,
Enabled: true,
Limit: nil,
Actual: nil,
}
})

testCases := []struct {
Name string
Case func(t *testing.T, db database.Store) OrganizationSyncTestCase
}{
{
Name: "SingleOrgDeployment",
Case: func(t *testing.T, db database.Store) OrganizationSyncTestCase {
def, _ := db.GetDefaultOrganization(context.Background())
other := dbgen.Organization(t, db, database.Organization{})
return OrganizationSyncTestCase{
Entitlements: entitled,
Settings: idpsync.SyncSettings{
OrganizationField: "",
OrganizationMapping: nil,
OrganizationAssignDefault: true,
},
Exps: []Expectations{
{
Name: "NoOrganizations",
Claims: jwt.MapClaims{},
ExpectedParams: idpsync.OrganizationParams{
SyncEnabled: false,
IncludeDefault: true,
Organizations: []uuid.UUID{},
},
Sync: ExpectedUser{
Organizations: []uuid.UUID{},
},
},
{
Name: "AlreadyInOrgs",
Claims: jwt.MapClaims{},
ExpectedParams: idpsync.OrganizationParams{
SyncEnabled: false,
IncludeDefault: true,
Organizations: []uuid.UUID{},
},
Mutate: func(t *testing.T, db database.Store, user database.User) {
dbgen.OrganizationMember(t, db, database.OrganizationMember{
UserID: user.ID,
OrganizationID: def.ID,
})
dbgen.OrganizationMember(t, db, database.OrganizationMember{
UserID: user.ID,
OrganizationID: other.ID,
})
},
Sync: ExpectedUser{
Organizations: []uuid.UUID{def.ID, other.ID},
},
},
},
}
},
},
}

for _, tc := range testCases {
tc := tc
t.Run(tc.Name, func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitMedium)
logger := slogtest.Make(t, &slogtest.Options{})

rdb, _ := dbtestutil.NewDB(t)
db := dbauthz.New(rdb, rbac.NewAuthorizer(prometheus.NewRegistry()), logger, coderdtest.AccessControlStorePointer())
caseData := tc.Case(t, rdb)
if caseData.Entitlements == nil {
caseData.Entitlements = entitlements.New()
}

// Create a new sync object
sync := NewSync(logger, caseData.Entitlements, caseData.Settings)
for _, exp := range caseData.Exps {
t.Run(exp.Name, func(t *testing.T) {
params, httpErr := sync.ParseOrganizationClaims(ctx, exp.Claims)
if exp.ParseError != nil {
exp.ParseError(t, httpErr)
return
}

require.Equal(t, exp.ExpectedParams.SyncEnabled, params.SyncEnabled, "match enabled")
require.Equal(t, exp.ExpectedParams.IncludeDefault, params.IncludeDefault, "match include default")
if exp.ExpectedParams.Organizations == nil {
exp.ExpectedParams.Organizations = []uuid.UUID{}
}
require.ElementsMatch(t, exp.ExpectedParams.Organizations, params.Organizations, "match organizations")

user := dbgen.User(t, db, database.User{})
if exp.Mutate != nil {
exp.Mutate(t, db, user)
}

err := sync.SyncOrganizations(ctx, db, user, params)
if exp.Sync.SyncError {
require.Error(t, err)
return
}
requireUserOrgs(t, db, user, exp.Sync.Organizations)
})
}
})
}
}