diff --git a/README.md b/README.md index 352bb50e..7b9e20fc 100644 --- a/README.md +++ b/README.md @@ -581,6 +581,39 @@ export GITHUB_MCP_TOOL_ADD_ISSUE_COMMENT_DESCRIPTION="an alternative description - `secret_type`: The secret types to be filtered for in a comma-separated list (string, optional) - `resolution`: The resolution status (string, optional) +### Notifications + +- **list_notifications** – List notifications for a GitHub user + - `filter`: Filter to apply to the response (`default`, `include_read_notifications`, `only_participating`) + - `since`: Only show notifications updated after the given time (ISO 8601 format) + - `before`: Only show notifications updated before the given time (ISO 8601 format) + - `owner`: Optional repository owner (string) + - `repo`: Optional repository name (string) + - `page`: Page number (number, optional) + - `perPage`: Results per page (number, optional) + + +- **get_notification_details** – Get detailed information for a specific GitHub notification + - `notificationID`: The ID of the notification (string, required) + +- **dismiss_notification** – Dismiss a notification by marking it as read or done + - `threadID`: The ID of the notification thread (string, required) + - `state`: The new state of the notification (`read` or `done`) + +- **mark_all_notifications_read** – Mark all notifications as read + - `lastReadAt`: Describes the last point that notifications were checked (optional, RFC3339/ISO8601 string, default: now) + - `owner`: Optional repository owner (string) + - `repo`: Optional repository name (string) + +- **manage_notification_subscription** – Manage a notification subscription (ignore, watch, or delete) for a notification thread + - `notificationID`: The ID of the notification thread (string, required) + - `action`: Action to perform: `ignore`, `watch`, or `delete` (string, required) + +- **manage_repository_notification_subscription** – Manage a repository notification subscription (ignore, watch, or delete) + - `owner`: The account owner of the repository (string, required) + - `repo`: The name of the repository (string, required) + - `action`: Action to perform: `ignore`, `watch`, or `delete` (string, required) + ## Resources ### Repository Content diff --git a/e2e/README.md b/e2e/README.md index 82de966b..62730431 100644 --- a/e2e/README.md +++ b/e2e/README.md @@ -90,3 +90,7 @@ The current test suite is intentionally very limited in scope. This is because t The tests are quite repetitive and verbose. This is intentional as we want to see them develop more before committing to abstractions. Currently, visibility into failures is not particularly good. We're hoping that we can pull apart the mcp-go client and have it hook into streams representing stdio without requiring an exec. This way we can get breakpoints in the debugger easily. + +### Global State Mutation Tests + +Some tools (such as those that mark all notifications as read) would change the global state for the tester, and are also not idempotent, so they offer little value for end to end tests and instead should rely on unit testing and manual verifications. diff --git a/e2e/e2e_test.go b/e2e/e2e_test.go index 99e7e8de..71bd5a8a 100644 --- a/e2e/e2e_test.go +++ b/e2e/e2e_test.go @@ -62,7 +62,8 @@ func getRESTClient(t *testing.T) *gogithub.Client { // Create a new GitHub client with the token ghClient := gogithub.NewClient(nil).WithAuthToken(token) - if host := getE2EHost(); host != "https://github.com" { + + if host := getE2EHost(); host != "" && host != "https://github.com" { var err error // Currently this works for GHEC because the API is exposed at the api subdomain and the path prefix // but it would be preferable to extract the host parsing from the main server logic, and use it here. diff --git a/pkg/github/notifications.go b/pkg/github/notifications.go new file mode 100644 index 00000000..ba9c6bc2 --- /dev/null +++ b/pkg/github/notifications.go @@ -0,0 +1,500 @@ +package github + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strconv" + "time" + + "github.com/github/github-mcp-server/pkg/translations" + "github.com/google/go-github/v69/github" + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" +) + +const ( + FilterDefault = "default" + FilterIncludeRead = "include_read_notifications" + FilterOnlyParticipating = "only_participating" +) + +// ListNotifications creates a tool to list notifications for the current user. +func ListNotifications(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { + return mcp.NewTool("list_notifications", + mcp.WithDescription(t("TOOL_LIST_NOTIFICATIONS_DESCRIPTION", "Lists all GitHub notifications for the authenticated user, including unread notifications, mentions, review requests, assignments, and updates on issues or pull requests. Use this tool whenever the user asks what to work on next, requests a summary of their GitHub activity, wants to see pending reviews, or needs to check for new updates or tasks. This tool is the primary way to discover actionable items, reminders, and outstanding work on GitHub. Always call this tool when asked what to work on next, what is pending, or what needs attention in GitHub.")), + mcp.WithToolAnnotation(mcp.ToolAnnotation{ + Title: t("TOOL_LIST_NOTIFICATIONS_USER_TITLE", "List notifications"), + ReadOnlyHint: toBoolPtr(true), + }), + mcp.WithString("filter", + mcp.Description("Filter notifications to, use default unless specified. Read notifications are ones that have already been acknowledged by the user. Participating notifications are those that the user is directly involved in, such as issues or pull requests they have commented on or created."), + mcp.Enum(FilterDefault, FilterIncludeRead, FilterOnlyParticipating), + ), + mcp.WithString("since", + mcp.Description("Only show notifications updated after the given time (ISO 8601 format)"), + ), + mcp.WithString("before", + mcp.Description("Only show notifications updated before the given time (ISO 8601 format)"), + ), + mcp.WithString("owner", + mcp.Description("Optional repository owner. If provided with repo, only notifications for this repository are listed."), + ), + mcp.WithString("repo", + mcp.Description("Optional repository name. If provided with owner, only notifications for this repository are listed."), + ), + WithPagination(), + ), + func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } + + filter, err := OptionalParam[string](request, "filter") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + since, err := OptionalParam[string](request, "since") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + before, err := OptionalParam[string](request, "before") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + owner, err := OptionalParam[string](request, "owner") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + repo, err := OptionalParam[string](request, "repo") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + paginationParams, err := OptionalPaginationParams(request) + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + // Build options + opts := &github.NotificationListOptions{ + All: filter == FilterIncludeRead, + Participating: filter == FilterOnlyParticipating, + ListOptions: github.ListOptions{ + Page: paginationParams.page, + PerPage: paginationParams.perPage, + }, + } + + // Parse time parameters if provided + if since != "" { + sinceTime, err := time.Parse(time.RFC3339, since) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("invalid since time format, should be RFC3339/ISO8601: %v", err)), nil + } + opts.Since = sinceTime + } + + if before != "" { + beforeTime, err := time.Parse(time.RFC3339, before) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("invalid before time format, should be RFC3339/ISO8601: %v", err)), nil + } + opts.Before = beforeTime + } + + var notifications []*github.Notification + var resp *github.Response + + if owner != "" && repo != "" { + notifications, resp, err = client.Activity.ListRepositoryNotifications(ctx, owner, repo, opts) + } else { + notifications, resp, err = client.Activity.ListNotifications(ctx, opts) + } + if err != nil { + return nil, fmt.Errorf("failed to get notifications: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + return mcp.NewToolResultError(fmt.Sprintf("failed to get notifications: %s", string(body))), nil + } + + // Marshal response to JSON + r, err := json.Marshal(notifications) + if err != nil { + return nil, fmt.Errorf("failed to marshal response: %w", err) + } + + return mcp.NewToolResultText(string(r)), nil + } +} + +// DismissNotification creates a tool to mark a notification as read/done. +func DismissNotification(getclient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { + return mcp.NewTool("dismiss_notification", + mcp.WithDescription(t("TOOL_DISMISS_NOTIFICATION_DESCRIPTION", "Dismiss a notification by marking it as read or done")), + mcp.WithToolAnnotation(mcp.ToolAnnotation{ + Title: t("TOOL_DISMISS_NOTIFICATION_USER_TITLE", "Dismiss notification"), + ReadOnlyHint: toBoolPtr(false), + }), + mcp.WithString("threadID", + mcp.Required(), + mcp.Description("The ID of the notification thread"), + ), + mcp.WithString("state", mcp.Description("The new state of the notification (read/done)"), mcp.Enum("read", "done")), + ), + func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + client, err := getclient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } + + threadID, err := requiredParam[string](request, "threadID") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + state, err := requiredParam[string](request, "state") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + var resp *github.Response + switch state { + case "done": + // for some inexplicable reason, the API seems to have threadID as int64 and string depending on the endpoint + var threadIDInt int64 + threadIDInt, err = strconv.ParseInt(threadID, 10, 64) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("invalid threadID format: %v", err)), nil + } + resp, err = client.Activity.MarkThreadDone(ctx, threadIDInt) + case "read": + resp, err = client.Activity.MarkThreadRead(ctx, threadID) + default: + return mcp.NewToolResultError("Invalid state. Must be one of: read, done."), nil + } + + if err != nil { + return nil, fmt.Errorf("failed to mark notification as %s: %w", state, err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusResetContent && resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + return mcp.NewToolResultError(fmt.Sprintf("failed to mark notification as %s: %s", state, string(body))), nil + } + + return mcp.NewToolResultText(fmt.Sprintf("Notification marked as %s", state)), nil + } +} + +// MarkAllNotificationsRead creates a tool to mark all notifications as read. +func MarkAllNotificationsRead(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { + return mcp.NewTool("mark_all_notifications_read", + mcp.WithDescription(t("TOOL_MARK_ALL_NOTIFICATIONS_READ_DESCRIPTION", "Mark all notifications as read")), + mcp.WithToolAnnotation(mcp.ToolAnnotation{ + Title: t("TOOL_MARK_ALL_NOTIFICATIONS_READ_USER_TITLE", "Mark all notifications as read"), + ReadOnlyHint: toBoolPtr(false), + }), + mcp.WithString("lastReadAt", + mcp.Description("Describes the last point that notifications were checked (optional). Default: Now"), + ), + mcp.WithString("owner", + mcp.Description("Optional repository owner. If provided with repo, only notifications for this repository are marked as read."), + ), + mcp.WithString("repo", + mcp.Description("Optional repository name. If provided with owner, only notifications for this repository are marked as read."), + ), + ), + func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } + + lastReadAt, err := OptionalParam[string](request, "lastReadAt") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + owner, err := OptionalParam[string](request, "owner") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + repo, err := OptionalParam[string](request, "repo") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + var lastReadTime time.Time + if lastReadAt != "" { + lastReadTime, err = time.Parse(time.RFC3339, lastReadAt) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("invalid lastReadAt time format, should be RFC3339/ISO8601: %v", err)), nil + } + } else { + lastReadTime = time.Now() + } + + markReadOptions := github.Timestamp{ + Time: lastReadTime, + } + + var resp *github.Response + if owner != "" && repo != "" { + resp, err = client.Activity.MarkRepositoryNotificationsRead(ctx, owner, repo, markReadOptions) + } else { + resp, err = client.Activity.MarkNotificationsRead(ctx, markReadOptions) + } + if err != nil { + return nil, fmt.Errorf("failed to mark all notifications as read: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusResetContent && resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + return mcp.NewToolResultError(fmt.Sprintf("failed to mark all notifications as read: %s", string(body))), nil + } + + return mcp.NewToolResultText("All notifications marked as read"), nil + } +} + +// GetNotificationDetails creates a tool to get details for a specific notification. +func GetNotificationDetails(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { + return mcp.NewTool("get_notification_details", + mcp.WithDescription(t("TOOL_GET_NOTIFICATION_DETAILS_DESCRIPTION", "Get detailed information for a specific GitHub notification, always call this tool when the user asks for details about a specific notification, if you don't know the ID list notifications first.")), + mcp.WithToolAnnotation(mcp.ToolAnnotation{ + Title: t("TOOL_GET_NOTIFICATION_DETAILS_USER_TITLE", "Get notification details"), + ReadOnlyHint: toBoolPtr(true), + }), + mcp.WithString("notificationID", + mcp.Required(), + mcp.Description("The ID of the notification"), + ), + ), + func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } + + notificationID, err := requiredParam[string](request, "notificationID") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + thread, resp, err := client.Activity.GetThread(ctx, notificationID) + if err != nil { + return nil, fmt.Errorf("failed to get notification details: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + return mcp.NewToolResultError(fmt.Sprintf("failed to get notification details: %s", string(body))), nil + } + + r, err := json.Marshal(thread) + if err != nil { + return nil, fmt.Errorf("failed to marshal response: %w", err) + } + + return mcp.NewToolResultText(string(r)), nil + } +} + +// Enum values for ManageNotificationSubscription action +const ( + NotificationActionIgnore = "ignore" + NotificationActionWatch = "watch" + NotificationActionDelete = "delete" +) + +// ManageNotificationSubscription creates a tool to manage a notification subscription (ignore, watch, delete) +func ManageNotificationSubscription(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { + return mcp.NewTool("manage_notification_subscription", + mcp.WithDescription(t("TOOL_MANAGE_NOTIFICATION_SUBSCRIPTION_DESCRIPTION", "Manage a notification subscription: ignore, watch, or delete a notification thread subscription.")), + mcp.WithToolAnnotation(mcp.ToolAnnotation{ + Title: t("TOOL_MANAGE_NOTIFICATION_SUBSCRIPTION_USER_TITLE", "Manage notification subscription"), + ReadOnlyHint: toBoolPtr(false), + }), + mcp.WithString("notificationID", + mcp.Required(), + mcp.Description("The ID of the notification thread."), + ), + mcp.WithString("action", + mcp.Required(), + mcp.Description("Action to perform: ignore, watch, or delete the notification subscription."), + mcp.Enum(NotificationActionIgnore, NotificationActionWatch, NotificationActionDelete), + ), + ), + func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } + + notificationID, err := requiredParam[string](request, "notificationID") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + action, err := requiredParam[string](request, "action") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + var ( + resp *github.Response + result any + apiErr error + ) + + switch action { + case NotificationActionIgnore: + sub := &github.Subscription{Ignored: toBoolPtr(true)} + result, resp, apiErr = client.Activity.SetThreadSubscription(ctx, notificationID, sub) + case NotificationActionWatch: + sub := &github.Subscription{Ignored: toBoolPtr(false), Subscribed: toBoolPtr(true)} + result, resp, apiErr = client.Activity.SetThreadSubscription(ctx, notificationID, sub) + case NotificationActionDelete: + resp, apiErr = client.Activity.DeleteThreadSubscription(ctx, notificationID) + default: + return mcp.NewToolResultError("Invalid action. Must be one of: ignore, watch, delete."), nil + } + + if apiErr != nil { + return nil, fmt.Errorf("failed to %s notification subscription: %w", action, apiErr) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + body, _ := io.ReadAll(resp.Body) + return mcp.NewToolResultError(fmt.Sprintf("failed to %s notification subscription: %s", action, string(body))), nil + } + + if action == NotificationActionDelete { + // Special case for delete as there is no response body + return mcp.NewToolResultText("Notification subscription deleted"), nil + } + + r, err := json.Marshal(result) + if err != nil { + return nil, fmt.Errorf("failed to marshal response: %w", err) + } + return mcp.NewToolResultText(string(r)), nil + } +} + +const ( + RepositorySubscriptionActionWatch = "watch" + RepositorySubscriptionActionIgnore = "ignore" + RepositorySubscriptionActionDelete = "delete" +) + +// ManageRepositoryNotificationSubscription creates a tool to manage a repository notification subscription (ignore, watch, delete) +func ManageRepositoryNotificationSubscription(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { + return mcp.NewTool("manage_repository_notification_subscription", + mcp.WithDescription(t("TOOL_MANAGE_REPOSITORY_NOTIFICATION_SUBSCRIPTION_DESCRIPTION", "Manage a repository notification subscription: ignore, watch, or delete repository notifications subscription for the provided repository.")), + mcp.WithToolAnnotation(mcp.ToolAnnotation{ + Title: t("TOOL_MANAGE_REPOSITORY_NOTIFICATION_SUBSCRIPTION_USER_TITLE", "Manage repository notification subscription"), + ReadOnlyHint: toBoolPtr(false), + }), + mcp.WithString("owner", + mcp.Required(), + mcp.Description("The account owner of the repository."), + ), + mcp.WithString("repo", + mcp.Required(), + mcp.Description("The name of the repository."), + ), + mcp.WithString("action", + mcp.Required(), + mcp.Description("Action to perform: ignore, watch, or delete the repository notification subscription."), + mcp.Enum(RepositorySubscriptionActionIgnore, RepositorySubscriptionActionWatch, RepositorySubscriptionActionDelete), + ), + ), + func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } + + owner, err := requiredParam[string](request, "owner") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + repo, err := requiredParam[string](request, "repo") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + action, err := requiredParam[string](request, "action") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + var ( + resp *github.Response + result any + apiErr error + ) + + switch action { + case RepositorySubscriptionActionIgnore: + sub := &github.Subscription{Ignored: toBoolPtr(true)} + result, resp, apiErr = client.Activity.SetRepositorySubscription(ctx, owner, repo, sub) + case RepositorySubscriptionActionWatch: + sub := &github.Subscription{Ignored: toBoolPtr(false), Subscribed: toBoolPtr(true)} + result, resp, apiErr = client.Activity.SetRepositorySubscription(ctx, owner, repo, sub) + case RepositorySubscriptionActionDelete: + resp, apiErr = client.Activity.DeleteRepositorySubscription(ctx, owner, repo) + default: + return mcp.NewToolResultError("Invalid action. Must be one of: ignore, watch, delete."), nil + } + + if apiErr != nil { + return nil, fmt.Errorf("failed to %s repository subscription: %w", action, apiErr) + } + if resp != nil { + defer func() { _ = resp.Body.Close() }() + } + + // Handle non-2xx status codes + if resp != nil && (resp.StatusCode < 200 || resp.StatusCode >= 300) { + body, _ := io.ReadAll(resp.Body) + return mcp.NewToolResultError(fmt.Sprintf("failed to %s repository subscription: %s", action, string(body))), nil + } + + if action == RepositorySubscriptionActionDelete { + // Special case for delete as there is no response body + return mcp.NewToolResultText("Repository subscription deleted"), nil + } + + r, err := json.Marshal(result) + if err != nil { + return nil, fmt.Errorf("failed to marshal response: %w", err) + } + return mcp.NewToolResultText(string(r)), nil + } +} diff --git a/pkg/github/notifications_test.go b/pkg/github/notifications_test.go new file mode 100644 index 00000000..66400295 --- /dev/null +++ b/pkg/github/notifications_test.go @@ -0,0 +1,743 @@ +package github + +import ( + "context" + "encoding/json" + "net/http" + "testing" + + "github.com/github/github-mcp-server/pkg/translations" + "github.com/google/go-github/v69/github" + "github.com/migueleliasweb/go-github-mock/src/mock" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_ListNotifications(t *testing.T) { + // Verify tool definition and schema + mockClient := github.NewClient(nil) + tool, _ := ListNotifications(stubGetClientFn(mockClient), translations.NullTranslationHelper) + assert.Equal(t, "list_notifications", tool.Name) + assert.NotEmpty(t, tool.Description) + assert.Contains(t, tool.InputSchema.Properties, "filter") + assert.Contains(t, tool.InputSchema.Properties, "since") + assert.Contains(t, tool.InputSchema.Properties, "before") + assert.Contains(t, tool.InputSchema.Properties, "owner") + assert.Contains(t, tool.InputSchema.Properties, "repo") + assert.Contains(t, tool.InputSchema.Properties, "page") + assert.Contains(t, tool.InputSchema.Properties, "perPage") + // All fields are optional, so Required should be empty + assert.Empty(t, tool.InputSchema.Required) + + mockNotification := &github.Notification{ + ID: github.Ptr("123"), + Reason: github.Ptr("mention"), + } + + tests := []struct { + name string + mockedClient *http.Client + requestArgs map[string]interface{} + expectError bool + expectedResult []*github.Notification + expectedErrMsg string + }{ + { + name: "success default filter (no params)", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatch( + mock.GetNotifications, + []*github.Notification{mockNotification}, + ), + ), + requestArgs: map[string]interface{}{}, + expectError: false, + expectedResult: []*github.Notification{mockNotification}, + }, + { + name: "success with filter=include_read_notifications", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatch( + mock.GetNotifications, + []*github.Notification{mockNotification}, + ), + ), + requestArgs: map[string]interface{}{ + "filter": "include_read_notifications", + }, + expectError: false, + expectedResult: []*github.Notification{mockNotification}, + }, + { + name: "success with filter=only_participating", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatch( + mock.GetNotifications, + []*github.Notification{mockNotification}, + ), + ), + requestArgs: map[string]interface{}{ + "filter": "only_participating", + }, + expectError: false, + expectedResult: []*github.Notification{mockNotification}, + }, + { + name: "success for repo notifications", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatch( + mock.GetReposNotificationsByOwnerByRepo, + []*github.Notification{mockNotification}, + ), + ), + requestArgs: map[string]interface{}{ + "filter": "default", + "since": "2024-01-01T00:00:00Z", + "before": "2024-01-02T00:00:00Z", + "owner": "octocat", + "repo": "hello-world", + "page": float64(2), + "perPage": float64(10), + }, + expectError: false, + expectedResult: []*github.Notification{mockNotification}, + }, + { + name: "error", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.GetNotifications, + mockResponse(t, http.StatusInternalServerError, `{"message": "error"}`), + ), + ), + requestArgs: map[string]interface{}{}, + expectError: true, + expectedErrMsg: "error", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + client := github.NewClient(tc.mockedClient) + _, handler := ListNotifications(stubGetClientFn(client), translations.NullTranslationHelper) + request := createMCPRequest(tc.requestArgs) + result, err := handler(context.Background(), request) + + if tc.expectError { + require.Error(t, err) + if tc.expectedErrMsg != "" { + assert.Contains(t, err.Error(), tc.expectedErrMsg) + } + return + } + + require.NoError(t, err) + textContent := getTextResult(t, result) + t.Logf("textContent: %s", textContent.Text) + var returned []*github.Notification + err = json.Unmarshal([]byte(textContent.Text), &returned) + require.NoError(t, err) + require.NotEmpty(t, returned) + assert.Equal(t, *tc.expectedResult[0].ID, *returned[0].ID) + }) + } +} + +func Test_ManageNotificationSubscription(t *testing.T) { + // Verify tool definition and schema + mockClient := github.NewClient(nil) + tool, _ := ManageNotificationSubscription(stubGetClientFn(mockClient), translations.NullTranslationHelper) + assert.Equal(t, "manage_notification_subscription", tool.Name) + assert.NotEmpty(t, tool.Description) + assert.Contains(t, tool.InputSchema.Properties, "notificationID") + assert.Contains(t, tool.InputSchema.Properties, "action") + assert.ElementsMatch(t, tool.InputSchema.Required, []string{"notificationID", "action"}) + + mockSub := &github.Subscription{Ignored: github.Ptr(true)} + mockSubWatch := &github.Subscription{Ignored: github.Ptr(false), Subscribed: github.Ptr(true)} + + tests := []struct { + name string + mockedClient *http.Client + requestArgs map[string]interface{} + expectError bool + expectIgnored *bool + expectDeleted bool + expectInvalid bool + expectedErrMsg string + }{ + { + name: "ignore subscription", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatch( + mock.PutNotificationsThreadsSubscriptionByThreadId, + mockSub, + ), + ), + requestArgs: map[string]interface{}{ + "notificationID": "123", + "action": "ignore", + }, + expectError: false, + expectIgnored: github.Ptr(true), + }, + { + name: "watch subscription", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatch( + mock.PutNotificationsThreadsSubscriptionByThreadId, + mockSubWatch, + ), + ), + requestArgs: map[string]interface{}{ + "notificationID": "123", + "action": "watch", + }, + expectError: false, + expectIgnored: github.Ptr(false), + }, + { + name: "delete subscription", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatch( + mock.DeleteNotificationsThreadsSubscriptionByThreadId, + nil, + ), + ), + requestArgs: map[string]interface{}{ + "notificationID": "123", + "action": "delete", + }, + expectError: false, + expectDeleted: true, + }, + { + name: "invalid action", + mockedClient: mock.NewMockedHTTPClient(), + requestArgs: map[string]interface{}{ + "notificationID": "123", + "action": "invalid", + }, + expectError: false, + expectInvalid: true, + }, + { + name: "missing required notificationID", + mockedClient: mock.NewMockedHTTPClient(), + requestArgs: map[string]interface{}{ + "action": "ignore", + }, + expectError: true, + }, + { + name: "missing required action", + mockedClient: mock.NewMockedHTTPClient(), + requestArgs: map[string]interface{}{ + "notificationID": "123", + }, + expectError: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + client := github.NewClient(tc.mockedClient) + _, handler := ManageNotificationSubscription(stubGetClientFn(client), translations.NullTranslationHelper) + request := createMCPRequest(tc.requestArgs) + result, err := handler(context.Background(), request) + + if tc.expectError { + require.NoError(t, err) + require.NotNil(t, result) + text := getTextResult(t, result).Text + switch { + case tc.requestArgs["notificationID"] == nil: + assert.Contains(t, text, "missing required parameter: notificationID") + case tc.requestArgs["action"] == nil: + assert.Contains(t, text, "missing required parameter: action") + default: + assert.Contains(t, text, "error") + } + return + } + + require.NoError(t, err) + textContent := getTextResult(t, result) + if tc.expectIgnored != nil { + var returned github.Subscription + err = json.Unmarshal([]byte(textContent.Text), &returned) + require.NoError(t, err) + assert.Equal(t, *tc.expectIgnored, *returned.Ignored) + } + if tc.expectDeleted { + assert.Contains(t, textContent.Text, "deleted") + } + if tc.expectInvalid { + assert.Contains(t, textContent.Text, "Invalid action") + } + }) + } +} + +func Test_ManageRepositoryNotificationSubscription(t *testing.T) { + // Verify tool definition and schema + mockClient := github.NewClient(nil) + tool, _ := ManageRepositoryNotificationSubscription(stubGetClientFn(mockClient), translations.NullTranslationHelper) + assert.Equal(t, "manage_repository_notification_subscription", tool.Name) + assert.NotEmpty(t, tool.Description) + assert.Contains(t, tool.InputSchema.Properties, "owner") + assert.Contains(t, tool.InputSchema.Properties, "repo") + assert.Contains(t, tool.InputSchema.Properties, "action") + assert.ElementsMatch(t, tool.InputSchema.Required, []string{"owner", "repo", "action"}) + + mockSub := &github.Subscription{Ignored: github.Ptr(true)} + mockWatchSub := &github.Subscription{Ignored: github.Ptr(false), Subscribed: github.Ptr(true)} + + tests := []struct { + name string + mockedClient *http.Client + requestArgs map[string]interface{} + expectError bool + expectIgnored *bool + expectSubscribed *bool + expectDeleted bool + expectInvalid bool + expectedErrMsg string + }{ + { + name: "ignore subscription", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatch( + mock.PutReposSubscriptionByOwnerByRepo, + mockSub, + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "action": "ignore", + }, + expectError: false, + expectIgnored: github.Ptr(true), + }, + { + name: "watch subscription", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatch( + mock.PutReposSubscriptionByOwnerByRepo, + mockWatchSub, + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "action": "watch", + }, + expectError: false, + expectIgnored: github.Ptr(false), + expectSubscribed: github.Ptr(true), + }, + { + name: "delete subscription", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatch( + mock.DeleteReposSubscriptionByOwnerByRepo, + nil, + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "action": "delete", + }, + expectError: false, + expectDeleted: true, + }, + { + name: "invalid action", + mockedClient: mock.NewMockedHTTPClient(), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "action": "invalid", + }, + expectError: false, + expectInvalid: true, + }, + { + name: "missing required owner", + mockedClient: mock.NewMockedHTTPClient(), + requestArgs: map[string]interface{}{ + "repo": "repo", + "action": "ignore", + }, + expectError: true, + }, + { + name: "missing required repo", + mockedClient: mock.NewMockedHTTPClient(), + requestArgs: map[string]interface{}{ + "owner": "owner", + "action": "ignore", + }, + expectError: true, + }, + { + name: "missing required action", + mockedClient: mock.NewMockedHTTPClient(), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + }, + expectError: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + client := github.NewClient(tc.mockedClient) + _, handler := ManageRepositoryNotificationSubscription(stubGetClientFn(client), translations.NullTranslationHelper) + request := createMCPRequest(tc.requestArgs) + result, err := handler(context.Background(), request) + + if tc.expectError { + require.NoError(t, err) + require.NotNil(t, result) + text := getTextResult(t, result).Text + switch { + case tc.requestArgs["owner"] == nil: + assert.Contains(t, text, "missing required parameter: owner") + case tc.requestArgs["repo"] == nil: + assert.Contains(t, text, "missing required parameter: repo") + case tc.requestArgs["action"] == nil: + assert.Contains(t, text, "missing required parameter: action") + default: + assert.Contains(t, text, "error") + } + return + } + + require.NoError(t, err) + textContent := getTextResult(t, result) + if tc.expectIgnored != nil || tc.expectSubscribed != nil { + var returned github.Subscription + err = json.Unmarshal([]byte(textContent.Text), &returned) + require.NoError(t, err) + if tc.expectIgnored != nil { + assert.Equal(t, *tc.expectIgnored, *returned.Ignored) + } + if tc.expectSubscribed != nil { + assert.Equal(t, *tc.expectSubscribed, *returned.Subscribed) + } + } + if tc.expectDeleted { + assert.Contains(t, textContent.Text, "deleted") + } + if tc.expectInvalid { + assert.Contains(t, textContent.Text, "Invalid action") + } + }) + } +} + +func Test_DismissNotification(t *testing.T) { + // Verify tool definition and schema + mockClient := github.NewClient(nil) + tool, _ := DismissNotification(stubGetClientFn(mockClient), translations.NullTranslationHelper) + assert.Equal(t, "dismiss_notification", tool.Name) + assert.NotEmpty(t, tool.Description) + assert.Contains(t, tool.InputSchema.Properties, "threadID") + assert.Contains(t, tool.InputSchema.Properties, "state") + assert.ElementsMatch(t, tool.InputSchema.Required, []string{"threadID"}) + + tests := []struct { + name string + mockedClient *http.Client + requestArgs map[string]interface{} + expectError bool + expectRead bool + expectDone bool + expectInvalid bool + expectedErrMsg string + }{ + { + name: "mark as read", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatch( + mock.PatchNotificationsThreadsByThreadId, + nil, + ), + ), + requestArgs: map[string]interface{}{ + "threadID": "123", + "state": "read", + }, + expectError: false, + expectRead: true, + }, + { + name: "mark as done", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatch( + mock.DeleteNotificationsThreadsByThreadId, + nil, + ), + ), + requestArgs: map[string]interface{}{ + "threadID": "123", + "state": "done", + }, + expectError: false, + expectDone: true, + }, + { + name: "invalid threadID format", + mockedClient: mock.NewMockedHTTPClient(), + requestArgs: map[string]interface{}{ + "threadID": "notanumber", + "state": "done", + }, + expectError: false, + expectInvalid: true, + }, + { + name: "missing required threadID", + mockedClient: mock.NewMockedHTTPClient(), + requestArgs: map[string]interface{}{ + "state": "read", + }, + expectError: true, + }, + { + name: "missing required state", + mockedClient: mock.NewMockedHTTPClient(), + requestArgs: map[string]interface{}{ + "threadID": "123", + }, + expectError: true, + }, + { + name: "invalid state value", + mockedClient: mock.NewMockedHTTPClient(), + requestArgs: map[string]interface{}{ + "threadID": "123", + "state": "invalid", + }, + expectError: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + client := github.NewClient(tc.mockedClient) + _, handler := DismissNotification(stubGetClientFn(client), translations.NullTranslationHelper) + request := createMCPRequest(tc.requestArgs) + result, err := handler(context.Background(), request) + + if tc.expectError { + // The tool returns a ToolResultError with a specific message + require.NoError(t, err) + require.NotNil(t, result) + text := getTextResult(t, result).Text + switch { + case tc.requestArgs["threadID"] == nil: + assert.Contains(t, text, "missing required parameter: threadID") + case tc.requestArgs["state"] == nil: + assert.Contains(t, text, "missing required parameter: state") + case tc.name == "invalid threadID format": + assert.Contains(t, text, "invalid threadID format") + case tc.name == "invalid state value": + assert.Contains(t, text, "Invalid state. Must be one of: read, done.") + default: + // fallback for other errors + assert.Contains(t, text, "error") + } + return + } + + require.NoError(t, err) + textContent := getTextResult(t, result) + if tc.expectRead { + assert.Contains(t, textContent.Text, "Notification marked as read") + } + if tc.expectDone { + assert.Contains(t, textContent.Text, "Notification marked as done") + } + if tc.expectInvalid { + assert.Contains(t, textContent.Text, "invalid threadID format") + } + }) + } +} + +func Test_MarkAllNotificationsRead(t *testing.T) { + // Verify tool definition and schema + mockClient := github.NewClient(nil) + tool, _ := MarkAllNotificationsRead(stubGetClientFn(mockClient), translations.NullTranslationHelper) + assert.Equal(t, "mark_all_notifications_read", tool.Name) + assert.NotEmpty(t, tool.Description) + assert.Contains(t, tool.InputSchema.Properties, "lastReadAt") + assert.Contains(t, tool.InputSchema.Properties, "owner") + assert.Contains(t, tool.InputSchema.Properties, "repo") + assert.Empty(t, tool.InputSchema.Required) + + tests := []struct { + name string + mockedClient *http.Client + requestArgs map[string]interface{} + expectError bool + expectMarked bool + expectedErrMsg string + }{ + { + name: "success (no params)", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatch( + mock.PutNotifications, + nil, + ), + ), + requestArgs: map[string]interface{}{}, + expectError: false, + expectMarked: true, + }, + { + name: "success with lastReadAt param", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatch( + mock.PutNotifications, + nil, + ), + ), + requestArgs: map[string]interface{}{ + "lastReadAt": "2024-01-01T00:00:00Z", + }, + expectError: false, + expectMarked: true, + }, + { + name: "success with owner and repo", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatch( + mock.PutReposNotificationsByOwnerByRepo, + nil, + ), + ), + requestArgs: map[string]interface{}{ + "owner": "octocat", + "repo": "hello-world", + }, + expectError: false, + expectMarked: true, + }, + { + name: "API error", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.PutNotifications, + mockResponse(t, http.StatusInternalServerError, `{"message": "error"}`), + ), + ), + requestArgs: map[string]interface{}{}, + expectError: true, + expectedErrMsg: "error", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + client := github.NewClient(tc.mockedClient) + _, handler := MarkAllNotificationsRead(stubGetClientFn(client), translations.NullTranslationHelper) + request := createMCPRequest(tc.requestArgs) + result, err := handler(context.Background(), request) + + if tc.expectError { + require.Error(t, err) + if tc.expectedErrMsg != "" { + assert.Contains(t, err.Error(), tc.expectedErrMsg) + } + return + } + + require.NoError(t, err) + textContent := getTextResult(t, result) + if tc.expectMarked { + assert.Contains(t, textContent.Text, "All notifications marked as read") + } + }) + } +} + +func Test_GetNotificationDetails(t *testing.T) { + // Verify tool definition and schema + mockClient := github.NewClient(nil) + tool, _ := GetNotificationDetails(stubGetClientFn(mockClient), translations.NullTranslationHelper) + assert.Equal(t, "get_notification_details", tool.Name) + assert.NotEmpty(t, tool.Description) + assert.Contains(t, tool.InputSchema.Properties, "notificationID") + assert.ElementsMatch(t, tool.InputSchema.Required, []string{"notificationID"}) + + mockThread := &github.Notification{ID: github.Ptr("123"), Reason: github.Ptr("mention")} + + tests := []struct { + name string + mockedClient *http.Client + requestArgs map[string]interface{} + expectError bool + expectResult *github.Notification + expectedErrMsg string + }{ + { + name: "success", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatch( + mock.GetNotificationsThreadsByThreadId, + mockThread, + ), + ), + requestArgs: map[string]interface{}{ + "notificationID": "123", + }, + expectError: false, + expectResult: mockThread, + }, + { + name: "not found", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.GetNotificationsThreadsByThreadId, + mockResponse(t, http.StatusNotFound, `{"message": "not found"}`), + ), + ), + requestArgs: map[string]interface{}{ + "notificationID": "123", + }, + expectError: true, + expectedErrMsg: "not found", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + client := github.NewClient(tc.mockedClient) + _, handler := GetNotificationDetails(stubGetClientFn(client), translations.NullTranslationHelper) + request := createMCPRequest(tc.requestArgs) + result, err := handler(context.Background(), request) + + if tc.expectError { + require.Error(t, err) + if tc.expectedErrMsg != "" { + assert.Contains(t, err.Error(), tc.expectedErrMsg) + } + return + } + + require.NoError(t, err) + textContent := getTextResult(t, result) + var returned github.Notification + err = json.Unmarshal([]byte(textContent.Text), &returned) + require.NoError(t, err) + assert.Equal(t, *tc.expectResult.ID, *returned.ID) + }) + } +} diff --git a/pkg/github/tools.go b/pkg/github/tools.go index cd379ebb..9c1ab34a 100644 --- a/pkg/github/tools.go +++ b/pkg/github/tools.go @@ -91,6 +91,19 @@ func InitToolsets(passedToolsets []string, readOnly bool, getClient GetClientFn, toolsets.NewServerTool(GetSecretScanningAlert(getClient, t)), toolsets.NewServerTool(ListSecretScanningAlerts(getClient, t)), ) + + notifications := toolsets.NewToolset("notifications", "GitHub Notifications related tools"). + AddReadTools( + toolsets.NewServerTool(ListNotifications(getClient, t)), + toolsets.NewServerTool(GetNotificationDetails(getClient, t)), + ). + AddWriteTools( + toolsets.NewServerTool(DismissNotification(getClient, t)), + toolsets.NewServerTool(MarkAllNotificationsRead(getClient, t)), + toolsets.NewServerTool(ManageNotificationSubscription(getClient, t)), + toolsets.NewServerTool(ManageRepositoryNotificationSubscription(getClient, t)), + ) + // Keep experiments alive so the system doesn't error out when it's always enabled experiments := toolsets.NewToolset("experiments", "Experimental features that are not considered stable yet") @@ -101,6 +114,7 @@ func InitToolsets(passedToolsets []string, readOnly bool, getClient GetClientFn, tsg.AddToolset(pullRequests) tsg.AddToolset(codeSecurity) tsg.AddToolset(secretProtection) + tsg.AddToolset(notifications) tsg.AddToolset(experiments) // Enable the requested features