diff --git a/README.md b/README.md index b40974e20..f4d38f3fe 100644 --- a/README.md +++ b/README.md @@ -526,6 +526,18 @@ The following sets of tools are available (all are on by default): - `repo`: Repository name (string, required) - `title`: Issue title (string, required) +- **find_closing_pull_requests** - Find closing pull requests + - `after`: Cursor for forward pagination (use with first/limit) (string, optional) + - `before`: Cursor for backward pagination (use with last) (string, optional) + - `includeClosedPrs`: Include closed/merged pull requests in results (default: false) (boolean, optional) + - `issue_numbers`: Array of issue numbers within the specified repository (number[], required) + - `last`: Number of results from end for backward pagination (max: 250) (number, optional) + - `limit`: Maximum number of closing PRs to return per issue (default: 100, max: 250) (number, optional) + - `orderByState`: Order results by pull request state (default: false) (boolean, optional) + - `owner`: The owner of the repository (string, required) + - `repo`: The name of the repository (string, required) + - `userLinkedOnly`: Return only manually linked pull requests (default: false) (boolean, optional) + - **get_issue** - Get issue details - `issue_number`: The number of the issue (number, required) - `owner`: The owner of the repository (string, required) diff --git a/e2e/e2e_test.go b/e2e/e2e_test.go index 64c5729ba..de6ee5fa9 100644 --- a/e2e/e2e_test.go +++ b/e2e/e2e_test.go @@ -21,6 +21,7 @@ import ( gogithub "github.com/google/go-github/v73/github" mcpClient "github.com/mark3labs/mcp-go/client" "github.com/mark3labs/mcp-go/mcp" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -1624,3 +1625,626 @@ func TestPullRequestReviewDeletion(t *testing.T) { require.NoError(t, err, "expected to unmarshal text content successfully") require.Len(t, noReviews, 0, "expected to find no reviews") } + +func TestFindClosingPullRequests(t *testing.T) { + t.Parallel() + + mcpClient := setupMCPClient(t, withToolsets([]string{"issues"})) + + ctx := context.Background() + + // Test with well-known GitHub repositories and issues + testCases := []struct { + name string + owner string + repo string + issueNumbers []int + limit int + expectError bool + expectedResults int + expectSomeWithClosingPR bool + }{ + { + name: "Single issue test - should handle gracefully even if no closing PRs", + owner: "octocat", + repo: "Hello-World", + issueNumbers: []int{1}, + limit: 5, + expectError: false, + expectedResults: 1, + }, + { + name: "Multiple issues test", + owner: "github", + repo: "docs", + issueNumbers: []int{1, 2}, + limit: 3, + expectError: false, + expectedResults: 2, + }, + { + name: "Empty issue_numbers array should return error", + owner: "octocat", + repo: "Hello-World", + issueNumbers: []int{}, + expectError: true, + }, + { + name: "Limit too high should return error", + owner: "octocat", + repo: "Hello-World", + issueNumbers: []int{1}, + limit: 251, + expectError: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Prepare the request + findClosingPRsRequest := mcp.CallToolRequest{} + findClosingPRsRequest.Params.Name = "find_closing_pull_requests" + + // Build arguments map + args := map[string]any{ + "owner": tc.owner, + "repo": tc.repo, + "issue_numbers": tc.issueNumbers, + } + if tc.limit > 0 { + args["limit"] = tc.limit + } + findClosingPRsRequest.Params.Arguments = args + + t.Logf("Calling find_closing_pull_requests with owner: %s, repo: %s, issue_numbers: %v", tc.owner, tc.repo, tc.issueNumbers) + resp, err := mcpClient.CallTool(ctx, findClosingPRsRequest) + + if tc.expectError { + // We expect either an error or an error response + if err != nil { + t.Logf("Expected error occurred: %v", err) + return + } + require.True(t, resp.IsError, "Expected error response") + t.Logf("Expected error in response: %+v", resp) + return + } + + require.NoError(t, err, "expected to call 'find_closing_pull_requests' tool successfully") + require.False(t, resp.IsError, fmt.Sprintf("expected result not to be an error: %+v", resp)) + + // Verify we got content + require.NotEmpty(t, resp.Content, "Expected response content") + + textContent, ok := resp.Content[0].(mcp.TextContent) + require.True(t, ok, "expected content to be of type TextContent") + + t.Logf("Response: %s", textContent.Text) + + // Parse the JSON response + var response struct { + Results []struct { + Owner string `json:"owner"` + Repo string `json:"repo"` + IssueNumber int `json:"issue_number"` + ClosingPullRequests []struct { + Number int `json:"number"` + Title string `json:"title"` + Body string `json:"body"` + State string `json:"state"` + URL string `json:"url"` + Merged bool `json:"merged"` + } `json:"closingPullRequests"` + TotalCount int `json:"totalCount"` + Error string `json:"error,omitempty"` + } `json:"results"` + } + + err = json.Unmarshal([]byte(textContent.Text), &response) + require.NoError(t, err, "expected to unmarshal response successfully") + + // Verify the response structure + require.Len(t, response.Results, tc.expectedResults, "Expected specific number of results") + + // Log and verify each result + for i, result := range response.Results { + t.Logf("Result %d:", i+1) + t.Logf(" Owner: %s, Repo: %s, Number: %d", result.Owner, result.Repo, result.IssueNumber) + t.Logf(" Total closing PRs: %d", result.TotalCount) + if result.Error != "" { + t.Logf(" Error: %s", result.Error) + } + + // Verify basic structure + assert.NotEmpty(t, result.Owner, "Owner should not be empty") + assert.NotEmpty(t, result.Repo, "Repo should not be empty") + assert.Greater(t, result.IssueNumber, 0, "Issue number should be positive") + + // Log details of any closing PRs found + for j, pr := range result.ClosingPullRequests { + t.Logf(" Closing PR %d:", j+1) + t.Logf(" Number: %d", pr.Number) + t.Logf(" Title: %s", pr.Title) + t.Logf(" State: %s, Merged: %t", pr.State, pr.Merged) + t.Logf(" URL: %s", pr.URL) + + // Verify PR structure + assert.Greater(t, pr.Number, 0, "PR number should be positive") + assert.NotEmpty(t, pr.Title, "PR title should not be empty") + assert.NotEmpty(t, pr.State, "PR state should not be empty") + assert.NotEmpty(t, pr.URL, "PR URL should not be empty") + } + + // The number of closing PRs in this page should be less than or equal to the total count + assert.LessOrEqual(t, len(result.ClosingPullRequests), result.TotalCount, "ClosingPullRequests length should not exceed TotalCount") + } + }) + } +} + +// TestFindClosingPullRequestsGraphQLParameters tests the enhanced GraphQL parameters +func TestFindClosingPullRequestsGraphQLParameters(t *testing.T) { + t.Parallel() + + mcpClient := setupMCPClient(t, withToolsets([]string{"issues"})) + ctx := context.Background() + + t.Run("Boolean Parameters", func(t *testing.T) { + // Test cases for boolean parameters + booleanTestCases := []struct { + name string + owner string + repo string + issueNumbers []int + includeClosedPrs *bool + orderByState *bool + userLinkedOnly *bool + expectError bool + description string + }{ + { + name: "includeClosedPrs=true - should include closed/merged PRs", + owner: "microsoft", + repo: "vscode", + issueNumbers: []int{1}, + includeClosedPrs: boolPtr(true), + description: "Test includeClosedPrs parameter with popular repository", + }, + { + name: "includeClosedPrs=false - should exclude closed/merged PRs", + owner: "microsoft", + repo: "vscode", + issueNumbers: []int{2}, + includeClosedPrs: boolPtr(false), + description: "Test includeClosedPrs=false parameter", + }, + { + name: "orderByState=true - should order results by PR state", + owner: "microsoft", + repo: "vscode", + issueNumbers: []int{1, 2}, // Use low numbers for older issues + orderByState: boolPtr(true), + description: "Test orderByState parameter with larger repository", + }, + { + name: "userLinkedOnly=true - should return only manually linked PRs", + owner: "facebook", + repo: "react", + issueNumbers: []int{1}, // First issue in React repo + userLinkedOnly: boolPtr(true), + description: "Test userLinkedOnly parameter", + }, + { + name: "Combined boolean parameters", + owner: "facebook", + repo: "react", + issueNumbers: []int{1, 2}, + includeClosedPrs: boolPtr(true), + orderByState: boolPtr(true), + userLinkedOnly: boolPtr(false), + description: "Test multiple boolean parameters together", + }, + } + + for _, tc := range booleanTestCases { + t.Run(tc.name, func(t *testing.T) { + // Build arguments map + args := map[string]any{ + "owner": tc.owner, + "repo": tc.repo, + "issue_numbers": tc.issueNumbers, + "limit": 5, // Keep limit reasonable + } + + // Add boolean parameters if specified + if tc.includeClosedPrs != nil { + args["includeClosedPrs"] = *tc.includeClosedPrs + } + if tc.orderByState != nil { + args["orderByState"] = *tc.orderByState + } + if tc.userLinkedOnly != nil { + args["userLinkedOnly"] = *tc.userLinkedOnly + } + + // Create request + request := mcp.CallToolRequest{} + request.Params.Name = "find_closing_pull_requests" + request.Params.Arguments = args + + t.Logf("Testing %s: %s", tc.name, tc.description) + t.Logf("Arguments: %+v", args) + + // Call the tool + resp, err := mcpClient.CallTool(ctx, request) + + if tc.expectError { + if err != nil { + t.Logf("Expected error occurred: %v", err) + return + } + require.True(t, resp.IsError, "Expected error response") + return + } + + require.NoError(t, err, "expected successful tool call") + require.False(t, resp.IsError, fmt.Sprintf("expected non-error response: %+v", resp)) + require.NotEmpty(t, resp.Content, "Expected response content") + + // Parse response + textContent, ok := resp.Content[0].(mcp.TextContent) + require.True(t, ok, "expected TextContent") + + var response struct { + Results []struct { + Owner string `json:"owner"` + Repo string `json:"repo"` + IssueNumber int `json:"issue_number"` + TotalCount int `json:"total_count"` + ClosingPullRequests []struct { + Number int `json:"number"` + Title string `json:"title"` + State string `json:"state"` + Merged bool `json:"merged"` + URL string `json:"url"` + } `json:"closing_pull_requests"` + Error string `json:"error,omitempty"` + } `json:"results"` + } + + err = json.Unmarshal([]byte(textContent.Text), &response) + require.NoError(t, err, "expected successful JSON parsing") + + // Verify response structure + require.NotEmpty(t, response.Results, "Expected at least one result") + + for i, result := range response.Results { + t.Logf("Result %d: Owner=%s, Repo=%s, Issue=%d, TotalCount=%d", + i+1, result.Owner, result.Repo, result.IssueNumber, result.TotalCount) + + // Log PRs found + for j, pr := range result.ClosingPullRequests { + t.Logf(" PR %d: #%d - %s (State: %s, Merged: %t)", + j+1, pr.Number, pr.Title, pr.State, pr.Merged) + } + + // Basic validations + assert.Equal(t, tc.owner, result.Owner, "Owner should match request") + assert.Equal(t, tc.repo, result.Repo, "Repo should match request") + assert.Contains(t, tc.issueNumbers, result.IssueNumber, "Issue number should be in request") + assert.LessOrEqual(t, len(result.ClosingPullRequests), result.TotalCount, "ClosingPullRequests length should not exceed TotalCount") + + // Parameter-specific validations + if tc.includeClosedPrs != nil && *tc.includeClosedPrs == false { + // When includeClosedPrs=false, should not include closed/merged PRs + for _, pr := range result.ClosingPullRequests { + assert.NotEqual(t, "CLOSED", pr.State, "Should not include closed PRs when includeClosedPrs=false") + if pr.State == "MERGED" { + assert.False(t, pr.Merged, "Should not include merged PRs when includeClosedPrs=false") + } + } + } + + if tc.orderByState != nil && *tc.orderByState == true && len(result.ClosingPullRequests) > 1 { + // When orderByState=true, verify some ordering (exact ordering depends on GitHub's implementation) + t.Logf("OrderByState=true: PRs should be ordered by state") + // Note: We can't assert exact ordering without knowing GitHub's algorithm + // but we can verify the parameter was processed (no errors) + } + } + }) + } + }) + + t.Run("Pagination Parameters", func(t *testing.T) { + // Test cases for pagination parameters + paginationTestCases := []struct { + name string + owner string + repo string + issueNumbers []int + first *int + last *int + after *string + before *string + expectError bool + description string + }{ + { + name: "Forward pagination with first parameter", + owner: "microsoft", + repo: "vscode", + issueNumbers: []int{1, 2}, + first: intPtr(1), + description: "Test forward pagination using first parameter", + }, + { + name: "Backward pagination with last parameter", + owner: "microsoft", + repo: "vscode", + issueNumbers: []int{1, 2}, + last: intPtr(1), + description: "Test backward pagination using last parameter", + }, + { + name: "Large repository pagination test", + owner: "microsoft", + repo: "vscode", + issueNumbers: []int{1, 2, 3}, + first: intPtr(1), // Small page size + description: "Test pagination with larger repository", + }, + } + + for _, tc := range paginationTestCases { + t.Run(tc.name, func(t *testing.T) { + // Build arguments map + args := map[string]any{ + "owner": tc.owner, + "repo": tc.repo, + "issue_numbers": tc.issueNumbers, + } + + // Add pagination parameters if specified + if tc.first != nil { + args["limit"] = *tc.first // first maps to limit in our tool + } + if tc.last != nil { + args["last"] = *tc.last + } + if tc.after != nil { + args["after"] = *tc.after + } + if tc.before != nil { + args["before"] = *tc.before + } + + // Create request + request := mcp.CallToolRequest{} + request.Params.Name = "find_closing_pull_requests" + request.Params.Arguments = args + + t.Logf("Testing %s: %s", tc.name, tc.description) + t.Logf("Arguments: %+v", args) + + // Call the tool + resp, err := mcpClient.CallTool(ctx, request) + + if tc.expectError { + if err != nil { + t.Logf("Expected error occurred: %v", err) + return + } + require.True(t, resp.IsError, "Expected error response") + return + } + + require.NoError(t, err, "expected successful tool call") + require.False(t, resp.IsError, fmt.Sprintf("expected non-error response: %+v", resp)) + require.NotEmpty(t, resp.Content, "Expected response content") + + // Parse response + textContent, ok := resp.Content[0].(mcp.TextContent) + require.True(t, ok, "expected TextContent") + + var response struct { + Results []struct { + Owner string `json:"owner"` + Repo string `json:"repo"` + IssueNumber int `json:"issue_number"` + TotalCount int `json:"total_count"` + ClosingPullRequests []struct { + Number int `json:"number"` + Title string `json:"title"` + State string `json:"state"` + } `json:"closing_pull_requests"` + PageInfo *struct { + HasNextPage bool `json:"hasNextPage"` + HasPreviousPage bool `json:"hasPreviousPage"` + StartCursor string `json:"startCursor"` + EndCursor string `json:"endCursor"` + } `json:"pageInfo,omitempty"` + } `json:"results"` + } + + err = json.Unmarshal([]byte(textContent.Text), &response) + require.NoError(t, err, "expected successful JSON parsing") + + // Verify response structure + require.NotEmpty(t, response.Results, "Expected at least one result") + + for i, result := range response.Results { + t.Logf("Result %d: Owner=%s, Repo=%s, Issue=%d, TotalCount=%d", + i+1, result.Owner, result.Repo, result.IssueNumber, result.TotalCount) + + // Verify pagination parameter effects + if tc.first != nil { + assert.LessOrEqual(t, len(result.ClosingPullRequests), *tc.first, + "Result count should not exceed 'first' parameter") + } + if tc.last != nil { + assert.LessOrEqual(t, len(result.ClosingPullRequests), *tc.last, + "Result count should not exceed 'last' parameter") + } + + // Log pagination info if present + if result.PageInfo != nil { + t.Logf(" PageInfo: HasNext=%t, HasPrev=%t", + result.PageInfo.HasNextPage, result.PageInfo.HasPreviousPage) + if result.PageInfo.StartCursor != "" { + t.Logf(" StartCursor: %s", result.PageInfo.StartCursor) + } + if result.PageInfo.EndCursor != "" { + t.Logf(" EndCursor: %s", result.PageInfo.EndCursor) + } + } + } + }) + } + }) + + t.Run("Error Validation", func(t *testing.T) { + // Test cases for parameter validation and error handling + errorTestCases := []struct { + name string + owner string + repo string + issueNumbers []int + args map[string]any + expectError bool + description string + }{ + { + name: "Conflicting limit and last parameters", + owner: "microsoft", + repo: "vscode", + issueNumbers: []int{1}, + args: map[string]any{ + "limit": 5, + "last": 3, + }, + expectError: true, + description: "Should reject conflicting limit and last parameters", + }, + { + name: "Conflicting after and before cursors", + owner: "microsoft", + repo: "vscode", + issueNumbers: []int{1}, + args: map[string]any{ + "after": "cursor1", + "before": "cursor2", + }, + expectError: true, + description: "Should reject conflicting after and before cursors", + }, + { + name: "Before cursor without last parameter", + owner: "microsoft", + repo: "vscode", + issueNumbers: []int{1}, + args: map[string]any{ + "before": "cursor1", + }, + expectError: true, + description: "Should reject before cursor without last parameter", + }, + { + name: "After cursor with last parameter", + owner: "microsoft", + repo: "vscode", + issueNumbers: []int{1}, + args: map[string]any{ + "after": "cursor1", + "last": 3, + }, + expectError: true, + description: "Should reject after cursor with last parameter", + }, + { + name: "Invalid limit range - too high", + owner: "microsoft", + repo: "vscode", + issueNumbers: []int{1}, + args: map[string]any{ + "limit": 251, + }, + expectError: true, + description: "Should reject limit greater than 250", + }, + { + name: "Invalid last range - too high", + owner: "microsoft", + repo: "vscode", + issueNumbers: []int{1}, + args: map[string]any{ + "last": 251, + }, + expectError: true, + description: "Should reject last greater than 250", + }, + } + + for _, tc := range errorTestCases { + t.Run(tc.name, func(t *testing.T) { + // Build base arguments + args := map[string]any{ + "owner": tc.owner, + "repo": tc.repo, + "issue_numbers": tc.issueNumbers, + } + + // Add test-specific arguments + for key, value := range tc.args { + args[key] = value + } + + // Create request + request := mcp.CallToolRequest{} + request.Params.Name = "find_closing_pull_requests" + request.Params.Arguments = args + + t.Logf("Testing %s: %s", tc.name, tc.description) + t.Logf("Arguments: %+v", args) + + // Call the tool + resp, err := mcpClient.CallTool(ctx, request) + + if tc.expectError { + // We expect either an error or an error response + if err != nil { + t.Logf("Expected error occurred: %v", err) + return + } + require.True(t, resp.IsError, "Expected error response") + t.Logf("Expected error in response: %+v", resp) + + // Verify error content contains helpful information + if len(resp.Content) > 0 { + if textContent, ok := resp.Content[0].(mcp.TextContent); ok { + assert.NotEmpty(t, textContent.Text, "Error message should not be empty") + t.Logf("Error message: %s", textContent.Text) + } + } + return + } + + require.NoError(t, err, "expected successful tool call") + require.False(t, resp.IsError, "expected non-error response") + }) + } + }) +} + +// Helper functions for pointer creation +func boolPtr(b bool) *bool { + return &b +} + +func intPtr(i int) *int { + return &i +} + +func stringPtr(s string) *string { + return &s +} diff --git a/pkg/github/discussions_test.go b/pkg/github/discussions_test.go index 9458dfce0..d1dcf064a 100644 --- a/pkg/github/discussions_test.go +++ b/pkg/github/discussions_test.go @@ -484,7 +484,7 @@ func Test_GetDiscussion(t *testing.T) { assert.ElementsMatch(t, toolDef.InputSchema.Required, []string{"owner", "repo", "discussionNumber"}) // Use exact string query that matches implementation output - qGetDiscussion := "query($discussionNumber:Int!$owner:String!$repo:String!){repository(owner: $owner, name: $repo){discussion(number: $discussionNumber){number,title,body,createdAt,url,category{name}}}}" + qGetDiscussion := "query($discussionNumber:Int!$owner:String!$repo:String!){repository(owner: $owner, name: $repo){discussion(number: $discussionNumber){number,title,body,createdAt,url,category{name}}}}" vars := map[string]interface{}{ "owner": "owner", diff --git a/pkg/github/find_closing_prs_integration_test.go b/pkg/github/find_closing_prs_integration_test.go new file mode 100644 index 000000000..eb5c6b015 --- /dev/null +++ b/pkg/github/find_closing_prs_integration_test.go @@ -0,0 +1,167 @@ +//go:build e2e + +package github + +import ( + "context" + "encoding/json" + "os" + "testing" + + "github.com/github/github-mcp-server/pkg/translations" + "github.com/google/go-github/v73/github" + "github.com/shurcooL/githubv4" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestFindClosingPullRequestsIntegration tests the FindClosingPullRequests tool with real GitHub API calls +func TestFindClosingPullRequestsIntegration(t *testing.T) { + // This test requires a GitHub token + token := os.Getenv("GITHUB_MCP_SERVER_E2E_TOKEN") + if token == "" { + t.Skip("GITHUB_MCP_SERVER_E2E_TOKEN environment variable is not set") + } + + // Create GitHub clients + httpClient := github.NewClient(nil).WithAuthToken(token).Client() + gqlClient := githubv4.NewClient(httpClient) + + getGQLClient := func(ctx context.Context) (*githubv4.Client, error) { + return gqlClient, nil + } + + // Create the tool + tool, handler := FindClosingPullRequests(getGQLClient, translations.NullTranslationHelper) + + // Test cases with known GitHub issues that were closed by PRs + testCases := []struct { + name string + owner string + repo string + issueNumbers []int + expectedResults int + expectSomeClosingPRs bool + expectSpecificIssue string + expectSpecificPRNumber int + }{ + { + name: "Single issue using issue_numbers - VS Code well-known closed issue", + owner: "microsoft", + repo: "vscode", + issueNumbers: []int{123456}, // This is a made-up issue for testing + expectedResults: 1, + expectSomeClosingPRs: false, // We expect this to not exist or have no closing PRs + }, + { + name: "Multiple issues using issue_numbers with mixed results", + owner: "microsoft", + repo: "vscode", + issueNumbers: []int{1, 999999}, + expectedResults: 2, + expectSomeClosingPRs: false, // These are likely non-existent or have no closing PRs + }, + { + name: "Issue from a popular repo using issue_numbers - React", + owner: "facebook", + repo: "react", + issueNumbers: []int{1}, // Very first issue in React repo + expectedResults: 1, + }, + } + + ctx := context.Background() + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Create request arguments + args := map[string]interface{}{ + "limit": 5, + "owner": tc.owner, + "repo": tc.repo, + "issue_numbers": tc.issueNumbers, + } + + // Create mock request + request := mockCallToolRequest{ + arguments: args, + } + + // Call the handler + result, err := handler(ctx, request) + + if err != nil { + t.Logf("Error calling tool: %v", err) + // For integration tests, we might expect some errors for non-existent issues + // Let's check if it's a reasonable error + assert.Contains(t, err.Error(), "failed to") + return + } + + require.NotNil(t, result) + assert.False(t, result.IsError, "Expected successful result") + + // Parse the response + textContent, ok := result.Content[0].(map[string]interface{}) + if !ok { + // Try to get as text content + if len(result.Content) > 0 { + if textResult, ok := result.Content[0].(string); ok { + t.Logf("Response: %s", textResult) + + // Parse JSON response + var response struct { + Results []map[string]interface{} `json:"results"` + } + err := json.Unmarshal([]byte(textResult), &response) + require.NoError(t, err, "Failed to parse JSON response") + + // Verify structure + assert.Len(t, response.Results, tc.expectedResults, "Expected specific number of results") + + for i, result := range response.Results { + t.Logf("Issue %d:", i+1) + t.Logf(" Owner: %v, Repo: %v, Number: %v", result["owner"], result["repo"], result["issue_number"]) + t.Logf(" Total closing PRs: %v", result["total_count"]) + + if errorMsg, hasError := result["error"]; hasError { + t.Logf(" Error: %v", errorMsg) + } + + // Verify basic structure + assert.NotEmpty(t, result["owner"], "Owner should not be empty") + assert.NotEmpty(t, result["repo"], "Repo should not be empty") + assert.NotNil(t, result["issue_number"], "Issue number should not be nil") + + // Check closing PRs if any + if closingPRs, ok := result["closing_pull_requests"].([]interface{}); ok { + t.Logf(" Found %d closing PRs", len(closingPRs)) + for j, pr := range closingPRs { + if prMap, ok := pr.(map[string]interface{}); ok { + t.Logf(" PR %d: #%v - %v", j+1, prMap["number"], prMap["title"]) + t.Logf(" State: %v, Merged: %v", prMap["state"], prMap["merged"]) + t.Logf(" URL: %v", prMap["url"]) + } + } + } + } + + return + } + } + t.Fatalf("Unexpected content type: %T", result.Content[0]) + } + + t.Logf("Response content: %+v", textContent) + }) + } +} + +// mockCallToolRequest implements the mcp.CallToolRequest interface for testing +type mockCallToolRequest struct { + arguments map[string]interface{} +} + +func (m mockCallToolRequest) GetArguments() map[string]interface{} { + return m.arguments +} diff --git a/pkg/github/issues.go b/pkg/github/issues.go index f718c37cb..69e4dea27 100644 --- a/pkg/github/issues.go +++ b/pkg/github/issues.go @@ -1321,3 +1321,336 @@ func AssignCodingAgentPrompt(t translations.TranslationHelperFunc) (tool mcp.Pro }, nil } } + +// ClosingPRNode represents a pull request that closed an issue +type ClosingPRNode struct { + Number githubv4.Int + Title githubv4.String + Body githubv4.String + State githubv4.String + URL githubv4.String + Merged githubv4.Boolean +} + +// ClosingPRsFragment represents the closedByPullRequestsReferences field with pagination info +type ClosingPRsFragment struct { + TotalCount githubv4.Int + Nodes []ClosingPRNode + PageInfo struct { + HasNextPage githubv4.Boolean + HasPreviousPage githubv4.Boolean + StartCursor githubv4.String + EndCursor githubv4.String + } +} + +// ClosingPullRequest represents a pull request in the response format +type ClosingPullRequest struct { + Number int `json:"number"` + Title string `json:"title"` + Body string `json:"body"` + State string `json:"state"` + URL string `json:"url"` + Merged bool `json:"merged"` +} + +const ( + // DefaultClosingPRsLimit is the default number of closing PRs to return per issue + // Aligned with GitHub GraphQL API default of 100 items per page + DefaultClosingPRsLimit = 100 + MaxGraphQLPageSize = 250 // Maximum page size for GitHub GraphQL API +) + +// FindClosingPullRequests creates a tool to find pull requests that closed specific issues +func FindClosingPullRequests(getGQLClient GetGQLClientFn, t translations.TranslationHelperFunc) (mcp.Tool, server.ToolHandlerFunc) { + return mcp.NewTool("find_closing_pull_requests", + mcp.WithDescription(t("TOOL_FIND_CLOSING_PULL_REQUESTS_DESCRIPTION", "Find pull requests that closed specific issues using closing references within a repository.")), + mcp.WithToolAnnotation(mcp.ToolAnnotation{ + Title: t("TOOL_FIND_CLOSING_PULL_REQUESTS_USER_TITLE", "Find closing pull requests"), + ReadOnlyHint: ToBoolPtr(true), + }), + mcp.WithString("owner", + mcp.Description("The owner of the repository"), + mcp.Required(), + ), + mcp.WithString("repo", + mcp.Description("The name of the repository"), + mcp.Required(), + ), + mcp.WithArray("issue_numbers", + mcp.Description("Array of issue numbers within the specified repository"), + mcp.Required(), + mcp.Items( + map[string]any{ + "type": "number", + }, + ), + ), + mcp.WithNumber("limit", + mcp.Description(fmt.Sprintf( + "Maximum number of closing PRs to return per issue (default: %d, max: %d)", + DefaultClosingPRsLimit, + MaxGraphQLPageSize, + )), + ), + mcp.WithBoolean("includeClosedPrs", + mcp.Description("Include closed/merged pull requests in results (default: false)"), + ), + mcp.WithBoolean("orderByState", + mcp.Description("Order results by pull request state (default: false)"), + ), + mcp.WithBoolean("userLinkedOnly", + mcp.Description("Return only manually linked pull requests (default: false)"), + ), + mcp.WithString("after", + mcp.Description("Cursor for forward pagination (use with first/limit)"), + ), + mcp.WithString("before", + mcp.Description("Cursor for backward pagination (use with last)"), + ), + mcp.WithNumber("last", + mcp.Description(fmt.Sprintf( + "Number of results from end for backward pagination (max: %d)", + MaxGraphQLPageSize, + )), + ), + ), + func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + // Parse pagination parameters + limit := DefaultClosingPRsLimit // default + limitExplicitlySet := false + if limitParam, exists := request.GetArguments()["limit"]; exists { + limitExplicitlySet = true + if limitFloat, ok := limitParam.(float64); ok { + limit = int(limitFloat) + if limit <= 0 || limit > MaxGraphQLPageSize { + return mcp.NewToolResultError(fmt.Sprintf("limit must be between 1 and %d inclusive (GitHub GraphQL API maximum)", MaxGraphQLPageSize)), nil + } + } else { + return mcp.NewToolResultError(fmt.Sprintf("limit must be a number between 1 and %d (GitHub GraphQL API maximum)", MaxGraphQLPageSize)), nil + } + } + + // Parse last parameter for backward pagination + last, err := OptionalIntParam(request, "last") + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("last parameter error: %s", err.Error())), nil + } + if last != 0 && (last <= 0 || last > MaxGraphQLPageSize) { + return mcp.NewToolResultError(fmt.Sprintf("last must be between 1 and %d inclusive for backward pagination (GitHub GraphQL API maximum)", MaxGraphQLPageSize)), nil + } + + // Parse cursor parameters + after, err := OptionalParam[string](request, "after") + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("after parameter error: %s", err.Error())), nil + } + before, err := OptionalParam[string](request, "before") + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("before parameter error: %s", err.Error())), nil + } + + // Validate pagination parameter combinations + if last != 0 && limitExplicitlySet { + return mcp.NewToolResultError("cannot use both 'limit' and 'last' parameters together - use 'limit' for forward pagination or 'last' for backward pagination"), nil + } + if after != "" && before != "" { + return mcp.NewToolResultError("cannot use both 'after' and 'before' cursors together - use 'after' for forward pagination or 'before' for backward pagination"), nil + } + if before != "" && last == 0 { + return mcp.NewToolResultError("'before' cursor requires 'last' parameter for backward pagination"), nil + } + if after != "" && last != 0 { + return mcp.NewToolResultError("'after' cursor cannot be used with 'last' parameter - use 'after' with 'limit' for forward pagination"), nil + } + + // Parse filtering parameters + includeClosedPrs, err := OptionalParam[bool](request, "includeClosedPrs") + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("includeClosedPrs parameter error: %s", err.Error())), nil + } + orderByState, err := OptionalParam[bool](request, "orderByState") + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("orderByState parameter error: %s", err.Error())), nil + } + userLinkedOnly, err := OptionalParam[bool](request, "userLinkedOnly") + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("userLinkedOnly parameter error: %s", err.Error())), nil + } + + // Get required parameters + owner, err := RequiredParam[string](request, "owner") + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("owner parameter error: %s", err.Error())), nil + } + repo, err := RequiredParam[string](request, "repo") + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("repo parameter error: %s", err.Error())), nil + } + issueNumbers, err := RequiredIntArrayParam(request, "issue_numbers") + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("issue_numbers parameter error: %s", err.Error())), nil + } + + // Get GraphQL client + client, err := getGQLClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GraphQL client: %w", err) + } + + // Create pagination parameters struct + paginationParams := struct { + First int + Last int + After string + Before string + IncludeClosedPrs bool + OrderByState bool + UserLinkedOnly bool + }{ + First: limit, + Last: last, + After: after, + Before: before, + IncludeClosedPrs: includeClosedPrs, + OrderByState: orderByState, + UserLinkedOnly: userLinkedOnly, + } + + // Process each issue number + var results []map[string]interface{} + for _, issueNum := range issueNumbers { + result, err := queryClosingPRsForIssueEnhanced(ctx, client, owner, repo, issueNum, paginationParams) + if err != nil { + // Add error result for this issue + results = append(results, map[string]interface{}{ + "owner": owner, + "repo": repo, + "issue_number": issueNum, + "error": err.Error(), + "total_count": 0, + "closing_pull_requests": []ClosingPullRequest{}, + }) + continue + } + results = append(results, result) + } + + // Return results + response := map[string]interface{}{ + "results": results, + } + + responseJSON, err := json.Marshal(response) + if err != nil { + return nil, fmt.Errorf("failed to marshal response: %w", err) + } + + return mcp.NewToolResultText(string(responseJSON)), nil + } +} + +// queryClosingPRsForIssueEnhanced queries closing PRs for a single issue with enhanced parameters +func queryClosingPRsForIssueEnhanced(ctx context.Context, client *githubv4.Client, owner, repo string, issueNumber int, params struct { + First int + Last int + After string + Before string + IncludeClosedPrs bool + OrderByState bool + UserLinkedOnly bool +}) (map[string]interface{}, error) { + // Define the GraphQL query for this specific issue + var query struct { + Repository struct { + Issue struct { + ClosedByPullRequestsReferences ClosingPRsFragment `graphql:"closedByPullRequestsReferences(first: $first, last: $last, after: $after, before: $before, includeClosedPrs: $includeClosedPrs, orderByState: $orderByState, userLinkedOnly: $userLinkedOnly)"` + } `graphql:"issue(number: $number)"` + } `graphql:"repository(owner: $owner, name: $repo)"` + } + + // Validate issue number (basic bounds check) + if issueNumber < 0 { + return nil, fmt.Errorf("issue number %d is out of valid range", issueNumber) + } + + // Validate pagination parameters (basic bounds check) + if params.Last < 0 { + return nil, fmt.Errorf("last parameter %d is out of valid range", params.Last) + } + if params.First < 0 { + return nil, fmt.Errorf("first parameter %d is out of valid range", params.First) + } + + // Build variables map + variables := map[string]any{ + "owner": githubv4.String(owner), + "repo": githubv4.String(repo), + "number": githubv4.Int(issueNumber), // #nosec G115 - issueNumber validated to be positive + } + + if params.Last != 0 { + variables["last"] = githubv4.Int(params.Last) // #nosec G115 - params.Last validated to be positive + variables["first"] = (*githubv4.Int)(nil) + } else { + variables["first"] = githubv4.Int(params.First) // #nosec G115 - params.First validated to be positive + variables["last"] = (*githubv4.Int)(nil) + } + + if params.After != "" { + variables["after"] = githubv4.String(params.After) + } else { + variables["after"] = (*githubv4.String)(nil) + } + + if params.Before != "" { + variables["before"] = githubv4.String(params.Before) + } else { + variables["before"] = (*githubv4.String)(nil) + } + + // Add filtering parameters + variables["includeClosedPrs"] = githubv4.Boolean(params.IncludeClosedPrs) + variables["orderByState"] = githubv4.Boolean(params.OrderByState) + variables["userLinkedOnly"] = githubv4.Boolean(params.UserLinkedOnly) + + err := client.Query(ctx, &query, variables) + if err != nil { + return nil, fmt.Errorf("failed to query issue: %w", err) + } + + // Convert GraphQL response to JSON format + var closingPullRequests []ClosingPullRequest + for _, node := range query.Repository.Issue.ClosedByPullRequestsReferences.Nodes { + closingPullRequests = append(closingPullRequests, ClosingPullRequest{ + Number: int(node.Number), + Title: string(node.Title), + Body: string(node.Body), + State: string(node.State), + URL: string(node.URL), + Merged: bool(node.Merged), + }) + } + + // Build response with pagination info + result := map[string]interface{}{ + "owner": owner, + "repo": repo, + "issue_number": issueNumber, + "total_count": int(query.Repository.Issue.ClosedByPullRequestsReferences.TotalCount), + "closing_pull_requests": closingPullRequests, + } + + // Add pagination information if cursors are being used + if params.After != "" || params.Before != "" || params.Last != 0 { + pageInfo := query.Repository.Issue.ClosedByPullRequestsReferences.PageInfo + result["page_info"] = map[string]interface{}{ + "has_next_page": bool(pageInfo.HasNextPage), + "has_previous_page": bool(pageInfo.HasPreviousPage), + "start_cursor": string(pageInfo.StartCursor), + "end_cursor": string(pageInfo.EndCursor), + } + } + + return result, nil +} diff --git a/pkg/github/server.go b/pkg/github/server.go index 193336b75..a06083627 100644 --- a/pkg/github/server.go +++ b/pkg/github/server.go @@ -99,6 +99,50 @@ func RequiredInt(r mcp.CallToolRequest, p string) (int, error) { return int(v), nil } +// RequiredIntArrayParam is a helper function that can be used to fetch a required integer array parameter from the request. +// It does the following checks: +// 1. Checks if the parameter is present in the request +// 2. Checks if the parameter is an array and each element can be converted to int +// 3. Checks if the array is not empty +func RequiredIntArrayParam(r mcp.CallToolRequest, p string) ([]int, error) { + // Check if the parameter is present in the request + if _, ok := r.GetArguments()[p]; !ok { + return nil, fmt.Errorf("missing required parameter: %s", p) + } + + switch v := r.GetArguments()[p].(type) { + case nil: + return nil, fmt.Errorf("missing required parameter: %s", p) + case []int: + if len(v) == 0 { + return nil, fmt.Errorf("parameter %s cannot be empty", p) + } + return v, nil + case []any: + if len(v) == 0 { + return nil, fmt.Errorf("parameter %s cannot be empty", p) + } + intSlice := make([]int, len(v)) + for i, elem := range v { + switch num := elem.(type) { + case float64: + intSlice[i] = int(num) + case int: + intSlice[i] = num + case int32: + intSlice[i] = int(num) + case int64: + intSlice[i] = int(num) + default: + return nil, fmt.Errorf("parameter %s contains non-numeric value, element %d is %T", p, i, elem) + } + } + return intSlice, nil + default: + return nil, fmt.Errorf("parameter %s is not an array, is %T", p, r.GetArguments()[p]) + } +} + // OptionalParam is a helper function that can be used to fetch a requested parameter from the request. // It does the following checks: // 1. Checks if the parameter is present in the request, if not, it returns its zero-value @@ -174,6 +218,43 @@ func OptionalStringArrayParam(r mcp.CallToolRequest, p string) ([]string, error) } } +// OptionalIntArrayParam is a helper function that can be used to fetch a requested parameter from the request. +// It does the following checks: +// 1. Checks if the parameter is present in the request, if not, it returns its zero-value +// 2. If it is present, iterates the elements and checks each is a number that can be converted to int +func OptionalIntArrayParam(r mcp.CallToolRequest, p string) ([]int, error) { + // Check if the parameter is present in the request + if _, ok := r.GetArguments()[p]; !ok { + return []int{}, nil + } + + switch v := r.GetArguments()[p].(type) { + case nil: + return []int{}, nil + case []int: + return v, nil + case []any: + intSlice := make([]int, len(v)) + for i, v := range v { + switch num := v.(type) { + case float64: + intSlice[i] = int(num) + case int: + intSlice[i] = num + case int32: + intSlice[i] = int(num) + case int64: + intSlice[i] = int(num) + default: + return []int{}, fmt.Errorf("parameter %s array element at index %d is not of type number, is %T", p, i, v) + } + } + return intSlice, nil + default: + return []int{}, fmt.Errorf("parameter %s could not be coerced to []int, is %T", p, r.GetArguments()[p]) + } +} + // WithPagination adds REST API pagination parameters to a tool. // https://docs.github.com/en/rest/using-the-rest-api/using-pagination-in-the-rest-api func WithPagination() mcp.ToolOption { diff --git a/pkg/github/server_test.go b/pkg/github/server_test.go index 7f8f29c0d..f96f934bf 100644 --- a/pkg/github/server_test.go +++ b/pkg/github/server_test.go @@ -478,6 +478,196 @@ func TestOptionalStringArrayParam(t *testing.T) { } } +func TestOptionalIntArrayParam(t *testing.T) { + tests := []struct { + name string + params map[string]interface{} + paramName string + expected []int + expectError bool + }{ + { + name: "parameter not in request", + params: map[string]any{}, + paramName: "numbers", + expected: []int{}, + expectError: false, + }, + { + name: "valid any array parameter with float64", + params: map[string]any{ + "numbers": []any{float64(1), float64(2), float64(3)}, + }, + paramName: "numbers", + expected: []int{1, 2, 3}, + expectError: false, + }, + { + name: "valid int array parameter", + params: map[string]any{ + "numbers": []int{1, 2, 3}, + }, + paramName: "numbers", + expected: []int{1, 2, 3}, + expectError: false, + }, + { + name: "mixed numeric types", + params: map[string]any{ + "numbers": []any{float64(1), int(2), int32(3), int64(4)}, + }, + paramName: "numbers", + expected: []int{1, 2, 3, 4}, + expectError: false, + }, + { + name: "invalid type in array", + params: map[string]any{ + "numbers": []any{float64(1), "not a number"}, + }, + paramName: "numbers", + expected: []int{}, + expectError: true, + }, + { + name: "nil value", + params: map[string]any{ + "numbers": nil, + }, + paramName: "numbers", + expected: []int{}, + expectError: false, + }, + { + name: "wrong parameter type", + params: map[string]any{ + "numbers": "not an array", + }, + paramName: "numbers", + expected: []int{}, + expectError: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + request := createMCPRequest(tc.params) + result, err := OptionalIntArrayParam(request, tc.paramName) + + if tc.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tc.expected, result) + } + }) + } +} + +func TestRequiredIntArrayParam(t *testing.T) { + tests := []struct { + name string + params map[string]interface{} + paramName string + expected []int + expectError bool + }{ + { + name: "parameter not in request", + params: map[string]any{}, + paramName: "numbers", + expected: nil, + expectError: true, + }, + { + name: "valid any array parameter with float64", + params: map[string]any{ + "numbers": []any{float64(1), float64(2), float64(3)}, + }, + paramName: "numbers", + expected: []int{1, 2, 3}, + expectError: false, + }, + { + name: "valid int array parameter", + params: map[string]any{ + "numbers": []int{1, 2, 3}, + }, + paramName: "numbers", + expected: []int{1, 2, 3}, + expectError: false, + }, + { + name: "mixed numeric types", + params: map[string]any{ + "numbers": []any{float64(1), int(2), int32(3), int64(4)}, + }, + paramName: "numbers", + expected: []int{1, 2, 3, 4}, + expectError: false, + }, + { + name: "invalid type in array", + params: map[string]any{ + "numbers": []any{float64(1), "not a number"}, + }, + paramName: "numbers", + expected: nil, + expectError: true, + }, + { + name: "nil value", + params: map[string]any{ + "numbers": nil, + }, + paramName: "numbers", + expected: nil, + expectError: true, + }, + { + name: "empty array", + params: map[string]any{ + "numbers": []any{}, + }, + paramName: "numbers", + expected: nil, + expectError: true, + }, + { + name: "empty int array", + params: map[string]any{ + "numbers": []int{}, + }, + paramName: "numbers", + expected: nil, + expectError: true, + }, + { + name: "wrong parameter type", + params: map[string]any{ + "numbers": "not an array", + }, + paramName: "numbers", + expected: nil, + expectError: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + request := createMCPRequest(tc.params) + result, err := RequiredIntArrayParam(request, tc.paramName) + + if tc.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tc.expected, result) + } + }) + } +} + func TestOptionalPaginationParams(t *testing.T) { tests := []struct { name string diff --git a/pkg/github/tools.go b/pkg/github/tools.go index 7fb1d39c0..455c436cc 100644 --- a/pkg/github/tools.go +++ b/pkg/github/tools.go @@ -54,6 +54,7 @@ func DefaultToolsetGroup(readOnly bool, getClient GetClientFn, getGQLClient GetG toolsets.NewServerTool(ListIssues(getClient, t)), toolsets.NewServerTool(GetIssueComments(getClient, t)), toolsets.NewServerTool(ListSubIssues(getClient, t)), + toolsets.NewServerTool(FindClosingPullRequests(getGQLClient, t)), ). AddWriteTools( toolsets.NewServerTool(CreateIssue(getClient, t)),