From 76911bbab68f191794448c7662a90254c8d84929 Mon Sep 17 00:00:00 2001 From: Spike Curtis Date: Mon, 2 Dec 2024 10:23:40 +0400 Subject: [PATCH] fix: fix goroutine leak in log streaming over websocket --- coderd/provisionerjobs.go | 9 ++-- coderd/workspaceagents.go | 24 ++++------- codersdk/provisionerdaemons.go | 33 ++------------- codersdk/workspaceagents.go | 29 ++----------- codersdk/wsjson/decoder.go | 75 ++++++++++++++++++++++++++++++++++ codersdk/wsjson/encoder.go | 42 +++++++++++++++++++ 6 files changed, 134 insertions(+), 78 deletions(-) create mode 100644 codersdk/wsjson/decoder.go create mode 100644 codersdk/wsjson/encoder.go diff --git a/coderd/provisionerjobs.go b/coderd/provisionerjobs.go index df832b810e696..3db5d7c20a4bf 100644 --- a/coderd/provisionerjobs.go +++ b/coderd/provisionerjobs.go @@ -15,6 +15,7 @@ import ( "nhooyr.io/websocket" "cdr.dev/slog" + "github.com/coder/coder/v2/codersdk/wsjson" "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/db2sdk" @@ -312,6 +313,7 @@ type logFollower struct { r *http.Request rw http.ResponseWriter conn *websocket.Conn + enc *wsjson.Encoder[codersdk.ProvisionerJobLog] jobID uuid.UUID after int64 @@ -391,6 +393,7 @@ func (f *logFollower) follow() { } defer f.conn.Close(websocket.StatusNormalClosure, "done") go httpapi.Heartbeat(f.ctx, f.conn) + f.enc = wsjson.NewEncoder[codersdk.ProvisionerJobLog](f.conn, websocket.MessageText) // query for logs once right away, so we can get historical data from before // subscription @@ -488,11 +491,7 @@ func (f *logFollower) query() error { return xerrors.Errorf("error fetching logs: %w", err) } for _, log := range logs { - logB, err := json.Marshal(convertProvisionerJobLog(log)) - if err != nil { - return xerrors.Errorf("error marshaling log: %w", err) - } - err = f.conn.Write(f.ctx, websocket.MessageText, logB) + err := f.enc.Encode(convertProvisionerJobLog(log)) if err != nil { return xerrors.Errorf("error writing to websocket: %w", err) } diff --git a/coderd/workspaceagents.go b/coderd/workspaceagents.go index 922d80f0e8085..6bc09e0e770f6 100644 --- a/coderd/workspaceagents.go +++ b/coderd/workspaceagents.go @@ -39,6 +39,7 @@ import ( "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/codersdk/agentsdk" "github.com/coder/coder/v2/codersdk/workspacesdk" + "github.com/coder/coder/v2/codersdk/wsjson" "github.com/coder/coder/v2/tailnet" "github.com/coder/coder/v2/tailnet/proto" ) @@ -396,11 +397,9 @@ func (api *API) workspaceAgentLogs(rw http.ResponseWriter, r *http.Request) { } go httpapi.Heartbeat(ctx, conn) - ctx, wsNetConn := codersdk.WebsocketNetConn(ctx, conn, websocket.MessageText) - defer wsNetConn.Close() // Also closes conn. + encoder := wsjson.NewEncoder[[]codersdk.WorkspaceAgentLog](conn, websocket.MessageText) + defer encoder.Close(websocket.StatusNormalClosure) - // The Go stdlib JSON encoder appends a newline character after message write. - encoder := json.NewEncoder(wsNetConn) err = encoder.Encode(convertWorkspaceAgentLogs(logs)) if err != nil { return @@ -740,16 +739,8 @@ func (api *API) derpMapUpdates(rw http.ResponseWriter, r *http.Request) { }) return } - ctx, nconn := codersdk.WebsocketNetConn(ctx, ws, websocket.MessageBinary) - defer nconn.Close() - - // Slurp all packets from the connection into io.Discard so pongs get sent - // by the websocket package. We don't do any reads ourselves so this is - // necessary. - go func() { - _, _ = io.Copy(io.Discard, nconn) - _ = nconn.Close() - }() + encoder := wsjson.NewEncoder[*tailcfg.DERPMap](ws, websocket.MessageBinary) + defer encoder.Close(websocket.StatusGoingAway) go func(ctx context.Context) { // TODO(mafredri): Is this too frequent? Use separate ping disconnect timeout? @@ -767,7 +758,7 @@ func (api *API) derpMapUpdates(rw http.ResponseWriter, r *http.Request) { err := ws.Ping(ctx) cancel() if err != nil { - _ = nconn.Close() + _ = ws.Close(websocket.StatusGoingAway, "ping failed") return } } @@ -780,9 +771,8 @@ func (api *API) derpMapUpdates(rw http.ResponseWriter, r *http.Request) { for { derpMap := api.DERPMap() if lastDERPMap == nil || !tailnet.CompareDERPMaps(lastDERPMap, derpMap) { - err := json.NewEncoder(nconn).Encode(derpMap) + err := encoder.Encode(derpMap) if err != nil { - _ = nconn.Close() return } lastDERPMap = derpMap diff --git a/codersdk/provisionerdaemons.go b/codersdk/provisionerdaemons.go index 27d2766a7cd13..fb588ef8ba468 100644 --- a/codersdk/provisionerdaemons.go +++ b/codersdk/provisionerdaemons.go @@ -19,6 +19,7 @@ import ( "github.com/coder/coder/v2/buildinfo" "github.com/coder/coder/v2/codersdk/drpc" + "github.com/coder/coder/v2/codersdk/wsjson" "github.com/coder/coder/v2/provisionerd/proto" "github.com/coder/coder/v2/provisionerd/runner" ) @@ -162,36 +163,8 @@ func (c *Client) provisionerJobLogsAfter(ctx context.Context, path string, after } return nil, nil, ReadBodyAsError(res) } - logs := make(chan ProvisionerJobLog) - closed := make(chan struct{}) - go func() { - defer close(closed) - defer close(logs) - defer conn.Close(websocket.StatusGoingAway, "") - var log ProvisionerJobLog - for { - msgType, msg, err := conn.Read(ctx) - if err != nil { - return - } - if msgType != websocket.MessageText { - return - } - err = json.Unmarshal(msg, &log) - if err != nil { - return - } - select { - case <-ctx.Done(): - return - case logs <- log: - } - } - }() - return logs, closeFunc(func() error { - <-closed - return nil - }), nil + d := wsjson.NewDecoder[ProvisionerJobLog](conn, websocket.MessageText, c.logger) + return d.Chan(), d, nil } // ServeProvisionerDaemonRequest are the parameters to call ServeProvisionerDaemon with diff --git a/codersdk/workspaceagents.go b/codersdk/workspaceagents.go index eeb335b130cdd..b4aec16a83190 100644 --- a/codersdk/workspaceagents.go +++ b/codersdk/workspaceagents.go @@ -15,6 +15,7 @@ import ( "nhooyr.io/websocket" "github.com/coder/coder/v2/coderd/tracing" + "github.com/coder/coder/v2/codersdk/wsjson" ) type WorkspaceAgentStatus string @@ -454,30 +455,6 @@ func (c *Client) WorkspaceAgentLogsAfter(ctx context.Context, agentID uuid.UUID, } return nil, nil, ReadBodyAsError(res) } - logChunks := make(chan []WorkspaceAgentLog, 1) - closed := make(chan struct{}) - ctx, wsNetConn := WebsocketNetConn(ctx, conn, websocket.MessageText) - decoder := json.NewDecoder(wsNetConn) - go func() { - defer close(closed) - defer close(logChunks) - defer conn.Close(websocket.StatusGoingAway, "") - for { - var logs []WorkspaceAgentLog - err = decoder.Decode(&logs) - if err != nil { - return - } - select { - case <-ctx.Done(): - return - case logChunks <- logs: - } - } - }() - return logChunks, closeFunc(func() error { - _ = wsNetConn.Close() - <-closed - return nil - }), nil + d := wsjson.NewDecoder[[]WorkspaceAgentLog](conn, websocket.MessageText, c.logger) + return d.Chan(), d, nil } diff --git a/codersdk/wsjson/decoder.go b/codersdk/wsjson/decoder.go new file mode 100644 index 0000000000000..4cc7ff380a73a --- /dev/null +++ b/codersdk/wsjson/decoder.go @@ -0,0 +1,75 @@ +package wsjson + +import ( + "context" + "encoding/json" + "sync/atomic" + + "nhooyr.io/websocket" + + "cdr.dev/slog" +) + +type Decoder[T any] struct { + conn *websocket.Conn + typ websocket.MessageType + ctx context.Context + cancel context.CancelFunc + chanCalled atomic.Bool + logger slog.Logger +} + +// Chan starts the decoder reading from the websocket and returns a channel for reading the +// resulting values. The chan T is closed if the underlying websocket is closed, or we encounter an +// error. We also close the underlying websocket if we encounter an error reading or decoding. +func (d *Decoder[T]) Chan() <-chan T { + if !d.chanCalled.CompareAndSwap(false, true) { + panic("chan called more than once") + } + values := make(chan T, 1) + go func() { + defer close(values) + defer d.conn.Close(websocket.StatusGoingAway, "") + for { + // we don't use d.ctx here because it only gets canceled after closing the connection + // and a "connection closed" type error is more clear than context canceled. + typ, b, err := d.conn.Read(context.Background()) + if err != nil { + // might be benign like EOF, so just log at debug + d.logger.Debug(d.ctx, "error reading from websocket", slog.Error(err)) + return + } + if typ != d.typ { + d.logger.Error(d.ctx, "websocket type mismatch while decoding") + return + } + var value T + err = json.Unmarshal(b, &value) + if err != nil { + d.logger.Error(d.ctx, "error unmarshalling", slog.Error(err)) + return + } + select { + case values <- value: + // OK + case <-d.ctx.Done(): + return + } + } + }() + return values +} + +// nolint: revive // complains that Encoder has the same function name +func (d *Decoder[T]) Close() error { + err := d.conn.Close(websocket.StatusNormalClosure, "") + d.cancel() + return err +} + +// NewDecoder creates a JSON-over-websocket decoder for type T, which must be deserializable from +// JSON. +func NewDecoder[T any](conn *websocket.Conn, typ websocket.MessageType, logger slog.Logger) *Decoder[T] { + ctx, cancel := context.WithCancel(context.Background()) + return &Decoder[T]{conn: conn, ctx: ctx, cancel: cancel, typ: typ, logger: logger} +} diff --git a/codersdk/wsjson/encoder.go b/codersdk/wsjson/encoder.go new file mode 100644 index 0000000000000..4cde05984e690 --- /dev/null +++ b/codersdk/wsjson/encoder.go @@ -0,0 +1,42 @@ +package wsjson + +import ( + "context" + "encoding/json" + + "golang.org/x/xerrors" + "nhooyr.io/websocket" +) + +type Encoder[T any] struct { + conn *websocket.Conn + typ websocket.MessageType +} + +func (e *Encoder[T]) Encode(v T) error { + w, err := e.conn.Writer(context.Background(), e.typ) + if err != nil { + return xerrors.Errorf("get websocket writer: %w", err) + } + defer w.Close() + j := json.NewEncoder(w) + err = j.Encode(v) + if err != nil { + return xerrors.Errorf("encode json: %w", err) + } + return nil +} + +func (e *Encoder[T]) Close(c websocket.StatusCode) error { + return e.conn.Close(c, "") +} + +// NewEncoder creates a JSON-over websocket encoder for the type T, which must be JSON-serializable. +// You may then call Encode() to send objects over the websocket. Creating an Encoder closes the +// websocket for reading, turning it into a unidirectional write stream of JSON-encoded objects. +func NewEncoder[T any](conn *websocket.Conn, typ websocket.MessageType) *Encoder[T] { + // Here we close the websocket for reading, so that the websocket library will handle pings and + // close frames. + _ = conn.CloseRead(context.Background()) + return &Encoder[T]{conn: conn, typ: typ} +}