Skip to content

Commit 5339a31

Browse files
authored
fix: remove refresh oauth logic on OIDC login (#8950)
* fix: do not do oauth refresh logic on oidc login
1 parent 1d4a72f commit 5339a31

File tree

6 files changed

+217
-68
lines changed

6 files changed

+217
-68
lines changed

coderd/coderd.go

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -693,15 +693,13 @@ func New(options *Options) *API {
693693
r.Route("/github", func(r chi.Router) {
694694
r.Use(
695695
httpmw.ExtractOAuth2(options.GithubOAuth2Config, options.HTTPClient, nil),
696-
apiKeyMiddlewareOptional,
697696
)
698697
r.Get("/callback", api.userOAuth2Github)
699698
})
700699
})
701700
r.Route("/oidc/callback", func(r chi.Router) {
702701
r.Use(
703702
httpmw.ExtractOAuth2(options.OIDCConfig, options.HTTPClient, oidcAuthURLParams),
704-
apiKeyMiddlewareOptional,
705703
)
706704
r.Get("/", api.userOIDC)
707705
})

coderd/coderdtest/coderdtest.go

Lines changed: 63 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1022,9 +1022,31 @@ func NewAWSInstanceIdentity(t *testing.T, instanceID string) (awsidentity.Certif
10221022
type OIDCConfig struct {
10231023
key *rsa.PrivateKey
10241024
issuer string
1025+
// These are optional
1026+
refreshToken string
1027+
oidcTokenExpires func() time.Time
1028+
tokenSource func() (*oauth2.Token, error)
10251029
}
10261030

1027-
func NewOIDCConfig(t *testing.T, issuer string) *OIDCConfig {
1031+
func WithRefreshToken(token string) func(cfg *OIDCConfig) {
1032+
return func(cfg *OIDCConfig) {
1033+
cfg.refreshToken = token
1034+
}
1035+
}
1036+
1037+
func WithTokenExpires(expFunc func() time.Time) func(cfg *OIDCConfig) {
1038+
return func(cfg *OIDCConfig) {
1039+
cfg.oidcTokenExpires = expFunc
1040+
}
1041+
}
1042+
1043+
func WithTokenSource(src func() (*oauth2.Token, error)) func(cfg *OIDCConfig) {
1044+
return func(cfg *OIDCConfig) {
1045+
cfg.tokenSource = src
1046+
}
1047+
}
1048+
1049+
func NewOIDCConfig(t *testing.T, issuer string, opts ...func(cfg *OIDCConfig)) *OIDCConfig {
10281050
t.Helper()
10291051

10301052
block, _ := pem.Decode([]byte(testRSAPrivateKey))
@@ -1035,54 +1057,79 @@ func NewOIDCConfig(t *testing.T, issuer string) *OIDCConfig {
10351057
issuer = "https://coder.com"
10361058
}
10371059

1038-
return &OIDCConfig{
1060+
cfg := &OIDCConfig{
10391061
key: pkey,
10401062
issuer: issuer,
10411063
}
1064+
for _, opt := range opts {
1065+
opt(cfg)
1066+
}
1067+
return cfg
10421068
}
10431069

10441070
func (*OIDCConfig) AuthCodeURL(state string, _ ...oauth2.AuthCodeOption) string {
10451071
return "/?state=" + url.QueryEscape(state)
10461072
}
10471073

1048-
func (*OIDCConfig) TokenSource(context.Context, *oauth2.Token) oauth2.TokenSource {
1049-
return nil
1074+
type tokenSource struct {
1075+
src func() (*oauth2.Token, error)
1076+
}
1077+
1078+
func (s tokenSource) Token() (*oauth2.Token, error) {
1079+
return s.src()
1080+
}
1081+
1082+
func (cfg *OIDCConfig) TokenSource(context.Context, *oauth2.Token) oauth2.TokenSource {
1083+
if cfg.tokenSource == nil {
1084+
return nil
1085+
}
1086+
return tokenSource{
1087+
src: cfg.tokenSource,
1088+
}
10501089
}
10511090

1052-
func (*OIDCConfig) Exchange(_ context.Context, code string, _ ...oauth2.AuthCodeOption) (*oauth2.Token, error) {
1091+
func (cfg *OIDCConfig) Exchange(_ context.Context, code string, _ ...oauth2.AuthCodeOption) (*oauth2.Token, error) {
10531092
token, err := base64.StdEncoding.DecodeString(code)
10541093
if err != nil {
10551094
return nil, xerrors.Errorf("decode code: %w", err)
10561095
}
1096+
1097+
var exp time.Time
1098+
if cfg.oidcTokenExpires != nil {
1099+
exp = cfg.oidcTokenExpires()
1100+
}
1101+
10571102
return (&oauth2.Token{
1058-
AccessToken: "token",
1103+
AccessToken: "token",
1104+
RefreshToken: cfg.refreshToken,
1105+
Expiry: exp,
10591106
}).WithExtra(map[string]interface{}{
10601107
"id_token": string(token),
10611108
}), nil
10621109
}
10631110

1064-
func (o *OIDCConfig) EncodeClaims(t *testing.T, claims jwt.MapClaims) string {
1111+
func (cfg *OIDCConfig) EncodeClaims(t *testing.T, claims jwt.MapClaims) string {
10651112
t.Helper()
10661113

10671114
if _, ok := claims["exp"]; !ok {
10681115
claims["exp"] = time.Now().Add(time.Hour).UnixMilli()
10691116
}
10701117

10711118
if _, ok := claims["iss"]; !ok {
1072-
claims["iss"] = o.issuer
1119+
claims["iss"] = cfg.issuer
10731120
}
10741121

10751122
if _, ok := claims["sub"]; !ok {
10761123
claims["sub"] = "testme"
10771124
}
10781125

1079-
signed, err := jwt.NewWithClaims(jwt.SigningMethodRS256, claims).SignedString(o.key)
1126+
signed, err := jwt.NewWithClaims(jwt.SigningMethodRS256, claims).SignedString(cfg.key)
10801127
require.NoError(t, err)
10811128

10821129
return base64.StdEncoding.EncodeToString([]byte(signed))
10831130
}
10841131

1085-
func (o *OIDCConfig) OIDCConfig(t *testing.T, userInfoClaims jwt.MapClaims, opts ...func(cfg *coderd.OIDCConfig)) *coderd.OIDCConfig {
1132+
func (cfg *OIDCConfig) OIDCConfig(t *testing.T, userInfoClaims jwt.MapClaims, opts ...func(cfg *coderd.OIDCConfig)) *coderd.OIDCConfig {
10861133
// By default, the provider can be empty.
10871134
// This means it won't support any endpoints!
10881135
provider := &oidc.Provider{}
@@ -1099,10 +1146,10 @@ func (o *OIDCConfig) OIDCConfig(t *testing.T, userInfoClaims jwt.MapClaims, opts
10991146
}
11001147
provider = cfg.NewProvider(context.Background())
11011148
}
1102-
cfg := &coderd.OIDCConfig{
1103-
OAuth2Config: o,
1104-
Verifier: oidc.NewVerifier(o.issuer, &oidc.StaticKeySet{
1105-
PublicKeys: []crypto.PublicKey{o.key.Public()},
1149+
newCFG := &coderd.OIDCConfig{
1150+
OAuth2Config: cfg,
1151+
Verifier: oidc.NewVerifier(cfg.issuer, &oidc.StaticKeySet{
1152+
PublicKeys: []crypto.PublicKey{cfg.key.Public()},
11061153
}, &oidc.Config{
11071154
SkipClientIDCheck: true,
11081155
}),
@@ -1113,9 +1160,9 @@ func (o *OIDCConfig) OIDCConfig(t *testing.T, userInfoClaims jwt.MapClaims, opts
11131160
GroupField: "groups",
11141161
}
11151162
for _, opt := range opts {
1116-
opt(cfg)
1163+
opt(newCFG)
11171164
}
1118-
return cfg
1165+
return newCFG
11191166
}
11201167

11211168
// NewAzureInstanceIdentity returns a metadata client and ID token validator for faking

coderd/httpmw/apikey.go

Lines changed: 55 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,56 @@ func ExtractAPIKeyMW(cfg ExtractAPIKeyConfig) func(http.Handler) http.Handler {
142142
}
143143
}
144144

145+
func APIKeyFromRequest(ctx context.Context, db database.Store, sessionTokenFunc func(r *http.Request) string, r *http.Request) (*database.APIKey, codersdk.Response, bool) {
146+
tokenFunc := APITokenFromRequest
147+
if sessionTokenFunc != nil {
148+
tokenFunc = sessionTokenFunc
149+
}
150+
151+
token := tokenFunc(r)
152+
if token == "" {
153+
return nil, codersdk.Response{
154+
Message: SignedOutErrorMessage,
155+
Detail: fmt.Sprintf("Cookie %q or query parameter must be provided.", codersdk.SessionTokenCookie),
156+
}, false
157+
}
158+
159+
keyID, keySecret, err := SplitAPIToken(token)
160+
if err != nil {
161+
return nil, codersdk.Response{
162+
Message: SignedOutErrorMessage,
163+
Detail: "Invalid API key format: " + err.Error(),
164+
}, false
165+
}
166+
167+
//nolint:gocritic // System needs to fetch API key to check if it's valid.
168+
key, err := db.GetAPIKeyByID(dbauthz.AsSystemRestricted(ctx), keyID)
169+
if err != nil {
170+
if errors.Is(err, sql.ErrNoRows) {
171+
return nil, codersdk.Response{
172+
Message: SignedOutErrorMessage,
173+
Detail: "API key is invalid.",
174+
}, false
175+
}
176+
177+
return nil, codersdk.Response{
178+
Message: internalErrorMessage,
179+
Detail: fmt.Sprintf("Internal error fetching API key by id. %s", err.Error()),
180+
}, false
181+
}
182+
183+
// Checking to see if the secret is valid.
184+
hashedSecret := sha256.Sum256([]byte(keySecret))
185+
if subtle.ConstantTimeCompare(key.HashedSecret, hashedSecret[:]) != 1 {
186+
return nil, codersdk.Response{
187+
Message: SignedOutErrorMessage,
188+
Detail: "API key secret is invalid.",
189+
}, false
190+
}
191+
192+
return &key, codersdk.Response{}, true
193+
}
194+
145195
// ExtractAPIKey requires authentication using a valid API key. It handles
146196
// extending an API key if it comes close to expiry, updating the last used time
147197
// in the database.
@@ -179,49 +229,9 @@ func ExtractAPIKey(rw http.ResponseWriter, r *http.Request, cfg ExtractAPIKeyCon
179229
return nil, nil, false
180230
}
181231

182-
tokenFunc := APITokenFromRequest
183-
if cfg.SessionTokenFunc != nil {
184-
tokenFunc = cfg.SessionTokenFunc
185-
}
186-
token := tokenFunc(r)
187-
if token == "" {
188-
return optionalWrite(http.StatusUnauthorized, codersdk.Response{
189-
Message: SignedOutErrorMessage,
190-
Detail: fmt.Sprintf("Cookie %q or query parameter must be provided.", codersdk.SessionTokenCookie),
191-
})
192-
}
193-
194-
keyID, keySecret, err := SplitAPIToken(token)
195-
if err != nil {
196-
return optionalWrite(http.StatusUnauthorized, codersdk.Response{
197-
Message: SignedOutErrorMessage,
198-
Detail: "Invalid API key format: " + err.Error(),
199-
})
200-
}
201-
202-
//nolint:gocritic // System needs to fetch API key to check if it's valid.
203-
key, err := cfg.DB.GetAPIKeyByID(dbauthz.AsSystemRestricted(ctx), keyID)
204-
if err != nil {
205-
if errors.Is(err, sql.ErrNoRows) {
206-
return optionalWrite(http.StatusUnauthorized, codersdk.Response{
207-
Message: SignedOutErrorMessage,
208-
Detail: "API key is invalid.",
209-
})
210-
}
211-
212-
return write(http.StatusInternalServerError, codersdk.Response{
213-
Message: internalErrorMessage,
214-
Detail: fmt.Sprintf("Internal error fetching API key by id. %s", err.Error()),
215-
})
216-
}
217-
218-
// Checking to see if the secret is valid.
219-
hashedSecret := sha256.Sum256([]byte(keySecret))
220-
if subtle.ConstantTimeCompare(key.HashedSecret, hashedSecret[:]) != 1 {
221-
return optionalWrite(http.StatusUnauthorized, codersdk.Response{
222-
Message: SignedOutErrorMessage,
223-
Detail: "API key secret is invalid.",
224-
})
232+
key, resp, ok := APIKeyFromRequest(ctx, cfg.DB, cfg.SessionTokenFunc, r)
233+
if !ok {
234+
return optionalWrite(http.StatusUnauthorized, resp)
225235
}
226236

227237
var (
@@ -232,7 +242,7 @@ func ExtractAPIKey(rw http.ResponseWriter, r *http.Request, cfg ExtractAPIKeyCon
232242
)
233243
if key.LoginType == database.LoginTypeGithub || key.LoginType == database.LoginTypeOIDC {
234244
//nolint:gocritic // System needs to fetch UserLink to check if it's valid.
235-
link, err = cfg.DB.GetUserLinkByUserIDLoginType(dbauthz.AsSystemRestricted(ctx), database.GetUserLinkByUserIDLoginTypeParams{
245+
link, err := cfg.DB.GetUserLinkByUserIDLoginType(dbauthz.AsSystemRestricted(ctx), database.GetUserLinkByUserIDLoginTypeParams{
236246
UserID: key.UserID,
237247
LoginType: key.LoginType,
238248
})
@@ -427,7 +437,7 @@ func ExtractAPIKey(rw http.ResponseWriter, r *http.Request, cfg ExtractAPIKeyCon
427437
}.WithCachedASTValue(),
428438
}
429439

430-
return &key, &authz, true
440+
return key, &authz, true
431441
}
432442

433443
// APITokenFromRequest returns the api token from the request.

coderd/userauth.go

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1427,7 +1427,8 @@ func (api *API) oauthLogin(r *http.Request, params *oauthLoginParams) ([]*http.C
14271427
}
14281428

14291429
var key database.APIKey
1430-
if oldKey, ok := httpmw.APIKeyOptional(r); ok && isConvertLoginType {
1430+
oldKey, _, ok := httpmw.APIKeyFromRequest(ctx, api.Database, nil, r)
1431+
if ok && oldKey != nil && isConvertLoginType {
14311432
// If this is a convert login type, and it succeeds, then delete the old
14321433
// session. Force the user to log back in.
14331434
err := api.Database.DeleteAPIKeyByID(r.Context(), oldKey.ID)
@@ -1447,7 +1448,9 @@ func (api *API) oauthLogin(r *http.Request, params *oauthLoginParams) ([]*http.C
14471448
Secure: api.SecureAuthCookie,
14481449
HttpOnly: true,
14491450
})
1450-
key = oldKey
1451+
// This is intentional setting the key to the deleted old key,
1452+
// as the user needs to be forced to log back in.
1453+
key = *oldKey
14511454
} else {
14521455
//nolint:gocritic
14531456
cookie, newKey, err := api.createAPIKey(dbauthz.AsSystemRestricted(ctx), apikey.CreateParams{

0 commit comments

Comments
 (0)