diff --git a/doc/whats_new.rst b/doc/whats_new.rst index 514b68c7d54d5..4274e65e5f826 100644 --- a/doc/whats_new.rst +++ b/doc/whats_new.rst @@ -101,6 +101,8 @@ Enhancements with ``n_jobs > 1`` used with a large grid of parameters on a small dataset. By `Vlad Niculae`_, `Olivier Grisel`_ and `Loic Esteve`_. + - Add multi-output support to :class:`bagging.BaggingClassifier` + and :class:`bagging.BaggingRegressor`. By `Arnaud Joly`_. Bug fixes ......... @@ -120,7 +122,7 @@ Bug fixes - Fixed bug in :class:`linear_model.LogisticRegressionCV` where `penalty` was ignored in the final fit. By `Manoj Kumar`_. - - Fixed bug in :class:`ensemble.forest.ForestClassifier` while computing + - Fixed bug in :class:`ensemble.forest.ForestClassifier` while computing oob_score and X is a sparse.csc_matrix. By `Ankur Ankan`_. API changes summary diff --git a/sklearn/ensemble/bagging.py b/sklearn/ensemble/bagging.py index bec6e040577c2..be15e2518e4f1 100644 --- a/sklearn/ensemble/bagging.py +++ b/sklearn/ensemble/bagging.py @@ -1,6 +1,7 @@ """Bagging meta-estimator.""" # Author: Gilles Louppe +# Arnaud Joly # License: BSD 3 clause from __future__ import division @@ -17,9 +18,11 @@ from ..externals.six.moves import zip from ..metrics import r2_score, accuracy_score from ..tree import DecisionTreeClassifier, DecisionTreeRegressor -from ..utils import check_random_state, check_X_y, check_array, column_or_1d +from ..utils import check_random_state, check_X_y, check_array from ..utils.random import sample_without_replacement -from ..utils.validation import has_fit_parameter, check_is_fitted +from ..utils.validation import check_is_fitted +from ..utils.validation import DataConversionWarning +from ..utils.validation import has_fit_parameter from ..utils.fixes import bincount from ..utils.metaestimators import if_delegate_has_method @@ -125,53 +128,86 @@ def _parallel_build_estimators(n_estimators, ensemble, X, y, sample_weight, return estimators, estimators_samples, estimators_features -def _parallel_predict_proba(estimators, estimators_features, X, n_classes): +def _parallel_predict_proba(estimators, estimators_features, X, n_outputs, + n_classes): """Private function used to compute (proba-)predictions within a job.""" n_samples = X.shape[0] - proba = np.zeros((n_samples, n_classes)) + + if n_outputs == 1: + n_classes = [n_classes] + + proba = [np.zeros((n_samples, n_classes_k)) for n_classes_k in n_classes] for estimator, features in zip(estimators, estimators_features): if hasattr(estimator, "predict_proba"): - proba_estimator = estimator.predict_proba(X[:, features]) + proba_est = estimator.predict_proba(X[:, features]) + if n_outputs == 1: + proba_est = [proba_est] + else: + # We resort to voting + y_pred = estimator.predict(X[:, features]) + if n_outputs == 1: + y_pred = y_pred.reshape((-1, 1)) - if n_classes == len(estimator.classes_): - proba += proba_estimator + proba_est = [] + for k, n_classes_k in enumerate(n_classes): + proba_est_k = np.zeros((n_samples, n_classes_k)) + for c in range(n_classes_k): + proba_est_k[:, c] = y_pred[:, k] == c - else: - proba[:, estimator.classes_] += \ - proba_estimator[:, range(len(estimator.classes_))] + proba_est.append(proba_est_k) - else: - # Resort to voting - predictions = estimator.predict(X[:, features]) + estimator_classes_ = estimator.classes_ + if n_outputs == 1: + estimator_classes_ = [estimator_classes_] + + for k in range(n_outputs): + if n_classes[k] == len(estimator_classes_[k]): + proba[k] += proba_est[k] - for i in range(n_samples): - proba[i, predictions[i]] += 1 + else: + proba[k][:, estimator_classes_[k]] += proba_est[k] return proba -def _parallel_predict_log_proba(estimators, estimators_features, X, n_classes): +def _parallel_predict_log_proba(estimators, estimators_features, X, n_outputs, + n_classes): """Private function used to compute log probabilities within a job.""" n_samples = X.shape[0] - log_proba = np.empty((n_samples, n_classes)) - log_proba.fill(-np.inf) - all_classes = np.arange(n_classes, dtype=np.int) + if n_outputs == 1: + n_classes = [n_classes] + + log_proba = [] + for n_classes_k in n_classes: + log_proba_k = np.empty((n_samples, n_classes_k)) + log_proba_k.fill(-np.inf) + log_proba.append(log_proba_k) + + all_classes = [np.arange(n_classes_k, dtype=np.int) + for n_classes_k in n_classes] for estimator, features in zip(estimators, estimators_features): log_proba_estimator = estimator.predict_log_proba(X[:, features]) - if n_classes == len(estimator.classes_): - log_proba = np.logaddexp(log_proba, log_proba_estimator) + estimator_classes_ = estimator.classes_ + if n_outputs == 1: + estimator_classes_ = [estimator_classes_] + log_proba_estimator = [log_proba_estimator] - else: - log_proba[:, estimator.classes_] = np.logaddexp( - log_proba[:, estimator.classes_], - log_proba_estimator[:, range(len(estimator.classes_))]) + for k in range(n_outputs): + if n_classes == len(estimator_classes_[k]): + log_proba[k] = np.logaddexp(log_proba, log_proba_estimator[k]) + + else: + log_proba[k][:, estimator_classes_[k]] = np.logaddexp( + log_proba[k][:, estimator_classes_[k]], + log_proba_estimator[k]) - missing = np.setdiff1d(all_classes, estimator.classes_) - log_proba[:, missing] = np.logaddexp(log_proba[:, missing], - -np.inf) + missing = np.setdiff1d(all_classes[k], estimator_classes_[k]) + log_proba[k][:, missing] = np.logaddexp( + log_proba[k][:, missing], + -np.inf) return log_proba @@ -234,7 +270,7 @@ def fit(self, X, y, sample_weight=None): The training input samples. Sparse matrices are accepted only if they are supported by the base estimator. - y : array-like, shape = [n_samples] + y : array-like, shape = [n_samples] or [n_samples, n_outputs] The target values (class labels in classification, real numbers in regression). @@ -251,10 +287,18 @@ def fit(self, X, y, sample_weight=None): random_state = check_random_state(self.random_state) # Convert data - X, y = check_X_y(X, y, ['csr', 'csc', 'coo']) + X, y = check_X_y(X, y, ['csr', 'csc', 'coo'], multi_output=True) # Remap output n_samples, self.n_features_ = X.shape + + y = np.atleast_1d(y) + if y.ndim == 2 and y.shape[1] == 1: + warn("A column-vector y was passed when a 1d array was" + " expected. Please change the shape of y to " + "(n_samples,), for example using ravel().", + DataConversionWarning, stacklevel=2) + self.n_outputs_ = 1 if y.ndim == 1 else y.shape[1] y = self._validate_y(y) # Check parameters @@ -346,7 +390,21 @@ def _set_oob_score(self, X, y): def _validate_y(self, y): # Default implementation - return column_or_1d(y, warn=True) + return y + + def _validate_X_predict(self, X): + check_is_fitted(self, ["estimators_", "n_features_", "n_outputs_"]) + + # Check data + X = check_array(X, accept_sparse=['csr', 'csc', 'coo']) + + if self.n_features_ != X.shape[1]: + raise ValueError("Number of features of the model must " + "match the input. Model n_features is {0} and " + "input n_features is {1}." + "".format(self.n_features_, X.shape[1])) + + return X class BaggingClassifier(BaseBagging, ClassifierMixin): @@ -439,6 +497,12 @@ class BaggingClassifier(BaseBagging, ClassifierMixin): n_classes_ : int or list The number of classes. + n_features_ : int + The number of features. + + n_outputs_ : int + The number of outputs. + oob_score_ : float Score of the training dataset obtained using an out-of-bag estimate. @@ -500,44 +564,70 @@ def _set_oob_score(self, X, y): classes_ = self.classes_ n_samples = y.shape[0] - predictions = np.zeros((n_samples, n_classes_)) + if self.n_outputs_ == 1: + n_classes_ = [n_classes_] + classes_ = [classes_] + + predictions = [np.zeros((n_samples, n_classes_k)) + for n_classes_k in n_classes_] for estimator, samples, features in zip(self.estimators_, self.estimators_samples_, self.estimators_features_): mask = np.ones(n_samples, dtype=np.bool) mask[samples] = False - if hasattr(estimator, "predict_proba"): - predictions[mask, :] += estimator.predict_proba( - (X[mask, :])[:, features]) - + est_proba = estimator.predict_proba((X[mask, :])[:, features]) + if self.n_outputs_ == 1: + est_proba = [est_proba] + for k, est_proba_k in enumerate(est_proba): + predictions[k][mask] += est_proba_k else: - p = estimator.predict((X[mask, :])[:, features]) + est_pred = estimator.predict((X[mask, :])[:, features]) j = 0 - for i in range(n_samples): if mask[i]: - predictions[i, p[j]] += 1 + for k in range(self.n_outputs_): + predictions[k][i, est_pred[j]] += 1 j += 1 - if (predictions.sum(axis=1) == 0).any(): + if any((pred_k.sum(axis=1) == 0).any() for pred_k in predictions): warn("Some inputs do not have OOB scores. " "This probably means too few estimators were used " "to compute any reliable oob estimates.") - oob_decision_function = (predictions / - predictions.sum(axis=1)[:, np.newaxis]) - oob_score = accuracy_score(y, classes_.take(np.argmax(predictions, - axis=1))) + oob_decision_function = [pred_k / pred_k.sum(axis=1)[:, np.newaxis] + for pred_k in predictions] + y_oob = np.zeros((n_samples, self.n_outputs_)) + for k, oob_df_k in enumerate(oob_decision_function): + y_oob[:, k] = classes_[k].take(np.argmax(oob_df_k, axis=1), + axis=0) + if self.n_outputs_ == 1: + y_oob = y_oob.ravel() + oob_decision_function = oob_decision_function[0] + + oob_score = accuracy_score(y, y_oob) self.oob_decision_function_ = oob_decision_function self.oob_score_ = oob_score def _validate_y(self, y): - y = column_or_1d(y, warn=True) - self.classes_, y = np.unique(y, return_inverse=True) - self.n_classes_ = len(self.classes_) + y = np.copy(y) + if y.ndim == 1: + y = y.reshape((-1, 1)) + + self.classes_ = [] + self.n_classes_ = [] + + for k in range(self.n_outputs_): + classes_k, y[:, k] = np.unique(y[:, k], return_inverse=True) + self.classes_.append(classes_k) + self.n_classes_.append(classes_k.shape[0]) + + if self.n_outputs_ == 1: + y = y.ravel() + self.n_classes_ = self.n_classes_[0] + self.classes_ = self.classes_[0] return y @@ -556,12 +646,21 @@ def predict(self, X): Returns ------- - y : array of shape = [n_samples] + y : array of shape = [n_samples] or [n_samples, n_outputs] The predicted classes. """ - predicted_probabilitiy = self.predict_proba(X) - return self.classes_.take((np.argmax(predicted_probabilitiy, axis=1)), - axis=0) + proba = self.predict_proba(X) + if self.n_outputs_ == 1: + return self.classes_.take(np.argmax(proba, axis=1), axis=0) + + else: + n_samples = proba[0].shape[0] + y = np.zeros((n_samples, self.n_outputs_)) + for k in range(self.n_outputs_): + y[:, k] = self.classes_[k].take(np.argmax(proba[k], axis=1), + axis=0) + + return y def predict_proba(self, X): """Predict class probabilities for X. @@ -581,19 +680,13 @@ def predict_proba(self, X): Returns ------- - p : array of shape = [n_samples, n_classes] + p : array of shape = [n_samples, n_classes], or a list of n_outputs + such arrays if n_outputs > 1. The class probabilities of the input samples. The order of the classes corresponds to that in the attribute `classes_`. """ - check_is_fitted(self, "classes_") # Check data - X = check_array(X, accept_sparse=['csr', 'csc', 'coo']) - - if self.n_features_ != X.shape[1]: - raise ValueError("Number of features of the model must " - "match the input. Model n_features is {0} and " - "input n_features is {1}." - "".format(self.n_features_, X.shape[1])) + X = self._validate_X_predict(X) # Parallel loop n_jobs, n_estimators, starts = _partition_estimators(self.n_estimators, @@ -604,11 +697,21 @@ def predict_proba(self, X): self.estimators_[starts[i]:starts[i + 1]], self.estimators_features_[starts[i]:starts[i + 1]], X, + self.n_outputs_, self.n_classes_) for i in range(n_jobs)) # Reduce - proba = sum(all_proba) / self.n_estimators + proba = all_proba[0] + for j in range(1, len(all_proba)): + for k in range(self.n_outputs_): + proba[k] += all_proba[j][k] + + for k in range(self.n_outputs_): + proba[k] /= self.n_estimators + + if self.n_outputs_ == 1: + proba = proba[0] return proba @@ -627,20 +730,16 @@ def predict_log_proba(self, X): Returns ------- - p : array of shape = [n_samples, n_classes] - The class log-probabilities of the input samples. The order of the + p : array of shape = [n_samples, n_classes], or a list of n_outputs + such arrays if n_outputs > 1. + The class probabilities of the input samples. The order of the classes corresponds to that in the attribute `classes_`. """ - check_is_fitted(self, "classes_") + check_is_fitted(self, "base_estimator_") + if hasattr(self.base_estimator_, "predict_log_proba"): # Check data - X = check_array(X) - - if self.n_features_ != X.shape[1]: - raise ValueError("Number of features of the model must " - "match the input. Model n_features is {0} " - "and input n_features is {1} " - "".format(self.n_features_, X.shape[1])) + X = self._validate_X_predict(X) # Parallel loop n_jobs, n_estimators, starts = _partition_estimators( @@ -651,21 +750,29 @@ def predict_log_proba(self, X): self.estimators_[starts[i]:starts[i + 1]], self.estimators_features_[starts[i]:starts[i + 1]], X, + self.n_outputs_, self.n_classes_) for i in range(n_jobs)) # Reduce log_proba = all_log_proba[0] - for j in range(1, len(all_log_proba)): - log_proba = np.logaddexp(log_proba, all_log_proba[j]) + for k in range(self.n_outputs_): + log_proba[k] = np.logaddexp(log_proba[k], all_log_proba[j]) + + for k in range(self.n_outputs_): + log_proba[k] -= np.log(self.n_estimators) - log_proba -= np.log(self.n_estimators) + if self.n_outputs_ == 1: + log_proba = log_proba[0] return log_proba else: - return np.log(self.predict_proba(X)) + if self.n_outputs_ == 1: + return np.log(self.predict_proba(X)) + else: + return [np.log(proba_k) for proba_k in self.predict_proba(X)] @if_delegate_has_method(delegate='base_estimator') def decision_function(self, X): @@ -686,16 +793,11 @@ def decision_function(self, X): cases with ``k == 1``, otherwise ``k==n_classes``. """ - check_is_fitted(self, "classes_") + if self.n_outputs_ > 1: + raise NotImplementedError("Not implemented for multi-output data") # Check data - X = check_array(X) - - if self.n_features_ != X.shape[1]: - raise ValueError("Number of features of the model must " - "match the input. Model n_features is {1} and " - "input n_features is {2} " - "".format(self.n_features_, X.shape[1])) + X = self._validate_X_predict(X) # Parallel loop n_jobs, n_estimators, starts = _partition_estimators(self.n_estimators, @@ -795,6 +897,12 @@ class BaggingRegressor(BaseBagging, RegressorMixin): estimators_features_ : list of arrays The subset of drawn features for each base estimator. + n_features_ : int + The number of features. + + n_outputs_ : int + The number of outputs. + oob_score_ : float Score of the training dataset obtained using an out-of-bag estimate. @@ -860,12 +968,11 @@ def predict(self, X): Returns ------- - y : array of shape = [n_samples] + y : array of shape = [n_samples] or [n_samples, n_outputs] The predicted values. """ - check_is_fitted(self, "estimators_features_") # Check data - X = check_array(X, accept_sparse=['csr', 'csc', 'coo']) + X = self._validate_X_predict(X) # Parallel loop n_jobs, n_estimators, starts = _partition_estimators(self.n_estimators, @@ -891,16 +998,19 @@ def _validate_estimator(self): def _set_oob_score(self, X, y): n_samples = y.shape[0] - predictions = np.zeros((n_samples,)) - n_predictions = np.zeros((n_samples,)) + predictions = np.zeros((n_samples, self.n_outputs_)) + n_predictions = np.zeros((n_samples, 1)) for estimator, samples, features in zip(self.estimators_, self.estimators_samples_, self.estimators_features_): mask = np.ones(n_samples, dtype=np.bool) mask[samples] = False + est_pred = estimator.predict((X[mask, :])[:, features]) + if est_pred.ndim == 1: + est_pred = est_pred.reshape((-1, 1)) - predictions[mask] += estimator.predict((X[mask, :])[:, features]) + predictions[mask] += est_pred n_predictions[mask] += 1 if (n_predictions == 0).any(): @@ -911,5 +1021,8 @@ def _set_oob_score(self, X, y): predictions /= n_predictions + if self.n_outputs_ == 1: + predictions = predictions.ravel() + self.oob_prediction_ = predictions self.oob_score_ = r2_score(y, predictions) diff --git a/sklearn/ensemble/tests/test_bagging.py b/sklearn/ensemble/tests/test_bagging.py index a5f2a3d086190..49edf3cc1d620 100644 --- a/sklearn/ensemble/tests/test_bagging.py +++ b/sklearn/ensemble/tests/test_bagging.py @@ -9,6 +9,7 @@ from sklearn.base import BaseEstimator +from sklearn.utils.testing import assert_almost_equal from sklearn.utils.testing import assert_array_equal from sklearn.utils.testing import assert_array_almost_equal from sklearn.utils.testing import assert_equal @@ -30,6 +31,7 @@ from sklearn.pipeline import make_pipeline from sklearn.feature_selection import SelectKBest from sklearn.cross_validation import train_test_split +from sklearn.datasets import make_multilabel_classification from sklearn.datasets import load_boston, load_iris, make_hastie_10_2 from sklearn.utils import check_random_state @@ -303,7 +305,6 @@ def test_probability(): def test_oob_score_classification(): # Check that oob prediction is a good estimation of the generalization # error. - rng = check_random_state(0) X_train, X_test, y_train, y_test = train_test_split(iris.data, iris.target, random_state=rng) @@ -313,7 +314,7 @@ def test_oob_score_classification(): n_estimators=100, bootstrap=True, oob_score=True, - random_state=rng).fit(X_train, y_train) + random_state=0).fit(X_train, y_train) test_score = clf.score(X_test, y_test) @@ -325,28 +326,45 @@ def test_oob_score_classification(): n_estimators=1, bootstrap=True, oob_score=True, - random_state=rng).fit, + random_state=0).fit, X_train, y_train) + # Check for multioutput / multilabel data + clf = BaggingClassifier(base_estimator=DecisionTreeClassifier(), + n_estimators=100, + bootstrap=True, + oob_score=True, + random_state=0).fit(X_train, y_train > 1) + + so_oob_df_ = clf.oob_decision_function_ + clf.fit(X_train, np.vstack([y_train, y_train]).T > 1) + for oob_df_k in clf.oob_decision_function_: + assert_almost_equal(oob_df_k, so_oob_df_) + def test_oob_score_regression(): # Check that oob prediction is a good estimation of the generalization # error. - rng = check_random_state(0) X_train, X_test, y_train, y_test = train_test_split(boston.data, boston.target, random_state=rng) - clf = BaggingRegressor(base_estimator=DecisionTreeRegressor(), + est = BaggingRegressor(base_estimator=DecisionTreeRegressor(), n_estimators=50, bootstrap=True, oob_score=True, - random_state=rng).fit(X_train, y_train) + random_state=0).fit(X_train, y_train) + + so_oob_prediction_ = est.oob_prediction_ + test_score = est.score(X_test, y_test) - test_score = clf.score(X_test, y_test) + assert_less(abs(test_score - est.oob_score_), 0.1) - assert_less(abs(test_score - clf.oob_score_), 0.1) + # multioutput-oob + est.fit(X_train, np.vstack([y_train, y_train]).T) + assert_array_almost_equal(est.oob_prediction_, + np.vstack(2 * [so_oob_prediction_]).T) # Test with few estimators assert_warns(UserWarning, @@ -354,7 +372,7 @@ def test_oob_score_regression(): n_estimators=1, bootstrap=True, oob_score=True, - random_state=rng).fit, + random_state=0).fit, X_train, y_train) @@ -407,7 +425,8 @@ def test_error(): BaggingClassifier(base, max_features="foobar").fit, X, y) # Test support of decision_function - assert_false(hasattr(BaggingClassifier(base).fit(X, y), 'decision_function')) + assert_false(hasattr(BaggingClassifier(base).fit(X, y), + 'decision_function')) def test_parallel_classification(): @@ -619,7 +638,8 @@ def test_warm_start_equal_n_estimators(): X_train += 1. assert_warns_message(UserWarning, - "Warm-start fitting without increasing n_estimators does not", + "Warm-start fitting without increasing " + "n_estimators does not", clf.fit, X_train, y_train) assert_array_equal(y_pred, clf.predict(X_test)) @@ -662,3 +682,27 @@ def test_oob_score_removed_on_warm_start(): clf.fit(X, y) assert_raises(AttributeError, getattr, clf, "oob_score_") + + +def test_multioutput(): + X, y = make_multilabel_classification(n_samples=100, n_labels=1, + n_classes=5, random_state=0, + return_indicator=True) + + X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0) + + # no bootstrap is used to get perfect score on training set + est = BaggingClassifier(random_state=0, bootstrap=False) + est.fit(X_train, y_train) + + assert_almost_equal(est.score(X_train, y_train), 1.) + + y_proba = est.predict_proba(X_test) + y_log_proba = est.predict_log_proba(X_test) + for p, log_p in zip(y_proba, y_log_proba): + assert_array_almost_equal(p, np.exp(log_p)) + + # no bootstrap is used to get perfect score on training set + est = BaggingRegressor(random_state=0, bootstrap=False) + est.fit(X_train, y_train) + assert_almost_equal(est.score(X_train, y_train), 1.)