Skip to content

fix(coderd): correctly handle tar dir entries with missing path separator #12479

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Mar 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions coderd/fileszip.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"errors"
"io"
"log"
"strings"
)

func CreateTarFromZip(zipReader *zip.Reader) ([]byte, error) {
Expand Down Expand Up @@ -87,6 +88,12 @@ func WriteZipArchive(w io.Writer, tarReader *tar.Reader) error {
return err
}
zipHeader.Name = tarHeader.Name
// Some versions of unzip do not check the mode on a file entry and
// simply assume that entries with a trailing path separator (/) are
// directories, and that everything else is a file. Give them a hint.
if tarHeader.FileInfo().IsDir() && !strings.HasSuffix(tarHeader.Name, "/") {
zipHeader.Name += "/"
}

zipEntry, err := zipWriter.CreateHeader(zipHeader)
if err != nil {
Expand Down
66 changes: 52 additions & 14 deletions coderd/fileszip_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,26 +58,64 @@ func TestCreateZipFromTar(t *testing.T) {
if runtime.GOOS != "linux" {
t.Skip("skipping this test on non-Linux platform")
}
t.Run("OK", func(t *testing.T) {
t.Parallel()
tarBytes, err := os.ReadFile(filepath.Join(".", "testdata", "test.tar"))
require.NoError(t, err, "failed to read sample tar file")

tarBytes, err := os.ReadFile(filepath.Join("testdata", "test.tar"))
require.NoError(t, err, "failed to read sample tar file")
tr := tar.NewReader(bytes.NewReader(tarBytes))
zipBytes, err := coderd.CreateZipFromTar(tr)
require.NoError(t, err)

tr := tar.NewReader(bytes.NewReader(tarBytes))
zipBytes, err := coderd.CreateZipFromTar(tr)
require.NoError(t, err)
assertSampleZipFile(t, zipBytes)

assertSampleZipFile(t, zipBytes)
tempDir := t.TempDir()
tempFilePath := filepath.Join(tempDir, "test.zip")
err = os.WriteFile(tempFilePath, zipBytes, 0o600)
require.NoError(t, err, "failed to write converted zip file")

tempDir := t.TempDir()
tempFilePath := filepath.Join(tempDir, "test.zip")
err = os.WriteFile(tempFilePath, zipBytes, 0o600)
require.NoError(t, err, "failed to write converted zip file")
ctx := testutil.Context(t, testutil.WaitShort)
cmd := exec.CommandContext(ctx, "unzip", tempFilePath, "-d", tempDir)
require.NoError(t, cmd.Run(), "failed to extract converted zip file")

ctx := testutil.Context(t, testutil.WaitShort)
cmd := exec.CommandContext(ctx, "unzip", tempFilePath, "-d", tempDir)
require.NoError(t, cmd.Run(), "failed to extract converted zip file")
assertExtractedFiles(t, tempDir, false)
})

assertExtractedFiles(t, tempDir, false)
t.Run("MissingSlashInDirectoryHeader", func(t *testing.T) {
t.Parallel()

// Given: a tar archive containing a directory entry that has the directory
// mode bit set but the name is missing a trailing slash

var tarBytes bytes.Buffer
tw := tar.NewWriter(&tarBytes)
tw.WriteHeader(&tar.Header{
Name: "dir",
Typeflag: tar.TypeDir,
Size: 0,
})
require.NoError(t, tw.Flush())
require.NoError(t, tw.Close())

// When: we convert this to a zip
tr := tar.NewReader(&tarBytes)
zipBytes, err := coderd.CreateZipFromTar(tr)
require.NoError(t, err)

// Then: the resulting zip should contain a corresponding directory
zr, err := zip.NewReader(bytes.NewReader(zipBytes), int64(len(zipBytes)))
require.NoError(t, err)
for _, zf := range zr.File {
switch zf.Name {
case "dir":
require.Fail(t, "missing trailing slash in directory name")
case "dir/":
require.True(t, zf.Mode().IsDir(), "should be a directory")
default:
require.Fail(t, "unexpected file in archive")
}
}
})
}

// nolint:revive // this is a control flag but it's in a unit test
Expand Down
5 changes: 5 additions & 0 deletions provisionersdk/archive.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,11 @@ func Tar(w io.Writer, logger slog.Logger, directory string, limit int64) error {
}
// Use unix paths in the tar archive.
header.Name = filepath.ToSlash(rel)
// tar.FileInfoHeader() will do this, but filepath.Rel() calls filepath.Clean()
// which strips trailing path separators for directories.
if fileInfo.IsDir() {
header.Name += "/"
}
if err := tarWriter.WriteHeader(header); err != nil {
return err
}
Expand Down