diff --git a/coderd/util/xio/limitwriter.go b/coderd/util/xio/limitwriter.go new file mode 100644 index 0000000000000..8357d5d97a5ca --- /dev/null +++ b/coderd/util/xio/limitwriter.go @@ -0,0 +1,43 @@ +package xio + +import ( + "io" + + "golang.org/x/xerrors" +) + +var ErrLimitReached = xerrors.Errorf("i/o limit reached") + +// LimitWriter will only write bytes to the underlying writer until the limit is reached. +type LimitWriter struct { + Limit int64 + N int64 + W io.Writer +} + +func NewLimitWriter(w io.Writer, n int64) *LimitWriter { + // If anyone tries this, just make a 0 writer. + if n < 0 { + n = 0 + } + return &LimitWriter{ + Limit: n, + N: 0, + W: w, + } +} + +func (l *LimitWriter) Write(p []byte) (int, error) { + if l.N >= l.Limit { + return 0, ErrLimitReached + } + + // Write 0 bytes if the limit is to be exceeded. + if int64(len(p)) > l.Limit-l.N { + return 0, ErrLimitReached + } + + n, err := l.W.Write(p) + l.N += int64(n) + return n, err +} diff --git a/coderd/util/xio/limitwriter_test.go b/coderd/util/xio/limitwriter_test.go new file mode 100644 index 0000000000000..52d6075fbb7f3 --- /dev/null +++ b/coderd/util/xio/limitwriter_test.go @@ -0,0 +1,141 @@ +package xio_test + +import ( + "bytes" + cryptorand "crypto/rand" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/coder/coder/coderd/util/xio" +) + +func TestLimitWriter(t *testing.T) { + t.Parallel() + + type writeCase struct { + N int + ExpN int + Err bool + } + + // testCases will do multiple writes to the same limit writer and check the output. + testCases := []struct { + Name string + L int64 + Writes []writeCase + N int + ExpN int + }{ + { + Name: "Empty", + L: 1000, + Writes: []writeCase{ + // A few empty writes + {N: 0, ExpN: 0}, {N: 0, ExpN: 0}, {N: 0, ExpN: 0}, + }, + }, + { + Name: "NotFull", + L: 1000, + Writes: []writeCase{ + {N: 250, ExpN: 250}, + {N: 250, ExpN: 250}, + {N: 250, ExpN: 250}, + }, + }, + { + Name: "Short", + L: 1000, + Writes: []writeCase{ + {N: 250, ExpN: 250}, + {N: 250, ExpN: 250}, + {N: 250, ExpN: 250}, + {N: 250, ExpN: 250}, + {N: 250, ExpN: 0, Err: true}, + }, + }, + { + Name: "Exact", + L: 1000, + Writes: []writeCase{ + { + N: 1000, + ExpN: 1000, + }, + { + N: 1000, + Err: true, + }, + }, + }, + { + Name: "Over", + L: 1000, + Writes: []writeCase{ + { + N: 5000, + ExpN: 0, + Err: true, + }, + { + N: 5000, + Err: true, + }, + { + N: 5000, + Err: true, + }, + }, + }, + { + Name: "Strange", + L: -1, + Writes: []writeCase{ + { + N: 5, + ExpN: 0, + Err: true, + }, + { + N: 0, + ExpN: 0, + Err: true, + }, + }, + }, + } + + for _, c := range testCases { + c := c + t.Run(c.Name, func(t *testing.T) { + t.Parallel() + + buf := bytes.NewBuffer([]byte{}) + allBuff := bytes.NewBuffer([]byte{}) + w := xio.NewLimitWriter(buf, c.L) + + for _, wc := range c.Writes { + data := make([]byte, wc.N) + + n, err := cryptorand.Read(data) + require.NoError(t, err, "crand read") + require.Equal(t, wc.N, n, "correct bytes read") + max := data[:wc.ExpN] + n, err = w.Write(data) + if wc.Err { + require.Error(t, err, "exp error") + } else { + require.NoError(t, err, "write") + } + + // Need to use this to compare across multiple writes. + // Each write appends to the expected output. + allBuff.Write(max) + + require.Equal(t, wc.ExpN, n, "correct bytes written") + require.Equal(t, allBuff.Bytes(), buf.Bytes(), "expected data") + } + }) + } +} diff --git a/provisionersdk/archive.go b/provisionersdk/archive.go index ec496b6f31592..4642c82777645 100644 --- a/provisionersdk/archive.go +++ b/provisionersdk/archive.go @@ -8,6 +8,8 @@ import ( "strings" "golang.org/x/xerrors" + + "github.com/coder/coder/coderd/util/xio" ) const ( @@ -32,8 +34,9 @@ func dirHasExt(dir string, ext string) (bool, error) { // Tar archives a Terraform directory. func Tar(w io.Writer, directory string, limit int64) error { + // The total bytes written must be under the limit, so use -1 + w = xio.NewLimitWriter(w, limit-1) tarWriter := tar.NewWriter(w) - totalSize := int64(0) const tfExt = ".tf" hasTf, err := dirHasExt(directory, tfExt) @@ -95,22 +98,26 @@ func Tar(w io.Writer, directory string, limit int64) error { if !fileInfo.Mode().IsRegular() { return nil } + data, err := os.Open(file) if err != nil { return err } defer data.Close() - wrote, err := io.Copy(tarWriter, data) + _, err = io.Copy(tarWriter, data) if err != nil { + if xerrors.Is(err, xio.ErrLimitReached) { + return xerrors.Errorf("Archive too big. Must be <= %d bytes", limit) + } return err } - totalSize += wrote - if limit != 0 && totalSize >= limit { - return xerrors.Errorf("Archive too big. Must be <= %d bytes", limit) - } + return data.Close() }) if err != nil { + if xerrors.Is(err, xio.ErrLimitReached) { + return xerrors.Errorf("Archive too big. Must be <= %d bytes", limit) + } return err } err = tarWriter.Flush() diff --git a/provisionersdk/archive_test.go b/provisionersdk/archive_test.go index 66fae25dd9832..947d9c48c30ab 100644 --- a/provisionersdk/archive_test.go +++ b/provisionersdk/archive_test.go @@ -15,6 +15,32 @@ import ( func TestTar(t *testing.T) { t.Parallel() + t.Run("HeaderBreakLimit", func(t *testing.T) { + t.Parallel() + dir := t.TempDir() + file, err := os.CreateTemp(dir, "*.tf") + require.NoError(t, err) + _ = file.Close() + // A header is 512 bytes + err = provisionersdk.Tar(io.Discard, dir, 100) + require.Error(t, err) + }) + t.Run("HeaderAndContent", func(t *testing.T) { + t.Parallel() + dir := t.TempDir() + file, err := os.CreateTemp(dir, "*.tf") + require.NoError(t, err) + _, _ = file.Write(make([]byte, 100)) + _ = file.Close() + // Pay + header is 1024 bytes (padding) + err = provisionersdk.Tar(io.Discard, dir, 1025) + require.NoError(t, err) + + // Limit is 1 byte too small (n == limit is a failure, must be under) + err = provisionersdk.Tar(io.Discard, dir, 1024) + require.Error(t, err) + }) + t.Run("NoTF", func(t *testing.T) { t.Parallel() dir := t.TempDir() @@ -97,7 +123,8 @@ func TestTar(t *testing.T) { } } archive := new(bytes.Buffer) - err := provisionersdk.Tar(archive, dir, 1024) + // Headers are chonky so raise the limit to something reasonable + err := provisionersdk.Tar(archive, dir, 1024<<2) require.NoError(t, err) dir = t.TempDir() err = provisionersdk.Untar(dir, archive)