diff --git a/agent/agentssh/agentssh.go b/agent/agentssh/agentssh.go index 5903220975b8c..081056b4f4ebd 100644 --- a/agent/agentssh/agentssh.go +++ b/agent/agentssh/agentssh.go @@ -79,9 +79,9 @@ type Config struct { // where users will land when they connect via SSH. Default is the home // directory of the user. WorkingDirectory func() string - // X11SocketDir is the directory where X11 sockets are created. Default is - // /tmp/.X11-unix. - X11SocketDir string + // X11DisplayOffset is the offset to add to the X11 display number. + // Default is 10. + X11DisplayOffset *int // BlockFileTransfer restricts use of file transfer applications. BlockFileTransfer bool } @@ -124,8 +124,9 @@ func NewServer(ctx context.Context, logger slog.Logger, prometheusRegistry *prom if config == nil { config = &Config{} } - if config.X11SocketDir == "" { - config.X11SocketDir = filepath.Join(os.TempDir(), ".X11-unix") + if config.X11DisplayOffset == nil { + offset := X11DefaultDisplayOffset + config.X11DisplayOffset = &offset } if config.UpdateEnv == nil { config.UpdateEnv = func(current []string) ([]string, error) { return current, nil } @@ -273,13 +274,13 @@ func (s *Server) sessionHandler(session ssh.Session) { extraEnv := make([]string, 0) x11, hasX11 := session.X11() if hasX11 { - handled := s.x11Handler(session.Context(), x11) + display, handled := s.x11Handler(session.Context(), x11) if !handled { _ = session.Exit(1) logger.Error(ctx, "x11 handler failed") return } - extraEnv = append(extraEnv, fmt.Sprintf("DISPLAY=:%d.0", x11.ScreenNumber)) + extraEnv = append(extraEnv, fmt.Sprintf("DISPLAY=localhost:%d.%d", display, x11.ScreenNumber)) } if s.fileTransferBlocked(session) { diff --git a/agent/agentssh/x11.go b/agent/agentssh/x11.go index 2b083fbf049b7..90ec34201bbd0 100644 --- a/agent/agentssh/x11.go +++ b/agent/agentssh/x11.go @@ -7,6 +7,7 @@ import ( "errors" "fmt" "io" + "math" "net" "os" "path/filepath" @@ -22,61 +23,69 @@ import ( "cdr.dev/slog" ) +const ( + // X11StartPort is the starting port for X11 forwarding, this is the + // port used for "DISPLAY=localhost:0". + X11StartPort = 6000 + // X11DefaultDisplayOffset is the default offset for X11 forwarding. + X11DefaultDisplayOffset = 10 +) + // x11Callback is called when the client requests X11 forwarding. -// It adds an Xauthority entry to the Xauthority file. -func (s *Server) x11Callback(ctx ssh.Context, x11 ssh.X11) bool { +func (*Server) x11Callback(_ ssh.Context, _ ssh.X11) bool { + // Always allow. + return true +} + +// x11Handler is called when a session has requested X11 forwarding. +// It listens for X11 connections and forwards them to the client. +func (s *Server) x11Handler(ctx ssh.Context, x11 ssh.X11) (displayNumber int, handled bool) { + serverConn, valid := ctx.Value(ssh.ContextKeyConn).(*gossh.ServerConn) + if !valid { + s.logger.Warn(ctx, "failed to get server connection") + return -1, false + } + hostname, err := os.Hostname() if err != nil { s.logger.Warn(ctx, "failed to get hostname", slog.Error(err)) s.metrics.x11HandlerErrors.WithLabelValues("hostname").Add(1) - return false + return -1, false } - err = s.fs.MkdirAll(s.config.X11SocketDir, 0o700) + ln, display, err := createX11Listener(ctx, *s.config.X11DisplayOffset) if err != nil { - s.logger.Warn(ctx, "failed to make the x11 socket dir", slog.F("dir", s.config.X11SocketDir), slog.Error(err)) - s.metrics.x11HandlerErrors.WithLabelValues("socker_dir").Add(1) - return false - } + s.logger.Warn(ctx, "failed to create X11 listener", slog.Error(err)) + s.metrics.x11HandlerErrors.WithLabelValues("listen").Add(1) + return -1, false + } + s.trackListener(ln, true) + defer func() { + if !handled { + s.trackListener(ln, false) + _ = ln.Close() + } + }() - err = addXauthEntry(ctx, s.fs, hostname, strconv.Itoa(int(x11.ScreenNumber)), x11.AuthProtocol, x11.AuthCookie) + err = addXauthEntry(ctx, s.fs, hostname, strconv.Itoa(display), x11.AuthProtocol, x11.AuthCookie) if err != nil { s.logger.Warn(ctx, "failed to add Xauthority entry", slog.Error(err)) s.metrics.x11HandlerErrors.WithLabelValues("xauthority").Add(1) - return false + return -1, false } - return true -} -// x11Handler is called when a session has requested X11 forwarding. -// It listens for X11 connections and forwards them to the client. -func (s *Server) x11Handler(ctx ssh.Context, x11 ssh.X11) bool { - serverConn, valid := ctx.Value(ssh.ContextKeyConn).(*gossh.ServerConn) - if !valid { - s.logger.Warn(ctx, "failed to get server connection") - return false - } - // We want to overwrite the socket so that subsequent connections will succeed. - socketPath := filepath.Join(s.config.X11SocketDir, fmt.Sprintf("X%d", x11.ScreenNumber)) - err := os.Remove(socketPath) - if err != nil && !errors.Is(err, os.ErrNotExist) { - s.logger.Warn(ctx, "failed to remove existing X11 socket", slog.Error(err)) - return false - } - listener, err := net.Listen("unix", socketPath) - if err != nil { - s.logger.Warn(ctx, "failed to listen for X11", slog.Error(err)) - return false - } - s.trackListener(listener, true) + go func() { + // Don't leave the listener open after the session is gone. + <-ctx.Done() + _ = ln.Close() + }() go func() { - defer listener.Close() - defer s.trackListener(listener, false) - handledFirstConnection := false + defer ln.Close() + defer s.trackListener(ln, false) for { - conn, err := listener.Accept() + conn, err := ln.Accept() if err != nil { if errors.Is(err, net.ErrClosed) { return @@ -84,40 +93,66 @@ func (s *Server) x11Handler(ctx ssh.Context, x11 ssh.X11) bool { s.logger.Warn(ctx, "failed to accept X11 connection", slog.Error(err)) return } - if x11.SingleConnection && handledFirstConnection { - s.logger.Warn(ctx, "X11 connection rejected because single connection is enabled") - _ = conn.Close() - continue + if x11.SingleConnection { + s.logger.Debug(ctx, "single connection requested, closing X11 listener") + _ = ln.Close() } - handledFirstConnection = true - unixConn, ok := conn.(*net.UnixConn) + tcpConn, ok := conn.(*net.TCPConn) if !ok { - s.logger.Warn(ctx, fmt.Sprintf("failed to cast connection to UnixConn. got: %T", conn)) - return + s.logger.Warn(ctx, fmt.Sprintf("failed to cast connection to TCPConn. got: %T", conn)) + _ = conn.Close() + continue } - unixAddr, ok := unixConn.LocalAddr().(*net.UnixAddr) + tcpAddr, ok := tcpConn.LocalAddr().(*net.TCPAddr) if !ok { - s.logger.Warn(ctx, fmt.Sprintf("failed to cast local address to UnixAddr. got: %T", unixConn.LocalAddr())) - return + s.logger.Warn(ctx, fmt.Sprintf("failed to cast local address to TCPAddr. got: %T", tcpConn.LocalAddr())) + _ = conn.Close() + continue } channel, reqs, err := serverConn.OpenChannel("x11", gossh.Marshal(struct { OriginatorAddress string OriginatorPort uint32 }{ - OriginatorAddress: unixAddr.Name, - OriginatorPort: 0, + OriginatorAddress: tcpAddr.IP.String(), + OriginatorPort: uint32(tcpAddr.Port), })) if err != nil { s.logger.Warn(ctx, "failed to open X11 channel", slog.Error(err)) - return + _ = conn.Close() + continue } go gossh.DiscardRequests(reqs) - go Bicopy(ctx, conn, channel) + + if !s.trackConn(ln, conn, true) { + s.logger.Warn(ctx, "failed to track X11 connection") + _ = conn.Close() + continue + } + go func() { + defer s.trackConn(ln, conn, false) + Bicopy(ctx, conn, channel) + }() } }() - return true + + return display, true +} + +// createX11Listener creates a listener for X11 forwarding, it will use +// the next available port starting from X11StartPort and displayOffset. +func createX11Listener(ctx context.Context, displayOffset int) (ln net.Listener, display int, err error) { + var lc net.ListenConfig + // Look for an open port to listen on. + for port := X11StartPort + displayOffset; port < math.MaxUint16; port++ { + ln, err = lc.Listen(ctx, "tcp", fmt.Sprintf("localhost:%d", port)) + if err == nil { + display = port - X11StartPort + return ln, display, nil + } + } + return nil, -1, xerrors.Errorf("failed to find open port for X11 listener: %w", err) } // addXauthEntry adds an Xauthority entry to the Xauthority file. diff --git a/agent/agentssh/x11_test.go b/agent/agentssh/x11_test.go index da3c68c3e5d5b..932caeba596e7 100644 --- a/agent/agentssh/x11_test.go +++ b/agent/agentssh/x11_test.go @@ -1,12 +1,17 @@ package agentssh_test import ( + "bufio" + "bytes" "context" "encoding/hex" + "fmt" "net" "os" "path/filepath" "runtime" + "strconv" + "strings" "testing" "github.com/gliderlabs/ssh" @@ -31,10 +36,7 @@ func TestServer_X11(t *testing.T) { ctx := context.Background() logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) fs := afero.NewOsFs() - dir := t.TempDir() - s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), fs, &agentssh.Config{ - X11SocketDir: dir, - }) + s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), fs, &agentssh.Config{}) require.NoError(t, err) defer s.Close() @@ -53,21 +55,45 @@ func TestServer_X11(t *testing.T) { sess, err := c.NewSession() require.NoError(t, err) + wantScreenNumber := 1 reply, err := sess.SendRequest("x11-req", true, gossh.Marshal(ssh.X11{ AuthProtocol: "MIT-MAGIC-COOKIE-1", AuthCookie: hex.EncodeToString([]byte("cookie")), - ScreenNumber: 0, + ScreenNumber: uint32(wantScreenNumber), })) require.NoError(t, err) assert.True(t, reply) - err = sess.Shell() + // Want: ~DISPLAY=localhost:10.1 + out, err := sess.Output("echo DISPLAY=$DISPLAY") require.NoError(t, err) + sc := bufio.NewScanner(bytes.NewReader(out)) + displayNumber := -1 + for sc.Scan() { + line := strings.TrimSpace(sc.Text()) + t.Log(line) + if strings.HasPrefix(line, "DISPLAY=") { + parts := strings.SplitN(line, "=", 2) + display := parts[1] + parts = strings.SplitN(display, ":", 2) + parts = strings.SplitN(parts[1], ".", 2) + displayNumber, err = strconv.Atoi(parts[0]) + require.NoError(t, err) + assert.GreaterOrEqual(t, displayNumber, 10, "display number should be >= 10") + gotScreenNumber, err := strconv.Atoi(parts[1]) + require.NoError(t, err) + assert.Equal(t, wantScreenNumber, gotScreenNumber, "screen number should match") + break + } + } + require.NoError(t, sc.Err()) + require.NotEqual(t, -1, displayNumber) + x11Chans := c.HandleChannelOpen("x11") payload := "hello world" require.Eventually(t, func() bool { - conn, err := net.Dial("unix", filepath.Join(dir, "X0")) + conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", agentssh.X11StartPort+displayNumber)) if err == nil { _, err = conn.Write([]byte(payload)) assert.NoError(t, err)