From cd6a1a6acf01d3035d2ffc62fe77059aafedeb98 Mon Sep 17 00:00:00 2001
From: Daniel Galvez
Date: Wed, 7 Jan 2015 13:03:10 -0500
Subject: [PATCH 1/6] Make apply method of trees public. Added test for
concistency with private method.
---
sklearn/tree/tests/test_tree.py | 2 ++
sklearn/tree/tree.py | 4 ++++
2 files changed, 6 insertions(+)
diff --git a/sklearn/tree/tests/test_tree.py b/sklearn/tree/tests/test_tree.py
index bd08fcdeadd55..02b17667d4fbf 100644
--- a/sklearn/tree/tests/test_tree.py
+++ b/sklearn/tree/tests/test_tree.py
@@ -1137,6 +1137,8 @@ def check_explicit_sparse_zeros(tree, max_depth=3,
Xs = (X_test, X_sparse_test)
for X1, X2 in product(Xs, Xs):
assert_array_almost_equal(s.tree_.apply(X1), d.tree_.apply(X2))
+ assert_array_almost_equal(s.apply(X1), d.apply(X2))
+ assert_array_almost_equal(s.apply(X1), s.tree_.apply(X1))
assert_array_almost_equal(s.predict(X1), d.predict(X2))
if tree in CLF_TREES:
diff --git a/sklearn/tree/tree.py b/sklearn/tree/tree.py
index ebb194845d970..1607f7b01778c 100644
--- a/sklearn/tree/tree.py
+++ b/sklearn/tree/tree.py
@@ -371,6 +371,10 @@ def feature_importances_(self):
return self.tree_.compute_feature_importances()
+ def apply(self, X):
+ X = check_array(X, dtype= DTYPE, accept_sparse="csr")
+ return self.tree_.apply(X)
+
# =============================================================================
# Public estimators
From e8928c9bf9fd61441acc3f25f11975eab9a43182 Mon Sep 17 00:00:00 2001
From: Daniel Galvez
Date: Wed, 7 Jan 2015 13:30:57 -0500
Subject: [PATCH 2/6] Added docstring
---
sklearn/tree/tree.py | 12 ++++++++++++
1 file changed, 12 insertions(+)
diff --git a/sklearn/tree/tree.py b/sklearn/tree/tree.py
index 1607f7b01778c..1f4553c693e58 100644
--- a/sklearn/tree/tree.py
+++ b/sklearn/tree/tree.py
@@ -372,6 +372,18 @@ def feature_importances_(self):
return self.tree_.compute_feature_importances()
def apply(self, X):
+ """
+ Returns the index of the leaf that each sample is predicted as.
+
+ Parameters
+ ----------
+ X: array_like, shape = (n_samples, n_features)
+ Input Samples
+
+ Returns
+ -------
+ X_leaves: array_like, shape = (n_samples,)
+ """
X = check_array(X, dtype= DTYPE, accept_sparse="csr")
return self.tree_.apply(X)
From f2e9ec7ea8f97d53e2afcc52a8ef24ccf318365e Mon Sep 17 00:00:00 2001
From: Daniel Galvez
Date: Wed, 7 Jan 2015 21:51:36 -0500
Subject: [PATCH 3/6] Added example demonstrating tree.apply
---
examples/tree/plot_tree_feat.py | 84 +++++++++++++++++++++++++++++++++
1 file changed, 84 insertions(+)
create mode 100644 examples/tree/plot_tree_feat.py
diff --git a/examples/tree/plot_tree_feat.py b/examples/tree/plot_tree_feat.py
new file mode 100644
index 0000000000000..56d4694652378
--- /dev/null
+++ b/examples/tree/plot_tree_feat.py
@@ -0,0 +1,84 @@
+"""
+===================================================================
+Decision Tree Feature Extraction
+===================================================================
+
+Obtaining features from decision trees.
+
+A dataset can be transformed using a decision tree's apply() method
+in two ways:
+
+1) Reducing the number of classes to predict. By selecting max_leaf_nodes
+to be a value less than the total number of classes in a classification
+problem, one can obtain a dataset with a reduced number of classes.
+
+2) Creating a new sparse feature representation of the data.
+Each sample will be transformed to a vector of the size of the number of
+leafs in the decision tree. Each leaf is assigned an index in this vector.
+If the sample falls into a given leaf, the value at that leaf's index in the
+vector is 1; otherwise it is 0. Only one value in the array can be 1.
+This sparse, high-dimensional representation may be useful for increasing data
+separability.
+
+Note that in the below double bar graph demonstrating the first
+transformation, all
+setosas fall into the first leaf, and none into the second leaf. Similarly,
+all versicolors and virginicas fall only into the second leaf. This suggests
+that virginicas and versicolors are more similar to each other than to
+setosas.
+
+"""
+print(__doc__)
+
+import numpy as np
+import matplotlib.pyplot as plt
+
+from sklearn import tree
+from sklearn import ensemble
+from sklearn.datasets import load_iris
+
+max_leaves = 2
+iris = load_iris()
+clf = tree.DecisionTreeClassifier(max_leaf_nodes=max_leaves)
+X = iris['data']
+y = iris['target']
+
+clf.fit(X,y)
+#1
+y_reduced = clf.apply(X) #Now only two classes instead of three.
+
+bar_width = .35
+opacity = 0.4
+index = np.arange(3)
+
+leaf_class_colors = {}
+leaf_class_colors.update(zip(range(np.max(y_reduced)), ['r', 'b']))
+
+new_classes = []
+for i in (1,2):# xrange(iris['target_names'].size):
+ new_classes.append(np.array([np.sum(y[y_reduced == i] == 0), \
+ np.sum(y[y_reduced == i] == 1), \
+ np.sum(y[y_reduced == i] == 2) \
+ ]))
+
+for i in xrange(np.max(y_reduced)):
+ plt.bar(index + i * bar_width, new_classes[i], bar_width, alpha=opacity, \
+ color=leaf_class_colors[i], label="Leaf " +str(i + 1))
+
+plt.title("The assignment of each original class to new leaf index classes")
+plt.xticks(index + bar_width, iris['target_names'])
+plt.xlabel("Original class")
+plt.ylabel("Number in each new leaf class")
+plt.legend()
+
+
+#2
+# We don't need to use a decision tree with a constrained number of leaves,
+# but we do so here for the convenience of using the same classifier to
+# demonstrate part 1.
+X_trans = np.zeros((y_reduced.size, max_leaves))
+for i in xrange(max_leaves):
+ X_trans[:,i] = y_reduced == i + 1 #Add 1 because leaf indexing begins at 1
+
+#For the graph to be built in the documentation, plt.show() must be called last.
+plt.show()
From 3777046266be057e908796399e16c42a89a6c42e Mon Sep 17 00:00:00 2001
From: Daniel Galvez
Date: Wed, 7 Jan 2015 21:55:23 -0500
Subject: [PATCH 4/6] Added indentation to docstring
---
sklearn/tree/tree.py | 20 ++++++++++----------
1 file changed, 10 insertions(+), 10 deletions(-)
diff --git a/sklearn/tree/tree.py b/sklearn/tree/tree.py
index 1f4553c693e58..20245f27e5095 100644
--- a/sklearn/tree/tree.py
+++ b/sklearn/tree/tree.py
@@ -372,18 +372,18 @@ def feature_importances_(self):
return self.tree_.compute_feature_importances()
def apply(self, X):
- """
- Returns the index of the leaf that each sample is predicted as.
+ """
+ Returns the index of the leaf that each sample is predicted as.
- Parameters
- ----------
- X: array_like, shape = (n_samples, n_features)
- Input Samples
+ Parameters
+ ----------
+ X: array_like, shape = (n_samples, n_features)
+ Input Samples
- Returns
- -------
- X_leaves: array_like, shape = (n_samples,)
- """
+ Returns
+ -------
+ X_leaves: array_like, shape = (n_samples,)
+ """
X = check_array(X, dtype= DTYPE, accept_sparse="csr")
return self.tree_.apply(X)
From 5e7b51f46abdaec3d79084679389211953c8f8a9 Mon Sep 17 00:00:00 2001
From: Daniel Galvez
Date: Wed, 7 Jan 2015 22:09:37 -0500
Subject: [PATCH 5/6] Removed cruft
---
examples/tree/plot_tree_feat.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/examples/tree/plot_tree_feat.py b/examples/tree/plot_tree_feat.py
index 56d4694652378..31f133fb997e5 100644
--- a/examples/tree/plot_tree_feat.py
+++ b/examples/tree/plot_tree_feat.py
@@ -55,7 +55,7 @@
leaf_class_colors.update(zip(range(np.max(y_reduced)), ['r', 'b']))
new_classes = []
-for i in (1,2):# xrange(iris['target_names'].size):
+for i in (1,2):
new_classes.append(np.array([np.sum(y[y_reduced == i] == 0), \
np.sum(y[y_reduced == i] == 1), \
np.sum(y[y_reduced == i] == 2) \
From 41e2aef5ee1fd4325d98692828d9dc92bce2e99c Mon Sep 17 00:00:00 2001
From: Daniel Galvez
Date: Fri, 9 Jan 2015 23:36:15 -0500
Subject: [PATCH 6/6] Added tests of apply() for valid and invalid inputs.
Fixed style.
---
examples/tree/plot_tree_feat.py | 9 ++++-----
sklearn/tree/tests/test_tree.py | 34 ++++++++++++++++++++++++++++++++-
sklearn/tree/tree.py | 19 ++++++++++++++----
3 files changed, 52 insertions(+), 10 deletions(-)
diff --git a/examples/tree/plot_tree_feat.py b/examples/tree/plot_tree_feat.py
index 31f133fb997e5..39e9e48293381 100644
--- a/examples/tree/plot_tree_feat.py
+++ b/examples/tree/plot_tree_feat.py
@@ -56,13 +56,12 @@
new_classes = []
for i in (1,2):
- new_classes.append(np.array([np.sum(y[y_reduced == i] == 0), \
- np.sum(y[y_reduced == i] == 1), \
- np.sum(y[y_reduced == i] == 2) \
- ]))
+ new_classes.append(np.array([np.sum(y[y_reduced == i] == 0),
+ np.sum(y[y_reduced == i] == 1),
+ np.sum(y[y_reduced == i] == 2)]))
for i in xrange(np.max(y_reduced)):
- plt.bar(index + i * bar_width, new_classes[i], bar_width, alpha=opacity, \
+ plt.bar(index + i * bar_width, new_classes[i], bar_width, alpha=opacity,
color=leaf_class_colors[i], label="Leaf " +str(i + 1))
plt.title("The assignment of each original class to new leaf index classes")
diff --git a/sklearn/tree/tests/test_tree.py b/sklearn/tree/tests/test_tree.py
index 02b17667d4fbf..b29249feab02a 100644
--- a/sklearn/tree/tests/test_tree.py
+++ b/sklearn/tree/tests/test_tree.py
@@ -1198,4 +1198,36 @@ def test_min_weight_leaf_split_level():
for name in ALL_TREES:
yield check_min_weight_leaf_split_level, name
-
+def check_public_apply(tree):
+ # tree_.apply does not check that data is of type float32, so we manually
+ # do it here.
+ X_small_32 = X_small.astype(np.float32, copy = True)
+ if tree in CLF_TREES.keys():
+ tree_class = CLF_TREES[tree]
+ clf = tree_class()
+ clf.fit(X_small_32, y_small)
+ else: #The tree is a regression a tree
+ tree_class = REG_TREES[tree]
+ clf = tree_class()
+ clf.fit(X_small_32, y_small_reg)
+
+ assert_array_equal(clf.apply(X_small_32), clf.tree_.apply(X_small_32))
+
+ for sparse_matrix in (csr_matrix, csc_matrix, coo_matrix):
+ X_small_sparse = sparse_matrix(X_small_32)
+ assert_array_equal(clf.apply(X_small_sparse), clf.tree_.apply(X_small_32))
+
+def test_public_apply():
+ """
+ Test that Tree.apply matches Tree.tree_.apply for sparse and dense inputs
+ """
+ for tree in ALL_TREES.iterkeys():
+ yield check_public_apply, tree
+
+def test_apply_valid():
+ """
+ Check that apply() raises error if preconditions not met.
+ """
+ clf = DecisionTreeClassifier()
+ X_sparse_small = csr_matrix(X_small)
+ assert_raises(ValueError, clf.apply, X_sparse_small)
diff --git a/sklearn/tree/tree.py b/sklearn/tree/tree.py
index 20245f27e5095..7e9b3070c484f 100644
--- a/sklearn/tree/tree.py
+++ b/sklearn/tree/tree.py
@@ -377,14 +377,25 @@ def apply(self, X):
Parameters
----------
- X: array_like, shape = (n_samples, n_features)
- Input Samples
+ X : array_like or sparse matrix, shape = [n_samples, n_features]
+ The input samples. Internally, it will be converted to
+ ``dtype=np.float32`` and if a sparse matrix is provided
+ to a sparse ``csr_matrix``.
Returns
-------
- X_leaves: array_like, shape = (n_samples,)
+ X_leaves : array_like, shape = [n_samples,]
+ For each datapoint x in X, return the index of the leaf x
+ ends up in.
"""
- X = check_array(X, dtype= DTYPE, accept_sparse="csr")
+ if self.tree_ is None:
+ raise ValueError("Estimator not fitted, "
+ "call `fit` before `apply`.")
+
+ X = check_array(X, dtype=DTYPE, accept_sparse="csr")
+ if issparse(X) and (X.indices.dtype != np.int32 or X.indptr.dtype != np.int32):
+ raise ValueError("No support for np.int64 index based "
+ "sparse matrices")
return self.tree_.apply(X)