diff --git a/client/client.go b/client/client.go index 60fe0cbf..63986328 100644 --- a/client/client.go +++ b/client/client.go @@ -16,11 +16,21 @@ import ( type Client struct { transport transport.Interface - initialized bool - notifications []func(mcp.JSONRPCNotification) - notifyMu sync.RWMutex - requestID atomic.Int64 - capabilities mcp.ServerCapabilities + initialized bool + notifications []func(mcp.JSONRPCNotification) + notifyMu sync.RWMutex + requestID atomic.Int64 + clientCapabilities mcp.ClientCapabilities + serverCapabilities mcp.ServerCapabilities +} + +type ClientOption func(*Client) + +// WithClientCapabilities sets the client capabilities for the client. +func WithClientCapabilities(capabilities mcp.ClientCapabilities) ClientOption { + return func(c *Client) { + c.clientCapabilities = capabilities + } } // NewClient creates a new MCP client with the given transport. @@ -31,10 +41,16 @@ type Client struct { // if err != nil { // log.Fatalf("Failed to create client: %v", err) // } -func NewClient(transport transport.Interface) *Client { - return &Client{ +func NewClient(transport transport.Interface, options ...ClientOption) *Client { + client := &Client{ transport: transport, } + + for _, opt := range options { + opt(client) + } + + return client } // Start initiates the connection to the server. @@ -115,7 +131,7 @@ func (c *Client) Initialize( params := struct { ProtocolVersion string `json:"protocolVersion"` ClientInfo mcp.Implementation `json:"clientInfo"` - Capabilities mcp.ClientCapabilities `json:"capabilities"` + Capabilities mcp.ClientCapabilities `json:"serverCapabilities"` }{ ProtocolVersion: request.Params.ProtocolVersion, ClientInfo: request.Params.ClientInfo, @@ -132,8 +148,8 @@ func (c *Client) Initialize( return nil, fmt.Errorf("failed to unmarshal response: %w", err) } - // Store capabilities - c.capabilities = result.Capabilities + // Store serverCapabilities + c.serverCapabilities = result.Capabilities // Send initialized notification notification := mcp.JSONRPCNotification{ @@ -406,3 +422,13 @@ func listByPage[T any]( func (c *Client) GetTransport() transport.Interface { return c.transport } + +// GetServerCapabilities returns the server capabilities. +func (c *Client) GetServerCapabilities() mcp.ServerCapabilities { + return c.serverCapabilities +} + +// GetClientCapabilities returns the client capabilities. +func (c *Client) GetClientCapabilities() mcp.ClientCapabilities { + return c.clientCapabilities +} diff --git a/client/inprocess.go b/client/inprocess.go new file mode 100644 index 00000000..5d8559de --- /dev/null +++ b/client/inprocess.go @@ -0,0 +1,12 @@ +package client + +import ( + "github.com/mark3labs/mcp-go/client/transport" + "github.com/mark3labs/mcp-go/server" +) + +// NewInProcessClient connect directly to a mcp server object in the same process +func NewInProcessClient(server *server.MCPServer) (*Client, error) { + inProcessTransport := transport.NewInProcessTransport(server) + return NewClient(inProcessTransport), nil +} diff --git a/client/inprocess_test.go b/client/inprocess_test.go new file mode 100644 index 00000000..de447602 --- /dev/null +++ b/client/inprocess_test.go @@ -0,0 +1,407 @@ +package client + +import ( + "context" + "testing" + + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" +) + +func TestInProcessMCPClient(t *testing.T) { + mcpServer := server.NewMCPServer( + "test-server", + "1.0.0", + server.WithResourceCapabilities(true, true), + server.WithPromptCapabilities(true), + server.WithToolCapabilities(true), + ) + + // Add a test tool + mcpServer.AddTool(mcp.NewTool( + "test-tool", + mcp.WithDescription("Test tool"), + mcp.WithString("parameter-1", mcp.Description("A string tool parameter")), + mcp.WithToolAnnotation(mcp.ToolAnnotation{ + Title: "Test Tool Annotation Title", + ReadOnlyHint: true, + DestructiveHint: false, + IdempotentHint: true, + OpenWorldHint: false, + }), + ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + return &mcp.CallToolResult{ + Content: []mcp.Content{ + mcp.TextContent{ + Type: "text", + Text: "Input parameter: " + request.Params.Arguments["parameter-1"].(string), + }, + }, + }, nil + }) + + mcpServer.AddResource( + mcp.Resource{ + URI: "resource://testresource", + Name: "My Resource", + }, + func(ctx context.Context, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { + return []mcp.ResourceContents{ + mcp.TextResourceContents{ + URI: "resource://testresource", + MIMEType: "text/plain", + Text: "test content", + }, + }, nil + }, + ) + + mcpServer.AddPrompt( + mcp.Prompt{ + Name: "test-prompt", + Description: "A test prompt", + Arguments: []mcp.PromptArgument{ + { + Name: "arg1", + Description: "First argument", + }, + }, + }, + func(ctx context.Context, request mcp.GetPromptRequest) (*mcp.GetPromptResult, error) { + return &mcp.GetPromptResult{ + Messages: []mcp.PromptMessage{ + { + Role: mcp.RoleAssistant, + Content: mcp.TextContent{ + Type: "text", + Text: "Test prompt with arg1: " + request.Params.Arguments["arg1"], + }, + }, + }, + }, nil + }, + ) + + t.Run("Can initialize and make requests", func(t *testing.T) { + client, err := NewInProcessClient(mcpServer) + if err != nil { + t.Fatalf("Failed to create client: %v", err) + } + defer client.Close() + + // Start the client + if err := client.Start(context.Background()); err != nil { + t.Fatalf("Failed to start client: %v", err) + } + + // Initialize + initRequest := mcp.InitializeRequest{} + initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION + initRequest.Params.ClientInfo = mcp.Implementation{ + Name: "test-client", + Version: "1.0.0", + } + + result, err := client.Initialize(context.Background(), initRequest) + if err != nil { + t.Fatalf("Failed to initialize: %v", err) + } + + if result.ServerInfo.Name != "test-server" { + t.Errorf( + "Expected server name 'test-server', got '%s'", + result.ServerInfo.Name, + ) + } + + // Test Ping + if err := client.Ping(context.Background()); err != nil { + t.Errorf("Ping failed: %v", err) + } + + // Test ListTools + toolsRequest := mcp.ListToolsRequest{} + toolListResult, err := client.ListTools(context.Background(), toolsRequest) + if err != nil { + t.Errorf("ListTools failed: %v", err) + } + if toolListResult == nil || len((*toolListResult).Tools) == 0 { + t.Errorf("Expected one tool") + } + testToolAnnotations := (*toolListResult).Tools[0].Annotations + if testToolAnnotations.Title != "Test Tool Annotation Title" || + testToolAnnotations.ReadOnlyHint != true || + testToolAnnotations.DestructiveHint != false || + testToolAnnotations.IdempotentHint != true || + testToolAnnotations.OpenWorldHint != false { + t.Errorf("The annotations of the tools are invalid") + } + }) + + t.Run("Handles errors properly", func(t *testing.T) { + client, err := NewInProcessClient(mcpServer) + if err != nil { + t.Fatalf("Failed to create client: %v", err) + } + defer client.Close() + + if err := client.Start(context.Background()); err != nil { + t.Fatalf("Failed to start client: %v", err) + } + + // Try to make a request without initializing + toolsRequest := mcp.ListToolsRequest{} + _, err = client.ListTools(context.Background(), toolsRequest) + if err == nil { + t.Error("Expected error when making request before initialization") + } + }) + + t.Run("CallTool", func(t *testing.T) { + client, err := NewInProcessClient(mcpServer) + if err != nil { + t.Fatalf("Failed to create client: %v", err) + } + defer client.Close() + + if err := client.Start(context.Background()); err != nil { + t.Fatalf("Failed to start client: %v", err) + } + + // Initialize + initRequest := mcp.InitializeRequest{} + initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION + initRequest.Params.ClientInfo = mcp.Implementation{ + Name: "test-client", + Version: "1.0.0", + } + + _, err = client.Initialize(context.Background(), initRequest) + if err != nil { + t.Fatalf("Failed to initialize: %v", err) + } + + request := mcp.CallToolRequest{} + request.Params.Name = "test-tool" + request.Params.Arguments = map[string]interface{}{ + "parameter-1": "value1", + } + + result, err := client.CallTool(context.Background(), request) + if err != nil { + t.Fatalf("CallTool failed: %v", err) + } + + if len(result.Content) != 1 { + t.Errorf("Expected 1 content item, got %d", len(result.Content)) + } + }) + + t.Run("Ping", func(t *testing.T) { + client, err := NewInProcessClient(mcpServer) + if err != nil { + t.Fatalf("Failed to create client: %v", err) + } + defer client.Close() + + if err := client.Start(context.Background()); err != nil { + t.Fatalf("Failed to start client: %v", err) + } + + // Initialize + initRequest := mcp.InitializeRequest{} + initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION + initRequest.Params.ClientInfo = mcp.Implementation{ + Name: "test-client", + Version: "1.0.0", + } + + _, err = client.Initialize(context.Background(), initRequest) + if err != nil { + t.Fatalf("Failed to initialize: %v", err) + } + + err = client.Ping(context.Background()) + if err != nil { + t.Errorf("Ping failed: %v", err) + } + }) + + t.Run("ListResources", func(t *testing.T) { + client, err := NewInProcessClient(mcpServer) + if err != nil { + t.Fatalf("Failed to create client: %v", err) + } + defer client.Close() + + if err := client.Start(context.Background()); err != nil { + t.Fatalf("Failed to start client: %v", err) + } + + // Initialize + initRequest := mcp.InitializeRequest{} + initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION + initRequest.Params.ClientInfo = mcp.Implementation{ + Name: "test-client", + Version: "1.0.0", + } + + _, err = client.Initialize(context.Background(), initRequest) + if err != nil { + t.Fatalf("Failed to initialize: %v", err) + } + + request := mcp.ListResourcesRequest{} + result, err := client.ListResources(context.Background(), request) + if err != nil { + t.Errorf("ListResources failed: %v", err) + } + + if len(result.Resources) != 1 { + t.Errorf("Expected 1 resource, got %d", len(result.Resources)) + } + }) + + t.Run("ReadResource", func(t *testing.T) { + client, err := NewInProcessClient(mcpServer) + if err != nil { + t.Fatalf("Failed to create client: %v", err) + } + defer client.Close() + + if err := client.Start(context.Background()); err != nil { + t.Fatalf("Failed to start client: %v", err) + } + + // Initialize + initRequest := mcp.InitializeRequest{} + initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION + initRequest.Params.ClientInfo = mcp.Implementation{ + Name: "test-client", + Version: "1.0.0", + } + + _, err = client.Initialize(context.Background(), initRequest) + if err != nil { + t.Fatalf("Failed to initialize: %v", err) + } + + request := mcp.ReadResourceRequest{} + request.Params.URI = "resource://testresource" + + result, err := client.ReadResource(context.Background(), request) + if err != nil { + t.Errorf("ReadResource failed: %v", err) + } + + if len(result.Contents) != 1 { + t.Errorf("Expected 1 content item, got %d", len(result.Contents)) + } + }) + + t.Run("ListPrompts", func(t *testing.T) { + client, err := NewInProcessClient(mcpServer) + if err != nil { + t.Fatalf("Failed to create client: %v", err) + } + defer client.Close() + + if err := client.Start(context.Background()); err != nil { + t.Fatalf("Failed to start client: %v", err) + } + + // Initialize + initRequest := mcp.InitializeRequest{} + initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION + initRequest.Params.ClientInfo = mcp.Implementation{ + Name: "test-client", + Version: "1.0.0", + } + + _, err = client.Initialize(context.Background(), initRequest) + if err != nil { + t.Fatalf("Failed to initialize: %v", err) + } + request := mcp.ListPromptsRequest{} + result, err := client.ListPrompts(context.Background(), request) + if err != nil { + t.Errorf("ListPrompts failed: %v", err) + } + + if len(result.Prompts) != 1 { + t.Errorf("Expected 1 prompt, got %d", len(result.Prompts)) + } + }) + + t.Run("GetPrompt", func(t *testing.T) { + client, err := NewInProcessClient(mcpServer) + if err != nil { + t.Fatalf("Failed to create client: %v", err) + } + defer client.Close() + + if err := client.Start(context.Background()); err != nil { + t.Fatalf("Failed to start client: %v", err) + } + + // Initialize + initRequest := mcp.InitializeRequest{} + initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION + initRequest.Params.ClientInfo = mcp.Implementation{ + Name: "test-client", + Version: "1.0.0", + } + + _, err = client.Initialize(context.Background(), initRequest) + if err != nil { + t.Fatalf("Failed to initialize: %v", err) + } + + request := mcp.GetPromptRequest{} + request.Params.Name = "test-prompt" + + result, err := client.GetPrompt(context.Background(), request) + if err != nil { + t.Errorf("GetPrompt failed: %v", err) + } + + if len(result.Messages) != 1 { + t.Errorf("Expected 1 message, got %d", len(result.Messages)) + } + }) + + t.Run("ListTools", func(t *testing.T) { + client, err := NewInProcessClient(mcpServer) + if err != nil { + t.Fatalf("Failed to create client: %v", err) + } + defer client.Close() + + if err := client.Start(context.Background()); err != nil { + t.Fatalf("Failed to start client: %v", err) + } + + // Initialize + initRequest := mcp.InitializeRequest{} + initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION + initRequest.Params.ClientInfo = mcp.Implementation{ + Name: "test-client", + Version: "1.0.0", + } + + _, err = client.Initialize(context.Background(), initRequest) + if err != nil { + t.Fatalf("Failed to initialize: %v", err) + } + + request := mcp.ListToolsRequest{} + result, err := client.ListTools(context.Background(), request) + if err != nil { + t.Errorf("ListTools failed: %v", err) + } + + if len(result.Tools) != 1 { + t.Errorf("Expected 1 tool, got %d", len(result.Tools)) + } + }) +} diff --git a/client/transport/inprocess.go b/client/transport/inprocess.go new file mode 100644 index 00000000..90fc2fae --- /dev/null +++ b/client/transport/inprocess.go @@ -0,0 +1,70 @@ +package transport + +import ( + "context" + "encoding/json" + "fmt" + "sync" + + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" +) + +type InProcessTransport struct { + server *server.MCPServer + + onNotification func(mcp.JSONRPCNotification) + notifyMu sync.RWMutex +} + +func NewInProcessTransport(server *server.MCPServer) *InProcessTransport { + return &InProcessTransport{ + server: server, + } +} + +func (c *InProcessTransport) Start(ctx context.Context) error { + return nil +} + +func (c *InProcessTransport) SendRequest(ctx context.Context, request JSONRPCRequest) (*JSONRPCResponse, error) { + requestBytes, err := json.Marshal(request) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + requestBytes = append(requestBytes, '\n') + + respMessage := c.server.HandleMessage(ctx, requestBytes) + respByte, err := json.Marshal(respMessage) + if err != nil { + return nil, fmt.Errorf("failed to marshal response message: %w", err) + } + rpcResp := JSONRPCResponse{} + err = json.Unmarshal(respByte, &rpcResp) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal response message: %w", err) + } + + return &rpcResp, nil +} + +func (c *InProcessTransport) SendNotification(ctx context.Context, notification mcp.JSONRPCNotification) error { + notificationBytes, err := json.Marshal(notification) + if err != nil { + return fmt.Errorf("failed to marshal notification: %w", err) + } + notificationBytes = append(notificationBytes, '\n') + c.server.HandleMessage(ctx, notificationBytes) + + return nil +} + +func (c *InProcessTransport) SetNotificationHandler(handler func(notification mcp.JSONRPCNotification)) { + c.notifyMu.Lock() + defer c.notifyMu.Unlock() + c.onNotification = handler +} + +func (*InProcessTransport) Close() error { + return nil +} diff --git a/mcp/types.go b/mcp/types.go index c940a460..2b2c6f00 100644 --- a/mcp/types.go +++ b/mcp/types.go @@ -12,40 +12,54 @@ type MCPMethod string const ( // Initiates connection and negotiates protocol capabilities. - // https://spec.modelcontextprotocol.io/specification/2024-11-05/basic/lifecycle/#initialization + // https://modelcontextprotocol.io/specification/2024-11-05/basic/lifecycle/#initialization MethodInitialize MCPMethod = "initialize" // Verifies connection liveness between client and server. - // https://spec.modelcontextprotocol.io/specification/2024-11-05/basic/utilities/ping/ + // https://modelcontextprotocol.io/specification/2024-11-05/basic/utilities/ping/ MethodPing MCPMethod = "ping" // Lists all available server resources. - // https://spec.modelcontextprotocol.io/specification/2024-11-05/server/resources/ + // https://modelcontextprotocol.io/specification/2024-11-05/server/resources/ MethodResourcesList MCPMethod = "resources/list" // Provides URI templates for constructing resource URIs. - // https://spec.modelcontextprotocol.io/specification/2024-11-05/server/resources/ + // https://modelcontextprotocol.io/specification/2024-11-05/server/resources/ MethodResourcesTemplatesList MCPMethod = "resources/templates/list" // Retrieves content of a specific resource by URI. - // https://spec.modelcontextprotocol.io/specification/2024-11-05/server/resources/ + // https://modelcontextprotocol.io/specification/2024-11-05/server/resources/ MethodResourcesRead MCPMethod = "resources/read" // Lists all available prompt templates. - // https://spec.modelcontextprotocol.io/specification/2024-11-05/server/prompts/ + // https://modelcontextprotocol.io/specification/2024-11-05/server/prompts/ MethodPromptsList MCPMethod = "prompts/list" // Retrieves a specific prompt template with filled parameters. - // https://spec.modelcontextprotocol.io/specification/2024-11-05/server/prompts/ + // https://modelcontextprotocol.io/specification/2024-11-05/server/prompts/ MethodPromptsGet MCPMethod = "prompts/get" // Lists all available executable tools. - // https://spec.modelcontextprotocol.io/specification/2024-11-05/server/tools/ + // https://modelcontextprotocol.io/specification/2024-11-05/server/tools/ MethodToolsList MCPMethod = "tools/list" // Invokes a specific tool with provided parameters. - // https://spec.modelcontextprotocol.io/specification/2024-11-05/server/tools/ + // https://modelcontextprotocol.io/specification/2024-11-05/server/tools/ MethodToolsCall MCPMethod = "tools/call" + + // Notifies when the list of available resources changes. + // https://modelcontextprotocol.io/specification/2025-03-26/server/resources#list-changed-notification + MethodNotificationResourcesListChanged = "notifications/resources/list_changed" + + MethodNotificationResourceUpdated = "notifications/resources/updated" + + // Notifies when the list of available prompt templates changes. + // https://modelcontextprotocol.io/specification/2025-03-26/server/prompts#list-changed-notification + MethodNotificationPromptsListChanged = "notifications/prompts/list_changed" + + // Notifies when the list of available tools changes. + // https://spec.modelcontextprotocol.io/specification/2024-11-05/server/tools/list_changed/ + MethodNotificationToolsListChanged = "notifications/tools/list_changed" ) type URITemplate struct { @@ -226,6 +240,11 @@ const ( INTERNAL_ERROR = -32603 ) +// MCP error codes +const ( + RESOURCE_NOT_FOUND = -32002 +) + /* Empty result */ // EmptyResult represents a response that indicates success but carries no data. diff --git a/server/hooks.go b/server/hooks.go index ce976a6c..30519d4c 100644 --- a/server/hooks.go +++ b/server/hooks.go @@ -11,6 +11,9 @@ import ( // OnRegisterSessionHookFunc is a hook that will be called when a new session is registered. type OnRegisterSessionHookFunc func(ctx context.Context, session ClientSession) +// OnUnregisterSessionHookFunc is a hook that will be called when a session is being unregistered. +type OnUnregisterSessionHookFunc func(ctx context.Context, session ClientSession) + // BeforeAnyHookFunc is a function that is called after the request is // parsed but before the method is called. type BeforeAnyHookFunc func(ctx context.Context, id any, method mcp.MCPMethod, message any) @@ -33,7 +36,7 @@ type OnSuccessHookFunc func(ctx context.Context, id any, method mcp.MCPMethod, m // } // // // Use errors.As to get specific error types -// var parseErr = &UnparseableMessageError{} +// var parseErr = &UnparsableMessageError{} // if errors.As(err, &parseErr) { // // Access specific methods/fields of the error type // log.Printf("Failed to parse message for method %s: %v", @@ -83,6 +86,7 @@ type OnAfterCallToolFunc func(ctx context.Context, id any, message *mcp.CallTool type Hooks struct { OnRegisterSession []OnRegisterSessionHookFunc + OnUnregisterSession []OnUnregisterSessionHookFunc OnBeforeAny []BeforeAnyHookFunc OnSuccess []OnSuccessHookFunc OnError []OnErrorHookFunc @@ -135,9 +139,9 @@ func (c *Hooks) AddOnSuccess(hook OnSuccessHookFunc) { // } // // // For parsing errors -// var parseErr = &UnparseableMessageError{} +// var parseErr = &UnparsableMessageError{} // if errors.As(err, &parseErr) { -// // Handle unparseable message errors +// // Handle unparsable message errors // fmt.Printf("Failed to parse %s request: %v\n", // parseErr.GetMethod(), parseErr.Unwrap()) // errChan <- parseErr @@ -191,7 +195,7 @@ func (c *Hooks) onSuccess(ctx context.Context, id any, method mcp.MCPMethod, mes // // Common error types include: // - ErrUnsupported: When a capability is not enabled -// - UnparseableMessageError: When request parsing fails +// - UnparsableMessageError: When request parsing fails // - ErrResourceNotFound: When a resource is not found // - ErrPromptNotFound: When a prompt is not found // - ErrToolNotFound: When a tool is not found @@ -216,6 +220,19 @@ func (c *Hooks) RegisterSession(ctx context.Context, session ClientSession) { hook(ctx, session) } } + +func (c *Hooks) AddOnUnregisterSession(hook OnUnregisterSessionHookFunc) { + c.OnUnregisterSession = append(c.OnUnregisterSession, hook) +} + +func (c *Hooks) UnregisterSession(ctx context.Context, session ClientSession) { + if c == nil { + return + } + for _, hook := range c.OnUnregisterSession { + hook(ctx, session) + } +} func (c *Hooks) AddBeforeInitialize(hook OnBeforeInitializeFunc) { c.OnBeforeInitialize = append(c.OnBeforeInitialize, hook) } diff --git a/server/internal/gen/hooks.go.tmpl b/server/internal/gen/hooks.go.tmpl index 4a8dcf1b..9451589d 100644 --- a/server/internal/gen/hooks.go.tmpl +++ b/server/internal/gen/hooks.go.tmpl @@ -14,6 +14,8 @@ import ( // OnRegisterSessionHookFunc is a hook that will be called when a new session is registered. type OnRegisterSessionHookFunc func(ctx context.Context, session ClientSession) +// OnUnregisterSessionHookFunc is a hook that will be called when a session is being unregistered. +type OnUnregisterSessionHookFunc func(ctx context.Context, session ClientSession) // BeforeAnyHookFunc is a function that is called after the request is // parsed but before the method is called. @@ -36,7 +38,7 @@ type OnSuccessHookFunc func(ctx context.Context, id any, method mcp.MCPMethod, m // } // // // Use errors.As to get specific error types -// var parseErr = &UnparseableMessageError{} +// var parseErr = &UnparsableMessageError{} // if errors.As(err, &parseErr) { // // Access specific methods/fields of the error type // log.Printf("Failed to parse message for method %s: %v", @@ -63,7 +65,8 @@ type OnAfter{{.HookName}}Func func(ctx context.Context, id any, message *mcp.{{. {{end}} type Hooks struct { - OnRegisterSession []OnRegisterSessionHookFunc + OnRegisterSession []OnRegisterSessionHookFunc + OnUnregisterSession []OnUnregisterSessionHookFunc OnBeforeAny []BeforeAnyHookFunc OnSuccess []OnSuccessHookFunc OnError []OnErrorHookFunc @@ -101,9 +104,9 @@ func (c *Hooks) AddOnSuccess(hook OnSuccessHookFunc) { // } // // // For parsing errors -// var parseErr = &UnparseableMessageError{} +// var parseErr = &UnparsableMessageError{} // if errors.As(err, &parseErr) { -// // Handle unparseable message errors +// // Handle unparsable message errors // fmt.Printf("Failed to parse %s request: %v\n", // parseErr.GetMethod(), parseErr.Unwrap()) // errChan <- parseErr @@ -157,7 +160,7 @@ func (c *Hooks) onSuccess(ctx context.Context, id any, method mcp.MCPMethod, mes // // Common error types include: // - ErrUnsupported: When a capability is not enabled -// - UnparseableMessageError: When request parsing fails +// - UnparsableMessageError: When request parsing fails // - ErrResourceNotFound: When a resource is not found // - ErrPromptNotFound: When a prompt is not found // - ErrToolNotFound: When a tool is not found @@ -183,6 +186,19 @@ func (c *Hooks) RegisterSession(ctx context.Context, session ClientSession) { } } +func (c *Hooks) AddOnUnregisterSession(hook OnUnregisterSessionHookFunc) { + c.OnUnregisterSession = append(c.OnUnregisterSession, hook) +} + +func (c *Hooks) UnregisterSession(ctx context.Context, session ClientSession) { + if c == nil { + return + } + for _, hook := range c.OnUnregisterSession { + hook(ctx, session) + } +} + {{- range .}} func (c *Hooks) AddBefore{{.HookName}}(hook OnBefore{{.HookName}}Func) { c.OnBefore{{.HookName}} = append(c.OnBefore{{.HookName}}, hook) diff --git a/server/internal/gen/request_handler.go.tmpl b/server/internal/gen/request_handler.go.tmpl index 5c69f5fa..e78f2799 100644 --- a/server/internal/gen/request_handler.go.tmpl +++ b/server/internal/gen/request_handler.go.tmpl @@ -78,7 +78,7 @@ func (s *MCPServer) HandleMessage( err = &requestError{ id: baseMessage.ID, code: mcp.INVALID_REQUEST, - err: &UnparseableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method}, + err: &UnparsableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method}, } } else { s.hooks.before{{.HookName}}(ctx, baseMessage.ID, &request) diff --git a/server/request_handler.go b/server/request_handler.go index 55d2d19e..0d0e68e8 100644 --- a/server/request_handler.go +++ b/server/request_handler.go @@ -70,7 +70,7 @@ func (s *MCPServer) HandleMessage( err = &requestError{ id: baseMessage.ID, code: mcp.INVALID_REQUEST, - err: &UnparseableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method}, + err: &UnparsableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method}, } } else { s.hooks.beforeInitialize(ctx, baseMessage.ID, &request) @@ -89,7 +89,7 @@ func (s *MCPServer) HandleMessage( err = &requestError{ id: baseMessage.ID, code: mcp.INVALID_REQUEST, - err: &UnparseableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method}, + err: &UnparsableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method}, } } else { s.hooks.beforePing(ctx, baseMessage.ID, &request) @@ -114,7 +114,7 @@ func (s *MCPServer) HandleMessage( err = &requestError{ id: baseMessage.ID, code: mcp.INVALID_REQUEST, - err: &UnparseableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method}, + err: &UnparsableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method}, } } else { s.hooks.beforeListResources(ctx, baseMessage.ID, &request) @@ -139,7 +139,7 @@ func (s *MCPServer) HandleMessage( err = &requestError{ id: baseMessage.ID, code: mcp.INVALID_REQUEST, - err: &UnparseableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method}, + err: &UnparsableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method}, } } else { s.hooks.beforeListResourceTemplates(ctx, baseMessage.ID, &request) @@ -164,7 +164,7 @@ func (s *MCPServer) HandleMessage( err = &requestError{ id: baseMessage.ID, code: mcp.INVALID_REQUEST, - err: &UnparseableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method}, + err: &UnparsableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method}, } } else { s.hooks.beforeReadResource(ctx, baseMessage.ID, &request) @@ -189,7 +189,7 @@ func (s *MCPServer) HandleMessage( err = &requestError{ id: baseMessage.ID, code: mcp.INVALID_REQUEST, - err: &UnparseableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method}, + err: &UnparsableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method}, } } else { s.hooks.beforeListPrompts(ctx, baseMessage.ID, &request) @@ -214,7 +214,7 @@ func (s *MCPServer) HandleMessage( err = &requestError{ id: baseMessage.ID, code: mcp.INVALID_REQUEST, - err: &UnparseableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method}, + err: &UnparsableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method}, } } else { s.hooks.beforeGetPrompt(ctx, baseMessage.ID, &request) @@ -239,7 +239,7 @@ func (s *MCPServer) HandleMessage( err = &requestError{ id: baseMessage.ID, code: mcp.INVALID_REQUEST, - err: &UnparseableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method}, + err: &UnparsableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method}, } } else { s.hooks.beforeListTools(ctx, baseMessage.ID, &request) @@ -264,7 +264,7 @@ func (s *MCPServer) HandleMessage( err = &requestError{ id: baseMessage.ID, code: mcp.INVALID_REQUEST, - err: &UnparseableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method}, + err: &UnparsableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method}, } } else { s.hooks.beforeCallTool(ctx, baseMessage.ID, &request) diff --git a/server/server.go b/server/server.go index 8ebd40bd..430f8d53 100644 --- a/server/server.go +++ b/server/server.go @@ -73,27 +73,27 @@ func ClientSessionFromContext(ctx context.Context) ClientSession { return nil } -// UnparseableMessageError is attached to the RequestError when json.Unmarshal +// UnparsableMessageError is attached to the RequestError when json.Unmarshal // fails on the request. -type UnparseableMessageError struct { +type UnparsableMessageError struct { message json.RawMessage method mcp.MCPMethod err error } -func (e *UnparseableMessageError) Error() string { - return fmt.Sprintf("unparseable %s request: %s", e.method, e.err) +func (e *UnparsableMessageError) Error() string { + return fmt.Sprintf("unparsable %s request: %s", e.method, e.err) } -func (e *UnparseableMessageError) Unwrap() error { +func (e *UnparsableMessageError) Unwrap() error { return e.err } -func (e *UnparseableMessageError) GetMessage() json.RawMessage { +func (e *UnparsableMessageError) GetMessage() json.RawMessage { return e.message } -func (e *UnparseableMessageError) GetMethod() mcp.MCPMethod { +func (e *UnparsableMessageError) GetMethod() mcp.MCPMethod { return e.method } @@ -206,13 +206,15 @@ func (s *MCPServer) RegisterSession( // UnregisterSession removes from storage session that is shut down. func (s *MCPServer) UnregisterSession( + ctx context.Context, sessionID string, ) { - s.sessions.Delete(sessionID) + session, _ := s.sessions.LoadAndDelete(sessionID) + s.hooks.UnregisterSession(ctx, session.(ClientSession)) } -// sendNotificationToAllClients sends a notification to all the currently active clients. -func (s *MCPServer) sendNotificationToAllClients( +// SendNotificationToAllClients sends a notification to all the currently active clients. +func (s *MCPServer) SendNotificationToAllClients( method string, params map[string]any, ) { @@ -417,6 +419,12 @@ func (s *MCPServer) AddResource( resource: resource, handler: handler, } + + // When the list of available resources changes, servers that declared the listChanged capability SHOULD send a notification + if s.capabilities.resources.listChanged { + // Send notification to all initialized sessions + s.SendNotificationToAllClients(mcp.MethodNotificationResourcesListChanged, nil) + } } // RemoveResource removes a resource from the server @@ -427,7 +435,7 @@ func (s *MCPServer) RemoveResource(uri string) { // Send notification to all initialized sessions if listChanged capability is enabled if s.capabilities.resources != nil && s.capabilities.resources.listChanged { - s.sendNotificationToAllClients("resources/list_changed", nil) + s.SendNotificationToAllClients("resources/list_changed", nil) } } @@ -448,6 +456,12 @@ func (s *MCPServer) AddResourceTemplate( template: template, handler: handler, } + + // When the list of available resources changes, servers that declared the listChanged capability SHOULD send a notification + if s.capabilities.resources.listChanged { + // Send notification to all initialized sessions + s.SendNotificationToAllClients(mcp.MethodNotificationResourcesListChanged, nil) + } } // AddPrompt registers a new prompt handler with the given name @@ -462,6 +476,12 @@ func (s *MCPServer) AddPrompt(prompt mcp.Prompt, handler PromptHandlerFunc) { defer s.promptsMu.Unlock() s.prompts[prompt.Name] = prompt s.promptHandlers[prompt.Name] = handler + + // When the list of available resources changes, servers that declared the listChanged capability SHOULD send a notification. + if s.capabilities.prompts.listChanged { + // Send notification to all initialized sessions + s.SendNotificationToAllClients(mcp.MethodNotificationPromptsListChanged, nil) + } } // AddTool registers a new tool and its handler @@ -483,14 +503,17 @@ func (s *MCPServer) AddTools(tools ...ServerTool) { } s.toolsMu.Unlock() - // Send notification to all initialized sessions - s.sendNotificationToAllClients("notifications/tools/list_changed", nil) + // When the list of available tools changes, servers that declared the listChanged capability SHOULD send a notification. + if s.capabilities.tools.listChanged { + // Send notification to all initialized sessions + s.SendNotificationToAllClients(mcp.MethodNotificationToolsListChanged, nil) + } } // SetTools replaces all existing tools with the provided list func (s *MCPServer) SetTools(tools ...ServerTool) { s.toolsMu.Lock() - s.tools = make(map[string]ServerTool) + s.tools = make(map[string]ServerTool, len(tools)) s.toolsMu.Unlock() s.AddTools(tools...) } @@ -503,8 +526,11 @@ func (s *MCPServer) DeleteTools(names ...string) { } s.toolsMu.Unlock() - // Send notification to all initialized sessions - s.sendNotificationToAllClients("notifications/tools/list_changed", nil) + // When the list of available tools changes, servers that declared the listChanged capability SHOULD send a notification. + if s.capabilities.tools.listChanged { + // Send notification to all initialized sessions + s.SendNotificationToAllClients(mcp.MethodNotificationToolsListChanged, nil) + } } // AddNotificationHandler registers a new handler for incoming notifications @@ -712,7 +738,7 @@ func (s *MCPServer) handleReadResource( matched = true matchedVars := template.URITemplate.Match(request.Params.URI) // Convert matched variables to a map - request.Params.Arguments = make(map[string]interface{}) + request.Params.Arguments = make(map[string]interface{}, len(matchedVars)) for name, value := range matchedVars { request.Params.Arguments[name] = value.V } @@ -735,7 +761,7 @@ func (s *MCPServer) handleReadResource( return nil, &requestError{ id: id, - code: mcp.INVALID_PARAMS, + code: mcp.RESOURCE_NOT_FOUND, err: fmt.Errorf("handler not found for resource URI '%s': %w", request.Params.URI, ErrResourceNotFound), } } diff --git a/server/server_test.go b/server/server_test.go index e55008f1..641a3c88 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -199,7 +199,7 @@ func TestMCPServer_Tools(t *testing.T) { }, expectedNotifications: 1, validate: func(t *testing.T, notifications []mcp.JSONRPCNotification, toolsList mcp.JSONRPCMessage) { - assert.Equal(t, "notifications/tools/list_changed", notifications[0].Method) + assert.Equal(t, mcp.MethodNotificationToolsListChanged, notifications[0].Method) tools := toolsList.(mcp.JSONRPCResponse).Result.(mcp.ListToolsResult).Tools assert.Len(t, tools, 2) assert.Equal(t, "test-tool-1", tools[0].Name) @@ -241,7 +241,7 @@ func TestMCPServer_Tools(t *testing.T) { expectedNotifications: 5, validate: func(t *testing.T, notifications []mcp.JSONRPCNotification, toolsList mcp.JSONRPCMessage) { for _, notification := range notifications { - assert.Equal(t, "notifications/tools/list_changed", notification.Method) + assert.Equal(t, mcp.MethodNotificationToolsListChanged, notification.Method) } tools := toolsList.(mcp.JSONRPCResponse).Result.(mcp.ListToolsResult).Tools assert.Len(t, tools, 2) @@ -269,8 +269,8 @@ func TestMCPServer_Tools(t *testing.T) { }, expectedNotifications: 2, validate: func(t *testing.T, notifications []mcp.JSONRPCNotification, toolsList mcp.JSONRPCMessage) { - assert.Equal(t, "notifications/tools/list_changed", notifications[0].Method) - assert.Equal(t, "notifications/tools/list_changed", notifications[1].Method) + assert.Equal(t, mcp.MethodNotificationToolsListChanged, notifications[0].Method) + assert.Equal(t, mcp.MethodNotificationToolsListChanged, notifications[1].Method) tools := toolsList.(mcp.JSONRPCResponse).Result.(mcp.ListToolsResult).Tools assert.Len(t, tools, 2) assert.Equal(t, "test-tool-1", tools[0].Name) @@ -294,9 +294,9 @@ func TestMCPServer_Tools(t *testing.T) { expectedNotifications: 2, validate: func(t *testing.T, notifications []mcp.JSONRPCNotification, toolsList mcp.JSONRPCMessage) { // One for SetTools - assert.Equal(t, "notifications/tools/list_changed", notifications[0].Method) + assert.Equal(t, mcp.MethodNotificationToolsListChanged, notifications[0].Method) // One for DeleteTools - assert.Equal(t, "notifications/tools/list_changed", notifications[1].Method) + assert.Equal(t, mcp.MethodNotificationToolsListChanged, notifications[1].Method) // Expect a successful response with an empty list of tools resp, ok := toolsList.(mcp.JSONRPCResponse) @@ -312,7 +312,7 @@ func TestMCPServer_Tools(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { ctx := context.Background() - server := NewMCPServer("test-server", "1.0.0") + server := NewMCPServer("test-server", "1.0.0", WithToolCapabilities(true)) _ = server.HandleMessage(ctx, []byte(`{ "jsonrpc": "2.0", "id": 1, @@ -340,7 +340,6 @@ func TestMCPServer_Tools(t *testing.T) { }`)) tt.validate(t, notifications, toolsList.(mcp.JSONRPCMessage)) }) - } } @@ -573,6 +572,75 @@ func TestMCPServer_SendNotificationToClient(t *testing.T) { } } +func TestMCPServer_SendNotificationToAllClients(t *testing.T) { + + contextPrepare := func(ctx context.Context, srv *MCPServer) context.Context { + // Create 5 active sessions + for i := 0; i < 5; i++ { + err := srv.RegisterSession(ctx, &fakeSession{ + sessionID: fmt.Sprintf("test%d", i), + notificationChannel: make(chan mcp.JSONRPCNotification, 10), + initialized: true, + }) + require.NoError(t, err) + } + return ctx + } + + validate := func(t *testing.T, ctx context.Context, srv *MCPServer) { + // Send 10 notifications to all sessions + for i := 0; i < 10; i++ { + srv.SendNotificationToAllClients("method", map[string]any{ + "count": i, + }) + } + + // Verify each session received all 10 notifications + srv.sessions.Range(func(k, v any) bool { + session := v.(ClientSession) + fakeSess := session.(*fakeSession) + notificationCount := 0 + + // Read all notifications from the channel + for notificationCount < 10 { + select { + case notification := <-fakeSess.notificationChannel: + // Verify notification method + assert.Equal(t, "method", notification.Method) + // Verify count parameter + count, ok := notification.Params.AdditionalFields["count"] + assert.True(t, ok, "count parameter not found") + assert.Equal(t, notificationCount, count.(int), "count should match notification count") + notificationCount++ + case <-time.After(100 * time.Millisecond): + t.Errorf("timeout waiting for notification %d for session %s", notificationCount, session.SessionID()) + return false + } + } + + // Verify no more notifications + select { + case notification := <-fakeSess.notificationChannel: + t.Errorf("unexpected notification received: %v", notification) + default: + // Channel empty as expected + } + return true + }) + } + + t.Run("all sessions", func(t *testing.T) { + server := NewMCPServer("test-server", "1.0.0") + ctx := contextPrepare(context.Background(), server) + _ = server.HandleMessage(ctx, []byte(`{ + "jsonrpc": "2.0", + "id": 1, + "method": "initialize" + }`)) + validate(t, ctx, server) + }) +} + func TestMCPServer_PromptHandling(t *testing.T) { server := NewMCPServer("test-server", "1.0.0", WithPromptCapabilities(true), @@ -725,11 +793,11 @@ func TestMCPServer_HandleInvalidMessages(t *testing.T) { message: `{"jsonrpc": "2.0", "id": 1, "method": "initialize", "params": "invalid"}`, expectedErr: mcp.INVALID_REQUEST, validateErr: func(t *testing.T, err error) { - var unparseableErr = &UnparseableMessageError{} - var ok = errors.As(err, &unparseableErr) - assert.True(t, ok, "Error should be UnparseableMessageError") - assert.Equal(t, mcp.MethodInitialize, unparseableErr.GetMethod()) - assert.Equal(t, json.RawMessage(`{"jsonrpc": "2.0", "id": 1, "method": "initialize", "params": "invalid"}`), unparseableErr.GetMessage()) + unparsableErr := &UnparsableMessageError{} + ok := errors.As(err, &unparsableErr) + assert.True(t, ok, "Error should be UnparsableMessageError") + assert.Equal(t, mcp.MethodInitialize, unparsableErr.GetMethod()) + assert.Equal(t, json.RawMessage(`{"jsonrpc": "2.0", "id": 1, "method": "initialize", "params": "invalid"}`), unparsableErr.GetMessage()) }, }, { @@ -861,7 +929,7 @@ func TestMCPServer_HandleUndefinedHandlers(t *testing.T) { "uri": "undefined-resource" } }`, - expectedErr: mcp.INVALID_PARAMS, + expectedErr: mcp.RESOURCE_NOT_FOUND, validateCallbacks: func(t *testing.T, err error, beforeResults beforeResult) { assert.Equal(t, mcp.MethodResourcesRead, beforeResults.method) assert.True(t, errors.Is(err, ErrResourceNotFound)) @@ -1125,7 +1193,6 @@ func TestMCPServer_ResourceTemplates(t *testing.T) { assert.Equal(t, "test://something/test-resource/a/b/c", resultContent.URI) assert.Equal(t, "text/plain", resultContent.MIMEType) assert.Equal(t, "test content: something", resultContent.Text) - }) } @@ -1353,6 +1420,76 @@ func TestMCPServer_WithHooks(t *testing.T) { assert.IsType(t, afterPingData[0].res, onSuccessData[0].res, "OnSuccess result should be same type as AfterPing result") } +func TestMCPServer_SessionHooks(t *testing.T) { + var ( + registerCalled bool + unregisterCalled bool + + registeredContext context.Context + unregisteredContext context.Context + + registeredSession ClientSession + unregisteredSession ClientSession + ) + + hooks := &Hooks{} + hooks.AddOnRegisterSession(func(ctx context.Context, session ClientSession) { + registerCalled = true + registeredContext = ctx + registeredSession = session + }) + hooks.AddOnUnregisterSession(func(ctx context.Context, session ClientSession) { + unregisterCalled = true + unregisteredContext = ctx + unregisteredSession = session + }) + + server := NewMCPServer( + "test-server", + "1.0.0", + WithHooks(hooks), + ) + + testSession := &fakeSession{ + sessionID: "test-session-id", + notificationChannel: make(chan mcp.JSONRPCNotification, 5), + initialized: false, + } + + ctx := context.WithoutCancel(context.Background()) + err := server.RegisterSession(ctx, testSession) + require.NoError(t, err) + + assert.True(t, registerCalled, "Register session hook was not called") + assert.Equal(t, testSession.SessionID(), registeredSession.SessionID(), + "Register hook received wrong session") + + server.UnregisterSession(ctx, testSession.SessionID()) + + assert.True(t, unregisterCalled, "Unregister session hook was not called") + assert.Equal(t, testSession.SessionID(), unregisteredSession.SessionID(), + "Unregister hook received wrong session") + + assert.Equal(t, ctx, unregisteredContext, "Unregister hook received wrong context") + assert.Equal(t, ctx, registeredContext, "Register hook received wrong context") +} + +func TestMCPServer_SessionHooks_NilHooks(t *testing.T) { + server := NewMCPServer("test-server", "1.0.0") + + testSession := &fakeSession{ + sessionID: "test-session-id", + notificationChannel: make(chan mcp.JSONRPCNotification, 5), + initialized: false, + } + + ctx := context.WithoutCancel(context.Background()) + err := server.RegisterSession(ctx, testSession) + require.NoError(t, err) + + server.UnregisterSession(ctx, testSession.SessionID()) +} + func TestMCPServer_WithRecover(t *testing.T) { panicToolHandler := func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { panic("test panic") diff --git a/server/sse.go b/server/sse.go index b6ae2144..9a419150 100644 --- a/server/sse.go +++ b/server/sse.go @@ -179,10 +179,7 @@ func NewSSEServer(server *MCPServer, opts ...SSEOption) *SSEServer { // NewTestServer creates a test server for testing purposes func NewTestServer(server *MCPServer, opts ...SSEOption) *httptest.Server { - sseServer := NewSSEServer(server) - for _, opt := range opts { - opt(sseServer) - } + sseServer := NewSSEServer(server, opts...) testServer := httptest.NewServer(sseServer) sseServer.baseURL = testServer.URL @@ -259,7 +256,7 @@ func (s *SSEServer) handleSSE(w http.ResponseWriter, r *http.Request) { http.Error(w, fmt.Sprintf("Session registration failed: %v", err), http.StatusInternalServerError) return } - defer s.server.UnregisterSession(sessionID) + defer s.server.UnregisterSession(r.Context(), sessionID) // Start notification handler for this session go func() { @@ -324,6 +321,8 @@ func (s *SSEServer) handleSSE(w http.ResponseWriter, r *http.Request) { case <-r.Context().Done(): close(session.done) return + case <-session.done: + return } } } @@ -438,6 +437,7 @@ func (s *SSEServer) SendEventToSession( return fmt.Errorf("event queue full") } } + func (s *SSEServer) GetUrlPath(input string) (string, error) { parse, err := url.Parse(input) if err != nil { @@ -449,6 +449,7 @@ func (s *SSEServer) GetUrlPath(input string) (string, error) { func (s *SSEServer) CompleteSseEndpoint() string { return s.baseURL + s.basePath + s.sseEndpoint } + func (s *SSEServer) CompleteSsePath() string { path, err := s.GetUrlPath(s.CompleteSseEndpoint()) if err != nil { @@ -460,6 +461,7 @@ func (s *SSEServer) CompleteSsePath() string { func (s *SSEServer) CompleteMessageEndpoint() string { return s.baseURL + s.basePath + s.messageEndpoint } + func (s *SSEServer) CompleteMessagePath() string { path, err := s.GetUrlPath(s.CompleteMessageEndpoint()) if err != nil { diff --git a/server/stdio.go b/server/stdio.go index 43d9570c..0de9f347 100644 --- a/server/stdio.go +++ b/server/stdio.go @@ -204,7 +204,7 @@ func (s *StdioServer) Listen( if err := s.server.RegisterSession(ctx, &stdioSessionInstance); err != nil { return fmt.Errorf("register session: %w", err) } - defer s.server.UnregisterSession(stdioSessionInstance.SessionID()) + defer s.server.UnregisterSession(ctx, stdioSessionInstance.SessionID()) ctx = s.server.WithContext(ctx, &stdioSessionInstance) // Add in any custom context.