@@ -6,16 +6,22 @@ import (
6
6
"crypto/sha1"
7
7
"crypto/x509"
8
8
"encoding/base64"
9
+ "encoding/json"
9
10
"encoding/pem"
11
+ "fmt"
12
+ "io"
13
+ "net/http"
14
+ "net/url"
10
15
"strings"
11
16
"time"
12
17
13
- "github.com/coder/coder/coderd/httpmw"
14
-
15
18
"github.com/golang-jwt/jwt/v4"
16
19
"github.com/google/uuid"
17
20
"golang.org/x/oauth2"
21
+ "golang.org/x/oauth2/jws"
18
22
"golang.org/x/xerrors"
23
+
24
+ "github.com/coder/coder/coderd/httpmw"
19
25
)
20
26
21
27
// Config uses jwt assertions over client_secret for oauth2 authentication of
@@ -30,6 +36,7 @@ import (
30
36
type Config struct {
31
37
cfg httpmw.OAuth2Config
32
38
39
+ scopes []string
33
40
clientID string
34
41
tokenURL string
35
42
// ClientSecret is the private key of the PKI cert.
@@ -43,6 +50,7 @@ type Config struct {
43
50
type ConfigParams struct {
44
51
ClientID string
45
52
TokenURL string
53
+ Scopes []string
46
54
PemEncodedKey []byte
47
55
PemEncodedCert []byte
48
56
@@ -64,6 +72,10 @@ func NewOauth2PKIConfig(params ConfigParams) (*Config, error) {
64
72
if params .ClientID == "" {
65
73
return nil , xerrors .Errorf ("" )
66
74
}
75
+ if len (params .Scopes ) == 0 {
76
+ return nil , xerrors .Errorf ("scopes are required" )
77
+ }
78
+
67
79
rsaKey , err := decodeClientKey (params .PemEncodedKey )
68
80
if err != nil {
69
81
return nil , err
@@ -81,6 +93,7 @@ func NewOauth2PKIConfig(params ConfigParams) (*Config, error) {
81
93
return & Config {
82
94
clientID : params .ClientID ,
83
95
tokenURL : params .TokenURL ,
96
+ scopes : params .Scopes ,
84
97
cfg : params .Config ,
85
98
clientKey : rsaKey ,
86
99
x5t : base64 .StdEncoding .EncodeToString (hashed [:]),
@@ -136,6 +149,111 @@ func (ja *Config) jwtToken() (string, error) {
136
149
}
137
150
138
151
func (ja * Config ) TokenSource (ctx context.Context , token * oauth2.Token ) oauth2.TokenSource {
139
- // TODO: Hijack the http.Client to insert proper client auth assertions.
140
- return ja .cfg .TokenSource (ctx , token )
152
+ return oauth2 .ReuseTokenSource (token , & jwtTokenSource {
153
+ cfg : ja ,
154
+ ctx : ctx ,
155
+ refreshToken : token .RefreshToken ,
156
+ })
157
+ }
158
+
159
+ type jwtTokenSource struct {
160
+ cfg * Config
161
+ ctx context.Context
162
+ refreshToken string
163
+ }
164
+
165
+ // Token must be safe for concurrent use by multiple go routines
166
+ // Very similar to the RetrieveToken implementation by the oauth2 package.
167
+ // https://github.com/golang/oauth2/blob/master/internal/token.go#L212
168
+ // Oauth2 package keeps this code unexported or in an /internal package,
169
+ // so we have to copy the implementation :(
170
+ func (src * jwtTokenSource ) Token () (* oauth2.Token , error ) {
171
+ if src .refreshToken == "" {
172
+ return nil , xerrors .New ("oauth2: token expired and refresh token is not set" )
173
+ }
174
+ cli := http .DefaultClient
175
+ if v , ok := src .ctx .Value (oauth2 .HTTPClient ).(* http.Client ); ok {
176
+ cli = v
177
+ }
178
+
179
+ token , err := src .cfg .jwtToken ()
180
+ if err != nil {
181
+ return nil , xerrors .Errorf ("failed jwt assertion: %w" , err )
182
+ }
183
+
184
+ v := url.Values {
185
+ "client_assertion" : {token },
186
+ "client_assertion_type" : {"urn:ietf:params:oauth:client-assertion-type:jwt-bearer" },
187
+ "client_id" : {src .cfg .clientID },
188
+ "grant_type" : {"refresh_token" },
189
+ "scope" : {strings .Join (src .cfg .scopes , " " )},
190
+ "refresh_token" : {src .refreshToken },
191
+ }
192
+ // Using params based auth
193
+ resp , err := cli .PostForm (src .cfg .tokenURL , v )
194
+ if err != nil {
195
+ return nil , xerrors .Errorf ("oauth2: cannot get token: %w" , err )
196
+ }
197
+
198
+ defer resp .Body .Close ()
199
+ body , err := io .ReadAll (resp .Body )
200
+ if err != nil {
201
+ return nil , xerrors .Errorf ("oauth2: cannot fetch token reading response body: %w" , err )
202
+ }
203
+
204
+ var tokenRes struct {
205
+ oauth2.Token
206
+ // Extra fields returned by the refresh that are needed
207
+ IDToken string `json:"id_token"`
208
+ ExpiresIn int64 `json:"expires_in"` // relative seconds from now
209
+ // error fields
210
+ // https://datatracker.ietf.org/doc/html/rfc6749#section-5.2
211
+ ErrorCode string `json:"error"`
212
+ ErrorDescription string `json:"error_description"`
213
+ ErrorURI string `json:"error_uri"`
214
+ }
215
+
216
+ unmarshalError := json .Unmarshal (body , & tokenRes )
217
+
218
+ if resp .StatusCode < 200 || resp .StatusCode > 299 {
219
+ // Return a standard oauth2 error. Attempt to read some error fields. The error fields
220
+ // can be encoded in a few places, so this does not catch all of them.
221
+ return nil , & oauth2.RetrieveError {
222
+ Response : resp ,
223
+ Body : body ,
224
+ // Best effort for error fields
225
+ ErrorCode : tokenRes .ErrorCode ,
226
+ ErrorDescription : tokenRes .ErrorDescription ,
227
+ ErrorURI : tokenRes .ErrorURI ,
228
+ }
229
+ }
230
+
231
+ if unmarshalError != nil {
232
+ return nil , fmt .Errorf ("oauth2: cannot unmarshal token: %v" , err )
233
+ }
234
+
235
+ newToken := & oauth2.Token {
236
+ AccessToken : tokenRes .AccessToken ,
237
+ TokenType : tokenRes .TokenType ,
238
+ RefreshToken : tokenRes .RefreshToken ,
239
+ }
240
+
241
+ if secs := tokenRes .ExpiresIn ; secs > 0 {
242
+ newToken .Expiry = time .Now ().Add (time .Duration (secs ) * time .Second )
243
+ }
244
+
245
+ // ID token is a JWT token. We can decode it to get the expiry.
246
+ // Not really sure what to do if the ExpiresIn and JWT expiry differ,
247
+ // but this one is attached in the JWT and guaranteed to be right for local
248
+ // validation. So use this one if found.
249
+ if v := tokenRes .IDToken ; v != "" {
250
+ // decode returned id token to get expiry
251
+ claimSet , err := jws .Decode (v )
252
+ if err != nil {
253
+ return nil , fmt .Errorf ("oauth2: error decoding JWT token: %v" , err )
254
+ }
255
+ newToken .Expiry = time .Unix (claimSet .Exp , 0 )
256
+ }
257
+
258
+ return newToken , nil
141
259
}
0 commit comments