Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
add custom expirary
  • Loading branch information
Emyrk committed Jan 12, 2024
commit 82c81697a7eedf323c3e91ea9b183406de6fd4f1
42 changes: 38 additions & 4 deletions coderd/coderdtest/oidctest/idp.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,8 @@ type FakeIDP struct {
hookAuthenticateClient func(t testing.TB, req *http.Request) (url.Values, error)
serve bool
// optional middlewares
middlewares chi.Middlewares
middlewares chi.Middlewares
defaultExpire time.Duration
}

func StatusError(code int, err error) error {
Expand Down Expand Up @@ -134,6 +135,23 @@ func WithRefresh(hook func(email string) error) func(*FakeIDP) {
}
}

func WithDefaultExpire(d time.Duration) func(*FakeIDP) {
return func(f *FakeIDP) {
f.defaultExpire = d
}
}

func WithStaticCredentials(id, secret string) func(*FakeIDP) {
return func(f *FakeIDP) {
if id != "" {
f.clientID = id
}
if secret != "" {
f.clientSecret = secret
}
}
}

// WithExtra returns extra fields that be accessed on the returned Oauth Token.
// These extra fields can override the default fields (id_token, access_token, etc).
func WithMutateToken(mutateToken func(token map[string]interface{})) func(*FakeIDP) {
Expand Down Expand Up @@ -219,6 +237,7 @@ func NewFakeIDP(t testing.TB, opts ...FakeIDPOpt) *FakeIDP {
hookOnRefresh: func(_ string) error { return nil },
hookUserInfo: func(email string) (jwt.MapClaims, error) { return jwt.MapClaims{}, nil },
hookValidRedirectURL: func(redirectURL string) error { return nil },
defaultExpire: time.Minute * 5,
}

for _, opt := range opts {
Expand Down Expand Up @@ -272,8 +291,23 @@ func (f *FakeIDP) updateIssuerURL(t testing.TB, issuer string) {
func (f *FakeIDP) realServer(t testing.TB) *httptest.Server {
t.Helper()

srvURL := "localhost:0"
issURL, err := url.Parse(f.issuer)
if err == nil {
if issURL.Hostname() == "localhost" || issURL.Hostname() == "127.0.0.1" {
srvURL = issURL.Host
}
}

l, err := net.Listen("tcp", srvURL)
require.NoError(t, err, "failed to create listener")

ctx, cancel := context.WithCancel(context.Background())
srv := httptest.NewUnstartedServer(f.handler)
srv := &httptest.Server{
Listener: l,
Config: &http.Server{Handler: f.handler},
}

srv.Config.BaseContext = func(_ net.Listener) context.Context {
return ctx
}
Expand Down Expand Up @@ -731,15 +765,15 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler {
return
}

exp := time.Now().Add(time.Minute * 5)
exp := time.Now().Add(f.defaultExpire)
claims["exp"] = exp.UnixMilli()
email := getEmail(claims)
refreshToken := f.newRefreshTokens(email)
token := map[string]interface{}{
"access_token": f.newToken(email),
"refresh_token": refreshToken,
"token_type": "Bearer",
"expires_in": int64((time.Minute * 5).Seconds()),
"expires_in": int64((f.defaultExpire).Seconds()),
"id_token": f.encodeClaims(t, claims),
}
if f.hookMutateToken != nil {
Expand Down