diff --git a/cli/ssh.go b/cli/ssh.go index aae28b76a03ff..bdc5d98b3c9c0 100644 --- a/cli/ssh.go +++ b/cli/ssh.go @@ -87,6 +87,14 @@ func (r *RootCmd) ssh() *clibase.Cmd { } }() + // In stdio mode, we can't allow any writes to stdin or stdout + // because they are used by the SSH protocol. + stdioReader, stdioWriter := inv.Stdin, inv.Stdout + if stdio { + inv.Stdin = stdioErrLogReader{inv.Logger} + inv.Stdout = inv.Stderr + } + // This WaitGroup solves for a race condition where we were logging // while closing the log file in a defer. It probably solves // others too. @@ -234,7 +242,7 @@ func (r *RootCmd) ssh() *clibase.Cmd { if err != nil { return xerrors.Errorf("connect SSH: %w", err) } - copier := newRawSSHCopier(logger, rawSSH, inv.Stdin, inv.Stdout) + copier := newRawSSHCopier(logger, rawSSH, stdioReader, stdioWriter) if err = stack.push("rawSSHCopier", copier); err != nil { return err } @@ -987,3 +995,12 @@ func sshDisableAutostartOption(src *clibase.Bool) clibase.Option { Default: "false", } } + +type stdioErrLogReader struct { + l slog.Logger +} + +func (r stdioErrLogReader) Read(_ []byte) (int, error) { + r.l.Error(context.Background(), "reading from stdin in stdio mode is not allowed") + return 0, io.EOF +} diff --git a/cli/ssh_test.go b/cli/ssh_test.go index d36df6218ed66..ee7cefbee58d7 100644 --- a/cli/ssh_test.go +++ b/cli/ssh_test.go @@ -1,6 +1,7 @@ package cli_test import ( + "bufio" "bytes" "context" "crypto/ecdsa" @@ -338,6 +339,157 @@ func TestSSH(t *testing.T) { <-cmdDone }) + t.Run("Stdio_StartStoppedWorkspace_CleanStdout", func(t *testing.T) { + t.Parallel() + + authToken := uuid.NewString() + ownerClient := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true}) + owner := coderdtest.CreateFirstUser(t, ownerClient) + client, _ := coderdtest.CreateAnotherUser(t, ownerClient, owner.OrganizationID, rbac.RoleTemplateAdmin()) + version := coderdtest.CreateTemplateVersion(t, client, owner.OrganizationID, &echo.Responses{ + Parse: echo.ParseComplete, + ProvisionPlan: echo.PlanComplete, + ProvisionApply: echo.ProvisionApplyWithAgent(authToken), + }) + coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID) + template := coderdtest.CreateTemplate(t, client, owner.OrganizationID, version.ID) + workspace := coderdtest.CreateWorkspace(t, client, owner.OrganizationID, template.ID) + coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspace.LatestBuild.ID) + // Stop the workspace + workspaceBuild := coderdtest.CreateWorkspaceBuild(t, client, workspace, database.WorkspaceTransitionStop) + coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspaceBuild.ID) + + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) + defer cancel() + + clientStdinR, clientStdinW := io.Pipe() + // Here's a simple flowchart for how these pipes are used: + // + // flowchart LR + // A[ProxyCommand] --> B[captureProxyCommandStdoutW] + // B --> C[captureProxyCommandStdoutR] + // C --> VA[Validate output] + // C --> D[proxyCommandStdoutW] + // D --> E[proxyCommandStdoutR] + // E --> F[SSH Client] + proxyCommandStdoutR, proxyCommandStdoutW := io.Pipe() + captureProxyCommandStdoutR, captureProxyCommandStdoutW := io.Pipe() + closePipes := func() { + for _, c := range []io.Closer{clientStdinR, clientStdinW, proxyCommandStdoutR, proxyCommandStdoutW, captureProxyCommandStdoutR, captureProxyCommandStdoutW} { + _ = c.Close() + } + } + defer closePipes() + tGo(t, func() { + <-ctx.Done() + closePipes() + }) + + // Here we start a monitor for the output produced by the proxy command, + // which is read by the SSH client. This is done to validate that the + // output is clean. + proxyCommandOutputBuf := make(chan byte, 4096) + tGo(t, func() { + defer close(proxyCommandOutputBuf) + + gotHeader := false + buf := bytes.Buffer{} + r := bufio.NewReader(captureProxyCommandStdoutR) + for { + b, err := r.ReadByte() + if err != nil { + if errors.Is(err, io.ErrClosedPipe) { + return + } + assert.NoError(t, err, "read byte failed") + return + } + if b == '\n' || b == '\r' { + out := buf.Bytes() + t.Logf("monitorServerOutput: %q (%#x)", out, out) + buf.Reset() + + // Ideally we would do further verification, but that would + // involve parsing the SSH protocol to look for output that + // doesn't belong. This at least ensures that no garbage is + // being sent to the SSH client before trying to connect. + if !gotHeader { + gotHeader = true + assert.Equal(t, "SSH-2.0-Go", string(out), "invalid header") + } + } else { + _ = buf.WriteByte(b) + } + select { + case proxyCommandOutputBuf <- b: + case <-ctx.Done(): + return + } + } + }) + tGo(t, func() { + defer proxyCommandStdoutW.Close() + + // Range closed by above goroutine. + for b := range proxyCommandOutputBuf { + _, err := proxyCommandStdoutW.Write([]byte{b}) + if err != nil { + if errors.Is(err, io.ErrClosedPipe) { + return + } + assert.NoError(t, err, "write byte failed") + return + } + } + }) + + // Start the SSH stdio command. + inv, root := clitest.New(t, "ssh", "--stdio", workspace.Name) + clitest.SetupConfig(t, client, root) + inv.Stdin = clientStdinR + inv.Stdout = captureProxyCommandStdoutW + inv.Stderr = io.Discard + + cmdDone := tGo(t, func() { + err := inv.WithContext(ctx).Run() + assert.NoError(t, err) + }) + + tGo(t, func() { + // When the agent connects, the workspace was started, and we should + // have access to the shell. + _ = agenttest.New(t, client.URL, authToken) + coderdtest.NewWorkspaceAgentWaiter(t, client, workspace.ID).Wait() + }) + + conn, channels, requests, err := ssh.NewClientConn(&stdioConn{ + Reader: proxyCommandStdoutR, + Writer: clientStdinW, + }, "", &ssh.ClientConfig{ + // #nosec + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + }) + require.NoError(t, err) + defer conn.Close() + + sshClient := ssh.NewClient(conn, channels, requests) + session, err := sshClient.NewSession() + require.NoError(t, err) + defer session.Close() + + command := "sh -c exit" + if runtime.GOOS == "windows" { + command = "cmd.exe /c exit" + } + err = session.Run(command) + require.NoError(t, err) + err = sshClient.Close() + require.NoError(t, err) + _ = clientStdinR.Close() + + <-cmdDone + }) + t.Run("Stdio_RemoteForward_Signal", func(t *testing.T) { t.Parallel() client, workspace, agentToken := setupWorkspaceForAgent(t)