Skip to content

Commit 4b426c5

Browse files
mafredrikylecarbs
authored andcommitted
fix: Avoid use of r.Context() after r.Hijack() (#1978)
1 parent e8a9358 commit 4b426c5

File tree

1 file changed

+74
-30
lines changed

1 file changed

+74
-30
lines changed

coderd/workspaceagents.go

+74-30
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package coderd
22

33
import (
4+
"context"
45
"database/sql"
56
"encoding/json"
67
"fmt"
@@ -16,6 +17,7 @@ import (
1617
"nhooyr.io/websocket"
1718

1819
"cdr.dev/slog"
20+
1921
"github.com/coder/coder/agent"
2022
"github.com/coder/coder/coderd/database"
2123
"github.com/coder/coder/coderd/httpapi"
@@ -69,17 +71,18 @@ func (api *API) workspaceAgentDial(rw http.ResponseWriter, r *http.Request) {
6971
})
7072
return
7173
}
72-
defer func() {
73-
_ = conn.Close(websocket.StatusNormalClosure, "")
74-
}()
74+
75+
ctx, wsNetConn := websocketNetConn(r.Context(), conn, websocket.MessageBinary)
76+
defer wsNetConn.Close() // Also closes conn.
77+
7578
config := yamux.DefaultConfig()
7679
config.LogOutput = io.Discard
77-
session, err := yamux.Server(websocket.NetConn(r.Context(), conn, websocket.MessageBinary), config)
80+
session, err := yamux.Server(wsNetConn, config)
7881
if err != nil {
7982
_ = conn.Close(websocket.StatusAbnormalClosure, err.Error())
8083
return
8184
}
82-
err = peerbroker.ProxyListen(r.Context(), session, peerbroker.ProxyOptions{
85+
err = peerbroker.ProxyListen(ctx, session, peerbroker.ProxyOptions{
8386
ChannelID: workspaceAgent.ID.String(),
8487
Logger: api.Logger.Named("peerbroker-proxy-dial"),
8588
Pubsub: api.Pubsub,
@@ -193,13 +196,12 @@ func (api *API) workspaceAgentListen(rw http.ResponseWriter, r *http.Request) {
193196
return
194197
}
195198

196-
defer func() {
197-
_ = conn.Close(websocket.StatusNormalClosure, "")
198-
}()
199+
ctx, wsNetConn := websocketNetConn(r.Context(), conn, websocket.MessageBinary)
200+
defer wsNetConn.Close() // Also closes conn.
199201

200202
config := yamux.DefaultConfig()
201203
config.LogOutput = io.Discard
202-
session, err := yamux.Server(websocket.NetConn(r.Context(), conn, websocket.MessageBinary), config)
204+
session, err := yamux.Server(wsNetConn, config)
203205
if err != nil {
204206
_ = conn.Close(websocket.StatusAbnormalClosure, err.Error())
205207
return
@@ -229,7 +231,7 @@ func (api *API) workspaceAgentListen(rw http.ResponseWriter, r *http.Request) {
229231
}
230232
disconnectedAt := workspaceAgent.DisconnectedAt
231233
updateConnectionTimes := func() error {
232-
err = api.Database.UpdateWorkspaceAgentConnectionByID(r.Context(), database.UpdateWorkspaceAgentConnectionByIDParams{
234+
err = api.Database.UpdateWorkspaceAgentConnectionByID(ctx, database.UpdateWorkspaceAgentConnectionByIDParams{
233235
ID: workspaceAgent.ID,
234236
FirstConnectedAt: firstConnectedAt,
235237
LastConnectedAt: lastConnectedAt,
@@ -255,7 +257,7 @@ func (api *API) workspaceAgentListen(rw http.ResponseWriter, r *http.Request) {
255257
return
256258
}
257259

258-
api.Logger.Info(r.Context(), "accepting agent", slog.F("resource", resource), slog.F("agent", workspaceAgent))
260+
api.Logger.Info(ctx, "accepting agent", slog.F("resource", resource), slog.F("agent", workspaceAgent))
259261

260262
ticker := time.NewTicker(api.AgentConnectionUpdateFrequency)
261263
defer ticker.Stop()
@@ -324,16 +326,16 @@ func (api *API) workspaceAgentTurn(rw http.ResponseWriter, r *http.Request) {
324326
})
325327
return
326328
}
327-
defer func() {
328-
_ = wsConn.Close(websocket.StatusNormalClosure, "")
329-
}()
330-
netConn := websocket.NetConn(r.Context(), wsConn, websocket.MessageBinary)
331-
api.Logger.Debug(r.Context(), "accepting turn connection", slog.F("remote-address", r.RemoteAddr), slog.F("local-address", localAddress))
329+
330+
ctx, wsNetConn := websocketNetConn(r.Context(), wsConn, websocket.MessageBinary)
331+
defer wsNetConn.Close() // Also closes conn.
332+
333+
api.Logger.Debug(ctx, "accepting turn connection", slog.F("remote-address", r.RemoteAddr), slog.F("local-address", localAddress))
332334
select {
333-
case <-api.TURNServer.Accept(netConn, remoteAddress, localAddress).Closed():
334-
case <-r.Context().Done():
335+
case <-api.TURNServer.Accept(wsNetConn, remoteAddress, localAddress).Closed():
336+
case <-ctx.Done():
335337
}
336-
api.Logger.Debug(r.Context(), "completed turn connection", slog.F("remote-address", r.RemoteAddr), slog.F("local-address", localAddress))
338+
api.Logger.Debug(ctx, "completed turn connection", slog.F("remote-address", r.RemoteAddr), slog.F("local-address", localAddress))
337339
}
338340

339341
// workspaceAgentPTY spawns a PTY and pipes it over a WebSocket.
@@ -384,12 +386,11 @@ func (api *API) workspaceAgentPTY(rw http.ResponseWriter, r *http.Request) {
384386
})
385387
return
386388
}
387-
defer func() {
388-
_ = conn.Close(websocket.StatusNormalClosure, "ended")
389-
}()
390-
// Accept text connections, because it's more developer friendly.
391-
wsNetConn := websocket.NetConn(r.Context(), conn, websocket.MessageBinary)
392-
agentConn, err := api.dialWorkspaceAgent(r, workspaceAgent.ID)
389+
390+
ctx, wsNetConn := websocketNetConn(r.Context(), conn, websocket.MessageBinary)
391+
defer wsNetConn.Close() // Also closes conn.
392+
393+
agentConn, err := api.dialWorkspaceAgent(ctx, r, workspaceAgent.ID)
393394
if err != nil {
394395
_ = conn.Close(websocket.StatusInternalError, httpapi.WebsocketCloseSprintf("dial workspace agent: %s", err))
395396
return
@@ -408,11 +409,13 @@ func (api *API) workspaceAgentPTY(rw http.ResponseWriter, r *http.Request) {
408409
_, _ = io.Copy(ptNetConn, wsNetConn)
409410
}
410411

411-
// dialWorkspaceAgent connects to a workspace agent by ID.
412-
func (api *API) dialWorkspaceAgent(r *http.Request, agentID uuid.UUID) (*agent.Conn, error) {
412+
// dialWorkspaceAgent connects to a workspace agent by ID. Only rely on
413+
// r.Context() for cancellation if it's use is safe or r.Hijack() has
414+
// not been performed.
415+
func (api *API) dialWorkspaceAgent(ctx context.Context, r *http.Request, agentID uuid.UUID) (*agent.Conn, error) {
413416
client, server := provisionersdk.TransportPipe()
414417
go func() {
415-
_ = peerbroker.ProxyListen(r.Context(), server, peerbroker.ProxyOptions{
418+
_ = peerbroker.ProxyListen(ctx, server, peerbroker.ProxyOptions{
416419
ChannelID: agentID.String(),
417420
Logger: api.Logger.Named("peerbroker-proxy-dial"),
418421
Pubsub: api.Pubsub,
@@ -422,7 +425,7 @@ func (api *API) dialWorkspaceAgent(r *http.Request, agentID uuid.UUID) (*agent.C
422425
}()
423426

424427
peerClient := proto.NewDRPCPeerBrokerClient(provisionersdk.Conn(client))
425-
stream, err := peerClient.NegotiateConnection(r.Context())
428+
stream, err := peerClient.NegotiateConnection(ctx)
426429
if err != nil {
427430
return nil, xerrors.Errorf("negotiate: %w", err)
428431
}
@@ -434,7 +437,7 @@ func (api *API) dialWorkspaceAgent(r *http.Request, agentID uuid.UUID) (*agent.C
434437
options.SettingEngine.SetICEProxyDialer(turnconn.ProxyDialer(func() (c net.Conn, err error) {
435438
clientPipe, serverPipe := net.Pipe()
436439
go func() {
437-
<-r.Context().Done()
440+
<-ctx.Done()
438441
_ = clientPipe.Close()
439442
_ = serverPipe.Close()
440443
}()
@@ -515,3 +518,44 @@ func convertWorkspaceAgent(dbAgent database.WorkspaceAgent, agentUpdateFrequency
515518

516519
return workspaceAgent, nil
517520
}
521+
522+
// wsNetConn wraps net.Conn created by websocket.NetConn(). Cancel func
523+
// is called if a read or write error is encountered.
524+
type wsNetConn struct {
525+
cancel context.CancelFunc
526+
net.Conn
527+
}
528+
529+
func (c *wsNetConn) Read(b []byte) (n int, err error) {
530+
n, err = c.Conn.Read(b)
531+
if err != nil {
532+
c.cancel()
533+
}
534+
return n, err
535+
}
536+
537+
func (c *wsNetConn) Write(b []byte) (n int, err error) {
538+
n, err = c.Conn.Write(b)
539+
if err != nil {
540+
c.cancel()
541+
}
542+
return n, err
543+
}
544+
545+
func (c *wsNetConn) Close() error {
546+
defer c.cancel()
547+
return c.Conn.Close()
548+
}
549+
550+
// websocketNetConn wraps websocket.NetConn and returns a context that
551+
// is tied to the parent context and the lifetime of the conn. Any error
552+
// during read or write will cancel the context, but not close the
553+
// conn. Close should be called to release context resources.
554+
func websocketNetConn(ctx context.Context, conn *websocket.Conn, msgType websocket.MessageType) (context.Context, net.Conn) {
555+
ctx, cancel := context.WithCancel(ctx)
556+
nc := websocket.NetConn(ctx, conn, msgType)
557+
return ctx, &wsNetConn{
558+
cancel: cancel,
559+
Conn: nc,
560+
}
561+
}

0 commit comments

Comments
 (0)