From 25dbb155a6134f0b8fecea6afe27969c4af078e6 Mon Sep 17 00:00:00 2001 From: Gilles Louppe Date: Sat, 12 Sep 2015 15:18:51 +0200 Subject: [PATCH 1/5] TEST: stronger tests for variable importances --- sklearn/ensemble/tests/test_forest.py | 168 +++++++++++++++++++++----- sklearn/tree/tests/test_tree.py | 2 +- 2 files changed, 139 insertions(+), 31 deletions(-) diff --git a/sklearn/ensemble/tests/test_forest.py b/sklearn/ensemble/tests/test_forest.py index f2ec00951bf10..29ba2f9bae247 100644 --- a/sklearn/ensemble/tests/test_forest.py +++ b/sklearn/ensemble/tests/test_forest.py @@ -10,10 +10,14 @@ import pickle from collections import defaultdict +from itertools import combinations from itertools import product import numpy as np -from scipy.sparse import csr_matrix, csc_matrix, coo_matrix +from scipy.misc import comb +from scipy.sparse import csr_matrix +from scipy.sparse import csc_matrix +from scipy.sparse import coo_matrix from sklearn.utils.testing import assert_almost_equal from sklearn.utils.testing import assert_array_almost_equal @@ -186,33 +190,35 @@ def test_probability(): yield check_probability, name -def check_importances(name, X, y): - # Check variable importances. +def check_importances(X, y, name, criterion): + ForestEstimator = FOREST_ESTIMATORS[name] - ForestClassifier = FOREST_CLASSIFIERS[name] - for n_jobs in [1, 2]: - clf = ForestClassifier(n_estimators=10, n_jobs=n_jobs) - clf.fit(X, y) - importances = clf.feature_importances_ - n_important = np.sum(importances > 0.1) - assert_equal(importances.shape[0], 10) - assert_equal(n_important, 3) - - X_new = clf.transform(X, threshold="mean") - assert_less(0 < X_new.shape[1], X.shape[1]) - - # Check with sample weights - sample_weight = np.ones(y.shape) - sample_weight[y == 1] *= 100 - - clf = ForestClassifier(n_estimators=50, n_jobs=n_jobs, random_state=0) - clf.fit(X, y, sample_weight=sample_weight) - importances = clf.feature_importances_ - assert_true(np.all(importances >= 0.0)) - - clf = ForestClassifier(n_estimators=50, n_jobs=n_jobs, random_state=0) - clf.fit(X, y, sample_weight=3 * sample_weight) - importances_bis = clf.feature_importances_ + est = ForestEstimator(n_estimators=20, criterion=criterion, + random_state=0) + est.fit(X, y) + importances = est.feature_importances_ + n_important = np.sum(importances > 0.1) + assert_equal(importances.shape[0], 10) + assert_equal(n_important, 3) + + X_new = est.transform(X, threshold="mean") + assert_less(0 < X_new.shape[1], X.shape[1]) + + # Check with sample weights + sample_weight = np.ones(y.shape) + sample_weight[y == 1] *= 100 + + est = ForestEstimator(n_estimators=20, random_state=0, + criterion=criterion) + est.fit(X, y, sample_weight=sample_weight) + importances = est.feature_importances_ + assert_true(np.all(importances >= 0.0)) + + for scale in [3, 10, 1000, 100000]: + est = ForestEstimator(n_estimators=20, random_state=0, + criterion=criterion) + est.fit(X, y, sample_weight=scale * sample_weight) + importances_bis = est.feature_importances_ assert_almost_equal(importances, importances_bis) @@ -222,8 +228,109 @@ def test_importances(): n_repeated=0, shuffle=False, random_state=0) - for name in FOREST_CLASSIFIERS: - yield check_importances, name, X, y + for name, criterion in product(FOREST_CLASSIFIERS, ["gini", "entropy"]): + yield check_importances, X, y, name, criterion + + for name, criterion in product(FOREST_REGRESSORS, ["mse"]): + yield check_importances, X, y, name, criterion + + +def test_importances_asymptotic(): + # Check whether variable importances of totally randomized trees + # converge towards their theoretical values (See Louppe et al, + # Understanding variable importances in forests of randomized trees, 2013). + + def binomial(k, n): + if k < 0 or k > n: + return 0 + else: + return comb(int(n), int(k), exact=True) + + def entropy(samples): + e = 0. + n_samples = len(samples) + + for count in np.bincount(samples): + p = 1. * count / n_samples + if p > 0: + e -= p * np.log2(p) + + return e + + def mdi_importance(X_m, X, y): + n_samples, p = X.shape + + variables = range(p) + variables.pop(X_m) + imp = 0. + + values = [] + for i in range(p): + values.append(np.unique(X[:, i])) + + for k in range(p): + # Weight of each B of size k + coef = 1. / (binomial(k, p) * (p - k)) + + # For all B of size k + for B in combinations(variables, k): + # For all values B=b + for b in product(*[values[B[j]] for j in range(k)]): + mask_b = np.ones(n_samples, dtype=np.bool) + + for j in xrange(k): + mask_b &= X[:, B[j]] == b[j] + + X_, y_ = X[mask_b, :], y[mask_b] + n_samples_b = len(X_) + + if n_samples_b > 0: + children = [] + + for xi in values[X_m]: + mask_xi = X_[:, X_m] == xi + children.append(y_[mask_xi]) + + imp += (coef + * (1. * n_samples_b / n_samples) # P(B=b) + * (entropy(y_) - + sum([entropy(c) * len(c) / n_samples_b + for c in children]))) + + return imp + + data = np.array([[0, 0, 1, 0, 0, 1, 0, 1], + [1, 0, 1, 1, 1, 0, 1, 2], + [1, 0, 1, 1, 0, 1, 1, 3], + [0, 1, 1, 1, 0, 1, 0, 4], + [1, 1, 0, 1, 0, 1, 1, 5], + [1, 1, 0, 1, 1, 1, 1, 6], + [1, 0, 1, 0, 0, 1, 0, 7], + [1, 1, 1, 1, 1, 1, 1, 8], + [1, 1, 1, 1, 0, 1, 1, 9], + [1, 1, 1, 0, 1, 1, 1, 0]]) + + X, y = np.array(data[:, :7], dtype=np.bool), data[:, 7] + n_features = X.shape[1] + + # Compute true importances + true_importances = np.zeros(n_features) + + for i in range(n_features): + true_importances[i] = mdi_importance(i, X, y) + + # Estimate importances with totally randomized trees + clf = ExtraTreesClassifier(n_estimators=1000, + max_features=1, + criterion="entropy", + random_state=0).fit(X, y) + + importances = sum(tree.tree_.compute_feature_importances(normalize=False) + for tree in clf.estimators_) / clf.n_estimators + + # Check correctness + assert_almost_equal(entropy(y), sum(importances)) + assert_less(((true_importances - importances) ** 2).sum(), 0.0005) def check_unfitted_feature_importances(name): @@ -239,6 +346,7 @@ def test_unfitted_feature_importances(): def check_oob_score(name, X, y, n_estimators=20): # Check that oob prediction is a good estimation of the generalization # error. + # Proper behavior est = FOREST_ESTIMATORS[name](oob_score=True, random_state=0, n_estimators=n_estimators, bootstrap=True) @@ -663,7 +771,7 @@ def check_sparse_input(name, X, X_sparse, y): def test_sparse_input(): X, y = datasets.make_multilabel_classification(random_state=0, - n_samples=40) + n_samples=50) for name, sparse_matrix in product(FOREST_ESTIMATORS, (csr_matrix, csc_matrix, coo_matrix)): diff --git a/sklearn/tree/tests/test_tree.py b/sklearn/tree/tests/test_tree.py index 12c150decbe2c..fa46c81a24c4a 100644 --- a/sklearn/tree/tests/test_tree.py +++ b/sklearn/tree/tests/test_tree.py @@ -820,7 +820,7 @@ def test_sample_weight(): X = iris.data y = iris.target - duplicates = rng.randint(0, X.shape[0], 200) + duplicates = rng.randint(0, X.shape[0], 100) clf = DecisionTreeClassifier(random_state=1) clf.fit(X[duplicates], y[duplicates]) From 78974def03ae670675c0338f2979f8a9d6d6a980 Mon Sep 17 00:00:00 2001 From: Gilles Louppe Date: Sat, 12 Sep 2015 15:40:36 +0200 Subject: [PATCH 2/5] TEST: use sklearn.fixes.bincount --- sklearn/ensemble/tests/test_forest.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/sklearn/ensemble/tests/test_forest.py b/sklearn/ensemble/tests/test_forest.py index 29ba2f9bae247..0e0da8dd8d539 100644 --- a/sklearn/ensemble/tests/test_forest.py +++ b/sklearn/ensemble/tests/test_forest.py @@ -39,6 +39,7 @@ from sklearn.ensemble import RandomTreesEmbedding from sklearn.grid_search import GridSearchCV from sklearn.svm import LinearSVC +from sklearn.utils.fixes import bincount from sklearn.utils.validation import check_random_state from sklearn.tree.tree import SPARSE_SPLITTERS @@ -250,7 +251,7 @@ def entropy(samples): e = 0. n_samples = len(samples) - for count in np.bincount(samples): + for count in bincount(samples): p = 1. * count / n_samples if p > 0: e -= p * np.log2(p) @@ -260,7 +261,7 @@ def entropy(samples): def mdi_importance(X_m, X, y): n_samples, p = X.shape - variables = range(p) + variables = list(range(p)) variables.pop(X_m) imp = 0. @@ -278,7 +279,7 @@ def mdi_importance(X_m, X, y): for b in product(*[values[B[j]] for j in range(k)]): mask_b = np.ones(n_samples, dtype=np.bool) - for j in xrange(k): + for j in range(k): mask_b &= X[:, B[j]] == b[j] X_, y_ = X[mask_b, :], y[mask_b] @@ -691,7 +692,7 @@ def check_min_samples_leaf(name, X, y): random_state=0) est.fit(X, y) out = est.estimators_[0].tree_.apply(X) - node_counts = np.bincount(out) + node_counts = bincount(out) # drop inner nodes leaf_count = node_counts[node_counts != 0] assert_greater(np.min(leaf_count), 4, @@ -725,7 +726,7 @@ def check_min_weight_fraction_leaf(name, X, y): est.bootstrap = False est.fit(X, y, sample_weight=weights) out = est.estimators_[0].tree_.apply(X) - node_weights = np.bincount(out, weights=weights) + node_weights = bincount(out, weights=weights) # drop inner nodes leaf_weights = node_weights[node_weights != 0] assert_greater_equal( From bcc6f1bf5715be0cfe1c4158f13b1c9f03194566 Mon Sep 17 00:00:00 2001 From: Gilles Louppe Date: Sat, 12 Sep 2015 17:32:33 +0200 Subject: [PATCH 3/5] TEST: take comments into account --- sklearn/ensemble/tests/test_forest.py | 35 +++++++++++---------------- 1 file changed, 14 insertions(+), 21 deletions(-) diff --git a/sklearn/ensemble/tests/test_forest.py b/sklearn/ensemble/tests/test_forest.py index 0e0da8dd8d539..f5a37ded80307 100644 --- a/sklearn/ensemble/tests/test_forest.py +++ b/sklearn/ensemble/tests/test_forest.py @@ -206,21 +206,19 @@ def check_importances(X, y, name, criterion): assert_less(0 < X_new.shape[1], X.shape[1]) # Check with sample weights - sample_weight = np.ones(y.shape) - sample_weight[y == 1] *= 100 - + sample_weight = check_random_state(0).randint(1, 10, len(X)) est = ForestEstimator(n_estimators=20, random_state=0, criterion=criterion) est.fit(X, y, sample_weight=sample_weight) importances = est.feature_importances_ assert_true(np.all(importances >= 0.0)) - for scale in [3, 10, 1000, 100000]: + for scale in [10, 100, 1000]: est = ForestEstimator(n_estimators=20, random_state=0, criterion=criterion) est.fit(X, y, sample_weight=scale * sample_weight) importances_bis = est.feature_importances_ - assert_almost_equal(importances, importances_bis) + assert_less(np.abs(importances - importances_bis).mean(), 0.0001) def test_importances(): @@ -232,7 +230,7 @@ def test_importances(): for name, criterion in product(FOREST_CLASSIFIERS, ["gini", "entropy"]): yield check_importances, X, y, name, criterion - for name, criterion in product(FOREST_REGRESSORS, ["mse"]): + for name, criterion in product(FOREST_REGRESSORS, ["mse", "friedman_mse"]): yield check_importances, X, y, name, criterion @@ -242,10 +240,7 @@ def test_importances_asymptotic(): # Understanding variable importances in forests of randomized trees, 2013). def binomial(k, n): - if k < 0 or k > n: - return 0 - else: - return comb(int(n), int(k), exact=True) + return 0 if k < 0 or k > n else comb(int(n), int(k), exact=True) def entropy(samples): e = 0. @@ -259,22 +254,20 @@ def entropy(samples): return e def mdi_importance(X_m, X, y): - n_samples, p = X.shape + n_samples, n_features = X.shape - variables = list(range(p)) - variables.pop(X_m) - imp = 0. + features = list(range(n_features)) + features.pop(X_m) + values = [np.unique(X[:, i]) for i in range(n_features)] - values = [] - for i in range(p): - values.append(np.unique(X[:, i])) + imp = 0. - for k in range(p): + for k in range(n_features): # Weight of each B of size k - coef = 1. / (binomial(k, p) * (p - k)) + coef = 1. / (binomial(k, n_features) * (n_features - k)) # For all B of size k - for B in combinations(variables, k): + for B in combinations(features, k): # For all values B=b for b in product(*[values[B[j]] for j in range(k)]): mask_b = np.ones(n_samples, dtype=np.bool) @@ -331,7 +324,7 @@ def mdi_importance(X_m, X, y): # Check correctness assert_almost_equal(entropy(y), sum(importances)) - assert_less(((true_importances - importances) ** 2).sum(), 0.0005) + assert_less(np.abs(true_importances - importances).mean(), 0.01) def check_unfitted_feature_importances(name): From 5f589fbbb4725d4eaed2d77dbff5405494dc0ae8 Mon Sep 17 00:00:00 2001 From: Gilles Louppe Date: Sun, 13 Sep 2015 20:09:07 +0200 Subject: [PATCH 4/5] TEST: reduce test time, variable name, etc --- sklearn/ensemble/tests/test_forest.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/sklearn/ensemble/tests/test_forest.py b/sklearn/ensemble/tests/test_forest.py index f5a37ded80307..49e0f2cd4bb18 100644 --- a/sklearn/ensemble/tests/test_forest.py +++ b/sklearn/ensemble/tests/test_forest.py @@ -203,7 +203,7 @@ def check_importances(X, y, name, criterion): assert_equal(n_important, 3) X_new = est.transform(X, threshold="mean") - assert_less(0 < X_new.shape[1], X.shape[1]) + assert_less(X_new.shape[1], X.shape[1]) # Check with sample weights sample_weight = check_random_state(0).randint(1, 10, len(X)) @@ -213,12 +213,12 @@ def check_importances(X, y, name, criterion): importances = est.feature_importances_ assert_true(np.all(importances >= 0.0)) - for scale in [10, 100, 1000]: + for scale in [0.5, 10, 100]: est = ForestEstimator(n_estimators=20, random_state=0, criterion=criterion) est.fit(X, y, sample_weight=scale * sample_weight) importances_bis = est.feature_importances_ - assert_less(np.abs(importances - importances_bis).mean(), 0.0001) + assert_less(np.abs(importances - importances_bis).mean(), 0.001) def test_importances(): @@ -243,15 +243,15 @@ def binomial(k, n): return 0 if k < 0 or k > n else comb(int(n), int(k), exact=True) def entropy(samples): - e = 0. n_samples = len(samples) + entropy = 0. for count in bincount(samples): p = 1. * count / n_samples if p > 0: - e -= p * np.log2(p) + entropy -= p * np.log2(p) - return e + return entropy def mdi_importance(X_m, X, y): n_samples, n_features = X.shape @@ -314,7 +314,7 @@ def mdi_importance(X_m, X, y): true_importances[i] = mdi_importance(i, X, y) # Estimate importances with totally randomized trees - clf = ExtraTreesClassifier(n_estimators=1000, + clf = ExtraTreesClassifier(n_estimators=500, max_features=1, criterion="entropy", random_state=0).fit(X, y) From 3575db60a16626b08afa3f6440b85d38a99a06c1 Mon Sep 17 00:00:00 2001 From: Gilles Louppe Date: Mon, 14 Sep 2015 07:38:01 +0200 Subject: [PATCH 5/5] TEST: check parallel computation --- sklearn/ensemble/tests/test_forest.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/sklearn/ensemble/tests/test_forest.py b/sklearn/ensemble/tests/test_forest.py index 49e0f2cd4bb18..43e2e761d1994 100644 --- a/sklearn/ensemble/tests/test_forest.py +++ b/sklearn/ensemble/tests/test_forest.py @@ -205,6 +205,12 @@ def check_importances(X, y, name, criterion): X_new = est.transform(X, threshold="mean") assert_less(X_new.shape[1], X.shape[1]) + # Check with parallel + importances = est.feature_importances_ + est.set_params(n_jobs=2) + importances_parrallel = est.feature_importances_ + assert_array_almost_equal(importances, importances_parrallel) + # Check with sample weights sample_weight = check_random_state(0).randint(1, 10, len(X)) est = ForestEstimator(n_estimators=20, random_state=0, @@ -222,7 +228,7 @@ def check_importances(X, y, name, criterion): def test_importances(): - X, y = datasets.make_classification(n_samples=1000, n_features=10, + X, y = datasets.make_classification(n_samples=500, n_features=10, n_informative=3, n_redundant=0, n_repeated=0, shuffle=False, random_state=0)