Skip to content

feat: log long-lived connections acceptance (cherry-pick #17219) #17495

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -581,7 +581,8 @@ GEN_FILES := \
$(TAILNETTEST_MOCKS) \
coderd/database/pubsub/psmock/psmock.go \
agent/agentcontainers/acmock/acmock.go \
agent/agentcontainers/dcspec/dcspec_gen.go
agent/agentcontainers/dcspec/dcspec_gen.go \
coderd/httpmw/loggermock/loggermock.go

# all gen targets should be added here and to gen/mark-fresh
gen: gen/db gen/golden-files $(GEN_FILES)
Expand Down Expand Up @@ -630,6 +631,7 @@ gen/mark-fresh:
coderd/database/pubsub/psmock/psmock.go \
agent/agentcontainers/acmock/acmock.go \
agent/agentcontainers/dcspec/dcspec_gen.go \
coderd/httpmw/loggermock/loggermock.go \
"

for file in $$files; do
Expand Down Expand Up @@ -669,6 +671,10 @@ agent/agentcontainers/acmock/acmock.go: agent/agentcontainers/containers.go
go generate ./agent/agentcontainers/acmock/
touch "$@"

coderd/httpmw/loggermock/loggermock.go: coderd/httpmw/logger.go
go generate ./coderd/httpmw/loggermock/
touch "$@"

agent/agentcontainers/dcspec/dcspec_gen.go: \
node_modules/.installed \
agent/agentcontainers/dcspec/devContainer.base.schema.json \
Expand Down
97 changes: 74 additions & 23 deletions coderd/httpmw/logger.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,42 +35,93 @@ func Logger(log slog.Logger) func(next http.Handler) http.Handler {
slog.F("start", start),
)

next.ServeHTTP(sw, r)
logContext := NewRequestLogger(httplog, r.Method, start)

end := time.Now()
ctx := WithRequestLogger(r.Context(), logContext)

next.ServeHTTP(sw, r.WithContext(ctx))

// Don't log successful health check requests.
if r.URL.Path == "/api/v2" && sw.Status == http.StatusOK {
return
}

httplog = httplog.With(
slog.F("took", end.Sub(start)),
slog.F("status_code", sw.Status),
slog.F("latency_ms", float64(end.Sub(start)/time.Millisecond)),
)

// For status codes 400 and higher we
// For status codes 500 and higher we
// want to log the response body.
if sw.Status >= http.StatusInternalServerError {
httplog = httplog.With(
logContext.WithFields(
slog.F("response_body", string(sw.ResponseBody())),
)
}

// We should not log at level ERROR for 5xx status codes because 5xx
// includes proxy errors etc. It also causes slogtest to fail
// instantly without an error message by default.
logLevelFn := httplog.Debug
if sw.Status >= http.StatusInternalServerError {
logLevelFn = httplog.Warn
}

// We already capture most of this information in the span (minus
// the response body which we don't want to capture anyways).
tracing.RunWithoutSpan(r.Context(), func(ctx context.Context) {
logLevelFn(ctx, r.Method)
})
logContext.WriteLog(r.Context(), sw.Status)
})
}
}

type RequestLogger interface {
WithFields(fields ...slog.Field)
WriteLog(ctx context.Context, status int)
}

type SlogRequestLogger struct {
log slog.Logger
written bool
message string
start time.Time
}

var _ RequestLogger = &SlogRequestLogger{}

func NewRequestLogger(log slog.Logger, message string, start time.Time) RequestLogger {
return &SlogRequestLogger{
log: log,
written: false,
message: message,
start: start,
}
}

func (c *SlogRequestLogger) WithFields(fields ...slog.Field) {
c.log = c.log.With(fields...)
}

func (c *SlogRequestLogger) WriteLog(ctx context.Context, status int) {
if c.written {
return
}
c.written = true
end := time.Now()

logger := c.log.With(
slog.F("took", end.Sub(c.start)),
slog.F("status_code", status),
slog.F("latency_ms", float64(end.Sub(c.start)/time.Millisecond)),
)
// We already capture most of this information in the span (minus
// the response body which we don't want to capture anyways).
tracing.RunWithoutSpan(ctx, func(ctx context.Context) {
// We should not log at level ERROR for 5xx status codes because 5xx
// includes proxy errors etc. It also causes slogtest to fail
// instantly without an error message by default.
if status >= http.StatusInternalServerError {
logger.Warn(ctx, c.message)
} else {
logger.Debug(ctx, c.message)
}
})
}

type logContextKey struct{}

func WithRequestLogger(ctx context.Context, rl RequestLogger) context.Context {
return context.WithValue(ctx, logContextKey{}, rl)
}

func RequestLoggerFromContext(ctx context.Context) RequestLogger {
val := ctx.Value(logContextKey{})
if logCtx, ok := val.(RequestLogger); ok {
return logCtx
}
return nil
}
174 changes: 174 additions & 0 deletions coderd/httpmw/logger_internal_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
package httpmw

import (
"context"
"net/http"
"net/http/httptest"
"sync"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"cdr.dev/slog"
"github.com/coder/coder/v2/coderd/tracing"
"github.com/coder/coder/v2/testutil"
"github.com/coder/websocket"
)

func TestRequestLogger_WriteLog(t *testing.T) {
t.Parallel()
ctx := context.Background()

sink := &fakeSink{}
logger := slog.Make(sink)
logger = logger.Leveled(slog.LevelDebug)
logCtx := NewRequestLogger(logger, "GET", time.Now())

// Add custom fields
logCtx.WithFields(
slog.F("custom_field", "custom_value"),
)

// Write log for 200 status
logCtx.WriteLog(ctx, http.StatusOK)

require.Len(t, sink.entries, 1, "log was written twice")

require.Equal(t, sink.entries[0].Message, "GET")

require.Equal(t, sink.entries[0].Fields[0].Value, "custom_value")

// Attempt to write again (should be skipped).
logCtx.WriteLog(ctx, http.StatusInternalServerError)

require.Len(t, sink.entries, 1, "log was written twice")
}

func TestLoggerMiddleware_SingleRequest(t *testing.T) {
t.Parallel()

sink := &fakeSink{}
logger := slog.Make(sink)
logger = logger.Leveled(slog.LevelDebug)

ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
defer cancel()

// Create a test handler to simulate an HTTP request
testHandler := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
rw.WriteHeader(http.StatusOK)
_, _ = rw.Write([]byte("OK"))
})

// Wrap the test handler with the Logger middleware
loggerMiddleware := Logger(logger)
wrappedHandler := loggerMiddleware(testHandler)

// Create a test HTTP request
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "/test-path", nil)
require.NoError(t, err, "failed to create request")

sw := &tracing.StatusWriter{ResponseWriter: httptest.NewRecorder()}

// Serve the request
wrappedHandler.ServeHTTP(sw, req)

require.Len(t, sink.entries, 1, "log was written twice")

require.Equal(t, sink.entries[0].Message, "GET")

fieldsMap := make(map[string]interface{})
for _, field := range sink.entries[0].Fields {
fieldsMap[field.Name] = field.Value
}

// Check that the log contains the expected fields
requiredFields := []string{"host", "path", "proto", "remote_addr", "start", "took", "status_code", "latency_ms"}
for _, field := range requiredFields {
_, exists := fieldsMap[field]
require.True(t, exists, "field %q is missing in log fields", field)
}

require.Len(t, sink.entries[0].Fields, len(requiredFields), "log should contain only the required fields")

// Check value of the status code
require.Equal(t, fieldsMap["status_code"], http.StatusOK)
}

func TestLoggerMiddleware_WebSocket(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
defer cancel()

sink := &fakeSink{
newEntries: make(chan slog.SinkEntry, 2),
}
logger := slog.Make(sink)
logger = logger.Leveled(slog.LevelDebug)
done := make(chan struct{})
wg := sync.WaitGroup{}
// Create a test handler to simulate a WebSocket connection
testHandler := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
conn, err := websocket.Accept(rw, r, nil)
if !assert.NoError(t, err, "failed to accept websocket") {
return
}
defer conn.Close(websocket.StatusGoingAway, "")

requestLgr := RequestLoggerFromContext(r.Context())
requestLgr.WriteLog(r.Context(), http.StatusSwitchingProtocols)
// Block so we can be sure the end of the middleware isn't being called.
wg.Wait()
})

// Wrap the test handler with the Logger middleware
loggerMiddleware := Logger(logger)
wrappedHandler := loggerMiddleware(testHandler)

// RequestLogger expects the ResponseWriter to be *tracing.StatusWriter
customHandler := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
defer close(done)
sw := &tracing.StatusWriter{ResponseWriter: rw}
wrappedHandler.ServeHTTP(sw, r)
})

srv := httptest.NewServer(customHandler)
defer srv.Close()
wg.Add(1)
// nolint: bodyclose
conn, _, err := websocket.Dial(ctx, srv.URL, nil)
require.NoError(t, err, "failed to dial WebSocket")
defer conn.Close(websocket.StatusNormalClosure, "")

// Wait for the log from within the handler
newEntry := testutil.RequireRecvCtx(ctx, t, sink.newEntries)
require.Equal(t, newEntry.Message, "GET")

// Signal the websocket handler to return (and read to handle the close frame)
wg.Done()
_, _, err = conn.Read(ctx)
require.ErrorAs(t, err, &websocket.CloseError{}, "websocket read should fail with close error")

// Wait for the request to finish completely and verify we only logged once
_ = testutil.RequireRecvCtx(ctx, t, done)
require.Len(t, sink.entries, 1, "log was written twice")
}

type fakeSink struct {
entries []slog.SinkEntry
newEntries chan slog.SinkEntry
}

func (s *fakeSink) LogEntry(_ context.Context, e slog.SinkEntry) {
s.entries = append(s.entries, e)
if s.newEntries != nil {
select {
case s.newEntries <- e:
default:
}
}
}

func (*fakeSink) Sync() {}
70 changes: 70 additions & 0 deletions coderd/httpmw/loggermock/loggermock.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading
Loading