Skip to content

SVC Sigmoid sometimes ROC AUC from predict_proba & decision_function are each other's inverse #31222

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
arhall0 opened this issue Apr 17, 2025 · 5 comments
Labels
Bug Needs Investigation Issue requires investigation

Comments

@arhall0
Copy link

arhall0 commented Apr 17, 2025

Describe the bug

Uncertain if this is a bug or counter-intuitive expected behavior.

Under certain circumstances the ROC AUC calculated for SVC with the sigmoid kernel will not agree depending on if you use predict_proba or decision_function. In fact, they will be nearly 1-other_method_auc.

This was noticed when comparing ROC AUC calculated using roc_auc_score with predictions from predict_proba(X)[:, 1] to using the scorer from get_scorer('roc_auc') which appears to be calling roc_auc_score with scores from decision_function.

Steps/Code to Reproduce

import numpy as np
from sklearn.svm import SVC
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import roc_auc_score, get_scorer
from sklearn.model_selection import train_test_split

n_samples = 100
n_features = 100
random_state = 123
rng = np.random.default_rng(random_state)

X = rng.normal(loc=0.0, scale=1.0, size=(n_samples, n_features))
y = rng.integers(0, 2, size=n_samples)
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=random_state)

svc_params = {
    "kernel": "sigmoid",
    "probability": True,
    "random_state":random_state,
}   
pipeline = Pipeline([
    ('scaler', StandardScaler()),
    ('svc', SVC(**svc_params))
])  
pipeline.fit(X_train, y_train)
y_proba = pipeline.predict_proba(X_test)[:, 1]
y_dec = pipeline.decision_function(X_test)
roc_auc_proba = roc_auc_score(y_test, y_proba)
roc_auc_dec = roc_auc_score(y_test, y_dec)
auc_scorer = get_scorer('roc_auc')
scorer_auc = auc_scorer(pipeline, X_test, y_test)

print(f"AUC (roc_auc_score from predict_proba) = {roc_auc_proba:.4f}")
print(f"AUC (roc_auc_score from decision_function) = {roc_auc_dec:.4f}")
print(f"AUC (get_scorer) = {scorer_auc:.4f}")

Expected Results

The measures of ROC AUC agree

Actual Results

AUC (roc_auc_score from predict_proba) = 0.5833
AUC (roc_auc_score from decision_function) = 0.4295
AUC (get_scorer) = 0.4295

Versions

System:
    python: 3.11.5

Python dependencies:
      sklearn: 1.7.dev0
          pip: 25.0.1
   setuptools: 65.5.0
        numpy: 1.26.4
        scipy: 1.15.2
       Cython: 3.0.12
       pandas: 2.2.3
   matplotlib: 3.10.1
       joblib: 1.2.0
threadpoolctl: 3.1.0

Built with OpenMP: True
@arhall0 arhall0 added Bug Needs Triage Issue requires triage labels Apr 17, 2025
@ogrisel
Copy link
Member

ogrisel commented Apr 25, 2025

Thanks for the report and the reproducer. I agree that this is a surprising behavior (bug?) of the built-in implementation of Platt-scaling from our vendored code of libsvm.

Note that you can instead use CalibratedClassifierCV instead, and this alternative does not suffer for this bug:

import numpy as np
from sklearn.svm import SVC
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import roc_auc_score, get_scorer
from sklearn.model_selection import train_test_split
from sklearn.calibration import CalibratedClassifierCV

n_samples = 100
n_features = 100
random_state = 123
rng = np.random.default_rng(random_state)

X = rng.normal(loc=0.0, scale=1.0, size=(n_samples, n_features))
y = rng.integers(0, 2, size=n_samples)
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=random_state)

svc_params = {
    "kernel": "sigmoid",
    "random_state": random_state,
}
pipeline = Pipeline(
    [
        ("scaler", StandardScaler()),
        ("svc", SVC(**svc_params)),
    ]
)
pipeline_cal = CalibratedClassifierCV(pipeline, method="sigmoid", ensemble=False)
pipeline.fit(X_train, y_train)
pipeline_cal.fit(X_train, y_train)

y_proba = pipeline_cal.predict_proba(X_test)[:, 1]
y_dec = pipeline.decision_function(X_test)
roc_auc_proba = roc_auc_score(y_test, y_proba)
roc_auc_dec = roc_auc_score(y_test, y_dec)
auc_scorer = get_scorer("roc_auc")
scorer_auc = auc_scorer(pipeline, X_test, y_test)

print(f"AUC (roc_auc_score from predict_proba) = {roc_auc_proba:.4f}")
print(f"AUC (roc_auc_score from decision_function) = {roc_auc_dec:.4f}")
print(f"AUC (get_scorer) = {scorer_auc:.4f}")
AUC (roc_auc_score from predict_proba) = 0.4295
AUC (roc_auc_score from decision_function) = 0.4295
AUC (get_scorer) = 0.4295

We also noticed that the Platt-scaling implementation of libsvm does not support sample_weight properly as documented in #16298.

Since we have a CalibrationClassifierCV that works as expected on both accounts, I would be in favor of deprecating the probability=True option of SVC and point our users to CalibratedClassifierCV instead.

@ogrisel ogrisel removed the Needs Triage Issue requires triage label Apr 25, 2025
@ogrisel
Copy link
Member

ogrisel commented Apr 25, 2025

Note that if you set ensemble=True in the above code, I get:

AUC (roc_auc_score from predict_proba) = 0.5000
AUC (roc_auc_score from decision_function) = 0.4295
AUC (get_scorer) = 0.4295

This is expected because the ensembling effect does not preserve the prediction ranking vs fitting a single SVC model. Maybe SVC does something similar internally, in which case this is not a bug and the expected behavior.

However, SVC gives no control on whether to ensemble or change the cross-validation strategy to control the size of the ensemble while CalibrationClassifierCV does offer such control (in addition to correctly handling sample_weight).

So I would still be in favor of deprecating the probability=True option.

@ogrisel
Copy link
Member

ogrisel commented Apr 25, 2025

cc @snath-xoc @antoinebaker

@ogrisel ogrisel added the Needs Investigation Issue requires investigation label Apr 25, 2025
@arhall0
Copy link
Author

arhall0 commented Apr 25, 2025

Thank you for taking a look at this.

Maybe SVC does something similar internally, in which case this is not a bug and the expected behavior.

From CalibratedClassifierCV documentation on the ensemble parameter (https://scikit-learn.org/stable/modules/generated/sklearn.calibration.CalibratedClassifierCV.html):

If False, cv is used to compute unbiased predictions, via cross_val_predict, which are then used for calibration. At prediction time, the classifier used is the estimator trained on all the data. Note that this method is also internally implemented in sklearn.svm estimators with the probabilities=True parameter.

So it seems like ensemble=False is what should be equivalent to probability=True according to documentation.

@snath-xoc
Copy link
Contributor

snath-xoc commented Apr 25, 2025

I was able to get the same results on my side as well, thank you @arhall0. I agree, this behaviour may be expected and possibly we should deprecate the probability=True option. This would also save a lot of confusion when using sample weighting.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Bug Needs Investigation Issue requires investigation
Projects
None yet
Development

No branches or pull requests

3 participants