Skip to content

SLEP006: CalibratedClassifierCV #24126

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
dc70216
Make CalibratedClassifierCV work with SLEP006
BenjaminBossan Aug 3, 2022
0d62945
Fix black complaints
BenjaminBossan Aug 5, 2022
3599e16
Merge branch 'sample-props' into slep006/calibratedclassifiercv
BenjaminBossan Aug 8, 2022
55488ca
Address reviewer comments by Adrin
BenjaminBossan Aug 8, 2022
f81389a
Address reviewer comment: warns_on
BenjaminBossan Aug 9, 2022
d7680f1
Address reviewer comment: recording None
BenjaminBossan Aug 10, 2022
aceb4b4
__copy__ has no arguments
BenjaminBossan Aug 10, 2022
d024a18
For expected errors, test each arg seperately
BenjaminBossan Aug 10, 2022
f9976ee
Fit the metaestimator if not done by method
BenjaminBossan Aug 10, 2022
4bf98e9
Merge branch 'sample-props' into slep006/calibratedclassifiercv
BenjaminBossan Aug 10, 2022
a353671
Ignore type checking on CheckingClassifier
BenjaminBossan Aug 15, 2022
d75b368
Add routing of groups to splitter
BenjaminBossan Aug 15, 2022
c3a1d67
Black formatting
BenjaminBossan Aug 15, 2022
ce3dabd
Fix typo
BenjaminBossan Aug 16, 2022
ad1b062
Reviewer request: Change checking of routed data
BenjaminBossan Aug 16, 2022
fdd7a89
Empty-Commit
BenjaminBossan Aug 16, 2022
b8d459d
CI empty-commit
thomasjpfan Aug 16, 2022
62867d9
Yet another empty commit to try if CI works
BenjaminBossan Aug 16, 2022
57b292b
New empty commit to try if CI works now
BenjaminBossan Aug 17, 2022
98f9904
Empty commit after creating CircleCI acct
BenjaminBossan Aug 17, 2022
0a54eb8
Don't explicitly check groups values in test
BenjaminBossan Aug 17, 2022
8ae3475
trigger ci
adrinjalali Aug 17, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 71 additions & 59 deletions sklearn/calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
# License: BSD 3 clause

import warnings
from inspect import signature
from functools import partial

from math import log
Expand Down Expand Up @@ -49,6 +48,7 @@
from .model_selection import check_cv, cross_val_predict
from .metrics._base import _check_pos_label_consistency
from .metrics._plot.base import _get_response
from .utils.metadata_routing import MetadataRouter, MethodMapping, process_routing


class CalibratedClassifierCV(ClassifierMixin, MetaEstimatorMixin, BaseEstimator):
Expand Down Expand Up @@ -259,6 +259,31 @@ def __init__(
self.ensemble = ensemble
self.base_estimator = base_estimator

def _get_estimator(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note to other reviewers: this is only a refactoring. Used in fit and get_metadata_routing.

"""Resolve which estimator to return (default is LinearSVC)"""
# TODO(1.4): Remove when base_estimator is removed
if self.base_estimator != "deprecated":
if self.estimator is not None:
raise ValueError(
"Both `base_estimator` and `estimator` are set. Only set "
"`estimator` since `base_estimator` is deprecated."
)
warnings.warn(
"`base_estimator` was renamed to `estimator` in version 1.2 and "
"will be removed in 1.4.",
FutureWarning,
)
estimator = self.base_estimator
else:
estimator = self.estimator

if estimator is None:
# we want all classifiers that don't expose a random_state
# to be deterministic (and we don't want to expose this one).
estimator = LinearSVC(random_state=0).set_fit_request(sample_weight=True)

return estimator

def fit(self, X, y, sample_weight=None, **fit_params):
"""Fit the calibrated model.

Expand Down Expand Up @@ -290,26 +315,7 @@ def fit(self, X, y, sample_weight=None, **fit_params):
for sample_aligned_params in fit_params.values():
check_consistent_length(y, sample_aligned_params)

# TODO(1.4): Remove when base_estimator is removed
if self.base_estimator != "deprecated":
if self.estimator is not None:
raise ValueError(
"Both `base_estimator` and `estimator` are set. Only set "
"`estimator` since `base_estimator` is deprecated."
)
warnings.warn(
"`base_estimator` was renamed to `estimator` in version 1.2 and "
"will be removed in 1.4.",
FutureWarning,
)
estimator = self.base_estimator
else:
estimator = self.estimator

if estimator is None:
# we want all classifiers that don't expose a random_state
# to be deterministic (and we don't want to expose this one).
estimator = LinearSVC(random_state=0)
estimator = self._get_estimator()

self.calibrated_classifiers_ = []
if self.cv == "prefit":
Expand All @@ -336,20 +342,12 @@ def fit(self, X, y, sample_weight=None, **fit_params):
self.classes_ = label_encoder_.classes_
n_classes = len(self.classes_)

# sample_weight checks
fit_parameters = signature(estimator.fit).parameters
supports_sw = "sample_weight" in fit_parameters
if sample_weight is not None and not supports_sw:
estimator_name = type(estimator).__name__
warnings.warn(
f"Since {estimator_name} does not appear to accept sample_weight, "
"sample weights will only be used for the calibration itself. This "
"can be caused by a limitation of the current scikit-learn API. "
"See the following issue for more details: "
"https://github.com/scikit-learn/scikit-learn/issues/21134. Be "
"warned that the result of the calibration is likely to be "
"incorrect."
)
Comment on lines -339 to -352
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note: metadata routing removes the need for this warning. The user will get the right warnings / errors if the metadata is not requested properly.

routed_params = process_routing(
obj=self,
method="fit",
sample_weight=sample_weight,
other_params=fit_params,
)

# Check that each cross-validation fold can have at least one
# example per class
Expand Down Expand Up @@ -380,20 +378,14 @@ def fit(self, X, y, sample_weight=None, **fit_params):
test=test,
method=self.method,
classes=self.classes_,
supports_sw=supports_sw,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note: we don't need this parameter since routing will know what to route and what not.

sample_weight=sample_weight,
**fit_params,
fit_params=routed_params.estimator.fit,
)
for train, test in cv.split(X, y)
for train, test in cv.split(X, y, **routed_params.splitter.split)
)
else:
this_estimator = clone(estimator)
_, method_name = _get_prediction_method(this_estimator)
fit_params = (
{"sample_weight": sample_weight}
if sample_weight is not None and supports_sw
else None
)
pred_method = partial(
cross_val_predict,
estimator=this_estimator,
Expand All @@ -402,16 +394,13 @@ def fit(self, X, y, sample_weight=None, **fit_params):
cv=cv,
method=method_name,
n_jobs=self.n_jobs,
fit_params=fit_params,
fit_params=routed_params.estimator.fit,
)
predictions = _compute_predictions(
pred_method, method_name, X, n_classes
)

if sample_weight is not None and supports_sw:
this_estimator.fit(X, y, sample_weight=sample_weight)
else:
this_estimator.fit(X, y)
this_estimator.fit(X, y, **routed_params.estimator.fit)
# Note: Here we don't pass on fit_params because the supported
# calibrators don't support fit_params anyway
calibrated_classifier = _fit_calibrator(
Expand Down Expand Up @@ -478,6 +467,37 @@ def predict(self, X):
check_is_fitted(self)
return self.classes_[np.argmax(self.predict_proba(X), axis=1)]

def get_metadata_routing(self):
"""Get metadata routing of this object.

Please check :ref:`User Guide <metadata_routing>` on how the routing
mechanism works.

Returns
-------
routing : MetadataRouter
A :class:`~utils.metadata_routing.MetadataRouter` encapsulating
routing information.
"""
router = (
MetadataRouter(owner=self.__class__.__name__)
.add_self(self)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note: self is added since this CCV is both a consumer and a router. One can do weighted CCV but unweighted fit for the underlying estimator.

.add(
estimator=self._get_estimator(),
method_mapping=MethodMapping().add(callee="fit", caller="fit"),
)
.add(
splitter=self.cv,
method_mapping=MethodMapping().add(callee="split", caller="fit"),
)
# the fit method already accepts everything, therefore we don't
# specify parameters. The value passed to ``child`` needs to be the
# same as what's passed to ``add`` above, in this case
# `"estimator"`.
.warn_on(child="estimator", method="fit", params=None)
)
return router

def _more_tags(self):
return {
"_xfail_checks": {
Expand All @@ -496,11 +516,10 @@ def _fit_classifier_calibrator_pair(
y,
train,
test,
supports_sw,
method,
classes,
sample_weight=None,
**fit_params,
fit_params=None,
):
"""Fit a classifier/calibration pair on a given train/test split.

Expand All @@ -525,9 +544,6 @@ def _fit_classifier_calibrator_pair(
test : ndarray, shape (n_test_indices,)
Indices of the testing subset.

supports_sw : bool
Whether or not the `estimator` supports sample weights.

method : {'sigmoid', 'isotonic'}
Method to use for calibration.

Expand All @@ -537,7 +553,7 @@ def _fit_classifier_calibrator_pair(
sample_weight : array-like, default=None
Sample weights for `X`.

**fit_params : dict
fit_params : dict, default=None
Parameters to pass to the `fit` method of the underlying
classifier.

Expand All @@ -549,11 +565,7 @@ def _fit_classifier_calibrator_pair(
X_train, y_train = _safe_indexing(X, train), _safe_indexing(y, train)
X_test, y_test = _safe_indexing(X, test), _safe_indexing(y, test)

if sample_weight is not None and supports_sw:
sw_train = _safe_indexing(sample_weight, train)
estimator.fit(X_train, y_train, sample_weight=sw_train, **fit_params_train)
else:
estimator.fit(X_train, y_train, **fit_params_train)
estimator.fit(X_train, y_train, **fit_params_train)

n_classes = len(classes)
pred_method, method_name = _get_prediction_method(estimator)
Expand Down
54 changes: 25 additions & 29 deletions sklearn/tests/test_calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from sklearn.exceptions import NotFittedError
from sklearn.datasets import make_classification, make_blobs, load_iris
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import KFold, cross_val_predict
from sklearn.model_selection import GroupKFold, KFold, cross_val_predict
from sklearn.naive_bayes import MultinomialNB
from sklearn.ensemble import (
RandomForestClassifier,
Expand Down Expand Up @@ -50,6 +50,10 @@
N_SAMPLES = 200


def _weighted(estimator):
return estimator.set_fit_request(sample_weight=True)
Comment on lines +53 to +54
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note: only fit can be weighted in CCV, hence only requesting sample_weight for fit.



@pytest.fixture(scope="module")
def data():
X, y = make_classification(n_samples=N_SAMPLES, n_features=6, random_state=42)
Expand Down Expand Up @@ -83,7 +87,9 @@ def test_calibration(data, method, ensemble):
(X_train, X_test),
(sparse.csr_matrix(X_train), sparse.csr_matrix(X_test)),
]:
cal_clf = CalibratedClassifierCV(clf, method=method, cv=5, ensemble=ensemble)
cal_clf = CalibratedClassifierCV(
_weighted(clf), method=method, cv=5, ensemble=ensemble
)
# Note that this fit overwrites the fit on the entire training
# set
cal_clf.fit(this_X_train, y_train, sample_weight=sw_train)
Expand Down Expand Up @@ -175,7 +181,7 @@ def test_sample_weight(data, method, ensemble):
X_train, y_train, sw_train = X[:n_samples], y[:n_samples], sample_weight[:n_samples]
X_test = X[n_samples:]

estimator = LinearSVC(random_state=42)
estimator = _weighted(LinearSVC(random_state=42))
calibrated_clf = CalibratedClassifierCV(estimator, method=method, ensemble=ensemble)
calibrated_clf.fit(X_train, y_train, sample_weight=sw_train)
probs_with_sw = calibrated_clf.predict_proba(X_test)
Expand Down Expand Up @@ -909,7 +915,7 @@ def test_calibrated_classifier_cv_double_sample_weights_equivalence(method, ense
y_twice[::2] = y
y_twice[1::2] = y

estimator = LogisticRegression()
estimator = _weighted(LogisticRegression())
calibrated_clf_without_weights = CalibratedClassifierCV(
estimator,
method=method,
Expand Down Expand Up @@ -951,7 +957,9 @@ def test_calibration_with_fit_params(fit_params_type, data):
"b": _convert_container(y, fit_params_type),
}

clf = CheckingClassifier(expected_fit_params=["a", "b"])
clf = CheckingClassifier(expected_fit_params=["a", "b"]).set_fit_request(
a=True, b=True
)
pc_clf = CalibratedClassifierCV(clf)

pc_clf.fit(X, y, **fit_params)
Expand All @@ -969,34 +977,12 @@ def test_calibration_with_sample_weight_base_estimator(sample_weight, data):
estimator.
"""
X, y = data
clf = CheckingClassifier(expected_sample_weight=True)
clf = _weighted(CheckingClassifier(expected_sample_weight=True))
pc_clf = CalibratedClassifierCV(clf)

pc_clf.fit(X, y, sample_weight=sample_weight)


def test_calibration_without_sample_weight_base_estimator(data):
"""Check that even if the estimator doesn't support
sample_weight, fitting with sample_weight still works.

There should be a warning, since the sample_weight is not passed
on to the estimator.
"""
X, y = data
sample_weight = np.ones_like(y)

class ClfWithoutSampleWeight(CheckingClassifier):
def fit(self, X, y, **fit_params):
assert "sample_weight" not in fit_params
return super().fit(X, y, **fit_params)

clf = ClfWithoutSampleWeight()
pc_clf = CalibratedClassifierCV(clf)

with pytest.warns(UserWarning):
pc_clf.fit(X, y, sample_weight=sample_weight)


def test_calibration_with_fit_params_inconsistent_length(data):
"""fit_params having different length than data should raise the
correct error message.
Expand Down Expand Up @@ -1029,7 +1015,7 @@ def test_calibrated_classifier_cv_zeros_sample_weights_equivalence(method, ensem
sample_weight = np.zeros_like(y)
sample_weight[::2] = 1

estimator = LogisticRegression()
estimator = _weighted(LogisticRegression())
calibrated_clf_without_weights = CalibratedClassifierCV(
estimator,
method=method,
Expand Down Expand Up @@ -1077,3 +1063,13 @@ def test_calibrated_classifier_deprecation_base_estimator(data):
warn_msg = "`base_estimator` was renamed to `estimator`"
with pytest.warns(FutureWarning, match=warn_msg):
calibrated_classifier.fit(*data)


def test_calibration_groupkfold(data):
# Check that groups are routed to the splitter
X, y = data
groups = np.array([0, 1] * (len(y) // 2)) # assumes len(y) is even
cv = GroupKFold(n_splits=2)
calib_clf = CalibratedClassifierCV(cv=cv)
# check that fitting does not raise an error
calib_clf.fit(X, y, groups=groups)
20 changes: 16 additions & 4 deletions sklearn/tests/test_metadata_routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def assert_request_is_empty(metadata_request, exclude=None):
"""Check if a metadata request dict is empty.

One can exclude a method or a list of methods from the check using the
``exclude`` perameter.
``exclude`` parameter.
"""
if isinstance(metadata_request, MetadataRouter):
for _, route_mapping in metadata_request:
Expand Down Expand Up @@ -70,10 +70,22 @@ def assert_request_equal(request, dictionary):
assert not len(getattr(request, method).requests)


def record_metadata(obj, method, **kwargs):
"""Utility function to store passed metadata to a method."""
def record_metadata(obj, method, record_default=True, **kwargs):
"""Utility function to store passed metadata to a method.

If record_default is False, kwargs whose values are "default" are skipped.
This is so that checks on keyword arguments whose default was not changed
are skipped.

"""
if not hasattr(obj, "_records"):
setattr(obj, "_records", dict())
obj._records = {}
if not record_default:
kwargs = {
key: val
for key, val in kwargs.items()
if not isinstance(val, str) or (val != "default")
}
obj._records[method] = kwargs


Expand Down
Loading