Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions doc/whats_new/v0.21.rst
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,13 @@ Changelog
`drop` parameter was not reflected in `get_feature_names`. :pr:`13894`
by :user:`James Myatt <jamesmyatt>`.

:mod:`sklearn.tree`
................................

- |Fix| Fixed an issue with :func:`plot_tree` where it display
entropy calculations even for `gini` criterion in DecisionTreeClassifiers.
:pr:`13947` by :user:`Frank Hoang <fhoang7>`.

:mod:`sklearn.utils.sparsefuncs`
................................

Expand Down Expand Up @@ -965,8 +972,8 @@ Baibak, daten-kieker, Denis Kataev, Didi Bar-Zev, Dillon Gardner, Dmitry Mottl,
Dmitry Vukolov, Dougal J. Sutherland, Dowon, drewmjohnston, Dror Atariah,
Edward J Brown, Ekaterina Krivich, Elizabeth Sander, Emmanuel Arias, Eric
Chang, Eric Larson, Erich Schubert, esvhd, Falak, Feda Curic, Federico Caselli,
Fibinse Xavier`, Finn O'Shea, Gabriel Marzinotto, Gabriel Vacaliuc, Gabriele
Calvo, Gael Varoquaux, GauravAhlawat, Giuseppe Vettigli, Greg Gandenberger,
Frank Hoang, Fibinse Xavier`, Finn O'Shea, Gabriel Marzinotto, Gabriel Vacaliuc,
Gabriele Calvo, Gael Varoquaux, GauravAhlawat, Giuseppe Vettigli, Greg Gandenberger,
Guillaume Fournier, Guillaume Lemaitre, Gustavo De Mari Pereira, Hanmin Qin,
haroldfox, hhu-luqi, Hunter McGushion, Ian Sanders, JackLangerman, Jacopo
Notarstefano, jakirkham, James Bourbeau, Jan Koch, Jan S, janvanrijn, Jarrod
Expand Down
11 changes: 6 additions & 5 deletions sklearn/tree/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,16 +547,16 @@ def __init__(self, max_depth=None, feature_names=None,

self.arrow_args = dict(arrowstyle="<-")

def _make_tree(self, node_id, et, depth=0):
def _make_tree(self, node_id, et, criterion, depth=0):
# traverses _tree.Tree recursively, builds intermediate
# "_reingold_tilford.Tree" object
name = self.node_to_str(et, node_id, criterion='entropy')
name = self.node_to_str(et, node_id, criterion=criterion)
if (et.children_left[node_id] != _tree.TREE_LEAF
and (self.max_depth is None or depth <= self.max_depth)):
children = [self._make_tree(et.children_left[node_id], et,
depth=depth + 1),
criterion, depth=depth + 1),
self._make_tree(et.children_right[node_id], et,
depth=depth + 1)]
criterion, depth=depth + 1)]
else:
return Tree(name, node_id)
return Tree(name, node_id, *children)
Expand All @@ -568,7 +568,8 @@ def export(self, decision_tree, ax=None):
ax = plt.gca()
ax.clear()
ax.set_axis_off()
my_tree = self._make_tree(0, decision_tree.tree_)
my_tree = self._make_tree(0, decision_tree.tree_,
decision_tree.criterion)
draw_tree = buchheim(my_tree)

# important to make sure we're still
Expand Down
27 changes: 23 additions & 4 deletions sklearn/tree/tests/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,20 +399,39 @@ def test_export_text():
assert export_text(reg, decimals=1, show_weights=True) == expected_report


def test_plot_tree(pyplot):
def test_plot_tree_entropy(pyplot):
# mostly smoke tests
# Check correctness of export_graphviz
# Check correctness of export_graphviz for criterion = entropy
clf = DecisionTreeClassifier(max_depth=3,
min_samples_split=2,
criterion="gini",
criterion="entropy",
random_state=2)
clf.fit(X, y)

# Test export code
feature_names = ['first feat', 'sepal_width']
nodes = plot_tree(clf, feature_names=feature_names)
assert len(nodes) == 3
assert nodes[0].get_text() == ("first feat <= 0.0\nentropy = 0.5\n"
assert nodes[0].get_text() == ("first feat <= 0.0\nentropy = 1.0\n"
"samples = 6\nvalue = [3, 3]")
assert nodes[1].get_text() == "entropy = 0.0\nsamples = 3\nvalue = [3, 0]"
assert nodes[2].get_text() == "entropy = 0.0\nsamples = 3\nvalue = [0, 3]"


def test_plot_tree_gini(pyplot):
# mostly smoke tests
# Check correctness of export_graphviz for criterion = gini
clf = DecisionTreeClassifier(max_depth=3,
min_samples_split=2,
criterion="gini",
random_state=2)
clf.fit(X, y)

# Test export code
feature_names = ['first feat', 'sepal_width']
nodes = plot_tree(clf, feature_names=feature_names)
assert len(nodes) == 3
assert nodes[0].get_text() == ("first feat <= 0.0\ngini = 0.5\n"
"samples = 6\nvalue = [3, 3]")
assert nodes[1].get_text() == "gini = 0.0\nsamples = 3\nvalue = [3, 0]"
assert nodes[2].get_text() == "gini = 0.0\nsamples = 3\nvalue = [0, 3]"