From 943991f92c51fe7b325c6ea576a887f9e9572e7c Mon Sep 17 00:00:00 2001 From: Spike Curtis Date: Thu, 8 Feb 2024 15:47:35 +0400 Subject: [PATCH] feat: use agent v2 API to send agent logs --- agent/agent.go | 27 +++++++- agent/agent_test.go | 82 +++++++++++++++++++++-- agent/agentscripts/agentscripts.go | 26 +++++--- agent/agentscripts/agentscripts_test.go | 87 ++++++++++++++++--------- agent/agenttest/client.go | 42 +++++++----- codersdk/agentsdk/agentsdk.go | 2 + 6 files changed, 203 insertions(+), 63 deletions(-) diff --git a/agent/agent.go b/agent/agent.go index e0256d2e22987..9b4beca64a32e 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -92,7 +92,6 @@ type Client interface { ConnectRPC(ctx context.Context) (drpc.Conn, error) PostLifecycle(ctx context.Context, state agentsdk.PostLifecycleRequest) error PostMetadata(ctx context.Context, req agentsdk.PostMetadataRequest) error - PatchLogs(ctx context.Context, req agentsdk.PatchLogs) error RewriteDERPMap(derpMap *tailcfg.DERPMap) } @@ -181,6 +180,7 @@ func New(options Options) Agent { syscaller: options.Syscaller, modifiedProcs: options.ModifiedProcesses, processManagementTick: options.ProcessManagementTick, + logSender: agentsdk.NewLogSender(options.Logger), prometheusRegistry: prometheusRegistry, metrics: newAgentMetrics(prometheusRegistry), @@ -245,6 +245,7 @@ type agent struct { network *tailnet.Conn addresses []netip.Prefix statsReporter *statsReporter + logSender *agentsdk.LogSender connCountReconnectingPTY atomic.Int64 @@ -283,7 +284,9 @@ func (a *agent) init() { Logger: a.logger, SSHServer: sshSrv, Filesystem: a.filesystem, - PatchLogs: a.client.PatchLogs, + GetScriptLogger: func(logSourceID uuid.UUID) agentscripts.ScriptLogger { + return a.logSender.GetScriptLogger(logSourceID) + }, }) // Register runner metrics. If the prom registry is nil, the metrics // will not report anywhere. @@ -763,6 +766,20 @@ func (a *agent) run() (retErr error) { }, ) + // sending logs gets gracefulShutdownBehaviorRemain because we want to send logs generated by + // shutdown scripts. + connMan.start("send logs", gracefulShutdownBehaviorRemain, + func(ctx context.Context, conn drpc.Conn) error { + err := a.logSender.SendLoop(ctx, proto.NewDRPCAgentClient(conn)) + if xerrors.Is(err, agentsdk.LogLimitExceededError) { + // we don't want this error to tear down the API connection and propagate to the + // other routines that use the API. The LogSender has already dropped a warning + // log, so just return nil here. + return nil + } + return err + }) + // channels to sync goroutines below // handle manifest // | @@ -1769,6 +1786,12 @@ lifecycleWaitLoop: a.logger.Debug(context.Background(), "coordinator RPC disconnected") } + // Wait for logs to be sent + err = a.logSender.WaitUntilEmpty(a.hardCtx) + if err != nil { + a.logger.Warn(context.Background(), "timed out waiting for all logs to be sent", slog.Error(err)) + } + a.hardCancel() if a.network != nil { _ = a.network.Close() diff --git a/agent/agent_test.go b/agent/agent_test.go index 573440769806c..d3c400173a96b 100644 --- a/agent/agent_test.go +++ b/agent/agent_test.go @@ -2062,6 +2062,80 @@ func TestAgent_DebugServer(t *testing.T) { }) } +func TestAgent_ScriptLogging(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("bash scripts only") + } + t.Parallel() + ctx := testutil.Context(t, testutil.WaitMedium) + + derpMap, _ := tailnettest.RunDERPAndSTUN(t) + logsCh := make(chan *proto.BatchCreateLogsRequest, 100) + lsStart := uuid.UUID{0x11} + lsStop := uuid.UUID{0x22} + //nolint:dogsled + _, _, _, _, agnt := setupAgent( + t, + agentsdk.Manifest{ + DERPMap: derpMap, + Scripts: []codersdk.WorkspaceAgentScript{ + { + LogSourceID: lsStart, + RunOnStart: true, + Script: `#!/bin/sh +i=0 +while [ $i -ne 5 ] +do + i=$(($i+1)) + echo "start $i" +done +`, + }, + { + LogSourceID: lsStop, + RunOnStop: true, + Script: `#!/bin/sh +i=0 +while [ $i -ne 3000 ] +do + i=$(($i+1)) + echo "stop $i" +done +`, // send a lot of stop logs to make sure we don't truncate shutdown logs before closing the API conn + }, + }, + }, + 0, + func(cl *agenttest.Client, _ *agent.Options) { + cl.SetLogsChannel(logsCh) + }, + ) + + n := 1 + for n <= 5 { + logs := testutil.RequireRecvCtx(ctx, t, logsCh) + require.NotNil(t, logs) + for _, l := range logs.GetLogs() { + require.Equal(t, fmt.Sprintf("start %d", n), l.GetOutput()) + n++ + } + } + + err := agnt.Close() + require.NoError(t, err) + + n = 1 + for n <= 3000 { + logs := testutil.RequireRecvCtx(ctx, t, logsCh) + require.NotNil(t, logs) + for _, l := range logs.GetLogs() { + require.Equal(t, fmt.Sprintf("stop %d", n), l.GetOutput()) + n++ + } + t.Logf("got %d stop logs", n-1) + } +} + // setupAgentSSHClient creates an agent, dials it, and sets up an ssh.Client for it func setupAgentSSHClient(ctx context.Context, t *testing.T) *ssh.Client { //nolint: dogsled @@ -2137,7 +2211,7 @@ func setupAgent(t *testing.T, metadata agentsdk.Manifest, ptyTimeout time.Durati }) statsCh := make(chan *proto.Stats, 50) fs := afero.NewMemMapFs() - c := agenttest.NewClient(t, logger.Named("agent"), metadata.AgentID, metadata, statsCh, coordinator) + c := agenttest.NewClient(t, logger.Named("agenttest"), metadata.AgentID, metadata, statsCh, coordinator) t.Cleanup(c.Close) options := agent.Options{ @@ -2152,9 +2226,9 @@ func setupAgent(t *testing.T, metadata agentsdk.Manifest, ptyTimeout time.Durati opt(c, &options) } - closer := agent.New(options) + agnt := agent.New(options) t.Cleanup(func() { - _ = closer.Close() + _ = agnt.Close() }) conn, err := tailnet.NewConn(&tailnet.Options{ Addresses: []netip.Prefix{netip.PrefixFrom(tailnet.IP(), 128)}, @@ -2191,7 +2265,7 @@ func setupAgent(t *testing.T, metadata agentsdk.Manifest, ptyTimeout time.Durati if !agentConn.AwaitReachable(ctx) { t.Fatal("agent not reachable") } - return agentConn, c, statsCh, fs, closer + return agentConn, c, statsCh, fs, agnt } var dialTestPayload = []byte("dean-was-here123") diff --git a/agent/agentscripts/agentscripts.go b/agent/agentscripts/agentscripts.go index e7169f9fdb699..dea9413b8e2a8 100644 --- a/agent/agentscripts/agentscripts.go +++ b/agent/agentscripts/agentscripts.go @@ -13,6 +13,7 @@ import ( "sync/atomic" "time" + "github.com/google/uuid" "github.com/prometheus/client_golang/prometheus" "github.com/robfig/cron/v3" "github.com/spf13/afero" @@ -41,14 +42,19 @@ var ( parser = cron.NewParser(cron.Second | cron.Minute | cron.Hour | cron.Dom | cron.Month | cron.DowOptional) ) +type ScriptLogger interface { + Send(ctx context.Context, log ...agentsdk.Log) error + Flush(context.Context) error +} + // Options are a set of options for the runner. type Options struct { - DataDirBase string - LogDir string - Logger slog.Logger - SSHServer *agentssh.Server - Filesystem afero.Fs - PatchLogs func(ctx context.Context, req agentsdk.PatchLogs) error + DataDirBase string + LogDir string + Logger slog.Logger + SSHServer *agentssh.Server + Filesystem afero.Fs + GetScriptLogger func(logSourceID uuid.UUID) ScriptLogger } // New creates a runner for the provided scripts. @@ -275,20 +281,20 @@ func (r *Runner) run(ctx context.Context, script codersdk.WorkspaceAgentScript) cmd.Env = append(cmd.Env, "CODER_SCRIPT_DATA_DIR="+scriptDataDir) cmd.Env = append(cmd.Env, "CODER_SCRIPT_BIN_DIR="+r.ScriptBinDir()) - send, flushAndClose := agentsdk.LogsSender(script.LogSourceID, r.PatchLogs, logger) + scriptLogger := r.GetScriptLogger(script.LogSourceID) // If ctx is canceled here (or in a writer below), we may be // discarding logs, but that's okay because we're shutting down // anyway. We could consider creating a new context here if we // want better control over flush during shutdown. defer func() { - if err := flushAndClose(ctx); err != nil { + if err := scriptLogger.Flush(ctx); err != nil { logger.Warn(ctx, "flush startup logs failed", slog.Error(err)) } }() - infoW := agentsdk.LogsWriter(ctx, send, script.LogSourceID, codersdk.LogLevelInfo) + infoW := agentsdk.LogsWriter(ctx, scriptLogger.Send, script.LogSourceID, codersdk.LogLevelInfo) defer infoW.Close() - errW := agentsdk.LogsWriter(ctx, send, script.LogSourceID, codersdk.LogLevelError) + errW := agentsdk.LogsWriter(ctx, scriptLogger.Send, script.LogSourceID, codersdk.LogLevelError) defer errW.Close() cmd.Stdout = io.MultiWriter(fileWriter, infoW) cmd.Stderr = io.MultiWriter(fileWriter, errW) diff --git a/agent/agentscripts/agentscripts_test.go b/agent/agentscripts/agentscripts_test.go index d7fce25fda1fa..b9c8ae9f04c19 100644 --- a/agent/agentscripts/agentscripts_test.go +++ b/agent/agentscripts/agentscripts_test.go @@ -28,13 +28,10 @@ func TestMain(m *testing.M) { func TestExecuteBasic(t *testing.T) { t.Parallel() - logs := make(chan agentsdk.PatchLogs, 1) - runner := setup(t, func(ctx context.Context, req agentsdk.PatchLogs) error { - select { - case <-ctx.Done(): - case logs <- req: - } - return nil + ctx := testutil.Context(t, testutil.WaitShort) + fLogger := newFakeScriptLogger() + runner := setup(t, func(uuid2 uuid.UUID) agentscripts.ScriptLogger { + return fLogger }) defer runner.Close() err := runner.Init([]codersdk.WorkspaceAgentScript{{ @@ -45,19 +42,15 @@ func TestExecuteBasic(t *testing.T) { require.NoError(t, runner.Execute(context.Background(), func(script codersdk.WorkspaceAgentScript) bool { return true })) - log := <-logs - require.Equal(t, "hello", log.Logs[0].Output) + log := testutil.RequireRecvCtx(ctx, t, fLogger.logs) + require.Equal(t, "hello", log.Output) } func TestEnv(t *testing.T) { t.Parallel() - logs := make(chan agentsdk.PatchLogs, 2) - runner := setup(t, func(ctx context.Context, req agentsdk.PatchLogs) error { - select { - case <-ctx.Done(): - case logs <- req: - } - return nil + fLogger := newFakeScriptLogger() + runner := setup(t, func(uuid2 uuid.UUID) agentscripts.ScriptLogger { + return fLogger }) defer runner.Close() id := uuid.New() @@ -88,11 +81,9 @@ func TestEnv(t *testing.T) { select { case <-ctx.Done(): require.Fail(t, "timed out waiting for logs") - case l := <-logs: - for _, l := range l.Logs { - t.Logf("log: %s", l.Output) - } - log = append(log, l.Logs...) + case l := <-fLogger.logs: + t.Logf("log: %s", l.Output) + log = append(log, l) } if len(log) >= 2 { break @@ -124,12 +115,12 @@ func TestCronClose(t *testing.T) { require.NoError(t, runner.Close(), "close runner") } -func setup(t *testing.T, patchLogs func(ctx context.Context, req agentsdk.PatchLogs) error) *agentscripts.Runner { +func setup(t *testing.T, getScriptLogger func(logSourceID uuid.UUID) agentscripts.ScriptLogger) *agentscripts.Runner { t.Helper() - if patchLogs == nil { + if getScriptLogger == nil { // noop - patchLogs = func(ctx context.Context, req agentsdk.PatchLogs) error { - return nil + getScriptLogger = func(uuid uuid.UUID) agentscripts.ScriptLogger { + return noopScriptLogger{} } } fs := afero.NewMemMapFs() @@ -140,11 +131,45 @@ func setup(t *testing.T, patchLogs func(ctx context.Context, req agentsdk.PatchL _ = s.Close() }) return agentscripts.New(agentscripts.Options{ - LogDir: t.TempDir(), - DataDirBase: t.TempDir(), - Logger: logger, - SSHServer: s, - Filesystem: fs, - PatchLogs: patchLogs, + LogDir: t.TempDir(), + DataDirBase: t.TempDir(), + Logger: logger, + SSHServer: s, + Filesystem: fs, + GetScriptLogger: getScriptLogger, }) } + +type noopScriptLogger struct{} + +func (noopScriptLogger) Send(context.Context, ...agentsdk.Log) error { + return nil +} + +func (noopScriptLogger) Flush(context.Context) error { + return nil +} + +type fakeScriptLogger struct { + logs chan agentsdk.Log +} + +func (f *fakeScriptLogger) Send(ctx context.Context, logs ...agentsdk.Log) error { + for _, log := range logs { + select { + case <-ctx.Done(): + return ctx.Err() + case f.logs <- log: + // OK! + } + } + return nil +} + +func (*fakeScriptLogger) Flush(context.Context) error { + return nil +} + +func newFakeScriptLogger() *fakeScriptLogger { + return &fakeScriptLogger{make(chan agentsdk.Log, 100)} +} diff --git a/agent/agenttest/client.go b/agent/agenttest/client.go index 004b3521e150d..b4bbd4feb7a32 100644 --- a/agent/agenttest/client.go +++ b/agent/agenttest/client.go @@ -46,7 +46,7 @@ func NewClient(t testing.TB, derpMapUpdates := make(chan *tailcfg.DERPMap) drpcService := &tailnet.DRPCService{ CoordPtr: &coordPtr, - Logger: logger, + Logger: logger.Named("tailnetsvc"), DerpMapUpdateFrequency: time.Microsecond, DerpMapFn: func() *tailcfg.DERPMap { return <-derpMapUpdates }, } @@ -85,7 +85,6 @@ type Client struct { server *drpcserver.Server fakeAgentAPI *FakeAgentAPI LastWorkspaceAgent func() - PatchWorkspaceLogs func() error mu sync.Mutex // Protects following. lifecycleStates []codersdk.WorkspaceAgentLifecycle @@ -165,17 +164,6 @@ func (c *Client) GetStartupLogs() []agentsdk.Log { return c.logs } -func (c *Client) PatchLogs(ctx context.Context, logs agentsdk.PatchLogs) error { - c.mu.Lock() - defer c.mu.Unlock() - if c.PatchWorkspaceLogs != nil { - return c.PatchWorkspaceLogs() - } - c.logs = append(c.logs, logs.Logs...) - c.logger.Debug(ctx, "patch startup logs", slog.F("req", logs)) - return nil -} - func (c *Client) SetServiceBannerFunc(f func() (codersdk.ServiceBannerConfig, error)) { c.fakeAgentAPI.SetServiceBannerFunc(f) } @@ -192,6 +180,10 @@ func (c *Client) PushDERPMapUpdate(update *tailcfg.DERPMap) error { return nil } +func (c *Client) SetLogsChannel(ch chan<- *agentproto.BatchCreateLogsRequest) { + c.fakeAgentAPI.SetLogsChannel(ch) +} + type FakeAgentAPI struct { sync.Mutex t testing.TB @@ -201,6 +193,7 @@ type FakeAgentAPI struct { startupCh chan *agentproto.Startup statsCh chan *agentproto.Stats appHealthCh chan *agentproto.BatchUpdateAppHealthRequest + logsCh chan<- *agentproto.BatchCreateLogsRequest getServiceBannerFunc func() (codersdk.ServiceBannerConfig, error) } @@ -263,9 +256,26 @@ func (*FakeAgentAPI) BatchUpdateMetadata(context.Context, *agentproto.BatchUpdat panic("implement me") } -func (*FakeAgentAPI) BatchCreateLogs(context.Context, *agentproto.BatchCreateLogsRequest) (*agentproto.BatchCreateLogsResponse, error) { - // TODO implement me - panic("implement me") +func (f *FakeAgentAPI) SetLogsChannel(ch chan<- *agentproto.BatchCreateLogsRequest) { + f.Lock() + defer f.Unlock() + f.logsCh = ch +} + +func (f *FakeAgentAPI) BatchCreateLogs(ctx context.Context, req *agentproto.BatchCreateLogsRequest) (*agentproto.BatchCreateLogsResponse, error) { + f.logger.Info(ctx, "batch create logs called", slog.F("req", req)) + f.Lock() + ch := f.logsCh + f.Unlock() + if ch != nil { + select { + case <-ctx.Done(): + return nil, ctx.Err() + case ch <- req: + // ok + } + } + return &agentproto.BatchCreateLogsResponse{}, nil } func NewFakeAgentAPI(t testing.TB, logger slog.Logger, manifest *agentproto.Manifest, statsCh chan *agentproto.Stats) *FakeAgentAPI { diff --git a/codersdk/agentsdk/agentsdk.go b/codersdk/agentsdk/agentsdk.go index d980847389644..6d225dbfae29c 100644 --- a/codersdk/agentsdk/agentsdk.go +++ b/codersdk/agentsdk/agentsdk.go @@ -517,6 +517,8 @@ type PatchLogs struct { // PatchLogs writes log messages to the agent startup script. // Log messages are limited to 1MB in total. +// +// Deprecated: use the DRPCAgentClient.BatchCreateLogs instead func (c *Client) PatchLogs(ctx context.Context, req PatchLogs) error { res, err := c.SDK.Request(ctx, http.MethodPatch, "/api/v2/workspaceagents/me/logs", req) if err != nil {