Skip to content

fix: close SSH sessions bottom-up if top-down fails #14678

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
fix: close SSH sessions bottom-up if top-down fails
  • Loading branch information
spikecurtis committed Sep 17, 2024
commit 7dd56e075a291cd3931decc6018b2c1e82d01902
72 changes: 60 additions & 12 deletions cli/ssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ import (
"github.com/coder/coder/v2/codersdk/workspacesdk"
"github.com/coder/coder/v2/cryptorand"
"github.com/coder/coder/v2/pty"
"github.com/coder/quartz"
"github.com/coder/retry"
"github.com/coder/serpent"
)
Expand All @@ -48,6 +49,8 @@ const (
var (
workspacePollInterval = time.Minute
autostopNotifyCountdown = []time.Duration{30 * time.Minute}
// gracefulShutdownTimeout is the timeout, per item in the stack of things to close
gracefulShutdownTimeout = 2 * time.Second
)

func (r *RootCmd) ssh() *serpent.Command {
Expand Down Expand Up @@ -153,7 +156,7 @@ func (r *RootCmd) ssh() *serpent.Command {
// log HTTP requests
client.SetLogger(logger)
}
stack := newCloserStack(ctx, logger)
stack := newCloserStack(ctx, logger, quartz.NewReal())
defer stack.close(nil)

for _, remoteForward := range remoteForwards {
Expand Down Expand Up @@ -936,11 +939,18 @@ type closerStack struct {
closed bool
logger slog.Logger
err error
wg sync.WaitGroup
allDone chan struct{}

// for testing
clock quartz.Clock
}

func newCloserStack(ctx context.Context, logger slog.Logger) *closerStack {
cs := &closerStack{logger: logger}
func newCloserStack(ctx context.Context, logger slog.Logger, clock quartz.Clock) *closerStack {
cs := &closerStack{
logger: logger,
allDone: make(chan struct{}),
clock: clock,
}
go cs.closeAfterContext(ctx)
return cs
}
Expand All @@ -954,20 +964,58 @@ func (c *closerStack) close(err error) {
c.Lock()
if c.closed {
c.Unlock()
c.wg.Wait()
<-c.allDone
return
}
c.closed = true
c.err = err
c.wg.Add(1)
defer c.wg.Done()
c.Unlock()
defer close(c.allDone)
if len(c.closers) == 0 {
return
}

for i := len(c.closers) - 1; i >= 0; i-- {
cwn := c.closers[i]
cErr := cwn.closer.Close()
c.logger.Debug(context.Background(),
"closed item from stack", slog.F("name", cwn.name), slog.Error(cErr))
// We are going to work down the stack in order. If things close quickly, we trigger the
// closers serially, in order. `done` is a channel that indicates the nth closer is done
// closing, and we should trigger the (n-1) closer. However, if things take too long we don't
// want to wait, so we also start a ticker that works down the stack and sends on `done` as
// well.
next := len(c.closers) - 1
// here we make the buffer 2x the number of closers because we could write once for it being
// actually done and once via the countdown for each closer
done := make(chan int, len(c.closers)*2)
startNext := func() {
go func(i int) {
defer func() { done <- i }()
cwn := c.closers[i]
cErr := cwn.closer.Close()
c.logger.Debug(context.Background(),
"closed item from stack", slog.F("name", cwn.name), slog.Error(cErr))
}(next)
next--
}
done <- len(c.closers) // kick us off right away

// start a ticking countdown in case we hang/don't close quickly
countdown := len(c.closers) - 1
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
c.clock.TickerFunc(ctx, gracefulShutdownTimeout, func() error {
if countdown < 0 {
return nil
}
done <- countdown
countdown--
return nil
}, "closerStack")

for n := range done { // the nth closer is done
if n == 0 {
return
}
if n-1 == next {
startNext()
}
}
}

Expand Down
103 changes: 86 additions & 17 deletions cli/ssh_internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ package cli

import (
"context"
"fmt"
"net/url"
"sync"
"testing"
"time"

Expand All @@ -12,6 +14,7 @@ import (

"cdr.dev/slog"
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/quartz"

"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/testutil"
Expand Down Expand Up @@ -68,7 +71,7 @@ func TestCloserStack_Mainline(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
uut := newCloserStack(ctx, logger)
uut := newCloserStack(ctx, logger, quartz.NewMock(t))
closes := new([]*fakeCloser)
fc0 := &fakeCloser{closes: closes}
fc1 := &fakeCloser{closes: closes}
Expand All @@ -84,13 +87,27 @@ func TestCloserStack_Mainline(t *testing.T) {
require.Equal(t, []*fakeCloser{fc1, fc0}, *closes)
}

func TestCloserStack_Empty(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
uut := newCloserStack(ctx, logger, quartz.NewMock(t))

closed := make(chan struct{})
go func() {
defer close(closed)
uut.close(nil)
}()
testutil.RequireRecvCtx(ctx, t, closed)
}

func TestCloserStack_Context(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
ctx, cancel := context.WithCancel(ctx)
defer cancel()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
uut := newCloserStack(ctx, logger)
uut := newCloserStack(ctx, logger, quartz.NewMock(t))
closes := new([]*fakeCloser)
fc0 := &fakeCloser{closes: closes}
fc1 := &fakeCloser{closes: closes}
Expand All @@ -111,7 +128,7 @@ func TestCloserStack_PushAfterClose(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
uut := newCloserStack(ctx, logger)
uut := newCloserStack(ctx, logger, quartz.NewMock(t))
closes := new([]*fakeCloser)
fc0 := &fakeCloser{closes: closes}
fc1 := &fakeCloser{closes: closes}
Expand All @@ -134,13 +151,9 @@ func TestCloserStack_CloseAfterContext(t *testing.T) {
ctx, cancel := context.WithCancel(testCtx)
defer cancel()
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
uut := newCloserStack(ctx, logger)
ac := &asyncCloser{
t: t,
ctx: testCtx,
complete: make(chan struct{}),
started: make(chan struct{}),
}
uut := newCloserStack(ctx, logger, quartz.NewMock(t))
ac := newAsyncCloser(testCtx, t)
defer ac.complete()
err := uut.push("async", ac)
require.NoError(t, err)
cancel()
Expand All @@ -160,11 +173,53 @@ func TestCloserStack_CloseAfterContext(t *testing.T) {
t.Fatal("closed before stack was finished")
}

// complete the asyncCloser
close(ac.complete)
ac.complete()
testutil.RequireRecvCtx(testCtx, t, closed)
}

func TestCloserStack_Timeout(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
mClock := quartz.NewMock(t)
trap := mClock.Trap().TickerFunc("closerStack")
defer trap.Close()
uut := newCloserStack(ctx, logger, mClock)
var ac [3]*asyncCloser
for i := range ac {
ac[i] = newAsyncCloser(ctx, t)
err := uut.push(fmt.Sprintf("async %d", i), ac[i])
require.NoError(t, err)
}
defer func() {
for _, a := range ac {
a.complete()
}
}()

closed := make(chan struct{})
go func() {
defer close(closed)
uut.close(nil)
}()
trap.MustWait(ctx).Release()
// top starts right away, but it hangs
testutil.RequireRecvCtx(ctx, t, ac[2].started)
// timer pops and we start the middle one
mClock.Advance(gracefulShutdownTimeout).MustWait(ctx)
testutil.RequireRecvCtx(ctx, t, ac[1].started)

// middle one finishes
ac[1].complete()
// bottom starts, but also hangs
testutil.RequireRecvCtx(ctx, t, ac[0].started)

// timer has to pop twice to time out.
mClock.Advance(gracefulShutdownTimeout).MustWait(ctx)
mClock.Advance(gracefulShutdownTimeout).MustWait(ctx)
testutil.RequireRecvCtx(ctx, t, closed)
}

type fakeCloser struct {
closes *[]*fakeCloser
err error
Expand All @@ -176,10 +231,11 @@ func (c *fakeCloser) Close() error {
}

type asyncCloser struct {
t *testing.T
ctx context.Context
started chan struct{}
complete chan struct{}
t *testing.T
ctx context.Context
started chan struct{}
isComplete chan struct{}
comepleteOnce sync.Once
}

func (c *asyncCloser) Close() error {
Expand All @@ -188,7 +244,20 @@ func (c *asyncCloser) Close() error {
case <-c.ctx.Done():
c.t.Error("timed out")
return c.ctx.Err()
case <-c.complete:
case <-c.isComplete:
return nil
}
}

func (c *asyncCloser) complete() {
c.comepleteOnce.Do(func() { close(c.isComplete) })
}

func newAsyncCloser(ctx context.Context, t *testing.T) *asyncCloser {
return &asyncCloser{
t: t,
ctx: ctx,
isComplete: make(chan struct{}),
started: make(chan struct{}),
}
}
Loading