Skip to content

Commit dfa2f4f

Browse files
committed
added WebSocket test to verify early logging
1 parent ca3c4d3 commit dfa2f4f

File tree

1 file changed

+52
-1
lines changed

1 file changed

+52
-1
lines changed

coderd/httpmw/logger_internal_test.go

+52-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,10 @@ import (
88
"time"
99

1010
"cdr.dev/slog"
11+
"github.com/coder/coder/v2/coderd/httpapi"
1112
"github.com/coder/coder/v2/coderd/tracing"
13+
"github.com/coder/coder/v2/testutil"
14+
"github.com/coder/websocket"
1215
)
1316

1417
func TestRequestLogger_WriteLog(t *testing.T) {
@@ -48,7 +51,7 @@ func TestRequestLogger_WriteLog(t *testing.T) {
4851
}
4952
}
5053

51-
func TestLoggerMiddleware(t *testing.T) {
54+
func TestLoggerMiddleware_SingleRequest(t *testing.T) {
5255
t.Parallel()
5356

5457
sink := &fakeSink{}
@@ -85,6 +88,54 @@ func TestLoggerMiddleware(t *testing.T) {
8588
}
8689
}
8790

91+
func TestLoggerMiddleware_WebSocket(t *testing.T) {
92+
t.Parallel()
93+
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
94+
defer cancel()
95+
96+
sink := &fakeSink{}
97+
logger := slog.Make(sink)
98+
logger = logger.Leveled(slog.LevelDebug)
99+
100+
// Create a test handler to simulate a WebSocket connection
101+
testHandler := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
102+
_, err := websocket.Accept(rw, r, nil)
103+
if err != nil {
104+
httpapi.Write(ctx, rw, http.StatusBadRequest, nil)
105+
return
106+
}
107+
time.Sleep(1000)
108+
})
109+
110+
// Wrap the test handler with the Logger middleware
111+
loggerMiddleware := Logger(logger)
112+
wrappedHandler := loggerMiddleware(testHandler)
113+
114+
// RequestLogger expects the ResponseWriter to be *tracing.StatusWriter
115+
customHandler := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
116+
sw := &tracing.StatusWriter{ResponseWriter: rw}
117+
wrappedHandler.ServeHTTP(sw, r)
118+
})
119+
120+
// Create a test HTTP request
121+
srv := httptest.NewServer(customHandler)
122+
defer srv.Close()
123+
124+
conn, _, err := websocket.Dial(ctx, srv.URL, nil)
125+
if err != nil {
126+
t.Fatalf("failed to create WebSocket connection: %v", err)
127+
}
128+
defer conn.Close(websocket.StatusNormalClosure, "")
129+
130+
if len(sink.entries) != 1 {
131+
t.Fatalf("expected 1 log entry, got %d", len(sink.entries))
132+
}
133+
134+
if sink.entries[0].Message != "GET" {
135+
t.Errorf("expected log message to be 'GET', got '%s'", sink.entries[0].Message)
136+
}
137+
}
138+
88139
type fakeSink struct {
89140
entries []slog.SinkEntry
90141
}

0 commit comments

Comments
 (0)