Skip to content

Commit daee91c

Browse files
authored
refactor: PTY & SSH (#7100)
* Add ssh tests for longoutput, orphan Signed-off-by: Spike Curtis <spike@coder.com> * PTY/SSH tests & improvements Signed-off-by: Spike Curtis <spike@coder.com> * Fix some tests Signed-off-by: Spike Curtis <spike@coder.com> * Fix linting Signed-off-by: Spike Curtis <spike@coder.com> * fmt Signed-off-by: Spike Curtis <spike@coder.com> * Fix windows test Signed-off-by: Spike Curtis <spike@coder.com> * Windows copy test Signed-off-by: Spike Curtis <spike@coder.com> * WIP Windows pty handling Signed-off-by: Spike Curtis <spike@coder.com> * Fix truncation tests Signed-off-by: Spike Curtis <spike@coder.com> * Appease linter/fmt Signed-off-by: Spike Curtis <spike@coder.com> * Fix typo Signed-off-by: Spike Curtis <spike@coder.com> * Rework truncation test to not assume OS buffers Signed-off-by: Spike Curtis <spike@coder.com> * Disable orphan test on Windows --- uses sh Signed-off-by: Spike Curtis <spike@coder.com> * agent_test running SSH in pty use ptytest.Start Signed-off-by: Spike Curtis <spike@coder.com> * More detail about closing pseudoconsole on windows Signed-off-by: Spike Curtis <spike@coder.com> * Code review fixes Signed-off-by: Spike Curtis <spike@coder.com> * Rearrange ptytest method order Signed-off-by: Spike Curtis <spike@coder.com> * Protect pty.Resize on windows from races Signed-off-by: Spike Curtis <spike@coder.com> * Fix windows bugs Signed-off-by: Spike Curtis <spike@coder.com> * PTY doesn't extend PTYCmd Signed-off-by: Spike Curtis <spike@coder.com> * Fix windows types Signed-off-by: Spike Curtis <spike@coder.com> --------- Signed-off-by: Spike Curtis <spike@coder.com>
1 parent c000f2e commit daee91c

16 files changed

+803
-288
lines changed

agent/agent.go

+3-3
Original file line numberDiff line numberDiff line change
@@ -1045,7 +1045,7 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, logger slog.Logger, m
10451045
if err = a.trackConnGoroutine(func() {
10461046
buffer := make([]byte, 1024)
10471047
for {
1048-
read, err := rpty.ptty.Output().Read(buffer)
1048+
read, err := rpty.ptty.OutputReader().Read(buffer)
10491049
if err != nil {
10501050
// When the PTY is closed, this is triggered.
10511051
break
@@ -1138,7 +1138,7 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, logger slog.Logger, m
11381138
logger.Warn(ctx, "read conn", slog.Error(err))
11391139
return nil
11401140
}
1141-
_, err = rpty.ptty.Input().Write([]byte(req.Data))
1141+
_, err = rpty.ptty.InputWriter().Write([]byte(req.Data))
11421142
if err != nil {
11431143
logger.Warn(ctx, "write to pty", slog.Error(err))
11441144
return nil
@@ -1358,7 +1358,7 @@ type reconnectingPTY struct {
13581358
circularBuffer *circbuf.Buffer
13591359
circularBufferMutex sync.RWMutex
13601360
timeout *time.Timer
1361-
ptty pty.PTY
1361+
ptty pty.PTYCmd
13621362
}
13631363

13641364
// Close ends all connections to the reconnecting

agent/agent_test.go

+16-42
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ import (
4545
"github.com/coder/coder/coderd/httpapi"
4646
"github.com/coder/coder/codersdk"
4747
"github.com/coder/coder/codersdk/agentsdk"
48+
"github.com/coder/coder/pty"
4849
"github.com/coder/coder/pty/ptytest"
4950
"github.com/coder/coder/tailnet"
5051
"github.com/coder/coder/tailnet/tailnettest"
@@ -481,17 +482,10 @@ func TestAgent_TCPLocalForwarding(t *testing.T) {
481482
}
482483
}()
483484

484-
pty := ptytest.New(t)
485-
486-
cmd := setupSSHCommand(t, []string{"-L", fmt.Sprintf("%d:127.0.0.1:%d", randomPort, remotePort)}, []string{"sleep", "5"})
487-
cmd.Stdin = pty.Input()
488-
cmd.Stdout = pty.Output()
489-
cmd.Stderr = pty.Output()
490-
err = cmd.Start()
491-
require.NoError(t, err)
485+
_, proc := setupSSHCommand(t, []string{"-L", fmt.Sprintf("%d:127.0.0.1:%d", randomPort, remotePort)}, []string{"sleep", "5"})
492486

493487
go func() {
494-
err := cmd.Wait()
488+
err := proc.Wait()
495489
select {
496490
case <-done:
497491
default:
@@ -523,7 +517,7 @@ func TestAgent_TCPLocalForwarding(t *testing.T) {
523517

524518
<-done
525519

526-
_ = cmd.Process.Kill()
520+
_ = proc.Kill()
527521
}
528522

529523
//nolint:paralleltest // This test reserves a port.
@@ -562,17 +556,10 @@ func TestAgent_TCPRemoteForwarding(t *testing.T) {
562556
}
563557
}()
564558

565-
pty := ptytest.New(t)
566-
567-
cmd := setupSSHCommand(t, []string{"-R", fmt.Sprintf("127.0.0.1:%d:127.0.0.1:%d", randomPort, localPort)}, []string{"sleep", "5"})
568-
cmd.Stdin = pty.Input()
569-
cmd.Stdout = pty.Output()
570-
cmd.Stderr = pty.Output()
571-
err = cmd.Start()
572-
require.NoError(t, err)
559+
_, proc := setupSSHCommand(t, []string{"-R", fmt.Sprintf("127.0.0.1:%d:127.0.0.1:%d", randomPort, localPort)}, []string{"sleep", "5"})
573560

574561
go func() {
575-
err := cmd.Wait()
562+
err := proc.Wait()
576563
select {
577564
case <-done:
578565
default:
@@ -604,7 +591,7 @@ func TestAgent_TCPRemoteForwarding(t *testing.T) {
604591

605592
<-done
606593

607-
_ = cmd.Process.Kill()
594+
_ = proc.Kill()
608595
}
609596

610597
func TestAgent_UnixLocalForwarding(t *testing.T) {
@@ -641,17 +628,10 @@ func TestAgent_UnixLocalForwarding(t *testing.T) {
641628
}
642629
}()
643630

644-
pty := ptytest.New(t)
645-
646-
cmd := setupSSHCommand(t, []string{"-L", fmt.Sprintf("%s:%s", localSocketPath, remoteSocketPath)}, []string{"sleep", "5"})
647-
cmd.Stdin = pty.Input()
648-
cmd.Stdout = pty.Output()
649-
cmd.Stderr = pty.Output()
650-
err = cmd.Start()
651-
require.NoError(t, err)
631+
_, proc := setupSSHCommand(t, []string{"-L", fmt.Sprintf("%s:%s", localSocketPath, remoteSocketPath)}, []string{"sleep", "5"})
652632

653633
go func() {
654-
err := cmd.Wait()
634+
err := proc.Wait()
655635
select {
656636
case <-done:
657637
default:
@@ -676,7 +656,7 @@ func TestAgent_UnixLocalForwarding(t *testing.T) {
676656
_ = conn.Close()
677657
<-done
678658

679-
_ = cmd.Process.Kill()
659+
_ = proc.Kill()
680660
}
681661

682662
func TestAgent_UnixRemoteForwarding(t *testing.T) {
@@ -713,17 +693,10 @@ func TestAgent_UnixRemoteForwarding(t *testing.T) {
713693
}
714694
}()
715695

716-
pty := ptytest.New(t)
717-
718-
cmd := setupSSHCommand(t, []string{"-R", fmt.Sprintf("%s:%s", remoteSocketPath, localSocketPath)}, []string{"sleep", "5"})
719-
cmd.Stdin = pty.Input()
720-
cmd.Stdout = pty.Output()
721-
cmd.Stderr = pty.Output()
722-
err = cmd.Start()
723-
require.NoError(t, err)
696+
_, proc := setupSSHCommand(t, []string{"-R", fmt.Sprintf("%s:%s", remoteSocketPath, localSocketPath)}, []string{"sleep", "5"})
724697

725698
go func() {
726-
err := cmd.Wait()
699+
err := proc.Wait()
727700
select {
728701
case <-done:
729702
default:
@@ -753,7 +726,7 @@ func TestAgent_UnixRemoteForwarding(t *testing.T) {
753726

754727
<-done
755728

756-
_ = cmd.Process.Kill()
729+
_ = proc.Kill()
757730
}
758731

759732
func TestAgent_SFTP(t *testing.T) {
@@ -1648,7 +1621,7 @@ func TestAgent_WriteVSCodeConfigs(t *testing.T) {
16481621
}, testutil.WaitShort, testutil.IntervalFast)
16491622
}
16501623

1651-
func setupSSHCommand(t *testing.T, beforeArgs []string, afterArgs []string) *exec.Cmd {
1624+
func setupSSHCommand(t *testing.T, beforeArgs []string, afterArgs []string) (*ptytest.PTYCmd, pty.Process) {
16521625
//nolint:dogsled
16531626
agentConn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0)
16541627
listener, err := net.Listen("tcp", "127.0.0.1:0")
@@ -1690,7 +1663,8 @@ func setupSSHCommand(t *testing.T, beforeArgs []string, afterArgs []string) *exe
16901663
"host",
16911664
)
16921665
args = append(args, afterArgs...)
1693-
return exec.Command("ssh", args...)
1666+
cmd := exec.Command("ssh", args...)
1667+
return ptytest.Start(t, cmd)
16941668
}
16951669

16961670
func setupSSHSession(t *testing.T, options agentsdk.Manifest) *ssh.Session {

agent/agentssh/agentssh.go

+91-97
Original file line numberDiff line numberDiff line change
@@ -253,102 +253,12 @@ func (s *Server) sessionStart(session ssh.Session, extraEnv []string) (retErr er
253253

254254
sshPty, windowSize, isPty := session.Pty()
255255
if isPty {
256-
// Disable minimal PTY emulation set by gliderlabs/ssh (NL-to-CRNL).
257-
// See https://github.com/coder/coder/issues/3371.
258-
session.DisablePTYEmulation()
259-
260-
if !isQuietLogin(session.RawCommand()) {
261-
manifest := s.Manifest.Load()
262-
if manifest != nil {
263-
err = showMOTD(session, manifest.MOTDFile)
264-
if err != nil {
265-
s.logger.Error(ctx, "show MOTD", slog.Error(err))
266-
}
267-
} else {
268-
s.logger.Warn(ctx, "metadata lookup failed, unable to show MOTD")
269-
}
270-
}
271-
272-
cmd.Env = append(cmd.Env, fmt.Sprintf("TERM=%s", sshPty.Term))
273-
274-
// The pty package sets `SSH_TTY` on supported platforms.
275-
ptty, process, err := pty.Start(cmd, pty.WithPTYOption(
276-
pty.WithSSHRequest(sshPty),
277-
pty.WithLogger(slog.Stdlib(ctx, s.logger, slog.LevelInfo)),
278-
))
279-
if err != nil {
280-
return xerrors.Errorf("start command: %w", err)
281-
}
282-
var wg sync.WaitGroup
283-
defer func() {
284-
defer wg.Wait()
285-
closeErr := ptty.Close()
286-
if closeErr != nil {
287-
s.logger.Warn(ctx, "failed to close tty", slog.Error(closeErr))
288-
if retErr == nil {
289-
retErr = closeErr
290-
}
291-
}
292-
}()
293-
go func() {
294-
for win := range windowSize {
295-
resizeErr := ptty.Resize(uint16(win.Height), uint16(win.Width))
296-
// If the pty is closed, then command has exited, no need to log.
297-
if resizeErr != nil && !errors.Is(resizeErr, pty.ErrClosed) {
298-
s.logger.Warn(ctx, "failed to resize tty", slog.Error(resizeErr))
299-
}
300-
}
301-
}()
302-
// We don't add input copy to wait group because
303-
// it won't return until the session is closed.
304-
go func() {
305-
_, _ = io.Copy(ptty.Input(), session)
306-
}()
307-
308-
// In low parallelism scenarios, the command may exit and we may close
309-
// the pty before the output copy has started. This can result in the
310-
// output being lost. To avoid this, we wait for the output copy to
311-
// start before waiting for the command to exit. This ensures that the
312-
// output copy goroutine will be scheduled before calling close on the
313-
// pty. This shouldn't be needed because of `pty.Dup()` below, but it
314-
// may not be supported on all platforms.
315-
outputCopyStarted := make(chan struct{})
316-
ptyOutput := func() io.ReadCloser {
317-
defer close(outputCopyStarted)
318-
// Try to dup so we can separate stdin and stdout closure.
319-
// Once the original pty is closed, the dup will return
320-
// input/output error once the buffered data has been read.
321-
stdout, err := ptty.Dup()
322-
if err == nil {
323-
return stdout
324-
}
325-
// If we can't dup, we shouldn't close
326-
// the fd since it's tied to stdin.
327-
return readNopCloser{ptty.Output()}
328-
}
329-
wg.Add(1)
330-
go func() {
331-
// Ensure data is flushed to session on command exit, if we
332-
// close the session too soon, we might lose data.
333-
defer wg.Done()
334-
335-
stdout := ptyOutput()
336-
defer stdout.Close()
337-
338-
_, _ = io.Copy(session, stdout)
339-
}()
340-
<-outputCopyStarted
341-
342-
err = process.Wait()
343-
var exitErr *exec.ExitError
344-
// ExitErrors just mean the command we run returned a non-zero exit code, which is normal
345-
// and not something to be concerned about. But, if it's something else, we should log it.
346-
if err != nil && !xerrors.As(err, &exitErr) {
347-
s.logger.Warn(ctx, "wait error", slog.Error(err))
348-
}
349-
return err
256+
return s.startPTYSession(session, cmd, sshPty, windowSize)
350257
}
258+
return startNonPTYSession(session, cmd)
259+
}
351260

261+
func startNonPTYSession(session ssh.Session, cmd *exec.Cmd) error {
352262
cmd.Stdout = session
353263
cmd.Stderr = session.Stderr()
354264
// This blocks forever until stdin is received if we don't
@@ -368,10 +278,94 @@ func (s *Server) sessionStart(session ssh.Session, extraEnv []string) (retErr er
368278
return cmd.Wait()
369279
}
370280

371-
type readNopCloser struct{ io.Reader }
281+
// ptySession is the interface to the ssh.Session that startPTYSession uses
282+
// we use an interface here so that we can fake it in tests.
283+
type ptySession interface {
284+
io.ReadWriter
285+
Context() ssh.Context
286+
DisablePTYEmulation()
287+
RawCommand() string
288+
}
289+
290+
func (s *Server) startPTYSession(session ptySession, cmd *exec.Cmd, sshPty ssh.Pty, windowSize <-chan ssh.Window) (retErr error) {
291+
ctx := session.Context()
292+
// Disable minimal PTY emulation set by gliderlabs/ssh (NL-to-CRNL).
293+
// See https://github.com/coder/coder/issues/3371.
294+
session.DisablePTYEmulation()
295+
296+
if !isQuietLogin(session.RawCommand()) {
297+
manifest := s.Manifest.Load()
298+
if manifest != nil {
299+
err := showMOTD(session, manifest.MOTDFile)
300+
if err != nil {
301+
s.logger.Error(ctx, "show MOTD", slog.Error(err))
302+
}
303+
} else {
304+
s.logger.Warn(ctx, "metadata lookup failed, unable to show MOTD")
305+
}
306+
}
307+
308+
cmd.Env = append(cmd.Env, fmt.Sprintf("TERM=%s", sshPty.Term))
309+
310+
// The pty package sets `SSH_TTY` on supported platforms.
311+
ptty, process, err := pty.Start(cmd, pty.WithPTYOption(
312+
pty.WithSSHRequest(sshPty),
313+
pty.WithLogger(slog.Stdlib(ctx, s.logger, slog.LevelInfo)),
314+
))
315+
if err != nil {
316+
return xerrors.Errorf("start command: %w", err)
317+
}
318+
defer func() {
319+
closeErr := ptty.Close()
320+
if closeErr != nil {
321+
s.logger.Warn(ctx, "failed to close tty", slog.Error(closeErr))
322+
if retErr == nil {
323+
retErr = closeErr
324+
}
325+
}
326+
}()
327+
go func() {
328+
for win := range windowSize {
329+
resizeErr := ptty.Resize(uint16(win.Height), uint16(win.Width))
330+
// If the pty is closed, then command has exited, no need to log.
331+
if resizeErr != nil && !errors.Is(resizeErr, pty.ErrClosed) {
332+
s.logger.Warn(ctx, "failed to resize tty", slog.Error(resizeErr))
333+
}
334+
}
335+
}()
336+
337+
go func() {
338+
_, _ = io.Copy(ptty.InputWriter(), session)
339+
}()
372340

373-
// Close implements io.Closer.
374-
func (readNopCloser) Close() error { return nil }
341+
// We need to wait for the command output to finish copying. It's safe to
342+
// just do this copy on the main handler goroutine because one of two things
343+
// will happen:
344+
//
345+
// 1. The command completes & closes the TTY, which then triggers an error
346+
// after we've Read() all the buffered data from the PTY.
347+
// 2. The client hangs up, which cancels the command's Context, and go will
348+
// kill the command's process. This then has the same effect as (1).
349+
n, err := io.Copy(session, ptty.OutputReader())
350+
s.logger.Debug(ctx, "copy output done", slog.F("bytes", n), slog.Error(err))
351+
if err != nil {
352+
return xerrors.Errorf("copy error: %w", err)
353+
}
354+
// We've gotten all the output, but we need to wait for the process to
355+
// complete so that we can get the exit code. This returns
356+
// immediately if the TTY was closed as part of the command exiting.
357+
err = process.Wait()
358+
var exitErr *exec.ExitError
359+
// ExitErrors just mean the command we run returned a non-zero exit code, which is normal
360+
// and not something to be concerned about. But, if it's something else, we should log it.
361+
if err != nil && !xerrors.As(err, &exitErr) {
362+
s.logger.Warn(ctx, "wait error", slog.Error(err))
363+
}
364+
if err != nil {
365+
return xerrors.Errorf("process wait: %w", err)
366+
}
367+
return nil
368+
}
375369

376370
func (s *Server) sftpHandler(session ssh.Session) {
377371
ctx := session.Context()

0 commit comments

Comments
 (0)