Skip to content

MAINT Parameter validation for tree.export_text #25867

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
bbda940
MAINT Parameter Validation for linear_model.orthogonal_mp
choudharynishu Mar 10, 2023
dee4ebc
Fixed tol parameter validation dictionary
choudharynishu Mar 10, 2023
0869598
Fixed n_nonzero_coefs parameter validation
choudharynishu Mar 10, 2023
946801b
Edited linear_model orthogonal file
choudharynishu Mar 10, 2023
de0c605
"merge main"
choudharynishu Mar 10, 2023
126c030
Edited tol parameter validation
choudharynishu Mar 10, 2023
1730480
Edited range for tol parameter in linear_model.orthogonal_mp
choudharynishu Mar 10, 2023
c960105
Added sklearn.linear_model.orthogonal_mp to test public functions list
choudharynishu Mar 10, 2023
df75106
Merge branch 'main' of github.com:scikit-learn/scikit-learn into para…
choudharynishu Mar 13, 2023
3e61c9a
validation for linear_model.orthogonal_mp changed X from ndarray to a…
choudharynishu Mar 13, 2023
ea7de7b
Merge branch 'main' into param_validation_linearmodel
choudharynishu Mar 14, 2023
3698c55
Merge branch 'main' into param_validation_linearmodel
choudharynishu Mar 14, 2023
cc70faa
Merge branch 'main' of github.com:scikit-learn/scikit-learn into para…
choudharynishu Mar 14, 2023
98ea492
Merge branch 'main' of github.com:scikit-learn/scikit-learn into para…
choudharynishu Mar 14, 2023
f868af8
Merge branch 'main' of github.com:scikit-learn/scikit-learn into para…
choudharynishu Mar 14, 2023
e2005ae
Removed outdated validation for 'tol' and 'n_nonzero_coefs' in linear…
choudharynishu Mar 14, 2023
6796ebf
Merge branch 'main' of github.com:scikit-learn/scikit-learn into para…
choudharynishu Mar 14, 2023
65aa5a5
Merge branch 'main' of github.com:scikit-learn/scikit-learn into para…
choudharynishu Mar 15, 2023
57ac873
added parameter validation for sklearn.tree.export_text & added to li…
choudharynishu Mar 15, 2023
463b338
Merge branch 'main' of github.com:scikit-learn/scikit-learn into para…
choudharynishu Mar 15, 2023
dc604ff
merging changes from upstream main
choudharynishu Mar 16, 2023
24ae386
Merge branch 'main' of github.com:scikit-learn/scikit-learn into para…
choudharynishu Mar 16, 2023
8a16441
Fixed incorrect ranges of max_depth, spacing, and decimal & Edited er…
choudharynishu Mar 16, 2023
0063139
Merge branch 'main' of github.com:scikit-learn/scikit-learn into para…
choudharynishu Mar 16, 2023
bb65f2f
Merge branch 'main' of github.com:scikit-learn/scikit-learn into para…
choudharynishu Mar 16, 2023
62c7300
Fixed max_depth, spacing, and decimals values
choudharynishu Mar 16, 2023
8ecf919
Removed outdates validation from export_text and corresponding test f…
choudharynishu Mar 16, 2023
17db14c
added class_names
choudharynishu Mar 16, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions sklearn/tests/test_public_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ def _check_function_param_validation(
"sklearn.model_selection.train_test_split",
"sklearn.random_projection.johnson_lindenstrauss_min_dim",
"sklearn.svm.l1_min_c",
"sklearn.tree.export_text",
"sklearn.utils.gen_batches",
]

Expand Down
24 changes: 14 additions & 10 deletions sklearn/tree/_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,14 @@
import numpy as np

from ..utils.validation import check_is_fitted
from ..utils._param_validation import Interval, validate_params

from ..base import is_classifier

from . import _criterion
from . import _tree
from ._reingold_tilford import buchheim, Tree
from . import DecisionTreeClassifier
from . import DecisionTreeClassifier, DecisionTreeRegressor


def _color_brew(n):
Expand Down Expand Up @@ -919,6 +921,17 @@ def compute_depth_(
return max(depths)


@validate_params(
{
"decision_tree": [DecisionTreeClassifier, DecisionTreeRegressor],
"feature_names": [list, None],
"class_names": [list, None],
"max_depth": [Interval(Integral, 0, None, closed="left"), None],
"spacing": [Interval(Integral, 1, None, closed="left"), None],
"decimals": [Interval(Integral, 0, None, closed="left"), None],
"show_weights": ["boolean"],
}
)
def export_text(
decision_tree,
*,
Expand Down Expand Up @@ -1011,21 +1024,12 @@ def export_text(
left_child_fmt = "{} {} > {}\n"
truncation_fmt = "{} {}\n"

if max_depth < 0:
raise ValueError("max_depth bust be >= 0, given %d" % max_depth)

if feature_names is not None and len(feature_names) != tree_.n_features:
raise ValueError(
"feature_names must contain %d elements, got %d"
% (tree_.n_features, len(feature_names))
)

if spacing <= 0:
raise ValueError("spacing must be > 0, given %d" % spacing)

if decimals < 0:
raise ValueError("decimals must be >= 0, given %d" % decimals)

if isinstance(decision_tree, DecisionTreeClassifier):
value_fmt = "{}{} weights: {}\n"
if not show_weights:
Expand Down
10 changes: 0 additions & 10 deletions sklearn/tree/tests/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,10 +350,6 @@ def test_precision():
def test_export_text_errors():
clf = DecisionTreeClassifier(max_depth=2, random_state=0)
clf.fit(X, y)

err_msg = "max_depth bust be >= 0, given -1"
with pytest.raises(ValueError, match=err_msg):
export_text(clf, max_depth=-1)
err_msg = "feature_names must contain 2 elements, got 1"
with pytest.raises(ValueError, match=err_msg):
export_text(clf, feature_names=["a"])
Expand All @@ -364,12 +360,6 @@ def test_export_text_errors():
)
with pytest.raises(ValueError, match=err_msg):
export_text(clf, class_names=["a"])
err_msg = "decimals must be >= 0, given -1"
with pytest.raises(ValueError, match=err_msg):
export_text(clf, decimals=-1)
err_msg = "spacing must be > 0, given 0"
with pytest.raises(ValueError, match=err_msg):
export_text(clf, spacing=0)


def test_export_text():
Expand Down