Skip to content

Commit 5771908

Browse files
sridharavinashwilliammartin
authored andcommitted
feat: add GitHub notifications tools for managing user notifications
1 parent b9a06d0 commit 5771908

File tree

3 files changed

+356
-0
lines changed

3 files changed

+356
-0
lines changed

pkg/github/notifications.go

Lines changed: 300 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,300 @@
1+
package github
2+
3+
import (
4+
"context"
5+
"encoding/json"
6+
"fmt"
7+
"io"
8+
"net/http"
9+
"strconv"
10+
"time"
11+
12+
"github.com/github/github-mcp-server/pkg/translations"
13+
"github.com/google/go-github/v69/github"
14+
"github.com/mark3labs/mcp-go/mcp"
15+
"github.com/mark3labs/mcp-go/server"
16+
)
17+
18+
// getNotifications creates a tool to list notifications for the current user.
19+
func GetNotifications(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) {
20+
return mcp.NewTool("get_notifications",
21+
mcp.WithDescription(t("TOOL_GET_NOTIFICATIONS_DESCRIPTION", "Get notifications for the authenticated GitHub user")),
22+
mcp.WithBoolean("all",
23+
mcp.Description("If true, show notifications marked as read. Default: false"),
24+
),
25+
mcp.WithBoolean("participating",
26+
mcp.Description("If true, only shows notifications in which the user is directly participating or mentioned. Default: false"),
27+
),
28+
mcp.WithString("since",
29+
mcp.Description("Only show notifications updated after the given time (ISO 8601 format)"),
30+
),
31+
mcp.WithString("before",
32+
mcp.Description("Only show notifications updated before the given time (ISO 8601 format)"),
33+
),
34+
mcp.WithNumber("per_page",
35+
mcp.Description("Results per page (max 100). Default: 30"),
36+
),
37+
mcp.WithNumber("page",
38+
mcp.Description("Page number of the results to fetch. Default: 1"),
39+
),
40+
),
41+
func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
42+
client, err := getClient(ctx)
43+
if err != nil {
44+
return nil, fmt.Errorf("failed to get GitHub client: %w", err)
45+
}
46+
47+
// Extract optional parameters with defaults
48+
all, err := OptionalParamWithDefault[bool](request, "all", false)
49+
if err != nil {
50+
return mcp.NewToolResultError(err.Error()), nil
51+
}
52+
53+
participating, err := OptionalParamWithDefault[bool](request, "participating", false)
54+
if err != nil {
55+
return mcp.NewToolResultError(err.Error()), nil
56+
}
57+
58+
since, err := OptionalParam[string](request, "since")
59+
if err != nil {
60+
return mcp.NewToolResultError(err.Error()), nil
61+
}
62+
63+
before, err := OptionalParam[string](request, "before")
64+
if err != nil {
65+
return mcp.NewToolResultError(err.Error()), nil
66+
}
67+
68+
perPage, err := OptionalIntParamWithDefault(request, "per_page", 30)
69+
if err != nil {
70+
return mcp.NewToolResultError(err.Error()), nil
71+
}
72+
73+
page, err := OptionalIntParamWithDefault(request, "page", 1)
74+
if err != nil {
75+
return mcp.NewToolResultError(err.Error()), nil
76+
}
77+
78+
// Build options
79+
opts := &github.NotificationListOptions{
80+
All: all,
81+
Participating: participating,
82+
ListOptions: github.ListOptions{
83+
Page: page,
84+
PerPage: perPage,
85+
},
86+
}
87+
88+
// Parse time parameters if provided
89+
if since != "" {
90+
sinceTime, err := time.Parse(time.RFC3339, since)
91+
if err != nil {
92+
return mcp.NewToolResultError(fmt.Sprintf("invalid since time format, should be RFC3339/ISO8601: %v", err)), nil
93+
}
94+
opts.Since = sinceTime
95+
}
96+
97+
if before != "" {
98+
beforeTime, err := time.Parse(time.RFC3339, before)
99+
if err != nil {
100+
return mcp.NewToolResultError(fmt.Sprintf("invalid before time format, should be RFC3339/ISO8601: %v", err)), nil
101+
}
102+
opts.Before = beforeTime
103+
}
104+
105+
// Call GitHub API
106+
notifications, resp, err := client.Activity.ListNotifications(ctx, opts)
107+
if err != nil {
108+
return nil, fmt.Errorf("failed to get notifications: %w", err)
109+
}
110+
defer func() { _ = resp.Body.Close() }()
111+
112+
if resp.StatusCode != http.StatusOK {
113+
body, err := io.ReadAll(resp.Body)
114+
if err != nil {
115+
return nil, fmt.Errorf("failed to read response body: %w", err)
116+
}
117+
return mcp.NewToolResultError(fmt.Sprintf("failed to get notifications: %s", string(body))), nil
118+
}
119+
120+
// Marshal response to JSON
121+
r, err := json.Marshal(notifications)
122+
if err != nil {
123+
return nil, fmt.Errorf("failed to marshal response: %w", err)
124+
}
125+
126+
return mcp.NewToolResultText(string(r)), nil
127+
}
128+
}
129+
130+
// markNotificationRead creates a tool to mark a notification as read.
131+
func MarkNotificationRead(getclient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) {
132+
return mcp.NewTool("mark_notification_read",
133+
mcp.WithDescription(t("TOOL_MARK_NOTIFICATION_READ_DESCRIPTION", "Mark a notification as read")),
134+
mcp.WithString("threadID",
135+
mcp.Required(),
136+
mcp.Description("The ID of the notification thread"),
137+
),
138+
),
139+
func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
140+
client, err := getclient(ctx)
141+
if err != nil {
142+
return nil, fmt.Errorf("failed to get GitHub client: %w", err)
143+
}
144+
145+
threadID, err := requiredParam[string](request, "threadID")
146+
if err != nil {
147+
return mcp.NewToolResultError(err.Error()), nil
148+
}
149+
150+
resp, err := client.Activity.MarkThreadRead(ctx, threadID)
151+
if err != nil {
152+
return nil, fmt.Errorf("failed to mark notification as read: %w", err)
153+
}
154+
defer func() { _ = resp.Body.Close() }()
155+
156+
if resp.StatusCode != http.StatusResetContent && resp.StatusCode != http.StatusOK {
157+
body, err := io.ReadAll(resp.Body)
158+
if err != nil {
159+
return nil, fmt.Errorf("failed to read response body: %w", err)
160+
}
161+
return mcp.NewToolResultError(fmt.Sprintf("failed to mark notification as read: %s", string(body))), nil
162+
}
163+
164+
return mcp.NewToolResultText("Notification marked as read"), nil
165+
}
166+
}
167+
168+
// MarkAllNotificationsRead creates a tool to mark all notifications as read.
169+
func MarkAllNotificationsRead(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) {
170+
return mcp.NewTool("mark_all_notifications_read",
171+
mcp.WithDescription(t("TOOL_MARK_ALL_NOTIFICATIONS_READ_DESCRIPTION", "Mark all notifications as read")),
172+
mcp.WithString("lastReadAt",
173+
mcp.Description("Describes the last point that notifications were checked (optional). Default: Now"),
174+
),
175+
),
176+
func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
177+
client, err := getClient(ctx)
178+
if err != nil {
179+
return nil, fmt.Errorf("failed to get GitHub client: %w", err)
180+
}
181+
182+
lastReadAt, err := OptionalParam(request, "lastReadAt")
183+
if err != nil {
184+
return mcp.NewToolResultError(err.Error()), nil
185+
}
186+
187+
var markReadOptions github.Timestamp
188+
if lastReadAt != "" {
189+
lastReadTime, err := time.Parse(time.RFC3339, lastReadAt)
190+
if err != nil {
191+
return mcp.NewToolResultError(fmt.Sprintf("invalid lastReadAt time format, should be RFC3339/ISO8601: %v", err)), nil
192+
}
193+
markReadOptions = github.Timestamp{
194+
Time: lastReadTime,
195+
}
196+
}
197+
198+
resp, err := client.Activity.MarkNotificationsRead(ctx, markReadOptions)
199+
if err != nil {
200+
return nil, fmt.Errorf("failed to mark all notifications as read: %w", err)
201+
}
202+
defer func() { _ = resp.Body.Close() }()
203+
204+
if resp.StatusCode != http.StatusResetContent && resp.StatusCode != http.StatusOK {
205+
body, err := io.ReadAll(resp.Body)
206+
if err != nil {
207+
return nil, fmt.Errorf("failed to read response body: %w", err)
208+
}
209+
return mcp.NewToolResultError(fmt.Sprintf("failed to mark all notifications as read: %s", string(body))), nil
210+
}
211+
212+
return mcp.NewToolResultText("All notifications marked as read"), nil
213+
}
214+
}
215+
216+
// GetNotificationThread creates a tool to get a specific notification thread.
217+
func GetNotificationThread(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) {
218+
return mcp.NewTool("get_notification_thread",
219+
mcp.WithDescription(t("TOOL_GET_NOTIFICATION_THREAD_DESCRIPTION", "Get a specific notification thread")),
220+
mcp.WithString("threadID",
221+
mcp.Required(),
222+
mcp.Description("The ID of the notification thread"),
223+
),
224+
),
225+
func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
226+
client, err := getClient(ctx)
227+
if err != nil {
228+
return nil, fmt.Errorf("failed to get GitHub client: %w", err)
229+
}
230+
231+
threadID, err := requiredParam[string](request, "threadID")
232+
if err != nil {
233+
return mcp.NewToolResultError(err.Error()), nil
234+
}
235+
236+
thread, resp, err := client.Activity.GetThread(ctx, threadID)
237+
if err != nil {
238+
return nil, fmt.Errorf("failed to get notification thread: %w", err)
239+
}
240+
defer func() { _ = resp.Body.Close() }()
241+
242+
if resp.StatusCode != http.StatusOK {
243+
body, err := io.ReadAll(resp.Body)
244+
if err != nil {
245+
return nil, fmt.Errorf("failed to read response body: %w", err)
246+
}
247+
return mcp.NewToolResultError(fmt.Sprintf("failed to get notification thread: %s", string(body))), nil
248+
}
249+
250+
r, err := json.Marshal(thread)
251+
if err != nil {
252+
return nil, fmt.Errorf("failed to marshal response: %w", err)
253+
}
254+
255+
return mcp.NewToolResultText(string(r)), nil
256+
}
257+
}
258+
259+
// markNotificationDone creates a tool to mark a notification as done.
260+
func MarkNotificationDone(getclient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) {
261+
return mcp.NewTool("mark_notification_done",
262+
mcp.WithDescription(t("TOOL_MARK_NOTIFICATION_DONE_DESCRIPTION", "Mark a notification as done")),
263+
mcp.WithString("threadID",
264+
mcp.Required(),
265+
mcp.Description("The ID of the notification thread"),
266+
),
267+
),
268+
func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
269+
client, err := getclient(ctx)
270+
if err != nil {
271+
return nil, fmt.Errorf("failed to get GitHub client: %w", err)
272+
}
273+
274+
threadIDStr, err := requiredParam[string](request, "threadID")
275+
if err != nil {
276+
return mcp.NewToolResultError(err.Error()), nil
277+
}
278+
279+
threadID, err := strconv.ParseInt(threadIDStr, 10, 64)
280+
if err != nil {
281+
return mcp.NewToolResultError("Invalid threadID: must be a numeric value"), nil
282+
}
283+
284+
resp, err := client.Activity.MarkThreadDone(ctx, threadID)
285+
if err != nil {
286+
return nil, fmt.Errorf("failed to mark notification as done: %w", err)
287+
}
288+
defer func() { _ = resp.Body.Close() }()
289+
290+
if resp.StatusCode != http.StatusResetContent && resp.StatusCode != http.StatusOK {
291+
body, err := io.ReadAll(resp.Body)
292+
if err != nil {
293+
return nil, fmt.Errorf("failed to read response body: %w", err)
294+
}
295+
return mcp.NewToolResultError(fmt.Sprintf("failed to mark notification as done: %s", string(body))), nil
296+
}
297+
298+
return mcp.NewToolResultText("Notification marked as done"), nil
299+
}
300+
}

pkg/github/server.go

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ func NewServer(version string, opts ...server.ServerOption) *server.MCPServer {
2626
version,
2727
opts...,
2828
)
29+
2930
return s
3031
}
3132

@@ -143,6 +144,47 @@ func OptionalIntParamWithDefault(r mcp.CallToolRequest, p string, d int) (int, e
143144
return v, nil
144145
}
145146

147+
// OptionalBoolParamWithDefault is a helper function that can be used to fetch a requested parameter from the request
148+
// similar to optionalParam, but it also takes a default value.
149+
func OptionalBoolParamWithDefault(r mcp.CallToolRequest, p string, d bool) (bool, error) {
150+
v, err := OptionalParam[bool](r, p)
151+
if err != nil {
152+
return false, err
153+
}
154+
if !v {
155+
return d, nil
156+
}
157+
return v, nil
158+
}
159+
160+
// OptionalStringParam is a helper function that can be used to fetch a requested parameter from the request.
161+
// It does the following checks:
162+
// 1. Checks if the parameter is present in the request, if not, it returns its zero-value
163+
// 2. If it is present, it checks if the parameter is of the expected type and returns it
164+
func OptionalStringParam(r mcp.CallToolRequest, p string) (string, error) {
165+
v, err := OptionalParam[string](r, p)
166+
if err != nil {
167+
return "", err
168+
}
169+
if v == "" {
170+
return "", nil
171+
}
172+
return v, nil
173+
}
174+
175+
// OptionalStringParamWithDefault is a helper function that can be used to fetch a requested parameter from the request
176+
// similar to optionalParam, but it also takes a default value.
177+
func OptionalStringParamWithDefault(r mcp.CallToolRequest, p string, d string) (string, error) {
178+
v, err := OptionalParam[string](r, p)
179+
if err != nil {
180+
return "", err
181+
}
182+
if v == "" {
183+
return d, nil
184+
}
185+
return v, nil
186+
}
187+
146188
// OptionalStringArrayParam is a helper function that can be used to fetch a requested parameter from the request.
147189
// It does the following checks:
148190
// 1. Checks if the parameter is present in the request, if not, it returns its zero-value

pkg/github/tools.go

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,19 @@ func InitToolsets(passedToolsets []string, readOnly bool, getClient GetClientFn,
9191
toolsets.NewServerTool(GetSecretScanningAlert(getClient, t)),
9292
toolsets.NewServerTool(ListSecretScanningAlerts(getClient, t)),
9393
)
94+
95+
notifications := toolsets.NewToolset("notifications", "GitHub Notifications related tools").
96+
AddReadTools(
97+
98+
toolsets.NewServerTool(MarkNotificationRead(getClient, t)),
99+
toolsets.NewServerTool(MarkAllNotificationsRead(getClient, t)),
100+
toolsets.NewServerTool(MarkNotificationDone(getClient, t)),
101+
).
102+
AddWriteTools(
103+
toolsets.NewServerTool(GetNotifications(getClient, t)),
104+
toolsets.NewServerTool(GetNotificationThread(getClient, t)),
105+
)
106+
94107
// Keep experiments alive so the system doesn't error out when it's always enabled
95108
experiments := toolsets.NewToolset("experiments", "Experimental features that are not considered stable yet")
96109

@@ -101,6 +114,7 @@ func InitToolsets(passedToolsets []string, readOnly bool, getClient GetClientFn,
101114
tsg.AddToolset(pullRequests)
102115
tsg.AddToolset(codeSecurity)
103116
tsg.AddToolset(secretProtection)
117+
tsg.AddToolset(notifications)
104118
tsg.AddToolset(experiments)
105119
// Enable the requested features
106120

0 commit comments

Comments
 (0)