Skip to content

chore: instrument external oauth2 requests #11519

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 20 commits into from
Jan 10, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Add a unit test
  • Loading branch information
Emyrk committed Jan 10, 2024
commit fd1e012a67e961594f0611bd4607e299f584e89e
53 changes: 47 additions & 6 deletions coderd/coderdtest/oidctest/idp.go
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,43 @@ func (f *FakeIDP) ExternalLogin(t testing.TB, client *codersdk.Client, opts ...f
_ = res.Body.Close()
}

// CreateAuthCode emulates a user clicking "allow" on the IDP page. When doing
// unit tests, it's easier to skip this step sometimes. It does make an actual
// request to the IDP, so it should be equivalent to doing this "manually" with
// actual requests.
func (f *FakeIDP) CreateAuthCode(t testing.TB, state string, opts ...func(r *http.Request)) string {
// We need to store some claims, because this is also an OIDC provider, and
// it expects some claims to be present.
f.stateToIDTokenClaims.Store(state, jwt.MapClaims{})

u := f.cfg.AuthCodeURL(state)
r, err := http.NewRequestWithContext(context.Background(), http.MethodPost, u, nil)
require.NoError(t, err, "failed to create auth request")

for _, opt := range opts {
opt(r)
}

rw := httptest.NewRecorder()
f.handler.ServeHTTP(rw, r)
resp := rw.Result()

require.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode, "expected redirect")
to := resp.Header.Get("Location")
require.NotEmpty(t, to, "expected redirect location")

toUrl, err := url.Parse(to)
require.NoError(t, err, "failed to parse redirect location")

code := toUrl.Query().Get("code")
require.NotEmpty(t, code, "expected code in redirect location")

newState := toUrl.Query().Get("state")
require.Equal(t, state, newState, "expected state to match")

return code
}

// OIDCCallback will emulate the IDP redirecting back to the Coder callback.
// This is helpful if no Coderd exists because the IDP needs to redirect to
// something.
Expand Down Expand Up @@ -917,13 +954,10 @@ func (f *FakeIDP) ExternalAuthConfig(t testing.TB, id string, custom *ExternalAu
return cfg
}

// OIDCConfig returns the OIDC config to use for Coderd.
func (f *FakeIDP) OIDCConfig(t testing.TB, scopes []string, opts ...func(cfg *coderd.OIDCConfig)) *coderd.OIDCConfig {
t.Helper()
func (f *FakeIDP) OAuthConfig(scopes ...string) *oauth2.Config {
if len(scopes) == 0 {
scopes = []string{"openid", "email", "profile"}
}

oauthCfg := &oauth2.Config{
ClientID: f.clientID,
ClientSecret: f.clientSecret,
Expand All @@ -937,6 +971,15 @@ func (f *FakeIDP) OIDCConfig(t testing.TB, scopes []string, opts ...func(cfg *co
RedirectURL: "https://redirect.com",
Scopes: scopes,
}
f.cfg = oauthCfg
return oauthCfg
}

// OIDCConfig returns the OIDC config to use for Coderd.
func (f *FakeIDP) OIDCConfig(t testing.TB, scopes []string, opts ...func(cfg *coderd.OIDCConfig)) *coderd.OIDCConfig {
t.Helper()

oauthCfg := f.OAuthConfig(scopes...)

ctx := oidc.ClientContext(context.Background(), f.HTTPClient(nil))
p, err := oidc.NewProvider(ctx, f.provider.Issuer)
Expand Down Expand Up @@ -965,8 +1008,6 @@ func (f *FakeIDP) OIDCConfig(t testing.TB, scopes []string, opts ...func(cfg *co
opt(cfg)
}

f.cfg = oauthCfg

return cfg
}

Expand Down
25 changes: 17 additions & 8 deletions coderd/promoauth/oauth2.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ func NewFactory(registry prometheus.Registerer) *Factory {
Help: "The total number of api calls made to external oauth2 providers. 'status_code' will be 0 if the request failed with no response.",
}, []string{
"name",
"source",
"status_code",
"domain",
}),
},
}
Expand Down Expand Up @@ -71,11 +71,11 @@ func (c *Config) AuthCodeURL(state string, opts ...oauth2.AuthCodeOption) string
}

func (c *Config) Exchange(ctx context.Context, code string, opts ...oauth2.AuthCodeOption) (*oauth2.Token, error) {
return c.underlying.Exchange(c.wrapClient(ctx), code, opts...)
return c.underlying.Exchange(c.wrapClient(ctx, "Exchange"), code, opts...)
}

func (c *Config) TokenSource(ctx context.Context, token *oauth2.Token) oauth2.TokenSource {
return c.underlying.TokenSource(c.wrapClient(ctx), token)
return c.underlying.TokenSource(c.wrapClient(ctx, "TokenSource"), token)
}

// wrapClient is the only way we can accurately instrument the oauth2 client.
Expand All @@ -85,30 +85,39 @@ func (c *Config) TokenSource(ctx context.Context, token *oauth2.Token) oauth2.To
// For example, the 'TokenSource' method will return a token
// source that will make a network request when the 'Token' method is called on
// it if the token is expired.
func (c *Config) wrapClient(ctx context.Context) context.Context {
cli := http.DefaultClient
func (c *Config) wrapClient(ctx context.Context, source string) context.Context {
cli := &http.Client{}

// Check if the context has an http client already.
if hc, ok := ctx.Value(oauth2.HTTPClient).(*http.Client); ok {
cli = hc
}

// The new tripper will instrument every request made by the oauth2 client.
cli.Transport = newInstrumentedTripper(c, cli.Transport)
cli.Transport = newInstrumentedTripper(c, source, cli.Transport)
return context.WithValue(ctx, oauth2.HTTPClient, cli)
}

type instrumentedTripper struct {
c *Config
source string
underlying http.RoundTripper
}

func newInstrumentedTripper(c *Config, under http.RoundTripper) *instrumentedTripper {
func newInstrumentedTripper(c *Config, source string, under http.RoundTripper) *instrumentedTripper {
if under == nil {
under = http.DefaultTransport
}

// If the underlying transport is the default, we need to clone it.
// We should also clone it if it supports cloning.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about testing for under.(interface{ Clone() ... }) instead? That would cover the "supports cloning" case.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wish I could but the method signature is:

func (t *Transport) Clone() *Transport

So the interface of Clone() http.RoundTripper does not match, and the default transport would not implement it.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, of course. That's unfortunate :(.

if tr, ok := under.(*http.Transport); ok {
under = tr.Clone()
}

return &instrumentedTripper{
c: c,
source: source,
underlying: under,
}
}
Expand All @@ -121,8 +130,8 @@ func (i *instrumentedTripper) RoundTrip(r *http.Request) (*http.Response, error)
}
i.c.metrics.externalRequestCount.With(prometheus.Labels{
"name": i.c.name,
"source": i.source,
"status_code": fmt.Sprintf("%d", statusCode),
"domain": r.URL.Host,
}).Inc()
return resp, err
}
53 changes: 53 additions & 0 deletions coderd/promoauth/oauth2_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package promoauth_test

import (
"net/http"
"testing"
"time"

"github.com/prometheus/client_golang/prometheus"
ptestutil "github.com/prometheus/client_golang/prometheus/testutil"
"github.com/stretchr/testify/require"

"github.com/coder/coder/v2/coderd/coderdtest/oidctest"
"github.com/coder/coder/v2/coderd/promoauth"
"github.com/coder/coder/v2/testutil"
)

func TestMaintainDefault(t *testing.T) {
t.Parallel()

ctx := testutil.Context(t, testutil.WaitShort)
idp := oidctest.NewFakeIDP(t, oidctest.WithServing())
reg := prometheus.NewRegistry()
count := func() int {
return ptestutil.CollectAndCount(reg, "coderd_oauth2_external_requests_total")
}

factory := promoauth.NewFactory(reg)
cfg := factory.New("test", idp.OAuthConfig())

// 0 Requests before we start
require.Equal(t, count(), 0)

// Exchange should trigger a request
code := idp.CreateAuthCode(t, "foo")
token, err := cfg.Exchange(ctx, code)
require.NoError(t, err)
require.Equal(t, count(), 1)

// Force a refresh
token.Expiry = time.Now().Add(time.Hour * -1)
src := cfg.TokenSource(ctx, token)
refreshed, err := src.Token()
require.NoError(t, err)
require.NotEqual(t, token.AccessToken, refreshed.AccessToken, "token refreshed")
require.Equal(t, count(), 2)

// Verify the default client was not broken. This check is added because we
// extend the http.DefaultTransport. If a `.Clone()` is not done, this can be
// mis-used. It is cheap to run this quick check.
_, err = http.DefaultClient.Get("https://coder.com")
require.NoError(t, err)
require.Equal(t, count(), 2)
}