-
-
Notifications
You must be signed in to change notification settings - Fork 25.8k
SVC Sigmoid sometimes ROC AUC from predict_proba & decision_function are each other's inverse #31222
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
Thanks for the report and the reproducer. I agree that this is a surprising behavior (bug?) of the built-in implementation of Platt-scaling from our vendored code of libsvm. Note that you can instead use import numpy as np
from sklearn.svm import SVC
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import roc_auc_score, get_scorer
from sklearn.model_selection import train_test_split
from sklearn.calibration import CalibratedClassifierCV
n_samples = 100
n_features = 100
random_state = 123
rng = np.random.default_rng(random_state)
X = rng.normal(loc=0.0, scale=1.0, size=(n_samples, n_features))
y = rng.integers(0, 2, size=n_samples)
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=random_state)
svc_params = {
"kernel": "sigmoid",
"random_state": random_state,
}
pipeline = Pipeline(
[
("scaler", StandardScaler()),
("svc", SVC(**svc_params)),
]
)
pipeline_cal = CalibratedClassifierCV(pipeline, method="sigmoid", ensemble=False)
pipeline.fit(X_train, y_train)
pipeline_cal.fit(X_train, y_train)
y_proba = pipeline_cal.predict_proba(X_test)[:, 1]
y_dec = pipeline.decision_function(X_test)
roc_auc_proba = roc_auc_score(y_test, y_proba)
roc_auc_dec = roc_auc_score(y_test, y_dec)
auc_scorer = get_scorer("roc_auc")
scorer_auc = auc_scorer(pipeline, X_test, y_test)
print(f"AUC (roc_auc_score from predict_proba) = {roc_auc_proba:.4f}")
print(f"AUC (roc_auc_score from decision_function) = {roc_auc_dec:.4f}")
print(f"AUC (get_scorer) = {scorer_auc:.4f}")
We also noticed that the Platt-scaling implementation of libsvm does not support Since we have a |
Note that if you set
This is expected because the ensembling effect does not preserve the prediction ranking vs fitting a single However, So I would still be in favor of deprecating the |
Thank you for taking a look at this.
From
So it seems like |
I was able to get the same results on my side as well, thank you @arhall0. I agree, this behaviour may be expected and possibly we should deprecate the |
Describe the bug
Uncertain if this is a bug or counter-intuitive expected behavior.
Under certain circumstances the ROC AUC calculated for
SVC
with thesigmoid
kernel will not agree depending on if you usepredict_proba
ordecision_function
. In fact, they will be nearly1-other_method_auc
.This was noticed when comparing ROC AUC calculated using
roc_auc_score
with predictions frompredict_proba(X)[:, 1]
to using the scorer fromget_scorer('roc_auc')
which appears to be callingroc_auc_score
with scores fromdecision_function
.Steps/Code to Reproduce
Expected Results
The measures of ROC AUC agree
Actual Results
Versions
The text was updated successfully, but these errors were encountered: