Skip to content

Commit d231994

Browse files
committed
Add coder_workspace_write_file MCP tool
1 parent 4bf63b4 commit d231994

File tree

7 files changed

+321
-0
lines changed

7 files changed

+321
-0
lines changed

agent/api.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ func (a *agent) apiHandler() http.Handler {
6161
r.Get("/api/v0/netcheck", a.HandleNetcheck)
6262
r.Post("/api/v0/list-directory", a.HandleLS)
6363
r.Get("/api/v0/read-file", a.HandleReadFile)
64+
r.Post("/api/v0/write-file", a.HandleWriteFile)
6465
r.Get("/debug/logs", a.HandleHTTPDebugLogs)
6566
r.Get("/debug/magicsock", a.HandleHTTPDebugMagicsock)
6667
r.Get("/debug/magicsock/debug-logging/{state}", a.HandleHTTPMagicsockDebugLoggingState)

agent/files.go

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,14 @@ package agent
33
import (
44
"context"
55
"errors"
6+
"fmt"
67
"io"
78
"mime"
89
"net/http"
910
"os"
1011
"path/filepath"
1112
"strconv"
13+
"strings"
1214

1315
"golang.org/x/xerrors"
1416

@@ -94,3 +96,71 @@ func (a *agent) streamFile(ctx context.Context, rw http.ResponseWriter, path str
9496

9597
return 0, nil
9698
}
99+
100+
func (a *agent) HandleWriteFile(rw http.ResponseWriter, r *http.Request) {
101+
ctx := r.Context()
102+
103+
query := r.URL.Query()
104+
parser := httpapi.NewQueryParamParser().RequiredNotEmpty("path")
105+
path := parser.String(query, "", "path")
106+
parser.ErrorExcessParams(query)
107+
if len(parser.Errors) > 0 {
108+
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
109+
Message: "Query parameters have invalid values.",
110+
Validations: parser.Errors,
111+
})
112+
return
113+
}
114+
115+
status, err := a.writeFile(ctx, r, path)
116+
if err != nil {
117+
httpapi.Write(ctx, rw, status, codersdk.Response{
118+
Message: err.Error(),
119+
})
120+
return
121+
}
122+
123+
httpapi.Write(ctx, rw, http.StatusOK, codersdk.Response{
124+
Message: fmt.Sprintf("Successfully wrote to %q", path),
125+
})
126+
}
127+
128+
func (a *agent) writeFile(ctx context.Context, r *http.Request, path string) (int, error) {
129+
if !filepath.IsAbs(path) {
130+
return http.StatusBadRequest, xerrors.Errorf("file path must be absolute: %q", path)
131+
}
132+
133+
dir := filepath.Dir(path)
134+
err := a.filesystem.MkdirAll(dir, 0o755)
135+
if err != nil {
136+
status := http.StatusInternalServerError
137+
switch {
138+
case errors.Is(err, os.ErrPermission):
139+
status = http.StatusForbidden
140+
case strings.Contains(err.Error(), "not a directory"):
141+
status = http.StatusBadRequest
142+
}
143+
return status, err
144+
}
145+
146+
f, err := a.filesystem.Create(path)
147+
if err != nil {
148+
status := http.StatusInternalServerError
149+
switch {
150+
case errors.Is(err, os.ErrPermission):
151+
status = http.StatusForbidden
152+
case strings.Contains(err.Error(), "is a directory") ||
153+
strings.Contains(err.Error(), "not a directory"):
154+
status = http.StatusBadRequest
155+
}
156+
return status, err
157+
}
158+
defer f.Close()
159+
160+
_, err = io.Copy(f, r.Body)
161+
if err != nil && !errors.Is(err, io.EOF) && ctx.Err() == nil {
162+
a.logger.Error(ctx, "workspace agent write file", slog.Error(err))
163+
}
164+
165+
return 0, nil
166+
}

agent/files_test.go

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package agent_test
22

33
import (
4+
"bytes"
45
"context"
56
"io"
67
"net/http"
@@ -10,6 +11,7 @@ import (
1011

1112
"github.com/spf13/afero"
1213
"github.com/stretchr/testify/require"
14+
"golang.org/x/xerrors"
1315

1416
"github.com/coder/coder/v2/agent"
1517
"github.com/coder/coder/v2/agent/agenttest"
@@ -38,6 +40,40 @@ func (fs *testFs) Open(name string) (afero.File, error) {
3840
return fs.Fs.Open(name)
3941
}
4042

43+
func (fs *testFs) Create(name string) (afero.File, error) {
44+
if err := fs.intercept("create", name); err != nil {
45+
return nil, err
46+
}
47+
// Unlike os, afero lets you create files where directories already exist and
48+
// lets you nest them underneath files, somehow.
49+
stat, err := fs.Fs.Stat(name)
50+
if err == nil && stat.IsDir() {
51+
return nil, xerrors.New("is a directory")
52+
}
53+
stat, err = fs.Fs.Stat(filepath.Dir(name))
54+
if err == nil && !stat.IsDir() {
55+
return nil, xerrors.New("not a directory")
56+
}
57+
return fs.Fs.Create(name)
58+
}
59+
60+
func (fs *testFs) MkdirAll(name string, mode os.FileMode) error {
61+
if err := fs.intercept("mkdirall", name); err != nil {
62+
return err
63+
}
64+
// Unlike os, afero lets you create directories where files already exist and
65+
// lets you nest them underneath files somehow.
66+
stat, err := fs.Fs.Stat(filepath.Dir(name))
67+
if err == nil && !stat.IsDir() {
68+
return xerrors.New("not a directory")
69+
}
70+
stat, err = fs.Fs.Stat(name)
71+
if err == nil && !stat.IsDir() {
72+
return xerrors.New("not a directory")
73+
}
74+
return fs.Fs.MkdirAll(name, mode)
75+
}
76+
4177
func TestReadFile(t *testing.T) {
4278
t.Parallel()
4379

@@ -214,3 +250,107 @@ func TestReadFile(t *testing.T) {
214250
})
215251
}
216252
}
253+
254+
func TestWriteFile(t *testing.T) {
255+
t.Parallel()
256+
257+
tmpdir := os.TempDir()
258+
noPermsFilePath := filepath.Join(tmpdir, "no-perms-file")
259+
noPermsDirPath := filepath.Join(tmpdir, "no-perms-dir")
260+
//nolint:dogsled
261+
conn, _, _, fs, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, opts *agent.Options) {
262+
opts.Filesystem = newTestFs(opts.Filesystem, func(call, file string) error {
263+
if file == noPermsFilePath || file == noPermsDirPath {
264+
return os.ErrPermission
265+
}
266+
return nil
267+
})
268+
})
269+
270+
dirPath := filepath.Join(tmpdir, "directory")
271+
err := fs.MkdirAll(dirPath, 0o755)
272+
require.NoError(t, err)
273+
274+
filePath := filepath.Join(tmpdir, "file")
275+
err = afero.WriteFile(fs, filePath, []byte("content"), 0o644)
276+
require.NoError(t, err)
277+
278+
tests := []struct {
279+
name string
280+
path string
281+
bytes []byte
282+
errCode int
283+
error string
284+
}{
285+
{
286+
name: "NoPath",
287+
path: "",
288+
errCode: http.StatusBadRequest,
289+
error: "\"path\" is required",
290+
},
291+
{
292+
name: "RelativePath",
293+
path: "./relative",
294+
errCode: http.StatusBadRequest,
295+
error: "file path must be absolute",
296+
},
297+
{
298+
name: "RelativePath",
299+
path: "also-relative",
300+
errCode: http.StatusBadRequest,
301+
error: "file path must be absolute",
302+
},
303+
{
304+
name: "NonExistent",
305+
path: filepath.Join(tmpdir, "/nested/does-not-exist"),
306+
bytes: []byte("now it does exist"),
307+
},
308+
{
309+
name: "IsDir",
310+
path: dirPath,
311+
errCode: http.StatusBadRequest,
312+
error: "is a directory",
313+
},
314+
{
315+
name: "IsNotDir",
316+
path: filepath.Join(filePath, "file2"),
317+
errCode: http.StatusBadRequest,
318+
error: "not a directory",
319+
},
320+
{
321+
name: "NoPermissionsFile",
322+
path: noPermsFilePath,
323+
errCode: http.StatusForbidden,
324+
error: "permission denied",
325+
},
326+
{
327+
name: "NoPermissionsDir",
328+
path: filepath.Join(noPermsDirPath, "within-no-perm-dir"),
329+
errCode: http.StatusForbidden,
330+
error: "permission denied",
331+
},
332+
}
333+
334+
for _, tt := range tests {
335+
t.Run(tt.name, func(t *testing.T) {
336+
t.Parallel()
337+
338+
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
339+
defer cancel()
340+
341+
reader := bytes.NewReader(tt.bytes)
342+
err := conn.WriteFile(ctx, tt.path, reader)
343+
if tt.errCode != 0 {
344+
require.Error(t, err)
345+
cerr := coderdtest.SDKError(t, err)
346+
require.Contains(t, cerr.Error(), tt.error)
347+
require.Equal(t, tt.errCode, cerr.StatusCode())
348+
} else {
349+
require.NoError(t, err)
350+
b, err := afero.ReadFile(fs, tt.path)
351+
require.NoError(t, err)
352+
require.Equal(t, tt.bytes, b)
353+
}
354+
})
355+
}
356+
}

codersdk/toolsdk/toolsdk.go

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ const (
4343
ToolNameChatGPTSearch = "search"
4444
ToolNameChatGPTFetch = "fetch"
4545
ToolNameWorkspaceReadFile = "coder_workspace_read_file"
46+
ToolNameWorkspaceWriteFile = "coder_workspace_write_file"
4647
)
4748

4849
func NewDeps(client *codersdk.Client, opts ...func(*Deps)) (Deps, error) {
@@ -211,6 +212,7 @@ var All = []GenericTool{
211212
ChatGPTSearch.Generic(),
212213
ChatGPTFetch.Generic(),
213214
WorkspaceReadFile.Generic(),
215+
WorkspaceWriteFile.Generic(),
214216
}
215217

216218
type ReportTaskArgs struct {
@@ -1441,6 +1443,54 @@ var WorkspaceReadFile = Tool[WorkspaceReadFileArgs, WorkspaceReadFileResponse]{
14411443
},
14421444
}
14431445

1446+
type WorkspaceWriteFileArgs struct {
1447+
Workspace string `json:"workspace"`
1448+
Path string `json:"path"`
1449+
Content []byte `json:"content"`
1450+
}
1451+
1452+
var WorkspaceWriteFile = Tool[WorkspaceWriteFileArgs, codersdk.Response]{
1453+
Tool: aisdk.Tool{
1454+
Name: ToolNameWorkspaceWriteFile,
1455+
Description: `Write a file in a workspace.`,
1456+
Schema: aisdk.Schema{
1457+
Properties: map[string]any{
1458+
"workspace": map[string]any{
1459+
"type": "string",
1460+
"description": "The workspace name in the format [owner/]workspace[.agent]. If an owner is not specified, the authenticated user is used.",
1461+
},
1462+
"path": map[string]any{
1463+
"type": "string",
1464+
"description": "The absolute path of the file to write in the workspace.",
1465+
},
1466+
"content": map[string]any{
1467+
"type": "string",
1468+
"description": "The base64-encoded bytes to write to the file.",
1469+
},
1470+
},
1471+
Required: []string{"path", "workspace", "content"},
1472+
},
1473+
},
1474+
UserClientOptional: true,
1475+
Handler: func(ctx context.Context, deps Deps, args WorkspaceWriteFileArgs) (codersdk.Response, error) {
1476+
conn, err := newAgentConn(ctx, deps.coderClient, args.Workspace)
1477+
if err != nil {
1478+
return codersdk.Response{}, err
1479+
}
1480+
defer conn.Close()
1481+
1482+
reader := bytes.NewReader(args.Content)
1483+
err = conn.WriteFile(ctx, args.Path, reader)
1484+
if err != nil {
1485+
return codersdk.Response{}, err
1486+
}
1487+
1488+
return codersdk.Response{
1489+
Message: "File written successfully.",
1490+
}, nil
1491+
},
1492+
}
1493+
14441494
// NormalizeWorkspaceInput converts workspace name input to standard format.
14451495
// Handles the following input formats:
14461496
// - workspace → workspace

codersdk/toolsdk/toolsdk_test.go

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -555,6 +555,30 @@ func TestTools(t *testing.T) {
555555
})
556556
}
557557
})
558+
559+
t.Run("WorkspaceWriteFile", func(t *testing.T) {
560+
t.Parallel()
561+
562+
client, workspace, agentToken := setupWorkspaceForAgent(t)
563+
fs := afero.NewMemMapFs()
564+
_ = agenttest.New(t, client.URL, agentToken, func(opts *agent.Options) {
565+
opts.Filesystem = fs
566+
})
567+
coderdtest.NewWorkspaceAgentWaiter(t, client, workspace.ID).Wait()
568+
tb, err := toolsdk.NewDeps(client)
569+
require.NoError(t, err)
570+
571+
_, err = testTool(t, toolsdk.WorkspaceWriteFile, tb, toolsdk.WorkspaceWriteFileArgs{
572+
Workspace: workspace.Name,
573+
Path: "/test/some/path",
574+
Content: []byte("content"),
575+
})
576+
require.NoError(t, err)
577+
578+
b, err := afero.ReadFile(fs, "/test/some/path")
579+
require.NoError(t, err)
580+
require.Equal(t, []byte("content"), b)
581+
})
558582
}
559583

560584
// TestedTools keeps track of which tools have been tested.

codersdk/workspacesdk/agentconn.go

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ type AgentConn interface {
6161
ReconnectingPTY(ctx context.Context, id uuid.UUID, height uint16, width uint16, command string, initOpts ...AgentReconnectingPTYInitOption) (net.Conn, error)
6262
RecreateDevcontainer(ctx context.Context, devcontainerID string) (codersdk.Response, error)
6363
ReadFile(ctx context.Context, path string, offset, limit int64) (io.ReadCloser, string, error)
64+
WriteFile(ctx context.Context, path string, reader io.Reader) error
6465
SSH(ctx context.Context) (*gonet.TCPConn, error)
6566
SSHClient(ctx context.Context) (*ssh.Client, error)
6667
SSHClientOnPort(ctx context.Context, port uint16) (*ssh.Client, error)
@@ -501,6 +502,27 @@ func (c *agentConn) ReadFile(ctx context.Context, path string, offset, limit int
501502
return res.Body, mimeType, nil
502503
}
503504

505+
// WriteFile writes to a file in the workspace.
506+
func (c *agentConn) WriteFile(ctx context.Context, path string, reader io.Reader) error {
507+
ctx, span := tracing.StartSpan(ctx)
508+
defer span.End()
509+
510+
res, err := c.apiRequest(ctx, http.MethodPost, fmt.Sprintf("/api/v0/write-file?path=%s", path), reader)
511+
if err != nil {
512+
return xerrors.Errorf("do request: %w", err)
513+
}
514+
defer res.Body.Close()
515+
if res.StatusCode != http.StatusOK {
516+
return codersdk.ReadBodyAsError(res)
517+
}
518+
519+
var m codersdk.Response
520+
if err := json.NewDecoder(res.Body).Decode(&m); err != nil {
521+
return xerrors.Errorf("decode response body: %w", err)
522+
}
523+
return nil
524+
}
525+
504526
// apiRequest makes a request to the workspace agent's HTTP API server.
505527
func (c *agentConn) apiRequest(ctx context.Context, method, path string, body io.Reader) (*http.Response, error) {
506528
ctx, span := tracing.StartSpan(ctx)

0 commit comments

Comments
 (0)