@@ -22,6 +22,7 @@ import (
22
22
gosshagent "golang.org/x/crypto/ssh/agent"
23
23
"golang.org/x/term"
24
24
"golang.org/x/xerrors"
25
+ "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
25
26
26
27
"cdr.dev/slog"
27
28
"cdr.dev/slog/sloggers/sloghuman"
@@ -129,6 +130,8 @@ func (r *RootCmd) ssh() *clibase.Cmd {
129
130
// log HTTP requests
130
131
client .SetLogger (logger )
131
132
}
133
+ stack := newCloserStack (ctx , logger )
134
+ defer stack .close (nil )
132
135
133
136
if remoteForward != "" {
134
137
isValid := validateRemoteForward (remoteForward )
@@ -212,7 +215,9 @@ func (r *RootCmd) ssh() *clibase.Cmd {
212
215
if err != nil {
213
216
return xerrors .Errorf ("dial agent: %w" , err )
214
217
}
215
- defer conn .Close ()
218
+ if err = stack .push ("agent conn" , conn ); err != nil {
219
+ return err
220
+ }
216
221
conn .AwaitReachable (ctx )
217
222
218
223
stopPolling := tryPollWorkspaceAutostop (ctx , client , workspace )
@@ -223,61 +228,46 @@ func (r *RootCmd) ssh() *clibase.Cmd {
223
228
if err != nil {
224
229
return xerrors .Errorf ("connect SSH: %w" , err )
225
230
}
226
- defer rawSSH .Close ()
231
+ copier := & rawSSHCopier {conn : rawSSH , r : inv .Stdin , w : inv .Stdout }
232
+ if err = stack .push ("rawSSHCopier" , copier ); err != nil {
233
+ return err
234
+ }
227
235
228
236
wg .Add (1 )
229
237
go func () {
230
238
defer wg .Done ()
231
239
watchAndClose (ctx , func () error {
232
- return rawSSH .Close ()
240
+ stack .close (xerrors .New ("watchAndClose" ))
241
+ return nil
233
242
}, logger , client , workspace )
234
243
}()
235
-
236
- wg .Add (1 )
237
- go func () {
238
- defer wg .Done ()
239
- // Ensure stdout copy closes incase stdin is closed
240
- // unexpectedly.
241
- defer rawSSH .Close ()
242
-
243
- _ , err := io .Copy (rawSSH , inv .Stdin )
244
- if err != nil {
245
- logger .Error (ctx , "copy stdin error" , slog .Error (err ))
246
- } else {
247
- logger .Debug (ctx , "copy stdin complete" )
248
- }
249
- }()
250
- _ , err = io .Copy (inv .Stdout , rawSSH )
251
- if err != nil {
252
- logger .Error (ctx , "copy stdout error" , slog .Error (err ))
253
- } else {
254
- logger .Debug (ctx , "copy stdout complete" )
255
- }
244
+ copier .copy (& wg )
256
245
return nil
257
246
}
258
247
259
248
sshClient , err := conn .SSHClient (ctx )
260
249
if err != nil {
261
250
return xerrors .Errorf ("ssh client: %w" , err )
262
251
}
263
- defer sshClient .Close ()
252
+ if err = stack .push ("ssh client" , sshClient ); err != nil {
253
+ return err
254
+ }
264
255
265
256
sshSession , err := sshClient .NewSession ()
266
257
if err != nil {
267
258
return xerrors .Errorf ("ssh session: %w" , err )
268
259
}
269
- defer sshSession .Close ()
260
+ if err = stack .push ("sshSession" , sshSession ); err != nil {
261
+ return err
262
+ }
270
263
271
264
wg .Add (1 )
272
265
go func () {
273
266
defer wg .Done ()
274
267
watchAndClose (
275
268
ctx ,
276
269
func () error {
277
- err := sshSession .Close ()
278
- logger .Debug (ctx , "session close" , slog .Error (err ))
279
- err = sshClient .Close ()
280
- logger .Debug (ctx , "client close" , slog .Error (err ))
270
+ stack .close (xerrors .New ("watchAndClose" ))
281
271
return nil
282
272
},
283
273
logger ,
@@ -313,7 +303,9 @@ func (r *RootCmd) ssh() *clibase.Cmd {
313
303
if err != nil {
314
304
return xerrors .Errorf ("forward GPG socket: %w" , err )
315
305
}
316
- defer closer .Close ()
306
+ if err = stack .push ("forwardGPGAgent" , closer ); err != nil {
307
+ return err
308
+ }
317
309
}
318
310
319
311
if remoteForward != "" {
@@ -326,7 +318,9 @@ func (r *RootCmd) ssh() *clibase.Cmd {
326
318
if err != nil {
327
319
return xerrors .Errorf ("ssh remote forward: %w" , err )
328
320
}
329
- defer closer .Close ()
321
+ if err = stack .push ("sshRemoteForward" , closer ); err != nil {
322
+ return err
323
+ }
330
324
}
331
325
332
326
stdoutFile , validOut := inv .Stdout .(* os.File )
@@ -795,3 +789,106 @@ func remoteGPGAgentSocket(sshClient *gossh.Client) (string, error) {
795
789
796
790
return string (bytes .TrimSpace (remoteSocket )), nil
797
791
}
792
+
793
+ type closerWithName struct {
794
+ name string
795
+ closer io.Closer
796
+ }
797
+
798
+ type closerStack struct {
799
+ sync.Mutex
800
+ closers []closerWithName
801
+ closed bool
802
+ logger slog.Logger
803
+ err error
804
+ }
805
+
806
+ func newCloserStack (ctx context.Context , logger slog.Logger ) * closerStack {
807
+ cs := & closerStack {logger : logger }
808
+ go cs .closeAfterContext (ctx )
809
+ return cs
810
+ }
811
+
812
+ func (c * closerStack ) closeAfterContext (ctx context.Context ) {
813
+ <- ctx .Done ()
814
+ c .close (ctx .Err ())
815
+ }
816
+
817
+ func (c * closerStack ) close (err error ) {
818
+ c .Lock ()
819
+ if c .closed {
820
+ c .Unlock ()
821
+ return
822
+ }
823
+ c .closed = true
824
+ c .err = err
825
+ c .Unlock ()
826
+
827
+ for i := len (c .closers ) - 1 ; i >= 0 ; i -- {
828
+ cwn := c .closers [i ]
829
+ cErr := cwn .closer .Close ()
830
+ c .logger .Debug (context .Background (),
831
+ "closed item from stack" , slog .F ("name" , cwn .name ), slog .Error (cErr ))
832
+ }
833
+ }
834
+
835
+ func (c * closerStack ) push (name string , closer io.Closer ) error {
836
+ c .Lock ()
837
+ if c .closed {
838
+ c .Unlock ()
839
+ // since we're refusing to push it on the stack, close it now
840
+ err := closer .Close ()
841
+ c .logger .Error (context .Background (),
842
+ "closed item rejected push" , slog .F ("name" , name ), slog .Error (err ))
843
+ return xerrors .Errorf ("already closed: %w" , c .err )
844
+ }
845
+ c .closers = append (c .closers , closerWithName {name : name , closer : closer })
846
+ c .Unlock ()
847
+ return nil
848
+ }
849
+
850
+ // rawSSHCopier handles copying raw SSH data between the conn and the pair (r, w).
851
+ type rawSSHCopier struct {
852
+ conn * gonet.TCPConn
853
+ logger slog.Logger
854
+ r io.Reader
855
+ w io.Writer
856
+ }
857
+
858
+ func (c * rawSSHCopier ) copy (wg * sync.WaitGroup ) {
859
+ logCtx := context .Background ()
860
+ wg .Add (1 )
861
+ go func () {
862
+ defer wg .Done ()
863
+ // We close connections using CloseWrite instead of Close, so that the SSH server sees the
864
+ // closed connection while reading, and shuts down cleanly. This will trigger the io.Copy
865
+ // in the server-to-client direction to also be closed and the copy() routine will exit.
866
+ // This ensures that we don't leave any state in the server, like forwarded ports if
867
+ // copy() were to return and the underlying tailnet connection torn down before the TCP
868
+ // session exits. This is a bit of a hack to block shut down at the application layer, since
869
+ // we can't serialize the TCP and tailnet layers shutting down.
870
+ //
871
+ // Of course, if the underlying transport is broken, io.Copy will still return.
872
+ defer func () {
873
+ cwErr := c .conn .CloseWrite ()
874
+ c .logger .Debug (logCtx , "closed raw SSH connection for writing" , slog .Error (cwErr ))
875
+ }()
876
+
877
+ _ , err := io .Copy (c .conn , c .r )
878
+ if err != nil {
879
+ c .logger .Error (logCtx , "copy stdin error" , slog .Error (err ))
880
+ } else {
881
+ c .logger .Debug (logCtx , "copy stdin complete" )
882
+ }
883
+ }()
884
+ _ , err := io .Copy (c .w , c .conn )
885
+ if err != nil {
886
+ c .logger .Error (logCtx , "copy stdout error" , slog .Error (err ))
887
+ } else {
888
+ c .logger .Debug (logCtx , "copy stdout complete" )
889
+ }
890
+ }
891
+
892
+ func (c * rawSSHCopier ) Close () error {
893
+ return c .conn .CloseWrite ()
894
+ }
0 commit comments