Skip to content

Commit 9c73346

Browse files
committed
Add unit test for mock e2e
1 parent 7346fbf commit 9c73346

File tree

1 file changed

+174
-41
lines changed

1 file changed

+174
-41
lines changed

coderd/oauthpki/okidcpki_test.go

+174-41
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,19 @@ import (
55
"encoding/base64"
66
"io"
77
"net/http"
8+
"net/http/httptest"
89
"net/url"
10+
"strings"
911
"testing"
12+
"time"
1013

1114
"github.com/coreos/go-oidc/v3/oidc"
1215
"github.com/golang-jwt/jwt"
1316
"github.com/stretchr/testify/assert"
1417
"github.com/stretchr/testify/require"
1518
"golang.org/x/oauth2"
19+
"golang.org/x/xerrors"
1620

17-
"github.com/coder/coder/coderd/httpmw"
1821
"github.com/coder/coder/coderd/oauthpki"
1922
"github.com/coder/coder/testutil"
2023
)
@@ -77,25 +80,9 @@ B1B7CpkMU55hPP+7nsofCszNrMDXT8Z5w2a3zLKM
7780
`
7881
)
7982

80-
type exchangeAssert struct {
81-
httpmw.OAuth2Config
82-
assert func(ctx context.Context, code string, opts ...oauth2.AuthCodeOption)
83-
}
84-
85-
func (a *exchangeAssert) Exchange(ctx context.Context, code string, opts ...oauth2.AuthCodeOption) (*oauth2.Token, error) {
86-
a.assert(ctx, code, opts...)
87-
return a.OAuth2Config.Exchange(ctx, code, opts...)
88-
}
89-
90-
type fakeRoundTripper struct {
91-
roundTrip func(req *http.Request) (*http.Response, error)
92-
}
93-
94-
func (f fakeRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
95-
return f.roundTrip(req)
96-
}
97-
9883
// TestAzureADPKIOIDC ensures we do not break Azure AD compatibility.
84+
// It runs an oauth2.Exchange method and hijacks the request to only check the
85+
// request side of the transaction.
9986
func TestAzureADPKIOIDC(t *testing.T) {
10087
oauthCfg := &oauth2.Config{
10188
ClientID: "random-client-id",
@@ -123,28 +110,7 @@ func TestAzureADPKIOIDC(t *testing.T) {
123110
// This is the easiest way to hijack the request and check
124111
// the params. The oauth2 package uses unexported types and
125112
// options, so we need to view the actual request created.
126-
data, err := io.ReadAll(req.Body)
127-
if !assert.NoError(t, err, "failed to read request body") {
128-
return resp, nil
129-
}
130-
vals, err := url.ParseQuery(string(data))
131-
if !assert.NoError(t, err, "failed to parse values") {
132-
return resp, nil
133-
}
134-
assert.Equal(t, "urn:ietf:params:oauth:client-assertion-type:jwt-bearer", vals.Get("client_assertion_type"))
135-
136-
jwtToken := vals.Get("client_assertion")
137-
138-
// No need to actually verify the jwt is signed right.
139-
parsedToken, _, err := (&jwt.Parser{}).ParseUnverified(jwtToken, jwt.MapClaims{})
140-
if !assert.NoError(t, err, "failed to parse jwt token") {
141-
return resp, nil
142-
}
143-
144-
// Azure requirements
145-
assert.NotEmpty(t, parsedToken.Header["x5t"], "hashed cert missing")
146-
assert.Equal(t, "RS256", parsedToken.Header["alg"], "azure only accepts RS256")
147-
assert.Equal(t, "JWT", parsedToken.Header["typ"], "azure only accepts JWT")
113+
assertJWTAuth(t, req)
148114
return resp, nil
149115
},
150116
},
@@ -153,3 +119,170 @@ func TestAzureADPKIOIDC(t *testing.T) {
153119
// We hijack the request and return an error intentionally
154120
require.Error(t, err, "error expected")
155121
}
122+
123+
// TestSavedAzureADPKIOIDC was created by capturing actual responses from an Azure
124+
// AD instance and saving them to replay, removing some details.
125+
// The reason this is done is that this is the only way to assert values
126+
// passed to the oauth2 provider via http requests.
127+
// It is not feasible to run against an actual Azure AD instance, so this attempts
128+
// to prevent some regressions by running a full "e2e" oauth and asserting some
129+
// of the request values.
130+
func TestSavedAzureADPKIOIDC(t *testing.T) {
131+
var (
132+
stateString = "random-state"
133+
oauth2Code = base64.StdEncoding.EncodeToString([]byte("random-code"))
134+
)
135+
136+
// Real oauth config. We will hijack all http requests so some of these values
137+
// are fake.
138+
cfg := &oauth2.Config{
139+
ClientID: "fake_app",
140+
ClientSecret: "",
141+
Endpoint: oauth2.Endpoint{
142+
AuthURL: "https://login.microsoftonline.com/fake_app/oauth2/v2.0/authorize",
143+
TokenURL: "https://login.microsoftonline.com/fake_app/oauth2/v2.0/token",
144+
AuthStyle: 0,
145+
},
146+
RedirectURL: "http://localhost/api/v2/users/oidc/callback",
147+
Scopes: []string{"openid", "profile", "email", "offline_access"},
148+
}
149+
150+
initialExchange := false
151+
tokenRefreshed := false
152+
153+
// Create the oauthpki config
154+
pki, err := oauthpki.NewOauth2PKIConfig(oauthpki.ConfigParams{
155+
ClientID: cfg.ClientID,
156+
TokenURL: cfg.Endpoint.TokenURL,
157+
Scopes: []string{"openid", "email", "profile", "offline_access"},
158+
PemEncodedKey: []byte(testClientKey),
159+
PemEncodedCert: []byte(testClientCert),
160+
Config: cfg,
161+
})
162+
require.NoError(t, err)
163+
164+
var fakeCtx context.Context
165+
fakeClient := &http.Client{
166+
Transport: fakeRoundTripper{
167+
roundTrip: func(req *http.Request) (*http.Response, error) {
168+
if strings.Contains(req.URL.String(), "authorize") {
169+
// This is the user hitting the browser endpoint to begin the OIDC
170+
// auth flow.
171+
172+
// Authorize should redirect the user back to the app after authentication on
173+
// the IDP.
174+
resp := httptest.NewRecorder()
175+
v := url.Values{
176+
"code": {oauth2Code},
177+
"state": {stateString},
178+
"session_state": {"a18cf797-1e2b-4bc3-baf9-66b41a4997cf"},
179+
}
180+
181+
// This url doesn't really matter since the fake client will hiject this actual request.
182+
http.Redirect(resp, req, "http://localhost:3000/api/v2/users/oidc/callback?"+v.Encode(), http.StatusTemporaryRedirect)
183+
return resp.Result(), nil
184+
}
185+
if strings.Contains(req.URL.String(), "v2.0/token") {
186+
vals := assertJWTAuth(t, req)
187+
switch vals.Get("grant_type") {
188+
case "authorization_code":
189+
// Initial token
190+
initialExchange = true
191+
assert.Equal(t, oauth2Code, vals.Get("code"), "initial exchange code mismatch")
192+
case "refresh_token":
193+
// refreshed token
194+
tokenRefreshed = true
195+
assert.Equal(t, "<refresh_token_JWT>", vals.Get("refresh_token"), "refresh token required")
196+
}
197+
198+
resp := httptest.NewRecorder()
199+
// Taken from an actual response
200+
// Just always return a token no matter what.
201+
resp.Header().Set("Content-Type", "application/json")
202+
_, _ = resp.Write([]byte(`{
203+
"token_type":"Bearer",
204+
"scope":"email openid profile AccessReview.ReadWrite.Membership Group.Read.All Group.ReadWrite.All User.Read",
205+
"expires_in":4009,
206+
"ext_expires_in":4009,
207+
"access_token":"<access_token_JWT>",
208+
"refresh_token":"<refresh_token_JWT>",
209+
"id_token":"eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiIsImtpZCI6Ii1LSTNROW5OUjdiUm9meG1lWm9YcWJIWkdldyJ9.eyJhdWQiOiIxZjAxODMyYS1mZWViLTQyZGMtODFkOS01ZjBhYjZhMDQxZTAiLCJpc3MiOiJodHRwczovL2xvZ2luLm1pY3Jvc29mdG9ubGluZS5jb20vMTEwZjBjMGYtY2Q3Ni00NzE3LWE2ZjgtNGVlYTNkMGY4MTA5L3YyLjAiLCJpYXQiOjE2OTE3OTI2MzQsIm5iZiI6MTY5MTc5MjYzNCwiZXhwIjoxNjkxNzk2NTM0LCJhaW8iOiJBWVFBZS84VUFBQUE1eEtqMmVTdWFXVmZsRlhCeGJJTnMvSkVyVHFvUGlaQW5ENmJIZWF3a2RRcisyRVRwM3RGNGY3akxicnh3ODhhVm9QOThrY0xMNjhON1hVV3FCN1I1N2JQRU9EclRlSUI1S0lyUHBjbCtIeXR0a1ljOVdWQklVVEErSllQbzl1a0ZjbGNWZ1krWUc3eHlmdi90K3Q1ZEczblNuZEdEQ1FYRVIxbDlTNko1T2c9IiwiZW1haWwiOiJzdGV2ZW5AY29kZXIuY29tIiwiZ3JvdXBzIjpbImM4MDQ4ZTkxLWY1YzMtNDdlNS05NjkzLTgzNGRlODQwMzRhZCIsIjcwYjQ4MTc1LTEwN2ItNGFkOC1iNDA1LTRkODg4YTFjNDY2ZiJdLCJpZHAiOiJtYWlsIiwibmFtZSI6IlN0ZXZlbiBNIiwib2lkIjoiN2JhNDYzNjAtZTAyNy00OTVhLTlhZTUtM2FlYWZlMzY3MGEyIiwicHJlZmVycmVkX3VzZXJuYW1lIjoic3RldmVuQGNvZGVyLmNvbSIsInByb3ZfZGF0YSI6W3siQXQiOnRydWUsIlByb3YiOiJnaXRodWIuY29tIiwiQWx0c2VjaWQiOiI1NDQ2Mjk4IiwiQWNjZXNzVG9rZW4iOm51bGx9XSwicmgiOiIwLkFUZ0FEd3dQRVhiTkYwZW0tRTdxUFEtQkNTcURBUl9yX3R4Q2dkbGZDcmFnUWVBNEFPRS4iLCJyb2xlcyI6WyJUZW1wbGF0ZUF1dGhvcnMiXSwic3ViIjoib0JTN3FjUERKdWlDMEYyQ19XdDJycVlvanhpT0o3S3JFWjlkQ1RkTGVYNCIsInRpZCI6IjExMGYwYzBmLWNkNzYtNDcxNy1hNmY4LTRlZWEzZDBmODEwOSIsInV0aSI6IktReGlIWGtaZUVxcC1tQWlVdTlyQUEiLCJ2ZXIiOiIyLjAiLCJyb2xlczIiOiJUZW1wbGF0ZUF1dGhvcnMifQ.JevFI4Xm9dW7kQq4xEgZnUaU0SqbeOAFtT0YIKQNefR9Db4sjxCaKRmX0pPt-CM9j45d6fAiAkLFDAqjlSbi4Zi0GbEomT3yegmuxKgEgjPpJlGjF2TBUpsNNyn5gJ9Wkct9BfwALJhX2ePJFzIlkvx9opNNbNK1qHKMMjOSRFG6AGExKRDiQAME0a4hVgCwrAdUs4JrCcj4LqB84dODN-eoh-jx2-1wDvf6fovfwLHDQwjY4lfBxaYdNavKM369hrhU-U067rSnCzvDD26f4VLhPF52hiQIbTVN5t7p_1XmcduUiaNnmr9AZiZxZ-94mctSRRR8xG0pNwO2yv84iA"
210+
}`))
211+
return resp.Result(), nil
212+
}
213+
// This is the "Coder" half of things. We can keep this in the fake
214+
// client, essentially being the fake client on both sides of the OIDC
215+
// flow.
216+
if strings.Contains(req.URL.String(), "v2/users/oidc/callback") {
217+
// This is the callback from the IDP.
218+
code := req.URL.Query().Get("code")
219+
require.Equal(t, oauth2Code, code, "code mismatch")
220+
state := req.URL.Query().Get("state")
221+
require.Equal(t, stateString, state, "state mismatch")
222+
223+
// Exchange for token should work
224+
token, err := pki.Exchange(fakeCtx, code)
225+
if !assert.NoError(t, err) {
226+
return httptest.NewRecorder().Result(), nil
227+
}
228+
229+
// Also try a refresh
230+
cpy := token
231+
cpy.Expiry = time.Now().Add(time.Minute * -1)
232+
src := pki.TokenSource(fakeCtx, cpy)
233+
_, err = src.Token()
234+
tokenRefreshed = true
235+
assert.NoError(t, err, "token refreshed")
236+
return httptest.NewRecorder().Result(), nil
237+
}
238+
239+
return nil, xerrors.Errorf("not implemented")
240+
}},
241+
}
242+
fakeCtx = oidc.ClientContext(context.Background(), fakeClient)
243+
var _ = fakeCtx
244+
245+
// This simulates a client logging into the browser. The 307 redirect will
246+
// make sure this goes through the full flow.
247+
_, err = fakeClient.Get(pki.AuthCodeURL("state", oauth2.AccessTypeOffline))
248+
require.NoError(t, err)
249+
250+
require.True(t, initialExchange, "initial token exchange complete")
251+
require.True(t, tokenRefreshed, "token was refreshed")
252+
}
253+
254+
type fakeRoundTripper struct {
255+
roundTrip func(req *http.Request) (*http.Response, error)
256+
}
257+
258+
func (f fakeRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
259+
return f.roundTrip(req)
260+
}
261+
262+
// assertJWTAuth will assert the basic JWT auth assertions. It will return the
263+
// url.Values from the request body for any additional assertions to be made.
264+
func assertJWTAuth(t *testing.T, r *http.Request) url.Values {
265+
body, err := io.ReadAll(r.Body)
266+
if !assert.NoError(t, err) {
267+
return nil
268+
}
269+
vals, err := url.ParseQuery(string(body))
270+
if !assert.NoError(t, err) {
271+
return nil
272+
}
273+
274+
assert.Equal(t, "urn:ietf:params:oauth:client-assertion-type:jwt-bearer", vals.Get("client_assertion_type"))
275+
jwtToken := vals.Get("client_assertion")
276+
// No need to actually verify the jwt is signed right.
277+
parsedToken, _, err := (&jwt.Parser{}).ParseUnverified(jwtToken, jwt.MapClaims{})
278+
if !assert.NoError(t, err, "failed to parse jwt token") {
279+
return nil
280+
}
281+
282+
// Azure requirements
283+
assert.NotEmpty(t, parsedToken.Header["x5t"], "hashed cert missing")
284+
assert.Equal(t, "RS256", parsedToken.Header["alg"], "azure only accepts RS256")
285+
assert.Equal(t, "JWT", parsedToken.Header["typ"], "azure only accepts JWT")
286+
287+
return vals
288+
}

0 commit comments

Comments
 (0)