Skip to content

Commit 148a5a3

Browse files
authored
fix: fix goroutine leak in log streaming over websocket (coder#15709)
fixes coder#14881 Our handlers for streaming logs don't read from the websocket. We don't allow the client to send us any data, but the websocket library we use requires reading from the websocket to properly handle pings and closing. Not doing so can [can cause the websocket to hang on write](coder/websocket#405), leaking go routines which were noticed in coder#14881. This fixes the issue, and in process refactors our log streaming to a encoder/decoder package which provides generic types for sending JSON over websocket. I'd also like for us to upgrade to the latest https://github.com/coder/websocket but we should also upgrade our tailscale fork before doing so to avoid including two copies of the websocket library.
1 parent e4f6c9a commit 148a5a3

File tree

6 files changed

+134
-78
lines changed

6 files changed

+134
-78
lines changed

coderd/provisionerjobs.go

+4-5
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ import (
1515
"nhooyr.io/websocket"
1616

1717
"cdr.dev/slog"
18+
"github.com/coder/coder/v2/codersdk/wsjson"
1819

1920
"github.com/coder/coder/v2/coderd/database"
2021
"github.com/coder/coder/v2/coderd/database/db2sdk"
@@ -312,6 +313,7 @@ type logFollower struct {
312313
r *http.Request
313314
rw http.ResponseWriter
314315
conn *websocket.Conn
316+
enc *wsjson.Encoder[codersdk.ProvisionerJobLog]
315317

316318
jobID uuid.UUID
317319
after int64
@@ -391,6 +393,7 @@ func (f *logFollower) follow() {
391393
}
392394
defer f.conn.Close(websocket.StatusNormalClosure, "done")
393395
go httpapi.Heartbeat(f.ctx, f.conn)
396+
f.enc = wsjson.NewEncoder[codersdk.ProvisionerJobLog](f.conn, websocket.MessageText)
394397

395398
// query for logs once right away, so we can get historical data from before
396399
// subscription
@@ -488,11 +491,7 @@ func (f *logFollower) query() error {
488491
return xerrors.Errorf("error fetching logs: %w", err)
489492
}
490493
for _, log := range logs {
491-
logB, err := json.Marshal(convertProvisionerJobLog(log))
492-
if err != nil {
493-
return xerrors.Errorf("error marshaling log: %w", err)
494-
}
495-
err = f.conn.Write(f.ctx, websocket.MessageText, logB)
494+
err := f.enc.Encode(convertProvisionerJobLog(log))
496495
if err != nil {
497496
return xerrors.Errorf("error writing to websocket: %w", err)
498497
}

coderd/workspaceagents.go

+7-17
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ import (
3939
"github.com/coder/coder/v2/codersdk"
4040
"github.com/coder/coder/v2/codersdk/agentsdk"
4141
"github.com/coder/coder/v2/codersdk/workspacesdk"
42+
"github.com/coder/coder/v2/codersdk/wsjson"
4243
"github.com/coder/coder/v2/tailnet"
4344
"github.com/coder/coder/v2/tailnet/proto"
4445
)
@@ -396,11 +397,9 @@ func (api *API) workspaceAgentLogs(rw http.ResponseWriter, r *http.Request) {
396397
}
397398
go httpapi.Heartbeat(ctx, conn)
398399

399-
ctx, wsNetConn := codersdk.WebsocketNetConn(ctx, conn, websocket.MessageText)
400-
defer wsNetConn.Close() // Also closes conn.
400+
encoder := wsjson.NewEncoder[[]codersdk.WorkspaceAgentLog](conn, websocket.MessageText)
401+
defer encoder.Close(websocket.StatusNormalClosure)
401402

402-
// The Go stdlib JSON encoder appends a newline character after message write.
403-
encoder := json.NewEncoder(wsNetConn)
404403
err = encoder.Encode(convertWorkspaceAgentLogs(logs))
405404
if err != nil {
406405
return
@@ -740,16 +739,8 @@ func (api *API) derpMapUpdates(rw http.ResponseWriter, r *http.Request) {
740739
})
741740
return
742741
}
743-
ctx, nconn := codersdk.WebsocketNetConn(ctx, ws, websocket.MessageBinary)
744-
defer nconn.Close()
745-
746-
// Slurp all packets from the connection into io.Discard so pongs get sent
747-
// by the websocket package. We don't do any reads ourselves so this is
748-
// necessary.
749-
go func() {
750-
_, _ = io.Copy(io.Discard, nconn)
751-
_ = nconn.Close()
752-
}()
742+
encoder := wsjson.NewEncoder[*tailcfg.DERPMap](ws, websocket.MessageBinary)
743+
defer encoder.Close(websocket.StatusGoingAway)
753744

754745
go func(ctx context.Context) {
755746
// TODO(mafredri): Is this too frequent? Use separate ping disconnect timeout?
@@ -767,7 +758,7 @@ func (api *API) derpMapUpdates(rw http.ResponseWriter, r *http.Request) {
767758
err := ws.Ping(ctx)
768759
cancel()
769760
if err != nil {
770-
_ = nconn.Close()
761+
_ = ws.Close(websocket.StatusGoingAway, "ping failed")
771762
return
772763
}
773764
}
@@ -780,9 +771,8 @@ func (api *API) derpMapUpdates(rw http.ResponseWriter, r *http.Request) {
780771
for {
781772
derpMap := api.DERPMap()
782773
if lastDERPMap == nil || !tailnet.CompareDERPMaps(lastDERPMap, derpMap) {
783-
err := json.NewEncoder(nconn).Encode(derpMap)
774+
err := encoder.Encode(derpMap)
784775
if err != nil {
785-
_ = nconn.Close()
786776
return
787777
}
788778
lastDERPMap = derpMap

codersdk/provisionerdaemons.go

+3-30
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ import (
1919

2020
"github.com/coder/coder/v2/buildinfo"
2121
"github.com/coder/coder/v2/codersdk/drpc"
22+
"github.com/coder/coder/v2/codersdk/wsjson"
2223
"github.com/coder/coder/v2/provisionerd/proto"
2324
"github.com/coder/coder/v2/provisionerd/runner"
2425
)
@@ -162,36 +163,8 @@ func (c *Client) provisionerJobLogsAfter(ctx context.Context, path string, after
162163
}
163164
return nil, nil, ReadBodyAsError(res)
164165
}
165-
logs := make(chan ProvisionerJobLog)
166-
closed := make(chan struct{})
167-
go func() {
168-
defer close(closed)
169-
defer close(logs)
170-
defer conn.Close(websocket.StatusGoingAway, "")
171-
var log ProvisionerJobLog
172-
for {
173-
msgType, msg, err := conn.Read(ctx)
174-
if err != nil {
175-
return
176-
}
177-
if msgType != websocket.MessageText {
178-
return
179-
}
180-
err = json.Unmarshal(msg, &log)
181-
if err != nil {
182-
return
183-
}
184-
select {
185-
case <-ctx.Done():
186-
return
187-
case logs <- log:
188-
}
189-
}
190-
}()
191-
return logs, closeFunc(func() error {
192-
<-closed
193-
return nil
194-
}), nil
166+
d := wsjson.NewDecoder[ProvisionerJobLog](conn, websocket.MessageText, c.logger)
167+
return d.Chan(), d, nil
195168
}
196169

197170
// ServeProvisionerDaemonRequest are the parameters to call ServeProvisionerDaemon with

codersdk/workspaceagents.go

+3-26
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ import (
1515
"nhooyr.io/websocket"
1616

1717
"github.com/coder/coder/v2/coderd/tracing"
18+
"github.com/coder/coder/v2/codersdk/wsjson"
1819
)
1920

2021
type WorkspaceAgentStatus string
@@ -454,30 +455,6 @@ func (c *Client) WorkspaceAgentLogsAfter(ctx context.Context, agentID uuid.UUID,
454455
}
455456
return nil, nil, ReadBodyAsError(res)
456457
}
457-
logChunks := make(chan []WorkspaceAgentLog, 1)
458-
closed := make(chan struct{})
459-
ctx, wsNetConn := WebsocketNetConn(ctx, conn, websocket.MessageText)
460-
decoder := json.NewDecoder(wsNetConn)
461-
go func() {
462-
defer close(closed)
463-
defer close(logChunks)
464-
defer conn.Close(websocket.StatusGoingAway, "")
465-
for {
466-
var logs []WorkspaceAgentLog
467-
err = decoder.Decode(&logs)
468-
if err != nil {
469-
return
470-
}
471-
select {
472-
case <-ctx.Done():
473-
return
474-
case logChunks <- logs:
475-
}
476-
}
477-
}()
478-
return logChunks, closeFunc(func() error {
479-
_ = wsNetConn.Close()
480-
<-closed
481-
return nil
482-
}), nil
458+
d := wsjson.NewDecoder[[]WorkspaceAgentLog](conn, websocket.MessageText, c.logger)
459+
return d.Chan(), d, nil
483460
}

codersdk/wsjson/decoder.go

+75
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
package wsjson
2+
3+
import (
4+
"context"
5+
"encoding/json"
6+
"sync/atomic"
7+
8+
"nhooyr.io/websocket"
9+
10+
"cdr.dev/slog"
11+
)
12+
13+
type Decoder[T any] struct {
14+
conn *websocket.Conn
15+
typ websocket.MessageType
16+
ctx context.Context
17+
cancel context.CancelFunc
18+
chanCalled atomic.Bool
19+
logger slog.Logger
20+
}
21+
22+
// Chan starts the decoder reading from the websocket and returns a channel for reading the
23+
// resulting values. The chan T is closed if the underlying websocket is closed, or we encounter an
24+
// error. We also close the underlying websocket if we encounter an error reading or decoding.
25+
func (d *Decoder[T]) Chan() <-chan T {
26+
if !d.chanCalled.CompareAndSwap(false, true) {
27+
panic("chan called more than once")
28+
}
29+
values := make(chan T, 1)
30+
go func() {
31+
defer close(values)
32+
defer d.conn.Close(websocket.StatusGoingAway, "")
33+
for {
34+
// we don't use d.ctx here because it only gets canceled after closing the connection
35+
// and a "connection closed" type error is more clear than context canceled.
36+
typ, b, err := d.conn.Read(context.Background())
37+
if err != nil {
38+
// might be benign like EOF, so just log at debug
39+
d.logger.Debug(d.ctx, "error reading from websocket", slog.Error(err))
40+
return
41+
}
42+
if typ != d.typ {
43+
d.logger.Error(d.ctx, "websocket type mismatch while decoding")
44+
return
45+
}
46+
var value T
47+
err = json.Unmarshal(b, &value)
48+
if err != nil {
49+
d.logger.Error(d.ctx, "error unmarshalling", slog.Error(err))
50+
return
51+
}
52+
select {
53+
case values <- value:
54+
// OK
55+
case <-d.ctx.Done():
56+
return
57+
}
58+
}
59+
}()
60+
return values
61+
}
62+
63+
// nolint: revive // complains that Encoder has the same function name
64+
func (d *Decoder[T]) Close() error {
65+
err := d.conn.Close(websocket.StatusNormalClosure, "")
66+
d.cancel()
67+
return err
68+
}
69+
70+
// NewDecoder creates a JSON-over-websocket decoder for type T, which must be deserializable from
71+
// JSON.
72+
func NewDecoder[T any](conn *websocket.Conn, typ websocket.MessageType, logger slog.Logger) *Decoder[T] {
73+
ctx, cancel := context.WithCancel(context.Background())
74+
return &Decoder[T]{conn: conn, ctx: ctx, cancel: cancel, typ: typ, logger: logger}
75+
}

codersdk/wsjson/encoder.go

+42
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
package wsjson
2+
3+
import (
4+
"context"
5+
"encoding/json"
6+
7+
"golang.org/x/xerrors"
8+
"nhooyr.io/websocket"
9+
)
10+
11+
type Encoder[T any] struct {
12+
conn *websocket.Conn
13+
typ websocket.MessageType
14+
}
15+
16+
func (e *Encoder[T]) Encode(v T) error {
17+
w, err := e.conn.Writer(context.Background(), e.typ)
18+
if err != nil {
19+
return xerrors.Errorf("get websocket writer: %w", err)
20+
}
21+
defer w.Close()
22+
j := json.NewEncoder(w)
23+
err = j.Encode(v)
24+
if err != nil {
25+
return xerrors.Errorf("encode json: %w", err)
26+
}
27+
return nil
28+
}
29+
30+
func (e *Encoder[T]) Close(c websocket.StatusCode) error {
31+
return e.conn.Close(c, "")
32+
}
33+
34+
// NewEncoder creates a JSON-over websocket encoder for the type T, which must be JSON-serializable.
35+
// You may then call Encode() to send objects over the websocket. Creating an Encoder closes the
36+
// websocket for reading, turning it into a unidirectional write stream of JSON-encoded objects.
37+
func NewEncoder[T any](conn *websocket.Conn, typ websocket.MessageType) *Encoder[T] {
38+
// Here we close the websocket for reading, so that the websocket library will handle pings and
39+
// close frames.
40+
_ = conn.CloseRead(context.Background())
41+
return &Encoder[T]{conn: conn, typ: typ}
42+
}

0 commit comments

Comments
 (0)