Skip to content

feat(agent/agentssh): use tcp for X11 forwarding #14560

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Sep 4, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
feat(agent/agentssh): use tcp for X11 forwarding
Fixes #14198
  • Loading branch information
mafredri committed Sep 4, 2024
commit f6f2964620b2b3a0a2c6fd6ba2c28e318ba9d198
15 changes: 8 additions & 7 deletions agent/agentssh/agentssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,9 @@ type Config struct {
// where users will land when they connect via SSH. Default is the home
// directory of the user.
WorkingDirectory func() string
// X11SocketDir is the directory where X11 sockets are created. Default is
// /tmp/.X11-unix.
X11SocketDir string
// X11DisplayOffset is the offset to add to the X11 display number.
// Default is 10.
X11DisplayOffset *int
// BlockFileTransfer restricts use of file transfer applications.
BlockFileTransfer bool
}
Expand Down Expand Up @@ -124,8 +124,9 @@ func NewServer(ctx context.Context, logger slog.Logger, prometheusRegistry *prom
if config == nil {
config = &Config{}
}
if config.X11SocketDir == "" {
config.X11SocketDir = filepath.Join(os.TempDir(), ".X11-unix")
if config.X11DisplayOffset == nil {
offset := X11DefaultDisplayOffset
config.X11DisplayOffset = &offset
}
if config.UpdateEnv == nil {
config.UpdateEnv = func(current []string) ([]string, error) { return current, nil }
Expand Down Expand Up @@ -273,13 +274,13 @@ func (s *Server) sessionHandler(session ssh.Session) {
extraEnv := make([]string, 0)
x11, hasX11 := session.X11()
if hasX11 {
handled := s.x11Handler(session.Context(), x11)
display, handled := s.x11Handler(session.Context(), x11)
if !handled {
_ = session.Exit(1)
logger.Error(ctx, "x11 handler failed")
return
}
extraEnv = append(extraEnv, fmt.Sprintf("DISPLAY=:%d.0", x11.ScreenNumber))
extraEnv = append(extraEnv, fmt.Sprintf("DISPLAY=localhost:%d.%d", display, x11.ScreenNumber))
}

if s.fileTransferBlocked(session) {
Expand Down
140 changes: 86 additions & 54 deletions agent/agentssh/x11.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"errors"
"fmt"
"io"
"math"
"net"
"os"
"path/filepath"
Expand All @@ -22,102 +23,133 @@ import (
"cdr.dev/slog"
)

// x11Callback is called when the client requests X11 forwarding.
// It adds an Xauthority entry to the Xauthority file.
func (s *Server) x11Callback(ctx ssh.Context, x11 ssh.X11) bool {
hostname, err := os.Hostname()
if err != nil {
s.logger.Warn(ctx, "failed to get hostname", slog.Error(err))
s.metrics.x11HandlerErrors.WithLabelValues("hostname").Add(1)
return false
}

err = s.fs.MkdirAll(s.config.X11SocketDir, 0o700)
if err != nil {
s.logger.Warn(ctx, "failed to make the x11 socket dir", slog.F("dir", s.config.X11SocketDir), slog.Error(err))
s.metrics.x11HandlerErrors.WithLabelValues("socker_dir").Add(1)
return false
}
const (
// X11StartPort is the starting port for X11 forwarding, this is the
// port used for "DISPLAY=localhost:0".
X11StartPort = 6000
// X11DefaultDisplayOffset is the default offset for X11 forwarding.
X11DefaultDisplayOffset = 10
)

err = addXauthEntry(ctx, s.fs, hostname, strconv.Itoa(int(x11.ScreenNumber)), x11.AuthProtocol, x11.AuthCookie)
if err != nil {
s.logger.Warn(ctx, "failed to add Xauthority entry", slog.Error(err))
s.metrics.x11HandlerErrors.WithLabelValues("xauthority").Add(1)
return false
}
// x11Callback is called when the client requests X11 forwarding.
func (*Server) x11Callback(_ ssh.Context, _ ssh.X11) bool {
// Always allow.
return true
}

// x11Handler is called when a session has requested X11 forwarding.
// It listens for X11 connections and forwards them to the client.
func (s *Server) x11Handler(ctx ssh.Context, x11 ssh.X11) bool {
func (s *Server) x11Handler(ctx ssh.Context, x11 ssh.X11) (display int, handled bool) {
serverConn, valid := ctx.Value(ssh.ContextKeyConn).(*gossh.ServerConn)
if !valid {
s.logger.Warn(ctx, "failed to get server connection")
return false
return -1, false
}
// We want to overwrite the socket so that subsequent connections will succeed.
socketPath := filepath.Join(s.config.X11SocketDir, fmt.Sprintf("X%d", x11.ScreenNumber))
err := os.Remove(socketPath)
if err != nil && !errors.Is(err, os.ErrNotExist) {
s.logger.Warn(ctx, "failed to remove existing X11 socket", slog.Error(err))
return false
}
listener, err := net.Listen("unix", socketPath)

hostname, err := os.Hostname()
if err != nil {
s.logger.Warn(ctx, "failed to get hostname", slog.Error(err))
s.metrics.x11HandlerErrors.WithLabelValues("hostname").Add(1)
return -1, false
}

var (
lc net.ListenConfig
ln net.Listener
port = X11StartPort + *s.config.X11DisplayOffset
)
// Look for an open port to listen on..
for ; port >= X11StartPort && port < math.MaxUint16; port++ {
ln, err = lc.Listen(ctx, "tcp", fmt.Sprintf("localhost:%d", port))
if err == nil {
display = port - X11StartPort
break
}
}
if ln == nil {
s.logger.Warn(ctx, "failed to listen for X11", slog.Error(err))
return false
s.metrics.x11HandlerErrors.WithLabelValues("listen").Add(1)
return -1, false
}
s.trackListener(ln, true)
defer func() {
if !handled {
s.trackListener(ln, false)
_ = ln.Close()
}
}()

err = addXauthEntry(ctx, s.fs, hostname, strconv.Itoa(display), x11.AuthProtocol, x11.AuthCookie)
if err != nil {
s.logger.Warn(ctx, "failed to add Xauthority entry", slog.Error(err))
s.metrics.x11HandlerErrors.WithLabelValues("xauthority").Add(1)
return -1, false
}
s.trackListener(listener, true)

go func() {
defer listener.Close()
defer s.trackListener(listener, false)
handledFirstConnection := false
// Don't leave the listener open after the session is gone.
<-ctx.Done()
_ = ln.Close()
}()

go func() {
defer ln.Close()
defer s.trackListener(ln, false)

for {
conn, err := listener.Accept()
conn, err := ln.Accept()
if err != nil {
if errors.Is(err, net.ErrClosed) {
return
}
s.logger.Warn(ctx, "failed to accept X11 connection", slog.Error(err))
return
}
if x11.SingleConnection && handledFirstConnection {
s.logger.Warn(ctx, "X11 connection rejected because single connection is enabled")
_ = conn.Close()
continue
if x11.SingleConnection {
s.logger.Debug(ctx, "single connection requested, closing X11 listener")
_ = ln.Close()
}
handledFirstConnection = true

unixConn, ok := conn.(*net.UnixConn)
tcpConn, ok := conn.(*net.TCPConn)
if !ok {
s.logger.Warn(ctx, fmt.Sprintf("failed to cast connection to UnixConn. got: %T", conn))
return
s.logger.Warn(ctx, fmt.Sprintf("failed to cast connection to TCPConn. got: %T", conn))
_ = conn.Close()
continue
}
unixAddr, ok := unixConn.LocalAddr().(*net.UnixAddr)
tcpAddr, ok := tcpConn.LocalAddr().(*net.TCPAddr)
if !ok {
s.logger.Warn(ctx, fmt.Sprintf("failed to cast local address to UnixAddr. got: %T", unixConn.LocalAddr()))
return
s.logger.Warn(ctx, fmt.Sprintf("failed to cast local address to TCPAddr. got: %T", tcpConn.LocalAddr()))
_ = conn.Close()
continue
}

channel, reqs, err := serverConn.OpenChannel("x11", gossh.Marshal(struct {
OriginatorAddress string
OriginatorPort uint32
}{
OriginatorAddress: unixAddr.Name,
OriginatorPort: 0,
OriginatorAddress: tcpAddr.IP.String(),
OriginatorPort: uint32(tcpAddr.Port),
}))
if err != nil {
s.logger.Warn(ctx, "failed to open X11 channel", slog.Error(err))
return
_ = conn.Close()
continue
}
go gossh.DiscardRequests(reqs)
go Bicopy(ctx, conn, channel)

if !s.trackConn(ln, conn, true) {
s.logger.Warn(ctx, "failed to track X11 connection")
_ = conn.Close()
continue
}
go func() {
defer s.trackConn(ln, conn, false)
Bicopy(ctx, conn, channel)
}()
}
}()
return true

return display, true
}

// addXauthEntry adds an Xauthority entry to the Xauthority file.
Expand Down
40 changes: 33 additions & 7 deletions agent/agentssh/x11_test.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
package agentssh_test

import (
"bufio"
"bytes"
"context"
"encoding/hex"
"fmt"
"net"
"os"
"path/filepath"
"runtime"
"strconv"
"strings"
"testing"

"github.com/gliderlabs/ssh"
Expand All @@ -31,10 +36,7 @@ func TestServer_X11(t *testing.T) {
ctx := context.Background()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
fs := afero.NewOsFs()
dir := t.TempDir()
s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), fs, &agentssh.Config{
X11SocketDir: dir,
})
s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), fs, &agentssh.Config{})
require.NoError(t, err)
defer s.Close()

Expand All @@ -53,21 +55,45 @@ func TestServer_X11(t *testing.T) {
sess, err := c.NewSession()
require.NoError(t, err)

wantScreenNumber := 1
reply, err := sess.SendRequest("x11-req", true, gossh.Marshal(ssh.X11{
AuthProtocol: "MIT-MAGIC-COOKIE-1",
AuthCookie: hex.EncodeToString([]byte("cookie")),
ScreenNumber: 0,
ScreenNumber: uint32(wantScreenNumber),
}))
require.NoError(t, err)
assert.True(t, reply)

err = sess.Shell()
// Want: ~DISPLAY=localhost:10.1
out, err := sess.Output("echo DISPLAY=$DISPLAY")
require.NoError(t, err)

sc := bufio.NewScanner(bytes.NewReader(out))
displayNumber := -1
for sc.Scan() {
line := strings.TrimSpace(sc.Text())
t.Log(line)
if strings.HasPrefix(line, "DISPLAY=") {
parts := strings.SplitN(line, "=", 2)
display := parts[1]
parts = strings.SplitN(display, ":", 2)
parts = strings.SplitN(parts[1], ".", 2)
displayNumber, err = strconv.Atoi(parts[0])
require.NoError(t, err)
assert.GreaterOrEqual(t, displayNumber, 10, "display number should be >= 10")
gotScreenNumber, err := strconv.Atoi(parts[1])
require.NoError(t, err)
assert.Equal(t, wantScreenNumber, gotScreenNumber, "screen number should match")
break
}
}
require.NoError(t, sc.Err())
require.NotEqual(t, -1, displayNumber)

x11Chans := c.HandleChannelOpen("x11")
payload := "hello world"
require.Eventually(t, func() bool {
conn, err := net.Dial("unix", filepath.Join(dir, "X0"))
conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", agentssh.X11StartPort+displayNumber))
if err == nil {
_, err = conn.Write([]byte(payload))
assert.NoError(t, err)
Expand Down
Loading