Skip to content

fix: limit OAuth redirects to local paths #14585

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Sep 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion coderd/coderdtest/coderdtest.go
Original file line number Diff line number Diff line change
Expand Up @@ -1140,7 +1140,7 @@ func MustWorkspace(t testing.TB, client *codersdk.Client, workspaceID uuid.UUID)

// RequestExternalAuthCallback makes a request with the proper OAuth2 state cookie
// to the external auth callback endpoint.
func RequestExternalAuthCallback(t testing.TB, providerID string, client *codersdk.Client) *http.Response {
func RequestExternalAuthCallback(t testing.TB, providerID string, client *codersdk.Client, opts ...func(*http.Request)) *http.Response {
client.HTTPClient.CheckRedirect = func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
}
Expand All @@ -1157,6 +1157,9 @@ func RequestExternalAuthCallback(t testing.TB, providerID string, client *coders
Name: codersdk.SessionTokenCookie,
Value: client.SessionToken(),
})
for _, opt := range opts {
opt(req)
}
res, err := client.HTTPClient.Do(req)
require.NoError(t, err)
t.Cleanup(func() {
Expand Down
17 changes: 13 additions & 4 deletions coderd/coderdtest/oidctest/idp.go
Original file line number Diff line number Diff line change
Expand Up @@ -479,7 +479,6 @@ func (f *FakeIDP) AttemptLogin(t testing.TB, client *codersdk.Client, idTokenCla
// This is a niche case, but it is needed for testing ConvertLoginType.
func (f *FakeIDP) LoginWithClient(t testing.TB, client *codersdk.Client, idTokenClaims jwt.MapClaims, opts ...func(r *http.Request)) (*codersdk.Client, *http.Response) {
t.Helper()

path := "/api/v2/users/oidc/callback"
if f.callbackPath != "" {
path = f.callbackPath
Expand All @@ -489,13 +488,23 @@ func (f *FakeIDP) LoginWithClient(t testing.TB, client *codersdk.Client, idToken
f.SetRedirect(t, coderOauthURL.String())

cli := f.HTTPClient(client.HTTPClient)
cli.CheckRedirect = func(req *http.Request, via []*http.Request) error {
redirectFn := cli.CheckRedirect
checkRedirect := func(req *http.Request, via []*http.Request) error {
// Store the idTokenClaims to the specific state request. This ties
// the claims 1:1 with a given authentication flow.
state := req.URL.Query().Get("state")
f.stateToIDTokenClaims.Store(state, idTokenClaims)
if state := req.URL.Query().Get("state"); state != "" {
f.stateToIDTokenClaims.Store(state, idTokenClaims)
return nil
}
// This is mainly intended to prevent the _last_ redirect
// The one involving the state param is a core part of the
// OIDC flow and shouldn't be redirected.
if redirectFn != nil {
return redirectFn(req, via)
}
return nil
}
cli.CheckRedirect = checkRedirect

req, err := http.NewRequestWithContext(context.Background(), "GET", coderOauthURL.String(), nil)
require.NoError(t, err)
Expand Down
11 changes: 11 additions & 0 deletions coderd/externalauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"net/http"
"net/url"

"github.com/sqlc-dev/pqtype"
"golang.org/x/sync/errgroup"
Expand Down Expand Up @@ -306,6 +307,7 @@ func (api *API) externalAuthCallback(externalAuthConfig *externalauth.Config) ht
// FE know not to enter the authentication loop again, and instead display an error.
redirect = fmt.Sprintf("/external-auth/%s?redirected=true", externalAuthConfig.ID)
}
redirect = uriFromURL(redirect)
http.Redirect(rw, r, redirect, http.StatusTemporaryRedirect)
}
}
Expand Down Expand Up @@ -401,3 +403,12 @@ func ExternalAuthConfig(cfg *externalauth.Config) codersdk.ExternalAuthLinkProvi
AllowValidate: cfg.ValidateURL != "",
}
}

func uriFromURL(u string) string {
uri, err := url.Parse(u)
if err != nil {
return "/"
}

return uri.RequestURI()
}
33 changes: 31 additions & 2 deletions coderd/externalauth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -207,12 +207,12 @@ func TestExternalAuthManagement(t *testing.T) {
const gitlabID = "fake-gitlab"

githubCalled := false
githubApp := oidctest.NewFakeIDP(t, oidctest.WithServing(), oidctest.WithRefresh(func(email string) error {
githubApp := oidctest.NewFakeIDP(t, oidctest.WithServing(), oidctest.WithRefresh(func(_ string) error {
githubCalled = true
return nil
}))
gitlabCalled := false
gitlab := oidctest.NewFakeIDP(t, oidctest.WithServing(), oidctest.WithRefresh(func(email string) error {
gitlab := oidctest.NewFakeIDP(t, oidctest.WithServing(), oidctest.WithRefresh(func(_ string) error {
gitlabCalled = true
return nil
}))
Expand Down Expand Up @@ -508,6 +508,35 @@ func TestExternalAuthCallback(t *testing.T) {
resp = coderdtest.RequestExternalAuthCallback(t, "github", client)
require.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)
})

t.Run("CustomRedirect", func(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, &coderdtest.Options{
IncludeProvisionerDaemon: true,
ExternalAuthConfigs: []*externalauth.Config{{
InstrumentedOAuth2Config: &testutil.OAuth2Config{},
ID: "github",
Regex: regexp.MustCompile(`github\.com`),
Type: codersdk.EnhancedExternalAuthProviderGitHub.String(),
}},
})
maliciousHost := "https://malicious.com"
expectedURI := "/some/path?param=1"
_ = coderdtest.CreateFirstUser(t, client)
resp := coderdtest.RequestExternalAuthCallback(t, "github", client, func(req *http.Request) {
req.AddCookie(&http.Cookie{
Name: codersdk.OAuth2RedirectCookie,
Value: maliciousHost + expectedURI,
})
})
require.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)
location, err := resp.Location()
require.NoError(t, err)
require.Equal(t, expectedURI, location.RequestURI())
require.Equal(t, client.URL.Host, location.Host)
require.NotContains(t, location.String(), maliciousHost)
})

t.Run("ValidateURL", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
Expand Down
22 changes: 20 additions & 2 deletions coderd/httpmw/oauth2.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"fmt"
"net/http"
"net/url"
"reflect"

"github.com/go-chi/chi/v5"
Expand Down Expand Up @@ -85,6 +86,15 @@ func ExtractOAuth2(config promoauth.OAuth2Config, client *http.Client, authURLOp

code := r.URL.Query().Get("code")
state := r.URL.Query().Get("state")
redirect := r.URL.Query().Get("redirect")
if redirect != "" {
// We want to ensure that we're only ever redirecting to the application.
// We could be more strict here and check to see if the host matches
// the host of the AccessURL but ultimately as long as our redirect
// url omits a host we're ensuring that we're routing to a path
// local to the application.
redirect = uriFromURL(redirect)
}

if code == "" {
// If the code isn't provided, we'll redirect!
Expand Down Expand Up @@ -119,7 +129,7 @@ func ExtractOAuth2(config promoauth.OAuth2Config, client *http.Client, authURLOp
// an old redirect could apply!
http.SetCookie(rw, &http.Cookie{
Name: codersdk.OAuth2RedirectCookie,
Value: r.URL.Query().Get("redirect"),
Value: redirect,
Path: "/",
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
Expand Down Expand Up @@ -150,7 +160,6 @@ func ExtractOAuth2(config promoauth.OAuth2Config, client *http.Client, authURLOp
return
}

var redirect string
stateRedirect, err := r.Cookie(codersdk.OAuth2RedirectCookie)
if err == nil {
redirect = stateRedirect.Value
Expand Down Expand Up @@ -302,3 +311,12 @@ func ExtractOAuth2ProviderAppSecret(db database.Store) func(http.Handler) http.H
})
}
}

func uriFromURL(u string) string {
uri, err := url.Parse(u)
if err != nil {
return "/"
}

return uri.RequestURI()
}
27 changes: 26 additions & 1 deletion coderd/httpmw/oauth2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,31 @@ func TestOAuth2(t *testing.T) {
cookie := res.Result().Cookies()[1]
require.Equal(t, "/dashboard", cookie.Value)
})
t.Run("OnlyPathBaseRedirect", func(t *testing.T) {
t.Parallel()
// Construct a URI to a potentially malicious
// site and assert that we omit the host
// when redirecting the request.
uri := &url.URL{
Scheme: "https",
Host: "some.bad.domain.com",
Path: "/sadf/asdfasdf",
RawQuery: "foo=hello&bar=world",
}
expectedValue := uri.Path + "?" + uri.RawQuery
req := httptest.NewRequest("GET", "/?redirect="+url.QueryEscape(uri.String()), nil)
res := httptest.NewRecorder()
tp := newTestOAuth2Provider(t, oauth2.AccessTypeOffline)
httpmw.ExtractOAuth2(tp, nil, nil)(nil).ServeHTTP(res, req)
location := res.Header().Get("Location")
if !assert.NotEmpty(t, location) {
return
}
require.Len(t, res.Result().Cookies(), 2)
cookie := res.Result().Cookies()[1]
require.Equal(t, expectedValue, cookie.Value)
})

t.Run("NoState", func(t *testing.T) {
t.Parallel()
req := httptest.NewRequest("GET", "/?code=something", nil)
Expand Down Expand Up @@ -108,7 +133,7 @@ func TestOAuth2(t *testing.T) {
})
res := httptest.NewRecorder()
tp := newTestOAuth2Provider(t, oauth2.AccessTypeOffline)
httpmw.ExtractOAuth2(tp, nil, nil)(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
httpmw.ExtractOAuth2(tp, nil, nil)(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
state := httpmw.OAuth2(r)
require.Equal(t, "/dashboard", state.Redirect)
})).ServeHTTP(res, req)
Expand Down
12 changes: 5 additions & 7 deletions coderd/userauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -707,9 +707,7 @@ func (api *API) userOAuth2Github(rw http.ResponseWriter, r *http.Request) {
http.SetCookie(rw, cookie)
}

if redirect == "" {
redirect = "/"
}
redirect = uriFromURL(redirect)
http.Redirect(rw, r, redirect, http.StatusTemporaryRedirect)
}

Expand Down Expand Up @@ -1085,9 +1083,9 @@ func (api *API) userOIDC(rw http.ResponseWriter, r *http.Request) {
}

redirect := state.Redirect
if redirect == "" {
redirect = "/"
}
// Strip the host if it exists on the URL to prevent
// any nefarious redirects.
redirect = uriFromURL(redirect)
http.Redirect(rw, r, redirect, http.StatusTemporaryRedirect)
}

Expand Down Expand Up @@ -1687,7 +1685,7 @@ func (api *API) convertUserToOauth(ctx context.Context, r *http.Request, db data
}
}
var claims OAuthConvertStateClaims
token, err := jwt.ParseWithClaims(jwtCookie.Value, &claims, func(token *jwt.Token) (interface{}, error) {
token, err := jwt.ParseWithClaims(jwtCookie.Value, &claims, func(_ *jwt.Token) (interface{}, error) {
return api.OAuthSigningKey[:], nil
})
if xerrors.Is(err, jwt.ErrSignatureInvalid) || !token.Valid {
Expand Down
76 changes: 73 additions & 3 deletions coderd/userauth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -354,11 +354,25 @@ func TestUserOAuth2Github(t *testing.T) {
})
numLogs := len(auditor.AuditLogs())

resp := oauth2Callback(t, client)
// Validate that attempting to redirect away from the
// site does not work.
maliciousHost := "https://malicious.com"
expectedPath := "/my/path"
resp := oauth2Callback(t, client, func(req *http.Request) {
// Add the cookie to bypass the parsing in httpmw/oauth2.go
req.AddCookie(&http.Cookie{
Name: codersdk.OAuth2RedirectCookie,
Value: maliciousHost + expectedPath,
})
})
numLogs++ // add an audit log for login

require.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)

redirect, err := resp.Location()
require.NoError(t, err)
require.Equal(t, expectedPath, redirect.Path)
require.Equal(t, client.URL.Host, redirect.Host)
require.NotContains(t, redirect.String(), maliciousHost)
client.SetSessionToken(authCookieValue(resp.Cookies()))
user, err := client.User(context.Background(), "me")
require.NoError(t, err)
Expand Down Expand Up @@ -1436,6 +1450,59 @@ func TestUserOIDC(t *testing.T) {
_, resp := fake.AttemptLogin(t, client, jwt.MapClaims{})
require.Equal(t, http.StatusBadRequest, resp.StatusCode)
})

t.Run("StripRedirectHost", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)

expectedRedirect := "/foo/bar?hello=world&bar=baz"
redirectURL := "https://malicious" + expectedRedirect

callbackPath := fmt.Sprintf("/api/v2/users/oidc/callback?redirect=%s", url.QueryEscape(redirectURL))
fake := oidctest.NewFakeIDP(t,
oidctest.WithRefresh(func(_ string) error {
return xerrors.New("refreshing token should never occur")
}),
oidctest.WithServing(),
oidctest.WithCallbackPath(callbackPath),
)
cfg := fake.OIDCConfig(t, nil, func(cfg *coderd.OIDCConfig) {
cfg.AllowSignups = true
})

client := coderdtest.New(t, &coderdtest.Options{
OIDCConfig: cfg,
})

client.HTTPClient.Transport = http.DefaultTransport

client.HTTPClient.CheckRedirect = func(*http.Request, []*http.Request) error {
return http.ErrUseLastResponse
}

claims := jwt.MapClaims{
"email": "user@example.com",
"email_verified": true,
}

// Perform the login
loginClient, resp := fake.LoginWithClient(t, client, claims)
require.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)

// Get the location from the response
location, err := resp.Location()
require.NoError(t, err)

// Check that the redirect URL has been stripped of its malicious host
require.Equal(t, expectedRedirect, location.RequestURI())
require.Equal(t, client.URL.Host, location.Host)
require.NotContains(t, location.String(), "malicious")

// Verify the user was created
user, err := loginClient.User(ctx, "me")
require.NoError(t, err)
require.Equal(t, "user@example.com", user.Email)
})
}

func TestUserLogout(t *testing.T) {
Expand Down Expand Up @@ -1587,7 +1654,7 @@ func TestOIDCSkipIssuer(t *testing.T) {
require.Equal(t, found.LoginType, codersdk.LoginTypeOIDC)
}

func oauth2Callback(t *testing.T, client *codersdk.Client) *http.Response {
func oauth2Callback(t *testing.T, client *codersdk.Client, opts ...func(*http.Request)) *http.Response {
client.HTTPClient.CheckRedirect = func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
}
Expand All @@ -1597,6 +1664,9 @@ func oauth2Callback(t *testing.T, client *codersdk.Client) *http.Response {
require.NoError(t, err)
req, err := http.NewRequestWithContext(context.Background(), "GET", oauthURL.String(), nil)
require.NoError(t, err)
for _, opt := range opts {
opt(req)
}
req.AddCookie(&http.Cookie{
Name: codersdk.OAuth2StateCookie,
Value: state,
Expand Down
Loading