@@ -47,10 +47,16 @@ const (
47
47
)
48
48
49
49
type Server struct {
50
- serveWg sync.WaitGroup
51
- logger slog.Logger
50
+ mu sync.RWMutex // Protects following.
51
+ listeners map [net.Listener ]struct {}
52
+ conns map [net.Conn ]struct {}
53
+ closing chan struct {}
54
+ // Wait for goroutines to exit, waited without
55
+ // a lock on mu but protected by closing.
56
+ wg sync.WaitGroup
52
57
53
- srv * ssh.Server
58
+ logger slog.Logger
59
+ srv * ssh.Server
54
60
55
61
Env map [string ]string
56
62
AgentToken func () string
@@ -78,7 +84,9 @@ func NewServer(ctx context.Context, logger slog.Logger, maxTimeout time.Duration
78
84
unixForwardHandler := & forwardedUnixHandler {log : logger }
79
85
80
86
s := & Server {
81
- logger : logger ,
87
+ listeners : make (map [net.Listener ]struct {}),
88
+ conns : make (map [net.Conn ]struct {}),
89
+ logger : logger ,
82
90
}
83
91
84
92
s .srv = & ssh.Server {
@@ -472,14 +480,118 @@ func (s *Server) CreateCommand(ctx context.Context, script string, env []string)
472
480
}
473
481
474
482
func (s * Server ) Serve (l net.Listener ) error {
475
- s .serveWg .Add (1 )
476
- defer s .serveWg .Done ()
477
- return s .srv .Serve (l )
483
+ defer l .Close ()
484
+
485
+ s .trackListener (l , true )
486
+ defer s .trackListener (l , false )
487
+ for {
488
+ conn , err := l .Accept ()
489
+ if err != nil {
490
+ return err
491
+ }
492
+ go s .handleConn (l , conn )
493
+ }
478
494
}
479
495
496
+ func (s * Server ) handleConn (l net.Listener , c net.Conn ) {
497
+ defer c .Close ()
498
+
499
+ if ! s .trackConn (l , c , true ) {
500
+ // Server is closed or we no longer want
501
+ // connections from this listener.
502
+ s .logger .Debug (context .Background (), "received connection after server closed" )
503
+ return
504
+ }
505
+ defer s .trackConn (l , c , false )
506
+
507
+ s .srv .HandleConn (c )
508
+ }
509
+
510
+ // trackListener registers the listener with the server. If the server is
511
+ // closing, the function will block until the server is closed.
512
+ //
513
+ //nolint:revive
514
+ func (s * Server ) trackListener (l net.Listener , add bool ) {
515
+ s .mu .Lock ()
516
+ defer s .mu .Unlock ()
517
+ if add {
518
+ closing := s .closing
519
+ if closing != nil {
520
+ // Wait until close is complete before
521
+ // serving a new listener.
522
+ s .mu .Unlock ()
523
+ <- closing
524
+ s .mu .Lock ()
525
+ }
526
+ s .wg .Add (1 )
527
+ s .listeners [l ] = struct {}{}
528
+ return
529
+ }
530
+ s .wg .Done ()
531
+ delete (s .listeners , l )
532
+ }
533
+
534
+ // trackConn registers the connection with the server. If the server is
535
+ // closed or the listener is closed, the connection is not registered
536
+ // and should be closed.
537
+ //
538
+ //nolint:revive
539
+ func (s * Server ) trackConn (l net.Listener , c net.Conn , add bool ) (ok bool ) {
540
+ s .mu .Lock ()
541
+ defer s .mu .Unlock ()
542
+ if add {
543
+ found := false
544
+ for ll := range s .listeners {
545
+ if l == ll {
546
+ found = true
547
+ break
548
+ }
549
+ }
550
+ if s .closing != nil || ! found {
551
+ // Server or listener closed.
552
+ return false
553
+ }
554
+ s .wg .Add (1 )
555
+ s .conns [c ] = struct {}{}
556
+ return true
557
+ }
558
+ s .wg .Done ()
559
+ delete (s .conns , c )
560
+ return true
561
+ }
562
+
563
+ // Close the server and all active connections. Server can be re-used
564
+ // after Close is done.
480
565
func (s * Server ) Close () error {
566
+ s .mu .Lock ()
567
+
568
+ // Guard against multiple calls to Close and
569
+ // accepting new connections during close.
570
+ if s .closing != nil {
571
+ s .mu .Unlock ()
572
+ return xerrors .New ("server is closing" )
573
+ }
574
+ s .closing = make (chan struct {})
575
+
576
+ // Close all active listeners and connections.
577
+ for l := range s .listeners {
578
+ _ = l .Close ()
579
+ }
580
+ for c := range s .conns {
581
+ _ = c .Close ()
582
+ }
583
+
584
+ // Close the underlying SSH server.
481
585
err := s .srv .Close ()
482
- s .serveWg .Wait ()
586
+
587
+ s .mu .Unlock ()
588
+ s .wg .Wait () // Wait for all goroutines to exit.
589
+
590
+ s .mu .Lock ()
591
+ close (s .closing )
592
+ s .closing = nil
593
+ s .mu .Unlock ()
594
+
483
595
return err
484
596
}
485
597
0 commit comments