4
4
"context"
5
5
"crypto/rand"
6
6
"crypto/rsa"
7
+ "encoding/json"
7
8
"errors"
8
9
"fmt"
9
10
"io"
@@ -12,10 +13,14 @@ import (
12
13
"os/exec"
13
14
"os/user"
14
15
"runtime"
16
+ "strconv"
15
17
"strings"
16
18
"sync"
17
19
"time"
18
20
21
+ "github.com/armon/circbuf"
22
+ "github.com/google/uuid"
23
+
19
24
gsyslog "github.com/hashicorp/go-syslog"
20
25
"go.uber.org/atomic"
21
26
@@ -33,6 +38,11 @@ import (
33
38
"golang.org/x/xerrors"
34
39
)
35
40
41
+ type Options struct {
42
+ ReconnectingPTYTimeout time.Duration
43
+ Logger slog.Logger
44
+ }
45
+
36
46
type Metadata struct {
37
47
OwnerEmail string `json:"owner_email"`
38
48
OwnerUsername string `json:"owner_username"`
@@ -42,13 +52,20 @@ type Metadata struct {
42
52
43
53
type Dialer func (ctx context.Context , logger slog.Logger ) (Metadata , * peerbroker.Listener , error )
44
54
45
- func New (dialer Dialer , logger slog.Logger ) io.Closer {
55
+ func New (dialer Dialer , options * Options ) io.Closer {
56
+ if options == nil {
57
+ options = & Options {}
58
+ }
59
+ if options .ReconnectingPTYTimeout == 0 {
60
+ options .ReconnectingPTYTimeout = 5 * time .Minute
61
+ }
46
62
ctx , cancelFunc := context .WithCancel (context .Background ())
47
63
server := & agent {
48
- dialer : dialer ,
49
- logger : logger ,
50
- closeCancel : cancelFunc ,
51
- closed : make (chan struct {}),
64
+ dialer : dialer ,
65
+ reconnectingPTYTimeout : options .ReconnectingPTYTimeout ,
66
+ logger : options .Logger ,
67
+ closeCancel : cancelFunc ,
68
+ closed : make (chan struct {}),
52
69
}
53
70
server .init (ctx )
54
71
return server
@@ -58,6 +75,9 @@ type agent struct {
58
75
dialer Dialer
59
76
logger slog.Logger
60
77
78
+ reconnectingPTYs sync.Map
79
+ reconnectingPTYTimeout time.Duration
80
+
61
81
connCloseWait sync.WaitGroup
62
82
closeCancel context.CancelFunc
63
83
closeMutex sync.Mutex
@@ -196,6 +216,8 @@ func (a *agent) handlePeerConn(ctx context.Context, conn *peer.Conn) {
196
216
switch channel .Protocol () {
197
217
case "ssh" :
198
218
go a .sshServer .HandleConn (channel .NetConn ())
219
+ case "reconnecting-pty" :
220
+ go a .handleReconnectingPTY (ctx , channel .Label (), channel .NetConn ())
199
221
default :
200
222
a .logger .Warn (ctx , "unhandled protocol from channel" ,
201
223
slog .F ("protocol" , channel .Protocol ()),
@@ -282,22 +304,25 @@ func (a *agent) init(ctx context.Context) {
282
304
go a .run (ctx )
283
305
}
284
306
285
- func (a * agent ) handleSSHSession (session ssh.Session ) error {
307
+ // createCommand processes raw command input with OpenSSH-like behavior.
308
+ // If the rawCommand provided is empty, it will default to the users shell.
309
+ // This injects environment variables specified by the user at launch too.
310
+ func (a * agent ) createCommand (ctx context.Context , rawCommand string , env []string ) (* exec.Cmd , error ) {
286
311
currentUser , err := user .Current ()
287
312
if err != nil {
288
- return xerrors .Errorf ("get current user: %w" , err )
313
+ return nil , xerrors .Errorf ("get current user: %w" , err )
289
314
}
290
315
username := currentUser .Username
291
316
292
317
shell , err := usershell .Get (username )
293
318
if err != nil {
294
- return xerrors .Errorf ("get user shell: %w" , err )
319
+ return nil , xerrors .Errorf ("get user shell: %w" , err )
295
320
}
296
321
297
322
// gliderlabs/ssh returns a command slice of zero
298
323
// when a shell is requested.
299
- command := session . RawCommand ()
300
- if len (session . Command () ) == 0 {
324
+ command := rawCommand
325
+ if len (command ) == 0 {
301
326
command = shell
302
327
}
303
328
@@ -307,11 +332,11 @@ func (a *agent) handleSSHSession(session ssh.Session) error {
307
332
if runtime .GOOS == "windows" {
308
333
caller = "/c"
309
334
}
310
- cmd := exec .CommandContext (session . Context () , shell , caller , command )
311
- cmd .Env = append (os .Environ (), session . Environ () ... )
335
+ cmd := exec .CommandContext (ctx , shell , caller , command )
336
+ cmd .Env = append (os .Environ (), env ... )
312
337
executablePath , err := os .Executable ()
313
338
if err != nil {
314
- return xerrors .Errorf ("getting os executable: %w" , err )
339
+ return nil , xerrors .Errorf ("getting os executable: %w" , err )
315
340
}
316
341
// Git on Windows resolves with UNIX-style paths.
317
342
// If using backslashes, it's unable to find the executable.
@@ -332,6 +357,14 @@ func (a *agent) handleSSHSession(session ssh.Session) error {
332
357
}
333
358
}
334
359
}
360
+ return cmd , nil
361
+ }
362
+
363
+ func (a * agent ) handleSSHSession (session ssh.Session ) error {
364
+ cmd , err := a .createCommand (session .Context (), session .RawCommand (), session .Environ ())
365
+ if err != nil {
366
+ return err
367
+ }
335
368
336
369
sshPty , windowSize , isPty := session .Pty ()
337
370
if isPty {
@@ -381,6 +414,194 @@ func (a *agent) handleSSHSession(session ssh.Session) error {
381
414
return cmd .Wait ()
382
415
}
383
416
417
+ func (a * agent ) handleReconnectingPTY (ctx context.Context , rawID string , conn net.Conn ) {
418
+ defer conn .Close ()
419
+
420
+ // The ID format is referenced in conn.go.
421
+ // <uuid>:<height>:<width>
422
+ idParts := strings .Split (rawID , ":" )
423
+ if len (idParts ) != 3 {
424
+ a .logger .Warn (ctx , "client sent invalid id format" , slog .F ("raw-id" , rawID ))
425
+ return
426
+ }
427
+ id := idParts [0 ]
428
+ // Enforce a consistent format for IDs.
429
+ _ , err := uuid .Parse (id )
430
+ if err != nil {
431
+ a .logger .Warn (ctx , "client sent reconnection token that isn't a uuid" , slog .F ("id" , id ), slog .Error (err ))
432
+ return
433
+ }
434
+ // Parse the initial terminal dimensions.
435
+ height , err := strconv .Atoi (idParts [1 ])
436
+ if err != nil {
437
+ a .logger .Warn (ctx , "client sent invalid height" , slog .F ("id" , id ), slog .F ("height" , idParts [1 ]))
438
+ return
439
+ }
440
+ width , err := strconv .Atoi (idParts [2 ])
441
+ if err != nil {
442
+ a .logger .Warn (ctx , "client sent invalid width" , slog .F ("id" , id ), slog .F ("width" , idParts [2 ]))
443
+ return
444
+ }
445
+
446
+ var rpty * reconnectingPTY
447
+ rawRPTY , ok := a .reconnectingPTYs .Load (id )
448
+ if ok {
449
+ rpty , ok = rawRPTY .(* reconnectingPTY )
450
+ if ! ok {
451
+ a .logger .Warn (ctx , "found invalid type in reconnecting pty map" , slog .F ("id" , id ))
452
+ }
453
+ } else {
454
+ // Empty command will default to the users shell!
455
+ cmd , err := a .createCommand (ctx , "" , nil )
456
+ if err != nil {
457
+ a .logger .Warn (ctx , "create reconnecting pty command" , slog .Error (err ))
458
+ return
459
+ }
460
+ cmd .Env = append (cmd .Env , "TERM=xterm-256color" )
461
+
462
+ ptty , process , err := pty .Start (cmd )
463
+ if err != nil {
464
+ a .logger .Warn (ctx , "start reconnecting pty command" , slog .F ("id" , id ))
465
+ }
466
+
467
+ // Default to buffer 64KB.
468
+ circularBuffer , err := circbuf .NewBuffer (64 * 1024 )
469
+ if err != nil {
470
+ a .logger .Warn (ctx , "create circular buffer" , slog .Error (err ))
471
+ return
472
+ }
473
+
474
+ a .closeMutex .Lock ()
475
+ a .connCloseWait .Add (1 )
476
+ a .closeMutex .Unlock ()
477
+ ctx , cancelFunc := context .WithCancel (ctx )
478
+ rpty = & reconnectingPTY {
479
+ activeConns : make (map [string ]net.Conn ),
480
+ ptty : ptty ,
481
+ // Timeouts created with an after func can be reset!
482
+ timeout : time .AfterFunc (a .reconnectingPTYTimeout , cancelFunc ),
483
+ circularBuffer : circularBuffer ,
484
+ }
485
+ a .reconnectingPTYs .Store (id , rpty )
486
+ go func () {
487
+ // CommandContext isn't respected for Windows PTYs right now,
488
+ // so we need to manually track the lifecycle.
489
+ // When the context has been completed either:
490
+ // 1. The timeout completed.
491
+ // 2. The parent context was canceled.
492
+ <- ctx .Done ()
493
+ _ = process .Kill ()
494
+ }()
495
+ go func () {
496
+ // If the process dies randomly, we should
497
+ // close the pty.
498
+ _ , _ = process .Wait ()
499
+ rpty .Close ()
500
+ }()
501
+ go func () {
502
+ buffer := make ([]byte , 1024 )
503
+ for {
504
+ read , err := rpty .ptty .Output ().Read (buffer )
505
+ if err != nil {
506
+ // When the PTY is closed, this is triggered.
507
+ break
508
+ }
509
+ part := buffer [:read ]
510
+ _ , err = rpty .circularBuffer .Write (part )
511
+ if err != nil {
512
+ a .logger .Error (ctx , "reconnecting pty write buffer" , slog .Error (err ), slog .F ("id" , id ))
513
+ break
514
+ }
515
+ rpty .activeConnsMutex .Lock ()
516
+ for _ , conn := range rpty .activeConns {
517
+ _ , _ = conn .Write (part )
518
+ }
519
+ rpty .activeConnsMutex .Unlock ()
520
+ }
521
+
522
+ // Cleanup the process, PTY, and delete it's
523
+ // ID from memory.
524
+ _ = process .Kill ()
525
+ rpty .Close ()
526
+ a .reconnectingPTYs .Delete (id )
527
+ a .connCloseWait .Done ()
528
+ }()
529
+ }
530
+ // Resize the PTY to initial height + width.
531
+ err = rpty .ptty .Resize (uint16 (height ), uint16 (width ))
532
+ if err != nil {
533
+ // We can continue after this, it's not fatal!
534
+ a .logger .Error (ctx , "resize reconnecting pty" , slog .F ("id" , id ), slog .Error (err ))
535
+ }
536
+ // Write any previously stored data for the TTY.
537
+ _ , err = conn .Write (rpty .circularBuffer .Bytes ())
538
+ if err != nil {
539
+ a .logger .Warn (ctx , "write reconnecting pty buffer" , slog .F ("id" , id ), slog .Error (err ))
540
+ return
541
+ }
542
+ connectionID := uuid .NewString ()
543
+ // Multiple connections to the same TTY are permitted.
544
+ // This could easily be used for terminal sharing, but
545
+ // we do it because it's a nice user experience to
546
+ // copy/paste a terminal URL and have it _just work_.
547
+ rpty .activeConnsMutex .Lock ()
548
+ rpty .activeConns [connectionID ] = conn
549
+ rpty .activeConnsMutex .Unlock ()
550
+ // Resetting this timeout prevents the PTY from exiting.
551
+ rpty .timeout .Reset (a .reconnectingPTYTimeout )
552
+
553
+ ctx , cancelFunc := context .WithCancel (ctx )
554
+ defer cancelFunc ()
555
+ heartbeat := time .NewTicker (a .reconnectingPTYTimeout / 2 )
556
+ defer heartbeat .Stop ()
557
+ go func () {
558
+ // Keep updating the activity while this
559
+ // connection is alive!
560
+ for {
561
+ select {
562
+ case <- ctx .Done ():
563
+ return
564
+ case <- heartbeat .C :
565
+ }
566
+ rpty .timeout .Reset (a .reconnectingPTYTimeout )
567
+ }
568
+ }()
569
+ defer func () {
570
+ // After this connection ends, remove it from
571
+ // the PTYs active connections. If it isn't
572
+ // removed, all PTY data will be sent to it.
573
+ rpty .activeConnsMutex .Lock ()
574
+ delete (rpty .activeConns , connectionID )
575
+ rpty .activeConnsMutex .Unlock ()
576
+ }()
577
+ decoder := json .NewDecoder (conn )
578
+ var req ReconnectingPTYRequest
579
+ for {
580
+ err = decoder .Decode (& req )
581
+ if xerrors .Is (err , io .EOF ) {
582
+ return
583
+ }
584
+ if err != nil {
585
+ a .logger .Warn (ctx , "reconnecting pty buffer read error" , slog .F ("id" , id ), slog .Error (err ))
586
+ return
587
+ }
588
+ _ , err = rpty .ptty .Input ().Write ([]byte (req .Data ))
589
+ if err != nil {
590
+ a .logger .Warn (ctx , "write to reconnecting pty" , slog .F ("id" , id ), slog .Error (err ))
591
+ return
592
+ }
593
+ // Check if a resize needs to happen!
594
+ if req .Height == 0 || req .Width == 0 {
595
+ continue
596
+ }
597
+ err = rpty .ptty .Resize (req .Height , req .Width )
598
+ if err != nil {
599
+ // We can continue after this, it's not fatal!
600
+ a .logger .Error (ctx , "resize reconnecting pty" , slog .F ("id" , id ), slog .Error (err ))
601
+ }
602
+ }
603
+ }
604
+
384
605
// isClosed returns whether the API is closed or not.
385
606
func (a * agent ) isClosed () bool {
386
607
select {
@@ -403,3 +624,25 @@ func (a *agent) Close() error {
403
624
a .connCloseWait .Wait ()
404
625
return nil
405
626
}
627
+
628
+ type reconnectingPTY struct {
629
+ activeConnsMutex sync.Mutex
630
+ activeConns map [string ]net.Conn
631
+
632
+ circularBuffer * circbuf.Buffer
633
+ timeout * time.Timer
634
+ ptty pty.PTY
635
+ }
636
+
637
+ // Close ends all connections to the reconnecting
638
+ // PTY and clear the circular buffer.
639
+ func (r * reconnectingPTY ) Close () {
640
+ r .activeConnsMutex .Lock ()
641
+ defer r .activeConnsMutex .Unlock ()
642
+ for _ , conn := range r .activeConns {
643
+ _ = conn .Close ()
644
+ }
645
+ _ = r .ptty .Close ()
646
+ r .circularBuffer .Reset ()
647
+ r .timeout .Stop ()
648
+ }
0 commit comments