Skip to content

feat: add azure oidc PKI auth instead of client secret #9054

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Aug 14, 2023
Merged
64 changes: 55 additions & 9 deletions cli/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ import (
"github.com/coder/coder/coderd/gitsshkey"
"github.com/coder/coder/coderd/httpapi"
"github.com/coder/coder/coderd/httpmw"
"github.com/coder/coder/coderd/oauthpki"
"github.com/coder/coder/coderd/prometheusmetrics"
"github.com/coder/coder/coderd/schedule"
"github.com/coder/coder/coderd/telemetry"
Expand Down Expand Up @@ -551,7 +552,7 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd.
}
}

if cfg.OIDC.ClientSecret != "" {
if cfg.OIDC.ClientKeyFile != "" || cfg.OIDC.ClientSecret != "" {
if cfg.OIDC.ClientID == "" {
return xerrors.Errorf("OIDC client ID be set!")
}
Expand All @@ -578,15 +579,33 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd.
if slice.Contains(cfg.OIDC.Scopes, "groups") && cfg.OIDC.GroupField == "" {
cfg.OIDC.GroupField = "groups"
}
oauthCfg := &oauth2.Config{
ClientID: cfg.OIDC.ClientID.String(),
ClientSecret: cfg.OIDC.ClientSecret.String(),
RedirectURL: redirectURL.String(),
Endpoint: oidcProvider.Endpoint(),
Scopes: cfg.OIDC.Scopes,
}

var useCfg httpmw.OAuth2Config = oauthCfg
if cfg.OIDC.ClientKeyFile != "" {
// PKI authentication is done in the params. If a
// counter example is found, we can add a config option to
// change this.
oauthCfg.Endpoint.AuthStyle = oauth2.AuthStyleInParams
if cfg.OIDC.ClientSecret != "" {
return xerrors.Errorf("cannot specify both oidc client secret and oidc client key file")
}

pkiCfg, err := configureOIDCPKI(oauthCfg, cfg.OIDC.ClientKeyFile.Value(), cfg.OIDC.ClientCertFile.Value())
if err != nil {
return xerrors.Errorf("configure oauth pki authentication: %w", err)
}
useCfg = pkiCfg
}
options.OIDCConfig = &coderd.OIDCConfig{
OAuth2Config: &oauth2.Config{
ClientID: cfg.OIDC.ClientID.String(),
ClientSecret: cfg.OIDC.ClientSecret.String(),
RedirectURL: redirectURL.String(),
Endpoint: oidcProvider.Endpoint(),
Scopes: cfg.OIDC.Scopes,
},
Provider: oidcProvider,
OAuth2Config: useCfg,
Provider: oidcProvider,
Verifier: oidcProvider.Verifier(&oidc.Config{
ClientID: cfg.OIDC.ClientID.String(),
}),
Expand Down Expand Up @@ -1494,6 +1513,33 @@ func configureTLS(tlsMinVersion, tlsClientAuth string, tlsCertFiles, tlsKeyFiles
return tlsConfig, nil
}

func configureOIDCPKI(orig *oauth2.Config, keyFile string, certFile string) (*oauthpki.Config, error) {
// Read the files
keyData, err := os.ReadFile(keyFile)
if err != nil {
return nil, xerrors.Errorf("read oidc client key file: %w", err)
}

var certData []byte
// According to the spec, this is not required. So do not require it on the initial loading
// of the PKI config.
if certFile != "" {
certData, err = os.ReadFile(certFile)
if err != nil {
return nil, xerrors.Errorf("read oidc client cert file: %w", err)
}
}

return oauthpki.NewOauth2PKIConfig(oauthpki.ConfigParams{
ClientID: orig.ClientID,
TokenURL: orig.Endpoint.TokenURL,
Scopes: orig.Scopes,
PemEncodedKey: keyData,
PemEncodedCert: certData,
Config: orig,
})
}

func configureCAPool(tlsClientCAFile string, tlsConfig *tls.Config) error {
if tlsClientCAFile != "" {
caPool := x509.NewCertPool()
Expand Down
265 changes: 265 additions & 0 deletions coderd/oauthpki/oidcpki.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,265 @@
package oauthpki

import (
"context"
"crypto/rsa"
"crypto/sha1"
"crypto/x509"
"encoding/base64"
"encoding/json"
"encoding/pem"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"time"

"github.com/golang-jwt/jwt/v4"
"github.com/google/uuid"
"golang.org/x/oauth2"
"golang.org/x/oauth2/jws"
"golang.org/x/xerrors"

"github.com/coder/coder/coderd/httpmw"
)

// Config uses jwt assertions over client_secret for oauth2 authentication of
// the application. This implementation was made specifically for Azure AD.
//
// https://learn.microsoft.com/en-us/azure/active-directory/develop/certificate-credentials
//
// However this does mostly follow the standard. We can generalize this as we
// include support for more IDPs.
//
// https://datatracker.ietf.org/doc/html/rfc7523
type Config struct {
cfg httpmw.OAuth2Config

// These values should match those provided in the oauth2.Config.
// Because the inner config is an interface, we need to duplicate these
// values here.
scopes []string
clientID string
tokenURL string

// ClientSecret is the private key of the PKI cert.
// Azure AD only supports RS256 signing algorithm.
clientKey *rsa.PrivateKey
// Base64url-encoded SHA-1 thumbprint of the X.509 certificate's DER encoding.
// This is specific to Azure AD
x5t string
}

type ConfigParams struct {
ClientID string
TokenURL string
Scopes []string
PemEncodedKey []byte
PemEncodedCert []byte

Config httpmw.OAuth2Config
}

// NewOauth2PKIConfig creates the oauth2 config for PKI based auth. It requires the certificate and it's private key.
// The values should be passed in as PEM encoded values, which is the standard encoding for x509 certs saved to disk.
// It should look like:
//
// -----BEGIN RSA PRIVATE KEY----
// ...
// -----END RSA PRIVATE KEY-----
//
// -----BEGIN CERTIFICATE-----
// ...
// -----END CERTIFICATE-----
func NewOauth2PKIConfig(params ConfigParams) (*Config, error) {
if params.ClientID == "" {
return nil, xerrors.Errorf("")
}
if len(params.Scopes) == 0 {
return nil, xerrors.Errorf("scopes are required")
}

rsaKey, err := decodeClientKey(params.PemEncodedKey)
if err != nil {
return nil, err
}

// Azure AD requires a certificate. The sha1 of the cert is used to identify the signer.
// This is not required in the general specification.
if strings.Contains(strings.ToLower(params.TokenURL), "microsoftonline") && len(params.PemEncodedCert) == 0 {
return nil, xerrors.Errorf("oidc client certificate is required and missing")
}

block, _ := pem.Decode(params.PemEncodedCert)
hashed := sha1.Sum(block.Bytes)

return &Config{
clientID: params.ClientID,
tokenURL: params.TokenURL,
scopes: params.Scopes,
cfg: params.Config,
clientKey: rsaKey,
x5t: base64.StdEncoding.EncodeToString(hashed[:]),
}, nil
}

// decodeClientKey decodes a PEM encoded rsa secret.
func decodeClientKey(pemEncoded []byte) (*rsa.PrivateKey, error) {
block, _ := pem.Decode(pemEncoded)
key, err := x509.ParsePKCS1PrivateKey(block.Bytes)
if err != nil {
return nil, xerrors.Errorf("failed to parse private key: %w", err)
}

return key, nil
}

func (ja *Config) AuthCodeURL(state string, opts ...oauth2.AuthCodeOption) string {
return ja.cfg.AuthCodeURL(state, opts...)
}

// Exchange includes the client_assertion signed JWT.
func (ja *Config) Exchange(ctx context.Context, code string, opts ...oauth2.AuthCodeOption) (*oauth2.Token, error) {
signed, err := ja.jwtToken()
if err != nil {
return nil, xerrors.Errorf("failed jwt assertion: %w", err)
}
opts = append(opts,
oauth2.SetAuthURLParam("client_assertion_type", "urn:ietf:params:oauth:client-assertion-type:jwt-bearer"),
oauth2.SetAuthURLParam("client_assertion", signed),
)
return ja.cfg.Exchange(ctx, code, opts...)
}

func (ja *Config) jwtToken() (string, error) {
now := time.Now()
token := jwt.NewWithClaims(jwt.SigningMethodRS256, jwt.MapClaims{
"iss": ja.clientID,
"sub": ja.clientID,
"aud": ja.tokenURL,
// 5-10 minutes is recommended in the Azure docs.
// So we'll use 5 minutes.
"exp": now.Add(time.Minute * 5).Unix(),
"jti": uuid.New().String(),
"nbf": now.Unix(),
"iat": now.Unix(),
})
token.Header["x5t"] = ja.x5t

signed, err := token.SignedString(ja.clientKey)
if err != nil {
return "", xerrors.Errorf("sign jwt assertion: %w", err)
}
return signed, nil
}

func (ja *Config) TokenSource(ctx context.Context, token *oauth2.Token) oauth2.TokenSource {
return oauth2.ReuseTokenSource(token, &jwtTokenSource{
cfg: ja,
ctx: ctx,
refreshToken: token.RefreshToken,
})
}

type jwtTokenSource struct {
cfg *Config
ctx context.Context
refreshToken string
}

// Token must be safe for concurrent use by multiple go routines
// Very similar to the RetrieveToken implementation by the oauth2 package.
// https://github.com/golang/oauth2/blob/master/internal/token.go#L212
// Oauth2 package keeps this code unexported or in an /internal package,
// so we have to copy the implementation :(
func (src *jwtTokenSource) Token() (*oauth2.Token, error) {
if src.refreshToken == "" {
return nil, xerrors.New("oauth2: token expired and refresh token is not set")
}
cli := http.DefaultClient
if v, ok := src.ctx.Value(oauth2.HTTPClient).(*http.Client); ok {
cli = v
}

token, err := src.cfg.jwtToken()
if err != nil {
return nil, xerrors.Errorf("failed jwt assertion: %w", err)
}

v := url.Values{
"client_assertion": {token},
"client_assertion_type": {"urn:ietf:params:oauth:client-assertion-type:jwt-bearer"},
"client_id": {src.cfg.clientID},
"grant_type": {"refresh_token"},
"scope": {strings.Join(src.cfg.scopes, " ")},
"refresh_token": {src.refreshToken},
}
// Using params based auth
resp, err := cli.PostForm(src.cfg.tokenURL, v)
if err != nil {
return nil, xerrors.Errorf("oauth2: cannot get token: %w", err)
}

defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, xerrors.Errorf("oauth2: cannot fetch token reading response body: %w", err)
}

var tokenRes struct {
oauth2.Token
// Extra fields returned by the refresh that are needed
IDToken string `json:"id_token"`
ExpiresIn int64 `json:"expires_in"` // relative seconds from now
// error fields
// https://datatracker.ietf.org/doc/html/rfc6749#section-5.2
ErrorCode string `json:"error"`
ErrorDescription string `json:"error_description"`
ErrorURI string `json:"error_uri"`
}

unmarshalError := json.Unmarshal(body, &tokenRes)

if resp.StatusCode < 200 || resp.StatusCode > 299 {
// Return a standard oauth2 error. Attempt to read some error fields. The error fields
// can be encoded in a few places, so this does not catch all of them.
return nil, &oauth2.RetrieveError{
Response: resp,
Body: body,
// Best effort for error fields
ErrorCode: tokenRes.ErrorCode,
ErrorDescription: tokenRes.ErrorDescription,
ErrorURI: tokenRes.ErrorURI,
}
}

if unmarshalError != nil {
return nil, fmt.Errorf("oauth2: cannot unmarshal token: %v", err)
}

newToken := &oauth2.Token{
AccessToken: tokenRes.AccessToken,
TokenType: tokenRes.TokenType,
RefreshToken: tokenRes.RefreshToken,
}

if secs := tokenRes.ExpiresIn; secs > 0 {
newToken.Expiry = time.Now().Add(time.Duration(secs) * time.Second)
}

// ID token is a JWT token. We can decode it to get the expiry.
// Not really sure what to do if the ExpiresIn and JWT expiry differ,
// but this one is attached in the JWT and guaranteed to be right for local
// validation. So use this one if found.
if v := tokenRes.IDToken; v != "" {
// decode returned id token to get expiry
claimSet, err := jws.Decode(v)
if err != nil {
return nil, fmt.Errorf("oauth2: error decoding JWT token: %v", err)
}
newToken.Expiry = time.Unix(claimSet.Exp, 0)
}

return newToken, nil
}
Loading