Skip to content

Commit a89f76b

Browse files
committed
chore: implement change set operation for org sync
add/remove users to/from orgs based on expected set
1 parent 07157f9 commit a89f76b

File tree

4 files changed

+157
-23
lines changed

4 files changed

+157
-23
lines changed

coderd/rbac/roles.go

+3-23
Original file line numberDiff line numberDiff line change
@@ -764,29 +764,9 @@ func SiteRoles() []Role {
764764
// RBAC checks can be applied using "ActionCreate" and "ActionDelete" for
765765
// "added" and "removed" roles respectively.
766766
func ChangeRoleSet(from []RoleIdentifier, to []RoleIdentifier) (added []RoleIdentifier, removed []RoleIdentifier) {
767-
has := make(map[RoleIdentifier]struct{})
768-
for _, exists := range from {
769-
has[exists] = struct{}{}
770-
}
771-
772-
for _, roleName := range to {
773-
// If the user already has the role assigned, we don't need to check the permission
774-
// to reassign it. Only run permission checks on the difference in the set of
775-
// roles.
776-
if _, ok := has[roleName]; ok {
777-
delete(has, roleName)
778-
continue
779-
}
780-
781-
added = append(added, roleName)
782-
}
783-
784-
// Remaining roles are the ones removed/deleted.
785-
for roleName := range has {
786-
removed = append(removed, roleName)
787-
}
788-
789-
return added, removed
767+
return slice.SymmetricDifferenceFunc(from, to, func(a, b RoleIdentifier) bool {
768+
return a.Name == b.Name && a.OrganizationID == b.OrganizationID
769+
})
790770
}
791771

792772
// Permissions is just a helper function to make building roles that list out resources

coderd/userauth.go

+71
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ import (
2525
"golang.org/x/xerrors"
2626

2727
"cdr.dev/slog"
28+
"github.com/coder/coder/v2/coderd/database/db2sdk"
29+
"github.com/coder/coder/v2/coderd/util/slice"
2830

2931
"github.com/coder/coder/v2/coderd/apikey"
3032
"github.com/coder/coder/v2/coderd/audit"
@@ -1430,6 +1432,7 @@ func (api *API) oauthLogin(r *http.Request, params *oauthLoginParams) ([]*http.C
14301432
// This can happen if a user is a built-in user but is signing in
14311433
// with OIDC for the first time.
14321434
if user.ID == uuid.Nil {
1435+
// TODO: Remove this, and only use params
14331436
// Until proper multi-org support, all users will be added to the default organization.
14341437
// The default organization should always be present.
14351438
//nolint:gocritic
@@ -1540,6 +1543,74 @@ func (api *API) oauthLogin(r *http.Request, params *oauthLoginParams) ([]*http.C
15401543
}
15411544
}
15421545

1546+
if params.UsingOrganizations {
1547+
var expected []uuid.UUID
1548+
// ignored keeps track of which configured organization syncs were
1549+
// ignored. Ignored options are no-ops.
1550+
ignored := make([]string, 0)
1551+
for _, orgSearch := range params.Organizations {
1552+
orgID, err := uuid.Parse(orgSearch)
1553+
if err == nil {
1554+
expected = append(expected, orgID)
1555+
continue
1556+
}
1557+
//nolint:gocritic // System actor being used to assign orgs
1558+
org, err := tx.GetOrganizationByName(dbauthz.AsSystemRestricted(ctx), orgSearch)
1559+
if err == nil {
1560+
expected = append(expected, org.ID)
1561+
continue
1562+
}
1563+
ignored = append(ignored, orgSearch)
1564+
}
1565+
1566+
if params.AssignDefaultOrganization {
1567+
//nolint:gocritic // System actor being used to assign orgs
1568+
defaultOrg, err := tx.GetDefaultOrganization(dbauthz.AsSystemRestricted(ctx))
1569+
if err != nil {
1570+
return xerrors.Errorf("get default organization: %w", err)
1571+
}
1572+
expected = append(expected, defaultOrg.ID)
1573+
}
1574+
1575+
// Sync the user's organizations to the ones provided.
1576+
//nolint:gocritic // Using to system to ensure all memberships are returned
1577+
existingOrgs, err := tx.GetOrganizationsByUserID(dbauthz.AsSystemRestricted(ctx), user.ID)
1578+
if err != nil {
1579+
return xerrors.Errorf("get user organizations: %w", err)
1580+
}
1581+
1582+
have := db2sdk.List(existingOrgs, func(org database.Organization) uuid.UUID {
1583+
return org.ID
1584+
})
1585+
// Find the difference in the expected and the existing orgs, and
1586+
// correct the set of orgs the user is a member of.
1587+
add, remove := slice.SymmetricDifference(have, expected)
1588+
for _, orgID := range add {
1589+
//nolint:gocritic // System actor being used to assign orgs
1590+
_, err := tx.InsertOrganizationMember(dbauthz.AsSystemRestricted(ctx), database.InsertOrganizationMemberParams{
1591+
OrganizationID: orgID,
1592+
UserID: user.ID,
1593+
CreatedAt: dbtime.Now(),
1594+
UpdatedAt: dbtime.Now(),
1595+
Roles: []string{},
1596+
})
1597+
if err != nil {
1598+
return xerrors.Errorf("add user to organization: %w", err)
1599+
}
1600+
}
1601+
1602+
for _, orgID := range remove {
1603+
//nolint:gocritic // System actor being used to assign orgs
1604+
err := tx.DeleteOrganizationMember(dbauthz.AsSystemRestricted(ctx), database.DeleteOrganizationMemberParams{
1605+
OrganizationID: orgID,
1606+
UserID: user.ID,
1607+
})
1608+
if err != nil {
1609+
return xerrors.Errorf("remove user from organization: %w", err)
1610+
}
1611+
}
1612+
}
1613+
15431614
// Ensure groups are correct.
15441615
// This places all groups into the default organization.
15451616
// To go multi-org, we need to add a mapping feature here to know which

coderd/util/slice/slice.go

+37
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,43 @@ import (
44
"golang.org/x/exp/constraints"
55
)
66

7+
// SymmetricDifference returns the elements that need to be added and removed
8+
// to get from set 'a' to set 'b'.
9+
// In classical set theory notation, SymmetricDifference returns
10+
// all elements of {add} and {remove} together. It is more useful to
11+
// return them as their own slices.
12+
// Example:
13+
//
14+
// a := []int{1, 3, 4}
15+
// b := []int{1, 2}
16+
// add, remove := SymmetricDifference(a, b)
17+
// fmt.Println(add) // [2]
18+
// fmt.Println(remove) // [3, 4]
19+
func SymmetricDifference[T comparable](a, b []T) (add []T, remove []T) {
20+
return Difference(b, a), Difference(a, b)
21+
}
22+
23+
// Difference returns the elements in 'a' that are not in 'b'.
24+
func Difference[T comparable](a []T, b []T) []T {
25+
return DifferenceFunc(a, b, func(a, b T) bool {
26+
return a == b
27+
})
28+
}
29+
30+
func SymmetricDifferenceFunc[T any](a, b []T, equal func(a, b T) bool) (add []T, remove []T) {
31+
return DifferenceFunc(a, b, equal), DifferenceFunc(b, a, equal)
32+
}
33+
34+
func DifferenceFunc[T any](a []T, b []T, equal func(a, b T) bool) []T {
35+
tmp := make([]T, 0, len(a))
36+
for _, v := range a {
37+
if !ContainsCompare(b, v, equal) {
38+
tmp = append(tmp, v)
39+
}
40+
}
41+
return tmp
42+
}
43+
744
// ToStrings works for any type where the base type is a string.
845
func ToStrings[T ~string](a []T) []string {
946
tmp := make([]string, 0, len(a))

coderd/util/slice/slice_test.go

+46
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,52 @@ import (
1111
"github.com/coder/coder/v2/coderd/util/slice"
1212
)
1313

14+
func TestSymmetricDifference(t *testing.T) {
15+
t.Parallel()
16+
17+
t.Run("Simple", func(t *testing.T) {
18+
add, remove := slice.SymmetricDifference([]int{1, 3, 4}, []int{1, 2})
19+
require.ElementsMatch(t, []int{2}, add)
20+
require.ElementsMatch(t, []int{3, 4}, remove)
21+
})
22+
23+
t.Run("Large", func(t *testing.T) {
24+
add, remove := slice.SymmetricDifference(
25+
[]int{1, 2, 3, 4, 5, 10, 11, 12, 13, 14, 15},
26+
[]int{1, 3, 7, 9, 11, 13, 17},
27+
)
28+
require.ElementsMatch(t, []int{7, 9, 17}, add)
29+
require.ElementsMatch(t, []int{2, 4, 5, 10, 12, 14, 15}, remove)
30+
})
31+
32+
t.Run("AddOnly", func(t *testing.T) {
33+
add, remove := slice.SymmetricDifference(
34+
[]int{1, 2},
35+
[]int{1, 2, 3, 4, 5, 6, 7, 8, 9},
36+
)
37+
require.ElementsMatch(t, []int{3, 4, 5, 6, 7, 8, 9}, add)
38+
require.ElementsMatch(t, []int{}, remove)
39+
})
40+
41+
t.Run("RemoveOnly", func(t *testing.T) {
42+
add, remove := slice.SymmetricDifference(
43+
[]int{1, 2, 3, 4, 5, 6, 7, 8, 9},
44+
[]int{1, 2},
45+
)
46+
require.ElementsMatch(t, []int{}, add)
47+
require.ElementsMatch(t, []int{3, 4, 5, 6, 7, 8, 9}, remove)
48+
})
49+
50+
t.Run("Equal", func(t *testing.T) {
51+
add, remove := slice.SymmetricDifference(
52+
[]int{1, 2, 3, 4, 5, 6, 7, 8, 9},
53+
[]int{1, 2, 3, 4, 5, 6, 7, 8, 9},
54+
)
55+
require.ElementsMatch(t, []int{}, add)
56+
require.ElementsMatch(t, []int{}, remove)
57+
})
58+
}
59+
1460
func TestSameElements(t *testing.T) {
1561
t.Parallel()
1662

0 commit comments

Comments
 (0)