Skip to content

chore: improve fake IDP script #11602

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Jan 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 0 additions & 58 deletions cmd/testidp/main.go

This file was deleted.

113 changes: 92 additions & 21 deletions coderd/coderdtest/oidctest/idp.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,12 @@ import (
"github.com/coder/coder/v2/codersdk"
)

type token struct {
issued time.Time
email string
exp time.Time
}

// FakeIDP is a functional OIDC provider.
// It only supports 1 OIDC client.
type FakeIDP struct {
Expand All @@ -65,7 +71,7 @@ type FakeIDP struct {
// That is the various access tokens, refresh tokens, states, etc.
codeToStateMap *syncmap.Map[string, string]
// Token -> Email
accessTokens *syncmap.Map[string, string]
accessTokens *syncmap.Map[string, token]
// Refresh Token -> Email
refreshTokensUsed *syncmap.Map[string, bool]
refreshTokens *syncmap.Map[string, string]
Expand All @@ -89,7 +95,8 @@ type FakeIDP struct {
hookAuthenticateClient func(t testing.TB, req *http.Request) (url.Values, error)
serve bool
// optional middlewares
middlewares chi.Middlewares
middlewares chi.Middlewares
defaultExpire time.Duration
}

func StatusError(code int, err error) error {
Expand Down Expand Up @@ -134,6 +141,23 @@ func WithRefresh(hook func(email string) error) func(*FakeIDP) {
}
}

func WithDefaultExpire(d time.Duration) func(*FakeIDP) {
return func(f *FakeIDP) {
f.defaultExpire = d
}
}

func WithStaticCredentials(id, secret string) func(*FakeIDP) {
return func(f *FakeIDP) {
if id != "" {
f.clientID = id
}
if secret != "" {
f.clientSecret = secret
}
}
}

// WithExtra returns extra fields that be accessed on the returned Oauth Token.
// These extra fields can override the default fields (id_token, access_token, etc).
func WithMutateToken(mutateToken func(token map[string]interface{})) func(*FakeIDP) {
Expand All @@ -155,6 +179,12 @@ func WithLogging(t testing.TB, options *slogtest.Options) func(*FakeIDP) {
}
}

func WithLogger(logger slog.Logger) func(*FakeIDP) {
return func(f *FakeIDP) {
f.logger = logger
}
}

// WithStaticUserInfo is optional, but will return the same user info for
// every user on the /userinfo endpoint.
func WithStaticUserInfo(info jwt.MapClaims) func(*FakeIDP) {
Expand Down Expand Up @@ -211,14 +241,15 @@ func NewFakeIDP(t testing.TB, opts ...FakeIDPOpt) *FakeIDP {
clientSecret: uuid.NewString(),
logger: slog.Make(),
codeToStateMap: syncmap.New[string, string](),
accessTokens: syncmap.New[string, string](),
accessTokens: syncmap.New[string, token](),
refreshTokens: syncmap.New[string, string](),
refreshTokensUsed: syncmap.New[string, bool](),
stateToIDTokenClaims: syncmap.New[string, jwt.MapClaims](),
refreshIDTokenClaims: syncmap.New[string, jwt.MapClaims](),
hookOnRefresh: func(_ string) error { return nil },
hookUserInfo: func(email string) (jwt.MapClaims, error) { return jwt.MapClaims{}, nil },
hookValidRedirectURL: func(redirectURL string) error { return nil },
defaultExpire: time.Minute * 5,
}

for _, opt := range opts {
Expand Down Expand Up @@ -265,15 +296,31 @@ func (f *FakeIDP) updateIssuerURL(t testing.TB, issuer string) {
Algorithms: []string{
"RS256",
},
ExternalAuthURL: u.ResolveReference(&url.URL{Path: "/external-auth-validate/user"}).String(),
}
}

// realServer turns the FakeIDP into a real http server.
func (f *FakeIDP) realServer(t testing.TB) *httptest.Server {
t.Helper()

srvURL := "localhost:0"
issURL, err := url.Parse(f.issuer)
if err == nil {
if issURL.Hostname() == "localhost" || issURL.Hostname() == "127.0.0.1" {
srvURL = issURL.Host
}
}

l, err := net.Listen("tcp", srvURL)
require.NoError(t, err, "failed to create listener")

ctx, cancel := context.WithCancel(context.Background())
srv := httptest.NewUnstartedServer(f.handler)
srv := &httptest.Server{
Listener: l,
Config: &http.Server{Handler: f.handler, ReadHeaderTimeout: time.Second * 5},
}

srv.Config.BaseContext = func(_ net.Listener) context.Context {
return ctx
}
Expand Down Expand Up @@ -495,6 +542,8 @@ type ProviderJSON struct {
JWKSURL string `json:"jwks_uri"`
UserInfoURL string `json:"userinfo_endpoint"`
Algorithms []string `json:"id_token_signing_alg_values_supported"`
// This is custom
ExternalAuthURL string `json:"external_auth_url"`
}

// newCode enforces the code exchanged is actually a valid code
Expand All @@ -507,9 +556,13 @@ func (f *FakeIDP) newCode(state string) string {

// newToken enforces the access token exchanged is actually a valid access token
// created by the IDP.
func (f *FakeIDP) newToken(email string) string {
func (f *FakeIDP) newToken(email string, expires time.Time) string {
accessToken := uuid.NewString()
f.accessTokens.Store(accessToken, email)
f.accessTokens.Store(accessToken, token{
issued: time.Now(),
email: email,
exp: expires,
})
return accessToken
}

Expand All @@ -525,10 +578,15 @@ func (f *FakeIDP) authenticateBearerTokenRequest(t testing.TB, req *http.Request

auth := req.Header.Get("Authorization")
token := strings.TrimPrefix(auth, "Bearer ")
_, ok := f.accessTokens.Load(token)
authToken, ok := f.accessTokens.Load(token)
if !ok {
return "", xerrors.New("invalid access token")
}

if !authToken.exp.IsZero() && authToken.exp.Before(time.Now()) {
return "", xerrors.New("access token expired")
}

return token, nil
}

Expand Down Expand Up @@ -653,7 +711,8 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler {
mux.Handle(tokenPath, http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
values, err := f.authenticateOIDCClientRequest(t, r)
f.logger.Info(r.Context(), "http idp call token",
slog.Error(err),
slog.F("valid", err == nil),
slog.F("grant_type", values.Get("grant_type")),
slog.F("values", values.Encode()),
)
if err != nil {
Expand Down Expand Up @@ -731,15 +790,15 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler {
return
}

exp := time.Now().Add(time.Minute * 5)
exp := time.Now().Add(f.defaultExpire)
claims["exp"] = exp.UnixMilli()
email := getEmail(claims)
refreshToken := f.newRefreshTokens(email)
token := map[string]interface{}{
"access_token": f.newToken(email),
"access_token": f.newToken(email, exp),
"refresh_token": refreshToken,
"token_type": "Bearer",
"expires_in": int64((time.Minute * 5).Seconds()),
"expires_in": int64((f.defaultExpire).Seconds()),
"id_token": f.encodeClaims(t, claims),
}
if f.hookMutateToken != nil {
Expand All @@ -754,25 +813,31 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler {

validateMW := func(rw http.ResponseWriter, r *http.Request) (email string, ok bool) {
token, err := f.authenticateBearerTokenRequest(t, r)
f.logger.Info(r.Context(), "http call idp user info",
slog.Error(err),
slog.F("url", r.URL.String()),
)
if err != nil {
http.Error(rw, fmt.Sprintf("invalid user info request: %s", err.Error()), http.StatusBadRequest)
http.Error(rw, fmt.Sprintf("invalid user info request: %s", err.Error()), http.StatusUnauthorized)
return "", false
}

email, ok = f.accessTokens.Load(token)
authToken, ok := f.accessTokens.Load(token)
if !ok {
t.Errorf("access token user for user_info has no email to indicate which user")
http.Error(rw, "invalid access token, missing user info", http.StatusBadRequest)
http.Error(rw, "invalid access token, missing user info", http.StatusUnauthorized)
return "", false
}

if !authToken.exp.IsZero() && authToken.exp.Before(time.Now()) {
http.Error(rw, "auth token expired", http.StatusUnauthorized)
return "", false
}
return email, true

return authToken.email, true
}
mux.Handle(userInfoPath, http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
email, ok := validateMW(rw, r)
f.logger.Info(r.Context(), "http userinfo endpoint",
slog.F("valid", ok),
slog.F("email", email),
)
if !ok {
return
}
Expand All @@ -790,6 +855,10 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler {
// should be strict, and this one needs to handle sub routes.
mux.Mount("/external-auth-validate/", http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
email, ok := validateMW(rw, r)
f.logger.Info(r.Context(), "http external auth validate",
slog.F("valid", ok),
slog.F("email", email),
)
if !ok {
return
}
Expand Down Expand Up @@ -941,7 +1010,7 @@ func (f *FakeIDP) ExternalAuthConfig(t testing.TB, id string, custom *ExternalAu
}
f.externalProviderID = id
f.externalAuthValidate = func(email string, rw http.ResponseWriter, r *http.Request) {
newPath := strings.TrimPrefix(r.URL.Path, fmt.Sprintf("/external-auth-validate/%s", id))
newPath := strings.TrimPrefix(r.URL.Path, "/external-auth-validate")
switch newPath {
// /user is ALWAYS supported under the `/` path too.
case "/user", "/", "":
Expand All @@ -965,18 +1034,20 @@ func (f *FakeIDP) ExternalAuthConfig(t testing.TB, id string, custom *ExternalAu
}
instrumentF := promoauth.NewFactory(prometheus.NewRegistry())
cfg := &externalauth.Config{
DisplayName: id,
InstrumentedOAuth2Config: instrumentF.New(f.clientID, f.OIDCConfig(t, nil)),
ID: id,
// No defaults for these fields by omitting the type
Type: "",
DisplayIcon: f.WellknownConfig().UserInfoURL,
// Omit the /user for the validate so we can easily append to it when modifying
// the cfg for advanced tests.
ValidateURL: f.issuerURL.ResolveReference(&url.URL{Path: fmt.Sprintf("/external-auth-validate/%s", id)}).String(),
ValidateURL: f.issuerURL.ResolveReference(&url.URL{Path: "/external-auth-validate/"}).String(),
}
for _, opt := range opts {
opt(cfg)
}
f.updateIssuerURL(t, f.issuer)
return cfg
}

Expand Down
2 changes: 1 addition & 1 deletion coderd/externalauth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ func TestExternalAuthByID(t *testing.T) {
client := coderdtest.New(t, &coderdtest.Options{
ExternalAuthConfigs: []*externalauth.Config{
fake.ExternalAuthConfig(t, providerID, routes, func(cfg *externalauth.Config) {
cfg.AppInstallationsURL = cfg.ValidateURL + "/installs"
cfg.AppInstallationsURL = strings.TrimSuffix(cfg.ValidateURL, "/") + "/installs"
cfg.Type = codersdk.EnhancedExternalAuthProviderGitHub.String()
}),
},
Expand Down
File renamed without changes.
Loading