diff --git a/server.go b/server.go index be4355e..b7e648b 100644 --- a/server.go +++ b/server.go @@ -73,6 +73,7 @@ type Server struct { conns map[*gossh.ServerConn]struct{} connWg sync.WaitGroup doneChan chan struct{} + closedAt time.Time } func (srv *Server) ensureHostSigner() error { @@ -180,6 +181,8 @@ func (srv *Server) Close() error { srv.mu.Lock() defer srv.mu.Unlock() + srv.closedAt = time.Now() + srv.closeDoneChanLocked() err := srv.closeListenersLocked() for c := range srv.conns { @@ -260,6 +263,17 @@ func (srv *Server) Serve(l net.Listener) error { } func (srv *Server) HandleConn(newConn net.Conn) { + // Best effort, track start time in case the server is closed + // between here and trackConn. If trackConn is called after the + // server is closed, this connection would leak. + // + // A better approach would be to have a clearer signal between the + // server being "started" and "closed", but that would require a + // larger refactor or change in logic. + srv.mu.RLock() + start := time.Now() + srv.mu.RUnlock() + ctx, cancel := newContext(srv) if srv.ConnCallback != nil { cbConn := srv.ConnCallback(ctx, newConn) @@ -289,6 +303,14 @@ func (srv *Server) HandleConn(newConn net.Conn) { srv.trackConn(sshConn, true) defer srv.trackConn(sshConn, false) + srv.mu.RLock() + if srv.closedAt.After(start) { + srv.mu.RUnlock() + sshConn.Close() + return + } + srv.mu.RUnlock() + ctx.SetValue(ContextKeyConn, sshConn) applyConnMetadata(ctx, sshConn) //go gossh.DiscardRequests(reqs) diff --git a/server_test.go b/server_test.go index 8028a3a..e7ae5fa 100644 --- a/server_test.go +++ b/server_test.go @@ -6,6 +6,8 @@ import ( "io" "testing" "time" + + gossh "golang.org/x/crypto/ssh" ) func TestAddHostKey(t *testing.T) { @@ -124,3 +126,104 @@ func TestServerClose(t *testing.T) { return } } + +func TestServerClose_ConnectionLeak(t *testing.T) { + l := newLocalListener() + s := &Server{ + Handler: func(s Session) { + time.Sleep(5 * time.Second) + }, + } + go func() { + err := s.Serve(l) + if err != nil && err != ErrServerClosed { + t.Error(err) + } + }() + + clientDoneChan := make(chan struct{}) + closeDoneChan := make(chan struct{}) + + num := 3 + ch := make(chan struct{}, num) + go func() { + for i := 0; i < num; i++ { + <-ch + } + close(clientDoneChan) + }() + prepare := make(chan struct{}, num) + go func() { + for i := 0; i < num; i++ { + go func() { + defer func() { + ch <- struct{}{} + }() + sess, _, cleanup, err := newClientSession2(t, l.Addr().String(), nil) + prepare <- struct{}{} + if err != nil { + t.Log(err) + return + } + defer cleanup() + if err := sess.Run(""); err != nil && err != io.EOF { + t.Log(err) + } + }() + } + }() + + go func() { + for i := 0; i < num-1; i++ { + <-prepare + } + err := s.Close() + if err != nil { + t.Error(err) + } + close(closeDoneChan) + }() + + timeout := time.After(1000 * time.Millisecond) + select { + case <-timeout: + t.Error("timeout") + return + case <-closeDoneChan: + } + select { + case <-timeout: + t.Error("timeout") + return + case <-clientDoneChan: + } +} + +func newClientSession2(t *testing.T, addr string, config *gossh.ClientConfig) (*gossh.Session, *gossh.Client, func(), error) { + t.Helper() + + if config == nil { + config = &gossh.ClientConfig{ + User: "testuser", + Auth: []gossh.AuthMethod{ + gossh.Password("testpass"), + }, + } + } + if config.HostKeyCallback == nil { + config.HostKeyCallback = gossh.InsecureIgnoreHostKey() + } + client, err := gossh.Dial("tcp", addr, config) + if err != nil { + return nil, nil, nil, err + } + session, err := client.NewSession() + if err != nil { + client.Close() + return nil, nil, nil, err + } + return session, client, func() { + session.Close() + client.Close() + }, nil +} diff --git a/session_test.go b/session_test.go index c6ce617..3fdb828 100644 --- a/session_test.go +++ b/session_test.go @@ -230,9 +230,9 @@ func TestPty(t *testing.T) { func TestPtyResize(t *testing.T) { t.Parallel() - winch0 := Window{40, 80} - winch1 := Window{80, 160} - winch2 := Window{20, 40} + winch0 := Window{Width: 40, Height: 80} + winch1 := Window{Width: 80, Height: 160} + winch2 := Window{Width: 20, Height: 40} winches := make(chan Window) done := make(chan bool) session, _, cleanup := newTestSession(t, &Server{