@@ -3,12 +3,10 @@ package agent
3
3
import (
4
4
"bytes"
5
5
"context"
6
- "encoding/binary"
7
6
"encoding/json"
8
7
"errors"
9
8
"fmt"
10
9
"io"
11
- "net"
12
10
"net/http"
13
11
"net/netip"
14
12
"os"
@@ -216,8 +214,8 @@ type agent struct {
216
214
portCacheDuration time.Duration
217
215
subsystems []codersdk.AgentSubsystem
218
216
219
- reconnectingPTYs sync.Map
220
217
reconnectingPTYTimeout time.Duration
218
+ reconnectingPTYServer * reconnectingpty.Server
221
219
222
220
// we track 2 contexts and associated cancel functions: "graceful" which is Done when it is time
223
221
// to start gracefully shutting down and "hard" which is Done when it is time to close
@@ -252,8 +250,6 @@ type agent struct {
252
250
statsReporter * statsReporter
253
251
logSender * agentsdk.LogSender
254
252
255
- connCountReconnectingPTY atomic.Int64
256
-
257
253
prometheusRegistry * prometheus.Registry
258
254
// metrics are prometheus registered metrics that will be collected and
259
255
// labeled in Coder with the agent + workspace.
@@ -297,6 +293,13 @@ func (a *agent) init() {
297
293
// Register runner metrics. If the prom registry is nil, the metrics
298
294
// will not report anywhere.
299
295
a .scriptRunner .RegisterMetrics (a .prometheusRegistry )
296
+
297
+ a .reconnectingPTYServer = reconnectingpty .NewServer (
298
+ a .logger .Named ("reconnecting-pty" ),
299
+ a .sshServer ,
300
+ a .metrics .connectionsTotal , a .metrics .reconnectingPTYErrors ,
301
+ a .reconnectingPTYTimeout ,
302
+ )
300
303
go a .runLoop ()
301
304
}
302
305
@@ -1181,55 +1184,12 @@ func (a *agent) createTailnet(ctx context.Context, agentID uuid.UUID, derpMap *t
1181
1184
}
1182
1185
}()
1183
1186
if err = a .trackGoroutine (func () {
1184
- logger := a .logger .Named ("reconnecting-pty" )
1185
- var wg sync.WaitGroup
1186
- for {
1187
- conn , err := reconnectingPTYListener .Accept ()
1188
- if err != nil {
1189
- if ! a .isClosed () {
1190
- logger .Debug (ctx , "accept pty failed" , slog .Error (err ))
1191
- }
1192
- break
1193
- }
1194
- clog := logger .With (
1195
- slog .F ("remote" , conn .RemoteAddr ().String ()),
1196
- slog .F ("local" , conn .LocalAddr ().String ()))
1197
- clog .Info (ctx , "accepted conn" )
1198
- wg .Add (1 )
1199
- closed := make (chan struct {})
1200
- go func () {
1201
- select {
1202
- case <- closed :
1203
- case <- a .hardCtx .Done ():
1204
- _ = conn .Close ()
1205
- }
1206
- wg .Done ()
1207
- }()
1208
- go func () {
1209
- defer close (closed )
1210
- // This cannot use a JSON decoder, since that can
1211
- // buffer additional data that is required for the PTY.
1212
- rawLen := make ([]byte , 2 )
1213
- _ , err = conn .Read (rawLen )
1214
- if err != nil {
1215
- return
1216
- }
1217
- length := binary .LittleEndian .Uint16 (rawLen )
1218
- data := make ([]byte , length )
1219
- _ , err = conn .Read (data )
1220
- if err != nil {
1221
- return
1222
- }
1223
- var msg workspacesdk.AgentReconnectingPTYInit
1224
- err = json .Unmarshal (data , & msg )
1225
- if err != nil {
1226
- logger .Warn (ctx , "failed to unmarshal init" , slog .F ("raw" , data ))
1227
- return
1228
- }
1229
- _ = a .handleReconnectingPTY (ctx , clog , msg , conn )
1230
- }()
1187
+ rPTYServeErr := a .reconnectingPTYServer .Serve (a .gracefulCtx , a .hardCtx , reconnectingPTYListener )
1188
+ if rPTYServeErr != nil &&
1189
+ a .gracefulCtx .Err () == nil &&
1190
+ ! strings .Contains (rPTYServeErr .Error (), "use of closed network connection" ) {
1191
+ a .logger .Error (ctx , "error serving reconnecting PTY" , slog .Error (err ))
1231
1192
}
1232
- wg .Wait ()
1233
1193
}); err != nil {
1234
1194
return nil , err
1235
1195
}
@@ -1308,9 +1268,9 @@ func (a *agent) createTailnet(ctx context.Context, agentID uuid.UUID, derpMap *t
1308
1268
_ = server .Close ()
1309
1269
}()
1310
1270
1311
- err := server .Serve (apiListener )
1312
- if err != nil && ! xerrors .Is (err , http .ErrServerClosed ) && ! strings .Contains (err .Error (), "use of closed network connection" ) {
1313
- a .logger .Critical (ctx , "serve HTTP API server" , slog .Error (err ))
1271
+ apiServErr := server .Serve (apiListener )
1272
+ if apiServErr != nil && ! xerrors .Is (apiServErr , http .ErrServerClosed ) && ! strings .Contains (apiServErr .Error (), "use of closed network connection" ) {
1273
+ a .logger .Critical (ctx , "serve HTTP API server" , slog .Error (apiServErr ))
1314
1274
}
1315
1275
}); err != nil {
1316
1276
return nil , err
@@ -1394,87 +1354,6 @@ func (a *agent) runDERPMapSubscriber(ctx context.Context, tClient tailnetproto.D
1394
1354
}
1395
1355
}
1396
1356
1397
- func (a * agent ) handleReconnectingPTY (ctx context.Context , logger slog.Logger , msg workspacesdk.AgentReconnectingPTYInit , conn net.Conn ) (retErr error ) {
1398
- defer conn .Close ()
1399
- a .metrics .connectionsTotal .Add (1 )
1400
-
1401
- a .connCountReconnectingPTY .Add (1 )
1402
- defer a .connCountReconnectingPTY .Add (- 1 )
1403
-
1404
- connectionID := uuid .NewString ()
1405
- connLogger := logger .With (slog .F ("message_id" , msg .ID ), slog .F ("connection_id" , connectionID ))
1406
- connLogger .Debug (ctx , "starting handler" )
1407
-
1408
- defer func () {
1409
- if err := retErr ; err != nil {
1410
- a .closeMutex .Lock ()
1411
- closed := a .isClosed ()
1412
- a .closeMutex .Unlock ()
1413
-
1414
- // If the agent is closed, we don't want to
1415
- // log this as an error since it's expected.
1416
- if closed {
1417
- connLogger .Info (ctx , "reconnecting pty failed with attach error (agent closed)" , slog .Error (err ))
1418
- } else {
1419
- connLogger .Error (ctx , "reconnecting pty failed with attach error" , slog .Error (err ))
1420
- }
1421
- }
1422
- connLogger .Info (ctx , "reconnecting pty connection closed" )
1423
- }()
1424
-
1425
- var rpty reconnectingpty.ReconnectingPTY
1426
- sendConnected := make (chan reconnectingpty.ReconnectingPTY , 1 )
1427
- // On store, reserve this ID to prevent multiple concurrent new connections.
1428
- waitReady , ok := a .reconnectingPTYs .LoadOrStore (msg .ID , sendConnected )
1429
- if ok {
1430
- close (sendConnected ) // Unused.
1431
- connLogger .Debug (ctx , "connecting to existing reconnecting pty" )
1432
- c , ok := waitReady .(chan reconnectingpty.ReconnectingPTY )
1433
- if ! ok {
1434
- return xerrors .Errorf ("found invalid type in reconnecting pty map: %T" , waitReady )
1435
- }
1436
- rpty , ok = <- c
1437
- if ! ok || rpty == nil {
1438
- return xerrors .Errorf ("reconnecting pty closed before connection" )
1439
- }
1440
- c <- rpty // Put it back for the next reconnect.
1441
- } else {
1442
- connLogger .Debug (ctx , "creating new reconnecting pty" )
1443
-
1444
- connected := false
1445
- defer func () {
1446
- if ! connected && retErr != nil {
1447
- a .reconnectingPTYs .Delete (msg .ID )
1448
- close (sendConnected )
1449
- }
1450
- }()
1451
-
1452
- // Empty command will default to the users shell!
1453
- cmd , err := a .sshServer .CreateCommand (ctx , msg .Command , nil )
1454
- if err != nil {
1455
- a .metrics .reconnectingPTYErrors .WithLabelValues ("create_command" ).Add (1 )
1456
- return xerrors .Errorf ("create command: %w" , err )
1457
- }
1458
-
1459
- rpty = reconnectingpty .New (ctx , cmd , & reconnectingpty.Options {
1460
- Timeout : a .reconnectingPTYTimeout ,
1461
- Metrics : a .metrics .reconnectingPTYErrors ,
1462
- }, logger .With (slog .F ("message_id" , msg .ID )))
1463
-
1464
- if err = a .trackGoroutine (func () {
1465
- rpty .Wait ()
1466
- a .reconnectingPTYs .Delete (msg .ID )
1467
- }); err != nil {
1468
- rpty .Close (err )
1469
- return xerrors .Errorf ("start routine: %w" , err )
1470
- }
1471
-
1472
- connected = true
1473
- sendConnected <- rpty
1474
- }
1475
- return rpty .Attach (ctx , connectionID , conn , msg .Height , msg .Width , connLogger )
1476
- }
1477
-
1478
1357
// Collect collects additional stats from the agent
1479
1358
func (a * agent ) Collect (ctx context.Context , networkStats map [netlogtype.Connection ]netlogtype.Counts ) * proto.Stats {
1480
1359
a .logger .Debug (context .Background (), "computing stats report" )
@@ -1496,7 +1375,7 @@ func (a *agent) Collect(ctx context.Context, networkStats map[netlogtype.Connect
1496
1375
stats .SessionCountVscode = sshStats .VSCode
1497
1376
stats .SessionCountJetbrains = sshStats .JetBrains
1498
1377
1499
- stats .SessionCountReconnectingPty = a .connCountReconnectingPTY . Load ()
1378
+ stats .SessionCountReconnectingPty = a .reconnectingPTYServer . ConnCount ()
1500
1379
1501
1380
// Compute the median connection latency!
1502
1381
a .logger .Debug (ctx , "starting peer latency measurement for stats" )
0 commit comments