7
7
"crypto/x509"
8
8
"encoding/json"
9
9
"encoding/pem"
10
+ "errors"
10
11
"fmt"
11
12
"io"
12
13
"net"
@@ -66,7 +67,7 @@ type FakeIDP struct {
66
67
// IDP -> Application. Almost all IDPs have the concept of
67
68
// "Authorized Redirect URLs". This can be used to emulate that.
68
69
hookValidRedirectURL func (redirectURL string ) error
69
- hookUserInfo func (email string ) jwt.MapClaims
70
+ hookUserInfo func (email string ) ( jwt.MapClaims , error )
70
71
fakeCoderd func (req * http.Request ) (* http.Response , error )
71
72
hookOnRefresh func (email string ) error
72
73
// Custom authentication for the client. This is useful if you want
@@ -75,6 +76,26 @@ type FakeIDP struct {
75
76
serve bool
76
77
}
77
78
79
+ func StatusError (code int , err error ) error {
80
+ return statusHookError {
81
+ Err : err ,
82
+ HTTPStatusCode : code ,
83
+ }
84
+ }
85
+
86
+ // statusHookError allows a hook to change the returned http status code.
87
+ type statusHookError struct {
88
+ Err error
89
+ HTTPStatusCode int
90
+ }
91
+
92
+ func (s statusHookError ) Error () string {
93
+ if s .Err == nil {
94
+ return ""
95
+ }
96
+ return s .Err .Error ()
97
+ }
98
+
78
99
type FakeIDPOpt func (idp * FakeIDP )
79
100
80
101
func WithAuthorizedRedirectURL (hook func (redirectURL string ) error ) func (* FakeIDP ) {
@@ -108,13 +129,13 @@ func WithLogging(t testing.TB, options *slogtest.Options) func(*FakeIDP) {
108
129
// every user on the /userinfo endpoint.
109
130
func WithStaticUserInfo (info jwt.MapClaims ) func (* FakeIDP ) {
110
131
return func (f * FakeIDP ) {
111
- f .hookUserInfo = func (_ string ) jwt.MapClaims {
112
- return info
132
+ f .hookUserInfo = func (_ string ) ( jwt.MapClaims , error ) {
133
+ return info , nil
113
134
}
114
135
}
115
136
}
116
137
117
- func WithDynamicUserInfo (userInfoFunc func (email string ) jwt.MapClaims ) func (* FakeIDP ) {
138
+ func WithDynamicUserInfo (userInfoFunc func (email string ) ( jwt.MapClaims , error ) ) func (* FakeIDP ) {
118
139
return func (f * FakeIDP ) {
119
140
f .hookUserInfo = userInfoFunc
120
141
}
@@ -160,7 +181,7 @@ func NewFakeIDP(t testing.TB, opts ...FakeIDPOpt) *FakeIDP {
160
181
stateToIDTokenClaims : syncmap .New [string , jwt.MapClaims ](),
161
182
refreshIDTokenClaims : syncmap .New [string , jwt.MapClaims ](),
162
183
hookOnRefresh : func (_ string ) error { return nil },
163
- hookUserInfo : func (email string ) jwt.MapClaims { return jwt.MapClaims {} },
184
+ hookUserInfo : func (email string ) ( jwt.MapClaims , error ) { return jwt.MapClaims {}, nil },
164
185
hookValidRedirectURL : func (redirectURL string ) error { return nil },
165
186
}
166
187
@@ -489,7 +510,7 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler {
489
510
err := f .hookValidRedirectURL (redirectURI )
490
511
if err != nil {
491
512
t .Errorf ("not authorized redirect_uri by custom hook %q: %s" , redirectURI , err .Error ())
492
- http .Error (rw , fmt .Sprintf ("invalid redirect_uri: %s" , err .Error ()), http .StatusBadRequest )
513
+ http .Error (rw , fmt .Sprintf ("invalid redirect_uri: %s" , err .Error ()), httpErrorCode ( http .StatusBadRequest , err ) )
493
514
return
494
515
}
495
516
@@ -515,7 +536,7 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler {
515
536
slog .F ("values" , values .Encode ()),
516
537
)
517
538
if err != nil {
518
- http .Error (rw , fmt .Sprintf ("invalid token request: %s" , err .Error ()), http .StatusBadRequest )
539
+ http .Error (rw , fmt .Sprintf ("invalid token request: %s" , err .Error ()), httpErrorCode ( http .StatusBadRequest , err ) )
519
540
return
520
541
}
521
542
getEmail := func (claims jwt.MapClaims ) string {
@@ -576,7 +597,7 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler {
576
597
claims = idTokenClaims
577
598
err := f .hookOnRefresh (getEmail (claims ))
578
599
if err != nil {
579
- http .Error (rw , fmt .Sprintf ("refresh hook blocked refresh: %s" , err .Error ()), http .StatusBadRequest )
600
+ http .Error (rw , fmt .Sprintf ("refresh hook blocked refresh: %s" , err .Error ()), httpErrorCode ( http .StatusBadRequest , err ) )
580
601
return
581
602
}
582
603
@@ -624,7 +645,12 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler {
624
645
http .Error (rw , "invalid access token, missing user info" , http .StatusBadRequest )
625
646
return
626
647
}
627
- _ = json .NewEncoder (rw ).Encode (f .hookUserInfo (email ))
648
+ claims , err := f .hookUserInfo (email )
649
+ if err != nil {
650
+ http .Error (rw , fmt .Sprintf ("user info hook returned error: %s" , err .Error ()), httpErrorCode (http .StatusBadRequest , err ))
651
+ return
652
+ }
653
+ _ = json .NewEncoder (rw ).Encode (claims )
628
654
}))
629
655
630
656
mux .Handle (keysPath , http .HandlerFunc (func (rw http.ResponseWriter , r * http.Request ) {
@@ -782,6 +808,15 @@ func (f *FakeIDP) OIDCConfig(t testing.TB, scopes []string, opts ...func(cfg *co
782
808
return cfg
783
809
}
784
810
811
+ func httpErrorCode (defaultCode int , err error ) int {
812
+ var stautsErr statusHookError
813
+ var status = defaultCode
814
+ if errors .As (err , & stautsErr ) {
815
+ status = stautsErr .HTTPStatusCode
816
+ }
817
+ return status
818
+ }
819
+
785
820
type fakeRoundTripper struct {
786
821
roundTrip func (req * http.Request ) (* http.Response , error )
787
822
}
0 commit comments