Skip to content

Commit 5d37630

Browse files
committed
add some more tests
1 parent 5a1a196 commit 5d37630

File tree

3 files changed

+152
-8
lines changed

3 files changed

+152
-8
lines changed

coderd/httpapi/status_writer.go

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ import (
44
"bufio"
55
"net"
66
"net/http"
7+
8+
"golang.org/x/xerrors"
79
)
810

911
var _ http.ResponseWriter = (*StatusWriter)(nil)
@@ -17,32 +19,38 @@ type StatusWriter struct {
1719
Status int
1820
Hijacked bool
1921
ResponseBody []byte
22+
23+
wroteHeader bool
2024
}
2125

2226
func (w *StatusWriter) WriteHeader(status int) {
23-
w.Status = status
24-
w.ResponseWriter.WriteHeader(status)
27+
if !w.wroteHeader {
28+
w.Status = status
29+
w.wroteHeader = true
30+
w.ResponseWriter.WriteHeader(status)
31+
}
2532
}
2633

2734
func (w *StatusWriter) Write(b []byte) (int, error) {
2835
const maxBodySize = 4096
2936

30-
if w.Status == 0 {
37+
if !w.wroteHeader {
3138
w.Status = http.StatusOK
3239
}
3340

3441
if w.Status >= http.StatusBadRequest {
35-
// Instantiate the recorded response body to be at most
36-
// maxBodySize length.
42+
// This is technically wrong as multiple calls to write
43+
// will simply overwrite w.ResponseBody but given that
44+
// we typically only write to the response body once
45+
// and this field is only used for logging I'm leaving
46+
// this as-is.
3747
w.ResponseBody = make([]byte, minInt(len(b), maxBodySize))
3848
copy(w.ResponseBody, b)
3949
}
4050

4151
return w.ResponseWriter.Write(b)
4252
}
4353

44-
// minInt returns the smaller of a or b. This is helpful because math.Min only
45-
// works with float64s.
4654
func minInt(a, b int) int {
4755
if a < b {
4856
return a
@@ -52,5 +60,10 @@ func minInt(a, b int) int {
5260

5361
func (w *StatusWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
5462
w.Hijacked = true
55-
return w.ResponseWriter.(http.Hijacker).Hijack()
63+
hijacker, ok := w.ResponseWriter.(http.Hijacker)
64+
if !ok {
65+
return nil, nil, xerrors.Errorf("%T is not a http.Hijacker", w.ResponseWriter)
66+
}
67+
68+
return hijacker.Hijack()
5669
}

coderd/httpapi/status_writer_test.go

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
package httpapi_test
2+
3+
import (
4+
"bufio"
5+
"crypto/rand"
6+
"net"
7+
"net/http"
8+
"net/http/httptest"
9+
"testing"
10+
11+
"github.com/stretchr/testify/require"
12+
"golang.org/x/xerrors"
13+
14+
"github.com/coder/coder/coderd/httpapi"
15+
)
16+
17+
func TestStatusWriter(t *testing.T) {
18+
t.Parallel()
19+
20+
t.Run("WriteHeader", func(t *testing.T) {
21+
t.Parallel()
22+
23+
var (
24+
rec = httptest.NewRecorder()
25+
w = &httpapi.StatusWriter{ResponseWriter: rec}
26+
)
27+
28+
w.WriteHeader(http.StatusOK)
29+
require.Equal(t, http.StatusOK, w.Status)
30+
// Validate that the code is written to the underlying Response.
31+
require.Equal(t, http.StatusOK, rec.Code)
32+
})
33+
34+
t.Run("WriteHeaderTwice", func(t *testing.T) {
35+
t.Parallel()
36+
37+
var (
38+
rec = httptest.NewRecorder()
39+
w = &httpapi.StatusWriter{ResponseWriter: rec}
40+
code = http.StatusNotFound
41+
)
42+
43+
w.WriteHeader(code)
44+
w.WriteHeader(http.StatusOK)
45+
// Validate that we only record the first status code.
46+
require.Equal(t, code, w.Status)
47+
// Validate that the code is written to the underlying Response.
48+
require.Equal(t, code, rec.Code)
49+
})
50+
51+
t.Run("WriteNoHeader", func(t *testing.T) {
52+
t.Parallel()
53+
var (
54+
rec = httptest.NewRecorder()
55+
w = &httpapi.StatusWriter{ResponseWriter: rec}
56+
body = []byte("hello")
57+
)
58+
59+
_, err := w.Write(body)
60+
require.NoError(t, err)
61+
62+
// Should set the status to OK.
63+
require.Equal(t, http.StatusOK, w.Status)
64+
// We don't record the body for codes <400.
65+
require.Equal(t, []byte(nil), w.ResponseBody)
66+
require.Equal(t, body, rec.Body.Bytes())
67+
})
68+
69+
t.Run("WriteAfterHeader", func(t *testing.T) {
70+
t.Parallel()
71+
var (
72+
rec = httptest.NewRecorder()
73+
w = &httpapi.StatusWriter{ResponseWriter: rec}
74+
body = []byte("hello")
75+
code = http.StatusInternalServerError
76+
)
77+
78+
w.WriteHeader(code)
79+
_, err := w.Write(body)
80+
require.NoError(t, err)
81+
82+
require.Equal(t, code, w.Status)
83+
require.Equal(t, body, w.ResponseBody)
84+
require.Equal(t, body, rec.Body.Bytes())
85+
})
86+
87+
t.Run("WriteMaxBody", func(t *testing.T) {
88+
t.Parallel()
89+
var (
90+
rec = httptest.NewRecorder()
91+
w = &httpapi.StatusWriter{ResponseWriter: rec}
92+
// 8kb body.
93+
body = make([]byte, 8<<10)
94+
code = http.StatusInternalServerError
95+
)
96+
97+
_, err := rand.Read(body)
98+
require.NoError(t, err)
99+
100+
w.WriteHeader(code)
101+
_, err = w.Write(body)
102+
require.NoError(t, err)
103+
104+
require.Equal(t, code, w.Status)
105+
require.Equal(t, body, rec.Body.Bytes())
106+
require.Equal(t, body[:4096], w.ResponseBody)
107+
})
108+
109+
t.Run("Hijack", func(t *testing.T) {
110+
t.Parallel()
111+
var (
112+
rec = httptest.NewRecorder()
113+
)
114+
115+
w := &httpapi.StatusWriter{ResponseWriter: hijacker{rec}}
116+
117+
_, _, err := w.Hijack()
118+
require.Error(t, err)
119+
require.Equal(t, "hijacked", err.Error())
120+
})
121+
}
122+
123+
type hijacker struct {
124+
http.ResponseWriter
125+
}
126+
127+
func (hijacker) Hijack() (net.Conn, *bufio.ReadWriter, error) {
128+
return nil, nil, xerrors.New("hijacked")
129+
}

coderd/httpmw/recover_test.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ import (
1313
)
1414

1515
func TestRecover(t *testing.T) {
16+
t.Parallel()
17+
1618
handler := func(isPanic, hijack bool) http.Handler {
1719
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1820
if isPanic {

0 commit comments

Comments
 (0)