Skip to content
1 change: 1 addition & 0 deletions sklearn/tests/test_public_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ def _check_function_param_validation(
"sklearn.random_projection.johnson_lindenstrauss_min_dim",
"sklearn.svm.l1_min_c",
"sklearn.tree.export_text",
"sklearn.tree.plot_tree",
"sklearn.utils.gen_batches",
]

Expand Down
33 changes: 18 additions & 15 deletions sklearn/tree/_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import numpy as np

from ..utils.validation import check_is_fitted
from ..utils._param_validation import Interval, validate_params
from ..utils._param_validation import Interval, validate_params, StrOptions

from ..base import is_classifier

Expand Down Expand Up @@ -77,6 +77,23 @@ def __repr__(self):
SENTINEL = Sentinel()


@validate_params(
{
"decision_tree": [DecisionTreeClassifier, DecisionTreeRegressor],
"max_depth": [Interval(Integral, 0, None, closed="left"), None],
"feature_names": [list, None],
"class_names": [list, None],
"label": [StrOptions({"all", "root", "none"})],
"filled": ["boolean"],
"impurity": ["boolean"],
"node_ids": ["boolean"],
"proportion": ["boolean"],
"rounded": ["boolean"],
"precision": [Interval(Integral, 0, None, closed="left"), None],
"ax": "no_validation", # delegate validation to matplotlib
"fontsize": [Interval(Integral, 0, None, closed="left"), None],
}
)
def plot_tree(
decision_tree,
*,
Expand Down Expand Up @@ -601,20 +618,6 @@ def __init__(
)
self.fontsize = fontsize

# validate
if isinstance(precision, Integral):
if precision < 0:
raise ValueError(
"'precision' should be greater or equal to 0."
" Got {} instead.".format(precision)
)
else:
raise ValueError(
"'precision' should be an integer. Got {} instead.".format(
type(precision)
)
)

# The depth of each node for plotting with 'leaf' option
self.ranks = {"leaves": []}
# The colors to render each node with
Expand Down