Skip to content

Commit cb118a9

Browse files
committed
added additional checks
1 parent 3d44727 commit cb118a9

File tree

1 file changed

+35
-23
lines changed

1 file changed

+35
-23
lines changed

coderd/httpmw/logger_internal_test.go

Lines changed: 35 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"testing"
99
"time"
1010

11+
"github.com/stretchr/testify/assert"
1112
"github.com/stretchr/testify/require"
1213

1314
"cdr.dev/slog"
@@ -35,9 +36,9 @@ func TestRequestLogger_WriteLog(t *testing.T) {
3536

3637
require.Len(t, sink.entries, 1, "log was written twice")
3738

38-
require.Equal(t, sink.entries[0].Message, "GET", "log message should be GET")
39+
require.Equal(t, sink.entries[0].Message, "GET")
3940

40-
require.Equal(t, sink.entries[0].Fields[0].Value, "custom_value", "custom_field should be custom_value")
41+
require.Equal(t, sink.entries[0].Fields[0].Value, "custom_value")
4142

4243
// Attempt to write again (should be skipped).
4344
logCtx.WriteLog(ctx, http.StatusInternalServerError)
@@ -67,9 +68,7 @@ func TestLoggerMiddleware_SingleRequest(t *testing.T) {
6768

6869
// Create a test HTTP request
6970
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "/test-path", nil)
70-
if err != nil {
71-
t.Fatalf("failed to create request: %v", err)
72-
}
71+
require.NoError(t, err, "failed to create request")
7372

7473
sw := &tracing.StatusWriter{ResponseWriter: httptest.NewRecorder()}
7574

@@ -78,7 +77,7 @@ func TestLoggerMiddleware_SingleRequest(t *testing.T) {
7877

7978
require.Len(t, sink.entries, 1, "log was written twice")
8079

81-
require.Equal(t, sink.entries[0].Message, "GET", "log message should be GET")
80+
require.Equal(t, sink.entries[0].Message, "GET")
8281

8382
fieldsMap := make(map[string]interface{})
8483
for _, field := range sink.entries[0].Fields {
@@ -95,33 +94,33 @@ func TestLoggerMiddleware_SingleRequest(t *testing.T) {
9594
require.Len(t, sink.entries[0].Fields, len(requiredFields), "log should contain only the required fields")
9695

9796
// Check value of the status code
98-
require.Equal(t, fieldsMap["status_code"], http.StatusOK, "status_code should be 200")
97+
require.Equal(t, fieldsMap["status_code"], http.StatusOK)
9998
}
10099

101100
func TestLoggerMiddleware_WebSocket(t *testing.T) {
102101
t.Parallel()
103102
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
104103
defer cancel()
105104

106-
sink := &fakeSink{}
105+
sink := &fakeSink{
106+
newEntries: make(chan slog.SinkEntry, 2),
107+
}
107108
logger := slog.Make(sink)
108109
logger = logger.Leveled(slog.LevelDebug)
110+
done := make(chan struct{})
109111
wg := sync.WaitGroup{}
110112
// Create a test handler to simulate a WebSocket connection
111113
testHandler := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
112114
conn, err := websocket.Accept(rw, r, nil)
113-
if err != nil {
114-
t.Errorf("failed to accept websocket: %v", err)
115+
if !assert.NoError(t, err, "failed to accept websocket") {
115116
return
116117
}
117-
requestLgr := RequestLoggerFromContext(r.Context())
118-
requestLgr.WriteLog(r.Context(), http.StatusSwitchingProtocols)
119-
wg.Done()
120118
defer conn.Close(websocket.StatusNormalClosure, "")
121119

122-
// Send a couple of messages for testing
123-
_ = conn.Write(ctx, websocket.MessageText, []byte("ping"))
124-
_ = conn.Write(ctx, websocket.MessageText, []byte("pong"))
120+
requestLgr := RequestLoggerFromContext(r.Context())
121+
requestLgr.WriteLog(r.Context(), http.StatusSwitchingProtocols)
122+
// Block so we can be sure the end of the middleware isn't being called.
123+
wg.Wait()
125124
})
126125

127126
// Wrap the test handler with the Logger middleware
@@ -130,6 +129,7 @@ func TestLoggerMiddleware_WebSocket(t *testing.T) {
130129

131130
// RequestLogger expects the ResponseWriter to be *tracing.StatusWriter
132131
customHandler := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
132+
defer close(done)
133133
sw := &tracing.StatusWriter{ResponseWriter: rw}
134134
wrappedHandler.ServeHTTP(sw, r)
135135
})
@@ -139,22 +139,34 @@ func TestLoggerMiddleware_WebSocket(t *testing.T) {
139139
wg.Add(1)
140140
// nolint: bodyclose
141141
conn, _, err := websocket.Dial(ctx, srv.URL, nil)
142-
if err != nil {
143-
t.Fatalf("failed to create WebSocket connection: %v", err)
144-
}
142+
require.NoError(t, err, "failed to dial WebSocket")
145143
defer conn.Close(websocket.StatusNormalClosure, "")
146-
wg.Wait()
147-
require.Len(t, sink.entries, 1, "log was written twice")
148144

149-
require.Equal(t, sink.entries[0].Message, "GET", "log message should be GET")
145+
// Wait for the log from within the handler
146+
newEntry := testutil.RequireRecvCtx(ctx, t, sink.newEntries)
147+
require.Equal(t, newEntry.Message, "GET")
148+
149+
// Signal the websocket handler to return
150+
wg.Done()
151+
152+
// Wait for the request to finish completely and verify we only logged once
153+
_ = testutil.RequireRecvCtx(ctx, t, done)
154+
require.Len(t, sink.entries, 1, "log was written twice")
150155
}
151156

152157
type fakeSink struct {
153-
entries []slog.SinkEntry
158+
entries []slog.SinkEntry
159+
newEntries chan slog.SinkEntry
154160
}
155161

156162
func (s *fakeSink) LogEntry(_ context.Context, e slog.SinkEntry) {
157163
s.entries = append(s.entries, e)
164+
if s.newEntries != nil {
165+
select {
166+
case s.newEntries <- e:
167+
default:
168+
}
169+
}
158170
}
159171

160172
func (*fakeSink) Sync() {}

0 commit comments

Comments
 (0)