diff --git a/cli/clitest/golden.go b/cli/clitest/golden.go index e79006ebb58e3..d4401d6c5d5f9 100644 --- a/cli/clitest/golden.go +++ b/cli/clitest/golden.go @@ -11,7 +11,9 @@ import ( "strings" "testing" + "github.com/google/go-cmp/cmp" "github.com/google/uuid" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/coder/coder/v2/cli/config" @@ -117,11 +119,7 @@ func TestGoldenFile(t *testing.T, fileName string, actual []byte, replacements m require.NoError(t, err, "read golden file, run \"make gen/golden-files\" and commit the changes") expected = normalizeGoldenFile(t, expected) - require.Equal( - t, string(expected), string(actual), - "golden file mismatch: %s, run \"make gen/golden-files\", verify and commit the changes", - goldenPath, - ) + assert.Empty(t, cmp.Diff(string(expected), string(actual)), "golden file mismatch (-want +got): %s, run \"make gen/golden-files\", verify and commit the changes", goldenPath) } // normalizeGoldenFile replaces any strings that are system or timing dependent diff --git a/cli/exp.go b/cli/exp.go index 2339da86313a6..dafd85402663e 100644 --- a/cli/exp.go +++ b/cli/exp.go @@ -13,6 +13,7 @@ func (r *RootCmd) expCmd() *serpent.Command { Children: []*serpent.Command{ r.scaletestCmd(), r.errorExample(), + r.mcpCommand(), r.promptExample(), r.rptyCommand(), }, diff --git a/cli/exp_mcp.go b/cli/exp_mcp.go new file mode 100644 index 0000000000000..a5af41d9103a6 --- /dev/null +++ b/cli/exp_mcp.go @@ -0,0 +1,284 @@ +package cli + +import ( + "context" + "encoding/json" + "errors" + "log" + "os" + "path/filepath" + + "cdr.dev/slog" + "cdr.dev/slog/sloggers/sloghuman" + "github.com/coder/coder/v2/cli/cliui" + "github.com/coder/coder/v2/codersdk" + codermcp "github.com/coder/coder/v2/mcp" + "github.com/coder/serpent" +) + +func (r *RootCmd) mcpCommand() *serpent.Command { + cmd := &serpent.Command{ + Use: "mcp", + Short: "Run the Coder MCP server and configure it to work with AI tools.", + Long: "The Coder MCP server allows you to automatically create workspaces with parameters.", + Handler: func(i *serpent.Invocation) error { + return i.Command.HelpHandler(i) + }, + Children: []*serpent.Command{ + r.mcpConfigure(), + r.mcpServer(), + }, + } + return cmd +} + +func (r *RootCmd) mcpConfigure() *serpent.Command { + cmd := &serpent.Command{ + Use: "configure", + Short: "Automatically configure the MCP server.", + Handler: func(i *serpent.Invocation) error { + return i.Command.HelpHandler(i) + }, + Children: []*serpent.Command{ + r.mcpConfigureClaudeDesktop(), + r.mcpConfigureClaudeCode(), + r.mcpConfigureCursor(), + }, + } + return cmd +} + +func (*RootCmd) mcpConfigureClaudeDesktop() *serpent.Command { + cmd := &serpent.Command{ + Use: "claude-desktop", + Short: "Configure the Claude Desktop server.", + Handler: func(_ *serpent.Invocation) error { + configPath, err := os.UserConfigDir() + if err != nil { + return err + } + configPath = filepath.Join(configPath, "Claude") + err = os.MkdirAll(configPath, 0o755) + if err != nil { + return err + } + configPath = filepath.Join(configPath, "claude_desktop_config.json") + _, err = os.Stat(configPath) + if err != nil { + if !os.IsNotExist(err) { + return err + } + } + contents := map[string]any{} + data, err := os.ReadFile(configPath) + if err != nil { + if !os.IsNotExist(err) { + return err + } + } else { + err = json.Unmarshal(data, &contents) + if err != nil { + return err + } + } + binPath, err := os.Executable() + if err != nil { + return err + } + contents["mcpServers"] = map[string]any{ + "coder": map[string]any{"command": binPath, "args": []string{"exp", "mcp", "server"}}, + } + data, err = json.MarshalIndent(contents, "", " ") + if err != nil { + return err + } + err = os.WriteFile(configPath, data, 0o600) + if err != nil { + return err + } + return nil + }, + } + return cmd +} + +func (*RootCmd) mcpConfigureClaudeCode() *serpent.Command { + cmd := &serpent.Command{ + Use: "claude-code", + Short: "Configure the Claude Code server.", + Handler: func(_ *serpent.Invocation) error { + return nil + }, + } + return cmd +} + +func (*RootCmd) mcpConfigureCursor() *serpent.Command { + var project bool + cmd := &serpent.Command{ + Use: "cursor", + Short: "Configure Cursor to use Coder MCP.", + Options: serpent.OptionSet{ + serpent.Option{ + Flag: "project", + Env: "CODER_MCP_CURSOR_PROJECT", + Description: "Use to configure a local project to use the Cursor MCP.", + Value: serpent.BoolOf(&project), + }, + }, + Handler: func(_ *serpent.Invocation) error { + dir, err := os.Getwd() + if err != nil { + return err + } + if !project { + dir, err = os.UserHomeDir() + if err != nil { + return err + } + } + cursorDir := filepath.Join(dir, ".cursor") + err = os.MkdirAll(cursorDir, 0o755) + if err != nil { + return err + } + mcpConfig := filepath.Join(cursorDir, "mcp.json") + _, err = os.Stat(mcpConfig) + contents := map[string]any{} + if err != nil { + if !os.IsNotExist(err) { + return err + } + } else { + data, err := os.ReadFile(mcpConfig) + if err != nil { + return err + } + // The config can be empty, so we don't want to return an error if it is. + if len(data) > 0 { + err = json.Unmarshal(data, &contents) + if err != nil { + return err + } + } + } + mcpServers, ok := contents["mcpServers"].(map[string]any) + if !ok { + mcpServers = map[string]any{} + } + binPath, err := os.Executable() + if err != nil { + return err + } + mcpServers["coder"] = map[string]any{ + "command": binPath, + "args": []string{"exp", "mcp", "server"}, + } + contents["mcpServers"] = mcpServers + data, err := json.MarshalIndent(contents, "", " ") + if err != nil { + return err + } + err = os.WriteFile(mcpConfig, data, 0o600) + if err != nil { + return err + } + return nil + }, + } + return cmd +} + +func (r *RootCmd) mcpServer() *serpent.Command { + var ( + client = new(codersdk.Client) + instructions string + allowedTools []string + ) + return &serpent.Command{ + Use: "server", + Handler: func(inv *serpent.Invocation) error { + return mcpServerHandler(inv, client, instructions, allowedTools) + }, + Short: "Start the Coder MCP server.", + Middleware: serpent.Chain( + r.InitClient(client), + ), + Options: []serpent.Option{ + { + Name: "instructions", + Description: "The instructions to pass to the MCP server.", + Flag: "instructions", + Value: serpent.StringOf(&instructions), + }, + { + Name: "allowed-tools", + Description: "Comma-separated list of allowed tools. If not specified, all tools are allowed.", + Flag: "allowed-tools", + Value: serpent.StringArrayOf(&allowedTools), + }, + }, + } +} + +func mcpServerHandler(inv *serpent.Invocation, client *codersdk.Client, instructions string, allowedTools []string) error { + ctx, cancel := context.WithCancel(inv.Context()) + defer cancel() + + logger := slog.Make(sloghuman.Sink(inv.Stdout)) + + me, err := client.User(ctx, codersdk.Me) + if err != nil { + cliui.Errorf(inv.Stderr, "Failed to log in to the Coder deployment.") + cliui.Errorf(inv.Stderr, "Please check your URL and credentials.") + cliui.Errorf(inv.Stderr, "Tip: Run `coder whoami` to check your credentials.") + return err + } + cliui.Infof(inv.Stderr, "Starting MCP server") + cliui.Infof(inv.Stderr, "User : %s", me.Username) + cliui.Infof(inv.Stderr, "URL : %s", client.URL) + cliui.Infof(inv.Stderr, "Instructions : %q", instructions) + if len(allowedTools) > 0 { + cliui.Infof(inv.Stderr, "Allowed Tools : %v", allowedTools) + } + cliui.Infof(inv.Stderr, "Press Ctrl+C to stop the server") + + // Capture the original stdin, stdout, and stderr. + invStdin := inv.Stdin + invStdout := inv.Stdout + invStderr := inv.Stderr + defer func() { + inv.Stdin = invStdin + inv.Stdout = invStdout + inv.Stderr = invStderr + }() + + options := []codermcp.Option{ + codermcp.WithInstructions(instructions), + codermcp.WithLogger(&logger), + } + + // Add allowed tools option if specified + if len(allowedTools) > 0 { + options = append(options, codermcp.WithAllowedTools(allowedTools)) + } + + srv := codermcp.NewStdio(client, options...) + srv.SetErrorLogger(log.New(invStderr, "", log.LstdFlags)) + + done := make(chan error) + go func() { + defer close(done) + srvErr := srv.Listen(ctx, invStdin, invStdout) + done <- srvErr + }() + + if err := <-done; err != nil { + if !errors.Is(err, context.Canceled) { + cliui.Errorf(inv.Stderr, "Failed to start the MCP server: %s", err) + return err + } + } + + return nil +} diff --git a/cli/exp_mcp_test.go b/cli/exp_mcp_test.go new file mode 100644 index 0000000000000..06d7693c86f7d --- /dev/null +++ b/cli/exp_mcp_test.go @@ -0,0 +1,142 @@ +package cli_test + +import ( + "context" + "encoding/json" + "runtime" + "slices" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/coder/coder/v2/cli/clitest" + "github.com/coder/coder/v2/coderd/coderdtest" + "github.com/coder/coder/v2/pty/ptytest" + "github.com/coder/coder/v2/testutil" +) + +func TestExpMcp(t *testing.T) { + t.Parallel() + + // Reading to / writing from the PTY is flaky on non-linux systems. + if runtime.GOOS != "linux" { + t.Skip("skipping on non-linux") + } + + t.Run("AllowedTools", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitShort) + cancelCtx, cancel := context.WithCancel(ctx) + t.Cleanup(cancel) + + // Given: a running coder deployment + client := coderdtest.New(t, nil) + _ = coderdtest.CreateFirstUser(t, client) + + // Given: we run the exp mcp command with allowed tools set + inv, root := clitest.New(t, "exp", "mcp", "server", "--allowed-tools=coder_whoami,coder_list_templates") + inv = inv.WithContext(cancelCtx) + + pty := ptytest.New(t) + inv.Stdin = pty.Input() + inv.Stdout = pty.Output() + clitest.SetupConfig(t, client, root) + + cmdDone := make(chan struct{}) + go func() { + defer close(cmdDone) + err := inv.Run() + assert.NoError(t, err) + }() + + // When: we send a tools/list request + toolsPayload := `{"jsonrpc":"2.0","id":2,"method":"tools/list"}` + pty.WriteLine(toolsPayload) + _ = pty.ReadLine(ctx) // ignore echoed output + output := pty.ReadLine(ctx) + + cancel() + <-cmdDone + + // Then: we should only see the allowed tools in the response + var toolsResponse struct { + Result struct { + Tools []struct { + Name string `json:"name"` + } `json:"tools"` + } `json:"result"` + } + err := json.Unmarshal([]byte(output), &toolsResponse) + require.NoError(t, err) + require.Len(t, toolsResponse.Result.Tools, 2, "should have exactly 2 tools") + foundTools := make([]string, 0, 2) + for _, tool := range toolsResponse.Result.Tools { + foundTools = append(foundTools, tool.Name) + } + slices.Sort(foundTools) + require.Equal(t, []string{"coder_list_templates", "coder_whoami"}, foundTools) + }) + + t.Run("OK", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitShort) + cancelCtx, cancel := context.WithCancel(ctx) + t.Cleanup(cancel) + + client := coderdtest.New(t, nil) + _ = coderdtest.CreateFirstUser(t, client) + inv, root := clitest.New(t, "exp", "mcp", "server") + inv = inv.WithContext(cancelCtx) + + pty := ptytest.New(t) + inv.Stdin = pty.Input() + inv.Stdout = pty.Output() + clitest.SetupConfig(t, client, root) + + cmdDone := make(chan struct{}) + go func() { + defer close(cmdDone) + err := inv.Run() + assert.NoError(t, err) + }() + + payload := `{"jsonrpc":"2.0","id":1,"method":"initialize"}` + pty.WriteLine(payload) + _ = pty.ReadLine(ctx) // ignore echoed output + output := pty.ReadLine(ctx) + cancel() + <-cmdDone + + // Ensure the initialize output is valid JSON + t.Logf("/initialize output: %s", output) + var initializeResponse map[string]interface{} + err := json.Unmarshal([]byte(output), &initializeResponse) + require.NoError(t, err) + require.Equal(t, "2.0", initializeResponse["jsonrpc"]) + require.Equal(t, 1.0, initializeResponse["id"]) + require.NotNil(t, initializeResponse["result"]) + }) + + t.Run("NoCredentials", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitShort) + cancelCtx, cancel := context.WithCancel(ctx) + t.Cleanup(cancel) + + client := coderdtest.New(t, nil) + inv, root := clitest.New(t, "exp", "mcp", "server") + inv = inv.WithContext(cancelCtx) + + pty := ptytest.New(t) + inv.Stdin = pty.Input() + inv.Stdout = pty.Output() + clitest.SetupConfig(t, client, root) + + err := inv.Run() + assert.ErrorContains(t, err, "your session has expired") + }) +} diff --git a/go.mod b/go.mod index 56c52a82b6721..3ecb96a3e14f6 100644 --- a/go.mod +++ b/go.mod @@ -480,3 +480,7 @@ require ( github.com/golang-jwt/jwt/v5 v5.2.2 // indirect github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect ) + +require github.com/mark3labs/mcp-go v0.17.0 + +require github.com/yosida95/uritemplate/v3 v3.0.2 // indirect diff --git a/go.sum b/go.sum index efa6ade52ffb6..70c46ff5266da 100644 --- a/go.sum +++ b/go.sum @@ -658,6 +658,8 @@ github.com/makeworld-the-better-one/dither/v2 v2.4.0 h1:Az/dYXiTcwcRSe59Hzw4RI1r github.com/makeworld-the-better-one/dither/v2 v2.4.0/go.mod h1:VBtN8DXO7SNtyGmLiGA7IsFeKrBkQPze1/iAeM95arc= github.com/marekm4/color-extractor v1.2.1 h1:3Zb2tQsn6bITZ8MBVhc33Qn1k5/SEuZ18mrXGUqIwn0= github.com/marekm4/color-extractor v1.2.1/go.mod h1:90VjmiHI6M8ez9eYUaXLdcKnS+BAOp7w+NpwBdkJmpA= +github.com/mark3labs/mcp-go v0.17.0 h1:5Ps6T7qXr7De/2QTqs9h6BKeZ/qdeUeGrgM5lPzi930= +github.com/mark3labs/mcp-go v0.17.0/go.mod h1:KmJndYv7GIgcPVwEKJjNcbhVQ+hJGJhrCCB/9xITzpE= github.com/mattn/go-colorable v0.0.9/go.mod h1:9vuHe8Xs5qXnSaW/c/ABM9alt+Vo+STaOChaDxuIBZU= github.com/mattn/go-colorable v0.1.2/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE= github.com/mattn/go-colorable v0.1.9/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc= @@ -972,6 +974,8 @@ github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZ github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E= github.com/yashtewari/glob-intersection v0.2.0 h1:8iuHdN88yYuCzCdjt0gDe+6bAhUwBeEWqThExu54RFg= github.com/yashtewari/glob-intersection v0.2.0/go.mod h1:LK7pIC3piUjovexikBbJ26Yml7g8xa5bsjfx2v1fwok= +github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= +github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= github.com/yudai/gojsondiff v1.0.0 h1:27cbfqXLVEJ1o8I6v3y9lg8Ydm53EKqHXAOMxEGlCOA= github.com/yudai/gojsondiff v1.0.0/go.mod h1:AY32+k2cwILAkW1fbgxQ5mUmMiZFgLIV+FBNExI05xg= github.com/yudai/golcs v0.0.0-20170316035057-ecda9a501e82 h1:BHyfKlQyqbsFN5p3IfnEUduWvb9is428/nNb5L3U01M= diff --git a/mcp/mcp.go b/mcp/mcp.go new file mode 100644 index 0000000000000..80e0f341e16e6 --- /dev/null +++ b/mcp/mcp.go @@ -0,0 +1,643 @@ +package codermcp + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "io" + "os" + "slices" + "strings" + "time" + + "github.com/google/uuid" + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" + "golang.org/x/xerrors" + + "cdr.dev/slog" + "cdr.dev/slog/sloggers/sloghuman" + "github.com/coder/coder/v2/buildinfo" + "github.com/coder/coder/v2/coderd/util/ptr" + "github.com/coder/coder/v2/codersdk" + "github.com/coder/coder/v2/codersdk/workspacesdk" +) + +type mcpOptions struct { + instructions string + logger *slog.Logger + allowedTools []string +} + +// Option is a function that configures the MCP server. +type Option func(*mcpOptions) + +// WithInstructions sets the instructions for the MCP server. +func WithInstructions(instructions string) Option { + return func(o *mcpOptions) { + o.instructions = instructions + } +} + +// WithLogger sets the logger for the MCP server. +func WithLogger(logger *slog.Logger) Option { + return func(o *mcpOptions) { + o.logger = logger + } +} + +// WithAllowedTools sets the allowed tools for the MCP server. +func WithAllowedTools(tools []string) Option { + return func(o *mcpOptions) { + o.allowedTools = tools + } +} + +// NewStdio creates a new MCP stdio server with the given client and options. +// It is the responsibility of the caller to start and stop the server. +func NewStdio(client *codersdk.Client, opts ...Option) *server.StdioServer { + options := &mcpOptions{ + instructions: ``, + logger: ptr.Ref(slog.Make(sloghuman.Sink(os.Stdout))), + } + for _, opt := range opts { + opt(options) + } + + mcpSrv := server.NewMCPServer( + "Coder Agent", + buildinfo.Version(), + server.WithInstructions(options.instructions), + ) + + logger := slog.Make(sloghuman.Sink(os.Stdout)) + + // Register tools based on the allowed list (if specified) + reg := AllTools() + if len(options.allowedTools) > 0 { + reg = reg.WithOnlyAllowed(options.allowedTools...) + } + reg.Register(mcpSrv, ToolDeps{ + Client: client, + Logger: &logger, + }) + + srv := server.NewStdioServer(mcpSrv) + return srv +} + +// allTools is the list of all available tools. When adding a new tool, +// make sure to update this list. +var allTools = ToolRegistry{ + { + Tool: mcp.NewTool("coder_report_task", + mcp.WithDescription(`Report progress on a user task in Coder. +Use this tool to keep the user informed about your progress with their request. +For long-running operations, call this periodically to provide status updates. +This is especially useful when performing multi-step operations like workspace creation or deployment.`), + mcp.WithString("summary", mcp.Description(`A concise summary of your current progress on the task. + +Good Summaries: +- "Taking a look at the login page..." +- "Found a bug! Fixing it now..." +- "Investigating the GitHub Issue..." +- "Waiting for workspace to start (1/3 resources ready)" +- "Downloading template files from repository"`), mcp.Required()), + mcp.WithString("link", mcp.Description(`A relevant URL related to your work, such as: +- GitHub issue link +- Pull request URL +- Documentation reference +- Workspace URL +Use complete URLs (including https://) when possible.`), mcp.Required()), + mcp.WithString("emoji", mcp.Description(`A relevant emoji that visually represents the current status: +- 🔍 for investigating/searching +- 🚀 for deploying/starting +- 🐛 for debugging +- ✅ for completion +- ⏳ for waiting +Choose an emoji that helps the user understand the current phase at a glance.`), mcp.Required()), + mcp.WithBoolean("done", mcp.Description(`Whether the overall task the user requested is complete. +Set to true only when the entire requested operation is finished successfully. +For multi-step processes, use false until all steps are complete.`), mcp.Required()), + ), + MakeHandler: handleCoderReportTask, + }, + { + Tool: mcp.NewTool("coder_whoami", + mcp.WithDescription(`Get information about the currently logged-in Coder user. +Returns JSON with the user's profile including fields: id, username, email, created_at, status, roles, etc. +Use this to identify the current user context before performing workspace operations. +This tool is useful for verifying permissions and checking the user's identity. + +Common errors: +- Authentication failure: The session may have expired +- Server unavailable: The Coder deployment may be unreachable`), + ), + MakeHandler: handleCoderWhoami, + }, + { + Tool: mcp.NewTool("coder_list_templates", + mcp.WithDescription(`List all templates available on the Coder deployment. +Returns JSON with detailed information about each template, including: +- Template name, ID, and description +- Creation/modification timestamps +- Version information +- Associated organization + +Use this tool to discover available templates before creating workspaces. +Templates define the infrastructure and configuration for workspaces. + +Common errors: +- Authentication failure: Check user permissions +- No templates available: The deployment may not have any templates configured`), + ), + MakeHandler: handleCoderListTemplates, + }, + { + Tool: mcp.NewTool("coder_list_workspaces", + mcp.WithDescription(`List workspaces available on the Coder deployment. +Returns JSON with workspace metadata including status, resources, and configurations. +Use this before other workspace operations to find valid workspace names/IDs. +Results are paginated - use offset and limit parameters for large deployments. + +Common errors: +- Authentication failure: Check user permissions +- Invalid owner parameter: Ensure the owner exists`), + mcp.WithString(`owner`, mcp.Description(`The username of the workspace owner to filter by. +Defaults to "me" which represents the currently authenticated user. +Use this to view workspaces belonging to other users (requires appropriate permissions). +Special value: "me" - List workspaces owned by the authenticated user.`), mcp.DefaultString(codersdk.Me)), + mcp.WithNumber(`offset`, mcp.Description(`Pagination offset - the starting index for listing workspaces. +Used with the 'limit' parameter to implement pagination. +For example, to get the second page of results with 10 items per page, use offset=10. +Defaults to 0 (first page).`), mcp.DefaultNumber(0)), + mcp.WithNumber(`limit`, mcp.Description(`Maximum number of workspaces to return in a single request. +Used with the 'offset' parameter to implement pagination. +Higher values return more results but may increase response time. +Valid range: 1-100. Defaults to 10.`), mcp.DefaultNumber(10)), + ), + MakeHandler: handleCoderListWorkspaces, + }, + { + Tool: mcp.NewTool("coder_get_workspace", + mcp.WithDescription(`Get detailed information about a specific Coder workspace. +Returns comprehensive JSON with the workspace's configuration, status, and resources. +Use this to check workspace status before performing operations like exec or start/stop. +The response includes the latest build status, agent connectivity, and resource details. + +Common errors: +- Workspace not found: Check the workspace name or ID +- Permission denied: The user may not have access to this workspace`), + mcp.WithString("workspace", mcp.Description(`The workspace ID (UUID) or name to retrieve. +Can be specified as either: +- Full UUID: e.g., "8a0b9c7d-1e2f-3a4b-5c6d-7e8f9a0b1c2d" +- Workspace name: e.g., "dev", "python-project" +Use coder_list_workspaces first if you're not sure about available workspace names.`), mcp.Required()), + ), + MakeHandler: handleCoderGetWorkspace, + }, + { + Tool: mcp.NewTool("coder_workspace_exec", + mcp.WithDescription(`Execute a shell command in a remote Coder workspace. +Runs the specified command and returns the complete output (stdout/stderr). +Use this for file operations, running build commands, or checking workspace state. +The workspace must be running with a connected agent for this to succeed. + +Before using this tool: +1. Verify the workspace is running using coder_get_workspace +2. Start the workspace if needed using coder_start_workspace + +Common errors: +- Workspace not running: Start the workspace first +- Command not allowed: Check security restrictions +- Agent not connected: The workspace may still be starting up`), + mcp.WithString("workspace", mcp.Description(`The workspace ID (UUID) or name where the command will execute. +Can be specified as either: +- Full UUID: e.g., "8a0b9c7d-1e2f-3a4b-5c6d-7e8f9a0b1c2d" +- Workspace name: e.g., "dev", "python-project" +The workspace must be running with a connected agent. +Use coder_get_workspace first to check the workspace status.`), mcp.Required()), + mcp.WithString("command", mcp.Description(`The shell command to execute in the workspace. +Commands are executed in the default shell of the workspace. + +Examples: +- "ls -la" - List files with details +- "cd /path/to/directory && command" - Execute in specific directory +- "cat ~/.bashrc" - View a file's contents +- "python -m pip list" - List installed Python packages + +Note: Very long-running commands may time out.`), mcp.Required()), + ), + MakeHandler: handleCoderWorkspaceExec, + }, + { + Tool: mcp.NewTool("coder_workspace_transition", + mcp.WithDescription(`Start or stop a running Coder workspace. +If stopping, initiates the workspace stop transition. +Only works on workspaces that are currently running or failed. + +If starting, initiates the workspace start transition. +Only works on workspaces that are currently stopped or failed. + +Stopping or starting a workspace is an asynchronous operation - it may take several minutes to complete. + +After calling this tool: +1. Use coder_report_task to inform the user that the workspace is stopping or starting +2. Use coder_get_workspace periodically to check for completion + +Common errors: +- Workspace already started/starting/stopped/stopping: No action needed +- Cancellation failed: There may be issues with the underlying infrastructure +- User doesn't own workspace: Permission issues`), + mcp.WithString("workspace", mcp.Description(`The workspace ID (UUID) or name to start or stop. +Can be specified as either: +- Full UUID: e.g., "8a0b9c7d-1e2f-3a4b-5c6d-7e8f9a0b1c2d" +- Workspace name: e.g., "dev", "python-project" +The workspace must be in a running state to be stopped, or in a stopped or failed state to be started. +Use coder_get_workspace first to check the current workspace status.`), mcp.Required()), + mcp.WithString("transition", mcp.Description(`The transition to apply to the workspace. +Can be either "start" or "stop".`)), + ), + MakeHandler: handleCoderWorkspaceTransition, + }, +} + +// ToolDeps contains all dependencies needed by tool handlers +type ToolDeps struct { + Client *codersdk.Client + Logger *slog.Logger +} + +// ToolHandler associates a tool with its handler creation function +type ToolHandler struct { + Tool mcp.Tool + MakeHandler func(ToolDeps) server.ToolHandlerFunc +} + +// ToolRegistry is a map of available tools with their handler creation +// functions +type ToolRegistry []ToolHandler + +// WithOnlyAllowed returns a new ToolRegistry containing only the tools +// specified in the allowed list. +func (r ToolRegistry) WithOnlyAllowed(allowed ...string) ToolRegistry { + if len(allowed) == 0 { + return []ToolHandler{} + } + + filtered := make(ToolRegistry, 0, len(r)) + + // The overhead of a map lookup is likely higher than a linear scan + // for a small number of tools. + for _, entry := range r { + if slices.Contains(allowed, entry.Tool.Name) { + filtered = append(filtered, entry) + } + } + return filtered +} + +// Register registers all tools in the registry with the given tool adder +// and dependencies. +func (r ToolRegistry) Register(srv *server.MCPServer, deps ToolDeps) { + for _, entry := range r { + srv.AddTool(entry.Tool, entry.MakeHandler(deps)) + } +} + +// AllTools returns all available tools. +func AllTools() ToolRegistry { + // return a copy of allTools to avoid mutating the original + return slices.Clone(allTools) +} + +type handleCoderReportTaskArgs struct { + Summary string `json:"summary"` + Link string `json:"link"` + Emoji string `json:"emoji"` + Done bool `json:"done"` +} + +// Example payload: +// {"jsonrpc":"2.0","id":1,"method":"tools/call", "params": {"name": "coder_report_task", "arguments": {"summary": "I'm working on the login page.", "link": "https://github.com/coder/coder/pull/1234", "emoji": "🔍", "done": false}}} +func handleCoderReportTask(deps ToolDeps) server.ToolHandlerFunc { + return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + if deps.Client == nil { + return nil, xerrors.New("developer error: client is required") + } + + // Convert the request parameters to a json.RawMessage so we can unmarshal + // them into the correct struct. + args, err := unmarshalArgs[handleCoderReportTaskArgs](request.Params.Arguments) + if err != nil { + return nil, xerrors.Errorf("failed to unmarshal arguments: %w", err) + } + + // TODO: Waiting on support for tasks. + deps.Logger.Info(ctx, "report task tool called", slog.F("summary", args.Summary), slog.F("link", args.Link), slog.F("done", args.Done), slog.F("emoji", args.Emoji)) + /* + err := sdk.PostTask(ctx, agentsdk.PostTaskRequest{ + Reporter: "claude", + Summary: summary, + URL: link, + Completion: done, + Icon: emoji, + }) + if err != nil { + return nil, err + } + */ + + return &mcp.CallToolResult{ + Content: []mcp.Content{ + mcp.NewTextContent("Thanks for reporting!"), + }, + }, nil + } +} + +// Example payload: +// {"jsonrpc":"2.0","id":1,"method":"tools/call", "params": {"name": "coder_whoami", "arguments": {}}} +func handleCoderWhoami(deps ToolDeps) server.ToolHandlerFunc { + return func(ctx context.Context, _ mcp.CallToolRequest) (*mcp.CallToolResult, error) { + if deps.Client == nil { + return nil, xerrors.New("developer error: client is required") + } + me, err := deps.Client.User(ctx, codersdk.Me) + if err != nil { + return nil, xerrors.Errorf("Failed to fetch the current user: %s", err.Error()) + } + + var buf bytes.Buffer + if err := json.NewEncoder(&buf).Encode(me); err != nil { + return nil, xerrors.Errorf("Failed to encode the current user: %s", err.Error()) + } + + return &mcp.CallToolResult{ + Content: []mcp.Content{ + mcp.NewTextContent(strings.TrimSpace(buf.String())), + }, + }, nil + } +} + +type handleCoderListWorkspacesArgs struct { + Owner string `json:"owner"` + Offset int `json:"offset"` + Limit int `json:"limit"` +} + +// Example payload: +// {"jsonrpc":"2.0","id":1,"method":"tools/call", "params": {"name": "coder_list_workspaces", "arguments": {"owner": "me", "offset": 0, "limit": 10}}} +func handleCoderListWorkspaces(deps ToolDeps) server.ToolHandlerFunc { + return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + if deps.Client == nil { + return nil, xerrors.New("developer error: client is required") + } + args, err := unmarshalArgs[handleCoderListWorkspacesArgs](request.Params.Arguments) + if err != nil { + return nil, xerrors.Errorf("failed to unmarshal arguments: %w", err) + } + + workspaces, err := deps.Client.Workspaces(ctx, codersdk.WorkspaceFilter{ + Owner: args.Owner, + Offset: args.Offset, + Limit: args.Limit, + }) + if err != nil { + return nil, xerrors.Errorf("failed to fetch workspaces: %w", err) + } + + // Encode it as JSON. TODO: It might be nicer for the agent to have a tabulated response. + data, err := json.Marshal(workspaces) + if err != nil { + return nil, xerrors.Errorf("failed to encode workspaces: %s", err.Error()) + } + + return &mcp.CallToolResult{ + Content: []mcp.Content{ + mcp.NewTextContent(string(data)), + }, + }, nil + } +} + +type handleCoderGetWorkspaceArgs struct { + Workspace string `json:"workspace"` +} + +// Example payload: +// {"jsonrpc":"2.0","id":1,"method":"tools/call", "params": {"name": "coder_get_workspace", "arguments": {"workspace": "dev"}}} +func handleCoderGetWorkspace(deps ToolDeps) server.ToolHandlerFunc { + return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + if deps.Client == nil { + return nil, xerrors.New("developer error: client is required") + } + args, err := unmarshalArgs[handleCoderGetWorkspaceArgs](request.Params.Arguments) + if err != nil { + return nil, xerrors.Errorf("failed to unmarshal arguments: %w", err) + } + + workspace, err := getWorkspaceByIDOrOwnerName(ctx, deps.Client, args.Workspace) + if err != nil { + return nil, xerrors.Errorf("failed to fetch workspace: %w", err) + } + + workspaceJSON, err := json.Marshal(workspace) + if err != nil { + return nil, xerrors.Errorf("failed to encode workspace: %w", err) + } + + return &mcp.CallToolResult{ + Content: []mcp.Content{ + mcp.NewTextContent(string(workspaceJSON)), + }, + }, nil + } +} + +type handleCoderWorkspaceExecArgs struct { + Workspace string `json:"workspace"` + Command string `json:"command"` +} + +// Example payload: +// {"jsonrpc":"2.0","id":1,"method":"tools/call", "params": {"name": "coder_workspace_exec", "arguments": {"workspace": "dev", "command": "ps -ef"}}} +func handleCoderWorkspaceExec(deps ToolDeps) server.ToolHandlerFunc { + return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + if deps.Client == nil { + return nil, xerrors.New("developer error: client is required") + } + args, err := unmarshalArgs[handleCoderWorkspaceExecArgs](request.Params.Arguments) + if err != nil { + return nil, xerrors.Errorf("failed to unmarshal arguments: %w", err) + } + + // Attempt to fetch the workspace. We may get a UUID or a name, so try to + // handle both. + ws, err := getWorkspaceByIDOrOwnerName(ctx, deps.Client, args.Workspace) + if err != nil { + return nil, xerrors.Errorf("failed to fetch workspace: %w", err) + } + + // Ensure the workspace is started. + // Select the first agent of the workspace. + var agt *codersdk.WorkspaceAgent + for _, r := range ws.LatestBuild.Resources { + for _, a := range r.Agents { + if a.Status != codersdk.WorkspaceAgentConnected { + continue + } + agt = ptr.Ref(a) + break + } + } + if agt == nil { + return nil, xerrors.Errorf("no connected agents for workspace %s", ws.ID) + } + + startedAt := time.Now() + conn, err := workspacesdk.New(deps.Client).AgentReconnectingPTY(ctx, workspacesdk.WorkspaceAgentReconnectingPTYOpts{ + AgentID: agt.ID, + Reconnect: uuid.New(), + Width: 80, + Height: 24, + Command: args.Command, + BackendType: "buffered", // the screen backend is annoying to use here. + }) + if err != nil { + return nil, xerrors.Errorf("failed to open reconnecting PTY: %w", err) + } + defer conn.Close() + connectedAt := time.Now() + + var buf bytes.Buffer + if _, err := io.Copy(&buf, conn); err != nil { + // EOF is expected when the connection is closed. + // We can ignore this error. + if !errors.Is(err, io.EOF) { + return nil, xerrors.Errorf("failed to read from reconnecting PTY: %w", err) + } + } + completedAt := time.Now() + connectionTime := connectedAt.Sub(startedAt) + executionTime := completedAt.Sub(connectedAt) + + resp := map[string]string{ + "connection_time": connectionTime.String(), + "execution_time": executionTime.String(), + "output": buf.String(), + } + respJSON, err := json.Marshal(resp) + if err != nil { + return nil, xerrors.Errorf("failed to encode workspace build: %w", err) + } + + return &mcp.CallToolResult{ + Content: []mcp.Content{ + mcp.NewTextContent(string(respJSON)), + }, + }, nil + } +} + +// Example payload: +// {"jsonrpc":"2.0","id":1,"method":"tools/call", "params": {"name": "coder_list_templates", "arguments": {}}} +func handleCoderListTemplates(deps ToolDeps) server.ToolHandlerFunc { + return func(ctx context.Context, _ mcp.CallToolRequest) (*mcp.CallToolResult, error) { + if deps.Client == nil { + return nil, xerrors.New("developer error: client is required") + } + templates, err := deps.Client.Templates(ctx, codersdk.TemplateFilter{}) + if err != nil { + return nil, xerrors.Errorf("failed to fetch templates: %w", err) + } + + templateJSON, err := json.Marshal(templates) + if err != nil { + return nil, xerrors.Errorf("failed to encode templates: %w", err) + } + + return &mcp.CallToolResult{ + Content: []mcp.Content{ + mcp.NewTextContent(string(templateJSON)), + }, + }, nil + } +} + +type handleCoderWorkspaceTransitionArgs struct { + Workspace string `json:"workspace"` + Transition string `json:"transition"` +} + +// Example payload: +// {"jsonrpc":"2.0","id":1,"method":"tools/call", "params": {"name": +// "coder_workspace_transition", "arguments": {"workspace": "dev", "transition": "stop"}}} +func handleCoderWorkspaceTransition(deps ToolDeps) server.ToolHandlerFunc { + return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + if deps.Client == nil { + return nil, xerrors.New("developer error: client is required") + } + args, err := unmarshalArgs[handleCoderWorkspaceTransitionArgs](request.Params.Arguments) + if err != nil { + return nil, xerrors.Errorf("failed to unmarshal arguments: %w", err) + } + + workspace, err := getWorkspaceByIDOrOwnerName(ctx, deps.Client, args.Workspace) + if err != nil { + return nil, xerrors.Errorf("failed to fetch workspace: %w", err) + } + + wsTransition := codersdk.WorkspaceTransition(args.Transition) + switch wsTransition { + case codersdk.WorkspaceTransitionStart: + case codersdk.WorkspaceTransitionStop: + default: + return nil, xerrors.New("invalid transition") + } + + // We're not going to check the workspace status here as it is checked on the + // server side. + wb, err := deps.Client.CreateWorkspaceBuild(ctx, workspace.ID, codersdk.CreateWorkspaceBuildRequest{ + Transition: wsTransition, + }) + if err != nil { + return nil, xerrors.Errorf("failed to stop workspace: %w", err) + } + + resp := map[string]any{"status": wb.Status, "transition": wb.Transition} + respJSON, err := json.Marshal(resp) + if err != nil { + return nil, xerrors.Errorf("failed to encode workspace build: %w", err) + } + + return &mcp.CallToolResult{ + Content: []mcp.Content{ + mcp.NewTextContent(string(respJSON)), + }, + }, nil + } +} + +func getWorkspaceByIDOrOwnerName(ctx context.Context, client *codersdk.Client, identifier string) (codersdk.Workspace, error) { + if wsid, err := uuid.Parse(identifier); err == nil { + return client.Workspace(ctx, wsid) + } + return client.WorkspaceByOwnerAndName(ctx, codersdk.Me, identifier, codersdk.WorkspaceOptions{}) +} + +// unmarshalArgs is a helper function to convert the map[string]any we get from +// the MCP server into a typed struct. It does this by marshaling and unmarshalling +// the arguments. +func unmarshalArgs[T any](args map[string]interface{}) (t T, err error) { + argsJSON, err := json.Marshal(args) + if err != nil { + return t, xerrors.Errorf("failed to marshal arguments: %w", err) + } + if err := json.Unmarshal(argsJSON, &t); err != nil { + return t, xerrors.Errorf("failed to unmarshal arguments: %w", err) + } + return t, nil +} diff --git a/mcp/mcp_test.go b/mcp/mcp_test.go new file mode 100644 index 0000000000000..f2573f44a1be6 --- /dev/null +++ b/mcp/mcp_test.go @@ -0,0 +1,361 @@ +package codermcp_test + +import ( + "context" + "encoding/json" + "io" + "runtime" + "testing" + + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" + "github.com/stretchr/testify/require" + + "cdr.dev/slog/sloggers/slogtest" + "github.com/coder/coder/v2/agent/agenttest" + "github.com/coder/coder/v2/coderd/coderdtest" + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/dbfake" + "github.com/coder/coder/v2/codersdk" + codermcp "github.com/coder/coder/v2/mcp" + "github.com/coder/coder/v2/pty/ptytest" + "github.com/coder/coder/v2/testutil" +) + +// These tests are dependent on the state of the coder server. +// Running them in parallel is prone to racy behavior. +// nolint:tparallel,paralleltest +func TestCoderTools(t *testing.T) { + if runtime.GOOS != "linux" { + t.Skip("skipping on non-linux due to pty issues") + } + ctx := testutil.Context(t, testutil.WaitLong) + // Given: a coder server, workspace, and agent. + client, store := coderdtest.NewWithDatabase(t, nil) + owner := coderdtest.CreateFirstUser(t, client) + // Given: a member user with which to test the tools. + memberClient, member := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID) + // Given: a workspace with an agent. + r := dbfake.WorkspaceBuild(t, store, database.WorkspaceTable{ + OrganizationID: owner.OrganizationID, + OwnerID: member.ID, + }).WithAgent().Do() + + // Note: we want to test the list_workspaces tool before starting the + // workspace agent. Starting the workspace agent will modify the workspace + // state, which will affect the results of the list_workspaces tool. + listWorkspacesDone := make(chan struct{}) + agentStarted := make(chan struct{}) + go func() { + defer close(agentStarted) + <-listWorkspacesDone + agt := agenttest.New(t, client.URL, r.AgentToken) + t.Cleanup(func() { + _ = agt.Close() + }) + _ = coderdtest.NewWorkspaceAgentWaiter(t, client, r.Workspace.ID).Wait() + }() + + // Given: a MCP server listening on a pty. + pty := ptytest.New(t) + mcpSrv, closeSrv := startTestMCPServer(ctx, t, pty.Input(), pty.Output()) + t.Cleanup(func() { + _ = closeSrv() + }) + + // Register tools using our registry + logger := slogtest.Make(t, nil) + codermcp.AllTools().Register(mcpSrv, codermcp.ToolDeps{ + Client: memberClient, + Logger: &logger, + }) + + t.Run("coder_list_templates", func(t *testing.T) { + // When: the coder_list_templates tool is called + ctr := makeJSONRPCRequest(t, "tools/call", "coder_list_templates", map[string]any{}) + + pty.WriteLine(ctr) + _ = pty.ReadLine(ctx) // skip the echo + + templates, err := memberClient.Templates(ctx, codersdk.TemplateFilter{}) + require.NoError(t, err) + templatesJSON, err := json.Marshal(templates) + require.NoError(t, err) + + // Then: the response is a list of templates visible to the user. + expected := makeJSONRPCTextResponse(t, string(templatesJSON)) + actual := pty.ReadLine(ctx) + testutil.RequireJSONEq(t, expected, actual) + }) + + t.Run("coder_report_task", func(t *testing.T) { + // When: the coder_report_task tool is called + ctr := makeJSONRPCRequest(t, "tools/call", "coder_report_task", map[string]any{ + "summary": "Test summary", + "link": "https://example.com", + "emoji": "🔍", + "done": false, + "coder_url": client.URL.String(), + "coder_session_token": client.SessionToken(), + }) + + pty.WriteLine(ctr) + _ = pty.ReadLine(ctx) // skip the echo + + // Then: the response is a success message. + // TODO: check the task was created. This functionality is not yet implemented. + expected := makeJSONRPCTextResponse(t, "Thanks for reporting!") + actual := pty.ReadLine(ctx) + testutil.RequireJSONEq(t, expected, actual) + }) + + t.Run("coder_whoami", func(t *testing.T) { + // When: the coder_whoami tool is called + me, err := memberClient.User(ctx, codersdk.Me) + require.NoError(t, err) + meJSON, err := json.Marshal(me) + require.NoError(t, err) + + ctr := makeJSONRPCRequest(t, "tools/call", "coder_whoami", map[string]any{}) + + pty.WriteLine(ctr) + _ = pty.ReadLine(ctx) // skip the echo + + // Then: the response is a valid JSON respresentation of the calling user. + expected := makeJSONRPCTextResponse(t, string(meJSON)) + actual := pty.ReadLine(ctx) + testutil.RequireJSONEq(t, expected, actual) + }) + + t.Run("coder_list_workspaces", func(t *testing.T) { + defer close(listWorkspacesDone) + // When: the coder_list_workspaces tool is called + ctr := makeJSONRPCRequest(t, "tools/call", "coder_list_workspaces", map[string]any{ + "coder_url": client.URL.String(), + "coder_session_token": client.SessionToken(), + }) + + pty.WriteLine(ctr) + _ = pty.ReadLine(ctx) // skip the echo + + ws, err := memberClient.Workspaces(ctx, codersdk.WorkspaceFilter{}) + require.NoError(t, err) + wsJSON, err := json.Marshal(ws) + require.NoError(t, err) + + // Then: the response is a valid JSON respresentation of the calling user's workspaces. + expected := makeJSONRPCTextResponse(t, string(wsJSON)) + actual := pty.ReadLine(ctx) + testutil.RequireJSONEq(t, expected, actual) + }) + + t.Run("coder_get_workspace", func(t *testing.T) { + // Given: the workspace agent is connected. + // The act of starting the agent will modify the workspace state. + <-agentStarted + // When: the coder_get_workspace tool is called + ctr := makeJSONRPCRequest(t, "tools/call", "coder_get_workspace", map[string]any{ + "workspace": r.Workspace.ID.String(), + }) + + pty.WriteLine(ctr) + _ = pty.ReadLine(ctx) // skip the echo + + ws, err := memberClient.Workspace(ctx, r.Workspace.ID) + require.NoError(t, err) + wsJSON, err := json.Marshal(ws) + require.NoError(t, err) + + // Then: the response is a valid JSON respresentation of the workspace. + expected := makeJSONRPCTextResponse(t, string(wsJSON)) + actual := pty.ReadLine(ctx) + testutil.RequireJSONEq(t, expected, actual) + }) + + // NOTE: this test runs after the list_workspaces tool is called. + t.Run("coder_workspace_exec", func(t *testing.T) { + // Given: the workspace agent is connected + <-agentStarted + + // When: the coder_workspace_exec tools is called with a command + randString := testutil.GetRandomName(t) + ctr := makeJSONRPCRequest(t, "tools/call", "coder_workspace_exec", map[string]any{ + "workspace": r.Workspace.ID.String(), + "command": "echo " + randString, + "coder_url": client.URL.String(), + "coder_session_token": client.SessionToken(), + }) + + pty.WriteLine(ctr) + _ = pty.ReadLine(ctx) // skip the echo + + // Then: the response is the output of the command. + actual := pty.ReadLine(ctx) + require.Contains(t, actual, randString) + }) + + // NOTE: this test runs after the list_workspaces tool is called. + t.Run("tool_restrictions", func(t *testing.T) { + // Given: the workspace agent is connected + <-agentStarted + + // Given: a restricted MCP server with only allowed tools and commands + restrictedPty := ptytest.New(t) + allowedTools := []string{"coder_workspace_exec"} + restrictedMCPSrv, closeRestrictedSrv := startTestMCPServer(ctx, t, restrictedPty.Input(), restrictedPty.Output()) + t.Cleanup(func() { + _ = closeRestrictedSrv() + }) + codermcp.AllTools(). + WithOnlyAllowed(allowedTools...). + Register(restrictedMCPSrv, codermcp.ToolDeps{ + Client: memberClient, + Logger: &logger, + }) + + // When: the tools/list command is called + toolsListCmd := makeJSONRPCRequest(t, "tools/list", "", nil) + restrictedPty.WriteLine(toolsListCmd) + _ = restrictedPty.ReadLine(ctx) // skip the echo + + // Then: the response is a list of only the allowed tools. + toolsListResponse := restrictedPty.ReadLine(ctx) + require.Contains(t, toolsListResponse, "coder_workspace_exec") + require.NotContains(t, toolsListResponse, "coder_whoami") + + // When: a disallowed tool is called + disallowedToolCmd := makeJSONRPCRequest(t, "tools/call", "coder_whoami", map[string]any{}) + restrictedPty.WriteLine(disallowedToolCmd) + _ = restrictedPty.ReadLine(ctx) // skip the echo + + // Then: the response is an error indicating the tool is not available. + disallowedToolResponse := restrictedPty.ReadLine(ctx) + require.Contains(t, disallowedToolResponse, "error") + require.Contains(t, disallowedToolResponse, "not found") + }) + + t.Run("coder_workspace_transition_stop", func(t *testing.T) { + // Given: a separate workspace in the running state + stopWs := dbfake.WorkspaceBuild(t, store, database.WorkspaceTable{ + OrganizationID: owner.OrganizationID, + OwnerID: member.ID, + }).WithAgent().Do() + + // When: the coder_workspace_transition tool is called with a stop transition + ctr := makeJSONRPCRequest(t, "tools/call", "coder_workspace_transition", map[string]any{ + "workspace": stopWs.Workspace.ID.String(), + "transition": "stop", + }) + + pty.WriteLine(ctr) + _ = pty.ReadLine(ctx) // skip the echo + + // Then: the response is as expected. + expected := makeJSONRPCTextResponse(t, `{"status":"pending","transition":"stop"}`) // no provisionerd yet + actual := pty.ReadLine(ctx) + testutil.RequireJSONEq(t, expected, actual) + }) + + t.Run("coder_workspace_transition_start", func(t *testing.T) { + // Given: a separate workspace in the stopped state + stopWs := dbfake.WorkspaceBuild(t, store, database.WorkspaceTable{ + OrganizationID: owner.OrganizationID, + OwnerID: member.ID, + }).Seed(database.WorkspaceBuild{ + Transition: database.WorkspaceTransitionStop, + }).Do() + + // When: the coder_workspace_transition tool is called with a start transition + ctr := makeJSONRPCRequest(t, "tools/call", "coder_workspace_transition", map[string]any{ + "workspace": stopWs.Workspace.ID.String(), + "transition": "start", + }) + + pty.WriteLine(ctr) + _ = pty.ReadLine(ctx) // skip the echo + + // Then: the response is as expected + expected := makeJSONRPCTextResponse(t, `{"status":"pending","transition":"start"}`) // no provisionerd yet + actual := pty.ReadLine(ctx) + testutil.RequireJSONEq(t, expected, actual) + }) +} + +// makeJSONRPCRequest is a helper function that makes a JSON RPC request. +func makeJSONRPCRequest(t *testing.T, method, name string, args map[string]any) string { + t.Helper() + req := mcp.JSONRPCRequest{ + ID: "1", + JSONRPC: "2.0", + Request: mcp.Request{Method: method}, + Params: struct { // Unfortunately, there is no type for this yet. + Name string "json:\"name\"" + Arguments map[string]any "json:\"arguments,omitempty\"" + Meta *struct { + ProgressToken mcp.ProgressToken "json:\"progressToken,omitempty\"" + } "json:\"_meta,omitempty\"" + }{ + Name: name, + Arguments: args, + }, + } + bs, err := json.Marshal(req) + require.NoError(t, err, "failed to marshal JSON RPC request") + return string(bs) +} + +// makeJSONRPCTextResponse is a helper function that makes a JSON RPC text response +func makeJSONRPCTextResponse(t *testing.T, text string) string { + t.Helper() + + resp := mcp.JSONRPCResponse{ + ID: "1", + JSONRPC: "2.0", + Result: mcp.CallToolResult{ + Content: []mcp.Content{ + mcp.NewTextContent(text), + }, + }, + } + bs, err := json.Marshal(resp) + require.NoError(t, err, "failed to marshal JSON RPC response") + return string(bs) +} + +// startTestMCPServer is a helper function that starts a MCP server listening on +// a pty. It is the responsibility of the caller to close the server. +func startTestMCPServer(ctx context.Context, t testing.TB, stdin io.Reader, stdout io.Writer) (*server.MCPServer, func() error) { + t.Helper() + + mcpSrv := server.NewMCPServer( + "Test Server", + "0.0.0", + server.WithInstructions(""), + server.WithLogging(), + ) + + stdioSrv := server.NewStdioServer(mcpSrv) + + cancelCtx, cancel := context.WithCancel(ctx) + closeCh := make(chan struct{}) + done := make(chan error) + go func() { + defer close(done) + srvErr := stdioSrv.Listen(cancelCtx, stdin, stdout) + done <- srvErr + }() + + go func() { + select { + case <-closeCh: + cancel() + case <-done: + cancel() + } + }() + + return mcpSrv, func() error { + close(closeCh) + return <-done + } +} diff --git a/testutil/json.go b/testutil/json.go new file mode 100644 index 0000000000000..006617d1ca030 --- /dev/null +++ b/testutil/json.go @@ -0,0 +1,27 @@ +package testutil + +import ( + "encoding/json" + "testing" + + "github.com/google/go-cmp/cmp" +) + +// RequireJSONEq is like assert.RequireJSONEq, but it's actually readable. +// Note that this calls t.Fatalf under the hood, so it should never +// be called in a goroutine. +func RequireJSONEq(t *testing.T, expected, actual string) { + t.Helper() + + var expectedJSON, actualJSON any + if err := json.Unmarshal([]byte(expected), &expectedJSON); err != nil { + t.Fatalf("failed to unmarshal expected JSON: %s", err) + } + if err := json.Unmarshal([]byte(actual), &actualJSON); err != nil { + t.Fatalf("failed to unmarshal actual JSON: %s", err) + } + + if diff := cmp.Diff(expectedJSON, actualJSON); diff != "" { + t.Fatalf("JSON diff (-want +got):\n%s", diff) + } +}