Skip to content

Commit 15720af

Browse files
committed
Add coder_workspace_write_file MCP tool
1 parent 2828b76 commit 15720af

File tree

7 files changed

+342
-3
lines changed

7 files changed

+342
-3
lines changed

agent/api.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ func (a *agent) apiHandler() http.Handler {
6161
r.Get("/api/v0/netcheck", a.HandleNetcheck)
6262
r.Post("/api/v0/list-directory", a.HandleLS)
6363
r.Get("/api/v0/read-file", a.HandleReadFile)
64+
r.Post("/api/v0/write-file", a.HandleWriteFile)
6465
r.Get("/debug/logs", a.HandleHTTPDebugLogs)
6566
r.Get("/debug/magicsock", a.HandleHTTPDebugMagicsock)
6667
r.Get("/debug/magicsock/debug-logging/{state}", a.HandleHTTPMagicsockDebugLoggingState)

agent/files.go

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,14 @@ package agent
33
import (
44
"context"
55
"errors"
6+
"fmt"
67
"io"
78
"mime"
89
"net/http"
910
"os"
1011
"path/filepath"
1112
"strconv"
13+
"strings"
1214

1315
"golang.org/x/xerrors"
1416

@@ -94,3 +96,71 @@ func (a *agent) streamFile(ctx context.Context, rw http.ResponseWriter, path str
9496

9597
return 0, nil
9698
}
99+
100+
func (a *agent) HandleWriteFile(rw http.ResponseWriter, r *http.Request) {
101+
ctx := r.Context()
102+
103+
query := r.URL.Query()
104+
parser := httpapi.NewQueryParamParser().RequiredNotEmpty("path")
105+
path := parser.String(query, "", "path")
106+
parser.ErrorExcessParams(query)
107+
if len(parser.Errors) > 0 {
108+
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
109+
Message: "Query parameters have invalid values.",
110+
Validations: parser.Errors,
111+
})
112+
return
113+
}
114+
115+
status, err := a.writeFile(ctx, r, path)
116+
if err != nil {
117+
httpapi.Write(ctx, rw, status, codersdk.Response{
118+
Message: err.Error(),
119+
})
120+
return
121+
}
122+
123+
httpapi.Write(ctx, rw, http.StatusOK, codersdk.Response{
124+
Message: fmt.Sprintf("Successfully wrote to %q", path),
125+
})
126+
}
127+
128+
func (a *agent) writeFile(ctx context.Context, r *http.Request, path string) (int, error) {
129+
if !filepath.IsAbs(path) {
130+
return http.StatusBadRequest, xerrors.Errorf("file path must be absolute: %q", path)
131+
}
132+
133+
dir := filepath.Dir(path)
134+
err := a.filesystem.MkdirAll(dir, 0o755)
135+
if err != nil {
136+
status := http.StatusInternalServerError
137+
switch {
138+
case errors.Is(err, os.ErrPermission):
139+
status = http.StatusForbidden
140+
case strings.Contains(err.Error(), "not a directory"):
141+
status = http.StatusBadRequest
142+
}
143+
return status, xerrors.Errorf("failed to create directory %q: %w", dir, err)
144+
}
145+
146+
f, err := a.filesystem.Create(path)
147+
if err != nil {
148+
status := http.StatusInternalServerError
149+
switch {
150+
case errors.Is(err, os.ErrPermission):
151+
status = http.StatusForbidden
152+
case strings.Contains(err.Error(), "is a directory") ||
153+
strings.Contains(err.Error(), "not a directory"):
154+
status = http.StatusBadRequest
155+
}
156+
return status, xerrors.Errorf("failed to create file %q: %w", path, err)
157+
}
158+
defer f.Close()
159+
160+
_, err = io.Copy(f, r.Body)
161+
if err != nil && !errors.Is(err, io.EOF) && ctx.Err() == nil {
162+
a.logger.Error(ctx, "workspace agent write file", slog.Error(err))
163+
}
164+
165+
return 0, nil
166+
}

agent/files_test.go

Lines changed: 137 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,12 @@ import (
55
"net/http"
66
"os"
77
"path/filepath"
8+
"slices"
89
"testing"
910

1011
"github.com/spf13/afero"
1112
"github.com/stretchr/testify/require"
13+
"golang.org/x/xerrors"
1214

1315
"github.com/coder/coder/v2/agent"
1416
"github.com/coder/coder/v2/agent/agenttest"
@@ -19,23 +21,57 @@ import (
1921

2022
type testFs struct {
2123
afero.Fs
22-
deny string
24+
deny []string
2325
}
2426

25-
func newTestFs(base afero.Fs, deny string) *testFs {
27+
func newTestFs(base afero.Fs, deny ...string) *testFs {
2628
return &testFs{
2729
Fs: base,
2830
deny: deny,
2931
}
3032
}
3133

3234
func (fs *testFs) Open(name string) (afero.File, error) {
33-
if name == fs.deny {
35+
if slices.Contains(fs.deny, name) {
3436
return nil, os.ErrPermission
3537
}
3638
return fs.Fs.Open(name)
3739
}
3840

41+
func (fs *testFs) Create(name string) (afero.File, error) {
42+
if slices.Contains(fs.deny, name) {
43+
return nil, os.ErrPermission
44+
}
45+
// Unlike os, afero lets you create files where directories already exist and
46+
// lets you nest them underneath files, somehow.
47+
stat, err := fs.Fs.Stat(name)
48+
if err == nil && stat.IsDir() {
49+
return nil, xerrors.New("is a directory")
50+
}
51+
stat, err = fs.Fs.Stat(filepath.Dir(name))
52+
if err == nil && !stat.IsDir() {
53+
return nil, xerrors.New("not a directory")
54+
}
55+
return fs.Fs.Create(name)
56+
}
57+
58+
func (fs *testFs) MkdirAll(name string, mode os.FileMode) error {
59+
if slices.Contains(fs.deny, name) {
60+
return os.ErrPermission
61+
}
62+
// Unlike os, afero lets you create directories where files already exist and
63+
// lets you nest them underneath files somehow.
64+
stat, err := fs.Fs.Stat(filepath.Dir(name))
65+
if err == nil && !stat.IsDir() {
66+
return xerrors.New("not a directory")
67+
}
68+
stat, err = fs.Fs.Stat(name)
69+
if err == nil && !stat.IsDir() {
70+
return xerrors.New("not a directory")
71+
}
72+
return fs.Fs.MkdirAll(name, mode)
73+
}
74+
3975
func TestReadFile(t *testing.T) {
4076
t.Parallel()
4177

@@ -200,3 +236,101 @@ func TestReadFile(t *testing.T) {
200236
})
201237
}
202238
}
239+
240+
func TestWriteFile(t *testing.T) {
241+
t.Parallel()
242+
243+
tmpdir := os.TempDir()
244+
noPermsFilePath := filepath.Join(tmpdir, "no-perms-file")
245+
noPermsDirPath := filepath.Join(tmpdir, "no-perms-dir")
246+
//nolint:dogsled
247+
conn, _, _, fs, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, opts *agent.Options) {
248+
opts.Filesystem = newTestFs(opts.Filesystem, noPermsFilePath, noPermsDirPath)
249+
})
250+
251+
dirPath := filepath.Join(tmpdir, "directory")
252+
err := fs.MkdirAll(dirPath, 0o755)
253+
require.NoError(t, err)
254+
255+
filePath := filepath.Join(tmpdir, "file")
256+
err = afero.WriteFile(fs, filePath, []byte("content"), 0o644)
257+
require.NoError(t, err)
258+
259+
tests := []struct {
260+
name string
261+
path string
262+
bytes []byte
263+
errCode int
264+
error string
265+
}{
266+
{
267+
name: "NoPath",
268+
path: "",
269+
errCode: http.StatusBadRequest,
270+
error: "\"path\" is required",
271+
},
272+
{
273+
name: "RelativePath",
274+
path: "./relative",
275+
errCode: http.StatusBadRequest,
276+
error: "file path must be absolute",
277+
},
278+
{
279+
name: "RelativePath",
280+
path: "also-relative",
281+
errCode: http.StatusBadRequest,
282+
error: "file path must be absolute",
283+
},
284+
{
285+
name: "NonExistent",
286+
path: filepath.Join(tmpdir, "/nested/does-not-exist"),
287+
bytes: []byte("now it does exist"),
288+
},
289+
{
290+
name: "IsDir",
291+
path: dirPath,
292+
errCode: http.StatusBadRequest,
293+
error: "is a directory",
294+
},
295+
{
296+
name: "IsNotDir",
297+
path: filepath.Join(filePath, "file2"),
298+
errCode: http.StatusBadRequest,
299+
error: "not a directory",
300+
},
301+
{
302+
name: "NoPermissionsFile",
303+
path: noPermsFilePath,
304+
errCode: http.StatusForbidden,
305+
error: "permission denied",
306+
},
307+
{
308+
name: "NoPermissionsDir",
309+
path: filepath.Join(noPermsDirPath, "within-no-perm-dir"),
310+
errCode: http.StatusForbidden,
311+
error: "permission denied",
312+
},
313+
}
314+
315+
for _, tt := range tests {
316+
t.Run(tt.name, func(t *testing.T) {
317+
t.Parallel()
318+
319+
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
320+
defer cancel()
321+
322+
err := conn.WriteFile(ctx, tt.path, tt.bytes)
323+
if tt.errCode != 0 {
324+
require.Error(t, err)
325+
cerr := coderdtest.SDKError(t, err)
326+
require.Equal(t, tt.errCode, cerr.StatusCode())
327+
require.Contains(t, cerr.Error(), tt.error)
328+
} else {
329+
require.NoError(t, err)
330+
b, err := afero.ReadFile(fs, tt.path)
331+
require.NoError(t, err)
332+
require.Equal(t, tt.bytes, b)
333+
}
334+
})
335+
}
336+
}

codersdk/toolsdk/toolsdk.go

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ const (
4141
ToolNameChatGPTSearch = "search"
4242
ToolNameChatGPTFetch = "fetch"
4343
ToolNameWorkspaceReadFile = "coder_workspace_read_file"
44+
ToolNameWorkspaceWriteFile = "coder_workspace_write_file"
4445
)
4546

4647
func NewDeps(client *codersdk.Client, opts ...func(*Deps)) (Deps, error) {
@@ -209,6 +210,7 @@ var All = []GenericTool{
209210
ChatGPTSearch.Generic(),
210211
ChatGPTFetch.Generic(),
211212
WorkspaceReadFile.Generic(),
213+
WorkspaceWriteFile.Generic(),
212214
}
213215

214216
type ReportTaskArgs struct {
@@ -1444,3 +1446,74 @@ var WorkspaceReadFile = Tool[WorkspaceReadFileArgs, WorkspaceReadFileResponse]{
14441446
return WorkspaceReadFileResponse{Content: bytes, MimeType: mimeType}, nil
14451447
},
14461448
}
1449+
1450+
type WorkspaceWriteFileArgs struct {
1451+
Workspace string `json:"workspace"`
1452+
Path string `json:"path"`
1453+
Content []byte `json:"content"`
1454+
}
1455+
1456+
var WorkspaceWriteFile = Tool[WorkspaceWriteFileArgs, codersdk.Response]{
1457+
Tool: aisdk.Tool{
1458+
Name: ToolNameWorkspaceWriteFile,
1459+
Description: `Write a file in a workspace.`,
1460+
Schema: aisdk.Schema{
1461+
Properties: map[string]any{
1462+
"workspace": map[string]any{
1463+
"type": "string",
1464+
"description": "The workspace name in the format [owner/]workspace[.agent]. If an owner is not specified, the authenticated user is used.",
1465+
},
1466+
"path": map[string]any{
1467+
"type": "string",
1468+
"description": "The absolute path of the file to write in the workspace.",
1469+
},
1470+
"content": map[string]any{
1471+
"type": "string",
1472+
"description": "The base64-encoded bytes to write to the file.",
1473+
},
1474+
},
1475+
Required: []string{"path", "workspace", "content"},
1476+
},
1477+
},
1478+
UserClientOptional: true,
1479+
Handler: func(ctx context.Context, deps Deps, args WorkspaceWriteFileArgs) (codersdk.Response, error) {
1480+
workspaceName := NormalizeWorkspaceInput(args.Workspace)
1481+
_, workspaceAgent, err := findWorkspaceAndAgent(ctx, deps.coderClient, workspaceName)
1482+
if err != nil {
1483+
return codersdk.Response{}, xerrors.Errorf("failed to find workspace: %w", err)
1484+
}
1485+
1486+
// Wait for agent to be ready.
1487+
if err := cliui.Agent(ctx, io.Discard, workspaceAgent.ID, cliui.AgentOptions{
1488+
FetchInterval: 0,
1489+
Fetch: deps.coderClient.WorkspaceAgent,
1490+
FetchLogs: deps.coderClient.WorkspaceAgentLogsAfter,
1491+
Wait: true, // Always wait for startup scripts
1492+
}); err != nil {
1493+
return codersdk.Response{}, xerrors.Errorf("agent not ready: %w", err)
1494+
}
1495+
1496+
wsClient := workspacesdk.New(deps.coderClient)
1497+
1498+
conn, err := wsClient.DialAgent(ctx, workspaceAgent.ID, &workspacesdk.DialAgentOptions{
1499+
BlockEndpoints: false,
1500+
})
1501+
if err != nil {
1502+
return codersdk.Response{}, xerrors.Errorf("failed to dial agent: %w", err)
1503+
}
1504+
defer conn.Close()
1505+
1506+
if !conn.AwaitReachable(ctx) {
1507+
return codersdk.Response{}, xerrors.New("agent connection not reachable")
1508+
}
1509+
1510+
err = conn.WriteFile(ctx, args.Path, args.Content)
1511+
if err != nil {
1512+
return codersdk.Response{}, err
1513+
}
1514+
1515+
return codersdk.Response{
1516+
Message: "File written successfully.",
1517+
}, nil
1518+
},
1519+
}

codersdk/toolsdk/toolsdk_test.go

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -540,6 +540,30 @@ func TestTools(t *testing.T) {
540540
})
541541
}
542542
})
543+
544+
t.Run("WorkspaceWriteFile", func(t *testing.T) {
545+
t.Parallel()
546+
547+
client, workspace, agentToken := setupWorkspaceForAgent(t)
548+
fs := afero.NewMemMapFs()
549+
_ = agenttest.New(t, client.URL, agentToken, func(opts *agent.Options) {
550+
opts.Filesystem = fs
551+
})
552+
coderdtest.NewWorkspaceAgentWaiter(t, client, workspace.ID).Wait()
553+
tb, err := toolsdk.NewDeps(client)
554+
require.NoError(t, err)
555+
556+
_, err = testTool(t, toolsdk.WorkspaceWriteFile, tb, toolsdk.WorkspaceWriteFileArgs{
557+
Workspace: workspace.Name,
558+
Path: "/test/some/path",
559+
Content: []byte("content"),
560+
})
561+
require.NoError(t, err)
562+
563+
b, err := afero.ReadFile(fs, "/test/some/path")
564+
require.NoError(t, err)
565+
require.Equal(t, []byte("content"), b)
566+
})
543567
}
544568

545569
// TestedTools keeps track of which tools have been tested.

0 commit comments

Comments
 (0)