Skip to content

chore: add support for one-way websockets to backend #16853

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Mar 28, 2025
Prev Previous commit
Next Next commit
wip: commit progress on tests
  • Loading branch information
Parkreiner committed Mar 19, 2025
commit c7d95d97be2067643ac1f70e94eab60243ce726f
21 changes: 7 additions & 14 deletions coderd/httpapi/httpapi.go
Original file line number Diff line number Diff line change
Expand Up @@ -326,16 +326,13 @@ func ServerSentEventSender(rw http.ResponseWriter, r *http.Request) (
// Synchronized handling of events (no guarantee of order).
go func() {
defer close(closed)

// Send a heartbeat every 15 seconds to avoid the connection being killed.
ticker := time.NewTicker(time.Second * 15)
ticker := time.NewTicker(HeartbeatInterval)
defer ticker.Stop()

for {
var event sseEvent

select {
case <-r.Context().Done():
case <-ctx.Done():
return
case event = <-eventC:
case <-ticker.C:
Expand All @@ -357,8 +354,6 @@ func ServerSentEventSender(rw http.ResponseWriter, r *http.Request) (

sendEvent := func(newEvent codersdk.ServerSentEvent) error {
buf := &bytes.Buffer{}
enc := json.NewEncoder(buf)

_, err := buf.WriteString(fmt.Sprintf("event: %s\n", newEvent.Type))
if err != nil {
return err
Expand All @@ -369,6 +364,8 @@ func ServerSentEventSender(rw http.ResponseWriter, r *http.Request) (
if err != nil {
return err
}

enc := json.NewEncoder(buf)
err = enc.Encode(newEvent.Data)
if err != nil {
return err
Expand All @@ -386,8 +383,6 @@ func ServerSentEventSender(rw http.ResponseWriter, r *http.Request) (
}

select {
case <-r.Context().Done():
return r.Context().Err()
case <-ctx.Done():
return ctx.Err()
case <-closed:
Expand All @@ -397,8 +392,6 @@ func ServerSentEventSender(rw http.ResponseWriter, r *http.Request) (
// for early exit. We don't check closed here because it
// can't happen while processing the event.
select {
case <-r.Context().Done():
return r.Context().Err()
case <-ctx.Done():
return ctx.Err()
case err := <-event.errC:
Expand All @@ -410,8 +403,8 @@ func ServerSentEventSender(rw http.ResponseWriter, r *http.Request) (
return sendEvent, closed, nil
}

// WebSocketEventSender establishes a new WebSocket connection that enforces
// one-way communication from the server to the client.
// OneWayWebSocketEventSender establishes a new WebSocket connection that
// enforces one-way communication from the server to the client.
//
// The function returned allows you to send a single message to the client,
// while the channel lets you listen for when the connection closes.
Expand All @@ -422,7 +415,7 @@ func ServerSentEventSender(rw http.ResponseWriter, r *http.Request) (
// open a workspace in multiple tabs, the entire UI can start to lock up.
// WebSockets have no such limitation, no matter what HTTP protocol was used to
// establish the connection.
func WebSocketEventSender(rw http.ResponseWriter, r *http.Request) (
func OneWayWebSocketEventSender(rw http.ResponseWriter, r *http.Request) (
func(event codersdk.ServerSentEvent) error,
<-chan struct{},
error,
Expand Down
163 changes: 117 additions & 46 deletions coderd/httpapi/httpapi_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -162,41 +162,66 @@ func TestWebsocketCloseMsg(t *testing.T) {
}

// Our WebSocket library accepts any arbitrary ResponseWriter at the type level,
// but the writer must also implement http.Hijacker for long-lived connections
type mockWsResponseWriter struct {
// but the writer must also implement http.Hijacker for long-lived connections.
// The SSE version only requires http.Flusher (no need for the Hijack method).
type mockEventSenderResponseWriter struct {
serverRecorder *httptest.ResponseRecorder
serverConn net.Conn
clientConn net.Conn
serverReadWriter *bufio.ReadWriter
}

func (m mockWsResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
func (m mockEventSenderResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
return m.serverConn, m.serverReadWriter, nil
}

func (m mockWsResponseWriter) Flush() {
func (m mockEventSenderResponseWriter) Flush() {
_ = m.serverReadWriter.Flush()
}

func (m mockWsResponseWriter) Header() http.Header {
func (m mockEventSenderResponseWriter) Header() http.Header {
return m.serverRecorder.Header()
}

func (m mockWsResponseWriter) Write(b []byte) (int, error) {
func (m mockEventSenderResponseWriter) Write(b []byte) (int, error) {
return m.serverReadWriter.Write(b)
}

func (m mockWsResponseWriter) WriteHeader(code int) {
func (m mockEventSenderResponseWriter) WriteHeader(code int) {
m.serverRecorder.WriteHeader(code)
}

type mockWsWrite func(b []byte) (int, error)
type mockEventSenderWrite func(b []byte) (int, error)

func (w mockWsWrite) Write(b []byte) (int, error) {
func (w mockEventSenderWrite) Write(b []byte) (int, error) {
return w(b)
}

func TestWebSocketEventSender(t *testing.T) {
func newMockEventSenderWriter() mockEventSenderResponseWriter {
mockServer, mockClient := net.Pipe()
recorder := httptest.NewRecorder()

var write mockEventSenderWrite = func(b []byte) (int, error) {
serverCount, err := mockServer.Write(b)
if err != nil {
return serverCount, err
}
recorderCount, err := recorder.Write(b)
return min(serverCount, recorderCount), err
}

return mockEventSenderResponseWriter{
serverConn: mockServer,
clientConn: mockClient,
serverRecorder: recorder,
serverReadWriter: bufio.NewReadWriter(
bufio.NewReader(mockServer),
bufio.NewWriter(write),
),
}
}

func TestOneWayWebSocketEventSender(t *testing.T) {
t.Parallel()

newBaseRequest := func(ctx context.Context) *http.Request {
Expand All @@ -213,30 +238,6 @@ func TestWebSocketEventSender(t *testing.T) {
return req
}

newWebsocketWriter := func() mockWsResponseWriter {
mockServer, mockClient := net.Pipe()
recorder := httptest.NewRecorder()

var write mockWsWrite = func(b []byte) (int, error) {
serverCount, err := mockServer.Write(b)
if err != nil {
return serverCount, err
}
recorderCount, err := recorder.Write(b)
return min(serverCount, recorderCount), err
}

return mockWsResponseWriter{
serverConn: mockServer,
clientConn: mockClient,
serverRecorder: recorder,
serverReadWriter: bufio.NewReadWriter(
bufio.NewReader(mockServer),
bufio.NewWriter(write),
),
}
}

t.Run("Produces error if the socket connection could not be established", func(t *testing.T) {
t.Parallel()

Expand All @@ -255,8 +256,8 @@ func TestWebSocketEventSender(t *testing.T) {
req.ProtoMinor = p.minor
req.Proto = p.proto

writer := newWebsocketWriter()
_, _, err := httpapi.WebSocketEventSender(writer, req)
writer := newMockEventSenderWriter()
_, _, err := httpapi.OneWayWebSocketEventSender(writer, req)
require.ErrorContains(t, err, p.proto)
}
})
Expand All @@ -266,8 +267,8 @@ func TestWebSocketEventSender(t *testing.T) {

ctx := testutil.Context(t, testutil.WaitShort)
req := newBaseRequest(ctx)
writer := newWebsocketWriter()
send, _, err := httpapi.WebSocketEventSender(writer, req)
writer := newMockEventSenderWriter()
send, _, err := httpapi.OneWayWebSocketEventSender(writer, req)
require.NoError(t, err)

serverPayload := codersdk.ServerSentEvent{
Expand All @@ -292,8 +293,8 @@ func TestWebSocketEventSender(t *testing.T) {

ctx, cancel := context.WithCancel(testutil.Context(t, testutil.WaitShort))
req := newBaseRequest(ctx)
writer := newWebsocketWriter()
_, done, err := httpapi.WebSocketEventSender(writer, req)
writer := newMockEventSenderWriter()
_, done, err := httpapi.OneWayWebSocketEventSender(writer, req)
require.NoError(t, err)

successC := make(chan bool)
Expand All @@ -316,8 +317,8 @@ func TestWebSocketEventSender(t *testing.T) {

ctx := testutil.Context(t, testutil.WaitShort)
req := newBaseRequest(ctx)
writer := newWebsocketWriter()
_, done, err := httpapi.WebSocketEventSender(writer, req)
writer := newMockEventSenderWriter()
_, done, err := httpapi.OneWayWebSocketEventSender(writer, req)
require.NoError(t, err)

successC := make(chan bool)
Expand Down Expand Up @@ -346,8 +347,8 @@ func TestWebSocketEventSender(t *testing.T) {

ctx, cancel := context.WithCancel(testutil.Context(t, testutil.WaitShort))
req := newBaseRequest(ctx)
writer := newWebsocketWriter()
send, done, err := httpapi.WebSocketEventSender(writer, req)
writer := newMockEventSenderWriter()
send, done, err := httpapi.OneWayWebSocketEventSender(writer, req)
require.NoError(t, err)

successC := make(chan bool)
Expand Down Expand Up @@ -387,8 +388,8 @@ func TestWebSocketEventSender(t *testing.T) {

ctx := testutil.Context(t, timeout)
req := newBaseRequest(ctx)
writer := newWebsocketWriter()
_, _, err := httpapi.WebSocketEventSender(writer, req)
writer := newMockEventSenderWriter()
_, _, err := httpapi.OneWayWebSocketEventSender(writer, req)
require.NoError(t, err)

type Result struct {
Expand Down Expand Up @@ -420,3 +421,73 @@ func TestWebSocketEventSender(t *testing.T) {
require.True(t, result.Success)
})
}

func TestServerSentEventSender(t *testing.T) {
t.Parallel()

newBaseRequest := func(ctx context.Context) *http.Request {
url := "ws://www.fake-website.com/logs"
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
require.NoError(t, err)
return req
}

t.Run("Mutates response headers to support SSE connections", func(t *testing.T) {
t.Parallel()

ctx := testutil.Context(t, testutil.WaitShort)
req := newBaseRequest(ctx)
writer := newMockEventSenderWriter()
_, _, err := httpapi.ServerSentEventSender(writer, req)
require.NoError(t, err)

h := writer.Header()
require.Equal(t, h.Get("Content-Type"), "text/event-stream")
require.Equal(t, h.Get("Cache-Control"), "no-cache")
require.Equal(t, h.Get("Connection"), "keep-alive")
require.Equal(t, h.Get("X-Accel-Buffering"), "no")
})

t.Run("Returned callback can publish new event to SSE connection", func(t *testing.T) {
t.Parallel()

ctx := testutil.Context(t, testutil.WaitShort)
req := newBaseRequest(ctx)
writer := newMockEventSenderWriter()
send, _, err := httpapi.ServerSentEventSender(writer, req)
require.NoError(t, err)

serverPayload := codersdk.ServerSentEvent{
Type: codersdk.ServerSentEventTypeData,
Data: "Blah",
}
err = send(serverPayload)
require.NoError(t, err)

// The client connection will receive a little bit of additional data on
// top of the main payload. Have to make sure check has tolerance for
// extra data being present
serverBytes, err := json.Marshal(serverPayload)
require.NoError(t, err)

// This is the part that's breaking
clientBytes, err := io.ReadAll(writer.clientConn)
require.NoError(t, err)
require.True(t, bytes.Contains(clientBytes, serverBytes))
})

t.Run("Signals to outside consumer when connection has been closed", func(t *testing.T) {
t.Parallel()
t.FailNow()
})

t.Run("Cancels the entire connection if the request context cancels", func(t *testing.T) {
t.Parallel()
t.FailNow()
})

t.Run("Sends a heartbeat to the client on a fixed internal of time to keep connections alive", func(t *testing.T) {
t.Parallel()
t.FailNow()
})
}
2 changes: 1 addition & 1 deletion coderd/workspaceagents.go
Original file line number Diff line number Diff line change
Expand Up @@ -1109,7 +1109,7 @@ func (api *API) watchWorkspaceAgentMetadataSSE(rw http.ResponseWriter, r *http.R
// @Router /workspaceagents/{workspaceagent}/watch-metadata-ws [get]
// @x-apidocgen {"skip": true}
func (api *API) watchWorkspaceAgentMetadataWS(rw http.ResponseWriter, r *http.Request) {
api.watchWorkspaceAgentMetadata(rw, r, httpapi.WebSocketEventSender)
api.watchWorkspaceAgentMetadata(rw, r, httpapi.OneWayWebSocketEventSender)
}

func (api *API) watchWorkspaceAgentMetadata(
Expand Down
2 changes: 1 addition & 1 deletion coderd/workspaces.go
Original file line number Diff line number Diff line change
Expand Up @@ -1732,7 +1732,7 @@ func (api *API) watchWorkspaceSSE(rw http.ResponseWriter, r *http.Request) {
// @Success 200 {object} codersdk.ServerSentEvent
// @Router /workspaces/{workspace}/watch-ws [get]
func (api *API) watchWorkspaceWS(rw http.ResponseWriter, r *http.Request) {
api.watchWorkspace(rw, r, httpapi.WebSocketEventSender)
api.watchWorkspace(rw, r, httpapi.OneWayWebSocketEventSender)
}

func (api *API) watchWorkspace(
Expand Down