From f66947a8bdbeae48a16fa0f14baa694281846dcc Mon Sep 17 00:00:00 2001 From: Colin Adler Date: Tue, 25 Jul 2023 22:37:44 +0000 Subject: [PATCH] fix(enterprise/scim): ensure creating a user is idempotent --- enterprise/coderd/scim.go | 21 ++++++++++++++++++++- enterprise/coderd/scim_test.go | 33 +++++++++++++++++++++++++++++++++ 2 files changed, 53 insertions(+), 1 deletion(-) diff --git a/enterprise/coderd/scim.go b/enterprise/coderd/scim.go index c46ff8f5dd3d7..efba55b932684 100644 --- a/enterprise/coderd/scim.go +++ b/enterprise/coderd/scim.go @@ -2,6 +2,7 @@ package coderd import ( "crypto/subtle" + "database/sql" "encoding/json" "net/http" @@ -11,6 +12,7 @@ import ( scimjson "github.com/imulab/go-scim/pkg/v2/json" "github.com/imulab/go-scim/pkg/v2/service" "github.com/imulab/go-scim/pkg/v2/spec" + "golang.org/x/xerrors" agpl "github.com/coder/coder/coderd" "github.com/coder/coder/coderd/database" @@ -152,6 +154,23 @@ func (api *API) scimPostUser(rw http.ResponseWriter, r *http.Request) { return } + //nolint:gocritic + user, err := api.Database.GetUserByEmailOrUsername(dbauthz.AsSystemRestricted(ctx), database.GetUserByEmailOrUsernameParams{ + Email: email, + Username: sUser.UserName, + }) + if err != nil && !xerrors.Is(err, sql.ErrNoRows) { + _ = handlerutil.WriteError(rw, err) + return + } + if err == nil { + sUser.ID = user.ID.String() + sUser.UserName = user.Username + + httpapi.Write(ctx, rw, http.StatusOK, sUser) + return + } + // The username is a required property in Coder. We make a best-effort // attempt at using what the claims provide, but if that fails we will // generate a random username. @@ -182,7 +201,7 @@ func (api *API) scimPostUser(rw http.ResponseWriter, r *http.Request) { } //nolint:gocritic // needed for SCIM - user, _, err := api.AGPL.CreateUser(dbauthz.AsSystemRestricted(ctx), api.Database, agpl.CreateUserRequest{ + user, _, err = api.AGPL.CreateUser(dbauthz.AsSystemRestricted(ctx), api.Database, agpl.CreateUserRequest{ CreateUserRequest: codersdk.CreateUserRequest{ Username: sUser.UserName, Email: email, diff --git a/enterprise/coderd/scim_test.go b/enterprise/coderd/scim_test.go index a72a1d227424b..f0778c26b51d0 100644 --- a/enterprise/coderd/scim_test.go +++ b/enterprise/coderd/scim_test.go @@ -131,6 +131,39 @@ func TestScim(t *testing.T) { assert.Equal(t, sUser.UserName, userRes.Users[0].Username) }) + t.Run("Duplicate", func(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) + defer cancel() + + scimAPIKey := []byte("hi") + client, _ := coderdenttest.New(t, &coderdenttest.Options{ + SCIMAPIKey: scimAPIKey, + LicenseOptions: &coderdenttest.LicenseOptions{ + AccountID: "coolin", + Features: license.Features{ + codersdk.FeatureSCIM: 1, + }, + }, + }) + + sUser := makeScimUser(t) + for i := 0; i < 3; i++ { + res, err := client.Request(ctx, "POST", "/scim/v2/Users", sUser, setScimAuth(scimAPIKey)) + require.NoError(t, err) + _ = res.Body.Close() + assert.Equal(t, http.StatusOK, res.StatusCode) + } + + userRes, err := client.Users(ctx, codersdk.UsersRequest{Search: sUser.Emails[0].Value}) + require.NoError(t, err) + require.Len(t, userRes.Users, 1) + + assert.Equal(t, sUser.Emails[0].Value, userRes.Users[0].Email) + assert.Equal(t, sUser.UserName, userRes.Users[0].Username) + }) + t.Run("DomainStrips", func(t *testing.T) { t.Parallel()