Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
audit oauth and write tests
  • Loading branch information
Kira-Pilot committed Feb 2, 2023
commit 809dc7bfda771a883eae7fffc393f337c935b509
70 changes: 61 additions & 9 deletions coderd/userauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,6 @@ func (api *API) userOAuth2Github(rw http.ResponseWriter, r *http.Request) {
aReq.New = database.APIKey{UserID: uuid.New()}
return
}

aReq.New = key

http.SetCookie(rw, cookie)
Expand Down Expand Up @@ -276,16 +275,28 @@ type OIDCConfig struct {
// @Router /users/oidc/callback [get]
func (api *API) userOIDC(rw http.ResponseWriter, r *http.Request) {
var (
ctx = r.Context()
state = httpmw.OAuth2(r)
ctx = r.Context()
state = httpmw.OAuth2(r)
auditor = api.Auditor.Load()
aReq, commitAudit = audit.InitRequest[database.APIKey](rw, &audit.RequestParams{
Audit: *auditor,
Log: api.Logger,
Request: r,
Action: database.AuditActionLogin,
})
)
aReq.Old = database.APIKey{}
defer commitAudit()

// See the example here: https://github.com/coreos/go-oidc
rawIDToken, ok := state.Token.Extra("id_token").(string)
if !ok {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "id_token not found in response payload. Ensure your OIDC callback is configured correctly!",
})
// We pass a disposable user ID just to force an audit diff
// and generate a log for a failed login
aReq.New = database.APIKey{UserID: uuid.New()}
return
}

Expand All @@ -295,6 +306,9 @@ func (api *API) userOIDC(rw http.ResponseWriter, r *http.Request) {
Message: "Failed to verify OIDC token.",
Detail: err.Error(),
})
// We pass a disposable user ID just to force an audit diff
// and generate a log for a failed login
aReq.New = database.APIKey{UserID: uuid.New()}
return
}

Expand All @@ -308,6 +322,9 @@ func (api *API) userOIDC(rw http.ResponseWriter, r *http.Request) {
Message: "Failed to extract OIDC claims.",
Detail: err.Error(),
})
// We pass a disposable user ID just to force an audit diff
// and generate a log for a failed login
aReq.New = database.APIKey{UserID: uuid.New()}
return
}

Expand All @@ -326,6 +343,9 @@ func (api *API) userOIDC(rw http.ResponseWriter, r *http.Request) {
Message: "Failed to unmarshal user info claims.",
Detail: err.Error(),
})
// We pass a disposable user ID just to force an audit diff
// and generate a log for a failed login
aReq.New = database.APIKey{UserID: uuid.New()}
return
}
for k, v := range userInfoClaims {
Expand All @@ -336,6 +356,9 @@ func (api *API) userOIDC(rw http.ResponseWriter, r *http.Request) {
Message: "Failed to obtain user information claims.",
Detail: "The OIDC provider returned no claims as part of the `id_token`. The attempt to fetch claims via the UserInfo endpoint failed: " + err.Error(),
})
// We pass a disposable user ID just to force an audit diff
// and generate a log for a failed login
aReq.New = database.APIKey{UserID: uuid.New()}
return
}

Expand All @@ -355,6 +378,9 @@ func (api *API) userOIDC(rw http.ResponseWriter, r *http.Request) {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "No email found in OIDC payload!",
})
// We pass a disposable user ID just to force an audit diff
// and generate a log for a failed login
aReq.New = database.APIKey{UserID: uuid.New()}
return
}
emailRaw = username
Expand All @@ -364,6 +390,9 @@ func (api *API) userOIDC(rw http.ResponseWriter, r *http.Request) {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: fmt.Sprintf("Email in OIDC payload isn't a string. Got: %t", emailRaw),
})
// We pass a disposable user ID just to force an audit diff
// and generate a log for a failed login
aReq.New = database.APIKey{UserID: uuid.New()}
return
}
verifiedRaw, ok := claims["email_verified"]
Expand All @@ -374,6 +403,9 @@ func (api *API) userOIDC(rw http.ResponseWriter, r *http.Request) {
httpapi.Write(ctx, rw, http.StatusForbidden, codersdk.Response{
Message: fmt.Sprintf("Verify the %q email address on your OIDC provider to authenticate!", email),
})
// We pass a disposable user ID just to force an audit diff
// and generate a log for a failed login
aReq.New = database.APIKey{UserID: uuid.New()}
return
}
api.Logger.Warn(ctx, "allowing unverified oidc email %q")
Expand Down Expand Up @@ -404,6 +436,9 @@ func (api *API) userOIDC(rw http.ResponseWriter, r *http.Request) {
httpapi.Write(ctx, rw, http.StatusForbidden, codersdk.Response{
Message: fmt.Sprintf("Your email %q is not in domains %q !", email, api.OIDCConfig.EmailDomain),
})
// We pass a disposable user ID just to force an audit diff
// and generate a log for a failed login
aReq.New = database.APIKey{UserID: uuid.New()}
return
}
}
Expand All @@ -413,7 +448,22 @@ func (api *API) userOIDC(rw http.ResponseWriter, r *http.Request) {
picture, _ = pictureRaw.(string)
}

cookie, _, err := api.oauthLogin(r, oauthLoginParams{
user, link, err := findLinkedUser(ctx, api.Database, oidcLinkedID(idToken), email)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to find linked user.",
Detail: err.Error(),
})
// We pass a disposable user ID just to force an audit diff
// and generate a log for a failed login
aReq.New = database.APIKey{UserID: uuid.New()}
return
}
aReq.UserID = user.ID

cookie, key, err := api.oauthLogin(r, oauthLoginParams{
User: user,
Link: link,
State: state,
LinkedID: oidcLinkedID(idToken),
LoginType: database.LoginTypeOIDC,
Expand All @@ -428,15 +478,22 @@ func (api *API) userOIDC(rw http.ResponseWriter, r *http.Request) {
Message: httpErr.msg,
Detail: httpErr.detail,
})
// We pass a disposable user ID just to force an audit diff
// and generate a log for a failed login
aReq.New = database.APIKey{UserID: uuid.New()}
return
}
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to process OAuth login.",
Detail: err.Error(),
})
// We pass a disposable user ID just to force an audit diff
// and generate a log for a failed login
aReq.New = database.APIKey{UserID: uuid.New()}
return
}
aReq.New = key

http.SetCookie(rw, cookie)

Expand Down Expand Up @@ -490,11 +547,6 @@ func (api *API) oauthLogin(r *http.Request, params oauthLoginParams) (*http.Cook

user = params.User
link = params.Link
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought it was safe to lift finding the user outside of the transaction and into the callers so we can access User.ID in those places. If, in the callers, we don't find a user, we bail early, never entering oauthLogin.

//
// user, link, err = findLinkedUser(ctx, tx, params.LinkedID, params.Email)
// if err != nil {
// return xerrors.Errorf("find linked user: %w", err)
// }

if user.ID == uuid.Nil && !params.AllowSignups {
return httpError{
Expand Down
38 changes: 37 additions & 1 deletion coderd/userauth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -718,6 +718,7 @@ func TestUserOIDC(t *testing.T) {
tc := tc
t.Run(tc.Name, func(t *testing.T) {
t.Parallel()
auditor := audit.NewMock()
conf := coderdtest.NewOIDCConfig(t, "")

config := conf.OIDCConfig(t, tc.UserInfoClaims)
Expand All @@ -726,9 +727,13 @@ func TestUserOIDC(t *testing.T) {
config.IgnoreEmailVerified = tc.IgnoreEmailVerified

client := coderdtest.New(t, &coderdtest.Options{
Auditor: auditor,
OIDCConfig: config,
})
numLogs := len(auditor.AuditLogs)

resp := oidcCallback(t, client, conf.EncodeClaims(t, tc.IDTokenClaims))
numLogs++ // add an audit log for login
assert.Equal(t, tc.StatusCode, resp.StatusCode)

ctx, _ := testutil.Context(t)
Expand All @@ -738,33 +743,43 @@ func TestUserOIDC(t *testing.T) {
user, err := client.User(ctx, "me")
require.NoError(t, err)
require.Equal(t, tc.Username, user.Username)

require.Len(t, auditor.AuditLogs, numLogs)
require.Equal(t, database.AuditActionLogin, auditor.AuditLogs[numLogs-1].Action)
}

if tc.AvatarURL != "" {
client.SetSessionToken(authCookieValue(resp.Cookies()))
user, err := client.User(ctx, "me")
require.NoError(t, err)
require.Equal(t, tc.AvatarURL, user.AvatarURL)

require.Len(t, auditor.AuditLogs, numLogs)
require.Equal(t, database.AuditActionLogin, auditor.AuditLogs[numLogs-1].Action)
}
})
}

t.Run("AlternateUsername", func(t *testing.T) {
t.Parallel()

auditor := audit.NewMock()
conf := coderdtest.NewOIDCConfig(t, "")

config := conf.OIDCConfig(t, nil)
config.AllowSignups = true

client := coderdtest.New(t, &coderdtest.Options{
Auditor: auditor,
OIDCConfig: config,
})
numLogs := len(auditor.AuditLogs)

code := conf.EncodeClaims(t, jwt.MapClaims{
"email": "jon@coder.com",
})
resp := oidcCallback(t, client, code)
numLogs++ // add an audit log for login

assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)

ctx, _ := testutil.Context(t)
Expand All @@ -781,12 +796,17 @@ func TestUserOIDC(t *testing.T) {
"sub": "diff",
})
resp = oidcCallback(t, client, code)
numLogs++ // add an audit log for login

assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)

client.SetSessionToken(authCookieValue(resp.Cookies()))
user, err = client.User(ctx, "me")
require.NoError(t, err)
require.True(t, strings.HasPrefix(user.Username, "jon-"), "username %q should have prefix %q", user.Username, "jon-")

require.Len(t, auditor.AuditLogs, numLogs)
require.Equal(t, database.AuditActionLogin, auditor.AuditLogs[numLogs-1].Action)
})

t.Run("Disabled", func(t *testing.T) {
Expand All @@ -798,23 +818,33 @@ func TestUserOIDC(t *testing.T) {

t.Run("NoIDToken", func(t *testing.T) {
t.Parallel()
auditor := audit.NewMock()
client := coderdtest.New(t, &coderdtest.Options{
Auditor: auditor,
OIDCConfig: &coderd.OIDCConfig{
OAuth2Config: &oauth2Config{},
},
})
numLogs := len(auditor.AuditLogs)

resp := oidcCallback(t, client, "asdf")
numLogs++ // add an audit log for login

require.Equal(t, http.StatusBadRequest, resp.StatusCode)
require.Len(t, auditor.AuditLogs, numLogs)
require.Equal(t, database.AuditActionLogin, auditor.AuditLogs[numLogs-1].Action)
})

t.Run("BadVerify", func(t *testing.T) {
t.Parallel()
auditor := audit.NewMock()
verifier := oidc.NewVerifier("", &oidc.StaticKeySet{
PublicKeys: []crypto.PublicKey{},
}, &oidc.Config{})
provider := &oidc.Provider{}

client := coderdtest.New(t, &coderdtest.Options{
Auditor: auditor,
OIDCConfig: &coderd.OIDCConfig{
OAuth2Config: &oauth2Config{
token: (&oauth2.Token{
Expand All @@ -827,8 +857,14 @@ func TestUserOIDC(t *testing.T) {
Verifier: verifier,
},
})
numLogs := len(auditor.AuditLogs)

resp := oidcCallback(t, client, "asdf")
numLogs++ // add an audit log for login

require.Equal(t, http.StatusBadRequest, resp.StatusCode)
require.Len(t, auditor.AuditLogs, numLogs)
require.Equal(t, database.AuditActionLogin, auditor.AuditLogs[numLogs-1].Action)
})
}

Expand Down