Skip to content

feat: add auto group create from OIDC #8884

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 22 commits into from
Aug 8, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
sync missing groups
Also added a regex filter to filter out groups that are not
important
  • Loading branch information
Emyrk committed Aug 3, 2023
commit f56113e7787fd1b32c8c4dfb6ae6c9d8d0b2d26f
31 changes: 31 additions & 0 deletions cli/clibase/values.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"net"
"net/url"
"reflect"
"regexp"
"strconv"
"strings"
"time"
Expand Down Expand Up @@ -461,6 +462,36 @@ func (e *Enum) String() string {
return *e.Value
}

type Regexp regexp.Regexp

func RegexpOf(s *regexp.Regexp) *Regexp {
return (*Regexp)(s)
}

func (s *Regexp) Set(v string) error {
exp, err := regexp.Compile(v)
if err != nil {
return xerrors.Errorf("invalid regexp %q: %w", v, err)
}
*s = Regexp(*exp)
return nil
}

func (s Regexp) String() string {
return s.Value().String()
}

func (s *Regexp) Value() *regexp.Regexp {
if s == nil {
return nil
}
return (*regexp.Regexp)(s)
}

func (Regexp) Type() string {
return "regexp"
}

var _ pflag.Value = (*YAMLConfigPath)(nil)

// YAMLConfigPath is a special value type that encodes a path to a YAML
Expand Down
1 change: 1 addition & 0 deletions cli/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -596,6 +596,7 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd.
AuthURLParams: cfg.OIDC.AuthURLParams.Value,
IgnoreUserInfo: cfg.OIDC.IgnoreUserInfo.Value(),
GroupField: cfg.OIDC.GroupField.String(),
GroupFilter: cfg.OIDC.GroupRegexFilter.Value(),
CreateMissingGroups: cfg.OIDC.GroupAutoCreate.Value(),
GroupMapping: cfg.OIDC.GroupMapping.Value,
UserRoleField: cfg.OIDC.UserRoleField.String(),
Expand Down
12 changes: 6 additions & 6 deletions coderd/coderd.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,8 @@ type Options struct {
BaseDERPMap *tailcfg.DERPMap
DERPMapUpdateFrequency time.Duration
SwaggerEndpoint bool
SetUserGroups func(ctx context.Context, tx database.Store, userID uuid.UUID, groupNames []string, createMissingGroups bool) error
SetUserSiteRoles func(ctx context.Context, tx database.Store, userID uuid.UUID, roles []string) error
SetUserGroups func(ctx context.Context, logger slog.Logger, tx database.Store, userID uuid.UUID, groupNames []string, createMissingGroups bool) error
SetUserSiteRoles func(ctx context.Context, logger slog.Logger, tx database.Store, userID uuid.UUID, roles []string) error
TemplateScheduleStore *atomic.Pointer[schedule.TemplateScheduleStore]
UserQuietHoursScheduleStore *atomic.Pointer[schedule.UserQuietHoursScheduleStore]
// AppSecurityKey is the crypto key used to sign and encrypt tokens related to
Expand Down Expand Up @@ -258,16 +258,16 @@ func New(options *Options) *API {
options.TracerProvider = trace.NewNoopTracerProvider()
}
if options.SetUserGroups == nil {
options.SetUserGroups = func(ctx context.Context, _ database.Store, userID uuid.UUID, groups []string, createMissingGroups bool) error {
options.Logger.Warn(ctx, "attempted to assign OIDC groups without enterprise license",
options.SetUserGroups = func(ctx context.Context, logger slog.Logger, _ database.Store, userID uuid.UUID, groups []string, createMissingGroups bool) error {
logger.Warn(ctx, "attempted to assign OIDC groups without enterprise license",
slog.F("user_id", userID), slog.F("groups", groups), slog.F("create_missing_groups", createMissingGroups),
)
return nil
}
}
if options.SetUserSiteRoles == nil {
options.SetUserSiteRoles = func(ctx context.Context, _ database.Store, userID uuid.UUID, roles []string) error {
options.Logger.Warn(ctx, "attempted to assign OIDC user roles without enterprise license",
options.SetUserSiteRoles = func(ctx context.Context, logger slog.Logger, _ database.Store, userID uuid.UUID, roles []string) error {
logger.Warn(ctx, "attempted to assign OIDC user roles without enterprise license",
slog.F("user_id", userID), slog.F("roles", roles),
)
return nil
Expand Down
7 changes: 7 additions & 0 deletions coderd/database/dbauthz/dbauthz.go
Original file line number Diff line number Diff line change
Expand Up @@ -1834,6 +1834,13 @@ func (q *querier) InsertLicense(ctx context.Context, arg database.InsertLicenseP
return q.db.InsertLicense(ctx, arg)
}

func (q *querier) InsertMissingGroups(ctx context.Context, arg database.InsertMissingGroupsParams) ([]database.Group, error) {
if err := q.authorizeContext(ctx, rbac.ActionCreate, rbac.ResourceSystem); err != nil {
return nil, err
}
return q.db.InsertMissingGroups(ctx, arg)
}

func (q *querier) InsertOrganization(ctx context.Context, arg database.InsertOrganizationParams) (database.Organization, error) {
return insert(q.log, q.auth, rbac.ResourceOrganization, q.db.InsertOrganization)(ctx, arg)
}
Expand Down
38 changes: 38 additions & 0 deletions coderd/database/dbfake/dbfake.go
Original file line number Diff line number Diff line change
Expand Up @@ -3578,6 +3578,44 @@ func (q *FakeQuerier) InsertLicense(
return l, nil
}

func (q *FakeQuerier) InsertMissingGroups(ctx context.Context, arg database.InsertMissingGroupsParams) ([]database.Group, error) {
err := validateDatabaseType(arg)
if err != nil {
return nil, err
}

groupNameMap := make(map[string]struct{})
for _, g := range arg.GroupNames {
groupNameMap[g] = struct{}{}
}

q.mutex.Lock()
defer q.mutex.Unlock()

for _, g := range q.groups {
if g.OrganizationID != arg.OrganizationID {
continue
}
delete(groupNameMap, g.Name)
}

newGroups := make([]database.Group, 0, len(groupNameMap))
for k := range groupNameMap {
g := database.Group{
ID: uuid.New(),
Name: k,
OrganizationID: arg.OrganizationID,
AvatarURL: "",
QuotaAllowance: 0,
DisplayName: "",
}
q.groups = append(q.groups, g)
newGroups = append(newGroups, g)
}

return newGroups, nil
}

func (q *FakeQuerier) InsertOrganization(_ context.Context, arg database.InsertOrganizationParams) (database.Organization, error) {
if err := validateDatabaseType(arg); err != nil {
return database.Organization{}, err
Expand Down
7 changes: 7 additions & 0 deletions coderd/database/dbmetrics/dbmetrics.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions coderd/database/querier.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

52 changes: 52 additions & 0 deletions coderd/database/queries.sql.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

17 changes: 17 additions & 0 deletions coderd/database/queries/groups.sql
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,23 @@ INSERT INTO groups (
VALUES
($1, $2, $3, $4, $5, $6) RETURNING *;

-- name: InsertMissingGroups :many
INSERT INTO groups (
id,
name,
organization_id
)
SELECT
gen_random_uuid(),
group_name,
@organization_id
FROM
UNNEST(@group_names :: text[]) AS group_name
-- If the name conflicts, do nothing.
ON CONFLICT DO NOTHING
RETURNING *;


-- We use the organization_id as the id
-- for simplicity since all users is
-- every member of the org.
Expand Down
21 changes: 19 additions & 2 deletions coderd/userauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"fmt"
"net/http"
"net/mail"
"regexp"
"sort"
"strconv"
"strings"
Expand Down Expand Up @@ -691,6 +692,10 @@ type OIDCConfig struct {
// CreateMissingGroups controls whether groups returned by the OIDC provider
// are automatically created in Coder if they are missing.
CreateMissingGroups bool
// GroupFilter is a regular expression that filters the groups returned by
// the OIDC provider. Any group not matched by this regex will be ignored.
// If the group filter is nil, then no group filtering will occur.
GroupFilter *regexp.Regexp
// GroupMapping controls how groups returned by the OIDC provider get mapped
// to groups within Coder.
// map[oidcGroupName]coderGroupName
Expand Down Expand Up @@ -1046,6 +1051,7 @@ func (api *API) userOIDC(rw http.ResponseWriter, r *http.Request) {
Roles: roles,
Groups: groups,
CreateMissingGroups: api.OIDCConfig.CreateMissingGroups,
GroupFilter: api.OIDCConfig.GroupFilter,
}).SetInitAuditRequest(func(params *audit.RequestParams) (*audit.Request[database.User], func()) {
return audit.InitRequest[database.User](rw, params)
})
Expand Down Expand Up @@ -1132,6 +1138,7 @@ type oauthLoginParams struct {
UsingGroups bool
CreateMissingGroups bool
Groups []string
GroupFilter *regexp.Regexp
// Is UsingRoles is true, then the user will be assigned
// the roles provided.
UsingRoles bool
Expand Down Expand Up @@ -1347,8 +1354,18 @@ func (api *API) oauthLogin(r *http.Request, params *oauthLoginParams) ([]*http.C

// Ensure groups are correct.
if params.UsingGroups {
filtered := params.Groups
if params.GroupFilter != nil {
filtered = make([]string, 0, len(params.Groups))
for _, group := range params.Groups {
if params.GroupFilter.MatchString(group) {
filtered = append(filtered, group)
}
}
}

//nolint:gocritic
err := api.Options.SetUserGroups(dbauthz.AsSystemRestricted(ctx), tx, user.ID, params.Groups)
err := api.Options.SetUserGroups(dbauthz.AsSystemRestricted(ctx), logger, tx, user.ID, filtered, params.CreateMissingGroups)
if err != nil {
return xerrors.Errorf("set user groups: %w", err)
}
Expand All @@ -1367,7 +1384,7 @@ func (api *API) oauthLogin(r *http.Request, params *oauthLoginParams) ([]*http.C
}

//nolint:gocritic
err := api.Options.SetUserSiteRoles(dbauthz.AsSystemRestricted(ctx), tx, user.ID, filtered)
err := api.Options.SetUserSiteRoles(dbauthz.AsSystemRestricted(ctx), logger, tx, user.ID, filtered)
if err != nil {
return httpError{
code: http.StatusBadRequest,
Expand Down
11 changes: 11 additions & 0 deletions codersdk/deployment.go
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,7 @@ type OIDCConfig struct {
AuthURLParams clibase.Struct[map[string]string] `json:"auth_url_params" typescript:",notnull"`
IgnoreUserInfo clibase.Bool `json:"ignore_user_info" typescript:",notnull"`
GroupAutoCreate clibase.Bool `json:"group_auto_create" typescript:",notnull"`
GroupRegexFilter clibase.Regexp `json:"group_regex_filter" typescript:",notnull"`
GroupField clibase.String `json:"groups_field" typescript:",notnull"`
GroupMapping clibase.Struct[map[string]string] `json:"group_mapping" typescript:",notnull"`
UserRoleField clibase.String `json:"user_role_field" typescript:",notnull"`
Expand Down Expand Up @@ -1076,6 +1077,16 @@ when required by your organization's security policy.`,
Group: &deploymentGroupOIDC,
YAML: "enableGroupAutoCreate",
},
{
Name: "OIDC Regex Group Filter",
Description: "If provided any group name not matching the regex is ignored. This allows for filtering out groups that are not needed.",
Flag: "oidc-group-regex-filter",
Env: "CODER_OIDC_GROUP_REGEX_FILTER",
Default: "",
Value: &c.OIDC.GroupRegexFilter,
Group: &deploymentGroupOIDC,
YAML: "groupRegexFilter",
},
{
Name: "OIDC User Role Field",
Description: "This field must be set if using the user roles sync feature. Set this to the name of the claim used to store the user's role. The roles should be sent as an array of strings.",
Expand Down
25 changes: 21 additions & 4 deletions enterprise/coderd/userauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@ import (
"cdr.dev/slog"
"github.com/coder/coder/coderd"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/database/dbauthz"
"github.com/coder/coder/codersdk"
)

func (api *API) setUserGroups(ctx context.Context, db database.Store, userID uuid.UUID, groupNames []string, createMissingGroups bool) error {
func (api *API) setUserGroups(ctx context.Context, logger slog.Logger, db database.Store, userID uuid.UUID, groupNames []string, createMissingGroups bool) error {
api.entitlementsMu.RLock()
enabled := api.entitlements.Features[codersdk.FeatureTemplateRBAC].Enabled
api.entitlementsMu.RUnlock()
Expand All @@ -39,7 +40,23 @@ func (api *API) setUserGroups(ctx context.Context, db database.Store, userID uui
return xerrors.Errorf("delete user groups: %w", err)
}

// TODO: Create missing groups if createMissingGroups is true.
if createMissingGroups {
// This is the system creating these additional groups, so we use the system restricted context.
// nolint:gocritic
created, err := tx.InsertMissingGroups(dbauthz.AsSystemRestricted(ctx), database.InsertMissingGroupsParams{
OrganizationID: orgs[0].ID,
GroupNames: groupNames,
})
if err != nil {
return xerrors.Errorf("insert missing groups: %w", err)
}
if len(created) > 0 {
logger.Debug(ctx, "auto created missing groups",
slog.F("org_id", orgs[0].ID),
slog.F("created", created),
)
}
}

// Re-add the user to all groups returned by the auth provider.
err = tx.InsertUserGroupsByName(ctx, database.InsertUserGroupsByNameParams{
Expand All @@ -55,13 +72,13 @@ func (api *API) setUserGroups(ctx context.Context, db database.Store, userID uui
}, nil)
}

func (api *API) setUserSiteRoles(ctx context.Context, db database.Store, userID uuid.UUID, roles []string) error {
func (api *API) setUserSiteRoles(ctx context.Context, logger slog.Logger, db database.Store, userID uuid.UUID, roles []string) error {
api.entitlementsMu.RLock()
enabled := api.entitlements.Features[codersdk.FeatureUserRoleManagement].Enabled
api.entitlementsMu.RUnlock()

if !enabled {
api.Logger.Warn(ctx, "attempted to assign OIDC user roles without enterprise entitlement, roles left unchanged",
logger.Warn(ctx, "attempted to assign OIDC user roles without enterprise entitlement, roles left unchanged",
slog.F("user_id", userID), slog.F("roles", roles),
)
return nil
Expand Down
Loading