Skip to content

feat: add azure oidc PKI auth instead of client secret #9054

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Aug 14, 2023
Merged
Prev Previous commit
Next Next commit
Custom token refresher to handle pki auth
  • Loading branch information
Emyrk committed Aug 11, 2023
commit 7346fbfc7aba7084dde048eac241bb6039b8ba96
1 change: 1 addition & 0 deletions cli/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -1531,6 +1531,7 @@ func configureOIDCPKI(orig *oauth2.Config, keyFile string, certFile string) (*oa
return oauthpki.NewOauth2PKIConfig(oauthpki.ConfigParams{
ClientID: orig.ClientID,
TokenURL: orig.Endpoint.TokenURL,
Scopes: orig.Scopes,
PemEncodedKey: keyData,
PemEncodedCert: certData,
Config: orig,
Expand Down
126 changes: 122 additions & 4 deletions coderd/oauthpki/oidcpki.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,22 @@ import (
"crypto/sha1"
"crypto/x509"
"encoding/base64"
"encoding/json"
"encoding/pem"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"time"

"github.com/coder/coder/coderd/httpmw"

"github.com/golang-jwt/jwt/v4"
"github.com/google/uuid"
"golang.org/x/oauth2"
"golang.org/x/oauth2/jws"
"golang.org/x/xerrors"

"github.com/coder/coder/coderd/httpmw"
)

// Config uses jwt assertions over client_secret for oauth2 authentication of
Expand All @@ -30,6 +36,7 @@ import (
type Config struct {
cfg httpmw.OAuth2Config

scopes []string
clientID string
tokenURL string
// ClientSecret is the private key of the PKI cert.
Expand All @@ -43,6 +50,7 @@ type Config struct {
type ConfigParams struct {
ClientID string
TokenURL string
Scopes []string
PemEncodedKey []byte
PemEncodedCert []byte

Expand All @@ -64,6 +72,10 @@ func NewOauth2PKIConfig(params ConfigParams) (*Config, error) {
if params.ClientID == "" {
return nil, xerrors.Errorf("")
}
if len(params.Scopes) == 0 {
return nil, xerrors.Errorf("scopes are required")
}

rsaKey, err := decodeClientKey(params.PemEncodedKey)
if err != nil {
return nil, err
Expand All @@ -81,6 +93,7 @@ func NewOauth2PKIConfig(params ConfigParams) (*Config, error) {
return &Config{
clientID: params.ClientID,
tokenURL: params.TokenURL,
scopes: params.Scopes,
cfg: params.Config,
clientKey: rsaKey,
x5t: base64.StdEncoding.EncodeToString(hashed[:]),
Expand Down Expand Up @@ -136,6 +149,111 @@ func (ja *Config) jwtToken() (string, error) {
}

func (ja *Config) TokenSource(ctx context.Context, token *oauth2.Token) oauth2.TokenSource {
// TODO: Hijack the http.Client to insert proper client auth assertions.
return ja.cfg.TokenSource(ctx, token)
return oauth2.ReuseTokenSource(token, &jwtTokenSource{
cfg: ja,
ctx: ctx,
refreshToken: token.RefreshToken,
})
}

type jwtTokenSource struct {
cfg *Config
ctx context.Context
refreshToken string
}

// Token must be safe for concurrent use by multiple go routines
// Very similar to the RetrieveToken implementation by the oauth2 package.
// https://github.com/golang/oauth2/blob/master/internal/token.go#L212
// Oauth2 package keeps this code unexported or in an /internal package,
// so we have to copy the implementation :(
func (src *jwtTokenSource) Token() (*oauth2.Token, error) {
if src.refreshToken == "" {
return nil, xerrors.New("oauth2: token expired and refresh token is not set")
}
cli := http.DefaultClient
if v, ok := src.ctx.Value(oauth2.HTTPClient).(*http.Client); ok {
cli = v
}

token, err := src.cfg.jwtToken()
if err != nil {
return nil, xerrors.Errorf("failed jwt assertion: %w", err)
}

v := url.Values{
"client_assertion": {token},
"client_assertion_type": {"urn:ietf:params:oauth:client-assertion-type:jwt-bearer"},
"client_id": {src.cfg.clientID},
"grant_type": {"refresh_token"},
"scope": {strings.Join(src.cfg.scopes, " ")},
"refresh_token": {src.refreshToken},
}
// Using params based auth
resp, err := cli.PostForm(src.cfg.tokenURL, v)
if err != nil {
return nil, xerrors.Errorf("oauth2: cannot get token: %w", err)
}

defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, xerrors.Errorf("oauth2: cannot fetch token reading response body: %w", err)
}

var tokenRes struct {
oauth2.Token
// Extra fields returned by the refresh that are needed
IDToken string `json:"id_token"`
ExpiresIn int64 `json:"expires_in"` // relative seconds from now
// error fields
// https://datatracker.ietf.org/doc/html/rfc6749#section-5.2
ErrorCode string `json:"error"`
ErrorDescription string `json:"error_description"`
ErrorURI string `json:"error_uri"`
}

unmarshalError := json.Unmarshal(body, &tokenRes)

if resp.StatusCode < 200 || resp.StatusCode > 299 {
// Return a standard oauth2 error. Attempt to read some error fields. The error fields
// can be encoded in a few places, so this does not catch all of them.
return nil, &oauth2.RetrieveError{
Response: resp,
Body: body,
// Best effort for error fields
ErrorCode: tokenRes.ErrorCode,
ErrorDescription: tokenRes.ErrorDescription,
ErrorURI: tokenRes.ErrorURI,
}
}

if unmarshalError != nil {
return nil, fmt.Errorf("oauth2: cannot unmarshal token: %v", err)
}

newToken := &oauth2.Token{
AccessToken: tokenRes.AccessToken,
TokenType: tokenRes.TokenType,
RefreshToken: tokenRes.RefreshToken,
}

if secs := tokenRes.ExpiresIn; secs > 0 {
newToken.Expiry = time.Now().Add(time.Duration(secs) * time.Second)
}

// ID token is a JWT token. We can decode it to get the expiry.
// Not really sure what to do if the ExpiresIn and JWT expiry differ,
// but this one is attached in the JWT and guaranteed to be right for local
// validation. So use this one if found.
if v := tokenRes.IDToken; v != "" {
// decode returned id token to get expiry
claimSet, err := jws.Decode(v)
if err != nil {
return nil, fmt.Errorf("oauth2: error decoding JWT token: %v", err)
}
newToken.Expiry = time.Unix(claimSet.Exp, 0)
}

return newToken, nil
}