Skip to content

Commit 9d3e2d9

Browse files
committed
chore: add support for one-way websockets to backend
1 parent ec11f11 commit 9d3e2d9

File tree

6 files changed

+456
-32
lines changed

6 files changed

+456
-32
lines changed

coderd/coderd.go

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -823,7 +823,7 @@ func New(options *Options) *API {
823823
// we do not override subdomain app routes.
824824
r.Get("/latency-check", tracing.StatusWriterMiddleware(prometheusMW(LatencyCheck())).ServeHTTP)
825825

826-
r.Get("/healthz", func(w http.ResponseWriter, r *http.Request) { _, _ = w.Write([]byte("OK")) })
826+
r.Get("/healthz", func(w http.ResponseWriter, _ *http.Request) { _, _ = w.Write([]byte("OK")) })
827827

828828
// Attach workspace apps routes.
829829
r.Group(func(r chi.Router) {
@@ -838,7 +838,7 @@ func New(options *Options) *API {
838838
r.Route("/derp", func(r chi.Router) {
839839
r.Get("/", derpHandler.ServeHTTP)
840840
// This is used when UDP is blocked, and latency must be checked via HTTP(s).
841-
r.Get("/latency-check", func(w http.ResponseWriter, r *http.Request) {
841+
r.Get("/latency-check", func(w http.ResponseWriter, _ *http.Request) {
842842
w.WriteHeader(http.StatusOK)
843843
})
844844
})
@@ -895,7 +895,7 @@ func New(options *Options) *API {
895895
r.Route("/api/v2", func(r chi.Router) {
896896
api.APIHandler = r
897897

898-
r.NotFound(func(rw http.ResponseWriter, r *http.Request) { httpapi.RouteNotFound(rw) })
898+
r.NotFound(func(rw http.ResponseWriter, _ *http.Request) { httpapi.RouteNotFound(rw) })
899899
r.Use(
900900
// Specific routes can specify different limits, but every rate
901901
// limit must be configurable by the admin.
@@ -1230,7 +1230,8 @@ func New(options *Options) *API {
12301230
httpmw.ExtractWorkspaceParam(options.Database),
12311231
)
12321232
r.Get("/", api.workspaceAgent)
1233-
r.Get("/watch-metadata", api.watchWorkspaceAgentMetadata)
1233+
r.Get("/watch-metadata", api.watchWorkspaceAgentMetadataSSE)
1234+
r.Get("/watch-metadata-ws", api.watchWorkspaceAgentMetadataWS)
12341235
r.Get("/startup-logs", api.workspaceAgentLogsDeprecated)
12351236
r.Get("/logs", api.workspaceAgentLogs)
12361237
r.Get("/listening-ports", api.workspaceAgentListeningPorts)
@@ -1262,7 +1263,8 @@ func New(options *Options) *API {
12621263
r.Route("/ttl", func(r chi.Router) {
12631264
r.Put("/", api.putWorkspaceTTL)
12641265
})
1265-
r.Get("/watch", api.watchWorkspace)
1266+
r.Get("/watch", api.watchWorkspaceSSE)
1267+
r.Get("/watch-ws", api.watchWorkspaceWS)
12661268
r.Put("/extend", api.putExtendWorkspace)
12671269
r.Post("/usage", api.postWorkspaceUsage)
12681270
r.Put("/dormant", api.putWorkspaceDormant)
@@ -1408,7 +1410,7 @@ func New(options *Options) *API {
14081410
// global variable here.
14091411
r.Get("/swagger/*", globalHTTPSwaggerHandler)
14101412
} else {
1411-
swaggerDisabled := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
1413+
swaggerDisabled := http.HandlerFunc(func(rw http.ResponseWriter, _ *http.Request) {
14121414
httpapi.Write(context.Background(), rw, http.StatusNotFound, codersdk.Response{
14131415
Message: "Swagger documentation is disabled.",
14141416
})

coderd/httpapi/httpapi.go

Lines changed: 119 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@ import (
1616
"github.com/go-playground/validator/v10"
1717
"golang.org/x/xerrors"
1818

19+
"github.com/coder/websocket"
20+
"github.com/coder/websocket/wsjson"
21+
1922
"github.com/coder/coder/v2/coderd/httpapi/httpapiconstraints"
2023
"github.com/coder/coder/v2/coderd/tracing"
2124
"github.com/coder/coder/v2/codersdk"
@@ -282,7 +285,25 @@ func WebsocketCloseSprintf(format string, vars ...any) string {
282285
return msg
283286
}
284287

285-
func ServerSentEventSender(rw http.ResponseWriter, r *http.Request) (sendEvent func(ctx context.Context, sse codersdk.ServerSentEvent) error, closed chan struct{}, err error) {
288+
type InitializeConnectionCallback func(rw http.ResponseWriter, r *http.Request) (
289+
sendEvent func(sse codersdk.ServerSentEvent) error,
290+
done <-chan struct{},
291+
err error,
292+
)
293+
294+
// ServerSentEventSender establishes a Server-Sent Event connection and allows
295+
// the consumer to send messages to the client.
296+
//
297+
// The function returned allows you to send a single message to the client,
298+
// while the channel lets you listen for when the connection closes.
299+
//
300+
// As much as possible, this function should be avoided in favor of using the
301+
// OneWayWebSocket function. See OneWayWebSocket for more context.
302+
func ServerSentEventSender(rw http.ResponseWriter, r *http.Request) (
303+
func(sse codersdk.ServerSentEvent) error,
304+
<-chan struct{},
305+
error,
306+
) {
286307
h := rw.Header()
287308
h.Set("Content-Type", "text/event-stream")
288309
h.Set("Cache-Control", "no-cache")
@@ -294,7 +315,8 @@ func ServerSentEventSender(rw http.ResponseWriter, r *http.Request) (sendEvent f
294315
panic("http.ResponseWriter is not http.Flusher")
295316
}
296317

297-
closed = make(chan struct{})
318+
ctx := r.Context()
319+
closed := make(chan struct{})
298320
type sseEvent struct {
299321
payload []byte
300322
errC chan error
@@ -333,21 +355,21 @@ func ServerSentEventSender(rw http.ResponseWriter, r *http.Request) (sendEvent f
333355
}
334356
}()
335357

336-
sendEvent = func(ctx context.Context, sse codersdk.ServerSentEvent) error {
358+
sendEvent := func(newEvent codersdk.ServerSentEvent) error {
337359
buf := &bytes.Buffer{}
338360
enc := json.NewEncoder(buf)
339361

340-
_, err := buf.WriteString(fmt.Sprintf("event: %s\n", sse.Type))
362+
_, err := buf.WriteString(fmt.Sprintf("event: %s\n", newEvent.Type))
341363
if err != nil {
342364
return err
343365
}
344366

345-
if sse.Data != nil {
367+
if newEvent.Data != nil {
346368
_, err = buf.WriteString("data: ")
347369
if err != nil {
348370
return err
349371
}
350-
err = enc.Encode(sse.Data)
372+
err = enc.Encode(newEvent.Data)
351373
if err != nil {
352374
return err
353375
}
@@ -387,3 +409,94 @@ func ServerSentEventSender(rw http.ResponseWriter, r *http.Request) (sendEvent f
387409

388410
return sendEvent, closed, nil
389411
}
412+
413+
// OneWayWebSocket establishes a new WebSocket connection that enforces one-way
414+
// communication from the server to the client.
415+
//
416+
// The function returned allows you to send a single message to the client,
417+
// while the channel lets you listen for when the connection closes.
418+
//
419+
// We must use an approach like this instead of Server-Sent Events for the
420+
// browser, because on HTTP/1.1 connections, browsers are locked to no more than
421+
// six HTTP connections for a domain total, across all tabs. If a user were to
422+
// open a workspace in multiple tabs, the entire UI can start to lock up.
423+
// WebSockets have no such limitation, no matter what HTTP protocol was used to
424+
// establish the connection.
425+
func OneWayWebSocket(rw http.ResponseWriter, r *http.Request) (
426+
func(event codersdk.ServerSentEvent) error,
427+
<-chan struct{},
428+
error,
429+
) {
430+
ctx, cancel := context.WithCancel(r.Context())
431+
r = r.WithContext(ctx)
432+
socket, err := websocket.Accept(rw, r, nil)
433+
if err != nil {
434+
cancel()
435+
return nil, nil, xerrors.Errorf("cannot establish connection: %w", err)
436+
}
437+
go Heartbeat(ctx, socket)
438+
439+
type SocketError struct {
440+
Code websocket.StatusCode
441+
Reason string
442+
}
443+
eventC := make(chan codersdk.ServerSentEvent)
444+
socketErrC := make(chan SocketError, 1)
445+
closed := make(chan struct{})
446+
go func() {
447+
defer cancel()
448+
defer close(closed)
449+
450+
for {
451+
select {
452+
case event := <-eventC:
453+
writeCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
454+
err := wsjson.Write(writeCtx, socket, event)
455+
cancel()
456+
if err == nil {
457+
continue
458+
}
459+
_ = socket.Close(websocket.StatusInternalError, "Unable to send newest message")
460+
case err := <-socketErrC:
461+
_ = socket.Close(err.Code, err.Reason)
462+
case <-ctx.Done():
463+
_ = socket.Close(websocket.StatusNormalClosure, "Connection closed")
464+
}
465+
return
466+
}
467+
}()
468+
469+
// We have some tools in the UI code to help enforce one-way WebSocket
470+
// connections, but there's still the possibility that the client could send
471+
// a message when it's not supposed to. If that happens, the client likely
472+
// forgot to use those tools, and communication probably can't be trusted.
473+
// Better to just close the socket and force the UI to fix its mess
474+
go func() {
475+
_, _, err := socket.Read(ctx)
476+
if errors.Is(err, context.Canceled) {
477+
return
478+
}
479+
if err != nil {
480+
socketErrC <- SocketError{
481+
Code: websocket.StatusInternalError,
482+
Reason: "Unable to process invalid message from client",
483+
}
484+
return
485+
}
486+
socketErrC <- SocketError{
487+
Code: websocket.StatusProtocolError,
488+
Reason: "Clients cannot send messages for one-way WebSockets",
489+
}
490+
}()
491+
492+
sendEvent := func(event codersdk.ServerSentEvent) error {
493+
select {
494+
case eventC <- event:
495+
case <-ctx.Done():
496+
return ctx.Err()
497+
}
498+
return nil
499+
}
500+
501+
return sendEvent, closed, nil
502+
}

0 commit comments

Comments
 (0)