Skip to content

Commit f56113e

Browse files
committed
sync missing groups
Also added a regex filter to filter out groups that are not important
1 parent 604a79f commit f56113e

File tree

13 files changed

+410
-14
lines changed

13 files changed

+410
-14
lines changed

cli/clibase/values.go

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"net"
88
"net/url"
99
"reflect"
10+
"regexp"
1011
"strconv"
1112
"strings"
1213
"time"
@@ -461,6 +462,36 @@ func (e *Enum) String() string {
461462
return *e.Value
462463
}
463464

465+
type Regexp regexp.Regexp
466+
467+
func RegexpOf(s *regexp.Regexp) *Regexp {
468+
return (*Regexp)(s)
469+
}
470+
471+
func (s *Regexp) Set(v string) error {
472+
exp, err := regexp.Compile(v)
473+
if err != nil {
474+
return xerrors.Errorf("invalid regexp %q: %w", v, err)
475+
}
476+
*s = Regexp(*exp)
477+
return nil
478+
}
479+
480+
func (s Regexp) String() string {
481+
return s.Value().String()
482+
}
483+
484+
func (s *Regexp) Value() *regexp.Regexp {
485+
if s == nil {
486+
return nil
487+
}
488+
return (*regexp.Regexp)(s)
489+
}
490+
491+
func (Regexp) Type() string {
492+
return "regexp"
493+
}
494+
464495
var _ pflag.Value = (*YAMLConfigPath)(nil)
465496

466497
// YAMLConfigPath is a special value type that encodes a path to a YAML

cli/server.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -596,6 +596,7 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd.
596596
AuthURLParams: cfg.OIDC.AuthURLParams.Value,
597597
IgnoreUserInfo: cfg.OIDC.IgnoreUserInfo.Value(),
598598
GroupField: cfg.OIDC.GroupField.String(),
599+
GroupFilter: cfg.OIDC.GroupRegexFilter.Value(),
599600
CreateMissingGroups: cfg.OIDC.GroupAutoCreate.Value(),
600601
GroupMapping: cfg.OIDC.GroupMapping.Value,
601602
UserRoleField: cfg.OIDC.UserRoleField.String(),

coderd/coderd.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -126,8 +126,8 @@ type Options struct {
126126
BaseDERPMap *tailcfg.DERPMap
127127
DERPMapUpdateFrequency time.Duration
128128
SwaggerEndpoint bool
129-
SetUserGroups func(ctx context.Context, tx database.Store, userID uuid.UUID, groupNames []string, createMissingGroups bool) error
130-
SetUserSiteRoles func(ctx context.Context, tx database.Store, userID uuid.UUID, roles []string) error
129+
SetUserGroups func(ctx context.Context, logger slog.Logger, tx database.Store, userID uuid.UUID, groupNames []string, createMissingGroups bool) error
130+
SetUserSiteRoles func(ctx context.Context, logger slog.Logger, tx database.Store, userID uuid.UUID, roles []string) error
131131
TemplateScheduleStore *atomic.Pointer[schedule.TemplateScheduleStore]
132132
UserQuietHoursScheduleStore *atomic.Pointer[schedule.UserQuietHoursScheduleStore]
133133
// AppSecurityKey is the crypto key used to sign and encrypt tokens related to
@@ -258,16 +258,16 @@ func New(options *Options) *API {
258258
options.TracerProvider = trace.NewNoopTracerProvider()
259259
}
260260
if options.SetUserGroups == nil {
261-
options.SetUserGroups = func(ctx context.Context, _ database.Store, userID uuid.UUID, groups []string, createMissingGroups bool) error {
262-
options.Logger.Warn(ctx, "attempted to assign OIDC groups without enterprise license",
261+
options.SetUserGroups = func(ctx context.Context, logger slog.Logger, _ database.Store, userID uuid.UUID, groups []string, createMissingGroups bool) error {
262+
logger.Warn(ctx, "attempted to assign OIDC groups without enterprise license",
263263
slog.F("user_id", userID), slog.F("groups", groups), slog.F("create_missing_groups", createMissingGroups),
264264
)
265265
return nil
266266
}
267267
}
268268
if options.SetUserSiteRoles == nil {
269-
options.SetUserSiteRoles = func(ctx context.Context, _ database.Store, userID uuid.UUID, roles []string) error {
270-
options.Logger.Warn(ctx, "attempted to assign OIDC user roles without enterprise license",
269+
options.SetUserSiteRoles = func(ctx context.Context, logger slog.Logger, _ database.Store, userID uuid.UUID, roles []string) error {
270+
logger.Warn(ctx, "attempted to assign OIDC user roles without enterprise license",
271271
slog.F("user_id", userID), slog.F("roles", roles),
272272
)
273273
return nil

coderd/database/dbauthz/dbauthz.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1834,6 +1834,13 @@ func (q *querier) InsertLicense(ctx context.Context, arg database.InsertLicenseP
18341834
return q.db.InsertLicense(ctx, arg)
18351835
}
18361836

1837+
func (q *querier) InsertMissingGroups(ctx context.Context, arg database.InsertMissingGroupsParams) ([]database.Group, error) {
1838+
if err := q.authorizeContext(ctx, rbac.ActionCreate, rbac.ResourceSystem); err != nil {
1839+
return nil, err
1840+
}
1841+
return q.db.InsertMissingGroups(ctx, arg)
1842+
}
1843+
18371844
func (q *querier) InsertOrganization(ctx context.Context, arg database.InsertOrganizationParams) (database.Organization, error) {
18381845
return insert(q.log, q.auth, rbac.ResourceOrganization, q.db.InsertOrganization)(ctx, arg)
18391846
}

coderd/database/dbfake/dbfake.go

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3578,6 +3578,44 @@ func (q *FakeQuerier) InsertLicense(
35783578
return l, nil
35793579
}
35803580

3581+
func (q *FakeQuerier) InsertMissingGroups(ctx context.Context, arg database.InsertMissingGroupsParams) ([]database.Group, error) {
3582+
err := validateDatabaseType(arg)
3583+
if err != nil {
3584+
return nil, err
3585+
}
3586+
3587+
groupNameMap := make(map[string]struct{})
3588+
for _, g := range arg.GroupNames {
3589+
groupNameMap[g] = struct{}{}
3590+
}
3591+
3592+
q.mutex.Lock()
3593+
defer q.mutex.Unlock()
3594+
3595+
for _, g := range q.groups {
3596+
if g.OrganizationID != arg.OrganizationID {
3597+
continue
3598+
}
3599+
delete(groupNameMap, g.Name)
3600+
}
3601+
3602+
newGroups := make([]database.Group, 0, len(groupNameMap))
3603+
for k := range groupNameMap {
3604+
g := database.Group{
3605+
ID: uuid.New(),
3606+
Name: k,
3607+
OrganizationID: arg.OrganizationID,
3608+
AvatarURL: "",
3609+
QuotaAllowance: 0,
3610+
DisplayName: "",
3611+
}
3612+
q.groups = append(q.groups, g)
3613+
newGroups = append(newGroups, g)
3614+
}
3615+
3616+
return newGroups, nil
3617+
}
3618+
35813619
func (q *FakeQuerier) InsertOrganization(_ context.Context, arg database.InsertOrganizationParams) (database.Organization, error) {
35823620
if err := validateDatabaseType(arg); err != nil {
35833621
return database.Organization{}, err

coderd/database/dbmetrics/dbmetrics.go

Lines changed: 7 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

coderd/database/querier.go

Lines changed: 2 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

coderd/database/queries.sql.go

Lines changed: 52 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

coderd/database/queries/groups.sql

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,23 @@ INSERT INTO groups (
4242
VALUES
4343
($1, $2, $3, $4, $5, $6) RETURNING *;
4444

45+
-- name: InsertMissingGroups :many
46+
INSERT INTO groups (
47+
id,
48+
name,
49+
organization_id
50+
)
51+
SELECT
52+
gen_random_uuid(),
53+
group_name,
54+
@organization_id
55+
FROM
56+
UNNEST(@group_names :: text[]) AS group_name
57+
-- If the name conflicts, do nothing.
58+
ON CONFLICT DO NOTHING
59+
RETURNING *;
60+
61+
4562
-- We use the organization_id as the id
4663
-- for simplicity since all users is
4764
-- every member of the org.

coderd/userauth.go

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"fmt"
88
"net/http"
99
"net/mail"
10+
"regexp"
1011
"sort"
1112
"strconv"
1213
"strings"
@@ -691,6 +692,10 @@ type OIDCConfig struct {
691692
// CreateMissingGroups controls whether groups returned by the OIDC provider
692693
// are automatically created in Coder if they are missing.
693694
CreateMissingGroups bool
695+
// GroupFilter is a regular expression that filters the groups returned by
696+
// the OIDC provider. Any group not matched by this regex will be ignored.
697+
// If the group filter is nil, then no group filtering will occur.
698+
GroupFilter *regexp.Regexp
694699
// GroupMapping controls how groups returned by the OIDC provider get mapped
695700
// to groups within Coder.
696701
// map[oidcGroupName]coderGroupName
@@ -1046,6 +1051,7 @@ func (api *API) userOIDC(rw http.ResponseWriter, r *http.Request) {
10461051
Roles: roles,
10471052
Groups: groups,
10481053
CreateMissingGroups: api.OIDCConfig.CreateMissingGroups,
1054+
GroupFilter: api.OIDCConfig.GroupFilter,
10491055
}).SetInitAuditRequest(func(params *audit.RequestParams) (*audit.Request[database.User], func()) {
10501056
return audit.InitRequest[database.User](rw, params)
10511057
})
@@ -1132,6 +1138,7 @@ type oauthLoginParams struct {
11321138
UsingGroups bool
11331139
CreateMissingGroups bool
11341140
Groups []string
1141+
GroupFilter *regexp.Regexp
11351142
// Is UsingRoles is true, then the user will be assigned
11361143
// the roles provided.
11371144
UsingRoles bool
@@ -1347,8 +1354,18 @@ func (api *API) oauthLogin(r *http.Request, params *oauthLoginParams) ([]*http.C
13471354

13481355
// Ensure groups are correct.
13491356
if params.UsingGroups {
1357+
filtered := params.Groups
1358+
if params.GroupFilter != nil {
1359+
filtered = make([]string, 0, len(params.Groups))
1360+
for _, group := range params.Groups {
1361+
if params.GroupFilter.MatchString(group) {
1362+
filtered = append(filtered, group)
1363+
}
1364+
}
1365+
}
1366+
13501367
//nolint:gocritic
1351-
err := api.Options.SetUserGroups(dbauthz.AsSystemRestricted(ctx), tx, user.ID, params.Groups)
1368+
err := api.Options.SetUserGroups(dbauthz.AsSystemRestricted(ctx), logger, tx, user.ID, filtered, params.CreateMissingGroups)
13521369
if err != nil {
13531370
return xerrors.Errorf("set user groups: %w", err)
13541371
}
@@ -1367,7 +1384,7 @@ func (api *API) oauthLogin(r *http.Request, params *oauthLoginParams) ([]*http.C
13671384
}
13681385

13691386
//nolint:gocritic
1370-
err := api.Options.SetUserSiteRoles(dbauthz.AsSystemRestricted(ctx), tx, user.ID, filtered)
1387+
err := api.Options.SetUserSiteRoles(dbauthz.AsSystemRestricted(ctx), logger, tx, user.ID, filtered)
13711388
if err != nil {
13721389
return httpError{
13731390
code: http.StatusBadRequest,

codersdk/deployment.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,7 @@ type OIDCConfig struct {
272272
AuthURLParams clibase.Struct[map[string]string] `json:"auth_url_params" typescript:",notnull"`
273273
IgnoreUserInfo clibase.Bool `json:"ignore_user_info" typescript:",notnull"`
274274
GroupAutoCreate clibase.Bool `json:"group_auto_create" typescript:",notnull"`
275+
GroupRegexFilter clibase.Regexp `json:"group_regex_filter" typescript:",notnull"`
275276
GroupField clibase.String `json:"groups_field" typescript:",notnull"`
276277
GroupMapping clibase.Struct[map[string]string] `json:"group_mapping" typescript:",notnull"`
277278
UserRoleField clibase.String `json:"user_role_field" typescript:",notnull"`
@@ -1076,6 +1077,16 @@ when required by your organization's security policy.`,
10761077
Group: &deploymentGroupOIDC,
10771078
YAML: "enableGroupAutoCreate",
10781079
},
1080+
{
1081+
Name: "OIDC Regex Group Filter",
1082+
Description: "If provided any group name not matching the regex is ignored. This allows for filtering out groups that are not needed.",
1083+
Flag: "oidc-group-regex-filter",
1084+
Env: "CODER_OIDC_GROUP_REGEX_FILTER",
1085+
Default: "",
1086+
Value: &c.OIDC.GroupRegexFilter,
1087+
Group: &deploymentGroupOIDC,
1088+
YAML: "groupRegexFilter",
1089+
},
10791090
{
10801091
Name: "OIDC User Role Field",
10811092
Description: "This field must be set if using the user roles sync feature. Set this to the name of the claim used to store the user's role. The roles should be sent as an array of strings.",

enterprise/coderd/userauth.go

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,11 @@ import (
99
"cdr.dev/slog"
1010
"github.com/coder/coder/coderd"
1111
"github.com/coder/coder/coderd/database"
12+
"github.com/coder/coder/coderd/database/dbauthz"
1213
"github.com/coder/coder/codersdk"
1314
)
1415

15-
func (api *API) setUserGroups(ctx context.Context, db database.Store, userID uuid.UUID, groupNames []string, createMissingGroups bool) error {
16+
func (api *API) setUserGroups(ctx context.Context, logger slog.Logger, db database.Store, userID uuid.UUID, groupNames []string, createMissingGroups bool) error {
1617
api.entitlementsMu.RLock()
1718
enabled := api.entitlements.Features[codersdk.FeatureTemplateRBAC].Enabled
1819
api.entitlementsMu.RUnlock()
@@ -39,7 +40,23 @@ func (api *API) setUserGroups(ctx context.Context, db database.Store, userID uui
3940
return xerrors.Errorf("delete user groups: %w", err)
4041
}
4142

42-
// TODO: Create missing groups if createMissingGroups is true.
43+
if createMissingGroups {
44+
// This is the system creating these additional groups, so we use the system restricted context.
45+
// nolint:gocritic
46+
created, err := tx.InsertMissingGroups(dbauthz.AsSystemRestricted(ctx), database.InsertMissingGroupsParams{
47+
OrganizationID: orgs[0].ID,
48+
GroupNames: groupNames,
49+
})
50+
if err != nil {
51+
return xerrors.Errorf("insert missing groups: %w", err)
52+
}
53+
if len(created) > 0 {
54+
logger.Debug(ctx, "auto created missing groups",
55+
slog.F("org_id", orgs[0].ID),
56+
slog.F("created", created),
57+
)
58+
}
59+
}
4360

4461
// Re-add the user to all groups returned by the auth provider.
4562
err = tx.InsertUserGroupsByName(ctx, database.InsertUserGroupsByNameParams{
@@ -55,13 +72,13 @@ func (api *API) setUserGroups(ctx context.Context, db database.Store, userID uui
5572
}, nil)
5673
}
5774

58-
func (api *API) setUserSiteRoles(ctx context.Context, db database.Store, userID uuid.UUID, roles []string) error {
75+
func (api *API) setUserSiteRoles(ctx context.Context, logger slog.Logger, db database.Store, userID uuid.UUID, roles []string) error {
5976
api.entitlementsMu.RLock()
6077
enabled := api.entitlements.Features[codersdk.FeatureUserRoleManagement].Enabled
6178
api.entitlementsMu.RUnlock()
6279

6380
if !enabled {
64-
api.Logger.Warn(ctx, "attempted to assign OIDC user roles without enterprise entitlement, roles left unchanged",
81+
logger.Warn(ctx, "attempted to assign OIDC user roles without enterprise entitlement, roles left unchanged",
6582
slog.F("user_id", userID), slog.F("roles", roles),
6683
)
6784
return nil

0 commit comments

Comments
 (0)