Skip to content

Commit 884c71b

Browse files
committed
chore: Refactor accepting websocket connections to track for close
1 parent 29e9b9e commit 884c71b

File tree

5 files changed

+97
-34
lines changed

5 files changed

+97
-34
lines changed

coderd/coderd.go

+1-3
Original file line numberDiff line numberDiff line change
@@ -784,8 +784,8 @@ type API struct {
784784

785785
siteHandler http.Handler
786786

787-
WebsocketWaitMutex sync.Mutex
788787
WebsocketWaitGroup sync.WaitGroup
788+
WebsocketWatch *ActiveWebsockets
789789
derpCloseFunc func()
790790

791791
metricsCache *metricscache.Cache
@@ -805,9 +805,7 @@ func (api *API) Close() error {
805805
api.cancel()
806806
api.derpCloseFunc()
807807

808-
api.WebsocketWaitMutex.Lock()
809808
api.WebsocketWaitGroup.Wait()
810-
api.WebsocketWaitMutex.Unlock()
811809

812810
api.metricsCache.Close()
813811
if api.updateChecker != nil {

coderd/provisionerjobs.go

-2
Original file line numberDiff line numberDiff line change
@@ -113,9 +113,7 @@ func (api *API) provisionerJobLogs(rw http.ResponseWriter, r *http.Request, job
113113
logs = []database.ProvisionerJobLog{}
114114
}
115115

116-
api.WebsocketWaitMutex.Lock()
117116
api.WebsocketWaitGroup.Add(1)
118-
api.WebsocketWaitMutex.Unlock()
119117
defer api.WebsocketWaitGroup.Done()
120118
conn, err := websocket.Accept(rw, r, nil)
121119
if err != nil {

coderd/sockets.go

+81
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
package coderd
2+
3+
import (
4+
"context"
5+
"net/http"
6+
"sync"
7+
8+
"nhooyr.io/websocket"
9+
10+
"github.com/coder/coder/coderd/httpapi"
11+
"github.com/coder/coder/codersdk"
12+
)
13+
14+
// ActiveWebsockets is a helper struct that can be used to track active
15+
// websocket connections.
16+
type ActiveWebsockets struct {
17+
ctx context.Context
18+
cancel func()
19+
20+
wg sync.WaitGroup
21+
}
22+
23+
func NewActiveWebsockets(ctx context.Context, cancel func()) *ActiveWebsockets {
24+
return &ActiveWebsockets{
25+
ctx: ctx,
26+
cancel: cancel,
27+
}
28+
}
29+
30+
// Accept accepts a websocket connection and calls f with the connection.
31+
// The function will be tracked by the ActiveWebsockets struct and will be
32+
// closed when the parent context is canceled.
33+
func (a *ActiveWebsockets) Accept(rw http.ResponseWriter, r *http.Request, options *websocket.AcceptOptions, f func(conn *websocket.Conn)) {
34+
// Ensure we are still accepting websocket connections, and not shutting down.
35+
if err := a.ctx.Err(); err != nil {
36+
httpapi.Write(context.Background(), rw, http.StatusBadRequest, codersdk.Response{
37+
Message: "No longer accepting websocket requests.",
38+
Detail: err.Error(),
39+
})
40+
return
41+
}
42+
// Ensure we decrement the wait group when we are done.
43+
a.wg.Add(1)
44+
defer a.wg.Done()
45+
46+
// Accept the websocket connection
47+
conn, err := websocket.Accept(rw, r, options)
48+
if err != nil {
49+
httpapi.Write(context.Background(), rw, http.StatusBadRequest, codersdk.Response{
50+
Message: "Failed to accept websocket.",
51+
Detail: err.Error(),
52+
})
53+
return
54+
}
55+
// Always track the connection before allowing the caller to handle it.
56+
// This ensures the connection is closed when the parent context is canceled.
57+
// This new context will end if the parent context is cancelled or if
58+
// the connection is closed.
59+
ctx, cancel := context.WithCancel(a.ctx)
60+
defer cancel()
61+
a.track(ctx, conn)
62+
63+
// Handle the websocket connection
64+
f(conn)
65+
}
66+
67+
// Track runs a go routine that will close a given websocket connection when
68+
// the parent context is canceled.
69+
func (a *ActiveWebsockets) track(ctx context.Context, conn *websocket.Conn) {
70+
go func() {
71+
select {
72+
case <-ctx.Done():
73+
_ = conn.Close(websocket.StatusNormalClosure, "")
74+
}
75+
}()
76+
}
77+
78+
func (a *ActiveWebsockets) Close() {
79+
a.cancel()
80+
a.wg.Wait()
81+
}

coderd/workspaceagents.go

+15-27
Original file line numberDiff line numberDiff line change
@@ -435,9 +435,7 @@ func (api *API) workspaceAgentStartupLogs(rw http.ResponseWriter, r *http.Reques
435435
return
436436
}
437437

438-
api.WebsocketWaitMutex.Lock()
439438
api.WebsocketWaitGroup.Add(1)
440-
api.WebsocketWaitMutex.Unlock()
441439
defer api.WebsocketWaitGroup.Done()
442440
conn, err := websocket.Accept(rw, r, nil)
443441
if err != nil {
@@ -559,9 +557,7 @@ func (api *API) workspaceAgentStartupLogs(rw http.ResponseWriter, r *http.Reques
559557
func (api *API) workspaceAgentPTY(rw http.ResponseWriter, r *http.Request) {
560558
ctx := r.Context()
561559

562-
api.WebsocketWaitMutex.Lock()
563560
api.WebsocketWaitGroup.Add(1)
564-
api.WebsocketWaitMutex.Unlock()
565561
defer api.WebsocketWaitGroup.Done()
566562

567563
appToken, ok := workspaceapps.ResolveRequest(api.Logger, api.AccessURL, api.WorkspaceAppsProvider, rw, r, workspaceapps.Request{
@@ -816,9 +812,7 @@ func (api *API) workspaceAgentConnection(rw http.ResponseWriter, r *http.Request
816812
func (api *API) workspaceAgentCoordinate(rw http.ResponseWriter, r *http.Request) {
817813
ctx := r.Context()
818814

819-
api.WebsocketWaitMutex.Lock()
820815
api.WebsocketWaitGroup.Add(1)
821-
api.WebsocketWaitMutex.Unlock()
822816
defer api.WebsocketWaitGroup.Done()
823817
workspaceAgent := httpmw.WorkspaceAgent(r)
824818
resource, err := api.Database.GetWorkspaceResourceByID(ctx, workspaceAgent.ResourceID)
@@ -1096,31 +1090,25 @@ func (api *API) workspaceAgentClientCoordinate(rw http.ResponseWriter, r *http.R
10961090
}
10971091
}
10981092

1099-
api.WebsocketWaitMutex.Lock()
1100-
api.WebsocketWaitGroup.Add(1)
1101-
api.WebsocketWaitMutex.Unlock()
1102-
defer api.WebsocketWaitGroup.Done()
11031093
workspaceAgent := httpmw.WorkspaceAgentParam(r)
11041094

1105-
conn, err := websocket.Accept(rw, r, nil)
1106-
if err != nil {
1107-
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
1108-
Message: "Failed to accept websocket.",
1109-
Detail: err.Error(),
1110-
})
1111-
return
1112-
}
1113-
ctx, wsNetConn := websocketNetConn(ctx, conn, websocket.MessageBinary)
1114-
defer wsNetConn.Close()
1095+
api.WebsocketWatch.Accept(rw, r, nil, func(conn *websocket.Conn) {
1096+
ctx, wsNetConn := websocketNetConn(ctx, conn, websocket.MessageBinary)
1097+
defer wsNetConn.Close()
11151098

1116-
go httpapi.Heartbeat(ctx, conn)
1099+
// Track for graceful shutdown.
1100+
api.WebsocketWatch.Add(wsNetConn)
1101+
defer api.WebsocketWatch.Done()
11171102

1118-
defer conn.Close(websocket.StatusNormalClosure, "")
1119-
err = (*api.TailnetCoordinator.Load()).ServeClient(wsNetConn, uuid.New(), workspaceAgent.ID)
1120-
if err != nil {
1121-
_ = conn.Close(websocket.StatusInternalError, err.Error())
1122-
return
1123-
}
1103+
go httpapi.Heartbeat(ctx, conn)
1104+
1105+
defer conn.Close(websocket.StatusNormalClosure, "")
1106+
err := (*api.TailnetCoordinator.Load()).ServeClient(wsNetConn, uuid.New(), workspaceAgent.ID)
1107+
if err != nil {
1108+
_ = conn.Close(websocket.StatusInternalError, err.Error())
1109+
return
1110+
}
1111+
})
11241112
}
11251113

11261114
func convertApps(dbApps []database.WorkspaceApp) []codersdk.WorkspaceApp {

enterprise/coderd/provisionerdaemons.go

-2
Original file line numberDiff line numberDiff line change
@@ -185,9 +185,7 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request)
185185
return
186186
}
187187

188-
api.AGPL.WebsocketWaitMutex.Lock()
189188
api.AGPL.WebsocketWaitGroup.Add(1)
190-
api.AGPL.WebsocketWaitMutex.Unlock()
191189
defer api.AGPL.WebsocketWaitGroup.Done()
192190

193191
conn, err := websocket.Accept(rw, r, &websocket.AcceptOptions{

0 commit comments

Comments
 (0)