Skip to content

Refactor CalibratedClassifierCV #17803

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
250 changes: 135 additions & 115 deletions sklearn/calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@
from scipy.special import expit
from scipy.special import xlogy
from scipy.optimize import fmin_bfgs
from .preprocessing import LabelEncoder

from .base import (BaseEstimator, ClassifierMixin, RegressorMixin, clone,
MetaEstimatorMixin)
from .preprocessing import label_binarize, LabelBinarizer
from .preprocessing import label_binarize, LabelEncoder
from .utils import check_array, indexable, column_or_1d
from .utils.multiclass import check_classification_targets
from .utils.validation import check_is_fitted, check_consistent_length
from .utils.validation import _check_sample_weight
from .pipeline import Pipeline
Expand Down Expand Up @@ -99,6 +99,14 @@ class CalibratedClassifierCV(BaseEstimator, ClassifierMixin,
split, which has been fitted on training folds and
calibrated on the testing fold.

n_features_in_ : int
The number of features in `X`. If `cv='prefit'`, number of features
in the data used to fit `base_estimator`.

label_encoder_ : LabelEncoder instance
`LabelEncoder` fitted on `y`. If `cv='prefit'`, `LabelEncoder`
fitted on `base_estimator.classes_`.

Examples
--------
>>> from sklearn.datasets import make_classification
Expand Down Expand Up @@ -178,6 +186,7 @@ def fit(self, X, y, sample_weight=None):
self : object
Returns an instance of self.
"""
check_classification_targets(y)
X, y = indexable(X, y)

self.calibrated_classifiers_ = []
Expand All @@ -189,26 +198,31 @@ def fit(self, X, y, sample_weight=None):
base_estimator = self.base_estimator

if self.cv == "prefit":
# Set `n_features_in_` attribute
# `classes_` and `n_features_in_` should be consistent with that
# of base_estimator
if isinstance(self.base_estimator, Pipeline):
check_is_fitted(self.base_estimator[-1])
else:
check_is_fitted(self.base_estimator)
with suppress(AttributeError):
self.n_features_in_ = base_estimator.n_features_in_
self.classes_ = self.base_estimator.classes_
self.label_encoder_ = LabelEncoder().fit(self.classes_)

calibrated_classifier = _CalibratedClassifier(
base_estimator, method=self.method)
calibrated_classifier.fit(X, y, sample_weight)
calibrated_classifier = _fit_calibrator(
base_estimator, self.label_encoder_, self.method, X, y,
sample_weight
)
self.calibrated_classifiers_.append(calibrated_classifier)
else:
X, y = self._validate_data(
X, y, accept_sparse=['csc', 'csr', 'coo'],
force_all_finite=False, allow_nd=True
)
le = LabelBinarizer().fit(y)
# Set attributes using all `y`
le = LabelEncoder().fit(y)
self.classes_ = le.classes_
self.label_encoder_ = le

# Check that each cross-validation fold can have at least one
# example per class
Expand Down Expand Up @@ -246,18 +260,19 @@ def fit(self, X, y, sample_weight=None):
else:
this_estimator.fit(X[train], y[train])

calibrated_classifier = _CalibratedClassifier(
this_estimator, method=self.method, classes=self.classes_)
sw = None if sample_weight is None else sample_weight[test]
calibrated_classifier.fit(X[test], y[test], sample_weight=sw)
calibrated_classifier = _fit_calibrator(
this_estimator, self.label_encoder_, self.method,
X[test], y[test], sw
)
self.calibrated_classifiers_.append(calibrated_classifier)

return self

def predict_proba(self, X):
"""Posterior probabilities of classification
"""Calibrated probabilities of classification

This function returns posterior probabilities of classification
This function returns calibrated probabilities of classification
according to each class on an array of test vectors X.

Parameters
Expand Down Expand Up @@ -311,145 +326,150 @@ def _more_tags(self):
}


class _CalibratedClassifier:
"""Probability calibration with isotonic regression or sigmoid.
def _get_predictions(clf_fitted, X, label_encoder_):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nitpick renaming suggestion: _compute_predictions to emphasize that this is actually calling the pred_method (rather that just retrieving some precomputed predictions).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changing this in #17856 as it's easier..

"""Returns predictions for `X` and index of classes present in `X`.

It assumes that base_estimator has already been fit, and trains the
calibration on the input set of the fit function. Note that this class
should not be used as an estimator directly. Use CalibratedClassifierCV
with cv="prefit" instead.
For predicitons, `decision_function` method of the `clf_fitted` is used.
If this does not exist, `predict_proba` method used.

Parameters
----------
base_estimator : instance BaseEstimator
The classifier whose output decision function needs to be calibrated
to offer more accurate predict_proba outputs. No default value since
it has to be an already fitted estimator.
clf_fitted : Estimator instance
Fitted classifier instance.

method : 'sigmoid' | 'isotonic'
The method to use for calibration. Can be 'sigmoid' which
corresponds to Platt's method or 'isotonic' which is a
non-parametric approach based on isotonic regression.
X : array-like
Sample data used for the predictions.

classes : array-like, shape (n_classes,), optional
Contains unique classes used to fit the base estimator.
if None, then classes is extracted from the given target values
in fit().
label_encoder_ : LabelEncoder instance
LabelEncoder instance fitted on all the targets.

See also
--------
CalibratedClassifierCV
Returns
-------
df : array-like, shape (X.shape[0], len(clf_fitted.classes_))
The predictions. Note array is of shape (X.shape[0], 1) when there are
2 classes.

References
----------
.. [1] Obtaining calibrated probability estimates from decision trees
and naive Bayesian classifiers, B. Zadrozny & C. Elkan, ICML 2001
pos_class_indices : array-like, shape (n_classes,)
Indices of the classes present in `X`.
"""
if hasattr(clf_fitted, "decision_function"):
df = clf_fitted.decision_function(X)
if df.ndim == 1:
df = df[:, np.newaxis]
elif hasattr(clf_fitted, "predict_proba"):
df = clf_fitted.predict_proba(X)
if len(label_encoder_.classes_) == 2:
df = df[:, 1:]
else:
raise RuntimeError("'base_estimator' has no 'decision_function' or "
"'predict_proba' method.")

.. [2] Transforming Classifier Scores into Accurate Multiclass
Probability Estimates, B. Zadrozny & C. Elkan, (KDD 2002)
pos_class_indices = label_encoder_.transform(clf_fitted.classes_)

.. [3] Probabilistic Outputs for Support Vector Machines and Comparisons to
Regularized Likelihood Methods, J. Platt, (1999)
return df, pos_class_indices

.. [4] Predicting Good Probabilities with Supervised Learning,
A. Niculescu-Mizil & R. Caruana, ICML 2005
"""
@_deprecate_positional_args
def __init__(self, base_estimator, *, method='sigmoid', classes=None):
self.base_estimator = base_estimator
self.method = method
self.classes = classes

def _preproc(self, X):
n_classes = len(self.classes_)
if hasattr(self.base_estimator, "decision_function"):
df = self.base_estimator.decision_function(X)
if df.ndim == 1:
df = df[:, np.newaxis]
elif hasattr(self.base_estimator, "predict_proba"):
df = self.base_estimator.predict_proba(X)
if n_classes == 2:
df = df[:, 1:]
else:
raise RuntimeError('classifier has no decision_function or '
'predict_proba method.')

idx_pos_class = self.label_encoder_.\
transform(self.base_estimator.classes_)
def _fit_calibrator(clf_fitted, label_encoder_, method, X, y,
sample_weight=None):
"""Fit calibrator(s) and return a `_CalibratedClassiferPipeline`
instance.

return df, idx_pos_class
Output of the `decision_function` method of the `clf_fitted` is used for
calibration. If this method does not exist, `predict_proba` method is used.

def fit(self, X, y, sample_weight=None):
"""Calibrate the fitted model
Parameters
----------
clf_fitted : Estimator instance
Fitted classifier.

Parameters
----------
X : array-like, shape (n_samples, n_features)
Training data.
label_encoder_ : LabelEncoder instance
LabelEncoder instance fitted on all the targets.

y : array-like, shape (n_samples,)
Target values.
method : {'sigmoid', 'isotonic'}
The method to use for calibration.

sample_weight : array-like of shape (n_samples,), default=None
Sample weights. If None, then samples are equally weighted.
X : array-like
Sample data used to calibrate predictions.

Returns
-------
self : object
Returns an instance of self.
"""
y : ndarray, shape (n_samples,)
The targets.

sample_weight : ndarray, shape (n_samples,), default=None
Sample weights. If `None`, then samples are equally weighted.

self.label_encoder_ = LabelEncoder()
if self.classes is None:
self.label_encoder_.fit(y)
Returns
-------
pipeline : _CalibratedClassiferPipeline instance
"""
Y = label_binarize(y, classes=label_encoder_.classes_)
df, pos_class_indices = _get_predictions(clf_fitted, X, label_encoder_)

calibrated_classifiers = []
for class_idx, this_df in zip(pos_class_indices, df.T):
if method == 'isotonic':
calibrator = IsotonicRegression(out_of_bounds='clip')
elif method == 'sigmoid':
calibrator = _SigmoidCalibration()
else:
self.label_encoder_.fit(self.classes)
raise ValueError("'method' should be one of: 'sigmoid' or "
f"'isotonic'. Got {method}.")
calibrator.fit(this_df, Y[:, class_idx], sample_weight)
calibrated_classifiers.append(calibrator)

self.classes_ = self.label_encoder_.classes_
Y = label_binarize(y, classes=self.classes_)
pipeline = _CalibratedClassiferPipeline(
clf_fitted, calibrated_classifiers, label_encoder_
)
return pipeline

df, idx_pos_class = self._preproc(X)
self.calibrators_ = []

for k, this_df in zip(idx_pos_class, df.T):
if self.method == 'isotonic':
calibrator = IsotonicRegression(out_of_bounds='clip')
elif self.method == 'sigmoid':
calibrator = _SigmoidCalibration()
else:
raise ValueError('method should be "sigmoid" or '
Copy link
Member

@NicolasHug NicolasHug Jul 1, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I prefer validating the input here instead of having an if/elif with no else clause. In general it's better to validate the input right before it is used.

'"isotonic". Got %s.' % self.method)
calibrator.fit(this_df, Y[:, k], sample_weight)
self.calibrators_.append(calibrator)
class _CalibratedClassiferPipeline:
"""Pipeline-like chaining a fitted classifier and its fitted calibrators.

return self
Parameters
----------
clf_fitted : Estimator instance
Fitted classifier.

calibrators_fitted : List of fitted estimator instances
List of fitted calibrators (either 'IsotonicRegression' or
'_SigmoidCalibration'). The number of calibrators equals the number of
classes. However, if there are 2 classes, the list contains only one
fitted calibrator.
"""
def __init__(self, clf_fitted, calibrators_fitted, label_encoder_):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I would prefer to use the names:

  • clf_fitted => base_estimator
  • calibrators_fitted = > calibrators
  • label_encoder_ => label_encoder

and just mention in the docstring that all of them are expected to be fitted.

Also we could add a property named calibrators_ that would just return calibrators and label_encoder_ pointing to label_encoder to ensure backward compat with the previous attribute names, maybe with a deprecation warning.

While this is a private class, calibrated classifier instances are stored in a public attribute of the public CalibratedClassifierCV class, therefore could be considered semi-public API, at least from a duck-typing point of view.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @NicolasHug reminded me of keeping the attributes constant and it's updated #17856
I've also removed the attribute label_encoder.

Also we could add a property named calibrators_ that would just return calibrators

This may be tricky/confusing because for multiclass, there is more than one calibrator per classifier. e.g., for 3 classes each pair would be (1 classifier, 3 calibrators). I guess we could return a list of tuples, each tuple containing the calibrators for a 'pair' ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually I realised that _CalibratedClassifier has a calibrators attribute, which can be accessed via calibrated_classifiers_.calibrators

self.clf_fitted = clf_fitted
self.calibrators_fitted = calibrators_fitted
self.label_encoder_ = label_encoder_

def predict_proba(self, X):
"""Posterior probabilities of classification
"""Calculate calibrated probabilities.

This function returns posterior probabilities of classification
according to each class on an array of test vectors X.
Calculates classification calibrated probabilities
for each class, in a one-vs-all manner, for `X`.

Parameters
----------
X : array-like, shape (n_samples, n_features)
The samples.
The sample data.

Returns
-------
C : array, shape (n_samples, n_classes)
The predicted probas. Can be exact zeros.
proba : array, shape (n_samples, n_classes)
The predicted probabilities. Can be exact zeros.
"""
n_classes = len(self.classes_)
proba = np.zeros((X.shape[0], n_classes))
n_classes = len(self.label_encoder_.classes_)
df, pos_class_indices = _get_predictions(
self.clf_fitted, X, self.label_encoder_
)

df, idx_pos_class = self._preproc(X)

for k, this_df, calibrator in \
zip(idx_pos_class, df.T, self.calibrators_):
proba = np.zeros((X.shape[0], n_classes))
for class_idx, this_df, calibrator in \
zip(pos_class_indices, df.T, self.calibrators_fitted):
if n_classes == 2:
k += 1
proba[:, k] = calibrator.predict(this_df)
# When binary, proba of clf_fitted.classes_[1]
# output but `pos_class_indices` = 0
class_idx += 1
proba[:, class_idx] = calibrator.predict(this_df)

# Normalize the probabilities
if n_classes == 2:
Expand Down Expand Up @@ -649,7 +669,7 @@ def calibration_curve(y_true, y_prob, *, normalize=False, n_bins=5,
array([0. , 0.5, 1. ])
>>> prob_pred
array([0.2 , 0.525, 0.85 ])
"""
"""
y_true = column_or_1d(y_true)
y_prob = column_or_1d(y_prob)
check_consistent_length(y_true, y_prob)
Expand Down
6 changes: 3 additions & 3 deletions sklearn/tests/test_calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
assert_raises, ignore_warnings)
from sklearn.exceptions import NotFittedError
from sklearn.datasets import make_classification, make_blobs
from sklearn.preprocessing import LabelBinarizer
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import KFold
from sklearn.naive_bayes import MultinomialNB
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
Expand Down Expand Up @@ -109,7 +109,7 @@ def test_calibration_default_estimator():
calib_clf = CalibratedClassifierCV(cv=2)
calib_clf.fit(X, y)

base_est = calib_clf.calibrated_classifiers_[0].base_estimator
base_est = calib_clf.calibrated_classifiers_[0].clf_fitted
assert isinstance(base_est, LinearSVC)


Expand Down Expand Up @@ -429,6 +429,6 @@ def test_calibration_attributes(clf, cv):
assert_array_equal(calib_clf.classes_, clf.classes_)
assert calib_clf.n_features_in_ == clf.n_features_in_
else:
classes = LabelBinarizer().fit(y).classes_
classes = LabelEncoder().fit(y).classes_
assert_array_equal(calib_clf.classes_, classes)
assert calib_clf.n_features_in_ == X.shape[1]