Skip to content

Commit ad2ffdd

Browse files
committed
fix: close ssh sessions gracefully
1 parent c130f8d commit ad2ffdd

File tree

4 files changed

+343
-36
lines changed

4 files changed

+343
-36
lines changed

cli/ssh.go

Lines changed: 143 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ import (
1414
"sync"
1515
"time"
1616

17+
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
18+
1719
"github.com/gen2brain/beeep"
1820
"github.com/gofrs/flock"
1921
"github.com/google/uuid"
@@ -129,6 +131,8 @@ func (r *RootCmd) ssh() *clibase.Cmd {
129131
// log HTTP requests
130132
client.SetLogger(logger)
131133
}
134+
stack := newCloserStack(ctx, logger)
135+
defer stack.close(nil)
132136

133137
if remoteForward != "" {
134138
isValid := validateRemoteForward(remoteForward)
@@ -212,7 +216,9 @@ func (r *RootCmd) ssh() *clibase.Cmd {
212216
if err != nil {
213217
return xerrors.Errorf("dial agent: %w", err)
214218
}
215-
defer conn.Close()
219+
if err = stack.push("agent conn", conn); err != nil {
220+
return err
221+
}
216222
conn.AwaitReachable(ctx)
217223

218224
stopPolling := tryPollWorkspaceAutostop(ctx, client, workspace)
@@ -223,61 +229,61 @@ func (r *RootCmd) ssh() *clibase.Cmd {
223229
if err != nil {
224230
return xerrors.Errorf("connect SSH: %w", err)
225231
}
226-
defer rawSSH.Close()
232+
copier := &rawSSHCopier{conn: rawSSH, r: inv.Stdin, w: inv.Stdout}
233+
if err = stack.push("rawSSHCopier", copier); err != nil {
234+
return err
235+
}
227236

228237
wg.Add(1)
229238
go func() {
230239
defer wg.Done()
231240
watchAndClose(ctx, func() error {
232-
return rawSSH.Close()
241+
stack.close(xerrors.New("watchAndClose"))
242+
return nil
233243
}, logger, client, workspace)
234244
}()
235-
236-
wg.Add(1)
237-
go func() {
238-
defer wg.Done()
239-
// Ensure stdout copy closes incase stdin is closed
240-
// unexpectedly.
241-
defer rawSSH.Close()
242-
243-
_, err := io.Copy(rawSSH, inv.Stdin)
244-
if err != nil {
245-
logger.Error(ctx, "copy stdin error", slog.Error(err))
246-
} else {
247-
logger.Debug(ctx, "copy stdin complete")
248-
}
249-
}()
250-
_, err = io.Copy(inv.Stdout, rawSSH)
251-
if err != nil {
252-
logger.Error(ctx, "copy stdout error", slog.Error(err))
253-
} else {
254-
logger.Debug(ctx, "copy stdout complete")
255-
}
245+
copier.copy(&wg)
256246
return nil
257247
}
258248

249+
//rawSSH, err := conn.SSH(ctx)
250+
//if err != nil {
251+
// return xerrors.Errorf("connect SSH: %w", err)
252+
//}
253+
//defer rawSSH.CloseWrite()
259254
sshClient, err := conn.SSHClient(ctx)
260255
if err != nil {
261256
return xerrors.Errorf("ssh client: %w", err)
262257
}
263-
defer sshClient.Close()
258+
if err = stack.push("ssh client", sshClient); err != nil {
259+
return err
260+
}
261+
//sshConn, channels, requests, err := gossh.NewClientConn(rawSSH, "localhost:22", &gossh.ClientConfig{
262+
// // SSH host validation isn't helpful, because obtaining a peer
263+
// // connection already signifies user-intent to dial a workspace.
264+
// // #nosec
265+
// HostKeyCallback: gossh.InsecureIgnoreHostKey(),
266+
//})
267+
//if err != nil {
268+
// return xerrors.Errorf("ssh conn: %w", err)
269+
//}
270+
//sshClient := gossh.NewClient(sshConn, channels, requests)
264271

265272
sshSession, err := sshClient.NewSession()
266273
if err != nil {
267274
return xerrors.Errorf("ssh session: %w", err)
268275
}
269-
defer sshSession.Close()
276+
if err = stack.push("sshSession", sshSession); err != nil {
277+
return err
278+
}
270279

271280
wg.Add(1)
272281
go func() {
273282
defer wg.Done()
274283
watchAndClose(
275284
ctx,
276285
func() error {
277-
err := sshSession.Close()
278-
logger.Debug(ctx, "session close", slog.Error(err))
279-
err = sshClient.Close()
280-
logger.Debug(ctx, "client close", slog.Error(err))
286+
stack.close(xerrors.New("watchAndClose"))
281287
return nil
282288
},
283289
logger,
@@ -313,7 +319,9 @@ func (r *RootCmd) ssh() *clibase.Cmd {
313319
if err != nil {
314320
return xerrors.Errorf("forward GPG socket: %w", err)
315321
}
316-
defer closer.Close()
322+
if err = stack.push("forwardGPGAgent", closer); err != nil {
323+
return err
324+
}
317325
}
318326

319327
if remoteForward != "" {
@@ -326,7 +334,9 @@ func (r *RootCmd) ssh() *clibase.Cmd {
326334
if err != nil {
327335
return xerrors.Errorf("ssh remote forward: %w", err)
328336
}
329-
defer closer.Close()
337+
if err = stack.push("sshRemoteForward", closer); err != nil {
338+
return err
339+
}
330340
}
331341

332342
stdoutFile, validOut := inv.Stdout.(*os.File)
@@ -795,3 +805,104 @@ func remoteGPGAgentSocket(sshClient *gossh.Client) (string, error) {
795805

796806
return string(bytes.TrimSpace(remoteSocket)), nil
797807
}
808+
809+
type closerWithName struct {
810+
name string
811+
closer io.Closer
812+
}
813+
814+
type closerStack struct {
815+
sync.Mutex
816+
closers []closerWithName
817+
closed bool
818+
logger slog.Logger
819+
err error
820+
}
821+
822+
func newCloserStack(ctx context.Context, logger slog.Logger) *closerStack {
823+
cs := &closerStack{logger: logger}
824+
go cs.closeAfterContext(ctx)
825+
return cs
826+
}
827+
828+
func (c *closerStack) closeAfterContext(ctx context.Context) {
829+
<-ctx.Done()
830+
c.close(ctx.Err())
831+
}
832+
833+
func (c *closerStack) close(err error) {
834+
c.Lock()
835+
if c.closed {
836+
c.Unlock()
837+
return
838+
}
839+
c.closed = true
840+
c.err = err
841+
c.Unlock()
842+
843+
for i := len(c.closers) - 1; i >= 0; i-- {
844+
cwn := c.closers[i]
845+
cErr := cwn.closer.Close()
846+
c.logger.Debug(context.Background(),
847+
"closed item from stack", slog.F("name", cwn.name), slog.Error(cErr))
848+
}
849+
}
850+
851+
func (c *closerStack) push(name string, closer io.Closer) error {
852+
c.Lock()
853+
if c.closed {
854+
c.Unlock()
855+
// since we're refusing to push it on the stack, close it now
856+
err := closer.Close()
857+
c.logger.Error(context.Background(),
858+
"closed item rejected push", slog.F("name", name), slog.Error(err))
859+
return xerrors.Errorf("already closed: %w", c.err)
860+
}
861+
c.closers = append(c.closers, closerWithName{name: name, closer: closer})
862+
c.Unlock()
863+
return nil
864+
}
865+
866+
// rawSSHCopier handles copying raw SSH data between the conn and the pair (r, w).
867+
type rawSSHCopier struct {
868+
conn *gonet.TCPConn
869+
logger slog.Logger
870+
r io.Reader
871+
w io.Writer
872+
}
873+
874+
func (c *rawSSHCopier) copy(wg *sync.WaitGroup) {
875+
logCtx := context.Background()
876+
wg.Add(1)
877+
go func() {
878+
defer wg.Done()
879+
// We close connections using CloseWrite instead of Close, so that the SSH server sees the
880+
// closed connection while reading, and shuts down cleanly. This will trigger the io.Copy
881+
// in the server-to-client direction to also be closed and the copy() routine will exit.
882+
// This ensures that we don't leave any state in the server, like forwarded ports if
883+
// copy() were to return and the underlying tailnet connection torn down before the TCP
884+
// session exits. This is a bit of a hack to block shut down at the application layer, since
885+
// we can't serialize the TCP and tailnet layers shutting down.
886+
//
887+
// Of course, if the underlying transport is broken, io.Copy will still return.
888+
defer c.conn.CloseWrite()
889+
890+
_, err := io.Copy(c.conn, c.r)
891+
if err != nil {
892+
c.logger.Error(logCtx, "copy stdin error", slog.Error(err))
893+
} else {
894+
c.logger.Debug(logCtx, "copy stdin complete")
895+
}
896+
}()
897+
_, err := io.Copy(c.w, c.conn)
898+
if err != nil {
899+
c.logger.Error(logCtx, "copy stdout error", slog.Error(err))
900+
} else {
901+
c.logger.Debug(logCtx, "copy stdout complete")
902+
}
903+
return
904+
}
905+
906+
func (c *rawSSHCopier) Close() error {
907+
return c.conn.CloseWrite()
908+
}

cli/ssh_internal_test.go

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,16 @@
11
package cli
22

33
import (
4+
"context"
45
"net/url"
56
"testing"
67

8+
"golang.org/x/xerrors"
9+
10+
"cdr.dev/slog"
11+
"cdr.dev/slog/sloggers/slogtest"
12+
"github.com/coder/coder/v2/testutil"
13+
714
"github.com/stretchr/testify/assert"
815
"github.com/stretchr/testify/require"
916

@@ -56,3 +63,74 @@ func TestBuildWorkspaceLink(t *testing.T) {
5663

5764
assert.Equal(t, workspaceLink.String(), fakeServerURL+"/@"+fakeOwnerName+"/"+fakeWorkspaceName)
5865
}
66+
67+
func TestCloserStack_Mainline(t *testing.T) {
68+
ctx := testutil.Context(t, testutil.WaitShort)
69+
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
70+
uut := newCloserStack(ctx, logger)
71+
closes := new([]*fakeCloser)
72+
fc0 := &fakeCloser{closes: closes}
73+
fc1 := &fakeCloser{closes: closes}
74+
75+
func() {
76+
defer uut.close(nil)
77+
err := uut.push("fc0", fc0)
78+
require.NoError(t, err)
79+
err = uut.push("fc1", fc1)
80+
require.NoError(t, err)
81+
}()
82+
// order reversed
83+
require.Equal(t, []*fakeCloser{fc1, fc0}, *closes)
84+
}
85+
86+
func TestCloserStack_Context(t *testing.T) {
87+
ctx := testutil.Context(t, testutil.WaitShort)
88+
ctx, cancel := context.WithCancel(ctx)
89+
defer cancel()
90+
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
91+
uut := newCloserStack(ctx, logger)
92+
closes := new([]*fakeCloser)
93+
fc0 := &fakeCloser{closes: closes}
94+
fc1 := &fakeCloser{closes: closes}
95+
96+
err := uut.push("fc0", fc0)
97+
require.NoError(t, err)
98+
err = uut.push("fc1", fc1)
99+
require.NoError(t, err)
100+
cancel()
101+
require.Eventually(t, func() bool {
102+
uut.Lock()
103+
defer uut.Unlock()
104+
return uut.closed
105+
}, testutil.WaitShort, testutil.IntervalFast)
106+
}
107+
108+
func TestCloserStack_PushAfterClose(t *testing.T) {
109+
ctx := testutil.Context(t, testutil.WaitShort)
110+
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
111+
uut := newCloserStack(ctx, logger)
112+
closes := new([]*fakeCloser)
113+
fc0 := &fakeCloser{closes: closes}
114+
fc1 := &fakeCloser{closes: closes}
115+
116+
err := uut.push("fc0", fc0)
117+
require.NoError(t, err)
118+
119+
exErr := xerrors.New("test")
120+
uut.close(exErr)
121+
require.Equal(t, []*fakeCloser{fc0}, *closes)
122+
123+
err = uut.push("fc1", fc1)
124+
require.ErrorIs(t, err, exErr)
125+
require.Equal(t, []*fakeCloser{fc1, fc0}, *closes, "should close fc1")
126+
}
127+
128+
type fakeCloser struct {
129+
closes *[]*fakeCloser
130+
err error
131+
}
132+
133+
func (c *fakeCloser) Close() error {
134+
*c.closes = append(*c.closes, c)
135+
return c.err
136+
}

0 commit comments

Comments
 (0)