@@ -2,11 +2,14 @@ package agentssh
2
2
3
3
import (
4
4
"context"
5
+ "errors"
5
6
"fmt"
7
+ "io/fs"
6
8
"net"
7
9
"os"
8
10
"path/filepath"
9
11
"sync"
12
+ "syscall"
10
13
11
14
"github.com/gliderlabs/ssh"
12
15
gossh "golang.org/x/crypto/ssh"
@@ -33,22 +36,29 @@ type forwardedStreamLocalPayload struct {
33
36
type forwardedUnixHandler struct {
34
37
sync.Mutex
35
38
log slog.Logger
36
- forwards map [string ]net.Listener
39
+ forwards map [forwardKey ]net.Listener
40
+ }
41
+
42
+ type forwardKey struct {
43
+ sessionID string
44
+ addr string
45
+ }
46
+
47
+ func newForwardedUnixHandler (log slog.Logger ) * forwardedUnixHandler {
48
+ return & forwardedUnixHandler {
49
+ log : log ,
50
+ forwards : make (map [forwardKey ]net.Listener ),
51
+ }
37
52
}
38
53
39
54
func (h * forwardedUnixHandler ) HandleSSHRequest (ctx ssh.Context , _ * ssh.Server , req * gossh.Request ) (bool , []byte ) {
40
55
h .log .Debug (ctx , "handling SSH unix forward" )
41
- h .Lock ()
42
- if h .forwards == nil {
43
- h .forwards = make (map [string ]net.Listener )
44
- }
45
- h .Unlock ()
46
56
conn , ok := ctx .Value (ssh .ContextKeyConn ).(* gossh.ServerConn )
47
57
if ! ok {
48
58
h .log .Warn (ctx , "SSH unix forward request from client with no gossh connection" )
49
59
return false , nil
50
60
}
51
- log := h .log .With (slog .F ("remote_addr" , conn .RemoteAddr ()))
61
+ log := h .log .With (slog .F ("session_id" , ctx . SessionID ()), slog . F ( " remote_addr" , conn .RemoteAddr ()))
52
62
53
63
switch req .Type {
54
64
case "streamlocal-forward@openssh.com" :
@@ -62,14 +72,22 @@ func (h *forwardedUnixHandler) HandleSSHRequest(ctx ssh.Context, _ *ssh.Server,
62
72
addr := reqPayload .SocketPath
63
73
log = log .With (slog .F ("socket_path" , addr ))
64
74
log .Debug (ctx , "request begin SSH unix forward" )
75
+
76
+ key := forwardKey {
77
+ sessionID : ctx .SessionID (),
78
+ addr : addr ,
79
+ }
80
+
65
81
h .Lock ()
66
- _ , ok := h .forwards [addr ]
82
+ _ , ok := h .forwards [key ]
67
83
h .Unlock ()
68
84
if ok {
69
- log .Warn (ctx , "SSH unix forward request for socket path that is already being forwarded (maybe to another client?)" ,
70
- slog .F ("socket_path" , addr ),
71
- )
72
- return false , nil
85
+ // In cases where `ExitOnForwardFailure=yes` is set, returning false
86
+ // here will cause the connection to be closed. To avoid this, and
87
+ // to match OpenSSH behavior, we silently ignore the second forward
88
+ // request.
89
+ log .Warn (ctx , "SSH unix forward request for socket path that is already being forwarded on this session, ignoring" )
90
+ return true , nil
73
91
}
74
92
75
93
// Create socket parent dir if not exists.
@@ -83,12 +101,20 @@ func (h *forwardedUnixHandler) HandleSSHRequest(ctx ssh.Context, _ *ssh.Server,
83
101
return false , nil
84
102
}
85
103
86
- ln , err := net .Listen ("unix" , addr )
104
+ // Remove existing socket if it exists. We do not use os.Remove() here
105
+ // so that directories are kept. Note that it's possible that we will
106
+ // overwrite a regular file here. Both of these behaviors match OpenSSH,
107
+ // however, which is why we unlink.
108
+ err = unlink (addr )
109
+ if err != nil && ! errors .Is (err , fs .ErrNotExist ) {
110
+ log .Warn (ctx , "remove existing socket for SSH unix forward request" , slog .Error (err ))
111
+ return false , nil
112
+ }
113
+
114
+ lc := & net.ListenConfig {}
115
+ ln , err := lc .Listen (ctx , "unix" , addr )
87
116
if err != nil {
88
- log .Warn (ctx , "listen on Unix socket for SSH unix forward request" ,
89
- slog .F ("socket_path" , addr ),
90
- slog .Error (err ),
91
- )
117
+ log .Warn (ctx , "listen on Unix socket for SSH unix forward request" , slog .Error (err ))
92
118
return false , nil
93
119
}
94
120
log .Debug (ctx , "SSH unix forward listening on socket" )
@@ -99,7 +125,7 @@ func (h *forwardedUnixHandler) HandleSSHRequest(ctx ssh.Context, _ *ssh.Server,
99
125
//
100
126
// This is also what the upstream TCP version of this code does.
101
127
h .Lock ()
102
- h .forwards [addr ] = ln
128
+ h .forwards [key ] = ln
103
129
h .Unlock ()
104
130
log .Debug (ctx , "SSH unix forward added to cache" )
105
131
@@ -115,9 +141,7 @@ func (h *forwardedUnixHandler) HandleSSHRequest(ctx ssh.Context, _ *ssh.Server,
115
141
c , err := ln .Accept ()
116
142
if err != nil {
117
143
if ! xerrors .Is (err , net .ErrClosed ) {
118
- log .Warn (ctx , "accept on local Unix socket for SSH unix forward request" ,
119
- slog .Error (err ),
120
- )
144
+ log .Warn (ctx , "accept on local Unix socket for SSH unix forward request" , slog .Error (err ))
121
145
}
122
146
// closed below
123
147
log .Debug (ctx , "SSH unix forward listener closed" )
@@ -131,10 +155,7 @@ func (h *forwardedUnixHandler) HandleSSHRequest(ctx ssh.Context, _ *ssh.Server,
131
155
go func () {
132
156
ch , reqs , err := conn .OpenChannel ("forwarded-streamlocal@openssh.com" , payload )
133
157
if err != nil {
134
- h .log .Warn (ctx , "open SSH unix forward channel to client" ,
135
- slog .F ("socket_path" , addr ),
136
- slog .Error (err ),
137
- )
158
+ h .log .Warn (ctx , "open SSH unix forward channel to client" , slog .Error (err ))
138
159
_ = c .Close ()
139
160
return
140
161
}
@@ -144,12 +165,11 @@ func (h *forwardedUnixHandler) HandleSSHRequest(ctx ssh.Context, _ *ssh.Server,
144
165
}
145
166
146
167
h .Lock ()
147
- ln2 , ok := h .forwards [addr ]
148
- if ok && ln2 == ln {
149
- delete (h .forwards , addr )
168
+ if ln2 , ok := h .forwards [key ]; ok && ln2 == ln {
169
+ delete (h .forwards , key )
150
170
}
151
171
h .Unlock ()
152
- log .Debug (ctx , "SSH unix forward listener removed from cache" , slog . F ( "path" , addr ) )
172
+ log .Debug (ctx , "SSH unix forward listener removed from cache" )
153
173
_ = ln .Close ()
154
174
}()
155
175
@@ -162,13 +182,22 @@ func (h *forwardedUnixHandler) HandleSSHRequest(ctx ssh.Context, _ *ssh.Server,
162
182
h .log .Warn (ctx , "parse cancel-streamlocal-forward@openssh.com (SSH unix forward) request payload from client" , slog .Error (err ))
163
183
return false , nil
164
184
}
165
- log .Debug (ctx , "request to cancel SSH unix forward" , slog .F ("path" , reqPayload .SocketPath ))
185
+ log .Debug (ctx , "request to cancel SSH unix forward" , slog .F ("socket_path" , reqPayload .SocketPath ))
186
+
187
+ key := forwardKey {
188
+ sessionID : ctx .SessionID (),
189
+ addr : reqPayload .SocketPath ,
190
+ }
191
+
166
192
h .Lock ()
167
- ln , ok := h .forwards [reqPayload .SocketPath ]
193
+ ln , ok := h .forwards [key ]
194
+ delete (h .forwards , key )
168
195
h .Unlock ()
169
- if ok {
170
- _ = ln .Close ()
196
+ if ! ok {
197
+ log .Warn (ctx , "SSH unix forward not found in cache" )
198
+ return true , nil
171
199
}
200
+ _ = ln .Close ()
172
201
return true , nil
173
202
174
203
default :
@@ -209,3 +238,15 @@ func directStreamLocalHandler(_ *ssh.Server, _ *gossh.ServerConn, newChan gossh.
209
238
210
239
Bicopy (ctx , ch , dconn )
211
240
}
241
+
242
+ // unlink removes files and unlike os.Remove, directories are kept.
243
+ func unlink (path string ) error {
244
+ // Ignore EINTR like os.Remove, see ignoringEINTR in os/file_posix.go
245
+ // for more details.
246
+ for {
247
+ err := syscall .Unlink (path )
248
+ if ! errors .Is (err , syscall .EINTR ) {
249
+ return err
250
+ }
251
+ }
252
+ }
0 commit comments