-
-
Notifications
You must be signed in to change notification settings - Fork 25.8k
[MRG] Tree apply #4065
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MRG] Tree apply #4065
Changes from all commits
cd6a1a6
e8928c9
f2e9ec7
3777046
5e7b51f
41e2aef
df20027
25afee2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
""" | ||
=================================================================== | ||
Decision Tree Feature Extraction | ||
=================================================================== | ||
|
||
Obtaining features from decision trees. | ||
|
||
A dataset can be transformed using a decision tree's apply() method | ||
in two ways: | ||
|
||
1) Reducing the number of classes to predict. By selecting max_leaf_nodes | ||
to be a value less than the total number of classes in a classification | ||
problem, one can obtain a dataset with a reduced number of classes. | ||
|
||
2) Creating a new sparse feature representation of the data. | ||
Each sample will be transformed to a vector of the size of the number of | ||
leafs in the decision tree. Each leaf is assigned an index in this vector. | ||
If the sample falls into a given leaf, the value at that leaf's index in the | ||
vector is 1; otherwise it is 0. Only one value in the array can be 1. | ||
This sparse, high-dimensional representation may be useful for increasing data | ||
separability. | ||
|
||
Note that in the below double bar graph demonstrating the first | ||
transformation, all | ||
setosas fall into the first leaf, and none into the second leaf. Similarly, | ||
all versicolors and virginicas fall only into the second leaf. This suggests | ||
that virginicas and versicolors are more similar to each other than to | ||
setosas. | ||
|
||
""" | ||
print(__doc__) | ||
|
||
import numpy as np | ||
import matplotlib.pyplot as plt | ||
|
||
from sklearn import tree | ||
from sklearn import ensemble | ||
from sklearn.datasets import load_iris | ||
|
||
max_leaves = 2 | ||
iris = load_iris() | ||
clf = tree.DecisionTreeClassifier(max_leaf_nodes=max_leaves) | ||
X = iris['data'] | ||
y = iris['target'] | ||
|
||
clf.fit(X,y) | ||
#1 | ||
y_reduced = clf.apply(X) #Now only two classes instead of three. | ||
|
||
bar_width = .35 | ||
opacity = 0.4 | ||
index = np.arange(3) | ||
|
||
leaf_class_colors = {} | ||
leaf_class_colors.update(zip(range(np.max(y_reduced)), ['r', 'b'])) | ||
|
||
new_classes = [] | ||
for i in (1,2): | ||
new_classes.append(np.array([np.sum(y[y_reduced == i] == 0), | ||
np.sum(y[y_reduced == i] == 1), | ||
np.sum(y[y_reduced == i] == 2)])) | ||
|
||
for i in xrange(np.max(y_reduced)): | ||
plt.bar(index + i * bar_width, new_classes[i], bar_width, alpha=opacity, | ||
color=leaf_class_colors[i], label="Leaf " +str(i + 1)) | ||
|
||
plt.title("The assignment of each original class to new leaf index classes") | ||
plt.xticks(index + bar_width, iris['target_names']) | ||
plt.xlabel("Original class") | ||
plt.ylabel("Number in each new leaf class") | ||
plt.legend() | ||
|
||
|
||
#2 | ||
# We don't need to use a decision tree with a constrained number of leaves, | ||
# but we do so here for the convenience of using the same classifier to | ||
# demonstrate part 1. | ||
X_trans = np.zeros((y_reduced.size, max_leaves)) | ||
for i in xrange(max_leaves): | ||
X_trans[:,i] = y_reduced == i + 1 #Add 1 because leaf indexing begins at 1 | ||
|
||
#For the graph to be built in the documentation, plt.show() must be called last. | ||
plt.show() |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1210,6 +1210,8 @@ def check_explicit_sparse_zeros(tree, max_depth=3, | |
Xs = (X_test, X_sparse_test) | ||
for X1, X2 in product(Xs, Xs): | ||
assert_array_almost_equal(s.tree_.apply(X1), d.tree_.apply(X2)) | ||
assert_array_almost_equal(s.apply(X1), d.apply(X2)) | ||
assert_array_almost_equal(s.apply(X1), s.tree_.apply(X1)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please write instead an independent test to only check the correctness of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. +1 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Wait, the test is there, but it should be removed here, right? |
||
assert_array_almost_equal(s.predict(X1), d.predict(X2)) | ||
|
||
if tree in CLF_TREES: | ||
|
@@ -1268,3 +1270,38 @@ def check_min_weight_leaf_split_level(name): | |
def test_min_weight_leaf_split_level(): | ||
for name in ALL_TREES: | ||
yield check_min_weight_leaf_split_level, name | ||
<<<<<<< HEAD | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Merge error here. |
||
|
||
def check_public_apply(tree): | ||
# tree_.apply does not check that data is of type float32, so we manually | ||
# do it here. | ||
X_small_32 = X_small.astype(np.float32, copy = True) | ||
if tree in CLF_TREES.keys(): | ||
tree_class = CLF_TREES[tree] | ||
clf = tree_class() | ||
clf.fit(X_small_32, y_small) | ||
else: #The tree is a regression a tree | ||
tree_class = REG_TREES[tree] | ||
clf = tree_class() | ||
clf.fit(X_small_32, y_small_reg) | ||
|
||
assert_array_equal(clf.apply(X_small_32), clf.tree_.apply(X_small_32)) | ||
|
||
for sparse_matrix in (csr_matrix, csc_matrix, coo_matrix): | ||
X_small_sparse = sparse_matrix(X_small_32) | ||
assert_array_equal(clf.apply(X_small_sparse), clf.tree_.apply(X_small_32)) | ||
|
||
def test_public_apply(): | ||
""" | ||
Test that Tree.apply matches Tree.tree_.apply for sparse and dense inputs | ||
""" | ||
for tree in ALL_TREES.iterkeys(): | ||
yield check_public_apply, tree | ||
|
||
def test_apply_valid(): | ||
""" | ||
Check that apply() raises error if preconditions not met. | ||
""" | ||
clf = DecisionTreeClassifier() | ||
X_sparse_small = csr_matrix(X_small) | ||
assert_raises(ValueError, clf.apply, X_sparse_small) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -415,6 +415,33 @@ def feature_importances_(self): | |
|
||
return self.tree_.compute_feature_importances() | ||
|
||
def apply(self, X): | ||
""" | ||
Returns the index of the leaf that each sample is predicted as. | ||
|
||
Parameters | ||
---------- | ||
X : array_like or sparse matrix, shape = [n_samples, n_features] | ||
The input samples. Internally, it will be converted to | ||
``dtype=np.float32`` and if a sparse matrix is provided | ||
to a sparse ``csr_matrix``. | ||
|
||
Returns | ||
------- | ||
X_leaves : array_like, shape = [n_samples,] | ||
For each datapoint x in X, return the index of the leaf x | ||
ends up in. | ||
""" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For consistency, could you mimick the docstring of https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/ensemble/forest.py#L138 ? |
||
if self.tree_ is None: | ||
raise ValueError("Estimator not fitted, " | ||
"call `fit` before `apply`.") | ||
|
||
X = check_array(X, dtype=DTYPE, accept_sparse="csr") | ||
if issparse(X) and (X.indices.dtype != np.int32 or X.indptr.dtype != np.int32): | ||
raise ValueError("No support for np.int64 index based " | ||
"sparse matrices") | ||
return self.tree_.apply(X) | ||
|
||
|
||
# ============================================================================= | ||
# Public estimators | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you please reformat this to :