From 9e6386e296ec0f26a7ac82c00b0afab31b55df50 Mon Sep 17 00:00:00 2001 From: Jan Hendrik Metzen Date: Sun, 7 Sep 2014 10:20:58 +0200 Subject: [PATCH 1/9] ENH Predictive standard deviation optionally returned in Random Forest and Bagging regressors --- sklearn/ensemble/bagging.py | 26 ++++++++++++++++++-------- sklearn/ensemble/forest.py | 21 +++++++++++++++------ 2 files changed, 33 insertions(+), 14 deletions(-) diff --git a/sklearn/ensemble/bagging.py b/sklearn/ensemble/bagging.py index 3bded083991fb..6e899a03e02a5 100644 --- a/sklearn/ensemble/bagging.py +++ b/sklearn/ensemble/bagging.py @@ -180,9 +180,8 @@ def _parallel_decision_function(estimators, estimators_features, X): def _parallel_predict_regression(estimators, estimators_features, X): """Private function used to compute predictions within a job.""" - return sum(estimator.predict(X[:, features]) - for estimator, features in zip(estimators, - estimators_features)) + return [estimator.predict(X[:, features]) + for estimator, features in zip(estimators, estimators_features)] class BaseBagging(with_metaclass(ABCMeta, BaseEnsemble)): @@ -791,7 +790,7 @@ def __init__(self, random_state=random_state, verbose=verbose) - def predict(self, X): + def predict(self, X, with_std=False): """Predict regression target for X. The predicted regression target of an input sample is computed as the @@ -803,10 +802,19 @@ def predict(self, X): The training input samples. Sparse matrices are accepted only if they are supported by the base estimator. + with_std : boolean, optional, default=False + A boolean specifying whether the standard deviation of the + predictions is evaluated or not. + Default assumes with_std = False and evaluates only the mean + prediction. + Returns ------- y : array of shape = [n_samples] The predicted values. + + y_std : array of shape = [n_samples] + The standard deviation of the predicted values. """ # Check data X = check_array(X, accept_sparse=['csr', 'csc', 'coo']) @@ -820,12 +828,14 @@ def predict(self, X): self.estimators_[starts[i]:starts[i + 1]], self.estimators_features_[starts[i]:starts[i + 1]], X) - for i in range(n_jobs)) + for i in range(n_jobs))[0] # Reduce - y_hat = sum(all_y_hat) / self.n_estimators - - return y_hat + y_hat = np.mean(all_y_hat, axis=0) + if with_std: + return y_hat, np.std(all_y_hat, axis=0) + else: + return y_hat def _validate_estimator(self): """Check the estimator and set the base_estimator_ attribute.""" diff --git a/sklearn/ensemble/forest.py b/sklearn/ensemble/forest.py index 16dd17d39892d..80dffe273a35a 100644 --- a/sklearn/ensemble/forest.py +++ b/sklearn/ensemble/forest.py @@ -38,7 +38,6 @@ class calls the ``fit`` method of each sub-estimator on random samples from __future__ import division -from itertools import chain import numpy as np from warnings import warn from abc import ABCMeta, abstractmethod @@ -524,7 +523,7 @@ def __init__(self, verbose=verbose, warm_start=warm_start) - def predict(self, X): + def predict(self, X, with_std=False): """Predict regression target for X. The predicted regression target of an input sample is computed as the @@ -535,10 +534,19 @@ def predict(self, X): X : array-like of shape = [n_samples, n_features] The input samples. + with_std : boolean, optional, default=False + A boolean specifying whether the standard deviation of the + predictions is evaluated or not. + Default assumes with_std = False and evaluates only the mean + prediction. + Returns ------- y: array of shape = [n_samples] or [n_samples, n_outputs] The predicted values. + + y_std : array of shape = [n_samples] + The standard deviation of the predicted values. """ # Check data if getattr(X, "dtype", None) != DTYPE or X.ndim != 2: @@ -554,10 +562,11 @@ def predict(self, X): delayed(_parallel_helper)(e, 'predict', X) for e in self.estimators_) - # Reduce - y_hat = sum(all_y_hat) / len(self.estimators_) - - return y_hat + y_hat = np.mean(all_y_hat, axis=0) + if with_std: + return y_hat, np.std(all_y_hat, axis=0) + else: + return y_hat def _set_oob_score(self, X, y): n_samples = y.shape[0] From 4893e9daaa699c3134a9a02b18571205d53a8e58 Mon Sep 17 00:00:00 2001 From: Jan Hendrik Metzen Date: Sun, 7 Sep 2014 10:22:43 +0200 Subject: [PATCH 2/9] REFACTOR Add option with_std to GaussianProcess.predict This is consistent with the interface of RandomForestRegressor.predict The old way of requesting the predictive variance via eval_MSE is deprecated --- sklearn/gaussian_process/gaussian_process.py | 58 ++++++++++++-------- 1 file changed, 34 insertions(+), 24 deletions(-) diff --git a/sklearn/gaussian_process/gaussian_process.py b/sklearn/gaussian_process/gaussian_process.py index 07a334b14388f..e61c7225510d8 100644 --- a/sklearn/gaussian_process/gaussian_process.py +++ b/sklearn/gaussian_process/gaussian_process.py @@ -6,6 +6,8 @@ from __future__ import print_function +import warnings + import numpy as np from scipy import linalg, optimize @@ -374,7 +376,7 @@ def fit(self, X, y): return self - def predict(self, X, eval_MSE=False, batch_size=None): + def predict(self, X, eval_MSE=False, batch_size=None, with_std=False): """ This function evaluates the Gaussian Process model at x. @@ -396,6 +398,12 @@ def predict(self, X, eval_MSE=False, batch_size=None): Default is None so that all given points are evaluated at the same time. + with_std : boolean, optional, default=False + A boolean specifying whether the standard deviation of the + predictions is evaluated or not. + Default assumes with_std = False and evaluates only the mean + prediction. + Returns ------- y : array_like, shape (n_samples, ) or (n_samples, n_targets) @@ -405,10 +413,20 @@ def predict(self, X, eval_MSE=False, batch_size=None): of shape (n_samples, n_targets) with the Best Linear Unbiased Prediction at x. - MSE : array_like, optional (if eval_MSE == True) + y_std : array_like, optional (if with_std == True) An array with shape (n_eval, ) or (n_eval, n_targets) as with y, - with the Mean Squared Error at x. + containing the standard deviation of the prediction at x. Only + returned when with_std is True. If the deprecated eval_MSE is True, + the variance (MSE) of the prediction is returned instead. """ + if eval_MSE: + warnings.warn("The eval_MSE parameter is deprecated as of version " + "0.16 and will be removed in 0.18. Use the parameter" + " with_std instead, which returns the standard " + "deviation of the prediction.", DeprecationWarning) + assert with_std is False, \ + "with_std and eval_MSE cannot both be True at the same time." + with_std = True # Check input shapes X = check_array(X) @@ -432,11 +450,6 @@ def predict(self, X, eval_MSE=False, batch_size=None): # Normalize input X = (X - self.X_mean) / self.X_std - # Initialize output - y = np.zeros(n_eval) - if eval_MSE: - MSE = np.zeros(n_eval) - # Get pairwise componentwise L1-distances to the input training set dx = manhattan_distances(X, Y=self.X, sum_over_features=False) # Get regression function and correlation @@ -453,7 +466,7 @@ def predict(self, X, eval_MSE=False, batch_size=None): y = y.ravel() # Mean Squared Error - if eval_MSE: + if with_std: C = self.C if C is None: # Light storage mode (need to recompute C, F, Ft and G) @@ -489,39 +502,36 @@ def predict(self, X, eval_MSE=False, batch_size=None): if self.y_ndim_ == 1: MSE = MSE.ravel() - return y, MSE - + if eval_MSE: # deprecated + return y, MSE + else: + return y, np.sqrt(MSE) else: - return y - else: # Memory management - if type(batch_size) is not int or batch_size <= 0: raise Exception("batch_size must be a positive integer") - if eval_MSE: - - y, MSE = np.zeros(n_eval), np.zeros(n_eval) + if with_std: + y, y_std = np.zeros(n_eval), np.zeros(n_eval) for k in range(max(1, n_eval / batch_size)): batch_from = k * batch_size batch_to = min([(k + 1) * batch_size + 1, n_eval + 1]) y[batch_from:batch_to], MSE[batch_from:batch_to] = \ self.predict(X[batch_from:batch_to], - eval_MSE=eval_MSE, batch_size=None) - - return y, MSE - + with_std=with_std, batch_size=None) + if eval_MSE: # Deprecated + return y, y_std ** 2 + else: + return y, y_std else: - y = np.zeros(n_eval) for k in range(max(1, n_eval / batch_size)): batch_from = k * batch_size batch_to = min([(k + 1) * batch_size + 1, n_eval + 1]) y[batch_from:batch_to] = \ - self.predict(X[batch_from:batch_to], - eval_MSE=eval_MSE, batch_size=None) + self.predict(X[batch_from:batch_to], batch_size=None) return y From e99d16e2f72887f2439173538ebbf1c3ff46ed59 Mon Sep 17 00:00:00 2001 From: Jan Hendrik Metzen Date: Sun, 7 Sep 2014 10:23:14 +0200 Subject: [PATCH 3/9] REFACTOR Tests and examples of GaussianProcess use with_std instead of eval_MSE --- ...ilistic_classification_after_regression.py | 3 +-- .../gaussian_process/plot_gp_regression.py | 18 ++++++++--------- .../tests/test_gaussian_process.py | 20 +++++++++---------- 3 files changed, 19 insertions(+), 22 deletions(-) diff --git a/examples/gaussian_process/plot_gp_probabilistic_classification_after_regression.py b/examples/gaussian_process/plot_gp_probabilistic_classification_after_regression.py index 3c9887aa66852..6915000dcea15 100644 --- a/examples/gaussian_process/plot_gp_probabilistic_classification_after_regression.py +++ b/examples/gaussian_process/plot_gp_probabilistic_classification_after_regression.py @@ -65,8 +65,7 @@ def g(x): xx = np.vstack([x1.reshape(x1.size), x2.reshape(x2.size)]).T y_true = g(xx) -y_pred, MSE = gp.predict(xx, eval_MSE=True) -sigma = np.sqrt(MSE) +y_pred, sigma = gp.predict(xx, with_std=True) y_true = y_true.reshape((res, res)) y_pred = y_pred.reshape((res, res)) sigma = sigma.reshape((res, res)) diff --git a/examples/gaussian_process/plot_gp_regression.py b/examples/gaussian_process/plot_gp_regression.py index 33b78750d1fe4..ef822eeb142f4 100644 --- a/examples/gaussian_process/plot_gp_regression.py +++ b/examples/gaussian_process/plot_gp_regression.py @@ -52,7 +52,7 @@ def f(x): y = f(X).ravel() # Mesh the input space for evaluations of the real function, the prediction and -# its MSE +# its standard deviation x = np.atleast_2d(np.linspace(0, 10, 1000)).T # Instanciate a Gaussian Process model @@ -62,12 +62,11 @@ def f(x): # Fit to data using Maximum Likelihood Estimation of the parameters gp.fit(X, y) -# Make the prediction on the meshed x-axis (ask for MSE as well) -y_pred, MSE = gp.predict(x, eval_MSE=True) -sigma = np.sqrt(MSE) +# Make the prediction on the meshed x-axis (ask for standard deviation as well) +y_pred, sigma = gp.predict(x, with_std=True) # Plot the function, the prediction and the 95% confidence interval based on -# the MSE +# the standard deviation fig = pl.figure() pl.plot(x, f(x), 'r:', label=u'$f(x) = x\,\sin(x)$') pl.plot(X, y, 'r.', markersize=10, label=u'Observations') @@ -93,7 +92,7 @@ def f(x): y += noise # Mesh the input space for evaluations of the real function, the prediction and -# its MSE +# its standard deviation x = np.atleast_2d(np.linspace(0, 10, 1000)).T # Instanciate a Gaussian Process model @@ -105,12 +104,11 @@ def f(x): # Fit to data using Maximum Likelihood Estimation of the parameters gp.fit(X, y) -# Make the prediction on the meshed x-axis (ask for MSE as well) -y_pred, MSE = gp.predict(x, eval_MSE=True) -sigma = np.sqrt(MSE) +# Make the prediction on the meshed x-axis (ask for standard deviation as well) +y_pred, sigma = gp.predict(x, with_std=True) # Plot the function, the prediction and the 95% confidence interval based on -# the MSE +# the standard deviation fig = pl.figure() pl.plot(x, f(x), 'r:', label=u'$f(x) = x\,\sin(x)$') pl.errorbar(X.ravel(), y, dy, fmt='r.', markersize=10, label=u'Observations') diff --git a/sklearn/gaussian_process/tests/test_gaussian_process.py b/sklearn/gaussian_process/tests/test_gaussian_process.py index 517fcb047aadf..3b6babbed6139 100644 --- a/sklearn/gaussian_process/tests/test_gaussian_process.py +++ b/sklearn/gaussian_process/tests/test_gaussian_process.py @@ -33,11 +33,11 @@ def test_1d(regr=regression.constant, corr=correlation.squared_exponential, gp = GaussianProcess(regr=regr, corr=corr, beta0=beta0, theta0=1e-2, thetaL=1e-4, thetaU=1e-1, random_start=random_start, verbose=False).fit(X, y) - y_pred, MSE = gp.predict(X, eval_MSE=True) - y2_pred, MSE2 = gp.predict(X2, eval_MSE=True) + y_pred, y_std = gp.predict(X, with_std=True) + _, y_std2 = gp.predict(X2, with_std=True) - assert_true(np.allclose(y_pred, y) and np.allclose(MSE, 0.) - and np.allclose(MSE2, 0., atol=10)) + assert_true(np.allclose(y_pred, y) and np.allclose(y_std ** 2, 0.) + and np.allclose(y_std2 ** 2, 0., atol=10)) def test_2d(regr=regression.constant, corr=correlation.squared_exponential, @@ -67,12 +67,12 @@ def test_2d(regr=regression.constant, corr=correlation.squared_exponential, thetaU=thetaU, random_start=random_start, verbose=False) gp.fit(X, y) - y_pred, MSE = gp.predict(X, eval_MSE=True) + y_pred, y_std = gp.predict(X, with_std=True) - assert_true(np.allclose(y_pred, y) and np.allclose(MSE, 0.)) + assert_true(np.allclose(y_pred, y) and np.allclose(y_std ** 2, 0.)) - assert_true(np.all(gp.theta_ >= thetaL)) # Lower bounds of hyperparameters - assert_true(np.all(gp.theta_ <= thetaU)) # Upper bounds of hyperparameters + assert_true(np.all(gp.theta_ >= thetaL)) # Lower bounds of hyperparameters + assert_true(np.all(gp.theta_ <= thetaU)) # Upper bounds of hyperparameters def test_2d_2d(regr=regression.constant, corr=correlation.squared_exponential, @@ -100,9 +100,9 @@ def test_2d_2d(regr=regression.constant, corr=correlation.squared_exponential, thetaU=[1e-1] * 2, random_start=random_start, verbose=False) gp.fit(X, y) - y_pred, MSE = gp.predict(X, eval_MSE=True) + y_pred, y_std = gp.predict(X, with_std=True) - assert_true(np.allclose(y_pred, y) and np.allclose(MSE, 0.)) + assert_true(np.allclose(y_pred, y) and np.allclose(y_std ** 2, 0.)) @raises(ValueError) From c3ea011b1f5ef1ac11b8cfb073062c12c4ebe4e1 Mon Sep 17 00:00:00 2001 From: Jan Hendrik Metzen Date: Sun, 7 Sep 2014 10:24:21 +0200 Subject: [PATCH 4/9] ADD Example comparing the predictive distributions of different regressors --- .../plot_predictive_standard_deviation.py | 86 +++++++++++++++++++ 1 file changed, 86 insertions(+) create mode 100644 examples/plot_predictive_standard_deviation.py diff --git a/examples/plot_predictive_standard_deviation.py b/examples/plot_predictive_standard_deviation.py new file mode 100644 index 0000000000000..9234d75fc6315 --- /dev/null +++ b/examples/plot_predictive_standard_deviation.py @@ -0,0 +1,86 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +r""" +============================================================== +Comparison of predictive distributions of different regressors +============================================================== + +A simple one-dimensional regression problem adressed by three different +regressors: + +1. A Gaussian Process +2. A Random Forest +3. A Bagging-based Regressor + +The regressors are fitted based on noisy observations where the magnitude of +the noise at the different training point varies and is known. Plotted are +both the mean and the pointwise 95% confidence interval of the predictions. + +This example is based on the example gaussian_process/plot_gp_regression.py. + +""" +print(__doc__) + +# Author: Jan Hendrik Metzen +# Licence: BSD 3 clause + +import numpy as np +from sklearn.gaussian_process import GaussianProcess +from sklearn.ensemble import RandomForestRegressor, BaggingRegressor +from matplotlib import pyplot as pl + +np.random.seed(1) + + +def f(x): + """The function to predict.""" + return x * np.sin(x) + +X = np.linspace(0.1, 9.9, 20) +X = np.atleast_2d(X).T + +# Observations and noise +y = f(X).ravel() +dy = 0.5 + 1.0 * np.random.random(y.shape) +noise = np.random.normal(0, dy) +y += noise + +# Mesh the input space for evaluations of the real function, the prediction and +# its standard deviation +x = np.atleast_2d(np.linspace(0, 10, 1000)).T + +regrs = {"gaussian_process": GaussianProcess(corr='squared_exponential', + theta0=1e-1, thetaL=1e-3, + thetaU=1, nugget=(dy / y) ** 2, + random_start=100), + "random_forest": RandomForestRegressor(n_estimators=250), + "bagging": BaggingRegressor(n_estimators=250)} + + +# Plot predictive distributions of different regressors +fig = pl.figure() +colors = {"gaussian_process": 'b', "random_forest": 'g', "bagging": 'c'} +for name, regr in regrs.items(): + regr.fit(X, y) + + # Make the prediction on the meshed x-axis (ask for standard deviation + # as well) + y_pred, sigma = regr.predict(x, with_std=True) + + # Plot 95% confidence interval based on the predictive standard deviation + pl.plot(x, y_pred, colors[name], label=name) + pl.fill(np.concatenate([x, x[::-1]]), + np.concatenate([y_pred - 1.9600 * sigma, + (y_pred + 1.9600 * sigma)[::-1]]), + alpha=.3, fc=colors[name], ec='None') + +# Plot the function and the observations +pl.plot(x, f(x), 'r:', label=u'$f(x) = x\,\sin(x)$') +pl.errorbar(X.ravel(), y, dy, fmt='r.', markersize=10, label=u'Observations') +pl.xlabel('$x$') +pl.ylabel('$f(x)$') +pl.ylim(-10, 20) +pl.legend(loc='upper left') + +pl.show() From bb625d5534659050d28f9127f8d957f4eb5fa278 Mon Sep 17 00:00:00 2001 From: Jan Hendrik Metzen Date: Sun, 7 Sep 2014 17:26:04 +0200 Subject: [PATCH 5/9] DOC Improved documentation of with_std parameter of predict() method --- sklearn/ensemble/bagging.py | 17 ++++++++--------- sklearn/ensemble/forest.py | 17 ++++++++--------- sklearn/gaussian_process/gaussian_process.py | 6 ++---- 3 files changed, 18 insertions(+), 22 deletions(-) diff --git a/sklearn/ensemble/bagging.py b/sklearn/ensemble/bagging.py index 6e899a03e02a5..b23ff26d62504 100644 --- a/sklearn/ensemble/bagging.py +++ b/sklearn/ensemble/bagging.py @@ -795,6 +795,7 @@ def predict(self, X, with_std=False): The predicted regression target of an input sample is computed as the mean predicted regression targets of the estimators in the ensemble. + Optionally, the standard deviation is computed in addition. Parameters ---------- @@ -803,15 +804,13 @@ def predict(self, X, with_std=False): they are supported by the base estimator. with_std : boolean, optional, default=False - A boolean specifying whether the standard deviation of the - predictions is evaluated or not. - Default assumes with_std = False and evaluates only the mean - prediction. + When True, the standard deviation of predictions across the + ensemble is returned in addition to the mean. Returns ------- - y : array of shape = [n_samples] - The predicted values. + y_mean : array of shape = [n_samples] + The mean of the predicted values. y_std : array of shape = [n_samples] The standard deviation of the predicted values. @@ -831,11 +830,11 @@ def predict(self, X, with_std=False): for i in range(n_jobs))[0] # Reduce - y_hat = np.mean(all_y_hat, axis=0) + y_mean = np.mean(all_y_hat, axis=0) if with_std: - return y_hat, np.std(all_y_hat, axis=0) + return y_mean, np.std(all_y_hat, axis=0) else: - return y_hat + return y_mean def _validate_estimator(self): """Check the estimator and set the base_estimator_ attribute.""" diff --git a/sklearn/ensemble/forest.py b/sklearn/ensemble/forest.py index 80dffe273a35a..fb60c0b915468 100644 --- a/sklearn/ensemble/forest.py +++ b/sklearn/ensemble/forest.py @@ -528,6 +528,7 @@ def predict(self, X, with_std=False): The predicted regression target of an input sample is computed as the mean predicted regression targets of the trees in the forest. + Optionally, the standard deviation is computed in addition. Parameters ---------- @@ -535,15 +536,13 @@ def predict(self, X, with_std=False): The input samples. with_std : boolean, optional, default=False - A boolean specifying whether the standard deviation of the - predictions is evaluated or not. - Default assumes with_std = False and evaluates only the mean - prediction. + When True, the standard deviation of predictions across the + ensemble is returned in addition to the mean. Returns ------- - y: array of shape = [n_samples] or [n_samples, n_outputs] - The predicted values. + y_mean: array of shape = [n_samples] or [n_samples, n_outputs] + The mean of the predicted values. y_std : array of shape = [n_samples] The standard deviation of the predicted values. @@ -562,11 +561,11 @@ def predict(self, X, with_std=False): delayed(_parallel_helper)(e, 'predict', X) for e in self.estimators_) - y_hat = np.mean(all_y_hat, axis=0) + y_mean = np.mean(all_y_hat, axis=0) if with_std: - return y_hat, np.std(all_y_hat, axis=0) + return y_mean, np.std(all_y_hat, axis=0) else: - return y_hat + return y_mean def _set_oob_score(self, X, y): n_samples = y.shape[0] diff --git a/sklearn/gaussian_process/gaussian_process.py b/sklearn/gaussian_process/gaussian_process.py index e61c7225510d8..6461d6f040058 100644 --- a/sklearn/gaussian_process/gaussian_process.py +++ b/sklearn/gaussian_process/gaussian_process.py @@ -399,10 +399,8 @@ def predict(self, X, eval_MSE=False, batch_size=None, with_std=False): time. with_std : boolean, optional, default=False - A boolean specifying whether the standard deviation of the - predictions is evaluated or not. - Default assumes with_std = False and evaluates only the mean - prediction. + When True, the standard deviation of predictions across the + ensemble is returned in addition to the mean. Returns ------- From 8c23d85daeb20d12de33cecc93b09c7047bd3d15 Mon Sep 17 00:00:00 2001 From: Jan Hendrik Metzen Date: Sun, 7 Sep 2014 17:37:38 +0200 Subject: [PATCH 6/9] FIX Bug in BaggingRegressor using _parallel_predict_regression --- sklearn/ensemble/bagging.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sklearn/ensemble/bagging.py b/sklearn/ensemble/bagging.py index b23ff26d62504..144cbafd0148d 100644 --- a/sklearn/ensemble/bagging.py +++ b/sklearn/ensemble/bagging.py @@ -827,9 +827,10 @@ def predict(self, X, with_std=False): self.estimators_[starts[i]:starts[i + 1]], self.estimators_features_[starts[i]:starts[i + 1]], X) - for i in range(n_jobs))[0] + for i in range(n_jobs)) # Reduce + all_y_hat = np.array(all_y_hat).reshape(self.n_estimators, -1) y_mean = np.mean(all_y_hat, axis=0) if with_std: return y_mean, np.std(all_y_hat, axis=0) From 1780535d6b3619047b76ae2a29fd9b29e4cc175a Mon Sep 17 00:00:00 2001 From: Jan Hendrik Metzen Date: Mon, 8 Sep 2014 17:08:28 +0200 Subject: [PATCH 7/9] DOC More consistent documentation of optional return-value y_std of predict --- sklearn/ensemble/bagging.py | 2 +- sklearn/ensemble/forest.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sklearn/ensemble/bagging.py b/sklearn/ensemble/bagging.py index 144cbafd0148d..98ec32a87f56c 100644 --- a/sklearn/ensemble/bagging.py +++ b/sklearn/ensemble/bagging.py @@ -812,7 +812,7 @@ def predict(self, X, with_std=False): y_mean : array of shape = [n_samples] The mean of the predicted values. - y_std : array of shape = [n_samples] + y_std : array of shape = [n_samples], optional (if with_std == True) The standard deviation of the predicted values. """ # Check data diff --git a/sklearn/ensemble/forest.py b/sklearn/ensemble/forest.py index fb60c0b915468..ef9ed7ffce095 100644 --- a/sklearn/ensemble/forest.py +++ b/sklearn/ensemble/forest.py @@ -544,7 +544,7 @@ def predict(self, X, with_std=False): y_mean: array of shape = [n_samples] or [n_samples, n_outputs] The mean of the predicted values. - y_std : array of shape = [n_samples] + y_std : array of shape = [n_samples], optional (if with_std == True) The standard deviation of the predicted values. """ # Check data From 3eb47719948d5fc50a6ffb219246c7a7c36deb2f Mon Sep 17 00:00:00 2001 From: Jan Hendrik Metzen Date: Sun, 14 Sep 2014 18:51:35 +0200 Subject: [PATCH 8/9] DOC Updated doc of predict() of BaggingRegressor and RandomForestRegressor --- sklearn/ensemble/bagging.py | 9 +++++---- sklearn/ensemble/forest.py | 7 ++++--- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/sklearn/ensemble/bagging.py b/sklearn/ensemble/bagging.py index 98ec32a87f56c..45a8997c79728 100644 --- a/sklearn/ensemble/bagging.py +++ b/sklearn/ensemble/bagging.py @@ -795,7 +795,8 @@ def predict(self, X, with_std=False): The predicted regression target of an input sample is computed as the mean predicted regression targets of the estimators in the ensemble. - Optionally, the standard deviation is computed in addition. + Optionally, the standard deviation of the predictions of the ensemble's + estimators is computed in addition. Parameters ---------- @@ -804,8 +805,8 @@ def predict(self, X, with_std=False): they are supported by the base estimator. with_std : boolean, optional, default=False - When True, the standard deviation of predictions across the - ensemble is returned in addition to the mean. + When True, the standard deviation of the predictions of the + ensemble's estimators is returned in addition to the mean. Returns ------- @@ -813,7 +814,7 @@ def predict(self, X, with_std=False): The mean of the predicted values. y_std : array of shape = [n_samples], optional (if with_std == True) - The standard deviation of the predicted values. + The standard deviation of the ensemble's predicted values. """ # Check data X = check_array(X, accept_sparse=['csr', 'csc', 'coo']) diff --git a/sklearn/ensemble/forest.py b/sklearn/ensemble/forest.py index ef9ed7ffce095..e44081682b6c0 100644 --- a/sklearn/ensemble/forest.py +++ b/sklearn/ensemble/forest.py @@ -528,7 +528,8 @@ def predict(self, X, with_std=False): The predicted regression target of an input sample is computed as the mean predicted regression targets of the trees in the forest. - Optionally, the standard deviation is computed in addition. + Optionally, the standard deviation of the predictions of the ensemble's + estimators is computed in addition. Parameters ---------- @@ -536,8 +537,8 @@ def predict(self, X, with_std=False): The input samples. with_std : boolean, optional, default=False - When True, the standard deviation of predictions across the - ensemble is returned in addition to the mean. + When True, the standard deviation of the predictions of the + ensemble's estimators is returned in addition to the mean. Returns ------- From 6043345202988b91d016dc925c785a98ee07e2d8 Mon Sep 17 00:00:00 2001 From: Jan Hendrik Metzen Date: Sun, 14 Sep 2014 19:56:21 +0200 Subject: [PATCH 9/9] ENH Extending example plot_predictive_standard_deviation.py --- .../plot_predictive_standard_deviation.py | 59 +++++++++++++++---- 1 file changed, 48 insertions(+), 11 deletions(-) diff --git a/examples/plot_predictive_standard_deviation.py b/examples/plot_predictive_standard_deviation.py index 9234d75fc6315..5319c4a98e5c6 100644 --- a/examples/plot_predictive_standard_deviation.py +++ b/examples/plot_predictive_standard_deviation.py @@ -6,7 +6,7 @@ Comparison of predictive distributions of different regressors ============================================================== -A simple one-dimensional regression problem adressed by three different +A simple one-dimensional, noisy regression problem adressed by three different regressors: 1. A Gaussian Process @@ -14,11 +14,14 @@ 3. A Bagging-based Regressor The regressors are fitted based on noisy observations where the magnitude of -the noise at the different training point varies and is known. Plotted are +the noise at the different training point is constant and known. Plotted are both the mean and the pointwise 95% confidence interval of the predictions. +The mean predictions are evaluated on noise-less test data using the mean- +squared-error. The mean log probabilities of the noise-less test data are used +to evaluate the predictive distributions (a normal distribution with the +predicted mean and standard deviation) of the three regressors. This example is based on the example gaussian_process/plot_gp_regression.py. - """ print(__doc__) @@ -26,8 +29,10 @@ # Licence: BSD 3 clause import numpy as np +from scipy.stats import norm from sklearn.gaussian_process import GaussianProcess from sklearn.ensemble import RandomForestRegressor, BaggingRegressor +from sklearn.metrics import mean_squared_error from matplotlib import pyplot as pl np.random.seed(1) @@ -42,7 +47,7 @@ def f(x): # Observations and noise y = f(X).ravel() -dy = 0.5 + 1.0 * np.random.random(y.shape) +dy = np.ones_like(y) noise = np.random.normal(0, dy) y += noise @@ -50,17 +55,26 @@ def f(x): # its standard deviation x = np.atleast_2d(np.linspace(0, 10, 1000)).T -regrs = {"gaussian_process": GaussianProcess(corr='squared_exponential', +regrs = {"Gaussian Process": GaussianProcess(corr='squared_exponential', theta0=1e-1, thetaL=1e-3, thetaU=1, nugget=(dy / y) ** 2, random_start=100), - "random_forest": RandomForestRegressor(n_estimators=250), - "bagging": BaggingRegressor(n_estimators=250)} + "Random Forest": RandomForestRegressor(n_estimators=250), + "Bagging": BaggingRegressor(n_estimators=250)} # Plot predictive distributions of different regressors fig = pl.figure() -colors = {"gaussian_process": 'b', "random_forest": 'g', "bagging": 'c'} +# Plot the function and the observations +pl.plot(x, f(x), 'r', label=u'$f(x) = x\,\sin(x)$') +pl.fill(np.concatenate([x, x[::-1]]), + np.concatenate([f(x) - 1.9600, (f(x) + 1.9600)[::-1]]), + alpha=.3, fc='r', ec='None') +pl.plot(X.ravel(), y, 'ko', zorder=5, label=u'Observations') +# Plot predictive distibutions of GP and Bagging +colors = {"Gaussian Process": 'b', "Bagging": 'g'} +mse = {} +log_pdf_loss = {} for name, regr in regrs.items(): regr.fit(X, y) @@ -68,6 +82,14 @@ def f(x): # as well) y_pred, sigma = regr.predict(x, with_std=True) + # Compute mean-squared error and log predictive loss + mse[name] = mean_squared_error(f(x), y_pred) + log_pdf_loss[name] = \ + norm(y_pred, sigma).logpdf(f(x)).mean() + + if name == "Random Forest": # Skip because RF is very similar to Bagging + continue + # Plot 95% confidence interval based on the predictive standard deviation pl.plot(x, y_pred, colors[name], label=name) pl.fill(np.concatenate([x, x[::-1]]), @@ -75,12 +97,27 @@ def f(x): (y_pred + 1.9600 * sigma)[::-1]]), alpha=.3, fc=colors[name], ec='None') -# Plot the function and the observations -pl.plot(x, f(x), 'r:', label=u'$f(x) = x\,\sin(x)$') -pl.errorbar(X.ravel(), y, dy, fmt='r.', markersize=10, label=u'Observations') + pl.xlabel('$x$') pl.ylabel('$f(x)$') pl.ylim(-10, 20) pl.legend(loc='upper left') +print "Mean-squared error of predictors on 1000 equidistant noise-less test " \ + "datapoints:\n\tRandom Forest: %.2f\n\tBagging: %.2f" \ + "\n\tGaussian Process: %.2f" \ + % (mse["Random Forest"], mse["Bagging"], mse["Gaussian Process"]) + +print "Mean log-probability of 1000 equidistant noise-less test datapoints " \ + "under the (normal) predictive distribution of the predictors, i.e., " \ + "log N(y_true| y_pred_mean, y_pred_std) [less is better]:"\ + "\n\tRandom Forest: %.2f\n\tBagging: %.2f\n\tGaussian Process: %.2f" \ + % (log_pdf_loss["Random Forest"], log_pdf_loss["Bagging"], + log_pdf_loss["Gaussian Process"]) + +print "In summary, the mean predictions of the Gaussian Process are slightly "\ + "better than those of Random Forest and Bagging. The predictive " \ + "distributions (taking into account also the predictive variance) " \ + "of the Gaussian Process are considerably better." + pl.show()