diff --git a/coderd/rbac/roles.go b/coderd/rbac/roles.go index 14f797ca0b4ee..db62bbd6e6d0d 100644 --- a/coderd/rbac/roles.go +++ b/coderd/rbac/roles.go @@ -764,29 +764,9 @@ func SiteRoles() []Role { // RBAC checks can be applied using "ActionCreate" and "ActionDelete" for // "added" and "removed" roles respectively. func ChangeRoleSet(from []RoleIdentifier, to []RoleIdentifier) (added []RoleIdentifier, removed []RoleIdentifier) { - has := make(map[RoleIdentifier]struct{}) - for _, exists := range from { - has[exists] = struct{}{} - } - - for _, roleName := range to { - // If the user already has the role assigned, we don't need to check the permission - // to reassign it. Only run permission checks on the difference in the set of - // roles. - if _, ok := has[roleName]; ok { - delete(has, roleName) - continue - } - - added = append(added, roleName) - } - - // Remaining roles are the ones removed/deleted. - for roleName := range has { - removed = append(removed, roleName) - } - - return added, removed + return slice.SymmetricDifferenceFunc(from, to, func(a, b RoleIdentifier) bool { + return a.Name == b.Name && a.OrganizationID == b.OrganizationID + }) } // Permissions is just a helper function to make building roles that list out resources diff --git a/coderd/util/slice/example_test.go b/coderd/util/slice/example_test.go new file mode 100644 index 0000000000000..f17d9c4aab0ff --- /dev/null +++ b/coderd/util/slice/example_test.go @@ -0,0 +1,21 @@ +package slice_test + +import ( + "fmt" + + "github.com/coder/coder/v2/coderd/util/slice" +) + +//nolint:revive // They want me to error check my Printlns +func ExampleSymmetricDifference() { + // The goal of this function is to find the elements to add & remove from + // set 'a' to make it equal to set 'b'. + a := []int{1, 2, 5, 6} + b := []int{2, 3, 4, 5} + add, remove := slice.SymmetricDifference(a, b) + fmt.Println("Elements to add:", add) + fmt.Println("Elements to remove:", remove) + // Output: + // Elements to add: [3 4] + // Elements to remove: [1 6] +} diff --git a/coderd/util/slice/slice.go b/coderd/util/slice/slice.go index 9bb1da930ff45..e186e0975de70 100644 --- a/coderd/util/slice/slice.go +++ b/coderd/util/slice/slice.go @@ -107,3 +107,35 @@ func Ascending[T constraints.Ordered](a, b T) int { func Descending[T constraints.Ordered](a, b T) int { return -Ascending[T](a, b) } + +// SymmetricDifference returns the elements that need to be added and removed +// to get from set 'a' to set 'b'. +// In classical set theory notation, SymmetricDifference returns +// all elements of {add} and {remove} together. It is more useful to +// return them as their own slices. +// Notation: A Δ B = (A\B) ∪ (B\A) +// Example: +// +// a := []int{1, 3, 4} +// b := []int{1, 2} +// add, remove := SymmetricDifference(a, b) +// fmt.Println(add) // [2] +// fmt.Println(remove) // [3, 4] +func SymmetricDifference[T comparable](a, b []T) (add []T, remove []T) { + f := func(a, b T) bool { return a == b } + return SymmetricDifferenceFunc(a, b, f) +} + +func SymmetricDifferenceFunc[T any](a, b []T, equal func(a, b T) bool) (add []T, remove []T) { + return DifferenceFunc(b, a, equal), DifferenceFunc(a, b, equal) +} + +func DifferenceFunc[T any](a []T, b []T, equal func(a, b T) bool) []T { + tmp := make([]T, 0, len(a)) + for _, v := range a { + if !ContainsCompare(b, v, equal) { + tmp = append(tmp, v) + } + } + return tmp +} diff --git a/coderd/util/slice/slice_test.go b/coderd/util/slice/slice_test.go index ef947a13e7659..5ab61f83ddbc1 100644 --- a/coderd/util/slice/slice_test.go +++ b/coderd/util/slice/slice_test.go @@ -131,3 +131,103 @@ func TestOmit(t *testing.T) { slice.Omit([]string{"a", "b", "c", "d", "e", "f"}, "c", "d", "e"), ) } + +func TestSymmetricDifference(t *testing.T) { + t.Parallel() + + t.Run("Simple", func(t *testing.T) { + t.Parallel() + + add, remove := slice.SymmetricDifference([]int{1, 3, 4}, []int{1, 2}) + require.ElementsMatch(t, []int{2}, add) + require.ElementsMatch(t, []int{3, 4}, remove) + }) + + t.Run("Large", func(t *testing.T) { + t.Parallel() + + add, remove := slice.SymmetricDifference( + []int{1, 2, 3, 4, 5, 10, 11, 12, 13, 14, 15}, + []int{1, 3, 7, 9, 11, 13, 17}, + ) + require.ElementsMatch(t, []int{7, 9, 17}, add) + require.ElementsMatch(t, []int{2, 4, 5, 10, 12, 14, 15}, remove) + }) + + t.Run("AddOnly", func(t *testing.T) { + t.Parallel() + + add, remove := slice.SymmetricDifference( + []int{1, 2}, + []int{1, 2, 3, 4, 5, 6, 7, 8, 9}, + ) + require.ElementsMatch(t, []int{3, 4, 5, 6, 7, 8, 9}, add) + require.ElementsMatch(t, []int{}, remove) + }) + + t.Run("RemoveOnly", func(t *testing.T) { + t.Parallel() + + add, remove := slice.SymmetricDifference( + []int{1, 2, 3, 4, 5, 6, 7, 8, 9}, + []int{1, 2}, + ) + require.ElementsMatch(t, []int{}, add) + require.ElementsMatch(t, []int{3, 4, 5, 6, 7, 8, 9}, remove) + }) + + t.Run("Equal", func(t *testing.T) { + t.Parallel() + + add, remove := slice.SymmetricDifference( + []int{1, 2, 3, 4, 5, 6, 7, 8, 9}, + []int{1, 2, 3, 4, 5, 6, 7, 8, 9}, + ) + require.ElementsMatch(t, []int{}, add) + require.ElementsMatch(t, []int{}, remove) + }) + + t.Run("ToEmpty", func(t *testing.T) { + t.Parallel() + + add, remove := slice.SymmetricDifference( + []int{1, 2, 3}, + []int{}, + ) + require.ElementsMatch(t, []int{}, add) + require.ElementsMatch(t, []int{1, 2, 3}, remove) + }) + + t.Run("ToNil", func(t *testing.T) { + t.Parallel() + + add, remove := slice.SymmetricDifference( + []int{1, 2, 3}, + nil, + ) + require.ElementsMatch(t, []int{}, add) + require.ElementsMatch(t, []int{1, 2, 3}, remove) + }) + + t.Run("FromEmpty", func(t *testing.T) { + t.Parallel() + + add, remove := slice.SymmetricDifference( + []int{}, + []int{1, 2, 3}, + ) + require.ElementsMatch(t, []int{1, 2, 3}, add) + require.ElementsMatch(t, []int{}, remove) + }) + + t.Run("FromNil", func(t *testing.T) { + t.Parallel() + + add, remove := slice.SymmetricDifference( + nil, + []int{1, 2, 3}, + ) + require.ElementsMatch(t, []int{1, 2, 3}, add) + require.ElementsMatch(t, []int{}, remove) + }) +}