diff --git a/README.md b/README.md index a55d174b..332ab3dd 100644 --- a/README.md +++ b/README.md @@ -163,7 +163,7 @@ func main() { result = x * y case "divide": if y == 0 { - return nil, errors.New("Cannot divide by zero") + return mcp.NewToolResultError("cannot divide by zero"), nil } result = x / y } @@ -325,7 +325,7 @@ s.AddTool(calculatorTool, func(ctx context.Context, request mcp.CallToolRequest) result = x * y case "divide": if y == 0 { - return nil, errors.New("Division by zero is not allowed") + return mcp.NewToolResultError("cannot divide by zero"), nil } result = x / y } @@ -370,20 +370,20 @@ s.AddTool(httpTool, func(ctx context.Context, request mcp.CallToolRequest) (*mcp req, err = http.NewRequest(method, url, nil) } if err != nil { - return nil, fmt.Errorf("Failed to create request: %v", err) + return mcp.NewToolResultErrorFromErr("unable to create request", err), nil } client := &http.Client{} resp, err := client.Do(req) if err != nil { - return nil, fmt.Errorf("Request failed: %v", err) + return mcp.NewToolResultErrorFromErr("unable to execute request", err), nil } defer resp.Body.Close() // Return response respBody, err := io.ReadAll(resp.Body) if err != nil { - return nil, fmt.Errorf("Failed to read response: %v", err) + return mcp.NewToolResultErrorFromErr("unable to read request response", err), nil } return mcp.NewToolResultText(fmt.Sprintf("Status: %d\nBody: %s", resp.StatusCode, string(respBody))), nil diff --git a/client/client.go b/client/client.go index ba03b582..60fe0cbf 100644 --- a/client/client.go +++ b/client/client.go @@ -1,124 +1,390 @@ -// Package client provides MCP (Model Control Protocol) client implementations. package client import ( "context" "encoding/json" + "errors" "fmt" + "sync" + "sync/atomic" + "github.com/mark3labs/mcp-go/client/transport" "github.com/mark3labs/mcp-go/mcp" ) -// MCPClient represents an MCP client interface -type MCPClient interface { - // Initialize sends the initial connection request to the server - Initialize( - ctx context.Context, - request mcp.InitializeRequest, - ) (*mcp.InitializeResult, error) - - // Ping checks if the server is alive - Ping(ctx context.Context) error - - // ListResourcesByPage manually list resources by page. - ListResourcesByPage( - ctx context.Context, - request mcp.ListResourcesRequest, - ) (*mcp.ListResourcesResult, error) - - // ListResources requests a list of available resources from the server - ListResources( - ctx context.Context, - request mcp.ListResourcesRequest, - ) (*mcp.ListResourcesResult, error) - - // ListResourceTemplatesByPage manually list resource templates by page. - ListResourceTemplatesByPage( - ctx context.Context, - request mcp.ListResourceTemplatesRequest, - ) (*mcp.ListResourceTemplatesResult, - error) - - // ListResourceTemplates requests a list of available resource templates from the server - ListResourceTemplates( - ctx context.Context, - request mcp.ListResourceTemplatesRequest, - ) (*mcp.ListResourceTemplatesResult, - error) - - // ReadResource reads a specific resource from the server - ReadResource( - ctx context.Context, - request mcp.ReadResourceRequest, - ) (*mcp.ReadResourceResult, error) - - // Subscribe requests notifications for changes to a specific resource - Subscribe(ctx context.Context, request mcp.SubscribeRequest) error - - // Unsubscribe cancels notifications for a specific resource - Unsubscribe(ctx context.Context, request mcp.UnsubscribeRequest) error - - // ListPromptsByPage manually list prompts by page. - ListPromptsByPage( - ctx context.Context, - request mcp.ListPromptsRequest, - ) (*mcp.ListPromptsResult, error) - - // ListPrompts requests a list of available prompts from the server - ListPrompts( - ctx context.Context, - request mcp.ListPromptsRequest, - ) (*mcp.ListPromptsResult, error) - - // GetPrompt retrieves a specific prompt from the server - GetPrompt( - ctx context.Context, - request mcp.GetPromptRequest, - ) (*mcp.GetPromptResult, error) - - // ListToolsByPage manually list tools by page. - ListToolsByPage( - ctx context.Context, - request mcp.ListToolsRequest, - ) (*mcp.ListToolsResult, error) - - // ListTools requests a list of available tools from the server - ListTools( - ctx context.Context, - request mcp.ListToolsRequest, - ) (*mcp.ListToolsResult, error) - - // CallTool invokes a specific tool on the server - CallTool( - ctx context.Context, - request mcp.CallToolRequest, - ) (*mcp.CallToolResult, error) - - // SetLevel sets the logging level for the server - SetLevel(ctx context.Context, request mcp.SetLevelRequest) error - - // Complete requests completion options for a given argument - Complete( - ctx context.Context, - request mcp.CompleteRequest, - ) (*mcp.CompleteResult, error) - - // Close client connection and cleanup resources - Close() error - - // OnNotification registers a handler for notifications - OnNotification(handler func(notification mcp.JSONRPCNotification)) -} - -type mcpClient interface { - MCPClient - - sendRequest(ctx context.Context, method string, params interface{}) (*json.RawMessage, error) +// Client implements the MCP client. +type Client struct { + transport transport.Interface + + initialized bool + notifications []func(mcp.JSONRPCNotification) + notifyMu sync.RWMutex + requestID atomic.Int64 + capabilities mcp.ServerCapabilities +} + +// NewClient creates a new MCP client with the given transport. +// Usage: +// +// stdio := transport.NewStdio("./mcp_server", nil, "xxx") +// client, err := NewClient(stdio) +// if err != nil { +// log.Fatalf("Failed to create client: %v", err) +// } +func NewClient(transport transport.Interface) *Client { + return &Client{ + transport: transport, + } +} + +// Start initiates the connection to the server. +// Must be called before using the client. +func (c *Client) Start(ctx context.Context) error { + if c.transport == nil { + return fmt.Errorf("transport is nil") + } + err := c.transport.Start(ctx) + if err != nil { + return err + } + + c.transport.SetNotificationHandler(func(notification mcp.JSONRPCNotification) { + c.notifyMu.RLock() + defer c.notifyMu.RUnlock() + for _, handler := range c.notifications { + handler(notification) + } + }) + return nil +} + +// Close shuts down the client and closes the transport. +func (c *Client) Close() error { + return c.transport.Close() +} + +// OnNotification registers a handler function to be called when notifications are received. +// Multiple handlers can be registered and will be called in the order they were added. +func (c *Client) OnNotification( + handler func(notification mcp.JSONRPCNotification), +) { + c.notifyMu.Lock() + defer c.notifyMu.Unlock() + c.notifications = append(c.notifications, handler) +} + +// sendRequest sends a JSON-RPC request to the server and waits for a response. +// Returns the raw JSON response message or an error if the request fails. +func (c *Client) sendRequest( + ctx context.Context, + method string, + params interface{}, +) (*json.RawMessage, error) { + if !c.initialized && method != "initialize" { + return nil, fmt.Errorf("client not initialized") + } + + id := c.requestID.Add(1) + + request := transport.JSONRPCRequest{ + JSONRPC: mcp.JSONRPC_VERSION, + ID: id, + Method: method, + Params: params, + } + + response, err := c.transport.SendRequest(ctx, request) + if err != nil { + return nil, fmt.Errorf("transport error: %w", err) + } + + if response.Error != nil { + return nil, errors.New(response.Error.Message) + } + + return &response.Result, nil +} + +// Initialize negotiates with the server. +// Must be called after Start, and before any request methods. +func (c *Client) Initialize( + ctx context.Context, + request mcp.InitializeRequest, +) (*mcp.InitializeResult, error) { + // Ensure we send a params object with all required fields + params := struct { + ProtocolVersion string `json:"protocolVersion"` + ClientInfo mcp.Implementation `json:"clientInfo"` + Capabilities mcp.ClientCapabilities `json:"capabilities"` + }{ + ProtocolVersion: request.Params.ProtocolVersion, + ClientInfo: request.Params.ClientInfo, + Capabilities: request.Params.Capabilities, // Will be empty struct if not set + } + + response, err := c.sendRequest(ctx, "initialize", params) + if err != nil { + return nil, err + } + + var result mcp.InitializeResult + if err := json.Unmarshal(*response, &result); err != nil { + return nil, fmt.Errorf("failed to unmarshal response: %w", err) + } + + // Store capabilities + c.capabilities = result.Capabilities + + // Send initialized notification + notification := mcp.JSONRPCNotification{ + JSONRPC: mcp.JSONRPC_VERSION, + Notification: mcp.Notification{ + Method: "notifications/initialized", + }, + } + + err = c.transport.SendNotification(ctx, notification) + if err != nil { + return nil, fmt.Errorf( + "failed to send initialized notification: %w", + err, + ) + } + + c.initialized = true + return &result, nil +} + +func (c *Client) Ping(ctx context.Context) error { + _, err := c.sendRequest(ctx, "ping", nil) + return err +} + +// ListResourcesByPage manually list resources by page. +func (c *Client) ListResourcesByPage( + ctx context.Context, + request mcp.ListResourcesRequest, +) (*mcp.ListResourcesResult, error) { + result, err := listByPage[mcp.ListResourcesResult](ctx, c, request.PaginatedRequest, "resources/list") + if err != nil { + return nil, err + } + return result, nil +} + +func (c *Client) ListResources( + ctx context.Context, + request mcp.ListResourcesRequest, +) (*mcp.ListResourcesResult, error) { + result, err := c.ListResourcesByPage(ctx, request) + if err != nil { + return nil, err + } + for result.NextCursor != "" { + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + request.Params.Cursor = result.NextCursor + newPageRes, err := c.ListResourcesByPage(ctx, request) + if err != nil { + return nil, err + } + result.Resources = append(result.Resources, newPageRes.Resources...) + result.NextCursor = newPageRes.NextCursor + } + } + return result, nil +} + +func (c *Client) ListResourceTemplatesByPage( + ctx context.Context, + request mcp.ListResourceTemplatesRequest, +) (*mcp.ListResourceTemplatesResult, error) { + result, err := listByPage[mcp.ListResourceTemplatesResult](ctx, c, request.PaginatedRequest, "resources/templates/list") + if err != nil { + return nil, err + } + return result, nil +} + +func (c *Client) ListResourceTemplates( + ctx context.Context, + request mcp.ListResourceTemplatesRequest, +) (*mcp.ListResourceTemplatesResult, error) { + result, err := c.ListResourceTemplatesByPage(ctx, request) + if err != nil { + return nil, err + } + for result.NextCursor != "" { + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + request.Params.Cursor = result.NextCursor + newPageRes, err := c.ListResourceTemplatesByPage(ctx, request) + if err != nil { + return nil, err + } + result.ResourceTemplates = append(result.ResourceTemplates, newPageRes.ResourceTemplates...) + result.NextCursor = newPageRes.NextCursor + } + } + return result, nil +} + +func (c *Client) ReadResource( + ctx context.Context, + request mcp.ReadResourceRequest, +) (*mcp.ReadResourceResult, error) { + response, err := c.sendRequest(ctx, "resources/read", request.Params) + if err != nil { + return nil, err + } + + return mcp.ParseReadResourceResult(response) +} + +func (c *Client) Subscribe( + ctx context.Context, + request mcp.SubscribeRequest, +) error { + _, err := c.sendRequest(ctx, "resources/subscribe", request.Params) + return err +} + +func (c *Client) Unsubscribe( + ctx context.Context, + request mcp.UnsubscribeRequest, +) error { + _, err := c.sendRequest(ctx, "resources/unsubscribe", request.Params) + return err +} + +func (c *Client) ListPromptsByPage( + ctx context.Context, + request mcp.ListPromptsRequest, +) (*mcp.ListPromptsResult, error) { + result, err := listByPage[mcp.ListPromptsResult](ctx, c, request.PaginatedRequest, "prompts/list") + if err != nil { + return nil, err + } + return result, nil +} + +func (c *Client) ListPrompts( + ctx context.Context, + request mcp.ListPromptsRequest, +) (*mcp.ListPromptsResult, error) { + result, err := c.ListPromptsByPage(ctx, request) + if err != nil { + return nil, err + } + for result.NextCursor != "" { + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + request.Params.Cursor = result.NextCursor + newPageRes, err := c.ListPromptsByPage(ctx, request) + if err != nil { + return nil, err + } + result.Prompts = append(result.Prompts, newPageRes.Prompts...) + result.NextCursor = newPageRes.NextCursor + } + } + return result, nil +} + +func (c *Client) GetPrompt( + ctx context.Context, + request mcp.GetPromptRequest, +) (*mcp.GetPromptResult, error) { + response, err := c.sendRequest(ctx, "prompts/get", request.Params) + if err != nil { + return nil, err + } + + return mcp.ParseGetPromptResult(response) +} + +func (c *Client) ListToolsByPage( + ctx context.Context, + request mcp.ListToolsRequest, +) (*mcp.ListToolsResult, error) { + result, err := listByPage[mcp.ListToolsResult](ctx, c, request.PaginatedRequest, "tools/list") + if err != nil { + return nil, err + } + return result, nil +} + +func (c *Client) ListTools( + ctx context.Context, + request mcp.ListToolsRequest, +) (*mcp.ListToolsResult, error) { + result, err := c.ListToolsByPage(ctx, request) + if err != nil { + return nil, err + } + for result.NextCursor != "" { + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + request.Params.Cursor = result.NextCursor + newPageRes, err := c.ListToolsByPage(ctx, request) + if err != nil { + return nil, err + } + result.Tools = append(result.Tools, newPageRes.Tools...) + result.NextCursor = newPageRes.NextCursor + } + } + return result, nil +} + +func (c *Client) CallTool( + ctx context.Context, + request mcp.CallToolRequest, +) (*mcp.CallToolResult, error) { + response, err := c.sendRequest(ctx, "tools/call", request.Params) + if err != nil { + return nil, err + } + + return mcp.ParseCallToolResult(response) +} + +func (c *Client) SetLevel( + ctx context.Context, + request mcp.SetLevelRequest, +) error { + _, err := c.sendRequest(ctx, "logging/setLevel", request.Params) + return err +} + +func (c *Client) Complete( + ctx context.Context, + request mcp.CompleteRequest, +) (*mcp.CompleteResult, error) { + response, err := c.sendRequest(ctx, "completion/complete", request.Params) + if err != nil { + return nil, err + } + + var result mcp.CompleteResult + if err := json.Unmarshal(*response, &result); err != nil { + return nil, fmt.Errorf("failed to unmarshal response: %w", err) + } + + return &result, nil } func listByPage[T any]( ctx context.Context, - client mcpClient, + client *Client, request mcp.PaginatedRequest, method string, ) (*T, error) { @@ -132,3 +398,11 @@ func listByPage[T any]( } return &result, nil } + +// Helper methods + +// GetTransport gives access to the underlying transport layer. +// Cast it to the specific transport type and obtain the other helper methods. +func (c *Client) GetTransport() transport.Interface { + return c.transport +} diff --git a/client/http.go b/client/http.go new file mode 100644 index 00000000..cb3be35d --- /dev/null +++ b/client/http.go @@ -0,0 +1,17 @@ +package client + +import ( + "fmt" + + "github.com/mark3labs/mcp-go/client/transport" +) + +// NewStreamableHttpClient is a convenience method that creates a new streamable-http-based MCP client +// with the given base URL. Returns an error if the URL is invalid. +func NewStreamableHttpClient(baseURL string, options ...transport.StreamableHTTPCOption) (*Client, error) { + trans, err := transport.NewStreamableHTTP(baseURL, options...) + if err != nil { + return nil, fmt.Errorf("failed to create SSE transport: %w", err) + } + return NewClient(trans), nil +} diff --git a/client/interface.go b/client/interface.go new file mode 100644 index 00000000..ea7f4d1f --- /dev/null +++ b/client/interface.go @@ -0,0 +1,109 @@ +// Package client provides MCP (Model Control Protocol) client implementations. +package client + +import ( + "context" + + "github.com/mark3labs/mcp-go/mcp" +) + +// MCPClient represents an MCP client interface +type MCPClient interface { + // Initialize sends the initial connection request to the server + Initialize( + ctx context.Context, + request mcp.InitializeRequest, + ) (*mcp.InitializeResult, error) + + // Ping checks if the server is alive + Ping(ctx context.Context) error + + // ListResourcesByPage manually list resources by page. + ListResourcesByPage( + ctx context.Context, + request mcp.ListResourcesRequest, + ) (*mcp.ListResourcesResult, error) + + // ListResources requests a list of available resources from the server + ListResources( + ctx context.Context, + request mcp.ListResourcesRequest, + ) (*mcp.ListResourcesResult, error) + + // ListResourceTemplatesByPage manually list resource templates by page. + ListResourceTemplatesByPage( + ctx context.Context, + request mcp.ListResourceTemplatesRequest, + ) (*mcp.ListResourceTemplatesResult, + error) + + // ListResourceTemplates requests a list of available resource templates from the server + ListResourceTemplates( + ctx context.Context, + request mcp.ListResourceTemplatesRequest, + ) (*mcp.ListResourceTemplatesResult, + error) + + // ReadResource reads a specific resource from the server + ReadResource( + ctx context.Context, + request mcp.ReadResourceRequest, + ) (*mcp.ReadResourceResult, error) + + // Subscribe requests notifications for changes to a specific resource + Subscribe(ctx context.Context, request mcp.SubscribeRequest) error + + // Unsubscribe cancels notifications for a specific resource + Unsubscribe(ctx context.Context, request mcp.UnsubscribeRequest) error + + // ListPromptsByPage manually list prompts by page. + ListPromptsByPage( + ctx context.Context, + request mcp.ListPromptsRequest, + ) (*mcp.ListPromptsResult, error) + + // ListPrompts requests a list of available prompts from the server + ListPrompts( + ctx context.Context, + request mcp.ListPromptsRequest, + ) (*mcp.ListPromptsResult, error) + + // GetPrompt retrieves a specific prompt from the server + GetPrompt( + ctx context.Context, + request mcp.GetPromptRequest, + ) (*mcp.GetPromptResult, error) + + // ListToolsByPage manually list tools by page. + ListToolsByPage( + ctx context.Context, + request mcp.ListToolsRequest, + ) (*mcp.ListToolsResult, error) + + // ListTools requests a list of available tools from the server + ListTools( + ctx context.Context, + request mcp.ListToolsRequest, + ) (*mcp.ListToolsResult, error) + + // CallTool invokes a specific tool on the server + CallTool( + ctx context.Context, + request mcp.CallToolRequest, + ) (*mcp.CallToolResult, error) + + // SetLevel sets the logging level for the server + SetLevel(ctx context.Context, request mcp.SetLevelRequest) error + + // Complete requests completion options for a given argument + Complete( + ctx context.Context, + request mcp.CompleteRequest, + ) (*mcp.CompleteResult, error) + + // Close client connection and cleanup resources + Close() error + + // OnNotification registers a handler for notifications + OnNotification(handler func(notification mcp.JSONRPCNotification)) +} diff --git a/client/sse.go b/client/sse.go index e7aaaa49..c26744a3 100644 --- a/client/sse.go +++ b/client/sse.go @@ -1,653 +1,32 @@ package client import ( - "bufio" - "bytes" - "context" - "encoding/json" - "errors" "fmt" - "io" - "net/http" + "github.com/mark3labs/mcp-go/client/transport" "net/url" - "strings" - "sync" - "sync/atomic" - "time" - - "github.com/mark3labs/mcp-go/mcp" ) -// SSEMCPClient implements the MCPClient interface using Server-Sent Events (SSE). -// It maintains a persistent HTTP connection to receive server-pushed events -// while sending requests over regular HTTP POST calls. The client handles -// automatic reconnection and message routing between requests and responses. -type SSEMCPClient struct { - baseURL *url.URL - endpoint *url.URL - httpClient *http.Client - requestID atomic.Int64 - responses map[int64]chan RPCResponse - mu sync.RWMutex - done chan struct{} - initialized bool - notifications []func(mcp.JSONRPCNotification) - notifyMu sync.RWMutex - endpointChan chan struct{} - capabilities mcp.ServerCapabilities - headers map[string]string -} - -type ClientOption func(*SSEMCPClient) - -func WithHeaders(headers map[string]string) ClientOption { - return func(sc *SSEMCPClient) { - sc.headers = headers - } +func WithHeaders(headers map[string]string) transport.ClientOption { + return transport.WithHeaders(headers) } // NewSSEMCPClient creates a new SSE-based MCP client with the given base URL. // Returns an error if the URL is invalid. -func NewSSEMCPClient(baseURL string, options ...ClientOption) (*SSEMCPClient, error) { - parsedURL, err := url.Parse(baseURL) - if err != nil { - return nil, fmt.Errorf("invalid URL: %w", err) - } - - smc := &SSEMCPClient{ - baseURL: parsedURL, - httpClient: &http.Client{}, - responses: make(map[int64]chan RPCResponse), - done: make(chan struct{}), - endpointChan: make(chan struct{}), - headers: make(map[string]string), - } - - for _, opt := range options { - opt(smc) - } - - return smc, nil -} - -// Start initiates the SSE connection to the server and waits for the endpoint information. -// Returns an error if the connection fails or times out waiting for the endpoint. -func (c *SSEMCPClient) Start(ctx context.Context) error { - - req, err := http.NewRequestWithContext(ctx, "GET", c.baseURL.String(), nil) - - if err != nil { - - return fmt.Errorf("failed to create request: %w", err) - - } - - req.Header.Set("Accept", "text/event-stream") - req.Header.Set("Cache-Control", "no-cache") - req.Header.Set("Connection", "keep-alive") - for k, v := range c.headers { - req.Header.Set(k, v) - } - - resp, err := c.httpClient.Do(req) - if err != nil { - return fmt.Errorf("failed to connect to SSE stream: %w", err) - } - - if resp.StatusCode != http.StatusOK { - resp.Body.Close() - return fmt.Errorf("unexpected status code: %d", resp.StatusCode) - } - - go c.readSSE(resp.Body) - - // Wait for the endpoint to be received - - select { - case <-c.endpointChan: - // Endpoint received, proceed - case <-ctx.Done(): - return fmt.Errorf("context cancelled while waiting for endpoint") - case <-time.After(30 * time.Second): // Add a timeout - return fmt.Errorf("timeout waiting for endpoint") - } - - return nil -} - -// readSSE continuously reads the SSE stream and processes events. -// It runs until the connection is closed or an error occurs. -func (c *SSEMCPClient) readSSE(reader io.ReadCloser) { - defer reader.Close() - - br := bufio.NewReader(reader) - var event, data string - - for { - select { - case <-c.done: - return - default: - line, err := br.ReadString('\n') - if err != nil { - if err == io.EOF { - // Process any pending event before exit - if event != "" && data != "" { - c.handleSSEEvent(event, data) - } - break - } - select { - case <-c.done: - return - default: - fmt.Printf("SSE stream error: %v\n", err) - return - } - } - - // Remove only newline markers - line = strings.TrimRight(line, "\r\n") - if line == "" { - // Empty line means end of event - if event != "" && data != "" { - c.handleSSEEvent(event, data) - event = "" - data = "" - } - continue - } - - if strings.HasPrefix(line, "event:") { - event = strings.TrimSpace(strings.TrimPrefix(line, "event:")) - } else if strings.HasPrefix(line, "data:") { - data = strings.TrimSpace(strings.TrimPrefix(line, "data:")) - } - } - } -} - -// handleSSEEvent processes SSE events based on their type. -// Handles 'endpoint' events for connection setup and 'message' events for JSON-RPC communication. -func (c *SSEMCPClient) handleSSEEvent(event, data string) { - switch event { - case "endpoint": - endpoint, err := c.baseURL.Parse(data) - if err != nil { - fmt.Printf("Error parsing endpoint URL: %v\n", err) - return - } - if endpoint.Host != c.baseURL.Host { - fmt.Printf("Endpoint origin does not match connection origin\n") - return - } - c.endpoint = endpoint - close(c.endpointChan) - - case "message": - var baseMessage struct { - JSONRPC string `json:"jsonrpc"` - ID *int64 `json:"id,omitempty"` - Method string `json:"method,omitempty"` - Result json.RawMessage `json:"result,omitempty"` - Error *struct { - Code int `json:"code"` - Message string `json:"message"` - } `json:"error,omitempty"` - } - - if err := json.Unmarshal([]byte(data), &baseMessage); err != nil { - fmt.Printf("Error unmarshaling message: %v\n", err) - return - } - - // Handle notification - if baseMessage.ID == nil { - var notification mcp.JSONRPCNotification - if err := json.Unmarshal([]byte(data), ¬ification); err != nil { - return - } - c.notifyMu.RLock() - for _, handler := range c.notifications { - handler(notification) - } - c.notifyMu.RUnlock() - return - } - - c.mu.RLock() - ch, ok := c.responses[*baseMessage.ID] - c.mu.RUnlock() - - if ok { - if baseMessage.Error != nil { - ch <- RPCResponse{ - Error: &baseMessage.Error.Message, - } - } else { - ch <- RPCResponse{ - Response: &baseMessage.Result, - } - } - c.mu.Lock() - delete(c.responses, *baseMessage.ID) - c.mu.Unlock() - } - } -} - -// OnNotification registers a handler function to be called when notifications are received. -// Multiple handlers can be registered and will be called in the order they were added. -func (c *SSEMCPClient) OnNotification( - handler func(notification mcp.JSONRPCNotification), -) { - c.notifyMu.Lock() - defer c.notifyMu.Unlock() - c.notifications = append(c.notifications, handler) -} - -// sendRequest sends a JSON-RPC request to the server and waits for a response. -// Returns the raw JSON response message or an error if the request fails. -func (c *SSEMCPClient) sendRequest( - ctx context.Context, - method string, - params interface{}, -) (*json.RawMessage, error) { - if !c.initialized && method != "initialize" { - return nil, fmt.Errorf("client not initialized") - } - - if c.endpoint == nil { - return nil, fmt.Errorf("endpoint not received") - } - - id := c.requestID.Add(1) - - request := mcp.JSONRPCRequest{ - JSONRPC: mcp.JSONRPC_VERSION, - ID: id, - Request: mcp.Request{ - Method: method, - }, - Params: params, - } - - requestBytes, err := json.Marshal(request) - if err != nil { - return nil, fmt.Errorf("failed to marshal request: %w", err) - } - - responseChan := make(chan RPCResponse, 1) - c.mu.Lock() - c.responses[id] = responseChan - c.mu.Unlock() - - req, err := http.NewRequestWithContext( - ctx, - "POST", - c.endpoint.String(), - bytes.NewReader(requestBytes), - ) - if err != nil { - return nil, fmt.Errorf("failed to create request: %w", err) - } - - req.Header.Set("Content-Type", "application/json") - // set custom http headers - for k, v := range c.headers { - req.Header.Set(k, v) - } - - resp, err := c.httpClient.Do(req) - if err != nil { - return nil, fmt.Errorf("failed to send request: %w", err) - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK && - resp.StatusCode != http.StatusAccepted { - body, _ := io.ReadAll(resp.Body) - return nil, fmt.Errorf( - "request failed with status %d: %s", - resp.StatusCode, - body, - ) - } - - select { - case <-ctx.Done(): - c.mu.Lock() - delete(c.responses, id) - c.mu.Unlock() - return nil, ctx.Err() - case response := <-responseChan: - if response.Error != nil { - return nil, errors.New(*response.Error) - } - return response.Response, nil - } -} - -func (c *SSEMCPClient) Initialize( - ctx context.Context, - request mcp.InitializeRequest, -) (*mcp.InitializeResult, error) { - // Ensure we send a params object with all required fields - params := struct { - ProtocolVersion string `json:"protocolVersion"` - ClientInfo mcp.Implementation `json:"clientInfo"` - Capabilities mcp.ClientCapabilities `json:"capabilities"` - }{ - ProtocolVersion: request.Params.ProtocolVersion, - ClientInfo: request.Params.ClientInfo, - Capabilities: request.Params.Capabilities, // Will be empty struct if not set - } - - response, err := c.sendRequest(ctx, "initialize", params) - if err != nil { - return nil, err - } - - var result mcp.InitializeResult - if err := json.Unmarshal(*response, &result); err != nil { - return nil, fmt.Errorf("failed to unmarshal response: %w", err) - } - - // Store capabilities - c.capabilities = result.Capabilities +func NewSSEMCPClient(baseURL string, options ...transport.ClientOption) (*Client, error) { - // Send initialized notification - notification := mcp.JSONRPCNotification{ - JSONRPC: mcp.JSONRPC_VERSION, - Notification: mcp.Notification{ - Method: "notifications/initialized", - }, - } - - notificationBytes, err := json.Marshal(notification) - if err != nil { - return nil, fmt.Errorf( - "failed to marshal initialized notification: %w", - err, - ) - } - - req, err := http.NewRequestWithContext( - ctx, - "POST", - c.endpoint.String(), - bytes.NewReader(notificationBytes), - ) - if err != nil { - return nil, fmt.Errorf("failed to create notification request: %w", err) - } - - req.Header.Set("Content-Type", "application/json") - - resp, err := c.httpClient.Do(req) - if err != nil { - return nil, fmt.Errorf( - "failed to send initialized notification: %w", - err, - ) - } - defer resp.Body.Close() - - c.initialized = true - return &result, nil -} - -func (c *SSEMCPClient) Ping(ctx context.Context) error { - _, err := c.sendRequest(ctx, "ping", nil) - return err -} - -// ListResourcesByPage manually list resources by page. -func (c *SSEMCPClient) ListResourcesByPage( - ctx context.Context, - request mcp.ListResourcesRequest, -) (*mcp.ListResourcesResult, error) { - result, err := listByPage[mcp.ListResourcesResult](ctx, c, request.PaginatedRequest, "resources/list") - if err != nil { - return nil, err - } - return result, nil -} - -func (c *SSEMCPClient) ListResources( - ctx context.Context, - request mcp.ListResourcesRequest, -) (*mcp.ListResourcesResult, error) { - result, err := c.ListResourcesByPage(ctx, request) - if err != nil { - return nil, err - } - for result.NextCursor != "" { - select { - case <-ctx.Done(): - return nil, ctx.Err() - default: - request.Params.Cursor = result.NextCursor - newPageRes, err := c.ListResourcesByPage(ctx, request) - if err != nil { - return nil, err - } - result.Resources = append(result.Resources, newPageRes.Resources...) - result.NextCursor = newPageRes.NextCursor - } - } - return result, nil -} - -func (c *SSEMCPClient) ListResourceTemplatesByPage( - ctx context.Context, - request mcp.ListResourceTemplatesRequest, -) (*mcp.ListResourceTemplatesResult, error) { - result, err := listByPage[mcp.ListResourceTemplatesResult](ctx, c, request.PaginatedRequest, "resources/templates/list") - if err != nil { - return nil, err - } - return result, nil -} - -func (c *SSEMCPClient) ListResourceTemplates( - ctx context.Context, - request mcp.ListResourceTemplatesRequest, -) (*mcp.ListResourceTemplatesResult, error) { - result, err := c.ListResourceTemplatesByPage(ctx, request) - if err != nil { - return nil, err - } - for result.NextCursor != "" { - select { - case <-ctx.Done(): - return nil, ctx.Err() - default: - request.Params.Cursor = result.NextCursor - newPageRes, err := c.ListResourceTemplatesByPage(ctx, request) - if err != nil { - return nil, err - } - result.ResourceTemplates = append(result.ResourceTemplates, newPageRes.ResourceTemplates...) - result.NextCursor = newPageRes.NextCursor - } - } - return result, nil -} - -func (c *SSEMCPClient) ReadResource( - ctx context.Context, - request mcp.ReadResourceRequest, -) (*mcp.ReadResourceResult, error) { - response, err := c.sendRequest(ctx, "resources/read", request.Params) - if err != nil { - return nil, err - } - - return mcp.ParseReadResourceResult(response) -} - -func (c *SSEMCPClient) Subscribe( - ctx context.Context, - request mcp.SubscribeRequest, -) error { - _, err := c.sendRequest(ctx, "resources/subscribe", request.Params) - return err -} - -func (c *SSEMCPClient) Unsubscribe( - ctx context.Context, - request mcp.UnsubscribeRequest, -) error { - _, err := c.sendRequest(ctx, "resources/unsubscribe", request.Params) - return err -} - -func (c *SSEMCPClient) ListPromptsByPage( - ctx context.Context, - request mcp.ListPromptsRequest, -) (*mcp.ListPromptsResult, error) { - result, err := listByPage[mcp.ListPromptsResult](ctx, c, request.PaginatedRequest, "prompts/list") - if err != nil { - return nil, err - } - return result, nil -} - -func (c *SSEMCPClient) ListPrompts( - ctx context.Context, - request mcp.ListPromptsRequest, -) (*mcp.ListPromptsResult, error) { - result, err := c.ListPromptsByPage(ctx, request) - if err != nil { - return nil, err - } - for result.NextCursor != "" { - select { - case <-ctx.Done(): - return nil, ctx.Err() - default: - request.Params.Cursor = result.NextCursor - newPageRes, err := c.ListPromptsByPage(ctx, request) - if err != nil { - return nil, err - } - result.Prompts = append(result.Prompts, newPageRes.Prompts...) - result.NextCursor = newPageRes.NextCursor - } - } - return result, nil -} - -func (c *SSEMCPClient) GetPrompt( - ctx context.Context, - request mcp.GetPromptRequest, -) (*mcp.GetPromptResult, error) { - response, err := c.sendRequest(ctx, "prompts/get", request.Params) + sseTransport, err := transport.NewSSE(baseURL, options...) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to create SSE transport: %w", err) } - return mcp.ParseGetPromptResult(response) + return NewClient(sseTransport), nil } -func (c *SSEMCPClient) ListToolsByPage( - ctx context.Context, - request mcp.ListToolsRequest, -) (*mcp.ListToolsResult, error) { - result, err := listByPage[mcp.ListToolsResult](ctx, c, request.PaginatedRequest, "tools/list") - if err != nil { - return nil, err - } - return result, nil -} - -func (c *SSEMCPClient) ListTools( - ctx context.Context, - request mcp.ListToolsRequest, -) (*mcp.ListToolsResult, error) { - result, err := c.ListToolsByPage(ctx, request) - if err != nil { - return nil, err - } - for result.NextCursor != "" { - select { - case <-ctx.Done(): - return nil, ctx.Err() - default: - request.Params.Cursor = result.NextCursor - newPageRes, err := c.ListToolsByPage(ctx, request) - if err != nil { - return nil, err - } - result.Tools = append(result.Tools, newPageRes.Tools...) - result.NextCursor = newPageRes.NextCursor - } - } - return result, nil -} - -func (c *SSEMCPClient) CallTool( - ctx context.Context, - request mcp.CallToolRequest, -) (*mcp.CallToolResult, error) { - response, err := c.sendRequest(ctx, "tools/call", request.Params) - if err != nil { - return nil, err - } - - return mcp.ParseCallToolResult(response) -} - -func (c *SSEMCPClient) SetLevel( - ctx context.Context, - request mcp.SetLevelRequest, -) error { - _, err := c.sendRequest(ctx, "logging/setLevel", request.Params) - return err -} - -func (c *SSEMCPClient) Complete( - ctx context.Context, - request mcp.CompleteRequest, -) (*mcp.CompleteResult, error) { - response, err := c.sendRequest(ctx, "completion/complete", request.Params) - if err != nil { - return nil, err - } - - var result mcp.CompleteResult - if err := json.Unmarshal(*response, &result); err != nil { - return nil, fmt.Errorf("failed to unmarshal response: %w", err) - } - - return &result, nil -} - -// Helper methods - // GetEndpoint returns the current endpoint URL for the SSE connection. -func (c *SSEMCPClient) GetEndpoint() *url.URL { - return c.endpoint -} - -// Close shuts down the SSE client connection and cleans up any pending responses. -// Returns an error if the shutdown process fails. -func (c *SSEMCPClient) Close() error { - select { - case <-c.done: - return nil // Already closed - default: - close(c.done) - } - - // Clean up any pending responses - c.mu.Lock() - for _, ch := range c.responses { - close(ch) - } - c.responses = make(map[int64]chan RPCResponse) - c.mu.Unlock() - - return nil +// +// Note: This method only works with SSE transport, or it will panic. +func GetEndpoint(c *Client) *url.URL { + t := c.GetTransport() + sse := t.(*transport.SSE) + return sse.GetEndpoint() } diff --git a/client/sse_test.go b/client/sse_test.go index 366fbc51..8e3607f6 100644 --- a/client/sse_test.go +++ b/client/sse_test.go @@ -2,6 +2,7 @@ package client import ( "context" + "github.com/mark3labs/mcp-go/client/transport" "testing" "time" @@ -24,6 +25,13 @@ func TestSSEMCPClient(t *testing.T) { "test-tool", mcp.WithDescription("Test tool"), mcp.WithString("parameter-1", mcp.Description("A string tool parameter")), + mcp.WithToolAnnotation(mcp.ToolAnnotation{ + Title: "Test Tool Annotation Title", + ReadOnlyHint: true, + DestructiveHint: false, + IdempotentHint: true, + OpenWorldHint: false, + }), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { return &mcp.CallToolResult{ Content: []mcp.Content{ @@ -46,7 +54,8 @@ func TestSSEMCPClient(t *testing.T) { } defer client.Close() - if client.baseURL == nil { + sseTransport := client.GetTransport().(*transport.SSE) + if sseTransport.GetBaseURL() == nil { t.Error("Base URL should not be nil") } }) @@ -93,10 +102,21 @@ func TestSSEMCPClient(t *testing.T) { // Test ListTools toolsRequest := mcp.ListToolsRequest{} - _, err = client.ListTools(ctx, toolsRequest) + toolListResult, err := client.ListTools(ctx, toolsRequest) if err != nil { t.Errorf("ListTools failed: %v", err) } + if toolListResult == nil || len((*toolListResult).Tools) == 0 { + t.Errorf("Expected one tool") + } + testToolAnnotations := (*toolListResult).Tools[0].Annotations + if testToolAnnotations.Title != "Test Tool Annotation Title" || + testToolAnnotations.ReadOnlyHint != true || + testToolAnnotations.DestructiveHint != false || + testToolAnnotations.IdempotentHint != true || + testToolAnnotations.OpenWorldHint != false { + t.Errorf("The annotations of the tools are invalid") + } }) // t.Run("Can handle notifications", func(t *testing.T) { diff --git a/client/stdio.go b/client/stdio.go index c9233492..a25f6d19 100644 --- a/client/stdio.go +++ b/client/stdio.go @@ -1,524 +1,40 @@ package client import ( - "bufio" "context" - "encoding/json" - "errors" "fmt" "io" - "os" - "os/exec" - "sync" - "sync/atomic" - "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/client/transport" ) -// StdioMCPClient implements the MCPClient interface using stdio communication. -// It launches a subprocess and communicates with it via standard input/output streams -// using JSON-RPC messages. The client handles message routing between requests and -// responses, and supports asynchronous notifications. -type StdioMCPClient struct { - cmd *exec.Cmd - stdin io.WriteCloser - stdout *bufio.Reader - stderr io.ReadCloser - requestID atomic.Int64 - responses map[int64]chan RPCResponse - mu sync.RWMutex - done chan struct{} - initialized bool - notifications []func(mcp.JSONRPCNotification) - notifyMu sync.RWMutex - capabilities mcp.ServerCapabilities -} - // NewStdioMCPClient creates a new stdio-based MCP client that communicates with a subprocess. // It launches the specified command with given arguments and sets up stdin/stdout pipes for communication. // Returns an error if the subprocess cannot be started or the pipes cannot be created. +// +// NOTICE: NewStdioMCPClient will start the connection automatically. Don't call the Start method manually. +// This is for backward compatibility. func NewStdioMCPClient( command string, env []string, args ...string, -) (*StdioMCPClient, error) { - cmd := exec.Command(command, args...) - - mergedEnv := os.Environ() - mergedEnv = append(mergedEnv, env...) - - cmd.Env = mergedEnv - - stdin, err := cmd.StdinPipe() - if err != nil { - return nil, fmt.Errorf("failed to create stdin pipe: %w", err) - } - - stdout, err := cmd.StdoutPipe() - if err != nil { - return nil, fmt.Errorf("failed to create stdout pipe: %w", err) - } +) (*Client, error) { - stderr, err := cmd.StderrPipe() + stdioTransport := transport.NewStdio(command, env, args...) + err := stdioTransport.Start(context.Background()) if err != nil { - return nil, fmt.Errorf("failed to create stderr pipe: %w", err) - } - - client := &StdioMCPClient{ - cmd: cmd, - stdin: stdin, - stderr: stderr, - stdout: bufio.NewReader(stdout), - responses: make(map[int64]chan RPCResponse), - done: make(chan struct{}), + return nil, fmt.Errorf("failed to start stdio transport: %w", err) } - if err := cmd.Start(); err != nil { - return nil, fmt.Errorf("failed to start command: %w", err) - } - - // Start reading responses in a goroutine and wait for it to be ready - ready := make(chan struct{}) - go func() { - close(ready) - client.readResponses() - }() - <-ready - - return client, nil -} - -// Close shuts down the stdio client, closing the stdin pipe and waiting for the subprocess to exit. -// Returns an error if there are issues closing stdin or waiting for the subprocess to terminate. -func (c *StdioMCPClient) Close() error { - close(c.done) - if err := c.stdin.Close(); err != nil { - return fmt.Errorf("failed to close stdin: %w", err) - } - if err := c.stderr.Close(); err != nil { - return fmt.Errorf("failed to close stderr: %w", err) - } - return c.cmd.Wait() + return NewClient(stdioTransport), nil } -// Stderr returns a reader for the stderr output of the subprocess. +// GetStderr returns a reader for the stderr output of the subprocess. // This can be used to capture error messages or logs from the subprocess. -func (c *StdioMCPClient) Stderr() io.Reader { - return c.stderr -} - -// OnNotification registers a handler function to be called when notifications are received. -// Multiple handlers can be registered and will be called in the order they were added. -func (c *StdioMCPClient) OnNotification( - handler func(notification mcp.JSONRPCNotification), -) { - c.notifyMu.Lock() - defer c.notifyMu.Unlock() - c.notifications = append(c.notifications, handler) -} - -// readResponses continuously reads and processes responses from the server's stdout. -// It handles both responses to requests and notifications, routing them appropriately. -// Runs until the done channel is closed or an error occurs reading from stdout. -func (c *StdioMCPClient) readResponses() { - for { - select { - case <-c.done: - return - default: - line, err := c.stdout.ReadString('\n') - if err != nil { - if err != io.EOF { - fmt.Printf("Error reading response: %v\n", err) - } - return - } - - var baseMessage struct { - JSONRPC string `json:"jsonrpc"` - ID *int64 `json:"id,omitempty"` - Method string `json:"method,omitempty"` - Result json.RawMessage `json:"result,omitempty"` - Error *struct { - Code int `json:"code"` - Message string `json:"message"` - } `json:"error,omitempty"` - } - - if err := json.Unmarshal([]byte(line), &baseMessage); err != nil { - continue - } - - // Handle notification - if baseMessage.ID == nil { - var notification mcp.JSONRPCNotification - if err := json.Unmarshal([]byte(line), ¬ification); err != nil { - continue - } - c.notifyMu.RLock() - for _, handler := range c.notifications { - handler(notification) - } - c.notifyMu.RUnlock() - continue - } - - c.mu.RLock() - ch, ok := c.responses[*baseMessage.ID] - c.mu.RUnlock() - - if ok { - if baseMessage.Error != nil { - ch <- RPCResponse{ - Error: &baseMessage.Error.Message, - } - } else { - ch <- RPCResponse{ - Response: &baseMessage.Result, - } - } - c.mu.Lock() - delete(c.responses, *baseMessage.ID) - c.mu.Unlock() - } - } - } -} - -// sendRequest sends a JSON-RPC request to the server and waits for a response. -// It creates a unique request ID, sends the request over stdin, and waits for -// the corresponding response or context cancellation. -// Returns the raw JSON response message or an error if the request fails. -func (c *StdioMCPClient) sendRequest( - ctx context.Context, - method string, - params interface{}, -) (*json.RawMessage, error) { - if !c.initialized && method != "initialize" { - return nil, fmt.Errorf("client not initialized") - } - - id := c.requestID.Add(1) - - // Create the complete request structure - request := mcp.JSONRPCRequest{ - JSONRPC: mcp.JSONRPC_VERSION, - ID: id, - Request: mcp.Request{ - Method: method, - }, - Params: params, - } - - responseChan := make(chan RPCResponse, 1) - c.mu.Lock() - c.responses[id] = responseChan - c.mu.Unlock() - - requestBytes, err := json.Marshal(request) - if err != nil { - return nil, fmt.Errorf("failed to marshal request: %w", err) - } - requestBytes = append(requestBytes, '\n') - - if _, err := c.stdin.Write(requestBytes); err != nil { - return nil, fmt.Errorf("failed to write request: %w", err) - } - - select { - case <-ctx.Done(): - c.mu.Lock() - delete(c.responses, id) - c.mu.Unlock() - return nil, ctx.Err() - case response := <-responseChan: - if response.Error != nil { - return nil, errors.New(*response.Error) - } - return response.Response, nil - } -} - -func (c *StdioMCPClient) Ping(ctx context.Context) error { - _, err := c.sendRequest(ctx, "ping", nil) - return err -} - -func (c *StdioMCPClient) Initialize( - ctx context.Context, - request mcp.InitializeRequest, -) (*mcp.InitializeResult, error) { - // This structure ensures Capabilities is always included in JSON - params := struct { - ProtocolVersion string `json:"protocolVersion"` - ClientInfo mcp.Implementation `json:"clientInfo"` - Capabilities mcp.ClientCapabilities `json:"capabilities"` - }{ - ProtocolVersion: request.Params.ProtocolVersion, - ClientInfo: request.Params.ClientInfo, - Capabilities: request.Params.Capabilities, // Will be empty struct if not set - } - - response, err := c.sendRequest(ctx, "initialize", params) - if err != nil { - return nil, err - } - - var result mcp.InitializeResult - if err := json.Unmarshal(*response, &result); err != nil { - return nil, fmt.Errorf("failed to unmarshal response: %w", err) - } - - // Store capabilities - c.capabilities = result.Capabilities - - // Send initialized notification - notification := mcp.JSONRPCNotification{ - JSONRPC: mcp.JSONRPC_VERSION, - Notification: mcp.Notification{ - Method: "notifications/initialized", - }, - } - - notificationBytes, err := json.Marshal(notification) - if err != nil { - return nil, fmt.Errorf( - "failed to marshal initialized notification: %w", - err, - ) - } - notificationBytes = append(notificationBytes, '\n') - - if _, err := c.stdin.Write(notificationBytes); err != nil { - return nil, fmt.Errorf( - "failed to send initialized notification: %w", - err, - ) - } - - c.initialized = true - return &result, nil -} - -// ListResourcesByPage manually list resources by page. -func (c *StdioMCPClient) ListResourcesByPage( - ctx context.Context, - request mcp.ListResourcesRequest, -) (*mcp.ListResourcesResult, error) { - result, err := listByPage[mcp.ListResourcesResult](ctx, c, request.PaginatedRequest, "resources/list") - if err != nil { - return nil, err - } - return result, nil -} - -func (c *StdioMCPClient) ListResources( - ctx context.Context, - request mcp.ListResourcesRequest, -) (*mcp.ListResourcesResult, error) { - result, err := c.ListResourcesByPage(ctx, request) - if err != nil { - return nil, err - } - for result.NextCursor != "" { - select { - case <-ctx.Done(): - return nil, ctx.Err() - default: - request.Params.Cursor = result.NextCursor - newPageRes, err := c.ListResourcesByPage(ctx, request) - if err != nil { - return nil, err - } - result.Resources = append(result.Resources, newPageRes.Resources...) - result.NextCursor = newPageRes.NextCursor - } - } - return result, nil -} - -func (c *StdioMCPClient) ListResourceTemplatesByPage( - ctx context.Context, - request mcp.ListResourceTemplatesRequest, -) (*mcp.ListResourceTemplatesResult, error) { - result, err := listByPage[mcp.ListResourceTemplatesResult](ctx, c, request.PaginatedRequest, "resources/templates/list") - if err != nil { - return nil, err - } - return result, nil -} - -func (c *StdioMCPClient) ListResourceTemplates( - ctx context.Context, - request mcp.ListResourceTemplatesRequest, -) (*mcp.ListResourceTemplatesResult, error) { - result, err := c.ListResourceTemplatesByPage(ctx, request) - if err != nil { - return nil, err - } - for result.NextCursor != "" { - select { - case <-ctx.Done(): - return nil, ctx.Err() - default: - request.Params.Cursor = result.NextCursor - newPageRes, err := c.ListResourceTemplatesByPage(ctx, request) - if err != nil { - return nil, err - } - result.ResourceTemplates = append(result.ResourceTemplates, newPageRes.ResourceTemplates...) - result.NextCursor = newPageRes.NextCursor - } - } - return result, nil -} - -func (c *StdioMCPClient) ReadResource( - ctx context.Context, - request mcp.ReadResourceRequest, -) (*mcp.ReadResourceResult, - error) { - response, err := c.sendRequest(ctx, "resources/read", request.Params) - if err != nil { - return nil, err - } - - return mcp.ParseReadResourceResult(response) -} - -func (c *StdioMCPClient) Subscribe( - ctx context.Context, - request mcp.SubscribeRequest, -) error { - _, err := c.sendRequest(ctx, "resources/subscribe", request.Params) - return err -} - -func (c *StdioMCPClient) Unsubscribe( - ctx context.Context, - request mcp.UnsubscribeRequest, -) error { - _, err := c.sendRequest(ctx, "resources/unsubscribe", request.Params) - return err -} - -func (c *StdioMCPClient) ListPromptsByPage( - ctx context.Context, - request mcp.ListPromptsRequest, -) (*mcp.ListPromptsResult, error) { - result, err := listByPage[mcp.ListPromptsResult](ctx, c, request.PaginatedRequest, "prompts/list") - if err != nil { - return nil, err - } - return result, nil -} - -func (c *StdioMCPClient) ListPrompts( - ctx context.Context, - request mcp.ListPromptsRequest, -) (*mcp.ListPromptsResult, error) { - result, err := c.ListPromptsByPage(ctx, request) - if err != nil { - return nil, err - } - for result.NextCursor != "" { - select { - case <-ctx.Done(): - return nil, ctx.Err() - default: - request.Params.Cursor = result.NextCursor - newPageRes, err := c.ListPromptsByPage(ctx, request) - if err != nil { - return nil, err - } - result.Prompts = append(result.Prompts, newPageRes.Prompts...) - result.NextCursor = newPageRes.NextCursor - } - } - return result, nil -} - -func (c *StdioMCPClient) GetPrompt( - ctx context.Context, - request mcp.GetPromptRequest, -) (*mcp.GetPromptResult, error) { - response, err := c.sendRequest(ctx, "prompts/get", request.Params) - if err != nil { - return nil, err - } - - return mcp.ParseGetPromptResult(response) -} - -func (c *StdioMCPClient) ListToolsByPage( - ctx context.Context, - request mcp.ListToolsRequest, -) (*mcp.ListToolsResult, error) { - result, err := listByPage[mcp.ListToolsResult](ctx, c, request.PaginatedRequest, "tools/list") - if err != nil { - return nil, err - } - return result, nil -} - -func (c *StdioMCPClient) ListTools( - ctx context.Context, - request mcp.ListToolsRequest, -) (*mcp.ListToolsResult, error) { - result, err := c.ListToolsByPage(ctx, request) - if err != nil { - return nil, err - } - for result.NextCursor != "" { - select { - case <-ctx.Done(): - return nil, ctx.Err() - default: - request.Params.Cursor = result.NextCursor - newPageRes, err := c.ListToolsByPage(ctx, request) - if err != nil { - return nil, err - } - result.Tools = append(result.Tools, newPageRes.Tools...) - result.NextCursor = newPageRes.NextCursor - } - } - return result, nil -} - -func (c *StdioMCPClient) CallTool( - ctx context.Context, - request mcp.CallToolRequest, -) (*mcp.CallToolResult, error) { - response, err := c.sendRequest(ctx, "tools/call", request.Params) - if err != nil { - return nil, err - } - - return mcp.ParseCallToolResult(response) -} - -func (c *StdioMCPClient) SetLevel( - ctx context.Context, - request mcp.SetLevelRequest, -) error { - _, err := c.sendRequest(ctx, "logging/setLevel", request.Params) - return err -} - -func (c *StdioMCPClient) Complete( - ctx context.Context, - request mcp.CompleteRequest, -) (*mcp.CompleteResult, error) { - response, err := c.sendRequest(ctx, "completion/complete", request.Params) - if err != nil { - return nil, err - } - - var result mcp.CompleteResult - if err := json.Unmarshal(*response, &result); err != nil { - return nil, fmt.Errorf("failed to unmarshal response: %w", err) - } - - return &result, nil +// +// Note: This method only works with stdio transport, or it will panic. +func GetStderr(c *Client) io.Reader { + t := c.GetTransport() + stdio := t.(*transport.Stdio) + return stdio.Stderr() } diff --git a/client/stdio_test.go b/client/stdio_test.go index df69b46a..94da0b54 100644 --- a/client/stdio_test.go +++ b/client/stdio_test.go @@ -47,7 +47,7 @@ func TestStdioMCPClient(t *testing.T) { wg.Add(1) go func() { defer wg.Done() - dec := json.NewDecoder(client.Stderr()) + dec := json.NewDecoder(GetStderr(client)) for { var record map[string]any if err := dec.Decode(&record); err != nil { diff --git a/client/transport/interface.go b/client/transport/interface.go new file mode 100644 index 00000000..8ac75d74 --- /dev/null +++ b/client/transport/interface.go @@ -0,0 +1,45 @@ +package transport + +import ( + "context" + "encoding/json" + + "github.com/mark3labs/mcp-go/mcp" +) + +// Interface for the transport layer. +type Interface interface { + // Start the connection. Start should only be called once. + Start(ctx context.Context) error + + // SendRequest sends a json RPC request and returns the response synchronously. + SendRequest(ctx context.Context, request JSONRPCRequest) (*JSONRPCResponse, error) + + // SendNotification sends a json RPC Notification to the server. + SendNotification(ctx context.Context, notification mcp.JSONRPCNotification) error + + // SetNotificationHandler sets the handler for notifications. + // Any notification before the handler is set will be discarded. + SetNotificationHandler(handler func(notification mcp.JSONRPCNotification)) + + // Close the connection. + Close() error +} + +type JSONRPCRequest struct { + JSONRPC string `json:"jsonrpc"` + ID int64 `json:"id"` + Method string `json:"method"` + Params any `json:"params,omitempty"` +} + +type JSONRPCResponse struct { + JSONRPC string `json:"jsonrpc"` + ID *int64 `json:"id"` + Result json.RawMessage `json:"result"` + Error *struct { + Code int `json:"code"` + Message string `json:"message"` + Data json.RawMessage `json:"data"` + } `json:"error"` +} diff --git a/client/transport/sse.go b/client/transport/sse.go new file mode 100644 index 00000000..a515ae76 --- /dev/null +++ b/client/transport/sse.go @@ -0,0 +1,376 @@ +package transport + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/mark3labs/mcp-go/mcp" +) + +// SSE implements the transport layer of the MCP protocol using Server-Sent Events (SSE). +// It maintains a persistent HTTP connection to receive server-pushed events +// while sending requests over regular HTTP POST calls. The client handles +// automatic reconnection and message routing between requests and responses. +type SSE struct { + baseURL *url.URL + endpoint *url.URL + httpClient *http.Client + responses map[int64]chan *JSONRPCResponse + mu sync.RWMutex + onNotification func(mcp.JSONRPCNotification) + notifyMu sync.RWMutex + endpointChan chan struct{} + headers map[string]string + + started atomic.Bool + closed atomic.Bool + cancelSSEStream context.CancelFunc +} + +type ClientOption func(*SSE) + +func WithHeaders(headers map[string]string) ClientOption { + return func(sc *SSE) { + sc.headers = headers + } +} + +// NewSSE creates a new SSE-based MCP client with the given base URL. +// Returns an error if the URL is invalid. +func NewSSE(baseURL string, options ...ClientOption) (*SSE, error) { + parsedURL, err := url.Parse(baseURL) + if err != nil { + return nil, fmt.Errorf("invalid URL: %w", err) + } + + smc := &SSE{ + baseURL: parsedURL, + httpClient: &http.Client{}, + responses: make(map[int64]chan *JSONRPCResponse), + endpointChan: make(chan struct{}), + headers: make(map[string]string), + } + + for _, opt := range options { + opt(smc) + } + + return smc, nil +} + +// Start initiates the SSE connection to the server and waits for the endpoint information. +// Returns an error if the connection fails or times out waiting for the endpoint. +func (c *SSE) Start(ctx context.Context) error { + + if c.started.Load() { + return fmt.Errorf("has already started") + } + + ctx, cancel := context.WithCancel(ctx) + c.cancelSSEStream = cancel + + req, err := http.NewRequestWithContext(ctx, "GET", c.baseURL.String(), nil) + + if err != nil { + return fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Accept", "text/event-stream") + req.Header.Set("Cache-Control", "no-cache") + req.Header.Set("Connection", "keep-alive") + + // set custom http headers + for k, v := range c.headers { + req.Header.Set(k, v) + } + + resp, err := c.httpClient.Do(req) + if err != nil { + return fmt.Errorf("failed to connect to SSE stream: %w", err) + } + + if resp.StatusCode != http.StatusOK { + resp.Body.Close() + return fmt.Errorf("unexpected status code: %d", resp.StatusCode) + } + + go c.readSSE(resp.Body) + + // Wait for the endpoint to be received + timeout := time.NewTimer(30 * time.Second) + defer timeout.Stop() + select { + case <-c.endpointChan: + // Endpoint received, proceed + case <-ctx.Done(): + return fmt.Errorf("context cancelled while waiting for endpoint") + case <-timeout.C: // Add a timeout + cancel() + return fmt.Errorf("timeout waiting for endpoint") + } + + c.started.Store(true) + return nil +} + +// readSSE continuously reads the SSE stream and processes events. +// It runs until the connection is closed or an error occurs. +func (c *SSE) readSSE(reader io.ReadCloser) { + defer reader.Close() + + br := bufio.NewReader(reader) + var event, data string + + for { + // when close or start's ctx cancel, the reader will be closed + // and the for loop will break. + line, err := br.ReadString('\n') + if err != nil { + if err == io.EOF { + // Process any pending event before exit + if event != "" && data != "" { + c.handleSSEEvent(event, data) + } + break + } + if !c.closed.Load() { + fmt.Printf("SSE stream error: %v\n", err) + } + return + } + + // Remove only newline markers + line = strings.TrimRight(line, "\r\n") + if line == "" { + // Empty line means end of event + if event != "" && data != "" { + c.handleSSEEvent(event, data) + event = "" + data = "" + } + continue + } + + if strings.HasPrefix(line, "event:") { + event = strings.TrimSpace(strings.TrimPrefix(line, "event:")) + } else if strings.HasPrefix(line, "data:") { + data = strings.TrimSpace(strings.TrimPrefix(line, "data:")) + } + } +} + +// handleSSEEvent processes SSE events based on their type. +// Handles 'endpoint' events for connection setup and 'message' events for JSON-RPC communication. +func (c *SSE) handleSSEEvent(event, data string) { + switch event { + case "endpoint": + endpoint, err := c.baseURL.Parse(data) + if err != nil { + fmt.Printf("Error parsing endpoint URL: %v\n", err) + return + } + if endpoint.Host != c.baseURL.Host { + fmt.Printf("Endpoint origin does not match connection origin\n") + return + } + c.endpoint = endpoint + close(c.endpointChan) + + case "message": + var baseMessage JSONRPCResponse + if err := json.Unmarshal([]byte(data), &baseMessage); err != nil { + fmt.Printf("Error unmarshaling message: %v\n", err) + return + } + + // Handle notification + if baseMessage.ID == nil { + var notification mcp.JSONRPCNotification + if err := json.Unmarshal([]byte(data), ¬ification); err != nil { + return + } + c.notifyMu.RLock() + if c.onNotification != nil { + c.onNotification(notification) + } + c.notifyMu.RUnlock() + return + } + + c.mu.RLock() + ch, ok := c.responses[*baseMessage.ID] + c.mu.RUnlock() + + if ok { + ch <- &baseMessage + c.mu.Lock() + delete(c.responses, *baseMessage.ID) + c.mu.Unlock() + } + } +} + +func (c *SSE) SetNotificationHandler(handler func(notification mcp.JSONRPCNotification)) { + c.notifyMu.Lock() + defer c.notifyMu.Unlock() + c.onNotification = handler +} + +// sendRequest sends a JSON-RPC request to the server and waits for a response. +// Returns the raw JSON response message or an error if the request fails. +func (c *SSE) SendRequest( + ctx context.Context, + request JSONRPCRequest, +) (*JSONRPCResponse, error) { + + if !c.started.Load() { + return nil, fmt.Errorf("transport not started yet") + } + if c.closed.Load() { + return nil, fmt.Errorf("transport has been closed") + } + if c.endpoint == nil { + return nil, fmt.Errorf("endpoint not received") + } + + requestBytes, err := json.Marshal(request) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + responseChan := make(chan *JSONRPCResponse, 1) + c.mu.Lock() + c.responses[request.ID] = responseChan + c.mu.Unlock() + + req, err := http.NewRequestWithContext( + ctx, + "POST", + c.endpoint.String(), + bytes.NewReader(requestBytes), + ) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + // set custom http headers + for k, v := range c.headers { + req.Header.Set(k, v) + } + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK && + resp.StatusCode != http.StatusAccepted { + body, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf( + "request failed with status %d: %s", + resp.StatusCode, + body, + ) + } + + select { + case <-ctx.Done(): + c.mu.Lock() + delete(c.responses, request.ID) + c.mu.Unlock() + return nil, ctx.Err() + case response := <-responseChan: + return response, nil + } +} + +// Close shuts down the SSE client connection and cleans up any pending responses. +// Returns an error if the shutdown process fails. +func (c *SSE) Close() error { + if !c.closed.CompareAndSwap(false, true) { + return nil // Already closed + } + + if c.cancelSSEStream != nil { + // It could stop the sse stream body, to quit the readSSE loop immediately + // Also, it could quit start() immediately if not receiving the endpoint + c.cancelSSEStream() + } + + // Clean up any pending responses + c.mu.Lock() + for _, ch := range c.responses { + close(ch) + } + c.responses = make(map[int64]chan *JSONRPCResponse) + c.mu.Unlock() + + return nil +} + +// SendNotification sends a JSON-RPC notification to the server without expecting a response. +func (c *SSE) SendNotification(ctx context.Context, notification mcp.JSONRPCNotification) error { + if c.endpoint == nil { + return fmt.Errorf("endpoint not received") + } + + notificationBytes, err := json.Marshal(notification) + if err != nil { + return fmt.Errorf("failed to marshal notification: %w", err) + } + + req, err := http.NewRequestWithContext( + ctx, + "POST", + c.endpoint.String(), + bytes.NewReader(notificationBytes), + ) + if err != nil { + return fmt.Errorf("failed to create notification request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + // Set custom HTTP headers + for k, v := range c.headers { + req.Header.Set(k, v) + } + + resp, err := c.httpClient.Do(req) + if err != nil { + return fmt.Errorf("failed to send notification: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusAccepted { + body, _ := io.ReadAll(resp.Body) + return fmt.Errorf( + "notification failed with status %d: %s", + resp.StatusCode, + body, + ) + } + + return nil +} + +// GetEndpoint returns the current endpoint URL for the SSE connection. +func (c *SSE) GetEndpoint() *url.URL { + return c.endpoint +} + +// GetBaseURL returns the base URL set in the SSE constructor. +func (c *SSE) GetBaseURL() *url.URL { + return c.baseURL +} diff --git a/client/transport/sse_test.go b/client/transport/sse_test.go new file mode 100644 index 00000000..0c4dff6a --- /dev/null +++ b/client/transport/sse_test.go @@ -0,0 +1,480 @@ +package transport + +import ( + "context" + "encoding/json" + "errors" + "sync" + "testing" + "time" + + "fmt" + "net/http" + "net/http/httptest" + + "github.com/mark3labs/mcp-go/mcp" +) + +// startMockSSEEchoServer starts a test HTTP server that implements +// a minimal SSE-based echo server for testing purposes. +// It returns the server URL and a function to close the server. +func startMockSSEEchoServer() (string, func()) { + // Create handler for SSE endpoint + var sseWriter http.ResponseWriter + var flush func() + var mu sync.Mutex + sseHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Setup SSE headers + defer func() { + mu.Lock() // for passing race test + sseWriter = nil + flush = nil + mu.Unlock() + fmt.Printf("SSEHandler ends: %v\n", r.Context().Err()) + }() + + w.Header().Set("Content-Type", "text/event-stream") + flusher, ok := w.(http.Flusher) + if !ok { + http.Error(w, "Streaming unsupported", http.StatusInternalServerError) + return + } + + mu.Lock() + sseWriter = w + flush = flusher.Flush + mu.Unlock() + + // Send initial endpoint event with message endpoint URL + mu.Lock() + fmt.Fprintf(w, "event: endpoint\ndata: %s\n\n", "/message") + flusher.Flush() + mu.Unlock() + + // Keep connection open + <-r.Context().Done() + }) + + // Create handler for message endpoint + messageHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Handle only POST requests + if r.Method != http.MethodPost { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + // Parse incoming JSON-RPC request + var request map[string]interface{} + decoder := json.NewDecoder(r.Body) + if err := decoder.Decode(&request); err != nil { + http.Error(w, fmt.Sprintf("Invalid JSON: %v", err), http.StatusBadRequest) + return + } + + // Echo back the request as the response result + response := map[string]interface{}{ + "jsonrpc": "2.0", + "id": request["id"], + "result": request, + } + + method := request["method"] + switch method { + case "debug/echo": + response["result"] = request + case "debug/echo_notification": + response["result"] = request + // send notification to client + responseBytes, _ := json.Marshal(map[string]any{ + "jsonrpc": "2.0", + "method": "debug/test", + "params": request, + }) + mu.Lock() + fmt.Fprintf(sseWriter, "event: message\ndata: %s\n\n", responseBytes) + flush() + mu.Unlock() + case "debug/echo_error_string": + data, _ := json.Marshal(request) + response["error"] = map[string]interface{}{ + "code": -1, + "message": string(data), + } + } + + // Set response headers + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusAccepted) + + go func() { + data, _ := json.Marshal(response) + mu.Lock() + defer mu.Unlock() + if sseWriter != nil && flush != nil { + fmt.Fprintf(sseWriter, "event: message\ndata: %s\n\n", data) + flush() + } + }() + + }) + + // Create a router to handle different endpoints + mux := http.NewServeMux() + mux.Handle("/", sseHandler) + mux.Handle("/message", messageHandler) + + // Start test server + testServer := httptest.NewServer(mux) + + return testServer.URL, testServer.Close +} + +func TestSSE(t *testing.T) { + // Compile mock server + url, closeF := startMockSSEEchoServer() + defer closeF() + + trans, err := NewSSE(url) + if err != nil { + t.Fatal(err) + } + + // Start the transport + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + defer cancel() + + err = trans.Start(ctx) + if err != nil { + t.Fatalf("Failed to start transport: %v", err) + } + defer trans.Close() + + t.Run("SendRequest", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + params := map[string]interface{}{ + "string": "hello world", + "array": []interface{}{1, 2, 3}, + } + + request := JSONRPCRequest{ + JSONRPC: "2.0", + ID: 1, + Method: "debug/echo", + Params: params, + } + + // Send the request + response, err := trans.SendRequest(ctx, request) + if err != nil { + t.Fatalf("SendRequest failed: %v", err) + } + + // Parse the result to verify echo + var result struct { + JSONRPC string `json:"jsonrpc"` + ID int64 `json:"id"` + Method string `json:"method"` + Params map[string]interface{} `json:"params"` + } + + if err := json.Unmarshal(response.Result, &result); err != nil { + t.Fatalf("Failed to unmarshal result: %v", err) + } + + // Verify response data matches what was sent + if result.JSONRPC != "2.0" { + t.Errorf("Expected JSONRPC value '2.0', got '%s'", result.JSONRPC) + } + if result.ID != 1 { + t.Errorf("Expected ID 1, got %d", result.ID) + } + if result.Method != "debug/echo" { + t.Errorf("Expected method 'debug/echo', got '%s'", result.Method) + } + + if str, ok := result.Params["string"].(string); !ok || str != "hello world" { + t.Errorf("Expected string 'hello world', got %v", result.Params["string"]) + } + + if arr, ok := result.Params["array"].([]interface{}); !ok || len(arr) != 3 { + t.Errorf("Expected array with 3 items, got %v", result.Params["array"]) + } + }) + + t.Run("SendRequestWithTimeout", func(t *testing.T) { + // Create a context that's already canceled + ctx, cancel := context.WithCancel(context.Background()) + cancel() // Cancel the context immediately + + // Prepare a request + request := JSONRPCRequest{ + JSONRPC: "2.0", + ID: 3, + Method: "debug/echo", + } + + // The request should fail because the context is canceled + _, err := trans.SendRequest(ctx, request) + if err == nil { + t.Errorf("Expected context canceled error, got nil") + } else if !errors.Is(err, context.Canceled) { + t.Errorf("Expected context.Canceled error, got: %v", err) + } + }) + + t.Run("SendNotification & NotificationHandler", func(t *testing.T) { + + var wg sync.WaitGroup + notificationChan := make(chan mcp.JSONRPCNotification, 1) + + // Set notification handler + trans.SetNotificationHandler(func(notification mcp.JSONRPCNotification) { + notificationChan <- notification + }) + + // Send a notification + // This would trigger a notification from the server + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + notification := mcp.JSONRPCNotification{ + JSONRPC: "2.0", + Notification: mcp.Notification{ + Method: "debug/echo_notification", + Params: mcp.NotificationParams{ + AdditionalFields: map[string]interface{}{"test": "value"}, + }, + }, + } + err := trans.SendNotification(ctx, notification) + if err != nil { + t.Fatalf("SendNotification failed: %v", err) + } + + wg.Add(1) + go func() { + defer wg.Done() + select { + case nt := <-notificationChan: + // We received a notification + responseJson, _ := json.Marshal(nt.Params.AdditionalFields) + requestJson, _ := json.Marshal(notification) + if string(responseJson) != string(requestJson) { + t.Errorf("Notification handler did not send the expected notification: \ngot %s\nexpect %s", responseJson, requestJson) + } + + case <-time.After(1 * time.Second): + t.Errorf("Expected notification, got none") + } + }() + + wg.Wait() + }) + + t.Run("MultipleRequests", func(t *testing.T) { + var wg sync.WaitGroup + const numRequests = 5 + + // Send multiple requests concurrently + mu := sync.Mutex{} + responses := make([]*JSONRPCResponse, numRequests) + errors := make([]error, numRequests) + + for i := 0; i < numRequests; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + // Each request has a unique ID and payload + request := JSONRPCRequest{ + JSONRPC: "2.0", + ID: int64(100 + idx), + Method: "debug/echo", + Params: map[string]interface{}{ + "requestIndex": idx, + "timestamp": time.Now().UnixNano(), + }, + } + + resp, err := trans.SendRequest(ctx, request) + mu.Lock() + responses[idx] = resp + errors[idx] = err + mu.Unlock() + }(i) + } + + wg.Wait() + + // Check results + for i := 0; i < numRequests; i++ { + if errors[i] != nil { + t.Errorf("Request %d failed: %v", i, errors[i]) + continue + } + + if responses[i] == nil || responses[i].ID == nil || *responses[i].ID != int64(100+i) { + t.Errorf("Request %d: Expected ID %d, got %v", i, 100+i, responses[i]) + continue + } + + // Parse the result to verify echo + var result struct { + JSONRPC string `json:"jsonrpc"` + ID int64 `json:"id"` + Method string `json:"method"` + Params map[string]interface{} `json:"params"` + } + + if err := json.Unmarshal(responses[i].Result, &result); err != nil { + t.Errorf("Request %d: Failed to unmarshal result: %v", i, err) + continue + } + + // Verify data matches what was sent + if result.ID != int64(100+i) { + t.Errorf("Request %d: Expected echoed ID %d, got %d", i, 100+i, result.ID) + } + + if result.Method != "debug/echo" { + t.Errorf("Request %d: Expected method 'debug/echo', got '%s'", i, result.Method) + } + + // Verify the requestIndex parameter + if idx, ok := result.Params["requestIndex"].(float64); !ok || int(idx) != i { + t.Errorf("Request %d: Expected requestIndex %d, got %v", i, i, result.Params["requestIndex"]) + } + } + }) + + t.Run("ResponseError", func(t *testing.T) { + + // Prepare a request + request := JSONRPCRequest{ + JSONRPC: "2.0", + ID: 100, + Method: "debug/echo_error_string", + } + + // The request should fail because the context is canceled + reps, err := trans.SendRequest(ctx, request) + if err != nil { + t.Errorf("SendRequest failed: %v", err) + } + + if reps.Error == nil { + t.Errorf("Expected error, got nil") + } + + var responseError JSONRPCRequest + if err := json.Unmarshal([]byte(reps.Error.Message), &responseError); err != nil { + t.Errorf("Failed to unmarshal result: %v", err) + } + + if responseError.Method != "debug/echo_error_string" { + t.Errorf("Expected method 'debug/echo_error_string', got '%s'", responseError.Method) + } + if responseError.ID != 100 { + t.Errorf("Expected ID 100, got %d", responseError.ID) + } + if responseError.JSONRPC != "2.0" { + t.Errorf("Expected JSONRPC '2.0', got '%s'", responseError.JSONRPC) + } + }) + +} + +func TestSSEErrors(t *testing.T) { + t.Run("InvalidURL", func(t *testing.T) { + // Create a new SSE transport with an invalid URL + _, err := NewSSE("://invalid-url") + if err == nil { + t.Errorf("Expected error when creating with invalid URL, got nil") + } + }) + + t.Run("NonExistentURL", func(t *testing.T) { + // Create a new SSE transport with a non-existent URL + sse, err := NewSSE("http://localhost:1") + if err != nil { + t.Fatalf("Failed to create SSE transport: %v", err) + } + + // Start should fail + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + err = sse.Start(ctx) + if err == nil { + t.Errorf("Expected error when starting with non-existent URL, got nil") + sse.Close() + } + }) + + t.Run("RequestBeforeStart", func(t *testing.T) { + url, closeF := startMockSSEEchoServer() + defer closeF() + + // Create a new SSE instance without calling Start method + sse, err := NewSSE(url) + if err != nil { + t.Fatalf("Failed to create SSE transport: %v", err) + } + + // Prepare a request + request := JSONRPCRequest{ + JSONRPC: "2.0", + ID: 99, + Method: "ping", + } + + ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond) + defer cancel() + + _, err = sse.SendRequest(ctx, request) + if err == nil { + t.Errorf("Expected SendRequest to fail before Start(), but it didn't") + } + }) + + t.Run("RequestAfterClose", func(t *testing.T) { + // Start a mock server + url, closeF := startMockSSEEchoServer() + defer closeF() + + // Create a new SSE transport + sse, err := NewSSE(url) + if err != nil { + t.Fatalf("Failed to create SSE transport: %v", err) + } + + // Start the transport + ctx := context.Background() + if err := sse.Start(ctx); err != nil { + t.Fatalf("Failed to start SSE transport: %v", err) + } + + // Close the transport + sse.Close() + + // Wait a bit to ensure connection has closed + time.Sleep(100 * time.Millisecond) + + // Try to send a request after close + request := JSONRPCRequest{ + JSONRPC: "2.0", + ID: 1, + Method: "ping", + } + + _, err = sse.SendRequest(ctx, request) + if err == nil { + t.Errorf("Expected error when sending request after close, got nil") + } + }) + +} diff --git a/client/transport/stdio.go b/client/transport/stdio.go new file mode 100644 index 00000000..85a300a1 --- /dev/null +++ b/client/transport/stdio.go @@ -0,0 +1,234 @@ +package transport + +import ( + "bufio" + "context" + "encoding/json" + "fmt" + "io" + "os" + "os/exec" + "sync" + + "github.com/mark3labs/mcp-go/mcp" +) + +// Stdio implements the transport layer of the MCP protocol using stdio communication. +// It launches a subprocess and communicates with it via standard input/output streams +// using JSON-RPC messages. The client handles message routing between requests and +// responses, and supports asynchronous notifications. +type Stdio struct { + command string + args []string + env []string + + cmd *exec.Cmd + stdin io.WriteCloser + stdout *bufio.Reader + stderr io.ReadCloser + responses map[int64]chan *JSONRPCResponse + mu sync.RWMutex + done chan struct{} + onNotification func(mcp.JSONRPCNotification) + notifyMu sync.RWMutex +} + +// NewStdio creates a new stdio transport to communicate with a subprocess. +// It launches the specified command with given arguments and sets up stdin/stdout pipes for communication. +// Returns an error if the subprocess cannot be started or the pipes cannot be created. +func NewStdio( + command string, + env []string, + args ...string, +) *Stdio { + + client := &Stdio{ + command: command, + args: args, + env: env, + + responses: make(map[int64]chan *JSONRPCResponse), + done: make(chan struct{}), + } + + return client +} + +func (c *Stdio) Start(ctx context.Context) error { + cmd := exec.CommandContext(ctx, c.command, c.args...) + + mergedEnv := os.Environ() + mergedEnv = append(mergedEnv, c.env...) + + cmd.Env = mergedEnv + + stdin, err := cmd.StdinPipe() + if err != nil { + return fmt.Errorf("failed to create stdin pipe: %w", err) + } + + stdout, err := cmd.StdoutPipe() + if err != nil { + return fmt.Errorf("failed to create stdout pipe: %w", err) + } + + stderr, err := cmd.StderrPipe() + if err != nil { + return fmt.Errorf("failed to create stderr pipe: %w", err) + } + + c.cmd = cmd + c.stdin = stdin + c.stderr = stderr + c.stdout = bufio.NewReader(stdout) + + if err := cmd.Start(); err != nil { + return fmt.Errorf("failed to start command: %w", err) + } + + // Start reading responses in a goroutine and wait for it to be ready + ready := make(chan struct{}) + go func() { + close(ready) + c.readResponses() + }() + <-ready + + return nil +} + +// Close shuts down the stdio client, closing the stdin pipe and waiting for the subprocess to exit. +// Returns an error if there are issues closing stdin or waiting for the subprocess to terminate. +func (c *Stdio) Close() error { + close(c.done) + if err := c.stdin.Close(); err != nil { + return fmt.Errorf("failed to close stdin: %w", err) + } + if err := c.stderr.Close(); err != nil { + return fmt.Errorf("failed to close stderr: %w", err) + } + return c.cmd.Wait() +} + +// OnNotification registers a handler function to be called when notifications are received. +// Multiple handlers can be registered and will be called in the order they were added. +func (c *Stdio) SetNotificationHandler( + handler func(notification mcp.JSONRPCNotification), +) { + c.notifyMu.Lock() + defer c.notifyMu.Unlock() + c.onNotification = handler +} + +// readResponses continuously reads and processes responses from the server's stdout. +// It handles both responses to requests and notifications, routing them appropriately. +// Runs until the done channel is closed or an error occurs reading from stdout. +func (c *Stdio) readResponses() { + for { + select { + case <-c.done: + return + default: + line, err := c.stdout.ReadString('\n') + if err != nil { + if err != io.EOF { + fmt.Printf("Error reading response: %v\n", err) + } + return + } + + var baseMessage JSONRPCResponse + if err := json.Unmarshal([]byte(line), &baseMessage); err != nil { + continue + } + + // Handle notification + if baseMessage.ID == nil { + var notification mcp.JSONRPCNotification + if err := json.Unmarshal([]byte(line), ¬ification); err != nil { + continue + } + c.notifyMu.RLock() + if c.onNotification != nil { + c.onNotification(notification) + } + c.notifyMu.RUnlock() + continue + } + + c.mu.RLock() + ch, ok := c.responses[*baseMessage.ID] + c.mu.RUnlock() + + if ok { + ch <- &baseMessage + c.mu.Lock() + delete(c.responses, *baseMessage.ID) + c.mu.Unlock() + } + } + } +} + +// sendRequest sends a JSON-RPC request to the server and waits for a response. +// It creates a unique request ID, sends the request over stdin, and waits for +// the corresponding response or context cancellation. +// Returns the raw JSON response message or an error if the request fails. +func (c *Stdio) SendRequest( + ctx context.Context, + request JSONRPCRequest, +) (*JSONRPCResponse, error) { + if c.stdin == nil { + return nil, fmt.Errorf("stdio client not started") + } + + // Create the complete request structure + responseChan := make(chan *JSONRPCResponse, 1) + c.mu.Lock() + c.responses[request.ID] = responseChan + c.mu.Unlock() + + requestBytes, err := json.Marshal(request) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + requestBytes = append(requestBytes, '\n') + + if _, err := c.stdin.Write(requestBytes); err != nil { + return nil, fmt.Errorf("failed to write request: %w", err) + } + + select { + case <-ctx.Done(): + c.mu.Lock() + delete(c.responses, request.ID) + c.mu.Unlock() + return nil, ctx.Err() + case response := <-responseChan: + return response, nil + } +} + +// SendNotification sends a json RPC Notification to the server. +func (c *Stdio) SendNotification( + ctx context.Context, + notification mcp.JSONRPCNotification, +) error { + notificationBytes, err := json.Marshal(notification) + if err != nil { + return fmt.Errorf("failed to marshal notification: %w", err) + } + notificationBytes = append(notificationBytes, '\n') + + if _, err := c.stdin.Write(notificationBytes); err != nil { + return fmt.Errorf("failed to write notification: %w", err) + } + + return nil +} + +// Stderr returns a reader for the stderr output of the subprocess. +// This can be used to capture error messages or logs from the subprocess. +func (c *Stdio) Stderr() io.Reader { + return c.stderr +} diff --git a/client/transport/stdio_test.go b/client/transport/stdio_test.go new file mode 100644 index 00000000..aa728ec6 --- /dev/null +++ b/client/transport/stdio_test.go @@ -0,0 +1,379 @@ +package transport + +import ( + "context" + "encoding/json" + "fmt" + "os" + "os/exec" + "path/filepath" + "runtime" + "sync" + "testing" + "time" + + "github.com/mark3labs/mcp-go/mcp" +) + +func compileTestServer(outputPath string) error { + cmd := exec.Command( + "go", + "build", + "-o", + outputPath, + "../../testdata/mockstdio_server.go", + ) + if output, err := cmd.CombinedOutput(); err != nil { + return fmt.Errorf("compilation failed: %v\nOutput: %s", err, output) + } + return nil +} + +func TestStdio(t *testing.T) { + // Compile mock server + mockServerPath := filepath.Join(os.TempDir(), "mockstdio_server") + // Add .exe suffix on Windows + if runtime.GOOS == "windows" { + mockServerPath += ".exe" + } + if err := compileTestServer(mockServerPath); err != nil { + t.Fatalf("Failed to compile mock server: %v", err) + } + defer os.Remove(mockServerPath) + + // Create a new Stdio transport + stdio := NewStdio(mockServerPath, nil) + + // Start the transport + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + err := stdio.Start(ctx) + if err != nil { + t.Fatalf("Failed to start Stdio transport: %v", err) + } + defer stdio.Close() + + t.Run("SendRequest", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 5000000000*time.Second) + defer cancel() + + params := map[string]interface{}{ + "string": "hello world", + "array": []interface{}{1, 2, 3}, + } + + request := JSONRPCRequest{ + JSONRPC: "2.0", + ID: 1, + Method: "debug/echo", + Params: params, + } + + // Send the request + response, err := stdio.SendRequest(ctx, request) + if err != nil { + t.Fatalf("SendRequest failed: %v", err) + } + + // Parse the result to verify echo + var result struct { + JSONRPC string `json:"jsonrpc"` + ID int64 `json:"id"` + Method string `json:"method"` + Params map[string]interface{} `json:"params"` + } + + if err := json.Unmarshal(response.Result, &result); err != nil { + t.Fatalf("Failed to unmarshal result: %v", err) + } + + // Verify response data matches what was sent + if result.JSONRPC != "2.0" { + t.Errorf("Expected JSONRPC value '2.0', got '%s'", result.JSONRPC) + } + if result.ID != 1 { + t.Errorf("Expected ID 1, got %d", result.ID) + } + if result.Method != "debug/echo" { + t.Errorf("Expected method 'debug/echo', got '%s'", result.Method) + } + + if str, ok := result.Params["string"].(string); !ok || str != "hello world" { + t.Errorf("Expected string 'hello world', got %v", result.Params["string"]) + } + + if arr, ok := result.Params["array"].([]interface{}); !ok || len(arr) != 3 { + t.Errorf("Expected array with 3 items, got %v", result.Params["array"]) + } + }) + + t.Run("SendRequestWithTimeout", func(t *testing.T) { + // Create a context that's already canceled + ctx, cancel := context.WithCancel(context.Background()) + cancel() // Cancel the context immediately + + // Prepare a request + request := JSONRPCRequest{ + JSONRPC: "2.0", + ID: 3, + Method: "debug/echo", + } + + // The request should fail because the context is canceled + _, err := stdio.SendRequest(ctx, request) + if err == nil { + t.Errorf("Expected context canceled error, got nil") + } else if err != context.Canceled { + t.Errorf("Expected context.Canceled error, got: %v", err) + } + }) + + t.Run("SendNotification & NotificationHandler", func(t *testing.T) { + + var wg sync.WaitGroup + notificationChan := make(chan mcp.JSONRPCNotification, 1) + + // Set notification handler + stdio.SetNotificationHandler(func(notification mcp.JSONRPCNotification) { + notificationChan <- notification + }) + + // Send a notification + // This would trigger a notification from the server + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + notification := mcp.JSONRPCNotification{ + JSONRPC: "2.0", + Notification: mcp.Notification{ + Method: "debug/echo_notification", + Params: mcp.NotificationParams{ + AdditionalFields: map[string]interface{}{"test": "value"}, + }, + }, + } + err := stdio.SendNotification(ctx, notification) + if err != nil { + t.Fatalf("SendNotification failed: %v", err) + } + + wg.Add(1) + go func() { + defer wg.Done() + select { + case nt := <-notificationChan: + // We received a notification + responseJson, _ := json.Marshal(nt.Params.AdditionalFields) + requestJson, _ := json.Marshal(notification) + if string(responseJson) != string(requestJson) { + t.Errorf("Notification handler did not send the expected notification: \ngot %s\nexpect %s", responseJson, requestJson) + } + + case <-time.After(1 * time.Second): + t.Errorf("Expected notification, got none") + } + }() + + wg.Wait() + }) + + t.Run("MultipleRequests", func(t *testing.T) { + var wg sync.WaitGroup + const numRequests = 5 + + // Send multiple requests concurrently + responses := make([]*JSONRPCResponse, numRequests) + errors := make([]error, numRequests) + mu := sync.Mutex{} + for i := 0; i < numRequests; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + // Each request has a unique ID and payload + request := JSONRPCRequest{ + JSONRPC: "2.0", + ID: int64(100 + idx), + Method: "debug/echo", + Params: map[string]interface{}{ + "requestIndex": idx, + "timestamp": time.Now().UnixNano(), + }, + } + + resp, err := stdio.SendRequest(ctx, request) + mu.Lock() + responses[idx] = resp + errors[idx] = err + mu.Unlock() + }(i) + } + + wg.Wait() + + // Check results + for i := 0; i < numRequests; i++ { + if errors[i] != nil { + t.Errorf("Request %d failed: %v", i, errors[i]) + continue + } + + if responses[i] == nil || responses[i].ID == nil || *responses[i].ID != int64(100+i) { + t.Errorf("Request %d: Expected ID %d, got %v", i, 100+i, responses[i]) + continue + } + + // Parse the result to verify echo + var result struct { + JSONRPC string `json:"jsonrpc"` + ID int64 `json:"id"` + Method string `json:"method"` + Params map[string]interface{} `json:"params"` + } + + if err := json.Unmarshal(responses[i].Result, &result); err != nil { + t.Errorf("Request %d: Failed to unmarshal result: %v", i, err) + continue + } + + // Verify data matches what was sent + if result.ID != int64(100+i) { + t.Errorf("Request %d: Expected echoed ID %d, got %d", i, 100+i, result.ID) + } + + if result.Method != "debug/echo" { + t.Errorf("Request %d: Expected method 'debug/echo', got '%s'", i, result.Method) + } + + // Verify the requestIndex parameter + if idx, ok := result.Params["requestIndex"].(float64); !ok || int(idx) != i { + t.Errorf("Request %d: Expected requestIndex %d, got %v", i, i, result.Params["requestIndex"]) + } + } + }) + + t.Run("ResponseError", func(t *testing.T) { + + // Prepare a request + request := JSONRPCRequest{ + JSONRPC: "2.0", + ID: 100, + Method: "debug/echo_error_string", + } + + // The request should fail because the context is canceled + reps, err := stdio.SendRequest(ctx, request) + if err != nil { + t.Errorf("SendRequest failed: %v", err) + } + + if reps.Error == nil { + t.Errorf("Expected error, got nil") + } + + var responseError JSONRPCRequest + if err := json.Unmarshal([]byte(reps.Error.Message), &responseError); err != nil { + t.Errorf("Failed to unmarshal result: %v", err) + } + + if responseError.Method != "debug/echo_error_string" { + t.Errorf("Expected method 'debug/echo_error_string', got '%s'", responseError.Method) + } + if responseError.ID != 100 { + t.Errorf("Expected ID 100, got %d", responseError.ID) + } + if responseError.JSONRPC != "2.0" { + t.Errorf("Expected JSONRPC '2.0', got '%s'", responseError.JSONRPC) + } + }) + +} + +func TestStdioErrors(t *testing.T) { + t.Run("InvalidCommand", func(t *testing.T) { + // Create a new Stdio transport with a non-existent command + stdio := NewStdio("non_existent_command", nil) + + // Start should fail + ctx := context.Background() + err := stdio.Start(ctx) + if err == nil { + t.Errorf("Expected error when starting with invalid command, got nil") + stdio.Close() + } + }) + + t.Run("RequestBeforeStart", func(t *testing.T) { + mockServerPath := filepath.Join(os.TempDir(), "mockstdio_server") + // Add .exe suffix on Windows + if runtime.GOOS == "windows" { + mockServerPath += ".exe" + } + if err := compileTestServer(mockServerPath); err != nil { + t.Fatalf("Failed to compile mock server: %v", err) + } + defer os.Remove(mockServerPath) + + uninitiatedStdio := NewStdio(mockServerPath, nil) + + // Prepare a request + request := JSONRPCRequest{ + JSONRPC: "2.0", + ID: 99, + Method: "ping", + } + + ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond) + defer cancel() + _, err := uninitiatedStdio.SendRequest(ctx, request) + if err == nil { + t.Errorf("Expected SendRequest to panic before Start(), but it didn't") + } else if err.Error() != "stdio client not started" { + t.Errorf("Expected error 'stdio client not started', got: %v", err) + } + }) + + t.Run("RequestAfterClose", func(t *testing.T) { + // Compile mock server + mockServerPath := filepath.Join(os.TempDir(), "mockstdio_server") + // Add .exe suffix on Windows + if runtime.GOOS == "windows" { + mockServerPath += ".exe" + } + if err := compileTestServer(mockServerPath); err != nil { + t.Fatalf("Failed to compile mock server: %v", err) + } + defer os.Remove(mockServerPath) + + // Create a new Stdio transport + stdio := NewStdio(mockServerPath, nil) + + // Start the transport + ctx := context.Background() + if err := stdio.Start(ctx); err != nil { + t.Fatalf("Failed to start Stdio transport: %v", err) + } + + // Close the transport - ignore errors like "broken pipe" since the process might exit already + stdio.Close() + + // Wait a bit to ensure process has exited + time.Sleep(100 * time.Millisecond) + + // Try to send a request after close + request := JSONRPCRequest{ + JSONRPC: "2.0", + ID: 1, + Method: "ping", + } + + _, err := stdio.SendRequest(ctx, request) + if err == nil { + t.Errorf("Expected error when sending request after close, got nil") + } + }) + +} diff --git a/client/transport/streamable_http.go b/client/transport/streamable_http.go new file mode 100644 index 00000000..4bc60a25 --- /dev/null +++ b/client/transport/streamable_http.go @@ -0,0 +1,387 @@ +package transport + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/mark3labs/mcp-go/mcp" +) + +type StreamableHTTPCOption func(*StreamableHTTP) + +func WithHTTPHeaders(headers map[string]string) StreamableHTTPCOption { + return func(sc *StreamableHTTP) { + sc.headers = headers + } +} + +// WithHTTPTimeout sets the timeout for a HTTP request and stream. +func WithHTTPTimeout(timeout time.Duration) StreamableHTTPCOption { + return func(sc *StreamableHTTP) { + sc.httpClient.Timeout = timeout + } +} + +// StreamableHTTP implements Streamable HTTP transport. +// +// It transmits JSON-RPC messages over individual HTTP requests. One message per request. +// The HTTP response body can either be a single JSON-RPC response, +// or an upgraded SSE stream that concludes with a JSON-RPC response for the same request. +// +// https://modelcontextprotocol.io/specification/2025-03-26/basic/transports +// +// The current implementation does not support the following features: +// - batching +// - continuously listening for server notifications when no request is in flight +// (https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#listening-for-messages-from-the-server) +// - resuming stream +// (https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#resumability-and-redelivery) +// - server -> client request +type StreamableHTTP struct { + baseURL *url.URL + httpClient *http.Client + headers map[string]string + + sessionID atomic.Value // string + + notificationHandler func(mcp.JSONRPCNotification) + notifyMu sync.RWMutex + + closed chan struct{} +} + +// NewStreamableHTTP creates a new Streamable HTTP transport with the given base URL. +// Returns an error if the URL is invalid. +func NewStreamableHTTP(baseURL string, options ...StreamableHTTPCOption) (*StreamableHTTP, error) { + parsedURL, err := url.Parse(baseURL) + if err != nil { + return nil, fmt.Errorf("invalid URL: %w", err) + } + + smc := &StreamableHTTP{ + baseURL: parsedURL, + httpClient: &http.Client{}, + headers: make(map[string]string), + closed: make(chan struct{}), + } + smc.sessionID.Store("") // set initial value to simplify later usage + + for _, opt := range options { + opt(smc) + } + + return smc, nil +} + +// Start initiates the HTTP connection to the server. +func (c *StreamableHTTP) Start(ctx context.Context) error { + // For Streamable HTTP, we don't need to establish a persistent connection + return nil +} + +// Close closes the all the HTTP connections to the server. +func (c *StreamableHTTP) Close() error { + select { + case <-c.closed: + return nil + default: + } + // Cancel all in-flight requests + close(c.closed) + + sessionId := c.sessionID.Load().(string) + if sessionId != "" { + c.sessionID.Store("") + + // notify server session closed + go func() { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + req, err := http.NewRequestWithContext(ctx, http.MethodDelete, c.baseURL.String(), nil) + if err != nil { + fmt.Printf("failed to create close request\n: %v", err) + return + } + req.Header.Set(headerKeySessionID, sessionId) + res, err := c.httpClient.Do(req) + if err != nil { + fmt.Printf("failed to send close request\n: %v", err) + return + } + res.Body.Close() + }() + } + + return nil +} + +const ( + initializeMethod = "initialize" + headerKeySessionID = "Mcp-Session-Id" +) + +// sendRequest sends a JSON-RPC request to the server and waits for a response. +// Returns the raw JSON response message or an error if the request fails. +func (c *StreamableHTTP) SendRequest( + ctx context.Context, + request JSONRPCRequest, +) (*JSONRPCResponse, error) { + + // Create a combined context that could be canceled when the client is closed + newCtx, cancel := context.WithCancel(ctx) + defer cancel() + go func() { + select { + case <-c.closed: + cancel() + case <-newCtx.Done(): + // The original context was canceled, no need to do anything + } + }() + ctx = newCtx + + // Marshal request + requestBody, err := json.Marshal(request) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + // Create HTTP request + req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL.String(), bytes.NewReader(requestBody)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + // Set headers + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json, text/event-stream") + sessionID := c.sessionID.Load() + if sessionID != "" { + req.Header.Set(headerKeySessionID, sessionID.(string)) + } + for k, v := range c.headers { + req.Header.Set(k, v) + } + + // Send request + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + // Check if we got an error response + if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusAccepted { + // handle session closed + if resp.StatusCode == http.StatusNotFound { + c.sessionID.CompareAndSwap(sessionID, "") + return nil, fmt.Errorf("session terminated (404). need to re-initialize") + } + + // handle error response + var errResponse JSONRPCResponse + body, _ := io.ReadAll(resp.Body) + if err := json.Unmarshal(body, &errResponse); err == nil { + return &errResponse, nil + } + return nil, fmt.Errorf("request failed with status %d: %s", resp.StatusCode, body) + } + + if request.Method == initializeMethod { + // saved the received session ID in the response + // empty session ID is allowed + if sessionID := resp.Header.Get(headerKeySessionID); sessionID != "" { + c.sessionID.Store(sessionID) + } + } + + // Handle different response types + switch resp.Header.Get("Content-Type") { + case "application/json": + // Single response + var response JSONRPCResponse + if err := json.NewDecoder(resp.Body).Decode(&response); err != nil { + return nil, fmt.Errorf("failed to decode response: %w", err) + } + + // should not be a notification + if response.ID == nil { + return nil, fmt.Errorf("response should contain RPC id: %v", response) + } + + return &response, nil + + case "text/event-stream": + // Server is using SSE for streaming responses + return c.handleSSEResponse(ctx, resp.Body) + + default: + return nil, fmt.Errorf("unexpected content type: %s", resp.Header.Get("Content-Type")) + } +} + +// handleSSEResponse processes an SSE stream for a specific request. +// It returns the final result for the request once received, or an error. +func (c *StreamableHTTP) handleSSEResponse(ctx context.Context, reader io.ReadCloser) (*JSONRPCResponse, error) { + + // Create a channel for this specific request + responseChan := make(chan *JSONRPCResponse, 1) + defer close(responseChan) + + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + // Start a goroutine to process the SSE stream + go c.readSSE(ctx, reader, func(event, data string) { + + // (unsupported: batching) + + var message JSONRPCResponse + if err := json.Unmarshal([]byte(data), &message); err != nil { + fmt.Printf("failed to unmarshal message: %v\n", err) + return + } + + // Handle notification + if message.ID == nil { + var notification mcp.JSONRPCNotification + if err := json.Unmarshal([]byte(data), ¬ification); err != nil { + fmt.Printf("failed to unmarshal notification: %v\n", err) + return + } + c.notifyMu.RLock() + if c.notificationHandler != nil { + c.notificationHandler(notification) + } + c.notifyMu.RUnlock() + return + } + + responseChan <- &message + }) + + // Wait for the response or context cancellation + select { + case response := <-responseChan: + if response == nil { + return nil, fmt.Errorf("unexpected nil response") + } + return response, nil + case <-ctx.Done(): + return nil, ctx.Err() + } +} + +// readSSE reads the SSE stream(reader) and calls the handler for each event and data pair. +// It will end when the reader is closed (or the context is done). +func (c *StreamableHTTP) readSSE(ctx context.Context, reader io.ReadCloser, handler func(event, data string)) { + defer reader.Close() + + br := bufio.NewReader(reader) + var event, data string + + for { + select { + case <-ctx.Done(): + return + default: + line, err := br.ReadString('\n') + if err != nil { + if err == io.EOF { + // Process any pending event before exit + if event != "" && data != "" { + handler(event, data) + } + return + } + select { + case <-ctx.Done(): + return + default: + fmt.Printf("SSE stream error: %v\n", err) + return + } + } + + // Remove only newline markers + line = strings.TrimRight(line, "\r\n") + if line == "" { + // Empty line means end of event + if event != "" && data != "" { + handler(event, data) + event = "" + data = "" + } + continue + } + + if strings.HasPrefix(line, "event:") { + event = strings.TrimSpace(strings.TrimPrefix(line, "event:")) + } else if strings.HasPrefix(line, "data:") { + data = strings.TrimSpace(strings.TrimPrefix(line, "data:")) + } + } + } +} + +func (c *StreamableHTTP) SendNotification(ctx context.Context, notification mcp.JSONRPCNotification) error { + + // Marshal request + requestBody, err := json.Marshal(notification) + if err != nil { + return fmt.Errorf("failed to marshal notification: %w", err) + } + + // Create HTTP request + req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL.String(), bytes.NewReader(requestBody)) + if err != nil { + return fmt.Errorf("failed to create request: %w", err) + } + + // Set headers + req.Header.Set("Content-Type", "application/json") + if sessionID := c.sessionID.Load(); sessionID != "" { + req.Header.Set(headerKeySessionID, sessionID.(string)) + } + for k, v := range c.headers { + req.Header.Set(k, v) + } + + // Send request + resp, err := c.httpClient.Do(req) + if err != nil { + return fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusAccepted { + body, _ := io.ReadAll(resp.Body) + return fmt.Errorf( + "notification failed with status %d: %s", + resp.StatusCode, + body, + ) + } + + return nil +} + +func (c *StreamableHTTP) SetNotificationHandler(handler func(mcp.JSONRPCNotification)) { + c.notifyMu.Lock() + defer c.notifyMu.Unlock() + c.notificationHandler = handler +} + +func (c *StreamableHTTP) GetSessionId() string { + return c.sessionID.Load().(string) +} diff --git a/client/transport/streamable_http_test.go b/client/transport/streamable_http_test.go new file mode 100644 index 00000000..b7b76b96 --- /dev/null +++ b/client/transport/streamable_http_test.go @@ -0,0 +1,425 @@ +package transport + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "net/http" + "net/http/httptest" + "sync" + "testing" + "time" + + "github.com/mark3labs/mcp-go/mcp" +) + +// startMockStreamableHTTPServer starts a test HTTP server that implements +// a minimal Streamable HTTP server for testing purposes. +// It returns the server URL and a function to close the server. +func startMockStreamableHTTPServer() (string, func()) { + var sessionID string + var mu sync.Mutex + + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Handle only POST requests + if r.Method != http.MethodPost { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + // Parse incoming JSON-RPC request + var request map[string]any + decoder := json.NewDecoder(r.Body) + if err := decoder.Decode(&request); err != nil { + http.Error(w, fmt.Sprintf("Invalid JSON: %v", err), http.StatusBadRequest) + return + } + + method := request["method"] + switch method { + case "initialize": + // Generate a new session ID + mu.Lock() + sessionID = fmt.Sprintf("test-session-%d", time.Now().UnixNano()) + mu.Unlock() + w.Header().Set("Mcp-Session-Id", sessionID) + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusAccepted) + json.NewEncoder(w).Encode(map[string]interface{}{ + "jsonrpc": "2.0", + "id": request["id"], + "result": "initialized", + }) + + case "debug/echo": + // Check session ID + if r.Header.Get("Mcp-Session-Id") != sessionID { + http.Error(w, "Invalid session ID", http.StatusNotFound) + return + } + + // Echo back the request as the response result + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(map[string]interface{}{ + "jsonrpc": "2.0", + "id": request["id"], + "result": request, + }) + + case "debug/echo_notification": + // Check session ID + if r.Header.Get("Mcp-Session-Id") != sessionID { + http.Error(w, "Invalid session ID", http.StatusNotFound) + return + } + + // Send response and notification + w.Header().Set("Content-Type", "text/event-stream") + w.WriteHeader(http.StatusOK) + notification := map[string]any{ + "jsonrpc": "2.0", + "method": "debug/test", + "params": request, + } + notificationData, _ := json.Marshal(notification) + fmt.Fprintf(w, "event: message\ndata: %s\n\n", notificationData) + response := map[string]any{ + "jsonrpc": "2.0", + "id": request["id"], + "result": request, + } + responseData, _ := json.Marshal(response) + fmt.Fprintf(w, "event: message\ndata: %s\n\n", responseData) + + case "debug/echo_error_string": + // Check session ID + if r.Header.Get("Mcp-Session-Id") != sessionID { + http.Error(w, "Invalid session ID", http.StatusNotFound) + return + } + + // Return an error response + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + data, _ := json.Marshal(request) + json.NewEncoder(w).Encode(map[string]interface{}{ + "jsonrpc": "2.0", + "id": request["id"], + "error": map[string]interface{}{ + "code": -1, + "message": string(data), + }, + }) + } + }) + + // Start test server + testServer := httptest.NewServer(handler) + return testServer.URL, testServer.Close +} + +func TestStreamableHTTP(t *testing.T) { + // Start mock server + url, closeF := startMockStreamableHTTPServer() + defer closeF() + + // Create transport + trans, err := NewStreamableHTTP(url) + if err != nil { + t.Fatal(err) + } + defer trans.Close() + + // Initialize the transport first + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + initRequest := JSONRPCRequest{ + JSONRPC: "2.0", + ID: 1, + Method: "initialize", + } + + _, err = trans.SendRequest(ctx, initRequest) + if err != nil { + t.Fatal(err) + } + + // Now run the tests + t.Run("SendRequest", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + params := map[string]interface{}{ + "string": "hello world", + "array": []interface{}{1, 2, 3}, + } + + request := JSONRPCRequest{ + JSONRPC: "2.0", + ID: 1, + Method: "debug/echo", + Params: params, + } + + // Send the request + response, err := trans.SendRequest(ctx, request) + if err != nil { + t.Fatalf("SendRequest failed: %v", err) + } + + // Parse the result to verify echo + var result struct { + JSONRPC string `json:"jsonrpc"` + ID int64 `json:"id"` + Method string `json:"method"` + Params map[string]interface{} `json:"params"` + } + + if err := json.Unmarshal(response.Result, &result); err != nil { + t.Fatalf("Failed to unmarshal result: %v", err) + } + + // Verify response data matches what was sent + if result.JSONRPC != "2.0" { + t.Errorf("Expected JSONRPC value '2.0', got '%s'", result.JSONRPC) + } + if result.ID != 1 { + t.Errorf("Expected ID 1, got %d", result.ID) + } + if result.Method != "debug/echo" { + t.Errorf("Expected method 'debug/echo', got '%s'", result.Method) + } + + if str, ok := result.Params["string"].(string); !ok || str != "hello world" { + t.Errorf("Expected string 'hello world', got %v", result.Params["string"]) + } + + if arr, ok := result.Params["array"].([]interface{}); !ok || len(arr) != 3 { + t.Errorf("Expected array with 3 items, got %v", result.Params["array"]) + } + }) + + t.Run("SendRequestWithTimeout", func(t *testing.T) { + // Create a context that's already canceled + ctx, cancel := context.WithCancel(context.Background()) + cancel() // Cancel the context immediately + + // Prepare a request + request := JSONRPCRequest{ + JSONRPC: "2.0", + ID: 3, + Method: "debug/echo", + } + + // The request should fail because the context is canceled + _, err := trans.SendRequest(ctx, request) + if err == nil { + t.Errorf("Expected context canceled error, got nil") + } else if !errors.Is(err, context.Canceled) { + t.Errorf("Expected context.Canceled error, got: %v", err) + } + }) + + t.Run("SendNotification & NotificationHandler", func(t *testing.T) { + var wg sync.WaitGroup + notificationChan := make(chan mcp.JSONRPCNotification, 1) + + // Set notification handler + trans.SetNotificationHandler(func(notification mcp.JSONRPCNotification) { + notificationChan <- notification + }) + + // Send a request that triggers a notification + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + request := JSONRPCRequest{ + JSONRPC: "2.0", + ID: 1, + Method: "debug/echo_notification", + } + + _, err := trans.SendRequest(ctx, request) + if err != nil { + t.Fatalf("SendRequest failed: %v", err) + } + + wg.Add(1) + go func() { + defer wg.Done() + select { + case notification := <-notificationChan: + // We received a notification + got := notification.Params.AdditionalFields + if got == nil { + t.Errorf("Notification handler did not send the expected notification: got nil") + } + if int64(got["id"].(float64)) != request.ID || + got["jsonrpc"] != request.JSONRPC || + got["method"] != request.Method { + + responseJson, _ := json.Marshal(got) + requestJson, _ := json.Marshal(request) + t.Errorf("Notification handler did not send the expected notification: \ngot %s\nexpect %s", responseJson, requestJson) + } + + case <-time.After(1 * time.Second): + t.Errorf("Expected notification, got none") + } + }() + + wg.Wait() + }) + + t.Run("MultipleRequests", func(t *testing.T) { + var wg sync.WaitGroup + const numRequests = 5 + + // Send multiple requests concurrently + mu := sync.Mutex{} + responses := make([]*JSONRPCResponse, numRequests) + errors := make([]error, numRequests) + + for i := 0; i < numRequests; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + // Each request has a unique ID and payload + request := JSONRPCRequest{ + JSONRPC: "2.0", + ID: int64(100 + idx), + Method: "debug/echo", + Params: map[string]interface{}{ + "requestIndex": idx, + "timestamp": time.Now().UnixNano(), + }, + } + + resp, err := trans.SendRequest(ctx, request) + mu.Lock() + responses[idx] = resp + errors[idx] = err + mu.Unlock() + }(i) + } + + wg.Wait() + + // Check results + for i := 0; i < numRequests; i++ { + if errors[i] != nil { + t.Errorf("Request %d failed: %v", i, errors[i]) + continue + } + + if responses[i] == nil || responses[i].ID == nil || *responses[i].ID != int64(100+i) { + t.Errorf("Request %d: Expected ID %d, got %v", i, 100+i, responses[i]) + continue + } + + // Parse the result to verify echo + var result struct { + JSONRPC string `json:"jsonrpc"` + ID int64 `json:"id"` + Method string `json:"method"` + Params map[string]interface{} `json:"params"` + } + + if err := json.Unmarshal(responses[i].Result, &result); err != nil { + t.Errorf("Request %d: Failed to unmarshal result: %v", i, err) + continue + } + + // Verify data matches what was sent + if result.ID != int64(100+i) { + t.Errorf("Request %d: Expected echoed ID %d, got %d", i, 100+i, result.ID) + } + + if result.Method != "debug/echo" { + t.Errorf("Request %d: Expected method 'debug/echo', got '%s'", i, result.Method) + } + + // Verify the requestIndex parameter + if idx, ok := result.Params["requestIndex"].(float64); !ok || int(idx) != i { + t.Errorf("Request %d: Expected requestIndex %d, got %v", i, i, result.Params["requestIndex"]) + } + } + }) + + t.Run("ResponseError", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + // Prepare a request + request := JSONRPCRequest{ + JSONRPC: "2.0", + ID: 100, + Method: "debug/echo_error_string", + } + + reps, err := trans.SendRequest(ctx, request) + if err != nil { + t.Errorf("SendRequest failed: %v", err) + } + + if reps.Error == nil { + t.Errorf("Expected error, got nil") + } + + var responseError JSONRPCRequest + if err := json.Unmarshal([]byte(reps.Error.Message), &responseError); err != nil { + t.Errorf("Failed to unmarshal result: %v", err) + return + } + + if responseError.Method != "debug/echo_error_string" { + t.Errorf("Expected method 'debug/echo_error_string', got '%s'", responseError.Method) + } + if responseError.ID != 100 { + t.Errorf("Expected ID 100, got %d", responseError.ID) + } + if responseError.JSONRPC != "2.0" { + t.Errorf("Expected JSONRPC '2.0', got '%s'", responseError.JSONRPC) + } + }) +} + +func TestStreamableHTTPErrors(t *testing.T) { + t.Run("InvalidURL", func(t *testing.T) { + // Create a new StreamableHTTP transport with an invalid URL + _, err := NewStreamableHTTP("://invalid-url") + if err == nil { + t.Errorf("Expected error when creating with invalid URL, got nil") + } + }) + + t.Run("NonExistentURL", func(t *testing.T) { + // Create a new StreamableHTTP transport with a non-existent URL + trans, err := NewStreamableHTTP("http://localhost:1") + if err != nil { + t.Fatalf("Failed to create StreamableHTTP transport: %v", err) + } + + // Send request should fail + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + request := JSONRPCRequest{ + JSONRPC: "2.0", + ID: 1, + Method: "initialize", + } + + _, err = trans.SendRequest(ctx, request) + if err == nil { + t.Errorf("Expected error when sending request to non-existent URL, got nil") + } + }) + +} diff --git a/client/types.go b/client/types.go deleted file mode 100644 index 4402bd02..00000000 --- a/client/types.go +++ /dev/null @@ -1,8 +0,0 @@ -package client - -import "encoding/json" - -type RPCResponse struct { - Error *string - Response *json.RawMessage -} diff --git a/examples/everything/main.go b/examples/everything/main.go index 3e5fd5d9..2701f89d 100644 --- a/examples/everything/main.go +++ b/examples/everything/main.go @@ -369,6 +369,7 @@ func handleLongRunningOperationTool( "progress": i, "total": int(steps), "progressToken": progressToken, + "message": fmt.Sprintf("Server progress %v%%", int(float64(i)*100/steps)), }, ) } diff --git a/go.mod b/go.mod index 940f05de..9b9fe2d4 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,7 @@ go 1.23 require ( github.com/google/uuid v1.6.0 + github.com/spf13/cast v1.7.1 github.com/stretchr/testify v1.9.0 github.com/yosida95/uritemplate/v3 v3.0.2 ) diff --git a/go.sum b/go.sum index 14437f70..31ed86d1 100644 --- a/go.sum +++ b/go.sum @@ -1,9 +1,21 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= +github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= +github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= +github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= +github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= +github.com/spf13/cast v1.7.1 h1:cuNEagBQEHWN1FnbGEjCXL2szYEXqfJPbP2HNUaca9Y= +github.com/spf13/cast v1.7.1/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= diff --git a/mcp/tools.go b/mcp/tools.go index c4c1b1de..b62cdd5b 100644 --- a/mcp/tools.go +++ b/mcp/tools.go @@ -75,6 +75,8 @@ type Tool struct { InputSchema ToolInputSchema `json:"inputSchema"` // Alternative to InputSchema - allows arbitrary JSON Schema to be provided RawInputSchema json.RawMessage `json:"-"` // Hide this from JSON marshaling + // Optional properties describing tool behavior + Annotations ToolAnnotation `json:"annotations"` } // MarshalJSON implements the json.Marshaler interface for Tool. @@ -100,15 +102,30 @@ func (t Tool) MarshalJSON() ([]byte, error) { m["inputSchema"] = t.InputSchema } + m["annotations"] = t.Annotations + return json.Marshal(m) } type ToolInputSchema struct { Type string `json:"type"` - Properties map[string]interface{} `json:"properties"` + Properties map[string]interface{} `json:"properties,omitempty"` Required []string `json:"required,omitempty"` } +type ToolAnnotation struct { + // Human-readable title for the tool + Title string `json:"title,omitempty"` + // If true, the tool does not modify its environment + ReadOnlyHint bool `json:"readOnlyHint,omitempty"` + // If true, the tool may perform destructive updates + DestructiveHint bool `json:"destructiveHint,omitempty"` + // If true, repeated calls with same args have no additional effect + IdempotentHint bool `json:"idempotentHint,omitempty"` + // If true, tool interacts with external entities + OpenWorldHint bool `json:"openWorldHint,omitempty"` +} + // ToolOption is a function that configures a Tool. // It provides a flexible way to set various properties of a Tool using the functional options pattern. type ToolOption func(*Tool) @@ -132,6 +149,13 @@ func NewTool(name string, opts ...ToolOption) Tool { Properties: make(map[string]interface{}), Required: nil, // Will be omitted from JSON if empty }, + Annotations: ToolAnnotation{ + Title: "", + ReadOnlyHint: false, + DestructiveHint: true, + IdempotentHint: false, + OpenWorldHint: true, + }, } for _, opt := range opts { @@ -166,6 +190,12 @@ func WithDescription(description string) ToolOption { } } +func WithToolAnnotation(annotation ToolAnnotation) ToolOption { + return func(t *Tool) { + t.Annotations = annotation + } +} + // // Common Property Options // @@ -286,6 +316,18 @@ func DefaultBool(value bool) PropertyOption { } } +// +// Array Property Options +// + +// DefaultArray sets the default value for an array property. +// This value will be used if the property is not explicitly provided. +func DefaultArray[T any](value []T) PropertyOption { + return func(schema map[string]interface{}) { + schema["default"] = value + } +} + // // Property Type Helpers // diff --git a/mcp/tools_test.go b/mcp/tools_test.go index 31a5b93e..872749e1 100644 --- a/mcp/tools_test.go +++ b/mcp/tools_test.go @@ -2,6 +2,7 @@ package mcp import ( "encoding/json" + "fmt" "testing" "github.com/stretchr/testify/assert" @@ -240,3 +241,70 @@ func TestToolWithObjectAndArray(t *testing.T) { assert.True(t, ok) assert.Contains(t, required, "books") } + +func TestParseToolCallToolRequest(t *testing.T) { + request := CallToolRequest{} + request.Params.Name = "test-tool" + request.Params.Arguments = map[string]interface{}{ + "bool_value": "true", + "int64_value": "123456789", + "int32_value": "123456789", + "int16_value": "123456789", + "int8_value": "123456789", + "int_value": "123456789", + "uint_value": "123456789", + "uint64_value": "123456789", + "uint32_value": "123456789", + "uint16_value": "123456789", + "uint8_value": "123456789", + "float32_value": "3.14", + "float64_value": "3.1415926", + "string_value": "hello", + } + param1 := ParseBoolean(request, "bool_value", false) + assert.Equal(t, fmt.Sprintf("%T", param1), "bool") + + param2 := ParseInt64(request, "int64_value", 1) + assert.Equal(t, fmt.Sprintf("%T", param2), "int64") + + param3 := ParseInt32(request, "int32_value", 1) + assert.Equal(t, fmt.Sprintf("%T", param3), "int32") + + param4 := ParseInt16(request, "int16_value", 1) + assert.Equal(t, fmt.Sprintf("%T", param4), "int16") + + param5 := ParseInt8(request, "int8_value", 1) + assert.Equal(t, fmt.Sprintf("%T", param5), "int8") + + param6 := ParseInt(request, "int_value", 1) + assert.Equal(t, fmt.Sprintf("%T", param6), "int") + + param7 := ParseUInt(request, "uint_value", 1) + assert.Equal(t, fmt.Sprintf("%T", param7), "uint") + + param8 := ParseUInt64(request, "uint64_value", 1) + assert.Equal(t, fmt.Sprintf("%T", param8), "uint64") + + param9 := ParseUInt32(request, "uint32_value", 1) + assert.Equal(t, fmt.Sprintf("%T", param9), "uint32") + + param10 := ParseUInt16(request, "uint16_value", 1) + assert.Equal(t, fmt.Sprintf("%T", param10), "uint16") + + param11 := ParseUInt8(request, "uint8_value", 1) + assert.Equal(t, fmt.Sprintf("%T", param11), "uint8") + + param12 := ParseFloat32(request, "float32_value", 1.0) + assert.Equal(t, fmt.Sprintf("%T", param12), "float32") + + param13 := ParseFloat64(request, "float64_value", 1.0) + assert.Equal(t, fmt.Sprintf("%T", param13), "float64") + + param14 := ParseString(request, "string_value", "") + assert.Equal(t, fmt.Sprintf("%T", param14), "string") + + param15 := ParseInt64(request, "string_value", 1) + assert.Equal(t, fmt.Sprintf("%T", param15), "int64") + t.Logf("param15 type: %T,value:%v", param15, param15) + +} diff --git a/mcp/types.go b/mcp/types.go index a3ad8174..c940a460 100644 --- a/mcp/types.go +++ b/mcp/types.go @@ -371,6 +371,9 @@ type ProgressNotification struct { Progress float64 `json:"progress"` // Total number of items to process (or total progress required), if known. Total float64 `json:"total,omitempty"` + // Message related to progress. This should provide relevant human-readable + // progress information. + Message string `json:"message,omitempty"` } `json:"params"` } diff --git a/mcp/utils.go b/mcp/utils.go index 236164cb..fc97208c 100644 --- a/mcp/utils.go +++ b/mcp/utils.go @@ -3,6 +3,7 @@ package mcp import ( "encoding/json" "fmt" + "github.com/spf13/cast" ) // ClientRequest types @@ -129,6 +130,7 @@ func NewProgressNotification( token ProgressToken, progress float64, total *float64, + message *string, ) ProgressNotification { notification := ProgressNotification{ Notification: Notification{ @@ -138,6 +140,7 @@ func NewProgressNotification( ProgressToken ProgressToken `json:"progressToken"` Progress float64 `json:"progress"` Total float64 `json:"total,omitempty"` + Message string `json:"message,omitempty"` }{ ProgressToken: token, Progress: progress, @@ -146,6 +149,9 @@ func NewProgressNotification( if total != nil { notification.Params.Total = *total } + if message != nil { + notification.Params.Message = *message + } return notification } @@ -266,6 +272,24 @@ func NewToolResultError(text string) *CallToolResult { } } +// NewToolResultErrorFromErr creates a new CallToolResult with an error message. +// If an error is provided, its details will be appended to the text message. +// Any errors that originate from the tool SHOULD be reported inside the result object. +func NewToolResultErrorFromErr(text string, err error) *CallToolResult { + if err != nil { + text = fmt.Sprintf("%s: %v", text, err) + } + return &CallToolResult{ + Content: []Content{ + TextContent{ + Type: "text", + Text: text, + }, + }, + IsError: true, + } +} + // NewListResourcesResult creates a new ListResourcesResult func NewListResourcesResult( resources []Resource, @@ -594,3 +618,105 @@ func ParseReadResourceResult(rawMessage *json.RawMessage) (*ReadResourceResult, return &result, nil } + +func ParseArgument(request CallToolRequest, key string, defaultVal any) any { + if _, ok := request.Params.Arguments[key]; !ok { + return defaultVal + } else { + return request.Params.Arguments[key] + } +} + +// ParseBoolean extracts and converts a boolean parameter from a CallToolRequest. +// If the key is not found in the Arguments map, the defaultValue is returned. +// The function uses cast.ToBool for conversion which handles various string representations +// such as "true", "yes", "1", etc. +func ParseBoolean(request CallToolRequest, key string, defaultValue bool) bool { + v := ParseArgument(request, key, defaultValue) + return cast.ToBool(v) +} + +// ParseInt64 extracts and converts an int64 parameter from a CallToolRequest. +// If the key is not found in the Arguments map, the defaultValue is returned. +func ParseInt64(request CallToolRequest, key string, defaultValue int64) int64 { + v := ParseArgument(request, key, defaultValue) + return cast.ToInt64(v) +} + +// ParseInt32 extracts and converts an int32 parameter from a CallToolRequest. +func ParseInt32(request CallToolRequest, key string, defaultValue int32) int32 { + v := ParseArgument(request, key, defaultValue) + return cast.ToInt32(v) +} + +// ParseInt16 extracts and converts an int16 parameter from a CallToolRequest. +func ParseInt16(request CallToolRequest, key string, defaultValue int16) int16 { + v := ParseArgument(request, key, defaultValue) + return cast.ToInt16(v) +} + +// ParseInt8 extracts and converts an int8 parameter from a CallToolRequest. +func ParseInt8(request CallToolRequest, key string, defaultValue int8) int8 { + v := ParseArgument(request, key, defaultValue) + return cast.ToInt8(v) +} + +// ParseInt extracts and converts an int parameter from a CallToolRequest. +func ParseInt(request CallToolRequest, key string, defaultValue int) int { + v := ParseArgument(request, key, defaultValue) + return cast.ToInt(v) +} + +// ParseUInt extracts and converts an uint parameter from a CallToolRequest. +func ParseUInt(request CallToolRequest, key string, defaultValue uint) uint { + v := ParseArgument(request, key, defaultValue) + return cast.ToUint(v) +} + +// ParseUInt64 extracts and converts an uint64 parameter from a CallToolRequest. +func ParseUInt64(request CallToolRequest, key string, defaultValue uint64) uint64 { + v := ParseArgument(request, key, defaultValue) + return cast.ToUint64(v) +} + +// ParseUInt32 extracts and converts an uint32 parameter from a CallToolRequest. +func ParseUInt32(request CallToolRequest, key string, defaultValue uint32) uint32 { + v := ParseArgument(request, key, defaultValue) + return cast.ToUint32(v) +} + +// ParseUInt16 extracts and converts an uint16 parameter from a CallToolRequest. +func ParseUInt16(request CallToolRequest, key string, defaultValue uint16) uint16 { + v := ParseArgument(request, key, defaultValue) + return cast.ToUint16(v) +} + +// ParseUInt8 extracts and converts an uint8 parameter from a CallToolRequest. +func ParseUInt8(request CallToolRequest, key string, defaultValue uint8) uint8 { + v := ParseArgument(request, key, defaultValue) + return cast.ToUint8(v) +} + +// ParseFloat32 extracts and converts a float32 parameter from a CallToolRequest. +func ParseFloat32(request CallToolRequest, key string, defaultValue float32) float32 { + v := ParseArgument(request, key, defaultValue) + return cast.ToFloat32(v) +} + +// ParseFloat64 extracts and converts a float64 parameter from a CallToolRequest. +func ParseFloat64(request CallToolRequest, key string, defaultValue float64) float64 { + v := ParseArgument(request, key, defaultValue) + return cast.ToFloat64(v) +} + +// ParseString extracts and converts a string parameter from a CallToolRequest. +func ParseString(request CallToolRequest, key string, defaultValue string) string { + v := ParseArgument(request, key, defaultValue) + return cast.ToString(v) +} + +// ParseStringMap extracts and converts a string map parameter from a CallToolRequest. +func ParseStringMap(request CallToolRequest, key string, defaultValue map[string]any) map[string]any { + v := ParseArgument(request, key, defaultValue) + return cast.ToStringMap(v) +} diff --git a/server/internal/gen/request_handler.go.tmpl b/server/internal/gen/request_handler.go.tmpl index 4e139e17..5c69f5fa 100644 --- a/server/internal/gen/request_handler.go.tmpl +++ b/server/internal/gen/request_handler.go.tmpl @@ -24,6 +24,7 @@ func (s *MCPServer) HandleMessage( JSONRPC string `json:"jsonrpc"` Method mcp.MCPMethod `json:"method"` ID any `json:"id,omitempty"` + Result any `json:"result,omitempty"` } if err := json.Unmarshal(message, &baseMessage); err != nil { @@ -56,6 +57,12 @@ func (s *MCPServer) HandleMessage( return nil // Return nil for notifications } + if baseMessage.Result != nil { + // this is a response to a request sent by the server (e.g. from a ping + // sent due to WithKeepAlive option) + return nil + } + switch baseMessage.Method { {{- range .}} case mcp.{{.MethodName}}: diff --git a/server/request_handler.go b/server/request_handler.go index 946ca7ab..55d2d19e 100644 --- a/server/request_handler.go +++ b/server/request_handler.go @@ -23,6 +23,7 @@ func (s *MCPServer) HandleMessage( JSONRPC string `json:"jsonrpc"` Method mcp.MCPMethod `json:"method"` ID any `json:"id,omitempty"` + Result any `json:"result,omitempty"` } if err := json.Unmarshal(message, &baseMessage); err != nil { @@ -55,6 +56,12 @@ func (s *MCPServer) HandleMessage( return nil // Return nil for notifications } + if baseMessage.Result != nil { + // this is a response to a request sent by the server (e.g. from a ping + // sent due to WithKeepAlive option) + return nil + } + switch baseMessage.Method { case mcp.MethodInitialize: var request mcp.InitializeRequest diff --git a/server/resource_test.go b/server/resource_test.go new file mode 100644 index 00000000..94b35a3d --- /dev/null +++ b/server/resource_test.go @@ -0,0 +1,253 @@ +package server + +import ( + "context" + "testing" + "time" + + "github.com/mark3labs/mcp-go/mcp" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestMCPServer_RemoveResource(t *testing.T) { + tests := []struct { + name string + action func(*testing.T, *MCPServer, chan mcp.JSONRPCNotification) + expectedNotifications int + validate func(*testing.T, []mcp.JSONRPCNotification, mcp.JSONRPCMessage) + }{ + { + name: "RemoveResource removes the resource from the server", + action: func(t *testing.T, server *MCPServer, notificationChannel chan mcp.JSONRPCNotification) { + // Add a test resource + server.AddResource( + mcp.NewResource( + "test://resource1", + "Resource 1", + mcp.WithResourceDescription("Test resource 1"), + mcp.WithMIMEType("text/plain"), + ), + func(ctx context.Context, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { + return []mcp.ResourceContents{ + mcp.TextResourceContents{ + URI: "test://resource1", + MIMEType: "text/plain", + Text: "test content 1", + }, + }, nil + }, + ) + + // Add a second resource + server.AddResource( + mcp.NewResource( + "test://resource2", + "Resource 2", + mcp.WithResourceDescription("Test resource 2"), + mcp.WithMIMEType("text/plain"), + ), + func(ctx context.Context, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { + return []mcp.ResourceContents{ + mcp.TextResourceContents{ + URI: "test://resource2", + MIMEType: "text/plain", + Text: "test content 2", + }, + }, nil + }, + ) + + // First, verify we have two resources + response := server.HandleMessage(context.Background(), []byte(`{ + "jsonrpc": "2.0", + "id": 1, + "method": "resources/list" + }`)) + resp, ok := response.(mcp.JSONRPCResponse) + assert.True(t, ok) + result, ok := resp.Result.(mcp.ListResourcesResult) + assert.True(t, ok) + assert.Len(t, result.Resources, 2) + + // Now register session to receive notifications + err := server.RegisterSession(context.TODO(), &fakeSession{ + sessionID: "test", + notificationChannel: notificationChannel, + initialized: true, + }) + require.NoError(t, err) + + // Now remove one resource + server.RemoveResource("test://resource1") + }, + expectedNotifications: 1, + validate: func(t *testing.T, notifications []mcp.JSONRPCNotification, resourcesList mcp.JSONRPCMessage) { + // Check that we received a list_changed notification + assert.Equal(t, "resources/list_changed", notifications[0].Method) + + // Verify we now have only one resource + resp, ok := resourcesList.(mcp.JSONRPCResponse) + assert.True(t, ok, "Expected JSONRPCResponse, got %T", resourcesList) + + result, ok := resp.Result.(mcp.ListResourcesResult) + assert.True(t, ok, "Expected ListResourcesResult, got %T", resp.Result) + + assert.Len(t, result.Resources, 1) + assert.Equal(t, "Resource 2", result.Resources[0].Name) + }, + }, + { + name: "RemoveResource with non-existent resource does nothing", + action: func(t *testing.T, server *MCPServer, notificationChannel chan mcp.JSONRPCNotification) { + // Add a test resource + server.AddResource( + mcp.NewResource( + "test://resource1", + "Resource 1", + mcp.WithResourceDescription("Test resource 1"), + mcp.WithMIMEType("text/plain"), + ), + func(ctx context.Context, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { + return []mcp.ResourceContents{ + mcp.TextResourceContents{ + URI: "test://resource1", + MIMEType: "text/plain", + Text: "test content 1", + }, + }, nil + }, + ) + + // Register session to receive notifications + err := server.RegisterSession(context.TODO(), &fakeSession{ + sessionID: "test", + notificationChannel: notificationChannel, + initialized: true, + }) + require.NoError(t, err) + + // Remove a non-existent resource + server.RemoveResource("test://nonexistent") + }, + expectedNotifications: 1, // Still sends a notification + validate: func(t *testing.T, notifications []mcp.JSONRPCNotification, resourcesList mcp.JSONRPCMessage) { + // Check that we received a list_changed notification + assert.Equal(t, "resources/list_changed", notifications[0].Method) + + // The original resource should still be there + resp, ok := resourcesList.(mcp.JSONRPCResponse) + assert.True(t, ok) + + result, ok := resp.Result.(mcp.ListResourcesResult) + assert.True(t, ok) + + assert.Len(t, result.Resources, 1) + assert.Equal(t, "Resource 1", result.Resources[0].Name) + }, + }, + { + name: "RemoveResource with no listChanged capability doesn't send notification", + action: func(t *testing.T, server *MCPServer, notificationChannel chan mcp.JSONRPCNotification) { + // Create a new server without listChanged capability + noListChangedServer := NewMCPServer( + "test-server", + "1.0.0", + WithResourceCapabilities(true, false), // Subscribe but not listChanged + ) + + // Add a resource + noListChangedServer.AddResource( + mcp.NewResource( + "test://resource1", + "Resource 1", + mcp.WithResourceDescription("Test resource 1"), + mcp.WithMIMEType("text/plain"), + ), + func(ctx context.Context, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { + return []mcp.ResourceContents{ + mcp.TextResourceContents{ + URI: "test://resource1", + MIMEType: "text/plain", + Text: "test content 1", + }, + }, nil + }, + ) + + // Register session to receive notifications + err := noListChangedServer.RegisterSession(context.TODO(), &fakeSession{ + sessionID: "test", + notificationChannel: notificationChannel, + initialized: true, + }) + require.NoError(t, err) + + // Remove the resource + noListChangedServer.RemoveResource("test://resource1") + + // The test can now proceed without waiting for notifications + // since we don't expect any + }, + expectedNotifications: 0, // No notifications expected + validate: func(t *testing.T, notifications []mcp.JSONRPCNotification, resourcesList mcp.JSONRPCMessage) { + // Nothing to do here, we're just verifying that no notifications were sent + assert.Empty(t, notifications) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := context.Background() + server := NewMCPServer( + "test-server", + "1.0.0", + WithResourceCapabilities(true, true), + ) + + // Initialize the server + _ = server.HandleMessage(ctx, []byte(`{ + "jsonrpc": "2.0", + "id": 1, + "method": "initialize" + }`)) + + notificationChannel := make(chan mcp.JSONRPCNotification, 100) + notifications := make([]mcp.JSONRPCNotification, 0) + + tt.action(t, server, notificationChannel) + + // Collect notifications with a timeout + if tt.expectedNotifications > 0 { + for i := 0; i < tt.expectedNotifications; i++ { + select { + case notification := <-notificationChannel: + notifications = append(notifications, notification) + case <-time.After(1 * time.Second): + t.Fatalf("Expected %d notifications but only received %d", tt.expectedNotifications, len(notifications)) + } + } + } else { + // If no notifications expected, wait a brief period to ensure none are sent + select { + case notification := <-notificationChannel: + notifications = append(notifications, notification) + case <-time.After(100 * time.Millisecond): + // This is the expected path - no notifications + } + } + + // Get final resources list + listMessage := `{ + "jsonrpc": "2.0", + "id": 1, + "method": "resources/list" + }` + resourcesList := server.HandleMessage(ctx, []byte(listMessage)) + + // Validate the results + tt.validate(t, notifications, resourcesList) + }) + } +} diff --git a/server/server.go b/server/server.go index 5b2d739d..8ebd40bd 100644 --- a/server/server.go +++ b/server/server.go @@ -419,6 +419,18 @@ func (s *MCPServer) AddResource( } } +// RemoveResource removes a resource from the server +func (s *MCPServer) RemoveResource(uri string) { + s.resourcesMu.Lock() + delete(s.resources, uri) + s.resourcesMu.Unlock() + + // Send notification to all initialized sessions if listChanged capability is enabled + if s.capabilities.resources != nil && s.capabilities.resources.listChanged { + s.sendNotificationToAllClients("resources/list_changed", nil) + } +} + // AddResourceTemplate registers a new resource template and its handler func (s *MCPServer) AddResourceTemplate( template mcp.ResourceTemplate, diff --git a/server/server_test.go b/server/server_test.go index 8ee31b30..e55008f1 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -800,6 +800,13 @@ func TestMCPServer_HandleUndefinedHandlers(t *testing.T) { Type: "object", Properties: map[string]interface{}{}, }, + Annotations: mcp.ToolAnnotation{ + Title: "test-tool", + ReadOnlyHint: true, + DestructiveHint: false, + IdempotentHint: false, + OpenWorldHint: false, + }, }, func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { return &mcp.CallToolResult{}, nil }) diff --git a/server/sse.go b/server/sse.go index f69451c6..b6ae2144 100644 --- a/server/sse.go +++ b/server/sse.go @@ -23,6 +23,7 @@ type sseSession struct { done chan struct{} eventQueue chan string // Channel for queuing events sessionID string + requestID atomic.Int64 notificationChannel chan mcp.JSONRPCNotification initialized atomic.Bool } @@ -53,18 +54,20 @@ var _ ClientSession = (*sseSession)(nil) // SSEServer implements a Server-Sent Events (SSE) based MCP server. // It provides real-time communication capabilities over HTTP using the SSE protocol. type SSEServer struct { - server *MCPServer - baseURL string - basePath string - useFullURLForMessageEndpoint bool - messageEndpoint string - sseEndpoint string - sessions sync.Map - srv *http.Server - contextFunc SSEContextFunc + server *MCPServer + baseURL string + basePath string + useFullURLForMessageEndpoint bool + messageEndpoint string + sseEndpoint string + sessions sync.Map + srv *http.Server + contextFunc SSEContextFunc keepAlive bool keepAliveInterval time.Duration + + mu sync.RWMutex } // SSEOption defines a function type for configuring SSEServer @@ -158,12 +161,12 @@ func WithSSEContextFunc(fn SSEContextFunc) SSEOption { // NewSSEServer creates a new SSE server instance with the given MCP server and options. func NewSSEServer(server *MCPServer, opts ...SSEOption) *SSEServer { s := &SSEServer{ - server: server, - sseEndpoint: "/sse", - messageEndpoint: "/message", - useFullURLForMessageEndpoint: true, - keepAlive: false, - keepAliveInterval: 10 * time.Second, + server: server, + sseEndpoint: "/sse", + messageEndpoint: "/message", + useFullURLForMessageEndpoint: true, + keepAlive: false, + keepAliveInterval: 10 * time.Second, } // Apply all options @@ -189,10 +192,12 @@ func NewTestServer(server *MCPServer, opts ...SSEOption) *httptest.Server { // Start begins serving SSE connections on the specified address. // It sets up HTTP handlers for SSE and message endpoints. func (s *SSEServer) Start(addr string) error { + s.mu.Lock() s.srv = &http.Server{ Addr: addr, Handler: s, } + s.mu.Unlock() return s.srv.ListenAndServe() } @@ -200,7 +205,11 @@ func (s *SSEServer) Start(addr string) error { // Shutdown gracefully stops the SSE server, closing all active sessions // and shutting down the HTTP server. func (s *SSEServer) Shutdown(ctx context.Context) error { - if s.srv != nil { + s.mu.RLock() + srv := s.srv + s.mu.RUnlock() + + if srv != nil { s.sessions.Range(func(key, value interface{}) bool { if session, ok := value.(*sseSession); ok { close(session.done) @@ -209,7 +218,7 @@ func (s *SSEServer) Shutdown(ctx context.Context) error { return true }) - return s.srv.Shutdown(ctx) + return srv.Shutdown(ctx) } return nil } @@ -282,8 +291,16 @@ func (s *SSEServer) handleSSE(w http.ResponseWriter, r *http.Request) { for { select { case <-ticker.C: - //: ping - 2025-03-27 07:44:38.682659+00:00 - session.eventQueue <- fmt.Sprintf(":ping - %s\n\n", time.Now().Format(time.RFC3339)) + message := mcp.JSONRPCRequest{ + JSONRPC: "2.0", + ID: session.requestID.Add(1), + Request: mcp.Request{ + Method: "ping", + }, + } + messageBytes, _ := json.Marshal(message) + pingMsg := fmt.Sprintf("event: message\ndata:%s\n\n", messageBytes) + session.eventQueue <- pingMsg case <-session.done: return case <-r.Context().Done(): @@ -293,7 +310,6 @@ func (s *SSEServer) handleSSE(w http.ResponseWriter, r *http.Request) { }() } - // Send the initial endpoint event fmt.Fprintf(w, "event: endpoint\ndata: %s\r\n\r\n", s.GetMessageEndpointForClient(sessionID)) flusher.Flush() @@ -335,7 +351,6 @@ func (s *SSEServer) handleMessage(w http.ResponseWriter, r *http.Request) { s.writeJSONRPCError(w, nil, mcp.INVALID_PARAMS, "Missing sessionId") return } - sessionI, ok := s.sessions.Load(sessionID) if !ok { s.writeJSONRPCError(w, nil, mcp.INVALID_PARAMS, "Invalid session ID") diff --git a/server/sse_test.go b/server/sse_test.go index 111c5845..93474c6b 100644 --- a/server/sse_test.go +++ b/server/sse_test.go @@ -1,10 +1,12 @@ package server import ( + "bufio" "bytes" "context" "encoding/json" "fmt" + "io" "math/rand" "net/http" "net/http/httptest" @@ -739,4 +741,117 @@ func TestSSEServer(t *testing.T) { } } }) + + t.Run("Client receives and can respond to ping messages", func(t *testing.T) { + mcpServer := NewMCPServer("test", "1.0.0") + testServer := NewTestServer(mcpServer, + WithKeepAlive(true), + WithKeepAliveInterval(50*time.Millisecond), + ) + defer testServer.Close() + + sseResp, err := http.Get(fmt.Sprintf("%s/sse", testServer.URL)) + if err != nil { + t.Fatalf("Failed to connect to SSE endpoint: %v", err) + } + defer sseResp.Body.Close() + + reader := bufio.NewReader(sseResp.Body) + + var messageURL string + var pingID float64 + + for { + line, err := reader.ReadString('\n') + if err != nil { + t.Fatalf("Failed to read SSE event: %v", err) + } + + if strings.HasPrefix(line, "event: endpoint") { + dataLine, err := reader.ReadString('\n') + if err != nil { + t.Fatalf("Failed to read endpoint data: %v", err) + } + messageURL = strings.TrimSpace(strings.TrimPrefix(dataLine, "data: ")) + + _, err = reader.ReadString('\n') + if err != nil { + t.Fatalf("Failed to read blank line: %v", err) + } + } + + if strings.HasPrefix(line, "event: message") { + dataLine, err := reader.ReadString('\n') + if err != nil { + t.Fatalf("Failed to read message data: %v", err) + } + + pingData := strings.TrimSpace(strings.TrimPrefix(dataLine, "data:")) + var pingMsg mcp.JSONRPCRequest + if err := json.Unmarshal([]byte(pingData), &pingMsg); err != nil { + t.Fatalf("Failed to parse ping message: %v", err) + } + + if pingMsg.Method == "ping" { + pingID = pingMsg.ID.(float64) + t.Logf("Received ping with ID: %f", pingID) + break // We got the ping, exit the loop + } + + _, err = reader.ReadString('\n') + if err != nil { + t.Fatalf("Failed to read blank line: %v", err) + } + } + + if messageURL != "" && pingID != 0 { + break + } + } + + if messageURL == "" { + t.Fatal("Did not receive message endpoint URL") + } + + pingResponse := map[string]any{ + "jsonrpc": "2.0", + "id": pingID, + "result": map[string]any{}, + } + + requestBody, err := json.Marshal(pingResponse) + if err != nil { + t.Fatalf("Failed to marshal ping response: %v", err) + } + + resp, err := http.Post( + messageURL, + "application/json", + bytes.NewBuffer(requestBody), + ) + if err != nil { + t.Fatalf("Failed to send ping response: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusAccepted { + t.Errorf("Expected status 202 for ping response, got %d", resp.StatusCode) + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("Failed to read response body: %v", err) + } + + if len(body) > 0 { + var response map[string]any + if err := json.Unmarshal(body, &response); err != nil { + t.Fatalf("Failed to parse response body: %v", err) + } + + if response["error"] != nil { + t.Errorf("Expected no error in response, got %v", response["error"]) + } + } + }) } diff --git a/testdata/mockstdio_server.go b/testdata/mockstdio_server.go index 3100c5a2..9f13d554 100644 --- a/testdata/mockstdio_server.go +++ b/testdata/mockstdio_server.go @@ -10,14 +10,14 @@ import ( type JSONRPCRequest struct { JSONRPC string `json:"jsonrpc"` - ID int64 `json:"id"` + ID *int64 `json:"id,omitempty"` Method string `json:"method"` Params json.RawMessage `json:"params"` } type JSONRPCResponse struct { JSONRPC string `json:"jsonrpc"` - ID int64 `json:"id"` + ID *int64 `json:"id,omitempty"` Result interface{} `json:"result,omitempty"` Error *struct { Code int `json:"code"` @@ -138,6 +138,30 @@ func handleRequest(request JSONRPCRequest) JSONRPCResponse { "values": []string{"test completion"}, }, } + + // Debug methods for testing transport. + case "debug/echo": + response.Result = request + case "debug/echo_notification": + response.Result = request + + // send notification to client + responseBytes, _ := json.Marshal(map[string]any{ + "jsonrpc": "2.0", + "method": "debug/test", + "params": request, + }) + fmt.Fprintf(os.Stdout, "%s\n", responseBytes) + + case "debug/echo_error_string": + all, _ := json.Marshal(request) + response.Error = &struct { + Code int `json:"code"` + Message string `json:"message"` + }{ + Code: -32601, + Message: string(all), + } default: response.Error = &struct { Code int `json:"code"`