Skip to content

Commit 8f07d33

Browse files
authored
feat(agent/agentssh): use tcp for X11 forwarding (coder#14560)
Fixes coder#14198
1 parent e6d8f67 commit 8f07d33

File tree

3 files changed

+129
-67
lines changed

3 files changed

+129
-67
lines changed

agent/agentssh/agentssh.go

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -79,9 +79,9 @@ type Config struct {
7979
// where users will land when they connect via SSH. Default is the home
8080
// directory of the user.
8181
WorkingDirectory func() string
82-
// X11SocketDir is the directory where X11 sockets are created. Default is
83-
// /tmp/.X11-unix.
84-
X11SocketDir string
82+
// X11DisplayOffset is the offset to add to the X11 display number.
83+
// Default is 10.
84+
X11DisplayOffset *int
8585
// BlockFileTransfer restricts use of file transfer applications.
8686
BlockFileTransfer bool
8787
}
@@ -124,8 +124,9 @@ func NewServer(ctx context.Context, logger slog.Logger, prometheusRegistry *prom
124124
if config == nil {
125125
config = &Config{}
126126
}
127-
if config.X11SocketDir == "" {
128-
config.X11SocketDir = filepath.Join(os.TempDir(), ".X11-unix")
127+
if config.X11DisplayOffset == nil {
128+
offset := X11DefaultDisplayOffset
129+
config.X11DisplayOffset = &offset
129130
}
130131
if config.UpdateEnv == nil {
131132
config.UpdateEnv = func(current []string) ([]string, error) { return current, nil }
@@ -273,13 +274,13 @@ func (s *Server) sessionHandler(session ssh.Session) {
273274
extraEnv := make([]string, 0)
274275
x11, hasX11 := session.X11()
275276
if hasX11 {
276-
handled := s.x11Handler(session.Context(), x11)
277+
display, handled := s.x11Handler(session.Context(), x11)
277278
if !handled {
278279
_ = session.Exit(1)
279280
logger.Error(ctx, "x11 handler failed")
280281
return
281282
}
282-
extraEnv = append(extraEnv, fmt.Sprintf("DISPLAY=:%d.0", x11.ScreenNumber))
283+
extraEnv = append(extraEnv, fmt.Sprintf("DISPLAY=localhost:%d.%d", display, x11.ScreenNumber))
283284
}
284285

285286
if s.fileTransferBlocked(session) {

agent/agentssh/x11.go

Lines changed: 88 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"errors"
88
"fmt"
99
"io"
10+
"math"
1011
"net"
1112
"os"
1213
"path/filepath"
@@ -22,102 +23,136 @@ import (
2223
"cdr.dev/slog"
2324
)
2425

26+
const (
27+
// X11StartPort is the starting port for X11 forwarding, this is the
28+
// port used for "DISPLAY=localhost:0".
29+
X11StartPort = 6000
30+
// X11DefaultDisplayOffset is the default offset for X11 forwarding.
31+
X11DefaultDisplayOffset = 10
32+
)
33+
2534
// x11Callback is called when the client requests X11 forwarding.
26-
// It adds an Xauthority entry to the Xauthority file.
27-
func (s *Server) x11Callback(ctx ssh.Context, x11 ssh.X11) bool {
35+
func (*Server) x11Callback(_ ssh.Context, _ ssh.X11) bool {
36+
// Always allow.
37+
return true
38+
}
39+
40+
// x11Handler is called when a session has requested X11 forwarding.
41+
// It listens for X11 connections and forwards them to the client.
42+
func (s *Server) x11Handler(ctx ssh.Context, x11 ssh.X11) (displayNumber int, handled bool) {
43+
serverConn, valid := ctx.Value(ssh.ContextKeyConn).(*gossh.ServerConn)
44+
if !valid {
45+
s.logger.Warn(ctx, "failed to get server connection")
46+
return -1, false
47+
}
48+
2849
hostname, err := os.Hostname()
2950
if err != nil {
3051
s.logger.Warn(ctx, "failed to get hostname", slog.Error(err))
3152
s.metrics.x11HandlerErrors.WithLabelValues("hostname").Add(1)
32-
return false
53+
return -1, false
3354
}
3455

35-
err = s.fs.MkdirAll(s.config.X11SocketDir, 0o700)
56+
ln, display, err := createX11Listener(ctx, *s.config.X11DisplayOffset)
3657
if err != nil {
37-
s.logger.Warn(ctx, "failed to make the x11 socket dir", slog.F("dir", s.config.X11SocketDir), slog.Error(err))
38-
s.metrics.x11HandlerErrors.WithLabelValues("socker_dir").Add(1)
39-
return false
40-
}
58+
s.logger.Warn(ctx, "failed to create X11 listener", slog.Error(err))
59+
s.metrics.x11HandlerErrors.WithLabelValues("listen").Add(1)
60+
return -1, false
61+
}
62+
s.trackListener(ln, true)
63+
defer func() {
64+
if !handled {
65+
s.trackListener(ln, false)
66+
_ = ln.Close()
67+
}
68+
}()
4169

42-
err = addXauthEntry(ctx, s.fs, hostname, strconv.Itoa(int(x11.ScreenNumber)), x11.AuthProtocol, x11.AuthCookie)
70+
err = addXauthEntry(ctx, s.fs, hostname, strconv.Itoa(display), x11.AuthProtocol, x11.AuthCookie)
4371
if err != nil {
4472
s.logger.Warn(ctx, "failed to add Xauthority entry", slog.Error(err))
4573
s.metrics.x11HandlerErrors.WithLabelValues("xauthority").Add(1)
46-
return false
74+
return -1, false
4775
}
48-
return true
49-
}
5076

51-
// x11Handler is called when a session has requested X11 forwarding.
52-
// It listens for X11 connections and forwards them to the client.
53-
func (s *Server) x11Handler(ctx ssh.Context, x11 ssh.X11) bool {
54-
serverConn, valid := ctx.Value(ssh.ContextKeyConn).(*gossh.ServerConn)
55-
if !valid {
56-
s.logger.Warn(ctx, "failed to get server connection")
57-
return false
58-
}
59-
// We want to overwrite the socket so that subsequent connections will succeed.
60-
socketPath := filepath.Join(s.config.X11SocketDir, fmt.Sprintf("X%d", x11.ScreenNumber))
61-
err := os.Remove(socketPath)
62-
if err != nil && !errors.Is(err, os.ErrNotExist) {
63-
s.logger.Warn(ctx, "failed to remove existing X11 socket", slog.Error(err))
64-
return false
65-
}
66-
listener, err := net.Listen("unix", socketPath)
67-
if err != nil {
68-
s.logger.Warn(ctx, "failed to listen for X11", slog.Error(err))
69-
return false
70-
}
71-
s.trackListener(listener, true)
77+
go func() {
78+
// Don't leave the listener open after the session is gone.
79+
<-ctx.Done()
80+
_ = ln.Close()
81+
}()
7282

7383
go func() {
74-
defer listener.Close()
75-
defer s.trackListener(listener, false)
76-
handledFirstConnection := false
84+
defer ln.Close()
85+
defer s.trackListener(ln, false)
7786

7887
for {
79-
conn, err := listener.Accept()
88+
conn, err := ln.Accept()
8089
if err != nil {
8190
if errors.Is(err, net.ErrClosed) {
8291
return
8392
}
8493
s.logger.Warn(ctx, "failed to accept X11 connection", slog.Error(err))
8594
return
8695
}
87-
if x11.SingleConnection && handledFirstConnection {
88-
s.logger.Warn(ctx, "X11 connection rejected because single connection is enabled")
89-
_ = conn.Close()
90-
continue
96+
if x11.SingleConnection {
97+
s.logger.Debug(ctx, "single connection requested, closing X11 listener")
98+
_ = ln.Close()
9199
}
92-
handledFirstConnection = true
93100

94-
unixConn, ok := conn.(*net.UnixConn)
101+
tcpConn, ok := conn.(*net.TCPConn)
95102
if !ok {
96-
s.logger.Warn(ctx, fmt.Sprintf("failed to cast connection to UnixConn. got: %T", conn))
97-
return
103+
s.logger.Warn(ctx, fmt.Sprintf("failed to cast connection to TCPConn. got: %T", conn))
104+
_ = conn.Close()
105+
continue
98106
}
99-
unixAddr, ok := unixConn.LocalAddr().(*net.UnixAddr)
107+
tcpAddr, ok := tcpConn.LocalAddr().(*net.TCPAddr)
100108
if !ok {
101-
s.logger.Warn(ctx, fmt.Sprintf("failed to cast local address to UnixAddr. got: %T", unixConn.LocalAddr()))
102-
return
109+
s.logger.Warn(ctx, fmt.Sprintf("failed to cast local address to TCPAddr. got: %T", tcpConn.LocalAddr()))
110+
_ = conn.Close()
111+
continue
103112
}
104113

105114
channel, reqs, err := serverConn.OpenChannel("x11", gossh.Marshal(struct {
106115
OriginatorAddress string
107116
OriginatorPort uint32
108117
}{
109-
OriginatorAddress: unixAddr.Name,
110-
OriginatorPort: 0,
118+
OriginatorAddress: tcpAddr.IP.String(),
119+
OriginatorPort: uint32(tcpAddr.Port),
111120
}))
112121
if err != nil {
113122
s.logger.Warn(ctx, "failed to open X11 channel", slog.Error(err))
114-
return
123+
_ = conn.Close()
124+
continue
115125
}
116126
go gossh.DiscardRequests(reqs)
117-
go Bicopy(ctx, conn, channel)
127+
128+
if !s.trackConn(ln, conn, true) {
129+
s.logger.Warn(ctx, "failed to track X11 connection")
130+
_ = conn.Close()
131+
continue
132+
}
133+
go func() {
134+
defer s.trackConn(ln, conn, false)
135+
Bicopy(ctx, conn, channel)
136+
}()
118137
}
119138
}()
120-
return true
139+
140+
return display, true
141+
}
142+
143+
// createX11Listener creates a listener for X11 forwarding, it will use
144+
// the next available port starting from X11StartPort and displayOffset.
145+
func createX11Listener(ctx context.Context, displayOffset int) (ln net.Listener, display int, err error) {
146+
var lc net.ListenConfig
147+
// Look for an open port to listen on.
148+
for port := X11StartPort + displayOffset; port < math.MaxUint16; port++ {
149+
ln, err = lc.Listen(ctx, "tcp", fmt.Sprintf("localhost:%d", port))
150+
if err == nil {
151+
display = port - X11StartPort
152+
return ln, display, nil
153+
}
154+
}
155+
return nil, -1, xerrors.Errorf("failed to find open port for X11 listener: %w", err)
121156
}
122157

123158
// addXauthEntry adds an Xauthority entry to the Xauthority file.

agent/agentssh/x11_test.go

Lines changed: 33 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,17 @@
11
package agentssh_test
22

33
import (
4+
"bufio"
5+
"bytes"
46
"context"
57
"encoding/hex"
8+
"fmt"
69
"net"
710
"os"
811
"path/filepath"
912
"runtime"
13+
"strconv"
14+
"strings"
1015
"testing"
1116

1217
"github.com/gliderlabs/ssh"
@@ -31,10 +36,7 @@ func TestServer_X11(t *testing.T) {
3136
ctx := context.Background()
3237
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
3338
fs := afero.NewOsFs()
34-
dir := t.TempDir()
35-
s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), fs, &agentssh.Config{
36-
X11SocketDir: dir,
37-
})
39+
s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), fs, &agentssh.Config{})
3840
require.NoError(t, err)
3941
defer s.Close()
4042

@@ -53,21 +55,45 @@ func TestServer_X11(t *testing.T) {
5355
sess, err := c.NewSession()
5456
require.NoError(t, err)
5557

58+
wantScreenNumber := 1
5659
reply, err := sess.SendRequest("x11-req", true, gossh.Marshal(ssh.X11{
5760
AuthProtocol: "MIT-MAGIC-COOKIE-1",
5861
AuthCookie: hex.EncodeToString([]byte("cookie")),
59-
ScreenNumber: 0,
62+
ScreenNumber: uint32(wantScreenNumber),
6063
}))
6164
require.NoError(t, err)
6265
assert.True(t, reply)
6366

64-
err = sess.Shell()
67+
// Want: ~DISPLAY=localhost:10.1
68+
out, err := sess.Output("echo DISPLAY=$DISPLAY")
6569
require.NoError(t, err)
6670

71+
sc := bufio.NewScanner(bytes.NewReader(out))
72+
displayNumber := -1
73+
for sc.Scan() {
74+
line := strings.TrimSpace(sc.Text())
75+
t.Log(line)
76+
if strings.HasPrefix(line, "DISPLAY=") {
77+
parts := strings.SplitN(line, "=", 2)
78+
display := parts[1]
79+
parts = strings.SplitN(display, ":", 2)
80+
parts = strings.SplitN(parts[1], ".", 2)
81+
displayNumber, err = strconv.Atoi(parts[0])
82+
require.NoError(t, err)
83+
assert.GreaterOrEqual(t, displayNumber, 10, "display number should be >= 10")
84+
gotScreenNumber, err := strconv.Atoi(parts[1])
85+
require.NoError(t, err)
86+
assert.Equal(t, wantScreenNumber, gotScreenNumber, "screen number should match")
87+
break
88+
}
89+
}
90+
require.NoError(t, sc.Err())
91+
require.NotEqual(t, -1, displayNumber)
92+
6793
x11Chans := c.HandleChannelOpen("x11")
6894
payload := "hello world"
6995
require.Eventually(t, func() bool {
70-
conn, err := net.Dial("unix", filepath.Join(dir, "X0"))
96+
conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", agentssh.X11StartPort+displayNumber))
7197
if err == nil {
7298
_, err = conn.Write([]byte(payload))
7399
assert.NoError(t, err)

0 commit comments

Comments
 (0)