From 3dddc2abcd67a95916edbcab572768ff1079afbb Mon Sep 17 00:00:00 2001 From: Javier Uruen Val Date: Sun, 16 Mar 2025 20:18:47 +0100 Subject: [PATCH] add initial tests --- go.mod | 9 +- go.sum | 12 +- pkg/github/code_scanning.go | 5 +- pkg/github/code_scanning_test.go | 230 +++++++ pkg/github/helper_test.go | 49 ++ pkg/github/issues.go | 7 +- pkg/github/issues_test.go | 371 ++++++++++++ pkg/github/pullrequests.go | 24 +- pkg/github/pullrequests_test.go | 990 +++++++++++++++++++++++++++++++ pkg/github/repositories.go | 10 +- pkg/github/repositories_test.go | 909 ++++++++++++++++++++++++++++ pkg/github/search_test.go | 429 ++++++++++++++ pkg/github/server.go | 10 +- pkg/github/server_test.go | 168 ++++++ 14 files changed, 3203 insertions(+), 20 deletions(-) create mode 100644 pkg/github/code_scanning_test.go create mode 100644 pkg/github/helper_test.go create mode 100644 pkg/github/issues_test.go create mode 100644 pkg/github/pullrequests_test.go create mode 100644 pkg/github/repositories_test.go create mode 100644 pkg/github/search_test.go create mode 100644 pkg/github/server_test.go diff --git a/go.mod b/go.mod index e53b8b6b..4338a69d 100644 --- a/go.mod +++ b/go.mod @@ -6,21 +6,27 @@ require ( github.com/aws/smithy-go v1.22.3 github.com/google/go-github/v69 v69.2.0 github.com/mark3labs/mcp-go v0.11.2 + github.com/migueleliasweb/go-github-mock v1.1.0 github.com/sirupsen/logrus v1.9.3 github.com/spf13/cobra v1.9.1 github.com/spf13/viper v1.19.0 + github.com/stretchr/testify v1.9.0 golang.org/x/exp v0.0.0-20230905200255-921286631fa9 ) require ( + github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/fsnotify/fsnotify v1.7.0 // indirect + github.com/google/go-github/v64 v64.0.0 // indirect github.com/google/go-querystring v1.1.0 // indirect github.com/google/uuid v1.6.0 // indirect + github.com/gorilla/mux v1.8.0 // indirect github.com/hashicorp/hcl v1.0.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/magiconair/properties v1.8.7 // indirect github.com/mitchellh/mapstructure v1.5.0 // indirect github.com/pelletier/go-toml/v2 v2.2.2 // indirect + github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/sagikazarmark/locafero v0.4.0 // indirect github.com/sagikazarmark/slog-shim v0.1.0 // indirect github.com/sourcegraph/conc v0.3.0 // indirect @@ -31,7 +37,8 @@ require ( go.uber.org/atomic v1.9.0 // indirect go.uber.org/multierr v1.9.0 // indirect golang.org/x/sys v0.18.0 // indirect - golang.org/x/text v0.14.0 // indirect + golang.org/x/text v0.19.0 // indirect + golang.org/x/time v0.5.0 // indirect gopkg.in/ini.v1 v1.67.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 42c7171f..5aa0482d 100644 --- a/go.sum +++ b/go.sum @@ -12,12 +12,16 @@ github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyT github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/go-github/v64 v64.0.0 h1:4G61sozmY3eiPAjjoOHponXDBONm+utovTKbyUb2Qdg= +github.com/google/go-github/v64 v64.0.0/go.mod h1:xB3vqMQNdHzilXBiO2I+M7iEFtHf+DP/omBOv6tQzVo= github.com/google/go-github/v69 v69.2.0 h1:wR+Wi/fN2zdUx9YxSmYE0ktiX9IAR/BeePzeaUUbEHE= github.com/google/go-github/v69 v69.2.0/go.mod h1:xne4jymxLR6Uj9b7J7PyTpkMYstEMMwGZa0Aehh1azM= github.com/google/go-querystring v1.1.0 h1:AnCroh3fv4ZBgVIf1Iwtovgjaw/GiKJo8M8yD/fhyJ8= github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17icRSOU623lUBU= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/gorilla/mux v1.8.0 h1:i40aqfkR1h2SlN9hojwV5ZA91wcXFOvkdNIeFDP5koI= +github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So= github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4= github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= @@ -30,6 +34,8 @@ github.com/magiconair/properties v1.8.7 h1:IeQXZAiQcpL9mgcAe1Nu6cX9LLw6ExEHKjN0V github.com/magiconair/properties v1.8.7/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0= github.com/mark3labs/mcp-go v0.11.2 h1:mCxWFUTrcXOtJIn9t7F8bxAL8rpE/ZZTTnx3PU/VNdA= github.com/mark3labs/mcp-go v0.11.2/go.mod h1:cjMlBU0cv/cj9kjlgmRhoJ5JREdS7YX83xeIG9Ko/jE= +github.com/migueleliasweb/go-github-mock v1.1.0 h1:GKaOBPsrPGkAKgtfuWY8MclS1xR6MInkx1SexJucMwE= +github.com/migueleliasweb/go-github-mock v1.1.0/go.mod h1:pYe/XlGs4BGMfRY4vmeixVsODHnVDDhJ9zoi0qzSMHc= github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY= github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM= @@ -80,8 +86,10 @@ golang.org/x/exp v0.0.0-20230905200255-921286631fa9/go.mod h1:S2oDrQGGwySpoQPVqR golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4= golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= -golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= +golang.org/x/text v0.19.0 h1:kTxAhCbGbxhK0IwgSKiMO5awPoDQ0RpfiVYBfK860YM= +golang.org/x/text v0.19.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= +golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk= +golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo= diff --git a/pkg/github/code_scanning.go b/pkg/github/code_scanning.go index da714744..0d9547eb 100644 --- a/pkg/github/code_scanning.go +++ b/pkg/github/code_scanning.go @@ -5,6 +5,7 @@ import ( "encoding/json" "fmt" "io" + "net/http" "github.com/google/go-github/v69/github" "github.com/mark3labs/mcp-go/mcp" @@ -38,7 +39,7 @@ func getCodeScanningAlert(client *github.Client) (tool mcp.Tool, handler server. } defer func() { _ = resp.Body.Close() }() - if resp.StatusCode != 200 { + if resp.StatusCode != http.StatusOK { body, err := io.ReadAll(resp.Body) if err != nil { return nil, fmt.Errorf("failed to read response body: %w", err) @@ -90,7 +91,7 @@ func listCodeScanningAlerts(client *github.Client) (tool mcp.Tool, handler serve } defer func() { _ = resp.Body.Close() }() - if resp.StatusCode != 200 { + if resp.StatusCode != http.StatusOK { body, err := io.ReadAll(resp.Body) if err != nil { return nil, fmt.Errorf("failed to read response body: %w", err) diff --git a/pkg/github/code_scanning_test.go b/pkg/github/code_scanning_test.go new file mode 100644 index 00000000..149c8b03 --- /dev/null +++ b/pkg/github/code_scanning_test.go @@ -0,0 +1,230 @@ +package github + +import ( + "context" + "encoding/json" + "net/http" + "testing" + + "github.com/google/go-github/v69/github" + "github.com/migueleliasweb/go-github-mock/src/mock" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_GetCodeScanningAlert(t *testing.T) { + // Verify tool definition once + mockClient := github.NewClient(nil) + tool, _ := getCodeScanningAlert(mockClient) + + assert.Equal(t, "get_code_scanning_alert", tool.Name) + assert.NotEmpty(t, tool.Description) + assert.Contains(t, tool.InputSchema.Properties, "owner") + assert.Contains(t, tool.InputSchema.Properties, "repo") + assert.Contains(t, tool.InputSchema.Properties, "alert_number") + assert.ElementsMatch(t, tool.InputSchema.Required, []string{"owner", "repo", "alert_number"}) + + // Setup mock alert for success case + mockAlert := &github.Alert{ + Number: github.Ptr(42), + State: github.Ptr("open"), + Rule: &github.Rule{ID: github.Ptr("test-rule"), Description: github.Ptr("Test Rule Description")}, + HTMLURL: github.Ptr("https://github.com/owner/repo/security/code-scanning/42"), + } + + tests := []struct { + name string + mockedClient *http.Client + requestArgs map[string]interface{} + expectError bool + expectedAlert *github.Alert + expectedErrMsg string + }{ + { + name: "successful alert fetch", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatch( + mock.GetReposCodeScanningAlertsByOwnerByRepoByAlertNumber, + mockAlert, + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "alert_number": float64(42), + }, + expectError: false, + expectedAlert: mockAlert, + }, + { + name: "alert fetch fails", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.GetReposCodeScanningAlertsByOwnerByRepoByAlertNumber, + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + _, _ = w.Write([]byte(`{"message": "Not Found"}`)) + }), + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "alert_number": float64(9999), + }, + expectError: true, + expectedErrMsg: "failed to get alert", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Setup client with mock + client := github.NewClient(tc.mockedClient) + _, handler := getCodeScanningAlert(client) + + // Create call request + request := createMCPRequest(tc.requestArgs) + + // Call handler + result, err := handler(context.Background(), request) + + // Verify results + if tc.expectError { + require.Error(t, err) + assert.Contains(t, err.Error(), tc.expectedErrMsg) + return + } + + require.NoError(t, err) + + // Parse the result and get the text content if no error + textContent := getTextResult(t, result) + + // Unmarshal and verify the result + var returnedAlert github.Alert + err = json.Unmarshal([]byte(textContent.Text), &returnedAlert) + assert.NoError(t, err) + assert.Equal(t, *tc.expectedAlert.Number, *returnedAlert.Number) + assert.Equal(t, *tc.expectedAlert.State, *returnedAlert.State) + assert.Equal(t, *tc.expectedAlert.Rule.ID, *returnedAlert.Rule.ID) + assert.Equal(t, *tc.expectedAlert.HTMLURL, *returnedAlert.HTMLURL) + + }) + } +} + +func Test_ListCodeScanningAlerts(t *testing.T) { + // Verify tool definition once + mockClient := github.NewClient(nil) + tool, _ := listCodeScanningAlerts(mockClient) + + assert.Equal(t, "list_code_scanning_alerts", tool.Name) + assert.NotEmpty(t, tool.Description) + assert.Contains(t, tool.InputSchema.Properties, "owner") + assert.Contains(t, tool.InputSchema.Properties, "repo") + assert.Contains(t, tool.InputSchema.Properties, "ref") + assert.Contains(t, tool.InputSchema.Properties, "state") + assert.Contains(t, tool.InputSchema.Properties, "severity") + assert.ElementsMatch(t, tool.InputSchema.Required, []string{"owner", "repo"}) + + // Setup mock alerts for success case + mockAlerts := []*github.Alert{ + { + Number: github.Ptr(42), + State: github.Ptr("open"), + Rule: &github.Rule{ID: github.Ptr("test-rule-1"), Description: github.Ptr("Test Rule 1")}, + HTMLURL: github.Ptr("https://github.com/owner/repo/security/code-scanning/42"), + }, + { + Number: github.Ptr(43), + State: github.Ptr("fixed"), + Rule: &github.Rule{ID: github.Ptr("test-rule-2"), Description: github.Ptr("Test Rule 2")}, + HTMLURL: github.Ptr("https://github.com/owner/repo/security/code-scanning/43"), + }, + } + + tests := []struct { + name string + mockedClient *http.Client + requestArgs map[string]interface{} + expectError bool + expectedAlerts []*github.Alert + expectedErrMsg string + }{ + { + name: "successful alerts listing", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatch( + mock.GetReposCodeScanningAlertsByOwnerByRepo, + mockAlerts, + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "ref": "main", + "state": "open", + "severity": "high", + }, + expectError: false, + expectedAlerts: mockAlerts, + }, + { + name: "alerts listing fails", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.GetReposCodeScanningAlertsByOwnerByRepo, + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusUnauthorized) + _, _ = w.Write([]byte(`{"message": "Unauthorized access"}`)) + }), + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + }, + expectError: true, + expectedErrMsg: "failed to list alerts", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Setup client with mock + client := github.NewClient(tc.mockedClient) + _, handler := listCodeScanningAlerts(client) + + // Create call request + request := createMCPRequest(tc.requestArgs) + + // Call handler + result, err := handler(context.Background(), request) + + // Verify results + if tc.expectError { + require.Error(t, err) + assert.Contains(t, err.Error(), tc.expectedErrMsg) + return + } + + require.NoError(t, err) + + // Parse the result and get the text content if no error + textContent := getTextResult(t, result) + + // Unmarshal and verify the result + var returnedAlerts []*github.Alert + err = json.Unmarshal([]byte(textContent.Text), &returnedAlerts) + assert.NoError(t, err) + assert.Len(t, returnedAlerts, len(tc.expectedAlerts)) + for i, alert := range returnedAlerts { + assert.Equal(t, *tc.expectedAlerts[i].Number, *alert.Number) + assert.Equal(t, *tc.expectedAlerts[i].State, *alert.State) + assert.Equal(t, *tc.expectedAlerts[i].Rule.ID, *alert.Rule.ID) + assert.Equal(t, *tc.expectedAlerts[i].HTMLURL, *alert.HTMLURL) + } + }) + } +} diff --git a/pkg/github/helper_test.go b/pkg/github/helper_test.go new file mode 100644 index 00000000..5e71f418 --- /dev/null +++ b/pkg/github/helper_test.go @@ -0,0 +1,49 @@ +package github + +import ( + "encoding/json" + "github.com/stretchr/testify/assert" + "net/http" + "testing" + + "github.com/mark3labs/mcp-go/mcp" + "github.com/stretchr/testify/require" +) + +// mockResponse is a helper function to create a mock HTTP response handler +// that returns a specified status code and marshalled body. +func mockResponse(t *testing.T, code int, body interface{}) http.HandlerFunc { + t.Helper() + return func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(code) + b, err := json.Marshal(body) + require.NoError(t, err) + _, _ = w.Write(b) + } +} + +// createMCPRequest is a helper function to create a MCP request with the given arguments. +func createMCPRequest(args map[string]interface{}) mcp.CallToolRequest { + return mcp.CallToolRequest{ + Params: struct { + Name string `json:"name"` + Arguments map[string]interface{} `json:"arguments,omitempty"` + Meta *struct { + ProgressToken mcp.ProgressToken `json:"progressToken,omitempty"` + } `json:"_meta,omitempty"` + }{ + Arguments: args, + }, + } +} + +// getTextResult is a helper function that returns a text result from a tool call. +func getTextResult(t *testing.T, result *mcp.CallToolResult) mcp.TextContent { + t.Helper() + assert.NotNil(t, result) + require.Len(t, result.Content, 1) + require.IsType(t, mcp.TextContent{}, result.Content[0]) + textContent := result.Content[0].(mcp.TextContent) + assert.Equal(t, "text", textContent.Type) + return textContent +} diff --git a/pkg/github/issues.go b/pkg/github/issues.go index c7c17289..6a43e59d 100644 --- a/pkg/github/issues.go +++ b/pkg/github/issues.go @@ -5,6 +5,7 @@ import ( "encoding/json" "fmt" "io" + "net/http" "github.com/google/go-github/v69/github" "github.com/mark3labs/mcp-go/mcp" @@ -39,7 +40,7 @@ func getIssue(client *github.Client) (tool mcp.Tool, handler server.ToolHandlerF } defer func() { _ = resp.Body.Close() }() - if resp.StatusCode != 200 { + if resp.StatusCode != http.StatusOK { body, err := io.ReadAll(resp.Body) if err != nil { return nil, fmt.Errorf("failed to read response body: %w", err) @@ -93,7 +94,7 @@ func addIssueComment(client *github.Client) (tool mcp.Tool, handler server.ToolH } defer func() { _ = resp.Body.Close() }() - if resp.StatusCode != 201 { + if resp.StatusCode != http.StatusCreated { body, err := io.ReadAll(resp.Body) if err != nil { return nil, fmt.Errorf("failed to read response body: %w", err) @@ -165,7 +166,7 @@ func searchIssues(client *github.Client) (tool mcp.Tool, handler server.ToolHand } defer func() { _ = resp.Body.Close() }() - if resp.StatusCode != 200 { + if resp.StatusCode != http.StatusOK { body, err := io.ReadAll(resp.Body) if err != nil { return nil, fmt.Errorf("failed to read response body: %w", err) diff --git a/pkg/github/issues_test.go b/pkg/github/issues_test.go new file mode 100644 index 00000000..7e9944b3 --- /dev/null +++ b/pkg/github/issues_test.go @@ -0,0 +1,371 @@ +package github + +import ( + "context" + "encoding/json" + "net/http" + "testing" + + "github.com/google/go-github/v69/github" + "github.com/mark3labs/mcp-go/mcp" + "github.com/migueleliasweb/go-github-mock/src/mock" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_GetIssue(t *testing.T) { + // Verify tool definition once + mockClient := github.NewClient(nil) + tool, _ := getIssue(mockClient) + + assert.Equal(t, "get_issue", tool.Name) + assert.NotEmpty(t, tool.Description) + assert.Contains(t, tool.InputSchema.Properties, "owner") + assert.Contains(t, tool.InputSchema.Properties, "repo") + assert.Contains(t, tool.InputSchema.Properties, "issue_number") + assert.ElementsMatch(t, tool.InputSchema.Required, []string{"owner", "repo", "issue_number"}) + + // Setup mock issue for success case + mockIssue := &github.Issue{ + Number: github.Ptr(42), + Title: github.Ptr("Test Issue"), + Body: github.Ptr("This is a test issue"), + State: github.Ptr("open"), + HTMLURL: github.Ptr("https://github.com/owner/repo/issues/42"), + } + + tests := []struct { + name string + mockedClient *http.Client + requestArgs map[string]interface{} + expectError bool + expectedIssue *github.Issue + expectedErrMsg string + }{ + { + name: "successful issue retrieval", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatch( + mock.GetReposIssuesByOwnerByRepoByIssueNumber, + mockIssue, + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "issue_number": float64(42), + }, + expectError: false, + expectedIssue: mockIssue, + }, + { + name: "issue not found", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.GetReposIssuesByOwnerByRepoByIssueNumber, + mockResponse(t, http.StatusNotFound, `{"message": "Issue not found"}`), + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "issue_number": float64(999), + }, + expectError: true, + expectedErrMsg: "failed to get issue", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Setup client with mock + client := github.NewClient(tc.mockedClient) + _, handler := getIssue(client) + + // Create call request + request := createMCPRequest(tc.requestArgs) + + // Call handler + result, err := handler(context.Background(), request) + + // Verify results + if tc.expectError { + require.Error(t, err) + assert.Contains(t, err.Error(), tc.expectedErrMsg) + return + } + + require.NoError(t, err) + textContent := getTextResult(t, result) + + // Unmarshal and verify the result + var returnedIssue github.Issue + err = json.Unmarshal([]byte(textContent.Text), &returnedIssue) + require.NoError(t, err) + assert.Equal(t, *tc.expectedIssue.Number, *returnedIssue.Number) + assert.Equal(t, *tc.expectedIssue.Title, *returnedIssue.Title) + assert.Equal(t, *tc.expectedIssue.Body, *returnedIssue.Body) + }) + } +} + +func Test_AddIssueComment(t *testing.T) { + // Verify tool definition once + mockClient := github.NewClient(nil) + tool, _ := addIssueComment(mockClient) + + assert.Equal(t, "add_issue_comment", tool.Name) + assert.NotEmpty(t, tool.Description) + assert.Contains(t, tool.InputSchema.Properties, "owner") + assert.Contains(t, tool.InputSchema.Properties, "repo") + assert.Contains(t, tool.InputSchema.Properties, "issue_number") + assert.Contains(t, tool.InputSchema.Properties, "body") + assert.ElementsMatch(t, tool.InputSchema.Required, []string{"owner", "repo", "issue_number", "body"}) + + // Setup mock comment for success case + mockComment := &github.IssueComment{ + ID: github.Ptr(int64(123)), + Body: github.Ptr("This is a test comment"), + User: &github.User{ + Login: github.Ptr("testuser"), + }, + HTMLURL: github.Ptr("https://github.com/owner/repo/issues/42#issuecomment-123"), + } + + tests := []struct { + name string + mockedClient *http.Client + requestArgs map[string]interface{} + expectError bool + expectedComment *github.IssueComment + expectedErrMsg string + }{ + { + name: "successful comment creation", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.PostReposIssuesCommentsByOwnerByRepoByIssueNumber, + mockResponse(t, http.StatusCreated, mockComment), + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "issue_number": float64(42), + "body": "This is a test comment", + }, + expectError: false, + expectedComment: mockComment, + }, + { + name: "comment creation fails", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.PostReposIssuesCommentsByOwnerByRepoByIssueNumber, + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusUnprocessableEntity) + _, _ = w.Write([]byte(`{"message": "Invalid request"}`)) + }), + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "issue_number": float64(42), + "body": "", + }, + expectError: true, + expectedErrMsg: "failed to create comment", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Setup client with mock + client := github.NewClient(tc.mockedClient) + _, handler := addIssueComment(client) + + // Create call request + request := mcp.CallToolRequest{ + Params: struct { + Name string `json:"name"` + Arguments map[string]interface{} `json:"arguments,omitempty"` + Meta *struct { + ProgressToken mcp.ProgressToken `json:"progressToken,omitempty"` + } `json:"_meta,omitempty"` + }{ + Arguments: tc.requestArgs, + }, + } + + // Call handler + result, err := handler(context.Background(), request) + + // Verify results + if tc.expectError { + require.Error(t, err) + assert.Contains(t, err.Error(), tc.expectedErrMsg) + return + } + + require.NoError(t, err) + + // Parse the result and get the text content if no error + textContent := getTextResult(t, result) + + // Unmarshal and verify the result + var returnedComment github.IssueComment + err = json.Unmarshal([]byte(textContent.Text), &returnedComment) + require.NoError(t, err) + assert.Equal(t, *tc.expectedComment.ID, *returnedComment.ID) + assert.Equal(t, *tc.expectedComment.Body, *returnedComment.Body) + assert.Equal(t, *tc.expectedComment.User.Login, *returnedComment.User.Login) + + }) + } +} + +func Test_SearchIssues(t *testing.T) { + // Verify tool definition once + mockClient := github.NewClient(nil) + tool, _ := searchIssues(mockClient) + + assert.Equal(t, "search_issues", tool.Name) + assert.NotEmpty(t, tool.Description) + assert.Contains(t, tool.InputSchema.Properties, "q") + assert.Contains(t, tool.InputSchema.Properties, "sort") + assert.Contains(t, tool.InputSchema.Properties, "order") + assert.Contains(t, tool.InputSchema.Properties, "per_page") + assert.Contains(t, tool.InputSchema.Properties, "page") + assert.ElementsMatch(t, tool.InputSchema.Required, []string{"q"}) + + // Setup mock search results + mockSearchResult := &github.IssuesSearchResult{ + Total: github.Ptr(2), + IncompleteResults: github.Ptr(false), + Issues: []*github.Issue{ + { + Number: github.Ptr(42), + Title: github.Ptr("Bug: Something is broken"), + Body: github.Ptr("This is a bug report"), + State: github.Ptr("open"), + HTMLURL: github.Ptr("https://github.com/owner/repo/issues/42"), + Comments: github.Ptr(5), + User: &github.User{ + Login: github.Ptr("user1"), + }, + }, + { + Number: github.Ptr(43), + Title: github.Ptr("Feature: Add new functionality"), + Body: github.Ptr("This is a feature request"), + State: github.Ptr("open"), + HTMLURL: github.Ptr("https://github.com/owner/repo/issues/43"), + Comments: github.Ptr(3), + User: &github.User{ + Login: github.Ptr("user2"), + }, + }, + }, + } + + tests := []struct { + name string + mockedClient *http.Client + requestArgs map[string]interface{} + expectError bool + expectedResult *github.IssuesSearchResult + expectedErrMsg string + }{ + { + name: "successful issues search with all parameters", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatch( + mock.GetSearchIssues, + mockSearchResult, + ), + ), + requestArgs: map[string]interface{}{ + "q": "repo:owner/repo is:issue is:open", + "sort": "created", + "order": "desc", + "page": float64(1), + "per_page": float64(30), + }, + expectError: false, + expectedResult: mockSearchResult, + }, + { + name: "issues search with minimal parameters", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatch( + mock.GetSearchIssues, + mockSearchResult, + ), + ), + requestArgs: map[string]interface{}{ + "q": "repo:owner/repo is:issue is:open", + }, + expectError: false, + expectedResult: mockSearchResult, + }, + { + name: "search issues fails", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.GetSearchIssues, + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusBadRequest) + _, _ = w.Write([]byte(`{"message": "Validation Failed"}`)) + }), + ), + ), + requestArgs: map[string]interface{}{ + "q": "invalid:query", + }, + expectError: true, + expectedErrMsg: "failed to search issues", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Setup client with mock + client := github.NewClient(tc.mockedClient) + _, handler := searchIssues(client) + + // Create call request + request := createMCPRequest(tc.requestArgs) + + // Call handler + result, err := handler(context.Background(), request) + + // Verify results + if tc.expectError { + require.Error(t, err) + assert.Contains(t, err.Error(), tc.expectedErrMsg) + return + } + + require.NoError(t, err) + + // Parse the result and get the text content if no error + textContent := getTextResult(t, result) + + // Unmarshal and verify the result + var returnedResult github.IssuesSearchResult + err = json.Unmarshal([]byte(textContent.Text), &returnedResult) + require.NoError(t, err) + assert.Equal(t, *tc.expectedResult.Total, *returnedResult.Total) + assert.Equal(t, *tc.expectedResult.IncompleteResults, *returnedResult.IncompleteResults) + assert.Len(t, returnedResult.Issues, len(tc.expectedResult.Issues)) + for i, issue := range returnedResult.Issues { + assert.Equal(t, *tc.expectedResult.Issues[i].Number, *issue.Number) + assert.Equal(t, *tc.expectedResult.Issues[i].Title, *issue.Title) + assert.Equal(t, *tc.expectedResult.Issues[i].State, *issue.State) + assert.Equal(t, *tc.expectedResult.Issues[i].HTMLURL, *issue.HTMLURL) + assert.Equal(t, *tc.expectedResult.Issues[i].User.Login, *issue.User.Login) + } + }) + } +} diff --git a/pkg/github/pullrequests.go b/pkg/github/pullrequests.go index 1cf5f724..b2f191b4 100644 --- a/pkg/github/pullrequests.go +++ b/pkg/github/pullrequests.go @@ -5,6 +5,7 @@ import ( "encoding/json" "fmt" "io" + "net/http" "github.com/google/go-github/v69/github" "github.com/mark3labs/mcp-go/mcp" @@ -39,7 +40,7 @@ func getPullRequest(client *github.Client) (tool mcp.Tool, handler server.ToolHa } defer func() { _ = resp.Body.Close() }() - if resp.StatusCode != 200 { + if resp.StatusCode != http.StatusOK { body, err := io.ReadAll(resp.Body) if err != nil { return nil, fmt.Errorf("failed to read response body: %w", err) @@ -140,7 +141,7 @@ func listPullRequests(client *github.Client) (tool mcp.Tool, handler server.Tool } defer func() { _ = resp.Body.Close() }() - if resp.StatusCode != 200 { + if resp.StatusCode != http.StatusOK { body, err := io.ReadAll(resp.Body) if err != nil { return nil, fmt.Errorf("failed to read response body: %w", err) @@ -211,7 +212,7 @@ func mergePullRequest(client *github.Client) (tool mcp.Tool, handler server.Tool } defer func() { _ = resp.Body.Close() }() - if resp.StatusCode != 200 { + if resp.StatusCode != http.StatusOK { body, err := io.ReadAll(resp.Body) if err != nil { return nil, fmt.Errorf("failed to read response body: %w", err) @@ -257,7 +258,7 @@ func getPullRequestFiles(client *github.Client) (tool mcp.Tool, handler server.T } defer func() { _ = resp.Body.Close() }() - if resp.StatusCode != 200 { + if resp.StatusCode != http.StatusOK { body, err := io.ReadAll(resp.Body) if err != nil { return nil, fmt.Errorf("failed to read response body: %w", err) @@ -303,7 +304,7 @@ func getPullRequestStatus(client *github.Client) (tool mcp.Tool, handler server. } defer func() { _ = resp.Body.Close() }() - if resp.StatusCode != 200 { + if resp.StatusCode != http.StatusOK { body, err := io.ReadAll(resp.Body) if err != nil { return nil, fmt.Errorf("failed to read response body: %w", err) @@ -318,7 +319,7 @@ func getPullRequestStatus(client *github.Client) (tool mcp.Tool, handler server. } defer func() { _ = resp.Body.Close() }() - if resp.StatusCode != 200 { + if resp.StatusCode != http.StatusOK { body, err := io.ReadAll(resp.Body) if err != nil { return nil, fmt.Errorf("failed to read response body: %w", err) @@ -371,11 +372,16 @@ func updatePullRequestBranch(client *github.Client) (tool mcp.Tool, handler serv result, resp, err := client.PullRequests.UpdateBranch(ctx, owner, repo, pullNumber, opts) if err != nil { + // Check if it's an acceptedError. An acceptedError indicates that the update is in progress, + // and it's not a real error. + if resp != nil && resp.StatusCode == http.StatusAccepted && isAcceptedError(err) { + return mcp.NewToolResultText("Pull request branch update is in progress"), nil + } return nil, fmt.Errorf("failed to update pull request branch: %w", err) } defer func() { _ = resp.Body.Close() }() - if resp.StatusCode != 202 { + if resp.StatusCode != http.StatusAccepted { body, err := io.ReadAll(resp.Body) if err != nil { return nil, fmt.Errorf("failed to read response body: %w", err) @@ -426,7 +432,7 @@ func getPullRequestComments(client *github.Client) (tool mcp.Tool, handler serve } defer func() { _ = resp.Body.Close() }() - if resp.StatusCode != 200 { + if resp.StatusCode != http.StatusOK { body, err := io.ReadAll(resp.Body) if err != nil { return nil, fmt.Errorf("failed to read response body: %w", err) @@ -471,7 +477,7 @@ func getPullRequestReviews(client *github.Client) (tool mcp.Tool, handler server } defer func() { _ = resp.Body.Close() }() - if resp.StatusCode != 200 { + if resp.StatusCode != http.StatusOK { body, err := io.ReadAll(resp.Body) if err != nil { return nil, fmt.Errorf("failed to read response body: %w", err) diff --git a/pkg/github/pullrequests_test.go b/pkg/github/pullrequests_test.go new file mode 100644 index 00000000..bbafc921 --- /dev/null +++ b/pkg/github/pullrequests_test.go @@ -0,0 +1,990 @@ +package github + +import ( + "context" + "encoding/json" + "net/http" + "testing" + "time" + + "github.com/google/go-github/v69/github" + "github.com/migueleliasweb/go-github-mock/src/mock" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_GetPullRequest(t *testing.T) { + // Verify tool definition once + mockClient := github.NewClient(nil) + tool, _ := getPullRequest(mockClient) + + assert.Equal(t, "get_pull_request", tool.Name) + assert.NotEmpty(t, tool.Description) + assert.Contains(t, tool.InputSchema.Properties, "owner") + assert.Contains(t, tool.InputSchema.Properties, "repo") + assert.Contains(t, tool.InputSchema.Properties, "pull_number") + assert.ElementsMatch(t, tool.InputSchema.Required, []string{"owner", "repo", "pull_number"}) + + // Setup mock PR for success case + mockPR := &github.PullRequest{ + Number: github.Ptr(42), + Title: github.Ptr("Test PR"), + State: github.Ptr("open"), + HTMLURL: github.Ptr("https://github.com/owner/repo/pull/42"), + Head: &github.PullRequestBranch{ + SHA: github.Ptr("abcd1234"), + Ref: github.Ptr("feature-branch"), + }, + Base: &github.PullRequestBranch{ + Ref: github.Ptr("main"), + }, + Body: github.Ptr("This is a test PR"), + User: &github.User{ + Login: github.Ptr("testuser"), + }, + } + + tests := []struct { + name string + mockedClient *http.Client + requestArgs map[string]interface{} + expectError bool + expectedPR *github.PullRequest + expectedErrMsg string + }{ + { + name: "successful PR fetch", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatch( + mock.GetReposPullsByOwnerByRepoByPullNumber, + mockPR, + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "pull_number": float64(42), + }, + expectError: false, + expectedPR: mockPR, + }, + { + name: "PR fetch fails", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.GetReposPullsByOwnerByRepoByPullNumber, + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + _, _ = w.Write([]byte(`{"message": "Not Found"}`)) + }), + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "pull_number": float64(999), + }, + expectError: true, + expectedErrMsg: "failed to get pull request", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Setup client with mock + client := github.NewClient(tc.mockedClient) + _, handler := getPullRequest(client) + + // Create call request + request := createMCPRequest(tc.requestArgs) + + // Call handler + result, err := handler(context.Background(), request) + + // Verify results + if tc.expectError { + require.Error(t, err) + assert.Contains(t, err.Error(), tc.expectedErrMsg) + return + } + + require.NoError(t, err) + + // Parse the result and get the text content if no error + textContent := getTextResult(t, result) + + // Unmarshal and verify the result + var returnedPR github.PullRequest + err = json.Unmarshal([]byte(textContent.Text), &returnedPR) + require.NoError(t, err) + assert.Equal(t, *tc.expectedPR.Number, *returnedPR.Number) + assert.Equal(t, *tc.expectedPR.Title, *returnedPR.Title) + assert.Equal(t, *tc.expectedPR.State, *returnedPR.State) + assert.Equal(t, *tc.expectedPR.HTMLURL, *returnedPR.HTMLURL) + }) + } +} + +func Test_ListPullRequests(t *testing.T) { + // Verify tool definition once + mockClient := github.NewClient(nil) + tool, _ := listPullRequests(mockClient) + + assert.Equal(t, "list_pull_requests", tool.Name) + assert.NotEmpty(t, tool.Description) + assert.Contains(t, tool.InputSchema.Properties, "owner") + assert.Contains(t, tool.InputSchema.Properties, "repo") + assert.Contains(t, tool.InputSchema.Properties, "state") + assert.Contains(t, tool.InputSchema.Properties, "head") + assert.Contains(t, tool.InputSchema.Properties, "base") + assert.Contains(t, tool.InputSchema.Properties, "sort") + assert.Contains(t, tool.InputSchema.Properties, "direction") + assert.Contains(t, tool.InputSchema.Properties, "per_page") + assert.Contains(t, tool.InputSchema.Properties, "page") + assert.ElementsMatch(t, tool.InputSchema.Required, []string{"owner", "repo"}) + + // Setup mock PRs for success case + mockPRs := []*github.PullRequest{ + { + Number: github.Ptr(42), + Title: github.Ptr("First PR"), + State: github.Ptr("open"), + HTMLURL: github.Ptr("https://github.com/owner/repo/pull/42"), + }, + { + Number: github.Ptr(43), + Title: github.Ptr("Second PR"), + State: github.Ptr("closed"), + HTMLURL: github.Ptr("https://github.com/owner/repo/pull/43"), + }, + } + + tests := []struct { + name string + mockedClient *http.Client + requestArgs map[string]interface{} + expectError bool + expectedPRs []*github.PullRequest + expectedErrMsg string + }{ + { + name: "successful PRs listing", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatch( + mock.GetReposPullsByOwnerByRepo, + mockPRs, + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "state": "all", + "sort": "created", + "direction": "desc", + "per_page": float64(30), + "page": float64(1), + }, + expectError: false, + expectedPRs: mockPRs, + }, + { + name: "PRs listing fails", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.GetReposPullsByOwnerByRepo, + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusBadRequest) + _, _ = w.Write([]byte(`{"message": "Invalid request"}`)) + }), + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "state": "invalid", + }, + expectError: true, + expectedErrMsg: "failed to list pull requests", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Setup client with mock + client := github.NewClient(tc.mockedClient) + _, handler := listPullRequests(client) + + // Create call request + request := createMCPRequest(tc.requestArgs) + + // Call handler + result, err := handler(context.Background(), request) + + // Verify results + if tc.expectError { + require.Error(t, err) + assert.Contains(t, err.Error(), tc.expectedErrMsg) + return + } + + require.NoError(t, err) + + // Parse the result and get the text content if no error + textContent := getTextResult(t, result) + + // Unmarshal and verify the result + var returnedPRs []*github.PullRequest + err = json.Unmarshal([]byte(textContent.Text), &returnedPRs) + require.NoError(t, err) + assert.Len(t, returnedPRs, 2) + assert.Equal(t, *tc.expectedPRs[0].Number, *returnedPRs[0].Number) + assert.Equal(t, *tc.expectedPRs[0].Title, *returnedPRs[0].Title) + assert.Equal(t, *tc.expectedPRs[0].State, *returnedPRs[0].State) + assert.Equal(t, *tc.expectedPRs[1].Number, *returnedPRs[1].Number) + assert.Equal(t, *tc.expectedPRs[1].Title, *returnedPRs[1].Title) + assert.Equal(t, *tc.expectedPRs[1].State, *returnedPRs[1].State) + }) + } +} + +func Test_MergePullRequest(t *testing.T) { + // Verify tool definition once + mockClient := github.NewClient(nil) + tool, _ := mergePullRequest(mockClient) + + assert.Equal(t, "merge_pull_request", tool.Name) + assert.NotEmpty(t, tool.Description) + assert.Contains(t, tool.InputSchema.Properties, "owner") + assert.Contains(t, tool.InputSchema.Properties, "repo") + assert.Contains(t, tool.InputSchema.Properties, "pull_number") + assert.Contains(t, tool.InputSchema.Properties, "commit_title") + assert.Contains(t, tool.InputSchema.Properties, "commit_message") + assert.Contains(t, tool.InputSchema.Properties, "merge_method") + assert.ElementsMatch(t, tool.InputSchema.Required, []string{"owner", "repo", "pull_number"}) + + // Setup mock merge result for success case + mockMergeResult := &github.PullRequestMergeResult{ + Merged: github.Ptr(true), + Message: github.Ptr("Pull Request successfully merged"), + SHA: github.Ptr("abcd1234efgh5678"), + } + + tests := []struct { + name string + mockedClient *http.Client + requestArgs map[string]interface{} + expectError bool + expectedMergeResult *github.PullRequestMergeResult + expectedErrMsg string + }{ + { + name: "successful merge", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatch( + mock.PutReposPullsMergeByOwnerByRepoByPullNumber, + mockMergeResult, + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "pull_number": float64(42), + "commit_title": "Merge PR #42", + "commit_message": "Merging awesome feature", + "merge_method": "squash", + }, + expectError: false, + expectedMergeResult: mockMergeResult, + }, + { + name: "merge fails", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.PutReposPullsMergeByOwnerByRepoByPullNumber, + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusMethodNotAllowed) + _, _ = w.Write([]byte(`{"message": "Pull request cannot be merged"}`)) + }), + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "pull_number": float64(42), + }, + expectError: true, + expectedErrMsg: "failed to merge pull request", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Setup client with mock + client := github.NewClient(tc.mockedClient) + _, handler := mergePullRequest(client) + + // Create call request + request := createMCPRequest(tc.requestArgs) + + // Call handler + result, err := handler(context.Background(), request) + + // Verify results + if tc.expectError { + require.Error(t, err) + assert.Contains(t, err.Error(), tc.expectedErrMsg) + return + } + + require.NoError(t, err) + + // Parse the result and get the text content if no error + textContent := getTextResult(t, result) + + // Unmarshal and verify the result + var returnedResult github.PullRequestMergeResult + err = json.Unmarshal([]byte(textContent.Text), &returnedResult) + require.NoError(t, err) + assert.Equal(t, *tc.expectedMergeResult.Merged, *returnedResult.Merged) + assert.Equal(t, *tc.expectedMergeResult.Message, *returnedResult.Message) + assert.Equal(t, *tc.expectedMergeResult.SHA, *returnedResult.SHA) + }) + } +} + +func Test_GetPullRequestFiles(t *testing.T) { + // Verify tool definition once + mockClient := github.NewClient(nil) + tool, _ := getPullRequestFiles(mockClient) + + assert.Equal(t, "get_pull_request_files", tool.Name) + assert.NotEmpty(t, tool.Description) + assert.Contains(t, tool.InputSchema.Properties, "owner") + assert.Contains(t, tool.InputSchema.Properties, "repo") + assert.Contains(t, tool.InputSchema.Properties, "pull_number") + assert.ElementsMatch(t, tool.InputSchema.Required, []string{"owner", "repo", "pull_number"}) + + // Setup mock PR files for success case + mockFiles := []*github.CommitFile{ + { + Filename: github.Ptr("file1.go"), + Status: github.Ptr("modified"), + Additions: github.Ptr(10), + Deletions: github.Ptr(5), + Changes: github.Ptr(15), + Patch: github.Ptr("@@ -1,5 +1,10 @@"), + }, + { + Filename: github.Ptr("file2.go"), + Status: github.Ptr("added"), + Additions: github.Ptr(20), + Deletions: github.Ptr(0), + Changes: github.Ptr(20), + Patch: github.Ptr("@@ -0,0 +1,20 @@"), + }, + } + + tests := []struct { + name string + mockedClient *http.Client + requestArgs map[string]interface{} + expectError bool + expectedFiles []*github.CommitFile + expectedErrMsg string + }{ + { + name: "successful files fetch", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatch( + mock.GetReposPullsFilesByOwnerByRepoByPullNumber, + mockFiles, + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "pull_number": float64(42), + }, + expectError: false, + expectedFiles: mockFiles, + }, + { + name: "files fetch fails", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.GetReposPullsFilesByOwnerByRepoByPullNumber, + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + _, _ = w.Write([]byte(`{"message": "Not Found"}`)) + }), + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "pull_number": float64(999), + }, + expectError: true, + expectedErrMsg: "failed to get pull request files", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Setup client with mock + client := github.NewClient(tc.mockedClient) + _, handler := getPullRequestFiles(client) + + // Create call request + request := createMCPRequest(tc.requestArgs) + + // Call handler + result, err := handler(context.Background(), request) + + // Verify results + if tc.expectError { + require.Error(t, err) + assert.Contains(t, err.Error(), tc.expectedErrMsg) + return + } + + require.NoError(t, err) + + // Parse the result and get the text content if no error + textContent := getTextResult(t, result) + + // Unmarshal and verify the result + var returnedFiles []*github.CommitFile + err = json.Unmarshal([]byte(textContent.Text), &returnedFiles) + require.NoError(t, err) + assert.Len(t, returnedFiles, len(tc.expectedFiles)) + for i, file := range returnedFiles { + assert.Equal(t, *tc.expectedFiles[i].Filename, *file.Filename) + assert.Equal(t, *tc.expectedFiles[i].Status, *file.Status) + assert.Equal(t, *tc.expectedFiles[i].Additions, *file.Additions) + assert.Equal(t, *tc.expectedFiles[i].Deletions, *file.Deletions) + } + }) + } +} + +func Test_GetPullRequestStatus(t *testing.T) { + // Verify tool definition once + mockClient := github.NewClient(nil) + tool, _ := getPullRequestStatus(mockClient) + + assert.Equal(t, "get_pull_request_status", tool.Name) + assert.NotEmpty(t, tool.Description) + assert.Contains(t, tool.InputSchema.Properties, "owner") + assert.Contains(t, tool.InputSchema.Properties, "repo") + assert.Contains(t, tool.InputSchema.Properties, "pull_number") + assert.ElementsMatch(t, tool.InputSchema.Required, []string{"owner", "repo", "pull_number"}) + + // Setup mock PR for successful PR fetch + mockPR := &github.PullRequest{ + Number: github.Ptr(42), + Title: github.Ptr("Test PR"), + HTMLURL: github.Ptr("https://github.com/owner/repo/pull/42"), + Head: &github.PullRequestBranch{ + SHA: github.Ptr("abcd1234"), + Ref: github.Ptr("feature-branch"), + }, + } + + // Setup mock status for success case + mockStatus := &github.CombinedStatus{ + State: github.Ptr("success"), + TotalCount: github.Ptr(3), + Statuses: []*github.RepoStatus{ + { + State: github.Ptr("success"), + Context: github.Ptr("continuous-integration/travis-ci"), + Description: github.Ptr("Build succeeded"), + TargetURL: github.Ptr("https://travis-ci.org/owner/repo/builds/123"), + }, + { + State: github.Ptr("success"), + Context: github.Ptr("codecov/patch"), + Description: github.Ptr("Coverage increased"), + TargetURL: github.Ptr("https://codecov.io/gh/owner/repo/pull/42"), + }, + { + State: github.Ptr("success"), + Context: github.Ptr("lint/golangci-lint"), + Description: github.Ptr("No issues found"), + TargetURL: github.Ptr("https://golangci.com/r/owner/repo/pull/42"), + }, + }, + } + + tests := []struct { + name string + mockedClient *http.Client + requestArgs map[string]interface{} + expectError bool + expectedStatus *github.CombinedStatus + expectedErrMsg string + }{ + { + name: "successful status fetch", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatch( + mock.GetReposPullsByOwnerByRepoByPullNumber, + mockPR, + ), + mock.WithRequestMatch( + mock.GetReposCommitsStatusByOwnerByRepoByRef, + mockStatus, + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "pull_number": float64(42), + }, + expectError: false, + expectedStatus: mockStatus, + }, + { + name: "PR fetch fails", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.GetReposPullsByOwnerByRepoByPullNumber, + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + _, _ = w.Write([]byte(`{"message": "Not Found"}`)) + }), + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "pull_number": float64(999), + }, + expectError: true, + expectedErrMsg: "failed to get pull request", + }, + { + name: "status fetch fails", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatch( + mock.GetReposPullsByOwnerByRepoByPullNumber, + mockPR, + ), + mock.WithRequestMatchHandler( + mock.GetReposCommitsStatusesByOwnerByRepoByRef, + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + _, _ = w.Write([]byte(`{"message": "Not Found"}`)) + }), + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "pull_number": float64(42), + }, + expectError: true, + expectedErrMsg: "failed to get combined status", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Setup client with mock + client := github.NewClient(tc.mockedClient) + _, handler := getPullRequestStatus(client) + + // Create call request + request := createMCPRequest(tc.requestArgs) + + // Call handler + result, err := handler(context.Background(), request) + + // Verify results + if tc.expectError { + require.Error(t, err) + assert.Contains(t, err.Error(), tc.expectedErrMsg) + return + } + + require.NoError(t, err) + + // Parse the result and get the text content if no error + textContent := getTextResult(t, result) + + // Unmarshal and verify the result + var returnedStatus github.CombinedStatus + err = json.Unmarshal([]byte(textContent.Text), &returnedStatus) + require.NoError(t, err) + assert.Equal(t, *tc.expectedStatus.State, *returnedStatus.State) + assert.Equal(t, *tc.expectedStatus.TotalCount, *returnedStatus.TotalCount) + assert.Len(t, returnedStatus.Statuses, len(tc.expectedStatus.Statuses)) + for i, status := range returnedStatus.Statuses { + assert.Equal(t, *tc.expectedStatus.Statuses[i].State, *status.State) + assert.Equal(t, *tc.expectedStatus.Statuses[i].Context, *status.Context) + assert.Equal(t, *tc.expectedStatus.Statuses[i].Description, *status.Description) + } + }) + } +} + +func Test_UpdatePullRequestBranch(t *testing.T) { + // Verify tool definition once + mockClient := github.NewClient(nil) + tool, _ := updatePullRequestBranch(mockClient) + + assert.Equal(t, "update_pull_request_branch", tool.Name) + assert.NotEmpty(t, tool.Description) + assert.Contains(t, tool.InputSchema.Properties, "owner") + assert.Contains(t, tool.InputSchema.Properties, "repo") + assert.Contains(t, tool.InputSchema.Properties, "pull_number") + assert.Contains(t, tool.InputSchema.Properties, "expected_head_sha") + assert.ElementsMatch(t, tool.InputSchema.Required, []string{"owner", "repo", "pull_number"}) + + // Setup mock update result for success case + mockUpdateResult := &github.PullRequestBranchUpdateResponse{ + Message: github.Ptr("Branch was updated successfully"), + URL: github.Ptr("https://api.github.com/repos/owner/repo/pulls/42"), + } + + tests := []struct { + name string + mockedClient *http.Client + requestArgs map[string]interface{} + expectError bool + expectedUpdateResult *github.PullRequestBranchUpdateResponse + expectedErrMsg string + }{ + { + name: "successful branch update", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.PutReposPullsUpdateBranchByOwnerByRepoByPullNumber, + mockResponse(t, http.StatusAccepted, mockUpdateResult), + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "pull_number": float64(42), + "expected_head_sha": "abcd1234", + }, + expectError: false, + expectedUpdateResult: mockUpdateResult, + }, + { + name: "branch update without expected SHA", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.PutReposPullsUpdateBranchByOwnerByRepoByPullNumber, + mockResponse(t, http.StatusAccepted, mockUpdateResult), + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "pull_number": float64(42), + }, + expectError: false, + expectedUpdateResult: mockUpdateResult, + }, + { + name: "branch update fails", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.PutReposPullsUpdateBranchByOwnerByRepoByPullNumber, + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusConflict) + _, _ = w.Write([]byte(`{"message": "Merge conflict"}`)) + }), + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "pull_number": float64(42), + }, + expectError: true, + expectedErrMsg: "failed to update pull request branch", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Setup client with mock + client := github.NewClient(tc.mockedClient) + _, handler := updatePullRequestBranch(client) + + // Create call request + request := createMCPRequest(tc.requestArgs) + + // Call handler + result, err := handler(context.Background(), request) + + // Verify results + if tc.expectError { + require.Error(t, err) + assert.Contains(t, err.Error(), tc.expectedErrMsg) + return + } + + require.NoError(t, err) + + // Parse the result and get the text content if no error + textContent := getTextResult(t, result) + + assert.Contains(t, textContent.Text, "is in progress") + }) + } +} + +func Test_GetPullRequestComments(t *testing.T) { + // Verify tool definition once + mockClient := github.NewClient(nil) + tool, _ := getPullRequestComments(mockClient) + + assert.Equal(t, "get_pull_request_comments", tool.Name) + assert.NotEmpty(t, tool.Description) + assert.Contains(t, tool.InputSchema.Properties, "owner") + assert.Contains(t, tool.InputSchema.Properties, "repo") + assert.Contains(t, tool.InputSchema.Properties, "pull_number") + assert.ElementsMatch(t, tool.InputSchema.Required, []string{"owner", "repo", "pull_number"}) + + // Setup mock PR comments for success case + mockComments := []*github.PullRequestComment{ + { + ID: github.Ptr(int64(101)), + Body: github.Ptr("This looks good"), + HTMLURL: github.Ptr("https://github.com/owner/repo/pull/42#discussion_r101"), + User: &github.User{ + Login: github.Ptr("reviewer1"), + }, + Path: github.Ptr("file1.go"), + Position: github.Ptr(5), + CommitID: github.Ptr("abcdef123456"), + CreatedAt: &github.Timestamp{Time: time.Now().Add(-24 * time.Hour)}, + UpdatedAt: &github.Timestamp{Time: time.Now().Add(-24 * time.Hour)}, + }, + { + ID: github.Ptr(int64(102)), + Body: github.Ptr("Please fix this"), + HTMLURL: github.Ptr("https://github.com/owner/repo/pull/42#discussion_r102"), + User: &github.User{ + Login: github.Ptr("reviewer2"), + }, + Path: github.Ptr("file2.go"), + Position: github.Ptr(10), + CommitID: github.Ptr("abcdef123456"), + CreatedAt: &github.Timestamp{Time: time.Now().Add(-12 * time.Hour)}, + UpdatedAt: &github.Timestamp{Time: time.Now().Add(-12 * time.Hour)}, + }, + } + + tests := []struct { + name string + mockedClient *http.Client + requestArgs map[string]interface{} + expectError bool + expectedComments []*github.PullRequestComment + expectedErrMsg string + }{ + { + name: "successful comments fetch", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatch( + mock.GetReposPullsCommentsByOwnerByRepoByPullNumber, + mockComments, + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "pull_number": float64(42), + }, + expectError: false, + expectedComments: mockComments, + }, + { + name: "comments fetch fails", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.GetReposPullsCommentsByOwnerByRepoByPullNumber, + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + _, _ = w.Write([]byte(`{"message": "Not Found"}`)) + }), + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "pull_number": float64(999), + }, + expectError: true, + expectedErrMsg: "failed to get pull request comments", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Setup client with mock + client := github.NewClient(tc.mockedClient) + _, handler := getPullRequestComments(client) + + // Create call request + request := createMCPRequest(tc.requestArgs) + + // Call handler + result, err := handler(context.Background(), request) + + // Verify results + if tc.expectError { + require.Error(t, err) + assert.Contains(t, err.Error(), tc.expectedErrMsg) + return + } + + require.NoError(t, err) + + // Parse the result and get the text content if no error + textContent := getTextResult(t, result) + + // Unmarshal and verify the result + var returnedComments []*github.PullRequestComment + err = json.Unmarshal([]byte(textContent.Text), &returnedComments) + require.NoError(t, err) + assert.Len(t, returnedComments, len(tc.expectedComments)) + for i, comment := range returnedComments { + assert.Equal(t, *tc.expectedComments[i].ID, *comment.ID) + assert.Equal(t, *tc.expectedComments[i].Body, *comment.Body) + assert.Equal(t, *tc.expectedComments[i].User.Login, *comment.User.Login) + assert.Equal(t, *tc.expectedComments[i].Path, *comment.Path) + assert.Equal(t, *tc.expectedComments[i].HTMLURL, *comment.HTMLURL) + } + }) + } +} + +func Test_GetPullRequestReviews(t *testing.T) { + // Verify tool definition once + mockClient := github.NewClient(nil) + tool, _ := getPullRequestReviews(mockClient) + + assert.Equal(t, "get_pull_request_reviews", tool.Name) + assert.NotEmpty(t, tool.Description) + assert.Contains(t, tool.InputSchema.Properties, "owner") + assert.Contains(t, tool.InputSchema.Properties, "repo") + assert.Contains(t, tool.InputSchema.Properties, "pull_number") + assert.ElementsMatch(t, tool.InputSchema.Required, []string{"owner", "repo", "pull_number"}) + + // Setup mock PR reviews for success case + mockReviews := []*github.PullRequestReview{ + { + ID: github.Ptr(int64(201)), + State: github.Ptr("APPROVED"), + Body: github.Ptr("LGTM"), + HTMLURL: github.Ptr("https://github.com/owner/repo/pull/42#pullrequestreview-201"), + User: &github.User{ + Login: github.Ptr("approver"), + }, + CommitID: github.Ptr("abcdef123456"), + SubmittedAt: &github.Timestamp{Time: time.Now().Add(-24 * time.Hour)}, + }, + { + ID: github.Ptr(int64(202)), + State: github.Ptr("CHANGES_REQUESTED"), + Body: github.Ptr("Please address the following issues"), + HTMLURL: github.Ptr("https://github.com/owner/repo/pull/42#pullrequestreview-202"), + User: &github.User{ + Login: github.Ptr("reviewer"), + }, + CommitID: github.Ptr("abcdef123456"), + SubmittedAt: &github.Timestamp{Time: time.Now().Add(-12 * time.Hour)}, + }, + } + + tests := []struct { + name string + mockedClient *http.Client + requestArgs map[string]interface{} + expectError bool + expectedReviews []*github.PullRequestReview + expectedErrMsg string + }{ + { + name: "successful reviews fetch", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatch( + mock.GetReposPullsReviewsByOwnerByRepoByPullNumber, + mockReviews, + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "pull_number": float64(42), + }, + expectError: false, + expectedReviews: mockReviews, + }, + { + name: "reviews fetch fails", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.GetReposPullsReviewsByOwnerByRepoByPullNumber, + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + _, _ = w.Write([]byte(`{"message": "Not Found"}`)) + }), + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "pull_number": float64(999), + }, + expectError: true, + expectedErrMsg: "failed to get pull request reviews", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Setup client with mock + client := github.NewClient(tc.mockedClient) + _, handler := getPullRequestReviews(client) + + // Create call request + request := createMCPRequest(tc.requestArgs) + + // Call handler + result, err := handler(context.Background(), request) + + // Verify results + if tc.expectError { + require.Error(t, err) + assert.Contains(t, err.Error(), tc.expectedErrMsg) + return + } + + require.NoError(t, err) + + // Parse the result and get the text content if no error + textContent := getTextResult(t, result) + + // Unmarshal and verify the result + var returnedReviews []*github.PullRequestReview + err = json.Unmarshal([]byte(textContent.Text), &returnedReviews) + require.NoError(t, err) + assert.Len(t, returnedReviews, len(tc.expectedReviews)) + for i, review := range returnedReviews { + assert.Equal(t, *tc.expectedReviews[i].ID, *review.ID) + assert.Equal(t, *tc.expectedReviews[i].State, *review.State) + assert.Equal(t, *tc.expectedReviews[i].Body, *review.Body) + assert.Equal(t, *tc.expectedReviews[i].User.Login, *review.User.Login) + assert.Equal(t, *tc.expectedReviews[i].HTMLURL, *review.HTMLURL) + } + }) + } +} diff --git a/pkg/github/repositories.go b/pkg/github/repositories.go index 37e07597..607f9d92 100644 --- a/pkg/github/repositories.go +++ b/pkg/github/repositories.go @@ -5,6 +5,7 @@ import ( "encoding/json" "fmt" "io" + "net/http" "github.com/aws/smithy-go/ptr" "github.com/google/go-github/v69/github" @@ -206,7 +207,7 @@ func createRepository(client *github.Client) (tool mcp.Tool, handler server.Tool } defer func() { _ = resp.Body.Close() }() - if resp.StatusCode != 201 { + if resp.StatusCode != http.StatusCreated { body, err := io.ReadAll(resp.Body) if err != nil { return nil, fmt.Errorf("failed to read response body: %w", err) @@ -314,11 +315,16 @@ func forkRepository(client *github.Client) (tool mcp.Tool, handler server.ToolHa forkedRepo, resp, err := client.Repositories.CreateFork(ctx, owner, repo, opts) if err != nil { + // Check if it's an acceptedError. An acceptedError indicates that the update is in progress, + // and it's not a real error. + if resp != nil && resp.StatusCode == http.StatusAccepted && isAcceptedError(err) { + return mcp.NewToolResultText("Fork is in progress"), nil + } return nil, fmt.Errorf("failed to fork repository: %w", err) } defer func() { _ = resp.Body.Close() }() - if resp.StatusCode != 202 { + if resp.StatusCode != http.StatusAccepted { body, err := io.ReadAll(resp.Body) if err != nil { return nil, fmt.Errorf("failed to read response body: %w", err) diff --git a/pkg/github/repositories_test.go b/pkg/github/repositories_test.go new file mode 100644 index 00000000..4e39b47f --- /dev/null +++ b/pkg/github/repositories_test.go @@ -0,0 +1,909 @@ +package github + +import ( + "context" + "encoding/json" + "net/http" + "testing" + "time" + + "github.com/google/go-github/v69/github" + "github.com/mark3labs/mcp-go/mcp" + "github.com/migueleliasweb/go-github-mock/src/mock" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_GetFileContents(t *testing.T) { + // Verify tool definition once + mockClient := github.NewClient(nil) + tool, _ := getFileContents(mockClient) + + assert.Equal(t, "get_file_contents", tool.Name) + assert.NotEmpty(t, tool.Description) + assert.Contains(t, tool.InputSchema.Properties, "owner") + assert.Contains(t, tool.InputSchema.Properties, "repo") + assert.Contains(t, tool.InputSchema.Properties, "path") + assert.Contains(t, tool.InputSchema.Properties, "branch") + assert.ElementsMatch(t, tool.InputSchema.Required, []string{"owner", "repo", "path"}) + + // Setup mock file content for success case + mockFileContent := &github.RepositoryContent{ + Type: github.Ptr("file"), + Name: github.Ptr("README.md"), + Path: github.Ptr("README.md"), + Content: github.Ptr("IyBUZXN0IFJlcG9zaXRvcnkKClRoaXMgaXMgYSB0ZXN0IHJlcG9zaXRvcnku"), // Base64 encoded "# Test Repository\n\nThis is a test repository." + SHA: github.Ptr("abc123"), + Size: github.Ptr(42), + HTMLURL: github.Ptr("https://github.com/owner/repo/blob/main/README.md"), + DownloadURL: github.Ptr("https://raw.githubusercontent.com/owner/repo/main/README.md"), + } + + // Setup mock directory content for success case + mockDirContent := []*github.RepositoryContent{ + { + Type: github.Ptr("file"), + Name: github.Ptr("README.md"), + Path: github.Ptr("README.md"), + SHA: github.Ptr("abc123"), + Size: github.Ptr(42), + HTMLURL: github.Ptr("https://github.com/owner/repo/blob/main/README.md"), + }, + { + Type: github.Ptr("dir"), + Name: github.Ptr("src"), + Path: github.Ptr("src"), + SHA: github.Ptr("def456"), + HTMLURL: github.Ptr("https://github.com/owner/repo/tree/main/src"), + }, + } + + tests := []struct { + name string + mockedClient *http.Client + requestArgs map[string]interface{} + expectError bool + expectedResult interface{} + expectedErrMsg string + }{ + { + name: "successful file content fetch", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatch( + mock.GetReposContentsByOwnerByRepoByPath, + mockFileContent, + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "path": "README.md", + "branch": "main", + }, + expectError: false, + expectedResult: mockFileContent, + }, + { + name: "successful directory content fetch", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatch( + mock.GetReposContentsByOwnerByRepoByPath, + mockDirContent, + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "path": "src", + }, + expectError: false, + expectedResult: mockDirContent, + }, + { + name: "content fetch fails", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.GetReposContentsByOwnerByRepoByPath, + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + _, _ = w.Write([]byte(`{"message": "Not Found"}`)) + }), + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "path": "nonexistent.md", + "branch": "main", + }, + expectError: true, + expectedErrMsg: "failed to get file contents", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Setup client with mock + client := github.NewClient(tc.mockedClient) + _, handler := getFileContents(client) + + // Create call request + request := mcp.CallToolRequest{ + Params: struct { + Name string `json:"name"` + Arguments map[string]interface{} `json:"arguments,omitempty"` + Meta *struct { + ProgressToken mcp.ProgressToken `json:"progressToken,omitempty"` + } `json:"_meta,omitempty"` + }{ + Arguments: tc.requestArgs, + }, + } + + // Call handler + result, err := handler(context.Background(), request) + + // Verify results + if tc.expectError { + require.Error(t, err) + assert.Contains(t, err.Error(), tc.expectedErrMsg) + return + } + + require.NoError(t, err) + + // Parse the result and get the text content if no error + textContent := getTextResult(t, result) + + // Verify based on expected type + switch expected := tc.expectedResult.(type) { + case *github.RepositoryContent: + var returnedContent github.RepositoryContent + err = json.Unmarshal([]byte(textContent.Text), &returnedContent) + require.NoError(t, err) + assert.Equal(t, *expected.Name, *returnedContent.Name) + assert.Equal(t, *expected.Path, *returnedContent.Path) + assert.Equal(t, *expected.Type, *returnedContent.Type) + case []*github.RepositoryContent: + var returnedContents []*github.RepositoryContent + err = json.Unmarshal([]byte(textContent.Text), &returnedContents) + require.NoError(t, err) + assert.Len(t, returnedContents, len(expected)) + for i, content := range returnedContents { + assert.Equal(t, *expected[i].Name, *content.Name) + assert.Equal(t, *expected[i].Path, *content.Path) + assert.Equal(t, *expected[i].Type, *content.Type) + } + } + }) + } +} + +func Test_ForkRepository(t *testing.T) { + // Verify tool definition once + mockClient := github.NewClient(nil) + tool, _ := forkRepository(mockClient) + + assert.Equal(t, "fork_repository", tool.Name) + assert.NotEmpty(t, tool.Description) + assert.Contains(t, tool.InputSchema.Properties, "owner") + assert.Contains(t, tool.InputSchema.Properties, "repo") + assert.Contains(t, tool.InputSchema.Properties, "organization") + assert.ElementsMatch(t, tool.InputSchema.Required, []string{"owner", "repo"}) + + // Setup mock forked repo for success case + mockForkedRepo := &github.Repository{ + ID: github.Ptr(int64(123456)), + Name: github.Ptr("repo"), + FullName: github.Ptr("new-owner/repo"), + Owner: &github.User{ + Login: github.Ptr("new-owner"), + }, + HTMLURL: github.Ptr("https://github.com/new-owner/repo"), + DefaultBranch: github.Ptr("main"), + Fork: github.Ptr(true), + ForksCount: github.Ptr(0), + } + + tests := []struct { + name string + mockedClient *http.Client + requestArgs map[string]interface{} + expectError bool + expectedRepo *github.Repository + expectedErrMsg string + }{ + { + name: "successful repository fork", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.PostReposForksByOwnerByRepo, + mockResponse(t, http.StatusAccepted, mockForkedRepo), + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + }, + expectError: false, + expectedRepo: mockForkedRepo, + }, + { + name: "repository fork fails", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.PostReposForksByOwnerByRepo, + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusForbidden) + _, _ = w.Write([]byte(`{"message": "Forbidden"}`)) + }), + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + }, + expectError: true, + expectedErrMsg: "failed to fork repository", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Setup client with mock + client := github.NewClient(tc.mockedClient) + _, handler := forkRepository(client) + + // Create call request + request := createMCPRequest(tc.requestArgs) + + // Call handler + result, err := handler(context.Background(), request) + + // Verify results + if tc.expectError { + require.Error(t, err) + assert.Contains(t, err.Error(), tc.expectedErrMsg) + return + } + + require.NoError(t, err) + + // Parse the result and get the text content if no error + textContent := getTextResult(t, result) + + assert.Contains(t, textContent.Text, "Fork is in progress") + }) + } +} + +func Test_CreateBranch(t *testing.T) { + // Verify tool definition once + mockClient := github.NewClient(nil) + tool, _ := createBranch(mockClient) + + assert.Equal(t, "create_branch", tool.Name) + assert.NotEmpty(t, tool.Description) + assert.Contains(t, tool.InputSchema.Properties, "owner") + assert.Contains(t, tool.InputSchema.Properties, "repo") + assert.Contains(t, tool.InputSchema.Properties, "branch") + assert.Contains(t, tool.InputSchema.Properties, "from_branch") + assert.ElementsMatch(t, tool.InputSchema.Required, []string{"owner", "repo", "branch"}) + + // Setup mock repository for default branch test + mockRepo := &github.Repository{ + DefaultBranch: github.Ptr("main"), + } + + // Setup mock reference for from_branch tests + mockSourceRef := &github.Reference{ + Ref: github.Ptr("refs/heads/main"), + Object: &github.GitObject{ + SHA: github.Ptr("abc123def456"), + }, + } + + // Setup mock created reference + mockCreatedRef := &github.Reference{ + Ref: github.Ptr("refs/heads/new-feature"), + Object: &github.GitObject{ + SHA: github.Ptr("abc123def456"), + }, + } + + tests := []struct { + name string + mockedClient *http.Client + requestArgs map[string]interface{} + expectError bool + expectedRef *github.Reference + expectedErrMsg string + }{ + { + name: "successful branch creation with from_branch", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatch( + mock.GetReposGitRefByOwnerByRepoByRef, + mockSourceRef, + ), + mock.WithRequestMatch( + mock.PostReposGitRefsByOwnerByRepo, + mockCreatedRef, + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "branch": "new-feature", + "from_branch": "main", + }, + expectError: false, + expectedRef: mockCreatedRef, + }, + { + name: "successful branch creation with default branch", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatch( + mock.GetReposByOwnerByRepo, + mockRepo, + ), + mock.WithRequestMatch( + mock.GetReposGitRefByOwnerByRepoByRef, + mockSourceRef, + ), + mock.WithRequestMatch( + mock.PostReposGitRefsByOwnerByRepo, + mockCreatedRef, + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "branch": "new-feature", + }, + expectError: false, + expectedRef: mockCreatedRef, + }, + { + name: "fail to get repository", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.GetReposByOwnerByRepo, + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + _, _ = w.Write([]byte(`{"message": "Repository not found"}`)) + }), + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "nonexistent-repo", + "branch": "new-feature", + }, + expectError: true, + expectedErrMsg: "failed to get repository", + }, + { + name: "fail to get reference", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.GetReposGitRefByOwnerByRepoByRef, + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + _, _ = w.Write([]byte(`{"message": "Reference not found"}`)) + }), + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "branch": "new-feature", + "from_branch": "nonexistent-branch", + }, + expectError: true, + expectedErrMsg: "failed to get reference", + }, + { + name: "fail to create branch", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatch( + mock.GetReposGitRefByOwnerByRepoByRef, + mockSourceRef, + ), + mock.WithRequestMatchHandler( + mock.PostReposGitRefsByOwnerByRepo, + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusUnprocessableEntity) + _, _ = w.Write([]byte(`{"message": "Reference already exists"}`)) + }), + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "branch": "existing-branch", + "from_branch": "main", + }, + expectError: true, + expectedErrMsg: "failed to create branch", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Setup client with mock + client := github.NewClient(tc.mockedClient) + _, handler := createBranch(client) + + // Create call request + request := createMCPRequest(tc.requestArgs) + + // Call handler + result, err := handler(context.Background(), request) + + // Verify results + if tc.expectError { + require.Error(t, err) + assert.Contains(t, err.Error(), tc.expectedErrMsg) + return + } + + require.NoError(t, err) + + // Parse the result and get the text content if no error + textContent := getTextResult(t, result) + + // Unmarshal and verify the result + var returnedRef github.Reference + err = json.Unmarshal([]byte(textContent.Text), &returnedRef) + require.NoError(t, err) + assert.Equal(t, *tc.expectedRef.Ref, *returnedRef.Ref) + assert.Equal(t, *tc.expectedRef.Object.SHA, *returnedRef.Object.SHA) + }) + } +} + +func Test_ListCommits(t *testing.T) { + // Verify tool definition once + mockClient := github.NewClient(nil) + tool, _ := listCommits(mockClient) + + assert.Equal(t, "list_commits", tool.Name) + assert.NotEmpty(t, tool.Description) + assert.Contains(t, tool.InputSchema.Properties, "owner") + assert.Contains(t, tool.InputSchema.Properties, "repo") + assert.Contains(t, tool.InputSchema.Properties, "sha") + assert.Contains(t, tool.InputSchema.Properties, "page") + assert.Contains(t, tool.InputSchema.Properties, "per_page") + assert.ElementsMatch(t, tool.InputSchema.Required, []string{"owner", "repo"}) + + // Setup mock commits for success case + mockCommits := []*github.RepositoryCommit{ + { + SHA: github.Ptr("abc123def456"), + Commit: &github.Commit{ + Message: github.Ptr("First commit"), + Author: &github.CommitAuthor{ + Name: github.Ptr("Test User"), + Email: github.Ptr("test@example.com"), + Date: &github.Timestamp{Time: time.Now().Add(-48 * time.Hour)}, + }, + }, + Author: &github.User{ + Login: github.Ptr("testuser"), + }, + HTMLURL: github.Ptr("https://github.com/owner/repo/commit/abc123def456"), + }, + { + SHA: github.Ptr("def456abc789"), + Commit: &github.Commit{ + Message: github.Ptr("Second commit"), + Author: &github.CommitAuthor{ + Name: github.Ptr("Another User"), + Email: github.Ptr("another@example.com"), + Date: &github.Timestamp{Time: time.Now().Add(-24 * time.Hour)}, + }, + }, + Author: &github.User{ + Login: github.Ptr("anotheruser"), + }, + HTMLURL: github.Ptr("https://github.com/owner/repo/commit/def456abc789"), + }, + } + + tests := []struct { + name string + mockedClient *http.Client + requestArgs map[string]interface{} + expectError bool + expectedCommits []*github.RepositoryCommit + expectedErrMsg string + }{ + { + name: "successful commits fetch with default params", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatch( + mock.GetReposCommitsByOwnerByRepo, + mockCommits, + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + }, + expectError: false, + expectedCommits: mockCommits, + }, + { + name: "successful commits fetch with branch", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatch( + mock.GetReposCommitsByOwnerByRepo, + mockCommits, + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "sha": "main", + }, + expectError: false, + expectedCommits: mockCommits, + }, + { + name: "successful commits fetch with pagination", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatch( + mock.GetReposCommitsByOwnerByRepo, + mockCommits, + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "page": float64(2), + "per_page": float64(10), + }, + expectError: false, + expectedCommits: mockCommits, + }, + { + name: "commits fetch fails", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.GetReposCommitsByOwnerByRepo, + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + _, _ = w.Write([]byte(`{"message": "Not Found"}`)) + }), + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "nonexistent-repo", + }, + expectError: true, + expectedErrMsg: "failed to list commits", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Setup client with mock + client := github.NewClient(tc.mockedClient) + _, handler := listCommits(client) + + // Create call request + request := createMCPRequest(tc.requestArgs) + + // Call handler + result, err := handler(context.Background(), request) + + // Verify results + if tc.expectError { + require.Error(t, err) + assert.Contains(t, err.Error(), tc.expectedErrMsg) + return + } + + require.NoError(t, err) + + // Parse the result and get the text content if no error + textContent := getTextResult(t, result) + + // Unmarshal and verify the result + var returnedCommits []*github.RepositoryCommit + err = json.Unmarshal([]byte(textContent.Text), &returnedCommits) + require.NoError(t, err) + assert.Len(t, returnedCommits, len(tc.expectedCommits)) + for i, commit := range returnedCommits { + assert.Equal(t, *tc.expectedCommits[i].SHA, *commit.SHA) + assert.Equal(t, *tc.expectedCommits[i].Commit.Message, *commit.Commit.Message) + assert.Equal(t, *tc.expectedCommits[i].Author.Login, *commit.Author.Login) + assert.Equal(t, *tc.expectedCommits[i].HTMLURL, *commit.HTMLURL) + } + }) + } +} + +func Test_CreateOrUpdateFile(t *testing.T) { + // Verify tool definition once + mockClient := github.NewClient(nil) + tool, _ := createOrUpdateFile(mockClient) + + assert.Equal(t, "create_or_update_file", tool.Name) + assert.NotEmpty(t, tool.Description) + assert.Contains(t, tool.InputSchema.Properties, "owner") + assert.Contains(t, tool.InputSchema.Properties, "repo") + assert.Contains(t, tool.InputSchema.Properties, "path") + assert.Contains(t, tool.InputSchema.Properties, "content") + assert.Contains(t, tool.InputSchema.Properties, "message") + assert.Contains(t, tool.InputSchema.Properties, "branch") + assert.Contains(t, tool.InputSchema.Properties, "sha") + assert.ElementsMatch(t, tool.InputSchema.Required, []string{"owner", "repo", "path", "content", "message", "branch"}) + + // Setup mock file content response + mockFileResponse := &github.RepositoryContentResponse{ + Content: &github.RepositoryContent{ + Name: github.Ptr("example.md"), + Path: github.Ptr("docs/example.md"), + SHA: github.Ptr("abc123def456"), + Size: github.Ptr(42), + HTMLURL: github.Ptr("https://github.com/owner/repo/blob/main/docs/example.md"), + DownloadURL: github.Ptr("https://raw.githubusercontent.com/owner/repo/main/docs/example.md"), + }, + Commit: github.Commit{ + SHA: github.Ptr("def456abc789"), + Message: github.Ptr("Add example file"), + Author: &github.CommitAuthor{ + Name: github.Ptr("Test User"), + Email: github.Ptr("test@example.com"), + Date: &github.Timestamp{Time: time.Now()}, + }, + HTMLURL: github.Ptr("https://github.com/owner/repo/commit/def456abc789"), + }, + } + + tests := []struct { + name string + mockedClient *http.Client + requestArgs map[string]interface{} + expectError bool + expectedContent *github.RepositoryContentResponse + expectedErrMsg string + }{ + { + name: "successful file creation", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatch( + mock.PutReposContentsByOwnerByRepoByPath, + mockFileResponse, + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "path": "docs/example.md", + "content": "# Example\n\nThis is an example file.", + "message": "Add example file", + "branch": "main", + }, + expectError: false, + expectedContent: mockFileResponse, + }, + { + name: "successful file update with SHA", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatch( + mock.PutReposContentsByOwnerByRepoByPath, + mockFileResponse, + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "path": "docs/example.md", + "content": "# Updated Example\n\nThis file has been updated.", + "message": "Update example file", + "branch": "main", + "sha": "abc123def456", + }, + expectError: false, + expectedContent: mockFileResponse, + }, + { + name: "file creation fails", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.PutReposContentsByOwnerByRepoByPath, + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusUnprocessableEntity) + _, _ = w.Write([]byte(`{"message": "Invalid request"}`)) + }), + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "path": "docs/example.md", + "content": "#Invalid Content", + "message": "Invalid request", + "branch": "nonexistent-branch", + }, + expectError: true, + expectedErrMsg: "failed to create/update file", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Setup client with mock + client := github.NewClient(tc.mockedClient) + _, handler := createOrUpdateFile(client) + + // Create call request + request := createMCPRequest(tc.requestArgs) + + // Call handler + result, err := handler(context.Background(), request) + + // Verify results + if tc.expectError { + require.Error(t, err) + assert.Contains(t, err.Error(), tc.expectedErrMsg) + return + } + + require.NoError(t, err) + + // Parse the result and get the text content if no error + textContent := getTextResult(t, result) + + // Unmarshal and verify the result + var returnedContent github.RepositoryContentResponse + err = json.Unmarshal([]byte(textContent.Text), &returnedContent) + require.NoError(t, err) + + // Verify content + assert.Equal(t, *tc.expectedContent.Content.Name, *returnedContent.Content.Name) + assert.Equal(t, *tc.expectedContent.Content.Path, *returnedContent.Content.Path) + assert.Equal(t, *tc.expectedContent.Content.SHA, *returnedContent.Content.SHA) + + // Verify commit + assert.Equal(t, *tc.expectedContent.Commit.SHA, *returnedContent.Commit.SHA) + assert.Equal(t, *tc.expectedContent.Commit.Message, *returnedContent.Commit.Message) + }) + } +} + +func Test_CreateRepository(t *testing.T) { + // Verify tool definition once + mockClient := github.NewClient(nil) + tool, _ := createRepository(mockClient) + + assert.Equal(t, "create_repository", tool.Name) + assert.NotEmpty(t, tool.Description) + assert.Contains(t, tool.InputSchema.Properties, "name") + assert.Contains(t, tool.InputSchema.Properties, "description") + assert.Contains(t, tool.InputSchema.Properties, "private") + assert.Contains(t, tool.InputSchema.Properties, "auto_init") + assert.ElementsMatch(t, tool.InputSchema.Required, []string{"name"}) + + // Setup mock repository response + mockRepo := &github.Repository{ + Name: github.Ptr("test-repo"), + Description: github.Ptr("Test repository"), + Private: github.Ptr(true), + HTMLURL: github.Ptr("https://github.com/testuser/test-repo"), + CloneURL: github.Ptr("https://github.com/testuser/test-repo.git"), + CreatedAt: &github.Timestamp{Time: time.Now()}, + Owner: &github.User{ + Login: github.Ptr("testuser"), + }, + } + + tests := []struct { + name string + mockedClient *http.Client + requestArgs map[string]interface{} + expectError bool + expectedRepo *github.Repository + expectedErrMsg string + }{ + { + name: "successful repository creation with all parameters", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.EndpointPattern{ + Pattern: "/user/repos", + Method: "POST", + }, + mockResponse(t, http.StatusCreated, mockRepo), + ), + ), + requestArgs: map[string]interface{}{ + "name": "test-repo", + "description": "Test repository", + "private": true, + "auto_init": true, + }, + expectError: false, + expectedRepo: mockRepo, + }, + { + name: "successful repository creation with minimal parameters", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.EndpointPattern{ + Pattern: "/user/repos", + Method: "POST", + }, + mockResponse(t, http.StatusCreated, mockRepo), + ), + ), + requestArgs: map[string]interface{}{ + "name": "test-repo", + }, + expectError: false, + expectedRepo: mockRepo, + }, + { + name: "repository creation fails", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.EndpointPattern{ + Pattern: "/user/repos", + Method: "POST", + }, + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusUnprocessableEntity) + _, _ = w.Write([]byte(`{"message": "Repository creation failed"}`)) + }), + ), + ), + requestArgs: map[string]interface{}{ + "name": "invalid-repo", + }, + expectError: true, + expectedErrMsg: "failed to create repository", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Setup client with mock + client := github.NewClient(tc.mockedClient) + _, handler := createRepository(client) + + // Create call request + request := createMCPRequest(tc.requestArgs) + + // Call handler + result, err := handler(context.Background(), request) + + // Verify results + if tc.expectError { + require.Error(t, err) + assert.Contains(t, err.Error(), tc.expectedErrMsg) + return + } + + require.NoError(t, err) + + // Parse the result and get the text content if no error + textContent := getTextResult(t, result) + + // Unmarshal and verify the result + var returnedRepo github.Repository + err = json.Unmarshal([]byte(textContent.Text), &returnedRepo) + assert.NoError(t, err) + + // Verify repository details + assert.Equal(t, *tc.expectedRepo.Name, *returnedRepo.Name) + assert.Equal(t, *tc.expectedRepo.Description, *returnedRepo.Description) + assert.Equal(t, *tc.expectedRepo.Private, *returnedRepo.Private) + assert.Equal(t, *tc.expectedRepo.HTMLURL, *returnedRepo.HTMLURL) + assert.Equal(t, *tc.expectedRepo.Owner.Login, *returnedRepo.Owner.Login) + }) + } +} diff --git a/pkg/github/search_test.go b/pkg/github/search_test.go new file mode 100644 index 00000000..d43fd843 --- /dev/null +++ b/pkg/github/search_test.go @@ -0,0 +1,429 @@ +package github + +import ( + "context" + "encoding/json" + "net/http" + "testing" + + "github.com/google/go-github/v69/github" + "github.com/migueleliasweb/go-github-mock/src/mock" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_SearchRepositories(t *testing.T) { + // Verify tool definition once + mockClient := github.NewClient(nil) + tool, _ := searchRepositories(mockClient) + + assert.Equal(t, "search_repositories", tool.Name) + assert.NotEmpty(t, tool.Description) + assert.Contains(t, tool.InputSchema.Properties, "query") + assert.Contains(t, tool.InputSchema.Properties, "page") + assert.Contains(t, tool.InputSchema.Properties, "per_page") + assert.ElementsMatch(t, tool.InputSchema.Required, []string{"query"}) + + // Setup mock search results + mockSearchResult := &github.RepositoriesSearchResult{ + Total: github.Ptr(2), + IncompleteResults: github.Ptr(false), + Repositories: []*github.Repository{ + { + ID: github.Ptr(int64(12345)), + Name: github.Ptr("repo-1"), + FullName: github.Ptr("owner/repo-1"), + HTMLURL: github.Ptr("https://github.com/owner/repo-1"), + Description: github.Ptr("Test repository 1"), + StargazersCount: github.Ptr(100), + }, + { + ID: github.Ptr(int64(67890)), + Name: github.Ptr("repo-2"), + FullName: github.Ptr("owner/repo-2"), + HTMLURL: github.Ptr("https://github.com/owner/repo-2"), + Description: github.Ptr("Test repository 2"), + StargazersCount: github.Ptr(50), + }, + }, + } + + tests := []struct { + name string + mockedClient *http.Client + requestArgs map[string]interface{} + expectError bool + expectedResult *github.RepositoriesSearchResult + expectedErrMsg string + }{ + { + name: "successful repository search", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatch( + mock.GetSearchRepositories, + mockSearchResult, + ), + ), + requestArgs: map[string]interface{}{ + "query": "golang test", + "page": float64(1), + "per_page": float64(30), + }, + expectError: false, + expectedResult: mockSearchResult, + }, + { + name: "repository search with default pagination", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatch( + mock.GetSearchRepositories, + mockSearchResult, + ), + ), + requestArgs: map[string]interface{}{ + "query": "golang test", + }, + expectError: false, + expectedResult: mockSearchResult, + }, + { + name: "search fails", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.GetSearchRepositories, + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusBadRequest) + _, _ = w.Write([]byte(`{"message": "Invalid query"}`)) + }), + ), + ), + requestArgs: map[string]interface{}{ + "query": "invalid:query", + }, + expectError: true, + expectedErrMsg: "failed to search repositories", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Setup client with mock + client := github.NewClient(tc.mockedClient) + _, handler := searchRepositories(client) + + // Create call request + request := createMCPRequest(tc.requestArgs) + + // Call handler + result, err := handler(context.Background(), request) + + // Verify results + if tc.expectError { + require.Error(t, err) + assert.Contains(t, err.Error(), tc.expectedErrMsg) + return + } + + require.NoError(t, err) + + // Parse the result and get the text content if no error + textContent := getTextResult(t, result) + + // Unmarshal and verify the result + var returnedResult github.RepositoriesSearchResult + err = json.Unmarshal([]byte(textContent.Text), &returnedResult) + require.NoError(t, err) + assert.Equal(t, *tc.expectedResult.Total, *returnedResult.Total) + assert.Equal(t, *tc.expectedResult.IncompleteResults, *returnedResult.IncompleteResults) + assert.Len(t, returnedResult.Repositories, len(tc.expectedResult.Repositories)) + for i, repo := range returnedResult.Repositories { + assert.Equal(t, *tc.expectedResult.Repositories[i].ID, *repo.ID) + assert.Equal(t, *tc.expectedResult.Repositories[i].Name, *repo.Name) + assert.Equal(t, *tc.expectedResult.Repositories[i].FullName, *repo.FullName) + assert.Equal(t, *tc.expectedResult.Repositories[i].HTMLURL, *repo.HTMLURL) + } + + }) + } +} + +func Test_SearchCode(t *testing.T) { + // Verify tool definition once + mockClient := github.NewClient(nil) + tool, _ := searchCode(mockClient) + + assert.Equal(t, "search_code", tool.Name) + assert.NotEmpty(t, tool.Description) + assert.Contains(t, tool.InputSchema.Properties, "q") + assert.Contains(t, tool.InputSchema.Properties, "sort") + assert.Contains(t, tool.InputSchema.Properties, "order") + assert.Contains(t, tool.InputSchema.Properties, "per_page") + assert.Contains(t, tool.InputSchema.Properties, "page") + assert.ElementsMatch(t, tool.InputSchema.Required, []string{"q"}) + + // Setup mock search results + mockSearchResult := &github.CodeSearchResult{ + Total: github.Ptr(2), + IncompleteResults: github.Ptr(false), + CodeResults: []*github.CodeResult{ + { + Name: github.Ptr("file1.go"), + Path: github.Ptr("path/to/file1.go"), + SHA: github.Ptr("abc123def456"), + HTMLURL: github.Ptr("https://github.com/owner/repo/blob/main/path/to/file1.go"), + Repository: &github.Repository{Name: github.Ptr("repo"), FullName: github.Ptr("owner/repo")}, + }, + { + Name: github.Ptr("file2.go"), + Path: github.Ptr("path/to/file2.go"), + SHA: github.Ptr("def456abc123"), + HTMLURL: github.Ptr("https://github.com/owner/repo/blob/main/path/to/file2.go"), + Repository: &github.Repository{Name: github.Ptr("repo"), FullName: github.Ptr("owner/repo")}, + }, + }, + } + + tests := []struct { + name string + mockedClient *http.Client + requestArgs map[string]interface{} + expectError bool + expectedResult *github.CodeSearchResult + expectedErrMsg string + }{ + { + name: "successful code search with all parameters", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatch( + mock.GetSearchCode, + mockSearchResult, + ), + ), + requestArgs: map[string]interface{}{ + "q": "fmt.Println language:go", + "sort": "indexed", + "order": "desc", + "page": float64(1), + "per_page": float64(30), + }, + expectError: false, + expectedResult: mockSearchResult, + }, + { + name: "code search with minimal parameters", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatch( + mock.GetSearchCode, + mockSearchResult, + ), + ), + requestArgs: map[string]interface{}{ + "q": "fmt.Println language:go", + }, + expectError: false, + expectedResult: mockSearchResult, + }, + { + name: "search code fails", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.GetSearchCode, + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusBadRequest) + _, _ = w.Write([]byte(`{"message": "Validation Failed"}`)) + }), + ), + ), + requestArgs: map[string]interface{}{ + "q": "invalid:query", + }, + expectError: true, + expectedErrMsg: "failed to search code", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Setup client with mock + client := github.NewClient(tc.mockedClient) + _, handler := searchCode(client) + + // Create call request + request := createMCPRequest(tc.requestArgs) + + // Call handler + result, err := handler(context.Background(), request) + + // Verify results + if tc.expectError { + require.Error(t, err) + assert.Contains(t, err.Error(), tc.expectedErrMsg) + return + } + + require.NoError(t, err) + + // Parse the result and get the text content if no error + textContent := getTextResult(t, result) + + // Unmarshal and verify the result + var returnedResult github.CodeSearchResult + err = json.Unmarshal([]byte(textContent.Text), &returnedResult) + require.NoError(t, err) + assert.Equal(t, *tc.expectedResult.Total, *returnedResult.Total) + assert.Equal(t, *tc.expectedResult.IncompleteResults, *returnedResult.IncompleteResults) + assert.Len(t, returnedResult.CodeResults, len(tc.expectedResult.CodeResults)) + for i, code := range returnedResult.CodeResults { + assert.Equal(t, *tc.expectedResult.CodeResults[i].Name, *code.Name) + assert.Equal(t, *tc.expectedResult.CodeResults[i].Path, *code.Path) + assert.Equal(t, *tc.expectedResult.CodeResults[i].SHA, *code.SHA) + assert.Equal(t, *tc.expectedResult.CodeResults[i].HTMLURL, *code.HTMLURL) + assert.Equal(t, *tc.expectedResult.CodeResults[i].Repository.FullName, *code.Repository.FullName) + } + }) + } +} + +func Test_SearchUsers(t *testing.T) { + // Verify tool definition once + mockClient := github.NewClient(nil) + tool, _ := searchUsers(mockClient) + + assert.Equal(t, "search_users", tool.Name) + assert.NotEmpty(t, tool.Description) + assert.Contains(t, tool.InputSchema.Properties, "q") + assert.Contains(t, tool.InputSchema.Properties, "sort") + assert.Contains(t, tool.InputSchema.Properties, "order") + assert.Contains(t, tool.InputSchema.Properties, "per_page") + assert.Contains(t, tool.InputSchema.Properties, "page") + assert.ElementsMatch(t, tool.InputSchema.Required, []string{"q"}) + + // Setup mock search results + mockSearchResult := &github.UsersSearchResult{ + Total: github.Ptr(2), + IncompleteResults: github.Ptr(false), + Users: []*github.User{ + { + Login: github.Ptr("user1"), + ID: github.Ptr(int64(1001)), + HTMLURL: github.Ptr("https://github.com/user1"), + AvatarURL: github.Ptr("https://avatars.githubusercontent.com/u/1001"), + Type: github.Ptr("User"), + Followers: github.Ptr(100), + Following: github.Ptr(50), + }, + { + Login: github.Ptr("user2"), + ID: github.Ptr(int64(1002)), + HTMLURL: github.Ptr("https://github.com/user2"), + AvatarURL: github.Ptr("https://avatars.githubusercontent.com/u/1002"), + Type: github.Ptr("User"), + Followers: github.Ptr(200), + Following: github.Ptr(75), + }, + }, + } + + tests := []struct { + name string + mockedClient *http.Client + requestArgs map[string]interface{} + expectError bool + expectedResult *github.UsersSearchResult + expectedErrMsg string + }{ + { + name: "successful users search with all parameters", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatch( + mock.GetSearchUsers, + mockSearchResult, + ), + ), + requestArgs: map[string]interface{}{ + "q": "location:finland language:go", + "sort": "followers", + "order": "desc", + "page": float64(1), + "per_page": float64(30), + }, + expectError: false, + expectedResult: mockSearchResult, + }, + { + name: "users search with minimal parameters", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatch( + mock.GetSearchUsers, + mockSearchResult, + ), + ), + requestArgs: map[string]interface{}{ + "q": "location:finland language:go", + }, + expectError: false, + expectedResult: mockSearchResult, + }, + { + name: "search users fails", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.GetSearchUsers, + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusBadRequest) + _, _ = w.Write([]byte(`{"message": "Validation Failed"}`)) + }), + ), + ), + requestArgs: map[string]interface{}{ + "q": "invalid:query", + }, + expectError: true, + expectedErrMsg: "failed to search users", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Setup client with mock + client := github.NewClient(tc.mockedClient) + _, handler := searchUsers(client) + + // Create call request + request := createMCPRequest(tc.requestArgs) + + // Call handler + result, err := handler(context.Background(), request) + + // Verify results + if tc.expectError { + require.Error(t, err) + assert.Contains(t, err.Error(), tc.expectedErrMsg) + return + } + + require.NoError(t, err) + + // Parse the result and get the text content if no error + require.NotNil(t, result) + + textContent := getTextResult(t, result) + + // Unmarshal and verify the result + var returnedResult github.UsersSearchResult + err = json.Unmarshal([]byte(textContent.Text), &returnedResult) + require.NoError(t, err) + assert.Equal(t, *tc.expectedResult.Total, *returnedResult.Total) + assert.Equal(t, *tc.expectedResult.IncompleteResults, *returnedResult.IncompleteResults) + assert.Len(t, returnedResult.Users, len(tc.expectedResult.Users)) + for i, user := range returnedResult.Users { + assert.Equal(t, *tc.expectedResult.Users[i].Login, *user.Login) + assert.Equal(t, *tc.expectedResult.Users[i].ID, *user.ID) + assert.Equal(t, *tc.expectedResult.Users[i].HTMLURL, *user.HTMLURL) + assert.Equal(t, *tc.expectedResult.Users[i].AvatarURL, *user.AvatarURL) + assert.Equal(t, *tc.expectedResult.Users[i].Type, *user.Type) + assert.Equal(t, *tc.expectedResult.Users[i].Followers, *user.Followers) + } + }) + } +} diff --git a/pkg/github/server.go b/pkg/github/server.go index b3ef7016..0a90b4d1 100644 --- a/pkg/github/server.go +++ b/pkg/github/server.go @@ -3,8 +3,10 @@ package github import ( "context" "encoding/json" + "errors" "fmt" "io" + "net/http" "github.com/google/go-github/v69/github" "github.com/mark3labs/mcp-go/mcp" @@ -73,7 +75,7 @@ func getMe(client *github.Client) (tool mcp.Tool, handler server.ToolHandlerFunc } defer func() { _ = resp.Body.Close() }() - if resp.StatusCode != 200 { + if resp.StatusCode != http.StatusOK { body, err := io.ReadAll(resp.Body) if err != nil { return nil, fmt.Errorf("failed to read response body: %w", err) @@ -89,3 +91,9 @@ func getMe(client *github.Client) (tool mcp.Tool, handler server.ToolHandlerFunc return mcp.NewToolResultText(string(r)), nil } } + +// isAcceptedError checks if the error is an accepted error. +func isAcceptedError(err error) bool { + var acceptedError *github.AcceptedError + return errors.As(err, &acceptedError) +} diff --git a/pkg/github/server_test.go b/pkg/github/server_test.go new file mode 100644 index 00000000..d56993de --- /dev/null +++ b/pkg/github/server_test.go @@ -0,0 +1,168 @@ +package github + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "testing" + "time" + + "github.com/google/go-github/v69/github" + "github.com/migueleliasweb/go-github-mock/src/mock" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_GetMe(t *testing.T) { + // Verify tool definition + mockClient := github.NewClient(nil) + tool, _ := getMe(mockClient) + + assert.Equal(t, "get_me", tool.Name) + assert.NotEmpty(t, tool.Description) + assert.Contains(t, tool.InputSchema.Properties, "reason") + assert.Empty(t, tool.InputSchema.Required) // No required parameters + + // Setup mock user response + mockUser := &github.User{ + Login: github.Ptr("testuser"), + Name: github.Ptr("Test User"), + Email: github.Ptr("test@example.com"), + Bio: github.Ptr("GitHub user for testing"), + Company: github.Ptr("Test Company"), + Location: github.Ptr("Test Location"), + HTMLURL: github.Ptr("https://github.com/testuser"), + CreatedAt: &github.Timestamp{Time: time.Now().Add(-365 * 24 * time.Hour)}, + Type: github.Ptr("User"), + Plan: &github.Plan{ + Name: github.Ptr("pro"), + }, + } + + tests := []struct { + name string + mockedClient *http.Client + requestArgs map[string]interface{} + expectError bool + expectedUser *github.User + expectedErrMsg string + }{ + { + name: "successful get user", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatch( + mock.GetUser, + mockUser, + ), + ), + requestArgs: map[string]interface{}{}, + expectError: false, + expectedUser: mockUser, + }, + { + name: "successful get user with reason", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatch( + mock.GetUser, + mockUser, + ), + ), + requestArgs: map[string]interface{}{ + "reason": "Testing API", + }, + expectError: false, + expectedUser: mockUser, + }, + { + name: "get user fails", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.GetUser, + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusUnauthorized) + _, _ = w.Write([]byte(`{"message": "Unauthorized"}`)) + }), + ), + ), + requestArgs: map[string]interface{}{}, + expectError: true, + expectedErrMsg: "failed to get user", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Setup client with mock + client := github.NewClient(tc.mockedClient) + _, handler := getMe(client) + + // Create call request + request := createMCPRequest(tc.requestArgs) + + // Call handler + result, err := handler(context.Background(), request) + + // Verify results + if tc.expectError { + require.Error(t, err) + assert.Contains(t, err.Error(), tc.expectedErrMsg) + return + } + + require.NoError(t, err) + + // Parse result and get text content if no error + textContent := getTextResult(t, result) + + // Unmarshal and verify the result + var returnedUser github.User + err = json.Unmarshal([]byte(textContent.Text), &returnedUser) + require.NoError(t, err) + + // Verify user details + assert.Equal(t, *tc.expectedUser.Login, *returnedUser.Login) + assert.Equal(t, *tc.expectedUser.Name, *returnedUser.Name) + assert.Equal(t, *tc.expectedUser.Email, *returnedUser.Email) + assert.Equal(t, *tc.expectedUser.Bio, *returnedUser.Bio) + assert.Equal(t, *tc.expectedUser.HTMLURL, *returnedUser.HTMLURL) + assert.Equal(t, *tc.expectedUser.Type, *returnedUser.Type) + }) + } +} + +func Test_IsAcceptedError(t *testing.T) { + tests := []struct { + name string + err error + expectAccepted bool + }{ + { + name: "github AcceptedError", + err: &github.AcceptedError{}, + expectAccepted: true, + }, + { + name: "regular error", + err: fmt.Errorf("some other error"), + expectAccepted: false, + }, + { + name: "nil error", + err: nil, + expectAccepted: false, + }, + { + name: "wrapped AcceptedError", + err: fmt.Errorf("wrapped: %w", &github.AcceptedError{}), + expectAccepted: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + result := isAcceptedError(tc.err) + assert.Equal(t, tc.expectAccepted, result) + }) + } +}