-
-
Notifications
You must be signed in to change notification settings - Fork 25.8k
tree.export_text
: add class_names
argument to mirror plot_tree
#20576
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
Hello, I've made researches and tried some things in order to achieve what you ask for. The goal is to provide the ExplanationHere's the initialization of my test : from sklearn.datasets import load_iris
from sklearn import tree
# Load iris dataset
iris = load_iris()
# Load decision tree classifiers: DecisionTree
clf_DT = tree.DecisionTreeClassifier(random_state=0)
# Build decision tree classifier from iris dataset
clf_DT = clf_DT.fit(iris.data, iris.target) To remind, the signature of the function is as follows (https://github.com/scikit-learn/scikit-learn/blob/main/sklearn/tree/_export.py#L922) : def export_text(
decision_tree,
*,
feature_names=None,
max_depth=10,
spacing=3,
decimals=2,
show_weights=False,
) Of course, if I try to put the argument
If not, then we have : feature_names = iris['feature_names']
r = tree.export_text(clf_DT, feature_names=feature_names, max_depth=2)
print(r)
After changing the source code (to see more, check "Changes in the code"), we have : class_names0 = None
# Then: class_names = decision_tree.classes_
class_names1 = [0, "1"]
# Then: class_names = decision_tree.classes_
# len(class_names1) != decision_tree.n_classes_
class_names2 = ["class0", "class1", "class2"]
# Then: class_names = class_names2
# Reason: len(class_names2) == decision_tree.n_classes_
c_names = [class_names0, class_names1, class_names2]
print("===> After change <===")
for i, class_name in enumerate(c_names):
r = tree.export_text(clf_DT, class_names=class_name, feature_names=feature_names, max_depth=2)
print(r)
I hope this message was clear enough to understand. I can make changes if needed to fine tune. Changes in the codeBefore changedef export_text(
decision_tree,
*,
feature_names=None,
max_depth=10,
spacing=3,
decimals=2,
show_weights=False,
):
check_is_fitted(decision_tree)
tree_ = decision_tree.tree_
if is_classifier(decision_tree):
class_names = decision_tree.classes_ After changedef export_text(
decision_tree,
*,
feature_names=None,
class_names=None, # add class_names argument
max_depth=10,
spacing=3,
decimals=2,
show_weights=False,
):
"""
...
class_names : list of arguments, default=None
Names of each of the target classes in ascending numerical order.
Only relevant for classification and not supported for multi-output.
...
"""
check_is_fitted(decision_tree)
tree_ = decision_tree.tree_
if is_classifier(decision_tree):
if ( # length of class_names must be equal to the number of classes given by the tree
class_names is not None
and len(class_names) == len(decision_tree.classes_)
):
class_names = class_names
else: # by default, we leave decision_tree.classes_
class_names = decision_tree.classes_ |
Describe the workflow you want to enable
Decision trees should have similar / identical options for all export formats.
Describe your proposed solution
Include the
class_names
argument.Describe alternatives you've considered, if relevant
Overhauling the whole
tree
plotting and export system, to make very nice output easy to achieve. Right now each format is dependent on force layouts, matplotlib engines, and varies a lot.Additional context
https://scikit-learn.org/stable/auto_examples/tree/plot_iris_dtc.html <- shows a very nice graph.
https://stackoverflow.com/questions/25274673/is-it-possible-to-print-the-decision-tree-in-scikit-learn <- printing the decision tree layout
The text was updated successfully, but these errors were encountered: