Skip to content

Commit 5df7872

Browse files
authored
fix: Improve use of context in websocket.NetConn code paths (#6198)
1 parent 6fb8aff commit 5df7872

File tree

5 files changed

+162
-19
lines changed

5 files changed

+162
-19
lines changed

coderd/workspaceagents.go

+4-1
Original file line numberDiff line numberDiff line change
@@ -748,10 +748,13 @@ func (api *API) workspaceAgentClientCoordinate(rw http.ResponseWriter, r *http.R
748748
})
749749
return
750750
}
751+
ctx, wsNetConn := websocketNetConn(ctx, conn, websocket.MessageBinary)
752+
defer wsNetConn.Close()
753+
751754
go httpapi.Heartbeat(ctx, conn)
752755

753756
defer conn.Close(websocket.StatusNormalClosure, "")
754-
err = (*api.TailnetCoordinator.Load()).ServeClient(websocket.NetConn(ctx, conn, websocket.MessageBinary), uuid.New(), workspaceAgent.ID)
757+
err = (*api.TailnetCoordinator.Load()).ServeClient(wsNetConn, uuid.New(), workspaceAgent.ID)
755758
if err != nil {
756759
_ = conn.Close(websocket.StatusInternalError, err.Error())
757760
return

codersdk/agentsdk/agentsdk.go

+44-1
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,8 @@ func (c *Client) Listen(ctx context.Context) (net.Conn, error) {
159159
return nil, codersdk.ReadBodyAsError(res)
160160
}
161161

162+
ctx, wsNetConn := websocketNetConn(ctx, conn, websocket.MessageBinary)
163+
162164
// Ping once every 30 seconds to ensure that the websocket is alive. If we
163165
// don't get a response within 30s we kill the websocket and reconnect.
164166
// See: https://github.com/coder/coder/pull/5824
@@ -195,7 +197,7 @@ func (c *Client) Listen(ctx context.Context) (net.Conn, error) {
195197
}
196198
}()
197199

198-
return websocket.NetConn(ctx, conn, websocket.MessageBinary), nil
200+
return wsNetConn, nil
199201
}
200202

201203
type PostAppHealthsRequest struct {
@@ -529,3 +531,44 @@ type closeFunc func() error
529531
func (c closeFunc) Close() error {
530532
return c()
531533
}
534+
535+
// wsNetConn wraps net.Conn created by websocket.NetConn(). Cancel func
536+
// is called if a read or write error is encountered.
537+
type wsNetConn struct {
538+
cancel context.CancelFunc
539+
net.Conn
540+
}
541+
542+
func (c *wsNetConn) Read(b []byte) (n int, err error) {
543+
n, err = c.Conn.Read(b)
544+
if err != nil {
545+
c.cancel()
546+
}
547+
return n, err
548+
}
549+
550+
func (c *wsNetConn) Write(b []byte) (n int, err error) {
551+
n, err = c.Conn.Write(b)
552+
if err != nil {
553+
c.cancel()
554+
}
555+
return n, err
556+
}
557+
558+
func (c *wsNetConn) Close() error {
559+
defer c.cancel()
560+
return c.Conn.Close()
561+
}
562+
563+
// websocketNetConn wraps websocket.NetConn and returns a context that
564+
// is tied to the parent context and the lifetime of the conn. Any error
565+
// during read or write will cancel the context, but not close the
566+
// conn. Close should be called to release context resources.
567+
func websocketNetConn(ctx context.Context, conn *websocket.Conn, msgType websocket.MessageType) (context.Context, net.Conn) {
568+
ctx, cancel := context.WithCancel(ctx)
569+
nc := websocket.NetConn(ctx, conn, msgType)
570+
return ctx, &wsNetConn{
571+
cancel: cancel,
572+
Conn: nc,
573+
}
574+
}

codersdk/provisionerdaemons.go

+54-4
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"encoding/json"
77
"fmt"
88
"io"
9+
"net"
910
"net/http"
1011
"net/http/cookiejar"
1112
"net/url"
@@ -143,8 +144,9 @@ func (c *Client) provisionerJobLogsAfter(ctx context.Context, path string, after
143144
return nil, nil, ReadBodyAsError(res)
144145
}
145146
logs := make(chan ProvisionerJobLog)
146-
decoder := json.NewDecoder(websocket.NetConn(ctx, conn, websocket.MessageText))
147147
closed := make(chan struct{})
148+
ctx, wsNetConn := websocketNetConn(ctx, conn, websocket.MessageText)
149+
decoder := json.NewDecoder(wsNetConn)
148150
go func() {
149151
defer close(closed)
150152
defer close(logs)
@@ -163,13 +165,15 @@ func (c *Client) provisionerJobLogsAfter(ctx context.Context, path string, after
163165
}
164166
}()
165167
return logs, closeFunc(func() error {
166-
_ = conn.Close(websocket.StatusNormalClosure, "")
168+
_ = wsNetConn.Close()
167169
<-closed
168170
return nil
169171
}), nil
170172
}
171173

172-
// ListenProvisionerDaemon returns the gRPC service for a provisioner daemon implementation.
174+
// ListenProvisionerDaemon returns the gRPC service for a provisioner daemon
175+
// implementation. The context is during dial, not during the lifetime of the
176+
// client. Client should be closed after use.
173177
func (c *Client) ServeProvisionerDaemon(ctx context.Context, organization uuid.UUID, provisioners []ProvisionerType, tags map[string]string) (proto.DRPCProvisionerDaemonClient, error) {
174178
serverURL, err := c.URL.Parse(fmt.Sprintf("/api/v2/organizations/%s/provisionerdaemons/serve", organization))
175179
if err != nil {
@@ -210,9 +214,55 @@ func (c *Client) ServeProvisionerDaemon(ctx context.Context, organization uuid.U
210214

211215
config := yamux.DefaultConfig()
212216
config.LogOutput = io.Discard
213-
session, err := yamux.Client(websocket.NetConn(ctx, conn, websocket.MessageBinary), config)
217+
// Use background context because caller should close the client.
218+
_, wsNetConn := websocketNetConn(context.Background(), conn, websocket.MessageBinary)
219+
session, err := yamux.Client(wsNetConn, config)
214220
if err != nil {
221+
_ = conn.Close(websocket.StatusGoingAway, "")
222+
_ = wsNetConn.Close()
215223
return nil, xerrors.Errorf("multiplex client: %w", err)
216224
}
217225
return proto.NewDRPCProvisionerDaemonClient(provisionersdk.MultiplexedConn(session)), nil
218226
}
227+
228+
// wsNetConn wraps net.Conn created by websocket.NetConn(). Cancel func
229+
// is called if a read or write error is encountered.
230+
// @typescript-ignore wsNetConn
231+
type wsNetConn struct {
232+
cancel context.CancelFunc
233+
net.Conn
234+
}
235+
236+
func (c *wsNetConn) Read(b []byte) (n int, err error) {
237+
n, err = c.Conn.Read(b)
238+
if err != nil {
239+
c.cancel()
240+
}
241+
return n, err
242+
}
243+
244+
func (c *wsNetConn) Write(b []byte) (n int, err error) {
245+
n, err = c.Conn.Write(b)
246+
if err != nil {
247+
c.cancel()
248+
}
249+
return n, err
250+
}
251+
252+
func (c *wsNetConn) Close() error {
253+
defer c.cancel()
254+
return c.Conn.Close()
255+
}
256+
257+
// websocketNetConn wraps websocket.NetConn and returns a context that
258+
// is tied to the parent context and the lifetime of the conn. Any error
259+
// during read or write will cancel the context, but not close the
260+
// conn. Close should be called to release context resources.
261+
func websocketNetConn(ctx context.Context, conn *websocket.Conn, msgType websocket.MessageType) (context.Context, net.Conn) {
262+
ctx, cancel := context.WithCancel(ctx)
263+
nc := websocket.NetConn(ctx, conn, msgType)
264+
return ctx, &wsNetConn{
265+
cancel: cancel,
266+
Conn: nc,
267+
}
268+
}

codersdk/workspaceagents.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ func (c *Client) WorkspaceAgentReconnectingPTY(ctx context.Context, agentID, rec
257257
}
258258
return nil, ReadBodyAsError(res)
259259
}
260-
return websocket.NetConn(ctx, conn, websocket.MessageBinary), nil
260+
return websocket.NetConn(context.Background(), conn, websocket.MessageBinary), nil
261261
}
262262

263263
// WorkspaceAgentListeningPorts returns a list of ports that are currently being

enterprise/coderd/provisionerdaemons.go

+59-12
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
package coderd
22

33
import (
4+
"context"
45
"database/sql"
56
"encoding/json"
67
"errors"
78
"fmt"
89
"io"
10+
"net"
911
"net/http"
1012
"strings"
1113

@@ -94,12 +96,14 @@ func (api *API) provisionerDaemons(rw http.ResponseWriter, r *http.Request) {
9496
// @Success 101
9597
// @Router /organizations/{organization}/provisionerdaemons/serve [get]
9698
func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request) {
99+
ctx := r.Context()
100+
97101
tags := map[string]string{}
98102
if r.URL.Query().Has("tag") {
99103
for _, tag := range r.URL.Query()["tag"] {
100104
parts := strings.SplitN(tag, "=", 2)
101105
if len(parts) < 2 {
102-
httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{
106+
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
103107
Message: fmt.Sprintf("Invalid format for tag %q. Key and value must be separated with =.", tag),
104108
})
105109
return
@@ -108,7 +112,7 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request)
108112
}
109113
}
110114
if !r.URL.Query().Has("provisioner") {
111-
httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{
115+
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
112116
Message: "The provisioner query parameter must be specified.",
113117
})
114118
return
@@ -122,7 +126,7 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request)
122126
case string(codersdk.ProvisionerTypeTerraform):
123127
provisionersMap[codersdk.ProvisionerTypeTerraform] = struct{}{}
124128
default:
125-
httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{
129+
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
126130
Message: fmt.Sprintf("Unknown provisioner type %q", provisioner),
127131
})
128132
return
@@ -137,7 +141,7 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request)
137141

138142
if tags[provisionerdserver.TagScope] == provisionerdserver.ScopeOrganization {
139143
if !api.AGPL.Authorize(r, rbac.ActionCreate, rbac.ResourceProvisionerDaemon) {
140-
httpapi.Write(r.Context(), rw, http.StatusForbidden, codersdk.Response{
144+
httpapi.Write(ctx, rw, http.StatusForbidden, codersdk.Response{
141145
Message: "You aren't allowed to create provisioner daemons for the organization.",
142146
})
143147
return
@@ -155,15 +159,15 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request)
155159
}
156160

157161
name := namesgenerator.GetRandomName(1)
158-
daemon, err := api.Database.InsertProvisionerDaemon(r.Context(), database.InsertProvisionerDaemonParams{
162+
daemon, err := api.Database.InsertProvisionerDaemon(ctx, database.InsertProvisionerDaemonParams{
159163
ID: uuid.New(),
160164
CreatedAt: database.Now(),
161165
Name: name,
162166
Provisioners: provisioners,
163167
Tags: tags,
164168
})
165169
if err != nil {
166-
httpapi.Write(r.Context(), rw, http.StatusInternalServerError, codersdk.Response{
170+
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
167171
Message: "Internal error writing provisioner daemon.",
168172
Detail: err.Error(),
169173
})
@@ -172,7 +176,7 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request)
172176

173177
rawTags, err := json.Marshal(daemon.Tags)
174178
if err != nil {
175-
httpapi.Write(r.Context(), rw, http.StatusInternalServerError, codersdk.Response{
179+
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
176180
Message: "Internal error marshaling daemon tags.",
177181
Detail: err.Error(),
178182
})
@@ -189,7 +193,7 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request)
189193
CompressionMode: websocket.CompressionDisabled,
190194
})
191195
if err != nil {
192-
httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{
196+
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
193197
Message: "Internal error accepting websocket connection.",
194198
Detail: err.Error(),
195199
})
@@ -203,7 +207,9 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request)
203207
// the same connection.
204208
config := yamux.DefaultConfig()
205209
config.LogOutput = io.Discard
206-
session, err := yamux.Server(websocket.NetConn(r.Context(), conn, websocket.MessageBinary), config)
210+
ctx, wsNetConn := websocketNetConn(ctx, conn, websocket.MessageBinary)
211+
defer wsNetConn.Close()
212+
session, err := yamux.Server(wsNetConn, config)
207213
if err != nil {
208214
_ = conn.Close(websocket.StatusInternalError, httpapi.WebsocketCloseSprintf("multiplex server: %s", err))
209215
return
@@ -229,12 +235,12 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request)
229235
if xerrors.Is(err, io.EOF) {
230236
return
231237
}
232-
api.Logger.Debug(r.Context(), "drpc server error", slog.Error(err))
238+
api.Logger.Debug(ctx, "drpc server error", slog.Error(err))
233239
},
234240
})
235-
err = server.Serve(r.Context(), session)
241+
err = server.Serve(ctx, session)
236242
if err != nil && !xerrors.Is(err, io.EOF) {
237-
api.Logger.Debug(r.Context(), "provisioner daemon disconnected", slog.Error(err))
243+
api.Logger.Debug(ctx, "provisioner daemon disconnected", slog.Error(err))
238244
_ = conn.Close(websocket.StatusInternalError, httpapi.WebsocketCloseSprintf("serve: %s", err))
239245
return
240246
}
@@ -254,3 +260,44 @@ func convertProvisionerDaemon(daemon database.ProvisionerDaemon) codersdk.Provis
254260
}
255261
return result
256262
}
263+
264+
// wsNetConn wraps net.Conn created by websocket.NetConn(). Cancel func
265+
// is called if a read or write error is encountered.
266+
type wsNetConn struct {
267+
cancel context.CancelFunc
268+
net.Conn
269+
}
270+
271+
func (c *wsNetConn) Read(b []byte) (n int, err error) {
272+
n, err = c.Conn.Read(b)
273+
if err != nil {
274+
c.cancel()
275+
}
276+
return n, err
277+
}
278+
279+
func (c *wsNetConn) Write(b []byte) (n int, err error) {
280+
n, err = c.Conn.Write(b)
281+
if err != nil {
282+
c.cancel()
283+
}
284+
return n, err
285+
}
286+
287+
func (c *wsNetConn) Close() error {
288+
defer c.cancel()
289+
return c.Conn.Close()
290+
}
291+
292+
// websocketNetConn wraps websocket.NetConn and returns a context that
293+
// is tied to the parent context and the lifetime of the conn. Any error
294+
// during read or write will cancel the context, but not close the
295+
// conn. Close should be called to release context resources.
296+
func websocketNetConn(ctx context.Context, conn *websocket.Conn, msgType websocket.MessageType) (context.Context, net.Conn) {
297+
ctx, cancel := context.WithCancel(ctx)
298+
nc := websocket.NetConn(ctx, conn, msgType)
299+
return ctx, &wsNetConn{
300+
cancel: cancel,
301+
Conn: nc,
302+
}
303+
}

0 commit comments

Comments
 (0)