Skip to content

Commit 5ad8633

Browse files
committed
Only run X forwarding on Linux
1 parent 271135b commit 5ad8633

File tree

3 files changed

+40
-17
lines changed

3 files changed

+40
-17
lines changed

agent/agentssh/agentssh.go

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ const (
4949

5050
type Server struct {
5151
mu sync.RWMutex // Protects following.
52+
fs afero.Fs
5253
listeners map[net.Listener]struct{}
5354
conns map[net.Conn]struct{}
5455
sessions map[ssh.Session]struct{}
@@ -85,16 +86,13 @@ func NewServer(ctx context.Context, logger slog.Logger, fs afero.Fs, maxTimeout
8586
if x11SocketDir == "" {
8687
x11SocketDir = filepath.Join(os.TempDir(), ".X11-unix")
8788
}
88-
err = fs.MkdirAll(x11SocketDir, 0700)
89-
if err != nil {
90-
return nil, err
91-
}
9289

9390
forwardHandler := &ssh.ForwardedTCPHandler{}
9491
unixForwardHandler := &forwardedUnixHandler{log: logger}
9592

9693
s := &Server{
9794
listeners: make(map[net.Listener]struct{}),
95+
fs: fs,
9896
conns: make(map[net.Conn]struct{}),
9997
sessions: make(map[ssh.Session]struct{}),
10098
logger: logger,
@@ -135,9 +133,7 @@ func NewServer(ctx context.Context, logger slog.Logger, fs afero.Fs, maxTimeout
135133
"streamlocal-forward@openssh.com": unixForwardHandler.HandleSSHRequest,
136134
"cancel-streamlocal-forward@openssh.com": unixForwardHandler.HandleSSHRequest,
137135
},
138-
X11Callback: func(ctx ssh.Context, x11 ssh.X11) bool {
139-
return x11Callback(logger, fs, ctx, x11)
140-
},
136+
X11Callback: s.x11Callback,
141137
ServerConfigCallback: func(ctx ssh.Context) *gossh.ServerConfig {
142138
return &gossh.ServerConfig{
143139
NoClientAuth: true,

agent/agentssh/x11.go

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package agentssh
22

33
import (
4+
"context"
45
"encoding/binary"
56
"encoding/hex"
67
"errors"
@@ -9,8 +10,10 @@ import (
910
"os"
1011
"path/filepath"
1112
"strconv"
13+
"time"
1214

1315
"github.com/gliderlabs/ssh"
16+
"github.com/gofrs/flock"
1417
"github.com/spf13/afero"
1518
gossh "golang.org/x/crypto/ssh"
1619
"golang.org/x/xerrors"
@@ -20,16 +23,22 @@ import (
2023

2124
// x11Callback is called when the client requests X11 forwarding.
2225
// It adds an Xauthority entry to the Xauthority file.
23-
func x11Callback(logger slog.Logger, fs afero.Fs, ctx ssh.Context, x11 ssh.X11) bool {
26+
func (s *Server) x11Callback(ctx ssh.Context, x11 ssh.X11) bool {
2427
hostname, err := os.Hostname()
2528
if err != nil {
26-
logger.Warn(ctx, "failed to get hostname", slog.Error(err))
29+
s.logger.Warn(ctx, "failed to get hostname", slog.Error(err))
2730
return false
2831
}
2932

30-
err = addXauthEntry(fs, hostname, strconv.Itoa(int(x11.ScreenNumber)), x11.AuthProtocol, x11.AuthCookie)
33+
err = s.fs.MkdirAll(s.x11SocketDir, 0o700)
3134
if err != nil {
32-
logger.Warn(ctx, "failed to add Xauthority entry", slog.Error(err))
35+
s.logger.Warn(ctx, "failed to make the x11 socket dir", slog.F("dir", s.x11SocketDir), slog.Error(err))
36+
return false
37+
}
38+
39+
err = addXauthEntry(ctx, s.fs, hostname, strconv.Itoa(int(x11.ScreenNumber)), x11.AuthProtocol, x11.AuthCookie)
40+
if err != nil {
41+
s.logger.Warn(ctx, "failed to add Xauthority entry", slog.Error(err))
3342
return false
3443
}
3544
return true
@@ -64,16 +73,16 @@ func (s *Server) x11Handler(ctx ssh.Context, x11 ssh.X11) bool {
6473
}
6574
unixConn, ok := conn.(*net.UnixConn)
6675
if !ok {
67-
s.logger.Warn(ctx, "failed to cast connection to UnixConn")
76+
s.logger.Warn(ctx, fmt.Sprintf("failed to cast connection to UnixConn. got: %T", conn))
6877
return
6978
}
7079
unixAddr, ok := unixConn.LocalAddr().(*net.UnixAddr)
7180
if !ok {
72-
s.logger.Warn(ctx, "failed to cast local address to UnixAddr")
81+
s.logger.Warn(ctx, fmt.Sprintf("failed to cast local address to UnixAddr. got: %T", unixConn.LocalAddr()))
7382
return
7483
}
7584

76-
channel, _, err := serverConn.OpenChannel("x11", gossh.Marshal(struct {
85+
channel, reqs, err := serverConn.OpenChannel("x11", gossh.Marshal(struct {
7786
OriginatorAddress string
7887
OriginatorPort uint32
7988
}{
@@ -84,7 +93,7 @@ func (s *Server) x11Handler(ctx ssh.Context, x11 ssh.X11) bool {
8493
s.logger.Warn(ctx, "failed to open X11 channel", slog.Error(err))
8594
return
8695
}
87-
96+
go gossh.DiscardRequests(reqs)
8897
go Bicopy(ctx, conn, channel)
8998
}
9099
}()
@@ -93,7 +102,7 @@ func (s *Server) x11Handler(ctx ssh.Context, x11 ssh.X11) bool {
93102

94103
// addXauthEntry adds an Xauthority entry to the Xauthority file.
95104
// The Xauthority file is located at ~/.Xauthority.
96-
func addXauthEntry(fs afero.Fs, host string, display string, authProtocol string, authCookie string) error {
105+
func addXauthEntry(ctx context.Context, fs afero.Fs, host string, display string, authProtocol string, authCookie string) error {
97106
// Get the Xauthority file path
98107
homeDir, err := os.UserHomeDir()
99108
if err != nil {
@@ -102,8 +111,15 @@ func addXauthEntry(fs afero.Fs, host string, display string, authProtocol string
102111

103112
xauthPath := filepath.Join(homeDir, ".Xauthority")
104113

114+
lock := flock.New(xauthPath)
115+
ok, err := lock.TryLockContext(ctx, 100*time.Millisecond)
116+
if !ok {
117+
return xerrors.Errorf("failed to lock Xauthority file: %w", err)
118+
}
119+
defer lock.Close()
120+
105121
// Open or create the Xauthority file
106-
file, err := fs.OpenFile(xauthPath, os.O_RDWR|os.O_CREATE|os.O_APPEND, 0600)
122+
file, err := fs.OpenFile(xauthPath, os.O_RDWR|os.O_CREATE|os.O_APPEND, 0o600)
107123
if err != nil {
108124
return xerrors.Errorf("failed to open Xauthority file: %w", err)
109125
}

agent/agentssh/x11_test.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@ import (
44
"context"
55
"encoding/hex"
66
"net"
7+
"os"
78
"path/filepath"
9+
"runtime"
810
"testing"
911

1012
"github.com/gliderlabs/ssh"
@@ -23,6 +25,9 @@ import (
2325

2426
func TestServer_X11(t *testing.T) {
2527
t.Parallel()
28+
if runtime.GOOS != "linux" {
29+
t.Skip("X11 forwarding is only supported on Linux")
30+
}
2631

2732
ctx := context.Background()
2833
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
@@ -79,4 +84,10 @@ func TestServer_X11(t *testing.T) {
7984
require.NoError(t, err)
8085
s.Close()
8186
<-done
87+
88+
// Ensure the Xauthority file was written!
89+
home, err := os.UserHomeDir()
90+
require.NoError(t, err)
91+
_, err = fs.Stat(filepath.Join(home, ".Xauthority"))
92+
require.NoError(t, err)
8293
}

0 commit comments

Comments
 (0)