Skip to content

fix: remove refresh oauth logic on OIDC login #8950

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Aug 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions coderd/coderd.go
Original file line number Diff line number Diff line change
Expand Up @@ -693,15 +693,13 @@ func New(options *Options) *API {
r.Route("/github", func(r chi.Router) {
r.Use(
httpmw.ExtractOAuth2(options.GithubOAuth2Config, options.HTTPClient, nil),
apiKeyMiddlewareOptional,
)
r.Get("/callback", api.userOAuth2Github)
})
})
r.Route("/oidc/callback", func(r chi.Router) {
r.Use(
httpmw.ExtractOAuth2(options.OIDCConfig, options.HTTPClient, oidcAuthURLParams),
apiKeyMiddlewareOptional,
)
r.Get("/", api.userOIDC)
})
Expand Down
79 changes: 63 additions & 16 deletions coderd/coderdtest/coderdtest.go
Original file line number Diff line number Diff line change
Expand Up @@ -1022,9 +1022,31 @@ func NewAWSInstanceIdentity(t *testing.T, instanceID string) (awsidentity.Certif
type OIDCConfig struct {
key *rsa.PrivateKey
issuer string
// These are optional
refreshToken string
oidcTokenExpires func() time.Time
tokenSource func() (*oauth2.Token, error)
}

func NewOIDCConfig(t *testing.T, issuer string) *OIDCConfig {
func WithRefreshToken(token string) func(cfg *OIDCConfig) {
return func(cfg *OIDCConfig) {
cfg.refreshToken = token
}
}

func WithTokenExpires(expFunc func() time.Time) func(cfg *OIDCConfig) {
return func(cfg *OIDCConfig) {
cfg.oidcTokenExpires = expFunc
}
}

func WithTokenSource(src func() (*oauth2.Token, error)) func(cfg *OIDCConfig) {
return func(cfg *OIDCConfig) {
cfg.tokenSource = src
}
}

func NewOIDCConfig(t *testing.T, issuer string, opts ...func(cfg *OIDCConfig)) *OIDCConfig {
t.Helper()

block, _ := pem.Decode([]byte(testRSAPrivateKey))
Expand All @@ -1035,54 +1057,79 @@ func NewOIDCConfig(t *testing.T, issuer string) *OIDCConfig {
issuer = "https://coder.com"
}

return &OIDCConfig{
cfg := &OIDCConfig{
key: pkey,
issuer: issuer,
}
for _, opt := range opts {
opt(cfg)
}
return cfg
}

func (*OIDCConfig) AuthCodeURL(state string, _ ...oauth2.AuthCodeOption) string {
return "/?state=" + url.QueryEscape(state)
}

func (*OIDCConfig) TokenSource(context.Context, *oauth2.Token) oauth2.TokenSource {
return nil
type tokenSource struct {
src func() (*oauth2.Token, error)
}

func (s tokenSource) Token() (*oauth2.Token, error) {
return s.src()
}

func (cfg *OIDCConfig) TokenSource(context.Context, *oauth2.Token) oauth2.TokenSource {
if cfg.tokenSource == nil {
return nil
}
return tokenSource{
src: cfg.tokenSource,
}
}

func (*OIDCConfig) Exchange(_ context.Context, code string, _ ...oauth2.AuthCodeOption) (*oauth2.Token, error) {
func (cfg *OIDCConfig) Exchange(_ context.Context, code string, _ ...oauth2.AuthCodeOption) (*oauth2.Token, error) {
token, err := base64.StdEncoding.DecodeString(code)
if err != nil {
return nil, xerrors.Errorf("decode code: %w", err)
}

var exp time.Time
if cfg.oidcTokenExpires != nil {
exp = cfg.oidcTokenExpires()
}

return (&oauth2.Token{
AccessToken: "token",
AccessToken: "token",
RefreshToken: cfg.refreshToken,
Expiry: exp,
}).WithExtra(map[string]interface{}{
"id_token": string(token),
}), nil
}

func (o *OIDCConfig) EncodeClaims(t *testing.T, claims jwt.MapClaims) string {
func (cfg *OIDCConfig) EncodeClaims(t *testing.T, claims jwt.MapClaims) string {
t.Helper()

if _, ok := claims["exp"]; !ok {
claims["exp"] = time.Now().Add(time.Hour).UnixMilli()
}

if _, ok := claims["iss"]; !ok {
claims["iss"] = o.issuer
claims["iss"] = cfg.issuer
}

if _, ok := claims["sub"]; !ok {
claims["sub"] = "testme"
}

signed, err := jwt.NewWithClaims(jwt.SigningMethodRS256, claims).SignedString(o.key)
signed, err := jwt.NewWithClaims(jwt.SigningMethodRS256, claims).SignedString(cfg.key)
require.NoError(t, err)

return base64.StdEncoding.EncodeToString([]byte(signed))
}

func (o *OIDCConfig) OIDCConfig(t *testing.T, userInfoClaims jwt.MapClaims, opts ...func(cfg *coderd.OIDCConfig)) *coderd.OIDCConfig {
func (cfg *OIDCConfig) OIDCConfig(t *testing.T, userInfoClaims jwt.MapClaims, opts ...func(cfg *coderd.OIDCConfig)) *coderd.OIDCConfig {
// By default, the provider can be empty.
// This means it won't support any endpoints!
provider := &oidc.Provider{}
Expand All @@ -1099,10 +1146,10 @@ func (o *OIDCConfig) OIDCConfig(t *testing.T, userInfoClaims jwt.MapClaims, opts
}
provider = cfg.NewProvider(context.Background())
}
cfg := &coderd.OIDCConfig{
OAuth2Config: o,
Verifier: oidc.NewVerifier(o.issuer, &oidc.StaticKeySet{
PublicKeys: []crypto.PublicKey{o.key.Public()},
newCFG := &coderd.OIDCConfig{
OAuth2Config: cfg,
Verifier: oidc.NewVerifier(cfg.issuer, &oidc.StaticKeySet{
PublicKeys: []crypto.PublicKey{cfg.key.Public()},
}, &oidc.Config{
SkipClientIDCheck: true,
}),
Expand All @@ -1113,9 +1160,9 @@ func (o *OIDCConfig) OIDCConfig(t *testing.T, userInfoClaims jwt.MapClaims, opts
GroupField: "groups",
}
for _, opt := range opts {
opt(cfg)
opt(newCFG)
}
return cfg
return newCFG
}

// NewAzureInstanceIdentity returns a metadata client and ID token validator for faking
Expand Down
100 changes: 55 additions & 45 deletions coderd/httpmw/apikey.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,56 @@ func ExtractAPIKeyMW(cfg ExtractAPIKeyConfig) func(http.Handler) http.Handler {
}
}

func APIKeyFromRequest(ctx context.Context, db database.Store, sessionTokenFunc func(r *http.Request) string, r *http.Request) (*database.APIKey, codersdk.Response, bool) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of this could you have just added a field to the opts struct above "NoRefreshToken" or something?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I could yea, but I didn't want to keep using that middleware for just the key extraction. It could evolve more features overtime and the use case in this situation doesn't ever change/evovle.

tokenFunc := APITokenFromRequest
if sessionTokenFunc != nil {
tokenFunc = sessionTokenFunc
}

token := tokenFunc(r)
if token == "" {
return nil, codersdk.Response{
Message: SignedOutErrorMessage,
Detail: fmt.Sprintf("Cookie %q or query parameter must be provided.", codersdk.SessionTokenCookie),
}, false
}

keyID, keySecret, err := SplitAPIToken(token)
if err != nil {
return nil, codersdk.Response{
Message: SignedOutErrorMessage,
Detail: "Invalid API key format: " + err.Error(),
}, false
}

//nolint:gocritic // System needs to fetch API key to check if it's valid.
key, err := db.GetAPIKeyByID(dbauthz.AsSystemRestricted(ctx), keyID)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, codersdk.Response{
Message: SignedOutErrorMessage,
Detail: "API key is invalid.",
}, false
}

return nil, codersdk.Response{
Message: internalErrorMessage,
Detail: fmt.Sprintf("Internal error fetching API key by id. %s", err.Error()),
}, false
}

// Checking to see if the secret is valid.
hashedSecret := sha256.Sum256([]byte(keySecret))
if subtle.ConstantTimeCompare(key.HashedSecret, hashedSecret[:]) != 1 {
return nil, codersdk.Response{
Message: SignedOutErrorMessage,
Detail: "API key secret is invalid.",
}, false
}

return &key, codersdk.Response{}, true
}

// ExtractAPIKey requires authentication using a valid API key. It handles
// extending an API key if it comes close to expiry, updating the last used time
// in the database.
Expand Down Expand Up @@ -179,49 +229,9 @@ func ExtractAPIKey(rw http.ResponseWriter, r *http.Request, cfg ExtractAPIKeyCon
return nil, nil, false
}

tokenFunc := APITokenFromRequest
if cfg.SessionTokenFunc != nil {
tokenFunc = cfg.SessionTokenFunc
}
token := tokenFunc(r)
if token == "" {
return optionalWrite(http.StatusUnauthorized, codersdk.Response{
Message: SignedOutErrorMessage,
Detail: fmt.Sprintf("Cookie %q or query parameter must be provided.", codersdk.SessionTokenCookie),
})
}

keyID, keySecret, err := SplitAPIToken(token)
if err != nil {
return optionalWrite(http.StatusUnauthorized, codersdk.Response{
Message: SignedOutErrorMessage,
Detail: "Invalid API key format: " + err.Error(),
})
}

//nolint:gocritic // System needs to fetch API key to check if it's valid.
key, err := cfg.DB.GetAPIKeyByID(dbauthz.AsSystemRestricted(ctx), keyID)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return optionalWrite(http.StatusUnauthorized, codersdk.Response{
Message: SignedOutErrorMessage,
Detail: "API key is invalid.",
})
}

return write(http.StatusInternalServerError, codersdk.Response{
Message: internalErrorMessage,
Detail: fmt.Sprintf("Internal error fetching API key by id. %s", err.Error()),
})
}

// Checking to see if the secret is valid.
hashedSecret := sha256.Sum256([]byte(keySecret))
if subtle.ConstantTimeCompare(key.HashedSecret, hashedSecret[:]) != 1 {
return optionalWrite(http.StatusUnauthorized, codersdk.Response{
Message: SignedOutErrorMessage,
Detail: "API key secret is invalid.",
})
key, resp, ok := APIKeyFromRequest(ctx, cfg.DB, cfg.SessionTokenFunc, r)
if !ok {
return optionalWrite(http.StatusUnauthorized, resp)
}

var (
Expand All @@ -232,7 +242,7 @@ func ExtractAPIKey(rw http.ResponseWriter, r *http.Request, cfg ExtractAPIKeyCon
)
if key.LoginType == database.LoginTypeGithub || key.LoginType == database.LoginTypeOIDC {
//nolint:gocritic // System needs to fetch UserLink to check if it's valid.
link, err = cfg.DB.GetUserLinkByUserIDLoginType(dbauthz.AsSystemRestricted(ctx), database.GetUserLinkByUserIDLoginTypeParams{
link, err := cfg.DB.GetUserLinkByUserIDLoginType(dbauthz.AsSystemRestricted(ctx), database.GetUserLinkByUserIDLoginTypeParams{
UserID: key.UserID,
LoginType: key.LoginType,
})
Expand Down Expand Up @@ -427,7 +437,7 @@ func ExtractAPIKey(rw http.ResponseWriter, r *http.Request, cfg ExtractAPIKeyCon
}.WithCachedASTValue(),
}

return &key, &authz, true
return key, &authz, true
}

// APITokenFromRequest returns the api token from the request.
Expand Down
7 changes: 5 additions & 2 deletions coderd/userauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -1427,7 +1427,8 @@ func (api *API) oauthLogin(r *http.Request, params *oauthLoginParams) ([]*http.C
}

var key database.APIKey
if oldKey, ok := httpmw.APIKeyOptional(r); ok && isConvertLoginType {
oldKey, _, ok := httpmw.APIKeyFromRequest(ctx, api.Database, nil, r)
if ok && oldKey != nil && isConvertLoginType {
// If this is a convert login type, and it succeeds, then delete the old
// session. Force the user to log back in.
err := api.Database.DeleteAPIKeyByID(r.Context(), oldKey.ID)
Expand All @@ -1447,7 +1448,9 @@ func (api *API) oauthLogin(r *http.Request, params *oauthLoginParams) ([]*http.C
Secure: api.SecureAuthCookie,
HttpOnly: true,
})
key = oldKey
// This is intentional setting the key to the deleted old key,
// as the user needs to be forced to log back in.
key = *oldKey
} else {
//nolint:gocritic
cookie, newKey, err := api.createAPIKey(dbauthz.AsSystemRestricted(ctx), apikey.CreateParams{
Expand Down
Loading