diff --git a/enterprise/coderd/scim.go b/enterprise/coderd/scim.go index efba55b932684..801ca61349ae3 100644 --- a/enterprise/coderd/scim.go +++ b/enterprise/coderd/scim.go @@ -155,7 +155,7 @@ func (api *API) scimPostUser(rw http.ResponseWriter, r *http.Request) { } //nolint:gocritic - user, err := api.Database.GetUserByEmailOrUsername(dbauthz.AsSystemRestricted(ctx), database.GetUserByEmailOrUsernameParams{ + dbUser, err := api.Database.GetUserByEmailOrUsername(dbauthz.AsSystemRestricted(ctx), database.GetUserByEmailOrUsernameParams{ Email: email, Username: sUser.UserName, }) @@ -164,8 +164,22 @@ func (api *API) scimPostUser(rw http.ResponseWriter, r *http.Request) { return } if err == nil { - sUser.ID = user.ID.String() - sUser.UserName = user.Username + sUser.ID = dbUser.ID.String() + sUser.UserName = dbUser.Username + + if sUser.Active && dbUser.Status == database.UserStatusSuspended { + //nolint:gocritic + _, err = api.Database.UpdateUserStatus(dbauthz.AsSystemRestricted(r.Context()), database.UpdateUserStatusParams{ + ID: dbUser.ID, + // The user will get transitioned to Active after logging in. + Status: database.UserStatusDormant, + UpdatedAt: database.Now(), + }) + if err != nil { + _ = handlerutil.WriteError(rw, err) + return + } + } httpapi.Write(ctx, rw, http.StatusOK, sUser) return @@ -201,7 +215,7 @@ func (api *API) scimPostUser(rw http.ResponseWriter, r *http.Request) { } //nolint:gocritic // needed for SCIM - user, _, err = api.AGPL.CreateUser(dbauthz.AsSystemRestricted(ctx), api.Database, agpl.CreateUserRequest{ + dbUser, _, err = api.AGPL.CreateUser(dbauthz.AsSystemRestricted(ctx), api.Database, agpl.CreateUserRequest{ CreateUserRequest: codersdk.CreateUserRequest{ Username: sUser.UserName, Email: email, @@ -214,8 +228,8 @@ func (api *API) scimPostUser(rw http.ResponseWriter, r *http.Request) { return } - sUser.ID = user.ID.String() - sUser.UserName = user.Username + sUser.ID = dbUser.ID.String() + sUser.UserName = dbUser.Username httpapi.Write(ctx, rw, http.StatusOK, sUser) } @@ -263,7 +277,8 @@ func (api *API) scimPatchUser(rw http.ResponseWriter, r *http.Request) { var status database.UserStatus if sUser.Active { - status = database.UserStatusActive + // The user will get transitioned to Active after logging in. + status = database.UserStatusDormant } else { status = database.UserStatusSuspended } diff --git a/enterprise/coderd/scim_test.go b/enterprise/coderd/scim_test.go index f0778c26b51d0..a74dc9bf3452b 100644 --- a/enterprise/coderd/scim_test.go +++ b/enterprise/coderd/scim_test.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "fmt" + "io" "net/http" "testing" @@ -164,6 +165,54 @@ func TestScim(t *testing.T) { assert.Equal(t, sUser.UserName, userRes.Users[0].Username) }) + t.Run("Unsuspend", func(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) + defer cancel() + + scimAPIKey := []byte("hi") + client, _ := coderdenttest.New(t, &coderdenttest.Options{ + SCIMAPIKey: scimAPIKey, + LicenseOptions: &coderdenttest.LicenseOptions{ + AccountID: "coolin", + Features: license.Features{ + codersdk.FeatureSCIM: 1, + }, + }, + }) + + sUser := makeScimUser(t) + res, err := client.Request(ctx, "POST", "/scim/v2/Users", sUser, setScimAuth(scimAPIKey)) + require.NoError(t, err) + defer res.Body.Close() + assert.Equal(t, http.StatusOK, res.StatusCode) + err = json.NewDecoder(res.Body).Decode(&sUser) + require.NoError(t, err) + + sUser.Active = false + res, err = client.Request(ctx, "PATCH", "/scim/v2/Users/"+sUser.ID, sUser, setScimAuth(scimAPIKey)) + require.NoError(t, err) + _, _ = io.Copy(io.Discard, res.Body) + _ = res.Body.Close() + assert.Equal(t, http.StatusOK, res.StatusCode) + + sUser.Active = true + res, err = client.Request(ctx, "POST", "/scim/v2/Users", sUser, setScimAuth(scimAPIKey)) + require.NoError(t, err) + _, _ = io.Copy(io.Discard, res.Body) + _ = res.Body.Close() + assert.Equal(t, http.StatusOK, res.StatusCode) + + userRes, err := client.Users(ctx, codersdk.UsersRequest{Search: sUser.Emails[0].Value}) + require.NoError(t, err) + require.Len(t, userRes.Users, 1) + + assert.Equal(t, sUser.Emails[0].Value, userRes.Users[0].Email) + assert.Equal(t, sUser.UserName, userRes.Users[0].Username) + assert.Equal(t, codersdk.UserStatusDormant, userRes.Users[0].Status) + }) + t.Run("DomainStrips", func(t *testing.T) { t.Parallel() @@ -185,7 +234,8 @@ func TestScim(t *testing.T) { sUser.UserName = sUser.UserName + "@coder.com" res, err := client.Request(ctx, "POST", "/scim/v2/Users", sUser, setScimAuth(scimAPIKey)) require.NoError(t, err) - defer res.Body.Close() + _, _ = io.Copy(io.Discard, res.Body) + _ = res.Body.Close() assert.Equal(t, http.StatusOK, res.StatusCode) userRes, err := client.Users(ctx, codersdk.UsersRequest{Search: sUser.Emails[0].Value}) @@ -220,7 +270,8 @@ func TestScim(t *testing.T) { res, err := client.Request(ctx, "PATCH", "/scim/v2/Users/bob", struct{}{}) require.NoError(t, err) - defer res.Body.Close() + _, _ = io.Copy(io.Discard, res.Body) + _ = res.Body.Close() assert.Equal(t, http.StatusNotFound, res.StatusCode) }) @@ -242,7 +293,8 @@ func TestScim(t *testing.T) { res, err := client.Request(ctx, "PATCH", "/scim/v2/Users/bob", struct{}{}) require.NoError(t, err) - defer res.Body.Close() + _, _ = io.Copy(io.Discard, res.Body) + _ = res.Body.Close() assert.Equal(t, http.StatusInternalServerError, res.StatusCode) }) @@ -276,7 +328,8 @@ func TestScim(t *testing.T) { res, err = client.Request(ctx, "PATCH", "/scim/v2/Users/"+sUser.ID, sUser, setScimAuth(scimAPIKey)) require.NoError(t, err) - defer res.Body.Close() + _, _ = io.Copy(io.Discard, res.Body) + _ = res.Body.Close() assert.Equal(t, http.StatusOK, res.StatusCode) userRes, err := client.Users(ctx, codersdk.UsersRequest{Search: sUser.Emails[0].Value})