From e7d2547fdc103cc64125097694e68a158beaeccb Mon Sep 17 00:00:00 2001 From: "David J. Hamilton" Date: Thu, 15 May 2025 10:18:39 -0700 Subject: [PATCH 1/5] feat(tools): implicitly register capabilities (#292) When users add tools via AddTool or AddSessionTool, implicitly set the tools capability. If the user has not already called WithToolCapabilities, then default listChanged to true, but honor any existing value. This mimics the behavior of the official typescript sdk, which registers `tools.listChanged: true` when the user adds a tool to the MCP server. --- server/server.go | 15 ++++++++--- server/server_test.go | 43 ++++++++++++++++++++++++++++++++ server/session.go | 2 ++ server/session_test.go | 56 ++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 113 insertions(+), 3 deletions(-) diff --git a/server/server.go b/server/server.go index 33ab4c38..b31b4865 100644 --- a/server/server.go +++ b/server/server.go @@ -411,20 +411,29 @@ func (s *MCPServer) AddTool(tool mcp.Tool, handler ToolHandlerFunc) { s.AddTools(ServerTool{Tool: tool, Handler: handler}) } -// AddTools registers multiple tools at once -func (s *MCPServer) AddTools(tools ...ServerTool) { +// Register tool capabilities due to a tool being added. Default to +// listChanged: true, but don't change the value if we've already explicitly +// registered tools.listChanged false. +func (s *MCPServer) implicitlyRegisterToolCapabilities() { s.capabilitiesMu.RLock() if s.capabilities.tools == nil { s.capabilitiesMu.RUnlock() s.capabilitiesMu.Lock() if s.capabilities.tools == nil { - s.capabilities.tools = &toolCapabilities{} + s.capabilities.tools = &toolCapabilities{ + listChanged: true, + } } s.capabilitiesMu.Unlock() } else { s.capabilitiesMu.RUnlock() } +} + +// AddTools registers multiple tools at once +func (s *MCPServer) AddTools(tools ...ServerTool) { + s.implicitlyRegisterToolCapabilities() s.toolsMu.Lock() for _, entry := range tools { diff --git a/server/server_test.go b/server/server_test.go index 6a97e5d3..4615b0fb 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -1624,3 +1624,46 @@ func BenchmarkMCPServer_PaginationForReflect(b *testing.B) { _, _, _ = listByPaginationForReflect[mcp.Tool](ctx, server, "dG9vbDY1NA==", list) } } + +func TestMCPServer_ToolCapabilitiesBehavior(t *testing.T) { + tests := []struct { + name string + serverOptions []ServerOption + validateServer func(t *testing.T, s *MCPServer) + }{ + { + name: "no tool capabilities provided", + serverOptions: []ServerOption{ + // No WithToolCapabilities + }, + validateServer: func(t *testing.T, s *MCPServer) { + s.capabilitiesMu.RLock() + defer s.capabilitiesMu.RUnlock() + + require.NotNil(t, s.capabilities.tools, "tools capability should be initialized") + assert.True(t, s.capabilities.tools.listChanged, "listChanged should be true when no capabilities were provided") + }, + }, + { + name: "tools.listChanged set to false", + serverOptions: []ServerOption{ + WithToolCapabilities(false), + }, + validateServer: func(t *testing.T, s *MCPServer) { + s.capabilitiesMu.RLock() + defer s.capabilitiesMu.RUnlock() + + require.NotNil(t, s.capabilities.tools, "tools capability should be initialized") + assert.False(t, s.capabilities.tools.listChanged, "listChanged should remain false when explicitly set to false") + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + server := NewMCPServer("test-server", "1.0.0", tt.serverOptions...) + server.AddTool(mcp.NewTool("test-tool"), nil) + tt.validateServer(t, server) + }) + } +} diff --git a/server/session.go b/server/session.go index ab13e057..20bd44ad 100644 --- a/server/session.go +++ b/server/session.go @@ -224,6 +224,8 @@ func (s *MCPServer) AddSessionTools(sessionID string, tools ...ServerTool) error return ErrSessionDoesNotSupportTools } + s.implicitlyRegisterToolCapabilities() + // Get existing tools (this should return a thread-safe copy) sessionTools := session.GetSessionTools() diff --git a/server/session_test.go b/server/session_test.go index 00152c03..9cf646f5 100644 --- a/server/session_test.go +++ b/server/session_test.go @@ -802,3 +802,59 @@ func TestMCPServer_NotificationChannelBlocked(t *testing.T) { assert.Equal(t, "blocked-session", localErrorSessionID, "Session ID should be captured in the error hook") assert.Equal(t, "broadcast-message", localErrorMethod, "Method should be captured in the error hook") } + +func TestMCPServer_SessionToolCapabilitiesBehavior(t *testing.T) { + tests := []struct { + name string + serverOptions []ServerOption + validateServer func(t *testing.T, s *MCPServer, session *sessionTestClientWithTools) + }{ + { + name: "no tool capabilities provided", + serverOptions: []ServerOption{ + // No WithToolCapabilities + }, + validateServer: func(t *testing.T, s *MCPServer, session *sessionTestClientWithTools) { + s.capabilitiesMu.RLock() + defer s.capabilitiesMu.RUnlock() + + require.NotNil(t, s.capabilities.tools, "tools capability should be initialized") + assert.True(t, s.capabilities.tools.listChanged, "listChanged should be true when no capabilities were provided") + }, + }, + { + name: "tools.listChanged set to false", + serverOptions: []ServerOption{ + WithToolCapabilities(false), + }, + validateServer: func(t *testing.T, s *MCPServer, session *sessionTestClientWithTools) { + s.capabilitiesMu.RLock() + defer s.capabilitiesMu.RUnlock() + + require.NotNil(t, s.capabilities.tools, "tools capability should be initialized") + assert.False(t, s.capabilities.tools.listChanged, "listChanged should remain false when explicitly set to false") + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + server := NewMCPServer("test-server", "1.0.0", tt.serverOptions...) + + // Create and register a session + session := &sessionTestClientWithTools{ + sessionID: "test-session", + notificationChannel: make(chan mcp.JSONRPCNotification, 10), + initialized: true, + } + err := server.RegisterSession(context.Background(), session) + require.NoError(t, err) + + // Add a session tool and verify listChanged remains false + err = server.AddSessionTool(session.SessionID(), mcp.NewTool("test-tool"), nil) + require.NoError(t, err) + + tt.validateServer(t, server, session) + }) + } +} From eb835b903dbf9e9f6c594b2344a4e80d98cd0712 Mon Sep 17 00:00:00 2001 From: "David J. Hamilton" Date: Thu, 15 May 2025 12:45:28 -0700 Subject: [PATCH 2/5] fix: Gate notifications on capabilities (#290) Servers may report their [tools.listChanged][] capability as false, in which case they indicate that they will not send notifications when available tools change. Honor the spec by not sending notifications/tools/list_changed notifications when capabilities.tools.listChanged is false. [tools.listChanged]: https://modelcontextprotocol.io/specification/2025-03-26/server/tools#capabilities --- server/session.go | 16 +++++++++--- server/session_test.go | 59 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 71 insertions(+), 4 deletions(-) diff --git a/server/session.go b/server/session.go index 20bd44ad..3a4206a7 100644 --- a/server/session.go +++ b/server/session.go @@ -247,8 +247,12 @@ func (s *MCPServer) AddSessionTools(sessionID string, tools ...ServerTool) error // It only makes sense to send tool notifications to initialized sessions -- // if we're not initialized yet the client can't possibly have sent their - // initial tools/list message - if session.Initialized() { + // initial tools/list message. + // + // For initialized sessions, honor tools.listChanged, which is specifically + // about whether notifications will be sent or not. + // see + if session.Initialized() && s.capabilities.tools != nil && s.capabilities.tools.listChanged { // Send notification only to this session if err := s.SendNotificationToSpecificClient(sessionID, "notifications/tools/list_changed", nil); err != nil { // Log the error but don't fail the operation @@ -305,8 +309,12 @@ func (s *MCPServer) DeleteSessionTools(sessionID string, names ...string) error // It only makes sense to send tool notifications to initialized sessions -- // if we're not initialized yet the client can't possibly have sent their - // initial tools/list message - if session.Initialized() { + // initial tools/list message. + // + // For initialized sessions, honor tools.listChanged, which is specifically + // about whether notifications will be sent or not. + // see + if session.Initialized() && s.capabilities.tools != nil && s.capabilities.tools.listChanged { // Send notification only to this session if err := s.SendNotificationToSpecificClient(sessionID, "notifications/tools/list_changed", nil); err != nil { // Log the error but don't fail the operation diff --git a/server/session_test.go b/server/session_test.go index 9cf646f5..54a78170 100644 --- a/server/session_test.go +++ b/server/session_test.go @@ -858,3 +858,62 @@ func TestMCPServer_SessionToolCapabilitiesBehavior(t *testing.T) { }) } } + +func TestMCPServer_ToolNotificationsDisabled(t *testing.T) { + // This test verifies that when tool capabilities are disabled, we still + // add/delete tools correctly but don't send notifications about it. + // + // This is important because: + // 1. Tools should still work even if notifications are disabled + // 2. We shouldn't waste resources sending notifications that won't be used + // 3. The client might not be ready to handle tool notifications yet + + // Create a server WITHOUT tool capabilities + server := NewMCPServer("test-server", "1.0.0", WithToolCapabilities(false)) + ctx := context.Background() + + // Create an initialized session + sessionChan := make(chan mcp.JSONRPCNotification, 1) + session := &sessionTestClientWithTools{ + sessionID: "session-1", + notificationChannel: sessionChan, + initialized: true, + } + + // Register the session + err := server.RegisterSession(ctx, session) + require.NoError(t, err) + + // Add a tool + err = server.AddSessionTools(session.SessionID(), + ServerTool{Tool: mcp.NewTool("test-tool")}, + ) + require.NoError(t, err) + + // Verify no notification was sent + select { + case <-sessionChan: + t.Error("Expected no notification to be sent when capabilities.tools.listChanged is false") + default: + // This is the expected case - no notification should be sent + } + + // Verify tool was added to session + assert.Len(t, session.GetSessionTools(), 1) + assert.Contains(t, session.GetSessionTools(), "test-tool") + + // Delete the tool + err = server.DeleteSessionTools(session.SessionID(), "test-tool") + require.NoError(t, err) + + // Verify no notification was sent + select { + case <-sessionChan: + t.Error("Expected no notification to be sent when capabilities.tools.listChanged is false") + default: + // This is the expected case - no notification should be sent + } + + // Verify tool was deleted from session + assert.Len(t, session.GetSessionTools(), 0) +} From 91ddba5f0b9cef6fd6b89cae1009b0ab55eeb1c0 Mon Sep 17 00:00:00 2001 From: "Anuraag (Rag) Agrawal" Date: Sat, 17 May 2025 02:32:14 +0900 Subject: [PATCH 3/5] feat(protocol): allow additional fields in meta (#293) --- mcp/tools.go | 10 +------ mcp/types.go | 48 ++++++++++++++++++++++++++------ mcp/types_test.go | 70 +++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 110 insertions(+), 18 deletions(-) create mode 100644 mcp/types_test.go diff --git a/mcp/tools.go b/mcp/tools.go index 392b837e..79d66e3f 100644 --- a/mcp/tools.go +++ b/mcp/tools.go @@ -46,15 +46,7 @@ type CallToolRequest struct { Params struct { Name string `json:"name"` Arguments map[string]any `json:"arguments,omitempty"` - Meta *struct { - // If specified, the caller is requesting out-of-band progress - // notifications for this request (as represented by - // notifications/progress). The value of this parameter is an - // opaque token that will be attached to any subsequent - // notifications. The receiver is not obligated to provide these - // notifications. - ProgressToken ProgressToken `json:"progressToken,omitempty"` - } `json:"_meta,omitempty"` + Meta *Meta `json:"_meta,omitempty"` } `json:"params"` } diff --git a/mcp/types.go b/mcp/types.go index 17a00618..e7fdb6f0 100644 --- a/mcp/types.go +++ b/mcp/types.go @@ -4,6 +4,7 @@ package mcp import ( "encoding/json" + "maps" "github.com/yosida95/uritemplate/v3" ) @@ -100,18 +101,47 @@ type ProgressToken any // Cursor is an opaque token used to represent a cursor for pagination. type Cursor string +// Meta is metadata attached to a request's parameters. This can include fields +// formally defined by the protocol or other arbitrary data. +type Meta struct { + // If specified, the caller is requesting out-of-band progress + // notifications for this request (as represented by + // notifications/progress). The value of this parameter is an + // opaque token that will be attached to any subsequent + // notifications. The receiver is not obligated to provide these + // notifications. + ProgressToken ProgressToken + + // AdditionalFields are any fields present in the Meta that are not + // otherwise defined in the protocol. + AdditionalFields map[string]any +} + +func (m *Meta) MarshalJSON() ([]byte, error) { + raw := make(map[string]any) + if m.ProgressToken != nil { + raw["progressToken"] = m.ProgressToken + } + maps.Copy(raw, m.AdditionalFields) + + return json.Marshal(raw) +} + +func (m *Meta) UnmarshalJSON(data []byte) error { + raw := make(map[string]any) + if err := json.Unmarshal(data, &raw); err != nil { + return err + } + m.ProgressToken = raw["progressToken"] + delete(raw, "progressToken") + m.AdditionalFields = raw + return nil +} + type Request struct { Method string `json:"method"` Params struct { - Meta *struct { - // If specified, the caller is requesting out-of-band progress - // notifications for this request (as represented by - // notifications/progress). The value of this parameter is an - // opaque token that will be attached to any subsequent - // notifications. The receiver is not obligated to provide these - // notifications. - ProgressToken ProgressToken `json:"progressToken,omitempty"` - } `json:"_meta,omitempty"` + Meta *Meta `json:"_meta,omitempty"` } `json:"params,omitempty"` } diff --git a/mcp/types_test.go b/mcp/types_test.go new file mode 100644 index 00000000..526e1ac1 --- /dev/null +++ b/mcp/types_test.go @@ -0,0 +1,70 @@ +package mcp + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestMetaMarshalling(t *testing.T) { + tests := []struct { + name string + json string + meta *Meta + expMeta *Meta + }{ + { + name: "empty", + json: "{}", + meta: &Meta{}, + expMeta: &Meta{AdditionalFields: map[string]any{}}, + }, + { + name: "empty additional fields", + json: "{}", + meta: &Meta{AdditionalFields: map[string]any{}}, + expMeta: &Meta{AdditionalFields: map[string]any{}}, + }, + { + name: "string token only", + json: `{"progressToken":"123"}`, + meta: &Meta{ProgressToken: "123"}, + expMeta: &Meta{ProgressToken: "123", AdditionalFields: map[string]any{}}, + }, + { + name: "string token only, empty additional fields", + json: `{"progressToken":"123"}`, + meta: &Meta{ProgressToken: "123", AdditionalFields: map[string]any{}}, + expMeta: &Meta{ProgressToken: "123", AdditionalFields: map[string]any{}}, + }, + { + name: "additional fields only", + json: `{"a":2,"b":"1"}`, + meta: &Meta{AdditionalFields: map[string]any{"a": 2, "b": "1"}}, + // For untyped map, numbers are always float64 + expMeta: &Meta{AdditionalFields: map[string]any{"a": float64(2), "b": "1"}}, + }, + { + name: "progress token and additional fields", + json: `{"a":2,"b":"1","progressToken":"123"}`, + meta: &Meta{ProgressToken: "123", AdditionalFields: map[string]any{"a": 2, "b": "1"}}, + // For untyped map, numbers are always float64 + expMeta: &Meta{ProgressToken: "123", AdditionalFields: map[string]any{"a": float64(2), "b": "1"}}, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + data, err := json.Marshal(tc.meta) + require.NoError(t, err) + assert.Equal(t, tc.json, string(data)) + + meta := &Meta{} + err = json.Unmarshal([]byte(tc.json), meta) + require.NoError(t, err) + assert.Equal(t, tc.expMeta, meta) + }) + } +} From 09c23b5fec768432e3362bea05e69f57a3bc7c92 Mon Sep 17 00:00:00 2001 From: Navendu Pottekkat Date: Fri, 16 May 2025 23:19:42 +0530 Subject: [PATCH 4/5] fix: type mismatch for request/response ID (#291) * fix: type mismatch for request/response ID * fix: make suggested changes --- client/client.go | 2 +- client/transport/interface.go | 10 +- client/transport/sse.go | 24 +++-- client/transport/sse_test.go | 51 +++++++---- client/transport/stdio.go | 24 +++-- client/transport/stdio_test.go | 112 +++++++++++++++++++---- client/transport/streamable_http.go | 4 +- client/transport/streamable_http_test.go | 52 +++++++---- mcp/types.go | 72 ++++++++++++++- server/server.go | 6 +- server/server_test.go | 10 +- server/sse.go | 2 +- server/sse_test.go | 11 ++- testdata/mockstdio_server.go | 10 +- 14 files changed, 297 insertions(+), 93 deletions(-) diff --git a/client/client.go b/client/client.go index 7689633c..dd0e31a0 100644 --- a/client/client.go +++ b/client/client.go @@ -104,7 +104,7 @@ func (c *Client) sendRequest( request := transport.JSONRPCRequest{ JSONRPC: mcp.JSONRPC_VERSION, - ID: id, + ID: mcp.NewRequestId(id), Method: method, Params: params, } diff --git a/client/transport/interface.go b/client/transport/interface.go index 8ac75d74..2fba4abf 100644 --- a/client/transport/interface.go +++ b/client/transport/interface.go @@ -27,15 +27,15 @@ type Interface interface { } type JSONRPCRequest struct { - JSONRPC string `json:"jsonrpc"` - ID int64 `json:"id"` - Method string `json:"method"` - Params any `json:"params,omitempty"` + JSONRPC string `json:"jsonrpc"` + ID mcp.RequestId `json:"id"` + Method string `json:"method"` + Params any `json:"params,omitempty"` } type JSONRPCResponse struct { JSONRPC string `json:"jsonrpc"` - ID *int64 `json:"id"` + ID mcp.RequestId `json:"id"` Result json.RawMessage `json:"result"` Error *struct { Code int `json:"code"` diff --git a/client/transport/sse.go b/client/transport/sse.go index eda9446e..24c5ce35 100644 --- a/client/transport/sse.go +++ b/client/transport/sse.go @@ -25,7 +25,7 @@ type SSE struct { baseURL *url.URL endpoint *url.URL httpClient *http.Client - responses map[int64]chan *JSONRPCResponse + responses map[string]chan *JSONRPCResponse mu sync.RWMutex onNotification func(mcp.JSONRPCNotification) notifyMu sync.RWMutex @@ -62,7 +62,7 @@ func NewSSE(baseURL string, options ...ClientOption) (*SSE, error) { smc := &SSE{ baseURL: parsedURL, httpClient: &http.Client{}, - responses: make(map[int64]chan *JSONRPCResponse), + responses: make(map[string]chan *JSONRPCResponse), endpointChan: make(chan struct{}), headers: make(map[string]string), } @@ -200,7 +200,7 @@ func (c *SSE) handleSSEEvent(event, data string) { } // Handle notification - if baseMessage.ID == nil { + if baseMessage.ID.IsNil() { var notification mcp.JSONRPCNotification if err := json.Unmarshal([]byte(data), ¬ification); err != nil { return @@ -213,14 +213,17 @@ func (c *SSE) handleSSEEvent(event, data string) { return } + // Create string key for map lookup + idKey := baseMessage.ID.String() + c.mu.RLock() - ch, ok := c.responses[*baseMessage.ID] + ch, exists := c.responses[idKey] c.mu.RUnlock() - if ok { + if exists { ch <- &baseMessage c.mu.Lock() - delete(c.responses, *baseMessage.ID) + delete(c.responses, idKey) c.mu.Unlock() } } @@ -267,14 +270,17 @@ func (c *SSE) SendRequest( req.Header.Set(k, v) } + // Create string key for map lookup + idKey := request.ID.String() + // Register response channel responseChan := make(chan *JSONRPCResponse, 1) c.mu.Lock() - c.responses[request.ID] = responseChan + c.responses[idKey] = responseChan c.mu.Unlock() deleteResponseChan := func() { c.mu.Lock() - delete(c.responses, request.ID) + delete(c.responses, idKey) c.mu.Unlock() } @@ -327,7 +333,7 @@ func (c *SSE) Close() error { for _, ch := range c.responses { close(ch) } - c.responses = make(map[int64]chan *JSONRPCResponse) + c.responses = make(map[string]chan *JSONRPCResponse) c.mu.Unlock() return nil diff --git a/client/transport/sse_test.go b/client/transport/sse_test.go index 230157d2..82074b11 100644 --- a/client/transport/sse_test.go +++ b/client/transport/sse_test.go @@ -160,7 +160,7 @@ func TestSSE(t *testing.T) { request := JSONRPCRequest{ JSONRPC: "2.0", - ID: 1, + ID: mcp.NewRequestId(int64(1)), Method: "debug/echo", Params: params, } @@ -174,7 +174,7 @@ func TestSSE(t *testing.T) { // Parse the result to verify echo var result struct { JSONRPC string `json:"jsonrpc"` - ID int64 `json:"id"` + ID mcp.RequestId `json:"id"` Method string `json:"method"` Params map[string]any `json:"params"` } @@ -187,8 +187,11 @@ func TestSSE(t *testing.T) { if result.JSONRPC != "2.0" { t.Errorf("Expected JSONRPC value '2.0', got '%s'", result.JSONRPC) } - if result.ID != 1 { - t.Errorf("Expected ID 1, got %d", result.ID) + idValue, ok := result.ID.Value().(int64) + if !ok { + t.Errorf("Expected ID to be int64, got %T", result.ID.Value()) + } else if idValue != 1 { + t.Errorf("Expected ID 1, got %d", idValue) } if result.Method != "debug/echo" { t.Errorf("Expected method 'debug/echo', got '%s'", result.Method) @@ -211,7 +214,7 @@ func TestSSE(t *testing.T) { // Prepare a request request := JSONRPCRequest{ JSONRPC: "2.0", - ID: 3, + ID: mcp.NewRequestId(int64(3)), Method: "debug/echo", } @@ -292,7 +295,7 @@ func TestSSE(t *testing.T) { // Each request has a unique ID and payload request := JSONRPCRequest{ JSONRPC: "2.0", - ID: int64(100 + idx), + ID: mcp.NewRequestId(int64(100 + idx)), Method: "debug/echo", Params: map[string]any{ "requestIndex": idx, @@ -317,15 +320,25 @@ func TestSSE(t *testing.T) { continue } - if responses[i] == nil || responses[i].ID == nil || *responses[i].ID != int64(100+i) { - t.Errorf("Request %d: Expected ID %d, got %v", i, 100+i, responses[i]) + if responses[i] == nil { + t.Errorf("Request %d: Response is nil", i) + continue + } + + expectedId := int64(100 + i) + idValue, ok := responses[i].ID.Value().(int64) + if !ok { + t.Errorf("Request %d: Expected ID to be int64, got %T", i, responses[i].ID.Value()) + continue + } else if idValue != expectedId { + t.Errorf("Request %d: Expected ID %d, got %d", i, expectedId, idValue) continue } // Parse the result to verify echo var result struct { JSONRPC string `json:"jsonrpc"` - ID int64 `json:"id"` + ID mcp.RequestId `json:"id"` Method string `json:"method"` Params map[string]any `json:"params"` } @@ -336,8 +349,11 @@ func TestSSE(t *testing.T) { } // Verify data matches what was sent - if result.ID != int64(100+i) { - t.Errorf("Request %d: Expected echoed ID %d, got %d", i, 100+i, result.ID) + idValue, ok = result.ID.Value().(int64) + if !ok { + t.Errorf("Request %d: Expected ID to be int64, got %T", i, result.ID.Value()) + } else if idValue != int64(100+i) { + t.Errorf("Request %d: Expected echoed ID %d, got %d", i, 100+i, idValue) } if result.Method != "debug/echo" { @@ -356,7 +372,7 @@ func TestSSE(t *testing.T) { // Prepare a request request := JSONRPCRequest{ JSONRPC: "2.0", - ID: 100, + ID: mcp.NewRequestId(int64(100)), Method: "debug/echo_error_string", } @@ -378,8 +394,11 @@ func TestSSE(t *testing.T) { if responseError.Method != "debug/echo_error_string" { t.Errorf("Expected method 'debug/echo_error_string', got '%s'", responseError.Method) } - if responseError.ID != 100 { - t.Errorf("Expected ID 100, got %d", responseError.ID) + idValue, ok := responseError.ID.Value().(int64) + if !ok { + t.Errorf("Expected ID to be int64, got %T", responseError.ID.Value()) + } else if idValue != 100 { + t.Errorf("Expected ID 100, got %d", idValue) } if responseError.JSONRPC != "2.0" { t.Errorf("Expected JSONRPC '2.0', got '%s'", responseError.JSONRPC) @@ -453,7 +472,7 @@ func TestSSEErrors(t *testing.T) { // Prepare a request request := JSONRPCRequest{ JSONRPC: "2.0", - ID: 99, + ID: mcp.NewRequestId(int64(99)), Method: "ping", } @@ -492,7 +511,7 @@ func TestSSEErrors(t *testing.T) { // Try to send a request after close request := JSONRPCRequest{ JSONRPC: "2.0", - ID: 1, + ID: mcp.NewRequestId(int64(1)), Method: "ping", } diff --git a/client/transport/stdio.go b/client/transport/stdio.go index 3d9d832a..c300c405 100644 --- a/client/transport/stdio.go +++ b/client/transport/stdio.go @@ -26,7 +26,7 @@ type Stdio struct { stdin io.WriteCloser stdout *bufio.Reader stderr io.ReadCloser - responses map[int64]chan *JSONRPCResponse + responses map[string]chan *JSONRPCResponse mu sync.RWMutex done chan struct{} onNotification func(mcp.JSONRPCNotification) @@ -42,7 +42,7 @@ func NewIO(input io.Reader, output io.WriteCloser, logging io.ReadCloser) *Stdio stdout: bufio.NewReader(input), stderr: logging, - responses: make(map[int64]chan *JSONRPCResponse), + responses: make(map[string]chan *JSONRPCResponse), done: make(chan struct{}), } } @@ -61,7 +61,7 @@ func NewStdio( args: args, env: env, - responses: make(map[int64]chan *JSONRPCResponse), + responses: make(map[string]chan *JSONRPCResponse), done: make(chan struct{}), } @@ -181,7 +181,7 @@ func (c *Stdio) readResponses() { } // Handle notification - if baseMessage.ID == nil { + if baseMessage.ID.IsNil() { var notification mcp.JSONRPCNotification if err := json.Unmarshal([]byte(line), ¬ification); err != nil { continue @@ -194,14 +194,17 @@ func (c *Stdio) readResponses() { continue } + // Create string key for map lookup + idKey := baseMessage.ID.String() + c.mu.RLock() - ch, ok := c.responses[*baseMessage.ID] + ch, exists := c.responses[idKey] c.mu.RUnlock() - if ok { + if exists { ch <- &baseMessage c.mu.Lock() - delete(c.responses, *baseMessage.ID) + delete(c.responses, idKey) c.mu.Unlock() } } @@ -227,14 +230,17 @@ func (c *Stdio) SendRequest( } requestBytes = append(requestBytes, '\n') + // Create string key for map lookup + idKey := request.ID.String() + // Register response channel responseChan := make(chan *JSONRPCResponse, 1) c.mu.Lock() - c.responses[request.ID] = responseChan + c.responses[idKey] = responseChan c.mu.Unlock() deleteResponseChan := func() { c.mu.Lock() - delete(c.responses, request.ID) + delete(c.responses, idKey) c.mu.Unlock() } diff --git a/client/transport/stdio_test.go b/client/transport/stdio_test.go index 155859e1..3eea5b23 100644 --- a/client/transport/stdio_test.go +++ b/client/transport/stdio_test.go @@ -70,7 +70,7 @@ func TestStdio(t *testing.T) { defer stdio.Close() t.Run("SendRequest", func(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), 5000000000*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() params := map[string]any{ @@ -80,7 +80,7 @@ func TestStdio(t *testing.T) { request := JSONRPCRequest{ JSONRPC: "2.0", - ID: 1, + ID: mcp.NewRequestId(int64(1)), Method: "debug/echo", Params: params, } @@ -94,7 +94,7 @@ func TestStdio(t *testing.T) { // Parse the result to verify echo var result struct { JSONRPC string `json:"jsonrpc"` - ID int64 `json:"id"` + ID mcp.RequestId `json:"id"` Method string `json:"method"` Params map[string]any `json:"params"` } @@ -107,8 +107,11 @@ func TestStdio(t *testing.T) { if result.JSONRPC != "2.0" { t.Errorf("Expected JSONRPC value '2.0', got '%s'", result.JSONRPC) } - if result.ID != 1 { - t.Errorf("Expected ID 1, got %d", result.ID) + idValue, ok := result.ID.Value().(int64) + if !ok { + t.Errorf("Expected ID to be int64, got %T", result.ID.Value()) + } else if idValue != 1 { + t.Errorf("Expected ID 1, got %d", idValue) } if result.Method != "debug/echo" { t.Errorf("Expected method 'debug/echo', got '%s'", result.Method) @@ -131,7 +134,7 @@ func TestStdio(t *testing.T) { // Prepare a request request := JSONRPCRequest{ JSONRPC: "2.0", - ID: 3, + ID: mcp.NewRequestId(int64(3)), Method: "debug/echo", } @@ -211,7 +214,7 @@ func TestStdio(t *testing.T) { // Each request has a unique ID and payload request := JSONRPCRequest{ JSONRPC: "2.0", - ID: int64(100 + idx), + ID: mcp.NewRequestId(int64(100 + idx)), Method: "debug/echo", Params: map[string]any{ "requestIndex": idx, @@ -236,15 +239,25 @@ func TestStdio(t *testing.T) { continue } - if responses[i] == nil || responses[i].ID == nil || *responses[i].ID != int64(100+i) { - t.Errorf("Request %d: Expected ID %d, got %v", i, 100+i, responses[i]) + if responses[i] == nil { + t.Errorf("Request %d: Response is nil", i) + continue + } + + expectedId := int64(100 + i) + idValue, ok := responses[i].ID.Value().(int64) + if !ok { + t.Errorf("Request %d: Expected ID to be int64, got %T", i, responses[i].ID.Value()) + continue + } else if idValue != expectedId { + t.Errorf("Request %d: Expected ID %d, got %d", i, expectedId, idValue) continue } // Parse the result to verify echo var result struct { JSONRPC string `json:"jsonrpc"` - ID int64 `json:"id"` + ID mcp.RequestId `json:"id"` Method string `json:"method"` Params map[string]any `json:"params"` } @@ -255,8 +268,11 @@ func TestStdio(t *testing.T) { } // Verify data matches what was sent - if result.ID != int64(100+i) { - t.Errorf("Request %d: Expected echoed ID %d, got %d", i, 100+i, result.ID) + idValue, ok = result.ID.Value().(int64) + if !ok { + t.Errorf("Request %d: Expected ID to be int64, got %T", i, result.ID.Value()) + } else if idValue != int64(100+i) { + t.Errorf("Request %d: Expected echoed ID %d, got %d", i, 100+i, idValue) } if result.Method != "debug/echo" { @@ -271,11 +287,10 @@ func TestStdio(t *testing.T) { }) t.Run("ResponseError", func(t *testing.T) { - // Prepare a request request := JSONRPCRequest{ JSONRPC: "2.0", - ID: 100, + ID: mcp.NewRequestId(int64(100)), Method: "debug/echo_error_string", } @@ -297,14 +312,75 @@ func TestStdio(t *testing.T) { if responseError.Method != "debug/echo_error_string" { t.Errorf("Expected method 'debug/echo_error_string', got '%s'", responseError.Method) } - if responseError.ID != 100 { - t.Errorf("Expected ID 100, got %d", responseError.ID) + idValue, ok := responseError.ID.Value().(int64) + if !ok { + t.Errorf("Expected ID to be int64, got %T", responseError.ID.Value()) + } else if idValue != 100 { + t.Errorf("Expected ID 100, got %d", idValue) } if responseError.JSONRPC != "2.0" { t.Errorf("Expected JSONRPC '2.0', got '%s'", responseError.JSONRPC) } }) + t.Run("SendRequestWithStringID", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + params := map[string]any{ + "string": "string id test", + "array": []any{4, 5, 6}, + } + + // Use a string ID instead of an integer + request := JSONRPCRequest{ + JSONRPC: "2.0", + ID: mcp.NewRequestId("request-123"), + Method: "debug/echo", + Params: params, + } + + response, err := stdio.SendRequest(ctx, request) + if err != nil { + t.Fatalf("SendRequest failed: %v", err) + } + + var result struct { + JSONRPC string `json:"jsonrpc"` + ID mcp.RequestId `json:"id"` + Method string `json:"method"` + Params map[string]any `json:"params"` + } + + if err := json.Unmarshal(response.Result, &result); err != nil { + t.Fatalf("Failed to unmarshal result: %v", err) + } + + if result.JSONRPC != "2.0" { + t.Errorf("Expected JSONRPC value '2.0', got '%s'", result.JSONRPC) + } + + // Verify the ID is a string and has the expected value + idValue, ok := result.ID.Value().(string) + if !ok { + t.Errorf("Expected ID to be string, got %T", result.ID.Value()) + } else if idValue != "request-123" { + t.Errorf("Expected ID 'request-123', got '%s'", idValue) + } + + if result.Method != "debug/echo" { + t.Errorf("Expected method 'debug/echo', got '%s'", result.Method) + } + + if str, ok := result.Params["string"].(string); !ok || str != "string id test" { + t.Errorf("Expected string 'string id test', got %v", result.Params["string"]) + } + + if arr, ok := result.Params["array"].([]any); !ok || len(arr) != 3 { + t.Errorf("Expected array with 3 items, got %v", result.Params["array"]) + } + }) + } func TestStdioErrors(t *testing.T) { @@ -346,7 +422,7 @@ func TestStdioErrors(t *testing.T) { // Prepare a request request := JSONRPCRequest{ JSONRPC: "2.0", - ID: 99, + ID: mcp.NewRequestId(int64(99)), Method: "ping", } @@ -398,7 +474,7 @@ func TestStdioErrors(t *testing.T) { // Try to send a request after close request := JSONRPCRequest{ JSONRPC: "2.0", - ID: 1, + ID: mcp.NewRequestId(int64(1)), Method: "ping", } diff --git a/client/transport/streamable_http.go b/client/transport/streamable_http.go index 98719bd0..34677031 100644 --- a/client/transport/streamable_http.go +++ b/client/transport/streamable_http.go @@ -217,7 +217,7 @@ func (c *StreamableHTTP) SendRequest( } // should not be a notification - if response.ID == nil { + if response.ID.IsNil() { return nil, fmt.Errorf("response should contain RPC id: %v", response) } @@ -258,7 +258,7 @@ func (c *StreamableHTTP) handleSSEResponse(ctx context.Context, reader io.ReadCl } // Handle notification - if message.ID == nil { + if message.ID.IsNil() { var notification mcp.JSONRPCNotification if err := json.Unmarshal([]byte(data), ¬ification); err != nil { fmt.Printf("failed to unmarshal notification: %v\n", err) diff --git a/client/transport/streamable_http_test.go b/client/transport/streamable_http_test.go index addddd20..de3cddff 100644 --- a/client/transport/streamable_http_test.go +++ b/client/transport/streamable_http_test.go @@ -147,7 +147,7 @@ func TestStreamableHTTP(t *testing.T) { initRequest := JSONRPCRequest{ JSONRPC: "2.0", - ID: 1, + ID: mcp.NewRequestId(int64(0)), Method: "initialize", } @@ -168,7 +168,7 @@ func TestStreamableHTTP(t *testing.T) { request := JSONRPCRequest{ JSONRPC: "2.0", - ID: 1, + ID: mcp.NewRequestId(int64(1)), Method: "debug/echo", Params: params, } @@ -182,7 +182,7 @@ func TestStreamableHTTP(t *testing.T) { // Parse the result to verify echo var result struct { JSONRPC string `json:"jsonrpc"` - ID int64 `json:"id"` + ID mcp.RequestId `json:"id"` Method string `json:"method"` Params map[string]any `json:"params"` } @@ -195,8 +195,11 @@ func TestStreamableHTTP(t *testing.T) { if result.JSONRPC != "2.0" { t.Errorf("Expected JSONRPC value '2.0', got '%s'", result.JSONRPC) } - if result.ID != 1 { - t.Errorf("Expected ID 1, got %d", result.ID) + idValue, ok := result.ID.Value().(int64) + if !ok { + t.Errorf("Expected ID to be int64, got %T", result.ID.Value()) + } else if idValue != 1 { + t.Errorf("Expected ID 1, got %d", idValue) } if result.Method != "debug/echo" { t.Errorf("Expected method 'debug/echo', got '%s'", result.Method) @@ -219,7 +222,7 @@ func TestStreamableHTTP(t *testing.T) { // Prepare a request request := JSONRPCRequest{ JSONRPC: "2.0", - ID: 3, + ID: mcp.NewRequestId(int64(3)), Method: "debug/echo", } @@ -247,7 +250,7 @@ func TestStreamableHTTP(t *testing.T) { request := JSONRPCRequest{ JSONRPC: "2.0", - ID: 1, + ID: mcp.NewRequestId(int64(1)), Method: "debug/echo_notification", } @@ -266,7 +269,7 @@ func TestStreamableHTTP(t *testing.T) { if got == nil { t.Errorf("Notification handler did not send the expected notification: got nil") } - if int64(got["id"].(float64)) != request.ID || + if int64(got["id"].(float64)) != request.ID.Value().(int64) || got["jsonrpc"] != request.JSONRPC || got["method"] != request.Method { @@ -302,7 +305,7 @@ func TestStreamableHTTP(t *testing.T) { // Each request has a unique ID and payload request := JSONRPCRequest{ JSONRPC: "2.0", - ID: int64(100 + idx), + ID: mcp.NewRequestId(int64(100 + idx)), Method: "debug/echo", Params: map[string]any{ "requestIndex": idx, @@ -327,15 +330,25 @@ func TestStreamableHTTP(t *testing.T) { continue } - if responses[i] == nil || responses[i].ID == nil || *responses[i].ID != int64(100+i) { - t.Errorf("Request %d: Expected ID %d, got %v", i, 100+i, responses[i]) + if responses[i] == nil { + t.Errorf("Request %d: Response is nil", i) + continue + } + + expectedId := int64(100 + i) + idValue, ok := responses[i].ID.Value().(int64) + if !ok { + t.Errorf("Request %d: Expected ID to be int64, got %T", i, responses[i].ID.Value()) + continue + } else if idValue != expectedId { + t.Errorf("Request %d: Expected ID %d, got %d", i, expectedId, idValue) continue } // Parse the result to verify echo var result struct { JSONRPC string `json:"jsonrpc"` - ID int64 `json:"id"` + ID mcp.RequestId `json:"id"` Method string `json:"method"` Params map[string]any `json:"params"` } @@ -346,8 +359,8 @@ func TestStreamableHTTP(t *testing.T) { } // Verify data matches what was sent - if result.ID != int64(100+i) { - t.Errorf("Request %d: Expected echoed ID %d, got %d", i, 100+i, result.ID) + if result.ID.Value().(int64) != expectedId { + t.Errorf("Request %d: Expected echoed ID %d, got %d", i, expectedId, result.ID.Value().(int64)) } if result.Method != "debug/echo" { @@ -368,7 +381,7 @@ func TestStreamableHTTP(t *testing.T) { // Prepare a request request := JSONRPCRequest{ JSONRPC: "2.0", - ID: 100, + ID: mcp.NewRequestId(int64(100)), Method: "debug/echo_error_string", } @@ -390,8 +403,11 @@ func TestStreamableHTTP(t *testing.T) { if responseError.Method != "debug/echo_error_string" { t.Errorf("Expected method 'debug/echo_error_string', got '%s'", responseError.Method) } - if responseError.ID != 100 { - t.Errorf("Expected ID 100, got %d", responseError.ID) + idValue, ok := responseError.ID.Value().(int64) + if !ok { + t.Errorf("Expected ID to be int64, got %T", responseError.ID.Value()) + } else if idValue != 100 { + t.Errorf("Expected ID 100, got %d", idValue) } if responseError.JSONRPC != "2.0" { t.Errorf("Expected JSONRPC '2.0', got '%s'", responseError.JSONRPC) @@ -421,7 +437,7 @@ func TestStreamableHTTPErrors(t *testing.T) { request := JSONRPCRequest{ JSONRPC: "2.0", - ID: 1, + ID: mcp.NewRequestId(int64(1)), Method: "initialize", } diff --git a/mcp/types.go b/mcp/types.go index e7fdb6f0..d086ac90 100644 --- a/mcp/types.go +++ b/mcp/types.go @@ -4,6 +4,8 @@ package mcp import ( "encoding/json" + "fmt" + "strconv" "maps" "github.com/yosida95/uritemplate/v3" @@ -222,7 +224,75 @@ type Result struct { // RequestId is a uniquely identifying ID for a request in JSON-RPC. // It can be any JSON-serializable value, typically a number or string. -type RequestId any +type RequestId struct { + value any +} + +// NewRequestId creates a new RequestId with the given value +func NewRequestId(value any) RequestId { + return RequestId{value: value} +} + +// Value returns the underlying value of the RequestId +func (r RequestId) Value() any { + return r.value +} + +// String returns a string representation of the RequestId +func (r RequestId) String() string { + switch v := r.value.(type) { + case string: + return "string:" + v + case int64: + return "int64:" + strconv.FormatInt(v, 10) + case float64: + if v == float64(int64(v)) { + return "int64:" + strconv.FormatInt(int64(v), 10) + } + return "float64:" + strconv.FormatFloat(v, 'f', -1, 64) + case nil: + return "" + default: + return "unknown:" + fmt.Sprintf("%v", v) + } +} + +// IsNil returns true if the RequestId is nil +func (r RequestId) IsNil() bool { + return r.value == nil +} + +func (r RequestId) MarshalJSON() ([]byte, error) { + return json.Marshal(r.value) +} + +func (r *RequestId) UnmarshalJSON(data []byte) error { + + if string(data) == "null" { + r.value = nil + return nil + } + + // Try unmarshaling as string first + var s string + if err := json.Unmarshal(data, &s); err == nil { + r.value = s + return nil + } + + // JSON numbers are unmarshaled as float64 in Go + var f float64 + if err := json.Unmarshal(data, &f); err == nil { + if f == float64(int64(f)) { + r.value = int64(f) + } else { + r.value = f + } + return nil + } + + return fmt.Errorf("invalid request id: %s", string(data)) +} // JSONRPCRequest represents a request that expects a response. type JSONRPCRequest struct { diff --git a/server/server.go b/server/server.go index b31b4865..e5b48a5e 100644 --- a/server/server.go +++ b/server/server.go @@ -101,7 +101,7 @@ func (e *requestError) Error() string { func (e *requestError) ToJSONRPCError() mcp.JSONRPCError { return mcp.JSONRPCError{ JSONRPC: mcp.JSONRPC_VERSION, - ID: e.id, + ID: mcp.NewRequestId(e.id), Error: struct { Code int `json:"code"` Message string `json:"message"` @@ -937,7 +937,7 @@ func (s *MCPServer) handleNotification( func createResponse(id any, result any) mcp.JSONRPCMessage { return mcp.JSONRPCResponse{ JSONRPC: mcp.JSONRPC_VERSION, - ID: id, + ID: mcp.NewRequestId(id), Result: result, } } @@ -949,7 +949,7 @@ func createErrorResponse( ) mcp.JSONRPCMessage { return mcp.JSONRPCError{ JSONRPC: mcp.JSONRPC_VERSION, - ID: id, + ID: mcp.NewRequestId(id), Error: struct { Code int `json:"code"` Message string `json:"message"` diff --git a/server/server_test.go b/server/server_test.go index 4615b0fb..5c2bff4e 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -134,7 +134,7 @@ func TestMCPServer_Capabilities(t *testing.T) { server := NewMCPServer("test-server", "1.0.0", tt.options...) message := mcp.JSONRPCRequest{ JSONRPC: "2.0", - ID: 1, + ID: mcp.NewRequestId(int64(1)), Request: mcp.Request{ Method: "initialize", }, @@ -388,7 +388,7 @@ func TestMCPServer_HandleValidMessages(t *testing.T) { name: "Initialize request", message: mcp.JSONRPCRequest{ JSONRPC: "2.0", - ID: 1, + ID: mcp.NewRequestId(int64(1)), Request: mcp.Request{ Method: "initialize", }, @@ -413,7 +413,7 @@ func TestMCPServer_HandleValidMessages(t *testing.T) { name: "Ping request", message: mcp.JSONRPCRequest{ JSONRPC: "2.0", - ID: 1, + ID: mcp.NewRequestId(int64(1)), Request: mcp.Request{ Method: "ping", }, @@ -430,7 +430,7 @@ func TestMCPServer_HandleValidMessages(t *testing.T) { name: "List resources", message: mcp.JSONRPCRequest{ JSONRPC: "2.0", - ID: 1, + ID: mcp.NewRequestId(int64(1)), Request: mcp.Request{ Method: "resources/list", }, @@ -1127,7 +1127,7 @@ func TestMCPServer_Instructions(t *testing.T) { message := mcp.JSONRPCRequest{ JSONRPC: "2.0", - ID: 1, + ID: mcp.NewRequestId(int64(1)), Request: mcp.Request{ Method: "initialize", }, diff --git a/server/sse.go b/server/sse.go index 630927d1..d51a8979 100644 --- a/server/sse.go +++ b/server/sse.go @@ -338,7 +338,7 @@ func (s *SSEServer) handleSSE(w http.ResponseWriter, r *http.Request) { case <-ticker.C: message := mcp.JSONRPCRequest{ JSONRPC: "2.0", - ID: session.requestID.Add(1), + ID: mcp.NewRequestId(session.requestID.Add(1)), Request: mcp.Request{ Method: "ping", }, diff --git a/server/sse_test.go b/server/sse_test.go index aebf69d0..62bd616b 100644 --- a/server/sse_test.go +++ b/server/sse_test.go @@ -810,7 +810,16 @@ func TestSSEServer(t *testing.T) { } if pingMsg.Method == "ping" { - pingID = pingMsg.ID.(float64) + idValue, ok := pingMsg.ID.Value().(int64) + if ok { + pingID = float64(idValue) + } else { + floatValue, ok := pingMsg.ID.Value().(float64) + if !ok { + t.Fatalf("Expected ping ID to be number, got %T: %v", pingMsg.ID.Value(), pingMsg.ID.Value()) + } + pingID = floatValue + } t.Logf("Received ping with ID: %f", pingID) break // We got the ping, exit the loop } diff --git a/testdata/mockstdio_server.go b/testdata/mockstdio_server.go index 63f7835d..f561285e 100644 --- a/testdata/mockstdio_server.go +++ b/testdata/mockstdio_server.go @@ -6,19 +6,21 @@ import ( "fmt" "log/slog" "os" + + "github.com/mark3labs/mcp-go/mcp" ) type JSONRPCRequest struct { JSONRPC string `json:"jsonrpc"` - ID *int64 `json:"id,omitempty"` + ID *mcp.RequestId `json:"id,omitempty"` Method string `json:"method"` Params json.RawMessage `json:"params"` } type JSONRPCResponse struct { - JSONRPC string `json:"jsonrpc"` - ID *int64 `json:"id,omitempty"` - Result any `json:"result,omitempty"` + JSONRPC string `json:"jsonrpc"` + ID *mcp.RequestId `json:"id,omitempty"` + Result any `json:"result,omitempty"` Error *struct { Code int `json:"code"` Message string `json:"message"` From 077f546c180dcd6ba9ad3f8cdb30643ddd153297 Mon Sep 17 00:00:00 2001 From: cryo Date: Sat, 17 May 2025 01:51:02 +0800 Subject: [PATCH 5/5] feat(MCPServer): support `logging/setlevel` request (#276) * feat(MCPServer): support logging/setlevel request * update template file and adopt coderabbitai suggestion --- mcp/types.go | 4 ++ server/errors.go | 9 +-- server/hooks.go | 32 +++++++++ server/internal/gen/data.go | 10 +++ server/request_handler.go | 25 +++++++ server/server.go | 51 ++++++++++++-- server/session.go | 9 +++ server/session_test.go | 131 +++++++++++++++++++++++++++++++++++- server/sse.go | 20 +++++- server/stdio.go | 24 ++++++- 10 files changed, 299 insertions(+), 16 deletions(-) diff --git a/mcp/types.go b/mcp/types.go index d086ac90..4dea23d2 100644 --- a/mcp/types.go +++ b/mcp/types.go @@ -50,6 +50,10 @@ const ( // https://modelcontextprotocol.io/specification/2024-11-05/server/tools/ MethodToolsCall MCPMethod = "tools/call" + // MethodSetLogLevel configures the minimum log level for client + // https://modelcontextprotocol.io/specification/2025-03-26/server/utilities/logging + MethodSetLogLevel MCPMethod = "logging/setLevel" + // MethodNotificationResourcesListChanged notifies when the list of available resources changes. // https://modelcontextprotocol.io/specification/2025-03-26/server/resources#list-changed-notification MethodNotificationResourcesListChanged = "notifications/resources/list_changed" diff --git a/server/errors.go b/server/errors.go index b984a28c..68b7f787 100644 --- a/server/errors.go +++ b/server/errors.go @@ -13,10 +13,11 @@ var ( ErrToolNotFound = errors.New("tool not found") // Session-related errors - ErrSessionNotFound = errors.New("session not found") - ErrSessionExists = errors.New("session already exists") - ErrSessionNotInitialized = errors.New("session not properly initialized") - ErrSessionDoesNotSupportTools = errors.New("session does not support per-session tools") + ErrSessionNotFound = errors.New("session not found") + ErrSessionExists = errors.New("session already exists") + ErrSessionNotInitialized = errors.New("session not properly initialized") + ErrSessionDoesNotSupportTools = errors.New("session does not support per-session tools") + ErrSessionDoesNotSupportLogging = errors.New("session does not support setting logging level") // Notification-related errors ErrNotificationNotInitialized = errors.New("notification channel not initialized") diff --git a/server/hooks.go b/server/hooks.go index 18e53857..4baa1c4e 100644 --- a/server/hooks.go +++ b/server/hooks.go @@ -67,6 +67,9 @@ type OnAfterInitializeFunc func(ctx context.Context, id any, message *mcp.Initia type OnBeforePingFunc func(ctx context.Context, id any, message *mcp.PingRequest) type OnAfterPingFunc func(ctx context.Context, id any, message *mcp.PingRequest, result *mcp.EmptyResult) +type OnBeforeSetLevelFunc func(ctx context.Context, id any, message *mcp.SetLevelRequest) +type OnAfterSetLevelFunc func(ctx context.Context, id any, message *mcp.SetLevelRequest, result *mcp.EmptyResult) + type OnBeforeListResourcesFunc func(ctx context.Context, id any, message *mcp.ListResourcesRequest) type OnAfterListResourcesFunc func(ctx context.Context, id any, message *mcp.ListResourcesRequest, result *mcp.ListResourcesResult) @@ -99,6 +102,8 @@ type Hooks struct { OnAfterInitialize []OnAfterInitializeFunc OnBeforePing []OnBeforePingFunc OnAfterPing []OnAfterPingFunc + OnBeforeSetLevel []OnBeforeSetLevelFunc + OnAfterSetLevel []OnAfterSetLevelFunc OnBeforeListResources []OnBeforeListResourcesFunc OnAfterListResources []OnAfterListResourcesFunc OnBeforeListResourceTemplates []OnBeforeListResourceTemplatesFunc @@ -309,6 +314,33 @@ func (c *Hooks) afterPing(ctx context.Context, id any, message *mcp.PingRequest, hook(ctx, id, message, result) } } +func (c *Hooks) AddBeforeSetLevel(hook OnBeforeSetLevelFunc) { + c.OnBeforeSetLevel = append(c.OnBeforeSetLevel, hook) +} + +func (c *Hooks) AddAfterSetLevel(hook OnAfterSetLevelFunc) { + c.OnAfterSetLevel = append(c.OnAfterSetLevel, hook) +} + +func (c *Hooks) beforeSetLevel(ctx context.Context, id any, message *mcp.SetLevelRequest) { + c.beforeAny(ctx, id, mcp.MethodSetLogLevel, message) + if c == nil { + return + } + for _, hook := range c.OnBeforeSetLevel { + hook(ctx, id, message) + } +} + +func (c *Hooks) afterSetLevel(ctx context.Context, id any, message *mcp.SetLevelRequest, result *mcp.EmptyResult) { + c.onSuccess(ctx, id, mcp.MethodSetLogLevel, message, result) + if c == nil { + return + } + for _, hook := range c.OnAfterSetLevel { + hook(ctx, id, message, result) + } +} func (c *Hooks) AddBeforeListResources(hook OnBeforeListResourcesFunc) { c.OnBeforeListResources = append(c.OnBeforeListResources, hook) } diff --git a/server/internal/gen/data.go b/server/internal/gen/data.go index 50fd70ca..a468f460 100644 --- a/server/internal/gen/data.go +++ b/server/internal/gen/data.go @@ -27,6 +27,16 @@ var MCPRequestTypes = []MCPRequestType{ HookName: "Ping", UnmarshalError: "invalid ping request", HandlerFunc: "handlePing", + }, { + MethodName: "MethodSetLogLevel", + ParamType: "SetLevelRequest", + ResultType: "EmptyResult", + Group: "logging", + GroupName: "Logging", + GroupHookName: "Logging", + HookName: "SetLevel", + UnmarshalError: "invalid set level request", + HandlerFunc: "handleSetLevel", }, { MethodName: "MethodResourcesList", ParamType: "ListResourcesRequest", diff --git a/server/request_handler.go b/server/request_handler.go index 0bc16b41..25f6ef14 100644 --- a/server/request_handler.go +++ b/server/request_handler.go @@ -110,6 +110,31 @@ func (s *MCPServer) HandleMessage( } s.hooks.afterPing(ctx, baseMessage.ID, &request, result) return createResponse(baseMessage.ID, *result) + case mcp.MethodSetLogLevel: + var request mcp.SetLevelRequest + var result *mcp.EmptyResult + if s.capabilities.logging == nil { + err = &requestError{ + id: baseMessage.ID, + code: mcp.METHOD_NOT_FOUND, + err: fmt.Errorf("logging %w", ErrUnsupported), + } + } else if unmarshalErr := json.Unmarshal(message, &request); unmarshalErr != nil { + err = &requestError{ + id: baseMessage.ID, + code: mcp.INVALID_REQUEST, + err: &UnparsableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method}, + } + } else { + s.hooks.beforeSetLevel(ctx, baseMessage.ID, &request) + result, err = s.handleSetLevel(ctx, baseMessage.ID, request) + } + if err != nil { + s.hooks.onError(ctx, baseMessage.ID, baseMessage.Method, &request, err) + return err.ToJSONRPCError() + } + s.hooks.afterSetLevel(ctx, baseMessage.ID, &request, result) + return createResponse(baseMessage.ID, *result) case mcp.MethodResourcesList: var request mcp.ListResourcesRequest var result *mcp.ListResourcesResult diff --git a/server/server.go b/server/server.go index e5b48a5e..f0fa78e7 100644 --- a/server/server.go +++ b/server/server.go @@ -161,7 +161,7 @@ type serverCapabilities struct { tools *toolCapabilities resources *resourceCapabilities prompts *promptCapabilities - logging bool + logging *bool } // resourceCapabilities defines the supported resource-related features @@ -260,7 +260,7 @@ func WithToolCapabilities(listChanged bool) ServerOption { // WithLogging enables logging capabilities for the server func WithLogging() ServerOption { return func(s *MCPServer) { - s.capabilities.logging = true + s.capabilities.logging = mcp.ToBoolPtr(true) } } @@ -289,7 +289,7 @@ func NewMCPServer( tools: nil, resources: nil, prompts: nil, - logging: false, + logging: nil, }, } @@ -521,7 +521,7 @@ func (s *MCPServer) handleInitialize( } } - if s.capabilities.logging { + if s.capabilities.logging != nil && *s.capabilities.logging { capabilities.Logging = &struct{}{} } @@ -549,6 +549,49 @@ func (s *MCPServer) handlePing( return &mcp.EmptyResult{}, nil } +func (s *MCPServer) handleSetLevel( + ctx context.Context, + id any, + request mcp.SetLevelRequest, +) (*mcp.EmptyResult, *requestError) { + clientSession := ClientSessionFromContext(ctx) + if clientSession == nil || !clientSession.Initialized() { + return nil, &requestError{ + id: id, + code: mcp.INTERNAL_ERROR, + err: ErrSessionNotInitialized, + } + } + + sessionLogging, ok := clientSession.(SessionWithLogging) + if !ok { + return nil, &requestError{ + id: id, + code: mcp.INTERNAL_ERROR, + err: ErrSessionDoesNotSupportLogging, + } + } + + level := request.Params.Level + // Validate logging level + switch level { + case mcp.LoggingLevelDebug, mcp.LoggingLevelInfo, mcp.LoggingLevelNotice, + mcp.LoggingLevelWarning, mcp.LoggingLevelError, mcp.LoggingLevelCritical, + mcp.LoggingLevelAlert, mcp.LoggingLevelEmergency: + // Valid level + default: + return nil, &requestError{ + id: id, + code: mcp.INVALID_PARAMS, + err: fmt.Errorf("invalid logging level '%s'", level), + } + } + + sessionLogging.SetLogLevel(level) + + return &mcp.EmptyResult{}, nil +} + func listByPagination[T mcp.Named]( ctx context.Context, s *MCPServer, diff --git a/server/session.go b/server/session.go index 3a4206a7..0c50a260 100644 --- a/server/session.go +++ b/server/session.go @@ -19,6 +19,15 @@ type ClientSession interface { SessionID() string } +// SessionWithLogging is an extension of ClientSession that can receive log message notifications and set log level +type SessionWithLogging interface { + ClientSession + // SetLogLevel sets the minimum log level + SetLogLevel(level mcp.LoggingLevel) + // GetLogLevel retrieves the minimum log level + GetLogLevel() mcp.LoggingLevel +} + // SessionWithTools is an extension of ClientSession that can store session-specific tool data type SessionWithTools interface { ClientSession diff --git a/server/session_test.go b/server/session_test.go index 54a78170..8f2cfa76 100644 --- a/server/session_test.go +++ b/server/session_test.go @@ -5,6 +5,7 @@ import ( "encoding/json" "errors" "sync" + "sync/atomic" "testing" "time" @@ -98,9 +99,47 @@ func (f *sessionTestClientWithTools) SetSessionTools(tools map[string]ServerTool f.sessionTools = toolsCopy } -// Verify that both implementations satisfy their respective interfaces -var _ ClientSession = &sessionTestClient{} -var _ SessionWithTools = &sessionTestClientWithTools{} +// sessionTestClientWithTools implements the SessionWithLogging interface for testing +type sessionTestClientWithLogging struct { + sessionID string + notificationChannel chan mcp.JSONRPCNotification + initialized bool + loggingLevel atomic.Value +} + +func (f *sessionTestClientWithLogging) SessionID() string { + return f.sessionID +} + +func (f *sessionTestClientWithLogging) NotificationChannel() chan<- mcp.JSONRPCNotification { + return f.notificationChannel +} + +func (f *sessionTestClientWithLogging) Initialize() { + // set default logging level + f.loggingLevel.Store(mcp.LoggingLevelError) + f.initialized = true +} + +func (f *sessionTestClientWithLogging) Initialized() bool { + return f.initialized +} + +func (f *sessionTestClientWithLogging) SetLogLevel(level mcp.LoggingLevel) { + f.loggingLevel.Store(level) +} + +func (f *sessionTestClientWithLogging) GetLogLevel() mcp.LoggingLevel { + level := f.loggingLevel.Load() + return level.(mcp.LoggingLevel) +} + +// Verify that all implementations satisfy their respective interfaces +var ( + _ ClientSession = (*sessionTestClient)(nil) + _ SessionWithTools = (*sessionTestClientWithTools)(nil) + _ SessionWithLogging = (*sessionTestClientWithLogging)(nil) +) func TestSessionWithTools_Integration(t *testing.T) { server := NewMCPServer("test-server", "1.0.0", WithToolCapabilities(true)) @@ -917,3 +956,89 @@ func TestMCPServer_ToolNotificationsDisabled(t *testing.T) { // Verify tool was deleted from session assert.Len(t, session.GetSessionTools(), 0) } + +func TestMCPServer_SetLevelNotEnabled(t *testing.T) { + // Create server without logging capability + server := NewMCPServer("test-server", "1.0.0") + + // Create and initialize a session + sessionChan := make(chan mcp.JSONRPCNotification, 10) + session := &sessionTestClientWithLogging{ + sessionID: "session-1", + notificationChannel: sessionChan, + } + session.Initialize() + + // Register the session + err := server.RegisterSession(context.Background(), session) + require.NoError(t, err) + + // Try to set logging level when capability is disabled + sessionCtx := server.WithContext(context.Background(), session) + setRequest := map[string]any{ + "jsonrpc": "2.0", + "id": 1, + "method": "logging/setLevel", + "params": map[string]any{ + "level": mcp.LoggingLevelCritical, + }, + } + requestBytes, err := json.Marshal(setRequest) + require.NoError(t, err) + + response := server.HandleMessage(sessionCtx, requestBytes) + errorResponse, ok := response.(mcp.JSONRPCError) + assert.True(t, ok) + + // Verify we get a METHOD_NOT_FOUND error + assert.NotNil(t, errorResponse.Error) + assert.Equal(t, mcp.METHOD_NOT_FOUND, errorResponse.Error.Code) +} + +func TestMCPServer_SetLevel(t *testing.T) { + server := NewMCPServer("test-server", "1.0.0", WithLogging()) + + // Create and initicalize a session + sessionChan := make(chan mcp.JSONRPCNotification, 10) + session := &sessionTestClientWithLogging{ + sessionID: "session-1", + notificationChannel: sessionChan, + } + session.Initialize() + + // Check default logging level + if session.GetLogLevel() != mcp.LoggingLevelError { + t.Errorf("Expected error level, got %v", session.GetLogLevel()) + } + + // Register the session + err := server.RegisterSession(context.Background(), session) + require.NoError(t, err) + + // Set Logging level to critical + sessionCtx := server.WithContext(context.Background(), session) + setRequest := map[string]any{ + "jsonrpc": "2.0", + "id": 1, + "method": "logging/setLevel", + "params": map[string]any{ + "level": mcp.LoggingLevelCritical, + }, + } + requestBytes, err := json.Marshal(setRequest) + if err != nil { + t.Fatalf("Failed to marshal tool request: %v", err) + } + + response := server.HandleMessage(sessionCtx, requestBytes) + resp, ok := response.(mcp.JSONRPCResponse) + assert.True(t, ok) + + _, ok = resp.Result.(mcp.EmptyResult) + assert.True(t, ok) + + // Check logging level + if session.GetLogLevel() != mcp.LoggingLevelCritical { + t.Errorf("Expected critical level, got %v", session.GetLogLevel()) + } +} \ No newline at end of file diff --git a/server/sse.go b/server/sse.go index d51a8979..22162dd3 100644 --- a/server/sse.go +++ b/server/sse.go @@ -28,6 +28,7 @@ type sseSession struct { requestID atomic.Int64 notificationChannel chan mcp.JSONRPCNotification initialized atomic.Bool + loggingLevel atomic.Value tools sync.Map // stores session-specific tools } @@ -45,6 +46,8 @@ func (s *sseSession) NotificationChannel() chan<- mcp.JSONRPCNotification { } func (s *sseSession) Initialize() { + // set default logging level + s.loggingLevel.Store(mcp.LoggingLevelError) s.initialized.Store(true) } @@ -52,6 +55,18 @@ func (s *sseSession) Initialized() bool { return s.initialized.Load() } +func(s *sseSession) SetLogLevel(level mcp.LoggingLevel) { + s.loggingLevel.Store(level) +} + +func(s *sseSession) GetLogLevel() mcp.LoggingLevel { + level := s.loggingLevel.Load() + if level == nil { + return mcp.LoggingLevelError + } + return level.(mcp.LoggingLevel) +} + func (s *sseSession) GetSessionTools() map[string]ServerTool { tools := make(map[string]ServerTool) s.tools.Range(func(key, value any) bool { @@ -77,8 +92,9 @@ func (s *sseSession) SetSessionTools(tools map[string]ServerTool) { } var ( - _ ClientSession = (*sseSession)(nil) - _ SessionWithTools = (*sseSession)(nil) + _ ClientSession = (*sseSession)(nil) + _ SessionWithTools = (*sseSession)(nil) + _ SessionWithLogging = (*sseSession)(nil) ) // SSEServer implements a Server-Sent Events (SSE) based MCP server. diff --git a/server/stdio.go b/server/stdio.go index c4fe1bf6..6f51e996 100644 --- a/server/stdio.go +++ b/server/stdio.go @@ -51,8 +51,9 @@ func WithStdioContextFunc(fn StdioContextFunc) StdioOption { // stdioSession is a static client session, since stdio has only one client. type stdioSession struct { - notifications chan mcp.JSONRPCNotification - initialized atomic.Bool + notifications chan mcp.JSONRPCNotification + initialized atomic.Bool + loggingLevel atomic.Value } func (s *stdioSession) SessionID() string { @@ -64,6 +65,8 @@ func (s *stdioSession) NotificationChannel() chan<- mcp.JSONRPCNotification { } func (s *stdioSession) Initialize() { + // set default logging level + s.loggingLevel.Store(mcp.LoggingLevelError) s.initialized.Store(true) } @@ -71,7 +74,22 @@ func (s *stdioSession) Initialized() bool { return s.initialized.Load() } -var _ ClientSession = (*stdioSession)(nil) +func(s *stdioSession) SetLogLevel(level mcp.LoggingLevel) { + s.loggingLevel.Store(level) +} + +func(s *stdioSession) GetLogLevel() mcp.LoggingLevel { + level := s.loggingLevel.Load() + if level == nil { + return mcp.LoggingLevelError + } + return level.(mcp.LoggingLevel) +} + +var ( + _ ClientSession = (*stdioSession)(nil) + _ SessionWithLogging = (*stdioSession)(nil) +) var stdioSessionInstance = stdioSession{ notifications: make(chan mcp.JSONRPCNotification, 100),