@@ -50,6 +50,7 @@ type Server struct {
50
50
mu sync.RWMutex // Protects following.
51
51
listeners map [net.Listener ]struct {}
52
52
conns map [net.Conn ]struct {}
53
+ sessions map [ssh.Session ]struct {}
53
54
closing chan struct {}
54
55
// Wait for goroutines to exit, waited without
55
56
// a lock on mu but protected by closing.
@@ -86,6 +87,7 @@ func NewServer(ctx context.Context, logger slog.Logger, maxTimeout time.Duration
86
87
s := & Server {
87
88
listeners : make (map [net.Listener ]struct {}),
88
89
conns : make (map [net.Conn ]struct {}),
90
+ sessions : make (map [ssh.Session ]struct {}),
89
91
logger : logger ,
90
92
}
91
93
@@ -129,7 +131,7 @@ func NewServer(ctx context.Context, logger slog.Logger, maxTimeout time.Duration
129
131
}
130
132
},
131
133
SubsystemHandlers : map [string ]ssh.SubsystemHandler {
132
- "sftp" : s .sftpHandler ,
134
+ "sftp" : s .sessionHandler ,
133
135
},
134
136
MaxTimeout : maxTimeout ,
135
137
}
@@ -152,7 +154,26 @@ func (s *Server) ConnStats() ConnStats {
152
154
}
153
155
154
156
func (s * Server ) sessionHandler (session ssh.Session ) {
157
+ if ! s .trackSession (session , true ) {
158
+ // See (*Server).Close() for why we call Close instead of Exit.
159
+ _ = session .Close ()
160
+ return
161
+ }
162
+ defer s .trackSession (session , false )
163
+
155
164
ctx := session .Context ()
165
+
166
+ switch ss := session .Subsystem (); ss {
167
+ case "" :
168
+ case "sftp" :
169
+ s .sftpHandler (session )
170
+ return
171
+ default :
172
+ s .logger .Debug (ctx , "unsupported subsystem" , slog .F ("subsystem" , ss ))
173
+ _ = session .Exit (1 )
174
+ return
175
+ }
176
+
156
177
err := s .sessionStart (session )
157
178
var exitError * exec.ExitError
158
179
if xerrors .As (err , & exitError ) {
@@ -560,6 +581,25 @@ func (s *Server) trackConn(l net.Listener, c net.Conn, add bool) (ok bool) {
560
581
return true
561
582
}
562
583
584
+ // trackSession registers the session with the server. If the server is
585
+ // closing, the session is not registered and should be closed.
586
+ //
587
+ //nolint:revive
588
+ func (s * Server ) trackSession (ss ssh.Session , add bool ) (ok bool ) {
589
+ s .mu .Lock ()
590
+ defer s .mu .Unlock ()
591
+ if add {
592
+ if s .closing != nil {
593
+ // Server closed.
594
+ return false
595
+ }
596
+ s .sessions [ss ] = struct {}{}
597
+ return true
598
+ }
599
+ delete (s .sessions , ss )
600
+ return true
601
+ }
602
+
563
603
// Close the server and all active connections. Server can be re-used
564
604
// after Close is done.
565
605
func (s * Server ) Close () error {
@@ -573,6 +613,15 @@ func (s *Server) Close() error {
573
613
}
574
614
s .closing = make (chan struct {})
575
615
616
+ // Close all active sessions to gracefully
617
+ // terminate client connections.
618
+ for ss := range s .sessions {
619
+ // We call Close on the underlying channel here because we don't
620
+ // want to send an exit status to the client (via Exit()).
621
+ // Typically OpenSSH clients will return 255 as the exit status.
622
+ _ = ss .Close ()
623
+ }
624
+
576
625
// Close all active listeners and connections.
577
626
for l := range s .listeners {
578
627
_ = l .Close ()
0 commit comments