Skip to content

tree.export_text: add class_names argument to mirror plot_tree #20576

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
gregglind opened this issue Jul 20, 2021 · 1 comment · Fixed by #25387
Closed

tree.export_text: add class_names argument to mirror plot_tree #20576

gregglind opened this issue Jul 20, 2021 · 1 comment · Fixed by #25387

Comments

@gregglind
Copy link

Describe the workflow you want to enable

Decision trees should have similar / identical options for all export formats.

Describe your proposed solution

Include the class_names argument.

Describe alternatives you've considered, if relevant

Overhauling the whole tree plotting and export system, to make very nice output easy to achieve. Right now each format is dependent on force layouts, matplotlib engines, and varies a lot.

Additional context

https://scikit-learn.org/stable/auto_examples/tree/plot_iris_dtc.html <- shows a very nice graph.
https://stackoverflow.com/questions/25274673/is-it-possible-to-print-the-decision-tree-in-scikit-learn <- printing the decision tree layout

@mokwilliam
Copy link
Contributor

mokwilliam commented Jan 5, 2023

Hello,
I'd like to make my first contribution !

I've made researches and tried some things in order to achieve what you ask for.
By the way, I might be wrong, this issue can be linked to this one : #19824

The goal is to provide the class_names argument in the signature of the export_text function so that the user can choose the class name(s). So I don't focus on changing the display format of the tree.

Explanation

Here's the initialization of my test :

from sklearn.datasets import load_iris
from sklearn import tree

# Load iris dataset
iris = load_iris()

# Load decision tree classifiers: DecisionTree
clf_DT = tree.DecisionTreeClassifier(random_state=0)

# Build decision tree classifier from iris dataset
clf_DT = clf_DT.fit(iris.data, iris.target)

To remind, the signature of the function is as follows (https://github.com/scikit-learn/scikit-learn/blob/main/sklearn/tree/_export.py#L922) :

def export_text(
    decision_tree,
    *,
    feature_names=None,
    max_depth=10,
    spacing=3,
    decimals=2,
    show_weights=False,
)

Of course, if I try to put the argument class_names, we have the following error

TypeError: export_text() got an unexpected keyword argument class_names

If not, then we have :

feature_names = iris['feature_names']
r = tree.export_text(clf_DT, feature_names=feature_names, max_depth=2)
print(r)
|--- petal width (cm) <= 0.80
|   |--- class: 0
|--- petal width (cm) >  0.80
|   |--- petal width (cm) <= 1.75
|   |   |--- petal length (cm) <= 4.95
|   |   |   |--- truncated branch of depth 2
|   |   |--- petal length (cm) >  4.95
|   |   |   |--- truncated branch of depth 3
|   |--- petal width (cm) >  1.75
|   |   |--- petal length (cm) <= 4.85
|   |   |   |--- truncated branch of depth 2
|   |   |--- petal length (cm) >  4.85
|   |   |   |--- class: 2

After changing the source code (to see more, check "Changes in the code"), we have :

class_names0 = None 
# Then: class_names = decision_tree.classes_

class_names1 = [0, "1"] 
# Then: class_names = decision_tree.classes_
# len(class_names1) != decision_tree.n_classes_

class_names2 = ["class0", "class1", "class2"] 
# Then: class_names = class_names2
# Reason: len(class_names2) == decision_tree.n_classes_

c_names = [class_names0, class_names1, class_names2]

print("===> After change <===")
for i, class_name in enumerate(c_names):
    r = tree.export_text(clf_DT, class_names=class_name, feature_names=feature_names, max_depth=2)
    print(r)
===> After change <===
|--- petal width (cm) <= 0.80
|   |--- class: 0
|--- petal width (cm) >  0.80
|   |--- petal width (cm) <= 1.75
|   |   |--- petal length (cm) <= 4.95
|   |   |   |--- truncated branch of depth 2
|   |   |--- petal length (cm) >  4.95
|   |   |   |--- truncated branch of depth 3
|   |--- petal width (cm) >  1.75
|   |   |--- petal length (cm) <= 4.85
|   |   |   |--- truncated branch of depth 2
|   |   |--- petal length (cm) >  4.85
|   |   |   |--- class: 2

|--- petal width (cm) <= 0.80
|   |--- class: 0
|--- petal width (cm) >  0.80
|   |--- petal width (cm) <= 1.75
|   |   |--- petal length (cm) <= 4.95
|   |   |   |--- truncated branch of depth 2
|   |   |--- petal length (cm) >  4.95
|   |   |   |--- truncated branch of depth 3
|   |--- petal width (cm) >  1.75
|   |   |--- petal length (cm) <= 4.85
|   |   |   |--- truncated branch of depth 2
|   |   |--- petal length (cm) >  4.85
|   |   |   |--- class: 2

|--- petal width (cm) <= 0.80
|   |--- class: class0
|--- petal width (cm) >  0.80
|   |--- petal width (cm) <= 1.75
|   |   |--- petal length (cm) <= 4.95
|   |   |   |--- truncated branch of depth 2
|   |   |--- petal length (cm) >  4.95
|   |   |   |--- truncated branch of depth 3
|   |--- petal width (cm) >  1.75
|   |   |--- petal length (cm) <= 4.85
|   |   |   |--- truncated branch of depth 2
|   |   |--- petal length (cm) >  4.85
|   |   |   |--- class: class2

I hope this message was clear enough to understand. I can make changes if needed to fine tune.
And if it's fine, I can open a PR.

Changes in the code

Before change

def export_text(
    decision_tree,
    *,
    feature_names=None,
    max_depth=10,
    spacing=3,
    decimals=2,
    show_weights=False,
):
    check_is_fitted(decision_tree)
    tree_ = decision_tree.tree_
    if is_classifier(decision_tree):
        class_names = decision_tree.classes_

After change

def export_text(
    decision_tree,
    *,
    feature_names=None,
    class_names=None,  # add class_names argument
    max_depth=10,
    spacing=3,
    decimals=2,
    show_weights=False,
):
    """
    ...
    class_names : list of arguments, default=None
        Names of each of the target classes in ascending numerical order.
        Only relevant for classification and not supported for multi-output.
    ...
    """
    check_is_fitted(decision_tree)
    tree_ = decision_tree.tree_
    if is_classifier(decision_tree):
        if (  # length of class_names must be equal to the number of classes given by the tree
            class_names is not None
            and len(class_names) == len(decision_tree.classes_)
        ):
            class_names = class_names
        else:  # by default, we leave decision_tree.classes_
            class_names = decision_tree.classes_

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants