Skip to content

Commit c54620f

Browse files
committed
allow specifying only allowed tools
1 parent df1ca76 commit c54620f

File tree

6 files changed

+274
-135
lines changed

6 files changed

+274
-135
lines changed

cli/exp_mcp.go

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,12 @@ func (r *RootCmd) mcpCommand() *serpent.Command {
1616
var (
1717
client = new(codersdk.Client)
1818
instructions string
19+
allowedTools []string
1920
)
2021
return &serpent.Command{
2122
Use: "mcp",
2223
Handler: func(inv *serpent.Invocation) error {
23-
return mcpHandler(inv, client, instructions)
24+
return mcpHandler(inv, client, instructions, allowedTools)
2425
},
2526
Short: "Start an MCP server that can be used to interact with a Coder depoyment.",
2627
Middleware: serpent.Chain(
@@ -33,11 +34,17 @@ func (r *RootCmd) mcpCommand() *serpent.Command {
3334
Flag: "instructions",
3435
Value: serpent.StringOf(&instructions),
3536
},
37+
{
38+
Name: "allowed-tools",
39+
Description: "Comma-separated list of allowed tools. If not specified, all tools are allowed.",
40+
Flag: "allowed-tools",
41+
Value: serpent.StringArrayOf(&allowedTools),
42+
},
3643
},
3744
}
3845
}
3946

40-
func mcpHandler(inv *serpent.Invocation, client *codersdk.Client, instructions string) error {
47+
func mcpHandler(inv *serpent.Invocation, client *codersdk.Client, instructions string, allowedTools []string) error {
4148
ctx, cancel := context.WithCancel(inv.Context())
4249
defer cancel()
4350

@@ -51,9 +58,12 @@ func mcpHandler(inv *serpent.Invocation, client *codersdk.Client, instructions s
5158
return err
5259
}
5360
cliui.Infof(inv.Stderr, "Starting MCP server")
54-
cliui.Infof(inv.Stderr, "User : %s", me.Username)
55-
cliui.Infof(inv.Stderr, "URL : %s", client.URL)
56-
cliui.Infof(inv.Stderr, "Instructions : %q", instructions)
61+
cliui.Infof(inv.Stderr, "User : %s", me.Username)
62+
cliui.Infof(inv.Stderr, "URL : %s", client.URL)
63+
cliui.Infof(inv.Stderr, "Instructions : %q", instructions)
64+
if len(allowedTools) > 0 {
65+
cliui.Infof(inv.Stderr, "Allowed Tools : %v", allowedTools)
66+
}
5767
cliui.Infof(inv.Stderr, "Press Ctrl+C to stop the server")
5868

5969
// Capture the original stdin, stdout, and stderr.
@@ -66,12 +76,19 @@ func mcpHandler(inv *serpent.Invocation, client *codersdk.Client, instructions s
6676
inv.Stderr = invStderr
6777
}()
6878

69-
closer := codermcp.New(ctx, client,
79+
options := []codermcp.Option{
7080
codermcp.WithInstructions(instructions),
7181
codermcp.WithLogger(&logger),
7282
codermcp.WithStdin(invStdin),
7383
codermcp.WithStdout(invStdout),
74-
)
84+
}
85+
86+
// Add allowed tools option if specified
87+
if len(allowedTools) > 0 {
88+
options = append(options, codermcp.WithAllowedTools(allowedTools))
89+
}
90+
91+
closer := codermcp.New(ctx, client, options...)
7592

7693
<-ctx.Done()
7794
if err := closer.Close(); err != nil {

cli/exp_mcp_test.go

Lines changed: 61 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package cli_test
33
import (
44
"context"
55
"encoding/json"
6+
"slices"
67
"testing"
78

89
"github.com/stretchr/testify/assert"
@@ -17,6 +18,61 @@ import (
1718
func TestExpMcp(t *testing.T) {
1819
t.Parallel()
1920

21+
t.Run("AllowedTools", func(t *testing.T) {
22+
t.Parallel()
23+
24+
ctx := testutil.Context(t, testutil.WaitShort)
25+
cancelCtx, cancel := context.WithCancel(ctx)
26+
t.Cleanup(cancel)
27+
28+
// Given: a running coder deployment
29+
client := coderdtest.New(t, nil)
30+
_ = coderdtest.CreateFirstUser(t, client)
31+
32+
// Given: we run the exp mcp command with allowed tools set
33+
inv, root := clitest.New(t, "exp", "mcp", "--allowed-tools=coder_whoami,coder_list_templates")
34+
inv = inv.WithContext(cancelCtx)
35+
36+
pty := ptytest.New(t)
37+
inv.Stdin = pty.Input()
38+
inv.Stdout = pty.Output()
39+
clitest.SetupConfig(t, client, root)
40+
41+
cmdDone := make(chan struct{})
42+
go func() {
43+
defer close(cmdDone)
44+
err := inv.Run()
45+
assert.NoError(t, err)
46+
}()
47+
48+
// When: we send a tools/list request
49+
toolsPayload := `{"jsonrpc":"2.0","id":2,"method":"tools/list"}`
50+
pty.WriteLine(toolsPayload)
51+
_ = pty.ReadLine(ctx) // ignore echoed output
52+
output := pty.ReadLine(ctx)
53+
54+
cancel()
55+
<-cmdDone
56+
57+
// Then: we should only see the allowed tools in the response
58+
var toolsResponse struct {
59+
Result struct {
60+
Tools []struct {
61+
Name string `json:"name"`
62+
} `json:"tools"`
63+
} `json:"result"`
64+
}
65+
err := json.Unmarshal([]byte(output), &toolsResponse)
66+
require.NoError(t, err)
67+
require.Len(t, toolsResponse.Result.Tools, 2, "should have exactly 2 tools")
68+
foundTools := make([]string, 0, 2)
69+
for _, tool := range toolsResponse.Result.Tools {
70+
foundTools = append(foundTools, tool.Name)
71+
}
72+
slices.Sort(foundTools)
73+
require.Equal(t, []string{"coder_list_templates", "coder_whoami"}, foundTools)
74+
})
75+
2076
t.Run("OK", func(t *testing.T) {
2177
t.Parallel()
2278

@@ -34,10 +90,12 @@ func TestExpMcp(t *testing.T) {
3490
inv.Stdout = pty.Output()
3591
clitest.SetupConfig(t, client, root)
3692

37-
cmdDone := tGo(t, func() {
93+
cmdDone := make(chan struct{})
94+
go func() {
95+
defer close(cmdDone)
3896
err := inv.Run()
3997
assert.NoError(t, err)
40-
})
98+
}()
4199

42100
payload := `{"jsonrpc":"2.0","id":1,"method":"initialize"}`
43101
pty.WriteLine(payload)
@@ -75,4 +133,4 @@ func TestExpMcp(t *testing.T) {
75133
err := inv.Run()
76134
assert.ErrorContains(t, err, "your session has expired")
77135
})
78-
}
136+
}

mcp/mcp.go

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ type mcpOptions struct {
2121
out io.Writer
2222
instructions string
2323
logger *slog.Logger
24+
allowedTools []string
2425
}
2526

2627
// Option is a function that configures the MCP server.
@@ -54,6 +55,13 @@ func WithStdout(out io.Writer) Option {
5455
}
5556
}
5657

58+
// WithAllowedTools sets the allowed tools for the MCP server.
59+
func WithAllowedTools(tools []string) Option {
60+
return func(o *mcpOptions) {
61+
o.allowedTools = tools
62+
}
63+
}
64+
5765
// New creates a new MCP server with the given client and options.
5866
func New(ctx context.Context, client *codersdk.Client, opts ...Option) io.Closer {
5967
options := &mcpOptions{
@@ -74,14 +82,15 @@ func New(ctx context.Context, client *codersdk.Client, opts ...Option) io.Closer
7482

7583
logger := slog.Make(sloghuman.Sink(os.Stdout))
7684

77-
mcptools.RegisterCoderReportTask(mcpSrv, client, logger)
78-
mcptools.RegisterCoderWhoami(mcpSrv, client)
79-
mcptools.RegisterCoderListTemplates(mcpSrv, client)
80-
mcptools.RegisterCoderListWorkspaces(mcpSrv, client)
81-
mcptools.RegisterCoderGetWorkspace(mcpSrv, client)
82-
mcptools.RegisterCoderWorkspaceExec(mcpSrv, client)
83-
mcptools.RegisterCoderStartWorkspace(mcpSrv, client)
84-
mcptools.RegisterCoderStopWorkspace(mcpSrv, client)
85+
// Register tools based on the allowed list (if specified)
86+
reg := mcptools.AllTools()
87+
if len(options.allowedTools) > 0 {
88+
reg = reg.WithOnlyAllowed(options.allowedTools...)
89+
}
90+
reg.Register(mcpSrv, mcptools.ToolDeps{
91+
Client: client,
92+
Logger: &logger,
93+
})
8594

8695
srv := server.NewStdioServer(mcpSrv)
8796
srv.SetErrorLogger(log.New(options.out, "", log.LstdFlags))

0 commit comments

Comments
 (0)