Skip to content

Commit b768d35

Browse files
committed
Improve handling of serve/close
1 parent e71ba85 commit b768d35

File tree

1 file changed

+120
-8
lines changed

1 file changed

+120
-8
lines changed

agent/agentssh/agentssh.go

+120-8
Original file line numberDiff line numberDiff line change
@@ -47,10 +47,16 @@ const (
4747
)
4848

4949
type Server struct {
50-
serveWg sync.WaitGroup
51-
logger slog.Logger
50+
mu sync.RWMutex // Protects following.
51+
listeners map[net.Listener]struct{}
52+
conns map[net.Conn]struct{}
53+
closing chan struct{}
54+
// Wait for goroutines to exit, waited without
55+
// a lock on mu but protected by closing.
56+
wg sync.WaitGroup
5257

53-
srv *ssh.Server
58+
logger slog.Logger
59+
srv *ssh.Server
5460

5561
Env map[string]string
5662
AgentToken func() string
@@ -78,7 +84,9 @@ func NewServer(ctx context.Context, logger slog.Logger, maxTimeout time.Duration
7884
unixForwardHandler := &forwardedUnixHandler{log: logger}
7985

8086
s := &Server{
81-
logger: logger,
87+
listeners: make(map[net.Listener]struct{}),
88+
conns: make(map[net.Conn]struct{}),
89+
logger: logger,
8290
}
8391

8492
s.srv = &ssh.Server{
@@ -472,14 +480,118 @@ func (s *Server) CreateCommand(ctx context.Context, script string, env []string)
472480
}
473481

474482
func (s *Server) Serve(l net.Listener) error {
475-
s.serveWg.Add(1)
476-
defer s.serveWg.Done()
477-
return s.srv.Serve(l)
483+
defer l.Close()
484+
485+
s.trackListener(l, true)
486+
defer s.trackListener(l, false)
487+
for {
488+
conn, err := l.Accept()
489+
if err != nil {
490+
return err
491+
}
492+
go s.handleConn(l, conn)
493+
}
478494
}
479495

496+
func (s *Server) handleConn(l net.Listener, c net.Conn) {
497+
defer c.Close()
498+
499+
if !s.trackConn(l, c, true) {
500+
// Server is closed or we no longer want
501+
// connections from this listener.
502+
s.logger.Debug(context.Background(), "received connection after server closed")
503+
return
504+
}
505+
defer s.trackConn(l, c, false)
506+
507+
s.srv.HandleConn(c)
508+
}
509+
510+
// trackListener registers the listener with the server. If the server is
511+
// closing, the function will block until the server is closed.
512+
//
513+
//nolint:revive
514+
func (s *Server) trackListener(l net.Listener, add bool) {
515+
s.mu.Lock()
516+
defer s.mu.Unlock()
517+
if add {
518+
closing := s.closing
519+
if closing != nil {
520+
// Wait until close is complete before
521+
// serving a new listener.
522+
s.mu.Unlock()
523+
<-closing
524+
s.mu.Lock()
525+
}
526+
s.wg.Add(1)
527+
s.listeners[l] = struct{}{}
528+
return
529+
}
530+
s.wg.Done()
531+
delete(s.listeners, l)
532+
}
533+
534+
// trackConn registers the connection with the server. If the server is
535+
// closed or the listener is closed, the connection is not registered
536+
// and should be closed.
537+
//
538+
//nolint:revive
539+
func (s *Server) trackConn(l net.Listener, c net.Conn, add bool) (ok bool) {
540+
s.mu.Lock()
541+
defer s.mu.Unlock()
542+
if add {
543+
found := false
544+
for ll := range s.listeners {
545+
if l == ll {
546+
found = true
547+
break
548+
}
549+
}
550+
if s.closing != nil || !found {
551+
// Server or listener closed.
552+
return false
553+
}
554+
s.wg.Add(1)
555+
s.conns[c] = struct{}{}
556+
return true
557+
}
558+
s.wg.Done()
559+
delete(s.conns, c)
560+
return true
561+
}
562+
563+
// Close the server and all active connections. Server can be re-used
564+
// after Close is done.
480565
func (s *Server) Close() error {
566+
s.mu.Lock()
567+
568+
// Guard against multiple calls to Close and
569+
// accepting new connections during close.
570+
if s.closing != nil {
571+
s.mu.Unlock()
572+
return xerrors.New("server is closing")
573+
}
574+
s.closing = make(chan struct{})
575+
576+
// Close all active listeners and connections.
577+
for l := range s.listeners {
578+
_ = l.Close()
579+
}
580+
for c := range s.conns {
581+
_ = c.Close()
582+
}
583+
584+
// Close the underlying SSH server.
481585
err := s.srv.Close()
482-
s.serveWg.Wait()
586+
587+
s.mu.Unlock()
588+
s.wg.Wait() // Wait for all goroutines to exit.
589+
590+
s.mu.Lock()
591+
close(s.closing)
592+
s.closing = nil
593+
s.mu.Unlock()
594+
483595
return err
484596
}
485597

0 commit comments

Comments
 (0)