Skip to content

Commit 13e42f0

Browse files
committed
Add coder_workspace_write_file MCP tool
1 parent 4bf63b4 commit 13e42f0

File tree

7 files changed

+322
-0
lines changed

7 files changed

+322
-0
lines changed

agent/api.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ func (a *agent) apiHandler() http.Handler {
6161
r.Get("/api/v0/netcheck", a.HandleNetcheck)
6262
r.Post("/api/v0/list-directory", a.HandleLS)
6363
r.Get("/api/v0/read-file", a.HandleReadFile)
64+
r.Post("/api/v0/write-file", a.HandleWriteFile)
6465
r.Get("/debug/logs", a.HandleHTTPDebugLogs)
6566
r.Get("/debug/magicsock", a.HandleHTTPDebugMagicsock)
6667
r.Get("/debug/magicsock/debug-logging/{state}", a.HandleHTTPMagicsockDebugLoggingState)

agent/files.go

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,14 @@ package agent
33
import (
44
"context"
55
"errors"
6+
"fmt"
67
"io"
78
"mime"
89
"net/http"
910
"os"
1011
"path/filepath"
1112
"strconv"
13+
"strings"
1214

1315
"golang.org/x/xerrors"
1416

@@ -94,3 +96,71 @@ func (a *agent) streamFile(ctx context.Context, rw http.ResponseWriter, path str
9496

9597
return 0, nil
9698
}
99+
100+
func (a *agent) HandleWriteFile(rw http.ResponseWriter, r *http.Request) {
101+
ctx := r.Context()
102+
103+
query := r.URL.Query()
104+
parser := httpapi.NewQueryParamParser().RequiredNotEmpty("path")
105+
path := parser.String(query, "", "path")
106+
parser.ErrorExcessParams(query)
107+
if len(parser.Errors) > 0 {
108+
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
109+
Message: "Query parameters have invalid values.",
110+
Validations: parser.Errors,
111+
})
112+
return
113+
}
114+
115+
status, err := a.writeFile(ctx, r, path)
116+
if err != nil {
117+
httpapi.Write(ctx, rw, status, codersdk.Response{
118+
Message: err.Error(),
119+
})
120+
return
121+
}
122+
123+
httpapi.Write(ctx, rw, http.StatusOK, codersdk.Response{
124+
Message: fmt.Sprintf("Successfully wrote to %q", path),
125+
})
126+
}
127+
128+
func (a *agent) writeFile(ctx context.Context, r *http.Request, path string) (int, error) {
129+
if !filepath.IsAbs(path) {
130+
return http.StatusBadRequest, xerrors.Errorf("file path must be absolute: %q", path)
131+
}
132+
133+
dir := filepath.Dir(path)
134+
err := a.filesystem.MkdirAll(dir, 0o755)
135+
if err != nil {
136+
status := http.StatusInternalServerError
137+
switch {
138+
case errors.Is(err, os.ErrPermission):
139+
status = http.StatusForbidden
140+
case strings.Contains(err.Error(), "not a directory"):
141+
status = http.StatusBadRequest
142+
}
143+
return status, err
144+
}
145+
146+
f, err := a.filesystem.Create(path)
147+
if err != nil {
148+
status := http.StatusInternalServerError
149+
switch {
150+
case errors.Is(err, os.ErrPermission):
151+
status = http.StatusForbidden
152+
case strings.Contains(err.Error(), "is a directory") ||
153+
strings.Contains(err.Error(), "not a directory"):
154+
status = http.StatusBadRequest
155+
}
156+
return status, err
157+
}
158+
defer f.Close()
159+
160+
_, err = io.Copy(f, r.Body)
161+
if err != nil && !errors.Is(err, io.EOF) && ctx.Err() == nil {
162+
a.logger.Error(ctx, "workspace agent write file", slog.Error(err))
163+
}
164+
165+
return 0, nil
166+
}

agent/files_test.go

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package agent_test
22

33
import (
4+
"bytes"
45
"context"
56
"io"
67
"net/http"
@@ -10,6 +11,7 @@ import (
1011

1112
"github.com/spf13/afero"
1213
"github.com/stretchr/testify/require"
14+
"golang.org/x/xerrors"
1315

1416
"github.com/coder/coder/v2/agent"
1517
"github.com/coder/coder/v2/agent/agenttest"
@@ -38,6 +40,40 @@ func (fs *testFs) Open(name string) (afero.File, error) {
3840
return fs.Fs.Open(name)
3941
}
4042

43+
func (fs *testFs) Create(name string) (afero.File, error) {
44+
if err := fs.intercept("create", name); err != nil {
45+
return nil, err
46+
}
47+
// Unlike os, afero lets you create files where directories already exist and
48+
// lets you nest them underneath files, somehow.
49+
stat, err := fs.Fs.Stat(name)
50+
if err == nil && stat.IsDir() {
51+
return nil, xerrors.New("is a directory")
52+
}
53+
stat, err = fs.Fs.Stat(filepath.Dir(name))
54+
if err == nil && !stat.IsDir() {
55+
return nil, xerrors.New("not a directory")
56+
}
57+
return fs.Fs.Create(name)
58+
}
59+
60+
func (fs *testFs) MkdirAll(name string, mode os.FileMode) error {
61+
if err := fs.intercept("mkdirall", name); err != nil {
62+
return err
63+
}
64+
// Unlike os, afero lets you create directories where files already exist and
65+
// lets you nest them underneath files somehow.
66+
stat, err := fs.Fs.Stat(filepath.Dir(name))
67+
if err == nil && !stat.IsDir() {
68+
return xerrors.New("not a directory")
69+
}
70+
stat, err = fs.Fs.Stat(name)
71+
if err == nil && !stat.IsDir() {
72+
return xerrors.New("not a directory")
73+
}
74+
return fs.Fs.MkdirAll(name, mode)
75+
}
76+
4177
func TestReadFile(t *testing.T) {
4278
t.Parallel()
4379

@@ -214,3 +250,107 @@ func TestReadFile(t *testing.T) {
214250
})
215251
}
216252
}
253+
254+
func TestWriteFile(t *testing.T) {
255+
t.Parallel()
256+
257+
tmpdir := os.TempDir()
258+
noPermsFilePath := filepath.Join(tmpdir, "no-perms-file")
259+
noPermsDirPath := filepath.Join(tmpdir, "no-perms-dir")
260+
//nolint:dogsled
261+
conn, _, _, fs, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, opts *agent.Options) {
262+
opts.Filesystem = newTestFs(opts.Filesystem, func(call, file string) error {
263+
if file == noPermsFilePath || file == noPermsDirPath {
264+
return os.ErrPermission
265+
}
266+
return nil
267+
})
268+
})
269+
270+
dirPath := filepath.Join(tmpdir, "directory")
271+
err := fs.MkdirAll(dirPath, 0o755)
272+
require.NoError(t, err)
273+
274+
filePath := filepath.Join(tmpdir, "file")
275+
err = afero.WriteFile(fs, filePath, []byte("content"), 0o644)
276+
require.NoError(t, err)
277+
278+
tests := []struct {
279+
name string
280+
path string
281+
bytes []byte
282+
errCode int
283+
error string
284+
}{
285+
{
286+
name: "NoPath",
287+
path: "",
288+
errCode: http.StatusBadRequest,
289+
error: "\"path\" is required",
290+
},
291+
{
292+
name: "RelativePath",
293+
path: "./relative",
294+
errCode: http.StatusBadRequest,
295+
error: "file path must be absolute",
296+
},
297+
{
298+
name: "RelativePath",
299+
path: "also-relative",
300+
errCode: http.StatusBadRequest,
301+
error: "file path must be absolute",
302+
},
303+
{
304+
name: "NonExistent",
305+
path: filepath.Join(tmpdir, "/nested/does-not-exist"),
306+
bytes: []byte("now it does exist"),
307+
},
308+
{
309+
name: "IsDir",
310+
path: dirPath,
311+
errCode: http.StatusBadRequest,
312+
error: "is a directory",
313+
},
314+
{
315+
name: "IsNotDir",
316+
path: filepath.Join(filePath, "file2"),
317+
errCode: http.StatusBadRequest,
318+
error: "not a directory",
319+
},
320+
{
321+
name: "NoPermissionsFile",
322+
path: noPermsFilePath,
323+
errCode: http.StatusForbidden,
324+
error: "permission denied",
325+
},
326+
{
327+
name: "NoPermissionsDir",
328+
path: filepath.Join(noPermsDirPath, "within-no-perm-dir"),
329+
errCode: http.StatusForbidden,
330+
error: "permission denied",
331+
},
332+
}
333+
334+
for _, tt := range tests {
335+
t.Run(tt.name, func(t *testing.T) {
336+
t.Parallel()
337+
338+
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
339+
defer cancel()
340+
341+
reader := bytes.NewReader(tt.bytes)
342+
err := conn.WriteFile(ctx, tt.path, reader)
343+
if tt.errCode != 0 {
344+
require.Error(t, err)
345+
cerr := coderdtest.SDKError(t, err)
346+
require.Contains(t, cerr.Error(), tt.error)
347+
require.Equal(t, tt.errCode, cerr.StatusCode())
348+
} else {
349+
require.NoError(t, err)
350+
b, err := afero.ReadFile(fs, tt.path)
351+
require.NoError(t, err)
352+
require.Equal(t, tt.bytes, b)
353+
}
354+
})
355+
}
356+
}

codersdk/toolsdk/toolsdk.go

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ const (
4343
ToolNameChatGPTSearch = "search"
4444
ToolNameChatGPTFetch = "fetch"
4545
ToolNameWorkspaceReadFile = "coder_workspace_read_file"
46+
ToolNameWorkspaceWriteFile = "coder_workspace_write_file"
4647
)
4748

4849
func NewDeps(client *codersdk.Client, opts ...func(*Deps)) (Deps, error) {
@@ -211,6 +212,7 @@ var All = []GenericTool{
211212
ChatGPTSearch.Generic(),
212213
ChatGPTFetch.Generic(),
213214
WorkspaceReadFile.Generic(),
215+
WorkspaceWriteFile.Generic(),
214216
}
215217

216218
type ReportTaskArgs struct {
@@ -1441,6 +1443,54 @@ var WorkspaceReadFile = Tool[WorkspaceReadFileArgs, WorkspaceReadFileResponse]{
14411443
},
14421444
}
14431445

1446+
type WorkspaceWriteFileArgs struct {
1447+
Workspace string `json:"workspace"`
1448+
Path string `json:"path"`
1449+
Content []byte `json:"content"`
1450+
}
1451+
1452+
var WorkspaceWriteFile = Tool[WorkspaceWriteFileArgs, codersdk.Response]{
1453+
Tool: aisdk.Tool{
1454+
Name: ToolNameWorkspaceWriteFile,
1455+
Description: `Write a file in a workspace.`,
1456+
Schema: aisdk.Schema{
1457+
Properties: map[string]any{
1458+
"workspace": map[string]any{
1459+
"type": "string",
1460+
"description": "The workspace name in the format [owner/]workspace[.agent]. If an owner is not specified, the authenticated user is used.",
1461+
},
1462+
"path": map[string]any{
1463+
"type": "string",
1464+
"description": "The absolute path of the file to write in the workspace.",
1465+
},
1466+
"content": map[string]any{
1467+
"type": "string",
1468+
"description": "The base64-encoded bytes to write to the file.",
1469+
},
1470+
},
1471+
Required: []string{"path", "workspace", "content"},
1472+
},
1473+
},
1474+
UserClientOptional: true,
1475+
Handler: func(ctx context.Context, deps Deps, args WorkspaceWriteFileArgs) (codersdk.Response, error) {
1476+
conn, err := newAgentConn(ctx, deps.coderClient, args.Workspace)
1477+
if err != nil {
1478+
return codersdk.Response{}, err
1479+
}
1480+
defer conn.Close()
1481+
1482+
reader := bytes.NewReader(args.Content)
1483+
err = conn.WriteFile(ctx, args.Path, reader)
1484+
if err != nil {
1485+
return codersdk.Response{}, err
1486+
}
1487+
1488+
return codersdk.Response{
1489+
Message: "File written successfully.",
1490+
}, nil
1491+
},
1492+
}
1493+
14441494
// NormalizeWorkspaceInput converts workspace name input to standard format.
14451495
// Handles the following input formats:
14461496
// - workspace → workspace

codersdk/toolsdk/toolsdk_test.go

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -555,6 +555,30 @@ func TestTools(t *testing.T) {
555555
})
556556
}
557557
})
558+
559+
t.Run("WorkspaceWriteFile", func(t *testing.T) {
560+
t.Parallel()
561+
562+
client, workspace, agentToken := setupWorkspaceForAgent(t)
563+
fs := afero.NewMemMapFs()
564+
_ = agenttest.New(t, client.URL, agentToken, func(opts *agent.Options) {
565+
opts.Filesystem = fs
566+
})
567+
coderdtest.NewWorkspaceAgentWaiter(t, client, workspace.ID).Wait()
568+
tb, err := toolsdk.NewDeps(client)
569+
require.NoError(t, err)
570+
571+
_, err = testTool(t, toolsdk.WorkspaceWriteFile, tb, toolsdk.WorkspaceWriteFileArgs{
572+
Workspace: workspace.Name,
573+
Path: "/test/some/path",
574+
Content: []byte("content"),
575+
})
576+
require.NoError(t, err)
577+
578+
b, err := afero.ReadFile(fs, "/test/some/path")
579+
require.NoError(t, err)
580+
require.Equal(t, []byte("content"), b)
581+
})
558582
}
559583

560584
// TestedTools keeps track of which tools have been tested.

0 commit comments

Comments
 (0)