@@ -162,41 +162,66 @@ func TestWebsocketCloseMsg(t *testing.T) {
162
162
}
163
163
164
164
// Our WebSocket library accepts any arbitrary ResponseWriter at the type level,
165
- // but the writer must also implement http.Hijacker for long-lived connections
166
- type mockWsResponseWriter struct {
165
+ // but the writer must also implement http.Hijacker for long-lived connections.
166
+ // The SSE version only requires http.Flusher (no need for the Hijack method).
167
+ type mockEventSenderResponseWriter struct {
167
168
serverRecorder * httptest.ResponseRecorder
168
169
serverConn net.Conn
169
170
clientConn net.Conn
170
171
serverReadWriter * bufio.ReadWriter
171
172
}
172
173
173
- func (m mockWsResponseWriter ) Hijack () (net.Conn , * bufio.ReadWriter , error ) {
174
+ func (m mockEventSenderResponseWriter ) Hijack () (net.Conn , * bufio.ReadWriter , error ) {
174
175
return m .serverConn , m .serverReadWriter , nil
175
176
}
176
177
177
- func (m mockWsResponseWriter ) Flush () {
178
+ func (m mockEventSenderResponseWriter ) Flush () {
178
179
_ = m .serverReadWriter .Flush ()
179
180
}
180
181
181
- func (m mockWsResponseWriter ) Header () http.Header {
182
+ func (m mockEventSenderResponseWriter ) Header () http.Header {
182
183
return m .serverRecorder .Header ()
183
184
}
184
185
185
- func (m mockWsResponseWriter ) Write (b []byte ) (int , error ) {
186
+ func (m mockEventSenderResponseWriter ) Write (b []byte ) (int , error ) {
186
187
return m .serverReadWriter .Write (b )
187
188
}
188
189
189
- func (m mockWsResponseWriter ) WriteHeader (code int ) {
190
+ func (m mockEventSenderResponseWriter ) WriteHeader (code int ) {
190
191
m .serverRecorder .WriteHeader (code )
191
192
}
192
193
193
- type mockWsWrite func (b []byte ) (int , error )
194
+ type mockEventSenderWrite func (b []byte ) (int , error )
194
195
195
- func (w mockWsWrite ) Write (b []byte ) (int , error ) {
196
+ func (w mockEventSenderWrite ) Write (b []byte ) (int , error ) {
196
197
return w (b )
197
198
}
198
199
199
- func TestWebSocketEventSender (t * testing.T ) {
200
+ func newMockEventSenderWriter () mockEventSenderResponseWriter {
201
+ mockServer , mockClient := net .Pipe ()
202
+ recorder := httptest .NewRecorder ()
203
+
204
+ var write mockEventSenderWrite = func (b []byte ) (int , error ) {
205
+ serverCount , err := mockServer .Write (b )
206
+ if err != nil {
207
+ return serverCount , err
208
+ }
209
+ recorderCount , err := recorder .Write (b )
210
+ return min (serverCount , recorderCount ), err
211
+ }
212
+
213
+ return mockEventSenderResponseWriter {
214
+ serverConn : mockServer ,
215
+ clientConn : mockClient ,
216
+ serverRecorder : recorder ,
217
+ serverReadWriter : bufio .NewReadWriter (
218
+ bufio .NewReader (mockServer ),
219
+ bufio .NewWriter (write ),
220
+ ),
221
+ }
222
+ }
223
+
224
+ func TestOneWayWebSocketEventSender (t * testing.T ) {
200
225
t .Parallel ()
201
226
202
227
newBaseRequest := func (ctx context.Context ) * http.Request {
@@ -213,30 +238,6 @@ func TestWebSocketEventSender(t *testing.T) {
213
238
return req
214
239
}
215
240
216
- newWebsocketWriter := func () mockWsResponseWriter {
217
- mockServer , mockClient := net .Pipe ()
218
- recorder := httptest .NewRecorder ()
219
-
220
- var write mockWsWrite = func (b []byte ) (int , error ) {
221
- serverCount , err := mockServer .Write (b )
222
- if err != nil {
223
- return serverCount , err
224
- }
225
- recorderCount , err := recorder .Write (b )
226
- return min (serverCount , recorderCount ), err
227
- }
228
-
229
- return mockWsResponseWriter {
230
- serverConn : mockServer ,
231
- clientConn : mockClient ,
232
- serverRecorder : recorder ,
233
- serverReadWriter : bufio .NewReadWriter (
234
- bufio .NewReader (mockServer ),
235
- bufio .NewWriter (write ),
236
- ),
237
- }
238
- }
239
-
240
241
t .Run ("Produces error if the socket connection could not be established" , func (t * testing.T ) {
241
242
t .Parallel ()
242
243
@@ -255,8 +256,8 @@ func TestWebSocketEventSender(t *testing.T) {
255
256
req .ProtoMinor = p .minor
256
257
req .Proto = p .proto
257
258
258
- writer := newWebsocketWriter ()
259
- _ , _ , err := httpapi .WebSocketEventSender (writer , req )
259
+ writer := newMockEventSenderWriter ()
260
+ _ , _ , err := httpapi .OneWayWebSocketEventSender (writer , req )
260
261
require .ErrorContains (t , err , p .proto )
261
262
}
262
263
})
@@ -266,8 +267,8 @@ func TestWebSocketEventSender(t *testing.T) {
266
267
267
268
ctx := testutil .Context (t , testutil .WaitShort )
268
269
req := newBaseRequest (ctx )
269
- writer := newWebsocketWriter ()
270
- send , _ , err := httpapi .WebSocketEventSender (writer , req )
270
+ writer := newMockEventSenderWriter ()
271
+ send , _ , err := httpapi .OneWayWebSocketEventSender (writer , req )
271
272
require .NoError (t , err )
272
273
273
274
serverPayload := codersdk.ServerSentEvent {
@@ -292,8 +293,8 @@ func TestWebSocketEventSender(t *testing.T) {
292
293
293
294
ctx , cancel := context .WithCancel (testutil .Context (t , testutil .WaitShort ))
294
295
req := newBaseRequest (ctx )
295
- writer := newWebsocketWriter ()
296
- _ , done , err := httpapi .WebSocketEventSender (writer , req )
296
+ writer := newMockEventSenderWriter ()
297
+ _ , done , err := httpapi .OneWayWebSocketEventSender (writer , req )
297
298
require .NoError (t , err )
298
299
299
300
successC := make (chan bool )
@@ -316,8 +317,8 @@ func TestWebSocketEventSender(t *testing.T) {
316
317
317
318
ctx := testutil .Context (t , testutil .WaitShort )
318
319
req := newBaseRequest (ctx )
319
- writer := newWebsocketWriter ()
320
- _ , done , err := httpapi .WebSocketEventSender (writer , req )
320
+ writer := newMockEventSenderWriter ()
321
+ _ , done , err := httpapi .OneWayWebSocketEventSender (writer , req )
321
322
require .NoError (t , err )
322
323
323
324
successC := make (chan bool )
@@ -346,8 +347,8 @@ func TestWebSocketEventSender(t *testing.T) {
346
347
347
348
ctx , cancel := context .WithCancel (testutil .Context (t , testutil .WaitShort ))
348
349
req := newBaseRequest (ctx )
349
- writer := newWebsocketWriter ()
350
- send , done , err := httpapi .WebSocketEventSender (writer , req )
350
+ writer := newMockEventSenderWriter ()
351
+ send , done , err := httpapi .OneWayWebSocketEventSender (writer , req )
351
352
require .NoError (t , err )
352
353
353
354
successC := make (chan bool )
@@ -387,8 +388,8 @@ func TestWebSocketEventSender(t *testing.T) {
387
388
388
389
ctx := testutil .Context (t , timeout )
389
390
req := newBaseRequest (ctx )
390
- writer := newWebsocketWriter ()
391
- _ , _ , err := httpapi .WebSocketEventSender (writer , req )
391
+ writer := newMockEventSenderWriter ()
392
+ _ , _ , err := httpapi .OneWayWebSocketEventSender (writer , req )
392
393
require .NoError (t , err )
393
394
394
395
type Result struct {
@@ -420,3 +421,73 @@ func TestWebSocketEventSender(t *testing.T) {
420
421
require .True (t , result .Success )
421
422
})
422
423
}
424
+
425
+ func TestServerSentEventSender (t * testing.T ) {
426
+ t .Parallel ()
427
+
428
+ newBaseRequest := func (ctx context.Context ) * http.Request {
429
+ url := "ws://www.fake-website.com/logs"
430
+ req , err := http .NewRequestWithContext (ctx , http .MethodGet , url , nil )
431
+ require .NoError (t , err )
432
+ return req
433
+ }
434
+
435
+ t .Run ("Mutates response headers to support SSE connections" , func (t * testing.T ) {
436
+ t .Parallel ()
437
+
438
+ ctx := testutil .Context (t , testutil .WaitShort )
439
+ req := newBaseRequest (ctx )
440
+ writer := newMockEventSenderWriter ()
441
+ _ , _ , err := httpapi .ServerSentEventSender (writer , req )
442
+ require .NoError (t , err )
443
+
444
+ h := writer .Header ()
445
+ require .Equal (t , h .Get ("Content-Type" ), "text/event-stream" )
446
+ require .Equal (t , h .Get ("Cache-Control" ), "no-cache" )
447
+ require .Equal (t , h .Get ("Connection" ), "keep-alive" )
448
+ require .Equal (t , h .Get ("X-Accel-Buffering" ), "no" )
449
+ })
450
+
451
+ t .Run ("Returned callback can publish new event to SSE connection" , func (t * testing.T ) {
452
+ t .Parallel ()
453
+
454
+ ctx := testutil .Context (t , testutil .WaitShort )
455
+ req := newBaseRequest (ctx )
456
+ writer := newMockEventSenderWriter ()
457
+ send , _ , err := httpapi .ServerSentEventSender (writer , req )
458
+ require .NoError (t , err )
459
+
460
+ serverPayload := codersdk.ServerSentEvent {
461
+ Type : codersdk .ServerSentEventTypeData ,
462
+ Data : "Blah" ,
463
+ }
464
+ err = send (serverPayload )
465
+ require .NoError (t , err )
466
+
467
+ // The client connection will receive a little bit of additional data on
468
+ // top of the main payload. Have to make sure check has tolerance for
469
+ // extra data being present
470
+ serverBytes , err := json .Marshal (serverPayload )
471
+ require .NoError (t , err )
472
+
473
+ // This is the part that's breaking
474
+ clientBytes , err := io .ReadAll (writer .clientConn )
475
+ require .NoError (t , err )
476
+ require .True (t , bytes .Contains (clientBytes , serverBytes ))
477
+ })
478
+
479
+ t .Run ("Signals to outside consumer when connection has been closed" , func (t * testing.T ) {
480
+ t .Parallel ()
481
+ t .FailNow ()
482
+ })
483
+
484
+ t .Run ("Cancels the entire connection if the request context cancels" , func (t * testing.T ) {
485
+ t .Parallel ()
486
+ t .FailNow ()
487
+ })
488
+
489
+ t .Run ("Sends a heartbeat to the client on a fixed internal of time to keep connections alive" , func (t * testing.T ) {
490
+ t .Parallel ()
491
+ t .FailNow ()
492
+ })
493
+ }
0 commit comments