Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
add expires on idp
  • Loading branch information
Emyrk committed Jan 12, 2024
commit 2fb4b0c948492d5208a4034417e7f82e78cd60ad
70 changes: 51 additions & 19 deletions coderd/coderdtest/oidctest/idp.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,12 @@ import (
"github.com/coder/coder/v2/codersdk"
)

type token struct {
issued time.Time
email string
exp time.Time
}

// FakeIDP is a functional OIDC provider.
// It only supports 1 OIDC client.
type FakeIDP struct {
Expand All @@ -65,7 +71,7 @@ type FakeIDP struct {
// That is the various access tokens, refresh tokens, states, etc.
codeToStateMap *syncmap.Map[string, string]
// Token -> Email
accessTokens *syncmap.Map[string, string]
accessTokens *syncmap.Map[string, token]
// Refresh Token -> Email
refreshTokensUsed *syncmap.Map[string, bool]
refreshTokens *syncmap.Map[string, string]
Expand Down Expand Up @@ -173,6 +179,12 @@ func WithLogging(t testing.TB, options *slogtest.Options) func(*FakeIDP) {
}
}

func WithLogger(logger slog.Logger) func(*FakeIDP) {
return func(f *FakeIDP) {
f.logger = logger
}
}

// WithStaticUserInfo is optional, but will return the same user info for
// every user on the /userinfo endpoint.
func WithStaticUserInfo(info jwt.MapClaims) func(*FakeIDP) {
Expand Down Expand Up @@ -229,7 +241,7 @@ func NewFakeIDP(t testing.TB, opts ...FakeIDPOpt) *FakeIDP {
clientSecret: uuid.NewString(),
logger: slog.Make(),
codeToStateMap: syncmap.New[string, string](),
accessTokens: syncmap.New[string, string](),
accessTokens: syncmap.New[string, token](),
refreshTokens: syncmap.New[string, string](),
refreshTokensUsed: syncmap.New[string, bool](),
stateToIDTokenClaims: syncmap.New[string, jwt.MapClaims](),
Expand Down Expand Up @@ -284,7 +296,7 @@ func (f *FakeIDP) updateIssuerURL(t testing.TB, issuer string) {
Algorithms: []string{
"RS256",
},
ExternalAuthURL: u.ResolveReference(&url.URL{Path: fmt.Sprintf("/external-auth-validate/%s", f.externalProviderID)}).String(),
ExternalAuthURL: u.ResolveReference(&url.URL{Path: "/external-auth-validate/user"}).String(),
}
}

Expand Down Expand Up @@ -417,7 +429,7 @@ func (f *FakeIDP) LoginWithClient(t testing.TB, client *codersdk.Client, idToken
// ExternalLogin does the oauth2 flow for external auth providers. This requires
// an authenticated coder client.
func (f *FakeIDP) ExternalLogin(t testing.TB, client *codersdk.Client, opts ...func(r *http.Request)) {
coderOauthURL, err := client.URL.Parse(fmt.Sprintf("/external-auth/%s/callback", f.externalProviderID))
coderOauthURL, err := client.URL.Parse("/external-auth/callback")
require.NoError(t, err)
f.SetRedirect(t, coderOauthURL.String())

Expand Down Expand Up @@ -544,9 +556,13 @@ func (f *FakeIDP) newCode(state string) string {

// newToken enforces the access token exchanged is actually a valid access token
// created by the IDP.
func (f *FakeIDP) newToken(email string) string {
func (f *FakeIDP) newToken(email string, expires time.Time) string {
accessToken := uuid.NewString()
f.accessTokens.Store(accessToken, email)
f.accessTokens.Store(accessToken, token{
issued: time.Now(),
email: email,
exp: expires,
})
return accessToken
}

Expand All @@ -562,10 +578,15 @@ func (f *FakeIDP) authenticateBearerTokenRequest(t testing.TB, req *http.Request

auth := req.Header.Get("Authorization")
token := strings.TrimPrefix(auth, "Bearer ")
_, ok := f.accessTokens.Load(token)
authToken, ok := f.accessTokens.Load(token)
if !ok {
return "", xerrors.New("invalid access token")
}

if !authToken.exp.IsZero() && authToken.exp.Before(time.Now()) {
return "", xerrors.New("access token expired")
}

return token, nil
}

Expand Down Expand Up @@ -690,7 +711,8 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler {
mux.Handle(tokenPath, http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
values, err := f.authenticateOIDCClientRequest(t, r)
f.logger.Info(r.Context(), "http idp call token",
slog.Error(err),
slog.F("valid", err == nil),
slog.F("grant_type", values.Get("grant_type")),
slog.F("values", values.Encode()),
)
if err != nil {
Expand Down Expand Up @@ -773,7 +795,7 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler {
email := getEmail(claims)
refreshToken := f.newRefreshTokens(email)
token := map[string]interface{}{
"access_token": f.newToken(email),
"access_token": f.newToken(email, exp),
"refresh_token": refreshToken,
"token_type": "Bearer",
"expires_in": int64((f.defaultExpire).Seconds()),
Expand All @@ -791,25 +813,31 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler {

validateMW := func(rw http.ResponseWriter, r *http.Request) (email string, ok bool) {
token, err := f.authenticateBearerTokenRequest(t, r)
f.logger.Info(r.Context(), "http call idp user info",
slog.Error(err),
slog.F("url", r.URL.String()),
)
if err != nil {
http.Error(rw, fmt.Sprintf("invalid user info request: %s", err.Error()), http.StatusBadRequest)
http.Error(rw, fmt.Sprintf("invalid user info request: %s", err.Error()), http.StatusUnauthorized)
return "", false
}

email, ok = f.accessTokens.Load(token)
authToken, ok := f.accessTokens.Load(token)
if !ok {
t.Errorf("access token user for user_info has no email to indicate which user")
http.Error(rw, "invalid access token, missing user info", http.StatusBadRequest)
http.Error(rw, "invalid access token, missing user info", http.StatusUnauthorized)
return "", false
}

if !authToken.exp.IsZero() && authToken.exp.Before(time.Now()) {
http.Error(rw, "auth token expired", http.StatusUnauthorized)
return "", false
}
return email, true

return authToken.email, true
}
mux.Handle(userInfoPath, http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
email, ok := validateMW(rw, r)
f.logger.Info(r.Context(), "http userinfo",
slog.F("valid", ok),
slog.F("email", email),
)
if !ok {
return
}
Expand All @@ -827,6 +855,10 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler {
// should be strict, and this one needs to handle sub routes.
mux.Mount("/external-auth-validate/", http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
email, ok := validateMW(rw, r)
f.logger.Info(r.Context(), "http external auth validate",
slog.F("valid", ok),
slog.F("email", email),
)
if !ok {
return
}
Expand Down Expand Up @@ -978,7 +1010,7 @@ func (f *FakeIDP) ExternalAuthConfig(t testing.TB, id string, custom *ExternalAu
}
f.externalProviderID = id
f.externalAuthValidate = func(email string, rw http.ResponseWriter, r *http.Request) {
newPath := strings.TrimPrefix(r.URL.Path, fmt.Sprintf("/external-auth-validate/%s", id))
newPath := strings.TrimPrefix(r.URL.Path, "/external-auth-validate")
switch newPath {
// /user is ALWAYS supported under the `/` path too.
case "/user", "/", "":
Expand Down Expand Up @@ -1010,7 +1042,7 @@ func (f *FakeIDP) ExternalAuthConfig(t testing.TB, id string, custom *ExternalAu
DisplayIcon: f.WellknownConfig().UserInfoURL,
// Omit the /user for the validate so we can easily append to it when modifying
// the cfg for advanced tests.
ValidateURL: f.issuerURL.ResolveReference(&url.URL{Path: fmt.Sprintf("/external-auth-validate/%s", id)}).String(),
ValidateURL: f.issuerURL.ResolveReference(&url.URL{Path: "/external-auth-validate/user"}).String(),
}
for _, opt := range opts {
opt(cfg)
Expand Down