Skip to content

fix(agent): ensure SSH server shutdown with process groups #17227

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Apr 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 14 additions & 7 deletions agent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -1773,15 +1773,22 @@ func (a *agent) Close() error {
a.setLifecycle(codersdk.WorkspaceAgentLifecycleShuttingDown)

// Attempt to gracefully shut down all active SSH connections and
// stop accepting new ones.
err := a.sshServer.Shutdown(a.hardCtx)
// stop accepting new ones. If all processes have not exited after 5
// seconds, we just log it and move on as it's more important to run
// the shutdown scripts. A typical shutdown time for containers is
// 10 seconds, so this still leaves a bit of time to run the
// shutdown scripts in the worst-case.
sshShutdownCtx, sshShutdownCancel := context.WithTimeout(a.hardCtx, 5*time.Second)
defer sshShutdownCancel()
err := a.sshServer.Shutdown(sshShutdownCtx)
if err != nil {
a.logger.Error(a.hardCtx, "ssh server shutdown", slog.Error(err))
}
err = a.sshServer.Close()
if err != nil {
a.logger.Error(a.hardCtx, "ssh server close", slog.Error(err))
if errors.Is(err, context.DeadlineExceeded) {
a.logger.Warn(sshShutdownCtx, "ssh server shutdown timeout", slog.Error(err))
} else {
a.logger.Error(sshShutdownCtx, "ssh server shutdown", slog.Error(err))
}
}

// wait for SSH to shut down before the general graceful cancel, because
// this triggers a disconnect in the tailnet layer, telling all clients to
// shut down their wireguard tunnels to us. If SSH sessions are still up,
Expand Down
66 changes: 53 additions & 13 deletions agent/agentssh/agentssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -582,6 +582,12 @@ func (s *Server) sessionStart(logger slog.Logger, session ssh.Session, env []str
func (s *Server) startNonPTYSession(logger slog.Logger, session ssh.Session, magicTypeLabel string, cmd *exec.Cmd) error {
s.metrics.sessionsTotal.WithLabelValues(magicTypeLabel, "no").Add(1)

// Create a process group and send SIGHUP to child processes,
// otherwise context cancellation will not propagate properly
// and SSH server close may be delayed.
cmd.SysProcAttr = cmdSysProcAttr()
cmd.Cancel = cmdCancel(session.Context(), logger, cmd)

cmd.Stdout = session
cmd.Stderr = session.Stderr()
// This blocks forever until stdin is received if we don't
Expand Down Expand Up @@ -926,7 +932,12 @@ func (s *Server) CreateCommand(ctx context.Context, script string, env []string,
// Serve starts the server to handle incoming connections on the provided listener.
// It returns an error if no host keys are set or if there is an issue accepting connections.
func (s *Server) Serve(l net.Listener) (retErr error) {
if len(s.srv.HostSigners) == 0 {
// Ensure we're not mutating HostSigners as we're reading it.
s.mu.RLock()
noHostKeys := len(s.srv.HostSigners) == 0
s.mu.RUnlock()

if noHostKeys {
return xerrors.New("no host keys set")
}

Expand Down Expand Up @@ -1054,43 +1065,72 @@ func (s *Server) Close() error {
}
s.closing = make(chan struct{})

ctx := context.Background()

s.logger.Debug(ctx, "closing server")

// Stop accepting new connections.
s.logger.Debug(ctx, "closing all active listeners", slog.F("count", len(s.listeners)))
for l := range s.listeners {
_ = l.Close()
}

// Close all active sessions to gracefully
// terminate client connections.
s.logger.Debug(ctx, "closing all active sessions", slog.F("count", len(s.sessions)))
for ss := range s.sessions {
// We call Close on the underlying channel here because we don't
// want to send an exit status to the client (via Exit()).
// Typically OpenSSH clients will return 255 as the exit status.
_ = ss.Close()
}

// Close all active listeners and connections.
for l := range s.listeners {
_ = l.Close()
}
s.logger.Debug(ctx, "closing all active connections", slog.F("count", len(s.conns)))
for c := range s.conns {
_ = c.Close()
}

// Close the underlying SSH server.
s.logger.Debug(ctx, "closing SSH server")
err := s.srv.Close()

s.mu.Unlock()

s.logger.Debug(ctx, "waiting for all goroutines to exit")
s.wg.Wait() // Wait for all goroutines to exit.

s.mu.Lock()
close(s.closing)
s.closing = nil
s.mu.Unlock()

s.logger.Debug(ctx, "closing server done")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion: log elapsed time for closing


return err
}

// Shutdown gracefully closes all active SSH connections and stops
// accepting new connections.
//
// Shutdown is not implemented.
func (*Server) Shutdown(_ context.Context) error {
// TODO(mafredri): Implement shutdown, SIGHUP running commands, etc.
// Shutdown stops accepting new connections. The current implementation
// calls Close() for simplicity instead of waiting for existing
// connections to close. If the context times out, Shutdown will return
// but Close() may not have completed.
func (s *Server) Shutdown(ctx context.Context) error {
ch := make(chan error, 1)
go func() {
// TODO(mafredri): Implement shutdown, SIGHUP running commands, etc.
// For now we just close the server.
ch <- s.Close()
}()
var err error
select {
case <-ctx.Done():
err = ctx.Err()
case err = <-ch:
}
// Re-check for context cancellation precedence.
if ctx.Err() != nil {
err = ctx.Err()
}
if err != nil {
return xerrors.Errorf("close server: %w", err)
}
return nil
}

Expand Down
116 changes: 79 additions & 37 deletions agent/agentssh/agentssh_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"go.uber.org/goleak"
"golang.org/x/crypto/ssh"

"cdr.dev/slog"
"cdr.dev/slog/sloggers/slogtest"

"github.com/coder/coder/v2/agent/agentexec"
Expand Down Expand Up @@ -147,51 +148,92 @@ func (*fakeEnvInfoer) ModifyCommand(cmd string, args ...string) (string, []strin
func TestNewServer_CloseActiveConnections(t *testing.T) {
t.Parallel()

ctx := context.Background()
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), agentexec.DefaultExecer, nil)
require.NoError(t, err)
defer s.Close()
err = s.UpdateHostSigner(42)
assert.NoError(t, err)
prepare := func(ctx context.Context, t *testing.T) (*agentssh.Server, func()) {
t.Helper()
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), agentexec.DefaultExecer, nil)
require.NoError(t, err)
defer s.Close()
err = s.UpdateHostSigner(42)
assert.NoError(t, err)

ln, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
ln, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)

var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
err := s.Serve(ln)
assert.Error(t, err) // Server is closed.
}()
waitConns := make([]chan struct{}, 4)

pty := ptytest.New(t)
var wg sync.WaitGroup
wg.Add(1 + len(waitConns))

doClose := make(chan struct{})
go func() {
defer wg.Done()
c := sshClient(t, ln.Addr().String())
sess, err := c.NewSession()
assert.NoError(t, err)
sess.Stdin = pty.Input()
sess.Stdout = pty.Output()
sess.Stderr = pty.Output()
go func() {
defer wg.Done()
err := s.Serve(ln)
assert.Error(t, err) // Server is closed.
}()

assert.NoError(t, err)
err = sess.Start("")
assert.NoError(t, err)
for i := 0; i < len(waitConns); i++ {
waitConns[i] = make(chan struct{})
go func(ch chan struct{}) {
defer wg.Done()
c := sshClient(t, ln.Addr().String())
sess, err := c.NewSession()
assert.NoError(t, err)
pty := ptytest.New(t)
sess.Stdin = pty.Input()
sess.Stdout = pty.Output()
sess.Stderr = pty.Output()

// Every other session will request a PTY.
if i%2 == 0 {
err = sess.RequestPty("xterm", 80, 80, nil)
assert.NoError(t, err)
}
// The 60 seconds here is intended to be longer than the
// test. The shutdown should propagate.
err = sess.Start("/bin/bash -c 'trap \"sleep 60\" SIGTERM; sleep 60'")
assert.NoError(t, err)

close(ch)
err = sess.Wait()
assert.Error(t, err)
}(waitConns[i])
}

close(doClose)
err = sess.Wait()
assert.Error(t, err)
}()
for _, ch := range waitConns {
<-ch
}

<-doClose
err = s.Close()
require.NoError(t, err)
return s, wg.Wait
}

t.Run("Close", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitMedium)
s, wait := prepare(ctx, t)
err := s.Close()
require.NoError(t, err)
wait()
})

wg.Wait()
t.Run("Shutdown", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitMedium)
s, wait := prepare(ctx, t)
err := s.Shutdown(ctx)
require.NoError(t, err)
wait()
})

t.Run("Shutdown Early", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitMedium)
s, wait := prepare(ctx, t)
ctx, cancel := context.WithCancel(ctx)
cancel()
err := s.Shutdown(ctx)
require.ErrorIs(t, err, context.Canceled)
wait()
})
}

func TestNewServer_Signal(t *testing.T) {
Expand Down
24 changes: 24 additions & 0 deletions agent/agentssh/exec_other.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
//go:build !windows
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Review: These are copy-pasta from agentscripts package.


package agentssh

import (
"context"
"os/exec"
"syscall"

"cdr.dev/slog"
)

func cmdSysProcAttr() *syscall.SysProcAttr {
return &syscall.SysProcAttr{
Setsid: true,
}
}

func cmdCancel(ctx context.Context, logger slog.Logger, cmd *exec.Cmd) func() error {
return func() error {
logger.Debug(ctx, "cmdCancel: sending SIGHUP to process and children", slog.F("pid", cmd.Process.Pid))
return syscall.Kill(-cmd.Process.Pid, syscall.SIGHUP)
}
}
21 changes: 21 additions & 0 deletions agent/agentssh/exec_windows.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package agentssh

import (
"context"
"os"
"os/exec"
"syscall"

"cdr.dev/slog"
)

func cmdSysProcAttr() *syscall.SysProcAttr {
return &syscall.SysProcAttr{}
}

func cmdCancel(ctx context.Context, logger slog.Logger, cmd *exec.Cmd) func() error {
return func() error {
logger.Debug(ctx, "cmdCancel: sending interrupt to process", slog.F("pid", cmd.Process.Pid))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

📋🍝 suggestion: check that cmd.Process != nil as it could have exited.

return cmd.Process.Signal(os.Interrupt)
}
}
Loading