Skip to content

Commit 24c4bdd

Browse files
committed
Add and fixup unit tests for prom metrics
1 parent 045b03a commit 24c4bdd

File tree

4 files changed

+219
-12
lines changed

4 files changed

+219
-12
lines changed

coderd/coderdtest/oidctest/idp.go

+9
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,8 @@ type FakeIDP struct {
8585
// to test something like PKI auth vs a client_secret.
8686
hookAuthenticateClient func(t testing.TB, req *http.Request) (url.Values, error)
8787
serve bool
88+
// optional middlewares
89+
middlewares chi.Middlewares
8890
}
8991

9092
func StatusError(code int, err error) error {
@@ -115,6 +117,12 @@ func WithAuthorizedRedirectURL(hook func(redirectURL string) error) func(*FakeID
115117
}
116118
}
117119

120+
func WithMiddlewares(mws ...func(http.Handler) http.Handler) func(*FakeIDP) {
121+
return func(f *FakeIDP) {
122+
f.middlewares = append(f.middlewares, mws...)
123+
}
124+
}
125+
118126
// WithRefresh is called when a refresh token is used. The email is
119127
// the email of the user that is being refreshed assuming the claims are correct.
120128
func WithRefresh(hook func(email string) error) func(*FakeIDP) {
@@ -570,6 +578,7 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler {
570578
t.Helper()
571579

572580
mux := chi.NewMux()
581+
mux.Use(f.middlewares...)
573582
// This endpoint is required to initialize the OIDC provider.
574583
// It is used to get the OIDC configuration.
575584
mux.Get("/.well-known/openid-configuration", func(rw http.ResponseWriter, r *http.Request) {

coderd/promoauth/github.go

+2
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@ func githubRateLimits(resp *http.Response, err error) (rateLimits, bool) {
6262
// the limit is hit.
6363

6464
if len(p.errors) > 0 {
65+
// If we are missing any headers, then do not try and guess
66+
// what the rate limits are.
6567
return limits, false
6668
}
6769
return limits, true

coderd/promoauth/oauth2.go

+9-2
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ var _ OAuth2Config = (*Config)(nil)
4747
// Primarily to avoid any prometheus errors registering duplicate metrics.
4848
type Factory struct {
4949
metrics *metrics
50+
// optional replace now func
51+
Now func() time.Time
5052
}
5153

5254
// metrics is the reusable metrics for all oauth2 providers.
@@ -130,7 +132,7 @@ func NewFactory(registry prometheus.Registerer) *Factory {
130132
}
131133
}
132134

133-
func (f *Factory) New(name string, under OAuth2Config, opts ...func(cfg *Config)) *Config {
135+
func (f *Factory) New(name string, under OAuth2Config) *Config {
134136
return &Config{
135137
name: name,
136138
underlying: under,
@@ -140,6 +142,8 @@ func (f *Factory) New(name string, under OAuth2Config, opts ...func(cfg *Config)
140142

141143
// NewGithub returns a new instrumented oauth2 config for github. It tracks
142144
// rate limits as well as just the external request counts.
145+
//
146+
//nolint:bodyclose
143147
func (f *Factory) NewGithub(name string, under OAuth2Config) *Config {
144148
cfg := f.New(name, under)
145149
cfg.interceptors = append(cfg.interceptors, func(resp *http.Response, err error) {
@@ -155,7 +159,10 @@ func (f *Factory) NewGithub(name string, under OAuth2Config) *Config {
155159
resetIn := float64(-1)
156160
if !limits.Reset.IsZero() {
157161
now := time.Now()
158-
resetIn = float64(limits.Reset.Sub(now).Seconds())
162+
if f.Now != nil {
163+
now = f.Now()
164+
}
165+
resetIn = limits.Reset.Sub(now).Seconds()
159166
if resetIn < 0 {
160167
// If it just reset, just make it 0.
161168
resetIn = 0

coderd/promoauth/oauth2_test.go

+199-10
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,23 @@
11
package promoauth_test
22

33
import (
4+
"context"
5+
"fmt"
6+
"io"
47
"net/http"
5-
"net/url"
8+
"net/http/httptest"
9+
"strings"
610
"testing"
711
"time"
812

913
"github.com/prometheus/client_golang/prometheus"
14+
"github.com/prometheus/client_golang/prometheus/promhttp"
1015
ptestutil "github.com/prometheus/client_golang/prometheus/testutil"
16+
io_prometheus_client "github.com/prometheus/client_model/go"
17+
"github.com/stretchr/testify/assert"
1118
"github.com/stretchr/testify/require"
19+
"golang.org/x/exp/maps"
20+
"golang.org/x/oauth2"
1221

1322
"github.com/coder/coder/v2/coderd/coderdtest/oidctest"
1423
"github.com/coder/coder/v2/coderd/externalauth"
@@ -22,53 +31,197 @@ func TestInstrument(t *testing.T) {
2231
ctx := testutil.Context(t, testutil.WaitShort)
2332
idp := oidctest.NewFakeIDP(t, oidctest.WithServing())
2433
reg := prometheus.NewRegistry()
25-
count := func() int {
26-
return ptestutil.CollectAndCount(reg, "coderd_oauth2_external_requests_total")
34+
t.Cleanup(func() {
35+
if t.Failed() {
36+
t.Log(registryDump(reg))
37+
}
38+
})
39+
40+
const id = "test"
41+
labels := prometheus.Labels{
42+
"name": id,
43+
"status_code": "200",
44+
}
45+
const metricname = "coderd_oauth2_external_requests_total"
46+
count := func(source string) int {
47+
labels["source"] = source
48+
return counterValue(t, reg, "coderd_oauth2_external_requests_total", labels)
2749
}
2850

2951
factory := promoauth.NewFactory(reg)
30-
const id = "test"
52+
3153
cfg := externalauth.Config{
3254
InstrumentedOAuth2Config: factory.New(id, idp.OIDCConfig(t, []string{})),
3355
ID: "test",
3456
ValidateURL: must[*url.URL](t)(idp.IssuerURL().Parse("/oauth2/userinfo")).String(),
3557
}
3658

3759
// 0 Requests before we start
38-
require.Equal(t, count(), 0)
60+
require.Nil(t, metricValue(t, reg, metricname, labels), "no metrics at start")
3961

4062
// Exchange should trigger a request
4163
code := idp.CreateAuthCode(t, "foo")
4264
token, err := cfg.Exchange(ctx, code)
4365
require.NoError(t, err)
44-
require.Equal(t, count(), 1)
66+
require.Equal(t, count("Exchange"), 1)
4567

4668
// Force a refresh
4769
token.Expiry = time.Now().Add(time.Hour * -1)
4870
src := cfg.TokenSource(ctx, token)
4971
refreshed, err := src.Token()
5072
require.NoError(t, err)
5173
require.NotEqual(t, token.AccessToken, refreshed.AccessToken, "token refreshed")
52-
require.Equal(t, count(), 2)
74+
require.Equal(t, count("TokenSource"), 1)
5375

5476
// Try a validate
5577
valid, _, err := cfg.ValidateToken(ctx, refreshed.AccessToken)
5678
require.NoError(t, err)
5779
require.True(t, valid)
58-
require.Equal(t, count(), 3)
80+
require.Equal(t, count("ValidateToken"), 1)
5981

6082
// Verify the default client was not broken. This check is added because we
6183
// extend the http.DefaultTransport. If a `.Clone()` is not done, this can be
6284
// mis-used. It is cheap to run this quick check.
85+
snapshot := registryDump(reg)
6386
req, err := http.NewRequestWithContext(ctx, http.MethodGet,
64-
must[*url.URL](t)(idp.IssuerURL().Parse("/.well-known/openid-configuration")).String(), nil)
87+
must(idp.IssuerURL().Parse("/.well-known/openid-configuration")).String(), nil)
6588
require.NoError(t, err)
6689

6790
resp, err := http.DefaultClient.Do(req)
6891
require.NoError(t, err)
6992
_ = resp.Body.Close()
7093

71-
require.Equal(t, count(), 3)
94+
require.NoError(t, compare(reg, snapshot), "no metric changes")
95+
}
96+
97+
func TestGithubRateLimits(t *testing.T) {
98+
t.Parallel()
99+
100+
now := time.Now()
101+
cases := []struct {
102+
Name string
103+
NoHeaders bool
104+
Omit []string
105+
ExpectNoMetrics bool
106+
Limit int
107+
Remaining int
108+
Used int
109+
Reset time.Time
110+
111+
at time.Time
112+
}{
113+
{
114+
Name: "NoHeaders",
115+
NoHeaders: true,
116+
ExpectNoMetrics: true,
117+
},
118+
{
119+
Name: "ZeroHeaders",
120+
ExpectNoMetrics: true,
121+
},
122+
{
123+
Name: "OverLimit",
124+
Limit: 100,
125+
Remaining: 0,
126+
Used: 500,
127+
Reset: now.Add(time.Hour),
128+
at: now,
129+
},
130+
{
131+
Name: "UnderLimit",
132+
Limit: 100,
133+
Remaining: 0,
134+
Used: 500,
135+
Reset: now.Add(time.Hour),
136+
at: now,
137+
},
138+
{
139+
Name: "Partial",
140+
Omit: []string{"x-ratelimit-remaining"},
141+
ExpectNoMetrics: true,
142+
Limit: 100,
143+
Remaining: 0,
144+
Used: 500,
145+
Reset: now.Add(time.Hour),
146+
at: now,
147+
},
148+
}
149+
150+
for _, c := range cases {
151+
c := c
152+
t.Run(c.Name, func(t *testing.T) {
153+
t.Parallel()
154+
155+
reg := prometheus.NewRegistry()
156+
idp := oidctest.NewFakeIDP(t, oidctest.WithMiddlewares(
157+
func(next http.Handler) http.Handler {
158+
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
159+
if !c.NoHeaders {
160+
rw.Header().Set("x-ratelimit-limit", fmt.Sprintf("%d", c.Limit))
161+
rw.Header().Set("x-ratelimit-remaining", fmt.Sprintf("%d", c.Remaining))
162+
rw.Header().Set("x-ratelimit-used", fmt.Sprintf("%d", c.Used))
163+
rw.Header().Set("x-ratelimit-resource", "core")
164+
rw.Header().Set("x-ratelimit-reset", fmt.Sprintf("%d", c.Reset.Unix()))
165+
for _, omit := range c.Omit {
166+
rw.Header().Del(omit)
167+
}
168+
}
169+
170+
next.ServeHTTP(rw, r)
171+
})
172+
}))
173+
174+
factory := promoauth.NewFactory(reg)
175+
if !c.at.IsZero() {
176+
factory.Now = func() time.Time {
177+
return c.at
178+
}
179+
}
180+
181+
cfg := factory.NewGithub("test", idp.OIDCConfig(t, []string{}))
182+
183+
// Do a single oauth2 call
184+
ctx := testutil.Context(t, testutil.WaitShort)
185+
ctx = context.WithValue(ctx, oauth2.HTTPClient, idp.HTTPClient(nil))
186+
_, err := cfg.Exchange(ctx, idp.CreateAuthCode(t, "foo"))
187+
require.NoError(t, err)
188+
189+
// Verify
190+
labels := prometheus.Labels{
191+
"name": "test",
192+
"resource": "core",
193+
}
194+
pass := true
195+
if !c.ExpectNoMetrics {
196+
pass = pass && assert.Equal(t, gaugeValue(t, reg, "coderd_oauth2_external_requests_rate_limit_total", labels), c.Limit, "limit")
197+
pass = pass && assert.Equal(t, gaugeValue(t, reg, "coderd_oauth2_external_requests_rate_limit_remaining", labels), c.Remaining, "remaining")
198+
pass = pass && assert.Equal(t, gaugeValue(t, reg, "coderd_oauth2_external_requests_rate_limit_used", labels), c.Used, "used")
199+
if !c.at.IsZero() {
200+
until := c.Reset.Sub(c.at)
201+
// Float accuracy is not great, so we allow a delta of 2
202+
pass = pass && assert.InDelta(t, gaugeValue(t, reg, "coderd_oauth2_external_requests_rate_limit_reset_in_seconds", labels), int(until.Seconds()), 2, "reset in")
203+
}
204+
} else {
205+
pass = pass && assert.Nil(t, metricValue(t, reg, "coderd_oauth2_external_requests_rate_limit_total", labels), "not exists")
206+
}
207+
208+
// Helpful debugging
209+
if !pass {
210+
t.Log(registryDump(reg))
211+
}
212+
})
213+
}
214+
}
215+
216+
func registryDump(reg *prometheus.Registry) string {
217+
h := promhttp.HandlerFor(reg, promhttp.HandlerOpts{})
218+
rec := httptest.NewRecorder()
219+
req, _ := http.NewRequest(http.MethodGet, "/", nil)
220+
h.ServeHTTP(rec, req)
221+
resp := rec.Result()
222+
data, _ := io.ReadAll(resp.Body)
223+
_ = resp.Body.Close()
224+
return string(data)
72225
}
73226

74227
func must[V any](t *testing.T) func(v V, err error) V {
@@ -78,3 +231,39 @@ func must[V any](t *testing.T) func(v V, err error) V {
78231
return v
79232
}
80233
}
234+
235+
func gaugeValue(t testing.TB, reg prometheus.Gatherer, metricName string, labels prometheus.Labels) int {
236+
labeled := metricValue(t, reg, metricName, labels)
237+
require.NotNilf(t, labeled, "metric %q with labels %v not found", metricName, labels)
238+
return int(labeled.GetGauge().GetValue())
239+
}
240+
241+
func counterValue(t testing.TB, reg prometheus.Gatherer, metricName string, labels prometheus.Labels) int {
242+
labeled := metricValue(t, reg, metricName, labels)
243+
require.NotNilf(t, labeled, "metric %q with labels %v not found", metricName, labels)
244+
return int(labeled.GetCounter().GetValue())
245+
}
246+
247+
func compare(reg prometheus.Gatherer, compare string) error {
248+
return ptestutil.GatherAndCompare(reg, strings.NewReader(compare))
249+
}
250+
251+
func metricValue(t testing.TB, reg prometheus.Gatherer, metricName string, labels prometheus.Labels) *io_prometheus_client.Metric {
252+
metrics, err := reg.Gather()
253+
require.NoError(t, err)
254+
255+
for _, m := range metrics {
256+
if m.GetName() == metricName {
257+
for _, labeled := range m.GetMetric() {
258+
mLables := make(prometheus.Labels)
259+
for _, v := range labeled.GetLabel() {
260+
mLables[v.GetName()] = v.GetValue()
261+
}
262+
if maps.Equal(mLables, labels) {
263+
return labeled
264+
}
265+
}
266+
}
267+
}
268+
return nil
269+
}

0 commit comments

Comments
 (0)