diff --git a/sklearn/ensemble/_forest.py b/sklearn/ensemble/_forest.py index 5def6ac60816b..3ff5288a4aa1f 100644 --- a/sklearn/ensemble/_forest.py +++ b/sklearn/ensemble/_forest.py @@ -570,6 +570,8 @@ def _compute_oob_predictions(self, X, y): oob_pred : ndarray of shape (n_samples, n_classes, n_outputs) or \ (n_samples, 1, n_outputs) The OOB predictions. + + oob_indices_per_tree """ # Prediction requires X to be in CSR format if issparse(X): @@ -601,7 +603,6 @@ def _compute_oob_predictions(self, X, y): n_samples, n_samples_bootstrap, ) - y_pred = self._get_oob_predictions(estimator, X[unsampled_indices, :]) oob_pred[unsampled_indices, ...] += y_pred n_oob_pred[unsampled_indices, :] += 1 @@ -681,6 +682,87 @@ def feature_importances_(self): all_importances = np.mean(all_importances, axis=0, dtype=np.float64) return all_importances / np.sum(all_importances) + def _compute_unbiased_feature_importance_and_oob_predictions_per_tree( + self, tree, X, y, method, n_samples + ): + n_samples_bootstrap = _get_n_samples_bootstrap( + n_samples, + self.max_samples, + ) + oob_indices = _generate_unsampled_indices( + tree.random_state, n_samples, n_samples_bootstrap + ) + X_test = X[oob_indices] + y_test = y[oob_indices] + + oob_pred = np.zeros( + (n_samples, self.estimators_[0].tree_.max_n_classes, self.n_outputs_), + dtype=np.float64, + ) + n_oob_pred = np.zeros((n_samples, self.n_outputs_), dtype=np.intp) + + importances, y_pred = ( + tree.compute_unbiased_feature_importance_and_oob_predictions( + X_test=X_test, + y_test=y_test, + method=method, + ) + ) + oob_pred[oob_indices, :, :] += y_pred + n_oob_pred[oob_indices, :] += 1 + return (importances, oob_pred, n_oob_pred) + + def _compute_unbiased_feature_importance_and_oob_predictions( + self, X, y, method="ufi" + ): # "mdi_oob" + check_is_fitted(self) + X = self._validate_X_predict(X) + y = np.asarray(y) + if y.ndim == 1: + y = y.reshape(-1, 1) + + n_samples, n_features = X.shape + max_n_classes = self.estimators_[0].tree_.max_n_classes + results = Parallel( + n_jobs=self.n_jobs, prefer="threads", return_as="generator_unordered" + )( + delayed( + self._compute_unbiased_feature_importance_and_oob_predictions_per_tree + )(tree, X, y, method, n_samples) + for tree in self.estimators_ + if tree.tree_.node_count > 1 + ) + + importances = np.zeros(n_features, dtype=np.float64) + oob_pred = np.zeros( + (n_samples, max_n_classes, self.n_outputs_), dtype=np.float64 + ) + n_oob_pred = np.zeros((n_samples, self.n_outputs_), dtype=np.intp) + + for importances_i, oob_pred_i, n_oob_pred_i in results: + oob_pred += oob_pred_i + n_oob_pred += n_oob_pred_i + importances += importances_i + + importances /= self.n_estimators + + for k in range(self.n_outputs_): + if (n_oob_pred == 0).any(): + warn( + ( + "Some inputs do not have OOB scores. This probably means " + "too few trees were used to compute any reliable OOB " + "estimates." + ), + UserWarning, + ) + n_oob_pred[n_oob_pred == 0] = 1 + oob_pred[..., k] /= n_oob_pred[..., [k]] + + if not importances.any(): + return np.zeros(self.n_features_in_, dtype=np.float64), oob_pred + return importances / importances.sum(), oob_pred + def _get_estimators_indices(self): # Get drawn indices along both sample and feature axes for tree in self.estimators_: @@ -814,18 +896,57 @@ def _set_oob_score_and_attributes(self, X, y, scoring_function=None): scoring_function : callable, default=None Scoring function for OOB score. Defaults to `accuracy_score`. """ - self.oob_decision_function_ = super()._compute_oob_predictions(X, y) - if self.oob_decision_function_.shape[-1] == 1: - # drop the n_outputs axis if there is a single output - self.oob_decision_function_ = self.oob_decision_function_.squeeze(axis=-1) if scoring_function is None: scoring_function = accuracy_score + ufi_feature_importances, self.oob_decision_function_ = ( + self._compute_unbiased_feature_importance_and_oob_predictions( + X, y, method="ufi" + ) + ) + mdi_oob_feature_importances, _ = ( + self._compute_unbiased_feature_importance_and_oob_predictions( + X, y, method="mdi_oob" + ) + ) + if self.criterion == "gini": + self._ufi_feature_importances_ = ufi_feature_importances + self._mdi_oob_feature_importances_ = mdi_oob_feature_importances + elif self.criterion in ["log_loss", "entropy"]: + self._ufi_feature_importances_ = ufi_feature_importances + # mdi_oob does not support entropy yet + + if self.oob_decision_function_.shape[-1] == 1: + # drop the n_outputs axis if there is a single output + self.oob_decision_function_ = self.oob_decision_function_.squeeze(axis=-1) + self.oob_score_ = scoring_function( y, np.argmax(self.oob_decision_function_, axis=1) ) + @property + def ufi_feature_importances_(self): + check_is_fitted(self) + if self.criterion in ["gini", "log_loss", "entropy"]: + return self._ufi_feature_importances_ + else: + raise AttributeError( + "ufi feature importance only available for" + " classification with split criterion 'gini', 'log_loss' or 'entropy'." + ) + + @property + def mdi_oob_feature_importances_(self): + check_is_fitted(self) + if self.criterion != "gini": + raise AttributeError( + "mdi_oob feature importance only available for" + " classification with split criterion 'gini'" + ) + else: + return self._mdi_oob_feature_importances_ + def _validate_y_class_weight(self, y): check_classification_targets(y) @@ -1121,16 +1242,54 @@ def _set_oob_score_and_attributes(self, X, y, scoring_function=None): scoring_function : callable, default=None Scoring function for OOB score. Defaults to `r2_score`. """ - self.oob_prediction_ = super()._compute_oob_predictions(X, y).squeeze(axis=1) + if scoring_function is None: + scoring_function = r2_score + + ufi_feature_importances, self.oob_prediction_ = ( + self._compute_unbiased_feature_importance_and_oob_predictions( + X, y, method="ufi" + ) + ) + mdi_oob_feature_importances, _ = ( + self._compute_unbiased_feature_importance_and_oob_predictions( + X, y, method="mdi_oob" + ) + ) + if self.criterion == "squared_error": + self._ufi_feature_importances = ufi_feature_importances + self._mdi_oob_feature_importances = mdi_oob_feature_importances + if self.oob_prediction_.shape[-1] == 1: # drop the n_outputs axis if there is a single output self.oob_prediction_ = self.oob_prediction_.squeeze(axis=-1) - if scoring_function is None: - scoring_function = r2_score + # Drop the n_classes axis of size 1 in regression + self.oob_prediction_ = self.oob_prediction_.squeeze(axis=1) self.oob_score_ = scoring_function(y, self.oob_prediction_) + @property + def ufi_feature_importances_(self): + check_is_fitted(self) + if self.criterion != "squared_error": + raise AttributeError( + "Unbiased feature importance only available for" + " regression with split criterion MSE" + ) + else: + return self._ufi_feature_importances + + @property + def mdi_oob_feature_importances_(self): + check_is_fitted(self) + if self.criterion != "squared_error": + raise AttributeError( + "Unbiased feature importance only available for" + " regression with split criterion MSE" + ) + else: + return self._mdi_oob_feature_importances + def _compute_partial_dependence_recursion(self, grid, target_features): """Fast partial dependence computation. diff --git a/sklearn/ensemble/tests/test_forest.py b/sklearn/ensemble/tests/test_forest.py index 5dec5c7ab90b2..c60730d5606c0 100644 --- a/sklearn/ensemble/tests/test_forest.py +++ b/sklearn/ensemble/tests/test_forest.py @@ -8,6 +8,7 @@ import itertools import math import pickle +import re from collections import defaultdict from functools import partial from itertools import combinations, product @@ -297,46 +298,97 @@ def test_probability(name): product(FOREST_REGRESSORS, ["squared_error", "friedman_mse", "absolute_error"]), ), ) -def test_importances(dtype, name, criterion): +@pytest.mark.parametrize( + "oob_score, importance_attribute_name", + [ + (False, "feature_importances_"), + (True, "ufi_feature_importances_"), + (True, "mdi_oob_feature_importances_"), + ], +) +def test_importances(dtype, name, criterion, oob_score, importance_attribute_name): tolerance = 0.01 if name in FOREST_REGRESSORS and criterion == "absolute_error": tolerance = 0.05 - # cast as dtype X = X_large.astype(dtype, copy=False) y = y_large.astype(dtype, copy=False) ForestEstimator = FOREST_ESTIMATORS[name] - est = ForestEstimator(n_estimators=10, criterion=criterion, random_state=0) + est = ForestEstimator( + n_estimators=10, + criterion=criterion, + oob_score=oob_score, + bootstrap=True, + random_state=0, + ) est.fit(X, y) - importances = est.feature_importances_ - - # The forest estimator can detect that only the first 3 features of the - # dataset are informative: - n_important = np.sum(importances > 0.1) - assert importances.shape[0] == 10 - assert n_important == 3 - assert np.all(importances[:3] > 0.1) - - # Check with parallel - importances = est.feature_importances_ - est.set_params(n_jobs=2) - importances_parallel = est.feature_importances_ - assert_array_almost_equal(importances, importances_parallel) - - # Check with sample weights - sample_weight = check_random_state(0).randint(1, 10, len(X)) - est = ForestEstimator(n_estimators=10, random_state=0, criterion=criterion) - est.fit(X, y, sample_weight=sample_weight) - importances = est.feature_importances_ - assert np.all(importances >= 0.0) - - for scale in [0.5, 100]: - est = ForestEstimator(n_estimators=10, random_state=0, criterion=criterion) - est.fit(X, y, sample_weight=scale * sample_weight) - importances_bis = est.feature_importances_ - assert np.abs(importances - importances_bis).mean() < tolerance + if oob_score and name in FOREST_REGRESSORS and criterion != "squared_error": + with pytest.raises( + AttributeError, + match="Unbiased feature importance only available for" + " regression with split criterion MSE", + ): + importances = getattr(est, importance_attribute_name) + elif oob_score and name in FOREST_CLASSIFIERS: + if ( + importance_attribute_name == "mdi_oob_feature_importances_" + and criterion != "gini" + ): + with pytest.raises( + AttributeError, + match="mdi_oob feature importance only available for" + " classification with split criterion 'gini'", + ): + importances = getattr(est, importance_attribute_name) + elif criterion not in ["gini", "log_loss", "entropy"]: + with pytest.raises( + AttributeError, + match="ufi feature importance only available for" + " classification with split criterion 'gini', 'log_loss' or 'entropy'.", + ): + importances = getattr(est, importance_attribute_name) + else: + importances = getattr(est, importance_attribute_name) + # The forest estimator can detect that only the first 3 features of the + # dataset are informative: + n_important = np.sum(importances > 0.1) + assert importances.shape[0] == 10 + assert n_important == 3 + assert np.all(importances[:3] > 0.1) + + # Check with parallel + importances = getattr(est, importance_attribute_name) + est.set_params(n_jobs=2) + importances_parallel = getattr(est, importance_attribute_name) + assert_array_almost_equal(importances, importances_parallel) + + # Check with sample weights + sample_weight = check_random_state(0).randint(1, 10, len(X)) + est = ForestEstimator( + n_estimators=10, + random_state=0, + oob_score=oob_score, + bootstrap=True, + criterion=criterion, + ) + est.fit(X, y, sample_weight=sample_weight) + importances = getattr(est, importance_attribute_name) + if importance_attribute_name == "feature_importances_": + assert np.all(importances >= 0.0) + + for scale in [0.5, 100]: + est = ForestEstimator( + n_estimators=10, + random_state=0, + oob_score=oob_score, + bootstrap=True, + criterion=criterion, + ) + est.fit(X, y, sample_weight=scale * sample_weight) + importances_bis = getattr(est, importance_attribute_name) + assert np.abs(importances - importances_bis).mean() < tolerance def test_importances_asymptotic(): @@ -458,6 +510,22 @@ def test_unfitted_feature_importances(name): getattr(FOREST_ESTIMATORS[name](), "feature_importances_") +@pytest.mark.parametrize("name", FOREST_ESTIMATORS) +@pytest.mark.parametrize( + "unbiased_importance_attribute_name", + [ + "ufi_feature_importances_", + "mdi_oob_feature_importances_", + ], +) +def test_non_OOB_unbiased_feature_importances(name, unbiased_importance_attribute_name): + clf = FOREST_ESTIMATORS[name]().fit(X_large, y_large) + assert not hasattr(clf, unbiased_importance_attribute_name) + assert not hasattr(clf, "oob_score_") + assert not hasattr(clf, "oob_decision_function_") + + +# TODO before merge: implement unbiased importance for sparse data @pytest.mark.parametrize("ForestClassifier", FOREST_CLASSIFIERS.values()) @pytest.mark.parametrize("X_type", ["array", "sparse_csr", "sparse_csc"]) @pytest.mark.parametrize( @@ -488,6 +556,8 @@ def test_unfitted_feature_importances(name): def test_forest_classifier_oob( ForestClassifier, X, y, X_type, lower_bound_accuracy, oob_score ): + if X_type != "array": + pytest.skip() """Check that OOB score is close to score on a test set.""" X = _convert_container(X, constructor_name=X_type) X_train, X_test, y_train, y_test = train_test_split( @@ -511,6 +581,8 @@ def test_forest_classifier_oob( test_score = oob_score(y_test, classifier.predict(X_test)) else: test_score = classifier.score(X_test, y_test) + print(test_score, classifier.oob_score_) + assert classifier.oob_score_ >= lower_bound_accuracy abs_diff = abs(test_score - classifier.oob_score_) @@ -550,6 +622,8 @@ def test_forest_classifier_oob( def test_forest_regressor_oob(ForestRegressor, X, y, X_type, lower_bound_r2, oob_score): """Check that forest-based regressor provide an OOB score close to the score on a test set.""" + if X_type != "array": + pytest.skip() X = _convert_container(X, constructor_name=X_type) X_train, X_test, y_train, y_test = train_test_split( X, @@ -573,6 +647,7 @@ def test_forest_regressor_oob(ForestRegressor, X, y, X_type, lower_bound_r2, oob else: test_score = regressor.score(X_test, y_test) assert regressor.oob_score_ >= lower_bound_r2 + print(test_score, regressor.oob_score_) assert abs(test_score - regressor.oob_score_) <= 0.1 @@ -1161,21 +1236,36 @@ def test_1d_input(name): @pytest.mark.parametrize("name", FOREST_CLASSIFIERS) -def test_class_weights(name): +@pytest.mark.parametrize( + "oob_score, importance_attribute_name", + [ + (False, "feature_importances_"), + (True, "ufi_feature_importances_"), + (True, "mdi_oob_feature_importances_"), + ], +) +def test_class_weights(name, oob_score, importance_attribute_name): # Check class_weights resemble sample_weights behavior. ForestClassifier = FOREST_CLASSIFIERS[name] # Iris is balanced, so no effect expected for using 'balanced' weights - clf1 = ForestClassifier(random_state=0) + clf1 = ForestClassifier(bootstrap=True, oob_score=oob_score, random_state=0) clf1.fit(iris.data, iris.target) - clf2 = ForestClassifier(class_weight="balanced", random_state=0) + clf2 = ForestClassifier( + bootstrap=True, oob_score=oob_score, class_weight="balanced", random_state=0 + ) clf2.fit(iris.data, iris.target) - assert_almost_equal(clf1.feature_importances_, clf2.feature_importances_) + assert_almost_equal( + getattr(clf1, importance_attribute_name), + getattr(clf2, importance_attribute_name), + ) # Make a multi-output problem with three copies of Iris iris_multi = np.vstack((iris.target, iris.target, iris.target)).T # Create user-defined weights that should balance over the outputs clf3 = ForestClassifier( + bootstrap=True, + oob_score=False, class_weight=[ {0: 2.0, 1: 2.0, 2: 1.0}, {0: 2.0, 1: 1.0, 2: 2.0}, @@ -1184,9 +1274,14 @@ def test_class_weights(name): random_state=0, ) clf3.fit(iris.data, iris_multi) + # We can't use oob_score=True on multiclass-multioutput + # So we use the regular feature_importances_ assert_almost_equal(clf2.feature_importances_, clf3.feature_importances_) + # Check against multi-output "balanced" which should also have no effect - clf4 = ForestClassifier(class_weight="balanced", random_state=0) + clf4 = ForestClassifier( + bootstrap=True, oob_score=False, class_weight="balanced", random_state=0 + ) clf4.fit(iris.data, iris_multi) assert_almost_equal(clf3.feature_importances_, clf4.feature_importances_) @@ -1194,18 +1289,28 @@ def test_class_weights(name): sample_weight = np.ones(iris.target.shape) sample_weight[iris.target == 1] *= 100 class_weight = {0: 1.0, 1: 100.0, 2: 1.0} - clf1 = ForestClassifier(random_state=0) + clf1 = ForestClassifier(bootstrap=True, oob_score=oob_score, random_state=0) clf1.fit(iris.data, iris.target, sample_weight) - clf2 = ForestClassifier(class_weight=class_weight, random_state=0) + clf2 = ForestClassifier( + bootstrap=True, oob_score=oob_score, class_weight=class_weight, random_state=0 + ) clf2.fit(iris.data, iris.target) - assert_almost_equal(clf1.feature_importances_, clf2.feature_importances_) + assert_almost_equal( + getattr(clf1, importance_attribute_name), + getattr(clf2, importance_attribute_name), + ) # Check that sample_weight and class_weight are multiplicative - clf1 = ForestClassifier(random_state=0) + clf1 = ForestClassifier(bootstrap=True, oob_score=oob_score, random_state=0) clf1.fit(iris.data, iris.target, sample_weight**2) - clf2 = ForestClassifier(class_weight=class_weight, random_state=0) + clf2 = ForestClassifier( + bootstrap=True, oob_score=oob_score, class_weight=class_weight, random_state=0 + ) clf2.fit(iris.data, iris.target, sample_weight) - assert_almost_equal(clf1.feature_importances_, clf2.feature_importances_) + assert_almost_equal( + getattr(clf1, importance_attribute_name), + getattr(clf2, importance_attribute_name), + ) @pytest.mark.parametrize("name", FOREST_CLASSIFIERS) @@ -1518,6 +1623,25 @@ def test_forest_feature_importances_sum(): assert math.isclose(1, clf.feature_importances_.sum(), abs_tol=1e-7) +@pytest.mark.parametrize( + "unbiased_importance_attribute_name", + [ + "ufi_feature_importances_", + "mdi_oob_feature_importances_", + ], +) +def test_forest_unbiased_feature_importances_sum(unbiased_importance_attribute_name): + X, y = make_classification( + n_samples=15, n_informative=3, random_state=1, n_classes=3 + ) + clf = RandomForestClassifier( + min_samples_leaf=5, random_state=42, n_estimators=200, oob_score=True + ).fit(X, y) + assert math.isclose( + 1, getattr(clf, unbiased_importance_attribute_name).sum(), abs_tol=1e-7 + ) + + def test_forest_degenerate_feature_importances(): # build a forest of single node trees. See #13636 X = np.zeros((10, 10)) @@ -1526,6 +1650,291 @@ def test_forest_degenerate_feature_importances(): assert_array_equal(gbr.feature_importances_, np.zeros(10, dtype=np.float64)) +@pytest.mark.parametrize( + "unbiased_importance_attribute_name", + [ + "ufi_feature_importances_", + "mdi_oob_feature_importances_", + ], +) +def test_forest_degenerate_unbiased_feature_importances( + unbiased_importance_attribute_name, +): + # build a forest of single node trees. See #13636 + X = np.zeros((10, 10)) + y = np.ones((10,)) + with pytest.warns( + UserWarning, + match=re.escape( + "Some inputs do not have OOB scores. This probably means too few trees were" + " used to compute any reliable OOB estimates." + ), + ): + clf = RandomForestClassifier(n_estimators=10, oob_score=True).fit(X, y) + assert_array_equal( + getattr(clf, unbiased_importance_attribute_name), np.zeros(10, dtype=np.float64) + ) + + +@pytest.mark.parametrize("name", FOREST_CLASSIFIERS) +@pytest.mark.parametrize( + "criterion, method", [("gini", "ufi"), ("gini", "mdi_oob"), ("log_loss", "ufi")] +) +def test_unbiased_feature_importance_on_train( + name, criterion, method, global_random_seed +): + from sklearn.ensemble._forest import _generate_sample_indices + + n_samples = 15 + X, y = make_classification( + n_samples=n_samples, + n_informative=3, + random_state=global_random_seed, + n_classes=2, + ) + clf = FOREST_ESTIMATORS[name]( + n_estimators=1, + bootstrap=True, + random_state=global_random_seed, + criterion=criterion, + ) + clf.fit(X, y) + method_on_train = 0 + for tree_idx, tree in enumerate(clf.estimators_): + in_bag_indicies = _generate_sample_indices( + clf.estimators_[tree_idx].random_state, n_samples, n_samples + ) + X_in_bag = clf._validate_X_predict(X)[in_bag_indicies] + y_in_bag = y.reshape(-1, 1)[in_bag_indicies] + method_on_train_tree = ( + tree.compute_unbiased_feature_importance_and_oob_predictions( + X_in_bag, y_in_bag, method + )[0] + ) + method_on_train += method_on_train_tree / method_on_train_tree.sum() + method_on_train /= clf.n_estimators + method_on_train /= method_on_train.sum() + assert_allclose(clf.feature_importances_, method_on_train, rtol=0, atol=1e-12) + + +@pytest.mark.parametrize("name", FOREST_CLASSIFIERS_REGRESSORS) +def test_ufi_match_paper(name): + def paper_ufi(clf, X, y, is_classification): + """ + Code from: Unbiased Measurement of Feature Importance in Tree-Based Methods + https://arxiv.org/pdf/1903.05179 + https://github.com/ZhengzeZhou/unbiased-feature-importance/blob/master/UFI.py + """ + from sklearn.ensemble._forest import _generate_sample_indices + + feature_importance = np.array([0.0] * X.shape[1]) + n_estimators = clf.n_estimators + + n_samples = X.shape[0] + inbag_counts = np.zeros((n_samples, clf.n_estimators)) + for tree_idx, tree in enumerate(clf.estimators_): + sample_idx = _generate_sample_indices( + tree.random_state, n_samples, n_samples + ) + inbag_counts[:, tree_idx] = np.bincount(sample_idx, minlength=n_samples) + + for tree_idx, tree in enumerate(clf.estimators_): + fi_tree = np.array([0.0] * X.shape[1]) + + n_nodes = tree.tree_.node_count + + tree_X_inb = X.repeat((inbag_counts[:, tree_idx]).astype("int"), axis=0) + tree_y_inb = y.repeat((inbag_counts[:, tree_idx]).astype("int"), axis=0) + decision_path_inb = tree.decision_path(tree_X_inb).todense() + + tree_X_oob = X[inbag_counts[:, tree_idx] == 0] + tree_y_oob = y[inbag_counts[:, tree_idx] == 0] + decision_path_oob = tree.decision_path(tree_X_oob).todense() + + impurity = [0] * n_nodes + + has_oob_samples_in_children = [True] * n_nodes + + weighted_n_node_samples = ( + np.array(np.sum(decision_path_inb, axis=0))[0] / tree_X_inb.shape[0] + ) + + for node_idx in range(n_nodes): + y_innode_oob = tree_y_oob[ + np.array(decision_path_oob[:, node_idx]) + .ravel() + .nonzero()[0] + .tolist() + ] + y_innode_inb = tree_y_inb[ + np.array(decision_path_inb[:, node_idx]) + .ravel() + .nonzero()[0] + .tolist() + ] + + if len(y_innode_oob) == 0: + if sum(tree.tree_.children_left == node_idx) > 0: + parent_node = np.arange(n_nodes)[ + tree.tree_.children_left == node_idx + ][0] + has_oob_samples_in_children[parent_node] = False + else: + parent_node = np.arange(n_nodes)[ + tree.tree_.children_right == node_idx + ][0] + has_oob_samples_in_children[parent_node] = False + + else: + p_node_oob = float(sum(y_innode_oob)) / len(y_innode_oob) + p_node_inb = float(sum(y_innode_inb)) / len(y_innode_inb) + if is_classification: + impurity[node_idx] = ( + 1 + - p_node_oob * p_node_inb + - (1 - p_node_oob) * (1 - p_node_inb) + ) + else: + impurity[node_idx] = np.sum( + (y_innode_oob - np.mean(y_innode_inb)) ** 2 + ) / len(y_innode_oob) + for node_idx in range(n_nodes): + if ( + tree.tree_.children_left[node_idx] == -1 + or tree.tree_.children_right[node_idx] == -1 + ): + continue + + feature_idx = tree.tree_.feature[node_idx] + + node_left = tree.tree_.children_left[node_idx] + node_right = tree.tree_.children_right[node_idx] + + if has_oob_samples_in_children[node_idx]: + if is_classification: + fi_tree[feature_idx] += ( + weighted_n_node_samples[node_idx] * impurity[node_idx] + - weighted_n_node_samples[node_left] * impurity[node_left] + - weighted_n_node_samples[node_right] * impurity[node_right] + ) + else: + impurity_train = tree.tree_.impurity + fi_tree[feature_idx] += ( + weighted_n_node_samples[node_idx] + * (impurity[node_idx] + impurity_train[node_idx]) + - weighted_n_node_samples[node_left] + * (impurity[node_left] + impurity_train[node_left]) + - weighted_n_node_samples[node_right] + * (impurity_train[node_right] + impurity[node_right]) + ) + feature_importance += fi_tree + feature_importance /= n_estimators + return feature_importance / feature_importance.sum() + + X, y = make_classification( + n_samples=15, n_informative=3, random_state=1, n_classes=2 + ) + is_classification = True if name in FOREST_CLASSIFIERS else False + est = FOREST_CLASSIFIERS_REGRESSORS[name]( + n_estimators=10, oob_score=True, bootstrap=True, random_state=1 + ) + est.fit(X, y) + assert_almost_equal( + est.ufi_feature_importances_, paper_ufi(est, X, y, is_classification) + ) + + +@pytest.mark.parametrize("name", FOREST_CLASSIFIERS_REGRESSORS) +def test_mdi_oob_match_paper(name): + def paper_mdi_oob(clf, X, y, is_classification): + """ + Code from: A Debiased MDI Feature Importance Measure for Random Forests + https://arxiv.org/pdf/1906.10845 + https://github.com/shifwang/paper-debiased-feature-importance/blob/9c3e1eed860478ef02111ebda4a39255b4d4be74/simulations/02_comparison.ipynb + """ + import copy + + from sklearn.ensemble._forest import _generate_unsampled_indices + from sklearn.preprocessing import OneHotEncoder + + n_samples, n_features = X.shape + + # change X to np.float32 + XX = X.copy().astype(np.float32) + + # infer y.shape + if len(y.shape) == 1: + yy = y[:, np.newaxis] + if is_classification: + yy = OneHotEncoder().fit_transform(yy) + + out = np.zeros((n_features,)) + for tree in clf.estimators_: + indices = _generate_unsampled_indices( + tree.random_state, n_samples, n_samples + ) + + decision_paths = np.array( + tree.tree_.decision_path(XX[indices, :]).todense() + ) + + # compute the impurity at each node + node_mean = ( + decision_paths / (np.sum(decision_paths, 0)[np.newaxis, :]) + ).T @ yy[indices, :] + tmp = copy.deepcopy(tree.tree_.value.squeeze(axis=1)) + + if is_classification: + node_previous_mean = tmp / np.sum(tmp, 1)[:, np.newaxis] + else: + node_previous_mean = tmp + + # compute the impurity decrease at each node + node_sample_size = tree.tree_.weighted_n_node_samples + lc = copy.deepcopy(tree.tree_.children_left) + rc = copy.deepcopy(tree.tree_.children_right) + tmp = lc == -1 + lc[tmp] = 0 + rc[tmp] = 0 + # decrease = node_impurity * node_sample_size - node_impurity[lc] * + # node_sample_size[lc] - node_impurity[rc] * node_sample_size[rc] + decrease = ( + np.sum( + (node_mean - node_mean[lc]) + * (node_previous_mean - node_previous_mean[lc]), + 1, + ) + * node_sample_size[lc] + + np.sum( + (node_mean - node_mean[rc]) + * (node_previous_mean - node_previous_mean[rc]), + 1, + ) + * node_sample_size[rc] + ) + feature = tree.tree_.feature + decrease[feature == -2] = np.nan + decrease[np.sum(decision_paths, 0) < 2] = np.nan + tmp = np.logical_not(np.isnan(decrease)) + for i in range(len(tmp)): + if tmp[i]: + out[feature[i]] += decrease[i] + out /= clf.n_estimators + return out / out.sum() + + X, y = make_classification( + n_samples=15, n_informative=3, random_state=1, n_classes=2 + ) + is_classification = True if name in FOREST_CLASSIFIERS else False + est = FOREST_CLASSIFIERS_REGRESSORS[name]( + n_estimators=10, oob_score=True, bootstrap=True, random_state=1 + ) + est.fit(X, y) + assert_almost_equal( + est.mdi_oob_feature_importances_, paper_mdi_oob(est, X, y, is_classification) + ) + + @pytest.mark.parametrize("name", FOREST_CLASSIFIERS_REGRESSORS) def test_max_samples_bootstrap(name): # Check invalid `max_samples` values diff --git a/sklearn/tree/_classes.py b/sklearn/tree/_classes.py index ec814f088d1d9..8327d852cfc05 100644 --- a/sklearn/tree/_classes.py +++ b/sklearn/tree/_classes.py @@ -690,6 +690,14 @@ def feature_importances_(self): return self.tree_.compute_feature_importances() + def compute_unbiased_feature_importance_and_oob_predictions( + self, X_test, y_test, method="ufi" + ): + check_is_fitted(self) + return self.tree_.compute_unbiased_feature_importance_and_oob_predictions( + X_test, y_test, self.criterion, method=method + ) + def __sklearn_tags__(self): tags = super().__sklearn_tags__() tags.input_tags.sparse = True diff --git a/sklearn/tree/_criterion.pxd b/sklearn/tree/_criterion.pxd index 84d2e800d6a87..afe428b707ec0 100644 --- a/sklearn/tree/_criterion.pxd +++ b/sklearn/tree/_criterion.pxd @@ -43,6 +43,15 @@ cdef class Criterion: intp_t start, intp_t end ) except -1 nogil + cdef int init_oob( + self, + const float64_t[:, ::1] y, + const float64_t[:] sample_weight, + float64_t weighted_n_samples, + const intp_t[:] sample_indices, + intp_t start, + intp_t end + ) except -1 nogil cdef void init_sum_missing(self) cdef void init_missing(self, intp_t n_missing) noexcept nogil cdef int reset(self) except -1 nogil @@ -98,6 +107,13 @@ cdef class ClassificationCriterion(Criterion): cdef float64_t[:, ::1] sum_right # Same as above, but for the right side of the split cdef float64_t[:, ::1] sum_missing # Same as above, but for missing values in X + # out of bag statistics when computiting cross criterion + cdef float64_t[:, ::1] sum_total_oob + cdef float64_t[:, ::1] sum_left_oob + cdef float64_t[:, ::1] sum_right_oob + cdef float64_t[:, ::1] sum_missing_oob + + cdef class RegressionCriterion(Criterion): """Abstract regression criterion.""" diff --git a/sklearn/tree/_criterion.pyx b/sklearn/tree/_criterion.pyx index 9f3db83399569..e5ff959111369 100644 --- a/sklearn/tree/_criterion.pyx +++ b/sklearn/tree/_criterion.pyx @@ -62,7 +62,39 @@ cdef class Criterion: """ pass + cdef int init_oob( + self, + const float64_t[:, ::1] y, + const float64_t[:] sample_weight, + float64_t weighted_n_samples, + const intp_t[:] sample_indices, + intp_t start, + intp_t end, + ) except -1 nogil: + """Placeholder for a method which will initialize the criterion. + Returns -1 in case of failure to allocate memory (and raise MemoryError) + or 0 otherwise. + + Parameters + ---------- + y : ndarray, dtype=float64_t + y is a buffer that can store values for n_outputs target variables + stored as a Cython memoryview. + sample_weight : ndarray, dtype=float64_t + The weight of each sample stored as a Cython memoryview. + weighted_n_samples : float64_t + The total weight of the samples being considered + sample_indices : ndarray, dtype=intp_t + A mask on the samples. Indices of the samples in X and y we want to use, + where sample_indices[start:end] correspond to the samples in this node. + start : intp_t + The first sample to be used on this node + end : intp_t + The last sample used on this node + + """ + pass cdef void init_missing(self, intp_t n_missing) noexcept nogil: """Initialize sum_missing if there are missing values. @@ -406,6 +438,52 @@ cdef class ClassificationCriterion(Criterion): self.reset() return 0 + cdef int init_oob( + self, + const float64_t[:, ::1] y, + const float64_t[:] sample_weight, + float64_t weighted_n_samples, + const intp_t[:] sample_indices, + intp_t start, + intp_t end + ) except -1 nogil: + self.y = y + self.sample_weight = sample_weight + self.sample_indices = sample_indices + self.start = start + self.end = end + self.n_node_samples = end - start + self.weighted_n_samples = weighted_n_samples + self.weighted_n_node_samples = 0.0 + + cdef intp_t i + cdef intp_t p + cdef intp_t k + cdef intp_t c + cdef float64_t w = 1.0 + + for k in range(self.n_outputs): + memset(&self.sum_total_oob[k, 0], 0, self.n_classes[k] * sizeof(float64_t)) + + for p in range(start, end): + i = sample_indices[p] + + # w is originally set to be 1.0, meaning that if no sample weights + # are given, the default weight of each sample is 1.0. + if sample_weight is not None: + w = sample_weight[i] + + # Count weighted class frequency for each target + for k in range(self.n_outputs): + c = self.y[i, k] + self.sum_total_oob[k, c] += w + + self.weighted_n_node_samples += w + + # Reset to pos=start + self.reset() + return 0 + cdef void init_sum_missing(self): """Init sum_missing to hold sums for missing values.""" self.sum_missing = np.zeros((self.n_outputs, self.max_n_classes), dtype=np.float64) diff --git a/sklearn/tree/_tree.pxd b/sklearn/tree/_tree.pxd index 2cadca4564a87..a4f0a037f54c0 100644 --- a/sklearn/tree/_tree.pxd +++ b/sklearn/tree/_tree.pxd @@ -78,6 +78,10 @@ cdef class Tree: cpdef compute_node_depths(self) cpdef compute_feature_importances(self, normalize=*) + cdef void _compute_oob_node_values_and_predictions(self, object X_test, intp_t[:, ::1] y_test, float64_t[:, :, ::1] oob_pred, int32_t[::1] has_oob_sample, float64_t[:, :, ::1] oob_node_values, str method) + cpdef compute_unbiased_feature_importance_and_oob_predictions(self, object X_test, object y_test, criterion, method=*) + cdef float64_t mdi_oob_impurity_decrease(self, float64_t[:, :, ::1] oob_node_values, int node_idx, int left_idx, int right_idx, Node node) + cdef float64_t ufi_impurity_decrease(self, float64_t[:, :, ::1] oob_node_values, int node_idx, int left_idx, int right_idx, Node node, str criterion) # ============================================================================= # Tree builder diff --git a/sklearn/tree/_tree.pyx b/sklearn/tree/_tree.pyx index 9d0b2854c3ba0..2daecc3337b6f 100644 --- a/sklearn/tree/_tree.pyx +++ b/sklearn/tree/_tree.pyx @@ -8,6 +8,7 @@ from libc.string cimport memcpy from libc.string cimport memset from libc.stdint cimport INTPTR_MAX from libc.math cimport isnan +from libc.math cimport log from libcpp.vector cimport vector from libcpp.algorithm cimport pop_heap from libcpp.algorithm cimport push_heap @@ -1274,6 +1275,203 @@ cdef class Tree: return np.asarray(importances) + cdef void _compute_oob_node_values_and_predictions(self, object X_test, intp_t[:, ::1] y_test, float64_t[:, :, ::1] oob_pred, int32_t[::1] has_oob_sample, float64_t[:, :, ::1] oob_node_values, str method): + if issparse(X_test): + raise(NotImplementedError("does not support sparse X yet")) + if not isinstance(X_test, np.ndarray): + raise ValueError("X should be in np.ndarray format, got %s" % type(X_test)) + if X_test.dtype != DTYPE: + raise ValueError("X.dtype should be np.float32, got %s" % X_test.dtype) + cdef const float32_t[:, :] X_ndarray = X_test + + cdef intp_t n_samples = X_test.shape[0] + cdef intp_t* n_classes = self.n_classes + cdef intp_t node_count = self.node_count + cdef intp_t n_outputs = self.n_outputs + cdef intp_t max_n_classes = self.max_n_classes + cdef int k, c, node_idx, sample_idx = 0 + cdef int32_t[:, ::1] count_oob_values = np.zeros((node_count, n_outputs), dtype=np.int32) + cdef int node_value_idx = -1 + + cdef Node* node + + cdef int32_t[::1] y_leafs = np.zeros(n_samples, dtype=np.int32) + + with nogil: + # pass the oob samples in the tree and count them per node + for sample_idx in range(n_samples): + # root node + node = self.nodes + node_idx = 0 + has_oob_sample[node_idx] = 1 + for k in range(n_outputs): + if n_classes[k] > 1: + for c in range(n_classes[k]): + if y_test[k, sample_idx] == c: + oob_node_values[node_idx, c, k] += 1.0 + # TODO use sample weight instead of 1 + count_oob_values[node_idx, k] += 1 + else: + if method == "ufi": + node_value_idx = node_idx * self.value_stride + k * max_n_classes + oob_node_values[node_idx, 0, k] += (y_test[k, sample_idx] - self.value[node_value_idx]) ** 2.0 + else: + oob_node_values[node_idx, 0, k] += y_test[k, sample_idx] + count_oob_values[node_idx, k] += 1 + # TODO use sample weight instead of 1 + # child nodes + while node.left_child != _TREE_LEAF and node.right_child != _TREE_LEAF: + if X_ndarray[sample_idx, node.feature] <= node.threshold: + node_idx = node.left_child + else: + node_idx = node.right_child + has_oob_sample[node_idx] = 1 + node = &self.nodes[node_idx] + for k in range(n_outputs): + if n_classes[k] > 1: + for c in range(n_classes[k]): + if y_test[k, sample_idx] == c: + oob_node_values[node_idx, c, k] += 1.0 + # TODO use sample weight instead of 1 + count_oob_values[node_idx, k] += 1 + else: + if method == "ufi": + node_value_idx = node_idx * self.value_stride + k * max_n_classes + oob_node_values[node_idx, 0, k] += (y_test[k, sample_idx] - self.value[node_value_idx]) ** 2.0 + else: + oob_node_values[node_idx, 0, k] += y_test[k, sample_idx] + count_oob_values[node_idx, k] += 1 + # TODO use sample weight instead of 1 + # store the id of the leaf where each sample ends up + y_leafs[sample_idx] = node_idx + + # convert the counts to proportions + for node_idx in range(node_count): + for k in range(n_outputs): + if count_oob_values[node_idx, k] > 0: + for c in range(n_classes[k]): + oob_node_values[node_idx, c, k] /= count_oob_values[node_idx, k] + # if leaf store the predictive proba + if self.nodes[node_idx].left_child == _TREE_LEAF and self.nodes[node_idx].right_child == _TREE_LEAF: + for sample_idx in range(n_samples): + if y_leafs[sample_idx] == node_idx: + for k in range(n_outputs): + for c in range(n_classes[k]): + node_value_idx = node_idx * self.value_stride + k * max_n_classes + c + oob_pred[sample_idx, c, k] = self.value[node_value_idx] + + cpdef compute_unbiased_feature_importance_and_oob_predictions(self, object X_test, object y_test, criterion, method="ufi"): + cdef intp_t n_samples = X_test.shape[0] + cdef intp_t n_features = X_test.shape[1] + cdef intp_t n_outputs = self.n_outputs + cdef intp_t max_n_classes = self.max_n_classes + cdef intp_t node_count = self.node_count + + cdef int32_t[::1] has_oob_sample = np.zeros(node_count, dtype=np.int32) + cdef float64_t[::1] importances = np.zeros((n_features,), dtype=np.float64) + cdef float64_t[:, :, ::1] oob_pred = np.zeros((n_samples, max_n_classes, n_outputs), dtype=np.float64) + cdef float64_t[:, :, ::1] oob_node_values = np.zeros((node_count, max_n_classes, n_outputs), dtype=np.float64) + + cdef Node* nodes = self.nodes + cdef Node node = nodes[0] + cdef int node_idx = 0 + cdef int left_idx, right_idx = -1 + + cdef intp_t[:, ::1] y_view = np.ascontiguousarray(y_test, dtype=np.intp) + self._compute_oob_node_values_and_predictions(X_test, y_view, oob_pred, has_oob_sample, oob_node_values, method) + + for node_idx in range(self.node_count): + node = nodes[node_idx] + if (node.left_child != _TREE_LEAF) and (node.right_child != _TREE_LEAF): + left_idx = node.left_child + right_idx = node.right_child + if has_oob_sample[left_idx] and has_oob_sample[right_idx]: + if method == "ufi": + # Supports criterion in ["gini", "log_loss", "entropy"] for classification + # And criterion=="squared_error" for regression + importances[node.feature] += self.ufi_impurity_decrease(oob_node_values, node_idx, left_idx, right_idx, node, criterion) + elif method == "mdi_oob": + # Only supports criterion=="gini"(resp. "squared_error") for classification (resp. regression) + importances[node.feature] += self.mdi_oob_impurity_decrease(oob_node_values, node_idx, left_idx, right_idx, node) + + return np.asarray(importances), np.asarray(oob_pred) + + cdef float64_t mdi_oob_impurity_decrease(self, float64_t[:, :, ::1] oob_node_values, int node_idx, int left_idx, int right_idx, Node node): + cdef float64_t importance = 0.0 + cdef int node_value_idx, left_value_idx, right_value_idx = -1 + cdef int k, c = 0 + with nogil: + for k in range(self.n_outputs): + for c in range(self.n_classes[k]): + node_value_idx = node_idx * self.value_stride + k * self.max_n_classes + c + left_value_idx = left_idx * self.value_stride + k * self.max_n_classes + c + right_value_idx = right_idx * self.value_stride + k * self.max_n_classes + c + importance += ( + (self.value[node_value_idx] - self.value[left_value_idx]) + * (oob_node_values[node_idx, c , k] - oob_node_values[left_idx, c, k]) + * self.nodes[left_idx].weighted_n_node_samples + + + (self.value[node_value_idx] - self.value[right_value_idx]) + * (oob_node_values[node_idx, c, k] - oob_node_values[right_idx, c, k]) + * self.nodes[right_idx].weighted_n_node_samples + ) + return importance / self.n_outputs + + cdef float64_t ufi_impurity_decrease(self, float64_t[:, :, ::1] oob_node_values, int node_idx, int left_idx, int right_idx, Node node, str criterion): + cdef float64_t importance = 0.0 + cdef int node_value_idx, left_value_idx, right_value_idx = -1 + cdef int k, c = 0 + with nogil: + for k in range(self.n_outputs): + if self.n_classes[k] > 1: # Classification + for c in range(self.n_classes[k]): + node_value_idx = node_idx * self.value_stride + k * self.max_n_classes + c + left_value_idx = left_idx * self.value_stride + k * self.max_n_classes + c + right_value_idx = right_idx * self.value_stride + k * self.max_n_classes + c + if criterion == "gini": + importance -= ( + self.value[node_value_idx] * oob_node_values[node_idx, c, k] + * node.weighted_n_node_samples + - + self.value[left_value_idx] * oob_node_values[left_idx, c, k] + * self.nodes[left_idx].weighted_n_node_samples + - + self.value[right_value_idx] * oob_node_values[right_idx, c, k] + * self.nodes[right_idx].weighted_n_node_samples + ) + elif criterion == "log_loss" or criterion == "entropy": + # Skip empty classes to avoid taking log(0) + if oob_node_values[node_idx, c, k] > 0.0 and self.value[node_value_idx] > 0.0: + importance -= ( + (self.value[node_value_idx] * log(oob_node_values[node_idx, c, k]) + + log(self.value[node_value_idx]) * oob_node_values[node_idx, c, k]) + * node.weighted_n_node_samples + ) / 2 + if oob_node_values[left_idx, c, k] > 0.0 and self.value[left_value_idx] > 0.0: + importance += ( + (self.value[left_value_idx] * log(oob_node_values[left_idx, c, k]) + + log(self.value[left_value_idx]) * oob_node_values[left_idx, c, k]) + * self.nodes[left_idx].weighted_n_node_samples + ) / 2 + if oob_node_values[right_idx, c, k] > 0.0 and self.value[right_value_idx] > 0.0: + importance += ( + (self.value[right_value_idx] * log(oob_node_values[right_idx, c, k]) + + log(self.value[right_value_idx]) * oob_node_values[right_idx, c, k]) + * self.nodes[right_idx].weighted_n_node_samples + ) / 2 + else: # Regression, only works for criterion == "squared_error" + importance += ( + (node.impurity + oob_node_values[node_idx, 0, k]) + * node.weighted_n_node_samples + - + (self.nodes[left_idx].impurity + oob_node_values[left_idx, 0, k]) + * self.nodes[left_idx].weighted_n_node_samples + - + (self.nodes[right_idx].impurity + oob_node_values[right_idx, 0, k]) + * self.nodes[right_idx].weighted_n_node_samples + ) + return importance / self.n_outputs + cdef cnp.ndarray _get_value_ndarray(self): """Wraps value as a 3-d NumPy array.