From e9bf1f194fce4887bf13dd71debf4c50685895ba Mon Sep 17 00:00:00 2001 From: William Martin Date: Thu, 3 Apr 2025 21:48:22 +0200 Subject: [PATCH 1/3] Use arrays rather than comma separated lists --- README.md | 10 ++-- pkg/github/helper_test.go | 46 +++++++++++++++ pkg/github/issues.go | 66 ++++++++++++++++------ pkg/github/issues_test.go | 44 ++++++++++++--- pkg/github/server.go | 37 ------------- pkg/github/server_test.go | 114 -------------------------------------- 6 files changed, 134 insertions(+), 183 deletions(-) diff --git a/README.md b/README.md index e4cd178a..b80513a9 100644 --- a/README.md +++ b/README.md @@ -129,8 +129,8 @@ export GITHUB_MCP_TOOL_ADD_ISSUE_COMMENT_DESCRIPTION="an alternative description - `repo`: Repository name (string, required) - `title`: Issue title (string, required) - `body`: Issue body content (string, optional) - - `assignees`: Comma-separated list of usernames to assign to this issue (string, optional) - - `labels`: Comma-separated list of labels to apply to this issue (string, optional) + - `assignees`: Usernames to assign to this issue (string[], optional) + - `labels`: Labels to apply to this issue (string[], optional) - **add_issue_comment** - Add a comment to an issue @@ -144,7 +144,7 @@ export GITHUB_MCP_TOOL_ADD_ISSUE_COMMENT_DESCRIPTION="an alternative description - `owner`: Repository owner (string, required) - `repo`: Repository name (string, required) - `state`: Filter by state ('open', 'closed', 'all') (string, optional) - - `labels`: Comma-separated list of labels to filter by (string, optional) + - `labels`: Labels to filter by (string[], optional) - `sort`: Sort by ('created', 'updated', 'comments') (string, optional) - `direction`: Sort direction ('asc', 'desc') (string, optional) - `since`: Filter by date (ISO 8601 timestamp) (string, optional) @@ -159,8 +159,8 @@ export GITHUB_MCP_TOOL_ADD_ISSUE_COMMENT_DESCRIPTION="an alternative description - `title`: New title (string, optional) - `body`: New description (string, optional) - `state`: New state ('open' or 'closed') (string, optional) - - `labels`: Comma-separated list of new labels (string, optional) - - `assignees`: Comma-separated list of new assignees (string, optional) + - `labels`: New labels (string[], optional) + - `assignees`: New assignees (string[], optional) - `milestone`: New milestone number (number, optional) - **search_issues** - Search for issues and pull requests diff --git a/pkg/github/helper_test.go b/pkg/github/helper_test.go index 72241623..9dcffa42 100644 --- a/pkg/github/helper_test.go +++ b/pkg/github/helper_test.go @@ -10,6 +10,52 @@ import ( "github.com/stretchr/testify/require" ) +// expectQueryParams is a helper function to create a partial mock that expects a +// request with the given query parameters, with the ability to chain a response handler. +func expectQueryParams(t *testing.T, expectedQueryParams map[string]string) *partialMock { + return &partialMock{ + t: t, + expectedQueryParams: expectedQueryParams, + } +} + +// expectRequestBody is a helper function to create a partial mock that expects a +// request with the given body, with the ability to chain a response handler. +func expectRequestBody(t *testing.T, expectedRequestBody any) *partialMock { + return &partialMock{ + t: t, + expectedRequestBody: expectedRequestBody, + } +} + +type partialMock struct { + t *testing.T + expectedQueryParams map[string]string + expectedRequestBody any +} + +func (p *partialMock) andThen(responseHandler http.HandlerFunc) http.HandlerFunc { + p.t.Helper() + return func(w http.ResponseWriter, r *http.Request) { + if p.expectedRequestBody != nil { + var unmarshaledRequestBody any + err := json.NewDecoder(r.Body).Decode(&unmarshaledRequestBody) + require.NoError(p.t, err) + + require.Equal(p.t, p.expectedRequestBody, unmarshaledRequestBody) + } + + if p.expectedQueryParams != nil { + require.Equal(p.t, len(p.expectedQueryParams), len(r.URL.Query())) + for k, v := range p.expectedQueryParams { + require.Equal(p.t, v, r.URL.Query().Get(k)) + } + } + + responseHandler(w, r) + } +} + // mockResponse is a helper function to create a mock HTTP response handler // that returns a specified status code and marshaled body. func mockResponse(t *testing.T, code int, body interface{}) http.HandlerFunc { diff --git a/pkg/github/issues.go b/pkg/github/issues.go index a62213ea..9c4a0ec2 100644 --- a/pkg/github/issues.go +++ b/pkg/github/issues.go @@ -228,11 +228,21 @@ func createIssue(client *github.Client, t translations.TranslationHelperFunc) (t mcp.WithString("body", mcp.Description("Issue body content"), ), - mcp.WithString("assignees", - mcp.Description("Comma-separate list of usernames to assign to this issue"), - ), - mcp.WithString("labels", - mcp.Description("Comma-separate list of labels to apply to this issue"), + mcp.WithArray("assignees", + mcp.Description("Usernames to assign to this issue"), + mcp.Items( + map[string]interface{}{ + "type": "string", + }, + ), + ), + mcp.WithArray("labels", + mcp.Description("Labels to apply to this issue"), + mcp.Items( + map[string]interface{}{ + "type": "string", + }, + ), ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { @@ -256,12 +266,13 @@ func createIssue(client *github.Client, t translations.TranslationHelperFunc) (t } // Get assignees - assignees, err := optionalCommaSeparatedListParam(request, "assignees") + assignees, err := optionalParam[[]string](request, "assignees") if err != nil { return mcp.NewToolResultError(err.Error()), nil } + // Get labels - labels, err := optionalCommaSeparatedListParam(request, "labels") + labels, err := optionalParam[[]string](request, "labels") if err != nil { return mcp.NewToolResultError(err.Error()), nil } @@ -312,8 +323,13 @@ func listIssues(client *github.Client, t translations.TranslationHelperFunc) (to mcp.WithString("state", mcp.Description("Filter by state ('open', 'closed', 'all')"), ), - mcp.WithString("labels", - mcp.Description("Comma-separated list of labels to filter by"), + mcp.WithArray("labels", + mcp.Description("Filter by labels"), + mcp.Items( + map[string]interface{}{ + "type": "string", + }, + ), ), mcp.WithString("sort", mcp.Description("Sort by ('created', 'updated', 'comments')"), @@ -349,7 +365,8 @@ func listIssues(client *github.Client, t translations.TranslationHelperFunc) (to return mcp.NewToolResultError(err.Error()), nil } - opts.Labels, err = optionalCommaSeparatedListParam(request, "labels") + // Get labels + opts.Labels, err = optionalParam[[]string](request, "labels") if err != nil { return mcp.NewToolResultError(err.Error()), nil } @@ -431,12 +448,23 @@ func updateIssue(client *github.Client, t translations.TranslationHelperFunc) (t ), mcp.WithString("state", mcp.Description("New state ('open' or 'closed')"), - ), - mcp.WithString("labels", - mcp.Description("Comma-separated list of new labels"), - ), - mcp.WithString("assignees", - mcp.Description("Comma-separated list of new assignees"), + mcp.Enum("open", "closed"), + ), + mcp.WithArray("labels", + mcp.Description("New labels"), + mcp.Items( + map[string]interface{}{ + "type": "string", + }, + ), + ), + mcp.WithArray("assignees", + mcp.Description("New assignees"), + mcp.Items( + map[string]interface{}{ + "type": "string", + }, + ), ), mcp.WithNumber("milestone", mcp.Description("New milestone number"), @@ -484,7 +512,8 @@ func updateIssue(client *github.Client, t translations.TranslationHelperFunc) (t issueRequest.State = github.Ptr(state) } - labels, err := optionalCommaSeparatedListParam(request, "labels") + // Get labels + labels, err := optionalParam[[]string](request, "labels") if err != nil { return mcp.NewToolResultError(err.Error()), nil } @@ -492,7 +521,8 @@ func updateIssue(client *github.Client, t translations.TranslationHelperFunc) (t issueRequest.Labels = &labels } - assignees, err := optionalCommaSeparatedListParam(request, "assignees") + // Get assignees + assignees, err := optionalParam[[]string](request, "assignees") if err != nil { return mcp.NewToolResultError(err.Error()), nil } diff --git a/pkg/github/issues_test.go b/pkg/github/issues_test.go index edc531ae..d9fdeb54 100644 --- a/pkg/github/issues_test.go +++ b/pkg/github/issues_test.go @@ -418,7 +418,14 @@ func Test_CreateIssue(t *testing.T) { mockedClient: mock.NewMockedHTTPClient( mock.WithRequestMatchHandler( mock.PostReposIssuesByOwnerByRepo, - mockResponse(t, http.StatusCreated, mockIssue), + expectRequestBody(t, map[string]any{ + "title": "Test Issue", + "body": "This is a test issue", + "labels": []any{"bug", "help wanted"}, + "assignees": []any{"user1", "user2"}, + }).andThen( + mockResponse(t, http.StatusCreated, mockIssue), + ), ), ), requestArgs: map[string]interface{}{ @@ -426,8 +433,8 @@ func Test_CreateIssue(t *testing.T) { "repo": "repo", "title": "Test Issue", "body": "This is a test issue", - "assignees": "user1, user2", - "labels": "bug, help wanted", + "assignees": []string{"user1", "user2"}, + "labels": []string{"bug", "help wanted"}, }, expectError: false, expectedIssue: mockIssue, @@ -606,16 +613,26 @@ func Test_ListIssues(t *testing.T) { { name: "list issues with all parameters", mockedClient: mock.NewMockedHTTPClient( - mock.WithRequestMatch( + mock.WithRequestMatchHandler( mock.GetReposIssuesByOwnerByRepo, - mockIssues, + expectQueryParams(t, map[string]string{ + "state": "open", + "labels": "bug,enhancement", + "sort": "created", + "direction": "desc", + "since": "2023-01-01T00:00:00Z", + "page": "1", + "per_page": "30", + }).andThen( + mockResponse(t, http.StatusOK, mockIssues), + ), ), ), requestArgs: map[string]interface{}{ "owner": "owner", "repo": "repo", "state": "open", - "labels": "bug,enhancement", + "labels": []string{"bug", "enhancement"}, "sort": "created", "direction": "desc", "since": "2023-01-01T00:00:00Z", @@ -750,7 +767,16 @@ func Test_UpdateIssue(t *testing.T) { mockedClient: mock.NewMockedHTTPClient( mock.WithRequestMatchHandler( mock.PatchReposIssuesByOwnerByRepoByIssueNumber, - mockResponse(t, http.StatusOK, mockIssue), + expectRequestBody(t, map[string]any{ + "title": "Updated Issue Title", + "body": "Updated issue description", + "state": "closed", + "labels": []any{"bug", "priority"}, + "assignees": []any{"assignee1", "assignee2"}, + "milestone": float64(5), + }).andThen( + mockResponse(t, http.StatusOK, mockIssue), + ), ), ), requestArgs: map[string]interface{}{ @@ -760,8 +786,8 @@ func Test_UpdateIssue(t *testing.T) { "title": "Updated Issue Title", "body": "Updated issue description", "state": "closed", - "labels": "bug,priority", - "assignees": "assignee1,assignee2", + "labels": []string{"bug", "priority"}, + "assignees": []string{"assignee1", "assignee2"}, "milestone": float64(5), }, expectError: false, diff --git a/pkg/github/server.go b/pkg/github/server.go index d652dde0..f93ca37f 100644 --- a/pkg/github/server.go +++ b/pkg/github/server.go @@ -7,7 +7,6 @@ import ( "fmt" "io" "net/http" - "strings" "github.com/github/github-mcp-server/pkg/translations" "github.com/google/go-github/v69/github" @@ -119,25 +118,6 @@ func isAcceptedError(err error) bool { return errors.As(err, &acceptedError) } -// parseCommaSeparatedList is a helper function that parses a comma-separated list of strings from the input string. -func parseCommaSeparatedList(input string) []string { - if input == "" { - return nil - } - - parts := strings.Split(input, ",") - result := make([]string, 0, len(parts)) - - for _, part := range parts { - trimmed := strings.TrimSpace(part) - if trimmed != "" { - result = append(result, trimmed) - } - } - - return result -} - // requiredParam is a helper function that can be used to fetch a requested parameter from the request. // It does the following checks: // 1. Checks if the parameter is present in the request. @@ -221,20 +201,3 @@ func optionalIntParamWithDefault(r mcp.CallToolRequest, p string, d int) (int, e } return v, nil } - -// optionalCommaSeparatedListParam is a helper function that can be used to fetch a requested parameter from the request. -// It does the following: -// 1. Checks if the parameter is present in the request, if not, it returns an empty list -// 2. If it is present, it checks if the parameter is of the expected type and uses parseCommaSeparatedList to parse it -// and return the list of strings -func optionalCommaSeparatedListParam(r mcp.CallToolRequest, p string) ([]string, error) { - v, err := optionalParam[string](r, p) - if err != nil { - return []string{}, err - } - l := parseCommaSeparatedList(v) - if len(l) == 0 { - return []string{}, nil - } - return l, nil -} diff --git a/pkg/github/server_test.go b/pkg/github/server_test.go index a4d819f7..ffaa4dd8 100644 --- a/pkg/github/server_test.go +++ b/pkg/github/server_test.go @@ -168,67 +168,6 @@ func Test_IsAcceptedError(t *testing.T) { } } -func Test_ParseCommaSeparatedList(t *testing.T) { - tests := []struct { - name string - input string - expected []string - }{ - { - name: "simple comma separated values", - input: "one,two,three", - expected: []string{"one", "two", "three"}, - }, - { - name: "values with spaces", - input: "one, two, three", - expected: []string{"one", "two", "three"}, - }, - { - name: "values with extra spaces", - input: " one , two , three ", - expected: []string{"one", "two", "three"}, - }, - { - name: "empty values in between", - input: "one,,three", - expected: []string{"one", "three"}, - }, - { - name: "only spaces", - input: " , , ", - expected: []string{}, - }, - { - name: "empty string", - input: "", - expected: nil, - }, - { - name: "single value", - input: "one", - expected: []string{"one"}, - }, - { - name: "trailing comma", - input: "one,two,", - expected: []string{"one", "two"}, - }, - { - name: "leading comma", - input: ",one,two", - expected: []string{"one", "two"}, - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - result := parseCommaSeparatedList(tc.input) - assert.Equal(t, tc.expected, result) - }) - } -} - func Test_RequiredStringParam(t *testing.T) { tests := []struct { name string @@ -492,59 +431,6 @@ func Test_OptionalNumberParamWithDefault(t *testing.T) { } } -func Test_OptionalCommaSeparatedListParam(t *testing.T) { - tests := []struct { - name string - params map[string]interface{} - paramName string - expected []string - expectError bool - }{ - { - name: "valid comma-separated list", - params: map[string]interface{}{"tags": "one,two,three"}, - paramName: "tags", - expected: []string{"one", "two", "three"}, - expectError: false, - }, - { - name: "empty list", - params: map[string]interface{}{"tags": ""}, - paramName: "tags", - expected: []string{}, - expectError: false, - }, - { - name: "missing parameter", - params: map[string]interface{}{}, - paramName: "tags", - expected: []string{}, - expectError: false, - }, - { - name: "wrong type parameter", - params: map[string]interface{}{"tags": 123}, - paramName: "tags", - expected: nil, - expectError: true, - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - request := createMCPRequest(tc.params) - result, err := optionalCommaSeparatedListParam(request, tc.paramName) - - if tc.expectError { - assert.Error(t, err) - } else { - assert.NoError(t, err) - assert.Equal(t, tc.expected, result) - } - }) - } -} - func Test_OptionalBooleanParam(t *testing.T) { tests := []struct { name string From 0c453207379294288f6a7df2ddba5041a7a13571 Mon Sep 17 00:00:00 2001 From: William Martin Date: Thu, 3 Apr 2025 20:50:00 +0200 Subject: [PATCH 2/3] Enumerate strings in schema --- pkg/github/issues.go | 17 +++++++++++++++++ pkg/github/search.go | 3 +++ 2 files changed, 20 insertions(+) diff --git a/pkg/github/issues.go b/pkg/github/issues.go index 9c4a0ec2..53ce61bf 100644 --- a/pkg/github/issues.go +++ b/pkg/github/issues.go @@ -144,9 +144,23 @@ func searchIssues(client *github.Client, t translations.TranslationHelperFunc) ( ), mcp.WithString("sort", mcp.Description("Sort field (comments, reactions, created, etc.)"), + mcp.Enum( + "comments", + "reactions", + "reactions-+1", + "reactions--1", + "reactions-smile", + "reactions-thinking_face", + "reactions-heart", + "reactions-tada", + "interactions", + "created", + "updated", + ), ), mcp.WithString("order", mcp.Description("Sort order ('asc' or 'desc')"), + mcp.Enum("asc", "desc"), ), mcp.WithNumber("per_page", mcp.Description("Results per page (max 100)"), @@ -322,6 +336,7 @@ func listIssues(client *github.Client, t translations.TranslationHelperFunc) (to ), mcp.WithString("state", mcp.Description("Filter by state ('open', 'closed', 'all')"), + mcp.Enum("open", "closed", "all"), ), mcp.WithArray("labels", mcp.Description("Filter by labels"), @@ -333,9 +348,11 @@ func listIssues(client *github.Client, t translations.TranslationHelperFunc) (to ), mcp.WithString("sort", mcp.Description("Sort by ('created', 'updated', 'comments')"), + mcp.Enum("created", "updated", "comments"), ), mcp.WithString("direction", mcp.Description("Sort direction ('asc', 'desc')"), + mcp.Enum("asc", "desc"), ), mcp.WithString("since", mcp.Description("Filter by date (ISO 8601 timestamp)"), diff --git a/pkg/github/search.go b/pkg/github/search.go index 904dc737..fc81432d 100644 --- a/pkg/github/search.go +++ b/pkg/github/search.go @@ -84,6 +84,7 @@ func searchCode(client *github.Client, t translations.TranslationHelperFunc) (to ), mcp.WithString("order", mcp.Description("Sort order ('asc' or 'desc')"), + mcp.Enum("asc", "desc"), ), mcp.WithNumber("per_page", mcp.Description("Results per page (max 100)"), @@ -156,9 +157,11 @@ func searchUsers(client *github.Client, t translations.TranslationHelperFunc) (t ), mcp.WithString("sort", mcp.Description("Sort field (followers, repositories, joined)"), + mcp.Enum("followers", "repositories", "joined"), ), mcp.WithString("order", mcp.Description("Sort order ('asc' or 'desc')"), + mcp.Enum("asc", "desc"), ), mcp.WithNumber("per_page", mcp.Description("Results per page (max 100)"), From a7d6411ea6b7becf91ca775bdec258d9bebba3c2 Mon Sep 17 00:00:00 2001 From: William Martin Date: Thu, 3 Apr 2025 21:58:50 +0200 Subject: [PATCH 3/3] Add boundaries to pagination --- pkg/github/issues.go | 3 +++ pkg/github/search.go | 6 ++++++ 2 files changed, 9 insertions(+) diff --git a/pkg/github/issues.go b/pkg/github/issues.go index 53ce61bf..e27215ce 100644 --- a/pkg/github/issues.go +++ b/pkg/github/issues.go @@ -164,9 +164,12 @@ func searchIssues(client *github.Client, t translations.TranslationHelperFunc) ( ), mcp.WithNumber("per_page", mcp.Description("Results per page (max 100)"), + mcp.Min(1), + mcp.Max(100), ), mcp.WithNumber("page", mcp.Description("Page number"), + mcp.Min(1), ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { diff --git a/pkg/github/search.go b/pkg/github/search.go index fc81432d..e02c3d0c 100644 --- a/pkg/github/search.go +++ b/pkg/github/search.go @@ -88,9 +88,12 @@ func searchCode(client *github.Client, t translations.TranslationHelperFunc) (to ), mcp.WithNumber("per_page", mcp.Description("Results per page (max 100)"), + mcp.Min(1), + mcp.Max(100), ), mcp.WithNumber("page", mcp.Description("Page number"), + mcp.Min(1), ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { @@ -165,9 +168,12 @@ func searchUsers(client *github.Client, t translations.TranslationHelperFunc) (t ), mcp.WithNumber("per_page", mcp.Description("Results per page (max 100)"), + mcp.Min(1), + mcp.Max(100), ), mcp.WithNumber("page", mcp.Description("Page number"), + mcp.Min(1), ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {