Skip to content

Commit 797ffd1

Browse files
committed
Add agent shutdown lifecycle states
1 parent 56be22f commit 797ffd1

17 files changed

+499
-169
lines changed

agent/agent.go

Lines changed: 81 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ func New(options Options) io.Closer {
112112
logDir: options.LogDir,
113113
tempDir: options.TempDir,
114114
lifecycleUpdate: make(chan struct{}, 1),
115+
lifecycleReported: make(chan codersdk.WorkspaceAgentLifecycle, 1),
115116
connStatsChan: make(chan *agentsdk.Stats, 1),
116117
}
117118
a.init(ctx)
@@ -140,9 +141,10 @@ type agent struct {
140141
sessionToken atomic.Pointer[string]
141142
sshServer *ssh.Server
142143

143-
lifecycleUpdate chan struct{}
144-
lifecycleMu sync.Mutex // Protects following.
145-
lifecycleState codersdk.WorkspaceAgentLifecycle
144+
lifecycleUpdate chan struct{}
145+
lifecycleReported chan codersdk.WorkspaceAgentLifecycle
146+
lifecycleMu sync.RWMutex // Protects following.
147+
lifecycleState codersdk.WorkspaceAgentLifecycle
146148

147149
network *tailnet.Conn
148150
connStatsChan chan *agentsdk.Stats
@@ -189,9 +191,9 @@ func (a *agent) reportLifecycleLoop(ctx context.Context) {
189191
}
190192

191193
for r := retry.New(time.Second, 15*time.Second); r.Wait(ctx); {
192-
a.lifecycleMu.Lock()
194+
a.lifecycleMu.RLock()
193195
state := a.lifecycleState
194-
a.lifecycleMu.Unlock()
196+
a.lifecycleMu.RUnlock()
195197

196198
if state == lastReported {
197199
break
@@ -204,6 +206,11 @@ func (a *agent) reportLifecycleLoop(ctx context.Context) {
204206
})
205207
if err == nil {
206208
lastReported = state
209+
select {
210+
case a.lifecycleReported <- state:
211+
case <-a.lifecycleReported:
212+
a.lifecycleReported <- state
213+
}
207214
break
208215
}
209216
if xerrors.Is(err, context.Canceled) || xerrors.Is(err, context.DeadlineExceeded) {
@@ -215,13 +222,20 @@ func (a *agent) reportLifecycleLoop(ctx context.Context) {
215222
}
216223
}
217224

225+
// setLifecycle sets the lifecycle state and notifies the lifecycle loop.
226+
// The state is only updated if it's a valid state transition.
218227
func (a *agent) setLifecycle(ctx context.Context, state codersdk.WorkspaceAgentLifecycle) {
219228
a.lifecycleMu.Lock()
220-
defer a.lifecycleMu.Unlock()
221-
222-
a.logger.Debug(ctx, "set lifecycle state", slog.F("state", state), slog.F("previous", a.lifecycleState))
223-
229+
lastState := a.lifecycleState
230+
if slices.Index(codersdk.WorkspaceAgentLifecycleOrder, lastState) > slices.Index(codersdk.WorkspaceAgentLifecycleOrder, state) {
231+
a.logger.Warn(ctx, "attempted to set lifecycle state to a previous state", slog.F("last", lastState), slog.F("state", state))
232+
a.lifecycleMu.Unlock()
233+
return
234+
}
224235
a.lifecycleState = state
236+
a.logger.Debug(ctx, "set lifecycle state", slog.F("state", state), slog.F("last", lastState))
237+
a.lifecycleMu.Unlock()
238+
225239
select {
226240
case a.lifecycleUpdate <- struct{}{}:
227241
default:
@@ -312,15 +326,15 @@ func (a *agent) run(ctx context.Context) error {
312326
return
313327
}
314328
execTime := time.Since(scriptStart)
315-
lifecycleStatus := codersdk.WorkspaceAgentLifecycleReady
329+
lifecycleState := codersdk.WorkspaceAgentLifecycleReady
316330
if err != nil {
317331
a.logger.Warn(ctx, "startup script failed", slog.F("execution_time", execTime), slog.Error(err))
318-
lifecycleStatus = codersdk.WorkspaceAgentLifecycleStartError
332+
lifecycleState = codersdk.WorkspaceAgentLifecycleStartError
319333
} else {
320334
a.logger.Info(ctx, "startup script completed", slog.F("execution_time", execTime))
321335
}
322336

323-
a.setLifecycle(ctx, lifecycleStatus)
337+
a.setLifecycle(ctx, lifecycleState)
324338
}()
325339
}
326340

@@ -1203,25 +1217,72 @@ func (a *agent) Close() error {
12031217
if a.isClosed() {
12041218
return nil
12051219
}
1206-
close(a.closed)
1207-
a.closeCancel()
12081220

1221+
ctx := context.Background()
1222+
a.setLifecycle(ctx, codersdk.WorkspaceAgentLifecycleShuttingDown)
1223+
1224+
// Close services before running shutdown script.
1225+
// TODO(mafredri): Gracefully shutdown:
1226+
// - Close active SSH server connections
1227+
// - Close processes (send HUP, wait, etc.)
1228+
1229+
lifecycleState := codersdk.WorkspaceAgentLifecycleOff
12091230
if metadata, ok := a.metadata.Load().(agentsdk.Metadata); ok {
1210-
ctx := context.Background()
1211-
err := a.runShutdownScript(ctx, metadata.ShutdownScript)
1231+
scriptDone := make(chan error, 1)
1232+
scriptStart := time.Now()
1233+
go func() {
1234+
defer close(scriptDone)
1235+
scriptDone <- a.runShutdownScript(ctx, metadata.ShutdownScript)
1236+
}()
1237+
1238+
var timeout <-chan time.Time
1239+
// If timeout is zero, an older version of the coder
1240+
// provider was used. Otherwise a timeout is always > 0.
1241+
if metadata.ShutdownScriptTimeout > 0 {
1242+
t := time.NewTimer(metadata.ShutdownScriptTimeout)
1243+
defer t.Stop()
1244+
timeout = t.C
1245+
}
1246+
1247+
var err error
1248+
select {
1249+
case err = <-scriptDone:
1250+
case <-timeout:
1251+
a.logger.Warn(ctx, "shutdown script timed out")
1252+
a.setLifecycle(ctx, codersdk.WorkspaceAgentLifecycleShutdownTimeout)
1253+
err = <-scriptDone // The script can still complete after a timeout.
1254+
}
1255+
execTime := time.Since(scriptStart)
12121256
if err != nil {
1213-
a.logger.Error(ctx, "shutdown script failed", slog.Error(err))
1257+
a.logger.Warn(ctx, "shutdown script failed", slog.F("execution_time", execTime), slog.Error(err))
1258+
lifecycleState = codersdk.WorkspaceAgentLifecycleShutdownError
1259+
} else {
1260+
a.logger.Info(ctx, "shutdown script completed", slog.F("execution_time", execTime))
12141261
}
1215-
} else {
1216-
// No metadata.. halt?
1262+
}
1263+
1264+
// Set final state and wait for it to be reported because context
1265+
// cancellation will stop the report loop.
1266+
a.setLifecycle(ctx, lifecycleState)
1267+
for s := range a.lifecycleReported {
1268+
if s == lifecycleState {
1269+
break
1270+
}
1271+
}
1272+
1273+
if lifecycleState != codersdk.WorkspaceAgentLifecycleOff {
1274+
// TODO(mafredri): Delay shutdown, ensure debugging is possible.
12171275
_ = false
12181276
}
12191277

1278+
close(a.closed)
1279+
a.closeCancel()
1280+
_ = a.sshServer.Close()
12201281
if a.network != nil {
12211282
_ = a.network.Close()
12221283
}
1223-
_ = a.sshServer.Close()
12241284
a.connCloseWait.Wait()
1285+
12251286
return nil
12261287
}
12271288

0 commit comments

Comments
 (0)