From f6f2964620b2b3a0a2c6fd6ba2c28e318ba9d198 Mon Sep 17 00:00:00 2001 From: Mathias Fredriksson Date: Wed, 4 Sep 2024 12:50:35 +0000 Subject: [PATCH 1/3] feat(agent/agentssh): use tcp for X11 forwarding Fixes #14198 --- agent/agentssh/agentssh.go | 15 ++-- agent/agentssh/x11.go | 140 +++++++++++++++++++++++-------------- agent/agentssh/x11_test.go | 40 +++++++++-- 3 files changed, 127 insertions(+), 68 deletions(-) diff --git a/agent/agentssh/agentssh.go b/agent/agentssh/agentssh.go index 5903220975b8c..081056b4f4ebd 100644 --- a/agent/agentssh/agentssh.go +++ b/agent/agentssh/agentssh.go @@ -79,9 +79,9 @@ type Config struct { // where users will land when they connect via SSH. Default is the home // directory of the user. WorkingDirectory func() string - // X11SocketDir is the directory where X11 sockets are created. Default is - // /tmp/.X11-unix. - X11SocketDir string + // X11DisplayOffset is the offset to add to the X11 display number. + // Default is 10. + X11DisplayOffset *int // BlockFileTransfer restricts use of file transfer applications. BlockFileTransfer bool } @@ -124,8 +124,9 @@ func NewServer(ctx context.Context, logger slog.Logger, prometheusRegistry *prom if config == nil { config = &Config{} } - if config.X11SocketDir == "" { - config.X11SocketDir = filepath.Join(os.TempDir(), ".X11-unix") + if config.X11DisplayOffset == nil { + offset := X11DefaultDisplayOffset + config.X11DisplayOffset = &offset } if config.UpdateEnv == nil { config.UpdateEnv = func(current []string) ([]string, error) { return current, nil } @@ -273,13 +274,13 @@ func (s *Server) sessionHandler(session ssh.Session) { extraEnv := make([]string, 0) x11, hasX11 := session.X11() if hasX11 { - handled := s.x11Handler(session.Context(), x11) + display, handled := s.x11Handler(session.Context(), x11) if !handled { _ = session.Exit(1) logger.Error(ctx, "x11 handler failed") return } - extraEnv = append(extraEnv, fmt.Sprintf("DISPLAY=:%d.0", x11.ScreenNumber)) + extraEnv = append(extraEnv, fmt.Sprintf("DISPLAY=localhost:%d.%d", display, x11.ScreenNumber)) } if s.fileTransferBlocked(session) { diff --git a/agent/agentssh/x11.go b/agent/agentssh/x11.go index 2b083fbf049b7..a4b494f92e798 100644 --- a/agent/agentssh/x11.go +++ b/agent/agentssh/x11.go @@ -7,6 +7,7 @@ import ( "errors" "fmt" "io" + "math" "net" "os" "path/filepath" @@ -22,61 +23,81 @@ import ( "cdr.dev/slog" ) -// x11Callback is called when the client requests X11 forwarding. -// It adds an Xauthority entry to the Xauthority file. -func (s *Server) x11Callback(ctx ssh.Context, x11 ssh.X11) bool { - hostname, err := os.Hostname() - if err != nil { - s.logger.Warn(ctx, "failed to get hostname", slog.Error(err)) - s.metrics.x11HandlerErrors.WithLabelValues("hostname").Add(1) - return false - } - - err = s.fs.MkdirAll(s.config.X11SocketDir, 0o700) - if err != nil { - s.logger.Warn(ctx, "failed to make the x11 socket dir", slog.F("dir", s.config.X11SocketDir), slog.Error(err)) - s.metrics.x11HandlerErrors.WithLabelValues("socker_dir").Add(1) - return false - } +const ( + // X11StartPort is the starting port for X11 forwarding, this is the + // port used for "DISPLAY=localhost:0". + X11StartPort = 6000 + // X11DefaultDisplayOffset is the default offset for X11 forwarding. + X11DefaultDisplayOffset = 10 +) - err = addXauthEntry(ctx, s.fs, hostname, strconv.Itoa(int(x11.ScreenNumber)), x11.AuthProtocol, x11.AuthCookie) - if err != nil { - s.logger.Warn(ctx, "failed to add Xauthority entry", slog.Error(err)) - s.metrics.x11HandlerErrors.WithLabelValues("xauthority").Add(1) - return false - } +// x11Callback is called when the client requests X11 forwarding. +func (*Server) x11Callback(_ ssh.Context, _ ssh.X11) bool { + // Always allow. return true } // x11Handler is called when a session has requested X11 forwarding. // It listens for X11 connections and forwards them to the client. -func (s *Server) x11Handler(ctx ssh.Context, x11 ssh.X11) bool { +func (s *Server) x11Handler(ctx ssh.Context, x11 ssh.X11) (display int, handled bool) { serverConn, valid := ctx.Value(ssh.ContextKeyConn).(*gossh.ServerConn) if !valid { s.logger.Warn(ctx, "failed to get server connection") - return false + return -1, false } - // We want to overwrite the socket so that subsequent connections will succeed. - socketPath := filepath.Join(s.config.X11SocketDir, fmt.Sprintf("X%d", x11.ScreenNumber)) - err := os.Remove(socketPath) - if err != nil && !errors.Is(err, os.ErrNotExist) { - s.logger.Warn(ctx, "failed to remove existing X11 socket", slog.Error(err)) - return false - } - listener, err := net.Listen("unix", socketPath) + + hostname, err := os.Hostname() if err != nil { + s.logger.Warn(ctx, "failed to get hostname", slog.Error(err)) + s.metrics.x11HandlerErrors.WithLabelValues("hostname").Add(1) + return -1, false + } + + var ( + lc net.ListenConfig + ln net.Listener + port = X11StartPort + *s.config.X11DisplayOffset + ) + // Look for an open port to listen on.. + for ; port >= X11StartPort && port < math.MaxUint16; port++ { + ln, err = lc.Listen(ctx, "tcp", fmt.Sprintf("localhost:%d", port)) + if err == nil { + display = port - X11StartPort + break + } + } + if ln == nil { s.logger.Warn(ctx, "failed to listen for X11", slog.Error(err)) - return false + s.metrics.x11HandlerErrors.WithLabelValues("listen").Add(1) + return -1, false + } + s.trackListener(ln, true) + defer func() { + if !handled { + s.trackListener(ln, false) + _ = ln.Close() + } + }() + + err = addXauthEntry(ctx, s.fs, hostname, strconv.Itoa(display), x11.AuthProtocol, x11.AuthCookie) + if err != nil { + s.logger.Warn(ctx, "failed to add Xauthority entry", slog.Error(err)) + s.metrics.x11HandlerErrors.WithLabelValues("xauthority").Add(1) + return -1, false } - s.trackListener(listener, true) go func() { - defer listener.Close() - defer s.trackListener(listener, false) - handledFirstConnection := false + // Don't leave the listener open after the session is gone. + <-ctx.Done() + _ = ln.Close() + }() + + go func() { + defer ln.Close() + defer s.trackListener(ln, false) for { - conn, err := listener.Accept() + conn, err := ln.Accept() if err != nil { if errors.Is(err, net.ErrClosed) { return @@ -84,40 +105,51 @@ func (s *Server) x11Handler(ctx ssh.Context, x11 ssh.X11) bool { s.logger.Warn(ctx, "failed to accept X11 connection", slog.Error(err)) return } - if x11.SingleConnection && handledFirstConnection { - s.logger.Warn(ctx, "X11 connection rejected because single connection is enabled") - _ = conn.Close() - continue + if x11.SingleConnection { + s.logger.Debug(ctx, "single connection requested, closing X11 listener") + _ = ln.Close() } - handledFirstConnection = true - unixConn, ok := conn.(*net.UnixConn) + tcpConn, ok := conn.(*net.TCPConn) if !ok { - s.logger.Warn(ctx, fmt.Sprintf("failed to cast connection to UnixConn. got: %T", conn)) - return + s.logger.Warn(ctx, fmt.Sprintf("failed to cast connection to TCPConn. got: %T", conn)) + _ = conn.Close() + continue } - unixAddr, ok := unixConn.LocalAddr().(*net.UnixAddr) + tcpAddr, ok := tcpConn.LocalAddr().(*net.TCPAddr) if !ok { - s.logger.Warn(ctx, fmt.Sprintf("failed to cast local address to UnixAddr. got: %T", unixConn.LocalAddr())) - return + s.logger.Warn(ctx, fmt.Sprintf("failed to cast local address to TCPAddr. got: %T", tcpConn.LocalAddr())) + _ = conn.Close() + continue } channel, reqs, err := serverConn.OpenChannel("x11", gossh.Marshal(struct { OriginatorAddress string OriginatorPort uint32 }{ - OriginatorAddress: unixAddr.Name, - OriginatorPort: 0, + OriginatorAddress: tcpAddr.IP.String(), + OriginatorPort: uint32(tcpAddr.Port), })) if err != nil { s.logger.Warn(ctx, "failed to open X11 channel", slog.Error(err)) - return + _ = conn.Close() + continue } go gossh.DiscardRequests(reqs) - go Bicopy(ctx, conn, channel) + + if !s.trackConn(ln, conn, true) { + s.logger.Warn(ctx, "failed to track X11 connection") + _ = conn.Close() + continue + } + go func() { + defer s.trackConn(ln, conn, false) + Bicopy(ctx, conn, channel) + }() } }() - return true + + return display, true } // addXauthEntry adds an Xauthority entry to the Xauthority file. diff --git a/agent/agentssh/x11_test.go b/agent/agentssh/x11_test.go index da3c68c3e5d5b..932caeba596e7 100644 --- a/agent/agentssh/x11_test.go +++ b/agent/agentssh/x11_test.go @@ -1,12 +1,17 @@ package agentssh_test import ( + "bufio" + "bytes" "context" "encoding/hex" + "fmt" "net" "os" "path/filepath" "runtime" + "strconv" + "strings" "testing" "github.com/gliderlabs/ssh" @@ -31,10 +36,7 @@ func TestServer_X11(t *testing.T) { ctx := context.Background() logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) fs := afero.NewOsFs() - dir := t.TempDir() - s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), fs, &agentssh.Config{ - X11SocketDir: dir, - }) + s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), fs, &agentssh.Config{}) require.NoError(t, err) defer s.Close() @@ -53,21 +55,45 @@ func TestServer_X11(t *testing.T) { sess, err := c.NewSession() require.NoError(t, err) + wantScreenNumber := 1 reply, err := sess.SendRequest("x11-req", true, gossh.Marshal(ssh.X11{ AuthProtocol: "MIT-MAGIC-COOKIE-1", AuthCookie: hex.EncodeToString([]byte("cookie")), - ScreenNumber: 0, + ScreenNumber: uint32(wantScreenNumber), })) require.NoError(t, err) assert.True(t, reply) - err = sess.Shell() + // Want: ~DISPLAY=localhost:10.1 + out, err := sess.Output("echo DISPLAY=$DISPLAY") require.NoError(t, err) + sc := bufio.NewScanner(bytes.NewReader(out)) + displayNumber := -1 + for sc.Scan() { + line := strings.TrimSpace(sc.Text()) + t.Log(line) + if strings.HasPrefix(line, "DISPLAY=") { + parts := strings.SplitN(line, "=", 2) + display := parts[1] + parts = strings.SplitN(display, ":", 2) + parts = strings.SplitN(parts[1], ".", 2) + displayNumber, err = strconv.Atoi(parts[0]) + require.NoError(t, err) + assert.GreaterOrEqual(t, displayNumber, 10, "display number should be >= 10") + gotScreenNumber, err := strconv.Atoi(parts[1]) + require.NoError(t, err) + assert.Equal(t, wantScreenNumber, gotScreenNumber, "screen number should match") + break + } + } + require.NoError(t, sc.Err()) + require.NotEqual(t, -1, displayNumber) + x11Chans := c.HandleChannelOpen("x11") payload := "hello world" require.Eventually(t, func() bool { - conn, err := net.Dial("unix", filepath.Join(dir, "X0")) + conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", agentssh.X11StartPort+displayNumber)) if err == nil { _, err = conn.Write([]byte(payload)) assert.NoError(t, err) From 6fc49a7aeaa39023f979f18702affa24977f223c Mon Sep 17 00:00:00 2001 From: Mathias Fredriksson Date: Wed, 4 Sep 2024 16:11:23 +0000 Subject: [PATCH 2/3] split out createX11Listener --- agent/agentssh/x11.go | 36 ++++++++++++++++++++---------------- 1 file changed, 20 insertions(+), 16 deletions(-) diff --git a/agent/agentssh/x11.go b/agent/agentssh/x11.go index a4b494f92e798..dbb2745ba738a 100644 --- a/agent/agentssh/x11.go +++ b/agent/agentssh/x11.go @@ -39,7 +39,7 @@ func (*Server) x11Callback(_ ssh.Context, _ ssh.X11) bool { // x11Handler is called when a session has requested X11 forwarding. // It listens for X11 connections and forwards them to the client. -func (s *Server) x11Handler(ctx ssh.Context, x11 ssh.X11) (display int, handled bool) { +func (s *Server) x11Handler(ctx ssh.Context, x11 ssh.X11) (displayNumber int, handled bool) { serverConn, valid := ctx.Value(ssh.ContextKeyConn).(*gossh.ServerConn) if !valid { s.logger.Warn(ctx, "failed to get server connection") @@ -53,24 +53,13 @@ func (s *Server) x11Handler(ctx ssh.Context, x11 ssh.X11) (display int, handled return -1, false } - var ( - lc net.ListenConfig - ln net.Listener - port = X11StartPort + *s.config.X11DisplayOffset - ) - // Look for an open port to listen on.. - for ; port >= X11StartPort && port < math.MaxUint16; port++ { - ln, err = lc.Listen(ctx, "tcp", fmt.Sprintf("localhost:%d", port)) - if err == nil { - display = port - X11StartPort - break - } - } - if ln == nil { - s.logger.Warn(ctx, "failed to listen for X11", slog.Error(err)) + ln, display, err := createX11Listener(ctx, *s.config.X11DisplayOffset) + if err != nil { + s.logger.Warn(ctx, "failed to create X11 listener", slog.Error(err)) s.metrics.x11HandlerErrors.WithLabelValues("listen").Add(1) return -1, false } + s.trackListener(ln, true) defer func() { if !handled { @@ -152,6 +141,21 @@ func (s *Server) x11Handler(ctx ssh.Context, x11 ssh.X11) (display int, handled return display, true } +// createX11Listener creates a listener for X11 forwarding, it will use +// the next available port starting from X11StartPort and displayOffset. +func createX11Listener(ctx context.Context, displayOffset int) (ln net.Listener, display int, err error) { + var lc net.ListenConfig + // Look for an open port to listen on. + for port := X11StartPort + displayOffset; port < math.MaxUint16; port++ { + ln, err = lc.Listen(ctx, "tcp", fmt.Sprintf("localhost:%d", port)) + if err == nil { + display = port - X11StartPort + return ln, display, nil + } + } + return nil, -1, xerrors.Errorf("failed to find open port for X11 listener: %w", err) +} + // addXauthEntry adds an Xauthority entry to the Xauthority file. // The Xauthority file is located at ~/.Xauthority. func addXauthEntry(ctx context.Context, fs afero.Fs, host string, display string, authProtocol string, authCookie string) error { From 3b78fd13f9b8dc5df00c57de88c35db292fee909 Mon Sep 17 00:00:00 2001 From: Mathias Fredriksson Date: Wed, 4 Sep 2024 16:14:04 +0000 Subject: [PATCH 3/3] del nl --- agent/agentssh/x11.go | 1 - 1 file changed, 1 deletion(-) diff --git a/agent/agentssh/x11.go b/agent/agentssh/x11.go index dbb2745ba738a..90ec34201bbd0 100644 --- a/agent/agentssh/x11.go +++ b/agent/agentssh/x11.go @@ -59,7 +59,6 @@ func (s *Server) x11Handler(ctx ssh.Context, x11 ssh.X11) (displayNumber int, ha s.metrics.x11HandlerErrors.WithLabelValues("listen").Add(1) return -1, false } - s.trackListener(ln, true) defer func() { if !handled {