diff --git a/pkg/buffer/buffer.go b/pkg/buffer/buffer.go new file mode 100644 index 000000000..950690104 --- /dev/null +++ b/pkg/buffer/buffer.go @@ -0,0 +1,67 @@ +package buffer + +import ( + "bufio" + "fmt" + "net/http" + "strings" +) + +// ProcessResponseAsRingBufferToEnd reads the body of an HTTP response line by line, +// storing only the last maxJobLogLines lines using a ring buffer (sliding window). +// This efficiently retains the most recent lines, overwriting older ones as needed. +// +// Parameters: +// httpResp: The HTTP response whose body will be read. +// maxJobLogLines: The maximum number of log lines to retain. +// +// Returns: +// string: The concatenated log lines (up to maxJobLogLines), separated by newlines. +// int: The total number of lines read from the response. +// *http.Response: The original HTTP response. +// error: Any error encountered during reading. +// +// The function uses a ring buffer to efficiently store only the last maxJobLogLines lines. +// If the response contains more lines than maxJobLogLines, only the most recent lines are kept. +func ProcessResponseAsRingBufferToEnd(httpResp *http.Response, maxJobLogLines int) (string, int, *http.Response, error) { + lines := make([]string, maxJobLogLines) + validLines := make([]bool, maxJobLogLines) + totalLines := 0 + writeIndex := 0 + + scanner := bufio.NewScanner(httpResp.Body) + scanner.Buffer(make([]byte, 0, 64*1024), 1024*1024) + + for scanner.Scan() { + line := scanner.Text() + totalLines++ + + lines[writeIndex] = line + validLines[writeIndex] = true + writeIndex = (writeIndex + 1) % maxJobLogLines + } + + if err := scanner.Err(); err != nil { + return "", 0, httpResp, fmt.Errorf("failed to read log content: %w", err) + } + + var result []string + linesInBuffer := totalLines + if linesInBuffer > maxJobLogLines { + linesInBuffer = maxJobLogLines + } + + startIndex := 0 + if totalLines > maxJobLogLines { + startIndex = writeIndex + } + + for i := 0; i < linesInBuffer; i++ { + idx := (startIndex + i) % maxJobLogLines + if validLines[idx] { + result = append(result, lines[idx]) + } + } + + return strings.Join(result, "\n"), totalLines, httpResp, nil +} diff --git a/pkg/github/actions.go b/pkg/github/actions.go index 38719f155..855a01b53 100644 --- a/pkg/github/actions.go +++ b/pkg/github/actions.go @@ -4,12 +4,13 @@ import ( "context" "encoding/json" "fmt" - "io" "net/http" "strconv" "strings" + buffer "github.com/github/github-mcp-server/pkg/buffer" ghErrors "github.com/github/github-mcp-server/pkg/errors" + "github.com/github/github-mcp-server/pkg/profiler" "github.com/github/github-mcp-server/pkg/translations" "github.com/google/go-github/v74/github" "github.com/mark3labs/mcp-go/mcp" @@ -19,6 +20,7 @@ import ( const ( DescriptionRepositoryOwner = "Repository owner" DescriptionRepositoryName = "Repository name" + maxJobLogLines = 50000 ) // ListWorkflows creates a tool to list workflows in a repository @@ -721,7 +723,7 @@ func getJobLogData(ctx context.Context, client *github.Client, owner, repo strin if returnContent { // Download and return the actual log content - content, originalLength, httpResp, err := downloadLogContent(url.String(), tailLines) //nolint:bodyclose // Response body is closed in downloadLogContent, but we need to return httpResp + content, originalLength, httpResp, err := downloadLogContent(ctx, url.String(), tailLines) //nolint:bodyclose // Response body is closed in downloadLogContent, but we need to return httpResp if err != nil { // To keep the return value consistent wrap the response as a GitHub Response ghRes := &github.Response{ @@ -742,9 +744,11 @@ func getJobLogData(ctx context.Context, client *github.Client, owner, repo strin return result, resp, nil } -// downloadLogContent downloads the actual log content from a GitHub logs URL -func downloadLogContent(logURL string, tailLines int) (string, int, *http.Response, error) { - httpResp, err := http.Get(logURL) //nolint:gosec // URLs are provided by GitHub API and are safe +func downloadLogContent(ctx context.Context, logURL string, tailLines int) (string, int, *http.Response, error) { + prof := profiler.New(nil, profiler.IsProfilingEnabled()) + finish := prof.Start(ctx, "log_buffer_processing") + + httpResp, err := http.Get(logURL) //nolint:gosec if err != nil { return "", 0, httpResp, fmt.Errorf("failed to download logs: %w", err) } @@ -754,36 +758,29 @@ func downloadLogContent(logURL string, tailLines int) (string, int, *http.Respon return "", 0, httpResp, fmt.Errorf("failed to download logs: HTTP %d", httpResp.StatusCode) } - content, err := io.ReadAll(httpResp.Body) - if err != nil { - return "", 0, httpResp, fmt.Errorf("failed to read log content: %w", err) + if tailLines <= 0 { + tailLines = 1000 } - // Clean up and format the log content for better readability - logContent := strings.TrimSpace(string(content)) + bufferSize := tailLines + if bufferSize > maxJobLogLines { + bufferSize = maxJobLogLines + } - trimmedContent, lineCount := trimContent(logContent, tailLines) - return trimmedContent, lineCount, httpResp, nil -} + processedInput, totalLines, httpResp, err := buffer.ProcessResponseAsRingBufferToEnd(httpResp, bufferSize) + if err != nil { + return "", 0, httpResp, fmt.Errorf("failed to process log content: %w", err) + } -// trimContent trims the content to a maximum length and returns the trimmed content and an original length -func trimContent(content string, tailLines int) (string, int) { - // Truncate to tail_lines if specified - lineCount := 0 - if tailLines > 0 { - - // Count backwards to find the nth newline from the end and a total number of lines - for i := len(content) - 1; i >= 0 && lineCount < tailLines; i-- { - if content[i] == '\n' { - lineCount++ - // If we have reached the tailLines, trim the content - if lineCount == tailLines { - content = content[i+1:] - } - } - } + lines := strings.Split(processedInput, "\n") + if len(lines) > tailLines { + lines = lines[len(lines)-tailLines:] } - return content, lineCount + finalResult := strings.Join(lines, "\n") + + _ = finish(len(lines), int64(len(finalResult))) + + return finalResult, totalLines, httpResp, nil } // RerunWorkflowRun creates a tool to re-run an entire workflow run diff --git a/pkg/github/actions_test.go b/pkg/github/actions_test.go index 3d7521125..33549aad9 100644 --- a/pkg/github/actions_test.go +++ b/pkg/github/actions_test.go @@ -3,10 +3,17 @@ package github import ( "context" "encoding/json" + "io" "net/http" "net/http/httptest" + "os" + "runtime" + "runtime/debug" + "strings" "testing" + buffer "github.com/github/github-mcp-server/pkg/buffer" + "github.com/github/github-mcp-server/pkg/profiler" "github.com/github/github-mcp-server/pkg/translations" "github.com/google/go-github/v74/github" "github.com/migueleliasweb/go-github-mock/src/mock" @@ -1162,8 +1169,118 @@ func Test_GetJobLogs_WithContentReturnAndTailLines(t *testing.T) { require.NoError(t, err) assert.Equal(t, float64(123), response["job_id"]) - assert.Equal(t, float64(1), response["original_length"]) + assert.Equal(t, float64(3), response["original_length"]) assert.Equal(t, expectedLogContent, response["logs_content"]) assert.Equal(t, "Job logs content retrieved successfully", response["message"]) assert.NotContains(t, response, "logs_url") // Should not have URL when returning content } + +func Test_GetJobLogs_WithContentReturnAndLargeTailLines(t *testing.T) { + logContent := "Line 1\nLine 2\nLine 3" + expectedLogContent := "Line 1\nLine 2\nLine 3" + + testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(logContent)) + })) + defer testServer.Close() + + mockedClient := mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.GetReposActionsJobsLogsByOwnerByRepoByJobId, + http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Location", testServer.URL) + w.WriteHeader(http.StatusFound) + }), + ), + ) + + client := github.NewClient(mockedClient) + _, handler := GetJobLogs(stubGetClientFn(client), translations.NullTranslationHelper) + + request := createMCPRequest(map[string]any{ + "owner": "owner", + "repo": "repo", + "job_id": float64(123), + "return_content": true, + "tail_lines": float64(100), + }) + + result, err := handler(context.Background(), request) + require.NoError(t, err) + require.False(t, result.IsError) + + textContent := getTextResult(t, result) + var response map[string]any + err = json.Unmarshal([]byte(textContent.Text), &response) + require.NoError(t, err) + + assert.Equal(t, float64(123), response["job_id"]) + assert.Equal(t, float64(3), response["original_length"]) + assert.Equal(t, expectedLogContent, response["logs_content"]) + assert.Equal(t, "Job logs content retrieved successfully", response["message"]) + assert.NotContains(t, response, "logs_url") +} + +func Test_MemoryUsage_SlidingWindow_vs_NoWindow(t *testing.T) { + if testing.Short() { + t.Skip("Skipping memory profiling test in short mode") + } + + const logLines = 100000 + const bufferSize = 1000 + largeLogContent := strings.Repeat("log line with some content\n", logLines) + + testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(largeLogContent)) + })) + defer testServer.Close() + + os.Setenv("GITHUB_MCP_PROFILING_ENABLED", "true") + defer os.Unsetenv("GITHUB_MCP_PROFILING_ENABLED") + + // Initialize the global profiler + profiler.InitFromEnv(nil) + + ctx := context.Background() + + debug.SetGCPercent(-1) + profile1, err1 := profiler.ProfileFuncWithMetrics(ctx, "sliding_window", func() (int, int64, error) { + resp1, err := http.Get(testServer.URL) + if err != nil { + return 0, 0, err + } + defer resp1.Body.Close() //nolint:bodyclose // Response body is closed in downloadLogContent, but we need to return httpResp + content, totalLines, _, err := buffer.ProcessResponseAsRingBufferToEnd(resp1, bufferSize) //nolint:bodyclose + return totalLines, int64(len(content)), err + }) + require.NoError(t, err1) + + runtime.GC() + profile2, err2 := profiler.ProfileFuncWithMetrics(ctx, "no_window", func() (int, int64, error) { + resp2, err := http.Get(testServer.URL) + if err != nil { + return 0, 0, err + } + defer resp2.Body.Close() //nolint:bodyclose // Response body is closed in downloadLogContent, but we need to return httpResp + content, err := io.ReadAll(resp2.Body) + if err != nil { + return 0, 0, err + } + lines := strings.Split(string(content), "\n") + if len(lines) > bufferSize { + lines = lines[len(lines)-bufferSize:] + } + result := strings.Join(lines, "\n") + return len(strings.Split(string(content), "\n")), int64(len(result)), nil + }) + require.NoError(t, err2) + debug.SetGCPercent(100) + + assert.Greater(t, profile2.MemoryDelta, profile1.MemoryDelta, + "Sliding window should use less memory than reading all into memory") + + t.Logf("Sliding window: %s", profile1.String()) + t.Logf("No window: %s", profile2.String()) +} diff --git a/pkg/profiler/profiler.go b/pkg/profiler/profiler.go new file mode 100644 index 000000000..1cfb7ffae --- /dev/null +++ b/pkg/profiler/profiler.go @@ -0,0 +1,215 @@ +package profiler + +import ( + "context" + "fmt" + "os" + "runtime" + "strconv" + "time" + + "log/slog" + "math" +) + +// Profile represents performance metrics for an operation +type Profile struct { + Operation string `json:"operation"` + Duration time.Duration `json:"duration_ns"` + MemoryBefore uint64 `json:"memory_before_bytes"` + MemoryAfter uint64 `json:"memory_after_bytes"` + MemoryDelta int64 `json:"memory_delta_bytes"` + LinesCount int `json:"lines_count,omitempty"` + BytesCount int64 `json:"bytes_count,omitempty"` + Timestamp time.Time `json:"timestamp"` +} + +// String returns a human-readable representation of the profile +func (p *Profile) String() string { + return fmt.Sprintf("[%s] %s: duration=%v, memory_delta=%+dB, lines=%d, bytes=%d", + p.Timestamp.Format("15:04:05.000"), + p.Operation, + p.Duration, + p.MemoryDelta, + p.LinesCount, + p.BytesCount, + ) +} + +func safeMemoryDelta(after, before uint64) int64 { + if after > math.MaxInt64 || before > math.MaxInt64 { + if after >= before { + diff := after - before + if diff > math.MaxInt64 { + return math.MaxInt64 + } + return int64(diff) + } + diff := before - after + if diff > math.MaxInt64 { + return -math.MaxInt64 + } + return -int64(diff) + } + + return int64(after) - int64(before) +} + +// Profiler provides minimal performance profiling capabilities +type Profiler struct { + logger *slog.Logger + enabled bool +} + +// New creates a new Profiler instance +func New(logger *slog.Logger, enabled bool) *Profiler { + return &Profiler{ + logger: logger, + enabled: enabled, + } +} + +// ProfileFunc profiles a function execution +func (p *Profiler) ProfileFunc(ctx context.Context, operation string, fn func() error) (*Profile, error) { + if !p.enabled { + return nil, fn() + } + + profile := &Profile{ + Operation: operation, + Timestamp: time.Now(), + } + + var memBefore runtime.MemStats + runtime.ReadMemStats(&memBefore) + profile.MemoryBefore = memBefore.Alloc + + start := time.Now() + err := fn() + profile.Duration = time.Since(start) + + var memAfter runtime.MemStats + runtime.ReadMemStats(&memAfter) + profile.MemoryAfter = memAfter.Alloc + profile.MemoryDelta = safeMemoryDelta(memAfter.Alloc, memBefore.Alloc) + + if p.logger != nil { + p.logger.InfoContext(ctx, "Performance profile", "profile", profile.String()) + } + + return profile, err +} + +// ProfileFuncWithMetrics profiles a function execution and captures additional metrics +func (p *Profiler) ProfileFuncWithMetrics(ctx context.Context, operation string, fn func() (int, int64, error)) (*Profile, error) { + if !p.enabled { + _, _, err := fn() + return nil, err + } + + profile := &Profile{ + Operation: operation, + Timestamp: time.Now(), + } + + var memBefore runtime.MemStats + runtime.ReadMemStats(&memBefore) + profile.MemoryBefore = memBefore.Alloc + + start := time.Now() + lines, bytes, err := fn() + profile.Duration = time.Since(start) + profile.LinesCount = lines + profile.BytesCount = bytes + + var memAfter runtime.MemStats + runtime.ReadMemStats(&memAfter) + profile.MemoryAfter = memAfter.Alloc + profile.MemoryDelta = safeMemoryDelta(memAfter.Alloc, memBefore.Alloc) + + if p.logger != nil { + p.logger.InfoContext(ctx, "Performance profile", "profile", profile.String()) + } + + return profile, err +} + +// Start begins timing an operation and returns a function to complete the profiling +func (p *Profiler) Start(ctx context.Context, operation string) func(lines int, bytes int64) *Profile { + if !p.enabled { + return func(int, int64) *Profile { return nil } + } + + profile := &Profile{ + Operation: operation, + Timestamp: time.Now(), + } + + var memBefore runtime.MemStats + runtime.ReadMemStats(&memBefore) + profile.MemoryBefore = memBefore.Alloc + + start := time.Now() + + return func(lines int, bytes int64) *Profile { + profile.Duration = time.Since(start) + profile.LinesCount = lines + profile.BytesCount = bytes + + var memAfter runtime.MemStats + runtime.ReadMemStats(&memAfter) + profile.MemoryAfter = memAfter.Alloc + profile.MemoryDelta = safeMemoryDelta(memAfter.Alloc, memBefore.Alloc) + + if p.logger != nil { + p.logger.InfoContext(ctx, "Performance profile", "profile", profile.String()) + } + + return profile + } +} + +var globalProfiler *Profiler + +// IsProfilingEnabled checks if profiling is enabled via environment variables +func IsProfilingEnabled() bool { + if enabled, err := strconv.ParseBool(os.Getenv("GITHUB_MCP_PROFILING_ENABLED")); err == nil { + return enabled + } + return false +} + +// Init initializes the global profiler +func Init(logger *slog.Logger, enabled bool) { + globalProfiler = New(logger, enabled) +} + +// InitFromEnv initializes the global profiler using environment variables +func InitFromEnv(logger *slog.Logger) { + globalProfiler = New(logger, IsProfilingEnabled()) +} + +// ProfileFunc profiles a function using the global profiler +func ProfileFunc(ctx context.Context, operation string, fn func() error) (*Profile, error) { + if globalProfiler == nil { + return nil, fn() + } + return globalProfiler.ProfileFunc(ctx, operation, fn) +} + +// ProfileFuncWithMetrics profiles a function with metrics using the global profiler +func ProfileFuncWithMetrics(ctx context.Context, operation string, fn func() (int, int64, error)) (*Profile, error) { + if globalProfiler == nil { + _, _, err := fn() + return nil, err + } + return globalProfiler.ProfileFuncWithMetrics(ctx, operation, fn) +} + +// Start begins timing using the global profiler +func Start(ctx context.Context, operation string) func(int, int64) *Profile { + if globalProfiler == nil { + return func(int, int64) *Profile { return nil } + } + return globalProfiler.Start(ctx, operation) +}