diff --git a/agent/agent.go b/agent/agent.go index efd57e5db29db..9b70506b49936 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -1045,7 +1045,7 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, logger slog.Logger, m if err = a.trackConnGoroutine(func() { buffer := make([]byte, 1024) for { - read, err := rpty.ptty.Output().Read(buffer) + read, err := rpty.ptty.OutputReader().Read(buffer) if err != nil { // When the PTY is closed, this is triggered. break @@ -1138,7 +1138,7 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, logger slog.Logger, m logger.Warn(ctx, "read conn", slog.Error(err)) return nil } - _, err = rpty.ptty.Input().Write([]byte(req.Data)) + _, err = rpty.ptty.InputWriter().Write([]byte(req.Data)) if err != nil { logger.Warn(ctx, "write to pty", slog.Error(err)) return nil @@ -1358,7 +1358,7 @@ type reconnectingPTY struct { circularBuffer *circbuf.Buffer circularBufferMutex sync.RWMutex timeout *time.Timer - ptty pty.PTY + ptty pty.PTYCmd } // Close ends all connections to the reconnecting diff --git a/agent/agent_test.go b/agent/agent_test.go index 49fbb9f2e7468..6527e82031f13 100644 --- a/agent/agent_test.go +++ b/agent/agent_test.go @@ -45,6 +45,7 @@ import ( "github.com/coder/coder/coderd/httpapi" "github.com/coder/coder/codersdk" "github.com/coder/coder/codersdk/agentsdk" + "github.com/coder/coder/pty" "github.com/coder/coder/pty/ptytest" "github.com/coder/coder/tailnet" "github.com/coder/coder/tailnet/tailnettest" @@ -481,17 +482,10 @@ func TestAgent_TCPLocalForwarding(t *testing.T) { } }() - pty := ptytest.New(t) - - cmd := setupSSHCommand(t, []string{"-L", fmt.Sprintf("%d:127.0.0.1:%d", randomPort, remotePort)}, []string{"sleep", "5"}) - cmd.Stdin = pty.Input() - cmd.Stdout = pty.Output() - cmd.Stderr = pty.Output() - err = cmd.Start() - require.NoError(t, err) + _, proc := setupSSHCommand(t, []string{"-L", fmt.Sprintf("%d:127.0.0.1:%d", randomPort, remotePort)}, []string{"sleep", "5"}) go func() { - err := cmd.Wait() + err := proc.Wait() select { case <-done: default: @@ -523,7 +517,7 @@ func TestAgent_TCPLocalForwarding(t *testing.T) { <-done - _ = cmd.Process.Kill() + _ = proc.Kill() } //nolint:paralleltest // This test reserves a port. @@ -562,17 +556,10 @@ func TestAgent_TCPRemoteForwarding(t *testing.T) { } }() - pty := ptytest.New(t) - - cmd := setupSSHCommand(t, []string{"-R", fmt.Sprintf("127.0.0.1:%d:127.0.0.1:%d", randomPort, localPort)}, []string{"sleep", "5"}) - cmd.Stdin = pty.Input() - cmd.Stdout = pty.Output() - cmd.Stderr = pty.Output() - err = cmd.Start() - require.NoError(t, err) + _, proc := setupSSHCommand(t, []string{"-R", fmt.Sprintf("127.0.0.1:%d:127.0.0.1:%d", randomPort, localPort)}, []string{"sleep", "5"}) go func() { - err := cmd.Wait() + err := proc.Wait() select { case <-done: default: @@ -604,7 +591,7 @@ func TestAgent_TCPRemoteForwarding(t *testing.T) { <-done - _ = cmd.Process.Kill() + _ = proc.Kill() } func TestAgent_UnixLocalForwarding(t *testing.T) { @@ -641,17 +628,10 @@ func TestAgent_UnixLocalForwarding(t *testing.T) { } }() - pty := ptytest.New(t) - - cmd := setupSSHCommand(t, []string{"-L", fmt.Sprintf("%s:%s", localSocketPath, remoteSocketPath)}, []string{"sleep", "5"}) - cmd.Stdin = pty.Input() - cmd.Stdout = pty.Output() - cmd.Stderr = pty.Output() - err = cmd.Start() - require.NoError(t, err) + _, proc := setupSSHCommand(t, []string{"-L", fmt.Sprintf("%s:%s", localSocketPath, remoteSocketPath)}, []string{"sleep", "5"}) go func() { - err := cmd.Wait() + err := proc.Wait() select { case <-done: default: @@ -676,7 +656,7 @@ func TestAgent_UnixLocalForwarding(t *testing.T) { _ = conn.Close() <-done - _ = cmd.Process.Kill() + _ = proc.Kill() } func TestAgent_UnixRemoteForwarding(t *testing.T) { @@ -713,17 +693,10 @@ func TestAgent_UnixRemoteForwarding(t *testing.T) { } }() - pty := ptytest.New(t) - - cmd := setupSSHCommand(t, []string{"-R", fmt.Sprintf("%s:%s", remoteSocketPath, localSocketPath)}, []string{"sleep", "5"}) - cmd.Stdin = pty.Input() - cmd.Stdout = pty.Output() - cmd.Stderr = pty.Output() - err = cmd.Start() - require.NoError(t, err) + _, proc := setupSSHCommand(t, []string{"-R", fmt.Sprintf("%s:%s", remoteSocketPath, localSocketPath)}, []string{"sleep", "5"}) go func() { - err := cmd.Wait() + err := proc.Wait() select { case <-done: default: @@ -753,7 +726,7 @@ func TestAgent_UnixRemoteForwarding(t *testing.T) { <-done - _ = cmd.Process.Kill() + _ = proc.Kill() } func TestAgent_SFTP(t *testing.T) { @@ -1648,7 +1621,7 @@ func TestAgent_WriteVSCodeConfigs(t *testing.T) { }, testutil.WaitShort, testutil.IntervalFast) } -func setupSSHCommand(t *testing.T, beforeArgs []string, afterArgs []string) *exec.Cmd { +func setupSSHCommand(t *testing.T, beforeArgs []string, afterArgs []string) (*ptytest.PTYCmd, pty.Process) { //nolint:dogsled agentConn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0) listener, err := net.Listen("tcp", "127.0.0.1:0") @@ -1690,7 +1663,8 @@ func setupSSHCommand(t *testing.T, beforeArgs []string, afterArgs []string) *exe "host", ) args = append(args, afterArgs...) - return exec.Command("ssh", args...) + cmd := exec.Command("ssh", args...) + return ptytest.Start(t, cmd) } func setupSSHSession(t *testing.T, options agentsdk.Manifest) *ssh.Session { diff --git a/agent/agentssh/agentssh.go b/agent/agentssh/agentssh.go index a22f86836d147..c9bd17362b156 100644 --- a/agent/agentssh/agentssh.go +++ b/agent/agentssh/agentssh.go @@ -253,102 +253,12 @@ func (s *Server) sessionStart(session ssh.Session, extraEnv []string) (retErr er sshPty, windowSize, isPty := session.Pty() if isPty { - // Disable minimal PTY emulation set by gliderlabs/ssh (NL-to-CRNL). - // See https://github.com/coder/coder/issues/3371. - session.DisablePTYEmulation() - - if !isQuietLogin(session.RawCommand()) { - manifest := s.Manifest.Load() - if manifest != nil { - err = showMOTD(session, manifest.MOTDFile) - if err != nil { - s.logger.Error(ctx, "show MOTD", slog.Error(err)) - } - } else { - s.logger.Warn(ctx, "metadata lookup failed, unable to show MOTD") - } - } - - cmd.Env = append(cmd.Env, fmt.Sprintf("TERM=%s", sshPty.Term)) - - // The pty package sets `SSH_TTY` on supported platforms. - ptty, process, err := pty.Start(cmd, pty.WithPTYOption( - pty.WithSSHRequest(sshPty), - pty.WithLogger(slog.Stdlib(ctx, s.logger, slog.LevelInfo)), - )) - if err != nil { - return xerrors.Errorf("start command: %w", err) - } - var wg sync.WaitGroup - defer func() { - defer wg.Wait() - closeErr := ptty.Close() - if closeErr != nil { - s.logger.Warn(ctx, "failed to close tty", slog.Error(closeErr)) - if retErr == nil { - retErr = closeErr - } - } - }() - go func() { - for win := range windowSize { - resizeErr := ptty.Resize(uint16(win.Height), uint16(win.Width)) - // If the pty is closed, then command has exited, no need to log. - if resizeErr != nil && !errors.Is(resizeErr, pty.ErrClosed) { - s.logger.Warn(ctx, "failed to resize tty", slog.Error(resizeErr)) - } - } - }() - // We don't add input copy to wait group because - // it won't return until the session is closed. - go func() { - _, _ = io.Copy(ptty.Input(), session) - }() - - // In low parallelism scenarios, the command may exit and we may close - // the pty before the output copy has started. This can result in the - // output being lost. To avoid this, we wait for the output copy to - // start before waiting for the command to exit. This ensures that the - // output copy goroutine will be scheduled before calling close on the - // pty. This shouldn't be needed because of `pty.Dup()` below, but it - // may not be supported on all platforms. - outputCopyStarted := make(chan struct{}) - ptyOutput := func() io.ReadCloser { - defer close(outputCopyStarted) - // Try to dup so we can separate stdin and stdout closure. - // Once the original pty is closed, the dup will return - // input/output error once the buffered data has been read. - stdout, err := ptty.Dup() - if err == nil { - return stdout - } - // If we can't dup, we shouldn't close - // the fd since it's tied to stdin. - return readNopCloser{ptty.Output()} - } - wg.Add(1) - go func() { - // Ensure data is flushed to session on command exit, if we - // close the session too soon, we might lose data. - defer wg.Done() - - stdout := ptyOutput() - defer stdout.Close() - - _, _ = io.Copy(session, stdout) - }() - <-outputCopyStarted - - err = process.Wait() - var exitErr *exec.ExitError - // ExitErrors just mean the command we run returned a non-zero exit code, which is normal - // and not something to be concerned about. But, if it's something else, we should log it. - if err != nil && !xerrors.As(err, &exitErr) { - s.logger.Warn(ctx, "wait error", slog.Error(err)) - } - return err + return s.startPTYSession(session, cmd, sshPty, windowSize) } + return startNonPTYSession(session, cmd) +} +func startNonPTYSession(session ssh.Session, cmd *exec.Cmd) error { cmd.Stdout = session cmd.Stderr = session.Stderr() // This blocks forever until stdin is received if we don't @@ -368,10 +278,94 @@ func (s *Server) sessionStart(session ssh.Session, extraEnv []string) (retErr er return cmd.Wait() } -type readNopCloser struct{ io.Reader } +// ptySession is the interface to the ssh.Session that startPTYSession uses +// we use an interface here so that we can fake it in tests. +type ptySession interface { + io.ReadWriter + Context() ssh.Context + DisablePTYEmulation() + RawCommand() string +} + +func (s *Server) startPTYSession(session ptySession, cmd *exec.Cmd, sshPty ssh.Pty, windowSize <-chan ssh.Window) (retErr error) { + ctx := session.Context() + // Disable minimal PTY emulation set by gliderlabs/ssh (NL-to-CRNL). + // See https://github.com/coder/coder/issues/3371. + session.DisablePTYEmulation() + + if !isQuietLogin(session.RawCommand()) { + manifest := s.Manifest.Load() + if manifest != nil { + err := showMOTD(session, manifest.MOTDFile) + if err != nil { + s.logger.Error(ctx, "show MOTD", slog.Error(err)) + } + } else { + s.logger.Warn(ctx, "metadata lookup failed, unable to show MOTD") + } + } + + cmd.Env = append(cmd.Env, fmt.Sprintf("TERM=%s", sshPty.Term)) + + // The pty package sets `SSH_TTY` on supported platforms. + ptty, process, err := pty.Start(cmd, pty.WithPTYOption( + pty.WithSSHRequest(sshPty), + pty.WithLogger(slog.Stdlib(ctx, s.logger, slog.LevelInfo)), + )) + if err != nil { + return xerrors.Errorf("start command: %w", err) + } + defer func() { + closeErr := ptty.Close() + if closeErr != nil { + s.logger.Warn(ctx, "failed to close tty", slog.Error(closeErr)) + if retErr == nil { + retErr = closeErr + } + } + }() + go func() { + for win := range windowSize { + resizeErr := ptty.Resize(uint16(win.Height), uint16(win.Width)) + // If the pty is closed, then command has exited, no need to log. + if resizeErr != nil && !errors.Is(resizeErr, pty.ErrClosed) { + s.logger.Warn(ctx, "failed to resize tty", slog.Error(resizeErr)) + } + } + }() + + go func() { + _, _ = io.Copy(ptty.InputWriter(), session) + }() -// Close implements io.Closer. -func (readNopCloser) Close() error { return nil } + // We need to wait for the command output to finish copying. It's safe to + // just do this copy on the main handler goroutine because one of two things + // will happen: + // + // 1. The command completes & closes the TTY, which then triggers an error + // after we've Read() all the buffered data from the PTY. + // 2. The client hangs up, which cancels the command's Context, and go will + // kill the command's process. This then has the same effect as (1). + n, err := io.Copy(session, ptty.OutputReader()) + s.logger.Debug(ctx, "copy output done", slog.F("bytes", n), slog.Error(err)) + if err != nil { + return xerrors.Errorf("copy error: %w", err) + } + // We've gotten all the output, but we need to wait for the process to + // complete so that we can get the exit code. This returns + // immediately if the TTY was closed as part of the command exiting. + err = process.Wait() + var exitErr *exec.ExitError + // ExitErrors just mean the command we run returned a non-zero exit code, which is normal + // and not something to be concerned about. But, if it's something else, we should log it. + if err != nil && !xerrors.As(err, &exitErr) { + s.logger.Warn(ctx, "wait error", slog.Error(err)) + } + if err != nil { + return xerrors.Errorf("process wait: %w", err) + } + return nil +} func (s *Server) sftpHandler(session ssh.Session) { ctx := session.Context() diff --git a/agent/agentssh/agentssh_internal_test.go b/agent/agentssh/agentssh_internal_test.go new file mode 100644 index 0000000000000..33f41dd15a452 --- /dev/null +++ b/agent/agentssh/agentssh_internal_test.go @@ -0,0 +1,190 @@ +//go:build !windows + +package agentssh + +import ( + "bufio" + "context" + "io" + "net" + "os/exec" + "testing" + + gliderssh "github.com/gliderlabs/ssh" + "github.com/spf13/afero" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/coder/coder/testutil" + + "cdr.dev/slog/sloggers/slogtest" +) + +const longScript = ` +echo "started" +sleep 30 +echo "done" +` + +// Test_sessionStart_orphan tests running a command that takes a long time to +// exit normally, and terminate the SSH session context early to verify that we +// return quickly and don't leave the command running as an "orphan" with no +// active SSH session. +func Test_sessionStart_orphan(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium) + defer cancel() + logger := slogtest.Make(t, nil) + s, err := NewServer(ctx, logger, afero.NewMemMapFs(), 0, "") + require.NoError(t, err) + + // Here we're going to call the handler directly with a faked SSH session + // that just uses io.Pipes instead of a network socket. There is a large + // variation in the time between closing the socket from the client side and + // the SSH server canceling the session Context, which would lead to a flaky + // test if we did it that way. So instead, we directly cancel the context + // in this test. + sessionCtx, sessionCancel := context.WithCancel(ctx) + toClient, fromClient, sess := newTestSession(sessionCtx) + ptyInfo := gliderssh.Pty{} + windowSize := make(chan gliderssh.Window) + close(windowSize) + // the command gets the session context so that Go will terminate it when + // the session expires. + cmd := exec.CommandContext(sessionCtx, "sh", "-c", longScript) + + done := make(chan struct{}) + go func() { + defer close(done) + // we don't really care what the error is here. In the larger scenario, + // the client has disconnected, so we can't return any error information + // to them. + _ = s.startPTYSession(sess, cmd, ptyInfo, windowSize) + }() + + readDone := make(chan struct{}) + go func() { + defer close(readDone) + s := bufio.NewScanner(toClient) + assert.True(t, s.Scan()) + txt := s.Text() + assert.Equal(t, "started", txt, "output corrupted") + }() + + waitForChan(ctx, t, readDone, "read timeout") + // process is started, and should be sleeping for ~30 seconds + + sessionCancel() + + // now, we wait for the handler to complete. If it does so before the + // main test timeout, we consider this a pass. If not, it indicates + // that the server isn't properly shutting down sessions when they are + // disconnected client side, which could lead to processes hanging around + // indefinitely. + waitForChan(ctx, t, done, "handler timeout") + + err = fromClient.Close() + require.NoError(t, err) +} + +func waitForChan(ctx context.Context, t *testing.T, c <-chan struct{}, msg string) { + t.Helper() + select { + case <-c: + // OK! + case <-ctx.Done(): + t.Fatal(msg) + } +} + +type testSession struct { + ctx testSSHContext + + // c2p is the client -> pty buffer + toPty *io.PipeReader + // p2c is the pty -> client buffer + fromPty *io.PipeWriter +} + +type testSSHContext struct { + context.Context +} + +func newTestSession(ctx context.Context) (toClient *io.PipeReader, fromClient *io.PipeWriter, s ptySession) { + toClient, fromPty := io.Pipe() + toPty, fromClient := io.Pipe() + + return toClient, fromClient, &testSession{ + ctx: testSSHContext{ctx}, + toPty: toPty, + fromPty: fromPty, + } +} + +func (s *testSession) Context() gliderssh.Context { + return s.ctx +} + +func (*testSession) DisablePTYEmulation() {} + +// RawCommand returns "quiet logon" so that the PTY handler doesn't attempt to +// write the message of the day, which will interfere with our tests. It writes +// the message of the day if it's a shell login (zero length RawCommand()). +func (*testSession) RawCommand() string { return "quiet logon" } + +func (s *testSession) Read(p []byte) (n int, err error) { + return s.toPty.Read(p) +} + +func (s *testSession) Write(p []byte) (n int, err error) { + return s.fromPty.Write(p) +} + +func (testSSHContext) Lock() { + panic("not implemented") +} + +func (testSSHContext) Unlock() { + panic("not implemented") +} + +// User returns the username used when establishing the SSH connection. +func (testSSHContext) User() string { + panic("not implemented") +} + +// SessionID returns the session hash. +func (testSSHContext) SessionID() string { + panic("not implemented") +} + +// ClientVersion returns the version reported by the client. +func (testSSHContext) ClientVersion() string { + panic("not implemented") +} + +// ServerVersion returns the version reported by the server. +func (testSSHContext) ServerVersion() string { + panic("not implemented") +} + +// RemoteAddr returns the remote address for this connection. +func (testSSHContext) RemoteAddr() net.Addr { + panic("not implemented") +} + +// LocalAddr returns the local address for this connection. +func (testSSHContext) LocalAddr() net.Addr { + panic("not implemented") +} + +// Permissions returns the Permissions object used for this connection. +func (testSSHContext) Permissions() *gliderssh.Permissions { + panic("not implemented") +} + +// SetValue allows you to easily write new values into the underlying context. +func (testSSHContext) SetValue(_, _ interface{}) { + panic("not implemented") +} diff --git a/go.mod b/go.mod index 2da8466c53707..91cbf22aebab0 100644 --- a/go.mod +++ b/go.mod @@ -108,6 +108,7 @@ require ( github.com/hashicorp/terraform-config-inspect v0.0.0-20211115214459-90acf1ca460f github.com/hashicorp/terraform-json v0.14.0 github.com/hashicorp/yamux v0.0.0-20220718163420-dd80a7ee44ce + github.com/hinshun/vt10x v0.0.0-20220301184237-5011da428d02 github.com/imulab/go-scim/pkg/v2 v2.2.0 github.com/jedib0t/go-pretty/v6 v6.4.0 github.com/jmoiron/sqlx v1.3.5 diff --git a/go.sum b/go.sum index f6ad4771f8f79..1ec2c9fd20669 100644 --- a/go.sum +++ b/go.sum @@ -972,8 +972,9 @@ github.com/hashicorp/yamux v0.0.0-20220718163420-dd80a7ee44ce h1:7FO+LmZwiG/eDsB github.com/hashicorp/yamux v0.0.0-20220718163420-dd80a7ee44ce/go.mod h1:CtWFDAQgb7dxtzFs4tWbplKIe2jSi3+5vKbgIO0SLnQ= github.com/hdevalence/ed25519consensus v0.0.0-20220222234857-c00d1f31bab3 h1:aSVUgRRRtOrZOC1fYmY9gV0e9z/Iu+xNVSASWjsuyGU= github.com/hdevalence/ed25519consensus v0.0.0-20220222234857-c00d1f31bab3/go.mod h1:5PC6ZNPde8bBqU/ewGZig35+UIZtw9Ytxez8/q5ZyFE= -github.com/hinshun/vt10x v0.0.0-20220119200601-820417d04eec h1:qv2VnGeEQHchGaZ/u7lxST/RaJw+cv273q79D81Xbog= github.com/hinshun/vt10x v0.0.0-20220119200601-820417d04eec/go.mod h1:Q48J4R4DvxnHolD5P8pOtXigYlRuPLGl6moFx3ulM68= +github.com/hinshun/vt10x v0.0.0-20220301184237-5011da428d02 h1:AgcIVYPa6XJnU3phs104wLj8l5GEththEw6+F79YsIY= +github.com/hinshun/vt10x v0.0.0-20220301184237-5011da428d02/go.mod h1:Q48J4R4DvxnHolD5P8pOtXigYlRuPLGl6moFx3ulM68= github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= github.com/huandu/xstrings v1.3.1/go.mod h1:y5/lhBue+AyNmUVz9RLU9xbLR0o4KIIExikq4ovT0aE= github.com/huandu/xstrings v1.3.2/go.mod h1:y5/lhBue+AyNmUVz9RLU9xbLR0o4KIIExikq4ovT0aE= diff --git a/pty/pty.go b/pty/pty.go index 4156e74caadee..507e9468e2084 100644 --- a/pty/pty.go +++ b/pty/pty.go @@ -3,7 +3,6 @@ package pty import ( "io" "log" - "os" "github.com/gliderlabs/ssh" "golang.org/x/xerrors" @@ -12,10 +11,33 @@ import ( // ErrClosed is returned when a PTY is used after it has been closed. var ErrClosed = xerrors.New("pty: closed") -// PTY is a minimal interface for interacting with a TTY. +// PTYCmd is an interface for interacting with a pseudo-TTY where we control +// only one end, and the other end has been passed to a running os.Process. +// nolint:revive +type PTYCmd interface { + io.Closer + + // Resize sets the size of the PTY. + Resize(height uint16, width uint16) error + + // OutputReader returns an io.Reader for reading the output from the process + // controlled by the pseudo-TTY + OutputReader() io.Reader + + // InputWriter returns an io.Writer for writing into to the process + // controlled by the pseudo-TTY + InputWriter() io.Writer +} + +// PTY is a minimal interface for interacting with pseudo-TTY where this +// process retains access to _both_ ends of the pseudo-TTY (i.e. `ptm` & `pts` +// on Linux). type PTY interface { io.Closer + // Resize sets the size of the PTY. + Resize(height uint16, width uint16) error + // Name of the TTY. Example on Linux would be "/dev/pts/1". Name() string @@ -34,14 +56,6 @@ type PTY interface { // // The same stream would be used to provide user input: pty.Input().Write(...) Input() ReadWriter - - // Dup returns a new file descriptor for the PTY. - // - // This is useful for closing stdin and stdout separately. - Dup() (*os.File, error) - - // Resize sets the size of the PTY. - Resize(height uint16, width uint16) error } // Process represents a process running in a PTY. We need to trigger special processing on the PTY @@ -108,8 +122,8 @@ func New(opts ...Option) (PTY, error) { // underlying file descriptors, one for reading and one for writing, and allows // them to be accessed separately. type ReadWriter struct { - Reader *os.File - Writer *os.File + Reader io.Reader + Writer io.Writer } func (rw ReadWriter) Read(p []byte) (int, error) { diff --git a/pty/pty_other.go b/pty/pty_other.go index f0a49184c80b9..a5fa9d555d545 100644 --- a/pty/pty_other.go +++ b/pty/pty_other.go @@ -3,15 +3,17 @@ package pty import ( + "io" + "io/fs" "os" "os/exec" "runtime" "sync" - "syscall" "github.com/creack/pty" "github.com/u-root/u-root/pkg/termios" "golang.org/x/sys/unix" + "golang.org/x/xerrors" ) func newPty(opt ...Option) (retPTY *otherPty, err error) { @@ -28,6 +30,7 @@ func newPty(opt ...Option) (retPTY *otherPty, err error) { pty: ptyFile, tty: ttyFile, opts: opts, + name: ttyFile.Name(), } defer func() { if err != nil { @@ -53,6 +56,7 @@ type otherPty struct { err error pty, tty *os.File opts ptyOptions + name string } func (p *otherPty) control(tty *os.File, fn func(fd uintptr) error) (err error) { @@ -85,7 +89,7 @@ func (p *otherPty) control(tty *os.File, fn func(fd uintptr) error) (err error) } func (p *otherPty) Name() string { - return p.tty.Name() + return p.name } func (p *otherPty) Input() ReadWriter { @@ -95,13 +99,21 @@ func (p *otherPty) Input() ReadWriter { } } +func (p *otherPty) InputWriter() io.Writer { + return p.pty +} + func (p *otherPty) Output() ReadWriter { return ReadWriter{ - Reader: p.pty, + Reader: &ptmReader{p.pty}, Writer: p.tty, } } +func (p *otherPty) OutputReader() io.Reader { + return &ptmReader{p.pty} +} + func (p *otherPty) Resize(height uint16, width uint16) error { return p.control(p.pty, func(fd uintptr) error { return termios.SetWinSize(fd, &termios.Winsize{ @@ -113,20 +125,6 @@ func (p *otherPty) Resize(height uint16, width uint16) error { }) } -func (p *otherPty) Dup() (*os.File, error) { - var newfd int - err := p.control(p.pty, func(fd uintptr) error { - var err error - newfd, err = syscall.Dup(int(fd)) - return err - }) - if err != nil { - return nil, err - } - - return os.NewFile(uintptr(newfd), p.pty.Name()), nil -} - func (p *otherPty) Close() error { p.mutex.Lock() defer p.mutex.Unlock() @@ -137,9 +135,12 @@ func (p *otherPty) Close() error { p.closed = true err := p.pty.Close() - err2 := p.tty.Close() - if err == nil { - err = err2 + // tty is closed & unset if we Start() a new process + if p.tty != nil { + err2 := p.tty.Close() + if err == nil { + err = err2 + } } if err != nil { @@ -177,3 +178,21 @@ func (p *otherProcess) waitInternal() { runtime.KeepAlive(p.pty) close(p.cmdDone) } + +// ptmReader wraps a reference to the ptm side of a pseudo-TTY for portability +type ptmReader struct { + ptm io.Reader +} + +func (r *ptmReader) Read(p []byte) (n int, err error) { + n, err = r.ptm.Read(p) + // output from the ptm will hit a PathErr when the process hangs up the + // other side (typically when the process exits, but could be earlier). For + // portability, and to fit with our use of io.Copy() to copy from the PTY, + // we want to translate this error into io.EOF + pathErr := &fs.PathError{} + if xerrors.As(err, &pathErr) { + return n, io.EOF + } + return n, err +} diff --git a/pty/pty_windows.go b/pty/pty_windows.go index b1afec6778be3..80f6b74f436e9 100644 --- a/pty/pty_windows.go +++ b/pty/pty_windows.go @@ -3,6 +3,7 @@ package pty import ( + "io" "os" "os/exec" "sync" @@ -21,7 +22,7 @@ var ( ) // See: https://docs.microsoft.com/en-us/windows/console/creating-a-pseudoconsole-session -func newPty(opt ...Option) (PTY, error) { +func newPty(opt ...Option) (*ptyWindows, error) { var opts ptyOptions for _, o := range opt { o(&opts) @@ -88,6 +89,7 @@ type windowsProcess struct { cmdDone chan any cmdErr error proc *os.Process + pw *ptyWindows } // Name returns the TTY name on Windows. @@ -104,6 +106,10 @@ func (p *ptyWindows) Output() ReadWriter { } } +func (p *ptyWindows) OutputReader() io.Reader { + return p.outputRead +} + func (p *ptyWindows) Input() ReadWriter { return ReadWriter{ Reader: p.inputRead, @@ -111,7 +117,17 @@ func (p *ptyWindows) Input() ReadWriter { } } +func (p *ptyWindows) InputWriter() io.Writer { + return p.inputWrite +} + func (p *ptyWindows) Resize(height uint16, width uint16) error { + // hold the lock, so we don't race with anyone trying to close the console + p.closeMutex.Lock() + defer p.closeMutex.Unlock() + if p.closed || p.console == windows.InvalidHandle { + return ErrClosed + } // Taken from: https://github.com/microsoft/hcsshim/blob/54a5ad86808d761e3e396aff3e2022840f39f9a8/internal/winapi/zsyscall_windows.go#L144 ret, _, err := procResizePseudoConsole.Call(uintptr(p.console), uintptr(*((*uint32)(unsafe.Pointer(&windows.Coord{ Y: int16(height), @@ -123,10 +139,6 @@ func (p *ptyWindows) Resize(height uint16, width uint16) error { return nil } -func (p *ptyWindows) Dup() (*os.File, error) { - return nil, xerrors.Errorf("not implemented") -} - func (p *ptyWindows) Close() error { p.closeMutex.Lock() defer p.closeMutex.Unlock() @@ -135,20 +147,54 @@ func (p *ptyWindows) Close() error { } p.closed = true - ret, _, err := procClosePseudoConsole.Call(uintptr(p.console)) - if ret < 0 { - return xerrors.Errorf("close pseudo console: %w", err) + // if we are running a command in the PTY, the corresponding *windowsProcess + // may have already closed the PseudoConsole when the command exited, so that + // output reads can get to EOF. In that case, we don't need to close it + // again here. + if p.console != windows.InvalidHandle { + ret, _, err := procClosePseudoConsole.Call(uintptr(p.console)) + if ret < 0 { + return xerrors.Errorf("close pseudo console: %w", err) + } + p.console = windows.InvalidHandle } - _ = p.outputWrite.Close() + // We always have these files _ = p.outputRead.Close() _ = p.inputWrite.Close() - _ = p.inputRead.Close() + // These get closed & unset if we Start() a new process. + if p.outputWrite != nil { + _ = p.outputWrite.Close() + } + if p.inputRead != nil { + _ = p.inputRead.Close() + } return nil } func (p *windowsProcess) waitInternal() { + // put this on the bottom of the defer stack since the next defer can write to p.cmdErr defer close(p.cmdDone) + defer func() { + // close the pseudoconsole handle when the process exits, if it hasn't already been closed. + // this is important because the PseudoConsole (conhost.exe) holds the write-end + // of the output pipe. If it is not closed, reads on that pipe will block, even though + // the command has exited. + // c.f. https://devblogs.microsoft.com/commandline/windows-command-line-introducing-the-windows-pseudo-console-conpty/ + p.pw.closeMutex.Lock() + defer p.pw.closeMutex.Unlock() + if p.pw.console != windows.InvalidHandle { + ret, _, err := procClosePseudoConsole.Call(uintptr(p.pw.console)) + if ret < 0 && p.cmdErr == nil { + // if we already have an error from the command, prefer that error + // but if the command succeeded and closing the PseudoConsole fails + // then record that error so that we have a chance to see it + p.cmdErr = err + } + p.pw.console = windows.InvalidHandle + } + }() + state, err := p.proc.Wait() if err != nil { p.cmdErr = err diff --git a/pty/ptytest/ptytest.go b/pty/ptytest/ptytest.go index 74331fbfaa1c5..69eb81026efbe 100644 --- a/pty/ptytest/ptytest.go +++ b/pty/ptytest/ptytest.go @@ -30,12 +30,21 @@ func New(t *testing.T, opts ...pty.Option) *PTY { ptty, err := pty.New(opts...) require.NoError(t, err) - return create(t, ptty, "cmd") + e := newExpecter(t, ptty.Output(), "cmd") + r := &PTY{ + outExpecter: e, + PTY: ptty, + } + // Ensure pty is cleaned up at the end of test. + t.Cleanup(func() { + _ = r.Close() + }) + return r } -// Start starts a new process asynchronously and returns a PTY and Process. -// It kills the process upon cleanup. -func Start(t *testing.T, cmd *exec.Cmd, opts ...pty.StartOption) (*PTY, pty.Process) { +// Start starts a new process asynchronously and returns a PTYCmd and Process. +// It kills the process and PTYCmd upon cleanup +func Start(t *testing.T, cmd *exec.Cmd, opts ...pty.StartOption) (*PTYCmd, pty.Process) { t.Helper() ptty, ps, err := pty.Start(cmd, opts...) @@ -44,10 +53,19 @@ func Start(t *testing.T, cmd *exec.Cmd, opts ...pty.StartOption) (*PTY, pty.Proc _ = ps.Kill() _ = ps.Wait() }) - return create(t, ptty, cmd.Args[0]), ps + ex := newExpecter(t, ptty.OutputReader(), cmd.Args[0]) + + r := &PTYCmd{ + outExpecter: ex, + PTYCmd: ptty, + } + t.Cleanup(func() { + _ = r.Close() + }) + return r, ps } -func create(t *testing.T, ptty pty.PTY, name string) *PTY { +func newExpecter(t *testing.T, r io.Reader, name string) outExpecter { // Use pipe for logging. logDone := make(chan struct{}) logr, logw := io.Pipe() @@ -57,37 +75,30 @@ func create(t *testing.T, ptty pty.PTY, name string) *PTY { out := newStdbuf() w := io.MultiWriter(logw, out) - tpty := &PTY{ + ex := outExpecter{ t: t, - PTY: ptty, out: out, name: name, runeReader: bufio.NewReaderSize(out, utf8.UTFMax), } - // Ensure pty is cleaned up at the end of test. - t.Cleanup(func() { - _ = tpty.Close() - }) logClose := func(name string, c io.Closer) { - tpty.logf("closing %s", name) + ex.logf("closing %s", name) err := c.Close() - tpty.logf("closed %s: %v", name, err) + ex.logf("closed %s: %v", name, err) } - // Set the actual close function for the tpty. - tpty.close = func(reason string) error { + // Set the actual close function for the outExpecter. + ex.close = func(reason string) error { ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) defer cancel() - tpty.logf("closing tpty: %s", reason) + ex.logf("closing expecter: %s", reason) - // Close pty only so that the copy goroutine can consume the - // remainder of it's buffer and then exit. - logClose("pty", ptty) + // Caller needs to have closed the PTY so that copying can complete select { case <-ctx.Done(): - tpty.fatalf("close", "copy did not close in time") + ex.fatalf("close", "copy did not close in time") case <-copyDone: } @@ -95,22 +106,22 @@ func create(t *testing.T, ptty pty.PTY, name string) *PTY { logClose("logr", logr) select { case <-ctx.Done(): - tpty.fatalf("close", "log pipe did not close in time") + ex.fatalf("close", "log pipe did not close in time") case <-logDone: } - tpty.logf("closed tpty") + ex.logf("closed expecter") return nil } go func() { defer close(copyDone) - _, err := io.Copy(w, ptty.Output()) - tpty.logf("copy done: %v", err) - tpty.logf("closing out") + _, err := io.Copy(w, r) + ex.logf("copy done: %v", err) + ex.logf("closing out") err = out.closeErr(err) - tpty.logf("closed out: %v", err) + ex.logf("closed out: %v", err) }() // Log all output as part of test for easier debugging on errors. @@ -118,15 +129,14 @@ func create(t *testing.T, ptty pty.PTY, name string) *PTY { defer close(logDone) s := bufio.NewScanner(logr) for s.Scan() { - tpty.logf("%q", stripansi.Strip(s.Text())) + ex.logf("%q", stripansi.Strip(s.Text())) } }() - return tpty + return ex } -type PTY struct { - pty.PTY +type outExpecter struct { t *testing.T close func(reason string) error out *stdbuf @@ -135,38 +145,23 @@ type PTY struct { runeReader *bufio.Reader } -func (p *PTY) Close() error { - p.t.Helper() - - return p.close("close") -} - -func (p *PTY) Attach(inv *clibase.Invocation) *PTY { - p.t.Helper() - - inv.Stdout = p.Output() - inv.Stderr = p.Output() - inv.Stdin = p.Input() - return p -} - -func (p *PTY) ExpectMatch(str string) string { - p.t.Helper() +func (e *outExpecter) ExpectMatch(str string) string { + e.t.Helper() timeout, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium) defer cancel() - return p.ExpectMatchContext(timeout, str) + return e.ExpectMatchContext(timeout, str) } // TODO(mafredri): Rename this to ExpectMatch when refactoring. -func (p *PTY) ExpectMatchContext(ctx context.Context, str string) string { - p.t.Helper() +func (e *outExpecter) ExpectMatchContext(ctx context.Context, str string) string { + e.t.Helper() var buffer bytes.Buffer - err := p.doMatchWithDeadline(ctx, "ExpectMatchContext", func() error { + err := e.doMatchWithDeadline(ctx, "ExpectMatchContext", func() error { for { - r, _, err := p.runeReader.ReadRune() + r, _, err := e.runeReader.ReadRune() if err != nil { return err } @@ -180,54 +175,54 @@ func (p *PTY) ExpectMatchContext(ctx context.Context, str string) string { } }) if err != nil { - p.fatalf("read error", "%v (wanted %q; got %q)", err, str, buffer.String()) + e.fatalf("read error", "%v (wanted %q; got %q)", err, str, buffer.String()) return "" } - p.logf("matched %q = %q", str, stripansi.Strip(buffer.String())) + e.logf("matched %q = %q", str, stripansi.Strip(buffer.String())) return buffer.String() } -func (p *PTY) Peek(ctx context.Context, n int) []byte { - p.t.Helper() +func (e *outExpecter) Peek(ctx context.Context, n int) []byte { + e.t.Helper() var out []byte - err := p.doMatchWithDeadline(ctx, "Peek", func() error { + err := e.doMatchWithDeadline(ctx, "Peek", func() error { var err error - out, err = p.runeReader.Peek(n) + out, err = e.runeReader.Peek(n) return err }) if err != nil { - p.fatalf("read error", "%v (wanted %d bytes; got %d: %q)", err, n, len(out), out) + e.fatalf("read error", "%v (wanted %d bytes; got %d: %q)", err, n, len(out), out) return nil } - p.logf("peeked %d/%d bytes = %q", len(out), n, out) + e.logf("peeked %d/%d bytes = %q", len(out), n, out) return slices.Clone(out) } -func (p *PTY) ReadRune(ctx context.Context) rune { - p.t.Helper() +func (e *outExpecter) ReadRune(ctx context.Context) rune { + e.t.Helper() var r rune - err := p.doMatchWithDeadline(ctx, "ReadRune", func() error { + err := e.doMatchWithDeadline(ctx, "ReadRune", func() error { var err error - r, _, err = p.runeReader.ReadRune() + r, _, err = e.runeReader.ReadRune() return err }) if err != nil { - p.fatalf("read error", "%v (wanted rune; got %q)", err, r) + e.fatalf("read error", "%v (wanted rune; got %q)", err, r) return 0 } - p.logf("matched rune = %q", r) + e.logf("matched rune = %q", r) return r } -func (p *PTY) ReadLine(ctx context.Context) string { - p.t.Helper() +func (e *outExpecter) ReadLine(ctx context.Context) string { + e.t.Helper() var buffer bytes.Buffer - err := p.doMatchWithDeadline(ctx, "ReadLine", func() error { + err := e.doMatchWithDeadline(ctx, "ReadLine", func() error { for { - r, _, err := p.runeReader.ReadRune() + r, _, err := e.runeReader.ReadRune() if err != nil { return err } @@ -240,14 +235,14 @@ func (p *PTY) ReadLine(ctx context.Context) string { // Unicode code points can be up to 4 bytes, but the // ones we're looking for are only 1 byte. - b, _ := p.runeReader.Peek(1) + b, _ := e.runeReader.Peek(1) if len(b) == 0 { return nil } r, _ = utf8.DecodeRune(b) if r == '\n' { - _, _, err = p.runeReader.ReadRune() + _, _, err = e.runeReader.ReadRune() if err != nil { return err } @@ -263,21 +258,21 @@ func (p *PTY) ReadLine(ctx context.Context) string { } }) if err != nil { - p.fatalf("read error", "%v (wanted newline; got %q)", err, buffer.String()) + e.fatalf("read error", "%v (wanted newline; got %q)", err, buffer.String()) return "" } - p.logf("matched newline = %q", buffer.String()) + e.logf("matched newline = %q", buffer.String()) return buffer.String() } -func (p *PTY) doMatchWithDeadline(ctx context.Context, name string, fn func() error) error { - p.t.Helper() +func (e *outExpecter) doMatchWithDeadline(ctx context.Context, name string, fn func() error) error { + e.t.Helper() // A timeout is mandatory, caller can decide by passing a context // that times out. if _, ok := ctx.Deadline(); !ok { timeout := testutil.WaitMedium - p.logf("%s ctx has no deadline, using %s", name, timeout) + e.logf("%s ctx has no deadline, using %s", name, timeout) var cancel context.CancelFunc //nolint:gocritic // Rule guard doesn't detect that we're using testutil.Wait*. ctx, cancel = context.WithTimeout(ctx, timeout) @@ -294,13 +289,55 @@ func (p *PTY) doMatchWithDeadline(ctx context.Context, name string, fn func() er return err case <-ctx.Done(): // Ensure goroutine is cleaned up before test exit. - _ = p.close("match deadline exceeded") + _ = e.close("match deadline exceeded") <-match return xerrors.Errorf("match deadline exceeded: %w", ctx.Err()) } } +func (e *outExpecter) logf(format string, args ...interface{}) { + e.t.Helper() + + // Match regular logger timestamp format, we seem to be logging in + // UTC in other places as well, so match here. + e.t.Logf("%s: %s: %s", time.Now().UTC().Format("2006-01-02 15:04:05.000"), e.name, fmt.Sprintf(format, args...)) +} + +func (e *outExpecter) fatalf(reason string, format string, args ...interface{}) { + e.t.Helper() + + // Ensure the message is part of the normal log stream before + // failing the test. + e.logf("%s: %s", reason, fmt.Sprintf(format, args...)) + + require.FailNowf(e.t, reason, format, args...) +} + +type PTY struct { + outExpecter + pty.PTY +} + +func (p *PTY) Close() error { + p.t.Helper() + pErr := p.PTY.Close() + eErr := p.outExpecter.close("close") + if pErr != nil { + return pErr + } + return eErr +} + +func (p *PTY) Attach(inv *clibase.Invocation) *PTY { + p.t.Helper() + + inv.Stdout = p.Output() + inv.Stderr = p.Output() + inv.Stdin = p.Input() + return p +} + func (p *PTY) Write(r rune) { p.t.Helper() @@ -321,22 +358,19 @@ func (p *PTY) WriteLine(str string) { require.NoError(p.t, err, "write line failed") } -func (p *PTY) logf(format string, args ...interface{}) { - p.t.Helper() - - // Match regular logger timestamp format, we seem to be logging in - // UTC in other places as well, so match here. - p.t.Logf("%s: %s: %s", time.Now().UTC().Format("2006-01-02 15:04:05.000"), p.name, fmt.Sprintf(format, args...)) +type PTYCmd struct { + outExpecter + pty.PTYCmd } -func (p *PTY) fatalf(reason string, format string, args ...interface{}) { +func (p *PTYCmd) Close() error { p.t.Helper() - - // Ensure the message is part of the normal log stream before - // failing the test. - p.logf("%s: %s", reason, fmt.Sprintf(format, args...)) - - require.FailNowf(p.t, reason, format, args...) + pErr := p.PTYCmd.Close() + eErr := p.outExpecter.close("close") + if pErr != nil { + return pErr + } + return eErr } // stdbuf is like a buffered stdout, it buffers writes until read. diff --git a/pty/start.go b/pty/start.go index ea09cbb251767..565edaca43d80 100644 --- a/pty/start.go +++ b/pty/start.go @@ -20,6 +20,6 @@ func WithPTYOption(opts ...Option) StartOption { // Start the command in a TTY. The calling code must not use cmd after passing it to the PTY, and // instead rely on the returned Process to manage the command/process. -func Start(cmd *exec.Cmd, opt ...StartOption) (PTY, Process, error) { +func Start(cmd *exec.Cmd, opt ...StartOption) (PTYCmd, Process, error) { return startPty(cmd, opt...) } diff --git a/pty/start_other.go b/pty/start_other.go index c38b6dcf8faee..a6353a138d9aa 100644 --- a/pty/start_other.go +++ b/pty/start_other.go @@ -50,6 +50,17 @@ func startPty(cmd *exec.Cmd, opt ...StartOption) (retPTY *otherPty, proc Process } return nil, nil, xerrors.Errorf("start: %w", err) } + // Now that we've started the command, and passed the TTY to it, close our + // file so that the other process has the only open file to the TTY. Once + // the process closes the TTY (usually on exit), there will be no open + // references and the OS kernel returns an error when trying to read or + // write to our PTY end. Without this, reading from the process output + // will block until we close our TTY. + if err := opty.tty.Close(); err != nil { + _ = cmd.Process.Kill() + return nil, nil, xerrors.Errorf("close tty: %w", err) + } + opty.tty = nil // remove so we don't attempt to close it again. oProcess := &otherProcess{ pty: opty.pty, cmd: cmd, diff --git a/pty/start_other_test.go b/pty/start_other_test.go index d1f11a419e48f..264f7912a89cc 100644 --- a/pty/start_other_test.go +++ b/pty/start_other_test.go @@ -25,20 +25,25 @@ func TestStart(t *testing.T) { t.Run("Echo", func(t *testing.T) { t.Parallel() pty, ps := ptytest.Start(t, exec.Command("echo", "test")) + pty.ExpectMatch("test") err := ps.Wait() require.NoError(t, err) + err = pty.Close() + require.NoError(t, err) }) t.Run("Kill", func(t *testing.T) { t.Parallel() - _, ps := ptytest.Start(t, exec.Command("sleep", "30")) + pty, ps := ptytest.Start(t, exec.Command("sleep", "30")) err := ps.Kill() assert.NoError(t, err) err = ps.Wait() var exitErr *exec.ExitError require.True(t, xerrors.As(err, &exitErr)) assert.NotEqual(t, 0, exitErr.ExitCode()) + err = pty.Close() + require.NoError(t, err) }) t.Run("SSH_TTY", func(t *testing.T) { @@ -53,5 +58,29 @@ func TestStart(t *testing.T) { pty.ExpectMatch("SSH_TTY=/dev/") err := ps.Wait() require.NoError(t, err) + err = pty.Close() + require.NoError(t, err) }) } + +// these constants/vars are used by Test_Start_copy + +const cmdEcho = "echo" + +var argEcho = []string{"test"} + +// these constants/vars are used by Test_Start_truncate + +const ( + countEnd = 1000 + cmdCount = "sh" +) + +var argCount = []string{"-c", ` +i=0 +while [ $i -ne 1000 ] +do + i=$(($i+1)) + echo "$i" +done +`} diff --git a/pty/start_test.go b/pty/start_test.go new file mode 100644 index 0000000000000..d8711cb99c0a4 --- /dev/null +++ b/pty/start_test.go @@ -0,0 +1,148 @@ +package pty_test + +import ( + "bytes" + "context" + "fmt" + "io" + "os/exec" + "strings" + "testing" + "time" + + "github.com/hinshun/vt10x" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/coder/coder/pty" + "github.com/coder/coder/testutil" +) + +// Test_Start_copy tests that we can use io.Copy() on command output +// without deadlocking. +func Test_Start_copy(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) + defer cancel() + + pc, cmd, err := pty.Start(exec.CommandContext(ctx, cmdEcho, argEcho...)) + require.NoError(t, err) + b := &bytes.Buffer{} + readDone := make(chan error, 1) + go func() { + _, err := io.Copy(b, pc.OutputReader()) + readDone <- err + }() + + select { + case err := <-readDone: + require.NoError(t, err) + case <-ctx.Done(): + t.Error("read timed out") + } + assert.Contains(t, b.String(), "test") + + cmdDone := make(chan error, 1) + go func() { + cmdDone <- cmd.Wait() + }() + + select { + case err := <-cmdDone: + require.NoError(t, err) + case <-ctx.Done(): + t.Error("cmd.Wait() timed out") + } +} + +// Test_Start_truncation tests that we can read command output without truncation +// even after the command has exited. +func Test_Start_truncation(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong) + defer cancel() + + pc, cmd, err := pty.Start(exec.CommandContext(ctx, cmdCount, argCount...)) + + require.NoError(t, err) + readDone := make(chan struct{}) + go func() { + defer close(readDone) + // avoid buffered IO so that we can precisely control how many bytes to read. + n := 1 + for n <= countEnd { + want := fmt.Sprintf("%d", n) + err := readUntil(ctx, t, want, pc.OutputReader()) + assert.NoError(t, err, "want: %s", want) + if err != nil { + return + } + n++ + if (countEnd - n) < 100 { + // If the OS buffers the output, the process can exit even if + // we're not done reading. We want to slow our reads so that + // if there is a race between reading the data and it being + // truncated, we will lose and fail the test. + time.Sleep(testutil.IntervalFast) + } + } + // ensure we still get to EOF + endB := &bytes.Buffer{} + _, err := io.Copy(endB, pc.OutputReader()) + assert.NoError(t, err) + }() + + cmdDone := make(chan error, 1) + go func() { + cmdDone <- cmd.Wait() + }() + + select { + case err := <-cmdDone: + require.NoError(t, err) + case <-ctx.Done(): + t.Fatal("cmd.Wait() timed out") + } + + select { + case <-readDone: + // OK! + case <-ctx.Done(): + t.Fatal("read timed out") + } +} + +// readUntil reads one byte at a time until we either see the string we want, or the context expires +func readUntil(ctx context.Context, t *testing.T, want string, r io.Reader) error { + // output can contain virtual terminal sequences, so we need to parse these + // to correctly interpret getting what we want. + term := vt10x.New(vt10x.WithSize(80, 80)) + readErrs := make(chan error, 1) + for { + b := make([]byte, 1) + go func() { + _, err := r.Read(b) + readErrs <- err + }() + select { + case err := <-readErrs: + if err != nil { + t.Logf("err: %v\ngot: %v", err, term) + return err + } + term.Write(b) + case <-ctx.Done(): + return ctx.Err() + } + got := term.String() + lines := strings.Split(got, "\n") + for _, line := range lines { + if strings.TrimSpace(line) == want { + t.Logf("want: %v\n got:%v", want, line) + return nil + } + } + } +} diff --git a/pty/start_windows.go b/pty/start_windows.go index f9307cd364b84..2811900ffc361 100644 --- a/pty/start_windows.go +++ b/pty/start_windows.go @@ -17,7 +17,7 @@ import ( // Allocates a PTY and starts the specified command attached to it. // See: https://docs.microsoft.com/en-us/windows/console/creating-a-pseudoconsole-session#creating-the-hosted-process -func startPty(cmd *exec.Cmd, opt ...StartOption) (PTY, Process, error) { +func startPty(cmd *exec.Cmd, opt ...StartOption) (_ PTYCmd, _ Process, retErr error) { var opts startOptions for _, o := range opt { o(&opts) @@ -45,11 +45,18 @@ func startPty(cmd *exec.Cmd, opt ...StartOption) (PTY, Process, error) { if err != nil { return nil, nil, err } - pty, err := newPty(opts.ptyOpts...) + + winPty, err := newPty(opts.ptyOpts...) if err != nil { return nil, nil, err } - winPty := pty.(*ptyWindows) + defer func() { + if retErr != nil { + // we hit some error finishing setup; close pty, so + // we don't leak the kernel resources associated with it + _ = winPty.Close() + } + }() if winPty.opts.sshReq != nil { cmd.Env = append(cmd.Env, fmt.Sprintf("SSH_TTY=%s", winPty.Name())) } @@ -95,9 +102,34 @@ func startPty(cmd *exec.Cmd, opt ...StartOption) (PTY, Process, error) { wp := &windowsProcess{ cmdDone: make(chan any), proc: process, + pw: winPty, + } + defer func() { + if retErr != nil { + // if we later error out, kill the process since + // the caller will have no way to interact with it + _ = process.Kill() + } + }() + + // Now that we've started the command, and passed the pseudoconsole to it, + // close the output write and input read files, so that the other process + // has the only handles to them. Once the process closes the console, there + // will be no open references and the OS kernel returns an error when trying + // to read or write to our end. Without this, reading from the process + // output will block until they are closed. + errO := winPty.outputWrite.Close() + winPty.outputWrite = nil + errI := winPty.inputRead.Close() + winPty.inputRead = nil + if errO != nil { + return nil, nil, errO + } + if errI != nil { + return nil, nil, errI } go wp.waitInternal() - return pty, wp, nil + return winPty, wp, nil } // Taken from: https://github.com/microsoft/hcsshim/blob/7fbdca16f91de8792371ba22b7305bf4ca84170a/internal/exec/exec.go#L476 diff --git a/pty/start_windows_test.go b/pty/start_windows_test.go index edbbd5dd99c3b..a8e287e1ed40a 100644 --- a/pty/start_windows_test.go +++ b/pty/start_windows_test.go @@ -4,6 +4,7 @@ package pty_test import ( + "fmt" "os/exec" "testing" @@ -22,25 +23,46 @@ func TestStart(t *testing.T) { t.Parallel() t.Run("Echo", func(t *testing.T) { t.Parallel() - pty, ps := ptytest.Start(t, exec.Command("cmd.exe", "/c", "echo", "test")) - pty.ExpectMatch("test") + ptty, ps := ptytest.Start(t, exec.Command("cmd.exe", "/c", "echo", "test")) + ptty.ExpectMatch("test") err := ps.Wait() require.NoError(t, err) + err = ptty.Close() + require.NoError(t, err) }) t.Run("Resize", func(t *testing.T) { t.Parallel() - pty, _ := ptytest.Start(t, exec.Command("cmd.exe")) - err := pty.Resize(100, 50) + ptty, _ := ptytest.Start(t, exec.Command("cmd.exe")) + err := ptty.Resize(100, 50) + require.NoError(t, err) + err = ptty.Close() require.NoError(t, err) }) t.Run("Kill", func(t *testing.T) { t.Parallel() - _, ps := ptytest.Start(t, exec.Command("cmd.exe")) + ptty, ps := ptytest.Start(t, exec.Command("cmd.exe")) err := ps.Kill() assert.NoError(t, err) err = ps.Wait() var exitErr *exec.ExitError require.True(t, xerrors.As(err, &exitErr)) assert.NotEqual(t, 0, exitErr.ExitCode()) + err = ptty.Close() + require.NoError(t, err) }) } + +// these constants/vars are used by Test_Start_copy + +const cmdEcho = "cmd.exe" + +var argEcho = []string{"/c", "echo", "test"} + +// these constants/vars are used by Test_Start_truncate + +const ( + countEnd = 1000 + cmdCount = "cmd.exe" +) + +var argCount = []string{"/c", fmt.Sprintf("for /L %%n in (1,1,%d) do @echo %%n", countEnd)}