@@ -8,14 +8,19 @@ import (
8
8
9
9
"github.com/go-chi/chi/v5"
10
10
"github.com/prometheus/client_golang/prometheus"
11
+ cm "github.com/prometheus/client_model/go"
12
+ "github.com/stretchr/testify/assert"
11
13
"github.com/stretchr/testify/require"
12
14
13
15
"github.com/coder/coder/v2/coderd/httpmw"
14
16
"github.com/coder/coder/v2/coderd/tracing"
17
+ "github.com/coder/coder/v2/testutil"
18
+ "github.com/coder/websocket"
15
19
)
16
20
17
21
func TestPrometheus (t * testing.T ) {
18
22
t .Parallel ()
23
+
19
24
t .Run ("All" , func (t * testing.T ) {
20
25
t .Parallel ()
21
26
req := httptest .NewRequest ("GET" , "/" , nil )
@@ -29,4 +34,90 @@ func TestPrometheus(t *testing.T) {
29
34
require .NoError (t , err )
30
35
require .Greater (t , len (metrics ), 0 )
31
36
})
37
+
38
+ t .Run ("Concurrent" , func (t * testing.T ) {
39
+ t .Parallel ()
40
+ ctx , cancel := context .WithTimeout (context .Background (), testutil .WaitShort )
41
+ defer cancel ()
42
+
43
+ reg := prometheus .NewRegistry ()
44
+ promMW := httpmw .Prometheus (reg )
45
+
46
+ // Create a test handler to simulate a WebSocket connection
47
+ testHandler := http .HandlerFunc (func (rw http.ResponseWriter , r * http.Request ) {
48
+ conn , err := websocket .Accept (rw , r , nil )
49
+ if ! assert .NoError (t , err , "failed to accept websocket" ) {
50
+ return
51
+ }
52
+ defer conn .Close (websocket .StatusGoingAway , "" )
53
+ })
54
+
55
+ wrappedHandler := promMW (testHandler )
56
+
57
+ r := chi .NewRouter ()
58
+ r .Use (tracing .StatusWriterMiddleware , promMW )
59
+ r .Get ("/api/v2/build/{build}/logs" , func (rw http.ResponseWriter , r * http.Request ) {
60
+ wrappedHandler .ServeHTTP (rw , r )
61
+ })
62
+
63
+ srv := httptest .NewServer (r )
64
+ defer srv .Close ()
65
+ // nolint: bodyclose
66
+ conn , _ , err := websocket .Dial (ctx , srv .URL + "/api/v2/build/1/logs" , nil )
67
+ require .NoError (t , err , "failed to dial WebSocket" )
68
+ defer conn .Close (websocket .StatusNormalClosure , "" )
69
+
70
+ metrics , err := reg .Gather ()
71
+ require .NoError (t , err )
72
+ require .Greater (t , len (metrics ), 0 )
73
+ metricLabels := getMetricLabels (metrics )
74
+
75
+ concurrentWebsockets , ok := metricLabels ["coderd_api_concurrent_websockets" ]
76
+ require .True (t , ok , "coderd_api_concurrent_websockets metric not found" )
77
+ require .Equal (t , "/api/v2/build/{build}/logs" , concurrentWebsockets ["path" ])
78
+ })
79
+
80
+ t .Run ("UserRoute" , func (t * testing.T ) {
81
+ t .Parallel ()
82
+ reg := prometheus .NewRegistry ()
83
+ promMW := httpmw .Prometheus (reg )
84
+
85
+ r := chi .NewRouter ()
86
+ r .With (promMW ).Get ("/api/v2/users/{user}" , func (w http.ResponseWriter , r * http.Request ) {})
87
+
88
+ req := httptest .NewRequest ("GET" , "/api/v2/users/john" , nil )
89
+
90
+ sw := & tracing.StatusWriter {ResponseWriter : httptest .NewRecorder ()}
91
+
92
+ r .ServeHTTP (sw , req )
93
+
94
+ metrics , err := reg .Gather ()
95
+ require .NoError (t , err )
96
+ require .Greater (t , len (metrics ), 0 )
97
+ metricLabels := getMetricLabels (metrics )
98
+
99
+ reqProcessed , ok := metricLabels ["coderd_api_requests_processed_total" ]
100
+ require .True (t , ok , "coderd_api_requests_processed_total metric not found" )
101
+ require .Equal (t , "/api/v2/users/{user}" , reqProcessed ["path" ])
102
+ require .Equal (t , "GET" , reqProcessed ["method" ])
103
+
104
+ concurrentRequests , ok := metricLabels ["coderd_api_concurrent_requests" ]
105
+ require .True (t , ok , "coderd_api_concurrent_requests metric not found" )
106
+ require .Equal (t , "/api/v2/users/{user}" , concurrentRequests ["path" ])
107
+ require .Equal (t , "GET" , concurrentRequests ["method" ])
108
+ })
109
+ }
110
+
111
+ func getMetricLabels (metrics []* cm.MetricFamily ) map [string ]map [string ]string {
112
+ metricLabels := map [string ]map [string ]string {}
113
+ for _ , metricFamily := range metrics {
114
+ metricName := metricFamily .GetName ()
115
+ metricLabels [metricName ] = map [string ]string {}
116
+ for _ , metric := range metricFamily .GetMetric () {
117
+ for _ , labelPair := range metric .GetLabel () {
118
+ metricLabels [metricName ][labelPair .GetName ()] = labelPair .GetValue ()
119
+ }
120
+ }
121
+ }
122
+ return metricLabels
32
123
}
0 commit comments