Skip to content

Commit b42ea1e

Browse files
committed
Add a unit test
1 parent 970ec25 commit b42ea1e

File tree

3 files changed

+117
-14
lines changed

3 files changed

+117
-14
lines changed

coderd/coderdtest/oidctest/idp.go

+47-6
Original file line numberDiff line numberDiff line change
@@ -397,6 +397,43 @@ func (f *FakeIDP) ExternalLogin(t testing.TB, client *codersdk.Client, opts ...f
397397
_ = res.Body.Close()
398398
}
399399

400+
// CreateAuthCode emulates a user clicking "allow" on the IDP page. When doing
401+
// unit tests, it's easier to skip this step sometimes. It does make an actual
402+
// request to the IDP, so it should be equivalent to doing this "manually" with
403+
// actual requests.
404+
func (f *FakeIDP) CreateAuthCode(t testing.TB, state string, opts ...func(r *http.Request)) string {
405+
// We need to store some claims, because this is also an OIDC provider, and
406+
// it expects some claims to be present.
407+
f.stateToIDTokenClaims.Store(state, jwt.MapClaims{})
408+
409+
u := f.cfg.AuthCodeURL(state)
410+
r, err := http.NewRequestWithContext(context.Background(), http.MethodPost, u, nil)
411+
require.NoError(t, err, "failed to create auth request")
412+
413+
for _, opt := range opts {
414+
opt(r)
415+
}
416+
417+
rw := httptest.NewRecorder()
418+
f.handler.ServeHTTP(rw, r)
419+
resp := rw.Result()
420+
421+
require.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode, "expected redirect")
422+
to := resp.Header.Get("Location")
423+
require.NotEmpty(t, to, "expected redirect location")
424+
425+
toUrl, err := url.Parse(to)
426+
require.NoError(t, err, "failed to parse redirect location")
427+
428+
code := toUrl.Query().Get("code")
429+
require.NotEmpty(t, code, "expected code in redirect location")
430+
431+
newState := toUrl.Query().Get("state")
432+
require.Equal(t, state, newState, "expected state to match")
433+
434+
return code
435+
}
436+
400437
// OIDCCallback will emulate the IDP redirecting back to the Coder callback.
401438
// This is helpful if no Coderd exists because the IDP needs to redirect to
402439
// something.
@@ -917,13 +954,10 @@ func (f *FakeIDP) ExternalAuthConfig(t testing.TB, id string, custom *ExternalAu
917954
return cfg
918955
}
919956

920-
// OIDCConfig returns the OIDC config to use for Coderd.
921-
func (f *FakeIDP) OIDCConfig(t testing.TB, scopes []string, opts ...func(cfg *coderd.OIDCConfig)) *coderd.OIDCConfig {
922-
t.Helper()
957+
func (f *FakeIDP) OAuthConfig(scopes ...string) *oauth2.Config {
923958
if len(scopes) == 0 {
924959
scopes = []string{"openid", "email", "profile"}
925960
}
926-
927961
oauthCfg := &oauth2.Config{
928962
ClientID: f.clientID,
929963
ClientSecret: f.clientSecret,
@@ -937,6 +971,15 @@ func (f *FakeIDP) OIDCConfig(t testing.TB, scopes []string, opts ...func(cfg *co
937971
RedirectURL: "https://redirect.com",
938972
Scopes: scopes,
939973
}
974+
f.cfg = oauthCfg
975+
return oauthCfg
976+
}
977+
978+
// OIDCConfig returns the OIDC config to use for Coderd.
979+
func (f *FakeIDP) OIDCConfig(t testing.TB, scopes []string, opts ...func(cfg *coderd.OIDCConfig)) *coderd.OIDCConfig {
980+
t.Helper()
981+
982+
oauthCfg := f.OAuthConfig(scopes...)
940983

941984
ctx := oidc.ClientContext(context.Background(), f.HTTPClient(nil))
942985
p, err := oidc.NewProvider(ctx, f.provider.Issuer)
@@ -965,8 +1008,6 @@ func (f *FakeIDP) OIDCConfig(t testing.TB, scopes []string, opts ...func(cfg *co
9651008
opt(cfg)
9661009
}
9671010

968-
f.cfg = oauthCfg
969-
9701011
return cfg
9711012
}
9721013

coderd/promoauth/oauth2.go

+17-8
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,8 @@ func NewFactory(registry prometheus.Registerer) *Factory {
4141
Help: "The total number of api calls made to external oauth2 providers. 'status_code' will be 0 if the request failed with no response.",
4242
}, []string{
4343
"name",
44+
"source",
4445
"status_code",
45-
"domain",
4646
}),
4747
},
4848
}
@@ -71,11 +71,11 @@ func (c *Config) AuthCodeURL(state string, opts ...oauth2.AuthCodeOption) string
7171
}
7272

7373
func (c *Config) Exchange(ctx context.Context, code string, opts ...oauth2.AuthCodeOption) (*oauth2.Token, error) {
74-
return c.underlying.Exchange(c.wrapClient(ctx), code, opts...)
74+
return c.underlying.Exchange(c.wrapClient(ctx, "Exchange"), code, opts...)
7575
}
7676

7777
func (c *Config) TokenSource(ctx context.Context, token *oauth2.Token) oauth2.TokenSource {
78-
return c.underlying.TokenSource(c.wrapClient(ctx), token)
78+
return c.underlying.TokenSource(c.wrapClient(ctx, "TokenSource"), token)
7979
}
8080

8181
// wrapClient is the only way we can accurately instrument the oauth2 client.
@@ -85,30 +85,39 @@ func (c *Config) TokenSource(ctx context.Context, token *oauth2.Token) oauth2.To
8585
// For example, the 'TokenSource' method will return a token
8686
// source that will make a network request when the 'Token' method is called on
8787
// it if the token is expired.
88-
func (c *Config) wrapClient(ctx context.Context) context.Context {
89-
cli := http.DefaultClient
88+
func (c *Config) wrapClient(ctx context.Context, source string) context.Context {
89+
cli := &http.Client{}
9090

9191
// Check if the context has an http client already.
9292
if hc, ok := ctx.Value(oauth2.HTTPClient).(*http.Client); ok {
9393
cli = hc
9494
}
9595

9696
// The new tripper will instrument every request made by the oauth2 client.
97-
cli.Transport = newInstrumentedTripper(c, cli.Transport)
97+
cli.Transport = newInstrumentedTripper(c, source, cli.Transport)
9898
return context.WithValue(ctx, oauth2.HTTPClient, cli)
9999
}
100100

101101
type instrumentedTripper struct {
102102
c *Config
103+
source string
103104
underlying http.RoundTripper
104105
}
105106

106-
func newInstrumentedTripper(c *Config, under http.RoundTripper) *instrumentedTripper {
107+
func newInstrumentedTripper(c *Config, source string, under http.RoundTripper) *instrumentedTripper {
107108
if under == nil {
108109
under = http.DefaultTransport
109110
}
111+
112+
// If the underlying transport is the default, we need to clone it.
113+
// We should also clone it if it supports cloning.
114+
if tr, ok := under.(*http.Transport); ok {
115+
under = tr.Clone()
116+
}
117+
110118
return &instrumentedTripper{
111119
c: c,
120+
source: source,
112121
underlying: under,
113122
}
114123
}
@@ -121,8 +130,8 @@ func (i *instrumentedTripper) RoundTrip(r *http.Request) (*http.Response, error)
121130
}
122131
i.c.metrics.externalRequestCount.With(prometheus.Labels{
123132
"name": i.c.name,
133+
"source": i.source,
124134
"status_code": fmt.Sprintf("%d", statusCode),
125-
"domain": r.URL.Host,
126135
}).Inc()
127136
return resp, err
128137
}

coderd/promoauth/oauth2_test.go

+53
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
package promoauth_test
2+
3+
import (
4+
"net/http"
5+
"testing"
6+
"time"
7+
8+
"github.com/prometheus/client_golang/prometheus"
9+
ptestutil "github.com/prometheus/client_golang/prometheus/testutil"
10+
"github.com/stretchr/testify/require"
11+
12+
"github.com/coder/coder/v2/coderd/coderdtest/oidctest"
13+
"github.com/coder/coder/v2/coderd/promoauth"
14+
"github.com/coder/coder/v2/testutil"
15+
)
16+
17+
func TestMaintainDefault(t *testing.T) {
18+
t.Parallel()
19+
20+
ctx := testutil.Context(t, testutil.WaitShort)
21+
idp := oidctest.NewFakeIDP(t, oidctest.WithServing())
22+
reg := prometheus.NewRegistry()
23+
count := func() int {
24+
return ptestutil.CollectAndCount(reg, "coderd_oauth2_external_requests_total")
25+
}
26+
27+
factory := promoauth.NewFactory(reg)
28+
cfg := factory.New("test", idp.OAuthConfig())
29+
30+
// 0 Requests before we start
31+
require.Equal(t, count(), 0)
32+
33+
// Exchange should trigger a request
34+
code := idp.CreateAuthCode(t, "foo")
35+
token, err := cfg.Exchange(ctx, code)
36+
require.NoError(t, err)
37+
require.Equal(t, count(), 1)
38+
39+
// Force a refresh
40+
token.Expiry = time.Now().Add(time.Hour * -1)
41+
src := cfg.TokenSource(ctx, token)
42+
refreshed, err := src.Token()
43+
require.NoError(t, err)
44+
require.NotEqual(t, token.AccessToken, refreshed.AccessToken, "token refreshed")
45+
require.Equal(t, count(), 2)
46+
47+
// Verify the default client was not broken. This check is added because we
48+
// extend the http.DefaultTransport. If a `.Clone()` is not done, this can be
49+
// mis-used. It is cheap to run this quick check.
50+
_, err = http.DefaultClient.Get("https://coder.com")
51+
require.NoError(t, err)
52+
require.Equal(t, count(), 2)
53+
}

0 commit comments

Comments
 (0)