Skip to content

chore: add support for one-way websockets to backend #16853

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Mar 28, 2025
Prev Previous commit
Next Next commit
wip: commit more test progress
  • Loading branch information
Parkreiner committed Mar 19, 2025
commit 792aa2dc6e3799e1a0e703103fd74eba9448fbde
170 changes: 119 additions & 51 deletions coderd/httpapi/httpapi_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -163,31 +163,32 @@ func TestWebsocketCloseMsg(t *testing.T) {

// Our WebSocket library accepts any arbitrary ResponseWriter at the type level,
// but the writer must also implement http.Hijacker for long-lived connections.
// The SSE version only requires http.Flusher (no need for the Hijack method).
type mockEventSenderResponseWriter struct {
type mockOneWaySocketWriter struct {
serverRecorder *httptest.ResponseRecorder
serverConn net.Conn
clientConn net.Conn
serverReadWriter *bufio.ReadWriter
testContext *testing.T
}

func (m mockEventSenderResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
func (m mockOneWaySocketWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
return m.serverConn, m.serverReadWriter, nil
}

func (m mockEventSenderResponseWriter) Flush() {
_ = m.serverReadWriter.Flush()
func (m mockOneWaySocketWriter) Flush() {
err := m.serverReadWriter.Flush()
require.NoError(m.testContext, err)
}

func (m mockEventSenderResponseWriter) Header() http.Header {
func (m mockOneWaySocketWriter) Header() http.Header {
return m.serverRecorder.Header()
}

func (m mockEventSenderResponseWriter) Write(b []byte) (int, error) {
func (m mockOneWaySocketWriter) Write(b []byte) (int, error) {
return m.serverReadWriter.Write(b)
}

func (m mockEventSenderResponseWriter) WriteHeader(code int) {
func (m mockOneWaySocketWriter) WriteHeader(code int) {
m.serverRecorder.WriteHeader(code)
}

Expand All @@ -197,30 +198,6 @@ func (w mockEventSenderWrite) Write(b []byte) (int, error) {
return w(b)
}

func newMockEventSenderWriter() mockEventSenderResponseWriter {
mockServer, mockClient := net.Pipe()
recorder := httptest.NewRecorder()

var write mockEventSenderWrite = func(b []byte) (int, error) {
serverCount, err := mockServer.Write(b)
if err != nil {
return serverCount, err
}
recorderCount, err := recorder.Write(b)
return min(serverCount, recorderCount), err
}

return mockEventSenderResponseWriter{
serverConn: mockServer,
clientConn: mockClient,
serverRecorder: recorder,
serverReadWriter: bufio.NewReadWriter(
bufio.NewReader(mockServer),
bufio.NewWriter(write),
),
}
}

func TestOneWayWebSocketEventSender(t *testing.T) {
t.Parallel()

Expand All @@ -238,6 +215,34 @@ func TestOneWayWebSocketEventSender(t *testing.T) {
return req
}

newOneWayWriter := func(t *testing.T) mockOneWaySocketWriter {
mockServer, mockClient := net.Pipe()
recorder := httptest.NewRecorder()

var write mockEventSenderWrite = func(b []byte) (int, error) {
serverCount, err := mockServer.Write(b)
if err != nil {
return 0, err
}
recorderCount, err := recorder.Write(b)
if err != nil {
return 0, err
}
return min(serverCount, recorderCount), nil
}

return mockOneWaySocketWriter{
testContext: t,
serverConn: mockServer,
clientConn: mockClient,
serverRecorder: recorder,
serverReadWriter: bufio.NewReadWriter(
bufio.NewReader(mockServer),
bufio.NewWriter(write),
),
}
}

t.Run("Produces error if the socket connection could not be established", func(t *testing.T) {
t.Parallel()

Expand All @@ -256,7 +261,7 @@ func TestOneWayWebSocketEventSender(t *testing.T) {
req.ProtoMinor = p.minor
req.Proto = p.proto

writer := newMockEventSenderWriter()
writer := newOneWayWriter(t)
_, _, err := httpapi.OneWayWebSocketEventSender(writer, req)
require.ErrorContains(t, err, p.proto)
}
Expand All @@ -267,7 +272,7 @@ func TestOneWayWebSocketEventSender(t *testing.T) {

ctx := testutil.Context(t, testutil.WaitShort)
req := newBaseRequest(ctx)
writer := newMockEventSenderWriter()
writer := newOneWayWriter(t)
send, _, err := httpapi.OneWayWebSocketEventSender(writer, req)
require.NoError(t, err)

Expand All @@ -293,7 +298,7 @@ func TestOneWayWebSocketEventSender(t *testing.T) {

ctx, cancel := context.WithCancel(testutil.Context(t, testutil.WaitShort))
req := newBaseRequest(ctx)
writer := newMockEventSenderWriter()
writer := newOneWayWriter(t)
_, done, err := httpapi.OneWayWebSocketEventSender(writer, req)
require.NoError(t, err)

Expand All @@ -317,7 +322,7 @@ func TestOneWayWebSocketEventSender(t *testing.T) {

ctx := testutil.Context(t, testutil.WaitShort)
req := newBaseRequest(ctx)
writer := newMockEventSenderWriter()
writer := newOneWayWriter(t)
_, done, err := httpapi.OneWayWebSocketEventSender(writer, req)
require.NoError(t, err)

Expand Down Expand Up @@ -347,7 +352,7 @@ func TestOneWayWebSocketEventSender(t *testing.T) {

ctx, cancel := context.WithCancel(testutil.Context(t, testutil.WaitShort))
req := newBaseRequest(ctx)
writer := newMockEventSenderWriter()
writer := newOneWayWriter(t)
send, done, err := httpapi.OneWayWebSocketEventSender(writer, req)
require.NoError(t, err)

Expand Down Expand Up @@ -388,7 +393,7 @@ func TestOneWayWebSocketEventSender(t *testing.T) {

ctx := testutil.Context(t, timeout)
req := newBaseRequest(ctx)
writer := newMockEventSenderWriter()
writer := newOneWayWriter(t)
_, _, err := httpapi.OneWayWebSocketEventSender(writer, req)
require.NoError(t, err)

Expand Down Expand Up @@ -422,6 +427,42 @@ func TestOneWayWebSocketEventSender(t *testing.T) {
})
}

// ServerSentEventSender accepts any arbitrary ResponseWriter at the type level,
// but the writer must also implement http.Flusher for long-lived connections
type mockServerSentWriter struct {
serverRecorder *httptest.ResponseRecorder
serverConn net.Conn
clientConn net.Conn
buffer *bytes.Buffer
testContext *testing.T
}

func (m mockServerSentWriter) Flush() {
b := m.buffer.Bytes()
_, err := m.serverConn.Write(b)
require.NoError(m.testContext, err)
m.buffer.Reset()

// Must close server connection to indicate EOF for any reads from the
// client connection; otherwise reads block forever. This is a testing
// limitation compared to the one-way websockets, since we have no way to
// frame the data and auto-indicate EOF for each message
err = m.serverConn.Close()
require.NoError(m.testContext, err)
}

func (m mockServerSentWriter) Header() http.Header {
return m.serverRecorder.Header()
}

func (m mockServerSentWriter) Write(b []byte) (int, error) {
return m.buffer.Write(b)
}

func (m mockServerSentWriter) WriteHeader(code int) {
m.serverRecorder.WriteHeader(code)
}

func TestServerSentEventSender(t *testing.T) {
t.Parallel()

Expand All @@ -432,12 +473,23 @@ func TestServerSentEventSender(t *testing.T) {
return req
}

newServerSentWriter := func(t *testing.T) mockServerSentWriter {
mockServer, mockClient := net.Pipe()
return mockServerSentWriter{
testContext: t,
serverRecorder: httptest.NewRecorder(),
clientConn: mockClient,
serverConn: mockServer,
buffer: &bytes.Buffer{},
}
}

t.Run("Mutates response headers to support SSE connections", func(t *testing.T) {
t.Parallel()

ctx := testutil.Context(t, testutil.WaitShort)
req := newBaseRequest(ctx)
writer := newMockEventSenderWriter()
writer := newServerSentWriter(t)
_, _, err := httpapi.ServerSentEventSender(writer, req)
require.NoError(t, err)

Expand All @@ -453,7 +505,7 @@ func TestServerSentEventSender(t *testing.T) {

ctx := testutil.Context(t, testutil.WaitShort)
req := newBaseRequest(ctx)
writer := newMockEventSenderWriter()
writer := newServerSentWriter(t)
send, _, err := httpapi.ServerSentEventSender(writer, req)
require.NoError(t, err)

Expand All @@ -464,30 +516,46 @@ func TestServerSentEventSender(t *testing.T) {
err = send(serverPayload)
require.NoError(t, err)

// The client connection will receive a little bit of additional data on
// top of the main payload. Have to make sure check has tolerance for
// extra data being present
serverBytes, err := json.Marshal(serverPayload)
require.NoError(t, err)

// This is the part that's breaking
clientBytes, err := io.ReadAll(writer.clientConn)
require.NoError(t, err)
require.True(t, bytes.Contains(clientBytes, serverBytes))
require.Equal(
t,
string(clientBytes),
"event: data\ndata: \"Blah\"\n\n",
)
})

t.Run("Signals to outside consumer when connection has been closed", func(t *testing.T) {
t.Parallel()
t.FailNow()

ctx, cancel := context.WithCancel(testutil.Context(t, testutil.WaitShort))
req := newBaseRequest(ctx)
writer := newServerSentWriter(t)
_, done, err := httpapi.ServerSentEventSender(writer, req)
require.NoError(t, err)

successC := make(chan bool)
ticker := time.NewTicker(testutil.WaitShort)
go func() {
select {
case <-done:
successC <- true
case <-ticker.C:
successC <- false
}
}()

cancel()
require.True(t, <-successC)
})

t.Run("Cancels the entire connection if the request context cancels", func(t *testing.T) {
t.Parallel()
t.FailNow()
t.Parallel()
})

t.Run("Sends a heartbeat to the client on a fixed internal of time to keep connections alive", func(t *testing.T) {
t.Parallel()
t.FailNow()
t.Parallel()
})
}