diff --git a/agent/agent.go b/agent/agent.go index 82ff9442bde3b..2daba701b4e89 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -33,6 +33,7 @@ import ( "tailscale.com/util/clientmetric" "cdr.dev/slog" + "github.com/coder/coder/v2/agent/agentexec" "github.com/coder/coder/v2/agent/agentscripts" "github.com/coder/coder/v2/agent/agentssh" "github.com/coder/coder/v2/agent/proto" @@ -80,6 +81,7 @@ type Options struct { ReportMetadataInterval time.Duration ServiceBannerRefreshInterval time.Duration BlockFileTransfer bool + Execer agentexec.Execer } type Client interface { @@ -139,6 +141,10 @@ func New(options Options) Agent { prometheusRegistry = prometheus.NewRegistry() } + if options.Execer == nil { + options.Execer = agentexec.DefaultExecer + } + hardCtx, hardCancel := context.WithCancel(context.Background()) gracefulCtx, gracefulCancel := context.WithCancel(hardCtx) a := &agent{ @@ -171,6 +177,7 @@ func New(options Options) Agent { prometheusRegistry: prometheusRegistry, metrics: newAgentMetrics(prometheusRegistry), + execer: options.Execer, } // Initially, we have a closed channel, reflecting the fact that we are not initially connected. // Each time we connect we replace the channel (while holding the closeMutex) with a new one @@ -239,6 +246,7 @@ type agent struct { // metrics are prometheus registered metrics that will be collected and // labeled in Coder with the agent + workspace. metrics *agentMetrics + execer agentexec.Execer } func (a *agent) TailnetConn() *tailnet.Conn { @@ -247,7 +255,7 @@ func (a *agent) TailnetConn() *tailnet.Conn { func (a *agent) init() { // pass the "hard" context because we explicitly close the SSH server as part of graceful shutdown. - sshSrv, err := agentssh.NewServer(a.hardCtx, a.logger.Named("ssh-server"), a.prometheusRegistry, a.filesystem, &agentssh.Config{ + sshSrv, err := agentssh.NewServer(a.hardCtx, a.logger.Named("ssh-server"), a.prometheusRegistry, a.filesystem, a.execer, &agentssh.Config{ MaxTimeout: a.sshMaxTimeout, MOTDFile: func() string { return a.manifest.Load().MOTDFile }, AnnouncementBanners: func() *[]codersdk.BannerConfig { return a.announcementBanners.Load() }, diff --git a/agent/agentexec/cli_linux.go b/agent/agentexec/cli_linux.go index 8c4acb9060a2e..9c6568c81811b 100644 --- a/agent/agentexec/cli_linux.go +++ b/agent/agentexec/cli_linux.go @@ -17,9 +17,6 @@ import ( "golang.org/x/xerrors" ) -// unset is set to an invalid value for nice and oom scores. -const unset = -2000 - // CLI runs the agent-exec command. It should only be called by the cli package. func CLI() error { // We lock the OS thread here to avoid a race condition where the nice priority diff --git a/agent/agentexec/exec.go b/agent/agentexec/exec.go index fdb75b8ee4d13..3c2d60c7a43ef 100644 --- a/agent/agentexec/exec.go +++ b/agent/agentexec/exec.go @@ -20,60 +20,101 @@ const ( EnvProcPrioMgmt = "CODER_PROC_PRIO_MGMT" EnvProcOOMScore = "CODER_PROC_OOM_SCORE" EnvProcNiceScore = "CODER_PROC_NICE_SCORE" -) -// CommandContext returns an exec.Cmd that calls "coder agent-exec" prior to exec'ing -// the provided command if CODER_PROC_PRIO_MGMT is set, otherwise a normal exec.Cmd -// is returned. All instances of exec.Cmd should flow through this function to ensure -// proper resource constraints are applied to the child process. -func CommandContext(ctx context.Context, cmd string, args ...string) (*exec.Cmd, error) { - cmd, args, err := agentExecCmd(cmd, args...) - if err != nil { - return nil, xerrors.Errorf("agent exec cmd: %w", err) - } - return exec.CommandContext(ctx, cmd, args...), nil -} + // unset is set to an invalid value for nice and oom scores. + unset = -2000 +) -// PTYCommandContext returns an pty.Cmd that calls "coder agent-exec" prior to exec'ing -// the provided command if CODER_PROC_PRIO_MGMT is set, otherwise a normal pty.Cmd -// is returned. All instances of pty.Cmd should flow through this function to ensure -// proper resource constraints are applied to the child process. -func PTYCommandContext(ctx context.Context, cmd string, args ...string) (*pty.Cmd, error) { - cmd, args, err := agentExecCmd(cmd, args...) - if err != nil { - return nil, xerrors.Errorf("agent exec cmd: %w", err) - } - return pty.CommandContext(ctx, cmd, args...), nil +var DefaultExecer Execer = execer{} + +// Execer defines an abstraction for creating exec.Cmd variants. It's unfortunately +// necessary because we need to be able to wrap child processes with "coder agent-exec" +// for templates that expect the agent to manage process priority. +type Execer interface { + // CommandContext returns an exec.Cmd that calls "coder agent-exec" prior to exec'ing + // the provided command if CODER_PROC_PRIO_MGMT is set, otherwise a normal exec.Cmd + // is returned. All instances of exec.Cmd should flow through this function to ensure + // proper resource constraints are applied to the child process. + CommandContext(ctx context.Context, cmd string, args ...string) *exec.Cmd + // PTYCommandContext returns an pty.Cmd that calls "coder agent-exec" prior to exec'ing + // the provided command if CODER_PROC_PRIO_MGMT is set, otherwise a normal pty.Cmd + // is returned. All instances of pty.Cmd should flow through this function to ensure + // proper resource constraints are applied to the child process. + PTYCommandContext(ctx context.Context, cmd string, args ...string) *pty.Cmd } -func agentExecCmd(cmd string, args ...string) (string, []string, error) { +func NewExecer() (Execer, error) { _, enabled := os.LookupEnv(EnvProcPrioMgmt) if runtime.GOOS != "linux" || !enabled { - return cmd, args, nil + return DefaultExecer, nil } executable, err := os.Executable() if err != nil { - return "", nil, xerrors.Errorf("get executable: %w", err) + return nil, xerrors.Errorf("get executable: %w", err) } bin, err := filepath.EvalSymlinks(executable) if err != nil { - return "", nil, xerrors.Errorf("eval symlinks: %w", err) + return nil, xerrors.Errorf("eval symlinks: %w", err) + } + + oomScore, ok := envValInt(EnvProcOOMScore) + if !ok { + oomScore = unset + } + + niceScore, ok := envValInt(EnvProcNiceScore) + if !ok { + niceScore = unset } + return priorityExecer{ + binPath: bin, + oomScore: oomScore, + niceScore: niceScore, + }, nil +} + +type execer struct{} + +func (execer) CommandContext(ctx context.Context, cmd string, args ...string) *exec.Cmd { + return exec.CommandContext(ctx, cmd, args...) +} + +func (execer) PTYCommandContext(ctx context.Context, cmd string, args ...string) *pty.Cmd { + return pty.CommandContext(ctx, cmd, args...) +} + +type priorityExecer struct { + binPath string + oomScore int + niceScore int +} + +func (e priorityExecer) CommandContext(ctx context.Context, cmd string, args ...string) *exec.Cmd { + cmd, args = e.agentExecCmd(cmd, args...) + return exec.CommandContext(ctx, cmd, args...) +} + +func (e priorityExecer) PTYCommandContext(ctx context.Context, cmd string, args ...string) *pty.Cmd { + cmd, args = e.agentExecCmd(cmd, args...) + return pty.CommandContext(ctx, cmd, args...) +} + +func (e priorityExecer) agentExecCmd(cmd string, args ...string) (string, []string) { execArgs := []string{"agent-exec"} - if score, ok := envValInt(EnvProcOOMScore); ok { - execArgs = append(execArgs, oomScoreArg(score)) + if e.oomScore != unset { + execArgs = append(execArgs, oomScoreArg(e.oomScore)) } - if score, ok := envValInt(EnvProcNiceScore); ok { - execArgs = append(execArgs, niceScoreArg(score)) + if e.niceScore != unset { + execArgs = append(execArgs, niceScoreArg(e.niceScore)) } execArgs = append(execArgs, "--", cmd) execArgs = append(execArgs, args...) - return bin, execArgs, nil + return e.binPath, execArgs } // envValInt searches for a key in a list of environment variables and parses it to an int. diff --git a/agent/agentexec/exec_internal_test.go b/agent/agentexec/exec_internal_test.go new file mode 100644 index 0000000000000..c7d991902fab1 --- /dev/null +++ b/agent/agentexec/exec_internal_test.go @@ -0,0 +1,84 @@ +package agentexec + +import ( + "context" + "os/exec" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestExecer(t *testing.T) { + t.Parallel() + + t.Run("Default", func(t *testing.T) { + t.Parallel() + + cmd := DefaultExecer.CommandContext(context.Background(), "sh", "-c", "sleep") + + path, err := exec.LookPath("sh") + require.NoError(t, err) + require.Equal(t, path, cmd.Path) + require.Equal(t, []string{"sh", "-c", "sleep"}, cmd.Args) + }) + + t.Run("Priority", func(t *testing.T) { + t.Parallel() + + t.Run("OK", func(t *testing.T) { + t.Parallel() + + e := priorityExecer{ + binPath: "/foo/bar/baz", + oomScore: unset, + niceScore: unset, + } + + cmd := e.CommandContext(context.Background(), "sh", "-c", "sleep") + require.Equal(t, e.binPath, cmd.Path) + require.Equal(t, []string{e.binPath, "agent-exec", "--", "sh", "-c", "sleep"}, cmd.Args) + }) + + t.Run("Nice", func(t *testing.T) { + t.Parallel() + + e := priorityExecer{ + binPath: "/foo/bar/baz", + oomScore: unset, + niceScore: 10, + } + + cmd := e.CommandContext(context.Background(), "sh", "-c", "sleep") + require.Equal(t, e.binPath, cmd.Path) + require.Equal(t, []string{e.binPath, "agent-exec", "--coder-nice=10", "--", "sh", "-c", "sleep"}, cmd.Args) + }) + + t.Run("OOM", func(t *testing.T) { + t.Parallel() + + e := priorityExecer{ + binPath: "/foo/bar/baz", + oomScore: 123, + niceScore: unset, + } + + cmd := e.CommandContext(context.Background(), "sh", "-c", "sleep") + require.Equal(t, e.binPath, cmd.Path) + require.Equal(t, []string{e.binPath, "agent-exec", "--coder-oom=123", "--", "sh", "-c", "sleep"}, cmd.Args) + }) + + t.Run("Both", func(t *testing.T) { + t.Parallel() + + e := priorityExecer{ + binPath: "/foo/bar/baz", + oomScore: 432, + niceScore: 14, + } + + cmd := e.CommandContext(context.Background(), "sh", "-c", "sleep") + require.Equal(t, e.binPath, cmd.Path) + require.Equal(t, []string{e.binPath, "agent-exec", "--coder-oom=432", "--coder-nice=14", "--", "sh", "-c", "sleep"}, cmd.Args) + }) + }) +} diff --git a/agent/agentexec/exec_test.go b/agent/agentexec/exec_test.go deleted file mode 100644 index cf25cca473fe9..0000000000000 --- a/agent/agentexec/exec_test.go +++ /dev/null @@ -1,119 +0,0 @@ -package agentexec_test - -import ( - "context" - "os" - "os/exec" - "runtime" - "testing" - - "github.com/stretchr/testify/require" - - "github.com/coder/coder/v2/agent/agentexec" -) - -//nolint:paralleltest // we need to test environment variables -func TestExec(t *testing.T) { - //nolint:paralleltest // we need to test environment variables - t.Run("NonLinux", func(t *testing.T) { - t.Setenv(agentexec.EnvProcPrioMgmt, "true") - - if runtime.GOOS == "linux" { - t.Skip("skipping on linux") - } - - cmd, err := agentexec.CommandContext(context.Background(), "sh", "-c", "sleep") - require.NoError(t, err) - - path, err := exec.LookPath("sh") - require.NoError(t, err) - require.Equal(t, path, cmd.Path) - require.Equal(t, []string{"sh", "-c", "sleep"}, cmd.Args) - }) - - //nolint:paralleltest // we need to test environment variables - t.Run("Linux", func(t *testing.T) { - //nolint:paralleltest // we need to test environment variables - t.Run("Disabled", func(t *testing.T) { - if runtime.GOOS != "linux" { - t.Skip("skipping on non-linux") - } - - cmd, err := agentexec.CommandContext(context.Background(), "sh", "-c", "sleep") - require.NoError(t, err) - path, err := exec.LookPath("sh") - require.NoError(t, err) - require.Equal(t, path, cmd.Path) - require.Equal(t, []string{"sh", "-c", "sleep"}, cmd.Args) - }) - - //nolint:paralleltest // we need to test environment variables - t.Run("Enabled", func(t *testing.T) { - t.Setenv(agentexec.EnvProcPrioMgmt, "hello") - - if runtime.GOOS != "linux" { - t.Skip("skipping on non-linux") - } - - executable, err := os.Executable() - require.NoError(t, err) - - cmd, err := agentexec.CommandContext(context.Background(), "sh", "-c", "sleep") - require.NoError(t, err) - require.Equal(t, executable, cmd.Path) - require.Equal(t, []string{executable, "agent-exec", "--", "sh", "-c", "sleep"}, cmd.Args) - }) - - t.Run("Nice", func(t *testing.T) { - t.Setenv(agentexec.EnvProcPrioMgmt, "hello") - t.Setenv(agentexec.EnvProcNiceScore, "10") - - if runtime.GOOS != "linux" { - t.Skip("skipping on non-linux") - } - - executable, err := os.Executable() - require.NoError(t, err) - - cmd, err := agentexec.CommandContext(context.Background(), "sh", "-c", "sleep") - require.NoError(t, err) - require.Equal(t, executable, cmd.Path) - require.Equal(t, []string{executable, "agent-exec", "--coder-nice=10", "--", "sh", "-c", "sleep"}, cmd.Args) - }) - - t.Run("OOM", func(t *testing.T) { - t.Setenv(agentexec.EnvProcPrioMgmt, "hello") - t.Setenv(agentexec.EnvProcOOMScore, "123") - - if runtime.GOOS != "linux" { - t.Skip("skipping on non-linux") - } - - executable, err := os.Executable() - require.NoError(t, err) - - cmd, err := agentexec.CommandContext(context.Background(), "sh", "-c", "sleep") - require.NoError(t, err) - require.Equal(t, executable, cmd.Path) - require.Equal(t, []string{executable, "agent-exec", "--coder-oom=123", "--", "sh", "-c", "sleep"}, cmd.Args) - }) - - t.Run("Both", func(t *testing.T) { - t.Setenv(agentexec.EnvProcPrioMgmt, "hello") - t.Setenv(agentexec.EnvProcOOMScore, "432") - t.Setenv(agentexec.EnvProcNiceScore, "14") - - if runtime.GOOS != "linux" { - t.Skip("skipping on non-linux") - } - - executable, err := os.Executable() - require.NoError(t, err) - - cmd, err := agentexec.CommandContext(context.Background(), "sh", "-c", "sleep") - require.NoError(t, err) - require.Equal(t, executable, cmd.Path) - require.Equal(t, []string{executable, "agent-exec", "--coder-oom=432", "--coder-nice=14", "--", "sh", "-c", "sleep"}, cmd.Args) - }) - }) -} diff --git a/agent/agentscripts/agentscripts_test.go b/agent/agentscripts/agentscripts_test.go index 9435d3e046058..572f7b509d4d2 100644 --- a/agent/agentscripts/agentscripts_test.go +++ b/agent/agentscripts/agentscripts_test.go @@ -14,6 +14,7 @@ import ( "github.com/stretchr/testify/require" "go.uber.org/goleak" + "github.com/coder/coder/v2/agent/agentexec" "github.com/coder/coder/v2/agent/agentscripts" "github.com/coder/coder/v2/agent/agentssh" "github.com/coder/coder/v2/agent/agenttest" @@ -160,7 +161,7 @@ func setup(t *testing.T, getScriptLogger func(logSourceID uuid.UUID) agentscript } fs := afero.NewMemMapFs() logger := testutil.Logger(t) - s, err := agentssh.NewServer(context.Background(), logger, prometheus.NewRegistry(), fs, nil) + s, err := agentssh.NewServer(context.Background(), logger, prometheus.NewRegistry(), fs, agentexec.DefaultExecer, nil) require.NoError(t, err) t.Cleanup(func() { _ = s.Close() diff --git a/agent/agentssh/agentssh.go b/agent/agentssh/agentssh.go index 415674c9e2e95..dae1b73b2de6c 100644 --- a/agent/agentssh/agentssh.go +++ b/agent/agentssh/agentssh.go @@ -98,6 +98,7 @@ type Server struct { // a lock on mu but protected by closing. wg sync.WaitGroup + Execer agentexec.Execer logger slog.Logger srv *ssh.Server @@ -110,7 +111,7 @@ type Server struct { metrics *sshServerMetrics } -func NewServer(ctx context.Context, logger slog.Logger, prometheusRegistry *prometheus.Registry, fs afero.Fs, config *Config) (*Server, error) { +func NewServer(ctx context.Context, logger slog.Logger, prometheusRegistry *prometheus.Registry, fs afero.Fs, execer agentexec.Execer, config *Config) (*Server, error) { // Clients' should ignore the host key when connecting. // The agent needs to authenticate with coderd to SSH, // so SSH authentication doesn't improve security. @@ -153,6 +154,7 @@ func NewServer(ctx context.Context, logger slog.Logger, prometheusRegistry *prom metrics := newSSHServerMetrics(prometheusRegistry) s := &Server{ + Execer: execer, listeners: make(map[net.Listener]struct{}), fs: fs, conns: make(map[net.Conn]struct{}), @@ -726,10 +728,7 @@ func (s *Server) CreateCommand(ctx context.Context, script string, env []string) } } - cmd, err := agentexec.PTYCommandContext(ctx, name, args...) - if err != nil { - return nil, xerrors.Errorf("pty command context: %w", err) - } + cmd := s.Execer.PTYCommandContext(ctx, name, args...) cmd.Dir = s.config.WorkingDirectory() // If the metadata directory doesn't exist, we run the command diff --git a/agent/agentssh/agentssh_internal_test.go b/agent/agentssh/agentssh_internal_test.go index fd1958848306b..0ffa45df19b0d 100644 --- a/agent/agentssh/agentssh_internal_test.go +++ b/agent/agentssh/agentssh_internal_test.go @@ -15,6 +15,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/coder/coder/v2/agent/agentexec" "github.com/coder/coder/v2/pty" "github.com/coder/coder/v2/testutil" ) @@ -35,7 +36,7 @@ func Test_sessionStart_orphan(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium) defer cancel() logger := testutil.Logger(t) - s, err := NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), nil) + s, err := NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), agentexec.DefaultExecer, nil) require.NoError(t, err) defer s.Close() diff --git a/agent/agentssh/agentssh_test.go b/agent/agentssh/agentssh_test.go index cb76e3ee2582a..dfe67290c358b 100644 --- a/agent/agentssh/agentssh_test.go +++ b/agent/agentssh/agentssh_test.go @@ -22,6 +22,7 @@ import ( "cdr.dev/slog/sloggers/slogtest" + "github.com/coder/coder/v2/agent/agentexec" "github.com/coder/coder/v2/agent/agentssh" "github.com/coder/coder/v2/pty/ptytest" "github.com/coder/coder/v2/testutil" @@ -36,7 +37,7 @@ func TestNewServer_ServeClient(t *testing.T) { ctx := context.Background() logger := testutil.Logger(t) - s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), nil) + s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), agentexec.DefaultExecer, nil) require.NoError(t, err) defer s.Close() @@ -77,7 +78,7 @@ func TestNewServer_ExecuteShebang(t *testing.T) { ctx := context.Background() logger := testutil.Logger(t) - s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), nil) + s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), agentexec.DefaultExecer, nil) require.NoError(t, err) t.Cleanup(func() { _ = s.Close() @@ -108,7 +109,7 @@ func TestNewServer_CloseActiveConnections(t *testing.T) { ctx := context.Background() logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}) - s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), nil) + s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), agentexec.DefaultExecer, nil) require.NoError(t, err) defer s.Close() @@ -159,7 +160,7 @@ func TestNewServer_Signal(t *testing.T) { ctx := context.Background() logger := testutil.Logger(t) - s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), nil) + s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), agentexec.DefaultExecer, nil) require.NoError(t, err) defer s.Close() @@ -224,7 +225,7 @@ func TestNewServer_Signal(t *testing.T) { ctx := context.Background() logger := testutil.Logger(t) - s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), nil) + s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), agentexec.DefaultExecer, nil) require.NoError(t, err) defer s.Close() diff --git a/agent/agentssh/x11_test.go b/agent/agentssh/x11_test.go index bba801e176042..057da9a21e642 100644 --- a/agent/agentssh/x11_test.go +++ b/agent/agentssh/x11_test.go @@ -21,6 +21,7 @@ import ( "github.com/stretchr/testify/require" gossh "golang.org/x/crypto/ssh" + "github.com/coder/coder/v2/agent/agentexec" "github.com/coder/coder/v2/agent/agentssh" "github.com/coder/coder/v2/testutil" ) @@ -34,7 +35,7 @@ func TestServer_X11(t *testing.T) { ctx := context.Background() logger := testutil.Logger(t) fs := afero.NewOsFs() - s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), fs, &agentssh.Config{}) + s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), fs, agentexec.DefaultExecer, &agentssh.Config{}) require.NoError(t, err) defer s.Close() diff --git a/agent/reconnectingpty/buffered.go b/agent/reconnectingpty/buffered.go index cde41fb227c9e..6f314333a725e 100644 --- a/agent/reconnectingpty/buffered.go +++ b/agent/reconnectingpty/buffered.go @@ -40,7 +40,7 @@ type bufferedReconnectingPTY struct { // newBuffered starts the buffered pty. If the context ends the process will be // killed. -func newBuffered(ctx context.Context, cmd *pty.Cmd, options *Options, logger slog.Logger) *bufferedReconnectingPTY { +func newBuffered(ctx context.Context, logger slog.Logger, execer agentexec.Execer, cmd *pty.Cmd, options *Options) *bufferedReconnectingPTY { rpty := &bufferedReconnectingPTY{ activeConns: map[string]net.Conn{}, command: cmd, @@ -59,11 +59,7 @@ func newBuffered(ctx context.Context, cmd *pty.Cmd, options *Options, logger slo // Add TERM then start the command with a pty. pty.Cmd duplicates Path as the // first argument so remove it. - cmdWithEnv, err := agentexec.PTYCommandContext(ctx, cmd.Path, cmd.Args[1:]...) - if err != nil { - rpty.state.setState(StateDone, xerrors.Errorf("pty command context: %w", err)) - return rpty - } + cmdWithEnv := execer.PTYCommandContext(ctx, cmd.Path, cmd.Args[1:]...) cmdWithEnv.Env = append(rpty.command.Env, "TERM=xterm-256color") cmdWithEnv.Dir = rpty.command.Dir ptty, process, err := pty.Start(cmdWithEnv) diff --git a/agent/reconnectingpty/reconnectingpty.go b/agent/reconnectingpty/reconnectingpty.go index fffe199f59b54..b5c4e0aaa0b39 100644 --- a/agent/reconnectingpty/reconnectingpty.go +++ b/agent/reconnectingpty/reconnectingpty.go @@ -14,6 +14,7 @@ import ( "golang.org/x/xerrors" "cdr.dev/slog" + "github.com/coder/coder/v2/agent/agentexec" "github.com/coder/coder/v2/codersdk/workspacesdk" "github.com/coder/coder/v2/pty" ) @@ -55,7 +56,7 @@ type ReconnectingPTY interface { // close itself (and all connections to it) if nothing is attached for the // duration of the timeout, if the context ends, or the process exits (buffered // backend only). -func New(ctx context.Context, cmd *pty.Cmd, options *Options, logger slog.Logger) ReconnectingPTY { +func New(ctx context.Context, logger slog.Logger, execer agentexec.Execer, cmd *pty.Cmd, options *Options) ReconnectingPTY { if options.Timeout == 0 { options.Timeout = 5 * time.Minute } @@ -75,9 +76,9 @@ func New(ctx context.Context, cmd *pty.Cmd, options *Options, logger slog.Logger switch backendType { case "screen": - return newScreen(ctx, cmd, options, logger) + return newScreen(ctx, logger, execer, cmd, options) default: - return newBuffered(ctx, cmd, options, logger) + return newBuffered(ctx, logger, execer, cmd, options) } } diff --git a/agent/reconnectingpty/screen.go b/agent/reconnectingpty/screen.go index 122ef1fffc792..98d21c5959d7b 100644 --- a/agent/reconnectingpty/screen.go +++ b/agent/reconnectingpty/screen.go @@ -25,6 +25,7 @@ import ( // screenReconnectingPTY provides a reconnectable PTY via `screen`. type screenReconnectingPTY struct { + execer agentexec.Execer command *pty.Cmd // id holds the id of the session for both creating and attaching. This will @@ -59,8 +60,9 @@ type screenReconnectingPTY struct { // spawns the daemon with a hardcoded 24x80 size it is not a very good user // experience. Instead we will let the attach command spawn the daemon on its // own which causes it to spawn with the specified size. -func newScreen(ctx context.Context, cmd *pty.Cmd, options *Options, logger slog.Logger) *screenReconnectingPTY { +func newScreen(ctx context.Context, logger slog.Logger, execer agentexec.Execer, cmd *pty.Cmd, options *Options) *screenReconnectingPTY { rpty := &screenReconnectingPTY{ + execer: execer, command: cmd, metrics: options.Metrics, state: newState(), @@ -210,7 +212,7 @@ func (rpty *screenReconnectingPTY) doAttach(ctx context.Context, conn net.Conn, logger.Debug(ctx, "spawning screen client", slog.F("screen_id", rpty.id)) // Wrap the command with screen and tie it to the connection's context. - cmd, err := agentexec.PTYCommandContext(ctx, "screen", append([]string{ + cmd := rpty.execer.PTYCommandContext(ctx, "screen", append([]string{ // -S is for setting the session's name. "-S", rpty.id, // -U tells screen to use UTF-8 encoding. @@ -223,9 +225,6 @@ func (rpty *screenReconnectingPTY) doAttach(ctx context.Context, conn net.Conn, rpty.command.Path, // pty.Cmd duplicates Path as the first argument so remove it. }, rpty.command.Args[1:]...)...) - if err != nil { - return nil, nil, xerrors.Errorf("pty command context: %w", err) - } cmd.Env = append(rpty.command.Env, "TERM=xterm-256color") cmd.Dir = rpty.command.Dir ptty, process, err := pty.Start(cmd, pty.WithPTYOption( @@ -333,7 +332,7 @@ func (rpty *screenReconnectingPTY) sendCommand(ctx context.Context, command stri run := func() (bool, error) { var stdout bytes.Buffer //nolint:gosec - cmd, err := agentexec.CommandContext(ctx, "screen", + cmd := rpty.execer.CommandContext(ctx, "screen", // -x targets an attached session. "-x", rpty.id, // -c is the flag for the config file. @@ -341,13 +340,10 @@ func (rpty *screenReconnectingPTY) sendCommand(ctx context.Context, command stri // -X runs a command in the matching session. "-X", command, ) - if err != nil { - return false, xerrors.Errorf("command context: %w", err) - } cmd.Env = append(rpty.command.Env, "TERM=xterm-256color") cmd.Dir = rpty.command.Dir cmd.Stdout = &stdout - err = cmd.Run() + err := cmd.Run() if err == nil { return true, nil } diff --git a/agent/reconnectingpty/server.go b/agent/reconnectingpty/server.go index 052a88e52b0b4..d48c7abec9353 100644 --- a/agent/reconnectingpty/server.go +++ b/agent/reconnectingpty/server.go @@ -165,10 +165,15 @@ func (s *Server) handleConn(ctx context.Context, logger slog.Logger, conn net.Co return xerrors.Errorf("create command: %w", err) } - rpty = New(ctx, cmd, &Options{ - Timeout: s.timeout, - Metrics: s.errorsTotal, - }, logger.With(slog.F("message_id", msg.ID))) + rpty = New(ctx, + logger.With(slog.F("message_id", msg.ID)), + s.commandCreator.Execer, + cmd, + &Options{ + Timeout: s.timeout, + Metrics: s.errorsTotal, + }, + ) done := make(chan struct{}) go func() { diff --git a/cli/agent.go b/cli/agent.go index f76f222d3053b..fc96aa6d323c3 100644 --- a/cli/agent.go +++ b/cli/agent.go @@ -309,6 +309,11 @@ func (r *RootCmd) workspaceAgent() *serpent.Command { ) } + execer, err := agentexec.NewExecer() + if err != nil { + return xerrors.Errorf("create agent execer: %w", err) + } + agnt := agent.New(agent.Options{ Client: client, Logger: logger, @@ -333,6 +338,7 @@ func (r *RootCmd) workspaceAgent() *serpent.Command { PrometheusRegistry: prometheusRegistry, BlockFileTransfer: blockFileTransfer, + Execer: execer, }) promHandler := agent.PrometheusMetricsHandler(prometheusRegistry, logger) diff --git a/scripts/rules.go b/scripts/rules.go index 57d067aeb5019..4e16adad06a87 100644 --- a/scripts/rules.go +++ b/scripts/rules.go @@ -503,7 +503,7 @@ func noExecInAgent(m dsl.Matcher) { !m.File().PkgPath.Matches("/agentexec") && !m.File().Name.Matches(`_test\.go$`), ). - Report("The agent and its subpackages should not use exec.Command or exec.CommandContext directly. Consider using agentexec.CommandContext instead.") + Report("The agent and its subpackages should not use exec.Command or exec.CommandContext directly. Consider using an agentexec.Execer instead.") } // noPTYInAgent ensures that packages under agent/ don't use pty.Command or @@ -521,5 +521,5 @@ func noPTYInAgent(m dsl.Matcher) { !m.File().PkgPath.Matches(`/agentexec`) && !m.File().Name.Matches(`_test\.go$`), ). - Report("The agent and its subpackages should not use pty.Command or pty.CommandContext directly. Consider using agentexec.PTYCommandContext instead.") + Report("The agent and its subpackages should not use pty.Command or pty.CommandContext directly. Consider using an agentexec.Execer instead.") }