Skip to content

Commit bbecff2

Browse files
authored
feat: return better error if file size is too big to upload (#7775)
* feat: return better error if file size is too big to upload * Use a limit writer to capture actual tar size
1 parent e016c30 commit bbecff2

File tree

4 files changed

+225
-7
lines changed

4 files changed

+225
-7
lines changed

coderd/util/xio/limitwriter.go

+43
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
package xio
2+
3+
import (
4+
"io"
5+
6+
"golang.org/x/xerrors"
7+
)
8+
9+
var ErrLimitReached = xerrors.Errorf("i/o limit reached")
10+
11+
// LimitWriter will only write bytes to the underlying writer until the limit is reached.
12+
type LimitWriter struct {
13+
Limit int64
14+
N int64
15+
W io.Writer
16+
}
17+
18+
func NewLimitWriter(w io.Writer, n int64) *LimitWriter {
19+
// If anyone tries this, just make a 0 writer.
20+
if n < 0 {
21+
n = 0
22+
}
23+
return &LimitWriter{
24+
Limit: n,
25+
N: 0,
26+
W: w,
27+
}
28+
}
29+
30+
func (l *LimitWriter) Write(p []byte) (int, error) {
31+
if l.N >= l.Limit {
32+
return 0, ErrLimitReached
33+
}
34+
35+
// Write 0 bytes if the limit is to be exceeded.
36+
if int64(len(p)) > l.Limit-l.N {
37+
return 0, ErrLimitReached
38+
}
39+
40+
n, err := l.W.Write(p)
41+
l.N += int64(n)
42+
return n, err
43+
}

coderd/util/xio/limitwriter_test.go

+141
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
package xio_test
2+
3+
import (
4+
"bytes"
5+
cryptorand "crypto/rand"
6+
"testing"
7+
8+
"github.com/stretchr/testify/require"
9+
10+
"github.com/coder/coder/coderd/util/xio"
11+
)
12+
13+
func TestLimitWriter(t *testing.T) {
14+
t.Parallel()
15+
16+
type writeCase struct {
17+
N int
18+
ExpN int
19+
Err bool
20+
}
21+
22+
// testCases will do multiple writes to the same limit writer and check the output.
23+
testCases := []struct {
24+
Name string
25+
L int64
26+
Writes []writeCase
27+
N int
28+
ExpN int
29+
}{
30+
{
31+
Name: "Empty",
32+
L: 1000,
33+
Writes: []writeCase{
34+
// A few empty writes
35+
{N: 0, ExpN: 0}, {N: 0, ExpN: 0}, {N: 0, ExpN: 0},
36+
},
37+
},
38+
{
39+
Name: "NotFull",
40+
L: 1000,
41+
Writes: []writeCase{
42+
{N: 250, ExpN: 250},
43+
{N: 250, ExpN: 250},
44+
{N: 250, ExpN: 250},
45+
},
46+
},
47+
{
48+
Name: "Short",
49+
L: 1000,
50+
Writes: []writeCase{
51+
{N: 250, ExpN: 250},
52+
{N: 250, ExpN: 250},
53+
{N: 250, ExpN: 250},
54+
{N: 250, ExpN: 250},
55+
{N: 250, ExpN: 0, Err: true},
56+
},
57+
},
58+
{
59+
Name: "Exact",
60+
L: 1000,
61+
Writes: []writeCase{
62+
{
63+
N: 1000,
64+
ExpN: 1000,
65+
},
66+
{
67+
N: 1000,
68+
Err: true,
69+
},
70+
},
71+
},
72+
{
73+
Name: "Over",
74+
L: 1000,
75+
Writes: []writeCase{
76+
{
77+
N: 5000,
78+
ExpN: 0,
79+
Err: true,
80+
},
81+
{
82+
N: 5000,
83+
Err: true,
84+
},
85+
{
86+
N: 5000,
87+
Err: true,
88+
},
89+
},
90+
},
91+
{
92+
Name: "Strange",
93+
L: -1,
94+
Writes: []writeCase{
95+
{
96+
N: 5,
97+
ExpN: 0,
98+
Err: true,
99+
},
100+
{
101+
N: 0,
102+
ExpN: 0,
103+
Err: true,
104+
},
105+
},
106+
},
107+
}
108+
109+
for _, c := range testCases {
110+
c := c
111+
t.Run(c.Name, func(t *testing.T) {
112+
t.Parallel()
113+
114+
buf := bytes.NewBuffer([]byte{})
115+
allBuff := bytes.NewBuffer([]byte{})
116+
w := xio.NewLimitWriter(buf, c.L)
117+
118+
for _, wc := range c.Writes {
119+
data := make([]byte, wc.N)
120+
121+
n, err := cryptorand.Read(data)
122+
require.NoError(t, err, "crand read")
123+
require.Equal(t, wc.N, n, "correct bytes read")
124+
max := data[:wc.ExpN]
125+
n, err = w.Write(data)
126+
if wc.Err {
127+
require.Error(t, err, "exp error")
128+
} else {
129+
require.NoError(t, err, "write")
130+
}
131+
132+
// Need to use this to compare across multiple writes.
133+
// Each write appends to the expected output.
134+
allBuff.Write(max)
135+
136+
require.Equal(t, wc.ExpN, n, "correct bytes written")
137+
require.Equal(t, allBuff.Bytes(), buf.Bytes(), "expected data")
138+
}
139+
})
140+
}
141+
}

provisionersdk/archive.go

+13-6
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ import (
88
"strings"
99

1010
"golang.org/x/xerrors"
11+
12+
"github.com/coder/coder/coderd/util/xio"
1113
)
1214

1315
const (
@@ -32,8 +34,9 @@ func dirHasExt(dir string, ext string) (bool, error) {
3234

3335
// Tar archives a Terraform directory.
3436
func Tar(w io.Writer, directory string, limit int64) error {
37+
// The total bytes written must be under the limit, so use -1
38+
w = xio.NewLimitWriter(w, limit-1)
3539
tarWriter := tar.NewWriter(w)
36-
totalSize := int64(0)
3740

3841
const tfExt = ".tf"
3942
hasTf, err := dirHasExt(directory, tfExt)
@@ -95,22 +98,26 @@ func Tar(w io.Writer, directory string, limit int64) error {
9598
if !fileInfo.Mode().IsRegular() {
9699
return nil
97100
}
101+
98102
data, err := os.Open(file)
99103
if err != nil {
100104
return err
101105
}
102106
defer data.Close()
103-
wrote, err := io.Copy(tarWriter, data)
107+
_, err = io.Copy(tarWriter, data)
104108
if err != nil {
109+
if xerrors.Is(err, xio.ErrLimitReached) {
110+
return xerrors.Errorf("Archive too big. Must be <= %d bytes", limit)
111+
}
105112
return err
106113
}
107-
totalSize += wrote
108-
if limit != 0 && totalSize >= limit {
109-
return xerrors.Errorf("Archive too big. Must be <= %d bytes", limit)
110-
}
114+
111115
return data.Close()
112116
})
113117
if err != nil {
118+
if xerrors.Is(err, xio.ErrLimitReached) {
119+
return xerrors.Errorf("Archive too big. Must be <= %d bytes", limit)
120+
}
114121
return err
115122
}
116123
err = tarWriter.Flush()

provisionersdk/archive_test.go

+28-1
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,32 @@ import (
1515

1616
func TestTar(t *testing.T) {
1717
t.Parallel()
18+
t.Run("HeaderBreakLimit", func(t *testing.T) {
19+
t.Parallel()
20+
dir := t.TempDir()
21+
file, err := os.CreateTemp(dir, "*.tf")
22+
require.NoError(t, err)
23+
_ = file.Close()
24+
// A header is 512 bytes
25+
err = provisionersdk.Tar(io.Discard, dir, 100)
26+
require.Error(t, err)
27+
})
28+
t.Run("HeaderAndContent", func(t *testing.T) {
29+
t.Parallel()
30+
dir := t.TempDir()
31+
file, err := os.CreateTemp(dir, "*.tf")
32+
require.NoError(t, err)
33+
_, _ = file.Write(make([]byte, 100))
34+
_ = file.Close()
35+
// Pay + header is 1024 bytes (padding)
36+
err = provisionersdk.Tar(io.Discard, dir, 1025)
37+
require.NoError(t, err)
38+
39+
// Limit is 1 byte too small (n == limit is a failure, must be under)
40+
err = provisionersdk.Tar(io.Discard, dir, 1024)
41+
require.Error(t, err)
42+
})
43+
1844
t.Run("NoTF", func(t *testing.T) {
1945
t.Parallel()
2046
dir := t.TempDir()
@@ -97,7 +123,8 @@ func TestTar(t *testing.T) {
97123
}
98124
}
99125
archive := new(bytes.Buffer)
100-
err := provisionersdk.Tar(archive, dir, 1024)
126+
// Headers are chonky so raise the limit to something reasonable
127+
err := provisionersdk.Tar(archive, dir, 1024<<2)
101128
require.NoError(t, err)
102129
dir = t.TempDir()
103130
err = provisionersdk.Untar(dir, archive)

0 commit comments

Comments
 (0)