Skip to content

Commit f2f76e9

Browse files
committed
add login_type to users table
1 parent b5dc95b commit f2f76e9

File tree

9 files changed

+92
-38
lines changed

9 files changed

+92
-38
lines changed

coderd/database/dump.sql

Lines changed: 2 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

coderd/database/migrations/000034_linked_user_id.up.sql

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,4 +51,19 @@ ALTER TABLE api_keys
5151
DROP COLUMN oauth_id_token,
5252
DROP COLUMN oauth_expiry;
5353

54+
ALTER TABLE users ADD COLUMN login_type login_type NOT NULL DEFAULT 'password';
55+
56+
UPDATE
57+
users
58+
SET
59+
login_type = (
60+
SELECT
61+
login_type
62+
FROM
63+
user_links
64+
WHERE
65+
user_links.user_id = users.id
66+
LIMIT 1
67+
);
68+
5469
COMMIT;

coderd/database/models.go

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

coderd/database/postgres/postgres.go

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,6 @@ func Open() (string, func(), error) {
136136
}
137137
err = database.MigrateUp(db)
138138
if err != nil {
139-
fmt.Printf("err: %v\n", err)
140139
return xerrors.Errorf("migrate db: %w", err)
141140
}
142141

coderd/database/queries.sql.go

Lines changed: 20 additions & 9 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

coderd/database/queries/users.sql

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,11 @@ INSERT INTO
3737
hashed_password,
3838
created_at,
3939
updated_at,
40-
rbac_roles
40+
rbac_roles,
41+
login_type
4142
)
4243
VALUES
43-
($1, $2, $3, $4, $5, $6, $7) RETURNING *;
44+
($1, $2, $3, $4, $5, $6, $7, $8) RETURNING *;
4445

4546
-- name: UpdateUserProfile :one
4647
UPDATE

coderd/userauth.go

Lines changed: 24 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ func (api *API) userOAuth2Github(rw http.ResponseWriter, r *http.Request) {
137137
return
138138
}
139139

140-
user, link, err := findLinkedUser(ctx, api.Database, database.LoginTypeGithub, githubLinkedID(ghUser), verifiedEmails...)
140+
user, link, err := findLinkedUser(ctx, api.Database, githubLinkedID(ghUser), verifiedEmails...)
141141
if err != nil {
142142
httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{
143143
Message: "An internal error occurred.",
@@ -146,9 +146,9 @@ func (api *API) userOAuth2Github(rw http.ResponseWriter, r *http.Request) {
146146
return
147147
}
148148

149-
if link.UserID != uuid.Nil && link.LoginType != database.LoginTypeGithub {
150-
httpapi.Write(rw, http.StatusConflict, codersdk.Response{
151-
Message: fmt.Sprintf("Incorrect login type, attempting to use %q but user is of login type %q", database.LoginTypeOIDC, link.LoginType),
149+
if user.ID != uuid.Nil && user.LoginType != database.LoginTypeGithub {
150+
httpapi.Write(rw, http.StatusForbidden, codersdk.Response{
151+
Message: fmt.Sprintf("Incorrect login type, attempting to use %q but user is of login type %q", database.LoginTypeOIDC, user.LoginType),
152152
})
153153
return
154154
}
@@ -184,10 +184,13 @@ func (api *API) userOAuth2Github(rw http.ResponseWriter, r *http.Request) {
184184
})
185185
return
186186
}
187-
user, _, err = api.createUser(r.Context(), codersdk.CreateUserRequest{
188-
Email: *verifiedEmail.Email,
189-
Username: *ghUser.Login,
190-
OrganizationID: organizationID,
187+
user, _, err = api.createUser(ctx, createUserRequest{
188+
CreateUserRequest: codersdk.CreateUserRequest{
189+
Email: *verifiedEmail.Email,
190+
Username: *ghUser.Login,
191+
OrganizationID: organizationID,
192+
},
193+
LoginType: database.LoginTypeGithub,
191194
})
192195
if err != nil {
193196
httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{
@@ -332,7 +335,7 @@ func (api *API) userOIDC(rw http.ResponseWriter, r *http.Request) {
332335
}
333336
}
334337

335-
user, link, err := findLinkedUser(ctx, api.Database, database.LoginTypeOIDC, oidcLinkedID(idToken), claims.Email)
338+
user, link, err := findLinkedUser(ctx, api.Database, oidcLinkedID(idToken), claims.Email)
336339
if err != nil {
337340
httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{
338341
Message: "Failed to find user.",
@@ -348,9 +351,9 @@ func (api *API) userOIDC(rw http.ResponseWriter, r *http.Request) {
348351
return
349352
}
350353

351-
if link.UserID != uuid.Nil && link.LoginType != database.LoginTypeOIDC {
352-
httpapi.Write(rw, http.StatusConflict, codersdk.Response{
353-
Message: fmt.Sprintf("Incorrect login type, attempting to use %q but user is of login type %q", database.LoginTypeOIDC, link.LoginType),
354+
if user.ID != uuid.Nil && user.LoginType != database.LoginTypeOIDC {
355+
httpapi.Write(rw, http.StatusForbidden, codersdk.Response{
356+
Message: fmt.Sprintf("Incorrect login type, attempting to use %q but user is of login type %q", database.LoginTypeOIDC, user.LoginType),
354357
})
355358
return
356359
}
@@ -365,10 +368,13 @@ func (api *API) userOIDC(rw http.ResponseWriter, r *http.Request) {
365368
organizationID = organizations[0].ID
366369
}
367370

368-
user, _, err = api.createUser(ctx, codersdk.CreateUserRequest{
369-
Email: claims.Email,
370-
Username: claims.Username,
371-
OrganizationID: organizationID,
371+
user, _, err = api.createUser(ctx, createUserRequest{
372+
CreateUserRequest: codersdk.CreateUserRequest{
373+
Email: claims.Email,
374+
Username: claims.Username,
375+
OrganizationID: organizationID,
376+
},
377+
LoginType: database.LoginTypeOIDC,
372378
})
373379
if err != nil {
374380
httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{
@@ -477,7 +483,7 @@ func oidcLinkedID(tok *oidc.IDToken) string {
477483

478484
// findLinkedUser tries to find a user by their unique OAuth-linked ID.
479485
// If it doesn't not find it, it returns the user by their email.
480-
func findLinkedUser(ctx context.Context, db database.Store, loginType database.LoginType, linkedID string, emails ...string) (database.User, database.UserLink, error) {
486+
func findLinkedUser(ctx context.Context, db database.Store, linkedID string, emails ...string) (database.User, database.UserLink, error) {
481487
var (
482488
user database.User
483489
link database.UserLink
@@ -518,7 +524,7 @@ func findLinkedUser(ctx context.Context, db database.Store, loginType database.L
518524
// possible that a user_link exists without a populated 'linked_id'.
519525
link, err = db.GetUserLinkByUserIDLoginType(ctx, database.GetUserLinkByUserIDLoginTypeParams{
520526
UserID: user.ID,
521-
LoginType: loginType,
527+
LoginType: user.LoginType,
522528
})
523529
if err != nil && !errors.Is(err, sql.ErrNoRows) {
524530
return database.User{}, database.UserLink{}, xerrors.Errorf("get user link by user id and login type: %w", err)

coderd/userauth_test.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ func TestUserOAuth2Github(t *testing.T) {
175175
resp := oauth2Callback(t, client)
176176
require.Equal(t, http.StatusForbidden, resp.StatusCode)
177177
})
178-
t.Run("Login", func(t *testing.T) {
178+
t.Run("MultiLoginNotAllowed", func(t *testing.T) {
179179
t.Parallel()
180180
client := coderdtest.New(t, &coderdtest.Options{
181181
GithubOAuth2Config: &coderd.GithubOAuth2Config{
@@ -199,6 +199,7 @@ func TestUserOAuth2Github(t *testing.T) {
199199
},
200200
},
201201
})
202+
// Creates the first user with login_type 'password'.
202203
_ = coderdtest.CreateFirstUser(t, client)
203204
resp := oauth2Callback(t, client)
204205
require.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)

coderd/users.go

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -77,10 +77,13 @@ func (api *API) postFirstUser(rw http.ResponseWriter, r *http.Request) {
7777
return
7878
}
7979

80-
user, organizationID, err := api.createUser(r.Context(), codersdk.CreateUserRequest{
81-
Email: createUser.Email,
82-
Username: createUser.Username,
83-
Password: createUser.Password,
80+
user, organizationID, err := api.createUser(r.Context(), createUserRequest{
81+
CreateUserRequest: codersdk.CreateUserRequest{
82+
Email: createUser.Email,
83+
Username: createUser.Username,
84+
Password: createUser.Password,
85+
},
86+
LoginType: database.LoginTypePassword,
8487
})
8588
if err != nil {
8689
httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{
@@ -243,7 +246,10 @@ func (api *API) postUser(rw http.ResponseWriter, r *http.Request) {
243246
return
244247
}
245248

246-
user, _, err := api.createUser(r.Context(), req)
249+
user, _, err := api.createUser(r.Context(), createUserRequest{
250+
CreateUserRequest: req,
251+
LoginType: database.LoginTypePassword,
252+
})
247253
if err != nil {
248254
httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{
249255
Message: "Internal error creating user.",
@@ -684,6 +690,13 @@ func (api *API) postLogin(rw http.ResponseWriter, r *http.Request) {
684690
return
685691
}
686692

693+
if user.LoginType != database.LoginTypePassword {
694+
httpapi.Write(rw, http.StatusForbidden, codersdk.Response{
695+
Message: fmt.Sprintf("Incorrect login type, attempting to use %q but user is of login type %q", database.LoginTypeOIDC, user.LoginType),
696+
})
697+
return
698+
}
699+
687700
// If the user doesn't exist, it will be a default struct.
688701
equal, err := userpassword.Compare(string(user.HashedPassword), loginWithPassword.Password)
689702
if err != nil {
@@ -896,7 +909,12 @@ func (api *API) createAPIKey(rw http.ResponseWriter, r *http.Request, params cre
896909
return sessionToken, true
897910
}
898911

899-
func (api *API) createUser(ctx context.Context, req codersdk.CreateUserRequest) (database.User, uuid.UUID, error) {
912+
type createUserRequest struct {
913+
codersdk.CreateUserRequest
914+
LoginType database.LoginType
915+
}
916+
917+
func (api *API) createUser(ctx context.Context, req createUserRequest) (database.User, uuid.UUID, error) {
900918
var user database.User
901919
return user, req.OrganizationID, api.Database.InTx(func(db database.Store) error {
902920
orgRoles := make([]string, 0)
@@ -923,6 +941,7 @@ func (api *API) createUser(ctx context.Context, req codersdk.CreateUserRequest)
923941
UpdatedAt: database.Now(),
924942
// All new users are defaulted to members of the site.
925943
RBACRoles: []string{},
944+
LoginType: req.LoginType,
926945
}
927946
// If a user signs up with OAuth, they can have no password!
928947
if req.Password != "" {

0 commit comments

Comments
 (0)