Skip to content

Commit 21ee4de

Browse files
committed
fix: close SSH sessions bottom-up if top-down fails
1 parent ff1eabe commit 21ee4de

File tree

2 files changed

+129
-29
lines changed

2 files changed

+129
-29
lines changed

cli/ssh.go

Lines changed: 57 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ import (
3737
"github.com/coder/coder/v2/codersdk/workspacesdk"
3838
"github.com/coder/coder/v2/cryptorand"
3939
"github.com/coder/coder/v2/pty"
40+
"github.com/coder/quartz"
4041
"github.com/coder/retry"
4142
"github.com/coder/serpent"
4243
)
@@ -48,6 +49,8 @@ const (
4849
var (
4950
workspacePollInterval = time.Minute
5051
autostopNotifyCountdown = []time.Duration{30 * time.Minute}
52+
// gracefulShutdownTimeout is the timeout, per item in the stack of things to close
53+
gracefulShutdownTimeout = 2 * time.Second
5154
)
5255

5356
func (r *RootCmd) ssh() *serpent.Command {
@@ -153,7 +156,7 @@ func (r *RootCmd) ssh() *serpent.Command {
153156
// log HTTP requests
154157
client.SetLogger(logger)
155158
}
156-
stack := newCloserStack(ctx, logger)
159+
stack := newCloserStack(ctx, logger, quartz.NewReal())
157160
defer stack.close(nil)
158161

159162
for _, remoteForward := range remoteForwards {
@@ -936,11 +939,18 @@ type closerStack struct {
936939
closed bool
937940
logger slog.Logger
938941
err error
939-
wg sync.WaitGroup
942+
allDone chan struct{}
943+
944+
// for testing
945+
clock quartz.Clock
940946
}
941947

942-
func newCloserStack(ctx context.Context, logger slog.Logger) *closerStack {
943-
cs := &closerStack{logger: logger}
948+
func newCloserStack(ctx context.Context, logger slog.Logger, clock quartz.Clock) *closerStack {
949+
cs := &closerStack{
950+
logger: logger,
951+
allDone: make(chan struct{}),
952+
clock: clock,
953+
}
944954
go cs.closeAfterContext(ctx)
945955
return cs
946956
}
@@ -954,20 +964,55 @@ func (c *closerStack) close(err error) {
954964
c.Lock()
955965
if c.closed {
956966
c.Unlock()
957-
c.wg.Wait()
967+
<-c.allDone
958968
return
959969
}
960970
c.closed = true
961971
c.err = err
962-
c.wg.Add(1)
963-
defer c.wg.Done()
964972
c.Unlock()
973+
defer close(c.allDone)
974+
975+
// We are going to work down the stack in order. If things close quickly, we trigger the
976+
// closers serially, in order. `done` is a channel that indicates the nth closer is done
977+
// closing, and we should trigger the (n-1) closer. However, if things take too long we don't
978+
// want to wait, so we also start a ticker that works down the stack and sends on `done` as
979+
// well.
980+
next := len(c.closers) - 1
981+
// here we make the buffer 2x the number of closers because we could write once for it being
982+
// actually done and once via the countdown for each closer
983+
done := make(chan int, len(c.closers)*2)
984+
startNext := func() {
985+
go func(i int) {
986+
defer func() { done <- i }()
987+
cwn := c.closers[i]
988+
cErr := cwn.closer.Close()
989+
c.logger.Debug(context.Background(),
990+
"closed item from stack", slog.F("name", cwn.name), slog.Error(cErr))
991+
}(next)
992+
next--
993+
}
994+
done <- len(c.closers) // kick us off right away
995+
996+
// start a ticking countdown in case we hang/don't close quickly
997+
countdown := len(c.closers) - 1
998+
ctx, cancel := context.WithCancel(context.Background())
999+
defer cancel()
1000+
c.clock.TickerFunc(ctx, gracefulShutdownTimeout, func() error {
1001+
select {
1002+
case done <- countdown:
1003+
countdown--
1004+
case <-ctx.Done():
1005+
}
1006+
return nil
1007+
}, "closerStack")
9651008

966-
for i := len(c.closers) - 1; i >= 0; i-- {
967-
cwn := c.closers[i]
968-
cErr := cwn.closer.Close()
969-
c.logger.Debug(context.Background(),
970-
"closed item from stack", slog.F("name", cwn.name), slog.Error(cErr))
1009+
for n := range done { // the nth closer is done
1010+
if n == 0 {
1011+
return
1012+
}
1013+
if n-1 == next {
1014+
startNext()
1015+
}
9711016
}
9721017
}
9731018

cli/ssh_internal_test.go

Lines changed: 72 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@ package cli
22

33
import (
44
"context"
5+
"fmt"
56
"net/url"
7+
"sync"
68
"testing"
79
"time"
810

@@ -12,6 +14,7 @@ import (
1214

1315
"cdr.dev/slog"
1416
"cdr.dev/slog/sloggers/slogtest"
17+
"github.com/coder/quartz"
1518

1619
"github.com/coder/coder/v2/codersdk"
1720
"github.com/coder/coder/v2/testutil"
@@ -68,7 +71,7 @@ func TestCloserStack_Mainline(t *testing.T) {
6871
t.Parallel()
6972
ctx := testutil.Context(t, testutil.WaitShort)
7073
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
71-
uut := newCloserStack(ctx, logger)
74+
uut := newCloserStack(ctx, logger, quartz.NewMock(t))
7275
closes := new([]*fakeCloser)
7376
fc0 := &fakeCloser{closes: closes}
7477
fc1 := &fakeCloser{closes: closes}
@@ -90,7 +93,7 @@ func TestCloserStack_Context(t *testing.T) {
9093
ctx, cancel := context.WithCancel(ctx)
9194
defer cancel()
9295
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
93-
uut := newCloserStack(ctx, logger)
96+
uut := newCloserStack(ctx, logger, quartz.NewMock(t))
9497
closes := new([]*fakeCloser)
9598
fc0 := &fakeCloser{closes: closes}
9699
fc1 := &fakeCloser{closes: closes}
@@ -111,7 +114,7 @@ func TestCloserStack_PushAfterClose(t *testing.T) {
111114
t.Parallel()
112115
ctx := testutil.Context(t, testutil.WaitShort)
113116
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
114-
uut := newCloserStack(ctx, logger)
117+
uut := newCloserStack(ctx, logger, quartz.NewMock(t))
115118
closes := new([]*fakeCloser)
116119
fc0 := &fakeCloser{closes: closes}
117120
fc1 := &fakeCloser{closes: closes}
@@ -134,13 +137,9 @@ func TestCloserStack_CloseAfterContext(t *testing.T) {
134137
ctx, cancel := context.WithCancel(testCtx)
135138
defer cancel()
136139
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
137-
uut := newCloserStack(ctx, logger)
138-
ac := &asyncCloser{
139-
t: t,
140-
ctx: testCtx,
141-
complete: make(chan struct{}),
142-
started: make(chan struct{}),
143-
}
140+
uut := newCloserStack(ctx, logger, quartz.NewMock(t))
141+
ac := newAsyncCloser(testCtx, t)
142+
defer ac.complete()
144143
err := uut.push("async", ac)
145144
require.NoError(t, err)
146145
cancel()
@@ -160,11 +159,53 @@ func TestCloserStack_CloseAfterContext(t *testing.T) {
160159
t.Fatal("closed before stack was finished")
161160
}
162161

163-
// complete the asyncCloser
164-
close(ac.complete)
162+
ac.complete()
165163
testutil.RequireRecvCtx(testCtx, t, closed)
166164
}
167165

166+
func TestCloserStack_Timeout(t *testing.T) {
167+
t.Parallel()
168+
ctx := testutil.Context(t, testutil.WaitShort)
169+
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
170+
mClock := quartz.NewMock(t)
171+
trap := mClock.Trap().TickerFunc("closerStack")
172+
defer trap.Close()
173+
uut := newCloserStack(ctx, logger, mClock)
174+
var ac [3]*asyncCloser
175+
for i := range ac {
176+
ac[i] = newAsyncCloser(ctx, t)
177+
err := uut.push(fmt.Sprintf("async %d", i), ac[i])
178+
require.NoError(t, err)
179+
}
180+
defer func() {
181+
for _, a := range ac {
182+
a.complete()
183+
}
184+
}()
185+
186+
closed := make(chan struct{})
187+
go func() {
188+
defer close(closed)
189+
uut.close(nil)
190+
}()
191+
trap.MustWait(ctx).Release()
192+
// top starts right away, but it hangs
193+
testutil.RequireRecvCtx(ctx, t, ac[2].started)
194+
// timer pops and we start the middle one
195+
mClock.Advance(gracefulShutdownTimeout).MustWait(ctx)
196+
testutil.RequireRecvCtx(ctx, t, ac[1].started)
197+
198+
// middle one finishes
199+
ac[1].complete()
200+
// bottom starts, but also hangs
201+
testutil.RequireRecvCtx(ctx, t, ac[0].started)
202+
203+
// timer has to pop twice to time out.
204+
mClock.Advance(gracefulShutdownTimeout).MustWait(ctx)
205+
mClock.Advance(gracefulShutdownTimeout).MustWait(ctx)
206+
testutil.RequireRecvCtx(ctx, t, closed)
207+
}
208+
168209
type fakeCloser struct {
169210
closes *[]*fakeCloser
170211
err error
@@ -176,10 +217,11 @@ func (c *fakeCloser) Close() error {
176217
}
177218

178219
type asyncCloser struct {
179-
t *testing.T
180-
ctx context.Context
181-
started chan struct{}
182-
complete chan struct{}
220+
t *testing.T
221+
ctx context.Context
222+
started chan struct{}
223+
isComplete chan struct{}
224+
comepleteOnce sync.Once
183225
}
184226

185227
func (c *asyncCloser) Close() error {
@@ -188,7 +230,20 @@ func (c *asyncCloser) Close() error {
188230
case <-c.ctx.Done():
189231
c.t.Error("timed out")
190232
return c.ctx.Err()
191-
case <-c.complete:
233+
case <-c.isComplete:
192234
return nil
193235
}
194236
}
237+
238+
func (c *asyncCloser) complete() {
239+
c.comepleteOnce.Do(func() { close(c.isComplete) })
240+
}
241+
242+
func newAsyncCloser(ctx context.Context, t *testing.T) *asyncCloser {
243+
return &asyncCloser{
244+
t: t,
245+
ctx: ctx,
246+
isComplete: make(chan struct{}),
247+
started: make(chan struct{}),
248+
}
249+
}

0 commit comments

Comments
 (0)