Skip to content

Commit 053fe6f

Browse files
authored
feat: add panic recovery middleware (#3687)
1 parent 3cf17d3 commit 053fe6f

13 files changed

+471
-40
lines changed

coderd/coderd.go

Lines changed: 2 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
11
package coderd
22

33
import (
4-
"context"
54
"crypto/x509"
6-
"fmt"
75
"io"
86
"net/http"
97
"net/url"
@@ -125,11 +123,8 @@ func New(options *Options) *API {
125123
apiKeyMiddleware := httpmw.ExtractAPIKey(options.Database, oauthConfigs, false)
126124

127125
r.Use(
128-
func(next http.Handler) http.Handler {
129-
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
130-
next.ServeHTTP(middleware.NewWrapResponseWriter(w, r.ProtoMajor), r)
131-
})
132-
},
126+
httpmw.Recover(api.Logger),
127+
httpmw.Logger(api.Logger),
133128
httpmw.Prometheus(options.PrometheusRegistry),
134129
)
135130

@@ -159,7 +154,6 @@ func New(options *Options) *API {
159154
r.Use(
160155
// Specific routes can specify smaller limits.
161156
httpmw.RateLimitPerMinute(options.APIRateLimit),
162-
debugLogRequest(api.Logger),
163157
tracing.HTTPMW(api.TracerProvider, "coderd.http"),
164158
)
165159
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
@@ -438,15 +432,6 @@ func (api *API) Close() error {
438432
return api.workspaceAgentCache.Close()
439433
}
440434

441-
func debugLogRequest(log slog.Logger) func(http.Handler) http.Handler {
442-
return func(next http.Handler) http.Handler {
443-
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
444-
log.Debug(context.Background(), fmt.Sprintf("%s %s", r.Method, r.URL.Path))
445-
next.ServeHTTP(rw, r)
446-
})
447-
}
448-
}
449-
450435
func compressHandler(h http.Handler) http.Handler {
451436
cmp := middleware.NewCompressor(5,
452437
"text/*",

coderd/httpapi/httpapi.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,18 @@ func Forbidden(rw http.ResponseWriter) {
5959
})
6060
}
6161

62+
func InternalServerError(rw http.ResponseWriter, err error) {
63+
var details string
64+
if err != nil {
65+
details = err.Error()
66+
}
67+
68+
Write(rw, http.StatusInternalServerError, codersdk.Response{
69+
Message: "An internal server error occurred.",
70+
Detail: details,
71+
})
72+
}
73+
6274
// Write outputs a standardized format to an HTTP response body.
6375
func Write(rw http.ResponseWriter, status int, response interface{}) {
6476
buf := &bytes.Buffer{}

coderd/httpapi/httpapi_test.go

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,46 @@ import (
1010

1111
"github.com/stretchr/testify/assert"
1212
"github.com/stretchr/testify/require"
13+
"golang.org/x/xerrors"
1314

1415
"github.com/coder/coder/coderd/httpapi"
1516
"github.com/coder/coder/codersdk"
1617
)
1718

19+
func TestInternalServerError(t *testing.T) {
20+
t.Parallel()
21+
22+
t.Run("NoError", func(t *testing.T) {
23+
t.Parallel()
24+
w := httptest.NewRecorder()
25+
httpapi.InternalServerError(w, nil)
26+
27+
var resp codersdk.Response
28+
err := json.NewDecoder(w.Body).Decode(&resp)
29+
require.NoError(t, err)
30+
require.Equal(t, http.StatusInternalServerError, w.Code)
31+
require.NotEmpty(t, resp.Message)
32+
require.Empty(t, resp.Detail)
33+
})
34+
35+
t.Run("WithError", func(t *testing.T) {
36+
t.Parallel()
37+
var (
38+
w = httptest.NewRecorder()
39+
httpErr = xerrors.New("error!")
40+
)
41+
42+
httpapi.InternalServerError(w, httpErr)
43+
44+
var resp codersdk.Response
45+
err := json.NewDecoder(w.Body).Decode(&resp)
46+
require.NoError(t, err)
47+
require.Equal(t, http.StatusInternalServerError, w.Code)
48+
require.NotEmpty(t, resp.Message)
49+
require.Equal(t, httpErr.Error(), resp.Detail)
50+
})
51+
}
52+
1853
func TestWrite(t *testing.T) {
1954
t.Parallel()
2055
t.Run("NoErrors", func(t *testing.T) {

coderd/httpapi/request.go

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
package httpapi
2+
3+
import "net/http"
4+
5+
const (
6+
// XForwardedHostHeader is a header used by proxies to indicate the
7+
// original host of the request.
8+
XForwardedHostHeader = "X-Forwarded-Host"
9+
)
10+
11+
// RequestHost returns the name of the host from the request. It prioritizes
12+
// 'X-Forwarded-Host' over r.Host since most requests are being proxied.
13+
func RequestHost(r *http.Request) string {
14+
host := r.Header.Get(XForwardedHostHeader)
15+
if host != "" {
16+
return host
17+
}
18+
19+
return r.Host
20+
}
21+
22+
func IsWebsocketUpgrade(r *http.Request) bool {
23+
vs := r.Header.Values("Upgrade")
24+
for _, v := range vs {
25+
if v == "websocket" {
26+
return true
27+
}
28+
}
29+
return false
30+
}

coderd/httpapi/status_writer.go

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
package httpapi
2+
3+
import (
4+
"bufio"
5+
"net"
6+
"net/http"
7+
8+
"golang.org/x/xerrors"
9+
)
10+
11+
var _ http.ResponseWriter = (*StatusWriter)(nil)
12+
var _ http.Hijacker = (*StatusWriter)(nil)
13+
14+
// StatusWriter intercepts the status of the request and the response body up
15+
// to maxBodySize if Status >= 400. It is guaranteed to be the ResponseWriter
16+
// directly downstream from Middleware.
17+
type StatusWriter struct {
18+
http.ResponseWriter
19+
Status int
20+
Hijacked bool
21+
responseBody []byte
22+
23+
wroteHeader bool
24+
}
25+
26+
func (w *StatusWriter) WriteHeader(status int) {
27+
if !w.wroteHeader {
28+
w.Status = status
29+
w.wroteHeader = true
30+
}
31+
w.ResponseWriter.WriteHeader(status)
32+
}
33+
34+
func (w *StatusWriter) Write(b []byte) (int, error) {
35+
const maxBodySize = 4096
36+
37+
if !w.wroteHeader {
38+
w.Status = http.StatusOK
39+
w.wroteHeader = true
40+
}
41+
42+
if w.Status >= http.StatusBadRequest {
43+
// This is technically wrong as multiple calls to write
44+
// will simply overwrite w.ResponseBody but given that
45+
// we typically only write to the response body once
46+
// and this field is only used for logging I'm leaving
47+
// this as-is.
48+
w.responseBody = make([]byte, minInt(len(b), maxBodySize))
49+
copy(w.responseBody, b)
50+
}
51+
52+
return w.ResponseWriter.Write(b)
53+
}
54+
55+
func minInt(a, b int) int {
56+
if a < b {
57+
return a
58+
}
59+
return b
60+
}
61+
62+
func (w *StatusWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
63+
hijacker, ok := w.ResponseWriter.(http.Hijacker)
64+
if !ok {
65+
return nil, nil, xerrors.Errorf("%T is not a http.Hijacker", w.ResponseWriter)
66+
}
67+
w.Hijacked = true
68+
69+
return hijacker.Hijack()
70+
}
71+
72+
func (w *StatusWriter) ResponseBody() []byte {
73+
return w.responseBody
74+
}

coderd/httpapi/status_writer_test.go

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
package httpapi_test
2+
3+
import (
4+
"bufio"
5+
"crypto/rand"
6+
"net"
7+
"net/http"
8+
"net/http/httptest"
9+
"testing"
10+
11+
"github.com/stretchr/testify/require"
12+
"golang.org/x/xerrors"
13+
14+
"github.com/coder/coder/coderd/httpapi"
15+
)
16+
17+
func TestStatusWriter(t *testing.T) {
18+
t.Parallel()
19+
20+
t.Run("WriteHeader", func(t *testing.T) {
21+
t.Parallel()
22+
23+
var (
24+
rec = httptest.NewRecorder()
25+
w = &httpapi.StatusWriter{ResponseWriter: rec}
26+
)
27+
28+
w.WriteHeader(http.StatusOK)
29+
require.Equal(t, http.StatusOK, w.Status)
30+
// Validate that the code is written to the underlying Response.
31+
require.Equal(t, http.StatusOK, rec.Code)
32+
})
33+
34+
t.Run("WriteHeaderTwice", func(t *testing.T) {
35+
t.Parallel()
36+
37+
var (
38+
rec = httptest.NewRecorder()
39+
w = &httpapi.StatusWriter{ResponseWriter: rec}
40+
code = http.StatusNotFound
41+
)
42+
43+
w.WriteHeader(code)
44+
w.WriteHeader(http.StatusOK)
45+
// Validate that we only record the first status code.
46+
require.Equal(t, code, w.Status)
47+
// Validate that the code is written to the underlying Response.
48+
require.Equal(t, code, rec.Code)
49+
})
50+
51+
t.Run("WriteNoHeader", func(t *testing.T) {
52+
t.Parallel()
53+
var (
54+
rec = httptest.NewRecorder()
55+
w = &httpapi.StatusWriter{ResponseWriter: rec}
56+
body = []byte("hello")
57+
)
58+
59+
_, err := w.Write(body)
60+
require.NoError(t, err)
61+
62+
// Should set the status to OK.
63+
require.Equal(t, http.StatusOK, w.Status)
64+
// We don't record the body for codes <400.
65+
require.Equal(t, []byte(nil), w.ResponseBody())
66+
require.Equal(t, body, rec.Body.Bytes())
67+
})
68+
69+
t.Run("WriteAfterHeader", func(t *testing.T) {
70+
t.Parallel()
71+
var (
72+
rec = httptest.NewRecorder()
73+
w = &httpapi.StatusWriter{ResponseWriter: rec}
74+
body = []byte("hello")
75+
code = http.StatusInternalServerError
76+
)
77+
78+
w.WriteHeader(code)
79+
_, err := w.Write(body)
80+
require.NoError(t, err)
81+
82+
require.Equal(t, code, w.Status)
83+
require.Equal(t, body, w.ResponseBody())
84+
require.Equal(t, body, rec.Body.Bytes())
85+
})
86+
87+
t.Run("WriteMaxBody", func(t *testing.T) {
88+
t.Parallel()
89+
var (
90+
rec = httptest.NewRecorder()
91+
w = &httpapi.StatusWriter{ResponseWriter: rec}
92+
// 8kb body.
93+
body = make([]byte, 8<<10)
94+
code = http.StatusInternalServerError
95+
)
96+
97+
_, err := rand.Read(body)
98+
require.NoError(t, err)
99+
100+
w.WriteHeader(code)
101+
_, err = w.Write(body)
102+
require.NoError(t, err)
103+
104+
require.Equal(t, code, w.Status)
105+
require.Equal(t, body, rec.Body.Bytes())
106+
require.Equal(t, body[:4096], w.ResponseBody())
107+
})
108+
109+
t.Run("Hijack", func(t *testing.T) {
110+
t.Parallel()
111+
var (
112+
rec = httptest.NewRecorder()
113+
)
114+
115+
w := &httpapi.StatusWriter{ResponseWriter: hijacker{rec}}
116+
117+
_, _, err := w.Hijack()
118+
require.Error(t, err)
119+
require.Equal(t, "hijacked", err.Error())
120+
})
121+
}
122+
123+
type hijacker struct {
124+
http.ResponseWriter
125+
}
126+
127+
func (hijacker) Hijack() (net.Conn, *bufio.ReadWriter, error) {
128+
return nil, nil, xerrors.New("hijacked")
129+
}

0 commit comments

Comments
 (0)