Skip to content

Commit 8f0098e

Browse files
committed
test: add test that we close stdin on SSH session close
1 parent 0b82f41 commit 8f0098e

File tree

2 files changed

+80
-1
lines changed

2 files changed

+80
-1
lines changed

agent/agentssh/agentssh.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -609,7 +609,9 @@ func (s *Server) startNonPTYSession(logger slog.Logger, session ssh.Session, mag
609609
// and SSH server close may be delayed.
610610
cmd.SysProcAttr = cmdSysProcAttr()
611611

612-
// to match OpenSSH, we don't actually tear a non-TTY command down, even if the session ends.
612+
// to match OpenSSH, we don't actually tear a non-TTY command down, even if the session ends. OpenSSH closes the
613+
// pipes to the process when the sesion ends; which is what happens here since we wire the command up to the
614+
// session for I/O.
613615
// c.f. https://github.com/coder/coder/issues/18519#issuecomment-3019118271
614616
cmd.Cancel = nil
615617

agent/agentssh/agentssh_test.go

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@ import (
88
"context"
99
"fmt"
1010
"net"
11+
"os"
1112
"os/user"
13+
"path/filepath"
1214
"runtime"
1315
"strings"
1416
"sync"
@@ -403,6 +405,81 @@ func TestNewServer_Signal(t *testing.T) {
403405
})
404406
}
405407

408+
func TestSSHServer_ClosesStdin(t *testing.T) {
409+
t.Parallel()
410+
if runtime.GOOS == "windows" {
411+
t.Skip("bash doesn't exist on Windows")
412+
}
413+
414+
ctx := testutil.Context(t, testutil.WaitMedium)
415+
logger := testutil.Logger(t)
416+
s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), agentexec.DefaultExecer, nil)
417+
require.NoError(t, err)
418+
defer s.Close()
419+
err = s.UpdateHostSigner(42)
420+
assert.NoError(t, err)
421+
422+
ln, err := net.Listen("tcp", "127.0.0.1:0")
423+
require.NoError(t, err)
424+
425+
done := make(chan struct{})
426+
go func() {
427+
defer close(done)
428+
err := s.Serve(ln)
429+
assert.Error(t, err) // Server is closed.
430+
}()
431+
defer func() {
432+
err := s.Close()
433+
require.NoError(t, err)
434+
<-done
435+
}()
436+
437+
c := sshClient(t, ln.Addr().String())
438+
439+
sess, err := c.NewSession()
440+
require.NoError(t, err)
441+
stdout, err := sess.StdoutPipe()
442+
require.NoError(t, err)
443+
stdin, err := sess.StdinPipe()
444+
require.NoError(t, err)
445+
defer stdin.Close()
446+
447+
dir := t.TempDir()
448+
err = os.MkdirAll(dir, 0o755)
449+
require.NoError(t, err)
450+
filePath := filepath.Join(dir, "result.txt")
451+
452+
// the shell command `read` will block until data is written to stdin, or closed. It will return
453+
// exit code 1 if it hits EOF, which is what we want to test.
454+
cmdErrCh := make(chan error, 1)
455+
go func() {
456+
cmdErrCh <- sess.Start(fmt.Sprintf("echo started; read; echo \"read exit code: $?\" > %s", filePath))
457+
}()
458+
459+
cmdErr := testutil.RequireReceive(ctx, t, cmdErrCh)
460+
require.NoError(t, cmdErr)
461+
462+
readCh := make(chan error, 1)
463+
go func() {
464+
buf := make([]byte, 8)
465+
_, err := stdout.Read(buf)
466+
assert.Equal(t, "started\n", string(buf))
467+
readCh <- err
468+
}()
469+
err = testutil.RequireReceive(ctx, t, readCh)
470+
require.NoError(t, err)
471+
472+
sess.Close()
473+
474+
var content []byte
475+
require.Eventually(t, func() bool {
476+
content, err = os.ReadFile(filePath)
477+
return err == nil
478+
}, testutil.WaitMedium, testutil.IntervalFast)
479+
require.NoError(t, err)
480+
require.Equal(t, "read exit code: 1\n", string(content))
481+
}
482+
406483
func sshClient(t *testing.T, addr string) *ssh.Client {
407484
conn, err := net.Dial("tcp", addr)
408485
require.NoError(t, err)

0 commit comments

Comments
 (0)