diff --git a/e2e/e2e_test.go b/e2e/e2e_test.go index 71bd5a8a..6e7d1d51 100644 --- a/e2e/e2e_test.go +++ b/e2e/e2e_test.go @@ -182,7 +182,7 @@ func setupMCPClient(t *testing.T, options ...clientOption) *mcpClient.Client { require.NoError(t, err, "expected to construct MCP server successfully") t.Log("Starting In Process MCP client...") - client, err = mcpClient.NewInProcessClient(ghServer) + client, err = mcpClient.NewInProcessClient(ghServer.GetMCPServer()) require.NoError(t, err, "expected to create in-process client successfully") } diff --git a/go.mod b/go.mod index 684ce8f2..bf9d2055 100644 --- a/go.mod +++ b/go.mod @@ -5,7 +5,7 @@ go 1.23.7 require ( github.com/google/go-github/v69 v69.2.0 github.com/josephburnett/jd v1.9.2 - github.com/mark3labs/mcp-go v0.30.0 + github.com/mark3labs/mcp-go v0.30.1 github.com/migueleliasweb/go-github-mock v1.3.0 github.com/sirupsen/logrus v1.9.3 github.com/spf13/cobra v1.9.1 diff --git a/go.sum b/go.sum index c2da59f6..e0820051 100644 --- a/go.sum +++ b/go.sum @@ -49,6 +49,8 @@ github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0 github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc= github.com/mark3labs/mcp-go v0.30.0 h1:Taz7fiefkxY/l8jz1nA90V+WdM2eoMtlvwfWforVYbo= github.com/mark3labs/mcp-go v0.30.0/go.mod h1:rXqOudj/djTORU/ThxYx8fqEVj/5pvTuuebQ2RC7uk4= +github.com/mark3labs/mcp-go v0.30.1 h1:3R1BPvNT/rC1iPpLx+EMXFy+gvux/Mz/Nio3c6XEU9E= +github.com/mark3labs/mcp-go v0.30.1/go.mod h1:rXqOudj/djTORU/ThxYx8fqEVj/5pvTuuebQ2RC7uk4= github.com/migueleliasweb/go-github-mock v1.3.0 h1:2sVP9JEMB2ubQw1IKto3/fzF51oFC6eVWOOFDgQoq88= github.com/migueleliasweb/go-github-mock v1.3.0/go.mod h1:ipQhV8fTcj/G6m7BKzin08GaJ/3B5/SonRAkgrk0zCY= github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno= diff --git a/internal/ghmcp/server.go b/internal/ghmcp/server.go index a75a9e0c..bd1c9e88 100644 --- a/internal/ghmcp/server.go +++ b/internal/ghmcp/server.go @@ -47,7 +47,7 @@ type MCPServerConfig struct { Translator translations.TranslationHelperFunc } -func NewMCPServer(cfg MCPServerConfig) (*server.MCPServer, error) { +func NewMCPServer(cfg MCPServerConfig) (*github.GitHubMCPServer, error) { apiHost, err := parseAPIHost(cfg.Host) if err != nil { return nil, fmt.Errorf("failed to parse API host: %w", err) @@ -91,8 +91,6 @@ func NewMCPServer(cfg MCPServerConfig) (*server.MCPServer, error) { OnBeforeInitialize: []server.OnBeforeInitializeFunc{beforeInit}, } - ghServer := github.NewServer(cfg.Version, server.WithHooks(hooks)) - enabledToolsets := cfg.EnabledToolsets if cfg.DynamicToolsets { // filter "all" from the enabled toolsets @@ -112,6 +110,8 @@ func NewMCPServer(cfg MCPServerConfig) (*server.MCPServer, error) { return gqlClient, nil // closing over client } + ghServer := github.NewGitHubServer(cfg.Version, getClient, server.WithHooks(hooks)) + // Create default toolsets toolsets, err := github.InitToolsets( enabledToolsets, @@ -125,15 +125,15 @@ func NewMCPServer(cfg MCPServerConfig) (*server.MCPServer, error) { } context := github.InitContextToolset(getClient, cfg.Translator) - github.RegisterResources(ghServer, getClient, cfg.Translator) + github.RegisterResources(ghServer.MCPServer, getClient, cfg.Translator) // Register the tools with the server - toolsets.RegisterTools(ghServer) - context.RegisterTools(ghServer) + toolsets.RegisterTools(ghServer.MCPServer) + context.RegisterTools(ghServer.MCPServer) if cfg.DynamicToolsets { - dynamic := github.InitDynamicToolset(ghServer, toolsets, cfg.Translator) - dynamic.RegisterTools(ghServer) + dynamic := github.InitDynamicToolset(ghServer.MCPServer, toolsets, cfg.Translator) + dynamic.RegisterTools(ghServer.MCPServer) } return ghServer, nil @@ -192,7 +192,7 @@ func RunStdioServer(cfg StdioServerConfig) error { return fmt.Errorf("failed to create MCP server: %w", err) } - stdioServer := server.NewStdioServer(ghServer) + stdioServer := github.NewCompletionAwareStdioServer(ghServer.GetMCPServer(), ghServer.GetCompletionHandler()) logrusLogger := logrus.New() if cfg.LogFilePath != "" { diff --git a/pkg/github/completion_integration_test.go b/pkg/github/completion_integration_test.go new file mode 100644 index 00000000..558a61f8 --- /dev/null +++ b/pkg/github/completion_integration_test.go @@ -0,0 +1,142 @@ +package github + +import ( + "context" + "testing" + + "github.com/google/go-github/v69/github" + "github.com/mark3labs/mcp-go/mcp" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestGitHubMCPServerCompletionIntegration(t *testing.T) { + // Mock client function + getClient := func(_ context.Context) (*github.Client, error) { + // Return a nil client - this will cause API calls to fail gracefully + // which is fine for testing the completion request handling flow + return nil, nil + } + + // Create a GitHub MCP server with completion support + ghServer := NewGitHubServer("test", getClient) + require.NotNil(t, ghServer) + + // Create an in-process client with our custom GitHubMCPServer transport + mcpClient, err := NewInProcessClientWithGitHubServer(ghServer) + require.NoError(t, err) + + // Initialize the client + ctx := context.Background() + request := mcp.InitializeRequest{} + request.Params.ProtocolVersion = "2025-03-26" + request.Params.ClientInfo = mcp.Implementation{ + Name: "test-client", + Version: "1.0.0", + } + + result, err := mcpClient.Initialize(ctx, request) + require.NoError(t, err) + assert.Equal(t, "github-mcp-server", result.ServerInfo.Name) + + // Test completion request - this should work even with a nil GitHub client + // because non-repo URIs return empty completions without calling GitHub APIs + completionRequest := mcp.CompleteRequest{ + Params: struct { + Ref any `json:"ref"` + Argument struct { + Name string `json:"name"` + Value string `json:"value"` + } `json:"argument"` + }{ + Ref: map[string]interface{}{ + "type": "ref/resource", + "uri": "file:///some/non-repo/path", + }, + Argument: struct { + Name string `json:"name"` + Value string `json:"value"` + }{ + Name: "param", + Value: "test", + }, + }, + } + + completionResult, err := mcpClient.Complete(ctx, completionRequest) + require.NoError(t, err) + require.NotNil(t, completionResult) + + // Should return empty completion for non-repo URIs + assert.Equal(t, []string{}, completionResult.Completion.Values) + assert.Equal(t, 0, completionResult.Completion.Total) + + // Test repo URI completion with unsupported argument + repoCompletionRequest := mcp.CompleteRequest{ + Params: struct { + Ref any `json:"ref"` + Argument struct { + Name string `json:"name"` + Value string `json:"value"` + } `json:"argument"` + }{ + Ref: map[string]interface{}{ + "type": "ref/resource", + "uri": "repo://{owner}/{repo}/contents{/path*}", + }, + Argument: struct { + Name string `json:"name"` + Value string `json:"value"` + }{ + Name: "unsupported", + Value: "test", + }, + }, + } + + repoCompletionResult, err := mcpClient.Complete(ctx, repoCompletionRequest) + require.NoError(t, err) + require.NotNil(t, repoCompletionResult) + + // Should return empty completion for unsupported arguments + assert.Equal(t, []string{}, repoCompletionResult.Completion.Values) + assert.Equal(t, 0, repoCompletionResult.Completion.Total) + + // Clean up + err = mcpClient.Close() + assert.NoError(t, err) +} + +func TestGitHubMCPServerCompletionCapabilities(t *testing.T) { + // Mock client function + getClient := func(_ context.Context) (*github.Client, error) { + return nil, nil + } + + // Create a GitHub MCP server with completion support + ghServer := NewGitHubServer("test", getClient) + require.NotNil(t, ghServer) + + // Create an in-process client with our custom GitHubMCPServer transport + mcpClient, err := NewInProcessClientWithGitHubServer(ghServer) + require.NoError(t, err) + + // Initialize the client + ctx := context.Background() + request := mcp.InitializeRequest{} + request.Params.ProtocolVersion = "2025-03-26" + request.Params.ClientInfo = mcp.Implementation{ + Name: "test-client", + Version: "1.0.0", + } + + result, err := mcpClient.Initialize(ctx, request) + require.NoError(t, err) + + // Check basic server info + assert.Equal(t, "github-mcp-server", result.ServerInfo.Name) + + // Clean up + err = mcpClient.Close() + assert.NoError(t, err) +} \ No newline at end of file diff --git a/pkg/github/completion_stdio_server.go b/pkg/github/completion_stdio_server.go new file mode 100644 index 00000000..27d13429 --- /dev/null +++ b/pkg/github/completion_stdio_server.go @@ -0,0 +1,141 @@ +package github + +import ( + "context" + "encoding/json" + "io" + "log" + + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" +) + +// CompletionAwareStdioServer wraps the MCP stdio server to add completion support +type CompletionAwareStdioServer struct { + baseServer *server.MCPServer + completionHandler CompletionHandlerFunc + errLogger *log.Logger +} + +// NewCompletionAwareStdioServer creates a new stdio server with completion support +func NewCompletionAwareStdioServer(mcpServer *server.MCPServer, completionHandler CompletionHandlerFunc) *CompletionAwareStdioServer { + return &CompletionAwareStdioServer{ + baseServer: mcpServer, + completionHandler: completionHandler, + errLogger: log.New(io.Discard, "", 0), // Default to discarding errors + } +} + +// SetErrorLogger sets the error logger for the server +func (s *CompletionAwareStdioServer) SetErrorLogger(logger *log.Logger) { + s.errLogger = logger +} + +// Listen starts the completion-aware stdio server +func (s *CompletionAwareStdioServer) Listen(ctx context.Context, stdin io.Reader, stdout io.Writer) error { + // Use the simplified approach: create a custom stdio server that mimics the real one + // but intercepts completion requests + + // We'll use the real stdio server from the mcp-go library and intercept the raw messages + realStdioServer := server.NewStdioServer(s.baseServer) + realStdioServer.SetErrorLogger(s.errLogger) + + // Create pipes to intercept messages + stdinPipe := &completionInterceptReader{ + original: stdin, + completionHandler: s.completionHandler, + baseServer: s.baseServer, + stdout: stdout, + ctx: ctx, + errLogger: s.errLogger, + } + + return realStdioServer.Listen(ctx, stdinPipe, stdout) +} + +// completionInterceptReader intercepts stdin to handle completion requests +type completionInterceptReader struct { + original io.Reader + completionHandler CompletionHandlerFunc + baseServer *server.MCPServer + stdout io.Writer + ctx context.Context + errLogger *log.Logger + buffer []byte + bufferPos int +} + +func (r *completionInterceptReader) Read(p []byte) (n int, err error) { + // If we have buffered data, return that first + if r.bufferPos < len(r.buffer) { + n = copy(p, r.buffer[r.bufferPos:]) + r.bufferPos += n + if r.bufferPos >= len(r.buffer) { + r.buffer = nil + r.bufferPos = 0 + } + return n, nil + } + + // Read from original source + n, err = r.original.Read(p) + if err != nil { + return n, err + } + + // Check if this contains a completion request + data := p[:n] + if r.isCompletionRequest(data) { + // Handle completion request directly + response := r.handleCompletionRequest(data) + if response != nil { + // Write response to stdout + encoder := json.NewEncoder(r.stdout) + if encErr := encoder.Encode(response); encErr != nil { + r.errLogger.Printf("Error writing completion response: %v", encErr) + } + } + // Return EOF to the real server so it doesn't process this message + return 0, io.EOF + } + + return n, err +} + +// isCompletionRequest checks if the data contains a completion request +func (r *completionInterceptReader) isCompletionRequest(data []byte) bool { + var baseMessage struct { + Method string `json:"method"` + } + + if err := json.Unmarshal(data, &baseMessage); err != nil { + return false + } + + return baseMessage.Method == "completion/complete" +} + +// handleCompletionRequest processes completion requests +func (r *completionInterceptReader) handleCompletionRequest(data []byte) mcp.JSONRPCMessage { + var baseMessage struct { + JSONRPC string `json:"jsonrpc"` + ID any `json:"id"` + Method string `json:"method"` + } + + if err := json.Unmarshal(data, &baseMessage); err != nil { + return createErrorResponse(baseMessage.ID, mcp.PARSE_ERROR, "Failed to parse completion request") + } + + var request mcp.CompleteRequest + if err := json.Unmarshal(data, &request); err != nil { + return createErrorResponse(baseMessage.ID, mcp.INVALID_REQUEST, "Failed to parse completion request") + } + + result, err := r.completionHandler(r.ctx, request) + if err != nil { + return createErrorResponse(baseMessage.ID, mcp.INTERNAL_ERROR, err.Error()) + } + + return createResponse(baseMessage.ID, *result) +} \ No newline at end of file diff --git a/pkg/github/github_inprocess_client.go b/pkg/github/github_inprocess_client.go new file mode 100644 index 00000000..220bace5 --- /dev/null +++ b/pkg/github/github_inprocess_client.go @@ -0,0 +1,74 @@ +package github + +import ( + "context" + "encoding/json" + "fmt" + + "github.com/mark3labs/mcp-go/client" + "github.com/mark3labs/mcp-go/client/transport" + "github.com/mark3labs/mcp-go/mcp" +) + +// GitHubInProcessTransport creates an in-process transport that uses our GitHubMCPServer +// This ensures that completion requests go through our HandleMessage override +type GitHubInProcessTransport struct { + server *GitHubMCPServer + notificationHandler func(mcp.JSONRPCNotification) +} + +// NewGitHubInProcessTransport creates a new in-process transport for GitHubMCPServer +func NewGitHubInProcessTransport(server *GitHubMCPServer) *GitHubInProcessTransport { + return &GitHubInProcessTransport{ + server: server, + } +} + +func (c *GitHubInProcessTransport) Start(ctx context.Context) error { + return nil +} + +func (c *GitHubInProcessTransport) SendRequest(ctx context.Context, request transport.JSONRPCRequest) (*transport.JSONRPCResponse, error) { + requestBytes, err := json.Marshal(request) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + requestBytes = append(requestBytes, '\n') + + // This is the key part: call HandleMessage on our GitHubMCPServer + // which will route completion requests to our handler + respMessage := c.server.HandleMessage(ctx, requestBytes) + respByte, err := json.Marshal(respMessage) + if err != nil { + return nil, fmt.Errorf("failed to marshal response message: %w", err) + } + rpcResp := transport.JSONRPCResponse{} + err = json.Unmarshal(respByte, &rpcResp) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal response message: %w", err) + } + + return &rpcResp, nil +} + +func (c *GitHubInProcessTransport) SendNotification(ctx context.Context, notification mcp.JSONRPCNotification) error { + // For in-process transport, we can just forward notifications to the handler + if c.notificationHandler != nil { + c.notificationHandler(notification) + } + return nil +} + +func (c *GitHubInProcessTransport) SetNotificationHandler(handler func(mcp.JSONRPCNotification)) { + c.notificationHandler = handler +} + +func (c *GitHubInProcessTransport) Close() error { + return nil +} + +// NewInProcessClientWithGitHubServer creates a client that works with GitHubMCPServer +func NewInProcessClientWithGitHubServer(server *GitHubMCPServer) (*client.Client, error) { + ghTransport := NewGitHubInProcessTransport(server) + return client.NewClient(ghTransport), nil +} \ No newline at end of file diff --git a/pkg/github/repository_completions.go b/pkg/github/repository_completions.go new file mode 100644 index 00000000..6d2903e9 --- /dev/null +++ b/pkg/github/repository_completions.go @@ -0,0 +1,621 @@ +package github + +import ( + "context" + "fmt" + "strings" + + "github.com/google/go-github/v69/github" + "github.com/mark3labs/mcp-go/mcp" +) + +// CompletionHandlerFunc is a function that handles completion requests. +type CompletionHandlerFunc func(ctx context.Context, request mcp.CompleteRequest) (*mcp.CompleteResult, error) + +// RepositoryCompletionHandler handles completion requests for repository resources. +func RepositoryCompletionHandler(getClient GetClientFn) CompletionHandlerFunc { + return func(ctx context.Context, request mcp.CompleteRequest) (*mcp.CompleteResult, error) { + // Extract the resource reference + ref, ok := request.Params.Ref.(map[string]interface{}) + if !ok { + return nil, fmt.Errorf("invalid ref type") + } + + refType, ok := ref["type"].(string) + if !ok || refType != "ref/resource" { + return nil, fmt.Errorf("unsupported ref type: %s", refType) + } + + uri, ok := ref["uri"].(string) + if !ok { + return nil, fmt.Errorf("missing uri in resource reference") + } + + // Only handle repo:// URIs + if !strings.HasPrefix(uri, "repo://") { + return &mcp.CompleteResult{ + Completion: struct { + Values []string `json:"values"` + Total int `json:"total,omitempty"` + HasMore bool `json:"hasMore,omitempty"` + }{ + Values: []string{}, + Total: 0, + }, + }, nil + } + + argumentName := request.Params.Argument.Name + argumentValue := request.Params.Argument.Value + + switch argumentName { + case "owner": + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } + return completeOwner(ctx, client, argumentValue) + case "repo": + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } + return completeRepo(ctx, client, argumentValue, uri) + case "branch": + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } + return completeBranch(ctx, client, argumentValue, uri) + case "sha": + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } + return completeCommit(ctx, client, argumentValue, uri) + case "tag": + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } + return completeTag(ctx, client, argumentValue, uri) + case "prNumber": + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } + return completePullRequest(ctx, client, argumentValue, uri) + case "path": + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } + return completePath(ctx, client, argumentValue, uri) + default: + // Return empty completion for unsupported arguments + return &mcp.CompleteResult{ + Completion: struct { + Values []string `json:"values"` + Total int `json:"total,omitempty"` + HasMore bool `json:"hasMore,omitempty"` + }{ + Values: []string{}, + Total: 0, + }, + }, nil + } + } +} + +// completeOwner provides completions for repository owners (users/orgs) +func completeOwner(ctx context.Context, client *github.Client, value string) (*mcp.CompleteResult, error) { + if value == "" { + // Return empty completion for now - could add popular orgs/users here + return &mcp.CompleteResult{ + Completion: struct { + Values []string `json:"values"` + Total int `json:"total,omitempty"` + HasMore bool `json:"hasMore,omitempty"` + }{ + Values: []string{}, + Total: 0, + }, + }, nil + } + + // Search for users/organizations + query := fmt.Sprintf("%s in:login", value) + opts := &github.SearchOptions{ + ListOptions: github.ListOptions{ + Page: 1, + PerPage: 10, + }, + } + + result, _, err := client.Search.Users(ctx, query, opts) + if err != nil { + return nil, fmt.Errorf("failed to search users: %w", err) + } + + var values []string + for _, user := range result.Users { + if user.Login != nil && strings.HasPrefix(strings.ToLower(*user.Login), strings.ToLower(value)) { + values = append(values, *user.Login) + } + } + + total := 0 + hasMore := false + if result.Total != nil { + total = *result.Total + hasMore = total > len(values) + } + + return &mcp.CompleteResult{ + Completion: struct { + Values []string `json:"values"` + Total int `json:"total,omitempty"` + HasMore bool `json:"hasMore,omitempty"` + }{ + Values: values, + Total: total, + HasMore: hasMore, + }, + }, nil +} + +// completeRepo provides completions for repository names +func completeRepo(ctx context.Context, client *github.Client, value string, uri string) (*mcp.CompleteResult, error) { + // Extract owner from URI + owner := extractOwnerFromURI(uri) + if owner == "" { + return &mcp.CompleteResult{ + Completion: struct { + Values []string `json:"values"` + Total int `json:"total,omitempty"` + HasMore bool `json:"hasMore,omitempty"` + }{ + Values: []string{}, + Total: 0, + }, + }, nil + } + + // Search for repositories + query := fmt.Sprintf("user:%s %s in:name", owner, value) + if value == "" { + query = fmt.Sprintf("user:%s", owner) + } + + opts := &github.SearchOptions{ + ListOptions: github.ListOptions{ + Page: 1, + PerPage: 10, + }, + } + + result, _, err := client.Search.Repositories(ctx, query, opts) + if err != nil { + return nil, fmt.Errorf("failed to search repositories: %w", err) + } + + var values []string + for _, repo := range result.Repositories { + if repo.Name != nil && strings.HasPrefix(strings.ToLower(*repo.Name), strings.ToLower(value)) { + values = append(values, *repo.Name) + } + } + + total := 0 + hasMore := false + if result.Total != nil { + total = *result.Total + hasMore = total > len(values) + } + + return &mcp.CompleteResult{ + Completion: struct { + Values []string `json:"values"` + Total int `json:"total,omitempty"` + HasMore bool `json:"hasMore,omitempty"` + }{ + Values: values, + Total: total, + HasMore: hasMore, + }, + }, nil +} + +// completeBranch provides completions for branch names +func completeBranch(ctx context.Context, client *github.Client, value string, uri string) (*mcp.CompleteResult, error) { + owner, repo := extractOwnerRepoFromURI(uri) + if owner == "" || repo == "" { + return &mcp.CompleteResult{ + Completion: struct { + Values []string `json:"values"` + Total int `json:"total,omitempty"` + HasMore bool `json:"hasMore,omitempty"` + }{ + Values: []string{}, + Total: 0, + }, + }, nil + } + + // List branches + opts := &github.BranchListOptions{ + ListOptions: github.ListOptions{ + Page: 1, + PerPage: 30, + }, + } + + branches, _, err := client.Repositories.ListBranches(ctx, owner, repo, opts) + if err != nil { + return nil, fmt.Errorf("failed to list branches: %w", err) + } + + var values []string + for _, branch := range branches { + if branch.Name != nil && strings.HasPrefix(strings.ToLower(*branch.Name), strings.ToLower(value)) { + values = append(values, *branch.Name) + } + } + + return &mcp.CompleteResult{ + Completion: struct { + Values []string `json:"values"` + Total int `json:"total,omitempty"` + HasMore bool `json:"hasMore,omitempty"` + }{ + Values: values, + Total: len(values), + HasMore: len(branches) >= 30, // Might have more + }, + }, nil +} + +// completeCommit provides completions for commit SHAs +func completeCommit(ctx context.Context, client *github.Client, value string, uri string) (*mcp.CompleteResult, error) { + owner, repo := extractOwnerRepoFromURI(uri) + if owner == "" || repo == "" { + return &mcp.CompleteResult{ + Completion: struct { + Values []string `json:"values"` + Total int `json:"total,omitempty"` + HasMore bool `json:"hasMore,omitempty"` + }{ + Values: []string{}, + Total: 0, + }, + }, nil + } + + // If user has typed some characters, search for commits + if len(value) >= 3 { + // List recent commits + opts := &github.CommitsListOptions{ + ListOptions: github.ListOptions{ + Page: 1, + PerPage: 10, + }, + } + + commits, _, err := client.Repositories.ListCommits(ctx, owner, repo, opts) + if err != nil { + return nil, fmt.Errorf("failed to list commits: %w", err) + } + + var values []string + for _, commit := range commits { + if commit.SHA != nil && strings.HasPrefix(strings.ToLower(*commit.SHA), strings.ToLower(value)) { + values = append(values, *commit.SHA) + } + } + + return &mcp.CompleteResult{ + Completion: struct { + Values []string `json:"values"` + Total int `json:"total,omitempty"` + HasMore bool `json:"hasMore,omitempty"` + }{ + Values: values, + Total: len(values), + HasMore: len(commits) >= 10, + }, + }, nil + } + + // For short prefixes, return empty completion + return &mcp.CompleteResult{ + Completion: struct { + Values []string `json:"values"` + Total int `json:"total,omitempty"` + HasMore bool `json:"hasMore,omitempty"` + }{ + Values: []string{}, + Total: 0, + }, + }, nil +} + +// completeTag provides completions for tag names +func completeTag(ctx context.Context, client *github.Client, value string, uri string) (*mcp.CompleteResult, error) { + owner, repo := extractOwnerRepoFromURI(uri) + if owner == "" || repo == "" { + return &mcp.CompleteResult{ + Completion: struct { + Values []string `json:"values"` + Total int `json:"total,omitempty"` + HasMore bool `json:"hasMore,omitempty"` + }{ + Values: []string{}, + Total: 0, + }, + }, nil + } + + // List tags + opts := &github.ListOptions{ + Page: 1, + PerPage: 30, + } + + tags, _, err := client.Repositories.ListTags(ctx, owner, repo, opts) + if err != nil { + return nil, fmt.Errorf("failed to list tags: %w", err) + } + + var values []string + for _, tag := range tags { + if tag.Name != nil && strings.HasPrefix(strings.ToLower(*tag.Name), strings.ToLower(value)) { + values = append(values, *tag.Name) + } + } + + return &mcp.CompleteResult{ + Completion: struct { + Values []string `json:"values"` + Total int `json:"total,omitempty"` + HasMore bool `json:"hasMore,omitempty"` + }{ + Values: values, + Total: len(values), + HasMore: len(tags) >= 30, + }, + }, nil +} + +// completePullRequest provides completions for pull request numbers +func completePullRequest(ctx context.Context, client *github.Client, value string, uri string) (*mcp.CompleteResult, error) { + owner, repo := extractOwnerRepoFromURI(uri) + if owner == "" || repo == "" { + return &mcp.CompleteResult{ + Completion: struct { + Values []string `json:"values"` + Total int `json:"total,omitempty"` + HasMore bool `json:"hasMore,omitempty"` + }{ + Values: []string{}, + Total: 0, + }, + }, nil + } + + // List pull requests + opts := &github.PullRequestListOptions{ + State: "all", + ListOptions: github.ListOptions{ + Page: 1, + PerPage: 20, + }, + } + + prs, _, err := client.PullRequests.List(ctx, owner, repo, opts) + if err != nil { + return nil, fmt.Errorf("failed to list pull requests: %w", err) + } + + var values []string + for _, pr := range prs { + if pr.Number != nil { + prNumber := fmt.Sprintf("%d", *pr.Number) + if strings.HasPrefix(prNumber, value) { + values = append(values, prNumber) + } + } + } + + return &mcp.CompleteResult{ + Completion: struct { + Values []string `json:"values"` + Total int `json:"total,omitempty"` + HasMore bool `json:"hasMore,omitempty"` + }{ + Values: values, + Total: len(values), + HasMore: len(prs) >= 20, + }, + }, nil +} + +// completePath provides completions for file/directory paths +func completePath(ctx context.Context, client *github.Client, value string, uri string) (*mcp.CompleteResult, error) { + owner, repo := extractOwnerRepoFromURI(uri) + if owner == "" || repo == "" { + return &mcp.CompleteResult{ + Completion: struct { + Values []string `json:"values"` + Total int `json:"total,omitempty"` + HasMore bool `json:"hasMore,omitempty"` + }{ + Values: []string{}, + Total: 0, + }, + }, nil + } + + // Determine the directory to list based on the current path value + var dirPath string + var prefix string + + if value == "" { + dirPath = "" + prefix = "" + } else if strings.HasSuffix(value, "/") { + dirPath = strings.TrimSuffix(value, "/") + prefix = "" + } else { + lastSlash := strings.LastIndex(value, "/") + if lastSlash == -1 { + dirPath = "" + prefix = value + } else { + dirPath = value[:lastSlash] + prefix = value[lastSlash+1:] + } + } + + // Get repository contents for the directory + opts := &github.RepositoryContentGetOptions{} + + // Extract ref if present in URI (branch, commit, etc.) + if ref := extractRefFromURI(uri); ref != "" { + opts.Ref = ref + } + + _, directoryContent, _, err := client.Repositories.GetContents(ctx, owner, repo, dirPath, opts) + if err != nil { + // If directory doesn't exist, return empty completion + return &mcp.CompleteResult{ + Completion: struct { + Values []string `json:"values"` + Total int `json:"total,omitempty"` + HasMore bool `json:"hasMore,omitempty"` + }{ + Values: []string{}, + Total: 0, + }, + }, nil + } + + var values []string + for _, entry := range directoryContent { + if entry.Name != nil && strings.HasPrefix(strings.ToLower(*entry.Name), strings.ToLower(prefix)) { + entryPath := *entry.Name + if dirPath != "" { + entryPath = dirPath + "/" + entryPath + } + + // Add trailing slash for directories + if entry.Type != nil && *entry.Type == "dir" { + entryPath += "/" + } + + values = append(values, entryPath) + } + } + + return &mcp.CompleteResult{ + Completion: struct { + Values []string `json:"values"` + Total int `json:"total,omitempty"` + HasMore bool `json:"hasMore,omitempty"` + }{ + Values: values, + Total: len(values), + HasMore: false, + }, + }, nil +} + +// Helper functions to extract information from URI + +// extractOwnerFromURI extracts the owner from a repo:// URI +func extractOwnerFromURI(uri string) string { + // Parse URI like repo://{owner}/{repo}/... + if !strings.HasPrefix(uri, "repo://") { + return "" + } + + path := strings.TrimPrefix(uri, "repo://") + parts := strings.Split(path, "/") + if len(parts) > 0 && strings.Contains(parts[0], "{owner}") { + return "" // Template not filled + } + if len(parts) > 0 { + return parts[0] + } + + return "" +} + +// extractOwnerRepoFromURI extracts owner and repo from a repo:// URI +func extractOwnerRepoFromURI(uri string) (string, string) { + if !strings.HasPrefix(uri, "repo://") { + return "", "" + } + + path := strings.TrimPrefix(uri, "repo://") + parts := strings.Split(path, "/") + + if len(parts) >= 2 { + owner := parts[0] + repo := parts[1] + + // Skip if still templates + if strings.Contains(owner, "{") || strings.Contains(repo, "{") { + return "", "" + } + + return owner, repo + } + + return "", "" +} + +// extractRefFromURI extracts the ref (branch, commit, tag) from a repo:// URI +func extractRefFromURI(uri string) string { + if !strings.HasPrefix(uri, "repo://") { + return "" + } + + path := strings.TrimPrefix(uri, "repo://") + + // Look for patterns like /refs/heads/{branch}, /sha/{sha}, /refs/tags/{tag}, /refs/pull/{prNumber}/head + if strings.Contains(path, "/refs/heads/") { + parts := strings.Split(path, "/refs/heads/") + if len(parts) > 1 { + branchPart := strings.Split(parts[1], "/")[0] + if !strings.Contains(branchPart, "{") { + return "refs/heads/" + branchPart + } + } + } else if strings.Contains(path, "/sha/") { + parts := strings.Split(path, "/sha/") + if len(parts) > 1 { + shaPart := strings.Split(parts[1], "/")[0] + if !strings.Contains(shaPart, "{") { + return shaPart + } + } + } else if strings.Contains(path, "/refs/tags/") { + parts := strings.Split(path, "/refs/tags/") + if len(parts) > 1 { + tagPart := strings.Split(parts[1], "/")[0] + if !strings.Contains(tagPart, "{") { + return "refs/tags/" + tagPart + } + } + } else if strings.Contains(path, "/refs/pull/") && strings.Contains(path, "/head") { + parts := strings.Split(path, "/refs/pull/") + if len(parts) > 1 { + prPart := strings.Split(parts[1], "/head")[0] + if !strings.Contains(prPart, "{") { + return "refs/pull/" + prPart + "/head" + } + } + } + + return "" +} \ No newline at end of file diff --git a/pkg/github/repository_completions_test.go b/pkg/github/repository_completions_test.go new file mode 100644 index 00000000..fa2292e2 --- /dev/null +++ b/pkg/github/repository_completions_test.go @@ -0,0 +1,344 @@ +package github + +import ( + "context" + "testing" + + "github.com/google/go-github/v69/github" + "github.com/mark3labs/mcp-go/mcp" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestRepositoryCompletionHandler(t *testing.T) { + // Mock client function that returns an error to test error handling + errorGetClient := func(_ context.Context) (*github.Client, error) { + return nil, assert.AnError + } + + tests := []struct { + name string + request mcp.CompleteRequest + getClient GetClientFn + expectedError bool + }{ + { + name: "invalid ref type - should return error", + request: mcp.CompleteRequest{ + Params: struct { + Ref any `json:"ref"` + Argument struct { + Name string `json:"name"` + Value string `json:"value"` + } `json:"argument"` + }{ + Ref: "invalid", + Argument: struct { + Name string `json:"name"` + Value string `json:"value"` + }{ + Name: "owner", + Value: "test", + }, + }, + }, + getClient: errorGetClient, + expectedError: true, + }, + { + name: "unsupported ref type - should return error", + request: mcp.CompleteRequest{ + Params: struct { + Ref any `json:"ref"` + Argument struct { + Name string `json:"name"` + Value string `json:"value"` + } `json:"argument"` + }{ + Ref: map[string]interface{}{ + "type": "ref/prompt", + "name": "some_prompt", + }, + Argument: struct { + Name string `json:"name"` + Value string `json:"value"` + }{ + Name: "param", + Value: "test", + }, + }, + }, + getClient: errorGetClient, + expectedError: true, + }, + { + name: "missing uri in resource reference - should return error", + request: mcp.CompleteRequest{ + Params: struct { + Ref any `json:"ref"` + Argument struct { + Name string `json:"name"` + Value string `json:"value"` + } `json:"argument"` + }{ + Ref: map[string]interface{}{ + "type": "ref/resource", + }, + Argument: struct { + Name string `json:"name"` + Value string `json:"value"` + }{ + Name: "owner", + Value: "test", + }, + }, + }, + getClient: errorGetClient, + expectedError: true, + }, + { + name: "non-repo URI - should return empty completion", + request: mcp.CompleteRequest{ + Params: struct { + Ref any `json:"ref"` + Argument struct { + Name string `json:"name"` + Value string `json:"value"` + } `json:"argument"` + }{ + Ref: map[string]interface{}{ + "type": "ref/resource", + "uri": "file:///some/path", + }, + Argument: struct { + Name string `json:"name"` + Value string `json:"value"` + }{ + Name: "param", + Value: "test", + }, + }, + }, + getClient: errorGetClient, + expectedError: false, + }, + { + name: "unsupported argument - should return empty completion", + request: mcp.CompleteRequest{ + Params: struct { + Ref any `json:"ref"` + Argument struct { + Name string `json:"name"` + Value string `json:"value"` + } `json:"argument"` + }{ + Ref: map[string]interface{}{ + "type": "ref/resource", + "uri": "repo://{owner}/{repo}/contents{/path*}", + }, + Argument: struct { + Name string `json:"name"` + Value string `json:"value"` + }{ + Name: "unsupported", + Value: "test", + }, + }, + }, + getClient: errorGetClient, + expectedError: false, + }, + { + name: "client error - should return error", + request: mcp.CompleteRequest{ + Params: struct { + Ref any `json:"ref"` + Argument struct { + Name string `json:"name"` + Value string `json:"value"` + } `json:"argument"` + }{ + Ref: map[string]interface{}{ + "type": "ref/resource", + "uri": "repo://{owner}/{repo}/contents{/path*}", + }, + Argument: struct { + Name string `json:"name"` + Value string `json:"value"` + }{ + Name: "owner", + Value: "test", + }, + }, + }, + getClient: errorGetClient, + expectedError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + handler := RepositoryCompletionHandler(tt.getClient) + result, err := handler(context.Background(), tt.request) + + if tt.expectedError { + assert.Error(t, err) + return + } + + require.NoError(t, err) + require.NotNil(t, result) + // For non-repo URIs and unsupported arguments, we should get empty completion + assert.Equal(t, []string{}, result.Completion.Values) + }) + } +} + +func TestUtilityFunctions(t *testing.T) { + tests := []struct { + name string + uri string + expected string + }{ + { + name: "extract owner from basic repo URI", + uri: "repo://octocat/Hello-World/contents", + expected: "octocat", + }, + { + name: "extract owner from template URI", + uri: "repo://{owner}/{repo}/contents{/path*}", + expected: "", + }, + { + name: "extract owner from non-repo URI", + uri: "file:///some/path", + expected: "", + }, + { + name: "empty URI", + uri: "", + expected: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := extractOwnerFromURI(tt.uri) + assert.Equal(t, tt.expected, result) + }) + } + + ownerRepoTests := []struct { + name string + uri string + expectedOwner string + expectedRepo string + }{ + { + name: "extract owner and repo from basic URI", + uri: "repo://octocat/Hello-World/contents", + expectedOwner: "octocat", + expectedRepo: "Hello-World", + }, + { + name: "extract from template URI", + uri: "repo://{owner}/{repo}/contents{/path*}", + expectedOwner: "", + expectedRepo: "", + }, + { + name: "extract from branch URI", + uri: "repo://octocat/Hello-World/refs/heads/main/contents", + expectedOwner: "octocat", + expectedRepo: "Hello-World", + }, + { + name: "extract from commit URI", + uri: "repo://octocat/Hello-World/sha/abc123/contents", + expectedOwner: "octocat", + expectedRepo: "Hello-World", + }, + { + name: "extract from tag URI", + uri: "repo://octocat/Hello-World/refs/tags/v1.0/contents", + expectedOwner: "octocat", + expectedRepo: "Hello-World", + }, + { + name: "extract from PR URI", + uri: "repo://octocat/Hello-World/refs/pull/123/head/contents", + expectedOwner: "octocat", + expectedRepo: "Hello-World", + }, + { + name: "non-repo URI", + uri: "file:///some/path", + expectedOwner: "", + expectedRepo: "", + }, + { + name: "empty URI", + uri: "", + expectedOwner: "", + expectedRepo: "", + }, + } + + for _, tt := range ownerRepoTests { + t.Run(tt.name, func(t *testing.T) { + owner, repo := extractOwnerRepoFromURI(tt.uri) + assert.Equal(t, tt.expectedOwner, owner) + assert.Equal(t, tt.expectedRepo, repo) + }) + } + + refTests := []struct { + name string + uri string + expectedRef string + }{ + { + name: "extract branch ref", + uri: "repo://octocat/Hello-World/refs/heads/main/contents", + expectedRef: "refs/heads/main", + }, + { + name: "extract commit ref", + uri: "repo://octocat/Hello-World/sha/abc123/contents", + expectedRef: "abc123", + }, + { + name: "extract tag ref", + uri: "repo://octocat/Hello-World/refs/tags/v1.0/contents", + expectedRef: "refs/tags/v1.0", + }, + { + name: "extract PR ref", + uri: "repo://octocat/Hello-World/refs/pull/123/head/contents", + expectedRef: "refs/pull/123/head", + }, + { + name: "basic repo URI - no ref", + uri: "repo://octocat/Hello-World/contents", + expectedRef: "", + }, + { + name: "template URI - no ref", + uri: "repo://{owner}/{repo}/contents{/path*}", + expectedRef: "", + }, + { + name: "template branch URI - no ref", + uri: "repo://{owner}/{repo}/refs/heads/{branch}/contents{/path*}", + expectedRef: "", + }, + } + + for _, tt := range refTests { + t.Run(tt.name, func(t *testing.T) { + result := extractRefFromURI(tt.uri) + assert.Equal(t, tt.expectedRef, result) + }) + } +} \ No newline at end of file diff --git a/pkg/github/server.go b/pkg/github/server.go index b182b8ca..070286b8 100644 --- a/pkg/github/server.go +++ b/pkg/github/server.go @@ -1,6 +1,7 @@ package github import ( + "context" "encoding/json" "errors" "fmt" @@ -30,6 +31,93 @@ func NewServer(version string, opts ...server.ServerOption) *server.MCPServer { return s } +// GitHubMCPServer wraps the base MCP server to add completion support +type GitHubMCPServer struct { + *server.MCPServer + completionHandler CompletionHandlerFunc +} + +// GetMCPServer returns the underlying MCP server for compatibility +func (s *GitHubMCPServer) GetMCPServer() *server.MCPServer { + return s.MCPServer +} + +// GetCompletionHandler returns the completion handler +func (s *GitHubMCPServer) GetCompletionHandler() CompletionHandlerFunc { + return s.completionHandler +} + +// NewGitHubServer creates a new GitHub MCP server with completion support +func NewGitHubServer(version string, getClient GetClientFn, opts ...server.ServerOption) *GitHubMCPServer { + baseServer := NewServer(version, opts...) + + return &GitHubMCPServer{ + MCPServer: baseServer, + completionHandler: RepositoryCompletionHandler(getClient), + } +} + +// HandleMessage overrides the base HandleMessage to add completion support +func (s *GitHubMCPServer) HandleMessage(ctx context.Context, message json.RawMessage) mcp.JSONRPCMessage { + // Parse the message to check for completion requests + var baseMessage struct { + JSONRPC string `json:"jsonrpc"` + Method mcp.MCPMethod `json:"method"` + ID any `json:"id,omitempty"` + } + + if err := json.Unmarshal(message, &baseMessage); err != nil { + return s.MCPServer.HandleMessage(ctx, message) + } + + // Handle completion requests + if string(baseMessage.Method) == "completion/complete" { + return s.handleCompletion(ctx, baseMessage.ID, message) + } + + // Delegate to base server for all other requests + return s.MCPServer.HandleMessage(ctx, message) +} + +// handleCompletion processes completion requests +func (s *GitHubMCPServer) handleCompletion(ctx context.Context, id any, message json.RawMessage) mcp.JSONRPCMessage { + var request mcp.CompleteRequest + if err := json.Unmarshal(message, &request); err != nil { + return createErrorResponse(id, mcp.INVALID_REQUEST, "Failed to parse completion request") + } + + result, err := s.completionHandler(ctx, request) + if err != nil { + return createErrorResponse(id, mcp.INTERNAL_ERROR, fmt.Sprintf("Completion failed: %v", err)) + } + + return createResponse(id, *result) +} + +// Helper functions for JSON-RPC responses +func createErrorResponse(id any, code int, message string) mcp.JSONRPCMessage { + return mcp.JSONRPCError{ + JSONRPC: mcp.JSONRPC_VERSION, + ID: mcp.NewRequestId(id), + Error: struct { + Code int `json:"code"` + Message string `json:"message"` + Data any `json:"data,omitempty"` + }{ + Code: code, + Message: message, + }, + } +} + +func createResponse(id any, result any) mcp.JSONRPCMessage { + return mcp.JSONRPCResponse{ + JSONRPC: mcp.JSONRPC_VERSION, + ID: mcp.NewRequestId(id), + Result: result, + } +} + // OptionalParamOK is a helper function that can be used to fetch a requested parameter from the request. // It returns the value, a boolean indicating if the parameter was present, and an error if the type is wrong. func OptionalParamOK[T any](r mcp.CallToolRequest, p string) (value T, ok bool, err error) { diff --git a/testdata/branches.json b/testdata/branches.json new file mode 100644 index 00000000..b113a99a --- /dev/null +++ b/testdata/branches.json @@ -0,0 +1,18 @@ +[ + { + "name": "main", + "commit": { + "sha": "6dcb09b5b57875f334f61aebed695e2e4193db5e", + "url": "https://api.github.com/repos/octocat/Hello-World/commits/6dcb09b5b57875f334f61aebed695e2e4193db5e" + }, + "protected": false + }, + { + "name": "dev", + "commit": { + "sha": "1234567890abcdef1234567890abcdef12345678", + "url": "https://api.github.com/repos/octocat/Hello-World/commits/1234567890abcdef1234567890abcdef12345678" + }, + "protected": false + } +] \ No newline at end of file diff --git a/testdata/commits.json b/testdata/commits.json new file mode 100644 index 00000000..3b833541 --- /dev/null +++ b/testdata/commits.json @@ -0,0 +1,14 @@ +[ + { + "sha": "6dcb09b5b57875f334f61aebed695e2e4193db5e", + "commit": { + "message": "Initial commit" + } + }, + { + "sha": "1234567890abcdef1234567890abcdef12345678", + "commit": { + "message": "Add new feature" + } + } +] \ No newline at end of file diff --git a/testdata/pullrequests.json b/testdata/pullrequests.json new file mode 100644 index 00000000..d28e1282 --- /dev/null +++ b/testdata/pullrequests.json @@ -0,0 +1,12 @@ +[ + { + "number": 1, + "title": "Test PR", + "state": "open" + }, + { + "number": 2, + "title": "Another PR", + "state": "closed" + } +] \ No newline at end of file diff --git a/testdata/repository_content.json b/testdata/repository_content.json new file mode 100644 index 00000000..68abd8d9 --- /dev/null +++ b/testdata/repository_content.json @@ -0,0 +1,17 @@ +[ + { + "name": "README.md", + "type": "file", + "path": "README.md" + }, + { + "name": "src", + "type": "dir", + "path": "src" + }, + { + "name": "package.json", + "type": "file", + "path": "package.json" + } +] \ No newline at end of file diff --git a/testdata/tags.json b/testdata/tags.json new file mode 100644 index 00000000..6e48e2ad --- /dev/null +++ b/testdata/tags.json @@ -0,0 +1,14 @@ +[ + { + "name": "v1.0.0", + "commit": { + "sha": "6dcb09b5b57875f334f61aebed695e2e4193db5e" + } + }, + { + "name": "v0.1.0", + "commit": { + "sha": "1234567890abcdef1234567890abcdef12345678" + } + } +] \ No newline at end of file