1
1
package promoauth_test
2
2
3
3
import (
4
+ "context"
5
+ "fmt"
6
+ "io"
4
7
"net/http"
8
+ "net/http/httptest"
9
+ "strings"
5
10
"testing"
6
11
"time"
7
12
8
13
"github.com/prometheus/client_golang/prometheus"
14
+ "github.com/prometheus/client_golang/prometheus/promhttp"
9
15
ptestutil "github.com/prometheus/client_golang/prometheus/testutil"
16
+ io_prometheus_client "github.com/prometheus/client_model/go"
17
+ "github.com/stretchr/testify/assert"
10
18
"github.com/stretchr/testify/require"
19
+ "golang.org/x/exp/maps"
20
+ "golang.org/x/oauth2"
11
21
12
22
"github.com/coder/coder/v2/coderd/coderdtest/oidctest"
13
23
"github.com/coder/coder/v2/coderd/externalauth"
@@ -21,44 +31,58 @@ func TestInstrument(t *testing.T) {
21
31
ctx := testutil .Context (t , testutil .WaitShort )
22
32
idp := oidctest .NewFakeIDP (t , oidctest .WithServing ())
23
33
reg := prometheus .NewRegistry ()
24
- count := func () int {
25
- return ptestutil .CollectAndCount (reg , "coderd_oauth2_external_requests_total" )
34
+ t .Cleanup (func () {
35
+ if t .Failed () {
36
+ t .Log (registryDump (reg ))
37
+ }
38
+ })
39
+
40
+ const id = "test"
41
+ labels := prometheus.Labels {
42
+ "name" : id ,
43
+ "status_code" : "200" ,
44
+ }
45
+ const metricname = "coderd_oauth2_external_requests_total"
46
+ count := func (source string ) int {
47
+ labels ["source" ] = source
48
+ return counterValue (t , reg , "coderd_oauth2_external_requests_total" , labels )
26
49
}
27
50
28
51
factory := promoauth .NewFactory (reg )
29
- const id = "test"
52
+
30
53
cfg := externalauth.Config {
31
54
InstrumentedOAuth2Config : factory .New (id , idp .OIDCConfig (t , []string {})),
32
55
ID : "test" ,
33
56
ValidateURL : must (idp .IssuerURL ().Parse ("/oauth2/userinfo" )).String (),
34
57
}
35
58
36
59
// 0 Requests before we start
37
- require .Equal (t , count ( ), 0 )
60
+ require .Nil (t , metricValue ( t , reg , metricname , labels ), "no metrics at start" )
38
61
39
62
// Exchange should trigger a request
40
63
code := idp .CreateAuthCode (t , "foo" )
41
64
token , err := cfg .Exchange (ctx , code )
42
65
require .NoError (t , err )
43
- require .Equal (t , count (), 1 )
66
+ require .Equal (t , count ("Exchange" ), 1 )
44
67
45
68
// Force a refresh
46
69
token .Expiry = time .Now ().Add (time .Hour * - 1 )
47
70
src := cfg .TokenSource (ctx , token )
48
71
refreshed , err := src .Token ()
49
72
require .NoError (t , err )
50
73
require .NotEqual (t , token .AccessToken , refreshed .AccessToken , "token refreshed" )
51
- require .Equal (t , count (), 2 )
74
+ require .Equal (t , count ("TokenSource" ), 1 )
52
75
53
76
// Try a validate
54
77
valid , _ , err := cfg .ValidateToken (ctx , refreshed .AccessToken )
55
78
require .NoError (t , err )
56
79
require .True (t , valid )
57
- require .Equal (t , count (), 3 )
80
+ require .Equal (t , count ("ValidateToken" ), 1 )
58
81
59
82
// Verify the default client was not broken. This check is added because we
60
83
// extend the http.DefaultTransport. If a `.Clone()` is not done, this can be
61
84
// mis-used. It is cheap to run this quick check.
85
+ snapshot := registryDump (reg )
62
86
req , err := http .NewRequest (http .MethodGet ,
63
87
must (idp .IssuerURL ().Parse ("/.well-known/openid-configuration" )).String (), nil )
64
88
require .NoError (t , err )
@@ -68,7 +92,137 @@ func TestInstrument(t *testing.T) {
68
92
require .NoError (t , err )
69
93
_ = resp .Body .Close ()
70
94
71
- require .Equal (t , count (), 3 )
95
+ require .NoError (t , compare (reg , snapshot ), "no metric changes" )
96
+ }
97
+
98
+ func TestGithubRateLimits (t * testing.T ) {
99
+ t .Parallel ()
100
+
101
+ now := time .Now ()
102
+ cases := []struct {
103
+ Name string
104
+ NoHeaders bool
105
+ Omit []string
106
+ ExpectNoMetrics bool
107
+ Limit int
108
+ Remaining int
109
+ Used int
110
+ Reset time.Time
111
+
112
+ at time.Time
113
+ }{
114
+ {
115
+ Name : "NoHeaders" ,
116
+ NoHeaders : true ,
117
+ ExpectNoMetrics : true ,
118
+ },
119
+ {
120
+ Name : "ZeroHeaders" ,
121
+ ExpectNoMetrics : true ,
122
+ },
123
+ {
124
+ Name : "OverLimit" ,
125
+ Limit : 100 ,
126
+ Remaining : 0 ,
127
+ Used : 500 ,
128
+ Reset : now .Add (time .Hour ),
129
+ at : now ,
130
+ },
131
+ {
132
+ Name : "UnderLimit" ,
133
+ Limit : 100 ,
134
+ Remaining : 0 ,
135
+ Used : 500 ,
136
+ Reset : now .Add (time .Hour ),
137
+ at : now ,
138
+ },
139
+ {
140
+ Name : "Partial" ,
141
+ Omit : []string {"x-ratelimit-remaining" },
142
+ ExpectNoMetrics : true ,
143
+ Limit : 100 ,
144
+ Remaining : 0 ,
145
+ Used : 500 ,
146
+ Reset : now .Add (time .Hour ),
147
+ at : now ,
148
+ },
149
+ }
150
+
151
+ for _ , c := range cases {
152
+ c := c
153
+ t .Run (c .Name , func (t * testing.T ) {
154
+ t .Parallel ()
155
+
156
+ reg := prometheus .NewRegistry ()
157
+ idp := oidctest .NewFakeIDP (t , oidctest .WithMiddlewares (
158
+ func (next http.Handler ) http.Handler {
159
+ return http .HandlerFunc (func (rw http.ResponseWriter , r * http.Request ) {
160
+ if ! c .NoHeaders {
161
+ rw .Header ().Set ("x-ratelimit-limit" , fmt .Sprintf ("%d" , c .Limit ))
162
+ rw .Header ().Set ("x-ratelimit-remaining" , fmt .Sprintf ("%d" , c .Remaining ))
163
+ rw .Header ().Set ("x-ratelimit-used" , fmt .Sprintf ("%d" , c .Used ))
164
+ rw .Header ().Set ("x-ratelimit-resource" , "core" )
165
+ rw .Header ().Set ("x-ratelimit-reset" , fmt .Sprintf ("%d" , c .Reset .Unix ()))
166
+ for _ , omit := range c .Omit {
167
+ rw .Header ().Del (omit )
168
+ }
169
+ }
170
+
171
+ next .ServeHTTP (rw , r )
172
+ })
173
+ }))
174
+
175
+ factory := promoauth .NewFactory (reg )
176
+ if ! c .at .IsZero () {
177
+ factory .Now = func () time.Time {
178
+ return c .at
179
+ }
180
+ }
181
+
182
+ cfg := factory .NewGithub ("test" , idp .OIDCConfig (t , []string {}))
183
+
184
+ // Do a single oauth2 call
185
+ ctx := testutil .Context (t , testutil .WaitShort )
186
+ ctx = context .WithValue (ctx , oauth2 .HTTPClient , idp .HTTPClient (nil ))
187
+ _ , err := cfg .Exchange (ctx , idp .CreateAuthCode (t , "foo" ))
188
+ require .NoError (t , err )
189
+
190
+ // Verify
191
+ labels := prometheus.Labels {
192
+ "name" : "test" ,
193
+ "resource" : "core" ,
194
+ }
195
+ pass := true
196
+ if ! c .ExpectNoMetrics {
197
+ pass = pass && assert .Equal (t , gaugeValue (t , reg , "coderd_oauth2_external_requests_rate_limit_total" , labels ), c .Limit , "limit" )
198
+ pass = pass && assert .Equal (t , gaugeValue (t , reg , "coderd_oauth2_external_requests_rate_limit_remaining" , labels ), c .Remaining , "remaining" )
199
+ pass = pass && assert .Equal (t , gaugeValue (t , reg , "coderd_oauth2_external_requests_rate_limit_used" , labels ), c .Used , "used" )
200
+ if ! c .at .IsZero () {
201
+ until := c .Reset .Sub (c .at )
202
+ // Float accuracy is not great, so we allow a delta of 2
203
+ pass = pass && assert .InDelta (t , gaugeValue (t , reg , "coderd_oauth2_external_requests_rate_limit_reset_in_seconds" , labels ), int (until .Seconds ()), 2 , "reset in" )
204
+ }
205
+ } else {
206
+ pass = pass && assert .Nil (t , metricValue (t , reg , "coderd_oauth2_external_requests_rate_limit_total" , labels ), "not exists" )
207
+ }
208
+
209
+ // Helpful debugging
210
+ if ! pass {
211
+ t .Log (registryDump (reg ))
212
+ }
213
+ })
214
+ }
215
+ }
216
+
217
+ func registryDump (reg * prometheus.Registry ) string {
218
+ h := promhttp .HandlerFor (reg , promhttp.HandlerOpts {})
219
+ rec := httptest .NewRecorder ()
220
+ req , _ := http .NewRequest (http .MethodGet , "/" , nil )
221
+ h .ServeHTTP (rec , req )
222
+ resp := rec .Result ()
223
+ data , _ := io .ReadAll (resp .Body )
224
+ _ = resp .Body .Close ()
225
+ return string (data )
72
226
}
73
227
74
228
func must [V any ](v V , err error ) V {
@@ -77,3 +231,39 @@ func must[V any](v V, err error) V {
77
231
}
78
232
return v
79
233
}
234
+
235
+ func gaugeValue (t testing.TB , reg prometheus.Gatherer , metricName string , labels prometheus.Labels ) int {
236
+ labeled := metricValue (t , reg , metricName , labels )
237
+ require .NotNilf (t , labeled , "metric %q with labels %v not found" , metricName , labels )
238
+ return int (labeled .GetGauge ().GetValue ())
239
+ }
240
+
241
+ func counterValue (t testing.TB , reg prometheus.Gatherer , metricName string , labels prometheus.Labels ) int {
242
+ labeled := metricValue (t , reg , metricName , labels )
243
+ require .NotNilf (t , labeled , "metric %q with labels %v not found" , metricName , labels )
244
+ return int (labeled .GetCounter ().GetValue ())
245
+ }
246
+
247
+ func compare (reg prometheus.Gatherer , compare string ) error {
248
+ return ptestutil .GatherAndCompare (reg , strings .NewReader (compare ))
249
+ }
250
+
251
+ func metricValue (t testing.TB , reg prometheus.Gatherer , metricName string , labels prometheus.Labels ) * io_prometheus_client.Metric {
252
+ metrics , err := reg .Gather ()
253
+ require .NoError (t , err )
254
+
255
+ for _ , m := range metrics {
256
+ if m .GetName () == metricName {
257
+ for _ , labeled := range m .GetMetric () {
258
+ mLables := make (prometheus.Labels )
259
+ for _ , v := range labeled .GetLabel () {
260
+ mLables [v .GetName ()] = v .GetValue ()
261
+ }
262
+ if maps .Equal (mLables , labels ) {
263
+ return labeled
264
+ }
265
+ }
266
+ }
267
+ }
268
+ return nil
269
+ }
0 commit comments