From 4d756dcdbaf210e630ef91339f9863915d2c9a5a Mon Sep 17 00:00:00 2001 From: Jan Hendrik Metzen Date: Sun, 7 Sep 2014 10:20:58 +0200 Subject: [PATCH 1/6] ENH Predictive standard deviation optionally returned in Random Forest and Bagging regressors REFACTOR Add option with_std to GaussianProcess.predict This is consistent with the interface of RandomForestRegressor.predict The old way of requesting the predictive variance via eval_MSE is deprecated REFACTOR Tests and examples of GaussianProcess use with_std instead of eval_MSE ADD Example comparing the predictive distributions of different regressors DOC Improved documentation of with_std parameter of predict() method FIX Bug in BaggingRegressor using _parallel_predict_regression DOC More consistent documentation of optional return-value y_std of predict DOC Updated doc of predict() of BaggingRegressor and RandomForestRegressor ENH Extending example plot_predictive_standard_deviation.py --- .../plot_predictive_standard_deviation.py | 123 ++++++++++++++++++ sklearn/ensemble/bagging.py | 29 +++-- sklearn/ensemble/forest.py | 24 +++- 3 files changed, 160 insertions(+), 16 deletions(-) create mode 100644 examples/plot_predictive_standard_deviation.py diff --git a/examples/plot_predictive_standard_deviation.py b/examples/plot_predictive_standard_deviation.py new file mode 100644 index 0000000000000..5319c4a98e5c6 --- /dev/null +++ b/examples/plot_predictive_standard_deviation.py @@ -0,0 +1,123 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +r""" +============================================================== +Comparison of predictive distributions of different regressors +============================================================== + +A simple one-dimensional, noisy regression problem adressed by three different +regressors: + +1. A Gaussian Process +2. A Random Forest +3. A Bagging-based Regressor + +The regressors are fitted based on noisy observations where the magnitude of +the noise at the different training point is constant and known. Plotted are +both the mean and the pointwise 95% confidence interval of the predictions. +The mean predictions are evaluated on noise-less test data using the mean- +squared-error. The mean log probabilities of the noise-less test data are used +to evaluate the predictive distributions (a normal distribution with the +predicted mean and standard deviation) of the three regressors. + +This example is based on the example gaussian_process/plot_gp_regression.py. +""" +print(__doc__) + +# Author: Jan Hendrik Metzen +# Licence: BSD 3 clause + +import numpy as np +from scipy.stats import norm +from sklearn.gaussian_process import GaussianProcess +from sklearn.ensemble import RandomForestRegressor, BaggingRegressor +from sklearn.metrics import mean_squared_error +from matplotlib import pyplot as pl + +np.random.seed(1) + + +def f(x): + """The function to predict.""" + return x * np.sin(x) + +X = np.linspace(0.1, 9.9, 20) +X = np.atleast_2d(X).T + +# Observations and noise +y = f(X).ravel() +dy = np.ones_like(y) +noise = np.random.normal(0, dy) +y += noise + +# Mesh the input space for evaluations of the real function, the prediction and +# its standard deviation +x = np.atleast_2d(np.linspace(0, 10, 1000)).T + +regrs = {"Gaussian Process": GaussianProcess(corr='squared_exponential', + theta0=1e-1, thetaL=1e-3, + thetaU=1, nugget=(dy / y) ** 2, + random_start=100), + "Random Forest": RandomForestRegressor(n_estimators=250), + "Bagging": BaggingRegressor(n_estimators=250)} + + +# Plot predictive distributions of different regressors +fig = pl.figure() +# Plot the function and the observations +pl.plot(x, f(x), 'r', label=u'$f(x) = x\,\sin(x)$') +pl.fill(np.concatenate([x, x[::-1]]), + np.concatenate([f(x) - 1.9600, (f(x) + 1.9600)[::-1]]), + alpha=.3, fc='r', ec='None') +pl.plot(X.ravel(), y, 'ko', zorder=5, label=u'Observations') +# Plot predictive distibutions of GP and Bagging +colors = {"Gaussian Process": 'b', "Bagging": 'g'} +mse = {} +log_pdf_loss = {} +for name, regr in regrs.items(): + regr.fit(X, y) + + # Make the prediction on the meshed x-axis (ask for standard deviation + # as well) + y_pred, sigma = regr.predict(x, with_std=True) + + # Compute mean-squared error and log predictive loss + mse[name] = mean_squared_error(f(x), y_pred) + log_pdf_loss[name] = \ + norm(y_pred, sigma).logpdf(f(x)).mean() + + if name == "Random Forest": # Skip because RF is very similar to Bagging + continue + + # Plot 95% confidence interval based on the predictive standard deviation + pl.plot(x, y_pred, colors[name], label=name) + pl.fill(np.concatenate([x, x[::-1]]), + np.concatenate([y_pred - 1.9600 * sigma, + (y_pred + 1.9600 * sigma)[::-1]]), + alpha=.3, fc=colors[name], ec='None') + + +pl.xlabel('$x$') +pl.ylabel('$f(x)$') +pl.ylim(-10, 20) +pl.legend(loc='upper left') + +print "Mean-squared error of predictors on 1000 equidistant noise-less test " \ + "datapoints:\n\tRandom Forest: %.2f\n\tBagging: %.2f" \ + "\n\tGaussian Process: %.2f" \ + % (mse["Random Forest"], mse["Bagging"], mse["Gaussian Process"]) + +print "Mean log-probability of 1000 equidistant noise-less test datapoints " \ + "under the (normal) predictive distribution of the predictors, i.e., " \ + "log N(y_true| y_pred_mean, y_pred_std) [less is better]:"\ + "\n\tRandom Forest: %.2f\n\tBagging: %.2f\n\tGaussian Process: %.2f" \ + % (log_pdf_loss["Random Forest"], log_pdf_loss["Bagging"], + log_pdf_loss["Gaussian Process"]) + +print "In summary, the mean predictions of the Gaussian Process are slightly "\ + "better than those of Random Forest and Bagging. The predictive " \ + "distributions (taking into account also the predictive variance) " \ + "of the Gaussian Process are considerably better." + +pl.show() diff --git a/sklearn/ensemble/bagging.py b/sklearn/ensemble/bagging.py index f9849d33389fb..842f2bd609e45 100644 --- a/sklearn/ensemble/bagging.py +++ b/sklearn/ensemble/bagging.py @@ -185,9 +185,8 @@ def _parallel_decision_function(estimators, estimators_features, X): def _parallel_predict_regression(estimators, estimators_features, X): """Private function used to compute predictions within a job.""" - return sum(estimator.predict(X[:, features]) - for estimator, features in zip(estimators, - estimators_features)) + return [estimator.predict(X[:, features]) + for estimator, features in zip(estimators, estimators_features)] class BaseBagging(with_metaclass(ABCMeta, BaseEnsemble)): @@ -876,11 +875,13 @@ def __init__(self, random_state=random_state, verbose=verbose) - def predict(self, X): + def predict(self, X, with_std=False): """Predict regression target for X. The predicted regression target of an input sample is computed as the mean predicted regression targets of the estimators in the ensemble. + Optionally, the standard deviation of the predictions of the ensemble's + estimators is computed in addition. Parameters ---------- @@ -888,10 +889,17 @@ def predict(self, X): The training input samples. Sparse matrices are accepted only if they are supported by the base estimator. + with_std : boolean, optional, default=False + When True, the standard deviation of the predictions of the + ensemble's estimators is returned in addition to the mean. + Returns ------- - y : array of shape = [n_samples] - The predicted values. + y_mean : array of shape = [n_samples] + The mean of the predicted values. + + y_std : array of shape = [n_samples], optional (if with_std == True) + The standard deviation of the ensemble's predicted values. """ check_is_fitted(self, "estimators_features_") # Check data @@ -909,9 +917,12 @@ def predict(self, X): for i in range(n_jobs)) # Reduce - y_hat = sum(all_y_hat) / self.n_estimators - - return y_hat + all_y_hat = np.array(all_y_hat).reshape(self.n_estimators, -1) + y_mean = np.mean(all_y_hat, axis=0) + if with_std: + return y_mean, np.std(all_y_hat, axis=0) + else: + return y_mean def _validate_estimator(self): """Check the estimator and set the base_estimator_ attribute.""" diff --git a/sklearn/ensemble/forest.py b/sklearn/ensemble/forest.py index 9349c28b44339..cb157115c953f 100644 --- a/sklearn/ensemble/forest.py +++ b/sklearn/ensemble/forest.py @@ -662,11 +662,13 @@ def __init__(self, verbose=verbose, warm_start=warm_start) - def predict(self, X): + def predict(self, X, with_std=False): """Predict regression target for X. The predicted regression target of an input sample is computed as the mean predicted regression targets of the trees in the forest. + Optionally, the standard deviation of the predictions of the ensemble's + estimators is computed in addition. Parameters ---------- @@ -675,10 +677,17 @@ def predict(self, X): ``dtype=np.float32`` and if a sparse matrix is provided to a sparse ``csr_matrix``. + with_std : boolean, optional, default=False + When True, the standard deviation of the predictions of the + ensemble's estimators is returned in addition to the mean. + Returns ------- - y : array of shape = [n_samples] or [n_samples, n_outputs] - The predicted values. + y_mean: array of shape = [n_samples] or [n_samples, n_outputs] + The mean of the predicted values. + + y_std : array of shape = [n_samples], optional (if with_std == True) + The standard deviation of the predicted values. """ # Check data X = self._validate_X_predict(X) @@ -692,10 +701,11 @@ def predict(self, X): delayed(_parallel_helper)(e, 'predict', X, check_input=False) for e in self.estimators_) - # Reduce - y_hat = sum(all_y_hat) / len(self.estimators_) - - return y_hat + y_mean = np.mean(all_y_hat, axis=0) + if with_std: + return y_mean, np.std(all_y_hat, axis=0) + else: + return y_mean def _set_oob_score(self, X, y): """Compute out-of-bag scores""" From 714dca0c7840234ec21f87b0a549966e8f800bf7 Mon Sep 17 00:00:00 2001 From: Gilles Louppe Date: Thu, 22 Oct 2015 16:22:33 +0200 Subject: [PATCH 2/6] FIX: rename with_std to return_std --- .../plot_predictive_standard_deviation.py | 43 +++++++++---------- sklearn/ensemble/bagging.py | 8 ++-- sklearn/ensemble/forest.py | 8 ++-- 3 files changed, 28 insertions(+), 31 deletions(-) diff --git a/examples/plot_predictive_standard_deviation.py b/examples/plot_predictive_standard_deviation.py index 5319c4a98e5c6..5940132730de9 100644 --- a/examples/plot_predictive_standard_deviation.py +++ b/examples/plot_predictive_standard_deviation.py @@ -30,7 +30,7 @@ import numpy as np from scipy.stats import norm -from sklearn.gaussian_process import GaussianProcess +from sklearn.gaussian_process import GaussianProcessRegressor from sklearn.ensemble import RandomForestRegressor, BaggingRegressor from sklearn.metrics import mean_squared_error from matplotlib import pyplot as pl @@ -55,10 +55,7 @@ def f(x): # its standard deviation x = np.atleast_2d(np.linspace(0, 10, 1000)).T -regrs = {"Gaussian Process": GaussianProcess(corr='squared_exponential', - theta0=1e-1, thetaL=1e-3, - thetaU=1, nugget=(dy / y) ** 2, - random_start=100), +regrs = {"Gaussian Process": GaussianProcessRegressor(alpha=(dy / y) ** 2), "Random Forest": RandomForestRegressor(n_estimators=250), "Bagging": BaggingRegressor(n_estimators=250)} @@ -80,7 +77,7 @@ def f(x): # Make the prediction on the meshed x-axis (ask for standard deviation # as well) - y_pred, sigma = regr.predict(x, with_std=True) + y_pred, sigma = regr.predict(x, return_std=True) # Compute mean-squared error and log predictive loss mse[name] = mean_squared_error(f(x), y_pred) @@ -94,7 +91,7 @@ def f(x): pl.plot(x, y_pred, colors[name], label=name) pl.fill(np.concatenate([x, x[::-1]]), np.concatenate([y_pred - 1.9600 * sigma, - (y_pred + 1.9600 * sigma)[::-1]]), + (y_pred + 1.9600 * sigma)[::-1]]), alpha=.3, fc=colors[name], ec='None') @@ -103,21 +100,21 @@ def f(x): pl.ylim(-10, 20) pl.legend(loc='upper left') -print "Mean-squared error of predictors on 1000 equidistant noise-less test " \ - "datapoints:\n\tRandom Forest: %.2f\n\tBagging: %.2f" \ - "\n\tGaussian Process: %.2f" \ - % (mse["Random Forest"], mse["Bagging"], mse["Gaussian Process"]) - -print "Mean log-probability of 1000 equidistant noise-less test datapoints " \ - "under the (normal) predictive distribution of the predictors, i.e., " \ - "log N(y_true| y_pred_mean, y_pred_std) [less is better]:"\ - "\n\tRandom Forest: %.2f\n\tBagging: %.2f\n\tGaussian Process: %.2f" \ - % (log_pdf_loss["Random Forest"], log_pdf_loss["Bagging"], - log_pdf_loss["Gaussian Process"]) - -print "In summary, the mean predictions of the Gaussian Process are slightly "\ - "better than those of Random Forest and Bagging. The predictive " \ - "distributions (taking into account also the predictive variance) " \ - "of the Gaussian Process are considerably better." +print("Mean-squared error of predictors on 1000 equidistant noise-less test " + "datapoints:\n\tRandom Forest: %.2f\n\tBagging: %.2f" + "\n\tGaussian Process: %.2f" + % (mse["Random Forest"], mse["Bagging"], mse["Gaussian Process"])) + +print("Mean log-probability of 1000 equidistant noise-less test datapoints " + "under the (normal) predictive distribution of the predictors, i.e., " + "log N(y_true| y_pred_mean, y_pred_std) [less is better]:" + "\n\tRandom Forest: %.2f\n\tBagging: %.2f\n\tGaussian Process: %.2f" + % (log_pdf_loss["Random Forest"], log_pdf_loss["Bagging"], + log_pdf_loss["Gaussian Process"])) + +print("In summary, the mean predictions of the Gaussian Process are slightly " + "better than those of Random Forest and Bagging. The predictive " + "distributions (taking into account also the predictive variance) " + "of the Gaussian Process are considerably better.") pl.show() diff --git a/sklearn/ensemble/bagging.py b/sklearn/ensemble/bagging.py index 842f2bd609e45..38147e78320d2 100644 --- a/sklearn/ensemble/bagging.py +++ b/sklearn/ensemble/bagging.py @@ -875,7 +875,7 @@ def __init__(self, random_state=random_state, verbose=verbose) - def predict(self, X, with_std=False): + def predict(self, X, return_std=False): """Predict regression target for X. The predicted regression target of an input sample is computed as the @@ -889,7 +889,7 @@ def predict(self, X, with_std=False): The training input samples. Sparse matrices are accepted only if they are supported by the base estimator. - with_std : boolean, optional, default=False + return_std : boolean, optional, default=False When True, the standard deviation of the predictions of the ensemble's estimators is returned in addition to the mean. @@ -898,7 +898,7 @@ def predict(self, X, with_std=False): y_mean : array of shape = [n_samples] The mean of the predicted values. - y_std : array of shape = [n_samples], optional (if with_std == True) + y_std : array of shape = [n_samples], optional (if return_std == True) The standard deviation of the ensemble's predicted values. """ check_is_fitted(self, "estimators_features_") @@ -919,7 +919,7 @@ def predict(self, X, with_std=False): # Reduce all_y_hat = np.array(all_y_hat).reshape(self.n_estimators, -1) y_mean = np.mean(all_y_hat, axis=0) - if with_std: + if return_std: return y_mean, np.std(all_y_hat, axis=0) else: return y_mean diff --git a/sklearn/ensemble/forest.py b/sklearn/ensemble/forest.py index cb157115c953f..7131b250ba2b5 100644 --- a/sklearn/ensemble/forest.py +++ b/sklearn/ensemble/forest.py @@ -662,7 +662,7 @@ def __init__(self, verbose=verbose, warm_start=warm_start) - def predict(self, X, with_std=False): + def predict(self, X, return_std=False): """Predict regression target for X. The predicted regression target of an input sample is computed as the @@ -677,7 +677,7 @@ def predict(self, X, with_std=False): ``dtype=np.float32`` and if a sparse matrix is provided to a sparse ``csr_matrix``. - with_std : boolean, optional, default=False + return_std : boolean, optional, default=False When True, the standard deviation of the predictions of the ensemble's estimators is returned in addition to the mean. @@ -686,7 +686,7 @@ def predict(self, X, with_std=False): y_mean: array of shape = [n_samples] or [n_samples, n_outputs] The mean of the predicted values. - y_std : array of shape = [n_samples], optional (if with_std == True) + y_std : array of shape = [n_samples], optional (if return_std == True) The standard deviation of the predicted values. """ # Check data @@ -702,7 +702,7 @@ def predict(self, X, with_std=False): for e in self.estimators_) y_mean = np.mean(all_y_hat, axis=0) - if with_std: + if return_std: return y_mean, np.std(all_y_hat, axis=0) else: return y_mean From de829ed90b79c57aaa6b0bbefde33671e6f9c2ac Mon Sep 17 00:00:00 2001 From: Gilles Louppe Date: Thu, 22 Oct 2015 19:00:57 +0200 Subject: [PATCH 3/6] FIX: optimize kernel parameters --- examples/plot_predictive_standard_deviation.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/examples/plot_predictive_standard_deviation.py b/examples/plot_predictive_standard_deviation.py index 5940132730de9..40162e4ccc89a 100644 --- a/examples/plot_predictive_standard_deviation.py +++ b/examples/plot_predictive_standard_deviation.py @@ -31,6 +31,7 @@ import numpy as np from scipy.stats import norm from sklearn.gaussian_process import GaussianProcessRegressor +from sklearn.gaussian_process.kernels import RBF, WhiteKernel from sklearn.ensemble import RandomForestRegressor, BaggingRegressor from sklearn.metrics import mean_squared_error from matplotlib import pyplot as pl @@ -55,10 +56,13 @@ def f(x): # its standard deviation x = np.atleast_2d(np.linspace(0, 10, 1000)).T -regrs = {"Gaussian Process": GaussianProcessRegressor(alpha=(dy / y) ** 2), - "Random Forest": RandomForestRegressor(n_estimators=250), - "Bagging": BaggingRegressor(n_estimators=250)} - +regrs = { + "Gaussian Process": GaussianProcessRegressor( + alpha=(dy / y) ** 2, kernel=1.0 * RBF() + WhiteKernel(), + n_restarts_optimizer=100), + "Random Forest": RandomForestRegressor(n_estimators=250), + "Bagging": BaggingRegressor(n_estimators=250) +} # Plot predictive distributions of different regressors fig = pl.figure() From 0061ae7725c610f490ce0b9492c4a503e1823739 Mon Sep 17 00:00:00 2001 From: Gilles Louppe Date: Mon, 26 Oct 2015 09:02:04 +0100 Subject: [PATCH 4/6] DOC: cleanup example --- .../plot_predictive_standard_deviation.py | 128 ++++++++---------- 1 file changed, 60 insertions(+), 68 deletions(-) diff --git a/examples/plot_predictive_standard_deviation.py b/examples/plot_predictive_standard_deviation.py index 40162e4ccc89a..93d5839ff5357 100644 --- a/examples/plot_predictive_standard_deviation.py +++ b/examples/plot_predictive_standard_deviation.py @@ -1,124 +1,116 @@ -#!/usr/bin/python -# -*- coding: utf-8 -*- - -r""" +""" ============================================================== Comparison of predictive distributions of different regressors ============================================================== -A simple one-dimensional, noisy regression problem adressed by three different +A simple one-dimensional, noisy regression problem addressed by two different regressors: -1. A Gaussian Process -2. A Random Forest -3. A Bagging-based Regressor +1. A Random Forest +2. A Gaussian Process The regressors are fitted based on noisy observations where the magnitude of the noise at the different training point is constant and known. Plotted are both the mean and the pointwise 95% confidence interval of the predictions. -The mean predictions are evaluated on noise-less test data using the mean- -squared-error. The mean log probabilities of the noise-less test data are used +The mean predictions are evaluated on noise-less test data using the mean +squared error. The mean log probabilities of the noise-less test data are used to evaluate the predictive distributions (a normal distribution with the predicted mean and standard deviation) of the three regressors. -This example is based on the example gaussian_process/plot_gp_regression.py. +The mean predictions of the Gaussian Process are slightly better than those of +Random Forest. The predictive distribution (taking into account +also the predictive variance) of the Random Forest is however slightly +more likely for this example. """ print(__doc__) -# Author: Jan Hendrik Metzen +# Authors: Jan Hendrik Metzen +# Gilles Louppe # Licence: BSD 3 clause +import matplotlib.pyplot as plt import numpy as np from scipy.stats import norm + +from sklearn.ensemble import RandomForestRegressor from sklearn.gaussian_process import GaussianProcessRegressor from sklearn.gaussian_process.kernels import RBF, WhiteKernel -from sklearn.ensemble import RandomForestRegressor, BaggingRegressor from sklearn.metrics import mean_squared_error -from matplotlib import pyplot as pl - -np.random.seed(1) +from sklearn.utils import check_random_state +rng = check_random_state(1) +# Observations and noise def f(x): """The function to predict.""" return x * np.sin(x) -X = np.linspace(0.1, 9.9, 20) -X = np.atleast_2d(X).T +X_train = rng.rand(20) * 10.0 +X_train = np.atleast_2d(X_train).T -# Observations and noise -y = f(X).ravel() -dy = np.ones_like(y) -noise = np.random.normal(0, dy) -y += noise +y_train = f(X_train).ravel() +dy = np.ones_like(y_train) +noise = rng.normal(0, dy) +y_train += noise # Mesh the input space for evaluations of the real function, the prediction and # its standard deviation -x = np.atleast_2d(np.linspace(0, 10, 1000)).T +X_test = np.atleast_2d(np.linspace(0, 10, 1000)).T -regrs = { +regressors = { "Gaussian Process": GaussianProcessRegressor( - alpha=(dy / y) ** 2, kernel=1.0 * RBF() + WhiteKernel(), - n_restarts_optimizer=100), - "Random Forest": RandomForestRegressor(n_estimators=250), - "Bagging": BaggingRegressor(n_estimators=250) + alpha=(dy / y_train) ** 2, + kernel=1.0 * RBF() + WhiteKernel(), + n_restarts_optimizer=100, random_state=rng), + "Random Forest": RandomForestRegressor(n_estimators=500, random_state=rng), } -# Plot predictive distributions of different regressors -fig = pl.figure() # Plot the function and the observations -pl.plot(x, f(x), 'r', label=u'$f(x) = x\,\sin(x)$') -pl.fill(np.concatenate([x, x[::-1]]), - np.concatenate([f(x) - 1.9600, (f(x) + 1.9600)[::-1]]), - alpha=.3, fc='r', ec='None') -pl.plot(X.ravel(), y, 'ko', zorder=5, label=u'Observations') +fig = plt.figure() +plt.plot(X_test, f(X_test), 'r', label=u'$f(x) = x\,\sin(x)$') +plt.fill(np.concatenate([X_test, X_test[::-1]]), + np.concatenate([f(X_test) - 1.9600, (f(X_test) + 1.9600)[::-1]]), + alpha=.3, fc='r', ec='None') +plt.plot(X_train.ravel(), y_train, 'ko', zorder=5, label=u'Observations') + # Plot predictive distibutions of GP and Bagging -colors = {"Gaussian Process": 'b', "Bagging": 'g'} +colors = {"Gaussian Process": 'b', "Random Forest": 'g'} mse = {} log_pdf_loss = {} -for name, regr in regrs.items(): - regr.fit(X, y) + +for name, regr in regressors.items(): + regr.fit(X_train, y_train) # Make the prediction on the meshed x-axis (ask for standard deviation # as well) - y_pred, sigma = regr.predict(x, return_std=True) + y_pred, sigma = regr.predict(X_test, return_std=True) # Compute mean-squared error and log predictive loss - mse[name] = mean_squared_error(f(x), y_pred) + mse[name] = mean_squared_error(f(X_test), y_pred) log_pdf_loss[name] = \ - norm(y_pred, sigma).logpdf(f(x)).mean() - - if name == "Random Forest": # Skip because RF is very similar to Bagging - continue + norm(y_pred, sigma).logpdf(f(X_test)).mean() # Plot 95% confidence interval based on the predictive standard deviation - pl.plot(x, y_pred, colors[name], label=name) - pl.fill(np.concatenate([x, x[::-1]]), - np.concatenate([y_pred - 1.9600 * sigma, - (y_pred + 1.9600 * sigma)[::-1]]), - alpha=.3, fc=colors[name], ec='None') + plt.plot(X_test, y_pred, colors[name], label=name) + plt.fill(np.concatenate([X_test, X_test[::-1]]), + np.concatenate([y_pred - 1.9600 * sigma, + (y_pred + 1.9600 * sigma)[::-1]]), + alpha=.3, fc=colors[name], ec='None') - -pl.xlabel('$x$') -pl.ylabel('$f(x)$') -pl.ylim(-10, 20) -pl.legend(loc='upper left') +plt.xlabel('$x$') +plt.ylabel('$f(x)$') +plt.ylim(-10, 20) +plt.legend(loc='upper left') print("Mean-squared error of predictors on 1000 equidistant noise-less test " - "datapoints:\n\tRandom Forest: %.2f\n\tBagging: %.2f" - "\n\tGaussian Process: %.2f" - % (mse["Random Forest"], mse["Bagging"], mse["Gaussian Process"])) + "datapoints:\n\tRandom Forest: %.2f\n\tGaussian Process: %.2f\n" + % (mse["Random Forest"], mse["Gaussian Process"])) -print("Mean log-probability of 1000 equidistant noise-less test datapoints " - "under the (normal) predictive distribution of the predictors, i.e., " +print("Mean log-probability of 1000 equidistant noise-less test datapoints\n" + "under the (normal) predictive distribution of the predictors, i.e.,\n" "log N(y_true| y_pred_mean, y_pred_std) [less is better]:" - "\n\tRandom Forest: %.2f\n\tBagging: %.2f\n\tGaussian Process: %.2f" - % (log_pdf_loss["Random Forest"], log_pdf_loss["Bagging"], + "\n\tRandom Forest: %.2f\n\tGaussian Process: %.2f\n" + % (log_pdf_loss["Random Forest"], log_pdf_loss["Gaussian Process"])) -print("In summary, the mean predictions of the Gaussian Process are slightly " - "better than those of Random Forest and Bagging. The predictive " - "distributions (taking into account also the predictive variance) " - "of the Gaussian Process are considerably better.") - -pl.show() +plt.show() From 2e93e4d18b2e40e60b5b2c94fa10d09556544be4 Mon Sep 17 00:00:00 2001 From: Gilles Louppe Date: Mon, 26 Oct 2015 16:01:19 +0100 Subject: [PATCH 5/6] ENH: implement infinitesimal jacknife estimate for bagging std Conflicts: sklearn/ensemble/bagging.py --- .../plot_predictive_standard_deviation.py | 44 ++++++++++--------- sklearn/ensemble/bagging.py | 37 +++++++++++----- sklearn/ensemble/forest.py | 8 ++-- 3 files changed, 55 insertions(+), 34 deletions(-) diff --git a/examples/plot_predictive_standard_deviation.py b/examples/plot_predictive_standard_deviation.py index 93d5839ff5357..6583c81121faa 100644 --- a/examples/plot_predictive_standard_deviation.py +++ b/examples/plot_predictive_standard_deviation.py @@ -6,8 +6,8 @@ A simple one-dimensional, noisy regression problem addressed by two different regressors: -1. A Random Forest -2. A Gaussian Process +1. A Gaussian Process +2. Bagging with extra-trees The regressors are fitted based on noisy observations where the magnitude of the noise at the different training point is constant and known. Plotted are @@ -19,7 +19,7 @@ The mean predictions of the Gaussian Process are slightly better than those of Random Forest. The predictive distribution (taking into account -also the predictive variance) of the Random Forest is however slightly +also the predictive variance) of Bagging is however more likely for this example. """ print(__doc__) @@ -32,7 +32,8 @@ import numpy as np from scipy.stats import norm -from sklearn.ensemble import RandomForestRegressor +from sklearn.ensemble import BaggingRegressor +from sklearn.tree import ExtraTreeRegressor from sklearn.gaussian_process import GaussianProcessRegressor from sklearn.gaussian_process.kernels import RBF, WhiteKernel from sklearn.metrics import mean_squared_error @@ -45,7 +46,7 @@ def f(x): """The function to predict.""" return x * np.sin(x) -X_train = rng.rand(20) * 10.0 +X_train = rng.rand(50) * 10.0 X_train = np.atleast_2d(X_train).T y_train = f(X_train).ravel() @@ -57,13 +58,17 @@ def f(x): # its standard deviation X_test = np.atleast_2d(np.linspace(0, 10, 1000)).T -regressors = { - "Gaussian Process": GaussianProcessRegressor( +regressors = [ + ("Gaussian Process", GaussianProcessRegressor( alpha=(dy / y_train) ** 2, kernel=1.0 * RBF() + WhiteKernel(), - n_restarts_optimizer=100, random_state=rng), - "Random Forest": RandomForestRegressor(n_estimators=500, random_state=rng), -} + n_restarts_optimizer=10, + random_state=rng), "b"), + ("Bagging", BaggingRegressor( + base_estimator=ExtraTreeRegressor(), + n_estimators=250, + random_state=rng), "g") +] # Plot the function and the observations fig = plt.figure() @@ -74,16 +79,15 @@ def f(x): plt.plot(X_train.ravel(), y_train, 'ko', zorder=5, label=u'Observations') # Plot predictive distibutions of GP and Bagging -colors = {"Gaussian Process": 'b', "Random Forest": 'g'} mse = {} log_pdf_loss = {} -for name, regr in regressors.items(): - regr.fit(X_train, y_train) +for name, regressor, color in regressors: + regressor.fit(X_train, y_train) # Make the prediction on the meshed x-axis (ask for standard deviation # as well) - y_pred, sigma = regr.predict(X_test, return_std=True) + y_pred, sigma = regressor.predict(X_test, return_std=True) # Compute mean-squared error and log predictive loss mse[name] = mean_squared_error(f(X_test), y_pred) @@ -91,11 +95,11 @@ def f(x): norm(y_pred, sigma).logpdf(f(X_test)).mean() # Plot 95% confidence interval based on the predictive standard deviation - plt.plot(X_test, y_pred, colors[name], label=name) + plt.plot(X_test, y_pred, color, label=name) plt.fill(np.concatenate([X_test, X_test[::-1]]), np.concatenate([y_pred - 1.9600 * sigma, (y_pred + 1.9600 * sigma)[::-1]]), - alpha=.3, fc=colors[name], ec='None') + alpha=.3, fc=color, ec='None') plt.xlabel('$x$') plt.ylabel('$f(x)$') @@ -103,14 +107,14 @@ def f(x): plt.legend(loc='upper left') print("Mean-squared error of predictors on 1000 equidistant noise-less test " - "datapoints:\n\tRandom Forest: %.2f\n\tGaussian Process: %.2f\n" - % (mse["Random Forest"], mse["Gaussian Process"])) + "datapoints:\n\tBagging: %.2f\n\tGaussian Process: %.2f\n" + % (mse["Bagging"], mse["Gaussian Process"])) print("Mean log-probability of 1000 equidistant noise-less test datapoints\n" "under the (normal) predictive distribution of the predictors, i.e.,\n" "log N(y_true| y_pred_mean, y_pred_std) [less is better]:" - "\n\tRandom Forest: %.2f\n\tGaussian Process: %.2f\n" - % (log_pdf_loss["Random Forest"], + "\n\tBagging: %.2f\n\tGaussian Process: %.2f\n" + % (log_pdf_loss["Bagging"], log_pdf_loss["Gaussian Process"])) plt.show() diff --git a/sklearn/ensemble/bagging.py b/sklearn/ensemble/bagging.py index 38147e78320d2..670d872746eb0 100644 --- a/sklearn/ensemble/bagging.py +++ b/sklearn/ensemble/bagging.py @@ -282,7 +282,7 @@ def _fit(self, X, y, max_samples, sample_weight=None): X, y = check_X_y(X, y, ['csr', 'csc']) # Remap output - n_samples, self.n_features_ = X.shape + self.n_samples_, self.n_features_ = X.shape y = self._validate_y(y) # Check parameters @@ -290,9 +290,9 @@ def _fit(self, X, y, max_samples, sample_weight=None): # if max_samples is float: if not isinstance(max_samples, (numbers.Integral, np.integer)): - max_samples = int(self.max_samples * X.shape[0]) + max_samples = int(self.max_samples * self.n_samples_) - if not (0 < max_samples <= X.shape[0]): + if not (0 < max_samples <= self.n_samples_): raise ValueError("max_samples must be in (0, n_samples]") if isinstance(self.max_features, (numbers.Integral, np.integer)): @@ -890,8 +890,8 @@ def predict(self, X, return_std=False): they are supported by the base estimator. return_std : boolean, optional, default=False - When True, the standard deviation of the predictions of the - ensemble's estimators is returned in addition to the mean. + When True, the sampling standard deviation of the predictions of + the ensemble is returned in addition to the predicted values. Returns ------- @@ -899,10 +899,10 @@ def predict(self, X, return_std=False): The mean of the predicted values. y_std : array of shape = [n_samples], optional (if return_std == True) - The standard deviation of the ensemble's predicted values. + The sampling standard deviation of the predicted values. """ + # Checks check_is_fitted(self, "estimators_features_") - # Check data X = check_array(X, accept_sparse=['csr', 'csc']) # Parallel loop @@ -919,11 +919,28 @@ def predict(self, X, return_std=False): # Reduce all_y_hat = np.array(all_y_hat).reshape(self.n_estimators, -1) y_mean = np.mean(all_y_hat, axis=0) - if return_std: - return y_mean, np.std(all_y_hat, axis=0) - else: + + if not return_std: return y_mean + else: + # Infinitesimal jacknife (IJ) estimate of the sampling variance + + # TODO: check correctness + # TODO: bias correction (compare with R code) + + var_IJ = np.zeros(len(X)) + N_bi = np.zeros((self.n_estimators, self.n_samples_)) + + for b, samples in enumerate(self.estimators_samples_): + N_bi[b, samples] += 1 + + var_IJ = np.dot((N_bi - np.mean(N_bi, axis=0)).T, + all_y_hat - y_mean) + var_IJ = (var_IJ ** 2).sum(axis=0) / self.n_estimators ** 2 + + return y_mean, var_IJ ** 0.5 + def _validate_estimator(self): """Check the estimator and set the base_estimator_ attribute.""" super(BaggingRegressor, self)._validate_estimator( diff --git a/sklearn/ensemble/forest.py b/sklearn/ensemble/forest.py index 7131b250ba2b5..9e28f46bf008a 100644 --- a/sklearn/ensemble/forest.py +++ b/sklearn/ensemble/forest.py @@ -678,8 +678,8 @@ def predict(self, X, return_std=False): to a sparse ``csr_matrix``. return_std : boolean, optional, default=False - When True, the standard deviation of the predictions of the - ensemble's estimators is returned in addition to the mean. + When True, the sampling standard deviation of the predictions of + the ensemble is returned in addition to the predicted values. Returns ------- @@ -687,9 +687,9 @@ def predict(self, X, return_std=False): The mean of the predicted values. y_std : array of shape = [n_samples], optional (if return_std == True) - The standard deviation of the predicted values. + The sampling standard deviation of the predicted values. """ - # Check data + # Checks X = self._validate_X_predict(X) # Assign chunk of trees to jobs From 523bbf9c5afff59b92c6bdb26ccd599ba82c2cc3 Mon Sep 17 00:00:00 2001 From: Gilles Louppe Date: Tue, 27 Oct 2015 09:21:00 +0100 Subject: [PATCH 6/6] ENH: implement Jacknife estimate for sampling variance --- .../plot_predictive_standard_deviation.py | 20 ++++---- sklearn/ensemble/bagging.py | 31 ++++++++---- sklearn/ensemble/forest.py | 47 +++++++++++++++---- 3 files changed, 70 insertions(+), 28 deletions(-) diff --git a/examples/plot_predictive_standard_deviation.py b/examples/plot_predictive_standard_deviation.py index 6583c81121faa..f3341361be160 100644 --- a/examples/plot_predictive_standard_deviation.py +++ b/examples/plot_predictive_standard_deviation.py @@ -15,12 +15,11 @@ The mean predictions are evaluated on noise-less test data using the mean squared error. The mean log probabilities of the noise-less test data are used to evaluate the predictive distributions (a normal distribution with the -predicted mean and standard deviation) of the three regressors. +predicted mean and standard deviation) of the two regressors. The mean predictions of the Gaussian Process are slightly better than those of Random Forest. The predictive distribution (taking into account -also the predictive variance) of Bagging is however -more likely for this example. +also the predictive variance) of Bagging is however better. """ print(__doc__) @@ -41,6 +40,7 @@ rng = check_random_state(1) + # Observations and noise def f(x): """The function to predict.""" @@ -66,7 +66,7 @@ def f(x): random_state=rng), "b"), ("Bagging", BaggingRegressor( base_estimator=ExtraTreeRegressor(), - n_estimators=250, + n_estimators=1000, random_state=rng), "g") ] @@ -107,14 +107,14 @@ def f(x): plt.legend(loc='upper left') print("Mean-squared error of predictors on 1000 equidistant noise-less test " - "datapoints:\n\tBagging: %.2f\n\tGaussian Process: %.2f\n" - % (mse["Bagging"], mse["Gaussian Process"])) + "datapoints:") +for name, _, _ in regressors: + print("\t%s: %.2f" % (name, mse[name])) print("Mean log-probability of 1000 equidistant noise-less test datapoints\n" "under the (normal) predictive distribution of the predictors, i.e.,\n" - "log N(y_true| y_pred_mean, y_pred_std) [less is better]:" - "\n\tBagging: %.2f\n\tGaussian Process: %.2f\n" - % (log_pdf_loss["Bagging"], - log_pdf_loss["Gaussian Process"])) + "log N(y_true| y_pred_mean, y_pred_std) [less is better]:") +for name, _, _ in regressors: + print("\t%s: %.2f" % (name, log_pdf_loss[name])) plt.show() diff --git a/sklearn/ensemble/bagging.py b/sklearn/ensemble/bagging.py index 670d872746eb0..f83a93309d088 100644 --- a/sklearn/ensemble/bagging.py +++ b/sklearn/ensemble/bagging.py @@ -892,6 +892,8 @@ def predict(self, X, return_std=False): return_std : boolean, optional, default=False When True, the sampling standard deviation of the predictions of the ensemble is returned in addition to the predicted values. + Standard deviations are computed using bias-corrected + Jacknife-after-bootstrap estimates, as decribed in arXiv:1311.4555. Returns ------- @@ -905,6 +907,11 @@ def predict(self, X, return_std=False): check_is_fitted(self, "estimators_features_") X = check_array(X, accept_sparse=['csr', 'csc']) + if return_std and not self.bootstrap: + raise ValueError("The sampling standard deviation of the " + "predicted values can be computed only when " + "bootstrap=True.") + # Parallel loop n_jobs, n_estimators, starts = _partition_estimators(self.n_estimators, self.n_jobs) @@ -924,22 +931,26 @@ def predict(self, X, return_std=False): return y_mean else: - # Infinitesimal jacknife (IJ) estimate of the sampling variance - - # TODO: check correctness - # TODO: bias correction (compare with R code) - - var_IJ = np.zeros(len(X)) + # Jacknife-after-bootstrap estimate of the sampling variance + var_J = np.zeros(len(X)) N_bi = np.zeros((self.n_estimators, self.n_samples_)) for b, samples in enumerate(self.estimators_samples_): N_bi[b, samples] += 1 - var_IJ = np.dot((N_bi - np.mean(N_bi, axis=0)).T, - all_y_hat - y_mean) - var_IJ = (var_IJ ** 2).sum(axis=0) / self.n_estimators ** 2 + out_of_bag = (N_bi == 0) + + for i in range(self.n_samples_): + if np.any(out_of_bag[:, i]): + delta_i = all_y_hat[out_of_bag[:, i]].mean(axis=0) - y_mean + var_J += delta_i ** 2 + + var_J *= (self.n_samples_ - 1.) / self.n_samples_ + correction = (self.n_samples_ * np.var(all_y_hat, axis=0) / + self.n_estimators) + var_J[correction < var_J] -= correction[correction < var_J] - return y_mean, var_IJ ** 0.5 + return y_mean, var_J ** 0.5 def _validate_estimator(self): """Check the estimator and set the base_estimator_ attribute.""" diff --git a/sklearn/ensemble/forest.py b/sklearn/ensemble/forest.py index 9e28f46bf008a..588136a4b5526 100644 --- a/sklearn/ensemble/forest.py +++ b/sklearn/ensemble/forest.py @@ -251,7 +251,7 @@ def fit(self, X, y, sample_weight=None): X.sort_indices() # Remap output - n_samples, self.n_features_ = X.shape + self.n_samples_, self.n_features_ = X.shape y = np.atleast_1d(y) if y.ndim == 2 and y.shape[1] == 1: @@ -549,10 +549,10 @@ def predict(self, X): def predict_proba(self, X): """Predict class probabilities for X. - The predicted class probabilities of an input sample is computed as - the mean predicted class probabilities of the trees in the forest. The - class probability of a single tree is the fraction of samples of the same - class in a leaf. + The predicted class probabilities of an input sample is computed as the + mean predicted class probabilities of the trees in the forest. The + class probability of a single tree is the fraction of samples of the + same class in a leaf. Parameters ---------- @@ -680,6 +680,8 @@ def predict(self, X, return_std=False): return_std : boolean, optional, default=False When True, the sampling standard deviation of the predictions of the ensemble is returned in addition to the predicted values. + Standard deviations are computed using bias-corrected + Jacknife-after-bootstrap estimates, as decribed in arXiv:1311.4555. Returns ------- @@ -692,6 +694,11 @@ def predict(self, X, return_std=False): # Checks X = self._validate_X_predict(X) + if return_std and not self.bootstrap: + raise ValueError("The sampling standard deviation of the " + "predicted values can be computed only when " + "bootstrap=True.") + # Assign chunk of trees to jobs n_jobs, _, _ = _partition_estimators(self.n_estimators, self.n_jobs) @@ -701,12 +708,36 @@ def predict(self, X, return_std=False): delayed(_parallel_helper)(e, 'predict', X, check_input=False) for e in self.estimators_) + all_y_hat = np.array(all_y_hat) y_mean = np.mean(all_y_hat, axis=0) - if return_std: - return y_mean, np.std(all_y_hat, axis=0) - else: + + if not return_std: return y_mean + else: + # Jacknife-after-bootstrap estimate of the sampling variance + var_J = np.zeros(len(X)) + N_bi = np.zeros((self.n_estimators, self.n_samples_)) + + for b, estimator in enumerate(self.estimators_): + samples = _generate_sample_indices(estimator.random_state, + self.n_samples_) + N_bi[b, samples] += 1 + + out_of_bag = (N_bi == 0) + + for i in range(self.n_samples_): + if np.any(out_of_bag[:, i]): + delta_i = all_y_hat[out_of_bag[:, i]].mean(axis=0) - y_mean + var_J += delta_i ** 2 + + var_J *= (self.n_samples_ - 1.) / self.n_samples_ + correction = (self.n_samples_ * np.var(all_y_hat, axis=0) / + self.n_estimators) + var_J[correction < var_J] -= correction[correction < var_J] + + return y_mean, var_J ** 0.5 + def _set_oob_score(self, X, y): """Compute out-of-bag scores""" X = check_array(X, dtype=DTYPE, accept_sparse='csr')