diff --git a/coderd/httpmw/csrf.go b/coderd/httpmw/csrf.go index e868019bac23b..8cd043146c082 100644 --- a/coderd/httpmw/csrf.go +++ b/coderd/httpmw/csrf.go @@ -22,7 +22,9 @@ func CSRF(secureCookie bool) func(next http.Handler) http.Handler { mw.SetBaseCookie(http.Cookie{Path: "/", HttpOnly: true, SameSite: http.SameSiteLaxMode, Secure: secureCookie}) mw.SetFailureHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { sessCookie, err := r.Cookie(codersdk.SessionTokenCookie) - if err == nil && r.Header.Get(codersdk.SessionTokenHeader) != sessCookie.Value { + if err == nil && + r.Header.Get(codersdk.SessionTokenHeader) != "" && + r.Header.Get(codersdk.SessionTokenHeader) != sessCookie.Value { // If a user is using header authentication and cookie auth, but the values // do not match, the cookie value takes priority. // At the very least, return a more helpful error to the user. diff --git a/coderd/httpmw/csrf_test.go b/coderd/httpmw/csrf_test.go index 12c6afe825f75..03f2babb2961a 100644 --- a/coderd/httpmw/csrf_test.go +++ b/coderd/httpmw/csrf_test.go @@ -3,6 +3,7 @@ package httpmw_test import ( "context" "net/http" + "net/http/httptest" "testing" "github.com/justinas/nosurf" @@ -69,3 +70,77 @@ func TestCSRFExemptList(t *testing.T) { }) } } + +// TestCSRFError verifies the error message returned to a user when CSRF +// checks fail. +// +//nolint:bodyclose // Using httptest.Recorders +func TestCSRFError(t *testing.T) { + t.Parallel() + + // Hard coded matching CSRF values + const csrfCookieValue = "JXm9hOUdZctWt0ZZGAy9xiS/gxMKYOThdxjjMnMUyn4=" + const csrfHeaderValue = "KNKvagCBEHZK7ihe2t7fj6VeJ0UyTDco1yVUJE8N06oNqxLu5Zx1vRxZbgfC0mJJgeGkVjgs08mgPbcWPBkZ1A==" + // Use a url with "/api" as the root, other routes bypass CSRF. + const urlPath = "https://coder.com/api/v2/hello" + + var handler http.Handler = http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { + writer.WriteHeader(http.StatusOK) + }) + handler = httpmw.CSRF(false)(handler) + + // Not testing the error case, just providing the example of things working + // to base the failure tests off of. + t.Run("ValidCSRF", func(t *testing.T) { + t.Parallel() + + req, err := http.NewRequestWithContext(context.Background(), http.MethodPost, urlPath, nil) + require.NoError(t, err) + + req.AddCookie(&http.Cookie{Name: codersdk.SessionTokenCookie, Value: "session_token_value"}) + req.AddCookie(&http.Cookie{Name: nosurf.CookieName, Value: csrfCookieValue}) + req.Header.Add(nosurf.HeaderName, csrfHeaderValue) + + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + resp := rec.Result() + require.Equal(t, http.StatusOK, resp.StatusCode) + }) + + // The classic CSRF failure returns the generic error. + t.Run("MissingCSRFHeader", func(t *testing.T) { + t.Parallel() + + req, err := http.NewRequestWithContext(context.Background(), http.MethodPost, urlPath, nil) + require.NoError(t, err) + + req.AddCookie(&http.Cookie{Name: codersdk.SessionTokenCookie, Value: "session_token_value"}) + req.AddCookie(&http.Cookie{Name: nosurf.CookieName, Value: csrfCookieValue}) + + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + resp := rec.Result() + require.Equal(t, http.StatusBadRequest, resp.StatusCode) + require.Contains(t, rec.Body.String(), "Something is wrong with your CSRF token.") + }) + + // Include the CSRF cookie, but not the CSRF header value. + // Including the 'codersdk.SessionTokenHeader' will bypass CSRF only if + // it matches the cookie. If it does not, we expect a more helpful error. + t.Run("MismatchedHeaderAndCookie", func(t *testing.T) { + t.Parallel() + + req, err := http.NewRequestWithContext(context.Background(), http.MethodPost, urlPath, nil) + require.NoError(t, err) + + req.AddCookie(&http.Cookie{Name: codersdk.SessionTokenCookie, Value: "session_token_value"}) + req.AddCookie(&http.Cookie{Name: nosurf.CookieName, Value: csrfCookieValue}) + req.Header.Add(codersdk.SessionTokenHeader, "mismatched_value") + + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + resp := rec.Result() + require.Equal(t, http.StatusBadRequest, resp.StatusCode) + require.Contains(t, rec.Body.String(), "CSRF error encountered. Authentication via") + }) +}