-
-
Notifications
You must be signed in to change notification settings - Fork 26k
[MRG+1] Stronger tests for variable importances #5261
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
25dbb15
78974de
bcc6f1b
5f589fb
3575db6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,10 +10,14 @@ | |
|
||
import pickle | ||
from collections import defaultdict | ||
from itertools import combinations | ||
from itertools import product | ||
|
||
import numpy as np | ||
from scipy.sparse import csr_matrix, csc_matrix, coo_matrix | ||
from scipy.misc import comb | ||
from scipy.sparse import csr_matrix | ||
from scipy.sparse import csc_matrix | ||
from scipy.sparse import coo_matrix | ||
|
||
from sklearn.utils.testing import assert_almost_equal | ||
from sklearn.utils.testing import assert_array_almost_equal | ||
|
@@ -35,6 +39,7 @@ | |
from sklearn.ensemble import RandomTreesEmbedding | ||
from sklearn.grid_search import GridSearchCV | ||
from sklearn.svm import LinearSVC | ||
from sklearn.utils.fixes import bincount | ||
from sklearn.utils.validation import check_random_state | ||
|
||
from sklearn.tree.tree import SPARSE_SPLITTERS | ||
|
@@ -186,44 +191,146 @@ def test_probability(): | |
yield check_probability, name | ||
|
||
|
||
def check_importances(name, X, y): | ||
# Check variable importances. | ||
|
||
ForestClassifier = FOREST_CLASSIFIERS[name] | ||
for n_jobs in [1, 2]: | ||
clf = ForestClassifier(n_estimators=10, n_jobs=n_jobs) | ||
clf.fit(X, y) | ||
importances = clf.feature_importances_ | ||
n_important = np.sum(importances > 0.1) | ||
assert_equal(importances.shape[0], 10) | ||
assert_equal(n_important, 3) | ||
|
||
X_new = clf.transform(X, threshold="mean") | ||
assert_less(0 < X_new.shape[1], X.shape[1]) | ||
|
||
# Check with sample weights | ||
sample_weight = np.ones(y.shape) | ||
sample_weight[y == 1] *= 100 | ||
|
||
clf = ForestClassifier(n_estimators=50, n_jobs=n_jobs, random_state=0) | ||
clf.fit(X, y, sample_weight=sample_weight) | ||
importances = clf.feature_importances_ | ||
assert_true(np.all(importances >= 0.0)) | ||
def check_importances(X, y, name, criterion): | ||
ForestEstimator = FOREST_ESTIMATORS[name] | ||
|
||
clf = ForestClassifier(n_estimators=50, n_jobs=n_jobs, random_state=0) | ||
clf.fit(X, y, sample_weight=3 * sample_weight) | ||
importances_bis = clf.feature_importances_ | ||
assert_almost_equal(importances, importances_bis) | ||
est = ForestEstimator(n_estimators=20, criterion=criterion, | ||
random_state=0) | ||
est.fit(X, y) | ||
importances = est.feature_importances_ | ||
n_important = np.sum(importances > 0.1) | ||
assert_equal(importances.shape[0], 10) | ||
assert_equal(n_important, 3) | ||
|
||
X_new = est.transform(X, threshold="mean") | ||
assert_less(X_new.shape[1], X.shape[1]) | ||
|
||
# Check with parallel | ||
importances = est.feature_importances_ | ||
est.set_params(n_jobs=2) | ||
importances_parrallel = est.feature_importances_ | ||
assert_array_almost_equal(importances, importances_parrallel) | ||
|
||
# Check with sample weights | ||
sample_weight = check_random_state(0).randint(1, 10, len(X)) | ||
est = ForestEstimator(n_estimators=20, random_state=0, | ||
criterion=criterion) | ||
est.fit(X, y, sample_weight=sample_weight) | ||
importances = est.feature_importances_ | ||
assert_true(np.all(importances >= 0.0)) | ||
|
||
for scale in [0.5, 10, 100]: | ||
est = ForestEstimator(n_estimators=20, random_state=0, | ||
criterion=criterion) | ||
est.fit(X, y, sample_weight=scale * sample_weight) | ||
importances_bis = est.feature_importances_ | ||
assert_less(np.abs(importances - importances_bis).mean(), 0.001) | ||
|
||
|
||
def test_importances(): | ||
X, y = datasets.make_classification(n_samples=1000, n_features=10, | ||
X, y = datasets.make_classification(n_samples=500, n_features=10, | ||
n_informative=3, n_redundant=0, | ||
n_repeated=0, shuffle=False, | ||
random_state=0) | ||
|
||
for name in FOREST_CLASSIFIERS: | ||
yield check_importances, name, X, y | ||
for name, criterion in product(FOREST_CLASSIFIERS, ["gini", "entropy"]): | ||
yield check_importances, X, y, name, criterion | ||
|
||
for name, criterion in product(FOREST_REGRESSORS, ["mse", "friedman_mse"]): | ||
yield check_importances, X, y, name, criterion | ||
|
||
|
||
def test_importances_asymptotic(): | ||
# Check whether variable importances of totally randomized trees | ||
# converge towards their theoretical values (See Louppe et al, | ||
# Understanding variable importances in forests of randomized trees, 2013). | ||
|
||
def binomial(k, n): | ||
return 0 if k < 0 or k > n else comb(int(n), int(k), exact=True) | ||
|
||
def entropy(samples): | ||
n_samples = len(samples) | ||
entropy = 0. | ||
|
||
for count in bincount(samples): | ||
p = 1. * count / n_samples | ||
if p > 0: | ||
entropy -= p * np.log2(p) | ||
|
||
return entropy | ||
|
||
def mdi_importance(X_m, X, y): | ||
n_samples, n_features = X.shape | ||
|
||
features = list(range(n_features)) | ||
features.pop(X_m) | ||
values = [np.unique(X[:, i]) for i in range(n_features)] | ||
|
||
imp = 0. | ||
|
||
for k in range(n_features): | ||
# Weight of each B of size k | ||
coef = 1. / (binomial(k, n_features) * (n_features - k)) | ||
|
||
# For all B of size k | ||
for B in combinations(features, k): | ||
# For all values B=b | ||
for b in product(*[values[B[j]] for j in range(k)]): | ||
mask_b = np.ones(n_samples, dtype=np.bool) | ||
|
||
for j in range(k): | ||
mask_b &= X[:, B[j]] == b[j] | ||
|
||
X_, y_ = X[mask_b, :], y[mask_b] | ||
n_samples_b = len(X_) | ||
|
||
if n_samples_b > 0: | ||
children = [] | ||
|
||
for xi in values[X_m]: | ||
mask_xi = X_[:, X_m] == xi | ||
children.append(y_[mask_xi]) | ||
|
||
imp += (coef | ||
* (1. * n_samples_b / n_samples) # P(B=b) | ||
* (entropy(y_) - | ||
sum([entropy(c) * len(c) / n_samples_b | ||
for c in children]))) | ||
|
||
return imp | ||
|
||
data = np.array([[0, 0, 1, 0, 0, 1, 0, 1], | ||
[1, 0, 1, 1, 1, 0, 1, 2], | ||
[1, 0, 1, 1, 0, 1, 1, 3], | ||
[0, 1, 1, 1, 0, 1, 0, 4], | ||
[1, 1, 0, 1, 0, 1, 1, 5], | ||
[1, 1, 0, 1, 1, 1, 1, 6], | ||
[1, 0, 1, 0, 0, 1, 0, 7], | ||
[1, 1, 1, 1, 1, 1, 1, 8], | ||
[1, 1, 1, 1, 0, 1, 1, 9], | ||
[1, 1, 1, 0, 1, 1, 1, 0]]) | ||
|
||
X, y = np.array(data[:, :7], dtype=np.bool), data[:, 7] | ||
n_features = X.shape[1] | ||
|
||
# Compute true importances | ||
true_importances = np.zeros(n_features) | ||
|
||
for i in range(n_features): | ||
true_importances[i] = mdi_importance(i, X, y) | ||
|
||
# Estimate importances with totally randomized trees | ||
clf = ExtraTreesClassifier(n_estimators=500, | ||
max_features=1, | ||
criterion="entropy", | ||
random_state=0).fit(X, y) | ||
|
||
importances = sum(tree.tree_.compute_feature_importances(normalize=False) | ||
for tree in clf.estimators_) / clf.n_estimators | ||
|
||
# Check correctness | ||
assert_almost_equal(entropy(y), sum(importances)) | ||
assert_less(np.abs(true_importances - importances).mean(), 0.01) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do you use assert_less over assert_array_almost_equal? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I find it easier to control the quality of the approximation, rather than having all single importance values match up to some digit (but doing this, I only require their mean to do so). |
||
|
||
|
||
def check_unfitted_feature_importances(name): | ||
|
@@ -239,6 +346,7 @@ def test_unfitted_feature_importances(): | |
def check_oob_score(name, X, y, n_estimators=20): | ||
# Check that oob prediction is a good estimation of the generalization | ||
# error. | ||
|
||
# Proper behavior | ||
est = FOREST_ESTIMATORS[name](oob_score=True, random_state=0, | ||
n_estimators=n_estimators, bootstrap=True) | ||
|
@@ -583,7 +691,7 @@ def check_min_samples_leaf(name, X, y): | |
random_state=0) | ||
est.fit(X, y) | ||
out = est.estimators_[0].tree_.apply(X) | ||
node_counts = np.bincount(out) | ||
node_counts = bincount(out) | ||
# drop inner nodes | ||
leaf_count = node_counts[node_counts != 0] | ||
assert_greater(np.min(leaf_count), 4, | ||
|
@@ -617,7 +725,7 @@ def check_min_weight_fraction_leaf(name, X, y): | |
est.bootstrap = False | ||
est.fit(X, y, sample_weight=weights) | ||
out = est.estimators_[0].tree_.apply(X) | ||
node_weights = np.bincount(out, weights=weights) | ||
node_weights = bincount(out, weights=weights) | ||
# drop inner nodes | ||
leaf_weights = node_weights[node_weights != 0] | ||
assert_greater_equal( | ||
|
@@ -663,7 +771,7 @@ def check_sparse_input(name, X, X_sparse, y): | |
|
||
def test_sparse_input(): | ||
X, y = datasets.make_multilabel_classification(random_state=0, | ||
n_samples=40) | ||
n_samples=50) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not that I disagree, but what is the reason for this change? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I was trying to make the tree construction more stable by automatically rescaling the sample weights. Nothing worked with much success (it just moved the issue to someplace else...), but this test failed a few times because of this. Using more samples made that test more stable. |
||
|
||
for name, sparse_matrix in product(FOREST_ESTIMATORS, | ||
(csr_matrix, csc_matrix, coo_matrix)): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -820,7 +820,7 @@ def test_sample_weight(): | |
X = iris.data | ||
y = iris.target | ||
|
||
duplicates = rng.randint(0, X.shape[0], 200) | ||
duplicates = rng.randint(0, X.shape[0], 100) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Likewise, here? Is there a functional difference? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This was just to make the test slightly faster. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK |
||
|
||
clf = DecisionTreeClassifier(random_state=1) | ||
clf.fit(X[duplicates], y[duplicates]) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what does that stand for?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mean decrease impurity