@@ -228,7 +228,7 @@ func (r *RootCmd) ssh() *clibase.Cmd {
228
228
if err != nil {
229
229
return xerrors .Errorf ("connect SSH: %w" , err )
230
230
}
231
- copier := & rawSSHCopier { conn : rawSSH , r : inv .Stdin , w : inv .Stdout }
231
+ copier := newRawSSHCopier ( logger , rawSSH , inv .Stdin , inv .Stdout )
232
232
if err = stack .push ("rawSSHCopier" , copier ); err != nil {
233
233
return err
234
234
}
@@ -853,9 +853,16 @@ type rawSSHCopier struct {
853
853
logger slog.Logger
854
854
r io.Reader
855
855
w io.Writer
856
+
857
+ done chan struct {}
858
+ }
859
+
860
+ func newRawSSHCopier (logger slog.Logger , conn * gonet.TCPConn , r io.Reader , w io.Writer ) * rawSSHCopier {
861
+ return & rawSSHCopier {conn : conn , logger : logger , r : r , w : w , done : make (chan struct {})}
856
862
}
857
863
858
864
func (c * rawSSHCopier ) copy (wg * sync.WaitGroup ) {
865
+ defer close (c .done )
859
866
logCtx := context .Background ()
860
867
wg .Add (1 )
861
868
go func () {
@@ -890,5 +897,16 @@ func (c *rawSSHCopier) copy(wg *sync.WaitGroup) {
890
897
}
891
898
892
899
func (c * rawSSHCopier ) Close () error {
893
- return c .conn .CloseWrite ()
900
+ err := c .conn .CloseWrite ()
901
+
902
+ // give the copy() call a chance to return on a timeout, so that we don't
903
+ // continue tearing down and close the underlying netstack before the SSH
904
+ // session has a chance to gracefully shut down.
905
+ t := time .NewTimer (5 * time .Second )
906
+ defer t .Stop ()
907
+ select {
908
+ case <- c .done :
909
+ case <- t .C :
910
+ }
911
+ return err
894
912
}
0 commit comments