Skip to content

Commit 5f3926d

Browse files
juruentoby
andauthored
add support for create_pull_request (#63)
Co-authored-by: Toby Padilla <toby@toby.sh>
1 parent 671f824 commit 5f3926d

File tree

4 files changed

+268
-0
lines changed

4 files changed

+268
-0
lines changed

README.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,17 @@ export GITHUB_MCP_TOOL_ADD_ISSUE_COMMENT_DESCRIPTION="an alternative description
238238
- `commit_id`: SHA of commit to review (string, optional)
239239
- `comments`: Line-specific comments array of objects, each object with path (string), position (number), and body (string) (array, optional)
240240

241+
- **create_pull_request** - Create a new pull request
242+
243+
- `owner`: Repository owner (string, required)
244+
- `repo`: Repository name (string, required)
245+
- `title`: PR title (string, required)
246+
- `body`: PR description (string, optional)
247+
- `head`: Branch containing changes (string, required)
248+
- `base`: Branch to merge into (string, required)
249+
- `draft`: Create as draft PR (boolean, optional)
250+
- `maintainer_can_modify`: Allow maintainer edits (boolean, optional)
251+
241252
### Repositories
242253

243254
- **create_or_update_file** - Create or update a single file in a repository

pkg/github/pullrequests.go

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -712,3 +712,110 @@ func createPullRequestReview(client *github.Client, t translations.TranslationHe
712712
return mcp.NewToolResultText(string(r)), nil
713713
}
714714
}
715+
716+
// createPullRequest creates a tool to create a new pull request.
717+
func createPullRequest(client *github.Client, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) {
718+
return mcp.NewTool("create_pull_request",
719+
mcp.WithDescription(t("TOOL_CREATE_PULL_REQUEST_DESCRIPTION", "Create a new pull request in a GitHub repository")),
720+
mcp.WithString("owner",
721+
mcp.Required(),
722+
mcp.Description("Repository owner"),
723+
),
724+
mcp.WithString("repo",
725+
mcp.Required(),
726+
mcp.Description("Repository name"),
727+
),
728+
mcp.WithString("title",
729+
mcp.Required(),
730+
mcp.Description("PR title"),
731+
),
732+
mcp.WithString("body",
733+
mcp.Description("PR description"),
734+
),
735+
mcp.WithString("head",
736+
mcp.Required(),
737+
mcp.Description("Branch containing changes"),
738+
),
739+
mcp.WithString("base",
740+
mcp.Required(),
741+
mcp.Description("Branch to merge into"),
742+
),
743+
mcp.WithBoolean("draft",
744+
mcp.Description("Create as draft PR"),
745+
),
746+
mcp.WithBoolean("maintainer_can_modify",
747+
mcp.Description("Allow maintainer edits"),
748+
),
749+
),
750+
func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
751+
owner, err := requiredParam[string](request, "owner")
752+
if err != nil {
753+
return mcp.NewToolResultError(err.Error()), nil
754+
}
755+
repo, err := requiredParam[string](request, "repo")
756+
if err != nil {
757+
return mcp.NewToolResultError(err.Error()), nil
758+
}
759+
title, err := requiredParam[string](request, "title")
760+
if err != nil {
761+
return mcp.NewToolResultError(err.Error()), nil
762+
}
763+
head, err := requiredParam[string](request, "head")
764+
if err != nil {
765+
return mcp.NewToolResultError(err.Error()), nil
766+
}
767+
base, err := requiredParam[string](request, "base")
768+
if err != nil {
769+
return mcp.NewToolResultError(err.Error()), nil
770+
}
771+
772+
body, err := optionalParam[string](request, "body")
773+
if err != nil {
774+
return mcp.NewToolResultError(err.Error()), nil
775+
}
776+
777+
draft, err := optionalParam[bool](request, "draft")
778+
if err != nil {
779+
return mcp.NewToolResultError(err.Error()), nil
780+
}
781+
782+
maintainerCanModify, err := optionalParam[bool](request, "maintainer_can_modify")
783+
if err != nil {
784+
return mcp.NewToolResultError(err.Error()), nil
785+
}
786+
787+
newPR := &github.NewPullRequest{
788+
Title: github.Ptr(title),
789+
Head: github.Ptr(head),
790+
Base: github.Ptr(base),
791+
}
792+
793+
if body != "" {
794+
newPR.Body = github.Ptr(body)
795+
}
796+
797+
newPR.Draft = github.Ptr(draft)
798+
newPR.MaintainerCanModify = github.Ptr(maintainerCanModify)
799+
800+
pr, resp, err := client.PullRequests.Create(ctx, owner, repo, newPR)
801+
if err != nil {
802+
return nil, fmt.Errorf("failed to create pull request: %w", err)
803+
}
804+
defer func() { _ = resp.Body.Close() }()
805+
806+
if resp.StatusCode != http.StatusCreated {
807+
body, err := io.ReadAll(resp.Body)
808+
if err != nil {
809+
return nil, fmt.Errorf("failed to read response body: %w", err)
810+
}
811+
return mcp.NewToolResultError(fmt.Sprintf("failed to create pull request: %s", string(body))), nil
812+
}
813+
814+
r, err := json.Marshal(pr)
815+
if err != nil {
816+
return nil, fmt.Errorf("failed to marshal response: %w", err)
817+
}
818+
819+
return mcp.NewToolResultText(string(r)), nil
820+
}
821+
}

pkg/github/pullrequests_test.go

Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1187,3 +1187,152 @@ func Test_CreatePullRequestReview(t *testing.T) {
11871187
})
11881188
}
11891189
}
1190+
1191+
func Test_CreatePullRequest(t *testing.T) {
1192+
// Verify tool definition once
1193+
mockClient := github.NewClient(nil)
1194+
tool, _ := createPullRequest(mockClient, translations.NullTranslationHelper)
1195+
1196+
assert.Equal(t, "create_pull_request", tool.Name)
1197+
assert.NotEmpty(t, tool.Description)
1198+
assert.Contains(t, tool.InputSchema.Properties, "owner")
1199+
assert.Contains(t, tool.InputSchema.Properties, "repo")
1200+
assert.Contains(t, tool.InputSchema.Properties, "title")
1201+
assert.Contains(t, tool.InputSchema.Properties, "body")
1202+
assert.Contains(t, tool.InputSchema.Properties, "head")
1203+
assert.Contains(t, tool.InputSchema.Properties, "base")
1204+
assert.Contains(t, tool.InputSchema.Properties, "draft")
1205+
assert.Contains(t, tool.InputSchema.Properties, "maintainer_can_modify")
1206+
assert.ElementsMatch(t, tool.InputSchema.Required, []string{"owner", "repo", "title", "head", "base"})
1207+
1208+
// Setup mock PR for success case
1209+
mockPR := &github.PullRequest{
1210+
Number: github.Ptr(42),
1211+
Title: github.Ptr("Test PR"),
1212+
State: github.Ptr("open"),
1213+
HTMLURL: github.Ptr("https://github.com/owner/repo/pull/42"),
1214+
Head: &github.PullRequestBranch{
1215+
SHA: github.Ptr("abcd1234"),
1216+
Ref: github.Ptr("feature-branch"),
1217+
},
1218+
Base: &github.PullRequestBranch{
1219+
SHA: github.Ptr("efgh5678"),
1220+
Ref: github.Ptr("main"),
1221+
},
1222+
Body: github.Ptr("This is a test PR"),
1223+
Draft: github.Ptr(false),
1224+
MaintainerCanModify: github.Ptr(true),
1225+
User: &github.User{
1226+
Login: github.Ptr("testuser"),
1227+
},
1228+
}
1229+
1230+
tests := []struct {
1231+
name string
1232+
mockedClient *http.Client
1233+
requestArgs map[string]interface{}
1234+
expectError bool
1235+
expectedPR *github.PullRequest
1236+
expectedErrMsg string
1237+
}{
1238+
{
1239+
name: "successful PR creation",
1240+
mockedClient: mock.NewMockedHTTPClient(
1241+
mock.WithRequestMatchHandler(
1242+
mock.PostReposPullsByOwnerByRepo,
1243+
mockResponse(t, http.StatusCreated, mockPR),
1244+
),
1245+
),
1246+
1247+
requestArgs: map[string]interface{}{
1248+
"owner": "owner",
1249+
"repo": "repo",
1250+
"title": "Test PR",
1251+
"body": "This is a test PR",
1252+
"head": "feature-branch",
1253+
"base": "main",
1254+
"draft": false,
1255+
"maintainer_can_modify": true,
1256+
},
1257+
expectError: false,
1258+
expectedPR: mockPR,
1259+
},
1260+
{
1261+
name: "missing required parameter",
1262+
mockedClient: mock.NewMockedHTTPClient(),
1263+
requestArgs: map[string]interface{}{
1264+
"owner": "owner",
1265+
"repo": "repo",
1266+
// missing title, head, base
1267+
},
1268+
expectError: true,
1269+
expectedErrMsg: "missing required parameter: title",
1270+
},
1271+
{
1272+
name: "PR creation fails",
1273+
mockedClient: mock.NewMockedHTTPClient(
1274+
mock.WithRequestMatchHandler(
1275+
mock.PostReposPullsByOwnerByRepo,
1276+
http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
1277+
w.WriteHeader(http.StatusUnprocessableEntity)
1278+
_, _ = w.Write([]byte(`{"message":"Validation failed","errors":[{"resource":"PullRequest","code":"invalid"}]}`))
1279+
}),
1280+
),
1281+
),
1282+
requestArgs: map[string]interface{}{
1283+
"owner": "owner",
1284+
"repo": "repo",
1285+
"title": "Test PR",
1286+
"head": "feature-branch",
1287+
"base": "main",
1288+
},
1289+
expectError: true,
1290+
expectedErrMsg: "failed to create pull request",
1291+
},
1292+
}
1293+
1294+
for _, tc := range tests {
1295+
t.Run(tc.name, func(t *testing.T) {
1296+
// Setup client with mock
1297+
client := github.NewClient(tc.mockedClient)
1298+
_, handler := createPullRequest(client, translations.NullTranslationHelper)
1299+
1300+
// Create call request
1301+
request := createMCPRequest(tc.requestArgs)
1302+
1303+
// Call handler
1304+
result, err := handler(context.Background(), request)
1305+
1306+
// Verify results
1307+
if tc.expectError {
1308+
if err != nil {
1309+
assert.Contains(t, err.Error(), tc.expectedErrMsg)
1310+
return
1311+
}
1312+
1313+
// If no error returned but in the result
1314+
textContent := getTextResult(t, result)
1315+
assert.Contains(t, textContent.Text, tc.expectedErrMsg)
1316+
return
1317+
}
1318+
1319+
require.NoError(t, err)
1320+
1321+
// Parse the result and get the text content if no error
1322+
textContent := getTextResult(t, result)
1323+
1324+
// Unmarshal and verify the result
1325+
var returnedPR github.PullRequest
1326+
err = json.Unmarshal([]byte(textContent.Text), &returnedPR)
1327+
require.NoError(t, err)
1328+
assert.Equal(t, *tc.expectedPR.Number, *returnedPR.Number)
1329+
assert.Equal(t, *tc.expectedPR.Title, *returnedPR.Title)
1330+
assert.Equal(t, *tc.expectedPR.State, *returnedPR.State)
1331+
assert.Equal(t, *tc.expectedPR.HTMLURL, *returnedPR.HTMLURL)
1332+
assert.Equal(t, *tc.expectedPR.Head.SHA, *returnedPR.Head.SHA)
1333+
assert.Equal(t, *tc.expectedPR.Base.Ref, *returnedPR.Base.Ref)
1334+
assert.Equal(t, *tc.expectedPR.Body, *returnedPR.Body)
1335+
assert.Equal(t, *tc.expectedPR.User.Login, *returnedPR.User.Login)
1336+
})
1337+
}
1338+
}

pkg/github/server.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ func NewServer(client *github.Client, readOnly bool, t translations.TranslationH
5555
s.AddTool(mergePullRequest(client, t))
5656
s.AddTool(updatePullRequestBranch(client, t))
5757
s.AddTool(createPullRequestReview(client, t))
58+
s.AddTool(createPullRequest(client, t))
5859
}
5960

6061
// Add GitHub tools - Repositories

0 commit comments

Comments
 (0)