From ffda8cde08e668ac5a5caf9b84ee2091b32846bd Mon Sep 17 00:00:00 2001 From: Mathias Fredriksson Date: Mon, 25 Jul 2022 16:40:35 +0300 Subject: [PATCH 1/7] fix: Rewrite ptytest to buffer stdout Fixes #2122 --- pty/ptytest/ptytest.go | 209 ++++++++++++++++++++++----- pty/ptytest/ptytest_internal_test.go | 31 ++++ pty/ptytest/ptytest_test.go | 13 +- 3 files changed, 206 insertions(+), 47 deletions(-) create mode 100644 pty/ptytest/ptytest_internal_test.go diff --git a/pty/ptytest/ptytest.go b/pty/ptytest/ptytest.go index 178998b5a21f9..6cc419d2396de 100644 --- a/pty/ptytest/ptytest.go +++ b/pty/ptytest/ptytest.go @@ -7,80 +7,91 @@ import ( "io" "os" "os/exec" - "regexp" "runtime" "strings" + "sync" "testing" "time" "unicode/utf8" "github.com/stretchr/testify/require" + "golang.org/x/xerrors" "github.com/coder/coder/pty" ) -var ( - // Used to ensure terminal output doesn't have anything crazy! - // See: https://stackoverflow.com/a/29497680 - stripAnsi = regexp.MustCompile("[\u001B\u009B][[\\]()#;?]*(?:(?:(?:[a-zA-Z\\d]*(?:;[a-zA-Z\\d]*)*)?\u0007)|(?:(?:\\d{1,4}(?:;\\d{0,4})*)?[\\dA-PRZcf-ntqry=><~]))") -) - func New(t *testing.T) *PTY { ptty, err := pty.New() require.NoError(t, err) - return create(t, ptty) + return create(t, ptty, "cmd") } func Start(t *testing.T, cmd *exec.Cmd) (*PTY, *os.Process) { ptty, ps, err := pty.Start(cmd) require.NoError(t, err) - return create(t, ptty), ps + return create(t, ptty, cmd.Args[0]), ps } -func create(t *testing.T, ptty pty.PTY) *PTY { - reader, writer := io.Pipe() - scanner := bufio.NewScanner(reader) +func create(t *testing.T, ptty pty.PTY, name string) *PTY { + // Use pipe for logging. + logDone := make(chan struct{}) + logr, logw := io.Pipe() t.Cleanup(func() { - _ = reader.Close() - _ = writer.Close() + _ = logw.Close() + _ = logr.Close() + <-logDone // Guard against logging after test. }) go func() { - for scanner.Scan() { - if scanner.Err() != nil { - return - } - t.Log(stripAnsi.ReplaceAllString(scanner.Text(), "")) + defer close(logDone) + s := bufio.NewScanner(logr) + for s.Scan() { + // Quote output to avoid terminal escape codes, e.g. bell. + t.Logf("%s: stdout: %q", name, s.Text()) } }() + // Write to log and output buffer. + copyDone := make(chan struct{}) + out := newStdbuf() + w := io.MultiWriter(logw, out) + go func() { + defer close(copyDone) + _, err := io.Copy(w, ptty.Output()) + _ = out.closeErr(err) + }() t.Cleanup(func() { + _ = out.Close _ = ptty.Close() + <-copyDone }) + return &PTY{ t: t, PTY: ptty, + out: out, - outputWriter: writer, - runeReader: bufio.NewReaderSize(ptty.Output(), utf8.UTFMax), + runeReader: bufio.NewReaderSize(out, utf8.UTFMax), } } type PTY struct { t *testing.T pty.PTY + out *stdbuf - outputWriter io.Writer - runeReader *bufio.Reader + runeReader *bufio.Reader } func (p *PTY) ExpectMatch(str string) string { - var buffer bytes.Buffer - multiWriter := io.MultiWriter(&buffer, p.outputWriter) - runeWriter := bufio.NewWriterSize(multiWriter, utf8.UTFMax) + p.t.Helper() + complete, cancelFunc := context.WithCancel(context.Background()) defer cancelFunc() + + timeout := make(chan error, 1) go func() { + defer close(timeout) timer := time.NewTimer(10 * time.Second) defer timer.Stop() select { @@ -88,31 +99,54 @@ func (p *PTY) ExpectMatch(str string) string { return case <-timer.C: } - _ = p.Close() - p.t.Errorf("%s match exceeded deadline: wanted %q; got %q", time.Now(), str, buffer.String()) + timeout <- xerrors.Errorf("%s match exceeded deadline", time.Now()) }() - for { - var r rune - r, _, err := p.runeReader.ReadRune() - require.NoError(p.t, err) - _, err = runeWriter.WriteRune(r) - require.NoError(p.t, err) - err = runeWriter.Flush() - require.NoError(p.t, err) - if strings.Contains(buffer.String(), str) { - break + + var buffer bytes.Buffer + match := make(chan error, 1) + go func() { + defer close(match) + for { + r, _, err := p.runeReader.ReadRune() + if err != nil { + match <- err + return + } + _, err = buffer.WriteRune(r) + if err != nil { + match <- err + return + } + if strings.Contains(buffer.String(), str) { + match <- nil + return + } + } + }() + + select { + case err := <-match: + if err != nil { + p.t.Fatalf("read error: %v (wanted %q; got %q)", err, str, buffer.String()) } + p.t.Logf("matched %q = %q", str, buffer.String()) + case err := <-timeout: + _ = p.out.closeErr(p.Close()) + p.t.Fatalf("%s: wanted %q; got %q", err, str, buffer.String()) } - p.t.Logf("matched %q = %q", str, stripAnsi.ReplaceAllString(buffer.String(), "")) return buffer.String() } func (p *PTY) Write(r rune) { + p.t.Helper() + _, err := p.Input().Write([]byte{byte(r)}) require.NoError(p.t, err) } func (p *PTY) WriteLine(str string) { + p.t.Helper() + newline := []byte{'\r'} if runtime.GOOS == "windows" { newline = append(newline, '\n') @@ -120,3 +154,100 @@ func (p *PTY) WriteLine(str string) { _, err := p.Input().Write(append([]byte(str), newline...)) require.NoError(p.t, err) } + +// stdbuf is like a buffered stdout, it buffers writes until read. +type stdbuf struct { + r io.Reader + + mu sync.Mutex // Protects following. + b []byte + more chan struct{} + err error +} + +func newStdbuf() *stdbuf { + return &stdbuf{more: make(chan struct{}, 1)} +} + +func (b *stdbuf) Read(p []byte) (int, error) { + if b.r == nil { + return b.read(p) + } + + n, err := b.r.Read(p) + if xerrors.Is(err, io.EOF) { + b.r = nil + err = nil + if n == 0 { + return b.read(p) + } + } + return n, err +} + +func (b *stdbuf) read(p []byte) (int, error) { + b.mu.Lock() + defer b.mu.Unlock() + + // Deplete channel so that more check + // is for future input into buffer. + select { + case <-b.more: + default: + } + + if len(b.b) == 0 { + if b.err != nil { + return 0, b.err + } + + b.mu.Unlock() + <-b.more + b.mu.Lock() + } + + b.r = bytes.NewReader(b.b) + b.b = b.b[len(b.b):] + + return b.r.Read(p) +} + +func (b *stdbuf) Write(p []byte) (int, error) { + if len(p) == 0 { + return 0, nil + } + + b.mu.Lock() + defer b.mu.Unlock() + + if b.err != nil { + return 0, b.err + } + + b.b = append(b.b, p...) + + select { + case b.more <- struct{}{}: + default: + } + + return len(p), nil +} + +func (b *stdbuf) Close() error { + return b.closeErr(nil) +} + +func (b *stdbuf) closeErr(err error) error { + b.mu.Lock() + defer b.mu.Unlock() + if b.err != nil { + return err + } + if err == nil { + err = io.EOF + } + b.err = err + close(b.more) + return b.err +} diff --git a/pty/ptytest/ptytest_internal_test.go b/pty/ptytest/ptytest_internal_test.go new file mode 100644 index 0000000000000..0d8657732e802 --- /dev/null +++ b/pty/ptytest/ptytest_internal_test.go @@ -0,0 +1,31 @@ +package ptytest + +import ( + "bytes" + "io" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestStdbuf(t *testing.T) { + t.Parallel() + + var got bytes.Buffer + + b := newStdbuf() + done := make(chan struct{}) + go func() { + defer close(done) + io.Copy(&got, b) + }() + + b.Write([]byte("hello ")) + b.Write([]byte("world\n")) + b.Write([]byte("bye\n")) + + b.Close() + <-done + + assert.Equal(t, "hello world\nbye\n", got.String()) +} diff --git a/pty/ptytest/ptytest_test.go b/pty/ptytest/ptytest_test.go index 7dfba01f04478..764ede12aec2c 100644 --- a/pty/ptytest/ptytest_test.go +++ b/pty/ptytest/ptytest_test.go @@ -2,7 +2,6 @@ package ptytest_test import ( "fmt" - "runtime" "strings" "testing" @@ -22,26 +21,24 @@ func TestPtytest(t *testing.T) { pty.WriteLine("read") }) + // See https://github.com/coder/coder/issues/2122 for the motivation + // behind this test. t.Run("Cobra ptytest should not hang when output is not consumed", func(t *testing.T) { t.Parallel() tests := []struct { name string output string - isPlatformBug bool // See https://github.com/coder/coder/issues/2122 for more info. + isPlatformBug bool }{ {name: "1024 is safe (does not exceed macOS buffer)", output: strings.Repeat(".", 1024)}, - {name: "1025 exceeds macOS buffer (must not hang)", output: strings.Repeat(".", 1025), isPlatformBug: true}, - {name: "10241 large output", output: strings.Repeat(".", 10241), isPlatformBug: true}, // 1024 * 10 + 1 + {name: "1025 exceeds macOS buffer (must not hang)", output: strings.Repeat(".", 1025)}, + {name: "10241 large output", output: strings.Repeat(".", 10241)}, // 1024 * 10 + 1 } for _, tt := range tests { tt := tt // nolint:paralleltest // Avoid parallel test to more easily identify the issue. t.Run(tt.name, func(t *testing.T) { - if tt.isPlatformBug && (runtime.GOOS == "darwin" || runtime.GOOS == "windows") { - t.Skip("This test hangs on macOS and Windows, see https://github.com/coder/coder/issues/2122") - } - cmd := cobra.Command{ Use: "test", RunE: func(cmd *cobra.Command, args []string) error { From c1112a59b91507b6040c008b30dfeedf37a1216f Mon Sep 17 00:00:00 2001 From: Mathias Fredriksson Date: Mon, 25 Jul 2022 16:54:27 +0300 Subject: [PATCH 2/7] Lint fixes --- pty/ptytest/ptytest.go | 6 +++--- pty/ptytest/ptytest_internal_test.go | 16 +++++++++++----- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/pty/ptytest/ptytest.go b/pty/ptytest/ptytest.go index 6cc419d2396de..2413591c02c41 100644 --- a/pty/ptytest/ptytest.go +++ b/pty/ptytest/ptytest.go @@ -171,7 +171,7 @@ func newStdbuf() *stdbuf { func (b *stdbuf) Read(p []byte) (int, error) { if b.r == nil { - return b.read(p) + return b.readOrWaitForMore(p) } n, err := b.r.Read(p) @@ -179,13 +179,13 @@ func (b *stdbuf) Read(p []byte) (int, error) { b.r = nil err = nil if n == 0 { - return b.read(p) + return b.readOrWaitForMore(p) } } return n, err } -func (b *stdbuf) read(p []byte) (int, error) { +func (b *stdbuf) readOrWaitForMore(p []byte) (int, error) { b.mu.Lock() defer b.mu.Unlock() diff --git a/pty/ptytest/ptytest_internal_test.go b/pty/ptytest/ptytest_internal_test.go index 0d8657732e802..29154178636f6 100644 --- a/pty/ptytest/ptytest_internal_test.go +++ b/pty/ptytest/ptytest_internal_test.go @@ -6,6 +6,7 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestStdbuf(t *testing.T) { @@ -17,14 +18,19 @@ func TestStdbuf(t *testing.T) { done := make(chan struct{}) go func() { defer close(done) - io.Copy(&got, b) + _, err := io.Copy(&got, b) + assert.NoError(t, err) }() - b.Write([]byte("hello ")) - b.Write([]byte("world\n")) - b.Write([]byte("bye\n")) + _, err := b.Write([]byte("hello ")) + require.NoError(t, err) + _, err = b.Write([]byte("world\n")) + require.NoError(t, err) + _, err = b.Write([]byte("bye\n")) + require.NoError(t, err) - b.Close() + err = b.Close() + require.NoError(t, err) <-done assert.Equal(t, "hello world\nbye\n", got.String()) From f974351cd572a250d7647d736b540cda5c545947 Mon Sep 17 00:00:00 2001 From: Mathias Fredriksson Date: Mon, 25 Jul 2022 17:06:07 +0300 Subject: [PATCH 3/7] Fix eof --- pty/ptytest/ptytest.go | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/pty/ptytest/ptytest.go b/pty/ptytest/ptytest.go index 2413591c02c41..ff4451495b1ca 100644 --- a/pty/ptytest/ptytest.go +++ b/pty/ptytest/ptytest.go @@ -245,9 +245,10 @@ func (b *stdbuf) closeErr(err error) error { return err } if err == nil { - err = io.EOF + b.err = io.EOF + } else { + b.err = err } - b.err = err close(b.more) - return b.err + return err } From 7b331cd3c4cea992ac811060088e50bd77d8d89f Mon Sep 17 00:00:00 2001 From: Mathias Fredriksson Date: Mon, 25 Jul 2022 19:00:40 +0300 Subject: [PATCH 4/7] chore: Simplify timeout in ExpectMatch --- pty/ptytest/ptytest.go | 29 +++++++++-------------------- 1 file changed, 9 insertions(+), 20 deletions(-) diff --git a/pty/ptytest/ptytest.go b/pty/ptytest/ptytest.go index ff4451495b1ca..2c5a8c3f24b64 100644 --- a/pty/ptytest/ptytest.go +++ b/pty/ptytest/ptytest.go @@ -86,21 +86,8 @@ type PTY struct { func (p *PTY) ExpectMatch(str string) string { p.t.Helper() - complete, cancelFunc := context.WithCancel(context.Background()) - defer cancelFunc() - - timeout := make(chan error, 1) - go func() { - defer close(timeout) - timer := time.NewTimer(10 * time.Second) - defer timer.Stop() - select { - case <-complete.Done(): - return - case <-timer.C: - } - timeout <- xerrors.Errorf("%s match exceeded deadline", time.Now()) - }() + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() var buffer bytes.Buffer match := make(chan error, 1) @@ -127,14 +114,16 @@ func (p *PTY) ExpectMatch(str string) string { select { case err := <-match: if err != nil { - p.t.Fatalf("read error: %v (wanted %q; got %q)", err, str, buffer.String()) + p.t.Fatalf("%s: read error: %v (wanted %q; got %q)", time.Now(), err, str, buffer.String()) + return "" } - p.t.Logf("matched %q = %q", str, buffer.String()) - case err := <-timeout: + p.t.Logf("%s: matched %q = %q", time.Now(), str, buffer.String()) + return buffer.String() + case <-ctx.Done(): _ = p.out.closeErr(p.Close()) - p.t.Fatalf("%s: wanted %q; got %q", err, str, buffer.String()) + p.t.Fatalf("%s: match exceeded deadline: wanted %q; got %q", time.Now(), str, buffer.String()) + return "" } - return buffer.String() } func (p *PTY) Write(r rune) { From d747356c2664894341bf7156dc1f16e08ba8d7f3 Mon Sep 17 00:00:00 2001 From: Mathias Fredriksson Date: Mon, 25 Jul 2022 19:26:22 +0300 Subject: [PATCH 5/7] chore: Rename ctx to timeout --- pty/ptytest/ptytest.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pty/ptytest/ptytest.go b/pty/ptytest/ptytest.go index 2c5a8c3f24b64..1cc29eb68941f 100644 --- a/pty/ptytest/ptytest.go +++ b/pty/ptytest/ptytest.go @@ -86,7 +86,7 @@ type PTY struct { func (p *PTY) ExpectMatch(str string) string { p.t.Helper() - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + timeout, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() var buffer bytes.Buffer @@ -119,7 +119,7 @@ func (p *PTY) ExpectMatch(str string) string { } p.t.Logf("%s: matched %q = %q", time.Now(), str, buffer.String()) return buffer.String() - case <-ctx.Done(): + case <-timeout.Done(): _ = p.out.closeErr(p.Close()) p.t.Fatalf("%s: match exceeded deadline: wanted %q; got %q", time.Now(), str, buffer.String()) return "" From 372f60cb09c0e6e743df59fe6c1523d78f3a3162 Mon Sep 17 00:00:00 2001 From: Mathias Fredriksson Date: Mon, 25 Jul 2022 19:28:11 +0300 Subject: [PATCH 6/7] fix: Ensure goroutine cleanup --- pty/ptytest/ptytest.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pty/ptytest/ptytest.go b/pty/ptytest/ptytest.go index 1cc29eb68941f..10fb705068535 100644 --- a/pty/ptytest/ptytest.go +++ b/pty/ptytest/ptytest.go @@ -120,7 +120,10 @@ func (p *PTY) ExpectMatch(str string) string { p.t.Logf("%s: matched %q = %q", time.Now(), str, buffer.String()) return buffer.String() case <-timeout.Done(): + // Ensure goroutine is cleaned up before test exit. _ = p.out.closeErr(p.Close()) + <-match + p.t.Fatalf("%s: match exceeded deadline: wanted %q; got %q", time.Now(), str, buffer.String()) return "" } From 3cd1bc85e5cbfe8cdb12d487d93b4a98df493ab0 Mon Sep 17 00:00:00 2001 From: Mathias Fredriksson Date: Mon, 25 Jul 2022 19:37:22 +0300 Subject: [PATCH 7/7] chore: Cleanup goroutine via inline function --- pty/ptytest/ptytest.go | 29 ++++++++++++++--------------- 1 file changed, 14 insertions(+), 15 deletions(-) diff --git a/pty/ptytest/ptytest.go b/pty/ptytest/ptytest.go index 10fb705068535..43fbaec1109e2 100644 --- a/pty/ptytest/ptytest.go +++ b/pty/ptytest/ptytest.go @@ -93,22 +93,21 @@ func (p *PTY) ExpectMatch(str string) string { match := make(chan error, 1) go func() { defer close(match) - for { - r, _, err := p.runeReader.ReadRune() - if err != nil { - match <- err - return + match <- func() error { + for { + r, _, err := p.runeReader.ReadRune() + if err != nil { + return err + } + _, err = buffer.WriteRune(r) + if err != nil { + return err + } + if strings.Contains(buffer.String(), str) { + return nil + } } - _, err = buffer.WriteRune(r) - if err != nil { - match <- err - return - } - if strings.Contains(buffer.String(), str) { - match <- nil - return - } - } + }() }() select {