Skip to content

Commit e659957

Browse files
authored
fix(cli/ssh): prevent reads/writes to stdin/stdout in stdio mode (coder#12045)
Fixes coder#11530
1 parent 151aaad commit e659957

File tree

2 files changed

+170
-1
lines changed

2 files changed

+170
-1
lines changed

cli/ssh.go

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,14 @@ func (r *RootCmd) ssh() *clibase.Cmd {
8787
}
8888
}()
8989

90+
// In stdio mode, we can't allow any writes to stdin or stdout
91+
// because they are used by the SSH protocol.
92+
stdioReader, stdioWriter := inv.Stdin, inv.Stdout
93+
if stdio {
94+
inv.Stdin = stdioErrLogReader{inv.Logger}
95+
inv.Stdout = inv.Stderr
96+
}
97+
9098
// This WaitGroup solves for a race condition where we were logging
9199
// while closing the log file in a defer. It probably solves
92100
// others too.
@@ -234,7 +242,7 @@ func (r *RootCmd) ssh() *clibase.Cmd {
234242
if err != nil {
235243
return xerrors.Errorf("connect SSH: %w", err)
236244
}
237-
copier := newRawSSHCopier(logger, rawSSH, inv.Stdin, inv.Stdout)
245+
copier := newRawSSHCopier(logger, rawSSH, stdioReader, stdioWriter)
238246
if err = stack.push("rawSSHCopier", copier); err != nil {
239247
return err
240248
}
@@ -987,3 +995,12 @@ func sshDisableAutostartOption(src *clibase.Bool) clibase.Option {
987995
Default: "false",
988996
}
989997
}
998+
999+
type stdioErrLogReader struct {
1000+
l slog.Logger
1001+
}
1002+
1003+
func (r stdioErrLogReader) Read(_ []byte) (int, error) {
1004+
r.l.Error(context.Background(), "reading from stdin in stdio mode is not allowed")
1005+
return 0, io.EOF
1006+
}

cli/ssh_test.go

Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package cli_test
22

33
import (
4+
"bufio"
45
"bytes"
56
"context"
67
"crypto/ecdsa"
@@ -338,6 +339,157 @@ func TestSSH(t *testing.T) {
338339
<-cmdDone
339340
})
340341

342+
t.Run("Stdio_StartStoppedWorkspace_CleanStdout", func(t *testing.T) {
343+
t.Parallel()
344+
345+
authToken := uuid.NewString()
346+
ownerClient := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
347+
owner := coderdtest.CreateFirstUser(t, ownerClient)
348+
client, _ := coderdtest.CreateAnotherUser(t, ownerClient, owner.OrganizationID, rbac.RoleTemplateAdmin())
349+
version := coderdtest.CreateTemplateVersion(t, client, owner.OrganizationID, &echo.Responses{
350+
Parse: echo.ParseComplete,
351+
ProvisionPlan: echo.PlanComplete,
352+
ProvisionApply: echo.ProvisionApplyWithAgent(authToken),
353+
})
354+
coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID)
355+
template := coderdtest.CreateTemplate(t, client, owner.OrganizationID, version.ID)
356+
workspace := coderdtest.CreateWorkspace(t, client, owner.OrganizationID, template.ID)
357+
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspace.LatestBuild.ID)
358+
// Stop the workspace
359+
workspaceBuild := coderdtest.CreateWorkspaceBuild(t, client, workspace, database.WorkspaceTransitionStop)
360+
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspaceBuild.ID)
361+
362+
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
363+
defer cancel()
364+
365+
clientStdinR, clientStdinW := io.Pipe()
366+
// Here's a simple flowchart for how these pipes are used:
367+
//
368+
// flowchart LR
369+
// A[ProxyCommand] --> B[captureProxyCommandStdoutW]
370+
// B --> C[captureProxyCommandStdoutR]
371+
// C --> VA[Validate output]
372+
// C --> D[proxyCommandStdoutW]
373+
// D --> E[proxyCommandStdoutR]
374+
// E --> F[SSH Client]
375+
proxyCommandStdoutR, proxyCommandStdoutW := io.Pipe()
376+
captureProxyCommandStdoutR, captureProxyCommandStdoutW := io.Pipe()
377+
closePipes := func() {
378+
for _, c := range []io.Closer{clientStdinR, clientStdinW, proxyCommandStdoutR, proxyCommandStdoutW, captureProxyCommandStdoutR, captureProxyCommandStdoutW} {
379+
_ = c.Close()
380+
}
381+
}
382+
defer closePipes()
383+
tGo(t, func() {
384+
<-ctx.Done()
385+
closePipes()
386+
})
387+
388+
// Here we start a monitor for the output produced by the proxy command,
389+
// which is read by the SSH client. This is done to validate that the
390+
// output is clean.
391+
proxyCommandOutputBuf := make(chan byte, 4096)
392+
tGo(t, func() {
393+
defer close(proxyCommandOutputBuf)
394+
395+
gotHeader := false
396+
buf := bytes.Buffer{}
397+
r := bufio.NewReader(captureProxyCommandStdoutR)
398+
for {
399+
b, err := r.ReadByte()
400+
if err != nil {
401+
if errors.Is(err, io.ErrClosedPipe) {
402+
return
403+
}
404+
assert.NoError(t, err, "read byte failed")
405+
return
406+
}
407+
if b == '\n' || b == '\r' {
408+
out := buf.Bytes()
409+
t.Logf("monitorServerOutput: %q (%#x)", out, out)
410+
buf.Reset()
411+
412+
// Ideally we would do further verification, but that would
413+
// involve parsing the SSH protocol to look for output that
414+
// doesn't belong. This at least ensures that no garbage is
415+
// being sent to the SSH client before trying to connect.
416+
if !gotHeader {
417+
gotHeader = true
418+
assert.Equal(t, "SSH-2.0-Go", string(out), "invalid header")
419+
}
420+
} else {
421+
_ = buf.WriteByte(b)
422+
}
423+
select {
424+
case proxyCommandOutputBuf <- b:
425+
case <-ctx.Done():
426+
return
427+
}
428+
}
429+
})
430+
tGo(t, func() {
431+
defer proxyCommandStdoutW.Close()
432+
433+
// Range closed by above goroutine.
434+
for b := range proxyCommandOutputBuf {
435+
_, err := proxyCommandStdoutW.Write([]byte{b})
436+
if err != nil {
437+
if errors.Is(err, io.ErrClosedPipe) {
438+
return
439+
}
440+
assert.NoError(t, err, "write byte failed")
441+
return
442+
}
443+
}
444+
})
445+
446+
// Start the SSH stdio command.
447+
inv, root := clitest.New(t, "ssh", "--stdio", workspace.Name)
448+
clitest.SetupConfig(t, client, root)
449+
inv.Stdin = clientStdinR
450+
inv.Stdout = captureProxyCommandStdoutW
451+
inv.Stderr = io.Discard
452+
453+
cmdDone := tGo(t, func() {
454+
err := inv.WithContext(ctx).Run()
455+
assert.NoError(t, err)
456+
})
457+
458+
tGo(t, func() {
459+
// When the agent connects, the workspace was started, and we should
460+
// have access to the shell.
461+
_ = agenttest.New(t, client.URL, authToken)
462+
coderdtest.NewWorkspaceAgentWaiter(t, client, workspace.ID).Wait()
463+
})
464+
465+
conn, channels, requests, err := ssh.NewClientConn(&stdioConn{
466+
Reader: proxyCommandStdoutR,
467+
Writer: clientStdinW,
468+
}, "", &ssh.ClientConfig{
469+
// #nosec
470+
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
471+
})
472+
require.NoError(t, err)
473+
defer conn.Close()
474+
475+
sshClient := ssh.NewClient(conn, channels, requests)
476+
session, err := sshClient.NewSession()
477+
require.NoError(t, err)
478+
defer session.Close()
479+
480+
command := "sh -c exit"
481+
if runtime.GOOS == "windows" {
482+
command = "cmd.exe /c exit"
483+
}
484+
err = session.Run(command)
485+
require.NoError(t, err)
486+
err = sshClient.Close()
487+
require.NoError(t, err)
488+
_ = clientStdinR.Close()
489+
490+
<-cmdDone
491+
})
492+
341493
t.Run("Stdio_RemoteForward_Signal", func(t *testing.T) {
342494
t.Parallel()
343495
client, workspace, agentToken := setupWorkspaceForAgent(t)

0 commit comments

Comments
 (0)