Skip to content

Commit 8a3b820

Browse files
committed
chore: fixup tests
1 parent 4f6e753 commit 8a3b820

File tree

3 files changed

+19
-14
lines changed

3 files changed

+19
-14
lines changed

coderd/coderd.go

+3-3
Original file line numberDiff line numberDiff line change
@@ -868,7 +868,7 @@ func New(options *Options) *API {
868868
r.Route(fmt.Sprintf("/%s/callback", externalAuthConfig.ID), func(r chi.Router) {
869869
r.Use(
870870
apiKeyMiddlewareRedirect,
871-
httpmw.ExtractOAuth2(externalAuthConfig, options.HTTPClient, nil),
871+
httpmw.ExtractOAuth2(externalAuthConfig, options.HTTPClient, codersdk.HTTPCookieConfig{}, nil),
872872
)
873873
r.Get("/", api.externalAuthCallback(externalAuthConfig))
874874
})
@@ -1123,14 +1123,14 @@ func New(options *Options) *API {
11231123
r.Get("/github/device", api.userOAuth2GithubDevice)
11241124
r.Route("/github", func(r chi.Router) {
11251125
r.Use(
1126-
httpmw.ExtractOAuth2(options.GithubOAuth2Config, options.HTTPClient, nil),
1126+
httpmw.ExtractOAuth2(options.GithubOAuth2Config, options.HTTPClient, codersdk.HTTPCookieConfig{}, nil),
11271127
)
11281128
r.Get("/callback", api.userOAuth2Github)
11291129
})
11301130
})
11311131
r.Route("/oidc/callback", func(r chi.Router) {
11321132
r.Use(
1133-
httpmw.ExtractOAuth2(options.OIDCConfig, options.HTTPClient, oidcAuthURLParams),
1133+
httpmw.ExtractOAuth2(options.OIDCConfig, options.HTTPClient, codersdk.HTTPCookieConfig{}, oidcAuthURLParams),
11341134
)
11351135
r.Get("/", api.userOIDC)
11361136
})

coderd/httpmw/csrf_test.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ func TestCSRFExemptList(t *testing.T) {
5353
},
5454
}
5555

56-
mw := httpmw.CSRF(false)
56+
mw := httpmw.CSRF(codersdk.HTTPCookieConfig{})
5757
csrfmw := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})).(*nosurf.CSRFHandler)
5858

5959
for _, c := range cases {
@@ -87,7 +87,7 @@ func TestCSRFError(t *testing.T) {
8787
var handler http.Handler = http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
8888
writer.WriteHeader(http.StatusOK)
8989
})
90-
handler = httpmw.CSRF(false)(handler)
90+
handler = httpmw.CSRF(codersdk.HTTPCookieConfig{})(handler)
9191

9292
// Not testing the error case, just providing the example of things working
9393
// to base the failure tests off of.

coderd/httpmw/oauth2_test.go

+14-9
Original file line numberDiff line numberDiff line change
@@ -50,15 +50,15 @@ func TestOAuth2(t *testing.T) {
5050
t.Parallel()
5151
req := httptest.NewRequest("GET", "/", nil)
5252
res := httptest.NewRecorder()
53-
httpmw.ExtractOAuth2(nil, nil, nil)(nil).ServeHTTP(res, req)
53+
httpmw.ExtractOAuth2(nil, nil, codersdk.HTTPCookieConfig{}, nil)(nil).ServeHTTP(res, req)
5454
require.Equal(t, http.StatusBadRequest, res.Result().StatusCode)
5555
})
5656
t.Run("RedirectWithoutCode", func(t *testing.T) {
5757
t.Parallel()
5858
req := httptest.NewRequest("GET", "/?redirect="+url.QueryEscape("/dashboard"), nil)
5959
res := httptest.NewRecorder()
6060
tp := newTestOAuth2Provider(t, oauth2.AccessTypeOffline)
61-
httpmw.ExtractOAuth2(tp, nil, nil)(nil).ServeHTTP(res, req)
61+
httpmw.ExtractOAuth2(tp, nil, codersdk.HTTPCookieConfig{}, nil)(nil).ServeHTTP(res, req)
6262
location := res.Header().Get("Location")
6363
if !assert.NotEmpty(t, location) {
6464
return
@@ -82,7 +82,7 @@ func TestOAuth2(t *testing.T) {
8282
req := httptest.NewRequest("GET", "/?redirect="+url.QueryEscape(uri.String()), nil)
8383
res := httptest.NewRecorder()
8484
tp := newTestOAuth2Provider(t, oauth2.AccessTypeOffline)
85-
httpmw.ExtractOAuth2(tp, nil, nil)(nil).ServeHTTP(res, req)
85+
httpmw.ExtractOAuth2(tp, nil, codersdk.HTTPCookieConfig{}, nil)(nil).ServeHTTP(res, req)
8686
location := res.Header().Get("Location")
8787
if !assert.NotEmpty(t, location) {
8888
return
@@ -97,15 +97,15 @@ func TestOAuth2(t *testing.T) {
9797
req := httptest.NewRequest("GET", "/?code=something", nil)
9898
res := httptest.NewRecorder()
9999
tp := newTestOAuth2Provider(t, oauth2.AccessTypeOffline)
100-
httpmw.ExtractOAuth2(tp, nil, nil)(nil).ServeHTTP(res, req)
100+
httpmw.ExtractOAuth2(tp, nil, codersdk.HTTPCookieConfig{}, nil)(nil).ServeHTTP(res, req)
101101
require.Equal(t, http.StatusBadRequest, res.Result().StatusCode)
102102
})
103103
t.Run("NoStateCookie", func(t *testing.T) {
104104
t.Parallel()
105105
req := httptest.NewRequest("GET", "/?code=something&state=test", nil)
106106
res := httptest.NewRecorder()
107107
tp := newTestOAuth2Provider(t, oauth2.AccessTypeOffline)
108-
httpmw.ExtractOAuth2(tp, nil, nil)(nil).ServeHTTP(res, req)
108+
httpmw.ExtractOAuth2(tp, nil, codersdk.HTTPCookieConfig{}, nil)(nil).ServeHTTP(res, req)
109109
require.Equal(t, http.StatusUnauthorized, res.Result().StatusCode)
110110
})
111111
t.Run("MismatchedState", func(t *testing.T) {
@@ -117,7 +117,7 @@ func TestOAuth2(t *testing.T) {
117117
})
118118
res := httptest.NewRecorder()
119119
tp := newTestOAuth2Provider(t, oauth2.AccessTypeOffline)
120-
httpmw.ExtractOAuth2(tp, nil, nil)(nil).ServeHTTP(res, req)
120+
httpmw.ExtractOAuth2(tp, nil, codersdk.HTTPCookieConfig{}, nil)(nil).ServeHTTP(res, req)
121121
require.Equal(t, http.StatusUnauthorized, res.Result().StatusCode)
122122
})
123123
t.Run("ExchangeCodeAndState", func(t *testing.T) {
@@ -133,7 +133,7 @@ func TestOAuth2(t *testing.T) {
133133
})
134134
res := httptest.NewRecorder()
135135
tp := newTestOAuth2Provider(t, oauth2.AccessTypeOffline)
136-
httpmw.ExtractOAuth2(tp, nil, nil)(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
136+
httpmw.ExtractOAuth2(tp, nil, codersdk.HTTPCookieConfig{}, nil)(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
137137
state := httpmw.OAuth2(r)
138138
require.Equal(t, "/dashboard", state.Redirect)
139139
})).ServeHTTP(res, req)
@@ -144,7 +144,7 @@ func TestOAuth2(t *testing.T) {
144144
res := httptest.NewRecorder()
145145
tp := newTestOAuth2Provider(t, oauth2.AccessTypeOffline, oauth2.SetAuthURLParam("foo", "bar"))
146146
authOpts := map[string]string{"foo": "bar"}
147-
httpmw.ExtractOAuth2(tp, nil, authOpts)(nil).ServeHTTP(res, req)
147+
httpmw.ExtractOAuth2(tp, nil, codersdk.HTTPCookieConfig{}, authOpts)(nil).ServeHTTP(res, req)
148148
location := res.Header().Get("Location")
149149
// Ideally we would also assert that the location contains the query params
150150
// we set in the auth URL but this would essentially be testing the oauth2 package.
@@ -157,12 +157,17 @@ func TestOAuth2(t *testing.T) {
157157
req := httptest.NewRequest("GET", "/?oidc_merge_state="+customState+"&redirect="+url.QueryEscape("/dashboard"), nil)
158158
res := httptest.NewRecorder()
159159
tp := newTestOAuth2Provider(t, oauth2.AccessTypeOffline)
160-
httpmw.ExtractOAuth2(tp, nil, nil)(nil).ServeHTTP(res, req)
160+
httpmw.ExtractOAuth2(tp, nil, codersdk.HTTPCookieConfig{
161+
Secure: true,
162+
SameSite: "none",
163+
}, nil)(nil).ServeHTTP(res, req)
161164

162165
found := false
163166
for _, cookie := range res.Result().Cookies() {
164167
if cookie.Name == codersdk.OAuth2StateCookie {
165168
require.Equal(t, cookie.Value, customState, "expected state")
169+
require.Equal(t, true, cookie.Secure, "cookie set to secure")
170+
require.Equal(t, http.SameSiteNoneMode, cookie.SameSite, "same-site = none")
166171
found = true
167172
}
168173
}

0 commit comments

Comments
 (0)