Skip to content

Commit 3a163aa

Browse files
committed
feat(agent/agentssh): use tcp for X11 forwarding
Fixes #14198
1 parent 5366f25 commit 3a163aa

File tree

3 files changed

+101
-58
lines changed

3 files changed

+101
-58
lines changed

agent/agentssh/agentssh.go

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -79,9 +79,9 @@ type Config struct {
7979
// where users will land when they connect via SSH. Default is the home
8080
// directory of the user.
8181
WorkingDirectory func() string
82-
// X11SocketDir is the directory where X11 sockets are created. Default is
83-
// /tmp/.X11-unix.
84-
X11SocketDir string
82+
// X11DisplayOffset is the offset to add to the X11 display number.
83+
// Default is 10.
84+
X11DisplayOffset *int
8585
// BlockFileTransfer restricts use of file transfer applications.
8686
BlockFileTransfer bool
8787
}
@@ -124,8 +124,9 @@ func NewServer(ctx context.Context, logger slog.Logger, prometheusRegistry *prom
124124
if config == nil {
125125
config = &Config{}
126126
}
127-
if config.X11SocketDir == "" {
128-
config.X11SocketDir = filepath.Join(os.TempDir(), ".X11-unix")
127+
if config.X11DisplayOffset == nil {
128+
offset := X11DefaultDisplayOffset
129+
config.X11DisplayOffset = &offset
129130
}
130131
if config.UpdateEnv == nil {
131132
config.UpdateEnv = func(current []string) ([]string, error) { return current, nil }
@@ -273,13 +274,13 @@ func (s *Server) sessionHandler(session ssh.Session) {
273274
extraEnv := make([]string, 0)
274275
x11, hasX11 := session.X11()
275276
if hasX11 {
276-
handled := s.x11Handler(session.Context(), x11)
277+
display, handled := s.x11Handler(session.Context(), x11)
277278
if !handled {
278279
_ = session.Exit(1)
279280
logger.Error(ctx, "x11 handler failed")
280281
return
281282
}
282-
extraEnv = append(extraEnv, fmt.Sprintf("DISPLAY=:%d.0", x11.ScreenNumber))
283+
extraEnv = append(extraEnv, fmt.Sprintf("DISPLAY=localhost:%d.%d", display, x11.ScreenNumber))
283284
}
284285

285286
if s.fileTransferBlocked(session) {

agent/agentssh/x11.go

Lines changed: 60 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"errors"
88
"fmt"
99
"io"
10+
"math"
1011
"net"
1112
"os"
1213
"path/filepath"
@@ -22,61 +23,72 @@ import (
2223
"cdr.dev/slog"
2324
)
2425

25-
// x11Callback is called when the client requests X11 forwarding.
26-
// It adds an Xauthority entry to the Xauthority file.
27-
func (s *Server) x11Callback(ctx ssh.Context, x11 ssh.X11) bool {
28-
hostname, err := os.Hostname()
29-
if err != nil {
30-
s.logger.Warn(ctx, "failed to get hostname", slog.Error(err))
31-
s.metrics.x11HandlerErrors.WithLabelValues("hostname").Add(1)
32-
return false
33-
}
34-
35-
err = s.fs.MkdirAll(s.config.X11SocketDir, 0o700)
36-
if err != nil {
37-
s.logger.Warn(ctx, "failed to make the x11 socket dir", slog.F("dir", s.config.X11SocketDir), slog.Error(err))
38-
s.metrics.x11HandlerErrors.WithLabelValues("socker_dir").Add(1)
39-
return false
40-
}
26+
const (
27+
X11StartPort = 6000
28+
X11DefaultDisplayOffset = 10
29+
)
4130

42-
err = addXauthEntry(ctx, s.fs, hostname, strconv.Itoa(int(x11.ScreenNumber)), x11.AuthProtocol, x11.AuthCookie)
43-
if err != nil {
44-
s.logger.Warn(ctx, "failed to add Xauthority entry", slog.Error(err))
45-
s.metrics.x11HandlerErrors.WithLabelValues("xauthority").Add(1)
46-
return false
47-
}
31+
// x11Callback is called when the client requests X11 forwarding.
32+
func (*Server) x11Callback(_ ssh.Context, _ ssh.X11) bool {
33+
// Always allow.
4834
return true
4935
}
5036

5137
// x11Handler is called when a session has requested X11 forwarding.
5238
// It listens for X11 connections and forwards them to the client.
53-
func (s *Server) x11Handler(ctx ssh.Context, x11 ssh.X11) bool {
39+
func (s *Server) x11Handler(ctx ssh.Context, x11 ssh.X11) (display int, handled bool) {
5440
serverConn, valid := ctx.Value(ssh.ContextKeyConn).(*gossh.ServerConn)
5541
if !valid {
5642
s.logger.Warn(ctx, "failed to get server connection")
57-
return false
43+
return -1, false
5844
}
59-
// We want to overwrite the socket so that subsequent connections will succeed.
60-
socketPath := filepath.Join(s.config.X11SocketDir, fmt.Sprintf("X%d", x11.ScreenNumber))
61-
err := os.Remove(socketPath)
62-
if err != nil && !errors.Is(err, os.ErrNotExist) {
63-
s.logger.Warn(ctx, "failed to remove existing X11 socket", slog.Error(err))
64-
return false
65-
}
66-
listener, err := net.Listen("unix", socketPath)
45+
46+
hostname, err := os.Hostname()
6747
if err != nil {
48+
s.logger.Warn(ctx, "failed to get hostname", slog.Error(err))
49+
s.metrics.x11HandlerErrors.WithLabelValues("hostname").Add(1)
50+
return -1, false
51+
}
52+
53+
var (
54+
lc net.ListenConfig
55+
ln net.Listener
56+
port = X11StartPort + *s.config.X11DisplayOffset
57+
)
58+
for ; port >= 6000 && port < math.MaxUint16; port++ {
59+
ln, err = lc.Listen(ctx, "tcp", fmt.Sprintf("localhost:%d", port))
60+
if err == nil {
61+
display = port - X11StartPort
62+
break
63+
}
64+
}
65+
if ln == nil {
6866
s.logger.Warn(ctx, "failed to listen for X11", slog.Error(err))
69-
return false
67+
s.metrics.x11HandlerErrors.WithLabelValues("listen").Add(1)
68+
return -1, false
69+
}
70+
s.trackListener(ln, true)
71+
defer func() {
72+
if !handled {
73+
s.trackListener(ln, false)
74+
_ = ln.Close()
75+
}
76+
}()
77+
78+
err = addXauthEntry(ctx, s.fs, hostname, strconv.Itoa(port), x11.AuthProtocol, x11.AuthCookie)
79+
if err != nil {
80+
s.logger.Warn(ctx, "failed to add Xauthority entry", slog.Error(err))
81+
s.metrics.x11HandlerErrors.WithLabelValues("xauthority").Add(1)
82+
return -1, false
7083
}
71-
s.trackListener(listener, true)
7284

7385
go func() {
74-
defer listener.Close()
75-
defer s.trackListener(listener, false)
86+
defer ln.Close()
87+
defer s.trackListener(ln, false)
7688
handledFirstConnection := false
7789

7890
for {
79-
conn, err := listener.Accept()
91+
conn, err := ln.Accept()
8092
if err != nil {
8193
if errors.Is(err, net.ErrClosed) {
8294
return
@@ -91,33 +103,37 @@ func (s *Server) x11Handler(ctx ssh.Context, x11 ssh.X11) bool {
91103
}
92104
handledFirstConnection = true
93105

94-
unixConn, ok := conn.(*net.UnixConn)
106+
tcpConn, ok := conn.(*net.TCPConn)
95107
if !ok {
96-
s.logger.Warn(ctx, fmt.Sprintf("failed to cast connection to UnixConn. got: %T", conn))
108+
s.logger.Warn(ctx, fmt.Sprintf("failed to cast connection to TCPConn. got: %T", conn))
109+
_ = conn.Close()
97110
return
98111
}
99-
unixAddr, ok := unixConn.LocalAddr().(*net.UnixAddr)
112+
tcpAddr, ok := tcpConn.LocalAddr().(*net.TCPAddr)
100113
if !ok {
101-
s.logger.Warn(ctx, fmt.Sprintf("failed to cast local address to UnixAddr. got: %T", unixConn.LocalAddr()))
114+
s.logger.Warn(ctx, fmt.Sprintf("failed to cast local address to TCPAddr. got: %T", tcpConn.LocalAddr()))
115+
_ = conn.Close()
102116
return
103117
}
104118

105119
channel, reqs, err := serverConn.OpenChannel("x11", gossh.Marshal(struct {
106120
OriginatorAddress string
107121
OriginatorPort uint32
108122
}{
109-
OriginatorAddress: unixAddr.Name,
110-
OriginatorPort: 0,
123+
OriginatorAddress: tcpAddr.IP.String(),
124+
OriginatorPort: uint32(tcpAddr.Port),
111125
}))
112126
if err != nil {
113127
s.logger.Warn(ctx, "failed to open X11 channel", slog.Error(err))
128+
_ = conn.Close()
114129
return
115130
}
116131
go gossh.DiscardRequests(reqs)
117132
go Bicopy(ctx, conn, channel)
118133
}
119134
}()
120-
return true
135+
136+
return display, true
121137
}
122138

123139
// addXauthEntry adds an Xauthority entry to the Xauthority file.

agent/agentssh/x11_test.go

Lines changed: 33 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,17 @@
11
package agentssh_test
22

33
import (
4+
"bufio"
5+
"bytes"
46
"context"
57
"encoding/hex"
8+
"fmt"
69
"net"
710
"os"
811
"path/filepath"
912
"runtime"
13+
"strconv"
14+
"strings"
1015
"testing"
1116

1217
"github.com/gliderlabs/ssh"
@@ -31,10 +36,7 @@ func TestServer_X11(t *testing.T) {
3136
ctx := context.Background()
3237
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
3338
fs := afero.NewOsFs()
34-
dir := t.TempDir()
35-
s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), fs, &agentssh.Config{
36-
X11SocketDir: dir,
37-
})
39+
s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), fs, &agentssh.Config{})
3840
require.NoError(t, err)
3941
defer s.Close()
4042

@@ -53,21 +55,45 @@ func TestServer_X11(t *testing.T) {
5355
sess, err := c.NewSession()
5456
require.NoError(t, err)
5557

58+
wantScreenNumber := 1
5659
reply, err := sess.SendRequest("x11-req", true, gossh.Marshal(ssh.X11{
5760
AuthProtocol: "MIT-MAGIC-COOKIE-1",
5861
AuthCookie: hex.EncodeToString([]byte("cookie")),
59-
ScreenNumber: 0,
62+
ScreenNumber: uint32(wantScreenNumber),
6063
}))
6164
require.NoError(t, err)
6265
assert.True(t, reply)
6366

64-
err = sess.Shell()
67+
// Want: ~DISPLAY=localhost:10.1
68+
out, err := sess.Output("echo DISPLAY=$DISPLAY")
6569
require.NoError(t, err)
6670

71+
sc := bufio.NewScanner(bytes.NewReader(out))
72+
displayNumber := -1
73+
for sc.Scan() {
74+
line := strings.TrimSpace(sc.Text())
75+
t.Log(line)
76+
if strings.HasPrefix(line, "DISPLAY=") {
77+
parts := strings.SplitN(line, "=", 2)
78+
display := parts[1]
79+
parts = strings.SplitN(display, ":", 2)
80+
parts = strings.SplitN(parts[1], ".", 2)
81+
displayNumber, err = strconv.Atoi(parts[0])
82+
require.NoError(t, err)
83+
assert.GreaterOrEqual(t, displayNumber, 10, "display number should be >= 10")
84+
gotScreenNumber, err := strconv.Atoi(parts[1])
85+
require.NoError(t, err)
86+
assert.Equal(t, wantScreenNumber, gotScreenNumber, "screen number should match")
87+
break
88+
}
89+
}
90+
require.NoError(t, sc.Err())
91+
require.NotEqual(t, -1, displayNumber)
92+
6793
x11Chans := c.HandleChannelOpen("x11")
6894
payload := "hello world"
6995
require.Eventually(t, func() bool {
70-
conn, err := net.Dial("unix", filepath.Join(dir, "X0"))
96+
conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", agentssh.X11StartPort+displayNumber))
7197
if err == nil {
7298
_, err = conn.Write([]byte(payload))
7399
assert.NoError(t, err)

0 commit comments

Comments
 (0)