Skip to content

refactor: PTY & SSH #7100

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 22 commits into from
Apr 24, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Fix linting
Signed-off-by: Spike Curtis <spike@coder.com>
  • Loading branch information
spikecurtis committed Apr 12, 2023
commit 28c0646a9490131930bdfdc8d03f9ef417a625c0
12 changes: 3 additions & 9 deletions agent/agentssh/agentssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -235,11 +235,10 @@ func (s *Server) sessionStart(session ssh.Session) error {
if isPty {
return s.startPTYSession(session, cmd, sshPty, windowSize)
}

return s.startNonPTYSession(session, cmd)
return startNonPTYSession(session, cmd)
}

func (s *Server) startNonPTYSession(session ssh.Session, cmd *exec.Cmd) error {
func startNonPTYSession(session ssh.Session, cmd *exec.Cmd) error {
cmd.Stdout = session
cmd.Stderr = session.Stderr()
// This blocks forever until stdin is received if we don't
Expand Down Expand Up @@ -330,7 +329,7 @@ func (s *Server) startPTYSession(session ptySession, cmd *exec.Cmd, sshPty ssh.P
n, err := io.Copy(session, ptty.OutputReader())
s.logger.Debug(ctx, "copy output done", slog.F("bytes", n), slog.Error(err))
if err != nil {
return xerrors.Errorf("copy error: %w", err, err, err)
return xerrors.Errorf("copy error: %w", err)
}
// We've gotten all the output, but we need to wait for the process to
// complete so that we can get the exit code. This returns
Expand All @@ -348,11 +347,6 @@ func (s *Server) startPTYSession(session ptySession, cmd *exec.Cmd, sshPty ssh.P
return nil
}

type readNopCloser struct{ io.Reader }

// Close implements io.Closer.
func (readNopCloser) Close() error { return nil }

func (s *Server) sftpHandler(session ssh.Session) {
ctx := session.Context()

Expand Down
40 changes: 21 additions & 19 deletions agent/agentssh/agentssh_internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,14 @@ import (
"net"
"os/exec"
"testing"
"time"

"cdr.dev/slog/sloggers/slogtest"
gliderssh "github.com/gliderlabs/ssh"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/coder/coder/testutil"

"cdr.dev/slog/sloggers/slogtest"
)

const longScript = `
Expand All @@ -28,7 +30,7 @@ echo "done"
func Test_sessionStart_orphan(t *testing.T) {
t.Parallel()

ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
defer cancel()
logger := slogtest.Make(t, nil)
s, err := NewServer(ctx, logger, 0)
Expand Down Expand Up @@ -62,12 +64,12 @@ func Test_sessionStart_orphan(t *testing.T) {
go func() {
defer close(readDone)
s := bufio.NewScanner(toClient)
require.True(t, s.Scan())
assert.True(t, s.Scan())
txt := s.Text()
assert.Equal(t, "started", txt, "output corrupted")
}()

waitForChan(t, readDone, ctx, "read timeout")
waitForChan(ctx, t, readDone, "read timeout")
// process is started, and should be sleeping for ~30 seconds

sessionCancel()
Expand All @@ -77,13 +79,13 @@ func Test_sessionStart_orphan(t *testing.T) {
// that the server isn't properly shutting down sessions when they are
// disconnected client side, which could lead to processes hanging around
// indefinitely.
waitForChan(t, done, ctx, "handler timeout")
waitForChan(ctx, t, done, "handler timeout")

err = fromClient.Close()
require.NoError(t, err)
}

func waitForChan(t *testing.T, c <-chan struct{}, ctx context.Context, msg string) {
func waitForChan(ctx context.Context, t *testing.T, c <-chan struct{}, msg string) {
t.Helper()
select {
case <-c:
Expand Down Expand Up @@ -121,12 +123,12 @@ func (s *testSession) Context() gliderssh.Context {
return s.ctx
}

func (s *testSession) DisablePTYEmulation() {}
func (*testSession) DisablePTYEmulation() {}

// RawCommand returns "quiet logon" so that the PTY handler doesn't attempt to
// write the message of the day, which will interfere with our tests. It writes
// the message of the day if it's a shell login (zero length RawCommand()).
func (s *testSession) RawCommand() string { return "quiet logon" }
func (*testSession) RawCommand() string { return "quiet logon" }

func (s *testSession) Read(p []byte) (n int, err error) {
return s.toPty.Read(p)
Expand All @@ -136,49 +138,49 @@ func (s *testSession) Write(p []byte) (n int, err error) {
return s.fromPty.Write(p)
}

func (c testSSHContext) Lock() {
func (testSSHContext) Lock() {
panic("not implemented")
}
func (c testSSHContext) Unlock() {
func (testSSHContext) Unlock() {
panic("not implemented")
}

// User returns the username used when establishing the SSH connection.
func (c testSSHContext) User() string {
func (testSSHContext) User() string {
panic("not implemented")
}

// SessionID returns the session hash.
func (c testSSHContext) SessionID() string {
func (testSSHContext) SessionID() string {
panic("not implemented")
}

// ClientVersion returns the version reported by the client.
func (c testSSHContext) ClientVersion() string {
func (testSSHContext) ClientVersion() string {
panic("not implemented")
}

// ServerVersion returns the version reported by the server.
func (c testSSHContext) ServerVersion() string {
func (testSSHContext) ServerVersion() string {
panic("not implemented")
}

// RemoteAddr returns the remote address for this connection.
func (c testSSHContext) RemoteAddr() net.Addr {
func (testSSHContext) RemoteAddr() net.Addr {
panic("not implemented")
}

// LocalAddr returns the local address for this connection.
func (c testSSHContext) LocalAddr() net.Addr {
func (testSSHContext) LocalAddr() net.Addr {
panic("not implemented")
}

// Permissions returns the Permissions object used for this connection.
func (c testSSHContext) Permissions() *gliderssh.Permissions {
func (testSSHContext) Permissions() *gliderssh.Permissions {
panic("not implemented")
}

// SetValue allows you to easily write new values into the underlying context.
func (c testSSHContext) SetValue(key, value interface{}) {
func (testSSHContext) SetValue(_, _ interface{}) {
panic("not implemented")
}
3 changes: 2 additions & 1 deletion agent/agentssh/agentssh_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,14 @@ import (
"sync"
"testing"

"cdr.dev/slog/sloggers/slogtest"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/atomic"
"go.uber.org/goleak"
"golang.org/x/crypto/ssh"

"cdr.dev/slog/sloggers/slogtest"

"github.com/coder/coder/agent/agentssh"
"github.com/coder/coder/codersdk/agentsdk"
"github.com/coder/coder/pty/ptytest"
Expand Down
1 change: 1 addition & 0 deletions pty/pty.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ var ErrClosed = xerrors.New("pty: closed")

// PTYCmd is an interface for interacting with a pseudo-TTY where we control
// only one end, and the other end has been passed to a running os.Process.
// nolint:revive
type PTYCmd interface {
io.Closer

Expand Down