diff --git a/cli/server_test.go b/cli/server_test.go index 64ad535ea34f3..d9019391114f3 100644 --- a/cli/server_test.go +++ b/cli/server_test.go @@ -25,7 +25,6 @@ import ( "runtime" "strconv" "strings" - "sync" "sync/atomic" "testing" "time" @@ -253,10 +252,8 @@ func TestServer(t *testing.T) { "--access-url", "http://localhost:3000/", "--cache-dir", t.TempDir(), ) - stdoutRW := syncReaderWriter{} - stderrRW := syncReaderWriter{} - inv.Stdout = io.MultiWriter(os.Stdout, &stdoutRW) - inv.Stderr = io.MultiWriter(os.Stderr, &stderrRW) + pty := ptytest.New(t).Attach(inv) + require.NoError(t, pty.Resize(20, 80)) clitest.Start(t, inv) // Wait for startup @@ -270,8 +267,9 @@ func TestServer(t *testing.T) { // normally shown to the user, so we'll ignore them. ignoreLines := []string{ "isn't externally reachable", - "install.sh will be unavailable", + "open install.sh: file does not exist", "telemetry disabled, unable to notify of security issues", + "installed terraform version newer than expected", } countLines := func(fullOutput string) int { @@ -282,9 +280,11 @@ func TestServer(t *testing.T) { for _, line := range linesByNewline { for _, ignoreLine := range ignoreLines { if strings.Contains(line, ignoreLine) { + t.Logf("Ignoring: %q", line) continue lineLoop } } + t.Logf("Counting: %q", line) if line == "" { // Empty lines take up one line. countByWidth++ @@ -295,17 +295,10 @@ func TestServer(t *testing.T) { return countByWidth } - stdout, err := io.ReadAll(&stdoutRW) - if err != nil { - t.Fatalf("failed to read stdout: %v", err) - } - stderr, err := io.ReadAll(&stderrRW) - if err != nil { - t.Fatalf("failed to read stderr: %v", err) - } - - numLines := countLines(string(stdout)) + countLines(string(stderr)) - require.Less(t, numLines, 20) + out := pty.ReadAll() + numLines := countLines(string(out)) + t.Logf("numLines: %d", numLines) + require.Less(t, numLines, 12, "expected less than 12 lines of output (terminal width 80), got %d", numLines) }) t.Run("OAuth2GitHubDefaultProvider", func(t *testing.T) { @@ -2355,22 +2348,3 @@ func mockTelemetryServer(t *testing.T) (*url.URL, chan *telemetry.Deployment, ch return serverURL, deployment, snapshot } - -// syncWriter provides a thread-safe io.ReadWriter implementation -type syncReaderWriter struct { - buf bytes.Buffer - mu sync.Mutex -} - -func (w *syncReaderWriter) Write(p []byte) (n int, err error) { - w.mu.Lock() - defer w.mu.Unlock() - return w.buf.Write(p) -} - -func (w *syncReaderWriter) Read(p []byte) (n int, err error) { - w.mu.Lock() - defer w.mu.Unlock() - - return w.buf.Read(p) -} diff --git a/pty/ptytest/ptytest.go b/pty/ptytest/ptytest.go index 3c86970ec0006..42d9f34a7bae0 100644 --- a/pty/ptytest/ptytest.go +++ b/pty/ptytest/ptytest.go @@ -319,6 +319,11 @@ func (e *outExpecter) ReadLine(ctx context.Context) string { return buffer.String() } +func (e *outExpecter) ReadAll() []byte { + e.t.Helper() + return e.out.ReadAll() +} + func (e *outExpecter) doMatchWithDeadline(ctx context.Context, name string, fn func(*bufio.Reader) error) error { e.t.Helper() @@ -460,6 +465,18 @@ func newStdbuf() *stdbuf { return &stdbuf{more: make(chan struct{}, 1)} } +func (b *stdbuf) ReadAll() []byte { + b.mu.Lock() + defer b.mu.Unlock() + + if b.err != nil { + return nil + } + p := append([]byte(nil), b.b...) + b.b = b.b[len(b.b):] + return p +} + func (b *stdbuf) Read(p []byte) (int, error) { if b.r == nil { return b.readOrWaitForMore(p)