Skip to content

fix: close ssh sessions gracefully #10732

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 129 additions & 32 deletions cli/ssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
gosshagent "golang.org/x/crypto/ssh/agent"
"golang.org/x/term"
"golang.org/x/xerrors"
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"

"cdr.dev/slog"
"cdr.dev/slog/sloggers/sloghuman"
Expand Down Expand Up @@ -129,6 +130,8 @@ func (r *RootCmd) ssh() *clibase.Cmd {
// log HTTP requests
client.SetLogger(logger)
}
stack := newCloserStack(ctx, logger)
defer stack.close(nil)

if remoteForward != "" {
isValid := validateRemoteForward(remoteForward)
Expand Down Expand Up @@ -212,7 +215,9 @@ func (r *RootCmd) ssh() *clibase.Cmd {
if err != nil {
return xerrors.Errorf("dial agent: %w", err)
}
defer conn.Close()
if err = stack.push("agent conn", conn); err != nil {
return err
}
conn.AwaitReachable(ctx)

stopPolling := tryPollWorkspaceAutostop(ctx, client, workspace)
Expand All @@ -223,61 +228,46 @@ func (r *RootCmd) ssh() *clibase.Cmd {
if err != nil {
return xerrors.Errorf("connect SSH: %w", err)
}
defer rawSSH.Close()
copier := &rawSSHCopier{conn: rawSSH, r: inv.Stdin, w: inv.Stdout}
if err = stack.push("rawSSHCopier", copier); err != nil {
return err
}

wg.Add(1)
go func() {
defer wg.Done()
watchAndClose(ctx, func() error {
return rawSSH.Close()
stack.close(xerrors.New("watchAndClose"))
return nil
}, logger, client, workspace)
}()

wg.Add(1)
go func() {
defer wg.Done()
// Ensure stdout copy closes incase stdin is closed
// unexpectedly.
defer rawSSH.Close()

_, err := io.Copy(rawSSH, inv.Stdin)
if err != nil {
logger.Error(ctx, "copy stdin error", slog.Error(err))
} else {
logger.Debug(ctx, "copy stdin complete")
}
}()
_, err = io.Copy(inv.Stdout, rawSSH)
if err != nil {
logger.Error(ctx, "copy stdout error", slog.Error(err))
} else {
logger.Debug(ctx, "copy stdout complete")
}
copier.copy(&wg)
return nil
}

sshClient, err := conn.SSHClient(ctx)
if err != nil {
return xerrors.Errorf("ssh client: %w", err)
}
defer sshClient.Close()
if err = stack.push("ssh client", sshClient); err != nil {
return err
}

sshSession, err := sshClient.NewSession()
if err != nil {
return xerrors.Errorf("ssh session: %w", err)
}
defer sshSession.Close()
if err = stack.push("sshSession", sshSession); err != nil {
return err
}

wg.Add(1)
go func() {
defer wg.Done()
watchAndClose(
ctx,
func() error {
err := sshSession.Close()
logger.Debug(ctx, "session close", slog.Error(err))
err = sshClient.Close()
logger.Debug(ctx, "client close", slog.Error(err))
stack.close(xerrors.New("watchAndClose"))
return nil
},
logger,
Expand Down Expand Up @@ -313,7 +303,9 @@ func (r *RootCmd) ssh() *clibase.Cmd {
if err != nil {
return xerrors.Errorf("forward GPG socket: %w", err)
}
defer closer.Close()
if err = stack.push("forwardGPGAgent", closer); err != nil {
return err
}
}

if remoteForward != "" {
Expand All @@ -326,7 +318,9 @@ func (r *RootCmd) ssh() *clibase.Cmd {
if err != nil {
return xerrors.Errorf("ssh remote forward: %w", err)
}
defer closer.Close()
if err = stack.push("sshRemoteForward", closer); err != nil {
return err
}
}

stdoutFile, validOut := inv.Stdout.(*os.File)
Expand Down Expand Up @@ -795,3 +789,106 @@ func remoteGPGAgentSocket(sshClient *gossh.Client) (string, error) {

return string(bytes.TrimSpace(remoteSocket)), nil
}

type closerWithName struct {
name string
closer io.Closer
}

type closerStack struct {
sync.Mutex
closers []closerWithName
closed bool
logger slog.Logger
err error
}

func newCloserStack(ctx context.Context, logger slog.Logger) *closerStack {
cs := &closerStack{logger: logger}
go cs.closeAfterContext(ctx)
return cs
}

func (c *closerStack) closeAfterContext(ctx context.Context) {
<-ctx.Done()
c.close(ctx.Err())
}

func (c *closerStack) close(err error) {
c.Lock()
if c.closed {
c.Unlock()
return
}
c.closed = true
c.err = err
c.Unlock()

for i := len(c.closers) - 1; i >= 0; i-- {
cwn := c.closers[i]
cErr := cwn.closer.Close()
c.logger.Debug(context.Background(),
"closed item from stack", slog.F("name", cwn.name), slog.Error(cErr))
}
}

func (c *closerStack) push(name string, closer io.Closer) error {
c.Lock()
if c.closed {
c.Unlock()
// since we're refusing to push it on the stack, close it now
err := closer.Close()
c.logger.Error(context.Background(),
"closed item rejected push", slog.F("name", name), slog.Error(err))
return xerrors.Errorf("already closed: %w", c.err)
}
c.closers = append(c.closers, closerWithName{name: name, closer: closer})
c.Unlock()
return nil
}

// rawSSHCopier handles copying raw SSH data between the conn and the pair (r, w).
type rawSSHCopier struct {
conn *gonet.TCPConn
logger slog.Logger
r io.Reader
w io.Writer
}

func (c *rawSSHCopier) copy(wg *sync.WaitGroup) {
logCtx := context.Background()
wg.Add(1)
go func() {
defer wg.Done()
// We close connections using CloseWrite instead of Close, so that the SSH server sees the
// closed connection while reading, and shuts down cleanly. This will trigger the io.Copy
// in the server-to-client direction to also be closed and the copy() routine will exit.
// This ensures that we don't leave any state in the server, like forwarded ports if
// copy() were to return and the underlying tailnet connection torn down before the TCP
// session exits. This is a bit of a hack to block shut down at the application layer, since
// we can't serialize the TCP and tailnet layers shutting down.
//
// Of course, if the underlying transport is broken, io.Copy will still return.
defer func() {
cwErr := c.conn.CloseWrite()
c.logger.Debug(logCtx, "closed raw SSH connection for writing", slog.Error(cwErr))
}()

_, err := io.Copy(c.conn, c.r)
if err != nil {
c.logger.Error(logCtx, "copy stdin error", slog.Error(err))
} else {
c.logger.Debug(logCtx, "copy stdin complete")
}
}()
_, err := io.Copy(c.w, c.conn)
if err != nil {
c.logger.Error(logCtx, "copy stdout error", slog.Error(err))
} else {
c.logger.Debug(logCtx, "copy stdout complete")
}
}

func (c *rawSSHCopier) Close() error {
return c.conn.CloseWrite()
}
81 changes: 81 additions & 0 deletions cli/ssh_internal_test.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,16 @@
package cli

import (
"context"
"net/url"
"testing"

"golang.org/x/xerrors"

"cdr.dev/slog"
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/v2/testutil"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

Expand Down Expand Up @@ -56,3 +63,77 @@ func TestBuildWorkspaceLink(t *testing.T) {

assert.Equal(t, workspaceLink.String(), fakeServerURL+"/@"+fakeOwnerName+"/"+fakeWorkspaceName)
}

func TestCloserStack_Mainline(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
uut := newCloserStack(ctx, logger)
closes := new([]*fakeCloser)
fc0 := &fakeCloser{closes: closes}
fc1 := &fakeCloser{closes: closes}

func() {
defer uut.close(nil)
err := uut.push("fc0", fc0)
require.NoError(t, err)
err = uut.push("fc1", fc1)
require.NoError(t, err)
}()
// order reversed
require.Equal(t, []*fakeCloser{fc1, fc0}, *closes)
}

func TestCloserStack_Context(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
ctx, cancel := context.WithCancel(ctx)
defer cancel()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
uut := newCloserStack(ctx, logger)
closes := new([]*fakeCloser)
fc0 := &fakeCloser{closes: closes}
fc1 := &fakeCloser{closes: closes}

err := uut.push("fc0", fc0)
require.NoError(t, err)
err = uut.push("fc1", fc1)
require.NoError(t, err)
cancel()
require.Eventually(t, func() bool {
uut.Lock()
defer uut.Unlock()
return uut.closed
}, testutil.WaitShort, testutil.IntervalFast)
}

func TestCloserStack_PushAfterClose(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
uut := newCloserStack(ctx, logger)
closes := new([]*fakeCloser)
fc0 := &fakeCloser{closes: closes}
fc1 := &fakeCloser{closes: closes}

err := uut.push("fc0", fc0)
require.NoError(t, err)

exErr := xerrors.New("test")
uut.close(exErr)
require.Equal(t, []*fakeCloser{fc0}, *closes)

err = uut.push("fc1", fc1)
require.ErrorIs(t, err, exErr)
require.Equal(t, []*fakeCloser{fc1, fc0}, *closes, "should close fc1")
}

type fakeCloser struct {
closes *[]*fakeCloser
err error
}

func (c *fakeCloser) Close() error {
*c.closes = append(*c.closes, c)
return c.err
}
Loading