Skip to content

refactor: PTY & SSH #7100

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 22 commits into from
Apr 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions agent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -1045,7 +1045,7 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, logger slog.Logger, m
if err = a.trackConnGoroutine(func() {
buffer := make([]byte, 1024)
for {
read, err := rpty.ptty.Output().Read(buffer)
read, err := rpty.ptty.OutputReader().Read(buffer)
if err != nil {
// When the PTY is closed, this is triggered.
break
Expand Down Expand Up @@ -1138,7 +1138,7 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, logger slog.Logger, m
logger.Warn(ctx, "read conn", slog.Error(err))
return nil
}
_, err = rpty.ptty.Input().Write([]byte(req.Data))
_, err = rpty.ptty.InputWriter().Write([]byte(req.Data))
if err != nil {
logger.Warn(ctx, "write to pty", slog.Error(err))
return nil
Expand Down Expand Up @@ -1358,7 +1358,7 @@ type reconnectingPTY struct {
circularBuffer *circbuf.Buffer
circularBufferMutex sync.RWMutex
timeout *time.Timer
ptty pty.PTY
ptty pty.PTYCmd
}

// Close ends all connections to the reconnecting
Expand Down
58 changes: 16 additions & 42 deletions agent/agent_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ import (
"github.com/coder/coder/coderd/httpapi"
"github.com/coder/coder/codersdk"
"github.com/coder/coder/codersdk/agentsdk"
"github.com/coder/coder/pty"
"github.com/coder/coder/pty/ptytest"
"github.com/coder/coder/tailnet"
"github.com/coder/coder/tailnet/tailnettest"
Expand Down Expand Up @@ -481,17 +482,10 @@ func TestAgent_TCPLocalForwarding(t *testing.T) {
}
}()

pty := ptytest.New(t)

cmd := setupSSHCommand(t, []string{"-L", fmt.Sprintf("%d:127.0.0.1:%d", randomPort, remotePort)}, []string{"sleep", "5"})
cmd.Stdin = pty.Input()
cmd.Stdout = pty.Output()
cmd.Stderr = pty.Output()
err = cmd.Start()
require.NoError(t, err)
_, proc := setupSSHCommand(t, []string{"-L", fmt.Sprintf("%d:127.0.0.1:%d", randomPort, remotePort)}, []string{"sleep", "5"})

go func() {
err := cmd.Wait()
err := proc.Wait()
select {
case <-done:
default:
Expand Down Expand Up @@ -523,7 +517,7 @@ func TestAgent_TCPLocalForwarding(t *testing.T) {

<-done

_ = cmd.Process.Kill()
_ = proc.Kill()
}

//nolint:paralleltest // This test reserves a port.
Expand Down Expand Up @@ -562,17 +556,10 @@ func TestAgent_TCPRemoteForwarding(t *testing.T) {
}
}()

pty := ptytest.New(t)

cmd := setupSSHCommand(t, []string{"-R", fmt.Sprintf("127.0.0.1:%d:127.0.0.1:%d", randomPort, localPort)}, []string{"sleep", "5"})
cmd.Stdin = pty.Input()
cmd.Stdout = pty.Output()
cmd.Stderr = pty.Output()
err = cmd.Start()
require.NoError(t, err)
_, proc := setupSSHCommand(t, []string{"-R", fmt.Sprintf("127.0.0.1:%d:127.0.0.1:%d", randomPort, localPort)}, []string{"sleep", "5"})

go func() {
err := cmd.Wait()
err := proc.Wait()
select {
case <-done:
default:
Expand Down Expand Up @@ -604,7 +591,7 @@ func TestAgent_TCPRemoteForwarding(t *testing.T) {

<-done

_ = cmd.Process.Kill()
_ = proc.Kill()
}

func TestAgent_UnixLocalForwarding(t *testing.T) {
Expand Down Expand Up @@ -641,17 +628,10 @@ func TestAgent_UnixLocalForwarding(t *testing.T) {
}
}()

pty := ptytest.New(t)

cmd := setupSSHCommand(t, []string{"-L", fmt.Sprintf("%s:%s", localSocketPath, remoteSocketPath)}, []string{"sleep", "5"})
cmd.Stdin = pty.Input()
cmd.Stdout = pty.Output()
cmd.Stderr = pty.Output()
err = cmd.Start()
require.NoError(t, err)
_, proc := setupSSHCommand(t, []string{"-L", fmt.Sprintf("%s:%s", localSocketPath, remoteSocketPath)}, []string{"sleep", "5"})

go func() {
err := cmd.Wait()
err := proc.Wait()
select {
case <-done:
default:
Expand All @@ -676,7 +656,7 @@ func TestAgent_UnixLocalForwarding(t *testing.T) {
_ = conn.Close()
<-done

_ = cmd.Process.Kill()
_ = proc.Kill()
}

func TestAgent_UnixRemoteForwarding(t *testing.T) {
Expand Down Expand Up @@ -713,17 +693,10 @@ func TestAgent_UnixRemoteForwarding(t *testing.T) {
}
}()

pty := ptytest.New(t)

cmd := setupSSHCommand(t, []string{"-R", fmt.Sprintf("%s:%s", remoteSocketPath, localSocketPath)}, []string{"sleep", "5"})
cmd.Stdin = pty.Input()
cmd.Stdout = pty.Output()
cmd.Stderr = pty.Output()
err = cmd.Start()
require.NoError(t, err)
_, proc := setupSSHCommand(t, []string{"-R", fmt.Sprintf("%s:%s", remoteSocketPath, localSocketPath)}, []string{"sleep", "5"})

go func() {
err := cmd.Wait()
err := proc.Wait()
select {
case <-done:
default:
Expand Down Expand Up @@ -753,7 +726,7 @@ func TestAgent_UnixRemoteForwarding(t *testing.T) {

<-done

_ = cmd.Process.Kill()
_ = proc.Kill()
}

func TestAgent_SFTP(t *testing.T) {
Expand Down Expand Up @@ -1648,7 +1621,7 @@ func TestAgent_WriteVSCodeConfigs(t *testing.T) {
}, testutil.WaitShort, testutil.IntervalFast)
}

func setupSSHCommand(t *testing.T, beforeArgs []string, afterArgs []string) *exec.Cmd {
func setupSSHCommand(t *testing.T, beforeArgs []string, afterArgs []string) (*ptytest.PTYCmd, pty.Process) {
//nolint:dogsled
agentConn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0)
listener, err := net.Listen("tcp", "127.0.0.1:0")
Expand Down Expand Up @@ -1690,7 +1663,8 @@ func setupSSHCommand(t *testing.T, beforeArgs []string, afterArgs []string) *exe
"host",
)
args = append(args, afterArgs...)
return exec.Command("ssh", args...)
cmd := exec.Command("ssh", args...)
return ptytest.Start(t, cmd)
}

func setupSSHSession(t *testing.T, options agentsdk.Manifest) *ssh.Session {
Expand Down
188 changes: 91 additions & 97 deletions agent/agentssh/agentssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -253,102 +253,12 @@ func (s *Server) sessionStart(session ssh.Session, extraEnv []string) (retErr er

sshPty, windowSize, isPty := session.Pty()
if isPty {
// Disable minimal PTY emulation set by gliderlabs/ssh (NL-to-CRNL).
// See https://github.com/coder/coder/issues/3371.
session.DisablePTYEmulation()

if !isQuietLogin(session.RawCommand()) {
manifest := s.Manifest.Load()
if manifest != nil {
err = showMOTD(session, manifest.MOTDFile)
if err != nil {
s.logger.Error(ctx, "show MOTD", slog.Error(err))
}
} else {
s.logger.Warn(ctx, "metadata lookup failed, unable to show MOTD")
}
}

cmd.Env = append(cmd.Env, fmt.Sprintf("TERM=%s", sshPty.Term))

// The pty package sets `SSH_TTY` on supported platforms.
ptty, process, err := pty.Start(cmd, pty.WithPTYOption(
pty.WithSSHRequest(sshPty),
pty.WithLogger(slog.Stdlib(ctx, s.logger, slog.LevelInfo)),
))
if err != nil {
return xerrors.Errorf("start command: %w", err)
}
var wg sync.WaitGroup
defer func() {
defer wg.Wait()
closeErr := ptty.Close()
if closeErr != nil {
s.logger.Warn(ctx, "failed to close tty", slog.Error(closeErr))
if retErr == nil {
retErr = closeErr
}
}
}()
go func() {
for win := range windowSize {
resizeErr := ptty.Resize(uint16(win.Height), uint16(win.Width))
// If the pty is closed, then command has exited, no need to log.
if resizeErr != nil && !errors.Is(resizeErr, pty.ErrClosed) {
s.logger.Warn(ctx, "failed to resize tty", slog.Error(resizeErr))
}
}
}()
// We don't add input copy to wait group because
// it won't return until the session is closed.
go func() {
_, _ = io.Copy(ptty.Input(), session)
}()

// In low parallelism scenarios, the command may exit and we may close
// the pty before the output copy has started. This can result in the
// output being lost. To avoid this, we wait for the output copy to
// start before waiting for the command to exit. This ensures that the
// output copy goroutine will be scheduled before calling close on the
// pty. This shouldn't be needed because of `pty.Dup()` below, but it
// may not be supported on all platforms.
outputCopyStarted := make(chan struct{})
ptyOutput := func() io.ReadCloser {
defer close(outputCopyStarted)
// Try to dup so we can separate stdin and stdout closure.
// Once the original pty is closed, the dup will return
// input/output error once the buffered data has been read.
stdout, err := ptty.Dup()
if err == nil {
return stdout
}
// If we can't dup, we shouldn't close
// the fd since it's tied to stdin.
return readNopCloser{ptty.Output()}
}
wg.Add(1)
go func() {
// Ensure data is flushed to session on command exit, if we
// close the session too soon, we might lose data.
defer wg.Done()

stdout := ptyOutput()
defer stdout.Close()

_, _ = io.Copy(session, stdout)
}()
<-outputCopyStarted

err = process.Wait()
var exitErr *exec.ExitError
// ExitErrors just mean the command we run returned a non-zero exit code, which is normal
// and not something to be concerned about. But, if it's something else, we should log it.
if err != nil && !xerrors.As(err, &exitErr) {
s.logger.Warn(ctx, "wait error", slog.Error(err))
}
return err
return s.startPTYSession(session, cmd, sshPty, windowSize)
}
return startNonPTYSession(session, cmd)
}

func startNonPTYSession(session ssh.Session, cmd *exec.Cmd) error {
cmd.Stdout = session
cmd.Stderr = session.Stderr()
// This blocks forever until stdin is received if we don't
Expand All @@ -368,10 +278,94 @@ func (s *Server) sessionStart(session ssh.Session, extraEnv []string) (retErr er
return cmd.Wait()
}

type readNopCloser struct{ io.Reader }
// ptySession is the interface to the ssh.Session that startPTYSession uses
// we use an interface here so that we can fake it in tests.
type ptySession interface {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason you chose a stripped down interface vs using ssh.Session directly? I see the test but we also have a fake context there with all the methods so I'm not seeing the benefit per-se. I feel this creates a bit of needless indirection. Not critical to change now, I'll see if I make some changes down the line as I do some refactoring for the session handling anyway.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ssh.Session is an even bigger interface!

I couldn't see any way to directly instantiate any of the concrete implementations in like gliderlabs.

My first stab at the test I wanted to do created the whole server like the other ssh tests do, but I found the even after calling Close() on the network connection, the session context in the handler wasn't closed quickly. So, this construction narrows down the big ssh.Session interface into just what we need, and allows me to more precisely control when the context expires in tests.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a bit annoying that I have to mock out the context methods, but unfortunately, ssh.Session is defined to have a method Context() that returns ssh.Context rather than context.Context.

Unfortunately the go type-checker isn't smart enough to understand that Context() ssh.Context must also satisfy Context() context.Context since context.Context is a strict subset of ssh.Context. Alas.

io.ReadWriter
Context() ssh.Context
DisablePTYEmulation()
RawCommand() string
}

func (s *Server) startPTYSession(session ptySession, cmd *exec.Cmd, sshPty ssh.Pty, windowSize <-chan ssh.Window) (retErr error) {
ctx := session.Context()
// Disable minimal PTY emulation set by gliderlabs/ssh (NL-to-CRNL).
// See https://github.com/coder/coder/issues/3371.
session.DisablePTYEmulation()

if !isQuietLogin(session.RawCommand()) {
manifest := s.Manifest.Load()
if manifest != nil {
err := showMOTD(session, manifest.MOTDFile)
if err != nil {
s.logger.Error(ctx, "show MOTD", slog.Error(err))
}
} else {
s.logger.Warn(ctx, "metadata lookup failed, unable to show MOTD")
}
}

cmd.Env = append(cmd.Env, fmt.Sprintf("TERM=%s", sshPty.Term))

// The pty package sets `SSH_TTY` on supported platforms.
ptty, process, err := pty.Start(cmd, pty.WithPTYOption(
pty.WithSSHRequest(sshPty),
pty.WithLogger(slog.Stdlib(ctx, s.logger, slog.LevelInfo)),
))
if err != nil {
return xerrors.Errorf("start command: %w", err)
}
defer func() {
closeErr := ptty.Close()
if closeErr != nil {
s.logger.Warn(ctx, "failed to close tty", slog.Error(closeErr))
if retErr == nil {
retErr = closeErr
}
}
}()
go func() {
for win := range windowSize {
resizeErr := ptty.Resize(uint16(win.Height), uint16(win.Width))
// If the pty is closed, then command has exited, no need to log.
if resizeErr != nil && !errors.Is(resizeErr, pty.ErrClosed) {
s.logger.Warn(ctx, "failed to resize tty", slog.Error(resizeErr))
}
}
}()

go func() {
_, _ = io.Copy(ptty.InputWriter(), session)
}()

// Close implements io.Closer.
func (readNopCloser) Close() error { return nil }
// We need to wait for the command output to finish copying. It's safe to
// just do this copy on the main handler goroutine because one of two things
// will happen:
//
// 1. The command completes & closes the TTY, which then triggers an error
// after we've Read() all the buffered data from the PTY.
// 2. The client hangs up, which cancels the command's Context, and go will
// kill the command's process. This then has the same effect as (1).
n, err := io.Copy(session, ptty.OutputReader())
s.logger.Debug(ctx, "copy output done", slog.F("bytes", n), slog.Error(err))
if err != nil {
return xerrors.Errorf("copy error: %w", err)
}
// We've gotten all the output, but we need to wait for the process to
// complete so that we can get the exit code. This returns
// immediately if the TTY was closed as part of the command exiting.
err = process.Wait()
var exitErr *exec.ExitError
// ExitErrors just mean the command we run returned a non-zero exit code, which is normal
// and not something to be concerned about. But, if it's something else, we should log it.
if err != nil && !xerrors.As(err, &exitErr) {
s.logger.Warn(ctx, "wait error", slog.Error(err))
}
if err != nil {
return xerrors.Errorf("process wait: %w", err)
}
return nil
}

func (s *Server) sftpHandler(session ssh.Session) {
ctx := session.Context()
Expand Down
Loading