Skip to content

Commit 60ddcf5

Browse files
authored
chore: improve testing coverage on ExtractProvisionerDaemonAuthenticated middleware (coder#15622)
This one aims to resolve coder#15604 Created some table tests for the main cases - also preferred to create two isolated cases for the most complicated cases in order to keep table tests simple enough. Give us full coverage on the middleware logic, for both optional and non optional cases - PSK and ProvisionerKey.
1 parent d60b588 commit 60ddcf5

File tree

3 files changed

+298
-0
lines changed

3 files changed

+298
-0
lines changed

coderd/httpmw/provisionerdaemon.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@ type ExtractProvisionerAuthConfig struct {
2525
PSK string
2626
}
2727

28+
// ExtractProvisionerDaemonAuthenticated authenticates a request as a provisioner daemon.
29+
// If the request is not authenticated, the next handler is called unless Optional is true.
30+
// This function currently is tested inside the enterprise package.
2831
func ExtractProvisionerDaemonAuthenticated(opts ExtractProvisionerAuthConfig) func(next http.Handler) http.Handler {
2932
return func(next http.Handler) http.Handler {
3033
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {

enterprise/coderd/httpmw/doc.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
// Package httpmw contains middleware for HTTP handlers.
2+
// Currently, the tested middleware is inside the AGPL package.
3+
// As the middleware also contains enterprise-only logic, tests had to be
4+
// moved here.
5+
package httpmw
Lines changed: 290 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,290 @@
1+
package httpmw_test
2+
3+
import (
4+
"context"
5+
"net/http"
6+
"net/http/httptest"
7+
"testing"
8+
9+
"github.com/go-chi/chi/v5"
10+
"github.com/google/uuid"
11+
"github.com/stretchr/testify/require"
12+
"go.uber.org/mock/gomock"
13+
"golang.org/x/xerrors"
14+
15+
"github.com/coder/coder/v2/coderd/database"
16+
"github.com/coder/coder/v2/coderd/database/dbmock"
17+
"github.com/coder/coder/v2/coderd/httpmw"
18+
"github.com/coder/coder/v2/codersdk"
19+
"github.com/coder/coder/v2/enterprise/coderd/coderdenttest"
20+
"github.com/coder/coder/v2/enterprise/coderd/license"
21+
"github.com/coder/coder/v2/testutil"
22+
)
23+
24+
func TestExtractProvisionerDaemonAuthenticated(t *testing.T) {
25+
const (
26+
//nolint:gosec // test key generated by test
27+
functionalKey = "5Hl2Qw9kX3nM7vB4jR8pY6tA1cF0eD5uI2oL9gN3mZ4"
28+
)
29+
t.Parallel()
30+
31+
tests := []struct {
32+
name string
33+
opts httpmw.ExtractProvisionerAuthConfig
34+
expectedStatusCode int
35+
expectedResponseMessage string
36+
provisionerKey string
37+
provisionerPSK string
38+
}{
39+
{
40+
name: "NoKeyProvided_Optional",
41+
opts: httpmw.ExtractProvisionerAuthConfig{
42+
DB: nil,
43+
Optional: true,
44+
},
45+
expectedStatusCode: http.StatusOK,
46+
},
47+
{
48+
name: "NoKeyProvided_NotOptional",
49+
opts: httpmw.ExtractProvisionerAuthConfig{
50+
DB: nil,
51+
Optional: false,
52+
},
53+
expectedStatusCode: http.StatusUnauthorized,
54+
expectedResponseMessage: "provisioner daemon key required",
55+
},
56+
{
57+
name: "ProvisionerKeyAndPSKProvided_NotOptional",
58+
opts: httpmw.ExtractProvisionerAuthConfig{
59+
DB: nil,
60+
Optional: false,
61+
},
62+
provisionerKey: "key",
63+
provisionerPSK: "psk",
64+
expectedStatusCode: http.StatusBadRequest,
65+
expectedResponseMessage: "provisioner daemon key and psk provided, but only one is allowed",
66+
},
67+
{
68+
name: "ProvisionerKeyAndPSKProvided_Optional",
69+
opts: httpmw.ExtractProvisionerAuthConfig{
70+
DB: nil,
71+
Optional: true,
72+
},
73+
provisionerKey: "key",
74+
expectedStatusCode: http.StatusOK,
75+
},
76+
{
77+
name: "InvalidProvisionerKey_NotOptional",
78+
opts: httpmw.ExtractProvisionerAuthConfig{
79+
DB: nil,
80+
Optional: false,
81+
},
82+
provisionerKey: "invalid",
83+
expectedStatusCode: http.StatusBadRequest,
84+
expectedResponseMessage: "provisioner daemon key invalid",
85+
},
86+
{
87+
name: "InvalidProvisionerKey_Optional",
88+
opts: httpmw.ExtractProvisionerAuthConfig{
89+
DB: nil,
90+
Optional: true,
91+
},
92+
provisionerKey: "invalid",
93+
expectedStatusCode: http.StatusOK,
94+
},
95+
{
96+
name: "InvalidProvisionerPSK_NotOptional",
97+
opts: httpmw.ExtractProvisionerAuthConfig{
98+
DB: nil,
99+
Optional: false,
100+
PSK: "psk",
101+
},
102+
provisionerPSK: "invalid",
103+
expectedStatusCode: http.StatusUnauthorized,
104+
expectedResponseMessage: "provisioner daemon psk invalid",
105+
},
106+
{
107+
name: "InvalidProvisionerPSK_Optional",
108+
opts: httpmw.ExtractProvisionerAuthConfig{
109+
DB: nil,
110+
Optional: true,
111+
PSK: "psk",
112+
},
113+
provisionerPSK: "invalid",
114+
expectedStatusCode: http.StatusOK,
115+
},
116+
{
117+
name: "ValidProvisionerPSK_NotOptional",
118+
opts: httpmw.ExtractProvisionerAuthConfig{
119+
DB: nil,
120+
Optional: false,
121+
PSK: "ThisIsAValidPSK",
122+
},
123+
provisionerPSK: "ThisIsAValidPSK",
124+
expectedStatusCode: http.StatusOK,
125+
},
126+
}
127+
128+
for _, test := range tests {
129+
test := test
130+
t.Run(test.name, func(t *testing.T) {
131+
t.Parallel()
132+
routeCtx := chi.NewRouteContext()
133+
r := httptest.NewRequest(http.MethodGet, "/", nil)
134+
r = r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, routeCtx))
135+
res := httptest.NewRecorder()
136+
137+
if test.provisionerKey != "" {
138+
r.Header.Set(codersdk.ProvisionerDaemonKey, test.provisionerKey)
139+
}
140+
if test.provisionerPSK != "" {
141+
r.Header.Set(codersdk.ProvisionerDaemonPSK, test.provisionerPSK)
142+
}
143+
144+
httpmw.ExtractProvisionerDaemonAuthenticated(test.opts)(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
145+
w.WriteHeader(http.StatusOK)
146+
})).ServeHTTP(res, r)
147+
148+
//nolint:bodyclose
149+
require.Equal(t, test.expectedStatusCode, res.Result().StatusCode)
150+
if test.expectedResponseMessage != "" {
151+
require.Contains(t, res.Body.String(), test.expectedResponseMessage)
152+
}
153+
})
154+
}
155+
156+
t.Run("ProvisionerKey", func(t *testing.T) {
157+
t.Parallel()
158+
159+
ctx := testutil.Context(t, testutil.WaitShort)
160+
client, db, user := coderdenttest.NewWithDatabase(t, &coderdenttest.Options{
161+
LicenseOptions: &coderdenttest.LicenseOptions{
162+
Features: license.Features{
163+
codersdk.FeatureExternalProvisionerDaemons: 1,
164+
},
165+
},
166+
})
167+
// nolint:gocritic // test
168+
key, err := client.CreateProvisionerKey(ctx, user.OrganizationID, codersdk.CreateProvisionerKeyRequest{
169+
Name: "dont-TEST-me",
170+
})
171+
require.NoError(t, err)
172+
173+
routeCtx := chi.NewRouteContext()
174+
r := httptest.NewRequest(http.MethodGet, "/", nil)
175+
r = r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, routeCtx))
176+
res := httptest.NewRecorder()
177+
178+
r.Header.Set(codersdk.ProvisionerDaemonKey, key.Key)
179+
180+
httpmw.ExtractProvisionerDaemonAuthenticated(httpmw.ExtractProvisionerAuthConfig{
181+
DB: db,
182+
Optional: false,
183+
})(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
184+
w.WriteHeader(http.StatusOK)
185+
})).ServeHTTP(res, r)
186+
187+
//nolint:bodyclose
188+
require.Equal(t, http.StatusOK, res.Result().StatusCode)
189+
})
190+
191+
t.Run("ProvisionerKey_NotFound", func(t *testing.T) {
192+
t.Parallel()
193+
194+
ctx := testutil.Context(t, testutil.WaitShort)
195+
client, db, user := coderdenttest.NewWithDatabase(t, &coderdenttest.Options{
196+
LicenseOptions: &coderdenttest.LicenseOptions{
197+
Features: license.Features{
198+
codersdk.FeatureExternalProvisionerDaemons: 1,
199+
},
200+
},
201+
})
202+
// nolint:gocritic // test
203+
_, err := client.CreateProvisionerKey(ctx, user.OrganizationID, codersdk.CreateProvisionerKeyRequest{
204+
Name: "dont-TEST-me",
205+
})
206+
require.NoError(t, err)
207+
208+
routeCtx := chi.NewRouteContext()
209+
r := httptest.NewRequest(http.MethodGet, "/", nil)
210+
r = r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, routeCtx))
211+
res := httptest.NewRecorder()
212+
213+
//nolint:gosec // test key generated by test
214+
pkey := "5Hl2Qw9kX3nM7vB4jR8pY6tA1cF0eD5uI2oL9gN3mZ4"
215+
r.Header.Set(codersdk.ProvisionerDaemonKey, pkey)
216+
217+
httpmw.ExtractProvisionerDaemonAuthenticated(httpmw.ExtractProvisionerAuthConfig{
218+
DB: db,
219+
Optional: false,
220+
})(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
221+
w.WriteHeader(http.StatusOK)
222+
})).ServeHTTP(res, r)
223+
224+
//nolint:bodyclose
225+
require.Equal(t, http.StatusUnauthorized, res.Result().StatusCode)
226+
require.Contains(t, res.Body.String(), "provisioner daemon key invalid")
227+
})
228+
229+
t.Run("ProvisionerKey_CompareFail", func(t *testing.T) {
230+
t.Parallel()
231+
232+
ctrl := gomock.NewController(t)
233+
mockDB := dbmock.NewMockStore(ctrl)
234+
235+
gomock.InOrder(
236+
mockDB.EXPECT().GetProvisionerKeyByHashedSecret(gomock.Any(), gomock.Any()).Times(1).Return(database.ProvisionerKey{
237+
ID: uuid.New(),
238+
HashedSecret: []byte("hashedSecret"),
239+
}, nil),
240+
)
241+
242+
routeCtx := chi.NewRouteContext()
243+
r := httptest.NewRequest(http.MethodGet, "/", nil)
244+
r = r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, routeCtx))
245+
res := httptest.NewRecorder()
246+
247+
r.Header.Set(codersdk.ProvisionerDaemonKey, functionalKey)
248+
249+
httpmw.ExtractProvisionerDaemonAuthenticated(httpmw.ExtractProvisionerAuthConfig{
250+
DB: mockDB,
251+
Optional: false,
252+
})(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
253+
w.WriteHeader(http.StatusOK)
254+
})).ServeHTTP(res, r)
255+
256+
//nolint:bodyclose
257+
require.Equal(t, http.StatusUnauthorized, res.Result().StatusCode)
258+
require.Contains(t, res.Body.String(), "provisioner daemon key invalid")
259+
})
260+
261+
t.Run("ProvisionerKey_DBError", func(t *testing.T) {
262+
t.Parallel()
263+
264+
ctrl := gomock.NewController(t)
265+
mockDB := dbmock.NewMockStore(ctrl)
266+
267+
gomock.InOrder(
268+
mockDB.EXPECT().GetProvisionerKeyByHashedSecret(gomock.Any(), gomock.Any()).Times(1).Return(database.ProvisionerKey{}, xerrors.New("error")),
269+
)
270+
271+
routeCtx := chi.NewRouteContext()
272+
r := httptest.NewRequest(http.MethodGet, "/", nil)
273+
r = r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, routeCtx))
274+
res := httptest.NewRecorder()
275+
276+
//nolint:gosec // test key generated by test
277+
r.Header.Set(codersdk.ProvisionerDaemonKey, functionalKey)
278+
279+
httpmw.ExtractProvisionerDaemonAuthenticated(httpmw.ExtractProvisionerAuthConfig{
280+
DB: mockDB,
281+
Optional: false,
282+
})(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
283+
w.WriteHeader(http.StatusOK)
284+
})).ServeHTTP(res, r)
285+
286+
//nolint:bodyclose
287+
require.Equal(t, http.StatusInternalServerError, res.Result().StatusCode)
288+
require.Contains(t, res.Body.String(), "get provisioner daemon key")
289+
})
290+
}

0 commit comments

Comments
 (0)