Skip to content

Commit 9d3f404

Browse files
committed
feat(agent/agentssh): use tcp for X11 forwarding
Fixes #14198
1 parent 5366f25 commit 9d3f404

File tree

3 files changed

+105
-58
lines changed

3 files changed

+105
-58
lines changed

agent/agentssh/agentssh.go

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -79,9 +79,9 @@ type Config struct {
7979
// where users will land when they connect via SSH. Default is the home
8080
// directory of the user.
8181
WorkingDirectory func() string
82-
// X11SocketDir is the directory where X11 sockets are created. Default is
83-
// /tmp/.X11-unix.
84-
X11SocketDir string
82+
// X11DisplayOffset is the offset to add to the X11 display number.
83+
// Default is 10.
84+
X11DisplayOffset *int
8585
// BlockFileTransfer restricts use of file transfer applications.
8686
BlockFileTransfer bool
8787
}
@@ -124,8 +124,9 @@ func NewServer(ctx context.Context, logger slog.Logger, prometheusRegistry *prom
124124
if config == nil {
125125
config = &Config{}
126126
}
127-
if config.X11SocketDir == "" {
128-
config.X11SocketDir = filepath.Join(os.TempDir(), ".X11-unix")
127+
if config.X11DisplayOffset == nil {
128+
offset := X11DefaultDisplayOffset
129+
config.X11DisplayOffset = &offset
129130
}
130131
if config.UpdateEnv == nil {
131132
config.UpdateEnv = func(current []string) ([]string, error) { return current, nil }
@@ -273,13 +274,13 @@ func (s *Server) sessionHandler(session ssh.Session) {
273274
extraEnv := make([]string, 0)
274275
x11, hasX11 := session.X11()
275276
if hasX11 {
276-
handled := s.x11Handler(session.Context(), x11)
277+
display, handled := s.x11Handler(session.Context(), x11)
277278
if !handled {
278279
_ = session.Exit(1)
279280
logger.Error(ctx, "x11 handler failed")
280281
return
281282
}
282-
extraEnv = append(extraEnv, fmt.Sprintf("DISPLAY=:%d.0", x11.ScreenNumber))
283+
extraEnv = append(extraEnv, fmt.Sprintf("DISPLAY=localhost:%d.%d", display, x11.ScreenNumber))
283284
}
284285

285286
if s.fileTransferBlocked(session) {

agent/agentssh/x11.go

Lines changed: 64 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"errors"
88
"fmt"
99
"io"
10+
"math"
1011
"net"
1112
"os"
1213
"path/filepath"
@@ -22,61 +23,76 @@ import (
2223
"cdr.dev/slog"
2324
)
2425

25-
// x11Callback is called when the client requests X11 forwarding.
26-
// It adds an Xauthority entry to the Xauthority file.
27-
func (s *Server) x11Callback(ctx ssh.Context, x11 ssh.X11) bool {
28-
hostname, err := os.Hostname()
29-
if err != nil {
30-
s.logger.Warn(ctx, "failed to get hostname", slog.Error(err))
31-
s.metrics.x11HandlerErrors.WithLabelValues("hostname").Add(1)
32-
return false
33-
}
34-
35-
err = s.fs.MkdirAll(s.config.X11SocketDir, 0o700)
36-
if err != nil {
37-
s.logger.Warn(ctx, "failed to make the x11 socket dir", slog.F("dir", s.config.X11SocketDir), slog.Error(err))
38-
s.metrics.x11HandlerErrors.WithLabelValues("socker_dir").Add(1)
39-
return false
40-
}
26+
const (
27+
// X11StartPort is the starting port for X11 forwarding, this is the
28+
// port used for "DISPLAY=localhost:0".
29+
X11StartPort = 6000
30+
// X11DefaultDisplayOffset is the default offset for X11 forwarding.
31+
X11DefaultDisplayOffset = 10
32+
)
4133

42-
err = addXauthEntry(ctx, s.fs, hostname, strconv.Itoa(int(x11.ScreenNumber)), x11.AuthProtocol, x11.AuthCookie)
43-
if err != nil {
44-
s.logger.Warn(ctx, "failed to add Xauthority entry", slog.Error(err))
45-
s.metrics.x11HandlerErrors.WithLabelValues("xauthority").Add(1)
46-
return false
47-
}
34+
// x11Callback is called when the client requests X11 forwarding.
35+
func (*Server) x11Callback(_ ssh.Context, _ ssh.X11) bool {
36+
// Always allow.
4837
return true
4938
}
5039

5140
// x11Handler is called when a session has requested X11 forwarding.
5241
// It listens for X11 connections and forwards them to the client.
53-
func (s *Server) x11Handler(ctx ssh.Context, x11 ssh.X11) bool {
42+
func (s *Server) x11Handler(ctx ssh.Context, x11 ssh.X11) (display int, handled bool) {
5443
serverConn, valid := ctx.Value(ssh.ContextKeyConn).(*gossh.ServerConn)
5544
if !valid {
5645
s.logger.Warn(ctx, "failed to get server connection")
57-
return false
46+
return -1, false
5847
}
59-
// We want to overwrite the socket so that subsequent connections will succeed.
60-
socketPath := filepath.Join(s.config.X11SocketDir, fmt.Sprintf("X%d", x11.ScreenNumber))
61-
err := os.Remove(socketPath)
62-
if err != nil && !errors.Is(err, os.ErrNotExist) {
63-
s.logger.Warn(ctx, "failed to remove existing X11 socket", slog.Error(err))
64-
return false
65-
}
66-
listener, err := net.Listen("unix", socketPath)
48+
49+
hostname, err := os.Hostname()
6750
if err != nil {
51+
s.logger.Warn(ctx, "failed to get hostname", slog.Error(err))
52+
s.metrics.x11HandlerErrors.WithLabelValues("hostname").Add(1)
53+
return -1, false
54+
}
55+
56+
var (
57+
lc net.ListenConfig
58+
ln net.Listener
59+
port = X11StartPort + *s.config.X11DisplayOffset
60+
)
61+
// Look for an open port to listen on..
62+
for ; port >= X11StartPort && port < math.MaxUint16; port++ {
63+
ln, err = lc.Listen(ctx, "tcp", fmt.Sprintf("localhost:%d", port))
64+
if err == nil {
65+
display = port - X11StartPort
66+
break
67+
}
68+
}
69+
if ln == nil {
6870
s.logger.Warn(ctx, "failed to listen for X11", slog.Error(err))
69-
return false
71+
s.metrics.x11HandlerErrors.WithLabelValues("listen").Add(1)
72+
return -1, false
73+
}
74+
s.trackListener(ln, true)
75+
defer func() {
76+
if !handled {
77+
s.trackListener(ln, false)
78+
_ = ln.Close()
79+
}
80+
}()
81+
82+
err = addXauthEntry(ctx, s.fs, hostname, strconv.Itoa(port), x11.AuthProtocol, x11.AuthCookie)
83+
if err != nil {
84+
s.logger.Warn(ctx, "failed to add Xauthority entry", slog.Error(err))
85+
s.metrics.x11HandlerErrors.WithLabelValues("xauthority").Add(1)
86+
return -1, false
7087
}
71-
s.trackListener(listener, true)
7288

7389
go func() {
74-
defer listener.Close()
75-
defer s.trackListener(listener, false)
90+
defer ln.Close()
91+
defer s.trackListener(ln, false)
7692
handledFirstConnection := false
7793

7894
for {
79-
conn, err := listener.Accept()
95+
conn, err := ln.Accept()
8096
if err != nil {
8197
if errors.Is(err, net.ErrClosed) {
8298
return
@@ -91,33 +107,37 @@ func (s *Server) x11Handler(ctx ssh.Context, x11 ssh.X11) bool {
91107
}
92108
handledFirstConnection = true
93109

94-
unixConn, ok := conn.(*net.UnixConn)
110+
tcpConn, ok := conn.(*net.TCPConn)
95111
if !ok {
96-
s.logger.Warn(ctx, fmt.Sprintf("failed to cast connection to UnixConn. got: %T", conn))
112+
s.logger.Warn(ctx, fmt.Sprintf("failed to cast connection to TCPConn. got: %T", conn))
113+
_ = conn.Close()
97114
return
98115
}
99-
unixAddr, ok := unixConn.LocalAddr().(*net.UnixAddr)
116+
tcpAddr, ok := tcpConn.LocalAddr().(*net.TCPAddr)
100117
if !ok {
101-
s.logger.Warn(ctx, fmt.Sprintf("failed to cast local address to UnixAddr. got: %T", unixConn.LocalAddr()))
118+
s.logger.Warn(ctx, fmt.Sprintf("failed to cast local address to TCPAddr. got: %T", tcpConn.LocalAddr()))
119+
_ = conn.Close()
102120
return
103121
}
104122

105123
channel, reqs, err := serverConn.OpenChannel("x11", gossh.Marshal(struct {
106124
OriginatorAddress string
107125
OriginatorPort uint32
108126
}{
109-
OriginatorAddress: unixAddr.Name,
110-
OriginatorPort: 0,
127+
OriginatorAddress: tcpAddr.IP.String(),
128+
OriginatorPort: uint32(tcpAddr.Port),
111129
}))
112130
if err != nil {
113131
s.logger.Warn(ctx, "failed to open X11 channel", slog.Error(err))
132+
_ = conn.Close()
114133
return
115134
}
116135
go gossh.DiscardRequests(reqs)
117136
go Bicopy(ctx, conn, channel)
118137
}
119138
}()
120-
return true
139+
140+
return display, true
121141
}
122142

123143
// addXauthEntry adds an Xauthority entry to the Xauthority file.

agent/agentssh/x11_test.go

Lines changed: 33 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,17 @@
11
package agentssh_test
22

33
import (
4+
"bufio"
5+
"bytes"
46
"context"
57
"encoding/hex"
8+
"fmt"
69
"net"
710
"os"
811
"path/filepath"
912
"runtime"
13+
"strconv"
14+
"strings"
1015
"testing"
1116

1217
"github.com/gliderlabs/ssh"
@@ -31,10 +36,7 @@ func TestServer_X11(t *testing.T) {
3136
ctx := context.Background()
3237
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
3338
fs := afero.NewOsFs()
34-
dir := t.TempDir()
35-
s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), fs, &agentssh.Config{
36-
X11SocketDir: dir,
37-
})
39+
s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), fs, &agentssh.Config{})
3840
require.NoError(t, err)
3941
defer s.Close()
4042

@@ -53,21 +55,45 @@ func TestServer_X11(t *testing.T) {
5355
sess, err := c.NewSession()
5456
require.NoError(t, err)
5557

58+
wantScreenNumber := 1
5659
reply, err := sess.SendRequest("x11-req", true, gossh.Marshal(ssh.X11{
5760
AuthProtocol: "MIT-MAGIC-COOKIE-1",
5861
AuthCookie: hex.EncodeToString([]byte("cookie")),
59-
ScreenNumber: 0,
62+
ScreenNumber: uint32(wantScreenNumber),
6063
}))
6164
require.NoError(t, err)
6265
assert.True(t, reply)
6366

64-
err = sess.Shell()
67+
// Want: ~DISPLAY=localhost:10.1
68+
out, err := sess.Output("echo DISPLAY=$DISPLAY")
6569
require.NoError(t, err)
6670

71+
sc := bufio.NewScanner(bytes.NewReader(out))
72+
displayNumber := -1
73+
for sc.Scan() {
74+
line := strings.TrimSpace(sc.Text())
75+
t.Log(line)
76+
if strings.HasPrefix(line, "DISPLAY=") {
77+
parts := strings.SplitN(line, "=", 2)
78+
display := parts[1]
79+
parts = strings.SplitN(display, ":", 2)
80+
parts = strings.SplitN(parts[1], ".", 2)
81+
displayNumber, err = strconv.Atoi(parts[0])
82+
require.NoError(t, err)
83+
assert.GreaterOrEqual(t, displayNumber, 10, "display number should be >= 10")
84+
gotScreenNumber, err := strconv.Atoi(parts[1])
85+
require.NoError(t, err)
86+
assert.Equal(t, wantScreenNumber, gotScreenNumber, "screen number should match")
87+
break
88+
}
89+
}
90+
require.NoError(t, sc.Err())
91+
require.NotEqual(t, -1, displayNumber)
92+
6793
x11Chans := c.HandleChannelOpen("x11")
6894
payload := "hello world"
6995
require.Eventually(t, func() bool {
70-
conn, err := net.Dial("unix", filepath.Join(dir, "X0"))
96+
conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", agentssh.X11StartPort+displayNumber))
7197
if err == nil {
7298
_, err = conn.Write([]byte(payload))
7399
assert.NoError(t, err)

0 commit comments

Comments
 (0)