@@ -8,7 +8,10 @@ import (
8
8
"time"
9
9
10
10
"cdr.dev/slog"
11
+ "github.com/coder/coder/v2/coderd/httpapi"
11
12
"github.com/coder/coder/v2/coderd/tracing"
13
+ "github.com/coder/coder/v2/testutil"
14
+ "github.com/coder/websocket"
12
15
)
13
16
14
17
func TestRequestLogger_WriteLog (t * testing.T ) {
@@ -48,7 +51,7 @@ func TestRequestLogger_WriteLog(t *testing.T) {
48
51
}
49
52
}
50
53
51
- func TestLoggerMiddleware (t * testing.T ) {
54
+ func TestLoggerMiddleware_SingleRequest (t * testing.T ) {
52
55
t .Parallel ()
53
56
54
57
sink := & fakeSink {}
@@ -85,6 +88,54 @@ func TestLoggerMiddleware(t *testing.T) {
85
88
}
86
89
}
87
90
91
+ func TestLoggerMiddleware_WebSocket (t * testing.T ) {
92
+ t .Parallel ()
93
+ ctx , cancel := context .WithTimeout (context .Background (), testutil .WaitShort )
94
+ defer cancel ()
95
+
96
+ sink := & fakeSink {}
97
+ logger := slog .Make (sink )
98
+ logger = logger .Leveled (slog .LevelDebug )
99
+
100
+ // Create a test handler to simulate a WebSocket connection
101
+ testHandler := http .HandlerFunc (func (rw http.ResponseWriter , r * http.Request ) {
102
+ _ , err := websocket .Accept (rw , r , nil )
103
+ if err != nil {
104
+ httpapi .Write (ctx , rw , http .StatusBadRequest , nil )
105
+ return
106
+ }
107
+ time .Sleep (1000 )
108
+ })
109
+
110
+ // Wrap the test handler with the Logger middleware
111
+ loggerMiddleware := Logger (logger )
112
+ wrappedHandler := loggerMiddleware (testHandler )
113
+
114
+ // RequestLogger expects the ResponseWriter to be *tracing.StatusWriter
115
+ customHandler := http .HandlerFunc (func (rw http.ResponseWriter , r * http.Request ) {
116
+ sw := & tracing.StatusWriter {ResponseWriter : rw }
117
+ wrappedHandler .ServeHTTP (sw , r )
118
+ })
119
+
120
+ // Create a test HTTP request
121
+ srv := httptest .NewServer (customHandler )
122
+ defer srv .Close ()
123
+
124
+ conn , _ , err := websocket .Dial (ctx , srv .URL , nil )
125
+ if err != nil {
126
+ t .Fatalf ("failed to create WebSocket connection: %v" , err )
127
+ }
128
+ defer conn .Close (websocket .StatusNormalClosure , "" )
129
+
130
+ if len (sink .entries ) != 1 {
131
+ t .Fatalf ("expected 1 log entry, got %d" , len (sink .entries ))
132
+ }
133
+
134
+ if sink .entries [0 ].Message != "GET" {
135
+ t .Errorf ("expected log message to be 'GET', got '%s'" , sink .entries [0 ].Message )
136
+ }
137
+ }
138
+
88
139
type fakeSink struct {
89
140
entries []slog.SinkEntry
90
141
}
0 commit comments