7
7
"errors"
8
8
"fmt"
9
9
"io"
10
+ "math"
10
11
"net"
11
12
"os"
12
13
"path/filepath"
@@ -22,61 +23,72 @@ import (
22
23
"cdr.dev/slog"
23
24
)
24
25
25
- // x11Callback is called when the client requests X11 forwarding.
26
- // It adds an Xauthority entry to the Xauthority file.
27
- func (s * Server ) x11Callback (ctx ssh.Context , x11 ssh.X11 ) bool {
28
- hostname , err := os .Hostname ()
29
- if err != nil {
30
- s .logger .Warn (ctx , "failed to get hostname" , slog .Error (err ))
31
- s .metrics .x11HandlerErrors .WithLabelValues ("hostname" ).Add (1 )
32
- return false
33
- }
34
-
35
- err = s .fs .MkdirAll (s .config .X11SocketDir , 0o700 )
36
- if err != nil {
37
- s .logger .Warn (ctx , "failed to make the x11 socket dir" , slog .F ("dir" , s .config .X11SocketDir ), slog .Error (err ))
38
- s .metrics .x11HandlerErrors .WithLabelValues ("socker_dir" ).Add (1 )
39
- return false
40
- }
26
+ const (
27
+ X11StartPort = 6000
28
+ X11DefaultDisplayOffset = 10
29
+ )
41
30
42
- err = addXauthEntry (ctx , s .fs , hostname , strconv .Itoa (int (x11 .ScreenNumber )), x11 .AuthProtocol , x11 .AuthCookie )
43
- if err != nil {
44
- s .logger .Warn (ctx , "failed to add Xauthority entry" , slog .Error (err ))
45
- s .metrics .x11HandlerErrors .WithLabelValues ("xauthority" ).Add (1 )
46
- return false
47
- }
31
+ // x11Callback is called when the client requests X11 forwarding.
32
+ func (* Server ) x11Callback (_ ssh.Context , _ ssh.X11 ) bool {
33
+ // Always allow.
48
34
return true
49
35
}
50
36
51
37
// x11Handler is called when a session has requested X11 forwarding.
52
38
// It listens for X11 connections and forwards them to the client.
53
- func (s * Server ) x11Handler (ctx ssh.Context , x11 ssh.X11 ) bool {
39
+ func (s * Server ) x11Handler (ctx ssh.Context , x11 ssh.X11 ) ( display int , handled bool ) {
54
40
serverConn , valid := ctx .Value (ssh .ContextKeyConn ).(* gossh.ServerConn )
55
41
if ! valid {
56
42
s .logger .Warn (ctx , "failed to get server connection" )
57
- return false
43
+ return - 1 , false
58
44
}
59
- // We want to overwrite the socket so that subsequent connections will succeed.
60
- socketPath := filepath .Join (s .config .X11SocketDir , fmt .Sprintf ("X%d" , x11 .ScreenNumber ))
61
- err := os .Remove (socketPath )
62
- if err != nil && ! errors .Is (err , os .ErrNotExist ) {
63
- s .logger .Warn (ctx , "failed to remove existing X11 socket" , slog .Error (err ))
64
- return false
65
- }
66
- listener , err := net .Listen ("unix" , socketPath )
45
+
46
+ hostname , err := os .Hostname ()
67
47
if err != nil {
48
+ s .logger .Warn (ctx , "failed to get hostname" , slog .Error (err ))
49
+ s .metrics .x11HandlerErrors .WithLabelValues ("hostname" ).Add (1 )
50
+ return - 1 , false
51
+ }
52
+
53
+ var (
54
+ lc net.ListenConfig
55
+ ln net.Listener
56
+ port = X11StartPort + * s .config .X11DisplayOffset
57
+ )
58
+ for ; port >= 6000 && port < math .MaxUint16 ; port ++ {
59
+ ln , err = lc .Listen (ctx , "tcp" , fmt .Sprintf ("localhost:%d" , port ))
60
+ if err == nil {
61
+ display = port - X11StartPort
62
+ break
63
+ }
64
+ }
65
+ if ln == nil {
68
66
s .logger .Warn (ctx , "failed to listen for X11" , slog .Error (err ))
69
- return false
67
+ s .metrics .x11HandlerErrors .WithLabelValues ("listen" ).Add (1 )
68
+ return - 1 , false
69
+ }
70
+ s .trackListener (ln , true )
71
+ defer func () {
72
+ if ! handled {
73
+ s .trackListener (ln , false )
74
+ _ = ln .Close ()
75
+ }
76
+ }()
77
+
78
+ err = addXauthEntry (ctx , s .fs , hostname , strconv .Itoa (port ), x11 .AuthProtocol , x11 .AuthCookie )
79
+ if err != nil {
80
+ s .logger .Warn (ctx , "failed to add Xauthority entry" , slog .Error (err ))
81
+ s .metrics .x11HandlerErrors .WithLabelValues ("xauthority" ).Add (1 )
82
+ return - 1 , false
70
83
}
71
- s .trackListener (listener , true )
72
84
73
85
go func () {
74
- defer listener .Close ()
75
- defer s .trackListener (listener , false )
86
+ defer ln .Close ()
87
+ defer s .trackListener (ln , false )
76
88
handledFirstConnection := false
77
89
78
90
for {
79
- conn , err := listener .Accept ()
91
+ conn , err := ln .Accept ()
80
92
if err != nil {
81
93
if errors .Is (err , net .ErrClosed ) {
82
94
return
@@ -91,33 +103,37 @@ func (s *Server) x11Handler(ctx ssh.Context, x11 ssh.X11) bool {
91
103
}
92
104
handledFirstConnection = true
93
105
94
- unixConn , ok := conn .(* net.UnixConn )
106
+ tcpConn , ok := conn .(* net.TCPConn )
95
107
if ! ok {
96
- s .logger .Warn (ctx , fmt .Sprintf ("failed to cast connection to UnixConn. got: %T" , conn ))
108
+ s .logger .Warn (ctx , fmt .Sprintf ("failed to cast connection to TCPConn. got: %T" , conn ))
109
+ _ = conn .Close ()
97
110
return
98
111
}
99
- unixAddr , ok := unixConn .LocalAddr ().(* net.UnixAddr )
112
+ tcpAddr , ok := tcpConn .LocalAddr ().(* net.TCPAddr )
100
113
if ! ok {
101
- s .logger .Warn (ctx , fmt .Sprintf ("failed to cast local address to UnixAddr. got: %T" , unixConn .LocalAddr ()))
114
+ s .logger .Warn (ctx , fmt .Sprintf ("failed to cast local address to TCPAddr. got: %T" , tcpConn .LocalAddr ()))
115
+ _ = conn .Close ()
102
116
return
103
117
}
104
118
105
119
channel , reqs , err := serverConn .OpenChannel ("x11" , gossh .Marshal (struct {
106
120
OriginatorAddress string
107
121
OriginatorPort uint32
108
122
}{
109
- OriginatorAddress : unixAddr . Name ,
110
- OriginatorPort : 0 ,
123
+ OriginatorAddress : tcpAddr . IP . String () ,
124
+ OriginatorPort : uint32 ( tcpAddr . Port ) ,
111
125
}))
112
126
if err != nil {
113
127
s .logger .Warn (ctx , "failed to open X11 channel" , slog .Error (err ))
128
+ _ = conn .Close ()
114
129
return
115
130
}
116
131
go gossh .DiscardRequests (reqs )
117
132
go Bicopy (ctx , conn , channel )
118
133
}
119
134
}()
120
- return true
135
+
136
+ return display , true
121
137
}
122
138
123
139
// addXauthEntry adds an Xauthority entry to the Xauthority file.
0 commit comments