Skip to content

Commit 3dd35e0

Browse files
authored
fix: close ssh sessions gracefully (coder#10732)
Re-enables TestSSH/RemoteForward_Unix_Signal and addresses the underlying race: we were not closing the remote forward on context expiry, only the session and connection. However, there is still a more fundamental issue in that we don't have the ability to ensure that TCP sessions are properly terminated before tearing down the Tailnet conn. This is due to the assumption in the sockets API, that the underlying IP interface is long lived compared with the TCP socket, and thus closing a socket returns immediately and does not wait for the TCP termination handshake --- that is handled async in the tcpip stack. However, this assumption does not hold for us and tailnet, since on shutdown, we also tear down the tailnet connection, and this can race with the TCP termination. Closing the remote forward explicitly should prevent forward state from accumulating, since the Close() function waits for a reply from the remote SSH server. I've also attempted to workaround the TCP/tailnet issue for `--stdio` by using `CloseWrite()` instead of `Close()`. By closing the write side of the connection, half-close the TCP connection, and the server detects this and closes the other direction, which then triggers our read loop to exit only after the server has had a chance to process the close. TODO in a stacked PR is to implement this logic for `vscodessh` as well.
1 parent ba955f4 commit 3dd35e0

File tree

4 files changed

+331
-36
lines changed

4 files changed

+331
-36
lines changed

cli/ssh.go

Lines changed: 129 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import (
2222
gosshagent "golang.org/x/crypto/ssh/agent"
2323
"golang.org/x/term"
2424
"golang.org/x/xerrors"
25+
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
2526

2627
"cdr.dev/slog"
2728
"cdr.dev/slog/sloggers/sloghuman"
@@ -129,6 +130,8 @@ func (r *RootCmd) ssh() *clibase.Cmd {
129130
// log HTTP requests
130131
client.SetLogger(logger)
131132
}
133+
stack := newCloserStack(ctx, logger)
134+
defer stack.close(nil)
132135

133136
if remoteForward != "" {
134137
isValid := validateRemoteForward(remoteForward)
@@ -212,7 +215,9 @@ func (r *RootCmd) ssh() *clibase.Cmd {
212215
if err != nil {
213216
return xerrors.Errorf("dial agent: %w", err)
214217
}
215-
defer conn.Close()
218+
if err = stack.push("agent conn", conn); err != nil {
219+
return err
220+
}
216221
conn.AwaitReachable(ctx)
217222

218223
stopPolling := tryPollWorkspaceAutostop(ctx, client, workspace)
@@ -223,61 +228,46 @@ func (r *RootCmd) ssh() *clibase.Cmd {
223228
if err != nil {
224229
return xerrors.Errorf("connect SSH: %w", err)
225230
}
226-
defer rawSSH.Close()
231+
copier := &rawSSHCopier{conn: rawSSH, r: inv.Stdin, w: inv.Stdout}
232+
if err = stack.push("rawSSHCopier", copier); err != nil {
233+
return err
234+
}
227235

228236
wg.Add(1)
229237
go func() {
230238
defer wg.Done()
231239
watchAndClose(ctx, func() error {
232-
return rawSSH.Close()
240+
stack.close(xerrors.New("watchAndClose"))
241+
return nil
233242
}, logger, client, workspace)
234243
}()
235-
236-
wg.Add(1)
237-
go func() {
238-
defer wg.Done()
239-
// Ensure stdout copy closes incase stdin is closed
240-
// unexpectedly.
241-
defer rawSSH.Close()
242-
243-
_, err := io.Copy(rawSSH, inv.Stdin)
244-
if err != nil {
245-
logger.Error(ctx, "copy stdin error", slog.Error(err))
246-
} else {
247-
logger.Debug(ctx, "copy stdin complete")
248-
}
249-
}()
250-
_, err = io.Copy(inv.Stdout, rawSSH)
251-
if err != nil {
252-
logger.Error(ctx, "copy stdout error", slog.Error(err))
253-
} else {
254-
logger.Debug(ctx, "copy stdout complete")
255-
}
244+
copier.copy(&wg)
256245
return nil
257246
}
258247

259248
sshClient, err := conn.SSHClient(ctx)
260249
if err != nil {
261250
return xerrors.Errorf("ssh client: %w", err)
262251
}
263-
defer sshClient.Close()
252+
if err = stack.push("ssh client", sshClient); err != nil {
253+
return err
254+
}
264255

265256
sshSession, err := sshClient.NewSession()
266257
if err != nil {
267258
return xerrors.Errorf("ssh session: %w", err)
268259
}
269-
defer sshSession.Close()
260+
if err = stack.push("sshSession", sshSession); err != nil {
261+
return err
262+
}
270263

271264
wg.Add(1)
272265
go func() {
273266
defer wg.Done()
274267
watchAndClose(
275268
ctx,
276269
func() error {
277-
err := sshSession.Close()
278-
logger.Debug(ctx, "session close", slog.Error(err))
279-
err = sshClient.Close()
280-
logger.Debug(ctx, "client close", slog.Error(err))
270+
stack.close(xerrors.New("watchAndClose"))
281271
return nil
282272
},
283273
logger,
@@ -313,7 +303,9 @@ func (r *RootCmd) ssh() *clibase.Cmd {
313303
if err != nil {
314304
return xerrors.Errorf("forward GPG socket: %w", err)
315305
}
316-
defer closer.Close()
306+
if err = stack.push("forwardGPGAgent", closer); err != nil {
307+
return err
308+
}
317309
}
318310

319311
if remoteForward != "" {
@@ -326,7 +318,9 @@ func (r *RootCmd) ssh() *clibase.Cmd {
326318
if err != nil {
327319
return xerrors.Errorf("ssh remote forward: %w", err)
328320
}
329-
defer closer.Close()
321+
if err = stack.push("sshRemoteForward", closer); err != nil {
322+
return err
323+
}
330324
}
331325

332326
stdoutFile, validOut := inv.Stdout.(*os.File)
@@ -795,3 +789,106 @@ func remoteGPGAgentSocket(sshClient *gossh.Client) (string, error) {
795789

796790
return string(bytes.TrimSpace(remoteSocket)), nil
797791
}
792+
793+
type closerWithName struct {
794+
name string
795+
closer io.Closer
796+
}
797+
798+
type closerStack struct {
799+
sync.Mutex
800+
closers []closerWithName
801+
closed bool
802+
logger slog.Logger
803+
err error
804+
}
805+
806+
func newCloserStack(ctx context.Context, logger slog.Logger) *closerStack {
807+
cs := &closerStack{logger: logger}
808+
go cs.closeAfterContext(ctx)
809+
return cs
810+
}
811+
812+
func (c *closerStack) closeAfterContext(ctx context.Context) {
813+
<-ctx.Done()
814+
c.close(ctx.Err())
815+
}
816+
817+
func (c *closerStack) close(err error) {
818+
c.Lock()
819+
if c.closed {
820+
c.Unlock()
821+
return
822+
}
823+
c.closed = true
824+
c.err = err
825+
c.Unlock()
826+
827+
for i := len(c.closers) - 1; i >= 0; i-- {
828+
cwn := c.closers[i]
829+
cErr := cwn.closer.Close()
830+
c.logger.Debug(context.Background(),
831+
"closed item from stack", slog.F("name", cwn.name), slog.Error(cErr))
832+
}
833+
}
834+
835+
func (c *closerStack) push(name string, closer io.Closer) error {
836+
c.Lock()
837+
if c.closed {
838+
c.Unlock()
839+
// since we're refusing to push it on the stack, close it now
840+
err := closer.Close()
841+
c.logger.Error(context.Background(),
842+
"closed item rejected push", slog.F("name", name), slog.Error(err))
843+
return xerrors.Errorf("already closed: %w", c.err)
844+
}
845+
c.closers = append(c.closers, closerWithName{name: name, closer: closer})
846+
c.Unlock()
847+
return nil
848+
}
849+
850+
// rawSSHCopier handles copying raw SSH data between the conn and the pair (r, w).
851+
type rawSSHCopier struct {
852+
conn *gonet.TCPConn
853+
logger slog.Logger
854+
r io.Reader
855+
w io.Writer
856+
}
857+
858+
func (c *rawSSHCopier) copy(wg *sync.WaitGroup) {
859+
logCtx := context.Background()
860+
wg.Add(1)
861+
go func() {
862+
defer wg.Done()
863+
// We close connections using CloseWrite instead of Close, so that the SSH server sees the
864+
// closed connection while reading, and shuts down cleanly. This will trigger the io.Copy
865+
// in the server-to-client direction to also be closed and the copy() routine will exit.
866+
// This ensures that we don't leave any state in the server, like forwarded ports if
867+
// copy() were to return and the underlying tailnet connection torn down before the TCP
868+
// session exits. This is a bit of a hack to block shut down at the application layer, since
869+
// we can't serialize the TCP and tailnet layers shutting down.
870+
//
871+
// Of course, if the underlying transport is broken, io.Copy will still return.
872+
defer func() {
873+
cwErr := c.conn.CloseWrite()
874+
c.logger.Debug(logCtx, "closed raw SSH connection for writing", slog.Error(cwErr))
875+
}()
876+
877+
_, err := io.Copy(c.conn, c.r)
878+
if err != nil {
879+
c.logger.Error(logCtx, "copy stdin error", slog.Error(err))
880+
} else {
881+
c.logger.Debug(logCtx, "copy stdin complete")
882+
}
883+
}()
884+
_, err := io.Copy(c.w, c.conn)
885+
if err != nil {
886+
c.logger.Error(logCtx, "copy stdout error", slog.Error(err))
887+
} else {
888+
c.logger.Debug(logCtx, "copy stdout complete")
889+
}
890+
}
891+
892+
func (c *rawSSHCopier) Close() error {
893+
return c.conn.CloseWrite()
894+
}

cli/ssh_internal_test.go

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,16 @@
11
package cli
22

33
import (
4+
"context"
45
"net/url"
56
"testing"
67

8+
"golang.org/x/xerrors"
9+
10+
"cdr.dev/slog"
11+
"cdr.dev/slog/sloggers/slogtest"
12+
"github.com/coder/coder/v2/testutil"
13+
714
"github.com/stretchr/testify/assert"
815
"github.com/stretchr/testify/require"
916

@@ -56,3 +63,77 @@ func TestBuildWorkspaceLink(t *testing.T) {
5663

5764
assert.Equal(t, workspaceLink.String(), fakeServerURL+"/@"+fakeOwnerName+"/"+fakeWorkspaceName)
5865
}
66+
67+
func TestCloserStack_Mainline(t *testing.T) {
68+
t.Parallel()
69+
ctx := testutil.Context(t, testutil.WaitShort)
70+
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
71+
uut := newCloserStack(ctx, logger)
72+
closes := new([]*fakeCloser)
73+
fc0 := &fakeCloser{closes: closes}
74+
fc1 := &fakeCloser{closes: closes}
75+
76+
func() {
77+
defer uut.close(nil)
78+
err := uut.push("fc0", fc0)
79+
require.NoError(t, err)
80+
err = uut.push("fc1", fc1)
81+
require.NoError(t, err)
82+
}()
83+
// order reversed
84+
require.Equal(t, []*fakeCloser{fc1, fc0}, *closes)
85+
}
86+
87+
func TestCloserStack_Context(t *testing.T) {
88+
t.Parallel()
89+
ctx := testutil.Context(t, testutil.WaitShort)
90+
ctx, cancel := context.WithCancel(ctx)
91+
defer cancel()
92+
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
93+
uut := newCloserStack(ctx, logger)
94+
closes := new([]*fakeCloser)
95+
fc0 := &fakeCloser{closes: closes}
96+
fc1 := &fakeCloser{closes: closes}
97+
98+
err := uut.push("fc0", fc0)
99+
require.NoError(t, err)
100+
err = uut.push("fc1", fc1)
101+
require.NoError(t, err)
102+
cancel()
103+
require.Eventually(t, func() bool {
104+
uut.Lock()
105+
defer uut.Unlock()
106+
return uut.closed
107+
}, testutil.WaitShort, testutil.IntervalFast)
108+
}
109+
110+
func TestCloserStack_PushAfterClose(t *testing.T) {
111+
t.Parallel()
112+
ctx := testutil.Context(t, testutil.WaitShort)
113+
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
114+
uut := newCloserStack(ctx, logger)
115+
closes := new([]*fakeCloser)
116+
fc0 := &fakeCloser{closes: closes}
117+
fc1 := &fakeCloser{closes: closes}
118+
119+
err := uut.push("fc0", fc0)
120+
require.NoError(t, err)
121+
122+
exErr := xerrors.New("test")
123+
uut.close(exErr)
124+
require.Equal(t, []*fakeCloser{fc0}, *closes)
125+
126+
err = uut.push("fc1", fc1)
127+
require.ErrorIs(t, err, exErr)
128+
require.Equal(t, []*fakeCloser{fc1, fc0}, *closes, "should close fc1")
129+
}
130+
131+
type fakeCloser struct {
132+
closes *[]*fakeCloser
133+
err error
134+
}
135+
136+
func (c *fakeCloser) Close() error {
137+
*c.closes = append(*c.closes, c)
138+
return c.err
139+
}

0 commit comments

Comments
 (0)