From 7dd56e075a291cd3931decc6018b2c1e82d01902 Mon Sep 17 00:00:00 2001 From: Spike Curtis Date: Fri, 13 Sep 2024 16:30:09 +0400 Subject: [PATCH] fix: close SSH sessions bottom-up if top-down fails --- cli/ssh.go | 72 ++++++++++++++++++++++----- cli/ssh_internal_test.go | 103 ++++++++++++++++++++++++++++++++------- 2 files changed, 146 insertions(+), 29 deletions(-) diff --git a/cli/ssh.go b/cli/ssh.go index 7d9d2368de2f9..e63c857fade8e 100644 --- a/cli/ssh.go +++ b/cli/ssh.go @@ -37,6 +37,7 @@ import ( "github.com/coder/coder/v2/codersdk/workspacesdk" "github.com/coder/coder/v2/cryptorand" "github.com/coder/coder/v2/pty" + "github.com/coder/quartz" "github.com/coder/retry" "github.com/coder/serpent" ) @@ -48,6 +49,8 @@ const ( var ( workspacePollInterval = time.Minute autostopNotifyCountdown = []time.Duration{30 * time.Minute} + // gracefulShutdownTimeout is the timeout, per item in the stack of things to close + gracefulShutdownTimeout = 2 * time.Second ) func (r *RootCmd) ssh() *serpent.Command { @@ -153,7 +156,7 @@ func (r *RootCmd) ssh() *serpent.Command { // log HTTP requests client.SetLogger(logger) } - stack := newCloserStack(ctx, logger) + stack := newCloserStack(ctx, logger, quartz.NewReal()) defer stack.close(nil) for _, remoteForward := range remoteForwards { @@ -936,11 +939,18 @@ type closerStack struct { closed bool logger slog.Logger err error - wg sync.WaitGroup + allDone chan struct{} + + // for testing + clock quartz.Clock } -func newCloserStack(ctx context.Context, logger slog.Logger) *closerStack { - cs := &closerStack{logger: logger} +func newCloserStack(ctx context.Context, logger slog.Logger, clock quartz.Clock) *closerStack { + cs := &closerStack{ + logger: logger, + allDone: make(chan struct{}), + clock: clock, + } go cs.closeAfterContext(ctx) return cs } @@ -954,20 +964,58 @@ func (c *closerStack) close(err error) { c.Lock() if c.closed { c.Unlock() - c.wg.Wait() + <-c.allDone return } c.closed = true c.err = err - c.wg.Add(1) - defer c.wg.Done() c.Unlock() + defer close(c.allDone) + if len(c.closers) == 0 { + return + } - for i := len(c.closers) - 1; i >= 0; i-- { - cwn := c.closers[i] - cErr := cwn.closer.Close() - c.logger.Debug(context.Background(), - "closed item from stack", slog.F("name", cwn.name), slog.Error(cErr)) + // We are going to work down the stack in order. If things close quickly, we trigger the + // closers serially, in order. `done` is a channel that indicates the nth closer is done + // closing, and we should trigger the (n-1) closer. However, if things take too long we don't + // want to wait, so we also start a ticker that works down the stack and sends on `done` as + // well. + next := len(c.closers) - 1 + // here we make the buffer 2x the number of closers because we could write once for it being + // actually done and once via the countdown for each closer + done := make(chan int, len(c.closers)*2) + startNext := func() { + go func(i int) { + defer func() { done <- i }() + cwn := c.closers[i] + cErr := cwn.closer.Close() + c.logger.Debug(context.Background(), + "closed item from stack", slog.F("name", cwn.name), slog.Error(cErr)) + }(next) + next-- + } + done <- len(c.closers) // kick us off right away + + // start a ticking countdown in case we hang/don't close quickly + countdown := len(c.closers) - 1 + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + c.clock.TickerFunc(ctx, gracefulShutdownTimeout, func() error { + if countdown < 0 { + return nil + } + done <- countdown + countdown-- + return nil + }, "closerStack") + + for n := range done { // the nth closer is done + if n == 0 { + return + } + if n-1 == next { + startNext() + } } } diff --git a/cli/ssh_internal_test.go b/cli/ssh_internal_test.go index b612dd5ef9a32..eacfb384e6797 100644 --- a/cli/ssh_internal_test.go +++ b/cli/ssh_internal_test.go @@ -2,7 +2,9 @@ package cli import ( "context" + "fmt" "net/url" + "sync" "testing" "time" @@ -12,6 +14,7 @@ import ( "cdr.dev/slog" "cdr.dev/slog/sloggers/slogtest" + "github.com/coder/quartz" "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/testutil" @@ -68,7 +71,7 @@ func TestCloserStack_Mainline(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitShort) logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) - uut := newCloserStack(ctx, logger) + uut := newCloserStack(ctx, logger, quartz.NewMock(t)) closes := new([]*fakeCloser) fc0 := &fakeCloser{closes: closes} fc1 := &fakeCloser{closes: closes} @@ -84,13 +87,27 @@ func TestCloserStack_Mainline(t *testing.T) { require.Equal(t, []*fakeCloser{fc1, fc0}, *closes) } +func TestCloserStack_Empty(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + uut := newCloserStack(ctx, logger, quartz.NewMock(t)) + + closed := make(chan struct{}) + go func() { + defer close(closed) + uut.close(nil) + }() + testutil.RequireRecvCtx(ctx, t, closed) +} + func TestCloserStack_Context(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitShort) ctx, cancel := context.WithCancel(ctx) defer cancel() logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) - uut := newCloserStack(ctx, logger) + uut := newCloserStack(ctx, logger, quartz.NewMock(t)) closes := new([]*fakeCloser) fc0 := &fakeCloser{closes: closes} fc1 := &fakeCloser{closes: closes} @@ -111,7 +128,7 @@ func TestCloserStack_PushAfterClose(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitShort) logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug) - uut := newCloserStack(ctx, logger) + uut := newCloserStack(ctx, logger, quartz.NewMock(t)) closes := new([]*fakeCloser) fc0 := &fakeCloser{closes: closes} fc1 := &fakeCloser{closes: closes} @@ -134,13 +151,9 @@ func TestCloserStack_CloseAfterContext(t *testing.T) { ctx, cancel := context.WithCancel(testCtx) defer cancel() logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug) - uut := newCloserStack(ctx, logger) - ac := &asyncCloser{ - t: t, - ctx: testCtx, - complete: make(chan struct{}), - started: make(chan struct{}), - } + uut := newCloserStack(ctx, logger, quartz.NewMock(t)) + ac := newAsyncCloser(testCtx, t) + defer ac.complete() err := uut.push("async", ac) require.NoError(t, err) cancel() @@ -160,11 +173,53 @@ func TestCloserStack_CloseAfterContext(t *testing.T) { t.Fatal("closed before stack was finished") } - // complete the asyncCloser - close(ac.complete) + ac.complete() testutil.RequireRecvCtx(testCtx, t, closed) } +func TestCloserStack_Timeout(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug) + mClock := quartz.NewMock(t) + trap := mClock.Trap().TickerFunc("closerStack") + defer trap.Close() + uut := newCloserStack(ctx, logger, mClock) + var ac [3]*asyncCloser + for i := range ac { + ac[i] = newAsyncCloser(ctx, t) + err := uut.push(fmt.Sprintf("async %d", i), ac[i]) + require.NoError(t, err) + } + defer func() { + for _, a := range ac { + a.complete() + } + }() + + closed := make(chan struct{}) + go func() { + defer close(closed) + uut.close(nil) + }() + trap.MustWait(ctx).Release() + // top starts right away, but it hangs + testutil.RequireRecvCtx(ctx, t, ac[2].started) + // timer pops and we start the middle one + mClock.Advance(gracefulShutdownTimeout).MustWait(ctx) + testutil.RequireRecvCtx(ctx, t, ac[1].started) + + // middle one finishes + ac[1].complete() + // bottom starts, but also hangs + testutil.RequireRecvCtx(ctx, t, ac[0].started) + + // timer has to pop twice to time out. + mClock.Advance(gracefulShutdownTimeout).MustWait(ctx) + mClock.Advance(gracefulShutdownTimeout).MustWait(ctx) + testutil.RequireRecvCtx(ctx, t, closed) +} + type fakeCloser struct { closes *[]*fakeCloser err error @@ -176,10 +231,11 @@ func (c *fakeCloser) Close() error { } type asyncCloser struct { - t *testing.T - ctx context.Context - started chan struct{} - complete chan struct{} + t *testing.T + ctx context.Context + started chan struct{} + isComplete chan struct{} + comepleteOnce sync.Once } func (c *asyncCloser) Close() error { @@ -188,7 +244,20 @@ func (c *asyncCloser) Close() error { case <-c.ctx.Done(): c.t.Error("timed out") return c.ctx.Err() - case <-c.complete: + case <-c.isComplete: return nil } } + +func (c *asyncCloser) complete() { + c.comepleteOnce.Do(func() { close(c.isComplete) }) +} + +func newAsyncCloser(ctx context.Context, t *testing.T) *asyncCloser { + return &asyncCloser{ + t: t, + ctx: ctx, + isComplete: make(chan struct{}), + started: make(chan struct{}), + } +}