Skip to content

Commit 1e0040b

Browse files
committed
add command filtering
1 parent c54620f commit 1e0040b

File tree

7 files changed

+234
-15
lines changed

7 files changed

+234
-15
lines changed

cli/exp_mcp.go

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,15 @@ import (
1414

1515
func (r *RootCmd) mcpCommand() *serpent.Command {
1616
var (
17-
client = new(codersdk.Client)
18-
instructions string
19-
allowedTools []string
17+
client = new(codersdk.Client)
18+
instructions string
19+
allowedTools []string
20+
allowedExecCommands []string
2021
)
2122
return &serpent.Command{
2223
Use: "mcp",
2324
Handler: func(inv *serpent.Invocation) error {
24-
return mcpHandler(inv, client, instructions, allowedTools)
25+
return mcpHandler(inv, client, instructions, allowedTools, allowedExecCommands)
2526
},
2627
Short: "Start an MCP server that can be used to interact with a Coder depoyment.",
2728
Middleware: serpent.Chain(
@@ -40,11 +41,17 @@ func (r *RootCmd) mcpCommand() *serpent.Command {
4041
Flag: "allowed-tools",
4142
Value: serpent.StringArrayOf(&allowedTools),
4243
},
44+
{
45+
Name: "allowed-exec-commands",
46+
Description: "Comma-separated list of allowed commands for workspace execution. If not specified, all commands are allowed.",
47+
Flag: "allowed-exec-commands",
48+
Value: serpent.StringArrayOf(&allowedExecCommands),
49+
},
4350
},
4451
}
4552
}
4653

47-
func mcpHandler(inv *serpent.Invocation, client *codersdk.Client, instructions string, allowedTools []string) error {
54+
func mcpHandler(inv *serpent.Invocation, client *codersdk.Client, instructions string, allowedTools []string, allowedExecCommands []string) error {
4855
ctx, cancel := context.WithCancel(inv.Context())
4956
defer cancel()
5057

@@ -64,6 +71,9 @@ func mcpHandler(inv *serpent.Invocation, client *codersdk.Client, instructions s
6471
if len(allowedTools) > 0 {
6572
cliui.Infof(inv.Stderr, "Allowed Tools : %v", allowedTools)
6673
}
74+
if len(allowedExecCommands) > 0 {
75+
cliui.Infof(inv.Stderr, "Allowed Exec Commands : %v", allowedExecCommands)
76+
}
6777
cliui.Infof(inv.Stderr, "Press Ctrl+C to stop the server")
6878

6979
// Capture the original stdin, stdout, and stderr.
@@ -88,6 +98,11 @@ func mcpHandler(inv *serpent.Invocation, client *codersdk.Client, instructions s
8898
options = append(options, codermcp.WithAllowedTools(allowedTools))
8999
}
90100

101+
// Add allowed exec commands option if specified
102+
if len(allowedExecCommands) > 0 {
103+
options = append(options, codermcp.WithAllowedExecCommands(allowedExecCommands))
104+
}
105+
91106
closer := codermcp.New(ctx, client, options...)
92107

93108
<-ctx.Done()
@@ -98,4 +113,4 @@ func mcpHandler(inv *serpent.Invocation, client *codersdk.Client, instructions s
98113
}
99114
}
100115
return nil
101-
}
116+
}

mcp/mcp.go

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,12 @@ import (
1717
)
1818

1919
type mcpOptions struct {
20-
in io.Reader
21-
out io.Writer
22-
instructions string
23-
logger *slog.Logger
24-
allowedTools []string
20+
in io.Reader
21+
out io.Writer
22+
instructions string
23+
logger *slog.Logger
24+
allowedTools []string
25+
allowedExecCommands []string
2526
}
2627

2728
// Option is a function that configures the MCP server.
@@ -62,6 +63,13 @@ func WithAllowedTools(tools []string) Option {
6263
}
6364
}
6465

66+
// WithAllowedExecCommands sets the allowed commands for workspace execution.
67+
func WithAllowedExecCommands(commands []string) Option {
68+
return func(o *mcpOptions) {
69+
o.allowedExecCommands = commands
70+
}
71+
}
72+
6573
// New creates a new MCP server with the given client and options.
6674
func New(ctx context.Context, client *codersdk.Client, opts ...Option) io.Closer {
6775
options := &mcpOptions{
@@ -88,8 +96,9 @@ func New(ctx context.Context, client *codersdk.Client, opts ...Option) io.Closer
8896
reg = reg.WithOnlyAllowed(options.allowedTools...)
8997
}
9098
reg.Register(mcpSrv, mcptools.ToolDeps{
91-
Client: client,
92-
Logger: &logger,
99+
Client: client,
100+
Logger: &logger,
101+
AllowedExecCommands: options.allowedExecCommands,
93102
})
94103

95104
srv := server.NewStdioServer(mcpSrv)

mcp/tools/command_validator.go

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
package mcptools
2+
3+
import (
4+
"strings"
5+
6+
"github.com/google/shlex"
7+
"golang.org/x/xerrors"
8+
)
9+
10+
// IsCommandAllowed checks if a command is in the allowed list.
11+
// It parses the command using shlex to correctly handle quoted arguments
12+
// and only checks the executable name (first part of the command).
13+
func IsCommandAllowed(command string, allowedCommands []string) (bool, error) {
14+
if len(allowedCommands) == 0 {
15+
// If no allowed commands are specified, all commands are allowed
16+
return true, nil
17+
}
18+
19+
// Parse the command to extract the executable name
20+
parts, err := shlex.Split(command)
21+
if err != nil {
22+
return false, xerrors.Errorf("failed to parse command: %w", err)
23+
}
24+
25+
if len(parts) == 0 {
26+
return false, xerrors.New("empty command")
27+
}
28+
29+
// The first part is the executable name
30+
executable := parts[0]
31+
32+
// Check if the executable is in the allowed list
33+
for _, allowed := range allowedCommands {
34+
if allowed == executable {
35+
return true, nil
36+
}
37+
}
38+
39+
// Build a helpful error message
40+
return false, xerrors.Errorf("command %q is not allowed. Allowed commands: %s",
41+
executable, strings.Join(allowedCommands, ", "))
42+
}

mcp/tools/command_validator_test.go

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
package mcptools_test
2+
3+
import (
4+
"testing"
5+
6+
"github.com/stretchr/testify/require"
7+
8+
mcptools "github.com/coder/coder/v2/mcp/tools"
9+
)
10+
11+
func TestIsCommandAllowed(t *testing.T) {
12+
tests := []struct {
13+
name string
14+
command string
15+
allowedCommands []string
16+
want bool
17+
wantErr bool
18+
errorMessage string
19+
}{
20+
{
21+
name: "empty allowed commands allows all",
22+
command: "ls -la",
23+
allowedCommands: []string{},
24+
want: true,
25+
wantErr: false,
26+
},
27+
{
28+
name: "allowed command",
29+
command: "ls -la",
30+
allowedCommands: []string{"ls", "cat", "grep"},
31+
want: true,
32+
wantErr: false,
33+
},
34+
{
35+
name: "disallowed command",
36+
command: "rm -rf /",
37+
allowedCommands: []string{"ls", "cat", "grep"},
38+
want: false,
39+
wantErr: true,
40+
errorMessage: "not allowed",
41+
},
42+
{
43+
name: "command with quotes",
44+
command: "echo \"hello world\"",
45+
allowedCommands: []string{"echo", "cat", "grep"},
46+
want: true,
47+
wantErr: false,
48+
},
49+
{
50+
name: "command with path",
51+
command: "/bin/ls -la",
52+
allowedCommands: []string{"/bin/ls", "cat", "grep"},
53+
want: true,
54+
wantErr: false,
55+
},
56+
{
57+
name: "empty command",
58+
command: "",
59+
allowedCommands: []string{"ls", "cat", "grep"},
60+
want: false,
61+
wantErr: true,
62+
errorMessage: "empty command",
63+
},
64+
}
65+
66+
for _, tt := range tests {
67+
t.Run(tt.name, func(t *testing.T) {
68+
got, err := mcptools.IsCommandAllowed(tt.command, tt.allowedCommands)
69+
if tt.wantErr {
70+
require.Error(t, err)
71+
if tt.errorMessage != "" {
72+
require.Contains(t, err.Error(), tt.errorMessage)
73+
}
74+
} else {
75+
require.NoError(t, err)
76+
}
77+
require.Equal(t, tt.want, got)
78+
})
79+
}
80+
}

mcp/tools/tools_coder.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,15 @@ func handleCoderWorkspaceExec(deps ToolDeps) mcpserver.ToolHandlerFunc {
194194
return nil, xerrors.New("command is required")
195195
}
196196

197+
// Validate the command if allowed commands are specified
198+
allowed, err := IsCommandAllowed(command, deps.AllowedExecCommands)
199+
if err != nil {
200+
return nil, err
201+
}
202+
if !allowed {
203+
return nil, xerrors.Errorf("command not allowed: %s", command)
204+
}
205+
197206
// Attempt to fetch the workspace. We may get a UUID or a name, so try to
198207
// handle both.
199208
ws, err := getWorkspaceByIDOrOwnerName(ctx, deps.Client, wsArg)

mcp/tools/tools_coder_test.go

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,69 @@ func TestCoderTools(t *testing.T) {
198198
testutil.RequireJSONEq(t, expected, actual)
199199
})
200200

201+
t.Run("tool_and_command_restrictions", func(t *testing.T) {
202+
// Given: a restricted MCP server with only allowed tools and commands
203+
restrictedPty := ptytest.New(t)
204+
allowedTools := []string{"coder_workspace_exec"}
205+
allowedCommands := []string{"echo", "ls"}
206+
restrictedMCPSrv, closeRestrictedSrv := startTestMCPServer(ctx, t, restrictedPty.Input(), restrictedPty.Output())
207+
t.Cleanup(func() {
208+
_ = closeRestrictedSrv()
209+
})
210+
mcptools.AllTools().
211+
WithOnlyAllowed(allowedTools...).
212+
Register(restrictedMCPSrv, mcptools.ToolDeps{
213+
Client: memberClient,
214+
Logger: &logger,
215+
AllowedExecCommands: allowedCommands,
216+
})
217+
218+
// When: the tools/list command is called
219+
toolsListCmd := makeJSONRPCRequest(t, "tools/list", "", nil)
220+
restrictedPty.WriteLine(toolsListCmd)
221+
_ = restrictedPty.ReadLine(ctx) // skip the echo
222+
223+
// Then: the response is a list of only the allowed tools.
224+
toolsListResponse := restrictedPty.ReadLine(ctx)
225+
require.Contains(t, toolsListResponse, "coder_workspace_exec")
226+
require.NotContains(t, toolsListResponse, "coder_whoami")
227+
228+
// When: a disallowed tool is called
229+
disallowedToolCmd := makeJSONRPCRequest(t, "tools/call", "coder_whoami", map[string]any{})
230+
restrictedPty.WriteLine(disallowedToolCmd)
231+
_ = restrictedPty.ReadLine(ctx) // skip the echo
232+
233+
// Then: the response is an error indicating the tool is not available.
234+
disallowedToolResponse := restrictedPty.ReadLine(ctx)
235+
require.Contains(t, disallowedToolResponse, "error")
236+
require.Contains(t, disallowedToolResponse, "not found")
237+
238+
// When: an allowed exec command is called
239+
randString := testutil.GetRandomName(t)
240+
allowedCmd := makeJSONRPCRequest(t, "tools/call", "coder_workspace_exec", map[string]any{
241+
"workspace": r.Workspace.ID.String(),
242+
"command": "echo " + randString,
243+
})
244+
245+
// Then: the response is the output of the command.
246+
restrictedPty.WriteLine(allowedCmd)
247+
_ = restrictedPty.ReadLine(ctx) // skip the echo
248+
actual := restrictedPty.ReadLine(ctx)
249+
require.Contains(t, actual, randString)
250+
251+
// When: a disallowed exec command is called
252+
disallowedCmd := makeJSONRPCRequest(t, "tools/call", "coder_workspace_exec", map[string]any{
253+
"workspace": r.Workspace.ID.String(),
254+
"command": "evil --hax",
255+
})
256+
257+
// Then: the response is an error indicating the command is not allowed.
258+
restrictedPty.WriteLine(disallowedCmd)
259+
_ = restrictedPty.ReadLine(ctx) // skip the echo
260+
errorResponse := restrictedPty.ReadLine(ctx)
261+
require.Contains(t, errorResponse, `command \"evil\" is not allowed`)
262+
})
263+
201264
t.Run("coder_start_workspace", func(t *testing.T) {
202265
// Given: a separate workspace in the stopped state
203266
stopWs := dbfake.WorkspaceBuild(t, store, database.WorkspaceTable{

mcp/tools/tools_registry.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,8 +91,9 @@ var _ ToolAdder = (*server.MCPServer)(nil)
9191

9292
// ToolDeps contains all dependencies needed by tool handlers
9393
type ToolDeps struct {
94-
Client *codersdk.Client
95-
Logger *slog.Logger
94+
Client *codersdk.Client
95+
Logger *slog.Logger
96+
AllowedExecCommands []string
9697
}
9798

9899
// ToolHandler associates a tool with its handler creation function

0 commit comments

Comments
 (0)