diff --git a/cli/ssh.go b/cli/ssh.go index 0b0a81f2ad83e..f2844b3c55498 100644 --- a/cli/ssh.go +++ b/cli/ssh.go @@ -22,6 +22,7 @@ import ( gosshagent "golang.org/x/crypto/ssh/agent" "golang.org/x/term" "golang.org/x/xerrors" + "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" "cdr.dev/slog" "cdr.dev/slog/sloggers/sloghuman" @@ -129,6 +130,8 @@ func (r *RootCmd) ssh() *clibase.Cmd { // log HTTP requests client.SetLogger(logger) } + stack := newCloserStack(ctx, logger) + defer stack.close(nil) if remoteForward != "" { isValid := validateRemoteForward(remoteForward) @@ -212,7 +215,9 @@ func (r *RootCmd) ssh() *clibase.Cmd { if err != nil { return xerrors.Errorf("dial agent: %w", err) } - defer conn.Close() + if err = stack.push("agent conn", conn); err != nil { + return err + } conn.AwaitReachable(ctx) stopPolling := tryPollWorkspaceAutostop(ctx, client, workspace) @@ -223,36 +228,20 @@ func (r *RootCmd) ssh() *clibase.Cmd { if err != nil { return xerrors.Errorf("connect SSH: %w", err) } - defer rawSSH.Close() + copier := &rawSSHCopier{conn: rawSSH, r: inv.Stdin, w: inv.Stdout} + if err = stack.push("rawSSHCopier", copier); err != nil { + return err + } wg.Add(1) go func() { defer wg.Done() watchAndClose(ctx, func() error { - return rawSSH.Close() + stack.close(xerrors.New("watchAndClose")) + return nil }, logger, client, workspace) }() - - wg.Add(1) - go func() { - defer wg.Done() - // Ensure stdout copy closes incase stdin is closed - // unexpectedly. - defer rawSSH.Close() - - _, err := io.Copy(rawSSH, inv.Stdin) - if err != nil { - logger.Error(ctx, "copy stdin error", slog.Error(err)) - } else { - logger.Debug(ctx, "copy stdin complete") - } - }() - _, err = io.Copy(inv.Stdout, rawSSH) - if err != nil { - logger.Error(ctx, "copy stdout error", slog.Error(err)) - } else { - logger.Debug(ctx, "copy stdout complete") - } + copier.copy(&wg) return nil } @@ -260,13 +249,17 @@ func (r *RootCmd) ssh() *clibase.Cmd { if err != nil { return xerrors.Errorf("ssh client: %w", err) } - defer sshClient.Close() + if err = stack.push("ssh client", sshClient); err != nil { + return err + } sshSession, err := sshClient.NewSession() if err != nil { return xerrors.Errorf("ssh session: %w", err) } - defer sshSession.Close() + if err = stack.push("sshSession", sshSession); err != nil { + return err + } wg.Add(1) go func() { @@ -274,10 +267,7 @@ func (r *RootCmd) ssh() *clibase.Cmd { watchAndClose( ctx, func() error { - err := sshSession.Close() - logger.Debug(ctx, "session close", slog.Error(err)) - err = sshClient.Close() - logger.Debug(ctx, "client close", slog.Error(err)) + stack.close(xerrors.New("watchAndClose")) return nil }, logger, @@ -313,7 +303,9 @@ func (r *RootCmd) ssh() *clibase.Cmd { if err != nil { return xerrors.Errorf("forward GPG socket: %w", err) } - defer closer.Close() + if err = stack.push("forwardGPGAgent", closer); err != nil { + return err + } } if remoteForward != "" { @@ -326,7 +318,9 @@ func (r *RootCmd) ssh() *clibase.Cmd { if err != nil { return xerrors.Errorf("ssh remote forward: %w", err) } - defer closer.Close() + if err = stack.push("sshRemoteForward", closer); err != nil { + return err + } } stdoutFile, validOut := inv.Stdout.(*os.File) @@ -795,3 +789,106 @@ func remoteGPGAgentSocket(sshClient *gossh.Client) (string, error) { return string(bytes.TrimSpace(remoteSocket)), nil } + +type closerWithName struct { + name string + closer io.Closer +} + +type closerStack struct { + sync.Mutex + closers []closerWithName + closed bool + logger slog.Logger + err error +} + +func newCloserStack(ctx context.Context, logger slog.Logger) *closerStack { + cs := &closerStack{logger: logger} + go cs.closeAfterContext(ctx) + return cs +} + +func (c *closerStack) closeAfterContext(ctx context.Context) { + <-ctx.Done() + c.close(ctx.Err()) +} + +func (c *closerStack) close(err error) { + c.Lock() + if c.closed { + c.Unlock() + return + } + c.closed = true + c.err = err + c.Unlock() + + for i := len(c.closers) - 1; i >= 0; i-- { + cwn := c.closers[i] + cErr := cwn.closer.Close() + c.logger.Debug(context.Background(), + "closed item from stack", slog.F("name", cwn.name), slog.Error(cErr)) + } +} + +func (c *closerStack) push(name string, closer io.Closer) error { + c.Lock() + if c.closed { + c.Unlock() + // since we're refusing to push it on the stack, close it now + err := closer.Close() + c.logger.Error(context.Background(), + "closed item rejected push", slog.F("name", name), slog.Error(err)) + return xerrors.Errorf("already closed: %w", c.err) + } + c.closers = append(c.closers, closerWithName{name: name, closer: closer}) + c.Unlock() + return nil +} + +// rawSSHCopier handles copying raw SSH data between the conn and the pair (r, w). +type rawSSHCopier struct { + conn *gonet.TCPConn + logger slog.Logger + r io.Reader + w io.Writer +} + +func (c *rawSSHCopier) copy(wg *sync.WaitGroup) { + logCtx := context.Background() + wg.Add(1) + go func() { + defer wg.Done() + // We close connections using CloseWrite instead of Close, so that the SSH server sees the + // closed connection while reading, and shuts down cleanly. This will trigger the io.Copy + // in the server-to-client direction to also be closed and the copy() routine will exit. + // This ensures that we don't leave any state in the server, like forwarded ports if + // copy() were to return and the underlying tailnet connection torn down before the TCP + // session exits. This is a bit of a hack to block shut down at the application layer, since + // we can't serialize the TCP and tailnet layers shutting down. + // + // Of course, if the underlying transport is broken, io.Copy will still return. + defer func() { + cwErr := c.conn.CloseWrite() + c.logger.Debug(logCtx, "closed raw SSH connection for writing", slog.Error(cwErr)) + }() + + _, err := io.Copy(c.conn, c.r) + if err != nil { + c.logger.Error(logCtx, "copy stdin error", slog.Error(err)) + } else { + c.logger.Debug(logCtx, "copy stdin complete") + } + }() + _, err := io.Copy(c.w, c.conn) + if err != nil { + c.logger.Error(logCtx, "copy stdout error", slog.Error(err)) + } else { + c.logger.Debug(logCtx, "copy stdout complete") + } +} + +func (c *rawSSHCopier) Close() error { + return c.conn.CloseWrite() +} diff --git a/cli/ssh_internal_test.go b/cli/ssh_internal_test.go index 07a6a3c5802f2..3e3e116e95e5d 100644 --- a/cli/ssh_internal_test.go +++ b/cli/ssh_internal_test.go @@ -1,9 +1,16 @@ package cli import ( + "context" "net/url" "testing" + "golang.org/x/xerrors" + + "cdr.dev/slog" + "cdr.dev/slog/sloggers/slogtest" + "github.com/coder/coder/v2/testutil" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -56,3 +63,77 @@ func TestBuildWorkspaceLink(t *testing.T) { assert.Equal(t, workspaceLink.String(), fakeServerURL+"/@"+fakeOwnerName+"/"+fakeWorkspaceName) } + +func TestCloserStack_Mainline(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + uut := newCloserStack(ctx, logger) + closes := new([]*fakeCloser) + fc0 := &fakeCloser{closes: closes} + fc1 := &fakeCloser{closes: closes} + + func() { + defer uut.close(nil) + err := uut.push("fc0", fc0) + require.NoError(t, err) + err = uut.push("fc1", fc1) + require.NoError(t, err) + }() + // order reversed + require.Equal(t, []*fakeCloser{fc1, fc0}, *closes) +} + +func TestCloserStack_Context(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) + ctx, cancel := context.WithCancel(ctx) + defer cancel() + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + uut := newCloserStack(ctx, logger) + closes := new([]*fakeCloser) + fc0 := &fakeCloser{closes: closes} + fc1 := &fakeCloser{closes: closes} + + err := uut.push("fc0", fc0) + require.NoError(t, err) + err = uut.push("fc1", fc1) + require.NoError(t, err) + cancel() + require.Eventually(t, func() bool { + uut.Lock() + defer uut.Unlock() + return uut.closed + }, testutil.WaitShort, testutil.IntervalFast) +} + +func TestCloserStack_PushAfterClose(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug) + uut := newCloserStack(ctx, logger) + closes := new([]*fakeCloser) + fc0 := &fakeCloser{closes: closes} + fc1 := &fakeCloser{closes: closes} + + err := uut.push("fc0", fc0) + require.NoError(t, err) + + exErr := xerrors.New("test") + uut.close(exErr) + require.Equal(t, []*fakeCloser{fc0}, *closes) + + err = uut.push("fc1", fc1) + require.ErrorIs(t, err, exErr) + require.Equal(t, []*fakeCloser{fc1, fc0}, *closes, "should close fc1") +} + +type fakeCloser struct { + closes *[]*fakeCloser + err error +} + +func (c *fakeCloser) Close() error { + *c.closes = append(*c.closes, c) + return c.err +} diff --git a/cli/ssh_test.go b/cli/ssh_test.go index 0abdfb4583ac2..fe911e6e7bde7 100644 --- a/cli/ssh_test.go +++ b/cli/ssh_test.go @@ -249,10 +249,125 @@ func TestSSH(t *testing.T) { <-cmdDone }) + t.Run("Stdio_RemoteForward_Signal", func(t *testing.T) { + t.Parallel() + client, workspace, agentToken := setupWorkspaceForAgent(t, nil) + _, _ = tGoContext(t, func(ctx context.Context) { + // Run this async so the SSH command has to wait for + // the build and agent to connect! + _ = agenttest.New(t, client.URL, agentToken) + <-ctx.Done() + }) + + clientOutput, clientInput := io.Pipe() + serverOutput, serverInput := io.Pipe() + defer func() { + for _, c := range []io.Closer{clientOutput, clientInput, serverOutput, serverInput} { + _ = c.Close() + } + }() + + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) + defer cancel() + + inv, root := clitest.New(t, "ssh", "--stdio", workspace.Name) + fsn := clitest.NewFakeSignalNotifier(t) + inv = inv.WithTestSignalNotifyContext(t, fsn.NotifyContext) + clitest.SetupConfig(t, client, root) + inv.Stdin = clientOutput + inv.Stdout = serverInput + inv.Stderr = io.Discard + + cmdDone := tGo(t, func() { + err := inv.WithContext(ctx).Run() + assert.NoError(t, err) + }) + + conn, channels, requests, err := ssh.NewClientConn(&stdioConn{ + Reader: serverOutput, + Writer: clientInput, + }, "", &ssh.ClientConfig{ + // #nosec + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + }) + require.NoError(t, err) + defer conn.Close() + + sshClient := ssh.NewClient(conn, channels, requests) + + tmpdir := tempDirUnixSocket(t) + + remoteSock := path.Join(tmpdir, "remote.sock") + _, err = sshClient.ListenUnix(remoteSock) + require.NoError(t, err) + + fsn.Notify() + <-cmdDone + fsn.AssertStopped() + require.Eventually(t, func() bool { + _, err = os.Stat(remoteSock) + return xerrors.Is(err, os.ErrNotExist) + }, testutil.WaitShort, testutil.IntervalFast) + }) + + t.Run("Stdio_BrokenConn", func(t *testing.T) { + t.Parallel() + client, workspace, agentToken := setupWorkspaceForAgent(t, nil) + _, _ = tGoContext(t, func(ctx context.Context) { + // Run this async so the SSH command has to wait for + // the build and agent to connect! + _ = agenttest.New(t, client.URL, agentToken) + <-ctx.Done() + }) + + clientOutput, clientInput := io.Pipe() + serverOutput, serverInput := io.Pipe() + defer func() { + for _, c := range []io.Closer{clientOutput, clientInput, serverOutput, serverInput} { + _ = c.Close() + } + }() + + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) + defer cancel() + + inv, root := clitest.New(t, "ssh", "--stdio", workspace.Name) + clitest.SetupConfig(t, client, root) + inv.Stdin = clientOutput + inv.Stdout = serverInput + inv.Stderr = io.Discard + + cmdDone := tGo(t, func() { + err := inv.WithContext(ctx).Run() + assert.NoError(t, err) + }) + + conn, channels, requests, err := ssh.NewClientConn(&stdioConn{ + Reader: serverOutput, + Writer: clientInput, + }, "", &ssh.ClientConfig{ + // #nosec + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + }) + require.NoError(t, err) + defer conn.Close() + + sshClient := ssh.NewClient(conn, channels, requests) + _ = serverOutput.Close() + _ = clientInput.Close() + select { + case <-cmdDone: + // OK + case <-time.After(testutil.WaitShort): + t.Error("timeout waiting for command to exit") + } + + _ = sshClient.Close() + }) + // Test that we handle OS signals properly while remote forwarding, and don't just leave the TCP // socket hanging. t.Run("RemoteForward_Unix_Signal", func(t *testing.T) { - t.Skip("still flaky") if runtime.GOOS == "windows" { t.Skip("No unix sockets on windows") } @@ -578,12 +693,13 @@ func TestSSH(t *testing.T) { l, err := net.Listen("unix", agentSock) require.NoError(t, err) defer l.Close() + remoteSock := filepath.Join(tmpdir, "remote.sock") inv, root := clitest.New(t, "ssh", workspace.Name, "--remote-forward", - "/tmp/test.sock:"+agentSock, + fmt.Sprintf("%s:%s", remoteSock, agentSock), ) clitest.SetupConfig(t, client, root) pty := ptytest.New(t).Attach(inv) @@ -598,7 +714,7 @@ func TestSSH(t *testing.T) { _ = pty.Peek(ctx, 1) // Download the test page - pty.WriteLine("ss -xl state listening src /tmp/test.sock | wc -l") + pty.WriteLine(fmt.Sprintf("ss -xl state listening src %s | wc -l", remoteSock)) pty.ExpectMatch("2") // And we're done. diff --git a/codersdk/workspaceagentconn.go b/codersdk/workspaceagentconn.go index e38b4f2a47f06..bbf2bdb8beab2 100644 --- a/codersdk/workspaceagentconn.go +++ b/codersdk/workspaceagentconn.go @@ -18,6 +18,7 @@ import ( "github.com/hashicorp/go-multierror" "golang.org/x/crypto/ssh" "golang.org/x/xerrors" + "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" "tailscale.com/ipn/ipnstate" "tailscale.com/net/speedtest" @@ -249,7 +250,7 @@ func (c *WorkspaceAgentConn) ReconnectingPTY(ctx context.Context, id uuid.UUID, // SSH pipes the SSH protocol over the returned net.Conn. // This connects to the built-in SSH server in the workspace agent. -func (c *WorkspaceAgentConn) SSH(ctx context.Context) (net.Conn, error) { +func (c *WorkspaceAgentConn) SSH(ctx context.Context) (*gonet.TCPConn, error) { ctx, span := tracing.StartSpan(ctx) defer span.End()