Skip to content

CalibratedClassifierCV does not handle well sample_weight when ensemble=False #20610

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
JulienB-78 opened this issue Jul 26, 2021 · 7 comments
Closed

Comments

@JulienB-78
Copy link
Contributor

JulienB-78 commented Jul 26, 2021

CalibratedClassifierCV does not handle well sample_weight with ensemble=False

In the fit method, sample_weight is not passed to cross_val_predict to generate the prediction scores (https://github.com/scikit-learn/scikit-learn/blob/2beed5584/sklearn/calibration.py#L325) whereas it is passed to fit when the classifier is refitted on the entire dataset (https://github.com/scikit-learn/scikit-learn/blob/2beed5584/sklearn/calibration.py#L328).

It makes the calibration to fail as the assumption that the classifiers built in each cv split of cross_val_predict has a similar behaviour as the one trained on the whole dataset at the end.

To correct the bug, I suggest to pass sample_weight to cross_val_predict using the fit_params dictionary

pred_method = partial(
                    cross_val_predict, estimator=this_estimator, X=X, y=y,
                    cv=cv, method=method_name, n_jobs=self.n_jobs,
                    fit_params={"sample_weight": sample_weight}
                )

Example to reproduce the issue:

import pandas as pd
import numpy as np
from matplotlib import pyplot as plt

from sklearn.datasets import make_hastie_10_2

from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.calibration import CalibratedClassifierCV

X, y = make_hastie_10_2(50000)
y[y == -1] = 0

X_0 = X[y == 0, :]
X_1 = X[y == 1, :]

y_0 = y[y == 0]
y_1 = y[y == 1]

# Discard half of the sample with y==0
X = np.vstack([X_0[:int(len(y_1) / 2), :], X_1])
y = np.hstack([y_0[:int(len(y_1) / 2)], y_1])

# Compute weigths to 'unbalance' the dataset'
weight = (y==0) + 1

X_train, X_test, y_train, y_test, weight_train, weight_test = train_test_split(X, y, weight)

calib = CalibratedClassifierCV(RandomForestClassifier(n_estimators=5, max_depth=3), ensemble=False, n_jobs=-1)
calib.fit(X_train, y_train, sample_weight=weight_train)

pred = calib.predict_proba(X_test)[:, 1]

# Check calibration in a way which takes into account that both classes have equal importance despite
# class 0 being less frequent

df_target_pred = pd.DataFrame([y_test, pred]).transpose()
df_target_pred.columns = ["target", "pred"]

hist_0 = np.histogram(df_target_pred.loc[df_target_pred.target == 0, 'pred'], bins=np.linspace(0, 1, 6), density=True)
hist_1 = np.histogram(df_target_pred.loc[df_target_pred.target == 1,  'pred'], bins=np.linspace(0, 1, 6), density=True)

plt.bar(hist_0[1][:-1], hist_0[0], align='edge', label='0', alpha=0.5, width=0.2)
plt.bar(hist_1[1][:-1], hist_1[0], align='edge', label='1', alpha=0.5, width=0.2)
plt.ylabel('Prediction histograms')
plt.xlabel('Predicted score')
plt.legend()
ax2 = plt.gca().twinx()
ax2.plot([0, 1], [0, 1], "k:", label="Perfectly calibrated")
ax2.plot((hist_0[1][:-1] + hist_0[1][1:]) / 2, hist_1[0] / (hist_0[0] + hist_1[0] + 1e-10))
ax2.grid(False)
ax2.set_ylabel('Fraction of positives')
plt.show()

Versions

System:
python: 3.7.10 (default, Feb 26 2021, 13:06:18) [MSC v.1916 64 bit (AMD64)]
executable: C:\HOMEWARE\Anaconda3-Windows-x86_64\envs\python37\python.exe
machine: Windows-10-10.0.18362-SP0

Python dependencies:
pip: 21.1.3
setuptools: 52.0.0.post20210125
sklearn: 0.24.2
numpy: 1.20.2
scipy: 1.6.2
Cython: None
pandas: 1.2.5
matplotlib: 3.3.4
joblib: 1.0.1
threadpoolctl: 2.1.0

Built with OpenMP: True

@glemaitre
Copy link
Member

Yep, I assume that we should not support this case since that currently we cannot pass any sample_weght during the cross-validation.

@glemaitre glemaitre added Bug and removed Bug: triage labels Jul 27, 2021
@JulienB-78
Copy link
Contributor Author

It seems to me that it is possible to pass sample_weight to the classifier fitting performed inside cross_val_predict by using the parameter fit_params={"sample_weight": sample_weight}.

I replaced the content of the else starting at https://github.com/scikit-learn/scikit-learn/blob/2beed5584/sklearn/calibration.py#L319 with the following code:

                this_estimator = clone(base_estimator)
                method_name = _get_prediction_method(this_estimator).__name__             

                if sample_weight is not None and supports_sw:
                    pred_method = partial(
                        cross_val_predict, estimator=this_estimator, X=X, y=y,
                        cv=cv, method=method_name, n_jobs=self.n_jobs, fit_params={"sample_weight": sample_weight}
                    )                
                    this_estimator.fit(X, y, sample_weight)
                else:
                    pred_method = partial(
                        cross_val_predict, estimator=this_estimator, X=X, y=y,
                        cv=cv, method=method_name, n_jobs=self.n_jobs
                    )                                    
                    this_estimator.fit(X, y)
                predictions = _compute_predictions(pred_method, X, n_classes)
                calibrated_classifier = _fit_calibrator(
                    this_estimator, predictions, y, self.classes_, self.method,
                    sample_weight
                )
                self.calibrated_classifiers_.append(calibrated_classifier)

and it does correct the problem.

Do you think it could have any negative side effect?

@glemaitre
Copy link
Member

Ah right, we pass using fit_params.

Do you think it could have any negative side effect?

I don't see any. As you mentioned this is what we do in the ensemble branch.

@glemaitre
Copy link
Member

Do you wish to make a PR to implement the fix that you propose together with a test that check that the behaviour is fine.

@JulienB-78
Copy link
Contributor Author

Yes, I have started to look at the guidelines for contributing to do so.

@adrinjalali
Copy link
Member

@BenjaminBossan I think this is fixed with your PR in #24126 (or was fixed before). Could you please confirm?

@JulienB-78
Copy link
Contributor Author

JulienB-78 commented Aug 18, 2022 via email

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
4 participants