diff --git a/sklearn/tests/test_public_functions.py b/sklearn/tests/test_public_functions.py index 2fa93fdfb6adf..7477d614da4c7 100644 --- a/sklearn/tests/test_public_functions.py +++ b/sklearn/tests/test_public_functions.py @@ -191,6 +191,7 @@ def _check_function_param_validation( "sklearn.model_selection.train_test_split", "sklearn.random_projection.johnson_lindenstrauss_min_dim", "sklearn.svm.l1_min_c", + "sklearn.tree.export_text", "sklearn.utils.gen_batches", ] diff --git a/sklearn/tree/_export.py b/sklearn/tree/_export.py index 3e65c4a2b0dc5..c7bc0dd04c08f 100644 --- a/sklearn/tree/_export.py +++ b/sklearn/tree/_export.py @@ -17,12 +17,14 @@ import numpy as np from ..utils.validation import check_is_fitted +from ..utils._param_validation import Interval, validate_params + from ..base import is_classifier from . import _criterion from . import _tree from ._reingold_tilford import buchheim, Tree -from . import DecisionTreeClassifier +from . import DecisionTreeClassifier, DecisionTreeRegressor def _color_brew(n): @@ -919,6 +921,17 @@ def compute_depth_( return max(depths) +@validate_params( + { + "decision_tree": [DecisionTreeClassifier, DecisionTreeRegressor], + "feature_names": [list, None], + "class_names": [list, None], + "max_depth": [Interval(Integral, 0, None, closed="left"), None], + "spacing": [Interval(Integral, 1, None, closed="left"), None], + "decimals": [Interval(Integral, 0, None, closed="left"), None], + "show_weights": ["boolean"], + } +) def export_text( decision_tree, *, @@ -1011,21 +1024,12 @@ def export_text( left_child_fmt = "{} {} > {}\n" truncation_fmt = "{} {}\n" - if max_depth < 0: - raise ValueError("max_depth bust be >= 0, given %d" % max_depth) - if feature_names is not None and len(feature_names) != tree_.n_features: raise ValueError( "feature_names must contain %d elements, got %d" % (tree_.n_features, len(feature_names)) ) - if spacing <= 0: - raise ValueError("spacing must be > 0, given %d" % spacing) - - if decimals < 0: - raise ValueError("decimals must be >= 0, given %d" % decimals) - if isinstance(decision_tree, DecisionTreeClassifier): value_fmt = "{}{} weights: {}\n" if not show_weights: diff --git a/sklearn/tree/tests/test_export.py b/sklearn/tree/tests/test_export.py index 8865cb724a02a..5b4d581951cac 100644 --- a/sklearn/tree/tests/test_export.py +++ b/sklearn/tree/tests/test_export.py @@ -350,10 +350,6 @@ def test_precision(): def test_export_text_errors(): clf = DecisionTreeClassifier(max_depth=2, random_state=0) clf.fit(X, y) - - err_msg = "max_depth bust be >= 0, given -1" - with pytest.raises(ValueError, match=err_msg): - export_text(clf, max_depth=-1) err_msg = "feature_names must contain 2 elements, got 1" with pytest.raises(ValueError, match=err_msg): export_text(clf, feature_names=["a"]) @@ -364,12 +360,6 @@ def test_export_text_errors(): ) with pytest.raises(ValueError, match=err_msg): export_text(clf, class_names=["a"]) - err_msg = "decimals must be >= 0, given -1" - with pytest.raises(ValueError, match=err_msg): - export_text(clf, decimals=-1) - err_msg = "spacing must be > 0, given 0" - with pytest.raises(ValueError, match=err_msg): - export_text(clf, spacing=0) def test_export_text():