diff --git a/agent/agentssh/agentssh.go b/agent/agentssh/agentssh.go index f56497d149499..293dd4db169ac 100644 --- a/agent/agentssh/agentssh.go +++ b/agent/agentssh/agentssh.go @@ -1060,8 +1060,10 @@ func (s *Server) Close() error { // Guard against multiple calls to Close and // accepting new connections during close. if s.closing != nil { + closing := s.closing s.mu.Unlock() - return xerrors.New("server is closing") + <-closing + return xerrors.New("server is closed") } s.closing = make(chan struct{}) diff --git a/agent/agentssh/agentssh_test.go b/agent/agentssh/agentssh_test.go index 9a427fdd7d91e..69f92e0fd31a0 100644 --- a/agent/agentssh/agentssh_test.go +++ b/agent/agentssh/agentssh_test.go @@ -153,7 +153,9 @@ func TestNewServer_CloseActiveConnections(t *testing.T) { logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug) s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), agentexec.DefaultExecer, nil) require.NoError(t, err) - defer s.Close() + t.Cleanup(func() { + _ = s.Close() + }) err = s.UpdateHostSigner(42) assert.NoError(t, err) @@ -190,10 +192,17 @@ func TestNewServer_CloseActiveConnections(t *testing.T) { } // The 60 seconds here is intended to be longer than the // test. The shutdown should propagate. - err = sess.Start("/bin/bash -c 'trap \"sleep 60\" SIGTERM; sleep 60'") + if runtime.GOOS == "windows" { + // Best effort to at least partially test this in Windows. + err = sess.Start("echo start\"ed\" && sleep 60") + } else { + err = sess.Start("/bin/bash -c 'trap \"sleep 60\" SIGTERM; echo start\"ed\"; sleep 60'") + } assert.NoError(t, err) + pty.ExpectMatchContext(ctx, "started") close(ch) + err = sess.Wait() assert.Error(t, err) }(waitConns[i]) diff --git a/agent/agentssh/exec_windows.go b/agent/agentssh/exec_windows.go index 0345ddd85e52e..39f0f97198479 100644 --- a/agent/agentssh/exec_windows.go +++ b/agent/agentssh/exec_windows.go @@ -2,7 +2,6 @@ package agentssh import ( "context" - "os" "os/exec" "syscall" @@ -15,7 +14,12 @@ func cmdSysProcAttr() *syscall.SysProcAttr { func cmdCancel(ctx context.Context, logger slog.Logger, cmd *exec.Cmd) func() error { return func() error { - logger.Debug(ctx, "cmdCancel: sending interrupt to process", slog.F("pid", cmd.Process.Pid)) - return cmd.Process.Signal(os.Interrupt) + logger.Debug(ctx, "cmdCancel: killing process", slog.F("pid", cmd.Process.Pid)) + // Windows doesn't support sending signals to process groups, so we + // have to kill the process directly. In the future, we may want to + // implement a more sophisticated solution for process groups on + // Windows, but for now, this is a simple way to ensure that the + // process is terminated when the context is cancelled. + return cmd.Process.Kill() } }