Skip to content

Commit 6b9a3e4

Browse files
committed
add test for oidc jwt
1 parent 4028995 commit 6b9a3e4

File tree

5 files changed

+133
-7
lines changed

5 files changed

+133
-7
lines changed

coderd/coderdtest/coderdtest.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,7 @@ type Options struct {
161161
WorkspaceUsageTrackerTick chan time.Time
162162
NotificationsEnqueuer notifications.Enqueuer
163163
APIKeyEncryptionCache cryptokeys.EncryptionKeycache
164+
OIDCConvertKeyCache cryptokeys.SigningKeycache
164165
Clock quartz.Clock
165166
}
166167

@@ -538,6 +539,7 @@ func NewOptions(t testing.TB, options *Options) (func(http.Handler), context.Can
538539
OneTimePasscodeValidityPeriod: options.OneTimePasscodeValidityPeriod,
539540
Clock: options.Clock,
540541
AppEncryptionKeyCache: options.APIKeyEncryptionCache,
542+
OIDCConvertKeyCache: options.OIDCConvertKeyCache,
541543
}
542544
}
543545

coderd/jwtutils/jwe.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ func Decrypt(ctx context.Context, d DecryptKeyProvider, token string, claims Cla
106106

107107
kid := object.Header.KeyID
108108
if kid == "" {
109-
return xerrors.Errorf("expected %q header to be a string", keyIDHeaderKey)
109+
return ErrMissingKeyID
110110
}
111111

112112
key, err := d.DecryptingKey(ctx, kid)

coderd/jwtutils/jws.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ import (
1010
"golang.org/x/xerrors"
1111
)
1212

13+
var ErrMissingKeyID = xerrors.New("missing key ID")
14+
1315
const (
1416
keyIDHeaderKey = "kid"
1517
)
@@ -126,7 +128,7 @@ func Verify(ctx context.Context, v VerifyKeyProvider, token string, claims Claim
126128

127129
kid := signature.Header.KeyID
128130
if kid == "" {
129-
return xerrors.Errorf("expected %q header to be a string", keyIDHeaderKey)
131+
return ErrMissingKeyID
130132
}
131133

132134
key, err := v.VerifyingKey(ctx, kid)

coderd/userauth.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ const (
5252
)
5353

5454
type OAuthConvertStateClaims struct {
55-
jwt.Claims
55+
jwtutils.RegisteredClaims
5656

5757
UserID uuid.UUID `json:"user_id"`
5858
State string `json:"state"`
@@ -61,7 +61,7 @@ type OAuthConvertStateClaims struct {
6161
}
6262

6363
func (o *OAuthConvertStateClaims) Validate(e jwt.Expected) error {
64-
return o.Claims.Validate(e)
64+
return o.RegisteredClaims.Validate(e)
6565
}
6666

6767
// postConvertLoginType replies with an oauth state token capable of converting
@@ -156,7 +156,7 @@ func (api *API) postConvertLoginType(rw http.ResponseWriter, r *http.Request) {
156156
// Eg: Developers with more than 1 deployment.
157157
now := time.Now()
158158
claims := &OAuthConvertStateClaims{
159-
Claims: jwt.Claims{
159+
RegisteredClaims: jwtutils.RegisteredClaims{
160160
Issuer: api.DeploymentID,
161161
Subject: stateString,
162162
Audience: []string{user.ID.String()},
@@ -1682,7 +1682,7 @@ func (api *API) convertUserToOauth(ctx context.Context, r *http.Request, db data
16821682
var claims OAuthConvertStateClaims
16831683

16841684
err = jwtutils.Verify(ctx, api.OIDCConvertKeyCache, jwtCookie.Value, &claims)
1685-
if xerrors.Is(err, cryptokeys.ErrKeyNotFound) || xerrors.Is(err, cryptokeys.ErrKeyInvalid) || xerrors.Is(err, jose.ErrCryptoFailure) {
1685+
if xerrors.Is(err, cryptokeys.ErrKeyNotFound) || xerrors.Is(err, cryptokeys.ErrKeyInvalid) || xerrors.Is(err, jose.ErrCryptoFailure) || xerrors.Is(err, jwtutils.ErrMissingKeyID) {
16861686
// These errors are probably because the user is mixing 2 coder deployments.
16871687
return database.User{}, idpsync.HTTPError{
16881688
Code: http.StatusBadRequest,

coderd/userauth_test.go

Lines changed: 123 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ package coderd_test
33
import (
44
"context"
55
"crypto"
6+
"crypto/rand"
7+
"encoding/json"
68
"fmt"
79
"io"
810
"net/http"
@@ -13,6 +15,7 @@ import (
1315
"time"
1416

1517
"github.com/coreos/go-oidc/v3/oidc"
18+
"github.com/go-jose/go-jose/v4"
1619
"github.com/golang-jwt/jwt/v4"
1720
"github.com/google/go-github/v43/github"
1821
"github.com/google/uuid"
@@ -27,10 +30,12 @@ import (
2730
"github.com/coder/coder/v2/coderd/audit"
2831
"github.com/coder/coder/v2/coderd/coderdtest"
2932
"github.com/coder/coder/v2/coderd/coderdtest/oidctest"
33+
"github.com/coder/coder/v2/coderd/cryptokeys"
3034
"github.com/coder/coder/v2/coderd/database"
3135
"github.com/coder/coder/v2/coderd/database/dbauthz"
3236
"github.com/coder/coder/v2/coderd/database/dbgen"
3337
"github.com/coder/coder/v2/coderd/database/dbtestutil"
38+
"github.com/coder/coder/v2/coderd/jwtutils"
3439
"github.com/coder/coder/v2/coderd/notifications"
3540
"github.com/coder/coder/v2/coderd/promoauth"
3641
"github.com/coder/coder/v2/codersdk"
@@ -1316,22 +1321,25 @@ func TestUserOIDC(t *testing.T) {
13161321

13171322
owner := coderdtest.CreateFirstUser(t, client)
13181323
user, userData := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID)
1324+
require.Equal(t, codersdk.LoginTypePassword, userData.LoginType)
13191325

13201326
claims := jwt.MapClaims{
13211327
"email": userData.Email,
13221328
}
13231329
var err error
13241330
user.HTTPClient.Jar, err = cookiejar.New(nil)
13251331
require.NoError(t, err)
1332+
user.HTTPClient.Transport = http.DefaultTransport.(*http.Transport).Clone()
13261333

13271334
ctx := testutil.Context(t, testutil.WaitShort)
1335+
13281336
convertResponse, err := user.ConvertLoginType(ctx, codersdk.ConvertLoginRequest{
13291337
ToType: codersdk.LoginTypeOIDC,
13301338
Password: "SomeSecurePassword!",
13311339
})
13321340
require.NoError(t, err)
13331341

1334-
fake.LoginWithClient(t, user, claims, func(r *http.Request) {
1342+
_, _ = fake.LoginWithClient(t, user, claims, func(r *http.Request) {
13351343
r.URL.RawQuery = url.Values{
13361344
"oidc_merge_state": {convertResponse.StateString},
13371345
}.Encode()
@@ -1341,6 +1349,99 @@ func TestUserOIDC(t *testing.T) {
13411349
r.AddCookie(cookie)
13421350
}
13431351
})
1352+
1353+
info, err := client.User(ctx, userData.ID.String())
1354+
require.NoError(t, err)
1355+
require.Equal(t, codersdk.LoginTypeOIDC, info.LoginType)
1356+
})
1357+
1358+
t.Run("BadJWT", func(t *testing.T) {
1359+
t.Parallel()
1360+
1361+
var (
1362+
ctx = testutil.Context(t, testutil.WaitMedium)
1363+
logger = slogtest.Make(t, nil)
1364+
)
1365+
1366+
auditor := audit.NewMock()
1367+
fake := oidctest.NewFakeIDP(t,
1368+
oidctest.WithRefresh(func(_ string) error {
1369+
return xerrors.New("refreshing token should never occur")
1370+
}),
1371+
oidctest.WithServing(),
1372+
)
1373+
cfg := fake.OIDCConfig(t, nil, func(cfg *coderd.OIDCConfig) {
1374+
cfg.AllowSignups = true
1375+
})
1376+
1377+
db, ps := dbtestutil.NewDB(t)
1378+
fetcher := &cryptokeys.DBFetcher{
1379+
DB: db,
1380+
}
1381+
1382+
kc, err := cryptokeys.NewSigningCache(ctx, logger, fetcher, codersdk.CryptoKeyFeatureOIDCConvert)
1383+
require.NoError(t, err)
1384+
1385+
client := coderdtest.New(t, &coderdtest.Options{
1386+
Auditor: auditor,
1387+
OIDCConfig: cfg,
1388+
Database: db,
1389+
Pubsub: ps,
1390+
OIDCConvertKeyCache: kc,
1391+
})
1392+
1393+
owner := coderdtest.CreateFirstUser(t, client)
1394+
user, userData := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID)
1395+
1396+
claims := jwt.MapClaims{
1397+
"email": userData.Email,
1398+
}
1399+
user.HTTPClient.Jar, err = cookiejar.New(nil)
1400+
require.NoError(t, err)
1401+
user.HTTPClient.Transport = http.DefaultTransport.(*http.Transport).Clone()
1402+
1403+
convertResponse, err := user.ConvertLoginType(ctx, codersdk.ConvertLoginRequest{
1404+
ToType: codersdk.LoginTypeOIDC,
1405+
Password: "SomeSecurePassword!",
1406+
})
1407+
require.NoError(t, err)
1408+
1409+
// Update the cookie to use a bad signing key. We're asserting the behavior of the scenario
1410+
// where a JWT gets minted on an old version of Coder but gets verified on a new version.
1411+
_, resp := fake.AttemptLogin(t, user, claims, func(r *http.Request) {
1412+
r.URL.RawQuery = url.Values{
1413+
"oidc_merge_state": {convertResponse.StateString},
1414+
}.Encode()
1415+
r.Header.Set(codersdk.SessionTokenHeader, user.SessionToken())
1416+
1417+
cookies := user.HTTPClient.Jar.Cookies(user.URL)
1418+
for i, cookie := range cookies {
1419+
if cookie.Name != coderd.OAuthConvertCookieValue {
1420+
continue
1421+
}
1422+
1423+
jwt := cookie.Value
1424+
var claims coderd.OAuthConvertStateClaims
1425+
err := jwtutils.Verify(ctx, kc, jwt, &claims)
1426+
require.NoError(t, err)
1427+
badJWT := generateBadJWT(t, claims)
1428+
cookie.Value = badJWT
1429+
cookies[i] = cookie
1430+
}
1431+
1432+
user.HTTPClient.Jar.SetCookies(user.URL, cookies)
1433+
1434+
for _, cookie := range cookies {
1435+
fmt.Printf("cookie: %+v\n", cookie)
1436+
r.AddCookie(cookie)
1437+
}
1438+
})
1439+
defer resp.Body.Close()
1440+
require.Equal(t, http.StatusBadRequest, resp.StatusCode)
1441+
var respErr codersdk.Response
1442+
err = json.NewDecoder(resp.Body).Decode(&respErr)
1443+
require.NoError(t, err)
1444+
require.Contains(t, respErr.Message, "Using an invalid jwt to authorize this action.")
13441445
})
13451446

13461447
t.Run("AlternateUsername", func(t *testing.T) {
@@ -2022,3 +2123,24 @@ func inflateClaims(t testing.TB, seed jwt.MapClaims, size int) jwt.MapClaims {
20222123
seed["random_data"] = junk
20232124
return seed
20242125
}
2126+
2127+
// generateBadJWT generates a JWT with a random key. It's intended to emulate the old-style JWT's we generated.
2128+
func generateBadJWT(t *testing.T, claims interface{}) string {
2129+
t.Helper()
2130+
2131+
var buf [64]byte
2132+
_, err := rand.Read(buf[:])
2133+
require.NoError(t, err)
2134+
signer, err := jose.NewSigner(jose.SigningKey{
2135+
Algorithm: jose.HS512,
2136+
Key: buf[:],
2137+
}, nil)
2138+
require.NoError(t, err)
2139+
payload, err := json.Marshal(claims)
2140+
require.NoError(t, err)
2141+
signed, err := signer.Sign(payload)
2142+
require.NoError(t, err)
2143+
compact, err := signed.CompactSerialize()
2144+
require.NoError(t, err)
2145+
return compact
2146+
}

0 commit comments

Comments
 (0)