Skip to content

fix: use terminal emulator that keeps state in ReconnectingPTY tests #9765

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Sep 19, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 15 additions & 11 deletions agent/agent_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1669,13 +1669,15 @@ func TestAgent_ReconnectingPTY(t *testing.T) {
}

// Once for typing the command...
require.NoError(t, testutil.ReadUntil(ctx, t, netConn1, matchEchoCommand), "find echo command")
tr1 := testutil.NewTerminalReader(t, netConn1)
require.NoError(t, tr1.ReadUntil(ctx, matchEchoCommand), "find echo command")
// And another time for the actual output.
require.NoError(t, testutil.ReadUntil(ctx, t, netConn1, matchEchoOutput), "find echo output")
require.NoError(t, tr1.ReadUntil(ctx, matchEchoOutput), "find echo output")

// Same for the other connection.
require.NoError(t, testutil.ReadUntil(ctx, t, netConn2, matchEchoCommand), "find echo command")
require.NoError(t, testutil.ReadUntil(ctx, t, netConn2, matchEchoOutput), "find echo output")
tr2 := testutil.NewTerminalReader(t, netConn2)
require.NoError(t, tr2.ReadUntil(ctx, matchEchoCommand), "find echo command")
require.NoError(t, tr2.ReadUntil(ctx, matchEchoOutput), "find echo output")

_ = netConn1.Close()
_ = netConn2.Close()
Expand All @@ -1684,8 +1686,9 @@ func TestAgent_ReconnectingPTY(t *testing.T) {
defer netConn3.Close()

// Same output again!
require.NoError(t, testutil.ReadUntil(ctx, t, netConn3, matchEchoCommand), "find echo command")
require.NoError(t, testutil.ReadUntil(ctx, t, netConn3, matchEchoOutput), "find echo output")
tr3 := testutil.NewTerminalReader(t, netConn3)
require.NoError(t, tr3.ReadUntil(ctx, matchEchoCommand), "find echo command")
require.NoError(t, tr3.ReadUntil(ctx, matchEchoOutput), "find echo output")

// Exit should cause the connection to close.
data, err = json.Marshal(codersdk.ReconnectingPTYRequest{
Expand All @@ -1696,19 +1699,20 @@ func TestAgent_ReconnectingPTY(t *testing.T) {
require.NoError(t, err)

// Once for the input and again for the output.
require.NoError(t, testutil.ReadUntil(ctx, t, netConn3, matchExitCommand), "find exit command")
require.NoError(t, testutil.ReadUntil(ctx, t, netConn3, matchExitOutput), "find exit output")
require.NoError(t, tr3.ReadUntil(ctx, matchExitCommand), "find exit command")
require.NoError(t, tr3.ReadUntil(ctx, matchExitOutput), "find exit output")

// Wait for the connection to close.
require.ErrorIs(t, testutil.ReadUntil(ctx, t, netConn3, nil), io.EOF)
require.ErrorIs(t, tr3.ReadUntil(ctx, nil), io.EOF)

// Try a non-shell command. It should output then immediately exit.
netConn4, err := conn.ReconnectingPTY(ctx, uuid.New(), 80, 80, "echo test")
require.NoError(t, err)
defer netConn4.Close()

require.NoError(t, testutil.ReadUntil(ctx, t, netConn4, matchEchoOutput), "find echo output")
require.ErrorIs(t, testutil.ReadUntil(ctx, t, netConn3, nil), io.EOF)
tr4 := testutil.NewTerminalReader(t, netConn4)
require.NoError(t, tr4.ReadUntil(ctx, matchEchoOutput), "find echo output")
require.ErrorIs(t, tr4.ReadUntil(ctx, nil), io.EOF)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note: original was netConn3 but I think that's copypasta

})
}
}
Expand Down
11 changes: 6 additions & 5 deletions coderd/workspaceapps/apptest/apptest.go
Original file line number Diff line number Diff line change
Expand Up @@ -1428,8 +1428,9 @@ func testReconnectingPTY(ctx context.Context, t *testing.T, client *codersdk.Cli
_, err = conn.Write(data)
require.NoError(t, err)

require.NoError(t, testutil.ReadUntil(ctx, t, conn, matchEchoCommand), "find echo command")
require.NoError(t, testutil.ReadUntil(ctx, t, conn, matchEchoOutput), "find echo output")
tr := testutil.NewTerminalReader(t, conn)
require.NoError(t, tr.ReadUntil(ctx, matchEchoCommand), "find echo command")
require.NoError(t, tr.ReadUntil(ctx, matchEchoOutput), "find echo output")

// Exit should cause the connection to close.
data, err = json.Marshal(codersdk.ReconnectingPTYRequest{
Expand All @@ -1440,9 +1441,9 @@ func testReconnectingPTY(ctx context.Context, t *testing.T, client *codersdk.Cli
require.NoError(t, err)

// Once for the input and again for the output.
require.NoError(t, testutil.ReadUntil(ctx, t, conn, matchExitCommand), "find exit command")
require.NoError(t, testutil.ReadUntil(ctx, t, conn, matchExitOutput), "find exit output")
require.NoError(t, tr.ReadUntil(ctx, matchExitCommand), "find exit command")
require.NoError(t, tr.ReadUntil(ctx, matchExitOutput), "find exit output")

// Ensure the connection closes.
require.ErrorIs(t, testutil.ReadUntil(ctx, t, conn, nil), io.EOF)
require.ErrorIs(t, tr.ReadUntil(ctx, nil), io.EOF)
}
4 changes: 2 additions & 2 deletions pty/start_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,11 @@ func Test_Start_truncation(t *testing.T) {
readDone := make(chan struct{})
go func() {
defer close(readDone)
// avoid buffered IO so that we can precisely control how many bytes to read.
terminalReader := testutil.NewTerminalReader(t, pc.OutputReader())
n := 1
for n <= countEnd {
want := fmt.Sprintf("%d", n)
err := testutil.ReadUntilString(ctx, t, want, pc.OutputReader())
err := terminalReader.ReadUntilString(ctx, want)
assert.NoError(t, err, "want: %s", want)
if err != nil {
return
Expand Down
56 changes: 42 additions & 14 deletions testutil/pty.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,46 +9,74 @@ import (
"github.com/hinshun/vt10x"
)

// TerminalReader emulates a terminal and allows matching output. It's important in cases where we
// can get control sequences to parse them correctly, and keep the state of the terminal across the
// lifespan of the PTY, since some control sequences are relative to the current cursor position.
type TerminalReader struct {
t *testing.T
r io.Reader
term vt10x.Terminal
}

func NewTerminalReader(t *testing.T, r io.Reader) *TerminalReader {
return &TerminalReader{
t: t,
r: r,
term: vt10x.New(vt10x.WithSize(80, 80)),
}
}

// ReadUntilString emulates a terminal and reads one byte at a time until we
// either see the string we want, or the context expires. The PTY must be sized
// to 80x80 or there could be unexpected results.
func ReadUntilString(ctx context.Context, t *testing.T, want string, r io.Reader) error {
return ReadUntil(ctx, t, r, func(line string) bool {
func (tr *TerminalReader) ReadUntilString(ctx context.Context, want string) error {
return tr.ReadUntil(ctx, func(line string) bool {
return strings.TrimSpace(line) == want
})
}

// ReadUntil emulates a terminal and reads one byte at a time until the matcher
// returns true or the context expires. If the matcher is nil, read until EOF.
// The PTY must be sized to 80x80 or there could be unexpected results.
func ReadUntil(ctx context.Context, t *testing.T, r io.Reader, matcher func(line string) bool) error {
// output can contain virtual terminal sequences, so we need to parse these
// to correctly interpret getting what we want.
term := vt10x.New(vt10x.WithSize(80, 80))
func (tr *TerminalReader) ReadUntil(ctx context.Context, matcher func(line string) bool) (retErr error) {
readBytes := make([]byte, 0)
readErrs := make(chan error, 1)
defer func() {
// Dump the terminal contents since they can be helpful for debugging, but
// skip empty lines since much of the terminal will usually be blank.
got := term.String()
// trim empty lines since much of the terminal will usually be blank.
got := tr.term.String()
lines := strings.Split(got, "\n")
for _, line := range lines {
if strings.TrimSpace(line) != "" {
t.Logf("got: %v", line)
for i := range lines {
if strings.TrimSpace(lines[i]) != "" {
lines = lines[i:]
break
}
}
for i := len(lines) - 1; i >= 0; i-- {
if strings.TrimSpace(lines[i]) != "" {
lines = lines[:i+1]
break
}
}
gotTrimmed := strings.Join(lines, "\n")
tr.t.Logf("Terminal contents:\n%s", gotTrimmed)
if retErr != nil {
tr.t.Logf("Bytes Read: %q", string(readBytes))
}
}()
for {
b := make([]byte, 1)
go func() {
_, err := r.Read(b)
_, err := tr.r.Read(b)
readErrs <- err
}()
select {
case err := <-readErrs:
if err != nil {
return err
}
_, err = term.Write(b)
readBytes = append(readBytes, b...)
_, err = tr.term.Write(b)
if err != nil {
return err
}
Expand All @@ -59,7 +87,7 @@ func ReadUntil(ctx context.Context, t *testing.T, r io.Reader, matcher func(line
// A nil matcher means to read until EOF.
continue
}
got := term.String()
got := tr.term.String()
lines := strings.Split(got, "\n")
for _, line := range lines {
if matcher(line) {
Expand Down