Skip to content

Commit 4637ddb

Browse files
committed
refactor rest of gitauth tests
1 parent 212429b commit 4637ddb

File tree

3 files changed

+284
-156
lines changed

3 files changed

+284
-156
lines changed

coderd/coderdtest/oidctest/idp.go

Lines changed: 44 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"crypto/x509"
88
"encoding/json"
99
"encoding/pem"
10+
"errors"
1011
"fmt"
1112
"io"
1213
"net"
@@ -66,7 +67,7 @@ type FakeIDP struct {
6667
// IDP -> Application. Almost all IDPs have the concept of
6768
// "Authorized Redirect URLs". This can be used to emulate that.
6869
hookValidRedirectURL func(redirectURL string) error
69-
hookUserInfo func(email string) jwt.MapClaims
70+
hookUserInfo func(email string) (jwt.MapClaims, error)
7071
fakeCoderd func(req *http.Request) (*http.Response, error)
7172
hookOnRefresh func(email string) error
7273
// Custom authentication for the client. This is useful if you want
@@ -75,6 +76,26 @@ type FakeIDP struct {
7576
serve bool
7677
}
7778

79+
func StatusError(code int, err error) error {
80+
return statusHookError{
81+
Err: err,
82+
HTTPStatusCode: code,
83+
}
84+
}
85+
86+
// statusHookError allows a hook to change the returned http status code.
87+
type statusHookError struct {
88+
Err error
89+
HTTPStatusCode int
90+
}
91+
92+
func (s statusHookError) Error() string {
93+
if s.Err == nil {
94+
return ""
95+
}
96+
return s.Err.Error()
97+
}
98+
7899
type FakeIDPOpt func(idp *FakeIDP)
79100

80101
func WithAuthorizedRedirectURL(hook func(redirectURL string) error) func(*FakeIDP) {
@@ -108,13 +129,13 @@ func WithLogging(t testing.TB, options *slogtest.Options) func(*FakeIDP) {
108129
// every user on the /userinfo endpoint.
109130
func WithStaticUserInfo(info jwt.MapClaims) func(*FakeIDP) {
110131
return func(f *FakeIDP) {
111-
f.hookUserInfo = func(_ string) jwt.MapClaims {
112-
return info
132+
f.hookUserInfo = func(_ string) (jwt.MapClaims, error) {
133+
return info, nil
113134
}
114135
}
115136
}
116137

117-
func WithDynamicUserInfo(userInfoFunc func(email string) jwt.MapClaims) func(*FakeIDP) {
138+
func WithDynamicUserInfo(userInfoFunc func(email string) (jwt.MapClaims, error)) func(*FakeIDP) {
118139
return func(f *FakeIDP) {
119140
f.hookUserInfo = userInfoFunc
120141
}
@@ -160,7 +181,7 @@ func NewFakeIDP(t testing.TB, opts ...FakeIDPOpt) *FakeIDP {
160181
stateToIDTokenClaims: syncmap.New[string, jwt.MapClaims](),
161182
refreshIDTokenClaims: syncmap.New[string, jwt.MapClaims](),
162183
hookOnRefresh: func(_ string) error { return nil },
163-
hookUserInfo: func(email string) jwt.MapClaims { return jwt.MapClaims{} },
184+
hookUserInfo: func(email string) (jwt.MapClaims, error) { return jwt.MapClaims{}, nil },
164185
hookValidRedirectURL: func(redirectURL string) error { return nil },
165186
}
166187

@@ -489,7 +510,7 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler {
489510
err := f.hookValidRedirectURL(redirectURI)
490511
if err != nil {
491512
t.Errorf("not authorized redirect_uri by custom hook %q: %s", redirectURI, err.Error())
492-
http.Error(rw, fmt.Sprintf("invalid redirect_uri: %s", err.Error()), http.StatusBadRequest)
513+
http.Error(rw, fmt.Sprintf("invalid redirect_uri: %s", err.Error()), httpErrorCode(http.StatusBadRequest, err))
493514
return
494515
}
495516

@@ -515,7 +536,7 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler {
515536
slog.F("values", values.Encode()),
516537
)
517538
if err != nil {
518-
http.Error(rw, fmt.Sprintf("invalid token request: %s", err.Error()), http.StatusBadRequest)
539+
http.Error(rw, fmt.Sprintf("invalid token request: %s", err.Error()), httpErrorCode(http.StatusBadRequest, err))
519540
return
520541
}
521542
getEmail := func(claims jwt.MapClaims) string {
@@ -576,7 +597,7 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler {
576597
claims = idTokenClaims
577598
err := f.hookOnRefresh(getEmail(claims))
578599
if err != nil {
579-
http.Error(rw, fmt.Sprintf("refresh hook blocked refresh: %s", err.Error()), http.StatusBadRequest)
600+
http.Error(rw, fmt.Sprintf("refresh hook blocked refresh: %s", err.Error()), httpErrorCode(http.StatusBadRequest, err))
580601
return
581602
}
582603

@@ -624,7 +645,12 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler {
624645
http.Error(rw, "invalid access token, missing user info", http.StatusBadRequest)
625646
return
626647
}
627-
_ = json.NewEncoder(rw).Encode(f.hookUserInfo(email))
648+
claims, err := f.hookUserInfo(email)
649+
if err != nil {
650+
http.Error(rw, fmt.Sprintf("user info hook returned error: %s", err.Error()), httpErrorCode(http.StatusBadRequest, err))
651+
return
652+
}
653+
_ = json.NewEncoder(rw).Encode(claims)
628654
}))
629655

630656
mux.Handle(keysPath, http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
@@ -782,6 +808,15 @@ func (f *FakeIDP) OIDCConfig(t testing.TB, scopes []string, opts ...func(cfg *co
782808
return cfg
783809
}
784810

811+
func httpErrorCode(defaultCode int, err error) int {
812+
var stautsErr statusHookError
813+
var status = defaultCode
814+
if errors.As(err, &stautsErr) {
815+
status = stautsErr.HTTPStatusCode
816+
}
817+
return status
818+
}
819+
785820
type fakeRoundTripper struct {
786821
roundTrip func(req *http.Request) (*http.Response, error)
787822
}

coderd/gitauth/config.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ type Config struct {
5959
}
6060

6161
// RefreshToken automatically refreshes the token if expired and permitted.
62-
// It returns the token and a bool indicating if the token was refreshed.
62+
// It returns the token and a bool indicating if the token is valid.
6363
func (c *Config) RefreshToken(ctx context.Context, db database.Store, gitAuthLink database.GitAuthLink) (database.GitAuthLink, bool, error) {
6464
// If the token is expired and refresh is disabled, we prompt
6565
// the user to authenticate again.

0 commit comments

Comments
 (0)