Skip to content

Commit 51c97c1

Browse files
committed
added tests for promethus middleware
1 parent 4698231 commit 51c97c1

File tree

1 file changed

+77
-56
lines changed

1 file changed

+77
-56
lines changed

coderd/httpmw/prometheus_internal_test.go

Lines changed: 77 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,29 @@ import (
88

99
"github.com/go-chi/chi/v5"
1010
"github.com/prometheus/client_golang/prometheus"
11+
dto "github.com/prometheus/client_model/go"
12+
"github.com/stretchr/testify/assert"
1113
"github.com/stretchr/testify/require"
1214

1315
"github.com/coder/coder/v2/coderd/tracing"
16+
"github.com/coder/coder/v2/testutil"
17+
"github.com/coder/websocket"
1418
)
1519

20+
func getMetricLabels(metrics []*dto.MetricFamily) map[string]map[string]string {
21+
metricLabels := map[string]map[string]string{}
22+
for _, metricFamily := range metrics {
23+
metricName := metricFamily.GetName()
24+
metricLabels[metricName] = map[string]string{}
25+
for _, metric := range metricFamily.GetMetric() {
26+
for _, labelPair := range metric.GetLabel() {
27+
metricLabels[metricName][labelPair.GetName()] = labelPair.GetValue()
28+
}
29+
}
30+
}
31+
return metricLabels
32+
}
33+
1634
func TestPrometheus(t *testing.T) {
1735
t.Parallel()
1836
t.Run("All", func(t *testing.T) {
@@ -30,71 +48,74 @@ func TestPrometheus(t *testing.T) {
3048
})
3149
}
3250

33-
func TestGetRoutePattern(t *testing.T) {
51+
func TestPrometheus_Concurrent(t *testing.T) {
3452
t.Parallel()
53+
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
54+
defer cancel()
55+
3556
reg := prometheus.NewRegistry()
3657
promMW := Prometheus(reg)
37-
// Create a test router with some routes
58+
59+
// Create a test handler to simulate a WebSocket connection
60+
testHandler := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
61+
conn, err := websocket.Accept(rw, r, nil)
62+
if !assert.NoError(t, err, "failed to accept websocket") {
63+
return
64+
}
65+
defer conn.Close(websocket.StatusGoingAway, "")
66+
})
67+
68+
wrappedHandler := promMW(testHandler)
69+
70+
r := chi.NewRouter()
71+
r.Use(tracing.StatusWriterMiddleware, promMW)
72+
r.Get("/api/v2/build/{build}/logs", func(rw http.ResponseWriter, r *http.Request) {
73+
wrappedHandler.ServeHTTP(rw, r)
74+
})
75+
76+
srv := httptest.NewServer(r)
77+
defer srv.Close()
78+
// nolint: bodyclose
79+
conn, _, err := websocket.Dial(ctx, srv.URL+"/api/v2/build/1/logs", nil)
80+
require.NoError(t, err, "failed to dial WebSocket")
81+
defer conn.Close(websocket.StatusNormalClosure, "")
82+
83+
metrics, err := reg.Gather()
84+
require.NoError(t, err)
85+
require.Greater(t, len(metrics), 0)
86+
metricLabels := getMetricLabels(metrics)
87+
88+
concurrentWebsockets, ok := metricLabels["coderd_api_concurrent_websockets"]
89+
require.True(t, ok, "coderd_api_concurrent_websockets metric not found")
90+
require.Equal(t, "/api/v2/build/{build}/logs", concurrentWebsockets["path"])
91+
}
92+
93+
func TestGetRoutePattern_UserRoute(t *testing.T) {
94+
t.Parallel()
95+
reg := prometheus.NewRegistry()
96+
promMW := Prometheus(reg)
97+
3898
r := chi.NewRouter()
39-
r.With(promMW).Get("/api/v2/workspaces/{workspace}", func(w http.ResponseWriter, r *http.Request) {})
4099
r.With(promMW).Get("/api/v2/users/{user}", func(w http.ResponseWriter, r *http.Request) {})
41-
r.With(promMW).Get("/static/*", func(w http.ResponseWriter, r *http.Request) {})
42-
43-
tests := []struct {
44-
name string
45-
method string
46-
path string
47-
expected string
48-
}{
49-
{
50-
name: "PatternAlreadyAvailable",
51-
method: "GET",
52-
path: "/api/v2/workspaces/test",
53-
expected: "/api/v2/workspaces/{workspace}",
54-
},
55-
{
56-
name: "UserRoute",
57-
method: "GET",
58-
path: "/api/v2/users/john",
59-
expected: "/api/v2/users/{user}",
60-
},
61-
{
62-
name: "StaticRoute",
63-
method: "GET",
64-
path: "/static/css/style.css",
65-
expected: "/static/*",
66-
},
67-
{
68-
name: "NoMatchingRoute",
69-
method: "GET",
70-
path: "/nonexistent",
71-
expected: "",
72-
},
73-
{
74-
name: "FrontendRoute",
75-
method: "GET",
76-
path: "/",
77-
expected: "",
78-
},
79-
}
80100

81-
for _, tt := range tests {
82-
tt := tt
83-
t.Run(tt.name, func(t *testing.T) {
84-
t.Parallel()
101+
req := httptest.NewRequest("GET", "/api/v2/users/john", nil)
85102

86-
req := httptest.NewRequest(tt.method, tt.path, nil)
103+
sw := &tracing.StatusWriter{ResponseWriter: httptest.NewRecorder()}
87104

88-
sw := &tracing.StatusWriter{ResponseWriter: httptest.NewRecorder()}
105+
r.ServeHTTP(sw, req)
89106

90-
r.ServeHTTP(sw, req)
107+
metrics, err := reg.Gather()
108+
require.NoError(t, err)
109+
require.Greater(t, len(metrics), 0)
110+
metricLabels := getMetricLabels(metrics)
91111

92-
metrics, err := reg.Gather()
93-
require.NoError(t, err)
94-
require.Greater(t, len(metrics), 0)
112+
reqProcessed, ok := metricLabels["coderd_api_requests_processed_total"]
113+
require.True(t, ok, "coderd_api_requests_processed_total metric not found")
114+
require.Equal(t, "/api/v2/users/{user}", reqProcessed["path"])
115+
require.Equal(t, "GET", reqProcessed["method"])
95116

96-
// Verify the result
97-
// require.Equal(t, tt.expected, pattern, "unexpected route pattern")
98-
})
99-
}
117+
concurrentRequests, ok := metricLabels["coderd_api_concurrent_requests"]
118+
require.True(t, ok, "coderd_api_concurrent_requests metric not found")
119+
require.Equal(t, "/api/v2/users/{user}", concurrentRequests["path"])
120+
require.Equal(t, "GET", concurrentRequests["method"])
100121
}

0 commit comments

Comments
 (0)