Skip to content

Commit 2c89e07

Browse files
authored
fix: Redirect to login when unauthenticated and requesting a workspace app (coder#2903)
Fixes coder#2884.
1 parent 08d90f7 commit 2c89e07

12 files changed

+94
-46
lines changed

coderd/coderd.go

+4-4
Original file line numberDiff line numberDiff line change
@@ -103,10 +103,10 @@ func New(options *Options) *API {
103103
siteHandler: site.Handler(site.FS(), binFS),
104104
}
105105
api.workspaceAgentCache = wsconncache.New(api.dialWorkspaceAgent, 0)
106-
107-
apiKeyMiddleware := httpmw.ExtractAPIKey(options.Database, &httpmw.OAuth2Configs{
106+
oauthConfigs := &httpmw.OAuth2Configs{
108107
Github: options.GithubOAuth2Config,
109-
})
108+
}
109+
apiKeyMiddleware := httpmw.ExtractAPIKey(options.Database, oauthConfigs, false)
110110

111111
r.Use(
112112
func(next http.Handler) http.Handler {
@@ -121,7 +121,7 @@ func New(options *Options) *API {
121121
apps := func(r chi.Router) {
122122
r.Use(
123123
httpmw.RateLimitPerMinute(options.APIRateLimit),
124-
apiKeyMiddleware,
124+
httpmw.ExtractAPIKey(options.Database, oauthConfigs, true),
125125
httpmw.ExtractUserParam(api.Database),
126126
)
127127
r.HandleFunc("/*", api.workspaceAppsProxyPath)

coderd/httpmw/apikey.go

+31-14
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,26 @@ type OAuth2Configs struct {
5656
// ExtractAPIKey requires authentication using a valid API key.
5757
// It handles extending an API key if it comes close to expiry,
5858
// updating the last used time in the database.
59-
func ExtractAPIKey(db database.Store, oauth *OAuth2Configs) func(http.Handler) http.Handler {
59+
// nolint:revive
60+
func ExtractAPIKey(db database.Store, oauth *OAuth2Configs, redirectToLogin bool) func(http.Handler) http.Handler {
6061
return func(next http.Handler) http.Handler {
6162
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
63+
// Write wraps writing a response to redirect if the handler
64+
// specified it should. This redirect is used for user-facing
65+
// pages like workspace applications.
66+
write := func(code int, response httpapi.Response) {
67+
if redirectToLogin {
68+
q := r.URL.Query()
69+
q.Add("message", response.Message)
70+
q.Add("redirect", r.URL.Path+"?"+r.URL.RawQuery)
71+
r.URL.RawQuery = q.Encode()
72+
r.URL.Path = "/login"
73+
http.Redirect(rw, r, r.URL.String(), http.StatusTemporaryRedirect)
74+
return
75+
}
76+
httpapi.Write(rw, code, response)
77+
}
78+
6279
var cookieValue string
6380
cookie, err := r.Cookie(SessionTokenKey)
6481
if err != nil {
@@ -67,15 +84,15 @@ func ExtractAPIKey(db database.Store, oauth *OAuth2Configs) func(http.Handler) h
6784
cookieValue = cookie.Value
6885
}
6986
if cookieValue == "" {
70-
httpapi.Write(rw, http.StatusUnauthorized, httpapi.Response{
87+
write(http.StatusUnauthorized, httpapi.Response{
7188
Message: fmt.Sprintf("Cookie %q or query parameter must be provided.", SessionTokenKey),
7289
})
7390
return
7491
}
7592
parts := strings.Split(cookieValue, "-")
7693
// APIKeys are formatted: ID-SECRET
7794
if len(parts) != 2 {
78-
httpapi.Write(rw, http.StatusUnauthorized, httpapi.Response{
95+
write(http.StatusUnauthorized, httpapi.Response{
7996
Message: fmt.Sprintf("Invalid %q cookie API key format.", SessionTokenKey),
8097
})
8198
return
@@ -84,26 +101,26 @@ func ExtractAPIKey(db database.Store, oauth *OAuth2Configs) func(http.Handler) h
84101
keySecret := parts[1]
85102
// Ensuring key lengths are valid.
86103
if len(keyID) != 10 {
87-
httpapi.Write(rw, http.StatusUnauthorized, httpapi.Response{
104+
write(http.StatusUnauthorized, httpapi.Response{
88105
Message: fmt.Sprintf("Invalid %q cookie API key id.", SessionTokenKey),
89106
})
90107
return
91108
}
92109
if len(keySecret) != 22 {
93-
httpapi.Write(rw, http.StatusUnauthorized, httpapi.Response{
110+
write(http.StatusUnauthorized, httpapi.Response{
94111
Message: fmt.Sprintf("Invalid %q cookie API key secret.", SessionTokenKey),
95112
})
96113
return
97114
}
98115
key, err := db.GetAPIKeyByID(r.Context(), keyID)
99116
if err != nil {
100117
if errors.Is(err, sql.ErrNoRows) {
101-
httpapi.Write(rw, http.StatusUnauthorized, httpapi.Response{
118+
write(http.StatusUnauthorized, httpapi.Response{
102119
Message: "API key is invalid.",
103120
})
104121
return
105122
}
106-
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
123+
write(http.StatusInternalServerError, httpapi.Response{
107124
Message: "Internal error fetching API key by id.",
108125
Detail: err.Error(),
109126
})
@@ -113,7 +130,7 @@ func ExtractAPIKey(db database.Store, oauth *OAuth2Configs) func(http.Handler) h
113130

114131
// Checking to see if the secret is valid.
115132
if subtle.ConstantTimeCompare(key.HashedSecret, hashed[:]) != 1 {
116-
httpapi.Write(rw, http.StatusUnauthorized, httpapi.Response{
133+
write(http.StatusUnauthorized, httpapi.Response{
117134
Message: "API key secret is invalid.",
118135
})
119136
return
@@ -130,7 +147,7 @@ func ExtractAPIKey(db database.Store, oauth *OAuth2Configs) func(http.Handler) h
130147
case database.LoginTypeGithub:
131148
oauthConfig = oauth.Github
132149
default:
133-
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
150+
write(http.StatusInternalServerError, httpapi.Response{
134151
Message: fmt.Sprintf("Unexpected authentication type %q.", key.LoginType),
135152
})
136153
return
@@ -142,7 +159,7 @@ func ExtractAPIKey(db database.Store, oauth *OAuth2Configs) func(http.Handler) h
142159
Expiry: key.OAuthExpiry,
143160
}).Token()
144161
if err != nil {
145-
httpapi.Write(rw, http.StatusUnauthorized, httpapi.Response{
162+
write(http.StatusUnauthorized, httpapi.Response{
146163
Message: "Could not refresh expired Oauth token.",
147164
Detail: err.Error(),
148165
})
@@ -158,7 +175,7 @@ func ExtractAPIKey(db database.Store, oauth *OAuth2Configs) func(http.Handler) h
158175

159176
// Checking if the key is expired.
160177
if key.ExpiresAt.Before(now) {
161-
httpapi.Write(rw, http.StatusUnauthorized, httpapi.Response{
178+
write(http.StatusUnauthorized, httpapi.Response{
162179
Message: fmt.Sprintf("API key expired at %q.", key.ExpiresAt.String()),
163180
})
164181
return
@@ -200,7 +217,7 @@ func ExtractAPIKey(db database.Store, oauth *OAuth2Configs) func(http.Handler) h
200217
OAuthExpiry: key.OAuthExpiry,
201218
})
202219
if err != nil {
203-
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
220+
write(http.StatusInternalServerError, httpapi.Response{
204221
Message: fmt.Sprintf("API key couldn't update: %s.", err.Error()),
205222
})
206223
return
@@ -212,15 +229,15 @@ func ExtractAPIKey(db database.Store, oauth *OAuth2Configs) func(http.Handler) h
212229
// is to block 'suspended' users from accessing the platform.
213230
roles, err := db.GetAuthorizationUserRoles(r.Context(), key.UserID)
214231
if err != nil {
215-
httpapi.Write(rw, http.StatusUnauthorized, httpapi.Response{
232+
write(http.StatusUnauthorized, httpapi.Response{
216233
Message: "Internal error fetching user's roles.",
217234
Detail: err.Error(),
218235
})
219236
return
220237
}
221238

222239
if roles.Status != database.UserStatusActive {
223-
httpapi.Write(rw, http.StatusUnauthorized, httpapi.Response{
240+
write(http.StatusUnauthorized, httpapi.Response{
224241
Message: fmt.Sprintf("User is not active (status = %q). Contact an admin to reactivate your account.", roles.Status),
225242
})
226243
return

coderd/httpmw/apikey_test.go

+30-14
Original file line numberDiff line numberDiff line change
@@ -44,12 +44,28 @@ func TestAPIKey(t *testing.T) {
4444
r = httptest.NewRequest("GET", "/", nil)
4545
rw = httptest.NewRecorder()
4646
)
47-
httpmw.ExtractAPIKey(db, nil)(successHandler).ServeHTTP(rw, r)
47+
httpmw.ExtractAPIKey(db, nil, false)(successHandler).ServeHTTP(rw, r)
4848
res := rw.Result()
4949
defer res.Body.Close()
5050
require.Equal(t, http.StatusUnauthorized, res.StatusCode)
5151
})
5252

53+
t.Run("NoCookieRedirects", func(t *testing.T) {
54+
t.Parallel()
55+
var (
56+
db = databasefake.New()
57+
r = httptest.NewRequest("GET", "/", nil)
58+
rw = httptest.NewRecorder()
59+
)
60+
httpmw.ExtractAPIKey(db, nil, true)(successHandler).ServeHTTP(rw, r)
61+
res := rw.Result()
62+
defer res.Body.Close()
63+
location, err := res.Location()
64+
require.NoError(t, err)
65+
require.NotEmpty(t, location.Query().Get("message"))
66+
require.Equal(t, http.StatusTemporaryRedirect, res.StatusCode)
67+
})
68+
5369
t.Run("InvalidFormat", func(t *testing.T) {
5470
t.Parallel()
5571
var (
@@ -62,7 +78,7 @@ func TestAPIKey(t *testing.T) {
6278
Value: "test-wow-hello",
6379
})
6480

65-
httpmw.ExtractAPIKey(db, nil)(successHandler).ServeHTTP(rw, r)
81+
httpmw.ExtractAPIKey(db, nil, false)(successHandler).ServeHTTP(rw, r)
6682
res := rw.Result()
6783
defer res.Body.Close()
6884
require.Equal(t, http.StatusUnauthorized, res.StatusCode)
@@ -80,7 +96,7 @@ func TestAPIKey(t *testing.T) {
8096
Value: "test-wow",
8197
})
8298

83-
httpmw.ExtractAPIKey(db, nil)(successHandler).ServeHTTP(rw, r)
99+
httpmw.ExtractAPIKey(db, nil, false)(successHandler).ServeHTTP(rw, r)
84100
res := rw.Result()
85101
defer res.Body.Close()
86102
require.Equal(t, http.StatusUnauthorized, res.StatusCode)
@@ -98,7 +114,7 @@ func TestAPIKey(t *testing.T) {
98114
Value: "testtestid-wow",
99115
})
100116

101-
httpmw.ExtractAPIKey(db, nil)(successHandler).ServeHTTP(rw, r)
117+
httpmw.ExtractAPIKey(db, nil, false)(successHandler).ServeHTTP(rw, r)
102118
res := rw.Result()
103119
defer res.Body.Close()
104120
require.Equal(t, http.StatusUnauthorized, res.StatusCode)
@@ -117,7 +133,7 @@ func TestAPIKey(t *testing.T) {
117133
Value: fmt.Sprintf("%s-%s", id, secret),
118134
})
119135

120-
httpmw.ExtractAPIKey(db, nil)(successHandler).ServeHTTP(rw, r)
136+
httpmw.ExtractAPIKey(db, nil, false)(successHandler).ServeHTTP(rw, r)
121137
res := rw.Result()
122138
defer res.Body.Close()
123139
require.Equal(t, http.StatusUnauthorized, res.StatusCode)
@@ -145,7 +161,7 @@ func TestAPIKey(t *testing.T) {
145161
UserID: user.ID,
146162
})
147163
require.NoError(t, err)
148-
httpmw.ExtractAPIKey(db, nil)(successHandler).ServeHTTP(rw, r)
164+
httpmw.ExtractAPIKey(db, nil, false)(successHandler).ServeHTTP(rw, r)
149165
res := rw.Result()
150166
defer res.Body.Close()
151167
require.Equal(t, http.StatusUnauthorized, res.StatusCode)
@@ -172,7 +188,7 @@ func TestAPIKey(t *testing.T) {
172188
UserID: user.ID,
173189
})
174190
require.NoError(t, err)
175-
httpmw.ExtractAPIKey(db, nil)(successHandler).ServeHTTP(rw, r)
191+
httpmw.ExtractAPIKey(db, nil, false)(successHandler).ServeHTTP(rw, r)
176192
res := rw.Result()
177193
defer res.Body.Close()
178194
require.Equal(t, http.StatusUnauthorized, res.StatusCode)
@@ -200,7 +216,7 @@ func TestAPIKey(t *testing.T) {
200216
UserID: user.ID,
201217
})
202218
require.NoError(t, err)
203-
httpmw.ExtractAPIKey(db, nil)(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
219+
httpmw.ExtractAPIKey(db, nil, false)(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
204220
// Checks that it exists on the context!
205221
_ = httpmw.APIKey(r)
206222
httpapi.Write(rw, http.StatusOK, httpapi.Response{
@@ -238,7 +254,7 @@ func TestAPIKey(t *testing.T) {
238254
UserID: user.ID,
239255
})
240256
require.NoError(t, err)
241-
httpmw.ExtractAPIKey(db, nil)(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
257+
httpmw.ExtractAPIKey(db, nil, false)(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
242258
// Checks that it exists on the context!
243259
_ = httpmw.APIKey(r)
244260
httpapi.Write(rw, http.StatusOK, httpapi.Response{
@@ -273,7 +289,7 @@ func TestAPIKey(t *testing.T) {
273289
UserID: user.ID,
274290
})
275291
require.NoError(t, err)
276-
httpmw.ExtractAPIKey(db, nil)(successHandler).ServeHTTP(rw, r)
292+
httpmw.ExtractAPIKey(db, nil, false)(successHandler).ServeHTTP(rw, r)
277293
res := rw.Result()
278294
defer res.Body.Close()
279295
require.Equal(t, http.StatusOK, res.StatusCode)
@@ -308,7 +324,7 @@ func TestAPIKey(t *testing.T) {
308324
UserID: user.ID,
309325
})
310326
require.NoError(t, err)
311-
httpmw.ExtractAPIKey(db, nil)(successHandler).ServeHTTP(rw, r)
327+
httpmw.ExtractAPIKey(db, nil, false)(successHandler).ServeHTTP(rw, r)
312328
res := rw.Result()
313329
defer res.Body.Close()
314330
require.Equal(t, http.StatusOK, res.StatusCode)
@@ -344,7 +360,7 @@ func TestAPIKey(t *testing.T) {
344360
UserID: user.ID,
345361
})
346362
require.NoError(t, err)
347-
httpmw.ExtractAPIKey(db, nil)(successHandler).ServeHTTP(rw, r)
363+
httpmw.ExtractAPIKey(db, nil, false)(successHandler).ServeHTTP(rw, r)
348364
res := rw.Result()
349365
defer res.Body.Close()
350366
require.Equal(t, http.StatusOK, res.StatusCode)
@@ -391,7 +407,7 @@ func TestAPIKey(t *testing.T) {
391407
return token, nil
392408
}),
393409
},
394-
})(successHandler).ServeHTTP(rw, r)
410+
}, false)(successHandler).ServeHTTP(rw, r)
395411
res := rw.Result()
396412
defer res.Body.Close()
397413
require.Equal(t, http.StatusOK, res.StatusCode)
@@ -428,7 +444,7 @@ func TestAPIKey(t *testing.T) {
428444
UserID: user.ID,
429445
})
430446
require.NoError(t, err)
431-
httpmw.ExtractAPIKey(db, nil)(successHandler).ServeHTTP(rw, r)
447+
httpmw.ExtractAPIKey(db, nil, false)(successHandler).ServeHTTP(rw, r)
432448
res := rw.Result()
433449
defer res.Body.Close()
434450
require.Equal(t, http.StatusOK, res.StatusCode)

coderd/httpmw/authorize_test.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ func TestExtractUserRoles(t *testing.T) {
8383
rtr = chi.NewRouter()
8484
)
8585
rtr.Use(
86-
httpmw.ExtractAPIKey(db, &httpmw.OAuth2Configs{}),
86+
httpmw.ExtractAPIKey(db, &httpmw.OAuth2Configs{}, false),
8787
)
8888
rtr.Get("/", func(_ http.ResponseWriter, r *http.Request) {
8989
roles := httpmw.AuthorizationUserRoles(r)

coderd/httpmw/organizationparam_test.go

+5-5
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ func TestOrganizationParam(t *testing.T) {
6767
rtr = chi.NewRouter()
6868
)
6969
rtr.Use(
70-
httpmw.ExtractAPIKey(db, nil),
70+
httpmw.ExtractAPIKey(db, nil, false),
7171
httpmw.ExtractOrganizationParam(db),
7272
)
7373
rtr.Get("/", nil)
@@ -87,7 +87,7 @@ func TestOrganizationParam(t *testing.T) {
8787
)
8888
chi.RouteContext(r.Context()).URLParams.Add("organization", uuid.NewString())
8989
rtr.Use(
90-
httpmw.ExtractAPIKey(db, nil),
90+
httpmw.ExtractAPIKey(db, nil, false),
9191
httpmw.ExtractOrganizationParam(db),
9292
)
9393
rtr.Get("/", nil)
@@ -107,7 +107,7 @@ func TestOrganizationParam(t *testing.T) {
107107
)
108108
chi.RouteContext(r.Context()).URLParams.Add("organization", "not-a-uuid")
109109
rtr.Use(
110-
httpmw.ExtractAPIKey(db, nil),
110+
httpmw.ExtractAPIKey(db, nil, false),
111111
httpmw.ExtractOrganizationParam(db),
112112
)
113113
rtr.Get("/", nil)
@@ -135,7 +135,7 @@ func TestOrganizationParam(t *testing.T) {
135135
chi.RouteContext(r.Context()).URLParams.Add("organization", organization.ID.String())
136136
chi.RouteContext(r.Context()).URLParams.Add("user", u.ID.String())
137137
rtr.Use(
138-
httpmw.ExtractAPIKey(db, nil),
138+
httpmw.ExtractAPIKey(db, nil, false),
139139
httpmw.ExtractUserParam(db),
140140
httpmw.ExtractOrganizationParam(db),
141141
httpmw.ExtractOrganizationMemberParam(db),
@@ -172,7 +172,7 @@ func TestOrganizationParam(t *testing.T) {
172172
chi.RouteContext(r.Context()).URLParams.Add("organization", organization.ID.String())
173173
chi.RouteContext(r.Context()).URLParams.Add("user", user.ID.String())
174174
rtr.Use(
175-
httpmw.ExtractAPIKey(db, nil),
175+
httpmw.ExtractAPIKey(db, nil, false),
176176
httpmw.ExtractOrganizationParam(db),
177177
httpmw.ExtractUserParam(db),
178178
httpmw.ExtractOrganizationMemberParam(db),

coderd/httpmw/templateparam_test.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ func TestTemplateParam(t *testing.T) {
132132
db := databasefake.New()
133133
rtr := chi.NewRouter()
134134
rtr.Use(
135-
httpmw.ExtractAPIKey(db, nil),
135+
httpmw.ExtractAPIKey(db, nil, false),
136136
httpmw.ExtractTemplateParam(db),
137137
httpmw.ExtractOrganizationParam(db),
138138
)

coderd/httpmw/templateversionparam_test.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ func TestTemplateVersionParam(t *testing.T) {
124124
db := databasefake.New()
125125
rtr := chi.NewRouter()
126126
rtr.Use(
127-
httpmw.ExtractAPIKey(db, nil),
127+
httpmw.ExtractAPIKey(db, nil, false),
128128
httpmw.ExtractTemplateVersionParam(db),
129129
httpmw.ExtractOrganizationParam(db),
130130
)

0 commit comments

Comments
 (0)