Skip to content

Commit 7346fbf

Browse files
committed
Custom token refresher to handle pki auth
1 parent ce39f2d commit 7346fbf

File tree

2 files changed

+123
-4
lines changed

2 files changed

+123
-4
lines changed

cli/server.go

+1
Original file line numberDiff line numberDiff line change
@@ -1531,6 +1531,7 @@ func configureOIDCPKI(orig *oauth2.Config, keyFile string, certFile string) (*oa
15311531
return oauthpki.NewOauth2PKIConfig(oauthpki.ConfigParams{
15321532
ClientID: orig.ClientID,
15331533
TokenURL: orig.Endpoint.TokenURL,
1534+
Scopes: orig.Scopes,
15341535
PemEncodedKey: keyData,
15351536
PemEncodedCert: certData,
15361537
Config: orig,

coderd/oauthpki/oidcpki.go

+122-4
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,22 @@ import (
66
"crypto/sha1"
77
"crypto/x509"
88
"encoding/base64"
9+
"encoding/json"
910
"encoding/pem"
11+
"fmt"
12+
"io"
13+
"net/http"
14+
"net/url"
1015
"strings"
1116
"time"
1217

13-
"github.com/coder/coder/coderd/httpmw"
14-
1518
"github.com/golang-jwt/jwt/v4"
1619
"github.com/google/uuid"
1720
"golang.org/x/oauth2"
21+
"golang.org/x/oauth2/jws"
1822
"golang.org/x/xerrors"
23+
24+
"github.com/coder/coder/coderd/httpmw"
1925
)
2026

2127
// Config uses jwt assertions over client_secret for oauth2 authentication of
@@ -30,6 +36,7 @@ import (
3036
type Config struct {
3137
cfg httpmw.OAuth2Config
3238

39+
scopes []string
3340
clientID string
3441
tokenURL string
3542
// ClientSecret is the private key of the PKI cert.
@@ -43,6 +50,7 @@ type Config struct {
4350
type ConfigParams struct {
4451
ClientID string
4552
TokenURL string
53+
Scopes []string
4654
PemEncodedKey []byte
4755
PemEncodedCert []byte
4856

@@ -64,6 +72,10 @@ func NewOauth2PKIConfig(params ConfigParams) (*Config, error) {
6472
if params.ClientID == "" {
6573
return nil, xerrors.Errorf("")
6674
}
75+
if len(params.Scopes) == 0 {
76+
return nil, xerrors.Errorf("scopes are required")
77+
}
78+
6779
rsaKey, err := decodeClientKey(params.PemEncodedKey)
6880
if err != nil {
6981
return nil, err
@@ -81,6 +93,7 @@ func NewOauth2PKIConfig(params ConfigParams) (*Config, error) {
8193
return &Config{
8294
clientID: params.ClientID,
8395
tokenURL: params.TokenURL,
96+
scopes: params.Scopes,
8497
cfg: params.Config,
8598
clientKey: rsaKey,
8699
x5t: base64.StdEncoding.EncodeToString(hashed[:]),
@@ -136,6 +149,111 @@ func (ja *Config) jwtToken() (string, error) {
136149
}
137150

138151
func (ja *Config) TokenSource(ctx context.Context, token *oauth2.Token) oauth2.TokenSource {
139-
// TODO: Hijack the http.Client to insert proper client auth assertions.
140-
return ja.cfg.TokenSource(ctx, token)
152+
return oauth2.ReuseTokenSource(token, &jwtTokenSource{
153+
cfg: ja,
154+
ctx: ctx,
155+
refreshToken: token.RefreshToken,
156+
})
157+
}
158+
159+
type jwtTokenSource struct {
160+
cfg *Config
161+
ctx context.Context
162+
refreshToken string
163+
}
164+
165+
// Token must be safe for concurrent use by multiple go routines
166+
// Very similar to the RetrieveToken implementation by the oauth2 package.
167+
// https://github.com/golang/oauth2/blob/master/internal/token.go#L212
168+
// Oauth2 package keeps this code unexported or in an /internal package,
169+
// so we have to copy the implementation :(
170+
func (src *jwtTokenSource) Token() (*oauth2.Token, error) {
171+
if src.refreshToken == "" {
172+
return nil, xerrors.New("oauth2: token expired and refresh token is not set")
173+
}
174+
cli := http.DefaultClient
175+
if v, ok := src.ctx.Value(oauth2.HTTPClient).(*http.Client); ok {
176+
cli = v
177+
}
178+
179+
token, err := src.cfg.jwtToken()
180+
if err != nil {
181+
return nil, xerrors.Errorf("failed jwt assertion: %w", err)
182+
}
183+
184+
v := url.Values{
185+
"client_assertion": {token},
186+
"client_assertion_type": {"urn:ietf:params:oauth:client-assertion-type:jwt-bearer"},
187+
"client_id": {src.cfg.clientID},
188+
"grant_type": {"refresh_token"},
189+
"scope": {strings.Join(src.cfg.scopes, " ")},
190+
"refresh_token": {src.refreshToken},
191+
}
192+
// Using params based auth
193+
resp, err := cli.PostForm(src.cfg.tokenURL, v)
194+
if err != nil {
195+
return nil, xerrors.Errorf("oauth2: cannot get token: %w", err)
196+
}
197+
198+
defer resp.Body.Close()
199+
body, err := io.ReadAll(resp.Body)
200+
if err != nil {
201+
return nil, xerrors.Errorf("oauth2: cannot fetch token reading response body: %w", err)
202+
}
203+
204+
var tokenRes struct {
205+
oauth2.Token
206+
// Extra fields returned by the refresh that are needed
207+
IDToken string `json:"id_token"`
208+
ExpiresIn int64 `json:"expires_in"` // relative seconds from now
209+
// error fields
210+
// https://datatracker.ietf.org/doc/html/rfc6749#section-5.2
211+
ErrorCode string `json:"error"`
212+
ErrorDescription string `json:"error_description"`
213+
ErrorURI string `json:"error_uri"`
214+
}
215+
216+
unmarshalError := json.Unmarshal(body, &tokenRes)
217+
218+
if resp.StatusCode < 200 || resp.StatusCode > 299 {
219+
// Return a standard oauth2 error. Attempt to read some error fields. The error fields
220+
// can be encoded in a few places, so this does not catch all of them.
221+
return nil, &oauth2.RetrieveError{
222+
Response: resp,
223+
Body: body,
224+
// Best effort for error fields
225+
ErrorCode: tokenRes.ErrorCode,
226+
ErrorDescription: tokenRes.ErrorDescription,
227+
ErrorURI: tokenRes.ErrorURI,
228+
}
229+
}
230+
231+
if unmarshalError != nil {
232+
return nil, fmt.Errorf("oauth2: cannot unmarshal token: %v", err)
233+
}
234+
235+
newToken := &oauth2.Token{
236+
AccessToken: tokenRes.AccessToken,
237+
TokenType: tokenRes.TokenType,
238+
RefreshToken: tokenRes.RefreshToken,
239+
}
240+
241+
if secs := tokenRes.ExpiresIn; secs > 0 {
242+
newToken.Expiry = time.Now().Add(time.Duration(secs) * time.Second)
243+
}
244+
245+
// ID token is a JWT token. We can decode it to get the expiry.
246+
// Not really sure what to do if the ExpiresIn and JWT expiry differ,
247+
// but this one is attached in the JWT and guaranteed to be right for local
248+
// validation. So use this one if found.
249+
if v := tokenRes.IDToken; v != "" {
250+
// decode returned id token to get expiry
251+
claimSet, err := jws.Decode(v)
252+
if err != nil {
253+
return nil, fmt.Errorf("oauth2: error decoding JWT token: %v", err)
254+
}
255+
newToken.Expiry = time.Unix(claimSet.Exp, 0)
256+
}
257+
258+
return newToken, nil
141259
}

0 commit comments

Comments
 (0)