Skip to content

[MRG+1] label binarizer not used consistently in CalibratedClassifierCV #7799

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Nov 8, 2016
Merged
6 changes: 6 additions & 0 deletions doc/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,12 @@ Bug fixes
``partial_fit`` was less than the total number of classes in the
data. :issue:`7786` by `Srivatsan Ramesh`_

- Fixes issue in :class:`calibration.CalibratedClassifierCV` where
the sum of probabilities of each class for a data was not 1, and
``CalibratedClassifierCV`` now handles the case where the training set
has less number of classes than the total data. :issue:`7799` by
`Srivatsan Ramesh`_


API changes summary
-------------------
Expand Down
43 changes: 30 additions & 13 deletions sklearn/calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@
import numpy as np

from scipy.optimize import fmin_bfgs
from sklearn.preprocessing import LabelEncoder

from .base import BaseEstimator, ClassifierMixin, RegressorMixin, clone
from .preprocessing import LabelBinarizer
from .preprocessing import label_binarize, LabelBinarizer
from .utils import check_X_y, check_array, indexable, column_or_1d
from .utils.validation import check_is_fitted
from .utils.fixes import signature
Expand Down Expand Up @@ -50,7 +51,8 @@ class CalibratedClassifierCV(BaseEstimator, ClassifierMixin):
The method to use for calibration. Can be 'sigmoid' which
corresponds to Platt's method or 'isotonic' which is a
non-parametric approach. It is not advised to use isotonic calibration
with too few calibration samples ``(<<1000)`` since it tends to overfit.
with too few calibration samples ``(<<1000)`` since it tends to
overfit.
Use sigmoids (Platt's calibration) in this case.

cv : integer, cross-validation generator, iterable or "prefit", optional
Expand All @@ -63,8 +65,8 @@ class CalibratedClassifierCV(BaseEstimator, ClassifierMixin):
- An iterable yielding train/test splits.

For integer/None inputs, if ``y`` is binary or multiclass,
:class:`sklearn.model_selection.StratifiedKFold` is used. If ``y``
is neither binary nor multiclass, :class:`sklearn.model_selection.KFold`
:class:`sklearn.model_selection.StratifiedKFold` is used. If ``y`` is
neither binary nor multiclass, :class:`sklearn.model_selection.KFold`
is used.

Refer :ref:`User Guide <cross_validation>` for the various
Expand Down Expand Up @@ -124,15 +126,16 @@ def fit(self, X, y, sample_weight=None):
X, y = check_X_y(X, y, accept_sparse=['csc', 'csr', 'coo'],
force_all_finite=False)
X, y = indexable(X, y)
lb = LabelBinarizer().fit(y)
self.classes_ = lb.classes_
le = LabelBinarizer().fit(y)
self.classes_ = le.classes_

# Check that each cross-validation fold can have at least one
# example per class
n_folds = self.cv if isinstance(self.cv, int) \
else self.cv.n_folds if hasattr(self.cv, "n_folds") else None
if n_folds and \
np.any([np.sum(y == class_) < n_folds for class_ in self.classes_]):
np.any([np.sum(y == class_) < n_folds for class_ in
self.classes_]):
raise ValueError("Requesting %d-fold cross-validation but provided"
" less than %d examples for at least one class."
% (n_folds, n_folds))
Expand Down Expand Up @@ -175,7 +178,8 @@ def fit(self, X, y, sample_weight=None):
this_estimator.fit(X[train], y[train])

calibrated_classifier = _CalibratedClassifier(
this_estimator, method=self.method)
this_estimator, method=self.method,
classes=self.classes_)
if sample_weight is not None:
calibrated_classifier.fit(X[test], y[test],
sample_weight[test])
Expand Down Expand Up @@ -253,6 +257,11 @@ class _CalibratedClassifier(object):
corresponds to Platt's method or 'isotonic' which is a
non-parametric approach based on isotonic regression.

classes : array-like, shape (n_classes,), optional
Contains unique classes used to fit the base estimator.
if None, then classes is extracted from the given target values
in fit().

References
----------
.. [1] Obtaining calibrated probability estimates from decision trees
Expand All @@ -267,9 +276,10 @@ class _CalibratedClassifier(object):
.. [4] Predicting Good Probabilities with Supervised Learning,
A. Niculescu-Mizil & R. Caruana, ICML 2005
"""
def __init__(self, base_estimator, method='sigmoid'):
def __init__(self, base_estimator, method='sigmoid', classes=None):
self.base_estimator = base_estimator
self.method = method
self.classes = classes

def _preproc(self, X):
n_classes = len(self.classes_)
Expand All @@ -285,7 +295,8 @@ def _preproc(self, X):
raise RuntimeError('classifier has no decision_function or '
'predict_proba method.')

idx_pos_class = np.arange(df.shape[1])
idx_pos_class = self.label_encoder_.\
transform(self.base_estimator.classes_)

return df, idx_pos_class

Expand All @@ -308,9 +319,15 @@ def fit(self, X, y, sample_weight=None):
self : object
Returns an instance of self.
"""
lb = LabelBinarizer()
Y = lb.fit_transform(y)
self.classes_ = lb.classes_

self.label_encoder_ = LabelEncoder()
if self.classes is None:
self.label_encoder_.fit(y)
else:
self.label_encoder_.fit(self.classes)

self.classes_ = self.label_encoder_.classes_
Y = label_binarize(y, self.classes_)

df, idx_pos_class = self._preproc(X)
self.calibrators_ = []
Expand Down
43 changes: 36 additions & 7 deletions sklearn/tests/test_calibration.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
# Authors: Alexandre Gramfort <alexandre.gramfort@telecom-paristech.fr>
# License: BSD 3 clause

from __future__ import division
import numpy as np
from scipy import sparse
from sklearn.model_selection import LeaveOneOut

from sklearn.utils.testing import (assert_array_almost_equal, assert_equal,
assert_greater, assert_almost_equal,
Expand All @@ -14,7 +16,6 @@
from sklearn.naive_bayes import MultinomialNB
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
from sklearn.svm import LinearSVC
from sklearn.linear_model import Ridge
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import Imputer
from sklearn.metrics import brier_score_loss, log_loss
Expand Down Expand Up @@ -87,12 +88,6 @@ def test_calibration():
brier_score_loss((y_test + 1) % 2,
prob_pos_pc_clf_relabeled))

# check that calibration can also deal with regressors that have
# a decision_function
clf_base_regressor = CalibratedClassifierCV(Ridge())
clf_base_regressor.fit(X_train, y_train)
clf_base_regressor.predict(X_test)

# Check failure cases:
# only "isotonic" and "sigmoid" should be accepted as methods
clf_invalid_method = CalibratedClassifierCV(clf, method="foo")
Expand Down Expand Up @@ -159,6 +154,7 @@ def test_calibration_multiclass():
def softmax(y_pred):
e = np.exp(-y_pred)
return e / e.sum(axis=1).reshape(-1, 1)

uncalibrated_log_loss = \
log_loss(y_test, softmax(clf.decision_function(X_test)))
calibrated_log_loss = log_loss(y_test, probas)
Expand Down Expand Up @@ -275,3 +271,36 @@ def test_calibration_nan_imputer():
clf_c = CalibratedClassifierCV(clf, cv=2, method='isotonic')
clf_c.fit(X, y)
clf_c.predict(X)


def test_calibration_prob_sum():
# Test that sum of probabilities is 1. A non-regression test for
# issue #7796
num_classes = 2
X, y = make_classification(n_samples=10, n_features=5,
n_classes=num_classes)
clf = LinearSVC(C=1.0)
clf_prob = CalibratedClassifierCV(clf, method="sigmoid", cv=LeaveOneOut())
clf_prob.fit(X, y)

probs = clf_prob.predict_proba(X)
assert_array_almost_equal(probs.sum(axis=1), np.ones(probs.shape[0]))


def test_calibration_less_classes():
# Test to check calibration works fine when train set in a test-train
# split does not contain all classes
# Since this test uses LOO, at each iteration train set will not contain a
# class label
X = np.random.randn(10, 5)
y = np.arange(10)
clf = LinearSVC(C=1.0)
cal_clf = CalibratedClassifierCV(clf, method="sigmoid", cv=LeaveOneOut())
cal_clf.fit(X, y)

for i, calibrated_classifier in \
enumerate(cal_clf.calibrated_classifiers_):
proba = calibrated_classifier.predict_proba(X)
assert_array_equal(proba[:, i], np.zeros(len(y)))
assert_equal(np.all(np.hstack([proba[:, :i],
proba[:, i + 1:]])), True)