Skip to content

Commit 991d38c

Browse files
feat: log long-lived connections acceptance (cherry-pick #17219) (#17495)
Cherry-picked feat: log long-lived connections acceptance (#17219) Closes #16904 Co-authored-by: Michael Suchacz <203725896+ibetitsmike@users.noreply.github.com>
1 parent 1d2af9c commit 991d38c

9 files changed

+351
-24
lines changed

Makefile

+7-1
Original file line numberDiff line numberDiff line change
@@ -581,7 +581,8 @@ GEN_FILES := \
581581
$(TAILNETTEST_MOCKS) \
582582
coderd/database/pubsub/psmock/psmock.go \
583583
agent/agentcontainers/acmock/acmock.go \
584-
agent/agentcontainers/dcspec/dcspec_gen.go
584+
agent/agentcontainers/dcspec/dcspec_gen.go \
585+
coderd/httpmw/loggermock/loggermock.go
585586

586587
# all gen targets should be added here and to gen/mark-fresh
587588
gen: gen/db gen/golden-files $(GEN_FILES)
@@ -630,6 +631,7 @@ gen/mark-fresh:
630631
coderd/database/pubsub/psmock/psmock.go \
631632
agent/agentcontainers/acmock/acmock.go \
632633
agent/agentcontainers/dcspec/dcspec_gen.go \
634+
coderd/httpmw/loggermock/loggermock.go \
633635
"
634636

635637
for file in $$files; do
@@ -669,6 +671,10 @@ agent/agentcontainers/acmock/acmock.go: agent/agentcontainers/containers.go
669671
go generate ./agent/agentcontainers/acmock/
670672
touch "$@"
671673

674+
coderd/httpmw/loggermock/loggermock.go: coderd/httpmw/logger.go
675+
go generate ./coderd/httpmw/loggermock/
676+
touch "$@"
677+
672678
agent/agentcontainers/dcspec/dcspec_gen.go: \
673679
node_modules/.installed \
674680
agent/agentcontainers/dcspec/devContainer.base.schema.json \

coderd/httpmw/logger.go

+74-23
Original file line numberDiff line numberDiff line change
@@ -35,42 +35,93 @@ func Logger(log slog.Logger) func(next http.Handler) http.Handler {
3535
slog.F("start", start),
3636
)
3737

38-
next.ServeHTTP(sw, r)
38+
logContext := NewRequestLogger(httplog, r.Method, start)
3939

40-
end := time.Now()
40+
ctx := WithRequestLogger(r.Context(), logContext)
41+
42+
next.ServeHTTP(sw, r.WithContext(ctx))
4143

4244
// Don't log successful health check requests.
4345
if r.URL.Path == "/api/v2" && sw.Status == http.StatusOK {
4446
return
4547
}
4648

47-
httplog = httplog.With(
48-
slog.F("took", end.Sub(start)),
49-
slog.F("status_code", sw.Status),
50-
slog.F("latency_ms", float64(end.Sub(start)/time.Millisecond)),
51-
)
52-
53-
// For status codes 400 and higher we
49+
// For status codes 500 and higher we
5450
// want to log the response body.
5551
if sw.Status >= http.StatusInternalServerError {
56-
httplog = httplog.With(
52+
logContext.WithFields(
5753
slog.F("response_body", string(sw.ResponseBody())),
5854
)
5955
}
6056

61-
// We should not log at level ERROR for 5xx status codes because 5xx
62-
// includes proxy errors etc. It also causes slogtest to fail
63-
// instantly without an error message by default.
64-
logLevelFn := httplog.Debug
65-
if sw.Status >= http.StatusInternalServerError {
66-
logLevelFn = httplog.Warn
67-
}
68-
69-
// We already capture most of this information in the span (minus
70-
// the response body which we don't want to capture anyways).
71-
tracing.RunWithoutSpan(r.Context(), func(ctx context.Context) {
72-
logLevelFn(ctx, r.Method)
73-
})
57+
logContext.WriteLog(r.Context(), sw.Status)
7458
})
7559
}
7660
}
61+
62+
type RequestLogger interface {
63+
WithFields(fields ...slog.Field)
64+
WriteLog(ctx context.Context, status int)
65+
}
66+
67+
type SlogRequestLogger struct {
68+
log slog.Logger
69+
written bool
70+
message string
71+
start time.Time
72+
}
73+
74+
var _ RequestLogger = &SlogRequestLogger{}
75+
76+
func NewRequestLogger(log slog.Logger, message string, start time.Time) RequestLogger {
77+
return &SlogRequestLogger{
78+
log: log,
79+
written: false,
80+
message: message,
81+
start: start,
82+
}
83+
}
84+
85+
func (c *SlogRequestLogger) WithFields(fields ...slog.Field) {
86+
c.log = c.log.With(fields...)
87+
}
88+
89+
func (c *SlogRequestLogger) WriteLog(ctx context.Context, status int) {
90+
if c.written {
91+
return
92+
}
93+
c.written = true
94+
end := time.Now()
95+
96+
logger := c.log.With(
97+
slog.F("took", end.Sub(c.start)),
98+
slog.F("status_code", status),
99+
slog.F("latency_ms", float64(end.Sub(c.start)/time.Millisecond)),
100+
)
101+
// We already capture most of this information in the span (minus
102+
// the response body which we don't want to capture anyways).
103+
tracing.RunWithoutSpan(ctx, func(ctx context.Context) {
104+
// We should not log at level ERROR for 5xx status codes because 5xx
105+
// includes proxy errors etc. It also causes slogtest to fail
106+
// instantly without an error message by default.
107+
if status >= http.StatusInternalServerError {
108+
logger.Warn(ctx, c.message)
109+
} else {
110+
logger.Debug(ctx, c.message)
111+
}
112+
})
113+
}
114+
115+
type logContextKey struct{}
116+
117+
func WithRequestLogger(ctx context.Context, rl RequestLogger) context.Context {
118+
return context.WithValue(ctx, logContextKey{}, rl)
119+
}
120+
121+
func RequestLoggerFromContext(ctx context.Context) RequestLogger {
122+
val := ctx.Value(logContextKey{})
123+
if logCtx, ok := val.(RequestLogger); ok {
124+
return logCtx
125+
}
126+
return nil
127+
}

coderd/httpmw/logger_internal_test.go

+174
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
package httpmw
2+
3+
import (
4+
"context"
5+
"net/http"
6+
"net/http/httptest"
7+
"sync"
8+
"testing"
9+
"time"
10+
11+
"github.com/stretchr/testify/assert"
12+
"github.com/stretchr/testify/require"
13+
14+
"cdr.dev/slog"
15+
"github.com/coder/coder/v2/coderd/tracing"
16+
"github.com/coder/coder/v2/testutil"
17+
"github.com/coder/websocket"
18+
)
19+
20+
func TestRequestLogger_WriteLog(t *testing.T) {
21+
t.Parallel()
22+
ctx := context.Background()
23+
24+
sink := &fakeSink{}
25+
logger := slog.Make(sink)
26+
logger = logger.Leveled(slog.LevelDebug)
27+
logCtx := NewRequestLogger(logger, "GET", time.Now())
28+
29+
// Add custom fields
30+
logCtx.WithFields(
31+
slog.F("custom_field", "custom_value"),
32+
)
33+
34+
// Write log for 200 status
35+
logCtx.WriteLog(ctx, http.StatusOK)
36+
37+
require.Len(t, sink.entries, 1, "log was written twice")
38+
39+
require.Equal(t, sink.entries[0].Message, "GET")
40+
41+
require.Equal(t, sink.entries[0].Fields[0].Value, "custom_value")
42+
43+
// Attempt to write again (should be skipped).
44+
logCtx.WriteLog(ctx, http.StatusInternalServerError)
45+
46+
require.Len(t, sink.entries, 1, "log was written twice")
47+
}
48+
49+
func TestLoggerMiddleware_SingleRequest(t *testing.T) {
50+
t.Parallel()
51+
52+
sink := &fakeSink{}
53+
logger := slog.Make(sink)
54+
logger = logger.Leveled(slog.LevelDebug)
55+
56+
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
57+
defer cancel()
58+
59+
// Create a test handler to simulate an HTTP request
60+
testHandler := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
61+
rw.WriteHeader(http.StatusOK)
62+
_, _ = rw.Write([]byte("OK"))
63+
})
64+
65+
// Wrap the test handler with the Logger middleware
66+
loggerMiddleware := Logger(logger)
67+
wrappedHandler := loggerMiddleware(testHandler)
68+
69+
// Create a test HTTP request
70+
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "/test-path", nil)
71+
require.NoError(t, err, "failed to create request")
72+
73+
sw := &tracing.StatusWriter{ResponseWriter: httptest.NewRecorder()}
74+
75+
// Serve the request
76+
wrappedHandler.ServeHTTP(sw, req)
77+
78+
require.Len(t, sink.entries, 1, "log was written twice")
79+
80+
require.Equal(t, sink.entries[0].Message, "GET")
81+
82+
fieldsMap := make(map[string]interface{})
83+
for _, field := range sink.entries[0].Fields {
84+
fieldsMap[field.Name] = field.Value
85+
}
86+
87+
// Check that the log contains the expected fields
88+
requiredFields := []string{"host", "path", "proto", "remote_addr", "start", "took", "status_code", "latency_ms"}
89+
for _, field := range requiredFields {
90+
_, exists := fieldsMap[field]
91+
require.True(t, exists, "field %q is missing in log fields", field)
92+
}
93+
94+
require.Len(t, sink.entries[0].Fields, len(requiredFields), "log should contain only the required fields")
95+
96+
// Check value of the status code
97+
require.Equal(t, fieldsMap["status_code"], http.StatusOK)
98+
}
99+
100+
func TestLoggerMiddleware_WebSocket(t *testing.T) {
101+
t.Parallel()
102+
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
103+
defer cancel()
104+
105+
sink := &fakeSink{
106+
newEntries: make(chan slog.SinkEntry, 2),
107+
}
108+
logger := slog.Make(sink)
109+
logger = logger.Leveled(slog.LevelDebug)
110+
done := make(chan struct{})
111+
wg := sync.WaitGroup{}
112+
// Create a test handler to simulate a WebSocket connection
113+
testHandler := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
114+
conn, err := websocket.Accept(rw, r, nil)
115+
if !assert.NoError(t, err, "failed to accept websocket") {
116+
return
117+
}
118+
defer conn.Close(websocket.StatusGoingAway, "")
119+
120+
requestLgr := RequestLoggerFromContext(r.Context())
121+
requestLgr.WriteLog(r.Context(), http.StatusSwitchingProtocols)
122+
// Block so we can be sure the end of the middleware isn't being called.
123+
wg.Wait()
124+
})
125+
126+
// Wrap the test handler with the Logger middleware
127+
loggerMiddleware := Logger(logger)
128+
wrappedHandler := loggerMiddleware(testHandler)
129+
130+
// RequestLogger expects the ResponseWriter to be *tracing.StatusWriter
131+
customHandler := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
132+
defer close(done)
133+
sw := &tracing.StatusWriter{ResponseWriter: rw}
134+
wrappedHandler.ServeHTTP(sw, r)
135+
})
136+
137+
srv := httptest.NewServer(customHandler)
138+
defer srv.Close()
139+
wg.Add(1)
140+
// nolint: bodyclose
141+
conn, _, err := websocket.Dial(ctx, srv.URL, nil)
142+
require.NoError(t, err, "failed to dial WebSocket")
143+
defer conn.Close(websocket.StatusNormalClosure, "")
144+
145+
// Wait for the log from within the handler
146+
newEntry := testutil.RequireRecvCtx(ctx, t, sink.newEntries)
147+
require.Equal(t, newEntry.Message, "GET")
148+
149+
// Signal the websocket handler to return (and read to handle the close frame)
150+
wg.Done()
151+
_, _, err = conn.Read(ctx)
152+
require.ErrorAs(t, err, &websocket.CloseError{}, "websocket read should fail with close error")
153+
154+
// Wait for the request to finish completely and verify we only logged once
155+
_ = testutil.RequireRecvCtx(ctx, t, done)
156+
require.Len(t, sink.entries, 1, "log was written twice")
157+
}
158+
159+
type fakeSink struct {
160+
entries []slog.SinkEntry
161+
newEntries chan slog.SinkEntry
162+
}
163+
164+
func (s *fakeSink) LogEntry(_ context.Context, e slog.SinkEntry) {
165+
s.entries = append(s.entries, e)
166+
if s.newEntries != nil {
167+
select {
168+
case s.newEntries <- e:
169+
default:
170+
}
171+
}
172+
}
173+
174+
func (*fakeSink) Sync() {}

coderd/httpmw/loggermock/loggermock.go

+70
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)