diff --git a/examples/gaussian_process/plot_gp_probabilistic_classification_after_regression.py b/examples/gaussian_process/plot_gp_probabilistic_classification_after_regression.py index 3c9887aa66852..6915000dcea15 100644 --- a/examples/gaussian_process/plot_gp_probabilistic_classification_after_regression.py +++ b/examples/gaussian_process/plot_gp_probabilistic_classification_after_regression.py @@ -65,8 +65,7 @@ def g(x): xx = np.vstack([x1.reshape(x1.size), x2.reshape(x2.size)]).T y_true = g(xx) -y_pred, MSE = gp.predict(xx, eval_MSE=True) -sigma = np.sqrt(MSE) +y_pred, sigma = gp.predict(xx, with_std=True) y_true = y_true.reshape((res, res)) y_pred = y_pred.reshape((res, res)) sigma = sigma.reshape((res, res)) diff --git a/examples/gaussian_process/plot_gp_regression.py b/examples/gaussian_process/plot_gp_regression.py index 33b78750d1fe4..ef822eeb142f4 100644 --- a/examples/gaussian_process/plot_gp_regression.py +++ b/examples/gaussian_process/plot_gp_regression.py @@ -52,7 +52,7 @@ def f(x): y = f(X).ravel() # Mesh the input space for evaluations of the real function, the prediction and -# its MSE +# its standard deviation x = np.atleast_2d(np.linspace(0, 10, 1000)).T # Instanciate a Gaussian Process model @@ -62,12 +62,11 @@ def f(x): # Fit to data using Maximum Likelihood Estimation of the parameters gp.fit(X, y) -# Make the prediction on the meshed x-axis (ask for MSE as well) -y_pred, MSE = gp.predict(x, eval_MSE=True) -sigma = np.sqrt(MSE) +# Make the prediction on the meshed x-axis (ask for standard deviation as well) +y_pred, sigma = gp.predict(x, with_std=True) # Plot the function, the prediction and the 95% confidence interval based on -# the MSE +# the standard deviation fig = pl.figure() pl.plot(x, f(x), 'r:', label=u'$f(x) = x\,\sin(x)$') pl.plot(X, y, 'r.', markersize=10, label=u'Observations') @@ -93,7 +92,7 @@ def f(x): y += noise # Mesh the input space for evaluations of the real function, the prediction and -# its MSE +# its standard deviation x = np.atleast_2d(np.linspace(0, 10, 1000)).T # Instanciate a Gaussian Process model @@ -105,12 +104,11 @@ def f(x): # Fit to data using Maximum Likelihood Estimation of the parameters gp.fit(X, y) -# Make the prediction on the meshed x-axis (ask for MSE as well) -y_pred, MSE = gp.predict(x, eval_MSE=True) -sigma = np.sqrt(MSE) +# Make the prediction on the meshed x-axis (ask for standard deviation as well) +y_pred, sigma = gp.predict(x, with_std=True) # Plot the function, the prediction and the 95% confidence interval based on -# the MSE +# the standard deviation fig = pl.figure() pl.plot(x, f(x), 'r:', label=u'$f(x) = x\,\sin(x)$') pl.errorbar(X.ravel(), y, dy, fmt='r.', markersize=10, label=u'Observations') diff --git a/examples/plot_predictive_standard_deviation.py b/examples/plot_predictive_standard_deviation.py new file mode 100644 index 0000000000000..5319c4a98e5c6 --- /dev/null +++ b/examples/plot_predictive_standard_deviation.py @@ -0,0 +1,123 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +r""" +============================================================== +Comparison of predictive distributions of different regressors +============================================================== + +A simple one-dimensional, noisy regression problem adressed by three different +regressors: + +1. A Gaussian Process +2. A Random Forest +3. A Bagging-based Regressor + +The regressors are fitted based on noisy observations where the magnitude of +the noise at the different training point is constant and known. Plotted are +both the mean and the pointwise 95% confidence interval of the predictions. +The mean predictions are evaluated on noise-less test data using the mean- +squared-error. The mean log probabilities of the noise-less test data are used +to evaluate the predictive distributions (a normal distribution with the +predicted mean and standard deviation) of the three regressors. + +This example is based on the example gaussian_process/plot_gp_regression.py. +""" +print(__doc__) + +# Author: Jan Hendrik Metzen +# Licence: BSD 3 clause + +import numpy as np +from scipy.stats import norm +from sklearn.gaussian_process import GaussianProcess +from sklearn.ensemble import RandomForestRegressor, BaggingRegressor +from sklearn.metrics import mean_squared_error +from matplotlib import pyplot as pl + +np.random.seed(1) + + +def f(x): + """The function to predict.""" + return x * np.sin(x) + +X = np.linspace(0.1, 9.9, 20) +X = np.atleast_2d(X).T + +# Observations and noise +y = f(X).ravel() +dy = np.ones_like(y) +noise = np.random.normal(0, dy) +y += noise + +# Mesh the input space for evaluations of the real function, the prediction and +# its standard deviation +x = np.atleast_2d(np.linspace(0, 10, 1000)).T + +regrs = {"Gaussian Process": GaussianProcess(corr='squared_exponential', + theta0=1e-1, thetaL=1e-3, + thetaU=1, nugget=(dy / y) ** 2, + random_start=100), + "Random Forest": RandomForestRegressor(n_estimators=250), + "Bagging": BaggingRegressor(n_estimators=250)} + + +# Plot predictive distributions of different regressors +fig = pl.figure() +# Plot the function and the observations +pl.plot(x, f(x), 'r', label=u'$f(x) = x\,\sin(x)$') +pl.fill(np.concatenate([x, x[::-1]]), + np.concatenate([f(x) - 1.9600, (f(x) + 1.9600)[::-1]]), + alpha=.3, fc='r', ec='None') +pl.plot(X.ravel(), y, 'ko', zorder=5, label=u'Observations') +# Plot predictive distibutions of GP and Bagging +colors = {"Gaussian Process": 'b', "Bagging": 'g'} +mse = {} +log_pdf_loss = {} +for name, regr in regrs.items(): + regr.fit(X, y) + + # Make the prediction on the meshed x-axis (ask for standard deviation + # as well) + y_pred, sigma = regr.predict(x, with_std=True) + + # Compute mean-squared error and log predictive loss + mse[name] = mean_squared_error(f(x), y_pred) + log_pdf_loss[name] = \ + norm(y_pred, sigma).logpdf(f(x)).mean() + + if name == "Random Forest": # Skip because RF is very similar to Bagging + continue + + # Plot 95% confidence interval based on the predictive standard deviation + pl.plot(x, y_pred, colors[name], label=name) + pl.fill(np.concatenate([x, x[::-1]]), + np.concatenate([y_pred - 1.9600 * sigma, + (y_pred + 1.9600 * sigma)[::-1]]), + alpha=.3, fc=colors[name], ec='None') + + +pl.xlabel('$x$') +pl.ylabel('$f(x)$') +pl.ylim(-10, 20) +pl.legend(loc='upper left') + +print "Mean-squared error of predictors on 1000 equidistant noise-less test " \ + "datapoints:\n\tRandom Forest: %.2f\n\tBagging: %.2f" \ + "\n\tGaussian Process: %.2f" \ + % (mse["Random Forest"], mse["Bagging"], mse["Gaussian Process"]) + +print "Mean log-probability of 1000 equidistant noise-less test datapoints " \ + "under the (normal) predictive distribution of the predictors, i.e., " \ + "log N(y_true| y_pred_mean, y_pred_std) [less is better]:"\ + "\n\tRandom Forest: %.2f\n\tBagging: %.2f\n\tGaussian Process: %.2f" \ + % (log_pdf_loss["Random Forest"], log_pdf_loss["Bagging"], + log_pdf_loss["Gaussian Process"]) + +print "In summary, the mean predictions of the Gaussian Process are slightly "\ + "better than those of Random Forest and Bagging. The predictive " \ + "distributions (taking into account also the predictive variance) " \ + "of the Gaussian Process are considerably better." + +pl.show() diff --git a/sklearn/ensemble/bagging.py b/sklearn/ensemble/bagging.py index 3bded083991fb..45a8997c79728 100644 --- a/sklearn/ensemble/bagging.py +++ b/sklearn/ensemble/bagging.py @@ -180,9 +180,8 @@ def _parallel_decision_function(estimators, estimators_features, X): def _parallel_predict_regression(estimators, estimators_features, X): """Private function used to compute predictions within a job.""" - return sum(estimator.predict(X[:, features]) - for estimator, features in zip(estimators, - estimators_features)) + return [estimator.predict(X[:, features]) + for estimator, features in zip(estimators, estimators_features)] class BaseBagging(with_metaclass(ABCMeta, BaseEnsemble)): @@ -791,11 +790,13 @@ def __init__(self, random_state=random_state, verbose=verbose) - def predict(self, X): + def predict(self, X, with_std=False): """Predict regression target for X. The predicted regression target of an input sample is computed as the mean predicted regression targets of the estimators in the ensemble. + Optionally, the standard deviation of the predictions of the ensemble's + estimators is computed in addition. Parameters ---------- @@ -803,10 +804,17 @@ def predict(self, X): The training input samples. Sparse matrices are accepted only if they are supported by the base estimator. + with_std : boolean, optional, default=False + When True, the standard deviation of the predictions of the + ensemble's estimators is returned in addition to the mean. + Returns ------- - y : array of shape = [n_samples] - The predicted values. + y_mean : array of shape = [n_samples] + The mean of the predicted values. + + y_std : array of shape = [n_samples], optional (if with_std == True) + The standard deviation of the ensemble's predicted values. """ # Check data X = check_array(X, accept_sparse=['csr', 'csc', 'coo']) @@ -823,9 +831,12 @@ def predict(self, X): for i in range(n_jobs)) # Reduce - y_hat = sum(all_y_hat) / self.n_estimators - - return y_hat + all_y_hat = np.array(all_y_hat).reshape(self.n_estimators, -1) + y_mean = np.mean(all_y_hat, axis=0) + if with_std: + return y_mean, np.std(all_y_hat, axis=0) + else: + return y_mean def _validate_estimator(self): """Check the estimator and set the base_estimator_ attribute.""" diff --git a/sklearn/ensemble/forest.py b/sklearn/ensemble/forest.py index 16dd17d39892d..e44081682b6c0 100644 --- a/sklearn/ensemble/forest.py +++ b/sklearn/ensemble/forest.py @@ -38,7 +38,6 @@ class calls the ``fit`` method of each sub-estimator on random samples from __future__ import division -from itertools import chain import numpy as np from warnings import warn from abc import ABCMeta, abstractmethod @@ -524,21 +523,30 @@ def __init__(self, verbose=verbose, warm_start=warm_start) - def predict(self, X): + def predict(self, X, with_std=False): """Predict regression target for X. The predicted regression target of an input sample is computed as the mean predicted regression targets of the trees in the forest. + Optionally, the standard deviation of the predictions of the ensemble's + estimators is computed in addition. Parameters ---------- X : array-like of shape = [n_samples, n_features] The input samples. + with_std : boolean, optional, default=False + When True, the standard deviation of the predictions of the + ensemble's estimators is returned in addition to the mean. + Returns ------- - y: array of shape = [n_samples] or [n_samples, n_outputs] - The predicted values. + y_mean: array of shape = [n_samples] or [n_samples, n_outputs] + The mean of the predicted values. + + y_std : array of shape = [n_samples], optional (if with_std == True) + The standard deviation of the predicted values. """ # Check data if getattr(X, "dtype", None) != DTYPE or X.ndim != 2: @@ -554,10 +562,11 @@ def predict(self, X): delayed(_parallel_helper)(e, 'predict', X) for e in self.estimators_) - # Reduce - y_hat = sum(all_y_hat) / len(self.estimators_) - - return y_hat + y_mean = np.mean(all_y_hat, axis=0) + if with_std: + return y_mean, np.std(all_y_hat, axis=0) + else: + return y_mean def _set_oob_score(self, X, y): n_samples = y.shape[0] diff --git a/sklearn/gaussian_process/gaussian_process.py b/sklearn/gaussian_process/gaussian_process.py index 07a334b14388f..6461d6f040058 100644 --- a/sklearn/gaussian_process/gaussian_process.py +++ b/sklearn/gaussian_process/gaussian_process.py @@ -6,6 +6,8 @@ from __future__ import print_function +import warnings + import numpy as np from scipy import linalg, optimize @@ -374,7 +376,7 @@ def fit(self, X, y): return self - def predict(self, X, eval_MSE=False, batch_size=None): + def predict(self, X, eval_MSE=False, batch_size=None, with_std=False): """ This function evaluates the Gaussian Process model at x. @@ -396,6 +398,10 @@ def predict(self, X, eval_MSE=False, batch_size=None): Default is None so that all given points are evaluated at the same time. + with_std : boolean, optional, default=False + When True, the standard deviation of predictions across the + ensemble is returned in addition to the mean. + Returns ------- y : array_like, shape (n_samples, ) or (n_samples, n_targets) @@ -405,10 +411,20 @@ def predict(self, X, eval_MSE=False, batch_size=None): of shape (n_samples, n_targets) with the Best Linear Unbiased Prediction at x. - MSE : array_like, optional (if eval_MSE == True) + y_std : array_like, optional (if with_std == True) An array with shape (n_eval, ) or (n_eval, n_targets) as with y, - with the Mean Squared Error at x. + containing the standard deviation of the prediction at x. Only + returned when with_std is True. If the deprecated eval_MSE is True, + the variance (MSE) of the prediction is returned instead. """ + if eval_MSE: + warnings.warn("The eval_MSE parameter is deprecated as of version " + "0.16 and will be removed in 0.18. Use the parameter" + " with_std instead, which returns the standard " + "deviation of the prediction.", DeprecationWarning) + assert with_std is False, \ + "with_std and eval_MSE cannot both be True at the same time." + with_std = True # Check input shapes X = check_array(X) @@ -432,11 +448,6 @@ def predict(self, X, eval_MSE=False, batch_size=None): # Normalize input X = (X - self.X_mean) / self.X_std - # Initialize output - y = np.zeros(n_eval) - if eval_MSE: - MSE = np.zeros(n_eval) - # Get pairwise componentwise L1-distances to the input training set dx = manhattan_distances(X, Y=self.X, sum_over_features=False) # Get regression function and correlation @@ -453,7 +464,7 @@ def predict(self, X, eval_MSE=False, batch_size=None): y = y.ravel() # Mean Squared Error - if eval_MSE: + if with_std: C = self.C if C is None: # Light storage mode (need to recompute C, F, Ft and G) @@ -489,39 +500,36 @@ def predict(self, X, eval_MSE=False, batch_size=None): if self.y_ndim_ == 1: MSE = MSE.ravel() - return y, MSE - + if eval_MSE: # deprecated + return y, MSE + else: + return y, np.sqrt(MSE) else: - return y - else: # Memory management - if type(batch_size) is not int or batch_size <= 0: raise Exception("batch_size must be a positive integer") - if eval_MSE: - - y, MSE = np.zeros(n_eval), np.zeros(n_eval) + if with_std: + y, y_std = np.zeros(n_eval), np.zeros(n_eval) for k in range(max(1, n_eval / batch_size)): batch_from = k * batch_size batch_to = min([(k + 1) * batch_size + 1, n_eval + 1]) y[batch_from:batch_to], MSE[batch_from:batch_to] = \ self.predict(X[batch_from:batch_to], - eval_MSE=eval_MSE, batch_size=None) - - return y, MSE - + with_std=with_std, batch_size=None) + if eval_MSE: # Deprecated + return y, y_std ** 2 + else: + return y, y_std else: - y = np.zeros(n_eval) for k in range(max(1, n_eval / batch_size)): batch_from = k * batch_size batch_to = min([(k + 1) * batch_size + 1, n_eval + 1]) y[batch_from:batch_to] = \ - self.predict(X[batch_from:batch_to], - eval_MSE=eval_MSE, batch_size=None) + self.predict(X[batch_from:batch_to], batch_size=None) return y diff --git a/sklearn/gaussian_process/tests/test_gaussian_process.py b/sklearn/gaussian_process/tests/test_gaussian_process.py index 517fcb047aadf..3b6babbed6139 100644 --- a/sklearn/gaussian_process/tests/test_gaussian_process.py +++ b/sklearn/gaussian_process/tests/test_gaussian_process.py @@ -33,11 +33,11 @@ def test_1d(regr=regression.constant, corr=correlation.squared_exponential, gp = GaussianProcess(regr=regr, corr=corr, beta0=beta0, theta0=1e-2, thetaL=1e-4, thetaU=1e-1, random_start=random_start, verbose=False).fit(X, y) - y_pred, MSE = gp.predict(X, eval_MSE=True) - y2_pred, MSE2 = gp.predict(X2, eval_MSE=True) + y_pred, y_std = gp.predict(X, with_std=True) + _, y_std2 = gp.predict(X2, with_std=True) - assert_true(np.allclose(y_pred, y) and np.allclose(MSE, 0.) - and np.allclose(MSE2, 0., atol=10)) + assert_true(np.allclose(y_pred, y) and np.allclose(y_std ** 2, 0.) + and np.allclose(y_std2 ** 2, 0., atol=10)) def test_2d(regr=regression.constant, corr=correlation.squared_exponential, @@ -67,12 +67,12 @@ def test_2d(regr=regression.constant, corr=correlation.squared_exponential, thetaU=thetaU, random_start=random_start, verbose=False) gp.fit(X, y) - y_pred, MSE = gp.predict(X, eval_MSE=True) + y_pred, y_std = gp.predict(X, with_std=True) - assert_true(np.allclose(y_pred, y) and np.allclose(MSE, 0.)) + assert_true(np.allclose(y_pred, y) and np.allclose(y_std ** 2, 0.)) - assert_true(np.all(gp.theta_ >= thetaL)) # Lower bounds of hyperparameters - assert_true(np.all(gp.theta_ <= thetaU)) # Upper bounds of hyperparameters + assert_true(np.all(gp.theta_ >= thetaL)) # Lower bounds of hyperparameters + assert_true(np.all(gp.theta_ <= thetaU)) # Upper bounds of hyperparameters def test_2d_2d(regr=regression.constant, corr=correlation.squared_exponential, @@ -100,9 +100,9 @@ def test_2d_2d(regr=regression.constant, corr=correlation.squared_exponential, thetaU=[1e-1] * 2, random_start=random_start, verbose=False) gp.fit(X, y) - y_pred, MSE = gp.predict(X, eval_MSE=True) + y_pred, y_std = gp.predict(X, with_std=True) - assert_true(np.allclose(y_pred, y) and np.allclose(MSE, 0.)) + assert_true(np.allclose(y_pred, y) and np.allclose(y_std ** 2, 0.)) @raises(ValueError)