Skip to content

Commit 425ee6f

Browse files
feat: reinitialize agents when a prebuilt workspace is claimed (#17475)
This pull request allows coder workspace agents to be reinitialized when a prebuilt workspace is claimed by a user. This facilitates the transfer of ownership between the anonymous prebuilds system user and the new owner of the workspace. Only a single agent per prebuilt workspace is supported for now, but plumbing has already been done to facilitate the seamless transition to multi-agent support. --------- Signed-off-by: Danny Kopping <dannykopping@gmail.com> Co-authored-by: Danny Kopping <dannykopping@gmail.com>
1 parent fcbdd1a commit 425ee6f

38 files changed

+2187
-452
lines changed

agent/agent.go

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -368,9 +368,11 @@ func (a *agent) runLoop() {
368368
if ctx.Err() != nil {
369369
// Context canceled errors may come from websocket pings, so we
370370
// don't want to use `errors.Is(err, context.Canceled)` here.
371+
a.logger.Warn(ctx, "runLoop exited with error", slog.Error(ctx.Err()))
371372
return
372373
}
373374
if a.isClosed() {
375+
a.logger.Warn(ctx, "runLoop exited because agent is closed")
374376
return
375377
}
376378
if errors.Is(err, io.EOF) {
@@ -1051,7 +1053,11 @@ func (a *agent) run() (retErr error) {
10511053
return a.statsReporter.reportLoop(ctx, aAPI)
10521054
})
10531055

1054-
return connMan.wait()
1056+
err = connMan.wait()
1057+
if err != nil {
1058+
a.logger.Info(context.Background(), "connection manager errored", slog.Error(err))
1059+
}
1060+
return err
10551061
}
10561062

10571063
// handleManifest returns a function that fetches and processes the manifest

cli/agent.go

Lines changed: 70 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ import (
2525
"cdr.dev/slog/sloggers/sloghuman"
2626
"cdr.dev/slog/sloggers/slogjson"
2727
"cdr.dev/slog/sloggers/slogstackdriver"
28+
"github.com/coder/serpent"
29+
2830
"github.com/coder/coder/v2/agent"
2931
"github.com/coder/coder/v2/agent/agentexec"
3032
"github.com/coder/coder/v2/agent/agentssh"
@@ -33,7 +35,6 @@ import (
3335
"github.com/coder/coder/v2/cli/clilog"
3436
"github.com/coder/coder/v2/codersdk"
3537
"github.com/coder/coder/v2/codersdk/agentsdk"
36-
"github.com/coder/serpent"
3738
)
3839

3940
func (r *RootCmd) workspaceAgent() *serpent.Command {
@@ -63,8 +64,10 @@ func (r *RootCmd) workspaceAgent() *serpent.Command {
6364
// This command isn't useful to manually execute.
6465
Hidden: true,
6566
Handler: func(inv *serpent.Invocation) error {
66-
ctx, cancel := context.WithCancel(inv.Context())
67-
defer cancel()
67+
ctx, cancel := context.WithCancelCause(inv.Context())
68+
defer func() {
69+
cancel(xerrors.New("agent exited"))
70+
}()
6871

6972
var (
7073
ignorePorts = map[int]string{}
@@ -281,7 +284,6 @@ func (r *RootCmd) workspaceAgent() *serpent.Command {
281284
return xerrors.Errorf("add executable to $PATH: %w", err)
282285
}
283286

284-
prometheusRegistry := prometheus.NewRegistry()
285287
subsystemsRaw := inv.Environ.Get(agent.EnvAgentSubsystem)
286288
subsystems := []codersdk.AgentSubsystem{}
287289
for _, s := range strings.Split(subsystemsRaw, ",") {
@@ -325,46 +327,70 @@ func (r *RootCmd) workspaceAgent() *serpent.Command {
325327
logger.Info(ctx, "agent devcontainer detection not enabled")
326328
}
327329

328-
agnt := agent.New(agent.Options{
329-
Client: client,
330-
Logger: logger,
331-
LogDir: logDir,
332-
ScriptDataDir: scriptDataDir,
333-
// #nosec G115 - Safe conversion as tailnet listen port is within uint16 range (0-65535)
334-
TailnetListenPort: uint16(tailnetListenPort),
335-
ExchangeToken: func(ctx context.Context) (string, error) {
336-
if exchangeToken == nil {
337-
return client.SDK.SessionToken(), nil
338-
}
339-
resp, err := exchangeToken(ctx)
340-
if err != nil {
341-
return "", err
342-
}
343-
client.SetSessionToken(resp.SessionToken)
344-
return resp.SessionToken, nil
345-
},
346-
EnvironmentVariables: environmentVariables,
347-
IgnorePorts: ignorePorts,
348-
SSHMaxTimeout: sshMaxTimeout,
349-
Subsystems: subsystems,
350-
351-
PrometheusRegistry: prometheusRegistry,
352-
BlockFileTransfer: blockFileTransfer,
353-
Execer: execer,
354-
SubAgent: subAgent,
355-
356-
ExperimentalDevcontainersEnabled: experimentalDevcontainersEnabled,
357-
})
358-
359-
promHandler := agent.PrometheusMetricsHandler(prometheusRegistry, logger)
360-
prometheusSrvClose := ServeHandler(ctx, logger, promHandler, prometheusAddress, "prometheus")
361-
defer prometheusSrvClose()
362-
363-
debugSrvClose := ServeHandler(ctx, logger, agnt.HTTPDebug(), debugAddress, "debug")
364-
defer debugSrvClose()
365-
366-
<-ctx.Done()
367-
return agnt.Close()
330+
reinitEvents := agentsdk.WaitForReinitLoop(ctx, logger, client)
331+
332+
var (
333+
lastErr error
334+
mustExit bool
335+
)
336+
for {
337+
prometheusRegistry := prometheus.NewRegistry()
338+
339+
agnt := agent.New(agent.Options{
340+
Client: client,
341+
Logger: logger,
342+
LogDir: logDir,
343+
ScriptDataDir: scriptDataDir,
344+
// #nosec G115 - Safe conversion as tailnet listen port is within uint16 range (0-65535)
345+
TailnetListenPort: uint16(tailnetListenPort),
346+
ExchangeToken: func(ctx context.Context) (string, error) {
347+
if exchangeToken == nil {
348+
return client.SDK.SessionToken(), nil
349+
}
350+
resp, err := exchangeToken(ctx)
351+
if err != nil {
352+
return "", err
353+
}
354+
client.SetSessionToken(resp.SessionToken)
355+
return resp.SessionToken, nil
356+
},
357+
EnvironmentVariables: environmentVariables,
358+
IgnorePorts: ignorePorts,
359+
SSHMaxTimeout: sshMaxTimeout,
360+
Subsystems: subsystems,
361+
362+
PrometheusRegistry: prometheusRegistry,
363+
BlockFileTransfer: blockFileTransfer,
364+
Execer: execer,
365+
SubAgent: subAgent,
366+
ExperimentalDevcontainersEnabled: experimentalDevcontainersEnabled,
367+
})
368+
369+
promHandler := agent.PrometheusMetricsHandler(prometheusRegistry, logger)
370+
prometheusSrvClose := ServeHandler(ctx, logger, promHandler, prometheusAddress, "prometheus")
371+
372+
debugSrvClose := ServeHandler(ctx, logger, agnt.HTTPDebug(), debugAddress, "debug")
373+
374+
select {
375+
case <-ctx.Done():
376+
logger.Info(ctx, "agent shutting down", slog.Error(context.Cause(ctx)))
377+
mustExit = true
378+
case event := <-reinitEvents:
379+
logger.Info(ctx, "agent received instruction to reinitialize",
380+
slog.F("workspace_id", event.WorkspaceID), slog.F("reason", event.Reason))
381+
}
382+
383+
lastErr = agnt.Close()
384+
debugSrvClose()
385+
prometheusSrvClose()
386+
387+
if mustExit {
388+
break
389+
}
390+
391+
logger.Info(ctx, "agent reinitializing")
392+
}
393+
return lastErr
368394
},
369395
}
370396

coderd/apidoc/docs.go

Lines changed: 45 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

coderd/apidoc/swagger.json

Lines changed: 37 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

coderd/coderd.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ import (
1919
"sync/atomic"
2020
"time"
2121

22+
"github.com/coder/coder/v2/coderd/prebuilds"
23+
2224
"github.com/andybalholm/brotli"
2325
"github.com/go-chi/chi/v5"
2426
"github.com/go-chi/chi/v5/middleware"
@@ -47,7 +49,6 @@ import (
4749
"github.com/coder/coder/v2/coderd/entitlements"
4850
"github.com/coder/coder/v2/coderd/files"
4951
"github.com/coder/coder/v2/coderd/idpsync"
50-
"github.com/coder/coder/v2/coderd/prebuilds"
5152
"github.com/coder/coder/v2/coderd/runtimeconfig"
5253
"github.com/coder/coder/v2/coderd/webpush"
5354

@@ -1299,6 +1300,7 @@ func New(options *Options) *API {
12991300
r.Get("/external-auth", api.workspaceAgentsExternalAuth)
13001301
r.Get("/gitsshkey", api.agentGitSSHKey)
13011302
r.Post("/log-source", api.workspaceAgentPostLogSource)
1303+
r.Get("/reinit", api.workspaceAgentReinit)
13021304
})
13031305
r.Route("/{workspaceagent}", func(r chi.Router) {
13041306
r.Use(

coderd/coderdtest/coderdtest.go

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1105,6 +1105,69 @@ func (w WorkspaceAgentWaiter) MatchResources(m func([]codersdk.WorkspaceResource
11051105
return w
11061106
}
11071107

1108+
// WaitForAgentFn represents a boolean assertion to be made against each agent
1109+
// that a given WorkspaceAgentWaited knows about. Each WaitForAgentFn should apply
1110+
// the check to a single agent, but it should be named for plural, because `func (w WorkspaceAgentWaiter) WaitFor`
1111+
// applies the check to all agents that it is aware of. This ensures that the public API of the waiter
1112+
// reads correctly. For example:
1113+
//
1114+
// waiter := coderdtest.NewWorkspaceAgentWaiter(t, client, r.Workspace.ID)
1115+
// waiter.WaitFor(coderdtest.AgentsReady)
1116+
type WaitForAgentFn func(agent codersdk.WorkspaceAgent) bool
1117+
1118+
// AgentsReady checks that the latest lifecycle state of an agent is "Ready".
1119+
func AgentsReady(agent codersdk.WorkspaceAgent) bool {
1120+
return agent.LifecycleState == codersdk.WorkspaceAgentLifecycleReady
1121+
}
1122+
1123+
// AgentsNotReady checks that the latest lifecycle state of an agent is anything except "Ready".
1124+
func AgentsNotReady(agent codersdk.WorkspaceAgent) bool {
1125+
return !AgentsReady(agent)
1126+
}
1127+
1128+
func (w WorkspaceAgentWaiter) WaitFor(criteria ...WaitForAgentFn) {
1129+
w.t.Helper()
1130+
1131+
agentNamesMap := make(map[string]struct{}, len(w.agentNames))
1132+
for _, name := range w.agentNames {
1133+
agentNamesMap[name] = struct{}{}
1134+
}
1135+
1136+
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
1137+
defer cancel()
1138+
1139+
w.t.Logf("waiting for workspace agents (workspace %s)", w.workspaceID)
1140+
require.Eventually(w.t, func() bool {
1141+
var err error
1142+
workspace, err := w.client.Workspace(ctx, w.workspaceID)
1143+
if err != nil {
1144+
return false
1145+
}
1146+
if workspace.LatestBuild.Job.CompletedAt == nil {
1147+
return false
1148+
}
1149+
if workspace.LatestBuild.Job.CompletedAt.IsZero() {
1150+
return false
1151+
}
1152+
1153+
for _, resource := range workspace.LatestBuild.Resources {
1154+
for _, agent := range resource.Agents {
1155+
if len(w.agentNames) > 0 {
1156+
if _, ok := agentNamesMap[agent.Name]; !ok {
1157+
continue
1158+
}
1159+
}
1160+
for _, criterium := range criteria {
1161+
if !criterium(agent) {
1162+
return false
1163+
}
1164+
}
1165+
}
1166+
}
1167+
return true
1168+
}, testutil.WaitLong, testutil.IntervalMedium)
1169+
}
1170+
11081171
// Wait waits for the agent(s) to connect and fails the test if they do not within testutil.WaitLong
11091172
func (w WorkspaceAgentWaiter) Wait() []codersdk.WorkspaceResource {
11101173
w.t.Helper()

coderd/database/dbauthz/dbauthz.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3020,6 +3020,15 @@ func (q *querier) GetWorkspaceAgentsByResourceIDs(ctx context.Context, ids []uui
30203020
return q.db.GetWorkspaceAgentsByResourceIDs(ctx, ids)
30213021
}
30223022

3023+
func (q *querier) GetWorkspaceAgentsByWorkspaceAndBuildNumber(ctx context.Context, arg database.GetWorkspaceAgentsByWorkspaceAndBuildNumberParams) ([]database.WorkspaceAgent, error) {
3024+
_, err := q.GetWorkspaceByID(ctx, arg.WorkspaceID)
3025+
if err != nil {
3026+
return nil, err
3027+
}
3028+
3029+
return q.db.GetWorkspaceAgentsByWorkspaceAndBuildNumber(ctx, arg)
3030+
}
3031+
30233032
func (q *querier) GetWorkspaceAgentsCreatedAfter(ctx context.Context, createdAt time.Time) ([]database.WorkspaceAgent, error) {
30243033
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceSystem); err != nil {
30253034
return nil, err

0 commit comments

Comments
 (0)