Skip to content

Commit e7ca1b2

Browse files
committed
Add agent shutdown lifecycle states
1 parent 5e41bcb commit e7ca1b2

17 files changed

+500
-170
lines changed

agent/agent.go

Lines changed: 81 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@ func New(options Options) io.Closer {
121121
logDir: options.LogDir,
122122
tempDir: options.TempDir,
123123
lifecycleUpdate: make(chan struct{}, 1),
124+
lifecycleReported: make(chan codersdk.WorkspaceAgentLifecycle, 1),
124125
connStatsChan: make(chan *agentsdk.Stats, 1),
125126
}
126127
a.init(ctx)
@@ -149,9 +150,10 @@ type agent struct {
149150
sessionToken atomic.Pointer[string]
150151
sshServer *ssh.Server
151152

152-
lifecycleUpdate chan struct{}
153-
lifecycleMu sync.Mutex // Protects following.
154-
lifecycleState codersdk.WorkspaceAgentLifecycle
153+
lifecycleUpdate chan struct{}
154+
lifecycleReported chan codersdk.WorkspaceAgentLifecycle
155+
lifecycleMu sync.RWMutex // Protects following.
156+
lifecycleState codersdk.WorkspaceAgentLifecycle
155157

156158
network *tailnet.Conn
157159
connStatsChan chan *agentsdk.Stats
@@ -207,9 +209,9 @@ func (a *agent) reportLifecycleLoop(ctx context.Context) {
207209
}
208210

209211
for r := retry.New(time.Second, 15*time.Second); r.Wait(ctx); {
210-
a.lifecycleMu.Lock()
212+
a.lifecycleMu.RLock()
211213
state := a.lifecycleState
212-
a.lifecycleMu.Unlock()
214+
a.lifecycleMu.RUnlock()
213215

214216
if state == lastReported {
215217
break
@@ -222,6 +224,11 @@ func (a *agent) reportLifecycleLoop(ctx context.Context) {
222224
})
223225
if err == nil {
224226
lastReported = state
227+
select {
228+
case a.lifecycleReported <- state:
229+
case <-a.lifecycleReported:
230+
a.lifecycleReported <- state
231+
}
225232
break
226233
}
227234
if xerrors.Is(err, context.Canceled) || xerrors.Is(err, context.DeadlineExceeded) {
@@ -233,13 +240,20 @@ func (a *agent) reportLifecycleLoop(ctx context.Context) {
233240
}
234241
}
235242

243+
// setLifecycle sets the lifecycle state and notifies the lifecycle loop.
244+
// The state is only updated if it's a valid state transition.
236245
func (a *agent) setLifecycle(ctx context.Context, state codersdk.WorkspaceAgentLifecycle) {
237246
a.lifecycleMu.Lock()
238-
defer a.lifecycleMu.Unlock()
239-
240-
a.logger.Debug(ctx, "set lifecycle state", slog.F("state", state), slog.F("previous", a.lifecycleState))
241-
247+
lastState := a.lifecycleState
248+
if slices.Index(codersdk.WorkspaceAgentLifecycleOrder, lastState) > slices.Index(codersdk.WorkspaceAgentLifecycleOrder, state) {
249+
a.logger.Warn(ctx, "attempted to set lifecycle state to a previous state", slog.F("last", lastState), slog.F("state", state))
250+
a.lifecycleMu.Unlock()
251+
return
252+
}
242253
a.lifecycleState = state
254+
a.logger.Debug(ctx, "set lifecycle state", slog.F("state", state), slog.F("last", lastState))
255+
a.lifecycleMu.Unlock()
256+
243257
select {
244258
case a.lifecycleUpdate <- struct{}{}:
245259
default:
@@ -330,15 +344,15 @@ func (a *agent) run(ctx context.Context) error {
330344
return
331345
}
332346
execTime := time.Since(scriptStart)
333-
lifecycleStatus := codersdk.WorkspaceAgentLifecycleReady
347+
lifecycleState := codersdk.WorkspaceAgentLifecycleReady
334348
if err != nil {
335349
a.logger.Warn(ctx, "startup script failed", slog.F("execution_time", execTime), slog.Error(err))
336-
lifecycleStatus = codersdk.WorkspaceAgentLifecycleStartError
350+
lifecycleState = codersdk.WorkspaceAgentLifecycleStartError
337351
} else {
338352
a.logger.Info(ctx, "startup script completed", slog.F("execution_time", execTime))
339353
}
340354

341-
a.setLifecycle(ctx, lifecycleStatus)
355+
a.setLifecycle(ctx, lifecycleState)
342356
}()
343357
}
344358

@@ -1298,25 +1312,72 @@ func (a *agent) Close() error {
12981312
if a.isClosed() {
12991313
return nil
13001314
}
1301-
close(a.closed)
1302-
a.closeCancel()
13031315

1316+
ctx := context.Background()
1317+
a.setLifecycle(ctx, codersdk.WorkspaceAgentLifecycleShuttingDown)
1318+
1319+
// Close services before running shutdown script.
1320+
// TODO(mafredri): Gracefully shutdown:
1321+
// - Close active SSH server connections
1322+
// - Close processes (send HUP, wait, etc.)
1323+
1324+
lifecycleState := codersdk.WorkspaceAgentLifecycleOff
13041325
if metadata, ok := a.metadata.Load().(agentsdk.Metadata); ok {
1305-
ctx := context.Background()
1306-
err := a.runShutdownScript(ctx, metadata.ShutdownScript)
1326+
scriptDone := make(chan error, 1)
1327+
scriptStart := time.Now()
1328+
go func() {
1329+
defer close(scriptDone)
1330+
scriptDone <- a.runShutdownScript(ctx, metadata.ShutdownScript)
1331+
}()
1332+
1333+
var timeout <-chan time.Time
1334+
// If timeout is zero, an older version of the coder
1335+
// provider was used. Otherwise a timeout is always > 0.
1336+
if metadata.ShutdownScriptTimeout > 0 {
1337+
t := time.NewTimer(metadata.ShutdownScriptTimeout)
1338+
defer t.Stop()
1339+
timeout = t.C
1340+
}
1341+
1342+
var err error
1343+
select {
1344+
case err = <-scriptDone:
1345+
case <-timeout:
1346+
a.logger.Warn(ctx, "shutdown script timed out")
1347+
a.setLifecycle(ctx, codersdk.WorkspaceAgentLifecycleShutdownTimeout)
1348+
err = <-scriptDone // The script can still complete after a timeout.
1349+
}
1350+
execTime := time.Since(scriptStart)
13071351
if err != nil {
1308-
a.logger.Error(ctx, "shutdown script failed", slog.Error(err))
1352+
a.logger.Warn(ctx, "shutdown script failed", slog.F("execution_time", execTime), slog.Error(err))
1353+
lifecycleState = codersdk.WorkspaceAgentLifecycleShutdownError
1354+
} else {
1355+
a.logger.Info(ctx, "shutdown script completed", slog.F("execution_time", execTime))
13091356
}
1310-
} else {
1311-
// No metadata.. halt?
1357+
}
1358+
1359+
// Set final state and wait for it to be reported because context
1360+
// cancellation will stop the report loop.
1361+
a.setLifecycle(ctx, lifecycleState)
1362+
for s := range a.lifecycleReported {
1363+
if s == lifecycleState {
1364+
break
1365+
}
1366+
}
1367+
1368+
if lifecycleState != codersdk.WorkspaceAgentLifecycleOff {
1369+
// TODO(mafredri): Delay shutdown, ensure debugging is possible.
13121370
_ = false
13131371
}
13141372

1373+
close(a.closed)
1374+
a.closeCancel()
1375+
_ = a.sshServer.Close()
13151376
if a.network != nil {
13161377
_ = a.network.Close()
13171378
}
1318-
_ = a.sshServer.Close()
13191379
a.connCloseWait.Wait()
1380+
13201381
return nil
13211382
}
13221383

0 commit comments

Comments
 (0)