diff --git a/README.md b/README.md index 7b9e20fc..f4ac4502 100644 --- a/README.md +++ b/README.md @@ -290,6 +290,16 @@ export GITHUB_MCP_TOOL_ADD_ISSUE_COMMENT_DESCRIPTION="an alternative description - **get_me** - Get details of the authenticated user - No parameters required +- **list_users_public_ssh_keys** - "Lists the public SSH keys for the authenticated user's GitHub account + - No parameters required + +- **get_users_public_ssh_key** - View extended details for a single public SSH key + - `key_id`: Key Id (number, required) + +- **add_users_public_ssh_key** - Adds a public SSH key to the authenticated user's GitHub account + - `title`: Title of the key (string, optional) + - `key`: Public key contents (string, required) + ### Issues - **get_issue** - Gets the contents of an issue within a repository diff --git a/pkg/github/tools.go b/pkg/github/tools.go index ab052817..7448c7b5 100644 --- a/pkg/github/tools.go +++ b/pkg/github/tools.go @@ -56,6 +56,9 @@ func InitToolsets(passedToolsets []string, readOnly bool, getClient GetClientFn, users := toolsets.NewToolset("users", "GitHub User related tools"). AddReadTools( toolsets.NewServerTool(SearchUsers(getClient, t)), + toolsets.NewServerTool(ListUsersPublicSSHKeys(getClient, t)), + toolsets.NewServerTool(GetUsersPublicSSHKey(getClient, t)), + toolsets.NewServerTool(AddUsersPublicSSHKey(getClient, t)), ) pullRequests := toolsets.NewToolset("pull_requests", "GitHub Pull Request related tools"). AddReadTools( diff --git a/pkg/github/user.go b/pkg/github/user.go new file mode 100644 index 00000000..c2a13a9c --- /dev/null +++ b/pkg/github/user.go @@ -0,0 +1,163 @@ +package github + +import ( + "context" + "encoding/json" + "fmt" + "io" + + "github.com/github/github-mcp-server/pkg/translations" + "github.com/google/go-github/v72/github" + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" +) + +// ListUsersPublicSSHKeys creates a tool to list public ssh keys for user +func ListUsersPublicSSHKeys(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { + return mcp.NewTool("list_users_public_ssh_keys", + mcp.WithDescription(t("TOOL_LIST_USERS_PUBLIC_SSH_KEYS", "Lists the public SSH keys for the authenticated user's GitHub account")), + mcp.WithToolAnnotation(mcp.ToolAnnotation{ + Title: t("TOOL_LIST_USERS_PUBLIC_SSH_KEYS_USER_TITLE", "List users public ssh keys"), + ReadOnlyHint: toBoolPtr(true), + }), + WithPagination(), + ), + func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + pagination, err := OptionalPaginationParams(request) + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + opts := &github.ListOptions{ + Page: pagination.page, + PerPage: pagination.perPage, + } + + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } + result, resp, err := client.Users.ListKeys(ctx, "", opts) + if err != nil { + return nil, fmt.Errorf("failed to list users ssh keys: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != 200 { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + return mcp.NewToolResultError(fmt.Sprintf("failed to list users ssh keys: %s", string(body))), nil + } + + r, err := json.Marshal(result) + if err != nil { + return nil, fmt.Errorf("failed to marshal response: %w", err) + } + + return mcp.NewToolResultText(string(r)), nil + } +} + +// GetUsersPublicSSHKey creates a tool to get public ssh key for user +func GetUsersPublicSSHKey(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { + return mcp.NewTool("get_users_public_ssh_key", + mcp.WithDescription(t("TOOL_GET_USERS_PUBLIC_SSH_KEY", "View extended details for a single public SSH key")), + mcp.WithToolAnnotation(mcp.ToolAnnotation{ + Title: t("TOOL_GET_USERS_PUBLIC_SSH_KEY_USER_TITLE", "Get public ssh key details"), + ReadOnlyHint: toBoolPtr(true), + }), + mcp.WithNumber("key_id", + mcp.Required(), + mcp.Description("The unique identifier of the key"), + ), + ), + func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } + keyId, err := RequiredInt(request, "key_id") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + result, resp, err := client.Users.GetKey(ctx, int64(keyId)) + if err != nil { + return nil, fmt.Errorf("failed to get ssh key details: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != 200 { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + return mcp.NewToolResultError(fmt.Sprintf("failed to get ssh key details: %s", string(body))), nil + } + + r, err := json.Marshal(result) + if err != nil { + return nil, fmt.Errorf("failed to marshal response: %w", err) + } + + return mcp.NewToolResultText(string(r)), nil + } +} + +// AddPublicSSHKey Adds a public SSH key to the authenticated user's GitHub account +func AddUsersPublicSSHKey(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { + return mcp.NewTool("add_users_public_ssh_key", + mcp.WithDescription(t("TOOL_ADD_USERS_PUBLIC_SSH_KEY", "Adds a public SSH key to the authenticated user's GitHub account")), + mcp.WithToolAnnotation(mcp.ToolAnnotation{ + Title: t("TOOL_ADD_USERS_PUBLIC_SSH_KEY_USER_TITLE", "Add users public ssh key"), + ReadOnlyHint: toBoolPtr(true), + }), + mcp.WithString("title", + mcp.Description("A descriptive name for the new key"), + ), + mcp.WithString("key", + mcp.Required(), + mcp.Description("The public SSH key to add to your GitHub account"), + ), + ), + func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } + title, err := OptionalParam[string](request, "title") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + key, err := requiredParam[string](request, "key") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + githubKey := &github.Key{ + Title: &title, + Key: &key, + } + result, resp, err := client.Users.CreateKey(ctx, githubKey) + if err != nil { + return nil, fmt.Errorf("failed to add ssh key: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != 201 { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + return mcp.NewToolResultError(fmt.Sprintf("failed to add ssh key: %s", string(body))), nil + } + + r, err := json.Marshal(result) + if err != nil { + return nil, fmt.Errorf("failed to marshal response: %w", err) + } + + return mcp.NewToolResultText(string(r)), nil + } +} diff --git a/pkg/github/user_test.go b/pkg/github/user_test.go new file mode 100644 index 00000000..ecc1cdb0 --- /dev/null +++ b/pkg/github/user_test.go @@ -0,0 +1,343 @@ +package github + +import ( + "context" + "encoding/json" + "net/http" + "testing" + + "github.com/github/github-mcp-server/pkg/translations" + "github.com/google/go-github/v72/github" + "github.com/migueleliasweb/go-github-mock/src/mock" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_ListUsersPublicSSHKeys(t *testing.T) { + mockClient := github.NewClient(nil) + tool, _ := ListUsersPublicSSHKeys(stubGetClientFn(mockClient), translations.NullTranslationHelper) + assert.Equal(t, "list_users_public_ssh_keys", tool.Name) + assert.NotEmpty(t, tool.Description) + assert.Contains(t, tool.InputSchema.Properties, "page") + assert.Contains(t, tool.InputSchema.Properties, "perPage") + + // Setup mock results + mockListSSHKeyResult := []*github.Key{ + { + ID: github.Ptr(int64(1)), + Key: github.Ptr("ssh test key"), + URL: github.Ptr("test url"), + Title: github.Ptr("test key 1"), + ReadOnly: github.Ptr(true), + Verified: github.Ptr(true), + }, + { + ID: github.Ptr(int64(2)), + Key: github.Ptr("ssh test key"), + URL: github.Ptr("test url"), + Title: github.Ptr("test key 2"), + ReadOnly: github.Ptr(true), + Verified: github.Ptr(true), + }, + } + tests := []struct { + name string + mockedClient *http.Client + requestArgs map[string]any + expectError bool + expectedResult []*github.Key + expectedErrMsg string + }{ + { + name: "list public ssh keys", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.GetUserKeys, + expectQueryParams(t, map[string]string{ + "page": "2", + "per_page": "10", + }).andThen( + mockResponse(t, http.StatusOK, mockListSSHKeyResult), + ), + ), + ), + requestArgs: map[string]any{ + "page": float64(2), + "perPage": float64(10), + }, + expectError: false, + expectedResult: mockListSSHKeyResult, + }, + { + name: "list public ssh keys with default pagination", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.GetUserKeys, + expectQueryParams(t, map[string]string{ + "page": "1", + "per_page": "30", + }).andThen( + mockResponse(t, http.StatusOK, mockListSSHKeyResult), + ), + ), + ), + expectError: false, + expectedResult: mockListSSHKeyResult, + }, + { + name: "list ssh key fails", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.GetUserKeys, + http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusUnauthorized) + _, _ = w.Write([]byte(`{"message": "bad permission"}`)) + }), + ), + ), + expectError: true, + expectedErrMsg: "failed to list users ssh keys", + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Setup client with mock + client := github.NewClient(tc.mockedClient) + _, handler := ListUsersPublicSSHKeys(stubGetClientFn(client), translations.NullTranslationHelper) + + // Create call request + request := createMCPRequest(tc.requestArgs) + + // Call handler + result, err := handler(context.Background(), request) + + // Verify results + if tc.expectError { + require.Error(t, err) + assert.Contains(t, err.Error(), tc.expectedErrMsg) + return + } + + require.NoError(t, err) + + // Parse the result and get the text content if no error + textContent := getTextResult(t, result) + + // Unmarshal and verify the result + var returnedResult []*github.Key + err = json.Unmarshal([]byte(textContent.Text), &returnedResult) + require.NoError(t, err) + assert.Equal(t, len(tc.expectedResult), len(returnedResult)) + for i, keyData := range returnedResult { + assert.Equal(t, tc.expectedResult[i].ID, keyData.ID) + assert.Equal(t, tc.expectedResult[i].Title, keyData.Title) + assert.Equal(t, tc.expectedResult[i].URL, keyData.URL) + assert.Equal(t, tc.expectedResult[i].Key, keyData.Key) + assert.Equal(t, tc.expectedResult[i].Verified, keyData.Verified) + assert.Equal(t, tc.expectedResult[i].ReadOnly, keyData.ReadOnly) + } + }) + } +} + +func Test_GetUsersPublicSSHKey(t *testing.T) { + mockClient := github.NewClient(nil) + tool, _ := GetUsersPublicSSHKey(stubGetClientFn(mockClient), translations.NullTranslationHelper) + assert.Equal(t, "get_users_public_ssh_key", tool.Name) + assert.NotEmpty(t, tool.Description) + assert.Contains(t, tool.InputSchema.Properties, "key_id") + assert.NotContains(t, tool.InputSchema.Properties, "page") + assert.NotContains(t, tool.InputSchema.Properties, "perPage") + + // Setup mock results + mockGetSSHKeyResult := &github.Key{ + ID: github.Ptr(int64(1)), + Key: github.Ptr("ssh test key"), + URL: github.Ptr("test url"), + Title: github.Ptr("test key 1"), + ReadOnly: github.Ptr(true), + Verified: github.Ptr(true), + } + tests := []struct { + name string + mockedClient *http.Client + requestArgs map[string]any + expectError bool + expectedResult *github.Key + expectedErrMsg string + }{ + { + name: "get public ssh key", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.GetUserKeysByKeyId, + expectPath(t, "/user/keys/1"). + andThen( + mockResponse(t, http.StatusOK, mockGetSSHKeyResult), + ), + ), + ), + requestArgs: map[string]any{ + "key_id": float64(1), + }, + expectError: false, + expectedResult: mockGetSSHKeyResult, + }, + { + name: "get public ssh key with bad wrong key", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.GetUserKeysByKeyId, + http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusNotFound) + _, _ = w.Write([]byte(`{"message": "key not found"}`)) + }), + ), + ), + requestArgs: map[string]any{ + "key_id": float64(2), + }, + expectError: true, + expectedErrMsg: "failed to get ssh key details", + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Setup client with mock + client := github.NewClient(tc.mockedClient) + _, handler := GetUsersPublicSSHKey(stubGetClientFn(client), translations.NullTranslationHelper) + + // Create call request + request := createMCPRequest(tc.requestArgs) + + // Call handler + result, err := handler(context.Background(), request) + + // Verify results + if tc.expectError { + require.Error(t, err) + assert.Contains(t, err.Error(), tc.expectedErrMsg) + return + } + + require.NoError(t, err) + + // Parse the result and get the text content if no error + textContent := getTextResult(t, result) + // Unmarshal and verify the result + var returnedResult *github.Key + err = json.Unmarshal([]byte(textContent.Text), &returnedResult) + require.NoError(t, err) + assert.Equal(t, tc.expectedResult.ID, returnedResult.ID) + assert.Equal(t, tc.expectedResult.Key, returnedResult.Key) + assert.Equal(t, tc.expectedResult.Title, returnedResult.Title) + assert.Equal(t, tc.expectedResult.URL, returnedResult.URL) + assert.Equal(t, tc.expectedResult.Verified, returnedResult.Verified) + assert.Equal(t, tc.expectedResult.ReadOnly, returnedResult.ReadOnly) + }) + } +} + +func Test_AddUsersPublicSSHKey(t *testing.T) { + mockClient := github.NewClient(nil) + tool, _ := AddUsersPublicSSHKey(stubGetClientFn(mockClient), translations.NullTranslationHelper) + assert.Equal(t, "add_users_public_ssh_key", tool.Name) + assert.NotEmpty(t, tool.Description) + assert.Contains(t, tool.InputSchema.Properties, "title") + assert.Contains(t, tool.InputSchema.Properties, "key") + assert.NotContains(t, tool.InputSchema.Properties, "page") + assert.NotContains(t, tool.InputSchema.Properties, "perPage") + + // Setup mock results + mockAddKeyResult := &github.Key{ + ID: github.Ptr(int64(1)), + Key: github.Ptr("ssh test key"), + URL: github.Ptr("test url"), + Title: github.Ptr("test key 1"), + ReadOnly: github.Ptr(true), + Verified: github.Ptr(true), + } + tests := []struct { + name string + mockedClient *http.Client + requestArgs map[string]any + expectError bool + expectedResult *github.Key + expectedErrMsg string + }{ + { + name: "add public ssh key", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.PostUserKeys, + expectRequestBody(t, map[string]any{ + "title": "test key 1", + "key": "ssh test key", + }). + andThen( + mockResponse(t, http.StatusCreated, mockAddKeyResult), + ), + ), + ), + requestArgs: map[string]any{ + "title": "test key 1", + "key": "ssh test key", + }, + expectError: false, + expectedResult: mockAddKeyResult, + }, + { + name: "add public ssh key fails", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.PostUserKeys, + http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + _, _ = w.Write([]byte(`{"message": "something bad happened"}`)) + }), + ), + ), + requestArgs: map[string]any{ + "title": "test key 1", + "key": "ssh test key", + }, + expectError: true, + expectedErrMsg: "failed to add ssh key", + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Setup client with mock + client := github.NewClient(tc.mockedClient) + _, handler := AddUsersPublicSSHKey(stubGetClientFn(client), translations.NullTranslationHelper) + + // Create call request + request := createMCPRequest(tc.requestArgs) + + // Call handler + result, err := handler(context.Background(), request) + + // Verify results + if tc.expectError { + require.Error(t, err) + assert.Contains(t, err.Error(), tc.expectedErrMsg) + return + } + + require.NoError(t, err) + + // Parse the result and get the text content if no error + textContent := getTextResult(t, result) + // Unmarshal and verify the result + var returnedResult *github.Key + err = json.Unmarshal([]byte(textContent.Text), &returnedResult) + require.NoError(t, err) + assert.Equal(t, tc.expectedResult.ID, returnedResult.ID) + assert.Equal(t, tc.expectedResult.Key, returnedResult.Key) + assert.Equal(t, tc.expectedResult.Title, returnedResult.Title) + assert.Equal(t, tc.expectedResult.URL, returnedResult.URL) + assert.Equal(t, tc.expectedResult.Verified, returnedResult.Verified) + assert.Equal(t, tc.expectedResult.ReadOnly, returnedResult.ReadOnly) + }) + } +}