@@ -50,15 +50,15 @@ func TestOAuth2(t *testing.T) {
50
50
t .Parallel ()
51
51
req := httptest .NewRequest ("GET" , "/" , nil )
52
52
res := httptest .NewRecorder ()
53
- httpmw .ExtractOAuth2 (nil , nil , nil )(nil ).ServeHTTP (res , req )
53
+ httpmw .ExtractOAuth2 (nil , nil , codersdk. HTTPCookieConfig {}, nil )(nil ).ServeHTTP (res , req )
54
54
require .Equal (t , http .StatusBadRequest , res .Result ().StatusCode )
55
55
})
56
56
t .Run ("RedirectWithoutCode" , func (t * testing.T ) {
57
57
t .Parallel ()
58
58
req := httptest .NewRequest ("GET" , "/?redirect=" + url .QueryEscape ("/dashboard" ), nil )
59
59
res := httptest .NewRecorder ()
60
60
tp := newTestOAuth2Provider (t , oauth2 .AccessTypeOffline )
61
- httpmw .ExtractOAuth2 (tp , nil , nil )(nil ).ServeHTTP (res , req )
61
+ httpmw .ExtractOAuth2 (tp , nil , codersdk. HTTPCookieConfig {}, nil )(nil ).ServeHTTP (res , req )
62
62
location := res .Header ().Get ("Location" )
63
63
if ! assert .NotEmpty (t , location ) {
64
64
return
@@ -82,7 +82,7 @@ func TestOAuth2(t *testing.T) {
82
82
req := httptest .NewRequest ("GET" , "/?redirect=" + url .QueryEscape (uri .String ()), nil )
83
83
res := httptest .NewRecorder ()
84
84
tp := newTestOAuth2Provider (t , oauth2 .AccessTypeOffline )
85
- httpmw .ExtractOAuth2 (tp , nil , nil )(nil ).ServeHTTP (res , req )
85
+ httpmw .ExtractOAuth2 (tp , nil , codersdk. HTTPCookieConfig {}, nil )(nil ).ServeHTTP (res , req )
86
86
location := res .Header ().Get ("Location" )
87
87
if ! assert .NotEmpty (t , location ) {
88
88
return
@@ -97,15 +97,15 @@ func TestOAuth2(t *testing.T) {
97
97
req := httptest .NewRequest ("GET" , "/?code=something" , nil )
98
98
res := httptest .NewRecorder ()
99
99
tp := newTestOAuth2Provider (t , oauth2 .AccessTypeOffline )
100
- httpmw .ExtractOAuth2 (tp , nil , nil )(nil ).ServeHTTP (res , req )
100
+ httpmw .ExtractOAuth2 (tp , nil , codersdk. HTTPCookieConfig {}, nil )(nil ).ServeHTTP (res , req )
101
101
require .Equal (t , http .StatusBadRequest , res .Result ().StatusCode )
102
102
})
103
103
t .Run ("NoStateCookie" , func (t * testing.T ) {
104
104
t .Parallel ()
105
105
req := httptest .NewRequest ("GET" , "/?code=something&state=test" , nil )
106
106
res := httptest .NewRecorder ()
107
107
tp := newTestOAuth2Provider (t , oauth2 .AccessTypeOffline )
108
- httpmw .ExtractOAuth2 (tp , nil , nil )(nil ).ServeHTTP (res , req )
108
+ httpmw .ExtractOAuth2 (tp , nil , codersdk. HTTPCookieConfig {}, nil )(nil ).ServeHTTP (res , req )
109
109
require .Equal (t , http .StatusUnauthorized , res .Result ().StatusCode )
110
110
})
111
111
t .Run ("MismatchedState" , func (t * testing.T ) {
@@ -117,7 +117,7 @@ func TestOAuth2(t *testing.T) {
117
117
})
118
118
res := httptest .NewRecorder ()
119
119
tp := newTestOAuth2Provider (t , oauth2 .AccessTypeOffline )
120
- httpmw .ExtractOAuth2 (tp , nil , nil )(nil ).ServeHTTP (res , req )
120
+ httpmw .ExtractOAuth2 (tp , nil , codersdk. HTTPCookieConfig {}, nil )(nil ).ServeHTTP (res , req )
121
121
require .Equal (t , http .StatusUnauthorized , res .Result ().StatusCode )
122
122
})
123
123
t .Run ("ExchangeCodeAndState" , func (t * testing.T ) {
@@ -133,7 +133,7 @@ func TestOAuth2(t *testing.T) {
133
133
})
134
134
res := httptest .NewRecorder ()
135
135
tp := newTestOAuth2Provider (t , oauth2 .AccessTypeOffline )
136
- httpmw .ExtractOAuth2 (tp , nil , nil )(http .HandlerFunc (func (_ http.ResponseWriter , r * http.Request ) {
136
+ httpmw .ExtractOAuth2 (tp , nil , codersdk. HTTPCookieConfig {}, nil )(http .HandlerFunc (func (_ http.ResponseWriter , r * http.Request ) {
137
137
state := httpmw .OAuth2 (r )
138
138
require .Equal (t , "/dashboard" , state .Redirect )
139
139
})).ServeHTTP (res , req )
@@ -144,7 +144,7 @@ func TestOAuth2(t *testing.T) {
144
144
res := httptest .NewRecorder ()
145
145
tp := newTestOAuth2Provider (t , oauth2 .AccessTypeOffline , oauth2 .SetAuthURLParam ("foo" , "bar" ))
146
146
authOpts := map [string ]string {"foo" : "bar" }
147
- httpmw .ExtractOAuth2 (tp , nil , authOpts )(nil ).ServeHTTP (res , req )
147
+ httpmw .ExtractOAuth2 (tp , nil , codersdk. HTTPCookieConfig {}, authOpts )(nil ).ServeHTTP (res , req )
148
148
location := res .Header ().Get ("Location" )
149
149
// Ideally we would also assert that the location contains the query params
150
150
// we set in the auth URL but this would essentially be testing the oauth2 package.
@@ -157,12 +157,17 @@ func TestOAuth2(t *testing.T) {
157
157
req := httptest .NewRequest ("GET" , "/?oidc_merge_state=" + customState + "&redirect=" + url .QueryEscape ("/dashboard" ), nil )
158
158
res := httptest .NewRecorder ()
159
159
tp := newTestOAuth2Provider (t , oauth2 .AccessTypeOffline )
160
- httpmw .ExtractOAuth2 (tp , nil , nil )(nil ).ServeHTTP (res , req )
160
+ httpmw .ExtractOAuth2 (tp , nil , codersdk.HTTPCookieConfig {
161
+ Secure : true ,
162
+ SameSite : "none" ,
163
+ }, nil )(nil ).ServeHTTP (res , req )
161
164
162
165
found := false
163
166
for _ , cookie := range res .Result ().Cookies () {
164
167
if cookie .Name == codersdk .OAuth2StateCookie {
165
168
require .Equal (t , cookie .Value , customState , "expected state" )
169
+ require .Equal (t , true , cookie .Secure , "cookie set to secure" )
170
+ require .Equal (t , http .SameSiteNoneMode , cookie .SameSite , "same-site = none" )
166
171
found = true
167
172
}
168
173
}
0 commit comments