From 4d7388ae1938f03348ae742bd008872e2a9a2a8c Mon Sep 17 00:00:00 2001 From: Javier Uruen Val Date: Mon, 24 Mar 2025 07:39:01 +0100 Subject: [PATCH 1/2] validate tools params --- pkg/github/code_scanning.go | 40 +++- pkg/github/issues.go | 197 +++++++++++++------ pkg/github/issues_test.go | 26 ++- pkg/github/pullrequests.go | 215 ++++++++++++++------- pkg/github/repositories.go | 174 +++++++++++------ pkg/github/search.go | 81 ++++---- pkg/github/server.go | 115 +++++++++++ pkg/github/server_test.go | 369 ++++++++++++++++++++++++++++++++++++ 8 files changed, 989 insertions(+), 228 deletions(-) diff --git a/pkg/github/code_scanning.go b/pkg/github/code_scanning.go index 6fc0936a..e7c8a4e2 100644 --- a/pkg/github/code_scanning.go +++ b/pkg/github/code_scanning.go @@ -30,9 +30,18 @@ func getCodeScanningAlert(client *github.Client, t translations.TranslationHelpe ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner, _ := request.Params.Arguments["owner"].(string) - repo, _ := request.Params.Arguments["repo"].(string) - alertNumber, _ := request.Params.Arguments["alert_number"].(float64) + owner, err := requiredStringParam(request, "owner") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + repo, err := requiredStringParam(request, "repo") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + alertNumber, err := requiredNumberParam(request, "alert_number") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } alert, resp, err := client.CodeScanning.GetAlert(ctx, owner, repo, int64(alertNumber)) if err != nil { @@ -80,11 +89,26 @@ func listCodeScanningAlerts(client *github.Client, t translations.TranslationHel ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner, _ := request.Params.Arguments["owner"].(string) - repo, _ := request.Params.Arguments["repo"].(string) - ref, _ := request.Params.Arguments["ref"].(string) - state, _ := request.Params.Arguments["state"].(string) - severity, _ := request.Params.Arguments["severity"].(string) + owner, err := requiredStringParam(request, "owner") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + repo, err := requiredStringParam(request, "repo") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + ref, err := optionalStringParam(request, "ref") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + state, err := optionalStringParam(request, "state") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + severity, err := optionalStringParam(request, "severity") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } alerts, resp, err := client.CodeScanning.ListAlertsForRepo(ctx, owner, repo, &github.AlertListOptions{Ref: ref, State: state, Severity: severity}) if err != nil { diff --git a/pkg/github/issues.go b/pkg/github/issues.go index 36130b98..521acfda 100644 --- a/pkg/github/issues.go +++ b/pkg/github/issues.go @@ -32,9 +32,18 @@ func getIssue(client *github.Client, t translations.TranslationHelperFunc) (tool ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner := request.Params.Arguments["owner"].(string) - repo := request.Params.Arguments["repo"].(string) - issueNumber := int(request.Params.Arguments["issue_number"].(float64)) + owner, err := requiredStringParam(request, "owner") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + repo, err := requiredStringParam(request, "repo") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + issueNumber, err := requiredNumberParam(request, "issue_number") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } issue, resp, err := client.Issues.Get(ctx, owner, repo, issueNumber) if err != nil { @@ -81,10 +90,22 @@ func addIssueComment(client *github.Client, t translations.TranslationHelperFunc ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner := request.Params.Arguments["owner"].(string) - repo := request.Params.Arguments["repo"].(string) - issueNumber := int(request.Params.Arguments["issue_number"].(float64)) - body := request.Params.Arguments["body"].(string) + owner, err := requiredStringParam(request, "owner") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + repo, err := requiredStringParam(request, "repo") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + issueNumber, err := requiredNumberParam(request, "issue_number") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + body, err := requiredStringParam(request, "body") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } comment := &github.IssueComment{ Body: github.Ptr(body), @@ -135,22 +156,25 @@ func searchIssues(client *github.Client, t translations.TranslationHelperFunc) ( ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - query := request.Params.Arguments["q"].(string) - sort := "" - if s, ok := request.Params.Arguments["sort"].(string); ok { - sort = s + query, err := requiredStringParam(request, "q") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil } - order := "" - if o, ok := request.Params.Arguments["order"].(string); ok { - order = o + sort, err := optionalStringParam(request, "sort") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil } - perPage := 30 - if pp, ok := request.Params.Arguments["per_page"].(float64); ok { - perPage = int(pp) + order, err := optionalStringParam(request, "order") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil } - page := 1 - if p, ok := request.Params.Arguments["page"].(float64); ok { - page = int(p) + perPage, err := optionalNumberParamWithDefault(request, "per_page", 30) + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + page, err := optionalNumberParamWithDefault(request, "page", 1) + if err != nil { + return mcp.NewToolResultError(err.Error()), nil } opts := &github.SearchOptions{ @@ -212,26 +236,34 @@ func createIssue(client *github.Client, t translations.TranslationHelperFunc) (t ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner := request.Params.Arguments["owner"].(string) - repo := request.Params.Arguments["repo"].(string) - title := request.Params.Arguments["title"].(string) + owner, err := requiredStringParam(request, "owner") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + repo, err := requiredStringParam(request, "repo") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + title, err := requiredStringParam(request, "title") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } // Optional parameters - var body string - if b, ok := request.Params.Arguments["body"].(string); ok { - body = b + body, err := optionalStringParam(request, "body") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil } - // Parse assignees if present - assignees := []string{} // default to empty slice, can't be nil - if a, ok := request.Params.Arguments["assignees"].(string); ok && a != "" { - assignees = parseCommaSeparatedList(a) + // Get assignees + assignees, err := optionalCommaSeparatedListParam(request, "assignees") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil } - - // Parse labels if present - labels := []string{} // default to empty slice, can't be nil - if l, ok := request.Params.Arguments["labels"].(string); ok && l != "" { - labels = parseCommaSeparatedList(l) + // Get labels + labels, err := optionalCommaSeparatedListParam(request, "labels") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil } // Create the issue request @@ -300,29 +332,43 @@ func listIssues(client *github.Client, t translations.TranslationHelperFunc) (to ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner := request.Params.Arguments["owner"].(string) - repo := request.Params.Arguments["repo"].(string) + owner, err := requiredStringParam(request, "owner") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + repo, err := requiredStringParam(request, "repo") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } opts := &github.IssueListByRepoOptions{} // Set optional parameters if provided - if state, ok := request.Params.Arguments["state"].(string); ok && state != "" { - opts.State = state + opts.State, err = optionalStringParam(request, "state") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil } - if labels, ok := request.Params.Arguments["labels"].(string); ok && labels != "" { - opts.Labels = parseCommaSeparatedList(labels) + opts.Labels, err = optionalCommaSeparatedListParam(request, "labels") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil } - if sort, ok := request.Params.Arguments["sort"].(string); ok && sort != "" { - opts.Sort = sort + opts.Sort, err = optionalStringParam(request, "sort") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil } - if direction, ok := request.Params.Arguments["direction"].(string); ok && direction != "" { - opts.Direction = direction + opts.Direction, err = optionalStringParam(request, "direction") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil } - if since, ok := request.Params.Arguments["since"].(string); ok && since != "" { + since, err := optionalStringParam(request, "since") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + if since != "" { timestamp, err := parseISOTimestamp(since) if err != nil { return mcp.NewToolResultError(fmt.Sprintf("failed to list issues: %s", err.Error())), nil @@ -397,38 +443,69 @@ func updateIssue(client *github.Client, t translations.TranslationHelperFunc) (t ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner := request.Params.Arguments["owner"].(string) - repo := request.Params.Arguments["repo"].(string) - issueNumber := int(request.Params.Arguments["issue_number"].(float64)) + owner, err := requiredStringParam(request, "owner") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + repo, err := requiredStringParam(request, "repo") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + issueNumber, err := requiredNumberParam(request, "issue_number") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } // Create the issue request with only provided fields issueRequest := &github.IssueRequest{} // Set optional parameters if provided - if title, ok := request.Params.Arguments["title"].(string); ok && title != "" { + title, err := optionalStringParam(request, "title") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + if title != "" { issueRequest.Title = github.Ptr(title) } - if body, ok := request.Params.Arguments["body"].(string); ok && body != "" { + body, err := optionalStringParam(request, "body") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + if body != "" { issueRequest.Body = github.Ptr(body) } - if state, ok := request.Params.Arguments["state"].(string); ok && state != "" { + state, err := optionalStringParam(request, "state") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + if state != "" { issueRequest.State = github.Ptr(state) } - if labels, ok := request.Params.Arguments["labels"].(string); ok && labels != "" { - labelsList := parseCommaSeparatedList(labels) - issueRequest.Labels = &labelsList + labels, err := optionalCommaSeparatedListParam(request, "labels") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + if len(labels) > 0 { + issueRequest.Labels = &labels } - if assignees, ok := request.Params.Arguments["assignees"].(string); ok && assignees != "" { - assigneesList := parseCommaSeparatedList(assignees) - issueRequest.Assignees = &assigneesList + assignees, err := optionalCommaSeparatedListParam(request, "assignees") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + if len(assignees) > 0 { + issueRequest.Assignees = &assignees } - if milestone, ok := request.Params.Arguments["milestone"].(float64); ok { - milestoneNum := int(milestone) + milestone, err := optionalNumberParam(request, "milestone") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + if milestone != 0 { + milestoneNum := milestone issueRequest.Milestone = &milestoneNum } diff --git a/pkg/github/issues_test.go b/pkg/github/issues_test.go index 4e8250fd..c2de6579 100644 --- a/pkg/github/issues_test.go +++ b/pkg/github/issues_test.go @@ -176,8 +176,8 @@ func Test_AddIssueComment(t *testing.T) { "issue_number": float64(42), "body": "", }, - expectError: true, - expectedErrMsg: "failed to create comment", + expectError: false, + expectedErrMsg: "missing required parameter: body", }, } @@ -210,6 +210,13 @@ func Test_AddIssueComment(t *testing.T) { return } + if tc.expectedErrMsg != "" { + require.NotNil(t, result) + textContent := getTextResult(t, result) + assert.Contains(t, textContent.Text, tc.expectedErrMsg) + return + } + require.NoError(t, err) // Parse the result and get the text content if no error @@ -419,8 +426,8 @@ func Test_CreateIssue(t *testing.T) { "repo": "repo", "title": "Test Issue", "body": "This is a test issue", - "assignees": []interface{}{"user1", "user2"}, - "labels": []interface{}{"bug", "help wanted"}, + "assignees": "user1, user2", + "labels": "bug, help wanted", }, expectError: false, expectedIssue: mockIssue, @@ -467,8 +474,8 @@ func Test_CreateIssue(t *testing.T) { "repo": "repo", "title": "", }, - expectError: true, - expectedErrMsg: "failed to create issue", + expectError: false, + expectedErrMsg: "missing required parameter: title", }, } @@ -491,6 +498,13 @@ func Test_CreateIssue(t *testing.T) { return } + if tc.expectedErrMsg != "" { + require.NotNil(t, result) + textContent := getTextResult(t, result) + assert.Contains(t, textContent.Text, tc.expectedErrMsg) + return + } + require.NoError(t, err) textContent := getTextResult(t, result) diff --git a/pkg/github/pullrequests.go b/pkg/github/pullrequests.go index e0414394..d5caab97 100644 --- a/pkg/github/pullrequests.go +++ b/pkg/github/pullrequests.go @@ -31,9 +31,18 @@ func getPullRequest(client *github.Client, t translations.TranslationHelperFunc) ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner := request.Params.Arguments["owner"].(string) - repo := request.Params.Arguments["repo"].(string) - pullNumber := int(request.Params.Arguments["pull_number"].(float64)) + owner, err := requiredStringParam(request, "owner") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + repo, err := requiredStringParam(request, "repo") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + pullNumber, err := requiredNumberParam(request, "pull_number") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } pr, resp, err := client.PullRequests.Get(ctx, owner, repo, pullNumber) if err != nil { @@ -93,35 +102,41 @@ func listPullRequests(client *github.Client, t translations.TranslationHelperFun ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner := request.Params.Arguments["owner"].(string) - repo := request.Params.Arguments["repo"].(string) - state := "" - if s, ok := request.Params.Arguments["state"].(string); ok { - state = s + owner, err := requiredStringParam(request, "owner") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + repo, err := requiredStringParam(request, "repo") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil } - head := "" - if h, ok := request.Params.Arguments["head"].(string); ok { - head = h + state, err := optionalStringParam(request, "state") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil } - base := "" - if b, ok := request.Params.Arguments["base"].(string); ok { - base = b + head, err := optionalStringParam(request, "head") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil } - sort := "" - if s, ok := request.Params.Arguments["sort"].(string); ok { - sort = s + base, err := optionalStringParam(request, "base") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil } - direction := "" - if d, ok := request.Params.Arguments["direction"].(string); ok { - direction = d + sort, err := optionalStringParam(request, "sort") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil } - perPage := 30 - if pp, ok := request.Params.Arguments["per_page"].(float64); ok { - perPage = int(pp) + direction, err := optionalStringParam(request, "direction") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil } - page := 1 - if p, ok := request.Params.Arguments["page"].(float64); ok { - page = int(p) + perPage, err := optionalNumberParamWithDefault(request, "per_page", 30) + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + page, err := optionalNumberParamWithDefault(request, "page", 1) + if err != nil { + return mcp.NewToolResultError(err.Error()), nil } opts := &github.PullRequestListOptions{ @@ -186,20 +201,29 @@ func mergePullRequest(client *github.Client, t translations.TranslationHelperFun ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner := request.Params.Arguments["owner"].(string) - repo := request.Params.Arguments["repo"].(string) - pullNumber := int(request.Params.Arguments["pull_number"].(float64)) - commitTitle := "" - if ct, ok := request.Params.Arguments["commit_title"].(string); ok { - commitTitle = ct + owner, err := requiredStringParam(request, "owner") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil } - commitMessage := "" - if cm, ok := request.Params.Arguments["commit_message"].(string); ok { - commitMessage = cm + repo, err := requiredStringParam(request, "repo") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil } - mergeMethod := "" - if mm, ok := request.Params.Arguments["merge_method"].(string); ok { - mergeMethod = mm + pullNumber, err := requiredNumberParam(request, "pull_number") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + commitTitle, err := optionalStringParam(request, "commit_title") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + commitMessage, err := optionalStringParam(request, "commit_message") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + mergeMethod, err := optionalStringParam(request, "merge_method") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil } options := &github.PullRequestOptions{ @@ -248,9 +272,18 @@ func getPullRequestFiles(client *github.Client, t translations.TranslationHelper ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner := request.Params.Arguments["owner"].(string) - repo := request.Params.Arguments["repo"].(string) - pullNumber := int(request.Params.Arguments["pull_number"].(float64)) + owner, err := requiredStringParam(request, "owner") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + repo, err := requiredStringParam(request, "repo") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + pullNumber, err := requiredNumberParam(request, "pull_number") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } opts := &github.ListOptions{} files, resp, err := client.PullRequests.ListFiles(ctx, owner, repo, pullNumber, opts) @@ -294,10 +327,18 @@ func getPullRequestStatus(client *github.Client, t translations.TranslationHelpe ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner := request.Params.Arguments["owner"].(string) - repo := request.Params.Arguments["repo"].(string) - pullNumber := int(request.Params.Arguments["pull_number"].(float64)) - + owner, err := requiredStringParam(request, "owner") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + repo, err := requiredStringParam(request, "repo") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + pullNumber, err := requiredNumberParam(request, "pull_number") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } // First get the PR to find the head SHA pr, resp, err := client.PullRequests.Get(ctx, owner, repo, pullNumber) if err != nil { @@ -358,14 +399,22 @@ func updatePullRequestBranch(client *github.Client, t translations.TranslationHe ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner := request.Params.Arguments["owner"].(string) - repo := request.Params.Arguments["repo"].(string) - pullNumber := int(request.Params.Arguments["pull_number"].(float64)) - expectedHeadSHA := "" - if sha, ok := request.Params.Arguments["expected_head_sha"].(string); ok { - expectedHeadSHA = sha + owner, err := requiredStringParam(request, "owner") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + repo, err := requiredStringParam(request, "repo") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + pullNumber, err := requiredNumberParam(request, "pull_number") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + expectedHeadSHA, err := optionalStringParam(request, "expected_head_sha") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil } - opts := &github.PullRequestBranchUpdateOptions{} if expectedHeadSHA != "" { opts.ExpectedHeadSHA = github.Ptr(expectedHeadSHA) @@ -417,9 +466,18 @@ func getPullRequestComments(client *github.Client, t translations.TranslationHel ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner := request.Params.Arguments["owner"].(string) - repo := request.Params.Arguments["repo"].(string) - pullNumber := int(request.Params.Arguments["pull_number"].(float64)) + owner, err := requiredStringParam(request, "owner") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + repo, err := requiredStringParam(request, "repo") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + pullNumber, err := requiredNumberParam(request, "pull_number") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } opts := &github.PullRequestListCommentsOptions{ ListOptions: github.ListOptions{ @@ -468,9 +526,18 @@ func getPullRequestReviews(client *github.Client, t translations.TranslationHelp ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner := request.Params.Arguments["owner"].(string) - repo := request.Params.Arguments["repo"].(string) - pullNumber := int(request.Params.Arguments["pull_number"].(float64)) + owner, err := requiredStringParam(request, "owner") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + repo, err := requiredStringParam(request, "repo") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + pullNumber, err := requiredNumberParam(request, "pull_number") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } reviews, resp, err := client.PullRequests.ListReviews(ctx, owner, repo, pullNumber, nil) if err != nil { @@ -526,10 +593,22 @@ func createPullRequestReview(client *github.Client, t translations.TranslationHe ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner := request.Params.Arguments["owner"].(string) - repo := request.Params.Arguments["repo"].(string) - pullNumber := int(request.Params.Arguments["pull_number"].(float64)) - event := request.Params.Arguments["event"].(string) + owner, err := requiredStringParam(request, "owner") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + repo, err := requiredStringParam(request, "repo") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + pullNumber, err := requiredNumberParam(request, "pull_number") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + event, err := requiredStringParam(request, "event") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } // Create review request reviewRequest := &github.PullRequestReviewRequest{ @@ -537,12 +616,20 @@ func createPullRequestReview(client *github.Client, t translations.TranslationHe } // Add body if provided - if body, ok := request.Params.Arguments["body"].(string); ok && body != "" { + body, err := optionalStringParam(request, "body") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + if body != "" { reviewRequest.Body = github.Ptr(body) } // Add commit ID if provided - if commitID, ok := request.Params.Arguments["commit_id"].(string); ok && commitID != "" { + commitID, err := optionalStringParam(request, "commit_id") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + if commitID != "" { reviewRequest.CommitID = github.Ptr(commitID) } diff --git a/pkg/github/repositories.go b/pkg/github/repositories.go index 6e3b176d..f222b1f8 100644 --- a/pkg/github/repositories.go +++ b/pkg/github/repositories.go @@ -37,19 +37,25 @@ func listCommits(client *github.Client, t translations.TranslationHelperFunc) (t ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner := request.Params.Arguments["owner"].(string) - repo := request.Params.Arguments["repo"].(string) - sha := "" - if s, ok := request.Params.Arguments["sha"].(string); ok { - sha = s + owner, err := requiredStringParam(request, "owner") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + repo, err := requiredStringParam(request, "repo") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil } - page := 1 - if p, ok := request.Params.Arguments["page"].(float64); ok { - page = int(p) + sha, err := optionalStringParam(request, "sha") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil } - perPage := 30 - if pp, ok := request.Params.Arguments["per_page"].(float64); ok { - perPage = int(pp) + page, err := optionalNumberParamWithDefault(request, "page", 1) + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + perPage, err := optionalNumberParamWithDefault(request, "per_page", 30) + if err != nil { + return mcp.NewToolResultError(err.Error()), nil } opts := &github.CommitsListOptions{ @@ -116,12 +122,30 @@ func createOrUpdateFile(client *github.Client, t translations.TranslationHelperF ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner := request.Params.Arguments["owner"].(string) - repo := request.Params.Arguments["repo"].(string) - path := request.Params.Arguments["path"].(string) - content := request.Params.Arguments["content"].(string) - message := request.Params.Arguments["message"].(string) - branch := request.Params.Arguments["branch"].(string) + owner, err := requiredStringParam(request, "owner") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + repo, err := requiredStringParam(request, "repo") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + path, err := requiredStringParam(request, "path") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + content, err := requiredStringParam(request, "content") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + message, err := requiredStringParam(request, "message") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + branch, err := requiredStringParam(request, "branch") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } // Convert content to base64 contentBytes := []byte(content) @@ -134,7 +158,11 @@ func createOrUpdateFile(client *github.Client, t translations.TranslationHelperF } // If SHA is provided, set it (for updates) - if sha, ok := request.Params.Arguments["sha"].(string); ok && sha != "" { + sha, err := optionalStringParam(request, "sha") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + if sha != "" { opts.SHA = ptr.String(sha) } @@ -181,25 +209,28 @@ func createRepository(client *github.Client, t translations.TranslationHelperFun ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - name := request.Params.Arguments["name"].(string) - description := "" - if desc, ok := request.Params.Arguments["description"].(string); ok { - description = desc + name, err := requiredStringParam(request, "name") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + description, err := optionalStringParam(request, "description") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil } - private := false - if priv, ok := request.Params.Arguments["private"].(bool); ok { - private = priv + private, err := optionalBooleanParam(request, "private") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil } - autoInit := false - if init, ok := request.Params.Arguments["auto_init"].(bool); ok { - autoInit = init + autoInit, err := optionalBooleanParam(request, "auto_init") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil } repo := &github.Repository{ - Name: github.String(name), - Description: github.String(description), - Private: github.Bool(private), - AutoInit: github.Bool(autoInit), + Name: github.Ptr(name), + Description: github.Ptr(description), + Private: github.Ptr(private), + AutoInit: github.Ptr(autoInit), } createdRepo, resp, err := client.Repositories.Create(ctx, "", repo) @@ -246,12 +277,21 @@ func getFileContents(client *github.Client, t translations.TranslationHelperFunc ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner := request.Params.Arguments["owner"].(string) - repo := request.Params.Arguments["repo"].(string) - path := request.Params.Arguments["path"].(string) - branch := "" - if b, ok := request.Params.Arguments["branch"].(string); ok { - branch = b + owner, err := requiredStringParam(request, "owner") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + repo, err := requiredStringParam(request, "repo") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + path, err := requiredStringParam(request, "path") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + branch, err := optionalStringParam(request, "branch") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil } opts := &github.RepositoryContentGetOptions{Ref: branch} @@ -302,11 +342,17 @@ func forkRepository(client *github.Client, t translations.TranslationHelperFunc) ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner := request.Params.Arguments["owner"].(string) - repo := request.Params.Arguments["repo"].(string) - org := "" - if o, ok := request.Params.Arguments["organization"].(string); ok { - org = o + owner, err := requiredStringParam(request, "owner") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + repo, err := requiredStringParam(request, "repo") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + org, err := optionalStringParam(request, "organization") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil } opts := &github.RepositoryCreateForkOptions{} @@ -363,17 +409,25 @@ func createBranch(client *github.Client, t translations.TranslationHelperFunc) ( ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner := request.Params.Arguments["owner"].(string) - repo := request.Params.Arguments["repo"].(string) - branch := request.Params.Arguments["branch"].(string) - fromBranch := "" - if fb, ok := request.Params.Arguments["from_branch"].(string); ok { - fromBranch = fb + owner, err := requiredStringParam(request, "owner") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + repo, err := requiredStringParam(request, "repo") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + branch, err := requiredStringParam(request, "branch") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + fromBranch, err := optionalStringParam(request, "from_branch") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil } // Get the source branch SHA var ref *github.Reference - var err error if fromBranch == "" { // Get default branch if from_branch not specified @@ -440,10 +494,22 @@ func pushFiles(client *github.Client, t translations.TranslationHelperFunc) (too ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner := request.Params.Arguments["owner"].(string) - repo := request.Params.Arguments["repo"].(string) - branch := request.Params.Arguments["branch"].(string) - message := request.Params.Arguments["message"].(string) + owner, err := requiredStringParam(request, "owner") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + repo, err := requiredStringParam(request, "repo") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + branch, err := requiredStringParam(request, "branch") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + message, err := requiredStringParam(request, "message") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } // Parse files parameter - this should be an array of objects with path and content filesObj, ok := request.Params.Arguments["files"].([]interface{}) diff --git a/pkg/github/search.go b/pkg/github/search.go index 353c6fb2..d7ea4904 100644 --- a/pkg/github/search.go +++ b/pkg/github/search.go @@ -28,14 +28,17 @@ func searchRepositories(client *github.Client, t translations.TranslationHelperF ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - query := request.Params.Arguments["query"].(string) - page := 1 - if p, ok := request.Params.Arguments["page"].(float64); ok { - page = int(p) + query, err := requiredStringParam(request, "query") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil } - perPage := 30 - if pp, ok := request.Params.Arguments["per_page"].(float64); ok { - perPage = int(pp) + page, err := optionalNumberParamWithDefault(request, "page", 1) + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + perPage, err := optionalNumberParamWithDefault(request, "per_page", 30) + if err != nil { + return mcp.NewToolResultError(err.Error()), nil } opts := &github.SearchOptions{ @@ -90,22 +93,25 @@ func searchCode(client *github.Client, t translations.TranslationHelperFunc) (to ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - query := request.Params.Arguments["q"].(string) - sort := "" - if s, ok := request.Params.Arguments["sort"].(string); ok { - sort = s + query, err := requiredStringParam(request, "q") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil } - order := "" - if o, ok := request.Params.Arguments["order"].(string); ok { - order = o + sort, err := optionalStringParam(request, "sort") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil } - perPage := 30 - if pp, ok := request.Params.Arguments["per_page"].(float64); ok { - perPage = int(pp) + order, err := optionalStringParam(request, "order") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil } - page := 1 - if p, ok := request.Params.Arguments["page"].(float64); ok { - page = int(p) + perPage, err := optionalNumberParamWithDefault(request, "per_page", 30) + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + page, err := optionalNumberParamWithDefault(request, "page", 1) + if err != nil { + return mcp.NewToolResultError(err.Error()), nil } opts := &github.SearchOptions{ @@ -162,22 +168,25 @@ func searchUsers(client *github.Client, t translations.TranslationHelperFunc) (t ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - query := request.Params.Arguments["q"].(string) - sort := "" - if s, ok := request.Params.Arguments["sort"].(string); ok { - sort = s - } - order := "" - if o, ok := request.Params.Arguments["order"].(string); ok { - order = o - } - perPage := 30 - if pp, ok := request.Params.Arguments["per_page"].(float64); ok { - perPage = int(pp) - } - page := 1 - if p, ok := request.Params.Arguments["page"].(float64); ok { - page = int(p) + query, err := requiredStringParam(request, "q") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + sort, err := optionalStringParam(request, "sort") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + order, err := optionalStringParam(request, "order") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + perPage, err := optionalNumberParamWithDefault(request, "per_page", 30) + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + page, err := optionalNumberParamWithDefault(request, "page", 1) + if err != nil { + return mcp.NewToolResultError(err.Error()), nil } opts := &github.SearchOptions{ diff --git a/pkg/github/server.go b/pkg/github/server.go index a0993e2f..42e23083 100644 --- a/pkg/github/server.go +++ b/pkg/github/server.go @@ -138,3 +138,118 @@ func parseCommaSeparatedList(input string) []string { return result } + +// requiredStringParam checks if the parameter is present in the request and is of type string. +func requiredStringParam(r mcp.CallToolRequest, p string) (string, error) { + // Check if the parameter is present in the request + if _, ok := r.Params.Arguments[p]; !ok { + return "", fmt.Errorf("missing required parameter: %s", p) + } + + // Check if the parameter is of the expected type + if _, ok := r.Params.Arguments[p].(string); !ok { + return "", fmt.Errorf("parameter %s is not of type string", p) + } + + // Check if the parameter is not the zero value + v := r.Params.Arguments[p].(string) + if v == "" { + return v, fmt.Errorf("missing required parameter: %s", p) + } + + return v, nil +} + +// requiredNumberParam checks if the parameter is present in the request and is of type number. +func requiredNumberParam(r mcp.CallToolRequest, p string) (int, error) { + // Check if the parameter is present in the request + if _, ok := r.Params.Arguments[p]; !ok { + return 0, fmt.Errorf("missing required parameter: %s", p) + } + + // Check if the parameter is of the expected type + if _, ok := r.Params.Arguments[p].(float64); !ok { + return 0, fmt.Errorf("parameter %s is not of type number", p) + } + + return int(r.Params.Arguments[p].(float64)), nil +} + +// optionalStringParam checks if an optional parameter is present in the request and is of type string. +func optionalStringParam(r mcp.CallToolRequest, p string) (value string, err error) { + // Check if the parameter is present in the request + if _, ok := r.Params.Arguments[p]; !ok { + return "", nil + } + + // Check if the parameter is of the expected type + if _, ok := r.Params.Arguments[p].(string); !ok { + return "", fmt.Errorf("parameter %s is not of type string", p) + } + + return r.Params.Arguments[p].(string), nil +} + +// optionalNumberParam checks if an optional parameter is present in the request and is of type number. +func optionalNumberParam(r mcp.CallToolRequest, p string) (int, error) { + // Check if the parameter is present in the request + if _, ok := r.Params.Arguments[p]; !ok { + return 0, nil + } + + // Check if the parameter is of the expected type + if _, ok := r.Params.Arguments[p].(float64); !ok { + return 0, fmt.Errorf("parameter %s is not of type number", p) + } + + return int(r.Params.Arguments[p].(float64)), nil +} + +// optionalNumberParamWithDefault checks if an optional parameter is present in the request and is of type number. +// If the parameter is not present or is zero, it returns the default value. +func optionalNumberParamWithDefault(r mcp.CallToolRequest, p string, d int) (int, error) { + v, err := optionalNumberParam(r, p) + if err != nil { + return 0, err + } + if v == 0 { + return d, nil + } + return v, nil +} + +// optionalCommaSeparatedListParam checks if an optional parameter is present in the request and is of type string. +// If the parameter is presents, it uses parseCommaSeparatedList to parse the string into a list of strings. +// If the parameter is not present or is empty, it returns an empty list. +func optionalCommaSeparatedListParam(r mcp.CallToolRequest, p string) ([]string, error) { + // Check if the parameter is present in the request + if _, ok := r.Params.Arguments[p]; !ok { + return []string{}, nil //default to empty list, not nil + } + + // Check if the parameter is of the expected type + if _, ok := r.Params.Arguments[p].(string); !ok { + return nil, fmt.Errorf("parameter %s is not of type string", p) + } + + l := parseCommaSeparatedList(r.Params.Arguments[p].(string)) + if len(l) == 0 { + return []string{}, nil // default to empty list, not nil + } + return l, nil +} + +// optionalBooleanParam checks if an optional parameter is present in the request and is of type boolean. +func optionalBooleanParam(r mcp.CallToolRequest, p string) (bool, error) { + // Check if the parameter is present in the request + if _, ok := r.Params.Arguments[p]; !ok { + return false, nil + } + + // Check if the parameter is of the expected type + if _, ok := r.Params.Arguments[p].(bool); !ok { + return false, fmt.Errorf("parameter %s is not of type bool", p) + } + + return r.Params.Arguments[p].(bool), nil +} diff --git a/pkg/github/server_test.go b/pkg/github/server_test.go index 316a0efa..a081d31d 100644 --- a/pkg/github/server_test.go +++ b/pkg/github/server_test.go @@ -228,3 +228,372 @@ func Test_ParseCommaSeparatedList(t *testing.T) { }) } } + +func Test_RequiredStringParam(t *testing.T) { + tests := []struct { + name string + params map[string]interface{} + paramName string + expected string + expectError bool + }{ + { + name: "valid string parameter", + params: map[string]interface{}{"name": "test-value"}, + paramName: "name", + expected: "test-value", + expectError: false, + }, + { + name: "missing parameter", + params: map[string]interface{}{}, + paramName: "name", + expected: "", + expectError: true, + }, + { + name: "empty string parameter", + params: map[string]interface{}{"name": ""}, + paramName: "name", + expected: "", + expectError: true, + }, + { + name: "wrong type parameter", + params: map[string]interface{}{"name": 123}, + paramName: "name", + expected: "", + expectError: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + request := createMCPRequest(tc.params) + result, err := requiredStringParam(request, tc.paramName) + + if tc.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tc.expected, result) + } + }) + } +} + +func Test_OptionalStringParam(t *testing.T) { + tests := []struct { + name string + params map[string]interface{} + paramName string + expected string + expectError bool + }{ + { + name: "valid string parameter", + params: map[string]interface{}{"name": "test-value"}, + paramName: "name", + expected: "test-value", + expectError: false, + }, + { + name: "missing parameter", + params: map[string]interface{}{}, + paramName: "name", + expected: "", + expectError: false, + }, + { + name: "empty string parameter", + params: map[string]interface{}{"name": ""}, + paramName: "name", + expected: "", + expectError: false, + }, + { + name: "wrong type parameter", + params: map[string]interface{}{"name": 123}, + paramName: "name", + expected: "", + expectError: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + request := createMCPRequest(tc.params) + result, err := optionalStringParam(request, tc.paramName) + + if tc.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tc.expected, result) + } + }) + } +} + +func Test_RequiredNumberParam(t *testing.T) { + tests := []struct { + name string + params map[string]interface{} + paramName string + expected int + expectError bool + }{ + { + name: "valid number parameter", + params: map[string]interface{}{"count": float64(42)}, + paramName: "count", + expected: 42, + expectError: false, + }, + { + name: "missing parameter", + params: map[string]interface{}{}, + paramName: "count", + expected: 0, + expectError: true, + }, + { + name: "wrong type parameter", + params: map[string]interface{}{"count": "not-a-number"}, + paramName: "count", + expected: 0, + expectError: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + request := createMCPRequest(tc.params) + result, err := requiredNumberParam(request, tc.paramName) + + if tc.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tc.expected, result) + } + }) + } +} + +func Test_OptionalNumberParam(t *testing.T) { + tests := []struct { + name string + params map[string]interface{} + paramName string + expected int + expectError bool + }{ + { + name: "valid number parameter", + params: map[string]interface{}{"count": float64(42)}, + paramName: "count", + expected: 42, + expectError: false, + }, + { + name: "missing parameter", + params: map[string]interface{}{}, + paramName: "count", + expected: 0, + expectError: false, + }, + { + name: "zero value", + params: map[string]interface{}{"count": float64(0)}, + paramName: "count", + expected: 0, + expectError: false, + }, + { + name: "wrong type parameter", + params: map[string]interface{}{"count": "not-a-number"}, + paramName: "count", + expected: 0, + expectError: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + request := createMCPRequest(tc.params) + result, err := optionalNumberParam(request, tc.paramName) + + if tc.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tc.expected, result) + } + }) + } +} + +func Test_OptionalNumberParamWithDefault(t *testing.T) { + tests := []struct { + name string + params map[string]interface{} + paramName string + defaultVal int + expected int + expectError bool + }{ + { + name: "valid number parameter", + params: map[string]interface{}{"count": float64(42)}, + paramName: "count", + defaultVal: 10, + expected: 42, + expectError: false, + }, + { + name: "missing parameter", + params: map[string]interface{}{}, + paramName: "count", + defaultVal: 10, + expected: 10, + expectError: false, + }, + { + name: "zero value", + params: map[string]interface{}{"count": float64(0)}, + paramName: "count", + defaultVal: 10, + expected: 10, + expectError: false, + }, + { + name: "wrong type parameter", + params: map[string]interface{}{"count": "not-a-number"}, + paramName: "count", + defaultVal: 10, + expected: 0, + expectError: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + request := createMCPRequest(tc.params) + result, err := optionalNumberParamWithDefault(request, tc.paramName, tc.defaultVal) + + if tc.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tc.expected, result) + } + }) + } +} + +func Test_OptionalCommaSeparatedListParam(t *testing.T) { + tests := []struct { + name string + params map[string]interface{} + paramName string + expected []string + expectError bool + }{ + { + name: "valid comma-separated list", + params: map[string]interface{}{"tags": "one,two,three"}, + paramName: "tags", + expected: []string{"one", "two", "three"}, + expectError: false, + }, + { + name: "empty list", + params: map[string]interface{}{"tags": ""}, + paramName: "tags", + expected: []string{}, + expectError: false, + }, + { + name: "missing parameter", + params: map[string]interface{}{}, + paramName: "tags", + expected: []string{}, + expectError: false, + }, + { + name: "wrong type parameter", + params: map[string]interface{}{"tags": 123}, + paramName: "tags", + expected: nil, + expectError: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + request := createMCPRequest(tc.params) + result, err := optionalCommaSeparatedListParam(request, tc.paramName) + + if tc.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tc.expected, result) + } + }) + } +} + +func Test_OptionalBooleanParam(t *testing.T) { + tests := []struct { + name string + params map[string]interface{} + paramName string + expected bool + expectError bool + }{ + { + name: "true value", + params: map[string]interface{}{"flag": true}, + paramName: "flag", + expected: true, + expectError: false, + }, + { + name: "false value", + params: map[string]interface{}{"flag": false}, + paramName: "flag", + expected: false, + expectError: false, + }, + { + name: "missing parameter", + params: map[string]interface{}{}, + paramName: "flag", + expected: false, + expectError: false, + }, + { + name: "wrong type parameter", + params: map[string]interface{}{"flag": "not-a-boolean"}, + paramName: "flag", + expected: false, + expectError: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + request := createMCPRequest(tc.params) + result, err := optionalBooleanParam(request, tc.paramName) + + if tc.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tc.expected, result) + } + }) + } +} From 3c6467ffef6b2493d69533d34776d5cdd97b65ce Mon Sep 17 00:00:00 2001 From: Javier Uruen Val Date: Mon, 24 Mar 2025 14:47:54 +0100 Subject: [PATCH 2/2] use generic for helper functions --- pkg/github/code_scanning.go | 16 ++--- pkg/github/issues.go | 58 ++++++++-------- pkg/github/pullrequests.go | 80 +++++++++++----------- pkg/github/repositories.go | 62 ++++++++--------- pkg/github/search.go | 26 +++---- pkg/github/server.go | 132 ++++++++++++++++-------------------- pkg/github/server_test.go | 12 ++-- 7 files changed, 186 insertions(+), 200 deletions(-) diff --git a/pkg/github/code_scanning.go b/pkg/github/code_scanning.go index e7c8a4e2..380dc02c 100644 --- a/pkg/github/code_scanning.go +++ b/pkg/github/code_scanning.go @@ -30,15 +30,15 @@ func getCodeScanningAlert(client *github.Client, t translations.TranslationHelpe ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner, err := requiredStringParam(request, "owner") + owner, err := requiredParam[string](request, "owner") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - repo, err := requiredStringParam(request, "repo") + repo, err := requiredParam[string](request, "repo") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - alertNumber, err := requiredNumberParam(request, "alert_number") + alertNumber, err := requiredInt(request, "alert_number") if err != nil { return mcp.NewToolResultError(err.Error()), nil } @@ -89,23 +89,23 @@ func listCodeScanningAlerts(client *github.Client, t translations.TranslationHel ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner, err := requiredStringParam(request, "owner") + owner, err := requiredParam[string](request, "owner") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - repo, err := requiredStringParam(request, "repo") + repo, err := requiredParam[string](request, "repo") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - ref, err := optionalStringParam(request, "ref") + ref, err := optionalParam[string](request, "ref") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - state, err := optionalStringParam(request, "state") + state, err := optionalParam[string](request, "state") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - severity, err := optionalStringParam(request, "severity") + severity, err := optionalParam[string](request, "severity") if err != nil { return mcp.NewToolResultError(err.Error()), nil } diff --git a/pkg/github/issues.go b/pkg/github/issues.go index 521acfda..a62213ea 100644 --- a/pkg/github/issues.go +++ b/pkg/github/issues.go @@ -32,15 +32,15 @@ func getIssue(client *github.Client, t translations.TranslationHelperFunc) (tool ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner, err := requiredStringParam(request, "owner") + owner, err := requiredParam[string](request, "owner") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - repo, err := requiredStringParam(request, "repo") + repo, err := requiredParam[string](request, "repo") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - issueNumber, err := requiredNumberParam(request, "issue_number") + issueNumber, err := requiredInt(request, "issue_number") if err != nil { return mcp.NewToolResultError(err.Error()), nil } @@ -90,19 +90,19 @@ func addIssueComment(client *github.Client, t translations.TranslationHelperFunc ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner, err := requiredStringParam(request, "owner") + owner, err := requiredParam[string](request, "owner") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - repo, err := requiredStringParam(request, "repo") + repo, err := requiredParam[string](request, "repo") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - issueNumber, err := requiredNumberParam(request, "issue_number") + issueNumber, err := requiredInt(request, "issue_number") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - body, err := requiredStringParam(request, "body") + body, err := requiredParam[string](request, "body") if err != nil { return mcp.NewToolResultError(err.Error()), nil } @@ -156,23 +156,23 @@ func searchIssues(client *github.Client, t translations.TranslationHelperFunc) ( ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - query, err := requiredStringParam(request, "q") + query, err := requiredParam[string](request, "q") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - sort, err := optionalStringParam(request, "sort") + sort, err := optionalParam[string](request, "sort") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - order, err := optionalStringParam(request, "order") + order, err := optionalParam[string](request, "order") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - perPage, err := optionalNumberParamWithDefault(request, "per_page", 30) + perPage, err := optionalIntParamWithDefault(request, "per_page", 30) if err != nil { return mcp.NewToolResultError(err.Error()), nil } - page, err := optionalNumberParamWithDefault(request, "page", 1) + page, err := optionalIntParamWithDefault(request, "page", 1) if err != nil { return mcp.NewToolResultError(err.Error()), nil } @@ -236,21 +236,21 @@ func createIssue(client *github.Client, t translations.TranslationHelperFunc) (t ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner, err := requiredStringParam(request, "owner") + owner, err := requiredParam[string](request, "owner") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - repo, err := requiredStringParam(request, "repo") + repo, err := requiredParam[string](request, "repo") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - title, err := requiredStringParam(request, "title") + title, err := requiredParam[string](request, "title") if err != nil { return mcp.NewToolResultError(err.Error()), nil } // Optional parameters - body, err := optionalStringParam(request, "body") + body, err := optionalParam[string](request, "body") if err != nil { return mcp.NewToolResultError(err.Error()), nil } @@ -332,11 +332,11 @@ func listIssues(client *github.Client, t translations.TranslationHelperFunc) (to ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner, err := requiredStringParam(request, "owner") + owner, err := requiredParam[string](request, "owner") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - repo, err := requiredStringParam(request, "repo") + repo, err := requiredParam[string](request, "repo") if err != nil { return mcp.NewToolResultError(err.Error()), nil } @@ -344,7 +344,7 @@ func listIssues(client *github.Client, t translations.TranslationHelperFunc) (to opts := &github.IssueListByRepoOptions{} // Set optional parameters if provided - opts.State, err = optionalStringParam(request, "state") + opts.State, err = optionalParam[string](request, "state") if err != nil { return mcp.NewToolResultError(err.Error()), nil } @@ -354,17 +354,17 @@ func listIssues(client *github.Client, t translations.TranslationHelperFunc) (to return mcp.NewToolResultError(err.Error()), nil } - opts.Sort, err = optionalStringParam(request, "sort") + opts.Sort, err = optionalParam[string](request, "sort") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - opts.Direction, err = optionalStringParam(request, "direction") + opts.Direction, err = optionalParam[string](request, "direction") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - since, err := optionalStringParam(request, "since") + since, err := optionalParam[string](request, "since") if err != nil { return mcp.NewToolResultError(err.Error()), nil } @@ -443,15 +443,15 @@ func updateIssue(client *github.Client, t translations.TranslationHelperFunc) (t ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner, err := requiredStringParam(request, "owner") + owner, err := requiredParam[string](request, "owner") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - repo, err := requiredStringParam(request, "repo") + repo, err := requiredParam[string](request, "repo") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - issueNumber, err := requiredNumberParam(request, "issue_number") + issueNumber, err := requiredInt(request, "issue_number") if err != nil { return mcp.NewToolResultError(err.Error()), nil } @@ -460,7 +460,7 @@ func updateIssue(client *github.Client, t translations.TranslationHelperFunc) (t issueRequest := &github.IssueRequest{} // Set optional parameters if provided - title, err := optionalStringParam(request, "title") + title, err := optionalParam[string](request, "title") if err != nil { return mcp.NewToolResultError(err.Error()), nil } @@ -468,7 +468,7 @@ func updateIssue(client *github.Client, t translations.TranslationHelperFunc) (t issueRequest.Title = github.Ptr(title) } - body, err := optionalStringParam(request, "body") + body, err := optionalParam[string](request, "body") if err != nil { return mcp.NewToolResultError(err.Error()), nil } @@ -476,7 +476,7 @@ func updateIssue(client *github.Client, t translations.TranslationHelperFunc) (t issueRequest.Body = github.Ptr(body) } - state, err := optionalStringParam(request, "state") + state, err := optionalParam[string](request, "state") if err != nil { return mcp.NewToolResultError(err.Error()), nil } @@ -500,7 +500,7 @@ func updateIssue(client *github.Client, t translations.TranslationHelperFunc) (t issueRequest.Assignees = &assignees } - milestone, err := optionalNumberParam(request, "milestone") + milestone, err := optionalIntParam(request, "milestone") if err != nil { return mcp.NewToolResultError(err.Error()), nil } diff --git a/pkg/github/pullrequests.go b/pkg/github/pullrequests.go index d5caab97..dc8b6481 100644 --- a/pkg/github/pullrequests.go +++ b/pkg/github/pullrequests.go @@ -31,15 +31,15 @@ func getPullRequest(client *github.Client, t translations.TranslationHelperFunc) ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner, err := requiredStringParam(request, "owner") + owner, err := requiredParam[string](request, "owner") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - repo, err := requiredStringParam(request, "repo") + repo, err := requiredParam[string](request, "repo") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - pullNumber, err := requiredNumberParam(request, "pull_number") + pullNumber, err := requiredInt(request, "pull_number") if err != nil { return mcp.NewToolResultError(err.Error()), nil } @@ -102,39 +102,39 @@ func listPullRequests(client *github.Client, t translations.TranslationHelperFun ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner, err := requiredStringParam(request, "owner") + owner, err := requiredParam[string](request, "owner") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - repo, err := requiredStringParam(request, "repo") + repo, err := requiredParam[string](request, "repo") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - state, err := optionalStringParam(request, "state") + state, err := optionalParam[string](request, "state") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - head, err := optionalStringParam(request, "head") + head, err := optionalParam[string](request, "head") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - base, err := optionalStringParam(request, "base") + base, err := optionalParam[string](request, "base") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - sort, err := optionalStringParam(request, "sort") + sort, err := optionalParam[string](request, "sort") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - direction, err := optionalStringParam(request, "direction") + direction, err := optionalParam[string](request, "direction") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - perPage, err := optionalNumberParamWithDefault(request, "per_page", 30) + perPage, err := optionalIntParamWithDefault(request, "per_page", 30) if err != nil { return mcp.NewToolResultError(err.Error()), nil } - page, err := optionalNumberParamWithDefault(request, "page", 1) + page, err := optionalIntParamWithDefault(request, "page", 1) if err != nil { return mcp.NewToolResultError(err.Error()), nil } @@ -201,27 +201,27 @@ func mergePullRequest(client *github.Client, t translations.TranslationHelperFun ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner, err := requiredStringParam(request, "owner") + owner, err := requiredParam[string](request, "owner") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - repo, err := requiredStringParam(request, "repo") + repo, err := requiredParam[string](request, "repo") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - pullNumber, err := requiredNumberParam(request, "pull_number") + pullNumber, err := requiredInt(request, "pull_number") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - commitTitle, err := optionalStringParam(request, "commit_title") + commitTitle, err := optionalParam[string](request, "commit_title") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - commitMessage, err := optionalStringParam(request, "commit_message") + commitMessage, err := optionalParam[string](request, "commit_message") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - mergeMethod, err := optionalStringParam(request, "merge_method") + mergeMethod, err := optionalParam[string](request, "merge_method") if err != nil { return mcp.NewToolResultError(err.Error()), nil } @@ -272,15 +272,15 @@ func getPullRequestFiles(client *github.Client, t translations.TranslationHelper ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner, err := requiredStringParam(request, "owner") + owner, err := requiredParam[string](request, "owner") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - repo, err := requiredStringParam(request, "repo") + repo, err := requiredParam[string](request, "repo") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - pullNumber, err := requiredNumberParam(request, "pull_number") + pullNumber, err := requiredInt(request, "pull_number") if err != nil { return mcp.NewToolResultError(err.Error()), nil } @@ -327,15 +327,15 @@ func getPullRequestStatus(client *github.Client, t translations.TranslationHelpe ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner, err := requiredStringParam(request, "owner") + owner, err := requiredParam[string](request, "owner") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - repo, err := requiredStringParam(request, "repo") + repo, err := requiredParam[string](request, "repo") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - pullNumber, err := requiredNumberParam(request, "pull_number") + pullNumber, err := requiredInt(request, "pull_number") if err != nil { return mcp.NewToolResultError(err.Error()), nil } @@ -399,19 +399,19 @@ func updatePullRequestBranch(client *github.Client, t translations.TranslationHe ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner, err := requiredStringParam(request, "owner") + owner, err := requiredParam[string](request, "owner") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - repo, err := requiredStringParam(request, "repo") + repo, err := requiredParam[string](request, "repo") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - pullNumber, err := requiredNumberParam(request, "pull_number") + pullNumber, err := requiredInt(request, "pull_number") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - expectedHeadSHA, err := optionalStringParam(request, "expected_head_sha") + expectedHeadSHA, err := optionalParam[string](request, "expected_head_sha") if err != nil { return mcp.NewToolResultError(err.Error()), nil } @@ -466,15 +466,15 @@ func getPullRequestComments(client *github.Client, t translations.TranslationHel ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner, err := requiredStringParam(request, "owner") + owner, err := requiredParam[string](request, "owner") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - repo, err := requiredStringParam(request, "repo") + repo, err := requiredParam[string](request, "repo") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - pullNumber, err := requiredNumberParam(request, "pull_number") + pullNumber, err := requiredInt(request, "pull_number") if err != nil { return mcp.NewToolResultError(err.Error()), nil } @@ -526,15 +526,15 @@ func getPullRequestReviews(client *github.Client, t translations.TranslationHelp ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner, err := requiredStringParam(request, "owner") + owner, err := requiredParam[string](request, "owner") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - repo, err := requiredStringParam(request, "repo") + repo, err := requiredParam[string](request, "repo") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - pullNumber, err := requiredNumberParam(request, "pull_number") + pullNumber, err := requiredInt(request, "pull_number") if err != nil { return mcp.NewToolResultError(err.Error()), nil } @@ -593,19 +593,19 @@ func createPullRequestReview(client *github.Client, t translations.TranslationHe ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner, err := requiredStringParam(request, "owner") + owner, err := requiredParam[string](request, "owner") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - repo, err := requiredStringParam(request, "repo") + repo, err := requiredParam[string](request, "repo") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - pullNumber, err := requiredNumberParam(request, "pull_number") + pullNumber, err := requiredInt(request, "pull_number") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - event, err := requiredStringParam(request, "event") + event, err := requiredParam[string](request, "event") if err != nil { return mcp.NewToolResultError(err.Error()), nil } @@ -616,7 +616,7 @@ func createPullRequestReview(client *github.Client, t translations.TranslationHe } // Add body if provided - body, err := optionalStringParam(request, "body") + body, err := optionalParam[string](request, "body") if err != nil { return mcp.NewToolResultError(err.Error()), nil } @@ -625,7 +625,7 @@ func createPullRequestReview(client *github.Client, t translations.TranslationHe } // Add commit ID if provided - commitID, err := optionalStringParam(request, "commit_id") + commitID, err := optionalParam[string](request, "commit_id") if err != nil { return mcp.NewToolResultError(err.Error()), nil } diff --git a/pkg/github/repositories.go b/pkg/github/repositories.go index f222b1f8..f507b897 100644 --- a/pkg/github/repositories.go +++ b/pkg/github/repositories.go @@ -37,23 +37,23 @@ func listCommits(client *github.Client, t translations.TranslationHelperFunc) (t ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner, err := requiredStringParam(request, "owner") + owner, err := requiredParam[string](request, "owner") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - repo, err := requiredStringParam(request, "repo") + repo, err := requiredParam[string](request, "repo") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - sha, err := optionalStringParam(request, "sha") + sha, err := optionalParam[string](request, "sha") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - page, err := optionalNumberParamWithDefault(request, "page", 1) + page, err := optionalIntParamWithDefault(request, "page", 1) if err != nil { return mcp.NewToolResultError(err.Error()), nil } - perPage, err := optionalNumberParamWithDefault(request, "per_page", 30) + perPage, err := optionalIntParamWithDefault(request, "per_page", 30) if err != nil { return mcp.NewToolResultError(err.Error()), nil } @@ -122,27 +122,27 @@ func createOrUpdateFile(client *github.Client, t translations.TranslationHelperF ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner, err := requiredStringParam(request, "owner") + owner, err := requiredParam[string](request, "owner") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - repo, err := requiredStringParam(request, "repo") + repo, err := requiredParam[string](request, "repo") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - path, err := requiredStringParam(request, "path") + path, err := requiredParam[string](request, "path") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - content, err := requiredStringParam(request, "content") + content, err := requiredParam[string](request, "content") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - message, err := requiredStringParam(request, "message") + message, err := requiredParam[string](request, "message") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - branch, err := requiredStringParam(request, "branch") + branch, err := requiredParam[string](request, "branch") if err != nil { return mcp.NewToolResultError(err.Error()), nil } @@ -158,7 +158,7 @@ func createOrUpdateFile(client *github.Client, t translations.TranslationHelperF } // If SHA is provided, set it (for updates) - sha, err := optionalStringParam(request, "sha") + sha, err := optionalParam[string](request, "sha") if err != nil { return mcp.NewToolResultError(err.Error()), nil } @@ -209,19 +209,19 @@ func createRepository(client *github.Client, t translations.TranslationHelperFun ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - name, err := requiredStringParam(request, "name") + name, err := requiredParam[string](request, "name") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - description, err := optionalStringParam(request, "description") + description, err := optionalParam[string](request, "description") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - private, err := optionalBooleanParam(request, "private") + private, err := optionalParam[bool](request, "private") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - autoInit, err := optionalBooleanParam(request, "auto_init") + autoInit, err := optionalParam[bool](request, "auto_init") if err != nil { return mcp.NewToolResultError(err.Error()), nil } @@ -277,19 +277,19 @@ func getFileContents(client *github.Client, t translations.TranslationHelperFunc ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner, err := requiredStringParam(request, "owner") + owner, err := requiredParam[string](request, "owner") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - repo, err := requiredStringParam(request, "repo") + repo, err := requiredParam[string](request, "repo") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - path, err := requiredStringParam(request, "path") + path, err := requiredParam[string](request, "path") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - branch, err := optionalStringParam(request, "branch") + branch, err := optionalParam[string](request, "branch") if err != nil { return mcp.NewToolResultError(err.Error()), nil } @@ -342,15 +342,15 @@ func forkRepository(client *github.Client, t translations.TranslationHelperFunc) ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner, err := requiredStringParam(request, "owner") + owner, err := requiredParam[string](request, "owner") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - repo, err := requiredStringParam(request, "repo") + repo, err := requiredParam[string](request, "repo") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - org, err := optionalStringParam(request, "organization") + org, err := optionalParam[string](request, "organization") if err != nil { return mcp.NewToolResultError(err.Error()), nil } @@ -409,19 +409,19 @@ func createBranch(client *github.Client, t translations.TranslationHelperFunc) ( ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner, err := requiredStringParam(request, "owner") + owner, err := requiredParam[string](request, "owner") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - repo, err := requiredStringParam(request, "repo") + repo, err := requiredParam[string](request, "repo") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - branch, err := requiredStringParam(request, "branch") + branch, err := requiredParam[string](request, "branch") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - fromBranch, err := optionalStringParam(request, "from_branch") + fromBranch, err := optionalParam[string](request, "from_branch") if err != nil { return mcp.NewToolResultError(err.Error()), nil } @@ -494,19 +494,19 @@ func pushFiles(client *github.Client, t translations.TranslationHelperFunc) (too ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner, err := requiredStringParam(request, "owner") + owner, err := requiredParam[string](request, "owner") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - repo, err := requiredStringParam(request, "repo") + repo, err := requiredParam[string](request, "repo") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - branch, err := requiredStringParam(request, "branch") + branch, err := requiredParam[string](request, "branch") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - message, err := requiredStringParam(request, "message") + message, err := requiredParam[string](request, "message") if err != nil { return mcp.NewToolResultError(err.Error()), nil } diff --git a/pkg/github/search.go b/pkg/github/search.go index d7ea4904..904dc737 100644 --- a/pkg/github/search.go +++ b/pkg/github/search.go @@ -28,15 +28,15 @@ func searchRepositories(client *github.Client, t translations.TranslationHelperF ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - query, err := requiredStringParam(request, "query") + query, err := requiredParam[string](request, "query") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - page, err := optionalNumberParamWithDefault(request, "page", 1) + page, err := optionalIntParamWithDefault(request, "page", 1) if err != nil { return mcp.NewToolResultError(err.Error()), nil } - perPage, err := optionalNumberParamWithDefault(request, "per_page", 30) + perPage, err := optionalIntParamWithDefault(request, "per_page", 30) if err != nil { return mcp.NewToolResultError(err.Error()), nil } @@ -93,23 +93,23 @@ func searchCode(client *github.Client, t translations.TranslationHelperFunc) (to ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - query, err := requiredStringParam(request, "q") + query, err := requiredParam[string](request, "q") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - sort, err := optionalStringParam(request, "sort") + sort, err := optionalParam[string](request, "sort") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - order, err := optionalStringParam(request, "order") + order, err := optionalParam[string](request, "order") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - perPage, err := optionalNumberParamWithDefault(request, "per_page", 30) + perPage, err := optionalIntParamWithDefault(request, "per_page", 30) if err != nil { return mcp.NewToolResultError(err.Error()), nil } - page, err := optionalNumberParamWithDefault(request, "page", 1) + page, err := optionalIntParamWithDefault(request, "page", 1) if err != nil { return mcp.NewToolResultError(err.Error()), nil } @@ -168,23 +168,23 @@ func searchUsers(client *github.Client, t translations.TranslationHelperFunc) (t ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - query, err := requiredStringParam(request, "q") + query, err := requiredParam[string](request, "q") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - sort, err := optionalStringParam(request, "sort") + sort, err := optionalParam[string](request, "sort") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - order, err := optionalStringParam(request, "order") + order, err := optionalParam[string](request, "order") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - perPage, err := optionalNumberParamWithDefault(request, "per_page", 30) + perPage, err := optionalIntParamWithDefault(request, "per_page", 30) if err != nil { return mcp.NewToolResultError(err.Error()), nil } - page, err := optionalNumberParamWithDefault(request, "page", 1) + page, err := optionalIntParamWithDefault(request, "page", 1) if err != nil { return mcp.NewToolResultError(err.Error()), nil } diff --git a/pkg/github/server.go b/pkg/github/server.go index 42e23083..829994f1 100644 --- a/pkg/github/server.go +++ b/pkg/github/server.go @@ -139,76 +139,81 @@ func parseCommaSeparatedList(input string) []string { return result } -// requiredStringParam checks if the parameter is present in the request and is of type string. -func requiredStringParam(r mcp.CallToolRequest, p string) (string, error) { +// requiredParam is a helper function that can be used to fetch a requested parameter from the request. +// It does the following checks: +// 1. Checks if the parameter is present in the request. +// 2. Checks if the parameter is of the expected type. +// 3. Checks if the parameter is not empty, i.e: non-zero value +func requiredParam[T comparable](r mcp.CallToolRequest, p string) (T, error) { + var zero T + // Check if the parameter is present in the request if _, ok := r.Params.Arguments[p]; !ok { - return "", fmt.Errorf("missing required parameter: %s", p) + return zero, fmt.Errorf("missing required parameter: %s", p) } // Check if the parameter is of the expected type - if _, ok := r.Params.Arguments[p].(string); !ok { - return "", fmt.Errorf("parameter %s is not of type string", p) + if _, ok := r.Params.Arguments[p].(T); !ok { + return zero, fmt.Errorf("parameter %s is not of type %T", p, zero) } - // Check if the parameter is not the zero value - v := r.Params.Arguments[p].(string) - if v == "" { - return v, fmt.Errorf("missing required parameter: %s", p) + if r.Params.Arguments[p].(T) == zero { + return zero, fmt.Errorf("missing required parameter: %s", p) + } - return v, nil + return r.Params.Arguments[p].(T), nil } -// requiredNumberParam checks if the parameter is present in the request and is of type number. -func requiredNumberParam(r mcp.CallToolRequest, p string) (int, error) { - // Check if the parameter is present in the request - if _, ok := r.Params.Arguments[p]; !ok { - return 0, fmt.Errorf("missing required parameter: %s", p) - } - - // Check if the parameter is of the expected type - if _, ok := r.Params.Arguments[p].(float64); !ok { - return 0, fmt.Errorf("parameter %s is not of type number", p) +// requiredInt is a helper function that can be used to fetch a requested parameter from the request. +// It does the following checks: +// 1. Checks if the parameter is present in the request. +// 2. Checks if the parameter is of the expected type. +// 3. Checks if the parameter is not empty, i.e: non-zero value +func requiredInt(r mcp.CallToolRequest, p string) (int, error) { + v, err := requiredParam[float64](r, p) + if err != nil { + return 0, err } - - return int(r.Params.Arguments[p].(float64)), nil + return int(v), nil } -// optionalStringParam checks if an optional parameter is present in the request and is of type string. -func optionalStringParam(r mcp.CallToolRequest, p string) (value string, err error) { +// optionalParam is a helper function that can be used to fetch a requested parameter from the request. +// It does the following checks: +// 1. Checks if the parameter is present in the request, if not, it returns its zero-value +// 2. If it is present, it checks if the parameter is of the expected type and returns it +func optionalParam[T any](r mcp.CallToolRequest, p string) (T, error) { + var zero T + // Check if the parameter is present in the request if _, ok := r.Params.Arguments[p]; !ok { - return "", nil + return zero, nil } // Check if the parameter is of the expected type - if _, ok := r.Params.Arguments[p].(string); !ok { - return "", fmt.Errorf("parameter %s is not of type string", p) + if _, ok := r.Params.Arguments[p].(T); !ok { + return zero, fmt.Errorf("parameter %s is not of type %T", p, zero) } - return r.Params.Arguments[p].(string), nil + return r.Params.Arguments[p].(T), nil } -// optionalNumberParam checks if an optional parameter is present in the request and is of type number. -func optionalNumberParam(r mcp.CallToolRequest, p string) (int, error) { - // Check if the parameter is present in the request - if _, ok := r.Params.Arguments[p]; !ok { - return 0, nil - } - - // Check if the parameter is of the expected type - if _, ok := r.Params.Arguments[p].(float64); !ok { - return 0, fmt.Errorf("parameter %s is not of type number", p) +// optionalIntParam is a helper function that can be used to fetch a requested parameter from the request. +// It does the following checks: +// 1. Checks if the parameter is present in the request, if not, it returns its zero-value +// 2. If it is present, it checks if the parameter is of the expected type and returns it +func optionalIntParam(r mcp.CallToolRequest, p string) (int, error) { + v, err := optionalParam[float64](r, p) + if err != nil { + return 0, err } - - return int(r.Params.Arguments[p].(float64)), nil + return int(v), nil } -// optionalNumberParamWithDefault checks if an optional parameter is present in the request and is of type number. -// If the parameter is not present or is zero, it returns the default value. -func optionalNumberParamWithDefault(r mcp.CallToolRequest, p string, d int) (int, error) { - v, err := optionalNumberParam(r, p) +// optionalIntParamWithDefault is a helper function that can be used to fetch a requested parameter from the request +// similar to optionalIntParam, but it also takes a default value. +func optionalIntParamWithDefault(r mcp.CallToolRequest, p string, d int) (int, error) { + v, err := optionalIntParam(r, p) if err != nil { return 0, err } @@ -218,38 +223,19 @@ func optionalNumberParamWithDefault(r mcp.CallToolRequest, p string, d int) (int return v, nil } -// optionalCommaSeparatedListParam checks if an optional parameter is present in the request and is of type string. -// If the parameter is presents, it uses parseCommaSeparatedList to parse the string into a list of strings. -// If the parameter is not present or is empty, it returns an empty list. +// optionalCommaSeparatedListParam is a helper function that can be used to fetch a requested parameter from the request. +// It does the following: +// 1. Checks if the parameter is present in the request, if not, it returns an empty list +// 2. If it is present, it checks if the parameter is of the expected type and uses parseCommaSeparatedList to parse it +// and return the list of strings func optionalCommaSeparatedListParam(r mcp.CallToolRequest, p string) ([]string, error) { - // Check if the parameter is present in the request - if _, ok := r.Params.Arguments[p]; !ok { - return []string{}, nil //default to empty list, not nil - } - - // Check if the parameter is of the expected type - if _, ok := r.Params.Arguments[p].(string); !ok { - return nil, fmt.Errorf("parameter %s is not of type string", p) + v, err := optionalParam[string](r, p) + if err != nil { + return []string{}, err } - - l := parseCommaSeparatedList(r.Params.Arguments[p].(string)) + l := parseCommaSeparatedList(v) if len(l) == 0 { - return []string{}, nil // default to empty list, not nil + return []string{}, nil } return l, nil } - -// optionalBooleanParam checks if an optional parameter is present in the request and is of type boolean. -func optionalBooleanParam(r mcp.CallToolRequest, p string) (bool, error) { - // Check if the parameter is present in the request - if _, ok := r.Params.Arguments[p]; !ok { - return false, nil - } - - // Check if the parameter is of the expected type - if _, ok := r.Params.Arguments[p].(bool); !ok { - return false, fmt.Errorf("parameter %s is not of type bool", p) - } - - return r.Params.Arguments[p].(bool), nil -} diff --git a/pkg/github/server_test.go b/pkg/github/server_test.go index a081d31d..5e7ac9d4 100644 --- a/pkg/github/server_test.go +++ b/pkg/github/server_test.go @@ -270,7 +270,7 @@ func Test_RequiredStringParam(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { request := createMCPRequest(tc.params) - result, err := requiredStringParam(request, tc.paramName) + result, err := requiredParam[string](request, tc.paramName) if tc.expectError { assert.Error(t, err) @@ -323,7 +323,7 @@ func Test_OptionalStringParam(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { request := createMCPRequest(tc.params) - result, err := optionalStringParam(request, tc.paramName) + result, err := optionalParam[string](request, tc.paramName) if tc.expectError { assert.Error(t, err) @@ -369,7 +369,7 @@ func Test_RequiredNumberParam(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { request := createMCPRequest(tc.params) - result, err := requiredNumberParam(request, tc.paramName) + result, err := requiredInt(request, tc.paramName) if tc.expectError { assert.Error(t, err) @@ -422,7 +422,7 @@ func Test_OptionalNumberParam(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { request := createMCPRequest(tc.params) - result, err := optionalNumberParam(request, tc.paramName) + result, err := optionalIntParam(request, tc.paramName) if tc.expectError { assert.Error(t, err) @@ -480,7 +480,7 @@ func Test_OptionalNumberParamWithDefault(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { request := createMCPRequest(tc.params) - result, err := optionalNumberParamWithDefault(request, tc.paramName, tc.defaultVal) + result, err := optionalIntParamWithDefault(request, tc.paramName, tc.defaultVal) if tc.expectError { assert.Error(t, err) @@ -586,7 +586,7 @@ func Test_OptionalBooleanParam(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { request := createMCPRequest(tc.params) - result, err := optionalBooleanParam(request, tc.paramName) + result, err := optionalParam[bool](request, tc.paramName) if tc.expectError { assert.Error(t, err)