Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions agent/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ func (a *agent) apiHandler() http.Handler {
r.Get("/api/v0/netcheck", a.HandleNetcheck)
r.Post("/api/v0/list-directory", a.HandleLS)
r.Get("/api/v0/read-file", a.HandleReadFile)
r.Post("/api/v0/write-file", a.HandleWriteFile)
r.Get("/debug/logs", a.HandleHTTPDebugLogs)
r.Get("/debug/magicsock", a.HandleHTTPDebugMagicsock)
r.Get("/debug/magicsock/debug-logging/{state}", a.HandleHTTPMagicsockDebugLoggingState)
Expand Down
73 changes: 72 additions & 1 deletion agent/files.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@ package agent
import (
"context"
"errors"
"fmt"
"io"
"mime"
"net/http"
"os"
"path/filepath"
"strconv"
"syscall"

"golang.org/x/xerrors"

Expand All @@ -17,6 +19,8 @@ import (
"github.com/coder/coder/v2/codersdk"
)

type HTTPResponseCode = int

func (a *agent) HandleReadFile(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()

Expand All @@ -43,7 +47,7 @@ func (a *agent) HandleReadFile(rw http.ResponseWriter, r *http.Request) {
}
}

func (a *agent) streamFile(ctx context.Context, rw http.ResponseWriter, path string, offset, limit int64) (int, error) {
func (a *agent) streamFile(ctx context.Context, rw http.ResponseWriter, path string, offset, limit int64) (HTTPResponseCode, error) {
if !filepath.IsAbs(path) {
return http.StatusBadRequest, xerrors.Errorf("file path must be absolute: %q", path)
}
Expand Down Expand Up @@ -94,3 +98,70 @@ func (a *agent) streamFile(ctx context.Context, rw http.ResponseWriter, path str

return 0, nil
}

func (a *agent) HandleWriteFile(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()

query := r.URL.Query()
parser := httpapi.NewQueryParamParser().RequiredNotEmpty("path")
path := parser.String(query, "", "path")
parser.ErrorExcessParams(query)
if len(parser.Errors) > 0 {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Query parameters have invalid values.",
Validations: parser.Errors,
})
return
}

status, err := a.writeFile(ctx, r, path)
if err != nil {
httpapi.Write(ctx, rw, status, codersdk.Response{
Message: err.Error(),
})
return
}

httpapi.Write(ctx, rw, http.StatusOK, codersdk.Response{
Message: fmt.Sprintf("Successfully wrote to %q", path),
})
}

func (a *agent) writeFile(ctx context.Context, r *http.Request, path string) (HTTPResponseCode, error) {
if !filepath.IsAbs(path) {
return http.StatusBadRequest, xerrors.Errorf("file path must be absolute: %q", path)
}

dir := filepath.Dir(path)
err := a.filesystem.MkdirAll(dir, 0o755)
if err != nil {
status := http.StatusInternalServerError
switch {
case errors.Is(err, os.ErrPermission):
status = http.StatusForbidden
case errors.Is(err, syscall.ENOTDIR):
status = http.StatusBadRequest
}
return status, err
}

f, err := a.filesystem.Create(path)
if err != nil {
status := http.StatusInternalServerError
switch {
case errors.Is(err, os.ErrPermission):
status = http.StatusForbidden
case errors.Is(err, syscall.EISDIR):
status = http.StatusBadRequest
}
return status, err
}
defer f.Close()

_, err = io.Copy(f, r.Body)
if err != nil && !errors.Is(err, io.EOF) && ctx.Err() == nil {
a.logger.Error(ctx, "workspace agent write file", slog.Error(err))
}

return 0, nil
}
164 changes: 163 additions & 1 deletion agent/files_test.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
package agent_test

import (
"bytes"
"context"
"io"
"net/http"
"os"
"path/filepath"
"runtime"
"syscall"
"testing"

"github.com/spf13/afero"
Expand Down Expand Up @@ -38,6 +41,56 @@ func (fs *testFs) Open(name string) (afero.File, error) {
return fs.Fs.Open(name)
}

func (fs *testFs) Create(name string) (afero.File, error) {
if err := fs.intercept("create", name); err != nil {
return nil, err
}
// Unlike os, afero lets you create files where directories already exist and
// lets you nest them underneath files, somehow.
stat, err := fs.Fs.Stat(name)
if err == nil && stat.IsDir() {
return nil, &os.PathError{
Op: "open",
Path: name,
Err: syscall.EISDIR,
}
}
stat, err = fs.Fs.Stat(filepath.Dir(name))
if err == nil && !stat.IsDir() {
return nil, &os.PathError{
Op: "open",
Path: name,
Err: syscall.ENOTDIR,
}
}
return fs.Fs.Create(name)
}

func (fs *testFs) MkdirAll(name string, mode os.FileMode) error {
if err := fs.intercept("mkdirall", name); err != nil {
return err
}
// Unlike os, afero lets you create directories where files already exist and
// lets you nest them underneath files somehow.
stat, err := fs.Fs.Stat(filepath.Dir(name))
if err == nil && !stat.IsDir() {
return &os.PathError{
Op: "mkdir",
Path: name,
Err: syscall.ENOTDIR,
}
}
stat, err = fs.Fs.Stat(name)
if err == nil && !stat.IsDir() {
return &os.PathError{
Op: "mkdir",
Path: name,
Err: syscall.ENOTDIR,
}
}
return fs.Fs.MkdirAll(name, mode)
}

func TestReadFile(t *testing.T) {
t.Parallel()

Expand Down Expand Up @@ -82,7 +135,7 @@ func TestReadFile(t *testing.T) {
error: "\"path\" is required",
},
{
name: "RelativePath",
name: "RelativePathDotSlash",
path: "./relative",
errCode: http.StatusBadRequest,
error: "file path must be absolute",
Expand Down Expand Up @@ -214,3 +267,112 @@ func TestReadFile(t *testing.T) {
})
}
}

func TestWriteFile(t *testing.T) {
t.Parallel()

tmpdir := os.TempDir()
noPermsFilePath := filepath.Join(tmpdir, "no-perms-file")
noPermsDirPath := filepath.Join(tmpdir, "no-perms-dir")
//nolint:dogsled
conn, _, _, fs, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, opts *agent.Options) {
opts.Filesystem = newTestFs(opts.Filesystem, func(call, file string) error {
if file == noPermsFilePath || file == noPermsDirPath {
return os.ErrPermission
}
return nil
})
})

dirPath := filepath.Join(tmpdir, "directory")
err := fs.MkdirAll(dirPath, 0o755)
require.NoError(t, err)

filePath := filepath.Join(tmpdir, "file")
err = afero.WriteFile(fs, filePath, []byte("content"), 0o644)
require.NoError(t, err)

notDirErr := "not a directory"
if runtime.GOOS == "windows" {
notDirErr = "cannot find the path"
}

tests := []struct {
name string
path string
bytes []byte
errCode int
error string
}{
{
name: "NoPath",
path: "",
errCode: http.StatusBadRequest,
error: "\"path\" is required",
},
{
name: "RelativePathDotSlash",
path: "./relative",
errCode: http.StatusBadRequest,
error: "file path must be absolute",
},
{
name: "RelativePath",
path: "also-relative",
errCode: http.StatusBadRequest,
error: "file path must be absolute",
},
{
name: "NonExistent",
path: filepath.Join(tmpdir, "/nested/does-not-exist"),
bytes: []byte("now it does exist"),
},
{
name: "IsDir",
path: dirPath,
errCode: http.StatusBadRequest,
error: "is a directory",
},
{
name: "IsNotDir",
path: filepath.Join(filePath, "file2"),
errCode: http.StatusBadRequest,
error: notDirErr,
},
{
name: "NoPermissionsFile",
path: noPermsFilePath,
errCode: http.StatusForbidden,
error: "permission denied",
},
{
name: "NoPermissionsDir",
path: filepath.Join(noPermsDirPath, "within-no-perm-dir"),
errCode: http.StatusForbidden,
error: "permission denied",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()

ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()

reader := bytes.NewReader(tt.bytes)
err := conn.WriteFile(ctx, tt.path, reader)
if tt.errCode != 0 {
require.Error(t, err)
cerr := coderdtest.SDKError(t, err)
require.Contains(t, cerr.Error(), tt.error)
require.Equal(t, tt.errCode, cerr.StatusCode())
} else {
require.NoError(t, err)
b, err := afero.ReadFile(fs, tt.path)
require.NoError(t, err)
require.Equal(t, tt.bytes, b)
}
})
}
}
50 changes: 50 additions & 0 deletions codersdk/toolsdk/toolsdk.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ const (
ToolNameChatGPTSearch = "search"
ToolNameChatGPTFetch = "fetch"
ToolNameWorkspaceReadFile = "coder_workspace_read_file"
ToolNameWorkspaceWriteFile = "coder_workspace_write_file"
)

func NewDeps(client *codersdk.Client, opts ...func(*Deps)) (Deps, error) {
Expand Down Expand Up @@ -211,6 +212,7 @@ var All = []GenericTool{
ChatGPTSearch.Generic(),
ChatGPTFetch.Generic(),
WorkspaceReadFile.Generic(),
WorkspaceWriteFile.Generic(),
}

type ReportTaskArgs struct {
Expand Down Expand Up @@ -1441,6 +1443,54 @@ var WorkspaceReadFile = Tool[WorkspaceReadFileArgs, WorkspaceReadFileResponse]{
},
}

type WorkspaceWriteFileArgs struct {
Workspace string `json:"workspace"`
Path string `json:"path"`
Content []byte `json:"content"`
}

var WorkspaceWriteFile = Tool[WorkspaceWriteFileArgs, codersdk.Response]{
Tool: aisdk.Tool{
Name: ToolNameWorkspaceWriteFile,
Description: `Write a file in a workspace.`,
Schema: aisdk.Schema{
Properties: map[string]any{
"workspace": map[string]any{
"type": "string",
"description": "The workspace name in the format [owner/]workspace[.agent]. If an owner is not specified, the authenticated user is used.",
},
"path": map[string]any{
"type": "string",
"description": "The absolute path of the file to write in the workspace.",
},
"content": map[string]any{
"type": "string",
"description": "The base64-encoded bytes to write to the file.",
Comment on lines +1466 to +1468
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we maybe want to limit the amount of file writing? Reads are 1MB, I think, so it might make sense to introduce a restriction here, too.

Copy link
Member Author

@code-asher code-asher Sep 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had that thought too but I am not sure how to make that restriction. The max is useful to make sure we are not going to run out of memory, but by the time we are here, anything in memory is already in memory and we might as well write it out since we have it.

I think instead of the tool handler, the restriction would need to be in the MCP library at the point where it reads in the data (and maybe it already has limits, I have not tested). Not sure if this is configurable though, or if we can hook in somewhere to enforce the restriction.

Copy link
Member Author

@code-asher code-asher Sep 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm maybe WithHooks will let me do it. Ohhh wait there is a MaxLength that I could pass through, which is probably what you were already thinking of 😅 Gonna implement this. I will need an offset as well, or an append boolean.

edit: er wait does not seem the server is enforcing the max length property...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, gotcha. In that case, I think we're good here.

},
},
Required: []string{"path", "workspace", "content"},
},
},
UserClientOptional: true,
Handler: func(ctx context.Context, deps Deps, args WorkspaceWriteFileArgs) (codersdk.Response, error) {
conn, err := newAgentConn(ctx, deps.coderClient, args.Workspace)
if err != nil {
return codersdk.Response{}, err
}
defer conn.Close()

reader := bytes.NewReader(args.Content)
err = conn.WriteFile(ctx, args.Path, reader)
if err != nil {
return codersdk.Response{}, err
}

return codersdk.Response{
Message: "File written successfully.",
}, nil
},
}

// NormalizeWorkspaceInput converts workspace name input to standard format.
// Handles the following input formats:
// - workspace → workspace
Expand Down
Loading
Loading