diff --git a/agent/api.go b/agent/api.go index ca0760e130ffe..809a62bedf4b9 100644 --- a/agent/api.go +++ b/agent/api.go @@ -60,6 +60,7 @@ func (a *agent) apiHandler() http.Handler { r.Get("/api/v0/listening-ports", lp.handler) r.Get("/api/v0/netcheck", a.HandleNetcheck) r.Post("/api/v0/list-directory", a.HandleLS) + r.Get("/api/v0/read-file", a.HandleReadFile) r.Get("/debug/logs", a.HandleHTTPDebugLogs) r.Get("/debug/magicsock", a.HandleHTTPDebugMagicsock) r.Get("/debug/magicsock/debug-logging/{state}", a.HandleHTTPMagicsockDebugLoggingState) diff --git a/agent/files.go b/agent/files.go new file mode 100644 index 0000000000000..0f5db82058d00 --- /dev/null +++ b/agent/files.go @@ -0,0 +1,96 @@ +package agent + +import ( + "context" + "errors" + "io" + "mime" + "net/http" + "os" + "path/filepath" + "strconv" + + "golang.org/x/xerrors" + + "cdr.dev/slog" + "github.com/coder/coder/v2/coderd/httpapi" + "github.com/coder/coder/v2/codersdk" +) + +func (a *agent) HandleReadFile(rw http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + query := r.URL.Query() + parser := httpapi.NewQueryParamParser().RequiredNotEmpty("path") + path := parser.String(query, "", "path") + offset := parser.PositiveInt64(query, 0, "offset") + limit := parser.PositiveInt64(query, 0, "limit") + parser.ErrorExcessParams(query) + if len(parser.Errors) > 0 { + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ + Message: "Query parameters have invalid values.", + Validations: parser.Errors, + }) + return + } + + status, err := a.streamFile(ctx, rw, path, offset, limit) + if err != nil { + httpapi.Write(ctx, rw, status, codersdk.Response{ + Message: err.Error(), + }) + return + } +} + +func (a *agent) streamFile(ctx context.Context, rw http.ResponseWriter, path string, offset, limit int64) (int, error) { + if !filepath.IsAbs(path) { + return http.StatusBadRequest, xerrors.Errorf("file path must be absolute: %q", path) + } + + f, err := a.filesystem.Open(path) + if err != nil { + status := http.StatusInternalServerError + switch { + case errors.Is(err, os.ErrNotExist): + status = http.StatusNotFound + case errors.Is(err, os.ErrPermission): + status = http.StatusForbidden + } + return status, err + } + defer f.Close() + + stat, err := f.Stat() + if err != nil { + return http.StatusInternalServerError, err + } + + if stat.IsDir() { + return http.StatusBadRequest, xerrors.Errorf("open %s: not a file", path) + } + + size := stat.Size() + if limit == 0 { + limit = size + } + bytesRemaining := max(size-offset, 0) + bytesToRead := min(bytesRemaining, limit) + + // Relying on just the file name for the mime type for now. + mimeType := mime.TypeByExtension(filepath.Ext(path)) + if mimeType == "" { + mimeType = "application/octet-stream" + } + rw.Header().Set("Content-Type", mimeType) + rw.Header().Set("Content-Length", strconv.FormatInt(bytesToRead, 10)) + rw.WriteHeader(http.StatusOK) + + reader := io.NewSectionReader(f, offset, bytesToRead) + _, err = io.Copy(rw, reader) + if err != nil && !errors.Is(err, io.EOF) && ctx.Err() == nil { + a.logger.Error(ctx, "workspace agent read file", slog.Error(err)) + } + + return 0, nil +} diff --git a/agent/files_test.go b/agent/files_test.go new file mode 100644 index 0000000000000..80aeadd518cf0 --- /dev/null +++ b/agent/files_test.go @@ -0,0 +1,216 @@ +package agent_test + +import ( + "context" + "io" + "net/http" + "os" + "path/filepath" + "testing" + + "github.com/spf13/afero" + "github.com/stretchr/testify/require" + + "github.com/coder/coder/v2/agent" + "github.com/coder/coder/v2/agent/agenttest" + "github.com/coder/coder/v2/coderd/coderdtest" + "github.com/coder/coder/v2/codersdk/agentsdk" + "github.com/coder/coder/v2/testutil" +) + +type testFs struct { + afero.Fs + // intercept can return an error for testing when a call fails. + intercept func(call, file string) error +} + +func newTestFs(base afero.Fs, intercept func(call, file string) error) *testFs { + return &testFs{ + Fs: base, + intercept: intercept, + } +} + +func (fs *testFs) Open(name string) (afero.File, error) { + if err := fs.intercept("open", name); err != nil { + return nil, err + } + return fs.Fs.Open(name) +} + +func TestReadFile(t *testing.T) { + t.Parallel() + + tmpdir := os.TempDir() + noPermsFilePath := filepath.Join(tmpdir, "no-perms") + //nolint:dogsled + conn, _, _, fs, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, opts *agent.Options) { + opts.Filesystem = newTestFs(opts.Filesystem, func(call, file string) error { + if file == noPermsFilePath { + return os.ErrPermission + } + return nil + }) + }) + + dirPath := filepath.Join(tmpdir, "a-directory") + err := fs.MkdirAll(dirPath, 0o755) + require.NoError(t, err) + + filePath := filepath.Join(tmpdir, "file") + err = afero.WriteFile(fs, filePath, []byte("content"), 0o644) + require.NoError(t, err) + + imagePath := filepath.Join(tmpdir, "file.png") + err = afero.WriteFile(fs, imagePath, []byte("not really an image"), 0o644) + require.NoError(t, err) + + tests := []struct { + name string + path string + limit int64 + offset int64 + bytes []byte + mimeType string + errCode int + error string + }{ + { + name: "NoPath", + path: "", + errCode: http.StatusBadRequest, + error: "\"path\" is required", + }, + { + name: "RelativePath", + path: "./relative", + errCode: http.StatusBadRequest, + error: "file path must be absolute", + }, + { + name: "RelativePath", + path: "also-relative", + errCode: http.StatusBadRequest, + error: "file path must be absolute", + }, + { + name: "NegativeLimit", + path: filePath, + limit: -10, + errCode: http.StatusBadRequest, + error: "value is negative", + }, + { + name: "NegativeOffset", + path: filePath, + offset: -10, + errCode: http.StatusBadRequest, + error: "value is negative", + }, + { + name: "NonExistent", + path: filepath.Join(tmpdir, "does-not-exist"), + errCode: http.StatusNotFound, + error: "file does not exist", + }, + { + name: "IsDir", + path: dirPath, + errCode: http.StatusBadRequest, + error: "not a file", + }, + { + name: "NoPermissions", + path: noPermsFilePath, + errCode: http.StatusForbidden, + error: "permission denied", + }, + { + name: "Defaults", + path: filePath, + bytes: []byte("content"), + mimeType: "application/octet-stream", + }, + { + name: "Limit1", + path: filePath, + limit: 1, + bytes: []byte("c"), + mimeType: "application/octet-stream", + }, + { + name: "Offset1", + path: filePath, + offset: 1, + bytes: []byte("ontent"), + mimeType: "application/octet-stream", + }, + { + name: "Limit1Offset2", + path: filePath, + limit: 1, + offset: 2, + bytes: []byte("n"), + mimeType: "application/octet-stream", + }, + { + name: "Limit7Offset0", + path: filePath, + limit: 7, + offset: 0, + bytes: []byte("content"), + mimeType: "application/octet-stream", + }, + { + name: "Limit100", + path: filePath, + limit: 100, + bytes: []byte("content"), + mimeType: "application/octet-stream", + }, + { + name: "Offset7", + path: filePath, + offset: 7, + bytes: []byte{}, + mimeType: "application/octet-stream", + }, + { + name: "Offset100", + path: filePath, + offset: 100, + bytes: []byte{}, + mimeType: "application/octet-stream", + }, + { + name: "MimeTypePng", + path: imagePath, + bytes: []byte("not really an image"), + mimeType: "image/png", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) + defer cancel() + + reader, mimeType, err := conn.ReadFile(ctx, tt.path, tt.offset, tt.limit) + if tt.errCode != 0 { + require.Error(t, err) + cerr := coderdtest.SDKError(t, err) + require.Contains(t, cerr.Error(), tt.error) + require.Equal(t, tt.errCode, cerr.StatusCode()) + } else { + require.NoError(t, err) + defer reader.Close() + bytes, err := io.ReadAll(reader) + require.NoError(t, err) + require.Equal(t, tt.bytes, bytes) + require.Equal(t, tt.mimeType, mimeType) + } + }) + } +} diff --git a/coderd/httpapi/queryparams.go b/coderd/httpapi/queryparams.go index e1bd983ea12a3..d30244eaf04cc 100644 --- a/coderd/httpapi/queryparams.go +++ b/coderd/httpapi/queryparams.go @@ -120,6 +120,27 @@ func (p *QueryParamParser) PositiveInt32(vals url.Values, def int32, queryParam return v } +// PositiveInt64 function checks if the given value is 64-bit and positive. +func (p *QueryParamParser) PositiveInt64(vals url.Values, def int64, queryParam string) int64 { + v, err := parseQueryParam(p, vals, func(v string) (int64, error) { + intValue, err := strconv.ParseInt(v, 10, 64) + if err != nil { + return 0, err + } + if intValue < 0 { + return 0, xerrors.Errorf("value is negative") + } + return intValue, nil + }, def, queryParam) + if err != nil { + p.Errors = append(p.Errors, codersdk.ValidationError{ + Field: queryParam, + Detail: fmt.Sprintf("Query param %q must be a valid 64-bit positive integer: %s", queryParam, err.Error()), + }) + } + return v +} + // NullableBoolean will return a null sql value if no input is provided. // SQLc still uses sql.NullBool rather than the generic type. So converting from // the generic type is required. diff --git a/codersdk/toolsdk/bash.go b/codersdk/toolsdk/bash.go index 037227337bfc9..7497363c2a54e 100644 --- a/codersdk/toolsdk/bash.go +++ b/codersdk/toolsdk/bash.go @@ -17,7 +17,6 @@ import ( "github.com/coder/coder/v2/cli/cliui" "github.com/coder/coder/v2/codersdk" - "github.com/coder/coder/v2/codersdk/workspacesdk" ) type WorkspaceBashArgs struct { @@ -94,42 +93,12 @@ Examples: ctx, cancel := context.WithTimeoutCause(ctx, 5*time.Minute, xerrors.New("MCP handler timeout after 5 min")) defer cancel() - // Normalize workspace input to handle various formats - workspaceName := NormalizeWorkspaceInput(args.Workspace) - - // Find workspace and agent - _, workspaceAgent, err := findWorkspaceAndAgent(ctx, deps.coderClient, workspaceName) - if err != nil { - return WorkspaceBashResult{}, xerrors.Errorf("failed to find workspace: %w", err) - } - - // Wait for agent to be ready - if err := cliui.Agent(ctx, io.Discard, workspaceAgent.ID, cliui.AgentOptions{ - FetchInterval: 0, - Fetch: deps.coderClient.WorkspaceAgent, - FetchLogs: deps.coderClient.WorkspaceAgentLogsAfter, - Wait: true, // Always wait for startup scripts - }); err != nil { - return WorkspaceBashResult{}, xerrors.Errorf("agent not ready: %w", err) - } - - // Create workspace SDK client for agent connection - wsClient := workspacesdk.New(deps.coderClient) - - // Dial agent - conn, err := wsClient.DialAgent(ctx, workspaceAgent.ID, &workspacesdk.DialAgentOptions{ - BlockEndpoints: false, - }) + conn, err := newAgentConn(ctx, deps.coderClient, args.Workspace) if err != nil { - return WorkspaceBashResult{}, xerrors.Errorf("failed to dial agent: %w", err) + return WorkspaceBashResult{}, err } defer conn.Close() - // Wait for connection to be reachable - if !conn.AwaitReachable(ctx) { - return WorkspaceBashResult{}, xerrors.New("agent connection not reachable") - } - // Create SSH client sshClient, err := conn.SSHClient(ctx) if err != nil { @@ -323,32 +292,6 @@ func namedWorkspace(ctx context.Context, client *codersdk.Client, identifier str return client.WorkspaceByOwnerAndName(ctx, owner, workspaceName, codersdk.WorkspaceOptions{}) } -// NormalizeWorkspaceInput converts workspace name input to standard format. -// Handles the following input formats: -// - workspace → workspace -// - workspace.agent → workspace.agent -// - owner/workspace → owner/workspace -// - owner--workspace → owner/workspace -// - owner/workspace.agent → owner/workspace.agent -// - owner--workspace.agent → owner/workspace.agent -// - agent.workspace.owner → owner/workspace.agent (Coder Connect format) -func NormalizeWorkspaceInput(input string) string { - // Handle the special Coder Connect format: agent.workspace.owner - // This format uses only dots and has exactly 3 parts - if strings.Count(input, ".") == 2 && !strings.Contains(input, "/") && !strings.Contains(input, "--") { - parts := strings.Split(input, ".") - if len(parts) == 3 { - // Convert agent.workspace.owner → owner/workspace.agent - return fmt.Sprintf("%s/%s.%s", parts[2], parts[1], parts[0]) - } - } - - // Convert -- separator to / separator for consistency - normalized := strings.ReplaceAll(input, "--", "/") - - return normalized -} - // executeCommandWithTimeout executes a command with timeout support func executeCommandWithTimeout(ctx context.Context, session *gossh.Session, command string) ([]byte, error) { // Set up pipes to capture output diff --git a/codersdk/toolsdk/bash_test.go b/codersdk/toolsdk/bash_test.go index caf54109688ea..da05a71ce3eda 100644 --- a/codersdk/toolsdk/bash_test.go +++ b/codersdk/toolsdk/bash_test.go @@ -99,63 +99,6 @@ func TestWorkspaceBash(t *testing.T) { }) } -func TestNormalizeWorkspaceInput(t *testing.T) { - t.Parallel() - if runtime.GOOS == "windows" { - t.Skip("Skipping on Windows: Workspace MCP bash tools rely on a Unix-like shell (bash) and POSIX/SSH semantics. Use Linux/macOS or WSL for these tests.") - } - - testCases := []struct { - name string - input string - expected string - }{ - { - name: "SimpleWorkspace", - input: "workspace", - expected: "workspace", - }, - { - name: "WorkspaceWithAgent", - input: "workspace.agent", - expected: "workspace.agent", - }, - { - name: "OwnerAndWorkspace", - input: "owner/workspace", - expected: "owner/workspace", - }, - { - name: "OwnerDashWorkspace", - input: "owner--workspace", - expected: "owner/workspace", - }, - { - name: "OwnerWorkspaceAgent", - input: "owner/workspace.agent", - expected: "owner/workspace.agent", - }, - { - name: "OwnerDashWorkspaceAgent", - input: "owner--workspace.agent", - expected: "owner/workspace.agent", - }, - { - name: "CoderConnectFormat", - input: "agent.workspace.owner", // Special Coder Connect reverse format - expected: "owner/workspace.agent", - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - result := toolsdk.NormalizeWorkspaceInput(tc.input) - require.Equal(t, tc.expected, result, "Input %q should normalize to %q but got %q", tc.input, tc.expected, result) - }) - } -} - func TestAllToolsIncludesBash(t *testing.T) { t.Parallel() if runtime.GOOS == "windows" { diff --git a/codersdk/toolsdk/toolsdk.go b/codersdk/toolsdk/toolsdk.go index 7cb8cecb25234..f63acae1c1137 100644 --- a/codersdk/toolsdk/toolsdk.go +++ b/codersdk/toolsdk/toolsdk.go @@ -5,8 +5,10 @@ import ( "bytes" "context" "encoding/json" + "fmt" "io" "runtime/debug" + "strings" "github.com/google/uuid" "golang.org/x/xerrors" @@ -14,7 +16,9 @@ import ( "github.com/coder/aisdk-go" "github.com/coder/coder/v2/buildinfo" + "github.com/coder/coder/v2/cli/cliui" "github.com/coder/coder/v2/codersdk" + "github.com/coder/coder/v2/codersdk/workspacesdk" ) // Tool name constants to avoid hardcoded strings @@ -38,6 +42,7 @@ const ( ToolNameWorkspaceBash = "coder_workspace_bash" ToolNameChatGPTSearch = "search" ToolNameChatGPTFetch = "fetch" + ToolNameWorkspaceReadFile = "coder_workspace_read_file" ) func NewDeps(client *codersdk.Client, opts ...func(*Deps)) (Deps, error) { @@ -205,6 +210,7 @@ var All = []GenericTool{ WorkspaceBash.Generic(), ChatGPTSearch.Generic(), ChatGPTFetch.Generic(), + WorkspaceReadFile.Generic(), } type ReportTaskArgs struct { @@ -1360,3 +1366,138 @@ type MinimalTemplate struct { ActiveVersionID uuid.UUID `json:"active_version_id"` ActiveUserCount int `json:"active_user_count"` } + +type WorkspaceReadFileArgs struct { + Workspace string `json:"workspace"` + Path string `json:"path"` + Offset int64 `json:"offset"` + Limit int64 `json:"limit"` +} + +type WorkspaceReadFileResponse struct { + // Content is the base64-encoded bytes from the file. + Content []byte `json:"content"` + MimeType string `json:"mimeType"` +} + +const maxFileLimit = 1 << 20 // 1MiB + +var WorkspaceReadFile = Tool[WorkspaceReadFileArgs, WorkspaceReadFileResponse]{ + Tool: aisdk.Tool{ + Name: ToolNameWorkspaceReadFile, + Description: `Read from a file in a workspace.`, + Schema: aisdk.Schema{ + Properties: map[string]any{ + "workspace": map[string]any{ + "type": "string", + "description": "The workspace name in the format [owner/]workspace[.agent]. If an owner is not specified, the authenticated user is used.", + }, + "path": map[string]any{ + "type": "string", + "description": "The absolute path of the file to read in the workspace.", + }, + "offset": map[string]any{ + "type": "integer", + "description": "A byte offset indicating where in the file to start reading. Defaults to zero. An empty string indicates the end of the file has been reached.", + }, + "limit": map[string]any{ + "type": "integer", + "description": "The number of bytes to read. Cannot exceed 1 MiB. Defaults to the full size of the file or 1 MiB, whichever is lower.", + }, + }, + Required: []string{"path", "workspace"}, + }, + }, + UserClientOptional: true, + Handler: func(ctx context.Context, deps Deps, args WorkspaceReadFileArgs) (WorkspaceReadFileResponse, error) { + conn, err := newAgentConn(ctx, deps.coderClient, args.Workspace) + if err != nil { + return WorkspaceReadFileResponse{}, err + } + defer conn.Close() + + // Ideally we could stream this all the way back, but it looks like the MCP + // interfaces only allow returning full responses which means the whole + // thing has to be read into memory. So, add a maximum limit to compensate. + limit := args.Limit + if limit == 0 { + limit = maxFileLimit + } else if limit > maxFileLimit { + return WorkspaceReadFileResponse{}, xerrors.Errorf("limit must be %d or less, got %d", maxFileLimit, limit) + } + + reader, mimeType, err := conn.ReadFile(ctx, args.Path, args.Offset, limit) + if err != nil { + return WorkspaceReadFileResponse{}, err + } + defer reader.Close() + + bs, err := io.ReadAll(reader) + if err != nil { + return WorkspaceReadFileResponse{}, xerrors.Errorf("read response body: %w", err) + } + + return WorkspaceReadFileResponse{Content: bs, MimeType: mimeType}, nil + }, +} + +// NormalizeWorkspaceInput converts workspace name input to standard format. +// Handles the following input formats: +// - workspace → workspace +// - workspace.agent → workspace.agent +// - owner/workspace → owner/workspace +// - owner--workspace → owner/workspace +// - owner/workspace.agent → owner/workspace.agent +// - owner--workspace.agent → owner/workspace.agent +// - agent.workspace.owner → owner/workspace.agent (Coder Connect format) +func NormalizeWorkspaceInput(input string) string { + // Handle the special Coder Connect format: agent.workspace.owner + // This format uses only dots and has exactly 3 parts + if strings.Count(input, ".") == 2 && !strings.Contains(input, "/") && !strings.Contains(input, "--") { + parts := strings.Split(input, ".") + if len(parts) == 3 { + // Convert agent.workspace.owner → owner/workspace.agent + return fmt.Sprintf("%s/%s.%s", parts[2], parts[1], parts[0]) + } + } + + // Convert -- separator to / separator for consistency + normalized := strings.ReplaceAll(input, "--", "/") + + return normalized +} + +// newAgentConn returns a connection to the agent specified by the workspace, +// which must be in the format [owner/]workspace[.agent]. +func newAgentConn(ctx context.Context, client *codersdk.Client, workspace string) (workspacesdk.AgentConn, error) { + workspaceName := NormalizeWorkspaceInput(workspace) + _, workspaceAgent, err := findWorkspaceAndAgent(ctx, client, workspaceName) + if err != nil { + return nil, xerrors.Errorf("failed to find workspace: %w", err) + } + + // Wait for agent to be ready. + if err := cliui.Agent(ctx, io.Discard, workspaceAgent.ID, cliui.AgentOptions{ + FetchInterval: 0, + Fetch: client.WorkspaceAgent, + FetchLogs: client.WorkspaceAgentLogsAfter, + Wait: true, // Always wait for startup scripts + }); err != nil { + return nil, xerrors.Errorf("agent not ready: %w", err) + } + + wsClient := workspacesdk.New(client) + + conn, err := wsClient.DialAgent(ctx, workspaceAgent.ID, &workspacesdk.DialAgentOptions{ + BlockEndpoints: false, + }) + if err != nil { + return nil, xerrors.Errorf("failed to dial agent: %w", err) + } + + if !conn.AwaitReachable(ctx) { + conn.Close() + return nil, xerrors.New("agent connection not reachable") + } + return conn, nil +} diff --git a/codersdk/toolsdk/toolsdk_test.go b/codersdk/toolsdk/toolsdk_test.go index fb321e90e7dee..fd128c198e126 100644 --- a/codersdk/toolsdk/toolsdk_test.go +++ b/codersdk/toolsdk/toolsdk_test.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "os" + "path/filepath" "runtime" "sort" "sync" @@ -11,12 +12,14 @@ import ( "time" "github.com/google/uuid" + "github.com/spf13/afero" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.uber.org/goleak" "github.com/coder/aisdk-go" + "github.com/coder/coder/v2/agent" "github.com/coder/coder/v2/agent/agenttest" "github.com/coder/coder/v2/coderd/coderdtest" "github.com/coder/coder/v2/coderd/database" @@ -450,6 +453,109 @@ func TestTools(t *testing.T) { require.Equal(t, 0, result.ExitCode) require.Equal(t, "owner format works", result.Output) }) + + t.Run("WorkspaceReadFile", func(t *testing.T) { + t.Parallel() + + client, workspace, agentToken := setupWorkspaceForAgent(t) + fs := afero.NewMemMapFs() + _ = agenttest.New(t, client.URL, agentToken, func(opts *agent.Options) { + opts.Filesystem = fs + }) + coderdtest.NewWorkspaceAgentWaiter(t, client, workspace.ID).Wait() + tb, err := toolsdk.NewDeps(client) + require.NoError(t, err) + + tmpdir := os.TempDir() + filePath := filepath.Join(tmpdir, "file") + err = afero.WriteFile(fs, filePath, []byte("content"), 0o644) + require.NoError(t, err) + + largeFilePath := filepath.Join(tmpdir, "large") + largeFile, err := fs.Create(largeFilePath) + require.NoError(t, err) + err = largeFile.Truncate(1 << 21) + require.NoError(t, err) + + imagePath := filepath.Join(tmpdir, "file.png") + err = afero.WriteFile(fs, imagePath, []byte("not really an image"), 0o644) + require.NoError(t, err) + + tests := []struct { + name string + path string + limit int64 + offset int64 + mimeType string + bytes []byte + length int + error string + }{ + { + name: "NonExistent", + path: filepath.Join(tmpdir, "does-not-exist"), + error: "file does not exist", + }, + { + name: "Exists", + path: filePath, + bytes: []byte("content"), + mimeType: "application/octet-stream", + }, + { + name: "Limit1Offset2", + path: filePath, + limit: 1, + offset: 2, + bytes: []byte("n"), + mimeType: "application/octet-stream", + }, + { + name: "DefaultMaxLimit", + path: largeFilePath, + length: 1 << 20, + mimeType: "application/octet-stream", + }, + { + name: "ExceedMaxLimit", + path: filePath, + limit: 1 << 21, + error: "limit must be 1048576 or less, got 2097152", + }, + { + name: "ImageMimeType", + path: imagePath, + bytes: []byte("not really an image"), + mimeType: "image/png", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + resp, err := testTool(t, toolsdk.WorkspaceReadFile, tb, toolsdk.WorkspaceReadFileArgs{ + Workspace: workspace.Name, + Path: tt.path, + Limit: tt.limit, + Offset: tt.offset, + }) + if tt.error != "" { + require.Error(t, err) + require.Contains(t, err.Error(), tt.error) + } else { + require.NoError(t, err) + if tt.length != 0 { + require.Len(t, resp.Content, tt.length) + } + if tt.bytes != nil { + require.Equal(t, tt.bytes, resp.Content) + } + require.Equal(t, tt.mimeType, resp.MimeType) + } + }) + } + }) } // TestedTools keeps track of which tools have been tested. @@ -748,3 +854,57 @@ func TestReportTaskWithReporter(t *testing.T) { // Verify response require.Equal(t, "Thanks for reporting!", result.Message) } + +func TestNormalizeWorkspaceInput(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + input string + expected string + }{ + { + name: "SimpleWorkspace", + input: "workspace", + expected: "workspace", + }, + { + name: "WorkspaceWithAgent", + input: "workspace.agent", + expected: "workspace.agent", + }, + { + name: "OwnerAndWorkspace", + input: "owner/workspace", + expected: "owner/workspace", + }, + { + name: "OwnerDashWorkspace", + input: "owner--workspace", + expected: "owner/workspace", + }, + { + name: "OwnerWorkspaceAgent", + input: "owner/workspace.agent", + expected: "owner/workspace.agent", + }, + { + name: "OwnerDashWorkspaceAgent", + input: "owner--workspace.agent", + expected: "owner/workspace.agent", + }, + { + name: "CoderConnectFormat", + input: "agent.workspace.owner", // Special Coder Connect reverse format + expected: "owner/workspace.agent", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + result := toolsdk.NormalizeWorkspaceInput(tc.input) + require.Equal(t, tc.expected, result, "Input %q should normalize to %q but got %q", tc.input, tc.expected, result) + }) + } +} diff --git a/codersdk/workspacesdk/agentconn.go b/codersdk/workspacesdk/agentconn.go index bb929c9ba2a04..7efdb06520ab0 100644 --- a/codersdk/workspacesdk/agentconn.go +++ b/codersdk/workspacesdk/agentconn.go @@ -60,6 +60,7 @@ type AgentConn interface { PrometheusMetrics(ctx context.Context) ([]byte, error) ReconnectingPTY(ctx context.Context, id uuid.UUID, height uint16, width uint16, command string, initOpts ...AgentReconnectingPTYInitOption) (net.Conn, error) RecreateDevcontainer(ctx context.Context, devcontainerID string) (codersdk.Response, error) + ReadFile(ctx context.Context, path string, offset, limit int64) (io.ReadCloser, string, error) SSH(ctx context.Context) (*gonet.TCPConn, error) SSHClient(ctx context.Context) (*ssh.Client, error) SSHClientOnPort(ctx context.Context, port uint16) (*ssh.Client, error) @@ -476,6 +477,30 @@ func (c *agentConn) RecreateDevcontainer(ctx context.Context, devcontainerID str return m, nil } +// ReadFile reads from a file from the workspace, returning a file reader and +// the mime type. +func (c *agentConn) ReadFile(ctx context.Context, path string, offset, limit int64) (io.ReadCloser, string, error) { + ctx, span := tracing.StartSpan(ctx) + defer span.End() + + //nolint:bodyclose // we want to return the body so the caller can stream. + res, err := c.apiRequest(ctx, http.MethodGet, fmt.Sprintf("/api/v0/read-file?path=%s&offset=%d&limit=%d", path, offset, limit), nil) + if err != nil { + return nil, "", xerrors.Errorf("do request: %w", err) + } + if res.StatusCode != http.StatusOK { + // codersdk.ReadBodyAsError will close the body. + return nil, "", codersdk.ReadBodyAsError(res) + } + + mimeType := res.Header.Get("Content-Type") + if mimeType == "" { + mimeType = "application/octet-stream" + } + + return res.Body, mimeType, nil +} + // apiRequest makes a request to the workspace agent's HTTP API server. func (c *agentConn) apiRequest(ctx context.Context, method, path string, body io.Reader) (*http.Response, error) { ctx, span := tracing.StartSpan(ctx) diff --git a/codersdk/workspacesdk/agentconnmock/agentconnmock.go b/codersdk/workspacesdk/agentconnmock/agentconnmock.go index eb55bb27938c0..6f93fb6e85ce1 100644 --- a/codersdk/workspacesdk/agentconnmock/agentconnmock.go +++ b/codersdk/workspacesdk/agentconnmock/agentconnmock.go @@ -232,6 +232,22 @@ func (mr *MockAgentConnMockRecorder) PrometheusMetrics(ctx any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PrometheusMetrics", reflect.TypeOf((*MockAgentConn)(nil).PrometheusMetrics), ctx) } +// ReadFile mocks base method. +func (m *MockAgentConn) ReadFile(ctx context.Context, path string, offset, limit int64) (io.ReadCloser, string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ReadFile", ctx, path, offset, limit) + ret0, _ := ret[0].(io.ReadCloser) + ret1, _ := ret[1].(string) + ret2, _ := ret[2].(error) + return ret0, ret1, ret2 +} + +// ReadFile indicates an expected call of ReadFile. +func (mr *MockAgentConnMockRecorder) ReadFile(ctx, path, offset, limit any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReadFile", reflect.TypeOf((*MockAgentConn)(nil).ReadFile), ctx, path, offset, limit) +} + // ReconnectingPTY mocks base method. func (m *MockAgentConn) ReconnectingPTY(ctx context.Context, id uuid.UUID, height, width uint16, command string, initOpts ...workspacesdk.AgentReconnectingPTYInitOption) (net.Conn, error) { m.ctrl.T.Helper()