Skip to content

Commit 30a635a

Browse files
authored
fix(enterprise): ensure scim usernames are validated (#7925)
1 parent a4cc883 commit 30a635a

File tree

3 files changed

+53
-4
lines changed

3 files changed

+53
-4
lines changed

coderd/users.go

+6
Original file line numberDiff line numberDiff line change
@@ -983,6 +983,12 @@ type CreateUserRequest struct {
983983
}
984984

985985
func (api *API) CreateUser(ctx context.Context, store database.Store, req CreateUserRequest) (database.User, uuid.UUID, error) {
986+
// Ensure the username is valid. It's the caller's responsibility to ensure
987+
// the username is valid and unique.
988+
if usernameValid := httpapi.NameValid(req.Username); usernameValid != nil {
989+
return database.User{}, uuid.Nil, xerrors.Errorf("invalid username %q: %w", req.Username, usernameValid)
990+
}
991+
986992
var user database.User
987993
return user, req.OrganizationID, store.InTx(func(tx database.Store) error {
988994
orgRoles := make([]string, 0)

enterprise/coderd/scim.go

+14-4
Original file line numberDiff line numberDiff line change
@@ -71,10 +71,6 @@ func (api *API) scimGetUsers(rw http.ResponseWriter, r *http.Request) {
7171
// This is done to always force Okta to try and create the user, this way we
7272
// don't need to implement fetching users twice.
7373
//
74-
// scimGetUsers intentionally always returns no users. This is done to always force
75-
// Okta to try and create each user individually, this way we don't need to
76-
// implement fetching users twice.
77-
//
7874
// @Summary SCIM 2.0: Get user by ID
7975
// @ID scim-get-user-by-id
8076
// @Security CoderSessionToken
@@ -156,6 +152,20 @@ func (api *API) scimPostUser(rw http.ResponseWriter, r *http.Request) {
156152
return
157153
}
158154

155+
// The username is a required property in Coder. We make a best-effort
156+
// attempt at using what the claims provide, but if that fails we will
157+
// generate a random username.
158+
usernameValid := httpapi.NameValid(sUser.UserName)
159+
if usernameValid != nil {
160+
// If no username is provided, we can default to use the email address.
161+
// This will be converted in the from function below, so it's safe
162+
// to keep the domain.
163+
if sUser.UserName == "" {
164+
sUser.UserName = email
165+
}
166+
sUser.UserName = httpapi.UsernameFrom(sUser.UserName)
167+
}
168+
159169
var organizationID uuid.UUID
160170
//nolint:gocritic
161171
organizations, err := api.Database.GetOrganizations(dbauthz.AsSystemRestricted(ctx))

enterprise/coderd/scim_test.go

+33
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,39 @@ func TestScim(t *testing.T) {
128128
assert.Equal(t, sUser.Emails[0].Value, userRes.Users[0].Email)
129129
assert.Equal(t, sUser.UserName, userRes.Users[0].Username)
130130
})
131+
132+
t.Run("DomainStrips", func(t *testing.T) {
133+
t.Parallel()
134+
135+
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
136+
defer cancel()
137+
138+
scimAPIKey := []byte("hi")
139+
client := coderdenttest.New(t, &coderdenttest.Options{SCIMAPIKey: scimAPIKey})
140+
_ = coderdtest.CreateFirstUser(t, client)
141+
coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{
142+
AccountID: "coolin",
143+
Features: license.Features{
144+
codersdk.FeatureSCIM: 1,
145+
},
146+
})
147+
148+
sUser := makeScimUser(t)
149+
sUser.UserName = sUser.UserName + "@coder.com"
150+
res, err := client.Request(ctx, "POST", "/scim/v2/Users", sUser, setScimAuth(scimAPIKey))
151+
require.NoError(t, err)
152+
defer res.Body.Close()
153+
assert.Equal(t, http.StatusOK, res.StatusCode)
154+
155+
userRes, err := client.Users(ctx, codersdk.UsersRequest{Search: sUser.Emails[0].Value})
156+
require.NoError(t, err)
157+
require.Len(t, userRes.Users, 1)
158+
159+
assert.Equal(t, sUser.Emails[0].Value, userRes.Users[0].Email)
160+
// Username should be the same as the given name. They all use the
161+
// same string before we modified it above.
162+
assert.Equal(t, sUser.Name.GivenName, userRes.Users[0].Username)
163+
})
131164
})
132165

133166
t.Run("patchUser", func(t *testing.T) {

0 commit comments

Comments
 (0)