Skip to content

fix: make 'NoRefresh' honor unlimited tokens in gitauth #9472

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Sep 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 64 additions & 15 deletions coderd/coderdtest/oidctest/idp.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"crypto/x509"
"encoding/json"
"encoding/pem"
"errors"
"fmt"
"io"
"net"
Expand Down Expand Up @@ -41,7 +42,7 @@ import (
type FakeIDP struct {
issuer string
key *rsa.PrivateKey
provider providerJSON
provider ProviderJSON
handler http.Handler
cfg *oauth2.Config

Expand All @@ -66,7 +67,7 @@ type FakeIDP struct {
// IDP -> Application. Almost all IDPs have the concept of
// "Authorized Redirect URLs". This can be used to emulate that.
hookValidRedirectURL func(redirectURL string) error
hookUserInfo func(email string) jwt.MapClaims
hookUserInfo func(email string) (jwt.MapClaims, error)
fakeCoderd func(req *http.Request) (*http.Response, error)
hookOnRefresh func(email string) error
// Custom authentication for the client. This is useful if you want
Expand All @@ -75,6 +76,26 @@ type FakeIDP struct {
serve bool
}

func StatusError(code int, err error) error {
return statusHookError{
Err: err,
HTTPStatusCode: code,
}
}

// statusHookError allows a hook to change the returned http status code.
type statusHookError struct {
Err error
HTTPStatusCode int
}

func (s statusHookError) Error() string {
if s.Err == nil {
return ""
}
return s.Err.Error()
}

type FakeIDPOpt func(idp *FakeIDP)

func WithAuthorizedRedirectURL(hook func(redirectURL string) error) func(*FakeIDP) {
Expand All @@ -83,9 +104,9 @@ func WithAuthorizedRedirectURL(hook func(redirectURL string) error) func(*FakeID
}
}

// WithRefreshHook is called when a refresh token is used. The email is
// WithRefresh is called when a refresh token is used. The email is
// the email of the user that is being refreshed assuming the claims are correct.
func WithRefreshHook(hook func(email string) error) func(*FakeIDP) {
func WithRefresh(hook func(email string) error) func(*FakeIDP) {
return func(f *FakeIDP) {
f.hookOnRefresh = hook
}
Expand All @@ -108,13 +129,13 @@ func WithLogging(t testing.TB, options *slogtest.Options) func(*FakeIDP) {
// every user on the /userinfo endpoint.
func WithStaticUserInfo(info jwt.MapClaims) func(*FakeIDP) {
return func(f *FakeIDP) {
f.hookUserInfo = func(_ string) jwt.MapClaims {
return info
f.hookUserInfo = func(_ string) (jwt.MapClaims, error) {
return info, nil
}
}
}

func WithDynamicUserInfo(userInfoFunc func(email string) jwt.MapClaims) func(*FakeIDP) {
func WithDynamicUserInfo(userInfoFunc func(email string) (jwt.MapClaims, error)) func(*FakeIDP) {
return func(f *FakeIDP) {
f.hookUserInfo = userInfoFunc
}
Expand Down Expand Up @@ -160,7 +181,7 @@ func NewFakeIDP(t testing.TB, opts ...FakeIDPOpt) *FakeIDP {
stateToIDTokenClaims: syncmap.New[string, jwt.MapClaims](),
refreshIDTokenClaims: syncmap.New[string, jwt.MapClaims](),
hookOnRefresh: func(_ string) error { return nil },
hookUserInfo: func(email string) jwt.MapClaims { return jwt.MapClaims{} },
hookUserInfo: func(email string) (jwt.MapClaims, error) { return jwt.MapClaims{}, nil },
hookValidRedirectURL: func(redirectURL string) error { return nil },
}

Expand All @@ -181,16 +202,20 @@ func NewFakeIDP(t testing.TB, opts ...FakeIDPOpt) *FakeIDP {
return idp
}

func (f *FakeIDP) WellknownConfig() ProviderJSON {
return f.provider
}

func (f *FakeIDP) updateIssuerURL(t testing.TB, issuer string) {
t.Helper()

u, err := url.Parse(issuer)
require.NoError(t, err, "invalid issuer URL")

f.issuer = issuer
// providerJSON is the JSON representation of the OpenID Connect provider
// ProviderJSON is the JSON representation of the OpenID Connect provider
// These are all the urls that the IDP will respond to.
f.provider = providerJSON{
f.provider = ProviderJSON{
Issuer: issuer,
AuthURL: u.ResolveReference(&url.URL{Path: authorizePath}).String(),
TokenURL: u.ResolveReference(&url.URL{Path: tokenPath}).String(),
Expand Down Expand Up @@ -220,6 +245,15 @@ func (f *FakeIDP) realServer(t testing.TB) *httptest.Server {
return srv
}

// GenerateAuthenticatedToken skips all oauth2 flows, and just generates a
// valid token for some given claims.
func (f *FakeIDP) GenerateAuthenticatedToken(claims jwt.MapClaims) (*oauth2.Token, error) {
state := uuid.NewString()
f.stateToIDTokenClaims.Store(state, claims)
code := f.newCode(state)
return f.cfg.Exchange(oidc.ClientContext(context.Background(), f.HTTPClient(nil)), code)
}

// Login does the full OIDC flow starting at the "LoginButton".
// The client argument is just to get the URL of the Coder instance.
//
Expand Down Expand Up @@ -333,7 +367,8 @@ func (f *FakeIDP) OIDCCallback(t testing.TB, state string, idTokenClaims jwt.Map
return resp, nil
}

type providerJSON struct {
// ProviderJSON is the .well-known/configuration JSON
type ProviderJSON struct {
Issuer string `json:"issuer"`
AuthURL string `json:"authorization_endpoint"`
TokenURL string `json:"token_endpoint"`
Expand Down Expand Up @@ -475,7 +510,7 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler {
err := f.hookValidRedirectURL(redirectURI)
if err != nil {
t.Errorf("not authorized redirect_uri by custom hook %q: %s", redirectURI, err.Error())
http.Error(rw, fmt.Sprintf("invalid redirect_uri: %s", err.Error()), http.StatusBadRequest)
http.Error(rw, fmt.Sprintf("invalid redirect_uri: %s", err.Error()), httpErrorCode(http.StatusBadRequest, err))
return
}

Expand All @@ -501,7 +536,7 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler {
slog.F("values", values.Encode()),
)
if err != nil {
http.Error(rw, fmt.Sprintf("invalid token request: %s", err.Error()), http.StatusBadRequest)
http.Error(rw, fmt.Sprintf("invalid token request: %s", err.Error()), httpErrorCode(http.StatusBadRequest, err))
return
}
getEmail := func(claims jwt.MapClaims) string {
Expand Down Expand Up @@ -562,7 +597,7 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler {
claims = idTokenClaims
err := f.hookOnRefresh(getEmail(claims))
if err != nil {
http.Error(rw, fmt.Sprintf("refresh hook blocked refresh: %s", err.Error()), http.StatusBadRequest)
http.Error(rw, fmt.Sprintf("refresh hook blocked refresh: %s", err.Error()), httpErrorCode(http.StatusBadRequest, err))
return
}

Expand Down Expand Up @@ -610,7 +645,12 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler {
http.Error(rw, "invalid access token, missing user info", http.StatusBadRequest)
return
}
_ = json.NewEncoder(rw).Encode(f.hookUserInfo(email))
claims, err := f.hookUserInfo(email)
if err != nil {
http.Error(rw, fmt.Sprintf("user info hook returned error: %s", err.Error()), httpErrorCode(http.StatusBadRequest, err))
return
}
_ = json.NewEncoder(rw).Encode(claims)
}))

mux.Handle(keysPath, http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
Expand Down Expand Up @@ -768,6 +808,15 @@ func (f *FakeIDP) OIDCConfig(t testing.TB, scopes []string, opts ...func(cfg *co
return cfg
}

func httpErrorCode(defaultCode int, err error) int {
var stautsErr statusHookError
status := defaultCode
if errors.As(err, &stautsErr) {
status = stautsErr.HTTPStatusCode
}
return status
}

type fakeRoundTripper struct {
roundTrip func(req *http.Request) (*http.Response, error)
}
Expand Down
26 changes: 22 additions & 4 deletions coderd/gitauth/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,17 +60,30 @@ type Config struct {
}

// RefreshToken automatically refreshes the token if expired and permitted.
// It returns the token and a bool indicating if the token was refreshed.
// It returns the token and a bool indicating if the token is valid.
func (c *Config) RefreshToken(ctx context.Context, db database.Store, gitAuthLink database.GitAuthLink) (database.GitAuthLink, bool, error) {
// If the token is expired and refresh is disabled, we prompt
// the user to authenticate again.
if c.NoRefresh && gitAuthLink.OAuthExpiry.Before(dbtime.Now()) {
if c.NoRefresh &&
// If the time is set to 0, then it should never expire.
// This is true for github, which has no expiry.
!gitAuthLink.OAuthExpiry.IsZero() &&
gitAuthLink.OAuthExpiry.Before(dbtime.Now()) {
return gitAuthLink, false, nil
}

// This is additional defensive programming. Because TokenSource is an interface,
// we cannot be sure that the implementation will treat an 'IsZero' time
// as "not-expired". The default implementation does, but a custom implementation
// might not. Removing the refreshToken will guarantee a refresh will fail.
refreshToken := gitAuthLink.OAuthRefreshToken
if c.NoRefresh {
refreshToken = ""
}

token, err := c.TokenSource(ctx, &oauth2.Token{
AccessToken: gitAuthLink.OAuthAccessToken,
RefreshToken: gitAuthLink.OAuthRefreshToken,
RefreshToken: refreshToken,
Expiry: gitAuthLink.OAuthExpiry,
}).Token()
if err != nil {
Expand Down Expand Up @@ -130,8 +143,13 @@ func (c *Config) ValidateToken(ctx context.Context, token string) (bool, *coders
if err != nil {
return false, nil, err
}

cli := http.DefaultClient
if v, ok := ctx.Value(oauth2.HTTPClient).(*http.Client); ok {
cli = v
}
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
res, err := http.DefaultClient.Do(req)
res, err := cli.Do(req)
if err != nil {
return false, nil, err
}
Expand Down
Loading