@@ -194,37 +194,10 @@ func (w mockWsResponseWrite) Write(b []byte) (int, error) {
194
194
return w (b )
195
195
}
196
196
197
- func newMockWebsocketWriter () mockWsResponseWriter {
198
- server , client := net .Pipe ()
199
- recorder := httptest .NewRecorder ()
200
-
201
- var write mockWsResponseWrite = func (b []byte ) (int , error ) {
202
- serverCount , err := server .Write (b )
203
- if err != nil {
204
- return serverCount , err
205
- }
206
- recorderCount , err := recorder .Write (b )
207
- if serverCount < recorderCount {
208
- return serverCount , err
209
- }
210
- return recorderCount , err
211
- }
212
-
213
- return mockWsResponseWriter {
214
- serverConn : server ,
215
- clientConn : client ,
216
- recorder : recorder ,
217
- serverReadWriter : bufio .NewReadWriter (
218
- bufio .NewReader (server ),
219
- bufio .NewWriter (write ),
220
- ),
221
- }
222
- }
223
-
224
197
func TestOneWayWebSocket (t * testing.T ) {
225
198
t .Parallel ()
226
199
227
- createBaseRequest := func (t * testing.T ) * http.Request {
200
+ newBaseRequest := func (t * testing.T ) * http.Request {
228
201
url := "ws://www.fake-website.com/logs"
229
202
ctx := testutil .Context (t , testutil .WaitShort )
230
203
req , err := http .NewRequestWithContext (ctx , http .MethodGet , url , nil )
@@ -234,11 +207,35 @@ func TestOneWayWebSocket(t *testing.T) {
234
207
h .Add ("Connection" , "Upgrade" )
235
208
h .Add ("Upgrade" , "websocket" )
236
209
h .Add ("Sec-WebSocket-Version" , "13" )
237
- h .Add ("Sec-WebSocket-Key" , "dGhlIHNhbXBsZSBub25jZQ==" )
210
+ h .Add ("Sec-WebSocket-Key" , "dGhlIHNhbXBsZSBub25jZQ==" ) // Just need any string
238
211
239
212
return req
240
213
}
241
214
215
+ newWebsocketWriter := func () mockWsResponseWriter {
216
+ server , client := net .Pipe ()
217
+ recorder := httptest .NewRecorder ()
218
+
219
+ var write mockWsResponseWrite = func (b []byte ) (int , error ) {
220
+ serverCount , err := server .Write (b )
221
+ if err != nil {
222
+ return serverCount , err
223
+ }
224
+ recorderCount , err := recorder .Write (b )
225
+ return min (serverCount , recorderCount ), err
226
+ }
227
+
228
+ return mockWsResponseWriter {
229
+ serverConn : server ,
230
+ clientConn : client ,
231
+ recorder : recorder ,
232
+ serverReadWriter : bufio .NewReadWriter (
233
+ bufio .NewReader (server ),
234
+ bufio .NewWriter (write ),
235
+ ),
236
+ }
237
+ }
238
+
242
239
t .Run ("Produces an error if the socket connection could not be established" , func (t * testing.T ) {
243
240
t .Parallel ()
244
241
@@ -251,12 +248,12 @@ func TestOneWayWebSocket(t *testing.T) {
251
248
{1 , 0 , "HTTP/1.0" },
252
249
}
253
250
for _ , p := range incorrectProtocols {
254
- req := createBaseRequest (t )
251
+ req := newBaseRequest (t )
255
252
req .ProtoMajor = p .major
256
253
req .ProtoMinor = p .minor
257
254
req .Proto = p .proto
258
255
259
- writer := newMockWebsocketWriter ()
256
+ writer := newWebsocketWriter ()
260
257
_ , _ , err := httpapi .OneWayWebSocket [any ](writer , req )
261
258
require .ErrorContains (t , err , p .proto )
262
259
}
@@ -265,15 +262,16 @@ func TestOneWayWebSocket(t *testing.T) {
265
262
t .Run ("Returned callback can publish a new event to the WebSocket connection" , func (t * testing.T ) {
266
263
t .Parallel ()
267
264
268
- req := createBaseRequest (t )
269
- writer := newMockWebsocketWriter ()
265
+ req := newBaseRequest (t )
266
+ writer := newWebsocketWriter ()
270
267
send , _ , err := httpapi .OneWayWebSocket [codersdk.ServerSentEvent ](writer , req )
271
268
require .NoError (t , err )
272
269
273
- err = send ( codersdk.ServerSentEvent {
270
+ payload := codersdk.ServerSentEvent {
274
271
Type : codersdk .ServerSentEventTypeData ,
275
272
Data : "Blah" ,
276
- })
273
+ }
274
+ err = send (payload )
277
275
require .NoError (t , err )
278
276
})
279
277
0 commit comments