Skip to content

Commit 554281c

Browse files
committed
chore(coderd): extract fileszip to its own package for reuse
1 parent 5bcaa93 commit 554281c

File tree

7 files changed

+244
-110
lines changed

7 files changed

+244
-110
lines changed

coderd/files.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import (
2121
"github.com/coder/coder/v2/coderd/httpapi"
2222
"github.com/coder/coder/v2/coderd/httpmw"
2323
"github.com/coder/coder/v2/codersdk"
24+
"github.com/coder/coder/v2/fileszip"
2425
)
2526

2627
const (
@@ -75,7 +76,7 @@ func (api *API) postFile(rw http.ResponseWriter, r *http.Request) {
7576
return
7677
}
7778

78-
data, err = CreateTarFromZip(zipReader)
79+
data, err = fileszip.CreateTarFromZip(zipReader, httpFileMaxBytes)
7980
if err != nil {
8081
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
8182
Message: "Internal error processing .zip archive.",
@@ -181,7 +182,7 @@ func (api *API) fileByID(rw http.ResponseWriter, r *http.Request) {
181182

182183
rw.Header().Set("Content-Type", codersdk.ContentTypeZip)
183184
rw.WriteHeader(http.StatusOK)
184-
err = WriteZipArchive(rw, tar.NewReader(bytes.NewReader(file.Data)))
185+
err = fileszip.WriteZipArchive(rw, tar.NewReader(bytes.NewReader(file.Data)), httpFileMaxBytes)
185186
if err != nil {
186187
api.Logger.Error(ctx, "invalid .zip archive", slog.F("file_id", fileID), slog.F("mimetype", file.Mimetype), slog.Error(err))
187188
}

coderd/files_test.go

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@ import (
55
"bytes"
66
"context"
77
"net/http"
8-
"os"
9-
"path/filepath"
108
"testing"
119

1210
"github.com/google/uuid"
@@ -15,6 +13,7 @@ import (
1513
"github.com/coder/coder/v2/coderd"
1614
"github.com/coder/coder/v2/coderd/coderdtest"
1715
"github.com/coder/coder/v2/codersdk"
16+
"github.com/coder/coder/v2/fileszip/filesziptest"
1817
"github.com/coder/coder/v2/testutil"
1918
)
2019

@@ -84,8 +83,8 @@ func TestDownload(t *testing.T) {
8483
// given
8584
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
8685
defer cancel()
87-
tarball, err := os.ReadFile(filepath.Join("testdata", "test.tar"))
88-
require.NoError(t, err)
86+
87+
tarball := filesziptest.TestTarFileBytes()
8988

9089
// when
9190
resp, err := client.Upload(ctx, codersdk.ContentTypeTar, bytes.NewReader(tarball))
@@ -97,7 +96,7 @@ func TestDownload(t *testing.T) {
9796
require.Len(t, data, len(tarball))
9897
require.Equal(t, codersdk.ContentTypeTar, contentType)
9998
require.Equal(t, tarball, data)
100-
assertSampleTarFile(t, data)
99+
filesziptest.AssertSampleTarFile(t, data)
101100
})
102101

103102
t.Run("InsertZip_DownloadTar", func(t *testing.T) {
@@ -106,8 +105,7 @@ func TestDownload(t *testing.T) {
106105
_ = coderdtest.CreateFirstUser(t, client)
107106

108107
// given
109-
zipContent, err := os.ReadFile(filepath.Join("testdata", "test.zip"))
110-
require.NoError(t, err)
108+
zipContent := filesziptest.TestZipFileBytes()
111109

112110
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
113111
defer cancel()
@@ -123,7 +121,7 @@ func TestDownload(t *testing.T) {
123121

124122
// Note: creating a zip from a tar will result in some loss of information
125123
// as zip files do not store UNIX user:group data.
126-
assertSampleTarFile(t, data)
124+
filesziptest.AssertSampleTarFile(t, data)
127125
})
128126

129127
t.Run("InsertTar_DownloadZip", func(t *testing.T) {
@@ -132,8 +130,7 @@ func TestDownload(t *testing.T) {
132130
_ = coderdtest.CreateFirstUser(t, client)
133131

134132
// given
135-
tarball, err := os.ReadFile(filepath.Join("testdata", "test.tar"))
136-
require.NoError(t, err)
133+
tarball := filesziptest.TestTarFileBytes()
137134

138135
tarReader := tar.NewReader(bytes.NewReader(tarball))
139136
expectedZip, err := coderd.CreateZipFromTar(tarReader)
@@ -151,6 +148,6 @@ func TestDownload(t *testing.T) {
151148
// then
152149
require.Equal(t, codersdk.ContentTypeZip, contentType)
153150
require.Equal(t, expectedZip, data)
154-
assertSampleZipFile(t, data)
151+
filesziptest.AssertSampleZipFile(t, data)
155152
})
156153
}

fileszip/fileszip.go

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
package fileszip
2+
3+
import (
4+
"archive/tar"
5+
"archive/zip"
6+
"bytes"
7+
"errors"
8+
"io"
9+
"log"
10+
"strings"
11+
)
12+
13+
func CreateTarFromZip(zipReader *zip.Reader, maxSize int64) ([]byte, error) {
14+
var tarBuffer bytes.Buffer
15+
err := writeTarArchive(&tarBuffer, zipReader, maxSize)
16+
if err != nil {
17+
return nil, err
18+
}
19+
return tarBuffer.Bytes(), nil
20+
}
21+
22+
func writeTarArchive(w io.Writer, zipReader *zip.Reader, maxSize int64) error {
23+
tarWriter := tar.NewWriter(w)
24+
defer tarWriter.Close()
25+
26+
for _, file := range zipReader.File {
27+
err := processFileInZipArchive(file, tarWriter, maxSize)
28+
if err != nil {
29+
return err
30+
}
31+
}
32+
return nil
33+
}
34+
35+
func processFileInZipArchive(file *zip.File, tarWriter *tar.Writer, maxSize int64) error {
36+
fileReader, err := file.Open()
37+
if err != nil {
38+
return err
39+
}
40+
defer fileReader.Close()
41+
42+
err = tarWriter.WriteHeader(&tar.Header{
43+
Name: file.Name,
44+
Size: file.FileInfo().Size(),
45+
Mode: int64(file.Mode()),
46+
ModTime: file.Modified,
47+
// Note: Zip archives do not store ownership information.
48+
Uid: 1000,
49+
Gid: 1000,
50+
})
51+
if err != nil {
52+
return err
53+
}
54+
55+
n, err := io.CopyN(tarWriter, fileReader, maxSize)
56+
log.Println(file.Name, n, err)
57+
if errors.Is(err, io.EOF) {
58+
err = nil
59+
}
60+
return err
61+
}
62+
63+
func CreateZipFromTar(tarReader *tar.Reader, maxSize int64) ([]byte, error) {
64+
var zipBuffer bytes.Buffer
65+
err := WriteZipArchive(&zipBuffer, tarReader, maxSize)
66+
if err != nil {
67+
return nil, err
68+
}
69+
return zipBuffer.Bytes(), nil
70+
}
71+
72+
func WriteZipArchive(w io.Writer, tarReader *tar.Reader, maxSize int64) error {
73+
zipWriter := zip.NewWriter(w)
74+
defer zipWriter.Close()
75+
76+
for {
77+
tarHeader, err := tarReader.Next()
78+
if errors.Is(err, io.EOF) {
79+
break
80+
}
81+
82+
if err != nil {
83+
return err
84+
}
85+
86+
zipHeader, err := zip.FileInfoHeader(tarHeader.FileInfo())
87+
if err != nil {
88+
return err
89+
}
90+
zipHeader.Name = tarHeader.Name
91+
// Some versions of unzip do not check the mode on a file entry and
92+
// simply assume that entries with a trailing path separator (/) are
93+
// directories, and that everything else is a file. Give them a hint.
94+
if tarHeader.FileInfo().IsDir() && !strings.HasSuffix(tarHeader.Name, "/") {
95+
zipHeader.Name += "/"
96+
}
97+
98+
zipEntry, err := zipWriter.CreateHeader(zipHeader)
99+
if err != nil {
100+
return err
101+
}
102+
103+
_, err = io.CopyN(zipEntry, tarReader, maxSize)
104+
if errors.Is(err, io.EOF) {
105+
err = nil
106+
}
107+
if err != nil {
108+
return err
109+
}
110+
}
111+
return nil // don't need to flush as we call `writer.Close()`
112+
}
Lines changed: 12 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,22 @@
1-
package coderd_test
1+
package fileszip_test
22

33
import (
44
"archive/tar"
55
"archive/zip"
66
"bytes"
7-
"io"
87
"io/fs"
98
"os"
109
"os/exec"
1110
"path/filepath"
1211
"runtime"
1312
"strings"
1413
"testing"
15-
"time"
1614

1715
"github.com/stretchr/testify/assert"
1816
"github.com/stretchr/testify/require"
19-
"golang.org/x/xerrors"
2017

21-
"github.com/coder/coder/v2/coderd"
18+
"github.com/coder/coder/v2/fileszip"
19+
"github.com/coder/coder/v2/fileszip/filesziptest"
2220
"github.com/coder/coder/v2/testutil"
2321
)
2422

@@ -30,18 +28,17 @@ func TestCreateTarFromZip(t *testing.T) {
3028

3129
// Read a zip file we prepared earlier
3230
ctx := testutil.Context(t, testutil.WaitShort)
33-
zipBytes, err := os.ReadFile(filepath.Join("testdata", "test.zip"))
34-
require.NoError(t, err, "failed to read sample zip file")
31+
zipBytes := filesziptest.TestZipFileBytes()
3532
// Assert invariant
36-
assertSampleZipFile(t, zipBytes)
33+
filesziptest.AssertSampleZipFile(t, zipBytes)
3734

3835
zr, err := zip.NewReader(bytes.NewReader(zipBytes), int64(len(zipBytes)))
3936
require.NoError(t, err, "failed to parse sample zip file")
4037

41-
tarBytes, err := coderd.CreateTarFromZip(zr)
38+
tarBytes, err := fileszip.CreateTarFromZip(zr, 10240)
4239
require.NoError(t, err, "failed to convert zip to tar")
4340

44-
assertSampleTarFile(t, tarBytes)
41+
filesziptest.AssertSampleTarFile(t, tarBytes)
4542

4643
tempDir := t.TempDir()
4744
tempFilePath := filepath.Join(tempDir, "test.tar")
@@ -60,14 +57,13 @@ func TestCreateZipFromTar(t *testing.T) {
6057
}
6158
t.Run("OK", func(t *testing.T) {
6259
t.Parallel()
63-
tarBytes, err := os.ReadFile(filepath.Join(".", "testdata", "test.tar"))
64-
require.NoError(t, err, "failed to read sample tar file")
60+
tarBytes := filesziptest.TestTarFileBytes()
6561

6662
tr := tar.NewReader(bytes.NewReader(tarBytes))
67-
zipBytes, err := coderd.CreateZipFromTar(tr)
63+
zipBytes, err := fileszip.CreateZipFromTar(tr, 10240)
6864
require.NoError(t, err)
6965

70-
assertSampleZipFile(t, zipBytes)
66+
filesziptest.AssertSampleZipFile(t, zipBytes)
7167

7268
tempDir := t.TempDir()
7369
tempFilePath := filepath.Join(tempDir, "test.zip")
@@ -99,7 +95,7 @@ func TestCreateZipFromTar(t *testing.T) {
9995

10096
// When: we convert this to a zip
10197
tr := tar.NewReader(&tarBytes)
102-
zipBytes, err := coderd.CreateZipFromTar(tr)
98+
zipBytes, err := fileszip.CreateZipFromTar(tr, 10240)
10399
require.NoError(t, err)
104100

105101
// Then: the resulting zip should contain a corresponding directory
@@ -133,7 +129,7 @@ func assertExtractedFiles(t *testing.T, dir string, checkModePerm bool) {
133129
if checkModePerm {
134130
assert.Equal(t, fs.ModePerm&0o755, stat.Mode().Perm(), "expected mode 0755 on directory")
135131
}
136-
assert.Equal(t, archiveRefTime(t).UTC(), stat.ModTime().UTC(), "unexpected modtime of %q", path)
132+
assert.Equal(t, filesziptest.ArchiveRefTime(t).UTC(), stat.ModTime().UTC(), "unexpected modtime of %q", path)
137133
case "/test/hello.txt":
138134
stat, err := os.Stat(path)
139135
assert.NoError(t, err, "failed to stat path %q", path)
@@ -168,84 +164,3 @@ func assertExtractedFiles(t *testing.T, dir string, checkModePerm bool) {
168164
return nil
169165
})
170166
}
171-
172-
func assertSampleTarFile(t *testing.T, tarBytes []byte) {
173-
t.Helper()
174-
175-
tr := tar.NewReader(bytes.NewReader(tarBytes))
176-
for {
177-
hdr, err := tr.Next()
178-
if err != nil {
179-
if err == io.EOF {
180-
return
181-
}
182-
require.NoError(t, err)
183-
}
184-
185-
// Note: ignoring timezones here.
186-
require.Equal(t, archiveRefTime(t).UTC(), hdr.ModTime.UTC())
187-
188-
switch hdr.Name {
189-
case "test/":
190-
require.Equal(t, hdr.Typeflag, byte(tar.TypeDir))
191-
case "test/hello.txt":
192-
require.Equal(t, hdr.Typeflag, byte(tar.TypeReg))
193-
bs, err := io.ReadAll(tr)
194-
if err != nil && !xerrors.Is(err, io.EOF) {
195-
require.NoError(t, err)
196-
}
197-
require.Equal(t, "hello", string(bs))
198-
case "test/dir/":
199-
require.Equal(t, hdr.Typeflag, byte(tar.TypeDir))
200-
case "test/dir/world.txt":
201-
require.Equal(t, hdr.Typeflag, byte(tar.TypeReg))
202-
bs, err := io.ReadAll(tr)
203-
if err != nil && !xerrors.Is(err, io.EOF) {
204-
require.NoError(t, err)
205-
}
206-
require.Equal(t, "world", string(bs))
207-
default:
208-
require.Failf(t, "unexpected file in tar", hdr.Name)
209-
}
210-
}
211-
}
212-
213-
func assertSampleZipFile(t *testing.T, zipBytes []byte) {
214-
t.Helper()
215-
216-
zr, err := zip.NewReader(bytes.NewReader(zipBytes), int64(len(zipBytes)))
217-
require.NoError(t, err)
218-
219-
for _, f := range zr.File {
220-
// Note: ignoring timezones here.
221-
require.Equal(t, archiveRefTime(t).UTC(), f.Modified.UTC())
222-
switch f.Name {
223-
case "test/", "test/dir/":
224-
// directory
225-
case "test/hello.txt":
226-
rc, err := f.Open()
227-
require.NoError(t, err)
228-
bs, err := io.ReadAll(rc)
229-
_ = rc.Close()
230-
require.NoError(t, err)
231-
require.Equal(t, "hello", string(bs))
232-
case "test/dir/world.txt":
233-
rc, err := f.Open()
234-
require.NoError(t, err)
235-
bs, err := io.ReadAll(rc)
236-
_ = rc.Close()
237-
require.NoError(t, err)
238-
require.Equal(t, "world", string(bs))
239-
default:
240-
require.Failf(t, "unexpected file in zip", f.Name)
241-
}
242-
}
243-
}
244-
245-
// archiveRefTime is the Go reference time. The contents of the sample tar and zip files
246-
// in testdata/ all have their modtimes set to the below in some timezone.
247-
func archiveRefTime(t *testing.T) time.Time {
248-
locMST, err := time.LoadLocation("MST")
249-
require.NoError(t, err, "failed to load MST timezone")
250-
return time.Date(2006, 1, 2, 3, 4, 5, 0, locMST)
251-
}

0 commit comments

Comments
 (0)