From dd1f7c49a60cff3ec1b0d962b0f8afe9c9d520b7 Mon Sep 17 00:00:00 2001 From: Mathias Fredriksson Date: Wed, 7 Feb 2024 12:27:15 +0200 Subject: [PATCH 1/4] fix(cli/ssh): prevent stdin/stdout reads/writes in stdio mode Fixes #11530 --- cli/ssh.go | 19 ++++++- cli/ssh_test.go | 142 ++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 160 insertions(+), 1 deletion(-) diff --git a/cli/ssh.go b/cli/ssh.go index aae28b76a03ff..bdc5d98b3c9c0 100644 --- a/cli/ssh.go +++ b/cli/ssh.go @@ -87,6 +87,14 @@ func (r *RootCmd) ssh() *clibase.Cmd { } }() + // In stdio mode, we can't allow any writes to stdin or stdout + // because they are used by the SSH protocol. + stdioReader, stdioWriter := inv.Stdin, inv.Stdout + if stdio { + inv.Stdin = stdioErrLogReader{inv.Logger} + inv.Stdout = inv.Stderr + } + // This WaitGroup solves for a race condition where we were logging // while closing the log file in a defer. It probably solves // others too. @@ -234,7 +242,7 @@ func (r *RootCmd) ssh() *clibase.Cmd { if err != nil { return xerrors.Errorf("connect SSH: %w", err) } - copier := newRawSSHCopier(logger, rawSSH, inv.Stdin, inv.Stdout) + copier := newRawSSHCopier(logger, rawSSH, stdioReader, stdioWriter) if err = stack.push("rawSSHCopier", copier); err != nil { return err } @@ -987,3 +995,12 @@ func sshDisableAutostartOption(src *clibase.Bool) clibase.Option { Default: "false", } } + +type stdioErrLogReader struct { + l slog.Logger +} + +func (r stdioErrLogReader) Read(_ []byte) (int, error) { + r.l.Error(context.Background(), "reading from stdin in stdio mode is not allowed") + return 0, io.EOF +} diff --git a/cli/ssh_test.go b/cli/ssh_test.go index d36df6218ed66..1c695602bdcdd 100644 --- a/cli/ssh_test.go +++ b/cli/ssh_test.go @@ -1,6 +1,7 @@ package cli_test import ( + "bufio" "bytes" "context" "crypto/ecdsa" @@ -338,6 +339,147 @@ func TestSSH(t *testing.T) { <-cmdDone }) + t.Run("Stdio_StartStoppedWorkspace_CleanStdout", func(t *testing.T) { + t.Parallel() + + authToken := uuid.NewString() + ownerClient := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true}) + owner := coderdtest.CreateFirstUser(t, ownerClient) + client, _ := coderdtest.CreateAnotherUser(t, ownerClient, owner.OrganizationID, rbac.RoleTemplateAdmin()) + version := coderdtest.CreateTemplateVersion(t, client, owner.OrganizationID, &echo.Responses{ + Parse: echo.ParseComplete, + ProvisionPlan: echo.PlanComplete, + ProvisionApply: echo.ProvisionApplyWithAgent(authToken), + }) + coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID) + template := coderdtest.CreateTemplate(t, client, owner.OrganizationID, version.ID) + workspace := coderdtest.CreateWorkspace(t, client, owner.OrganizationID, template.ID) + coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspace.LatestBuild.ID) + // Stop the workspace + workspaceBuild := coderdtest.CreateWorkspaceBuild(t, client, workspace, database.WorkspaceTransitionStop) + coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspaceBuild.ID) + + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) + defer cancel() + + clientOutput, clientInput := io.Pipe() + serverOutput, serverInput := io.Pipe() + monitorServerOutput, monitorServerInput := io.Pipe() + closePipes := func() { + for _, c := range []io.Closer{clientOutput, clientInput, serverOutput, serverInput, monitorServerOutput, monitorServerInput} { + _ = c.Close() + } + } + defer closePipes() + tGo(t, func() { + <-ctx.Done() + closePipes() + }) + + // Here we start a monitor for the input going to the server + // (i.e. client stdout) to ensure that the output is clean. + serverInputBuf := make(chan byte, 4096) + tGo(t, func() { + defer close(serverInputBuf) + + gotHeader := false + buf := bytes.Buffer{} + r := bufio.NewReader(monitorServerOutput) + for { + b, err := r.ReadByte() + if err != nil { + if errors.Is(err, io.ErrClosedPipe) { + return + } + assert.NoError(t, err, "read byte failed") + return + } + if b == '\n' || b == '\r' { + out := buf.Bytes() + t.Logf("monitorServerOutput: %q (%#x)", out, out) + buf.Reset() + + // Ideally we would do further verification, but that would + // involve parsing the SSH protocol to look for output that + // doesn't belong. This at least ensures that no garbage is + // being sent to the server before trying to connect. + if !gotHeader { + gotHeader = true + assert.Equal(t, "SSH-2.0-Go", string(out), "invalid header") + } + } else { + _ = buf.WriteByte(b) + } + select { + case serverInputBuf <- b: + case <-ctx.Done(): + return + } + } + }) + tGo(t, func() { + defer serverInput.Close() + + // Range closed by above goroutine. + for b := range serverInputBuf { + _, err := serverInput.Write([]byte{b}) + if err != nil { + if errors.Is(err, io.ErrClosedPipe) { + return + } + assert.NoError(t, err, "write byte failed") + return + } + } + }) + + // Start the SSH stdio command. + inv, root := clitest.New(t, "ssh", "--stdio", workspace.Name) + clitest.SetupConfig(t, client, root) + inv.Stdin = clientOutput + inv.Stdout = monitorServerInput + inv.Stderr = io.Discard + + cmdDone := tGo(t, func() { + err := inv.WithContext(ctx).Run() + assert.NoError(t, err) + }) + + tGo(t, func() { + // When the agent connects, the workspace was started, and we should + // have access to the shell. + _ = agenttest.New(t, client.URL, authToken) + coderdtest.NewWorkspaceAgentWaiter(t, client, workspace.ID).Wait() + }) + + conn, channels, requests, err := ssh.NewClientConn(&stdioConn{ + Reader: serverOutput, + Writer: clientInput, + }, "", &ssh.ClientConfig{ + // #nosec + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + }) + require.NoError(t, err) + defer conn.Close() + + sshClient := ssh.NewClient(conn, channels, requests) + session, err := sshClient.NewSession() + require.NoError(t, err) + defer session.Close() + + command := "sh -c exit" + if runtime.GOOS == "windows" { + command = "cmd.exe /c exit" + } + err = session.Run(command) + require.NoError(t, err) + err = sshClient.Close() + require.NoError(t, err) + _ = clientOutput.Close() + + <-cmdDone + }) + t.Run("Stdio_RemoteForward_Signal", func(t *testing.T) { t.Parallel() client, workspace, agentToken := setupWorkspaceForAgent(t) From 80f9ebb61dd2411cb0fb0ee46f3184f699cc6325 Mon Sep 17 00:00:00 2001 From: Mathias Fredriksson Date: Wed, 7 Feb 2024 13:39:42 +0000 Subject: [PATCH 2/4] improve naming and add pipe flowchart --- cli/ssh_test.go | 43 ++++++++++++++++++++++++++----------------- 1 file changed, 26 insertions(+), 17 deletions(-) diff --git a/cli/ssh_test.go b/cli/ssh_test.go index 1c695602bdcdd..ed1f14cc856f9 100644 --- a/cli/ssh_test.go +++ b/cli/ssh_test.go @@ -362,11 +362,20 @@ func TestSSH(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) defer cancel() - clientOutput, clientInput := io.Pipe() - serverOutput, serverInput := io.Pipe() - monitorServerOutput, monitorServerInput := io.Pipe() + clientStdinR, clientStdinW := io.Pipe() + // Here's a simple flowchart for how these pipes are used: + // + // flowchart LR + // A[ProxyCommand] --> B[captureProxyCommandStdoutW] + // B --> C[captureProxyCommandStdoutR] + // C --> VA[Validate output] + // C --> D[proxyCommandOutputW] + // D --> E[proxyCommandOutputR] + // E --> F[SSH Client] + proxyCommandOutputR, proxyCommandOutputW := io.Pipe() + captureProxyCommandStdoutR, captureProxyCommandStdoutW := io.Pipe() closePipes := func() { - for _, c := range []io.Closer{clientOutput, clientInput, serverOutput, serverInput, monitorServerOutput, monitorServerInput} { + for _, c := range []io.Closer{clientStdinR, clientStdinW, proxyCommandOutputR, proxyCommandOutputW, captureProxyCommandStdoutR, captureProxyCommandStdoutW} { _ = c.Close() } } @@ -378,13 +387,13 @@ func TestSSH(t *testing.T) { // Here we start a monitor for the input going to the server // (i.e. client stdout) to ensure that the output is clean. - serverInputBuf := make(chan byte, 4096) + proxyCommandOutputBuf := make(chan byte, 4096) tGo(t, func() { - defer close(serverInputBuf) + defer close(proxyCommandOutputBuf) gotHeader := false buf := bytes.Buffer{} - r := bufio.NewReader(monitorServerOutput) + r := bufio.NewReader(captureProxyCommandStdoutR) for { b, err := r.ReadByte() if err != nil { @@ -402,7 +411,7 @@ func TestSSH(t *testing.T) { // Ideally we would do further verification, but that would // involve parsing the SSH protocol to look for output that // doesn't belong. This at least ensures that no garbage is - // being sent to the server before trying to connect. + // being sent to the SSH client before trying to connect. if !gotHeader { gotHeader = true assert.Equal(t, "SSH-2.0-Go", string(out), "invalid header") @@ -411,18 +420,18 @@ func TestSSH(t *testing.T) { _ = buf.WriteByte(b) } select { - case serverInputBuf <- b: + case proxyCommandOutputBuf <- b: case <-ctx.Done(): return } } }) tGo(t, func() { - defer serverInput.Close() + defer proxyCommandOutputW.Close() // Range closed by above goroutine. - for b := range serverInputBuf { - _, err := serverInput.Write([]byte{b}) + for b := range proxyCommandOutputBuf { + _, err := proxyCommandOutputW.Write([]byte{b}) if err != nil { if errors.Is(err, io.ErrClosedPipe) { return @@ -436,8 +445,8 @@ func TestSSH(t *testing.T) { // Start the SSH stdio command. inv, root := clitest.New(t, "ssh", "--stdio", workspace.Name) clitest.SetupConfig(t, client, root) - inv.Stdin = clientOutput - inv.Stdout = monitorServerInput + inv.Stdin = clientStdinR + inv.Stdout = captureProxyCommandStdoutW inv.Stderr = io.Discard cmdDone := tGo(t, func() { @@ -453,8 +462,8 @@ func TestSSH(t *testing.T) { }) conn, channels, requests, err := ssh.NewClientConn(&stdioConn{ - Reader: serverOutput, - Writer: clientInput, + Reader: proxyCommandOutputR, + Writer: clientStdinW, }, "", &ssh.ClientConfig{ // #nosec HostKeyCallback: ssh.InsecureIgnoreHostKey(), @@ -475,7 +484,7 @@ func TestSSH(t *testing.T) { require.NoError(t, err) err = sshClient.Close() require.NoError(t, err) - _ = clientOutput.Close() + _ = clientStdinR.Close() <-cmdDone }) From 6df02c1bb434e7ffbfd79c0c82f1cdecd6e9d825 Mon Sep 17 00:00:00 2001 From: Mathias Fredriksson Date: Wed, 7 Feb 2024 13:43:25 +0000 Subject: [PATCH 3/4] missed some names --- cli/ssh_test.go | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/cli/ssh_test.go b/cli/ssh_test.go index ed1f14cc856f9..0a70cc62bff0f 100644 --- a/cli/ssh_test.go +++ b/cli/ssh_test.go @@ -369,13 +369,13 @@ func TestSSH(t *testing.T) { // A[ProxyCommand] --> B[captureProxyCommandStdoutW] // B --> C[captureProxyCommandStdoutR] // C --> VA[Validate output] - // C --> D[proxyCommandOutputW] - // D --> E[proxyCommandOutputR] + // C --> D[proxyCommandStdoutW] + // D --> E[proxyCommandStdoutR] // E --> F[SSH Client] - proxyCommandOutputR, proxyCommandOutputW := io.Pipe() + proxyCommandStdoutR, proxyCommandStdoutW := io.Pipe() captureProxyCommandStdoutR, captureProxyCommandStdoutW := io.Pipe() closePipes := func() { - for _, c := range []io.Closer{clientStdinR, clientStdinW, proxyCommandOutputR, proxyCommandOutputW, captureProxyCommandStdoutR, captureProxyCommandStdoutW} { + for _, c := range []io.Closer{clientStdinR, clientStdinW, proxyCommandStdoutR, proxyCommandStdoutW, captureProxyCommandStdoutR, captureProxyCommandStdoutW} { _ = c.Close() } } @@ -427,11 +427,11 @@ func TestSSH(t *testing.T) { } }) tGo(t, func() { - defer proxyCommandOutputW.Close() + defer proxyCommandStdoutW.Close() // Range closed by above goroutine. for b := range proxyCommandOutputBuf { - _, err := proxyCommandOutputW.Write([]byte{b}) + _, err := proxyCommandStdoutW.Write([]byte{b}) if err != nil { if errors.Is(err, io.ErrClosedPipe) { return @@ -462,7 +462,7 @@ func TestSSH(t *testing.T) { }) conn, channels, requests, err := ssh.NewClientConn(&stdioConn{ - Reader: proxyCommandOutputR, + Reader: proxyCommandStdoutR, Writer: clientStdinW, }, "", &ssh.ClientConfig{ // #nosec From 7def32061d8ab1fb1006308e9cf8ae9b9333efc1 Mon Sep 17 00:00:00 2001 From: Mathias Fredriksson Date: Wed, 7 Feb 2024 13:45:45 +0000 Subject: [PATCH 4/4] fix comment --- cli/ssh_test.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/cli/ssh_test.go b/cli/ssh_test.go index 0a70cc62bff0f..ee7cefbee58d7 100644 --- a/cli/ssh_test.go +++ b/cli/ssh_test.go @@ -385,8 +385,9 @@ func TestSSH(t *testing.T) { closePipes() }) - // Here we start a monitor for the input going to the server - // (i.e. client stdout) to ensure that the output is clean. + // Here we start a monitor for the output produced by the proxy command, + // which is read by the SSH client. This is done to validate that the + // output is clean. proxyCommandOutputBuf := make(chan byte, 4096) tGo(t, func() { defer close(proxyCommandOutputBuf)