1
1
package promoauth_test
2
2
3
3
import (
4
+ "context"
5
+ "fmt"
6
+ "io"
4
7
"net/http"
5
- "net/url"
8
+ "net/http/httptest"
9
+ "strings"
6
10
"testing"
7
11
"time"
8
12
9
13
"github.com/prometheus/client_golang/prometheus"
14
+ "github.com/prometheus/client_golang/prometheus/promhttp"
10
15
ptestutil "github.com/prometheus/client_golang/prometheus/testutil"
16
+ io_prometheus_client "github.com/prometheus/client_model/go"
17
+ "github.com/stretchr/testify/assert"
11
18
"github.com/stretchr/testify/require"
19
+ "golang.org/x/exp/maps"
20
+ "golang.org/x/oauth2"
12
21
13
22
"github.com/coder/coder/v2/coderd/coderdtest/oidctest"
14
23
"github.com/coder/coder/v2/coderd/externalauth"
@@ -22,53 +31,197 @@ func TestInstrument(t *testing.T) {
22
31
ctx := testutil .Context (t , testutil .WaitShort )
23
32
idp := oidctest .NewFakeIDP (t , oidctest .WithServing ())
24
33
reg := prometheus .NewRegistry ()
25
- count := func () int {
26
- return ptestutil .CollectAndCount (reg , "coderd_oauth2_external_requests_total" )
34
+ t .Cleanup (func () {
35
+ if t .Failed () {
36
+ t .Log (registryDump (reg ))
37
+ }
38
+ })
39
+
40
+ const id = "test"
41
+ labels := prometheus.Labels {
42
+ "name" : id ,
43
+ "status_code" : "200" ,
44
+ }
45
+ const metricname = "coderd_oauth2_external_requests_total"
46
+ count := func (source string ) int {
47
+ labels ["source" ] = source
48
+ return counterValue (t , reg , "coderd_oauth2_external_requests_total" , labels )
27
49
}
28
50
29
51
factory := promoauth .NewFactory (reg )
30
- const id = "test"
52
+
31
53
cfg := externalauth.Config {
32
54
InstrumentedOAuth2Config : factory .New (id , idp .OIDCConfig (t , []string {})),
33
55
ID : "test" ,
34
56
ValidateURL : must [* url.URL ](t )(idp .IssuerURL ().Parse ("/oauth2/userinfo" )).String (),
35
57
}
36
58
37
59
// 0 Requests before we start
38
- require .Equal (t , count ( ), 0 )
60
+ require .Nil (t , metricValue ( t , reg , metricname , labels ), "no metrics at start" )
39
61
40
62
// Exchange should trigger a request
41
63
code := idp .CreateAuthCode (t , "foo" )
42
64
token , err := cfg .Exchange (ctx , code )
43
65
require .NoError (t , err )
44
- require .Equal (t , count (), 1 )
66
+ require .Equal (t , count ("Exchange" ), 1 )
45
67
46
68
// Force a refresh
47
69
token .Expiry = time .Now ().Add (time .Hour * - 1 )
48
70
src := cfg .TokenSource (ctx , token )
49
71
refreshed , err := src .Token ()
50
72
require .NoError (t , err )
51
73
require .NotEqual (t , token .AccessToken , refreshed .AccessToken , "token refreshed" )
52
- require .Equal (t , count (), 2 )
74
+ require .Equal (t , count ("TokenSource" ), 1 )
53
75
54
76
// Try a validate
55
77
valid , _ , err := cfg .ValidateToken (ctx , refreshed .AccessToken )
56
78
require .NoError (t , err )
57
79
require .True (t , valid )
58
- require .Equal (t , count (), 3 )
80
+ require .Equal (t , count ("ValidateToken" ), 1 )
59
81
60
82
// Verify the default client was not broken. This check is added because we
61
83
// extend the http.DefaultTransport. If a `.Clone()` is not done, this can be
62
84
// mis-used. It is cheap to run this quick check.
85
+ snapshot := registryDump (reg )
63
86
req , err := http .NewRequestWithContext (ctx , http .MethodGet ,
64
- must [ * url. URL ]( t ) (idp .IssuerURL ().Parse ("/.well-known/openid-configuration" )).String (), nil )
87
+ must (idp .IssuerURL ().Parse ("/.well-known/openid-configuration" )).String (), nil )
65
88
require .NoError (t , err )
66
89
67
90
resp , err := http .DefaultClient .Do (req )
68
91
require .NoError (t , err )
69
92
_ = resp .Body .Close ()
70
93
71
- require .Equal (t , count (), 3 )
94
+ require .NoError (t , compare (reg , snapshot ), "no metric changes" )
95
+ }
96
+
97
+ func TestGithubRateLimits (t * testing.T ) {
98
+ t .Parallel ()
99
+
100
+ now := time .Now ()
101
+ cases := []struct {
102
+ Name string
103
+ NoHeaders bool
104
+ Omit []string
105
+ ExpectNoMetrics bool
106
+ Limit int
107
+ Remaining int
108
+ Used int
109
+ Reset time.Time
110
+
111
+ at time.Time
112
+ }{
113
+ {
114
+ Name : "NoHeaders" ,
115
+ NoHeaders : true ,
116
+ ExpectNoMetrics : true ,
117
+ },
118
+ {
119
+ Name : "ZeroHeaders" ,
120
+ ExpectNoMetrics : true ,
121
+ },
122
+ {
123
+ Name : "OverLimit" ,
124
+ Limit : 100 ,
125
+ Remaining : 0 ,
126
+ Used : 500 ,
127
+ Reset : now .Add (time .Hour ),
128
+ at : now ,
129
+ },
130
+ {
131
+ Name : "UnderLimit" ,
132
+ Limit : 100 ,
133
+ Remaining : 0 ,
134
+ Used : 500 ,
135
+ Reset : now .Add (time .Hour ),
136
+ at : now ,
137
+ },
138
+ {
139
+ Name : "Partial" ,
140
+ Omit : []string {"x-ratelimit-remaining" },
141
+ ExpectNoMetrics : true ,
142
+ Limit : 100 ,
143
+ Remaining : 0 ,
144
+ Used : 500 ,
145
+ Reset : now .Add (time .Hour ),
146
+ at : now ,
147
+ },
148
+ }
149
+
150
+ for _ , c := range cases {
151
+ c := c
152
+ t .Run (c .Name , func (t * testing.T ) {
153
+ t .Parallel ()
154
+
155
+ reg := prometheus .NewRegistry ()
156
+ idp := oidctest .NewFakeIDP (t , oidctest .WithMiddlewares (
157
+ func (next http.Handler ) http.Handler {
158
+ return http .HandlerFunc (func (rw http.ResponseWriter , r * http.Request ) {
159
+ if ! c .NoHeaders {
160
+ rw .Header ().Set ("x-ratelimit-limit" , fmt .Sprintf ("%d" , c .Limit ))
161
+ rw .Header ().Set ("x-ratelimit-remaining" , fmt .Sprintf ("%d" , c .Remaining ))
162
+ rw .Header ().Set ("x-ratelimit-used" , fmt .Sprintf ("%d" , c .Used ))
163
+ rw .Header ().Set ("x-ratelimit-resource" , "core" )
164
+ rw .Header ().Set ("x-ratelimit-reset" , fmt .Sprintf ("%d" , c .Reset .Unix ()))
165
+ for _ , omit := range c .Omit {
166
+ rw .Header ().Del (omit )
167
+ }
168
+ }
169
+
170
+ next .ServeHTTP (rw , r )
171
+ })
172
+ }))
173
+
174
+ factory := promoauth .NewFactory (reg )
175
+ if ! c .at .IsZero () {
176
+ factory .Now = func () time.Time {
177
+ return c .at
178
+ }
179
+ }
180
+
181
+ cfg := factory .NewGithub ("test" , idp .OIDCConfig (t , []string {}))
182
+
183
+ // Do a single oauth2 call
184
+ ctx := testutil .Context (t , testutil .WaitShort )
185
+ ctx = context .WithValue (ctx , oauth2 .HTTPClient , idp .HTTPClient (nil ))
186
+ _ , err := cfg .Exchange (ctx , idp .CreateAuthCode (t , "foo" ))
187
+ require .NoError (t , err )
188
+
189
+ // Verify
190
+ labels := prometheus.Labels {
191
+ "name" : "test" ,
192
+ "resource" : "core" ,
193
+ }
194
+ pass := true
195
+ if ! c .ExpectNoMetrics {
196
+ pass = pass && assert .Equal (t , gaugeValue (t , reg , "coderd_oauth2_external_requests_rate_limit_total" , labels ), c .Limit , "limit" )
197
+ pass = pass && assert .Equal (t , gaugeValue (t , reg , "coderd_oauth2_external_requests_rate_limit_remaining" , labels ), c .Remaining , "remaining" )
198
+ pass = pass && assert .Equal (t , gaugeValue (t , reg , "coderd_oauth2_external_requests_rate_limit_used" , labels ), c .Used , "used" )
199
+ if ! c .at .IsZero () {
200
+ until := c .Reset .Sub (c .at )
201
+ // Float accuracy is not great, so we allow a delta of 2
202
+ pass = pass && assert .InDelta (t , gaugeValue (t , reg , "coderd_oauth2_external_requests_rate_limit_reset_in_seconds" , labels ), int (until .Seconds ()), 2 , "reset in" )
203
+ }
204
+ } else {
205
+ pass = pass && assert .Nil (t , metricValue (t , reg , "coderd_oauth2_external_requests_rate_limit_total" , labels ), "not exists" )
206
+ }
207
+
208
+ // Helpful debugging
209
+ if ! pass {
210
+ t .Log (registryDump (reg ))
211
+ }
212
+ })
213
+ }
214
+ }
215
+
216
+ func registryDump (reg * prometheus.Registry ) string {
217
+ h := promhttp .HandlerFor (reg , promhttp.HandlerOpts {})
218
+ rec := httptest .NewRecorder ()
219
+ req , _ := http .NewRequest (http .MethodGet , "/" , nil )
220
+ h .ServeHTTP (rec , req )
221
+ resp := rec .Result ()
222
+ data , _ := io .ReadAll (resp .Body )
223
+ _ = resp .Body .Close ()
224
+ return string (data )
72
225
}
73
226
74
227
func must [V any ](t * testing.T ) func (v V , err error ) V {
@@ -78,3 +231,39 @@ func must[V any](t *testing.T) func(v V, err error) V {
78
231
return v
79
232
}
80
233
}
234
+
235
+ func gaugeValue (t testing.TB , reg prometheus.Gatherer , metricName string , labels prometheus.Labels ) int {
236
+ labeled := metricValue (t , reg , metricName , labels )
237
+ require .NotNilf (t , labeled , "metric %q with labels %v not found" , metricName , labels )
238
+ return int (labeled .GetGauge ().GetValue ())
239
+ }
240
+
241
+ func counterValue (t testing.TB , reg prometheus.Gatherer , metricName string , labels prometheus.Labels ) int {
242
+ labeled := metricValue (t , reg , metricName , labels )
243
+ require .NotNilf (t , labeled , "metric %q with labels %v not found" , metricName , labels )
244
+ return int (labeled .GetCounter ().GetValue ())
245
+ }
246
+
247
+ func compare (reg prometheus.Gatherer , compare string ) error {
248
+ return ptestutil .GatherAndCompare (reg , strings .NewReader (compare ))
249
+ }
250
+
251
+ func metricValue (t testing.TB , reg prometheus.Gatherer , metricName string , labels prometheus.Labels ) * io_prometheus_client.Metric {
252
+ metrics , err := reg .Gather ()
253
+ require .NoError (t , err )
254
+
255
+ for _ , m := range metrics {
256
+ if m .GetName () == metricName {
257
+ for _ , labeled := range m .GetMetric () {
258
+ mLables := make (prometheus.Labels )
259
+ for _ , v := range labeled .GetLabel () {
260
+ mLables [v .GetName ()] = v .GetValue ()
261
+ }
262
+ if maps .Equal (mLables , labels ) {
263
+ return labeled
264
+ }
265
+ }
266
+ }
267
+ }
268
+ return nil
269
+ }
0 commit comments