@@ -10,6 +10,16 @@ import (
10
10
"golang.org/x/oauth2"
11
11
)
12
12
13
+ type Oauth2Source string
14
+
15
+ const (
16
+ SourceValidateToken Oauth2Source = "ValidateToken"
17
+ SourceExchange Oauth2Source = "Exchange"
18
+ SourceTokenSource Oauth2Source = "TokenSource"
19
+ SourceAppInstallations Oauth2Source = "AppInstallations"
20
+ SourceAuthorizeDevice Oauth2Source = "AuthorizeDevice"
21
+ )
22
+
13
23
// OAuth2Config exposes a subset of *oauth2.Config functions for easier testing.
14
24
// *oauth2.Config should be used instead of implementing this in production.
15
25
type OAuth2Config interface {
@@ -27,7 +37,7 @@ type InstrumentedOAuth2Config interface {
27
37
28
38
// Do is provided as a convenience method to make a request with the oauth2 client.
29
39
// It mirrors `http.Client.Do`.
30
- Do (ctx context.Context , source string , req * http.Request ) (* http.Response , error )
40
+ Do (ctx context.Context , source Oauth2Source , req * http.Request ) (* http.Response , error )
31
41
}
32
42
33
43
var _ OAuth2Config = (* Config )(nil )
@@ -79,7 +89,7 @@ type Config struct {
79
89
metrics * metrics
80
90
}
81
91
82
- func (c * Config ) Do (ctx context.Context , source string , req * http.Request ) (* http.Response , error ) {
92
+ func (c * Config ) Do (ctx context.Context , source Oauth2Source , req * http.Request ) (* http.Response , error ) {
83
93
cli := c .oauthHTTPClient (ctx , source )
84
94
return cli .Do (req )
85
95
}
@@ -90,11 +100,11 @@ func (c *Config) AuthCodeURL(state string, opts ...oauth2.AuthCodeOption) string
90
100
}
91
101
92
102
func (c * Config ) Exchange (ctx context.Context , code string , opts ... oauth2.AuthCodeOption ) (* oauth2.Token , error ) {
93
- return c .underlying .Exchange (c .wrapClient (ctx , "Exchange" ), code , opts ... )
103
+ return c .underlying .Exchange (c .wrapClient (ctx , SourceExchange ), code , opts ... )
94
104
}
95
105
96
106
func (c * Config ) TokenSource (ctx context.Context , token * oauth2.Token ) oauth2.TokenSource {
97
- return c .underlying .TokenSource (c .wrapClient (ctx , "TokenSource" ), token )
107
+ return c .underlying .TokenSource (c .wrapClient (ctx , SourceTokenSource ), token )
98
108
}
99
109
100
110
// wrapClient is the only way we can accurately instrument the oauth2 client.
@@ -104,12 +114,12 @@ func (c *Config) TokenSource(ctx context.Context, token *oauth2.Token) oauth2.To
104
114
// For example, the 'TokenSource' method will return a token
105
115
// source that will make a network request when the 'Token' method is called on
106
116
// it if the token is expired.
107
- func (c * Config ) wrapClient (ctx context.Context , source string ) context.Context {
117
+ func (c * Config ) wrapClient (ctx context.Context , source Oauth2Source ) context.Context {
108
118
return context .WithValue (ctx , oauth2 .HTTPClient , c .oauthHTTPClient (ctx , source ))
109
119
}
110
120
111
121
// oauthHTTPClient returns an http client that will instrument every request made.
112
- func (c * Config ) oauthHTTPClient (ctx context.Context , source string ) * http.Client {
122
+ func (c * Config ) oauthHTTPClient (ctx context.Context , source Oauth2Source ) * http.Client {
113
123
cli := & http.Client {}
114
124
115
125
// Check if the context has a http client already.
@@ -124,13 +134,13 @@ func (c *Config) oauthHTTPClient(ctx context.Context, source string) *http.Clien
124
134
125
135
type instrumentedTripper struct {
126
136
c * Config
127
- source string
137
+ source Oauth2Source
128
138
underlying http.RoundTripper
129
139
}
130
140
131
141
// newInstrumentedTripper intercepts a http request, and increments the
132
142
// externalRequestCount metric.
133
- func newInstrumentedTripper (c * Config , source string , under http.RoundTripper ) * instrumentedTripper {
143
+ func newInstrumentedTripper (c * Config , source Oauth2Source , under http.RoundTripper ) * instrumentedTripper {
134
144
if under == nil {
135
145
under = http .DefaultTransport
136
146
}
@@ -156,7 +166,7 @@ func (i *instrumentedTripper) RoundTrip(r *http.Request) (*http.Response, error)
156
166
}
157
167
i .c .metrics .externalRequestCount .With (prometheus.Labels {
158
168
"name" : i .c .name ,
159
- "source" : i .source ,
169
+ "source" : string ( i .source ) ,
160
170
"status_code" : fmt .Sprintf ("%d" , statusCode ),
161
171
}).Inc ()
162
172
return resp , err
0 commit comments