You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The March 20, 2023 commit (MAINT validate_params for plot_tree (#25882) ) of the file: scikit-learn/sklearn/tree/_export.py
introduced the parameter validation of the plot_tree function that does not seem to agree with the documentation in the docstring or website. The parameter validation seems to omit the bool option described in the help. This option was previously permissible. Has it been removed as a valid option or is the parameter validation missing this option?
@validate_params(
{
...
"class_names": _**[list, None]**_,
...
}
)
class_names : list of str or **_bool_**, default=None
Names of each of the target classes in ascending numerical order.
Only relevant for classification and not supported for multi-output.
If ``True``, shows a symbolic representation of the class name.
Steps/Code to Reproduce
fromsklearnimportdatasetsfromsklearn.model_selectionimporttrain_test_splitfromsklearn.treeimportDecisionTreeClassifierfromsklearn.treeimportplot_treeimportmatplotlib.pyplotaspltSEED=42data=datasets.load_wine()
X=data.datay=data.targetX_train, X_test, y_train, y_test=train_test_split(X, y, random_state=SEED)
dt=DecisionTreeClassifier(max_depth=4, random_state=SEED)
dt.fit(X_train, y_train)
features=data.feature_namesclasses=data.target_names.tolist()
plot_tree(dt, feature_names=features, class_names=classes)
plt.show()
# Works in 1.2.2, error in 1.3.0plot_tree(dt, feature_names=features, class_names=True)
plt.show()
Expected Results
No error is thrown.
Actual Results
Traceback (most recent call last):
File ~\Anaconda3\envs\py311\Lib\site-packages\spyder_kernels\py3compat.py:356 in compat_exec
exec(code, globals, locals)
File c:\temp\decisiontree.py:26
plot_tree(dt, feature_names=features, class_names=True)
File ~\Anaconda3\envs\py311\Lib\site-packages\sklearn\utils\_param_validation.py:201 in wrapper
validate_parameter_constraints(
File ~\Anaconda3\envs\py311\Lib\site-packages\sklearn\utils\_param_validation.py:95 in validate_parameter_constraints
raise InvalidParameterError(
InvalidParameterError: The 'class_names' parameter of plot_tree must be an instance of 'list' or None. Got True instead.
Describe the bug
The March 20, 2023 commit (MAINT validate_params for plot_tree (#25882) ) of the file:
scikit-learn/sklearn/tree/_export.py
introduced the parameter validation of the plot_tree function that does not seem to agree with the documentation in the docstring or website. The parameter validation seems to omit the bool option described in the help. This option was previously permissible. Has it been removed as a valid option or is the parameter validation missing this option?
Steps/Code to Reproduce
Expected Results
No error is thrown.
Actual Results
Versions
The text was updated successfully, but these errors were encountered: