From 743f5f30410eb94db92e8ce6078c4f5afdb7c4ff Mon Sep 17 00:00:00 2001 From: Spike Curtis Date: Wed, 7 Feb 2024 16:47:55 +0400 Subject: [PATCH] feat: ensure that log batches don't exceed 1MiB in logSender --- agent/logs.go | 30 ++++++++++-- agent/logs_internal_test.go | 97 +++++++++++++++++++++++++++++++++++++ 2 files changed, 122 insertions(+), 5 deletions(-) diff --git a/agent/logs.go b/agent/logs.go index 75ff343a67212..590a46cd0d431 100644 --- a/agent/logs.go +++ b/agent/logs.go @@ -13,7 +13,11 @@ import ( "github.com/coder/coder/v2/codersdk/agentsdk" ) -const flushInterval = time.Second +const ( + flushInterval = time.Second + logOutputMaxBytes = 1 << 20 // 1MiB + overheadPerLog = 21 // found by testing +) type logQueue struct { logs []*proto.Log @@ -131,14 +135,30 @@ func (l *logSender) sendLoop(ctx context.Context, dest logDest) error { return nil } src, q := l.getPendingWorkLocked() + logger := l.logger.With(slog.F("log_source_id", src)) q.flushRequested = false // clear flag since we're now flushing req := &proto.BatchCreateLogsRequest{ LogSourceId: src[:], - Logs: q.logs[:], + } + o := 0 + n := 0 + for n < len(q.logs) { + log := q.logs[n] + if len(log.Output) > logOutputMaxBytes { + logger.Warn(ctx, "dropping log line that exceeds our limit") + n++ + continue + } + o += len(log.Output) + overheadPerLog + if o > logOutputMaxBytes { + break + } + req.Logs = append(req.Logs, log) + n++ } l.L.Unlock() - l.logger.Debug(ctx, "sending logs to agent API", slog.F("log_source_id", src), slog.F("num_logs", len(req.Logs))) + logger.Debug(ctx, "sending logs to agent API", slog.F("num_logs", len(req.Logs))) resp, err := dest.BatchCreateLogs(ctx, req) l.L.Lock() if err != nil { @@ -157,10 +177,10 @@ func (l *logSender) sendLoop(ctx context.Context, dest logDest) error { // Since elsewhere we only append to the logs, here we can remove them // since we successfully sent them. First we nil the pointers though, // so that they can be gc'd. - for i := 0; i < len(req.Logs); i++ { + for i := 0; i < n; i++ { q.logs[i] = nil } - q.logs = q.logs[len(req.Logs):] + q.logs = q.logs[n:] if len(q.logs) == 0 { // no empty queues delete(l.queues, src) diff --git a/agent/logs_internal_test.go b/agent/logs_internal_test.go index 22bd6b632ded9..146777d5882e6 100644 --- a/agent/logs_internal_test.go +++ b/agent/logs_internal_test.go @@ -9,6 +9,7 @@ import ( "github.com/google/uuid" "github.com/stretchr/testify/require" + protobuf "google.golang.org/protobuf/proto" "cdr.dev/slog" "cdr.dev/slog/sloggers/slogtest" @@ -165,6 +166,102 @@ func TestLogSender_LogLimitExceeded(t *testing.T) { require.Len(t, uut.queues, 0) } +func TestLogSender_SkipHugeLog(t *testing.T) { + t.Parallel() + testCtx := testutil.Context(t, testutil.WaitShort) + ctx, cancel := context.WithCancel(testCtx) + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + fDest := newFakeLogDest() + uut := newLogSender(logger) + + t0 := dbtime.Now() + ls1 := uuid.UUID{0x11} + hugeLog := make([]byte, logOutputMaxBytes+1) + for i := range hugeLog { + hugeLog[i] = 'q' + } + err := uut.enqueue(ls1, + agentsdk.Log{ + CreatedAt: t0, + Output: string(hugeLog), + Level: codersdk.LogLevelInfo, + }, + agentsdk.Log{ + CreatedAt: t0, + Output: "test log 1, src 1", + Level: codersdk.LogLevelInfo, + }) + require.NoError(t, err) + + loopErr := make(chan error, 1) + go func() { + err := uut.sendLoop(ctx, fDest) + loopErr <- err + }() + + req := testutil.RequireRecvCtx(ctx, t, fDest.reqs) + require.NotNil(t, req) + require.Len(t, req.Logs, 1, "it should skip the huge log") + require.Equal(t, "test log 1, src 1", req.Logs[0].GetOutput()) + require.Equal(t, proto.Log_INFO, req.Logs[0].GetLevel()) + testutil.RequireSendCtx(ctx, t, fDest.resps, &proto.BatchCreateLogsResponse{}) + + cancel() + err = testutil.RequireRecvCtx(testCtx, t, loopErr) + require.NoError(t, err) +} + +func TestLogSender_Batch(t *testing.T) { + t.Parallel() + testCtx := testutil.Context(t, testutil.WaitShort) + ctx, cancel := context.WithCancel(testCtx) + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + fDest := newFakeLogDest() + uut := newLogSender(logger) + + t0 := dbtime.Now() + ls1 := uuid.UUID{0x11} + var logs []agentsdk.Log + for i := 0; i < 60000; i++ { + logs = append(logs, agentsdk.Log{ + CreatedAt: t0, + Output: "r", + Level: codersdk.LogLevelInfo, + }) + } + err := uut.enqueue(ls1, logs...) + require.NoError(t, err) + + loopErr := make(chan error, 1) + go func() { + err := uut.sendLoop(ctx, fDest) + loopErr <- err + }() + + // with 60k logs, we should split into two updates to avoid going over 1MiB, since each log + // is about 21 bytes. + gotLogs := 0 + req := testutil.RequireRecvCtx(ctx, t, fDest.reqs) + require.NotNil(t, req) + gotLogs += len(req.Logs) + wire, err := protobuf.Marshal(req) + require.NoError(t, err) + require.Less(t, len(wire), logOutputMaxBytes, "wire should not exceed 1MiB") + testutil.RequireSendCtx(ctx, t, fDest.resps, &proto.BatchCreateLogsResponse{}) + req = testutil.RequireRecvCtx(ctx, t, fDest.reqs) + require.NotNil(t, req) + gotLogs += len(req.Logs) + wire, err = protobuf.Marshal(req) + require.NoError(t, err) + require.Less(t, len(wire), logOutputMaxBytes, "wire should not exceed 1MiB") + require.Equal(t, 60000, gotLogs) + testutil.RequireSendCtx(ctx, t, fDest.resps, &proto.BatchCreateLogsResponse{}) + + cancel() + err = testutil.RequireRecvCtx(testCtx, t, loopErr) + require.NoError(t, err) +} + type fakeLogDest struct { reqs chan *proto.BatchCreateLogsRequest resps chan *proto.BatchCreateLogsResponse