Skip to content

Commit 2cf357a

Browse files
committed
Fix some tests
Signed-off-by: Spike Curtis <spike@coder.com>
1 parent a491d4f commit 2cf357a

File tree

8 files changed

+57
-115
lines changed

8 files changed

+57
-115
lines changed

agent/agent.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1045,7 +1045,7 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, logger slog.Logger, m
10451045
if err = a.trackConnGoroutine(func() {
10461046
buffer := make([]byte, 1024)
10471047
for {
1048-
read, err := rpty.ptty.Output().Read(buffer)
1048+
read, err := rpty.ptty.OutputReader().Read(buffer)
10491049
if err != nil {
10501050
// When the PTY is closed, this is triggered.
10511051
break
@@ -1138,7 +1138,7 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, logger slog.Logger, m
11381138
logger.Warn(ctx, "read conn", slog.Error(err))
11391139
return nil
11401140
}
1141-
_, err = rpty.ptty.Input().Write([]byte(req.Data))
1141+
_, err = rpty.ptty.InputWriter().Write([]byte(req.Data))
11421142
if err != nil {
11431143
logger.Warn(ctx, "write to pty", slog.Error(err))
11441144
return nil
@@ -1358,7 +1358,7 @@ type reconnectingPTY struct {
13581358
circularBuffer *circbuf.Buffer
13591359
circularBufferMutex sync.RWMutex
13601360
timeout *time.Timer
1361-
ptty pty.PTY
1361+
ptty pty.PTYCmd
13621362
}
13631363

13641364
// Close ends all connections to the reconnecting

agent/agentssh/agentssh.go

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ import (
88
"errors"
99
"fmt"
1010
"io"
11-
"io/fs"
1211
"net"
1312
"os"
1413
"os/exec"
@@ -330,12 +329,7 @@ func (s *Server) startPTYSession(session ptySession, cmd *exec.Cmd, sshPty ssh.P
330329
// kill the command's process. This then has the same effect as (1).
331330
n, err := io.Copy(session, ptty.OutputReader())
332331
s.logger.Debug(ctx, "copy output done", slog.F("bytes", n), slog.Error(err))
333-
334-
// output from the ptty will hit a PathErr on the PTY when the process
335-
// hangs up the other side (typically when the process exits, but could
336-
// be earlier)
337-
pathErr := &fs.PathError{}
338-
if err != nil && !xerrors.As(err, &pathErr) {
332+
if err != nil {
339333
return xerrors.Errorf("copy error: %w", err, err, err)
340334
}
341335
// We've gotten all the output, but we need to wait for the process to

agent/agentssh/agentssh_test.go

Lines changed: 0 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,12 @@
33
package agentssh_test
44

55
import (
6-
"bufio"
76
"bytes"
87
"context"
98
"net"
10-
"strconv"
119
"strings"
1210
"sync"
1311
"testing"
14-
"time"
1512

1613
"cdr.dev/slog/sloggers/slogtest"
1714
"github.com/stretchr/testify/assert"
@@ -93,7 +90,6 @@ func TestNewServer_CloseActiveConnections(t *testing.T) {
9390
}()
9491

9592
pty := ptytest.New(t)
96-
defer pty.Close()
9793

9894
doClose := make(chan struct{})
9995
go func() {
@@ -120,92 +116,6 @@ func TestNewServer_CloseActiveConnections(t *testing.T) {
120116
wg.Wait()
121117
}
122118

123-
const countingScript = `
124-
i=0
125-
while [ $i -ne 20000 ]
126-
do
127-
i=$(($i+1))
128-
echo "$i"
129-
done
130-
`
131-
132-
// TestServer_sessionStart_longoutput is designed to test running a command that
133-
// produces a lot of output and ensure we don't truncate the output returned
134-
// over SSH.
135-
func TestServer_sessionStart_longoutput(t *testing.T) {
136-
t.Parallel()
137-
138-
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
139-
defer cancel()
140-
logger := slogtest.Make(t, nil)
141-
s, err := agentssh.NewServer(ctx, logger, 0)
142-
require.NoError(t, err)
143-
144-
// The assumption is that these are set before serving SSH connections.
145-
s.AgentToken = func() string { return "" }
146-
s.Manifest = atomic.NewPointer(&agentsdk.Manifest{})
147-
148-
ln, err := net.Listen("tcp", "127.0.0.1:0")
149-
require.NoError(t, err)
150-
151-
done := make(chan struct{})
152-
go func() {
153-
defer close(done)
154-
err := s.Serve(ln)
155-
assert.Error(t, err) // Server is closed.
156-
}()
157-
158-
c := sshClient(t, ln.Addr().String())
159-
sess, err := c.NewSession()
160-
require.NoError(t, err)
161-
162-
stdout, err := sess.StdoutPipe()
163-
require.NoError(t, err)
164-
readDone := make(chan struct{})
165-
go func() {
166-
w := 0
167-
defer close(readDone)
168-
s := bufio.NewScanner(stdout)
169-
for s.Scan() {
170-
w++
171-
ns := s.Text()
172-
n, err := strconv.Atoi(ns)
173-
require.NoError(t, err)
174-
require.Equal(t, w, n, "output corrupted")
175-
}
176-
assert.Equal(t, w, 20000, "output truncated")
177-
assert.NoError(t, s.Err())
178-
}()
179-
180-
err = sess.Start(countingScript)
181-
require.NoError(t, err)
182-
183-
waitForChan(t, readDone, ctx, "read timeout")
184-
185-
sessionDone := make(chan struct{})
186-
go func() {
187-
defer close(sessionDone)
188-
err := sess.Wait()
189-
assert.NoError(t, err)
190-
}()
191-
192-
waitForChan(t, sessionDone, ctx, "session timeout")
193-
err = s.Close()
194-
require.NoError(t, err)
195-
waitForChan(t, done, ctx, "timeout closing server")
196-
}
197-
198-
func waitForChan(t *testing.T, c <-chan struct{}, ctx context.Context, msg string) {
199-
t.Helper()
200-
select {
201-
case <-c:
202-
// OK!
203-
case <-ctx.Done():
204-
t.Fatal(msg)
205-
}
206-
}
207-
208-
// sshClient creates an ssh.Client for testing
209119
func sshClient(t *testing.T, addr string) *ssh.Client {
210120
conn, err := net.Dial("tcp", addr)
211121
require.NoError(t, err)

pty/pty.go

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ package pty
33
import (
44
"io"
55
"log"
6-
"os"
76

87
"github.com/gliderlabs/ssh"
98
"golang.org/x/xerrors"
@@ -119,8 +118,8 @@ func New(opts ...Option) (PTY, error) {
119118
// underlying file descriptors, one for reading and one for writing, and allows
120119
// them to be accessed separately.
121120
type ReadWriter struct {
122-
Reader *os.File
123-
Writer *os.File
121+
Reader io.Reader
122+
Writer io.Writer
124123
}
125124

126125
func (rw ReadWriter) Read(p []byte) (int, error) {

pty/pty_other.go

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,14 @@ package pty
44

55
import (
66
"io"
7+
"io/fs"
78
"os"
89
"os/exec"
910
"runtime"
1011
"sync"
1112

13+
"golang.org/x/xerrors"
14+
1215
"github.com/creack/pty"
1316
"github.com/u-root/u-root/pkg/termios"
1417
"golang.org/x/sys/unix"
@@ -103,13 +106,13 @@ func (p *otherPty) InputWriter() io.Writer {
103106

104107
func (p *otherPty) Output() ReadWriter {
105108
return ReadWriter{
106-
Reader: p.pty,
109+
Reader: &ptmReader{p.pty},
107110
Writer: p.tty,
108111
}
109112
}
110113

111114
func (p *otherPty) OutputReader() io.Reader {
112-
return p.pty
115+
return &ptmReader{p.pty}
113116
}
114117

115118
func (p *otherPty) Resize(height uint16, width uint16) error {
@@ -176,3 +179,21 @@ func (p *otherProcess) waitInternal() {
176179
runtime.KeepAlive(p.pty)
177180
close(p.cmdDone)
178181
}
182+
183+
// ptmReader wraps a reference to the ptm side of a pseudo-TTY for portability
184+
type ptmReader struct {
185+
ptm io.Reader
186+
}
187+
188+
func (r *ptmReader) Read(p []byte) (n int, err error) {
189+
n, err = r.ptm.Read(p)
190+
// output from the ptm will hit a PathErr when the process hangs up the
191+
// other side (typically when the process exits, but could be earlier). For
192+
// portability, and to fit with our use of io.Copy() to copy from the PTY,
193+
// we want to translate this error into io.EOF
194+
pathErr := &fs.PathError{}
195+
if xerrors.As(err, &pathErr) {
196+
return n, io.EOF
197+
}
198+
return n, err
199+
}

pty/ptytest/ptytest.go

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -29,20 +29,21 @@ func New(t *testing.T, opts ...pty.Option) *PTY {
2929

3030
ptty, err := pty.New(opts...)
3131
require.NoError(t, err)
32-
// Ensure pty is cleaned up at the end of test.
33-
t.Cleanup(func() {
34-
_ = ptty.Close()
35-
})
3632

3733
e := newExpecter(t, ptty.Output(), "cmd")
38-
return &PTY{
34+
r := &PTY{
3935
outExpecter: *e,
4036
PTY: ptty,
4137
}
38+
// Ensure pty is cleaned up at the end of test.
39+
t.Cleanup(func() {
40+
_ = r.Close()
41+
})
42+
return r
4243
}
4344

44-
// Start starts a new process asynchronously and returns a PTY and Process.
45-
// It kills the process upon cleanup.
45+
// Start starts a new process asynchronously and returns a PTYCmd and Process.
46+
// It kills the process and PTYCmd upon cleanup
4647
func Start(t *testing.T, cmd *exec.Cmd, opts ...pty.StartOption) (*PTYCmd, pty.Process) {
4748
t.Helper()
4849

@@ -54,10 +55,14 @@ func Start(t *testing.T, cmd *exec.Cmd, opts ...pty.StartOption) (*PTYCmd, pty.P
5455
})
5556
ex := newExpecter(t, ptty.OutputReader(), cmd.Args[0])
5657

57-
return &PTYCmd{
58+
r := &PTYCmd{
5859
outExpecter: *ex,
5960
PTYCmd: ptty,
60-
}, ps
61+
}
62+
t.Cleanup(func() {
63+
_ = r.Close()
64+
})
65+
return r, ps
6166
}
6267

6368
func newExpecter(t *testing.T, r io.Reader, name string) *outExpecter {

pty/start_other_test.go

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,20 +25,25 @@ func TestStart(t *testing.T) {
2525
t.Run("Echo", func(t *testing.T) {
2626
t.Parallel()
2727
pty, ps := ptytest.Start(t, exec.Command("echo", "test"))
28+
2829
pty.ExpectMatch("test")
2930
err := ps.Wait()
3031
require.NoError(t, err)
32+
err = pty.Close()
33+
require.NoError(t, err)
3134
})
3235

3336
t.Run("Kill", func(t *testing.T) {
3437
t.Parallel()
35-
_, ps := ptytest.Start(t, exec.Command("sleep", "30"))
38+
pty, ps := ptytest.Start(t, exec.Command("sleep", "30"))
3639
err := ps.Kill()
3740
assert.NoError(t, err)
3841
err = ps.Wait()
3942
var exitErr *exec.ExitError
4043
require.True(t, xerrors.As(err, &exitErr))
4144
assert.NotEqual(t, 0, exitErr.ExitCode())
45+
err = pty.Close()
46+
require.NoError(t, err)
4247
})
4348

4449
t.Run("SSH_TTY", func(t *testing.T) {
@@ -53,5 +58,7 @@ func TestStart(t *testing.T) {
5358
pty.ExpectMatch("SSH_TTY=/dev/")
5459
err := ps.Wait()
5560
require.NoError(t, err)
61+
err = pty.Close()
62+
require.NoError(t, err)
5663
})
5764
}

pty/start_windows_test.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,16 @@ func TestStart(t *testing.T) {
2626
pty.ExpectMatch("test")
2727
err := ps.Wait()
2828
require.NoError(t, err)
29+
err = pty.Close()
30+
require.NoError(t, err)
2931
})
3032
t.Run("Resize", func(t *testing.T) {
3133
t.Parallel()
3234
pty, _ := ptytest.Start(t, exec.Command("cmd.exe"))
3335
err := pty.Resize(100, 50)
3436
require.NoError(t, err)
37+
err = pty.Close()
38+
require.NoError(t, err)
3539
})
3640
t.Run("Kill", func(t *testing.T) {
3741
t.Parallel()
@@ -42,5 +46,7 @@ func TestStart(t *testing.T) {
4246
var exitErr *exec.ExitError
4347
require.True(t, xerrors.As(err, &exitErr))
4448
assert.NotEqual(t, 0, exitErr.ExitCode())
49+
err = pty.Close()
50+
require.NoError(t, err)
4551
})
4652
}

0 commit comments

Comments
 (0)