Skip to content

Commit c7d95d9

Browse files
committed
wip: commit progress on tests
1 parent 43b1676 commit c7d95d9

File tree

4 files changed

+126
-62
lines changed

4 files changed

+126
-62
lines changed

coderd/httpapi/httpapi.go

+7-14
Original file line numberDiff line numberDiff line change
@@ -326,16 +326,13 @@ func ServerSentEventSender(rw http.ResponseWriter, r *http.Request) (
326326
// Synchronized handling of events (no guarantee of order).
327327
go func() {
328328
defer close(closed)
329-
330-
// Send a heartbeat every 15 seconds to avoid the connection being killed.
331-
ticker := time.NewTicker(time.Second * 15)
329+
ticker := time.NewTicker(HeartbeatInterval)
332330
defer ticker.Stop()
333331

334332
for {
335333
var event sseEvent
336-
337334
select {
338-
case <-r.Context().Done():
335+
case <-ctx.Done():
339336
return
340337
case event = <-eventC:
341338
case <-ticker.C:
@@ -357,8 +354,6 @@ func ServerSentEventSender(rw http.ResponseWriter, r *http.Request) (
357354

358355
sendEvent := func(newEvent codersdk.ServerSentEvent) error {
359356
buf := &bytes.Buffer{}
360-
enc := json.NewEncoder(buf)
361-
362357
_, err := buf.WriteString(fmt.Sprintf("event: %s\n", newEvent.Type))
363358
if err != nil {
364359
return err
@@ -369,6 +364,8 @@ func ServerSentEventSender(rw http.ResponseWriter, r *http.Request) (
369364
if err != nil {
370365
return err
371366
}
367+
368+
enc := json.NewEncoder(buf)
372369
err = enc.Encode(newEvent.Data)
373370
if err != nil {
374371
return err
@@ -386,8 +383,6 @@ func ServerSentEventSender(rw http.ResponseWriter, r *http.Request) (
386383
}
387384

388385
select {
389-
case <-r.Context().Done():
390-
return r.Context().Err()
391386
case <-ctx.Done():
392387
return ctx.Err()
393388
case <-closed:
@@ -397,8 +392,6 @@ func ServerSentEventSender(rw http.ResponseWriter, r *http.Request) (
397392
// for early exit. We don't check closed here because it
398393
// can't happen while processing the event.
399394
select {
400-
case <-r.Context().Done():
401-
return r.Context().Err()
402395
case <-ctx.Done():
403396
return ctx.Err()
404397
case err := <-event.errC:
@@ -410,8 +403,8 @@ func ServerSentEventSender(rw http.ResponseWriter, r *http.Request) (
410403
return sendEvent, closed, nil
411404
}
412405

413-
// WebSocketEventSender establishes a new WebSocket connection that enforces
414-
// one-way communication from the server to the client.
406+
// OneWayWebSocketEventSender establishes a new WebSocket connection that
407+
// enforces one-way communication from the server to the client.
415408
//
416409
// The function returned allows you to send a single message to the client,
417410
// while the channel lets you listen for when the connection closes.
@@ -422,7 +415,7 @@ func ServerSentEventSender(rw http.ResponseWriter, r *http.Request) (
422415
// open a workspace in multiple tabs, the entire UI can start to lock up.
423416
// WebSockets have no such limitation, no matter what HTTP protocol was used to
424417
// establish the connection.
425-
func WebSocketEventSender(rw http.ResponseWriter, r *http.Request) (
418+
func OneWayWebSocketEventSender(rw http.ResponseWriter, r *http.Request) (
426419
func(event codersdk.ServerSentEvent) error,
427420
<-chan struct{},
428421
error,

coderd/httpapi/httpapi_test.go

+117-46
Original file line numberDiff line numberDiff line change
@@ -162,41 +162,66 @@ func TestWebsocketCloseMsg(t *testing.T) {
162162
}
163163

164164
// Our WebSocket library accepts any arbitrary ResponseWriter at the type level,
165-
// but the writer must also implement http.Hijacker for long-lived connections
166-
type mockWsResponseWriter struct {
165+
// but the writer must also implement http.Hijacker for long-lived connections.
166+
// The SSE version only requires http.Flusher (no need for the Hijack method).
167+
type mockEventSenderResponseWriter struct {
167168
serverRecorder *httptest.ResponseRecorder
168169
serverConn net.Conn
169170
clientConn net.Conn
170171
serverReadWriter *bufio.ReadWriter
171172
}
172173

173-
func (m mockWsResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
174+
func (m mockEventSenderResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
174175
return m.serverConn, m.serverReadWriter, nil
175176
}
176177

177-
func (m mockWsResponseWriter) Flush() {
178+
func (m mockEventSenderResponseWriter) Flush() {
178179
_ = m.serverReadWriter.Flush()
179180
}
180181

181-
func (m mockWsResponseWriter) Header() http.Header {
182+
func (m mockEventSenderResponseWriter) Header() http.Header {
182183
return m.serverRecorder.Header()
183184
}
184185

185-
func (m mockWsResponseWriter) Write(b []byte) (int, error) {
186+
func (m mockEventSenderResponseWriter) Write(b []byte) (int, error) {
186187
return m.serverReadWriter.Write(b)
187188
}
188189

189-
func (m mockWsResponseWriter) WriteHeader(code int) {
190+
func (m mockEventSenderResponseWriter) WriteHeader(code int) {
190191
m.serverRecorder.WriteHeader(code)
191192
}
192193

193-
type mockWsWrite func(b []byte) (int, error)
194+
type mockEventSenderWrite func(b []byte) (int, error)
194195

195-
func (w mockWsWrite) Write(b []byte) (int, error) {
196+
func (w mockEventSenderWrite) Write(b []byte) (int, error) {
196197
return w(b)
197198
}
198199

199-
func TestWebSocketEventSender(t *testing.T) {
200+
func newMockEventSenderWriter() mockEventSenderResponseWriter {
201+
mockServer, mockClient := net.Pipe()
202+
recorder := httptest.NewRecorder()
203+
204+
var write mockEventSenderWrite = func(b []byte) (int, error) {
205+
serverCount, err := mockServer.Write(b)
206+
if err != nil {
207+
return serverCount, err
208+
}
209+
recorderCount, err := recorder.Write(b)
210+
return min(serverCount, recorderCount), err
211+
}
212+
213+
return mockEventSenderResponseWriter{
214+
serverConn: mockServer,
215+
clientConn: mockClient,
216+
serverRecorder: recorder,
217+
serverReadWriter: bufio.NewReadWriter(
218+
bufio.NewReader(mockServer),
219+
bufio.NewWriter(write),
220+
),
221+
}
222+
}
223+
224+
func TestOneWayWebSocketEventSender(t *testing.T) {
200225
t.Parallel()
201226

202227
newBaseRequest := func(ctx context.Context) *http.Request {
@@ -213,30 +238,6 @@ func TestWebSocketEventSender(t *testing.T) {
213238
return req
214239
}
215240

216-
newWebsocketWriter := func() mockWsResponseWriter {
217-
mockServer, mockClient := net.Pipe()
218-
recorder := httptest.NewRecorder()
219-
220-
var write mockWsWrite = func(b []byte) (int, error) {
221-
serverCount, err := mockServer.Write(b)
222-
if err != nil {
223-
return serverCount, err
224-
}
225-
recorderCount, err := recorder.Write(b)
226-
return min(serverCount, recorderCount), err
227-
}
228-
229-
return mockWsResponseWriter{
230-
serverConn: mockServer,
231-
clientConn: mockClient,
232-
serverRecorder: recorder,
233-
serverReadWriter: bufio.NewReadWriter(
234-
bufio.NewReader(mockServer),
235-
bufio.NewWriter(write),
236-
),
237-
}
238-
}
239-
240241
t.Run("Produces error if the socket connection could not be established", func(t *testing.T) {
241242
t.Parallel()
242243

@@ -255,8 +256,8 @@ func TestWebSocketEventSender(t *testing.T) {
255256
req.ProtoMinor = p.minor
256257
req.Proto = p.proto
257258

258-
writer := newWebsocketWriter()
259-
_, _, err := httpapi.WebSocketEventSender(writer, req)
259+
writer := newMockEventSenderWriter()
260+
_, _, err := httpapi.OneWayWebSocketEventSender(writer, req)
260261
require.ErrorContains(t, err, p.proto)
261262
}
262263
})
@@ -266,8 +267,8 @@ func TestWebSocketEventSender(t *testing.T) {
266267

267268
ctx := testutil.Context(t, testutil.WaitShort)
268269
req := newBaseRequest(ctx)
269-
writer := newWebsocketWriter()
270-
send, _, err := httpapi.WebSocketEventSender(writer, req)
270+
writer := newMockEventSenderWriter()
271+
send, _, err := httpapi.OneWayWebSocketEventSender(writer, req)
271272
require.NoError(t, err)
272273

273274
serverPayload := codersdk.ServerSentEvent{
@@ -292,8 +293,8 @@ func TestWebSocketEventSender(t *testing.T) {
292293

293294
ctx, cancel := context.WithCancel(testutil.Context(t, testutil.WaitShort))
294295
req := newBaseRequest(ctx)
295-
writer := newWebsocketWriter()
296-
_, done, err := httpapi.WebSocketEventSender(writer, req)
296+
writer := newMockEventSenderWriter()
297+
_, done, err := httpapi.OneWayWebSocketEventSender(writer, req)
297298
require.NoError(t, err)
298299

299300
successC := make(chan bool)
@@ -316,8 +317,8 @@ func TestWebSocketEventSender(t *testing.T) {
316317

317318
ctx := testutil.Context(t, testutil.WaitShort)
318319
req := newBaseRequest(ctx)
319-
writer := newWebsocketWriter()
320-
_, done, err := httpapi.WebSocketEventSender(writer, req)
320+
writer := newMockEventSenderWriter()
321+
_, done, err := httpapi.OneWayWebSocketEventSender(writer, req)
321322
require.NoError(t, err)
322323

323324
successC := make(chan bool)
@@ -346,8 +347,8 @@ func TestWebSocketEventSender(t *testing.T) {
346347

347348
ctx, cancel := context.WithCancel(testutil.Context(t, testutil.WaitShort))
348349
req := newBaseRequest(ctx)
349-
writer := newWebsocketWriter()
350-
send, done, err := httpapi.WebSocketEventSender(writer, req)
350+
writer := newMockEventSenderWriter()
351+
send, done, err := httpapi.OneWayWebSocketEventSender(writer, req)
351352
require.NoError(t, err)
352353

353354
successC := make(chan bool)
@@ -387,8 +388,8 @@ func TestWebSocketEventSender(t *testing.T) {
387388

388389
ctx := testutil.Context(t, timeout)
389390
req := newBaseRequest(ctx)
390-
writer := newWebsocketWriter()
391-
_, _, err := httpapi.WebSocketEventSender(writer, req)
391+
writer := newMockEventSenderWriter()
392+
_, _, err := httpapi.OneWayWebSocketEventSender(writer, req)
392393
require.NoError(t, err)
393394

394395
type Result struct {
@@ -420,3 +421,73 @@ func TestWebSocketEventSender(t *testing.T) {
420421
require.True(t, result.Success)
421422
})
422423
}
424+
425+
func TestServerSentEventSender(t *testing.T) {
426+
t.Parallel()
427+
428+
newBaseRequest := func(ctx context.Context) *http.Request {
429+
url := "ws://www.fake-website.com/logs"
430+
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
431+
require.NoError(t, err)
432+
return req
433+
}
434+
435+
t.Run("Mutates response headers to support SSE connections", func(t *testing.T) {
436+
t.Parallel()
437+
438+
ctx := testutil.Context(t, testutil.WaitShort)
439+
req := newBaseRequest(ctx)
440+
writer := newMockEventSenderWriter()
441+
_, _, err := httpapi.ServerSentEventSender(writer, req)
442+
require.NoError(t, err)
443+
444+
h := writer.Header()
445+
require.Equal(t, h.Get("Content-Type"), "text/event-stream")
446+
require.Equal(t, h.Get("Cache-Control"), "no-cache")
447+
require.Equal(t, h.Get("Connection"), "keep-alive")
448+
require.Equal(t, h.Get("X-Accel-Buffering"), "no")
449+
})
450+
451+
t.Run("Returned callback can publish new event to SSE connection", func(t *testing.T) {
452+
t.Parallel()
453+
454+
ctx := testutil.Context(t, testutil.WaitShort)
455+
req := newBaseRequest(ctx)
456+
writer := newMockEventSenderWriter()
457+
send, _, err := httpapi.ServerSentEventSender(writer, req)
458+
require.NoError(t, err)
459+
460+
serverPayload := codersdk.ServerSentEvent{
461+
Type: codersdk.ServerSentEventTypeData,
462+
Data: "Blah",
463+
}
464+
err = send(serverPayload)
465+
require.NoError(t, err)
466+
467+
// The client connection will receive a little bit of additional data on
468+
// top of the main payload. Have to make sure check has tolerance for
469+
// extra data being present
470+
serverBytes, err := json.Marshal(serverPayload)
471+
require.NoError(t, err)
472+
473+
// This is the part that's breaking
474+
clientBytes, err := io.ReadAll(writer.clientConn)
475+
require.NoError(t, err)
476+
require.True(t, bytes.Contains(clientBytes, serverBytes))
477+
})
478+
479+
t.Run("Signals to outside consumer when connection has been closed", func(t *testing.T) {
480+
t.Parallel()
481+
t.FailNow()
482+
})
483+
484+
t.Run("Cancels the entire connection if the request context cancels", func(t *testing.T) {
485+
t.Parallel()
486+
t.FailNow()
487+
})
488+
489+
t.Run("Sends a heartbeat to the client on a fixed internal of time to keep connections alive", func(t *testing.T) {
490+
t.Parallel()
491+
t.FailNow()
492+
})
493+
}

coderd/workspaceagents.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -1109,7 +1109,7 @@ func (api *API) watchWorkspaceAgentMetadataSSE(rw http.ResponseWriter, r *http.R
11091109
// @Router /workspaceagents/{workspaceagent}/watch-metadata-ws [get]
11101110
// @x-apidocgen {"skip": true}
11111111
func (api *API) watchWorkspaceAgentMetadataWS(rw http.ResponseWriter, r *http.Request) {
1112-
api.watchWorkspaceAgentMetadata(rw, r, httpapi.WebSocketEventSender)
1112+
api.watchWorkspaceAgentMetadata(rw, r, httpapi.OneWayWebSocketEventSender)
11131113
}
11141114

11151115
func (api *API) watchWorkspaceAgentMetadata(

coderd/workspaces.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -1732,7 +1732,7 @@ func (api *API) watchWorkspaceSSE(rw http.ResponseWriter, r *http.Request) {
17321732
// @Success 200 {object} codersdk.ServerSentEvent
17331733
// @Router /workspaces/{workspace}/watch-ws [get]
17341734
func (api *API) watchWorkspaceWS(rw http.ResponseWriter, r *http.Request) {
1735-
api.watchWorkspace(rw, r, httpapi.WebSocketEventSender)
1735+
api.watchWorkspace(rw, r, httpapi.OneWayWebSocketEventSender)
17361736
}
17371737

17381738
func (api *API) watchWorkspace(

0 commit comments

Comments
 (0)