From 6ff2be0fd79afd5c233cfcdd67271727f6f269f0 Mon Sep 17 00:00:00 2001 From: Jan Hendrik Metzen Date: Wed, 18 Feb 2015 12:04:07 +0100 Subject: [PATCH 1/8] ENH Add probability calibration based on isotonic regr. and Platt's sigmoid fit + calibration-curve CalibratedClassifierCV allows to calibrate the predicted probabilities of base classifiers based on a cross-validation scheme and either Platt's sigmoid fit or isotonic regression. This can be used to compensate for an under-confident or over-confident classifier. It allows also to turn the decision scores of a non-probabilistic classifier into valid probabilities. The function calibration_curve allows to evaluate how well calibrated the probabilties returned by a classifier are. Ideally, the curve should be close to diagonal. --- sklearn/__init__.py | 2 +- sklearn/calibration.py | 525 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 526 insertions(+), 1 deletion(-) create mode 100644 sklearn/calibration.py diff --git a/sklearn/__init__.py b/sklearn/__init__.py index 6ad049e1f637c..f3462531d8987 100644 --- a/sklearn/__init__.py +++ b/sklearn/__init__.py @@ -56,7 +56,7 @@ from .base import clone __check_build # avoid flakes unused variable error - __all__ = ['cluster', 'covariance', 'cross_decomposition', + __all__ = ['calibration', 'cluster', 'covariance', 'cross_decomposition', 'cross_validation', 'datasets', 'decomposition', 'dummy', 'ensemble', 'externals', 'feature_extraction', 'feature_selection', 'gaussian_process', 'grid_search', 'hmm', diff --git a/sklearn/calibration.py b/sklearn/calibration.py new file mode 100644 index 0000000000000..e4750cbb58d51 --- /dev/null +++ b/sklearn/calibration.py @@ -0,0 +1,525 @@ +"""Calibration of predicted probabilities.""" + +# Author: Alexandre Gramfort +# Balazs Kegl +# Jan Hendrik Metzen +# Mathieu Blondel +# +# License: BSD 3 clause + +from __future__ import division +import inspect + +from math import log +import numpy as np + +from scipy.optimize import fmin_bfgs + +from .base import BaseEstimator, ClassifierMixin, RegressorMixin, clone +from .preprocessing import LabelBinarizer +from .utils import check_X_y, check_array, indexable, column_or_1d +from .utils.validation import check_is_fitted +from .isotonic import IsotonicRegression +from .naive_bayes import GaussianNB +from .cross_validation import _check_cv +from .metrics.classification import _check_binary_probabilistic_predictions + + +class CalibratedClassifierCV(BaseEstimator, ClassifierMixin): + """Probability calibration with isotonic regression or sigmoid. + + With this class, the base_estimator is fit on the train set of the + cross-validation generator and the test set is used for calibration. + The probabilities for each of the folds are then averaged + for prediction. In case that cv="prefit" is passed to __init__, + it is it is assumed that base_estimator has been + fitted already and all data is used for calibration. Note that + data for fitting the classifier and for calibrating it must be disjpint. + + Parameters + ---------- + base_estimator : instance BaseEstimator + The classifier whose output decision function needs to be calibrated + to offer more accurate predict_proba outputs. If cv=prefit, the + classifier must have been fit already on data. + + method : 'sigmoid' | 'isotonic' + The method to use for calibration. Can be 'sigmoid' which + corresponds to Platt's method or 'isotonic' which is a + non-parameteric approach. It is not advised to use isotonic calibration + with too few calibration samples (<<1000) since it tends to overfit. + Use sigmoids (Platt's calibration) in this case. + + cv : integer or cross-validation generator or "prefit", optional + If an integer is passed, it is the number of folds (default 3). + Specific cross-validation objects can be passed, see + sklearn.cross_validation module for the list of possible objects. + If "prefit" is passed, it is assumed that base_estimator has been + fitted already and all data is used for calibration. + + Attributes + ---------- + classes_ : array, shape (n_classes) + The class labels. + + calibrated_classifiers_: list (len() equal to cv or 1 if cv == "prefit") + The list of calibrated classifiers, one for each crossvalidation fold, + which has been fitted on all but the validation fold and calibrated + on the validation fold. + + References + ---------- + .. [1] Obtaining calibrated probability estimates from decision trees + and naive Bayesian classifiers, B. Zadrozny & C. Elkan, ICML 2001 + + .. [2] Transforming Classifier Scores into Accurate Multiclass + Probability Estimates, B. Zadrozny & C. Elkan, (KDD 2002) + + .. [3] Probabilistic Outputs for Support Vector Machines and Comparisons to + Regularized Likelihood Methods, J. Platt, (1999) + + .. [4] Predicting Good Probabilities with Supervised Learning, + A. Niculescu-Mizil & R. Caruana, ICML 2005 + """ + def __init__(self, base_estimator=GaussianNB(), method='sigmoid', cv=3): + self.base_estimator = base_estimator + self.method = method + self.cv = cv + + def fit(self, X, y, sample_weight=None): + """Fit the calibrated model + + Parameters + ---------- + X : array-like, shape (n_samples, n_features) + Training data. + + y : array-like, shape (n_samples,) + Target values. + + sample_weight : array-like, shape = [n_samples] or None + Sample weights. If None, then samples are equally weighted. + + Returns + ------- + self : object + Returns an instance of self. + """ + X, y = check_X_y(X, y, accept_sparse=['csc', 'csr', 'coo']) + X, y = indexable(X, y) + lb = LabelBinarizer().fit(y) + self.classes_ = lb.classes_ + + # Check that we each cross-validation fold can have at least one + # example per class + n_folds = self.cv if isinstance(self.cv, int) \ + else self.cv.n_folds if hasattr(self.cv, "n_folds") else None + if n_folds and \ + np.any([np.sum(y==class_) < n_folds for class_ in self.classes_]): + raise ValueError("Requesting %d-fold cross-validation but provided" + " less than %d examples for at least one class." + % (n_folds, n_folds)) + + self.calibrated_classifiers_ = [] + if self.cv == "prefit": + calibrated_classifier = _CalibratedClassifier(self.base_estimator, + method=self.method) + if sample_weight is not None: + calibrated_classifier.fit(X, y, sample_weight) + else: + calibrated_classifier.fit(X, y) + self.calibrated_classifiers_.append(calibrated_classifier) + else: + cv = _check_cv(self.cv, X, y, classifier=True) + for train, test in cv: + this_estimator = clone(self.base_estimator) + if sample_weight is not None and \ + "sample_weight" in inspect.getargspec( + this_estimator.fit)[0]: + this_estimator.fit(X[train], y[train], + sample_weight[train]) + else: + this_estimator.fit(X[train], y[train]) + + calibrated_classifier = \ + _CalibratedClassifier(this_estimator, method=self.method) + if sample_weight is not None: + calibrated_classifier.fit(X[test], y[test], + sample_weight[test]) + else: + calibrated_classifier.fit(X[test], y[test]) + self.calibrated_classifiers_.append(calibrated_classifier) + + return self + + def predict_proba(self, X): + """Posterior probabilities of classification + + This function returns posterior probabilities of classification + according to each class on an array of test vectors X. + + Parameters + ---------- + X : array-like, shape (n_samples, n_features) + The samples. + + Returns + ------- + C : array, shape (n_samples, n_classes) + The predicted probas. + """ + check_is_fitted(self, ["classes_", "calibrated_classifiers_"]) + X = check_array(X, accept_sparse=['csc', 'csr', 'coo']) + # Compute the arithmetic mean of the predictions of the calibrated + # classfiers + mean_proba = np.zeros((X.shape[0], len(self.classes_))) + for calibrated_classifier in self.calibrated_classifiers_: + proba = calibrated_classifier.predict_proba(X) + mean_proba += proba + + mean_proba /= len(self.calibrated_classifiers_) + + return mean_proba + + def predict(self, X): + """Predict the target of new samples. Can be different from the + prediction of the uncalibrated classifier. + + Parameters + ---------- + X : array-like, shape (n_samples, n_features) + The samples. + + Returns + ------- + C : array, shape (n_samples,) + The predicted class. + """ + check_is_fitted(self, ["classes_", "calibrated_classifiers_"]) + return self.classes_[np.argmax(self.predict_proba(X), axis=1)] + + +class _CalibratedClassifier(object): + """Probability calibration with isotonic regression or sigmoid. + + It assumes that base_estimator has already been fit, and trains the + calibration on the input set of the fit function. Note that this class + should not be used as an estimator directly. Use CalibratedClassifierCV + with cv="prefit" instead. + + Parameters + ---------- + base_estimator : instance BaseEstimator + The classifier whose output decision function needs to be calibrated + to offer more accurate predict_proba outputs. No default value since + it has to be an already fitted estimator. + + method : 'sigmoid' | 'isotonic' + The method to use for calibration. Can be 'sigmoid' which + corresponds to Platt's method or 'isotonic' which is a + non-parameteric approach based on isotonic regression. + + References + ---------- + .. [1] Obtaining calibrated probability estimates from decision trees + and naive Bayesian classifiers, B. Zadrozny & C. Elkan, ICML 2001 + + .. [2] Transforming Classifier Scores into Accurate Multiclass + Probability Estimates, B. Zadrozny & C. Elkan, (KDD 2002) + + .. [3] Probabilistic Outputs for Support Vector Machines and Comparisons to + Regularized Likelihood Methods, J. Platt, (1999) + + .. [4] Predicting Good Probabilities with Supervised Learning, + A. Niculescu-Mizil & R. Caruana, ICML 2005 + """ + def __init__(self, base_estimator, method='sigmoid'): + self.base_estimator = base_estimator + self.method = method + + def _preproc(self, X): + n_classes = len(self.classes_) + if hasattr(self.base_estimator, "decision_function"): + df = self.base_estimator.decision_function(X) + if df.ndim == 1: + df = df[:, np.newaxis] + elif hasattr(self.base_estimator, "predict_proba"): + df = self.base_estimator.predict_proba(X) + if n_classes == 2: + df = df[:, 1:] + else: + raise RuntimeError('classifier has no decision_function or ' + 'predict_proba method.') + + idx_pos_class = np.arange(df.shape[1]) + + return df, idx_pos_class + + def fit(self, X, y, sample_weight=None): + """Calibrate the fitted model + + Parameters + ---------- + X : array-like, shape (n_samples, n_features) + Training data. + + y : array-like, shape (n_samples,) + Target values. + + sample_weight : array-like, shape = [n_samples] or None + Sample weights. If None, then samples are equally weighted. + + Returns + ------- + self : object + Returns an instance of self. + """ + lb = LabelBinarizer() + Y = lb.fit_transform(y) + self.classes_ = lb.classes_ + + df, idx_pos_class = self._preproc(X) + self.calibrators_ = [] + + for k, this_df in zip(idx_pos_class, df.T): + if self.method == 'isotonic': + calibrator = IsotonicRegression(out_of_bounds='clip') + # XXX: isotonic regression cannot deal correctly with + # situations in which multiple inputs are identical but + # have different outputs. Since this is not untypical + # when calibrating, we add some small random jitter to + # the inputs. + this_df = \ + this_df + np.random.normal(0, 1e-10, this_df.shape[0]) + elif self.method == 'sigmoid': + calibrator = _SigmoidCalibration() + else: + raise ValueError('method should be "sigmoid" or ' + '"isotonic". Got %s.' % self.method) + calibrator.fit(this_df, Y[:, k], sample_weight) + self.calibrators_.append(calibrator) + + return self + + def predict_proba(self, X): + """Posterior probabilities of classification + + This function returns posterior probabilities of classification + according to each class on an array of test vectors X. + + Parameters + ---------- + X : array-like, shape (n_samples, n_features) + The samples. + + Returns + ------- + C : array, shape (n_samples, n_classes) + The predicted probas. Can be exact zeros. + """ + n_classes = len(self.classes_) + proba = np.zeros((X.shape[0], n_classes)) + + df, idx_pos_class = self._preproc(X) + + for k, this_df, calibrator in \ + zip(idx_pos_class, df.T, self.calibrators_): + if n_classes == 2: + k += 1 + proba[:, k] = calibrator.predict(this_df) + + # Normalize the probabilities + if n_classes == 2: + proba[:, 0] = 1. - proba[:, 1] + else: + proba /= np.sum(proba, axis=1)[:, np.newaxis] + + # XXX : for some reason all probas can be 0 + proba[np.isnan(proba)] = 1. / n_classes + + # Deal with cases where the predicted probability minimally exceeds 1.0 + proba[(1.0 < proba) & (proba <= 1.0 + 1e-5)] = 1.0 + + return proba + + +def sigmoid_calibration(df, y, sample_weight=None): + """Probability Calibration with sigmoid method (Platt 2000) + + Parameters + ---------- + df : ndarray, shape (n_samples,) + The decision function or predict proba for the samples. + + y : ndarray, shape (n_samples,) + The targets. + + sample_weight : array-like, shape = [n_samples] or None + Sample weights. If None, then samples are equally weighted. + + Returns + ------- + a : float + The slope. + + b : float + The intercept. + + References + ---------- + Platt, "Probabilistic Outputs for Support Vector Machines" + """ + df = column_or_1d(df) + y = column_or_1d(y) + + F = df # F follows Platt's notations + tiny = np.finfo(np.float).tiny # to avoid division by 0 warning + + # Bayesian priors (see Platt end of section 2.2) + prior0 = float(np.sum(y <= 0)) + prior1 = y.shape[0] - prior0 + T = np.zeros(y.shape) + T[y > 0] = (prior1 + 1.) / (prior1 + 2.) + T[y <= 0] = 1. / (prior0 + 2.) + T1 = 1. - T + + def objective(AB): + # From Platt (beginning of Section 2.2) + E = np.exp(AB[0] * F + AB[1]) + P = 1. / (1. + E) + l = -(T * np.log(P + tiny) + T1 * np.log(1. - P + tiny)) + if sample_weight is not None: + return (sample_weight * l).sum() + else: + return l.sum() + + def grad(AB): + # gradient of the objective function + E = np.exp(AB[0] * F + AB[1]) + P = 1. / (1. + E) + TEP_minus_T1P = P * (T * E - T1) + if sample_weight is not None: + TEP_minus_T1P *= sample_weight + dA = np.dot(TEP_minus_T1P, F) + dB = np.sum(TEP_minus_T1P) + return np.array([dA, dB]) + + AB0 = np.array([0., log((prior0 + 1.) / (prior1 + 1.))]) + AB_ = fmin_bfgs(objective, AB0, fprime=grad, disp=False) + return AB_[0], AB_[1] + + +class _SigmoidCalibration(BaseEstimator, RegressorMixin): + """Sigmoid regression model. + + Attributes + ---------- + `a_` : float + The slope. + + `b_` : float + The intercept. + """ + def fit(self, X, y, sample_weight=None): + """Fit the model using X, y as training data. + + Parameters + ---------- + X : array-like, shape (n_samples,) + Training data. + + y : array-like, shape (n_samples,) + Training target. + + sample_weight : array-like, shape = [n_samples] or None + Sample weights. If None, then samples are equally weighted. + + Returns + ------- + self : object + Returns an instance of self. + """ + X = column_or_1d(X) + y = column_or_1d(y) + X, y = indexable(X, y) + + if len(X.shape) != 1: + raise ValueError("X should be a 1d array") + + self.a_, self.b_ = sigmoid_calibration(X, y, sample_weight) + return self + + def predict(self, T): + """Predict new data by linear interpolation. + + Parameters + ---------- + T : array-like, shape (n_samples,) + Data to predict from. + + Returns + ------- + `T_` : array, shape (n_samples,) + The predicted data. + """ + T = column_or_1d(T) + return 1. / (1. + np.exp(self.a_ * T + self.b_)) + + +def calibration_curve(y_true, y_prob, normalize=False, n_bins=5): + """Compute true and predicted probabilities for a calibration curve. + + Parameters + ---------- + y_true : array, shape (n_samples,) + True targets. + + y_prob : array, shape (n_samples,) + Probabilities of the positive class. + + normalize : bool, optional, default=False + Whether y_prob needs to be normalized into the bin [0, 1], i.e. is not + a proper probability. If True, the smallest value in y_prob is mapped + onto 0 and the largest one onto 1. + + n_bins : int + Number of bins. A bigger number requires more data. + + Returns + ------- + prob_true : array, shape (n_bins,) + The true probability in each bin (fraction of positives). + + prob_pred : array, shape (n_bins,) + The mean predicted probability in each bin. + + References + ---------- + Alexandru Niculescu-Mizil and Rich Caruana (2005) Predicting Good + Probabilities With Supervised Learning, in Proceedings of the 22nd + International Conference on Machine Learning (ICML). + See section 4 (Qualitative Analysis of Predictions). + """ + y_true = column_or_1d(y_true) + y_prob = column_or_1d(y_prob) + + if normalize: # Normalize predicted values into interval [0, 1] + y_prob = (y_prob - y_prob.min()) / (y_prob.max() - y_prob.min()) + elif y_prob.min() < 0 or y_prob.max() > 1: + raise ValueError("y_prob has values outside [0, 1] and normalize is " + "set to False.") + + y_true = _check_binary_probabilistic_predictions(y_true, y_prob) + + bins = np.linspace(0., 1. + 1e-8, n_bins + 1) + binids = np.digitize(y_prob, bins) - 1 + + bin_sums = np.bincount(binids, weights=y_prob, minlength=len(bins)) + bin_true = np.bincount(binids, weights=y_true, minlength=len(bins)) + bin_total = np.bincount(binids, minlength=len(bins)) + + nonzero = bin_total != 0 + prob_true = (bin_true[nonzero] / bin_total[nonzero]) + prob_pred = (bin_sums[nonzero] / bin_total[nonzero]) + + return prob_true, prob_pred From 93614bffa7015b10e8cf25a19ce4d57f7f025718 Mon Sep 17 00:00:00 2001 From: Jan Hendrik Metzen Date: Wed, 18 Feb 2015 12:07:44 +0100 Subject: [PATCH 2/8] ENH Brier-score loss metric for classifiers The Brier score allows to evaluate the quality of the predicted probabilities of a classifier. It is defined as the mean squared difference between (1) the predicted probability assigned to the possible outcomes for an item and (2) the actual outcome. --- sklearn/metrics/__init__.py | 3 + sklearn/metrics/classification.py | 95 +++++++++++++++++++- sklearn/metrics/tests/test_classification.py | 19 ++++ 3 files changed, 116 insertions(+), 1 deletion(-) diff --git a/sklearn/metrics/__init__.py b/sklearn/metrics/__init__.py index 3cd4d2afe3308..c455f9e8c4d81 100644 --- a/sklearn/metrics/__init__.py +++ b/sklearn/metrics/__init__.py @@ -3,6 +3,7 @@ and pairwise metrics and distance computations. """ + from .ranking import auc from .ranking import average_precision_score from .ranking import coverage_error @@ -25,6 +26,7 @@ from .classification import precision_score from .classification import recall_score from .classification import zero_one_loss +from .classification import brier_score_loss from . import cluster from .cluster import adjusted_mutual_info_score @@ -103,4 +105,5 @@ 'silhouette_score', 'v_measure_score', 'zero_one_loss', + 'brier_score_loss', ] diff --git a/sklearn/metrics/classification.py b/sklearn/metrics/classification.py index 99a7e509fe69e..441af936d53d0 100644 --- a/sklearn/metrics/classification.py +++ b/sklearn/metrics/classification.py @@ -28,12 +28,13 @@ from scipy.sparse import csr_matrix from scipy.spatial.distance import hamming as sp_hamming -from ..preprocessing import LabelBinarizer +from ..preprocessing import LabelBinarizer, label_binarize from ..preprocessing import LabelEncoder from ..utils import check_array from ..utils import check_consistent_length from ..preprocessing import MultiLabelBinarizer from ..utils import column_or_1d +from ..utils import check_consistent_length from ..utils.multiclass import unique_labels from ..utils.multiclass import type_of_target from ..utils.validation import _num_samples @@ -1515,3 +1516,95 @@ def hinge_loss(y_true, pred_decision, labels=None, sample_weight=None): # The hinge_loss doesn't penalize good enough predictions. losses[losses <= 0] = 0 return np.average(losses, weights=sample_weight) + + +def _check_binary_probabilistic_predictions(y_true, y_prob): + """Check that y_true is binary and y_prob contains valid probabilities""" + check_consistent_length(y_true, y_prob) + + labels = np.unique(y_true) + + if len(labels) != 2: + raise ValueError("Only binary classification is supported. " + "Provided labels %s." % labels) + + if y_prob.max() > 1: + raise ValueError("y_prob contains values greater than 1.") + + if y_prob.min() < 0: + raise ValueError("y_prob contains values less than 0.") + + return label_binarize(y_true, labels)[:, 0] + + +def brier_score_loss(y_true, y_prob, sample_weight=None, pos_label=None): + """Compute the Brier score. + + The smaller the Brier score, the better, hence the naming with "loss". + + Across all items in a set N predictions, the Brier score measures the + mean squared difference between (1) the predicted probability assigned + to the possible outcomes for item i, and (2) the actual outcome. + Therefore, the lower the Brier score is for a set of predictions, the + better the predictions are calibrated. Note that the Brier score always + takes on a value between zero and one, since this is the largest + possible difference between a predicted probability (which must be + between zero and one) and the actual outcome (which can take on values + of only 0 and 1). + + The Brier score is appropriate for binary and categorical outcomes that + can be structured as true or false, but is inappropriate for ordinal + variables which can take on three or more values (this is because the + Brier score assumes that all possible outcomes are equivalently + "distant" from one another). Which label is considered to be the positive + label is controlled via the parameter pos_label, which defaults to 1. + + + Parameters + ---------- + y_true : array, shape (n_samples,) + True targets. + + y_prob : array, shape (n_samples,) + Probabilities of the positive class. + + sample_weight : array-like of shape = [n_samples], optional + Sample weights. + + pos_label : int (default: None) + Label of the positive class. If None, the maximum label is used as + positive class + + Returns + ------- + score : float + Brier score + + Examples + -------- + >>> import numpy as np + >>> from sklearn.metrics import brier_score_loss + >>> y_true = np.array([0, 1, 1, 0]) + >>> y_true_categorical = np.array(["spam", "ham", "ham", "spam"]) + >>> y_prob = np.array([0.1, 0.9, 0.8, 0.3]) + >>> brier_score_loss(y_true, y_prob) # doctest: +ELLIPSIS + 0.037... + >>> brier_score_loss(y_true, 1-y_prob, pos_label=0) # doctest: +ELLIPSIS + 0.037... + >>> brier_score_loss(y_true_categorical, y_prob, \ + pos_label="ham") # doctest: +ELLIPSIS + 0.037... + >>> brier_score_loss(y_true, np.array(y_prob) > 0.5) + 0.0 + + References + ---------- + http://en.wikipedia.org/wiki/Brier_score + """ + y_true = column_or_1d(y_true) + y_prob = column_or_1d(y_prob) + if pos_label is None: + pos_label = y_true.max() + y_true = np.array(y_true == pos_label, int) + y_true = _check_binary_probabilistic_predictions(y_true, y_prob) + return np.average((y_true - y_prob) ** 2, weights=sample_weight) diff --git a/sklearn/metrics/tests/test_classification.py b/sklearn/metrics/tests/test_classification.py index fe48585fddb16..d658169bd20ef 100644 --- a/sklearn/metrics/tests/test_classification.py +++ b/sklearn/metrics/tests/test_classification.py @@ -1,6 +1,7 @@ from __future__ import division, print_function import numpy as np +from scipy import linalg from functools import partial from itertools import product import warnings @@ -40,6 +41,7 @@ from sklearn.metrics import precision_score from sklearn.metrics import recall_score from sklearn.metrics import zero_one_loss +from sklearn.metrics import brier_score_loss from sklearn.metrics.classification import _check_targets @@ -1252,3 +1254,20 @@ def test_log_loss(): y_pred = [[0.2, 0.7], [0.6, 0.5], [0.4, 0.1], [0.7, 0.2]] loss = log_loss(y_true, y_pred) assert_almost_equal(loss, 1.0383217, decimal=6) + + +def test_brier_score_loss(): + """Check brier_score_loss function""" + y_true = np.array([0, 1, 1, 0, 1, 1]) + y_pred = np.array([0.1, 0.8, 0.9, 0.3, 1., 0.95]) + true_score = linalg.norm(y_true - y_pred) ** 2 / len(y_true) + + assert_almost_equal(brier_score_loss(y_true, y_true), 0.0) + assert_almost_equal(brier_score_loss(y_true, y_pred), true_score) + assert_almost_equal(brier_score_loss(1. + y_true, y_pred), + true_score) + assert_almost_equal(brier_score_loss(2 * y_true - 1, y_pred), + true_score) + assert_raises(ValueError, brier_score_loss, y_true, y_pred[1:]) + assert_raises(ValueError, brier_score_loss, y_true, y_pred + 1.) + assert_raises(ValueError, brier_score_loss, y_true, y_pred - 1.) From 527f6e8cd08db7aa3d9aa520f5276fb8322552c5 Mon Sep 17 00:00:00 2001 From: Jan Hendrik Metzen Date: Wed, 18 Feb 2015 12:08:20 +0100 Subject: [PATCH 3/8] TST Tests for the calibration module --- sklearn/tests/test_calibration.py | 197 ++++++++++++++++++++++++++++++ sklearn/utils/estimator_checks.py | 5 +- sklearn/utils/testing.py | 3 +- 3 files changed, 202 insertions(+), 3 deletions(-) create mode 100644 sklearn/tests/test_calibration.py diff --git a/sklearn/tests/test_calibration.py b/sklearn/tests/test_calibration.py new file mode 100644 index 0000000000000..797fb0bc60674 --- /dev/null +++ b/sklearn/tests/test_calibration.py @@ -0,0 +1,197 @@ +# Authors: Alexandre Gramfort +# License: BSD 3 clause + +import numpy as np +from scipy import sparse + +from sklearn.utils.testing import (assert_array_almost_equal, assert_equal, + assert_greater, assert_almost_equal, + assert_greater_equal) +from sklearn.datasets import make_classification, make_blobs +from sklearn.naive_bayes import MultinomialNB +from sklearn.ensemble import RandomForestClassifier +from sklearn.svm import LinearSVC +from sklearn.metrics import brier_score_loss, log_loss +from sklearn.calibration import CalibratedClassifierCV +from sklearn.calibration import sigmoid_calibration, _SigmoidCalibration +from sklearn.calibration import calibration_curve + + +def test_calibration(): + """Test calibration objects with isotonic and sigmoid""" + n_samples = 100 + X, y = make_classification(n_samples=2 * n_samples, n_features=6, + random_state=42) + sample_weight = np.random.RandomState(seed=42).uniform(size=y.size) + + X -= X.min() # MultinomialNB only allows positive X + + # split train and test + X_train, y_train, sw_train = \ + X[:n_samples], y[:n_samples], sample_weight[:n_samples] + X_test, y_test, sw_test = \ + X[n_samples:], y[n_samples:], sample_weight[n_samples:] + + # Naive-Bayes + clf = MultinomialNB() + clf.fit(X_train, y_train, sw_train) + prob_pos_clf = clf.predict_proba(X_test)[:, 1] + + # Naive Bayes with calibration + for this_X_train, this_X_test in [(X_train, X_test), + (sparse.csr_matrix(X_train), + sparse.csr_matrix(X_test))]: + for method in ['isotonic', 'sigmoid']: + pc_clf = CalibratedClassifierCV(clf, method=method, cv=2) + # Note that this fit overwrites the fit on the entire training + # set + pc_clf.fit(this_X_train, y_train, sample_weight=sw_train) + prob_pos_pc_clf = pc_clf.predict_proba(this_X_test)[:, 1] + + # Check that brier score has improved after calibration + assert_greater(brier_score_loss(y_test, prob_pos_clf), + brier_score_loss(y_test, prob_pos_pc_clf)) + + # Check invariance against relabeling [0, 1] -> [1, 2] + pc_clf.fit(this_X_train, y_train + 1, sample_weight=sw_train) + prob_pos_pc_clf_relabeled = pc_clf.predict_proba(this_X_test)[:, 1] + assert_array_almost_equal(prob_pos_pc_clf, + prob_pos_pc_clf_relabeled) + + # Check invariance against relabeling [0, 1] -> [-1, 1] + pc_clf.fit(this_X_train, 2 * y_train - 1, sample_weight=sw_train) + prob_pos_pc_clf_relabeled = pc_clf.predict_proba(this_X_test)[:, 1] + assert_array_almost_equal(prob_pos_pc_clf, + prob_pos_pc_clf_relabeled) + + # Check invariance against relabeling [0, 1] -> [1, 0] + pc_clf.fit(this_X_train, (y_train + 1) % 2, + sample_weight=sw_train) + prob_pos_pc_clf_relabeled = \ + pc_clf.predict_proba(this_X_test)[:, 1] + if method == "sigmoid": + assert_array_almost_equal(prob_pos_pc_clf, + 1 - prob_pos_pc_clf_relabeled) + else: + # Isotonic calibration is not invariant against relabeling + # but should improve in both cases + assert_greater(brier_score_loss(y_test, prob_pos_clf), + brier_score_loss((y_test + 1) % 2, + prob_pos_pc_clf_relabeled)) + + +def test_calibration_multiclass(): + """Test calibration for multiclass """ + # test multi-class setting with classifier that implements + # only decision function + clf = LinearSVC() + X, y_idx = make_blobs(n_samples=100, n_features=2, random_state=42, + centers=3, cluster_std=3.0) + + # Use categorical labels to check that CalibratedClassifierCV supports + # them correctly + target_names = np.array(['a', 'b', 'c']) + y = target_names[y_idx] + + X_train, y_train = X[::2], y[::2] + X_test, y_test = X[1::2], y[1::2] + + clf.fit(X_train, y_train) + for method in ['isotonic', 'sigmoid']: + cal_clf = CalibratedClassifierCV(clf, method=method, cv=2) + cal_clf.fit(X_train, y_train) + probas = cal_clf.predict_proba(X_test) + assert_array_almost_equal(np.sum(probas, axis=1), np.ones(len(X_test))) + + # Check that log-loss of calibrated classifier is smaller than + # log-loss of naively turned OvR decision function to probabilities + # via softmax + def softmax(y_pred): + e = np.exp(-y_pred) + return e / e.sum(axis=1).reshape(-1, 1) + uncalibrated_log_loss = \ + log_loss(y_test, softmax(clf.decision_function(X_test))) + calibrated_log_loss = log_loss(y_test, probas) + assert_greater_equal(uncalibrated_log_loss, calibrated_log_loss) + + # Test that calibration of a multiclass classifier decreases log-loss + # for RandomForestClassifier + X, y = make_blobs(n_samples=100, n_features=2, random_state=42, + cluster_std=3.0) + X_train, y_train = X[::2], y[::2] + X_test, y_test = X[1::2], y[1::2] + + clf = RandomForestClassifier(n_estimators=10, random_state=42) + clf.fit(X_train, y_train) + clf_probs = clf.predict_proba(X_test) + loss = log_loss(y_test, clf_probs) + + for method in ['isotonic', 'sigmoid']: + cal_clf = CalibratedClassifierCV(clf, method=method, cv=3) + cal_clf.fit(X_train, y_train) + cal_clf_probs = cal_clf.predict_proba(X_test) + cal_loss = log_loss(y_test, cal_clf_probs) + assert_greater(loss, cal_loss) + + +def test_calibration_prefit(): + """Test calibration for prefitted classifiers""" + n_samples = 50 + X, y = make_classification(n_samples=3 * n_samples, n_features=6, + random_state=42) + sample_weight = np.random.RandomState(seed=42).uniform(size=y.size) + + X -= X.min() # MultinomialNB only allows positive X + + # split train and test + X_train, y_train, sw_train = \ + X[:n_samples], y[:n_samples], sample_weight[:n_samples] + X_calib, y_calib, sw_calib = \ + X[n_samples:2*n_samples], y[n_samples:2*n_samples], \ + sample_weight[n_samples:2*n_samples] + X_test, y_test, sw_test = \ + X[2*n_samples:], y[2*n_samples:], sample_weight[2*n_samples:] + + # Naive-Bayes + clf = MultinomialNB() + clf.fit(X_train, y_train, sw_train) + prob_pos_clf = clf.predict_proba(X_test)[:, 1] + + # Naive Bayes with calibration + for this_X_calib, this_X_test in [(X_calib, X_test), + (sparse.csr_matrix(X_calib), + sparse.csr_matrix(X_test))]: + for method in ['isotonic', 'sigmoid']: + pc_clf = CalibratedClassifierCV(clf, method=method, cv="prefit") + pc_clf.fit(this_X_calib, y_calib, sample_weight=sw_calib) + prob_pos_pc_clf = pc_clf.predict_proba(this_X_test)[:, 1] + + assert_greater(brier_score_loss(y_test, prob_pos_clf), + brier_score_loss(y_test, prob_pos_pc_clf)) + + +def test_sigmoid_calibration(): + """Test calibration values with Platt sigmoid model""" + exF = np.array([5, -4, 1.0]) + exY = np.array([1, -1, -1]) + # computed from my python port of the C++ code in LibSVM + AB_lin_libsvm = np.array([-0.20261354391187855, 0.65236314980010512]) + assert_array_almost_equal(AB_lin_libsvm, sigmoid_calibration(exF, exY), 3) + lin_prob = 1. / (1. + np.exp(AB_lin_libsvm[0] * exF + AB_lin_libsvm[1])) + sk_prob = _SigmoidCalibration().fit(exF, exY).predict(exF) + assert_array_almost_equal(lin_prob, sk_prob, 6) + + +def test_calibration_curve(): + """Check calibration_curve function""" + y_true = np.array([0, 0, 0, 1, 1, 1]) + y_pred = np.array([0., 0.1, 0.2, 0.8, 0.9, 1.]) + prob_true, prob_pred = calibration_curve(y_true, y_pred, n_bins=2) + prob_true_unnormalized, prob_pred_unnormalized = \ + calibration_curve(y_true, y_pred * 2, n_bins=2, normalize=True) + assert_equal(len(prob_true), len(prob_pred)) + assert_equal(len(prob_true), 2) + assert_almost_equal(prob_true, [0, 1]) + assert_almost_equal(prob_pred, [0.1, 0.9]) + assert_almost_equal(prob_true, prob_true_unnormalized) + assert_almost_equal(prob_pred, prob_pred_unnormalized) diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index 2ca110a17d6c0..edfd6dd4d2757 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -277,7 +277,7 @@ def check_fit_score_takes_y(name, Estimator): # in fit and score so they can be used in pipelines rnd = np.random.RandomState(0) X = rnd.uniform(size=(10, 3)) - y = (X[:, 0] * 4).astype(np.int) + y = np.arange(10) % 3 y = multioutput_estimator_convert_y_2d(name, y) estimator = Estimator() set_fast_parameters(estimator) @@ -294,7 +294,7 @@ def check_fit_score_takes_y(name, Estimator): def check_estimators_dtypes(name, Estimator): rnd = np.random.RandomState(0) - X_train_32 = 4 * rnd.uniform(size=(10, 3)).astype(np.float32) + X_train_32 = 3 * rnd.uniform(size=(20, 5)).astype(np.float32) X_train_64 = X_train_32.astype(np.float64) X_train_int_64 = X_train_32.astype(np.int64) X_train_int_32 = X_train_32.astype(np.int32) @@ -634,6 +634,7 @@ def check_classifiers_input_shapes(name, Classifier): # raised with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always", DataConversionWarning) + warnings.simplefilter("ignore", RuntimeWarning) classifier.fit(X, y[:, np.newaxis]) msg = "expected 1 DataConversionWarning, got: %s" % ( ", ".join([str(w_x) for w_x in w])) diff --git a/sklearn/utils/testing.py b/sklearn/utils/testing.py index 704b67399abb4..a0ae8fb01b63d 100644 --- a/sklearn/utils/testing.py +++ b/sklearn/utils/testing.py @@ -509,7 +509,8 @@ def uninstall_mldata_mock(): # exclude them in another way 'ZeroEstimator', 'ScaledLogOddsEstimator', 'QuantileEstimator', 'MeanEstimator', - 'LogOddsEstimator', 'PriorProbabilityEstimator'] + 'LogOddsEstimator', 'PriorProbabilityEstimator', + '_SigmoidCalibration'] def all_estimators(include_meta_estimators=False, From 2022f17384a777199dbbf92c04a643217be31987 Mon Sep 17 00:00:00 2001 From: Jan Hendrik Metzen Date: Wed, 18 Feb 2015 12:10:51 +0100 Subject: [PATCH 4/8] DOC Examples for the calibration of predicted probabilities and calibration-curves --- examples/calibration/README.txt | 6 + examples/calibration/plot_calibration.py | 116 ++++++++++++ .../calibration/plot_calibration_curve.py | 134 ++++++++++++++ .../plot_calibration_multiclass.py | 168 ++++++++++++++++++ .../calibration/plot_compare_calibration.py | 124 +++++++++++++ 5 files changed, 548 insertions(+) create mode 100644 examples/calibration/README.txt create mode 100644 examples/calibration/plot_calibration.py create mode 100644 examples/calibration/plot_calibration_curve.py create mode 100644 examples/calibration/plot_calibration_multiclass.py create mode 100644 examples/calibration/plot_compare_calibration.py diff --git a/examples/calibration/README.txt b/examples/calibration/README.txt new file mode 100644 index 0000000000000..5e4a31b966b50 --- /dev/null +++ b/examples/calibration/README.txt @@ -0,0 +1,6 @@ +.. _calibration_examples: + +Calibration +----------------------- + +Examples illustrating the calibration of predicted probabilities of classifiers. diff --git a/examples/calibration/plot_calibration.py b/examples/calibration/plot_calibration.py new file mode 100644 index 0000000000000..2267f02dd0022 --- /dev/null +++ b/examples/calibration/plot_calibration.py @@ -0,0 +1,116 @@ +""" +====================================== +Probability calibration of classifiers +====================================== + +When performing classification you often want to predict not only +the class label, but also the associated probability. This probability +gives you some kind of confidence on the prediction. However, not all +classifiers provide well-calibrated probabilities, some being over-confident +while others being under-confident. Thus, a separate calibration of predicted +probabilities is often desirable as a postprocessing. This example illustrates +two different methods for this calibration and evaluates the quality of the +returned probabilities using Brier's score +(see http://en.wikipedia.org/wiki/Brier_score). + +Compared are the estimated probability using a Gaussian naive Bayes classifier +without calibration, with a sigmoid calibration, and with a non-parametric +isotonic calibration. One can observe that only the non-parametric model is able +to provide a probability calibration that returns probabilities close to the +expected 0.5 for most of the samples belonging to the middle cluster with +heterogeneous labels. This results in a significantly improved Brier score. +""" +print(__doc__) + +# Author: Mathieu Blondel +# Alexandre Gramfort +# Balazs Kegl +# Jan Hendrik Metzen +# License: BSD Style. + +import numpy as np +import matplotlib.pyplot as plt +from matplotlib import cm + +from sklearn.datasets import make_blobs +from sklearn.naive_bayes import GaussianNB +from sklearn.metrics import brier_score_loss +from sklearn.calibration import CalibratedClassifierCV +from sklearn.cross_validation import train_test_split + + +n_samples = 50000 +n_bins = 3 # use 3 bins for calibration_curve as we have 3 clusters here + +# Generate 3 blobs with 2 classes where the second blob contains +# half positive samples and half negative samples. Probability in this +# blob is therefore 0.5. +centers = [(-5, -5), (0, 0), (5, 5)] +X, y = make_blobs(n_samples=n_samples, n_features=2, cluster_std=1.0, + centers=centers, shuffle=False, random_state=42) + +y[:n_samples // 2] = 0 +y[n_samples // 2:] = 1 +sample_weight = np.random.RandomState(42).rand(y.shape[0]) + +# split train, test for calibration +X_train, X_test, y_train, y_test, sw_train, sw_test = \ + train_test_split(X, y, sample_weight, test_size=0.9, random_state=42) + +# Gaussian Naive-Bayes with no calibration +clf = GaussianNB() +clf.fit(X_train, y_train) # GaussianNB itself does not support sample-weights +prob_pos_clf = clf.predict_proba(X_test)[:, 1] + +# Gaussian Naive-Bayes with isotonic calibration +clf_isotonic = CalibratedClassifierCV(clf, cv=2, method='isotonic') +clf_isotonic.fit(X_train, y_train, sw_train) +prob_pos_isotonic = clf_isotonic.predict_proba(X_test)[:, 1] + +# Gaussian Naive-Bayes with sigmoid calibration +clf_sigmoid = CalibratedClassifierCV(clf, cv=2, method='sigmoid') +clf_sigmoid.fit(X_train, y_train, sw_train) +prob_pos_sigmoid = clf_sigmoid.predict_proba(X_test)[:, 1] + +print("Brier scores: (the smaller the better)") + +clf_score = brier_score_loss(y_test, prob_pos_clf, sw_test) +print("No calibration: %1.3f" % clf_score) + +clf_isotonic_score = brier_score_loss(y_test, prob_pos_isotonic, sw_test) +print("With isotonic calibration: %1.3f" % clf_isotonic_score) + +clf_sigmoid_score = brier_score_loss(y_test, prob_pos_sigmoid, sw_test) +print("With sigmoid calibration: %1.3f" % clf_sigmoid_score) + +############################################################################### +# Plot the data and the predicted probabilities +plt.figure() +y_unique = np.unique(y) +colors = cm.rainbow(np.linspace(0.0, 1.0, y_unique.size)) +for this_y, color in zip(y_unique, colors): + this_X = X_train[y_train == this_y] + this_sw = sw_train[y_train == this_y] + plt.scatter(this_X[:, 0], this_X[:, 1], s=this_sw * 50, c=color, alpha=0.5, + label="Class %s" % this_y) +plt.legend(loc="best") +plt.title("Data") + +plt.figure() +order = np.lexsort((prob_pos_clf, )) +plt.plot(prob_pos_clf[order], 'r', label='No calibration (%1.3f)' % clf_score) +plt.plot(prob_pos_isotonic[order], 'g', linewidth=3, + label='Isotonic calibration (%1.3f)' % clf_isotonic_score) +plt.plot(prob_pos_sigmoid[order], 'b', linewidth=3, + label='Sigmoid calibration (%1.3f)' % clf_sigmoid_score) +plt.plot(np.linspace(0, y_test.size, 51)[1::2], + y_test[order].reshape(25, -1).mean(1), + 'k', linewidth=3, label=r'Empirical') +plt.ylim([-0.05, 1.05]) +plt.xlabel("Instances sorted according to predicted probability " + "(uncalibrated GNB)") +plt.ylabel("P(y=1)") +plt.legend(loc="upper left") +plt.title("Gaussian naive Bayes probabilities") + +plt.show() diff --git a/examples/calibration/plot_calibration_curve.py b/examples/calibration/plot_calibration_curve.py new file mode 100644 index 0000000000000..42dc8473e6c30 --- /dev/null +++ b/examples/calibration/plot_calibration_curve.py @@ -0,0 +1,134 @@ +""" +============================== +Probability Calibration curves +============================== + +When performing classification one often wants to predict not only the class +label, but also the associated probability. This probability gives some +kind of confidence on the prediction. This example demonstrates how to display +how well calibrated the predicted probabilities are and how to calibrate an +uncalibrated classifier. + +The experiment is performed on an artificial dataset for binary classification +with 100.000 samples (1.000 of them are used for model fitting) with 20 +features. Of the 20 features, only 2 are informative and 10 are redundant. The +first figure shows the estimated probabilities obtained with logistic +regression, Gaussian naive Bayes, and Gaussian naive Bayes with both isotonic +calibration and sigmoid calibration. The calibration performance is evaluated +with Brier score, reported in the legend (the smaller the better). One can +observe here that logistic regression is well calibrated while raw Gaussian +naive Bayes performs very badly. This is because of the redundant features +which violate the assumption of feature-independence and result in an overly +confident classifier, which is indicated by the typical transposed-sigmoid +curve. + +Calibration of the probabilities of Gaussian naive Bayes with isotonic +regression can fix this issue as can be seen from the nearly diagonal +calibration curve. Sigmoid calibration also improves the brier score slightly, +albeit not as strongly as the non-parametric isotonic regression. This can be +attributed to the fact that we have plenty of calibration data such that the +greater flexibility of the non-parametric model can be exploited. + +The second figure shows the calibration curve of a linear support-vector +classifier (LinearSVC). LinearSVC shows the opposite behavior as Gaussian +naive Bayes: the calibration curve has a sigmoid curve, which is typical for +an under-confident classifier. In the case of LinearSVC, this is caused by the +margin property of the hinge loss, which lets the model focus on hard samples +that are close to the decision boundary (the support vectors). + +Both kinds of calibration can fix this issue and yield nearly identical +results. This shows that sigmoid calibration can deal with situations where +the calibration curve of the base classifier is sigmoid (e.g., for LinearSVC) +but not where it is transposed-sigmoid (e.g., Gaussian naive Bayes). +""" +print(__doc__) + +# Author: Alexandre Gramfort +# Jan Hendrik Metzen +# License: BSD Style. + +import matplotlib.pyplot as plt + +from sklearn import datasets +from sklearn.naive_bayes import GaussianNB +from sklearn.svm import LinearSVC +from sklearn.linear_model import LogisticRegression +from sklearn.metrics import (brier_score_loss, precision_score, recall_score, + f1_score) +from sklearn.calibration import CalibratedClassifierCV, calibration_curve +from sklearn.cross_validation import train_test_split + + +# Create dataset of classification task with many redundant and few +# informative features +X, y = datasets.make_classification(n_samples=100000, n_features=20, + n_informative=2, n_redundant=10, + random_state=42) + +X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.99, + random_state=42) + + +def plot_calibration_curve(est, name, fig_index): + """Plot calibration curve for est w/o and with calibration. """ + # Calibrated with isotonic calibration + isotonic = CalibratedClassifierCV(est, cv=2, method='isotonic') + + # Calibrated with sigmoid calibration + sigmoid = CalibratedClassifierCV(est, cv=2, method='sigmoid') + + # Logistic regression with no calibration as baseline + lr = LogisticRegression(C=1., solver='lbfgs') + + fig = plt.figure(fig_index, figsize=(10, 10)) + ax1 = plt.subplot2grid((3, 1), (0, 0), rowspan=2) + ax2 = plt.subplot2grid((3, 1), (2, 0)) + + ax1.plot([0, 1], [0, 1], "k:", label="Perfectly calibrated") + for clf, name in [(lr, 'Logistic'), + (est, name), + (isotonic, name + ' + Isotonic'), + (sigmoid, name + ' + Sigmoid')]: + clf.fit(X_train, y_train) + y_pred = clf.predict(X_test) + if hasattr(clf, "predict_proba"): + prob_pos = clf.predict_proba(X_test)[:, 1] + else: # use decision function + prob_pos = clf.decision_function(X_test) + prob_pos = \ + (prob_pos - prob_pos.min()) / (prob_pos.max() - prob_pos.min()) + + clf_score = brier_score_loss(y_test, prob_pos, pos_label=y.max()) + print("%s:" % name) + print("\tBrier: %1.3f" % (clf_score)) + print("\tPrecision: %1.3f" % precision_score(y_test, y_pred)) + print("\tRecall: %1.3f" % recall_score(y_test, y_pred)) + print("\tF1: %1.3f\n" % f1_score(y_test, y_pred)) + + fraction_of_positives, mean_predicted_value = \ + calibration_curve(y_test, prob_pos, n_bins=10) + + ax1.plot(mean_predicted_value, fraction_of_positives, "s-", + label="%s (%1.3f)" % (name, clf_score)) + + ax2.hist(prob_pos, range=(0, 1), bins=10, label=name, + histtype="step", lw=2) + + ax1.set_ylabel("Fraction of positives") + ax1.set_ylim([-0.05, 1.05]) + ax1.legend(loc="lower right") + ax1.set_title('Calibration plots (reliability curve)') + + ax2.set_xlabel("Mean predicted value") + ax2.set_ylabel("Count") + ax2.legend(loc="upper center", ncol=2) + + plt.tight_layout() + +# Plot calibration cuve for Gaussian Naive Bayes +plot_calibration_curve(GaussianNB(), "Naive Bayes", 1) + +# Plot calibration cuve for Linear SVC +plot_calibration_curve(LinearSVC(), "SVC", 2) + +plt.show() diff --git a/examples/calibration/plot_calibration_multiclass.py b/examples/calibration/plot_calibration_multiclass.py new file mode 100644 index 0000000000000..7843b84fa35cb --- /dev/null +++ b/examples/calibration/plot_calibration_multiclass.py @@ -0,0 +1,168 @@ +""" +================================================== +Probability Calibration for 3-class classification +================================================== + +This example illustrates how sigmoid calibration changes predicted +probabilities for a 3-class classification problem. Illustrated is the +standard 2-simplex, where the three corners correspond to the three classes. +Arrows point from the probability vectors predicted by an uncalibrated +classifier to the probability vectors predicted by the same classifier after +sigmoid calibration on a hold-out validation set. Colors indicate the true +class of an instance (red: class 1, green: class 2, blue: class 3). + +The base classifier is a random forest classifier with 25 base estimators +(trees). If this classifier is trained on all 800 training datapoints, it is +overly confident in its predictions and thus incurs a large log-loss. +Calibrating an identical classifier, which was trained on 600 datapoints, with +method='sigmoid' on the remaining 200 datapoints reduces the confidence of the +predictions, i.e., moves the probability vectors from the edges of the simplex +towards the center. This calibration results in a lower log-loss. Note that an +alternative would have been to increase the number of base estimators which +would have resulted in a similar decrease in log-loss. +""" +print(__doc__) + +# Author: Jan Hendrik Metzen +# License: BSD Style. + + +import matplotlib.pyplot as plt + +import numpy as np + +from sklearn.datasets import make_blobs +from sklearn.ensemble import RandomForestClassifier +from sklearn.calibration import CalibratedClassifierCV +from sklearn.metrics import log_loss + +np.random.seed(0) + +# Generate data +X, y = make_blobs(n_samples=1000, n_features=2, random_state=42, + cluster_std=5.0) +X_train, y_train = X[:600], y[:600] +X_valid, y_valid = X[600:800], y[600:800] +X_train_valid, y_train_valid = X[:800], y[:800] +X_test, y_test = X[800:], y[800:] + +# Train uncalibrated random forest classifier on whole train and validation +# data and evaluate on test data +clf = RandomForestClassifier(n_estimators=25) +clf.fit(X_train_valid, y_train_valid) +clf_probs = clf.predict_proba(X_test) +score = log_loss(y_test, clf_probs) + +# Train random forest classifier, calibrate on validation data and evaluate +# on test data +clf = RandomForestClassifier(n_estimators=25) +clf.fit(X_train, y_train) +clf_probs = clf.predict_proba(X_test) +sig_clf = CalibratedClassifierCV(clf, method="sigmoid", cv="prefit") +sig_clf.fit(X_valid, y_valid) +sig_clf_probs = sig_clf.predict_proba(X_test) +sig_score = log_loss(y_test, sig_clf_probs) + +# Plot changes in predicted probabilities via arrows +plt.figure(0) +colors = ["r", "g", "b"] +for i in range(clf_probs.shape[0]): + plt.arrow(clf_probs[i, 0], clf_probs[i, 1], + sig_clf_probs[i, 0] - clf_probs[i, 0], + sig_clf_probs[i, 1] - clf_probs[i, 1], + color=colors[y_test[i]], head_width=1e-2) + +# Plot perfect predictions +plt.plot([1.0], [0.0], 'ro', ms=20, label="Class 1") +plt.plot([0.0], [1.0], 'go', ms=20, label="Class 2") +plt.plot([0.0], [0.0], 'bo', ms=20, label="Class 3") + +# Plot boundaries of unit simplex +plt.plot([0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], 'k', label="Simplex") + +# Annotate points on the simplex +plt.annotate(r'($\frac{1}{3}$, $\frac{1}{3}$, $\frac{1}{3}$)', + xy=(1.0/3, 1.0/3), xytext=(1.0/3, .23), xycoords='data', + arrowprops=dict(facecolor='black', shrink=0.05), + horizontalalignment='center', verticalalignment='center') +plt.plot([1.0/3], [1.0/3], 'ko', ms=5) +plt.annotate(r'($\frac{1}{2}$, $0$, $\frac{1}{2}$)', + xy=(.5, .0), xytext=(.5, .1), xycoords='data', + arrowprops=dict(facecolor='black', shrink=0.05), + horizontalalignment='center', verticalalignment='center') +plt.annotate(r'($0$, $\frac{1}{2}$, $\frac{1}{2}$)', + xy=(.0, .5), xytext=(.1, .5), xycoords='data', + arrowprops=dict(facecolor='black', shrink=0.05), + horizontalalignment='center', verticalalignment='center') +plt.annotate(r'($\frac{1}{2}$, $\frac{1}{2}$, $0$)', + xy=(.5, .5), xytext=(.6, .6), xycoords='data', + arrowprops=dict(facecolor='black', shrink=0.05), + horizontalalignment='center', verticalalignment='center') +plt.annotate(r'($0$, $0$, $1$)', + xy=(0, 0), xytext=(.1, .1), xycoords='data', + arrowprops=dict(facecolor='black', shrink=0.05), + horizontalalignment='center', verticalalignment='center') +plt.annotate(r'($1$, $0$, $0$)', + xy=(1, 0), xytext=(1, .1), xycoords='data', + arrowprops=dict(facecolor='black', shrink=0.05), + horizontalalignment='center', verticalalignment='center') +plt.annotate(r'($0$, $1$, $0$)', + xy=(0, 1), xytext=(.1, 1), xycoords='data', + arrowprops=dict(facecolor='black', shrink=0.05), + horizontalalignment='center', verticalalignment='center') +# Add grid +plt.grid("off") +for x in [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]: + plt.plot([0, x], [x, 0], 'k', alpha=0.2) + plt.plot([0, 0 + (1-x)/2], [x, x + (1-x)/2], 'k', alpha=0.2) + plt.plot([x, x + (1-x)/2], [0, 0 + (1-x)/2], 'k', alpha=0.2) + +plt.title("Change of predicted probabilities after sigmoid calibration") +plt.xlabel("Probability class 1") +plt.ylabel("Probability class 2") +plt.xlim(-0.05, 1.05) +plt.ylim(-0.05, 1.05) +plt.legend(loc="best") + +print("Log-loss of") +print(" * uncalibrated classifier trained on 800 datapoints: %.3f " + % score) +print(" * classifier trained on 600 datapoints and calibrated on " + "200 datapoint: %.3f" % sig_score) + +# Illustrate calibrator +plt.figure(1) +# generate grid over 2-simplex +p1d = np.linspace(0, 1, 20) +p0, p1 = np.meshgrid(p1d, p1d) +p2 = 1 - p0 - p1 +p = np.c_[p0.ravel(), p1.ravel(), p2.ravel()] +p = p[p[:, 2] >= 0] + +calibrated_classifier = sig_clf.calibrated_classifiers_[0] +prediction = np.vstack([calibrator.predict(this_p) + for calibrator, this_p in + zip(calibrated_classifier.calibrators_, p.T)]).T +prediction /= prediction.sum(axis=1)[:, None] + +# Ploit modifications of calibrator +for i in range(prediction.shape[0]): + plt.arrow(p[i, 0], p[i, 1], + prediction[i, 0] - p[i, 0], prediction[i, 1] - p[i, 1], + head_width=1e-2, color=colors[np.argmax(p[i])]) +# Plot boundaries of unit simplex +plt.plot([0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], 'k', label="Simplex") + +plt.grid("off") +for x in [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]: + plt.plot([0, x], [x, 0], 'k', alpha=0.2) + plt.plot([0, 0 + (1-x)/2], [x, x + (1-x)/2], 'k', alpha=0.2) + plt.plot([x, x + (1-x)/2], [0, 0 + (1-x)/2], 'k', alpha=0.2) + +plt.title("Illustration of sigmoid calibrator") +plt.xlabel("Probability class 1") +plt.ylabel("Probability class 2") +plt.xlim(-0.05, 1.05) +plt.ylim(-0.05, 1.05) + +plt.show() diff --git a/examples/calibration/plot_compare_calibration.py b/examples/calibration/plot_compare_calibration.py new file mode 100644 index 0000000000000..dc40ff108b8c5 --- /dev/null +++ b/examples/calibration/plot_compare_calibration.py @@ -0,0 +1,124 @@ +""" +======================================= +Comparison of Calibration of Classifiers +======================================= + +Well calibrated classifiers are probabilistic classifiers for which the output +of the predict_proba method can be directly interpreted as a confidence level. +For instance a well calibrated (binary) classifier should classify the samples +such that among the samples to which it gave a predict_proba value close to +0.8, approx. 80% actually belong to the positive class. + +LogisticRegression returns well calibrated predictions as it directly +optimizes log-loss. In contrast, the other methods return biased probilities, +with different biases per method: + +* GaussianNaiveBayes tends to push probabilties to 0 or 1 (note the counts in + the histograms). This is mainly because it makes the assumption that features + are conditionally independent given the class, which is not the case in this + dataset which contains 2 redundant features. + +* RandomForestClassifier shows the opposite behavior: the histograms show + peaks at approx. 0.2 and 0.9 probability, while probabilities close to 0 or 1 + are very rare. An explanation for this is given by Niculescu-Mizil and Caruana + [1]: "Methods such as bagging and random forests that average predictions from + a base set of models can have difficulty making predictions near 0 and 1 + because variance in the underlying base models will bias predictions that + should be near zero or one away from these values. Because predictions are + restricted to the interval [0,1], errors caused by variance tend to be one- + sided near zero and one. For example, if a model should predict p = 0 for a + case, the only way bagging can achieve this is if all bagged trees predict + zero. If we add noise to the trees that bagging is averaging over, this noise + will cause some trees to predict values larger than 0 for this case, thus + moving the average prediction of the bagged ensemble away from 0. We observe + this effect most strongly with random forests because the base-level trees + trained with random forests have relatively high variance due to feature + subseting." As a result, the calibration curve shows a characteristic sigmoid + shape, indicating that the classifier could trust its "intuition" more and + return probabilties closer to 0 or 1 typically. + +* Support Vector Classification (SVC) shows an even more sigmoid curve as + the RandomForestClassifier, which is typical for maximum-margin methods + (compare Niculescu-Mizil and Caruana [1]), which focus on hard samples + that are close to the decision boundary (the support vectors). + +.. topic:: References: + + .. [1] Predicting Good Probabilities with Supervised Learning, + A. Niculescu-Mizil & R. Caruana, ICML 2005 +""" +print(__doc__) + +# Author: Jan Hendrik Metzen +# License: BSD Style. + +import numpy as np +np.random.seed(0) + +import matplotlib.pyplot as plt + +from sklearn import datasets +from sklearn.svm import SVC +from sklearn.naive_bayes import GaussianNB +from sklearn.linear_model import LogisticRegression +from sklearn.ensemble import RandomForestClassifier +from sklearn.svm import LinearSVC +from sklearn.metrics import brier_score_loss +from sklearn.calibration import calibration_curve + +X, y = datasets.make_classification(n_samples=100000, n_features=20, + n_informative=2, n_redundant=2) + +train_samples = 100 # Samples used for training the models + +X_train = X[:train_samples] +X_test = X[train_samples:] +y_train = y[:train_samples] +y_test = y[train_samples:] + +# Create classifiers +lr = LogisticRegression() +gnb = GaussianNB() +svc = LinearSVC(C=1.0) +rfc = RandomForestClassifier(n_estimators=100) + + +############################################################################### +# Plot calibration plots + +plt.figure(figsize=(10, 10)) +ax1 = plt.subplot2grid((3, 1), (0, 0), rowspan=2) +ax2 = plt.subplot2grid((3, 1), (2, 0)) + +ax1.plot([0, 1], [0, 1], "k:", label="Perfectly calibrated") +for clf, name in [(lr, 'Logistic'), + (gnb, 'Naive Bayes'), + (svc, 'Support Vector Classification'), + (rfc, 'Random Forest')]: + clf.fit(X_train, y_train) + if hasattr(clf, "predict_proba"): + prob_pos = clf.predict_proba(X_test)[:, 1] + else: # use decision function + prob_pos = clf.decision_function(X_test) + prob_pos = \ + (prob_pos - prob_pos.min()) / (prob_pos.max() - prob_pos.min()) + fraction_of_positives, mean_predicted_value = \ + calibration_curve(y_test, prob_pos, n_bins=10) + + ax1.plot(mean_predicted_value, fraction_of_positives, "s-", + label="%s" % (name, )) + + ax2.hist(prob_pos, range=(0, 1), bins=10, label=name, + histtype="step", lw=2) + +ax1.set_ylabel("Fraction of positives") +ax1.set_ylim([-0.05, 1.05]) +ax1.legend(loc="lower right") +ax1.set_title('Calibration plots (reliability curve)') + +ax2.set_xlabel("Mean predicted value") +ax2.set_ylabel("Count") +ax2.legend(loc="upper center", ncol=2) + +plt.tight_layout() +plt.show() From 1a2a9ae60cf3b38c560d56aff234b715e9469e76 Mon Sep 17 00:00:00 2001 From: Jan Hendrik Metzen Date: Wed, 18 Feb 2015 12:11:52 +0100 Subject: [PATCH 5/8] DOC Narrative doc for the calibration module --- doc/modules/calibration.rst | 195 ++++++++++++++++++++++++++++++++++++ doc/modules/classes.rst | 29 ++++++ doc/supervised_learning.rst | 1 + doc/whats_new.rst | 4 + 4 files changed, 229 insertions(+) create mode 100644 doc/modules/calibration.rst diff --git a/doc/modules/calibration.rst b/doc/modules/calibration.rst new file mode 100644 index 0000000000000..8590d6b1b33db --- /dev/null +++ b/doc/modules/calibration.rst @@ -0,0 +1,195 @@ +.. _calibration: + +======================= +Probability calibration +======================= + +.. currentmodule:: sklearn.calibration + + +When performing classification you often want not only to predict the class +label, but also obtain a probability of the respective label. This probability +gives you some kind of confidence on the prediction. Some models can give you +poor estimates of the class probabilities and some even do not not support +probability prediction. The calibration module allows you to better calibrate +the probabilities of a given model, or to add support for probability +prediction. + +Well calibrated classifiers are probabilistic classifiers for which the output +of the predict_proba method can be directly interpreted as a confidence level. +For instance, a well calibrated (binary) classifier should classify the samples +such that among the samples to which it gave a predict_proba value close to 0.8, +approximately 80% actually belong to the positive class. The following plot compares +how well the probabilistic predictions of different classifiers are calibrated: + +.. figure:: ../auto_examples/calibration/images/plot_compare_calibration_001.png + :target: ../auto_examples/calibration/plot_compare_calibration.html + :align: center + +.. currentmodule:: sklearn.linear_model +:class:`LogisticRegression` returns well calibrated predictions by default as it directly +optimizes log-loss. In contrast, the other methods return biased probabilities; +with different biases per method: + + * .. currentmodule:: sklearn.naive_bayes + :class:`GaussianNB` tends to push probabilties to 0 or 1 (note the + counts in the histograms). This is mainly because it makes the assumption + that features are conditionally independent given the class, which is not + the case in this dataset which contains 2 redundant features. + + * .. currentmodule:: sklearn.ensemble + :class:`RandomForestClassifier` shows the opposite behavior: the histograms + show peaks at approximately 0.2 and 0.9 probability, while probabilities close to + 0 or 1 are very rare. An explanation for this is given by Niculescu-Mizil + and Caruana [4]: "Methods such as bagging and random forests that average + predictions from a base set of models can have difficulty making predictions + near 0 and 1 because variance in the underlying base models will bias + predictions that should be near zero or one away from these values. Because + predictions are restricted to the interval [0,1], errors caused by variance + tend to be one-sided near zero and one. For example, if a model should + predict p = 0 for a case, the only way bagging can achieve this is if all + bagged trees predict zero. If we add noise to the trees that bagging is + averaging over, this noise will cause some trees to predict values larger + than 0 for this case, thus moving the average prediction of the bagged + ensemble away from 0. We observe this effect most strongly with random + forests because the base-level trees trained with random forests have + relatively high variance due to feature subseting." As a result, the + calibration curve shows a characteristic sigmoid shape, indicating that the + classifier could trust its "intuition" more and return probabilties closer + to 0 or 1 typically. + + * .. currentmodule:: sklearn.svm + Linear Support Vector Classification (:class:`LinearSVC`) shows an even more sigmoid curve + as the RandomForestClassifier, which is typical for maximum-margin methods + (compare Niculescu-Mizil and Caruana [4]), which focus on hard samples + that are close to the decision boundary (the support vectors). + +.. currentmodule:: sklearn.calibration +Two approaches for performing calibration of probabilistic predictions are +provided: a parametric approach based on Platt's sigmoid model and a +non-parametric approach based on isotonic regression (:mod:`sklearn.isotonic`). +Probability calibration should be done on new data not used for model fitting. +The class :class:`CalibratedClassifierCV` uses a cross-validation generator and +estimates for each split the model parameter on the train samples and the +calibration of the test samples. The probabilities predicted for the +folds are then averaged. Already fitted classifiers can be calibrated by +:class:`CalibratedClassifierCV` via the paramter cv="prefit". In this case, +the user has to take care manually that data for model fitting and calibration +are disjoint. + +The following images demonstrate the benefit of probability calibration. +The first image present a dataset with 2 classes and 3 blobs of +data. The blob in the middle contains random samples of each class. +The probability for the samples in this blob should be 0.5. + +.. figure:: ../auto_examples/calibration/images/plot_calibration_001.png + :target: ../auto_examples/calibration/plot_calibration.html + :align: center + +The following image shows on the data above the estimated probability +using a Gaussian naive Bayes classifier without calibration, +with a sigmoid calibration and with a non-parametric isotonic +calibration. One can observe that the non-parametric model +provides the most accurate probability estimates for samples +in the middle, i.e., 0.5. + +.. figure:: ../auto_examples/calibration/images/plot_calibration_002.png + :target: ../auto_examples/calibration/plot_calibration.html + :align: center + +.. currentmodule:: sklearn.metrics +The following experiment is performed on an artificial dataset for binary +classification with 100.000 samples (1.000 of them are used for model fitting) +with 20 features. Of the 20 features, only 2 are informative and 10 are +redundant. The figure shows the estimated probabilities obtained with +logistic regression, a linear support-vector classifier (SVC), and linear SVC with +both isotonic calibration and sigmoid calibration. The calibration performance +is evaluated with Brier score :func:`brier_score_loss`, reported in the legend +(the smaller the better). + +.. figure:: ../auto_examples/calibration/images/plot_calibration_curve_002.png + :target: ../auto_examples/calibration/plot_calibration_curve.html + :align: center + +One can observe here that logistic regression is well calibrated as its curve is +nearly diagonal. Linear SVC's calibration curve has a sigmoid curve, which is +typical for an under-confident classifier. In the case of LinearSVC, this is +caused by the margin property of the hinge loss, which lets the model focus on +hard samples that are close to the decision boundary (the support vectors). Both +kinds of calibration can fix this issue and yield nearly identical results. +The next figure shows the calibration curve of Gaussian naive Bayes on +the same data, with both kinds of calibration and also without calibration. + +.. figure:: ../auto_examples/calibration/images/plot_calibration_curve_001.png + :target: ../auto_examples/calibration/plot_calibration_curve.html + :align: center + +One can see that Gaussian naive Bayes performs very badly but does so in an +other way than linear SVC: While linear SVC exhibited a sigmoid calibration +curve, Gaussian naive Bayes' calibration curve has a transposed-sigmoid shape. +This is typical for an over-confident classifier. In this case, the classifier's +overconfidence is caused by the redundant features which violate the naive Bayes +assumption of feature-independence. + +Calibration of the probabilities of Gaussian naive Bayes with isotonic +regression can fix this issue as can be seen from the nearly diagonal +calibration curve. Sigmoid calibration also improves the brier score slightly, +albeit not as strongly as the non-parametric isotonic calibration. This is an +intrinsic limitation of sigmoid calibration, whose parametric form assumes a +sigmoid rather than a transposed-sigmoid curve. The non-parametric isotonic +calibration model, however, makes no such strong assumptions and can deal with +either shape, provided that there is sufficient calibration data. In general, +sigmoid calibration is preferable if the calibration curve is sigmoid and when +there is few calibration data while isotonic calibration is preferable for non- +sigmoid calibration curves and in situations where many additional data can be +used for calibration. + +.. currentmodule:: sklearn.calibration +:class:`CalibratedClassifierCV` can also deal with classification tasks that +involve more than two classes if the base estimator can do so. In this case, +the classifier is calibrated first for each class separately in an one-vs-rest +fashion. When predicting probabilities for unseen data, the calibrated +probabilities for each class are predicted separately. As those probabilities +do not necessarily sum to one, a postprocessing is performed to normalize them. + +The next image illustrates how sigmoid calibration changes predicted +probabilities for a 3-class classification problem. Illustrated is the standard +2-simplex, where the three corners correspond to the three classes. Arrows point +from the probability vectors predicted by an uncalibrated classifier to the +probability vectors predicted by the same classifier after sigmoid calibration +on a hold-out validation set. Colors indicate the true class of an instance +(red: class 1, green: class 2, blue: class 3). + +.. figure:: ../auto_examples/calibration/images/plot_calibration_multiclass_000.png + :target: ../auto_examples/calibration/plot_calibration_multiclass.html + :align: center + +The base classifier is a random forest classifier with 25 base estimators +(trees). If this classifier is trained on all 800 training datapoints, it is +overly confident in its predictions and thus incurs a large log-loss. +Calibrating an identical classifier, which was trained on 600 datapoints, with +method='sigmoid' on the remaining 200 datapoints reduces the confidence of the +predictions, i.e., moves the probability vectors from the edges of the simplex +towards the center: + +.. figure:: ../auto_examples/calibration/images/plot_calibration_multiclass_001.png + :target: ../auto_examples/calibration/plot_calibration_multiclass.html + :align: center + +This calibration results in a lower log-loss. Note that an alternative would +have been to increase the number of base estimators which would have resulted in +a similar decrease in log-loss. + +.. topic:: References: + + .. [1] Obtaining calibrated probability estimates from decision trees + and naive Bayesian classifiers, B. Zadrozny & C. Elkan, ICML 2001 + + .. [2] Transforming Classifier Scores into Accurate Multiclass + Probability Estimates, B. Zadrozny & C. Elkan, (KDD 2002) + + .. [3] Probabilistic Outputs for Support Vector Machines and Comparisons to + Regularized Likelihood Methods, J. Platt, (1999) + + .. [4] Predicting Good Probabilities with Supervised Learning, + A. Niculescu-Mizil & R. Caruana, ICML 2005 diff --git a/doc/modules/classes.rst b/doc/modules/classes.rst index 3df61164ab870..eefe6d6a6f3d4 100644 --- a/doc/modules/classes.rst +++ b/doc/modules/classes.rst @@ -768,6 +768,7 @@ details. metrics.accuracy_score metrics.auc metrics.average_precision_score + metrics.brier_score_loss metrics.classification_report metrics.confusion_matrix metrics.f1_score @@ -784,6 +785,7 @@ details. metrics.roc_auc_score metrics.roc_curve metrics.zero_one_loss + metrics.brier_score_loss Regression metrics ------------------ @@ -1009,6 +1011,33 @@ See the :ref:`metrics` section of the user guide for further details. neural_network.BernoulliRBM +.. _calibration_ref: + +:mod:`sklearn.calibration`: Probability Calibration +=================================================== + +.. automodule:: sklearn.calibration + :no-members: + :no-inherited-members: + +**User guide:** See the :ref:`calibration` section for further details. + +.. currentmodule:: sklearn + +.. autosummary:: + :toctree: generated/ + :template: class.rst + + calibration.CalibratedClassifierCV + + +.. autosummary:: + :toctree: generated/ + :template: function.rst + + calibration.calibration_curve + + .. _cross_decomposition_ref: :mod:`sklearn.cross_decomposition`: Cross decomposition diff --git a/doc/supervised_learning.rst b/doc/supervised_learning.rst index 4fc42bd1e4a1c..cde197287bee9 100644 --- a/doc/supervised_learning.rst +++ b/doc/supervised_learning.rst @@ -22,3 +22,4 @@ Supervised learning modules/feature_selection.rst modules/label_propagation.rst modules/isotonic.rst + modules/calibration.rst diff --git a/doc/whats_new.rst b/doc/whats_new.rst index c49769567f960..9e454fb84d2a7 100644 --- a/doc/whats_new.rst +++ b/doc/whats_new.rst @@ -75,6 +75,10 @@ New features for fixed user-provided cross-validation folds. By `untom `_. + - Added :class:`calibration.CalibratedClassifierCV`, an approach for + calibrating the predicted probabilities of a classifier. + By `Alexandre Gramfort`_ and `Jan Hendrik Metzen`_. + Enhancements ............ From 458669bc3f55a7358a23ef2f37f71aa4d637a797 Mon Sep 17 00:00:00 2001 From: Alexandre Gramfort Date: Wed, 18 Feb 2015 21:04:26 +0100 Subject: [PATCH 6/8] COSMIT : pep8 + 2 spaces --- doc/modules/calibration.rst | 12 ++++++------ sklearn/calibration.py | 4 ++-- sklearn/tests/test_calibration.py | 10 ++++------ 3 files changed, 12 insertions(+), 14 deletions(-) diff --git a/doc/modules/calibration.rst b/doc/modules/calibration.rst index 8590d6b1b33db..f215a8628dbc1 100644 --- a/doc/modules/calibration.rst +++ b/doc/modules/calibration.rst @@ -8,8 +8,8 @@ Probability calibration When performing classification you often want not only to predict the class -label, but also obtain a probability of the respective label. This probability -gives you some kind of confidence on the prediction. Some models can give you +label, but also obtain a probability of the respective label. This probability +gives you some kind of confidence on the prediction. Some models can give you poor estimates of the class probabilities and some even do not not support probability prediction. The calibration module allows you to better calibrate the probabilities of a given model, or to add support for probability @@ -34,8 +34,8 @@ with different biases per method: * .. currentmodule:: sklearn.naive_bayes :class:`GaussianNB` tends to push probabilties to 0 or 1 (note the counts in the histograms). This is mainly because it makes the assumption - that features are conditionally independent given the class, which is not - the case in this dataset which contains 2 redundant features. + that features are conditionally independent given the class, which is not + the case in this dataset which contains 2 redundant features. * .. currentmodule:: sklearn.ensemble :class:`RandomForestClassifier` shows the opposite behavior: the histograms @@ -60,8 +60,8 @@ with different biases per method: * .. currentmodule:: sklearn.svm Linear Support Vector Classification (:class:`LinearSVC`) shows an even more sigmoid curve - as the RandomForestClassifier, which is typical for maximum-margin methods - (compare Niculescu-Mizil and Caruana [4]), which focus on hard samples + as the RandomForestClassifier, which is typical for maximum-margin methods + (compare Niculescu-Mizil and Caruana [4]), which focus on hard samples that are close to the decision boundary (the support vectors). .. currentmodule:: sklearn.calibration diff --git a/sklearn/calibration.py b/sklearn/calibration.py index e4750cbb58d51..af338ddda9b51 100644 --- a/sklearn/calibration.py +++ b/sklearn/calibration.py @@ -115,7 +115,7 @@ def fit(self, X, y, sample_weight=None): n_folds = self.cv if isinstance(self.cv, int) \ else self.cv.n_folds if hasattr(self.cv, "n_folds") else None if n_folds and \ - np.any([np.sum(y==class_) < n_folds for class_ in self.classes_]): + np.any([np.sum(y == class_) < n_folds for class_ in self.classes_]): raise ValueError("Requesting %d-fold cross-validation but provided" " less than %d examples for at least one class." % (n_folds, n_folds)) @@ -338,7 +338,7 @@ def predict_proba(self, X): proba[np.isnan(proba)] = 1. / n_classes # Deal with cases where the predicted probability minimally exceeds 1.0 - proba[(1.0 < proba) & (proba <= 1.0 + 1e-5)] = 1.0 + proba[(1.0 < proba) & (proba <= 1.0 + 1e-5)] = 1.0 return proba diff --git a/sklearn/tests/test_calibration.py b/sklearn/tests/test_calibration.py index 797fb0bc60674..e9d85ecd851ae 100644 --- a/sklearn/tests/test_calibration.py +++ b/sklearn/tests/test_calibration.py @@ -29,8 +29,7 @@ def test_calibration(): # split train and test X_train, y_train, sw_train = \ X[:n_samples], y[:n_samples], sample_weight[:n_samples] - X_test, y_test, sw_test = \ - X[n_samples:], y[n_samples:], sample_weight[n_samples:] + X_test, y_test = X[n_samples:], y[n_samples:] # Naive-Bayes clf = MultinomialNB() @@ -147,10 +146,9 @@ def test_calibration_prefit(): X_train, y_train, sw_train = \ X[:n_samples], y[:n_samples], sample_weight[:n_samples] X_calib, y_calib, sw_calib = \ - X[n_samples:2*n_samples], y[n_samples:2*n_samples], \ - sample_weight[n_samples:2*n_samples] - X_test, y_test, sw_test = \ - X[2*n_samples:], y[2*n_samples:], sample_weight[2*n_samples:] + X[n_samples:2 * n_samples], y[n_samples:2 * n_samples], \ + sample_weight[n_samples:2 * n_samples] + X_test, y_test = X[2 * n_samples:], y[2 * n_samples:] # Naive-Bayes clf = MultinomialNB() From db131325389439a7b6c0c82b84327311429e3f7c Mon Sep 17 00:00:00 2001 From: Alexandre Gramfort Date: Wed, 18 Feb 2015 21:16:38 +0100 Subject: [PATCH 7/8] TST improve coverage of calibration.py --- sklearn/calibration.py | 3 -- sklearn/tests/test_calibration.py | 51 +++++++++++++++++++++++++++---- 2 files changed, 45 insertions(+), 9 deletions(-) diff --git a/sklearn/calibration.py b/sklearn/calibration.py index af338ddda9b51..d6a8e290c42c7 100644 --- a/sklearn/calibration.py +++ b/sklearn/calibration.py @@ -443,9 +443,6 @@ def fit(self, X, y, sample_weight=None): y = column_or_1d(y) X, y = indexable(X, y) - if len(X.shape) != 1: - raise ValueError("X should be a 1d array") - self.a_, self.b_ = sigmoid_calibration(X, y, sample_weight) return self diff --git a/sklearn/tests/test_calibration.py b/sklearn/tests/test_calibration.py index e9d85ecd851ae..98e3d63fc70e5 100644 --- a/sklearn/tests/test_calibration.py +++ b/sklearn/tests/test_calibration.py @@ -6,11 +6,14 @@ from sklearn.utils.testing import (assert_array_almost_equal, assert_equal, assert_greater, assert_almost_equal, - assert_greater_equal) + assert_greater_equal, + assert_array_equal, + assert_raises) from sklearn.datasets import make_classification, make_blobs from sklearn.naive_bayes import MultinomialNB -from sklearn.ensemble import RandomForestClassifier +from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor from sklearn.svm import LinearSVC +from sklearn.linear_model import Ridge from sklearn.metrics import brier_score_loss, log_loss from sklearn.calibration import CalibratedClassifierCV from sklearn.calibration import sigmoid_calibration, _SigmoidCalibration @@ -36,6 +39,9 @@ def test_calibration(): clf.fit(X_train, y_train, sw_train) prob_pos_clf = clf.predict_proba(X_test)[:, 1] + pc_clf = CalibratedClassifierCV(clf, cv=y.size + 1) + assert_raises(ValueError, pc_clf.fit, X, y) + # Naive Bayes with calibration for this_X_train, this_X_test in [(X_train, X_test), (sparse.csr_matrix(X_train), @@ -78,6 +84,23 @@ def test_calibration(): brier_score_loss((y_test + 1) % 2, prob_pos_pc_clf_relabeled)) + # check that calibration can also deal with regressors that have + # a decision_function + clf_base_regressor = CalibratedClassifierCV(Ridge(), method="sigmoid") + clf_base_regressor.fit(X_train, y_train) + clf_base_regressor.predict(X_test) + + # Check failure cases: + # only "isotonic" and "sigmoid" should be accepted as methods + clf_invalid_method = CalibratedClassifierCV(clf, method="foo") + assert_raises(ValueError, clf_invalid_method.fit, X_train, y_train) + + # base-estimators should provide either decision_function or + # predict_proba (most regressors, for instance, should fail) + clf_base_regressor = \ + CalibratedClassifierCV(RandomForestRegressor(), method="sigmoid") + assert_raises(RuntimeError, clf_base_regressor.fit, X_train, y_train) + def test_calibration_multiclass(): """Test calibration for multiclass """ @@ -161,11 +184,17 @@ def test_calibration_prefit(): sparse.csr_matrix(X_test))]: for method in ['isotonic', 'sigmoid']: pc_clf = CalibratedClassifierCV(clf, method=method, cv="prefit") - pc_clf.fit(this_X_calib, y_calib, sample_weight=sw_calib) - prob_pos_pc_clf = pc_clf.predict_proba(this_X_test)[:, 1] - assert_greater(brier_score_loss(y_test, prob_pos_clf), - brier_score_loss(y_test, prob_pos_pc_clf)) + for sw in [sw_calib, None]: + pc_clf.fit(this_X_calib, y_calib, sample_weight=sw) + y_prob = pc_clf.predict_proba(this_X_test) + y_pred = pc_clf.predict(this_X_test) + prob_pos_pc_clf = y_prob[:, 1] + assert_array_equal(y_pred, + np.array([0, 1])[np.argmax(y_prob, axis=1)]) + + assert_greater(brier_score_loss(y_test, prob_pos_clf), + brier_score_loss(y_test, prob_pos_pc_clf)) def test_sigmoid_calibration(): @@ -179,6 +208,11 @@ def test_sigmoid_calibration(): sk_prob = _SigmoidCalibration().fit(exF, exY).predict(exF) assert_array_almost_equal(lin_prob, sk_prob, 6) + # check that _SigmoidCalibration().fit only accepts 1d array or 2d column + # arrays + assert_raises(ValueError, _SigmoidCalibration().fit, + np.vstack((exF, exF)), exY) + def test_calibration_curve(): """Check calibration_curve function""" @@ -193,3 +227,8 @@ def test_calibration_curve(): assert_almost_equal(prob_pred, [0.1, 0.9]) assert_almost_equal(prob_true, prob_true_unnormalized) assert_almost_equal(prob_pred, prob_pred_unnormalized) + + # probabilities outside [0, 1] should not be accepted when normalize + # is set to False + assert_raises(ValueError, calibration_curve, [1.1], [-0.1], + normalize=False) From 9474f09ff95bdc9d3536d5c1bb890b7feee68201 Mon Sep 17 00:00:00 2001 From: Jan Hendrik Metzen Date: Fri, 20 Feb 2015 18:11:00 +0100 Subject: [PATCH 8/8] TST Adding brier_score_loss to test_common.py --- sklearn/metrics/tests/test_common.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/sklearn/metrics/tests/test_common.py b/sklearn/metrics/tests/test_common.py index 34e30e0c7640d..f443e776ad14d 100644 --- a/sklearn/metrics/tests/test_common.py +++ b/sklearn/metrics/tests/test_common.py @@ -25,6 +25,7 @@ from sklearn.metrics import accuracy_score from sklearn.metrics import average_precision_score +from sklearn.metrics import brier_score_loss from sklearn.metrics import confusion_matrix from sklearn.metrics import coverage_error from sklearn.metrics import explained_variance_score @@ -148,6 +149,8 @@ "hinge_loss": hinge_loss, + "brier_score_loss": brier_score_loss, + "roc_auc_score": roc_auc_score, "weighted_roc_auc": partial(roc_auc_score, average="weighted"), "samples_roc_auc": partial(roc_auc_score, average="samples"), @@ -197,6 +200,7 @@ "macro_roc_auc", "samples_roc_auc", "coverage_error", + "brier_score_loss" ] # Metrics with an "average" argument @@ -211,7 +215,9 @@ # Metrics with a "pos_label" argument METRICS_WITH_POS_LABEL = [ - "roc_curve", "hinge_loss", + "roc_curve", + + "brier_score_loss", "precision_score", "recall_score", "f1_score", "f2_score", "f0.5_score", @@ -554,9 +560,15 @@ def test_invariance_string_vs_numbers_labels(): "invariance test".format(name)) for name, metric in THRESHOLDED_METRICS.items(): - if name in ("log_loss", "hinge_loss", "unnormalized_log_loss"): + if name in ("log_loss", "hinge_loss", "unnormalized_log_loss", + "brier_score_loss"): + # Ugly, but handle case with a pos_label and label + metric_str = metric + if name in METRICS_WITH_POS_LABEL: + metric_str = partial(metric_str, pos_label=pos_label_str) + measure_with_number = metric(y1, y2) - measure_with_str = metric(y1_str, y2) + measure_with_str = metric_str(y1_str, y2) assert_array_equal(measure_with_number, measure_with_str, err_msg="{0} failed string vs number " "invariance test".format(name))