Skip to content

Commit 92a95fb

Browse files
authored
fix: Rewrite ptytest to buffer stdout (#3170)
Fixes #2122
1 parent d7dee2c commit 92a95fb

File tree

3 files changed

+213
-56
lines changed

3 files changed

+213
-56
lines changed

pty/ptytest/ptytest.go

+171-48
Original file line numberDiff line numberDiff line change
@@ -7,116 +7,239 @@ import (
77
"io"
88
"os"
99
"os/exec"
10-
"regexp"
1110
"runtime"
1211
"strings"
12+
"sync"
1313
"testing"
1414
"time"
1515
"unicode/utf8"
1616

1717
"github.com/stretchr/testify/require"
18+
"golang.org/x/xerrors"
1819

1920
"github.com/coder/coder/pty"
2021
)
2122

22-
var (
23-
// Used to ensure terminal output doesn't have anything crazy!
24-
// See: https://stackoverflow.com/a/29497680
25-
stripAnsi = regexp.MustCompile("[\u001B\u009B][[\\]()#;?]*(?:(?:(?:[a-zA-Z\\d]*(?:;[a-zA-Z\\d]*)*)?\u0007)|(?:(?:\\d{1,4}(?:;\\d{0,4})*)?[\\dA-PRZcf-ntqry=><~]))")
26-
)
27-
2823
func New(t *testing.T) *PTY {
2924
ptty, err := pty.New()
3025
require.NoError(t, err)
3126

32-
return create(t, ptty)
27+
return create(t, ptty, "cmd")
3328
}
3429

3530
func Start(t *testing.T, cmd *exec.Cmd) (*PTY, *os.Process) {
3631
ptty, ps, err := pty.Start(cmd)
3732
require.NoError(t, err)
38-
return create(t, ptty), ps
33+
return create(t, ptty, cmd.Args[0]), ps
3934
}
4035

41-
func create(t *testing.T, ptty pty.PTY) *PTY {
42-
reader, writer := io.Pipe()
43-
scanner := bufio.NewScanner(reader)
36+
func create(t *testing.T, ptty pty.PTY, name string) *PTY {
37+
// Use pipe for logging.
38+
logDone := make(chan struct{})
39+
logr, logw := io.Pipe()
4440
t.Cleanup(func() {
45-
_ = reader.Close()
46-
_ = writer.Close()
41+
_ = logw.Close()
42+
_ = logr.Close()
43+
<-logDone // Guard against logging after test.
4744
})
4845
go func() {
49-
for scanner.Scan() {
50-
if scanner.Err() != nil {
51-
return
52-
}
53-
t.Log(stripAnsi.ReplaceAllString(scanner.Text(), ""))
46+
defer close(logDone)
47+
s := bufio.NewScanner(logr)
48+
for s.Scan() {
49+
// Quote output to avoid terminal escape codes, e.g. bell.
50+
t.Logf("%s: stdout: %q", name, s.Text())
5451
}
5552
}()
5653

54+
// Write to log and output buffer.
55+
copyDone := make(chan struct{})
56+
out := newStdbuf()
57+
w := io.MultiWriter(logw, out)
58+
go func() {
59+
defer close(copyDone)
60+
_, err := io.Copy(w, ptty.Output())
61+
_ = out.closeErr(err)
62+
}()
5763
t.Cleanup(func() {
64+
_ = out.Close
5865
_ = ptty.Close()
66+
<-copyDone
5967
})
68+
6069
return &PTY{
6170
t: t,
6271
PTY: ptty,
72+
out: out,
6373

64-
outputWriter: writer,
65-
runeReader: bufio.NewReaderSize(ptty.Output(), utf8.UTFMax),
74+
runeReader: bufio.NewReaderSize(out, utf8.UTFMax),
6675
}
6776
}
6877

6978
type PTY struct {
7079
t *testing.T
7180
pty.PTY
81+
out *stdbuf
7282

73-
outputWriter io.Writer
74-
runeReader *bufio.Reader
83+
runeReader *bufio.Reader
7584
}
7685

7786
func (p *PTY) ExpectMatch(str string) string {
87+
p.t.Helper()
88+
89+
timeout, cancel := context.WithTimeout(context.Background(), 10*time.Second)
90+
defer cancel()
91+
7892
var buffer bytes.Buffer
79-
multiWriter := io.MultiWriter(&buffer, p.outputWriter)
80-
runeWriter := bufio.NewWriterSize(multiWriter, utf8.UTFMax)
81-
complete, cancelFunc := context.WithCancel(context.Background())
82-
defer cancelFunc()
93+
match := make(chan error, 1)
8394
go func() {
84-
timer := time.NewTimer(10 * time.Second)
85-
defer timer.Stop()
86-
select {
87-
case <-complete.Done():
88-
return
89-
case <-timer.C:
90-
}
91-
_ = p.Close()
92-
p.t.Errorf("%s match exceeded deadline: wanted %q; got %q", time.Now(), str, buffer.String())
95+
defer close(match)
96+
match <- func() error {
97+
for {
98+
r, _, err := p.runeReader.ReadRune()
99+
if err != nil {
100+
return err
101+
}
102+
_, err = buffer.WriteRune(r)
103+
if err != nil {
104+
return err
105+
}
106+
if strings.Contains(buffer.String(), str) {
107+
return nil
108+
}
109+
}
110+
}()
93111
}()
94-
for {
95-
var r rune
96-
r, _, err := p.runeReader.ReadRune()
97-
require.NoError(p.t, err)
98-
_, err = runeWriter.WriteRune(r)
99-
require.NoError(p.t, err)
100-
err = runeWriter.Flush()
101-
require.NoError(p.t, err)
102-
if strings.Contains(buffer.String(), str) {
103-
break
112+
113+
select {
114+
case err := <-match:
115+
if err != nil {
116+
p.t.Fatalf("%s: read error: %v (wanted %q; got %q)", time.Now(), err, str, buffer.String())
117+
return ""
104118
}
119+
p.t.Logf("%s: matched %q = %q", time.Now(), str, buffer.String())
120+
return buffer.String()
121+
case <-timeout.Done():
122+
// Ensure goroutine is cleaned up before test exit.
123+
_ = p.out.closeErr(p.Close())
124+
<-match
125+
126+
p.t.Fatalf("%s: match exceeded deadline: wanted %q; got %q", time.Now(), str, buffer.String())
127+
return ""
105128
}
106-
p.t.Logf("matched %q = %q", str, stripAnsi.ReplaceAllString(buffer.String(), ""))
107-
return buffer.String()
108129
}
109130

110131
func (p *PTY) Write(r rune) {
132+
p.t.Helper()
133+
111134
_, err := p.Input().Write([]byte{byte(r)})
112135
require.NoError(p.t, err)
113136
}
114137

115138
func (p *PTY) WriteLine(str string) {
139+
p.t.Helper()
140+
116141
newline := []byte{'\r'}
117142
if runtime.GOOS == "windows" {
118143
newline = append(newline, '\n')
119144
}
120145
_, err := p.Input().Write(append([]byte(str), newline...))
121146
require.NoError(p.t, err)
122147
}
148+
149+
// stdbuf is like a buffered stdout, it buffers writes until read.
150+
type stdbuf struct {
151+
r io.Reader
152+
153+
mu sync.Mutex // Protects following.
154+
b []byte
155+
more chan struct{}
156+
err error
157+
}
158+
159+
func newStdbuf() *stdbuf {
160+
return &stdbuf{more: make(chan struct{}, 1)}
161+
}
162+
163+
func (b *stdbuf) Read(p []byte) (int, error) {
164+
if b.r == nil {
165+
return b.readOrWaitForMore(p)
166+
}
167+
168+
n, err := b.r.Read(p)
169+
if xerrors.Is(err, io.EOF) {
170+
b.r = nil
171+
err = nil
172+
if n == 0 {
173+
return b.readOrWaitForMore(p)
174+
}
175+
}
176+
return n, err
177+
}
178+
179+
func (b *stdbuf) readOrWaitForMore(p []byte) (int, error) {
180+
b.mu.Lock()
181+
defer b.mu.Unlock()
182+
183+
// Deplete channel so that more check
184+
// is for future input into buffer.
185+
select {
186+
case <-b.more:
187+
default:
188+
}
189+
190+
if len(b.b) == 0 {
191+
if b.err != nil {
192+
return 0, b.err
193+
}
194+
195+
b.mu.Unlock()
196+
<-b.more
197+
b.mu.Lock()
198+
}
199+
200+
b.r = bytes.NewReader(b.b)
201+
b.b = b.b[len(b.b):]
202+
203+
return b.r.Read(p)
204+
}
205+
206+
func (b *stdbuf) Write(p []byte) (int, error) {
207+
if len(p) == 0 {
208+
return 0, nil
209+
}
210+
211+
b.mu.Lock()
212+
defer b.mu.Unlock()
213+
214+
if b.err != nil {
215+
return 0, b.err
216+
}
217+
218+
b.b = append(b.b, p...)
219+
220+
select {
221+
case b.more <- struct{}{}:
222+
default:
223+
}
224+
225+
return len(p), nil
226+
}
227+
228+
func (b *stdbuf) Close() error {
229+
return b.closeErr(nil)
230+
}
231+
232+
func (b *stdbuf) closeErr(err error) error {
233+
b.mu.Lock()
234+
defer b.mu.Unlock()
235+
if b.err != nil {
236+
return err
237+
}
238+
if err == nil {
239+
b.err = io.EOF
240+
} else {
241+
b.err = err
242+
}
243+
close(b.more)
244+
return err
245+
}

pty/ptytest/ptytest_internal_test.go

+37
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
package ptytest
2+
3+
import (
4+
"bytes"
5+
"io"
6+
"testing"
7+
8+
"github.com/stretchr/testify/assert"
9+
"github.com/stretchr/testify/require"
10+
)
11+
12+
func TestStdbuf(t *testing.T) {
13+
t.Parallel()
14+
15+
var got bytes.Buffer
16+
17+
b := newStdbuf()
18+
done := make(chan struct{})
19+
go func() {
20+
defer close(done)
21+
_, err := io.Copy(&got, b)
22+
assert.NoError(t, err)
23+
}()
24+
25+
_, err := b.Write([]byte("hello "))
26+
require.NoError(t, err)
27+
_, err = b.Write([]byte("world\n"))
28+
require.NoError(t, err)
29+
_, err = b.Write([]byte("bye\n"))
30+
require.NoError(t, err)
31+
32+
err = b.Close()
33+
require.NoError(t, err)
34+
<-done
35+
36+
assert.Equal(t, "hello world\nbye\n", got.String())
37+
}

pty/ptytest/ptytest_test.go

+5-8
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ package ptytest_test
22

33
import (
44
"fmt"
5-
"runtime"
65
"strings"
76
"testing"
87

@@ -22,26 +21,24 @@ func TestPtytest(t *testing.T) {
2221
pty.WriteLine("read")
2322
})
2423

24+
// See https://github.com/coder/coder/issues/2122 for the motivation
25+
// behind this test.
2526
t.Run("Cobra ptytest should not hang when output is not consumed", func(t *testing.T) {
2627
t.Parallel()
2728

2829
tests := []struct {
2930
name string
3031
output string
31-
isPlatformBug bool // See https://github.com/coder/coder/issues/2122 for more info.
32+
isPlatformBug bool
3233
}{
3334
{name: "1024 is safe (does not exceed macOS buffer)", output: strings.Repeat(".", 1024)},
34-
{name: "1025 exceeds macOS buffer (must not hang)", output: strings.Repeat(".", 1025), isPlatformBug: true},
35-
{name: "10241 large output", output: strings.Repeat(".", 10241), isPlatformBug: true}, // 1024 * 10 + 1
35+
{name: "1025 exceeds macOS buffer (must not hang)", output: strings.Repeat(".", 1025)},
36+
{name: "10241 large output", output: strings.Repeat(".", 10241)}, // 1024 * 10 + 1
3637
}
3738
for _, tt := range tests {
3839
tt := tt
3940
// nolint:paralleltest // Avoid parallel test to more easily identify the issue.
4041
t.Run(tt.name, func(t *testing.T) {
41-
if tt.isPlatformBug && (runtime.GOOS == "darwin" || runtime.GOOS == "windows") {
42-
t.Skip("This test hangs on macOS and Windows, see https://github.com/coder/coder/issues/2122")
43-
}
44-
4542
cmd := cobra.Command{
4643
Use: "test",
4744
RunE: func(cmd *cobra.Command, args []string) error {

0 commit comments

Comments
 (0)