Skip to content

feat: add azure oidc PKI auth instead of client secret #9054

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Aug 14, 2023
Merged
Prev Previous commit
Next Next commit
Add unit test for mock e2e
  • Loading branch information
Emyrk committed Aug 11, 2023
commit 9c73346b62c3db0f94464da417a1411686d9295c
215 changes: 174 additions & 41 deletions coderd/oauthpki/okidcpki_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,19 @@ import (
"encoding/base64"
"io"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"time"

"github.com/coreos/go-oidc/v3/oidc"
"github.com/golang-jwt/jwt"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/oauth2"
"golang.org/x/xerrors"

"github.com/coder/coder/coderd/httpmw"
"github.com/coder/coder/coderd/oauthpki"
"github.com/coder/coder/testutil"
)
Expand Down Expand Up @@ -77,25 +80,9 @@ B1B7CpkMU55hPP+7nsofCszNrMDXT8Z5w2a3zLKM
`
)

type exchangeAssert struct {
httpmw.OAuth2Config
assert func(ctx context.Context, code string, opts ...oauth2.AuthCodeOption)
}

func (a *exchangeAssert) Exchange(ctx context.Context, code string, opts ...oauth2.AuthCodeOption) (*oauth2.Token, error) {
a.assert(ctx, code, opts...)
return a.OAuth2Config.Exchange(ctx, code, opts...)
}

type fakeRoundTripper struct {
roundTrip func(req *http.Request) (*http.Response, error)
}

func (f fakeRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
return f.roundTrip(req)
}

// TestAzureADPKIOIDC ensures we do not break Azure AD compatibility.
// It runs an oauth2.Exchange method and hijacks the request to only check the
// request side of the transaction.
func TestAzureADPKIOIDC(t *testing.T) {
oauthCfg := &oauth2.Config{
ClientID: "random-client-id",
Expand Down Expand Up @@ -123,28 +110,7 @@ func TestAzureADPKIOIDC(t *testing.T) {
// This is the easiest way to hijack the request and check
// the params. The oauth2 package uses unexported types and
// options, so we need to view the actual request created.
data, err := io.ReadAll(req.Body)
if !assert.NoError(t, err, "failed to read request body") {
return resp, nil
}
vals, err := url.ParseQuery(string(data))
if !assert.NoError(t, err, "failed to parse values") {
return resp, nil
}
assert.Equal(t, "urn:ietf:params:oauth:client-assertion-type:jwt-bearer", vals.Get("client_assertion_type"))

jwtToken := vals.Get("client_assertion")

// No need to actually verify the jwt is signed right.
parsedToken, _, err := (&jwt.Parser{}).ParseUnverified(jwtToken, jwt.MapClaims{})
if !assert.NoError(t, err, "failed to parse jwt token") {
return resp, nil
}

// Azure requirements
assert.NotEmpty(t, parsedToken.Header["x5t"], "hashed cert missing")
assert.Equal(t, "RS256", parsedToken.Header["alg"], "azure only accepts RS256")
assert.Equal(t, "JWT", parsedToken.Header["typ"], "azure only accepts JWT")
assertJWTAuth(t, req)
return resp, nil
},
},
Expand All @@ -153,3 +119,170 @@ func TestAzureADPKIOIDC(t *testing.T) {
// We hijack the request and return an error intentionally
require.Error(t, err, "error expected")
}

// TestSavedAzureADPKIOIDC was created by capturing actual responses from an Azure
// AD instance and saving them to replay, removing some details.
// The reason this is done is that this is the only way to assert values
// passed to the oauth2 provider via http requests.
// It is not feasible to run against an actual Azure AD instance, so this attempts
// to prevent some regressions by running a full "e2e" oauth and asserting some
// of the request values.
func TestSavedAzureADPKIOIDC(t *testing.T) {
var (
stateString = "random-state"
oauth2Code = base64.StdEncoding.EncodeToString([]byte("random-code"))
)

// Real oauth config. We will hijack all http requests so some of these values
// are fake.
cfg := &oauth2.Config{
ClientID: "fake_app",
ClientSecret: "",
Endpoint: oauth2.Endpoint{
AuthURL: "https://login.microsoftonline.com/fake_app/oauth2/v2.0/authorize",
TokenURL: "https://login.microsoftonline.com/fake_app/oauth2/v2.0/token",
AuthStyle: 0,
},
RedirectURL: "http://localhost/api/v2/users/oidc/callback",
Scopes: []string{"openid", "profile", "email", "offline_access"},
}

initialExchange := false
tokenRefreshed := false

// Create the oauthpki config
pki, err := oauthpki.NewOauth2PKIConfig(oauthpki.ConfigParams{
ClientID: cfg.ClientID,
TokenURL: cfg.Endpoint.TokenURL,
Scopes: []string{"openid", "email", "profile", "offline_access"},
PemEncodedKey: []byte(testClientKey),
PemEncodedCert: []byte(testClientCert),
Config: cfg,
})
require.NoError(t, err)

var fakeCtx context.Context
fakeClient := &http.Client{
Transport: fakeRoundTripper{
roundTrip: func(req *http.Request) (*http.Response, error) {
if strings.Contains(req.URL.String(), "authorize") {
// This is the user hitting the browser endpoint to begin the OIDC
// auth flow.

// Authorize should redirect the user back to the app after authentication on
// the IDP.
resp := httptest.NewRecorder()
v := url.Values{
"code": {oauth2Code},
"state": {stateString},
"session_state": {"a18cf797-1e2b-4bc3-baf9-66b41a4997cf"},
}

// This url doesn't really matter since the fake client will hiject this actual request.
http.Redirect(resp, req, "http://localhost:3000/api/v2/users/oidc/callback?"+v.Encode(), http.StatusTemporaryRedirect)
return resp.Result(), nil
}
if strings.Contains(req.URL.String(), "v2.0/token") {
vals := assertJWTAuth(t, req)
switch vals.Get("grant_type") {
case "authorization_code":
// Initial token
initialExchange = true
assert.Equal(t, oauth2Code, vals.Get("code"), "initial exchange code mismatch")
case "refresh_token":
// refreshed token
tokenRefreshed = true
assert.Equal(t, "<refresh_token_JWT>", vals.Get("refresh_token"), "refresh token required")
}

resp := httptest.NewRecorder()
// Taken from an actual response
// Just always return a token no matter what.
resp.Header().Set("Content-Type", "application/json")
_, _ = resp.Write([]byte(`{
"token_type":"Bearer",
"scope":"email openid profile AccessReview.ReadWrite.Membership Group.Read.All Group.ReadWrite.All User.Read",
"expires_in":4009,
"ext_expires_in":4009,
"access_token":"<access_token_JWT>",
"refresh_token":"<refresh_token_JWT>",
"id_token":"eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiIsImtpZCI6Ii1LSTNROW5OUjdiUm9meG1lWm9YcWJIWkdldyJ9.eyJhdWQiOiIxZjAxODMyYS1mZWViLTQyZGMtODFkOS01ZjBhYjZhMDQxZTAiLCJpc3MiOiJodHRwczovL2xvZ2luLm1pY3Jvc29mdG9ubGluZS5jb20vMTEwZjBjMGYtY2Q3Ni00NzE3LWE2ZjgtNGVlYTNkMGY4MTA5L3YyLjAiLCJpYXQiOjE2OTE3OTI2MzQsIm5iZiI6MTY5MTc5MjYzNCwiZXhwIjoxNjkxNzk2NTM0LCJhaW8iOiJBWVFBZS84VUFBQUE1eEtqMmVTdWFXVmZsRlhCeGJJTnMvSkVyVHFvUGlaQW5ENmJIZWF3a2RRcisyRVRwM3RGNGY3akxicnh3ODhhVm9QOThrY0xMNjhON1hVV3FCN1I1N2JQRU9EclRlSUI1S0lyUHBjbCtIeXR0a1ljOVdWQklVVEErSllQbzl1a0ZjbGNWZ1krWUc3eHlmdi90K3Q1ZEczblNuZEdEQ1FYRVIxbDlTNko1T2c9IiwiZW1haWwiOiJzdGV2ZW5AY29kZXIuY29tIiwiZ3JvdXBzIjpbImM4MDQ4ZTkxLWY1YzMtNDdlNS05NjkzLTgzNGRlODQwMzRhZCIsIjcwYjQ4MTc1LTEwN2ItNGFkOC1iNDA1LTRkODg4YTFjNDY2ZiJdLCJpZHAiOiJtYWlsIiwibmFtZSI6IlN0ZXZlbiBNIiwib2lkIjoiN2JhNDYzNjAtZTAyNy00OTVhLTlhZTUtM2FlYWZlMzY3MGEyIiwicHJlZmVycmVkX3VzZXJuYW1lIjoic3RldmVuQGNvZGVyLmNvbSIsInByb3ZfZGF0YSI6W3siQXQiOnRydWUsIlByb3YiOiJnaXRodWIuY29tIiwiQWx0c2VjaWQiOiI1NDQ2Mjk4IiwiQWNjZXNzVG9rZW4iOm51bGx9XSwicmgiOiIwLkFUZ0FEd3dQRVhiTkYwZW0tRTdxUFEtQkNTcURBUl9yX3R4Q2dkbGZDcmFnUWVBNEFPRS4iLCJyb2xlcyI6WyJUZW1wbGF0ZUF1dGhvcnMiXSwic3ViIjoib0JTN3FjUERKdWlDMEYyQ19XdDJycVlvanhpT0o3S3JFWjlkQ1RkTGVYNCIsInRpZCI6IjExMGYwYzBmLWNkNzYtNDcxNy1hNmY4LTRlZWEzZDBmODEwOSIsInV0aSI6IktReGlIWGtaZUVxcC1tQWlVdTlyQUEiLCJ2ZXIiOiIyLjAiLCJyb2xlczIiOiJUZW1wbGF0ZUF1dGhvcnMifQ.JevFI4Xm9dW7kQq4xEgZnUaU0SqbeOAFtT0YIKQNefR9Db4sjxCaKRmX0pPt-CM9j45d6fAiAkLFDAqjlSbi4Zi0GbEomT3yegmuxKgEgjPpJlGjF2TBUpsNNyn5gJ9Wkct9BfwALJhX2ePJFzIlkvx9opNNbNK1qHKMMjOSRFG6AGExKRDiQAME0a4hVgCwrAdUs4JrCcj4LqB84dODN-eoh-jx2-1wDvf6fovfwLHDQwjY4lfBxaYdNavKM369hrhU-U067rSnCzvDD26f4VLhPF52hiQIbTVN5t7p_1XmcduUiaNnmr9AZiZxZ-94mctSRRR8xG0pNwO2yv84iA"
}`))
return resp.Result(), nil
}
// This is the "Coder" half of things. We can keep this in the fake
// client, essentially being the fake client on both sides of the OIDC
// flow.
if strings.Contains(req.URL.String(), "v2/users/oidc/callback") {
// This is the callback from the IDP.
code := req.URL.Query().Get("code")
require.Equal(t, oauth2Code, code, "code mismatch")
state := req.URL.Query().Get("state")
require.Equal(t, stateString, state, "state mismatch")

// Exchange for token should work
token, err := pki.Exchange(fakeCtx, code)
if !assert.NoError(t, err) {
return httptest.NewRecorder().Result(), nil
}

// Also try a refresh
cpy := token
cpy.Expiry = time.Now().Add(time.Minute * -1)
src := pki.TokenSource(fakeCtx, cpy)
_, err = src.Token()
tokenRefreshed = true
assert.NoError(t, err, "token refreshed")
return httptest.NewRecorder().Result(), nil
}

return nil, xerrors.Errorf("not implemented")
}},
}
fakeCtx = oidc.ClientContext(context.Background(), fakeClient)
var _ = fakeCtx

// This simulates a client logging into the browser. The 307 redirect will
// make sure this goes through the full flow.
_, err = fakeClient.Get(pki.AuthCodeURL("state", oauth2.AccessTypeOffline))
require.NoError(t, err)

require.True(t, initialExchange, "initial token exchange complete")
require.True(t, tokenRefreshed, "token was refreshed")
}

type fakeRoundTripper struct {
roundTrip func(req *http.Request) (*http.Response, error)
}

func (f fakeRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
return f.roundTrip(req)
}

// assertJWTAuth will assert the basic JWT auth assertions. It will return the
// url.Values from the request body for any additional assertions to be made.
func assertJWTAuth(t *testing.T, r *http.Request) url.Values {
body, err := io.ReadAll(r.Body)
if !assert.NoError(t, err) {
return nil
}
vals, err := url.ParseQuery(string(body))
if !assert.NoError(t, err) {
return nil
}

assert.Equal(t, "urn:ietf:params:oauth:client-assertion-type:jwt-bearer", vals.Get("client_assertion_type"))
jwtToken := vals.Get("client_assertion")
// No need to actually verify the jwt is signed right.
parsedToken, _, err := (&jwt.Parser{}).ParseUnverified(jwtToken, jwt.MapClaims{})
if !assert.NoError(t, err, "failed to parse jwt token") {
return nil
}

// Azure requirements
assert.NotEmpty(t, parsedToken.Header["x5t"], "hashed cert missing")
assert.Equal(t, "RS256", parsedToken.Header["alg"], "azure only accepts RS256")
assert.Equal(t, "JWT", parsedToken.Header["typ"], "azure only accepts JWT")

return vals
}