diff --git a/pty/ptytest/ptytest.go b/pty/ptytest/ptytest.go index 178998b5a21f9..43fbaec1109e2 100644 --- a/pty/ptytest/ptytest.go +++ b/pty/ptytest/ptytest.go @@ -7,112 +7,137 @@ import ( "io" "os" "os/exec" - "regexp" "runtime" "strings" + "sync" "testing" "time" "unicode/utf8" "github.com/stretchr/testify/require" + "golang.org/x/xerrors" "github.com/coder/coder/pty" ) -var ( - // Used to ensure terminal output doesn't have anything crazy! - // See: https://stackoverflow.com/a/29497680 - stripAnsi = regexp.MustCompile("[\u001B\u009B][[\\]()#;?]*(?:(?:(?:[a-zA-Z\\d]*(?:;[a-zA-Z\\d]*)*)?\u0007)|(?:(?:\\d{1,4}(?:;\\d{0,4})*)?[\\dA-PRZcf-ntqry=><~]))") -) - func New(t *testing.T) *PTY { ptty, err := pty.New() require.NoError(t, err) - return create(t, ptty) + return create(t, ptty, "cmd") } func Start(t *testing.T, cmd *exec.Cmd) (*PTY, *os.Process) { ptty, ps, err := pty.Start(cmd) require.NoError(t, err) - return create(t, ptty), ps + return create(t, ptty, cmd.Args[0]), ps } -func create(t *testing.T, ptty pty.PTY) *PTY { - reader, writer := io.Pipe() - scanner := bufio.NewScanner(reader) +func create(t *testing.T, ptty pty.PTY, name string) *PTY { + // Use pipe for logging. + logDone := make(chan struct{}) + logr, logw := io.Pipe() t.Cleanup(func() { - _ = reader.Close() - _ = writer.Close() + _ = logw.Close() + _ = logr.Close() + <-logDone // Guard against logging after test. }) go func() { - for scanner.Scan() { - if scanner.Err() != nil { - return - } - t.Log(stripAnsi.ReplaceAllString(scanner.Text(), "")) + defer close(logDone) + s := bufio.NewScanner(logr) + for s.Scan() { + // Quote output to avoid terminal escape codes, e.g. bell. + t.Logf("%s: stdout: %q", name, s.Text()) } }() + // Write to log and output buffer. + copyDone := make(chan struct{}) + out := newStdbuf() + w := io.MultiWriter(logw, out) + go func() { + defer close(copyDone) + _, err := io.Copy(w, ptty.Output()) + _ = out.closeErr(err) + }() t.Cleanup(func() { + _ = out.Close _ = ptty.Close() + <-copyDone }) + return &PTY{ t: t, PTY: ptty, + out: out, - outputWriter: writer, - runeReader: bufio.NewReaderSize(ptty.Output(), utf8.UTFMax), + runeReader: bufio.NewReaderSize(out, utf8.UTFMax), } } type PTY struct { t *testing.T pty.PTY + out *stdbuf - outputWriter io.Writer - runeReader *bufio.Reader + runeReader *bufio.Reader } func (p *PTY) ExpectMatch(str string) string { + p.t.Helper() + + timeout, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + var buffer bytes.Buffer - multiWriter := io.MultiWriter(&buffer, p.outputWriter) - runeWriter := bufio.NewWriterSize(multiWriter, utf8.UTFMax) - complete, cancelFunc := context.WithCancel(context.Background()) - defer cancelFunc() + match := make(chan error, 1) go func() { - timer := time.NewTimer(10 * time.Second) - defer timer.Stop() - select { - case <-complete.Done(): - return - case <-timer.C: - } - _ = p.Close() - p.t.Errorf("%s match exceeded deadline: wanted %q; got %q", time.Now(), str, buffer.String()) + defer close(match) + match <- func() error { + for { + r, _, err := p.runeReader.ReadRune() + if err != nil { + return err + } + _, err = buffer.WriteRune(r) + if err != nil { + return err + } + if strings.Contains(buffer.String(), str) { + return nil + } + } + }() }() - for { - var r rune - r, _, err := p.runeReader.ReadRune() - require.NoError(p.t, err) - _, err = runeWriter.WriteRune(r) - require.NoError(p.t, err) - err = runeWriter.Flush() - require.NoError(p.t, err) - if strings.Contains(buffer.String(), str) { - break + + select { + case err := <-match: + if err != nil { + p.t.Fatalf("%s: read error: %v (wanted %q; got %q)", time.Now(), err, str, buffer.String()) + return "" } + p.t.Logf("%s: matched %q = %q", time.Now(), str, buffer.String()) + return buffer.String() + case <-timeout.Done(): + // Ensure goroutine is cleaned up before test exit. + _ = p.out.closeErr(p.Close()) + <-match + + p.t.Fatalf("%s: match exceeded deadline: wanted %q; got %q", time.Now(), str, buffer.String()) + return "" } - p.t.Logf("matched %q = %q", str, stripAnsi.ReplaceAllString(buffer.String(), "")) - return buffer.String() } func (p *PTY) Write(r rune) { + p.t.Helper() + _, err := p.Input().Write([]byte{byte(r)}) require.NoError(p.t, err) } func (p *PTY) WriteLine(str string) { + p.t.Helper() + newline := []byte{'\r'} if runtime.GOOS == "windows" { newline = append(newline, '\n') @@ -120,3 +145,101 @@ func (p *PTY) WriteLine(str string) { _, err := p.Input().Write(append([]byte(str), newline...)) require.NoError(p.t, err) } + +// stdbuf is like a buffered stdout, it buffers writes until read. +type stdbuf struct { + r io.Reader + + mu sync.Mutex // Protects following. + b []byte + more chan struct{} + err error +} + +func newStdbuf() *stdbuf { + return &stdbuf{more: make(chan struct{}, 1)} +} + +func (b *stdbuf) Read(p []byte) (int, error) { + if b.r == nil { + return b.readOrWaitForMore(p) + } + + n, err := b.r.Read(p) + if xerrors.Is(err, io.EOF) { + b.r = nil + err = nil + if n == 0 { + return b.readOrWaitForMore(p) + } + } + return n, err +} + +func (b *stdbuf) readOrWaitForMore(p []byte) (int, error) { + b.mu.Lock() + defer b.mu.Unlock() + + // Deplete channel so that more check + // is for future input into buffer. + select { + case <-b.more: + default: + } + + if len(b.b) == 0 { + if b.err != nil { + return 0, b.err + } + + b.mu.Unlock() + <-b.more + b.mu.Lock() + } + + b.r = bytes.NewReader(b.b) + b.b = b.b[len(b.b):] + + return b.r.Read(p) +} + +func (b *stdbuf) Write(p []byte) (int, error) { + if len(p) == 0 { + return 0, nil + } + + b.mu.Lock() + defer b.mu.Unlock() + + if b.err != nil { + return 0, b.err + } + + b.b = append(b.b, p...) + + select { + case b.more <- struct{}{}: + default: + } + + return len(p), nil +} + +func (b *stdbuf) Close() error { + return b.closeErr(nil) +} + +func (b *stdbuf) closeErr(err error) error { + b.mu.Lock() + defer b.mu.Unlock() + if b.err != nil { + return err + } + if err == nil { + b.err = io.EOF + } else { + b.err = err + } + close(b.more) + return err +} diff --git a/pty/ptytest/ptytest_internal_test.go b/pty/ptytest/ptytest_internal_test.go new file mode 100644 index 0000000000000..29154178636f6 --- /dev/null +++ b/pty/ptytest/ptytest_internal_test.go @@ -0,0 +1,37 @@ +package ptytest + +import ( + "bytes" + "io" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestStdbuf(t *testing.T) { + t.Parallel() + + var got bytes.Buffer + + b := newStdbuf() + done := make(chan struct{}) + go func() { + defer close(done) + _, err := io.Copy(&got, b) + assert.NoError(t, err) + }() + + _, err := b.Write([]byte("hello ")) + require.NoError(t, err) + _, err = b.Write([]byte("world\n")) + require.NoError(t, err) + _, err = b.Write([]byte("bye\n")) + require.NoError(t, err) + + err = b.Close() + require.NoError(t, err) + <-done + + assert.Equal(t, "hello world\nbye\n", got.String()) +} diff --git a/pty/ptytest/ptytest_test.go b/pty/ptytest/ptytest_test.go index 7dfba01f04478..764ede12aec2c 100644 --- a/pty/ptytest/ptytest_test.go +++ b/pty/ptytest/ptytest_test.go @@ -2,7 +2,6 @@ package ptytest_test import ( "fmt" - "runtime" "strings" "testing" @@ -22,26 +21,24 @@ func TestPtytest(t *testing.T) { pty.WriteLine("read") }) + // See https://github.com/coder/coder/issues/2122 for the motivation + // behind this test. t.Run("Cobra ptytest should not hang when output is not consumed", func(t *testing.T) { t.Parallel() tests := []struct { name string output string - isPlatformBug bool // See https://github.com/coder/coder/issues/2122 for more info. + isPlatformBug bool }{ {name: "1024 is safe (does not exceed macOS buffer)", output: strings.Repeat(".", 1024)}, - {name: "1025 exceeds macOS buffer (must not hang)", output: strings.Repeat(".", 1025), isPlatformBug: true}, - {name: "10241 large output", output: strings.Repeat(".", 10241), isPlatformBug: true}, // 1024 * 10 + 1 + {name: "1025 exceeds macOS buffer (must not hang)", output: strings.Repeat(".", 1025)}, + {name: "10241 large output", output: strings.Repeat(".", 10241)}, // 1024 * 10 + 1 } for _, tt := range tests { tt := tt // nolint:paralleltest // Avoid parallel test to more easily identify the issue. t.Run(tt.name, func(t *testing.T) { - if tt.isPlatformBug && (runtime.GOOS == "darwin" || runtime.GOOS == "windows") { - t.Skip("This test hangs on macOS and Windows, see https://github.com/coder/coder/issues/2122") - } - cmd := cobra.Command{ Use: "test", RunE: func(cmd *cobra.Command, args []string) error {