Skip to content

Commit 911efb7

Browse files
committed
Add coder_workspace_write_file MCP tool
1 parent a042936 commit 911efb7

File tree

7 files changed

+341
-2
lines changed

7 files changed

+341
-2
lines changed

agent/api.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ func (a *agent) apiHandler() http.Handler {
6161
r.Get("/api/v0/netcheck", a.HandleNetcheck)
6262
r.Post("/api/v0/list-directory", a.HandleLS)
6363
r.Get("/api/v0/read-file", a.HandleReadFile)
64+
r.Post("/api/v0/write-file", a.HandleWriteFile)
6465
r.Get("/debug/logs", a.HandleHTTPDebugLogs)
6566
r.Get("/debug/magicsock", a.HandleHTTPDebugMagicsock)
6667
r.Get("/debug/magicsock/debug-logging/{state}", a.HandleHTTPMagicsockDebugLoggingState)

agent/files.go

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,14 @@ package agent
33
import (
44
"context"
55
"errors"
6+
"fmt"
67
"io"
78
"mime"
89
"net/http"
910
"os"
1011
"path/filepath"
1112
"strconv"
13+
"strings"
1214

1315
"golang.org/x/xerrors"
1416

@@ -94,3 +96,71 @@ func (a *agent) streamFile(ctx context.Context, rw http.ResponseWriter, path str
9496

9597
return 0, nil
9698
}
99+
100+
func (a *agent) HandleWriteFile(rw http.ResponseWriter, r *http.Request) {
101+
ctx := r.Context()
102+
103+
query := r.URL.Query()
104+
parser := httpapi.NewQueryParamParser().RequiredNotEmpty("path")
105+
path := parser.String(query, "", "path")
106+
parser.ErrorExcessParams(query)
107+
if len(parser.Errors) > 0 {
108+
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
109+
Message: "Query parameters have invalid values.",
110+
Validations: parser.Errors,
111+
})
112+
return
113+
}
114+
115+
status, err := a.writeFile(ctx, r, path)
116+
if err != nil {
117+
httpapi.Write(ctx, rw, status, codersdk.Response{
118+
Message: err.Error(),
119+
})
120+
return
121+
}
122+
123+
httpapi.Write(ctx, rw, http.StatusOK, codersdk.Response{
124+
Message: fmt.Sprintf("Successfully wrote to %q", path),
125+
})
126+
}
127+
128+
func (a *agent) writeFile(ctx context.Context, r *http.Request, path string) (int, error) {
129+
if !filepath.IsAbs(path) {
130+
return http.StatusBadRequest, xerrors.Errorf("file path must be absolute: %q", path)
131+
}
132+
133+
dir := filepath.Dir(path)
134+
err := a.filesystem.MkdirAll(dir, 0o755)
135+
if err != nil {
136+
status := http.StatusInternalServerError
137+
switch {
138+
case errors.Is(err, os.ErrPermission):
139+
status = http.StatusForbidden
140+
case strings.Contains(err.Error(), "not a directory"):
141+
status = http.StatusBadRequest
142+
}
143+
return status, err
144+
}
145+
146+
f, err := a.filesystem.Create(path)
147+
if err != nil {
148+
status := http.StatusInternalServerError
149+
switch {
150+
case errors.Is(err, os.ErrPermission):
151+
status = http.StatusForbidden
152+
case strings.Contains(err.Error(), "is a directory") ||
153+
strings.Contains(err.Error(), "not a directory"):
154+
status = http.StatusBadRequest
155+
}
156+
return status, err
157+
}
158+
defer f.Close()
159+
160+
_, err = io.Copy(f, r.Body)
161+
if err != nil && !errors.Is(err, io.EOF) && ctx.Err() == nil {
162+
a.logger.Error(ctx, "workspace agent write file", slog.Error(err))
163+
}
164+
165+
return 0, nil
166+
}

agent/files_test.go

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99

1010
"github.com/spf13/afero"
1111
"github.com/stretchr/testify/require"
12+
"golang.org/x/xerrors"
1213

1314
"github.com/coder/coder/v2/agent"
1415
"github.com/coder/coder/v2/agent/agenttest"
@@ -37,6 +38,40 @@ func (fs *testFs) Open(name string) (afero.File, error) {
3738
return fs.Fs.Open(name)
3839
}
3940

41+
func (fs *testFs) Create(name string) (afero.File, error) {
42+
if err := fs.intercept("create", name); err != nil {
43+
return nil, err
44+
}
45+
// Unlike os, afero lets you create files where directories already exist and
46+
// lets you nest them underneath files, somehow.
47+
stat, err := fs.Fs.Stat(name)
48+
if err == nil && stat.IsDir() {
49+
return nil, xerrors.New("is a directory")
50+
}
51+
stat, err = fs.Fs.Stat(filepath.Dir(name))
52+
if err == nil && !stat.IsDir() {
53+
return nil, xerrors.New("not a directory")
54+
}
55+
return fs.Fs.Create(name)
56+
}
57+
58+
func (fs *testFs) MkdirAll(name string, mode os.FileMode) error {
59+
if err := fs.intercept("mkdirall", name); err != nil {
60+
return err
61+
}
62+
// Unlike os, afero lets you create directories where files already exist and
63+
// lets you nest them underneath files somehow.
64+
stat, err := fs.Fs.Stat(filepath.Dir(name))
65+
if err == nil && !stat.IsDir() {
66+
return xerrors.New("not a directory")
67+
}
68+
stat, err = fs.Fs.Stat(name)
69+
if err == nil && !stat.IsDir() {
70+
return xerrors.New("not a directory")
71+
}
72+
return fs.Fs.MkdirAll(name, mode)
73+
}
74+
4075
func TestReadFile(t *testing.T) {
4176
t.Parallel()
4277

@@ -206,3 +241,106 @@ func TestReadFile(t *testing.T) {
206241
})
207242
}
208243
}
244+
245+
func TestWriteFile(t *testing.T) {
246+
t.Parallel()
247+
248+
tmpdir := os.TempDir()
249+
noPermsFilePath := filepath.Join(tmpdir, "no-perms-file")
250+
noPermsDirPath := filepath.Join(tmpdir, "no-perms-dir")
251+
//nolint:dogsled
252+
conn, _, _, fs, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, opts *agent.Options) {
253+
opts.Filesystem = newTestFs(opts.Filesystem, func(call, file string) error {
254+
if file == noPermsFilePath || file == noPermsDirPath {
255+
return os.ErrPermission
256+
}
257+
return nil
258+
})
259+
})
260+
261+
dirPath := filepath.Join(tmpdir, "directory")
262+
err := fs.MkdirAll(dirPath, 0o755)
263+
require.NoError(t, err)
264+
265+
filePath := filepath.Join(tmpdir, "file")
266+
err = afero.WriteFile(fs, filePath, []byte("content"), 0o644)
267+
require.NoError(t, err)
268+
269+
tests := []struct {
270+
name string
271+
path string
272+
bytes []byte
273+
errCode int
274+
error string
275+
}{
276+
{
277+
name: "NoPath",
278+
path: "",
279+
errCode: http.StatusBadRequest,
280+
error: "\"path\" is required",
281+
},
282+
{
283+
name: "RelativePath",
284+
path: "./relative",
285+
errCode: http.StatusBadRequest,
286+
error: "file path must be absolute",
287+
},
288+
{
289+
name: "RelativePath",
290+
path: "also-relative",
291+
errCode: http.StatusBadRequest,
292+
error: "file path must be absolute",
293+
},
294+
{
295+
name: "NonExistent",
296+
path: filepath.Join(tmpdir, "/nested/does-not-exist"),
297+
bytes: []byte("now it does exist"),
298+
},
299+
{
300+
name: "IsDir",
301+
path: dirPath,
302+
errCode: http.StatusBadRequest,
303+
error: "is a directory",
304+
},
305+
{
306+
name: "IsNotDir",
307+
path: filepath.Join(filePath, "file2"),
308+
errCode: http.StatusBadRequest,
309+
error: "not a directory",
310+
},
311+
{
312+
name: "NoPermissionsFile",
313+
path: noPermsFilePath,
314+
errCode: http.StatusForbidden,
315+
error: "permission denied",
316+
},
317+
{
318+
name: "NoPermissionsDir",
319+
path: filepath.Join(noPermsDirPath, "within-no-perm-dir"),
320+
errCode: http.StatusForbidden,
321+
error: "permission denied",
322+
},
323+
}
324+
325+
for _, tt := range tests {
326+
t.Run(tt.name, func(t *testing.T) {
327+
t.Parallel()
328+
329+
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
330+
defer cancel()
331+
332+
err := conn.WriteFile(ctx, tt.path, tt.bytes)
333+
if tt.errCode != 0 {
334+
require.Error(t, err)
335+
cerr := coderdtest.SDKError(t, err)
336+
require.Contains(t, cerr.Error(), tt.error)
337+
require.Equal(t, tt.errCode, cerr.StatusCode())
338+
} else {
339+
require.NoError(t, err)
340+
b, err := afero.ReadFile(fs, tt.path)
341+
require.NoError(t, err)
342+
require.Equal(t, tt.bytes, b)
343+
}
344+
})
345+
}
346+
}

codersdk/toolsdk/toolsdk.go

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ const (
4343
ToolNameChatGPTSearch = "search"
4444
ToolNameChatGPTFetch = "fetch"
4545
ToolNameWorkspaceReadFile = "coder_workspace_read_file"
46+
ToolNameWorkspaceWriteFile = "coder_workspace_write_file"
4647
)
4748

4849
func NewDeps(client *codersdk.Client, opts ...func(*Deps)) (Deps, error) {
@@ -211,6 +212,7 @@ var All = []GenericTool{
211212
ChatGPTSearch.Generic(),
212213
ChatGPTFetch.Generic(),
213214
WorkspaceReadFile.Generic(),
215+
WorkspaceWriteFile.Generic(),
214216
}
215217

216218
type ReportTaskArgs struct {
@@ -1423,6 +1425,53 @@ var WorkspaceReadFile = Tool[WorkspaceReadFileArgs, WorkspaceReadFileResponse]{
14231425
},
14241426
}
14251427

1428+
type WorkspaceWriteFileArgs struct {
1429+
Workspace string `json:"workspace"`
1430+
Path string `json:"path"`
1431+
Content []byte `json:"content"`
1432+
}
1433+
1434+
var WorkspaceWriteFile = Tool[WorkspaceWriteFileArgs, codersdk.Response]{
1435+
Tool: aisdk.Tool{
1436+
Name: ToolNameWorkspaceWriteFile,
1437+
Description: `Write a file in a workspace.`,
1438+
Schema: aisdk.Schema{
1439+
Properties: map[string]any{
1440+
"workspace": map[string]any{
1441+
"type": "string",
1442+
"description": "The workspace name in the format [owner/]workspace[.agent]. If an owner is not specified, the authenticated user is used.",
1443+
},
1444+
"path": map[string]any{
1445+
"type": "string",
1446+
"description": "The absolute path of the file to write in the workspace.",
1447+
},
1448+
"content": map[string]any{
1449+
"type": "string",
1450+
"description": "The base64-encoded bytes to write to the file.",
1451+
},
1452+
},
1453+
Required: []string{"path", "workspace", "content"},
1454+
},
1455+
},
1456+
UserClientOptional: true,
1457+
Handler: func(ctx context.Context, deps Deps, args WorkspaceWriteFileArgs) (codersdk.Response, error) {
1458+
conn, err := newAgentConn(ctx, deps.coderClient, args.Workspace)
1459+
if err != nil {
1460+
return codersdk.Response{}, err
1461+
}
1462+
defer conn.Close()
1463+
1464+
err = conn.WriteFile(ctx, args.Path, args.Content)
1465+
if err != nil {
1466+
return codersdk.Response{}, err
1467+
}
1468+
1469+
return codersdk.Response{
1470+
Message: "File written successfully.",
1471+
}, nil
1472+
},
1473+
}
1474+
14261475
// NormalizeWorkspaceInput converts workspace name input to standard format.
14271476
// Handles the following input formats:
14281477
// - workspace → workspace

codersdk/toolsdk/toolsdk_test.go

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -538,6 +538,30 @@ func TestTools(t *testing.T) {
538538
})
539539
}
540540
})
541+
542+
t.Run("WorkspaceWriteFile", func(t *testing.T) {
543+
t.Parallel()
544+
545+
client, workspace, agentToken := setupWorkspaceForAgent(t)
546+
fs := afero.NewMemMapFs()
547+
_ = agenttest.New(t, client.URL, agentToken, func(opts *agent.Options) {
548+
opts.Filesystem = fs
549+
})
550+
coderdtest.NewWorkspaceAgentWaiter(t, client, workspace.ID).Wait()
551+
tb, err := toolsdk.NewDeps(client)
552+
require.NoError(t, err)
553+
554+
_, err = testTool(t, toolsdk.WorkspaceWriteFile, tb, toolsdk.WorkspaceWriteFileArgs{
555+
Workspace: workspace.Name,
556+
Path: "/test/some/path",
557+
Content: []byte("content"),
558+
})
559+
require.NoError(t, err)
560+
561+
b, err := afero.ReadFile(fs, "/test/some/path")
562+
require.NoError(t, err)
563+
require.Equal(t, []byte("content"), b)
564+
})
541565
}
542566

543567
// TestedTools keeps track of which tools have been tested.

0 commit comments

Comments
 (0)