diff --git a/README.md b/README.md index b40974e20..b2b8cfafc 100644 --- a/README.md +++ b/README.md @@ -458,7 +458,7 @@ The following sets of tools are available (all are on by default): - **list_discussion_categories** - List discussion categories - `owner`: Repository owner (string, required) - - `repo`: Repository name (string, required) + - `repo`: Repository name. If not provided, discussion categories will be queried at the organisation level. (string, optional) - **list_discussions** - List discussions - `after`: Cursor for pagination. Use the endCursor from the previous page's PageInfo for GraphQL APIs. (string, optional) diff --git a/pkg/github/discussions.go b/pkg/github/discussions.go index 91487f7aa..fc7d94b06 100644 --- a/pkg/github/discussions.go +++ b/pkg/github/discussions.go @@ -443,7 +443,7 @@ func GetDiscussionComments(getGQLClient GetGQLClientFn, t translations.Translati func ListDiscussionCategories(getGQLClient GetGQLClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { return mcp.NewTool("list_discussion_categories", - mcp.WithDescription(t("TOOL_LIST_DISCUSSION_CATEGORIES_DESCRIPTION", "List discussion categories with their id and name, for a repository")), + mcp.WithDescription(t("TOOL_LIST_DISCUSSION_CATEGORIES_DESCRIPTION", "List discussion categories with their id and name, for a repository or organisation.")), mcp.WithToolAnnotation(mcp.ToolAnnotation{ Title: t("TOOL_LIST_DISCUSSION_CATEGORIES_USER_TITLE", "List discussion categories"), ReadOnlyHint: ToBoolPtr(true), @@ -453,19 +453,23 @@ func ListDiscussionCategories(getGQLClient GetGQLClientFn, t translations.Transl mcp.Description("Repository owner"), ), mcp.WithString("repo", - mcp.Required(), - mcp.Description("Repository name"), + mcp.Description("Repository name. If not provided, discussion categories will be queried at the organisation level."), ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - // Decode params - var params struct { - Owner string - Repo string + owner, err := RequiredParam[string](request, "owner") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil } - if err := mapstructure.Decode(request.Params.Arguments, ¶ms); err != nil { + repo, err := OptionalParam[string](request, "repo") + if err != nil { return mcp.NewToolResultError(err.Error()), nil } + // when not provided, default to the .github repository + // this will query discussion categories at the organisation level + if repo == "" { + repo = ".github" + } client, err := getGQLClient(ctx) if err != nil { @@ -490,8 +494,8 @@ func ListDiscussionCategories(getGQLClient GetGQLClientFn, t translations.Transl } `graphql:"repository(owner: $owner, name: $repo)"` } vars := map[string]interface{}{ - "owner": githubv4.String(params.Owner), - "repo": githubv4.String(params.Repo), + "owner": githubv4.String(owner), + "repo": githubv4.String(repo), "first": githubv4.Int(25), } if err := client.Query(ctx, &q, vars); err != nil { diff --git a/pkg/github/discussions_test.go b/pkg/github/discussions_test.go index 9458dfce0..b9bc7f611 100644 --- a/pkg/github/discussions_test.go +++ b/pkg/github/discussions_test.go @@ -484,7 +484,7 @@ func Test_GetDiscussion(t *testing.T) { assert.ElementsMatch(t, toolDef.InputSchema.Required, []string{"owner", "repo", "discussionNumber"}) // Use exact string query that matches implementation output - qGetDiscussion := "query($discussionNumber:Int!$owner:String!$repo:String!){repository(owner: $owner, name: $repo){discussion(number: $discussionNumber){number,title,body,createdAt,url,category{name}}}}" + qGetDiscussion := "query($discussionNumber:Int!$owner:String!$repo:String!){repository(owner: $owner, name: $repo){discussion(number: $discussionNumber){number,title,body,createdAt,url,category{name}}}}" vars := map[string]interface{}{ "owner": "owner", @@ -638,17 +638,33 @@ func Test_GetDiscussionComments(t *testing.T) { } func Test_ListDiscussionCategories(t *testing.T) { + mockClient := githubv4.NewClient(nil) + toolDef, _ := ListDiscussionCategories(stubGetGQLClientFn(mockClient), translations.NullTranslationHelper) + assert.Equal(t, "list_discussion_categories", toolDef.Name) + assert.NotEmpty(t, toolDef.Description) + assert.Contains(t, toolDef.Description, "or organisation") + assert.Contains(t, toolDef.InputSchema.Properties, "owner") + assert.Contains(t, toolDef.InputSchema.Properties, "repo") + assert.ElementsMatch(t, toolDef.InputSchema.Required, []string{"owner"}) + // Use exact string query that matches implementation output qListCategories := "query($first:Int!$owner:String!$repo:String!){repository(owner: $owner, name: $repo){discussionCategories(first: $first){nodes{id,name},pageInfo{hasNextPage,hasPreviousPage,startCursor,endCursor},totalCount}}}" - // Variables matching what GraphQL receives after JSON marshaling/unmarshaling - vars := map[string]interface{}{ + // Variables for repository-level categories + varsRepo := map[string]interface{}{ "owner": "owner", "repo": "repo", "first": float64(25), } - mockResp := githubv4mock.DataResponse(map[string]any{ + // Variables for organization-level categories (using .github repo) + varsOrg := map[string]interface{}{ + "owner": "owner", + "repo": ".github", + "first": float64(25), + } + + mockRespRepo := githubv4mock.DataResponse(map[string]any{ "repository": map[string]any{ "discussionCategories": map[string]any{ "nodes": []map[string]any{ @@ -665,37 +681,98 @@ func Test_ListDiscussionCategories(t *testing.T) { }, }, }) - matcher := githubv4mock.NewQueryMatcher(qListCategories, vars, mockResp) - httpClient := githubv4mock.NewMockedHTTPClient(matcher) - gqlClient := githubv4.NewClient(httpClient) - tool, handler := ListDiscussionCategories(stubGetGQLClientFn(gqlClient), translations.NullTranslationHelper) - assert.Equal(t, "list_discussion_categories", tool.Name) - assert.NotEmpty(t, tool.Description) - assert.Contains(t, tool.InputSchema.Properties, "owner") - assert.Contains(t, tool.InputSchema.Properties, "repo") - assert.ElementsMatch(t, tool.InputSchema.Required, []string{"owner", "repo"}) + mockRespOrg := githubv4mock.DataResponse(map[string]any{ + "repository": map[string]any{ + "discussionCategories": map[string]any{ + "nodes": []map[string]any{ + {"id": "789", "name": "Announcements"}, + {"id": "101", "name": "General"}, + {"id": "112", "name": "Ideas"}, + }, + "pageInfo": map[string]any{ + "hasNextPage": false, + "hasPreviousPage": false, + "startCursor": "", + "endCursor": "", + }, + "totalCount": 3, + }, + }, + }) - request := createMCPRequest(map[string]interface{}{"owner": "owner", "repo": "repo"}) - result, err := handler(context.Background(), request) - require.NoError(t, err) + tests := []struct { + name string + reqParams map[string]interface{} + vars map[string]interface{} + mockResponse githubv4mock.GQLResponse + expectError bool + expectedCount int + expectedCategories []map[string]string + }{ + { + name: "list repository-level discussion categories", + reqParams: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + }, + vars: varsRepo, + mockResponse: mockRespRepo, + expectError: false, + expectedCount: 2, + expectedCategories: []map[string]string{ + {"id": "123", "name": "CategoryOne"}, + {"id": "456", "name": "CategoryTwo"}, + }, + }, + { + name: "list org-level discussion categories (no repo provided)", + reqParams: map[string]interface{}{ + "owner": "owner", + // repo is not provided, it will default to ".github" + }, + vars: varsOrg, + mockResponse: mockRespOrg, + expectError: false, + expectedCount: 3, + expectedCategories: []map[string]string{ + {"id": "789", "name": "Announcements"}, + {"id": "101", "name": "General"}, + {"id": "112", "name": "Ideas"}, + }, + }, + } - text := getTextResult(t, result).Text + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + matcher := githubv4mock.NewQueryMatcher(qListCategories, tc.vars, tc.mockResponse) + httpClient := githubv4mock.NewMockedHTTPClient(matcher) + gqlClient := githubv4.NewClient(httpClient) - var response struct { - Categories []map[string]string `json:"categories"` - PageInfo struct { - HasNextPage bool `json:"hasNextPage"` - HasPreviousPage bool `json:"hasPreviousPage"` - StartCursor string `json:"startCursor"` - EndCursor string `json:"endCursor"` - } `json:"pageInfo"` - TotalCount int `json:"totalCount"` + _, handler := ListDiscussionCategories(stubGetGQLClientFn(gqlClient), translations.NullTranslationHelper) + + req := createMCPRequest(tc.reqParams) + res, err := handler(context.Background(), req) + text := getTextResult(t, res).Text + + if tc.expectError { + require.True(t, res.IsError) + return + } + require.NoError(t, err) + + var response struct { + Categories []map[string]string `json:"categories"` + PageInfo struct { + HasNextPage bool `json:"hasNextPage"` + HasPreviousPage bool `json:"hasPreviousPage"` + StartCursor string `json:"startCursor"` + EndCursor string `json:"endCursor"` + } `json:"pageInfo"` + TotalCount int `json:"totalCount"` + } + require.NoError(t, json.Unmarshal([]byte(text), &response)) + assert.Equal(t, tc.expectedCategories, response.Categories) + }) } - require.NoError(t, json.Unmarshal([]byte(text), &response)) - assert.Len(t, response.Categories, 2) - assert.Equal(t, "123", response.Categories[0]["id"]) - assert.Equal(t, "CategoryOne", response.Categories[0]["name"]) - assert.Equal(t, "456", response.Categories[1]["id"]) - assert.Equal(t, "CategoryTwo", response.Categories[1]["name"]) }