From 0075d7da3108610bb2fd2c2a55394187baf4a2da Mon Sep 17 00:00:00 2001 From: Spike Curtis Date: Tue, 11 Apr 2023 15:00:05 +0400 Subject: [PATCH 01/21] Add ssh tests for longoutput, orphan Signed-off-by: Spike Curtis --- agent/agentssh/agentssh_internal_test.go | 207 +++++++++++++++++++++++ agent/agentssh/agentssh_test.go | 29 +--- 2 files changed, 210 insertions(+), 26 deletions(-) create mode 100644 agent/agentssh/agentssh_internal_test.go diff --git a/agent/agentssh/agentssh_internal_test.go b/agent/agentssh/agentssh_internal_test.go new file mode 100644 index 0000000000000..338ca78628299 --- /dev/null +++ b/agent/agentssh/agentssh_internal_test.go @@ -0,0 +1,207 @@ +package agentssh + +import ( + "bufio" + "context" + "net" + "strconv" + "testing" + "time" + + "golang.org/x/crypto/ssh" + + "cdr.dev/slog/sloggers/slogtest" + "github.com/coder/coder/codersdk/agentsdk" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/atomic" +) + +const countingScript = ` +i=0 +while [ $i -ne 20000 ] +do + i=$(($i+1)) + echo "$i" +done +` + +func TestServer_sessionStart_longoutput(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + logger := slogtest.Make(t, nil) + s, err := NewServer(ctx, logger, 0) + require.NoError(t, err) + + // The assumption is that these are set before serving SSH connections. + s.AgentToken = func() string { return "" } + s.Manifest = atomic.NewPointer(&agentsdk.Manifest{}) + + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + + done := make(chan struct{}) + go func() { + defer close(done) + err := s.Serve(ln) + assert.Error(t, err) // Server is closed. + }() + + c := SSHTestClient(t, ln.Addr().String()) + sess, err := c.NewSession() + require.NoError(t, err) + + stdout, err := sess.StdoutPipe() + require.NoError(t, err) + readDone := make(chan struct{}) + go func() { + w := 0 + defer close(readDone) + s := bufio.NewScanner(stdout) + for s.Scan() { + w++ + ns := s.Text() + n, err := strconv.Atoi(ns) + require.NoError(t, err) + require.Equal(t, w, n, "output corrupted") + } + assert.Equal(t, w, 20000, "output truncated") + assert.NoError(t, s.Err()) + }() + + err = sess.Start(countingScript) + require.NoError(t, err) + + select { + case <-readDone: + // OK + case <-ctx.Done(): + t.Fatal("read timeout") + } + + sessionDone := make(chan struct{}) + go func() { + defer close(sessionDone) + err := sess.Wait() + assert.NoError(t, err) + }() + + select { + case <-sessionDone: + // OK! + case <-ctx.Done(): + t.Fatal("session timeout") + } +} + +const longScript = ` +echo "started" +sleep 30 +echo "done" +` + +func TestServer_sessionStart_orphan(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + logger := slogtest.Make(t, nil) + s, err := NewServer(ctx, logger, 0) + require.NoError(t, err) + + // The assumption is that these are set before serving SSH connections. + s.AgentToken = func() string { return "" } + s.Manifest = atomic.NewPointer(&agentsdk.Manifest{}) + + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + + done := make(chan struct{}) + go func() { + defer close(done) + err := s.Serve(ln) + assert.Error(t, err) // Server is closed. + }() + + c := SSHTestClient(t, ln.Addr().String()) + sess, err := c.NewSession() + require.NoError(t, err) + + stdout, err := sess.StdoutPipe() + require.NoError(t, err) + readDone := make(chan struct{}) + go func() { + defer close(readDone) + s := bufio.NewScanner(stdout) + require.True(t, s.Scan()) + txt := s.Text() + assert.Equal(t, "started", txt, "output corrupted") + }() + + err = sess.Start(longScript) + require.NoError(t, err) + + select { + case <-readDone: + // OK + case <-ctx.Done(): + t.Fatal("read timeout") + } + + // process is started, and should be sleeping for ~30 seconds + // close the session + err = sess.Close() + require.NoError(t, err) + + // now, we wait for the handler to complete. If it does so before the + // main test timeout, we consider this a pass. If not, it indicates + // that the server isn't properly shutting down sessions when they are + // disconnected client side, which could lead to processes hanging around + // indefinitely. + handlerDone := make(chan struct{}) + go func() { + defer close(handlerDone) + for { + select { + case <-time.After(time.Millisecond * 10): + s.mu.Lock() + n := len(s.sessions) + s.mu.Unlock() + if n == 0 { + return + } + } + } + }() + + select { + case <-handlerDone: + // OK! + case <-ctx.Done(): + t.Fatal("handler timeout") + } +} + +// SSHTestClient creates an ssh.Client for testing +func SSHTestClient(t *testing.T, addr string) *ssh.Client { + conn, err := net.Dial("tcp", addr) + require.NoError(t, err) + t.Cleanup(func() { + _ = conn.Close() + }) + + sshConn, channels, requests, err := ssh.NewClientConn(conn, "localhost:22", &ssh.ClientConfig{ + HostKeyCallback: ssh.InsecureIgnoreHostKey(), //nolint:gosec // This is a test. + }) + require.NoError(t, err) + t.Cleanup(func() { + _ = sshConn.Close() + }) + c := ssh.NewClient(sshConn, channels, requests) + t.Cleanup(func() { + _ = c.Close() + }) + return c +} diff --git a/agent/agentssh/agentssh_test.go b/agent/agentssh/agentssh_test.go index 684c0e36bbb18..6eea6accb9e7c 100644 --- a/agent/agentssh/agentssh_test.go +++ b/agent/agentssh/agentssh_test.go @@ -10,13 +10,11 @@ import ( "sync" "testing" + "cdr.dev/slog/sloggers/slogtest" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.uber.org/atomic" "go.uber.org/goleak" - "golang.org/x/crypto/ssh" - - "cdr.dev/slog/sloggers/slogtest" "github.com/coder/coder/agent/agentssh" "github.com/coder/coder/codersdk/agentsdk" @@ -49,7 +47,7 @@ func TestNewServer_ServeClient(t *testing.T) { assert.Error(t, err) // Server is closed. }() - c := sshClient(t, ln.Addr().String()) + c := agentssh.SSHTestClient(t, ln.Addr().String()) var b bytes.Buffer sess, err := c.NewSession() sess.Stdout = &b @@ -95,7 +93,7 @@ func TestNewServer_CloseActiveConnections(t *testing.T) { doClose := make(chan struct{}) go func() { defer wg.Done() - c := sshClient(t, ln.Addr().String()) + c := agentssh.SSHTestClient(t, ln.Addr().String()) sess, err := c.NewSession() sess.Stdin = pty.Input() sess.Stdout = pty.Output() @@ -116,24 +114,3 @@ func TestNewServer_CloseActiveConnections(t *testing.T) { wg.Wait() } - -func sshClient(t *testing.T, addr string) *ssh.Client { - conn, err := net.Dial("tcp", addr) - require.NoError(t, err) - t.Cleanup(func() { - _ = conn.Close() - }) - - sshConn, channels, requests, err := ssh.NewClientConn(conn, "localhost:22", &ssh.ClientConfig{ - HostKeyCallback: ssh.InsecureIgnoreHostKey(), //nolint:gosec // This is a test. - }) - require.NoError(t, err) - t.Cleanup(func() { - _ = sshConn.Close() - }) - c := ssh.NewClient(sshConn, channels, requests) - t.Cleanup(func() { - _ = c.Close() - }) - return c -} From a491d4f99e1ce188b3df2fea8e6583f7f5520aa9 Mon Sep 17 00:00:00 2001 From: Spike Curtis Date: Wed, 12 Apr 2023 13:52:01 +0400 Subject: [PATCH 02/21] PTY/SSH tests & improvements Signed-off-by: Spike Curtis --- agent/agentssh/agentssh.go | 196 ++++++++-------- agent/agentssh/agentssh_internal_test.go | 271 +++++++++++------------ agent/agentssh/agentssh_test.go | 116 +++++++++- pty/pty.go | 31 ++- pty/pty_other.go | 37 ++-- pty/pty_windows.go | 23 +- pty/ptytest/ptytest.go | 167 ++++++++------ pty/start.go | 2 +- pty/start_other.go | 11 + pty/start_windows.go | 16 ++ 10 files changed, 521 insertions(+), 349 deletions(-) diff --git a/agent/agentssh/agentssh.go b/agent/agentssh/agentssh.go index 86e1eb9e36af4..5430e891309eb 100644 --- a/agent/agentssh/agentssh.go +++ b/agent/agentssh/agentssh.go @@ -8,6 +8,7 @@ import ( "errors" "fmt" "io" + "io/fs" "net" "os" "os/exec" @@ -191,7 +192,7 @@ func (s *Server) sessionHandler(session ssh.Session) { _ = session.Exit(0) } -func (s *Server) sessionStart(session ssh.Session) (retErr error) { +func (s *Server) sessionStart(session ssh.Session) error { ctx := session.Context() env := session.Environ() var magicType string @@ -233,102 +234,13 @@ func (s *Server) sessionStart(session ssh.Session) (retErr error) { sshPty, windowSize, isPty := session.Pty() if isPty { - // Disable minimal PTY emulation set by gliderlabs/ssh (NL-to-CRNL). - // See https://github.com/coder/coder/issues/3371. - session.DisablePTYEmulation() - - if !isQuietLogin(session.RawCommand()) { - manifest := s.Manifest.Load() - if manifest != nil { - err = showMOTD(session, manifest.MOTDFile) - if err != nil { - s.logger.Error(ctx, "show MOTD", slog.Error(err)) - } - } else { - s.logger.Warn(ctx, "metadata lookup failed, unable to show MOTD") - } - } - - cmd.Env = append(cmd.Env, fmt.Sprintf("TERM=%s", sshPty.Term)) - - // The pty package sets `SSH_TTY` on supported platforms. - ptty, process, err := pty.Start(cmd, pty.WithPTYOption( - pty.WithSSHRequest(sshPty), - pty.WithLogger(slog.Stdlib(ctx, s.logger, slog.LevelInfo)), - )) - if err != nil { - return xerrors.Errorf("start command: %w", err) - } - var wg sync.WaitGroup - defer func() { - defer wg.Wait() - closeErr := ptty.Close() - if closeErr != nil { - s.logger.Warn(ctx, "failed to close tty", slog.Error(closeErr)) - if retErr == nil { - retErr = closeErr - } - } - }() - go func() { - for win := range windowSize { - resizeErr := ptty.Resize(uint16(win.Height), uint16(win.Width)) - // If the pty is closed, then command has exited, no need to log. - if resizeErr != nil && !errors.Is(resizeErr, pty.ErrClosed) { - s.logger.Warn(ctx, "failed to resize tty", slog.Error(resizeErr)) - } - } - }() - // We don't add input copy to wait group because - // it won't return until the session is closed. - go func() { - _, _ = io.Copy(ptty.Input(), session) - }() - - // In low parallelism scenarios, the command may exit and we may close - // the pty before the output copy has started. This can result in the - // output being lost. To avoid this, we wait for the output copy to - // start before waiting for the command to exit. This ensures that the - // output copy goroutine will be scheduled before calling close on the - // pty. This shouldn't be needed because of `pty.Dup()` below, but it - // may not be supported on all platforms. - outputCopyStarted := make(chan struct{}) - ptyOutput := func() io.ReadCloser { - defer close(outputCopyStarted) - // Try to dup so we can separate stdin and stdout closure. - // Once the original pty is closed, the dup will return - // input/output error once the buffered data has been read. - stdout, err := ptty.Dup() - if err == nil { - return stdout - } - // If we can't dup, we shouldn't close - // the fd since it's tied to stdin. - return readNopCloser{ptty.Output()} - } - wg.Add(1) - go func() { - // Ensure data is flushed to session on command exit, if we - // close the session too soon, we might lose data. - defer wg.Done() - - stdout := ptyOutput() - defer stdout.Close() - - _, _ = io.Copy(session, stdout) - }() - <-outputCopyStarted - - err = process.Wait() - var exitErr *exec.ExitError - // ExitErrors just mean the command we run returned a non-zero exit code, which is normal - // and not something to be concerned about. But, if it's something else, we should log it. - if err != nil && !xerrors.As(err, &exitErr) { - s.logger.Warn(ctx, "wait error", slog.Error(err)) - } - return err + return s.startPTYSession(session, cmd, sshPty, windowSize) } + return s.startNonPTYSession(session, cmd) +} + +func (s *Server) startNonPTYSession(session ssh.Session, cmd *exec.Cmd) error { cmd.Stdout = session cmd.Stderr = session.Stderr() // This blocks forever until stdin is received if we don't @@ -348,6 +260,100 @@ func (s *Server) sessionStart(session ssh.Session) (retErr error) { return cmd.Wait() } +// ptySession is the interface to the ssh.Session that startPTYSession uses +// we use an interface here so that we can fake it in tests. +type ptySession interface { + io.ReadWriter + Context() ssh.Context + DisablePTYEmulation() + RawCommand() string +} + +func (s *Server) startPTYSession(session ptySession, cmd *exec.Cmd, sshPty ssh.Pty, windowSize <-chan ssh.Window) (retErr error) { + ctx := session.Context() + // Disable minimal PTY emulation set by gliderlabs/ssh (NL-to-CRNL). + // See https://github.com/coder/coder/issues/3371. + session.DisablePTYEmulation() + + if !isQuietLogin(session.RawCommand()) { + manifest := s.Manifest.Load() + if manifest != nil { + err := showMOTD(session, manifest.MOTDFile) + if err != nil { + s.logger.Error(ctx, "show MOTD", slog.Error(err)) + } + } else { + s.logger.Warn(ctx, "metadata lookup failed, unable to show MOTD") + } + } + + cmd.Env = append(cmd.Env, fmt.Sprintf("TERM=%s", sshPty.Term)) + + // The pty package sets `SSH_TTY` on supported platforms. + ptty, process, err := pty.Start(cmd, pty.WithPTYOption( + pty.WithSSHRequest(sshPty), + pty.WithLogger(slog.Stdlib(ctx, s.logger, slog.LevelInfo)), + )) + if err != nil { + return xerrors.Errorf("start command: %w", err) + } + defer func() { + closeErr := ptty.Close() + if closeErr != nil { + s.logger.Warn(ctx, "failed to close tty", slog.Error(closeErr)) + if retErr == nil { + retErr = closeErr + } + } + }() + go func() { + for win := range windowSize { + resizeErr := ptty.Resize(uint16(win.Height), uint16(win.Width)) + // If the pty is closed, then command has exited, no need to log. + if resizeErr != nil && !errors.Is(resizeErr, pty.ErrClosed) { + s.logger.Warn(ctx, "failed to resize tty", slog.Error(resizeErr)) + } + } + }() + + go func() { + _, _ = io.Copy(ptty.InputWriter(), session) + }() + + // We need to wait for the command output to finish copying. It's safe to + // just do this copy on the main handler goroutine because one of two things + // will happen: + // + // 1. The command completes & closes the TTY, which then triggers an error + // after we've Read() all the buffered data from the PTY. + // 2. The client hangs up, which cancels the command's Context, and go will + // kill the command's process. This then has the same effect as (1). + n, err := io.Copy(session, ptty.OutputReader()) + s.logger.Debug(ctx, "copy output done", slog.F("bytes", n), slog.Error(err)) + + // output from the ptty will hit a PathErr on the PTY when the process + // hangs up the other side (typically when the process exits, but could + // be earlier) + pathErr := &fs.PathError{} + if err != nil && !xerrors.As(err, &pathErr) { + return xerrors.Errorf("copy error: %w", err, err, err) + } + // We've gotten all the output, but we need to wait for the process to + // complete so that we can get the exit code. This returns + // immediately if the TTY was closed as part of the command exiting. + err = process.Wait() + var exitErr *exec.ExitError + // ExitErrors just mean the command we run returned a non-zero exit code, which is normal + // and not something to be concerned about. But, if it's something else, we should log it. + if err != nil && !xerrors.As(err, &exitErr) { + s.logger.Warn(ctx, "wait error", slog.Error(err)) + } + if err != nil { + return xerrors.Errorf("process wait: %w", err) + } + return nil +} + type readNopCloser struct{ io.Reader } // Close implements io.Closer. diff --git a/agent/agentssh/agentssh_internal_test.go b/agent/agentssh/agentssh_internal_test.go index 338ca78628299..1db767badd8b7 100644 --- a/agent/agentssh/agentssh_internal_test.go +++ b/agent/agentssh/agentssh_internal_test.go @@ -3,30 +3,29 @@ package agentssh import ( "bufio" "context" + "io" "net" - "strconv" + "os/exec" "testing" "time" - "golang.org/x/crypto/ssh" - "cdr.dev/slog/sloggers/slogtest" - "github.com/coder/coder/codersdk/agentsdk" + gliderssh "github.com/gliderlabs/ssh" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "go.uber.org/atomic" ) -const countingScript = ` -i=0 -while [ $i -ne 20000 ] -do - i=$(($i+1)) - echo "$i" -done +const longScript = ` +echo "started" +sleep 30 +echo "done" ` -func TestServer_sessionStart_longoutput(t *testing.T) { +// Test_sessionStart_orphan tests running a command that takes a long time to +// exit normally, and terminate the SSH session context early to verify that we +// return quickly and don't leave the command running as an "orphan" with no +// active SSH session. +func Test_sessionStart_orphan(t *testing.T) { t.Parallel() ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) @@ -35,173 +34,151 @@ func TestServer_sessionStart_longoutput(t *testing.T) { s, err := NewServer(ctx, logger, 0) require.NoError(t, err) - // The assumption is that these are set before serving SSH connections. - s.AgentToken = func() string { return "" } - s.Manifest = atomic.NewPointer(&agentsdk.Manifest{}) - - ln, err := net.Listen("tcp", "127.0.0.1:0") - require.NoError(t, err) + // Here we're going to call the handler directly with a faked SSH session + // that just uses io.Pipes instead of a network socket. There is a large + // variation in the time between closing the socket from the client side and + // the SSH server canceling the session Context, which would lead to a flaky + // test if we did it that way. So instead, we directly cancel the context + // in this test. + sessionCtx, sessionCancel := context.WithCancel(ctx) + toClient, fromClient, sess := newTestSession(sessionCtx) + ptyInfo := gliderssh.Pty{} + windowSize := make(chan gliderssh.Window) + close(windowSize) + // the command gets the session context so that Go will terminate it when + // the session expires. + cmd := exec.CommandContext(sessionCtx, "sh", "-c", longScript) done := make(chan struct{}) go func() { defer close(done) - err := s.Serve(ln) - assert.Error(t, err) // Server is closed. + // we don't really care what the error is here. In the larger scenario, + // the client has disconnected, so we can't return any error information + // to them. + _ = s.startPTYSession(sess, cmd, ptyInfo, windowSize) }() - c := SSHTestClient(t, ln.Addr().String()) - sess, err := c.NewSession() - require.NoError(t, err) - - stdout, err := sess.StdoutPipe() - require.NoError(t, err) readDone := make(chan struct{}) go func() { - w := 0 defer close(readDone) - s := bufio.NewScanner(stdout) - for s.Scan() { - w++ - ns := s.Text() - n, err := strconv.Atoi(ns) - require.NoError(t, err) - require.Equal(t, w, n, "output corrupted") - } - assert.Equal(t, w, 20000, "output truncated") - assert.NoError(t, s.Err()) + s := bufio.NewScanner(toClient) + require.True(t, s.Scan()) + txt := s.Text() + assert.Equal(t, "started", txt, "output corrupted") }() - err = sess.Start(countingScript) - require.NoError(t, err) + waitForChan(t, readDone, ctx, "read timeout") + // process is started, and should be sleeping for ~30 seconds - select { - case <-readDone: - // OK - case <-ctx.Done(): - t.Fatal("read timeout") - } + sessionCancel() - sessionDone := make(chan struct{}) - go func() { - defer close(sessionDone) - err := sess.Wait() - assert.NoError(t, err) - }() + // now, we wait for the handler to complete. If it does so before the + // main test timeout, we consider this a pass. If not, it indicates + // that the server isn't properly shutting down sessions when they are + // disconnected client side, which could lead to processes hanging around + // indefinitely. + waitForChan(t, done, ctx, "handler timeout") + + err = fromClient.Close() + require.NoError(t, err) +} +func waitForChan(t *testing.T, c <-chan struct{}, ctx context.Context, msg string) { + t.Helper() select { - case <-sessionDone: + case <-c: // OK! case <-ctx.Done(): - t.Fatal("session timeout") + t.Fatal(msg) } } -const longScript = ` -echo "started" -sleep 30 -echo "done" -` +type testSession struct { + ctx testSSHContext -func TestServer_sessionStart_orphan(t *testing.T) { - t.Parallel() + // c2p is the client -> pty buffer + toPty *io.PipeReader + // p2c is the pty -> client buffer + fromPty *io.PipeWriter +} - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - logger := slogtest.Make(t, nil) - s, err := NewServer(ctx, logger, 0) - require.NoError(t, err) +type testSSHContext struct { + context.Context +} - // The assumption is that these are set before serving SSH connections. - s.AgentToken = func() string { return "" } - s.Manifest = atomic.NewPointer(&agentsdk.Manifest{}) +func newTestSession(ctx context.Context) (toClient *io.PipeReader, fromClient *io.PipeWriter, s ptySession) { + toClient, fromPty := io.Pipe() + toPty, fromClient := io.Pipe() - ln, err := net.Listen("tcp", "127.0.0.1:0") - require.NoError(t, err) + return toClient, fromClient, &testSession{ + ctx: testSSHContext{ctx}, + toPty: toPty, + fromPty: fromPty, + } +} - done := make(chan struct{}) - go func() { - defer close(done) - err := s.Serve(ln) - assert.Error(t, err) // Server is closed. - }() +func (s *testSession) Context() gliderssh.Context { + return s.ctx +} - c := SSHTestClient(t, ln.Addr().String()) - sess, err := c.NewSession() - require.NoError(t, err) +func (s *testSession) DisablePTYEmulation() {} - stdout, err := sess.StdoutPipe() - require.NoError(t, err) - readDone := make(chan struct{}) - go func() { - defer close(readDone) - s := bufio.NewScanner(stdout) - require.True(t, s.Scan()) - txt := s.Text() - assert.Equal(t, "started", txt, "output corrupted") - }() +// RawCommand returns "quiet logon" so that the PTY handler doesn't attempt to +// write the message of the day, which will interfere with our tests. It writes +// the message of the day if it's a shell login (zero length RawCommand()). +func (s *testSession) RawCommand() string { return "quiet logon" } - err = sess.Start(longScript) - require.NoError(t, err) +func (s *testSession) Read(p []byte) (n int, err error) { + return s.toPty.Read(p) +} - select { - case <-readDone: - // OK - case <-ctx.Done(): - t.Fatal("read timeout") - } +func (s *testSession) Write(p []byte) (n int, err error) { + return s.fromPty.Write(p) +} - // process is started, and should be sleeping for ~30 seconds - // close the session - err = sess.Close() - require.NoError(t, err) +func (c testSSHContext) Lock() { + panic("not implemented") +} +func (c testSSHContext) Unlock() { + panic("not implemented") +} - // now, we wait for the handler to complete. If it does so before the - // main test timeout, we consider this a pass. If not, it indicates - // that the server isn't properly shutting down sessions when they are - // disconnected client side, which could lead to processes hanging around - // indefinitely. - handlerDone := make(chan struct{}) - go func() { - defer close(handlerDone) - for { - select { - case <-time.After(time.Millisecond * 10): - s.mu.Lock() - n := len(s.sessions) - s.mu.Unlock() - if n == 0 { - return - } - } - } - }() +// User returns the username used when establishing the SSH connection. +func (c testSSHContext) User() string { + panic("not implemented") +} - select { - case <-handlerDone: - // OK! - case <-ctx.Done(): - t.Fatal("handler timeout") - } +// SessionID returns the session hash. +func (c testSSHContext) SessionID() string { + panic("not implemented") } -// SSHTestClient creates an ssh.Client for testing -func SSHTestClient(t *testing.T, addr string) *ssh.Client { - conn, err := net.Dial("tcp", addr) - require.NoError(t, err) - t.Cleanup(func() { - _ = conn.Close() - }) +// ClientVersion returns the version reported by the client. +func (c testSSHContext) ClientVersion() string { + panic("not implemented") +} - sshConn, channels, requests, err := ssh.NewClientConn(conn, "localhost:22", &ssh.ClientConfig{ - HostKeyCallback: ssh.InsecureIgnoreHostKey(), //nolint:gosec // This is a test. - }) - require.NoError(t, err) - t.Cleanup(func() { - _ = sshConn.Close() - }) - c := ssh.NewClient(sshConn, channels, requests) - t.Cleanup(func() { - _ = c.Close() - }) - return c +// ServerVersion returns the version reported by the server. +func (c testSSHContext) ServerVersion() string { + panic("not implemented") +} + +// RemoteAddr returns the remote address for this connection. +func (c testSSHContext) RemoteAddr() net.Addr { + panic("not implemented") +} + +// LocalAddr returns the local address for this connection. +func (c testSSHContext) LocalAddr() net.Addr { + panic("not implemented") +} + +// Permissions returns the Permissions object used for this connection. +func (c testSSHContext) Permissions() *gliderssh.Permissions { + panic("not implemented") +} + +// SetValue allows you to easily write new values into the underlying context. +func (c testSSHContext) SetValue(key, value interface{}) { + panic("not implemented") } diff --git a/agent/agentssh/agentssh_test.go b/agent/agentssh/agentssh_test.go index 6eea6accb9e7c..22878b709d79a 100644 --- a/agent/agentssh/agentssh_test.go +++ b/agent/agentssh/agentssh_test.go @@ -3,18 +3,22 @@ package agentssh_test import ( + "bufio" "bytes" "context" "net" + "strconv" "strings" "sync" "testing" + "time" "cdr.dev/slog/sloggers/slogtest" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.uber.org/atomic" "go.uber.org/goleak" + "golang.org/x/crypto/ssh" "github.com/coder/coder/agent/agentssh" "github.com/coder/coder/codersdk/agentsdk" @@ -47,7 +51,7 @@ func TestNewServer_ServeClient(t *testing.T) { assert.Error(t, err) // Server is closed. }() - c := agentssh.SSHTestClient(t, ln.Addr().String()) + c := sshClient(t, ln.Addr().String()) var b bytes.Buffer sess, err := c.NewSession() sess.Stdout = &b @@ -89,11 +93,12 @@ func TestNewServer_CloseActiveConnections(t *testing.T) { }() pty := ptytest.New(t) + defer pty.Close() doClose := make(chan struct{}) go func() { defer wg.Done() - c := agentssh.SSHTestClient(t, ln.Addr().String()) + c := sshClient(t, ln.Addr().String()) sess, err := c.NewSession() sess.Stdin = pty.Input() sess.Stdout = pty.Output() @@ -114,3 +119,110 @@ func TestNewServer_CloseActiveConnections(t *testing.T) { wg.Wait() } + +const countingScript = ` +i=0 +while [ $i -ne 20000 ] +do + i=$(($i+1)) + echo "$i" +done +` + +// TestServer_sessionStart_longoutput is designed to test running a command that +// produces a lot of output and ensure we don't truncate the output returned +// over SSH. +func TestServer_sessionStart_longoutput(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + logger := slogtest.Make(t, nil) + s, err := agentssh.NewServer(ctx, logger, 0) + require.NoError(t, err) + + // The assumption is that these are set before serving SSH connections. + s.AgentToken = func() string { return "" } + s.Manifest = atomic.NewPointer(&agentsdk.Manifest{}) + + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + + done := make(chan struct{}) + go func() { + defer close(done) + err := s.Serve(ln) + assert.Error(t, err) // Server is closed. + }() + + c := sshClient(t, ln.Addr().String()) + sess, err := c.NewSession() + require.NoError(t, err) + + stdout, err := sess.StdoutPipe() + require.NoError(t, err) + readDone := make(chan struct{}) + go func() { + w := 0 + defer close(readDone) + s := bufio.NewScanner(stdout) + for s.Scan() { + w++ + ns := s.Text() + n, err := strconv.Atoi(ns) + require.NoError(t, err) + require.Equal(t, w, n, "output corrupted") + } + assert.Equal(t, w, 20000, "output truncated") + assert.NoError(t, s.Err()) + }() + + err = sess.Start(countingScript) + require.NoError(t, err) + + waitForChan(t, readDone, ctx, "read timeout") + + sessionDone := make(chan struct{}) + go func() { + defer close(sessionDone) + err := sess.Wait() + assert.NoError(t, err) + }() + + waitForChan(t, sessionDone, ctx, "session timeout") + err = s.Close() + require.NoError(t, err) + waitForChan(t, done, ctx, "timeout closing server") +} + +func waitForChan(t *testing.T, c <-chan struct{}, ctx context.Context, msg string) { + t.Helper() + select { + case <-c: + // OK! + case <-ctx.Done(): + t.Fatal(msg) + } +} + +// sshClient creates an ssh.Client for testing +func sshClient(t *testing.T, addr string) *ssh.Client { + conn, err := net.Dial("tcp", addr) + require.NoError(t, err) + t.Cleanup(func() { + _ = conn.Close() + }) + + sshConn, channels, requests, err := ssh.NewClientConn(conn, "localhost:22", &ssh.ClientConfig{ + HostKeyCallback: ssh.InsecureIgnoreHostKey(), //nolint:gosec // This is a test. + }) + require.NoError(t, err) + t.Cleanup(func() { + _ = sshConn.Close() + }) + c := ssh.NewClient(sshConn, channels, requests) + t.Cleanup(func() { + _ = c.Close() + }) + return c +} diff --git a/pty/pty.go b/pty/pty.go index 4156e74caadee..cae2daf0b67b6 100644 --- a/pty/pty.go +++ b/pty/pty.go @@ -12,10 +12,29 @@ import ( // ErrClosed is returned when a PTY is used after it has been closed. var ErrClosed = xerrors.New("pty: closed") -// PTY is a minimal interface for interacting with a TTY. -type PTY interface { +// PTYCmd is an interface for interacting with a pseudo-TTY where we control +// only one end, and the other end has been passed to a running os.Process. +type PTYCmd interface { io.Closer + // Resize sets the size of the PTY. + Resize(height uint16, width uint16) error + + // OutputReader returns an io.Reader for reading the output from the process + // controlled by the pseudo-TTY + OutputReader() io.Reader + + // InputWriter returns an io.Writer for writing into to the process + // controlled by the pseudo-TTY + InputWriter() io.Writer +} + +// PTY is a minimal interface for interacting with pseudo-TTY where this +// process retains access to _both_ ends of the pseudo-TTY (i.e. `ptm` & `pts` +// on Linux). +type PTY interface { + PTYCmd + // Name of the TTY. Example on Linux would be "/dev/pts/1". Name() string @@ -34,14 +53,6 @@ type PTY interface { // // The same stream would be used to provide user input: pty.Input().Write(...) Input() ReadWriter - - // Dup returns a new file descriptor for the PTY. - // - // This is useful for closing stdin and stdout separately. - Dup() (*os.File, error) - - // Resize sets the size of the PTY. - Resize(height uint16, width uint16) error } // Process represents a process running in a PTY. We need to trigger special processing on the PTY diff --git a/pty/pty_other.go b/pty/pty_other.go index f0a49184c80b9..efc01e04267f9 100644 --- a/pty/pty_other.go +++ b/pty/pty_other.go @@ -3,11 +3,11 @@ package pty import ( + "io" "os" "os/exec" "runtime" "sync" - "syscall" "github.com/creack/pty" "github.com/u-root/u-root/pkg/termios" @@ -28,6 +28,7 @@ func newPty(opt ...Option) (retPTY *otherPty, err error) { pty: ptyFile, tty: ttyFile, opts: opts, + name: ttyFile.Name(), } defer func() { if err != nil { @@ -53,6 +54,7 @@ type otherPty struct { err error pty, tty *os.File opts ptyOptions + name string } func (p *otherPty) control(tty *os.File, fn func(fd uintptr) error) (err error) { @@ -85,7 +87,7 @@ func (p *otherPty) control(tty *os.File, fn func(fd uintptr) error) (err error) } func (p *otherPty) Name() string { - return p.tty.Name() + return p.name } func (p *otherPty) Input() ReadWriter { @@ -95,6 +97,10 @@ func (p *otherPty) Input() ReadWriter { } } +func (p *otherPty) InputWriter() io.Writer { + return p.pty +} + func (p *otherPty) Output() ReadWriter { return ReadWriter{ Reader: p.pty, @@ -102,6 +108,10 @@ func (p *otherPty) Output() ReadWriter { } } +func (p *otherPty) OutputReader() io.Reader { + return p.pty +} + func (p *otherPty) Resize(height uint16, width uint16) error { return p.control(p.pty, func(fd uintptr) error { return termios.SetWinSize(fd, &termios.Winsize{ @@ -113,20 +123,6 @@ func (p *otherPty) Resize(height uint16, width uint16) error { }) } -func (p *otherPty) Dup() (*os.File, error) { - var newfd int - err := p.control(p.pty, func(fd uintptr) error { - var err error - newfd, err = syscall.Dup(int(fd)) - return err - }) - if err != nil { - return nil, err - } - - return os.NewFile(uintptr(newfd), p.pty.Name()), nil -} - func (p *otherPty) Close() error { p.mutex.Lock() defer p.mutex.Unlock() @@ -137,9 +133,12 @@ func (p *otherPty) Close() error { p.closed = true err := p.pty.Close() - err2 := p.tty.Close() - if err == nil { - err = err2 + // tty is closed & unset if we Start() a new process + if p.tty != nil { + err2 := p.tty.Close() + if err == nil { + err = err2 + } } if err != nil { diff --git a/pty/pty_windows.go b/pty/pty_windows.go index b1afec6778be3..21a7132b68daa 100644 --- a/pty/pty_windows.go +++ b/pty/pty_windows.go @@ -3,6 +3,7 @@ package pty import ( + "io" "os" "os/exec" "sync" @@ -104,6 +105,10 @@ func (p *ptyWindows) Output() ReadWriter { } } +func (p *ptyWindows) OutputReader() io.Reader { + return p.outputRead +} + func (p *ptyWindows) Input() ReadWriter { return ReadWriter{ Reader: p.inputRead, @@ -111,6 +116,10 @@ func (p *ptyWindows) Input() ReadWriter { } } +func (p *ptyWindows) InputWriter() io.Writer { + return p.inputWrite +} + func (p *ptyWindows) Resize(height uint16, width uint16) error { // Taken from: https://github.com/microsoft/hcsshim/blob/54a5ad86808d761e3e396aff3e2022840f39f9a8/internal/winapi/zsyscall_windows.go#L144 ret, _, err := procResizePseudoConsole.Call(uintptr(p.console), uintptr(*((*uint32)(unsafe.Pointer(&windows.Coord{ @@ -123,10 +132,6 @@ func (p *ptyWindows) Resize(height uint16, width uint16) error { return nil } -func (p *ptyWindows) Dup() (*os.File, error) { - return nil, xerrors.Errorf("not implemented") -} - func (p *ptyWindows) Close() error { p.closeMutex.Lock() defer p.closeMutex.Unlock() @@ -140,10 +145,16 @@ func (p *ptyWindows) Close() error { return xerrors.Errorf("close pseudo console: %w", err) } - _ = p.outputWrite.Close() + // We always have these files _ = p.outputRead.Close() _ = p.inputWrite.Close() - _ = p.inputRead.Close() + // These get closed & unset if we Start() a new process. + if p.outputWrite != nil { + _ = p.outputWrite.Close() + } + if p.inputRead.Close() != nil { + _ = p.inputRead.Close() + } return nil } diff --git a/pty/ptytest/ptytest.go b/pty/ptytest/ptytest.go index 74331fbfaa1c5..bd11e78244993 100644 --- a/pty/ptytest/ptytest.go +++ b/pty/ptytest/ptytest.go @@ -29,13 +29,21 @@ func New(t *testing.T, opts ...pty.Option) *PTY { ptty, err := pty.New(opts...) require.NoError(t, err) + // Ensure pty is cleaned up at the end of test. + t.Cleanup(func() { + _ = ptty.Close() + }) - return create(t, ptty, "cmd") + e := newExpecter(t, ptty.Output(), "cmd") + return &PTY{ + outExpecter: *e, + PTY: ptty, + } } // Start starts a new process asynchronously and returns a PTY and Process. // It kills the process upon cleanup. -func Start(t *testing.T, cmd *exec.Cmd, opts ...pty.StartOption) (*PTY, pty.Process) { +func Start(t *testing.T, cmd *exec.Cmd, opts ...pty.StartOption) (*PTYCmd, pty.Process) { t.Helper() ptty, ps, err := pty.Start(cmd, opts...) @@ -44,10 +52,15 @@ func Start(t *testing.T, cmd *exec.Cmd, opts ...pty.StartOption) (*PTY, pty.Proc _ = ps.Kill() _ = ps.Wait() }) - return create(t, ptty, cmd.Args[0]), ps + ex := newExpecter(t, ptty.OutputReader(), cmd.Args[0]) + + return &PTYCmd{ + outExpecter: *ex, + PTYCmd: ptty, + }, ps } -func create(t *testing.T, ptty pty.PTY, name string) *PTY { +func newExpecter(t *testing.T, r io.Reader, name string) *outExpecter { // Use pipe for logging. logDone := make(chan struct{}) logr, logw := io.Pipe() @@ -57,37 +70,30 @@ func create(t *testing.T, ptty pty.PTY, name string) *PTY { out := newStdbuf() w := io.MultiWriter(logw, out) - tpty := &PTY{ + ex := &outExpecter{ t: t, - PTY: ptty, out: out, name: name, runeReader: bufio.NewReaderSize(out, utf8.UTFMax), } - // Ensure pty is cleaned up at the end of test. - t.Cleanup(func() { - _ = tpty.Close() - }) logClose := func(name string, c io.Closer) { - tpty.logf("closing %s", name) + ex.logf("closing %s", name) err := c.Close() - tpty.logf("closed %s: %v", name, err) + ex.logf("closed %s: %v", name, err) } // Set the actual close function for the tpty. - tpty.close = func(reason string) error { + ex.close = func(reason string) error { ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) defer cancel() - tpty.logf("closing tpty: %s", reason) + ex.logf("closing expecter: %s", reason) - // Close pty only so that the copy goroutine can consume the - // remainder of it's buffer and then exit. - logClose("pty", ptty) + // Caller needs to have closed the PTY so that copying can complete select { case <-ctx.Done(): - tpty.fatalf("close", "copy did not close in time") + ex.fatalf("close", "copy did not close in time") case <-copyDone: } @@ -95,22 +101,22 @@ func create(t *testing.T, ptty pty.PTY, name string) *PTY { logClose("logr", logr) select { case <-ctx.Done(): - tpty.fatalf("close", "log pipe did not close in time") + ex.fatalf("close", "log pipe did not close in time") case <-logDone: } - tpty.logf("closed tpty") + ex.logf("closed expecter") return nil } go func() { defer close(copyDone) - _, err := io.Copy(w, ptty.Output()) - tpty.logf("copy done: %v", err) - tpty.logf("closing out") + _, err := io.Copy(w, r) + ex.logf("copy done: %v", err) + ex.logf("closing out") err = out.closeErr(err) - tpty.logf("closed out: %v", err) + ex.logf("closed out: %v", err) }() // Log all output as part of test for easier debugging on errors. @@ -118,15 +124,14 @@ func create(t *testing.T, ptty pty.PTY, name string) *PTY { defer close(logDone) s := bufio.NewScanner(logr) for s.Scan() { - tpty.logf("%q", stripansi.Strip(s.Text())) + ex.logf("%q", stripansi.Strip(s.Text())) } }() - return tpty + return ex } -type PTY struct { - pty.PTY +type outExpecter struct { t *testing.T close func(reason string) error out *stdbuf @@ -135,10 +140,34 @@ type PTY struct { runeReader *bufio.Reader } +type PTY struct { + outExpecter + pty.PTY +} + +type PTYCmd struct { + outExpecter + pty.PTYCmd +} + func (p *PTY) Close() error { p.t.Helper() + pErr := p.PTY.Close() + eErr := p.outExpecter.close("close") + if pErr != nil { + return pErr + } + return eErr +} - return p.close("close") +func (p *PTYCmd) Close() error { + p.t.Helper() + pErr := p.PTYCmd.Close() + eErr := p.outExpecter.close("close") + if pErr != nil { + return pErr + } + return eErr } func (p *PTY) Attach(inv *clibase.Invocation) *PTY { @@ -150,23 +179,23 @@ func (p *PTY) Attach(inv *clibase.Invocation) *PTY { return p } -func (p *PTY) ExpectMatch(str string) string { - p.t.Helper() +func (e *outExpecter) ExpectMatch(str string) string { + e.t.Helper() timeout, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium) defer cancel() - return p.ExpectMatchContext(timeout, str) + return e.ExpectMatchContext(timeout, str) } // TODO(mafredri): Rename this to ExpectMatch when refactoring. -func (p *PTY) ExpectMatchContext(ctx context.Context, str string) string { - p.t.Helper() +func (e *outExpecter) ExpectMatchContext(ctx context.Context, str string) string { + e.t.Helper() var buffer bytes.Buffer - err := p.doMatchWithDeadline(ctx, "ExpectMatchContext", func() error { + err := e.doMatchWithDeadline(ctx, "ExpectMatchContext", func() error { for { - r, _, err := p.runeReader.ReadRune() + r, _, err := e.runeReader.ReadRune() if err != nil { return err } @@ -180,54 +209,54 @@ func (p *PTY) ExpectMatchContext(ctx context.Context, str string) string { } }) if err != nil { - p.fatalf("read error", "%v (wanted %q; got %q)", err, str, buffer.String()) + e.fatalf("read error", "%v (wanted %q; got %q)", err, str, buffer.String()) return "" } - p.logf("matched %q = %q", str, stripansi.Strip(buffer.String())) + e.logf("matched %q = %q", str, stripansi.Strip(buffer.String())) return buffer.String() } -func (p *PTY) Peek(ctx context.Context, n int) []byte { - p.t.Helper() +func (e *outExpecter) Peek(ctx context.Context, n int) []byte { + e.t.Helper() var out []byte - err := p.doMatchWithDeadline(ctx, "Peek", func() error { + err := e.doMatchWithDeadline(ctx, "Peek", func() error { var err error - out, err = p.runeReader.Peek(n) + out, err = e.runeReader.Peek(n) return err }) if err != nil { - p.fatalf("read error", "%v (wanted %d bytes; got %d: %q)", err, n, len(out), out) + e.fatalf("read error", "%v (wanted %d bytes; got %d: %q)", err, n, len(out), out) return nil } - p.logf("peeked %d/%d bytes = %q", len(out), n, out) + e.logf("peeked %d/%d bytes = %q", len(out), n, out) return slices.Clone(out) } -func (p *PTY) ReadRune(ctx context.Context) rune { - p.t.Helper() +func (e *outExpecter) ReadRune(ctx context.Context) rune { + e.t.Helper() var r rune - err := p.doMatchWithDeadline(ctx, "ReadRune", func() error { + err := e.doMatchWithDeadline(ctx, "ReadRune", func() error { var err error - r, _, err = p.runeReader.ReadRune() + r, _, err = e.runeReader.ReadRune() return err }) if err != nil { - p.fatalf("read error", "%v (wanted rune; got %q)", err, r) + e.fatalf("read error", "%v (wanted rune; got %q)", err, r) return 0 } - p.logf("matched rune = %q", r) + e.logf("matched rune = %q", r) return r } -func (p *PTY) ReadLine(ctx context.Context) string { - p.t.Helper() +func (e *outExpecter) ReadLine(ctx context.Context) string { + e.t.Helper() var buffer bytes.Buffer - err := p.doMatchWithDeadline(ctx, "ReadLine", func() error { + err := e.doMatchWithDeadline(ctx, "ReadLine", func() error { for { - r, _, err := p.runeReader.ReadRune() + r, _, err := e.runeReader.ReadRune() if err != nil { return err } @@ -240,14 +269,14 @@ func (p *PTY) ReadLine(ctx context.Context) string { // Unicode code points can be up to 4 bytes, but the // ones we're looking for are only 1 byte. - b, _ := p.runeReader.Peek(1) + b, _ := e.runeReader.Peek(1) if len(b) == 0 { return nil } r, _ = utf8.DecodeRune(b) if r == '\n' { - _, _, err = p.runeReader.ReadRune() + _, _, err = e.runeReader.ReadRune() if err != nil { return err } @@ -263,21 +292,21 @@ func (p *PTY) ReadLine(ctx context.Context) string { } }) if err != nil { - p.fatalf("read error", "%v (wanted newline; got %q)", err, buffer.String()) + e.fatalf("read error", "%v (wanted newline; got %q)", err, buffer.String()) return "" } - p.logf("matched newline = %q", buffer.String()) + e.logf("matched newline = %q", buffer.String()) return buffer.String() } -func (p *PTY) doMatchWithDeadline(ctx context.Context, name string, fn func() error) error { - p.t.Helper() +func (e *outExpecter) doMatchWithDeadline(ctx context.Context, name string, fn func() error) error { + e.t.Helper() // A timeout is mandatory, caller can decide by passing a context // that times out. if _, ok := ctx.Deadline(); !ok { timeout := testutil.WaitMedium - p.logf("%s ctx has no deadline, using %s", name, timeout) + e.logf("%s ctx has no deadline, using %s", name, timeout) var cancel context.CancelFunc //nolint:gocritic // Rule guard doesn't detect that we're using testutil.Wait*. ctx, cancel = context.WithTimeout(ctx, timeout) @@ -294,7 +323,7 @@ func (p *PTY) doMatchWithDeadline(ctx context.Context, name string, fn func() er return err case <-ctx.Done(): // Ensure goroutine is cleaned up before test exit. - _ = p.close("match deadline exceeded") + _ = e.close("match deadline exceeded") <-match return xerrors.Errorf("match deadline exceeded: %w", ctx.Err()) @@ -321,22 +350,22 @@ func (p *PTY) WriteLine(str string) { require.NoError(p.t, err, "write line failed") } -func (p *PTY) logf(format string, args ...interface{}) { - p.t.Helper() +func (e *outExpecter) logf(format string, args ...interface{}) { + e.t.Helper() // Match regular logger timestamp format, we seem to be logging in // UTC in other places as well, so match here. - p.t.Logf("%s: %s: %s", time.Now().UTC().Format("2006-01-02 15:04:05.000"), p.name, fmt.Sprintf(format, args...)) + e.t.Logf("%s: %s: %s", time.Now().UTC().Format("2006-01-02 15:04:05.000"), e.name, fmt.Sprintf(format, args...)) } -func (p *PTY) fatalf(reason string, format string, args ...interface{}) { - p.t.Helper() +func (e *outExpecter) fatalf(reason string, format string, args ...interface{}) { + e.t.Helper() // Ensure the message is part of the normal log stream before // failing the test. - p.logf("%s: %s", reason, fmt.Sprintf(format, args...)) + e.logf("%s: %s", reason, fmt.Sprintf(format, args...)) - require.FailNowf(p.t, reason, format, args...) + require.FailNowf(e.t, reason, format, args...) } // stdbuf is like a buffered stdout, it buffers writes until read. diff --git a/pty/start.go b/pty/start.go index ea09cbb251767..565edaca43d80 100644 --- a/pty/start.go +++ b/pty/start.go @@ -20,6 +20,6 @@ func WithPTYOption(opts ...Option) StartOption { // Start the command in a TTY. The calling code must not use cmd after passing it to the PTY, and // instead rely on the returned Process to manage the command/process. -func Start(cmd *exec.Cmd, opt ...StartOption) (PTY, Process, error) { +func Start(cmd *exec.Cmd, opt ...StartOption) (PTYCmd, Process, error) { return startPty(cmd, opt...) } diff --git a/pty/start_other.go b/pty/start_other.go index c38b6dcf8faee..a6353a138d9aa 100644 --- a/pty/start_other.go +++ b/pty/start_other.go @@ -50,6 +50,17 @@ func startPty(cmd *exec.Cmd, opt ...StartOption) (retPTY *otherPty, proc Process } return nil, nil, xerrors.Errorf("start: %w", err) } + // Now that we've started the command, and passed the TTY to it, close our + // file so that the other process has the only open file to the TTY. Once + // the process closes the TTY (usually on exit), there will be no open + // references and the OS kernel returns an error when trying to read or + // write to our PTY end. Without this, reading from the process output + // will block until we close our TTY. + if err := opty.tty.Close(); err != nil { + _ = cmd.Process.Kill() + return nil, nil, xerrors.Errorf("close tty: %w", err) + } + opty.tty = nil // remove so we don't attempt to close it again. oProcess := &otherProcess{ pty: opty.pty, cmd: cmd, diff --git a/pty/start_windows.go b/pty/start_windows.go index f9307cd364b84..9b000878fbf5f 100644 --- a/pty/start_windows.go +++ b/pty/start_windows.go @@ -87,6 +87,22 @@ func startPty(cmd *exec.Cmd, opt ...StartOption) (PTY, Process, error) { } defer windows.CloseHandle(processInfo.Thread) defer windows.CloseHandle(processInfo.Process) + // Now that we've started the command, and passed the pseudoconsole to it, + // close the output write and input read files, so that the other process + // has the only handles to them. Once the process closes the console, there + // will be no open references and the OS kernel returns an error when trying + // to read or write to our end. Without this, reading from the process + // output will block until they are closed. + errO := winPty.outputWrite.Close() + winPty.outputWrite = nil + errI := winPty.inputRead.Close() + winPty.inputRead = nil + if errO != nil { + return nil, nil, errO + } + if errI != nil { + return nil, nil, errI + } process, err := os.FindProcess(int(processInfo.ProcessId)) if err != nil { From 2cf357aa753e0faafe1039493d2565404c4a928f Mon Sep 17 00:00:00 2001 From: Spike Curtis Date: Wed, 12 Apr 2023 15:14:04 +0400 Subject: [PATCH 03/21] Fix some tests Signed-off-by: Spike Curtis --- agent/agent.go | 6 +-- agent/agentssh/agentssh.go | 8 +-- agent/agentssh/agentssh_test.go | 90 --------------------------------- pty/pty.go | 5 +- pty/pty_other.go | 25 ++++++++- pty/ptytest/ptytest.go | 23 +++++---- pty/start_other_test.go | 9 +++- pty/start_windows_test.go | 6 +++ 8 files changed, 57 insertions(+), 115 deletions(-) diff --git a/agent/agent.go b/agent/agent.go index f538ef93b4af8..52f76ce184b4a 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -1045,7 +1045,7 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, logger slog.Logger, m if err = a.trackConnGoroutine(func() { buffer := make([]byte, 1024) for { - read, err := rpty.ptty.Output().Read(buffer) + read, err := rpty.ptty.OutputReader().Read(buffer) if err != nil { // When the PTY is closed, this is triggered. break @@ -1138,7 +1138,7 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, logger slog.Logger, m logger.Warn(ctx, "read conn", slog.Error(err)) return nil } - _, err = rpty.ptty.Input().Write([]byte(req.Data)) + _, err = rpty.ptty.InputWriter().Write([]byte(req.Data)) if err != nil { logger.Warn(ctx, "write to pty", slog.Error(err)) return nil @@ -1358,7 +1358,7 @@ type reconnectingPTY struct { circularBuffer *circbuf.Buffer circularBufferMutex sync.RWMutex timeout *time.Timer - ptty pty.PTY + ptty pty.PTYCmd } // Close ends all connections to the reconnecting diff --git a/agent/agentssh/agentssh.go b/agent/agentssh/agentssh.go index 5430e891309eb..7aa1159c5cf05 100644 --- a/agent/agentssh/agentssh.go +++ b/agent/agentssh/agentssh.go @@ -8,7 +8,6 @@ import ( "errors" "fmt" "io" - "io/fs" "net" "os" "os/exec" @@ -330,12 +329,7 @@ func (s *Server) startPTYSession(session ptySession, cmd *exec.Cmd, sshPty ssh.P // kill the command's process. This then has the same effect as (1). n, err := io.Copy(session, ptty.OutputReader()) s.logger.Debug(ctx, "copy output done", slog.F("bytes", n), slog.Error(err)) - - // output from the ptty will hit a PathErr on the PTY when the process - // hangs up the other side (typically when the process exits, but could - // be earlier) - pathErr := &fs.PathError{} - if err != nil && !xerrors.As(err, &pathErr) { + if err != nil { return xerrors.Errorf("copy error: %w", err, err, err) } // We've gotten all the output, but we need to wait for the process to diff --git a/agent/agentssh/agentssh_test.go b/agent/agentssh/agentssh_test.go index 22878b709d79a..c8d5e7638a127 100644 --- a/agent/agentssh/agentssh_test.go +++ b/agent/agentssh/agentssh_test.go @@ -3,15 +3,12 @@ package agentssh_test import ( - "bufio" "bytes" "context" "net" - "strconv" "strings" "sync" "testing" - "time" "cdr.dev/slog/sloggers/slogtest" "github.com/stretchr/testify/assert" @@ -93,7 +90,6 @@ func TestNewServer_CloseActiveConnections(t *testing.T) { }() pty := ptytest.New(t) - defer pty.Close() doClose := make(chan struct{}) go func() { @@ -120,92 +116,6 @@ func TestNewServer_CloseActiveConnections(t *testing.T) { wg.Wait() } -const countingScript = ` -i=0 -while [ $i -ne 20000 ] -do - i=$(($i+1)) - echo "$i" -done -` - -// TestServer_sessionStart_longoutput is designed to test running a command that -// produces a lot of output and ensure we don't truncate the output returned -// over SSH. -func TestServer_sessionStart_longoutput(t *testing.T) { - t.Parallel() - - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - logger := slogtest.Make(t, nil) - s, err := agentssh.NewServer(ctx, logger, 0) - require.NoError(t, err) - - // The assumption is that these are set before serving SSH connections. - s.AgentToken = func() string { return "" } - s.Manifest = atomic.NewPointer(&agentsdk.Manifest{}) - - ln, err := net.Listen("tcp", "127.0.0.1:0") - require.NoError(t, err) - - done := make(chan struct{}) - go func() { - defer close(done) - err := s.Serve(ln) - assert.Error(t, err) // Server is closed. - }() - - c := sshClient(t, ln.Addr().String()) - sess, err := c.NewSession() - require.NoError(t, err) - - stdout, err := sess.StdoutPipe() - require.NoError(t, err) - readDone := make(chan struct{}) - go func() { - w := 0 - defer close(readDone) - s := bufio.NewScanner(stdout) - for s.Scan() { - w++ - ns := s.Text() - n, err := strconv.Atoi(ns) - require.NoError(t, err) - require.Equal(t, w, n, "output corrupted") - } - assert.Equal(t, w, 20000, "output truncated") - assert.NoError(t, s.Err()) - }() - - err = sess.Start(countingScript) - require.NoError(t, err) - - waitForChan(t, readDone, ctx, "read timeout") - - sessionDone := make(chan struct{}) - go func() { - defer close(sessionDone) - err := sess.Wait() - assert.NoError(t, err) - }() - - waitForChan(t, sessionDone, ctx, "session timeout") - err = s.Close() - require.NoError(t, err) - waitForChan(t, done, ctx, "timeout closing server") -} - -func waitForChan(t *testing.T, c <-chan struct{}, ctx context.Context, msg string) { - t.Helper() - select { - case <-c: - // OK! - case <-ctx.Done(): - t.Fatal(msg) - } -} - -// sshClient creates an ssh.Client for testing func sshClient(t *testing.T, addr string) *ssh.Client { conn, err := net.Dial("tcp", addr) require.NoError(t, err) diff --git a/pty/pty.go b/pty/pty.go index cae2daf0b67b6..7bd4fa76ae020 100644 --- a/pty/pty.go +++ b/pty/pty.go @@ -3,7 +3,6 @@ package pty import ( "io" "log" - "os" "github.com/gliderlabs/ssh" "golang.org/x/xerrors" @@ -119,8 +118,8 @@ func New(opts ...Option) (PTY, error) { // underlying file descriptors, one for reading and one for writing, and allows // them to be accessed separately. type ReadWriter struct { - Reader *os.File - Writer *os.File + Reader io.Reader + Writer io.Writer } func (rw ReadWriter) Read(p []byte) (int, error) { diff --git a/pty/pty_other.go b/pty/pty_other.go index efc01e04267f9..1cc3a28f3ad83 100644 --- a/pty/pty_other.go +++ b/pty/pty_other.go @@ -4,11 +4,14 @@ package pty import ( "io" + "io/fs" "os" "os/exec" "runtime" "sync" + "golang.org/x/xerrors" + "github.com/creack/pty" "github.com/u-root/u-root/pkg/termios" "golang.org/x/sys/unix" @@ -103,13 +106,13 @@ func (p *otherPty) InputWriter() io.Writer { func (p *otherPty) Output() ReadWriter { return ReadWriter{ - Reader: p.pty, + Reader: &ptmReader{p.pty}, Writer: p.tty, } } func (p *otherPty) OutputReader() io.Reader { - return p.pty + return &ptmReader{p.pty} } func (p *otherPty) Resize(height uint16, width uint16) error { @@ -176,3 +179,21 @@ func (p *otherProcess) waitInternal() { runtime.KeepAlive(p.pty) close(p.cmdDone) } + +// ptmReader wraps a reference to the ptm side of a pseudo-TTY for portability +type ptmReader struct { + ptm io.Reader +} + +func (r *ptmReader) Read(p []byte) (n int, err error) { + n, err = r.ptm.Read(p) + // output from the ptm will hit a PathErr when the process hangs up the + // other side (typically when the process exits, but could be earlier). For + // portability, and to fit with our use of io.Copy() to copy from the PTY, + // we want to translate this error into io.EOF + pathErr := &fs.PathError{} + if xerrors.As(err, &pathErr) { + return n, io.EOF + } + return n, err +} diff --git a/pty/ptytest/ptytest.go b/pty/ptytest/ptytest.go index bd11e78244993..25490658b6ce1 100644 --- a/pty/ptytest/ptytest.go +++ b/pty/ptytest/ptytest.go @@ -29,20 +29,21 @@ func New(t *testing.T, opts ...pty.Option) *PTY { ptty, err := pty.New(opts...) require.NoError(t, err) - // Ensure pty is cleaned up at the end of test. - t.Cleanup(func() { - _ = ptty.Close() - }) e := newExpecter(t, ptty.Output(), "cmd") - return &PTY{ + r := &PTY{ outExpecter: *e, PTY: ptty, } + // Ensure pty is cleaned up at the end of test. + t.Cleanup(func() { + _ = r.Close() + }) + return r } -// Start starts a new process asynchronously and returns a PTY and Process. -// It kills the process upon cleanup. +// Start starts a new process asynchronously and returns a PTYCmd and Process. +// It kills the process and PTYCmd upon cleanup func Start(t *testing.T, cmd *exec.Cmd, opts ...pty.StartOption) (*PTYCmd, pty.Process) { t.Helper() @@ -54,10 +55,14 @@ func Start(t *testing.T, cmd *exec.Cmd, opts ...pty.StartOption) (*PTYCmd, pty.P }) ex := newExpecter(t, ptty.OutputReader(), cmd.Args[0]) - return &PTYCmd{ + r := &PTYCmd{ outExpecter: *ex, PTYCmd: ptty, - }, ps + } + t.Cleanup(func() { + _ = r.Close() + }) + return r, ps } func newExpecter(t *testing.T, r io.Reader, name string) *outExpecter { diff --git a/pty/start_other_test.go b/pty/start_other_test.go index d1f11a419e48f..1aba204126e94 100644 --- a/pty/start_other_test.go +++ b/pty/start_other_test.go @@ -25,20 +25,25 @@ func TestStart(t *testing.T) { t.Run("Echo", func(t *testing.T) { t.Parallel() pty, ps := ptytest.Start(t, exec.Command("echo", "test")) + pty.ExpectMatch("test") err := ps.Wait() require.NoError(t, err) + err = pty.Close() + require.NoError(t, err) }) t.Run("Kill", func(t *testing.T) { t.Parallel() - _, ps := ptytest.Start(t, exec.Command("sleep", "30")) + pty, ps := ptytest.Start(t, exec.Command("sleep", "30")) err := ps.Kill() assert.NoError(t, err) err = ps.Wait() var exitErr *exec.ExitError require.True(t, xerrors.As(err, &exitErr)) assert.NotEqual(t, 0, exitErr.ExitCode()) + err = pty.Close() + require.NoError(t, err) }) t.Run("SSH_TTY", func(t *testing.T) { @@ -53,5 +58,7 @@ func TestStart(t *testing.T) { pty.ExpectMatch("SSH_TTY=/dev/") err := ps.Wait() require.NoError(t, err) + err = pty.Close() + require.NoError(t, err) }) } diff --git a/pty/start_windows_test.go b/pty/start_windows_test.go index edbbd5dd99c3b..d24b3aca4f793 100644 --- a/pty/start_windows_test.go +++ b/pty/start_windows_test.go @@ -26,12 +26,16 @@ func TestStart(t *testing.T) { pty.ExpectMatch("test") err := ps.Wait() require.NoError(t, err) + err = pty.Close() + require.NoError(t, err) }) t.Run("Resize", func(t *testing.T) { t.Parallel() pty, _ := ptytest.Start(t, exec.Command("cmd.exe")) err := pty.Resize(100, 50) require.NoError(t, err) + err = pty.Close() + require.NoError(t, err) }) t.Run("Kill", func(t *testing.T) { t.Parallel() @@ -42,5 +46,7 @@ func TestStart(t *testing.T) { var exitErr *exec.ExitError require.True(t, xerrors.As(err, &exitErr)) assert.NotEqual(t, 0, exitErr.ExitCode()) + err = pty.Close() + require.NoError(t, err) }) } From 28c0646a9490131930bdfdc8d03f9ef417a625c0 Mon Sep 17 00:00:00 2001 From: Spike Curtis Date: Wed, 12 Apr 2023 12:05:27 +0000 Subject: [PATCH 04/21] Fix linting Signed-off-by: Spike Curtis --- agent/agentssh/agentssh.go | 12 ++----- agent/agentssh/agentssh_internal_test.go | 40 +++++++++++++----------- agent/agentssh/agentssh_test.go | 3 +- pty/pty.go | 1 + 4 files changed, 27 insertions(+), 29 deletions(-) diff --git a/agent/agentssh/agentssh.go b/agent/agentssh/agentssh.go index 7aa1159c5cf05..8c724e57612cc 100644 --- a/agent/agentssh/agentssh.go +++ b/agent/agentssh/agentssh.go @@ -235,11 +235,10 @@ func (s *Server) sessionStart(session ssh.Session) error { if isPty { return s.startPTYSession(session, cmd, sshPty, windowSize) } - - return s.startNonPTYSession(session, cmd) + return startNonPTYSession(session, cmd) } -func (s *Server) startNonPTYSession(session ssh.Session, cmd *exec.Cmd) error { +func startNonPTYSession(session ssh.Session, cmd *exec.Cmd) error { cmd.Stdout = session cmd.Stderr = session.Stderr() // This blocks forever until stdin is received if we don't @@ -330,7 +329,7 @@ func (s *Server) startPTYSession(session ptySession, cmd *exec.Cmd, sshPty ssh.P n, err := io.Copy(session, ptty.OutputReader()) s.logger.Debug(ctx, "copy output done", slog.F("bytes", n), slog.Error(err)) if err != nil { - return xerrors.Errorf("copy error: %w", err, err, err) + return xerrors.Errorf("copy error: %w", err) } // We've gotten all the output, but we need to wait for the process to // complete so that we can get the exit code. This returns @@ -348,11 +347,6 @@ func (s *Server) startPTYSession(session ptySession, cmd *exec.Cmd, sshPty ssh.P return nil } -type readNopCloser struct{ io.Reader } - -// Close implements io.Closer. -func (readNopCloser) Close() error { return nil } - func (s *Server) sftpHandler(session ssh.Session) { ctx := session.Context() diff --git a/agent/agentssh/agentssh_internal_test.go b/agent/agentssh/agentssh_internal_test.go index 1db767badd8b7..aaec46f526b9c 100644 --- a/agent/agentssh/agentssh_internal_test.go +++ b/agent/agentssh/agentssh_internal_test.go @@ -7,12 +7,14 @@ import ( "net" "os/exec" "testing" - "time" - "cdr.dev/slog/sloggers/slogtest" gliderssh "github.com/gliderlabs/ssh" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + "github.com/coder/coder/testutil" + + "cdr.dev/slog/sloggers/slogtest" ) const longScript = ` @@ -28,7 +30,7 @@ echo "done" func Test_sessionStart_orphan(t *testing.T) { t.Parallel() - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium) defer cancel() logger := slogtest.Make(t, nil) s, err := NewServer(ctx, logger, 0) @@ -62,12 +64,12 @@ func Test_sessionStart_orphan(t *testing.T) { go func() { defer close(readDone) s := bufio.NewScanner(toClient) - require.True(t, s.Scan()) + assert.True(t, s.Scan()) txt := s.Text() assert.Equal(t, "started", txt, "output corrupted") }() - waitForChan(t, readDone, ctx, "read timeout") + waitForChan(ctx, t, readDone, "read timeout") // process is started, and should be sleeping for ~30 seconds sessionCancel() @@ -77,13 +79,13 @@ func Test_sessionStart_orphan(t *testing.T) { // that the server isn't properly shutting down sessions when they are // disconnected client side, which could lead to processes hanging around // indefinitely. - waitForChan(t, done, ctx, "handler timeout") + waitForChan(ctx, t, done, "handler timeout") err = fromClient.Close() require.NoError(t, err) } -func waitForChan(t *testing.T, c <-chan struct{}, ctx context.Context, msg string) { +func waitForChan(ctx context.Context, t *testing.T, c <-chan struct{}, msg string) { t.Helper() select { case <-c: @@ -121,12 +123,12 @@ func (s *testSession) Context() gliderssh.Context { return s.ctx } -func (s *testSession) DisablePTYEmulation() {} +func (*testSession) DisablePTYEmulation() {} // RawCommand returns "quiet logon" so that the PTY handler doesn't attempt to // write the message of the day, which will interfere with our tests. It writes // the message of the day if it's a shell login (zero length RawCommand()). -func (s *testSession) RawCommand() string { return "quiet logon" } +func (*testSession) RawCommand() string { return "quiet logon" } func (s *testSession) Read(p []byte) (n int, err error) { return s.toPty.Read(p) @@ -136,49 +138,49 @@ func (s *testSession) Write(p []byte) (n int, err error) { return s.fromPty.Write(p) } -func (c testSSHContext) Lock() { +func (testSSHContext) Lock() { panic("not implemented") } -func (c testSSHContext) Unlock() { +func (testSSHContext) Unlock() { panic("not implemented") } // User returns the username used when establishing the SSH connection. -func (c testSSHContext) User() string { +func (testSSHContext) User() string { panic("not implemented") } // SessionID returns the session hash. -func (c testSSHContext) SessionID() string { +func (testSSHContext) SessionID() string { panic("not implemented") } // ClientVersion returns the version reported by the client. -func (c testSSHContext) ClientVersion() string { +func (testSSHContext) ClientVersion() string { panic("not implemented") } // ServerVersion returns the version reported by the server. -func (c testSSHContext) ServerVersion() string { +func (testSSHContext) ServerVersion() string { panic("not implemented") } // RemoteAddr returns the remote address for this connection. -func (c testSSHContext) RemoteAddr() net.Addr { +func (testSSHContext) RemoteAddr() net.Addr { panic("not implemented") } // LocalAddr returns the local address for this connection. -func (c testSSHContext) LocalAddr() net.Addr { +func (testSSHContext) LocalAddr() net.Addr { panic("not implemented") } // Permissions returns the Permissions object used for this connection. -func (c testSSHContext) Permissions() *gliderssh.Permissions { +func (testSSHContext) Permissions() *gliderssh.Permissions { panic("not implemented") } // SetValue allows you to easily write new values into the underlying context. -func (c testSSHContext) SetValue(key, value interface{}) { +func (testSSHContext) SetValue(_, _ interface{}) { panic("not implemented") } diff --git a/agent/agentssh/agentssh_test.go b/agent/agentssh/agentssh_test.go index c8d5e7638a127..684c0e36bbb18 100644 --- a/agent/agentssh/agentssh_test.go +++ b/agent/agentssh/agentssh_test.go @@ -10,13 +10,14 @@ import ( "sync" "testing" - "cdr.dev/slog/sloggers/slogtest" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.uber.org/atomic" "go.uber.org/goleak" "golang.org/x/crypto/ssh" + "cdr.dev/slog/sloggers/slogtest" + "github.com/coder/coder/agent/agentssh" "github.com/coder/coder/codersdk/agentsdk" "github.com/coder/coder/pty/ptytest" diff --git a/pty/pty.go b/pty/pty.go index 7bd4fa76ae020..e93115dc9ef45 100644 --- a/pty/pty.go +++ b/pty/pty.go @@ -13,6 +13,7 @@ var ErrClosed = xerrors.New("pty: closed") // PTYCmd is an interface for interacting with a pseudo-TTY where we control // only one end, and the other end has been passed to a running os.Process. +// nolint:revive type PTYCmd interface { io.Closer From 872e357b171b2d3731f09296b9b3e79b8a081883 Mon Sep 17 00:00:00 2001 From: Spike Curtis Date: Wed, 12 Apr 2023 12:08:12 +0000 Subject: [PATCH 05/21] fmt Signed-off-by: Spike Curtis --- agent/agentssh/agentssh_internal_test.go | 1 + 1 file changed, 1 insertion(+) diff --git a/agent/agentssh/agentssh_internal_test.go b/agent/agentssh/agentssh_internal_test.go index aaec46f526b9c..28d0c95ec3529 100644 --- a/agent/agentssh/agentssh_internal_test.go +++ b/agent/agentssh/agentssh_internal_test.go @@ -141,6 +141,7 @@ func (s *testSession) Write(p []byte) (n int, err error) { func (testSSHContext) Lock() { panic("not implemented") } + func (testSSHContext) Unlock() { panic("not implemented") } From 3f21e3099bfb8d2db026308f4ecdfa064e180257 Mon Sep 17 00:00:00 2001 From: Spike Curtis Date: Thu, 13 Apr 2023 15:10:16 +0400 Subject: [PATCH 06/21] Fix windows test Signed-off-by: Spike Curtis --- pty/start_windows_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pty/start_windows_test.go b/pty/start_windows_test.go index d24b3aca4f793..8de0ce4ca6211 100644 --- a/pty/start_windows_test.go +++ b/pty/start_windows_test.go @@ -39,7 +39,7 @@ func TestStart(t *testing.T) { }) t.Run("Kill", func(t *testing.T) { t.Parallel() - _, ps := ptytest.Start(t, exec.Command("cmd.exe")) + pty, ps := ptytest.Start(t, exec.Command("cmd.exe")) err := ps.Kill() assert.NoError(t, err) err = ps.Wait() From e83ff6ef00143512e20d7f58418e3eb72d6c8c2e Mon Sep 17 00:00:00 2001 From: Spike Curtis Date: Fri, 14 Apr 2023 10:30:06 +0400 Subject: [PATCH 07/21] Windows copy test Signed-off-by: Spike Curtis --- pty/start_windows_test.go | 57 +++++++++++++++++++++++++++++++++------ 1 file changed, 49 insertions(+), 8 deletions(-) diff --git a/pty/start_windows_test.go b/pty/start_windows_test.go index 8de0ce4ca6211..e579da5fb9b70 100644 --- a/pty/start_windows_test.go +++ b/pty/start_windows_test.go @@ -4,9 +4,14 @@ package pty_test import ( + "bytes" + "context" + "io" "os/exec" "testing" + "time" + "github.com/coder/coder/pty" "github.com/coder/coder/pty/ptytest" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -22,31 +27,67 @@ func TestStart(t *testing.T) { t.Parallel() t.Run("Echo", func(t *testing.T) { t.Parallel() - pty, ps := ptytest.Start(t, exec.Command("cmd.exe", "/c", "echo", "test")) - pty.ExpectMatch("test") + ptty, ps := ptytest.Start(t, exec.Command("cmd.exe", "/c", "echo", "test")) + ptty.ExpectMatch("test") err := ps.Wait() require.NoError(t, err) - err = pty.Close() + err = ptty.Close() require.NoError(t, err) }) t.Run("Resize", func(t *testing.T) { t.Parallel() - pty, _ := ptytest.Start(t, exec.Command("cmd.exe")) - err := pty.Resize(100, 50) + ptty, _ := ptytest.Start(t, exec.Command("cmd.exe")) + err := ptty.Resize(100, 50) require.NoError(t, err) - err = pty.Close() + err = ptty.Close() require.NoError(t, err) }) t.Run("Kill", func(t *testing.T) { t.Parallel() - pty, ps := ptytest.Start(t, exec.Command("cmd.exe")) + ptty, ps := ptytest.Start(t, exec.Command("cmd.exe")) err := ps.Kill() assert.NoError(t, err) err = ps.Wait() var exitErr *exec.ExitError require.True(t, xerrors.As(err, &exitErr)) assert.NotEqual(t, 0, exitErr.ExitCode()) - err = pty.Close() + err = ptty.Close() require.NoError(t, err) }) } + +func Test_Start_copy(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) + defer cancel() + + pc, cmd, err := pty.Start(exec.CommandContext(ctx, "cmd.exe", "/c", "echo", "test")) + require.NoError(t, err) + b := &bytes.Buffer{} + readDone := make(chan error) + go func() { + _, err := io.Copy(b, pc.OutputReader()) + readDone <- err + }() + + select { + case err := <-readDone: + require.NoError(t, err) + case <-ctx.Done(): + t.Error("read timed out") + } + assert.Equal(t, "test", b.String()) + + cmdDone := make(chan error) + go func() { + cmdDone <- cmd.Wait() + }() + + select { + case err := <-cmdDone: + require.NoError(t, err) + case <-ctx.Done(): + t.Error("cmd.Wait() timed out") + } +} From b6105797201a0355d057886be7ba16f4164574d7 Mon Sep 17 00:00:00 2001 From: Spike Curtis Date: Fri, 14 Apr 2023 15:36:13 +0400 Subject: [PATCH 08/21] WIP Windows pty handling Signed-off-by: Spike Curtis --- pty/pty_windows.go | 23 +++++++-- pty/start_windows.go | 1 + pty/start_windows_test.go | 101 +++++++++++++++++++++++++++++++++++++- 3 files changed, 121 insertions(+), 4 deletions(-) diff --git a/pty/pty_windows.go b/pty/pty_windows.go index 21a7132b68daa..8c4179be3bc6f 100644 --- a/pty/pty_windows.go +++ b/pty/pty_windows.go @@ -89,6 +89,7 @@ type windowsProcess struct { cmdDone chan any cmdErr error proc *os.Process + pw *ptyWindows } // Name returns the TTY name on Windows. @@ -140,9 +141,12 @@ func (p *ptyWindows) Close() error { } p.closed = true - ret, _, err := procClosePseudoConsole.Call(uintptr(p.console)) - if ret < 0 { - return xerrors.Errorf("close pseudo console: %w", err) + if p.console != windows.InvalidHandle { + ret, _, err := procClosePseudoConsole.Call(uintptr(p.console)) + if ret < 0 { + return xerrors.Errorf("close pseudo console: %w", err) + } + p.console = windows.InvalidHandle } // We always have these files @@ -159,6 +163,19 @@ func (p *ptyWindows) Close() error { } func (p *windowsProcess) waitInternal() { + defer func() { + // close the pseudoconsole handle when the process exits, if it hasn't already been closed. + p.pw.closeMutex.Lock() + defer p.pw.closeMutex.Unlock() + if p.pw.console != windows.InvalidHandle { + ret, _, err := procClosePseudoConsole.Call(uintptr(p.pw.console)) + if ret < 0 { + // not much we can do here... + panic(err) + } + p.pw.console = windows.InvalidHandle + } + }() defer close(p.cmdDone) state, err := p.proc.Wait() if err != nil { diff --git a/pty/start_windows.go b/pty/start_windows.go index 9b000878fbf5f..016822b10613a 100644 --- a/pty/start_windows.go +++ b/pty/start_windows.go @@ -111,6 +111,7 @@ func startPty(cmd *exec.Cmd, opt ...StartOption) (PTY, Process, error) { wp := &windowsProcess{ cmdDone: make(chan any), proc: process, + pw: winPty, } go wp.waitInternal() return pty, wp, nil diff --git a/pty/start_windows_test.go b/pty/start_windows_test.go index e579da5fb9b70..fe683bd16029c 100644 --- a/pty/start_windows_test.go +++ b/pty/start_windows_test.go @@ -6,8 +6,10 @@ package pty_test import ( "bytes" "context" + "fmt" "io" "os/exec" + "strings" "testing" "time" @@ -77,7 +79,7 @@ func Test_Start_copy(t *testing.T) { case <-ctx.Done(): t.Error("read timed out") } - assert.Equal(t, "test", b.String()) + assert.Contains(t, b.String(), "test") cmdDone := make(chan error) go func() { @@ -91,3 +93,100 @@ func Test_Start_copy(t *testing.T) { t.Error("cmd.Wait() timed out") } } + +const countEnd = 1000 + +func Test_Start_trucation(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*1000) + defer cancel() + + pc, cmd, err := pty.Start(exec.CommandContext(ctx, + "cmd.exe", "/c", + fmt.Sprintf("for /L %%n in (1,1,%d) do @echo %%n", countEnd))) + require.NoError(t, err) + readDone := make(chan struct{}) + go func() { + defer close(readDone) + // avoid buffered IO so that we can precisely control how many bytes to read. + n := 1 + for n < countEnd-25 { + want := fmt.Sprintf("%d\r\n", n) + // the output also contains virtual terminal sequences + // so just read until we see the number we want. + err := readUntil(ctx, want, pc.OutputReader()) + require.NoError(t, err, "want: %s", want) + n++ + } + }() + + select { + case <-readDone: + // OK! + case <-ctx.Done(): + t.Error("read timed out") + } + + cmdDone := make(chan error) + go func() { + cmdDone <- cmd.Wait() + }() + + select { + case err := <-cmdDone: + require.NoError(t, err) + case <-ctx.Done(): + t.Error("cmd.Wait() timed out") + } + + // do our final 25 reads, to make sure the output wasn't lost + readDone = make(chan struct{}) + go func() { + defer close(readDone) + // avoid buffered IO so that we can precisely control how many bytes to read. + n := countEnd - 25 + for n <= countEnd { + want := fmt.Sprintf("%d\r\n", n) + err := readUntil(ctx, want, pc.OutputReader()) + if n < countEnd { + require.NoError(t, err, "want: %s", want) + } else { + require.ErrorIs(t, err, io.EOF) + } + n++ + } + }() + + select { + case <-readDone: + // OK! + case <-ctx.Done(): + t.Error("read timed out") + } +} + +// readUntil reads one byte at a time until we either see the string we want, or the context expires +func readUntil(ctx context.Context, want string, r io.Reader) error { + got := "" + readErrs := make(chan error, 1) + for { + b := make([]byte, 1) + go func() { + _, err := r.Read(b) + readErrs <- err + }() + select { + case err := <-readErrs: + if err != nil { + return err + } + got = got + string(b) + case <-ctx.Done(): + return ctx.Err() + } + if strings.Contains(got, want) { + return nil + } + } +} From 90bfe94d9df6d90de99a6e000c950fa86b24c8f4 Mon Sep 17 00:00:00 2001 From: Spike Curtis Date: Tue, 18 Apr 2023 10:30:23 +0400 Subject: [PATCH 09/21] Fix truncation tests Signed-off-by: Spike Curtis --- go.mod | 2 + go.sum | 2 + pty/pty_windows.go | 12 ++- pty/start_other_test.go | 20 +++++ pty/start_test.go | 158 ++++++++++++++++++++++++++++++++++++++ pty/start_windows_test.go | 138 ++------------------------------- 6 files changed, 197 insertions(+), 135 deletions(-) create mode 100644 pty/start_test.go diff --git a/go.mod b/go.mod index 12b05ec50ab3a..b7f3de4543d85 100644 --- a/go.mod +++ b/go.mod @@ -174,6 +174,8 @@ require ( tailscale.com v1.32.2 ) +require github.com/hinshun/vt10x v0.0.0-20220301184237-5011da428d02 // indirect + require ( cloud.google.com/go/compute v1.18.0 // indirect cloud.google.com/go/logging v1.6.1 // indirect diff --git a/go.sum b/go.sum index 62610990aaf7f..71ec67afb581c 100644 --- a/go.sum +++ b/go.sum @@ -1088,6 +1088,8 @@ github.com/hdevalence/ed25519consensus v0.0.0-20220222234857-c00d1f31bab3 h1:aSV github.com/hdevalence/ed25519consensus v0.0.0-20220222234857-c00d1f31bab3/go.mod h1:5PC6ZNPde8bBqU/ewGZig35+UIZtw9Ytxez8/q5ZyFE= github.com/hinshun/vt10x v0.0.0-20220119200601-820417d04eec h1:qv2VnGeEQHchGaZ/u7lxST/RaJw+cv273q79D81Xbog= github.com/hinshun/vt10x v0.0.0-20220119200601-820417d04eec/go.mod h1:Q48J4R4DvxnHolD5P8pOtXigYlRuPLGl6moFx3ulM68= +github.com/hinshun/vt10x v0.0.0-20220301184237-5011da428d02 h1:AgcIVYPa6XJnU3phs104wLj8l5GEththEw6+F79YsIY= +github.com/hinshun/vt10x v0.0.0-20220301184237-5011da428d02/go.mod h1:Q48J4R4DvxnHolD5P8pOtXigYlRuPLGl6moFx3ulM68= github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= github.com/huandu/xstrings v1.0.0/go.mod h1:4qWG/gcEcfX4z/mBDHJ++3ReCw9ibxbsNJbcucJdbSo= github.com/huandu/xstrings v1.2.0/go.mod h1:DvyZB1rfVYsBIigL8HwpZgxHwXozlTgGqn63UyNX5k4= diff --git a/pty/pty_windows.go b/pty/pty_windows.go index 8c4179be3bc6f..d0a10cb2b5b04 100644 --- a/pty/pty_windows.go +++ b/pty/pty_windows.go @@ -141,6 +141,10 @@ func (p *ptyWindows) Close() error { } p.closed = true + // if we are running a command in the PTY, the corresponding *windowsProcess + // may have already closed the PseudoConsole when the command exited, so that + // output reads can get to EOF. In that case, we don't need to close it + // again here. if p.console != windows.InvalidHandle { ret, _, err := procClosePseudoConsole.Call(uintptr(p.console)) if ret < 0 { @@ -169,9 +173,11 @@ func (p *windowsProcess) waitInternal() { defer p.pw.closeMutex.Unlock() if p.pw.console != windows.InvalidHandle { ret, _, err := procClosePseudoConsole.Call(uintptr(p.pw.console)) - if ret < 0 { - // not much we can do here... - panic(err) + if ret < 0 && p.cmdErr == nil { + // if we already have an error from the command, prefer that error + // but if the command succeeded and closing the PseudoConsole fails + // then record that errror so that we have a chance to see it + p.cmdErr = err } p.pw.console = windows.InvalidHandle } diff --git a/pty/start_other_test.go b/pty/start_other_test.go index 1aba204126e94..de3dd2adbea45 100644 --- a/pty/start_other_test.go +++ b/pty/start_other_test.go @@ -62,3 +62,23 @@ func TestStart(t *testing.T) { require.NoError(t, err) }) } + +// these constants/vars are used by Test_Start_copy + +const cmdEcho = "echo" + +var argEcho = []string{"test"} + +// these constants/vars are used by Test_Start_truncate + +const countEnd = 1000 +const cmdCount = "sh" + +var argCount = []string{"-c", ` +i=0 +while [ $i -ne 1000 ] +do + i=$(($i+1)) + echo "$i" +done +`} diff --git a/pty/start_test.go b/pty/start_test.go new file mode 100644 index 0000000000000..69246388ae415 --- /dev/null +++ b/pty/start_test.go @@ -0,0 +1,158 @@ +package pty_test + +import ( + "bytes" + "context" + "fmt" + "io" + "os/exec" + "strings" + "testing" + "time" + + "github.com/coder/coder/pty" + "github.com/hinshun/vt10x" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// Test_Start_copy tests that we can use io.Copy() on command output +// without deadlocking. +func Test_Start_copy(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) + defer cancel() + + pc, cmd, err := pty.Start(exec.CommandContext(ctx, cmdEcho, argEcho...)) + require.NoError(t, err) + b := &bytes.Buffer{} + readDone := make(chan error) + go func() { + _, err := io.Copy(b, pc.OutputReader()) + readDone <- err + }() + + select { + case err := <-readDone: + require.NoError(t, err) + case <-ctx.Done(): + t.Error("read timed out") + } + assert.Contains(t, b.String(), "test") + + cmdDone := make(chan error) + go func() { + cmdDone <- cmd.Wait() + }() + + select { + case err := <-cmdDone: + require.NoError(t, err) + case <-ctx.Done(): + t.Error("cmd.Wait() timed out") + } +} + +// Test_Start_truncation tests that we can read command ouput without truncation +// even after the command has exited. +func Test_Start_trucation(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*1000) + defer cancel() + + pc, cmd, err := pty.Start(exec.CommandContext(ctx, cmdCount, argCount...)) + + require.NoError(t, err) + readDone := make(chan struct{}) + go func() { + defer close(readDone) + // avoid buffered IO so that we can precisely control how many bytes to read. + n := 1 + for n < countEnd-25 { + want := fmt.Sprintf("%d", n) + err := readUntil(ctx, t, want, pc.OutputReader()) + require.NoError(t, err, "want: %s", want) + n++ + } + }() + + select { + case <-readDone: + // OK! + case <-ctx.Done(): + t.Error("read timed out") + } + + cmdDone := make(chan error) + go func() { + cmdDone <- cmd.Wait() + }() + + select { + case err := <-cmdDone: + require.NoError(t, err) + case <-ctx.Done(): + t.Error("cmd.Wait() timed out") + } + + // do our final 25 reads, to make sure the output wasn't lost + readDone = make(chan struct{}) + go func() { + defer close(readDone) + // avoid buffered IO so that we can precisely control how many bytes to read. + n := countEnd - 25 + for n <= countEnd { + want := fmt.Sprintf("%d", n) + err := readUntil(ctx, t, want, pc.OutputReader()) + require.NoError(t, err, "want: %s", want) + n++ + } + // ensure we still get to EOF + endB := &bytes.Buffer{} + _, err := io.Copy(endB, pc.OutputReader()) + require.NoError(t, err) + + }() + + select { + case <-readDone: + // OK! + case <-ctx.Done(): + t.Error("read timed out") + } +} + +// readUntil reads one byte at a time until we either see the string we want, or the context expires +func readUntil(ctx context.Context, t *testing.T, want string, r io.Reader) error { + // output can contain virtual terminal sequences, so we need to parse these + // to correctly interpret getting what we want. + term := vt10x.New(vt10x.WithSize(80, 80)) + readErrs := make(chan error, 1) + for { + b := make([]byte, 1) + go func() { + _, err := r.Read(b) + readErrs <- err + }() + select { + case err := <-readErrs: + if err != nil { + t.Logf("err: %v\ngot: %v", err, term) + return err + } + term.Write(b) + case <-ctx.Done(): + return ctx.Err() + } + got := term.String() + lines := strings.Split(got, "\n") + for _, line := range lines { + if strings.TrimSpace(line) == want { + t.Logf("want: %v\n got:%v", want, line) + return nil + } + } + } +} diff --git a/pty/start_windows_test.go b/pty/start_windows_test.go index fe683bd16029c..0f1ebda092f94 100644 --- a/pty/start_windows_test.go +++ b/pty/start_windows_test.go @@ -4,16 +4,10 @@ package pty_test import ( - "bytes" - "context" "fmt" - "io" "os/exec" - "strings" "testing" - "time" - "github.com/coder/coder/pty" "github.com/coder/coder/pty/ptytest" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -58,135 +52,15 @@ func TestStart(t *testing.T) { }) } -func Test_Start_copy(t *testing.T) { - t.Parallel() - - ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) - defer cancel() +// these constants/vars are used by Test_Start_copy - pc, cmd, err := pty.Start(exec.CommandContext(ctx, "cmd.exe", "/c", "echo", "test")) - require.NoError(t, err) - b := &bytes.Buffer{} - readDone := make(chan error) - go func() { - _, err := io.Copy(b, pc.OutputReader()) - readDone <- err - }() +const cmdEcho = "cmd.exe" - select { - case err := <-readDone: - require.NoError(t, err) - case <-ctx.Done(): - t.Error("read timed out") - } - assert.Contains(t, b.String(), "test") +var argEcho = []string{"/c", "echo", "test"} - cmdDone := make(chan error) - go func() { - cmdDone <- cmd.Wait() - }() - - select { - case err := <-cmdDone: - require.NoError(t, err) - case <-ctx.Done(): - t.Error("cmd.Wait() timed out") - } -} +// these constants/vars are used by Test_Start_truncate const countEnd = 1000 +const cmdCount = "cmd.exe" -func Test_Start_trucation(t *testing.T) { - t.Parallel() - - ctx, cancel := context.WithTimeout(context.Background(), time.Second*1000) - defer cancel() - - pc, cmd, err := pty.Start(exec.CommandContext(ctx, - "cmd.exe", "/c", - fmt.Sprintf("for /L %%n in (1,1,%d) do @echo %%n", countEnd))) - require.NoError(t, err) - readDone := make(chan struct{}) - go func() { - defer close(readDone) - // avoid buffered IO so that we can precisely control how many bytes to read. - n := 1 - for n < countEnd-25 { - want := fmt.Sprintf("%d\r\n", n) - // the output also contains virtual terminal sequences - // so just read until we see the number we want. - err := readUntil(ctx, want, pc.OutputReader()) - require.NoError(t, err, "want: %s", want) - n++ - } - }() - - select { - case <-readDone: - // OK! - case <-ctx.Done(): - t.Error("read timed out") - } - - cmdDone := make(chan error) - go func() { - cmdDone <- cmd.Wait() - }() - - select { - case err := <-cmdDone: - require.NoError(t, err) - case <-ctx.Done(): - t.Error("cmd.Wait() timed out") - } - - // do our final 25 reads, to make sure the output wasn't lost - readDone = make(chan struct{}) - go func() { - defer close(readDone) - // avoid buffered IO so that we can precisely control how many bytes to read. - n := countEnd - 25 - for n <= countEnd { - want := fmt.Sprintf("%d\r\n", n) - err := readUntil(ctx, want, pc.OutputReader()) - if n < countEnd { - require.NoError(t, err, "want: %s", want) - } else { - require.ErrorIs(t, err, io.EOF) - } - n++ - } - }() - - select { - case <-readDone: - // OK! - case <-ctx.Done(): - t.Error("read timed out") - } -} - -// readUntil reads one byte at a time until we either see the string we want, or the context expires -func readUntil(ctx context.Context, want string, r io.Reader) error { - got := "" - readErrs := make(chan error, 1) - for { - b := make([]byte, 1) - go func() { - _, err := r.Read(b) - readErrs <- err - }() - select { - case err := <-readErrs: - if err != nil { - return err - } - got = got + string(b) - case <-ctx.Done(): - return ctx.Err() - } - if strings.Contains(got, want) { - return nil - } - } -} +var argCount = []string{"/c", fmt.Sprintf("for /L %%n in (1,1,%d) do @echo %%n", countEnd)} From d6e131ccbb03b40285eef62bcc74cfec03fd6eeb Mon Sep 17 00:00:00 2001 From: Spike Curtis Date: Tue, 18 Apr 2023 07:53:18 +0000 Subject: [PATCH 10/21] Appease linter/fmt Signed-off-by: Spike Curtis --- pty/start_other_test.go | 6 ++++-- pty/start_test.go | 24 +++++++++++++++--------- pty/start_windows_test.go | 6 ++++-- 3 files changed, 23 insertions(+), 13 deletions(-) diff --git a/pty/start_other_test.go b/pty/start_other_test.go index de3dd2adbea45..264f7912a89cc 100644 --- a/pty/start_other_test.go +++ b/pty/start_other_test.go @@ -71,8 +71,10 @@ var argEcho = []string{"test"} // these constants/vars are used by Test_Start_truncate -const countEnd = 1000 -const cmdCount = "sh" +const ( + countEnd = 1000 + cmdCount = "sh" +) var argCount = []string{"-c", ` i=0 diff --git a/pty/start_test.go b/pty/start_test.go index 69246388ae415..4c8497de37800 100644 --- a/pty/start_test.go +++ b/pty/start_test.go @@ -8,12 +8,13 @@ import ( "os/exec" "strings" "testing" - "time" - "github.com/coder/coder/pty" "github.com/hinshun/vt10x" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + "github.com/coder/coder/pty" + "github.com/coder/coder/testutil" ) // Test_Start_copy tests that we can use io.Copy() on command output @@ -21,7 +22,7 @@ import ( func Test_Start_copy(t *testing.T) { t.Parallel() - ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) defer cancel() pc, cmd, err := pty.Start(exec.CommandContext(ctx, cmdEcho, argEcho...)) @@ -54,12 +55,12 @@ func Test_Start_copy(t *testing.T) { } } -// Test_Start_truncation tests that we can read command ouput without truncation +// Test_Start_truncation tests that we can read command output without truncation // even after the command has exited. func Test_Start_trucation(t *testing.T) { t.Parallel() - ctx, cancel := context.WithTimeout(context.Background(), time.Second*1000) + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) defer cancel() pc, cmd, err := pty.Start(exec.CommandContext(ctx, cmdCount, argCount...)) @@ -73,7 +74,10 @@ func Test_Start_trucation(t *testing.T) { for n < countEnd-25 { want := fmt.Sprintf("%d", n) err := readUntil(ctx, t, want, pc.OutputReader()) - require.NoError(t, err, "want: %s", want) + assert.NoError(t, err, "want: %s", want) + if err != nil { + return + } n++ } }() @@ -106,14 +110,16 @@ func Test_Start_trucation(t *testing.T) { for n <= countEnd { want := fmt.Sprintf("%d", n) err := readUntil(ctx, t, want, pc.OutputReader()) - require.NoError(t, err, "want: %s", want) + assert.NoError(t, err, "want: %s", want) + if err != nil { + return + } n++ } // ensure we still get to EOF endB := &bytes.Buffer{} _, err := io.Copy(endB, pc.OutputReader()) - require.NoError(t, err) - + assert.NoError(t, err) }() select { diff --git a/pty/start_windows_test.go b/pty/start_windows_test.go index 0f1ebda092f94..a8e287e1ed40a 100644 --- a/pty/start_windows_test.go +++ b/pty/start_windows_test.go @@ -60,7 +60,9 @@ var argEcho = []string{"/c", "echo", "test"} // these constants/vars are used by Test_Start_truncate -const countEnd = 1000 -const cmdCount = "cmd.exe" +const ( + countEnd = 1000 + cmdCount = "cmd.exe" +) var argCount = []string{"/c", fmt.Sprintf("for /L %%n in (1,1,%d) do @echo %%n", countEnd)} From 2c9c6efbc10dc2cac4a508518b0c3597e200387d Mon Sep 17 00:00:00 2001 From: Spike Curtis Date: Tue, 18 Apr 2023 10:17:02 +0000 Subject: [PATCH 11/21] Fix typo Signed-off-by: Spike Curtis --- pty/pty_windows.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pty/pty_windows.go b/pty/pty_windows.go index d0a10cb2b5b04..3a171c15d939f 100644 --- a/pty/pty_windows.go +++ b/pty/pty_windows.go @@ -176,7 +176,7 @@ func (p *windowsProcess) waitInternal() { if ret < 0 && p.cmdErr == nil { // if we already have an error from the command, prefer that error // but if the command succeeded and closing the PseudoConsole fails - // then record that errror so that we have a chance to see it + // then record that error so that we have a chance to see it p.cmdErr = err } p.pw.console = windows.InvalidHandle From e39e885590f5a4c434bfaeb614b3056de1c5d2a8 Mon Sep 17 00:00:00 2001 From: Spike Curtis Date: Tue, 18 Apr 2023 14:36:38 +0400 Subject: [PATCH 12/21] Rework truncation test to not assume OS buffers Signed-off-by: Spike Curtis --- pty/start_test.go | 44 ++++++++++++++------------------------------ 1 file changed, 14 insertions(+), 30 deletions(-) diff --git a/pty/start_test.go b/pty/start_test.go index 4c8497de37800..3b6ca9b8a951e 100644 --- a/pty/start_test.go +++ b/pty/start_test.go @@ -8,6 +8,7 @@ import ( "os/exec" "strings" "testing" + "time" "github.com/hinshun/vt10x" "github.com/stretchr/testify/assert" @@ -60,7 +61,7 @@ func Test_Start_copy(t *testing.T) { func Test_Start_trucation(t *testing.T) { t.Parallel() - ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong) defer cancel() pc, cmd, err := pty.Start(exec.CommandContext(ctx, cmdCount, argCount...)) @@ -71,7 +72,7 @@ func Test_Start_trucation(t *testing.T) { defer close(readDone) // avoid buffered IO so that we can precisely control how many bytes to read. n := 1 - for n < countEnd-25 { + for n <= countEnd { want := fmt.Sprintf("%d", n) err := readUntil(ctx, t, want, pc.OutputReader()) assert.NoError(t, err, "want: %s", want) @@ -79,16 +80,20 @@ func Test_Start_trucation(t *testing.T) { return } n++ + if (countEnd - n) < 100 { + // If the OS buffers the output, the process can exit even if + // we're not done reading. We want to slow our reads so that + // if there is a race between reading the data and it being + // truncated, we will lose and fail the test. + time.Sleep(testutil.IntervalFast) + } } + // ensure we still get to EOF + endB := &bytes.Buffer{} + _, err := io.Copy(endB, pc.OutputReader()) + assert.NoError(t, err) }() - select { - case <-readDone: - // OK! - case <-ctx.Done(): - t.Error("read timed out") - } - cmdDone := make(chan error) go func() { cmdDone <- cmd.Wait() @@ -101,27 +106,6 @@ func Test_Start_trucation(t *testing.T) { t.Error("cmd.Wait() timed out") } - // do our final 25 reads, to make sure the output wasn't lost - readDone = make(chan struct{}) - go func() { - defer close(readDone) - // avoid buffered IO so that we can precisely control how many bytes to read. - n := countEnd - 25 - for n <= countEnd { - want := fmt.Sprintf("%d", n) - err := readUntil(ctx, t, want, pc.OutputReader()) - assert.NoError(t, err, "want: %s", want) - if err != nil { - return - } - n++ - } - // ensure we still get to EOF - endB := &bytes.Buffer{} - _, err := io.Copy(endB, pc.OutputReader()) - assert.NoError(t, err) - }() - select { case <-readDone: // OK! From 8ec3d1fd3a1b9bee1ce40f303fa9df2d52248c5d Mon Sep 17 00:00:00 2001 From: Spike Curtis Date: Tue, 18 Apr 2023 15:15:19 +0400 Subject: [PATCH 13/21] Disable orphan test on Windows --- uses sh Signed-off-by: Spike Curtis --- agent/agentssh/agentssh_internal_test.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/agent/agentssh/agentssh_internal_test.go b/agent/agentssh/agentssh_internal_test.go index 28d0c95ec3529..699495ccd2339 100644 --- a/agent/agentssh/agentssh_internal_test.go +++ b/agent/agentssh/agentssh_internal_test.go @@ -1,3 +1,5 @@ +//go:build !windows + package agentssh import ( From df424e6b461c3cb3e8e251a9c8583473f90980d8 Mon Sep 17 00:00:00 2001 From: Spike Curtis Date: Wed, 19 Apr 2023 10:17:17 +0400 Subject: [PATCH 14/21] agent_test running SSH in pty use ptytest.Start Signed-off-by: Spike Curtis --- agent/agent_test.go | 59 +++++++++++++-------------------------------- 1 file changed, 17 insertions(+), 42 deletions(-) diff --git a/agent/agent_test.go b/agent/agent_test.go index aa58f22ace474..939e92e6b9b3e 100644 --- a/agent/agent_test.go +++ b/agent/agent_test.go @@ -24,6 +24,8 @@ import ( "testing" "time" + "github.com/coder/coder/pty" + scp "github.com/bramvdbogaerde/go-scp" "github.com/google/uuid" "github.com/pion/udp" @@ -481,17 +483,10 @@ func TestAgent_TCPLocalForwarding(t *testing.T) { } }() - pty := ptytest.New(t) - - cmd := setupSSHCommand(t, []string{"-L", fmt.Sprintf("%d:127.0.0.1:%d", randomPort, remotePort)}, []string{"sleep", "5"}) - cmd.Stdin = pty.Input() - cmd.Stdout = pty.Output() - cmd.Stderr = pty.Output() - err = cmd.Start() - require.NoError(t, err) + _, proc := setupSSHCommand(t, []string{"-L", fmt.Sprintf("%d:127.0.0.1:%d", randomPort, remotePort)}, []string{"sleep", "5"}) go func() { - err := cmd.Wait() + err := proc.Wait() select { case <-done: default: @@ -523,7 +518,7 @@ func TestAgent_TCPLocalForwarding(t *testing.T) { <-done - _ = cmd.Process.Kill() + _ = proc.Kill() } //nolint:paralleltest // This test reserves a port. @@ -562,17 +557,10 @@ func TestAgent_TCPRemoteForwarding(t *testing.T) { } }() - pty := ptytest.New(t) - - cmd := setupSSHCommand(t, []string{"-R", fmt.Sprintf("127.0.0.1:%d:127.0.0.1:%d", randomPort, localPort)}, []string{"sleep", "5"}) - cmd.Stdin = pty.Input() - cmd.Stdout = pty.Output() - cmd.Stderr = pty.Output() - err = cmd.Start() - require.NoError(t, err) + _, proc := setupSSHCommand(t, []string{"-R", fmt.Sprintf("127.0.0.1:%d:127.0.0.1:%d", randomPort, localPort)}, []string{"sleep", "5"}) go func() { - err := cmd.Wait() + err := proc.Wait() select { case <-done: default: @@ -604,7 +592,7 @@ func TestAgent_TCPRemoteForwarding(t *testing.T) { <-done - _ = cmd.Process.Kill() + _ = proc.Kill() } func TestAgent_UnixLocalForwarding(t *testing.T) { @@ -641,17 +629,10 @@ func TestAgent_UnixLocalForwarding(t *testing.T) { } }() - pty := ptytest.New(t) - - cmd := setupSSHCommand(t, []string{"-L", fmt.Sprintf("%s:%s", localSocketPath, remoteSocketPath)}, []string{"sleep", "5"}) - cmd.Stdin = pty.Input() - cmd.Stdout = pty.Output() - cmd.Stderr = pty.Output() - err = cmd.Start() - require.NoError(t, err) + _, proc := setupSSHCommand(t, []string{"-L", fmt.Sprintf("%s:%s", localSocketPath, remoteSocketPath)}, []string{"sleep", "5"}) go func() { - err := cmd.Wait() + err := proc.Wait() select { case <-done: default: @@ -676,7 +657,7 @@ func TestAgent_UnixLocalForwarding(t *testing.T) { _ = conn.Close() <-done - _ = cmd.Process.Kill() + _ = proc.Kill() } func TestAgent_UnixRemoteForwarding(t *testing.T) { @@ -713,17 +694,10 @@ func TestAgent_UnixRemoteForwarding(t *testing.T) { } }() - pty := ptytest.New(t) - - cmd := setupSSHCommand(t, []string{"-R", fmt.Sprintf("%s:%s", remoteSocketPath, localSocketPath)}, []string{"sleep", "5"}) - cmd.Stdin = pty.Input() - cmd.Stdout = pty.Output() - cmd.Stderr = pty.Output() - err = cmd.Start() - require.NoError(t, err) + _, proc := setupSSHCommand(t, []string{"-R", fmt.Sprintf("%s:%s", remoteSocketPath, localSocketPath)}, []string{"sleep", "5"}) go func() { - err := cmd.Wait() + err := proc.Wait() select { case <-done: default: @@ -750,7 +724,7 @@ func TestAgent_UnixRemoteForwarding(t *testing.T) { <-done - _ = cmd.Process.Kill() + _ = proc.Kill() } func TestAgent_SFTP(t *testing.T) { @@ -1629,7 +1603,7 @@ func TestAgent_WriteVSCodeConfigs(t *testing.T) { }, testutil.WaitShort, testutil.IntervalFast) } -func setupSSHCommand(t *testing.T, beforeArgs []string, afterArgs []string) *exec.Cmd { +func setupSSHCommand(t *testing.T, beforeArgs []string, afterArgs []string) (*ptytest.PTYCmd, pty.Process) { //nolint:dogsled agentConn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0) listener, err := net.Listen("tcp", "127.0.0.1:0") @@ -1671,7 +1645,8 @@ func setupSSHCommand(t *testing.T, beforeArgs []string, afterArgs []string) *exe "host", ) args = append(args, afterArgs...) - return exec.Command("ssh", args...) + cmd := exec.Command("ssh", args...) + return ptytest.Start(t, cmd) } func setupSSHSession(t *testing.T, options agentsdk.Manifest) *ssh.Session { From 50e3fec9f81f118005278b767741c0188857ca04 Mon Sep 17 00:00:00 2001 From: Spike Curtis Date: Wed, 19 Apr 2023 10:23:27 +0400 Subject: [PATCH 15/21] More detail about closing pseudoconsole on windows Signed-off-by: Spike Curtis --- pty/pty_windows.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pty/pty_windows.go b/pty/pty_windows.go index 3a171c15d939f..49df5d029e080 100644 --- a/pty/pty_windows.go +++ b/pty/pty_windows.go @@ -169,6 +169,10 @@ func (p *ptyWindows) Close() error { func (p *windowsProcess) waitInternal() { defer func() { // close the pseudoconsole handle when the process exits, if it hasn't already been closed. + // this is important because the PseudoConsole (conhost.exe) holds the write-end + // of the output pipe. If it is not closed, reads on that pipe will block, even though + // the command has exited. + // c.f. https://devblogs.microsoft.com/commandline/windows-command-line-introducing-the-windows-pseudo-console-conpty/ p.pw.closeMutex.Lock() defer p.pw.closeMutex.Unlock() if p.pw.console != windows.InvalidHandle { From c09083ef56915c900b93707dcd86b267d8cf40f1 Mon Sep 17 00:00:00 2001 From: Spike Curtis Date: Thu, 20 Apr 2023 11:07:32 +0000 Subject: [PATCH 16/21] Code review fixes Signed-off-by: Spike Curtis --- agent/agent_test.go | 3 +-- go.mod | 3 +-- go.sum | 1 - pty/pty_other.go | 3 +-- pty/pty_windows.go | 6 ++++-- pty/ptytest/ptytest.go | 8 ++++---- pty/start_test.go | 12 ++++++------ pty/start_windows.go | 38 +++++++++++++++++++++++++++----------- 8 files changed, 44 insertions(+), 30 deletions(-) diff --git a/agent/agent_test.go b/agent/agent_test.go index 939e92e6b9b3e..7d1ab52ff992c 100644 --- a/agent/agent_test.go +++ b/agent/agent_test.go @@ -24,8 +24,6 @@ import ( "testing" "time" - "github.com/coder/coder/pty" - scp "github.com/bramvdbogaerde/go-scp" "github.com/google/uuid" "github.com/pion/udp" @@ -47,6 +45,7 @@ import ( "github.com/coder/coder/coderd/httpapi" "github.com/coder/coder/codersdk" "github.com/coder/coder/codersdk/agentsdk" + "github.com/coder/coder/pty" "github.com/coder/coder/pty/ptytest" "github.com/coder/coder/tailnet" "github.com/coder/coder/tailnet/tailnettest" diff --git a/go.mod b/go.mod index b7f3de4543d85..636bc3e04d0bc 100644 --- a/go.mod +++ b/go.mod @@ -108,6 +108,7 @@ require ( github.com/hashicorp/terraform-config-inspect v0.0.0-20211115214459-90acf1ca460f github.com/hashicorp/terraform-json v0.14.0 github.com/hashicorp/yamux v0.0.0-20220718163420-dd80a7ee44ce + github.com/hinshun/vt10x v0.0.0-20220301184237-5011da428d02 github.com/imulab/go-scim/pkg/v2 v2.2.0 github.com/jedib0t/go-pretty/v6 v6.4.0 github.com/jmoiron/sqlx v1.3.5 @@ -174,8 +175,6 @@ require ( tailscale.com v1.32.2 ) -require github.com/hinshun/vt10x v0.0.0-20220301184237-5011da428d02 // indirect - require ( cloud.google.com/go/compute v1.18.0 // indirect cloud.google.com/go/logging v1.6.1 // indirect diff --git a/go.sum b/go.sum index 71ec67afb581c..4d4a2efb4dba1 100644 --- a/go.sum +++ b/go.sum @@ -1086,7 +1086,6 @@ github.com/hashicorp/yamux v0.0.0-20220718163420-dd80a7ee44ce h1:7FO+LmZwiG/eDsB github.com/hashicorp/yamux v0.0.0-20220718163420-dd80a7ee44ce/go.mod h1:CtWFDAQgb7dxtzFs4tWbplKIe2jSi3+5vKbgIO0SLnQ= github.com/hdevalence/ed25519consensus v0.0.0-20220222234857-c00d1f31bab3 h1:aSVUgRRRtOrZOC1fYmY9gV0e9z/Iu+xNVSASWjsuyGU= github.com/hdevalence/ed25519consensus v0.0.0-20220222234857-c00d1f31bab3/go.mod h1:5PC6ZNPde8bBqU/ewGZig35+UIZtw9Ytxez8/q5ZyFE= -github.com/hinshun/vt10x v0.0.0-20220119200601-820417d04eec h1:qv2VnGeEQHchGaZ/u7lxST/RaJw+cv273q79D81Xbog= github.com/hinshun/vt10x v0.0.0-20220119200601-820417d04eec/go.mod h1:Q48J4R4DvxnHolD5P8pOtXigYlRuPLGl6moFx3ulM68= github.com/hinshun/vt10x v0.0.0-20220301184237-5011da428d02 h1:AgcIVYPa6XJnU3phs104wLj8l5GEththEw6+F79YsIY= github.com/hinshun/vt10x v0.0.0-20220301184237-5011da428d02/go.mod h1:Q48J4R4DvxnHolD5P8pOtXigYlRuPLGl6moFx3ulM68= diff --git a/pty/pty_other.go b/pty/pty_other.go index 1cc3a28f3ad83..a5fa9d555d545 100644 --- a/pty/pty_other.go +++ b/pty/pty_other.go @@ -10,11 +10,10 @@ import ( "runtime" "sync" - "golang.org/x/xerrors" - "github.com/creack/pty" "github.com/u-root/u-root/pkg/termios" "golang.org/x/sys/unix" + "golang.org/x/xerrors" ) func newPty(opt ...Option) (retPTY *otherPty, err error) { diff --git a/pty/pty_windows.go b/pty/pty_windows.go index 49df5d029e080..400b1deecd438 100644 --- a/pty/pty_windows.go +++ b/pty/pty_windows.go @@ -160,13 +160,15 @@ func (p *ptyWindows) Close() error { if p.outputWrite != nil { _ = p.outputWrite.Close() } - if p.inputRead.Close() != nil { + if p.inputRead != nil { _ = p.inputRead.Close() } return nil } func (p *windowsProcess) waitInternal() { + // put this on the bottom of the defer stack since the next defer can write to p.cmdErr + defer close(p.cmdDone) defer func() { // close the pseudoconsole handle when the process exits, if it hasn't already been closed. // this is important because the PseudoConsole (conhost.exe) holds the write-end @@ -186,7 +188,7 @@ func (p *windowsProcess) waitInternal() { p.pw.console = windows.InvalidHandle } }() - defer close(p.cmdDone) + state, err := p.proc.Wait() if err != nil { p.cmdErr = err diff --git a/pty/ptytest/ptytest.go b/pty/ptytest/ptytest.go index 25490658b6ce1..47c2c104444c4 100644 --- a/pty/ptytest/ptytest.go +++ b/pty/ptytest/ptytest.go @@ -32,7 +32,7 @@ func New(t *testing.T, opts ...pty.Option) *PTY { e := newExpecter(t, ptty.Output(), "cmd") r := &PTY{ - outExpecter: *e, + outExpecter: e, PTY: ptty, } // Ensure pty is cleaned up at the end of test. @@ -56,7 +56,7 @@ func Start(t *testing.T, cmd *exec.Cmd, opts ...pty.StartOption) (*PTYCmd, pty.P ex := newExpecter(t, ptty.OutputReader(), cmd.Args[0]) r := &PTYCmd{ - outExpecter: *ex, + outExpecter: ex, PTYCmd: ptty, } t.Cleanup(func() { @@ -65,7 +65,7 @@ func Start(t *testing.T, cmd *exec.Cmd, opts ...pty.StartOption) (*PTYCmd, pty.P return r, ps } -func newExpecter(t *testing.T, r io.Reader, name string) *outExpecter { +func newExpecter(t *testing.T, r io.Reader, name string) outExpecter { // Use pipe for logging. logDone := make(chan struct{}) logr, logw := io.Pipe() @@ -75,7 +75,7 @@ func newExpecter(t *testing.T, r io.Reader, name string) *outExpecter { out := newStdbuf() w := io.MultiWriter(logw, out) - ex := &outExpecter{ + ex := outExpecter{ t: t, out: out, name: name, diff --git a/pty/start_test.go b/pty/start_test.go index 3b6ca9b8a951e..d8711cb99c0a4 100644 --- a/pty/start_test.go +++ b/pty/start_test.go @@ -29,7 +29,7 @@ func Test_Start_copy(t *testing.T) { pc, cmd, err := pty.Start(exec.CommandContext(ctx, cmdEcho, argEcho...)) require.NoError(t, err) b := &bytes.Buffer{} - readDone := make(chan error) + readDone := make(chan error, 1) go func() { _, err := io.Copy(b, pc.OutputReader()) readDone <- err @@ -43,7 +43,7 @@ func Test_Start_copy(t *testing.T) { } assert.Contains(t, b.String(), "test") - cmdDone := make(chan error) + cmdDone := make(chan error, 1) go func() { cmdDone <- cmd.Wait() }() @@ -58,7 +58,7 @@ func Test_Start_copy(t *testing.T) { // Test_Start_truncation tests that we can read command output without truncation // even after the command has exited. -func Test_Start_trucation(t *testing.T) { +func Test_Start_truncation(t *testing.T) { t.Parallel() ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong) @@ -94,7 +94,7 @@ func Test_Start_trucation(t *testing.T) { assert.NoError(t, err) }() - cmdDone := make(chan error) + cmdDone := make(chan error, 1) go func() { cmdDone <- cmd.Wait() }() @@ -103,14 +103,14 @@ func Test_Start_trucation(t *testing.T) { case err := <-cmdDone: require.NoError(t, err) case <-ctx.Done(): - t.Error("cmd.Wait() timed out") + t.Fatal("cmd.Wait() timed out") } select { case <-readDone: // OK! case <-ctx.Done(): - t.Error("read timed out") + t.Fatal("read timed out") } } diff --git a/pty/start_windows.go b/pty/start_windows.go index 016822b10613a..8c612e35eec28 100644 --- a/pty/start_windows.go +++ b/pty/start_windows.go @@ -17,7 +17,7 @@ import ( // Allocates a PTY and starts the specified command attached to it. // See: https://docs.microsoft.com/en-us/windows/console/creating-a-pseudoconsole-session#creating-the-hosted-process -func startPty(cmd *exec.Cmd, opt ...StartOption) (PTY, Process, error) { +func startPty(cmd *exec.Cmd, opt ...StartOption) (_ PTY, _ Process, retErr error) { var opts startOptions for _, o := range opt { o(&opts) @@ -45,10 +45,18 @@ func startPty(cmd *exec.Cmd, opt ...StartOption) (PTY, Process, error) { if err != nil { return nil, nil, err } + pty, err := newPty(opts.ptyOpts...) if err != nil { return nil, nil, err } + defer func() { + if retErr != nil { + // we hit some error finishing setup; close pty, so + // we don't leak the kernel resources associated with it + _ := pty.Close() + } + }() winPty := pty.(*ptyWindows) if winPty.opts.sshReq != nil { cmd.Env = append(cmd.Env, fmt.Sprintf("SSH_TTY=%s", winPty.Name())) @@ -87,6 +95,24 @@ func startPty(cmd *exec.Cmd, opt ...StartOption) (PTY, Process, error) { } defer windows.CloseHandle(processInfo.Thread) defer windows.CloseHandle(processInfo.Process) + + process, err := os.FindProcess(int(processInfo.ProcessId)) + if err != nil { + return nil, nil, xerrors.Errorf("find process %d: %w", processInfo.ProcessId, err) + } + wp := &windowsProcess{ + cmdDone: make(chan any), + proc: process, + pw: winPty, + } + defer func() { + if retErr != nil { + // if we later error out, kill the process since + // the caller will have no way to interact with it + _ = process.Kill() + } + }() + // Now that we've started the command, and passed the pseudoconsole to it, // close the output write and input read files, so that the other process // has the only handles to them. Once the process closes the console, there @@ -103,16 +129,6 @@ func startPty(cmd *exec.Cmd, opt ...StartOption) (PTY, Process, error) { if errI != nil { return nil, nil, errI } - - process, err := os.FindProcess(int(processInfo.ProcessId)) - if err != nil { - return nil, nil, xerrors.Errorf("find process %d: %w", processInfo.ProcessId, err) - } - wp := &windowsProcess{ - cmdDone: make(chan any), - proc: process, - pw: winPty, - } go wp.waitInternal() return pty, wp, nil } From 439107df051261b04a4165711a768bb774bee117 Mon Sep 17 00:00:00 2001 From: Spike Curtis Date: Thu, 20 Apr 2023 11:12:41 +0000 Subject: [PATCH 17/21] Rearrange ptytest method order Signed-off-by: Spike Curtis --- pty/ptytest/ptytest.go | 108 ++++++++++++++++++++--------------------- 1 file changed, 54 insertions(+), 54 deletions(-) diff --git a/pty/ptytest/ptytest.go b/pty/ptytest/ptytest.go index 47c2c104444c4..69eb81026efbe 100644 --- a/pty/ptytest/ptytest.go +++ b/pty/ptytest/ptytest.go @@ -88,7 +88,7 @@ func newExpecter(t *testing.T, r io.Reader, name string) outExpecter { err := c.Close() ex.logf("closed %s: %v", name, err) } - // Set the actual close function for the tpty. + // Set the actual close function for the outExpecter. ex.close = func(reason string) error { ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) defer cancel() @@ -145,45 +145,6 @@ type outExpecter struct { runeReader *bufio.Reader } -type PTY struct { - outExpecter - pty.PTY -} - -type PTYCmd struct { - outExpecter - pty.PTYCmd -} - -func (p *PTY) Close() error { - p.t.Helper() - pErr := p.PTY.Close() - eErr := p.outExpecter.close("close") - if pErr != nil { - return pErr - } - return eErr -} - -func (p *PTYCmd) Close() error { - p.t.Helper() - pErr := p.PTYCmd.Close() - eErr := p.outExpecter.close("close") - if pErr != nil { - return pErr - } - return eErr -} - -func (p *PTY) Attach(inv *clibase.Invocation) *PTY { - p.t.Helper() - - inv.Stdout = p.Output() - inv.Stderr = p.Output() - inv.Stdin = p.Input() - return p -} - func (e *outExpecter) ExpectMatch(str string) string { e.t.Helper() @@ -335,6 +296,48 @@ func (e *outExpecter) doMatchWithDeadline(ctx context.Context, name string, fn f } } +func (e *outExpecter) logf(format string, args ...interface{}) { + e.t.Helper() + + // Match regular logger timestamp format, we seem to be logging in + // UTC in other places as well, so match here. + e.t.Logf("%s: %s: %s", time.Now().UTC().Format("2006-01-02 15:04:05.000"), e.name, fmt.Sprintf(format, args...)) +} + +func (e *outExpecter) fatalf(reason string, format string, args ...interface{}) { + e.t.Helper() + + // Ensure the message is part of the normal log stream before + // failing the test. + e.logf("%s: %s", reason, fmt.Sprintf(format, args...)) + + require.FailNowf(e.t, reason, format, args...) +} + +type PTY struct { + outExpecter + pty.PTY +} + +func (p *PTY) Close() error { + p.t.Helper() + pErr := p.PTY.Close() + eErr := p.outExpecter.close("close") + if pErr != nil { + return pErr + } + return eErr +} + +func (p *PTY) Attach(inv *clibase.Invocation) *PTY { + p.t.Helper() + + inv.Stdout = p.Output() + inv.Stderr = p.Output() + inv.Stdin = p.Input() + return p +} + func (p *PTY) Write(r rune) { p.t.Helper() @@ -355,22 +358,19 @@ func (p *PTY) WriteLine(str string) { require.NoError(p.t, err, "write line failed") } -func (e *outExpecter) logf(format string, args ...interface{}) { - e.t.Helper() - - // Match regular logger timestamp format, we seem to be logging in - // UTC in other places as well, so match here. - e.t.Logf("%s: %s: %s", time.Now().UTC().Format("2006-01-02 15:04:05.000"), e.name, fmt.Sprintf(format, args...)) +type PTYCmd struct { + outExpecter + pty.PTYCmd } -func (e *outExpecter) fatalf(reason string, format string, args ...interface{}) { - e.t.Helper() - - // Ensure the message is part of the normal log stream before - // failing the test. - e.logf("%s: %s", reason, fmt.Sprintf(format, args...)) - - require.FailNowf(e.t, reason, format, args...) +func (p *PTYCmd) Close() error { + p.t.Helper() + pErr := p.PTYCmd.Close() + eErr := p.outExpecter.close("close") + if pErr != nil { + return pErr + } + return eErr } // stdbuf is like a buffered stdout, it buffers writes until read. From c6a3229e76acf62bf10b273bd5afac9b50196a91 Mon Sep 17 00:00:00 2001 From: Spike Curtis Date: Fri, 21 Apr 2023 05:21:31 +0000 Subject: [PATCH 18/21] Protect pty.Resize on windows from races Signed-off-by: Spike Curtis --- pty/pty_windows.go | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/pty/pty_windows.go b/pty/pty_windows.go index 400b1deecd438..c595e0ee48be0 100644 --- a/pty/pty_windows.go +++ b/pty/pty_windows.go @@ -122,6 +122,12 @@ func (p *ptyWindows) InputWriter() io.Writer { } func (p *ptyWindows) Resize(height uint16, width uint16) error { + // hold the lock, so we don't race with anyone trying to close the console + p.closeMutex.Lock() + defer p.closeMutex.Unlock() + if p.closed || p.console == windows.InvalidHandle { + return pty.ErrClosed + } // Taken from: https://github.com/microsoft/hcsshim/blob/54a5ad86808d761e3e396aff3e2022840f39f9a8/internal/winapi/zsyscall_windows.go#L144 ret, _, err := procResizePseudoConsole.Call(uintptr(p.console), uintptr(*((*uint32)(unsafe.Pointer(&windows.Coord{ Y: int16(height), From aa9454695f2d03d6f79f396833a441e6bca6b07d Mon Sep 17 00:00:00 2001 From: Spike Curtis Date: Fri, 21 Apr 2023 06:14:43 +0000 Subject: [PATCH 19/21] Fix windows bugs Signed-off-by: Spike Curtis --- pty/pty_windows.go | 2 +- pty/start_windows.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pty/pty_windows.go b/pty/pty_windows.go index c595e0ee48be0..0caf91d799272 100644 --- a/pty/pty_windows.go +++ b/pty/pty_windows.go @@ -126,7 +126,7 @@ func (p *ptyWindows) Resize(height uint16, width uint16) error { p.closeMutex.Lock() defer p.closeMutex.Unlock() if p.closed || p.console == windows.InvalidHandle { - return pty.ErrClosed + return ErrClosed } // Taken from: https://github.com/microsoft/hcsshim/blob/54a5ad86808d761e3e396aff3e2022840f39f9a8/internal/winapi/zsyscall_windows.go#L144 ret, _, err := procResizePseudoConsole.Call(uintptr(p.console), uintptr(*((*uint32)(unsafe.Pointer(&windows.Coord{ diff --git a/pty/start_windows.go b/pty/start_windows.go index 8c612e35eec28..406622f2e2a72 100644 --- a/pty/start_windows.go +++ b/pty/start_windows.go @@ -54,7 +54,7 @@ func startPty(cmd *exec.Cmd, opt ...StartOption) (_ PTY, _ Process, retErr error if retErr != nil { // we hit some error finishing setup; close pty, so // we don't leak the kernel resources associated with it - _ := pty.Close() + _ = pty.Close() } }() winPty := pty.(*ptyWindows) From 0f07cb98548bf6a390d3ebb9e90720c3e79174b9 Mon Sep 17 00:00:00 2001 From: Spike Curtis Date: Mon, 24 Apr 2023 06:06:06 +0000 Subject: [PATCH 20/21] PTY doesn't extend PTYCmd Signed-off-by: Spike Curtis --- pty/pty.go | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/pty/pty.go b/pty/pty.go index e93115dc9ef45..507e9468e2084 100644 --- a/pty/pty.go +++ b/pty/pty.go @@ -33,7 +33,10 @@ type PTYCmd interface { // process retains access to _both_ ends of the pseudo-TTY (i.e. `ptm` & `pts` // on Linux). type PTY interface { - PTYCmd + io.Closer + + // Resize sets the size of the PTY. + Resize(height uint16, width uint16) error // Name of the TTY. Example on Linux would be "/dev/pts/1". Name() string From 50060d83556eb4e5207986aa72fe2b78bb695e64 Mon Sep 17 00:00:00 2001 From: Spike Curtis Date: Mon, 24 Apr 2023 08:46:23 +0000 Subject: [PATCH 21/21] Fix windows types Signed-off-by: Spike Curtis --- pty/pty_windows.go | 2 +- pty/start_windows.go | 9 ++++----- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/pty/pty_windows.go b/pty/pty_windows.go index 0caf91d799272..80f6b74f436e9 100644 --- a/pty/pty_windows.go +++ b/pty/pty_windows.go @@ -22,7 +22,7 @@ var ( ) // See: https://docs.microsoft.com/en-us/windows/console/creating-a-pseudoconsole-session -func newPty(opt ...Option) (PTY, error) { +func newPty(opt ...Option) (*ptyWindows, error) { var opts ptyOptions for _, o := range opt { o(&opts) diff --git a/pty/start_windows.go b/pty/start_windows.go index 406622f2e2a72..2811900ffc361 100644 --- a/pty/start_windows.go +++ b/pty/start_windows.go @@ -17,7 +17,7 @@ import ( // Allocates a PTY and starts the specified command attached to it. // See: https://docs.microsoft.com/en-us/windows/console/creating-a-pseudoconsole-session#creating-the-hosted-process -func startPty(cmd *exec.Cmd, opt ...StartOption) (_ PTY, _ Process, retErr error) { +func startPty(cmd *exec.Cmd, opt ...StartOption) (_ PTYCmd, _ Process, retErr error) { var opts startOptions for _, o := range opt { o(&opts) @@ -46,7 +46,7 @@ func startPty(cmd *exec.Cmd, opt ...StartOption) (_ PTY, _ Process, retErr error return nil, nil, err } - pty, err := newPty(opts.ptyOpts...) + winPty, err := newPty(opts.ptyOpts...) if err != nil { return nil, nil, err } @@ -54,10 +54,9 @@ func startPty(cmd *exec.Cmd, opt ...StartOption) (_ PTY, _ Process, retErr error if retErr != nil { // we hit some error finishing setup; close pty, so // we don't leak the kernel resources associated with it - _ = pty.Close() + _ = winPty.Close() } }() - winPty := pty.(*ptyWindows) if winPty.opts.sshReq != nil { cmd.Env = append(cmd.Env, fmt.Sprintf("SSH_TTY=%s", winPty.Name())) } @@ -130,7 +129,7 @@ func startPty(cmd *exec.Cmd, opt ...StartOption) (_ PTY, _ Process, retErr error return nil, nil, errI } go wp.waitInternal() - return pty, wp, nil + return winPty, wp, nil } // Taken from: https://github.com/microsoft/hcsshim/blob/7fbdca16f91de8792371ba22b7305bf4ca84170a/internal/exec/exec.go#L476