Skip to content

Commit d5a02d5

Browse files
authored
feat: add coder_workspace_write_file MCP tool (#19591)
1 parent eec6c8c commit d5a02d5

File tree

7 files changed

+346
-2
lines changed

7 files changed

+346
-2
lines changed

agent/api.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ func (a *agent) apiHandler() http.Handler {
6161
r.Get("/api/v0/netcheck", a.HandleNetcheck)
6262
r.Post("/api/v0/list-directory", a.HandleLS)
6363
r.Get("/api/v0/read-file", a.HandleReadFile)
64+
r.Post("/api/v0/write-file", a.HandleWriteFile)
6465
r.Get("/debug/logs", a.HandleHTTPDebugLogs)
6566
r.Get("/debug/magicsock", a.HandleHTTPDebugMagicsock)
6667
r.Get("/debug/magicsock/debug-logging/{state}", a.HandleHTTPMagicsockDebugLoggingState)

agent/files.go

Lines changed: 72 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,14 @@ package agent
33
import (
44
"context"
55
"errors"
6+
"fmt"
67
"io"
78
"mime"
89
"net/http"
910
"os"
1011
"path/filepath"
1112
"strconv"
13+
"syscall"
1214

1315
"golang.org/x/xerrors"
1416

@@ -17,6 +19,8 @@ import (
1719
"github.com/coder/coder/v2/codersdk"
1820
)
1921

22+
type HTTPResponseCode = int
23+
2024
func (a *agent) HandleReadFile(rw http.ResponseWriter, r *http.Request) {
2125
ctx := r.Context()
2226

@@ -43,7 +47,7 @@ func (a *agent) HandleReadFile(rw http.ResponseWriter, r *http.Request) {
4347
}
4448
}
4549

46-
func (a *agent) streamFile(ctx context.Context, rw http.ResponseWriter, path string, offset, limit int64) (int, error) {
50+
func (a *agent) streamFile(ctx context.Context, rw http.ResponseWriter, path string, offset, limit int64) (HTTPResponseCode, error) {
4751
if !filepath.IsAbs(path) {
4852
return http.StatusBadRequest, xerrors.Errorf("file path must be absolute: %q", path)
4953
}
@@ -94,3 +98,70 @@ func (a *agent) streamFile(ctx context.Context, rw http.ResponseWriter, path str
9498

9599
return 0, nil
96100
}
101+
102+
func (a *agent) HandleWriteFile(rw http.ResponseWriter, r *http.Request) {
103+
ctx := r.Context()
104+
105+
query := r.URL.Query()
106+
parser := httpapi.NewQueryParamParser().RequiredNotEmpty("path")
107+
path := parser.String(query, "", "path")
108+
parser.ErrorExcessParams(query)
109+
if len(parser.Errors) > 0 {
110+
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
111+
Message: "Query parameters have invalid values.",
112+
Validations: parser.Errors,
113+
})
114+
return
115+
}
116+
117+
status, err := a.writeFile(ctx, r, path)
118+
if err != nil {
119+
httpapi.Write(ctx, rw, status, codersdk.Response{
120+
Message: err.Error(),
121+
})
122+
return
123+
}
124+
125+
httpapi.Write(ctx, rw, http.StatusOK, codersdk.Response{
126+
Message: fmt.Sprintf("Successfully wrote to %q", path),
127+
})
128+
}
129+
130+
func (a *agent) writeFile(ctx context.Context, r *http.Request, path string) (HTTPResponseCode, error) {
131+
if !filepath.IsAbs(path) {
132+
return http.StatusBadRequest, xerrors.Errorf("file path must be absolute: %q", path)
133+
}
134+
135+
dir := filepath.Dir(path)
136+
err := a.filesystem.MkdirAll(dir, 0o755)
137+
if err != nil {
138+
status := http.StatusInternalServerError
139+
switch {
140+
case errors.Is(err, os.ErrPermission):
141+
status = http.StatusForbidden
142+
case errors.Is(err, syscall.ENOTDIR):
143+
status = http.StatusBadRequest
144+
}
145+
return status, err
146+
}
147+
148+
f, err := a.filesystem.Create(path)
149+
if err != nil {
150+
status := http.StatusInternalServerError
151+
switch {
152+
case errors.Is(err, os.ErrPermission):
153+
status = http.StatusForbidden
154+
case errors.Is(err, syscall.EISDIR):
155+
status = http.StatusBadRequest
156+
}
157+
return status, err
158+
}
159+
defer f.Close()
160+
161+
_, err = io.Copy(f, r.Body)
162+
if err != nil && !errors.Is(err, io.EOF) && ctx.Err() == nil {
163+
a.logger.Error(ctx, "workspace agent write file", slog.Error(err))
164+
}
165+
166+
return 0, nil
167+
}

agent/files_test.go

Lines changed: 163 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
package agent_test
22

33
import (
4+
"bytes"
45
"context"
56
"io"
67
"net/http"
78
"os"
89
"path/filepath"
10+
"runtime"
11+
"syscall"
912
"testing"
1013

1114
"github.com/spf13/afero"
@@ -38,6 +41,56 @@ func (fs *testFs) Open(name string) (afero.File, error) {
3841
return fs.Fs.Open(name)
3942
}
4043

44+
func (fs *testFs) Create(name string) (afero.File, error) {
45+
if err := fs.intercept("create", name); err != nil {
46+
return nil, err
47+
}
48+
// Unlike os, afero lets you create files where directories already exist and
49+
// lets you nest them underneath files, somehow.
50+
stat, err := fs.Fs.Stat(name)
51+
if err == nil && stat.IsDir() {
52+
return nil, &os.PathError{
53+
Op: "open",
54+
Path: name,
55+
Err: syscall.EISDIR,
56+
}
57+
}
58+
stat, err = fs.Fs.Stat(filepath.Dir(name))
59+
if err == nil && !stat.IsDir() {
60+
return nil, &os.PathError{
61+
Op: "open",
62+
Path: name,
63+
Err: syscall.ENOTDIR,
64+
}
65+
}
66+
return fs.Fs.Create(name)
67+
}
68+
69+
func (fs *testFs) MkdirAll(name string, mode os.FileMode) error {
70+
if err := fs.intercept("mkdirall", name); err != nil {
71+
return err
72+
}
73+
// Unlike os, afero lets you create directories where files already exist and
74+
// lets you nest them underneath files somehow.
75+
stat, err := fs.Fs.Stat(filepath.Dir(name))
76+
if err == nil && !stat.IsDir() {
77+
return &os.PathError{
78+
Op: "mkdir",
79+
Path: name,
80+
Err: syscall.ENOTDIR,
81+
}
82+
}
83+
stat, err = fs.Fs.Stat(name)
84+
if err == nil && !stat.IsDir() {
85+
return &os.PathError{
86+
Op: "mkdir",
87+
Path: name,
88+
Err: syscall.ENOTDIR,
89+
}
90+
}
91+
return fs.Fs.MkdirAll(name, mode)
92+
}
93+
4194
func TestReadFile(t *testing.T) {
4295
t.Parallel()
4396

@@ -82,7 +135,7 @@ func TestReadFile(t *testing.T) {
82135
error: "\"path\" is required",
83136
},
84137
{
85-
name: "RelativePath",
138+
name: "RelativePathDotSlash",
86139
path: "./relative",
87140
errCode: http.StatusBadRequest,
88141
error: "file path must be absolute",
@@ -214,3 +267,112 @@ func TestReadFile(t *testing.T) {
214267
})
215268
}
216269
}
270+
271+
func TestWriteFile(t *testing.T) {
272+
t.Parallel()
273+
274+
tmpdir := os.TempDir()
275+
noPermsFilePath := filepath.Join(tmpdir, "no-perms-file")
276+
noPermsDirPath := filepath.Join(tmpdir, "no-perms-dir")
277+
//nolint:dogsled
278+
conn, _, _, fs, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, opts *agent.Options) {
279+
opts.Filesystem = newTestFs(opts.Filesystem, func(call, file string) error {
280+
if file == noPermsFilePath || file == noPermsDirPath {
281+
return os.ErrPermission
282+
}
283+
return nil
284+
})
285+
})
286+
287+
dirPath := filepath.Join(tmpdir, "directory")
288+
err := fs.MkdirAll(dirPath, 0o755)
289+
require.NoError(t, err)
290+
291+
filePath := filepath.Join(tmpdir, "file")
292+
err = afero.WriteFile(fs, filePath, []byte("content"), 0o644)
293+
require.NoError(t, err)
294+
295+
notDirErr := "not a directory"
296+
if runtime.GOOS == "windows" {
297+
notDirErr = "cannot find the path"
298+
}
299+
300+
tests := []struct {
301+
name string
302+
path string
303+
bytes []byte
304+
errCode int
305+
error string
306+
}{
307+
{
308+
name: "NoPath",
309+
path: "",
310+
errCode: http.StatusBadRequest,
311+
error: "\"path\" is required",
312+
},
313+
{
314+
name: "RelativePathDotSlash",
315+
path: "./relative",
316+
errCode: http.StatusBadRequest,
317+
error: "file path must be absolute",
318+
},
319+
{
320+
name: "RelativePath",
321+
path: "also-relative",
322+
errCode: http.StatusBadRequest,
323+
error: "file path must be absolute",
324+
},
325+
{
326+
name: "NonExistent",
327+
path: filepath.Join(tmpdir, "/nested/does-not-exist"),
328+
bytes: []byte("now it does exist"),
329+
},
330+
{
331+
name: "IsDir",
332+
path: dirPath,
333+
errCode: http.StatusBadRequest,
334+
error: "is a directory",
335+
},
336+
{
337+
name: "IsNotDir",
338+
path: filepath.Join(filePath, "file2"),
339+
errCode: http.StatusBadRequest,
340+
error: notDirErr,
341+
},
342+
{
343+
name: "NoPermissionsFile",
344+
path: noPermsFilePath,
345+
errCode: http.StatusForbidden,
346+
error: "permission denied",
347+
},
348+
{
349+
name: "NoPermissionsDir",
350+
path: filepath.Join(noPermsDirPath, "within-no-perm-dir"),
351+
errCode: http.StatusForbidden,
352+
error: "permission denied",
353+
},
354+
}
355+
356+
for _, tt := range tests {
357+
t.Run(tt.name, func(t *testing.T) {
358+
t.Parallel()
359+
360+
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
361+
defer cancel()
362+
363+
reader := bytes.NewReader(tt.bytes)
364+
err := conn.WriteFile(ctx, tt.path, reader)
365+
if tt.errCode != 0 {
366+
require.Error(t, err)
367+
cerr := coderdtest.SDKError(t, err)
368+
require.Contains(t, cerr.Error(), tt.error)
369+
require.Equal(t, tt.errCode, cerr.StatusCode())
370+
} else {
371+
require.NoError(t, err)
372+
b, err := afero.ReadFile(fs, tt.path)
373+
require.NoError(t, err)
374+
require.Equal(t, tt.bytes, b)
375+
}
376+
})
377+
}
378+
}

codersdk/toolsdk/toolsdk.go

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ const (
4343
ToolNameChatGPTSearch = "search"
4444
ToolNameChatGPTFetch = "fetch"
4545
ToolNameWorkspaceReadFile = "coder_workspace_read_file"
46+
ToolNameWorkspaceWriteFile = "coder_workspace_write_file"
4647
)
4748

4849
func NewDeps(client *codersdk.Client, opts ...func(*Deps)) (Deps, error) {
@@ -211,6 +212,7 @@ var All = []GenericTool{
211212
ChatGPTSearch.Generic(),
212213
ChatGPTFetch.Generic(),
213214
WorkspaceReadFile.Generic(),
215+
WorkspaceWriteFile.Generic(),
214216
}
215217

216218
type ReportTaskArgs struct {
@@ -1441,6 +1443,54 @@ var WorkspaceReadFile = Tool[WorkspaceReadFileArgs, WorkspaceReadFileResponse]{
14411443
},
14421444
}
14431445

1446+
type WorkspaceWriteFileArgs struct {
1447+
Workspace string `json:"workspace"`
1448+
Path string `json:"path"`
1449+
Content []byte `json:"content"`
1450+
}
1451+
1452+
var WorkspaceWriteFile = Tool[WorkspaceWriteFileArgs, codersdk.Response]{
1453+
Tool: aisdk.Tool{
1454+
Name: ToolNameWorkspaceWriteFile,
1455+
Description: `Write a file in a workspace.`,
1456+
Schema: aisdk.Schema{
1457+
Properties: map[string]any{
1458+
"workspace": map[string]any{
1459+
"type": "string",
1460+
"description": "The workspace name in the format [owner/]workspace[.agent]. If an owner is not specified, the authenticated user is used.",
1461+
},
1462+
"path": map[string]any{
1463+
"type": "string",
1464+
"description": "The absolute path of the file to write in the workspace.",
1465+
},
1466+
"content": map[string]any{
1467+
"type": "string",
1468+
"description": "The base64-encoded bytes to write to the file.",
1469+
},
1470+
},
1471+
Required: []string{"path", "workspace", "content"},
1472+
},
1473+
},
1474+
UserClientOptional: true,
1475+
Handler: func(ctx context.Context, deps Deps, args WorkspaceWriteFileArgs) (codersdk.Response, error) {
1476+
conn, err := newAgentConn(ctx, deps.coderClient, args.Workspace)
1477+
if err != nil {
1478+
return codersdk.Response{}, err
1479+
}
1480+
defer conn.Close()
1481+
1482+
reader := bytes.NewReader(args.Content)
1483+
err = conn.WriteFile(ctx, args.Path, reader)
1484+
if err != nil {
1485+
return codersdk.Response{}, err
1486+
}
1487+
1488+
return codersdk.Response{
1489+
Message: "File written successfully.",
1490+
}, nil
1491+
},
1492+
}
1493+
14441494
// NormalizeWorkspaceInput converts workspace name input to standard format.
14451495
// Handles the following input formats:
14461496
// - workspace → workspace

0 commit comments

Comments
 (0)