Skip to content

refactor: PTY & SSH #7100

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 22 commits into from
Apr 24, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Fix truncation tests
Signed-off-by: Spike Curtis <spike@coder.com>
  • Loading branch information
spikecurtis committed Apr 18, 2023
commit 90bfe94d9df6d90de99a6e000c950fa86b24c8f4
2 changes: 2 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,8 @@ require (
tailscale.com v1.32.2
)

require github.com/hinshun/vt10x v0.0.0-20220301184237-5011da428d02 // indirect

require (
cloud.google.com/go/compute v1.18.0 // indirect
cloud.google.com/go/logging v1.6.1 // indirect
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -1088,6 +1088,8 @@ github.com/hdevalence/ed25519consensus v0.0.0-20220222234857-c00d1f31bab3 h1:aSV
github.com/hdevalence/ed25519consensus v0.0.0-20220222234857-c00d1f31bab3/go.mod h1:5PC6ZNPde8bBqU/ewGZig35+UIZtw9Ytxez8/q5ZyFE=
github.com/hinshun/vt10x v0.0.0-20220119200601-820417d04eec h1:qv2VnGeEQHchGaZ/u7lxST/RaJw+cv273q79D81Xbog=
github.com/hinshun/vt10x v0.0.0-20220119200601-820417d04eec/go.mod h1:Q48J4R4DvxnHolD5P8pOtXigYlRuPLGl6moFx3ulM68=
github.com/hinshun/vt10x v0.0.0-20220301184237-5011da428d02 h1:AgcIVYPa6XJnU3phs104wLj8l5GEththEw6+F79YsIY=
github.com/hinshun/vt10x v0.0.0-20220301184237-5011da428d02/go.mod h1:Q48J4R4DvxnHolD5P8pOtXigYlRuPLGl6moFx3ulM68=
github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU=
github.com/huandu/xstrings v1.0.0/go.mod h1:4qWG/gcEcfX4z/mBDHJ++3ReCw9ibxbsNJbcucJdbSo=
github.com/huandu/xstrings v1.2.0/go.mod h1:DvyZB1rfVYsBIigL8HwpZgxHwXozlTgGqn63UyNX5k4=
Expand Down
12 changes: 9 additions & 3 deletions pty/pty_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,10 @@ func (p *ptyWindows) Close() error {
}
p.closed = true

// if we are running a command in the PTY, the corresponding *windowsProcess
// may have already closed the PseudoConsole when the command exited, so that
// output reads can get to EOF. In that case, we don't need to close it
// again here.
if p.console != windows.InvalidHandle {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Assigned both here and in waitInternal, probably needs protection.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Protected by the closeMutex

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah right, the p.pw threw me off. I noticed Resize is not protected though!

ret, _, err := procClosePseudoConsole.Call(uintptr(p.console))
if ret < 0 {
Expand Down Expand Up @@ -169,9 +173,11 @@ func (p *windowsProcess) waitInternal() {
defer p.pw.closeMutex.Unlock()
if p.pw.console != windows.InvalidHandle {
ret, _, err := procClosePseudoConsole.Call(uintptr(p.pw.console))
if ret < 0 {
// not much we can do here...
panic(err)
if ret < 0 && p.cmdErr == nil {
// if we already have an error from the command, prefer that error
// but if the command succeeded and closing the PseudoConsole fails
// then record that errror so that we have a chance to see it
p.cmdErr = err
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It'd be nice to log the error whether it's kept or discarded

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be nice, but I don't have a logger, and I'd have to change all the init signatures to get one, so I decided on this.

Another alternative I considered is panicking---something pretty messed up at the OS level is happening if we get an error here. We're pretty much guaranteed to crash the agent if we panic, so it's likely we'd get told about it.

}
p.pw.console = windows.InvalidHandle
}
Expand Down
20 changes: 20 additions & 0 deletions pty/start_other_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,23 @@ func TestStart(t *testing.T) {
require.NoError(t, err)
})
}

// these constants/vars are used by Test_Start_copy

const cmdEcho = "echo"

var argEcho = []string{"test"}

// these constants/vars are used by Test_Start_truncate

const countEnd = 1000
const cmdCount = "sh"

var argCount = []string{"-c", `
i=0
while [ $i -ne 1000 ]
do
i=$(($i+1))
echo "$i"
done
`}
158 changes: 158 additions & 0 deletions pty/start_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
package pty_test

import (
"bytes"
"context"
"fmt"
"io"
"os/exec"
"strings"
"testing"
"time"

"github.com/coder/coder/pty"
"github.com/hinshun/vt10x"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

// Test_Start_copy tests that we can use io.Copy() on command output
// without deadlocking.
func Test_Start_copy(t *testing.T) {
t.Parallel()

ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
defer cancel()

pc, cmd, err := pty.Start(exec.CommandContext(ctx, cmdEcho, argEcho...))
require.NoError(t, err)
b := &bytes.Buffer{}
readDone := make(chan error)
go func() {
_, err := io.Copy(b, pc.OutputReader())
readDone <- err
}()

select {
case err := <-readDone:
require.NoError(t, err)
case <-ctx.Done():
t.Error("read timed out")
}
assert.Contains(t, b.String(), "test")

cmdDone := make(chan error)
go func() {
cmdDone <- cmd.Wait()
}()

select {
case err := <-cmdDone:
require.NoError(t, err)
case <-ctx.Done():
t.Error("cmd.Wait() timed out")
}
}

// Test_Start_truncation tests that we can read command ouput without truncation
// even after the command has exited.
func Test_Start_trucation(t *testing.T) {
t.Parallel()

ctx, cancel := context.WithTimeout(context.Background(), time.Second*1000)
defer cancel()

pc, cmd, err := pty.Start(exec.CommandContext(ctx, cmdCount, argCount...))

require.NoError(t, err)
readDone := make(chan struct{})
go func() {
defer close(readDone)
// avoid buffered IO so that we can precisely control how many bytes to read.
n := 1
for n < countEnd-25 {
want := fmt.Sprintf("%d", n)
err := readUntil(ctx, t, want, pc.OutputReader())
require.NoError(t, err, "want: %s", want)
n++
}
}()

select {
case <-readDone:
// OK!
case <-ctx.Done():
t.Error("read timed out")
}

cmdDone := make(chan error)
go func() {
cmdDone <- cmd.Wait()
}()

select {
case err := <-cmdDone:
require.NoError(t, err)
case <-ctx.Done():
t.Error("cmd.Wait() timed out")
}

// do our final 25 reads, to make sure the output wasn't lost
readDone = make(chan struct{})
go func() {
defer close(readDone)
// avoid buffered IO so that we can precisely control how many bytes to read.
n := countEnd - 25
for n <= countEnd {
want := fmt.Sprintf("%d", n)
err := readUntil(ctx, t, want, pc.OutputReader())
require.NoError(t, err, "want: %s", want)
n++
}
// ensure we still get to EOF
endB := &bytes.Buffer{}
_, err := io.Copy(endB, pc.OutputReader())
require.NoError(t, err)

}()

select {
case <-readDone:
// OK!
case <-ctx.Done():
t.Error("read timed out")
}
}

// readUntil reads one byte at a time until we either see the string we want, or the context expires
func readUntil(ctx context.Context, t *testing.T, want string, r io.Reader) error {
// output can contain virtual terminal sequences, so we need to parse these
// to correctly interpret getting what we want.
term := vt10x.New(vt10x.WithSize(80, 80))
readErrs := make(chan error, 1)
for {
b := make([]byte, 1)
go func() {
_, err := r.Read(b)
readErrs <- err
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
readErrs <- err
select {
case readErrs <- err:
default:
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this is needed.

The whole purpose of this little goroutine is to allow me to call Read() but still be able to timeout in this function if the context expires. So I read in a goroutine, then select whether the read completed before the context expires. The chan error is buffered, so even if the context expires, the goroutine will be able to finish once Read() returns. Note that an expired context calls return so there can only be one read goroutine started at a time.

}()
select {
case err := <-readErrs:
if err != nil {
t.Logf("err: %v\ngot: %v", err, term)
return err
}
term.Write(b)
case <-ctx.Done():
return ctx.Err()
}
got := term.String()
lines := strings.Split(got, "\n")
for _, line := range lines {
if strings.TrimSpace(line) == want {
t.Logf("want: %v\n got:%v", want, line)
return nil
}
}
}
}
138 changes: 6 additions & 132 deletions pty/start_windows_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,10 @@
package pty_test

import (
"bytes"
"context"
"fmt"
"io"
"os/exec"
"strings"
"testing"
"time"

"github.com/coder/coder/pty"
"github.com/coder/coder/pty/ptytest"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -58,135 +52,15 @@ func TestStart(t *testing.T) {
})
}

func Test_Start_copy(t *testing.T) {
t.Parallel()

ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
defer cancel()
// these constants/vars are used by Test_Start_copy

pc, cmd, err := pty.Start(exec.CommandContext(ctx, "cmd.exe", "/c", "echo", "test"))
require.NoError(t, err)
b := &bytes.Buffer{}
readDone := make(chan error)
go func() {
_, err := io.Copy(b, pc.OutputReader())
readDone <- err
}()
const cmdEcho = "cmd.exe"

select {
case err := <-readDone:
require.NoError(t, err)
case <-ctx.Done():
t.Error("read timed out")
}
assert.Contains(t, b.String(), "test")
var argEcho = []string{"/c", "echo", "test"}

cmdDone := make(chan error)
go func() {
cmdDone <- cmd.Wait()
}()

select {
case err := <-cmdDone:
require.NoError(t, err)
case <-ctx.Done():
t.Error("cmd.Wait() timed out")
}
}
// these constants/vars are used by Test_Start_truncate

const countEnd = 1000
const cmdCount = "cmd.exe"

func Test_Start_trucation(t *testing.T) {
t.Parallel()

ctx, cancel := context.WithTimeout(context.Background(), time.Second*1000)
defer cancel()

pc, cmd, err := pty.Start(exec.CommandContext(ctx,
"cmd.exe", "/c",
fmt.Sprintf("for /L %%n in (1,1,%d) do @echo %%n", countEnd)))
require.NoError(t, err)
readDone := make(chan struct{})
go func() {
defer close(readDone)
// avoid buffered IO so that we can precisely control how many bytes to read.
n := 1
for n < countEnd-25 {
want := fmt.Sprintf("%d\r\n", n)
// the output also contains virtual terminal sequences
// so just read until we see the number we want.
err := readUntil(ctx, want, pc.OutputReader())
require.NoError(t, err, "want: %s", want)
n++
}
}()

select {
case <-readDone:
// OK!
case <-ctx.Done():
t.Error("read timed out")
}

cmdDone := make(chan error)
go func() {
cmdDone <- cmd.Wait()
}()

select {
case err := <-cmdDone:
require.NoError(t, err)
case <-ctx.Done():
t.Error("cmd.Wait() timed out")
}

// do our final 25 reads, to make sure the output wasn't lost
readDone = make(chan struct{})
go func() {
defer close(readDone)
// avoid buffered IO so that we can precisely control how many bytes to read.
n := countEnd - 25
for n <= countEnd {
want := fmt.Sprintf("%d\r\n", n)
err := readUntil(ctx, want, pc.OutputReader())
if n < countEnd {
require.NoError(t, err, "want: %s", want)
} else {
require.ErrorIs(t, err, io.EOF)
}
n++
}
}()

select {
case <-readDone:
// OK!
case <-ctx.Done():
t.Error("read timed out")
}
}

// readUntil reads one byte at a time until we either see the string we want, or the context expires
func readUntil(ctx context.Context, want string, r io.Reader) error {
got := ""
readErrs := make(chan error, 1)
for {
b := make([]byte, 1)
go func() {
_, err := r.Read(b)
readErrs <- err
}()
select {
case err := <-readErrs:
if err != nil {
return err
}
got = got + string(b)
case <-ctx.Done():
return ctx.Err()
}
if strings.Contains(got, want) {
return nil
}
}
}
var argCount = []string{"/c", fmt.Sprintf("for /L %%n in (1,1,%d) do @echo %%n", countEnd)}