diff --git a/pkg/github/code_scanning.go b/pkg/github/code_scanning.go index 6fc0936a..380dc02c 100644 --- a/pkg/github/code_scanning.go +++ b/pkg/github/code_scanning.go @@ -30,9 +30,18 @@ func getCodeScanningAlert(client *github.Client, t translations.TranslationHelpe ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner, _ := request.Params.Arguments["owner"].(string) - repo, _ := request.Params.Arguments["repo"].(string) - alertNumber, _ := request.Params.Arguments["alert_number"].(float64) + owner, err := requiredParam[string](request, "owner") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + repo, err := requiredParam[string](request, "repo") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + alertNumber, err := requiredInt(request, "alert_number") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } alert, resp, err := client.CodeScanning.GetAlert(ctx, owner, repo, int64(alertNumber)) if err != nil { @@ -80,11 +89,26 @@ func listCodeScanningAlerts(client *github.Client, t translations.TranslationHel ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner, _ := request.Params.Arguments["owner"].(string) - repo, _ := request.Params.Arguments["repo"].(string) - ref, _ := request.Params.Arguments["ref"].(string) - state, _ := request.Params.Arguments["state"].(string) - severity, _ := request.Params.Arguments["severity"].(string) + owner, err := requiredParam[string](request, "owner") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + repo, err := requiredParam[string](request, "repo") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + ref, err := optionalParam[string](request, "ref") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + state, err := optionalParam[string](request, "state") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + severity, err := optionalParam[string](request, "severity") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } alerts, resp, err := client.CodeScanning.ListAlertsForRepo(ctx, owner, repo, &github.AlertListOptions{Ref: ref, State: state, Severity: severity}) if err != nil { diff --git a/pkg/github/issues.go b/pkg/github/issues.go index 36130b98..a62213ea 100644 --- a/pkg/github/issues.go +++ b/pkg/github/issues.go @@ -32,9 +32,18 @@ func getIssue(client *github.Client, t translations.TranslationHelperFunc) (tool ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner := request.Params.Arguments["owner"].(string) - repo := request.Params.Arguments["repo"].(string) - issueNumber := int(request.Params.Arguments["issue_number"].(float64)) + owner, err := requiredParam[string](request, "owner") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + repo, err := requiredParam[string](request, "repo") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + issueNumber, err := requiredInt(request, "issue_number") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } issue, resp, err := client.Issues.Get(ctx, owner, repo, issueNumber) if err != nil { @@ -81,10 +90,22 @@ func addIssueComment(client *github.Client, t translations.TranslationHelperFunc ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner := request.Params.Arguments["owner"].(string) - repo := request.Params.Arguments["repo"].(string) - issueNumber := int(request.Params.Arguments["issue_number"].(float64)) - body := request.Params.Arguments["body"].(string) + owner, err := requiredParam[string](request, "owner") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + repo, err := requiredParam[string](request, "repo") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + issueNumber, err := requiredInt(request, "issue_number") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + body, err := requiredParam[string](request, "body") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } comment := &github.IssueComment{ Body: github.Ptr(body), @@ -135,22 +156,25 @@ func searchIssues(client *github.Client, t translations.TranslationHelperFunc) ( ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - query := request.Params.Arguments["q"].(string) - sort := "" - if s, ok := request.Params.Arguments["sort"].(string); ok { - sort = s + query, err := requiredParam[string](request, "q") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil } - order := "" - if o, ok := request.Params.Arguments["order"].(string); ok { - order = o + sort, err := optionalParam[string](request, "sort") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil } - perPage := 30 - if pp, ok := request.Params.Arguments["per_page"].(float64); ok { - perPage = int(pp) + order, err := optionalParam[string](request, "order") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil } - page := 1 - if p, ok := request.Params.Arguments["page"].(float64); ok { - page = int(p) + perPage, err := optionalIntParamWithDefault(request, "per_page", 30) + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + page, err := optionalIntParamWithDefault(request, "page", 1) + if err != nil { + return mcp.NewToolResultError(err.Error()), nil } opts := &github.SearchOptions{ @@ -212,26 +236,34 @@ func createIssue(client *github.Client, t translations.TranslationHelperFunc) (t ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner := request.Params.Arguments["owner"].(string) - repo := request.Params.Arguments["repo"].(string) - title := request.Params.Arguments["title"].(string) + owner, err := requiredParam[string](request, "owner") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + repo, err := requiredParam[string](request, "repo") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + title, err := requiredParam[string](request, "title") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } // Optional parameters - var body string - if b, ok := request.Params.Arguments["body"].(string); ok { - body = b + body, err := optionalParam[string](request, "body") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil } - // Parse assignees if present - assignees := []string{} // default to empty slice, can't be nil - if a, ok := request.Params.Arguments["assignees"].(string); ok && a != "" { - assignees = parseCommaSeparatedList(a) + // Get assignees + assignees, err := optionalCommaSeparatedListParam(request, "assignees") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil } - - // Parse labels if present - labels := []string{} // default to empty slice, can't be nil - if l, ok := request.Params.Arguments["labels"].(string); ok && l != "" { - labels = parseCommaSeparatedList(l) + // Get labels + labels, err := optionalCommaSeparatedListParam(request, "labels") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil } // Create the issue request @@ -300,29 +332,43 @@ func listIssues(client *github.Client, t translations.TranslationHelperFunc) (to ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner := request.Params.Arguments["owner"].(string) - repo := request.Params.Arguments["repo"].(string) + owner, err := requiredParam[string](request, "owner") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + repo, err := requiredParam[string](request, "repo") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } opts := &github.IssueListByRepoOptions{} // Set optional parameters if provided - if state, ok := request.Params.Arguments["state"].(string); ok && state != "" { - opts.State = state + opts.State, err = optionalParam[string](request, "state") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil } - if labels, ok := request.Params.Arguments["labels"].(string); ok && labels != "" { - opts.Labels = parseCommaSeparatedList(labels) + opts.Labels, err = optionalCommaSeparatedListParam(request, "labels") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil } - if sort, ok := request.Params.Arguments["sort"].(string); ok && sort != "" { - opts.Sort = sort + opts.Sort, err = optionalParam[string](request, "sort") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil } - if direction, ok := request.Params.Arguments["direction"].(string); ok && direction != "" { - opts.Direction = direction + opts.Direction, err = optionalParam[string](request, "direction") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil } - if since, ok := request.Params.Arguments["since"].(string); ok && since != "" { + since, err := optionalParam[string](request, "since") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + if since != "" { timestamp, err := parseISOTimestamp(since) if err != nil { return mcp.NewToolResultError(fmt.Sprintf("failed to list issues: %s", err.Error())), nil @@ -397,38 +443,69 @@ func updateIssue(client *github.Client, t translations.TranslationHelperFunc) (t ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner := request.Params.Arguments["owner"].(string) - repo := request.Params.Arguments["repo"].(string) - issueNumber := int(request.Params.Arguments["issue_number"].(float64)) + owner, err := requiredParam[string](request, "owner") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + repo, err := requiredParam[string](request, "repo") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + issueNumber, err := requiredInt(request, "issue_number") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } // Create the issue request with only provided fields issueRequest := &github.IssueRequest{} // Set optional parameters if provided - if title, ok := request.Params.Arguments["title"].(string); ok && title != "" { + title, err := optionalParam[string](request, "title") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + if title != "" { issueRequest.Title = github.Ptr(title) } - if body, ok := request.Params.Arguments["body"].(string); ok && body != "" { + body, err := optionalParam[string](request, "body") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + if body != "" { issueRequest.Body = github.Ptr(body) } - if state, ok := request.Params.Arguments["state"].(string); ok && state != "" { + state, err := optionalParam[string](request, "state") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + if state != "" { issueRequest.State = github.Ptr(state) } - if labels, ok := request.Params.Arguments["labels"].(string); ok && labels != "" { - labelsList := parseCommaSeparatedList(labels) - issueRequest.Labels = &labelsList + labels, err := optionalCommaSeparatedListParam(request, "labels") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + if len(labels) > 0 { + issueRequest.Labels = &labels } - if assignees, ok := request.Params.Arguments["assignees"].(string); ok && assignees != "" { - assigneesList := parseCommaSeparatedList(assignees) - issueRequest.Assignees = &assigneesList + assignees, err := optionalCommaSeparatedListParam(request, "assignees") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + if len(assignees) > 0 { + issueRequest.Assignees = &assignees } - if milestone, ok := request.Params.Arguments["milestone"].(float64); ok { - milestoneNum := int(milestone) + milestone, err := optionalIntParam(request, "milestone") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + if milestone != 0 { + milestoneNum := milestone issueRequest.Milestone = &milestoneNum } diff --git a/pkg/github/issues_test.go b/pkg/github/issues_test.go index 4e8250fd..c2de6579 100644 --- a/pkg/github/issues_test.go +++ b/pkg/github/issues_test.go @@ -176,8 +176,8 @@ func Test_AddIssueComment(t *testing.T) { "issue_number": float64(42), "body": "", }, - expectError: true, - expectedErrMsg: "failed to create comment", + expectError: false, + expectedErrMsg: "missing required parameter: body", }, } @@ -210,6 +210,13 @@ func Test_AddIssueComment(t *testing.T) { return } + if tc.expectedErrMsg != "" { + require.NotNil(t, result) + textContent := getTextResult(t, result) + assert.Contains(t, textContent.Text, tc.expectedErrMsg) + return + } + require.NoError(t, err) // Parse the result and get the text content if no error @@ -419,8 +426,8 @@ func Test_CreateIssue(t *testing.T) { "repo": "repo", "title": "Test Issue", "body": "This is a test issue", - "assignees": []interface{}{"user1", "user2"}, - "labels": []interface{}{"bug", "help wanted"}, + "assignees": "user1, user2", + "labels": "bug, help wanted", }, expectError: false, expectedIssue: mockIssue, @@ -467,8 +474,8 @@ func Test_CreateIssue(t *testing.T) { "repo": "repo", "title": "", }, - expectError: true, - expectedErrMsg: "failed to create issue", + expectError: false, + expectedErrMsg: "missing required parameter: title", }, } @@ -491,6 +498,13 @@ func Test_CreateIssue(t *testing.T) { return } + if tc.expectedErrMsg != "" { + require.NotNil(t, result) + textContent := getTextResult(t, result) + assert.Contains(t, textContent.Text, tc.expectedErrMsg) + return + } + require.NoError(t, err) textContent := getTextResult(t, result) diff --git a/pkg/github/pullrequests.go b/pkg/github/pullrequests.go index e0414394..dc8b6481 100644 --- a/pkg/github/pullrequests.go +++ b/pkg/github/pullrequests.go @@ -31,9 +31,18 @@ func getPullRequest(client *github.Client, t translations.TranslationHelperFunc) ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner := request.Params.Arguments["owner"].(string) - repo := request.Params.Arguments["repo"].(string) - pullNumber := int(request.Params.Arguments["pull_number"].(float64)) + owner, err := requiredParam[string](request, "owner") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + repo, err := requiredParam[string](request, "repo") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + pullNumber, err := requiredInt(request, "pull_number") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } pr, resp, err := client.PullRequests.Get(ctx, owner, repo, pullNumber) if err != nil { @@ -93,35 +102,41 @@ func listPullRequests(client *github.Client, t translations.TranslationHelperFun ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner := request.Params.Arguments["owner"].(string) - repo := request.Params.Arguments["repo"].(string) - state := "" - if s, ok := request.Params.Arguments["state"].(string); ok { - state = s + owner, err := requiredParam[string](request, "owner") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + repo, err := requiredParam[string](request, "repo") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil } - head := "" - if h, ok := request.Params.Arguments["head"].(string); ok { - head = h + state, err := optionalParam[string](request, "state") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil } - base := "" - if b, ok := request.Params.Arguments["base"].(string); ok { - base = b + head, err := optionalParam[string](request, "head") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil } - sort := "" - if s, ok := request.Params.Arguments["sort"].(string); ok { - sort = s + base, err := optionalParam[string](request, "base") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil } - direction := "" - if d, ok := request.Params.Arguments["direction"].(string); ok { - direction = d + sort, err := optionalParam[string](request, "sort") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil } - perPage := 30 - if pp, ok := request.Params.Arguments["per_page"].(float64); ok { - perPage = int(pp) + direction, err := optionalParam[string](request, "direction") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil } - page := 1 - if p, ok := request.Params.Arguments["page"].(float64); ok { - page = int(p) + perPage, err := optionalIntParamWithDefault(request, "per_page", 30) + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + page, err := optionalIntParamWithDefault(request, "page", 1) + if err != nil { + return mcp.NewToolResultError(err.Error()), nil } opts := &github.PullRequestListOptions{ @@ -186,20 +201,29 @@ func mergePullRequest(client *github.Client, t translations.TranslationHelperFun ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner := request.Params.Arguments["owner"].(string) - repo := request.Params.Arguments["repo"].(string) - pullNumber := int(request.Params.Arguments["pull_number"].(float64)) - commitTitle := "" - if ct, ok := request.Params.Arguments["commit_title"].(string); ok { - commitTitle = ct + owner, err := requiredParam[string](request, "owner") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil } - commitMessage := "" - if cm, ok := request.Params.Arguments["commit_message"].(string); ok { - commitMessage = cm + repo, err := requiredParam[string](request, "repo") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil } - mergeMethod := "" - if mm, ok := request.Params.Arguments["merge_method"].(string); ok { - mergeMethod = mm + pullNumber, err := requiredInt(request, "pull_number") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + commitTitle, err := optionalParam[string](request, "commit_title") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + commitMessage, err := optionalParam[string](request, "commit_message") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + mergeMethod, err := optionalParam[string](request, "merge_method") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil } options := &github.PullRequestOptions{ @@ -248,9 +272,18 @@ func getPullRequestFiles(client *github.Client, t translations.TranslationHelper ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner := request.Params.Arguments["owner"].(string) - repo := request.Params.Arguments["repo"].(string) - pullNumber := int(request.Params.Arguments["pull_number"].(float64)) + owner, err := requiredParam[string](request, "owner") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + repo, err := requiredParam[string](request, "repo") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + pullNumber, err := requiredInt(request, "pull_number") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } opts := &github.ListOptions{} files, resp, err := client.PullRequests.ListFiles(ctx, owner, repo, pullNumber, opts) @@ -294,10 +327,18 @@ func getPullRequestStatus(client *github.Client, t translations.TranslationHelpe ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner := request.Params.Arguments["owner"].(string) - repo := request.Params.Arguments["repo"].(string) - pullNumber := int(request.Params.Arguments["pull_number"].(float64)) - + owner, err := requiredParam[string](request, "owner") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + repo, err := requiredParam[string](request, "repo") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + pullNumber, err := requiredInt(request, "pull_number") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } // First get the PR to find the head SHA pr, resp, err := client.PullRequests.Get(ctx, owner, repo, pullNumber) if err != nil { @@ -358,14 +399,22 @@ func updatePullRequestBranch(client *github.Client, t translations.TranslationHe ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner := request.Params.Arguments["owner"].(string) - repo := request.Params.Arguments["repo"].(string) - pullNumber := int(request.Params.Arguments["pull_number"].(float64)) - expectedHeadSHA := "" - if sha, ok := request.Params.Arguments["expected_head_sha"].(string); ok { - expectedHeadSHA = sha + owner, err := requiredParam[string](request, "owner") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + repo, err := requiredParam[string](request, "repo") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + pullNumber, err := requiredInt(request, "pull_number") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + expectedHeadSHA, err := optionalParam[string](request, "expected_head_sha") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil } - opts := &github.PullRequestBranchUpdateOptions{} if expectedHeadSHA != "" { opts.ExpectedHeadSHA = github.Ptr(expectedHeadSHA) @@ -417,9 +466,18 @@ func getPullRequestComments(client *github.Client, t translations.TranslationHel ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner := request.Params.Arguments["owner"].(string) - repo := request.Params.Arguments["repo"].(string) - pullNumber := int(request.Params.Arguments["pull_number"].(float64)) + owner, err := requiredParam[string](request, "owner") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + repo, err := requiredParam[string](request, "repo") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + pullNumber, err := requiredInt(request, "pull_number") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } opts := &github.PullRequestListCommentsOptions{ ListOptions: github.ListOptions{ @@ -468,9 +526,18 @@ func getPullRequestReviews(client *github.Client, t translations.TranslationHelp ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner := request.Params.Arguments["owner"].(string) - repo := request.Params.Arguments["repo"].(string) - pullNumber := int(request.Params.Arguments["pull_number"].(float64)) + owner, err := requiredParam[string](request, "owner") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + repo, err := requiredParam[string](request, "repo") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + pullNumber, err := requiredInt(request, "pull_number") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } reviews, resp, err := client.PullRequests.ListReviews(ctx, owner, repo, pullNumber, nil) if err != nil { @@ -526,10 +593,22 @@ func createPullRequestReview(client *github.Client, t translations.TranslationHe ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner := request.Params.Arguments["owner"].(string) - repo := request.Params.Arguments["repo"].(string) - pullNumber := int(request.Params.Arguments["pull_number"].(float64)) - event := request.Params.Arguments["event"].(string) + owner, err := requiredParam[string](request, "owner") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + repo, err := requiredParam[string](request, "repo") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + pullNumber, err := requiredInt(request, "pull_number") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + event, err := requiredParam[string](request, "event") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } // Create review request reviewRequest := &github.PullRequestReviewRequest{ @@ -537,12 +616,20 @@ func createPullRequestReview(client *github.Client, t translations.TranslationHe } // Add body if provided - if body, ok := request.Params.Arguments["body"].(string); ok && body != "" { + body, err := optionalParam[string](request, "body") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + if body != "" { reviewRequest.Body = github.Ptr(body) } // Add commit ID if provided - if commitID, ok := request.Params.Arguments["commit_id"].(string); ok && commitID != "" { + commitID, err := optionalParam[string](request, "commit_id") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + if commitID != "" { reviewRequest.CommitID = github.Ptr(commitID) } diff --git a/pkg/github/repositories.go b/pkg/github/repositories.go index 6e3b176d..f507b897 100644 --- a/pkg/github/repositories.go +++ b/pkg/github/repositories.go @@ -37,19 +37,25 @@ func listCommits(client *github.Client, t translations.TranslationHelperFunc) (t ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner := request.Params.Arguments["owner"].(string) - repo := request.Params.Arguments["repo"].(string) - sha := "" - if s, ok := request.Params.Arguments["sha"].(string); ok { - sha = s + owner, err := requiredParam[string](request, "owner") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + repo, err := requiredParam[string](request, "repo") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil } - page := 1 - if p, ok := request.Params.Arguments["page"].(float64); ok { - page = int(p) + sha, err := optionalParam[string](request, "sha") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil } - perPage := 30 - if pp, ok := request.Params.Arguments["per_page"].(float64); ok { - perPage = int(pp) + page, err := optionalIntParamWithDefault(request, "page", 1) + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + perPage, err := optionalIntParamWithDefault(request, "per_page", 30) + if err != nil { + return mcp.NewToolResultError(err.Error()), nil } opts := &github.CommitsListOptions{ @@ -116,12 +122,30 @@ func createOrUpdateFile(client *github.Client, t translations.TranslationHelperF ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner := request.Params.Arguments["owner"].(string) - repo := request.Params.Arguments["repo"].(string) - path := request.Params.Arguments["path"].(string) - content := request.Params.Arguments["content"].(string) - message := request.Params.Arguments["message"].(string) - branch := request.Params.Arguments["branch"].(string) + owner, err := requiredParam[string](request, "owner") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + repo, err := requiredParam[string](request, "repo") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + path, err := requiredParam[string](request, "path") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + content, err := requiredParam[string](request, "content") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + message, err := requiredParam[string](request, "message") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + branch, err := requiredParam[string](request, "branch") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } // Convert content to base64 contentBytes := []byte(content) @@ -134,7 +158,11 @@ func createOrUpdateFile(client *github.Client, t translations.TranslationHelperF } // If SHA is provided, set it (for updates) - if sha, ok := request.Params.Arguments["sha"].(string); ok && sha != "" { + sha, err := optionalParam[string](request, "sha") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + if sha != "" { opts.SHA = ptr.String(sha) } @@ -181,25 +209,28 @@ func createRepository(client *github.Client, t translations.TranslationHelperFun ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - name := request.Params.Arguments["name"].(string) - description := "" - if desc, ok := request.Params.Arguments["description"].(string); ok { - description = desc + name, err := requiredParam[string](request, "name") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + description, err := optionalParam[string](request, "description") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil } - private := false - if priv, ok := request.Params.Arguments["private"].(bool); ok { - private = priv + private, err := optionalParam[bool](request, "private") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil } - autoInit := false - if init, ok := request.Params.Arguments["auto_init"].(bool); ok { - autoInit = init + autoInit, err := optionalParam[bool](request, "auto_init") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil } repo := &github.Repository{ - Name: github.String(name), - Description: github.String(description), - Private: github.Bool(private), - AutoInit: github.Bool(autoInit), + Name: github.Ptr(name), + Description: github.Ptr(description), + Private: github.Ptr(private), + AutoInit: github.Ptr(autoInit), } createdRepo, resp, err := client.Repositories.Create(ctx, "", repo) @@ -246,12 +277,21 @@ func getFileContents(client *github.Client, t translations.TranslationHelperFunc ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner := request.Params.Arguments["owner"].(string) - repo := request.Params.Arguments["repo"].(string) - path := request.Params.Arguments["path"].(string) - branch := "" - if b, ok := request.Params.Arguments["branch"].(string); ok { - branch = b + owner, err := requiredParam[string](request, "owner") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + repo, err := requiredParam[string](request, "repo") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + path, err := requiredParam[string](request, "path") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + branch, err := optionalParam[string](request, "branch") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil } opts := &github.RepositoryContentGetOptions{Ref: branch} @@ -302,11 +342,17 @@ func forkRepository(client *github.Client, t translations.TranslationHelperFunc) ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner := request.Params.Arguments["owner"].(string) - repo := request.Params.Arguments["repo"].(string) - org := "" - if o, ok := request.Params.Arguments["organization"].(string); ok { - org = o + owner, err := requiredParam[string](request, "owner") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + repo, err := requiredParam[string](request, "repo") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + org, err := optionalParam[string](request, "organization") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil } opts := &github.RepositoryCreateForkOptions{} @@ -363,17 +409,25 @@ func createBranch(client *github.Client, t translations.TranslationHelperFunc) ( ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner := request.Params.Arguments["owner"].(string) - repo := request.Params.Arguments["repo"].(string) - branch := request.Params.Arguments["branch"].(string) - fromBranch := "" - if fb, ok := request.Params.Arguments["from_branch"].(string); ok { - fromBranch = fb + owner, err := requiredParam[string](request, "owner") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + repo, err := requiredParam[string](request, "repo") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + branch, err := requiredParam[string](request, "branch") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + fromBranch, err := optionalParam[string](request, "from_branch") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil } // Get the source branch SHA var ref *github.Reference - var err error if fromBranch == "" { // Get default branch if from_branch not specified @@ -440,10 +494,22 @@ func pushFiles(client *github.Client, t translations.TranslationHelperFunc) (too ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner := request.Params.Arguments["owner"].(string) - repo := request.Params.Arguments["repo"].(string) - branch := request.Params.Arguments["branch"].(string) - message := request.Params.Arguments["message"].(string) + owner, err := requiredParam[string](request, "owner") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + repo, err := requiredParam[string](request, "repo") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + branch, err := requiredParam[string](request, "branch") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + message, err := requiredParam[string](request, "message") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } // Parse files parameter - this should be an array of objects with path and content filesObj, ok := request.Params.Arguments["files"].([]interface{}) diff --git a/pkg/github/search.go b/pkg/github/search.go index 353c6fb2..904dc737 100644 --- a/pkg/github/search.go +++ b/pkg/github/search.go @@ -28,14 +28,17 @@ func searchRepositories(client *github.Client, t translations.TranslationHelperF ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - query := request.Params.Arguments["query"].(string) - page := 1 - if p, ok := request.Params.Arguments["page"].(float64); ok { - page = int(p) + query, err := requiredParam[string](request, "query") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil } - perPage := 30 - if pp, ok := request.Params.Arguments["per_page"].(float64); ok { - perPage = int(pp) + page, err := optionalIntParamWithDefault(request, "page", 1) + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + perPage, err := optionalIntParamWithDefault(request, "per_page", 30) + if err != nil { + return mcp.NewToolResultError(err.Error()), nil } opts := &github.SearchOptions{ @@ -90,22 +93,25 @@ func searchCode(client *github.Client, t translations.TranslationHelperFunc) (to ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - query := request.Params.Arguments["q"].(string) - sort := "" - if s, ok := request.Params.Arguments["sort"].(string); ok { - sort = s + query, err := requiredParam[string](request, "q") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil } - order := "" - if o, ok := request.Params.Arguments["order"].(string); ok { - order = o + sort, err := optionalParam[string](request, "sort") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil } - perPage := 30 - if pp, ok := request.Params.Arguments["per_page"].(float64); ok { - perPage = int(pp) + order, err := optionalParam[string](request, "order") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil } - page := 1 - if p, ok := request.Params.Arguments["page"].(float64); ok { - page = int(p) + perPage, err := optionalIntParamWithDefault(request, "per_page", 30) + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + page, err := optionalIntParamWithDefault(request, "page", 1) + if err != nil { + return mcp.NewToolResultError(err.Error()), nil } opts := &github.SearchOptions{ @@ -162,22 +168,25 @@ func searchUsers(client *github.Client, t translations.TranslationHelperFunc) (t ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - query := request.Params.Arguments["q"].(string) - sort := "" - if s, ok := request.Params.Arguments["sort"].(string); ok { - sort = s - } - order := "" - if o, ok := request.Params.Arguments["order"].(string); ok { - order = o - } - perPage := 30 - if pp, ok := request.Params.Arguments["per_page"].(float64); ok { - perPage = int(pp) - } - page := 1 - if p, ok := request.Params.Arguments["page"].(float64); ok { - page = int(p) + query, err := requiredParam[string](request, "q") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + sort, err := optionalParam[string](request, "sort") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + order, err := optionalParam[string](request, "order") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + perPage, err := optionalIntParamWithDefault(request, "per_page", 30) + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + page, err := optionalIntParamWithDefault(request, "page", 1) + if err != nil { + return mcp.NewToolResultError(err.Error()), nil } opts := &github.SearchOptions{ diff --git a/pkg/github/server.go b/pkg/github/server.go index a0993e2f..829994f1 100644 --- a/pkg/github/server.go +++ b/pkg/github/server.go @@ -138,3 +138,104 @@ func parseCommaSeparatedList(input string) []string { return result } + +// requiredParam is a helper function that can be used to fetch a requested parameter from the request. +// It does the following checks: +// 1. Checks if the parameter is present in the request. +// 2. Checks if the parameter is of the expected type. +// 3. Checks if the parameter is not empty, i.e: non-zero value +func requiredParam[T comparable](r mcp.CallToolRequest, p string) (T, error) { + var zero T + + // Check if the parameter is present in the request + if _, ok := r.Params.Arguments[p]; !ok { + return zero, fmt.Errorf("missing required parameter: %s", p) + } + + // Check if the parameter is of the expected type + if _, ok := r.Params.Arguments[p].(T); !ok { + return zero, fmt.Errorf("parameter %s is not of type %T", p, zero) + } + + if r.Params.Arguments[p].(T) == zero { + return zero, fmt.Errorf("missing required parameter: %s", p) + + } + + return r.Params.Arguments[p].(T), nil +} + +// requiredInt is a helper function that can be used to fetch a requested parameter from the request. +// It does the following checks: +// 1. Checks if the parameter is present in the request. +// 2. Checks if the parameter is of the expected type. +// 3. Checks if the parameter is not empty, i.e: non-zero value +func requiredInt(r mcp.CallToolRequest, p string) (int, error) { + v, err := requiredParam[float64](r, p) + if err != nil { + return 0, err + } + return int(v), nil +} + +// optionalParam is a helper function that can be used to fetch a requested parameter from the request. +// It does the following checks: +// 1. Checks if the parameter is present in the request, if not, it returns its zero-value +// 2. If it is present, it checks if the parameter is of the expected type and returns it +func optionalParam[T any](r mcp.CallToolRequest, p string) (T, error) { + var zero T + + // Check if the parameter is present in the request + if _, ok := r.Params.Arguments[p]; !ok { + return zero, nil + } + + // Check if the parameter is of the expected type + if _, ok := r.Params.Arguments[p].(T); !ok { + return zero, fmt.Errorf("parameter %s is not of type %T", p, zero) + } + + return r.Params.Arguments[p].(T), nil +} + +// optionalIntParam is a helper function that can be used to fetch a requested parameter from the request. +// It does the following checks: +// 1. Checks if the parameter is present in the request, if not, it returns its zero-value +// 2. If it is present, it checks if the parameter is of the expected type and returns it +func optionalIntParam(r mcp.CallToolRequest, p string) (int, error) { + v, err := optionalParam[float64](r, p) + if err != nil { + return 0, err + } + return int(v), nil +} + +// optionalIntParamWithDefault is a helper function that can be used to fetch a requested parameter from the request +// similar to optionalIntParam, but it also takes a default value. +func optionalIntParamWithDefault(r mcp.CallToolRequest, p string, d int) (int, error) { + v, err := optionalIntParam(r, p) + if err != nil { + return 0, err + } + if v == 0 { + return d, nil + } + return v, nil +} + +// optionalCommaSeparatedListParam is a helper function that can be used to fetch a requested parameter from the request. +// It does the following: +// 1. Checks if the parameter is present in the request, if not, it returns an empty list +// 2. If it is present, it checks if the parameter is of the expected type and uses parseCommaSeparatedList to parse it +// and return the list of strings +func optionalCommaSeparatedListParam(r mcp.CallToolRequest, p string) ([]string, error) { + v, err := optionalParam[string](r, p) + if err != nil { + return []string{}, err + } + l := parseCommaSeparatedList(v) + if len(l) == 0 { + return []string{}, nil + } + return l, nil +} diff --git a/pkg/github/server_test.go b/pkg/github/server_test.go index 316a0efa..5e7ac9d4 100644 --- a/pkg/github/server_test.go +++ b/pkg/github/server_test.go @@ -228,3 +228,372 @@ func Test_ParseCommaSeparatedList(t *testing.T) { }) } } + +func Test_RequiredStringParam(t *testing.T) { + tests := []struct { + name string + params map[string]interface{} + paramName string + expected string + expectError bool + }{ + { + name: "valid string parameter", + params: map[string]interface{}{"name": "test-value"}, + paramName: "name", + expected: "test-value", + expectError: false, + }, + { + name: "missing parameter", + params: map[string]interface{}{}, + paramName: "name", + expected: "", + expectError: true, + }, + { + name: "empty string parameter", + params: map[string]interface{}{"name": ""}, + paramName: "name", + expected: "", + expectError: true, + }, + { + name: "wrong type parameter", + params: map[string]interface{}{"name": 123}, + paramName: "name", + expected: "", + expectError: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + request := createMCPRequest(tc.params) + result, err := requiredParam[string](request, tc.paramName) + + if tc.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tc.expected, result) + } + }) + } +} + +func Test_OptionalStringParam(t *testing.T) { + tests := []struct { + name string + params map[string]interface{} + paramName string + expected string + expectError bool + }{ + { + name: "valid string parameter", + params: map[string]interface{}{"name": "test-value"}, + paramName: "name", + expected: "test-value", + expectError: false, + }, + { + name: "missing parameter", + params: map[string]interface{}{}, + paramName: "name", + expected: "", + expectError: false, + }, + { + name: "empty string parameter", + params: map[string]interface{}{"name": ""}, + paramName: "name", + expected: "", + expectError: false, + }, + { + name: "wrong type parameter", + params: map[string]interface{}{"name": 123}, + paramName: "name", + expected: "", + expectError: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + request := createMCPRequest(tc.params) + result, err := optionalParam[string](request, tc.paramName) + + if tc.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tc.expected, result) + } + }) + } +} + +func Test_RequiredNumberParam(t *testing.T) { + tests := []struct { + name string + params map[string]interface{} + paramName string + expected int + expectError bool + }{ + { + name: "valid number parameter", + params: map[string]interface{}{"count": float64(42)}, + paramName: "count", + expected: 42, + expectError: false, + }, + { + name: "missing parameter", + params: map[string]interface{}{}, + paramName: "count", + expected: 0, + expectError: true, + }, + { + name: "wrong type parameter", + params: map[string]interface{}{"count": "not-a-number"}, + paramName: "count", + expected: 0, + expectError: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + request := createMCPRequest(tc.params) + result, err := requiredInt(request, tc.paramName) + + if tc.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tc.expected, result) + } + }) + } +} + +func Test_OptionalNumberParam(t *testing.T) { + tests := []struct { + name string + params map[string]interface{} + paramName string + expected int + expectError bool + }{ + { + name: "valid number parameter", + params: map[string]interface{}{"count": float64(42)}, + paramName: "count", + expected: 42, + expectError: false, + }, + { + name: "missing parameter", + params: map[string]interface{}{}, + paramName: "count", + expected: 0, + expectError: false, + }, + { + name: "zero value", + params: map[string]interface{}{"count": float64(0)}, + paramName: "count", + expected: 0, + expectError: false, + }, + { + name: "wrong type parameter", + params: map[string]interface{}{"count": "not-a-number"}, + paramName: "count", + expected: 0, + expectError: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + request := createMCPRequest(tc.params) + result, err := optionalIntParam(request, tc.paramName) + + if tc.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tc.expected, result) + } + }) + } +} + +func Test_OptionalNumberParamWithDefault(t *testing.T) { + tests := []struct { + name string + params map[string]interface{} + paramName string + defaultVal int + expected int + expectError bool + }{ + { + name: "valid number parameter", + params: map[string]interface{}{"count": float64(42)}, + paramName: "count", + defaultVal: 10, + expected: 42, + expectError: false, + }, + { + name: "missing parameter", + params: map[string]interface{}{}, + paramName: "count", + defaultVal: 10, + expected: 10, + expectError: false, + }, + { + name: "zero value", + params: map[string]interface{}{"count": float64(0)}, + paramName: "count", + defaultVal: 10, + expected: 10, + expectError: false, + }, + { + name: "wrong type parameter", + params: map[string]interface{}{"count": "not-a-number"}, + paramName: "count", + defaultVal: 10, + expected: 0, + expectError: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + request := createMCPRequest(tc.params) + result, err := optionalIntParamWithDefault(request, tc.paramName, tc.defaultVal) + + if tc.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tc.expected, result) + } + }) + } +} + +func Test_OptionalCommaSeparatedListParam(t *testing.T) { + tests := []struct { + name string + params map[string]interface{} + paramName string + expected []string + expectError bool + }{ + { + name: "valid comma-separated list", + params: map[string]interface{}{"tags": "one,two,three"}, + paramName: "tags", + expected: []string{"one", "two", "three"}, + expectError: false, + }, + { + name: "empty list", + params: map[string]interface{}{"tags": ""}, + paramName: "tags", + expected: []string{}, + expectError: false, + }, + { + name: "missing parameter", + params: map[string]interface{}{}, + paramName: "tags", + expected: []string{}, + expectError: false, + }, + { + name: "wrong type parameter", + params: map[string]interface{}{"tags": 123}, + paramName: "tags", + expected: nil, + expectError: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + request := createMCPRequest(tc.params) + result, err := optionalCommaSeparatedListParam(request, tc.paramName) + + if tc.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tc.expected, result) + } + }) + } +} + +func Test_OptionalBooleanParam(t *testing.T) { + tests := []struct { + name string + params map[string]interface{} + paramName string + expected bool + expectError bool + }{ + { + name: "true value", + params: map[string]interface{}{"flag": true}, + paramName: "flag", + expected: true, + expectError: false, + }, + { + name: "false value", + params: map[string]interface{}{"flag": false}, + paramName: "flag", + expected: false, + expectError: false, + }, + { + name: "missing parameter", + params: map[string]interface{}{}, + paramName: "flag", + expected: false, + expectError: false, + }, + { + name: "wrong type parameter", + params: map[string]interface{}{"flag": "not-a-boolean"}, + paramName: "flag", + expected: false, + expectError: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + request := createMCPRequest(tc.params) + result, err := optionalParam[bool](request, tc.paramName) + + if tc.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tc.expected, result) + } + }) + } +}