Skip to content

Commit 7dd56e0

Browse files
committed
fix: close SSH sessions bottom-up if top-down fails
1 parent ff1eabe commit 7dd56e0

File tree

2 files changed

+146
-29
lines changed

2 files changed

+146
-29
lines changed

cli/ssh.go

Lines changed: 60 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ import (
3737
"github.com/coder/coder/v2/codersdk/workspacesdk"
3838
"github.com/coder/coder/v2/cryptorand"
3939
"github.com/coder/coder/v2/pty"
40+
"github.com/coder/quartz"
4041
"github.com/coder/retry"
4142
"github.com/coder/serpent"
4243
)
@@ -48,6 +49,8 @@ const (
4849
var (
4950
workspacePollInterval = time.Minute
5051
autostopNotifyCountdown = []time.Duration{30 * time.Minute}
52+
// gracefulShutdownTimeout is the timeout, per item in the stack of things to close
53+
gracefulShutdownTimeout = 2 * time.Second
5154
)
5255

5356
func (r *RootCmd) ssh() *serpent.Command {
@@ -153,7 +156,7 @@ func (r *RootCmd) ssh() *serpent.Command {
153156
// log HTTP requests
154157
client.SetLogger(logger)
155158
}
156-
stack := newCloserStack(ctx, logger)
159+
stack := newCloserStack(ctx, logger, quartz.NewReal())
157160
defer stack.close(nil)
158161

159162
for _, remoteForward := range remoteForwards {
@@ -936,11 +939,18 @@ type closerStack struct {
936939
closed bool
937940
logger slog.Logger
938941
err error
939-
wg sync.WaitGroup
942+
allDone chan struct{}
943+
944+
// for testing
945+
clock quartz.Clock
940946
}
941947

942-
func newCloserStack(ctx context.Context, logger slog.Logger) *closerStack {
943-
cs := &closerStack{logger: logger}
948+
func newCloserStack(ctx context.Context, logger slog.Logger, clock quartz.Clock) *closerStack {
949+
cs := &closerStack{
950+
logger: logger,
951+
allDone: make(chan struct{}),
952+
clock: clock,
953+
}
944954
go cs.closeAfterContext(ctx)
945955
return cs
946956
}
@@ -954,20 +964,58 @@ func (c *closerStack) close(err error) {
954964
c.Lock()
955965
if c.closed {
956966
c.Unlock()
957-
c.wg.Wait()
967+
<-c.allDone
958968
return
959969
}
960970
c.closed = true
961971
c.err = err
962-
c.wg.Add(1)
963-
defer c.wg.Done()
964972
c.Unlock()
973+
defer close(c.allDone)
974+
if len(c.closers) == 0 {
975+
return
976+
}
965977

966-
for i := len(c.closers) - 1; i >= 0; i-- {
967-
cwn := c.closers[i]
968-
cErr := cwn.closer.Close()
969-
c.logger.Debug(context.Background(),
970-
"closed item from stack", slog.F("name", cwn.name), slog.Error(cErr))
978+
// We are going to work down the stack in order. If things close quickly, we trigger the
979+
// closers serially, in order. `done` is a channel that indicates the nth closer is done
980+
// closing, and we should trigger the (n-1) closer. However, if things take too long we don't
981+
// want to wait, so we also start a ticker that works down the stack and sends on `done` as
982+
// well.
983+
next := len(c.closers) - 1
984+
// here we make the buffer 2x the number of closers because we could write once for it being
985+
// actually done and once via the countdown for each closer
986+
done := make(chan int, len(c.closers)*2)
987+
startNext := func() {
988+
go func(i int) {
989+
defer func() { done <- i }()
990+
cwn := c.closers[i]
991+
cErr := cwn.closer.Close()
992+
c.logger.Debug(context.Background(),
993+
"closed item from stack", slog.F("name", cwn.name), slog.Error(cErr))
994+
}(next)
995+
next--
996+
}
997+
done <- len(c.closers) // kick us off right away
998+
999+
// start a ticking countdown in case we hang/don't close quickly
1000+
countdown := len(c.closers) - 1
1001+
ctx, cancel := context.WithCancel(context.Background())
1002+
defer cancel()
1003+
c.clock.TickerFunc(ctx, gracefulShutdownTimeout, func() error {
1004+
if countdown < 0 {
1005+
return nil
1006+
}
1007+
done <- countdown
1008+
countdown--
1009+
return nil
1010+
}, "closerStack")
1011+
1012+
for n := range done { // the nth closer is done
1013+
if n == 0 {
1014+
return
1015+
}
1016+
if n-1 == next {
1017+
startNext()
1018+
}
9711019
}
9721020
}
9731021

cli/ssh_internal_test.go

Lines changed: 86 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@ package cli
22

33
import (
44
"context"
5+
"fmt"
56
"net/url"
7+
"sync"
68
"testing"
79
"time"
810

@@ -12,6 +14,7 @@ import (
1214

1315
"cdr.dev/slog"
1416
"cdr.dev/slog/sloggers/slogtest"
17+
"github.com/coder/quartz"
1518

1619
"github.com/coder/coder/v2/codersdk"
1720
"github.com/coder/coder/v2/testutil"
@@ -68,7 +71,7 @@ func TestCloserStack_Mainline(t *testing.T) {
6871
t.Parallel()
6972
ctx := testutil.Context(t, testutil.WaitShort)
7073
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
71-
uut := newCloserStack(ctx, logger)
74+
uut := newCloserStack(ctx, logger, quartz.NewMock(t))
7275
closes := new([]*fakeCloser)
7376
fc0 := &fakeCloser{closes: closes}
7477
fc1 := &fakeCloser{closes: closes}
@@ -84,13 +87,27 @@ func TestCloserStack_Mainline(t *testing.T) {
8487
require.Equal(t, []*fakeCloser{fc1, fc0}, *closes)
8588
}
8689

90+
func TestCloserStack_Empty(t *testing.T) {
91+
t.Parallel()
92+
ctx := testutil.Context(t, testutil.WaitShort)
93+
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
94+
uut := newCloserStack(ctx, logger, quartz.NewMock(t))
95+
96+
closed := make(chan struct{})
97+
go func() {
98+
defer close(closed)
99+
uut.close(nil)
100+
}()
101+
testutil.RequireRecvCtx(ctx, t, closed)
102+
}
103+
87104
func TestCloserStack_Context(t *testing.T) {
88105
t.Parallel()
89106
ctx := testutil.Context(t, testutil.WaitShort)
90107
ctx, cancel := context.WithCancel(ctx)
91108
defer cancel()
92109
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
93-
uut := newCloserStack(ctx, logger)
110+
uut := newCloserStack(ctx, logger, quartz.NewMock(t))
94111
closes := new([]*fakeCloser)
95112
fc0 := &fakeCloser{closes: closes}
96113
fc1 := &fakeCloser{closes: closes}
@@ -111,7 +128,7 @@ func TestCloserStack_PushAfterClose(t *testing.T) {
111128
t.Parallel()
112129
ctx := testutil.Context(t, testutil.WaitShort)
113130
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
114-
uut := newCloserStack(ctx, logger)
131+
uut := newCloserStack(ctx, logger, quartz.NewMock(t))
115132
closes := new([]*fakeCloser)
116133
fc0 := &fakeCloser{closes: closes}
117134
fc1 := &fakeCloser{closes: closes}
@@ -134,13 +151,9 @@ func TestCloserStack_CloseAfterContext(t *testing.T) {
134151
ctx, cancel := context.WithCancel(testCtx)
135152
defer cancel()
136153
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
137-
uut := newCloserStack(ctx, logger)
138-
ac := &asyncCloser{
139-
t: t,
140-
ctx: testCtx,
141-
complete: make(chan struct{}),
142-
started: make(chan struct{}),
143-
}
154+
uut := newCloserStack(ctx, logger, quartz.NewMock(t))
155+
ac := newAsyncCloser(testCtx, t)
156+
defer ac.complete()
144157
err := uut.push("async", ac)
145158
require.NoError(t, err)
146159
cancel()
@@ -160,11 +173,53 @@ func TestCloserStack_CloseAfterContext(t *testing.T) {
160173
t.Fatal("closed before stack was finished")
161174
}
162175

163-
// complete the asyncCloser
164-
close(ac.complete)
176+
ac.complete()
165177
testutil.RequireRecvCtx(testCtx, t, closed)
166178
}
167179

180+
func TestCloserStack_Timeout(t *testing.T) {
181+
t.Parallel()
182+
ctx := testutil.Context(t, testutil.WaitShort)
183+
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
184+
mClock := quartz.NewMock(t)
185+
trap := mClock.Trap().TickerFunc("closerStack")
186+
defer trap.Close()
187+
uut := newCloserStack(ctx, logger, mClock)
188+
var ac [3]*asyncCloser
189+
for i := range ac {
190+
ac[i] = newAsyncCloser(ctx, t)
191+
err := uut.push(fmt.Sprintf("async %d", i), ac[i])
192+
require.NoError(t, err)
193+
}
194+
defer func() {
195+
for _, a := range ac {
196+
a.complete()
197+
}
198+
}()
199+
200+
closed := make(chan struct{})
201+
go func() {
202+
defer close(closed)
203+
uut.close(nil)
204+
}()
205+
trap.MustWait(ctx).Release()
206+
// top starts right away, but it hangs
207+
testutil.RequireRecvCtx(ctx, t, ac[2].started)
208+
// timer pops and we start the middle one
209+
mClock.Advance(gracefulShutdownTimeout).MustWait(ctx)
210+
testutil.RequireRecvCtx(ctx, t, ac[1].started)
211+
212+
// middle one finishes
213+
ac[1].complete()
214+
// bottom starts, but also hangs
215+
testutil.RequireRecvCtx(ctx, t, ac[0].started)
216+
217+
// timer has to pop twice to time out.
218+
mClock.Advance(gracefulShutdownTimeout).MustWait(ctx)
219+
mClock.Advance(gracefulShutdownTimeout).MustWait(ctx)
220+
testutil.RequireRecvCtx(ctx, t, closed)
221+
}
222+
168223
type fakeCloser struct {
169224
closes *[]*fakeCloser
170225
err error
@@ -176,10 +231,11 @@ func (c *fakeCloser) Close() error {
176231
}
177232

178233
type asyncCloser struct {
179-
t *testing.T
180-
ctx context.Context
181-
started chan struct{}
182-
complete chan struct{}
234+
t *testing.T
235+
ctx context.Context
236+
started chan struct{}
237+
isComplete chan struct{}
238+
comepleteOnce sync.Once
183239
}
184240

185241
func (c *asyncCloser) Close() error {
@@ -188,7 +244,20 @@ func (c *asyncCloser) Close() error {
188244
case <-c.ctx.Done():
189245
c.t.Error("timed out")
190246
return c.ctx.Err()
191-
case <-c.complete:
247+
case <-c.isComplete:
192248
return nil
193249
}
194250
}
251+
252+
func (c *asyncCloser) complete() {
253+
c.comepleteOnce.Do(func() { close(c.isComplete) })
254+
}
255+
256+
func newAsyncCloser(ctx context.Context, t *testing.T) *asyncCloser {
257+
return &asyncCloser{
258+
t: t,
259+
ctx: ctx,
260+
isComplete: make(chan struct{}),
261+
started: make(chan struct{}),
262+
}
263+
}

0 commit comments

Comments
 (0)