Skip to content

Commit a042936

Browse files
committed
Add coder_workspace_read_file MCP tool
1 parent c5282ea commit a042936

File tree

8 files changed

+530
-0
lines changed

8 files changed

+530
-0
lines changed

agent/api.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ func (a *agent) apiHandler() http.Handler {
6060
r.Get("/api/v0/listening-ports", lp.handler)
6161
r.Get("/api/v0/netcheck", a.HandleNetcheck)
6262
r.Post("/api/v0/list-directory", a.HandleLS)
63+
r.Get("/api/v0/read-file", a.HandleReadFile)
6364
r.Get("/debug/logs", a.HandleHTTPDebugLogs)
6465
r.Get("/debug/magicsock", a.HandleHTTPDebugMagicsock)
6566
r.Get("/debug/magicsock/debug-logging/{state}", a.HandleHTTPMagicsockDebugLoggingState)

agent/files.go

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
package agent
2+
3+
import (
4+
"context"
5+
"errors"
6+
"io"
7+
"mime"
8+
"net/http"
9+
"os"
10+
"path/filepath"
11+
"strconv"
12+
13+
"golang.org/x/xerrors"
14+
15+
"cdr.dev/slog"
16+
"github.com/coder/coder/v2/coderd/httpapi"
17+
"github.com/coder/coder/v2/codersdk"
18+
)
19+
20+
func (a *agent) HandleReadFile(rw http.ResponseWriter, r *http.Request) {
21+
ctx := r.Context()
22+
23+
query := r.URL.Query()
24+
parser := httpapi.NewQueryParamParser().RequiredNotEmpty("path")
25+
path := parser.String(query, "", "path")
26+
offset := parser.PositiveInt64(query, 0, "offset")
27+
limit := parser.PositiveInt64(query, 0, "limit")
28+
parser.ErrorExcessParams(query)
29+
if len(parser.Errors) > 0 {
30+
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
31+
Message: "Query parameters have invalid values.",
32+
Validations: parser.Errors,
33+
})
34+
return
35+
}
36+
37+
status, err := a.streamFile(ctx, rw, path, offset, limit)
38+
if err != nil {
39+
httpapi.Write(ctx, rw, status, codersdk.Response{
40+
Message: err.Error(),
41+
})
42+
return
43+
}
44+
}
45+
46+
func (a *agent) streamFile(ctx context.Context, rw http.ResponseWriter, path string, offset, limit int64) (int, error) {
47+
if !filepath.IsAbs(path) {
48+
return http.StatusBadRequest, xerrors.Errorf("file path must be absolute: %q", path)
49+
}
50+
51+
f, err := a.filesystem.Open(path)
52+
if err != nil {
53+
status := http.StatusInternalServerError
54+
switch {
55+
case errors.Is(err, os.ErrNotExist):
56+
status = http.StatusNotFound
57+
case errors.Is(err, os.ErrPermission):
58+
status = http.StatusForbidden
59+
}
60+
return status, err
61+
}
62+
defer f.Close()
63+
64+
stat, err := f.Stat()
65+
if err != nil {
66+
return http.StatusInternalServerError, err
67+
}
68+
69+
if stat.IsDir() {
70+
return http.StatusBadRequest, xerrors.Errorf("open %s: not a file", path)
71+
}
72+
73+
size := stat.Size()
74+
if limit == 0 {
75+
limit = size
76+
}
77+
bytesRemaining := max(size-offset, 0)
78+
bytesToRead := min(bytesRemaining, limit)
79+
80+
// Relying on just the file name for the mime type for now.
81+
mimeType := mime.TypeByExtension(filepath.Ext(path))
82+
if mimeType == "" {
83+
mimeType = "application/octet-stream"
84+
}
85+
rw.Header().Set("Content-Type", mimeType)
86+
rw.Header().Set("Content-Length", strconv.FormatInt(bytesToRead, 10))
87+
rw.WriteHeader(http.StatusOK)
88+
89+
reader := io.NewSectionReader(f, offset, bytesToRead)
90+
_, err = io.Copy(rw, reader)
91+
if err != nil && !errors.Is(err, io.EOF) && ctx.Err() == nil {
92+
a.logger.Error(ctx, "workspace agent read file", slog.Error(err))
93+
}
94+
95+
return 0, nil
96+
}

agent/files_test.go

Lines changed: 208 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,208 @@
1+
package agent_test
2+
3+
import (
4+
"context"
5+
"net/http"
6+
"os"
7+
"path/filepath"
8+
"testing"
9+
10+
"github.com/spf13/afero"
11+
"github.com/stretchr/testify/require"
12+
13+
"github.com/coder/coder/v2/agent"
14+
"github.com/coder/coder/v2/agent/agenttest"
15+
"github.com/coder/coder/v2/coderd/coderdtest"
16+
"github.com/coder/coder/v2/codersdk/agentsdk"
17+
"github.com/coder/coder/v2/testutil"
18+
)
19+
20+
type testFs struct {
21+
afero.Fs
22+
// intercept can return an error for testing when a call fails.
23+
intercept func(call, file string) error
24+
}
25+
26+
func newTestFs(base afero.Fs, intercept func(call, file string) error) *testFs {
27+
return &testFs{
28+
Fs: base,
29+
intercept: intercept,
30+
}
31+
}
32+
33+
func (fs *testFs) Open(name string) (afero.File, error) {
34+
if err := fs.intercept("open", name); err != nil {
35+
return nil, err
36+
}
37+
return fs.Fs.Open(name)
38+
}
39+
40+
func TestReadFile(t *testing.T) {
41+
t.Parallel()
42+
43+
tmpdir := os.TempDir()
44+
noPermsFilePath := filepath.Join(tmpdir, "no-perms")
45+
//nolint:dogsled
46+
conn, _, _, fs, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, opts *agent.Options) {
47+
opts.Filesystem = newTestFs(opts.Filesystem, func(call, file string) error {
48+
if file == noPermsFilePath {
49+
return os.ErrPermission
50+
}
51+
return nil
52+
})
53+
})
54+
55+
dirPath := filepath.Join(tmpdir, "a-directory")
56+
err := fs.MkdirAll(dirPath, 0o755)
57+
require.NoError(t, err)
58+
59+
filePath := filepath.Join(tmpdir, "file")
60+
err = afero.WriteFile(fs, filePath, []byte("content"), 0o644)
61+
require.NoError(t, err)
62+
63+
imagePath := filepath.Join(tmpdir, "file.png")
64+
err = afero.WriteFile(fs, imagePath, []byte("not really an image"), 0o644)
65+
require.NoError(t, err)
66+
67+
tests := []struct {
68+
name string
69+
path string
70+
limit int64
71+
offset int64
72+
bytes []byte
73+
mimeType string
74+
errCode int
75+
error string
76+
}{
77+
{
78+
name: "NoPath",
79+
path: "",
80+
errCode: http.StatusBadRequest,
81+
error: "\"path\" is required",
82+
},
83+
{
84+
name: "RelativePath",
85+
path: "./relative",
86+
errCode: http.StatusBadRequest,
87+
error: "file path must be absolute",
88+
},
89+
{
90+
name: "RelativePath",
91+
path: "also-relative",
92+
errCode: http.StatusBadRequest,
93+
error: "file path must be absolute",
94+
},
95+
{
96+
name: "NegativeLimit",
97+
path: filePath,
98+
limit: -10,
99+
errCode: http.StatusBadRequest,
100+
error: "value is negative",
101+
},
102+
{
103+
name: "NegativeOffset",
104+
path: filePath,
105+
offset: -10,
106+
errCode: http.StatusBadRequest,
107+
error: "value is negative",
108+
},
109+
{
110+
name: "NonExistent",
111+
path: filepath.Join(tmpdir, "does-not-exist"),
112+
errCode: http.StatusNotFound,
113+
error: "file does not exist",
114+
},
115+
{
116+
name: "IsDir",
117+
path: dirPath,
118+
errCode: http.StatusBadRequest,
119+
error: "not a file",
120+
},
121+
{
122+
name: "NoPermissions",
123+
path: noPermsFilePath,
124+
errCode: http.StatusForbidden,
125+
error: "permission denied",
126+
},
127+
{
128+
name: "Defaults",
129+
path: filePath,
130+
bytes: []byte("content"),
131+
},
132+
{
133+
name: "Limit1",
134+
path: filePath,
135+
limit: 1,
136+
bytes: []byte("c"),
137+
},
138+
{
139+
name: "Offset1",
140+
path: filePath,
141+
offset: 1,
142+
bytes: []byte("ontent"),
143+
},
144+
{
145+
name: "Limit1Offset2",
146+
path: filePath,
147+
limit: 1,
148+
offset: 2,
149+
bytes: []byte("n"),
150+
},
151+
{
152+
name: "Limit7Offset0",
153+
path: filePath,
154+
limit: 7,
155+
offset: 0,
156+
bytes: []byte("content"),
157+
},
158+
{
159+
name: "Limit100",
160+
path: filePath,
161+
limit: 100,
162+
bytes: []byte("content"),
163+
},
164+
{
165+
name: "Offset7",
166+
path: filePath,
167+
offset: 7,
168+
bytes: []byte{},
169+
},
170+
{
171+
name: "Offset100",
172+
path: filePath,
173+
offset: 100,
174+
bytes: []byte{},
175+
},
176+
{
177+
name: "MimeTypePng",
178+
path: imagePath,
179+
bytes: []byte("not really an image"),
180+
mimeType: "image/png",
181+
},
182+
}
183+
184+
for _, tt := range tests {
185+
t.Run(tt.name, func(t *testing.T) {
186+
t.Parallel()
187+
188+
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
189+
defer cancel()
190+
191+
b, mimeType, err := conn.ReadFile(ctx, tt.path, tt.offset, tt.limit)
192+
if tt.errCode != 0 {
193+
require.Error(t, err)
194+
cerr := coderdtest.SDKError(t, err)
195+
require.Contains(t, cerr.Error(), tt.error)
196+
require.Equal(t, tt.errCode, cerr.StatusCode())
197+
} else {
198+
require.NoError(t, err)
199+
require.Equal(t, tt.bytes, b)
200+
expectedMimeType := tt.mimeType
201+
if expectedMimeType == "" {
202+
expectedMimeType = "application/octet-stream"
203+
}
204+
require.Equal(t, expectedMimeType, mimeType)
205+
}
206+
})
207+
}
208+
}

coderd/httpapi/queryparams.go

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,27 @@ func (p *QueryParamParser) PositiveInt32(vals url.Values, def int32, queryParam
120120
return v
121121
}
122122

123+
// PositiveInt64 function checks if the given value is 64-bit and positive.
124+
func (p *QueryParamParser) PositiveInt64(vals url.Values, def int64, queryParam string) int64 {
125+
v, err := parseQueryParam(p, vals, func(v string) (int64, error) {
126+
intValue, err := strconv.ParseInt(v, 10, 64)
127+
if err != nil {
128+
return 0, err
129+
}
130+
if intValue < 0 {
131+
return 0, xerrors.Errorf("value is negative")
132+
}
133+
return intValue, nil
134+
}, def, queryParam)
135+
if err != nil {
136+
p.Errors = append(p.Errors, codersdk.ValidationError{
137+
Field: queryParam,
138+
Detail: fmt.Sprintf("Query param %q must be a valid 64-bit positive integer: %s", queryParam, err.Error()),
139+
})
140+
}
141+
return v
142+
}
143+
123144
// NullableBoolean will return a null sql value if no input is provided.
124145
// SQLc still uses sql.NullBool rather than the generic type. So converting from
125146
// the generic type is required.

0 commit comments

Comments
 (0)