@@ -163,31 +163,32 @@ func TestWebsocketCloseMsg(t *testing.T) {
163
163
164
164
// Our WebSocket library accepts any arbitrary ResponseWriter at the type level,
165
165
// but the writer must also implement http.Hijacker for long-lived connections.
166
- // The SSE version only requires http.Flusher (no need for the Hijack method).
167
- type mockEventSenderResponseWriter struct {
166
+ type mockOneWaySocketWriter struct {
168
167
serverRecorder * httptest.ResponseRecorder
169
168
serverConn net.Conn
170
169
clientConn net.Conn
171
170
serverReadWriter * bufio.ReadWriter
171
+ testContext * testing.T
172
172
}
173
173
174
- func (m mockEventSenderResponseWriter ) Hijack () (net.Conn , * bufio.ReadWriter , error ) {
174
+ func (m mockOneWaySocketWriter ) Hijack () (net.Conn , * bufio.ReadWriter , error ) {
175
175
return m .serverConn , m .serverReadWriter , nil
176
176
}
177
177
178
- func (m mockEventSenderResponseWriter ) Flush () {
179
- _ = m .serverReadWriter .Flush ()
178
+ func (m mockOneWaySocketWriter ) Flush () {
179
+ err := m .serverReadWriter .Flush ()
180
+ require .NoError (m .testContext , err )
180
181
}
181
182
182
- func (m mockEventSenderResponseWriter ) Header () http.Header {
183
+ func (m mockOneWaySocketWriter ) Header () http.Header {
183
184
return m .serverRecorder .Header ()
184
185
}
185
186
186
- func (m mockEventSenderResponseWriter ) Write (b []byte ) (int , error ) {
187
+ func (m mockOneWaySocketWriter ) Write (b []byte ) (int , error ) {
187
188
return m .serverReadWriter .Write (b )
188
189
}
189
190
190
- func (m mockEventSenderResponseWriter ) WriteHeader (code int ) {
191
+ func (m mockOneWaySocketWriter ) WriteHeader (code int ) {
191
192
m .serverRecorder .WriteHeader (code )
192
193
}
193
194
@@ -197,30 +198,6 @@ func (w mockEventSenderWrite) Write(b []byte) (int, error) {
197
198
return w (b )
198
199
}
199
200
200
- func newMockEventSenderWriter () mockEventSenderResponseWriter {
201
- mockServer , mockClient := net .Pipe ()
202
- recorder := httptest .NewRecorder ()
203
-
204
- var write mockEventSenderWrite = func (b []byte ) (int , error ) {
205
- serverCount , err := mockServer .Write (b )
206
- if err != nil {
207
- return serverCount , err
208
- }
209
- recorderCount , err := recorder .Write (b )
210
- return min (serverCount , recorderCount ), err
211
- }
212
-
213
- return mockEventSenderResponseWriter {
214
- serverConn : mockServer ,
215
- clientConn : mockClient ,
216
- serverRecorder : recorder ,
217
- serverReadWriter : bufio .NewReadWriter (
218
- bufio .NewReader (mockServer ),
219
- bufio .NewWriter (write ),
220
- ),
221
- }
222
- }
223
-
224
201
func TestOneWayWebSocketEventSender (t * testing.T ) {
225
202
t .Parallel ()
226
203
@@ -238,6 +215,34 @@ func TestOneWayWebSocketEventSender(t *testing.T) {
238
215
return req
239
216
}
240
217
218
+ newOneWayWriter := func (t * testing.T ) mockOneWaySocketWriter {
219
+ mockServer , mockClient := net .Pipe ()
220
+ recorder := httptest .NewRecorder ()
221
+
222
+ var write mockEventSenderWrite = func (b []byte ) (int , error ) {
223
+ serverCount , err := mockServer .Write (b )
224
+ if err != nil {
225
+ return 0 , err
226
+ }
227
+ recorderCount , err := recorder .Write (b )
228
+ if err != nil {
229
+ return 0 , err
230
+ }
231
+ return min (serverCount , recorderCount ), nil
232
+ }
233
+
234
+ return mockOneWaySocketWriter {
235
+ testContext : t ,
236
+ serverConn : mockServer ,
237
+ clientConn : mockClient ,
238
+ serverRecorder : recorder ,
239
+ serverReadWriter : bufio .NewReadWriter (
240
+ bufio .NewReader (mockServer ),
241
+ bufio .NewWriter (write ),
242
+ ),
243
+ }
244
+ }
245
+
241
246
t .Run ("Produces error if the socket connection could not be established" , func (t * testing.T ) {
242
247
t .Parallel ()
243
248
@@ -256,7 +261,7 @@ func TestOneWayWebSocketEventSender(t *testing.T) {
256
261
req .ProtoMinor = p .minor
257
262
req .Proto = p .proto
258
263
259
- writer := newMockEventSenderWriter ( )
264
+ writer := newOneWayWriter ( t )
260
265
_ , _ , err := httpapi .OneWayWebSocketEventSender (writer , req )
261
266
require .ErrorContains (t , err , p .proto )
262
267
}
@@ -267,7 +272,7 @@ func TestOneWayWebSocketEventSender(t *testing.T) {
267
272
268
273
ctx := testutil .Context (t , testutil .WaitShort )
269
274
req := newBaseRequest (ctx )
270
- writer := newMockEventSenderWriter ( )
275
+ writer := newOneWayWriter ( t )
271
276
send , _ , err := httpapi .OneWayWebSocketEventSender (writer , req )
272
277
require .NoError (t , err )
273
278
@@ -293,7 +298,7 @@ func TestOneWayWebSocketEventSender(t *testing.T) {
293
298
294
299
ctx , cancel := context .WithCancel (testutil .Context (t , testutil .WaitShort ))
295
300
req := newBaseRequest (ctx )
296
- writer := newMockEventSenderWriter ( )
301
+ writer := newOneWayWriter ( t )
297
302
_ , done , err := httpapi .OneWayWebSocketEventSender (writer , req )
298
303
require .NoError (t , err )
299
304
@@ -317,7 +322,7 @@ func TestOneWayWebSocketEventSender(t *testing.T) {
317
322
318
323
ctx := testutil .Context (t , testutil .WaitShort )
319
324
req := newBaseRequest (ctx )
320
- writer := newMockEventSenderWriter ( )
325
+ writer := newOneWayWriter ( t )
321
326
_ , done , err := httpapi .OneWayWebSocketEventSender (writer , req )
322
327
require .NoError (t , err )
323
328
@@ -347,7 +352,7 @@ func TestOneWayWebSocketEventSender(t *testing.T) {
347
352
348
353
ctx , cancel := context .WithCancel (testutil .Context (t , testutil .WaitShort ))
349
354
req := newBaseRequest (ctx )
350
- writer := newMockEventSenderWriter ( )
355
+ writer := newOneWayWriter ( t )
351
356
send , done , err := httpapi .OneWayWebSocketEventSender (writer , req )
352
357
require .NoError (t , err )
353
358
@@ -388,7 +393,7 @@ func TestOneWayWebSocketEventSender(t *testing.T) {
388
393
389
394
ctx := testutil .Context (t , timeout )
390
395
req := newBaseRequest (ctx )
391
- writer := newMockEventSenderWriter ( )
396
+ writer := newOneWayWriter ( t )
392
397
_ , _ , err := httpapi .OneWayWebSocketEventSender (writer , req )
393
398
require .NoError (t , err )
394
399
@@ -422,6 +427,42 @@ func TestOneWayWebSocketEventSender(t *testing.T) {
422
427
})
423
428
}
424
429
430
+ // ServerSentEventSender accepts any arbitrary ResponseWriter at the type level,
431
+ // but the writer must also implement http.Flusher for long-lived connections
432
+ type mockServerSentWriter struct {
433
+ serverRecorder * httptest.ResponseRecorder
434
+ serverConn net.Conn
435
+ clientConn net.Conn
436
+ buffer * bytes.Buffer
437
+ testContext * testing.T
438
+ }
439
+
440
+ func (m mockServerSentWriter ) Flush () {
441
+ b := m .buffer .Bytes ()
442
+ _ , err := m .serverConn .Write (b )
443
+ require .NoError (m .testContext , err )
444
+ m .buffer .Reset ()
445
+
446
+ // Must close server connection to indicate EOF for any reads from the
447
+ // client connection; otherwise reads block forever. This is a testing
448
+ // limitation compared to the one-way websockets, since we have no way to
449
+ // frame the data and auto-indicate EOF for each message
450
+ err = m .serverConn .Close ()
451
+ require .NoError (m .testContext , err )
452
+ }
453
+
454
+ func (m mockServerSentWriter ) Header () http.Header {
455
+ return m .serverRecorder .Header ()
456
+ }
457
+
458
+ func (m mockServerSentWriter ) Write (b []byte ) (int , error ) {
459
+ return m .buffer .Write (b )
460
+ }
461
+
462
+ func (m mockServerSentWriter ) WriteHeader (code int ) {
463
+ m .serverRecorder .WriteHeader (code )
464
+ }
465
+
425
466
func TestServerSentEventSender (t * testing.T ) {
426
467
t .Parallel ()
427
468
@@ -432,12 +473,23 @@ func TestServerSentEventSender(t *testing.T) {
432
473
return req
433
474
}
434
475
476
+ newServerSentWriter := func (t * testing.T ) mockServerSentWriter {
477
+ mockServer , mockClient := net .Pipe ()
478
+ return mockServerSentWriter {
479
+ testContext : t ,
480
+ serverRecorder : httptest .NewRecorder (),
481
+ clientConn : mockClient ,
482
+ serverConn : mockServer ,
483
+ buffer : & bytes.Buffer {},
484
+ }
485
+ }
486
+
435
487
t .Run ("Mutates response headers to support SSE connections" , func (t * testing.T ) {
436
488
t .Parallel ()
437
489
438
490
ctx := testutil .Context (t , testutil .WaitShort )
439
491
req := newBaseRequest (ctx )
440
- writer := newMockEventSenderWriter ( )
492
+ writer := newServerSentWriter ( t )
441
493
_ , _ , err := httpapi .ServerSentEventSender (writer , req )
442
494
require .NoError (t , err )
443
495
@@ -453,7 +505,7 @@ func TestServerSentEventSender(t *testing.T) {
453
505
454
506
ctx := testutil .Context (t , testutil .WaitShort )
455
507
req := newBaseRequest (ctx )
456
- writer := newMockEventSenderWriter ( )
508
+ writer := newServerSentWriter ( t )
457
509
send , _ , err := httpapi .ServerSentEventSender (writer , req )
458
510
require .NoError (t , err )
459
511
@@ -464,30 +516,46 @@ func TestServerSentEventSender(t *testing.T) {
464
516
err = send (serverPayload )
465
517
require .NoError (t , err )
466
518
467
- // The client connection will receive a little bit of additional data on
468
- // top of the main payload. Have to make sure check has tolerance for
469
- // extra data being present
470
- serverBytes , err := json .Marshal (serverPayload )
471
- require .NoError (t , err )
472
-
473
- // This is the part that's breaking
474
519
clientBytes , err := io .ReadAll (writer .clientConn )
475
520
require .NoError (t , err )
476
- require .True (t , bytes .Contains (clientBytes , serverBytes ))
521
+ require .Equal (
522
+ t ,
523
+ string (clientBytes ),
524
+ "event: data\n data: \" Blah\" \n \n " ,
525
+ )
477
526
})
478
527
479
528
t .Run ("Signals to outside consumer when connection has been closed" , func (t * testing.T ) {
480
529
t .Parallel ()
481
- t .FailNow ()
530
+
531
+ ctx , cancel := context .WithCancel (testutil .Context (t , testutil .WaitShort ))
532
+ req := newBaseRequest (ctx )
533
+ writer := newServerSentWriter (t )
534
+ _ , done , err := httpapi .ServerSentEventSender (writer , req )
535
+ require .NoError (t , err )
536
+
537
+ successC := make (chan bool )
538
+ ticker := time .NewTicker (testutil .WaitShort )
539
+ go func () {
540
+ select {
541
+ case <- done :
542
+ successC <- true
543
+ case <- ticker .C :
544
+ successC <- false
545
+ }
546
+ }()
547
+
548
+ cancel ()
549
+ require .True (t , <- successC )
482
550
})
483
551
484
552
t .Run ("Cancels the entire connection if the request context cancels" , func (t * testing.T ) {
485
- t .Parallel ()
486
553
t .FailNow ()
554
+ t .Parallel ()
487
555
})
488
556
489
557
t .Run ("Sends a heartbeat to the client on a fixed internal of time to keep connections alive" , func (t * testing.T ) {
490
- t .Parallel ()
491
558
t .FailNow ()
559
+ t .Parallel ()
492
560
})
493
561
}
0 commit comments