diff --git a/oidc/jwks_test.go b/oidc/jwks_test.go index 2cbf38b..7123dab 100644 --- a/oidc/jwks_test.go +++ b/oidc/jwks_test.go @@ -11,6 +11,7 @@ import ( "encoding/json" "errors" "fmt" + "io" "net/http" "net/http/httptest" "strconv" @@ -151,11 +152,8 @@ func TestKeyVerifyContextCanceled(t *testing.T) { t.Fatal(err) } - ch := make(chan struct{}) - defer close(ch) - s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - <-ch + io.WriteString(w, "{}") })) defer s.Close() diff --git a/oidc/verify.go b/oidc/verify.go index 52b27b7..a8bf107 100644 --- a/oidc/verify.go +++ b/oidc/verify.go @@ -1,15 +1,11 @@ package oidc import ( - "bytes" "context" - "encoding/base64" "encoding/json" - "errors" "fmt" "io" "net/http" - "strings" "time" jose "github.com/go-jose/go-jose/v4" @@ -145,18 +141,6 @@ func (p *Provider) newVerifier(keySet KeySet, config *Config) *IDTokenVerifier { return NewVerifier(p.issuer, keySet, config) } -func parseJWT(p string) ([]byte, error) { - parts := strings.Split(p, ".") - if len(parts) < 2 { - return nil, fmt.Errorf("oidc: malformed jwt, expected 3 parts got %d", len(parts)) - } - payload, err := base64.RawURLEncoding.DecodeString(parts[1]) - if err != nil { - return nil, fmt.Errorf("oidc: malformed jwt payload: %v", err) - } - return payload, nil -} - func contains(sli []string, ele string) bool { for _, s := range sli { if s == ele { @@ -219,12 +203,49 @@ func resolveDistributedClaim(ctx context.Context, verifier *IDTokenVerifier, src // // token, err := verifier.Verify(ctx, rawIDToken) func (v *IDTokenVerifier) Verify(ctx context.Context, rawIDToken string) (*IDToken, error) { - // Throw out tokens with invalid claims before trying to verify the token. This lets - // us do cheap checks before possibly re-syncing keys. - payload, err := parseJWT(rawIDToken) + var supportedSigAlgs []jose.SignatureAlgorithm + for _, alg := range v.config.SupportedSigningAlgs { + supportedSigAlgs = append(supportedSigAlgs, jose.SignatureAlgorithm(alg)) + } + if len(supportedSigAlgs) == 0 { + // If no algorithms were specified by both the config and discovery, default + // to the one mandatory algorithm "RS256". + supportedSigAlgs = []jose.SignatureAlgorithm{jose.RS256} + } + if v.config.InsecureSkipSignatureCheck { + // "none" is a required value to even parse a JWT with the "none" algorithm + // using go-jose. + supportedSigAlgs = append(supportedSigAlgs, "none") + } + + // Parse and verify the signature first. This at least forces the user to have + // a valid, signed ID token before we do any other processing. + jws, err := jose.ParseSigned(rawIDToken, supportedSigAlgs) if err != nil { return nil, fmt.Errorf("oidc: malformed jwt: %v", err) } + switch len(jws.Signatures) { + case 0: + return nil, fmt.Errorf("oidc: id token not signed") + case 1: + default: + return nil, fmt.Errorf("oidc: multiple signatures on id token not supported") + } + sig := jws.Signatures[0] + + var payload []byte + if v.config.InsecureSkipSignatureCheck { + // Yolo mode. + payload = jws.UnsafePayloadWithoutVerification() + } else { + // The JWT is attached here for the happy path to avoid the verifier from + // having to parse the JWT twice. + ctx = context.WithValue(ctx, parsedJWTKey, jws) + payload, err = v.keySet.VerifySignature(ctx, rawIDToken) + if err != nil { + return nil, fmt.Errorf("failed to verify signature: %v", err) + } + } var token idToken if err := json.Unmarshal(payload, &token); err != nil { return nil, fmt.Errorf("oidc: failed to unmarshal claims: %v", err) @@ -254,6 +275,7 @@ func (v *IDTokenVerifier) Verify(ctx context.Context, rawIDToken string) (*IDTok AccessTokenHash: token.AtHash, claims: payload, distributedClaims: distributedClaims, + sigAlgorithm: sig.Header.Algorithm, } // Check issuer. @@ -306,45 +328,6 @@ func (v *IDTokenVerifier) Verify(ctx context.Context, rawIDToken string) (*IDTok } } - if v.config.InsecureSkipSignatureCheck { - return t, nil - } - - var supportedSigAlgs []jose.SignatureAlgorithm - for _, alg := range v.config.SupportedSigningAlgs { - supportedSigAlgs = append(supportedSigAlgs, jose.SignatureAlgorithm(alg)) - } - if len(supportedSigAlgs) == 0 { - // If no algorithms were specified by both the config and discovery, default - // to the one mandatory algorithm "RS256". - supportedSigAlgs = []jose.SignatureAlgorithm{jose.RS256} - } - jws, err := jose.ParseSigned(rawIDToken, supportedSigAlgs) - if err != nil { - return nil, fmt.Errorf("oidc: malformed jwt: %v", err) - } - - switch len(jws.Signatures) { - case 0: - return nil, fmt.Errorf("oidc: id token not signed") - case 1: - default: - return nil, fmt.Errorf("oidc: multiple signatures on id token not supported") - } - sig := jws.Signatures[0] - t.sigAlgorithm = sig.Header.Algorithm - - ctx = context.WithValue(ctx, parsedJWTKey, jws) - gotPayload, err := v.keySet.VerifySignature(ctx, rawIDToken) - if err != nil { - return nil, fmt.Errorf("failed to verify signature: %v", err) - } - - // Ensure that the payload returned by the square actually matches the payload parsed earlier. - if !bytes.Equal(gotPayload, payload) { - return nil, errors.New("oidc: internal error, payload parsed did not match previous payload") - } - return t, nil } diff --git a/oidc/verify_test.go b/oidc/verify_test.go index f2e2433..725735e 100644 --- a/oidc/verify_test.go +++ b/oidc/verify_test.go @@ -580,9 +580,14 @@ func (v verificationTest) runGetToken(t *testing.T) (*IDToken, error) { if v.signKey != nil { token = v.signKey.sign(t, []byte(v.idToken)) } else { - token = base64.RawURLEncoding.EncodeToString([]byte(`{alg: "none"}`)) + // "none" still uses a second "." character, but "...MUST use the empty octet + // sequence as its JWS Signature value." + // + // https://datatracker.ietf.org/doc/html/rfc7518#section-3.6 + token = base64.RawURLEncoding.EncodeToString([]byte(`{"alg": "none"}`)) token += "." token += base64.RawURLEncoding.EncodeToString([]byte(v.idToken)) + token += "." } ctx, cancel := context.WithCancel(context.Background())