Skip to content

Commit aa660e0

Browse files
authored
feat(agentssh): Gracefully close SSH sessions on Close (coder#7027)
By tracking and closing sessions manually before closing the underlying connections, we ensure that the termination is propagated to SSH/SFTP clients and they're not left waiting for a connection timeout. Refs: coder#6177
1 parent f4f40d0 commit aa660e0

File tree

1 file changed

+50
-1
lines changed

1 file changed

+50
-1
lines changed

agent/agentssh/agentssh.go

+50-1
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ type Server struct {
5050
mu sync.RWMutex // Protects following.
5151
listeners map[net.Listener]struct{}
5252
conns map[net.Conn]struct{}
53+
sessions map[ssh.Session]struct{}
5354
closing chan struct{}
5455
// Wait for goroutines to exit, waited without
5556
// a lock on mu but protected by closing.
@@ -86,6 +87,7 @@ func NewServer(ctx context.Context, logger slog.Logger, maxTimeout time.Duration
8687
s := &Server{
8788
listeners: make(map[net.Listener]struct{}),
8889
conns: make(map[net.Conn]struct{}),
90+
sessions: make(map[ssh.Session]struct{}),
8991
logger: logger,
9092
}
9193

@@ -129,7 +131,7 @@ func NewServer(ctx context.Context, logger slog.Logger, maxTimeout time.Duration
129131
}
130132
},
131133
SubsystemHandlers: map[string]ssh.SubsystemHandler{
132-
"sftp": s.sftpHandler,
134+
"sftp": s.sessionHandler,
133135
},
134136
MaxTimeout: maxTimeout,
135137
}
@@ -152,7 +154,26 @@ func (s *Server) ConnStats() ConnStats {
152154
}
153155

154156
func (s *Server) sessionHandler(session ssh.Session) {
157+
if !s.trackSession(session, true) {
158+
// See (*Server).Close() for why we call Close instead of Exit.
159+
_ = session.Close()
160+
return
161+
}
162+
defer s.trackSession(session, false)
163+
155164
ctx := session.Context()
165+
166+
switch ss := session.Subsystem(); ss {
167+
case "":
168+
case "sftp":
169+
s.sftpHandler(session)
170+
return
171+
default:
172+
s.logger.Debug(ctx, "unsupported subsystem", slog.F("subsystem", ss))
173+
_ = session.Exit(1)
174+
return
175+
}
176+
156177
err := s.sessionStart(session)
157178
var exitError *exec.ExitError
158179
if xerrors.As(err, &exitError) {
@@ -560,6 +581,25 @@ func (s *Server) trackConn(l net.Listener, c net.Conn, add bool) (ok bool) {
560581
return true
561582
}
562583

584+
// trackSession registers the session with the server. If the server is
585+
// closing, the session is not registered and should be closed.
586+
//
587+
//nolint:revive
588+
func (s *Server) trackSession(ss ssh.Session, add bool) (ok bool) {
589+
s.mu.Lock()
590+
defer s.mu.Unlock()
591+
if add {
592+
if s.closing != nil {
593+
// Server closed.
594+
return false
595+
}
596+
s.sessions[ss] = struct{}{}
597+
return true
598+
}
599+
delete(s.sessions, ss)
600+
return true
601+
}
602+
563603
// Close the server and all active connections. Server can be re-used
564604
// after Close is done.
565605
func (s *Server) Close() error {
@@ -573,6 +613,15 @@ func (s *Server) Close() error {
573613
}
574614
s.closing = make(chan struct{})
575615

616+
// Close all active sessions to gracefully
617+
// terminate client connections.
618+
for ss := range s.sessions {
619+
// We call Close on the underlying channel here because we don't
620+
// want to send an exit status to the client (via Exit()).
621+
// Typically OpenSSH clients will return 255 as the exit status.
622+
_ = ss.Close()
623+
}
624+
576625
// Close all active listeners and connections.
577626
for l := range s.listeners {
578627
_ = l.Close()

0 commit comments

Comments
 (0)