Skip to content

Commit bf64a43

Browse files
committed
Refactor websockets to be tracked
1 parent 884c71b commit bf64a43

File tree

5 files changed

+369
-406
lines changed

5 files changed

+369
-406
lines changed

coderd/coderd.go

+4-5
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ import (
1212
"path/filepath"
1313
"regexp"
1414
"strings"
15-
"sync"
1615
"sync/atomic"
1716
"time"
1817

@@ -316,6 +315,7 @@ func New(options *Options) *API {
316315
TemplateScheduleStore: options.TemplateScheduleStore,
317316
Experiments: experiments,
318317
healthCheckGroup: &singleflight.Group[string, *healthcheck.Report]{},
318+
WebsocketWatch: NewActiveWebsockets(ctx),
319319
}
320320
if options.UpdateCheckOptions != nil {
321321
api.updateChecker = updatecheck.New(
@@ -784,9 +784,8 @@ type API struct {
784784

785785
siteHandler http.Handler
786786

787-
WebsocketWaitGroup sync.WaitGroup
788-
WebsocketWatch *ActiveWebsockets
789-
derpCloseFunc func()
787+
WebsocketWatch *ActiveWebsockets
788+
derpCloseFunc func()
790789

791790
metricsCache *metricscache.Cache
792791
workspaceAgentCache *wsconncache.Cache
@@ -805,7 +804,7 @@ func (api *API) Close() error {
805804
api.cancel()
806805
api.derpCloseFunc()
807806

808-
api.WebsocketWaitGroup.Wait()
807+
api.WebsocketWatch.Close()
809808

810809
api.metricsCache.Close()
811810
if api.updateChecker != nil {

coderd/provisionerjobs.go

+43-51
Original file line numberDiff line numberDiff line change
@@ -113,69 +113,61 @@ func (api *API) provisionerJobLogs(rw http.ResponseWriter, r *http.Request, job
113113
logs = []database.ProvisionerJobLog{}
114114
}
115115

116-
api.WebsocketWaitGroup.Add(1)
117-
defer api.WebsocketWaitGroup.Done()
118-
conn, err := websocket.Accept(rw, r, nil)
119-
if err != nil {
120-
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
121-
Message: "Failed to accept websocket.",
122-
Detail: err.Error(),
123-
})
124-
return
125-
}
126-
go httpapi.Heartbeat(ctx, conn)
116+
api.WebsocketWatch.Accept(rw, r, nil, func(conn *websocket.Conn) {
117+
go httpapi.Heartbeat(ctx, conn)
127118

128-
ctx, wsNetConn := websocketNetConn(ctx, conn, websocket.MessageText)
129-
defer wsNetConn.Close() // Also closes conn.
119+
ctx, wsNetConn := websocketNetConn(ctx, conn, websocket.MessageText)
120+
defer wsNetConn.Close() // Also closes conn.
130121

131-
logIdsDone := make(map[int64]bool)
122+
logIdsDone := make(map[int64]bool)
132123

133-
// The Go stdlib JSON encoder appends a newline character after message write.
134-
encoder := json.NewEncoder(wsNetConn)
135-
for _, provisionerJobLog := range logs {
136-
logIdsDone[provisionerJobLog.ID] = true
137-
err = encoder.Encode(convertProvisionerJobLog(provisionerJobLog))
124+
// The Go stdlib JSON encoder appends a newline character after message write.
125+
encoder := json.NewEncoder(wsNetConn)
126+
for _, provisionerJobLog := range logs {
127+
logIdsDone[provisionerJobLog.ID] = true
128+
err = encoder.Encode(convertProvisionerJobLog(provisionerJobLog))
129+
if err != nil {
130+
return
131+
}
132+
}
133+
job, err = api.Database.GetProvisionerJobByID(ctx, job.ID)
138134
if err != nil {
135+
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
136+
Message: "Internal error fetching provisioner job.",
137+
Detail: err.Error(),
138+
})
139139
return
140140
}
141-
}
142-
job, err = api.Database.GetProvisionerJobByID(ctx, job.ID)
143-
if err != nil {
144-
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
145-
Message: "Internal error fetching provisioner job.",
146-
Detail: err.Error(),
147-
})
148-
return
149-
}
150-
if job.CompletedAt.Valid {
151-
// job was complete before we queried the database for historical logs
152-
return
153-
}
154-
155-
for {
156-
select {
157-
case <-ctx.Done():
158-
logger.Debug(context.Background(), "job logs context canceled")
141+
if job.CompletedAt.Valid {
142+
// job was complete before we queried the database for historical logs
159143
return
160-
case log, ok := <-bufferedLogs:
161-
// A nil log is sent when complete!
162-
if !ok || log == nil {
163-
logger.Debug(context.Background(), "reached the end of published logs")
144+
}
145+
146+
for {
147+
select {
148+
case <-ctx.Done():
149+
logger.Debug(context.Background(), "job logs context canceled")
164150
return
165-
}
166-
if logIdsDone[log.ID] {
167-
logger.Debug(ctx, "subscribe duplicated log",
168-
slog.F("stage", log.Stage))
169-
} else {
170-
logger.Debug(ctx, "subscribe encoding log",
171-
slog.F("stage", log.Stage))
172-
err = encoder.Encode(convertProvisionerJobLog(*log))
173-
if err != nil {
151+
case log, ok := <-bufferedLogs:
152+
// A nil log is sent when complete!
153+
if !ok || log == nil {
154+
logger.Debug(context.Background(), "reached the end of published logs")
174155
return
175156
}
157+
if logIdsDone[log.ID] {
158+
logger.Debug(ctx, "subscribe duplicated log",
159+
slog.F("stage", log.Stage))
160+
} else {
161+
logger.Debug(ctx, "subscribe encoding log",
162+
slog.F("stage", log.Stage))
163+
err = encoder.Encode(convertProvisionerJobLog(*log))
164+
if err != nil {
165+
return
166+
}
167+
}
176168
}
177169
}
178-
}
170+
})
179171
}
180172

181173
func (api *API) provisionerJobResources(rw http.ResponseWriter, r *http.Request, job database.ProvisionerJob) {

coderd/sockets.go

+22-8
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package coderd
33
import (
44
"context"
55
"net/http"
6+
"runtime/pprof"
67
"sync"
78

89
"nhooyr.io/websocket"
@@ -12,15 +13,17 @@ import (
1213
)
1314

1415
// ActiveWebsockets is a helper struct that can be used to track active
15-
// websocket connections.
16+
// websocket connections. All connections will be closed when the parent
17+
// context is canceled.
1618
type ActiveWebsockets struct {
1719
ctx context.Context
1820
cancel func()
1921

2022
wg sync.WaitGroup
2123
}
2224

23-
func NewActiveWebsockets(ctx context.Context, cancel func()) *ActiveWebsockets {
25+
func NewActiveWebsockets(ctx context.Context) *ActiveWebsockets {
26+
ctx, cancel := context.WithCancel(ctx)
2427
return &ActiveWebsockets{
2528
ctx: ctx,
2629
cancel: cancel,
@@ -30,6 +33,14 @@ func NewActiveWebsockets(ctx context.Context, cancel func()) *ActiveWebsockets {
3033
// Accept accepts a websocket connection and calls f with the connection.
3134
// The function will be tracked by the ActiveWebsockets struct and will be
3235
// closed when the parent context is canceled.
36+
// Steps:
37+
// 1. Ensure we are still accepting websocket connections, and not shutting down.
38+
// 2. Add 1 to the wait group.
39+
// 3. Ensure we decrement the wait group when we are done (defer).
40+
// 4. Accept the websocket connection.
41+
// 4a. If there is an error, write the error to the response writer and return.
42+
// 5. Launch go routine to kill websocket if the parent context is canceled.
43+
// 6. Call 'f' with the websocket connection.
3344
func (a *ActiveWebsockets) Accept(rw http.ResponseWriter, r *http.Request, options *websocket.AcceptOptions, f func(conn *websocket.Conn)) {
3445
// Ensure we are still accepting websocket connections, and not shutting down.
3546
if err := a.ctx.Err(); err != nil {
@@ -58,23 +69,26 @@ func (a *ActiveWebsockets) Accept(rw http.ResponseWriter, r *http.Request, optio
5869
// the connection is closed.
5970
ctx, cancel := context.WithCancel(a.ctx)
6071
defer cancel()
61-
a.track(ctx, conn)
72+
closeConnOnContext(ctx, conn)
6273

6374
// Handle the websocket connection
6475
f(conn)
6576
}
6677

67-
// Track runs a go routine that will close a given websocket connection when
68-
// the parent context is canceled.
69-
func (a *ActiveWebsockets) track(ctx context.Context, conn *websocket.Conn) {
70-
go func() {
78+
// closeConnOnContext launches a go routine that will watch a given context
79+
// and close a websocket connection if that context is canceled.
80+
func closeConnOnContext(ctx context.Context, conn *websocket.Conn) {
81+
// Labeling the go routine for goroutine dumps/debugging.
82+
go pprof.Do(ctx, pprof.Labels("service", "api-server", "function", "ActiveWebsockets.track"), func(ctx context.Context) {
7183
select {
7284
case <-ctx.Done():
7385
_ = conn.Close(websocket.StatusNormalClosure, "")
7486
}
75-
}()
87+
})
7688
}
7789

90+
// Close will close all active websocket connections and wait for them to
91+
// finish.
7892
func (a *ActiveWebsockets) Close() {
7993
a.cancel()
8094
a.wg.Wait()

0 commit comments

Comments
 (0)