1
1
package coderd
2
2
3
3
import (
4
+ "context"
4
5
"database/sql"
5
6
"encoding/json"
6
7
"fmt"
@@ -16,6 +17,7 @@ import (
16
17
"nhooyr.io/websocket"
17
18
18
19
"cdr.dev/slog"
20
+
19
21
"github.com/coder/coder/agent"
20
22
"github.com/coder/coder/coderd/database"
21
23
"github.com/coder/coder/coderd/httpapi"
@@ -69,17 +71,18 @@ func (api *API) workspaceAgentDial(rw http.ResponseWriter, r *http.Request) {
69
71
})
70
72
return
71
73
}
72
- defer func () {
73
- _ = conn .Close (websocket .StatusNormalClosure , "" )
74
- }()
74
+
75
+ ctx , wsNetConn := websocketNetConn (r .Context (), conn , websocket .MessageBinary )
76
+ defer wsNetConn .Close () // Also closes conn.
77
+
75
78
config := yamux .DefaultConfig ()
76
79
config .LogOutput = io .Discard
77
- session , err := yamux .Server (websocket . NetConn ( r . Context (), conn , websocket . MessageBinary ) , config )
80
+ session , err := yamux .Server (wsNetConn , config )
78
81
if err != nil {
79
82
_ = conn .Close (websocket .StatusAbnormalClosure , err .Error ())
80
83
return
81
84
}
82
- err = peerbroker .ProxyListen (r . Context () , session , peerbroker.ProxyOptions {
85
+ err = peerbroker .ProxyListen (ctx , session , peerbroker.ProxyOptions {
83
86
ChannelID : workspaceAgent .ID .String (),
84
87
Logger : api .Logger .Named ("peerbroker-proxy-dial" ),
85
88
Pubsub : api .Pubsub ,
@@ -193,13 +196,12 @@ func (api *API) workspaceAgentListen(rw http.ResponseWriter, r *http.Request) {
193
196
return
194
197
}
195
198
196
- defer func () {
197
- _ = conn .Close (websocket .StatusNormalClosure , "" )
198
- }()
199
+ ctx , wsNetConn := websocketNetConn (r .Context (), conn , websocket .MessageBinary )
200
+ defer wsNetConn .Close () // Also closes conn.
199
201
200
202
config := yamux .DefaultConfig ()
201
203
config .LogOutput = io .Discard
202
- session , err := yamux .Server (websocket . NetConn ( r . Context (), conn , websocket . MessageBinary ) , config )
204
+ session , err := yamux .Server (wsNetConn , config )
203
205
if err != nil {
204
206
_ = conn .Close (websocket .StatusAbnormalClosure , err .Error ())
205
207
return
@@ -229,7 +231,7 @@ func (api *API) workspaceAgentListen(rw http.ResponseWriter, r *http.Request) {
229
231
}
230
232
disconnectedAt := workspaceAgent .DisconnectedAt
231
233
updateConnectionTimes := func () error {
232
- err = api .Database .UpdateWorkspaceAgentConnectionByID (r . Context () , database.UpdateWorkspaceAgentConnectionByIDParams {
234
+ err = api .Database .UpdateWorkspaceAgentConnectionByID (ctx , database.UpdateWorkspaceAgentConnectionByIDParams {
233
235
ID : workspaceAgent .ID ,
234
236
FirstConnectedAt : firstConnectedAt ,
235
237
LastConnectedAt : lastConnectedAt ,
@@ -255,7 +257,7 @@ func (api *API) workspaceAgentListen(rw http.ResponseWriter, r *http.Request) {
255
257
return
256
258
}
257
259
258
- api .Logger .Info (r . Context () , "accepting agent" , slog .F ("resource" , resource ), slog .F ("agent" , workspaceAgent ))
260
+ api .Logger .Info (ctx , "accepting agent" , slog .F ("resource" , resource ), slog .F ("agent" , workspaceAgent ))
259
261
260
262
ticker := time .NewTicker (api .AgentConnectionUpdateFrequency )
261
263
defer ticker .Stop ()
@@ -324,16 +326,16 @@ func (api *API) workspaceAgentTurn(rw http.ResponseWriter, r *http.Request) {
324
326
})
325
327
return
326
328
}
327
- defer func () {
328
- _ = wsConn . Close ( websocket . StatusNormalClosure , "" )
329
- }()
330
- netConn := websocket . NetConn ( r . Context (), wsConn , websocket . MessageBinary )
331
- api .Logger .Debug (r . Context () , "accepting turn connection" , slog .F ("remote-address" , r .RemoteAddr ), slog .F ("local-address" , localAddress ))
329
+
330
+ ctx , wsNetConn := websocketNetConn ( r . Context (), wsConn , websocket . MessageBinary )
331
+ defer wsNetConn . Close () // Also closes conn.
332
+
333
+ api .Logger .Debug (ctx , "accepting turn connection" , slog .F ("remote-address" , r .RemoteAddr ), slog .F ("local-address" , localAddress ))
332
334
select {
333
- case <- api .TURNServer .Accept (netConn , remoteAddress , localAddress ).Closed ():
334
- case <- r . Context () .Done ():
335
+ case <- api .TURNServer .Accept (wsNetConn , remoteAddress , localAddress ).Closed ():
336
+ case <- ctx .Done ():
335
337
}
336
- api .Logger .Debug (r . Context () , "completed turn connection" , slog .F ("remote-address" , r .RemoteAddr ), slog .F ("local-address" , localAddress ))
338
+ api .Logger .Debug (ctx , "completed turn connection" , slog .F ("remote-address" , r .RemoteAddr ), slog .F ("local-address" , localAddress ))
337
339
}
338
340
339
341
// workspaceAgentPTY spawns a PTY and pipes it over a WebSocket.
@@ -384,12 +386,11 @@ func (api *API) workspaceAgentPTY(rw http.ResponseWriter, r *http.Request) {
384
386
})
385
387
return
386
388
}
387
- defer func () {
388
- _ = conn .Close (websocket .StatusNormalClosure , "ended" )
389
- }()
390
- // Accept text connections, because it's more developer friendly.
391
- wsNetConn := websocket .NetConn (r .Context (), conn , websocket .MessageBinary )
392
- agentConn , err := api .dialWorkspaceAgent (r , workspaceAgent .ID )
389
+
390
+ ctx , wsNetConn := websocketNetConn (r .Context (), conn , websocket .MessageBinary )
391
+ defer wsNetConn .Close () // Also closes conn.
392
+
393
+ agentConn , err := api .dialWorkspaceAgent (ctx , r , workspaceAgent .ID )
393
394
if err != nil {
394
395
_ = conn .Close (websocket .StatusInternalError , httpapi .WebsocketCloseSprintf ("dial workspace agent: %s" , err ))
395
396
return
@@ -408,11 +409,13 @@ func (api *API) workspaceAgentPTY(rw http.ResponseWriter, r *http.Request) {
408
409
_ , _ = io .Copy (ptNetConn , wsNetConn )
409
410
}
410
411
411
- // dialWorkspaceAgent connects to a workspace agent by ID.
412
- func (api * API ) dialWorkspaceAgent (r * http.Request , agentID uuid.UUID ) (* agent.Conn , error ) {
412
+ // dialWorkspaceAgent connects to a workspace agent by ID. Only rely on
413
+ // r.Context() for cancellation if it's use is safe or r.Hijack() has
414
+ // not been performed.
415
+ func (api * API ) dialWorkspaceAgent (ctx context.Context , r * http.Request , agentID uuid.UUID ) (* agent.Conn , error ) {
413
416
client , server := provisionersdk .TransportPipe ()
414
417
go func () {
415
- _ = peerbroker .ProxyListen (r . Context () , server , peerbroker.ProxyOptions {
418
+ _ = peerbroker .ProxyListen (ctx , server , peerbroker.ProxyOptions {
416
419
ChannelID : agentID .String (),
417
420
Logger : api .Logger .Named ("peerbroker-proxy-dial" ),
418
421
Pubsub : api .Pubsub ,
@@ -422,7 +425,7 @@ func (api *API) dialWorkspaceAgent(r *http.Request, agentID uuid.UUID) (*agent.C
422
425
}()
423
426
424
427
peerClient := proto .NewDRPCPeerBrokerClient (provisionersdk .Conn (client ))
425
- stream , err := peerClient .NegotiateConnection (r . Context () )
428
+ stream , err := peerClient .NegotiateConnection (ctx )
426
429
if err != nil {
427
430
return nil , xerrors .Errorf ("negotiate: %w" , err )
428
431
}
@@ -434,7 +437,7 @@ func (api *API) dialWorkspaceAgent(r *http.Request, agentID uuid.UUID) (*agent.C
434
437
options .SettingEngine .SetICEProxyDialer (turnconn .ProxyDialer (func () (c net.Conn , err error ) {
435
438
clientPipe , serverPipe := net .Pipe ()
436
439
go func () {
437
- <- r . Context () .Done ()
440
+ <- ctx .Done ()
438
441
_ = clientPipe .Close ()
439
442
_ = serverPipe .Close ()
440
443
}()
@@ -515,3 +518,44 @@ func convertWorkspaceAgent(dbAgent database.WorkspaceAgent, agentUpdateFrequency
515
518
516
519
return workspaceAgent , nil
517
520
}
521
+
522
+ // wsNetConn wraps net.Conn created by websocket.NetConn(). Cancel func
523
+ // is called if a read or write error is encountered.
524
+ type wsNetConn struct {
525
+ cancel context.CancelFunc
526
+ net.Conn
527
+ }
528
+
529
+ func (c * wsNetConn ) Read (b []byte ) (n int , err error ) {
530
+ n , err = c .Conn .Read (b )
531
+ if err != nil {
532
+ c .cancel ()
533
+ }
534
+ return n , err
535
+ }
536
+
537
+ func (c * wsNetConn ) Write (b []byte ) (n int , err error ) {
538
+ n , err = c .Conn .Write (b )
539
+ if err != nil {
540
+ c .cancel ()
541
+ }
542
+ return n , err
543
+ }
544
+
545
+ func (c * wsNetConn ) Close () error {
546
+ defer c .cancel ()
547
+ return c .Conn .Close ()
548
+ }
549
+
550
+ // websocketNetConn wraps websocket.NetConn and returns a context that
551
+ // is tied to the parent context and the lifetime of the conn. Any error
552
+ // during read or write will cancel the context, but not close the
553
+ // conn. Close should be called to release context resources.
554
+ func websocketNetConn (ctx context.Context , conn * websocket.Conn , msgType websocket.MessageType ) (context.Context , net.Conn ) {
555
+ ctx , cancel := context .WithCancel (ctx )
556
+ nc := websocket .NetConn (ctx , conn , msgType )
557
+ return ctx , & wsNetConn {
558
+ cancel : cancel ,
559
+ Conn : nc ,
560
+ }
561
+ }
0 commit comments