Skip to content

Commit d9dc47a

Browse files
committed
fix: close SSH sessions bottom-up if top-down fails
1 parent 2df9a3e commit d9dc47a

File tree

2 files changed

+95
-1
lines changed

2 files changed

+95
-1
lines changed

cli/ssh.go

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ import (
3737
"github.com/coder/coder/v2/codersdk/workspacesdk"
3838
"github.com/coder/coder/v2/cryptorand"
3939
"github.com/coder/coder/v2/pty"
40+
"github.com/coder/quartz"
4041
"github.com/coder/retry"
4142
"github.com/coder/serpent"
4243
)
@@ -48,6 +49,7 @@ const (
4849
var (
4950
workspacePollInterval = time.Minute
5051
autostopNotifyCountdown = []time.Duration{30 * time.Minute}
52+
gracefulShutdownTimeout = 5 * time.Second
5153
)
5254

5355
func (r *RootCmd) ssh() *serpent.Command {
@@ -250,7 +252,16 @@ func (r *RootCmd) ssh() *serpent.Command {
250252
if err != nil {
251253
return xerrors.Errorf("dial agent: %w", err)
252254
}
253-
if err = stack.push("agent conn", conn); err != nil {
255+
if err = stack.push(
256+
"agent conn",
257+
// We set a long TCP timeout on SSH connections, which means if the underlying
258+
// network fails, the SSH layer can hang for a really long time trying to send a
259+
// shutdown message for any remote forwards (https://github.com/golang/go/issues/69484)
260+
// Normally, we want to tear stuff down top to bottom, but if we get stuck doing it
261+
// that way, this timeoutCloser will trip and close the underlying connection,
262+
// bottom-up.
263+
newTimeoutCloser(ctx, logger, gracefulShutdownTimeout, conn, quartz.NewReal()),
264+
); err != nil {
254265
return err
255266
}
256267
conn.AwaitReachable(ctx)
@@ -1085,3 +1096,49 @@ func getUsageAppName(usageApp string) codersdk.UsageAppName {
10851096

10861097
return codersdk.UsageAppNameSSH
10871098
}
1099+
1100+
type timeoutCloser struct {
1101+
target io.Closer
1102+
closeCalled chan struct{}
1103+
1104+
// for testing
1105+
clock quartz.Clock
1106+
}
1107+
1108+
func newTimeoutCloser(
1109+
ctx context.Context, logger slog.Logger, timeout time.Duration, target io.Closer, clock quartz.Clock,
1110+
) *timeoutCloser {
1111+
b := &timeoutCloser{
1112+
target: target,
1113+
closeCalled: make(chan struct{}),
1114+
clock: clock,
1115+
}
1116+
go b.waitForCtxOrClose(ctx, logger, timeout)
1117+
return b
1118+
}
1119+
1120+
func (t *timeoutCloser) waitForCtxOrClose(ctx context.Context, logger slog.Logger, timeout time.Duration) {
1121+
select {
1122+
case <-t.closeCalled:
1123+
return
1124+
case <-ctx.Done():
1125+
}
1126+
tmr := t.clock.NewTimer(timeout, "timeoutCloser", "waitForCtxOrClose")
1127+
defer tmr.Stop()
1128+
select {
1129+
case <-t.closeCalled:
1130+
return
1131+
case <-tmr.C:
1132+
logger.Warn(ctx, "timed out waiting for graceful shutdown")
1133+
err := t.target.Close()
1134+
if err != nil {
1135+
logger.Debug(ctx, "error closing target", slog.Error(err))
1136+
}
1137+
}
1138+
}
1139+
1140+
// Close should only be called at most once, e.g. in the closerStack
1141+
func (t *timeoutCloser) Close() error {
1142+
close(t.closeCalled)
1143+
return t.target.Close()
1144+
}

cli/ssh_internal_test.go

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import (
1212

1313
"cdr.dev/slog"
1414
"cdr.dev/slog/sloggers/slogtest"
15+
"github.com/coder/quartz"
1516

1617
"github.com/coder/coder/v2/codersdk"
1718
"github.com/coder/coder/v2/testutil"
@@ -192,3 +193,39 @@ func (c *asyncCloser) Close() error {
192193
return nil
193194
}
194195
}
196+
197+
func TestTimeoutCloser_Close(t *testing.T) {
198+
t.Parallel()
199+
ctx := testutil.Context(t, testutil.WaitShort)
200+
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
201+
mClock := quartz.NewMock(t)
202+
closes := new([]*fakeCloser)
203+
fc0 := &fakeCloser{closes: closes}
204+
uut := newTimeoutCloser(ctx, logger, time.Second, fc0, mClock)
205+
err := uut.Close()
206+
require.NoError(t, err)
207+
require.Equal(t, []*fakeCloser{fc0}, *closes, "should close fc0")
208+
}
209+
210+
func TestTimeoutCloser_Timeout(t *testing.T) {
211+
t.Parallel()
212+
testCtx := testutil.Context(t, testutil.WaitShort)
213+
ctx, cancel := context.WithCancel(testCtx)
214+
defer cancel()
215+
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
216+
mClock := quartz.NewMock(t)
217+
trap := mClock.Trap().NewTimer("timeoutCloser", "waitForCtxOrClose")
218+
defer trap.Close()
219+
ac := &asyncCloser{
220+
t: t,
221+
ctx: testCtx,
222+
complete: make(chan struct{}),
223+
started: make(chan struct{}),
224+
}
225+
_ = newTimeoutCloser(ctx, logger, time.Second, ac, mClock)
226+
cancel()
227+
trap.MustWait(testCtx).Release()
228+
mClock.Advance(time.Second).MustWait(testCtx)
229+
testutil.RequireRecvCtx(testCtx, t, ac.started)
230+
close(ac.complete)
231+
}

0 commit comments

Comments
 (0)