Skip to content

Commit d9d4d74

Browse files
authored
test: add full OIDC fake IDP (coder#9317)
* test: implement fake OIDC provider with full functionality * Refactor existing tests
1 parent 0a213a6 commit d9d4d74

File tree

10 files changed

+1596
-626
lines changed

10 files changed

+1596
-626
lines changed

coderd/coderdtest/coderdtest.go

+1-165
Original file line numberDiff line numberDiff line change
@@ -31,15 +31,13 @@ import (
3131
"time"
3232

3333
"cloud.google.com/go/compute/metadata"
34-
"github.com/coreos/go-oidc/v3/oidc"
3534
"github.com/fullsailor/pkcs7"
36-
"github.com/golang-jwt/jwt"
35+
"github.com/golang-jwt/jwt/v4"
3736
"github.com/google/uuid"
3837
"github.com/moby/moby/pkg/namesgenerator"
3938
"github.com/prometheus/client_golang/prometheus"
4039
"github.com/stretchr/testify/assert"
4140
"github.com/stretchr/testify/require"
42-
"golang.org/x/oauth2"
4341
"golang.org/x/xerrors"
4442
"google.golang.org/api/idtoken"
4543
"google.golang.org/api/option"
@@ -1020,152 +1018,6 @@ func NewAWSInstanceIdentity(t *testing.T, instanceID string) (awsidentity.Certif
10201018
}
10211019
}
10221020

1023-
type OIDCConfig struct {
1024-
key *rsa.PrivateKey
1025-
issuer string
1026-
// These are optional
1027-
refreshToken string
1028-
oidcTokenExpires func() time.Time
1029-
tokenSource func() (*oauth2.Token, error)
1030-
}
1031-
1032-
func WithRefreshToken(token string) func(cfg *OIDCConfig) {
1033-
return func(cfg *OIDCConfig) {
1034-
cfg.refreshToken = token
1035-
}
1036-
}
1037-
1038-
func WithTokenExpires(expFunc func() time.Time) func(cfg *OIDCConfig) {
1039-
return func(cfg *OIDCConfig) {
1040-
cfg.oidcTokenExpires = expFunc
1041-
}
1042-
}
1043-
1044-
func WithTokenSource(src func() (*oauth2.Token, error)) func(cfg *OIDCConfig) {
1045-
return func(cfg *OIDCConfig) {
1046-
cfg.tokenSource = src
1047-
}
1048-
}
1049-
1050-
func NewOIDCConfig(t *testing.T, issuer string, opts ...func(cfg *OIDCConfig)) *OIDCConfig {
1051-
t.Helper()
1052-
1053-
block, _ := pem.Decode([]byte(testRSAPrivateKey))
1054-
pkey, err := x509.ParsePKCS1PrivateKey(block.Bytes)
1055-
require.NoError(t, err)
1056-
1057-
if issuer == "" {
1058-
issuer = "https://coder.com"
1059-
}
1060-
1061-
cfg := &OIDCConfig{
1062-
key: pkey,
1063-
issuer: issuer,
1064-
}
1065-
for _, opt := range opts {
1066-
opt(cfg)
1067-
}
1068-
return cfg
1069-
}
1070-
1071-
func (*OIDCConfig) AuthCodeURL(state string, _ ...oauth2.AuthCodeOption) string {
1072-
return "/?state=" + url.QueryEscape(state)
1073-
}
1074-
1075-
type tokenSource struct {
1076-
src func() (*oauth2.Token, error)
1077-
}
1078-
1079-
func (s tokenSource) Token() (*oauth2.Token, error) {
1080-
return s.src()
1081-
}
1082-
1083-
func (cfg *OIDCConfig) TokenSource(context.Context, *oauth2.Token) oauth2.TokenSource {
1084-
if cfg.tokenSource == nil {
1085-
return nil
1086-
}
1087-
return tokenSource{
1088-
src: cfg.tokenSource,
1089-
}
1090-
}
1091-
1092-
func (cfg *OIDCConfig) Exchange(_ context.Context, code string, _ ...oauth2.AuthCodeOption) (*oauth2.Token, error) {
1093-
token, err := base64.StdEncoding.DecodeString(code)
1094-
if err != nil {
1095-
return nil, xerrors.Errorf("decode code: %w", err)
1096-
}
1097-
1098-
var exp time.Time
1099-
if cfg.oidcTokenExpires != nil {
1100-
exp = cfg.oidcTokenExpires()
1101-
}
1102-
1103-
return (&oauth2.Token{
1104-
AccessToken: "token",
1105-
RefreshToken: cfg.refreshToken,
1106-
Expiry: exp,
1107-
}).WithExtra(map[string]interface{}{
1108-
"id_token": string(token),
1109-
}), nil
1110-
}
1111-
1112-
func (cfg *OIDCConfig) EncodeClaims(t *testing.T, claims jwt.MapClaims) string {
1113-
t.Helper()
1114-
1115-
if _, ok := claims["exp"]; !ok {
1116-
claims["exp"] = time.Now().Add(time.Hour).UnixMilli()
1117-
}
1118-
1119-
if _, ok := claims["iss"]; !ok {
1120-
claims["iss"] = cfg.issuer
1121-
}
1122-
1123-
if _, ok := claims["sub"]; !ok {
1124-
claims["sub"] = "testme"
1125-
}
1126-
1127-
signed, err := jwt.NewWithClaims(jwt.SigningMethodRS256, claims).SignedString(cfg.key)
1128-
require.NoError(t, err)
1129-
1130-
return base64.StdEncoding.EncodeToString([]byte(signed))
1131-
}
1132-
1133-
func (cfg *OIDCConfig) OIDCConfig(t *testing.T, userInfoClaims jwt.MapClaims, opts ...func(cfg *coderd.OIDCConfig)) *coderd.OIDCConfig {
1134-
// By default, the provider can be empty.
1135-
// This means it won't support any endpoints!
1136-
provider := &oidc.Provider{}
1137-
if userInfoClaims != nil {
1138-
resp, err := json.Marshal(userInfoClaims)
1139-
require.NoError(t, err)
1140-
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1141-
w.WriteHeader(http.StatusOK)
1142-
_, _ = w.Write(resp)
1143-
}))
1144-
t.Cleanup(srv.Close)
1145-
cfg := &oidc.ProviderConfig{
1146-
UserInfoURL: srv.URL,
1147-
}
1148-
provider = cfg.NewProvider(context.Background())
1149-
}
1150-
newCFG := &coderd.OIDCConfig{
1151-
OAuth2Config: cfg,
1152-
Verifier: oidc.NewVerifier(cfg.issuer, &oidc.StaticKeySet{
1153-
PublicKeys: []crypto.PublicKey{cfg.key.Public()},
1154-
}, &oidc.Config{
1155-
SkipClientIDCheck: true,
1156-
}),
1157-
Provider: provider,
1158-
UsernameField: "preferred_username",
1159-
EmailField: "email",
1160-
AuthURLParams: map[string]string{"access_type": "offline"},
1161-
GroupField: "groups",
1162-
}
1163-
for _, opt := range opts {
1164-
opt(newCFG)
1165-
}
1166-
return newCFG
1167-
}
1168-
11691021
// NewAzureInstanceIdentity returns a metadata client and ID token validator for faking
11701022
// instance authentication for Azure.
11711023
func NewAzureInstanceIdentity(t *testing.T, instanceID string) (x509.VerifyOptions, *http.Client) {
@@ -1254,22 +1106,6 @@ func SDKError(t *testing.T, err error) *codersdk.Error {
12541106
return cerr
12551107
}
12561108

1257-
const testRSAPrivateKey = `-----BEGIN RSA PRIVATE KEY-----
1258-
MIICXQIBAAKBgQDLets8+7M+iAQAqN/5BVyCIjhTQ4cmXulL+gm3v0oGMWzLupUS
1259-
v8KPA+Tp7dgC/DZPfMLaNH1obBBhJ9DhS6RdS3AS3kzeFrdu8zFHLWF53DUBhS92
1260-
5dCAEuJpDnNizdEhxTfoHrhuCmz8l2nt1pe5eUK2XWgd08Uc93h5ij098wIDAQAB
1261-
AoGAHLaZeWGLSaen6O/rqxg2laZ+jEFbMO7zvOTruiIkL/uJfrY1kw+8RLIn+1q0
1262-
wLcWcuEIHgKKL9IP/aXAtAoYh1FBvRPLkovF1NZB0Je/+CSGka6wvc3TGdvppZJe
1263-
rKNcUvuOYLxkmLy4g9zuY5qrxFyhtIn2qZzXEtLaVOHzPQECQQDvN0mSajpU7dTB
1264-
w4jwx7IRXGSSx65c+AsHSc1Rj++9qtPC6WsFgAfFN2CEmqhMbEUVGPv/aPjdyWk9
1265-
pyLE9xR/AkEA2cGwyIunijE5v2rlZAD7C4vRgdcMyCf3uuPcgzFtsR6ZhyQSgLZ8
1266-
YRPuvwm4cdPJMmO3YwBfxT6XGuSc2k8MjQJBAI0+b8prvpV2+DCQa8L/pjxp+VhR
1267-
Xrq2GozrHrgR7NRokTB88hwFRJFF6U9iogy9wOx8HA7qxEbwLZuhm/4AhbECQC2a
1268-
d8h4Ht09E+f3nhTEc87mODkl7WJZpHL6V2sORfeq/eIkds+H6CJ4hy5w/bSw8tjf
1269-
sz9Di8sGIaUbLZI2rd0CQQCzlVwEtRtoNCyMJTTrkgUuNufLP19RZ5FpyXxBO5/u
1270-
QastnN77KfUwdj3SJt44U/uh1jAIv4oSLBr8HYUkbnI8
1271-
-----END RSA PRIVATE KEY-----`
1272-
12731109
func DeploymentValues(t testing.TB) *codersdk.DeploymentValues {
12741110
var cfg codersdk.DeploymentValues
12751111
opts := cfg.Options()

coderd/coderdtest/oidctest/helper.go

+103
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
package oidctest
2+
3+
import (
4+
"net/http"
5+
"testing"
6+
"time"
7+
8+
"github.com/golang-jwt/jwt/v4"
9+
"github.com/stretchr/testify/require"
10+
11+
"github.com/coder/coder/v2/coderd/database"
12+
"github.com/coder/coder/v2/coderd/database/dbauthz"
13+
"github.com/coder/coder/v2/coderd/httpmw"
14+
"github.com/coder/coder/v2/codersdk"
15+
"github.com/coder/coder/v2/testutil"
16+
)
17+
18+
// LoginHelper helps with logging in a user and refreshing their oauth tokens.
19+
// It is mainly because refreshing oauth tokens is a bit tricky and requires
20+
// some database manipulation.
21+
type LoginHelper struct {
22+
fake *FakeIDP
23+
client *codersdk.Client
24+
}
25+
26+
func NewLoginHelper(client *codersdk.Client, fake *FakeIDP) *LoginHelper {
27+
if client == nil {
28+
panic("client must not be nil")
29+
}
30+
if fake == nil {
31+
panic("fake must not be nil")
32+
}
33+
return &LoginHelper{
34+
fake: fake,
35+
client: client,
36+
}
37+
}
38+
39+
// Login just helps by making an unauthenticated client and logging in with
40+
// the given claims. All Logins should be unauthenticated, so this is a
41+
// convenience method.
42+
func (h *LoginHelper) Login(t *testing.T, idTokenClaims jwt.MapClaims) (*codersdk.Client, *http.Response) {
43+
t.Helper()
44+
unauthenticatedClient := codersdk.New(h.client.URL)
45+
46+
return h.fake.Login(t, unauthenticatedClient, idTokenClaims)
47+
}
48+
49+
// ExpireOauthToken expires the oauth token for the given user.
50+
func (*LoginHelper) ExpireOauthToken(t *testing.T, db database.Store, user *codersdk.Client) database.UserLink {
51+
t.Helper()
52+
53+
//nolint:gocritic // Testing
54+
ctx := dbauthz.AsSystemRestricted(testutil.Context(t, testutil.WaitMedium))
55+
56+
id, _, err := httpmw.SplitAPIToken(user.SessionToken())
57+
require.NoError(t, err)
58+
59+
// We need to get the OIDC link and update it in the database to force
60+
// it to be expired.
61+
key, err := db.GetAPIKeyByID(ctx, id)
62+
require.NoError(t, err, "get api key")
63+
64+
link, err := db.GetUserLinkByUserIDLoginType(ctx, database.GetUserLinkByUserIDLoginTypeParams{
65+
UserID: key.UserID,
66+
LoginType: database.LoginTypeOIDC,
67+
})
68+
require.NoError(t, err, "get user link")
69+
70+
// Expire the oauth link for the given user.
71+
updated, err := db.UpdateUserLink(ctx, database.UpdateUserLinkParams{
72+
OAuthAccessToken: link.OAuthAccessToken,
73+
OAuthRefreshToken: link.OAuthRefreshToken,
74+
OAuthExpiry: time.Now().Add(time.Hour * -1),
75+
UserID: link.UserID,
76+
LoginType: link.LoginType,
77+
})
78+
require.NoError(t, err, "expire user link")
79+
80+
return updated
81+
}
82+
83+
// ForceRefresh forces the client to refresh its oauth token. It does this by
84+
// expiring the oauth token, then doing an authenticated call. This will force
85+
// the API Key middleware to refresh the oauth token.
86+
//
87+
// A unit test assertion makes sure the refresh token is used.
88+
func (h *LoginHelper) ForceRefresh(t *testing.T, db database.Store, user *codersdk.Client, idToken jwt.MapClaims) {
89+
t.Helper()
90+
91+
link := h.ExpireOauthToken(t, db, user)
92+
// Updates the claims that the IDP will return. By default, it always
93+
// uses the original claims for the original oauth token.
94+
h.fake.UpdateRefreshClaims(link.OAuthRefreshToken, idToken)
95+
96+
t.Cleanup(func() {
97+
require.True(t, h.fake.RefreshUsed(link.OAuthRefreshToken), "refresh token must be used, but has not. Did you forget to call the returned function from this call?")
98+
})
99+
100+
// Do any authenticated call to force the refresh
101+
_, err := user.User(testutil.Context(t, testutil.WaitShort), "me")
102+
require.NoError(t, err, "user must be able to be fetched")
103+
}

0 commit comments

Comments
 (0)