Skip to content

fix: Redirect to login when unauthenticated and requesting a workspace app #2903

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions coderd/coderd.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,10 +103,10 @@ func New(options *Options) *API {
siteHandler: site.Handler(site.FS(), binFS),
}
api.workspaceAgentCache = wsconncache.New(api.dialWorkspaceAgent, 0)

apiKeyMiddleware := httpmw.ExtractAPIKey(options.Database, &httpmw.OAuth2Configs{
oauthConfigs := &httpmw.OAuth2Configs{
Github: options.GithubOAuth2Config,
})
}
apiKeyMiddleware := httpmw.ExtractAPIKey(options.Database, oauthConfigs, false)

r.Use(
func(next http.Handler) http.Handler {
Expand All @@ -121,7 +121,7 @@ func New(options *Options) *API {
apps := func(r chi.Router) {
r.Use(
httpmw.RateLimitPerMinute(options.APIRateLimit),
apiKeyMiddleware,
httpmw.ExtractAPIKey(options.Database, oauthConfigs, true),
httpmw.ExtractUserParam(api.Database),
)
r.HandleFunc("/*", api.workspaceAppsProxyPath)
Expand Down
45 changes: 31 additions & 14 deletions coderd/httpmw/apikey.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,26 @@ type OAuth2Configs struct {
// ExtractAPIKey requires authentication using a valid API key.
// It handles extending an API key if it comes close to expiry,
// updating the last used time in the database.
func ExtractAPIKey(db database.Store, oauth *OAuth2Configs) func(http.Handler) http.Handler {
// nolint:revive
func ExtractAPIKey(db database.Store, oauth *OAuth2Configs, redirectToLogin bool) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
// Write wraps writing a response to redirect if the handler
// specified it should. This redirect is used for user-facing
// pages like workspace applications.
write := func(code int, response httpapi.Response) {
if redirectToLogin {
q := r.URL.Query()
q.Add("message", response.Message)
q.Add("redirect", r.URL.Path+"?"+r.URL.RawQuery)
r.URL.RawQuery = q.Encode()
r.URL.Path = "/login"
http.Redirect(rw, r, r.URL.String(), http.StatusTemporaryRedirect)
return
}
httpapi.Write(rw, code, response)
}

var cookieValue string
cookie, err := r.Cookie(SessionTokenKey)
if err != nil {
Expand All @@ -67,15 +84,15 @@ func ExtractAPIKey(db database.Store, oauth *OAuth2Configs) func(http.Handler) h
cookieValue = cookie.Value
}
if cookieValue == "" {
httpapi.Write(rw, http.StatusUnauthorized, httpapi.Response{
write(http.StatusUnauthorized, httpapi.Response{
Message: fmt.Sprintf("Cookie %q or query parameter must be provided.", SessionTokenKey),
})
return
}
parts := strings.Split(cookieValue, "-")
// APIKeys are formatted: ID-SECRET
if len(parts) != 2 {
httpapi.Write(rw, http.StatusUnauthorized, httpapi.Response{
write(http.StatusUnauthorized, httpapi.Response{
Message: fmt.Sprintf("Invalid %q cookie API key format.", SessionTokenKey),
})
return
Expand All @@ -84,26 +101,26 @@ func ExtractAPIKey(db database.Store, oauth *OAuth2Configs) func(http.Handler) h
keySecret := parts[1]
// Ensuring key lengths are valid.
if len(keyID) != 10 {
httpapi.Write(rw, http.StatusUnauthorized, httpapi.Response{
write(http.StatusUnauthorized, httpapi.Response{
Message: fmt.Sprintf("Invalid %q cookie API key id.", SessionTokenKey),
})
return
}
if len(keySecret) != 22 {
httpapi.Write(rw, http.StatusUnauthorized, httpapi.Response{
write(http.StatusUnauthorized, httpapi.Response{
Message: fmt.Sprintf("Invalid %q cookie API key secret.", SessionTokenKey),
})
return
}
key, err := db.GetAPIKeyByID(r.Context(), keyID)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
httpapi.Write(rw, http.StatusUnauthorized, httpapi.Response{
write(http.StatusUnauthorized, httpapi.Response{
Message: "API key is invalid.",
})
return
}
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
write(http.StatusInternalServerError, httpapi.Response{
Message: "Internal error fetching API key by id.",
Detail: err.Error(),
})
Expand All @@ -113,7 +130,7 @@ func ExtractAPIKey(db database.Store, oauth *OAuth2Configs) func(http.Handler) h

// Checking to see if the secret is valid.
if subtle.ConstantTimeCompare(key.HashedSecret, hashed[:]) != 1 {
httpapi.Write(rw, http.StatusUnauthorized, httpapi.Response{
write(http.StatusUnauthorized, httpapi.Response{
Message: "API key secret is invalid.",
})
return
Expand All @@ -130,7 +147,7 @@ func ExtractAPIKey(db database.Store, oauth *OAuth2Configs) func(http.Handler) h
case database.LoginTypeGithub:
oauthConfig = oauth.Github
default:
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
write(http.StatusInternalServerError, httpapi.Response{
Message: fmt.Sprintf("Unexpected authentication type %q.", key.LoginType),
})
return
Expand All @@ -142,7 +159,7 @@ func ExtractAPIKey(db database.Store, oauth *OAuth2Configs) func(http.Handler) h
Expiry: key.OAuthExpiry,
}).Token()
if err != nil {
httpapi.Write(rw, http.StatusUnauthorized, httpapi.Response{
write(http.StatusUnauthorized, httpapi.Response{
Message: "Could not refresh expired Oauth token.",
Detail: err.Error(),
})
Expand All @@ -158,7 +175,7 @@ func ExtractAPIKey(db database.Store, oauth *OAuth2Configs) func(http.Handler) h

// Checking if the key is expired.
if key.ExpiresAt.Before(now) {
httpapi.Write(rw, http.StatusUnauthorized, httpapi.Response{
write(http.StatusUnauthorized, httpapi.Response{
Message: fmt.Sprintf("API key expired at %q.", key.ExpiresAt.String()),
})
return
Expand Down Expand Up @@ -200,7 +217,7 @@ func ExtractAPIKey(db database.Store, oauth *OAuth2Configs) func(http.Handler) h
OAuthExpiry: key.OAuthExpiry,
})
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
write(http.StatusInternalServerError, httpapi.Response{
Message: fmt.Sprintf("API key couldn't update: %s.", err.Error()),
})
return
Expand All @@ -212,15 +229,15 @@ func ExtractAPIKey(db database.Store, oauth *OAuth2Configs) func(http.Handler) h
// is to block 'suspended' users from accessing the platform.
roles, err := db.GetAuthorizationUserRoles(r.Context(), key.UserID)
if err != nil {
httpapi.Write(rw, http.StatusUnauthorized, httpapi.Response{
write(http.StatusUnauthorized, httpapi.Response{
Message: "Internal error fetching user's roles.",
Detail: err.Error(),
})
return
}

if roles.Status != database.UserStatusActive {
httpapi.Write(rw, http.StatusUnauthorized, httpapi.Response{
write(http.StatusUnauthorized, httpapi.Response{
Message: fmt.Sprintf("User is not active (status = %q). Contact an admin to reactivate your account.", roles.Status),
})
return
Expand Down
44 changes: 30 additions & 14 deletions coderd/httpmw/apikey_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,28 @@ func TestAPIKey(t *testing.T) {
r = httptest.NewRequest("GET", "/", nil)
rw = httptest.NewRecorder()
)
httpmw.ExtractAPIKey(db, nil)(successHandler).ServeHTTP(rw, r)
httpmw.ExtractAPIKey(db, nil, false)(successHandler).ServeHTTP(rw, r)
res := rw.Result()
defer res.Body.Close()
require.Equal(t, http.StatusUnauthorized, res.StatusCode)
})

t.Run("NoCookieRedirects", func(t *testing.T) {
t.Parallel()
var (
db = databasefake.New()
r = httptest.NewRequest("GET", "/", nil)
rw = httptest.NewRecorder()
)
httpmw.ExtractAPIKey(db, nil, true)(successHandler).ServeHTTP(rw, r)
res := rw.Result()
defer res.Body.Close()
location, err := res.Location()
require.NoError(t, err)
require.NotEmpty(t, location.Query().Get("message"))
require.Equal(t, http.StatusTemporaryRedirect, res.StatusCode)
})

t.Run("InvalidFormat", func(t *testing.T) {
t.Parallel()
var (
Expand All @@ -62,7 +78,7 @@ func TestAPIKey(t *testing.T) {
Value: "test-wow-hello",
})

httpmw.ExtractAPIKey(db, nil)(successHandler).ServeHTTP(rw, r)
httpmw.ExtractAPIKey(db, nil, false)(successHandler).ServeHTTP(rw, r)
res := rw.Result()
defer res.Body.Close()
require.Equal(t, http.StatusUnauthorized, res.StatusCode)
Expand All @@ -80,7 +96,7 @@ func TestAPIKey(t *testing.T) {
Value: "test-wow",
})

httpmw.ExtractAPIKey(db, nil)(successHandler).ServeHTTP(rw, r)
httpmw.ExtractAPIKey(db, nil, false)(successHandler).ServeHTTP(rw, r)
res := rw.Result()
defer res.Body.Close()
require.Equal(t, http.StatusUnauthorized, res.StatusCode)
Expand All @@ -98,7 +114,7 @@ func TestAPIKey(t *testing.T) {
Value: "testtestid-wow",
})

httpmw.ExtractAPIKey(db, nil)(successHandler).ServeHTTP(rw, r)
httpmw.ExtractAPIKey(db, nil, false)(successHandler).ServeHTTP(rw, r)
res := rw.Result()
defer res.Body.Close()
require.Equal(t, http.StatusUnauthorized, res.StatusCode)
Expand All @@ -117,7 +133,7 @@ func TestAPIKey(t *testing.T) {
Value: fmt.Sprintf("%s-%s", id, secret),
})

httpmw.ExtractAPIKey(db, nil)(successHandler).ServeHTTP(rw, r)
httpmw.ExtractAPIKey(db, nil, false)(successHandler).ServeHTTP(rw, r)
res := rw.Result()
defer res.Body.Close()
require.Equal(t, http.StatusUnauthorized, res.StatusCode)
Expand Down Expand Up @@ -145,7 +161,7 @@ func TestAPIKey(t *testing.T) {
UserID: user.ID,
})
require.NoError(t, err)
httpmw.ExtractAPIKey(db, nil)(successHandler).ServeHTTP(rw, r)
httpmw.ExtractAPIKey(db, nil, false)(successHandler).ServeHTTP(rw, r)
res := rw.Result()
defer res.Body.Close()
require.Equal(t, http.StatusUnauthorized, res.StatusCode)
Expand All @@ -172,7 +188,7 @@ func TestAPIKey(t *testing.T) {
UserID: user.ID,
})
require.NoError(t, err)
httpmw.ExtractAPIKey(db, nil)(successHandler).ServeHTTP(rw, r)
httpmw.ExtractAPIKey(db, nil, false)(successHandler).ServeHTTP(rw, r)
res := rw.Result()
defer res.Body.Close()
require.Equal(t, http.StatusUnauthorized, res.StatusCode)
Expand Down Expand Up @@ -200,7 +216,7 @@ func TestAPIKey(t *testing.T) {
UserID: user.ID,
})
require.NoError(t, err)
httpmw.ExtractAPIKey(db, nil)(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
httpmw.ExtractAPIKey(db, nil, false)(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
// Checks that it exists on the context!
_ = httpmw.APIKey(r)
httpapi.Write(rw, http.StatusOK, httpapi.Response{
Expand Down Expand Up @@ -238,7 +254,7 @@ func TestAPIKey(t *testing.T) {
UserID: user.ID,
})
require.NoError(t, err)
httpmw.ExtractAPIKey(db, nil)(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
httpmw.ExtractAPIKey(db, nil, false)(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
// Checks that it exists on the context!
_ = httpmw.APIKey(r)
httpapi.Write(rw, http.StatusOK, httpapi.Response{
Expand Down Expand Up @@ -273,7 +289,7 @@ func TestAPIKey(t *testing.T) {
UserID: user.ID,
})
require.NoError(t, err)
httpmw.ExtractAPIKey(db, nil)(successHandler).ServeHTTP(rw, r)
httpmw.ExtractAPIKey(db, nil, false)(successHandler).ServeHTTP(rw, r)
res := rw.Result()
defer res.Body.Close()
require.Equal(t, http.StatusOK, res.StatusCode)
Expand Down Expand Up @@ -308,7 +324,7 @@ func TestAPIKey(t *testing.T) {
UserID: user.ID,
})
require.NoError(t, err)
httpmw.ExtractAPIKey(db, nil)(successHandler).ServeHTTP(rw, r)
httpmw.ExtractAPIKey(db, nil, false)(successHandler).ServeHTTP(rw, r)
res := rw.Result()
defer res.Body.Close()
require.Equal(t, http.StatusOK, res.StatusCode)
Expand Down Expand Up @@ -344,7 +360,7 @@ func TestAPIKey(t *testing.T) {
UserID: user.ID,
})
require.NoError(t, err)
httpmw.ExtractAPIKey(db, nil)(successHandler).ServeHTTP(rw, r)
httpmw.ExtractAPIKey(db, nil, false)(successHandler).ServeHTTP(rw, r)
res := rw.Result()
defer res.Body.Close()
require.Equal(t, http.StatusOK, res.StatusCode)
Expand Down Expand Up @@ -391,7 +407,7 @@ func TestAPIKey(t *testing.T) {
return token, nil
}),
},
})(successHandler).ServeHTTP(rw, r)
}, false)(successHandler).ServeHTTP(rw, r)
res := rw.Result()
defer res.Body.Close()
require.Equal(t, http.StatusOK, res.StatusCode)
Expand Down Expand Up @@ -428,7 +444,7 @@ func TestAPIKey(t *testing.T) {
UserID: user.ID,
})
require.NoError(t, err)
httpmw.ExtractAPIKey(db, nil)(successHandler).ServeHTTP(rw, r)
httpmw.ExtractAPIKey(db, nil, false)(successHandler).ServeHTTP(rw, r)
res := rw.Result()
defer res.Body.Close()
require.Equal(t, http.StatusOK, res.StatusCode)
Expand Down
2 changes: 1 addition & 1 deletion coderd/httpmw/authorize_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ func TestExtractUserRoles(t *testing.T) {
rtr = chi.NewRouter()
)
rtr.Use(
httpmw.ExtractAPIKey(db, &httpmw.OAuth2Configs{}),
httpmw.ExtractAPIKey(db, &httpmw.OAuth2Configs{}, false),
)
rtr.Get("/", func(_ http.ResponseWriter, r *http.Request) {
roles := httpmw.AuthorizationUserRoles(r)
Expand Down
10 changes: 5 additions & 5 deletions coderd/httpmw/organizationparam_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ func TestOrganizationParam(t *testing.T) {
rtr = chi.NewRouter()
)
rtr.Use(
httpmw.ExtractAPIKey(db, nil),
httpmw.ExtractAPIKey(db, nil, false),
httpmw.ExtractOrganizationParam(db),
)
rtr.Get("/", nil)
Expand All @@ -87,7 +87,7 @@ func TestOrganizationParam(t *testing.T) {
)
chi.RouteContext(r.Context()).URLParams.Add("organization", uuid.NewString())
rtr.Use(
httpmw.ExtractAPIKey(db, nil),
httpmw.ExtractAPIKey(db, nil, false),
httpmw.ExtractOrganizationParam(db),
)
rtr.Get("/", nil)
Expand All @@ -107,7 +107,7 @@ func TestOrganizationParam(t *testing.T) {
)
chi.RouteContext(r.Context()).URLParams.Add("organization", "not-a-uuid")
rtr.Use(
httpmw.ExtractAPIKey(db, nil),
httpmw.ExtractAPIKey(db, nil, false),
httpmw.ExtractOrganizationParam(db),
)
rtr.Get("/", nil)
Expand Down Expand Up @@ -135,7 +135,7 @@ func TestOrganizationParam(t *testing.T) {
chi.RouteContext(r.Context()).URLParams.Add("organization", organization.ID.String())
chi.RouteContext(r.Context()).URLParams.Add("user", u.ID.String())
rtr.Use(
httpmw.ExtractAPIKey(db, nil),
httpmw.ExtractAPIKey(db, nil, false),
httpmw.ExtractUserParam(db),
httpmw.ExtractOrganizationParam(db),
httpmw.ExtractOrganizationMemberParam(db),
Expand Down Expand Up @@ -172,7 +172,7 @@ func TestOrganizationParam(t *testing.T) {
chi.RouteContext(r.Context()).URLParams.Add("organization", organization.ID.String())
chi.RouteContext(r.Context()).URLParams.Add("user", user.ID.String())
rtr.Use(
httpmw.ExtractAPIKey(db, nil),
httpmw.ExtractAPIKey(db, nil, false),
httpmw.ExtractOrganizationParam(db),
httpmw.ExtractUserParam(db),
httpmw.ExtractOrganizationMemberParam(db),
Expand Down
2 changes: 1 addition & 1 deletion coderd/httpmw/templateparam_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ func TestTemplateParam(t *testing.T) {
db := databasefake.New()
rtr := chi.NewRouter()
rtr.Use(
httpmw.ExtractAPIKey(db, nil),
httpmw.ExtractAPIKey(db, nil, false),
httpmw.ExtractTemplateParam(db),
httpmw.ExtractOrganizationParam(db),
)
Expand Down
2 changes: 1 addition & 1 deletion coderd/httpmw/templateversionparam_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ func TestTemplateVersionParam(t *testing.T) {
db := databasefake.New()
rtr := chi.NewRouter()
rtr.Use(
httpmw.ExtractAPIKey(db, nil),
httpmw.ExtractAPIKey(db, nil, false),
httpmw.ExtractTemplateVersionParam(db),
httpmw.ExtractOrganizationParam(db),
)
Expand Down
Loading