Skip to content

Commit 5eafd6d

Browse files
committed
improve testing coverage
1 parent 66bd65e commit 5eafd6d

File tree

3 files changed

+92
-19
lines changed

3 files changed

+92
-19
lines changed

coderd/httpmw/provisionerdaemon.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@ type ExtractProvisionerAuthConfig struct {
2525
PSK string
2626
}
2727

28+
// ExtractProvisionerDaemonAuthenticated authenticates a request as a provisioner daemon.
29+
// If the request is not authenticated, the next handler is called unless Optional is true.
30+
// This function currently is tested inside the enterprise package.
2831
func ExtractProvisionerDaemonAuthenticated(opts ExtractProvisionerAuthConfig) func(next http.Handler) http.Handler {
2932
return func(next http.Handler) http.Handler {
3033
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {

enterprise/coderd/httpmw/doc.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
// Package httpmw contains middleware for HTTP handlers.
2+
// Currently, the tested middleware is inside the AGPL package.
3+
// As the middleware also contains enterprise-only logic, tests had to be
4+
// moved here.
5+
package httpmw

enterprise/coderd/httpmw/provisionerdaemon_test.go

Lines changed: 84 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,13 @@ import (
77
"testing"
88

99
"github.com/go-chi/chi/v5"
10+
"github.com/google/uuid"
1011
"github.com/stretchr/testify/require"
12+
"go.uber.org/mock/gomock"
13+
"golang.org/x/xerrors"
1114

15+
"github.com/coder/coder/v2/coderd/database"
16+
"github.com/coder/coder/v2/coderd/database/dbmock"
1217
"github.com/coder/coder/v2/coderd/httpmw"
1318
"github.com/coder/coder/v2/codersdk"
1419
"github.com/coder/coder/v2/enterprise/coderd/coderdenttest"
@@ -17,6 +22,10 @@ import (
1722
)
1823

1924
func TestExtractProvisionerDaemonAuthenticated(t *testing.T) {
25+
const (
26+
//nolint:gosec // test key generated by test
27+
functionalKey = "5Hl2Qw9kX3nM7vB4jR8pY6tA1cF0eD5uI2oL9gN3mZ4"
28+
)
2029
t.Parallel()
2130

2231
tests := []struct {
@@ -33,8 +42,7 @@ func TestExtractProvisionerDaemonAuthenticated(t *testing.T) {
3342
DB: nil,
3443
Optional: true,
3544
},
36-
expectedStatusCode: http.StatusOK,
37-
expectedResponseMessage: "",
45+
expectedStatusCode: http.StatusOK,
3846
},
3947
{
4048
name: "NoKeyProvided_NotOptional",
@@ -62,9 +70,8 @@ func TestExtractProvisionerDaemonAuthenticated(t *testing.T) {
6270
DB: nil,
6371
Optional: true,
6472
},
65-
provisionerKey: "key",
66-
expectedStatusCode: http.StatusOK,
67-
expectedResponseMessage: "",
73+
provisionerKey: "key",
74+
expectedStatusCode: http.StatusOK,
6875
},
6976
{
7077
name: "InvalidProvisionerKey_NotOptional",
@@ -82,9 +89,7 @@ func TestExtractProvisionerDaemonAuthenticated(t *testing.T) {
8289
DB: nil,
8390
Optional: true,
8491
},
85-
provisionerKey: "invalid",
86-
expectedStatusCode: http.StatusOK,
87-
expectedResponseMessage: "",
92+
provisionerKey: "invalid",
8893
},
8994
{
9095
name: "InvalidProvisionerPSK_NotOptional",
@@ -104,9 +109,8 @@ func TestExtractProvisionerDaemonAuthenticated(t *testing.T) {
104109
Optional: true,
105110
PSK: "psk",
106111
},
107-
provisionerPSK: "invalid",
108-
expectedStatusCode: http.StatusOK,
109-
expectedResponseMessage: "",
112+
provisionerPSK: "invalid",
113+
expectedStatusCode: http.StatusOK,
110114
},
111115
{
112116
name: "ValidProvisionerPSK_NotOptional",
@@ -115,9 +119,8 @@ func TestExtractProvisionerDaemonAuthenticated(t *testing.T) {
115119
Optional: false,
116120
PSK: "ThisIsAValidPSK",
117121
},
118-
provisionerPSK: "ThisIsAValidPSK",
119-
expectedStatusCode: http.StatusOK,
120-
expectedResponseMessage: "",
122+
provisionerPSK: "ThisIsAValidPSK",
123+
expectedStatusCode: http.StatusOK,
121124
},
122125
}
123126

@@ -152,8 +155,7 @@ func TestExtractProvisionerDaemonAuthenticated(t *testing.T) {
152155
t.Run("ProvisionerKey", func(t *testing.T) {
153156
t.Parallel()
154157

155-
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
156-
defer cancel()
158+
ctx := testutil.Context(t, testutil.WaitShort)
157159
client, db, user := coderdenttest.NewWithDatabase(t, &coderdenttest.Options{
158160
LicenseOptions: &coderdenttest.LicenseOptions{
159161
Features: license.Features{
@@ -188,8 +190,7 @@ func TestExtractProvisionerDaemonAuthenticated(t *testing.T) {
188190
t.Run("ProvisionerKey_NotFound", func(t *testing.T) {
189191
t.Parallel()
190192

191-
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
192-
defer cancel()
193+
ctx := testutil.Context(t, testutil.WaitShort)
193194
client, db, user := coderdenttest.NewWithDatabase(t, &coderdenttest.Options{
194195
LicenseOptions: &coderdenttest.LicenseOptions{
195196
Features: license.Features{
@@ -208,7 +209,9 @@ func TestExtractProvisionerDaemonAuthenticated(t *testing.T) {
208209
r = r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, routeCtx))
209210
res := httptest.NewRecorder()
210211

211-
r.Header.Set(codersdk.ProvisionerDaemonKey, "5Hl2Qw9kX3nM7vB4jR8pY6tA1cF0eD5uI2oL9gN3mZ4")
212+
//nolint:gosec // test key generated by test
213+
pkey := "5Hl2Qw9kX3nM7vB4jR8pY6tA1cF0eD5uI2oL9gN3mZ4"
214+
r.Header.Set(codersdk.ProvisionerDaemonKey, pkey)
212215

213216
httpmw.ExtractProvisionerDaemonAuthenticated(httpmw.ExtractProvisionerAuthConfig{
214217
DB: db,
@@ -221,4 +224,66 @@ func TestExtractProvisionerDaemonAuthenticated(t *testing.T) {
221224
require.Equal(t, http.StatusUnauthorized, res.Result().StatusCode)
222225
require.Contains(t, res.Body.String(), "provisioner daemon key invalid")
223226
})
227+
228+
t.Run("ProvisionerKey_CompareFail", func(t *testing.T) {
229+
t.Parallel()
230+
231+
ctrl := gomock.NewController(t)
232+
mockDB := dbmock.NewMockStore(ctrl)
233+
234+
gomock.InOrder(
235+
mockDB.EXPECT().GetProvisionerKeyByHashedSecret(gomock.Any(), gomock.Any()).Times(1).Return(database.ProvisionerKey{
236+
ID: uuid.New(),
237+
HashedSecret: []byte("hashedSecret"),
238+
}, nil),
239+
)
240+
241+
routeCtx := chi.NewRouteContext()
242+
r := httptest.NewRequest(http.MethodGet, "/", nil)
243+
r = r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, routeCtx))
244+
res := httptest.NewRecorder()
245+
246+
r.Header.Set(codersdk.ProvisionerDaemonKey, functionalKey)
247+
248+
httpmw.ExtractProvisionerDaemonAuthenticated(httpmw.ExtractProvisionerAuthConfig{
249+
DB: mockDB,
250+
Optional: false,
251+
})(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
252+
w.WriteHeader(http.StatusOK)
253+
})).ServeHTTP(res, r)
254+
255+
//nolint:bodyclose
256+
require.Equal(t, http.StatusUnauthorized, res.Result().StatusCode)
257+
require.Contains(t, res.Body.String(), "provisioner daemon key invalid")
258+
})
259+
260+
t.Run("ProvisionerKey_DBError", func(t *testing.T) {
261+
t.Parallel()
262+
263+
ctrl := gomock.NewController(t)
264+
mockDB := dbmock.NewMockStore(ctrl)
265+
266+
gomock.InOrder(
267+
mockDB.EXPECT().GetProvisionerKeyByHashedSecret(gomock.Any(), gomock.Any()).Times(1).Return(database.ProvisionerKey{}, xerrors.New("error")),
268+
)
269+
270+
routeCtx := chi.NewRouteContext()
271+
r := httptest.NewRequest(http.MethodGet, "/", nil)
272+
r = r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, routeCtx))
273+
res := httptest.NewRecorder()
274+
275+
//nolint:gosec // test key generated by test
276+
r.Header.Set(codersdk.ProvisionerDaemonKey, functionalKey)
277+
278+
httpmw.ExtractProvisionerDaemonAuthenticated(httpmw.ExtractProvisionerAuthConfig{
279+
DB: mockDB,
280+
Optional: false,
281+
})(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
282+
w.WriteHeader(http.StatusOK)
283+
})).ServeHTTP(res, r)
284+
285+
//nolint:bodyclose
286+
require.Equal(t, http.StatusUnauthorized, res.Result().StatusCode)
287+
require.Contains(t, res.Body.String(), "get provisioner daemon key")
288+
})
224289
}

0 commit comments

Comments
 (0)