diff --git a/README.md b/README.md index 86e516ce..b78be380 100644 --- a/README.md +++ b/README.md @@ -287,6 +287,17 @@ export GITHUB_MCP_TOOL_ADD_ISSUE_COMMENT_DESCRIPTION="an alternative description - `draft`: Create as draft PR (boolean, optional) - `maintainer_can_modify`: Allow maintainer edits (boolean, optional) +- **update_pull_request** - Update an existing pull request in a GitHub repository + + - `owner`: Repository owner (string, required) + - `repo`: Repository name (string, required) + - `pullNumber`: Pull request number to update (number, required) + - `title`: New title (string, optional) + - `body`: New description (string, optional) + - `state`: New state ('open' or 'closed') (string, optional) + - `base`: New base branch name (string, optional) + - `maintainer_can_modify`: Allow maintainer edits (boolean, optional) + ### Repositories - **create_or_update_file** - Create or update a single file in a repository diff --git a/pkg/github/helper_test.go b/pkg/github/helper_test.go index 9dcffa42..40fc0b94 100644 --- a/pkg/github/helper_test.go +++ b/pkg/github/helper_test.go @@ -93,3 +93,115 @@ func getTextResult(t *testing.T, result *mcp.CallToolResult) mcp.TextContent { assert.Equal(t, "text", textContent.Type) return textContent } + +func TestOptionalParamOK(t *testing.T) { + tests := []struct { + name string + args map[string]interface{} + paramName string + expectedVal interface{} + expectedOk bool + expectError bool + errorMsg string + }{ + { + name: "present and correct type (string)", + args: map[string]interface{}{"myParam": "hello"}, + paramName: "myParam", + expectedVal: "hello", + expectedOk: true, + expectError: false, + }, + { + name: "present and correct type (bool)", + args: map[string]interface{}{"myParam": true}, + paramName: "myParam", + expectedVal: true, + expectedOk: true, + expectError: false, + }, + { + name: "present and correct type (number)", + args: map[string]interface{}{"myParam": float64(123)}, + paramName: "myParam", + expectedVal: float64(123), + expectedOk: true, + expectError: false, + }, + { + name: "present but wrong type (string expected, got bool)", + args: map[string]interface{}{"myParam": true}, + paramName: "myParam", + expectedVal: "", // Zero value for string + expectedOk: true, // ok is true because param exists + expectError: true, + errorMsg: "parameter myParam is not of type string, is bool", + }, + { + name: "present but wrong type (bool expected, got string)", + args: map[string]interface{}{"myParam": "true"}, + paramName: "myParam", + expectedVal: false, // Zero value for bool + expectedOk: true, // ok is true because param exists + expectError: true, + errorMsg: "parameter myParam is not of type bool, is string", + }, + { + name: "parameter not present", + args: map[string]interface{}{"anotherParam": "value"}, + paramName: "myParam", + expectedVal: "", // Zero value for string + expectedOk: false, + expectError: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + request := createMCPRequest(tc.args) + + // Test with string type assertion + if _, isString := tc.expectedVal.(string); isString || tc.errorMsg == "parameter myParam is not of type string, is bool" { + val, ok, err := OptionalParamOK[string](request, tc.paramName) + if tc.expectError { + require.Error(t, err) + assert.Contains(t, err.Error(), tc.errorMsg) + assert.Equal(t, tc.expectedOk, ok) // Check ok even on error + assert.Equal(t, tc.expectedVal, val) // Check zero value on error + } else { + require.NoError(t, err) + assert.Equal(t, tc.expectedOk, ok) + assert.Equal(t, tc.expectedVal, val) + } + } + + // Test with bool type assertion + if _, isBool := tc.expectedVal.(bool); isBool || tc.errorMsg == "parameter myParam is not of type bool, is string" { + val, ok, err := OptionalParamOK[bool](request, tc.paramName) + if tc.expectError { + require.Error(t, err) + assert.Contains(t, err.Error(), tc.errorMsg) + assert.Equal(t, tc.expectedOk, ok) // Check ok even on error + assert.Equal(t, tc.expectedVal, val) // Check zero value on error + } else { + require.NoError(t, err) + assert.Equal(t, tc.expectedOk, ok) + assert.Equal(t, tc.expectedVal, val) + } + } + + // Test with float64 type assertion (for number case) + if _, isFloat := tc.expectedVal.(float64); isFloat { + val, ok, err := OptionalParamOK[float64](request, tc.paramName) + if tc.expectError { + // This case shouldn't happen for float64 in the defined tests + require.Fail(t, "Unexpected error case for float64") + } else { + require.NoError(t, err) + assert.Equal(t, tc.expectedOk, ok) + assert.Equal(t, tc.expectedVal, val) + } + } + }) + } +} diff --git a/pkg/github/pullrequests.go b/pkg/github/pullrequests.go index 65b87154..c5f9d9fa 100644 --- a/pkg/github/pullrequests.go +++ b/pkg/github/pullrequests.go @@ -67,6 +67,119 @@ func GetPullRequest(client *github.Client, t translations.TranslationHelperFunc) } } +// UpdatePullRequest creates a tool to update an existing pull request. +func UpdatePullRequest(client *github.Client, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { + return mcp.NewTool("update_pull_request", + mcp.WithDescription(t("TOOL_UPDATE_PULL_REQUEST_DESCRIPTION", "Update an existing pull request in a GitHub repository")), + mcp.WithString("owner", + mcp.Required(), + mcp.Description("Repository owner"), + ), + mcp.WithString("repo", + mcp.Required(), + mcp.Description("Repository name"), + ), + mcp.WithNumber("pullNumber", + mcp.Required(), + mcp.Description("Pull request number to update"), + ), + mcp.WithString("title", + mcp.Description("New title"), + ), + mcp.WithString("body", + mcp.Description("New description"), + ), + mcp.WithString("state", + mcp.Description("New state ('open' or 'closed')"), + mcp.Enum("open", "closed"), + ), + mcp.WithString("base", + mcp.Description("New base branch name"), + ), + mcp.WithBoolean("maintainer_can_modify", + mcp.Description("Allow maintainer edits"), + ), + ), + func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + owner, err := requiredParam[string](request, "owner") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + repo, err := requiredParam[string](request, "repo") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + pullNumber, err := RequiredInt(request, "pullNumber") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + // Build the update struct only with provided fields + update := &github.PullRequest{} + updateNeeded := false + + if title, ok, err := OptionalParamOK[string](request, "title"); err != nil { + return mcp.NewToolResultError(err.Error()), nil + } else if ok { + update.Title = github.Ptr(title) + updateNeeded = true + } + + if body, ok, err := OptionalParamOK[string](request, "body"); err != nil { + return mcp.NewToolResultError(err.Error()), nil + } else if ok { + update.Body = github.Ptr(body) + updateNeeded = true + } + + if state, ok, err := OptionalParamOK[string](request, "state"); err != nil { + return mcp.NewToolResultError(err.Error()), nil + } else if ok { + update.State = github.Ptr(state) + updateNeeded = true + } + + if base, ok, err := OptionalParamOK[string](request, "base"); err != nil { + return mcp.NewToolResultError(err.Error()), nil + } else if ok { + update.Base = &github.PullRequestBranch{Ref: github.Ptr(base)} + updateNeeded = true + } + + if maintainerCanModify, ok, err := OptionalParamOK[bool](request, "maintainer_can_modify"); err != nil { + return mcp.NewToolResultError(err.Error()), nil + } else if ok { + update.MaintainerCanModify = github.Ptr(maintainerCanModify) + updateNeeded = true + } + + if !updateNeeded { + return mcp.NewToolResultError("No update parameters provided."), nil + } + + pr, resp, err := client.PullRequests.Edit(ctx, owner, repo, pullNumber, update) + if err != nil { + return nil, fmt.Errorf("failed to update pull request: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + return mcp.NewToolResultError(fmt.Sprintf("failed to update pull request: %s", string(body))), nil + } + + r, err := json.Marshal(pr) + if err != nil { + return nil, fmt.Errorf("failed to marshal response: %w", err) + } + + return mcp.NewToolResultText(string(r)), nil + } +} + // ListPullRequests creates a tool to list and filter repository pull requests. func ListPullRequests(client *github.Client, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { return mcp.NewTool("list_pull_requests", diff --git a/pkg/github/pullrequests_test.go b/pkg/github/pullrequests_test.go index 34c41cc7..e9647029 100644 --- a/pkg/github/pullrequests_test.go +++ b/pkg/github/pullrequests_test.go @@ -126,6 +126,188 @@ func Test_GetPullRequest(t *testing.T) { } } +func Test_UpdatePullRequest(t *testing.T) { + // Verify tool definition once + mockClient := github.NewClient(nil) + tool, _ := UpdatePullRequest(mockClient, translations.NullTranslationHelper) + + assert.Equal(t, "update_pull_request", tool.Name) + assert.NotEmpty(t, tool.Description) + assert.Contains(t, tool.InputSchema.Properties, "owner") + assert.Contains(t, tool.InputSchema.Properties, "repo") + assert.Contains(t, tool.InputSchema.Properties, "pullNumber") + assert.Contains(t, tool.InputSchema.Properties, "title") + assert.Contains(t, tool.InputSchema.Properties, "body") + assert.Contains(t, tool.InputSchema.Properties, "state") + assert.Contains(t, tool.InputSchema.Properties, "base") + assert.Contains(t, tool.InputSchema.Properties, "maintainer_can_modify") + assert.ElementsMatch(t, tool.InputSchema.Required, []string{"owner", "repo", "pullNumber"}) + + // Setup mock PR for success case + mockUpdatedPR := &github.PullRequest{ + Number: github.Ptr(42), + Title: github.Ptr("Updated Test PR Title"), + State: github.Ptr("open"), + HTMLURL: github.Ptr("https://github.com/owner/repo/pull/42"), + Body: github.Ptr("Updated test PR body."), + MaintainerCanModify: github.Ptr(false), + Base: &github.PullRequestBranch{ + Ref: github.Ptr("develop"), + }, + } + + mockClosedPR := &github.PullRequest{ + Number: github.Ptr(42), + Title: github.Ptr("Test PR"), + State: github.Ptr("closed"), // State updated + } + + tests := []struct { + name string + mockedClient *http.Client + requestArgs map[string]interface{} + expectError bool + expectedPR *github.PullRequest + expectedErrMsg string + }{ + { + name: "successful PR update (title, body, base, maintainer_can_modify)", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.PatchReposPullsByOwnerByRepoByPullNumber, + // Expect the flat string based on previous test failure output and API docs + expectRequestBody(t, map[string]interface{}{ + "title": "Updated Test PR Title", + "body": "Updated test PR body.", + "base": "develop", + "maintainer_can_modify": false, + }).andThen( + mockResponse(t, http.StatusOK, mockUpdatedPR), + ), + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "pullNumber": float64(42), + "title": "Updated Test PR Title", + "body": "Updated test PR body.", + "base": "develop", + "maintainer_can_modify": false, + }, + expectError: false, + expectedPR: mockUpdatedPR, + }, + { + name: "successful PR update (state)", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.PatchReposPullsByOwnerByRepoByPullNumber, + expectRequestBody(t, map[string]interface{}{ + "state": "closed", + }).andThen( + mockResponse(t, http.StatusOK, mockClosedPR), + ), + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "pullNumber": float64(42), + "state": "closed", + }, + expectError: false, + expectedPR: mockClosedPR, + }, + { + name: "no update parameters provided", + mockedClient: mock.NewMockedHTTPClient(), // No API call expected + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "pullNumber": float64(42), + // No update fields + }, + expectError: false, // Error is returned in the result, not as Go error + expectedErrMsg: "No update parameters provided", + }, + { + name: "PR update fails (API error)", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.PatchReposPullsByOwnerByRepoByPullNumber, + http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusUnprocessableEntity) + _, _ = w.Write([]byte(`{"message": "Validation Failed"}`)) + }), + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "pullNumber": float64(42), + "title": "Invalid Title Causing Error", + }, + expectError: true, + expectedErrMsg: "failed to update pull request", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Setup client with mock + client := github.NewClient(tc.mockedClient) + _, handler := UpdatePullRequest(client, translations.NullTranslationHelper) + + // Create call request + request := createMCPRequest(tc.requestArgs) + + // Call handler + result, err := handler(context.Background(), request) + + // Verify results + if tc.expectError { + require.Error(t, err) + assert.Contains(t, err.Error(), tc.expectedErrMsg) + return + } + + require.NoError(t, err) + + // Parse the result and get the text content + textContent := getTextResult(t, result) + + // Check for expected error message within the result text + if tc.expectedErrMsg != "" { + assert.Contains(t, textContent.Text, tc.expectedErrMsg) + return + } + + // Unmarshal and verify the successful result + var returnedPR github.PullRequest + err = json.Unmarshal([]byte(textContent.Text), &returnedPR) + require.NoError(t, err) + assert.Equal(t, *tc.expectedPR.Number, *returnedPR.Number) + if tc.expectedPR.Title != nil { + assert.Equal(t, *tc.expectedPR.Title, *returnedPR.Title) + } + if tc.expectedPR.Body != nil { + assert.Equal(t, *tc.expectedPR.Body, *returnedPR.Body) + } + if tc.expectedPR.State != nil { + assert.Equal(t, *tc.expectedPR.State, *returnedPR.State) + } + if tc.expectedPR.Base != nil && tc.expectedPR.Base.Ref != nil { + assert.NotNil(t, returnedPR.Base) + assert.Equal(t, *tc.expectedPR.Base.Ref, *returnedPR.Base.Ref) + } + if tc.expectedPR.MaintainerCanModify != nil { + assert.Equal(t, *tc.expectedPR.MaintainerCanModify, *returnedPR.MaintainerCanModify) + } + }) + } +} + func Test_ListPullRequests(t *testing.T) { // Verify tool definition once mockClient := github.NewClient(nil) diff --git a/pkg/github/server.go b/pkg/github/server.go index 5852d581..84c15f50 100644 --- a/pkg/github/server.go +++ b/pkg/github/server.go @@ -53,6 +53,7 @@ func NewServer(client *github.Client, version string, readOnly bool, t translati s.AddTool(UpdatePullRequestBranch(client, t)) s.AddTool(CreatePullRequestReview(client, t)) s.AddTool(CreatePullRequest(client, t)) + s.AddTool(UpdatePullRequest(client, t)) } // Add GitHub tools - Repositories @@ -112,6 +113,30 @@ func GetMe(client *github.Client, t translations.TranslationHelperFunc) (tool mc } } +// OptionalParamOK is a helper function that can be used to fetch a requested parameter from the request. +// It returns the value, a boolean indicating if the parameter was present, and an error if the type is wrong. +func OptionalParamOK[T any](r mcp.CallToolRequest, p string) (value T, ok bool, err error) { + // Check if the parameter is present in the request + val, exists := r.Params.Arguments[p] + if !exists { + // Not present, return zero value, false, no error + return + } + + // Check if the parameter is of the expected type + value, ok = val.(T) + if !ok { + // Present but wrong type + err = fmt.Errorf("parameter %s is not of type %T, is %T", p, value, val) + ok = true // Set ok to true because the parameter *was* present, even if wrong type + return + } + + // Present and correct type + ok = true + return +} + // isAcceptedError checks if the error is an accepted error. func isAcceptedError(err error) bool { var acceptedError *github.AcceptedError