From 34775353e7a163b7e5d8018f0378e458ed9bd5cd Mon Sep 17 00:00:00 2001 From: Avinash Sridhar Date: Fri, 11 Apr 2025 11:48:42 -0400 Subject: [PATCH 1/2] feat: add GitHub notifications tools for managing user notifications --- pkg/github/notifications.go | 300 ++++++++++++++++++++++++++++++++++++ pkg/github/server.go | 42 +++++ pkg/github/tools.go | 14 ++ 3 files changed, 356 insertions(+) create mode 100644 pkg/github/notifications.go diff --git a/pkg/github/notifications.go b/pkg/github/notifications.go new file mode 100644 index 00000000..ac93081c --- /dev/null +++ b/pkg/github/notifications.go @@ -0,0 +1,300 @@ +package github + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strconv" + "time" + + "github.com/github/github-mcp-server/pkg/translations" + "github.com/google/go-github/v69/github" + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" +) + +// getNotifications creates a tool to list notifications for the current user. +func GetNotifications(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { + return mcp.NewTool("get_notifications", + mcp.WithDescription(t("TOOL_GET_NOTIFICATIONS_DESCRIPTION", "Get notifications for the authenticated GitHub user")), + mcp.WithBoolean("all", + mcp.Description("If true, show notifications marked as read. Default: false"), + ), + mcp.WithBoolean("participating", + mcp.Description("If true, only shows notifications in which the user is directly participating or mentioned. Default: false"), + ), + mcp.WithString("since", + mcp.Description("Only show notifications updated after the given time (ISO 8601 format)"), + ), + mcp.WithString("before", + mcp.Description("Only show notifications updated before the given time (ISO 8601 format)"), + ), + mcp.WithNumber("per_page", + mcp.Description("Results per page (max 100). Default: 30"), + ), + mcp.WithNumber("page", + mcp.Description("Page number of the results to fetch. Default: 1"), + ), + ), + func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } + + // Extract optional parameters with defaults + all, err := OptionalParamWithDefault[bool](request, "all", false) + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + participating, err := OptionalParamWithDefault[bool](request, "participating", false) + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + since, err := OptionalParam[string](request, "since") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + before, err := OptionalParam[string](request, "before") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + perPage, err := OptionalIntParamWithDefault(request, "per_page", 30) + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + page, err := OptionalIntParamWithDefault(request, "page", 1) + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + // Build options + opts := &github.NotificationListOptions{ + All: all, + Participating: participating, + ListOptions: github.ListOptions{ + Page: page, + PerPage: perPage, + }, + } + + // Parse time parameters if provided + if since != "" { + sinceTime, err := time.Parse(time.RFC3339, since) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("invalid since time format, should be RFC3339/ISO8601: %v", err)), nil + } + opts.Since = sinceTime + } + + if before != "" { + beforeTime, err := time.Parse(time.RFC3339, before) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("invalid before time format, should be RFC3339/ISO8601: %v", err)), nil + } + opts.Before = beforeTime + } + + // Call GitHub API + notifications, resp, err := client.Activity.ListNotifications(ctx, opts) + if err != nil { + return nil, fmt.Errorf("failed to get notifications: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + return mcp.NewToolResultError(fmt.Sprintf("failed to get notifications: %s", string(body))), nil + } + + // Marshal response to JSON + r, err := json.Marshal(notifications) + if err != nil { + return nil, fmt.Errorf("failed to marshal response: %w", err) + } + + return mcp.NewToolResultText(string(r)), nil + } +} + +// markNotificationRead creates a tool to mark a notification as read. +func MarkNotificationRead(getclient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { + return mcp.NewTool("mark_notification_read", + mcp.WithDescription(t("TOOL_MARK_NOTIFICATION_READ_DESCRIPTION", "Mark a notification as read")), + mcp.WithString("threadID", + mcp.Required(), + mcp.Description("The ID of the notification thread"), + ), + ), + func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + client, err := getclient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } + + threadID, err := requiredParam[string](request, "threadID") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + resp, err := client.Activity.MarkThreadRead(ctx, threadID) + if err != nil { + return nil, fmt.Errorf("failed to mark notification as read: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusResetContent && resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + return mcp.NewToolResultError(fmt.Sprintf("failed to mark notification as read: %s", string(body))), nil + } + + return mcp.NewToolResultText("Notification marked as read"), nil + } +} + +// MarkAllNotificationsRead creates a tool to mark all notifications as read. +func MarkAllNotificationsRead(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { + return mcp.NewTool("mark_all_notifications_read", + mcp.WithDescription(t("TOOL_MARK_ALL_NOTIFICATIONS_READ_DESCRIPTION", "Mark all notifications as read")), + mcp.WithString("lastReadAt", + mcp.Description("Describes the last point that notifications were checked (optional). Default: Now"), + ), + ), + func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } + + lastReadAt, err := OptionalParam(request, "lastReadAt") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + var markReadOptions github.Timestamp + if lastReadAt != "" { + lastReadTime, err := time.Parse(time.RFC3339, lastReadAt) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("invalid lastReadAt time format, should be RFC3339/ISO8601: %v", err)), nil + } + markReadOptions = github.Timestamp{ + Time: lastReadTime, + } + } + + resp, err := client.Activity.MarkNotificationsRead(ctx, markReadOptions) + if err != nil { + return nil, fmt.Errorf("failed to mark all notifications as read: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusResetContent && resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + return mcp.NewToolResultError(fmt.Sprintf("failed to mark all notifications as read: %s", string(body))), nil + } + + return mcp.NewToolResultText("All notifications marked as read"), nil + } +} + +// GetNotificationThread creates a tool to get a specific notification thread. +func GetNotificationThread(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { + return mcp.NewTool("get_notification_thread", + mcp.WithDescription(t("TOOL_GET_NOTIFICATION_THREAD_DESCRIPTION", "Get a specific notification thread")), + mcp.WithString("threadID", + mcp.Required(), + mcp.Description("The ID of the notification thread"), + ), + ), + func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } + + threadID, err := requiredParam[string](request, "threadID") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + thread, resp, err := client.Activity.GetThread(ctx, threadID) + if err != nil { + return nil, fmt.Errorf("failed to get notification thread: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + return mcp.NewToolResultError(fmt.Sprintf("failed to get notification thread: %s", string(body))), nil + } + + r, err := json.Marshal(thread) + if err != nil { + return nil, fmt.Errorf("failed to marshal response: %w", err) + } + + return mcp.NewToolResultText(string(r)), nil + } +} + +// markNotificationDone creates a tool to mark a notification as done. +func MarkNotificationDone(getclient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { + return mcp.NewTool("mark_notification_done", + mcp.WithDescription(t("TOOL_MARK_NOTIFICATION_DONE_DESCRIPTION", "Mark a notification as done")), + mcp.WithString("threadID", + mcp.Required(), + mcp.Description("The ID of the notification thread"), + ), + ), + func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + client, err := getclient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } + + threadIDStr, err := requiredParam[string](request, "threadID") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + threadID, err := strconv.ParseInt(threadIDStr, 10, 64) + if err != nil { + return mcp.NewToolResultError("Invalid threadID: must be a numeric value"), nil + } + + resp, err := client.Activity.MarkThreadDone(ctx, threadID) + if err != nil { + return nil, fmt.Errorf("failed to mark notification as done: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusResetContent && resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + return mcp.NewToolResultError(fmt.Sprintf("failed to mark notification as done: %s", string(body))), nil + } + + return mcp.NewToolResultText("Notification marked as done"), nil + } +} diff --git a/pkg/github/server.go b/pkg/github/server.go index e4c24171..79c146a4 100644 --- a/pkg/github/server.go +++ b/pkg/github/server.go @@ -26,6 +26,7 @@ func NewServer(version string, opts ...server.ServerOption) *server.MCPServer { version, opts..., ) + return s } @@ -143,6 +144,47 @@ func OptionalIntParamWithDefault(r mcp.CallToolRequest, p string, d int) (int, e return v, nil } +// OptionalBoolParamWithDefault is a helper function that can be used to fetch a requested parameter from the request +// similar to optionalParam, but it also takes a default value. +func OptionalBoolParamWithDefault(r mcp.CallToolRequest, p string, d bool) (bool, error) { + v, err := OptionalParam[bool](r, p) + if err != nil { + return false, err + } + if !v { + return d, nil + } + return v, nil +} + +// OptionalStringParam is a helper function that can be used to fetch a requested parameter from the request. +// It does the following checks: +// 1. Checks if the parameter is present in the request, if not, it returns its zero-value +// 2. If it is present, it checks if the parameter is of the expected type and returns it +func OptionalStringParam(r mcp.CallToolRequest, p string) (string, error) { + v, err := OptionalParam[string](r, p) + if err != nil { + return "", err + } + if v == "" { + return "", nil + } + return v, nil +} + +// OptionalStringParamWithDefault is a helper function that can be used to fetch a requested parameter from the request +// similar to optionalParam, but it also takes a default value. +func OptionalStringParamWithDefault(r mcp.CallToolRequest, p string, d string) (string, error) { + v, err := OptionalParam[string](r, p) + if err != nil { + return "", err + } + if v == "" { + return d, nil + } + return v, nil +} + // OptionalStringArrayParam is a helper function that can be used to fetch a requested parameter from the request. // It does the following checks: // 1. Checks if the parameter is present in the request, if not, it returns its zero-value diff --git a/pkg/github/tools.go b/pkg/github/tools.go index cd379ebb..72b85b0c 100644 --- a/pkg/github/tools.go +++ b/pkg/github/tools.go @@ -91,6 +91,19 @@ func InitToolsets(passedToolsets []string, readOnly bool, getClient GetClientFn, toolsets.NewServerTool(GetSecretScanningAlert(getClient, t)), toolsets.NewServerTool(ListSecretScanningAlerts(getClient, t)), ) + + notifications := toolsets.NewToolset("notifications", "GitHub Notifications related tools"). + AddReadTools( + + toolsets.NewServerTool(MarkNotificationRead(getClient, t)), + toolsets.NewServerTool(MarkAllNotificationsRead(getClient, t)), + toolsets.NewServerTool(MarkNotificationDone(getClient, t)), + ). + AddWriteTools( + toolsets.NewServerTool(GetNotifications(getClient, t)), + toolsets.NewServerTool(GetNotificationThread(getClient, t)), + ) + // Keep experiments alive so the system doesn't error out when it's always enabled experiments := toolsets.NewToolset("experiments", "Experimental features that are not considered stable yet") @@ -101,6 +114,7 @@ func InitToolsets(passedToolsets []string, readOnly bool, getClient GetClientFn, tsg.AddToolset(pullRequests) tsg.AddToolset(codeSecurity) tsg.AddToolset(secretProtection) + tsg.AddToolset(notifications) tsg.AddToolset(experiments) // Enable the requested features From cd76a6caf060a558bd3412d49709b19317e102a0 Mon Sep 17 00:00:00 2001 From: Sam Morrow Date: Tue, 20 May 2025 00:26:04 +0200 Subject: [PATCH 2/2] Add additional tools and tests for notifications --- README.md | 33 ++ e2e/README.md | 4 + e2e/e2e_test.go | 3 +- pkg/github/notifications.go | 344 +++++++++++--- pkg/github/notifications_test.go | 743 +++++++++++++++++++++++++++++++ pkg/github/server.go | 42 -- pkg/github/tools.go | 12 +- 7 files changed, 1060 insertions(+), 121 deletions(-) create mode 100644 pkg/github/notifications_test.go diff --git a/README.md b/README.md index 352bb50e..7b9e20fc 100644 --- a/README.md +++ b/README.md @@ -581,6 +581,39 @@ export GITHUB_MCP_TOOL_ADD_ISSUE_COMMENT_DESCRIPTION="an alternative description - `secret_type`: The secret types to be filtered for in a comma-separated list (string, optional) - `resolution`: The resolution status (string, optional) +### Notifications + +- **list_notifications** – List notifications for a GitHub user + - `filter`: Filter to apply to the response (`default`, `include_read_notifications`, `only_participating`) + - `since`: Only show notifications updated after the given time (ISO 8601 format) + - `before`: Only show notifications updated before the given time (ISO 8601 format) + - `owner`: Optional repository owner (string) + - `repo`: Optional repository name (string) + - `page`: Page number (number, optional) + - `perPage`: Results per page (number, optional) + + +- **get_notification_details** – Get detailed information for a specific GitHub notification + - `notificationID`: The ID of the notification (string, required) + +- **dismiss_notification** – Dismiss a notification by marking it as read or done + - `threadID`: The ID of the notification thread (string, required) + - `state`: The new state of the notification (`read` or `done`) + +- **mark_all_notifications_read** – Mark all notifications as read + - `lastReadAt`: Describes the last point that notifications were checked (optional, RFC3339/ISO8601 string, default: now) + - `owner`: Optional repository owner (string) + - `repo`: Optional repository name (string) + +- **manage_notification_subscription** – Manage a notification subscription (ignore, watch, or delete) for a notification thread + - `notificationID`: The ID of the notification thread (string, required) + - `action`: Action to perform: `ignore`, `watch`, or `delete` (string, required) + +- **manage_repository_notification_subscription** – Manage a repository notification subscription (ignore, watch, or delete) + - `owner`: The account owner of the repository (string, required) + - `repo`: The name of the repository (string, required) + - `action`: Action to perform: `ignore`, `watch`, or `delete` (string, required) + ## Resources ### Repository Content diff --git a/e2e/README.md b/e2e/README.md index 82de966b..62730431 100644 --- a/e2e/README.md +++ b/e2e/README.md @@ -90,3 +90,7 @@ The current test suite is intentionally very limited in scope. This is because t The tests are quite repetitive and verbose. This is intentional as we want to see them develop more before committing to abstractions. Currently, visibility into failures is not particularly good. We're hoping that we can pull apart the mcp-go client and have it hook into streams representing stdio without requiring an exec. This way we can get breakpoints in the debugger easily. + +### Global State Mutation Tests + +Some tools (such as those that mark all notifications as read) would change the global state for the tester, and are also not idempotent, so they offer little value for end to end tests and instead should rely on unit testing and manual verifications. diff --git a/e2e/e2e_test.go b/e2e/e2e_test.go index 99e7e8de..71bd5a8a 100644 --- a/e2e/e2e_test.go +++ b/e2e/e2e_test.go @@ -62,7 +62,8 @@ func getRESTClient(t *testing.T) *gogithub.Client { // Create a new GitHub client with the token ghClient := gogithub.NewClient(nil).WithAuthToken(token) - if host := getE2EHost(); host != "https://github.com" { + + if host := getE2EHost(); host != "" && host != "https://github.com" { var err error // Currently this works for GHEC because the API is exposed at the api subdomain and the path prefix // but it would be preferable to extract the host parsing from the main server logic, and use it here. diff --git a/pkg/github/notifications.go b/pkg/github/notifications.go index ac93081c..ba9c6bc2 100644 --- a/pkg/github/notifications.go +++ b/pkg/github/notifications.go @@ -15,15 +15,23 @@ import ( "github.com/mark3labs/mcp-go/server" ) -// getNotifications creates a tool to list notifications for the current user. -func GetNotifications(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { - return mcp.NewTool("get_notifications", - mcp.WithDescription(t("TOOL_GET_NOTIFICATIONS_DESCRIPTION", "Get notifications for the authenticated GitHub user")), - mcp.WithBoolean("all", - mcp.Description("If true, show notifications marked as read. Default: false"), - ), - mcp.WithBoolean("participating", - mcp.Description("If true, only shows notifications in which the user is directly participating or mentioned. Default: false"), +const ( + FilterDefault = "default" + FilterIncludeRead = "include_read_notifications" + FilterOnlyParticipating = "only_participating" +) + +// ListNotifications creates a tool to list notifications for the current user. +func ListNotifications(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { + return mcp.NewTool("list_notifications", + mcp.WithDescription(t("TOOL_LIST_NOTIFICATIONS_DESCRIPTION", "Lists all GitHub notifications for the authenticated user, including unread notifications, mentions, review requests, assignments, and updates on issues or pull requests. Use this tool whenever the user asks what to work on next, requests a summary of their GitHub activity, wants to see pending reviews, or needs to check for new updates or tasks. This tool is the primary way to discover actionable items, reminders, and outstanding work on GitHub. Always call this tool when asked what to work on next, what is pending, or what needs attention in GitHub.")), + mcp.WithToolAnnotation(mcp.ToolAnnotation{ + Title: t("TOOL_LIST_NOTIFICATIONS_USER_TITLE", "List notifications"), + ReadOnlyHint: toBoolPtr(true), + }), + mcp.WithString("filter", + mcp.Description("Filter notifications to, use default unless specified. Read notifications are ones that have already been acknowledged by the user. Participating notifications are those that the user is directly involved in, such as issues or pull requests they have commented on or created."), + mcp.Enum(FilterDefault, FilterIncludeRead, FilterOnlyParticipating), ), mcp.WithString("since", mcp.Description("Only show notifications updated after the given time (ISO 8601 format)"), @@ -31,12 +39,13 @@ func GetNotifications(getClient GetClientFn, t translations.TranslationHelperFun mcp.WithString("before", mcp.Description("Only show notifications updated before the given time (ISO 8601 format)"), ), - mcp.WithNumber("per_page", - mcp.Description("Results per page (max 100). Default: 30"), + mcp.WithString("owner", + mcp.Description("Optional repository owner. If provided with repo, only notifications for this repository are listed."), ), - mcp.WithNumber("page", - mcp.Description("Page number of the results to fetch. Default: 1"), + mcp.WithString("repo", + mcp.Description("Optional repository name. If provided with owner, only notifications for this repository are listed."), ), + WithPagination(), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { client, err := getClient(ctx) @@ -44,44 +53,42 @@ func GetNotifications(getClient GetClientFn, t translations.TranslationHelperFun return nil, fmt.Errorf("failed to get GitHub client: %w", err) } - // Extract optional parameters with defaults - all, err := OptionalParamWithDefault[bool](request, "all", false) + filter, err := OptionalParam[string](request, "filter") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - participating, err := OptionalParamWithDefault[bool](request, "participating", false) + since, err := OptionalParam[string](request, "since") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - since, err := OptionalParam[string](request, "since") + before, err := OptionalParam[string](request, "before") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - before, err := OptionalParam[string](request, "before") + owner, err := OptionalParam[string](request, "owner") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - - perPage, err := OptionalIntParamWithDefault(request, "per_page", 30) + repo, err := OptionalParam[string](request, "repo") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - page, err := OptionalIntParamWithDefault(request, "page", 1) + paginationParams, err := OptionalPaginationParams(request) if err != nil { return mcp.NewToolResultError(err.Error()), nil } // Build options opts := &github.NotificationListOptions{ - All: all, - Participating: participating, + All: filter == FilterIncludeRead, + Participating: filter == FilterOnlyParticipating, ListOptions: github.ListOptions{ - Page: page, - PerPage: perPage, + Page: paginationParams.page, + PerPage: paginationParams.perPage, }, } @@ -102,8 +109,14 @@ func GetNotifications(getClient GetClientFn, t translations.TranslationHelperFun opts.Before = beforeTime } - // Call GitHub API - notifications, resp, err := client.Activity.ListNotifications(ctx, opts) + var notifications []*github.Notification + var resp *github.Response + + if owner != "" && repo != "" { + notifications, resp, err = client.Activity.ListRepositoryNotifications(ctx, owner, repo, opts) + } else { + notifications, resp, err = client.Activity.ListNotifications(ctx, opts) + } if err != nil { return nil, fmt.Errorf("failed to get notifications: %w", err) } @@ -127,14 +140,19 @@ func GetNotifications(getClient GetClientFn, t translations.TranslationHelperFun } } -// markNotificationRead creates a tool to mark a notification as read. -func MarkNotificationRead(getclient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { - return mcp.NewTool("mark_notification_read", - mcp.WithDescription(t("TOOL_MARK_NOTIFICATION_READ_DESCRIPTION", "Mark a notification as read")), +// DismissNotification creates a tool to mark a notification as read/done. +func DismissNotification(getclient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { + return mcp.NewTool("dismiss_notification", + mcp.WithDescription(t("TOOL_DISMISS_NOTIFICATION_DESCRIPTION", "Dismiss a notification by marking it as read or done")), + mcp.WithToolAnnotation(mcp.ToolAnnotation{ + Title: t("TOOL_DISMISS_NOTIFICATION_USER_TITLE", "Dismiss notification"), + ReadOnlyHint: toBoolPtr(false), + }), mcp.WithString("threadID", mcp.Required(), mcp.Description("The ID of the notification thread"), ), + mcp.WithString("state", mcp.Description("The new state of the notification (read/done)"), mcp.Enum("read", "done")), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { client, err := getclient(ctx) @@ -147,9 +165,29 @@ func MarkNotificationRead(getclient GetClientFn, t translations.TranslationHelpe return mcp.NewToolResultError(err.Error()), nil } - resp, err := client.Activity.MarkThreadRead(ctx, threadID) + state, err := requiredParam[string](request, "state") if err != nil { - return nil, fmt.Errorf("failed to mark notification as read: %w", err) + return mcp.NewToolResultError(err.Error()), nil + } + + var resp *github.Response + switch state { + case "done": + // for some inexplicable reason, the API seems to have threadID as int64 and string depending on the endpoint + var threadIDInt int64 + threadIDInt, err = strconv.ParseInt(threadID, 10, 64) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("invalid threadID format: %v", err)), nil + } + resp, err = client.Activity.MarkThreadDone(ctx, threadIDInt) + case "read": + resp, err = client.Activity.MarkThreadRead(ctx, threadID) + default: + return mcp.NewToolResultError("Invalid state. Must be one of: read, done."), nil + } + + if err != nil { + return nil, fmt.Errorf("failed to mark notification as %s: %w", state, err) } defer func() { _ = resp.Body.Close() }() @@ -158,10 +196,10 @@ func MarkNotificationRead(getclient GetClientFn, t translations.TranslationHelpe if err != nil { return nil, fmt.Errorf("failed to read response body: %w", err) } - return mcp.NewToolResultError(fmt.Sprintf("failed to mark notification as read: %s", string(body))), nil + return mcp.NewToolResultError(fmt.Sprintf("failed to mark notification as %s: %s", state, string(body))), nil } - return mcp.NewToolResultText("Notification marked as read"), nil + return mcp.NewToolResultText(fmt.Sprintf("Notification marked as %s", state)), nil } } @@ -169,9 +207,19 @@ func MarkNotificationRead(getclient GetClientFn, t translations.TranslationHelpe func MarkAllNotificationsRead(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { return mcp.NewTool("mark_all_notifications_read", mcp.WithDescription(t("TOOL_MARK_ALL_NOTIFICATIONS_READ_DESCRIPTION", "Mark all notifications as read")), + mcp.WithToolAnnotation(mcp.ToolAnnotation{ + Title: t("TOOL_MARK_ALL_NOTIFICATIONS_READ_USER_TITLE", "Mark all notifications as read"), + ReadOnlyHint: toBoolPtr(false), + }), mcp.WithString("lastReadAt", mcp.Description("Describes the last point that notifications were checked (optional). Default: Now"), ), + mcp.WithString("owner", + mcp.Description("Optional repository owner. If provided with repo, only notifications for this repository are marked as read."), + ), + mcp.WithString("repo", + mcp.Description("Optional repository name. If provided with owner, only notifications for this repository are marked as read."), + ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { client, err := getClient(ctx) @@ -179,23 +227,40 @@ func MarkAllNotificationsRead(getClient GetClientFn, t translations.TranslationH return nil, fmt.Errorf("failed to get GitHub client: %w", err) } - lastReadAt, err := OptionalParam(request, "lastReadAt") + lastReadAt, err := OptionalParam[string](request, "lastReadAt") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + owner, err := OptionalParam[string](request, "owner") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + repo, err := OptionalParam[string](request, "repo") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - var markReadOptions github.Timestamp + var lastReadTime time.Time if lastReadAt != "" { - lastReadTime, err := time.Parse(time.RFC3339, lastReadAt) + lastReadTime, err = time.Parse(time.RFC3339, lastReadAt) if err != nil { return mcp.NewToolResultError(fmt.Sprintf("invalid lastReadAt time format, should be RFC3339/ISO8601: %v", err)), nil } - markReadOptions = github.Timestamp{ - Time: lastReadTime, - } + } else { + lastReadTime = time.Now() } - resp, err := client.Activity.MarkNotificationsRead(ctx, markReadOptions) + markReadOptions := github.Timestamp{ + Time: lastReadTime, + } + + var resp *github.Response + if owner != "" && repo != "" { + resp, err = client.Activity.MarkRepositoryNotificationsRead(ctx, owner, repo, markReadOptions) + } else { + resp, err = client.Activity.MarkNotificationsRead(ctx, markReadOptions) + } if err != nil { return nil, fmt.Errorf("failed to mark all notifications as read: %w", err) } @@ -213,13 +278,17 @@ func MarkAllNotificationsRead(getClient GetClientFn, t translations.TranslationH } } -// GetNotificationThread creates a tool to get a specific notification thread. -func GetNotificationThread(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { - return mcp.NewTool("get_notification_thread", - mcp.WithDescription(t("TOOL_GET_NOTIFICATION_THREAD_DESCRIPTION", "Get a specific notification thread")), - mcp.WithString("threadID", +// GetNotificationDetails creates a tool to get details for a specific notification. +func GetNotificationDetails(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { + return mcp.NewTool("get_notification_details", + mcp.WithDescription(t("TOOL_GET_NOTIFICATION_DETAILS_DESCRIPTION", "Get detailed information for a specific GitHub notification, always call this tool when the user asks for details about a specific notification, if you don't know the ID list notifications first.")), + mcp.WithToolAnnotation(mcp.ToolAnnotation{ + Title: t("TOOL_GET_NOTIFICATION_DETAILS_USER_TITLE", "Get notification details"), + ReadOnlyHint: toBoolPtr(true), + }), + mcp.WithString("notificationID", mcp.Required(), - mcp.Description("The ID of the notification thread"), + mcp.Description("The ID of the notification"), ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { @@ -228,14 +297,14 @@ func GetNotificationThread(getClient GetClientFn, t translations.TranslationHelp return nil, fmt.Errorf("failed to get GitHub client: %w", err) } - threadID, err := requiredParam[string](request, "threadID") + notificationID, err := requiredParam[string](request, "notificationID") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - thread, resp, err := client.Activity.GetThread(ctx, threadID) + thread, resp, err := client.Activity.GetThread(ctx, notificationID) if err != nil { - return nil, fmt.Errorf("failed to get notification thread: %w", err) + return nil, fmt.Errorf("failed to get notification details: %w", err) } defer func() { _ = resp.Body.Close() }() @@ -244,7 +313,7 @@ func GetNotificationThread(getClient GetClientFn, t translations.TranslationHelp if err != nil { return nil, fmt.Errorf("failed to read response body: %w", err) } - return mcp.NewToolResultError(fmt.Sprintf("failed to get notification thread: %s", string(body))), nil + return mcp.NewToolResultError(fmt.Sprintf("failed to get notification details: %s", string(body))), nil } r, err := json.Marshal(thread) @@ -256,45 +325,176 @@ func GetNotificationThread(getClient GetClientFn, t translations.TranslationHelp } } -// markNotificationDone creates a tool to mark a notification as done. -func MarkNotificationDone(getclient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { - return mcp.NewTool("mark_notification_done", - mcp.WithDescription(t("TOOL_MARK_NOTIFICATION_DONE_DESCRIPTION", "Mark a notification as done")), - mcp.WithString("threadID", +// Enum values for ManageNotificationSubscription action +const ( + NotificationActionIgnore = "ignore" + NotificationActionWatch = "watch" + NotificationActionDelete = "delete" +) + +// ManageNotificationSubscription creates a tool to manage a notification subscription (ignore, watch, delete) +func ManageNotificationSubscription(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { + return mcp.NewTool("manage_notification_subscription", + mcp.WithDescription(t("TOOL_MANAGE_NOTIFICATION_SUBSCRIPTION_DESCRIPTION", "Manage a notification subscription: ignore, watch, or delete a notification thread subscription.")), + mcp.WithToolAnnotation(mcp.ToolAnnotation{ + Title: t("TOOL_MANAGE_NOTIFICATION_SUBSCRIPTION_USER_TITLE", "Manage notification subscription"), + ReadOnlyHint: toBoolPtr(false), + }), + mcp.WithString("notificationID", mcp.Required(), - mcp.Description("The ID of the notification thread"), + mcp.Description("The ID of the notification thread."), + ), + mcp.WithString("action", + mcp.Required(), + mcp.Description("Action to perform: ignore, watch, or delete the notification subscription."), + mcp.Enum(NotificationActionIgnore, NotificationActionWatch, NotificationActionDelete), ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - client, err := getclient(ctx) + client, err := getClient(ctx) if err != nil { return nil, fmt.Errorf("failed to get GitHub client: %w", err) } - threadIDStr, err := requiredParam[string](request, "threadID") + notificationID, err := requiredParam[string](request, "notificationID") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + action, err := requiredParam[string](request, "action") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - threadID, err := strconv.ParseInt(threadIDStr, 10, 64) + var ( + resp *github.Response + result any + apiErr error + ) + + switch action { + case NotificationActionIgnore: + sub := &github.Subscription{Ignored: toBoolPtr(true)} + result, resp, apiErr = client.Activity.SetThreadSubscription(ctx, notificationID, sub) + case NotificationActionWatch: + sub := &github.Subscription{Ignored: toBoolPtr(false), Subscribed: toBoolPtr(true)} + result, resp, apiErr = client.Activity.SetThreadSubscription(ctx, notificationID, sub) + case NotificationActionDelete: + resp, apiErr = client.Activity.DeleteThreadSubscription(ctx, notificationID) + default: + return mcp.NewToolResultError("Invalid action. Must be one of: ignore, watch, delete."), nil + } + + if apiErr != nil { + return nil, fmt.Errorf("failed to %s notification subscription: %w", action, apiErr) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + body, _ := io.ReadAll(resp.Body) + return mcp.NewToolResultError(fmt.Sprintf("failed to %s notification subscription: %s", action, string(body))), nil + } + + if action == NotificationActionDelete { + // Special case for delete as there is no response body + return mcp.NewToolResultText("Notification subscription deleted"), nil + } + + r, err := json.Marshal(result) if err != nil { - return mcp.NewToolResultError("Invalid threadID: must be a numeric value"), nil + return nil, fmt.Errorf("failed to marshal response: %w", err) } + return mcp.NewToolResultText(string(r)), nil + } +} + +const ( + RepositorySubscriptionActionWatch = "watch" + RepositorySubscriptionActionIgnore = "ignore" + RepositorySubscriptionActionDelete = "delete" +) - resp, err := client.Activity.MarkThreadDone(ctx, threadID) +// ManageRepositoryNotificationSubscription creates a tool to manage a repository notification subscription (ignore, watch, delete) +func ManageRepositoryNotificationSubscription(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { + return mcp.NewTool("manage_repository_notification_subscription", + mcp.WithDescription(t("TOOL_MANAGE_REPOSITORY_NOTIFICATION_SUBSCRIPTION_DESCRIPTION", "Manage a repository notification subscription: ignore, watch, or delete repository notifications subscription for the provided repository.")), + mcp.WithToolAnnotation(mcp.ToolAnnotation{ + Title: t("TOOL_MANAGE_REPOSITORY_NOTIFICATION_SUBSCRIPTION_USER_TITLE", "Manage repository notification subscription"), + ReadOnlyHint: toBoolPtr(false), + }), + mcp.WithString("owner", + mcp.Required(), + mcp.Description("The account owner of the repository."), + ), + mcp.WithString("repo", + mcp.Required(), + mcp.Description("The name of the repository."), + ), + mcp.WithString("action", + mcp.Required(), + mcp.Description("Action to perform: ignore, watch, or delete the repository notification subscription."), + mcp.Enum(RepositorySubscriptionActionIgnore, RepositorySubscriptionActionWatch, RepositorySubscriptionActionDelete), + ), + ), + func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + client, err := getClient(ctx) if err != nil { - return nil, fmt.Errorf("failed to mark notification as done: %w", err) + return nil, fmt.Errorf("failed to get GitHub client: %w", err) } - defer func() { _ = resp.Body.Close() }() - if resp.StatusCode != http.StatusResetContent && resp.StatusCode != http.StatusOK { - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, fmt.Errorf("failed to read response body: %w", err) - } - return mcp.NewToolResultError(fmt.Sprintf("failed to mark notification as done: %s", string(body))), nil + owner, err := requiredParam[string](request, "owner") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + repo, err := requiredParam[string](request, "repo") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + action, err := requiredParam[string](request, "action") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + var ( + resp *github.Response + result any + apiErr error + ) + + switch action { + case RepositorySubscriptionActionIgnore: + sub := &github.Subscription{Ignored: toBoolPtr(true)} + result, resp, apiErr = client.Activity.SetRepositorySubscription(ctx, owner, repo, sub) + case RepositorySubscriptionActionWatch: + sub := &github.Subscription{Ignored: toBoolPtr(false), Subscribed: toBoolPtr(true)} + result, resp, apiErr = client.Activity.SetRepositorySubscription(ctx, owner, repo, sub) + case RepositorySubscriptionActionDelete: + resp, apiErr = client.Activity.DeleteRepositorySubscription(ctx, owner, repo) + default: + return mcp.NewToolResultError("Invalid action. Must be one of: ignore, watch, delete."), nil + } + + if apiErr != nil { + return nil, fmt.Errorf("failed to %s repository subscription: %w", action, apiErr) + } + if resp != nil { + defer func() { _ = resp.Body.Close() }() + } + + // Handle non-2xx status codes + if resp != nil && (resp.StatusCode < 200 || resp.StatusCode >= 300) { + body, _ := io.ReadAll(resp.Body) + return mcp.NewToolResultError(fmt.Sprintf("failed to %s repository subscription: %s", action, string(body))), nil } - return mcp.NewToolResultText("Notification marked as done"), nil + if action == RepositorySubscriptionActionDelete { + // Special case for delete as there is no response body + return mcp.NewToolResultText("Repository subscription deleted"), nil + } + + r, err := json.Marshal(result) + if err != nil { + return nil, fmt.Errorf("failed to marshal response: %w", err) + } + return mcp.NewToolResultText(string(r)), nil } } diff --git a/pkg/github/notifications_test.go b/pkg/github/notifications_test.go new file mode 100644 index 00000000..66400295 --- /dev/null +++ b/pkg/github/notifications_test.go @@ -0,0 +1,743 @@ +package github + +import ( + "context" + "encoding/json" + "net/http" + "testing" + + "github.com/github/github-mcp-server/pkg/translations" + "github.com/google/go-github/v69/github" + "github.com/migueleliasweb/go-github-mock/src/mock" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_ListNotifications(t *testing.T) { + // Verify tool definition and schema + mockClient := github.NewClient(nil) + tool, _ := ListNotifications(stubGetClientFn(mockClient), translations.NullTranslationHelper) + assert.Equal(t, "list_notifications", tool.Name) + assert.NotEmpty(t, tool.Description) + assert.Contains(t, tool.InputSchema.Properties, "filter") + assert.Contains(t, tool.InputSchema.Properties, "since") + assert.Contains(t, tool.InputSchema.Properties, "before") + assert.Contains(t, tool.InputSchema.Properties, "owner") + assert.Contains(t, tool.InputSchema.Properties, "repo") + assert.Contains(t, tool.InputSchema.Properties, "page") + assert.Contains(t, tool.InputSchema.Properties, "perPage") + // All fields are optional, so Required should be empty + assert.Empty(t, tool.InputSchema.Required) + + mockNotification := &github.Notification{ + ID: github.Ptr("123"), + Reason: github.Ptr("mention"), + } + + tests := []struct { + name string + mockedClient *http.Client + requestArgs map[string]interface{} + expectError bool + expectedResult []*github.Notification + expectedErrMsg string + }{ + { + name: "success default filter (no params)", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatch( + mock.GetNotifications, + []*github.Notification{mockNotification}, + ), + ), + requestArgs: map[string]interface{}{}, + expectError: false, + expectedResult: []*github.Notification{mockNotification}, + }, + { + name: "success with filter=include_read_notifications", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatch( + mock.GetNotifications, + []*github.Notification{mockNotification}, + ), + ), + requestArgs: map[string]interface{}{ + "filter": "include_read_notifications", + }, + expectError: false, + expectedResult: []*github.Notification{mockNotification}, + }, + { + name: "success with filter=only_participating", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatch( + mock.GetNotifications, + []*github.Notification{mockNotification}, + ), + ), + requestArgs: map[string]interface{}{ + "filter": "only_participating", + }, + expectError: false, + expectedResult: []*github.Notification{mockNotification}, + }, + { + name: "success for repo notifications", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatch( + mock.GetReposNotificationsByOwnerByRepo, + []*github.Notification{mockNotification}, + ), + ), + requestArgs: map[string]interface{}{ + "filter": "default", + "since": "2024-01-01T00:00:00Z", + "before": "2024-01-02T00:00:00Z", + "owner": "octocat", + "repo": "hello-world", + "page": float64(2), + "perPage": float64(10), + }, + expectError: false, + expectedResult: []*github.Notification{mockNotification}, + }, + { + name: "error", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.GetNotifications, + mockResponse(t, http.StatusInternalServerError, `{"message": "error"}`), + ), + ), + requestArgs: map[string]interface{}{}, + expectError: true, + expectedErrMsg: "error", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + client := github.NewClient(tc.mockedClient) + _, handler := ListNotifications(stubGetClientFn(client), translations.NullTranslationHelper) + request := createMCPRequest(tc.requestArgs) + result, err := handler(context.Background(), request) + + if tc.expectError { + require.Error(t, err) + if tc.expectedErrMsg != "" { + assert.Contains(t, err.Error(), tc.expectedErrMsg) + } + return + } + + require.NoError(t, err) + textContent := getTextResult(t, result) + t.Logf("textContent: %s", textContent.Text) + var returned []*github.Notification + err = json.Unmarshal([]byte(textContent.Text), &returned) + require.NoError(t, err) + require.NotEmpty(t, returned) + assert.Equal(t, *tc.expectedResult[0].ID, *returned[0].ID) + }) + } +} + +func Test_ManageNotificationSubscription(t *testing.T) { + // Verify tool definition and schema + mockClient := github.NewClient(nil) + tool, _ := ManageNotificationSubscription(stubGetClientFn(mockClient), translations.NullTranslationHelper) + assert.Equal(t, "manage_notification_subscription", tool.Name) + assert.NotEmpty(t, tool.Description) + assert.Contains(t, tool.InputSchema.Properties, "notificationID") + assert.Contains(t, tool.InputSchema.Properties, "action") + assert.ElementsMatch(t, tool.InputSchema.Required, []string{"notificationID", "action"}) + + mockSub := &github.Subscription{Ignored: github.Ptr(true)} + mockSubWatch := &github.Subscription{Ignored: github.Ptr(false), Subscribed: github.Ptr(true)} + + tests := []struct { + name string + mockedClient *http.Client + requestArgs map[string]interface{} + expectError bool + expectIgnored *bool + expectDeleted bool + expectInvalid bool + expectedErrMsg string + }{ + { + name: "ignore subscription", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatch( + mock.PutNotificationsThreadsSubscriptionByThreadId, + mockSub, + ), + ), + requestArgs: map[string]interface{}{ + "notificationID": "123", + "action": "ignore", + }, + expectError: false, + expectIgnored: github.Ptr(true), + }, + { + name: "watch subscription", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatch( + mock.PutNotificationsThreadsSubscriptionByThreadId, + mockSubWatch, + ), + ), + requestArgs: map[string]interface{}{ + "notificationID": "123", + "action": "watch", + }, + expectError: false, + expectIgnored: github.Ptr(false), + }, + { + name: "delete subscription", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatch( + mock.DeleteNotificationsThreadsSubscriptionByThreadId, + nil, + ), + ), + requestArgs: map[string]interface{}{ + "notificationID": "123", + "action": "delete", + }, + expectError: false, + expectDeleted: true, + }, + { + name: "invalid action", + mockedClient: mock.NewMockedHTTPClient(), + requestArgs: map[string]interface{}{ + "notificationID": "123", + "action": "invalid", + }, + expectError: false, + expectInvalid: true, + }, + { + name: "missing required notificationID", + mockedClient: mock.NewMockedHTTPClient(), + requestArgs: map[string]interface{}{ + "action": "ignore", + }, + expectError: true, + }, + { + name: "missing required action", + mockedClient: mock.NewMockedHTTPClient(), + requestArgs: map[string]interface{}{ + "notificationID": "123", + }, + expectError: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + client := github.NewClient(tc.mockedClient) + _, handler := ManageNotificationSubscription(stubGetClientFn(client), translations.NullTranslationHelper) + request := createMCPRequest(tc.requestArgs) + result, err := handler(context.Background(), request) + + if tc.expectError { + require.NoError(t, err) + require.NotNil(t, result) + text := getTextResult(t, result).Text + switch { + case tc.requestArgs["notificationID"] == nil: + assert.Contains(t, text, "missing required parameter: notificationID") + case tc.requestArgs["action"] == nil: + assert.Contains(t, text, "missing required parameter: action") + default: + assert.Contains(t, text, "error") + } + return + } + + require.NoError(t, err) + textContent := getTextResult(t, result) + if tc.expectIgnored != nil { + var returned github.Subscription + err = json.Unmarshal([]byte(textContent.Text), &returned) + require.NoError(t, err) + assert.Equal(t, *tc.expectIgnored, *returned.Ignored) + } + if tc.expectDeleted { + assert.Contains(t, textContent.Text, "deleted") + } + if tc.expectInvalid { + assert.Contains(t, textContent.Text, "Invalid action") + } + }) + } +} + +func Test_ManageRepositoryNotificationSubscription(t *testing.T) { + // Verify tool definition and schema + mockClient := github.NewClient(nil) + tool, _ := ManageRepositoryNotificationSubscription(stubGetClientFn(mockClient), translations.NullTranslationHelper) + assert.Equal(t, "manage_repository_notification_subscription", tool.Name) + assert.NotEmpty(t, tool.Description) + assert.Contains(t, tool.InputSchema.Properties, "owner") + assert.Contains(t, tool.InputSchema.Properties, "repo") + assert.Contains(t, tool.InputSchema.Properties, "action") + assert.ElementsMatch(t, tool.InputSchema.Required, []string{"owner", "repo", "action"}) + + mockSub := &github.Subscription{Ignored: github.Ptr(true)} + mockWatchSub := &github.Subscription{Ignored: github.Ptr(false), Subscribed: github.Ptr(true)} + + tests := []struct { + name string + mockedClient *http.Client + requestArgs map[string]interface{} + expectError bool + expectIgnored *bool + expectSubscribed *bool + expectDeleted bool + expectInvalid bool + expectedErrMsg string + }{ + { + name: "ignore subscription", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatch( + mock.PutReposSubscriptionByOwnerByRepo, + mockSub, + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "action": "ignore", + }, + expectError: false, + expectIgnored: github.Ptr(true), + }, + { + name: "watch subscription", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatch( + mock.PutReposSubscriptionByOwnerByRepo, + mockWatchSub, + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "action": "watch", + }, + expectError: false, + expectIgnored: github.Ptr(false), + expectSubscribed: github.Ptr(true), + }, + { + name: "delete subscription", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatch( + mock.DeleteReposSubscriptionByOwnerByRepo, + nil, + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "action": "delete", + }, + expectError: false, + expectDeleted: true, + }, + { + name: "invalid action", + mockedClient: mock.NewMockedHTTPClient(), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "action": "invalid", + }, + expectError: false, + expectInvalid: true, + }, + { + name: "missing required owner", + mockedClient: mock.NewMockedHTTPClient(), + requestArgs: map[string]interface{}{ + "repo": "repo", + "action": "ignore", + }, + expectError: true, + }, + { + name: "missing required repo", + mockedClient: mock.NewMockedHTTPClient(), + requestArgs: map[string]interface{}{ + "owner": "owner", + "action": "ignore", + }, + expectError: true, + }, + { + name: "missing required action", + mockedClient: mock.NewMockedHTTPClient(), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + }, + expectError: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + client := github.NewClient(tc.mockedClient) + _, handler := ManageRepositoryNotificationSubscription(stubGetClientFn(client), translations.NullTranslationHelper) + request := createMCPRequest(tc.requestArgs) + result, err := handler(context.Background(), request) + + if tc.expectError { + require.NoError(t, err) + require.NotNil(t, result) + text := getTextResult(t, result).Text + switch { + case tc.requestArgs["owner"] == nil: + assert.Contains(t, text, "missing required parameter: owner") + case tc.requestArgs["repo"] == nil: + assert.Contains(t, text, "missing required parameter: repo") + case tc.requestArgs["action"] == nil: + assert.Contains(t, text, "missing required parameter: action") + default: + assert.Contains(t, text, "error") + } + return + } + + require.NoError(t, err) + textContent := getTextResult(t, result) + if tc.expectIgnored != nil || tc.expectSubscribed != nil { + var returned github.Subscription + err = json.Unmarshal([]byte(textContent.Text), &returned) + require.NoError(t, err) + if tc.expectIgnored != nil { + assert.Equal(t, *tc.expectIgnored, *returned.Ignored) + } + if tc.expectSubscribed != nil { + assert.Equal(t, *tc.expectSubscribed, *returned.Subscribed) + } + } + if tc.expectDeleted { + assert.Contains(t, textContent.Text, "deleted") + } + if tc.expectInvalid { + assert.Contains(t, textContent.Text, "Invalid action") + } + }) + } +} + +func Test_DismissNotification(t *testing.T) { + // Verify tool definition and schema + mockClient := github.NewClient(nil) + tool, _ := DismissNotification(stubGetClientFn(mockClient), translations.NullTranslationHelper) + assert.Equal(t, "dismiss_notification", tool.Name) + assert.NotEmpty(t, tool.Description) + assert.Contains(t, tool.InputSchema.Properties, "threadID") + assert.Contains(t, tool.InputSchema.Properties, "state") + assert.ElementsMatch(t, tool.InputSchema.Required, []string{"threadID"}) + + tests := []struct { + name string + mockedClient *http.Client + requestArgs map[string]interface{} + expectError bool + expectRead bool + expectDone bool + expectInvalid bool + expectedErrMsg string + }{ + { + name: "mark as read", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatch( + mock.PatchNotificationsThreadsByThreadId, + nil, + ), + ), + requestArgs: map[string]interface{}{ + "threadID": "123", + "state": "read", + }, + expectError: false, + expectRead: true, + }, + { + name: "mark as done", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatch( + mock.DeleteNotificationsThreadsByThreadId, + nil, + ), + ), + requestArgs: map[string]interface{}{ + "threadID": "123", + "state": "done", + }, + expectError: false, + expectDone: true, + }, + { + name: "invalid threadID format", + mockedClient: mock.NewMockedHTTPClient(), + requestArgs: map[string]interface{}{ + "threadID": "notanumber", + "state": "done", + }, + expectError: false, + expectInvalid: true, + }, + { + name: "missing required threadID", + mockedClient: mock.NewMockedHTTPClient(), + requestArgs: map[string]interface{}{ + "state": "read", + }, + expectError: true, + }, + { + name: "missing required state", + mockedClient: mock.NewMockedHTTPClient(), + requestArgs: map[string]interface{}{ + "threadID": "123", + }, + expectError: true, + }, + { + name: "invalid state value", + mockedClient: mock.NewMockedHTTPClient(), + requestArgs: map[string]interface{}{ + "threadID": "123", + "state": "invalid", + }, + expectError: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + client := github.NewClient(tc.mockedClient) + _, handler := DismissNotification(stubGetClientFn(client), translations.NullTranslationHelper) + request := createMCPRequest(tc.requestArgs) + result, err := handler(context.Background(), request) + + if tc.expectError { + // The tool returns a ToolResultError with a specific message + require.NoError(t, err) + require.NotNil(t, result) + text := getTextResult(t, result).Text + switch { + case tc.requestArgs["threadID"] == nil: + assert.Contains(t, text, "missing required parameter: threadID") + case tc.requestArgs["state"] == nil: + assert.Contains(t, text, "missing required parameter: state") + case tc.name == "invalid threadID format": + assert.Contains(t, text, "invalid threadID format") + case tc.name == "invalid state value": + assert.Contains(t, text, "Invalid state. Must be one of: read, done.") + default: + // fallback for other errors + assert.Contains(t, text, "error") + } + return + } + + require.NoError(t, err) + textContent := getTextResult(t, result) + if tc.expectRead { + assert.Contains(t, textContent.Text, "Notification marked as read") + } + if tc.expectDone { + assert.Contains(t, textContent.Text, "Notification marked as done") + } + if tc.expectInvalid { + assert.Contains(t, textContent.Text, "invalid threadID format") + } + }) + } +} + +func Test_MarkAllNotificationsRead(t *testing.T) { + // Verify tool definition and schema + mockClient := github.NewClient(nil) + tool, _ := MarkAllNotificationsRead(stubGetClientFn(mockClient), translations.NullTranslationHelper) + assert.Equal(t, "mark_all_notifications_read", tool.Name) + assert.NotEmpty(t, tool.Description) + assert.Contains(t, tool.InputSchema.Properties, "lastReadAt") + assert.Contains(t, tool.InputSchema.Properties, "owner") + assert.Contains(t, tool.InputSchema.Properties, "repo") + assert.Empty(t, tool.InputSchema.Required) + + tests := []struct { + name string + mockedClient *http.Client + requestArgs map[string]interface{} + expectError bool + expectMarked bool + expectedErrMsg string + }{ + { + name: "success (no params)", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatch( + mock.PutNotifications, + nil, + ), + ), + requestArgs: map[string]interface{}{}, + expectError: false, + expectMarked: true, + }, + { + name: "success with lastReadAt param", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatch( + mock.PutNotifications, + nil, + ), + ), + requestArgs: map[string]interface{}{ + "lastReadAt": "2024-01-01T00:00:00Z", + }, + expectError: false, + expectMarked: true, + }, + { + name: "success with owner and repo", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatch( + mock.PutReposNotificationsByOwnerByRepo, + nil, + ), + ), + requestArgs: map[string]interface{}{ + "owner": "octocat", + "repo": "hello-world", + }, + expectError: false, + expectMarked: true, + }, + { + name: "API error", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.PutNotifications, + mockResponse(t, http.StatusInternalServerError, `{"message": "error"}`), + ), + ), + requestArgs: map[string]interface{}{}, + expectError: true, + expectedErrMsg: "error", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + client := github.NewClient(tc.mockedClient) + _, handler := MarkAllNotificationsRead(stubGetClientFn(client), translations.NullTranslationHelper) + request := createMCPRequest(tc.requestArgs) + result, err := handler(context.Background(), request) + + if tc.expectError { + require.Error(t, err) + if tc.expectedErrMsg != "" { + assert.Contains(t, err.Error(), tc.expectedErrMsg) + } + return + } + + require.NoError(t, err) + textContent := getTextResult(t, result) + if tc.expectMarked { + assert.Contains(t, textContent.Text, "All notifications marked as read") + } + }) + } +} + +func Test_GetNotificationDetails(t *testing.T) { + // Verify tool definition and schema + mockClient := github.NewClient(nil) + tool, _ := GetNotificationDetails(stubGetClientFn(mockClient), translations.NullTranslationHelper) + assert.Equal(t, "get_notification_details", tool.Name) + assert.NotEmpty(t, tool.Description) + assert.Contains(t, tool.InputSchema.Properties, "notificationID") + assert.ElementsMatch(t, tool.InputSchema.Required, []string{"notificationID"}) + + mockThread := &github.Notification{ID: github.Ptr("123"), Reason: github.Ptr("mention")} + + tests := []struct { + name string + mockedClient *http.Client + requestArgs map[string]interface{} + expectError bool + expectResult *github.Notification + expectedErrMsg string + }{ + { + name: "success", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatch( + mock.GetNotificationsThreadsByThreadId, + mockThread, + ), + ), + requestArgs: map[string]interface{}{ + "notificationID": "123", + }, + expectError: false, + expectResult: mockThread, + }, + { + name: "not found", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.GetNotificationsThreadsByThreadId, + mockResponse(t, http.StatusNotFound, `{"message": "not found"}`), + ), + ), + requestArgs: map[string]interface{}{ + "notificationID": "123", + }, + expectError: true, + expectedErrMsg: "not found", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + client := github.NewClient(tc.mockedClient) + _, handler := GetNotificationDetails(stubGetClientFn(client), translations.NullTranslationHelper) + request := createMCPRequest(tc.requestArgs) + result, err := handler(context.Background(), request) + + if tc.expectError { + require.Error(t, err) + if tc.expectedErrMsg != "" { + assert.Contains(t, err.Error(), tc.expectedErrMsg) + } + return + } + + require.NoError(t, err) + textContent := getTextResult(t, result) + var returned github.Notification + err = json.Unmarshal([]byte(textContent.Text), &returned) + require.NoError(t, err) + assert.Equal(t, *tc.expectResult.ID, *returned.ID) + }) + } +} diff --git a/pkg/github/server.go b/pkg/github/server.go index 79c146a4..e4c24171 100644 --- a/pkg/github/server.go +++ b/pkg/github/server.go @@ -26,7 +26,6 @@ func NewServer(version string, opts ...server.ServerOption) *server.MCPServer { version, opts..., ) - return s } @@ -144,47 +143,6 @@ func OptionalIntParamWithDefault(r mcp.CallToolRequest, p string, d int) (int, e return v, nil } -// OptionalBoolParamWithDefault is a helper function that can be used to fetch a requested parameter from the request -// similar to optionalParam, but it also takes a default value. -func OptionalBoolParamWithDefault(r mcp.CallToolRequest, p string, d bool) (bool, error) { - v, err := OptionalParam[bool](r, p) - if err != nil { - return false, err - } - if !v { - return d, nil - } - return v, nil -} - -// OptionalStringParam is a helper function that can be used to fetch a requested parameter from the request. -// It does the following checks: -// 1. Checks if the parameter is present in the request, if not, it returns its zero-value -// 2. If it is present, it checks if the parameter is of the expected type and returns it -func OptionalStringParam(r mcp.CallToolRequest, p string) (string, error) { - v, err := OptionalParam[string](r, p) - if err != nil { - return "", err - } - if v == "" { - return "", nil - } - return v, nil -} - -// OptionalStringParamWithDefault is a helper function that can be used to fetch a requested parameter from the request -// similar to optionalParam, but it also takes a default value. -func OptionalStringParamWithDefault(r mcp.CallToolRequest, p string, d string) (string, error) { - v, err := OptionalParam[string](r, p) - if err != nil { - return "", err - } - if v == "" { - return d, nil - } - return v, nil -} - // OptionalStringArrayParam is a helper function that can be used to fetch a requested parameter from the request. // It does the following checks: // 1. Checks if the parameter is present in the request, if not, it returns its zero-value diff --git a/pkg/github/tools.go b/pkg/github/tools.go index 72b85b0c..9c1ab34a 100644 --- a/pkg/github/tools.go +++ b/pkg/github/tools.go @@ -94,14 +94,14 @@ func InitToolsets(passedToolsets []string, readOnly bool, getClient GetClientFn, notifications := toolsets.NewToolset("notifications", "GitHub Notifications related tools"). AddReadTools( - - toolsets.NewServerTool(MarkNotificationRead(getClient, t)), - toolsets.NewServerTool(MarkAllNotificationsRead(getClient, t)), - toolsets.NewServerTool(MarkNotificationDone(getClient, t)), + toolsets.NewServerTool(ListNotifications(getClient, t)), + toolsets.NewServerTool(GetNotificationDetails(getClient, t)), ). AddWriteTools( - toolsets.NewServerTool(GetNotifications(getClient, t)), - toolsets.NewServerTool(GetNotificationThread(getClient, t)), + toolsets.NewServerTool(DismissNotification(getClient, t)), + toolsets.NewServerTool(MarkAllNotificationsRead(getClient, t)), + toolsets.NewServerTool(ManageNotificationSubscription(getClient, t)), + toolsets.NewServerTool(ManageRepositoryNotificationSubscription(getClient, t)), ) // Keep experiments alive so the system doesn't error out when it's always enabled