4
4
"context"
5
5
"crypto/rand"
6
6
"crypto/rsa"
7
+ "encoding/json"
7
8
"errors"
8
9
"fmt"
9
10
"io"
@@ -12,10 +13,14 @@ import (
12
13
"os/exec"
13
14
"os/user"
14
15
"runtime"
16
+ "strconv"
15
17
"strings"
16
18
"sync"
17
19
"time"
18
20
21
+ "github.com/armon/circbuf"
22
+ "github.com/google/uuid"
23
+
19
24
gsyslog "github.com/hashicorp/go-syslog"
20
25
"go.uber.org/atomic"
21
26
@@ -33,6 +38,11 @@ import (
33
38
"golang.org/x/xerrors"
34
39
)
35
40
41
+ type Options struct {
42
+ ReconnectingPTYTimeout time.Duration
43
+ Logger slog.Logger
44
+ }
45
+
36
46
type Metadata struct {
37
47
OwnerEmail string `json:"owner_email"`
38
48
OwnerUsername string `json:"owner_username"`
@@ -43,13 +53,20 @@ type Metadata struct {
43
53
44
54
type Dialer func (ctx context.Context , logger slog.Logger ) (Metadata , * peerbroker.Listener , error )
45
55
46
- func New (dialer Dialer , logger slog.Logger ) io.Closer {
56
+ func New (dialer Dialer , options * Options ) io.Closer {
57
+ if options == nil {
58
+ options = & Options {}
59
+ }
60
+ if options .ReconnectingPTYTimeout == 0 {
61
+ options .ReconnectingPTYTimeout = 5 * time .Minute
62
+ }
47
63
ctx , cancelFunc := context .WithCancel (context .Background ())
48
64
server := & agent {
49
- dialer : dialer ,
50
- logger : logger ,
51
- closeCancel : cancelFunc ,
52
- closed : make (chan struct {}),
65
+ dialer : dialer ,
66
+ reconnectingPTYTimeout : options .ReconnectingPTYTimeout ,
67
+ logger : options .Logger ,
68
+ closeCancel : cancelFunc ,
69
+ closed : make (chan struct {}),
53
70
}
54
71
server .init (ctx )
55
72
return server
@@ -59,6 +76,9 @@ type agent struct {
59
76
dialer Dialer
60
77
logger slog.Logger
61
78
79
+ reconnectingPTYs sync.Map
80
+ reconnectingPTYTimeout time.Duration
81
+
62
82
connCloseWait sync.WaitGroup
63
83
closeCancel context.CancelFunc
64
84
closeMutex sync.Mutex
@@ -199,6 +219,8 @@ func (a *agent) handlePeerConn(ctx context.Context, conn *peer.Conn) {
199
219
switch channel .Protocol () {
200
220
case "ssh" :
201
221
go a .sshServer .HandleConn (channel .NetConn ())
222
+ case "reconnecting-pty" :
223
+ go a .handleReconnectingPTY (ctx , channel .Label (), channel .NetConn ())
202
224
default :
203
225
a .logger .Warn (ctx , "unhandled protocol from channel" ,
204
226
slog .F ("protocol" , channel .Protocol ()),
@@ -285,22 +307,25 @@ func (a *agent) init(ctx context.Context) {
285
307
go a .run (ctx )
286
308
}
287
309
288
- func (a * agent ) handleSSHSession (session ssh.Session ) error {
310
+ // createCommand processes raw command input with OpenSSH-like behavior.
311
+ // If the rawCommand provided is empty, it will default to the users shell.
312
+ // This injects environment variables specified by the user at launch too.
313
+ func (a * agent ) createCommand (ctx context.Context , rawCommand string , env []string ) (* exec.Cmd , error ) {
289
314
currentUser , err := user .Current ()
290
315
if err != nil {
291
- return xerrors .Errorf ("get current user: %w" , err )
316
+ return nil , xerrors .Errorf ("get current user: %w" , err )
292
317
}
293
318
username := currentUser .Username
294
319
295
320
shell , err := usershell .Get (username )
296
321
if err != nil {
297
- return xerrors .Errorf ("get user shell: %w" , err )
322
+ return nil , xerrors .Errorf ("get user shell: %w" , err )
298
323
}
299
324
300
325
// gliderlabs/ssh returns a command slice of zero
301
326
// when a shell is requested.
302
- command := session . RawCommand ()
303
- if len (session . Command () ) == 0 {
327
+ command := rawCommand
328
+ if len (command ) == 0 {
304
329
command = shell
305
330
}
306
331
@@ -310,16 +335,16 @@ func (a *agent) handleSSHSession(session ssh.Session) error {
310
335
if runtime .GOOS == "windows" {
311
336
caller = "/c"
312
337
}
313
- cmd := exec .CommandContext (session . Context () , shell , caller , command )
338
+ cmd := exec .CommandContext (ctx , shell , caller , command )
314
339
cmd .Dir = a .directory .Load ()
315
340
if cmd .Dir == "" {
316
341
// Default to $HOME if a directory is not set!
317
342
cmd .Dir = os .Getenv ("HOME" )
318
343
}
319
- cmd .Env = append (os .Environ (), session . Environ () ... )
344
+ cmd .Env = append (os .Environ (), env ... )
320
345
executablePath , err := os .Executable ()
321
346
if err != nil {
322
- return xerrors .Errorf ("getting os executable: %w" , err )
347
+ return nil , xerrors .Errorf ("getting os executable: %w" , err )
323
348
}
324
349
// Git on Windows resolves with UNIX-style paths.
325
350
// If using backslashes, it's unable to find the executable.
@@ -340,6 +365,14 @@ func (a *agent) handleSSHSession(session ssh.Session) error {
340
365
}
341
366
}
342
367
}
368
+ return cmd , nil
369
+ }
370
+
371
+ func (a * agent ) handleSSHSession (session ssh.Session ) error {
372
+ cmd , err := a .createCommand (session .Context (), session .RawCommand (), session .Environ ())
373
+ if err != nil {
374
+ return err
375
+ }
343
376
344
377
sshPty , windowSize , isPty := session .Pty ()
345
378
if isPty {
@@ -389,6 +422,194 @@ func (a *agent) handleSSHSession(session ssh.Session) error {
389
422
return cmd .Wait ()
390
423
}
391
424
425
+ func (a * agent ) handleReconnectingPTY (ctx context.Context , rawID string , conn net.Conn ) {
426
+ defer conn .Close ()
427
+
428
+ // The ID format is referenced in conn.go.
429
+ // <uuid>:<height>:<width>
430
+ idParts := strings .Split (rawID , ":" )
431
+ if len (idParts ) != 3 {
432
+ a .logger .Warn (ctx , "client sent invalid id format" , slog .F ("raw-id" , rawID ))
433
+ return
434
+ }
435
+ id := idParts [0 ]
436
+ // Enforce a consistent format for IDs.
437
+ _ , err := uuid .Parse (id )
438
+ if err != nil {
439
+ a .logger .Warn (ctx , "client sent reconnection token that isn't a uuid" , slog .F ("id" , id ), slog .Error (err ))
440
+ return
441
+ }
442
+ // Parse the initial terminal dimensions.
443
+ height , err := strconv .Atoi (idParts [1 ])
444
+ if err != nil {
445
+ a .logger .Warn (ctx , "client sent invalid height" , slog .F ("id" , id ), slog .F ("height" , idParts [1 ]))
446
+ return
447
+ }
448
+ width , err := strconv .Atoi (idParts [2 ])
449
+ if err != nil {
450
+ a .logger .Warn (ctx , "client sent invalid width" , slog .F ("id" , id ), slog .F ("width" , idParts [2 ]))
451
+ return
452
+ }
453
+
454
+ var rpty * reconnectingPTY
455
+ rawRPTY , ok := a .reconnectingPTYs .Load (id )
456
+ if ok {
457
+ rpty , ok = rawRPTY .(* reconnectingPTY )
458
+ if ! ok {
459
+ a .logger .Warn (ctx , "found invalid type in reconnecting pty map" , slog .F ("id" , id ))
460
+ }
461
+ } else {
462
+ // Empty command will default to the users shell!
463
+ cmd , err := a .createCommand (ctx , "" , nil )
464
+ if err != nil {
465
+ a .logger .Warn (ctx , "create reconnecting pty command" , slog .Error (err ))
466
+ return
467
+ }
468
+ cmd .Env = append (cmd .Env , "TERM=xterm-256color" )
469
+
470
+ ptty , process , err := pty .Start (cmd )
471
+ if err != nil {
472
+ a .logger .Warn (ctx , "start reconnecting pty command" , slog .F ("id" , id ))
473
+ }
474
+
475
+ // Default to buffer 64KB.
476
+ circularBuffer , err := circbuf .NewBuffer (64 * 1024 )
477
+ if err != nil {
478
+ a .logger .Warn (ctx , "create circular buffer" , slog .Error (err ))
479
+ return
480
+ }
481
+
482
+ a .closeMutex .Lock ()
483
+ a .connCloseWait .Add (1 )
484
+ a .closeMutex .Unlock ()
485
+ ctx , cancelFunc := context .WithCancel (ctx )
486
+ rpty = & reconnectingPTY {
487
+ activeConns : make (map [string ]net.Conn ),
488
+ ptty : ptty ,
489
+ // Timeouts created with an after func can be reset!
490
+ timeout : time .AfterFunc (a .reconnectingPTYTimeout , cancelFunc ),
491
+ circularBuffer : circularBuffer ,
492
+ }
493
+ a .reconnectingPTYs .Store (id , rpty )
494
+ go func () {
495
+ // CommandContext isn't respected for Windows PTYs right now,
496
+ // so we need to manually track the lifecycle.
497
+ // When the context has been completed either:
498
+ // 1. The timeout completed.
499
+ // 2. The parent context was canceled.
500
+ <- ctx .Done ()
501
+ _ = process .Kill ()
502
+ }()
503
+ go func () {
504
+ // If the process dies randomly, we should
505
+ // close the pty.
506
+ _ , _ = process .Wait ()
507
+ rpty .Close ()
508
+ }()
509
+ go func () {
510
+ buffer := make ([]byte , 1024 )
511
+ for {
512
+ read , err := rpty .ptty .Output ().Read (buffer )
513
+ if err != nil {
514
+ // When the PTY is closed, this is triggered.
515
+ break
516
+ }
517
+ part := buffer [:read ]
518
+ _ , err = rpty .circularBuffer .Write (part )
519
+ if err != nil {
520
+ a .logger .Error (ctx , "reconnecting pty write buffer" , slog .Error (err ), slog .F ("id" , id ))
521
+ break
522
+ }
523
+ rpty .activeConnsMutex .Lock ()
524
+ for _ , conn := range rpty .activeConns {
525
+ _ , _ = conn .Write (part )
526
+ }
527
+ rpty .activeConnsMutex .Unlock ()
528
+ }
529
+
530
+ // Cleanup the process, PTY, and delete it's
531
+ // ID from memory.
532
+ _ = process .Kill ()
533
+ rpty .Close ()
534
+ a .reconnectingPTYs .Delete (id )
535
+ a .connCloseWait .Done ()
536
+ }()
537
+ }
538
+ // Resize the PTY to initial height + width.
539
+ err = rpty .ptty .Resize (uint16 (height ), uint16 (width ))
540
+ if err != nil {
541
+ // We can continue after this, it's not fatal!
542
+ a .logger .Error (ctx , "resize reconnecting pty" , slog .F ("id" , id ), slog .Error (err ))
543
+ }
544
+ // Write any previously stored data for the TTY.
545
+ _ , err = conn .Write (rpty .circularBuffer .Bytes ())
546
+ if err != nil {
547
+ a .logger .Warn (ctx , "write reconnecting pty buffer" , slog .F ("id" , id ), slog .Error (err ))
548
+ return
549
+ }
550
+ connectionID := uuid .NewString ()
551
+ // Multiple connections to the same TTY are permitted.
552
+ // This could easily be used for terminal sharing, but
553
+ // we do it because it's a nice user experience to
554
+ // copy/paste a terminal URL and have it _just work_.
555
+ rpty .activeConnsMutex .Lock ()
556
+ rpty .activeConns [connectionID ] = conn
557
+ rpty .activeConnsMutex .Unlock ()
558
+ // Resetting this timeout prevents the PTY from exiting.
559
+ rpty .timeout .Reset (a .reconnectingPTYTimeout )
560
+
561
+ ctx , cancelFunc := context .WithCancel (ctx )
562
+ defer cancelFunc ()
563
+ heartbeat := time .NewTicker (a .reconnectingPTYTimeout / 2 )
564
+ defer heartbeat .Stop ()
565
+ go func () {
566
+ // Keep updating the activity while this
567
+ // connection is alive!
568
+ for {
569
+ select {
570
+ case <- ctx .Done ():
571
+ return
572
+ case <- heartbeat .C :
573
+ }
574
+ rpty .timeout .Reset (a .reconnectingPTYTimeout )
575
+ }
576
+ }()
577
+ defer func () {
578
+ // After this connection ends, remove it from
579
+ // the PTYs active connections. If it isn't
580
+ // removed, all PTY data will be sent to it.
581
+ rpty .activeConnsMutex .Lock ()
582
+ delete (rpty .activeConns , connectionID )
583
+ rpty .activeConnsMutex .Unlock ()
584
+ }()
585
+ decoder := json .NewDecoder (conn )
586
+ var req ReconnectingPTYRequest
587
+ for {
588
+ err = decoder .Decode (& req )
589
+ if xerrors .Is (err , io .EOF ) {
590
+ return
591
+ }
592
+ if err != nil {
593
+ a .logger .Warn (ctx , "reconnecting pty buffer read error" , slog .F ("id" , id ), slog .Error (err ))
594
+ return
595
+ }
596
+ _ , err = rpty .ptty .Input ().Write ([]byte (req .Data ))
597
+ if err != nil {
598
+ a .logger .Warn (ctx , "write to reconnecting pty" , slog .F ("id" , id ), slog .Error (err ))
599
+ return
600
+ }
601
+ // Check if a resize needs to happen!
602
+ if req .Height == 0 || req .Width == 0 {
603
+ continue
604
+ }
605
+ err = rpty .ptty .Resize (req .Height , req .Width )
606
+ if err != nil {
607
+ // We can continue after this, it's not fatal!
608
+ a .logger .Error (ctx , "resize reconnecting pty" , slog .F ("id" , id ), slog .Error (err ))
609
+ }
610
+ }
611
+ }
612
+
392
613
// isClosed returns whether the API is closed or not.
393
614
func (a * agent ) isClosed () bool {
394
615
select {
@@ -411,3 +632,25 @@ func (a *agent) Close() error {
411
632
a .connCloseWait .Wait ()
412
633
return nil
413
634
}
635
+
636
+ type reconnectingPTY struct {
637
+ activeConnsMutex sync.Mutex
638
+ activeConns map [string ]net.Conn
639
+
640
+ circularBuffer * circbuf.Buffer
641
+ timeout * time.Timer
642
+ ptty pty.PTY
643
+ }
644
+
645
+ // Close ends all connections to the reconnecting
646
+ // PTY and clear the circular buffer.
647
+ func (r * reconnectingPTY ) Close () {
648
+ r .activeConnsMutex .Lock ()
649
+ defer r .activeConnsMutex .Unlock ()
650
+ for _ , conn := range r .activeConns {
651
+ _ = conn .Close ()
652
+ }
653
+ _ = r .ptty .Close ()
654
+ r .circularBuffer .Reset ()
655
+ r .timeout .Stop ()
656
+ }
0 commit comments