@@ -12,6 +12,7 @@ import (
12
12
"net/http/httptest"
13
13
"strings"
14
14
"testing"
15
+ "time"
15
16
16
17
"github.com/stretchr/testify/assert"
17
18
"github.com/stretchr/testify/require"
@@ -199,9 +200,8 @@ func (w mockWsResponseWrite) Write(b []byte) (int, error) {
199
200
func TestOneWayWebSocket (t * testing.T ) {
200
201
t .Parallel ()
201
202
202
- newBaseRequest := func (t * testing. T ) * http.Request {
203
+ newBaseRequest := func (ctx context. Context ) * http.Request {
203
204
url := "ws://www.fake-website.com/logs"
204
- ctx := testutil .Context (t , testutil .WaitShort )
205
205
req , err := http .NewRequestWithContext (ctx , http .MethodGet , url , nil )
206
206
require .NoError (t , err )
207
207
@@ -243,7 +243,7 @@ func TestOneWayWebSocket(t *testing.T) {
243
243
}
244
244
}
245
245
246
- t .Run ("Produces an error if the socket connection could not be established" , func (t * testing.T ) {
246
+ t .Run ("Produces error if the socket connection could not be established" , func (t * testing.T ) {
247
247
t .Parallel ()
248
248
249
249
incorrectProtocols := []struct {
@@ -255,24 +255,26 @@ func TestOneWayWebSocket(t *testing.T) {
255
255
{1 , 0 , "HTTP/1.0" },
256
256
}
257
257
for _ , p := range incorrectProtocols {
258
- req := newBaseRequest (t )
258
+ ctx := testutil .Context (t , testutil .WaitShort )
259
+ req := newBaseRequest (ctx )
259
260
req .ProtoMajor = p .major
260
261
req .ProtoMinor = p .minor
261
262
req .Proto = p .proto
262
263
263
264
writer := newWebsocketWriter ()
264
265
_ , _ , err := httpapi .OneWayWebSocket [any ](writer , req )
265
266
require .ErrorContains (t , err , p .proto )
266
- writer .close ( )
267
+ t . Cleanup ( writer .close )
267
268
}
268
269
})
269
270
270
- t .Run ("Returned callback can publish a new event to the WebSocket connection" , func (t * testing.T ) {
271
+ t .Run ("Returned callback can publish new event to WebSocket connection" , func (t * testing.T ) {
271
272
t .Parallel ()
272
273
273
- req := newBaseRequest (t )
274
+ ctx := testutil .Context (t , testutil .WaitShort )
275
+ req := newBaseRequest (ctx )
274
276
writer := newWebsocketWriter ()
275
- defer writer .close ( )
277
+ t . Cleanup ( writer .close )
276
278
send , _ , err := httpapi .OneWayWebSocket [codersdk.ServerSentEvent ](writer , req )
277
279
require .NoError (t , err )
278
280
@@ -285,20 +287,44 @@ func TestOneWayWebSocket(t *testing.T) {
285
287
286
288
b , err := io .ReadAll (writer .clientConn )
287
289
require .NoError (t , err )
290
+ fmt .Printf ("-----------%q\n " , b ) // todo: Figure out why junk characters are added to JSON
288
291
clientPayload := codersdk.ServerSentEvent {}
289
292
err = json .Unmarshal (b , & clientPayload )
290
293
require .NoError (t , err )
291
294
require .Equal (t , serverPayload .Type , clientPayload .Type )
292
- cb , ok := clientPayload .Data .([]byte )
295
+ data , ok := clientPayload .Data .([]byte )
293
296
require .True (t , ok )
294
- require .Equal (t , serverPayload .Data , string (cb ))
297
+ require .Equal (t , serverPayload .Data , string (data ))
295
298
})
296
299
297
- t .Run ("Signals to an outside consumer when the socket has been closed" , func (t * testing.T ) {
300
+ t .Run ("Signals to outside consumer when socket has been closed" , func (t * testing.T ) {
298
301
t .Parallel ()
302
+
303
+ rootCtx := testutil .Context (t , testutil .WaitShort )
304
+ cancelCtx , cancel := context .WithCancel (rootCtx )
305
+
306
+ req := newBaseRequest (cancelCtx )
307
+ writer := newWebsocketWriter ()
308
+ t .Cleanup (writer .close )
309
+ _ , done , err := httpapi .OneWayWebSocket [codersdk.ServerSentEvent ](writer , req )
310
+ require .NoError (t , err )
311
+
312
+ successC := make (chan bool )
313
+ ticker := time .NewTicker (testutil .WaitShort )
314
+ go func () {
315
+ select {
316
+ case <- done :
317
+ successC <- true
318
+ case <- ticker .C :
319
+ successC <- false
320
+ }
321
+ }()
322
+
323
+ cancel ()
324
+ require .True (t , <- successC )
299
325
})
300
326
301
- t .Run ("Socket will automatically close if client sends a single message" , func (t * testing.T ) {
327
+ t .Run ("Socket will immediately close if client sends any message" , func (t * testing.T ) {
302
328
t .Parallel ()
303
329
})
304
330
0 commit comments