diff --git a/client/http.go b/client/http.go new file mode 100644 index 00000000..cb3be35d --- /dev/null +++ b/client/http.go @@ -0,0 +1,17 @@ +package client + +import ( + "fmt" + + "github.com/mark3labs/mcp-go/client/transport" +) + +// NewStreamableHttpClient is a convenience method that creates a new streamable-http-based MCP client +// with the given base URL. Returns an error if the URL is invalid. +func NewStreamableHttpClient(baseURL string, options ...transport.StreamableHTTPCOption) (*Client, error) { + trans, err := transport.NewStreamableHTTP(baseURL, options...) + if err != nil { + return nil, fmt.Errorf("failed to create SSE transport: %w", err) + } + return NewClient(trans), nil +} diff --git a/client/transport/stdio_test.go b/client/transport/stdio_test.go index 445ba07e..aa728ec6 100644 --- a/client/transport/stdio_test.go +++ b/client/transport/stdio_test.go @@ -7,6 +7,7 @@ import ( "os" "os/exec" "path/filepath" + "runtime" "sync" "testing" "time" @@ -31,6 +32,10 @@ func compileTestServer(outputPath string) error { func TestStdio(t *testing.T) { // Compile mock server mockServerPath := filepath.Join(os.TempDir(), "mockstdio_server") + // Add .exe suffix on Windows + if runtime.GOOS == "windows" { + mockServerPath += ".exe" + } if err := compileTestServer(mockServerPath); err != nil { t.Fatalf("Failed to compile mock server: %v", err) } @@ -302,8 +307,11 @@ func TestStdioErrors(t *testing.T) { }) t.Run("RequestBeforeStart", func(t *testing.T) { - // 创建一个新的 Stdio 实例但不调用 Start 方法 mockServerPath := filepath.Join(os.TempDir(), "mockstdio_server") + // Add .exe suffix on Windows + if runtime.GOOS == "windows" { + mockServerPath += ".exe" + } if err := compileTestServer(mockServerPath); err != nil { t.Fatalf("Failed to compile mock server: %v", err) } @@ -311,7 +319,7 @@ func TestStdioErrors(t *testing.T) { uninitiatedStdio := NewStdio(mockServerPath, nil) - // 准备一个请求 + // Prepare a request request := JSONRPCRequest{ JSONRPC: "2.0", ID: 99, @@ -331,6 +339,10 @@ func TestStdioErrors(t *testing.T) { t.Run("RequestAfterClose", func(t *testing.T) { // Compile mock server mockServerPath := filepath.Join(os.TempDir(), "mockstdio_server") + // Add .exe suffix on Windows + if runtime.GOOS == "windows" { + mockServerPath += ".exe" + } if err := compileTestServer(mockServerPath); err != nil { t.Fatalf("Failed to compile mock server: %v", err) } diff --git a/client/transport/streamable_http.go b/client/transport/streamable_http.go new file mode 100644 index 00000000..4bc60a25 --- /dev/null +++ b/client/transport/streamable_http.go @@ -0,0 +1,387 @@ +package transport + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/mark3labs/mcp-go/mcp" +) + +type StreamableHTTPCOption func(*StreamableHTTP) + +func WithHTTPHeaders(headers map[string]string) StreamableHTTPCOption { + return func(sc *StreamableHTTP) { + sc.headers = headers + } +} + +// WithHTTPTimeout sets the timeout for a HTTP request and stream. +func WithHTTPTimeout(timeout time.Duration) StreamableHTTPCOption { + return func(sc *StreamableHTTP) { + sc.httpClient.Timeout = timeout + } +} + +// StreamableHTTP implements Streamable HTTP transport. +// +// It transmits JSON-RPC messages over individual HTTP requests. One message per request. +// The HTTP response body can either be a single JSON-RPC response, +// or an upgraded SSE stream that concludes with a JSON-RPC response for the same request. +// +// https://modelcontextprotocol.io/specification/2025-03-26/basic/transports +// +// The current implementation does not support the following features: +// - batching +// - continuously listening for server notifications when no request is in flight +// (https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#listening-for-messages-from-the-server) +// - resuming stream +// (https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#resumability-and-redelivery) +// - server -> client request +type StreamableHTTP struct { + baseURL *url.URL + httpClient *http.Client + headers map[string]string + + sessionID atomic.Value // string + + notificationHandler func(mcp.JSONRPCNotification) + notifyMu sync.RWMutex + + closed chan struct{} +} + +// NewStreamableHTTP creates a new Streamable HTTP transport with the given base URL. +// Returns an error if the URL is invalid. +func NewStreamableHTTP(baseURL string, options ...StreamableHTTPCOption) (*StreamableHTTP, error) { + parsedURL, err := url.Parse(baseURL) + if err != nil { + return nil, fmt.Errorf("invalid URL: %w", err) + } + + smc := &StreamableHTTP{ + baseURL: parsedURL, + httpClient: &http.Client{}, + headers: make(map[string]string), + closed: make(chan struct{}), + } + smc.sessionID.Store("") // set initial value to simplify later usage + + for _, opt := range options { + opt(smc) + } + + return smc, nil +} + +// Start initiates the HTTP connection to the server. +func (c *StreamableHTTP) Start(ctx context.Context) error { + // For Streamable HTTP, we don't need to establish a persistent connection + return nil +} + +// Close closes the all the HTTP connections to the server. +func (c *StreamableHTTP) Close() error { + select { + case <-c.closed: + return nil + default: + } + // Cancel all in-flight requests + close(c.closed) + + sessionId := c.sessionID.Load().(string) + if sessionId != "" { + c.sessionID.Store("") + + // notify server session closed + go func() { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + req, err := http.NewRequestWithContext(ctx, http.MethodDelete, c.baseURL.String(), nil) + if err != nil { + fmt.Printf("failed to create close request\n: %v", err) + return + } + req.Header.Set(headerKeySessionID, sessionId) + res, err := c.httpClient.Do(req) + if err != nil { + fmt.Printf("failed to send close request\n: %v", err) + return + } + res.Body.Close() + }() + } + + return nil +} + +const ( + initializeMethod = "initialize" + headerKeySessionID = "Mcp-Session-Id" +) + +// sendRequest sends a JSON-RPC request to the server and waits for a response. +// Returns the raw JSON response message or an error if the request fails. +func (c *StreamableHTTP) SendRequest( + ctx context.Context, + request JSONRPCRequest, +) (*JSONRPCResponse, error) { + + // Create a combined context that could be canceled when the client is closed + newCtx, cancel := context.WithCancel(ctx) + defer cancel() + go func() { + select { + case <-c.closed: + cancel() + case <-newCtx.Done(): + // The original context was canceled, no need to do anything + } + }() + ctx = newCtx + + // Marshal request + requestBody, err := json.Marshal(request) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + // Create HTTP request + req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL.String(), bytes.NewReader(requestBody)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + // Set headers + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json, text/event-stream") + sessionID := c.sessionID.Load() + if sessionID != "" { + req.Header.Set(headerKeySessionID, sessionID.(string)) + } + for k, v := range c.headers { + req.Header.Set(k, v) + } + + // Send request + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + // Check if we got an error response + if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusAccepted { + // handle session closed + if resp.StatusCode == http.StatusNotFound { + c.sessionID.CompareAndSwap(sessionID, "") + return nil, fmt.Errorf("session terminated (404). need to re-initialize") + } + + // handle error response + var errResponse JSONRPCResponse + body, _ := io.ReadAll(resp.Body) + if err := json.Unmarshal(body, &errResponse); err == nil { + return &errResponse, nil + } + return nil, fmt.Errorf("request failed with status %d: %s", resp.StatusCode, body) + } + + if request.Method == initializeMethod { + // saved the received session ID in the response + // empty session ID is allowed + if sessionID := resp.Header.Get(headerKeySessionID); sessionID != "" { + c.sessionID.Store(sessionID) + } + } + + // Handle different response types + switch resp.Header.Get("Content-Type") { + case "application/json": + // Single response + var response JSONRPCResponse + if err := json.NewDecoder(resp.Body).Decode(&response); err != nil { + return nil, fmt.Errorf("failed to decode response: %w", err) + } + + // should not be a notification + if response.ID == nil { + return nil, fmt.Errorf("response should contain RPC id: %v", response) + } + + return &response, nil + + case "text/event-stream": + // Server is using SSE for streaming responses + return c.handleSSEResponse(ctx, resp.Body) + + default: + return nil, fmt.Errorf("unexpected content type: %s", resp.Header.Get("Content-Type")) + } +} + +// handleSSEResponse processes an SSE stream for a specific request. +// It returns the final result for the request once received, or an error. +func (c *StreamableHTTP) handleSSEResponse(ctx context.Context, reader io.ReadCloser) (*JSONRPCResponse, error) { + + // Create a channel for this specific request + responseChan := make(chan *JSONRPCResponse, 1) + defer close(responseChan) + + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + // Start a goroutine to process the SSE stream + go c.readSSE(ctx, reader, func(event, data string) { + + // (unsupported: batching) + + var message JSONRPCResponse + if err := json.Unmarshal([]byte(data), &message); err != nil { + fmt.Printf("failed to unmarshal message: %v\n", err) + return + } + + // Handle notification + if message.ID == nil { + var notification mcp.JSONRPCNotification + if err := json.Unmarshal([]byte(data), ¬ification); err != nil { + fmt.Printf("failed to unmarshal notification: %v\n", err) + return + } + c.notifyMu.RLock() + if c.notificationHandler != nil { + c.notificationHandler(notification) + } + c.notifyMu.RUnlock() + return + } + + responseChan <- &message + }) + + // Wait for the response or context cancellation + select { + case response := <-responseChan: + if response == nil { + return nil, fmt.Errorf("unexpected nil response") + } + return response, nil + case <-ctx.Done(): + return nil, ctx.Err() + } +} + +// readSSE reads the SSE stream(reader) and calls the handler for each event and data pair. +// It will end when the reader is closed (or the context is done). +func (c *StreamableHTTP) readSSE(ctx context.Context, reader io.ReadCloser, handler func(event, data string)) { + defer reader.Close() + + br := bufio.NewReader(reader) + var event, data string + + for { + select { + case <-ctx.Done(): + return + default: + line, err := br.ReadString('\n') + if err != nil { + if err == io.EOF { + // Process any pending event before exit + if event != "" && data != "" { + handler(event, data) + } + return + } + select { + case <-ctx.Done(): + return + default: + fmt.Printf("SSE stream error: %v\n", err) + return + } + } + + // Remove only newline markers + line = strings.TrimRight(line, "\r\n") + if line == "" { + // Empty line means end of event + if event != "" && data != "" { + handler(event, data) + event = "" + data = "" + } + continue + } + + if strings.HasPrefix(line, "event:") { + event = strings.TrimSpace(strings.TrimPrefix(line, "event:")) + } else if strings.HasPrefix(line, "data:") { + data = strings.TrimSpace(strings.TrimPrefix(line, "data:")) + } + } + } +} + +func (c *StreamableHTTP) SendNotification(ctx context.Context, notification mcp.JSONRPCNotification) error { + + // Marshal request + requestBody, err := json.Marshal(notification) + if err != nil { + return fmt.Errorf("failed to marshal notification: %w", err) + } + + // Create HTTP request + req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL.String(), bytes.NewReader(requestBody)) + if err != nil { + return fmt.Errorf("failed to create request: %w", err) + } + + // Set headers + req.Header.Set("Content-Type", "application/json") + if sessionID := c.sessionID.Load(); sessionID != "" { + req.Header.Set(headerKeySessionID, sessionID.(string)) + } + for k, v := range c.headers { + req.Header.Set(k, v) + } + + // Send request + resp, err := c.httpClient.Do(req) + if err != nil { + return fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusAccepted { + body, _ := io.ReadAll(resp.Body) + return fmt.Errorf( + "notification failed with status %d: %s", + resp.StatusCode, + body, + ) + } + + return nil +} + +func (c *StreamableHTTP) SetNotificationHandler(handler func(mcp.JSONRPCNotification)) { + c.notifyMu.Lock() + defer c.notifyMu.Unlock() + c.notificationHandler = handler +} + +func (c *StreamableHTTP) GetSessionId() string { + return c.sessionID.Load().(string) +} diff --git a/client/transport/streamable_http_test.go b/client/transport/streamable_http_test.go new file mode 100644 index 00000000..b7b76b96 --- /dev/null +++ b/client/transport/streamable_http_test.go @@ -0,0 +1,425 @@ +package transport + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "net/http" + "net/http/httptest" + "sync" + "testing" + "time" + + "github.com/mark3labs/mcp-go/mcp" +) + +// startMockStreamableHTTPServer starts a test HTTP server that implements +// a minimal Streamable HTTP server for testing purposes. +// It returns the server URL and a function to close the server. +func startMockStreamableHTTPServer() (string, func()) { + var sessionID string + var mu sync.Mutex + + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Handle only POST requests + if r.Method != http.MethodPost { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + // Parse incoming JSON-RPC request + var request map[string]any + decoder := json.NewDecoder(r.Body) + if err := decoder.Decode(&request); err != nil { + http.Error(w, fmt.Sprintf("Invalid JSON: %v", err), http.StatusBadRequest) + return + } + + method := request["method"] + switch method { + case "initialize": + // Generate a new session ID + mu.Lock() + sessionID = fmt.Sprintf("test-session-%d", time.Now().UnixNano()) + mu.Unlock() + w.Header().Set("Mcp-Session-Id", sessionID) + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusAccepted) + json.NewEncoder(w).Encode(map[string]interface{}{ + "jsonrpc": "2.0", + "id": request["id"], + "result": "initialized", + }) + + case "debug/echo": + // Check session ID + if r.Header.Get("Mcp-Session-Id") != sessionID { + http.Error(w, "Invalid session ID", http.StatusNotFound) + return + } + + // Echo back the request as the response result + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(map[string]interface{}{ + "jsonrpc": "2.0", + "id": request["id"], + "result": request, + }) + + case "debug/echo_notification": + // Check session ID + if r.Header.Get("Mcp-Session-Id") != sessionID { + http.Error(w, "Invalid session ID", http.StatusNotFound) + return + } + + // Send response and notification + w.Header().Set("Content-Type", "text/event-stream") + w.WriteHeader(http.StatusOK) + notification := map[string]any{ + "jsonrpc": "2.0", + "method": "debug/test", + "params": request, + } + notificationData, _ := json.Marshal(notification) + fmt.Fprintf(w, "event: message\ndata: %s\n\n", notificationData) + response := map[string]any{ + "jsonrpc": "2.0", + "id": request["id"], + "result": request, + } + responseData, _ := json.Marshal(response) + fmt.Fprintf(w, "event: message\ndata: %s\n\n", responseData) + + case "debug/echo_error_string": + // Check session ID + if r.Header.Get("Mcp-Session-Id") != sessionID { + http.Error(w, "Invalid session ID", http.StatusNotFound) + return + } + + // Return an error response + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + data, _ := json.Marshal(request) + json.NewEncoder(w).Encode(map[string]interface{}{ + "jsonrpc": "2.0", + "id": request["id"], + "error": map[string]interface{}{ + "code": -1, + "message": string(data), + }, + }) + } + }) + + // Start test server + testServer := httptest.NewServer(handler) + return testServer.URL, testServer.Close +} + +func TestStreamableHTTP(t *testing.T) { + // Start mock server + url, closeF := startMockStreamableHTTPServer() + defer closeF() + + // Create transport + trans, err := NewStreamableHTTP(url) + if err != nil { + t.Fatal(err) + } + defer trans.Close() + + // Initialize the transport first + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + initRequest := JSONRPCRequest{ + JSONRPC: "2.0", + ID: 1, + Method: "initialize", + } + + _, err = trans.SendRequest(ctx, initRequest) + if err != nil { + t.Fatal(err) + } + + // Now run the tests + t.Run("SendRequest", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + params := map[string]interface{}{ + "string": "hello world", + "array": []interface{}{1, 2, 3}, + } + + request := JSONRPCRequest{ + JSONRPC: "2.0", + ID: 1, + Method: "debug/echo", + Params: params, + } + + // Send the request + response, err := trans.SendRequest(ctx, request) + if err != nil { + t.Fatalf("SendRequest failed: %v", err) + } + + // Parse the result to verify echo + var result struct { + JSONRPC string `json:"jsonrpc"` + ID int64 `json:"id"` + Method string `json:"method"` + Params map[string]interface{} `json:"params"` + } + + if err := json.Unmarshal(response.Result, &result); err != nil { + t.Fatalf("Failed to unmarshal result: %v", err) + } + + // Verify response data matches what was sent + if result.JSONRPC != "2.0" { + t.Errorf("Expected JSONRPC value '2.0', got '%s'", result.JSONRPC) + } + if result.ID != 1 { + t.Errorf("Expected ID 1, got %d", result.ID) + } + if result.Method != "debug/echo" { + t.Errorf("Expected method 'debug/echo', got '%s'", result.Method) + } + + if str, ok := result.Params["string"].(string); !ok || str != "hello world" { + t.Errorf("Expected string 'hello world', got %v", result.Params["string"]) + } + + if arr, ok := result.Params["array"].([]interface{}); !ok || len(arr) != 3 { + t.Errorf("Expected array with 3 items, got %v", result.Params["array"]) + } + }) + + t.Run("SendRequestWithTimeout", func(t *testing.T) { + // Create a context that's already canceled + ctx, cancel := context.WithCancel(context.Background()) + cancel() // Cancel the context immediately + + // Prepare a request + request := JSONRPCRequest{ + JSONRPC: "2.0", + ID: 3, + Method: "debug/echo", + } + + // The request should fail because the context is canceled + _, err := trans.SendRequest(ctx, request) + if err == nil { + t.Errorf("Expected context canceled error, got nil") + } else if !errors.Is(err, context.Canceled) { + t.Errorf("Expected context.Canceled error, got: %v", err) + } + }) + + t.Run("SendNotification & NotificationHandler", func(t *testing.T) { + var wg sync.WaitGroup + notificationChan := make(chan mcp.JSONRPCNotification, 1) + + // Set notification handler + trans.SetNotificationHandler(func(notification mcp.JSONRPCNotification) { + notificationChan <- notification + }) + + // Send a request that triggers a notification + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + request := JSONRPCRequest{ + JSONRPC: "2.0", + ID: 1, + Method: "debug/echo_notification", + } + + _, err := trans.SendRequest(ctx, request) + if err != nil { + t.Fatalf("SendRequest failed: %v", err) + } + + wg.Add(1) + go func() { + defer wg.Done() + select { + case notification := <-notificationChan: + // We received a notification + got := notification.Params.AdditionalFields + if got == nil { + t.Errorf("Notification handler did not send the expected notification: got nil") + } + if int64(got["id"].(float64)) != request.ID || + got["jsonrpc"] != request.JSONRPC || + got["method"] != request.Method { + + responseJson, _ := json.Marshal(got) + requestJson, _ := json.Marshal(request) + t.Errorf("Notification handler did not send the expected notification: \ngot %s\nexpect %s", responseJson, requestJson) + } + + case <-time.After(1 * time.Second): + t.Errorf("Expected notification, got none") + } + }() + + wg.Wait() + }) + + t.Run("MultipleRequests", func(t *testing.T) { + var wg sync.WaitGroup + const numRequests = 5 + + // Send multiple requests concurrently + mu := sync.Mutex{} + responses := make([]*JSONRPCResponse, numRequests) + errors := make([]error, numRequests) + + for i := 0; i < numRequests; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + // Each request has a unique ID and payload + request := JSONRPCRequest{ + JSONRPC: "2.0", + ID: int64(100 + idx), + Method: "debug/echo", + Params: map[string]interface{}{ + "requestIndex": idx, + "timestamp": time.Now().UnixNano(), + }, + } + + resp, err := trans.SendRequest(ctx, request) + mu.Lock() + responses[idx] = resp + errors[idx] = err + mu.Unlock() + }(i) + } + + wg.Wait() + + // Check results + for i := 0; i < numRequests; i++ { + if errors[i] != nil { + t.Errorf("Request %d failed: %v", i, errors[i]) + continue + } + + if responses[i] == nil || responses[i].ID == nil || *responses[i].ID != int64(100+i) { + t.Errorf("Request %d: Expected ID %d, got %v", i, 100+i, responses[i]) + continue + } + + // Parse the result to verify echo + var result struct { + JSONRPC string `json:"jsonrpc"` + ID int64 `json:"id"` + Method string `json:"method"` + Params map[string]interface{} `json:"params"` + } + + if err := json.Unmarshal(responses[i].Result, &result); err != nil { + t.Errorf("Request %d: Failed to unmarshal result: %v", i, err) + continue + } + + // Verify data matches what was sent + if result.ID != int64(100+i) { + t.Errorf("Request %d: Expected echoed ID %d, got %d", i, 100+i, result.ID) + } + + if result.Method != "debug/echo" { + t.Errorf("Request %d: Expected method 'debug/echo', got '%s'", i, result.Method) + } + + // Verify the requestIndex parameter + if idx, ok := result.Params["requestIndex"].(float64); !ok || int(idx) != i { + t.Errorf("Request %d: Expected requestIndex %d, got %v", i, i, result.Params["requestIndex"]) + } + } + }) + + t.Run("ResponseError", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + // Prepare a request + request := JSONRPCRequest{ + JSONRPC: "2.0", + ID: 100, + Method: "debug/echo_error_string", + } + + reps, err := trans.SendRequest(ctx, request) + if err != nil { + t.Errorf("SendRequest failed: %v", err) + } + + if reps.Error == nil { + t.Errorf("Expected error, got nil") + } + + var responseError JSONRPCRequest + if err := json.Unmarshal([]byte(reps.Error.Message), &responseError); err != nil { + t.Errorf("Failed to unmarshal result: %v", err) + return + } + + if responseError.Method != "debug/echo_error_string" { + t.Errorf("Expected method 'debug/echo_error_string', got '%s'", responseError.Method) + } + if responseError.ID != 100 { + t.Errorf("Expected ID 100, got %d", responseError.ID) + } + if responseError.JSONRPC != "2.0" { + t.Errorf("Expected JSONRPC '2.0', got '%s'", responseError.JSONRPC) + } + }) +} + +func TestStreamableHTTPErrors(t *testing.T) { + t.Run("InvalidURL", func(t *testing.T) { + // Create a new StreamableHTTP transport with an invalid URL + _, err := NewStreamableHTTP("://invalid-url") + if err == nil { + t.Errorf("Expected error when creating with invalid URL, got nil") + } + }) + + t.Run("NonExistentURL", func(t *testing.T) { + // Create a new StreamableHTTP transport with a non-existent URL + trans, err := NewStreamableHTTP("http://localhost:1") + if err != nil { + t.Fatalf("Failed to create StreamableHTTP transport: %v", err) + } + + // Send request should fail + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + request := JSONRPCRequest{ + JSONRPC: "2.0", + ID: 1, + Method: "initialize", + } + + _, err = trans.SendRequest(ctx, request) + if err == nil { + t.Errorf("Expected error when sending request to non-existent URL, got nil") + } + }) + +} diff --git a/examples/everything/main.go b/examples/everything/main.go index 3e5fd5d9..2701f89d 100644 --- a/examples/everything/main.go +++ b/examples/everything/main.go @@ -369,6 +369,7 @@ func handleLongRunningOperationTool( "progress": i, "total": int(steps), "progressToken": progressToken, + "message": fmt.Sprintf("Server progress %v%%", int(float64(i)*100/steps)), }, ) } diff --git a/go.mod b/go.mod index 940f05de..9b9fe2d4 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,7 @@ go 1.23 require ( github.com/google/uuid v1.6.0 + github.com/spf13/cast v1.7.1 github.com/stretchr/testify v1.9.0 github.com/yosida95/uritemplate/v3 v3.0.2 ) diff --git a/go.sum b/go.sum index 14437f70..31ed86d1 100644 --- a/go.sum +++ b/go.sum @@ -1,9 +1,21 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= +github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= +github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= +github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= +github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= +github.com/spf13/cast v1.7.1 h1:cuNEagBQEHWN1FnbGEjCXL2szYEXqfJPbP2HNUaca9Y= +github.com/spf13/cast v1.7.1/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= diff --git a/mcp/tools_test.go b/mcp/tools_test.go index 31a5b93e..872749e1 100644 --- a/mcp/tools_test.go +++ b/mcp/tools_test.go @@ -2,6 +2,7 @@ package mcp import ( "encoding/json" + "fmt" "testing" "github.com/stretchr/testify/assert" @@ -240,3 +241,70 @@ func TestToolWithObjectAndArray(t *testing.T) { assert.True(t, ok) assert.Contains(t, required, "books") } + +func TestParseToolCallToolRequest(t *testing.T) { + request := CallToolRequest{} + request.Params.Name = "test-tool" + request.Params.Arguments = map[string]interface{}{ + "bool_value": "true", + "int64_value": "123456789", + "int32_value": "123456789", + "int16_value": "123456789", + "int8_value": "123456789", + "int_value": "123456789", + "uint_value": "123456789", + "uint64_value": "123456789", + "uint32_value": "123456789", + "uint16_value": "123456789", + "uint8_value": "123456789", + "float32_value": "3.14", + "float64_value": "3.1415926", + "string_value": "hello", + } + param1 := ParseBoolean(request, "bool_value", false) + assert.Equal(t, fmt.Sprintf("%T", param1), "bool") + + param2 := ParseInt64(request, "int64_value", 1) + assert.Equal(t, fmt.Sprintf("%T", param2), "int64") + + param3 := ParseInt32(request, "int32_value", 1) + assert.Equal(t, fmt.Sprintf("%T", param3), "int32") + + param4 := ParseInt16(request, "int16_value", 1) + assert.Equal(t, fmt.Sprintf("%T", param4), "int16") + + param5 := ParseInt8(request, "int8_value", 1) + assert.Equal(t, fmt.Sprintf("%T", param5), "int8") + + param6 := ParseInt(request, "int_value", 1) + assert.Equal(t, fmt.Sprintf("%T", param6), "int") + + param7 := ParseUInt(request, "uint_value", 1) + assert.Equal(t, fmt.Sprintf("%T", param7), "uint") + + param8 := ParseUInt64(request, "uint64_value", 1) + assert.Equal(t, fmt.Sprintf("%T", param8), "uint64") + + param9 := ParseUInt32(request, "uint32_value", 1) + assert.Equal(t, fmt.Sprintf("%T", param9), "uint32") + + param10 := ParseUInt16(request, "uint16_value", 1) + assert.Equal(t, fmt.Sprintf("%T", param10), "uint16") + + param11 := ParseUInt8(request, "uint8_value", 1) + assert.Equal(t, fmt.Sprintf("%T", param11), "uint8") + + param12 := ParseFloat32(request, "float32_value", 1.0) + assert.Equal(t, fmt.Sprintf("%T", param12), "float32") + + param13 := ParseFloat64(request, "float64_value", 1.0) + assert.Equal(t, fmt.Sprintf("%T", param13), "float64") + + param14 := ParseString(request, "string_value", "") + assert.Equal(t, fmt.Sprintf("%T", param14), "string") + + param15 := ParseInt64(request, "string_value", 1) + assert.Equal(t, fmt.Sprintf("%T", param15), "int64") + t.Logf("param15 type: %T,value:%v", param15, param15) + +} diff --git a/mcp/types.go b/mcp/types.go index a3ad8174..c940a460 100644 --- a/mcp/types.go +++ b/mcp/types.go @@ -371,6 +371,9 @@ type ProgressNotification struct { Progress float64 `json:"progress"` // Total number of items to process (or total progress required), if known. Total float64 `json:"total,omitempty"` + // Message related to progress. This should provide relevant human-readable + // progress information. + Message string `json:"message,omitempty"` } `json:"params"` } diff --git a/mcp/utils.go b/mcp/utils.go index 55a708e9..fc97208c 100644 --- a/mcp/utils.go +++ b/mcp/utils.go @@ -3,6 +3,7 @@ package mcp import ( "encoding/json" "fmt" + "github.com/spf13/cast" ) // ClientRequest types @@ -129,6 +130,7 @@ func NewProgressNotification( token ProgressToken, progress float64, total *float64, + message *string, ) ProgressNotification { notification := ProgressNotification{ Notification: Notification{ @@ -138,6 +140,7 @@ func NewProgressNotification( ProgressToken ProgressToken `json:"progressToken"` Progress float64 `json:"progress"` Total float64 `json:"total,omitempty"` + Message string `json:"message,omitempty"` }{ ProgressToken: token, Progress: progress, @@ -146,6 +149,9 @@ func NewProgressNotification( if total != nil { notification.Params.Total = *total } + if message != nil { + notification.Params.Message = *message + } return notification } @@ -612,3 +618,105 @@ func ParseReadResourceResult(rawMessage *json.RawMessage) (*ReadResourceResult, return &result, nil } + +func ParseArgument(request CallToolRequest, key string, defaultVal any) any { + if _, ok := request.Params.Arguments[key]; !ok { + return defaultVal + } else { + return request.Params.Arguments[key] + } +} + +// ParseBoolean extracts and converts a boolean parameter from a CallToolRequest. +// If the key is not found in the Arguments map, the defaultValue is returned. +// The function uses cast.ToBool for conversion which handles various string representations +// such as "true", "yes", "1", etc. +func ParseBoolean(request CallToolRequest, key string, defaultValue bool) bool { + v := ParseArgument(request, key, defaultValue) + return cast.ToBool(v) +} + +// ParseInt64 extracts and converts an int64 parameter from a CallToolRequest. +// If the key is not found in the Arguments map, the defaultValue is returned. +func ParseInt64(request CallToolRequest, key string, defaultValue int64) int64 { + v := ParseArgument(request, key, defaultValue) + return cast.ToInt64(v) +} + +// ParseInt32 extracts and converts an int32 parameter from a CallToolRequest. +func ParseInt32(request CallToolRequest, key string, defaultValue int32) int32 { + v := ParseArgument(request, key, defaultValue) + return cast.ToInt32(v) +} + +// ParseInt16 extracts and converts an int16 parameter from a CallToolRequest. +func ParseInt16(request CallToolRequest, key string, defaultValue int16) int16 { + v := ParseArgument(request, key, defaultValue) + return cast.ToInt16(v) +} + +// ParseInt8 extracts and converts an int8 parameter from a CallToolRequest. +func ParseInt8(request CallToolRequest, key string, defaultValue int8) int8 { + v := ParseArgument(request, key, defaultValue) + return cast.ToInt8(v) +} + +// ParseInt extracts and converts an int parameter from a CallToolRequest. +func ParseInt(request CallToolRequest, key string, defaultValue int) int { + v := ParseArgument(request, key, defaultValue) + return cast.ToInt(v) +} + +// ParseUInt extracts and converts an uint parameter from a CallToolRequest. +func ParseUInt(request CallToolRequest, key string, defaultValue uint) uint { + v := ParseArgument(request, key, defaultValue) + return cast.ToUint(v) +} + +// ParseUInt64 extracts and converts an uint64 parameter from a CallToolRequest. +func ParseUInt64(request CallToolRequest, key string, defaultValue uint64) uint64 { + v := ParseArgument(request, key, defaultValue) + return cast.ToUint64(v) +} + +// ParseUInt32 extracts and converts an uint32 parameter from a CallToolRequest. +func ParseUInt32(request CallToolRequest, key string, defaultValue uint32) uint32 { + v := ParseArgument(request, key, defaultValue) + return cast.ToUint32(v) +} + +// ParseUInt16 extracts and converts an uint16 parameter from a CallToolRequest. +func ParseUInt16(request CallToolRequest, key string, defaultValue uint16) uint16 { + v := ParseArgument(request, key, defaultValue) + return cast.ToUint16(v) +} + +// ParseUInt8 extracts and converts an uint8 parameter from a CallToolRequest. +func ParseUInt8(request CallToolRequest, key string, defaultValue uint8) uint8 { + v := ParseArgument(request, key, defaultValue) + return cast.ToUint8(v) +} + +// ParseFloat32 extracts and converts a float32 parameter from a CallToolRequest. +func ParseFloat32(request CallToolRequest, key string, defaultValue float32) float32 { + v := ParseArgument(request, key, defaultValue) + return cast.ToFloat32(v) +} + +// ParseFloat64 extracts and converts a float64 parameter from a CallToolRequest. +func ParseFloat64(request CallToolRequest, key string, defaultValue float64) float64 { + v := ParseArgument(request, key, defaultValue) + return cast.ToFloat64(v) +} + +// ParseString extracts and converts a string parameter from a CallToolRequest. +func ParseString(request CallToolRequest, key string, defaultValue string) string { + v := ParseArgument(request, key, defaultValue) + return cast.ToString(v) +} + +// ParseStringMap extracts and converts a string map parameter from a CallToolRequest. +func ParseStringMap(request CallToolRequest, key string, defaultValue map[string]any) map[string]any { + v := ParseArgument(request, key, defaultValue) + return cast.ToStringMap(v) +} diff --git a/server/internal/gen/request_handler.go.tmpl b/server/internal/gen/request_handler.go.tmpl index 4e139e17..5c69f5fa 100644 --- a/server/internal/gen/request_handler.go.tmpl +++ b/server/internal/gen/request_handler.go.tmpl @@ -24,6 +24,7 @@ func (s *MCPServer) HandleMessage( JSONRPC string `json:"jsonrpc"` Method mcp.MCPMethod `json:"method"` ID any `json:"id,omitempty"` + Result any `json:"result,omitempty"` } if err := json.Unmarshal(message, &baseMessage); err != nil { @@ -56,6 +57,12 @@ func (s *MCPServer) HandleMessage( return nil // Return nil for notifications } + if baseMessage.Result != nil { + // this is a response to a request sent by the server (e.g. from a ping + // sent due to WithKeepAlive option) + return nil + } + switch baseMessage.Method { {{- range .}} case mcp.{{.MethodName}}: diff --git a/server/request_handler.go b/server/request_handler.go index 946ca7ab..55d2d19e 100644 --- a/server/request_handler.go +++ b/server/request_handler.go @@ -23,6 +23,7 @@ func (s *MCPServer) HandleMessage( JSONRPC string `json:"jsonrpc"` Method mcp.MCPMethod `json:"method"` ID any `json:"id,omitempty"` + Result any `json:"result,omitempty"` } if err := json.Unmarshal(message, &baseMessage); err != nil { @@ -55,6 +56,12 @@ func (s *MCPServer) HandleMessage( return nil // Return nil for notifications } + if baseMessage.Result != nil { + // this is a response to a request sent by the server (e.g. from a ping + // sent due to WithKeepAlive option) + return nil + } + switch baseMessage.Method { case mcp.MethodInitialize: var request mcp.InitializeRequest diff --git a/server/resource_test.go b/server/resource_test.go new file mode 100644 index 00000000..94b35a3d --- /dev/null +++ b/server/resource_test.go @@ -0,0 +1,253 @@ +package server + +import ( + "context" + "testing" + "time" + + "github.com/mark3labs/mcp-go/mcp" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestMCPServer_RemoveResource(t *testing.T) { + tests := []struct { + name string + action func(*testing.T, *MCPServer, chan mcp.JSONRPCNotification) + expectedNotifications int + validate func(*testing.T, []mcp.JSONRPCNotification, mcp.JSONRPCMessage) + }{ + { + name: "RemoveResource removes the resource from the server", + action: func(t *testing.T, server *MCPServer, notificationChannel chan mcp.JSONRPCNotification) { + // Add a test resource + server.AddResource( + mcp.NewResource( + "test://resource1", + "Resource 1", + mcp.WithResourceDescription("Test resource 1"), + mcp.WithMIMEType("text/plain"), + ), + func(ctx context.Context, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { + return []mcp.ResourceContents{ + mcp.TextResourceContents{ + URI: "test://resource1", + MIMEType: "text/plain", + Text: "test content 1", + }, + }, nil + }, + ) + + // Add a second resource + server.AddResource( + mcp.NewResource( + "test://resource2", + "Resource 2", + mcp.WithResourceDescription("Test resource 2"), + mcp.WithMIMEType("text/plain"), + ), + func(ctx context.Context, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { + return []mcp.ResourceContents{ + mcp.TextResourceContents{ + URI: "test://resource2", + MIMEType: "text/plain", + Text: "test content 2", + }, + }, nil + }, + ) + + // First, verify we have two resources + response := server.HandleMessage(context.Background(), []byte(`{ + "jsonrpc": "2.0", + "id": 1, + "method": "resources/list" + }`)) + resp, ok := response.(mcp.JSONRPCResponse) + assert.True(t, ok) + result, ok := resp.Result.(mcp.ListResourcesResult) + assert.True(t, ok) + assert.Len(t, result.Resources, 2) + + // Now register session to receive notifications + err := server.RegisterSession(context.TODO(), &fakeSession{ + sessionID: "test", + notificationChannel: notificationChannel, + initialized: true, + }) + require.NoError(t, err) + + // Now remove one resource + server.RemoveResource("test://resource1") + }, + expectedNotifications: 1, + validate: func(t *testing.T, notifications []mcp.JSONRPCNotification, resourcesList mcp.JSONRPCMessage) { + // Check that we received a list_changed notification + assert.Equal(t, "resources/list_changed", notifications[0].Method) + + // Verify we now have only one resource + resp, ok := resourcesList.(mcp.JSONRPCResponse) + assert.True(t, ok, "Expected JSONRPCResponse, got %T", resourcesList) + + result, ok := resp.Result.(mcp.ListResourcesResult) + assert.True(t, ok, "Expected ListResourcesResult, got %T", resp.Result) + + assert.Len(t, result.Resources, 1) + assert.Equal(t, "Resource 2", result.Resources[0].Name) + }, + }, + { + name: "RemoveResource with non-existent resource does nothing", + action: func(t *testing.T, server *MCPServer, notificationChannel chan mcp.JSONRPCNotification) { + // Add a test resource + server.AddResource( + mcp.NewResource( + "test://resource1", + "Resource 1", + mcp.WithResourceDescription("Test resource 1"), + mcp.WithMIMEType("text/plain"), + ), + func(ctx context.Context, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { + return []mcp.ResourceContents{ + mcp.TextResourceContents{ + URI: "test://resource1", + MIMEType: "text/plain", + Text: "test content 1", + }, + }, nil + }, + ) + + // Register session to receive notifications + err := server.RegisterSession(context.TODO(), &fakeSession{ + sessionID: "test", + notificationChannel: notificationChannel, + initialized: true, + }) + require.NoError(t, err) + + // Remove a non-existent resource + server.RemoveResource("test://nonexistent") + }, + expectedNotifications: 1, // Still sends a notification + validate: func(t *testing.T, notifications []mcp.JSONRPCNotification, resourcesList mcp.JSONRPCMessage) { + // Check that we received a list_changed notification + assert.Equal(t, "resources/list_changed", notifications[0].Method) + + // The original resource should still be there + resp, ok := resourcesList.(mcp.JSONRPCResponse) + assert.True(t, ok) + + result, ok := resp.Result.(mcp.ListResourcesResult) + assert.True(t, ok) + + assert.Len(t, result.Resources, 1) + assert.Equal(t, "Resource 1", result.Resources[0].Name) + }, + }, + { + name: "RemoveResource with no listChanged capability doesn't send notification", + action: func(t *testing.T, server *MCPServer, notificationChannel chan mcp.JSONRPCNotification) { + // Create a new server without listChanged capability + noListChangedServer := NewMCPServer( + "test-server", + "1.0.0", + WithResourceCapabilities(true, false), // Subscribe but not listChanged + ) + + // Add a resource + noListChangedServer.AddResource( + mcp.NewResource( + "test://resource1", + "Resource 1", + mcp.WithResourceDescription("Test resource 1"), + mcp.WithMIMEType("text/plain"), + ), + func(ctx context.Context, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { + return []mcp.ResourceContents{ + mcp.TextResourceContents{ + URI: "test://resource1", + MIMEType: "text/plain", + Text: "test content 1", + }, + }, nil + }, + ) + + // Register session to receive notifications + err := noListChangedServer.RegisterSession(context.TODO(), &fakeSession{ + sessionID: "test", + notificationChannel: notificationChannel, + initialized: true, + }) + require.NoError(t, err) + + // Remove the resource + noListChangedServer.RemoveResource("test://resource1") + + // The test can now proceed without waiting for notifications + // since we don't expect any + }, + expectedNotifications: 0, // No notifications expected + validate: func(t *testing.T, notifications []mcp.JSONRPCNotification, resourcesList mcp.JSONRPCMessage) { + // Nothing to do here, we're just verifying that no notifications were sent + assert.Empty(t, notifications) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := context.Background() + server := NewMCPServer( + "test-server", + "1.0.0", + WithResourceCapabilities(true, true), + ) + + // Initialize the server + _ = server.HandleMessage(ctx, []byte(`{ + "jsonrpc": "2.0", + "id": 1, + "method": "initialize" + }`)) + + notificationChannel := make(chan mcp.JSONRPCNotification, 100) + notifications := make([]mcp.JSONRPCNotification, 0) + + tt.action(t, server, notificationChannel) + + // Collect notifications with a timeout + if tt.expectedNotifications > 0 { + for i := 0; i < tt.expectedNotifications; i++ { + select { + case notification := <-notificationChannel: + notifications = append(notifications, notification) + case <-time.After(1 * time.Second): + t.Fatalf("Expected %d notifications but only received %d", tt.expectedNotifications, len(notifications)) + } + } + } else { + // If no notifications expected, wait a brief period to ensure none are sent + select { + case notification := <-notificationChannel: + notifications = append(notifications, notification) + case <-time.After(100 * time.Millisecond): + // This is the expected path - no notifications + } + } + + // Get final resources list + listMessage := `{ + "jsonrpc": "2.0", + "id": 1, + "method": "resources/list" + }` + resourcesList := server.HandleMessage(ctx, []byte(listMessage)) + + // Validate the results + tt.validate(t, notifications, resourcesList) + }) + } +} diff --git a/server/server.go b/server/server.go index 5b2d739d..8ebd40bd 100644 --- a/server/server.go +++ b/server/server.go @@ -419,6 +419,18 @@ func (s *MCPServer) AddResource( } } +// RemoveResource removes a resource from the server +func (s *MCPServer) RemoveResource(uri string) { + s.resourcesMu.Lock() + delete(s.resources, uri) + s.resourcesMu.Unlock() + + // Send notification to all initialized sessions if listChanged capability is enabled + if s.capabilities.resources != nil && s.capabilities.resources.listChanged { + s.sendNotificationToAllClients("resources/list_changed", nil) + } +} + // AddResourceTemplate registers a new resource template and its handler func (s *MCPServer) AddResourceTemplate( template mcp.ResourceTemplate, diff --git a/server/sse.go b/server/sse.go index f69451c6..b6ae2144 100644 --- a/server/sse.go +++ b/server/sse.go @@ -23,6 +23,7 @@ type sseSession struct { done chan struct{} eventQueue chan string // Channel for queuing events sessionID string + requestID atomic.Int64 notificationChannel chan mcp.JSONRPCNotification initialized atomic.Bool } @@ -53,18 +54,20 @@ var _ ClientSession = (*sseSession)(nil) // SSEServer implements a Server-Sent Events (SSE) based MCP server. // It provides real-time communication capabilities over HTTP using the SSE protocol. type SSEServer struct { - server *MCPServer - baseURL string - basePath string - useFullURLForMessageEndpoint bool - messageEndpoint string - sseEndpoint string - sessions sync.Map - srv *http.Server - contextFunc SSEContextFunc + server *MCPServer + baseURL string + basePath string + useFullURLForMessageEndpoint bool + messageEndpoint string + sseEndpoint string + sessions sync.Map + srv *http.Server + contextFunc SSEContextFunc keepAlive bool keepAliveInterval time.Duration + + mu sync.RWMutex } // SSEOption defines a function type for configuring SSEServer @@ -158,12 +161,12 @@ func WithSSEContextFunc(fn SSEContextFunc) SSEOption { // NewSSEServer creates a new SSE server instance with the given MCP server and options. func NewSSEServer(server *MCPServer, opts ...SSEOption) *SSEServer { s := &SSEServer{ - server: server, - sseEndpoint: "/sse", - messageEndpoint: "/message", - useFullURLForMessageEndpoint: true, - keepAlive: false, - keepAliveInterval: 10 * time.Second, + server: server, + sseEndpoint: "/sse", + messageEndpoint: "/message", + useFullURLForMessageEndpoint: true, + keepAlive: false, + keepAliveInterval: 10 * time.Second, } // Apply all options @@ -189,10 +192,12 @@ func NewTestServer(server *MCPServer, opts ...SSEOption) *httptest.Server { // Start begins serving SSE connections on the specified address. // It sets up HTTP handlers for SSE and message endpoints. func (s *SSEServer) Start(addr string) error { + s.mu.Lock() s.srv = &http.Server{ Addr: addr, Handler: s, } + s.mu.Unlock() return s.srv.ListenAndServe() } @@ -200,7 +205,11 @@ func (s *SSEServer) Start(addr string) error { // Shutdown gracefully stops the SSE server, closing all active sessions // and shutting down the HTTP server. func (s *SSEServer) Shutdown(ctx context.Context) error { - if s.srv != nil { + s.mu.RLock() + srv := s.srv + s.mu.RUnlock() + + if srv != nil { s.sessions.Range(func(key, value interface{}) bool { if session, ok := value.(*sseSession); ok { close(session.done) @@ -209,7 +218,7 @@ func (s *SSEServer) Shutdown(ctx context.Context) error { return true }) - return s.srv.Shutdown(ctx) + return srv.Shutdown(ctx) } return nil } @@ -282,8 +291,16 @@ func (s *SSEServer) handleSSE(w http.ResponseWriter, r *http.Request) { for { select { case <-ticker.C: - //: ping - 2025-03-27 07:44:38.682659+00:00 - session.eventQueue <- fmt.Sprintf(":ping - %s\n\n", time.Now().Format(time.RFC3339)) + message := mcp.JSONRPCRequest{ + JSONRPC: "2.0", + ID: session.requestID.Add(1), + Request: mcp.Request{ + Method: "ping", + }, + } + messageBytes, _ := json.Marshal(message) + pingMsg := fmt.Sprintf("event: message\ndata:%s\n\n", messageBytes) + session.eventQueue <- pingMsg case <-session.done: return case <-r.Context().Done(): @@ -293,7 +310,6 @@ func (s *SSEServer) handleSSE(w http.ResponseWriter, r *http.Request) { }() } - // Send the initial endpoint event fmt.Fprintf(w, "event: endpoint\ndata: %s\r\n\r\n", s.GetMessageEndpointForClient(sessionID)) flusher.Flush() @@ -335,7 +351,6 @@ func (s *SSEServer) handleMessage(w http.ResponseWriter, r *http.Request) { s.writeJSONRPCError(w, nil, mcp.INVALID_PARAMS, "Missing sessionId") return } - sessionI, ok := s.sessions.Load(sessionID) if !ok { s.writeJSONRPCError(w, nil, mcp.INVALID_PARAMS, "Invalid session ID") diff --git a/server/sse_test.go b/server/sse_test.go index 111c5845..93474c6b 100644 --- a/server/sse_test.go +++ b/server/sse_test.go @@ -1,10 +1,12 @@ package server import ( + "bufio" "bytes" "context" "encoding/json" "fmt" + "io" "math/rand" "net/http" "net/http/httptest" @@ -739,4 +741,117 @@ func TestSSEServer(t *testing.T) { } } }) + + t.Run("Client receives and can respond to ping messages", func(t *testing.T) { + mcpServer := NewMCPServer("test", "1.0.0") + testServer := NewTestServer(mcpServer, + WithKeepAlive(true), + WithKeepAliveInterval(50*time.Millisecond), + ) + defer testServer.Close() + + sseResp, err := http.Get(fmt.Sprintf("%s/sse", testServer.URL)) + if err != nil { + t.Fatalf("Failed to connect to SSE endpoint: %v", err) + } + defer sseResp.Body.Close() + + reader := bufio.NewReader(sseResp.Body) + + var messageURL string + var pingID float64 + + for { + line, err := reader.ReadString('\n') + if err != nil { + t.Fatalf("Failed to read SSE event: %v", err) + } + + if strings.HasPrefix(line, "event: endpoint") { + dataLine, err := reader.ReadString('\n') + if err != nil { + t.Fatalf("Failed to read endpoint data: %v", err) + } + messageURL = strings.TrimSpace(strings.TrimPrefix(dataLine, "data: ")) + + _, err = reader.ReadString('\n') + if err != nil { + t.Fatalf("Failed to read blank line: %v", err) + } + } + + if strings.HasPrefix(line, "event: message") { + dataLine, err := reader.ReadString('\n') + if err != nil { + t.Fatalf("Failed to read message data: %v", err) + } + + pingData := strings.TrimSpace(strings.TrimPrefix(dataLine, "data:")) + var pingMsg mcp.JSONRPCRequest + if err := json.Unmarshal([]byte(pingData), &pingMsg); err != nil { + t.Fatalf("Failed to parse ping message: %v", err) + } + + if pingMsg.Method == "ping" { + pingID = pingMsg.ID.(float64) + t.Logf("Received ping with ID: %f", pingID) + break // We got the ping, exit the loop + } + + _, err = reader.ReadString('\n') + if err != nil { + t.Fatalf("Failed to read blank line: %v", err) + } + } + + if messageURL != "" && pingID != 0 { + break + } + } + + if messageURL == "" { + t.Fatal("Did not receive message endpoint URL") + } + + pingResponse := map[string]any{ + "jsonrpc": "2.0", + "id": pingID, + "result": map[string]any{}, + } + + requestBody, err := json.Marshal(pingResponse) + if err != nil { + t.Fatalf("Failed to marshal ping response: %v", err) + } + + resp, err := http.Post( + messageURL, + "application/json", + bytes.NewBuffer(requestBody), + ) + if err != nil { + t.Fatalf("Failed to send ping response: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusAccepted { + t.Errorf("Expected status 202 for ping response, got %d", resp.StatusCode) + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("Failed to read response body: %v", err) + } + + if len(body) > 0 { + var response map[string]any + if err := json.Unmarshal(body, &response); err != nil { + t.Fatalf("Failed to parse response body: %v", err) + } + + if response["error"] != nil { + t.Errorf("Expected no error in response, got %v", response["error"]) + } + } + }) }