-
-
Notifications
You must be signed in to change notification settings - Fork 25.8k
SLEP006: CalibratedClassifierCV #24126
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
dc70216
0d62945
3599e16
55488ca
f81389a
d7680f1
aceb4b4
d024a18
f9976ee
4bf98e9
a353671
d75b368
c3a1d67
ce3dabd
ad1b062
fdd7a89
b8d459d
62867d9
57b292b
98f9904
0a54eb8
8ae3475
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,7 +8,6 @@ | |
# License: BSD 3 clause | ||
|
||
import warnings | ||
from inspect import signature | ||
from functools import partial | ||
|
||
from math import log | ||
|
@@ -49,6 +48,7 @@ | |
from .model_selection import check_cv, cross_val_predict | ||
from .metrics._base import _check_pos_label_consistency | ||
from .metrics._plot.base import _get_response | ||
from .utils.metadata_routing import MetadataRouter, MethodMapping, process_routing | ||
|
||
|
||
class CalibratedClassifierCV(ClassifierMixin, MetaEstimatorMixin, BaseEstimator): | ||
|
@@ -259,6 +259,31 @@ def __init__( | |
self.ensemble = ensemble | ||
self.base_estimator = base_estimator | ||
|
||
def _get_estimator(self): | ||
"""Resolve which estimator to return (default is LinearSVC)""" | ||
# TODO(1.4): Remove when base_estimator is removed | ||
if self.base_estimator != "deprecated": | ||
if self.estimator is not None: | ||
raise ValueError( | ||
"Both `base_estimator` and `estimator` are set. Only set " | ||
"`estimator` since `base_estimator` is deprecated." | ||
) | ||
warnings.warn( | ||
"`base_estimator` was renamed to `estimator` in version 1.2 and " | ||
"will be removed in 1.4.", | ||
FutureWarning, | ||
) | ||
estimator = self.base_estimator | ||
else: | ||
estimator = self.estimator | ||
|
||
if estimator is None: | ||
# we want all classifiers that don't expose a random_state | ||
# to be deterministic (and we don't want to expose this one). | ||
estimator = LinearSVC(random_state=0).set_fit_request(sample_weight=True) | ||
|
||
return estimator | ||
|
||
def fit(self, X, y, sample_weight=None, **fit_params): | ||
"""Fit the calibrated model. | ||
|
||
|
@@ -290,26 +315,7 @@ def fit(self, X, y, sample_weight=None, **fit_params): | |
for sample_aligned_params in fit_params.values(): | ||
check_consistent_length(y, sample_aligned_params) | ||
|
||
# TODO(1.4): Remove when base_estimator is removed | ||
if self.base_estimator != "deprecated": | ||
if self.estimator is not None: | ||
raise ValueError( | ||
"Both `base_estimator` and `estimator` are set. Only set " | ||
"`estimator` since `base_estimator` is deprecated." | ||
) | ||
warnings.warn( | ||
"`base_estimator` was renamed to `estimator` in version 1.2 and " | ||
"will be removed in 1.4.", | ||
FutureWarning, | ||
) | ||
estimator = self.base_estimator | ||
else: | ||
estimator = self.estimator | ||
|
||
if estimator is None: | ||
# we want all classifiers that don't expose a random_state | ||
# to be deterministic (and we don't want to expose this one). | ||
estimator = LinearSVC(random_state=0) | ||
estimator = self._get_estimator() | ||
|
||
self.calibrated_classifiers_ = [] | ||
if self.cv == "prefit": | ||
|
@@ -336,20 +342,12 @@ def fit(self, X, y, sample_weight=None, **fit_params): | |
self.classes_ = label_encoder_.classes_ | ||
n_classes = len(self.classes_) | ||
|
||
# sample_weight checks | ||
fit_parameters = signature(estimator.fit).parameters | ||
supports_sw = "sample_weight" in fit_parameters | ||
if sample_weight is not None and not supports_sw: | ||
estimator_name = type(estimator).__name__ | ||
warnings.warn( | ||
f"Since {estimator_name} does not appear to accept sample_weight, " | ||
"sample weights will only be used for the calibration itself. This " | ||
"can be caused by a limitation of the current scikit-learn API. " | ||
"See the following issue for more details: " | ||
"https://github.com/scikit-learn/scikit-learn/issues/21134. Be " | ||
"warned that the result of the calibration is likely to be " | ||
"incorrect." | ||
) | ||
Comment on lines
-339
to
-352
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. note: metadata routing removes the need for this warning. The user will get the right warnings / errors if the metadata is not requested properly. |
||
routed_params = process_routing( | ||
obj=self, | ||
method="fit", | ||
sample_weight=sample_weight, | ||
other_params=fit_params, | ||
) | ||
|
||
# Check that each cross-validation fold can have at least one | ||
# example per class | ||
|
@@ -380,20 +378,14 @@ def fit(self, X, y, sample_weight=None, **fit_params): | |
test=test, | ||
method=self.method, | ||
classes=self.classes_, | ||
supports_sw=supports_sw, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. note: we don't need this parameter since routing will know what to route and what not. |
||
sample_weight=sample_weight, | ||
**fit_params, | ||
fit_params=routed_params.estimator.fit, | ||
) | ||
for train, test in cv.split(X, y) | ||
for train, test in cv.split(X, y, **routed_params.splitter.split) | ||
) | ||
else: | ||
this_estimator = clone(estimator) | ||
_, method_name = _get_prediction_method(this_estimator) | ||
fit_params = ( | ||
{"sample_weight": sample_weight} | ||
if sample_weight is not None and supports_sw | ||
else None | ||
) | ||
pred_method = partial( | ||
cross_val_predict, | ||
estimator=this_estimator, | ||
|
@@ -402,16 +394,13 @@ def fit(self, X, y, sample_weight=None, **fit_params): | |
cv=cv, | ||
method=method_name, | ||
n_jobs=self.n_jobs, | ||
fit_params=fit_params, | ||
fit_params=routed_params.estimator.fit, | ||
) | ||
predictions = _compute_predictions( | ||
pred_method, method_name, X, n_classes | ||
) | ||
|
||
if sample_weight is not None and supports_sw: | ||
this_estimator.fit(X, y, sample_weight=sample_weight) | ||
else: | ||
this_estimator.fit(X, y) | ||
this_estimator.fit(X, y, **routed_params.estimator.fit) | ||
# Note: Here we don't pass on fit_params because the supported | ||
# calibrators don't support fit_params anyway | ||
calibrated_classifier = _fit_calibrator( | ||
|
@@ -478,6 +467,37 @@ def predict(self, X): | |
check_is_fitted(self) | ||
return self.classes_[np.argmax(self.predict_proba(X), axis=1)] | ||
|
||
def get_metadata_routing(self): | ||
"""Get metadata routing of this object. | ||
|
||
Please check :ref:`User Guide <metadata_routing>` on how the routing | ||
mechanism works. | ||
|
||
Returns | ||
------- | ||
routing : MetadataRouter | ||
A :class:`~utils.metadata_routing.MetadataRouter` encapsulating | ||
routing information. | ||
""" | ||
router = ( | ||
MetadataRouter(owner=self.__class__.__name__) | ||
.add_self(self) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. note: self is added since this CCV is both a consumer and a router. One can do weighted CCV but unweighted fit for the underlying estimator. |
||
.add( | ||
estimator=self._get_estimator(), | ||
method_mapping=MethodMapping().add(callee="fit", caller="fit"), | ||
) | ||
adrinjalali marked this conversation as resolved.
Show resolved
Hide resolved
|
||
.add( | ||
splitter=self.cv, | ||
method_mapping=MethodMapping().add(callee="split", caller="fit"), | ||
) | ||
# the fit method already accepts everything, therefore we don't | ||
# specify parameters. The value passed to ``child`` needs to be the | ||
# same as what's passed to ``add`` above, in this case | ||
# `"estimator"`. | ||
.warn_on(child="estimator", method="fit", params=None) | ||
) | ||
return router | ||
|
||
def _more_tags(self): | ||
return { | ||
"_xfail_checks": { | ||
|
@@ -496,11 +516,10 @@ def _fit_classifier_calibrator_pair( | |
y, | ||
train, | ||
test, | ||
supports_sw, | ||
method, | ||
classes, | ||
sample_weight=None, | ||
**fit_params, | ||
fit_params=None, | ||
adrinjalali marked this conversation as resolved.
Show resolved
Hide resolved
|
||
): | ||
"""Fit a classifier/calibration pair on a given train/test split. | ||
|
||
|
@@ -525,9 +544,6 @@ def _fit_classifier_calibrator_pair( | |
test : ndarray, shape (n_test_indices,) | ||
Indices of the testing subset. | ||
|
||
supports_sw : bool | ||
Whether or not the `estimator` supports sample weights. | ||
|
||
method : {'sigmoid', 'isotonic'} | ||
Method to use for calibration. | ||
|
||
|
@@ -537,7 +553,7 @@ def _fit_classifier_calibrator_pair( | |
sample_weight : array-like, default=None | ||
Sample weights for `X`. | ||
|
||
**fit_params : dict | ||
fit_params : dict, default=None | ||
Parameters to pass to the `fit` method of the underlying | ||
classifier. | ||
|
||
|
@@ -549,11 +565,7 @@ def _fit_classifier_calibrator_pair( | |
X_train, y_train = _safe_indexing(X, train), _safe_indexing(y, train) | ||
X_test, y_test = _safe_indexing(X, test), _safe_indexing(y, test) | ||
|
||
if sample_weight is not None and supports_sw: | ||
sw_train = _safe_indexing(sample_weight, train) | ||
estimator.fit(X_train, y_train, sample_weight=sw_train, **fit_params_train) | ||
else: | ||
estimator.fit(X_train, y_train, **fit_params_train) | ||
estimator.fit(X_train, y_train, **fit_params_train) | ||
|
||
n_classes = len(classes) | ||
pred_method, method_name = _get_prediction_method(estimator) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,7 +19,7 @@ | |
from sklearn.exceptions import NotFittedError | ||
from sklearn.datasets import make_classification, make_blobs, load_iris | ||
from sklearn.preprocessing import LabelEncoder | ||
from sklearn.model_selection import KFold, cross_val_predict | ||
from sklearn.model_selection import GroupKFold, KFold, cross_val_predict | ||
from sklearn.naive_bayes import MultinomialNB | ||
from sklearn.ensemble import ( | ||
RandomForestClassifier, | ||
|
@@ -50,6 +50,10 @@ | |
N_SAMPLES = 200 | ||
|
||
|
||
def _weighted(estimator): | ||
return estimator.set_fit_request(sample_weight=True) | ||
Comment on lines
+53
to
+54
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. note: only |
||
|
||
|
||
@pytest.fixture(scope="module") | ||
def data(): | ||
X, y = make_classification(n_samples=N_SAMPLES, n_features=6, random_state=42) | ||
|
@@ -83,7 +87,9 @@ def test_calibration(data, method, ensemble): | |
(X_train, X_test), | ||
(sparse.csr_matrix(X_train), sparse.csr_matrix(X_test)), | ||
]: | ||
cal_clf = CalibratedClassifierCV(clf, method=method, cv=5, ensemble=ensemble) | ||
cal_clf = CalibratedClassifierCV( | ||
_weighted(clf), method=method, cv=5, ensemble=ensemble | ||
) | ||
# Note that this fit overwrites the fit on the entire training | ||
# set | ||
cal_clf.fit(this_X_train, y_train, sample_weight=sw_train) | ||
|
@@ -175,7 +181,7 @@ def test_sample_weight(data, method, ensemble): | |
X_train, y_train, sw_train = X[:n_samples], y[:n_samples], sample_weight[:n_samples] | ||
X_test = X[n_samples:] | ||
|
||
estimator = LinearSVC(random_state=42) | ||
estimator = _weighted(LinearSVC(random_state=42)) | ||
calibrated_clf = CalibratedClassifierCV(estimator, method=method, ensemble=ensemble) | ||
calibrated_clf.fit(X_train, y_train, sample_weight=sw_train) | ||
probs_with_sw = calibrated_clf.predict_proba(X_test) | ||
|
@@ -909,7 +915,7 @@ def test_calibrated_classifier_cv_double_sample_weights_equivalence(method, ense | |
y_twice[::2] = y | ||
y_twice[1::2] = y | ||
|
||
estimator = LogisticRegression() | ||
estimator = _weighted(LogisticRegression()) | ||
calibrated_clf_without_weights = CalibratedClassifierCV( | ||
estimator, | ||
method=method, | ||
|
@@ -951,7 +957,9 @@ def test_calibration_with_fit_params(fit_params_type, data): | |
"b": _convert_container(y, fit_params_type), | ||
} | ||
|
||
clf = CheckingClassifier(expected_fit_params=["a", "b"]) | ||
clf = CheckingClassifier(expected_fit_params=["a", "b"]).set_fit_request( | ||
a=True, b=True | ||
) | ||
pc_clf = CalibratedClassifierCV(clf) | ||
|
||
pc_clf.fit(X, y, **fit_params) | ||
|
@@ -969,34 +977,12 @@ def test_calibration_with_sample_weight_base_estimator(sample_weight, data): | |
estimator. | ||
""" | ||
X, y = data | ||
clf = CheckingClassifier(expected_sample_weight=True) | ||
clf = _weighted(CheckingClassifier(expected_sample_weight=True)) | ||
pc_clf = CalibratedClassifierCV(clf) | ||
|
||
pc_clf.fit(X, y, sample_weight=sample_weight) | ||
|
||
|
||
def test_calibration_without_sample_weight_base_estimator(data): | ||
"""Check that even if the estimator doesn't support | ||
sample_weight, fitting with sample_weight still works. | ||
|
||
There should be a warning, since the sample_weight is not passed | ||
on to the estimator. | ||
""" | ||
X, y = data | ||
sample_weight = np.ones_like(y) | ||
|
||
class ClfWithoutSampleWeight(CheckingClassifier): | ||
def fit(self, X, y, **fit_params): | ||
assert "sample_weight" not in fit_params | ||
return super().fit(X, y, **fit_params) | ||
|
||
clf = ClfWithoutSampleWeight() | ||
pc_clf = CalibratedClassifierCV(clf) | ||
|
||
with pytest.warns(UserWarning): | ||
pc_clf.fit(X, y, sample_weight=sample_weight) | ||
|
||
|
||
def test_calibration_with_fit_params_inconsistent_length(data): | ||
"""fit_params having different length than data should raise the | ||
correct error message. | ||
|
@@ -1029,7 +1015,7 @@ def test_calibrated_classifier_cv_zeros_sample_weights_equivalence(method, ensem | |
sample_weight = np.zeros_like(y) | ||
sample_weight[::2] = 1 | ||
|
||
estimator = LogisticRegression() | ||
estimator = _weighted(LogisticRegression()) | ||
calibrated_clf_without_weights = CalibratedClassifierCV( | ||
estimator, | ||
method=method, | ||
|
@@ -1077,3 +1063,13 @@ def test_calibrated_classifier_deprecation_base_estimator(data): | |
warn_msg = "`base_estimator` was renamed to `estimator`" | ||
with pytest.warns(FutureWarning, match=warn_msg): | ||
calibrated_classifier.fit(*data) | ||
|
||
|
||
def test_calibration_groupkfold(data): | ||
# Check that groups are routed to the splitter | ||
X, y = data | ||
groups = np.array([0, 1] * (len(y) // 2)) # assumes len(y) is even | ||
cv = GroupKFold(n_splits=2) | ||
calib_clf = CalibratedClassifierCV(cv=cv) | ||
# check that fitting does not raise an error | ||
calib_clf.fit(X, y, groups=groups) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
note to other reviewers: this is only a refactoring. Used in
fit
andget_metadata_routing
.