From 06246a70340b7365cdfd6c87e9045f5d541a1f11 Mon Sep 17 00:00:00 2001 From: Jake Shorty Date: Thu, 10 Apr 2025 17:54:21 +0000 Subject: [PATCH 1/2] Add tool for getting a commit --- README.md | 9 +- pkg/github/repositories.go | 64 +++++++++++++ pkg/github/repositories_test.go | 164 +++++++++++++++++++++++++------- pkg/github/server.go | 1 + 4 files changed, 205 insertions(+), 33 deletions(-) diff --git a/README.md b/README.md index cf53ce2a..b7a6af35 100644 --- a/README.md +++ b/README.md @@ -342,7 +342,7 @@ export GITHUB_MCP_TOOL_ADD_ISSUE_COMMENT_DESCRIPTION="an alternative description - `branch`: New branch name (string, required) - `sha`: SHA to create branch from (string, required) -- **list_commits** - Gets commits of a branch in a repository +- **list_commits** - Get a list of commits of a branch in a repository - `owner`: Repository owner (string, required) - `repo`: Repository name (string, required) - `sha`: Branch name, tag, or commit SHA (string, optional) @@ -350,6 +350,13 @@ export GITHUB_MCP_TOOL_ADD_ISSUE_COMMENT_DESCRIPTION="an alternative description - `page`: Page number (number, optional) - `perPage`: Results per page (number, optional) +- **get_commit** - Get details for a commit from a repository + - `owner`: Repository owner (string, required) + - `repo`: Repository name (string, required) + - `sha`: Commit SHA, branch name, or tag name (string, required) + - `page`: Page number, for files in the commit (number, optional) + - `perPage`: Results per page, for files in the commit (number, optional) + ### Search - **search_code** - Search for code across GitHub repositories diff --git a/pkg/github/repositories.go b/pkg/github/repositories.go index 5b8725d1..c9a869e1 100644 --- a/pkg/github/repositories.go +++ b/pkg/github/repositories.go @@ -13,6 +13,70 @@ import ( "github.com/mark3labs/mcp-go/server" ) +func getCommit(client *github.Client, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { + return mcp.NewTool("get_commit", + mcp.WithDescription(t("TOOL_GET_COMMITS_DESCRIPTION", "Get details for a commit from a GitHub repository")), + mcp.WithString("owner", + mcp.Required(), + mcp.Description("Repository owner"), + ), + mcp.WithString("repo", + mcp.Required(), + mcp.Description("Repository name"), + ), + mcp.WithString("sha", + mcp.Required(), + mcp.Description("Commit SHA, branch name, or tag name"), + ), + withPagination(), + ), + func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + owner, err := requiredParam[string](request, "owner") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + repo, err := requiredParam[string](request, "repo") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + sha, err := requiredParam[string](request, "sha") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + pagination, err := optionalPaginationParams(request) + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + opts := &github.ListOptions{ + Page: pagination.page, + PerPage: pagination.perPage, + } + + commit, resp, err := client.Repositories.GetCommit(ctx, owner, repo, sha, opts) + if err != nil { + return nil, fmt.Errorf("failed to get commit: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != 200 { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + return mcp.NewToolResultError(fmt.Sprintf("failed to get commit: %s", string(body))), nil + } + + r, err := json.Marshal(commit) + if err != nil { + return nil, fmt.Errorf("failed to marshal response: %w", err) + } + + return mcp.NewToolResultText(string(r)), nil + } +} + + // listCommits creates a tool to get commits of a branch in a repository. func listCommits(client *github.Client, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { return mcp.NewTool("list_commits", diff --git a/pkg/github/repositories_test.go b/pkg/github/repositories_test.go index f7ed8e71..c8bcb7b4 100644 --- a/pkg/github/repositories_test.go +++ b/pkg/github/repositories_test.go @@ -15,6 +15,39 @@ import ( "github.com/stretchr/testify/require" ) +var mockCommits = []*github.RepositoryCommit{ + { + SHA: github.Ptr("abc123def456"), + Commit: &github.Commit{ + Message: github.Ptr("First commit"), + Author: &github.CommitAuthor{ + Name: github.Ptr("Test User"), + Email: github.Ptr("test@example.com"), + Date: &github.Timestamp{Time: time.Now().Add(-48 * time.Hour)}, + }, + }, + Author: &github.User{ + Login: github.Ptr("testuser"), + }, + HTMLURL: github.Ptr("https://github.com/owner/repo/commit/abc123def456"), + }, + { + SHA: github.Ptr("def456abc789"), + Commit: &github.Commit{ + Message: github.Ptr("Second commit"), + Author: &github.CommitAuthor{ + Name: github.Ptr("Another User"), + Email: github.Ptr("another@example.com"), + Date: &github.Timestamp{Time: time.Now().Add(-24 * time.Hour)}, + }, + }, + Author: &github.User{ + Login: github.Ptr("anotheruser"), + }, + HTMLURL: github.Ptr("https://github.com/owner/repo/commit/def456abc789"), + }, +} + func Test_GetFileContents(t *testing.T) { // Verify tool definition once mockClient := github.NewClient(nil) @@ -475,54 +508,121 @@ func Test_CreateBranch(t *testing.T) { } } -func Test_ListCommits(t *testing.T) { +func Test_GetCommit(t *testing.T) { // Verify tool definition once mockClient := github.NewClient(nil) - tool, _ := listCommits(mockClient, translations.NullTranslationHelper) + tool, _ := getCommit(mockClient, translations.NullTranslationHelper) - assert.Equal(t, "list_commits", tool.Name) + assert.Equal(t, "get_commit", tool.Name) assert.NotEmpty(t, tool.Description) assert.Contains(t, tool.InputSchema.Properties, "owner") assert.Contains(t, tool.InputSchema.Properties, "repo") assert.Contains(t, tool.InputSchema.Properties, "sha") - assert.Contains(t, tool.InputSchema.Properties, "page") - assert.Contains(t, tool.InputSchema.Properties, "perPage") - assert.ElementsMatch(t, tool.InputSchema.Required, []string{"owner", "repo"}) + assert.ElementsMatch(t, tool.InputSchema.Required, []string{"owner", "repo", "sha"}) + + mockCommit := mockCommits[0] + // This one currently isn't defined in the mock package we're using. + var mockEndpointPattern = mock.EndpointPattern{ + Pattern: "/repos/{owner}/{repo}/commits/{sha}", + Method: "GET", + } - // Setup mock commits for success case - mockCommits := []*github.RepositoryCommit{ + tests := []struct { + name string + mockedClient *http.Client + requestArgs map[string]interface{} + expectError bool + expectedCommit *github.RepositoryCommit + expectedErrMsg string + }{ { - SHA: github.Ptr("abc123def456"), - Commit: &github.Commit{ - Message: github.Ptr("First commit"), - Author: &github.CommitAuthor{ - Name: github.Ptr("Test User"), - Email: github.Ptr("test@example.com"), - Date: &github.Timestamp{Time: time.Now().Add(-48 * time.Hour)}, - }, - }, - Author: &github.User{ - Login: github.Ptr("testuser"), + name: "successful commit fetch", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mockEndpointPattern, + mockResponse(t, http.StatusOK, mockCommit), + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "sha": "abc123def456", }, - HTMLURL: github.Ptr("https://github.com/owner/repo/commit/abc123def456"), + expectError: false, + expectedCommit: mockCommit, }, { - SHA: github.Ptr("def456abc789"), - Commit: &github.Commit{ - Message: github.Ptr("Second commit"), - Author: &github.CommitAuthor{ - Name: github.Ptr("Another User"), - Email: github.Ptr("another@example.com"), - Date: &github.Timestamp{Time: time.Now().Add(-24 * time.Hour)}, - }, - }, - Author: &github.User{ - Login: github.Ptr("anotheruser"), + name: "commit fetch fails", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mockEndpointPattern, + http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusNotFound) + _, _ = w.Write([]byte(`{"message": "Not Found"}`)) + }), + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "sha": "nonexistent-sha", }, - HTMLURL: github.Ptr("https://github.com/owner/repo/commit/def456abc789"), + expectError: true, + expectedErrMsg: "failed to get commit", }, } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Setup client with mock + client := github.NewClient(tc.mockedClient) + _, handler := getCommit(client, translations.NullTranslationHelper) + + // Create call request + request := createMCPRequest(tc.requestArgs) + + // Call handler + result, err := handler(context.Background(), request) + + // Verify results + if tc.expectError { + require.Error(t, err) + assert.Contains(t, err.Error(), tc.expectedErrMsg) + return + } + + require.NoError(t, err) + + // Parse the result and get the text content if no error + textContent := getTextResult(t, result) + + // Unmarshal and verify the result + var returnedCommit github.RepositoryCommit + err = json.Unmarshal([]byte(textContent.Text), &returnedCommit) + require.NoError(t, err) + + assert.Equal(t, *tc.expectedCommit.SHA, *returnedCommit.SHA) + assert.Equal(t, *tc.expectedCommit.Commit.Message, *returnedCommit.Commit.Message) + assert.Equal(t, *tc.expectedCommit.Author.Login, *returnedCommit.Author.Login) + assert.Equal(t, *tc.expectedCommit.HTMLURL, *returnedCommit.HTMLURL) + }) + } +} + +func Test_ListCommits(t *testing.T) { + // Verify tool definition once + mockClient := github.NewClient(nil) + tool, _ := listCommits(mockClient, translations.NullTranslationHelper) + + assert.Equal(t, "list_commits", tool.Name) + assert.NotEmpty(t, tool.Description) + assert.Contains(t, tool.InputSchema.Properties, "owner") + assert.Contains(t, tool.InputSchema.Properties, "repo") + assert.Contains(t, tool.InputSchema.Properties, "sha") + assert.Contains(t, tool.InputSchema.Properties, "page") + assert.Contains(t, tool.InputSchema.Properties, "perPage") + assert.ElementsMatch(t, tool.InputSchema.Required, []string{"owner", "repo"}) + tests := []struct { name string mockedClient *http.Client diff --git a/pkg/github/server.go b/pkg/github/server.go index bf3583b9..ce06999a 100644 --- a/pkg/github/server.go +++ b/pkg/github/server.go @@ -58,6 +58,7 @@ func NewServer(client *github.Client, version string, readOnly bool, t translati // Add GitHub tools - Repositories s.AddTool(searchRepositories(client, t)) s.AddTool(getFileContents(client, t)) + s.AddTool(getCommit(client, t)) s.AddTool(listCommits(client, t)) if !readOnly { s.AddTool(createOrUpdateFile(client, t)) From 334a981c6c91f16cb12ac5cc2d93e937c4863a83 Mon Sep 17 00:00:00 2001 From: Jake Shorty Date: Thu, 10 Apr 2025 18:07:54 +0000 Subject: [PATCH 2/2] Split mock back out, use RepositoryCommit with Files/Stats --- .github/CODEOWNERS | 2 +- .github/workflows/docker-publish.yml | 12 ++ .github/workflows/goreleaser.yml | 10 + README.md | 18 +- cmd/github-mcp-server/main.go | 28 +-- cmd/mcpcurl/README.md | 4 +- pkg/github/code_scanning.go | 21 ++- pkg/github/code_scanning_test.go | 8 +- pkg/github/helper_test.go | 112 +++++++++++ pkg/github/issues.go | 116 +++++++----- pkg/github/issues_test.go | 35 ++-- pkg/github/pullrequests.go | 245 ++++++++++++++++++++----- pkg/github/pullrequests_test.go | 222 ++++++++++++++++++++-- pkg/github/repositories.go | 89 ++++++--- pkg/github/repositories_test.go | 134 ++++++++------ pkg/github/repository_resource.go | 37 ++-- pkg/github/repository_resource_test.go | 22 +-- pkg/github/search.go | 46 +++-- pkg/github/search_test.go | 12 +- pkg/github/server.go | 155 ++++++++++------ pkg/github/server_test.go | 38 ++-- script/licenses | 2 +- script/licenses-check | 2 +- 23 files changed, 999 insertions(+), 371 deletions(-) diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 998d87b2..954bc41c 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -1 +1 @@ -@juruen @sammorrowdrums @williammartin @toby +* @juruen @sammorrowdrums @williammartin @toby diff --git a/.github/workflows/docker-publish.yml b/.github/workflows/docker-publish.yml index 4c370ebe..35ffc47d 100644 --- a/.github/workflows/docker-publish.yml +++ b/.github/workflows/docker-publish.yml @@ -66,6 +66,18 @@ jobs: uses: docker/metadata-action@96383f45573cb7f253c731d3b3ab81c87ef81934 # v5.0.0 with: images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }} + tags: | + type=schedule + type=ref,event=branch + type=ref,event=tag + type=ref,event=pr + type=semver,pattern={{version}} + type=semver,pattern={{major}}.{{minor}} + type=semver,pattern={{major}} + type=sha + type=edge + # Custom rule to prevent pre-releases from getting latest tag + type=raw,value=latest,enable=${{ github.ref_type == 'tag' && startsWith(github.ref, 'refs/tags/v') && !contains(github.ref, '-') }} - name: Go Build Cache for Docker uses: actions/cache@v4 diff --git a/.github/workflows/goreleaser.yml b/.github/workflows/goreleaser.yml index a25a3469..263607ee 100644 --- a/.github/workflows/goreleaser.yml +++ b/.github/workflows/goreleaser.yml @@ -5,6 +5,8 @@ on: - "v*" permissions: contents: write + id-token: write + attestations: write jobs: release: @@ -33,3 +35,11 @@ jobs: workdir: . env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + + - name: Generate signed build provenance attestations for workflow artifacts + uses: actions/attest-build-provenance@v2 + with: + subject-path: | + dist/*.tar.gz + dist/*.zip + dist/*.txt diff --git a/README.md b/README.md index b7a6af35..f85663ab 100644 --- a/README.md +++ b/README.md @@ -15,7 +15,8 @@ automation and interaction capabilities for developers and tools. ## Prerequisites 1. To run the server in a container, you will need to have [Docker](https://www.docker.com/) installed. -2. [Create a GitHub Personal Access Token](https://github.com/settings/personal-access-tokens/new). +2. Once Docker is installed, you will also need to ensure Docker is running. +3. Lastly you will need to [Create a GitHub Personal Access Token](https://github.com/settings/personal-access-tokens/new). The MCP server can use many of the GitHub APIs, so enable the permissions that you feel comfortable granting your AI tools (to learn more about access tokens, please check out the [documentation](https://docs.github.com/en/authentication/keeping-your-account-and-data-secure/managing-your-personal-access-tokens)). @@ -287,6 +288,17 @@ export GITHUB_MCP_TOOL_ADD_ISSUE_COMMENT_DESCRIPTION="an alternative description - `draft`: Create as draft PR (boolean, optional) - `maintainer_can_modify`: Allow maintainer edits (boolean, optional) +- **update_pull_request** - Update an existing pull request in a GitHub repository + + - `owner`: Repository owner (string, required) + - `repo`: Repository name (string, required) + - `pullNumber`: Pull request number to update (number, required) + - `title`: New title (string, optional) + - `body`: New description (string, optional) + - `state`: New state ('open' or 'closed') (string, optional) + - `base`: New base branch name (string, optional) + - `maintainer_can_modify`: Allow maintainer edits (boolean, optional) + ### Repositories - **create_or_update_file** - Create or update a single file in a repository @@ -442,6 +454,10 @@ export GITHUB_MCP_TOOL_ADD_ISSUE_COMMENT_DESCRIPTION="an alternative description - `prNumber`: Pull request number (string, required) - `path`: File or directory path (string, optional) +## Library Usage + +The exported Go API of this module should currently be considered unstable, and subject to breaking changes. In the future, we may offer stability; please file an issue if there is a use case where this would be valuable. + ## License This project is licensed under the terms of the MIT open source license. Please refer to [MIT](./LICENSE) for the full terms. diff --git a/cmd/github-mcp-server/main.go b/cmd/github-mcp-server/main.go index dd4d41a7..f5539529 100644 --- a/cmd/github-mcp-server/main.go +++ b/cmd/github-mcp-server/main.go @@ -1,9 +1,7 @@ package main import ( - "bytes" "context" - "encoding/json" "fmt" "io" stdlog "log" @@ -41,7 +39,6 @@ var ( logFile := viper.GetString("log-file") readOnly := viper.GetBool("read-only") exportTranslations := viper.GetBool("export-translations") - prettyPrintJSON := viper.GetBool("pretty-print-json") logger, err := initLogger(logFile) if err != nil { stdlog.Fatal("Failed to initialize logger:", err) @@ -52,7 +49,6 @@ var ( logger: logger, logCommands: logCommands, exportTranslations: exportTranslations, - prettyPrintJSON: prettyPrintJSON, } if err := runStdioServer(cfg); err != nil { stdlog.Fatal("failed to run stdio server:", err) @@ -70,7 +66,6 @@ func init() { rootCmd.PersistentFlags().Bool("enable-command-logging", false, "When enabled, the server will log all command requests and responses to the log file") rootCmd.PersistentFlags().Bool("export-translations", false, "Save translations to a JSON file") rootCmd.PersistentFlags().String("gh-host", "", "Specify the GitHub hostname (for GitHub Enterprise etc.)") - rootCmd.PersistentFlags().Bool("pretty-print-json", false, "Pretty print JSON output") // Bind flag to viper _ = viper.BindPFlag("read-only", rootCmd.PersistentFlags().Lookup("read-only")) @@ -78,7 +73,6 @@ func init() { _ = viper.BindPFlag("enable-command-logging", rootCmd.PersistentFlags().Lookup("enable-command-logging")) _ = viper.BindPFlag("export-translations", rootCmd.PersistentFlags().Lookup("export-translations")) _ = viper.BindPFlag("gh-host", rootCmd.PersistentFlags().Lookup("gh-host")) - _ = viper.BindPFlag("pretty-print-json", rootCmd.PersistentFlags().Lookup("pretty-print-json")) // Add subcommands rootCmd.AddCommand(stdioCmd) @@ -112,20 +106,6 @@ type runConfig struct { logger *log.Logger logCommands bool exportTranslations bool - prettyPrintJSON bool -} - -// JSONPrettyPrintWriter is a Writer that pretty prints input to indented JSON -type JSONPrettyPrintWriter struct { - writer io.Writer -} - -func (j JSONPrettyPrintWriter) Write(p []byte) (n int, err error) { - var prettyJSON bytes.Buffer - if err := json.Indent(&prettyJSON, p, "", "\t"); err != nil { - return 0, err - } - return j.writer.Write(prettyJSON.Bytes()) } func runStdioServer(cfg runConfig) error { @@ -157,8 +137,11 @@ func runStdioServer(cfg runConfig) error { t, dumpTranslations := translations.TranslationHelper() + getClient := func(_ context.Context) (*gogithub.Client, error) { + return ghClient, nil // closing over client + } // Create - ghServer := github.NewServer(ghClient, version, cfg.readOnly, t) + ghServer := github.NewServer(getClient, version, cfg.readOnly, t) stdioServer := server.NewStdioServer(ghServer) stdLogger := stdlog.New(cfg.logger.Writer(), "stdioserver", 0) @@ -179,9 +162,6 @@ func runStdioServer(cfg runConfig) error { in, out = loggedIO, loggedIO } - if cfg.prettyPrintJSON { - out = JSONPrettyPrintWriter{writer: out} - } errC <- stdioServer.Listen(ctx, in, out) }() diff --git a/cmd/mcpcurl/README.md b/cmd/mcpcurl/README.md index 95e1339a..0104a1b3 100644 --- a/cmd/mcpcurl/README.md +++ b/cmd/mcpcurl/README.md @@ -49,7 +49,7 @@ Available Commands: create_repository Create a new GitHub repository in your account fork_repository Fork a GitHub repository to your account or specified organization get_file_contents Get the contents of a file or directory from a GitHub repository - get_issue Get details of a specific issue in a GitHub repository. + get_issue Get details of a specific issue in a GitHub repository get_issue_comments Get comments for a GitHub issue list_commits Get list of commits of a branch in a GitHub repository list_issues List issues in a GitHub repository with filtering options @@ -74,7 +74,7 @@ Get help for a specific tool: ```bash % ./mcpcurl --stdio-server-cmd "docker run -i --rm -e GITHUB_PERSONAL_ACCESS_TOKEN mcp/github" tools get_issue --help -Get details of a specific issue in a GitHub repository. +Get details of a specific issue in a GitHub repository Usage: mcpcurl tools get_issue [flags] diff --git a/pkg/github/code_scanning.go b/pkg/github/code_scanning.go index 81ee2c31..4fc029bf 100644 --- a/pkg/github/code_scanning.go +++ b/pkg/github/code_scanning.go @@ -13,7 +13,7 @@ import ( "github.com/mark3labs/mcp-go/server" ) -func getCodeScanningAlert(client *github.Client, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { +func GetCodeScanningAlert(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { return mcp.NewTool("get_code_scanning_alert", mcp.WithDescription(t("TOOL_GET_CODE_SCANNING_ALERT_DESCRIPTION", "Get details of a specific code scanning alert in a GitHub repository.")), mcp.WithString("owner", @@ -38,11 +38,16 @@ func getCodeScanningAlert(client *github.Client, t translations.TranslationHelpe if err != nil { return mcp.NewToolResultError(err.Error()), nil } - alertNumber, err := requiredInt(request, "alertNumber") + alertNumber, err := RequiredInt(request, "alertNumber") if err != nil { return mcp.NewToolResultError(err.Error()), nil } + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } + alert, resp, err := client.CodeScanning.GetAlert(ctx, owner, repo, int64(alertNumber)) if err != nil { return nil, fmt.Errorf("failed to get alert: %w", err) @@ -66,7 +71,7 @@ func getCodeScanningAlert(client *github.Client, t translations.TranslationHelpe } } -func listCodeScanningAlerts(client *github.Client, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { +func ListCodeScanningAlerts(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { return mcp.NewTool("list_code_scanning_alerts", mcp.WithDescription(t("TOOL_LIST_CODE_SCANNING_ALERTS_DESCRIPTION", "List code scanning alerts in a GitHub repository.")), mcp.WithString("owner", @@ -97,19 +102,23 @@ func listCodeScanningAlerts(client *github.Client, t translations.TranslationHel if err != nil { return mcp.NewToolResultError(err.Error()), nil } - ref, err := optionalParam[string](request, "ref") + ref, err := OptionalParam[string](request, "ref") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - state, err := optionalParam[string](request, "state") + state, err := OptionalParam[string](request, "state") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - severity, err := optionalParam[string](request, "severity") + severity, err := OptionalParam[string](request, "severity") if err != nil { return mcp.NewToolResultError(err.Error()), nil } + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } alerts, resp, err := client.CodeScanning.ListAlertsForRepo(ctx, owner, repo, &github.AlertListOptions{Ref: ref, State: state, Severity: severity}) if err != nil { return nil, fmt.Errorf("failed to list alerts: %w", err) diff --git a/pkg/github/code_scanning_test.go b/pkg/github/code_scanning_test.go index ec4d671e..c9895e26 100644 --- a/pkg/github/code_scanning_test.go +++ b/pkg/github/code_scanning_test.go @@ -16,7 +16,7 @@ import ( func Test_GetCodeScanningAlert(t *testing.T) { // Verify tool definition once mockClient := github.NewClient(nil) - tool, _ := getCodeScanningAlert(mockClient, translations.NullTranslationHelper) + tool, _ := GetCodeScanningAlert(stubGetClientFn(mockClient), translations.NullTranslationHelper) assert.Equal(t, "get_code_scanning_alert", tool.Name) assert.NotEmpty(t, tool.Description) @@ -82,7 +82,7 @@ func Test_GetCodeScanningAlert(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := getCodeScanningAlert(client, translations.NullTranslationHelper) + _, handler := GetCodeScanningAlert(stubGetClientFn(client), translations.NullTranslationHelper) // Create call request request := createMCPRequest(tc.requestArgs) @@ -118,7 +118,7 @@ func Test_GetCodeScanningAlert(t *testing.T) { func Test_ListCodeScanningAlerts(t *testing.T) { // Verify tool definition once mockClient := github.NewClient(nil) - tool, _ := listCodeScanningAlerts(mockClient, translations.NullTranslationHelper) + tool, _ := ListCodeScanningAlerts(stubGetClientFn(mockClient), translations.NullTranslationHelper) assert.Equal(t, "list_code_scanning_alerts", tool.Name) assert.NotEmpty(t, tool.Description) @@ -201,7 +201,7 @@ func Test_ListCodeScanningAlerts(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := listCodeScanningAlerts(client, translations.NullTranslationHelper) + _, handler := ListCodeScanningAlerts(stubGetClientFn(client), translations.NullTranslationHelper) // Create call request request := createMCPRequest(tc.requestArgs) diff --git a/pkg/github/helper_test.go b/pkg/github/helper_test.go index 9dcffa42..40fc0b94 100644 --- a/pkg/github/helper_test.go +++ b/pkg/github/helper_test.go @@ -93,3 +93,115 @@ func getTextResult(t *testing.T, result *mcp.CallToolResult) mcp.TextContent { assert.Equal(t, "text", textContent.Type) return textContent } + +func TestOptionalParamOK(t *testing.T) { + tests := []struct { + name string + args map[string]interface{} + paramName string + expectedVal interface{} + expectedOk bool + expectError bool + errorMsg string + }{ + { + name: "present and correct type (string)", + args: map[string]interface{}{"myParam": "hello"}, + paramName: "myParam", + expectedVal: "hello", + expectedOk: true, + expectError: false, + }, + { + name: "present and correct type (bool)", + args: map[string]interface{}{"myParam": true}, + paramName: "myParam", + expectedVal: true, + expectedOk: true, + expectError: false, + }, + { + name: "present and correct type (number)", + args: map[string]interface{}{"myParam": float64(123)}, + paramName: "myParam", + expectedVal: float64(123), + expectedOk: true, + expectError: false, + }, + { + name: "present but wrong type (string expected, got bool)", + args: map[string]interface{}{"myParam": true}, + paramName: "myParam", + expectedVal: "", // Zero value for string + expectedOk: true, // ok is true because param exists + expectError: true, + errorMsg: "parameter myParam is not of type string, is bool", + }, + { + name: "present but wrong type (bool expected, got string)", + args: map[string]interface{}{"myParam": "true"}, + paramName: "myParam", + expectedVal: false, // Zero value for bool + expectedOk: true, // ok is true because param exists + expectError: true, + errorMsg: "parameter myParam is not of type bool, is string", + }, + { + name: "parameter not present", + args: map[string]interface{}{"anotherParam": "value"}, + paramName: "myParam", + expectedVal: "", // Zero value for string + expectedOk: false, + expectError: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + request := createMCPRequest(tc.args) + + // Test with string type assertion + if _, isString := tc.expectedVal.(string); isString || tc.errorMsg == "parameter myParam is not of type string, is bool" { + val, ok, err := OptionalParamOK[string](request, tc.paramName) + if tc.expectError { + require.Error(t, err) + assert.Contains(t, err.Error(), tc.errorMsg) + assert.Equal(t, tc.expectedOk, ok) // Check ok even on error + assert.Equal(t, tc.expectedVal, val) // Check zero value on error + } else { + require.NoError(t, err) + assert.Equal(t, tc.expectedOk, ok) + assert.Equal(t, tc.expectedVal, val) + } + } + + // Test with bool type assertion + if _, isBool := tc.expectedVal.(bool); isBool || tc.errorMsg == "parameter myParam is not of type bool, is string" { + val, ok, err := OptionalParamOK[bool](request, tc.paramName) + if tc.expectError { + require.Error(t, err) + assert.Contains(t, err.Error(), tc.errorMsg) + assert.Equal(t, tc.expectedOk, ok) // Check ok even on error + assert.Equal(t, tc.expectedVal, val) // Check zero value on error + } else { + require.NoError(t, err) + assert.Equal(t, tc.expectedOk, ok) + assert.Equal(t, tc.expectedVal, val) + } + } + + // Test with float64 type assertion (for number case) + if _, isFloat := tc.expectedVal.(float64); isFloat { + val, ok, err := OptionalParamOK[float64](request, tc.paramName) + if tc.expectError { + // This case shouldn't happen for float64 in the defined tests + require.Fail(t, "Unexpected error case for float64") + } else { + require.NoError(t, err) + assert.Equal(t, tc.expectedOk, ok) + assert.Equal(t, tc.expectedVal, val) + } + } + }) + } +} diff --git a/pkg/github/issues.go b/pkg/github/issues.go index 1632e9e8..16c34141 100644 --- a/pkg/github/issues.go +++ b/pkg/github/issues.go @@ -14,21 +14,21 @@ import ( "github.com/mark3labs/mcp-go/server" ) -// getIssue creates a tool to get details of a specific issue in a GitHub repository. -func getIssue(client *github.Client, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { +// GetIssue creates a tool to get details of a specific issue in a GitHub repository. +func GetIssue(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { return mcp.NewTool("get_issue", - mcp.WithDescription(t("TOOL_GET_ISSUE_DESCRIPTION", "Get details of a specific issue in a GitHub repository.")), + mcp.WithDescription(t("TOOL_GET_ISSUE_DESCRIPTION", "Get details of a specific issue in a GitHub repository")), mcp.WithString("owner", mcp.Required(), - mcp.Description("The owner of the repository."), + mcp.Description("The owner of the repository"), ), mcp.WithString("repo", mcp.Required(), - mcp.Description("The name of the repository."), + mcp.Description("The name of the repository"), ), mcp.WithNumber("issue_number", mcp.Required(), - mcp.Description("The number of the issue."), + mcp.Description("The number of the issue"), ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { @@ -40,11 +40,15 @@ func getIssue(client *github.Client, t translations.TranslationHelperFunc) (tool if err != nil { return mcp.NewToolResultError(err.Error()), nil } - issueNumber, err := requiredInt(request, "issue_number") + issueNumber, err := RequiredInt(request, "issue_number") if err != nil { return mcp.NewToolResultError(err.Error()), nil } + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } issue, resp, err := client.Issues.Get(ctx, owner, repo, issueNumber) if err != nil { return nil, fmt.Errorf("failed to get issue: %w", err) @@ -68,8 +72,8 @@ func getIssue(client *github.Client, t translations.TranslationHelperFunc) (tool } } -// addIssueComment creates a tool to add a comment to an issue. -func addIssueComment(client *github.Client, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { +// AddIssueComment creates a tool to add a comment to an issue. +func AddIssueComment(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { return mcp.NewTool("add_issue_comment", mcp.WithDescription(t("TOOL_ADD_ISSUE_COMMENT_DESCRIPTION", "Add a comment to an existing issue")), mcp.WithString("owner", @@ -98,7 +102,7 @@ func addIssueComment(client *github.Client, t translations.TranslationHelperFunc if err != nil { return mcp.NewToolResultError(err.Error()), nil } - issueNumber, err := requiredInt(request, "issue_number") + issueNumber, err := RequiredInt(request, "issue_number") if err != nil { return mcp.NewToolResultError(err.Error()), nil } @@ -111,6 +115,10 @@ func addIssueComment(client *github.Client, t translations.TranslationHelperFunc Body: github.Ptr(body), } + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } createdComment, resp, err := client.Issues.CreateComment(ctx, owner, repo, issueNumber, comment) if err != nil { return nil, fmt.Errorf("failed to create comment: %w", err) @@ -134,8 +142,8 @@ func addIssueComment(client *github.Client, t translations.TranslationHelperFunc } } -// searchIssues creates a tool to search for issues and pull requests. -func searchIssues(client *github.Client, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { +// SearchIssues creates a tool to search for issues and pull requests. +func SearchIssues(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { return mcp.NewTool("search_issues", mcp.WithDescription(t("TOOL_SEARCH_ISSUES_DESCRIPTION", "Search for issues and pull requests across GitHub repositories")), mcp.WithString("q", @@ -162,22 +170,22 @@ func searchIssues(client *github.Client, t translations.TranslationHelperFunc) ( mcp.Description("Sort order ('asc' or 'desc')"), mcp.Enum("asc", "desc"), ), - withPagination(), + WithPagination(), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { query, err := requiredParam[string](request, "q") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - sort, err := optionalParam[string](request, "sort") + sort, err := OptionalParam[string](request, "sort") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - order, err := optionalParam[string](request, "order") + order, err := OptionalParam[string](request, "order") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - pagination, err := optionalPaginationParams(request) + pagination, err := OptionalPaginationParams(request) if err != nil { return mcp.NewToolResultError(err.Error()), nil } @@ -191,6 +199,10 @@ func searchIssues(client *github.Client, t translations.TranslationHelperFunc) ( }, } + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } result, resp, err := client.Search.Issues(ctx, query, opts) if err != nil { return nil, fmt.Errorf("failed to search issues: %w", err) @@ -214,8 +226,8 @@ func searchIssues(client *github.Client, t translations.TranslationHelperFunc) ( } } -// createIssue creates a tool to create a new issue in a GitHub repository. -func createIssue(client *github.Client, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { +// CreateIssue creates a tool to create a new issue in a GitHub repository. +func CreateIssue(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { return mcp.NewTool("create_issue", mcp.WithDescription(t("TOOL_CREATE_ISSUE_DESCRIPTION", "Create a new issue in a GitHub repository")), mcp.WithString("owner", @@ -268,25 +280,25 @@ func createIssue(client *github.Client, t translations.TranslationHelperFunc) (t } // Optional parameters - body, err := optionalParam[string](request, "body") + body, err := OptionalParam[string](request, "body") if err != nil { return mcp.NewToolResultError(err.Error()), nil } // Get assignees - assignees, err := optionalStringArrayParam(request, "assignees") + assignees, err := OptionalStringArrayParam(request, "assignees") if err != nil { return mcp.NewToolResultError(err.Error()), nil } // Get labels - labels, err := optionalStringArrayParam(request, "labels") + labels, err := OptionalStringArrayParam(request, "labels") if err != nil { return mcp.NewToolResultError(err.Error()), nil } // Get optional milestone - milestone, err := optionalIntParam(request, "milestone") + milestone, err := OptionalIntParam(request, "milestone") if err != nil { return mcp.NewToolResultError(err.Error()), nil } @@ -305,6 +317,10 @@ func createIssue(client *github.Client, t translations.TranslationHelperFunc) (t Milestone: milestoneNum, } + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } issue, resp, err := client.Issues.Create(ctx, owner, repo, issueRequest) if err != nil { return nil, fmt.Errorf("failed to create issue: %w", err) @@ -328,8 +344,8 @@ func createIssue(client *github.Client, t translations.TranslationHelperFunc) (t } } -// listIssues creates a tool to list and filter repository issues -func listIssues(client *github.Client, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { +// ListIssues creates a tool to list and filter repository issues +func ListIssues(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { return mcp.NewTool("list_issues", mcp.WithDescription(t("TOOL_LIST_ISSUES_DESCRIPTION", "List issues in a GitHub repository with filtering options")), mcp.WithString("owner", @@ -363,7 +379,7 @@ func listIssues(client *github.Client, t translations.TranslationHelperFunc) (to mcp.WithString("since", mcp.Description("Filter by date (ISO 8601 timestamp)"), ), - withPagination(), + WithPagination(), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { owner, err := requiredParam[string](request, "owner") @@ -378,28 +394,28 @@ func listIssues(client *github.Client, t translations.TranslationHelperFunc) (to opts := &github.IssueListByRepoOptions{} // Set optional parameters if provided - opts.State, err = optionalParam[string](request, "state") + opts.State, err = OptionalParam[string](request, "state") if err != nil { return mcp.NewToolResultError(err.Error()), nil } // Get labels - opts.Labels, err = optionalStringArrayParam(request, "labels") + opts.Labels, err = OptionalStringArrayParam(request, "labels") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - opts.Sort, err = optionalParam[string](request, "sort") + opts.Sort, err = OptionalParam[string](request, "sort") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - opts.Direction, err = optionalParam[string](request, "direction") + opts.Direction, err = OptionalParam[string](request, "direction") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - since, err := optionalParam[string](request, "since") + since, err := OptionalParam[string](request, "since") if err != nil { return mcp.NewToolResultError(err.Error()), nil } @@ -419,6 +435,10 @@ func listIssues(client *github.Client, t translations.TranslationHelperFunc) (to opts.PerPage = int(perPage) } + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } issues, resp, err := client.Issues.ListByRepo(ctx, owner, repo, opts) if err != nil { return nil, fmt.Errorf("failed to list issues: %w", err) @@ -442,8 +462,8 @@ func listIssues(client *github.Client, t translations.TranslationHelperFunc) (to } } -// updateIssue creates a tool to update an existing issue in a GitHub repository. -func updateIssue(client *github.Client, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { +// UpdateIssue creates a tool to update an existing issue in a GitHub repository. +func UpdateIssue(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { return mcp.NewTool("update_issue", mcp.WithDescription(t("TOOL_UPDATE_ISSUE_DESCRIPTION", "Update an existing issue in a GitHub repository")), mcp.WithString("owner", @@ -497,7 +517,7 @@ func updateIssue(client *github.Client, t translations.TranslationHelperFunc) (t if err != nil { return mcp.NewToolResultError(err.Error()), nil } - issueNumber, err := requiredInt(request, "issue_number") + issueNumber, err := RequiredInt(request, "issue_number") if err != nil { return mcp.NewToolResultError(err.Error()), nil } @@ -506,7 +526,7 @@ func updateIssue(client *github.Client, t translations.TranslationHelperFunc) (t issueRequest := &github.IssueRequest{} // Set optional parameters if provided - title, err := optionalParam[string](request, "title") + title, err := OptionalParam[string](request, "title") if err != nil { return mcp.NewToolResultError(err.Error()), nil } @@ -514,7 +534,7 @@ func updateIssue(client *github.Client, t translations.TranslationHelperFunc) (t issueRequest.Title = github.Ptr(title) } - body, err := optionalParam[string](request, "body") + body, err := OptionalParam[string](request, "body") if err != nil { return mcp.NewToolResultError(err.Error()), nil } @@ -522,7 +542,7 @@ func updateIssue(client *github.Client, t translations.TranslationHelperFunc) (t issueRequest.Body = github.Ptr(body) } - state, err := optionalParam[string](request, "state") + state, err := OptionalParam[string](request, "state") if err != nil { return mcp.NewToolResultError(err.Error()), nil } @@ -531,7 +551,7 @@ func updateIssue(client *github.Client, t translations.TranslationHelperFunc) (t } // Get labels - labels, err := optionalStringArrayParam(request, "labels") + labels, err := OptionalStringArrayParam(request, "labels") if err != nil { return mcp.NewToolResultError(err.Error()), nil } @@ -540,7 +560,7 @@ func updateIssue(client *github.Client, t translations.TranslationHelperFunc) (t } // Get assignees - assignees, err := optionalStringArrayParam(request, "assignees") + assignees, err := OptionalStringArrayParam(request, "assignees") if err != nil { return mcp.NewToolResultError(err.Error()), nil } @@ -548,7 +568,7 @@ func updateIssue(client *github.Client, t translations.TranslationHelperFunc) (t issueRequest.Assignees = &assignees } - milestone, err := optionalIntParam(request, "milestone") + milestone, err := OptionalIntParam(request, "milestone") if err != nil { return mcp.NewToolResultError(err.Error()), nil } @@ -557,6 +577,10 @@ func updateIssue(client *github.Client, t translations.TranslationHelperFunc) (t issueRequest.Milestone = &milestoneNum } + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } updatedIssue, resp, err := client.Issues.Edit(ctx, owner, repo, issueNumber, issueRequest) if err != nil { return nil, fmt.Errorf("failed to update issue: %w", err) @@ -580,8 +604,8 @@ func updateIssue(client *github.Client, t translations.TranslationHelperFunc) (t } } -// getIssueComments creates a tool to get comments for a GitHub issue. -func getIssueComments(client *github.Client, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { +// GetIssueComments creates a tool to get comments for a GitHub issue. +func GetIssueComments(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { return mcp.NewTool("get_issue_comments", mcp.WithDescription(t("TOOL_GET_ISSUE_COMMENTS_DESCRIPTION", "Get comments for a GitHub issue")), mcp.WithString("owner", @@ -612,15 +636,15 @@ func getIssueComments(client *github.Client, t translations.TranslationHelperFun if err != nil { return mcp.NewToolResultError(err.Error()), nil } - issueNumber, err := requiredInt(request, "issue_number") + issueNumber, err := RequiredInt(request, "issue_number") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - page, err := optionalIntParamWithDefault(request, "page", 1) + page, err := OptionalIntParamWithDefault(request, "page", 1) if err != nil { return mcp.NewToolResultError(err.Error()), nil } - perPage, err := optionalIntParamWithDefault(request, "per_page", 30) + perPage, err := OptionalIntParamWithDefault(request, "per_page", 30) if err != nil { return mcp.NewToolResultError(err.Error()), nil } @@ -632,6 +656,10 @@ func getIssueComments(client *github.Client, t translations.TranslationHelperFun }, } + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } comments, resp, err := client.Issues.ListComments(ctx, owner, repo, issueNumber, opts) if err != nil { return nil, fmt.Errorf("failed to get issue comments: %w", err) diff --git a/pkg/github/issues_test.go b/pkg/github/issues_test.go index 485169fd..61ca0ae7 100644 --- a/pkg/github/issues_test.go +++ b/pkg/github/issues_test.go @@ -18,7 +18,7 @@ import ( func Test_GetIssue(t *testing.T) { // Verify tool definition once mockClient := github.NewClient(nil) - tool, _ := getIssue(mockClient, translations.NullTranslationHelper) + tool, _ := GetIssue(stubGetClientFn(mockClient), translations.NullTranslationHelper) assert.Equal(t, "get_issue", tool.Name) assert.NotEmpty(t, tool.Description) @@ -82,7 +82,7 @@ func Test_GetIssue(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := getIssue(client, translations.NullTranslationHelper) + _, handler := GetIssue(stubGetClientFn(client), translations.NullTranslationHelper) // Create call request request := createMCPRequest(tc.requestArgs) @@ -114,7 +114,7 @@ func Test_GetIssue(t *testing.T) { func Test_AddIssueComment(t *testing.T) { // Verify tool definition once mockClient := github.NewClient(nil) - tool, _ := addIssueComment(mockClient, translations.NullTranslationHelper) + tool, _ := AddIssueComment(stubGetClientFn(mockClient), translations.NullTranslationHelper) assert.Equal(t, "add_issue_comment", tool.Name) assert.NotEmpty(t, tool.Description) @@ -185,7 +185,7 @@ func Test_AddIssueComment(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := addIssueComment(client, translations.NullTranslationHelper) + _, handler := AddIssueComment(stubGetClientFn(client), translations.NullTranslationHelper) // Create call request request := mcp.CallToolRequest{ @@ -237,7 +237,7 @@ func Test_AddIssueComment(t *testing.T) { func Test_SearchIssues(t *testing.T) { // Verify tool definition once mockClient := github.NewClient(nil) - tool, _ := searchIssues(mockClient, translations.NullTranslationHelper) + tool, _ := SearchIssues(stubGetClientFn(mockClient), translations.NullTranslationHelper) assert.Equal(t, "search_issues", tool.Name) assert.NotEmpty(t, tool.Description) @@ -352,7 +352,7 @@ func Test_SearchIssues(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := searchIssues(client, translations.NullTranslationHelper) + _, handler := SearchIssues(stubGetClientFn(client), translations.NullTranslationHelper) // Create call request request := createMCPRequest(tc.requestArgs) @@ -393,7 +393,7 @@ func Test_SearchIssues(t *testing.T) { func Test_CreateIssue(t *testing.T) { // Verify tool definition once mockClient := github.NewClient(nil) - tool, _ := createIssue(mockClient, translations.NullTranslationHelper) + tool, _ := CreateIssue(stubGetClientFn(mockClient), translations.NullTranslationHelper) assert.Equal(t, "create_issue", tool.Name) assert.NotEmpty(t, tool.Description) @@ -468,9 +468,10 @@ func Test_CreateIssue(t *testing.T) { ), ), requestArgs: map[string]interface{}{ - "owner": "owner", - "repo": "repo", - "title": "Minimal Issue", + "owner": "owner", + "repo": "repo", + "title": "Minimal Issue", + "assignees": nil, // Expect no failure with nil optional value. }, expectError: false, expectedIssue: &github.Issue{ @@ -505,7 +506,7 @@ func Test_CreateIssue(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := createIssue(client, translations.NullTranslationHelper) + _, handler := CreateIssue(stubGetClientFn(client), translations.NullTranslationHelper) // Create call request request := createMCPRequest(tc.requestArgs) @@ -566,7 +567,7 @@ func Test_CreateIssue(t *testing.T) { func Test_ListIssues(t *testing.T) { // Verify tool definition mockClient := github.NewClient(nil) - tool, _ := listIssues(mockClient, translations.NullTranslationHelper) + tool, _ := ListIssues(stubGetClientFn(mockClient), translations.NullTranslationHelper) assert.Equal(t, "list_issues", tool.Name) assert.NotEmpty(t, tool.Description) @@ -697,7 +698,7 @@ func Test_ListIssues(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := listIssues(client, translations.NullTranslationHelper) + _, handler := ListIssues(stubGetClientFn(client), translations.NullTranslationHelper) // Create call request request := createMCPRequest(tc.requestArgs) @@ -742,7 +743,7 @@ func Test_ListIssues(t *testing.T) { func Test_UpdateIssue(t *testing.T) { // Verify tool definition mockClient := github.NewClient(nil) - tool, _ := updateIssue(mockClient, translations.NullTranslationHelper) + tool, _ := UpdateIssue(stubGetClientFn(mockClient), translations.NullTranslationHelper) assert.Equal(t, "update_issue", tool.Name) assert.NotEmpty(t, tool.Description) @@ -881,7 +882,7 @@ func Test_UpdateIssue(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := updateIssue(client, translations.NullTranslationHelper) + _, handler := UpdateIssue(stubGetClientFn(client), translations.NullTranslationHelper) // Create call request request := createMCPRequest(tc.requestArgs) @@ -999,7 +1000,7 @@ func Test_ParseISOTimestamp(t *testing.T) { func Test_GetIssueComments(t *testing.T) { // Verify tool definition once mockClient := github.NewClient(nil) - tool, _ := getIssueComments(mockClient, translations.NullTranslationHelper) + tool, _ := GetIssueComments(stubGetClientFn(mockClient), translations.NullTranslationHelper) assert.Equal(t, "get_issue_comments", tool.Name) assert.NotEmpty(t, tool.Description) @@ -1099,7 +1100,7 @@ func Test_GetIssueComments(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := getIssueComments(client, translations.NullTranslationHelper) + _, handler := GetIssueComments(stubGetClientFn(client), translations.NullTranslationHelper) // Create call request request := createMCPRequest(tc.requestArgs) diff --git a/pkg/github/pullrequests.go b/pkg/github/pullrequests.go index 25090cb7..14aeb918 100644 --- a/pkg/github/pullrequests.go +++ b/pkg/github/pullrequests.go @@ -13,8 +13,8 @@ import ( "github.com/mark3labs/mcp-go/server" ) -// getPullRequest creates a tool to get details of a specific pull request. -func getPullRequest(client *github.Client, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { +// GetPullRequest creates a tool to get details of a specific pull request. +func GetPullRequest(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { return mcp.NewTool("get_pull_request", mcp.WithDescription(t("TOOL_GET_PULL_REQUEST_DESCRIPTION", "Get details of a specific pull request")), mcp.WithString("owner", @@ -39,11 +39,15 @@ func getPullRequest(client *github.Client, t translations.TranslationHelperFunc) if err != nil { return mcp.NewToolResultError(err.Error()), nil } - pullNumber, err := requiredInt(request, "pullNumber") + pullNumber, err := RequiredInt(request, "pullNumber") if err != nil { return mcp.NewToolResultError(err.Error()), nil } + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } pr, resp, err := client.PullRequests.Get(ctx, owner, repo, pullNumber) if err != nil { return nil, fmt.Errorf("failed to get pull request: %w", err) @@ -67,8 +71,125 @@ func getPullRequest(client *github.Client, t translations.TranslationHelperFunc) } } -// listPullRequests creates a tool to list and filter repository pull requests. -func listPullRequests(client *github.Client, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { +// UpdatePullRequest creates a tool to update an existing pull request. +func UpdatePullRequest(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { + return mcp.NewTool("update_pull_request", + mcp.WithDescription(t("TOOL_UPDATE_PULL_REQUEST_DESCRIPTION", "Update an existing pull request in a GitHub repository")), + mcp.WithString("owner", + mcp.Required(), + mcp.Description("Repository owner"), + ), + mcp.WithString("repo", + mcp.Required(), + mcp.Description("Repository name"), + ), + mcp.WithNumber("pullNumber", + mcp.Required(), + mcp.Description("Pull request number to update"), + ), + mcp.WithString("title", + mcp.Description("New title"), + ), + mcp.WithString("body", + mcp.Description("New description"), + ), + mcp.WithString("state", + mcp.Description("New state ('open' or 'closed')"), + mcp.Enum("open", "closed"), + ), + mcp.WithString("base", + mcp.Description("New base branch name"), + ), + mcp.WithBoolean("maintainer_can_modify", + mcp.Description("Allow maintainer edits"), + ), + ), + func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + owner, err := requiredParam[string](request, "owner") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + repo, err := requiredParam[string](request, "repo") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + pullNumber, err := RequiredInt(request, "pullNumber") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + // Build the update struct only with provided fields + update := &github.PullRequest{} + updateNeeded := false + + if title, ok, err := OptionalParamOK[string](request, "title"); err != nil { + return mcp.NewToolResultError(err.Error()), nil + } else if ok { + update.Title = github.Ptr(title) + updateNeeded = true + } + + if body, ok, err := OptionalParamOK[string](request, "body"); err != nil { + return mcp.NewToolResultError(err.Error()), nil + } else if ok { + update.Body = github.Ptr(body) + updateNeeded = true + } + + if state, ok, err := OptionalParamOK[string](request, "state"); err != nil { + return mcp.NewToolResultError(err.Error()), nil + } else if ok { + update.State = github.Ptr(state) + updateNeeded = true + } + + if base, ok, err := OptionalParamOK[string](request, "base"); err != nil { + return mcp.NewToolResultError(err.Error()), nil + } else if ok { + update.Base = &github.PullRequestBranch{Ref: github.Ptr(base)} + updateNeeded = true + } + + if maintainerCanModify, ok, err := OptionalParamOK[bool](request, "maintainer_can_modify"); err != nil { + return mcp.NewToolResultError(err.Error()), nil + } else if ok { + update.MaintainerCanModify = github.Ptr(maintainerCanModify) + updateNeeded = true + } + + if !updateNeeded { + return mcp.NewToolResultError("No update parameters provided."), nil + } + + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } + pr, resp, err := client.PullRequests.Edit(ctx, owner, repo, pullNumber, update) + if err != nil { + return nil, fmt.Errorf("failed to update pull request: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + return mcp.NewToolResultError(fmt.Sprintf("failed to update pull request: %s", string(body))), nil + } + + r, err := json.Marshal(pr) + if err != nil { + return nil, fmt.Errorf("failed to marshal response: %w", err) + } + + return mcp.NewToolResultText(string(r)), nil + } +} + +// ListPullRequests creates a tool to list and filter repository pull requests. +func ListPullRequests(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { return mcp.NewTool("list_pull_requests", mcp.WithDescription(t("TOOL_LIST_PULL_REQUESTS_DESCRIPTION", "List and filter repository pull requests")), mcp.WithString("owner", @@ -94,7 +215,7 @@ func listPullRequests(client *github.Client, t translations.TranslationHelperFun mcp.WithString("direction", mcp.Description("Sort direction ('asc', 'desc')"), ), - withPagination(), + WithPagination(), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { owner, err := requiredParam[string](request, "owner") @@ -105,27 +226,27 @@ func listPullRequests(client *github.Client, t translations.TranslationHelperFun if err != nil { return mcp.NewToolResultError(err.Error()), nil } - state, err := optionalParam[string](request, "state") + state, err := OptionalParam[string](request, "state") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - head, err := optionalParam[string](request, "head") + head, err := OptionalParam[string](request, "head") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - base, err := optionalParam[string](request, "base") + base, err := OptionalParam[string](request, "base") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - sort, err := optionalParam[string](request, "sort") + sort, err := OptionalParam[string](request, "sort") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - direction, err := optionalParam[string](request, "direction") + direction, err := OptionalParam[string](request, "direction") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - pagination, err := optionalPaginationParams(request) + pagination, err := OptionalPaginationParams(request) if err != nil { return mcp.NewToolResultError(err.Error()), nil } @@ -142,6 +263,10 @@ func listPullRequests(client *github.Client, t translations.TranslationHelperFun }, } + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } prs, resp, err := client.PullRequests.List(ctx, owner, repo, opts) if err != nil { return nil, fmt.Errorf("failed to list pull requests: %w", err) @@ -165,8 +290,8 @@ func listPullRequests(client *github.Client, t translations.TranslationHelperFun } } -// mergePullRequest creates a tool to merge a pull request. -func mergePullRequest(client *github.Client, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { +// MergePullRequest creates a tool to merge a pull request. +func MergePullRequest(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { return mcp.NewTool("merge_pull_request", mcp.WithDescription(t("TOOL_MERGE_PULL_REQUEST_DESCRIPTION", "Merge a pull request")), mcp.WithString("owner", @@ -200,19 +325,19 @@ func mergePullRequest(client *github.Client, t translations.TranslationHelperFun if err != nil { return mcp.NewToolResultError(err.Error()), nil } - pullNumber, err := requiredInt(request, "pullNumber") + pullNumber, err := RequiredInt(request, "pullNumber") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - commitTitle, err := optionalParam[string](request, "commit_title") + commitTitle, err := OptionalParam[string](request, "commit_title") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - commitMessage, err := optionalParam[string](request, "commit_message") + commitMessage, err := OptionalParam[string](request, "commit_message") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - mergeMethod, err := optionalParam[string](request, "merge_method") + mergeMethod, err := OptionalParam[string](request, "merge_method") if err != nil { return mcp.NewToolResultError(err.Error()), nil } @@ -222,6 +347,10 @@ func mergePullRequest(client *github.Client, t translations.TranslationHelperFun MergeMethod: mergeMethod, } + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } result, resp, err := client.PullRequests.Merge(ctx, owner, repo, pullNumber, commitMessage, options) if err != nil { return nil, fmt.Errorf("failed to merge pull request: %w", err) @@ -245,8 +374,8 @@ func mergePullRequest(client *github.Client, t translations.TranslationHelperFun } } -// getPullRequestFiles creates a tool to get the list of files changed in a pull request. -func getPullRequestFiles(client *github.Client, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { +// GetPullRequestFiles creates a tool to get the list of files changed in a pull request. +func GetPullRequestFiles(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { return mcp.NewTool("get_pull_request_files", mcp.WithDescription(t("TOOL_GET_PULL_REQUEST_FILES_DESCRIPTION", "Get the list of files changed in a pull request")), mcp.WithString("owner", @@ -271,11 +400,15 @@ func getPullRequestFiles(client *github.Client, t translations.TranslationHelper if err != nil { return mcp.NewToolResultError(err.Error()), nil } - pullNumber, err := requiredInt(request, "pullNumber") + pullNumber, err := RequiredInt(request, "pullNumber") if err != nil { return mcp.NewToolResultError(err.Error()), nil } + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } opts := &github.ListOptions{} files, resp, err := client.PullRequests.ListFiles(ctx, owner, repo, pullNumber, opts) if err != nil { @@ -300,8 +433,8 @@ func getPullRequestFiles(client *github.Client, t translations.TranslationHelper } } -// getPullRequestStatus creates a tool to get the combined status of all status checks for a pull request. -func getPullRequestStatus(client *github.Client, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { +// GetPullRequestStatus creates a tool to get the combined status of all status checks for a pull request. +func GetPullRequestStatus(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { return mcp.NewTool("get_pull_request_status", mcp.WithDescription(t("TOOL_GET_PULL_REQUEST_STATUS_DESCRIPTION", "Get the combined status of all status checks for a pull request")), mcp.WithString("owner", @@ -326,11 +459,15 @@ func getPullRequestStatus(client *github.Client, t translations.TranslationHelpe if err != nil { return mcp.NewToolResultError(err.Error()), nil } - pullNumber, err := requiredInt(request, "pullNumber") + pullNumber, err := RequiredInt(request, "pullNumber") if err != nil { return mcp.NewToolResultError(err.Error()), nil } // First get the PR to find the head SHA + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } pr, resp, err := client.PullRequests.Get(ctx, owner, repo, pullNumber) if err != nil { return nil, fmt.Errorf("failed to get pull request: %w", err) @@ -369,8 +506,8 @@ func getPullRequestStatus(client *github.Client, t translations.TranslationHelpe } } -// updatePullRequestBranch creates a tool to update a pull request branch with the latest changes from the base branch. -func updatePullRequestBranch(client *github.Client, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { +// UpdatePullRequestBranch creates a tool to update a pull request branch with the latest changes from the base branch. +func UpdatePullRequestBranch(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { return mcp.NewTool("update_pull_request_branch", mcp.WithDescription(t("TOOL_UPDATE_PULL_REQUEST_BRANCH_DESCRIPTION", "Update a pull request branch with the latest changes from the base branch")), mcp.WithString("owner", @@ -398,11 +535,11 @@ func updatePullRequestBranch(client *github.Client, t translations.TranslationHe if err != nil { return mcp.NewToolResultError(err.Error()), nil } - pullNumber, err := requiredInt(request, "pullNumber") + pullNumber, err := RequiredInt(request, "pullNumber") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - expectedHeadSHA, err := optionalParam[string](request, "expectedHeadSha") + expectedHeadSHA, err := OptionalParam[string](request, "expectedHeadSha") if err != nil { return mcp.NewToolResultError(err.Error()), nil } @@ -411,6 +548,10 @@ func updatePullRequestBranch(client *github.Client, t translations.TranslationHe opts.ExpectedHeadSHA = github.Ptr(expectedHeadSHA) } + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } result, resp, err := client.PullRequests.UpdateBranch(ctx, owner, repo, pullNumber, opts) if err != nil { // Check if it's an acceptedError. An acceptedError indicates that the update is in progress, @@ -439,8 +580,8 @@ func updatePullRequestBranch(client *github.Client, t translations.TranslationHe } } -// getPullRequestComments creates a tool to get the review comments on a pull request. -func getPullRequestComments(client *github.Client, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { +// GetPullRequestComments creates a tool to get the review comments on a pull request. +func GetPullRequestComments(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { return mcp.NewTool("get_pull_request_comments", mcp.WithDescription(t("TOOL_GET_PULL_REQUEST_COMMENTS_DESCRIPTION", "Get the review comments on a pull request")), mcp.WithString("owner", @@ -465,7 +606,7 @@ func getPullRequestComments(client *github.Client, t translations.TranslationHel if err != nil { return mcp.NewToolResultError(err.Error()), nil } - pullNumber, err := requiredInt(request, "pullNumber") + pullNumber, err := RequiredInt(request, "pullNumber") if err != nil { return mcp.NewToolResultError(err.Error()), nil } @@ -476,6 +617,10 @@ func getPullRequestComments(client *github.Client, t translations.TranslationHel }, } + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } comments, resp, err := client.PullRequests.ListComments(ctx, owner, repo, pullNumber, opts) if err != nil { return nil, fmt.Errorf("failed to get pull request comments: %w", err) @@ -499,8 +644,8 @@ func getPullRequestComments(client *github.Client, t translations.TranslationHel } } -// getPullRequestReviews creates a tool to get the reviews on a pull request. -func getPullRequestReviews(client *github.Client, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { +// GetPullRequestReviews creates a tool to get the reviews on a pull request. +func GetPullRequestReviews(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { return mcp.NewTool("get_pull_request_reviews", mcp.WithDescription(t("TOOL_GET_PULL_REQUEST_REVIEWS_DESCRIPTION", "Get the reviews on a pull request")), mcp.WithString("owner", @@ -525,11 +670,15 @@ func getPullRequestReviews(client *github.Client, t translations.TranslationHelp if err != nil { return mcp.NewToolResultError(err.Error()), nil } - pullNumber, err := requiredInt(request, "pullNumber") + pullNumber, err := RequiredInt(request, "pullNumber") if err != nil { return mcp.NewToolResultError(err.Error()), nil } + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } reviews, resp, err := client.PullRequests.ListReviews(ctx, owner, repo, pullNumber, nil) if err != nil { return nil, fmt.Errorf("failed to get pull request reviews: %w", err) @@ -553,8 +702,8 @@ func getPullRequestReviews(client *github.Client, t translations.TranslationHelp } } -// createPullRequestReview creates a tool to submit a review on a pull request. -func createPullRequestReview(client *github.Client, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { +// CreatePullRequestReview creates a tool to submit a review on a pull request. +func CreatePullRequestReview(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { return mcp.NewTool("create_pull_request_review", mcp.WithDescription(t("TOOL_CREATE_PULL_REQUEST_REVIEW_DESCRIPTION", "Create a review on a pull request")), mcp.WithString("owner", @@ -629,7 +778,7 @@ func createPullRequestReview(client *github.Client, t translations.TranslationHe if err != nil { return mcp.NewToolResultError(err.Error()), nil } - pullNumber, err := requiredInt(request, "pullNumber") + pullNumber, err := RequiredInt(request, "pullNumber") if err != nil { return mcp.NewToolResultError(err.Error()), nil } @@ -644,7 +793,7 @@ func createPullRequestReview(client *github.Client, t translations.TranslationHe } // Add body if provided - body, err := optionalParam[string](request, "body") + body, err := OptionalParam[string](request, "body") if err != nil { return mcp.NewToolResultError(err.Error()), nil } @@ -653,7 +802,7 @@ func createPullRequestReview(client *github.Client, t translations.TranslationHe } // Add commit ID if provided - commitID, err := optionalParam[string](request, "commitId") + commitID, err := OptionalParam[string](request, "commitId") if err != nil { return mcp.NewToolResultError(err.Error()), nil } @@ -722,6 +871,10 @@ func createPullRequestReview(client *github.Client, t translations.TranslationHe reviewRequest.Comments = comments } + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } review, resp, err := client.PullRequests.CreateReview(ctx, owner, repo, pullNumber, reviewRequest) if err != nil { return nil, fmt.Errorf("failed to create pull request review: %w", err) @@ -745,8 +898,8 @@ func createPullRequestReview(client *github.Client, t translations.TranslationHe } } -// createPullRequest creates a tool to create a new pull request. -func createPullRequest(client *github.Client, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { +// CreatePullRequest creates a tool to create a new pull request. +func CreatePullRequest(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { return mcp.NewTool("create_pull_request", mcp.WithDescription(t("TOOL_CREATE_PULL_REQUEST_DESCRIPTION", "Create a new pull request in a GitHub repository")), mcp.WithString("owner", @@ -801,17 +954,17 @@ func createPullRequest(client *github.Client, t translations.TranslationHelperFu return mcp.NewToolResultError(err.Error()), nil } - body, err := optionalParam[string](request, "body") + body, err := OptionalParam[string](request, "body") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - draft, err := optionalParam[bool](request, "draft") + draft, err := OptionalParam[bool](request, "draft") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - maintainerCanModify, err := optionalParam[bool](request, "maintainer_can_modify") + maintainerCanModify, err := OptionalParam[bool](request, "maintainer_can_modify") if err != nil { return mcp.NewToolResultError(err.Error()), nil } @@ -829,6 +982,10 @@ func createPullRequest(client *github.Client, t translations.TranslationHelperFu newPR.Draft = github.Ptr(draft) newPR.MaintainerCanModify = github.Ptr(maintainerCanModify) + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } pr, resp, err := client.PullRequests.Create(ctx, owner, repo, newPR) if err != nil { return nil, fmt.Errorf("failed to create pull request: %w", err) diff --git a/pkg/github/pullrequests_test.go b/pkg/github/pullrequests_test.go index cf1afcdc..3c20dfc2 100644 --- a/pkg/github/pullrequests_test.go +++ b/pkg/github/pullrequests_test.go @@ -17,7 +17,7 @@ import ( func Test_GetPullRequest(t *testing.T) { // Verify tool definition once mockClient := github.NewClient(nil) - tool, _ := getPullRequest(mockClient, translations.NullTranslationHelper) + tool, _ := GetPullRequest(stubGetClientFn(mockClient), translations.NullTranslationHelper) assert.Equal(t, "get_pull_request", tool.Name) assert.NotEmpty(t, tool.Description) @@ -94,7 +94,7 @@ func Test_GetPullRequest(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := getPullRequest(client, translations.NullTranslationHelper) + _, handler := GetPullRequest(stubGetClientFn(client), translations.NullTranslationHelper) // Create call request request := createMCPRequest(tc.requestArgs) @@ -126,10 +126,192 @@ func Test_GetPullRequest(t *testing.T) { } } +func Test_UpdatePullRequest(t *testing.T) { + // Verify tool definition once + mockClient := github.NewClient(nil) + tool, _ := UpdatePullRequest(stubGetClientFn(mockClient), translations.NullTranslationHelper) + + assert.Equal(t, "update_pull_request", tool.Name) + assert.NotEmpty(t, tool.Description) + assert.Contains(t, tool.InputSchema.Properties, "owner") + assert.Contains(t, tool.InputSchema.Properties, "repo") + assert.Contains(t, tool.InputSchema.Properties, "pullNumber") + assert.Contains(t, tool.InputSchema.Properties, "title") + assert.Contains(t, tool.InputSchema.Properties, "body") + assert.Contains(t, tool.InputSchema.Properties, "state") + assert.Contains(t, tool.InputSchema.Properties, "base") + assert.Contains(t, tool.InputSchema.Properties, "maintainer_can_modify") + assert.ElementsMatch(t, tool.InputSchema.Required, []string{"owner", "repo", "pullNumber"}) + + // Setup mock PR for success case + mockUpdatedPR := &github.PullRequest{ + Number: github.Ptr(42), + Title: github.Ptr("Updated Test PR Title"), + State: github.Ptr("open"), + HTMLURL: github.Ptr("https://github.com/owner/repo/pull/42"), + Body: github.Ptr("Updated test PR body."), + MaintainerCanModify: github.Ptr(false), + Base: &github.PullRequestBranch{ + Ref: github.Ptr("develop"), + }, + } + + mockClosedPR := &github.PullRequest{ + Number: github.Ptr(42), + Title: github.Ptr("Test PR"), + State: github.Ptr("closed"), // State updated + } + + tests := []struct { + name string + mockedClient *http.Client + requestArgs map[string]interface{} + expectError bool + expectedPR *github.PullRequest + expectedErrMsg string + }{ + { + name: "successful PR update (title, body, base, maintainer_can_modify)", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.PatchReposPullsByOwnerByRepoByPullNumber, + // Expect the flat string based on previous test failure output and API docs + expectRequestBody(t, map[string]interface{}{ + "title": "Updated Test PR Title", + "body": "Updated test PR body.", + "base": "develop", + "maintainer_can_modify": false, + }).andThen( + mockResponse(t, http.StatusOK, mockUpdatedPR), + ), + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "pullNumber": float64(42), + "title": "Updated Test PR Title", + "body": "Updated test PR body.", + "base": "develop", + "maintainer_can_modify": false, + }, + expectError: false, + expectedPR: mockUpdatedPR, + }, + { + name: "successful PR update (state)", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.PatchReposPullsByOwnerByRepoByPullNumber, + expectRequestBody(t, map[string]interface{}{ + "state": "closed", + }).andThen( + mockResponse(t, http.StatusOK, mockClosedPR), + ), + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "pullNumber": float64(42), + "state": "closed", + }, + expectError: false, + expectedPR: mockClosedPR, + }, + { + name: "no update parameters provided", + mockedClient: mock.NewMockedHTTPClient(), // No API call expected + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "pullNumber": float64(42), + // No update fields + }, + expectError: false, // Error is returned in the result, not as Go error + expectedErrMsg: "No update parameters provided", + }, + { + name: "PR update fails (API error)", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.PatchReposPullsByOwnerByRepoByPullNumber, + http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusUnprocessableEntity) + _, _ = w.Write([]byte(`{"message": "Validation Failed"}`)) + }), + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "pullNumber": float64(42), + "title": "Invalid Title Causing Error", + }, + expectError: true, + expectedErrMsg: "failed to update pull request", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Setup client with mock + client := github.NewClient(tc.mockedClient) + _, handler := UpdatePullRequest(stubGetClientFn(client), translations.NullTranslationHelper) + + // Create call request + request := createMCPRequest(tc.requestArgs) + + // Call handler + result, err := handler(context.Background(), request) + + // Verify results + if tc.expectError { + require.Error(t, err) + assert.Contains(t, err.Error(), tc.expectedErrMsg) + return + } + + require.NoError(t, err) + + // Parse the result and get the text content + textContent := getTextResult(t, result) + + // Check for expected error message within the result text + if tc.expectedErrMsg != "" { + assert.Contains(t, textContent.Text, tc.expectedErrMsg) + return + } + + // Unmarshal and verify the successful result + var returnedPR github.PullRequest + err = json.Unmarshal([]byte(textContent.Text), &returnedPR) + require.NoError(t, err) + assert.Equal(t, *tc.expectedPR.Number, *returnedPR.Number) + if tc.expectedPR.Title != nil { + assert.Equal(t, *tc.expectedPR.Title, *returnedPR.Title) + } + if tc.expectedPR.Body != nil { + assert.Equal(t, *tc.expectedPR.Body, *returnedPR.Body) + } + if tc.expectedPR.State != nil { + assert.Equal(t, *tc.expectedPR.State, *returnedPR.State) + } + if tc.expectedPR.Base != nil && tc.expectedPR.Base.Ref != nil { + assert.NotNil(t, returnedPR.Base) + assert.Equal(t, *tc.expectedPR.Base.Ref, *returnedPR.Base.Ref) + } + if tc.expectedPR.MaintainerCanModify != nil { + assert.Equal(t, *tc.expectedPR.MaintainerCanModify, *returnedPR.MaintainerCanModify) + } + }) + } +} + func Test_ListPullRequests(t *testing.T) { // Verify tool definition once mockClient := github.NewClient(nil) - tool, _ := listPullRequests(mockClient, translations.NullTranslationHelper) + tool, _ := ListPullRequests(stubGetClientFn(mockClient), translations.NullTranslationHelper) assert.Equal(t, "list_pull_requests", tool.Name) assert.NotEmpty(t, tool.Description) @@ -221,7 +403,7 @@ func Test_ListPullRequests(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := listPullRequests(client, translations.NullTranslationHelper) + _, handler := ListPullRequests(stubGetClientFn(client), translations.NullTranslationHelper) // Create call request request := createMCPRequest(tc.requestArgs) @@ -259,7 +441,7 @@ func Test_ListPullRequests(t *testing.T) { func Test_MergePullRequest(t *testing.T) { // Verify tool definition once mockClient := github.NewClient(nil) - tool, _ := mergePullRequest(mockClient, translations.NullTranslationHelper) + tool, _ := MergePullRequest(stubGetClientFn(mockClient), translations.NullTranslationHelper) assert.Equal(t, "merge_pull_request", tool.Name) assert.NotEmpty(t, tool.Description) @@ -336,7 +518,7 @@ func Test_MergePullRequest(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := mergePullRequest(client, translations.NullTranslationHelper) + _, handler := MergePullRequest(stubGetClientFn(client), translations.NullTranslationHelper) // Create call request request := createMCPRequest(tc.requestArgs) @@ -370,7 +552,7 @@ func Test_MergePullRequest(t *testing.T) { func Test_GetPullRequestFiles(t *testing.T) { // Verify tool definition once mockClient := github.NewClient(nil) - tool, _ := getPullRequestFiles(mockClient, translations.NullTranslationHelper) + tool, _ := GetPullRequestFiles(stubGetClientFn(mockClient), translations.NullTranslationHelper) assert.Equal(t, "get_pull_request_files", tool.Name) assert.NotEmpty(t, tool.Description) @@ -448,7 +630,7 @@ func Test_GetPullRequestFiles(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := getPullRequestFiles(client, translations.NullTranslationHelper) + _, handler := GetPullRequestFiles(stubGetClientFn(client), translations.NullTranslationHelper) // Create call request request := createMCPRequest(tc.requestArgs) @@ -486,7 +668,7 @@ func Test_GetPullRequestFiles(t *testing.T) { func Test_GetPullRequestStatus(t *testing.T) { // Verify tool definition once mockClient := github.NewClient(nil) - tool, _ := getPullRequestStatus(mockClient, translations.NullTranslationHelper) + tool, _ := GetPullRequestStatus(stubGetClientFn(mockClient), translations.NullTranslationHelper) assert.Equal(t, "get_pull_request_status", tool.Name) assert.NotEmpty(t, tool.Description) @@ -608,7 +790,7 @@ func Test_GetPullRequestStatus(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := getPullRequestStatus(client, translations.NullTranslationHelper) + _, handler := GetPullRequestStatus(stubGetClientFn(client), translations.NullTranslationHelper) // Create call request request := createMCPRequest(tc.requestArgs) @@ -647,7 +829,7 @@ func Test_GetPullRequestStatus(t *testing.T) { func Test_UpdatePullRequestBranch(t *testing.T) { // Verify tool definition once mockClient := github.NewClient(nil) - tool, _ := updatePullRequestBranch(mockClient, translations.NullTranslationHelper) + tool, _ := UpdatePullRequestBranch(stubGetClientFn(mockClient), translations.NullTranslationHelper) assert.Equal(t, "update_pull_request_branch", tool.Name) assert.NotEmpty(t, tool.Description) @@ -735,7 +917,7 @@ func Test_UpdatePullRequestBranch(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := updatePullRequestBranch(client, translations.NullTranslationHelper) + _, handler := UpdatePullRequestBranch(stubGetClientFn(client), translations.NullTranslationHelper) // Create call request request := createMCPRequest(tc.requestArgs) @@ -763,7 +945,7 @@ func Test_UpdatePullRequestBranch(t *testing.T) { func Test_GetPullRequestComments(t *testing.T) { // Verify tool definition once mockClient := github.NewClient(nil) - tool, _ := getPullRequestComments(mockClient, translations.NullTranslationHelper) + tool, _ := GetPullRequestComments(stubGetClientFn(mockClient), translations.NullTranslationHelper) assert.Equal(t, "get_pull_request_comments", tool.Name) assert.NotEmpty(t, tool.Description) @@ -851,7 +1033,7 @@ func Test_GetPullRequestComments(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := getPullRequestComments(client, translations.NullTranslationHelper) + _, handler := GetPullRequestComments(stubGetClientFn(client), translations.NullTranslationHelper) // Create call request request := createMCPRequest(tc.requestArgs) @@ -890,7 +1072,7 @@ func Test_GetPullRequestComments(t *testing.T) { func Test_GetPullRequestReviews(t *testing.T) { // Verify tool definition once mockClient := github.NewClient(nil) - tool, _ := getPullRequestReviews(mockClient, translations.NullTranslationHelper) + tool, _ := GetPullRequestReviews(stubGetClientFn(mockClient), translations.NullTranslationHelper) assert.Equal(t, "get_pull_request_reviews", tool.Name) assert.NotEmpty(t, tool.Description) @@ -974,7 +1156,7 @@ func Test_GetPullRequestReviews(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := getPullRequestReviews(client, translations.NullTranslationHelper) + _, handler := GetPullRequestReviews(stubGetClientFn(client), translations.NullTranslationHelper) // Create call request request := createMCPRequest(tc.requestArgs) @@ -1013,7 +1195,7 @@ func Test_GetPullRequestReviews(t *testing.T) { func Test_CreatePullRequestReview(t *testing.T) { // Verify tool definition once mockClient := github.NewClient(nil) - tool, _ := createPullRequestReview(mockClient, translations.NullTranslationHelper) + tool, _ := CreatePullRequestReview(stubGetClientFn(mockClient), translations.NullTranslationHelper) assert.Equal(t, "create_pull_request_review", tool.Name) assert.NotEmpty(t, tool.Description) @@ -1341,7 +1523,7 @@ func Test_CreatePullRequestReview(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := createPullRequestReview(client, translations.NullTranslationHelper) + _, handler := CreatePullRequestReview(stubGetClientFn(client), translations.NullTranslationHelper) // Create call request request := createMCPRequest(tc.requestArgs) @@ -1384,7 +1566,7 @@ func Test_CreatePullRequestReview(t *testing.T) { func Test_CreatePullRequest(t *testing.T) { // Verify tool definition once mockClient := github.NewClient(nil) - tool, _ := createPullRequest(mockClient, translations.NullTranslationHelper) + tool, _ := CreatePullRequest(stubGetClientFn(mockClient), translations.NullTranslationHelper) assert.Equal(t, "create_pull_request", tool.Name) assert.NotEmpty(t, tool.Description) @@ -1496,7 +1678,7 @@ func Test_CreatePullRequest(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := createPullRequest(client, translations.NullTranslationHelper) + _, handler := CreatePullRequest(stubGetClientFn(client), translations.NullTranslationHelper) // Create call request request := createMCPRequest(tc.requestArgs) diff --git a/pkg/github/repositories.go b/pkg/github/repositories.go index c9a869e1..56500eaf 100644 --- a/pkg/github/repositories.go +++ b/pkg/github/repositories.go @@ -13,7 +13,7 @@ import ( "github.com/mark3labs/mcp-go/server" ) -func getCommit(client *github.Client, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { +func GetCommit(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { return mcp.NewTool("get_commit", mcp.WithDescription(t("TOOL_GET_COMMITS_DESCRIPTION", "Get details for a commit from a GitHub repository")), mcp.WithString("owner", @@ -28,7 +28,7 @@ func getCommit(client *github.Client, t translations.TranslationHelperFunc) (too mcp.Required(), mcp.Description("Commit SHA, branch name, or tag name"), ), - withPagination(), + WithPagination(), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { owner, err := requiredParam[string](request, "owner") @@ -43,7 +43,7 @@ func getCommit(client *github.Client, t translations.TranslationHelperFunc) (too if err != nil { return mcp.NewToolResultError(err.Error()), nil } - pagination, err := optionalPaginationParams(request) + pagination, err := OptionalPaginationParams(request) if err != nil { return mcp.NewToolResultError(err.Error()), nil } @@ -53,6 +53,10 @@ func getCommit(client *github.Client, t translations.TranslationHelperFunc) (too PerPage: pagination.perPage, } + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } commit, resp, err := client.Repositories.GetCommit(ctx, owner, repo, sha, opts) if err != nil { return nil, fmt.Errorf("failed to get commit: %w", err) @@ -76,9 +80,8 @@ func getCommit(client *github.Client, t translations.TranslationHelperFunc) (too } } - -// listCommits creates a tool to get commits of a branch in a repository. -func listCommits(client *github.Client, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { +// ListCommits creates a tool to get commits of a branch in a repository. +func ListCommits(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { return mcp.NewTool("list_commits", mcp.WithDescription(t("TOOL_LIST_COMMITS_DESCRIPTION", "Get list of commits of a branch in a GitHub repository")), mcp.WithString("owner", @@ -92,7 +95,7 @@ func listCommits(client *github.Client, t translations.TranslationHelperFunc) (t mcp.WithString("sha", mcp.Description("Branch name"), ), - withPagination(), + WithPagination(), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { owner, err := requiredParam[string](request, "owner") @@ -103,11 +106,11 @@ func listCommits(client *github.Client, t translations.TranslationHelperFunc) (t if err != nil { return mcp.NewToolResultError(err.Error()), nil } - sha, err := optionalParam[string](request, "sha") + sha, err := OptionalParam[string](request, "sha") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - pagination, err := optionalPaginationParams(request) + pagination, err := OptionalPaginationParams(request) if err != nil { return mcp.NewToolResultError(err.Error()), nil } @@ -120,6 +123,10 @@ func listCommits(client *github.Client, t translations.TranslationHelperFunc) (t }, } + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } commits, resp, err := client.Repositories.ListCommits(ctx, owner, repo, opts) if err != nil { return nil, fmt.Errorf("failed to list commits: %w", err) @@ -143,8 +150,8 @@ func listCommits(client *github.Client, t translations.TranslationHelperFunc) (t } } -// createOrUpdateFile creates a tool to create or update a file in a GitHub repository. -func createOrUpdateFile(client *github.Client, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { +// CreateOrUpdateFile creates a tool to create or update a file in a GitHub repository. +func CreateOrUpdateFile(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { return mcp.NewTool("create_or_update_file", mcp.WithDescription(t("TOOL_CREATE_OR_UPDATE_FILE_DESCRIPTION", "Create or update a single file in a GitHub repository")), mcp.WithString("owner", @@ -212,7 +219,7 @@ func createOrUpdateFile(client *github.Client, t translations.TranslationHelperF } // If SHA is provided, set it (for updates) - sha, err := optionalParam[string](request, "sha") + sha, err := OptionalParam[string](request, "sha") if err != nil { return mcp.NewToolResultError(err.Error()), nil } @@ -221,6 +228,10 @@ func createOrUpdateFile(client *github.Client, t translations.TranslationHelperF } // Create or update the file + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } fileContent, resp, err := client.Repositories.CreateFile(ctx, owner, repo, path, opts) if err != nil { return nil, fmt.Errorf("failed to create/update file: %w", err) @@ -244,8 +255,8 @@ func createOrUpdateFile(client *github.Client, t translations.TranslationHelperF } } -// createRepository creates a tool to create a new GitHub repository. -func createRepository(client *github.Client, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { +// CreateRepository creates a tool to create a new GitHub repository. +func CreateRepository(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { return mcp.NewTool("create_repository", mcp.WithDescription(t("TOOL_CREATE_REPOSITORY_DESCRIPTION", "Create a new GitHub repository in your account")), mcp.WithString("name", @@ -267,15 +278,15 @@ func createRepository(client *github.Client, t translations.TranslationHelperFun if err != nil { return mcp.NewToolResultError(err.Error()), nil } - description, err := optionalParam[string](request, "description") + description, err := OptionalParam[string](request, "description") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - private, err := optionalParam[bool](request, "private") + private, err := OptionalParam[bool](request, "private") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - autoInit, err := optionalParam[bool](request, "autoInit") + autoInit, err := OptionalParam[bool](request, "autoInit") if err != nil { return mcp.NewToolResultError(err.Error()), nil } @@ -287,6 +298,10 @@ func createRepository(client *github.Client, t translations.TranslationHelperFun AutoInit: github.Ptr(autoInit), } + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } createdRepo, resp, err := client.Repositories.Create(ctx, "", repo) if err != nil { return nil, fmt.Errorf("failed to create repository: %w", err) @@ -310,8 +325,8 @@ func createRepository(client *github.Client, t translations.TranslationHelperFun } } -// getFileContents creates a tool to get the contents of a file or directory from a GitHub repository. -func getFileContents(client *github.Client, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { +// GetFileContents creates a tool to get the contents of a file or directory from a GitHub repository. +func GetFileContents(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { return mcp.NewTool("get_file_contents", mcp.WithDescription(t("TOOL_GET_FILE_CONTENTS_DESCRIPTION", "Get the contents of a file or directory from a GitHub repository")), mcp.WithString("owner", @@ -343,11 +358,15 @@ func getFileContents(client *github.Client, t translations.TranslationHelperFunc if err != nil { return mcp.NewToolResultError(err.Error()), nil } - branch, err := optionalParam[string](request, "branch") + branch, err := OptionalParam[string](request, "branch") if err != nil { return mcp.NewToolResultError(err.Error()), nil } + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } opts := &github.RepositoryContentGetOptions{Ref: branch} fileContent, dirContent, resp, err := client.Repositories.GetContents(ctx, owner, repo, path, opts) if err != nil { @@ -379,8 +398,8 @@ func getFileContents(client *github.Client, t translations.TranslationHelperFunc } } -// forkRepository creates a tool to fork a repository. -func forkRepository(client *github.Client, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { +// ForkRepository creates a tool to fork a repository. +func ForkRepository(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { return mcp.NewTool("fork_repository", mcp.WithDescription(t("TOOL_FORK_REPOSITORY_DESCRIPTION", "Fork a GitHub repository to your account or specified organization")), mcp.WithString("owner", @@ -404,7 +423,7 @@ func forkRepository(client *github.Client, t translations.TranslationHelperFunc) if err != nil { return mcp.NewToolResultError(err.Error()), nil } - org, err := optionalParam[string](request, "organization") + org, err := OptionalParam[string](request, "organization") if err != nil { return mcp.NewToolResultError(err.Error()), nil } @@ -414,6 +433,10 @@ func forkRepository(client *github.Client, t translations.TranslationHelperFunc) opts.Organization = org } + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } forkedRepo, resp, err := client.Repositories.CreateFork(ctx, owner, repo, opts) if err != nil { // Check if it's an acceptedError. An acceptedError indicates that the update is in progress, @@ -442,8 +465,8 @@ func forkRepository(client *github.Client, t translations.TranslationHelperFunc) } } -// createBranch creates a tool to create a new branch. -func createBranch(client *github.Client, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { +// CreateBranch creates a tool to create a new branch. +func CreateBranch(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { return mcp.NewTool("create_branch", mcp.WithDescription(t("TOOL_CREATE_BRANCH_DESCRIPTION", "Create a new branch in a GitHub repository")), mcp.WithString("owner", @@ -475,11 +498,16 @@ func createBranch(client *github.Client, t translations.TranslationHelperFunc) ( if err != nil { return mcp.NewToolResultError(err.Error()), nil } - fromBranch, err := optionalParam[string](request, "from_branch") + fromBranch, err := OptionalParam[string](request, "from_branch") if err != nil { return mcp.NewToolResultError(err.Error()), nil } + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } + // Get the source branch SHA var ref *github.Reference @@ -522,8 +550,8 @@ func createBranch(client *github.Client, t translations.TranslationHelperFunc) ( } } -// pushFiles creates a tool to push multiple files in a single commit to a GitHub repository. -func pushFiles(client *github.Client, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { +// PushFiles creates a tool to push multiple files in a single commit to a GitHub repository. +func PushFiles(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { return mcp.NewTool("push_files", mcp.WithDescription(t("TOOL_PUSH_FILES_DESCRIPTION", "Push multiple files to a GitHub repository in a single commit")), mcp.WithString("owner", @@ -587,6 +615,11 @@ func pushFiles(client *github.Client, t translations.TranslationHelperFunc) (too return mcp.NewToolResultError("files parameter must be an array of objects with path and content"), nil } + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } + // Get the reference for the branch ref, resp, err := client.Git.GetRef(ctx, owner, repo, "refs/heads/"+branch) if err != nil { diff --git a/pkg/github/repositories_test.go b/pkg/github/repositories_test.go index c8bcb7b4..20f96dde 100644 --- a/pkg/github/repositories_test.go +++ b/pkg/github/repositories_test.go @@ -15,43 +15,10 @@ import ( "github.com/stretchr/testify/require" ) -var mockCommits = []*github.RepositoryCommit{ - { - SHA: github.Ptr("abc123def456"), - Commit: &github.Commit{ - Message: github.Ptr("First commit"), - Author: &github.CommitAuthor{ - Name: github.Ptr("Test User"), - Email: github.Ptr("test@example.com"), - Date: &github.Timestamp{Time: time.Now().Add(-48 * time.Hour)}, - }, - }, - Author: &github.User{ - Login: github.Ptr("testuser"), - }, - HTMLURL: github.Ptr("https://github.com/owner/repo/commit/abc123def456"), - }, - { - SHA: github.Ptr("def456abc789"), - Commit: &github.Commit{ - Message: github.Ptr("Second commit"), - Author: &github.CommitAuthor{ - Name: github.Ptr("Another User"), - Email: github.Ptr("another@example.com"), - Date: &github.Timestamp{Time: time.Now().Add(-24 * time.Hour)}, - }, - }, - Author: &github.User{ - Login: github.Ptr("anotheruser"), - }, - HTMLURL: github.Ptr("https://github.com/owner/repo/commit/def456abc789"), - }, -} - func Test_GetFileContents(t *testing.T) { // Verify tool definition once mockClient := github.NewClient(nil) - tool, _ := getFileContents(mockClient, translations.NullTranslationHelper) + tool, _ := GetFileContents(stubGetClientFn(mockClient), translations.NullTranslationHelper) assert.Equal(t, "get_file_contents", tool.Name) assert.NotEmpty(t, tool.Description) @@ -165,7 +132,7 @@ func Test_GetFileContents(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := getFileContents(client, translations.NullTranslationHelper) + _, handler := GetFileContents(stubGetClientFn(client), translations.NullTranslationHelper) // Create call request request := mcp.CallToolRequest{ @@ -222,7 +189,7 @@ func Test_GetFileContents(t *testing.T) { func Test_ForkRepository(t *testing.T) { // Verify tool definition once mockClient := github.NewClient(nil) - tool, _ := forkRepository(mockClient, translations.NullTranslationHelper) + tool, _ := ForkRepository(stubGetClientFn(mockClient), translations.NullTranslationHelper) assert.Equal(t, "fork_repository", tool.Name) assert.NotEmpty(t, tool.Description) @@ -292,7 +259,7 @@ func Test_ForkRepository(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := forkRepository(client, translations.NullTranslationHelper) + _, handler := ForkRepository(stubGetClientFn(client), translations.NullTranslationHelper) // Create call request request := createMCPRequest(tc.requestArgs) @@ -320,7 +287,7 @@ func Test_ForkRepository(t *testing.T) { func Test_CreateBranch(t *testing.T) { // Verify tool definition once mockClient := github.NewClient(nil) - tool, _ := createBranch(mockClient, translations.NullTranslationHelper) + tool, _ := CreateBranch(stubGetClientFn(mockClient), translations.NullTranslationHelper) assert.Equal(t, "create_branch", tool.Name) assert.NotEmpty(t, tool.Description) @@ -478,7 +445,7 @@ func Test_CreateBranch(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := createBranch(client, translations.NullTranslationHelper) + _, handler := CreateBranch(stubGetClientFn(client), translations.NullTranslationHelper) // Create call request request := createMCPRequest(tc.requestArgs) @@ -511,7 +478,7 @@ func Test_CreateBranch(t *testing.T) { func Test_GetCommit(t *testing.T) { // Verify tool definition once mockClient := github.NewClient(nil) - tool, _ := getCommit(mockClient, translations.NullTranslationHelper) + tool, _ := GetCommit(stubGetClientFn(mockClient), translations.NullTranslationHelper) assert.Equal(t, "get_commit", tool.Name) assert.NotEmpty(t, tool.Description) @@ -520,7 +487,36 @@ func Test_GetCommit(t *testing.T) { assert.Contains(t, tool.InputSchema.Properties, "sha") assert.ElementsMatch(t, tool.InputSchema.Required, []string{"owner", "repo", "sha"}) - mockCommit := mockCommits[0] + mockCommit := &github.RepositoryCommit{ + SHA: github.Ptr("abc123def456"), + Commit: &github.Commit{ + Message: github.Ptr("First commit"), + Author: &github.CommitAuthor{ + Name: github.Ptr("Test User"), + Email: github.Ptr("test@example.com"), + Date: &github.Timestamp{Time: time.Now().Add(-48 * time.Hour)}, + }, + }, + Author: &github.User{ + Login: github.Ptr("testuser"), + }, + HTMLURL: github.Ptr("https://github.com/owner/repo/commit/abc123def456"), + Stats: &github.CommitStats{ + Additions: github.Ptr(10), + Deletions: github.Ptr(2), + Total: github.Ptr(12), + }, + Files: []*github.CommitFile{ + { + Filename: github.Ptr("file1.go"), + Status: github.Ptr("modified"), + Additions: github.Ptr(10), + Deletions: github.Ptr(2), + Changes: github.Ptr(12), + Patch: github.Ptr("@@ -1,2 +1,10 @@"), + }, + }, + } // This one currently isn't defined in the mock package we're using. var mockEndpointPattern = mock.EndpointPattern{ Pattern: "/repos/{owner}/{repo}/commits/{sha}", @@ -576,7 +572,7 @@ func Test_GetCommit(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := getCommit(client, translations.NullTranslationHelper) + _, handler := GetCommit(stubGetClientFn(client), translations.NullTranslationHelper) // Create call request request := createMCPRequest(tc.requestArgs) @@ -600,7 +596,7 @@ func Test_GetCommit(t *testing.T) { var returnedCommit github.RepositoryCommit err = json.Unmarshal([]byte(textContent.Text), &returnedCommit) require.NoError(t, err) - + assert.Equal(t, *tc.expectedCommit.SHA, *returnedCommit.SHA) assert.Equal(t, *tc.expectedCommit.Commit.Message, *returnedCommit.Commit.Message) assert.Equal(t, *tc.expectedCommit.Author.Login, *returnedCommit.Author.Login) @@ -612,7 +608,7 @@ func Test_GetCommit(t *testing.T) { func Test_ListCommits(t *testing.T) { // Verify tool definition once mockClient := github.NewClient(nil) - tool, _ := listCommits(mockClient, translations.NullTranslationHelper) + tool, _ := ListCommits(stubGetClientFn(mockClient), translations.NullTranslationHelper) assert.Equal(t, "list_commits", tool.Name) assert.NotEmpty(t, tool.Description) @@ -622,7 +618,41 @@ func Test_ListCommits(t *testing.T) { assert.Contains(t, tool.InputSchema.Properties, "page") assert.Contains(t, tool.InputSchema.Properties, "perPage") assert.ElementsMatch(t, tool.InputSchema.Required, []string{"owner", "repo"}) - + + // Setup mock commits for success case + mockCommits := []*github.RepositoryCommit{ + { + SHA: github.Ptr("abc123def456"), + Commit: &github.Commit{ + Message: github.Ptr("First commit"), + Author: &github.CommitAuthor{ + Name: github.Ptr("Test User"), + Email: github.Ptr("test@example.com"), + Date: &github.Timestamp{Time: time.Now().Add(-48 * time.Hour)}, + }, + }, + Author: &github.User{ + Login: github.Ptr("testuser"), + }, + HTMLURL: github.Ptr("https://github.com/owner/repo/commit/abc123def456"), + }, + { + SHA: github.Ptr("def456abc789"), + Commit: &github.Commit{ + Message: github.Ptr("Second commit"), + Author: &github.CommitAuthor{ + Name: github.Ptr("Another User"), + Email: github.Ptr("another@example.com"), + Date: &github.Timestamp{Time: time.Now().Add(-24 * time.Hour)}, + }, + }, + Author: &github.User{ + Login: github.Ptr("anotheruser"), + }, + HTMLURL: github.Ptr("https://github.com/owner/repo/commit/def456abc789"), + }, + } + tests := []struct { name string mockedClient *http.Client @@ -714,7 +744,7 @@ func Test_ListCommits(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := listCommits(client, translations.NullTranslationHelper) + _, handler := ListCommits(stubGetClientFn(client), translations.NullTranslationHelper) // Create call request request := createMCPRequest(tc.requestArgs) @@ -752,7 +782,7 @@ func Test_ListCommits(t *testing.T) { func Test_CreateOrUpdateFile(t *testing.T) { // Verify tool definition once mockClient := github.NewClient(nil) - tool, _ := createOrUpdateFile(mockClient, translations.NullTranslationHelper) + tool, _ := CreateOrUpdateFile(stubGetClientFn(mockClient), translations.NullTranslationHelper) assert.Equal(t, "create_or_update_file", tool.Name) assert.NotEmpty(t, tool.Description) @@ -875,7 +905,7 @@ func Test_CreateOrUpdateFile(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := createOrUpdateFile(client, translations.NullTranslationHelper) + _, handler := CreateOrUpdateFile(stubGetClientFn(client), translations.NullTranslationHelper) // Create call request request := createMCPRequest(tc.requestArgs) @@ -915,7 +945,7 @@ func Test_CreateOrUpdateFile(t *testing.T) { func Test_CreateRepository(t *testing.T) { // Verify tool definition once mockClient := github.NewClient(nil) - tool, _ := createRepository(mockClient, translations.NullTranslationHelper) + tool, _ := CreateRepository(stubGetClientFn(mockClient), translations.NullTranslationHelper) assert.Equal(t, "create_repository", tool.Name) assert.NotEmpty(t, tool.Description) @@ -1023,7 +1053,7 @@ func Test_CreateRepository(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := createRepository(client, translations.NullTranslationHelper) + _, handler := CreateRepository(stubGetClientFn(client), translations.NullTranslationHelper) // Create call request request := createMCPRequest(tc.requestArgs) @@ -1061,7 +1091,7 @@ func Test_CreateRepository(t *testing.T) { func Test_PushFiles(t *testing.T) { // Verify tool definition once mockClient := github.NewClient(nil) - tool, _ := pushFiles(mockClient, translations.NullTranslationHelper) + tool, _ := PushFiles(stubGetClientFn(mockClient), translations.NullTranslationHelper) assert.Equal(t, "push_files", tool.Name) assert.NotEmpty(t, tool.Description) @@ -1356,7 +1386,7 @@ func Test_PushFiles(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := pushFiles(client, translations.NullTranslationHelper) + _, handler := PushFiles(stubGetClientFn(client), translations.NullTranslationHelper) // Create call request request := createMCPRequest(tc.requestArgs) diff --git a/pkg/github/repository_resource.go b/pkg/github/repository_resource.go index 8b2ba7a7..949157f5 100644 --- a/pkg/github/repository_resource.go +++ b/pkg/github/repository_resource.go @@ -17,52 +17,53 @@ import ( "github.com/mark3labs/mcp-go/server" ) -// getRepositoryResourceContent defines the resource template and handler for getting repository content. -func getRepositoryResourceContent(client *github.Client, t translations.TranslationHelperFunc) (mcp.ResourceTemplate, server.ResourceTemplateHandlerFunc) { +// GetRepositoryResourceContent defines the resource template and handler for getting repository content. +func GetRepositoryResourceContent(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.ResourceTemplate, server.ResourceTemplateHandlerFunc) { return mcp.NewResourceTemplate( "repo://{owner}/{repo}/contents{/path*}", // Resource template t("RESOURCE_REPOSITORY_CONTENT_DESCRIPTION", "Repository Content"), ), - repositoryResourceContentsHandler(client) + RepositoryResourceContentsHandler(getClient) } -// getRepositoryContent defines the resource template and handler for getting repository content for a branch. -func getRepositoryResourceBranchContent(client *github.Client, t translations.TranslationHelperFunc) (mcp.ResourceTemplate, server.ResourceTemplateHandlerFunc) { +// GetRepositoryResourceBranchContent defines the resource template and handler for getting repository content for a branch. +func GetRepositoryResourceBranchContent(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.ResourceTemplate, server.ResourceTemplateHandlerFunc) { return mcp.NewResourceTemplate( "repo://{owner}/{repo}/refs/heads/{branch}/contents{/path*}", // Resource template t("RESOURCE_REPOSITORY_CONTENT_BRANCH_DESCRIPTION", "Repository Content for specific branch"), ), - repositoryResourceContentsHandler(client) + RepositoryResourceContentsHandler(getClient) } -// getRepositoryResourceCommitContent defines the resource template and handler for getting repository content for a commit. -func getRepositoryResourceCommitContent(client *github.Client, t translations.TranslationHelperFunc) (mcp.ResourceTemplate, server.ResourceTemplateHandlerFunc) { +// GetRepositoryResourceCommitContent defines the resource template and handler for getting repository content for a commit. +func GetRepositoryResourceCommitContent(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.ResourceTemplate, server.ResourceTemplateHandlerFunc) { return mcp.NewResourceTemplate( "repo://{owner}/{repo}/sha/{sha}/contents{/path*}", // Resource template t("RESOURCE_REPOSITORY_CONTENT_COMMIT_DESCRIPTION", "Repository Content for specific commit"), ), - repositoryResourceContentsHandler(client) + RepositoryResourceContentsHandler(getClient) } -// getRepositoryResourceTagContent defines the resource template and handler for getting repository content for a tag. -func getRepositoryResourceTagContent(client *github.Client, t translations.TranslationHelperFunc) (mcp.ResourceTemplate, server.ResourceTemplateHandlerFunc) { +// GetRepositoryResourceTagContent defines the resource template and handler for getting repository content for a tag. +func GetRepositoryResourceTagContent(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.ResourceTemplate, server.ResourceTemplateHandlerFunc) { return mcp.NewResourceTemplate( "repo://{owner}/{repo}/refs/tags/{tag}/contents{/path*}", // Resource template t("RESOURCE_REPOSITORY_CONTENT_TAG_DESCRIPTION", "Repository Content for specific tag"), ), - repositoryResourceContentsHandler(client) + RepositoryResourceContentsHandler(getClient) } -// getRepositoryResourcePrContent defines the resource template and handler for getting repository content for a pull request. -func getRepositoryResourcePrContent(client *github.Client, t translations.TranslationHelperFunc) (mcp.ResourceTemplate, server.ResourceTemplateHandlerFunc) { +// GetRepositoryResourcePrContent defines the resource template and handler for getting repository content for a pull request. +func GetRepositoryResourcePrContent(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.ResourceTemplate, server.ResourceTemplateHandlerFunc) { return mcp.NewResourceTemplate( "repo://{owner}/{repo}/refs/pull/{prNumber}/head/contents{/path*}", // Resource template t("RESOURCE_REPOSITORY_CONTENT_PR_DESCRIPTION", "Repository Content for specific pull request"), ), - repositoryResourceContentsHandler(client) + RepositoryResourceContentsHandler(getClient) } -func repositoryResourceContentsHandler(client *github.Client) func(ctx context.Context, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { +// RepositoryResourceContentsHandler returns a handler function for repository content requests. +func RepositoryResourceContentsHandler(getClient GetClientFn) func(ctx context.Context, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { return func(ctx context.Context, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { // the matcher will give []string with one element // https://github.com/mark3labs/mcp-go/pull/54 @@ -106,6 +107,10 @@ func repositoryResourceContentsHandler(client *github.Client) func(ctx context.C opts.Ref = "refs/pull/" + prNumber[0] + "/head" } + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } fileContent, directoryContent, _, err := client.Repositories.GetContents(ctx, owner, repo, path, opts) if err != nil { return nil, err diff --git a/pkg/github/repository_resource_test.go b/pkg/github/repository_resource_test.go index adad8744..ffd14be3 100644 --- a/pkg/github/repository_resource_test.go +++ b/pkg/github/repository_resource_test.go @@ -234,7 +234,7 @@ func Test_repositoryResourceContentsHandler(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { client := github.NewClient(tc.mockedClient) - handler := repositoryResourceContentsHandler(client) + handler := RepositoryResourceContentsHandler((stubGetClientFn(client))) request := mcp.ReadResourceRequest{ Params: struct { @@ -258,26 +258,26 @@ func Test_repositoryResourceContentsHandler(t *testing.T) { } } -func Test_getRepositoryResourceContent(t *testing.T) { - tmpl, _ := getRepositoryResourceContent(nil, translations.NullTranslationHelper) +func Test_GetRepositoryResourceContent(t *testing.T) { + tmpl, _ := GetRepositoryResourceContent(nil, translations.NullTranslationHelper) require.Equal(t, "repo://{owner}/{repo}/contents{/path*}", tmpl.URITemplate.Raw()) } -func Test_getRepositoryResourceBranchContent(t *testing.T) { - tmpl, _ := getRepositoryResourceBranchContent(nil, translations.NullTranslationHelper) +func Test_GetRepositoryResourceBranchContent(t *testing.T) { + tmpl, _ := GetRepositoryResourceBranchContent(nil, translations.NullTranslationHelper) require.Equal(t, "repo://{owner}/{repo}/refs/heads/{branch}/contents{/path*}", tmpl.URITemplate.Raw()) } -func Test_getRepositoryResourceCommitContent(t *testing.T) { - tmpl, _ := getRepositoryResourceCommitContent(nil, translations.NullTranslationHelper) +func Test_GetRepositoryResourceCommitContent(t *testing.T) { + tmpl, _ := GetRepositoryResourceCommitContent(nil, translations.NullTranslationHelper) require.Equal(t, "repo://{owner}/{repo}/sha/{sha}/contents{/path*}", tmpl.URITemplate.Raw()) } -func Test_getRepositoryResourceTagContent(t *testing.T) { - tmpl, _ := getRepositoryResourceTagContent(nil, translations.NullTranslationHelper) +func Test_GetRepositoryResourceTagContent(t *testing.T) { + tmpl, _ := GetRepositoryResourceTagContent(nil, translations.NullTranslationHelper) require.Equal(t, "repo://{owner}/{repo}/refs/tags/{tag}/contents{/path*}", tmpl.URITemplate.Raw()) } -func Test_getRepositoryResourcePrContent(t *testing.T) { - tmpl, _ := getRepositoryResourcePrContent(nil, translations.NullTranslationHelper) +func Test_GetRepositoryResourcePrContent(t *testing.T) { + tmpl, _ := GetRepositoryResourcePrContent(nil, translations.NullTranslationHelper) require.Equal(t, "repo://{owner}/{repo}/refs/pull/{prNumber}/head/contents{/path*}", tmpl.URITemplate.Raw()) } diff --git a/pkg/github/search.go b/pkg/github/search.go index 117e8298..75810e24 100644 --- a/pkg/github/search.go +++ b/pkg/github/search.go @@ -12,22 +12,22 @@ import ( "github.com/mark3labs/mcp-go/server" ) -// searchRepositories creates a tool to search for GitHub repositories. -func searchRepositories(client *github.Client, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { +// SearchRepositories creates a tool to search for GitHub repositories. +func SearchRepositories(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { return mcp.NewTool("search_repositories", mcp.WithDescription(t("TOOL_SEARCH_REPOSITORIES_DESCRIPTION", "Search for GitHub repositories")), mcp.WithString("query", mcp.Required(), mcp.Description("Search query"), ), - withPagination(), + WithPagination(), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { query, err := requiredParam[string](request, "query") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - pagination, err := optionalPaginationParams(request) + pagination, err := OptionalPaginationParams(request) if err != nil { return mcp.NewToolResultError(err.Error()), nil } @@ -39,6 +39,10 @@ func searchRepositories(client *github.Client, t translations.TranslationHelperF }, } + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } result, resp, err := client.Search.Repositories(ctx, query, opts) if err != nil { return nil, fmt.Errorf("failed to search repositories: %w", err) @@ -62,8 +66,8 @@ func searchRepositories(client *github.Client, t translations.TranslationHelperF } } -// searchCode creates a tool to search for code across GitHub repositories. -func searchCode(client *github.Client, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { +// SearchCode creates a tool to search for code across GitHub repositories. +func SearchCode(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { return mcp.NewTool("search_code", mcp.WithDescription(t("TOOL_SEARCH_CODE_DESCRIPTION", "Search for code across GitHub repositories")), mcp.WithString("q", @@ -77,22 +81,22 @@ func searchCode(client *github.Client, t translations.TranslationHelperFunc) (to mcp.Description("Sort order ('asc' or 'desc')"), mcp.Enum("asc", "desc"), ), - withPagination(), + WithPagination(), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { query, err := requiredParam[string](request, "q") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - sort, err := optionalParam[string](request, "sort") + sort, err := OptionalParam[string](request, "sort") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - order, err := optionalParam[string](request, "order") + order, err := OptionalParam[string](request, "order") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - pagination, err := optionalPaginationParams(request) + pagination, err := OptionalPaginationParams(request) if err != nil { return mcp.NewToolResultError(err.Error()), nil } @@ -106,6 +110,11 @@ func searchCode(client *github.Client, t translations.TranslationHelperFunc) (to }, } + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } + result, resp, err := client.Search.Code(ctx, query, opts) if err != nil { return nil, fmt.Errorf("failed to search code: %w", err) @@ -129,8 +138,8 @@ func searchCode(client *github.Client, t translations.TranslationHelperFunc) (to } } -// searchUsers creates a tool to search for GitHub users. -func searchUsers(client *github.Client, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { +// SearchUsers creates a tool to search for GitHub users. +func SearchUsers(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { return mcp.NewTool("search_users", mcp.WithDescription(t("TOOL_SEARCH_USERS_DESCRIPTION", "Search for GitHub users")), mcp.WithString("q", @@ -145,22 +154,22 @@ func searchUsers(client *github.Client, t translations.TranslationHelperFunc) (t mcp.Description("Sort order ('asc' or 'desc')"), mcp.Enum("asc", "desc"), ), - withPagination(), + WithPagination(), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { query, err := requiredParam[string](request, "q") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - sort, err := optionalParam[string](request, "sort") + sort, err := OptionalParam[string](request, "sort") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - order, err := optionalParam[string](request, "order") + order, err := OptionalParam[string](request, "order") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - pagination, err := optionalPaginationParams(request) + pagination, err := OptionalPaginationParams(request) if err != nil { return mcp.NewToolResultError(err.Error()), nil } @@ -174,6 +183,11 @@ func searchUsers(client *github.Client, t translations.TranslationHelperFunc) (t }, } + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } + result, resp, err := client.Search.Users(ctx, query, opts) if err != nil { return nil, fmt.Errorf("failed to search users: %w", err) diff --git a/pkg/github/search_test.go b/pkg/github/search_test.go index bf1bff45..b61518e4 100644 --- a/pkg/github/search_test.go +++ b/pkg/github/search_test.go @@ -16,7 +16,7 @@ import ( func Test_SearchRepositories(t *testing.T) { // Verify tool definition once mockClient := github.NewClient(nil) - tool, _ := searchRepositories(mockClient, translations.NullTranslationHelper) + tool, _ := SearchRepositories(stubGetClientFn(mockClient), translations.NullTranslationHelper) assert.Equal(t, "search_repositories", tool.Name) assert.NotEmpty(t, tool.Description) @@ -122,7 +122,7 @@ func Test_SearchRepositories(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := searchRepositories(client, translations.NullTranslationHelper) + _, handler := SearchRepositories(stubGetClientFn(client), translations.NullTranslationHelper) // Create call request request := createMCPRequest(tc.requestArgs) @@ -163,7 +163,7 @@ func Test_SearchRepositories(t *testing.T) { func Test_SearchCode(t *testing.T) { // Verify tool definition once mockClient := github.NewClient(nil) - tool, _ := searchCode(mockClient, translations.NullTranslationHelper) + tool, _ := SearchCode(stubGetClientFn(mockClient), translations.NullTranslationHelper) assert.Equal(t, "search_code", tool.Name) assert.NotEmpty(t, tool.Description) @@ -273,7 +273,7 @@ func Test_SearchCode(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := searchCode(client, translations.NullTranslationHelper) + _, handler := SearchCode(stubGetClientFn(client), translations.NullTranslationHelper) // Create call request request := createMCPRequest(tc.requestArgs) @@ -314,7 +314,7 @@ func Test_SearchCode(t *testing.T) { func Test_SearchUsers(t *testing.T) { // Verify tool definition once mockClient := github.NewClient(nil) - tool, _ := searchUsers(mockClient, translations.NullTranslationHelper) + tool, _ := SearchUsers(stubGetClientFn(mockClient), translations.NullTranslationHelper) assert.Equal(t, "search_users", tool.Name) assert.NotEmpty(t, tool.Description) @@ -428,7 +428,7 @@ func Test_SearchUsers(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := searchUsers(client, translations.NullTranslationHelper) + _, handler := SearchUsers(stubGetClientFn(client), translations.NullTranslationHelper) // Create call request request := createMCPRequest(tc.requestArgs) diff --git a/pkg/github/server.go b/pkg/github/server.go index ce06999a..2d252b29 100644 --- a/pkg/github/server.go +++ b/pkg/github/server.go @@ -14,8 +14,10 @@ import ( "github.com/mark3labs/mcp-go/server" ) +type GetClientFn func(context.Context) (*github.Client, error) + // NewServer creates a new GitHub MCP server with the specified GH client and logger. -func NewServer(client *github.Client, version string, readOnly bool, t translations.TranslationHelperFunc) *server.MCPServer { +func NewServer(getClient GetClientFn, version string, readOnly bool, t translations.TranslationHelperFunc) *server.MCPServer { // Create a new MCP server s := server.NewMCPServer( "github-mcp-server", @@ -24,65 +26,66 @@ func NewServer(client *github.Client, version string, readOnly bool, t translati server.WithLogging()) // Add GitHub Resources - s.AddResourceTemplate(getRepositoryResourceContent(client, t)) - s.AddResourceTemplate(getRepositoryResourceBranchContent(client, t)) - s.AddResourceTemplate(getRepositoryResourceCommitContent(client, t)) - s.AddResourceTemplate(getRepositoryResourceTagContent(client, t)) - s.AddResourceTemplate(getRepositoryResourcePrContent(client, t)) + s.AddResourceTemplate(GetRepositoryResourceContent(getClient, t)) + s.AddResourceTemplate(GetRepositoryResourceBranchContent(getClient, t)) + s.AddResourceTemplate(GetRepositoryResourceCommitContent(getClient, t)) + s.AddResourceTemplate(GetRepositoryResourceTagContent(getClient, t)) + s.AddResourceTemplate(GetRepositoryResourcePrContent(getClient, t)) // Add GitHub tools - Issues - s.AddTool(getIssue(client, t)) - s.AddTool(searchIssues(client, t)) - s.AddTool(listIssues(client, t)) - s.AddTool(getIssueComments(client, t)) + s.AddTool(GetIssue(getClient, t)) + s.AddTool(SearchIssues(getClient, t)) + s.AddTool(ListIssues(getClient, t)) + s.AddTool(GetIssueComments(getClient, t)) if !readOnly { - s.AddTool(createIssue(client, t)) - s.AddTool(addIssueComment(client, t)) - s.AddTool(updateIssue(client, t)) + s.AddTool(CreateIssue(getClient, t)) + s.AddTool(AddIssueComment(getClient, t)) + s.AddTool(UpdateIssue(getClient, t)) } // Add GitHub tools - Pull Requests - s.AddTool(getPullRequest(client, t)) - s.AddTool(listPullRequests(client, t)) - s.AddTool(getPullRequestFiles(client, t)) - s.AddTool(getPullRequestStatus(client, t)) - s.AddTool(getPullRequestComments(client, t)) - s.AddTool(getPullRequestReviews(client, t)) + s.AddTool(GetPullRequest(getClient, t)) + s.AddTool(ListPullRequests(getClient, t)) + s.AddTool(GetPullRequestFiles(getClient, t)) + s.AddTool(GetPullRequestStatus(getClient, t)) + s.AddTool(GetPullRequestComments(getClient, t)) + s.AddTool(GetPullRequestReviews(getClient, t)) if !readOnly { - s.AddTool(mergePullRequest(client, t)) - s.AddTool(updatePullRequestBranch(client, t)) - s.AddTool(createPullRequestReview(client, t)) - s.AddTool(createPullRequest(client, t)) + s.AddTool(MergePullRequest(getClient, t)) + s.AddTool(UpdatePullRequestBranch(getClient, t)) + s.AddTool(CreatePullRequestReview(getClient, t)) + s.AddTool(CreatePullRequest(getClient, t)) + s.AddTool(UpdatePullRequest(getClient, t)) } // Add GitHub tools - Repositories - s.AddTool(searchRepositories(client, t)) - s.AddTool(getFileContents(client, t)) - s.AddTool(getCommit(client, t)) - s.AddTool(listCommits(client, t)) + s.AddTool(SearchRepositories(getClient, t)) + s.AddTool(GetFileContents(getClient, t)) + s.AddTool(GetCommit(getClient, t)) + s.AddTool(ListCommits(getClient, t)) if !readOnly { - s.AddTool(createOrUpdateFile(client, t)) - s.AddTool(createRepository(client, t)) - s.AddTool(forkRepository(client, t)) - s.AddTool(createBranch(client, t)) - s.AddTool(pushFiles(client, t)) + s.AddTool(CreateOrUpdateFile(getClient, t)) + s.AddTool(CreateRepository(getClient, t)) + s.AddTool(ForkRepository(getClient, t)) + s.AddTool(CreateBranch(getClient, t)) + s.AddTool(PushFiles(getClient, t)) } // Add GitHub tools - Search - s.AddTool(searchCode(client, t)) - s.AddTool(searchUsers(client, t)) + s.AddTool(SearchCode(getClient, t)) + s.AddTool(SearchUsers(getClient, t)) // Add GitHub tools - Users - s.AddTool(getMe(client, t)) + s.AddTool(GetMe(getClient, t)) // Add GitHub tools - Code Scanning - s.AddTool(getCodeScanningAlert(client, t)) - s.AddTool(listCodeScanningAlerts(client, t)) + s.AddTool(GetCodeScanningAlert(getClient, t)) + s.AddTool(ListCodeScanningAlerts(getClient, t)) return s } -// getMe creates a tool to get details of the authenticated user. -func getMe(client *github.Client, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { +// GetMe creates a tool to get details of the authenticated user. +func GetMe(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { return mcp.NewTool("get_me", mcp.WithDescription(t("TOOL_GET_ME_DESCRIPTION", "Get details of the authenticated GitHub user. Use this when a request include \"me\", \"my\"...")), mcp.WithString("reason", @@ -90,6 +93,10 @@ func getMe(client *github.Client, t translations.TranslationHelperFunc) (tool mc ), ), func(ctx context.Context, _ mcp.CallToolRequest) (*mcp.CallToolResult, error) { + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } user, resp, err := client.Users.Get(ctx, "") if err != nil { return nil, fmt.Errorf("failed to get user: %w", err) @@ -113,6 +120,30 @@ func getMe(client *github.Client, t translations.TranslationHelperFunc) (tool mc } } +// OptionalParamOK is a helper function that can be used to fetch a requested parameter from the request. +// It returns the value, a boolean indicating if the parameter was present, and an error if the type is wrong. +func OptionalParamOK[T any](r mcp.CallToolRequest, p string) (value T, ok bool, err error) { + // Check if the parameter is present in the request + val, exists := r.Params.Arguments[p] + if !exists { + // Not present, return zero value, false, no error + return + } + + // Check if the parameter is of the expected type + value, ok = val.(T) + if !ok { + // Present but wrong type + err = fmt.Errorf("parameter %s is not of type %T, is %T", p, value, val) + ok = true // Set ok to true because the parameter *was* present, even if wrong type + return + } + + // Present and correct type + ok = true + return +} + // isAcceptedError checks if the error is an accepted error. func isAcceptedError(err error) bool { var acceptedError *github.AcceptedError @@ -145,12 +176,12 @@ func requiredParam[T comparable](r mcp.CallToolRequest, p string) (T, error) { return r.Params.Arguments[p].(T), nil } -// requiredInt is a helper function that can be used to fetch a requested parameter from the request. +// RequiredInt is a helper function that can be used to fetch a requested parameter from the request. // It does the following checks: // 1. Checks if the parameter is present in the request. // 2. Checks if the parameter is of the expected type. // 3. Checks if the parameter is not empty, i.e: non-zero value -func requiredInt(r mcp.CallToolRequest, p string) (int, error) { +func RequiredInt(r mcp.CallToolRequest, p string) (int, error) { v, err := requiredParam[float64](r, p) if err != nil { return 0, err @@ -158,11 +189,11 @@ func requiredInt(r mcp.CallToolRequest, p string) (int, error) { return int(v), nil } -// optionalParam is a helper function that can be used to fetch a requested parameter from the request. +// OptionalParam is a helper function that can be used to fetch a requested parameter from the request. // It does the following checks: // 1. Checks if the parameter is present in the request, if not, it returns its zero-value // 2. If it is present, it checks if the parameter is of the expected type and returns it -func optionalParam[T any](r mcp.CallToolRequest, p string) (T, error) { +func OptionalParam[T any](r mcp.CallToolRequest, p string) (T, error) { var zero T // Check if the parameter is present in the request @@ -178,22 +209,22 @@ func optionalParam[T any](r mcp.CallToolRequest, p string) (T, error) { return r.Params.Arguments[p].(T), nil } -// optionalIntParam is a helper function that can be used to fetch a requested parameter from the request. +// OptionalIntParam is a helper function that can be used to fetch a requested parameter from the request. // It does the following checks: // 1. Checks if the parameter is present in the request, if not, it returns its zero-value // 2. If it is present, it checks if the parameter is of the expected type and returns it -func optionalIntParam(r mcp.CallToolRequest, p string) (int, error) { - v, err := optionalParam[float64](r, p) +func OptionalIntParam(r mcp.CallToolRequest, p string) (int, error) { + v, err := OptionalParam[float64](r, p) if err != nil { return 0, err } return int(v), nil } -// optionalIntParamWithDefault is a helper function that can be used to fetch a requested parameter from the request +// OptionalIntParamWithDefault is a helper function that can be used to fetch a requested parameter from the request // similar to optionalIntParam, but it also takes a default value. -func optionalIntParamWithDefault(r mcp.CallToolRequest, p string, d int) (int, error) { - v, err := optionalIntParam(r, p) +func OptionalIntParamWithDefault(r mcp.CallToolRequest, p string, d int) (int, error) { + v, err := OptionalIntParam(r, p) if err != nil { return 0, err } @@ -203,17 +234,19 @@ func optionalIntParamWithDefault(r mcp.CallToolRequest, p string, d int) (int, e return v, nil } -// optionalStringArrayParam is a helper function that can be used to fetch a requested parameter from the request. +// OptionalStringArrayParam is a helper function that can be used to fetch a requested parameter from the request. // It does the following checks: // 1. Checks if the parameter is present in the request, if not, it returns its zero-value // 2. If it is present, iterates the elements and checks each is a string -func optionalStringArrayParam(r mcp.CallToolRequest, p string) ([]string, error) { +func OptionalStringArrayParam(r mcp.CallToolRequest, p string) ([]string, error) { // Check if the parameter is present in the request if _, ok := r.Params.Arguments[p]; !ok { return []string{}, nil } switch v := r.Params.Arguments[p].(type) { + case nil: + return []string{}, nil case []string: return v, nil case []any: @@ -231,9 +264,9 @@ func optionalStringArrayParam(r mcp.CallToolRequest, p string) ([]string, error) } } -// withPagination returns a ToolOption that adds "page" and "perPage" parameters to the tool. +// WithPagination returns a ToolOption that adds "page" and "perPage" parameters to the tool. // The "page" parameter is optional, min 1. The "perPage" parameter is optional, min 1, max 100. -func withPagination() mcp.ToolOption { +func WithPagination() mcp.ToolOption { return func(tool *mcp.Tool) { mcp.WithNumber("page", mcp.Description("Page number for pagination (min 1)"), @@ -248,26 +281,26 @@ func withPagination() mcp.ToolOption { } } -type paginationParams struct { +type PaginationParams struct { page int perPage int } -// optionalPaginationParams returns the "page" and "perPage" parameters from the request, +// OptionalPaginationParams returns the "page" and "perPage" parameters from the request, // or their default values if not present, "page" default is 1, "perPage" default is 30. // In future, we may want to make the default values configurable, or even have this // function returned from `withPagination`, where the defaults are provided alongside // the min/max values. -func optionalPaginationParams(r mcp.CallToolRequest) (paginationParams, error) { - page, err := optionalIntParamWithDefault(r, "page", 1) +func OptionalPaginationParams(r mcp.CallToolRequest) (PaginationParams, error) { + page, err := OptionalIntParamWithDefault(r, "page", 1) if err != nil { - return paginationParams{}, err + return PaginationParams{}, err } - perPage, err := optionalIntParamWithDefault(r, "perPage", 30) + perPage, err := OptionalIntParamWithDefault(r, "perPage", 30) if err != nil { - return paginationParams{}, err + return PaginationParams{}, err } - return paginationParams{ + return PaginationParams{ page: page, perPage: perPage, }, nil diff --git a/pkg/github/server_test.go b/pkg/github/server_test.go index 149fb77a..3ee9851a 100644 --- a/pkg/github/server_test.go +++ b/pkg/github/server_test.go @@ -15,10 +15,16 @@ import ( "github.com/stretchr/testify/require" ) +func stubGetClientFn(client *github.Client) GetClientFn { + return func(_ context.Context) (*github.Client, error) { + return client, nil + } +} + func Test_GetMe(t *testing.T) { // Verify tool definition mockClient := github.NewClient(nil) - tool, _ := getMe(mockClient, translations.NullTranslationHelper) + tool, _ := GetMe(stubGetClientFn(mockClient), translations.NullTranslationHelper) assert.Equal(t, "get_me", tool.Name) assert.NotEmpty(t, tool.Description) @@ -96,7 +102,7 @@ func Test_GetMe(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := getMe(client, translations.NullTranslationHelper) + _, handler := GetMe(stubGetClientFn(client), translations.NullTranslationHelper) // Create call request request := createMCPRequest(tc.requestArgs) @@ -262,7 +268,7 @@ func Test_OptionalStringParam(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { request := createMCPRequest(tc.params) - result, err := optionalParam[string](request, tc.paramName) + result, err := OptionalParam[string](request, tc.paramName) if tc.expectError { assert.Error(t, err) @@ -308,7 +314,7 @@ func Test_RequiredNumberParam(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { request := createMCPRequest(tc.params) - result, err := requiredInt(request, tc.paramName) + result, err := RequiredInt(request, tc.paramName) if tc.expectError { assert.Error(t, err) @@ -361,7 +367,7 @@ func Test_OptionalNumberParam(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { request := createMCPRequest(tc.params) - result, err := optionalIntParam(request, tc.paramName) + result, err := OptionalIntParam(request, tc.paramName) if tc.expectError { assert.Error(t, err) @@ -419,7 +425,7 @@ func Test_OptionalNumberParamWithDefault(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { request := createMCPRequest(tc.params) - result, err := optionalIntParamWithDefault(request, tc.paramName, tc.defaultVal) + result, err := OptionalIntParamWithDefault(request, tc.paramName, tc.defaultVal) if tc.expectError { assert.Error(t, err) @@ -472,7 +478,7 @@ func Test_OptionalBooleanParam(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { request := createMCPRequest(tc.params) - result, err := optionalParam[bool](request, tc.paramName) + result, err := OptionalParam[bool](request, tc.paramName) if tc.expectError { assert.Error(t, err) @@ -540,7 +546,7 @@ func TestOptionalStringArrayParam(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { request := createMCPRequest(tc.params) - result, err := optionalStringArrayParam(request, tc.paramName) + result, err := OptionalStringArrayParam(request, tc.paramName) if tc.expectError { assert.Error(t, err) @@ -556,13 +562,13 @@ func TestOptionalPaginationParams(t *testing.T) { tests := []struct { name string params map[string]any - expected paginationParams + expected PaginationParams expectError bool }{ { name: "no pagination parameters, default values", params: map[string]any{}, - expected: paginationParams{ + expected: PaginationParams{ page: 1, perPage: 30, }, @@ -573,7 +579,7 @@ func TestOptionalPaginationParams(t *testing.T) { params: map[string]any{ "page": float64(2), }, - expected: paginationParams{ + expected: PaginationParams{ page: 2, perPage: 30, }, @@ -584,7 +590,7 @@ func TestOptionalPaginationParams(t *testing.T) { params: map[string]any{ "perPage": float64(50), }, - expected: paginationParams{ + expected: PaginationParams{ page: 1, perPage: 50, }, @@ -596,7 +602,7 @@ func TestOptionalPaginationParams(t *testing.T) { "page": float64(2), "perPage": float64(50), }, - expected: paginationParams{ + expected: PaginationParams{ page: 2, perPage: 50, }, @@ -607,7 +613,7 @@ func TestOptionalPaginationParams(t *testing.T) { params: map[string]any{ "page": "not-a-number", }, - expected: paginationParams{}, + expected: PaginationParams{}, expectError: true, }, { @@ -615,7 +621,7 @@ func TestOptionalPaginationParams(t *testing.T) { params: map[string]any{ "perPage": "not-a-number", }, - expected: paginationParams{}, + expected: PaginationParams{}, expectError: true, }, } @@ -623,7 +629,7 @@ func TestOptionalPaginationParams(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { request := createMCPRequest(tc.params) - result, err := optionalPaginationParams(request) + result, err := OptionalPaginationParams(request) if tc.expectError { assert.Error(t, err) diff --git a/script/licenses b/script/licenses index f231a458..c7f8ed4c 100755 --- a/script/licenses +++ b/script/licenses @@ -10,7 +10,7 @@ trap "rm -fr ${TEMPDIR}" EXIT for goos in linux darwin windows ; do # Note: we ignore warnings because we want the command to succeed, however the output should be checked - # for any new warnings, and potentially we may need to add licence information. + # for any new warnings, and potentially we may need to add license information. # # Normally these warnings are packages containing non go code, which may or may not require explicit attribution, # depending on the license. diff --git a/script/licenses-check b/script/licenses-check index 369277ca..5ad93027 100755 --- a/script/licenses-check +++ b/script/licenses-check @@ -4,7 +4,7 @@ go install github.com/google/go-licenses@latest for goos in linux darwin windows ; do # Note: we ignore warnings because we want the command to succeed, however the output should be checked - # for any new warnings, and potentially we may need to add licence information. + # for any new warnings, and potentially we may need to add license information. # # Normally these warnings are packages containing non go code, which may or may not require explicit attribution, # depending on the license.