Skip to content

chore: instrument github oauth2 limits #11532

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jan 10, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Add and fixup unit tests for prom metrics
  • Loading branch information
Emyrk committed Jan 10, 2024
commit f9317a5c1ed1b45dcad0bd54cea41e95d30be4d7
9 changes: 9 additions & 0 deletions coderd/coderdtest/oidctest/idp.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@ type FakeIDP struct {
// to test something like PKI auth vs a client_secret.
hookAuthenticateClient func(t testing.TB, req *http.Request) (url.Values, error)
serve bool
// optional middlewares
middlewares chi.Middlewares
}

func StatusError(code int, err error) error {
Expand Down Expand Up @@ -115,6 +117,12 @@ func WithAuthorizedRedirectURL(hook func(redirectURL string) error) func(*FakeID
}
}

func WithMiddlewares(mws ...func(http.Handler) http.Handler) func(*FakeIDP) {
return func(f *FakeIDP) {
f.middlewares = append(f.middlewares, mws...)
}
}

// WithRefresh is called when a refresh token is used. The email is
// the email of the user that is being refreshed assuming the claims are correct.
func WithRefresh(hook func(email string) error) func(*FakeIDP) {
Expand Down Expand Up @@ -570,6 +578,7 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler {
t.Helper()

mux := chi.NewMux()
mux.Use(f.middlewares...)
// This endpoint is required to initialize the OIDC provider.
// It is used to get the OIDC configuration.
mux.Get("/.well-known/openid-configuration", func(rw http.ResponseWriter, r *http.Request) {
Expand Down
2 changes: 2 additions & 0 deletions coderd/promoauth/github.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ func githubRateLimits(resp *http.Response, err error) (rateLimits, bool) {
// the limit is hit.

if len(p.errors) > 0 {
// If we are missing any headers, then do not try and guess
// what the rate limits are.
return limits, false
}
return limits, true
Expand Down
11 changes: 9 additions & 2 deletions coderd/promoauth/oauth2.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ var _ OAuth2Config = (*Config)(nil)
// Primarily to avoid any prometheus errors registering duplicate metrics.
type Factory struct {
metrics *metrics
// optional replace now func
Now func() time.Time
}

// metrics is the reusable metrics for all oauth2 providers.
Expand Down Expand Up @@ -130,7 +132,7 @@ func NewFactory(registry prometheus.Registerer) *Factory {
}
}

func (f *Factory) New(name string, under OAuth2Config, opts ...func(cfg *Config)) *Config {
func (f *Factory) New(name string, under OAuth2Config) *Config {
return &Config{
name: name,
underlying: under,
Expand All @@ -140,6 +142,8 @@ func (f *Factory) New(name string, under OAuth2Config, opts ...func(cfg *Config)

// NewGithub returns a new instrumented oauth2 config for github. It tracks
// rate limits as well as just the external request counts.
//
//nolint:bodyclose
func (f *Factory) NewGithub(name string, under OAuth2Config) *Config {
cfg := f.New(name, under)
cfg.interceptors = append(cfg.interceptors, func(resp *http.Response, err error) {
Expand All @@ -155,7 +159,10 @@ func (f *Factory) NewGithub(name string, under OAuth2Config) *Config {
resetIn := float64(-1)
if !limits.Reset.IsZero() {
now := time.Now()
resetIn = float64(limits.Reset.Sub(now).Seconds())
if f.Now != nil {
now = f.Now()
}
resetIn = limits.Reset.Sub(now).Seconds()
if resetIn < 0 {
// If it just reset, just make it 0.
resetIn = 0
Expand Down
209 changes: 199 additions & 10 deletions coderd/promoauth/oauth2_test.go
Original file line number Diff line number Diff line change
@@ -1,14 +1,23 @@
package promoauth_test

import (
"context"
"fmt"
"io"
"net/http"
"net/url"
"net/http/httptest"
"strings"
"testing"
"time"

"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promhttp"
ptestutil "github.com/prometheus/client_golang/prometheus/testutil"
io_prometheus_client "github.com/prometheus/client_model/go"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/exp/maps"
"golang.org/x/oauth2"

"github.com/coder/coder/v2/coderd/coderdtest/oidctest"
"github.com/coder/coder/v2/coderd/externalauth"
Expand All @@ -22,53 +31,197 @@ func TestInstrument(t *testing.T) {
ctx := testutil.Context(t, testutil.WaitShort)
idp := oidctest.NewFakeIDP(t, oidctest.WithServing())
reg := prometheus.NewRegistry()
count := func() int {
return ptestutil.CollectAndCount(reg, "coderd_oauth2_external_requests_total")
t.Cleanup(func() {
if t.Failed() {
t.Log(registryDump(reg))
}
})

const id = "test"
labels := prometheus.Labels{
"name": id,
"status_code": "200",
}
const metricname = "coderd_oauth2_external_requests_total"
count := func(source string) int {
labels["source"] = source
return counterValue(t, reg, "coderd_oauth2_external_requests_total", labels)
}

factory := promoauth.NewFactory(reg)
const id = "test"

cfg := externalauth.Config{
InstrumentedOAuth2Config: factory.New(id, idp.OIDCConfig(t, []string{})),
ID: "test",
ValidateURL: must[*url.URL](t)(idp.IssuerURL().Parse("/oauth2/userinfo")).String(),
}

// 0 Requests before we start
require.Equal(t, count(), 0)
require.Nil(t, metricValue(t, reg, metricname, labels), "no metrics at start")

// Exchange should trigger a request
code := idp.CreateAuthCode(t, "foo")
token, err := cfg.Exchange(ctx, code)
require.NoError(t, err)
require.Equal(t, count(), 1)
require.Equal(t, count("Exchange"), 1)

// Force a refresh
token.Expiry = time.Now().Add(time.Hour * -1)
src := cfg.TokenSource(ctx, token)
refreshed, err := src.Token()
require.NoError(t, err)
require.NotEqual(t, token.AccessToken, refreshed.AccessToken, "token refreshed")
require.Equal(t, count(), 2)
require.Equal(t, count("TokenSource"), 1)

// Try a validate
valid, _, err := cfg.ValidateToken(ctx, refreshed.AccessToken)
require.NoError(t, err)
require.True(t, valid)
require.Equal(t, count(), 3)
require.Equal(t, count("ValidateToken"), 1)

// Verify the default client was not broken. This check is added because we
// extend the http.DefaultTransport. If a `.Clone()` is not done, this can be
// mis-used. It is cheap to run this quick check.
snapshot := registryDump(reg)
req, err := http.NewRequestWithContext(ctx, http.MethodGet,
must[*url.URL](t)(idp.IssuerURL().Parse("/.well-known/openid-configuration")).String(), nil)
must(idp.IssuerURL().Parse("/.well-known/openid-configuration")).String(), nil)
require.NoError(t, err)

resp, err := http.DefaultClient.Do(req)
require.NoError(t, err)
_ = resp.Body.Close()

require.Equal(t, count(), 3)
require.NoError(t, compare(reg, snapshot), "no metric changes")
}

func TestGithubRateLimits(t *testing.T) {
t.Parallel()

now := time.Now()
cases := []struct {
Name string
NoHeaders bool
Omit []string
ExpectNoMetrics bool
Limit int
Remaining int
Used int
Reset time.Time

at time.Time
}{
{
Name: "NoHeaders",
NoHeaders: true,
ExpectNoMetrics: true,
},
{
Name: "ZeroHeaders",
ExpectNoMetrics: true,
},
{
Name: "OverLimit",
Limit: 100,
Remaining: 0,
Used: 500,
Reset: now.Add(time.Hour),
at: now,
},
{
Name: "UnderLimit",
Limit: 100,
Remaining: 0,
Used: 500,
Reset: now.Add(time.Hour),
at: now,
},
{
Name: "Partial",
Omit: []string{"x-ratelimit-remaining"},
ExpectNoMetrics: true,
Limit: 100,
Remaining: 0,
Used: 500,
Reset: now.Add(time.Hour),
at: now,
},
}

for _, c := range cases {
c := c
t.Run(c.Name, func(t *testing.T) {
t.Parallel()

reg := prometheus.NewRegistry()
idp := oidctest.NewFakeIDP(t, oidctest.WithMiddlewares(
func(next http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
if !c.NoHeaders {
rw.Header().Set("x-ratelimit-limit", fmt.Sprintf("%d", c.Limit))
rw.Header().Set("x-ratelimit-remaining", fmt.Sprintf("%d", c.Remaining))
rw.Header().Set("x-ratelimit-used", fmt.Sprintf("%d", c.Used))
rw.Header().Set("x-ratelimit-resource", "core")
rw.Header().Set("x-ratelimit-reset", fmt.Sprintf("%d", c.Reset.Unix()))
for _, omit := range c.Omit {
rw.Header().Del(omit)
}
}

next.ServeHTTP(rw, r)
})
}))

factory := promoauth.NewFactory(reg)
if !c.at.IsZero() {
factory.Now = func() time.Time {
return c.at
}
}

cfg := factory.NewGithub("test", idp.OIDCConfig(t, []string{}))

// Do a single oauth2 call
ctx := testutil.Context(t, testutil.WaitShort)
ctx = context.WithValue(ctx, oauth2.HTTPClient, idp.HTTPClient(nil))
_, err := cfg.Exchange(ctx, idp.CreateAuthCode(t, "foo"))
require.NoError(t, err)

// Verify
labels := prometheus.Labels{
"name": "test",
"resource": "core",
}
pass := true
if !c.ExpectNoMetrics {
pass = pass && assert.Equal(t, gaugeValue(t, reg, "coderd_oauth2_external_requests_rate_limit_total", labels), c.Limit, "limit")
pass = pass && assert.Equal(t, gaugeValue(t, reg, "coderd_oauth2_external_requests_rate_limit_remaining", labels), c.Remaining, "remaining")
pass = pass && assert.Equal(t, gaugeValue(t, reg, "coderd_oauth2_external_requests_rate_limit_used", labels), c.Used, "used")
if !c.at.IsZero() {
until := c.Reset.Sub(c.at)
// Float accuracy is not great, so we allow a delta of 2
pass = pass && assert.InDelta(t, gaugeValue(t, reg, "coderd_oauth2_external_requests_rate_limit_reset_in_seconds", labels), int(until.Seconds()), 2, "reset in")
}
} else {
pass = pass && assert.Nil(t, metricValue(t, reg, "coderd_oauth2_external_requests_rate_limit_total", labels), "not exists")
}

// Helpful debugging
if !pass {
t.Log(registryDump(reg))
}
})
}
}

func registryDump(reg *prometheus.Registry) string {
h := promhttp.HandlerFor(reg, promhttp.HandlerOpts{})
rec := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodGet, "/", nil)
h.ServeHTTP(rec, req)
resp := rec.Result()
data, _ := io.ReadAll(resp.Body)
_ = resp.Body.Close()
return string(data)
}

func must[V any](t *testing.T) func(v V, err error) V {
Expand All @@ -78,3 +231,39 @@ func must[V any](t *testing.T) func(v V, err error) V {
return v
}
}

func gaugeValue(t testing.TB, reg prometheus.Gatherer, metricName string, labels prometheus.Labels) int {
labeled := metricValue(t, reg, metricName, labels)
require.NotNilf(t, labeled, "metric %q with labels %v not found", metricName, labels)
return int(labeled.GetGauge().GetValue())
}

func counterValue(t testing.TB, reg prometheus.Gatherer, metricName string, labels prometheus.Labels) int {
labeled := metricValue(t, reg, metricName, labels)
require.NotNilf(t, labeled, "metric %q with labels %v not found", metricName, labels)
return int(labeled.GetCounter().GetValue())
}

func compare(reg prometheus.Gatherer, compare string) error {
return ptestutil.GatherAndCompare(reg, strings.NewReader(compare))
}

func metricValue(t testing.TB, reg prometheus.Gatherer, metricName string, labels prometheus.Labels) *io_prometheus_client.Metric {
metrics, err := reg.Gather()
require.NoError(t, err)

for _, m := range metrics {
if m.GetName() == metricName {
for _, labeled := range m.GetMetric() {
mLables := make(prometheus.Labels)
for _, v := range labeled.GetLabel() {
mLables[v.GetName()] = v.GetValue()
}
if maps.Equal(mLables, labels) {
return labeled
}
}
}
}
return nil
}