Skip to content

Commit f13df6f

Browse files
committed
Add and fixup unit tests for prom metrics
1 parent ec31915 commit f13df6f

File tree

4 files changed

+218
-10
lines changed

4 files changed

+218
-10
lines changed

coderd/coderdtest/oidctest/idp.go

+9
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,8 @@ type FakeIDP struct {
8585
// to test something like PKI auth vs a client_secret.
8686
hookAuthenticateClient func(t testing.TB, req *http.Request) (url.Values, error)
8787
serve bool
88+
// optional middlewares
89+
middlewares chi.Middlewares
8890
}
8991

9092
func StatusError(code int, err error) error {
@@ -115,6 +117,12 @@ func WithAuthorizedRedirectURL(hook func(redirectURL string) error) func(*FakeID
115117
}
116118
}
117119

120+
func WithMiddlewares(mws ...func(http.Handler) http.Handler) func(*FakeIDP) {
121+
return func(f *FakeIDP) {
122+
f.middlewares = append(f.middlewares, mws...)
123+
}
124+
}
125+
118126
// WithRefresh is called when a refresh token is used. The email is
119127
// the email of the user that is being refreshed assuming the claims are correct.
120128
func WithRefresh(hook func(email string) error) func(*FakeIDP) {
@@ -570,6 +578,7 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler {
570578
t.Helper()
571579

572580
mux := chi.NewMux()
581+
mux.Use(f.middlewares...)
573582
// This endpoint is required to initialize the OIDC provider.
574583
// It is used to get the OIDC configuration.
575584
mux.Get("/.well-known/openid-configuration", func(rw http.ResponseWriter, r *http.Request) {

coderd/promoauth/github.go

+2
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@ func githubRateLimits(resp *http.Response, err error) (rateLimits, bool) {
6262
// the limit is hit.
6363

6464
if len(p.errors) > 0 {
65+
// If we are missing any headers, then do not try and guess
66+
// what the rate limits are.
6567
return limits, false
6668
}
6769
return limits, true

coderd/promoauth/oauth2.go

+9-2
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ var _ OAuth2Config = (*Config)(nil)
3737
// Primarily to avoid any prometheus errors registering duplicate metrics.
3838
type Factory struct {
3939
metrics *metrics
40+
// optional replace now func
41+
Now func() time.Time
4042
}
4143

4244
// metrics is the reusable metrics for all oauth2 providers.
@@ -120,7 +122,7 @@ func NewFactory(registry prometheus.Registerer) *Factory {
120122
}
121123
}
122124

123-
func (f *Factory) New(name string, under OAuth2Config, opts ...func(cfg *Config)) *Config {
125+
func (f *Factory) New(name string, under OAuth2Config) *Config {
124126
return &Config{
125127
name: name,
126128
underlying: under,
@@ -130,6 +132,8 @@ func (f *Factory) New(name string, under OAuth2Config, opts ...func(cfg *Config)
130132

131133
// NewGithub returns a new instrumented oauth2 config for github. It tracks
132134
// rate limits as well as just the external request counts.
135+
//
136+
//nolint:bodyclose
133137
func (f *Factory) NewGithub(name string, under OAuth2Config) *Config {
134138
cfg := f.New(name, under)
135139
cfg.interceptors = append(cfg.interceptors, func(resp *http.Response, err error) {
@@ -145,7 +149,10 @@ func (f *Factory) NewGithub(name string, under OAuth2Config) *Config {
145149
resetIn := float64(-1)
146150
if !limits.Reset.IsZero() {
147151
now := time.Now()
148-
resetIn = float64(limits.Reset.Sub(now).Seconds())
152+
if f.Now != nil {
153+
now = f.Now()
154+
}
155+
resetIn = limits.Reset.Sub(now).Seconds()
149156
if resetIn < 0 {
150157
// If it just reset, just make it 0.
151158
resetIn = 0

coderd/promoauth/oauth2_test.go

+198-8
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,23 @@
11
package promoauth_test
22

33
import (
4+
"context"
5+
"fmt"
6+
"io"
47
"net/http"
8+
"net/http/httptest"
9+
"strings"
510
"testing"
611
"time"
712

813
"github.com/prometheus/client_golang/prometheus"
14+
"github.com/prometheus/client_golang/prometheus/promhttp"
915
ptestutil "github.com/prometheus/client_golang/prometheus/testutil"
16+
io_prometheus_client "github.com/prometheus/client_model/go"
17+
"github.com/stretchr/testify/assert"
1018
"github.com/stretchr/testify/require"
19+
"golang.org/x/exp/maps"
20+
"golang.org/x/oauth2"
1121

1222
"github.com/coder/coder/v2/coderd/coderdtest/oidctest"
1323
"github.com/coder/coder/v2/coderd/externalauth"
@@ -21,44 +31,58 @@ func TestInstrument(t *testing.T) {
2131
ctx := testutil.Context(t, testutil.WaitShort)
2232
idp := oidctest.NewFakeIDP(t, oidctest.WithServing())
2333
reg := prometheus.NewRegistry()
24-
count := func() int {
25-
return ptestutil.CollectAndCount(reg, "coderd_oauth2_external_requests_total")
34+
t.Cleanup(func() {
35+
if t.Failed() {
36+
t.Log(registryDump(reg))
37+
}
38+
})
39+
40+
const id = "test"
41+
labels := prometheus.Labels{
42+
"name": id,
43+
"status_code": "200",
44+
}
45+
const metricname = "coderd_oauth2_external_requests_total"
46+
count := func(source string) int {
47+
labels["source"] = source
48+
return counterValue(t, reg, "coderd_oauth2_external_requests_total", labels)
2649
}
2750

2851
factory := promoauth.NewFactory(reg)
29-
const id = "test"
52+
3053
cfg := externalauth.Config{
3154
InstrumentedOAuth2Config: factory.New(id, idp.OIDCConfig(t, []string{})),
3255
ID: "test",
3356
ValidateURL: must(idp.IssuerURL().Parse("/oauth2/userinfo")).String(),
3457
}
3558

3659
// 0 Requests before we start
37-
require.Equal(t, count(), 0)
60+
require.Nil(t, metricValue(t, reg, metricname, labels), "no metrics at start")
3861

3962
// Exchange should trigger a request
4063
code := idp.CreateAuthCode(t, "foo")
4164
token, err := cfg.Exchange(ctx, code)
4265
require.NoError(t, err)
43-
require.Equal(t, count(), 1)
66+
require.Equal(t, count("Exchange"), 1)
4467

4568
// Force a refresh
4669
token.Expiry = time.Now().Add(time.Hour * -1)
4770
src := cfg.TokenSource(ctx, token)
4871
refreshed, err := src.Token()
4972
require.NoError(t, err)
5073
require.NotEqual(t, token.AccessToken, refreshed.AccessToken, "token refreshed")
51-
require.Equal(t, count(), 2)
74+
require.Equal(t, count("TokenSource"), 1)
5275

5376
// Try a validate
5477
valid, _, err := cfg.ValidateToken(ctx, refreshed.AccessToken)
5578
require.NoError(t, err)
5679
require.True(t, valid)
57-
require.Equal(t, count(), 3)
80+
require.Equal(t, count("ValidateToken"), 1)
5881

5982
// Verify the default client was not broken. This check is added because we
6083
// extend the http.DefaultTransport. If a `.Clone()` is not done, this can be
6184
// mis-used. It is cheap to run this quick check.
85+
snapshot := registryDump(reg)
6286
req, err := http.NewRequest(http.MethodGet,
6387
must(idp.IssuerURL().Parse("/.well-known/openid-configuration")).String(), nil)
6488
require.NoError(t, err)
@@ -68,7 +92,137 @@ func TestInstrument(t *testing.T) {
6892
require.NoError(t, err)
6993
_ = resp.Body.Close()
7094

71-
require.Equal(t, count(), 3)
95+
require.NoError(t, compare(reg, snapshot), "no metric changes")
96+
}
97+
98+
func TestGithubRateLimits(t *testing.T) {
99+
t.Parallel()
100+
101+
now := time.Now()
102+
cases := []struct {
103+
Name string
104+
NoHeaders bool
105+
Omit []string
106+
ExpectNoMetrics bool
107+
Limit int
108+
Remaining int
109+
Used int
110+
Reset time.Time
111+
112+
at time.Time
113+
}{
114+
{
115+
Name: "NoHeaders",
116+
NoHeaders: true,
117+
ExpectNoMetrics: true,
118+
},
119+
{
120+
Name: "ZeroHeaders",
121+
ExpectNoMetrics: true,
122+
},
123+
{
124+
Name: "OverLimit",
125+
Limit: 100,
126+
Remaining: 0,
127+
Used: 500,
128+
Reset: now.Add(time.Hour),
129+
at: now,
130+
},
131+
{
132+
Name: "UnderLimit",
133+
Limit: 100,
134+
Remaining: 0,
135+
Used: 500,
136+
Reset: now.Add(time.Hour),
137+
at: now,
138+
},
139+
{
140+
Name: "Partial",
141+
Omit: []string{"x-ratelimit-remaining"},
142+
ExpectNoMetrics: true,
143+
Limit: 100,
144+
Remaining: 0,
145+
Used: 500,
146+
Reset: now.Add(time.Hour),
147+
at: now,
148+
},
149+
}
150+
151+
for _, c := range cases {
152+
c := c
153+
t.Run(c.Name, func(t *testing.T) {
154+
t.Parallel()
155+
156+
reg := prometheus.NewRegistry()
157+
idp := oidctest.NewFakeIDP(t, oidctest.WithMiddlewares(
158+
func(next http.Handler) http.Handler {
159+
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
160+
if !c.NoHeaders {
161+
rw.Header().Set("x-ratelimit-limit", fmt.Sprintf("%d", c.Limit))
162+
rw.Header().Set("x-ratelimit-remaining", fmt.Sprintf("%d", c.Remaining))
163+
rw.Header().Set("x-ratelimit-used", fmt.Sprintf("%d", c.Used))
164+
rw.Header().Set("x-ratelimit-resource", "core")
165+
rw.Header().Set("x-ratelimit-reset", fmt.Sprintf("%d", c.Reset.Unix()))
166+
for _, omit := range c.Omit {
167+
rw.Header().Del(omit)
168+
}
169+
}
170+
171+
next.ServeHTTP(rw, r)
172+
})
173+
}))
174+
175+
factory := promoauth.NewFactory(reg)
176+
if !c.at.IsZero() {
177+
factory.Now = func() time.Time {
178+
return c.at
179+
}
180+
}
181+
182+
cfg := factory.NewGithub("test", idp.OIDCConfig(t, []string{}))
183+
184+
// Do a single oauth2 call
185+
ctx := testutil.Context(t, testutil.WaitShort)
186+
ctx = context.WithValue(ctx, oauth2.HTTPClient, idp.HTTPClient(nil))
187+
_, err := cfg.Exchange(ctx, idp.CreateAuthCode(t, "foo"))
188+
require.NoError(t, err)
189+
190+
// Verify
191+
labels := prometheus.Labels{
192+
"name": "test",
193+
"resource": "core",
194+
}
195+
pass := true
196+
if !c.ExpectNoMetrics {
197+
pass = pass && assert.Equal(t, gaugeValue(t, reg, "coderd_oauth2_external_requests_rate_limit_total", labels), c.Limit, "limit")
198+
pass = pass && assert.Equal(t, gaugeValue(t, reg, "coderd_oauth2_external_requests_rate_limit_remaining", labels), c.Remaining, "remaining")
199+
pass = pass && assert.Equal(t, gaugeValue(t, reg, "coderd_oauth2_external_requests_rate_limit_used", labels), c.Used, "used")
200+
if !c.at.IsZero() {
201+
until := c.Reset.Sub(c.at)
202+
// Float accuracy is not great, so we allow a delta of 2
203+
pass = pass && assert.InDelta(t, gaugeValue(t, reg, "coderd_oauth2_external_requests_rate_limit_reset_in_seconds", labels), int(until.Seconds()), 2, "reset in")
204+
}
205+
} else {
206+
pass = pass && assert.Nil(t, metricValue(t, reg, "coderd_oauth2_external_requests_rate_limit_total", labels), "not exists")
207+
}
208+
209+
// Helpful debugging
210+
if !pass {
211+
t.Log(registryDump(reg))
212+
}
213+
})
214+
}
215+
}
216+
217+
func registryDump(reg *prometheus.Registry) string {
218+
h := promhttp.HandlerFor(reg, promhttp.HandlerOpts{})
219+
rec := httptest.NewRecorder()
220+
req, _ := http.NewRequest(http.MethodGet, "/", nil)
221+
h.ServeHTTP(rec, req)
222+
resp := rec.Result()
223+
data, _ := io.ReadAll(resp.Body)
224+
_ = resp.Body.Close()
225+
return string(data)
72226
}
73227

74228
func must[V any](v V, err error) V {
@@ -77,3 +231,39 @@ func must[V any](v V, err error) V {
77231
}
78232
return v
79233
}
234+
235+
func gaugeValue(t testing.TB, reg prometheus.Gatherer, metricName string, labels prometheus.Labels) int {
236+
labeled := metricValue(t, reg, metricName, labels)
237+
require.NotNilf(t, labeled, "metric %q with labels %v not found", metricName, labels)
238+
return int(labeled.GetGauge().GetValue())
239+
}
240+
241+
func counterValue(t testing.TB, reg prometheus.Gatherer, metricName string, labels prometheus.Labels) int {
242+
labeled := metricValue(t, reg, metricName, labels)
243+
require.NotNilf(t, labeled, "metric %q with labels %v not found", metricName, labels)
244+
return int(labeled.GetCounter().GetValue())
245+
}
246+
247+
func compare(reg prometheus.Gatherer, compare string) error {
248+
return ptestutil.GatherAndCompare(reg, strings.NewReader(compare))
249+
}
250+
251+
func metricValue(t testing.TB, reg prometheus.Gatherer, metricName string, labels prometheus.Labels) *io_prometheus_client.Metric {
252+
metrics, err := reg.Gather()
253+
require.NoError(t, err)
254+
255+
for _, m := range metrics {
256+
if m.GetName() == metricName {
257+
for _, labeled := range m.GetMetric() {
258+
mLables := make(prometheus.Labels)
259+
for _, v := range labeled.GetLabel() {
260+
mLables[v.GetName()] = v.GetValue()
261+
}
262+
if maps.Equal(mLables, labels) {
263+
return labeled
264+
}
265+
}
266+
}
267+
}
268+
return nil
269+
}

0 commit comments

Comments
 (0)