Skip to content

Commit e379c84

Browse files
committed
fix: fix goroutine leak in log streaming over websocket
1 parent 3014713 commit e379c84

File tree

6 files changed

+132
-78
lines changed

6 files changed

+132
-78
lines changed

coderd/provisionerjobs.go

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ import (
1515
"nhooyr.io/websocket"
1616

1717
"cdr.dev/slog"
18+
"github.com/coder/coder/v2/codersdk/wsjson"
1819

1920
"github.com/coder/coder/v2/coderd/database"
2021
"github.com/coder/coder/v2/coderd/database/db2sdk"
@@ -312,6 +313,7 @@ type logFollower struct {
312313
r *http.Request
313314
rw http.ResponseWriter
314315
conn *websocket.Conn
316+
enc *wsjson.Encoder[codersdk.ProvisionerJobLog]
315317

316318
jobID uuid.UUID
317319
after int64
@@ -391,6 +393,7 @@ func (f *logFollower) follow() {
391393
}
392394
defer f.conn.Close(websocket.StatusNormalClosure, "done")
393395
go httpapi.Heartbeat(f.ctx, f.conn)
396+
f.enc = wsjson.NewEncoder[codersdk.ProvisionerJobLog](f.conn, websocket.MessageText)
394397

395398
// query for logs once right away, so we can get historical data from before
396399
// subscription
@@ -488,11 +491,7 @@ func (f *logFollower) query() error {
488491
return xerrors.Errorf("error fetching logs: %w", err)
489492
}
490493
for _, log := range logs {
491-
logB, err := json.Marshal(convertProvisionerJobLog(log))
492-
if err != nil {
493-
return xerrors.Errorf("error marshaling log: %w", err)
494-
}
495-
err = f.conn.Write(f.ctx, websocket.MessageText, logB)
494+
err := f.enc.Encode(convertProvisionerJobLog(log))
496495
if err != nil {
497496
return xerrors.Errorf("error writing to websocket: %w", err)
498497
}

coderd/workspaceagents.go

Lines changed: 7 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ import (
3939
"github.com/coder/coder/v2/codersdk"
4040
"github.com/coder/coder/v2/codersdk/agentsdk"
4141
"github.com/coder/coder/v2/codersdk/workspacesdk"
42+
"github.com/coder/coder/v2/codersdk/wsjson"
4243
"github.com/coder/coder/v2/tailnet"
4344
"github.com/coder/coder/v2/tailnet/proto"
4445
)
@@ -396,11 +397,9 @@ func (api *API) workspaceAgentLogs(rw http.ResponseWriter, r *http.Request) {
396397
}
397398
go httpapi.Heartbeat(ctx, conn)
398399

399-
ctx, wsNetConn := codersdk.WebsocketNetConn(ctx, conn, websocket.MessageText)
400-
defer wsNetConn.Close() // Also closes conn.
400+
encoder := wsjson.NewEncoder[[]codersdk.WorkspaceAgentLog](conn, websocket.MessageText)
401+
defer encoder.Close(websocket.StatusGoingAway)
401402

402-
// The Go stdlib JSON encoder appends a newline character after message write.
403-
encoder := json.NewEncoder(wsNetConn)
404403
err = encoder.Encode(convertWorkspaceAgentLogs(logs))
405404
if err != nil {
406405
return
@@ -740,16 +739,8 @@ func (api *API) derpMapUpdates(rw http.ResponseWriter, r *http.Request) {
740739
})
741740
return
742741
}
743-
ctx, nconn := codersdk.WebsocketNetConn(ctx, ws, websocket.MessageBinary)
744-
defer nconn.Close()
745-
746-
// Slurp all packets from the connection into io.Discard so pongs get sent
747-
// by the websocket package. We don't do any reads ourselves so this is
748-
// necessary.
749-
go func() {
750-
_, _ = io.Copy(io.Discard, nconn)
751-
_ = nconn.Close()
752-
}()
742+
encoder := wsjson.NewEncoder[*tailcfg.DERPMap](ws, websocket.MessageBinary)
743+
defer encoder.Close(websocket.StatusGoingAway)
753744

754745
go func(ctx context.Context) {
755746
// TODO(mafredri): Is this too frequent? Use separate ping disconnect timeout?
@@ -767,7 +758,7 @@ func (api *API) derpMapUpdates(rw http.ResponseWriter, r *http.Request) {
767758
err := ws.Ping(ctx)
768759
cancel()
769760
if err != nil {
770-
_ = nconn.Close()
761+
_ = ws.Close(websocket.StatusGoingAway, "ping failed")
771762
return
772763
}
773764
}
@@ -780,9 +771,8 @@ func (api *API) derpMapUpdates(rw http.ResponseWriter, r *http.Request) {
780771
for {
781772
derpMap := api.DERPMap()
782773
if lastDERPMap == nil || !tailnet.CompareDERPMaps(lastDERPMap, derpMap) {
783-
err := json.NewEncoder(nconn).Encode(derpMap)
774+
err := encoder.Encode(derpMap)
784775
if err != nil {
785-
_ = nconn.Close()
786776
return
787777
}
788778
lastDERPMap = derpMap

codersdk/provisionerdaemons.go

Lines changed: 3 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ import (
1919

2020
"github.com/coder/coder/v2/buildinfo"
2121
"github.com/coder/coder/v2/codersdk/drpc"
22+
"github.com/coder/coder/v2/codersdk/wsjson"
2223
"github.com/coder/coder/v2/provisionerd/proto"
2324
"github.com/coder/coder/v2/provisionerd/runner"
2425
)
@@ -162,36 +163,8 @@ func (c *Client) provisionerJobLogsAfter(ctx context.Context, path string, after
162163
}
163164
return nil, nil, ReadBodyAsError(res)
164165
}
165-
logs := make(chan ProvisionerJobLog)
166-
closed := make(chan struct{})
167-
go func() {
168-
defer close(closed)
169-
defer close(logs)
170-
defer conn.Close(websocket.StatusGoingAway, "")
171-
var log ProvisionerJobLog
172-
for {
173-
msgType, msg, err := conn.Read(ctx)
174-
if err != nil {
175-
return
176-
}
177-
if msgType != websocket.MessageText {
178-
return
179-
}
180-
err = json.Unmarshal(msg, &log)
181-
if err != nil {
182-
return
183-
}
184-
select {
185-
case <-ctx.Done():
186-
return
187-
case logs <- log:
188-
}
189-
}
190-
}()
191-
return logs, closeFunc(func() error {
192-
<-closed
193-
return nil
194-
}), nil
166+
d := wsjson.NewDecoder[ProvisionerJobLog](conn, websocket.MessageText, c.logger)
167+
return d.Chan(), d, nil
195168
}
196169

197170
// ServeProvisionerDaemonRequest are the parameters to call ServeProvisionerDaemon with

codersdk/workspaceagents.go

Lines changed: 3 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ import (
1515
"nhooyr.io/websocket"
1616

1717
"github.com/coder/coder/v2/coderd/tracing"
18+
"github.com/coder/coder/v2/codersdk/wsjson"
1819
)
1920

2021
type WorkspaceAgentStatus string
@@ -454,30 +455,6 @@ func (c *Client) WorkspaceAgentLogsAfter(ctx context.Context, agentID uuid.UUID,
454455
}
455456
return nil, nil, ReadBodyAsError(res)
456457
}
457-
logChunks := make(chan []WorkspaceAgentLog, 1)
458-
closed := make(chan struct{})
459-
ctx, wsNetConn := WebsocketNetConn(ctx, conn, websocket.MessageText)
460-
decoder := json.NewDecoder(wsNetConn)
461-
go func() {
462-
defer close(closed)
463-
defer close(logChunks)
464-
defer conn.Close(websocket.StatusGoingAway, "")
465-
for {
466-
var logs []WorkspaceAgentLog
467-
err = decoder.Decode(&logs)
468-
if err != nil {
469-
return
470-
}
471-
select {
472-
case <-ctx.Done():
473-
return
474-
case logChunks <- logs:
475-
}
476-
}
477-
}()
478-
return logChunks, closeFunc(func() error {
479-
_ = wsNetConn.Close()
480-
<-closed
481-
return nil
482-
}), nil
458+
d := wsjson.NewDecoder[[]WorkspaceAgentLog](conn, websocket.MessageText, c.logger)
459+
return d.Chan(), d, nil
483460
}

codersdk/wsjson/decoder.go

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
package wsjson
2+
3+
import (
4+
"context"
5+
"encoding/json"
6+
"sync/atomic"
7+
8+
"nhooyr.io/websocket"
9+
10+
"cdr.dev/slog"
11+
)
12+
13+
type Decoder[T any] struct {
14+
conn *websocket.Conn
15+
typ websocket.MessageType
16+
ctx context.Context
17+
cancel context.CancelFunc
18+
chanCalled atomic.Bool
19+
logger slog.Logger
20+
}
21+
22+
// Chan starts the decoder reading from the websocket and returns a channel for reading the
23+
// resulting values. The chan T is closed if the underlying websocket is closed, or we encounter an
24+
// error. We also close the underlying websocket if we encounter an error reading or decoding.
25+
func (d *Decoder[T]) Chan() <-chan T {
26+
if d.chanCalled.Load() {
27+
panic("chan called more than once")
28+
}
29+
d.chanCalled.Store(true)
30+
values := make(chan T)
31+
go func() {
32+
defer close(values)
33+
defer d.conn.Close(websocket.StatusGoingAway, "")
34+
for {
35+
typ, b, err := d.conn.Read(context.Background())
36+
if err != nil {
37+
// might be benign like EOF, so just log at debug
38+
d.logger.Debug(d.ctx, "error reading from websocket", slog.Error(err))
39+
return
40+
}
41+
if typ != d.typ {
42+
d.logger.Error(d.ctx, "websocket type mismatch while decoding")
43+
return
44+
}
45+
var value T
46+
err = json.Unmarshal(b, &value)
47+
if err != nil {
48+
d.logger.Error(d.ctx, "error unmarshalling", slog.Error(err))
49+
return
50+
}
51+
select {
52+
case values <- value:
53+
// OK
54+
case <-d.ctx.Done():
55+
return
56+
}
57+
}
58+
}()
59+
return values
60+
}
61+
62+
func (d *Decoder[T]) Close() error {
63+
err := d.conn.Close(websocket.StatusGoingAway, "")
64+
d.cancel()
65+
return err
66+
}
67+
68+
// NewDecoder creates a JSON-over-websocket decoder for type T, which must be deserializable from
69+
// JSON.
70+
func NewDecoder[T any](conn *websocket.Conn, typ websocket.MessageType, logger slog.Logger) *Decoder[T] {
71+
ctx, cancel := context.WithCancel(context.Background())
72+
return &Decoder[T]{conn: conn, ctx: ctx, cancel: cancel, typ: typ, logger: logger}
73+
}

codersdk/wsjson/encoder.go

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
package wsjson
2+
3+
import (
4+
"context"
5+
"encoding/json"
6+
7+
"golang.org/x/xerrors"
8+
"nhooyr.io/websocket"
9+
)
10+
11+
type Encoder[T any] struct {
12+
conn *websocket.Conn
13+
typ websocket.MessageType
14+
}
15+
16+
func (e *Encoder[T]) Encode(v T) error {
17+
w, err := e.conn.Writer(context.Background(), e.typ)
18+
if err != nil {
19+
return xerrors.Errorf("get websocket writer: %w", err)
20+
}
21+
defer w.Close()
22+
j := json.NewEncoder(w)
23+
err = j.Encode(v)
24+
if err != nil {
25+
return xerrors.Errorf("encode json: %w", err)
26+
}
27+
return nil
28+
}
29+
30+
func (e *Encoder[T]) Close(c websocket.StatusCode) error {
31+
return e.conn.Close(c, "")
32+
}
33+
34+
// NewEncoder creates a JSON-over websocket encoder for the type T, which must be JSON-serializable.
35+
// You may then call Encode() to send objects over the websocket. Creating an Encoder closes the
36+
// websocket for reading, turning it into a unidirectional write stream of JSON-encoded objects.
37+
func NewEncoder[T any](conn *websocket.Conn, typ websocket.MessageType) *Encoder[T] {
38+
// Here we close the websocket for reading, so that the websocket library will handle pings and
39+
// close frames.
40+
_ = conn.CloseRead(context.Background())
41+
return &Encoder[T]{conn: conn, typ: typ}
42+
}

0 commit comments

Comments
 (0)