@@ -5,16 +5,19 @@ import (
5
5
"encoding/base64"
6
6
"io"
7
7
"net/http"
8
+ "net/http/httptest"
8
9
"net/url"
10
+ "strings"
9
11
"testing"
12
+ "time"
10
13
11
14
"github.com/coreos/go-oidc/v3/oidc"
12
15
"github.com/golang-jwt/jwt"
13
16
"github.com/stretchr/testify/assert"
14
17
"github.com/stretchr/testify/require"
15
18
"golang.org/x/oauth2"
19
+ "golang.org/x/xerrors"
16
20
17
- "github.com/coder/coder/coderd/httpmw"
18
21
"github.com/coder/coder/coderd/oauthpki"
19
22
"github.com/coder/coder/testutil"
20
23
)
@@ -77,25 +80,9 @@ B1B7CpkMU55hPP+7nsofCszNrMDXT8Z5w2a3zLKM
77
80
`
78
81
)
79
82
80
- type exchangeAssert struct {
81
- httpmw.OAuth2Config
82
- assert func (ctx context.Context , code string , opts ... oauth2.AuthCodeOption )
83
- }
84
-
85
- func (a * exchangeAssert ) Exchange (ctx context.Context , code string , opts ... oauth2.AuthCodeOption ) (* oauth2.Token , error ) {
86
- a .assert (ctx , code , opts ... )
87
- return a .OAuth2Config .Exchange (ctx , code , opts ... )
88
- }
89
-
90
- type fakeRoundTripper struct {
91
- roundTrip func (req * http.Request ) (* http.Response , error )
92
- }
93
-
94
- func (f fakeRoundTripper ) RoundTrip (req * http.Request ) (* http.Response , error ) {
95
- return f .roundTrip (req )
96
- }
97
-
98
83
// TestAzureADPKIOIDC ensures we do not break Azure AD compatibility.
84
+ // It runs an oauth2.Exchange method and hijacks the request to only check the
85
+ // request side of the transaction.
99
86
func TestAzureADPKIOIDC (t * testing.T ) {
100
87
oauthCfg := & oauth2.Config {
101
88
ClientID : "random-client-id" ,
@@ -123,28 +110,7 @@ func TestAzureADPKIOIDC(t *testing.T) {
123
110
// This is the easiest way to hijack the request and check
124
111
// the params. The oauth2 package uses unexported types and
125
112
// options, so we need to view the actual request created.
126
- data , err := io .ReadAll (req .Body )
127
- if ! assert .NoError (t , err , "failed to read request body" ) {
128
- return resp , nil
129
- }
130
- vals , err := url .ParseQuery (string (data ))
131
- if ! assert .NoError (t , err , "failed to parse values" ) {
132
- return resp , nil
133
- }
134
- assert .Equal (t , "urn:ietf:params:oauth:client-assertion-type:jwt-bearer" , vals .Get ("client_assertion_type" ))
135
-
136
- jwtToken := vals .Get ("client_assertion" )
137
-
138
- // No need to actually verify the jwt is signed right.
139
- parsedToken , _ , err := (& jwt.Parser {}).ParseUnverified (jwtToken , jwt.MapClaims {})
140
- if ! assert .NoError (t , err , "failed to parse jwt token" ) {
141
- return resp , nil
142
- }
143
-
144
- // Azure requirements
145
- assert .NotEmpty (t , parsedToken .Header ["x5t" ], "hashed cert missing" )
146
- assert .Equal (t , "RS256" , parsedToken .Header ["alg" ], "azure only accepts RS256" )
147
- assert .Equal (t , "JWT" , parsedToken .Header ["typ" ], "azure only accepts JWT" )
113
+ assertJWTAuth (t , req )
148
114
return resp , nil
149
115
},
150
116
},
@@ -153,3 +119,170 @@ func TestAzureADPKIOIDC(t *testing.T) {
153
119
// We hijack the request and return an error intentionally
154
120
require .Error (t , err , "error expected" )
155
121
}
122
+
123
+ // TestSavedAzureADPKIOIDC was created by capturing actual responses from an Azure
124
+ // AD instance and saving them to replay, removing some details.
125
+ // The reason this is done is that this is the only way to assert values
126
+ // passed to the oauth2 provider via http requests.
127
+ // It is not feasible to run against an actual Azure AD instance, so this attempts
128
+ // to prevent some regressions by running a full "e2e" oauth and asserting some
129
+ // of the request values.
130
+ func TestSavedAzureADPKIOIDC (t * testing.T ) {
131
+ var (
132
+ stateString = "random-state"
133
+ oauth2Code = base64 .StdEncoding .EncodeToString ([]byte ("random-code" ))
134
+ )
135
+
136
+ // Real oauth config. We will hijack all http requests so some of these values
137
+ // are fake.
138
+ cfg := & oauth2.Config {
139
+ ClientID : "fake_app" ,
140
+ ClientSecret : "" ,
141
+ Endpoint : oauth2.Endpoint {
142
+ AuthURL : "https://login.microsoftonline.com/fake_app/oauth2/v2.0/authorize" ,
143
+ TokenURL : "https://login.microsoftonline.com/fake_app/oauth2/v2.0/token" ,
144
+ AuthStyle : 0 ,
145
+ },
146
+ RedirectURL : "http://localhost/api/v2/users/oidc/callback" ,
147
+ Scopes : []string {"openid" , "profile" , "email" , "offline_access" },
148
+ }
149
+
150
+ initialExchange := false
151
+ tokenRefreshed := false
152
+
153
+ // Create the oauthpki config
154
+ pki , err := oauthpki .NewOauth2PKIConfig (oauthpki.ConfigParams {
155
+ ClientID : cfg .ClientID ,
156
+ TokenURL : cfg .Endpoint .TokenURL ,
157
+ Scopes : []string {"openid" , "email" , "profile" , "offline_access" },
158
+ PemEncodedKey : []byte (testClientKey ),
159
+ PemEncodedCert : []byte (testClientCert ),
160
+ Config : cfg ,
161
+ })
162
+ require .NoError (t , err )
163
+
164
+ var fakeCtx context.Context
165
+ fakeClient := & http.Client {
166
+ Transport : fakeRoundTripper {
167
+ roundTrip : func (req * http.Request ) (* http.Response , error ) {
168
+ if strings .Contains (req .URL .String (), "authorize" ) {
169
+ // This is the user hitting the browser endpoint to begin the OIDC
170
+ // auth flow.
171
+
172
+ // Authorize should redirect the user back to the app after authentication on
173
+ // the IDP.
174
+ resp := httptest .NewRecorder ()
175
+ v := url.Values {
176
+ "code" : {oauth2Code },
177
+ "state" : {stateString },
178
+ "session_state" : {"a18cf797-1e2b-4bc3-baf9-66b41a4997cf" },
179
+ }
180
+
181
+ // This url doesn't really matter since the fake client will hiject this actual request.
182
+ http .Redirect (resp , req , "http://localhost:3000/api/v2/users/oidc/callback?" + v .Encode (), http .StatusTemporaryRedirect )
183
+ return resp .Result (), nil
184
+ }
185
+ if strings .Contains (req .URL .String (), "v2.0/token" ) {
186
+ vals := assertJWTAuth (t , req )
187
+ switch vals .Get ("grant_type" ) {
188
+ case "authorization_code" :
189
+ // Initial token
190
+ initialExchange = true
191
+ assert .Equal (t , oauth2Code , vals .Get ("code" ), "initial exchange code mismatch" )
192
+ case "refresh_token" :
193
+ // refreshed token
194
+ tokenRefreshed = true
195
+ assert .Equal (t , "<refresh_token_JWT>" , vals .Get ("refresh_token" ), "refresh token required" )
196
+ }
197
+
198
+ resp := httptest .NewRecorder ()
199
+ // Taken from an actual response
200
+ // Just always return a token no matter what.
201
+ resp .Header ().Set ("Content-Type" , "application/json" )
202
+ _ , _ = resp .Write ([]byte (`{
203
+ "token_type":"Bearer",
204
+ "scope":"email openid profile AccessReview.ReadWrite.Membership Group.Read.All Group.ReadWrite.All User.Read",
205
+ "expires_in":4009,
206
+ "ext_expires_in":4009,
207
+ "access_token":"<access_token_JWT>",
208
+ "refresh_token":"<refresh_token_JWT>",
209
+ "id_token":"eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiIsImtpZCI6Ii1LSTNROW5OUjdiUm9meG1lWm9YcWJIWkdldyJ9.eyJhdWQiOiIxZjAxODMyYS1mZWViLTQyZGMtODFkOS01ZjBhYjZhMDQxZTAiLCJpc3MiOiJodHRwczovL2xvZ2luLm1pY3Jvc29mdG9ubGluZS5jb20vMTEwZjBjMGYtY2Q3Ni00NzE3LWE2ZjgtNGVlYTNkMGY4MTA5L3YyLjAiLCJpYXQiOjE2OTE3OTI2MzQsIm5iZiI6MTY5MTc5MjYzNCwiZXhwIjoxNjkxNzk2NTM0LCJhaW8iOiJBWVFBZS84VUFBQUE1eEtqMmVTdWFXVmZsRlhCeGJJTnMvSkVyVHFvUGlaQW5ENmJIZWF3a2RRcisyRVRwM3RGNGY3akxicnh3ODhhVm9QOThrY0xMNjhON1hVV3FCN1I1N2JQRU9EclRlSUI1S0lyUHBjbCtIeXR0a1ljOVdWQklVVEErSllQbzl1a0ZjbGNWZ1krWUc3eHlmdi90K3Q1ZEczblNuZEdEQ1FYRVIxbDlTNko1T2c9IiwiZW1haWwiOiJzdGV2ZW5AY29kZXIuY29tIiwiZ3JvdXBzIjpbImM4MDQ4ZTkxLWY1YzMtNDdlNS05NjkzLTgzNGRlODQwMzRhZCIsIjcwYjQ4MTc1LTEwN2ItNGFkOC1iNDA1LTRkODg4YTFjNDY2ZiJdLCJpZHAiOiJtYWlsIiwibmFtZSI6IlN0ZXZlbiBNIiwib2lkIjoiN2JhNDYzNjAtZTAyNy00OTVhLTlhZTUtM2FlYWZlMzY3MGEyIiwicHJlZmVycmVkX3VzZXJuYW1lIjoic3RldmVuQGNvZGVyLmNvbSIsInByb3ZfZGF0YSI6W3siQXQiOnRydWUsIlByb3YiOiJnaXRodWIuY29tIiwiQWx0c2VjaWQiOiI1NDQ2Mjk4IiwiQWNjZXNzVG9rZW4iOm51bGx9XSwicmgiOiIwLkFUZ0FEd3dQRVhiTkYwZW0tRTdxUFEtQkNTcURBUl9yX3R4Q2dkbGZDcmFnUWVBNEFPRS4iLCJyb2xlcyI6WyJUZW1wbGF0ZUF1dGhvcnMiXSwic3ViIjoib0JTN3FjUERKdWlDMEYyQ19XdDJycVlvanhpT0o3S3JFWjlkQ1RkTGVYNCIsInRpZCI6IjExMGYwYzBmLWNkNzYtNDcxNy1hNmY4LTRlZWEzZDBmODEwOSIsInV0aSI6IktReGlIWGtaZUVxcC1tQWlVdTlyQUEiLCJ2ZXIiOiIyLjAiLCJyb2xlczIiOiJUZW1wbGF0ZUF1dGhvcnMifQ.JevFI4Xm9dW7kQq4xEgZnUaU0SqbeOAFtT0YIKQNefR9Db4sjxCaKRmX0pPt-CM9j45d6fAiAkLFDAqjlSbi4Zi0GbEomT3yegmuxKgEgjPpJlGjF2TBUpsNNyn5gJ9Wkct9BfwALJhX2ePJFzIlkvx9opNNbNK1qHKMMjOSRFG6AGExKRDiQAME0a4hVgCwrAdUs4JrCcj4LqB84dODN-eoh-jx2-1wDvf6fovfwLHDQwjY4lfBxaYdNavKM369hrhU-U067rSnCzvDD26f4VLhPF52hiQIbTVN5t7p_1XmcduUiaNnmr9AZiZxZ-94mctSRRR8xG0pNwO2yv84iA"
210
+ }` ))
211
+ return resp .Result (), nil
212
+ }
213
+ // This is the "Coder" half of things. We can keep this in the fake
214
+ // client, essentially being the fake client on both sides of the OIDC
215
+ // flow.
216
+ if strings .Contains (req .URL .String (), "v2/users/oidc/callback" ) {
217
+ // This is the callback from the IDP.
218
+ code := req .URL .Query ().Get ("code" )
219
+ require .Equal (t , oauth2Code , code , "code mismatch" )
220
+ state := req .URL .Query ().Get ("state" )
221
+ require .Equal (t , stateString , state , "state mismatch" )
222
+
223
+ // Exchange for token should work
224
+ token , err := pki .Exchange (fakeCtx , code )
225
+ if ! assert .NoError (t , err ) {
226
+ return httptest .NewRecorder ().Result (), nil
227
+ }
228
+
229
+ // Also try a refresh
230
+ cpy := token
231
+ cpy .Expiry = time .Now ().Add (time .Minute * - 1 )
232
+ src := pki .TokenSource (fakeCtx , cpy )
233
+ _ , err = src .Token ()
234
+ tokenRefreshed = true
235
+ assert .NoError (t , err , "token refreshed" )
236
+ return httptest .NewRecorder ().Result (), nil
237
+ }
238
+
239
+ return nil , xerrors .Errorf ("not implemented" )
240
+ }},
241
+ }
242
+ fakeCtx = oidc .ClientContext (context .Background (), fakeClient )
243
+ var _ = fakeCtx
244
+
245
+ // This simulates a client logging into the browser. The 307 redirect will
246
+ // make sure this goes through the full flow.
247
+ _ , err = fakeClient .Get (pki .AuthCodeURL ("state" , oauth2 .AccessTypeOffline ))
248
+ require .NoError (t , err )
249
+
250
+ require .True (t , initialExchange , "initial token exchange complete" )
251
+ require .True (t , tokenRefreshed , "token was refreshed" )
252
+ }
253
+
254
+ type fakeRoundTripper struct {
255
+ roundTrip func (req * http.Request ) (* http.Response , error )
256
+ }
257
+
258
+ func (f fakeRoundTripper ) RoundTrip (req * http.Request ) (* http.Response , error ) {
259
+ return f .roundTrip (req )
260
+ }
261
+
262
+ // assertJWTAuth will assert the basic JWT auth assertions. It will return the
263
+ // url.Values from the request body for any additional assertions to be made.
264
+ func assertJWTAuth (t * testing.T , r * http.Request ) url.Values {
265
+ body , err := io .ReadAll (r .Body )
266
+ if ! assert .NoError (t , err ) {
267
+ return nil
268
+ }
269
+ vals , err := url .ParseQuery (string (body ))
270
+ if ! assert .NoError (t , err ) {
271
+ return nil
272
+ }
273
+
274
+ assert .Equal (t , "urn:ietf:params:oauth:client-assertion-type:jwt-bearer" , vals .Get ("client_assertion_type" ))
275
+ jwtToken := vals .Get ("client_assertion" )
276
+ // No need to actually verify the jwt is signed right.
277
+ parsedToken , _ , err := (& jwt.Parser {}).ParseUnverified (jwtToken , jwt.MapClaims {})
278
+ if ! assert .NoError (t , err , "failed to parse jwt token" ) {
279
+ return nil
280
+ }
281
+
282
+ // Azure requirements
283
+ assert .NotEmpty (t , parsedToken .Header ["x5t" ], "hashed cert missing" )
284
+ assert .Equal (t , "RS256" , parsedToken .Header ["alg" ], "azure only accepts RS256" )
285
+ assert .Equal (t , "JWT" , parsedToken .Header ["typ" ], "azure only accepts JWT" )
286
+
287
+ return vals
288
+ }
0 commit comments