From e1a3c1e9d23dcb1c139b9fc29fb3e57ac3f63e60 Mon Sep 17 00:00:00 2001 From: Spike Curtis Date: Tue, 21 Nov 2023 16:17:33 +0400 Subject: [PATCH] fix: give SSH stdio sessions a chance to close before closing netstack --- cli/ssh.go | 22 ++++++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/cli/ssh.go b/cli/ssh.go index f2844b3c55498..0c4b537949806 100644 --- a/cli/ssh.go +++ b/cli/ssh.go @@ -228,7 +228,7 @@ func (r *RootCmd) ssh() *clibase.Cmd { if err != nil { return xerrors.Errorf("connect SSH: %w", err) } - copier := &rawSSHCopier{conn: rawSSH, r: inv.Stdin, w: inv.Stdout} + copier := newRawSSHCopier(logger, rawSSH, inv.Stdin, inv.Stdout) if err = stack.push("rawSSHCopier", copier); err != nil { return err } @@ -853,9 +853,16 @@ type rawSSHCopier struct { logger slog.Logger r io.Reader w io.Writer + + done chan struct{} +} + +func newRawSSHCopier(logger slog.Logger, conn *gonet.TCPConn, r io.Reader, w io.Writer) *rawSSHCopier { + return &rawSSHCopier{conn: conn, logger: logger, r: r, w: w, done: make(chan struct{})} } func (c *rawSSHCopier) copy(wg *sync.WaitGroup) { + defer close(c.done) logCtx := context.Background() wg.Add(1) go func() { @@ -890,5 +897,16 @@ func (c *rawSSHCopier) copy(wg *sync.WaitGroup) { } func (c *rawSSHCopier) Close() error { - return c.conn.CloseWrite() + err := c.conn.CloseWrite() + + // give the copy() call a chance to return on a timeout, so that we don't + // continue tearing down and close the underlying netstack before the SSH + // session has a chance to gracefully shut down. + t := time.NewTimer(5 * time.Second) + defer t.Stop() + select { + case <-c.done: + case <-t.C: + } + return err }