Skip to content

fix: Rewrite ptytest to buffer stdout #3170

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jul 25, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
219 changes: 171 additions & 48 deletions pty/ptytest/ptytest.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,116 +7,239 @@ import (
"io"
"os"
"os/exec"
"regexp"
"runtime"
"strings"
"sync"
"testing"
"time"
"unicode/utf8"

"github.com/stretchr/testify/require"
"golang.org/x/xerrors"

"github.com/coder/coder/pty"
)

var (
// Used to ensure terminal output doesn't have anything crazy!
// See: https://stackoverflow.com/a/29497680
stripAnsi = regexp.MustCompile("[\u001B\u009B][[\\]()#;?]*(?:(?:(?:[a-zA-Z\\d]*(?:;[a-zA-Z\\d]*)*)?\u0007)|(?:(?:\\d{1,4}(?:;\\d{0,4})*)?[\\dA-PRZcf-ntqry=><~]))")
)

func New(t *testing.T) *PTY {
ptty, err := pty.New()
require.NoError(t, err)

return create(t, ptty)
return create(t, ptty, "cmd")
}

func Start(t *testing.T, cmd *exec.Cmd) (*PTY, *os.Process) {
ptty, ps, err := pty.Start(cmd)
require.NoError(t, err)
return create(t, ptty), ps
return create(t, ptty, cmd.Args[0]), ps
}

func create(t *testing.T, ptty pty.PTY) *PTY {
reader, writer := io.Pipe()
scanner := bufio.NewScanner(reader)
func create(t *testing.T, ptty pty.PTY, name string) *PTY {
// Use pipe for logging.
logDone := make(chan struct{})
logr, logw := io.Pipe()
t.Cleanup(func() {
_ = reader.Close()
_ = writer.Close()
_ = logw.Close()
_ = logr.Close()
<-logDone // Guard against logging after test.
})
go func() {
for scanner.Scan() {
if scanner.Err() != nil {
return
}
t.Log(stripAnsi.ReplaceAllString(scanner.Text(), ""))
defer close(logDone)
s := bufio.NewScanner(logr)
for s.Scan() {
// Quote output to avoid terminal escape codes, e.g. bell.
t.Logf("%s: stdout: %q", name, s.Text())
}
}()

// Write to log and output buffer.
copyDone := make(chan struct{})
out := newStdbuf()
w := io.MultiWriter(logw, out)
go func() {
defer close(copyDone)
_, err := io.Copy(w, ptty.Output())
_ = out.closeErr(err)
}()
t.Cleanup(func() {
_ = out.Close
_ = ptty.Close()
<-copyDone
})

return &PTY{
t: t,
PTY: ptty,
out: out,

outputWriter: writer,
runeReader: bufio.NewReaderSize(ptty.Output(), utf8.UTFMax),
runeReader: bufio.NewReaderSize(out, utf8.UTFMax),
}
}

type PTY struct {
t *testing.T
pty.PTY
out *stdbuf

outputWriter io.Writer
runeReader *bufio.Reader
runeReader *bufio.Reader
}

func (p *PTY) ExpectMatch(str string) string {
p.t.Helper()

timeout, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()

var buffer bytes.Buffer
multiWriter := io.MultiWriter(&buffer, p.outputWriter)
runeWriter := bufio.NewWriterSize(multiWriter, utf8.UTFMax)
complete, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
match := make(chan error, 1)
go func() {
timer := time.NewTimer(10 * time.Second)
defer timer.Stop()
select {
case <-complete.Done():
return
case <-timer.C:
}
_ = p.Close()
p.t.Errorf("%s match exceeded deadline: wanted %q; got %q", time.Now(), str, buffer.String())
defer close(match)
match <- func() error {
for {
r, _, err := p.runeReader.ReadRune()
if err != nil {
return err
}
_, err = buffer.WriteRune(r)
if err != nil {
return err
}
if strings.Contains(buffer.String(), str) {
return nil
}
}
}()
}()
for {
var r rune
r, _, err := p.runeReader.ReadRune()
require.NoError(p.t, err)
_, err = runeWriter.WriteRune(r)
require.NoError(p.t, err)
err = runeWriter.Flush()
require.NoError(p.t, err)
if strings.Contains(buffer.String(), str) {
break

select {
case err := <-match:
if err != nil {
p.t.Fatalf("%s: read error: %v (wanted %q; got %q)", time.Now(), err, str, buffer.String())
return ""
}
p.t.Logf("%s: matched %q = %q", time.Now(), str, buffer.String())
return buffer.String()
case <-timeout.Done():
// Ensure goroutine is cleaned up before test exit.
_ = p.out.closeErr(p.Close())
<-match

p.t.Fatalf("%s: match exceeded deadline: wanted %q; got %q", time.Now(), str, buffer.String())
return ""
}
p.t.Logf("matched %q = %q", str, stripAnsi.ReplaceAllString(buffer.String(), ""))
return buffer.String()
}

func (p *PTY) Write(r rune) {
p.t.Helper()

_, err := p.Input().Write([]byte{byte(r)})
require.NoError(p.t, err)
}

func (p *PTY) WriteLine(str string) {
p.t.Helper()

newline := []byte{'\r'}
if runtime.GOOS == "windows" {
newline = append(newline, '\n')
}
_, err := p.Input().Write(append([]byte(str), newline...))
require.NoError(p.t, err)
}

// stdbuf is like a buffered stdout, it buffers writes until read.
type stdbuf struct {
r io.Reader

mu sync.Mutex // Protects following.
b []byte
more chan struct{}
err error
}

func newStdbuf() *stdbuf {
return &stdbuf{more: make(chan struct{}, 1)}
}

func (b *stdbuf) Read(p []byte) (int, error) {
if b.r == nil {
return b.readOrWaitForMore(p)
}

n, err := b.r.Read(p)
if xerrors.Is(err, io.EOF) {
b.r = nil
err = nil
if n == 0 {
return b.readOrWaitForMore(p)
}
}
return n, err
}

func (b *stdbuf) readOrWaitForMore(p []byte) (int, error) {
b.mu.Lock()
defer b.mu.Unlock()

// Deplete channel so that more check
// is for future input into buffer.
select {
case <-b.more:
default:
}

if len(b.b) == 0 {
if b.err != nil {
return 0, b.err
}

b.mu.Unlock()
<-b.more
b.mu.Lock()
}

b.r = bytes.NewReader(b.b)
b.b = b.b[len(b.b):]

return b.r.Read(p)
}

func (b *stdbuf) Write(p []byte) (int, error) {
if len(p) == 0 {
return 0, nil
}

b.mu.Lock()
defer b.mu.Unlock()

if b.err != nil {
return 0, b.err
}

b.b = append(b.b, p...)

select {
case b.more <- struct{}{}:
default:
}

return len(p), nil
}

func (b *stdbuf) Close() error {
return b.closeErr(nil)
}

func (b *stdbuf) closeErr(err error) error {
b.mu.Lock()
defer b.mu.Unlock()
if b.err != nil {
return err
}
if err == nil {
b.err = io.EOF
} else {
b.err = err
}
close(b.more)
return err
}
37 changes: 37 additions & 0 deletions pty/ptytest/ptytest_internal_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
package ptytest

import (
"bytes"
"io"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestStdbuf(t *testing.T) {
t.Parallel()

var got bytes.Buffer

b := newStdbuf()
done := make(chan struct{})
go func() {
defer close(done)
_, err := io.Copy(&got, b)
assert.NoError(t, err)
}()

_, err := b.Write([]byte("hello "))
require.NoError(t, err)
_, err = b.Write([]byte("world\n"))
require.NoError(t, err)
_, err = b.Write([]byte("bye\n"))
require.NoError(t, err)

err = b.Close()
require.NoError(t, err)
<-done

assert.Equal(t, "hello world\nbye\n", got.String())
}
13 changes: 5 additions & 8 deletions pty/ptytest/ptytest_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package ptytest_test

import (
"fmt"
"runtime"
"strings"
"testing"

Expand All @@ -22,26 +21,24 @@ func TestPtytest(t *testing.T) {
pty.WriteLine("read")
})

// See https://github.com/coder/coder/issues/2122 for the motivation
// behind this test.
t.Run("Cobra ptytest should not hang when output is not consumed", func(t *testing.T) {
t.Parallel()

tests := []struct {
name string
output string
isPlatformBug bool // See https://github.com/coder/coder/issues/2122 for more info.
isPlatformBug bool
}{
{name: "1024 is safe (does not exceed macOS buffer)", output: strings.Repeat(".", 1024)},
{name: "1025 exceeds macOS buffer (must not hang)", output: strings.Repeat(".", 1025), isPlatformBug: true},
{name: "10241 large output", output: strings.Repeat(".", 10241), isPlatformBug: true}, // 1024 * 10 + 1
{name: "1025 exceeds macOS buffer (must not hang)", output: strings.Repeat(".", 1025)},
{name: "10241 large output", output: strings.Repeat(".", 10241)}, // 1024 * 10 + 1
}
for _, tt := range tests {
tt := tt
// nolint:paralleltest // Avoid parallel test to more easily identify the issue.
t.Run(tt.name, func(t *testing.T) {
if tt.isPlatformBug && (runtime.GOOS == "darwin" || runtime.GOOS == "windows") {
t.Skip("This test hangs on macOS and Windows, see https://github.com/coder/coder/issues/2122")
}

cmd := cobra.Command{
Use: "test",
RunE: func(cmd *cobra.Command, args []string) error {
Expand Down