Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions coderd/coderd.go
Original file line number Diff line number Diff line change
Expand Up @@ -693,15 +693,13 @@ func New(options *Options) *API {
r.Route("/github", func(r chi.Router) {
r.Use(
httpmw.ExtractOAuth2(options.GithubOAuth2Config, options.HTTPClient, nil),
apiKeyMiddlewareOptional,
)
r.Get("/callback", api.userOAuth2Github)
})
})
r.Route("/oidc/callback", func(r chi.Router) {
r.Use(
httpmw.ExtractOAuth2(options.OIDCConfig, options.HTTPClient, oidcAuthURLParams),
apiKeyMiddlewareOptional,
)
r.Get("/", api.userOIDC)
})
Expand Down
59 changes: 53 additions & 6 deletions coderd/coderdtest/coderdtest.go
Original file line number Diff line number Diff line change
Expand Up @@ -1022,9 +1022,31 @@ func NewAWSInstanceIdentity(t *testing.T, instanceID string) (awsidentity.Certif
type OIDCConfig struct {
key *rsa.PrivateKey
issuer string
// These are optional
refreshToken string
oidcTokenExpires func() time.Time
tokenSource func() (*oauth2.Token, error)
}

func NewOIDCConfig(t *testing.T, issuer string) *OIDCConfig {
func WithRefreshToken(token string) func(cfg *OIDCConfig) {
return func(cfg *OIDCConfig) {
cfg.refreshToken = token
}
}

func WithTokenExpires(expFunc func() time.Time) func(cfg *OIDCConfig) {
return func(cfg *OIDCConfig) {
cfg.oidcTokenExpires = expFunc
}
}

func WithTokenSource(src func() (*oauth2.Token, error)) func(cfg *OIDCConfig) {
return func(cfg *OIDCConfig) {
cfg.tokenSource = src
}
}

func NewOIDCConfig(t *testing.T, issuer string, opts ...func(cfg *OIDCConfig)) *OIDCConfig {
t.Helper()

block, _ := pem.Decode([]byte(testRSAPrivateKey))
Expand All @@ -1035,27 +1057,52 @@ func NewOIDCConfig(t *testing.T, issuer string) *OIDCConfig {
issuer = "https://coder.com"
}

return &OIDCConfig{
cfg := &OIDCConfig{
key: pkey,
issuer: issuer,
}
for _, opt := range opts {
opt(cfg)
}
return cfg
}

func (*OIDCConfig) AuthCodeURL(state string, _ ...oauth2.AuthCodeOption) string {
return "/?state=" + url.QueryEscape(state)
}

func (*OIDCConfig) TokenSource(context.Context, *oauth2.Token) oauth2.TokenSource {
return nil
type tokenSource struct {
src func() (*oauth2.Token, error)
}

func (s tokenSource) Token() (*oauth2.Token, error) {
return s.src()
}

func (*OIDCConfig) Exchange(_ context.Context, code string, _ ...oauth2.AuthCodeOption) (*oauth2.Token, error) {
func (cfg *OIDCConfig) TokenSource(context.Context, *oauth2.Token) oauth2.TokenSource {
if cfg.tokenSource == nil {
return nil
}
return tokenSource{
src: cfg.tokenSource,
}
}

func (cfg *OIDCConfig) Exchange(_ context.Context, code string, _ ...oauth2.AuthCodeOption) (*oauth2.Token, error) {
token, err := base64.StdEncoding.DecodeString(code)
if err != nil {
return nil, xerrors.Errorf("decode code: %w", err)
}

var exp time.Time
if cfg.oidcTokenExpires != nil {
exp = cfg.oidcTokenExpires()
}

return (&oauth2.Token{
AccessToken: "token",
AccessToken: "token",
RefreshToken: cfg.refreshToken,
Expiry: exp,
}).WithExtra(map[string]interface{}{
"id_token": string(token),
}), nil
Expand Down
100 changes: 55 additions & 45 deletions coderd/httpmw/apikey.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,56 @@ func ExtractAPIKeyMW(cfg ExtractAPIKeyConfig) func(http.Handler) http.Handler {
}
}

func APIKeyFromRequest(ctx context.Context, db database.Store, sessionTokenFunc func(r *http.Request) string, r *http.Request) (*database.APIKey, codersdk.Response, bool) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of this could you have just added a field to the opts struct above "NoRefreshToken" or something?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I could yea, but I didn't want to keep using that middleware for just the key extraction. It could evolve more features overtime and the use case in this situation doesn't ever change/evovle.

tokenFunc := APITokenFromRequest
if sessionTokenFunc != nil {
tokenFunc = sessionTokenFunc
}

token := tokenFunc(r)
if token == "" {
return nil, codersdk.Response{
Message: SignedOutErrorMessage,
Detail: fmt.Sprintf("Cookie %q or query parameter must be provided.", codersdk.SessionTokenCookie),
}, false
}

keyID, keySecret, err := SplitAPIToken(token)
if err != nil {
return nil, codersdk.Response{
Message: SignedOutErrorMessage,
Detail: "Invalid API key format: " + err.Error(),
}, false
}

//nolint:gocritic // System needs to fetch API key to check if it's valid.
key, err := db.GetAPIKeyByID(dbauthz.AsSystemRestricted(ctx), keyID)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, codersdk.Response{
Message: SignedOutErrorMessage,
Detail: "API key is invalid.",
}, false
}

return nil, codersdk.Response{
Message: internalErrorMessage,
Detail: fmt.Sprintf("Internal error fetching API key by id. %s", err.Error()),
}, false
}

// Checking to see if the secret is valid.
hashedSecret := sha256.Sum256([]byte(keySecret))
if subtle.ConstantTimeCompare(key.HashedSecret, hashedSecret[:]) != 1 {
return nil, codersdk.Response{
Message: SignedOutErrorMessage,
Detail: "API key secret is invalid.",
}, false
}

return &key, codersdk.Response{}, true
}

// ExtractAPIKey requires authentication using a valid API key. It handles
// extending an API key if it comes close to expiry, updating the last used time
// in the database.
Expand Down Expand Up @@ -179,49 +229,9 @@ func ExtractAPIKey(rw http.ResponseWriter, r *http.Request, cfg ExtractAPIKeyCon
return nil, nil, false
}

tokenFunc := APITokenFromRequest
if cfg.SessionTokenFunc != nil {
tokenFunc = cfg.SessionTokenFunc
}
token := tokenFunc(r)
if token == "" {
return optionalWrite(http.StatusUnauthorized, codersdk.Response{
Message: SignedOutErrorMessage,
Detail: fmt.Sprintf("Cookie %q or query parameter must be provided.", codersdk.SessionTokenCookie),
})
}

keyID, keySecret, err := SplitAPIToken(token)
if err != nil {
return optionalWrite(http.StatusUnauthorized, codersdk.Response{
Message: SignedOutErrorMessage,
Detail: "Invalid API key format: " + err.Error(),
})
}

//nolint:gocritic // System needs to fetch API key to check if it's valid.
key, err := cfg.DB.GetAPIKeyByID(dbauthz.AsSystemRestricted(ctx), keyID)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return optionalWrite(http.StatusUnauthorized, codersdk.Response{
Message: SignedOutErrorMessage,
Detail: "API key is invalid.",
})
}

return write(http.StatusInternalServerError, codersdk.Response{
Message: internalErrorMessage,
Detail: fmt.Sprintf("Internal error fetching API key by id. %s", err.Error()),
})
}

// Checking to see if the secret is valid.
hashedSecret := sha256.Sum256([]byte(keySecret))
if subtle.ConstantTimeCompare(key.HashedSecret, hashedSecret[:]) != 1 {
return optionalWrite(http.StatusUnauthorized, codersdk.Response{
Message: SignedOutErrorMessage,
Detail: "API key secret is invalid.",
})
key, resp, ok := APIKeyFromRequest(ctx, cfg.DB, cfg.SessionTokenFunc, r)
if !ok {
return optionalWrite(http.StatusUnauthorized, resp)
}

var (
Expand All @@ -232,7 +242,7 @@ func ExtractAPIKey(rw http.ResponseWriter, r *http.Request, cfg ExtractAPIKeyCon
)
if key.LoginType == database.LoginTypeGithub || key.LoginType == database.LoginTypeOIDC {
//nolint:gocritic // System needs to fetch UserLink to check if it's valid.
link, err = cfg.DB.GetUserLinkByUserIDLoginType(dbauthz.AsSystemRestricted(ctx), database.GetUserLinkByUserIDLoginTypeParams{
link, err := cfg.DB.GetUserLinkByUserIDLoginType(dbauthz.AsSystemRestricted(ctx), database.GetUserLinkByUserIDLoginTypeParams{
UserID: key.UserID,
LoginType: key.LoginType,
})
Expand Down Expand Up @@ -427,7 +437,7 @@ func ExtractAPIKey(rw http.ResponseWriter, r *http.Request, cfg ExtractAPIKeyCon
}.WithCachedASTValue(),
}

return &key, &authz, true
return key, &authz, true
}

// APITokenFromRequest returns the api token from the request.
Expand Down
7 changes: 5 additions & 2 deletions coderd/userauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -1427,7 +1427,8 @@ func (api *API) oauthLogin(r *http.Request, params *oauthLoginParams) ([]*http.C
}

var key database.APIKey
if oldKey, ok := httpmw.APIKeyOptional(r); ok && isConvertLoginType {
oldKey, _, ok := httpmw.APIKeyFromRequest(ctx, api.Database, nil, r)
if ok && oldKey != nil && isConvertLoginType {
// If this is a convert login type, and it succeeds, then delete the old
// session. Force the user to log back in.
err := api.Database.DeleteAPIKeyByID(r.Context(), oldKey.ID)
Expand All @@ -1447,7 +1448,9 @@ func (api *API) oauthLogin(r *http.Request, params *oauthLoginParams) ([]*http.C
Secure: api.SecureAuthCookie,
HttpOnly: true,
})
key = oldKey
// This is intentional setting the key to the deleted old key,
// as the user needs to be forced to log back in.
key = *oldKey
} else {
//nolint:gocritic
cookie, newKey, err := api.createAPIKey(dbauthz.AsSystemRestricted(ctx), apikey.CreateParams{
Expand Down
92 changes: 89 additions & 3 deletions coderd/userauth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"net/http/cookiejar"
"strings"
"testing"
"time"

"github.com/coreos/go-oidc/v3/oidc"
"github.com/golang-jwt/jwt"
Expand All @@ -24,12 +25,94 @@ import (
"github.com/coder/coder/coderd/audit"
"github.com/coder/coder/coderd/coderdtest"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/database/dbauthz"
"github.com/coder/coder/coderd/database/dbgen"
"github.com/coder/coder/coderd/database/dbtestutil"
"github.com/coder/coder/codersdk"
"github.com/coder/coder/testutil"
)

// This test specifically tests logging in with OIDC when an expired
// OIDC session token exists.
// The token refreshing should not happen since we are reauthenticating.
func TestOIDCOauthLoginWithExisting(t *testing.T) {
conf := coderdtest.NewOIDCConfig(t, "",
// Provide a refresh token so we use the refresh token flow
coderdtest.WithRefreshToken("refresh_token"),
// We need to set the expire in the future for the first api calls.
coderdtest.WithTokenExpires(func() time.Time {
return time.Now().Add(time.Hour).UTC()
}),
// No refresh should actually happen in this test.
coderdtest.WithTokenSource(func() (*oauth2.Token, error) {
return nil, xerrors.New("token should not require refresh")
}),
)
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
auditor := audit.NewMock()
const username = "alice"
claims := jwt.MapClaims{
"email": "alice@coder.com",
"email_verified": true,
"preferred_username": username,
}
config := conf.OIDCConfig(t, claims)

config.AllowSignups = true
config.IgnoreUserInfo = true
client, _, api := coderdtest.NewWithAPI(t, &coderdtest.Options{
Auditor: auditor,
OIDCConfig: config,
Logger: &logger,
})

// Signup alice
resp := oidcCallback(t, client, conf.EncodeClaims(t, claims))
// Set the client to use this OIDC context
authCookie := authCookieValue(resp.Cookies())
client.SetSessionToken(authCookie)
_ = resp.Body.Close()

ctx := testutil.Context(t, testutil.WaitLong)
// Verify the user and oauth link
user, err := client.User(ctx, "me")
require.NoError(t, err)
require.Equal(t, username, user.Username)

// nolint:gocritic
link, err := api.Database.GetUserLinkByUserIDLoginType(dbauthz.AsSystemRestricted(ctx), database.GetUserLinkByUserIDLoginTypeParams{
UserID: user.ID,
LoginType: database.LoginType(user.LoginType),
})
require.NoError(t, err, "failed to get user link")

// Expire the link
// nolint:gocritic
_, err = api.Database.UpdateUserLink(dbauthz.AsSystemRestricted(ctx), database.UpdateUserLinkParams{
OAuthAccessToken: link.OAuthAccessToken,
OAuthRefreshToken: link.OAuthRefreshToken,
OAuthExpiry: time.Now().Add(time.Hour * -1).UTC(),
UserID: link.UserID,
LoginType: link.LoginType,
})
require.NoError(t, err, "failed to update user link")

// Log in again with OIDC
loginAgain := oidcCallbackWithState(t, client, conf.EncodeClaims(t, claims), "seconds_login", func(req *http.Request) {
req.AddCookie(&http.Cookie{
Name: codersdk.SessionTokenCookie,
Value: authCookie,
Path: "/",
})
})
require.Equal(t, http.StatusTemporaryRedirect, loginAgain.StatusCode)

// Try to use new login
client.SetSessionToken(authCookieValue(resp.Cookies()))
_, err = client.User(ctx, "me")
require.NoError(t, err, "use new session")
}

func TestUserLogin(t *testing.T) {
t.Parallel()
t.Run("OK", func(t *testing.T) {
Expand Down Expand Up @@ -819,7 +902,7 @@ func TestUserOIDC(t *testing.T) {
})
require.NoError(t, err)

resp := oidcCallbackWithState(t, user, code, convertResponse.StateString)
resp := oidcCallbackWithState(t, user, code, convertResponse.StateString, nil)
require.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)
})

Expand Down Expand Up @@ -1045,10 +1128,10 @@ func oauth2Callback(t *testing.T, client *codersdk.Client) *http.Response {
}

func oidcCallback(t *testing.T, client *codersdk.Client, code string) *http.Response {
return oidcCallbackWithState(t, client, code, "somestate")
return oidcCallbackWithState(t, client, code, "somestate", nil)
}

func oidcCallbackWithState(t *testing.T, client *codersdk.Client, code, state string) *http.Response {
func oidcCallbackWithState(t *testing.T, client *codersdk.Client, code, state string, modify func(r *http.Request)) *http.Response {
t.Helper()

client.HTTPClient.CheckRedirect = func(req *http.Request, via []*http.Request) error {
Expand All @@ -1062,6 +1145,9 @@ func oidcCallbackWithState(t *testing.T, client *codersdk.Client, code, state st
Name: codersdk.OAuth2StateCookie,
Value: state,
})
if modify != nil {
modify(req)
}
res, err := client.HTTPClient.Do(req)
require.NoError(t, err)
defer res.Body.Close()
Expand Down