From 1d83838dc886e8cbba7112a9255101fb583890cd Mon Sep 17 00:00:00 2001 From: Jurre Stender Date: Wed, 4 Jun 2025 18:00:51 +0000 Subject: [PATCH] Add Global Security Advisories Toolset --- README.md | 23 +++ pkg/github/security_advisories.go | 198 ++++++++++++++++++++ pkg/github/security_advisories_test.go | 241 +++++++++++++++++++++++++ pkg/github/tools.go | 6 + 4 files changed, 468 insertions(+) create mode 100644 pkg/github/security_advisories.go create mode 100644 pkg/github/security_advisories_test.go diff --git a/README.md b/README.md index 7b9e20fc..471e1214 100644 --- a/README.md +++ b/README.md @@ -666,6 +666,29 @@ export GITHUB_MCP_TOOL_ADD_ISSUE_COMMENT_DESCRIPTION="an alternative description - `repo`: Repository name (string, required) - `prNumber`: Pull request number (string, required) - `path`: File or directory path (string, optional) + +## Security Advisories + +- **`list_global_security_advisories`** + List global security advisories + + - **Parameters**: + - * `ghsaId`: Filter by GitHub Security Advisory ID (string, optional – format: `GHSA-xxxx-xxxx-xxxx`) + - * `type`: Advisory type (string, optional – one of `reviewed`, `malware`, `unreviewed`) + - * `cveId`: Filter by CVE ID (string, optional) + - * `ecosystem`: Filter by package ecosystem (string, optional – one of `actions`, `composer`, `erlang`, `go`, `maven`, `npm`, `nuget`, `other`, `pip`, `pub`, `rubygems`, `rust`) + - * `severity`: Filter by severity (string, optional – one of `unknown`, `low`, `medium`, `high`, `critical`) + - * `cwes`: Filter by Common Weakness Enumeration IDs (array of strings, optional – e.g. `["79", "284", "22"]`) + - * `isWithdrawn`: Whether to only return withdrawn advisories (boolean, optional) + - * `affects`: Filter advisories by affected package or version (string, optional – e.g. `"package1,package2@1.0.0"`) + - * `published`: Filter by publish date or date range (string, optional – ISO 8601 date or range) + - * `updated`: Filter by update date or date range (string, optional – ISO 8601 date or range) + - * `modified`: Filter by publish or update date or date range (string, optional – ISO 8601 date or range) + +- **`get_global_security_advisory`** + Get a global security advisory + + - **Template**: `advisories/{ghsaId}` ## Library Usage diff --git a/pkg/github/security_advisories.go b/pkg/github/security_advisories.go new file mode 100644 index 00000000..6c01ce7d --- /dev/null +++ b/pkg/github/security_advisories.go @@ -0,0 +1,198 @@ +package github + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + + "github.com/github/github-mcp-server/pkg/translations" + "github.com/google/go-github/v72/github" + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" +) + +func ListGlobalSecurityAdvisories(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { + return mcp.NewTool("list_global_security_advisories", + mcp.WithDescription(t("TOOL_LIST_GLOBAL_SECURITY_ADVISORIES_DESCRIPTION", "List global security advisories from GitHub.")), + mcp.WithToolAnnotation(mcp.ToolAnnotation{ + Title: t("TOOL_LIST_GLOBAL_SECURITY_ADVISORIES_USER_TITLE", "List global security advisories"), + ReadOnlyHint: toBoolPtr(true), + }), + mcp.WithString("ghsaId", + mcp.Description("Filter by GitHub Security Advisory ID (format: GHSA-xxxx-xxxx-xxxx)."), + ), + mcp.WithString("type", + mcp.Description("Advisory type."), + mcp.Enum("reviewed", "malware", "unreviewed"), + ), + mcp.WithString("cveId", + mcp.Description("Filter by CVE ID."), + ), + mcp.WithString("ecosystem", + mcp.Description("Filter by package ecosystem."), + mcp.Enum("actions", "composer", "erlang", "go", "maven", "npm", "nuget", "other", "pip", "pub", "rubygems", "rust"), + ), + mcp.WithString("severity", + mcp.Description("Filter by severity."), + mcp.Enum("unknown", "low", "medium", "high", "critical"), + ), + mcp.WithArray("cwes", + mcp.Description("Filter by Common Weakness Enumeration IDs (e.g. [\"79\", \"284\", \"22\"])."), + ), + mcp.WithBoolean("isWithdrawn", + mcp.Description("Whether to only return withdrawn advisories."), + ), + mcp.WithString("affects", + mcp.Description("Filter advisories by affected package or version (e.g. \"package1,package2@1.0.0\")."), + ), + mcp.WithString("published", + mcp.Description("Filter by publish date or date range (ISO 8601 date or range)."), + ), + mcp.WithString("updated", + mcp.Description("Filter by update date or date range (ISO 8601 date or range)."), + ), + mcp.WithString("modified", + mcp.Description("Filter by publish or update date or date range (ISO 8601 date or range)."), + ), + ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } + + ghsaID, err := OptionalParam[string](request, "ghsaId") + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("invalid ghsaId: %v", err)), nil + } + + typ, err := OptionalParam[string](request, "type") + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("invalid type: %v", err)), nil + } + + cveID, err := OptionalParam[string](request, "cveId") + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("invalid cveId: %v", err)), nil + } + + eco, err := OptionalParam[string](request, "ecosystem") + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("invalid ecosystem: %v", err)), nil + } + + sev, err := OptionalParam[string](request, "severity") + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("invalid severity: %v", err)), nil + } + + cwes, err := OptionalParam[[]string](request, "cwes") + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("invalid cwes: %v", err)), nil + } + + isWithdrawn, err := OptionalParam[bool](request, "isWithdrawn") + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("invalid isWithdrawn: %v", err)), nil + } + + affects, err := OptionalParam[string](request, "affects") + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("invalid affects: %v", err)), nil + } + + published, err := OptionalParam[string](request, "published") + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("invalid published: %v", err)), nil + } + + updated, err := OptionalParam[string](request, "updated") + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("invalid updated: %v", err)), nil + } + + modified, err := OptionalParam[string](request, "modified") + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("invalid modified: %v", err)), nil + } + + advisories, resp, err := client.SecurityAdvisories.ListGlobalSecurityAdvisories(ctx, &github.ListGlobalSecurityAdvisoriesOptions{ + GHSAID: &ghsaID, + Type: &typ, + CVEID: &cveID, + Ecosystem: &eco, + Severity: &sev, + CWEs: cwes, + IsWithdrawn: &isWithdrawn, + Affects: &affects, + Published: &published, + Updated: &updated, + Modified: &modified, + }) + if err != nil { + return nil, fmt.Errorf("failed to list global security advisories: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + return mcp.NewToolResultError(fmt.Sprintf("failed to list advisories: %s", string(body))), nil + } + + r, err := json.Marshal(advisories) + if err != nil { + return nil, fmt.Errorf("failed to marshal advisories: %w", err) + } + + return mcp.NewToolResultText(string(r)), nil + } +} + +func GetGlobalSecurityAdvisory(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { + return mcp.NewTool("get_global_security_advisory", + mcp.WithDescription(t("TOOL_GET_GLOBAL_SECURITY_ADVISORY_DESCRIPTION", "Get a global security advisory")), + mcp.WithToolAnnotation(mcp.ToolAnnotation{ + Title: t("TOOL_GET_GLOBAL_SECURITY_ADVISORY_USER_TITLE", "Get a global security advisory"), + ReadOnlyHint: toBoolPtr(true), + }), + mcp.WithString("ghsaId", + mcp.Description("GitHub Security Advisory ID (format: GHSA-xxxx-xxxx-xxxx)."), + mcp.Required(), + ), + ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } + + ghsaID, err := requiredParam[string](request, "ghsaId") + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("invalid ghsaId: %v", err)), nil + } + + advisory, resp, err := client.SecurityAdvisories.GetGlobalSecurityAdvisories(ctx, ghsaID) + if err != nil { + return nil, fmt.Errorf("failed to get advisory: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + return mcp.NewToolResultError(fmt.Sprintf("failed to get advisory: %s", string(body))), nil + } + + r, err := json.Marshal(advisory) + if err != nil { + return nil, fmt.Errorf("failed to marshal advisory: %w", err) + } + + return mcp.NewToolResultText(string(r)), nil + } +} diff --git a/pkg/github/security_advisories_test.go b/pkg/github/security_advisories_test.go new file mode 100644 index 00000000..fc5acda7 --- /dev/null +++ b/pkg/github/security_advisories_test.go @@ -0,0 +1,241 @@ +package github + +import ( + "context" + "encoding/json" + "net/http" + "testing" + + "github.com/github/github-mcp-server/pkg/translations" + "github.com/google/go-github/v72/github" + "github.com/migueleliasweb/go-github-mock/src/mock" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_ListGlobalSecurityAdvisories(t *testing.T) { + mockClient := github.NewClient(nil) + tool, _ := ListGlobalSecurityAdvisories(stubGetClientFn(mockClient), translations.NullTranslationHelper) + + assert.Equal(t, "list_global_security_advisories", tool.Name) + assert.NotEmpty(t, tool.Description) + assert.Contains(t, tool.InputSchema.Properties, "ecosystem") + assert.Contains(t, tool.InputSchema.Properties, "severity") + assert.Contains(t, tool.InputSchema.Properties, "ghsaId") + assert.ElementsMatch(t, tool.InputSchema.Required, []string{}) + + // Setup mock advisory for success case + mockAdvisory := &github.GlobalSecurityAdvisory{ + SecurityAdvisory: github.SecurityAdvisory{ + GHSAID: github.Ptr("GHSA-xxxx-xxxx-xxxx"), + Summary: github.Ptr("Test advisory"), + Description: github.Ptr("This is a test advisory."), + Severity: github.Ptr("high"), + }, + } + + tests := []struct { + name string + mockedClient *http.Client + requestArgs map[string]interface{} + expectError bool + expectedAdvisories []*github.GlobalSecurityAdvisory + expectedErrMsg string + }{ + { + name: "successful advisory fetch", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatch( + mock.GetAdvisories, + []*github.GlobalSecurityAdvisory{mockAdvisory}, + ), + ), + requestArgs: map[string]interface{}{ + "ecosystem": "npm", + "severity": "high", + }, + expectError: false, + expectedAdvisories: []*github.GlobalSecurityAdvisory{mockAdvisory}, + }, + { + name: "invalid severity value", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.GetAdvisories, + http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusBadRequest) + _, _ = w.Write([]byte(`{"message": "Bad Request"}`)) + }), + ), + ), + requestArgs: map[string]interface{}{ + "severity": "extreme", + }, + expectError: true, + expectedErrMsg: "failed to list global security advisories", + }, + { + name: "API error handling", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.GetAdvisories, + http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + _, _ = w.Write([]byte(`{"message": "Internal Server Error"}`)) + }), + ), + ), + requestArgs: map[string]interface{}{}, + expectError: true, + expectedErrMsg: "failed to list global security advisories", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Setup client with mock + client := github.NewClient(tc.mockedClient) + _, handler := ListGlobalSecurityAdvisories(stubGetClientFn(client), translations.NullTranslationHelper) + + // Create call request + request := createMCPRequest(tc.requestArgs) + + // Call handler + result, err := handler(context.Background(), request) + + // Verify results + if tc.expectError { + require.Error(t, err) + assert.Contains(t, err.Error(), tc.expectedErrMsg) + return + } + + require.NoError(t, err) + + // Parse the result and get the text content if no error + textContent := getTextResult(t, result) + + // Unmarshal and verify the result + var returnedAdvisories []*github.GlobalSecurityAdvisory + err = json.Unmarshal([]byte(textContent.Text), &returnedAdvisories) + assert.NoError(t, err) + assert.Len(t, returnedAdvisories, len(tc.expectedAdvisories)) + for i, advisory := range returnedAdvisories { + assert.Equal(t, *tc.expectedAdvisories[i].GHSAID, *advisory.GHSAID) + assert.Equal(t, *tc.expectedAdvisories[i].Summary, *advisory.Summary) + assert.Equal(t, *tc.expectedAdvisories[i].Description, *advisory.Description) + assert.Equal(t, *tc.expectedAdvisories[i].Severity, *advisory.Severity) + } + }) + } +} + +func Test_GetGlobalSecurityAdvisory(t *testing.T) { + mockClient := github.NewClient(nil) + tool, _ := GetGlobalSecurityAdvisory(stubGetClientFn(mockClient), translations.NullTranslationHelper) + + assert.Equal(t, "get_global_security_advisory", tool.Name) + assert.NotEmpty(t, tool.Description) + assert.Contains(t, tool.InputSchema.Properties, "ghsaId") + assert.ElementsMatch(t, tool.InputSchema.Required, []string{"ghsaId"}) + + // Setup mock advisory for success case + mockAdvisory := &github.GlobalSecurityAdvisory{ + SecurityAdvisory: github.SecurityAdvisory{ + GHSAID: github.Ptr("GHSA-xxxx-xxxx-xxxx"), + Summary: github.Ptr("Test advisory"), + Description: github.Ptr("This is a test advisory."), + Severity: github.Ptr("high"), + }, + } + + tests := []struct { + name string + mockedClient *http.Client + requestArgs map[string]interface{} + expectError bool + expectedAdvisory *github.GlobalSecurityAdvisory + expectedErrMsg string + }{ + { + name: "successful advisory fetch", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatch( + mock.GetAdvisoriesByGhsaId, + mockAdvisory, + ), + ), + requestArgs: map[string]interface{}{ + "ghsaId": "GHSA-xxxx-xxxx-xxxx", + }, + expectError: false, + expectedAdvisory: mockAdvisory, + }, + { + name: "invalid ghsaId format", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.GetAdvisoriesByGhsaId, + http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusBadRequest) + _, _ = w.Write([]byte(`{"message": "Bad Request"}`)) + }), + ), + ), + requestArgs: map[string]interface{}{ + "ghsaId": "invalid-ghsa-id", + }, + expectError: true, + expectedErrMsg: "failed to get advisory", + }, + { + name: "advisory not found", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.GetAdvisoriesByGhsaId, + http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusNotFound) + _, _ = w.Write([]byte(`{"message": "Not Found"}`)) + }), + ), + ), + requestArgs: map[string]interface{}{ + "ghsaId": "GHSA-xxxx-xxxx-xxxx", + }, + expectError: true, + expectedErrMsg: "failed to get advisory", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Setup client with mock + client := github.NewClient(tc.mockedClient) + _, handler := GetGlobalSecurityAdvisory(stubGetClientFn(client), translations.NullTranslationHelper) + + // Create call request + request := createMCPRequest(tc.requestArgs) + + // Call handler + result, err := handler(context.Background(), request) + + // Verify results + if tc.expectError { + require.Error(t, err) + assert.Contains(t, err.Error(), tc.expectedErrMsg) + return + } + + require.NoError(t, err) + + // Parse the result and get the text content if no error + textContent := getTextResult(t, result) + + // Verify the result + assert.Contains(t, textContent.Text, *tc.expectedAdvisory.GHSAID) + assert.Contains(t, textContent.Text, *tc.expectedAdvisory.Summary) + assert.Contains(t, textContent.Text, *tc.expectedAdvisory.Description) + assert.Contains(t, textContent.Text, *tc.expectedAdvisory.Severity) + }) + } +} diff --git a/pkg/github/tools.go b/pkg/github/tools.go index ab052817..b0eab8df 100644 --- a/pkg/github/tools.go +++ b/pkg/github/tools.go @@ -104,6 +104,11 @@ func InitToolsets(passedToolsets []string, readOnly bool, getClient GetClientFn, toolsets.NewServerTool(ManageRepositoryNotificationSubscription(getClient, t)), ) + securityAdvisories := toolsets.NewToolset("security_advisories", "Security advisories related tools"). + AddReadTools( + toolsets.NewServerTool(ListGlobalSecurityAdvisories(getClient, t)), + toolsets.NewServerTool(GetGlobalSecurityAdvisory(getClient, t)), + ) // Keep experiments alive so the system doesn't error out when it's always enabled experiments := toolsets.NewToolset("experiments", "Experimental features that are not considered stable yet") @@ -116,6 +121,7 @@ func InitToolsets(passedToolsets []string, readOnly bool, getClient GetClientFn, tsg.AddToolset(secretProtection) tsg.AddToolset(notifications) tsg.AddToolset(experiments) + tsg.AddToolset(securityAdvisories) // Enable the requested features if err := tsg.EnableToolsets(passedToolsets); err != nil {