diff --git a/agent/agent.go b/agent/agent.go index 9b4beca64a32e..b1bd45d50d0d4 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -90,7 +90,6 @@ type Options struct { type Client interface { ConnectRPC(ctx context.Context) (drpc.Conn, error) - PostLifecycle(ctx context.Context, state agentsdk.PostLifecycleRequest) error PostMetadata(ctx context.Context, req agentsdk.PostMetadataRequest) error RewriteDERPMap(derpMap *tailcfg.DERPMap) } @@ -299,7 +298,6 @@ func (a *agent) init() { // may be happening, but regardless after the intermittent // failure, you'll want the agent to reconnect. func (a *agent) runLoop() { - go a.reportLifecycleUntilClose() go a.reportMetadataUntilGracefulShutdown() go a.manageProcessPriorityUntilGracefulShutdown() @@ -618,21 +616,19 @@ func (a *agent) reportMetadataUntilGracefulShutdown() { } } -// reportLifecycleUntilClose reports the current lifecycle state once. All state +// reportLifecycle reports the current lifecycle state once. All state // changes are reported in order. -func (a *agent) reportLifecycleUntilClose() { - // part of graceful shut down is reporting the final lifecycle states, e.g "ShuttingDown" so the - // lifecycle reporting has to be via the "hard" context. - ctx := a.hardCtx +func (a *agent) reportLifecycle(ctx context.Context, conn drpc.Conn) error { + aAPI := proto.NewDRPCAgentClient(conn) lastReportedIndex := 0 // Start off with the created state without reporting it. for { select { case <-a.lifecycleUpdate: case <-ctx.Done(): - return + return ctx.Err() } - for r := retry.New(time.Second, 15*time.Second); r.Wait(ctx); { + for { a.lifecycleMu.RLock() lastIndex := len(a.lifecycleStates) - 1 report := a.lifecycleStates[lastReportedIndex] @@ -644,33 +640,36 @@ func (a *agent) reportLifecycleUntilClose() { if lastIndex == lastReportedIndex { break } + l, err := agentsdk.ProtoFromLifecycle(report) + if err != nil { + a.logger.Critical(ctx, "failed to convert lifecycle state", slog.F("report", report)) + // Skip this report; there is no point retrying. Maybe we can successfully convert the next one? + lastReportedIndex++ + continue + } + payload := &proto.UpdateLifecycleRequest{Lifecycle: l} + logger := a.logger.With(slog.F("payload", payload)) + logger.Debug(ctx, "reporting lifecycle state") - a.logger.Debug(ctx, "reporting lifecycle state", slog.F("payload", report)) + _, err = aAPI.UpdateLifecycle(ctx, payload) + if err != nil { + return xerrors.Errorf("failed to update lifecycle: %w", err) + } - err := a.client.PostLifecycle(ctx, report) - if err == nil { - a.logger.Debug(ctx, "successfully reported lifecycle state", slog.F("payload", report)) - r.Reset() // don't back off when we are successful - lastReportedIndex++ - select { - case a.lifecycleReported <- report.State: - case <-a.lifecycleReported: - a.lifecycleReported <- report.State - } - if lastReportedIndex < lastIndex { - // Keep reporting until we've sent all messages, we can't - // rely on the channel triggering us before the backlog is - // consumed. - continue - } - break + logger.Debug(ctx, "successfully reported lifecycle state") + lastReportedIndex++ + select { + case a.lifecycleReported <- report.State: + case <-a.lifecycleReported: + a.lifecycleReported <- report.State } - if xerrors.Is(err, context.Canceled) || xerrors.Is(err, context.DeadlineExceeded) { - a.logger.Debug(ctx, "canceled reporting lifecycle state", slog.F("payload", report)) - return + if lastReportedIndex < lastIndex { + // Keep reporting until we've sent all messages, we can't + // rely on the channel triggering us before the backlog is + // consumed. + continue } - // If we fail to report the state we probably shouldn't exit, log only. - a.logger.Error(ctx, "agent failed to report the lifecycle state", slog.Error(err)) + break } } } @@ -780,6 +779,10 @@ func (a *agent) run() (retErr error) { return err }) + // part of graceful shut down is reporting the final lifecycle states, e.g "ShuttingDown" so the + // lifecycle reporting has to be via gracefulShutdownBehaviorRemain + connMan.start("report lifecycle", gracefulShutdownBehaviorRemain, a.reportLifecycle) + // channels to sync goroutines below // handle manifest // | diff --git a/agent/agenttest/client.go b/agent/agenttest/client.go index b4bbd4feb7a32..19dc19372b36f 100644 --- a/agent/agenttest/client.go +++ b/agent/agenttest/client.go @@ -9,8 +9,10 @@ import ( "time" "github.com/google/uuid" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/exp/maps" + "golang.org/x/exp/slices" "golang.org/x/xerrors" "google.golang.org/protobuf/types/known/durationpb" "storj.io/drpc" @@ -86,11 +88,10 @@ type Client struct { fakeAgentAPI *FakeAgentAPI LastWorkspaceAgent func() - mu sync.Mutex // Protects following. - lifecycleStates []codersdk.WorkspaceAgentLifecycle - logs []agentsdk.Log - derpMapUpdates chan *tailcfg.DERPMap - derpMapOnce sync.Once + mu sync.Mutex // Protects following. + logs []agentsdk.Log + derpMapUpdates chan *tailcfg.DERPMap + derpMapOnce sync.Once } func (*Client) RewriteDERPMap(*tailcfg.DERPMap) {} @@ -122,17 +123,7 @@ func (c *Client) ConnectRPC(ctx context.Context) (drpc.Conn, error) { } func (c *Client) GetLifecycleStates() []codersdk.WorkspaceAgentLifecycle { - c.mu.Lock() - defer c.mu.Unlock() - return c.lifecycleStates -} - -func (c *Client) PostLifecycle(ctx context.Context, req agentsdk.PostLifecycleRequest) error { - c.mu.Lock() - defer c.mu.Unlock() - c.lifecycleStates = append(c.lifecycleStates, req.State) - c.logger.Debug(ctx, "post lifecycle", slog.F("req", req)) - return nil + return c.fakeAgentAPI.GetLifecycleStates() } func (c *Client) GetStartup() <-chan *agentproto.Startup { @@ -189,11 +180,12 @@ type FakeAgentAPI struct { t testing.TB logger slog.Logger - manifest *agentproto.Manifest - startupCh chan *agentproto.Startup - statsCh chan *agentproto.Stats - appHealthCh chan *agentproto.BatchUpdateAppHealthRequest - logsCh chan<- *agentproto.BatchCreateLogsRequest + manifest *agentproto.Manifest + startupCh chan *agentproto.Startup + statsCh chan *agentproto.Stats + appHealthCh chan *agentproto.BatchUpdateAppHealthRequest + logsCh chan<- *agentproto.BatchCreateLogsRequest + lifecycleStates []codersdk.WorkspaceAgentLifecycle getServiceBannerFunc func() (codersdk.ServiceBannerConfig, error) } @@ -231,9 +223,20 @@ func (f *FakeAgentAPI) UpdateStats(ctx context.Context, req *agentproto.UpdateSt return &agentproto.UpdateStatsResponse{ReportInterval: durationpb.New(statsInterval)}, nil } -func (*FakeAgentAPI) UpdateLifecycle(context.Context, *agentproto.UpdateLifecycleRequest) (*agentproto.Lifecycle, error) { - // TODO implement me - panic("implement me") +func (f *FakeAgentAPI) GetLifecycleStates() []codersdk.WorkspaceAgentLifecycle { + f.Lock() + defer f.Unlock() + return slices.Clone(f.lifecycleStates) +} + +func (f *FakeAgentAPI) UpdateLifecycle(_ context.Context, req *agentproto.UpdateLifecycleRequest) (*agentproto.Lifecycle, error) { + f.Lock() + defer f.Unlock() + s, err := agentsdk.LifecycleStateFromProto(req.GetLifecycle().GetState()) + if assert.NoError(f.t, err) { + f.lifecycleStates = append(f.lifecycleStates, s) + } + return req.GetLifecycle(), nil } func (f *FakeAgentAPI) BatchUpdateAppHealths(ctx context.Context, req *agentproto.BatchUpdateAppHealthRequest) (*agentproto.BatchUpdateAppHealthResponse, error) { diff --git a/codersdk/agentsdk/agentsdk.go b/codersdk/agentsdk/agentsdk.go index 6d225dbfae29c..bde518cb762d7 100644 --- a/codersdk/agentsdk/agentsdk.go +++ b/codersdk/agentsdk/agentsdk.go @@ -485,6 +485,9 @@ type PostLifecycleRequest struct { ChangedAt time.Time `json:"changed_at"` } +// PostLifecycle posts the agent's lifecycle to the Coder server. +// +// Deprecated: Use UpdateLifecycle on the dRPC API instead func (c *Client) PostLifecycle(ctx context.Context, req PostLifecycleRequest) error { res, err := c.SDK.Request(ctx, http.MethodPost, "/api/v2/workspaceagents/me/report-lifecycle", req) if err != nil { diff --git a/codersdk/agentsdk/convert.go b/codersdk/agentsdk/convert.go index 9628f1d05eb49..c872a81b1d2ed 100644 --- a/codersdk/agentsdk/convert.go +++ b/codersdk/agentsdk/convert.go @@ -311,3 +311,22 @@ func ProtoFromLog(log Log) (*proto.Log, error) { Level: proto.Log_Level(lvl), }, nil } + +func ProtoFromLifecycle(req PostLifecycleRequest) (*proto.Lifecycle, error) { + s, ok := proto.Lifecycle_State_value[strings.ToUpper(string(req.State))] + if !ok { + return nil, xerrors.Errorf("unknown lifecycle state: %s", req.State) + } + return &proto.Lifecycle{ + State: proto.Lifecycle_State(s), + ChangedAt: timestamppb.New(req.ChangedAt), + }, nil +} + +func LifecycleStateFromProto(s proto.Lifecycle_State) (codersdk.WorkspaceAgentLifecycle, error) { + caps, ok := proto.Lifecycle_State_name[int32(s)] + if !ok { + return "", xerrors.Errorf("unknown lifecycle state: %d", s) + } + return codersdk.WorkspaceAgentLifecycle(strings.ToLower(caps)), nil +} diff --git a/codersdk/agentsdk/convert_test.go b/codersdk/agentsdk/convert_test.go index 3417416b43fe1..3519408b6f872 100644 --- a/codersdk/agentsdk/convert_test.go +++ b/codersdk/agentsdk/convert_test.go @@ -9,6 +9,7 @@ import ( "tailscale.com/tailcfg" "github.com/coder/coder/v2/agent/proto" + "github.com/coder/coder/v2/coderd/database/dbtime" "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/codersdk/agentsdk" "github.com/coder/coder/v2/tailnet" @@ -161,3 +162,17 @@ func TestSubsystems(t *testing.T) { proto.Startup_EXECTRACE, }) } + +func TestProtoFromLifecycle(t *testing.T) { + t.Parallel() + now := dbtime.Now() + for _, s := range codersdk.WorkspaceAgentLifecycleOrder { + sr := agentsdk.PostLifecycleRequest{State: s, ChangedAt: now} + pr, err := agentsdk.ProtoFromLifecycle(sr) + require.NoError(t, err) + require.Equal(t, now, pr.ChangedAt.AsTime()) + state, err := agentsdk.LifecycleStateFromProto(pr.State) + require.NoError(t, err) + require.Equal(t, s, state) + } +}