diff --git a/README.md b/README.md index a17bf435..b3ad2d71 100644 --- a/README.md +++ b/README.md @@ -149,6 +149,7 @@ The following sets of tools are available (all are on by default): | ----------------------- | ------------------------------------------------------------- | | `repos` | Repository-related tools (file operations, branches, commits) | | `issues` | Issue-related tools (create, read, update, comment) | +| `discussions` | GitHub Discussions tools (list, get, comments, categories) | | `users` | Anything relating to GitHub Users | | `pull_requests` | Pull request operations (create, merge, review) | | `code_security` | Code scanning alerts and security features | @@ -615,6 +616,39 @@ export GITHUB_MCP_TOOL_ADD_ISSUE_COMMENT_DESCRIPTION="an alternative description - `repo`: The name of the repository (string, required) - `action`: Action to perform: `ignore`, `watch`, or `delete` (string, required) +### Discussions + +- **list_discussions** - List discussions for a repository + - `owner`: Repository owner (string, required) + - `repo`: Repository name (string, required) + - `category`: Filter by category name (string, optional) + - `since`: Filter by date (ISO 8601 timestamp) (string, optional) + - `first`: Pagination - Number of records to retrieve (number, optional) + - `last`: Pagination - Number of records to retrieve from the end (number, optional) + - `after`: Pagination - Cursor to start with (string, optional) + - `before`: Pagination - Cursor to end with (string, optional) + - `sort`: Sort by ('CREATED_AT', 'UPDATED_AT') (string, optional) + - `direction`: Sort direction ('ASC', 'DESC') (string, optional) + - `answered`: Filter by whether discussions have been answered or not (boolean, optional) + +- **get_discussion** - Get a specific discussion by ID + - `owner`: Repository owner (string, required) + - `repo`: Repository name (string, required) + - `discussionNumber`: Discussion number (required) + +- **get_discussion_comments** - Get comments from a discussion + - `owner`: Repository owner (string, required) + - `repo`: Repository name (string, required) + - `discussionNumber`: Discussion number (required) + +- **list_discussion_categories** - List discussion categories for a repository, with their IDs and names + - `owner`: Repository owner (string, required) + - `repo`: Repository name (string, required) + - `first`: Pagination - Number of categories to return per page (number, optional, min 1, max 100) + - `last`: Pagination - Number of categories to return from the end (number, optional, min 1, max 100) + - `after`: Pagination - Cursor to start with (string, optional) + - `before`: Pagination - Cursor to end with (string, optional) + ## Resources ### Repository Content diff --git a/pkg/github/discussions.go b/pkg/github/discussions.go new file mode 100644 index 00000000..54ae6d04 --- /dev/null +++ b/pkg/github/discussions.go @@ -0,0 +1,459 @@ +package github + +import ( + "context" + "encoding/json" + "fmt" + "time" + + "github.com/github/github-mcp-server/pkg/translations" + "github.com/go-viper/mapstructure/v2" + "github.com/google/go-github/v72/github" + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" + "github.com/shurcooL/githubv4" +) + +// GetAllDiscussionCategories retrieves all discussion categories for a repository +// by paginating through all pages and returns them as a map where the key is the +// category name and the value is the category ID. +func GetAllDiscussionCategories(ctx context.Context, client *githubv4.Client, owner, repo string) (map[string]string, error) { + categories := make(map[string]string) + var after string + hasNextPage := true + + for hasNextPage { + // Prepare GraphQL query with pagination + var q struct { + Repository struct { + DiscussionCategories struct { + Nodes []struct { + ID githubv4.ID + Name githubv4.String + } + PageInfo struct { + HasNextPage githubv4.Boolean + EndCursor githubv4.String + } + } `graphql:"discussionCategories(first: 100, after: $after)"` + } `graphql:"repository(owner: $owner, name: $repo)"` + } + + vars := map[string]interface{}{ + "owner": githubv4.String(owner), + "repo": githubv4.String(repo), + "after": githubv4.String(after), + } + + if err := client.Query(ctx, &q, vars); err != nil { + return nil, fmt.Errorf("failed to query discussion categories: %w", err) + } + + // Add categories to the map + for _, category := range q.Repository.DiscussionCategories.Nodes { + categories[string(category.Name)] = fmt.Sprint(category.ID) + } + + // Check if there are more pages + hasNextPage = bool(q.Repository.DiscussionCategories.PageInfo.HasNextPage) + if hasNextPage { + after = string(q.Repository.DiscussionCategories.PageInfo.EndCursor) + } + } + + return categories, nil +} + +func ListDiscussions(getGQLClient GetGQLClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { + return mcp.NewTool("list_discussions", + mcp.WithDescription(t("TOOL_LIST_DISCUSSIONS_DESCRIPTION", "List discussions for a repository")), + mcp.WithToolAnnotation(mcp.ToolAnnotation{ + Title: t("TOOL_LIST_DISCUSSIONS_USER_TITLE", "List discussions"), + ReadOnlyHint: toBoolPtr(true), + }), + mcp.WithString("owner", + mcp.Required(), + mcp.Description("Repository owner"), + ), + mcp.WithString("repo", + mcp.Required(), + mcp.Description("Repository name"), + ), + mcp.WithString("category", + mcp.Description("Category filter (name)"), + ), + mcp.WithString("since", + mcp.Description("Filter by date (ISO 8601 timestamp)"), + ), + mcp.WithString("sort", + mcp.Description("Sort field"), + mcp.DefaultString("CREATED_AT"), + mcp.Enum("CREATED_AT", "UPDATED_AT"), + ), + mcp.WithString("direction", + mcp.Description("Sort direction"), + mcp.DefaultString("DESC"), + mcp.Enum("ASC", "DESC"), + ), + mcp.WithNumber("first", + mcp.Description("Number of discussions to return per page (min 1, max 100)"), + mcp.Min(1), + mcp.Max(100), + ), + mcp.WithNumber("last", + mcp.Description("Number of discussions to return from the end (min 1, max 100)"), + mcp.Min(1), + mcp.Max(100), + ), + mcp.WithString("after", + mcp.Description("Cursor for pagination, use the 'after' field from the previous response"), + ), + mcp.WithString("before", + mcp.Description("Cursor for pagination, use the 'before' field from the previous response"), + ), + mcp.WithBoolean("answered", + mcp.Description("Filter by whether discussions have been answered or not"), + ), + ), + func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + // Decode params + var params struct { + Owner string + Repo string + Category string + Since string + Sort string + Direction string + First int32 + Last int32 + After string + Before string + Answered bool + } + if err := mapstructure.Decode(request.Params.Arguments, ¶ms); err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + if params.First != 0 && params.Last != 0 { + return mcp.NewToolResultError("only one of 'first' or 'last' may be specified"), nil + } + if params.After != "" && params.Before != "" { + return mcp.NewToolResultError("only one of 'after' or 'before' may be specified"), nil + } + if params.After != "" && params.Last != 0 { + return mcp.NewToolResultError("'after' cannot be used with 'last'. Did you mean to use 'before' instead?"), nil + } + if params.Before != "" && params.First != 0 { + return mcp.NewToolResultError("'before' cannot be used with 'first'. Did you mean to use 'after' instead?"), nil + } + // Get GraphQL client + client, err := getGQLClient(ctx) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("failed to get GitHub GQL client: %v", err)), nil + } + // Prepare GraphQL query + var q struct { + Repository struct { + Discussions struct { + Nodes []struct { + Number githubv4.Int + Title githubv4.String + CreatedAt githubv4.DateTime + Category struct { + Name githubv4.String + } `graphql:"category"` + URL githubv4.String `graphql:"url"` + } + } `graphql:"discussions(categoryId: $categoryId, orderBy: {field: $sort, direction: $direction}, first: $first, after: $after, last: $last, before: $before, answered: $answered)"` + } `graphql:"repository(owner: $owner, name: $repo)"` + } + categories, err := GetAllDiscussionCategories(ctx, client, params.Owner, params.Repo) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("failed to get discussion categories: %v", err)), nil + } + var categoryID githubv4.ID = categories[params.Category] + if categoryID == "" && params.Category != "" { + return mcp.NewToolResultError(fmt.Sprintf("category '%s' not found", params.Category)), nil + } + // Build query variables + vars := map[string]interface{}{ + "owner": githubv4.String(params.Owner), + "repo": githubv4.String(params.Repo), + "categoryId": categoryID, + "sort": githubv4.DiscussionOrderField(params.Sort), + "direction": githubv4.OrderDirection(params.Direction), + "first": githubv4.Int(params.First), + "last": githubv4.Int(params.Last), + "after": githubv4.String(params.After), + "before": githubv4.String(params.Before), + "answered": githubv4.Boolean(params.Answered), + } + // Execute query + if err := client.Query(ctx, &q, vars); err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + // Map nodes to GitHub Issue objects - there is no discussion type in the GitHub API, so we use Issue to benefit from existing code + var discussions []*github.Issue + for _, n := range q.Repository.Discussions.Nodes { + di := &github.Issue{ + Number: github.Ptr(int(n.Number)), + Title: github.Ptr(string(n.Title)), + HTMLURL: github.Ptr(string(n.URL)), + CreatedAt: &github.Timestamp{Time: n.CreatedAt.Time}, + } + discussions = append(discussions, di) + } + + // Post filtering discussions based on 'since' parameter + if params.Since != "" { + sinceTime, err := time.Parse(time.RFC3339, params.Since) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("invalid 'since' timestamp: %v", err)), nil + } + var filteredDiscussions []*github.Issue + for _, d := range discussions { + if d.CreatedAt.Time.After(sinceTime) { + filteredDiscussions = append(filteredDiscussions, d) + } + } + discussions = filteredDiscussions + } + + // Marshal and return + out, err := json.Marshal(discussions) + if err != nil { + return nil, fmt.Errorf("failed to marshal discussions: %w", err) + } + return mcp.NewToolResultText(string(out)), nil + } +} + +func GetDiscussion(getGQLClient GetGQLClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { + return mcp.NewTool("get_discussion", + mcp.WithDescription(t("TOOL_GET_DISCUSSION_DESCRIPTION", "Get a specific discussion by ID")), + mcp.WithToolAnnotation(mcp.ToolAnnotation{ + Title: t("TOOL_GET_DISCUSSION_USER_TITLE", "Get discussion"), + ReadOnlyHint: toBoolPtr(true), + }), + mcp.WithString("owner", + mcp.Required(), + mcp.Description("Repository owner"), + ), + mcp.WithString("repo", + mcp.Required(), + mcp.Description("Repository name"), + ), + mcp.WithNumber("discussionNumber", + mcp.Required(), + mcp.Description("Discussion Number"), + ), + ), + func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + // Decode params + var params struct { + Owner string + Repo string + DiscussionNumber int32 + } + if err := mapstructure.Decode(request.Params.Arguments, ¶ms); err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + client, err := getGQLClient(ctx) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("failed to get GitHub GQL client: %v", err)), nil + } + + var q struct { + Repository struct { + Discussion struct { + Number githubv4.Int + Body githubv4.String + State githubv4.String + CreatedAt githubv4.DateTime + URL githubv4.String `graphql:"url"` + } `graphql:"discussion(number: $discussionNumber)"` + } `graphql:"repository(owner: $owner, name: $repo)"` + } + vars := map[string]interface{}{ + "owner": githubv4.String(params.Owner), + "repo": githubv4.String(params.Repo), + "discussionNumber": githubv4.Int(params.DiscussionNumber), + } + if err := client.Query(ctx, &q, vars); err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + d := q.Repository.Discussion + discussion := &github.Issue{ + Number: github.Ptr(int(d.Number)), + Body: github.Ptr(string(d.Body)), + State: github.Ptr(string(d.State)), + HTMLURL: github.Ptr(string(d.URL)), + CreatedAt: &github.Timestamp{Time: d.CreatedAt.Time}, + } + out, err := json.Marshal(discussion) + if err != nil { + return nil, fmt.Errorf("failed to marshal discussion: %w", err) + } + + return mcp.NewToolResultText(string(out)), nil + } +} + +func GetDiscussionComments(getGQLClient GetGQLClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { + return mcp.NewTool("get_discussion_comments", + mcp.WithDescription(t("TOOL_GET_DISCUSSION_COMMENTS_DESCRIPTION", "Get comments from a discussion")), + mcp.WithToolAnnotation(mcp.ToolAnnotation{ + Title: t("TOOL_GET_DISCUSSION_COMMENTS_USER_TITLE", "Get discussion comments"), + ReadOnlyHint: toBoolPtr(true), + }), + mcp.WithString("owner", mcp.Required(), mcp.Description("Repository owner")), + mcp.WithString("repo", mcp.Required(), mcp.Description("Repository name")), + mcp.WithNumber("discussionNumber", mcp.Required(), mcp.Description("Discussion Number")), + ), + func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + // Decode params + var params struct { + Owner string + Repo string + DiscussionNumber int32 + } + if err := mapstructure.Decode(request.Params.Arguments, ¶ms); err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + client, err := getGQLClient(ctx) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("failed to get GitHub GQL client: %v", err)), nil + } + + var q struct { + Repository struct { + Discussion struct { + Comments struct { + Nodes []struct { + Body githubv4.String + } + } `graphql:"comments(first:100)"` + } `graphql:"discussion(number: $discussionNumber)"` + } `graphql:"repository(owner: $owner, name: $repo)"` + } + vars := map[string]interface{}{ + "owner": githubv4.String(params.Owner), + "repo": githubv4.String(params.Repo), + "discussionNumber": githubv4.Int(params.DiscussionNumber), + } + if err := client.Query(ctx, &q, vars); err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + var comments []*github.IssueComment + for _, c := range q.Repository.Discussion.Comments.Nodes { + comments = append(comments, &github.IssueComment{Body: github.Ptr(string(c.Body))}) + } + + out, err := json.Marshal(comments) + if err != nil { + return nil, fmt.Errorf("failed to marshal comments: %w", err) + } + + return mcp.NewToolResultText(string(out)), nil + } +} + +func ListDiscussionCategories(getGQLClient GetGQLClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { + return mcp.NewTool("list_discussion_categories", + mcp.WithDescription(t("TOOL_LIST_DISCUSSION_CATEGORIES_DESCRIPTION", "List discussion categories with their id and name, for a repository")), + mcp.WithToolAnnotation(mcp.ToolAnnotation{ + Title: t("TOOL_LIST_DISCUSSION_CATEGORIES_USER_TITLE", "List discussion categories"), + ReadOnlyHint: toBoolPtr(true), + }), + mcp.WithString("owner", + mcp.Required(), + mcp.Description("Repository owner"), + ), + mcp.WithString("repo", + mcp.Required(), + mcp.Description("Repository name"), + ), + mcp.WithNumber("first", + mcp.Description("Number of categories to return per page (min 1, max 100)"), + mcp.Min(1), + mcp.Max(100), + ), + mcp.WithNumber("last", + mcp.Description("Number of categories to return from the end (min 1, max 100)"), + mcp.Min(1), + mcp.Max(100), + ), + mcp.WithString("after", + mcp.Description("Cursor for pagination, use the 'after' field from the previous response"), + ), + mcp.WithString("before", + mcp.Description("Cursor for pagination, use the 'before' field from the previous response"), + ), + ), + func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + // Decode params + var params struct { + Owner string + Repo string + First int32 + Last int32 + After string + Before string + } + if err := mapstructure.Decode(request.Params.Arguments, ¶ms); err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + // Validate pagination parameters + if params.First != 0 && params.Last != 0 { + return mcp.NewToolResultError("only one of 'first' or 'last' may be specified"), nil + } + if params.After != "" && params.Before != "" { + return mcp.NewToolResultError("only one of 'after' or 'before' may be specified"), nil + } + if params.After != "" && params.Last != 0 { + return mcp.NewToolResultError("'after' cannot be used with 'last'. Did you mean to use 'before' instead?"), nil + } + if params.Before != "" && params.First != 0 { + return mcp.NewToolResultError("'before' cannot be used with 'first'. Did you mean to use 'after' instead?"), nil + } + + client, err := getGQLClient(ctx) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("failed to get GitHub GQL client: %v", err)), nil + } + var q struct { + Repository struct { + DiscussionCategories struct { + Nodes []struct { + ID githubv4.ID + Name githubv4.String + } + } `graphql:"discussionCategories(first: $first, last: $last, after: $after, before: $before)"` + } `graphql:"repository(owner: $owner, name: $repo)"` + } + vars := map[string]interface{}{ + "owner": githubv4.String(params.Owner), + "repo": githubv4.String(params.Repo), + "first": githubv4.Int(params.First), + "last": githubv4.Int(params.Last), + "after": githubv4.String(params.After), + "before": githubv4.String(params.Before), + } + if err := client.Query(ctx, &q, vars); err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + var categories []map[string]string + for _, c := range q.Repository.DiscussionCategories.Nodes { + categories = append(categories, map[string]string{ + "id": fmt.Sprint(c.ID), + "name": string(c.Name), + }) + } + out, err := json.Marshal(categories) + if err != nil { + return nil, fmt.Errorf("failed to marshal discussion categories: %w", err) + } + return mcp.NewToolResultText(string(out)), nil + } +} diff --git a/pkg/github/discussions_test.go b/pkg/github/discussions_test.go new file mode 100644 index 00000000..8b10d0c9 --- /dev/null +++ b/pkg/github/discussions_test.go @@ -0,0 +1,510 @@ +package github + +import ( + "context" + "encoding/json" + "testing" + "time" + + "github.com/github/github-mcp-server/internal/githubv4mock" + "github.com/github/github-mcp-server/pkg/translations" + "github.com/google/go-github/v72/github" + "github.com/shurcooL/githubv4" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +var ( + discussionsAll = []map[string]any{ + {"number": 1, "title": "Discussion 1 title", "createdAt": "2023-01-01T00:00:00Z", "category": map[string]any{"name": "news"}, "url": "https://github.com/owner/repo/discussions/1"}, + {"number": 2, "title": "Discussion 2 title", "createdAt": "2023-02-01T00:00:00Z", "category": map[string]any{"name": "updates"}, "url": "https://github.com/owner/repo/discussions/2"}, + {"number": 3, "title": "Discussion 3 title", "createdAt": "2023-03-01T00:00:00Z", "category": map[string]any{"name": "questions"}, "url": "https://github.com/owner/repo/discussions/3"}, + } + mockResponseListAll = githubv4mock.DataResponse(map[string]any{ + "repository": map[string]any{ + "discussions": map[string]any{"nodes": discussionsAll}, + }, + }) + mockResponseCategory = githubv4mock.DataResponse(map[string]any{ + "repository": map[string]any{ + "discussions": map[string]any{"nodes": discussionsAll[:1]}, // Only return the first discussion for category test + }, + }) + mockErrorRepoNotFound = githubv4mock.ErrorResponse("repository not found") +) + +func Test_ListDiscussions(t *testing.T) { + // Verify tool definition and schema + toolDef, _ := ListDiscussions(nil, translations.NullTranslationHelper) + assert.Equal(t, "list_discussions", toolDef.Name) + assert.NotEmpty(t, toolDef.Description) + assert.Contains(t, toolDef.InputSchema.Properties, "owner") + assert.Contains(t, toolDef.InputSchema.Properties, "repo") + assert.ElementsMatch(t, toolDef.InputSchema.Required, []string{"owner", "repo"}) + + // mock for the call to list all categories: query struct, variables, response + var qCat struct { + Repository struct { + DiscussionCategories struct { + Nodes []struct { + ID githubv4.ID + Name githubv4.String + } + PageInfo struct { + HasNextPage githubv4.Boolean + EndCursor githubv4.String + } + } `graphql:"discussionCategories(first: 100, after: $after)"` + } `graphql:"repository(owner: $owner, name: $repo)"` + } + + varsCat := map[string]interface{}{ + "owner": githubv4.String("owner"), + "repo": githubv4.String("repo"), + "after": githubv4.String(""), + } + + varsCatInvalid := map[string]interface{}{ + "owner": githubv4.String("invalid"), + "repo": githubv4.String("repo"), + "after": githubv4.String(""), + } + + mockRespCat := githubv4mock.DataResponse(map[string]any{ + "repository": map[string]any{ + "discussionCategories": map[string]any{ + "nodes": []map[string]any{ + {"id": "123", "name": "CategoryOne"}, + {"id": "456", "name": "CategoryTwo"}, + }, + }, + }, + }) + + mockRespCatInvalid := githubv4mock.ErrorResponse("repository not found") + + // mock for the call to ListDiscussions: query struct, variables, response + var q struct { + Repository struct { + Discussions struct { + Nodes []struct { + Number githubv4.Int + Title githubv4.String + CreatedAt githubv4.DateTime + Category struct { + Name githubv4.String + } `graphql:"category"` + URL githubv4.String `graphql:"url"` + } + } `graphql:"discussions(categoryId: $categoryId, orderBy: {field: $sort, direction: $direction}, first: $first, after: $after, last: $last, before: $before, answered: $answered)"` + } `graphql:"repository(owner: $owner, name: $repo)"` + } + + varsListAll := map[string]interface{}{ + "owner": githubv4.String("owner"), + "repo": githubv4.String("repo"), + "categoryId": githubv4.ID(""), + "sort": githubv4.DiscussionOrderField(""), + "direction": githubv4.OrderDirection(""), + "first": githubv4.Int(0), + "last": githubv4.Int(0), + "after": githubv4.String(""), + "before": githubv4.String(""), + "answered": githubv4.Boolean(false), + } + + varsListInvalid := map[string]interface{}{ + "owner": githubv4.String("invalid"), + "repo": githubv4.String("repo"), + "categoryId": githubv4.ID(""), + "sort": githubv4.DiscussionOrderField(""), + "direction": githubv4.OrderDirection(""), + "first": githubv4.Int(0), + "last": githubv4.Int(0), + "after": githubv4.String(""), + "before": githubv4.String(""), + "answered": githubv4.Boolean(false), + } + + varsListWithCategory := map[string]interface{}{ + "owner": githubv4.String("owner"), + "repo": githubv4.String("repo"), + "categoryId": githubv4.ID("123"), + "sort": githubv4.DiscussionOrderField(""), + "direction": githubv4.OrderDirection(""), + "first": githubv4.Int(0), + "last": githubv4.Int(0), + "after": githubv4.String(""), + "before": githubv4.String(""), + "answered": githubv4.Boolean(false), + } + + catMatcher := githubv4mock.NewQueryMatcher(qCat, varsCat, mockRespCat) + catMatcherInvalid := githubv4mock.NewQueryMatcher(qCat, varsCatInvalid, mockRespCatInvalid) + + tests := []struct { + name string + vars map[string]interface{} + reqParams map[string]interface{} + response githubv4mock.GQLResponse + expectError bool + expectedIds []int64 + errContains string + catMatcher githubv4mock.Matcher + }{ + { + name: "list all discussions", + vars: varsListAll, + reqParams: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + }, + response: mockResponseListAll, + expectError: false, + expectedIds: []int64{1, 2, 3}, + catMatcher: catMatcher, + }, + { + name: "invalid owner or repo", + vars: varsListInvalid, + reqParams: map[string]interface{}{ + "owner": "invalid", + "repo": "repo", + }, + response: mockErrorRepoNotFound, + expectError: true, + errContains: "repository not found", + catMatcher: catMatcherInvalid, + }, + { + name: "list discussions with category", + vars: varsListWithCategory, + reqParams: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "category": "CategoryOne", // This should match the ID "123" in the mock response + }, + response: mockResponseCategory, + expectError: false, + expectedIds: []int64{1}, + catMatcher: catMatcher, + }, + { + name: "list discussions with since date", + vars: varsListAll, + reqParams: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "since": "2023-01-10T00:00:00Z", + }, + response: mockResponseListAll, + expectError: false, + expectedIds: []int64{2, 3}, + catMatcher: catMatcher, + }, + { + name: "both first and last parameters provided", + vars: varsListAll, // vars don't matter since error occurs before GraphQL call + reqParams: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "first": int32(10), + "last": int32(5), + }, + response: mockResponseListAll, // response doesn't matter since error occurs before GraphQL call + expectError: true, + errContains: "only one of 'first' or 'last' may be specified", + catMatcher: catMatcher, + }, + { + name: "after with last parameters provided", + vars: varsListAll, // vars don't matter since error occurs before GraphQL call + reqParams: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "after": "cursor123", + "last": int32(5), + }, + response: mockResponseListAll, // response doesn't matter since error occurs before GraphQL call + expectError: true, + errContains: "'after' cannot be used with 'last'. Did you mean to use 'before' instead?", + catMatcher: catMatcher, + }, + { + name: "before with first parameters provided", + vars: varsListAll, // vars don't matter since error occurs before GraphQL call + reqParams: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "before": "cursor456", + "first": int32(10), + }, + response: mockResponseListAll, // response doesn't matter since error occurs before GraphQL call + expectError: true, + errContains: "'before' cannot be used with 'first'. Did you mean to use 'after' instead?", + catMatcher: catMatcher, + }, + { + name: "both after and before parameters provided", + vars: varsListAll, // vars don't matter since error occurs before GraphQL call + reqParams: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "after": "cursor123", + "before": "cursor456", + }, + response: mockResponseListAll, // response doesn't matter since error occurs before GraphQL call + expectError: true, + errContains: "only one of 'after' or 'before' may be specified", + catMatcher: catMatcher, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + matcher := githubv4mock.NewQueryMatcher(q, tc.vars, tc.response) + httpClient := githubv4mock.NewMockedHTTPClient(matcher, tc.catMatcher) + gqlClient := githubv4.NewClient(httpClient) + _, handler := ListDiscussions(stubGetGQLClientFn(gqlClient), translations.NullTranslationHelper) + + req := createMCPRequest(tc.reqParams) + res, err := handler(context.Background(), req) + text := getTextResult(t, res).Text + + if tc.expectError { + require.True(t, res.IsError) + assert.Contains(t, text, tc.errContains) + return + } + require.NoError(t, err) + + var returnedDiscussions []*github.Issue + err = json.Unmarshal([]byte(text), &returnedDiscussions) + require.NoError(t, err) + + assert.Len(t, returnedDiscussions, len(tc.expectedIds), "Expected %d discussions, got %d", len(tc.expectedIds), len(returnedDiscussions)) + + // If no discussions are expected, skip further checks + if len(tc.expectedIds) == 0 { + return + } + + // Create a map of expected IDs for easier checking + expectedIDMap := make(map[int64]bool) + for _, id := range tc.expectedIds { + expectedIDMap[id] = true + } + + for _, discussion := range returnedDiscussions { + // Check if the discussion Number is in the expected list + assert.True(t, expectedIDMap[int64(*discussion.Number)], "Unexpected discussion Number: %d", *discussion.Number) + } + }) + } +} + +func Test_GetDiscussion(t *testing.T) { + // Verify tool definition and schema + toolDef, _ := GetDiscussion(nil, translations.NullTranslationHelper) + assert.Equal(t, "get_discussion", toolDef.Name) + assert.NotEmpty(t, toolDef.Description) + assert.Contains(t, toolDef.InputSchema.Properties, "owner") + assert.Contains(t, toolDef.InputSchema.Properties, "repo") + assert.Contains(t, toolDef.InputSchema.Properties, "discussionNumber") + assert.ElementsMatch(t, toolDef.InputSchema.Required, []string{"owner", "repo", "discussionNumber"}) + + var q struct { + Repository struct { + Discussion struct { + Number githubv4.Int + Body githubv4.String + State githubv4.String + CreatedAt githubv4.DateTime + URL githubv4.String `graphql:"url"` + } `graphql:"discussion(number: $discussionNumber)"` + } `graphql:"repository(owner: $owner, name: $repo)"` + } + vars := map[string]interface{}{ + "owner": githubv4.String("owner"), + "repo": githubv4.String("repo"), + "discussionNumber": githubv4.Int(1), + } + tests := []struct { + name string + response githubv4mock.GQLResponse + expectError bool + expected *github.Issue + errContains string + }{ + { + name: "successful retrieval", + response: githubv4mock.DataResponse(map[string]any{ + "repository": map[string]any{"discussion": map[string]any{ + "number": 1, + "body": "This is a test discussion", + "state": "open", + "url": "https://github.com/owner/repo/discussions/1", + "createdAt": "2025-04-25T12:00:00Z", + }}, + }), + expectError: false, + expected: &github.Issue{ + HTMLURL: github.Ptr("https://github.com/owner/repo/discussions/1"), + Number: github.Ptr(1), + Body: github.Ptr("This is a test discussion"), + State: github.Ptr("open"), + CreatedAt: &github.Timestamp{Time: time.Date(2025, 4, 25, 12, 0, 0, 0, time.UTC)}, + }, + }, + { + name: "discussion not found", + response: githubv4mock.ErrorResponse("discussion not found"), + expectError: true, + errContains: "discussion not found", + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + matcher := githubv4mock.NewQueryMatcher(q, vars, tc.response) + httpClient := githubv4mock.NewMockedHTTPClient(matcher) + gqlClient := githubv4.NewClient(httpClient) + _, handler := GetDiscussion(stubGetGQLClientFn(gqlClient), translations.NullTranslationHelper) + + req := createMCPRequest(map[string]interface{}{"owner": "owner", "repo": "repo", "discussionNumber": int32(1)}) + res, err := handler(context.Background(), req) + text := getTextResult(t, res).Text + + if tc.expectError { + require.True(t, res.IsError) + assert.Contains(t, text, tc.errContains) + return + } + + require.NoError(t, err) + var out github.Issue + require.NoError(t, json.Unmarshal([]byte(text), &out)) + assert.Equal(t, *tc.expected.HTMLURL, *out.HTMLURL) + assert.Equal(t, *tc.expected.Number, *out.Number) + assert.Equal(t, *tc.expected.Body, *out.Body) + assert.Equal(t, *tc.expected.State, *out.State) + }) + } +} + +func Test_GetDiscussionComments(t *testing.T) { + // Verify tool definition and schema + toolDef, _ := GetDiscussionComments(nil, translations.NullTranslationHelper) + assert.Equal(t, "get_discussion_comments", toolDef.Name) + assert.NotEmpty(t, toolDef.Description) + assert.Contains(t, toolDef.InputSchema.Properties, "owner") + assert.Contains(t, toolDef.InputSchema.Properties, "repo") + assert.Contains(t, toolDef.InputSchema.Properties, "discussionNumber") + assert.ElementsMatch(t, toolDef.InputSchema.Required, []string{"owner", "repo", "discussionNumber"}) + + var q struct { + Repository struct { + Discussion struct { + Comments struct { + Nodes []struct { + Body githubv4.String + } + } `graphql:"comments(first:100)"` + } `graphql:"discussion(number: $discussionNumber)"` + } `graphql:"repository(owner: $owner, name: $repo)"` + } + vars := map[string]interface{}{ + "owner": githubv4.String("owner"), + "repo": githubv4.String("repo"), + "discussionNumber": githubv4.Int(1), + } + mockResponse := githubv4mock.DataResponse(map[string]any{ + "repository": map[string]any{ + "discussion": map[string]any{ + "comments": map[string]any{ + "nodes": []map[string]any{ + {"body": "This is the first comment"}, + {"body": "This is the second comment"}, + }, + }, + }, + }, + }) + matcher := githubv4mock.NewQueryMatcher(q, vars, mockResponse) + httpClient := githubv4mock.NewMockedHTTPClient(matcher) + gqlClient := githubv4.NewClient(httpClient) + _, handler := GetDiscussionComments(stubGetGQLClientFn(gqlClient), translations.NullTranslationHelper) + + request := createMCPRequest(map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "discussionNumber": int32(1), + }) + + result, err := handler(context.Background(), request) + require.NoError(t, err) + + textContent := getTextResult(t, result) + + var returnedComments []*github.IssueComment + err = json.Unmarshal([]byte(textContent.Text), &returnedComments) + require.NoError(t, err) + assert.Len(t, returnedComments, 2) + expectedBodies := []string{"This is the first comment", "This is the second comment"} + for i, comment := range returnedComments { + assert.Equal(t, expectedBodies[i], *comment.Body) + } +} + +func Test_ListDiscussionCategories(t *testing.T) { + var q struct { + Repository struct { + DiscussionCategories struct { + Nodes []struct { + ID githubv4.ID + Name githubv4.String + } + } `graphql:"discussionCategories(first: $first, last: $last, after: $after, before: $before)"` + } `graphql:"repository(owner: $owner, name: $repo)"` + } + vars := map[string]interface{}{ + "owner": githubv4.String("owner"), + "repo": githubv4.String("repo"), + "first": githubv4.Int(0), // Default to 100 categories + "last": githubv4.Int(0), // Not used, but required by schema + "after": githubv4.String(""), // Not used, but required by schema + "before": githubv4.String(""), // Not used, but required by schema + } + mockResp := githubv4mock.DataResponse(map[string]any{ + "repository": map[string]any{ + "discussionCategories": map[string]any{ + "nodes": []map[string]any{ + {"id": "123", "name": "CategoryOne"}, + {"id": "456", "name": "CategoryTwo"}, + }, + }, + }, + }) + matcher := githubv4mock.NewQueryMatcher(q, vars, mockResp) + httpClient := githubv4mock.NewMockedHTTPClient(matcher) + gqlClient := githubv4.NewClient(httpClient) + + tool, handler := ListDiscussionCategories(stubGetGQLClientFn(gqlClient), translations.NullTranslationHelper) + assert.Equal(t, "list_discussion_categories", tool.Name) + assert.NotEmpty(t, tool.Description) + assert.Contains(t, tool.InputSchema.Properties, "owner") + assert.Contains(t, tool.InputSchema.Properties, "repo") + assert.ElementsMatch(t, tool.InputSchema.Required, []string{"owner", "repo"}) + + request := createMCPRequest(map[string]interface{}{"owner": "owner", "repo": "repo"}) + result, err := handler(context.Background(), request) + require.NoError(t, err) + + text := getTextResult(t, result).Text + var categories []map[string]string + require.NoError(t, json.Unmarshal([]byte(text), &categories)) + assert.Len(t, categories, 2) + assert.Equal(t, "123", categories[0]["id"]) + assert.Equal(t, "CategoryOne", categories[0]["name"]) + assert.Equal(t, "456", categories[1]["id"]) + assert.Equal(t, "CategoryTwo", categories[1]["name"]) +} diff --git a/pkg/github/tools.go b/pkg/github/tools.go index ab052817..bf5beee1 100644 --- a/pkg/github/tools.go +++ b/pkg/github/tools.go @@ -104,6 +104,14 @@ func InitToolsets(passedToolsets []string, readOnly bool, getClient GetClientFn, toolsets.NewServerTool(ManageRepositoryNotificationSubscription(getClient, t)), ) + discussions := toolsets.NewToolset("discussions", "GitHub Discussions related tools"). + AddReadTools( + toolsets.NewServerTool(ListDiscussions(getGQLClient, t)), + toolsets.NewServerTool(GetDiscussion(getGQLClient, t)), + toolsets.NewServerTool(GetDiscussionComments(getGQLClient, t)), + toolsets.NewServerTool(ListDiscussionCategories(getGQLClient, t)), + ) + // Keep experiments alive so the system doesn't error out when it's always enabled experiments := toolsets.NewToolset("experiments", "Experimental features that are not considered stable yet") @@ -116,6 +124,8 @@ func InitToolsets(passedToolsets []string, readOnly bool, getClient GetClientFn, tsg.AddToolset(secretProtection) tsg.AddToolset(notifications) tsg.AddToolset(experiments) + tsg.AddToolset(discussions) + // Enable the requested features if err := tsg.EnableToolsets(passedToolsets); err != nil { diff --git a/script/get-discussions b/script/get-discussions new file mode 100755 index 00000000..3e68abf2 --- /dev/null +++ b/script/get-discussions @@ -0,0 +1,5 @@ +#!/bin/bash + +# echo '{"jsonrpc":"2.0","id":3,"params":{"name":"list_discussions","arguments": {"owner": "github", "repo": "securitylab", "first": 10, "since": "2025-04-01T00:00:00Z"}},"method":"tools/call"}' | go run cmd/github-mcp-server/main.go stdio | jq . +echo '{"jsonrpc":"2.0","id":3,"params":{"name":"list_discussions","arguments": {"owner": "github", "repo": "securitylab", "first": 10, "since": "2025-04-01T00:00:00Z", "sort": "CREATED_AT", "direction": "DESC"}},"method":"tools/call"}' | go run cmd/github-mcp-server/main.go stdio | jq . +