diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 67086433c..523417b3b 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -12,6 +12,7 @@ jobs: uses: actions/checkout@v3 - name: Create GitHub Release + id: create_release uses: actions/create-release@v1 env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} @@ -20,3 +21,38 @@ jobs: release_name: Release ${{ github.ref }} draft: false prerelease: false + + - name: Send Discord Notification + if: success() + env: + DISCORD_WEBHOOK: ${{ secrets.RELEASES_WEBHOOK }} + TAG_NAME: ${{ github.ref_name }} + RELEASE_URL: ${{ steps.create_release.outputs.html_url }} + run: | + curl -H "Content-Type: application/json" \ + -X POST \ + -d "{ + \"embeds\": [{ + \"title\": \"🚀 New Release: $TAG_NAME\", + \"description\": \"A new version of mcp-go has been released!\", + \"color\": 5814783, + \"fields\": [ + { + \"name\": \"Version\", + \"value\": \"$TAG_NAME\", + \"inline\": true + }, + { + \"name\": \"Repository\", + \"value\": \"[mcp-go](https://github.com/${{ github.repository }})\", + \"inline\": true + } + ], + \"footer\": { + \"text\": \"Released via GitHub Actions\" + }, + \"timestamp\": \"$(date -u +%Y-%m-%dT%H:%M:%S.000Z)\", + \"url\": \"$RELEASE_URL\" + }] + }" \ + $DISCORD_WEBHOOK diff --git a/examples/structured_output/README.md b/examples/structured_input_and_output/README.md similarity index 79% rename from examples/structured_output/README.md rename to examples/structured_input_and_output/README.md index e2de01fcf..76de8e1df 100644 --- a/examples/structured_output/README.md +++ b/examples/structured_input_and_output/README.md @@ -6,6 +6,15 @@ Defined in the MCP spec here: https://modelcontextprotocol.io/specification/2025 ## Usage +Define a struct for your input: + +```go +type WeatherRequest struct { + Location string `json:"location,required" jsonschema_description:"City or location"` + Units string `json:"units,omitempty" jsonschema_description:"celsius or fahrenheit" jsonschema:"enum=celsius,enum=fahrenheit"` +} +``` + Define a struct for your output: ```go @@ -21,8 +30,8 @@ Add it to your tool: ```go tool := mcp.NewTool("get_weather", mcp.WithDescription("Get weather information"), + mcp.WithInputSchema[WeatherRequest](), mcp.WithOutputSchema[WeatherResponse](), - mcp.WithString("location", mcp.Required()), ) ``` diff --git a/examples/structured_output/main.go b/examples/structured_input_and_output/main.go similarity index 90% rename from examples/structured_output/main.go rename to examples/structured_input_and_output/main.go index e7df04021..f932def08 100644 --- a/examples/structured_output/main.go +++ b/examples/structured_input_and_output/main.go @@ -12,8 +12,8 @@ import ( // Note: The jsonschema_description tag is added to the JSON schema as description // Ideally use better descriptions, this is just an example type WeatherRequest struct { - Location string `json:"location" jsonschema_description:"City or location"` - Units string `json:"units,omitempty" jsonschema_description:"celsius or fahrenheit"` + Location string `json:"location" jsonschema_description:"City or location" jsonschema:"required"` + Units string `json:"units,omitempty" jsonschema_description:"celsius or fahrenheit" jsonschema:"enum=celsius,enum=fahrenheit"` } type WeatherResponse struct { @@ -32,7 +32,7 @@ type UserProfile struct { } type UserRequest struct { - UserID string `json:"userId" jsonschema_description:"User ID"` + UserID string `json:"userId" jsonschema_description:"User ID" jsonschema:"required"` } type Asset struct { @@ -43,12 +43,12 @@ type Asset struct { } type AssetListRequest struct { - Limit int `json:"limit,omitempty" jsonschema_description:"Number of assets to return"` + Limit int `json:"limit,omitempty" jsonschema_description:"Number of assets to return" jsonschema:"minimum=1,maximum=100,default=10"` } func main() { s := server.NewMCPServer( - "Structured Output Example", + "Structured Input/Output Example", "1.0.0", server.WithToolCapabilities(false), ) @@ -56,33 +56,32 @@ func main() { // Example 1: Auto-generated schema from struct weatherTool := mcp.NewTool("get_weather", mcp.WithDescription("Get weather with structured output"), + mcp.WithInputSchema[WeatherRequest](), mcp.WithOutputSchema[WeatherResponse](), - mcp.WithString("location", mcp.Required()), - mcp.WithString("units", mcp.Enum("celsius", "fahrenheit"), mcp.DefaultString("celsius")), ) s.AddTool(weatherTool, mcp.NewStructuredToolHandler(getWeatherHandler)) // Example 2: Nested struct schema userTool := mcp.NewTool("get_user_profile", mcp.WithDescription("Get user profile"), + mcp.WithInputSchema[UserRequest](), mcp.WithOutputSchema[UserProfile](), - mcp.WithString("userId", mcp.Required()), ) s.AddTool(userTool, mcp.NewStructuredToolHandler(getUserProfileHandler)) // Example 3: Array output - direct array of objects assetsTool := mcp.NewTool("get_assets", mcp.WithDescription("Get list of assets as array"), + mcp.WithInputSchema[AssetListRequest](), mcp.WithOutputSchema[[]Asset](), - mcp.WithNumber("limit", mcp.Min(1), mcp.Max(100), mcp.DefaultNumber(10)), ) s.AddTool(assetsTool, mcp.NewStructuredToolHandler(getAssetsHandler)) // Example 4: Manual result creation manualTool := mcp.NewTool("manual_structured", mcp.WithDescription("Manual structured result"), + mcp.WithInputSchema[WeatherRequest](), mcp.WithOutputSchema[WeatherResponse](), - mcp.WithString("location", mcp.Required()), ) s.AddTool(manualTool, mcp.NewTypedToolHandler(manualWeatherHandler)) diff --git a/mcp/consts.go b/mcp/consts.go new file mode 100644 index 000000000..66eb3803b --- /dev/null +++ b/mcp/consts.go @@ -0,0 +1,9 @@ +package mcp + +const ( + ContentTypeText = "text" + ContentTypeImage = "image" + ContentTypeAudio = "audio" + ContentTypeLink = "resource_link" + ContentTypeResource = "resource" +) diff --git a/mcp/tools.go b/mcp/tools.go index 500503e2a..3f3674923 100644 --- a/mcp/tools.go +++ b/mcp/tools.go @@ -486,6 +486,11 @@ func (r CallToolResult) MarshalJSON() ([]byte, error) { } m["content"] = content + // Marshal StructuredContent if present + if r.StructuredContent != nil { + m["structuredContent"] = r.StructuredContent + } + // Marshal IsError if true if r.IsError { m["isError"] = r.IsError @@ -526,6 +531,11 @@ func (r *CallToolResult) UnmarshalJSON(data []byte) error { } } + // Unmarshal StructuredContent if present + if structured, ok := raw["structuredContent"]; ok { + r.StructuredContent = structured + } + // Unmarshal IsError if isError, ok := raw["isError"]; ok { if isErrorBool, ok := isError.(bool); ok { @@ -704,6 +714,47 @@ func WithDescription(description string) ToolOption { } } +// WithInputSchema creates a ToolOption that sets the input schema for a tool. +// It accepts any Go type, usually a struct, and automatically generates a JSON schema from it. +func WithInputSchema[T any]() ToolOption { + return func(t *Tool) { + var zero T + + // Generate schema using invopop/jsonschema library + // Configure reflector to generate clean, MCP-compatible schemas + reflector := jsonschema.Reflector{ + DoNotReference: true, // Removes $defs map, outputs entire structure inline + Anonymous: true, // Hides auto-generated Schema IDs + AllowAdditionalProperties: true, // Removes additionalProperties: false + } + schema := reflector.Reflect(zero) + + // Clean up schema for MCP compliance + schema.Version = "" // Remove $schema field + + // Convert to raw JSON for MCP + mcpSchema, err := json.Marshal(schema) + if err != nil { + // Skip and maintain backward compatibility + return + } + + t.InputSchema.Type = "" + t.RawInputSchema = json.RawMessage(mcpSchema) + } +} + +// WithRawInputSchema sets a raw JSON schema for the tool's input. +// Use this when you need full control over the schema or when working with +// complex schemas that can't be generated from Go types. The jsonschema library +// can handle complex schemas and provides nice extension points, so be sure to +// check that out before using this. +func WithRawInputSchema(schema json.RawMessage) ToolOption { + return func(t *Tool) { + t.RawInputSchema = schema + } +} + // WithOutputSchema creates a ToolOption that sets the output schema for a tool. // It accepts any Go type, usually a struct, and automatically generates a JSON schema from it. func WithOutputSchema[T any]() ToolOption { diff --git a/mcp/tools_test.go b/mcp/tools_test.go index 7beec31dd..13c0f5643 100644 --- a/mcp/tools_test.go +++ b/mcp/tools_test.go @@ -306,7 +306,6 @@ func TestParseToolCallToolRequest(t *testing.T) { param15 := ParseInt64(request, "string_value", 1) assert.Equal(t, fmt.Sprintf("%T", param15), "int64") t.Logf("param15 type: %T,value:%v", param15, param15) - } func TestCallToolRequestBindArguments(t *testing.T) { @@ -529,6 +528,55 @@ func TestFlexibleArgumentsJSONMarshalUnmarshal(t *testing.T) { assert.Equal(t, float64(123), args["key2"]) // JSON numbers are unmarshaled as float64 } +// TestToolWithInputSchema tests that the WithInputSchema function +// generates an MCP-compatible JSON output schema for a tool +func TestToolWithInputSchema(t *testing.T) { + type TestInput struct { + Name string `json:"name" jsonschema_description:"Person's name" jsonschema:"required"` + Age int `json:"age" jsonschema_description:"Person's age"` + Email string `json:"email,omitempty" jsonschema_description:"Email address" jsonschema:"required"` + } + + tool := NewTool("test_tool", + WithDescription("Test tool with output schema"), + WithInputSchema[TestInput](), + ) + + // Check that RawOutputSchema was set + assert.NotNil(t, tool.RawInputSchema) + + // Marshal and verify structure + data, err := json.Marshal(tool) + assert.NoError(t, err) + + var toolData map[string]any + err = json.Unmarshal(data, &toolData) + assert.NoError(t, err) + + // Verify inputSchema exists + inputSchema, exists := toolData["inputSchema"] + assert.True(t, exists) + assert.NotNil(t, inputSchema) + + // Verify required list exists + schemaMap, ok := inputSchema.(map[string]interface{}) + assert.True(t, ok) + requiredList, exists := schemaMap["required"] + assert.True(t, exists) + assert.NotNil(t, requiredList) + + // Verify properties exist + properties, exists := schemaMap["properties"] + assert.True(t, exists) + propertiesMap, ok := properties.(map[string]interface{}) + assert.True(t, ok) + + // Verify specific properties + assert.Contains(t, propertiesMap, "name") + assert.Contains(t, propertiesMap, "age") + assert.Contains(t, propertiesMap, "email") +} + // TestToolWithOutputSchema tests that the WithOutputSchema function // generates an MCP-compatible JSON output schema for a tool func TestToolWithOutputSchema(t *testing.T) { @@ -580,6 +628,567 @@ func TestNewToolResultStructured(t *testing.T) { assert.NotNil(t, result.StructuredContent) } +// TestCallToolResultMarshalJSON tests the custom JSON marshaling of CallToolResult +func TestCallToolResultMarshalJSON(t *testing.T) { + tests := []struct { + name string + result CallToolResult + expected map[string]any + }{ + { + name: "basic result with text content", + result: CallToolResult{ + Result: Result{ + Meta: NewMetaFromMap(map[string]any{"key": "value"}), + }, + Content: []Content{ + TextContent{Type: "text", Text: "Hello, world!"}, + }, + IsError: false, + }, + expected: map[string]any{ + "_meta": map[string]any{"key": "value"}, + "content": []any{ + map[string]any{ + "type": "text", + "text": "Hello, world!", + }, + }, + }, + }, + { + name: "result with structured content", + result: CallToolResult{ + Result: Result{ + Meta: NewMetaFromMap(map[string]any{"key": "value"}), + }, + Content: []Content{ + TextContent{Type: "text", Text: "Operation completed"}, + }, + StructuredContent: map[string]any{ + "status": "success", + "count": 42, + "message": "Data processed successfully", + }, + IsError: false, + }, + expected: map[string]any{ + "_meta": map[string]any{"key": "value"}, + "content": []any{ + map[string]any{ + "type": "text", + "text": "Operation completed", + }, + }, + "structuredContent": map[string]any{ + "status": "success", + "count": float64(42), // JSON numbers are unmarshaled as float64 + "message": "Data processed successfully", + }, + }, + }, + { + name: "error result", + result: CallToolResult{ + Result: Result{ + Meta: NewMetaFromMap(map[string]any{"error_code": "E001"}), + }, + Content: []Content{ + TextContent{Type: "text", Text: "An error occurred"}, + }, + IsError: true, + }, + expected: map[string]any{ + "_meta": map[string]any{"error_code": "E001"}, + "content": []any{ + map[string]any{ + "type": "text", + "text": "An error occurred", + }, + }, + "isError": true, + }, + }, + { + name: "result with multiple content types", + result: CallToolResult{ + Result: Result{ + Meta: NewMetaFromMap(map[string]any{"session_id": "12345"}), + }, + Content: []Content{ + TextContent{Type: "text", Text: "Processing complete"}, + ImageContent{Type: "image", Data: "base64-encoded-image-data", MIMEType: "image/jpeg"}, + }, + StructuredContent: map[string]any{ + "processed_items": 100, + "errors": 0, + }, + IsError: false, + }, + expected: map[string]any{ + "_meta": map[string]any{"session_id": "12345"}, + "content": []any{ + map[string]any{ + "type": "text", + "text": "Processing complete", + }, + map[string]any{ + "type": "image", + "data": "base64-encoded-image-data", + "mimeType": "image/jpeg", + }, + }, + "structuredContent": map[string]any{ + "processed_items": float64(100), + "errors": float64(0), + }, + }, + }, + { + name: "result with nil structured content", + result: CallToolResult{ + Result: Result{ + Meta: NewMetaFromMap(map[string]any{"key": "value"}), + }, + Content: []Content{ + TextContent{Type: "text", Text: "Simple result"}, + }, + StructuredContent: nil, + IsError: false, + }, + expected: map[string]any{ + "_meta": map[string]any{"key": "value"}, + "content": []any{ + map[string]any{ + "type": "text", + "text": "Simple result", + }, + }, + }, + }, + { + name: "result with empty content array", + result: CallToolResult{ + Result: Result{ + Meta: NewMetaFromMap(map[string]any{"key": "value"}), + }, + Content: []Content{}, + StructuredContent: map[string]any{ + "data": "structured only", + }, + IsError: false, + }, + expected: map[string]any{ + "_meta": map[string]any{"key": "value"}, + "content": []any{}, + "structuredContent": map[string]any{ + "data": "structured only", + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Marshal the result + data, err := json.Marshal(tt.result) + assert.NoError(t, err) + + // Unmarshal to map for comparison + var result map[string]any + err = json.Unmarshal(data, &result) + assert.NoError(t, err) + + // Compare expected fields + for key, expectedValue := range tt.expected { + assert.Contains(t, result, key, "Result should contain key: %s", key) + assert.Equal(t, expectedValue, result[key], "Value for key %s should match", key) + } + + // Verify that unexpected fields are not present + for key := range result { + if key != "_meta" && key != "content" && key != "structuredContent" && key != "isError" { + t.Errorf("Unexpected field in result: %s", key) + } + } + }) + } +} + +// TestCallToolResultUnmarshalJSON tests the custom JSON unmarshaling of CallToolResult +func TestCallToolResultUnmarshalJSON(t *testing.T) { + tests := []struct { + name string + jsonData string + expected CallToolResult + wantErr bool + }{ + { + name: "basic result with text content", + jsonData: `{ + "_meta": {"key": "value"}, + "content": [ + {"type": "text", "text": "Hello, world!"} + ] + }`, + expected: CallToolResult{ + Result: Result{ + Meta: NewMetaFromMap(map[string]any{"key": "value"}), + }, + Content: []Content{ + TextContent{Type: "text", Text: "Hello, world!"}, + }, + IsError: false, + }, + wantErr: false, + }, + { + name: "result with structured content", + jsonData: `{ + "_meta": {"key": "value"}, + "content": [ + {"type": "text", "text": "Operation completed"} + ], + "structuredContent": { + "status": "success", + "count": 42, + "message": "Data processed successfully" + } + }`, + expected: CallToolResult{ + Result: Result{ + Meta: NewMetaFromMap(map[string]any{"key": "value"}), + }, + Content: []Content{ + TextContent{Type: "text", Text: "Operation completed"}, + }, + StructuredContent: map[string]any{ + "status": "success", + "count": float64(42), + "message": "Data processed successfully", + }, + IsError: false, + }, + wantErr: false, + }, + { + name: "error result", + jsonData: `{ + "_meta": {"error_code": "E001"}, + "content": [ + {"type": "text", "text": "An error occurred"} + ], + "isError": true + }`, + expected: CallToolResult{ + Result: Result{ + Meta: NewMetaFromMap(map[string]any{"error_code": "E001"}), + }, + Content: []Content{ + TextContent{Type: "text", Text: "An error occurred"}, + }, + IsError: true, + }, + wantErr: false, + }, + { + name: "result with multiple content types", + jsonData: `{ + "_meta": {"session_id": "12345"}, + "content": [ + {"type": "text", "text": "Processing complete"}, + {"type": "image", "data": "base64-encoded-image-data", "mimeType": "image/jpeg"} + ], + "structuredContent": { + "processed_items": 100, + "errors": 0 + } + }`, + expected: CallToolResult{ + Result: Result{ + Meta: NewMetaFromMap(map[string]any{"session_id": "12345"}), + }, + Content: []Content{ + TextContent{Type: "text", Text: "Processing complete"}, + ImageContent{Type: "image", Data: "base64-encoded-image-data", MIMEType: "image/jpeg"}, + }, + StructuredContent: map[string]any{ + "processed_items": float64(100), + "errors": float64(0), + }, + IsError: false, + }, + wantErr: false, + }, + { + name: "result with nil structured content", + jsonData: `{ + "_meta": {"key": "value"}, + "content": [ + {"type": "text", "text": "Simple result"} + ] + }`, + expected: CallToolResult{ + Result: Result{ + Meta: NewMetaFromMap(map[string]any{"key": "value"}), + }, + Content: []Content{ + TextContent{Type: "text", Text: "Simple result"}, + }, + StructuredContent: nil, + IsError: false, + }, + wantErr: false, + }, + { + name: "result with empty content array", + jsonData: `{ + "_meta": {"key": "value"}, + "content": [], + "structuredContent": { + "data": "structured only" + } + }`, + expected: CallToolResult{ + Result: Result{ + Meta: NewMetaFromMap(map[string]any{"key": "value"}), + }, + Content: []Content{}, + StructuredContent: map[string]any{ + "data": "structured only", + }, + IsError: false, + }, + wantErr: false, + }, + { + name: "invalid JSON", + jsonData: `{invalid json}`, + wantErr: true, + }, + { + name: "result with missing content field", + jsonData: `{ + "_meta": {"key": "value"}, + "structuredContent": {"data": "no content"} + }`, + expected: CallToolResult{ + Result: Result{ + Meta: NewMetaFromMap(map[string]any{"key": "value"}), + }, + Content: nil, + StructuredContent: map[string]any{ + "data": "no content", + }, + IsError: false, + }, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var result CallToolResult + err := json.Unmarshal([]byte(tt.jsonData), &result) + + if tt.wantErr { + assert.Error(t, err) + return + } + + assert.NoError(t, err) + + // Compare Meta + if tt.expected.Meta != nil { + assert.Equal(t, tt.expected.Meta, result.Meta) + } + + // Compare Content + assert.Len(t, result.Content, len(tt.expected.Content)) + for i, expectedContent := range tt.expected.Content { + if i < len(result.Content) { + // Compare content types and values + switch expected := expectedContent.(type) { + case TextContent: + if actual, ok := result.Content[i].(TextContent); ok { + assert.Equal(t, expected.Text, actual.Text) + } else { + t.Errorf("Expected TextContent at index %d, got %T", i, result.Content[i]) + } + case ImageContent: + if actual, ok := result.Content[i].(ImageContent); ok { + assert.Equal(t, expected.Data, actual.Data) + assert.Equal(t, expected.MIMEType, actual.MIMEType) + } else { + t.Errorf("Expected ImageContent at index %d, got %T", i, result.Content[i]) + } + } + } + } + + // Compare StructuredContent + assert.Equal(t, tt.expected.StructuredContent, result.StructuredContent) + + // Compare IsError + assert.Equal(t, tt.expected.IsError, result.IsError) + }) + } +} + +// TestCallToolResultRoundTrip tests that marshaling and unmarshaling preserves all data +func TestCallToolResultRoundTrip(t *testing.T) { + original := CallToolResult{ + Result: Result{ + Meta: NewMetaFromMap(map[string]any{ + "session_id": "12345", + "user_id": "user123", + "timestamp": "2024-01-01T00:00:00Z", + }), + }, + Content: []Content{ + TextContent{Type: "text", Text: "Operation started"}, + ImageContent{Type: "image", Data: "base64-encoded-chart-data", MIMEType: "image/png"}, + TextContent{Type: "text", Text: "Operation completed successfully"}, + }, + StructuredContent: map[string]any{ + "status": "success", + "processed_count": float64(150.0), + "error_count": float64(0.0), + "warnings": []any{"Minor issue detected"}, + "metadata": map[string]any{ + "version": "1.0.0", + "build": "2024-01-01", + }, + }, + IsError: false, + } + + // Marshal to JSON + data, err := json.Marshal(original) + assert.NoError(t, err) + + // Unmarshal back + var unmarshaled CallToolResult + err = json.Unmarshal(data, &unmarshaled) + assert.NoError(t, err) + + // Verify all fields are preserved + assert.Equal(t, original.Meta, unmarshaled.Meta) + assert.Equal(t, original.IsError, unmarshaled.IsError) + assert.Equal(t, original.StructuredContent, unmarshaled.StructuredContent) + + // Verify content array + assert.Len(t, unmarshaled.Content, len(original.Content)) + for i, expectedContent := range original.Content { + if i < len(unmarshaled.Content) { + switch expected := expectedContent.(type) { + case TextContent: + if actual, ok := unmarshaled.Content[i].(TextContent); ok { + assert.Equal(t, expected.Text, actual.Text) + } else { + t.Errorf("Expected TextContent at index %d, got %T", i, unmarshaled.Content[i]) + } + case ImageContent: + if actual, ok := unmarshaled.Content[i].(ImageContent); ok { + assert.Equal(t, expected.Data, actual.Data) + assert.Equal(t, expected.MIMEType, actual.MIMEType) + } else { + t.Errorf("Expected ImageContent at index %d, got %T", i, unmarshaled.Content[i]) + } + } + } + } +} + +// TestCallToolResultEdgeCases tests edge cases for CallToolResult marshaling/unmarshaling +func TestCallToolResultEdgeCases(t *testing.T) { + tests := []struct { + name string + result CallToolResult + jsonData string + }{ + { + name: "result with complex structured content", + result: CallToolResult{ + Content: []Content{ + TextContent{Type: "text", Text: "Complex data returned"}, + }, + StructuredContent: map[string]any{ + "nested": map[string]any{ + "array": []any{1, 2, 3, "string", true, nil}, + "object": map[string]any{ + "deep": map[string]any{ + "value": "very deep", + }, + }, + }, + "mixed_types": []any{ + map[string]any{"type": "object"}, + "string", + 42.5, + true, + nil, + }, + }, + }, + }, + { + name: "result with empty structured content object", + result: CallToolResult{ + Content: []Content{ + TextContent{Type: "text", Text: "Empty structured content"}, + }, + StructuredContent: map[string]any{}, + }, + }, + { + name: "result with null structured content in JSON", + jsonData: `{ + "content": [{"type": "text", "text": "Null structured content"}], + "structuredContent": null + }`, + }, + { + name: "result with missing isError field", + jsonData: `{ + "content": [{"type": "text", "text": "No error field"}] + }`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var data []byte + var err error + + if tt.jsonData != "" { + // Test unmarshaling from JSON + var result CallToolResult + err = json.Unmarshal([]byte(tt.jsonData), &result) + assert.NoError(t, err) + + // Verify the result can be marshaled back + data, err = json.Marshal(result) + assert.NoError(t, err) + } else { + // Test marshaling the result + data, err = json.Marshal(tt.result) + assert.NoError(t, err) + + // Verify it can be unmarshaled back + var result CallToolResult + err = json.Unmarshal(data, &result) + assert.NoError(t, err) + } + + // Verify the JSON is valid + var jsonMap map[string]any + err = json.Unmarshal(data, &jsonMap) + assert.NoError(t, err) + }) + } +} + // TestNewItemsAPICompatibility tests that the new Items API functions // generate the same schema as the original Items() function with manual schema objects func TestNewItemsAPICompatibility(t *testing.T) { diff --git a/mcp/types.go b/mcp/types.go index 344924992..f871b7d9d 100644 --- a/mcp/types.go +++ b/mcp/types.go @@ -8,8 +8,9 @@ import ( "maps" "strconv" - "github.com/yosida95/uritemplate/v3" "net/http" + + "github.com/yosida95/uritemplate/v3" ) type MCPMethod string @@ -1146,23 +1147,23 @@ func UnmarshalContent(data []byte) (Content, error) { } switch contentType { - case "text": + case ContentTypeText: var content TextContent err := json.Unmarshal(data, &content) return content, err - case "image": + case ContentTypeImage: var content ImageContent err := json.Unmarshal(data, &content) return content, err - case "audio": + case ContentTypeAudio: var content AudioContent err := json.Unmarshal(data, &content) return content, err - case "resource_link": + case ContentTypeLink: var content ResourceLink err := json.Unmarshal(data, &content) return content, err - case "resource": + case ContentTypeResource: var content EmbeddedResource err := json.Unmarshal(data, &content) return content, err diff --git a/mcp/utils.go b/mcp/utils.go index 4d2b170b4..b8deeae9c 100644 --- a/mcp/utils.go +++ b/mcp/utils.go @@ -198,7 +198,7 @@ func NewPromptMessage(role Role, content Content) PromptMessage { // Helper function to create a new TextContent func NewTextContent(text string) TextContent { return TextContent{ - Type: "text", + Type: ContentTypeText, Text: text, } } @@ -207,7 +207,7 @@ func NewTextContent(text string) TextContent { // Helper function to create a new ImageContent func NewImageContent(data, mimeType string) ImageContent { return ImageContent{ - Type: "image", + Type: ContentTypeImage, Data: data, MIMEType: mimeType, } @@ -216,7 +216,7 @@ func NewImageContent(data, mimeType string) ImageContent { // Helper function to create a new AudioContent func NewAudioContent(data, mimeType string) AudioContent { return AudioContent{ - Type: "audio", + Type: ContentTypeAudio, Data: data, MIMEType: mimeType, } @@ -225,7 +225,7 @@ func NewAudioContent(data, mimeType string) AudioContent { // Helper function to create a new ResourceLink func NewResourceLink(uri, name, description, mimeType string) ResourceLink { return ResourceLink{ - Type: "resource_link", + Type: ContentTypeLink, URI: uri, Name: name, Description: description, @@ -236,7 +236,7 @@ func NewResourceLink(uri, name, description, mimeType string) ResourceLink { // Helper function to create a new EmbeddedResource func NewEmbeddedResource(resource ResourceContents) EmbeddedResource { return EmbeddedResource{ - Type: "resource", + Type: ContentTypeResource, Resource: resource, } } @@ -246,7 +246,7 @@ func NewToolResultText(text string) *CallToolResult { return &CallToolResult{ Content: []Content{ TextContent{ - Type: "text", + Type: ContentTypeText, Text: text, }, }, @@ -296,11 +296,11 @@ func NewToolResultImage(text, imageData, mimeType string) *CallToolResult { return &CallToolResult{ Content: []Content{ TextContent{ - Type: "text", + Type: ContentTypeText, Text: text, }, ImageContent{ - Type: "image", + Type: ContentTypeImage, Data: imageData, MIMEType: mimeType, }, @@ -313,11 +313,11 @@ func NewToolResultAudio(text, imageData, mimeType string) *CallToolResult { return &CallToolResult{ Content: []Content{ TextContent{ - Type: "text", + Type: ContentTypeText, Text: text, }, AudioContent{ - Type: "audio", + Type: ContentTypeAudio, Data: imageData, MIMEType: mimeType, }, @@ -333,11 +333,11 @@ func NewToolResultResource( return &CallToolResult{ Content: []Content{ TextContent{ - Type: "text", + Type: ContentTypeText, Text: text, }, EmbeddedResource{ - Type: "resource", + Type: ContentTypeResource, Resource: resource, }, }, @@ -350,7 +350,7 @@ func NewToolResultError(text string) *CallToolResult { return &CallToolResult{ Content: []Content{ TextContent{ - Type: "text", + Type: ContentTypeText, Text: text, }, }, @@ -368,7 +368,7 @@ func NewToolResultErrorFromErr(text string, err error) *CallToolResult { return &CallToolResult{ Content: []Content{ TextContent{ - Type: "text", + Type: ContentTypeText, Text: text, }, }, @@ -383,7 +383,7 @@ func NewToolResultErrorf(format string, a ...any) *CallToolResult { return &CallToolResult{ Content: []Content{ TextContent{ - Type: "text", + Type: ContentTypeText, Text: fmt.Sprintf(format, a...), }, }, @@ -505,11 +505,11 @@ func ParseContent(contentMap map[string]any) (Content, error) { contentType := ExtractString(contentMap, "type") switch contentType { - case "text": + case ContentTypeText: text := ExtractString(contentMap, "text") return NewTextContent(text), nil - case "image": + case ContentTypeImage: data := ExtractString(contentMap, "data") mimeType := ExtractString(contentMap, "mimeType") if data == "" || mimeType == "" { @@ -517,7 +517,7 @@ func ParseContent(contentMap map[string]any) (Content, error) { } return NewImageContent(data, mimeType), nil - case "audio": + case ContentTypeAudio: data := ExtractString(contentMap, "data") mimeType := ExtractString(contentMap, "mimeType") if data == "" || mimeType == "" { @@ -525,7 +525,7 @@ func ParseContent(contentMap map[string]any) (Content, error) { } return NewAudioContent(data, mimeType), nil - case "resource_link": + case ContentTypeLink: uri := ExtractString(contentMap, "uri") name := ExtractString(contentMap, "name") description := ExtractString(contentMap, "description") @@ -535,7 +535,7 @@ func ParseContent(contentMap map[string]any) (Content, error) { } return NewResourceLink(uri, name, description, mimeType), nil - case "resource": + case ContentTypeResource: resourceMap := ExtractMap(contentMap, "resource") if resourceMap == nil { return nil, fmt.Errorf("resource is missing") @@ -670,6 +670,12 @@ func ParseCallToolResult(rawMessage *json.RawMessage) (*CallToolResult, error) { result.Content = append(result.Content, content) } + // Handle structured content + structuredContent, ok := jsonContent["structuredContent"] + if ok { + result.StructuredContent = structuredContent + } + return &result, nil } diff --git a/mcptest/mcptest_test.go b/mcptest/mcptest_test.go index 18922cb84..3e4be38e3 100644 --- a/mcptest/mcptest_test.go +++ b/mcptest/mcptest_test.go @@ -78,6 +78,95 @@ func resultToString(result *mcp.CallToolResult) (string, error) { return b.String(), nil } +func TestServerWithToolStructuredContent(t *testing.T) { + ctx := context.Background() + + srv, err := mcptest.NewServer(t, server.ServerTool{ + Tool: mcp.NewTool("get_user", + mcp.WithDescription("Gets user information with structured data."), + mcp.WithString("user_id", mcp.Description("The user ID to look up.")), + ), + Handler: structuredContentHandler, + }) + if err != nil { + t.Fatal(err) + } + defer srv.Close() + + client := srv.Client() + + var req mcp.CallToolRequest + req.Params.Name = "get_user" + req.Params.Arguments = map[string]any{ + "user_id": "123", + } + + result, err := client.CallTool(ctx, req) + if err != nil { + t.Fatal("CallTool:", err) + } + + if result.IsError { + t.Fatalf("unexpected error result: %+v", result) + } + + if len(result.Content) != 1 { + t.Fatalf("Expected 1 content item, got %d", len(result.Content)) + } + + // Check text content (fallback) + textContent, ok := result.Content[0].(mcp.TextContent) + if !ok { + t.Fatalf("Expected content to be TextContent, got %T", result.Content[0]) + } + expectedText := "User found" + if textContent.Text != expectedText { + t.Errorf("Expected text %q, got %q", expectedText, textContent.Text) + } + + // Check structured content + if result.StructuredContent == nil { + t.Fatal("Expected StructuredContent to be present") + } + + structuredData, ok := result.StructuredContent.(map[string]any) + if !ok { + t.Fatalf("Expected StructuredContent to be map[string]any, got %T", result.StructuredContent) + } + + // Verify structured data + if structuredData["id"] != "123" { + t.Errorf("Expected id '123', got %v", structuredData["id"]) + } + if structuredData["name"] != "John Doe" { + t.Errorf("Expected name 'John Doe', got %v", structuredData["name"]) + } + if structuredData["email"] != "john@example.com" { + t.Errorf("Expected email 'john@example.com', got %v", structuredData["email"]) + } + if structuredData["active"] != true { + t.Errorf("Expected active true, got %v", structuredData["active"]) + } +} + +func structuredContentHandler(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + userID, ok := request.GetArguments()["user_id"].(string) + if !ok { + return mcp.NewToolResultError("user_id parameter is required"), nil + } + + // Create structured data + userData := map[string]any{ + "id": userID, + "name": "John Doe", + "email": "john@example.com", + "active": true, + } + + // Use NewToolResultStructured to create result with both text fallback and structured content + return mcp.NewToolResultStructured(userData, "User found"), nil +} + func TestServerWithPrompt(t *testing.T) { ctx := context.Background() diff --git a/server/server.go b/server/server.go index 9f04e9478..366bf6611 100644 --- a/server/server.go +++ b/server/server.go @@ -365,6 +365,24 @@ func (s *MCPServer) AddResource( s.AddResources(ServerResource{Resource: resource, Handler: handler}) } +// DeleteResources removes resources from the server +func (s *MCPServer) DeleteResources(uris ...string) { + s.resourcesMu.Lock() + var exists bool + for _, uri := range uris { + if _, ok := s.resources[uri]; ok { + delete(s.resources, uri) + exists = true + } + } + s.resourcesMu.Unlock() + + // Send notification to all initialized sessions if listChanged capability is enabled and we actually remove a resource + if exists && s.capabilities.resources != nil && s.capabilities.resources.listChanged { + s.SendNotificationToAllClients(mcp.MethodNotificationResourcesListChanged, nil) + } +} + // RemoveResource removes a resource from the server func (s *MCPServer) RemoveResource(uri string) { s.resourcesMu.Lock() diff --git a/server/server_test.go b/server/server_test.go index aca99ef60..c4d6d4c00 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -963,6 +963,50 @@ func TestMCPServer_Prompts(t *testing.T) { assert.Equal(t, "test-prompt-2", prompts[1].Name) }, }, + { + name: "SetPrompts sends single notifications/prompts/list_changed with one active session", + action: func(t *testing.T, server *MCPServer, notificationChannel chan mcp.JSONRPCNotification) { + err := server.RegisterSession(context.TODO(), &fakeSession{ + sessionID: "test", + notificationChannel: notificationChannel, + initialized: true, + }) + require.NoError(t, err) + server.SetPrompts(ServerPrompt{ + Prompt: mcp.Prompt{ + Name: "test-prompt-1", + Description: "A test prompt", + Arguments: []mcp.PromptArgument{ + { + Name: "arg1", + Description: "First argument", + }, + }, + }, + Handler: nil, + }, ServerPrompt{ + Prompt: mcp.Prompt{ + Name: "test-prompt-2", + Description: "Another test prompt", + Arguments: []mcp.PromptArgument{ + { + Name: "arg2", + Description: "Second argument", + }, + }, + }, + Handler: nil, + }) + }, + expectedNotifications: 1, + validate: func(t *testing.T, notifications []mcp.JSONRPCNotification, promptsList mcp.JSONRPCMessage) { + assert.Equal(t, mcp.MethodNotificationPromptsListChanged, notifications[0].Method) + prompts := promptsList.(mcp.JSONRPCResponse).Result.(mcp.ListPromptsResult).Prompts + assert.Len(t, prompts, 2) + assert.Equal(t, "test-prompt-1", prompts[0].Name) + assert.Equal(t, "test-prompt-2", prompts[1].Name) + }, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -998,6 +1042,211 @@ func TestMCPServer_Prompts(t *testing.T) { } } +func TestMCPServer_Resources(t *testing.T) { + tests := []struct { + name string + action func(*testing.T, *MCPServer, chan mcp.JSONRPCNotification) + expectedNotifications int + validate func(*testing.T, []mcp.JSONRPCNotification, mcp.JSONRPCMessage) + }{ + { + name: "DeleteResources sends single notifications/resources/list_changed", + action: func(t *testing.T, server *MCPServer, notificationChannel chan mcp.JSONRPCNotification) { + err := server.RegisterSession(context.TODO(), &fakeSession{ + sessionID: "test", + notificationChannel: notificationChannel, + initialized: true, + }) + require.NoError(t, err) + server.AddResource( + mcp.Resource{ + URI: "test://test-resource-1", + Name: "Test Resource 1", + }, + func(ctx context.Context, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { + return []mcp.ResourceContents{}, nil + }, + ) + server.DeleteResources("test://test-resource-1") + }, + expectedNotifications: 2, + validate: func(t *testing.T, notifications []mcp.JSONRPCNotification, resourcesList mcp.JSONRPCMessage) { + // One for AddResource + assert.Equal(t, mcp.MethodNotificationResourcesListChanged, notifications[0].Method) + // One for DeleteResources + assert.Equal(t, mcp.MethodNotificationResourcesListChanged, notifications[1].Method) + + // Expect a successful response with an empty list of resources + resp, ok := resourcesList.(mcp.JSONRPCResponse) + assert.True(t, ok, "Expected JSONRPCResponse, got %T", resourcesList) + + result, ok := resp.Result.(mcp.ListResourcesResult) + assert.True(t, ok, "Expected ListResourcesResult, got %T", resp.Result) + + assert.Empty(t, result.Resources, "Expected empty resources list") + }, + }, + { + name: "DeleteResources removes the first resource and retains the other", + action: func(t *testing.T, server *MCPServer, notificationChannel chan mcp.JSONRPCNotification) { + err := server.RegisterSession(context.TODO(), &fakeSession{ + sessionID: "test", + notificationChannel: notificationChannel, + initialized: true, + }) + require.NoError(t, err) + server.AddResource( + mcp.Resource{ + URI: "test://test-resource-1", + Name: "Test Resource 1", + }, + func(ctx context.Context, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { + return []mcp.ResourceContents{}, nil + }, + ) + server.AddResource( + mcp.Resource{ + URI: "test://test-resource-2", + Name: "Test Resource 2", + }, + func(ctx context.Context, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { + return []mcp.ResourceContents{}, nil + }, + ) + server.DeleteResources("test://test-resource-1") + }, + expectedNotifications: 3, + validate: func(t *testing.T, notifications []mcp.JSONRPCNotification, resourcesList mcp.JSONRPCMessage) { + // first notification expected for AddResource test-resource-1 + assert.Equal(t, mcp.MethodNotificationResourcesListChanged, notifications[0].Method) + // second notification expected for AddResource test-resource-2 + assert.Equal(t, mcp.MethodNotificationResourcesListChanged, notifications[1].Method) + // third notification expected for DeleteResources test-resource-1 + assert.Equal(t, mcp.MethodNotificationResourcesListChanged, notifications[2].Method) + + // Confirm the resource list contains only test-resource-2 + resources := resourcesList.(mcp.JSONRPCResponse).Result.(mcp.ListResourcesResult).Resources + assert.Len(t, resources, 1) + assert.Equal(t, "test://test-resource-2", resources[0].URI) + }, + }, + { + name: "DeleteResources with non-existent resources does nothing and not receives notifications from MCPServer", + action: func(t *testing.T, server *MCPServer, notificationChannel chan mcp.JSONRPCNotification) { + err := server.RegisterSession(context.TODO(), &fakeSession{ + sessionID: "test", + notificationChannel: notificationChannel, + initialized: true, + }) + require.NoError(t, err) + server.AddResource( + mcp.Resource{ + URI: "test://test-resource-1", + Name: "Test Resource 1", + }, + func(ctx context.Context, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { + return []mcp.ResourceContents{}, nil + }, + ) + server.AddResource( + mcp.Resource{ + URI: "test://test-resource-2", + Name: "Test Resource 2", + }, + func(ctx context.Context, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { + return []mcp.ResourceContents{}, nil + }, + ) + // Remove non-existing resources + server.DeleteResources("test://test-resource-3", "test://test-resource-4") + }, + expectedNotifications: 2, + validate: func(t *testing.T, notifications []mcp.JSONRPCNotification, resourcesList mcp.JSONRPCMessage) { + // first notification expected for AddResource test-resource-1 + assert.Equal(t, mcp.MethodNotificationResourcesListChanged, notifications[0].Method) + // second notification expected for AddResource test-resource-2 + assert.Equal(t, mcp.MethodNotificationResourcesListChanged, notifications[1].Method) + + // Confirm the resource list does not change + resources := resourcesList.(mcp.JSONRPCResponse).Result.(mcp.ListResourcesResult).Resources + assert.Len(t, resources, 2) + // Resources are sorted by name + assert.Equal(t, "test://test-resource-1", resources[0].URI) + assert.Equal(t, "test://test-resource-2", resources[1].URI) + }, + }, + { + name: "SetResources sends single notifications/resources/list_changed with one active session", + action: func(t *testing.T, server *MCPServer, notificationChannel chan mcp.JSONRPCNotification) { + err := server.RegisterSession(context.TODO(), &fakeSession{ + sessionID: "test", + notificationChannel: notificationChannel, + initialized: true, + }) + require.NoError(t, err) + server.SetResources(ServerResource{ + Resource: mcp.Resource{ + URI: "test://test-resource-1", + Name: "Test Resource 1", + }, + Handler: func(ctx context.Context, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { + return []mcp.ResourceContents{}, nil + }, + }, ServerResource{ + Resource: mcp.Resource{ + URI: "test://test-resource-2", + Name: "Test Resource 2", + }, + Handler: func(ctx context.Context, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { + return []mcp.ResourceContents{}, nil + }, + }) + }, + expectedNotifications: 1, + validate: func(t *testing.T, notifications []mcp.JSONRPCNotification, resourcesList mcp.JSONRPCMessage) { + assert.Equal(t, mcp.MethodNotificationResourcesListChanged, notifications[0].Method) + resources := resourcesList.(mcp.JSONRPCResponse).Result.(mcp.ListResourcesResult).Resources + assert.Len(t, resources, 2) + // Resources are sorted by name + assert.Equal(t, "test://test-resource-1", resources[0].URI) + assert.Equal(t, "test://test-resource-2", resources[1].URI) + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := context.Background() + server := NewMCPServer("test-server", "1.0.0", WithResourceCapabilities(true, true)) + _ = server.HandleMessage(ctx, []byte(`{ + "jsonrpc": "2.0", + "id": 1, + "method": "initialize" + }`)) + notificationChannel := make(chan mcp.JSONRPCNotification, 100) + notifications := make([]mcp.JSONRPCNotification, 0) + tt.action(t, server, notificationChannel) + for done := false; !done; { + select { + case serverNotification := <-notificationChannel: + notifications = append(notifications, serverNotification) + if len(notifications) == tt.expectedNotifications { + done = true + } + case <-time.After(1 * time.Second): + done = true + } + } + assert.Len(t, notifications, tt.expectedNotifications) + resourcesList := server.HandleMessage(ctx, []byte(`{ + "jsonrpc": "2.0", + "id": 1, + "method": "resources/list" + }`)) + tt.validate(t, notifications, resourcesList) + }) + } +} + func TestMCPServer_HandleInvalidMessages(t *testing.T) { var errs []error hooks := &Hooks{} diff --git a/server/stdio.go b/server/stdio.go index 4d567d8cb..8c270e18b 100644 --- a/server/stdio.go +++ b/server/stdio.go @@ -29,6 +29,20 @@ type StdioServer struct { server *MCPServer errLogger *log.Logger contextFunc StdioContextFunc + + // Thread-safe tool call processing + toolCallQueue chan *toolCallWork + workerWg sync.WaitGroup + workerPoolSize int + queueSize int + writeMu sync.Mutex // Protects concurrent writes +} + +// toolCallWork represents a queued tool call request +type toolCallWork struct { + ctx context.Context + message json.RawMessage + writer io.Writer } // StdioOption defines a function type for configuring StdioServer @@ -50,6 +64,32 @@ func WithStdioContextFunc(fn StdioContextFunc) StdioOption { } } +// WithWorkerPoolSize sets the number of workers for processing tool calls +func WithWorkerPoolSize(size int) StdioOption { + return func(s *StdioServer) { + const maxWorkerPoolSize = 100 + if size > 0 && size <= maxWorkerPoolSize { + s.workerPoolSize = size + } else if size > maxWorkerPoolSize { + s.errLogger.Printf("Worker pool size %d exceeds maximum (%d), using maximum", size, maxWorkerPoolSize) + s.workerPoolSize = maxWorkerPoolSize + } + } +} + +// WithQueueSize sets the size of the tool call queue +func WithQueueSize(size int) StdioOption { + return func(s *StdioServer) { + const maxQueueSize = 10000 + if size > 0 && size <= maxQueueSize { + s.queueSize = size + } else if size > maxQueueSize { + s.errLogger.Printf("Queue size %d exceeds maximum (%d), using maximum", size, maxQueueSize) + s.queueSize = maxQueueSize + } + } +} + // stdioSession is a static client session, since stdio has only one client. type stdioSession struct { notifications chan mcp.JSONRPCNotification @@ -218,6 +258,8 @@ func NewStdioServer(server *MCPServer) *StdioServer { "", log.LstdFlags, ), // Default to discarding logs + workerPoolSize: 5, // Default worker pool size + queueSize: 100, // Default queue size } } @@ -281,6 +323,30 @@ func (s *StdioServer) processInputStream(ctx context.Context, reader *bufio.Read } } +// toolCallWorker processes tool calls from the queue +func (s *StdioServer) toolCallWorker(ctx context.Context) { + defer s.workerWg.Done() + + for { + select { + case work, ok := <-s.toolCallQueue: + if !ok { + // Channel closed, exit worker + return + } + // Process the tool call + response := s.server.HandleMessage(work.ctx, work.message) + if response != nil { + if err := s.writeResponse(response, work.writer); err != nil { + s.errLogger.Printf("Error writing tool response: %v", err) + } + } + case <-ctx.Done(): + return + } + } +} + // readNextLine reads a single line from the input reader in a context-aware manner. // It uses channels to make the read operation cancellable via context. // Returns the read line and any error encountered. If the context is cancelled, @@ -315,6 +381,9 @@ func (s *StdioServer) Listen( stdin io.Reader, stdout io.Writer, ) error { + // Initialize the tool call queue + s.toolCallQueue = make(chan *toolCallWork, s.queueSize) + // Set a static client context since stdio only has one client if err := s.server.RegisterSession(ctx, &stdioSessionInstance); err != nil { return fmt.Errorf("register session: %w", err) @@ -332,9 +401,23 @@ func (s *StdioServer) Listen( reader := bufio.NewReader(stdin) + // Start worker pool for tool calls + for i := 0; i < s.workerPoolSize; i++ { + s.workerWg.Add(1) + go s.toolCallWorker(ctx) + } + // Start notification handler go s.handleNotifications(ctx, stdout) - return s.processInputStream(ctx, reader, stdout) + + // Process input stream + err := s.processInputStream(ctx, reader, stdout) + + // Shutdown workers gracefully + close(s.toolCallQueue) + s.workerWg.Wait() + + return err } // processMessage handles a single JSON-RPC message and writes the response. @@ -367,16 +450,25 @@ func (s *StdioServer) processMessage( Method string `json:"method"` } if json.Unmarshal(rawMessage, &baseMessage) == nil && baseMessage.Method == "tools/call" { - // Process tool calls concurrently to avoid blocking on sampling requests - go func() { + // Queue tool calls for processing by workers + select { + case s.toolCallQueue <- &toolCallWork{ + ctx: ctx, + message: rawMessage, + writer: writer, + }: + return nil + case <-ctx.Done(): + return ctx.Err() + default: + // Queue is full, process synchronously as fallback + s.errLogger.Printf("Tool call queue full, processing synchronously") response := s.server.HandleMessage(ctx, rawMessage) if response != nil { - if err := s.writeResponse(response, writer); err != nil { - s.errLogger.Printf("Error writing tool response: %v", err) - } + return s.writeResponse(response, writer) } - }() - return nil + return nil + } } // Handle other messages synchronously @@ -462,6 +554,10 @@ func (s *StdioServer) writeResponse( return err } + // Protect concurrent writes + s.writeMu.Lock() + defer s.writeMu.Unlock() + // Write response followed by newline if _, err := fmt.Fprintf(writer, "%s\n", responseBytes); err != nil { return err diff --git a/server/stdio_test.go b/server/stdio_test.go index 4ec725927..8fb542a80 100644 --- a/server/stdio_test.go +++ b/server/stdio_test.go @@ -4,10 +4,13 @@ import ( "bufio" "context" "encoding/json" + "fmt" "io" "log" "os" + "sync" "testing" + "time" "github.com/mark3labs/mcp-go/mcp" ) @@ -267,4 +270,217 @@ func TestStdioServer(t *testing.T) { t.Errorf("unexpected server error: %v", err) } }) + + t.Run("Can handle concurrent tool calls", func(t *testing.T) { + // Create pipes for stdin and stdout + stdinReader, stdinWriter := io.Pipe() + stdoutReader, stdoutWriter := io.Pipe() + + // Track tool call executions (sync.Map is already thread-safe) + var callCount sync.Map + + // Create server with test tools + mcpServer := NewMCPServer("test", "1.0.0") + + // Add multiple tools that simulate work and track concurrent execution + for i := 0; i < 5; i++ { + toolName := fmt.Sprintf("test_tool_%d", i) + mcpServer.AddTool( + mcp.NewTool(toolName), + func(name string) func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + // Track concurrent executions + count, _ := callCount.LoadOrStore(name, 0) + callCount.Store(name, count.(int)+1) + + // Simulate some work + time.Sleep(10 * time.Millisecond) + + return mcp.NewToolResultText(fmt.Sprintf("Result from %s", name)), nil + } + }(toolName), + ) + } + + stdioServer := NewStdioServer(mcpServer) + stdioServer.SetErrorLogger(log.New(io.Discard, "", 0)) + + // Create context with cancel + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // Start server + serverErrCh := make(chan error, 1) + go func() { + err := stdioServer.Listen(ctx, stdinReader, stdoutWriter) + if err != nil && err != io.EOF && err != context.Canceled { + serverErrCh <- err + } + stdoutWriter.Close() + close(serverErrCh) + }() + + // Initialize the session + initRequest := map[string]any{ + "jsonrpc": "2.0", + "id": 1, + "method": "initialize", + "params": map[string]any{ + "protocolVersion": "2024-11-05", + "clientInfo": map[string]any{ + "name": "test-client", + "version": "1.0.0", + }, + }, + } + + requestBytes, _ := json.Marshal(initRequest) + if _, err := stdinWriter.Write(append(requestBytes, '\n')); err != nil { + t.Fatalf("Failed to write init request: %v", err) + } + + // Read init response + scanner := bufio.NewScanner(stdoutReader) + scanner.Scan() + + // Send multiple concurrent tool calls + var wg sync.WaitGroup + responseChan := make(chan string, 10) + + // Send 10 concurrent tool calls + for i := 0; i < 10; i++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + + toolRequest := map[string]any{ + "jsonrpc": "2.0", + "id": id + 2, + "method": "tools/call", + "params": map[string]any{ + "name": fmt.Sprintf("test_tool_%d", id%5), + }, + } + + requestBytes, _ := json.Marshal(toolRequest) + if _, err := stdinWriter.Write(append(requestBytes, '\n')); err != nil { + t.Errorf("Failed to write tool request %d: %v", id, err) + } + }(i) + } + + // Read all responses + go func() { + for i := 0; i < 10; i++ { + if scanner.Scan() { + responseChan <- scanner.Text() + } + } + close(responseChan) + }() + + // Wait for all requests to be sent + wg.Wait() + + // Collect responses + responses := make([]string, 0, 10) + timeout := time.After(2 * time.Second) + + collectLoop: + for len(responses) < 10 { + select { + case resp, ok := <-responseChan: + if !ok { + break collectLoop + } + responses = append(responses, resp) + case <-timeout: + t.Fatal("Timeout waiting for responses") + } + } + // Verify we got all responses + if len(responses) != 10 { + t.Errorf("Expected 10 responses, got %d", len(responses)) + } + + // Verify no errors in responses + for _, resp := range responses { + var response map[string]any + if err := json.Unmarshal([]byte(resp), &response); err != nil { + t.Errorf("Failed to unmarshal response: %v", err) + continue + } + + if response["error"] != nil { + t.Errorf("Unexpected error in response: %v", response["error"]) + } + + // Verify response has expected structure + if response["result"] == nil { + t.Error("Expected result in response") + } + } + + // Verify tools were called + callCount.Range(func(key, value interface{}) bool { + toolName := key.(string) + count := value.(int) + if count == 0 { + t.Errorf("Tool %s was not called", toolName) + } + return true + }) + + // Clean up + cancel() + stdinWriter.Close() + + // Check for server errors + if err := <-serverErrCh; err != nil { + t.Errorf("Server error: %v", err) + } + }) + + t.Run("Configuration options respect bounds", func(t *testing.T) { + mcpServer := NewMCPServer("test", "1.0.0") + + // Test worker pool size bounds + stdioServer := NewStdioServer(mcpServer) + WithWorkerPoolSize(150)(stdioServer) + if stdioServer.workerPoolSize != 100 { // Should use maximum + t.Errorf("Expected maximum worker pool size 100, got %d", stdioServer.workerPoolSize) + } + + // Test valid worker pool size + stdioServer = NewStdioServer(mcpServer) + WithWorkerPoolSize(50)(stdioServer) + if stdioServer.workerPoolSize != 50 { + t.Errorf("Expected worker pool size 50, got %d", stdioServer.workerPoolSize) + } + + // Test queue size bounds + stdioServer = NewStdioServer(mcpServer) + WithQueueSize(20000)(stdioServer) + if stdioServer.queueSize != 10000 { // Should use maximum + t.Errorf("Expected maximum queue size 10000, got %d", stdioServer.queueSize) + } + + // Test valid queue size + stdioServer = NewStdioServer(mcpServer) + WithQueueSize(500)(stdioServer) + if stdioServer.queueSize != 500 { + t.Errorf("Expected queue size 500, got %d", stdioServer.queueSize) + } + + // Test zero and negative values + stdioServer = NewStdioServer(mcpServer) + WithWorkerPoolSize(0)(stdioServer) + WithQueueSize(-10)(stdioServer) + if stdioServer.workerPoolSize != 5 { + t.Errorf("Expected default worker pool size 5 for zero input, got %d", stdioServer.workerPoolSize) + } + if stdioServer.queueSize != 100 { + t.Errorf("Expected default queue size 100 for negative input, got %d", stdioServer.queueSize) + } + }) } diff --git a/www/docs/pages/servers/tools.mdx b/www/docs/pages/servers/tools.mdx index e329bd1c1..7bd5bf75a 100644 --- a/www/docs/pages/servers/tools.mdx +++ b/www/docs/pages/servers/tools.mdx @@ -101,6 +101,310 @@ mcp.WithNumber("price", ) ``` +## Struct-Based Schema Definition + +MCP-Go supports defining input and output schemas using Go structs with automatic JSON schema generation. This provides a type-safe alternative to manual parameter definition, especially useful for complex tools with structured inputs and outputs. + +### Input Schema with Go Structs + +Define your input parameters as a Go struct and use `WithInputSchema`: + +```go +// Define input struct with JSON schema tags +type SearchRequest struct { + Query string `json:"query" jsonschema_description:"Search query" jsonschema:"required"` + Limit int `json:"limit,omitempty" jsonschema_description:"Maximum results" jsonschema:"minimum=1,maximum=100,default=10"` + Categories []string `json:"categories,omitempty" jsonschema_description:"Filter by categories"` + SortBy string `json:"sortBy,omitempty" jsonschema_description:"Sort field" jsonschema:"enum=relevance,enum=date,enum=popularity"` +} + +// Create tool with struct-based input schema +searchTool := mcp.NewTool("search_products", + mcp.WithDescription("Search product catalog"), + mcp.WithInputSchema[SearchRequest](), +) +``` + +### Output Schema with Go Structs + +Define structured output for predictable tool responses: + +```go +// Define output struct +type SearchResponse struct { + Query string `json:"query" jsonschema_description:"Original search query"` + TotalCount int `json:"totalCount" jsonschema_description:"Total matching products"` + Products []Product `json:"products" jsonschema_description:"Search results"` + ProcessedAt time.Time `json:"processedAt" jsonschema_description:"When search was performed"` +} + +type Product struct { + ID string `json:"id" jsonschema_description:"Product ID"` + Name string `json:"name" jsonschema_description:"Product name"` + Price float64 `json:"price" jsonschema_description:"Price in USD"` + InStock bool `json:"inStock" jsonschema_description:"Availability"` +} + +// Create tool with both input and output schemas +searchTool := mcp.NewTool("search_products", + mcp.WithDescription("Search product catalog with structured output"), + mcp.WithInputSchema[SearchRequest](), + mcp.WithOutputSchema[SearchResponse](), +) +``` + +### Structured Tool Handlers + +Use `NewStructuredToolHandler` for type-safe handler implementation: + +```go +func main() { + s := server.NewMCPServer("Product Search", "1.0.0", + server.WithToolCapabilities(false), + ) + + // Define tool with input and output schemas + searchTool := mcp.NewTool("search_products", + mcp.WithDescription("Search product catalog"), + mcp.WithInputSchema[SearchRequest](), + mcp.WithOutputSchema[SearchResponse](), + ) + + // Add tool with structured handler + s.AddTool(searchTool, mcp.NewStructuredToolHandler(searchProductsHandler)) + + server.ServeStdio(s) +} + +// Handler receives typed input and returns typed output +func searchProductsHandler(ctx context.Context, req mcp.CallToolRequest, args SearchRequest) (SearchResponse, error) { + // Input is already validated and bound to SearchRequest struct + limit := args.Limit + if limit <= 0 { + limit = 10 + } + + // Perform search logic + products := searchDatabase(args.Query, args.Categories, limit) + + // Return structured response + return SearchResponse{ + Query: args.Query, + TotalCount: len(products), + Products: products, + ProcessedAt: time.Now(), + }, nil +} +``` + +### Array Output Schema + +Tools can return arrays of structured data: + +```go +// Define asset struct +type Asset struct { + ID string `json:"id" jsonschema_description:"Asset identifier"` + Name string `json:"name" jsonschema_description:"Asset name"` + Value float64 `json:"value" jsonschema_description:"Current value"` + Currency string `json:"currency" jsonschema_description:"Currency code"` +} + +// Tool that returns array of assets +assetsTool := mcp.NewTool("list_assets", + mcp.WithDescription("List portfolio assets"), + mcp.WithInputSchema[struct { + Portfolio string `json:"portfolio" jsonschema_description:"Portfolio ID" jsonschema:"required"` + }](), + mcp.WithOutputSchema[[]Asset](), // Array output schema +) + +func listAssetsHandler(ctx context.Context, req mcp.CallToolRequest, args struct{ Portfolio string }) ([]Asset, error) { + // Return array of assets + return []Asset{ + {ID: "btc", Name: "Bitcoin", Value: 45000.50, Currency: "USD"}, + {ID: "eth", Name: "Ethereum", Value: 3200.75, Currency: "USD"}, + }, nil +} +``` + +### Schema Tags Reference + +MCP-Go uses the `jsonschema` struct tags for schema generation: + +```go +type ExampleStruct struct { + // Required field + Name string `json:"name" jsonschema:"required"` + + // Field with description + Age int `json:"age" jsonschema_description:"User age in years"` + + // Field with constraints + Score float64 `json:"score" jsonschema:"minimum=0,maximum=100"` + + // Enum field + Status string `json:"status" jsonschema:"enum=active,enum=inactive,enum=pending"` + + // Optional field with default + PageSize int `json:"pageSize,omitempty" jsonschema:"default=20"` + + // Array with constraints + Tags []string `json:"tags" jsonschema:"minItems=1,maxItems=10"` +} +``` + +### Manual Structured Results + +For more control over the response, use `NewTypedToolHandler` with manual result creation: + +```go +manualTool := mcp.NewTool("process_data", + mcp.WithDescription("Process data with custom result"), + mcp.WithInputSchema[ProcessRequest](), + mcp.WithOutputSchema[ProcessResponse](), +) + +s.AddTool(manualTool, mcp.NewTypedToolHandler(manualProcessHandler)) + +func manualProcessHandler(ctx context.Context, req mcp.CallToolRequest, args ProcessRequest) (*mcp.CallToolResult, error) { + // Process the data + response := ProcessResponse{ + Status: "completed", + ProcessedAt: time.Now(), + ItemCount: 42, + } + + // Create custom fallback text for backward compatibility + fallbackText := fmt.Sprintf("Processed %d items successfully", response.ItemCount) + + // Return structured result with custom text + return mcp.NewToolResultStructured(response, fallbackText), nil +} +``` + +### Complete Example: File Operations with Structured I/O + +Here's a complete example using the file operations pattern from earlier, enhanced with structured schemas: + +```go +// Define structured input for file operations +type FileOperationRequest struct { + Path string `json:"path" jsonschema_description:"File path" jsonschema:"required"` + Content string `json:"content,omitempty" jsonschema_description:"File content (for write operations)"` + Encoding string `json:"encoding,omitempty" jsonschema_description:"File encoding" jsonschema:"enum=utf-8,enum=ascii,enum=base64,default=utf-8"` +} + +// Define structured output +type FileOperationResponse struct { + Success bool `json:"success" jsonschema_description:"Operation success status"` + Path string `json:"path" jsonschema_description:"File path"` + Message string `json:"message" jsonschema_description:"Result message"` + Content string `json:"content,omitempty" jsonschema_description:"File content (for read operations)"` + Size int64 `json:"size,omitempty" jsonschema_description:"File size in bytes"` + Modified time.Time `json:"modified,omitempty" jsonschema_description:"Last modified time"` +} + +func main() { + s := server.NewMCPServer("File Manager", "1.0.0", + server.WithToolCapabilities(true), + ) + + // Create file tool with structured I/O + createFileTool := mcp.NewTool("create_file", + mcp.WithDescription("Create a new file with content"), + mcp.WithInputSchema[FileOperationRequest](), + mcp.WithOutputSchema[FileOperationResponse](), + ) + + // Read file tool + readFileTool := mcp.NewTool("read_file", + mcp.WithDescription("Read file contents"), + mcp.WithInputSchema[struct { + Path string `json:"path" jsonschema_description:"File path to read" jsonschema:"required"` + }](), + mcp.WithOutputSchema[FileOperationResponse](), + ) + + s.AddTool(createFileTool, mcp.NewStructuredToolHandler(handleCreateFile)) + s.AddTool(readFileTool, mcp.NewStructuredToolHandler(handleReadFile)) + + server.ServeStdio(s) +} + +func handleCreateFile(ctx context.Context, req mcp.CallToolRequest, args FileOperationRequest) (FileOperationResponse, error) { + // Validate path for security + if strings.Contains(args.Path, "..") { + return FileOperationResponse{ + Success: false, + Path: args.Path, + Message: "Invalid path: directory traversal not allowed", + }, nil + } + + // Handle different encodings + var data []byte + switch args.Encoding { + case "base64": + var err error + data, err = base64.StdEncoding.DecodeString(args.Content) + if err != nil { + return FileOperationResponse{ + Success: false, + Path: args.Path, + Message: fmt.Sprintf("Invalid base64 content: %v", err), + }, nil + } + default: + data = []byte(args.Content) + } + + // Create file + if err := os.WriteFile(args.Path, data, 0644); err != nil { + return FileOperationResponse{ + Success: false, + Path: args.Path, + Message: fmt.Sprintf("Failed to create file: %v", err), + }, nil + } + + // Get file info + info, _ := os.Stat(args.Path) + + return FileOperationResponse{ + Success: true, + Path: args.Path, + Message: "File created successfully", + Size: info.Size(), + Modified: info.ModTime(), + }, nil +} + +func handleReadFile(ctx context.Context, req mcp.CallToolRequest, args struct{ Path string }) (FileOperationResponse, error) { + // Read file + data, err := os.ReadFile(args.Path) + if err != nil { + return FileOperationResponse{ + Success: false, + Path: args.Path, + Message: fmt.Sprintf("Failed to read file: %v", err), + }, nil + } + + // Get file info + info, _ := os.Stat(args.Path) + + return FileOperationResponse{ + Success: true, + Path: args.Path, + Message: "File read successfully", + Content: string(data), + Size: info.Size(), + Modified: info.ModTime(), + }, nil +} +``` + ## Tool Handlers Tool handlers process the actual function calls from LLMs. MCP-Go provides convenient helper methods for safe parameter extraction.