Skip to content

MAINT validate_params for plot_tree (#25882) #27085

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
lendres opened this issue Aug 17, 2023 · 1 comment
Closed

MAINT validate_params for plot_tree (#25882) #27085

lendres opened this issue Aug 17, 2023 · 1 comment
Labels
Bug Needs Triage Issue requires triage

Comments

@lendres
Copy link

lendres commented Aug 17, 2023

Describe the bug

The March 20, 2023 commit (MAINT validate_params for plot_tree (#25882) ) of the file:
scikit-learn/sklearn/tree/_export.py
introduced the parameter validation of the plot_tree function that does not seem to agree with the documentation in the docstring or website. The parameter validation seems to omit the bool option described in the help. This option was previously permissible. Has it been removed as a valid option or is the parameter validation missing this option?

  @validate_params(
      {
  ...
          "class_names": _**[list, None]**_,
  ...
      }
  )

class_names : list of str or **_bool_**, default=None
    Names of each of the target classes in ascending numerical order.
    Only relevant for classification and not supported for multi-output.
    If ``True``, shows a symbolic representation of the class name.

Steps/Code to Reproduce

from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
from sklearn.tree import plot_tree
import matplotlib.pyplot as plt

SEED = 42

data = datasets.load_wine()
X = data.data
y = data.target

X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=SEED)

dt = DecisionTreeClassifier(max_depth=4, random_state=SEED)
dt.fit(X_train, y_train)

features = data.feature_names
classes = data.target_names.tolist()


plot_tree(dt, feature_names=features, class_names=classes)
plt.show()

# Works in 1.2.2, error in 1.3.0
plot_tree(dt, feature_names=features, class_names=True)
plt.show()

Expected Results

No error is thrown.

Actual Results

 Traceback (most recent call last):

  File ~\Anaconda3\envs\py311\Lib\site-packages\spyder_kernels\py3compat.py:356 in compat_exec
    exec(code, globals, locals)

  File c:\temp\decisiontree.py:26
    plot_tree(dt, feature_names=features, class_names=True)

  File ~\Anaconda3\envs\py311\Lib\site-packages\sklearn\utils\_param_validation.py:201 in wrapper
    validate_parameter_constraints(

  File ~\Anaconda3\envs\py311\Lib\site-packages\sklearn\utils\_param_validation.py:95 in validate_parameter_constraints
    raise InvalidParameterError(

InvalidParameterError: The 'class_names' parameter of plot_tree must be an instance of 'list' or None. Got True instead.

Versions

import sklearn; sklearn.show_versions()

System:
    python: 3.9.17 (main, Jul  5 2023, 21:22:06) [MSC v.1916 64 bit (AMD64)]
executable: C:\Users\lance.endres\Anaconda3\python.exe
   machine: Windows-10-10.0.19044-SP0

Python dependencies:
      sklearn: 1.2.2
          pip: 23.2.1
   setuptools: 68.0.0
        numpy: 1.21.5
        scipy: 1.10.1
       Cython: 0.29.32
       pandas: 1.5.3
   matplotlib: 3.7.1
       joblib: 1.2.0
threadpoolctl: 2.2.0

Built with OpenMP: True

threadpoolctl info:
       filepath: C:\Users\lance.endres\Anaconda3\Library\bin\mkl_rt.1.dll
         prefix: mkl_rt
       user_api: blas
   internal_api: mkl
        version: 2021.4-Product
    num_threads: 6
threading_layer: intel

       filepath: C:\Users\lance.endres\Anaconda3\vcomp140.dll
         prefix: vcomp
       user_api: openmp
   internal_api: openmp
        version: None
    num_threads: 12

       filepath: C:\Users\lance.endres\Anaconda3\Library\bin\libiomp5md.dll
         prefix: libiomp
       user_api: openmp
   internal_api: openmp
        version: None
    num_threads: 12
@lendres lendres added Bug Needs Triage Issue requires triage labels Aug 17, 2023
@adrinjalali
Copy link
Member

already fixed in #26903

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Bug Needs Triage Issue requires triage
Projects
None yet
Development

No branches or pull requests

2 participants