Skip to content

fix(agent): ensure SSH server shutdown with process groups #17227

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Apr 3, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 12 additions & 7 deletions agent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -1773,15 +1773,20 @@ func (a *agent) Close() error {
a.setLifecycle(codersdk.WorkspaceAgentLifecycleShuttingDown)

// Attempt to gracefully shut down all active SSH connections and
// stop accepting new ones.
err := a.sshServer.Shutdown(a.hardCtx)
// stop accepting new ones. If all processes have not exited after
// 10 seconds, we just log it and move on as it's more important
// to run the shutdown scripts.
sshShutdownCtx, sshShutdownCancel := context.WithTimeout(a.hardCtx, 10*time.Second)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can foresee someone wanting to adjust this timeout. Maybe add a CLI option for it?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that if we do get requests to make this configurable, it will help drive the implementation. I'm not certain this approach is the end-goal we want and I would like to avoid exposing functionality that may limit future implementations.

When we start solving #6175, changing this behavior may be relevant.

I'll go ahead and reduce this to 5 seconds as well as the default docker grace timeout is 10s I believe.

The original motivation behind closing SSH before running shutdown scripts is to ensure that commands doing work exit before potentially doing backups of state or something along those lines. I think giving those commands a grace of 5 seconds is plenty in most cases.

defer sshShutdownCancel()
err := a.sshServer.Shutdown(sshShutdownCtx)
if err != nil {
a.logger.Error(a.hardCtx, "ssh server shutdown", slog.Error(err))
}
err = a.sshServer.Close()
if err != nil {
a.logger.Error(a.hardCtx, "ssh server close", slog.Error(err))
if errors.Is(err, context.DeadlineExceeded) {
a.logger.Warn(sshShutdownCtx, "ssh server shutdown timeout", slog.Error(err))
} else {
a.logger.Error(sshShutdownCtx, "ssh server shutdown", slog.Error(err))
}
}

// wait for SSH to shut down before the general graceful cancel, because
// this triggers a disconnect in the tailnet layer, telling all clients to
// shut down their wireguard tunnels to us. If SSH sessions are still up,
Expand Down
64 changes: 53 additions & 11 deletions agent/agentssh/agentssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -582,6 +582,12 @@ func (s *Server) sessionStart(logger slog.Logger, session ssh.Session, env []str
func (s *Server) startNonPTYSession(logger slog.Logger, session ssh.Session, magicTypeLabel string, cmd *exec.Cmd) error {
s.metrics.sessionsTotal.WithLabelValues(magicTypeLabel, "no").Add(1)

// Create a process group and send SIGHUP to child processes,
// otherwise context cancellation will not propagate properly
// and SSH server close may be delayed.
cmd.SysProcAttr = cmdSysProcAttr()
cmd.Cancel = cmdCancel(session.Context(), logger, cmd)

cmd.Stdout = session
cmd.Stderr = session.Stderr()
// This blocks forever until stdin is received if we don't
Expand Down Expand Up @@ -926,7 +932,12 @@ func (s *Server) CreateCommand(ctx context.Context, script string, env []string,
// Serve starts the server to handle incoming connections on the provided listener.
// It returns an error if no host keys are set or if there is an issue accepting connections.
func (s *Server) Serve(l net.Listener) (retErr error) {
if len(s.srv.HostSigners) == 0 {
// Ensure we're not mutating HostSigners as we're reading it.
s.mu.RLock()
noHostKeys := len(s.srv.HostSigners) == 0
s.mu.RUnlock()

if noHostKeys {
return xerrors.New("no host keys set")
}

Expand Down Expand Up @@ -1044,6 +1055,11 @@ func (s *Server) trackSession(ss ssh.Session, add bool) (ok bool) {
// Close the server and all active connections. Server can be re-used
// after Close is done.
func (s *Server) Close() error {
return s.close(context.Background())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need this extra method with the context argument? I don't see it used elsewhere.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Originally I intended to use s.close in Shutdown too, but I oopsied that change. I'll revert this though since ctx is only used for logger and we don't even utilize that feature often.

}

//nolint:revive // Ignore the similarity of close and Close.
func (s *Server) close(ctx context.Context) error {
s.mu.Lock()

// Guard against multiple calls to Close and
Expand All @@ -1054,24 +1070,29 @@ func (s *Server) Close() error {
}
s.closing = make(chan struct{})

s.logger.Debug(ctx, "closing server")

// Stop accepting new connections.
s.logger.Debug(ctx, "closing all active listeners")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion: log the number of active listeners

for l := range s.listeners {
_ = l.Close()
}

// Close all active sessions to gracefully
// terminate client connections.
s.logger.Debug(ctx, "closing all active sessions")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion: log the number of active sessions

for ss := range s.sessions {
// We call Close on the underlying channel here because we don't
// want to send an exit status to the client (via Exit()).
// Typically OpenSSH clients will return 255 as the exit status.
_ = ss.Close()
}

// Close all active listeners and connections.
for l := range s.listeners {
_ = l.Close()
}
s.logger.Debug(ctx, "closing all active connections")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion: log the number of active conns

for c := range s.conns {
_ = c.Close()
}

// Close the underlying SSH server.
s.logger.Debug(ctx, "closing SSH server")
err := s.srv.Close()

s.mu.Unlock()
Expand All @@ -1082,15 +1103,36 @@ func (s *Server) Close() error {
s.closing = nil
s.mu.Unlock()

s.logger.Debug(ctx, "closing server done")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion: log elapsed time for closing


return err
}

// Shutdown gracefully closes all active SSH connections and stops
// Shutdown ~~gracefully~~ closes all active SSH connections and stops
// accepting new connections.
//
// Shutdown is not implemented.
func (*Server) Shutdown(_ context.Context) error {
// TODO(mafredri): Implement shutdown, SIGHUP running commands, etc.
// For now, simply calls Close and allows early return via context
// cancellation.
func (s *Server) Shutdown(ctx context.Context) error {
ch := make(chan error, 1)
go func() {
// TODO(mafredri): Implement shutdown, SIGHUP running commands, etc.
// For now we just close the server.
ch <- s.Close()
}()
var err error
select {
case <-ctx.Done():
err = ctx.Err()
case err = <-ch:
}
// Re-check for context cancellation precedence.
if ctx.Err() != nil {
err = ctx.Err()
}
if err != nil {
return xerrors.Errorf("close server: %w", err)
}
return nil
}

Expand Down
116 changes: 79 additions & 37 deletions agent/agentssh/agentssh_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"go.uber.org/goleak"
"golang.org/x/crypto/ssh"

"cdr.dev/slog"
"cdr.dev/slog/sloggers/slogtest"

"github.com/coder/coder/v2/agent/agentexec"
Expand Down Expand Up @@ -147,51 +148,92 @@ func (*fakeEnvInfoer) ModifyCommand(cmd string, args ...string) (string, []strin
func TestNewServer_CloseActiveConnections(t *testing.T) {
t.Parallel()

ctx := context.Background()
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), agentexec.DefaultExecer, nil)
require.NoError(t, err)
defer s.Close()
err = s.UpdateHostSigner(42)
assert.NoError(t, err)
prepare := func(ctx context.Context, t *testing.T) (*agentssh.Server, func()) {
t.Helper()
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), agentexec.DefaultExecer, nil)
require.NoError(t, err)
defer s.Close()
err = s.UpdateHostSigner(42)
assert.NoError(t, err)

ln, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
ln, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)

var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
err := s.Serve(ln)
assert.Error(t, err) // Server is closed.
}()
waitConns := make([]chan struct{}, 4)

pty := ptytest.New(t)
var wg sync.WaitGroup
wg.Add(1 + len(waitConns))

doClose := make(chan struct{})
go func() {
defer wg.Done()
c := sshClient(t, ln.Addr().String())
sess, err := c.NewSession()
assert.NoError(t, err)
sess.Stdin = pty.Input()
sess.Stdout = pty.Output()
sess.Stderr = pty.Output()
go func() {
defer wg.Done()
err := s.Serve(ln)
assert.Error(t, err) // Server is closed.
}()

assert.NoError(t, err)
err = sess.Start("")
assert.NoError(t, err)
for i := 0; i < len(waitConns); i++ {
waitConns[i] = make(chan struct{})
go func(ch chan struct{}) {
defer wg.Done()
c := sshClient(t, ln.Addr().String())
sess, err := c.NewSession()
assert.NoError(t, err)
pty := ptytest.New(t)
sess.Stdin = pty.Input()
sess.Stdout = pty.Output()
sess.Stderr = pty.Output()

// Every other session will request a PTY.
if i%2 == 0 {
err = sess.RequestPty("xterm", 80, 80, nil)
assert.NoError(t, err)
}
// The 60 seconds here is intended to be longer than the
// test. The shutdown should propagate.
err = sess.Start("/bin/bash -c 'trap \"sleep 60\" SIGTERM; sleep 60'")
assert.NoError(t, err)

close(ch)
err = sess.Wait()
assert.Error(t, err)
}(waitConns[i])
}

close(doClose)
err = sess.Wait()
assert.Error(t, err)
}()
for _, ch := range waitConns {
<-ch
}

<-doClose
err = s.Close()
require.NoError(t, err)
return s, wg.Wait
}

t.Run("Close", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitMedium)
s, wait := prepare(ctx, t)
err := s.Close()
require.NoError(t, err)
wait()
})

wg.Wait()
t.Run("Shutdown", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitMedium)
s, wait := prepare(ctx, t)
err := s.Shutdown(ctx)
require.NoError(t, err)
wait()
})

t.Run("Shutdown Early", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitMedium)
s, wait := prepare(ctx, t)
ctx, cancel := context.WithCancel(ctx)
cancel()
err := s.Shutdown(ctx)
require.ErrorIs(t, err, context.Canceled)
wait()
})
}

func TestNewServer_Signal(t *testing.T) {
Expand Down
24 changes: 24 additions & 0 deletions agent/agentssh/exec_other.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
//go:build !windows
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Review: These are copy-pasta from agentscripts package.


package agentssh

import (
"context"
"os/exec"
"syscall"

"cdr.dev/slog"
)

func cmdSysProcAttr() *syscall.SysProcAttr {
return &syscall.SysProcAttr{
Setsid: true,
}
}

func cmdCancel(ctx context.Context, logger slog.Logger, cmd *exec.Cmd) func() error {
return func() error {
logger.Debug(ctx, "cmdCancel: sending SIGHUP to process and children", slog.F("pid", cmd.Process.Pid))
return syscall.Kill(-cmd.Process.Pid, syscall.SIGHUP)
}
}
21 changes: 21 additions & 0 deletions agent/agentssh/exec_windows.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package agentssh

import (
"context"
"os"
"os/exec"
"syscall"

"cdr.dev/slog"
)

func cmdSysProcAttr() *syscall.SysProcAttr {
return &syscall.SysProcAttr{}
}

func cmdCancel(ctx context.Context, logger slog.Logger, cmd *exec.Cmd) func() error {
return func() error {
logger.Debug(ctx, "cmdCancel: sending interrupt to process", slog.F("pid", cmd.Process.Pid))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

📋🍝 suggestion: check that cmd.Process != nil as it could have exited.

return cmd.Process.Signal(os.Interrupt)
}
}
Loading