diff --git a/doc/whats_new/v1.0.rst b/doc/whats_new/v1.0.rst index ffb9646499982..a00523ec2223b 100644 --- a/doc/whats_new/v1.0.rst +++ b/doc/whats_new/v1.0.rst @@ -59,6 +59,16 @@ Changelog - |Fix| Fixes incorrect multiple data-conversion warnings when clustering boolean data. :pr:`19046` by :user:`Surya Prakash `. +:mod:`sklearn.ensemble` +....................... + +- |Fix| Do not allow to compute out-of-bag (OOB) score in + :class:`ensemble.RandomForestClassifier` and + :class:`ensemble.ExtraTreesClassifier` with multiclass-multioutput target + since scikit-learn does not provide any metric supporting this type of + target. Additional private refactoring was performed. + :pr:`19162` by :user:`Guillaume Lemaitre `. + :mod:`sklearn.feature_extraction` ................................. diff --git a/sklearn/ensemble/_forest.py b/sklearn/ensemble/_forest.py index ff1e781f7e166..c97b5b9f12528 100644 --- a/sklearn/ensemble/_forest.py +++ b/sklearn/ensemble/_forest.py @@ -50,8 +50,9 @@ class calls the ``fit`` method of each sub-estimator on random samples from scipy.sparse import hstack as sparse_hstack from joblib import Parallel +from ..base import is_classifier from ..base import ClassifierMixin, RegressorMixin, MultiOutputMixin -from ..metrics import r2_score +from ..metrics import accuracy_score, r2_score from ..preprocessing import OneHotEncoder from ..tree import (DecisionTreeClassifier, DecisionTreeRegressor, ExtraTreeClassifier, ExtraTreeRegressor) @@ -61,7 +62,7 @@ class calls the ``fit`` method of each sub-estimator on random samples from ._base import BaseEnsemble, _partition_estimators from ..utils.fixes import delayed from ..utils.fixes import _joblib_parallel_args -from ..utils.multiclass import check_classification_targets +from ..utils.multiclass import check_classification_targets, type_of_target from ..utils.validation import check_is_fitted, _check_sample_weight from ..utils.validation import _deprecate_positional_args @@ -396,7 +397,19 @@ def fit(self, X, y, sample_weight=None): self.estimators_.extend(trees) if self.oob_score: - self._set_oob_score(X, y) + y_type = type_of_target(y) + if y_type in ("multiclass-multioutput", "unknown"): + # FIXME: we could consider to support multiclass-multioutput if + # we introduce or reuse a constructor parameter (e.g. + # oob_score) allowing our user to pass a callable defining the + # scoring strategy on OOB sample. + raise ValueError( + f"The type of target cannot be used to compute OOB " + f"estimates. Got {y_type} while only the following are " + f"supported: continuous, continuous-multioutput, binary, " + f"multiclass, multilabel-indicator." + ) + self._set_oob_score_and_attributes(X, y) # Decapsulate classes_ attributes if hasattr(self, "classes_") and self.n_outputs_ == 1: @@ -406,9 +419,76 @@ def fit(self, X, y, sample_weight=None): return self @abstractmethod - def _set_oob_score(self, X, y): + def _set_oob_score_and_attributes(self, X, y): + """Compute and set the OOB score and attributes. + + Parameters + ---------- + X : array-like of shape (n_samples, n_features) + The data matrix. + y : ndarray of shape (n_samples, n_outputs) + The target matrix. """ - Calculate out of bag predictions and score.""" + + def _compute_oob_predictions(self, X, y): + """Compute and set the OOB score. + + Parameters + ---------- + X : array-like of shape (n_samples, n_features) + The data matrix. + y : ndarray of shape (n_samples, n_outputs) + The target matrix. + + Returns + ------- + oob_pred : ndarray of shape (n_samples, n_classes, n_outputs) or \ + (n_samples, 1, n_outputs) + The OOB predictions. + """ + X = check_array(X, dtype=DTYPE, accept_sparse='csr') + + n_samples = y.shape[0] + n_outputs = self.n_outputs_ + if is_classifier(self) and hasattr(self, "n_classes_"): + # n_classes_ is a ndarray at this stage + # all the supported type of target will have the same number of + # classes in all outputs + oob_pred_shape = (n_samples, self.n_classes_[0], n_outputs) + else: + # for regression, n_classes_ does not exist and we create an empty + # axis to be consistent with the classification case and make + # the array operations compatible with the 2 settings + oob_pred_shape = (n_samples, 1, n_outputs) + + oob_pred = np.zeros(shape=oob_pred_shape, dtype=np.float64) + n_oob_pred = np.zeros((n_samples, n_outputs), dtype=np.int64) + + n_samples_bootstrap = _get_n_samples_bootstrap( + n_samples, self.max_samples, + ) + for estimator in self.estimators_: + unsampled_indices = _generate_unsampled_indices( + estimator.random_state, n_samples, n_samples_bootstrap, + ) + + y_pred = self._get_oob_predictions( + estimator, X[unsampled_indices, :] + ) + oob_pred[unsampled_indices, ...] += y_pred + n_oob_pred[unsampled_indices, :] += 1 + + for k in range(n_outputs): + if (n_oob_pred == 0).any(): + warn( + "Some inputs do not have OOB scores. This probably means " + "too few trees were used to compute any reliable OOB " + "estimates.", UserWarning + ) + n_oob_pred[n_oob_pred == 0] = 1 + oob_pred[..., k] /= n_oob_pred[..., [k]] + + return oob_pred def _validate_y_class_weight(self, y): # Default implementation @@ -507,53 +587,53 @@ def __init__(self, class_weight=class_weight, max_samples=max_samples) - def _set_oob_score(self, X, y): - """ - Compute out-of-bag score.""" - X = check_array(X, dtype=DTYPE, accept_sparse='csr') - - n_classes_ = self.n_classes_ - n_samples = y.shape[0] - - oob_decision_function = [] - oob_score = 0.0 - predictions = [np.zeros((n_samples, n_classes_[k])) - for k in range(self.n_outputs_)] - - n_samples_bootstrap = _get_n_samples_bootstrap( - n_samples, self.max_samples - ) - - for estimator in self.estimators_: - unsampled_indices = _generate_unsampled_indices( - estimator.random_state, n_samples, n_samples_bootstrap) - p_estimator = estimator.predict_proba(X[unsampled_indices, :], - check_input=False) - - if self.n_outputs_ == 1: - p_estimator = [p_estimator] - - for k in range(self.n_outputs_): - predictions[k][unsampled_indices, :] += p_estimator[k] - - for k in range(self.n_outputs_): - if (predictions[k].sum(axis=1) == 0).any(): - warn("Some inputs do not have OOB scores. " - "This probably means too few trees were used " - "to compute any reliable oob estimates.") + @staticmethod + def _get_oob_predictions(tree, X): + """Compute the OOB predictions for an individual tree. - decision = (predictions[k] / - predictions[k].sum(axis=1)[:, np.newaxis]) - oob_decision_function.append(decision) - oob_score += np.mean(y[:, k] == - np.argmax(predictions[k], axis=1), axis=0) + Parameters + ---------- + tree : DecisionTreeClassifier object + A single decision tree classifier. + X : ndarray of shape (n_samples, n_features) + The OOB samples. - if self.n_outputs_ == 1: - self.oob_decision_function_ = oob_decision_function[0] + Returns + ------- + y_pred : ndarray of shape (n_samples, n_classes, n_outputs) + The OOB associated predictions. + """ + y_pred = tree.predict_proba(X, check_input=False) + y_pred = np.array(y_pred, copy=False) + if y_pred.ndim == 2: + # binary and multiclass + y_pred = y_pred[..., np.newaxis] else: - self.oob_decision_function_ = oob_decision_function + # Roll the first `n_outputs` axis to the last axis. We will reshape + # from a shape of (n_outputs, n_samples, n_classes) to a shape of + # (n_samples, n_classes, n_outputs). + y_pred = np.rollaxis(y_pred, axis=0, start=3) + return y_pred + + def _set_oob_score_and_attributes(self, X, y): + """Compute and set the OOB score and attributes. - self.oob_score_ = oob_score / self.n_outputs_ + Parameters + ---------- + X : array-like of shape (n_samples, n_features) + The data matrix. + y : ndarray of shape (n_samples, n_outputs) + The target matrix. + """ + self.oob_decision_function_ = super()._compute_oob_predictions(X, y) + if self.oob_decision_function_.shape[-1] == 1: + # drop the n_outputs axis if there is a single output + self.oob_decision_function_ = self.oob_decision_function_.squeeze( + axis=-1 + ) + self.oob_score_ = accuracy_score( + y, np.argmax(self.oob_decision_function_, axis=1) + ) def _validate_y_class_weight(self, y): check_classification_targets(y) @@ -664,8 +744,7 @@ def predict_proba(self, X): Returns ------- - p : ndarray of shape (n_samples, n_classes), or a list of n_outputs - such arrays if n_outputs > 1. + p : ndarray of shape (n_samples, n_classes), or a list of such arrays The class probabilities of the input samples. The order of the classes corresponds to that in the attribute :term:`classes_`. """ @@ -711,8 +790,7 @@ def predict_log_proba(self, X): Returns ------- - p : ndarray of shape (n_samples, n_classes), or a list of n_outputs - such arrays if n_outputs > 1. + p : ndarray of shape (n_samples, n_classes), or a list of such arrays The class probabilities of the input samples. The order of the classes corresponds to that in the attribute :term:`classes_`. """ @@ -803,52 +881,48 @@ def predict(self, X): return y_hat - def _set_oob_score(self, X, y): - """ - Compute out-of-bag scores.""" - X = check_array(X, dtype=DTYPE, accept_sparse='csr') - - n_samples = y.shape[0] - - predictions = np.zeros((n_samples, self.n_outputs_)) - n_predictions = np.zeros((n_samples, self.n_outputs_)) - - n_samples_bootstrap = _get_n_samples_bootstrap( - n_samples, self.max_samples - ) - - for estimator in self.estimators_: - unsampled_indices = _generate_unsampled_indices( - estimator.random_state, n_samples, n_samples_bootstrap) - p_estimator = estimator.predict( - X[unsampled_indices, :], check_input=False) - - if self.n_outputs_ == 1: - p_estimator = p_estimator[:, np.newaxis] - - predictions[unsampled_indices, :] += p_estimator - n_predictions[unsampled_indices, :] += 1 + @staticmethod + def _get_oob_predictions(tree, X): + """Compute the OOB predictions for an individual tree. - if (n_predictions == 0).any(): - warn("Some inputs do not have OOB scores. " - "This probably means too few trees were used " - "to compute any reliable oob estimates.") - n_predictions[n_predictions == 0] = 1 - - predictions /= n_predictions - self.oob_prediction_ = predictions - - if self.n_outputs_ == 1: - self.oob_prediction_ = \ - self.oob_prediction_.reshape((n_samples, )) + Parameters + ---------- + tree : DecisionTreeRegressor object + A single decision tree regressor. + X : ndarray of shape (n_samples, n_features) + The OOB samples. - self.oob_score_ = 0.0 + Returns + ------- + y_pred : ndarray of shape (n_samples, 1, n_outputs) + The OOB associated predictions. + """ + y_pred = tree.predict(X, check_input=False) + if y_pred.ndim == 1: + # single output regression + y_pred = y_pred[:, np.newaxis, np.newaxis] + else: + # multioutput regression + y_pred = y_pred[:, np.newaxis, :] + return y_pred - for k in range(self.n_outputs_): - self.oob_score_ += r2_score(y[:, k], - predictions[:, k]) + def _set_oob_score_and_attributes(self, X, y): + """Compute and set the OOB score and attributes. - self.oob_score_ /= self.n_outputs_ + Parameters + ---------- + X : array-like of shape (n_samples, n_features) + The data matrix. + y : ndarray of shape (n_samples, n_outputs) + The target matrix. + """ + self.oob_prediction_ = super()._compute_oob_predictions(X, y).squeeze( + axis=1 + ) + if self.oob_prediction_.shape[-1] == 1: + # drop the n_outputs axis if there is a single output + self.oob_prediction_ = self.oob_prediction_.squeeze(axis=-1) + self.oob_score_ = r2_score(y, self.oob_prediction_) def _compute_partial_dependence_recursion(self, grid, target_features): """Fast partial dependence computation. @@ -881,6 +955,7 @@ def _compute_partial_dependence_recursion(self, grid, target_features): return averaged_predictions + class RandomForestClassifier(ForestClassifier): """ A random forest classifier. @@ -999,8 +1074,7 @@ class RandomForestClassifier(ForestClassifier): whole dataset is used to build each tree. oob_score : bool, default=False - Whether to use out-of-bag samples to estimate - the generalization accuracy. + Whether to use out-of-bag samples to estimate the generalization score. n_jobs : int, default=None The number of jobs to run in parallel. :meth:`fit`, :meth:`predict`, @@ -1107,7 +1181,8 @@ class labels (multi-output problem). Score of the training dataset obtained using an out-of-bag estimate. This attribute exists only when ``oob_score`` is True. - oob_decision_function_ : ndarray of shape (n_samples, n_classes) + oob_decision_function_ : ndarray of shape (n_samples, n_classes) or \ + (n_samples, n_classes, n_outputs) Decision function computed with out-of-bag estimate on the training set. If n_estimators is small it might be possible that a data point was never left out during the bootstrap. In this case, @@ -1322,8 +1397,7 @@ class RandomForestRegressor(ForestRegressor): whole dataset is used to build each tree. oob_score : bool, default=False - whether to use out-of-bag samples to estimate - the R^2 on unseen data. + Whether to use out-of-bag samples to estimate the generalization score. n_jobs : int, default=None The number of jobs to run in parallel. :meth:`fit`, :meth:`predict`, @@ -1396,7 +1470,7 @@ class RandomForestRegressor(ForestRegressor): Score of the training dataset obtained using an out-of-bag estimate. This attribute exists only when ``oob_score`` is True. - oob_prediction_ : ndarray of shape (n_samples,) + oob_prediction_ : ndarray of shape (n_samples,) or (n_samples, n_outputs) Prediction computed with out-of-bag estimate on the training set. This attribute exists only when ``oob_score`` is True. @@ -1605,8 +1679,7 @@ class ExtraTreesClassifier(ForestClassifier): whole dataset is used to build each tree. oob_score : bool, default=False - Whether to use out-of-bag samples to estimate - the generalization accuracy. + Whether to use out-of-bag samples to estimate the generalization score. n_jobs : int, default=None The number of jobs to run in parallel. :meth:`fit`, :meth:`predict`, @@ -1717,7 +1790,8 @@ class labels (multi-output problem). Score of the training dataset obtained using an out-of-bag estimate. This attribute exists only when ``oob_score`` is True. - oob_decision_function_ : ndarray of shape (n_samples, n_classes) + oob_decision_function_ : ndarray of shape (n_samples, n_classes) or \ + (n_samples, n_classes, n_outputs) Decision function computed with out-of-bag estimate on the training set. If n_estimators is small it might be possible that a data point was never left out during the bootstrap. In this case, @@ -1924,7 +1998,7 @@ class ExtraTreesRegressor(ForestRegressor): whole dataset is used to build each tree. oob_score : bool, default=False - Whether to use out-of-bag samples to estimate the R^2 on unseen data. + Whether to use out-of-bag samples to estimate the generalization score. n_jobs : int, default=None The number of jobs to run in parallel. :meth:`fit`, :meth:`predict`, @@ -2001,7 +2075,7 @@ class ExtraTreesRegressor(ForestRegressor): Score of the training dataset obtained using an out-of-bag estimate. This attribute exists only when ``oob_score`` is True. - oob_prediction_ : ndarray of shape (n_samples,) + oob_prediction_ : ndarray of shape (n_samples,) or (n_samples, n_outputs) Prediction computed with out-of-bag estimate on the training set. This attribute exists only when ``oob_score`` is True. @@ -2290,7 +2364,7 @@ def __init__(self, self.min_impurity_split = min_impurity_split self.sparse_output = sparse_output - def _set_oob_score(self, X, y): + def _set_oob_score_and_attributes(self, X, y): raise NotImplementedError("OOB score not supported by tree embedding") def fit(self, X, y=None, sample_weight=None): diff --git a/sklearn/ensemble/tests/test_forest.py b/sklearn/ensemble/tests/test_forest.py index 65350f4d602d9..2302ed169bf86 100644 --- a/sklearn/ensemble/tests/test_forest.py +++ b/sklearn/ensemble/tests/test_forest.py @@ -32,6 +32,7 @@ from sklearn.utils._testing import assert_raises from sklearn.utils._testing import assert_warns from sklearn.utils._testing import assert_warns_message +from sklearn.utils._testing import _convert_container from sklearn.utils._testing import ignore_warnings from sklearn.utils._testing import skip_if_no_parallel from sklearn.utils.fixes import parse_version @@ -46,6 +47,7 @@ from sklearn.ensemble import RandomForestClassifier from sklearn.ensemble import RandomForestRegressor from sklearn.ensemble import RandomTreesEmbedding +from sklearn.model_selection import train_test_split from sklearn.model_selection import GridSearchCV from sklearn.svm import LinearSVC from sklearn.utils.validation import check_random_state @@ -371,72 +373,156 @@ def test_unfitted_feature_importances(name): getattr(FOREST_ESTIMATORS[name](), 'feature_importances_') -def check_oob_score(name, X, y, n_estimators=20): - # Check that oob prediction is a good estimation of the generalization - # error. - - # Proper behavior - est = FOREST_ESTIMATORS[name](oob_score=True, random_state=0, - n_estimators=n_estimators, bootstrap=True) - n_samples = X.shape[0] - est.fit(X[:n_samples // 2, :], y[:n_samples // 2]) - test_score = est.score(X[n_samples // 2:, :], y[n_samples // 2:]) - oob_score = est.oob_score_ - - assert abs(test_score - oob_score) < 0.1 and oob_score > 0.7 - - # Check warning if not enough estimators - with np.errstate(divide="ignore", invalid="ignore"): - est = FOREST_ESTIMATORS[name](oob_score=True, random_state=0, - n_estimators=1, bootstrap=True) - assert_warns(UserWarning, est.fit, X, y) +@pytest.mark.parametrize("ForestClassifier", FOREST_CLASSIFIERS.values()) +@pytest.mark.parametrize("X_type", ["array", "sparse_csr", "sparse_csc"]) +@pytest.mark.parametrize( + "X, y, lower_bound_accuracy", + [ + ( + *datasets.make_classification( + n_samples=300, n_classes=2, random_state=0 + ), + 0.9, + ), + ( + *datasets.make_classification( + n_samples=1000, n_classes=3, n_informative=6, random_state=0 + ), + 0.65, + ), + ( + iris.data, iris.target * 2 + 1, 0.65, + ), + ( + *datasets.make_multilabel_classification( + n_samples=300, random_state=0 + ), + 0.18, + ), + ], +) +def test_forest_classifier_oob( + ForestClassifier, X, y, X_type, lower_bound_accuracy +): + """Check that OOB score is close to score on a test set.""" + X = _convert_container(X, constructor_name=X_type) + X_train, X_test, y_train, y_test = train_test_split( + X, y, test_size=0.5, random_state=0, + ) + classifier = ForestClassifier( + n_estimators=40, bootstrap=True, oob_score=True, random_state=0, + ) + assert not hasattr(classifier, "oob_score_") + assert not hasattr(classifier, "oob_decision_function_") -@pytest.mark.parametrize('name', FOREST_CLASSIFIERS) -def test_oob_score_classifiers(name): - check_oob_score(name, iris.data, iris.target) + classifier.fit(X_train, y_train) + test_score = classifier.score(X_test, y_test) - # csc matrix - check_oob_score(name, csc_matrix(iris.data), iris.target) + assert abs(test_score - classifier.oob_score_) <= 0.1 + assert classifier.oob_score_ >= lower_bound_accuracy - # non-contiguous targets in classification - check_oob_score(name, iris.data, iris.target * 2 + 1) + assert hasattr(classifier, "oob_score_") + assert not hasattr(classifier, "oob_prediction_") + assert hasattr(classifier, "oob_decision_function_") + if y.ndim == 1: + expected_shape = (X_train.shape[0], len(set(y))) + else: + expected_shape = (X_train.shape[0], len(set(y[:, 0])), y.shape[1]) + assert classifier.oob_decision_function_.shape == expected_shape -@pytest.mark.parametrize('name', FOREST_REGRESSORS) -def test_oob_score_regressors(name): - check_oob_score(name, X_reg, y_reg, 50) - # csc matrix - check_oob_score(name, csc_matrix(X_reg), y_reg, 50) +@pytest.mark.parametrize("ForestRegressor", FOREST_REGRESSORS.values()) +@pytest.mark.parametrize("X_type", ["array", "sparse_csr", "sparse_csc"]) +@pytest.mark.parametrize( + "X, y, lower_bound_r2", + [ + ( + *datasets.make_regression( + n_samples=500, n_features=10, n_targets=1, random_state=0 + ), + 0.7, + ), + ( + *datasets.make_regression( + n_samples=500, n_features=10, n_targets=2, random_state=0 + ), + 0.55, + ), + ], +) +def test_forest_regressor_oob( + ForestRegressor, X, y, X_type, lower_bound_r2 +): + """Check that forest-based regressor provide an OOB score close to the + score on a test set.""" + X = _convert_container(X, constructor_name=X_type) + X_train, X_test, y_train, y_test = train_test_split( + X, y, test_size=0.5, random_state=0, + ) + regressor = ForestRegressor( + n_estimators=50, bootstrap=True, oob_score=True, random_state=0, + ) + assert not hasattr(regressor, "oob_score_") + assert not hasattr(regressor, "oob_prediction_") -def check_oob_score_raise_error(name): - ForestEstimator = FOREST_ESTIMATORS[name] + regressor.fit(X_train, y_train) + test_score = regressor.score(X_test, y_test) - if name in FOREST_TRANSFORMERS: - for oob_score in [True, False]: - assert_raises(TypeError, ForestEstimator, oob_score=oob_score) + assert abs(test_score - regressor.oob_score_) <= 0.1 + assert regressor.oob_score_ >= lower_bound_r2 - assert_raises(NotImplementedError, ForestEstimator()._set_oob_score, - X, y) + assert hasattr(regressor, "oob_score_") + assert hasattr(regressor, "oob_prediction_") + assert not hasattr(regressor, "oob_decision_function_") + if y.ndim == 1: + expected_shape = (X_train.shape[0],) else: - # Unfitted / no bootstrap / no oob_score - for oob_score, bootstrap in [(True, False), (False, True), - (False, False)]: - est = ForestEstimator(oob_score=oob_score, bootstrap=bootstrap, - random_state=0) - assert not hasattr(est, "oob_score_") + expected_shape = (X_train.shape[0], y.ndim) + assert regressor.oob_prediction_.shape == expected_shape - # No bootstrap - assert_raises(ValueError, ForestEstimator(oob_score=True, - bootstrap=False).fit, X, y) +@pytest.mark.parametrize( + "ForestEstimator", FOREST_CLASSIFIERS_REGRESSORS.values() +) +def test_forest_oob_warning(ForestEstimator): + """Check that a warning is raised when not enough estimator and the OOB + estimates will be inacurrate.""" + estimator = ForestEstimator( + n_estimators=1, oob_score=True, bootstrap=True, random_state=0, + ) + with pytest.warns(UserWarning, match="Some inputs do not have OOB scores"): + estimator.fit(iris.data, iris.target) -@pytest.mark.parametrize('name', FOREST_ESTIMATORS) -def test_oob_score_raise_error(name): - check_oob_score_raise_error(name) + +@pytest.mark.parametrize( + "ForestEstimator", FOREST_CLASSIFIERS_REGRESSORS.values() +) +@pytest.mark.parametrize( + "X, y, params, err_msg", + [ + (iris.data, iris.target, {"oob_score": True, "bootstrap": False}, + "Out of bag estimation only available if bootstrap=True"), + (iris.data, rng.randint(low=0, high=5, size=(iris.data.shape[0], 2)), + {"oob_score": True, "bootstrap": True}, + "The type of target cannot be used to compute OOB estimates") + ] +) +def test_forest_oob_error(ForestEstimator, X, y, params, err_msg): + estimator = ForestEstimator(**params) + with pytest.raises(ValueError, match=err_msg): + estimator.fit(X, y) + + +@pytest.mark.parametrize("oob_score", [True, False]) +def test_random_trees_embedding_raise_error_oob(oob_score): + with pytest.raises(TypeError, match="got an unexpected keyword argument"): + RandomTreesEmbedding(oob_score=oob_score) + with pytest.raises(NotImplementedError, match="OOB score not supported"): + RandomTreesEmbedding()._set_oob_score_and_attributes(X, y) def check_gridsearch(name):