@@ -14,6 +14,8 @@ import (
14
14
"sync"
15
15
"time"
16
16
17
+ "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
18
+
17
19
"github.com/gen2brain/beeep"
18
20
"github.com/gofrs/flock"
19
21
"github.com/google/uuid"
@@ -129,6 +131,8 @@ func (r *RootCmd) ssh() *clibase.Cmd {
129
131
// log HTTP requests
130
132
client .SetLogger (logger )
131
133
}
134
+ stack := newCloserStack (ctx , logger )
135
+ defer stack .close (nil )
132
136
133
137
if remoteForward != "" {
134
138
isValid := validateRemoteForward (remoteForward )
@@ -212,7 +216,9 @@ func (r *RootCmd) ssh() *clibase.Cmd {
212
216
if err != nil {
213
217
return xerrors .Errorf ("dial agent: %w" , err )
214
218
}
215
- defer conn .Close ()
219
+ if err = stack .push ("agent conn" , conn ); err != nil {
220
+ return err
221
+ }
216
222
conn .AwaitReachable (ctx )
217
223
218
224
stopPolling := tryPollWorkspaceAutostop (ctx , client , workspace )
@@ -223,61 +229,61 @@ func (r *RootCmd) ssh() *clibase.Cmd {
223
229
if err != nil {
224
230
return xerrors .Errorf ("connect SSH: %w" , err )
225
231
}
226
- defer rawSSH .Close ()
232
+ copier := & rawSSHCopier {conn : rawSSH , r : inv .Stdin , w : inv .Stdout }
233
+ if err = stack .push ("rawSSHCopier" , copier ); err != nil {
234
+ return err
235
+ }
227
236
228
237
wg .Add (1 )
229
238
go func () {
230
239
defer wg .Done ()
231
240
watchAndClose (ctx , func () error {
232
- return rawSSH .Close ()
241
+ stack .close (xerrors .New ("watchAndClose" ))
242
+ return nil
233
243
}, logger , client , workspace )
234
244
}()
235
-
236
- wg .Add (1 )
237
- go func () {
238
- defer wg .Done ()
239
- // Ensure stdout copy closes incase stdin is closed
240
- // unexpectedly.
241
- defer rawSSH .Close ()
242
-
243
- _ , err := io .Copy (rawSSH , inv .Stdin )
244
- if err != nil {
245
- logger .Error (ctx , "copy stdin error" , slog .Error (err ))
246
- } else {
247
- logger .Debug (ctx , "copy stdin complete" )
248
- }
249
- }()
250
- _ , err = io .Copy (inv .Stdout , rawSSH )
251
- if err != nil {
252
- logger .Error (ctx , "copy stdout error" , slog .Error (err ))
253
- } else {
254
- logger .Debug (ctx , "copy stdout complete" )
255
- }
245
+ copier .copy (& wg )
256
246
return nil
257
247
}
258
248
249
+ //rawSSH, err := conn.SSH(ctx)
250
+ //if err != nil {
251
+ // return xerrors.Errorf("connect SSH: %w", err)
252
+ //}
253
+ //defer rawSSH.CloseWrite()
259
254
sshClient , err := conn .SSHClient (ctx )
260
255
if err != nil {
261
256
return xerrors .Errorf ("ssh client: %w" , err )
262
257
}
263
- defer sshClient .Close ()
258
+ if err = stack .push ("ssh client" , sshClient ); err != nil {
259
+ return err
260
+ }
261
+ //sshConn, channels, requests, err := gossh.NewClientConn(rawSSH, "localhost:22", &gossh.ClientConfig{
262
+ // // SSH host validation isn't helpful, because obtaining a peer
263
+ // // connection already signifies user-intent to dial a workspace.
264
+ // // #nosec
265
+ // HostKeyCallback: gossh.InsecureIgnoreHostKey(),
266
+ //})
267
+ //if err != nil {
268
+ // return xerrors.Errorf("ssh conn: %w", err)
269
+ //}
270
+ //sshClient := gossh.NewClient(sshConn, channels, requests)
264
271
265
272
sshSession , err := sshClient .NewSession ()
266
273
if err != nil {
267
274
return xerrors .Errorf ("ssh session: %w" , err )
268
275
}
269
- defer sshSession .Close ()
276
+ if err = stack .push ("sshSession" , sshSession ); err != nil {
277
+ return err
278
+ }
270
279
271
280
wg .Add (1 )
272
281
go func () {
273
282
defer wg .Done ()
274
283
watchAndClose (
275
284
ctx ,
276
285
func () error {
277
- err := sshSession .Close ()
278
- logger .Debug (ctx , "session close" , slog .Error (err ))
279
- err = sshClient .Close ()
280
- logger .Debug (ctx , "client close" , slog .Error (err ))
286
+ stack .close (xerrors .New ("watchAndClose" ))
281
287
return nil
282
288
},
283
289
logger ,
@@ -313,7 +319,9 @@ func (r *RootCmd) ssh() *clibase.Cmd {
313
319
if err != nil {
314
320
return xerrors .Errorf ("forward GPG socket: %w" , err )
315
321
}
316
- defer closer .Close ()
322
+ if err = stack .push ("forwardGPGAgent" , closer ); err != nil {
323
+ return err
324
+ }
317
325
}
318
326
319
327
if remoteForward != "" {
@@ -326,7 +334,9 @@ func (r *RootCmd) ssh() *clibase.Cmd {
326
334
if err != nil {
327
335
return xerrors .Errorf ("ssh remote forward: %w" , err )
328
336
}
329
- defer closer .Close ()
337
+ if err = stack .push ("sshRemoteForward" , closer ); err != nil {
338
+ return err
339
+ }
330
340
}
331
341
332
342
stdoutFile , validOut := inv .Stdout .(* os.File )
@@ -795,3 +805,104 @@ func remoteGPGAgentSocket(sshClient *gossh.Client) (string, error) {
795
805
796
806
return string (bytes .TrimSpace (remoteSocket )), nil
797
807
}
808
+
809
+ type closerWithName struct {
810
+ name string
811
+ closer io.Closer
812
+ }
813
+
814
+ type closerStack struct {
815
+ sync.Mutex
816
+ closers []closerWithName
817
+ closed bool
818
+ logger slog.Logger
819
+ err error
820
+ }
821
+
822
+ func newCloserStack (ctx context.Context , logger slog.Logger ) * closerStack {
823
+ cs := & closerStack {logger : logger }
824
+ go cs .closeAfterContext (ctx )
825
+ return cs
826
+ }
827
+
828
+ func (c * closerStack ) closeAfterContext (ctx context.Context ) {
829
+ <- ctx .Done ()
830
+ c .close (ctx .Err ())
831
+ }
832
+
833
+ func (c * closerStack ) close (err error ) {
834
+ c .Lock ()
835
+ if c .closed {
836
+ c .Unlock ()
837
+ return
838
+ }
839
+ c .closed = true
840
+ c .err = err
841
+ c .Unlock ()
842
+
843
+ for i := len (c .closers ) - 1 ; i >= 0 ; i -- {
844
+ cwn := c .closers [i ]
845
+ cErr := cwn .closer .Close ()
846
+ c .logger .Debug (context .Background (),
847
+ "closed item from stack" , slog .F ("name" , cwn .name ), slog .Error (cErr ))
848
+ }
849
+ }
850
+
851
+ func (c * closerStack ) push (name string , closer io.Closer ) error {
852
+ c .Lock ()
853
+ if c .closed {
854
+ c .Unlock ()
855
+ // since we're refusing to push it on the stack, close it now
856
+ err := closer .Close ()
857
+ c .logger .Error (context .Background (),
858
+ "closed item rejected push" , slog .F ("name" , name ), slog .Error (err ))
859
+ return xerrors .Errorf ("already closed: %w" , c .err )
860
+ }
861
+ c .closers = append (c .closers , closerWithName {name : name , closer : closer })
862
+ c .Unlock ()
863
+ return nil
864
+ }
865
+
866
+ // rawSSHCopier handles copying raw SSH data between the conn and the pair (r, w).
867
+ type rawSSHCopier struct {
868
+ conn * gonet.TCPConn
869
+ logger slog.Logger
870
+ r io.Reader
871
+ w io.Writer
872
+ }
873
+
874
+ func (c * rawSSHCopier ) copy (wg * sync.WaitGroup ) {
875
+ logCtx := context .Background ()
876
+ wg .Add (1 )
877
+ go func () {
878
+ defer wg .Done ()
879
+ // We close connections using CloseWrite instead of Close, so that the SSH server sees the
880
+ // closed connection while reading, and shuts down cleanly. This will trigger the io.Copy
881
+ // in the server-to-client direction to also be closed and the copy() routine will exit.
882
+ // This ensures that we don't leave any state in the server, like forwarded ports if
883
+ // copy() were to return and the underlying tailnet connection torn down before the TCP
884
+ // session exits. This is a bit of a hack to block shut down at the application layer, since
885
+ // we can't serialize the TCP and tailnet layers shutting down.
886
+ //
887
+ // Of course, if the underlying transport is broken, io.Copy will still return.
888
+ defer c .conn .CloseWrite ()
889
+
890
+ _ , err := io .Copy (c .conn , c .r )
891
+ if err != nil {
892
+ c .logger .Error (logCtx , "copy stdin error" , slog .Error (err ))
893
+ } else {
894
+ c .logger .Debug (logCtx , "copy stdin complete" )
895
+ }
896
+ }()
897
+ _ , err := io .Copy (c .w , c .conn )
898
+ if err != nil {
899
+ c .logger .Error (logCtx , "copy stdout error" , slog .Error (err ))
900
+ } else {
901
+ c .logger .Debug (logCtx , "copy stdout complete" )
902
+ }
903
+ return
904
+ }
905
+
906
+ func (c * rawSSHCopier ) Close () error {
907
+ return c .conn .CloseWrite ()
908
+ }
0 commit comments