Skip to content

Commit f39e6a7

Browse files
authored
feat: add support for X11 forwarding (#7205)
* feat: add support for X11 forwarding * Only run X forwarding on Linux * Fix piping * Fix comments
1 parent 6f06f8d commit f39e6a7

File tree

7 files changed

+324
-11
lines changed

7 files changed

+324
-11
lines changed

agent/agent.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ type agent struct {
161161
}
162162

163163
func (a *agent) init(ctx context.Context) {
164-
sshSrv, err := agentssh.NewServer(ctx, a.logger.Named("ssh-server"), a.sshMaxTimeout)
164+
sshSrv, err := agentssh.NewServer(ctx, a.logger.Named("ssh-server"), a.filesystem, a.sshMaxTimeout, "")
165165
if err != nil {
166166
panic(err)
167167
}

agent/agentssh/agentssh.go

+25-7
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import (
2020

2121
"github.com/gliderlabs/ssh"
2222
"github.com/pkg/sftp"
23+
"github.com/spf13/afero"
2324
"go.uber.org/atomic"
2425
gossh "golang.org/x/crypto/ssh"
2526
"golang.org/x/xerrors"
@@ -48,6 +49,7 @@ const (
4849

4950
type Server struct {
5051
mu sync.RWMutex // Protects following.
52+
fs afero.Fs
5153
listeners map[net.Listener]struct{}
5254
conns map[net.Conn]struct{}
5355
sessions map[ssh.Session]struct{}
@@ -56,8 +58,9 @@ type Server struct {
5658
// a lock on mu but protected by closing.
5759
wg sync.WaitGroup
5860

59-
logger slog.Logger
60-
srv *ssh.Server
61+
logger slog.Logger
62+
srv *ssh.Server
63+
x11SocketDir string
6164

6265
Env map[string]string
6366
AgentToken func() string
@@ -68,7 +71,7 @@ type Server struct {
6871
connCountSSHSession atomic.Int64
6972
}
7073

71-
func NewServer(ctx context.Context, logger slog.Logger, maxTimeout time.Duration) (*Server, error) {
74+
func NewServer(ctx context.Context, logger slog.Logger, fs afero.Fs, maxTimeout time.Duration, x11SocketDir string) (*Server, error) {
7275
// Clients' should ignore the host key when connecting.
7376
// The agent needs to authenticate with coderd to SSH,
7477
// so SSH authentication doesn't improve security.
@@ -80,15 +83,20 @@ func NewServer(ctx context.Context, logger slog.Logger, maxTimeout time.Duration
8083
if err != nil {
8184
return nil, err
8285
}
86+
if x11SocketDir == "" {
87+
x11SocketDir = filepath.Join(os.TempDir(), ".X11-unix")
88+
}
8389

8490
forwardHandler := &ssh.ForwardedTCPHandler{}
8591
unixForwardHandler := &forwardedUnixHandler{log: logger}
8692

8793
s := &Server{
88-
listeners: make(map[net.Listener]struct{}),
89-
conns: make(map[net.Conn]struct{}),
90-
sessions: make(map[ssh.Session]struct{}),
91-
logger: logger,
94+
listeners: make(map[net.Listener]struct{}),
95+
fs: fs,
96+
conns: make(map[net.Conn]struct{}),
97+
sessions: make(map[ssh.Session]struct{}),
98+
logger: logger,
99+
x11SocketDir: x11SocketDir,
92100
}
93101

94102
s.srv = &ssh.Server{
@@ -125,6 +133,7 @@ func NewServer(ctx context.Context, logger slog.Logger, maxTimeout time.Duration
125133
"streamlocal-forward@openssh.com": unixForwardHandler.HandleSSHRequest,
126134
"cancel-streamlocal-forward@openssh.com": unixForwardHandler.HandleSSHRequest,
127135
},
136+
X11Callback: s.x11Callback,
128137
ServerConfigCallback: func(ctx ssh.Context) *gossh.ServerConfig {
129138
return &gossh.ServerConfig{
130139
NoClientAuth: true,
@@ -163,6 +172,15 @@ func (s *Server) sessionHandler(session ssh.Session) {
163172

164173
ctx := session.Context()
165174

175+
x11, hasX11 := session.X11()
176+
if hasX11 {
177+
handled := s.x11Handler(session.Context(), x11)
178+
if !handled {
179+
_ = session.Exit(1)
180+
return
181+
}
182+
}
183+
166184
switch ss := session.Subsystem(); ss {
167185
case "":
168186
case "sftp":

agent/agentssh/agentssh_test.go

+4-2
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010
"sync"
1111
"testing"
1212

13+
"github.com/spf13/afero"
1314
"github.com/stretchr/testify/assert"
1415
"github.com/stretchr/testify/require"
1516
"go.uber.org/atomic"
@@ -32,7 +33,7 @@ func TestNewServer_ServeClient(t *testing.T) {
3233

3334
ctx := context.Background()
3435
logger := slogtest.Make(t, nil)
35-
s, err := agentssh.NewServer(ctx, logger, 0)
36+
s, err := agentssh.NewServer(ctx, logger, afero.NewMemMapFs(), 0, "")
3637
require.NoError(t, err)
3738

3839
// The assumption is that these are set before serving SSH connections.
@@ -50,6 +51,7 @@ func TestNewServer_ServeClient(t *testing.T) {
5051
}()
5152

5253
c := sshClient(t, ln.Addr().String())
54+
5355
var b bytes.Buffer
5456
sess, err := c.NewSession()
5557
sess.Stdout = &b
@@ -72,7 +74,7 @@ func TestNewServer_CloseActiveConnections(t *testing.T) {
7274

7375
ctx := context.Background()
7476
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
75-
s, err := agentssh.NewServer(ctx, logger, 0)
77+
s, err := agentssh.NewServer(ctx, logger, afero.NewMemMapFs(), 0, "")
7678
require.NoError(t, err)
7779

7880
// The assumption is that these are set before serving SSH connections.

agent/agentssh/x11.go

+190
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,190 @@
1+
package agentssh
2+
3+
import (
4+
"context"
5+
"encoding/binary"
6+
"encoding/hex"
7+
"errors"
8+
"fmt"
9+
"net"
10+
"os"
11+
"path/filepath"
12+
"strconv"
13+
"time"
14+
15+
"github.com/gliderlabs/ssh"
16+
"github.com/gofrs/flock"
17+
"github.com/spf13/afero"
18+
gossh "golang.org/x/crypto/ssh"
19+
"golang.org/x/xerrors"
20+
21+
"cdr.dev/slog"
22+
)
23+
24+
// x11Callback is called when the client requests X11 forwarding.
25+
// It adds an Xauthority entry to the Xauthority file.
26+
func (s *Server) x11Callback(ctx ssh.Context, x11 ssh.X11) bool {
27+
hostname, err := os.Hostname()
28+
if err != nil {
29+
s.logger.Warn(ctx, "failed to get hostname", slog.Error(err))
30+
return false
31+
}
32+
33+
err = s.fs.MkdirAll(s.x11SocketDir, 0o700)
34+
if err != nil {
35+
s.logger.Warn(ctx, "failed to make the x11 socket dir", slog.F("dir", s.x11SocketDir), slog.Error(err))
36+
return false
37+
}
38+
39+
err = addXauthEntry(ctx, s.fs, hostname, strconv.Itoa(int(x11.ScreenNumber)), x11.AuthProtocol, x11.AuthCookie)
40+
if err != nil {
41+
s.logger.Warn(ctx, "failed to add Xauthority entry", slog.Error(err))
42+
return false
43+
}
44+
return true
45+
}
46+
47+
// x11Handler is called when a session has requested X11 forwarding.
48+
// It listens for X11 connections and forwards them to the client.
49+
func (s *Server) x11Handler(ctx ssh.Context, x11 ssh.X11) bool {
50+
serverConn, valid := ctx.Value(ssh.ContextKeyConn).(*gossh.ServerConn)
51+
if !valid {
52+
s.logger.Warn(ctx, "failed to get server connection")
53+
return false
54+
}
55+
listener, err := net.Listen("unix", filepath.Join(s.x11SocketDir, fmt.Sprintf("X%d", x11.ScreenNumber)))
56+
if err != nil {
57+
s.logger.Warn(ctx, "failed to listen for X11", slog.Error(err))
58+
return false
59+
}
60+
s.trackListener(listener, true)
61+
62+
go func() {
63+
defer listener.Close()
64+
defer s.trackListener(listener, false)
65+
handledFirstConnection := false
66+
67+
for {
68+
conn, err := listener.Accept()
69+
if err != nil {
70+
if errors.Is(err, net.ErrClosed) {
71+
return
72+
}
73+
s.logger.Warn(ctx, "failed to accept X11 connection", slog.Error(err))
74+
return
75+
}
76+
if x11.SingleConnection && handledFirstConnection {
77+
s.logger.Warn(ctx, "X11 connection rejected because single connection is enabled")
78+
_ = conn.Close()
79+
continue
80+
}
81+
handledFirstConnection = true
82+
83+
unixConn, ok := conn.(*net.UnixConn)
84+
if !ok {
85+
s.logger.Warn(ctx, fmt.Sprintf("failed to cast connection to UnixConn. got: %T", conn))
86+
return
87+
}
88+
unixAddr, ok := unixConn.LocalAddr().(*net.UnixAddr)
89+
if !ok {
90+
s.logger.Warn(ctx, fmt.Sprintf("failed to cast local address to UnixAddr. got: %T", unixConn.LocalAddr()))
91+
return
92+
}
93+
94+
channel, reqs, err := serverConn.OpenChannel("x11", gossh.Marshal(struct {
95+
OriginatorAddress string
96+
OriginatorPort uint32
97+
}{
98+
OriginatorAddress: unixAddr.Name,
99+
OriginatorPort: 0,
100+
}))
101+
if err != nil {
102+
s.logger.Warn(ctx, "failed to open X11 channel", slog.Error(err))
103+
return
104+
}
105+
go gossh.DiscardRequests(reqs)
106+
go Bicopy(ctx, conn, channel)
107+
}
108+
}()
109+
return true
110+
}
111+
112+
// addXauthEntry adds an Xauthority entry to the Xauthority file.
113+
// The Xauthority file is located at ~/.Xauthority.
114+
func addXauthEntry(ctx context.Context, fs afero.Fs, host string, display string, authProtocol string, authCookie string) error {
115+
// Get the Xauthority file path
116+
homeDir, err := os.UserHomeDir()
117+
if err != nil {
118+
return xerrors.Errorf("failed to get user home directory: %w", err)
119+
}
120+
121+
xauthPath := filepath.Join(homeDir, ".Xauthority")
122+
123+
lock := flock.New(xauthPath)
124+
defer lock.Close()
125+
ok, err := lock.TryLockContext(ctx, 100*time.Millisecond)
126+
if !ok {
127+
return xerrors.Errorf("failed to lock Xauthority file: %w", err)
128+
}
129+
if err != nil {
130+
return xerrors.Errorf("failed to lock Xauthority file: %w", err)
131+
}
132+
133+
// Open or create the Xauthority file
134+
file, err := fs.OpenFile(xauthPath, os.O_RDWR|os.O_CREATE|os.O_APPEND, 0o600)
135+
if err != nil {
136+
return xerrors.Errorf("failed to open Xauthority file: %w", err)
137+
}
138+
defer file.Close()
139+
140+
// Convert the authCookie from hex string to byte slice
141+
authCookieBytes, err := hex.DecodeString(authCookie)
142+
if err != nil {
143+
return xerrors.Errorf("failed to decode auth cookie: %w", err)
144+
}
145+
146+
// Write Xauthority entry
147+
family := uint16(0x0100) // FamilyLocal
148+
err = binary.Write(file, binary.BigEndian, family)
149+
if err != nil {
150+
return xerrors.Errorf("failed to write family: %w", err)
151+
}
152+
153+
err = binary.Write(file, binary.BigEndian, uint16(len(host)))
154+
if err != nil {
155+
return xerrors.Errorf("failed to write host length: %w", err)
156+
}
157+
_, err = file.WriteString(host)
158+
if err != nil {
159+
return xerrors.Errorf("failed to write host: %w", err)
160+
}
161+
162+
err = binary.Write(file, binary.BigEndian, uint16(len(display)))
163+
if err != nil {
164+
return xerrors.Errorf("failed to write display length: %w", err)
165+
}
166+
_, err = file.WriteString(display)
167+
if err != nil {
168+
return xerrors.Errorf("failed to write display: %w", err)
169+
}
170+
171+
err = binary.Write(file, binary.BigEndian, uint16(len(authProtocol)))
172+
if err != nil {
173+
return xerrors.Errorf("failed to write auth protocol length: %w", err)
174+
}
175+
_, err = file.WriteString(authProtocol)
176+
if err != nil {
177+
return xerrors.Errorf("failed to write auth protocol: %w", err)
178+
}
179+
180+
err = binary.Write(file, binary.BigEndian, uint16(len(authCookieBytes)))
181+
if err != nil {
182+
return xerrors.Errorf("failed to write auth cookie length: %w", err)
183+
}
184+
_, err = file.Write(authCookieBytes)
185+
if err != nil {
186+
return xerrors.Errorf("failed to write auth cookie: %w", err)
187+
}
188+
189+
return nil
190+
}

0 commit comments

Comments
 (0)