Skip to content

Commit 664895f

Browse files
committed
add unit test
1 parent c62e548 commit 664895f

File tree

3 files changed

+195
-17
lines changed

3 files changed

+195
-17
lines changed

cli/server.go

+8-3
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,6 @@ import (
3333
"sync/atomic"
3434
"time"
3535

36-
"github.com/coder/coder/coderd/oauthpki"
37-
3836
"github.com/coreos/go-oidc/v3/oidc"
3937
"github.com/coreos/go-systemd/daemon"
4038
embeddedpostgres "github.com/fergusstrange/embedded-postgres"
@@ -78,6 +76,7 @@ import (
7876
"github.com/coder/coder/coderd/gitsshkey"
7977
"github.com/coder/coder/coderd/httpapi"
8078
"github.com/coder/coder/coderd/httpmw"
79+
"github.com/coder/coder/coderd/oauthpki"
8180
"github.com/coder/coder/coderd/prometheusmetrics"
8281
"github.com/coder/coder/coderd/schedule"
8382
"github.com/coder/coder/coderd/telemetry"
@@ -1525,7 +1524,13 @@ func configureOIDCPKI(orig *oauth2.Config, keyFile string, certFile string) (*oa
15251524
}
15261525
}
15271526

1528-
return oauthpki.NewOauth2PKIConfig(orig, keyData, certData)
1527+
return oauthpki.NewOauth2PKIConfig(oauthpki.ConfigParams{
1528+
ClientID: orig.ClientID,
1529+
TokenURL: orig.Endpoint.TokenURL,
1530+
PemEncodedKey: keyData,
1531+
PemEncodedCert: certData,
1532+
Config: orig,
1533+
})
15291534
}
15301535

15311536
func configureCAPool(tlsClientCAFile string, tlsConfig *tls.Config) error {

coderd/oauthpki/oidcpki.go

+32-14
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ import (
1010
"strings"
1111
"time"
1212

13+
"github.com/coder/coder/coderd/httpmw"
14+
1315
"github.com/golang-jwt/jwt/v4"
1416
"github.com/google/uuid"
1517
"golang.org/x/oauth2"
@@ -26,8 +28,10 @@ import (
2628
//
2729
// https://datatracker.ietf.org/doc/html/rfc7523
2830
type Config struct {
29-
*oauth2.Config
31+
cfg httpmw.OAuth2Config
3032

33+
clientID string
34+
tokenURL string
3135
// ClientSecret is the private key of the PKI cert.
3236
// Azure AD only supports RS256 signing algorithm.
3337
clientKey *rsa.PrivateKey
@@ -36,6 +40,15 @@ type Config struct {
3640
x5t string
3741
}
3842

43+
type ConfigParams struct {
44+
ClientID string
45+
TokenURL string
46+
PemEncodedKey []byte
47+
PemEncodedCert []byte
48+
49+
Config httpmw.OAuth2Config
50+
}
51+
3952
// NewOauth2PKIConfig creates the oauth2 config for PKI based auth. It requires the certificate and it's private key.
4053
// The values should be passed in as PEM encoded values, which is the standard encoding for x509 certs saved to disk.
4154
// It should look like:
@@ -47,30 +60,35 @@ type Config struct {
4760
// -----BEGIN CERTIFICATE-----
4861
// ...
4962
// -----END CERTIFICATE-----
50-
func NewOauth2PKIConfig(config *oauth2.Config, pemEncodedKey []byte, pemEncodedCert []byte) (*Config, error) {
51-
rsaKey, err := decodeKeyCertificate(pemEncodedKey)
63+
func NewOauth2PKIConfig(params ConfigParams) (*Config, error) {
64+
if params.ClientID == "" {
65+
return nil, xerrors.Errorf("")
66+
}
67+
rsaKey, err := decodeClientKey(params.PemEncodedKey)
5268
if err != nil {
5369
return nil, err
5470
}
5571

5672
// Azure AD requires a certificate. The sha1 of the cert is used to identify the signer.
5773
// This is not required in the general specification.
58-
if strings.Contains(strings.ToLower(config.Endpoint.TokenURL), "microsoftonline") && len(pemEncodedCert) == 0 {
74+
if strings.Contains(strings.ToLower(params.TokenURL), "microsoftonline") && len(params.PemEncodedCert) == 0 {
5975
return nil, xerrors.Errorf("oidc client certificate is required and missing")
6076
}
6177

62-
block, _ := pem.Decode(pemEncodedCert)
78+
block, _ := pem.Decode(params.PemEncodedCert)
6379
hashed := sha1.Sum(block.Bytes)
6480

6581
return &Config{
66-
Config: config,
82+
clientID: params.ClientID,
83+
tokenURL: params.TokenURL,
84+
cfg: params.Config,
6785
clientKey: rsaKey,
6886
x5t: base64.StdEncoding.EncodeToString(hashed[:]),
6987
}, nil
7088
}
7189

72-
// decodeKeyCertificate decodes a PEM encoded PKI cert.
73-
func decodeKeyCertificate(pemEncoded []byte) (*rsa.PrivateKey, error) {
90+
// decodeClientKey decodes a PEM encoded rsa secret.
91+
func decodeClientKey(pemEncoded []byte) (*rsa.PrivateKey, error) {
7492
block, _ := pem.Decode(pemEncoded)
7593
key, err := x509.ParsePKCS1PrivateKey(block.Bytes)
7694
if err != nil {
@@ -81,16 +99,16 @@ func decodeKeyCertificate(pemEncoded []byte) (*rsa.PrivateKey, error) {
8199
}
82100

83101
func (ja *Config) AuthCodeURL(state string, opts ...oauth2.AuthCodeOption) string {
84-
return ja.Config.AuthCodeURL(state, opts...)
102+
return ja.cfg.AuthCodeURL(state, opts...)
85103
}
86104

87105
// Exchange includes the client_assertion signed JWT.
88106
func (ja *Config) Exchange(ctx context.Context, code string, opts ...oauth2.AuthCodeOption) (*oauth2.Token, error) {
89107
now := time.Now()
90108
token := jwt.NewWithClaims(jwt.SigningMethodRS256, jwt.MapClaims{
91-
"iss": ja.Config.ClientID,
92-
"sub": ja.Config.ClientID,
93-
"aud": ja.Config.Endpoint.TokenURL,
109+
"iss": ja.clientID,
110+
"sub": ja.clientID,
111+
"aud": ja.tokenURL,
94112
"exp": now.Add(time.Minute * 5).Unix(),
95113
"jti": uuid.New().String(),
96114
"nbf": now.Unix(),
@@ -107,9 +125,9 @@ func (ja *Config) Exchange(ctx context.Context, code string, opts ...oauth2.Auth
107125
oauth2.SetAuthURLParam("client_assertion_type", "urn:ietf:params:oauth:client-assertion-type:jwt-bearer"),
108126
oauth2.SetAuthURLParam("client_assertion", signed),
109127
)
110-
return ja.Config.Exchange(ctx, code, opts...)
128+
return ja.cfg.Exchange(ctx, code, opts...)
111129
}
112130

113131
func (ja *Config) TokenSource(ctx context.Context, token *oauth2.Token) oauth2.TokenSource {
114-
return ja.Config.TokenSource(ctx, token)
132+
return ja.cfg.TokenSource(ctx, token)
115133
}

coderd/oauthpki/okidcpki_test.go

+155
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
package oauthpki_test
2+
3+
import (
4+
"context"
5+
"encoding/base64"
6+
"io"
7+
"net/http"
8+
"net/url"
9+
"testing"
10+
11+
"github.com/coreos/go-oidc/v3/oidc"
12+
"github.com/golang-jwt/jwt"
13+
"github.com/stretchr/testify/assert"
14+
"github.com/stretchr/testify/require"
15+
"golang.org/x/oauth2"
16+
17+
"github.com/coder/coder/coderd/httpmw"
18+
"github.com/coder/coder/coderd/oauthpki"
19+
"github.com/coder/coder/testutil"
20+
)
21+
22+
const (
23+
testClientKey = `-----BEGIN RSA PRIVATE KEY-----
24+
MIIEpAIBAAKCAQEAnUryZEfn5kA8wuk9a7ogFuWbk3uPHEhioYuAg9m3/tIdqSqu
25+
ASpRzw8+1nORTf3ykWRRlhxZWnKimmkB0Ux5Yrz9TDVWDQbzEH3B8ibMlmaNcoN8
26+
wYVzeEpqCe3fJagnV0lh0sHB1Z+vhcJ/M2nEAdyfhIgQEbG6Xtl2+WcGqyMWUJpV
27+
g8+ebK+JkXELAGN1hg3DdV52gjodEjoe1/ibHz8y3NR7j2tOKix7iKOhccyFkD35
28+
xqSnfyZJK5yxIfmGiWdVOIGqc2rYpgvrXJLTOjLoeyDSNi+Q604T64ZxsqfuM4LX
29+
BakVG3EwHFXPBfsBKjUE9HYvXEXw3fJP9K6mIwIDAQABAoIBAQCb+aH7x0IylSir
30+
r1Z06RDBI9bunOwBA9aqkwdRuCg4zGsVQXljNnABgACz7837JQPRIUW2MU553otX
31+
yyE+RzNnsjkLxSgbqvSFOe+FDOx7iB5jm/euf4NNmZ0lU3iggurgJ6iVsgVgrQUF
32+
AyXX+d2gawLUDYjBwxgozkSodH2sXYSX+SWfSOXHsFzSa3tLtUMbAIflM0rlRXf7
33+
Z57M8mMomZUvmmojH+TnBQljJlU8lhrvOaDD4DT8qAtVHE3VluDBQ9/3E8OIjz+E
34+
EqUgWLgrdq1rIMhJbHN90NwLwWs+2PcRfdB6hqKPktLne2KZFOgVKlxPKOYByBq1
35+
PX/vJ/HBAoGBAMFmJ6nYqyUVl26ajlXmnXBjQ+iBLHo9lcUu84+rpqRf90Bsm5bd
36+
jMmYr3Yo3yXNiit3rvZzBfPElo+IVa1HpPtgOaa2AU5B3QzxWCNT0FNRQqMG2LcA
37+
CvB10pOdJEABQxr7d4eFRg2/KbF1fr0r0vqMEelwa5ejTg6ROD3DtadpAoGBANA0
38+
4EClniCwvd1IECy2oTuTDosXgmRKwRAcwgE34YXy1Y/L4X/ghFeCHi3ybrep0uwL
39+
ptJNK+0sqvPu6UhC356GfMqfuzOKNMkXybnPUbHrz5KTkN+QQMfPc73Veel2gpD3
40+
xNataEmHtxcOx0X1OnjwyZZpmMbrUY3Cackn+durAoGBAKYR5nU+jJfnloVvSlIR
41+
GZhsZN++LEc7ouQTkSoJp6r2jQZRPLmrvT1PUzwPlK6NdNwmhaMy2iWc5fySgZ+u
42+
KcmBs3+oQi7E9+ApThnn2rfwy1vagTWDX+FkC1KeWYZsjwcYcGd61dDwGgk8b3xZ
43+
qW1j4e2mj31CycBQiw7eg5ohAoGADvkOe3etlHpBXS12hFCp7afYruYE6YN6uNbo
44+
mL/VBxX8h7fIwrJ5sfVYiENb9PdQhMsdtxf3pbnFnX875Ydxn2vag5PTGZTB0QhV
45+
6HfhTyM/LTJRg9JS5kuj7i3w83ojT5uR20JjMo6A+zaD3CMTjmj6hkeXxg5cMg6e
46+
HuoyDLsCgYBcbboYMFT1cUSxBeMtPGt3CxxZUYnUQaRUeOcjqYYlFL+DCWhY7pxH
47+
EnLhwW/KzkDzOmwRmmNOMqD7UhR/ayxR+avRt6v5d5l8fVCuNexgs7kR9L5IQp9l
48+
YV2wsCoXBCcuPmio/te44U//BlzprEu0w1iHpb3ibmQg4y291R0TvQ==
49+
-----END RSA PRIVATE KEY-----`
50+
51+
testClientCert = `
52+
-----BEGIN CERTIFICATE-----
53+
MIIEOjCCAiKgAwIBAgIQMO50KnWsRbmrrthPQgyubjANBgkqhkiG9w0BAQsFADAY
54+
MRYwFAYDVQQDEw1Mb2NhbGhvc3RDZXJ0MB4XDTIzMDgxMDE2MjYxOFoXDTI1MDIx
55+
MDE2MjU0M1owFDESMBAGA1UEAxMJbG9jYWxob3N0MIIBIjANBgkqhkiG9w0BAQEF
56+
AAOCAQ8AMIIBCgKCAQEAnUryZEfn5kA8wuk9a7ogFuWbk3uPHEhioYuAg9m3/tId
57+
qSquASpRzw8+1nORTf3ykWRRlhxZWnKimmkB0Ux5Yrz9TDVWDQbzEH3B8ibMlmaN
58+
coN8wYVzeEpqCe3fJagnV0lh0sHB1Z+vhcJ/M2nEAdyfhIgQEbG6Xtl2+WcGqyMW
59+
UJpVg8+ebK+JkXELAGN1hg3DdV52gjodEjoe1/ibHz8y3NR7j2tOKix7iKOhccyF
60+
kD35xqSnfyZJK5yxIfmGiWdVOIGqc2rYpgvrXJLTOjLoeyDSNi+Q604T64Zxsqfu
61+
M4LXBakVG3EwHFXPBfsBKjUE9HYvXEXw3fJP9K6mIwIDAQABo4GDMIGAMA4GA1Ud
62+
DwEB/wQEAwIDuDAdBgNVHSUEFjAUBggrBgEFBQcDAQYIKwYBBQUHAwIwHQYDVR0O
63+
BBYEFAYCdgydG3h2SNWF+BfAyJtNliJtMB8GA1UdIwQYMBaAFHR/aptP0RUNNFyf
64+
5uky527SECt1MA8GA1UdEQQIMAaHBH8AAAEwDQYJKoZIhvcNAQELBQADggIBAI6P
65+
ymG7l06JvJ3p6xgaMyOxgkpQl6WkY4LJHVEhfeDSoO3qsJc4PxUdSExJsT84weXb
66+
lF+tK6D/CPlvjmG720IlB5cSKJ71rWjwmaMWKxWKXyoZdDrHAp55+FNdXegUZF2o
67+
EF/ZM5CHaO8iHMkuWEv1OASHBQWC/o4spUN5HGQ9HepwLVvO/aX++LYfvfL9faKA
68+
IT+w9i8pJbfItFmfA8x2OEVZk8aEA0WtKdfsMwzGmZ1GSGa4UYcynxQGCMiB5h4L
69+
C/dpoJRbEzdGLuTZgV2SCaN3k5BrH4aaILI9tqZaq0gamN9Rd2yji3cGiduCeAAo
70+
RmVcl9fBliMLxylWEP5+B2JmCZEc8Lfm0TBNnjaG17KY40gzbfBYixBxBTYgsPua
71+
bfprtfksSG++zcsDbkC8CtPamtlNWtDAiFp4yQRkP79PlJO6qCdTrFWPukTMCMso
72+
25hjLvxj1fLy/jSMDEZu/oQ14TMCZSGHRjz4CPiaCfXqgqOtVOD+5+yWInwUGp/i
73+
Nb1vIq4ruEAbyCbdWKHbE0yT5AP7hm5ZNybpZ4/311AEBD2HKip/OqB05p99XcLw
74+
BIC4ODNvwCn6x00KZoqWz/MX2dEQ/HqWiWaDB/OSemfTVE3I94mzEWnqpF2cQpcT
75+
B1B7CpkMU55hPP+7nsofCszNrMDXT8Z5w2a3zLKM
76+
-----END CERTIFICATE-----
77+
`
78+
)
79+
80+
type exchangeAssert struct {
81+
httpmw.OAuth2Config
82+
assert func(ctx context.Context, code string, opts ...oauth2.AuthCodeOption)
83+
}
84+
85+
func (a *exchangeAssert) Exchange(ctx context.Context, code string, opts ...oauth2.AuthCodeOption) (*oauth2.Token, error) {
86+
a.assert(ctx, code, opts...)
87+
return a.OAuth2Config.Exchange(ctx, code, opts...)
88+
}
89+
90+
type fakeRoundTripper struct {
91+
roundTrip func(req *http.Request) (*http.Response, error)
92+
}
93+
94+
func (f fakeRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
95+
return f.roundTrip(req)
96+
}
97+
98+
// TestAzureADPKIOIDC ensures we do not break Azure AD compatibility.
99+
func TestAzureADPKIOIDC(t *testing.T) {
100+
oauthCfg := &oauth2.Config{
101+
ClientID: "random-client-id",
102+
Endpoint: oauth2.Endpoint{
103+
TokenURL: "https://login.microsoftonline.com/6a1e9139-13f2-4afb-8f46-036feac8bd79/v2.0/token",
104+
},
105+
}
106+
107+
pkiConfig, err := oauthpki.NewOauth2PKIConfig(oauthpki.ConfigParams{
108+
ClientID: oauthCfg.ClientID,
109+
TokenURL: oauthCfg.Endpoint.TokenURL,
110+
PemEncodedKey: []byte(testClientKey),
111+
PemEncodedCert: []byte(testClientCert),
112+
Config: oauthCfg,
113+
})
114+
require.NoError(t, err, "failed to create pki config")
115+
116+
ctx := testutil.Context(t, testutil.WaitMedium)
117+
ctx = oidc.ClientContext(ctx, &http.Client{
118+
Transport: &fakeRoundTripper{
119+
roundTrip: func(req *http.Request) (*http.Response, error) {
120+
resp := &http.Response{
121+
Status: "500 Internal Service Error",
122+
}
123+
// This is the easiest way to hijack the request and check
124+
// the params. The oauth2 package uses unexported types and
125+
// options, so we need to view the actual request created.
126+
data, err := io.ReadAll(req.Body)
127+
if !assert.NoError(t, err, "failed to read request body") {
128+
return resp, nil
129+
}
130+
vals, err := url.ParseQuery(string(data))
131+
if !assert.NoError(t, err, "failed to parse values") {
132+
return resp, nil
133+
}
134+
assert.Equal(t, "urn:ietf:params:oauth:client-assertion-type:jwt-bearer", vals.Get("client_assertion_type"))
135+
136+
jwtToken := vals.Get("client_assertion")
137+
138+
// No need to actually verify the jwt is signed right.
139+
parsedToken, _, err := (&jwt.Parser{}).ParseUnverified(jwtToken, jwt.MapClaims{})
140+
if !assert.NoError(t, err, "failed to parse jwt token") {
141+
return resp, nil
142+
}
143+
144+
// Azure requirements
145+
assert.NotEmpty(t, parsedToken.Header["x5t"], "hashed cert missing")
146+
assert.Equal(t, "RS256", parsedToken.Header["alg"], "azure only accepts RS256")
147+
assert.Equal(t, "JWT", parsedToken.Header["typ"], "azure only accepts JWT")
148+
return resp, nil
149+
},
150+
},
151+
})
152+
token, err := pkiConfig.Exchange(ctx, base64.StdEncoding.EncodeToString([]byte("random-code")))
153+
// We hijack the request and return an error intentionally
154+
require.Error(t, err, "error expected")
155+
}

0 commit comments

Comments
 (0)