From c5282ea985bf376d604ea4b1d0bc0ee88e760836 Mon Sep 17 00:00:00 2001 From: Asher Date: Thu, 28 Aug 2025 09:20:09 -0800 Subject: [PATCH 1/3] Extract connecting to agent from MCP This will be used in multiple tools. --- codersdk/toolsdk/bash.go | 61 +------------------------------ codersdk/toolsdk/bash_test.go | 57 ----------------------------- codersdk/toolsdk/toolsdk.go | 63 ++++++++++++++++++++++++++++++++ codersdk/toolsdk/toolsdk_test.go | 54 +++++++++++++++++++++++++++ 4 files changed, 119 insertions(+), 116 deletions(-) diff --git a/codersdk/toolsdk/bash.go b/codersdk/toolsdk/bash.go index 037227337bfc9..7497363c2a54e 100644 --- a/codersdk/toolsdk/bash.go +++ b/codersdk/toolsdk/bash.go @@ -17,7 +17,6 @@ import ( "github.com/coder/coder/v2/cli/cliui" "github.com/coder/coder/v2/codersdk" - "github.com/coder/coder/v2/codersdk/workspacesdk" ) type WorkspaceBashArgs struct { @@ -94,42 +93,12 @@ Examples: ctx, cancel := context.WithTimeoutCause(ctx, 5*time.Minute, xerrors.New("MCP handler timeout after 5 min")) defer cancel() - // Normalize workspace input to handle various formats - workspaceName := NormalizeWorkspaceInput(args.Workspace) - - // Find workspace and agent - _, workspaceAgent, err := findWorkspaceAndAgent(ctx, deps.coderClient, workspaceName) - if err != nil { - return WorkspaceBashResult{}, xerrors.Errorf("failed to find workspace: %w", err) - } - - // Wait for agent to be ready - if err := cliui.Agent(ctx, io.Discard, workspaceAgent.ID, cliui.AgentOptions{ - FetchInterval: 0, - Fetch: deps.coderClient.WorkspaceAgent, - FetchLogs: deps.coderClient.WorkspaceAgentLogsAfter, - Wait: true, // Always wait for startup scripts - }); err != nil { - return WorkspaceBashResult{}, xerrors.Errorf("agent not ready: %w", err) - } - - // Create workspace SDK client for agent connection - wsClient := workspacesdk.New(deps.coderClient) - - // Dial agent - conn, err := wsClient.DialAgent(ctx, workspaceAgent.ID, &workspacesdk.DialAgentOptions{ - BlockEndpoints: false, - }) + conn, err := newAgentConn(ctx, deps.coderClient, args.Workspace) if err != nil { - return WorkspaceBashResult{}, xerrors.Errorf("failed to dial agent: %w", err) + return WorkspaceBashResult{}, err } defer conn.Close() - // Wait for connection to be reachable - if !conn.AwaitReachable(ctx) { - return WorkspaceBashResult{}, xerrors.New("agent connection not reachable") - } - // Create SSH client sshClient, err := conn.SSHClient(ctx) if err != nil { @@ -323,32 +292,6 @@ func namedWorkspace(ctx context.Context, client *codersdk.Client, identifier str return client.WorkspaceByOwnerAndName(ctx, owner, workspaceName, codersdk.WorkspaceOptions{}) } -// NormalizeWorkspaceInput converts workspace name input to standard format. -// Handles the following input formats: -// - workspace → workspace -// - workspace.agent → workspace.agent -// - owner/workspace → owner/workspace -// - owner--workspace → owner/workspace -// - owner/workspace.agent → owner/workspace.agent -// - owner--workspace.agent → owner/workspace.agent -// - agent.workspace.owner → owner/workspace.agent (Coder Connect format) -func NormalizeWorkspaceInput(input string) string { - // Handle the special Coder Connect format: agent.workspace.owner - // This format uses only dots and has exactly 3 parts - if strings.Count(input, ".") == 2 && !strings.Contains(input, "/") && !strings.Contains(input, "--") { - parts := strings.Split(input, ".") - if len(parts) == 3 { - // Convert agent.workspace.owner → owner/workspace.agent - return fmt.Sprintf("%s/%s.%s", parts[2], parts[1], parts[0]) - } - } - - // Convert -- separator to / separator for consistency - normalized := strings.ReplaceAll(input, "--", "/") - - return normalized -} - // executeCommandWithTimeout executes a command with timeout support func executeCommandWithTimeout(ctx context.Context, session *gossh.Session, command string) ([]byte, error) { // Set up pipes to capture output diff --git a/codersdk/toolsdk/bash_test.go b/codersdk/toolsdk/bash_test.go index caf54109688ea..da05a71ce3eda 100644 --- a/codersdk/toolsdk/bash_test.go +++ b/codersdk/toolsdk/bash_test.go @@ -99,63 +99,6 @@ func TestWorkspaceBash(t *testing.T) { }) } -func TestNormalizeWorkspaceInput(t *testing.T) { - t.Parallel() - if runtime.GOOS == "windows" { - t.Skip("Skipping on Windows: Workspace MCP bash tools rely on a Unix-like shell (bash) and POSIX/SSH semantics. Use Linux/macOS or WSL for these tests.") - } - - testCases := []struct { - name string - input string - expected string - }{ - { - name: "SimpleWorkspace", - input: "workspace", - expected: "workspace", - }, - { - name: "WorkspaceWithAgent", - input: "workspace.agent", - expected: "workspace.agent", - }, - { - name: "OwnerAndWorkspace", - input: "owner/workspace", - expected: "owner/workspace", - }, - { - name: "OwnerDashWorkspace", - input: "owner--workspace", - expected: "owner/workspace", - }, - { - name: "OwnerWorkspaceAgent", - input: "owner/workspace.agent", - expected: "owner/workspace.agent", - }, - { - name: "OwnerDashWorkspaceAgent", - input: "owner--workspace.agent", - expected: "owner/workspace.agent", - }, - { - name: "CoderConnectFormat", - input: "agent.workspace.owner", // Special Coder Connect reverse format - expected: "owner/workspace.agent", - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - result := toolsdk.NormalizeWorkspaceInput(tc.input) - require.Equal(t, tc.expected, result, "Input %q should normalize to %q but got %q", tc.input, tc.expected, result) - }) - } -} - func TestAllToolsIncludesBash(t *testing.T) { t.Parallel() if runtime.GOOS == "windows" { diff --git a/codersdk/toolsdk/toolsdk.go b/codersdk/toolsdk/toolsdk.go index 7cb8cecb25234..a5f0f5d66bcad 100644 --- a/codersdk/toolsdk/toolsdk.go +++ b/codersdk/toolsdk/toolsdk.go @@ -5,8 +5,10 @@ import ( "bytes" "context" "encoding/json" + "fmt" "io" "runtime/debug" + "strings" "github.com/google/uuid" "golang.org/x/xerrors" @@ -1360,3 +1362,64 @@ type MinimalTemplate struct { ActiveVersionID uuid.UUID `json:"active_version_id"` ActiveUserCount int `json:"active_user_count"` } + +// NormalizeWorkspaceInput converts workspace name input to standard format. +// Handles the following input formats: +// - workspace → workspace +// - workspace.agent → workspace.agent +// - owner/workspace → owner/workspace +// - owner--workspace → owner/workspace +// - owner/workspace.agent → owner/workspace.agent +// - owner--workspace.agent → owner/workspace.agent +// - agent.workspace.owner → owner/workspace.agent (Coder Connect format) +func NormalizeWorkspaceInput(input string) string { + // Handle the special Coder Connect format: agent.workspace.owner + // This format uses only dots and has exactly 3 parts + if strings.Count(input, ".") == 2 && !strings.Contains(input, "/") && !strings.Contains(input, "--") { + parts := strings.Split(input, ".") + if len(parts) == 3 { + // Convert agent.workspace.owner → owner/workspace.agent + return fmt.Sprintf("%s/%s.%s", parts[2], parts[1], parts[0]) + } + } + + // Convert -- separator to / separator for consistency + normalized := strings.ReplaceAll(input, "--", "/") + + return normalized +} + +// newAgentConn returns a connection to the agent specified by the workspace, +// which must be in the format [owner/]workspace[.agent]. +func newAgentConn(ctx context.Context, client *codersdk.Client, workspace string) (workspacesdk.AgentConn, error) { + workspaceName := NormalizeWorkspaceInput(workspace) + _, workspaceAgent, err := findWorkspaceAndAgent(ctx, client, workspaceName) + if err != nil { + return nil, xerrors.Errorf("failed to find workspace: %w", err) + } + + // Wait for agent to be ready. + if err := cliui.Agent(ctx, io.Discard, workspaceAgent.ID, cliui.AgentOptions{ + FetchInterval: 0, + Fetch: client.WorkspaceAgent, + FetchLogs: client.WorkspaceAgentLogsAfter, + Wait: true, // Always wait for startup scripts + }); err != nil { + return nil, xerrors.Errorf("agent not ready: %w", err) + } + + wsClient := workspacesdk.New(client) + + conn, err := wsClient.DialAgent(ctx, workspaceAgent.ID, &workspacesdk.DialAgentOptions{ + BlockEndpoints: false, + }) + if err != nil { + return nil, xerrors.Errorf("failed to dial agent: %w", err) + } + + if !conn.AwaitReachable(ctx) { + conn.Close() + return nil, xerrors.New("agent connection not reachable") + } + return conn, nil +} diff --git a/codersdk/toolsdk/toolsdk_test.go b/codersdk/toolsdk/toolsdk_test.go index fb321e90e7dee..1e2e05a7406c2 100644 --- a/codersdk/toolsdk/toolsdk_test.go +++ b/codersdk/toolsdk/toolsdk_test.go @@ -748,3 +748,57 @@ func TestReportTaskWithReporter(t *testing.T) { // Verify response require.Equal(t, "Thanks for reporting!", result.Message) } + +func TestNormalizeWorkspaceInput(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + input string + expected string + }{ + { + name: "SimpleWorkspace", + input: "workspace", + expected: "workspace", + }, + { + name: "WorkspaceWithAgent", + input: "workspace.agent", + expected: "workspace.agent", + }, + { + name: "OwnerAndWorkspace", + input: "owner/workspace", + expected: "owner/workspace", + }, + { + name: "OwnerDashWorkspace", + input: "owner--workspace", + expected: "owner/workspace", + }, + { + name: "OwnerWorkspaceAgent", + input: "owner/workspace.agent", + expected: "owner/workspace.agent", + }, + { + name: "OwnerDashWorkspaceAgent", + input: "owner--workspace.agent", + expected: "owner/workspace.agent", + }, + { + name: "CoderConnectFormat", + input: "agent.workspace.owner", // Special Coder Connect reverse format + expected: "owner/workspace.agent", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + result := toolsdk.NormalizeWorkspaceInput(tc.input) + require.Equal(t, tc.expected, result, "Input %q should normalize to %q but got %q", tc.input, tc.expected, result) + }) + } +} From a0429366373f9d207d9cd7e18a49d9d50716fa13 Mon Sep 17 00:00:00 2001 From: Asher Date: Tue, 26 Aug 2025 12:00:01 -0800 Subject: [PATCH 2/3] Add coder_workspace_read_file MCP tool --- agent/api.go | 1 + agent/files.go | 96 ++++++++ agent/files_test.go | 208 ++++++++++++++++++ coderd/httpapi/queryparams.go | 21 ++ codersdk/toolsdk/toolsdk.go | 60 +++++ codersdk/toolsdk/toolsdk_test.go | 88 ++++++++ codersdk/workspacesdk/agentconn.go | 40 ++++ .../agentconnmock/agentconnmock.go | 16 ++ 8 files changed, 530 insertions(+) create mode 100644 agent/files.go create mode 100644 agent/files_test.go diff --git a/agent/api.go b/agent/api.go index ca0760e130ffe..809a62bedf4b9 100644 --- a/agent/api.go +++ b/agent/api.go @@ -60,6 +60,7 @@ func (a *agent) apiHandler() http.Handler { r.Get("/api/v0/listening-ports", lp.handler) r.Get("/api/v0/netcheck", a.HandleNetcheck) r.Post("/api/v0/list-directory", a.HandleLS) + r.Get("/api/v0/read-file", a.HandleReadFile) r.Get("/debug/logs", a.HandleHTTPDebugLogs) r.Get("/debug/magicsock", a.HandleHTTPDebugMagicsock) r.Get("/debug/magicsock/debug-logging/{state}", a.HandleHTTPMagicsockDebugLoggingState) diff --git a/agent/files.go b/agent/files.go new file mode 100644 index 0000000000000..0f5db82058d00 --- /dev/null +++ b/agent/files.go @@ -0,0 +1,96 @@ +package agent + +import ( + "context" + "errors" + "io" + "mime" + "net/http" + "os" + "path/filepath" + "strconv" + + "golang.org/x/xerrors" + + "cdr.dev/slog" + "github.com/coder/coder/v2/coderd/httpapi" + "github.com/coder/coder/v2/codersdk" +) + +func (a *agent) HandleReadFile(rw http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + query := r.URL.Query() + parser := httpapi.NewQueryParamParser().RequiredNotEmpty("path") + path := parser.String(query, "", "path") + offset := parser.PositiveInt64(query, 0, "offset") + limit := parser.PositiveInt64(query, 0, "limit") + parser.ErrorExcessParams(query) + if len(parser.Errors) > 0 { + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ + Message: "Query parameters have invalid values.", + Validations: parser.Errors, + }) + return + } + + status, err := a.streamFile(ctx, rw, path, offset, limit) + if err != nil { + httpapi.Write(ctx, rw, status, codersdk.Response{ + Message: err.Error(), + }) + return + } +} + +func (a *agent) streamFile(ctx context.Context, rw http.ResponseWriter, path string, offset, limit int64) (int, error) { + if !filepath.IsAbs(path) { + return http.StatusBadRequest, xerrors.Errorf("file path must be absolute: %q", path) + } + + f, err := a.filesystem.Open(path) + if err != nil { + status := http.StatusInternalServerError + switch { + case errors.Is(err, os.ErrNotExist): + status = http.StatusNotFound + case errors.Is(err, os.ErrPermission): + status = http.StatusForbidden + } + return status, err + } + defer f.Close() + + stat, err := f.Stat() + if err != nil { + return http.StatusInternalServerError, err + } + + if stat.IsDir() { + return http.StatusBadRequest, xerrors.Errorf("open %s: not a file", path) + } + + size := stat.Size() + if limit == 0 { + limit = size + } + bytesRemaining := max(size-offset, 0) + bytesToRead := min(bytesRemaining, limit) + + // Relying on just the file name for the mime type for now. + mimeType := mime.TypeByExtension(filepath.Ext(path)) + if mimeType == "" { + mimeType = "application/octet-stream" + } + rw.Header().Set("Content-Type", mimeType) + rw.Header().Set("Content-Length", strconv.FormatInt(bytesToRead, 10)) + rw.WriteHeader(http.StatusOK) + + reader := io.NewSectionReader(f, offset, bytesToRead) + _, err = io.Copy(rw, reader) + if err != nil && !errors.Is(err, io.EOF) && ctx.Err() == nil { + a.logger.Error(ctx, "workspace agent read file", slog.Error(err)) + } + + return 0, nil +} diff --git a/agent/files_test.go b/agent/files_test.go new file mode 100644 index 0000000000000..3731db2dccbdd --- /dev/null +++ b/agent/files_test.go @@ -0,0 +1,208 @@ +package agent_test + +import ( + "context" + "net/http" + "os" + "path/filepath" + "testing" + + "github.com/spf13/afero" + "github.com/stretchr/testify/require" + + "github.com/coder/coder/v2/agent" + "github.com/coder/coder/v2/agent/agenttest" + "github.com/coder/coder/v2/coderd/coderdtest" + "github.com/coder/coder/v2/codersdk/agentsdk" + "github.com/coder/coder/v2/testutil" +) + +type testFs struct { + afero.Fs + // intercept can return an error for testing when a call fails. + intercept func(call, file string) error +} + +func newTestFs(base afero.Fs, intercept func(call, file string) error) *testFs { + return &testFs{ + Fs: base, + intercept: intercept, + } +} + +func (fs *testFs) Open(name string) (afero.File, error) { + if err := fs.intercept("open", name); err != nil { + return nil, err + } + return fs.Fs.Open(name) +} + +func TestReadFile(t *testing.T) { + t.Parallel() + + tmpdir := os.TempDir() + noPermsFilePath := filepath.Join(tmpdir, "no-perms") + //nolint:dogsled + conn, _, _, fs, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, opts *agent.Options) { + opts.Filesystem = newTestFs(opts.Filesystem, func(call, file string) error { + if file == noPermsFilePath { + return os.ErrPermission + } + return nil + }) + }) + + dirPath := filepath.Join(tmpdir, "a-directory") + err := fs.MkdirAll(dirPath, 0o755) + require.NoError(t, err) + + filePath := filepath.Join(tmpdir, "file") + err = afero.WriteFile(fs, filePath, []byte("content"), 0o644) + require.NoError(t, err) + + imagePath := filepath.Join(tmpdir, "file.png") + err = afero.WriteFile(fs, imagePath, []byte("not really an image"), 0o644) + require.NoError(t, err) + + tests := []struct { + name string + path string + limit int64 + offset int64 + bytes []byte + mimeType string + errCode int + error string + }{ + { + name: "NoPath", + path: "", + errCode: http.StatusBadRequest, + error: "\"path\" is required", + }, + { + name: "RelativePath", + path: "./relative", + errCode: http.StatusBadRequest, + error: "file path must be absolute", + }, + { + name: "RelativePath", + path: "also-relative", + errCode: http.StatusBadRequest, + error: "file path must be absolute", + }, + { + name: "NegativeLimit", + path: filePath, + limit: -10, + errCode: http.StatusBadRequest, + error: "value is negative", + }, + { + name: "NegativeOffset", + path: filePath, + offset: -10, + errCode: http.StatusBadRequest, + error: "value is negative", + }, + { + name: "NonExistent", + path: filepath.Join(tmpdir, "does-not-exist"), + errCode: http.StatusNotFound, + error: "file does not exist", + }, + { + name: "IsDir", + path: dirPath, + errCode: http.StatusBadRequest, + error: "not a file", + }, + { + name: "NoPermissions", + path: noPermsFilePath, + errCode: http.StatusForbidden, + error: "permission denied", + }, + { + name: "Defaults", + path: filePath, + bytes: []byte("content"), + }, + { + name: "Limit1", + path: filePath, + limit: 1, + bytes: []byte("c"), + }, + { + name: "Offset1", + path: filePath, + offset: 1, + bytes: []byte("ontent"), + }, + { + name: "Limit1Offset2", + path: filePath, + limit: 1, + offset: 2, + bytes: []byte("n"), + }, + { + name: "Limit7Offset0", + path: filePath, + limit: 7, + offset: 0, + bytes: []byte("content"), + }, + { + name: "Limit100", + path: filePath, + limit: 100, + bytes: []byte("content"), + }, + { + name: "Offset7", + path: filePath, + offset: 7, + bytes: []byte{}, + }, + { + name: "Offset100", + path: filePath, + offset: 100, + bytes: []byte{}, + }, + { + name: "MimeTypePng", + path: imagePath, + bytes: []byte("not really an image"), + mimeType: "image/png", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) + defer cancel() + + b, mimeType, err := conn.ReadFile(ctx, tt.path, tt.offset, tt.limit) + if tt.errCode != 0 { + require.Error(t, err) + cerr := coderdtest.SDKError(t, err) + require.Contains(t, cerr.Error(), tt.error) + require.Equal(t, tt.errCode, cerr.StatusCode()) + } else { + require.NoError(t, err) + require.Equal(t, tt.bytes, b) + expectedMimeType := tt.mimeType + if expectedMimeType == "" { + expectedMimeType = "application/octet-stream" + } + require.Equal(t, expectedMimeType, mimeType) + } + }) + } +} diff --git a/coderd/httpapi/queryparams.go b/coderd/httpapi/queryparams.go index e1bd983ea12a3..d30244eaf04cc 100644 --- a/coderd/httpapi/queryparams.go +++ b/coderd/httpapi/queryparams.go @@ -120,6 +120,27 @@ func (p *QueryParamParser) PositiveInt32(vals url.Values, def int32, queryParam return v } +// PositiveInt64 function checks if the given value is 64-bit and positive. +func (p *QueryParamParser) PositiveInt64(vals url.Values, def int64, queryParam string) int64 { + v, err := parseQueryParam(p, vals, func(v string) (int64, error) { + intValue, err := strconv.ParseInt(v, 10, 64) + if err != nil { + return 0, err + } + if intValue < 0 { + return 0, xerrors.Errorf("value is negative") + } + return intValue, nil + }, def, queryParam) + if err != nil { + p.Errors = append(p.Errors, codersdk.ValidationError{ + Field: queryParam, + Detail: fmt.Sprintf("Query param %q must be a valid 64-bit positive integer: %s", queryParam, err.Error()), + }) + } + return v +} + // NullableBoolean will return a null sql value if no input is provided. // SQLc still uses sql.NullBool rather than the generic type. So converting from // the generic type is required. diff --git a/codersdk/toolsdk/toolsdk.go b/codersdk/toolsdk/toolsdk.go index a5f0f5d66bcad..78a824616a19e 100644 --- a/codersdk/toolsdk/toolsdk.go +++ b/codersdk/toolsdk/toolsdk.go @@ -16,7 +16,9 @@ import ( "github.com/coder/aisdk-go" "github.com/coder/coder/v2/buildinfo" + "github.com/coder/coder/v2/cli/cliui" "github.com/coder/coder/v2/codersdk" + "github.com/coder/coder/v2/codersdk/workspacesdk" ) // Tool name constants to avoid hardcoded strings @@ -40,6 +42,7 @@ const ( ToolNameWorkspaceBash = "coder_workspace_bash" ToolNameChatGPTSearch = "search" ToolNameChatGPTFetch = "fetch" + ToolNameWorkspaceReadFile = "coder_workspace_read_file" ) func NewDeps(client *codersdk.Client, opts ...func(*Deps)) (Deps, error) { @@ -207,6 +210,7 @@ var All = []GenericTool{ WorkspaceBash.Generic(), ChatGPTSearch.Generic(), ChatGPTFetch.Generic(), + WorkspaceReadFile.Generic(), } type ReportTaskArgs struct { @@ -1363,6 +1367,62 @@ type MinimalTemplate struct { ActiveUserCount int `json:"active_user_count"` } +type WorkspaceReadFileArgs struct { + Workspace string `json:"workspace"` + Path string `json:"path"` + Offset int64 `json:"offset"` + Limit int64 `json:"limit"` +} + +type WorkspaceReadFileResponse struct { + // Content is the base64-encoded bytes from the file. + Content []byte `json:"content"` + MimeType string `json:"mimeType"` +} + +var WorkspaceReadFile = Tool[WorkspaceReadFileArgs, WorkspaceReadFileResponse]{ + Tool: aisdk.Tool{ + Name: ToolNameWorkspaceReadFile, + Description: `Read from a file in a workspace.`, + Schema: aisdk.Schema{ + Properties: map[string]any{ + "workspace": map[string]any{ + "type": "string", + "description": "The workspace name in the format [owner/]workspace[.agent]. If an owner is not specified, the authenticated user is used.", + }, + "path": map[string]any{ + "type": "string", + "description": "The absolute path of the file to read in the workspace.", + }, + "offset": map[string]any{ + "type": "integer", + "description": "A byte offset indicating where in the file to start reading. Defaults to zero. An empty string indicates the end of the file has been reached.", + }, + "limit": map[string]any{ + "type": "integer", + "description": "The number of bytes to read. Cannot exceed 1 MiB. Defaults to the full size of the file or 1 MiB, whichever is lower.", + }, + }, + Required: []string{"path", "workspace"}, + }, + }, + UserClientOptional: true, + Handler: func(ctx context.Context, deps Deps, args WorkspaceReadFileArgs) (WorkspaceReadFileResponse, error) { + conn, err := newAgentConn(ctx, deps.coderClient, args.Workspace) + if err != nil { + return WorkspaceReadFileResponse{}, err + } + defer conn.Close() + + bytes, mimeType, err := conn.ReadFile(ctx, args.Path, args.Offset, args.Limit) + if err != nil { + return WorkspaceReadFileResponse{}, err + } + + return WorkspaceReadFileResponse{Content: bytes, MimeType: mimeType}, nil + }, +} + // NormalizeWorkspaceInput converts workspace name input to standard format. // Handles the following input formats: // - workspace → workspace diff --git a/codersdk/toolsdk/toolsdk_test.go b/codersdk/toolsdk/toolsdk_test.go index 1e2e05a7406c2..45ca8db8e20ff 100644 --- a/codersdk/toolsdk/toolsdk_test.go +++ b/codersdk/toolsdk/toolsdk_test.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "os" + "path/filepath" "runtime" "sort" "sync" @@ -11,12 +12,14 @@ import ( "time" "github.com/google/uuid" + "github.com/spf13/afero" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.uber.org/goleak" "github.com/coder/aisdk-go" + "github.com/coder/coder/v2/agent" "github.com/coder/coder/v2/agent/agenttest" "github.com/coder/coder/v2/coderd/coderdtest" "github.com/coder/coder/v2/coderd/database" @@ -450,6 +453,91 @@ func TestTools(t *testing.T) { require.Equal(t, 0, result.ExitCode) require.Equal(t, "owner format works", result.Output) }) + + t.Run("WorkspaceReadFile", func(t *testing.T) { + t.Parallel() + + client, workspace, agentToken := setupWorkspaceForAgent(t) + fs := afero.NewMemMapFs() + _ = agenttest.New(t, client.URL, agentToken, func(opts *agent.Options) { + opts.Filesystem = fs + }) + coderdtest.NewWorkspaceAgentWaiter(t, client, workspace.ID).Wait() + tb, err := toolsdk.NewDeps(client) + require.NoError(t, err) + + tmpdir := os.TempDir() + filePath := filepath.Join(tmpdir, "file") + err = afero.WriteFile(fs, filePath, []byte("content"), 0o644) + require.NoError(t, err) + + imagePath := filepath.Join(tmpdir, "file.png") + err = afero.WriteFile(fs, imagePath, []byte("not really an image"), 0o644) + require.NoError(t, err) + + tests := []struct { + name string + path string + limit int64 + offset int64 + mimeType string + bytes []byte + error string + }{ + { + name: "NonExistent", + path: filepath.Join(tmpdir, "does-not-exist"), + error: "file does not exist", + }, + { + name: "Exists", + path: filePath, + bytes: []byte("content"), + mimeType: "application/octet-stream", + }, + { + name: "Limit1Offset2", + path: filePath, + limit: 1, + offset: 2, + bytes: []byte("n"), + mimeType: "application/octet-stream", + }, + { + name: "MaxLimit", + path: filePath, + limit: 1 << 21, + error: "limit must be 1048576 or less, got 2097152", + }, + { + name: "ImageMimeType", + path: imagePath, + bytes: []byte("not really an image"), + mimeType: "image/png", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + resp, err := testTool(t, toolsdk.WorkspaceReadFile, tb, toolsdk.WorkspaceReadFileArgs{ + Workspace: workspace.Name, + Path: tt.path, + Limit: tt.limit, + Offset: tt.offset, + }) + if tt.error != "" { + require.Error(t, err) + require.Contains(t, err.Error(), tt.error) + } else { + require.NoError(t, err) + require.Equal(t, tt.bytes, resp.Content) + require.Equal(t, tt.mimeType, resp.MimeType) + } + }) + } + }) } // TestedTools keeps track of which tools have been tested. diff --git a/codersdk/workspacesdk/agentconn.go b/codersdk/workspacesdk/agentconn.go index bb929c9ba2a04..e18080a1881ca 100644 --- a/codersdk/workspacesdk/agentconn.go +++ b/codersdk/workspacesdk/agentconn.go @@ -60,6 +60,7 @@ type AgentConn interface { PrometheusMetrics(ctx context.Context) ([]byte, error) ReconnectingPTY(ctx context.Context, id uuid.UUID, height uint16, width uint16, command string, initOpts ...AgentReconnectingPTYInitOption) (net.Conn, error) RecreateDevcontainer(ctx context.Context, devcontainerID string) (codersdk.Response, error) + ReadFile(ctx context.Context, path string, offset, limit int64) ([]byte, string, error) SSH(ctx context.Context) (*gonet.TCPConn, error) SSHClient(ctx context.Context) (*ssh.Client, error) SSHClientOnPort(ctx context.Context, port uint16) (*ssh.Client, error) @@ -476,6 +477,45 @@ func (c *agentConn) RecreateDevcontainer(ctx context.Context, devcontainerID str return m, nil } +const maxFileLimit = 1 << 20 // 1MiB + +// ReadFile reads from a file from the workspace, returning the file's +// (potentially partial) bytes and the mime type. +func (c *agentConn) ReadFile(ctx context.Context, path string, offset, limit int64) ([]byte, string, error) { + ctx, span := tracing.StartSpan(ctx) + defer span.End() + + // Ideally we could stream this all the way back, but it looks like the MCP + // interfaces only allow returning full responses which means the whole thing + // has to be read into memory. So, add a maximum to compensate. + if limit == 0 { + limit = maxFileLimit + } else if limit > maxFileLimit { + return nil, "", xerrors.Errorf("limit must be %d or less, got %d", maxFileLimit, limit) + } + + res, err := c.apiRequest(ctx, http.MethodGet, fmt.Sprintf("/api/v0/read-file?path=%s&offset=%d&limit=%d", path, offset, limit), nil) + if err != nil { + return nil, "", xerrors.Errorf("do request: %w", err) + } + defer res.Body.Close() + if res.StatusCode != http.StatusOK { + return nil, "", codersdk.ReadBodyAsError(res) + } + + bs, err := io.ReadAll(res.Body) + if err != nil { + return nil, "", xerrors.Errorf("read response body: %w", err) + } + + mimeType := res.Header.Get("Content-Type") + if mimeType == "" { + mimeType = "application/octet-stream" + } + + return bs, mimeType, nil +} + // apiRequest makes a request to the workspace agent's HTTP API server. func (c *agentConn) apiRequest(ctx context.Context, method, path string, body io.Reader) (*http.Response, error) { ctx, span := tracing.StartSpan(ctx) diff --git a/codersdk/workspacesdk/agentconnmock/agentconnmock.go b/codersdk/workspacesdk/agentconnmock/agentconnmock.go index eb55bb27938c0..18b16e03c0d71 100644 --- a/codersdk/workspacesdk/agentconnmock/agentconnmock.go +++ b/codersdk/workspacesdk/agentconnmock/agentconnmock.go @@ -232,6 +232,22 @@ func (mr *MockAgentConnMockRecorder) PrometheusMetrics(ctx any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PrometheusMetrics", reflect.TypeOf((*MockAgentConn)(nil).PrometheusMetrics), ctx) } +// ReadFile mocks base method. +func (m *MockAgentConn) ReadFile(ctx context.Context, path string, offset, limit int64) ([]byte, string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ReadFile", ctx, path, offset, limit) + ret0, _ := ret[0].([]byte) + ret1, _ := ret[1].(string) + ret2, _ := ret[2].(error) + return ret0, ret1, ret2 +} + +// ReadFile indicates an expected call of ReadFile. +func (mr *MockAgentConnMockRecorder) ReadFile(ctx, path, offset, limit any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReadFile", reflect.TypeOf((*MockAgentConn)(nil).ReadFile), ctx, path, offset, limit) +} + // ReconnectingPTY mocks base method. func (m *MockAgentConn) ReconnectingPTY(ctx context.Context, id uuid.UUID, height, width uint16, command string, initOpts ...workspacesdk.AgentReconnectingPTYInitOption) (net.Conn, error) { m.ctrl.T.Helper() From f9034236f38a5fcc1176964cee5240a354c7f9ce Mon Sep 17 00:00:00 2001 From: Asher Date: Tue, 9 Sep 2025 14:11:33 -0800 Subject: [PATCH 3/3] Move file read buffering to toolsdk --- agent/files_test.go | 88 ++++++++++--------- codersdk/toolsdk/toolsdk.go | 22 ++++- codersdk/toolsdk/toolsdk_test.go | 22 ++++- codersdk/workspacesdk/agentconn.go | 29 ++---- .../agentconnmock/agentconnmock.go | 4 +- 5 files changed, 97 insertions(+), 68 deletions(-) diff --git a/agent/files_test.go b/agent/files_test.go index 3731db2dccbdd..80aeadd518cf0 100644 --- a/agent/files_test.go +++ b/agent/files_test.go @@ -2,6 +2,7 @@ package agent_test import ( "context" + "io" "net/http" "os" "path/filepath" @@ -125,53 +126,61 @@ func TestReadFile(t *testing.T) { error: "permission denied", }, { - name: "Defaults", - path: filePath, - bytes: []byte("content"), + name: "Defaults", + path: filePath, + bytes: []byte("content"), + mimeType: "application/octet-stream", }, { - name: "Limit1", - path: filePath, - limit: 1, - bytes: []byte("c"), + name: "Limit1", + path: filePath, + limit: 1, + bytes: []byte("c"), + mimeType: "application/octet-stream", }, { - name: "Offset1", - path: filePath, - offset: 1, - bytes: []byte("ontent"), + name: "Offset1", + path: filePath, + offset: 1, + bytes: []byte("ontent"), + mimeType: "application/octet-stream", }, { - name: "Limit1Offset2", - path: filePath, - limit: 1, - offset: 2, - bytes: []byte("n"), + name: "Limit1Offset2", + path: filePath, + limit: 1, + offset: 2, + bytes: []byte("n"), + mimeType: "application/octet-stream", }, { - name: "Limit7Offset0", - path: filePath, - limit: 7, - offset: 0, - bytes: []byte("content"), + name: "Limit7Offset0", + path: filePath, + limit: 7, + offset: 0, + bytes: []byte("content"), + mimeType: "application/octet-stream", }, { - name: "Limit100", - path: filePath, - limit: 100, - bytes: []byte("content"), + name: "Limit100", + path: filePath, + limit: 100, + bytes: []byte("content"), + mimeType: "application/octet-stream", }, { - name: "Offset7", - path: filePath, - offset: 7, - bytes: []byte{}, + name: "Offset7", + path: filePath, + offset: 7, + bytes: []byte{}, + mimeType: "application/octet-stream", }, { - name: "Offset100", - path: filePath, - offset: 100, - bytes: []byte{}, + name: "Offset100", + path: filePath, + offset: 100, + bytes: []byte{}, + mimeType: "application/octet-stream", }, { name: "MimeTypePng", @@ -188,7 +197,7 @@ func TestReadFile(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) defer cancel() - b, mimeType, err := conn.ReadFile(ctx, tt.path, tt.offset, tt.limit) + reader, mimeType, err := conn.ReadFile(ctx, tt.path, tt.offset, tt.limit) if tt.errCode != 0 { require.Error(t, err) cerr := coderdtest.SDKError(t, err) @@ -196,12 +205,11 @@ func TestReadFile(t *testing.T) { require.Equal(t, tt.errCode, cerr.StatusCode()) } else { require.NoError(t, err) - require.Equal(t, tt.bytes, b) - expectedMimeType := tt.mimeType - if expectedMimeType == "" { - expectedMimeType = "application/octet-stream" - } - require.Equal(t, expectedMimeType, mimeType) + defer reader.Close() + bytes, err := io.ReadAll(reader) + require.NoError(t, err) + require.Equal(t, tt.bytes, bytes) + require.Equal(t, tt.mimeType, mimeType) } }) } diff --git a/codersdk/toolsdk/toolsdk.go b/codersdk/toolsdk/toolsdk.go index 78a824616a19e..f63acae1c1137 100644 --- a/codersdk/toolsdk/toolsdk.go +++ b/codersdk/toolsdk/toolsdk.go @@ -1380,6 +1380,8 @@ type WorkspaceReadFileResponse struct { MimeType string `json:"mimeType"` } +const maxFileLimit = 1 << 20 // 1MiB + var WorkspaceReadFile = Tool[WorkspaceReadFileArgs, WorkspaceReadFileResponse]{ Tool: aisdk.Tool{ Name: ToolNameWorkspaceReadFile, @@ -1414,12 +1416,28 @@ var WorkspaceReadFile = Tool[WorkspaceReadFileArgs, WorkspaceReadFileResponse]{ } defer conn.Close() - bytes, mimeType, err := conn.ReadFile(ctx, args.Path, args.Offset, args.Limit) + // Ideally we could stream this all the way back, but it looks like the MCP + // interfaces only allow returning full responses which means the whole + // thing has to be read into memory. So, add a maximum limit to compensate. + limit := args.Limit + if limit == 0 { + limit = maxFileLimit + } else if limit > maxFileLimit { + return WorkspaceReadFileResponse{}, xerrors.Errorf("limit must be %d or less, got %d", maxFileLimit, limit) + } + + reader, mimeType, err := conn.ReadFile(ctx, args.Path, args.Offset, limit) if err != nil { return WorkspaceReadFileResponse{}, err } + defer reader.Close() + + bs, err := io.ReadAll(reader) + if err != nil { + return WorkspaceReadFileResponse{}, xerrors.Errorf("read response body: %w", err) + } - return WorkspaceReadFileResponse{Content: bytes, MimeType: mimeType}, nil + return WorkspaceReadFileResponse{Content: bs, MimeType: mimeType}, nil }, } diff --git a/codersdk/toolsdk/toolsdk_test.go b/codersdk/toolsdk/toolsdk_test.go index 45ca8db8e20ff..fd128c198e126 100644 --- a/codersdk/toolsdk/toolsdk_test.go +++ b/codersdk/toolsdk/toolsdk_test.go @@ -471,6 +471,12 @@ func TestTools(t *testing.T) { err = afero.WriteFile(fs, filePath, []byte("content"), 0o644) require.NoError(t, err) + largeFilePath := filepath.Join(tmpdir, "large") + largeFile, err := fs.Create(largeFilePath) + require.NoError(t, err) + err = largeFile.Truncate(1 << 21) + require.NoError(t, err) + imagePath := filepath.Join(tmpdir, "file.png") err = afero.WriteFile(fs, imagePath, []byte("not really an image"), 0o644) require.NoError(t, err) @@ -482,6 +488,7 @@ func TestTools(t *testing.T) { offset int64 mimeType string bytes []byte + length int error string }{ { @@ -504,7 +511,13 @@ func TestTools(t *testing.T) { mimeType: "application/octet-stream", }, { - name: "MaxLimit", + name: "DefaultMaxLimit", + path: largeFilePath, + length: 1 << 20, + mimeType: "application/octet-stream", + }, + { + name: "ExceedMaxLimit", path: filePath, limit: 1 << 21, error: "limit must be 1048576 or less, got 2097152", @@ -532,7 +545,12 @@ func TestTools(t *testing.T) { require.Contains(t, err.Error(), tt.error) } else { require.NoError(t, err) - require.Equal(t, tt.bytes, resp.Content) + if tt.length != 0 { + require.Len(t, resp.Content, tt.length) + } + if tt.bytes != nil { + require.Equal(t, tt.bytes, resp.Content) + } require.Equal(t, tt.mimeType, resp.MimeType) } }) diff --git a/codersdk/workspacesdk/agentconn.go b/codersdk/workspacesdk/agentconn.go index e18080a1881ca..7efdb06520ab0 100644 --- a/codersdk/workspacesdk/agentconn.go +++ b/codersdk/workspacesdk/agentconn.go @@ -60,7 +60,7 @@ type AgentConn interface { PrometheusMetrics(ctx context.Context) ([]byte, error) ReconnectingPTY(ctx context.Context, id uuid.UUID, height uint16, width uint16, command string, initOpts ...AgentReconnectingPTYInitOption) (net.Conn, error) RecreateDevcontainer(ctx context.Context, devcontainerID string) (codersdk.Response, error) - ReadFile(ctx context.Context, path string, offset, limit int64) ([]byte, string, error) + ReadFile(ctx context.Context, path string, offset, limit int64) (io.ReadCloser, string, error) SSH(ctx context.Context) (*gonet.TCPConn, error) SSHClient(ctx context.Context) (*ssh.Client, error) SSHClientOnPort(ctx context.Context, port uint16) (*ssh.Client, error) @@ -477,43 +477,28 @@ func (c *agentConn) RecreateDevcontainer(ctx context.Context, devcontainerID str return m, nil } -const maxFileLimit = 1 << 20 // 1MiB - -// ReadFile reads from a file from the workspace, returning the file's -// (potentially partial) bytes and the mime type. -func (c *agentConn) ReadFile(ctx context.Context, path string, offset, limit int64) ([]byte, string, error) { +// ReadFile reads from a file from the workspace, returning a file reader and +// the mime type. +func (c *agentConn) ReadFile(ctx context.Context, path string, offset, limit int64) (io.ReadCloser, string, error) { ctx, span := tracing.StartSpan(ctx) defer span.End() - // Ideally we could stream this all the way back, but it looks like the MCP - // interfaces only allow returning full responses which means the whole thing - // has to be read into memory. So, add a maximum to compensate. - if limit == 0 { - limit = maxFileLimit - } else if limit > maxFileLimit { - return nil, "", xerrors.Errorf("limit must be %d or less, got %d", maxFileLimit, limit) - } - + //nolint:bodyclose // we want to return the body so the caller can stream. res, err := c.apiRequest(ctx, http.MethodGet, fmt.Sprintf("/api/v0/read-file?path=%s&offset=%d&limit=%d", path, offset, limit), nil) if err != nil { return nil, "", xerrors.Errorf("do request: %w", err) } - defer res.Body.Close() if res.StatusCode != http.StatusOK { + // codersdk.ReadBodyAsError will close the body. return nil, "", codersdk.ReadBodyAsError(res) } - bs, err := io.ReadAll(res.Body) - if err != nil { - return nil, "", xerrors.Errorf("read response body: %w", err) - } - mimeType := res.Header.Get("Content-Type") if mimeType == "" { mimeType = "application/octet-stream" } - return bs, mimeType, nil + return res.Body, mimeType, nil } // apiRequest makes a request to the workspace agent's HTTP API server. diff --git a/codersdk/workspacesdk/agentconnmock/agentconnmock.go b/codersdk/workspacesdk/agentconnmock/agentconnmock.go index 18b16e03c0d71..6f93fb6e85ce1 100644 --- a/codersdk/workspacesdk/agentconnmock/agentconnmock.go +++ b/codersdk/workspacesdk/agentconnmock/agentconnmock.go @@ -233,10 +233,10 @@ func (mr *MockAgentConnMockRecorder) PrometheusMetrics(ctx any) *gomock.Call { } // ReadFile mocks base method. -func (m *MockAgentConn) ReadFile(ctx context.Context, path string, offset, limit int64) ([]byte, string, error) { +func (m *MockAgentConn) ReadFile(ctx context.Context, path string, offset, limit int64) (io.ReadCloser, string, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "ReadFile", ctx, path, offset, limit) - ret0, _ := ret[0].([]byte) + ret0, _ := ret[0].(io.ReadCloser) ret1, _ := ret[1].(string) ret2, _ := ret[2].(error) return ret0, ret1, ret2