Skip to content

Commit 93ce739

Browse files
feat: reinitialize agents when a prebuilt workspace is claimed (#17475)
This pull request allows coder workspace agents to be reinitialized when a prebuilt workspace is claimed by a user. This facilitates the transfer of ownership between the anonymous prebuilds system user and the new owner of the workspace. Only a single agent per prebuilt workspace is supported for now, but plumbing has already been done to facilitate the seamless transition to multi-agent support. --------- Signed-off-by: Danny Kopping <dannykopping@gmail.com> Co-authored-by: Danny Kopping <dannykopping@gmail.com> Signed-off-by: Danny Kopping <dannykopping@gmail.com>
1 parent ca314ba commit 93ce739

39 files changed

+2164
-423
lines changed

agent/agent.go

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -363,9 +363,11 @@ func (a *agent) runLoop() {
363363
if ctx.Err() != nil {
364364
// Context canceled errors may come from websocket pings, so we
365365
// don't want to use `errors.Is(err, context.Canceled)` here.
366+
a.logger.Warn(ctx, "runLoop exited with error", slog.Error(ctx.Err()))
366367
return
367368
}
368369
if a.isClosed() {
370+
a.logger.Warn(ctx, "runLoop exited because agent is closed")
369371
return
370372
}
371373
if errors.Is(err, io.EOF) {
@@ -1046,7 +1048,11 @@ func (a *agent) run() (retErr error) {
10461048
return a.statsReporter.reportLoop(ctx, aAPI)
10471049
})
10481050

1049-
return connMan.wait()
1051+
err = connMan.wait()
1052+
if err != nil {
1053+
a.logger.Info(context.Background(), "connection manager errored", slog.Error(err))
1054+
}
1055+
return err
10501056
}
10511057

10521058
// handleManifest returns a function that fetches and processes the manifest

cli/agent.go

Lines changed: 60 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ import (
2525
"cdr.dev/slog/sloggers/sloghuman"
2626
"cdr.dev/slog/sloggers/slogjson"
2727
"cdr.dev/slog/sloggers/slogstackdriver"
28+
"github.com/coder/serpent"
29+
2830
"github.com/coder/coder/v2/agent"
2931
"github.com/coder/coder/v2/agent/agentexec"
3032
"github.com/coder/coder/v2/agent/agentssh"
@@ -33,7 +35,6 @@ import (
3335
"github.com/coder/coder/v2/cli/clilog"
3436
"github.com/coder/coder/v2/codersdk"
3537
"github.com/coder/coder/v2/codersdk/agentsdk"
36-
"github.com/coder/serpent"
3738
)
3839

3940
func (r *RootCmd) workspaceAgent() *serpent.Command {
@@ -62,8 +63,10 @@ func (r *RootCmd) workspaceAgent() *serpent.Command {
6263
// This command isn't useful to manually execute.
6364
Hidden: true,
6465
Handler: func(inv *serpent.Invocation) error {
65-
ctx, cancel := context.WithCancel(inv.Context())
66-
defer cancel()
66+
ctx, cancel := context.WithCancelCause(inv.Context())
67+
defer func() {
68+
cancel(xerrors.New("agent exited"))
69+
}()
6770

6871
var (
6972
ignorePorts = map[int]string{}
@@ -280,7 +283,6 @@ func (r *RootCmd) workspaceAgent() *serpent.Command {
280283
return xerrors.Errorf("add executable to $PATH: %w", err)
281284
}
282285

283-
prometheusRegistry := prometheus.NewRegistry()
284286
subsystemsRaw := inv.Environ.Get(agent.EnvAgentSubsystem)
285287
subsystems := []codersdk.AgentSubsystem{}
286288
for _, s := range strings.Split(subsystemsRaw, ",") {
@@ -324,28 +326,37 @@ func (r *RootCmd) workspaceAgent() *serpent.Command {
324326
logger.Info(ctx, "agent devcontainer detection not enabled")
325327
}
326328

327-
agnt := agent.New(agent.Options{
328-
Client: client,
329-
Logger: logger,
330-
LogDir: logDir,
331-
ScriptDataDir: scriptDataDir,
332-
// #nosec G115 - Safe conversion as tailnet listen port is within uint16 range (0-65535)
333-
TailnetListenPort: uint16(tailnetListenPort),
334-
ExchangeToken: func(ctx context.Context) (string, error) {
335-
if exchangeToken == nil {
336-
return client.SDK.SessionToken(), nil
337-
}
338-
resp, err := exchangeToken(ctx)
339-
if err != nil {
340-
return "", err
341-
}
342-
client.SetSessionToken(resp.SessionToken)
343-
return resp.SessionToken, nil
344-
},
345-
EnvironmentVariables: environmentVariables,
346-
IgnorePorts: ignorePorts,
347-
SSHMaxTimeout: sshMaxTimeout,
348-
Subsystems: subsystems,
329+
reinitEvents := agentsdk.WaitForReinitLoop(ctx, logger, client)
330+
331+
var (
332+
lastErr error
333+
mustExit bool
334+
)
335+
for {
336+
prometheusRegistry := prometheus.NewRegistry()
337+
338+
agnt := agent.New(agent.Options{
339+
Client: client,
340+
Logger: logger,
341+
LogDir: logDir,
342+
ScriptDataDir: scriptDataDir,
343+
// #nosec G115 - Safe conversion as tailnet listen port is within uint16 range (0-65535)
344+
TailnetListenPort: uint16(tailnetListenPort),
345+
ExchangeToken: func(ctx context.Context) (string, error) {
346+
if exchangeToken == nil {
347+
return client.SDK.SessionToken(), nil
348+
}
349+
resp, err := exchangeToken(ctx)
350+
if err != nil {
351+
return "", err
352+
}
353+
client.SetSessionToken(resp.SessionToken)
354+
return resp.SessionToken, nil
355+
},
356+
EnvironmentVariables: environmentVariables,
357+
IgnorePorts: ignorePorts,
358+
SSHMaxTimeout: sshMaxTimeout,
359+
Subsystems: subsystems,
349360

350361
PrometheusRegistry: prometheusRegistry,
351362
BlockFileTransfer: blockFileTransfer,
@@ -354,15 +365,31 @@ func (r *RootCmd) workspaceAgent() *serpent.Command {
354365
ExperimentalDevcontainersEnabled: experimentalDevcontainersEnabled,
355366
})
356367

357-
promHandler := agent.PrometheusMetricsHandler(prometheusRegistry, logger)
358-
prometheusSrvClose := ServeHandler(ctx, logger, promHandler, prometheusAddress, "prometheus")
359-
defer prometheusSrvClose()
368+
promHandler := agent.PrometheusMetricsHandler(prometheusRegistry, logger)
369+
prometheusSrvClose := ServeHandler(ctx, logger, promHandler, prometheusAddress, "prometheus")
370+
371+
debugSrvClose := ServeHandler(ctx, logger, agnt.HTTPDebug(), debugAddress, "debug")
360372

361-
debugSrvClose := ServeHandler(ctx, logger, agnt.HTTPDebug(), debugAddress, "debug")
362-
defer debugSrvClose()
373+
select {
374+
case <-ctx.Done():
375+
logger.Info(ctx, "agent shutting down", slog.Error(context.Cause(ctx)))
376+
mustExit = true
377+
case event := <-reinitEvents:
378+
logger.Info(ctx, "agent received instruction to reinitialize",
379+
slog.F("workspace_id", event.WorkspaceID), slog.F("reason", event.Reason))
380+
}
381+
382+
lastErr = agnt.Close()
383+
debugSrvClose()
384+
prometheusSrvClose()
363385

364-
<-ctx.Done()
365-
return agnt.Close()
386+
if mustExit {
387+
break
388+
}
389+
390+
logger.Info(ctx, "agent reinitializing")
391+
}
392+
return lastErr
366393
},
367394
}
368395

coderd/apidoc/docs.go

Lines changed: 45 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

coderd/apidoc/swagger.json

Lines changed: 37 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

coderd/coderd.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ import (
1919
"sync/atomic"
2020
"time"
2121

22+
"github.com/coder/coder/v2/coderd/prebuilds"
23+
2224
"github.com/andybalholm/brotli"
2325
"github.com/go-chi/chi/v5"
2426
"github.com/go-chi/chi/v5/middleware"
@@ -45,7 +47,6 @@ import (
4547
"github.com/coder/coder/v2/coderd/entitlements"
4648
"github.com/coder/coder/v2/coderd/files"
4749
"github.com/coder/coder/v2/coderd/idpsync"
48-
"github.com/coder/coder/v2/coderd/prebuilds"
4950
"github.com/coder/coder/v2/coderd/runtimeconfig"
5051
"github.com/coder/coder/v2/coderd/webpush"
5152

@@ -1278,6 +1279,7 @@ func New(options *Options) *API {
12781279
r.Get("/external-auth", api.workspaceAgentsExternalAuth)
12791280
r.Get("/gitsshkey", api.agentGitSSHKey)
12801281
r.Post("/log-source", api.workspaceAgentPostLogSource)
1282+
r.Get("/reinit", api.workspaceAgentReinit)
12811283
})
12821284
r.Route("/{workspaceagent}", func(r chi.Router) {
12831285
r.Use(

coderd/coderdtest/coderdtest.go

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1105,6 +1105,69 @@ func (w WorkspaceAgentWaiter) MatchResources(m func([]codersdk.WorkspaceResource
11051105
return w
11061106
}
11071107

1108+
// WaitForAgentFn represents a boolean assertion to be made against each agent
1109+
// that a given WorkspaceAgentWaited knows about. Each WaitForAgentFn should apply
1110+
// the check to a single agent, but it should be named for plural, because `func (w WorkspaceAgentWaiter) WaitFor`
1111+
// applies the check to all agents that it is aware of. This ensures that the public API of the waiter
1112+
// reads correctly. For example:
1113+
//
1114+
// waiter := coderdtest.NewWorkspaceAgentWaiter(t, client, r.Workspace.ID)
1115+
// waiter.WaitFor(coderdtest.AgentsReady)
1116+
type WaitForAgentFn func(agent codersdk.WorkspaceAgent) bool
1117+
1118+
// AgentsReady checks that the latest lifecycle state of an agent is "Ready".
1119+
func AgentsReady(agent codersdk.WorkspaceAgent) bool {
1120+
return agent.LifecycleState == codersdk.WorkspaceAgentLifecycleReady
1121+
}
1122+
1123+
// AgentsNotReady checks that the latest lifecycle state of an agent is anything except "Ready".
1124+
func AgentsNotReady(agent codersdk.WorkspaceAgent) bool {
1125+
return !AgentsReady(agent)
1126+
}
1127+
1128+
func (w WorkspaceAgentWaiter) WaitFor(criteria ...WaitForAgentFn) {
1129+
w.t.Helper()
1130+
1131+
agentNamesMap := make(map[string]struct{}, len(w.agentNames))
1132+
for _, name := range w.agentNames {
1133+
agentNamesMap[name] = struct{}{}
1134+
}
1135+
1136+
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
1137+
defer cancel()
1138+
1139+
w.t.Logf("waiting for workspace agents (workspace %s)", w.workspaceID)
1140+
require.Eventually(w.t, func() bool {
1141+
var err error
1142+
workspace, err := w.client.Workspace(ctx, w.workspaceID)
1143+
if err != nil {
1144+
return false
1145+
}
1146+
if workspace.LatestBuild.Job.CompletedAt == nil {
1147+
return false
1148+
}
1149+
if workspace.LatestBuild.Job.CompletedAt.IsZero() {
1150+
return false
1151+
}
1152+
1153+
for _, resource := range workspace.LatestBuild.Resources {
1154+
for _, agent := range resource.Agents {
1155+
if len(w.agentNames) > 0 {
1156+
if _, ok := agentNamesMap[agent.Name]; !ok {
1157+
continue
1158+
}
1159+
}
1160+
for _, criterium := range criteria {
1161+
if !criterium(agent) {
1162+
return false
1163+
}
1164+
}
1165+
}
1166+
}
1167+
return true
1168+
}, testutil.WaitLong, testutil.IntervalMedium)
1169+
}
1170+
11081171
// Wait waits for the agent(s) to connect and fails the test if they do not within testutil.WaitLong
11091172
func (w WorkspaceAgentWaiter) Wait() []codersdk.WorkspaceResource {
11101173
w.t.Helper()

coderd/database/dbauthz/dbauthz.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3001,6 +3001,15 @@ func (q *querier) GetWorkspaceAgentsByResourceIDs(ctx context.Context, ids []uui
30013001
return q.db.GetWorkspaceAgentsByResourceIDs(ctx, ids)
30023002
}
30033003

3004+
func (q *querier) GetWorkspaceAgentsByWorkspaceAndBuildNumber(ctx context.Context, arg database.GetWorkspaceAgentsByWorkspaceAndBuildNumberParams) ([]database.WorkspaceAgent, error) {
3005+
_, err := q.GetWorkspaceByID(ctx, arg.WorkspaceID)
3006+
if err != nil {
3007+
return nil, err
3008+
}
3009+
3010+
return q.db.GetWorkspaceAgentsByWorkspaceAndBuildNumber(ctx, arg)
3011+
}
3012+
30043013
func (q *querier) GetWorkspaceAgentsCreatedAfter(ctx context.Context, createdAt time.Time) ([]database.WorkspaceAgent, error) {
30053014
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceSystem); err != nil {
30063015
return nil, err

0 commit comments

Comments
 (0)