From 5a3dc7f97faeb6adcf6de6884126fb62342cd2bf Mon Sep 17 00:00:00 2001 From: Eran Cohen Date: Thu, 24 Apr 2025 16:23:25 +0300 Subject: [PATCH 1/4] feat: Add support for git tag operations Add git tag functionality including: - List repository tags - Get tag details - Support for tag-based content access This enables basic read-only tag management through the MCP server API. --- pkg/github/repositories.go | 144 ++++++++++++++++++ pkg/github/repositories_test.go | 254 ++++++++++++++++++++++++++++++++ pkg/github/tools.go | 2 + 3 files changed, 400 insertions(+) diff --git a/pkg/github/repositories.go b/pkg/github/repositories.go index 7c1bc23e..6abdde64 100644 --- a/pkg/github/repositories.go +++ b/pkg/github/repositories.go @@ -796,3 +796,147 @@ func PushFiles(getClient GetClientFn, t translations.TranslationHelperFunc) (too return mcp.NewToolResultText(string(r)), nil } } + +// ListTags creates a tool to list tags in a GitHub repository. +func ListTags(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { + return mcp.NewTool("list_tags", + mcp.WithDescription(t("TOOL_LIST_TAGS_DESCRIPTION", "List tags in a GitHub repository")), + mcp.WithToolAnnotation(mcp.ToolAnnotation{ + Title: t("TOOL_LIST_TAGS_USER_TITLE", "List tags"), + ReadOnlyHint: true, + }), + mcp.WithString("owner", + mcp.Required(), + mcp.Description("Repository owner"), + ), + mcp.WithString("repo", + mcp.Required(), + mcp.Description("Repository name"), + ), + WithPagination(), + ), + func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + owner, err := requiredParam[string](request, "owner") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + repo, err := requiredParam[string](request, "repo") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + pagination, err := OptionalPaginationParams(request) + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + opts := &github.ListOptions{ + Page: pagination.page, + PerPage: pagination.perPage, + } + + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } + + tags, resp, err := client.Repositories.ListTags(ctx, owner, repo, opts) + if err != nil { + return nil, fmt.Errorf("failed to list tags: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + return mcp.NewToolResultError(fmt.Sprintf("failed to list tags: %s", string(body))), nil + } + + r, err := json.Marshal(tags) + if err != nil { + return nil, fmt.Errorf("failed to marshal response: %w", err) + } + + return mcp.NewToolResultText(string(r)), nil + } +} + +// GetTag creates a tool to get details about a specific tag in a GitHub repository. +func GetTag(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { + return mcp.NewTool("get_tag", + mcp.WithDescription(t("TOOL_GET_TAG_DESCRIPTION", "Get details about a specific tag in a GitHub repository")), + mcp.WithToolAnnotation(mcp.ToolAnnotation{ + Title: t("TOOL_GET_TAG_USER_TITLE", "Get tag details"), + ReadOnlyHint: true, + }), + mcp.WithString("owner", + mcp.Required(), + mcp.Description("Repository owner"), + ), + mcp.WithString("repo", + mcp.Required(), + mcp.Description("Repository name"), + ), + mcp.WithString("tag", + mcp.Required(), + mcp.Description("Tag name"), + ), + ), + func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + owner, err := requiredParam[string](request, "owner") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + repo, err := requiredParam[string](request, "repo") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + tag, err := requiredParam[string](request, "tag") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } + + // First get the tag reference + ref, resp, err := client.Git.GetRef(ctx, owner, repo, "refs/tags/"+tag) + if err != nil { + return nil, fmt.Errorf("failed to get tag reference: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + return mcp.NewToolResultError(fmt.Sprintf("failed to get tag reference: %s", string(body))), nil + } + + // Then get the tag object + tagObj, resp, err := client.Git.GetTag(ctx, owner, repo, *ref.Object.SHA) + if err != nil { + return nil, fmt.Errorf("failed to get tag object: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + return mcp.NewToolResultError(fmt.Sprintf("failed to get tag object: %s", string(body))), nil + } + + r, err := json.Marshal(tagObj) + if err != nil { + return nil, fmt.Errorf("failed to marshal response: %w", err) + } + + return mcp.NewToolResultText(string(r)), nil + } +} diff --git a/pkg/github/repositories_test.go b/pkg/github/repositories_test.go index 5b8129fe..7fe58fc8 100644 --- a/pkg/github/repositories_test.go +++ b/pkg/github/repositories_test.go @@ -1528,3 +1528,257 @@ func Test_ListBranches(t *testing.T) { }) } } + +func Test_ListTags(t *testing.T) { + // Verify tool definition once + mockClient := github.NewClient(nil) + tool, _ := ListTags(stubGetClientFn(mockClient), translations.NullTranslationHelper) + + assert.Equal(t, "list_tags", tool.Name) + assert.NotEmpty(t, tool.Description) + assert.Contains(t, tool.InputSchema.Properties, "owner") + assert.Contains(t, tool.InputSchema.Properties, "repo") + assert.ElementsMatch(t, tool.InputSchema.Required, []string{"owner", "repo"}) + + // Setup mock tags for success case + mockTags := []*github.RepositoryTag{ + { + Name: github.Ptr("v1.0.0"), + Commit: &github.Commit{ + SHA: github.Ptr("abc123"), + URL: github.Ptr("https://api.github.com/repos/owner/repo/commits/abc123"), + }, + ZipballURL: github.Ptr("https://github.com/owner/repo/zipball/v1.0.0"), + TarballURL: github.Ptr("https://github.com/owner/repo/tarball/v1.0.0"), + }, + { + Name: github.Ptr("v0.9.0"), + Commit: &github.Commit{ + SHA: github.Ptr("def456"), + URL: github.Ptr("https://api.github.com/repos/owner/repo/commits/def456"), + }, + ZipballURL: github.Ptr("https://github.com/owner/repo/zipball/v0.9.0"), + TarballURL: github.Ptr("https://github.com/owner/repo/tarball/v0.9.0"), + }, + } + + tests := []struct { + name string + mockedClient *http.Client + requestArgs map[string]interface{} + expectError bool + expectedTags []*github.RepositoryTag + expectedErrMsg string + }{ + { + name: "successful tags list", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatch( + mock.GetReposTagsByOwnerByRepo, + mockTags, + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + }, + expectError: false, + expectedTags: mockTags, + }, + { + name: "list tags fails", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.GetReposTagsByOwnerByRepo, + http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + _, _ = w.Write([]byte(`{"message": "Internal Server Error"}`)) + }), + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + }, + expectError: true, + expectedErrMsg: "failed to list tags", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Setup client with mock + client := github.NewClient(tc.mockedClient) + _, handler := ListTags(stubGetClientFn(client), translations.NullTranslationHelper) + + // Create call request + request := createMCPRequest(tc.requestArgs) + + // Call handler + result, err := handler(context.Background(), request) + + // Verify results + if tc.expectError { + require.Error(t, err) + assert.Contains(t, err.Error(), tc.expectedErrMsg) + return + } + + require.NoError(t, err) + + // Parse the result and get the text content if no error + textContent := getTextResult(t, result) + + // Parse and verify the result + var returnedTags []*github.RepositoryTag + err = json.Unmarshal([]byte(textContent.Text), &returnedTags) + require.NoError(t, err) + + // Verify each tag + require.Equal(t, len(tc.expectedTags), len(returnedTags)) + for i, expectedTag := range tc.expectedTags { + assert.Equal(t, *expectedTag.Name, *returnedTags[i].Name) + assert.Equal(t, *expectedTag.Commit.SHA, *returnedTags[i].Commit.SHA) + } + }) + } +} + +func Test_GetTag(t *testing.T) { + // Verify tool definition once + mockClient := github.NewClient(nil) + tool, _ := GetTag(stubGetClientFn(mockClient), translations.NullTranslationHelper) + + assert.Equal(t, "get_tag", tool.Name) + assert.NotEmpty(t, tool.Description) + assert.Contains(t, tool.InputSchema.Properties, "owner") + assert.Contains(t, tool.InputSchema.Properties, "repo") + assert.Contains(t, tool.InputSchema.Properties, "tag") + assert.ElementsMatch(t, tool.InputSchema.Required, []string{"owner", "repo", "tag"}) + + mockTagRef := &github.Reference{ + Ref: github.Ptr("refs/tags/v1.0.0"), + Object: &github.GitObject{ + SHA: github.Ptr("tag123"), + }, + } + + mockTagObj := &github.Tag{ + SHA: github.Ptr("tag123"), + Tag: github.Ptr("v1.0.0"), + Message: github.Ptr("Release v1.0.0"), + Object: &github.GitObject{ + Type: github.Ptr("commit"), + SHA: github.Ptr("abc123"), + }, + } + + tests := []struct { + name string + mockedClient *http.Client + requestArgs map[string]interface{} + expectError bool + expectedTag *github.Tag + expectedErrMsg string + }{ + { + name: "successful tag retrieval", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatch( + mock.GetReposGitRefByOwnerByRepoByRef, + mockTagRef, + ), + mock.WithRequestMatch( + mock.GetReposGitTagsByOwnerByRepoByTagSha, + mockTagObj, + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "tag": "v1.0.0", + }, + expectError: false, + expectedTag: mockTagObj, + }, + { + name: "tag reference not found", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.GetReposGitRefByOwnerByRepoByRef, + http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusNotFound) + _, _ = w.Write([]byte(`{"message": "Reference does not exist"}`)) + }), + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "tag": "v1.0.0", + }, + expectError: true, + expectedErrMsg: "failed to get tag reference", + }, + { + name: "tag object not found", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatch( + mock.GetReposGitRefByOwnerByRepoByRef, + mockTagRef, + ), + mock.WithRequestMatchHandler( + mock.GetReposGitTagsByOwnerByRepoByTagSha, + http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusNotFound) + _, _ = w.Write([]byte(`{"message": "Tag object does not exist"}`)) + }), + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "tag": "v1.0.0", + }, + expectError: true, + expectedErrMsg: "failed to get tag object", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Setup client with mock + client := github.NewClient(tc.mockedClient) + _, handler := GetTag(stubGetClientFn(client), translations.NullTranslationHelper) + + // Create call request + request := createMCPRequest(tc.requestArgs) + + // Call handler + result, err := handler(context.Background(), request) + + // Verify results + if tc.expectError { + require.Error(t, err) + assert.Contains(t, err.Error(), tc.expectedErrMsg) + return + } + + require.NoError(t, err) + + // Parse the result and get the text content if no error + textContent := getTextResult(t, result) + + // Parse and verify the result + var returnedTag github.Tag + err = json.Unmarshal([]byte(textContent.Text), &returnedTag) + require.NoError(t, err) + + assert.Equal(t, *tc.expectedTag.SHA, *returnedTag.SHA) + assert.Equal(t, *tc.expectedTag.Tag, *returnedTag.Tag) + assert.Equal(t, *tc.expectedTag.Message, *returnedTag.Message) + assert.Equal(t, *tc.expectedTag.Object.Type, *returnedTag.Object.Type) + assert.Equal(t, *tc.expectedTag.Object.SHA, *returnedTag.Object.SHA) + }) + } +} diff --git a/pkg/github/tools.go b/pkg/github/tools.go index 1a4a3b4d..3776a129 100644 --- a/pkg/github/tools.go +++ b/pkg/github/tools.go @@ -27,6 +27,8 @@ func InitToolsets(passedToolsets []string, readOnly bool, getClient GetClientFn, toolsets.NewServerTool(SearchCode(getClient, t)), toolsets.NewServerTool(GetCommit(getClient, t)), toolsets.NewServerTool(ListBranches(getClient, t)), + toolsets.NewServerTool(ListTags(getClient, t)), + toolsets.NewServerTool(GetTag(getClient, t)), ). AddWriteTools( toolsets.NewServerTool(CreateOrUpdateFile(getClient, t)), From 8f1cb69cab14fb8dd99164152548e5e888f2010a Mon Sep 17 00:00:00 2001 From: William Martin Date: Tue, 29 Apr 2025 16:39:19 +0200 Subject: [PATCH 2/4] Test path params for tag tools --- pkg/github/helper_test.go | 29 ++++++++++++++++++++------- pkg/github/repositories_test.go | 35 +++++++++++++++++++++++---------- 2 files changed, 47 insertions(+), 17 deletions(-) diff --git a/pkg/github/helper_test.go b/pkg/github/helper_test.go index 40fc0b94..f241d334 100644 --- a/pkg/github/helper_test.go +++ b/pkg/github/helper_test.go @@ -10,6 +10,15 @@ import ( "github.com/stretchr/testify/require" ) +// expectPath is a helper function to create a partial mock that expects a +// request with the given path, with the ability to chain a response handler. +func expectPath(t *testing.T, expectedPath string) *partialMock { + return &partialMock{ + t: t, + expectedPath: expectedPath, + } +} + // expectQueryParams is a helper function to create a partial mock that expects a // request with the given query parameters, with the ability to chain a response handler. func expectQueryParams(t *testing.T, expectedQueryParams map[string]string) *partialMock { @@ -29,7 +38,9 @@ func expectRequestBody(t *testing.T, expectedRequestBody any) *partialMock { } type partialMock struct { - t *testing.T + t *testing.T + + expectedPath string expectedQueryParams map[string]string expectedRequestBody any } @@ -37,12 +48,8 @@ type partialMock struct { func (p *partialMock) andThen(responseHandler http.HandlerFunc) http.HandlerFunc { p.t.Helper() return func(w http.ResponseWriter, r *http.Request) { - if p.expectedRequestBody != nil { - var unmarshaledRequestBody any - err := json.NewDecoder(r.Body).Decode(&unmarshaledRequestBody) - require.NoError(p.t, err) - - require.Equal(p.t, p.expectedRequestBody, unmarshaledRequestBody) + if p.expectedPath != "" { + require.Equal(p.t, p.expectedPath, r.URL.Path) } if p.expectedQueryParams != nil { @@ -52,6 +59,14 @@ func (p *partialMock) andThen(responseHandler http.HandlerFunc) http.HandlerFunc } } + if p.expectedRequestBody != nil { + var unmarshaledRequestBody any + err := json.NewDecoder(r.Body).Decode(&unmarshaledRequestBody) + require.NoError(p.t, err) + + require.Equal(p.t, p.expectedRequestBody, unmarshaledRequestBody) + } + responseHandler(w, r) } } diff --git a/pkg/github/repositories_test.go b/pkg/github/repositories_test.go index 7fe58fc8..59d19fc4 100644 --- a/pkg/github/repositories_test.go +++ b/pkg/github/repositories_test.go @@ -1545,7 +1545,7 @@ func Test_ListTags(t *testing.T) { { Name: github.Ptr("v1.0.0"), Commit: &github.Commit{ - SHA: github.Ptr("abc123"), + SHA: github.Ptr("v1.0.0-tag-sha"), URL: github.Ptr("https://api.github.com/repos/owner/repo/commits/abc123"), }, ZipballURL: github.Ptr("https://github.com/owner/repo/zipball/v1.0.0"), @@ -1554,7 +1554,7 @@ func Test_ListTags(t *testing.T) { { Name: github.Ptr("v0.9.0"), Commit: &github.Commit{ - SHA: github.Ptr("def456"), + SHA: github.Ptr("v0.9.0-tag-sha"), URL: github.Ptr("https://api.github.com/repos/owner/repo/commits/def456"), }, ZipballURL: github.Ptr("https://github.com/owner/repo/zipball/v0.9.0"), @@ -1573,9 +1573,14 @@ func Test_ListTags(t *testing.T) { { name: "successful tags list", mockedClient: mock.NewMockedHTTPClient( - mock.WithRequestMatch( + mock.WithRequestMatchHandler( mock.GetReposTagsByOwnerByRepo, - mockTags, + expectPath( + t, + "/repos/owner/repo/tags", + ).andThen( + mockResponse(t, http.StatusOK, mockTags), + ), ), ), requestArgs: map[string]interface{}{ @@ -1659,12 +1664,12 @@ func Test_GetTag(t *testing.T) { mockTagRef := &github.Reference{ Ref: github.Ptr("refs/tags/v1.0.0"), Object: &github.GitObject{ - SHA: github.Ptr("tag123"), + SHA: github.Ptr("v1.0.0-tag-sha"), }, } mockTagObj := &github.Tag{ - SHA: github.Ptr("tag123"), + SHA: github.Ptr("v1.0.0-tag-sha"), Tag: github.Ptr("v1.0.0"), Message: github.Ptr("Release v1.0.0"), Object: &github.GitObject{ @@ -1684,13 +1689,23 @@ func Test_GetTag(t *testing.T) { { name: "successful tag retrieval", mockedClient: mock.NewMockedHTTPClient( - mock.WithRequestMatch( + mock.WithRequestMatchHandler( mock.GetReposGitRefByOwnerByRepoByRef, - mockTagRef, + expectPath( + t, + "/repos/owner/repo/git/ref/tags/v1.0.0", + ).andThen( + mockResponse(t, http.StatusOK, mockTagRef), + ), ), - mock.WithRequestMatch( + mock.WithRequestMatchHandler( mock.GetReposGitTagsByOwnerByRepoByTagSha, - mockTagObj, + expectPath( + t, + "/repos/owner/repo/git/tags/v1.0.0-tag-sha", + ).andThen( + mockResponse(t, http.StatusOK, mockTagObj), + ), ), ), requestArgs: map[string]interface{}{ From 3a8eeb693e4a696275e4c62d2607f7a757416716 Mon Sep 17 00:00:00 2001 From: William Martin Date: Wed, 30 Apr 2025 13:17:38 +0200 Subject: [PATCH 3/4] Add e2e test for tags --- e2e/README.md | 4 +- e2e/e2e_test.go | 136 ++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 139 insertions(+), 1 deletion(-) diff --git a/e2e/README.md b/e2e/README.md index 21b65bfa..bb93b32c 100644 --- a/e2e/README.md +++ b/e2e/README.md @@ -81,4 +81,6 @@ FAIL The current test suite is intentionally very limited in scope. This is because the maintenance costs on e2e tests tend to increase significantly over time. To read about some challenges with GitHub integration tests, see [go-github integration tests README](https://github.com/google/go-github/blob/5b75aa86dba5cf4af2923afa0938774f37fa0a67/test/README.md). We will expand this suite circumspectly! -Currently, visibility into failures is not particularly good. +The tests are quite repetitive and verbose. This is intentional as we want to see them develop more before committing to abstractions. + +Currently, visibility into failures is not particularly good. We're hoping that we can pull apart the mcp-go client and have it hook into streams representing stdio without requiring an exec. This way we can get breakpoints in the debugger easily. diff --git a/e2e/e2e_test.go b/e2e/e2e_test.go index 757dd5c2..5da6379c 100644 --- a/e2e/e2e_test.go +++ b/e2e/e2e_test.go @@ -206,3 +206,139 @@ func TestToolsets(t *testing.T) { require.True(t, toolsContains("list_branches"), "expected to find 'list_branches' tool") require.False(t, toolsContains("get_pull_request"), "expected not to find 'get_pull_request' tool") } + +func TestTags(t *testing.T) { + mcpClient := setupMCPClient(t) + + ctx := context.Background() + + // First, who am I + getMeRequest := mcp.CallToolRequest{} + getMeRequest.Params.Name = "get_me" + + t.Log("Getting current user...") + resp, err := mcpClient.CallTool(ctx, getMeRequest) + require.NoError(t, err, "expected to call 'get_me' tool successfully") + require.False(t, resp.IsError, fmt.Sprintf("expected result not to be an error: %+v", resp)) + + require.False(t, resp.IsError, "expected result not to be an error") + require.Len(t, resp.Content, 1, "expected content to have one item") + + textContent, ok := resp.Content[0].(mcp.TextContent) + require.True(t, ok, "expected content to be of type TextContent") + + var trimmedGetMeText struct { + Login string `json:"login"` + } + err = json.Unmarshal([]byte(textContent.Text), &trimmedGetMeText) + require.NoError(t, err, "expected to unmarshal text content successfully") + + currentOwner := trimmedGetMeText.Login + + // Then create a repository with a README (via autoInit) + repoName := fmt.Sprintf("github-mcp-server-e2e-%s-%d", t.Name(), time.Now().UnixMilli()) + createRepoRequest := mcp.CallToolRequest{} + createRepoRequest.Params.Name = "create_repository" + createRepoRequest.Params.Arguments = map[string]any{ + "name": repoName, + "private": true, + "autoInit": true, + } + + t.Logf("Creating repository %s/%s...", currentOwner, repoName) + _, err = mcpClient.CallTool(ctx, createRepoRequest) + require.NoError(t, err, "expected to call 'get_me' tool successfully") + require.False(t, resp.IsError, fmt.Sprintf("expected result not to be an error: %+v", resp)) + + // Cleanup the repository after the test + t.Cleanup(func() { + // MCP Server doesn't support deletions, but we can use the GitHub Client + ghClient := github.NewClient(nil).WithAuthToken(getE2EToken(t)) + t.Logf("Deleting repository %s/%s...", currentOwner, repoName) + _, err := ghClient.Repositories.Delete(context.Background(), currentOwner, repoName) + require.NoError(t, err, "expected to delete repository successfully") + }) + + // Then create a tag + // MCP Server doesn't support tag creation, but we can use the GitHub Client + ghClient := github.NewClient(nil).WithAuthToken(getE2EToken(t)) + t.Logf("Creating tag %s/%s:%s...", currentOwner, repoName, "v0.0.1") + ref, _, err := ghClient.Git.GetRef(context.Background(), currentOwner, repoName, "refs/heads/main") + require.NoError(t, err, "expected to get ref successfully") + + tagObj, _, err := ghClient.Git.CreateTag(context.Background(), currentOwner, repoName, &github.Tag{ + Tag: github.Ptr("v0.0.1"), + Message: github.Ptr("v0.0.1"), + Object: &github.GitObject{ + SHA: ref.Object.SHA, + Type: github.Ptr("commit"), + }, + }) + require.NoError(t, err, "expected to create tag object successfully") + + _, _, err = ghClient.Git.CreateRef(context.Background(), currentOwner, repoName, &github.Reference{ + Ref: github.Ptr("refs/tags/v0.0.1"), + Object: &github.GitObject{ + SHA: tagObj.SHA, + }, + }) + require.NoError(t, err, "expected to create tag ref successfully") + + // List the tags + listTagsRequest := mcp.CallToolRequest{} + listTagsRequest.Params.Name = "list_tags" + listTagsRequest.Params.Arguments = map[string]any{ + "owner": currentOwner, + "repo": repoName, + } + + t.Logf("Listing tags for %s/%s...", currentOwner, repoName) + resp, err = mcpClient.CallTool(ctx, listTagsRequest) + require.NoError(t, err, "expected to call 'list_tags' tool successfully") + require.False(t, resp.IsError, fmt.Sprintf("expected result not to be an error: %+v", resp)) + + require.False(t, resp.IsError, "expected result not to be an error") + require.Len(t, resp.Content, 1, "expected content to have one item") + + textContent, ok = resp.Content[0].(mcp.TextContent) + require.True(t, ok, "expected content to be of type TextContent") + + var trimmedTags []struct { + Name string `json:"name"` + Commit struct { + SHA string `json:"sha"` + } `json:"commit"` + } + err = json.Unmarshal([]byte(textContent.Text), &trimmedTags) + require.NoError(t, err, "expected to unmarshal text content successfully") + + require.Len(t, trimmedTags, 1, "expected to find one tag") + require.Equal(t, "v0.0.1", trimmedTags[0].Name, "expected tag name to match") + require.Equal(t, *ref.Object.SHA, trimmedTags[0].Commit.SHA, "expected tag SHA to match") + + // And fetch an individual tag + getTagRequest := mcp.CallToolRequest{} + getTagRequest.Params.Name = "get_tag" + getTagRequest.Params.Arguments = map[string]any{ + "owner": currentOwner, + "repo": repoName, + "tag": "v0.0.1", + } + + t.Logf("Getting tag %s/%s:%s...", currentOwner, repoName, "v0.0.1") + resp, err = mcpClient.CallTool(ctx, getTagRequest) + require.NoError(t, err, "expected to call 'get_tag' tool successfully") + require.False(t, resp.IsError, "expected result not to be an error") + + var trimmedTag []struct { // don't understand why this is an array + Name string `json:"name"` + Commit struct { + SHA string `json:"sha"` + } `json:"commit"` + } + err = json.Unmarshal([]byte(textContent.Text), &trimmedTag) + require.NoError(t, err, "expected to unmarshal text content successfully") + require.Len(t, trimmedTag, 1, "expected to find one tag") + require.Equal(t, "v0.0.1", trimmedTag[0].Name, "expected tag name to match") + require.Equal(t, *ref.Object.SHA, trimmedTag[0].Commit.SHA, "expected tag SHA to match") +} From 1a23af09619ae3701fe1c9efd9105615087584f7 Mon Sep 17 00:00:00 2001 From: William Martin Date: Wed, 30 Apr 2025 17:01:02 +0200 Subject: [PATCH 4/4] WIP: in process e2e --- e2e/e2e_test.go | 339 ++++++++++++++++++++++++++++++++++++++------ pkg/github/stdio.go | 146 +++++++++++++++++++ 2 files changed, 442 insertions(+), 43 deletions(-) create mode 100644 pkg/github/stdio.go diff --git a/e2e/e2e_test.go b/e2e/e2e_test.go index 5da6379c..04abd347 100644 --- a/e2e/e2e_test.go +++ b/e2e/e2e_test.go @@ -3,9 +3,12 @@ package e2e_test import ( + "bufio" + "bytes" "context" "encoding/json" "fmt" + "io" "os" "os/exec" "slices" @@ -13,9 +16,12 @@ import ( "testing" "time" - "github.com/google/go-github/v69/github" + "github.com/github/github-mcp-server/pkg/github" + gogithub "github.com/google/go-github/v69/github" mcpClient "github.com/mark3labs/mcp-go/client" + "github.com/mark3labs/mcp-go/client/transport" "github.com/mark3labs/mcp-go/mcp" + "github.com/sirupsen/logrus" "github.com/stretchr/testify/require" ) @@ -77,47 +83,296 @@ func WithEnvVars(envVars map[string]string) ClientOption { func setupMCPClient(t *testing.T, options ...ClientOption) *mcpClient.Client { // Get token and ensure Docker image is built token := getE2EToken(t) - ensureDockerImageBuilt(t) - // Create and configure options - opts := &ClientOpts{ - EnvVars: make(map[string]string), + // By default, we run the tests including the Docker image, but with DEBUG + // enabled, we run the server in-process, allowing for easier debugging. + var client *mcpClient.Client + if os.Getenv("GITHUB_MCP_SERVER_E2E_DEBUG") == "" { + ensureDockerImageBuilt(t) + + // Create and configure options + opts := &ClientOpts{ + EnvVars: make(map[string]string), + } + + // Apply all options to configure the opts struct + for _, option := range options { + option(opts) + } + + // Prepare Docker arguments + args := []string{ + "docker", + "run", + "-i", + "--rm", + "-e", + "GITHUB_PERSONAL_ACCESS_TOKEN", // Personal access token is all required + } + + // Add all environment variables to the Docker arguments + for key := range opts.EnvVars { + args = append(args, "-e", key) + } + + // Add the image name + args = append(args, "github/e2e-github-mcp-server") + + // Construct the env vars for the MCP Client to execute docker with + dockerEnvVars := make([]string, 0, len(opts.EnvVars)+1) + dockerEnvVars = append(dockerEnvVars, fmt.Sprintf("GITHUB_PERSONAL_ACCESS_TOKEN=%s", token)) + for key, value := range opts.EnvVars { + dockerEnvVars = append(dockerEnvVars, fmt.Sprintf("%s=%s", key, value)) + } + + // Create the client + t.Log("Starting Stdio MCP client...") + var err error + client, err = mcpClient.NewStdioMCPClient(args[0], dockerEnvVars, args[1:]...) + require.NoError(t, err, "expected to create client successfully") + } else { + // Pipe setup: clientToServer simulates stdin/stdout + clientToServerR, clientToServerW := io.Pipe() + serverToClientR, serverToClientW := io.Pipe() + stderrBuf := &bytes.Buffer{} + + go func() { + require.NoError(t, + github.RunStdioServer(github.RunConfig{ + Stdin: clientToServerR, + Stdout: serverToClientW, + // Version: "", + Token: token, + Logger: &logrus.Logger{}, + LogCommands: false, + ReadOnly: false, + ExportTranslations: false, + EnabledToolsets: []string{"all"}, + }), + "expected to start server successfully", + ) + }() + + transport := NewInProcessStdioTransport(serverToClientR, clientToServerW, stderrBuf) + require.NoError(t, transport.Start(context.Background()), "expected to start client successfully") + + client = mcpClient.NewClient(transport) } - // Apply all options to configure the opts struct - for _, option := range options { - option(opts) + t.Cleanup(func() { + require.NoError(t, client.Close(), "expected to close client successfully") + }) + + // Initialize the client + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + request := mcp.InitializeRequest{} + request.Params.ProtocolVersion = "2025-03-26" + request.Params.ClientInfo = mcp.Implementation{ + Name: "e2e-test-client", + Version: "0.0.1", } - // Prepare Docker arguments - args := []string{ - "docker", - "run", - "-i", - "--rm", - "-e", - "GITHUB_PERSONAL_ACCESS_TOKEN", // Personal access token is all required + result, err := client.Initialize(ctx, request) + require.NoError(t, err, "failed to initialize client") + require.Equal(t, "github-mcp-server", result.ServerInfo.Name, "unexpected server name") + + return client +} + +// Taken from mcp-go client and adjusted to take stdio rather than execing a command. +type InProcessStdioTransport struct { + stdin io.Writer + stdout *bufio.Reader + stderr io.Writer + + responses map[int64]chan *transport.JSONRPCResponse + mu sync.RWMutex + done chan struct{} + onNotification func(mcp.JSONRPCNotification) + notifyMu sync.RWMutex + + startOnce sync.Once + closeOnce sync.Once +} + +var _ transport.Interface = (*InProcessStdioTransport)(nil) + +func NewInProcessStdioTransport(r io.Reader, w io.Writer, stderr io.Writer) *InProcessStdioTransport { + return &InProcessStdioTransport{ + stdin: w, + stdout: bufio.NewReader(r), + stderr: stderr, + responses: make(map[int64]chan *transport.JSONRPCResponse), + done: make(chan struct{}), + } +} + +func (c *InProcessStdioTransport) Start(ctx context.Context) error { + c.startOnce.Do(func() { + go c.readResponses() + }) + return nil +} + +func (c *InProcessStdioTransport) Close() error { + c.closeOnce.Do(func() { + close(c.done) + }) + return nil +} + +func (c *InProcessStdioTransport) SendRequest( + ctx context.Context, + request transport.JSONRPCRequest, +) (*transport.JSONRPCResponse, error) { + if c.stdin == nil { + return nil, fmt.Errorf("in-process stdio not started") + } + + responseChan := make(chan *transport.JSONRPCResponse, 1) + + c.mu.Lock() + c.responses[request.ID] = responseChan + c.mu.Unlock() + + requestBytes, err := json.Marshal(request) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + requestBytes = append(requestBytes, '\n') + + if _, err := c.stdin.Write(requestBytes); err != nil { + return nil, fmt.Errorf("failed to write request: %w", err) } - // Add all environment variables to the Docker arguments - for key := range opts.EnvVars { - args = append(args, "-e", key) + select { + case <-ctx.Done(): + c.mu.Lock() + delete(c.responses, request.ID) + c.mu.Unlock() + return nil, ctx.Err() + case response := <-responseChan: + return response, nil } +} - // Add the image name - args = append(args, "github/e2e-github-mcp-server") +func (c *InProcessStdioTransport) SendNotification( + ctx context.Context, + notification mcp.JSONRPCNotification, +) error { + notificationBytes, err := json.Marshal(notification) + if err != nil { + return fmt.Errorf("failed to marshal notification: %w", err) + } + notificationBytes = append(notificationBytes, '\n') - // Construct the env vars for the MCP Client to execute docker with - dockerEnvVars := make([]string, 0, len(opts.EnvVars)+1) - dockerEnvVars = append(dockerEnvVars, fmt.Sprintf("GITHUB_PERSONAL_ACCESS_TOKEN=%s", token)) - for key, value := range opts.EnvVars { - dockerEnvVars = append(dockerEnvVars, fmt.Sprintf("%s=%s", key, value)) + if _, err := c.stdin.Write(notificationBytes); err != nil { + return fmt.Errorf("failed to write notification: %w", err) } + return nil +} + +func (c *InProcessStdioTransport) SetNotificationHandler(handler func(notification mcp.JSONRPCNotification)) { + c.notifyMu.Lock() + defer c.notifyMu.Unlock() + c.onNotification = handler +} + +func (c *InProcessStdioTransport) readResponses() { + for { + select { + case <-c.done: + return + default: + line, err := c.stdout.ReadString('\n') + if err != nil { + if err != io.EOF { + fmt.Fprintf(c.stderr, "error reading response: %v\n", err) + } + return + } + + var baseMessage transport.JSONRPCResponse + if err := json.Unmarshal([]byte(line), &baseMessage); err != nil { + continue + } + + if baseMessage.ID == nil { + var notification mcp.JSONRPCNotification + if err := json.Unmarshal([]byte(line), ¬ification); err != nil { + continue + } + c.notifyMu.RLock() + if c.onNotification != nil { + c.onNotification(notification) + } + c.notifyMu.RUnlock() + continue + } + + c.mu.RLock() + ch, ok := c.responses[*baseMessage.ID] + c.mu.RUnlock() + + if ok { + ch <- &baseMessage + c.mu.Lock() + delete(c.responses, *baseMessage.ID) + c.mu.Unlock() + } + } + } +} + +func (c *InProcessStdioTransport) Stderr() io.Writer { + return c.stderr +} + +func TestInProcessMCPClient(t *testing.T) { + // Pipe setup: clientToServer simulates stdin/stdout + clientToServerR, clientToServerW := io.Pipe() + serverToClientR, serverToClientW := io.Pipe() + stderrBuf := &bytes.Buffer{} + + transport := NewInProcessStdioTransport(serverToClientR, clientToServerW, stderrBuf) + + // Simulated in-memory server + go func() { + decoder := json.NewDecoder(clientToServerR) + encoder := json.NewEncoder(serverToClientW) + + for { + var req mcp.JSONRPCRequest + if err := decoder.Decode(&req); err != nil { + return + } + + // Log request to stderr + stderrBuf.WriteString("received method: " + req.Method + "\n") + + // Respond to initialization + resp := mcp.JSONRPCResponse{ + JSONRPC: "2.0", + ID: req.ID, + Result: mcp.InitializeResult{ + ProtocolVersion: "2025-03-26", + Capabilities: mcp.ServerCapabilities{}, + ServerInfo: mcp.Implementation{ + Name: "in-process-server", + Version: "v0.0.1", + }, + }, + } + _ = encoder.Encode(resp) + } + }() // Create the client - t.Log("Starting Stdio MCP client...") - client, err := mcpClient.NewStdioMCPClient(args[0], dockerEnvVars, args[1:]...) - require.NoError(t, err, "expected to create client successfully") + + require.NoError(t, transport.Start(context.Background()), "expected to start client successfully") + client := mcpClient.NewClient(transport) t.Cleanup(func() { require.NoError(t, client.Close(), "expected to close client successfully") }) @@ -135,9 +390,7 @@ func setupMCPClient(t *testing.T, options ...ClientOption) *mcpClient.Client { result, err := client.Initialize(ctx, request) require.NoError(t, err, "failed to initialize client") - require.Equal(t, "github-mcp-server", result.ServerInfo.Name, "unexpected server name") - - return client + require.Equal(t, "in-process-server", result.ServerInfo.Name, "unexpected server name") } func TestGetMe(t *testing.T) { @@ -169,7 +422,7 @@ func TestGetMe(t *testing.T) { // Then the login in the response should match the login obtained via the same // token using the GitHub API. - ghClient := github.NewClient(nil).WithAuthToken(getE2EToken(t)) + ghClient := gogithub.NewClient(nil).WithAuthToken(getE2EToken(t)) user, _, err := ghClient.Users.Get(context.Background(), "") require.NoError(t, err, "expected to get user successfully") require.Equal(t, trimmedContent.Login, *user.Login, "expected login to match") @@ -253,7 +506,7 @@ func TestTags(t *testing.T) { // Cleanup the repository after the test t.Cleanup(func() { // MCP Server doesn't support deletions, but we can use the GitHub Client - ghClient := github.NewClient(nil).WithAuthToken(getE2EToken(t)) + ghClient := gogithub.NewClient(nil).WithAuthToken(getE2EToken(t)) t.Logf("Deleting repository %s/%s...", currentOwner, repoName) _, err := ghClient.Repositories.Delete(context.Background(), currentOwner, repoName) require.NoError(t, err, "expected to delete repository successfully") @@ -261,24 +514,24 @@ func TestTags(t *testing.T) { // Then create a tag // MCP Server doesn't support tag creation, but we can use the GitHub Client - ghClient := github.NewClient(nil).WithAuthToken(getE2EToken(t)) + ghClient := gogithub.NewClient(nil).WithAuthToken(getE2EToken(t)) t.Logf("Creating tag %s/%s:%s...", currentOwner, repoName, "v0.0.1") ref, _, err := ghClient.Git.GetRef(context.Background(), currentOwner, repoName, "refs/heads/main") require.NoError(t, err, "expected to get ref successfully") - tagObj, _, err := ghClient.Git.CreateTag(context.Background(), currentOwner, repoName, &github.Tag{ - Tag: github.Ptr("v0.0.1"), - Message: github.Ptr("v0.0.1"), - Object: &github.GitObject{ + tagObj, _, err := ghClient.Git.CreateTag(context.Background(), currentOwner, repoName, &gogithub.Tag{ + Tag: gogithub.Ptr("v0.0.1"), + Message: gogithub.Ptr("v0.0.1"), + Object: &gogithub.GitObject{ SHA: ref.Object.SHA, - Type: github.Ptr("commit"), + Type: gogithub.Ptr("commit"), }, }) require.NoError(t, err, "expected to create tag object successfully") - _, _, err = ghClient.Git.CreateRef(context.Background(), currentOwner, repoName, &github.Reference{ - Ref: github.Ptr("refs/tags/v0.0.1"), - Object: &github.GitObject{ + _, _, err = ghClient.Git.CreateRef(context.Background(), currentOwner, repoName, &gogithub.Reference{ + Ref: gogithub.Ptr("refs/tags/v0.0.1"), + Object: &gogithub.GitObject{ SHA: tagObj.SHA, }, }) diff --git a/pkg/github/stdio.go b/pkg/github/stdio.go new file mode 100644 index 00000000..6f97509f --- /dev/null +++ b/pkg/github/stdio.go @@ -0,0 +1,146 @@ +package github + +import ( + "context" + "fmt" + "io" + "os" + "os/signal" + "syscall" + + stdlog "log" + + iolog "github.com/github/github-mcp-server/pkg/log" + "github.com/github/github-mcp-server/pkg/translations" + gogithub "github.com/google/go-github/v69/github" + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" + log "github.com/sirupsen/logrus" + "github.com/spf13/viper" +) + +type RunConfig struct { + Stdin io.Reader + Stdout io.Writer + + Version string + + Token string + + Logger *log.Logger + LogCommands bool + + ReadOnly bool + ExportTranslations bool + EnabledToolsets []string +} + +func RunStdioServer(cfg RunConfig) error { + // Create app context + ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) + defer stop() + + // Create GH client + ghClient := gogithub.NewClient(nil).WithAuthToken(cfg.Token) + ghClient.UserAgent = fmt.Sprintf("github-mcp-server/%s", cfg.Version) + + host := viper.GetString("host") + + if host != "" { + var err error + ghClient, err = ghClient.WithEnterpriseURLs(host, host) + if err != nil { + return fmt.Errorf("failed to create GitHub client with host: %w", err) + } + } + + t, dumpTranslations := translations.TranslationHelper() + + beforeInit := func(_ context.Context, _ any, message *mcp.InitializeRequest) { + ghClient.UserAgent = fmt.Sprintf( + "github-mcp-server/%s (%s/%s)", + cfg.Version, + message.Params.ClientInfo.Name, + message.Params.ClientInfo.Version, + ) + } + + getClient := func(_ context.Context) (*gogithub.Client, error) { + return ghClient, nil // closing over client + } + + hooks := &server.Hooks{ + OnBeforeInitialize: []server.OnBeforeInitializeFunc{beforeInit}, + } + // Create server + ghServer := NewServer(cfg.Version, server.WithHooks(hooks)) + + enabled := cfg.EnabledToolsets + // TODO: tear this out + dynamic := viper.GetBool("dynamic_toolsets") + if dynamic { + // filter "all" from the enabled toolsets + enabled = make([]string, 0, len(cfg.EnabledToolsets)) + for _, toolset := range cfg.EnabledToolsets { + if toolset != "all" { + enabled = append(enabled, toolset) + } + } + } + + // Create default toolsets + toolsets, err := InitToolsets(enabled, cfg.ReadOnly, getClient, t) + if err != nil { + cfg.Logger.Fatal("Failed to initialize toolsets:", err) + } + + context := InitContextToolset(getClient, t) + + // Register resources with the server + RegisterResources(ghServer, getClient, t) + // Register the tools with the server + toolsets.RegisterTools(ghServer) + context.RegisterTools(ghServer) + + if dynamic { + dynamic := InitDynamicToolset(ghServer, toolsets, t) + dynamic.RegisterTools(ghServer) + } + + stdioServer := server.NewStdioServer(ghServer) + + stdLogger := stdlog.New(cfg.Logger.Writer(), "stdioserver", 0) + stdioServer.SetErrorLogger(stdLogger) + + if cfg.ExportTranslations { + // Once server is initialized, all translations are loaded + dumpTranslations() + } + + // Start listening for messages + errC := make(chan error, 1) + go func() { + in, out := cfg.Stdin, cfg.Stdout + if cfg.LogCommands { + loggedIO := iolog.NewIOLogger(in, out, cfg.Logger) + in, out = loggedIO, loggedIO + } + + errC <- stdioServer.Listen(ctx, in, out) + }() + + // Output github-mcp-server string + _, _ = fmt.Fprintf(os.Stderr, "GitHub MCP Server running on stdio\n") + + // Wait for shutdown signal + select { + case <-ctx.Done(): + cfg.Logger.Infof("shutting down server...") + case err := <-errC: + if err != nil { + return fmt.Errorf("error running server: %w", err) + } + } + + return nil +}