@@ -7,8 +7,13 @@ import (
7
7
"testing"
8
8
9
9
"github.com/go-chi/chi/v5"
10
+ "github.com/google/uuid"
10
11
"github.com/stretchr/testify/require"
12
+ "go.uber.org/mock/gomock"
13
+ "golang.org/x/xerrors"
11
14
15
+ "github.com/coder/coder/v2/coderd/database"
16
+ "github.com/coder/coder/v2/coderd/database/dbmock"
12
17
"github.com/coder/coder/v2/coderd/httpmw"
13
18
"github.com/coder/coder/v2/codersdk"
14
19
"github.com/coder/coder/v2/enterprise/coderd/coderdenttest"
@@ -17,6 +22,10 @@ import (
17
22
)
18
23
19
24
func TestExtractProvisionerDaemonAuthenticated (t * testing.T ) {
25
+ const (
26
+ //nolint:gosec // test key generated by test
27
+ functionalKey = "5Hl2Qw9kX3nM7vB4jR8pY6tA1cF0eD5uI2oL9gN3mZ4"
28
+ )
20
29
t .Parallel ()
21
30
22
31
tests := []struct {
@@ -33,8 +42,7 @@ func TestExtractProvisionerDaemonAuthenticated(t *testing.T) {
33
42
DB : nil ,
34
43
Optional : true ,
35
44
},
36
- expectedStatusCode : http .StatusOK ,
37
- expectedResponseMessage : "" ,
45
+ expectedStatusCode : http .StatusOK ,
38
46
},
39
47
{
40
48
name : "NoKeyProvided_NotOptional" ,
@@ -62,9 +70,8 @@ func TestExtractProvisionerDaemonAuthenticated(t *testing.T) {
62
70
DB : nil ,
63
71
Optional : true ,
64
72
},
65
- provisionerKey : "key" ,
66
- expectedStatusCode : http .StatusOK ,
67
- expectedResponseMessage : "" ,
73
+ provisionerKey : "key" ,
74
+ expectedStatusCode : http .StatusOK ,
68
75
},
69
76
{
70
77
name : "InvalidProvisionerKey_NotOptional" ,
@@ -82,9 +89,7 @@ func TestExtractProvisionerDaemonAuthenticated(t *testing.T) {
82
89
DB : nil ,
83
90
Optional : true ,
84
91
},
85
- provisionerKey : "invalid" ,
86
- expectedStatusCode : http .StatusOK ,
87
- expectedResponseMessage : "" ,
92
+ provisionerKey : "invalid" ,
88
93
},
89
94
{
90
95
name : "InvalidProvisionerPSK_NotOptional" ,
@@ -104,9 +109,8 @@ func TestExtractProvisionerDaemonAuthenticated(t *testing.T) {
104
109
Optional : true ,
105
110
PSK : "psk" ,
106
111
},
107
- provisionerPSK : "invalid" ,
108
- expectedStatusCode : http .StatusOK ,
109
- expectedResponseMessage : "" ,
112
+ provisionerPSK : "invalid" ,
113
+ expectedStatusCode : http .StatusOK ,
110
114
},
111
115
{
112
116
name : "ValidProvisionerPSK_NotOptional" ,
@@ -115,9 +119,8 @@ func TestExtractProvisionerDaemonAuthenticated(t *testing.T) {
115
119
Optional : false ,
116
120
PSK : "ThisIsAValidPSK" ,
117
121
},
118
- provisionerPSK : "ThisIsAValidPSK" ,
119
- expectedStatusCode : http .StatusOK ,
120
- expectedResponseMessage : "" ,
122
+ provisionerPSK : "ThisIsAValidPSK" ,
123
+ expectedStatusCode : http .StatusOK ,
121
124
},
122
125
}
123
126
@@ -152,8 +155,7 @@ func TestExtractProvisionerDaemonAuthenticated(t *testing.T) {
152
155
t .Run ("ProvisionerKey" , func (t * testing.T ) {
153
156
t .Parallel ()
154
157
155
- ctx , cancel := context .WithTimeout (context .Background (), testutil .WaitShort )
156
- defer cancel ()
158
+ ctx := testutil .Context (t , testutil .WaitShort )
157
159
client , db , user := coderdenttest .NewWithDatabase (t , & coderdenttest.Options {
158
160
LicenseOptions : & coderdenttest.LicenseOptions {
159
161
Features : license.Features {
@@ -188,8 +190,7 @@ func TestExtractProvisionerDaemonAuthenticated(t *testing.T) {
188
190
t .Run ("ProvisionerKey_NotFound" , func (t * testing.T ) {
189
191
t .Parallel ()
190
192
191
- ctx , cancel := context .WithTimeout (context .Background (), testutil .WaitShort )
192
- defer cancel ()
193
+ ctx := testutil .Context (t , testutil .WaitShort )
193
194
client , db , user := coderdenttest .NewWithDatabase (t , & coderdenttest.Options {
194
195
LicenseOptions : & coderdenttest.LicenseOptions {
195
196
Features : license.Features {
@@ -208,7 +209,9 @@ func TestExtractProvisionerDaemonAuthenticated(t *testing.T) {
208
209
r = r .WithContext (context .WithValue (r .Context (), chi .RouteCtxKey , routeCtx ))
209
210
res := httptest .NewRecorder ()
210
211
211
- r .Header .Set (codersdk .ProvisionerDaemonKey , "5Hl2Qw9kX3nM7vB4jR8pY6tA1cF0eD5uI2oL9gN3mZ4" )
212
+ //nolint:gosec // test key generated by test
213
+ pkey := "5Hl2Qw9kX3nM7vB4jR8pY6tA1cF0eD5uI2oL9gN3mZ4"
214
+ r .Header .Set (codersdk .ProvisionerDaemonKey , pkey )
212
215
213
216
httpmw .ExtractProvisionerDaemonAuthenticated (httpmw.ExtractProvisionerAuthConfig {
214
217
DB : db ,
@@ -221,4 +224,66 @@ func TestExtractProvisionerDaemonAuthenticated(t *testing.T) {
221
224
require .Equal (t , http .StatusUnauthorized , res .Result ().StatusCode )
222
225
require .Contains (t , res .Body .String (), "provisioner daemon key invalid" )
223
226
})
227
+
228
+ t .Run ("ProvisionerKey_CompareFail" , func (t * testing.T ) {
229
+ t .Parallel ()
230
+
231
+ ctrl := gomock .NewController (t )
232
+ mockDB := dbmock .NewMockStore (ctrl )
233
+
234
+ gomock .InOrder (
235
+ mockDB .EXPECT ().GetProvisionerKeyByHashedSecret (gomock .Any (), gomock .Any ()).Times (1 ).Return (database.ProvisionerKey {
236
+ ID : uuid .New (),
237
+ HashedSecret : []byte ("hashedSecret" ),
238
+ }, nil ),
239
+ )
240
+
241
+ routeCtx := chi .NewRouteContext ()
242
+ r := httptest .NewRequest (http .MethodGet , "/" , nil )
243
+ r = r .WithContext (context .WithValue (r .Context (), chi .RouteCtxKey , routeCtx ))
244
+ res := httptest .NewRecorder ()
245
+
246
+ r .Header .Set (codersdk .ProvisionerDaemonKey , functionalKey )
247
+
248
+ httpmw .ExtractProvisionerDaemonAuthenticated (httpmw.ExtractProvisionerAuthConfig {
249
+ DB : mockDB ,
250
+ Optional : false ,
251
+ })(http .HandlerFunc (func (w http.ResponseWriter , _ * http.Request ) {
252
+ w .WriteHeader (http .StatusOK )
253
+ })).ServeHTTP (res , r )
254
+
255
+ //nolint:bodyclose
256
+ require .Equal (t , http .StatusUnauthorized , res .Result ().StatusCode )
257
+ require .Contains (t , res .Body .String (), "provisioner daemon key invalid" )
258
+ })
259
+
260
+ t .Run ("ProvisionerKey_DBError" , func (t * testing.T ) {
261
+ t .Parallel ()
262
+
263
+ ctrl := gomock .NewController (t )
264
+ mockDB := dbmock .NewMockStore (ctrl )
265
+
266
+ gomock .InOrder (
267
+ mockDB .EXPECT ().GetProvisionerKeyByHashedSecret (gomock .Any (), gomock .Any ()).Times (1 ).Return (database.ProvisionerKey {}, xerrors .New ("error" )),
268
+ )
269
+
270
+ routeCtx := chi .NewRouteContext ()
271
+ r := httptest .NewRequest (http .MethodGet , "/" , nil )
272
+ r = r .WithContext (context .WithValue (r .Context (), chi .RouteCtxKey , routeCtx ))
273
+ res := httptest .NewRecorder ()
274
+
275
+ //nolint:gosec // test key generated by test
276
+ r .Header .Set (codersdk .ProvisionerDaemonKey , functionalKey )
277
+
278
+ httpmw .ExtractProvisionerDaemonAuthenticated (httpmw.ExtractProvisionerAuthConfig {
279
+ DB : mockDB ,
280
+ Optional : false ,
281
+ })(http .HandlerFunc (func (w http.ResponseWriter , _ * http.Request ) {
282
+ w .WriteHeader (http .StatusOK )
283
+ })).ServeHTTP (res , r )
284
+
285
+ //nolint:bodyclose
286
+ require .Equal (t , http .StatusUnauthorized , res .Result ().StatusCode )
287
+ require .Contains (t , res .Body .String (), "get provisioner daemon key" )
288
+ })
224
289
}
0 commit comments