Skip to content

Commit dd1f7c4

Browse files
committed
fix(cli/ssh): prevent stdin/stdout reads/writes in stdio mode
Fixes #11530
1 parent f2aef07 commit dd1f7c4

File tree

2 files changed

+160
-1
lines changed

2 files changed

+160
-1
lines changed

cli/ssh.go

+18-1
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,14 @@ func (r *RootCmd) ssh() *clibase.Cmd {
8787
}
8888
}()
8989

90+
// In stdio mode, we can't allow any writes to stdin or stdout
91+
// because they are used by the SSH protocol.
92+
stdioReader, stdioWriter := inv.Stdin, inv.Stdout
93+
if stdio {
94+
inv.Stdin = stdioErrLogReader{inv.Logger}
95+
inv.Stdout = inv.Stderr
96+
}
97+
9098
// This WaitGroup solves for a race condition where we were logging
9199
// while closing the log file in a defer. It probably solves
92100
// others too.
@@ -234,7 +242,7 @@ func (r *RootCmd) ssh() *clibase.Cmd {
234242
if err != nil {
235243
return xerrors.Errorf("connect SSH: %w", err)
236244
}
237-
copier := newRawSSHCopier(logger, rawSSH, inv.Stdin, inv.Stdout)
245+
copier := newRawSSHCopier(logger, rawSSH, stdioReader, stdioWriter)
238246
if err = stack.push("rawSSHCopier", copier); err != nil {
239247
return err
240248
}
@@ -987,3 +995,12 @@ func sshDisableAutostartOption(src *clibase.Bool) clibase.Option {
987995
Default: "false",
988996
}
989997
}
998+
999+
type stdioErrLogReader struct {
1000+
l slog.Logger
1001+
}
1002+
1003+
func (r stdioErrLogReader) Read(_ []byte) (int, error) {
1004+
r.l.Error(context.Background(), "reading from stdin in stdio mode is not allowed")
1005+
return 0, io.EOF
1006+
}

cli/ssh_test.go

+142
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package cli_test
22

33
import (
4+
"bufio"
45
"bytes"
56
"context"
67
"crypto/ecdsa"
@@ -338,6 +339,147 @@ func TestSSH(t *testing.T) {
338339
<-cmdDone
339340
})
340341

342+
t.Run("Stdio_StartStoppedWorkspace_CleanStdout", func(t *testing.T) {
343+
t.Parallel()
344+
345+
authToken := uuid.NewString()
346+
ownerClient := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
347+
owner := coderdtest.CreateFirstUser(t, ownerClient)
348+
client, _ := coderdtest.CreateAnotherUser(t, ownerClient, owner.OrganizationID, rbac.RoleTemplateAdmin())
349+
version := coderdtest.CreateTemplateVersion(t, client, owner.OrganizationID, &echo.Responses{
350+
Parse: echo.ParseComplete,
351+
ProvisionPlan: echo.PlanComplete,
352+
ProvisionApply: echo.ProvisionApplyWithAgent(authToken),
353+
})
354+
coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID)
355+
template := coderdtest.CreateTemplate(t, client, owner.OrganizationID, version.ID)
356+
workspace := coderdtest.CreateWorkspace(t, client, owner.OrganizationID, template.ID)
357+
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspace.LatestBuild.ID)
358+
// Stop the workspace
359+
workspaceBuild := coderdtest.CreateWorkspaceBuild(t, client, workspace, database.WorkspaceTransitionStop)
360+
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspaceBuild.ID)
361+
362+
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
363+
defer cancel()
364+
365+
clientOutput, clientInput := io.Pipe()
366+
serverOutput, serverInput := io.Pipe()
367+
monitorServerOutput, monitorServerInput := io.Pipe()
368+
closePipes := func() {
369+
for _, c := range []io.Closer{clientOutput, clientInput, serverOutput, serverInput, monitorServerOutput, monitorServerInput} {
370+
_ = c.Close()
371+
}
372+
}
373+
defer closePipes()
374+
tGo(t, func() {
375+
<-ctx.Done()
376+
closePipes()
377+
})
378+
379+
// Here we start a monitor for the input going to the server
380+
// (i.e. client stdout) to ensure that the output is clean.
381+
serverInputBuf := make(chan byte, 4096)
382+
tGo(t, func() {
383+
defer close(serverInputBuf)
384+
385+
gotHeader := false
386+
buf := bytes.Buffer{}
387+
r := bufio.NewReader(monitorServerOutput)
388+
for {
389+
b, err := r.ReadByte()
390+
if err != nil {
391+
if errors.Is(err, io.ErrClosedPipe) {
392+
return
393+
}
394+
assert.NoError(t, err, "read byte failed")
395+
return
396+
}
397+
if b == '\n' || b == '\r' {
398+
out := buf.Bytes()
399+
t.Logf("monitorServerOutput: %q (%#x)", out, out)
400+
buf.Reset()
401+
402+
// Ideally we would do further verification, but that would
403+
// involve parsing the SSH protocol to look for output that
404+
// doesn't belong. This at least ensures that no garbage is
405+
// being sent to the server before trying to connect.
406+
if !gotHeader {
407+
gotHeader = true
408+
assert.Equal(t, "SSH-2.0-Go", string(out), "invalid header")
409+
}
410+
} else {
411+
_ = buf.WriteByte(b)
412+
}
413+
select {
414+
case serverInputBuf <- b:
415+
case <-ctx.Done():
416+
return
417+
}
418+
}
419+
})
420+
tGo(t, func() {
421+
defer serverInput.Close()
422+
423+
// Range closed by above goroutine.
424+
for b := range serverInputBuf {
425+
_, err := serverInput.Write([]byte{b})
426+
if err != nil {
427+
if errors.Is(err, io.ErrClosedPipe) {
428+
return
429+
}
430+
assert.NoError(t, err, "write byte failed")
431+
return
432+
}
433+
}
434+
})
435+
436+
// Start the SSH stdio command.
437+
inv, root := clitest.New(t, "ssh", "--stdio", workspace.Name)
438+
clitest.SetupConfig(t, client, root)
439+
inv.Stdin = clientOutput
440+
inv.Stdout = monitorServerInput
441+
inv.Stderr = io.Discard
442+
443+
cmdDone := tGo(t, func() {
444+
err := inv.WithContext(ctx).Run()
445+
assert.NoError(t, err)
446+
})
447+
448+
tGo(t, func() {
449+
// When the agent connects, the workspace was started, and we should
450+
// have access to the shell.
451+
_ = agenttest.New(t, client.URL, authToken)
452+
coderdtest.NewWorkspaceAgentWaiter(t, client, workspace.ID).Wait()
453+
})
454+
455+
conn, channels, requests, err := ssh.NewClientConn(&stdioConn{
456+
Reader: serverOutput,
457+
Writer: clientInput,
458+
}, "", &ssh.ClientConfig{
459+
// #nosec
460+
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
461+
})
462+
require.NoError(t, err)
463+
defer conn.Close()
464+
465+
sshClient := ssh.NewClient(conn, channels, requests)
466+
session, err := sshClient.NewSession()
467+
require.NoError(t, err)
468+
defer session.Close()
469+
470+
command := "sh -c exit"
471+
if runtime.GOOS == "windows" {
472+
command = "cmd.exe /c exit"
473+
}
474+
err = session.Run(command)
475+
require.NoError(t, err)
476+
err = sshClient.Close()
477+
require.NoError(t, err)
478+
_ = clientOutput.Close()
479+
480+
<-cmdDone
481+
})
482+
341483
t.Run("Stdio_RemoteForward_Signal", func(t *testing.T) {
342484
t.Parallel()
343485
client, workspace, agentToken := setupWorkspaceForAgent(t)

0 commit comments

Comments
 (0)