Skip to content

feat: switch agent to use v2 API for sending logs #12068

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 25 additions & 2 deletions agent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,6 @@ type Client interface {
ConnectRPC(ctx context.Context) (drpc.Conn, error)
PostLifecycle(ctx context.Context, state agentsdk.PostLifecycleRequest) error
PostMetadata(ctx context.Context, req agentsdk.PostMetadataRequest) error
PatchLogs(ctx context.Context, req agentsdk.PatchLogs) error
RewriteDERPMap(derpMap *tailcfg.DERPMap)
}

Expand Down Expand Up @@ -181,6 +180,7 @@ func New(options Options) Agent {
syscaller: options.Syscaller,
modifiedProcs: options.ModifiedProcesses,
processManagementTick: options.ProcessManagementTick,
logSender: agentsdk.NewLogSender(options.Logger),

prometheusRegistry: prometheusRegistry,
metrics: newAgentMetrics(prometheusRegistry),
Expand Down Expand Up @@ -245,6 +245,7 @@ type agent struct {
network *tailnet.Conn
addresses []netip.Prefix
statsReporter *statsReporter
logSender *agentsdk.LogSender

connCountReconnectingPTY atomic.Int64

Expand Down Expand Up @@ -283,7 +284,9 @@ func (a *agent) init() {
Logger: a.logger,
SSHServer: sshSrv,
Filesystem: a.filesystem,
PatchLogs: a.client.PatchLogs,
GetScriptLogger: func(logSourceID uuid.UUID) agentscripts.ScriptLogger {
return a.logSender.GetScriptLogger(logSourceID)
},
})
// Register runner metrics. If the prom registry is nil, the metrics
// will not report anywhere.
Expand Down Expand Up @@ -763,6 +766,20 @@ func (a *agent) run() (retErr error) {
},
)

// sending logs gets gracefulShutdownBehaviorRemain because we want to send logs generated by
// shutdown scripts.
connMan.start("send logs", gracefulShutdownBehaviorRemain,
func(ctx context.Context, conn drpc.Conn) error {
err := a.logSender.SendLoop(ctx, proto.NewDRPCAgentClient(conn))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it good practice to use multiple clients over the same conn, or should we define the client in the parent scope?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The client contains no state other than the conn itself. It just maps RPC methods to Invoke or NewStream calls on the conn.

So, I think it's totally reasonable to make a new one per routine. That allows me to keep the connMan generic with a drpc.Conn. There are actually 2 different proto APIs (tailnet and agent) with their own client types.

if xerrors.Is(err, agentsdk.LogLimitExceededError) {
// we don't want this error to tear down the API connection and propagate to the
// other routines that use the API. The LogSender has already dropped a warning
// log, so just return nil here.
return nil
}
return err
})

// channels to sync goroutines below
// handle manifest
// |
Expand Down Expand Up @@ -1769,6 +1786,12 @@ lifecycleWaitLoop:
a.logger.Debug(context.Background(), "coordinator RPC disconnected")
}

// Wait for logs to be sent
err = a.logSender.WaitUntilEmpty(a.hardCtx)
if err != nil {
a.logger.Warn(context.Background(), "timed out waiting for all logs to be sent", slog.Error(err))
}

a.hardCancel()
if a.network != nil {
_ = a.network.Close()
Expand Down
82 changes: 78 additions & 4 deletions agent/agent_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2062,6 +2062,80 @@ func TestAgent_DebugServer(t *testing.T) {
})
}

func TestAgent_ScriptLogging(t *testing.T) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

❤️

if runtime.GOOS == "windows" {
t.Skip("bash scripts only")
}
t.Parallel()
ctx := testutil.Context(t, testutil.WaitMedium)

derpMap, _ := tailnettest.RunDERPAndSTUN(t)
logsCh := make(chan *proto.BatchCreateLogsRequest, 100)
lsStart := uuid.UUID{0x11}
lsStop := uuid.UUID{0x22}
//nolint:dogsled
_, _, _, _, agnt := setupAgent(
t,
agentsdk.Manifest{
DERPMap: derpMap,
Scripts: []codersdk.WorkspaceAgentScript{
{
LogSourceID: lsStart,
RunOnStart: true,
Script: `#!/bin/sh
i=0
while [ $i -ne 5 ]
do
i=$(($i+1))
echo "start $i"
done
`,
},
{
LogSourceID: lsStop,
RunOnStop: true,
Script: `#!/bin/sh
i=0
while [ $i -ne 3000 ]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if 3000 is flake-risky, considering we're using WaitMedium?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Running on its own, the script part completes in ~100ms, including queueing all the logs, sending them, and asserting them in the test. If it flakes, something else is very wrong.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure 👍🏻. I was mostly just worried about Windows, it can be unfathomably slow 😄

do
i=$(($i+1))
echo "stop $i"
done
`, // send a lot of stop logs to make sure we don't truncate shutdown logs before closing the API conn
},
},
},
0,
func(cl *agenttest.Client, _ *agent.Options) {
cl.SetLogsChannel(logsCh)
},
)

n := 1
for n <= 5 {
logs := testutil.RequireRecvCtx(ctx, t, logsCh)
require.NotNil(t, logs)
for _, l := range logs.GetLogs() {
require.Equal(t, fmt.Sprintf("start %d", n), l.GetOutput())
n++
}
}

err := agnt.Close()
require.NoError(t, err)

n = 1
for n <= 3000 {
logs := testutil.RequireRecvCtx(ctx, t, logsCh)
require.NotNil(t, logs)
for _, l := range logs.GetLogs() {
require.Equal(t, fmt.Sprintf("stop %d", n), l.GetOutput())
n++
}
t.Logf("got %d stop logs", n-1)
}
}

// setupAgentSSHClient creates an agent, dials it, and sets up an ssh.Client for it
func setupAgentSSHClient(ctx context.Context, t *testing.T) *ssh.Client {
//nolint: dogsled
Expand Down Expand Up @@ -2137,7 +2211,7 @@ func setupAgent(t *testing.T, metadata agentsdk.Manifest, ptyTimeout time.Durati
})
statsCh := make(chan *proto.Stats, 50)
fs := afero.NewMemMapFs()
c := agenttest.NewClient(t, logger.Named("agent"), metadata.AgentID, metadata, statsCh, coordinator)
c := agenttest.NewClient(t, logger.Named("agenttest"), metadata.AgentID, metadata, statsCh, coordinator)
t.Cleanup(c.Close)

options := agent.Options{
Expand All @@ -2152,9 +2226,9 @@ func setupAgent(t *testing.T, metadata agentsdk.Manifest, ptyTimeout time.Durati
opt(c, &options)
}

closer := agent.New(options)
agnt := agent.New(options)
t.Cleanup(func() {
_ = closer.Close()
_ = agnt.Close()
})
conn, err := tailnet.NewConn(&tailnet.Options{
Addresses: []netip.Prefix{netip.PrefixFrom(tailnet.IP(), 128)},
Expand Down Expand Up @@ -2191,7 +2265,7 @@ func setupAgent(t *testing.T, metadata agentsdk.Manifest, ptyTimeout time.Durati
if !agentConn.AwaitReachable(ctx) {
t.Fatal("agent not reachable")
}
return agentConn, c, statsCh, fs, closer
return agentConn, c, statsCh, fs, agnt
}

var dialTestPayload = []byte("dean-was-here123")
Expand Down
26 changes: 16 additions & 10 deletions agent/agentscripts/agentscripts.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"sync/atomic"
"time"

"github.com/google/uuid"
"github.com/prometheus/client_golang/prometheus"
"github.com/robfig/cron/v3"
"github.com/spf13/afero"
Expand Down Expand Up @@ -41,14 +42,19 @@ var (
parser = cron.NewParser(cron.Second | cron.Minute | cron.Hour | cron.Dom | cron.Month | cron.DowOptional)
)

type ScriptLogger interface {
Send(ctx context.Context, log ...agentsdk.Log) error
Flush(context.Context) error
}

// Options are a set of options for the runner.
type Options struct {
DataDirBase string
LogDir string
Logger slog.Logger
SSHServer *agentssh.Server
Filesystem afero.Fs
PatchLogs func(ctx context.Context, req agentsdk.PatchLogs) error
DataDirBase string
LogDir string
Logger slog.Logger
SSHServer *agentssh.Server
Filesystem afero.Fs
GetScriptLogger func(logSourceID uuid.UUID) ScriptLogger
}

// New creates a runner for the provided scripts.
Expand Down Expand Up @@ -275,20 +281,20 @@ func (r *Runner) run(ctx context.Context, script codersdk.WorkspaceAgentScript)
cmd.Env = append(cmd.Env, "CODER_SCRIPT_DATA_DIR="+scriptDataDir)
cmd.Env = append(cmd.Env, "CODER_SCRIPT_BIN_DIR="+r.ScriptBinDir())

send, flushAndClose := agentsdk.LogsSender(script.LogSourceID, r.PatchLogs, logger)
scriptLogger := r.GetScriptLogger(script.LogSourceID)
// If ctx is canceled here (or in a writer below), we may be
// discarding logs, but that's okay because we're shutting down
// anyway. We could consider creating a new context here if we
// want better control over flush during shutdown.
defer func() {
if err := flushAndClose(ctx); err != nil {
if err := scriptLogger.Flush(ctx); err != nil {
logger.Warn(ctx, "flush startup logs failed", slog.Error(err))
}
}()

infoW := agentsdk.LogsWriter(ctx, send, script.LogSourceID, codersdk.LogLevelInfo)
infoW := agentsdk.LogsWriter(ctx, scriptLogger.Send, script.LogSourceID, codersdk.LogLevelInfo)
defer infoW.Close()
errW := agentsdk.LogsWriter(ctx, send, script.LogSourceID, codersdk.LogLevelError)
errW := agentsdk.LogsWriter(ctx, scriptLogger.Send, script.LogSourceID, codersdk.LogLevelError)
defer errW.Close()
cmd.Stdout = io.MultiWriter(fileWriter, infoW)
cmd.Stderr = io.MultiWriter(fileWriter, errW)
Expand Down
87 changes: 56 additions & 31 deletions agent/agentscripts/agentscripts_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,10 @@ func TestMain(m *testing.M) {

func TestExecuteBasic(t *testing.T) {
t.Parallel()
logs := make(chan agentsdk.PatchLogs, 1)
runner := setup(t, func(ctx context.Context, req agentsdk.PatchLogs) error {
select {
case <-ctx.Done():
case logs <- req:
}
return nil
ctx := testutil.Context(t, testutil.WaitShort)
fLogger := newFakeScriptLogger()
runner := setup(t, func(uuid2 uuid.UUID) agentscripts.ScriptLogger {
return fLogger
})
defer runner.Close()
err := runner.Init([]codersdk.WorkspaceAgentScript{{
Expand All @@ -45,19 +42,15 @@ func TestExecuteBasic(t *testing.T) {
require.NoError(t, runner.Execute(context.Background(), func(script codersdk.WorkspaceAgentScript) bool {
return true
}))
log := <-logs
require.Equal(t, "hello", log.Logs[0].Output)
log := testutil.RequireRecvCtx(ctx, t, fLogger.logs)
require.Equal(t, "hello", log.Output)
}

func TestEnv(t *testing.T) {
t.Parallel()
logs := make(chan agentsdk.PatchLogs, 2)
runner := setup(t, func(ctx context.Context, req agentsdk.PatchLogs) error {
select {
case <-ctx.Done():
case logs <- req:
}
return nil
fLogger := newFakeScriptLogger()
runner := setup(t, func(uuid2 uuid.UUID) agentscripts.ScriptLogger {
return fLogger
})
defer runner.Close()
id := uuid.New()
Expand Down Expand Up @@ -88,11 +81,9 @@ func TestEnv(t *testing.T) {
select {
case <-ctx.Done():
require.Fail(t, "timed out waiting for logs")
case l := <-logs:
for _, l := range l.Logs {
t.Logf("log: %s", l.Output)
}
log = append(log, l.Logs...)
case l := <-fLogger.logs:
t.Logf("log: %s", l.Output)
log = append(log, l)
}
if len(log) >= 2 {
break
Expand Down Expand Up @@ -124,12 +115,12 @@ func TestCronClose(t *testing.T) {
require.NoError(t, runner.Close(), "close runner")
}

func setup(t *testing.T, patchLogs func(ctx context.Context, req agentsdk.PatchLogs) error) *agentscripts.Runner {
func setup(t *testing.T, getScriptLogger func(logSourceID uuid.UUID) agentscripts.ScriptLogger) *agentscripts.Runner {
t.Helper()
if patchLogs == nil {
if getScriptLogger == nil {
// noop
patchLogs = func(ctx context.Context, req agentsdk.PatchLogs) error {
return nil
getScriptLogger = func(uuid uuid.UUID) agentscripts.ScriptLogger {
return noopScriptLogger{}
}
}
fs := afero.NewMemMapFs()
Expand All @@ -140,11 +131,45 @@ func setup(t *testing.T, patchLogs func(ctx context.Context, req agentsdk.PatchL
_ = s.Close()
})
return agentscripts.New(agentscripts.Options{
LogDir: t.TempDir(),
DataDirBase: t.TempDir(),
Logger: logger,
SSHServer: s,
Filesystem: fs,
PatchLogs: patchLogs,
LogDir: t.TempDir(),
DataDirBase: t.TempDir(),
Logger: logger,
SSHServer: s,
Filesystem: fs,
GetScriptLogger: getScriptLogger,
})
}

type noopScriptLogger struct{}

func (noopScriptLogger) Send(context.Context, ...agentsdk.Log) error {
return nil
}

func (noopScriptLogger) Flush(context.Context) error {
return nil
}

type fakeScriptLogger struct {
logs chan agentsdk.Log
}

func (f *fakeScriptLogger) Send(ctx context.Context, logs ...agentsdk.Log) error {
for _, log := range logs {
select {
case <-ctx.Done():
return ctx.Err()
case f.logs <- log:
// OK!
}
}
return nil
}

func (*fakeScriptLogger) Flush(context.Context) error {
return nil
}

func newFakeScriptLogger() *fakeScriptLogger {
return &fakeScriptLogger{make(chan agentsdk.Log, 100)}
}
Loading