Skip to content

Commit a8a9633

Browse files
committed
refactor all old tests, delete old fake
1 parent f906fae commit a8a9633

File tree

6 files changed

+262
-416
lines changed

6 files changed

+262
-416
lines changed

coderd/coderdtest/coderdtest.go

Lines changed: 0 additions & 148 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@ import (
3131
"time"
3232

3333
"cloud.google.com/go/compute/metadata"
34-
"github.com/coreos/go-oidc/v3/oidc"
3534
"github.com/fullsailor/pkcs7"
3635
"github.com/golang-jwt/jwt/v4"
3736
"github.com/google/uuid"
@@ -40,7 +39,6 @@ import (
4039
"github.com/spf13/afero"
4140
"github.com/stretchr/testify/assert"
4241
"github.com/stretchr/testify/require"
43-
"golang.org/x/oauth2"
4442
"golang.org/x/xerrors"
4543
"google.golang.org/api/idtoken"
4644
"google.golang.org/api/option"
@@ -1022,152 +1020,6 @@ func NewAWSInstanceIdentity(t *testing.T, instanceID string) (awsidentity.Certif
10221020
}
10231021
}
10241022

1025-
type OIDCConfig struct {
1026-
key *rsa.PrivateKey
1027-
issuer string
1028-
// These are optional
1029-
refreshToken string
1030-
oidcTokenExpires func() time.Time
1031-
tokenSource func() (*oauth2.Token, error)
1032-
}
1033-
1034-
func WithRefreshToken(token string) func(cfg *OIDCConfig) {
1035-
return func(cfg *OIDCConfig) {
1036-
cfg.refreshToken = token
1037-
}
1038-
}
1039-
1040-
func WithTokenExpires(expFunc func() time.Time) func(cfg *OIDCConfig) {
1041-
return func(cfg *OIDCConfig) {
1042-
cfg.oidcTokenExpires = expFunc
1043-
}
1044-
}
1045-
1046-
func WithTokenSource(src func() (*oauth2.Token, error)) func(cfg *OIDCConfig) {
1047-
return func(cfg *OIDCConfig) {
1048-
cfg.tokenSource = src
1049-
}
1050-
}
1051-
1052-
func NewOIDCConfig(t *testing.T, issuer string, opts ...func(cfg *OIDCConfig)) *OIDCConfig {
1053-
t.Helper()
1054-
1055-
block, _ := pem.Decode([]byte(testRSAPrivateKey))
1056-
pkey, err := x509.ParsePKCS1PrivateKey(block.Bytes)
1057-
require.NoError(t, err)
1058-
1059-
if issuer == "" {
1060-
issuer = "https://coder.com"
1061-
}
1062-
1063-
cfg := &OIDCConfig{
1064-
key: pkey,
1065-
issuer: issuer,
1066-
}
1067-
for _, opt := range opts {
1068-
opt(cfg)
1069-
}
1070-
return cfg
1071-
}
1072-
1073-
func (*OIDCConfig) AuthCodeURL(state string, _ ...oauth2.AuthCodeOption) string {
1074-
return "/?state=" + url.QueryEscape(state)
1075-
}
1076-
1077-
type tokenSource struct {
1078-
src func() (*oauth2.Token, error)
1079-
}
1080-
1081-
func (s tokenSource) Token() (*oauth2.Token, error) {
1082-
return s.src()
1083-
}
1084-
1085-
func (cfg *OIDCConfig) TokenSource(context.Context, *oauth2.Token) oauth2.TokenSource {
1086-
if cfg.tokenSource == nil {
1087-
return nil
1088-
}
1089-
return tokenSource{
1090-
src: cfg.tokenSource,
1091-
}
1092-
}
1093-
1094-
func (cfg *OIDCConfig) Exchange(_ context.Context, code string, _ ...oauth2.AuthCodeOption) (*oauth2.Token, error) {
1095-
token, err := base64.StdEncoding.DecodeString(code)
1096-
if err != nil {
1097-
return nil, xerrors.Errorf("decode code: %w", err)
1098-
}
1099-
1100-
var exp time.Time
1101-
if cfg.oidcTokenExpires != nil {
1102-
exp = cfg.oidcTokenExpires()
1103-
}
1104-
1105-
return (&oauth2.Token{
1106-
AccessToken: "token",
1107-
RefreshToken: cfg.refreshToken,
1108-
Expiry: exp,
1109-
}).WithExtra(map[string]interface{}{
1110-
"id_token": string(token),
1111-
}), nil
1112-
}
1113-
1114-
func (cfg *OIDCConfig) EncodeClaims(t *testing.T, claims jwt.MapClaims) string {
1115-
t.Helper()
1116-
1117-
if _, ok := claims["exp"]; !ok {
1118-
claims["exp"] = time.Now().Add(time.Hour).UnixMilli()
1119-
}
1120-
1121-
if _, ok := claims["iss"]; !ok {
1122-
claims["iss"] = cfg.issuer
1123-
}
1124-
1125-
if _, ok := claims["sub"]; !ok {
1126-
claims["sub"] = "testme"
1127-
}
1128-
1129-
signed, err := jwt.NewWithClaims(jwt.SigningMethodRS256, claims).SignedString(cfg.key)
1130-
require.NoError(t, err)
1131-
1132-
return base64.StdEncoding.EncodeToString([]byte(signed))
1133-
}
1134-
1135-
func (cfg *OIDCConfig) OIDCConfig(t *testing.T, userInfoClaims jwt.MapClaims, opts ...func(cfg *coderd.OIDCConfig)) *coderd.OIDCConfig {
1136-
// By default, the provider can be empty.
1137-
// This means it won't support any endpoints!
1138-
provider := &oidc.Provider{}
1139-
if userInfoClaims != nil {
1140-
resp, err := json.Marshal(userInfoClaims)
1141-
require.NoError(t, err)
1142-
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1143-
w.WriteHeader(http.StatusOK)
1144-
_, _ = w.Write(resp)
1145-
}))
1146-
t.Cleanup(srv.Close)
1147-
cfg := &oidc.ProviderConfig{
1148-
UserInfoURL: srv.URL,
1149-
}
1150-
provider = cfg.NewProvider(context.Background())
1151-
}
1152-
newCFG := &coderd.OIDCConfig{
1153-
OAuth2Config: cfg,
1154-
Verifier: oidc.NewVerifier(cfg.issuer, &oidc.StaticKeySet{
1155-
PublicKeys: []crypto.PublicKey{cfg.key.Public()},
1156-
}, &oidc.Config{
1157-
SkipClientIDCheck: true,
1158-
}),
1159-
Provider: provider,
1160-
UsernameField: "preferred_username",
1161-
EmailField: "email",
1162-
AuthURLParams: map[string]string{"access_type": "offline"},
1163-
GroupField: "groups",
1164-
}
1165-
for _, opt := range opts {
1166-
opt(newCFG)
1167-
}
1168-
return newCFG
1169-
}
1170-
11711023
// NewAzureInstanceIdentity returns a metadata client and ID token validator for faking
11721024
// instance authentication for Azure.
11731025
func NewAzureInstanceIdentity(t *testing.T, instanceID string) (x509.VerifyOptions, *http.Client) {

coderd/coderdtest/oidctest/runner.go renamed to coderd/coderdtest/oidctest/helper.go

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -43,11 +43,10 @@ func (h *LoginHelper) Login(t *testing.T, idTokenClaims jwt.MapClaims) (*codersd
4343
t.Helper()
4444
unauthenticatedClient := codersdk.New(h.owner.URL)
4545

46-
return h.fake.LoginClient(t, unauthenticatedClient, idTokenClaims)
46+
return h.fake.Login(t, unauthenticatedClient, idTokenClaims)
4747
}
4848

49-
// ForceRefresh forces the client to refresh its oauth token.
50-
func (h *LoginHelper) ForceRefresh(t *testing.T, db database.Store, user *codersdk.Client, idToken jwt.MapClaims) (authenticatedCall func(t *testing.T)) {
49+
func (h *LoginHelper) ExpireOauthToken(t *testing.T, db database.Store, user *codersdk.Client) (refreshToken string) {
5150
t.Helper()
5251

5352
//nolint:gocritic // Testing
@@ -67,10 +66,6 @@ func (h *LoginHelper) ForceRefresh(t *testing.T, db database.Store, user *coders
6766
})
6867
require.NoError(t, err, "get user link")
6968

70-
// Updates the claims that the IDP will return. By default, it always
71-
// uses the original claims for the original oauth token.
72-
h.fake.UpdateRefreshClaims(link.OAuthRefreshToken, idToken)
73-
7469
// Fetch the oauth link for the given user.
7570
_, err = db.UpdateUserLink(ctx, database.UpdateUserLinkParams{
7671
OAuthAccessToken: link.OAuthAccessToken,
@@ -80,15 +75,24 @@ func (h *LoginHelper) ForceRefresh(t *testing.T, db database.Store, user *coders
8075
LoginType: database.LoginTypeOIDC,
8176
})
8277
require.NoError(t, err, "expire user link")
78+
79+
return link.OAuthRefreshToken
80+
}
81+
82+
// ForceRefresh forces the client to refresh its oauth token.
83+
func (h *LoginHelper) ForceRefresh(t *testing.T, db database.Store, user *codersdk.Client, idToken jwt.MapClaims) {
84+
t.Helper()
85+
86+
refreshToken := h.ExpireOauthToken(t, db, user)
87+
// Updates the claims that the IDP will return. By default, it always
88+
// uses the original claims for the original oauth token.
89+
h.fake.UpdateRefreshClaims(refreshToken, idToken)
90+
8391
t.Cleanup(func() {
84-
require.True(t, h.fake.RefreshUsed(link.OAuthRefreshToken), "refresh token must be used, but has not. Did you forget to call the returned function from this call?")
92+
require.True(t, h.fake.RefreshUsed(refreshToken), "refresh token must be used, but has not. Did you forget to call the returned function from this call?")
8593
})
8694

87-
return func(t *testing.T) {
88-
t.Helper()
89-
90-
// Do any authenticated call to force the refresh
91-
_, err := user.User(testutil.Context(t, testutil.WaitShort), "me")
92-
require.NoError(t, err, "user must be able to be fetched")
93-
}
95+
// Do any authenticated call to force the refresh
96+
_, err := user.User(testutil.Context(t, testutil.WaitShort), "me")
97+
require.NoError(t, err, "user must be able to be fetched")
9498
}

0 commit comments

Comments
 (0)