diff --git a/examples/plot_predictive_standard_deviation.py b/examples/plot_predictive_standard_deviation.py new file mode 100644 index 0000000000000..f3341361be160 --- /dev/null +++ b/examples/plot_predictive_standard_deviation.py @@ -0,0 +1,120 @@ +""" +============================================================== +Comparison of predictive distributions of different regressors +============================================================== + +A simple one-dimensional, noisy regression problem addressed by two different +regressors: + +1. A Gaussian Process +2. Bagging with extra-trees + +The regressors are fitted based on noisy observations where the magnitude of +the noise at the different training point is constant and known. Plotted are +both the mean and the pointwise 95% confidence interval of the predictions. +The mean predictions are evaluated on noise-less test data using the mean +squared error. The mean log probabilities of the noise-less test data are used +to evaluate the predictive distributions (a normal distribution with the +predicted mean and standard deviation) of the two regressors. + +The mean predictions of the Gaussian Process are slightly better than those of +Random Forest. The predictive distribution (taking into account +also the predictive variance) of Bagging is however better. +""" +print(__doc__) + +# Authors: Jan Hendrik Metzen +# Gilles Louppe +# Licence: BSD 3 clause + +import matplotlib.pyplot as plt +import numpy as np +from scipy.stats import norm + +from sklearn.ensemble import BaggingRegressor +from sklearn.tree import ExtraTreeRegressor +from sklearn.gaussian_process import GaussianProcessRegressor +from sklearn.gaussian_process.kernels import RBF, WhiteKernel +from sklearn.metrics import mean_squared_error +from sklearn.utils import check_random_state + +rng = check_random_state(1) + + +# Observations and noise +def f(x): + """The function to predict.""" + return x * np.sin(x) + +X_train = rng.rand(50) * 10.0 +X_train = np.atleast_2d(X_train).T + +y_train = f(X_train).ravel() +dy = np.ones_like(y_train) +noise = rng.normal(0, dy) +y_train += noise + +# Mesh the input space for evaluations of the real function, the prediction and +# its standard deviation +X_test = np.atleast_2d(np.linspace(0, 10, 1000)).T + +regressors = [ + ("Gaussian Process", GaussianProcessRegressor( + alpha=(dy / y_train) ** 2, + kernel=1.0 * RBF() + WhiteKernel(), + n_restarts_optimizer=10, + random_state=rng), "b"), + ("Bagging", BaggingRegressor( + base_estimator=ExtraTreeRegressor(), + n_estimators=1000, + random_state=rng), "g") +] + +# Plot the function and the observations +fig = plt.figure() +plt.plot(X_test, f(X_test), 'r', label=u'$f(x) = x\,\sin(x)$') +plt.fill(np.concatenate([X_test, X_test[::-1]]), + np.concatenate([f(X_test) - 1.9600, (f(X_test) + 1.9600)[::-1]]), + alpha=.3, fc='r', ec='None') +plt.plot(X_train.ravel(), y_train, 'ko', zorder=5, label=u'Observations') + +# Plot predictive distibutions of GP and Bagging +mse = {} +log_pdf_loss = {} + +for name, regressor, color in regressors: + regressor.fit(X_train, y_train) + + # Make the prediction on the meshed x-axis (ask for standard deviation + # as well) + y_pred, sigma = regressor.predict(X_test, return_std=True) + + # Compute mean-squared error and log predictive loss + mse[name] = mean_squared_error(f(X_test), y_pred) + log_pdf_loss[name] = \ + norm(y_pred, sigma).logpdf(f(X_test)).mean() + + # Plot 95% confidence interval based on the predictive standard deviation + plt.plot(X_test, y_pred, color, label=name) + plt.fill(np.concatenate([X_test, X_test[::-1]]), + np.concatenate([y_pred - 1.9600 * sigma, + (y_pred + 1.9600 * sigma)[::-1]]), + alpha=.3, fc=color, ec='None') + +plt.xlabel('$x$') +plt.ylabel('$f(x)$') +plt.ylim(-10, 20) +plt.legend(loc='upper left') + +print("Mean-squared error of predictors on 1000 equidistant noise-less test " + "datapoints:") +for name, _, _ in regressors: + print("\t%s: %.2f" % (name, mse[name])) + +print("Mean log-probability of 1000 equidistant noise-less test datapoints\n" + "under the (normal) predictive distribution of the predictors, i.e.,\n" + "log N(y_true| y_pred_mean, y_pred_std) [less is better]:") +for name, _, _ in regressors: + print("\t%s: %.2f" % (name, log_pdf_loss[name])) + +plt.show() diff --git a/sklearn/ensemble/bagging.py b/sklearn/ensemble/bagging.py index f9849d33389fb..f83a93309d088 100644 --- a/sklearn/ensemble/bagging.py +++ b/sklearn/ensemble/bagging.py @@ -185,9 +185,8 @@ def _parallel_decision_function(estimators, estimators_features, X): def _parallel_predict_regression(estimators, estimators_features, X): """Private function used to compute predictions within a job.""" - return sum(estimator.predict(X[:, features]) - for estimator, features in zip(estimators, - estimators_features)) + return [estimator.predict(X[:, features]) + for estimator, features in zip(estimators, estimators_features)] class BaseBagging(with_metaclass(ABCMeta, BaseEnsemble)): @@ -283,7 +282,7 @@ def _fit(self, X, y, max_samples, sample_weight=None): X, y = check_X_y(X, y, ['csr', 'csc']) # Remap output - n_samples, self.n_features_ = X.shape + self.n_samples_, self.n_features_ = X.shape y = self._validate_y(y) # Check parameters @@ -291,9 +290,9 @@ def _fit(self, X, y, max_samples, sample_weight=None): # if max_samples is float: if not isinstance(max_samples, (numbers.Integral, np.integer)): - max_samples = int(self.max_samples * X.shape[0]) + max_samples = int(self.max_samples * self.n_samples_) - if not (0 < max_samples <= X.shape[0]): + if not (0 < max_samples <= self.n_samples_): raise ValueError("max_samples must be in (0, n_samples]") if isinstance(self.max_features, (numbers.Integral, np.integer)): @@ -876,11 +875,13 @@ def __init__(self, random_state=random_state, verbose=verbose) - def predict(self, X): + def predict(self, X, return_std=False): """Predict regression target for X. The predicted regression target of an input sample is computed as the mean predicted regression targets of the estimators in the ensemble. + Optionally, the standard deviation of the predictions of the ensemble's + estimators is computed in addition. Parameters ---------- @@ -888,15 +889,29 @@ def predict(self, X): The training input samples. Sparse matrices are accepted only if they are supported by the base estimator. + return_std : boolean, optional, default=False + When True, the sampling standard deviation of the predictions of + the ensemble is returned in addition to the predicted values. + Standard deviations are computed using bias-corrected + Jacknife-after-bootstrap estimates, as decribed in arXiv:1311.4555. + Returns ------- - y : array of shape = [n_samples] - The predicted values. + y_mean : array of shape = [n_samples] + The mean of the predicted values. + + y_std : array of shape = [n_samples], optional (if return_std == True) + The sampling standard deviation of the predicted values. """ + # Checks check_is_fitted(self, "estimators_features_") - # Check data X = check_array(X, accept_sparse=['csr', 'csc']) + if return_std and not self.bootstrap: + raise ValueError("The sampling standard deviation of the " + "predicted values can be computed only when " + "bootstrap=True.") + # Parallel loop n_jobs, n_estimators, starts = _partition_estimators(self.n_estimators, self.n_jobs) @@ -909,9 +924,33 @@ def predict(self, X): for i in range(n_jobs)) # Reduce - y_hat = sum(all_y_hat) / self.n_estimators + all_y_hat = np.array(all_y_hat).reshape(self.n_estimators, -1) + y_mean = np.mean(all_y_hat, axis=0) + + if not return_std: + return y_mean + + else: + # Jacknife-after-bootstrap estimate of the sampling variance + var_J = np.zeros(len(X)) + N_bi = np.zeros((self.n_estimators, self.n_samples_)) + + for b, samples in enumerate(self.estimators_samples_): + N_bi[b, samples] += 1 + + out_of_bag = (N_bi == 0) + + for i in range(self.n_samples_): + if np.any(out_of_bag[:, i]): + delta_i = all_y_hat[out_of_bag[:, i]].mean(axis=0) - y_mean + var_J += delta_i ** 2 + + var_J *= (self.n_samples_ - 1.) / self.n_samples_ + correction = (self.n_samples_ * np.var(all_y_hat, axis=0) / + self.n_estimators) + var_J[correction < var_J] -= correction[correction < var_J] - return y_hat + return y_mean, var_J ** 0.5 def _validate_estimator(self): """Check the estimator and set the base_estimator_ attribute.""" diff --git a/sklearn/ensemble/forest.py b/sklearn/ensemble/forest.py index 9349c28b44339..588136a4b5526 100644 --- a/sklearn/ensemble/forest.py +++ b/sklearn/ensemble/forest.py @@ -251,7 +251,7 @@ def fit(self, X, y, sample_weight=None): X.sort_indices() # Remap output - n_samples, self.n_features_ = X.shape + self.n_samples_, self.n_features_ = X.shape y = np.atleast_1d(y) if y.ndim == 2 and y.shape[1] == 1: @@ -549,10 +549,10 @@ def predict(self, X): def predict_proba(self, X): """Predict class probabilities for X. - The predicted class probabilities of an input sample is computed as - the mean predicted class probabilities of the trees in the forest. The - class probability of a single tree is the fraction of samples of the same - class in a leaf. + The predicted class probabilities of an input sample is computed as the + mean predicted class probabilities of the trees in the forest. The + class probability of a single tree is the fraction of samples of the + same class in a leaf. Parameters ---------- @@ -662,11 +662,13 @@ def __init__(self, verbose=verbose, warm_start=warm_start) - def predict(self, X): + def predict(self, X, return_std=False): """Predict regression target for X. The predicted regression target of an input sample is computed as the mean predicted regression targets of the trees in the forest. + Optionally, the standard deviation of the predictions of the ensemble's + estimators is computed in addition. Parameters ---------- @@ -675,14 +677,28 @@ def predict(self, X): ``dtype=np.float32`` and if a sparse matrix is provided to a sparse ``csr_matrix``. + return_std : boolean, optional, default=False + When True, the sampling standard deviation of the predictions of + the ensemble is returned in addition to the predicted values. + Standard deviations are computed using bias-corrected + Jacknife-after-bootstrap estimates, as decribed in arXiv:1311.4555. + Returns ------- - y : array of shape = [n_samples] or [n_samples, n_outputs] - The predicted values. + y_mean: array of shape = [n_samples] or [n_samples, n_outputs] + The mean of the predicted values. + + y_std : array of shape = [n_samples], optional (if return_std == True) + The sampling standard deviation of the predicted values. """ - # Check data + # Checks X = self._validate_X_predict(X) + if return_std and not self.bootstrap: + raise ValueError("The sampling standard deviation of the " + "predicted values can be computed only when " + "bootstrap=True.") + # Assign chunk of trees to jobs n_jobs, _, _ = _partition_estimators(self.n_estimators, self.n_jobs) @@ -692,10 +708,35 @@ def predict(self, X): delayed(_parallel_helper)(e, 'predict', X, check_input=False) for e in self.estimators_) - # Reduce - y_hat = sum(all_y_hat) / len(self.estimators_) + all_y_hat = np.array(all_y_hat) + y_mean = np.mean(all_y_hat, axis=0) + + if not return_std: + return y_mean + + else: + # Jacknife-after-bootstrap estimate of the sampling variance + var_J = np.zeros(len(X)) + N_bi = np.zeros((self.n_estimators, self.n_samples_)) + + for b, estimator in enumerate(self.estimators_): + samples = _generate_sample_indices(estimator.random_state, + self.n_samples_) + N_bi[b, samples] += 1 + + out_of_bag = (N_bi == 0) + + for i in range(self.n_samples_): + if np.any(out_of_bag[:, i]): + delta_i = all_y_hat[out_of_bag[:, i]].mean(axis=0) - y_mean + var_J += delta_i ** 2 + + var_J *= (self.n_samples_ - 1.) / self.n_samples_ + correction = (self.n_samples_ * np.var(all_y_hat, axis=0) / + self.n_estimators) + var_J[correction < var_J] -= correction[correction < var_J] - return y_hat + return y_mean, var_J ** 0.5 def _set_oob_score(self, X, y): """Compute out-of-bag scores"""