7
7
"errors"
8
8
"fmt"
9
9
"io"
10
+ "math"
10
11
"net"
11
12
"os"
12
13
"path/filepath"
@@ -22,102 +23,136 @@ import (
22
23
"cdr.dev/slog"
23
24
)
24
25
26
+ const (
27
+ // X11StartPort is the starting port for X11 forwarding, this is the
28
+ // port used for "DISPLAY=localhost:0".
29
+ X11StartPort = 6000
30
+ // X11DefaultDisplayOffset is the default offset for X11 forwarding.
31
+ X11DefaultDisplayOffset = 10
32
+ )
33
+
25
34
// x11Callback is called when the client requests X11 forwarding.
26
- // It adds an Xauthority entry to the Xauthority file.
27
- func (s * Server ) x11Callback (ctx ssh.Context , x11 ssh.X11 ) bool {
35
+ func (* Server ) x11Callback (_ ssh.Context , _ ssh.X11 ) bool {
36
+ // Always allow.
37
+ return true
38
+ }
39
+
40
+ // x11Handler is called when a session has requested X11 forwarding.
41
+ // It listens for X11 connections and forwards them to the client.
42
+ func (s * Server ) x11Handler (ctx ssh.Context , x11 ssh.X11 ) (displayNumber int , handled bool ) {
43
+ serverConn , valid := ctx .Value (ssh .ContextKeyConn ).(* gossh.ServerConn )
44
+ if ! valid {
45
+ s .logger .Warn (ctx , "failed to get server connection" )
46
+ return - 1 , false
47
+ }
48
+
28
49
hostname , err := os .Hostname ()
29
50
if err != nil {
30
51
s .logger .Warn (ctx , "failed to get hostname" , slog .Error (err ))
31
52
s .metrics .x11HandlerErrors .WithLabelValues ("hostname" ).Add (1 )
32
- return false
53
+ return - 1 , false
33
54
}
34
55
35
- err = s . fs . MkdirAll ( s .config .X11SocketDir , 0o700 )
56
+ ln , display , err := createX11Listener ( ctx , * s .config .X11DisplayOffset )
36
57
if err != nil {
37
- s .logger .Warn (ctx , "failed to make the x11 socket dir" , slog .F ("dir" , s .config .X11SocketDir ), slog .Error (err ))
38
- s .metrics .x11HandlerErrors .WithLabelValues ("socker_dir" ).Add (1 )
39
- return false
40
- }
58
+ s .logger .Warn (ctx , "failed to create X11 listener" , slog .Error (err ))
59
+ s .metrics .x11HandlerErrors .WithLabelValues ("listen" ).Add (1 )
60
+ return - 1 , false
61
+ }
62
+ s .trackListener (ln , true )
63
+ defer func () {
64
+ if ! handled {
65
+ s .trackListener (ln , false )
66
+ _ = ln .Close ()
67
+ }
68
+ }()
41
69
42
- err = addXauthEntry (ctx , s .fs , hostname , strconv .Itoa (int ( x11 . ScreenNumber ) ), x11 .AuthProtocol , x11 .AuthCookie )
70
+ err = addXauthEntry (ctx , s .fs , hostname , strconv .Itoa (display ), x11 .AuthProtocol , x11 .AuthCookie )
43
71
if err != nil {
44
72
s .logger .Warn (ctx , "failed to add Xauthority entry" , slog .Error (err ))
45
73
s .metrics .x11HandlerErrors .WithLabelValues ("xauthority" ).Add (1 )
46
- return false
74
+ return - 1 , false
47
75
}
48
- return true
49
- }
50
76
51
- // x11Handler is called when a session has requested X11 forwarding.
52
- // It listens for X11 connections and forwards them to the client.
53
- func (s * Server ) x11Handler (ctx ssh.Context , x11 ssh.X11 ) bool {
54
- serverConn , valid := ctx .Value (ssh .ContextKeyConn ).(* gossh.ServerConn )
55
- if ! valid {
56
- s .logger .Warn (ctx , "failed to get server connection" )
57
- return false
58
- }
59
- // We want to overwrite the socket so that subsequent connections will succeed.
60
- socketPath := filepath .Join (s .config .X11SocketDir , fmt .Sprintf ("X%d" , x11 .ScreenNumber ))
61
- err := os .Remove (socketPath )
62
- if err != nil && ! errors .Is (err , os .ErrNotExist ) {
63
- s .logger .Warn (ctx , "failed to remove existing X11 socket" , slog .Error (err ))
64
- return false
65
- }
66
- listener , err := net .Listen ("unix" , socketPath )
67
- if err != nil {
68
- s .logger .Warn (ctx , "failed to listen for X11" , slog .Error (err ))
69
- return false
70
- }
71
- s .trackListener (listener , true )
77
+ go func () {
78
+ // Don't leave the listener open after the session is gone.
79
+ <- ctx .Done ()
80
+ _ = ln .Close ()
81
+ }()
72
82
73
83
go func () {
74
- defer listener .Close ()
75
- defer s .trackListener (listener , false )
76
- handledFirstConnection := false
84
+ defer ln .Close ()
85
+ defer s .trackListener (ln , false )
77
86
78
87
for {
79
- conn , err := listener .Accept ()
88
+ conn , err := ln .Accept ()
80
89
if err != nil {
81
90
if errors .Is (err , net .ErrClosed ) {
82
91
return
83
92
}
84
93
s .logger .Warn (ctx , "failed to accept X11 connection" , slog .Error (err ))
85
94
return
86
95
}
87
- if x11 .SingleConnection && handledFirstConnection {
88
- s .logger .Warn (ctx , "X11 connection rejected because single connection is enabled" )
89
- _ = conn .Close ()
90
- continue
96
+ if x11 .SingleConnection {
97
+ s .logger .Debug (ctx , "single connection requested, closing X11 listener" )
98
+ _ = ln .Close ()
91
99
}
92
- handledFirstConnection = true
93
100
94
- unixConn , ok := conn .(* net.UnixConn )
101
+ tcpConn , ok := conn .(* net.TCPConn )
95
102
if ! ok {
96
- s .logger .Warn (ctx , fmt .Sprintf ("failed to cast connection to UnixConn. got: %T" , conn ))
97
- return
103
+ s .logger .Warn (ctx , fmt .Sprintf ("failed to cast connection to TCPConn. got: %T" , conn ))
104
+ _ = conn .Close ()
105
+ continue
98
106
}
99
- unixAddr , ok := unixConn .LocalAddr ().(* net.UnixAddr )
107
+ tcpAddr , ok := tcpConn .LocalAddr ().(* net.TCPAddr )
100
108
if ! ok {
101
- s .logger .Warn (ctx , fmt .Sprintf ("failed to cast local address to UnixAddr. got: %T" , unixConn .LocalAddr ()))
102
- return
109
+ s .logger .Warn (ctx , fmt .Sprintf ("failed to cast local address to TCPAddr. got: %T" , tcpConn .LocalAddr ()))
110
+ _ = conn .Close ()
111
+ continue
103
112
}
104
113
105
114
channel , reqs , err := serverConn .OpenChannel ("x11" , gossh .Marshal (struct {
106
115
OriginatorAddress string
107
116
OriginatorPort uint32
108
117
}{
109
- OriginatorAddress : unixAddr . Name ,
110
- OriginatorPort : 0 ,
118
+ OriginatorAddress : tcpAddr . IP . String () ,
119
+ OriginatorPort : uint32 ( tcpAddr . Port ) ,
111
120
}))
112
121
if err != nil {
113
122
s .logger .Warn (ctx , "failed to open X11 channel" , slog .Error (err ))
114
- return
123
+ _ = conn .Close ()
124
+ continue
115
125
}
116
126
go gossh .DiscardRequests (reqs )
117
- go Bicopy (ctx , conn , channel )
127
+
128
+ if ! s .trackConn (ln , conn , true ) {
129
+ s .logger .Warn (ctx , "failed to track X11 connection" )
130
+ _ = conn .Close ()
131
+ continue
132
+ }
133
+ go func () {
134
+ defer s .trackConn (ln , conn , false )
135
+ Bicopy (ctx , conn , channel )
136
+ }()
118
137
}
119
138
}()
120
- return true
139
+
140
+ return display , true
141
+ }
142
+
143
+ // createX11Listener creates a listener for X11 forwarding, it will use
144
+ // the next available port starting from X11StartPort and displayOffset.
145
+ func createX11Listener (ctx context.Context , displayOffset int ) (ln net.Listener , display int , err error ) {
146
+ var lc net.ListenConfig
147
+ // Look for an open port to listen on.
148
+ for port := X11StartPort + displayOffset ; port < math .MaxUint16 ; port ++ {
149
+ ln , err = lc .Listen (ctx , "tcp" , fmt .Sprintf ("localhost:%d" , port ))
150
+ if err == nil {
151
+ display = port - X11StartPort
152
+ return ln , display , nil
153
+ }
154
+ }
155
+ return nil , - 1 , xerrors .Errorf ("failed to find open port for X11 listener: %w" , err )
121
156
}
122
157
123
158
// addXauthEntry adds an Xauthority entry to the Xauthority file.
0 commit comments