diff --git a/agent/agent.go b/agent/agent.go index 4f07eec69db95..eddebc5d6b26d 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -1773,15 +1773,22 @@ func (a *agent) Close() error { a.setLifecycle(codersdk.WorkspaceAgentLifecycleShuttingDown) // Attempt to gracefully shut down all active SSH connections and - // stop accepting new ones. - err := a.sshServer.Shutdown(a.hardCtx) + // stop accepting new ones. If all processes have not exited after 5 + // seconds, we just log it and move on as it's more important to run + // the shutdown scripts. A typical shutdown time for containers is + // 10 seconds, so this still leaves a bit of time to run the + // shutdown scripts in the worst-case. + sshShutdownCtx, sshShutdownCancel := context.WithTimeout(a.hardCtx, 5*time.Second) + defer sshShutdownCancel() + err := a.sshServer.Shutdown(sshShutdownCtx) if err != nil { - a.logger.Error(a.hardCtx, "ssh server shutdown", slog.Error(err)) - } - err = a.sshServer.Close() - if err != nil { - a.logger.Error(a.hardCtx, "ssh server close", slog.Error(err)) + if errors.Is(err, context.DeadlineExceeded) { + a.logger.Warn(sshShutdownCtx, "ssh server shutdown timeout", slog.Error(err)) + } else { + a.logger.Error(sshShutdownCtx, "ssh server shutdown", slog.Error(err)) + } } + // wait for SSH to shut down before the general graceful cancel, because // this triggers a disconnect in the tailnet layer, telling all clients to // shut down their wireguard tunnels to us. If SSH sessions are still up, diff --git a/agent/agentssh/agentssh.go b/agent/agentssh/agentssh.go index 473f38c26d64c..f56497d149499 100644 --- a/agent/agentssh/agentssh.go +++ b/agent/agentssh/agentssh.go @@ -582,6 +582,12 @@ func (s *Server) sessionStart(logger slog.Logger, session ssh.Session, env []str func (s *Server) startNonPTYSession(logger slog.Logger, session ssh.Session, magicTypeLabel string, cmd *exec.Cmd) error { s.metrics.sessionsTotal.WithLabelValues(magicTypeLabel, "no").Add(1) + // Create a process group and send SIGHUP to child processes, + // otherwise context cancellation will not propagate properly + // and SSH server close may be delayed. + cmd.SysProcAttr = cmdSysProcAttr() + cmd.Cancel = cmdCancel(session.Context(), logger, cmd) + cmd.Stdout = session cmd.Stderr = session.Stderr() // This blocks forever until stdin is received if we don't @@ -926,7 +932,12 @@ func (s *Server) CreateCommand(ctx context.Context, script string, env []string, // Serve starts the server to handle incoming connections on the provided listener. // It returns an error if no host keys are set or if there is an issue accepting connections. func (s *Server) Serve(l net.Listener) (retErr error) { - if len(s.srv.HostSigners) == 0 { + // Ensure we're not mutating HostSigners as we're reading it. + s.mu.RLock() + noHostKeys := len(s.srv.HostSigners) == 0 + s.mu.RUnlock() + + if noHostKeys { return xerrors.New("no host keys set") } @@ -1054,27 +1065,36 @@ func (s *Server) Close() error { } s.closing = make(chan struct{}) + ctx := context.Background() + + s.logger.Debug(ctx, "closing server") + + // Stop accepting new connections. + s.logger.Debug(ctx, "closing all active listeners", slog.F("count", len(s.listeners))) + for l := range s.listeners { + _ = l.Close() + } + // Close all active sessions to gracefully // terminate client connections. + s.logger.Debug(ctx, "closing all active sessions", slog.F("count", len(s.sessions))) for ss := range s.sessions { // We call Close on the underlying channel here because we don't // want to send an exit status to the client (via Exit()). // Typically OpenSSH clients will return 255 as the exit status. _ = ss.Close() } - - // Close all active listeners and connections. - for l := range s.listeners { - _ = l.Close() - } + s.logger.Debug(ctx, "closing all active connections", slog.F("count", len(s.conns))) for c := range s.conns { _ = c.Close() } - // Close the underlying SSH server. + s.logger.Debug(ctx, "closing SSH server") err := s.srv.Close() s.mu.Unlock() + + s.logger.Debug(ctx, "waiting for all goroutines to exit") s.wg.Wait() // Wait for all goroutines to exit. s.mu.Lock() @@ -1082,15 +1102,35 @@ func (s *Server) Close() error { s.closing = nil s.mu.Unlock() + s.logger.Debug(ctx, "closing server done") + return err } -// Shutdown gracefully closes all active SSH connections and stops -// accepting new connections. -// -// Shutdown is not implemented. -func (*Server) Shutdown(_ context.Context) error { - // TODO(mafredri): Implement shutdown, SIGHUP running commands, etc. +// Shutdown stops accepting new connections. The current implementation +// calls Close() for simplicity instead of waiting for existing +// connections to close. If the context times out, Shutdown will return +// but Close() may not have completed. +func (s *Server) Shutdown(ctx context.Context) error { + ch := make(chan error, 1) + go func() { + // TODO(mafredri): Implement shutdown, SIGHUP running commands, etc. + // For now we just close the server. + ch <- s.Close() + }() + var err error + select { + case <-ctx.Done(): + err = ctx.Err() + case err = <-ch: + } + // Re-check for context cancellation precedence. + if ctx.Err() != nil { + err = ctx.Err() + } + if err != nil { + return xerrors.Errorf("close server: %w", err) + } return nil } diff --git a/agent/agentssh/agentssh_test.go b/agent/agentssh/agentssh_test.go index 6b0706e95db44..9a427fdd7d91e 100644 --- a/agent/agentssh/agentssh_test.go +++ b/agent/agentssh/agentssh_test.go @@ -21,6 +21,7 @@ import ( "go.uber.org/goleak" "golang.org/x/crypto/ssh" + "cdr.dev/slog" "cdr.dev/slog/sloggers/slogtest" "github.com/coder/coder/v2/agent/agentexec" @@ -147,51 +148,92 @@ func (*fakeEnvInfoer) ModifyCommand(cmd string, args ...string) (string, []strin func TestNewServer_CloseActiveConnections(t *testing.T) { t.Parallel() - ctx := context.Background() - logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}) - s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), agentexec.DefaultExecer, nil) - require.NoError(t, err) - defer s.Close() - err = s.UpdateHostSigner(42) - assert.NoError(t, err) + prepare := func(ctx context.Context, t *testing.T) (*agentssh.Server, func()) { + t.Helper() + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug) + s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), agentexec.DefaultExecer, nil) + require.NoError(t, err) + defer s.Close() + err = s.UpdateHostSigner(42) + assert.NoError(t, err) - ln, err := net.Listen("tcp", "127.0.0.1:0") - require.NoError(t, err) + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) - var wg sync.WaitGroup - wg.Add(2) - go func() { - defer wg.Done() - err := s.Serve(ln) - assert.Error(t, err) // Server is closed. - }() + waitConns := make([]chan struct{}, 4) - pty := ptytest.New(t) + var wg sync.WaitGroup + wg.Add(1 + len(waitConns)) - doClose := make(chan struct{}) - go func() { - defer wg.Done() - c := sshClient(t, ln.Addr().String()) - sess, err := c.NewSession() - assert.NoError(t, err) - sess.Stdin = pty.Input() - sess.Stdout = pty.Output() - sess.Stderr = pty.Output() + go func() { + defer wg.Done() + err := s.Serve(ln) + assert.Error(t, err) // Server is closed. + }() - assert.NoError(t, err) - err = sess.Start("") - assert.NoError(t, err) + for i := 0; i < len(waitConns); i++ { + waitConns[i] = make(chan struct{}) + go func(ch chan struct{}) { + defer wg.Done() + c := sshClient(t, ln.Addr().String()) + sess, err := c.NewSession() + assert.NoError(t, err) + pty := ptytest.New(t) + sess.Stdin = pty.Input() + sess.Stdout = pty.Output() + sess.Stderr = pty.Output() + + // Every other session will request a PTY. + if i%2 == 0 { + err = sess.RequestPty("xterm", 80, 80, nil) + assert.NoError(t, err) + } + // The 60 seconds here is intended to be longer than the + // test. The shutdown should propagate. + err = sess.Start("/bin/bash -c 'trap \"sleep 60\" SIGTERM; sleep 60'") + assert.NoError(t, err) + + close(ch) + err = sess.Wait() + assert.Error(t, err) + }(waitConns[i]) + } - close(doClose) - err = sess.Wait() - assert.Error(t, err) - }() + for _, ch := range waitConns { + <-ch + } - <-doClose - err = s.Close() - require.NoError(t, err) + return s, wg.Wait + } + + t.Run("Close", func(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitMedium) + s, wait := prepare(ctx, t) + err := s.Close() + require.NoError(t, err) + wait() + }) - wg.Wait() + t.Run("Shutdown", func(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitMedium) + s, wait := prepare(ctx, t) + err := s.Shutdown(ctx) + require.NoError(t, err) + wait() + }) + + t.Run("Shutdown Early", func(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitMedium) + s, wait := prepare(ctx, t) + ctx, cancel := context.WithCancel(ctx) + cancel() + err := s.Shutdown(ctx) + require.ErrorIs(t, err, context.Canceled) + wait() + }) } func TestNewServer_Signal(t *testing.T) { diff --git a/agent/agentssh/exec_other.go b/agent/agentssh/exec_other.go new file mode 100644 index 0000000000000..54dfd50899412 --- /dev/null +++ b/agent/agentssh/exec_other.go @@ -0,0 +1,24 @@ +//go:build !windows + +package agentssh + +import ( + "context" + "os/exec" + "syscall" + + "cdr.dev/slog" +) + +func cmdSysProcAttr() *syscall.SysProcAttr { + return &syscall.SysProcAttr{ + Setsid: true, + } +} + +func cmdCancel(ctx context.Context, logger slog.Logger, cmd *exec.Cmd) func() error { + return func() error { + logger.Debug(ctx, "cmdCancel: sending SIGHUP to process and children", slog.F("pid", cmd.Process.Pid)) + return syscall.Kill(-cmd.Process.Pid, syscall.SIGHUP) + } +} diff --git a/agent/agentssh/exec_windows.go b/agent/agentssh/exec_windows.go new file mode 100644 index 0000000000000..0345ddd85e52e --- /dev/null +++ b/agent/agentssh/exec_windows.go @@ -0,0 +1,21 @@ +package agentssh + +import ( + "context" + "os" + "os/exec" + "syscall" + + "cdr.dev/slog" +) + +func cmdSysProcAttr() *syscall.SysProcAttr { + return &syscall.SysProcAttr{} +} + +func cmdCancel(ctx context.Context, logger slog.Logger, cmd *exec.Cmd) func() error { + return func() error { + logger.Debug(ctx, "cmdCancel: sending interrupt to process", slog.F("pid", cmd.Process.Pid)) + return cmd.Process.Signal(os.Interrupt) + } +}