Skip to content

Commit 70ebaf3

Browse files
committed
Refactor agent scripts into it's own package
1 parent 9ae6e62 commit 70ebaf3

19 files changed

+789
-580
lines changed

agent/agent.go

Lines changed: 68 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ import (
3434
"tailscale.com/types/netlogtype"
3535

3636
"cdr.dev/slog"
37+
"github.com/coder/coder/v2/agent/agentscripts"
3738
"github.com/coder/coder/v2/agent/agentssh"
3839
"github.com/coder/coder/v2/agent/reconnectingpty"
3940
"github.com/coder/coder/v2/buildinfo"
@@ -177,6 +178,7 @@ type agent struct {
177178

178179
manifest atomic.Pointer[agentsdk.Manifest] // manifest is atomic because values can change after reconnection.
179180
reportMetadataInterval time.Duration
181+
scriptRunner *agentscripts.Runner
180182
serviceBanner atomic.Pointer[codersdk.ServiceBannerConfig] // serviceBanner is atomic because it is periodically updated.
181183
serviceBannerRefreshInterval time.Duration
182184
sessionToken atomic.Pointer[string]
@@ -213,7 +215,13 @@ func (a *agent) init(ctx context.Context) {
213215
sshSrv.Manifest = &a.manifest
214216
sshSrv.ServiceBanner = &a.serviceBanner
215217
a.sshServer = sshSrv
216-
218+
a.scriptRunner = agentscripts.New(ctx, agentscripts.Options{
219+
LogDir: a.logDir,
220+
Logger: a.logger,
221+
SSHServer: sshSrv,
222+
Filesystem: a.filesystem,
223+
PatchLogs: a.client.PatchLogs,
224+
})
217225
go a.runLoop(ctx)
218226
}
219227

@@ -631,41 +639,28 @@ func (a *agent) run(ctx context.Context) error {
631639
}
632640
}
633641

634-
lifecycleState := codersdk.WorkspaceAgentLifecycleReady
635-
scriptDone := make(chan error, 1)
636-
err = a.trackConnGoroutine(func() {
637-
defer close(scriptDone)
638-
scriptDone <- a.runStartupScript(ctx, manifest.StartupScript)
639-
})
642+
err = a.scriptRunner.Init(manifest.Scripts)
640643
if err != nil {
641-
return xerrors.Errorf("track startup script: %w", err)
644+
return xerrors.Errorf("init script runner: %w", err)
642645
}
643-
go func() {
644-
var timeout <-chan time.Time
645-
// If timeout is zero, an older version of the coder
646-
// provider was used. Otherwise a timeout is always > 0.
647-
if manifest.StartupScriptTimeout > 0 {
648-
t := time.NewTimer(manifest.StartupScriptTimeout)
649-
defer t.Stop()
650-
timeout = t.C
651-
}
652-
653-
var err error
654-
select {
655-
case err = <-scriptDone:
656-
case <-timeout:
657-
a.logger.Warn(ctx, "script timed out", slog.F("lifecycle", "startup"), slog.F("timeout", manifest.StartupScriptTimeout))
658-
a.setLifecycle(ctx, codersdk.WorkspaceAgentLifecycleStartTimeout)
659-
err = <-scriptDone // The script can still complete after a timeout.
660-
}
646+
err = a.trackConnGoroutine(func() {
647+
err := a.scriptRunner.Execute(func(script codersdk.WorkspaceAgentScript) bool {
648+
return script.RunOnStart
649+
})
661650
if err != nil {
662-
if errors.Is(err, context.Canceled) {
663-
return
651+
a.logger.Warn(ctx, "startup script failed", slog.Error(err))
652+
if errors.Is(err, agentscripts.ErrTimeout) {
653+
a.setLifecycle(ctx, codersdk.WorkspaceAgentLifecycleStartTimeout)
654+
} else {
655+
a.setLifecycle(ctx, codersdk.WorkspaceAgentLifecycleStartError)
664656
}
665-
lifecycleState = codersdk.WorkspaceAgentLifecycleStartError
657+
} else {
658+
a.setLifecycle(ctx, codersdk.WorkspaceAgentLifecycleReady)
666659
}
667-
a.setLifecycle(ctx, lifecycleState)
668-
}()
660+
})
661+
if err != nil {
662+
return xerrors.Errorf("track conn goroutine: %w", err)
663+
}
669664
}
670665

671666
// This automatically closes when the context ends!
@@ -980,63 +975,48 @@ func (a *agent) runDERPMapSubscriber(ctx context.Context, network *tailnet.Conn)
980975
}
981976
}
982977

983-
func (a *agent) runStartupScript(ctx context.Context, script string) error {
984-
return a.runScript(ctx, "startup", script)
985-
}
986-
987-
func (a *agent) runShutdownScript(ctx context.Context, script string) error {
988-
return a.runScript(ctx, "shutdown", script)
989-
}
990-
991-
func (a *agent) runScript(ctx context.Context, lifecycle, script string) (err error) {
992-
if script == "" {
978+
func (a *agent) runScript(ctx context.Context, script codersdk.WorkspaceAgentScript) (err error) {
979+
if script.Script == "" {
993980
return nil
994981
}
995982

996-
logger := a.logger.With(slog.F("lifecycle", lifecycle))
983+
logger := a.logger.With(slog.F("log_source", script.LogSourceDisplayName))
997984

998-
logger.Info(ctx, fmt.Sprintf("running %s script", lifecycle), slog.F("script", script))
999-
fileWriter, err := a.filesystem.OpenFile(filepath.Join(a.logDir, fmt.Sprintf("coder-%s-script.log", lifecycle)), os.O_CREATE|os.O_RDWR, 0o600)
985+
logger.Info(ctx, "running script", slog.F("script", script.Script))
986+
fileWriter, err := a.filesystem.OpenFile(filepath.Join(a.logDir, fmt.Sprintf("coder-%s-script.log", script.LogSourceDisplayName)), os.O_CREATE|os.O_RDWR, 0o600)
1000987
if err != nil {
1001-
return xerrors.Errorf("open %s script log file: %w", lifecycle, err)
988+
return xerrors.Errorf("open %s script log file: %w", script.LogSourceDisplayName, err)
1002989
}
1003990
defer func() {
1004991
err := fileWriter.Close()
1005992
if err != nil {
1006-
logger.Warn(ctx, fmt.Sprintf("close %s script log file", lifecycle), slog.Error(err))
993+
logger.Warn(ctx, fmt.Sprintf("close %s script log file", script.LogSourceDisplayName), slog.Error(err))
1007994
}
1008995
}()
1009996

1010-
cmdPty, err := a.sshServer.CreateCommand(ctx, script, nil)
997+
cmdPty, err := a.sshServer.CreateCommand(ctx, script.Script, nil)
1011998
if err != nil {
1012-
return xerrors.Errorf("%s script: create command: %w", lifecycle, err)
999+
return xerrors.Errorf("%s script: create command: %w", script.LogSourceDisplayName, err)
10131000
}
10141001
cmd := cmdPty.AsExec()
10151002

1016-
var stdout, stderr io.Writer = fileWriter, fileWriter
1017-
if lifecycle == "startup" {
1018-
send, flushAndClose := agentsdk.LogsSender(a.client.PatchLogs, logger)
1019-
// If ctx is canceled here (or in a writer below), we may be
1020-
// discarding logs, but that's okay because we're shutting down
1021-
// anyway. We could consider creating a new context here if we
1022-
// want better control over flush during shutdown.
1023-
defer func() {
1024-
if err := flushAndClose(ctx); err != nil {
1025-
logger.Warn(ctx, "flush startup logs failed", slog.Error(err))
1026-
}
1027-
}()
1028-
1029-
infoW := agentsdk.StartupLogsWriter(ctx, send, codersdk.WorkspaceAgentLogSourceStartupScript, codersdk.LogLevelInfo)
1030-
defer infoW.Close()
1031-
errW := agentsdk.StartupLogsWriter(ctx, send, codersdk.WorkspaceAgentLogSourceStartupScript, codersdk.LogLevelError)
1032-
defer errW.Close()
1033-
1034-
stdout = io.MultiWriter(fileWriter, infoW)
1035-
stderr = io.MultiWriter(fileWriter, errW)
1036-
}
1003+
send, flushAndClose := agentsdk.LogsSender(script.LogSourceID, a.client.PatchLogs, logger)
1004+
// If ctx is canceled here (or in a writer below), we may be
1005+
// discarding logs, but that's okay because we're shutting down
1006+
// anyway. We could consider creating a new context here if we
1007+
// want better control over flush during shutdown.
1008+
defer func() {
1009+
if err := flushAndClose(ctx); err != nil {
1010+
logger.Warn(ctx, "flush startup logs failed", slog.Error(err))
1011+
}
1012+
}()
10371013

1038-
cmd.Stdout = stdout
1039-
cmd.Stderr = stderr
1014+
infoW := agentsdk.StartupLogsWriter(ctx, send, script.LogSourceID, codersdk.LogLevelInfo)
1015+
defer infoW.Close()
1016+
errW := agentsdk.StartupLogsWriter(ctx, send, script.LogSourceID, codersdk.LogLevelError)
1017+
defer errW.Close()
1018+
cmd.Stdout = io.MultiWriter(fileWriter, infoW)
1019+
cmd.Stderr = io.MultiWriter(fileWriter, errW)
10401020

10411021
start := time.Now()
10421022
defer func() {
@@ -1049,9 +1029,9 @@ func (a *agent) runScript(ctx context.Context, lifecycle, script string) (err er
10491029
if xerrors.As(err, &exitError) {
10501030
exitCode = exitError.ExitCode()
10511031
}
1052-
logger.Warn(ctx, fmt.Sprintf("%s script failed", lifecycle), slog.F("execution_time", execTime), slog.F("exit_code", exitCode), slog.Error(err))
1032+
logger.Warn(ctx, fmt.Sprintf("%s script failed", script.LogSourceDisplayName), slog.F("execution_time", execTime), slog.F("exit_code", exitCode), slog.Error(err))
10531033
} else {
1054-
logger.Info(ctx, fmt.Sprintf("%s script completed", lifecycle), slog.F("execution_time", execTime), slog.F("exit_code", exitCode))
1034+
logger.Info(ctx, fmt.Sprintf("%s script completed", script.LogSourceDisplayName), slog.F("execution_time", execTime), slog.F("exit_code", exitCode))
10551035
}
10561036
}()
10571037

@@ -1062,7 +1042,7 @@ func (a *agent) runScript(ctx context.Context, lifecycle, script string) (err er
10621042
return ctx.Err()
10631043
}
10641044

1065-
return xerrors.Errorf("%s script: run: %w", lifecycle, err)
1045+
return xerrors.Errorf("%s script: run: %w", script.LogSourceDisplayName, err)
10661046
}
10671047
return nil
10681048
}
@@ -1336,39 +1316,25 @@ func (a *agent) Close() error {
13361316
}
13371317

13381318
lifecycleState := codersdk.WorkspaceAgentLifecycleOff
1339-
if manifest := a.manifest.Load(); manifest != nil && manifest.ShutdownScript != "" {
1340-
scriptDone := make(chan error, 1)
1341-
go func() {
1342-
defer close(scriptDone)
1343-
scriptDone <- a.runShutdownScript(ctx, manifest.ShutdownScript)
1344-
}()
1345-
1346-
var timeout <-chan time.Time
1347-
// If timeout is zero, an older version of the coder
1348-
// provider was used. Otherwise a timeout is always > 0.
1349-
if manifest.ShutdownScriptTimeout > 0 {
1350-
t := time.NewTimer(manifest.ShutdownScriptTimeout)
1351-
defer t.Stop()
1352-
timeout = t.C
1353-
}
1354-
1355-
var err error
1356-
select {
1357-
case err = <-scriptDone:
1358-
case <-timeout:
1359-
a.logger.Warn(ctx, "script timed out", slog.F("lifecycle", "shutdown"), slog.F("timeout", manifest.ShutdownScriptTimeout))
1360-
a.setLifecycle(ctx, codersdk.WorkspaceAgentLifecycleShutdownTimeout)
1361-
err = <-scriptDone // The script can still complete after a timeout.
1362-
}
1363-
if err != nil {
1319+
err = a.scriptRunner.Execute(func(script codersdk.WorkspaceAgentScript) bool {
1320+
return script.RunOnStop
1321+
})
1322+
if err != nil {
1323+
if errors.Is(err, agentscripts.ErrTimeout) {
1324+
lifecycleState = codersdk.WorkspaceAgentLifecycleShutdownTimeout
1325+
} else {
13641326
lifecycleState = codersdk.WorkspaceAgentLifecycleShutdownError
13651327
}
1328+
} else {
1329+
lifecycleState = codersdk.WorkspaceAgentLifecycleOff
13661330
}
1367-
1368-
// Set final state and wait for it to be reported because context
1369-
// cancellation will stop the report loop.
13701331
a.setLifecycle(ctx, lifecycleState)
13711332

1333+
err = a.scriptRunner.Close()
1334+
if err != nil {
1335+
a.logger.Error(ctx, "script runner close", slog.Error(err))
1336+
}
1337+
13721338
// Wait for the lifecycle to be reported, but don't wait forever so
13731339
// that we don't break user expectations.
13741340
ctx, cancel := context.WithTimeout(ctx, 5*time.Second)

0 commit comments

Comments
 (0)