Skip to content

refactor: PTY & SSH #7100

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 22 commits into from
Apr 24, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Add ssh tests for longoutput, orphan
Signed-off-by: Spike Curtis <spike@coder.com>
  • Loading branch information
spikecurtis committed Apr 11, 2023
commit 0075d7da3108610bb2fd2c2a55394187baf4a2da
207 changes: 207 additions & 0 deletions agent/agentssh/agentssh_internal_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
package agentssh

import (
"bufio"
"context"
"net"
"strconv"
"testing"
"time"

"golang.org/x/crypto/ssh"

"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/codersdk/agentsdk"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/atomic"
)

const countingScript = `
i=0
while [ $i -ne 20000 ]
do
i=$(($i+1))
echo "$i"
done
`

func TestServer_sessionStart_longoutput(t *testing.T) {
t.Parallel()

ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
logger := slogtest.Make(t, nil)
s, err := NewServer(ctx, logger, 0)
require.NoError(t, err)

// The assumption is that these are set before serving SSH connections.
s.AgentToken = func() string { return "" }
s.Manifest = atomic.NewPointer(&agentsdk.Manifest{})

ln, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)

done := make(chan struct{})
go func() {
defer close(done)
err := s.Serve(ln)
assert.Error(t, err) // Server is closed.
}()

c := SSHTestClient(t, ln.Addr().String())
sess, err := c.NewSession()
require.NoError(t, err)

stdout, err := sess.StdoutPipe()
require.NoError(t, err)
readDone := make(chan struct{})
go func() {
w := 0
defer close(readDone)
s := bufio.NewScanner(stdout)
for s.Scan() {
w++
ns := s.Text()
n, err := strconv.Atoi(ns)
require.NoError(t, err)
require.Equal(t, w, n, "output corrupted")
}
assert.Equal(t, w, 20000, "output truncated")
assert.NoError(t, s.Err())
}()

err = sess.Start(countingScript)
require.NoError(t, err)

select {
case <-readDone:
// OK
case <-ctx.Done():
t.Fatal("read timeout")
}

sessionDone := make(chan struct{})
go func() {
defer close(sessionDone)
err := sess.Wait()
assert.NoError(t, err)
}()

select {
case <-sessionDone:
// OK!
case <-ctx.Done():
t.Fatal("session timeout")
}
}

const longScript = `
echo "started"
sleep 30
echo "done"
`

func TestServer_sessionStart_orphan(t *testing.T) {
t.Parallel()

ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
logger := slogtest.Make(t, nil)
s, err := NewServer(ctx, logger, 0)
require.NoError(t, err)

// The assumption is that these are set before serving SSH connections.
s.AgentToken = func() string { return "" }
s.Manifest = atomic.NewPointer(&agentsdk.Manifest{})

ln, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)

done := make(chan struct{})
go func() {
defer close(done)
err := s.Serve(ln)
assert.Error(t, err) // Server is closed.
}()

c := SSHTestClient(t, ln.Addr().String())
sess, err := c.NewSession()
require.NoError(t, err)

stdout, err := sess.StdoutPipe()
require.NoError(t, err)
readDone := make(chan struct{})
go func() {
defer close(readDone)
s := bufio.NewScanner(stdout)
require.True(t, s.Scan())
txt := s.Text()
assert.Equal(t, "started", txt, "output corrupted")
}()

err = sess.Start(longScript)
require.NoError(t, err)

select {
case <-readDone:
// OK
case <-ctx.Done():
t.Fatal("read timeout")
}

// process is started, and should be sleeping for ~30 seconds
// close the session
err = sess.Close()
require.NoError(t, err)

// now, we wait for the handler to complete. If it does so before the
// main test timeout, we consider this a pass. If not, it indicates
// that the server isn't properly shutting down sessions when they are
// disconnected client side, which could lead to processes hanging around
// indefinitely.
handlerDone := make(chan struct{})
go func() {
defer close(handlerDone)
for {
select {
case <-time.After(time.Millisecond * 10):
s.mu.Lock()
n := len(s.sessions)
s.mu.Unlock()
if n == 0 {
return
}
}
}
}()

select {
case <-handlerDone:
// OK!
case <-ctx.Done():
t.Fatal("handler timeout")
}
}

// SSHTestClient creates an ssh.Client for testing
func SSHTestClient(t *testing.T, addr string) *ssh.Client {
conn, err := net.Dial("tcp", addr)
require.NoError(t, err)
t.Cleanup(func() {
_ = conn.Close()
})

sshConn, channels, requests, err := ssh.NewClientConn(conn, "localhost:22", &ssh.ClientConfig{
HostKeyCallback: ssh.InsecureIgnoreHostKey(), //nolint:gosec // This is a test.
})
require.NoError(t, err)
t.Cleanup(func() {
_ = sshConn.Close()
})
c := ssh.NewClient(sshConn, channels, requests)
t.Cleanup(func() {
_ = c.Close()
})
return c
}
29 changes: 3 additions & 26 deletions agent/agentssh/agentssh_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,11 @@ import (
"sync"
"testing"

"cdr.dev/slog/sloggers/slogtest"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/atomic"
"go.uber.org/goleak"
"golang.org/x/crypto/ssh"

"cdr.dev/slog/sloggers/slogtest"

"github.com/coder/coder/agent/agentssh"
"github.com/coder/coder/codersdk/agentsdk"
Expand Down Expand Up @@ -49,7 +47,7 @@ func TestNewServer_ServeClient(t *testing.T) {
assert.Error(t, err) // Server is closed.
}()

c := sshClient(t, ln.Addr().String())
c := agentssh.SSHTestClient(t, ln.Addr().String())
var b bytes.Buffer
sess, err := c.NewSession()
sess.Stdout = &b
Expand Down Expand Up @@ -95,7 +93,7 @@ func TestNewServer_CloseActiveConnections(t *testing.T) {
doClose := make(chan struct{})
go func() {
defer wg.Done()
c := sshClient(t, ln.Addr().String())
c := agentssh.SSHTestClient(t, ln.Addr().String())
sess, err := c.NewSession()
sess.Stdin = pty.Input()
sess.Stdout = pty.Output()
Expand All @@ -116,24 +114,3 @@ func TestNewServer_CloseActiveConnections(t *testing.T) {

wg.Wait()
}

func sshClient(t *testing.T, addr string) *ssh.Client {
conn, err := net.Dial("tcp", addr)
require.NoError(t, err)
t.Cleanup(func() {
_ = conn.Close()
})

sshConn, channels, requests, err := ssh.NewClientConn(conn, "localhost:22", &ssh.ClientConfig{
HostKeyCallback: ssh.InsecureIgnoreHostKey(), //nolint:gosec // This is a test.
})
require.NoError(t, err)
t.Cleanup(func() {
_ = sshConn.Close()
})
c := ssh.NewClient(sshConn, channels, requests)
t.Cleanup(func() {
_ = c.Close()
})
return c
}