diff --git a/agent/api.go b/agent/api.go index 809a62bedf4b9..bb3adc9e2457c 100644 --- a/agent/api.go +++ b/agent/api.go @@ -61,6 +61,7 @@ func (a *agent) apiHandler() http.Handler { r.Get("/api/v0/netcheck", a.HandleNetcheck) r.Post("/api/v0/list-directory", a.HandleLS) r.Get("/api/v0/read-file", a.HandleReadFile) + r.Post("/api/v0/write-file", a.HandleWriteFile) r.Get("/debug/logs", a.HandleHTTPDebugLogs) r.Get("/debug/magicsock", a.HandleHTTPDebugMagicsock) r.Get("/debug/magicsock/debug-logging/{state}", a.HandleHTTPMagicsockDebugLoggingState) diff --git a/agent/files.go b/agent/files.go index 0f5db82058d00..2f6a217093640 100644 --- a/agent/files.go +++ b/agent/files.go @@ -3,12 +3,14 @@ package agent import ( "context" "errors" + "fmt" "io" "mime" "net/http" "os" "path/filepath" "strconv" + "syscall" "golang.org/x/xerrors" @@ -17,6 +19,8 @@ import ( "github.com/coder/coder/v2/codersdk" ) +type HTTPResponseCode = int + func (a *agent) HandleReadFile(rw http.ResponseWriter, r *http.Request) { ctx := r.Context() @@ -43,7 +47,7 @@ func (a *agent) HandleReadFile(rw http.ResponseWriter, r *http.Request) { } } -func (a *agent) streamFile(ctx context.Context, rw http.ResponseWriter, path string, offset, limit int64) (int, error) { +func (a *agent) streamFile(ctx context.Context, rw http.ResponseWriter, path string, offset, limit int64) (HTTPResponseCode, error) { if !filepath.IsAbs(path) { return http.StatusBadRequest, xerrors.Errorf("file path must be absolute: %q", path) } @@ -94,3 +98,70 @@ func (a *agent) streamFile(ctx context.Context, rw http.ResponseWriter, path str return 0, nil } + +func (a *agent) HandleWriteFile(rw http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + query := r.URL.Query() + parser := httpapi.NewQueryParamParser().RequiredNotEmpty("path") + path := parser.String(query, "", "path") + parser.ErrorExcessParams(query) + if len(parser.Errors) > 0 { + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ + Message: "Query parameters have invalid values.", + Validations: parser.Errors, + }) + return + } + + status, err := a.writeFile(ctx, r, path) + if err != nil { + httpapi.Write(ctx, rw, status, codersdk.Response{ + Message: err.Error(), + }) + return + } + + httpapi.Write(ctx, rw, http.StatusOK, codersdk.Response{ + Message: fmt.Sprintf("Successfully wrote to %q", path), + }) +} + +func (a *agent) writeFile(ctx context.Context, r *http.Request, path string) (HTTPResponseCode, error) { + if !filepath.IsAbs(path) { + return http.StatusBadRequest, xerrors.Errorf("file path must be absolute: %q", path) + } + + dir := filepath.Dir(path) + err := a.filesystem.MkdirAll(dir, 0o755) + if err != nil { + status := http.StatusInternalServerError + switch { + case errors.Is(err, os.ErrPermission): + status = http.StatusForbidden + case errors.Is(err, syscall.ENOTDIR): + status = http.StatusBadRequest + } + return status, err + } + + f, err := a.filesystem.Create(path) + if err != nil { + status := http.StatusInternalServerError + switch { + case errors.Is(err, os.ErrPermission): + status = http.StatusForbidden + case errors.Is(err, syscall.EISDIR): + status = http.StatusBadRequest + } + return status, err + } + defer f.Close() + + _, err = io.Copy(f, r.Body) + if err != nil && !errors.Is(err, io.EOF) && ctx.Err() == nil { + a.logger.Error(ctx, "workspace agent write file", slog.Error(err)) + } + + return 0, nil +} diff --git a/agent/files_test.go b/agent/files_test.go index 80aeadd518cf0..e443f27e73e2b 100644 --- a/agent/files_test.go +++ b/agent/files_test.go @@ -1,11 +1,14 @@ package agent_test import ( + "bytes" "context" "io" "net/http" "os" "path/filepath" + "runtime" + "syscall" "testing" "github.com/spf13/afero" @@ -38,6 +41,56 @@ func (fs *testFs) Open(name string) (afero.File, error) { return fs.Fs.Open(name) } +func (fs *testFs) Create(name string) (afero.File, error) { + if err := fs.intercept("create", name); err != nil { + return nil, err + } + // Unlike os, afero lets you create files where directories already exist and + // lets you nest them underneath files, somehow. + stat, err := fs.Fs.Stat(name) + if err == nil && stat.IsDir() { + return nil, &os.PathError{ + Op: "open", + Path: name, + Err: syscall.EISDIR, + } + } + stat, err = fs.Fs.Stat(filepath.Dir(name)) + if err == nil && !stat.IsDir() { + return nil, &os.PathError{ + Op: "open", + Path: name, + Err: syscall.ENOTDIR, + } + } + return fs.Fs.Create(name) +} + +func (fs *testFs) MkdirAll(name string, mode os.FileMode) error { + if err := fs.intercept("mkdirall", name); err != nil { + return err + } + // Unlike os, afero lets you create directories where files already exist and + // lets you nest them underneath files somehow. + stat, err := fs.Fs.Stat(filepath.Dir(name)) + if err == nil && !stat.IsDir() { + return &os.PathError{ + Op: "mkdir", + Path: name, + Err: syscall.ENOTDIR, + } + } + stat, err = fs.Fs.Stat(name) + if err == nil && !stat.IsDir() { + return &os.PathError{ + Op: "mkdir", + Path: name, + Err: syscall.ENOTDIR, + } + } + return fs.Fs.MkdirAll(name, mode) +} + func TestReadFile(t *testing.T) { t.Parallel() @@ -82,7 +135,7 @@ func TestReadFile(t *testing.T) { error: "\"path\" is required", }, { - name: "RelativePath", + name: "RelativePathDotSlash", path: "./relative", errCode: http.StatusBadRequest, error: "file path must be absolute", @@ -214,3 +267,112 @@ func TestReadFile(t *testing.T) { }) } } + +func TestWriteFile(t *testing.T) { + t.Parallel() + + tmpdir := os.TempDir() + noPermsFilePath := filepath.Join(tmpdir, "no-perms-file") + noPermsDirPath := filepath.Join(tmpdir, "no-perms-dir") + //nolint:dogsled + conn, _, _, fs, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, opts *agent.Options) { + opts.Filesystem = newTestFs(opts.Filesystem, func(call, file string) error { + if file == noPermsFilePath || file == noPermsDirPath { + return os.ErrPermission + } + return nil + }) + }) + + dirPath := filepath.Join(tmpdir, "directory") + err := fs.MkdirAll(dirPath, 0o755) + require.NoError(t, err) + + filePath := filepath.Join(tmpdir, "file") + err = afero.WriteFile(fs, filePath, []byte("content"), 0o644) + require.NoError(t, err) + + notDirErr := "not a directory" + if runtime.GOOS == "windows" { + notDirErr = "cannot find the path" + } + + tests := []struct { + name string + path string + bytes []byte + errCode int + error string + }{ + { + name: "NoPath", + path: "", + errCode: http.StatusBadRequest, + error: "\"path\" is required", + }, + { + name: "RelativePathDotSlash", + path: "./relative", + errCode: http.StatusBadRequest, + error: "file path must be absolute", + }, + { + name: "RelativePath", + path: "also-relative", + errCode: http.StatusBadRequest, + error: "file path must be absolute", + }, + { + name: "NonExistent", + path: filepath.Join(tmpdir, "/nested/does-not-exist"), + bytes: []byte("now it does exist"), + }, + { + name: "IsDir", + path: dirPath, + errCode: http.StatusBadRequest, + error: "is a directory", + }, + { + name: "IsNotDir", + path: filepath.Join(filePath, "file2"), + errCode: http.StatusBadRequest, + error: notDirErr, + }, + { + name: "NoPermissionsFile", + path: noPermsFilePath, + errCode: http.StatusForbidden, + error: "permission denied", + }, + { + name: "NoPermissionsDir", + path: filepath.Join(noPermsDirPath, "within-no-perm-dir"), + errCode: http.StatusForbidden, + error: "permission denied", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) + defer cancel() + + reader := bytes.NewReader(tt.bytes) + err := conn.WriteFile(ctx, tt.path, reader) + if tt.errCode != 0 { + require.Error(t, err) + cerr := coderdtest.SDKError(t, err) + require.Contains(t, cerr.Error(), tt.error) + require.Equal(t, tt.errCode, cerr.StatusCode()) + } else { + require.NoError(t, err) + b, err := afero.ReadFile(fs, tt.path) + require.NoError(t, err) + require.Equal(t, tt.bytes, b) + } + }) + } +} diff --git a/codersdk/toolsdk/toolsdk.go b/codersdk/toolsdk/toolsdk.go index f63acae1c1137..46c296c0535aa 100644 --- a/codersdk/toolsdk/toolsdk.go +++ b/codersdk/toolsdk/toolsdk.go @@ -43,6 +43,7 @@ const ( ToolNameChatGPTSearch = "search" ToolNameChatGPTFetch = "fetch" ToolNameWorkspaceReadFile = "coder_workspace_read_file" + ToolNameWorkspaceWriteFile = "coder_workspace_write_file" ) func NewDeps(client *codersdk.Client, opts ...func(*Deps)) (Deps, error) { @@ -211,6 +212,7 @@ var All = []GenericTool{ ChatGPTSearch.Generic(), ChatGPTFetch.Generic(), WorkspaceReadFile.Generic(), + WorkspaceWriteFile.Generic(), } type ReportTaskArgs struct { @@ -1441,6 +1443,54 @@ var WorkspaceReadFile = Tool[WorkspaceReadFileArgs, WorkspaceReadFileResponse]{ }, } +type WorkspaceWriteFileArgs struct { + Workspace string `json:"workspace"` + Path string `json:"path"` + Content []byte `json:"content"` +} + +var WorkspaceWriteFile = Tool[WorkspaceWriteFileArgs, codersdk.Response]{ + Tool: aisdk.Tool{ + Name: ToolNameWorkspaceWriteFile, + Description: `Write a file in a workspace.`, + Schema: aisdk.Schema{ + Properties: map[string]any{ + "workspace": map[string]any{ + "type": "string", + "description": "The workspace name in the format [owner/]workspace[.agent]. If an owner is not specified, the authenticated user is used.", + }, + "path": map[string]any{ + "type": "string", + "description": "The absolute path of the file to write in the workspace.", + }, + "content": map[string]any{ + "type": "string", + "description": "The base64-encoded bytes to write to the file.", + }, + }, + Required: []string{"path", "workspace", "content"}, + }, + }, + UserClientOptional: true, + Handler: func(ctx context.Context, deps Deps, args WorkspaceWriteFileArgs) (codersdk.Response, error) { + conn, err := newAgentConn(ctx, deps.coderClient, args.Workspace) + if err != nil { + return codersdk.Response{}, err + } + defer conn.Close() + + reader := bytes.NewReader(args.Content) + err = conn.WriteFile(ctx, args.Path, reader) + if err != nil { + return codersdk.Response{}, err + } + + return codersdk.Response{ + Message: "File written successfully.", + }, nil + }, +} + // NormalizeWorkspaceInput converts workspace name input to standard format. // Handles the following input formats: // - workspace → workspace diff --git a/codersdk/toolsdk/toolsdk_test.go b/codersdk/toolsdk/toolsdk_test.go index ea37375af62bc..0030549f5eea2 100644 --- a/codersdk/toolsdk/toolsdk_test.go +++ b/codersdk/toolsdk/toolsdk_test.go @@ -555,6 +555,30 @@ func TestTools(t *testing.T) { }) } }) + + t.Run("WorkspaceWriteFile", func(t *testing.T) { + t.Parallel() + + client, workspace, agentToken := setupWorkspaceForAgent(t) + fs := afero.NewMemMapFs() + _ = agenttest.New(t, client.URL, agentToken, func(opts *agent.Options) { + opts.Filesystem = fs + }) + coderdtest.NewWorkspaceAgentWaiter(t, client, workspace.ID).Wait() + tb, err := toolsdk.NewDeps(client) + require.NoError(t, err) + + _, err = testTool(t, toolsdk.WorkspaceWriteFile, tb, toolsdk.WorkspaceWriteFileArgs{ + Workspace: workspace.Name, + Path: "/test/some/path", + Content: []byte("content"), + }) + require.NoError(t, err) + + b, err := afero.ReadFile(fs, "/test/some/path") + require.NoError(t, err) + require.Equal(t, []byte("content"), b) + }) } // TestedTools keeps track of which tools have been tested. diff --git a/codersdk/workspacesdk/agentconn.go b/codersdk/workspacesdk/agentconn.go index 7efdb06520ab0..0afb6f0c868a8 100644 --- a/codersdk/workspacesdk/agentconn.go +++ b/codersdk/workspacesdk/agentconn.go @@ -61,6 +61,7 @@ type AgentConn interface { ReconnectingPTY(ctx context.Context, id uuid.UUID, height uint16, width uint16, command string, initOpts ...AgentReconnectingPTYInitOption) (net.Conn, error) RecreateDevcontainer(ctx context.Context, devcontainerID string) (codersdk.Response, error) ReadFile(ctx context.Context, path string, offset, limit int64) (io.ReadCloser, string, error) + WriteFile(ctx context.Context, path string, reader io.Reader) error SSH(ctx context.Context) (*gonet.TCPConn, error) SSHClient(ctx context.Context) (*ssh.Client, error) SSHClientOnPort(ctx context.Context, port uint16) (*ssh.Client, error) @@ -501,6 +502,27 @@ func (c *agentConn) ReadFile(ctx context.Context, path string, offset, limit int return res.Body, mimeType, nil } +// WriteFile writes to a file in the workspace. +func (c *agentConn) WriteFile(ctx context.Context, path string, reader io.Reader) error { + ctx, span := tracing.StartSpan(ctx) + defer span.End() + + res, err := c.apiRequest(ctx, http.MethodPost, fmt.Sprintf("/api/v0/write-file?path=%s", path), reader) + if err != nil { + return xerrors.Errorf("do request: %w", err) + } + defer res.Body.Close() + if res.StatusCode != http.StatusOK { + return codersdk.ReadBodyAsError(res) + } + + var m codersdk.Response + if err := json.NewDecoder(res.Body).Decode(&m); err != nil { + return xerrors.Errorf("decode response body: %w", err) + } + return nil +} + // apiRequest makes a request to the workspace agent's HTTP API server. func (c *agentConn) apiRequest(ctx context.Context, method, path string, body io.Reader) (*http.Response, error) { ctx, span := tracing.StartSpan(ctx) diff --git a/codersdk/workspacesdk/agentconnmock/agentconnmock.go b/codersdk/workspacesdk/agentconnmock/agentconnmock.go index 6f93fb6e85ce1..4956be0c26c2b 100644 --- a/codersdk/workspacesdk/agentconnmock/agentconnmock.go +++ b/codersdk/workspacesdk/agentconnmock/agentconnmock.go @@ -387,3 +387,17 @@ func (mr *MockAgentConnMockRecorder) WatchContainers(ctx, logger any) *gomock.Ca mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WatchContainers", reflect.TypeOf((*MockAgentConn)(nil).WatchContainers), ctx, logger) } + +// WriteFile mocks base method. +func (m *MockAgentConn) WriteFile(ctx context.Context, path string, reader io.Reader) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "WriteFile", ctx, path, reader) + ret0, _ := ret[0].(error) + return ret0 +} + +// WriteFile indicates an expected call of WriteFile. +func (mr *MockAgentConnMockRecorder) WriteFile(ctx, path, reader any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WriteFile", reflect.TypeOf((*MockAgentConn)(nil).WriteFile), ctx, path, reader) +}