From 3ba0c9170c46d33db02e6741e34dd8f7830c9c7f Mon Sep 17 00:00:00 2001 From: Robert Jackson Date: Thu, 1 May 2025 09:51:56 -0400 Subject: [PATCH 01/10] feat(sse): Add `SessionWithTools` support to SSEServer (#232) Implement SessionWithTools interface for sseSession to support session-specific tools: - Add tools field to sseSession struct - Implement GetSessionTools and SetSessionTools methods --- server/sse.go | 30 +++++++++++- server/sse_test.go | 113 ++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 141 insertions(+), 2 deletions(-) diff --git a/server/sse.go b/server/sse.go index e380d20a..7994c606 100644 --- a/server/sse.go +++ b/server/sse.go @@ -28,6 +28,7 @@ type sseSession struct { requestID atomic.Int64 notificationChannel chan mcp.JSONRPCNotification initialized atomic.Bool + tools sync.Map // stores session-specific tools } // SSEContextFunc is a function that takes an existing context and the current @@ -58,7 +59,34 @@ func (s *sseSession) Initialized() bool { return s.initialized.Load() } -var _ ClientSession = (*sseSession)(nil) +func (s *sseSession) GetSessionTools() map[string]ServerTool { + tools := make(map[string]ServerTool) + s.tools.Range(func(key, value interface{}) bool { + if tool, ok := value.(ServerTool); ok { + tools[key.(string)] = tool + } + return true + }) + return tools +} + +func (s *sseSession) SetSessionTools(tools map[string]ServerTool) { + // Clear existing tools + s.tools.Range(func(key, _ interface{}) bool { + s.tools.Delete(key) + return true + }) + + // Set new tools + for name, tool := range tools { + s.tools.Store(name, tool) + } +} + +var ( + _ ClientSession = (*sseSession)(nil) + _ SessionWithTools = (*sseSession)(nil) +) // SSEServer implements a Server-Sent Events (SSE) based MCP server. // It provides real-time communication capabilities over HTTP using the SSE protocol. diff --git a/server/sse_test.go b/server/sse_test.go index a121581a..75da1eac 100644 --- a/server/sse_test.go +++ b/server/sse_test.go @@ -666,7 +666,8 @@ func TestSSEServer(t *testing.T) { t.Fatalf("Failed to marshal tool request: %v", err) } - req, err := http.NewRequest(http.MethodPost, messageURL, bytes.NewBuffer(requestBody)) + var req *http.Request + req, err = http.NewRequest(http.MethodPost, messageURL, bytes.NewBuffer(requestBody)) if err != nil { t.Fatalf("Failed to create tool request: %v", err) } @@ -1129,6 +1130,116 @@ func TestSSEServer(t *testing.T) { }) } }) + + t.Run("SessionWithTools implementation", func(t *testing.T) { + // Create hooks to track sessions + hooks := &Hooks{} + var registeredSession *sseSession + hooks.AddOnRegisterSession(func(ctx context.Context, session ClientSession) { + if s, ok := session.(*sseSession); ok { + registeredSession = s + } + }) + + mcpServer := NewMCPServer("test", "1.0.0", WithHooks(hooks)) + testServer := NewTestServer(mcpServer) + defer testServer.Close() + + // Connect to SSE endpoint + sseResp, err := http.Get(fmt.Sprintf("%s/sse", testServer.URL)) + if err != nil { + t.Fatalf("Failed to connect to SSE endpoint: %v", err) + } + defer sseResp.Body.Close() + + // Read the endpoint event to ensure session is established + _, err = readSeeEvent(sseResp) + if err != nil { + t.Fatalf("Failed to read SSE response: %v", err) + } + + // Verify we got a session + if registeredSession == nil { + t.Fatal("Session was not registered via hook") + } + + // Test setting and getting tools + tools := map[string]ServerTool{ + "test_tool": { + Tool: mcp.Tool{ + Name: "test_tool", + Description: "A test tool", + Annotations: mcp.ToolAnnotation{ + Title: "Test Tool", + }, + }, + Handler: func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + return mcp.NewToolResultText("test"), nil + }, + }, + } + + // Test SetSessionTools + registeredSession.SetSessionTools(tools) + + // Test GetSessionTools + retrievedTools := registeredSession.GetSessionTools() + if len(retrievedTools) != 1 { + t.Errorf("Expected 1 tool, got %d", len(retrievedTools)) + } + if tool, exists := retrievedTools["test_tool"]; !exists { + t.Error("Expected test_tool to exist") + } else if tool.Tool.Name != "test_tool" { + t.Errorf("Expected tool name test_tool, got %s", tool.Tool.Name) + } + + // Test concurrent access + var wg sync.WaitGroup + for i := 0; i < 10; i++ { + wg.Add(2) + go func(i int) { + defer wg.Done() + tools := map[string]ServerTool{ + fmt.Sprintf("tool_%d", i): { + Tool: mcp.Tool{ + Name: fmt.Sprintf("tool_%d", i), + Description: fmt.Sprintf("Tool %d", i), + Annotations: mcp.ToolAnnotation{ + Title: fmt.Sprintf("Tool %d", i), + }, + }, + }, + } + registeredSession.SetSessionTools(tools) + }(i) + go func() { + defer wg.Done() + _ = registeredSession.GetSessionTools() + }() + } + wg.Wait() + + // Verify we can still get and set tools after concurrent access + finalTools := map[string]ServerTool{ + "final_tool": { + Tool: mcp.Tool{ + Name: "final_tool", + Description: "Final Tool", + Annotations: mcp.ToolAnnotation{ + Title: "Final Tool", + }, + }, + }, + } + registeredSession.SetSessionTools(finalTools) + retrievedTools = registeredSession.GetSessionTools() + if len(retrievedTools) != 1 { + t.Errorf("Expected 1 tool, got %d", len(retrievedTools)) + } + if _, exists := retrievedTools["final_tool"]; !exists { + t.Error("Expected final_tool to exist") + } + }) } func readSeeEvent(sseResp *http.Response) (string, error) { From f3fef81032fde6519525abf64a9f67afdb0b3e38 Mon Sep 17 00:00:00 2001 From: Roman Gelembjuk Date: Fri, 2 May 2025 08:55:46 +0100 Subject: [PATCH 02/10] Fix bug with MarshalJSON for NotificationParams (#233) Co-authored-by: Roman Gelembjuk --- mcp/types.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mcp/types.go b/mcp/types.go index c79baae1..516f90b4 100644 --- a/mcp/types.go +++ b/mcp/types.go @@ -132,7 +132,7 @@ type NotificationParams struct { } // MarshalJSON implements custom JSON marshaling -func (p *NotificationParams) MarshalJSON() ([]byte, error) { +func (p NotificationParams) MarshalJSON() ([]byte, error) { // Create a map to hold all fields m := make(map[string]interface{}) From 90bd8779a6895dd5ef1eb6ac9084653d783287ca Mon Sep 17 00:00:00 2001 From: QihengZhou Date: Fri, 2 May 2025 18:34:37 +0800 Subject: [PATCH 03/10] fix: write back error message if the response marshal failed (#235) --- server/sse.go | 2 - server/sse_test.go | 102 +++++++++++++++++++++++++++++++++++++++------ 2 files changed, 90 insertions(+), 14 deletions(-) diff --git a/server/sse.go b/server/sse.go index 7994c606..94dee192 100644 --- a/server/sse.go +++ b/server/sse.go @@ -457,7 +457,6 @@ func (s *SSEServer) handleMessage(w http.ResponseWriter, r *http.Request) { go func() { // Process message through MCPServer response := s.server.HandleMessage(ctx, rawMessage) - // Only send response if there is one (not for notifications) if response != nil { var message string @@ -465,7 +464,6 @@ func (s *SSEServer) handleMessage(w http.ResponseWriter, r *http.Request) { // If there is an error marshalling the response, send a generic error response log.Printf("failed to marshal response: %v", err) message = fmt.Sprintf("event: message\ndata: {\"error\": \"internal error\",\"jsonrpc\": \"2.0\", \"id\": null}\n\n") - return } else { message = fmt.Sprintf("event: message\ndata: %s\n\n", eventData) } diff --git a/server/sse_test.go b/server/sse_test.go index 75da1eac..937dc274 100644 --- a/server/sse_test.go +++ b/server/sse_test.go @@ -62,7 +62,7 @@ func TestSSEServer(t *testing.T) { defer sseResp.Body.Close() // Read the endpoint event - endpointEvent, err := readSeeEvent(sseResp) + endpointEvent, err := readSSEEvent(sseResp) if err != nil { t.Fatalf("Failed to read SSE response: %v", err) } @@ -195,7 +195,7 @@ func TestSSEServer(t *testing.T) { } defer resp.Body.Close() - endpointEvent, err = readSeeEvent(sseResp) + endpointEvent, err = readSSEEvent(sseResp) if err != nil { t.Fatalf("Failed to read SSE response: %v", err) } @@ -590,7 +590,7 @@ func TestSSEServer(t *testing.T) { defer sseResp.Body.Close() // Read the endpoint event - endpointEvent, err := readSeeEvent(sseResp) + endpointEvent, err := readSSEEvent(sseResp) if err != nil { t.Fatalf("Failed to read SSE response: %v", err) } @@ -632,16 +632,16 @@ func TestSSEServer(t *testing.T) { } // Verify response - endpointEvent, err = readSeeEvent(sseResp) + endpointEvent, err = readSSEEvent(sseResp) if err != nil { t.Fatalf("Failed to read SSE response: %v", err) } - respFromSee := strings.TrimSpace( + respFromSSE := strings.TrimSpace( strings.Split(strings.Split(endpointEvent, "data: ")[1], "\n")[0], ) var response map[string]interface{} - if err := json.NewDecoder(strings.NewReader(respFromSee)).Decode(&response); err != nil { + if err := json.NewDecoder(strings.NewReader(respFromSSE)).Decode(&response); err != nil { t.Fatalf("Failed to decode response: %v", err) } @@ -680,17 +680,17 @@ func TestSSEServer(t *testing.T) { } defer resp.Body.Close() - endpointEvent, err = readSeeEvent(sseResp) + endpointEvent, err = readSSEEvent(sseResp) if err != nil { t.Fatalf("Failed to read SSE response: %v", err) } - respFromSee = strings.TrimSpace( + respFromSSE = strings.TrimSpace( strings.Split(strings.Split(endpointEvent, "data: ")[1], "\n")[0], ) response = make(map[string]interface{}) - if err := json.NewDecoder(strings.NewReader(respFromSee)).Decode(&response); err != nil { + if err := json.NewDecoder(strings.NewReader(respFromSSE)).Decode(&response); err != nil { t.Fatalf("Failed to decode response: %v", err) } @@ -1140,7 +1140,7 @@ func TestSSEServer(t *testing.T) { registeredSession = s } }) - + mcpServer := NewMCPServer("test", "1.0.0", WithHooks(hooks)) testServer := NewTestServer(mcpServer) defer testServer.Close() @@ -1153,7 +1153,7 @@ func TestSSEServer(t *testing.T) { defer sseResp.Body.Close() // Read the endpoint event to ensure session is established - _, err = readSeeEvent(sseResp) + _, err = readSSEEvent(sseResp) if err != nil { t.Fatalf("Failed to read SSE response: %v", err) } @@ -1240,9 +1240,87 @@ func TestSSEServer(t *testing.T) { t.Error("Expected final_tool to exist") } }) + + t.Run("TestServerResponseMarshalError", func(t *testing.T) { + mcpServer := NewMCPServer("test", "1.0.0", + WithResourceCapabilities(true, true), + WithHooks(&Hooks{ + OnAfterInitialize: []OnAfterInitializeFunc{ + func(ctx context.Context, id any, message *mcp.InitializeRequest, result *mcp.InitializeResult) { + result.Result.Meta = map[string]interface{}{"invalid": func() {}} // marshal will fail + }, + }, + }), + ) + testServer := NewTestServer(mcpServer) + defer testServer.Close() + + // Connect to SSE endpoint + sseResp, err := http.Get(fmt.Sprintf("%s/sse", testServer.URL)) + if err != nil { + t.Fatalf("Failed to connect to SSE endpoint: %v", err) + } + defer sseResp.Body.Close() + + // Read the endpoint event + endpointEvent, err := readSSEEvent(sseResp) + if err != nil { + t.Fatalf("Failed to read SSE response: %v", err) + } + if !strings.Contains(endpointEvent, "event: endpoint") { + t.Fatalf("Expected endpoint event, got: %s", endpointEvent) + } + + // Extract message endpoint URL + messageURL := strings.TrimSpace( + strings.Split(strings.Split(endpointEvent, "data: ")[1], "\n")[0], + ) + + // Send initialize request + initRequest := map[string]interface{}{ + "jsonrpc": "2.0", + "id": 1, + "method": "initialize", + "params": map[string]interface{}{ + "protocolVersion": "2024-11-05", + "clientInfo": map[string]interface{}{ + "name": "test-client", + "version": "1.0.0", + }, + }, + } + + requestBody, err := json.Marshal(initRequest) + if err != nil { + t.Fatalf("Failed to marshal request: %v", err) + } + + resp, err := http.Post( + messageURL, + "application/json", + bytes.NewBuffer(requestBody), + ) + if err != nil { + t.Fatalf("Failed to send message: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusAccepted { + t.Errorf("Expected status 202, got %d", resp.StatusCode) + } + + endpointEvent, err = readSSEEvent(sseResp) + if err != nil { + t.Fatalf("Failed to read SSE response: %v", err) + } + + if !strings.Contains(endpointEvent, "\"id\": null") { + t.Errorf("Expected id to be null") + } + }) } -func readSeeEvent(sseResp *http.Response) (string, error) { +func readSSEEvent(sseResp *http.Response) (string, error) { buf := make([]byte, 1024) n, err := sseResp.Body.Read(buf) if err != nil { From 6d55e4eb867e1911cda3326c7c863e178df4d645 Mon Sep 17 00:00:00 2001 From: cryo Date: Sun, 4 May 2025 00:59:01 +0800 Subject: [PATCH 04/10] fix(server/sse): potential goroutine leak in Heartbeat sender (#236) --- server/sse.go | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/server/sse.go b/server/sse.go index 94dee192..90fde667 100644 --- a/server/sse.go +++ b/server/sse.go @@ -367,7 +367,12 @@ func (s *SSEServer) handleSSE(w http.ResponseWriter, r *http.Request) { } messageBytes, _ := json.Marshal(message) pingMsg := fmt.Sprintf("event: message\ndata:%s\n\n", messageBytes) - session.eventQueue <- pingMsg + select { + case session.eventQueue <- pingMsg: + // Message sent successfully + case <-session.done: + return + } case <-session.done: return case <-r.Context().Done(): From 524448985ec3ddcbcfe6680365f81c532cc1fd07 Mon Sep 17 00:00:00 2001 From: Ed Zynda Date: Sun, 4 May 2025 13:59:15 +0300 Subject: [PATCH 05/10] Fix stdio test compilation issues in CI (#240) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR fixes the test failures in CI by: 1. Using -buildmode=pie flag when compiling test binaries 2. Using os.CreateTemp() for more reliable temporary file creation 3. Verifying binary existence after compilation 4. Fixing variable shadowing issues 🤖 Generated with opencode Co-Authored-By: opencode --- client/stdio_test.go | 23 ++++++++--- client/transport/stdio_test.go | 75 +++++++++++++++++++++++----------- 2 files changed, 70 insertions(+), 28 deletions(-) diff --git a/client/stdio_test.go b/client/stdio_test.go index 7bffa3b2..8c9ff299 100644 --- a/client/stdio_test.go +++ b/client/stdio_test.go @@ -7,7 +7,7 @@ import ( "log/slog" "os" "os/exec" - "path/filepath" + "runtime" "sync" "testing" "time" @@ -19,6 +19,7 @@ func compileTestServer(outputPath string) error { cmd := exec.Command( "go", "build", + "-buildmode=pie", "-o", outputPath, "../testdata/mockstdio_server.go", @@ -33,10 +34,22 @@ func compileTestServer(outputPath string) error { } func TestStdioMCPClient(t *testing.T) { - // Compile mock server - mockServerPath := filepath.Join(os.TempDir(), "mockstdio_server") - if err := compileTestServer(mockServerPath); err != nil { - t.Fatalf("Failed to compile mock server: %v", err) + // Create a temporary file for the mock server + tempFile, err := os.CreateTemp("", "mockstdio_server") + if err != nil { + t.Fatalf("Failed to create temp file: %v", err) + } + tempFile.Close() + mockServerPath := tempFile.Name() + + // Add .exe suffix on Windows + if runtime.GOOS == "windows" { + os.Remove(mockServerPath) // Remove the empty file first + mockServerPath += ".exe" + } + + if compileErr := compileTestServer(mockServerPath); compileErr != nil { + t.Fatalf("Failed to compile mock server: %v", compileErr) } defer os.Remove(mockServerPath) diff --git a/client/transport/stdio_test.go b/client/transport/stdio_test.go index aa728ec6..53db7a0f 100644 --- a/client/transport/stdio_test.go +++ b/client/transport/stdio_test.go @@ -6,7 +6,6 @@ import ( "fmt" "os" "os/exec" - "path/filepath" "runtime" "sync" "testing" @@ -19,6 +18,7 @@ func compileTestServer(outputPath string) error { cmd := exec.Command( "go", "build", + "-buildmode=pie", "-o", outputPath, "../../testdata/mockstdio_server.go", @@ -26,18 +26,30 @@ func compileTestServer(outputPath string) error { if output, err := cmd.CombinedOutput(); err != nil { return fmt.Errorf("compilation failed: %v\nOutput: %s", err, output) } + // Verify the binary was actually created + if _, err := os.Stat(outputPath); os.IsNotExist(err) { + return fmt.Errorf("mock server binary not found at %s after compilation", outputPath) + } return nil } func TestStdio(t *testing.T) { - // Compile mock server - mockServerPath := filepath.Join(os.TempDir(), "mockstdio_server") + // Create a temporary file for the mock server + tempFile, err := os.CreateTemp("", "mockstdio_server") + if err != nil { + t.Fatalf("Failed to create temp file: %v", err) + } + tempFile.Close() + mockServerPath := tempFile.Name() + // Add .exe suffix on Windows if runtime.GOOS == "windows" { + os.Remove(mockServerPath) // Remove the empty file first mockServerPath += ".exe" } - if err := compileTestServer(mockServerPath); err != nil { - t.Fatalf("Failed to compile mock server: %v", err) + + if compileErr := compileTestServer(mockServerPath); compileErr != nil { + t.Fatalf("Failed to compile mock server: %v", compileErr) } defer os.Remove(mockServerPath) @@ -48,9 +60,9 @@ func TestStdio(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() - err := stdio.Start(ctx) - if err != nil { - t.Fatalf("Failed to start Stdio transport: %v", err) + startErr := stdio.Start(ctx) + if startErr != nil { + t.Fatalf("Failed to start Stdio transport: %v", startErr) } defer stdio.Close() @@ -307,13 +319,22 @@ func TestStdioErrors(t *testing.T) { }) t.Run("RequestBeforeStart", func(t *testing.T) { - mockServerPath := filepath.Join(os.TempDir(), "mockstdio_server") + // Create a temporary file for the mock server + tempFile, err := os.CreateTemp("", "mockstdio_server") + if err != nil { + t.Fatalf("Failed to create temp file: %v", err) + } + tempFile.Close() + mockServerPath := tempFile.Name() + // Add .exe suffix on Windows if runtime.GOOS == "windows" { + os.Remove(mockServerPath) // Remove the empty file first mockServerPath += ".exe" } - if err := compileTestServer(mockServerPath); err != nil { - t.Fatalf("Failed to compile mock server: %v", err) + + if compileErr := compileTestServer(mockServerPath); compileErr != nil { + t.Fatalf("Failed to compile mock server: %v", compileErr) } defer os.Remove(mockServerPath) @@ -328,23 +349,31 @@ func TestStdioErrors(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond) defer cancel() - _, err := uninitiatedStdio.SendRequest(ctx, request) - if err == nil { + _, reqErr := uninitiatedStdio.SendRequest(ctx, request) + if reqErr == nil { t.Errorf("Expected SendRequest to panic before Start(), but it didn't") - } else if err.Error() != "stdio client not started" { - t.Errorf("Expected error 'stdio client not started', got: %v", err) + } else if reqErr.Error() != "stdio client not started" { + t.Errorf("Expected error 'stdio client not started', got: %v", reqErr) } }) t.Run("RequestAfterClose", func(t *testing.T) { - // Compile mock server - mockServerPath := filepath.Join(os.TempDir(), "mockstdio_server") + // Create a temporary file for the mock server + tempFile, err := os.CreateTemp("", "mockstdio_server") + if err != nil { + t.Fatalf("Failed to create temp file: %v", err) + } + tempFile.Close() + mockServerPath := tempFile.Name() + // Add .exe suffix on Windows if runtime.GOOS == "windows" { + os.Remove(mockServerPath) // Remove the empty file first mockServerPath += ".exe" } - if err := compileTestServer(mockServerPath); err != nil { - t.Fatalf("Failed to compile mock server: %v", err) + + if compileErr := compileTestServer(mockServerPath); compileErr != nil { + t.Fatalf("Failed to compile mock server: %v", compileErr) } defer os.Remove(mockServerPath) @@ -353,8 +382,8 @@ func TestStdioErrors(t *testing.T) { // Start the transport ctx := context.Background() - if err := stdio.Start(ctx); err != nil { - t.Fatalf("Failed to start Stdio transport: %v", err) + if startErr := stdio.Start(ctx); startErr != nil { + t.Fatalf("Failed to start Stdio transport: %v", startErr) } // Close the transport - ignore errors like "broken pipe" since the process might exit already @@ -370,8 +399,8 @@ func TestStdioErrors(t *testing.T) { Method: "ping", } - _, err := stdio.SendRequest(ctx, request) - if err == nil { + _, sendErr := stdio.SendRequest(ctx, request) + if sendErr == nil { t.Errorf("Expected error when sending request after close, got nil") } }) From 2f24f3f146cd006eee8e75112ce631168a3efb9c Mon Sep 17 00:00:00 2001 From: Robert Jackson Date: Sun, 4 May 2025 11:37:10 -0400 Subject: [PATCH 06/10] refactor(server/sse): rename WithBasePath to WithStaticBasePath for clarity (#238) The new name makes its relationship to `WithDynamicBasePath` clearer. The implementation preserves the original functionality with a build time warning (in go 1.21+). --- server/sse.go | 13 +++++++++++-- server/sse_test.go | 6 +++--- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/server/sse.go b/server/sse.go index 90fde667..8467b02f 100644 --- a/server/sse.go +++ b/server/sse.go @@ -135,13 +135,22 @@ func WithBaseURL(baseURL string) SSEOption { } } -// WithBasePath adds a new option for setting a static base path -func WithBasePath(basePath string) SSEOption { +// WithStaticBasePath adds a new option for setting a static base path +func WithStaticBasePath(basePath string) SSEOption { return func(s *SSEServer) { s.basePath = normalizeURLPath(basePath) } } +// WithBasePath adds a new option for setting a static base path. +// +// Deprecated: Use WithStaticBasePath instead. This will be removed in a future version. +// +//go:deprecated +func WithBasePath(basePath string) SSEOption { + return WithStaticBasePath(basePath) +} + // WithDynamicBasePath accepts a function for generating the base path. This is // useful for cases where the base path is not known at the time of SSE server // creation, such as when using a reverse proxy or when the server is mounted diff --git a/server/sse_test.go b/server/sse_test.go index 937dc274..9196c8fe 100644 --- a/server/sse_test.go +++ b/server/sse_test.go @@ -24,7 +24,7 @@ func TestSSEServer(t *testing.T) { mcpServer := NewMCPServer("test", "1.0.0") sseServer := NewSSEServer(mcpServer, WithBaseURL("http://localhost:8080"), - WithBasePath("/mcp"), + WithStaticBasePath("/mcp"), ) if sseServer == nil { @@ -499,7 +499,7 @@ func TestSSEServer(t *testing.T) { t.Run("works as http.Handler with custom basePath", func(t *testing.T) { mcpServer := NewMCPServer("test", "1.0.0") - sseServer := NewSSEServer(mcpServer, WithBasePath("/mcp")) + sseServer := NewSSEServer(mcpServer, WithStaticBasePath("/mcp")) ts := httptest.NewServer(sseServer) defer ts.Close() @@ -717,7 +717,7 @@ func TestSSEServer(t *testing.T) { useFullURLForMessageEndpoint := false srv := &http.Server{} rands := []SSEOption{ - WithBasePath(basePath), + WithStaticBasePath(basePath), WithBaseURL(baseURL), WithMessageEndpoint(messageEndpoint), WithUseFullURLForMessageEndpoint(useFullURLForMessageEndpoint), From e40e7a79aebe51278fac8ebe9cd5faeb051d2e0a Mon Sep 17 00:00:00 2001 From: cryo Date: Sun, 4 May 2025 23:38:00 +0800 Subject: [PATCH 07/10] fix(MCPServer): Session tool handler not used due to variable shadowing (#242) --- server/server.go | 2 +- server/session_test.go | 59 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 60 insertions(+), 1 deletion(-) diff --git a/server/server.go b/server/server.go index 95831ebd..8aac05ca 100644 --- a/server/server.go +++ b/server/server.go @@ -856,7 +856,7 @@ func (s *MCPServer) handleToolCall( session := ClientSessionFromContext(ctx) if session != nil { - if sessionWithTools, ok := session.(SessionWithTools); ok { + if sessionWithTools, typeAssertOk := session.(SessionWithTools); typeAssertOk { if sessionTools := sessionWithTools.GetSessionTools(); sessionTools != nil { var sessionOk bool tool, sessionOk = sessionTools[request.Params.Name] diff --git a/server/session_test.go b/server/session_test.go index d1d0bc79..42def221 100644 --- a/server/session_test.go +++ b/server/session_test.go @@ -2,6 +2,7 @@ package server import ( "context" + "encoding/json" "errors" "sync" "testing" @@ -295,6 +296,64 @@ func TestMCPServer_AddSessionTool(t *testing.T) { assert.Contains(t, session.GetSessionTools(), "session-tool-helper") } +func TestMCPServer_CallSessionTool(t *testing.T) { + server := NewMCPServer("test-server", "1.0.0", WithToolCapabilities(true)) + + // Add global tool + server.AddTool(mcp.NewTool("test_tool"), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + return mcp.NewToolResultText("global result"), nil + }) + + // Create a session + sessionChan := make(chan mcp.JSONRPCNotification, 10) + session := &sessionTestClientWithTools{ + sessionID: "session-1", + notificationChannel: sessionChan, + initialized: true, + } + + // Register the session + err := server.RegisterSession(context.Background(), session) + require.NoError(t, err) + + // Add session-specific tool with the same name to override the global tool + err = server.AddSessionTool( + session.SessionID(), + mcp.NewTool("test_tool"), + func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + return mcp.NewToolResultText("session result"), nil + }, + ) + require.NoError(t, err) + + // Call the tool using session context + sessionCtx := server.WithContext(context.Background(), session) + toolRequest := map[string]interface{}{ + "jsonrpc": "2.0", + "id": 1, + "method": "tools/call", + "params": map[string]interface{}{ + "name": "test_tool", + }, + } + requestBytes, err := json.Marshal(toolRequest) + if err != nil { + t.Fatalf("Failed to marshal tool request: %v", err) + } + + response := server.HandleMessage(sessionCtx, requestBytes) + resp, ok := response.(mcp.JSONRPCResponse) + assert.True(t, ok) + + callToolResult, ok := resp.Result.(mcp.CallToolResult) + assert.True(t, ok) + + // Since we specify a tool with the same name for current session, the expected text should be "session result" + if text := callToolResult.Content[0].(mcp.TextContent).Text; text != "session result" { + t.Errorf("Expected result 'session result', got %q", text) + } +} + func TestMCPServer_DeleteSessionTools(t *testing.T) { server := NewMCPServer("test-server", "1.0.0", WithToolCapabilities(true)) ctx := context.Background() From a999079f3650486677e7698963fb2505b7bdda87 Mon Sep 17 00:00:00 2001 From: Robert Jackson Date: Sun, 4 May 2025 11:38:40 -0400 Subject: [PATCH 08/10] test: build mockstdio_server with isolated cache to prevent flaky CI (#241) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit CI occasionally failed with the linker error: /link: cannot open file DO NOT USE - main build pseudo-cache built This is most likely because several parallel `go build` invocations shared the same `$GOCACHE`, letting one job evict the object file another job had promised the linker. The placeholder path then leaked through and the build aborted. This gives each compile its own cache by setting `GOCACHE=$(mktemp -d)` for the helper’s `go build` call. After these changes `go test ./... -race` passed 100/100 consecutive runs locally. --- client/stdio_test.go | 4 ++++ client/transport/stdio_test.go | 3 +++ 2 files changed, 7 insertions(+) diff --git a/client/stdio_test.go b/client/stdio_test.go index 8c9ff299..fe4e3b5a 100644 --- a/client/stdio_test.go +++ b/client/stdio_test.go @@ -24,9 +24,13 @@ func compileTestServer(outputPath string) error { outputPath, "../testdata/mockstdio_server.go", ) + tmpCache, _ := os.MkdirTemp("", "gocache") + cmd.Env = append(os.Environ(), "GOCACHE="+tmpCache) + if output, err := cmd.CombinedOutput(); err != nil { return fmt.Errorf("compilation failed: %v\nOutput: %s", err, output) } + // Verify the binary was actually created if _, err := os.Stat(outputPath); os.IsNotExist(err) { return fmt.Errorf("mock server binary not found at %s after compilation", outputPath) } diff --git a/client/transport/stdio_test.go b/client/transport/stdio_test.go index 53db7a0f..6d87cdbd 100644 --- a/client/transport/stdio_test.go +++ b/client/transport/stdio_test.go @@ -23,6 +23,9 @@ func compileTestServer(outputPath string) error { outputPath, "../../testdata/mockstdio_server.go", ) + tmpCache, _ := os.MkdirTemp("", "gocache") + cmd.Env = append(os.Environ(), "GOCACHE="+tmpCache) + if output, err := cmd.CombinedOutput(); err != nil { return fmt.Errorf("compilation failed: %v\nOutput: %s", err, output) } From f47e2bce1b69c409bf262f18f6ecba8039ee7fb3 Mon Sep 17 00:00:00 2001 From: Yashwanth <53632453+yash025@users.noreply.github.com> Date: Mon, 5 May 2025 14:42:48 +0530 Subject: [PATCH 09/10] fix: Use detached context for SSE message handling (#244) * fix: Use detached context for SSE message handling Prevents premature cancellation of message processing when HTTP request ends. * test for message processing when we return early to the client * rename variable --------- Co-authored-by: Yashwanth H L --- server/sse.go | 13 +++++-- server/sse_test.go | 84 +++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 94 insertions(+), 3 deletions(-) diff --git a/server/sse.go b/server/sse.go index 8467b02f..018657e6 100644 --- a/server/sse.go +++ b/server/sse.go @@ -465,10 +465,19 @@ func (s *SSEServer) handleMessage(w http.ResponseWriter, r *http.Request) { return } + // Create a context that preserves all values from parent ctx but won't be canceled when the parent is canceled. + // this is required because the http ctx will be canceled when the client disconnects + detachedCtx := context.WithoutCancel(ctx) + // quick return request, send 202 Accepted with no body, then deal the message and sent response via SSE w.WriteHeader(http.StatusAccepted) - go func() { + // Create a new context for handling the message that will be canceled when the message handling is done + messageCtx, cancel := context.WithCancel(detachedCtx) + + go func(ctx context.Context) { + defer cancel() + // Use the context that will be canceled when session is done // Process message through MCPServer response := s.server.HandleMessage(ctx, rawMessage) // Only send response if there is one (not for notifications) @@ -493,7 +502,7 @@ func (s *SSEServer) handleMessage(w http.ResponseWriter, r *http.Request) { log.Printf("Event queue full for session %s", sessionID) } } - }() + }(messageCtx) } // writeJSONRPCError writes a JSON-RPC error response with the given error details. diff --git a/server/sse_test.go b/server/sse_test.go index 9196c8fe..393a70cf 100644 --- a/server/sse_test.go +++ b/server/sse_test.go @@ -203,7 +203,6 @@ func TestSSEServer(t *testing.T) { strings.Split(strings.Split(endpointEvent, "data: ")[1], "\n")[0], ) - fmt.Printf("========> %v", respFromSee) var response map[string]interface{} if err := json.NewDecoder(strings.NewReader(respFromSee)).Decode(&response); err != nil { t.Errorf( @@ -1318,6 +1317,89 @@ func TestSSEServer(t *testing.T) { t.Errorf("Expected id to be null") } }) + + t.Run("Message processing continues after we return back result to client", func(t *testing.T) { + mcpServer := NewMCPServer("test", "1.0.0") + + processingCompleted := make(chan struct{}) + processingStarted := make(chan struct{}) + + mcpServer.AddTool(mcp.NewTool("slowMethod"), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + close(processingStarted) // signal for processing started + + select { + case <-ctx.Done(): // If this happens, the test will fail because processingCompleted won't be closed + return nil, fmt.Errorf("context was canceled") + case <-time.After(1 * time.Second): // Simulate processing time + // Successfully completed processing, now close the completed channel to signal completion + close(processingCompleted) + return &mcp.CallToolResult{ + Content: []mcp.Content{ + mcp.TextContent{ + Type: "text", + Text: "success", + }, + }, + }, nil + } + }) + + testServer := NewTestServer(mcpServer) + defer testServer.Close() + + sseResp, err := http.Get(fmt.Sprintf("%s/sse", testServer.URL)) + require.NoError(t, err, "Failed to connect to SSE endpoint") + defer sseResp.Body.Close() + + endpointEvent, err := readSSEEvent(sseResp) + require.NoError(t, err, "Failed to read SSE response") + require.Contains(t, endpointEvent, "event: endpoint", "Expected endpoint event") + + messageURL := strings.TrimSpace( + strings.Split(strings.Split(endpointEvent, "data: ")[1], "\n")[0], + ) + + messageRequest := map[string]interface{}{ + "jsonrpc": "2.0", + "id": 1, + "method": "tools/call", + "params": map[string]interface{}{ + "name": "slowMethod", + "parameters": map[string]interface{}{}, + }, + } + + requestBody, err := json.Marshal(messageRequest) + require.NoError(t, err, "Failed to marshal request") + + ctx, cancel := context.WithCancel(context.Background()) + req, err := http.NewRequestWithContext(ctx, "POST", messageURL, bytes.NewBuffer(requestBody)) + require.NoError(t, err, "Failed to create request") + req.Header.Set("Content-Type", "application/json") + + client := &http.Client{} + resp, err := client.Do(req) + require.NoError(t, err, "Failed to send message") + defer resp.Body.Close() + + require.Equal(t, http.StatusAccepted, resp.StatusCode, "Expected status 202 Accepted") + + // Wait for processing to start + select { + case <-processingStarted: // Processing has started, now cancel the client context to simulate disconnection + case <-time.After(2 * time.Second): + t.Fatal("Timed out waiting for processing to start") + } + + cancel() // cancel the client context to simulate disconnection + + // wait for processing to complete, if the test passes, it means the processing continued despite client disconnection + select { + case <-processingCompleted: + case <-time.After(2 * time.Second): + t.Fatal("Processing did not complete after client disconnection") + } + }) } func readSSEEvent(sseResp *http.Response) (string, error) { From 9d6b793133b9b56a25152083a0d5fcfd92d59882 Mon Sep 17 00:00:00 2001 From: Ed Zynda Date: Tue, 6 May 2025 19:36:54 +0300 Subject: [PATCH 10/10] Format --- client/stdio_test.go | 4 ++-- client/transport/stdio_test.go | 12 ++++++------ 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/client/stdio_test.go b/client/stdio_test.go index fe4e3b5a..48514d91 100644 --- a/client/stdio_test.go +++ b/client/stdio_test.go @@ -45,13 +45,13 @@ func TestStdioMCPClient(t *testing.T) { } tempFile.Close() mockServerPath := tempFile.Name() - + // Add .exe suffix on Windows if runtime.GOOS == "windows" { os.Remove(mockServerPath) // Remove the empty file first mockServerPath += ".exe" } - + if compileErr := compileTestServer(mockServerPath); compileErr != nil { t.Fatalf("Failed to compile mock server: %v", compileErr) } diff --git a/client/transport/stdio_test.go b/client/transport/stdio_test.go index 6d87cdbd..cb25bf79 100644 --- a/client/transport/stdio_test.go +++ b/client/transport/stdio_test.go @@ -44,13 +44,13 @@ func TestStdio(t *testing.T) { } tempFile.Close() mockServerPath := tempFile.Name() - + // Add .exe suffix on Windows if runtime.GOOS == "windows" { os.Remove(mockServerPath) // Remove the empty file first mockServerPath += ".exe" } - + if compileErr := compileTestServer(mockServerPath); compileErr != nil { t.Fatalf("Failed to compile mock server: %v", compileErr) } @@ -329,13 +329,13 @@ func TestStdioErrors(t *testing.T) { } tempFile.Close() mockServerPath := tempFile.Name() - + // Add .exe suffix on Windows if runtime.GOOS == "windows" { os.Remove(mockServerPath) // Remove the empty file first mockServerPath += ".exe" } - + if compileErr := compileTestServer(mockServerPath); compileErr != nil { t.Fatalf("Failed to compile mock server: %v", compileErr) } @@ -368,13 +368,13 @@ func TestStdioErrors(t *testing.T) { } tempFile.Close() mockServerPath := tempFile.Name() - + // Add .exe suffix on Windows if runtime.GOOS == "windows" { os.Remove(mockServerPath) // Remove the empty file first mockServerPath += ".exe" } - + if compileErr := compileTestServer(mockServerPath); compileErr != nil { t.Fatalf("Failed to compile mock server: %v", compileErr) }