Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions agent/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ func (a *agent) apiHandler() http.Handler {
r.Get("/api/v0/listening-ports", lp.handler)
r.Get("/api/v0/netcheck", a.HandleNetcheck)
r.Post("/api/v0/list-directory", a.HandleLS)
r.Get("/api/v0/read-file", a.HandleReadFile)
r.Get("/debug/logs", a.HandleHTTPDebugLogs)
r.Get("/debug/magicsock", a.HandleHTTPDebugMagicsock)
r.Get("/debug/magicsock/debug-logging/{state}", a.HandleHTTPMagicsockDebugLoggingState)
Expand Down
96 changes: 96 additions & 0 deletions agent/files.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
package agent

import (
"context"
"errors"
"io"
"mime"
"net/http"
"os"
"path/filepath"
"strconv"

"golang.org/x/xerrors"

"cdr.dev/slog"
"github.com/coder/coder/v2/coderd/httpapi"
"github.com/coder/coder/v2/codersdk"
)

func (a *agent) HandleReadFile(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()

query := r.URL.Query()
parser := httpapi.NewQueryParamParser().RequiredNotEmpty("path")
path := parser.String(query, "", "path")
offset := parser.PositiveInt64(query, 0, "offset")
limit := parser.PositiveInt64(query, 0, "limit")
parser.ErrorExcessParams(query)
if len(parser.Errors) > 0 {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Query parameters have invalid values.",
Validations: parser.Errors,
})
return
}

status, err := a.streamFile(ctx, rw, path, offset, limit)
if err != nil {
httpapi.Write(ctx, rw, status, codersdk.Response{
Message: err.Error(),
})
return
}
}

func (a *agent) streamFile(ctx context.Context, rw http.ResponseWriter, path string, offset, limit int64) (int, error) {
if !filepath.IsAbs(path) {
return http.StatusBadRequest, xerrors.Errorf("file path must be absolute: %q", path)
}

f, err := a.filesystem.Open(path)
if err != nil {
status := http.StatusInternalServerError
switch {
case errors.Is(err, os.ErrNotExist):
status = http.StatusNotFound
case errors.Is(err, os.ErrPermission):
status = http.StatusForbidden
}
return status, err
}
defer f.Close()

stat, err := f.Stat()
if err != nil {
return http.StatusInternalServerError, err
}

if stat.IsDir() {
return http.StatusBadRequest, xerrors.Errorf("open %s: not a file", path)
}

size := stat.Size()
if limit == 0 {
limit = size
}
bytesRemaining := max(size-offset, 0)
bytesToRead := min(bytesRemaining, limit)

// Relying on just the file name for the mime type for now.
mimeType := mime.TypeByExtension(filepath.Ext(path))
if mimeType == "" {
mimeType = "application/octet-stream"
}
rw.Header().Set("Content-Type", mimeType)
rw.Header().Set("Content-Length", strconv.FormatInt(bytesToRead, 10))
rw.WriteHeader(http.StatusOK)

reader := io.NewSectionReader(f, offset, bytesToRead)
_, err = io.Copy(rw, reader)
if err != nil && !errors.Is(err, io.EOF) && ctx.Err() == nil {
a.logger.Error(ctx, "workspace agent read file", slog.Error(err))
}

return 0, nil
}
216 changes: 216 additions & 0 deletions agent/files_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
package agent_test

import (
"context"
"io"
"net/http"
"os"
"path/filepath"
"testing"

"github.com/spf13/afero"
"github.com/stretchr/testify/require"

"github.com/coder/coder/v2/agent"
"github.com/coder/coder/v2/agent/agenttest"
"github.com/coder/coder/v2/coderd/coderdtest"
"github.com/coder/coder/v2/codersdk/agentsdk"
"github.com/coder/coder/v2/testutil"
)

type testFs struct {
afero.Fs
// intercept can return an error for testing when a call fails.
intercept func(call, file string) error
}

func newTestFs(base afero.Fs, intercept func(call, file string) error) *testFs {
return &testFs{
Fs: base,
intercept: intercept,
}
}

func (fs *testFs) Open(name string) (afero.File, error) {
if err := fs.intercept("open", name); err != nil {
return nil, err
}
return fs.Fs.Open(name)
}

func TestReadFile(t *testing.T) {
t.Parallel()

tmpdir := os.TempDir()
noPermsFilePath := filepath.Join(tmpdir, "no-perms")
//nolint:dogsled
conn, _, _, fs, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, opts *agent.Options) {
opts.Filesystem = newTestFs(opts.Filesystem, func(call, file string) error {
if file == noPermsFilePath {
return os.ErrPermission
}
return nil
})
})

dirPath := filepath.Join(tmpdir, "a-directory")
err := fs.MkdirAll(dirPath, 0o755)
require.NoError(t, err)

filePath := filepath.Join(tmpdir, "file")
err = afero.WriteFile(fs, filePath, []byte("content"), 0o644)
require.NoError(t, err)

imagePath := filepath.Join(tmpdir, "file.png")
err = afero.WriteFile(fs, imagePath, []byte("not really an image"), 0o644)
require.NoError(t, err)

tests := []struct {
name string
path string
limit int64
offset int64
bytes []byte
mimeType string
errCode int
error string
}{
{
name: "NoPath",
path: "",
errCode: http.StatusBadRequest,
error: "\"path\" is required",
},
{
name: "RelativePath",
path: "./relative",
errCode: http.StatusBadRequest,
error: "file path must be absolute",
},
{
name: "RelativePath",
path: "also-relative",
errCode: http.StatusBadRequest,
error: "file path must be absolute",
},
{
name: "NegativeLimit",
path: filePath,
limit: -10,
errCode: http.StatusBadRequest,
error: "value is negative",
},
{
name: "NegativeOffset",
path: filePath,
offset: -10,
errCode: http.StatusBadRequest,
error: "value is negative",
},
{
name: "NonExistent",
path: filepath.Join(tmpdir, "does-not-exist"),
errCode: http.StatusNotFound,
error: "file does not exist",
},
{
name: "IsDir",
path: dirPath,
errCode: http.StatusBadRequest,
error: "not a file",
},
{
name: "NoPermissions",
path: noPermsFilePath,
errCode: http.StatusForbidden,
error: "permission denied",
},
{
name: "Defaults",
path: filePath,
bytes: []byte("content"),
mimeType: "application/octet-stream",
},
{
name: "Limit1",
path: filePath,
limit: 1,
bytes: []byte("c"),
mimeType: "application/octet-stream",
},
{
name: "Offset1",
path: filePath,
offset: 1,
bytes: []byte("ontent"),
mimeType: "application/octet-stream",
},
{
name: "Limit1Offset2",
path: filePath,
limit: 1,
offset: 2,
bytes: []byte("n"),
mimeType: "application/octet-stream",
},
{
name: "Limit7Offset0",
path: filePath,
limit: 7,
offset: 0,
bytes: []byte("content"),
mimeType: "application/octet-stream",
},
{
name: "Limit100",
path: filePath,
limit: 100,
bytes: []byte("content"),
mimeType: "application/octet-stream",
},
{
name: "Offset7",
path: filePath,
offset: 7,
bytes: []byte{},
mimeType: "application/octet-stream",
},
{
name: "Offset100",
path: filePath,
offset: 100,
bytes: []byte{},
mimeType: "application/octet-stream",
},
{
name: "MimeTypePng",
path: imagePath,
bytes: []byte("not really an image"),
mimeType: "image/png",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()

ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()

reader, mimeType, err := conn.ReadFile(ctx, tt.path, tt.offset, tt.limit)
if tt.errCode != 0 {
require.Error(t, err)
cerr := coderdtest.SDKError(t, err)
require.Contains(t, cerr.Error(), tt.error)
require.Equal(t, tt.errCode, cerr.StatusCode())
} else {
require.NoError(t, err)
defer reader.Close()
bytes, err := io.ReadAll(reader)
require.NoError(t, err)
require.Equal(t, tt.bytes, bytes)
require.Equal(t, tt.mimeType, mimeType)
}
})
}
}
21 changes: 21 additions & 0 deletions coderd/httpapi/queryparams.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,27 @@ func (p *QueryParamParser) PositiveInt32(vals url.Values, def int32, queryParam
return v
}

// PositiveInt64 function checks if the given value is 64-bit and positive.
func (p *QueryParamParser) PositiveInt64(vals url.Values, def int64, queryParam string) int64 {
v, err := parseQueryParam(p, vals, func(v string) (int64, error) {
intValue, err := strconv.ParseInt(v, 10, 64)
if err != nil {
return 0, err
}
if intValue < 0 {
return 0, xerrors.Errorf("value is negative")
}
return intValue, nil
}, def, queryParam)
if err != nil {
p.Errors = append(p.Errors, codersdk.ValidationError{
Field: queryParam,
Detail: fmt.Sprintf("Query param %q must be a valid 64-bit positive integer: %s", queryParam, err.Error()),
})
}
return v
}

// NullableBoolean will return a null sql value if no input is provided.
// SQLc still uses sql.NullBool rather than the generic type. So converting from
// the generic type is required.
Expand Down
Loading
Loading