Skip to content

Commit 792aa2d

Browse files
committed
wip: commit more test progress
1 parent c7d95d9 commit 792aa2d

File tree

1 file changed

+119
-51
lines changed

1 file changed

+119
-51
lines changed

coderd/httpapi/httpapi_test.go

+119-51
Original file line numberDiff line numberDiff line change
@@ -163,31 +163,32 @@ func TestWebsocketCloseMsg(t *testing.T) {
163163

164164
// Our WebSocket library accepts any arbitrary ResponseWriter at the type level,
165165
// but the writer must also implement http.Hijacker for long-lived connections.
166-
// The SSE version only requires http.Flusher (no need for the Hijack method).
167-
type mockEventSenderResponseWriter struct {
166+
type mockOneWaySocketWriter struct {
168167
serverRecorder *httptest.ResponseRecorder
169168
serverConn net.Conn
170169
clientConn net.Conn
171170
serverReadWriter *bufio.ReadWriter
171+
testContext *testing.T
172172
}
173173

174-
func (m mockEventSenderResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
174+
func (m mockOneWaySocketWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
175175
return m.serverConn, m.serverReadWriter, nil
176176
}
177177

178-
func (m mockEventSenderResponseWriter) Flush() {
179-
_ = m.serverReadWriter.Flush()
178+
func (m mockOneWaySocketWriter) Flush() {
179+
err := m.serverReadWriter.Flush()
180+
require.NoError(m.testContext, err)
180181
}
181182

182-
func (m mockEventSenderResponseWriter) Header() http.Header {
183+
func (m mockOneWaySocketWriter) Header() http.Header {
183184
return m.serverRecorder.Header()
184185
}
185186

186-
func (m mockEventSenderResponseWriter) Write(b []byte) (int, error) {
187+
func (m mockOneWaySocketWriter) Write(b []byte) (int, error) {
187188
return m.serverReadWriter.Write(b)
188189
}
189190

190-
func (m mockEventSenderResponseWriter) WriteHeader(code int) {
191+
func (m mockOneWaySocketWriter) WriteHeader(code int) {
191192
m.serverRecorder.WriteHeader(code)
192193
}
193194

@@ -197,30 +198,6 @@ func (w mockEventSenderWrite) Write(b []byte) (int, error) {
197198
return w(b)
198199
}
199200

200-
func newMockEventSenderWriter() mockEventSenderResponseWriter {
201-
mockServer, mockClient := net.Pipe()
202-
recorder := httptest.NewRecorder()
203-
204-
var write mockEventSenderWrite = func(b []byte) (int, error) {
205-
serverCount, err := mockServer.Write(b)
206-
if err != nil {
207-
return serverCount, err
208-
}
209-
recorderCount, err := recorder.Write(b)
210-
return min(serverCount, recorderCount), err
211-
}
212-
213-
return mockEventSenderResponseWriter{
214-
serverConn: mockServer,
215-
clientConn: mockClient,
216-
serverRecorder: recorder,
217-
serverReadWriter: bufio.NewReadWriter(
218-
bufio.NewReader(mockServer),
219-
bufio.NewWriter(write),
220-
),
221-
}
222-
}
223-
224201
func TestOneWayWebSocketEventSender(t *testing.T) {
225202
t.Parallel()
226203

@@ -238,6 +215,34 @@ func TestOneWayWebSocketEventSender(t *testing.T) {
238215
return req
239216
}
240217

218+
newOneWayWriter := func(t *testing.T) mockOneWaySocketWriter {
219+
mockServer, mockClient := net.Pipe()
220+
recorder := httptest.NewRecorder()
221+
222+
var write mockEventSenderWrite = func(b []byte) (int, error) {
223+
serverCount, err := mockServer.Write(b)
224+
if err != nil {
225+
return 0, err
226+
}
227+
recorderCount, err := recorder.Write(b)
228+
if err != nil {
229+
return 0, err
230+
}
231+
return min(serverCount, recorderCount), nil
232+
}
233+
234+
return mockOneWaySocketWriter{
235+
testContext: t,
236+
serverConn: mockServer,
237+
clientConn: mockClient,
238+
serverRecorder: recorder,
239+
serverReadWriter: bufio.NewReadWriter(
240+
bufio.NewReader(mockServer),
241+
bufio.NewWriter(write),
242+
),
243+
}
244+
}
245+
241246
t.Run("Produces error if the socket connection could not be established", func(t *testing.T) {
242247
t.Parallel()
243248

@@ -256,7 +261,7 @@ func TestOneWayWebSocketEventSender(t *testing.T) {
256261
req.ProtoMinor = p.minor
257262
req.Proto = p.proto
258263

259-
writer := newMockEventSenderWriter()
264+
writer := newOneWayWriter(t)
260265
_, _, err := httpapi.OneWayWebSocketEventSender(writer, req)
261266
require.ErrorContains(t, err, p.proto)
262267
}
@@ -267,7 +272,7 @@ func TestOneWayWebSocketEventSender(t *testing.T) {
267272

268273
ctx := testutil.Context(t, testutil.WaitShort)
269274
req := newBaseRequest(ctx)
270-
writer := newMockEventSenderWriter()
275+
writer := newOneWayWriter(t)
271276
send, _, err := httpapi.OneWayWebSocketEventSender(writer, req)
272277
require.NoError(t, err)
273278

@@ -293,7 +298,7 @@ func TestOneWayWebSocketEventSender(t *testing.T) {
293298

294299
ctx, cancel := context.WithCancel(testutil.Context(t, testutil.WaitShort))
295300
req := newBaseRequest(ctx)
296-
writer := newMockEventSenderWriter()
301+
writer := newOneWayWriter(t)
297302
_, done, err := httpapi.OneWayWebSocketEventSender(writer, req)
298303
require.NoError(t, err)
299304

@@ -317,7 +322,7 @@ func TestOneWayWebSocketEventSender(t *testing.T) {
317322

318323
ctx := testutil.Context(t, testutil.WaitShort)
319324
req := newBaseRequest(ctx)
320-
writer := newMockEventSenderWriter()
325+
writer := newOneWayWriter(t)
321326
_, done, err := httpapi.OneWayWebSocketEventSender(writer, req)
322327
require.NoError(t, err)
323328

@@ -347,7 +352,7 @@ func TestOneWayWebSocketEventSender(t *testing.T) {
347352

348353
ctx, cancel := context.WithCancel(testutil.Context(t, testutil.WaitShort))
349354
req := newBaseRequest(ctx)
350-
writer := newMockEventSenderWriter()
355+
writer := newOneWayWriter(t)
351356
send, done, err := httpapi.OneWayWebSocketEventSender(writer, req)
352357
require.NoError(t, err)
353358

@@ -388,7 +393,7 @@ func TestOneWayWebSocketEventSender(t *testing.T) {
388393

389394
ctx := testutil.Context(t, timeout)
390395
req := newBaseRequest(ctx)
391-
writer := newMockEventSenderWriter()
396+
writer := newOneWayWriter(t)
392397
_, _, err := httpapi.OneWayWebSocketEventSender(writer, req)
393398
require.NoError(t, err)
394399

@@ -422,6 +427,42 @@ func TestOneWayWebSocketEventSender(t *testing.T) {
422427
})
423428
}
424429

430+
// ServerSentEventSender accepts any arbitrary ResponseWriter at the type level,
431+
// but the writer must also implement http.Flusher for long-lived connections
432+
type mockServerSentWriter struct {
433+
serverRecorder *httptest.ResponseRecorder
434+
serverConn net.Conn
435+
clientConn net.Conn
436+
buffer *bytes.Buffer
437+
testContext *testing.T
438+
}
439+
440+
func (m mockServerSentWriter) Flush() {
441+
b := m.buffer.Bytes()
442+
_, err := m.serverConn.Write(b)
443+
require.NoError(m.testContext, err)
444+
m.buffer.Reset()
445+
446+
// Must close server connection to indicate EOF for any reads from the
447+
// client connection; otherwise reads block forever. This is a testing
448+
// limitation compared to the one-way websockets, since we have no way to
449+
// frame the data and auto-indicate EOF for each message
450+
err = m.serverConn.Close()
451+
require.NoError(m.testContext, err)
452+
}
453+
454+
func (m mockServerSentWriter) Header() http.Header {
455+
return m.serverRecorder.Header()
456+
}
457+
458+
func (m mockServerSentWriter) Write(b []byte) (int, error) {
459+
return m.buffer.Write(b)
460+
}
461+
462+
func (m mockServerSentWriter) WriteHeader(code int) {
463+
m.serverRecorder.WriteHeader(code)
464+
}
465+
425466
func TestServerSentEventSender(t *testing.T) {
426467
t.Parallel()
427468

@@ -432,12 +473,23 @@ func TestServerSentEventSender(t *testing.T) {
432473
return req
433474
}
434475

476+
newServerSentWriter := func(t *testing.T) mockServerSentWriter {
477+
mockServer, mockClient := net.Pipe()
478+
return mockServerSentWriter{
479+
testContext: t,
480+
serverRecorder: httptest.NewRecorder(),
481+
clientConn: mockClient,
482+
serverConn: mockServer,
483+
buffer: &bytes.Buffer{},
484+
}
485+
}
486+
435487
t.Run("Mutates response headers to support SSE connections", func(t *testing.T) {
436488
t.Parallel()
437489

438490
ctx := testutil.Context(t, testutil.WaitShort)
439491
req := newBaseRequest(ctx)
440-
writer := newMockEventSenderWriter()
492+
writer := newServerSentWriter(t)
441493
_, _, err := httpapi.ServerSentEventSender(writer, req)
442494
require.NoError(t, err)
443495

@@ -453,7 +505,7 @@ func TestServerSentEventSender(t *testing.T) {
453505

454506
ctx := testutil.Context(t, testutil.WaitShort)
455507
req := newBaseRequest(ctx)
456-
writer := newMockEventSenderWriter()
508+
writer := newServerSentWriter(t)
457509
send, _, err := httpapi.ServerSentEventSender(writer, req)
458510
require.NoError(t, err)
459511

@@ -464,30 +516,46 @@ func TestServerSentEventSender(t *testing.T) {
464516
err = send(serverPayload)
465517
require.NoError(t, err)
466518

467-
// The client connection will receive a little bit of additional data on
468-
// top of the main payload. Have to make sure check has tolerance for
469-
// extra data being present
470-
serverBytes, err := json.Marshal(serverPayload)
471-
require.NoError(t, err)
472-
473-
// This is the part that's breaking
474519
clientBytes, err := io.ReadAll(writer.clientConn)
475520
require.NoError(t, err)
476-
require.True(t, bytes.Contains(clientBytes, serverBytes))
521+
require.Equal(
522+
t,
523+
string(clientBytes),
524+
"event: data\ndata: \"Blah\"\n\n",
525+
)
477526
})
478527

479528
t.Run("Signals to outside consumer when connection has been closed", func(t *testing.T) {
480529
t.Parallel()
481-
t.FailNow()
530+
531+
ctx, cancel := context.WithCancel(testutil.Context(t, testutil.WaitShort))
532+
req := newBaseRequest(ctx)
533+
writer := newServerSentWriter(t)
534+
_, done, err := httpapi.ServerSentEventSender(writer, req)
535+
require.NoError(t, err)
536+
537+
successC := make(chan bool)
538+
ticker := time.NewTicker(testutil.WaitShort)
539+
go func() {
540+
select {
541+
case <-done:
542+
successC <- true
543+
case <-ticker.C:
544+
successC <- false
545+
}
546+
}()
547+
548+
cancel()
549+
require.True(t, <-successC)
482550
})
483551

484552
t.Run("Cancels the entire connection if the request context cancels", func(t *testing.T) {
485-
t.Parallel()
486553
t.FailNow()
554+
t.Parallel()
487555
})
488556

489557
t.Run("Sends a heartbeat to the client on a fixed internal of time to keep connections alive", func(t *testing.T) {
490-
t.Parallel()
491558
t.FailNow()
559+
t.Parallel()
492560
})
493561
}

0 commit comments

Comments
 (0)