diff --git a/agent/agent.go b/agent/agent.go index a369432c0390f..c0a61fa97fe98 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -66,6 +66,7 @@ type Options struct { Filesystem afero.Fs LogDir string TempDir string + ScriptDataDir string ExchangeToken func(ctx context.Context) (string, error) Client Client ReconnectingPTYTimeout time.Duration @@ -112,9 +113,19 @@ func New(options Options) Agent { if options.LogDir == "" { if options.TempDir != os.TempDir() { options.Logger.Debug(context.Background(), "log dir not set, using temp dir", slog.F("temp_dir", options.TempDir)) + } else { + options.Logger.Debug(context.Background(), "using log dir", slog.F("log_dir", options.LogDir)) } options.LogDir = options.TempDir } + if options.ScriptDataDir == "" { + if options.TempDir != os.TempDir() { + options.Logger.Debug(context.Background(), "script data dir not set, using temp dir", slog.F("temp_dir", options.TempDir)) + } else { + options.Logger.Debug(context.Background(), "using script data dir", slog.F("script_data_dir", options.ScriptDataDir)) + } + options.ScriptDataDir = options.TempDir + } if options.ExchangeToken == nil { options.ExchangeToken = func(ctx context.Context) (string, error) { return "", nil @@ -152,6 +163,7 @@ func New(options Options) Agent { filesystem: options.Filesystem, logDir: options.LogDir, tempDir: options.TempDir, + scriptDataDir: options.ScriptDataDir, lifecycleUpdate: make(chan struct{}, 1), lifecycleReported: make(chan codersdk.WorkspaceAgentLifecycle, 1), lifecycleStates: []agentsdk.PostLifecycleRequest{{State: codersdk.WorkspaceAgentLifecycleCreated}}, @@ -183,6 +195,7 @@ type agent struct { filesystem afero.Fs logDir string tempDir string + scriptDataDir string // ignorePorts tells the api handler which ports to ignore when // listing all listening ports. This is helpful to hide ports that // are used by the agent, that the user does not care about. @@ -249,11 +262,12 @@ func (a *agent) init(ctx context.Context) { } a.sshServer = sshSrv a.scriptRunner = agentscripts.New(agentscripts.Options{ - LogDir: a.logDir, - Logger: a.logger, - SSHServer: sshSrv, - Filesystem: a.filesystem, - PatchLogs: a.client.PatchLogs, + LogDir: a.logDir, + DataDirBase: a.scriptDataDir, + Logger: a.logger, + SSHServer: sshSrv, + Filesystem: a.filesystem, + PatchLogs: a.client.PatchLogs, }) // Register runner metrics. If the prom registry is nil, the metrics // will not report anywhere. @@ -954,6 +968,13 @@ func (a *agent) updateCommandEnv(current []string) (updated []string, err error) envs[k] = v } + // Prepend the agent script bin directory to the PATH + // (this is where Coder modules place their binaries). + if _, ok := envs["PATH"]; !ok { + envs["PATH"] = os.Getenv("PATH") + } + envs["PATH"] = fmt.Sprintf("%s%c%s", a.scriptRunner.ScriptBinDir(), filepath.ListSeparator, envs["PATH"]) + for k, v := range envs { updated = append(updated, fmt.Sprintf("%s=%s", k, v)) } diff --git a/agent/agent_test.go b/agent/agent_test.go index b894beeca905f..573440769806c 100644 --- a/agent/agent_test.go +++ b/agent/agent_test.go @@ -286,6 +286,12 @@ func TestAgent_SessionExec(t *testing.T) { func TestAgent_Session_EnvironmentVariables(t *testing.T) { t.Parallel() + tmpdir := t.TempDir() + + // Defined by the coder script runner, hardcoded here since we don't + // have a reference to it. + scriptBinDir := filepath.Join(tmpdir, "coder-script-data", "bin") + manifest := agentsdk.Manifest{ EnvironmentVariables: map[string]string{ "MY_MANIFEST": "true", @@ -295,6 +301,7 @@ func TestAgent_Session_EnvironmentVariables(t *testing.T) { } banner := codersdk.ServiceBannerConfig{} session := setupSSHSession(t, manifest, banner, nil, func(_ *agenttest.Client, opts *agent.Options) { + opts.ScriptDataDir = tmpdir opts.EnvironmentVariables["MY_OVERRIDE"] = "true" }) @@ -341,6 +348,7 @@ func TestAgent_Session_EnvironmentVariables(t *testing.T) { "MY_OVERRIDE": "true", // From the agent environment variables option, overrides manifest. "MY_SESSION_MANIFEST": "false", // From the manifest, overrides session env. "MY_SESSION": "true", // From the session. + "PATH": scriptBinDir + string(filepath.ListSeparator), } { t.Run(k, func(t *testing.T) { echoEnv(t, stdin, k) diff --git a/agent/agentscripts/agentscripts.go b/agent/agentscripts/agentscripts.go index f6052d605432e..e7169f9fdb699 100644 --- a/agent/agentscripts/agentscripts.go +++ b/agent/agentscripts/agentscripts.go @@ -43,11 +43,12 @@ var ( // Options are a set of options for the runner. type Options struct { - LogDir string - Logger slog.Logger - SSHServer *agentssh.Server - Filesystem afero.Fs - PatchLogs func(ctx context.Context, req agentsdk.PatchLogs) error + DataDirBase string + LogDir string + Logger slog.Logger + SSHServer *agentssh.Server + Filesystem afero.Fs + PatchLogs func(ctx context.Context, req agentsdk.PatchLogs) error } // New creates a runner for the provided scripts. @@ -59,6 +60,7 @@ func New(opts Options) *Runner { cronCtxCancel: cronCtxCancel, cron: cron.New(cron.WithParser(parser)), closed: make(chan struct{}), + dataDir: filepath.Join(opts.DataDirBase, "coder-script-data"), scriptsExecuted: prometheus.NewCounterVec(prometheus.CounterOpts{ Namespace: "agent", Subsystem: "scripts", @@ -78,6 +80,7 @@ type Runner struct { cron *cron.Cron initialized atomic.Bool scripts []codersdk.WorkspaceAgentScript + dataDir string // scriptsExecuted includes all scripts executed by the workspace agent. Agents // execute startup scripts, and scripts on a cron schedule. Both will increment @@ -85,6 +88,17 @@ type Runner struct { scriptsExecuted *prometheus.CounterVec } +// DataDir returns the directory where scripts data is stored. +func (r *Runner) DataDir() string { + return r.dataDir +} + +// ScriptBinDir returns the directory where scripts can store executable +// binaries. +func (r *Runner) ScriptBinDir() string { + return filepath.Join(r.dataDir, "bin") +} + func (r *Runner) RegisterMetrics(reg prometheus.Registerer) { if reg == nil { // If no registry, do nothing. @@ -104,6 +118,11 @@ func (r *Runner) Init(scripts []codersdk.WorkspaceAgentScript) error { r.scripts = scripts r.Logger.Info(r.cronCtx, "initializing agent scripts", slog.F("script_count", len(scripts)), slog.F("log_dir", r.LogDir)) + err := r.Filesystem.MkdirAll(r.ScriptBinDir(), 0o700) + if err != nil { + return xerrors.Errorf("create script bin dir: %w", err) + } + for _, script := range scripts { if script.Cron == "" { continue @@ -208,7 +227,18 @@ func (r *Runner) run(ctx context.Context, script codersdk.WorkspaceAgentScript) if !filepath.IsAbs(logPath) { logPath = filepath.Join(r.LogDir, logPath) } - logger := r.Logger.With(slog.F("log_path", logPath)) + + scriptDataDir := filepath.Join(r.DataDir(), script.LogSourceID.String()) + err := r.Filesystem.MkdirAll(scriptDataDir, 0o700) + if err != nil { + return xerrors.Errorf("%s script: create script temp dir: %w", scriptDataDir, err) + } + + logger := r.Logger.With( + slog.F("log_source_id", script.LogSourceID), + slog.F("log_path", logPath), + slog.F("script_data_dir", scriptDataDir), + ) logger.Info(ctx, "running agent script", slog.F("script", script.Script)) fileWriter, err := r.Filesystem.OpenFile(logPath, os.O_CREATE|os.O_RDWR, 0o600) @@ -238,6 +268,13 @@ func (r *Runner) run(ctx context.Context, script codersdk.WorkspaceAgentScript) cmd.WaitDelay = 10 * time.Second cmd.Cancel = cmdCancel(cmd) + // Expose env vars that can be used in the script for storing data + // and binaries. In the future, we may want to expose more env vars + // for the script to use, like CODER_SCRIPT_DATA_DIR for persistent + // storage. + cmd.Env = append(cmd.Env, "CODER_SCRIPT_DATA_DIR="+scriptDataDir) + cmd.Env = append(cmd.Env, "CODER_SCRIPT_BIN_DIR="+r.ScriptBinDir()) + send, flushAndClose := agentsdk.LogsSender(script.LogSourceID, r.PatchLogs, logger) // If ctx is canceled here (or in a writer below), we may be // discarding logs, but that's okay because we're shutting down diff --git a/agent/agentscripts/agentscripts_test.go b/agent/agentscripts/agentscripts_test.go index bb3f842a45962..d7fce25fda1fa 100644 --- a/agent/agentscripts/agentscripts_test.go +++ b/agent/agentscripts/agentscripts_test.go @@ -2,11 +2,15 @@ package agentscripts_test import ( "context" + "path/filepath" + "runtime" "testing" "time" + "github.com/google/uuid" "github.com/prometheus/client_golang/prometheus" "github.com/spf13/afero" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.uber.org/goleak" @@ -15,6 +19,7 @@ import ( "github.com/coder/coder/v2/agent/agentssh" "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/codersdk/agentsdk" + "github.com/coder/coder/v2/testutil" ) func TestMain(m *testing.M) { @@ -25,12 +30,16 @@ func TestExecuteBasic(t *testing.T) { t.Parallel() logs := make(chan agentsdk.PatchLogs, 1) runner := setup(t, func(ctx context.Context, req agentsdk.PatchLogs) error { - logs <- req + select { + case <-ctx.Done(): + case logs <- req: + } return nil }) defer runner.Close() err := runner.Init([]codersdk.WorkspaceAgentScript{{ - Script: "echo hello", + LogSourceID: uuid.New(), + Script: "echo hello", }}) require.NoError(t, err) require.NoError(t, runner.Execute(context.Background(), func(script codersdk.WorkspaceAgentScript) bool { @@ -40,13 +49,67 @@ func TestExecuteBasic(t *testing.T) { require.Equal(t, "hello", log.Logs[0].Output) } +func TestEnv(t *testing.T) { + t.Parallel() + logs := make(chan agentsdk.PatchLogs, 2) + runner := setup(t, func(ctx context.Context, req agentsdk.PatchLogs) error { + select { + case <-ctx.Done(): + case logs <- req: + } + return nil + }) + defer runner.Close() + id := uuid.New() + script := "echo $CODER_SCRIPT_DATA_DIR\necho $CODER_SCRIPT_BIN_DIR\n" + if runtime.GOOS == "windows" { + script = ` + cmd.exe /c echo %CODER_SCRIPT_DATA_DIR% + cmd.exe /c echo %CODER_SCRIPT_BIN_DIR% + ` + } + err := runner.Init([]codersdk.WorkspaceAgentScript{{ + LogSourceID: id, + Script: script, + }}) + require.NoError(t, err) + + ctx := testutil.Context(t, testutil.WaitLong) + + testutil.Go(t, func() { + err := runner.Execute(ctx, func(script codersdk.WorkspaceAgentScript) bool { + return true + }) + assert.NoError(t, err) + }) + + var log []agentsdk.Log + for { + select { + case <-ctx.Done(): + require.Fail(t, "timed out waiting for logs") + case l := <-logs: + for _, l := range l.Logs { + t.Logf("log: %s", l.Output) + } + log = append(log, l.Logs...) + } + if len(log) >= 2 { + break + } + } + require.Contains(t, log[0].Output, filepath.Join(runner.DataDir(), id.String())) + require.Contains(t, log[1].Output, runner.ScriptBinDir()) +} + func TestTimeout(t *testing.T) { t.Parallel() runner := setup(t, nil) defer runner.Close() err := runner.Init([]codersdk.WorkspaceAgentScript{{ - Script: "sleep infinity", - Timeout: time.Millisecond, + LogSourceID: uuid.New(), + Script: "sleep infinity", + Timeout: time.Millisecond, }}) require.NoError(t, err) require.ErrorIs(t, runner.Execute(context.Background(), nil), agentscripts.ErrTimeout) @@ -77,10 +140,11 @@ func setup(t *testing.T, patchLogs func(ctx context.Context, req agentsdk.PatchL _ = s.Close() }) return agentscripts.New(agentscripts.Options{ - LogDir: t.TempDir(), - Logger: logger, - SSHServer: s, - Filesystem: fs, - PatchLogs: patchLogs, + LogDir: t.TempDir(), + DataDirBase: t.TempDir(), + Logger: logger, + SSHServer: s, + Filesystem: fs, + PatchLogs: patchLogs, }) } diff --git a/cli/agent.go b/cli/agent.go index c951ec7509901..23473022abea7 100644 --- a/cli/agent.go +++ b/cli/agent.go @@ -40,6 +40,7 @@ func (r *RootCmd) workspaceAgent() *clibase.Cmd { var ( auth string logDir string + scriptDataDir string pprofAddress string noReap bool sshMaxTimeout time.Duration @@ -289,6 +290,7 @@ func (r *RootCmd) workspaceAgent() *clibase.Cmd { Client: client, Logger: logger, LogDir: logDir, + ScriptDataDir: scriptDataDir, TailnetListenPort: uint16(tailnetListenPort), ExchangeToken: func(ctx context.Context) (string, error) { if exchangeToken == nil { @@ -339,6 +341,13 @@ func (r *RootCmd) workspaceAgent() *clibase.Cmd { Env: "CODER_AGENT_LOG_DIR", Value: clibase.StringOf(&logDir), }, + { + Flag: "script-data-dir", + Default: os.TempDir(), + Description: "Specify the location for storing script data.", + Env: "CODER_AGENT_SCRIPT_DATA_DIR", + Value: clibase.StringOf(&scriptDataDir), + }, { Flag: "pprof-address", Default: "127.0.0.1:6060", diff --git a/cli/testdata/coder_agent_--help.golden b/cli/testdata/coder_agent_--help.golden index 08dab47a21e14..372395c4ba5fe 100644 --- a/cli/testdata/coder_agent_--help.golden +++ b/cli/testdata/coder_agent_--help.golden @@ -33,6 +33,9 @@ OPTIONS: --prometheus-address string, $CODER_AGENT_PROMETHEUS_ADDRESS (default: 127.0.0.1:2112) The bind address to serve Prometheus metrics. + --script-data-dir string, $CODER_AGENT_SCRIPT_DATA_DIR (default: /tmp) + Specify the location for storing script data. + --ssh-max-timeout duration, $CODER_AGENT_SSH_MAX_TIMEOUT (default: 72h) Specify the max timeout for a SSH connection, it is advisable to set it to a minimum of 60s, but no more than 72h.