Skip to content

Commit 1c7e8b4

Browse files
committed
Pr feedback, add authorized redirect urls
1 parent 00b4760 commit 1c7e8b4

File tree

2 files changed

+43
-23
lines changed

2 files changed

+43
-23
lines changed

coderd/coderdtest/oidctest/helper.go

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -19,20 +19,20 @@ import (
1919
// It is mainly because refreshing oauth tokens is a bit tricky and requires
2020
// some database manipulation.
2121
type LoginHelper struct {
22-
fake *FakeIDP
23-
owner *codersdk.Client
22+
fake *FakeIDP
23+
client *codersdk.Client
2424
}
2525

26-
func NewLoginHelper(owner *codersdk.Client, fake *FakeIDP) *LoginHelper {
27-
if owner == nil {
28-
panic("owner must not be nil")
26+
func NewLoginHelper(client *codersdk.Client, fake *FakeIDP) *LoginHelper {
27+
if client == nil {
28+
panic("client must not be nil")
2929
}
3030
if fake == nil {
3131
panic("fake must not be nil")
3232
}
3333
return &LoginHelper{
34-
fake: fake,
35-
owner: owner,
34+
fake: fake,
35+
client: client,
3636
}
3737
}
3838

@@ -41,13 +41,13 @@ func NewLoginHelper(owner *codersdk.Client, fake *FakeIDP) *LoginHelper {
4141
// convenience method.
4242
func (h *LoginHelper) Login(t *testing.T, idTokenClaims jwt.MapClaims) (*codersdk.Client, *http.Response) {
4343
t.Helper()
44-
unauthenticatedClient := codersdk.New(h.owner.URL)
44+
unauthenticatedClient := codersdk.New(h.client.URL)
4545

4646
return h.fake.Login(t, unauthenticatedClient, idTokenClaims)
4747
}
4848

4949
// ExpireOauthToken expires the oauth token for the given user.
50-
func (*LoginHelper) ExpireOauthToken(t *testing.T, db database.Store, user *codersdk.Client) (refreshToken string) {
50+
func (*LoginHelper) ExpireOauthToken(t *testing.T, db database.Store, user *codersdk.Client) database.UserLink {
5151
t.Helper()
5252

5353
//nolint:gocritic // Testing
@@ -68,7 +68,7 @@ func (*LoginHelper) ExpireOauthToken(t *testing.T, db database.Store, user *code
6868
require.NoError(t, err, "get user link")
6969

7070
// Expire the oauth link for the given user.
71-
_, err = db.UpdateUserLink(ctx, database.UpdateUserLinkParams{
71+
updated, err := db.UpdateUserLink(ctx, database.UpdateUserLinkParams{
7272
OAuthAccessToken: link.OAuthAccessToken,
7373
OAuthRefreshToken: link.OAuthRefreshToken,
7474
OAuthExpiry: time.Now().Add(time.Hour * -1),
@@ -77,7 +77,7 @@ func (*LoginHelper) ExpireOauthToken(t *testing.T, db database.Store, user *code
7777
})
7878
require.NoError(t, err, "expire user link")
7979

80-
return link.OAuthRefreshToken
80+
return updated
8181
}
8282

8383
// ForceRefresh forces the client to refresh its oauth token. It does this by
@@ -88,13 +88,13 @@ func (*LoginHelper) ExpireOauthToken(t *testing.T, db database.Store, user *code
8888
func (h *LoginHelper) ForceRefresh(t *testing.T, db database.Store, user *codersdk.Client, idToken jwt.MapClaims) {
8989
t.Helper()
9090

91-
refreshToken := h.ExpireOauthToken(t, db, user)
91+
link := h.ExpireOauthToken(t, db, user)
9292
// Updates the claims that the IDP will return. By default, it always
9393
// uses the original claims for the original oauth token.
94-
h.fake.UpdateRefreshClaims(refreshToken, idToken)
94+
h.fake.UpdateRefreshClaims(link.OAuthRefreshToken, idToken)
9595

9696
t.Cleanup(func() {
97-
require.True(t, h.fake.RefreshUsed(refreshToken), "refresh token must be used, but has not. Did you forget to call the returned function from this call?")
97+
require.True(t, h.fake.RefreshUsed(link.OAuthRefreshToken), "refresh token must be used, but has not. Did you forget to call the returned function from this call?")
9898
})
9999

100100
// Do any authenticated call to force the refresh

coderd/coderdtest/oidctest/idp.go

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -60,9 +60,13 @@ type FakeIDP struct {
6060
refreshIDTokenClaims *SyncMap[string, jwt.MapClaims]
6161

6262
// hooks
63-
hookUserInfo func(email string) jwt.MapClaims
64-
fakeCoderd func(req *http.Request) (*http.Response, error)
65-
hookOnRefresh func(email string) error
63+
// hookValidRedirectURL can be used to reject a redirect url from the
64+
// IDP -> Application. Almost all IDPs have the concept of
65+
// "Authorized Redirect URLs". This can be used to emulate that.
66+
hookValidRedirectURL func(redirectURL string) error
67+
hookUserInfo func(email string) jwt.MapClaims
68+
fakeCoderd func(req *http.Request) (*http.Response, error)
69+
hookOnRefresh func(email string) error
6670
// Custom authentication for the client. This is useful if you want
6771
// to test something like PKI auth vs a client_secret.
6872
hookAuthenticateClient func(t testing.TB, req *http.Request) (url.Values, error)
@@ -74,6 +78,12 @@ type FakeIDP struct {
7478

7579
type FakeIDPOpt func(idp *FakeIDP)
7680

81+
func WithAuthorizedRedirectURL(hook func(redirectURL string) error) func(*FakeIDP) {
82+
return func(f *FakeIDP) {
83+
f.hookValidRedirectURL = hook
84+
}
85+
}
86+
7787
// WithRefreshHook is called when a refresh token is used. The email is
7888
// the email of the user that is being refreshed assuming the claims are correct.
7989
func WithRefreshHook(hook func(email string) error) func(*FakeIDP) {
@@ -421,6 +431,8 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler {
421431
// This endpoint is required to initialize the OIDC provider.
422432
// It is used to get the OIDC configuration.
423433
mux.Get("/.well-known/openid-configuration", func(rw http.ResponseWriter, r *http.Request) {
434+
f.logger.Info(r.Context(), "HTTP OIDC Config", slog.F("url", r.URL.String()))
435+
424436
_ = json.NewEncoder(rw).Encode(f.provider)
425437
})
426438

@@ -429,19 +441,19 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler {
429441
// w/e and clicking "Allow". They will be redirected back to the redirect
430442
// when this is done.
431443
mux.Handle(authorizePath, http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
432-
f.logger.Info(r.Context(), "HTTP Call Authorize", slog.F("url", string(r.URL.String())))
444+
f.logger.Info(r.Context(), "HTTP Call Authorize", slog.F("url", r.URL.String()))
433445

434446
clientID := r.URL.Query().Get("client_id")
435-
if clientID != f.clientID {
436-
t.Errorf("unexpected client_id %q", clientID)
447+
if !assert.Equal(t, f.clientID, clientID, "unexpected client_id") {
437448
http.Error(rw, "invalid client_id", http.StatusBadRequest)
449+
return
438450
}
439451

440452
redirectURI := r.URL.Query().Get("redirect_uri")
441453
state := r.URL.Query().Get("state")
442454

443455
scope := r.URL.Query().Get("scope")
444-
_ = scope
456+
assert.NotEmpty(t, scope, "scope is empty")
445457

446458
responseType := r.URL.Query().Get("response_type")
447459
switch responseType {
@@ -456,10 +468,17 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler {
456468
return
457469
}
458470

471+
err := f.hookValidRedirectURL(redirectURI)
472+
if err != nil {
473+
t.Errorf("not authorized redirect_uri by custom hook %q: %s", redirectURI, err.Error())
474+
http.Error(rw, fmt.Sprintf("invalid redirect_uri: %s", err.Error()), http.StatusBadRequest)
475+
return
476+
}
477+
459478
ru, err := url.Parse(redirectURI)
460479
if err != nil {
461-
t.Errorf("invalid redirect_uri %q", redirectURI)
462-
http.Error(rw, "invalid redirect_uri", http.StatusBadRequest)
480+
t.Errorf("invalid redirect_uri %q: %s", redirectURI, err.Error())
481+
http.Error(rw, fmt.Sprintf("invalid redirect_uri: %s", err.Error()), http.StatusBadRequest)
463482
return
464483
}
465484

@@ -573,6 +592,7 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler {
573592
token, err := f.authenticateBearerTokenRequest(t, r)
574593
f.logger.Info(r.Context(), "HTTP Call UserInfo",
575594
slog.Error(err),
595+
slog.F("url", r.URL.String()),
576596
)
577597
if err != nil {
578598
http.Error(rw, fmt.Sprintf("invalid user info request: %s", err.Error()), http.StatusBadRequest)

0 commit comments

Comments
 (0)