diff --git a/coderd/fileszip.go b/archive/archive.go similarity index 69% rename from coderd/fileszip.go rename to archive/archive.go index 389e524746291..db78b8c700010 100644 --- a/coderd/fileszip.go +++ b/archive/archive.go @@ -1,4 +1,4 @@ -package coderd +package archive import ( "archive/tar" @@ -10,21 +10,22 @@ import ( "strings" ) -func CreateTarFromZip(zipReader *zip.Reader) ([]byte, error) { +// CreateTarFromZip converts the given zipReader to a tar archive. +func CreateTarFromZip(zipReader *zip.Reader, maxSize int64) ([]byte, error) { var tarBuffer bytes.Buffer - err := writeTarArchive(&tarBuffer, zipReader) + err := writeTarArchive(&tarBuffer, zipReader, maxSize) if err != nil { return nil, err } return tarBuffer.Bytes(), nil } -func writeTarArchive(w io.Writer, zipReader *zip.Reader) error { +func writeTarArchive(w io.Writer, zipReader *zip.Reader, maxSize int64) error { tarWriter := tar.NewWriter(w) defer tarWriter.Close() for _, file := range zipReader.File { - err := processFileInZipArchive(file, tarWriter) + err := processFileInZipArchive(file, tarWriter, maxSize) if err != nil { return err } @@ -32,7 +33,7 @@ func writeTarArchive(w io.Writer, zipReader *zip.Reader) error { return nil } -func processFileInZipArchive(file *zip.File, tarWriter *tar.Writer) error { +func processFileInZipArchive(file *zip.File, tarWriter *tar.Writer, maxSize int64) error { fileReader, err := file.Open() if err != nil { return err @@ -52,7 +53,7 @@ func processFileInZipArchive(file *zip.File, tarWriter *tar.Writer) error { return err } - n, err := io.CopyN(tarWriter, fileReader, httpFileMaxBytes) + n, err := io.CopyN(tarWriter, fileReader, maxSize) log.Println(file.Name, n, err) if errors.Is(err, io.EOF) { err = nil @@ -60,16 +61,18 @@ func processFileInZipArchive(file *zip.File, tarWriter *tar.Writer) error { return err } -func CreateZipFromTar(tarReader *tar.Reader) ([]byte, error) { +// CreateZipFromTar converts the given tarReader to a zip archive. +func CreateZipFromTar(tarReader *tar.Reader, maxSize int64) ([]byte, error) { var zipBuffer bytes.Buffer - err := WriteZipArchive(&zipBuffer, tarReader) + err := WriteZip(&zipBuffer, tarReader, maxSize) if err != nil { return nil, err } return zipBuffer.Bytes(), nil } -func WriteZipArchive(w io.Writer, tarReader *tar.Reader) error { +// WriteZip writes the given tarReader to w. +func WriteZip(w io.Writer, tarReader *tar.Reader, maxSize int64) error { zipWriter := zip.NewWriter(w) defer zipWriter.Close() @@ -100,7 +103,7 @@ func WriteZipArchive(w io.Writer, tarReader *tar.Reader) error { return err } - _, err = io.CopyN(zipEntry, tarReader, httpFileMaxBytes) + _, err = io.CopyN(zipEntry, tarReader, maxSize) if errors.Is(err, io.EOF) { err = nil } diff --git a/coderd/fileszip_test.go b/archive/archive_test.go similarity index 62% rename from coderd/fileszip_test.go rename to archive/archive_test.go index 1c3781c39d70b..c10d103622fa7 100644 --- a/coderd/fileszip_test.go +++ b/archive/archive_test.go @@ -1,10 +1,9 @@ -package coderd_test +package archive_test import ( "archive/tar" "archive/zip" "bytes" - "io" "io/fs" "os" "os/exec" @@ -12,13 +11,12 @@ import ( "runtime" "strings" "testing" - "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "golang.org/x/xerrors" - "github.com/coder/coder/v2/coderd" + "github.com/coder/coder/v2/archive" + "github.com/coder/coder/v2/archive/archivetest" "github.com/coder/coder/v2/testutil" ) @@ -30,18 +28,17 @@ func TestCreateTarFromZip(t *testing.T) { // Read a zip file we prepared earlier ctx := testutil.Context(t, testutil.WaitShort) - zipBytes, err := os.ReadFile(filepath.Join("testdata", "test.zip")) - require.NoError(t, err, "failed to read sample zip file") + zipBytes := archivetest.TestZipFileBytes() // Assert invariant - assertSampleZipFile(t, zipBytes) + archivetest.AssertSampleZipFile(t, zipBytes) zr, err := zip.NewReader(bytes.NewReader(zipBytes), int64(len(zipBytes))) require.NoError(t, err, "failed to parse sample zip file") - tarBytes, err := coderd.CreateTarFromZip(zr) + tarBytes, err := archive.CreateTarFromZip(zr, int64(len(zipBytes))) require.NoError(t, err, "failed to convert zip to tar") - assertSampleTarFile(t, tarBytes) + archivetest.AssertSampleTarFile(t, tarBytes) tempDir := t.TempDir() tempFilePath := filepath.Join(tempDir, "test.tar") @@ -60,14 +57,13 @@ func TestCreateZipFromTar(t *testing.T) { } t.Run("OK", func(t *testing.T) { t.Parallel() - tarBytes, err := os.ReadFile(filepath.Join(".", "testdata", "test.tar")) - require.NoError(t, err, "failed to read sample tar file") + tarBytes := archivetest.TestTarFileBytes() tr := tar.NewReader(bytes.NewReader(tarBytes)) - zipBytes, err := coderd.CreateZipFromTar(tr) + zipBytes, err := archive.CreateZipFromTar(tr, int64(len(tarBytes))) require.NoError(t, err) - assertSampleZipFile(t, zipBytes) + archivetest.AssertSampleZipFile(t, zipBytes) tempDir := t.TempDir() tempFilePath := filepath.Join(tempDir, "test.zip") @@ -99,7 +95,7 @@ func TestCreateZipFromTar(t *testing.T) { // When: we convert this to a zip tr := tar.NewReader(&tarBytes) - zipBytes, err := coderd.CreateZipFromTar(tr) + zipBytes, err := archive.CreateZipFromTar(tr, int64(tarBytes.Len())) require.NoError(t, err) // Then: the resulting zip should contain a corresponding directory @@ -133,7 +129,7 @@ func assertExtractedFiles(t *testing.T, dir string, checkModePerm bool) { if checkModePerm { assert.Equal(t, fs.ModePerm&0o755, stat.Mode().Perm(), "expected mode 0755 on directory") } - assert.Equal(t, archiveRefTime(t).UTC(), stat.ModTime().UTC(), "unexpected modtime of %q", path) + assert.Equal(t, archivetest.ArchiveRefTime(t).UTC(), stat.ModTime().UTC(), "unexpected modtime of %q", path) case "/test/hello.txt": stat, err := os.Stat(path) assert.NoError(t, err, "failed to stat path %q", path) @@ -168,84 +164,3 @@ func assertExtractedFiles(t *testing.T, dir string, checkModePerm bool) { return nil }) } - -func assertSampleTarFile(t *testing.T, tarBytes []byte) { - t.Helper() - - tr := tar.NewReader(bytes.NewReader(tarBytes)) - for { - hdr, err := tr.Next() - if err != nil { - if err == io.EOF { - return - } - require.NoError(t, err) - } - - // Note: ignoring timezones here. - require.Equal(t, archiveRefTime(t).UTC(), hdr.ModTime.UTC()) - - switch hdr.Name { - case "test/": - require.Equal(t, hdr.Typeflag, byte(tar.TypeDir)) - case "test/hello.txt": - require.Equal(t, hdr.Typeflag, byte(tar.TypeReg)) - bs, err := io.ReadAll(tr) - if err != nil && !xerrors.Is(err, io.EOF) { - require.NoError(t, err) - } - require.Equal(t, "hello", string(bs)) - case "test/dir/": - require.Equal(t, hdr.Typeflag, byte(tar.TypeDir)) - case "test/dir/world.txt": - require.Equal(t, hdr.Typeflag, byte(tar.TypeReg)) - bs, err := io.ReadAll(tr) - if err != nil && !xerrors.Is(err, io.EOF) { - require.NoError(t, err) - } - require.Equal(t, "world", string(bs)) - default: - require.Failf(t, "unexpected file in tar", hdr.Name) - } - } -} - -func assertSampleZipFile(t *testing.T, zipBytes []byte) { - t.Helper() - - zr, err := zip.NewReader(bytes.NewReader(zipBytes), int64(len(zipBytes))) - require.NoError(t, err) - - for _, f := range zr.File { - // Note: ignoring timezones here. - require.Equal(t, archiveRefTime(t).UTC(), f.Modified.UTC()) - switch f.Name { - case "test/", "test/dir/": - // directory - case "test/hello.txt": - rc, err := f.Open() - require.NoError(t, err) - bs, err := io.ReadAll(rc) - _ = rc.Close() - require.NoError(t, err) - require.Equal(t, "hello", string(bs)) - case "test/dir/world.txt": - rc, err := f.Open() - require.NoError(t, err) - bs, err := io.ReadAll(rc) - _ = rc.Close() - require.NoError(t, err) - require.Equal(t, "world", string(bs)) - default: - require.Failf(t, "unexpected file in zip", f.Name) - } - } -} - -// archiveRefTime is the Go reference time. The contents of the sample tar and zip files -// in testdata/ all have their modtimes set to the below in some timezone. -func archiveRefTime(t *testing.T) time.Time { - locMST, err := time.LoadLocation("MST") - require.NoError(t, err, "failed to load MST timezone") - return time.Date(2006, 1, 2, 3, 4, 5, 0, locMST) -} diff --git a/archive/archivetest/archivetest.go b/archive/archivetest/archivetest.go new file mode 100644 index 0000000000000..2daa6fad4ae9b --- /dev/null +++ b/archive/archivetest/archivetest.go @@ -0,0 +1,113 @@ +package archivetest + +import ( + "archive/tar" + "archive/zip" + "bytes" + _ "embed" + "io" + "testing" + "time" + + "github.com/stretchr/testify/require" + "golang.org/x/xerrors" +) + +//go:embed testdata/test.tar +var testTarFileBytes []byte + +//go:embed testdata/test.zip +var testZipFileBytes []byte + +// TestTarFileBytes returns the content of testdata/test.tar +func TestTarFileBytes() []byte { + return append([]byte{}, testTarFileBytes...) +} + +// TestZipFileBytes returns the content of testdata/test.zip +func TestZipFileBytes() []byte { + return append([]byte{}, testZipFileBytes...) +} + +// AssertSampleTarfile compares the content of tarBytes against testdata/test.tar. +func AssertSampleTarFile(t *testing.T, tarBytes []byte) { + t.Helper() + + tr := tar.NewReader(bytes.NewReader(tarBytes)) + for { + hdr, err := tr.Next() + if err != nil { + if err == io.EOF { + return + } + require.NoError(t, err) + } + + // Note: ignoring timezones here. + require.Equal(t, ArchiveRefTime(t).UTC(), hdr.ModTime.UTC()) + + switch hdr.Name { + case "test/": + require.Equal(t, hdr.Typeflag, byte(tar.TypeDir)) + case "test/hello.txt": + require.Equal(t, hdr.Typeflag, byte(tar.TypeReg)) + bs, err := io.ReadAll(tr) + if err != nil && !xerrors.Is(err, io.EOF) { + require.NoError(t, err) + } + require.Equal(t, "hello", string(bs)) + case "test/dir/": + require.Equal(t, hdr.Typeflag, byte(tar.TypeDir)) + case "test/dir/world.txt": + require.Equal(t, hdr.Typeflag, byte(tar.TypeReg)) + bs, err := io.ReadAll(tr) + if err != nil && !xerrors.Is(err, io.EOF) { + require.NoError(t, err) + } + require.Equal(t, "world", string(bs)) + default: + require.Failf(t, "unexpected file in tar", hdr.Name) + } + } +} + +// AssertSampleZipFile compares the content of zipBytes against testdata/test.zip. +func AssertSampleZipFile(t *testing.T, zipBytes []byte) { + t.Helper() + + zr, err := zip.NewReader(bytes.NewReader(zipBytes), int64(len(zipBytes))) + require.NoError(t, err) + + for _, f := range zr.File { + // Note: ignoring timezones here. + require.Equal(t, ArchiveRefTime(t).UTC(), f.Modified.UTC()) + switch f.Name { + case "test/", "test/dir/": + // directory + case "test/hello.txt": + rc, err := f.Open() + require.NoError(t, err) + bs, err := io.ReadAll(rc) + _ = rc.Close() + require.NoError(t, err) + require.Equal(t, "hello", string(bs)) + case "test/dir/world.txt": + rc, err := f.Open() + require.NoError(t, err) + bs, err := io.ReadAll(rc) + _ = rc.Close() + require.NoError(t, err) + require.Equal(t, "world", string(bs)) + default: + require.Failf(t, "unexpected file in zip", f.Name) + } + } +} + +// archiveRefTime is the Go reference time. The contents of the sample tar and zip files +// in testdata/ all have their modtimes set to the below in some timezone. +func ArchiveRefTime(t *testing.T) time.Time { + locMST, err := time.LoadLocation("MST") + require.NoError(t, err, "failed to load MST timezone") + return time.Date(2006, 1, 2, 3, 4, 5, 0, locMST) +} diff --git a/coderd/testdata/test.tar b/archive/archivetest/testdata/test.tar similarity index 100% rename from coderd/testdata/test.tar rename to archive/archivetest/testdata/test.tar diff --git a/coderd/testdata/test.zip b/archive/archivetest/testdata/test.zip similarity index 100% rename from coderd/testdata/test.zip rename to archive/archivetest/testdata/test.zip diff --git a/cli/templatepull_test.go b/cli/templatepull_test.go index da981f6ad658f..99f23d12923cd 100644 --- a/cli/templatepull_test.go +++ b/cli/templatepull_test.go @@ -13,6 +13,7 @@ import ( "github.com/google/uuid" "github.com/stretchr/testify/require" + "github.com/coder/coder/v2/archive" "github.com/coder/coder/v2/cli/clitest" "github.com/coder/coder/v2/coderd" "github.com/coder/coder/v2/coderd/coderdtest" @@ -95,7 +96,7 @@ func TestTemplatePull_Stdout(t *testing.T) { // Verify .zip format tarReader := tar.NewReader(bytes.NewReader(expected)) - expectedZip, err := coderd.CreateZipFromTar(tarReader) + expectedZip, err := archive.CreateZipFromTar(tarReader, coderd.HTTPFileMaxBytes) require.NoError(t, err) inv, root = clitest.New(t, "templates", "pull", "--zip", template.Name) diff --git a/coderd/files.go b/coderd/files.go index d16a3447a1d94..bf1885da1eee9 100644 --- a/coderd/files.go +++ b/coderd/files.go @@ -16,6 +16,7 @@ import ( "github.com/google/uuid" "cdr.dev/slog" + "github.com/coder/coder/v2/archive" "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/dbtime" "github.com/coder/coder/v2/coderd/httpapi" @@ -27,7 +28,7 @@ const ( tarMimeType = "application/x-tar" zipMimeType = "application/zip" - httpFileMaxBytes = 10 * (10 << 20) + HTTPFileMaxBytes = 10 * (10 << 20) ) // @Summary Upload file @@ -55,7 +56,7 @@ func (api *API) postFile(rw http.ResponseWriter, r *http.Request) { return } - r.Body = http.MaxBytesReader(rw, r.Body, httpFileMaxBytes) + r.Body = http.MaxBytesReader(rw, r.Body, HTTPFileMaxBytes) data, err := io.ReadAll(r.Body) if err != nil { httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ @@ -75,7 +76,7 @@ func (api *API) postFile(rw http.ResponseWriter, r *http.Request) { return } - data, err = CreateTarFromZip(zipReader) + data, err = archive.CreateTarFromZip(zipReader, HTTPFileMaxBytes) if err != nil { httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ Message: "Internal error processing .zip archive.", @@ -181,7 +182,7 @@ func (api *API) fileByID(rw http.ResponseWriter, r *http.Request) { rw.Header().Set("Content-Type", codersdk.ContentTypeZip) rw.WriteHeader(http.StatusOK) - err = WriteZipArchive(rw, tar.NewReader(bytes.NewReader(file.Data))) + err = archive.WriteZip(rw, tar.NewReader(bytes.NewReader(file.Data)), HTTPFileMaxBytes) if err != nil { api.Logger.Error(ctx, "invalid .zip archive", slog.F("file_id", fileID), slog.F("mimetype", file.Mimetype), slog.Error(err)) } diff --git a/coderd/files_test.go b/coderd/files_test.go index a5e6aab2498e1..f2dd788e3a6dd 100644 --- a/coderd/files_test.go +++ b/coderd/files_test.go @@ -5,14 +5,13 @@ import ( "bytes" "context" "net/http" - "os" - "path/filepath" "testing" "github.com/google/uuid" "github.com/stretchr/testify/require" - "github.com/coder/coder/v2/coderd" + "github.com/coder/coder/v2/archive" + "github.com/coder/coder/v2/archive/archivetest" "github.com/coder/coder/v2/coderd/coderdtest" "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/testutil" @@ -84,8 +83,8 @@ func TestDownload(t *testing.T) { // given ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) defer cancel() - tarball, err := os.ReadFile(filepath.Join("testdata", "test.tar")) - require.NoError(t, err) + + tarball := archivetest.TestTarFileBytes() // when resp, err := client.Upload(ctx, codersdk.ContentTypeTar, bytes.NewReader(tarball)) @@ -97,7 +96,7 @@ func TestDownload(t *testing.T) { require.Len(t, data, len(tarball)) require.Equal(t, codersdk.ContentTypeTar, contentType) require.Equal(t, tarball, data) - assertSampleTarFile(t, data) + archivetest.AssertSampleTarFile(t, data) }) t.Run("InsertZip_DownloadTar", func(t *testing.T) { @@ -106,8 +105,7 @@ func TestDownload(t *testing.T) { _ = coderdtest.CreateFirstUser(t, client) // given - zipContent, err := os.ReadFile(filepath.Join("testdata", "test.zip")) - require.NoError(t, err) + zipContent := archivetest.TestZipFileBytes() ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) defer cancel() @@ -123,7 +121,7 @@ func TestDownload(t *testing.T) { // Note: creating a zip from a tar will result in some loss of information // as zip files do not store UNIX user:group data. - assertSampleTarFile(t, data) + archivetest.AssertSampleTarFile(t, data) }) t.Run("InsertTar_DownloadZip", func(t *testing.T) { @@ -132,11 +130,10 @@ func TestDownload(t *testing.T) { _ = coderdtest.CreateFirstUser(t, client) // given - tarball, err := os.ReadFile(filepath.Join("testdata", "test.tar")) - require.NoError(t, err) + tarball := archivetest.TestTarFileBytes() tarReader := tar.NewReader(bytes.NewReader(tarball)) - expectedZip, err := coderd.CreateZipFromTar(tarReader) + expectedZip, err := archive.CreateZipFromTar(tarReader, 10240) require.NoError(t, err) ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) @@ -151,6 +148,6 @@ func TestDownload(t *testing.T) { // then require.Equal(t, codersdk.ContentTypeZip, contentType) require.Equal(t, expectedZip, data) - assertSampleZipFile(t, data) + archivetest.AssertSampleZipFile(t, data) }) }