Skip to content

Commit 271135b

Browse files
committed
feat: add support for X11 forwarding
1 parent 2b9d128 commit 271135b

File tree

7 files changed

+283
-11
lines changed

7 files changed

+283
-11
lines changed

agent/agent.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ type agent struct {
161161
}
162162

163163
func (a *agent) init(ctx context.Context) {
164-
sshSrv, err := agentssh.NewServer(ctx, a.logger.Named("ssh-server"), a.sshMaxTimeout)
164+
sshSrv, err := agentssh.NewServer(ctx, a.logger.Named("ssh-server"), a.filesystem, a.sshMaxTimeout, "")
165165
if err != nil {
166166
panic(err)
167167
}

agent/agentssh/agentssh.go

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import (
2020

2121
"github.com/gliderlabs/ssh"
2222
"github.com/pkg/sftp"
23+
"github.com/spf13/afero"
2324
"go.uber.org/atomic"
2425
gossh "golang.org/x/crypto/ssh"
2526
"golang.org/x/xerrors"
@@ -56,8 +57,9 @@ type Server struct {
5657
// a lock on mu but protected by closing.
5758
wg sync.WaitGroup
5859

59-
logger slog.Logger
60-
srv *ssh.Server
60+
logger slog.Logger
61+
srv *ssh.Server
62+
x11SocketDir string
6163

6264
Env map[string]string
6365
AgentToken func() string
@@ -68,7 +70,7 @@ type Server struct {
6870
connCountSSHSession atomic.Int64
6971
}
7072

71-
func NewServer(ctx context.Context, logger slog.Logger, maxTimeout time.Duration) (*Server, error) {
73+
func NewServer(ctx context.Context, logger slog.Logger, fs afero.Fs, maxTimeout time.Duration, x11SocketDir string) (*Server, error) {
7274
// Clients' should ignore the host key when connecting.
7375
// The agent needs to authenticate with coderd to SSH,
7476
// so SSH authentication doesn't improve security.
@@ -80,15 +82,23 @@ func NewServer(ctx context.Context, logger slog.Logger, maxTimeout time.Duration
8082
if err != nil {
8183
return nil, err
8284
}
85+
if x11SocketDir == "" {
86+
x11SocketDir = filepath.Join(os.TempDir(), ".X11-unix")
87+
}
88+
err = fs.MkdirAll(x11SocketDir, 0700)
89+
if err != nil {
90+
return nil, err
91+
}
8392

8493
forwardHandler := &ssh.ForwardedTCPHandler{}
8594
unixForwardHandler := &forwardedUnixHandler{log: logger}
8695

8796
s := &Server{
88-
listeners: make(map[net.Listener]struct{}),
89-
conns: make(map[net.Conn]struct{}),
90-
sessions: make(map[ssh.Session]struct{}),
91-
logger: logger,
97+
listeners: make(map[net.Listener]struct{}),
98+
conns: make(map[net.Conn]struct{}),
99+
sessions: make(map[ssh.Session]struct{}),
100+
logger: logger,
101+
x11SocketDir: x11SocketDir,
92102
}
93103

94104
s.srv = &ssh.Server{
@@ -125,6 +135,9 @@ func NewServer(ctx context.Context, logger slog.Logger, maxTimeout time.Duration
125135
"streamlocal-forward@openssh.com": unixForwardHandler.HandleSSHRequest,
126136
"cancel-streamlocal-forward@openssh.com": unixForwardHandler.HandleSSHRequest,
127137
},
138+
X11Callback: func(ctx ssh.Context, x11 ssh.X11) bool {
139+
return x11Callback(logger, fs, ctx, x11)
140+
},
128141
ServerConfigCallback: func(ctx ssh.Context) *gossh.ServerConfig {
129142
return &gossh.ServerConfig{
130143
NoClientAuth: true,
@@ -163,6 +176,15 @@ func (s *Server) sessionHandler(session ssh.Session) {
163176

164177
ctx := session.Context()
165178

179+
x11, hasX11 := session.X11()
180+
if hasX11 {
181+
handled := s.x11Handler(session.Context(), x11)
182+
if !handled {
183+
_ = session.Exit(1)
184+
return
185+
}
186+
}
187+
166188
switch ss := session.Subsystem(); ss {
167189
case "":
168190
case "sftp":

agent/agentssh/agentssh_test.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010
"sync"
1111
"testing"
1212

13+
"github.com/spf13/afero"
1314
"github.com/stretchr/testify/assert"
1415
"github.com/stretchr/testify/require"
1516
"go.uber.org/atomic"
@@ -32,7 +33,7 @@ func TestNewServer_ServeClient(t *testing.T) {
3233

3334
ctx := context.Background()
3435
logger := slogtest.Make(t, nil)
35-
s, err := agentssh.NewServer(ctx, logger, 0)
36+
s, err := agentssh.NewServer(ctx, logger, afero.NewMemMapFs(), 0, "")
3637
require.NoError(t, err)
3738

3839
// The assumption is that these are set before serving SSH connections.
@@ -50,6 +51,7 @@ func TestNewServer_ServeClient(t *testing.T) {
5051
}()
5152

5253
c := sshClient(t, ln.Addr().String())
54+
5355
var b bytes.Buffer
5456
sess, err := c.NewSession()
5557
sess.Stdout = &b
@@ -72,7 +74,7 @@ func TestNewServer_CloseActiveConnections(t *testing.T) {
7274

7375
ctx := context.Background()
7476
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
75-
s, err := agentssh.NewServer(ctx, logger, 0)
77+
s, err := agentssh.NewServer(ctx, logger, afero.NewMemMapFs(), 0, "")
7678
require.NoError(t, err)
7779

7880
// The assumption is that these are set before serving SSH connections.

agent/agentssh/x11.go

Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
package agentssh
2+
3+
import (
4+
"encoding/binary"
5+
"encoding/hex"
6+
"errors"
7+
"fmt"
8+
"net"
9+
"os"
10+
"path/filepath"
11+
"strconv"
12+
13+
"github.com/gliderlabs/ssh"
14+
"github.com/spf13/afero"
15+
gossh "golang.org/x/crypto/ssh"
16+
"golang.org/x/xerrors"
17+
18+
"cdr.dev/slog"
19+
)
20+
21+
// x11Callback is called when the client requests X11 forwarding.
22+
// It adds an Xauthority entry to the Xauthority file.
23+
func x11Callback(logger slog.Logger, fs afero.Fs, ctx ssh.Context, x11 ssh.X11) bool {
24+
hostname, err := os.Hostname()
25+
if err != nil {
26+
logger.Warn(ctx, "failed to get hostname", slog.Error(err))
27+
return false
28+
}
29+
30+
err = addXauthEntry(fs, hostname, strconv.Itoa(int(x11.ScreenNumber)), x11.AuthProtocol, x11.AuthCookie)
31+
if err != nil {
32+
logger.Warn(ctx, "failed to add Xauthority entry", slog.Error(err))
33+
return false
34+
}
35+
return true
36+
}
37+
38+
// x11Handler is called when a session has requested X11 forwarding.
39+
// It listens for X11 connections and forwards them to the client.
40+
func (s *Server) x11Handler(ctx ssh.Context, x11 ssh.X11) bool {
41+
serverConn, valid := ctx.Value(ssh.ContextKeyConn).(*gossh.ServerConn)
42+
if !valid {
43+
s.logger.Warn(ctx, "failed to get server connection")
44+
return false
45+
}
46+
listener, err := net.Listen("unix", filepath.Join(s.x11SocketDir, fmt.Sprintf("X%d", x11.ScreenNumber)))
47+
if err != nil {
48+
s.logger.Warn(ctx, "failed to listen for X11", slog.Error(err))
49+
return false
50+
}
51+
s.trackListener(listener, true)
52+
53+
go func() {
54+
defer s.trackListener(listener, false)
55+
56+
for {
57+
conn, err := listener.Accept()
58+
if err != nil {
59+
if errors.Is(err, net.ErrClosed) {
60+
return
61+
}
62+
s.logger.Warn(ctx, "failed to accept X11 connection", slog.Error(err))
63+
return
64+
}
65+
unixConn, ok := conn.(*net.UnixConn)
66+
if !ok {
67+
s.logger.Warn(ctx, "failed to cast connection to UnixConn")
68+
return
69+
}
70+
unixAddr, ok := unixConn.LocalAddr().(*net.UnixAddr)
71+
if !ok {
72+
s.logger.Warn(ctx, "failed to cast local address to UnixAddr")
73+
return
74+
}
75+
76+
channel, _, err := serverConn.OpenChannel("x11", gossh.Marshal(struct {
77+
OriginatorAddress string
78+
OriginatorPort uint32
79+
}{
80+
OriginatorAddress: unixAddr.Name,
81+
OriginatorPort: 0,
82+
}))
83+
if err != nil {
84+
s.logger.Warn(ctx, "failed to open X11 channel", slog.Error(err))
85+
return
86+
}
87+
88+
go Bicopy(ctx, conn, channel)
89+
}
90+
}()
91+
return true
92+
}
93+
94+
// addXauthEntry adds an Xauthority entry to the Xauthority file.
95+
// The Xauthority file is located at ~/.Xauthority.
96+
func addXauthEntry(fs afero.Fs, host string, display string, authProtocol string, authCookie string) error {
97+
// Get the Xauthority file path
98+
homeDir, err := os.UserHomeDir()
99+
if err != nil {
100+
return xerrors.Errorf("failed to get user home directory: %w", err)
101+
}
102+
103+
xauthPath := filepath.Join(homeDir, ".Xauthority")
104+
105+
// Open or create the Xauthority file
106+
file, err := fs.OpenFile(xauthPath, os.O_RDWR|os.O_CREATE|os.O_APPEND, 0600)
107+
if err != nil {
108+
return xerrors.Errorf("failed to open Xauthority file: %w", err)
109+
}
110+
defer file.Close()
111+
112+
// Convert the authCookie from hex string to byte slice
113+
authCookieBytes, err := hex.DecodeString(authCookie)
114+
if err != nil {
115+
return xerrors.Errorf("failed to decode auth cookie: %w", err)
116+
}
117+
118+
// Write Xauthority entry
119+
family := uint16(0x0100) // FamilyLocal
120+
err = binary.Write(file, binary.BigEndian, family)
121+
if err != nil {
122+
return xerrors.Errorf("failed to write family: %w", err)
123+
}
124+
125+
err = binary.Write(file, binary.BigEndian, uint16(len(host)))
126+
if err != nil {
127+
return xerrors.Errorf("failed to write host length: %w", err)
128+
}
129+
_, err = file.WriteString(host)
130+
if err != nil {
131+
return xerrors.Errorf("failed to write host: %w", err)
132+
}
133+
134+
err = binary.Write(file, binary.BigEndian, uint16(len(display)))
135+
if err != nil {
136+
return xerrors.Errorf("failed to write display length: %w", err)
137+
}
138+
_, err = file.WriteString(display)
139+
if err != nil {
140+
return xerrors.Errorf("failed to write display: %w", err)
141+
}
142+
143+
err = binary.Write(file, binary.BigEndian, uint16(len(authProtocol)))
144+
if err != nil {
145+
return xerrors.Errorf("failed to write auth protocol length: %w", err)
146+
}
147+
_, err = file.WriteString(authProtocol)
148+
if err != nil {
149+
return xerrors.Errorf("failed to write auth protocol: %w", err)
150+
}
151+
152+
err = binary.Write(file, binary.BigEndian, uint16(len(authCookieBytes)))
153+
if err != nil {
154+
return xerrors.Errorf("failed to write auth cookie length: %w", err)
155+
}
156+
_, err = file.Write(authCookieBytes)
157+
if err != nil {
158+
return xerrors.Errorf("failed to write auth cookie: %w", err)
159+
}
160+
161+
return nil
162+
}

agent/agentssh/x11_test.go

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
package agentssh_test
2+
3+
import (
4+
"context"
5+
"encoding/hex"
6+
"net"
7+
"path/filepath"
8+
"testing"
9+
10+
"github.com/gliderlabs/ssh"
11+
"github.com/spf13/afero"
12+
"github.com/stretchr/testify/assert"
13+
"github.com/stretchr/testify/require"
14+
"go.uber.org/atomic"
15+
gossh "golang.org/x/crypto/ssh"
16+
17+
"cdr.dev/slog"
18+
"cdr.dev/slog/sloggers/slogtest"
19+
"github.com/coder/coder/agent/agentssh"
20+
"github.com/coder/coder/codersdk/agentsdk"
21+
"github.com/coder/coder/testutil"
22+
)
23+
24+
func TestServer_X11(t *testing.T) {
25+
t.Parallel()
26+
27+
ctx := context.Background()
28+
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
29+
fs := afero.NewOsFs()
30+
dir := t.TempDir()
31+
s, err := agentssh.NewServer(ctx, logger, fs, 0, dir)
32+
require.NoError(t, err)
33+
defer s.Close()
34+
35+
// The assumption is that these are set before serving SSH connections.
36+
s.AgentToken = func() string { return "" }
37+
s.Manifest = atomic.NewPointer(&agentsdk.Manifest{})
38+
39+
ln, err := net.Listen("tcp", "127.0.0.1:0")
40+
require.NoError(t, err)
41+
42+
done := make(chan struct{})
43+
go func() {
44+
defer close(done)
45+
err := s.Serve(ln)
46+
assert.Error(t, err) // Server is closed.
47+
}()
48+
49+
c := sshClient(t, ln.Addr().String())
50+
51+
sess, err := c.NewSession()
52+
require.NoError(t, err)
53+
54+
reply, err := sess.SendRequest("x11-req", true, gossh.Marshal(ssh.X11{
55+
AuthProtocol: "MIT-MAGIC-COOKIE-1",
56+
AuthCookie: hex.EncodeToString([]byte("cookie")),
57+
ScreenNumber: 0,
58+
}))
59+
require.NoError(t, err)
60+
assert.True(t, reply)
61+
62+
err = sess.Shell()
63+
require.NoError(t, err)
64+
65+
x11Chans := c.HandleChannelOpen("x11")
66+
require.Eventually(t, func() bool {
67+
conn, err := net.Dial("unix", filepath.Join(dir, "X0"))
68+
if err == nil {
69+
_ = conn.Close()
70+
}
71+
return err == nil
72+
}, testutil.WaitShort, testutil.IntervalFast)
73+
74+
x11 := <-x11Chans
75+
ch, reqs, err := x11.Accept()
76+
require.NoError(t, err)
77+
go gossh.DiscardRequests(reqs)
78+
err = ch.Close()
79+
require.NoError(t, err)
80+
s.Close()
81+
<-done
82+
}

go.mod

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ replace tailscale.com => github.com/coder/tailscale v1.1.1-0.20230418202606-ed93
4545
// repo as tailscale.com/tempfork/gliderlabs/ssh, however, we can't replace the
4646
// subpath and it includes changes to golang.org/x/crypto/ssh as well which
4747
// makes importing it directly a bit messy.
48-
replace github.com/gliderlabs/ssh => github.com/coder/ssh v0.0.0-20220811105153-fcea99919338
48+
replace github.com/gliderlabs/ssh => github.com/coder/ssh v0.0.0-20230419180646-49c741437b53
4949

5050
// Waiting on https://github.com/imulab/go-scim/pull/95 to merge.
5151
replace github.com/imulab/go-scim/pkg/v2 => github.com/coder/go-scim/pkg/v2 v2.0.0-20230221055123-1d63c1222136

go.sum

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -380,6 +380,10 @@ github.com/coder/retry v1.3.1-0.20230210155434-e90a2e1e091d h1:09JG37IgTB6n3ouX9
380380
github.com/coder/retry v1.3.1-0.20230210155434-e90a2e1e091d/go.mod h1:r+1J5i/989wt6CUeNSuvFKKA9hHuKKPMxdzDbTuvwwk=
381381
github.com/coder/ssh v0.0.0-20220811105153-fcea99919338 h1:tN5GKFT68YLVzJoA8AHuiMNJ0qlhoD3pGN3JY9gxSko=
382382
github.com/coder/ssh v0.0.0-20220811105153-fcea99919338/go.mod h1:ZSS+CUoKHDrqVakTfTWUlKSr9MtMFkC4UvtQKD7O914=
383+
github.com/coder/ssh v0.0.0-20230419175457-0612ba535202 h1:1I/Im5ZUan1Y9ypAr6VuAKQ4NbvEy/frR3cV86pKQk8=
384+
github.com/coder/ssh v0.0.0-20230419175457-0612ba535202/go.mod h1:ZSS+CUoKHDrqVakTfTWUlKSr9MtMFkC4UvtQKD7O914=
385+
github.com/coder/ssh v0.0.0-20230419180646-49c741437b53 h1:kaLOp3tlVnbOJIjmAvXuBTgeWWoZZlJJJ4QGeSMjOnA=
386+
github.com/coder/ssh v0.0.0-20230419180646-49c741437b53/go.mod h1:ZSS+CUoKHDrqVakTfTWUlKSr9MtMFkC4UvtQKD7O914=
383387
github.com/coder/tailscale v1.1.1-0.20230418202606-ed9307cf1b22 h1:bvGOqnI0ITbwOZFQ0SZ4MBw/8LLUEjxmNu57XEujrfQ=
384388
github.com/coder/tailscale v1.1.1-0.20230418202606-ed9307cf1b22/go.mod h1:jpg+77g19FpXL43U1VoIqoSg1K/Vh5CVxycGldQ8KhA=
385389
github.com/coder/terraform-provider-coder v0.6.23 h1:O2Rcj0umez4DfVdGnKZi63z1Xzxd0IQOn9VQDB8YU8g=

0 commit comments

Comments
 (0)