8
8
"testing"
9
9
"time"
10
10
11
+ "github.com/stretchr/testify/assert"
11
12
"github.com/stretchr/testify/require"
12
13
13
14
"cdr.dev/slog"
@@ -35,9 +36,9 @@ func TestRequestLogger_WriteLog(t *testing.T) {
35
36
36
37
require .Len (t , sink .entries , 1 , "log was written twice" )
37
38
38
- require .Equal (t , sink .entries [0 ].Message , "GET" , "log message should be GET" )
39
+ require .Equal (t , sink .entries [0 ].Message , "GET" )
39
40
40
- require .Equal (t , sink .entries [0 ].Fields [0 ].Value , "custom_value" , "custom_field should be custom_value" )
41
+ require .Equal (t , sink .entries [0 ].Fields [0 ].Value , "custom_value" )
41
42
42
43
// Attempt to write again (should be skipped).
43
44
logCtx .WriteLog (ctx , http .StatusInternalServerError )
@@ -67,9 +68,7 @@ func TestLoggerMiddleware_SingleRequest(t *testing.T) {
67
68
68
69
// Create a test HTTP request
69
70
req , err := http .NewRequestWithContext (ctx , http .MethodGet , "/test-path" , nil )
70
- if err != nil {
71
- t .Fatalf ("failed to create request: %v" , err )
72
- }
71
+ require .NoError (t , err , "failed to create request" )
73
72
74
73
sw := & tracing.StatusWriter {ResponseWriter : httptest .NewRecorder ()}
75
74
@@ -78,7 +77,7 @@ func TestLoggerMiddleware_SingleRequest(t *testing.T) {
78
77
79
78
require .Len (t , sink .entries , 1 , "log was written twice" )
80
79
81
- require .Equal (t , sink .entries [0 ].Message , "GET" , "log message should be GET" )
80
+ require .Equal (t , sink .entries [0 ].Message , "GET" )
82
81
83
82
fieldsMap := make (map [string ]interface {})
84
83
for _ , field := range sink .entries [0 ].Fields {
@@ -95,33 +94,33 @@ func TestLoggerMiddleware_SingleRequest(t *testing.T) {
95
94
require .Len (t , sink .entries [0 ].Fields , len (requiredFields ), "log should contain only the required fields" )
96
95
97
96
// Check value of the status code
98
- require .Equal (t , fieldsMap ["status_code" ], http .StatusOK , "status_code should be 200" )
97
+ require .Equal (t , fieldsMap ["status_code" ], http .StatusOK )
99
98
}
100
99
101
100
func TestLoggerMiddleware_WebSocket (t * testing.T ) {
102
101
t .Parallel ()
103
102
ctx , cancel := context .WithTimeout (context .Background (), testutil .WaitShort )
104
103
defer cancel ()
105
104
106
- sink := & fakeSink {}
105
+ sink := & fakeSink {
106
+ newEntries : make (chan slog.SinkEntry , 2 ),
107
+ }
107
108
logger := slog .Make (sink )
108
109
logger = logger .Leveled (slog .LevelDebug )
110
+ done := make (chan struct {})
109
111
wg := sync.WaitGroup {}
110
112
// Create a test handler to simulate a WebSocket connection
111
113
testHandler := http .HandlerFunc (func (rw http.ResponseWriter , r * http.Request ) {
112
114
conn , err := websocket .Accept (rw , r , nil )
113
- if err != nil {
114
- t .Errorf ("failed to accept websocket: %v" , err )
115
+ if ! assert .NoError (t , err , "failed to accept websocket" ) {
115
116
return
116
117
}
117
- requestLgr := RequestLoggerFromContext (r .Context ())
118
- requestLgr .WriteLog (r .Context (), http .StatusSwitchingProtocols )
119
- wg .Done ()
120
118
defer conn .Close (websocket .StatusNormalClosure , "" )
121
119
122
- // Send a couple of messages for testing
123
- _ = conn .Write (ctx , websocket .MessageText , []byte ("ping" ))
124
- _ = conn .Write (ctx , websocket .MessageText , []byte ("pong" ))
120
+ requestLgr := RequestLoggerFromContext (r .Context ())
121
+ requestLgr .WriteLog (r .Context (), http .StatusSwitchingProtocols )
122
+ // Block so we can be sure the end of the middleware isn't being called.
123
+ wg .Wait ()
125
124
})
126
125
127
126
// Wrap the test handler with the Logger middleware
@@ -130,6 +129,7 @@ func TestLoggerMiddleware_WebSocket(t *testing.T) {
130
129
131
130
// RequestLogger expects the ResponseWriter to be *tracing.StatusWriter
132
131
customHandler := http .HandlerFunc (func (rw http.ResponseWriter , r * http.Request ) {
132
+ defer close (done )
133
133
sw := & tracing.StatusWriter {ResponseWriter : rw }
134
134
wrappedHandler .ServeHTTP (sw , r )
135
135
})
@@ -139,22 +139,34 @@ func TestLoggerMiddleware_WebSocket(t *testing.T) {
139
139
wg .Add (1 )
140
140
// nolint: bodyclose
141
141
conn , _ , err := websocket .Dial (ctx , srv .URL , nil )
142
- if err != nil {
143
- t .Fatalf ("failed to create WebSocket connection: %v" , err )
144
- }
142
+ require .NoError (t , err , "failed to dial WebSocket" )
145
143
defer conn .Close (websocket .StatusNormalClosure , "" )
146
- wg .Wait ()
147
- require .Len (t , sink .entries , 1 , "log was written twice" )
148
144
149
- require .Equal (t , sink .entries [0 ].Message , "GET" , "log message should be GET" )
145
+ // Wait for the log from within the handler
146
+ newEntry := testutil .RequireRecvCtx (ctx , t , sink .newEntries )
147
+ require .Equal (t , newEntry .Message , "GET" )
148
+
149
+ // Signal the websocket handler to return
150
+ wg .Done ()
151
+
152
+ // Wait for the request to finish completely and verify we only logged once
153
+ _ = testutil .RequireRecvCtx (ctx , t , done )
154
+ require .Len (t , sink .entries , 1 , "log was written twice" )
150
155
}
151
156
152
157
type fakeSink struct {
153
- entries []slog.SinkEntry
158
+ entries []slog.SinkEntry
159
+ newEntries chan slog.SinkEntry
154
160
}
155
161
156
162
func (s * fakeSink ) LogEntry (_ context.Context , e slog.SinkEntry ) {
157
163
s .entries = append (s .entries , e )
164
+ if s .newEntries != nil {
165
+ select {
166
+ case s .newEntries <- e :
167
+ default :
168
+ }
169
+ }
158
170
}
159
171
160
172
func (* fakeSink ) Sync () {}
0 commit comments