@@ -2,11 +2,14 @@ package agentssh
2
2
3
3
import (
4
4
"context"
5
+ "errors"
5
6
"fmt"
7
+ "io/fs"
6
8
"net"
7
9
"os"
8
10
"path/filepath"
9
11
"sync"
12
+ "syscall"
10
13
11
14
"github.com/gliderlabs/ssh"
12
15
gossh "golang.org/x/crypto/ssh"
@@ -33,22 +36,29 @@ type forwardedStreamLocalPayload struct {
33
36
type forwardedUnixHandler struct {
34
37
sync.Mutex
35
38
log slog.Logger
36
- forwards map [string ]net.Listener
39
+ forwards map [forwardKey ]net.Listener
40
+ }
41
+
42
+ type forwardKey struct {
43
+ sessionID string
44
+ addr string
45
+ }
46
+
47
+ func newForwardedUnixHandler (log slog.Logger ) * forwardedUnixHandler {
48
+ return & forwardedUnixHandler {
49
+ log : log ,
50
+ forwards : make (map [forwardKey ]net.Listener ),
51
+ }
37
52
}
38
53
39
54
func (h * forwardedUnixHandler ) HandleSSHRequest (ctx ssh.Context , _ * ssh.Server , req * gossh.Request ) (bool , []byte ) {
40
55
h .log .Debug (ctx , "handling SSH unix forward" )
41
- h .Lock ()
42
- if h .forwards == nil {
43
- h .forwards = make (map [string ]net.Listener )
44
- }
45
- h .Unlock ()
46
56
conn , ok := ctx .Value (ssh .ContextKeyConn ).(* gossh.ServerConn )
47
57
if ! ok {
48
58
h .log .Warn (ctx , "SSH unix forward request from client with no gossh connection" )
49
59
return false , nil
50
60
}
51
- log := h .log .With (slog .F ("remote_addr" , conn .RemoteAddr ()))
61
+ log := h .log .With (slog .F ("session_id" , ctx . SessionID ()), slog . F ( " remote_addr" , conn .RemoteAddr ()))
52
62
53
63
switch req .Type {
54
64
case "streamlocal-forward@openssh.com" :
@@ -62,13 +72,17 @@ func (h *forwardedUnixHandler) HandleSSHRequest(ctx ssh.Context, _ *ssh.Server,
62
72
addr := reqPayload .SocketPath
63
73
log = log .With (slog .F ("socket_path" , addr ))
64
74
log .Debug (ctx , "request begin SSH unix forward" )
75
+
76
+ key := forwardKey {
77
+ sessionID : ctx .SessionID (),
78
+ addr : addr ,
79
+ }
80
+
65
81
h .Lock ()
66
- _ , ok := h .forwards [addr ]
82
+ _ , ok := h .forwards [key ]
67
83
h .Unlock ()
68
84
if ok {
69
- log .Warn (ctx , "SSH unix forward request for socket path that is already being forwarded (maybe to another client?)" ,
70
- slog .F ("socket_path" , addr ),
71
- )
85
+ log .Warn (ctx , "SSH unix forward request for socket path that is already being forwarded on this session" )
72
86
return false , nil
73
87
}
74
88
@@ -83,12 +97,18 @@ func (h *forwardedUnixHandler) HandleSSHRequest(ctx ssh.Context, _ *ssh.Server,
83
97
return false , nil
84
98
}
85
99
86
- ln , err := net .Listen ("unix" , addr )
100
+ // Remove existing socket if it exists. It's possible we will overwrite
101
+ // a regular file here, but this matches the behavior of OpenSSH.
102
+ err = unlink (addr )
103
+ if err != nil && ! errors .Is (err , fs .ErrNotExist ) {
104
+ log .Warn (ctx , "remove existing socket for SSH unix forward request" , slog .Error (err ))
105
+ return false , nil
106
+ }
107
+
108
+ lc := & net.ListenConfig {}
109
+ ln , err := lc .Listen (ctx , "unix" , addr )
87
110
if err != nil {
88
- log .Warn (ctx , "listen on Unix socket for SSH unix forward request" ,
89
- slog .F ("socket_path" , addr ),
90
- slog .Error (err ),
91
- )
111
+ log .Warn (ctx , "listen on Unix socket for SSH unix forward request" , slog .Error (err ))
92
112
return false , nil
93
113
}
94
114
log .Debug (ctx , "SSH unix forward listening on socket" )
@@ -99,7 +119,7 @@ func (h *forwardedUnixHandler) HandleSSHRequest(ctx ssh.Context, _ *ssh.Server,
99
119
//
100
120
// This is also what the upstream TCP version of this code does.
101
121
h .Lock ()
102
- h .forwards [addr ] = ln
122
+ h .forwards [key ] = ln
103
123
h .Unlock ()
104
124
log .Debug (ctx , "SSH unix forward added to cache" )
105
125
@@ -115,9 +135,7 @@ func (h *forwardedUnixHandler) HandleSSHRequest(ctx ssh.Context, _ *ssh.Server,
115
135
c , err := ln .Accept ()
116
136
if err != nil {
117
137
if ! xerrors .Is (err , net .ErrClosed ) {
118
- log .Warn (ctx , "accept on local Unix socket for SSH unix forward request" ,
119
- slog .Error (err ),
120
- )
138
+ log .Warn (ctx , "accept on local Unix socket for SSH unix forward request" , slog .Error (err ))
121
139
}
122
140
// closed below
123
141
log .Debug (ctx , "SSH unix forward listener closed" )
@@ -131,10 +149,7 @@ func (h *forwardedUnixHandler) HandleSSHRequest(ctx ssh.Context, _ *ssh.Server,
131
149
go func () {
132
150
ch , reqs , err := conn .OpenChannel ("forwarded-streamlocal@openssh.com" , payload )
133
151
if err != nil {
134
- h .log .Warn (ctx , "open SSH unix forward channel to client" ,
135
- slog .F ("socket_path" , addr ),
136
- slog .Error (err ),
137
- )
152
+ h .log .Warn (ctx , "open SSH unix forward channel to client" , slog .Error (err ))
138
153
_ = c .Close ()
139
154
return
140
155
}
@@ -144,12 +159,11 @@ func (h *forwardedUnixHandler) HandleSSHRequest(ctx ssh.Context, _ *ssh.Server,
144
159
}
145
160
146
161
h .Lock ()
147
- ln2 , ok := h .forwards [addr ]
148
- if ok && ln2 == ln {
149
- delete (h .forwards , addr )
162
+ if ln2 , ok := h .forwards [key ]; ok && ln2 == ln {
163
+ delete (h .forwards , key )
150
164
}
151
165
h .Unlock ()
152
- log .Debug (ctx , "SSH unix forward listener removed from cache" , slog . F ( "path" , addr ) )
166
+ log .Debug (ctx , "SSH unix forward listener removed from cache" )
153
167
_ = ln .Close ()
154
168
}()
155
169
@@ -162,13 +176,23 @@ func (h *forwardedUnixHandler) HandleSSHRequest(ctx ssh.Context, _ *ssh.Server,
162
176
h .log .Warn (ctx , "parse cancel-streamlocal-forward@openssh.com (SSH unix forward) request payload from client" , slog .Error (err ))
163
177
return false , nil
164
178
}
165
- log .Debug (ctx , "request to cancel SSH unix forward" , slog .F ("path" , reqPayload .SocketPath ))
179
+ log .Debug (ctx , "request to cancel SSH unix forward" , slog .F ("socket_path" , reqPayload .SocketPath ))
180
+
181
+ key := forwardKey {
182
+ sessionID : ctx .SessionID (),
183
+ addr : reqPayload .SocketPath ,
184
+ }
185
+
166
186
h .Lock ()
167
- ln , ok := h .forwards [reqPayload .SocketPath ]
187
+ ln , ok := h .forwards [key ]
188
+ delete (h .forwards , key )
168
189
h .Unlock ()
169
- if ok {
170
- _ = ln .Close ()
190
+ if ! ok {
191
+ log .Warn (ctx , "SSH unix forward not found in cache" )
192
+ return true , nil
171
193
}
194
+ log .Debug (ctx , "SSH unix forward listener removed from cache" )
195
+ _ = ln .Close ()
172
196
return true , nil
173
197
174
198
default :
@@ -209,3 +233,14 @@ func directStreamLocalHandler(_ *ssh.Server, _ *gossh.ServerConn, newChan gossh.
209
233
210
234
Bicopy (ctx , ch , dconn )
211
235
}
236
+
237
+ // unlink removes files only.
238
+ func unlink (path string ) error {
239
+ // From os/file_posix.go:
240
+ for {
241
+ err := syscall .Unlink (path )
242
+ if ! errors .Is (err , syscall .EINTR ) {
243
+ return err
244
+ }
245
+ }
246
+ }
0 commit comments