From b510e4404007eea9b891a07ecda2069c4808abbf Mon Sep 17 00:00:00 2001 From: jeremiedbb Date: Thu, 9 Mar 2023 19:11:13 +0100 Subject: [PATCH 1/4] ensure disjoint interval constraints --- sklearn/cluster/_optics.py | 14 +- sklearn/cluster/tests/test_optics.py | 2 +- sklearn/decomposition/_pca.py | 2 +- sklearn/ensemble/_bagging.py | 6 +- sklearn/ensemble/_forest.py | 2 +- .../gradient_boosting.py | 2 +- sklearn/ensemble/_iforest.py | 2 +- sklearn/feature_extraction/image.py | 6 +- sklearn/feature_extraction/text.py | 6 +- sklearn/feature_selection/_rfe.py | 6 +- sklearn/feature_selection/_sequential.py | 2 +- sklearn/linear_model/_ransac.py | 2 +- sklearn/model_selection/_split.py | 4 +- sklearn/preprocessing/_encoders.py | 4 +- sklearn/tests/test_public_functions.py | 18 ++- sklearn/utils/_param_validation.py | 138 ++++-------------- sklearn/utils/estimator_checks.py | 20 ++- sklearn/utils/tests/test_param_validation.py | 52 +++---- 18 files changed, 114 insertions(+), 174 deletions(-) diff --git a/sklearn/cluster/_optics.py b/sklearn/cluster/_optics.py index fb8daa4db1226..0d12832f15b64 100755 --- a/sklearn/cluster/_optics.py +++ b/sklearn/cluster/_optics.py @@ -233,7 +233,7 @@ class OPTICS(ClusterMixin, BaseEstimator): _parameter_constraints: dict = { "min_samples": [ Interval(Integral, 2, None, closed="left"), - Interval(Real, 0, 1, closed="both"), + Interval("real_not_int", 0, 1, closed="both"), ], "max_eps": [Interval(Real, 0, None, closed="both")], "metric": [StrOptions(set(_VALID_METRICS) | {"precomputed"}), callable], @@ -245,7 +245,7 @@ class OPTICS(ClusterMixin, BaseEstimator): "predecessor_correction": ["boolean"], "min_cluster_size": [ Interval(Integral, 2, None, closed="left"), - Interval(Real, 0, 1, closed="right"), + Interval("real_not_int", 0, 1, closed="right"), None, ], "algorithm": [StrOptions({"auto", "brute", "ball_tree", "kd_tree"})], @@ -431,7 +431,7 @@ def _compute_core_distances_(X, neighbors, min_samples, working_memory): "X": [np.ndarray, "sparse matrix"], "min_samples": [ Interval(Integral, 2, None, closed="left"), - Interval(Real, 0, 1, closed="both"), + Interval("real_not_int", 0, 1, closed="both"), ], "max_eps": [Interval(Real, 0, None, closed="both")], "metric": [StrOptions(set(_VALID_METRICS) | {"precomputed"}), callable], @@ -723,12 +723,12 @@ def cluster_optics_dbscan(*, reachability, core_distances, ordering, eps): "predecessor": [np.ndarray], "ordering": [np.ndarray], "min_samples": [ - Interval(Integral, 1, None, closed="neither"), - Interval(Real, 0, 1, closed="both"), + Interval(Integral, 2, None, closed="left"), + Interval("real_not_int", 0, 1, closed="both"), ], "min_cluster_size": [ - Interval(Integral, 1, None, closed="neither"), - Interval(Real, 0, 1, closed="both"), + Interval(Integral, 2, None, closed="left"), + Interval("real_not_int", 0, 1, closed="both"), None, ], "xi": [Interval(Real, 0, 1, closed="both")], diff --git a/sklearn/cluster/tests/test_optics.py b/sklearn/cluster/tests/test_optics.py index 48ebd11751ef3..38c92f855b746 100644 --- a/sklearn/cluster/tests/test_optics.py +++ b/sklearn/cluster/tests/test_optics.py @@ -197,7 +197,7 @@ def test_minimum_number_of_sample_check(): # Compute OPTICS X = [[1, 1]] - clust = OPTICS(max_eps=5.0 * 0.3, min_samples=10, min_cluster_size=1) + clust = OPTICS(max_eps=5.0 * 0.3, min_samples=10, min_cluster_size=1.0) # Run the fit with pytest.raises(ValueError, match=msg): diff --git a/sklearn/decomposition/_pca.py b/sklearn/decomposition/_pca.py index 789b38c9cd455..3ee3ba4d36bba 100644 --- a/sklearn/decomposition/_pca.py +++ b/sklearn/decomposition/_pca.py @@ -363,7 +363,7 @@ class PCA(_BasePCA): _parameter_constraints: dict = { "n_components": [ Interval(Integral, 0, None, closed="left"), - Interval(Real, 0, 1, closed="neither"), + Interval("real_not_int", 0, 1, closed="neither"), StrOptions({"mle"}), None, ], diff --git a/sklearn/ensemble/_bagging.py b/sklearn/ensemble/_bagging.py index d10f89102ea82..a9a93d78152e6 100644 --- a/sklearn/ensemble/_bagging.py +++ b/sklearn/ensemble/_bagging.py @@ -8,7 +8,7 @@ import numbers import numpy as np from abc import ABCMeta, abstractmethod -from numbers import Integral, Real +from numbers import Integral from warnings import warn from functools import partial @@ -248,11 +248,11 @@ class BaseBagging(BaseEnsemble, metaclass=ABCMeta): "n_estimators": [Interval(Integral, 1, None, closed="left")], "max_samples": [ Interval(Integral, 1, None, closed="left"), - Interval(Real, 0, 1, closed="right"), + Interval("real_not_int", 0, 1, closed="right"), ], "max_features": [ Interval(Integral, 1, None, closed="left"), - Interval(Real, 0, 1, closed="right"), + Interval("real_not_int", 0, 1, closed="right"), ], "bootstrap": ["boolean"], "bootstrap_features": ["boolean"], diff --git a/sklearn/ensemble/_forest.py b/sklearn/ensemble/_forest.py index c60a7177133ec..09fb2e547468f 100644 --- a/sklearn/ensemble/_forest.py +++ b/sklearn/ensemble/_forest.py @@ -206,7 +206,7 @@ class BaseForest(MultiOutputMixin, BaseEnsemble, metaclass=ABCMeta): "warm_start": ["boolean"], "max_samples": [ None, - Interval(Real, 0.0, 1.0, closed="right"), + Interval("real_not_int", 0.0, 1.0, closed="right"), Interval(Integral, 1, None, closed="left"), ], } diff --git a/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py b/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py index 31069fe14ee41..b5d254333724f 100644 --- a/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py +++ b/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py @@ -103,7 +103,7 @@ class BaseHistGradientBoosting(BaseEstimator, ABC): ], "n_iter_no_change": [Interval(Integral, 1, None, closed="left")], "validation_fraction": [ - Interval(Real, 0, 1, closed="neither"), + Interval("real_not_int", 0, 1, closed="neither"), Interval(Integral, 1, None, closed="left"), None, ], diff --git a/sklearn/ensemble/_iforest.py b/sklearn/ensemble/_iforest.py index 4e5422c50e614..1d581fce5db47 100644 --- a/sklearn/ensemble/_iforest.py +++ b/sklearn/ensemble/_iforest.py @@ -206,7 +206,7 @@ class IsolationForest(OutlierMixin, BaseBagging): "max_samples": [ StrOptions({"auto"}), Interval(Integral, 1, None, closed="left"), - Interval(Real, 0, 1, closed="right"), + Interval("real_not_int", 0, 1, closed="right"), ], "contamination": [ StrOptions({"auto"}), diff --git a/sklearn/feature_extraction/image.py b/sklearn/feature_extraction/image.py index ec7c9dfec90ea..b70078b364969 100644 --- a/sklearn/feature_extraction/image.py +++ b/sklearn/feature_extraction/image.py @@ -343,8 +343,8 @@ def _extract_patches(arr, patch_shape=8, extraction_step=1): "image": [np.ndarray], "patch_size": [tuple, list], "max_patches": [ - Interval(Real, left=0, right=1, closed="neither"), - Interval(Integral, left=1, right=None, closed="left"), + Interval("real_not_int", 0, 1, closed="neither"), + Interval(Integral, 1, None, closed="left"), None, ], "random_state": ["random_state"], @@ -542,7 +542,7 @@ class PatchExtractor(TransformerMixin, BaseEstimator): "patch_size": [tuple, None], "max_patches": [ None, - Interval(Real, 0, 1, closed="neither"), + Interval("real_not_int", 0, 1, closed="neither"), Interval(Integral, 1, None, closed="left"), ], "random_state": ["random_state"], diff --git a/sklearn/feature_extraction/text.py b/sklearn/feature_extraction/text.py index 0160bfeaa539f..8c3c5f4f7653b 100644 --- a/sklearn/feature_extraction/text.py +++ b/sklearn/feature_extraction/text.py @@ -15,7 +15,7 @@ from collections import defaultdict from collections.abc import Mapping from functools import partial -from numbers import Integral, Real +from numbers import Integral from operator import itemgetter import re import unicodedata @@ -1148,11 +1148,11 @@ class CountVectorizer(_VectorizerMixin, BaseEstimator): "ngram_range": [tuple], "analyzer": [StrOptions({"word", "char", "char_wb"}), callable], "max_df": [ - Interval(Real, 0, 1, closed="both"), + Interval("real_not_int", 0, 1, closed="both"), Interval(Integral, 1, None, closed="left"), ], "min_df": [ - Interval(Real, 0, 1, closed="both"), + Interval("real_not_int", 0, 1, closed="both"), Interval(Integral, 1, None, closed="left"), ], "max_features": [Interval(Integral, 1, None, closed="left"), None], diff --git a/sklearn/feature_selection/_rfe.py b/sklearn/feature_selection/_rfe.py index d105ba1ae3567..e99f06d4cac11 100644 --- a/sklearn/feature_selection/_rfe.py +++ b/sklearn/feature_selection/_rfe.py @@ -7,7 +7,7 @@ """Recursive feature elimination for feature ranking""" import numpy as np -from numbers import Integral, Real +from numbers import Integral from joblib import effective_n_jobs @@ -187,12 +187,12 @@ class RFE(SelectorMixin, MetaEstimatorMixin, BaseEstimator): "estimator": [HasMethods(["fit"])], "n_features_to_select": [ None, - Interval(Real, 0, 1, closed="right"), + Interval("real_not_int", 0, 1, closed="right"), Interval(Integral, 0, None, closed="neither"), ], "step": [ Interval(Integral, 0, None, closed="neither"), - Interval(Real, 0, 1, closed="neither"), + Interval("real_not_int", 0, 1, closed="neither"), ], "verbose": ["verbose"], "importance_getter": [str, callable], diff --git a/sklearn/feature_selection/_sequential.py b/sklearn/feature_selection/_sequential.py index 91ea7bdc719b9..89820c599280c 100644 --- a/sklearn/feature_selection/_sequential.py +++ b/sklearn/feature_selection/_sequential.py @@ -154,7 +154,7 @@ class SequentialFeatureSelector(SelectorMixin, MetaEstimatorMixin, BaseEstimator "estimator": [HasMethods(["fit"])], "n_features_to_select": [ StrOptions({"auto", "warn"}, deprecated={"warn"}), - Interval(Real, 0, 1, closed="right"), + Interval("real_not_int", 0, 1, closed="right"), Interval(Integral, 0, None, closed="neither"), Hidden(None), ], diff --git a/sklearn/linear_model/_ransac.py b/sklearn/linear_model/_ransac.py index ff58e067cde35..92a9539e47ae7 100644 --- a/sklearn/linear_model/_ransac.py +++ b/sklearn/linear_model/_ransac.py @@ -236,7 +236,7 @@ class RANSACRegressor( "estimator": [HasMethods(["fit", "score", "predict"]), None], "min_samples": [ Interval(Integral, 1, None, closed="left"), - Interval(Real, 0, 1, closed="both"), + Interval("real_not_int", 0, 1, closed="both"), None, ], "residual_threshold": [Interval(Real, 0, None, closed="left"), None], diff --git a/sklearn/model_selection/_split.py b/sklearn/model_selection/_split.py index 5c3854cfc9d3b..931db1a3ea9de 100644 --- a/sklearn/model_selection/_split.py +++ b/sklearn/model_selection/_split.py @@ -2464,12 +2464,12 @@ def check_cv(cv=5, y=None, *, classifier=False): @validate_params( { "test_size": [ - Interval(numbers.Real, 0, 1, closed="neither"), + Interval("real_not_int", 0, 1, closed="neither"), Interval(numbers.Integral, 1, None, closed="left"), None, ], "train_size": [ - Interval(numbers.Real, 0, 1, closed="neither"), + Interval("real_not_int", 0, 1, closed="neither"), Interval(numbers.Integral, 1, None, closed="left"), None, ], diff --git a/sklearn/preprocessing/_encoders.py b/sklearn/preprocessing/_encoders.py index 4c3e80771c35a..ef2fc39d1d5bc 100644 --- a/sklearn/preprocessing/_encoders.py +++ b/sklearn/preprocessing/_encoders.py @@ -3,7 +3,7 @@ # License: BSD 3 clause import numbers -from numbers import Integral, Real +from numbers import Integral import warnings import numpy as np @@ -493,7 +493,7 @@ class OneHotEncoder(_BaseEncoder): "max_categories": [Interval(Integral, 1, None, closed="left"), None], "min_frequency": [ Interval(Integral, 1, None, closed="left"), - Interval(Real, 0, 1, closed="neither"), + Interval("real_not_int", 0, 1, closed="neither"), None, ], "sparse": [Hidden(StrOptions({"deprecated"})), "boolean"], # deprecated diff --git a/sklearn/tests/test_public_functions.py b/sklearn/tests/test_public_functions.py index 8d6d3e1e7f45e..f8a318cd759f6 100644 --- a/sklearn/tests/test_public_functions.py +++ b/sklearn/tests/test_public_functions.py @@ -1,5 +1,6 @@ from importlib import import_module from inspect import signature +from numbers import Integral, Real import pytest @@ -7,6 +8,7 @@ from sklearn.utils._param_validation import generate_valid_param from sklearn.utils._param_validation import make_constraint from sklearn.utils._param_validation import InvalidParameterError +from sklearn.utils._param_validation import Interval def _get_func_info(func_module): @@ -70,6 +72,20 @@ def _check_function_param_validation( # This parameter is not validated continue + # Mixing an interval of reals and an interval of integers must be avoided. + if any( + isinstance(constraint, Interval) and constraint.type == Integral + for constraint in constraints + ) and any( + isinstance(constraint, Interval) and constraint.type == Real + for constraint in constraints + ): + raise ValueError( + f"The constraint for parameter {param_name} of {func_name} can't have a" + " mix of intervals of Integral and Real types. Use the type" + " 'real_not_int' instead of Real." + ) + match = ( rf"The '{param_name}' parameter of {func_name} must be .* Got .* instead." ) @@ -85,7 +101,7 @@ def _check_function_param_validation( for constraint in constraints: try: - bad_value = generate_invalid_param_val(constraint, constraints) + bad_value = generate_invalid_param_val(constraint) except NotImplementedError: continue diff --git a/sklearn/utils/_param_validation.py b/sklearn/utils/_param_validation.py index 8d23f0b23b6eb..de4b5adfdaa06 100644 --- a/sklearn/utils/_param_validation.py +++ b/sklearn/utils/_param_validation.py @@ -476,6 +476,13 @@ def __str__(self): left_bound = "-inf" if self.left is None else self.left right_bound = "inf" if self.right is None else self.right right_bracket = "]" if self.closed in ("right", "both") else ")" + + # better repr if the bounds were given as integers + if not self.type == Integral and isinstance(self.left, Real): + left_bound = float(left_bound) + if not self.type == Integral and isinstance(self.right, Real): + right_bound = float(right_bound) + return ( f"{type_str} in the range " f"{left_bracket}{left_bound}, {right_bound}{right_bracket}" @@ -718,7 +725,7 @@ def __init__(self, constraint): self.constraint = constraint -def generate_invalid_param_val(constraint, constraints=None): +def generate_invalid_param_val(constraint): """Return a value that does not satisfy the constraint. Raises a NotImplementedError if there exists no invalid value for this constraint. @@ -730,10 +737,6 @@ def generate_invalid_param_val(constraint, constraints=None): constraint : _Constraint instance The constraint to generate a value for. - constraints : list of _Constraint instances or None, default=None - The list of all constraints for this parameter. If None, the list only - containing `constraint` is used. - Returns ------- val : object @@ -757,116 +760,31 @@ def generate_invalid_param_val(constraint, constraints=None): if isinstance(constraint, _CVObjects): return "not a cv object" - if not isinstance(constraint, Interval): - raise NotImplementedError - - # constraint is an interval - constraints = [constraint] if constraints is None else constraints - return _generate_invalid_param_val_interval(constraint, constraints) - - -def _generate_invalid_param_val_interval(interval, constraints): - """Return a value that does not satisfy an interval constraint. - - Generating an invalid value for an integer interval depends on the other constraints - since an int is a real, meaning that it can be valid for a real interval. - Assumes that there can be at most 2 interval constraints: one integer interval - and/or one real interval. - - This is only useful for testing purpose. - - Parameters - ---------- - interval : Interval instance - The interval to generate a value for. + if isinstance(constraint, Interval) and constraint.type is Integral: + if constraint.left is not None: + return constraint.left - 1 + if constraint.right is not None: + return constraint.right + 1 - constraints : list of _Constraint instances - The list of all constraints for this parameter. + # There's no integer outside (-inf, +inf) + raise NotImplementedError - Returns - ------- - val : object - A value that does not satisfy the interval constraint. - """ - if interval.type is Real: - # generate a non-integer value such that it can't be valid even if there's also - # an integer interval constraint. - if interval.left is None and interval.right is None: - if interval.closed in ("left", "neither"): - return np.inf - elif interval.closed in ("right", "neither"): - return -np.inf - else: - raise NotImplementedError + if isinstance(constraint, Interval) and constraint.type in (Real, "real_not_int"): + if constraint.left is not None: + return constraint.left - 1e-6 + if constraint.right is not None: + return constraint.right + 1e-6 - if interval.left is not None: - return np.floor(interval.left) - 0.5 - else: # right is not None - return np.ceil(interval.right) + 0.5 + # bounds are -inf, +inf + if constraint.closed in ("right", "neither"): + return -np.inf + if constraint.closed in ("left", "neither"): + return np.inf - else: # interval.type is Integral - if interval.left is None and interval.right is None: - raise NotImplementedError + # interval is [-inf, +inf] + return np.nan - # We need to check if there's also a real interval constraint to generate a - # value that is not valid for any of the 2 interval constraints. - real_intervals = [ - i for i in constraints if isinstance(i, Interval) and i.type is Real - ] - real_interval = real_intervals[0] if real_intervals else None - - if real_interval is None: - # Only the integer interval constraint -> easy - if interval.left is not None: - return interval.left - 1 - else: # interval.right is not None - return interval.right + 1 - - # There's also a real interval constraint. Try to find a value left to both or - # right to both or in between them. - - # redefine left and right bounds to be smallest and largest valid integers in - # both intervals. - int_left = interval.left - if int_left is not None and interval.closed in ("right", "neither"): - int_left = int_left + 1 - - int_right = interval.right - if int_right is not None and interval.closed in ("left", "neither"): - int_right = int_right - 1 - - real_left = real_interval.left - if real_interval.left is not None: - real_left = int(np.ceil(real_interval.left)) - if real_interval.closed in ("right", "neither"): - real_left = real_left + 1 - - real_right = real_interval.right - if real_interval.right is not None: - real_right = int(np.floor(real_interval.right)) - if real_interval.closed in ("left", "neither"): - real_right = real_right - 1 - - if int_left is not None and real_left is not None: - # there exists an int left to both intervals - return min(int_left, real_left) - 1 - - if int_right is not None and real_right is not None: - # there exists an int right to both intervals - return max(int_right, real_right) + 1 - - if int_left is not None: - if real_right is not None and int_left - real_right >= 2: - # there exists an int between the 2 intervals - return int_left - 1 - else: - raise NotImplementedError - else: # int_right is not None - if real_left is not None and real_left - int_right >= 2: - # there exists an int between the 2 intervals - return int_right + 1 - else: - raise NotImplementedError + raise NotImplementedError def generate_valid_param(constraint): diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index 9eb666c68984d..9abfe2b73ea68 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -4,7 +4,7 @@ from copy import deepcopy from functools import partial, wraps from inspect import signature -from numbers import Real +from numbers import Real, Integral import numpy as np from scipy import sparse @@ -1434,7 +1434,7 @@ def check_fit2d_1sample(name, estimator_orig): # min_cluster_size cannot be less than the data size for OPTICS. if name == "OPTICS": - estimator.set_params(min_samples=1) + estimator.set_params(min_samples=1.0) # perplexity cannot be more than the number of samples for TSNE. if name == "TSNE": @@ -4155,6 +4155,20 @@ def check_param_validation(name, estimator_orig): # This parameter is not validated continue + # Mixing an interval of reals and an interval of integers must be avoided. + if any( + isinstance(constraint, Interval) and constraint.type == Integral + for constraint in constraints + ) and any( + isinstance(constraint, Interval) and constraint.type == Real + for constraint in constraints + ): + raise ValueError( + f"The constraint for parameter {param_name} of {name} can't have a mix" + " of intervals of Integral and Real types. Use the type 'real_not_int'" + " instead of Real." + ) + match = rf"The '{param_name}' parameter of {name} must be .* Got .* instead." err_msg = ( f"{name} does not raise an informative error message when the " @@ -4188,7 +4202,7 @@ def check_param_validation(name, estimator_orig): for constraint in constraints: try: - bad_value = generate_invalid_param_val(constraint, constraints) + bad_value = generate_invalid_param_val(constraint) except NotImplementedError: continue diff --git a/sklearn/utils/tests/test_param_validation.py b/sklearn/utils/tests/test_param_validation.py index ce8f9cdf939fd..afddb50382899 100644 --- a/sklearn/utils/tests/test_param_validation.py +++ b/sklearn/utils/tests/test_param_validation.py @@ -219,75 +219,75 @@ def test_generate_invalid_param_val(constraint): [ ( Interval(Integral, None, 3, closed="right"), - Interval(Real, -5, 5, closed="both"), + Interval("real_not_int", -5, 5, closed="both"), ), ( Interval(Integral, None, 3, closed="right"), - Interval(Real, -5, 5, closed="neither"), + Interval("real_not_int", -5, 5, closed="neither"), ), ( Interval(Integral, None, 3, closed="right"), - Interval(Real, 4, 5, closed="both"), + Interval("real_not_int", 4, 5, closed="both"), ), ( Interval(Integral, None, 3, closed="right"), - Interval(Real, 5, None, closed="left"), + Interval("real_not_int", 5, None, closed="left"), ), ( Interval(Integral, None, 3, closed="right"), - Interval(Real, 4, None, closed="neither"), + Interval("real_not_int", 4, None, closed="neither"), ), ( Interval(Integral, 3, None, closed="left"), - Interval(Real, -5, 5, closed="both"), + Interval("real_not_int", -5, 5, closed="both"), ), ( Interval(Integral, 3, None, closed="left"), - Interval(Real, -5, 5, closed="neither"), + Interval("real_not_int", -5, 5, closed="neither"), ), ( Interval(Integral, 3, None, closed="left"), - Interval(Real, 1, 2, closed="both"), + Interval("real_not_int", 1, 2, closed="both"), ), ( Interval(Integral, 3, None, closed="left"), - Interval(Real, None, -5, closed="left"), + Interval("real_not_int", None, -5, closed="left"), ), ( Interval(Integral, 3, None, closed="left"), - Interval(Real, None, -4, closed="neither"), + Interval("real_not_int", None, -4, closed="neither"), ), ( Interval(Integral, -5, 5, closed="both"), - Interval(Real, None, 1, closed="right"), + Interval("real_not_int", None, 1, closed="right"), ), ( Interval(Integral, -5, 5, closed="both"), - Interval(Real, 1, None, closed="left"), + Interval("real_not_int", 1, None, closed="left"), ), ( Interval(Integral, -5, 5, closed="both"), - Interval(Real, -10, -4, closed="neither"), + Interval("real_not_int", -10, -4, closed="neither"), ), ( Interval(Integral, -5, 5, closed="both"), - Interval(Real, -10, -4, closed="right"), + Interval("real_not_int", -10, -4, closed="right"), ), ( Interval(Integral, -5, 5, closed="neither"), - Interval(Real, 6, 10, closed="neither"), + Interval("real_not_int", 6, 10, closed="neither"), ), ( Interval(Integral, -5, 5, closed="neither"), - Interval(Real, 6, 10, closed="left"), + Interval("real_not_int", 6, 10, closed="left"), ), ( Interval(Integral, 2, None, closed="left"), - Interval(Real, 0, 1, closed="both"), + Interval("real_not_int", 0, 1, closed="both"), ), ( Interval(Integral, 1, None, closed="left"), - Interval(Real, 0, 1, closed="both"), + Interval("real_not_int", 0, 1, closed="both"), ), ], ) @@ -295,15 +295,11 @@ def test_generate_invalid_param_val_2_intervals(integer_interval, real_interval) """Check that the value generated for an interval constraint does not satisfy any of the interval constraints. """ - bad_value = generate_invalid_param_val( - real_interval, constraints=[real_interval, integer_interval] - ) + bad_value = generate_invalid_param_val(constraint=real_interval) assert not real_interval.is_satisfied_by(bad_value) assert not integer_interval.is_satisfied_by(bad_value) - bad_value = generate_invalid_param_val( - integer_interval, constraints=[real_interval, integer_interval] - ) + bad_value = generate_invalid_param_val(constraint=integer_interval) assert not real_interval.is_satisfied_by(bad_value) assert not integer_interval.is_satisfied_by(bad_value) @@ -318,11 +314,7 @@ def test_generate_invalid_param_val_2_intervals(integer_interval, real_interval) [_RandomStates()], [_SparseMatrices()], [_Booleans()], - [Interval(Real, None, None, closed="both")], - [ - Interval(Integral, 0, None, closed="left"), - Interval(Real, None, 0, closed="neither"), - ], + [Interval(Integral, None, None, closed="neither")], ], ) def test_generate_invalid_param_val_all_valid(constraints): @@ -330,7 +322,7 @@ def test_generate_invalid_param_val_all_valid(constraints): for the constraint. """ with pytest.raises(NotImplementedError): - generate_invalid_param_val(constraints[0], constraints=constraints) + generate_invalid_param_val(constraints[0]) @pytest.mark.parametrize( From 718523f905129dad9b26368e8f2ddf7a404d9f44 Mon Sep 17 00:00:00 2001 From: jeremiedbb Date: Tue, 14 Mar 2023 18:13:35 +0100 Subject: [PATCH 2/4] exp: create custom RealNotInt type --- sklearn/cluster/_optics.py | 11 ++--- sklearn/decomposition/_pca.py | 3 +- sklearn/ensemble/_bagging.py | 5 ++- sklearn/ensemble/_forest.py | 3 +- .../gradient_boosting.py | 3 +- sklearn/ensemble/_iforest.py | 3 +- sklearn/feature_extraction/image.py | 5 ++- sklearn/feature_extraction/text.py | 5 ++- sklearn/feature_selection/_rfe.py | 5 ++- sklearn/feature_selection/_sequential.py | 3 +- sklearn/linear_model/_ransac.py | 3 +- sklearn/model_selection/_split.py | 5 ++- sklearn/preprocessing/_encoders.py | 3 +- sklearn/tests/test_public_functions.py | 2 +- sklearn/tree/_classes.py | 7 ++-- sklearn/utils/_param_validation.py | 29 ++++++++----- sklearn/utils/estimator_checks.py | 2 +- sklearn/utils/tests/test_param_validation.py | 41 ++++++++++--------- 18 files changed, 80 insertions(+), 58 deletions(-) diff --git a/sklearn/cluster/_optics.py b/sklearn/cluster/_optics.py index 0d12832f15b64..0dd5fb7f7daad 100755 --- a/sklearn/cluster/_optics.py +++ b/sklearn/cluster/_optics.py @@ -20,6 +20,7 @@ from ..metrics.pairwise import _VALID_METRICS from ..utils import gen_batches, get_chunk_n_rows from ..utils._param_validation import Interval, HasMethods, StrOptions, validate_params +from ..utils._param_validation import RealNotInt from ..utils.validation import check_memory from ..neighbors import NearestNeighbors from ..base import BaseEstimator, ClusterMixin @@ -233,7 +234,7 @@ class OPTICS(ClusterMixin, BaseEstimator): _parameter_constraints: dict = { "min_samples": [ Interval(Integral, 2, None, closed="left"), - Interval("real_not_int", 0, 1, closed="both"), + Interval(RealNotInt, 0, 1, closed="both"), ], "max_eps": [Interval(Real, 0, None, closed="both")], "metric": [StrOptions(set(_VALID_METRICS) | {"precomputed"}), callable], @@ -245,7 +246,7 @@ class OPTICS(ClusterMixin, BaseEstimator): "predecessor_correction": ["boolean"], "min_cluster_size": [ Interval(Integral, 2, None, closed="left"), - Interval("real_not_int", 0, 1, closed="right"), + Interval(RealNotInt, 0, 1, closed="right"), None, ], "algorithm": [StrOptions({"auto", "brute", "ball_tree", "kd_tree"})], @@ -431,7 +432,7 @@ def _compute_core_distances_(X, neighbors, min_samples, working_memory): "X": [np.ndarray, "sparse matrix"], "min_samples": [ Interval(Integral, 2, None, closed="left"), - Interval("real_not_int", 0, 1, closed="both"), + Interval(RealNotInt, 0, 1, closed="both"), ], "max_eps": [Interval(Real, 0, None, closed="both")], "metric": [StrOptions(set(_VALID_METRICS) | {"precomputed"}), callable], @@ -724,11 +725,11 @@ def cluster_optics_dbscan(*, reachability, core_distances, ordering, eps): "ordering": [np.ndarray], "min_samples": [ Interval(Integral, 2, None, closed="left"), - Interval("real_not_int", 0, 1, closed="both"), + Interval(RealNotInt, 0, 1, closed="both"), ], "min_cluster_size": [ Interval(Integral, 2, None, closed="left"), - Interval("real_not_int", 0, 1, closed="both"), + Interval(RealNotInt, 0, 1, closed="both"), None, ], "xi": [Interval(Real, 0, 1, closed="both")], diff --git a/sklearn/decomposition/_pca.py b/sklearn/decomposition/_pca.py index 3ee3ba4d36bba..e8c302fc47129 100644 --- a/sklearn/decomposition/_pca.py +++ b/sklearn/decomposition/_pca.py @@ -27,6 +27,7 @@ from ..utils.extmath import stable_cumsum from ..utils.validation import check_is_fitted from ..utils._param_validation import Interval, StrOptions +from ..utils._param_validation import RealNotInt def _assess_dimension(spectrum, rank, n_samples): @@ -363,7 +364,7 @@ class PCA(_BasePCA): _parameter_constraints: dict = { "n_components": [ Interval(Integral, 0, None, closed="left"), - Interval("real_not_int", 0, 1, closed="neither"), + Interval(RealNotInt, 0, 1, closed="neither"), StrOptions({"mle"}), None, ], diff --git a/sklearn/ensemble/_bagging.py b/sklearn/ensemble/_bagging.py index a9a93d78152e6..c4fb966aacac3 100644 --- a/sklearn/ensemble/_bagging.py +++ b/sklearn/ensemble/_bagging.py @@ -22,6 +22,7 @@ from ..utils.multiclass import check_classification_targets from ..utils.random import sample_without_replacement from ..utils._param_validation import Interval, HasMethods, StrOptions +from ..utils._param_validation import RealNotInt from ..utils.validation import has_fit_parameter, check_is_fitted, _check_sample_weight from ..utils._tags import _safe_tags from ..utils.parallel import delayed, Parallel @@ -248,11 +249,11 @@ class BaseBagging(BaseEnsemble, metaclass=ABCMeta): "n_estimators": [Interval(Integral, 1, None, closed="left")], "max_samples": [ Interval(Integral, 1, None, closed="left"), - Interval("real_not_int", 0, 1, closed="right"), + Interval(RealNotInt, 0, 1, closed="right"), ], "max_features": [ Interval(Integral, 1, None, closed="left"), - Interval("real_not_int", 0, 1, closed="right"), + Interval(RealNotInt, 0, 1, closed="right"), ], "bootstrap": ["boolean"], "bootstrap_features": ["boolean"], diff --git a/sklearn/ensemble/_forest.py b/sklearn/ensemble/_forest.py index 09fb2e547468f..c7064f477d5cf 100644 --- a/sklearn/ensemble/_forest.py +++ b/sklearn/ensemble/_forest.py @@ -74,6 +74,7 @@ class calls the ``fit`` method of each sub-estimator on random samples ) from ..utils.validation import _num_samples from ..utils._param_validation import Interval, StrOptions +from ..utils._param_validation import RealNotInt __all__ = [ @@ -206,7 +207,7 @@ class BaseForest(MultiOutputMixin, BaseEnsemble, metaclass=ABCMeta): "warm_start": ["boolean"], "max_samples": [ None, - Interval("real_not_int", 0.0, 1.0, closed="right"), + Interval(RealNotInt, 0.0, 1.0, closed="right"), Interval(Integral, 1, None, closed="left"), ], } diff --git a/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py b/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py index b5d254333724f..76b49c60226a7 100644 --- a/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py +++ b/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py @@ -27,6 +27,7 @@ _check_monotonic_cst, ) from ...utils._param_validation import Interval, StrOptions +from ...utils._param_validation import RealNotInt from ...utils._openmp_helpers import _openmp_effective_n_threads from ...utils.multiclass import check_classification_targets from ...metrics import check_scoring @@ -103,7 +104,7 @@ class BaseHistGradientBoosting(BaseEstimator, ABC): ], "n_iter_no_change": [Interval(Integral, 1, None, closed="left")], "validation_fraction": [ - Interval("real_not_int", 0, 1, closed="neither"), + Interval(RealNotInt, 0, 1, closed="neither"), Interval(Integral, 1, None, closed="left"), None, ], diff --git a/sklearn/ensemble/_iforest.py b/sklearn/ensemble/_iforest.py index 1d581fce5db47..4d6c1c3f0b7f9 100644 --- a/sklearn/ensemble/_iforest.py +++ b/sklearn/ensemble/_iforest.py @@ -17,6 +17,7 @@ get_chunk_n_rows, ) from ..utils._param_validation import Interval, StrOptions +from ..utils._param_validation import RealNotInt from ..utils.validation import check_is_fitted, _num_samples from ..base import OutlierMixin @@ -206,7 +207,7 @@ class IsolationForest(OutlierMixin, BaseBagging): "max_samples": [ StrOptions({"auto"}), Interval(Integral, 1, None, closed="left"), - Interval("real_not_int", 0, 1, closed="right"), + Interval(RealNotInt, 0, 1, closed="right"), ], "contamination": [ StrOptions({"auto"}), diff --git a/sklearn/feature_extraction/image.py b/sklearn/feature_extraction/image.py index b70078b364969..06cbe3b423276 100644 --- a/sklearn/feature_extraction/image.py +++ b/sklearn/feature_extraction/image.py @@ -18,6 +18,7 @@ from ..base import BaseEstimator, TransformerMixin from ..utils import check_array, check_random_state from ..utils._param_validation import Hidden, Interval, validate_params +from ..utils._param_validation import RealNotInt __all__ = [ "PatchExtractor", @@ -343,7 +344,7 @@ def _extract_patches(arr, patch_shape=8, extraction_step=1): "image": [np.ndarray], "patch_size": [tuple, list], "max_patches": [ - Interval("real_not_int", 0, 1, closed="neither"), + Interval(RealNotInt, 0, 1, closed="neither"), Interval(Integral, 1, None, closed="left"), None, ], @@ -542,7 +543,7 @@ class PatchExtractor(TransformerMixin, BaseEstimator): "patch_size": [tuple, None], "max_patches": [ None, - Interval("real_not_int", 0, 1, closed="neither"), + Interval(RealNotInt, 0, 1, closed="neither"), Interval(Integral, 1, None, closed="left"), ], "random_state": ["random_state"], diff --git a/sklearn/feature_extraction/text.py b/sklearn/feature_extraction/text.py index 8c3c5f4f7653b..9d56d2d1bf2dd 100644 --- a/sklearn/feature_extraction/text.py +++ b/sklearn/feature_extraction/text.py @@ -32,6 +32,7 @@ from ..utils import _IS_32BIT from ..exceptions import NotFittedError from ..utils._param_validation import StrOptions, Interval, HasMethods +from ..utils._param_validation import RealNotInt __all__ = [ @@ -1148,11 +1149,11 @@ class CountVectorizer(_VectorizerMixin, BaseEstimator): "ngram_range": [tuple], "analyzer": [StrOptions({"word", "char", "char_wb"}), callable], "max_df": [ - Interval("real_not_int", 0, 1, closed="both"), + Interval(RealNotInt, 0, 1, closed="both"), Interval(Integral, 1, None, closed="left"), ], "min_df": [ - Interval("real_not_int", 0, 1, closed="both"), + Interval(RealNotInt, 0, 1, closed="both"), Interval(Integral, 1, None, closed="left"), ], "max_features": [Interval(Integral, 1, None, closed="left"), None], diff --git a/sklearn/feature_selection/_rfe.py b/sklearn/feature_selection/_rfe.py index e99f06d4cac11..214ac9e0c30cf 100644 --- a/sklearn/feature_selection/_rfe.py +++ b/sklearn/feature_selection/_rfe.py @@ -14,6 +14,7 @@ from ..utils.metaestimators import available_if from ..utils.metaestimators import _safe_split from ..utils._param_validation import HasMethods, Interval +from ..utils._param_validation import RealNotInt from ..utils._tags import _safe_tags from ..utils.validation import check_is_fitted from ..utils.parallel import delayed, Parallel @@ -187,12 +188,12 @@ class RFE(SelectorMixin, MetaEstimatorMixin, BaseEstimator): "estimator": [HasMethods(["fit"])], "n_features_to_select": [ None, - Interval("real_not_int", 0, 1, closed="right"), + Interval(RealNotInt, 0, 1, closed="right"), Interval(Integral, 0, None, closed="neither"), ], "step": [ Interval(Integral, 0, None, closed="neither"), - Interval("real_not_int", 0, 1, closed="neither"), + Interval(RealNotInt, 0, 1, closed="neither"), ], "verbose": ["verbose"], "importance_getter": [str, callable], diff --git a/sklearn/feature_selection/_sequential.py b/sklearn/feature_selection/_sequential.py index 89820c599280c..e983c55de7d25 100644 --- a/sklearn/feature_selection/_sequential.py +++ b/sklearn/feature_selection/_sequential.py @@ -10,6 +10,7 @@ from ._base import SelectorMixin from ..base import BaseEstimator, MetaEstimatorMixin, clone from ..utils._param_validation import HasMethods, Hidden, Interval, StrOptions +from ..utils._param_validation import RealNotInt from ..utils._tags import _safe_tags from ..utils.validation import check_is_fitted from ..model_selection import cross_val_score @@ -154,7 +155,7 @@ class SequentialFeatureSelector(SelectorMixin, MetaEstimatorMixin, BaseEstimator "estimator": [HasMethods(["fit"])], "n_features_to_select": [ StrOptions({"auto", "warn"}, deprecated={"warn"}), - Interval("real_not_int", 0, 1, closed="right"), + Interval(RealNotInt, 0, 1, closed="right"), Interval(Integral, 0, None, closed="neither"), Hidden(None), ], diff --git a/sklearn/linear_model/_ransac.py b/sklearn/linear_model/_ransac.py index 92a9539e47ae7..472f2275ffa6b 100644 --- a/sklearn/linear_model/_ransac.py +++ b/sklearn/linear_model/_ransac.py @@ -15,6 +15,7 @@ from ._base import LinearRegression from ..utils.validation import has_fit_parameter from ..utils._param_validation import Interval, Options, StrOptions, HasMethods, Hidden +from ..utils._param_validation import RealNotInt from ..exceptions import ConvergenceWarning _EPSILON = np.spacing(1) @@ -236,7 +237,7 @@ class RANSACRegressor( "estimator": [HasMethods(["fit", "score", "predict"]), None], "min_samples": [ Interval(Integral, 1, None, closed="left"), - Interval("real_not_int", 0, 1, closed="both"), + Interval(RealNotInt, 0, 1, closed="both"), None, ], "residual_threshold": [Interval(Real, 0, None, closed="left"), None], diff --git a/sklearn/model_selection/_split.py b/sklearn/model_selection/_split.py index 931db1a3ea9de..dded962a350d8 100644 --- a/sklearn/model_selection/_split.py +++ b/sklearn/model_selection/_split.py @@ -29,6 +29,7 @@ from ..utils.validation import check_array from ..utils.multiclass import type_of_target from ..utils._param_validation import validate_params, Interval +from ..utils._param_validation import RealNotInt __all__ = [ "BaseCrossValidator", @@ -2464,12 +2465,12 @@ def check_cv(cv=5, y=None, *, classifier=False): @validate_params( { "test_size": [ - Interval("real_not_int", 0, 1, closed="neither"), + Interval(RealNotInt, 0, 1, closed="neither"), Interval(numbers.Integral, 1, None, closed="left"), None, ], "train_size": [ - Interval("real_not_int", 0, 1, closed="neither"), + Interval(RealNotInt, 0, 1, closed="neither"), Interval(numbers.Integral, 1, None, closed="left"), None, ], diff --git a/sklearn/preprocessing/_encoders.py b/sklearn/preprocessing/_encoders.py index ef2fc39d1d5bc..7df0584ba9d5b 100644 --- a/sklearn/preprocessing/_encoders.py +++ b/sklearn/preprocessing/_encoders.py @@ -14,6 +14,7 @@ from ..utils.validation import check_is_fitted from ..utils.validation import _check_feature_names_in from ..utils._param_validation import Interval, StrOptions, Hidden +from ..utils._param_validation import RealNotInt from ..utils._mask import _get_mask from ..utils._encode import _encode, _check_unknown, _unique, _get_counts @@ -493,7 +494,7 @@ class OneHotEncoder(_BaseEncoder): "max_categories": [Interval(Integral, 1, None, closed="left"), None], "min_frequency": [ Interval(Integral, 1, None, closed="left"), - Interval("real_not_int", 0, 1, closed="neither"), + Interval(RealNotInt, 0, 1, closed="neither"), None, ], "sparse": [Hidden(StrOptions({"deprecated"})), "boolean"], # deprecated diff --git a/sklearn/tests/test_public_functions.py b/sklearn/tests/test_public_functions.py index f8a318cd759f6..bfbb4888b3d17 100644 --- a/sklearn/tests/test_public_functions.py +++ b/sklearn/tests/test_public_functions.py @@ -83,7 +83,7 @@ def _check_function_param_validation( raise ValueError( f"The constraint for parameter {param_name} of {func_name} can't have a" " mix of intervals of Integral and Real types. Use the type" - " 'real_not_int' instead of Real." + " RealNotInt instead of Real." ) match = ( diff --git a/sklearn/tree/_classes.py b/sklearn/tree/_classes.py index 6e01b8b49e594..e0e341d9a89f6 100644 --- a/sklearn/tree/_classes.py +++ b/sklearn/tree/_classes.py @@ -38,6 +38,7 @@ from ..utils.multiclass import check_classification_targets from ..utils.validation import check_is_fitted from ..utils._param_validation import Hidden, Interval, StrOptions +from ..utils._param_validation import RealNotInt from ._criterion import Criterion from ._splitter import Splitter @@ -99,16 +100,16 @@ class BaseDecisionTree(MultiOutputMixin, BaseEstimator, metaclass=ABCMeta): "max_depth": [Interval(Integral, 1, None, closed="left"), None], "min_samples_split": [ Interval(Integral, 2, None, closed="left"), - Interval("real_not_int", 0.0, 1.0, closed="right"), + Interval(RealNotInt, 0.0, 1.0, closed="right"), ], "min_samples_leaf": [ Interval(Integral, 1, None, closed="left"), - Interval("real_not_int", 0.0, 1.0, closed="neither"), + Interval(RealNotInt, 0.0, 1.0, closed="neither"), ], "min_weight_fraction_leaf": [Interval(Real, 0.0, 0.5, closed="both")], "max_features": [ Interval(Integral, 1, None, closed="left"), - Interval("real_not_int", 0.0, 1.0, closed="right"), + Interval(RealNotInt, 0.0, 1.0, closed="right"), StrOptions({"auto", "sqrt", "log2"}, deprecated={"auto"}), None, ], diff --git a/sklearn/utils/_param_validation.py b/sklearn/utils/_param_validation.py index de4b5adfdaa06..4370817094be8 100644 --- a/sklearn/utils/_param_validation.py +++ b/sklearn/utils/_param_validation.py @@ -207,6 +207,18 @@ def wrapper(*args, **kwargs): return decorator +class RealNotInt(Real): + """A type that represents reals that are not instances of int. + + Behaves like float, but also works with values extracted from numpy arrays. + isintance(1, RealNotInt) -> False + isinstance(1.0, RealNotInt) -> True + """ + + +RealNotInt.register(float) + + def _type_name(t): """Convert type into human readable string.""" module = t.__module__ @@ -364,10 +376,10 @@ class Interval(_Constraint): Parameters ---------- - type : {numbers.Integral, numbers.Real, "real_not_int"} + type : {numbers.Integral, numbers.Real, RealNotInt} The set of numbers in which to set the interval. - If "real_not_int", only reals that don't have the integer type + If RealNotInt, only reals that don't have the integer type are allowed. For example 1.0 is allowed but 1 is not. left : float or int or None @@ -405,9 +417,9 @@ def __init__(self, type, left, right, *, closed): self._check_params() def _check_params(self): - if self.type not in (Integral, Real, "real_not_int"): + if self.type not in (Integral, Real, RealNotInt): raise ValueError( - "type must be either numbers.Integral, numbers.Real or 'real_not_int'." + "type must be either numbers.Integral, numbers.Real or RealNotInt." f" Got {self.type} instead." ) @@ -459,13 +471,8 @@ def __contains__(self, val): return False return True - def _has_valid_type(self, val): - if self.type == "real_not_int": - return isinstance(val, Real) and not isinstance(val, Integral) - return isinstance(val, self.type) - def is_satisfied_by(self, val): - if not self._has_valid_type(val): + if not isinstance(val, self.type): return False return val in self @@ -769,7 +776,7 @@ def generate_invalid_param_val(constraint): # There's no integer outside (-inf, +inf) raise NotImplementedError - if isinstance(constraint, Interval) and constraint.type in (Real, "real_not_int"): + if isinstance(constraint, Interval) and constraint.type in (Real, RealNotInt): if constraint.left is not None: return constraint.left - 1e-6 if constraint.right is not None: diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index 9abfe2b73ea68..30b35c31fe214 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -4165,7 +4165,7 @@ def check_param_validation(name, estimator_orig): ): raise ValueError( f"The constraint for parameter {param_name} of {name} can't have a mix" - " of intervals of Integral and Real types. Use the type 'real_not_int'" + " of intervals of Integral and Real types. Use the type RealNotInt" " instead of Real." ) diff --git a/sklearn/utils/tests/test_param_validation.py b/sklearn/utils/tests/test_param_validation.py index afddb50382899..e2386e20bd2a2 100644 --- a/sklearn/utils/tests/test_param_validation.py +++ b/sklearn/utils/tests/test_param_validation.py @@ -29,6 +29,7 @@ from sklearn.utils._param_validation import generate_valid_param from sklearn.utils._param_validation import validate_params from sklearn.utils._param_validation import InvalidParameterError +from sklearn.utils._param_validation import RealNotInt # Some helpers for the tests @@ -219,75 +220,75 @@ def test_generate_invalid_param_val(constraint): [ ( Interval(Integral, None, 3, closed="right"), - Interval("real_not_int", -5, 5, closed="both"), + Interval(RealNotInt, -5, 5, closed="both"), ), ( Interval(Integral, None, 3, closed="right"), - Interval("real_not_int", -5, 5, closed="neither"), + Interval(RealNotInt, -5, 5, closed="neither"), ), ( Interval(Integral, None, 3, closed="right"), - Interval("real_not_int", 4, 5, closed="both"), + Interval(RealNotInt, 4, 5, closed="both"), ), ( Interval(Integral, None, 3, closed="right"), - Interval("real_not_int", 5, None, closed="left"), + Interval(RealNotInt, 5, None, closed="left"), ), ( Interval(Integral, None, 3, closed="right"), - Interval("real_not_int", 4, None, closed="neither"), + Interval(RealNotInt, 4, None, closed="neither"), ), ( Interval(Integral, 3, None, closed="left"), - Interval("real_not_int", -5, 5, closed="both"), + Interval(RealNotInt, -5, 5, closed="both"), ), ( Interval(Integral, 3, None, closed="left"), - Interval("real_not_int", -5, 5, closed="neither"), + Interval(RealNotInt, -5, 5, closed="neither"), ), ( Interval(Integral, 3, None, closed="left"), - Interval("real_not_int", 1, 2, closed="both"), + Interval(RealNotInt, 1, 2, closed="both"), ), ( Interval(Integral, 3, None, closed="left"), - Interval("real_not_int", None, -5, closed="left"), + Interval(RealNotInt, None, -5, closed="left"), ), ( Interval(Integral, 3, None, closed="left"), - Interval("real_not_int", None, -4, closed="neither"), + Interval(RealNotInt, None, -4, closed="neither"), ), ( Interval(Integral, -5, 5, closed="both"), - Interval("real_not_int", None, 1, closed="right"), + Interval(RealNotInt, None, 1, closed="right"), ), ( Interval(Integral, -5, 5, closed="both"), - Interval("real_not_int", 1, None, closed="left"), + Interval(RealNotInt, 1, None, closed="left"), ), ( Interval(Integral, -5, 5, closed="both"), - Interval("real_not_int", -10, -4, closed="neither"), + Interval(RealNotInt, -10, -4, closed="neither"), ), ( Interval(Integral, -5, 5, closed="both"), - Interval("real_not_int", -10, -4, closed="right"), + Interval(RealNotInt, -10, -4, closed="right"), ), ( Interval(Integral, -5, 5, closed="neither"), - Interval("real_not_int", 6, 10, closed="neither"), + Interval(RealNotInt, 6, 10, closed="neither"), ), ( Interval(Integral, -5, 5, closed="neither"), - Interval("real_not_int", 6, 10, closed="left"), + Interval(RealNotInt, 6, 10, closed="left"), ), ( Interval(Integral, 2, None, closed="left"), - Interval("real_not_int", 0, 1, closed="both"), + Interval(RealNotInt, 0, 1, closed="both"), ), ( Interval(Integral, 1, None, closed="left"), - Interval("real_not_int", 0, 1, closed="both"), + Interval(RealNotInt, 0, 1, closed="both"), ), ], ) @@ -657,7 +658,7 @@ def fit(self, X=None, y=None): def test_interval_real_not_int(): - """Check for the type "real_not_int" in the Interval constraint.""" - constraint = Interval("real_not_int", 0, 1, closed="both") + """Check for the type RealNotInt in the Interval constraint.""" + constraint = Interval(RealNotInt, 0, 1, closed="both") assert constraint.is_satisfied_by(1.0) assert not constraint.is_satisfied_by(1) From 9a08b8ac8cd010a447ea1e997cf4ae46203b4ed9 Mon Sep 17 00:00:00 2001 From: jeremiedbb Date: Tue, 14 Mar 2023 18:16:52 +0100 Subject: [PATCH 3/4] cln test invalid param --- sklearn/utils/tests/test_param_validation.py | 22 ++++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/sklearn/utils/tests/test_param_validation.py b/sklearn/utils/tests/test_param_validation.py index e2386e20bd2a2..c70565e2674ce 100644 --- a/sklearn/utils/tests/test_param_validation.py +++ b/sklearn/utils/tests/test_param_validation.py @@ -306,24 +306,24 @@ def test_generate_invalid_param_val_2_intervals(integer_interval, real_interval) @pytest.mark.parametrize( - "constraints", + "constraint", [ - [_ArrayLikes()], - [_InstancesOf(list)], - [_Callables()], - [_NoneConstraint()], - [_RandomStates()], - [_SparseMatrices()], - [_Booleans()], - [Interval(Integral, None, None, closed="neither")], + _ArrayLikes(), + _InstancesOf(list), + _Callables(), + _NoneConstraint(), + _RandomStates(), + _SparseMatrices(), + _Booleans(), + Interval(Integral, None, None, closed="neither"), ], ) -def test_generate_invalid_param_val_all_valid(constraints): +def test_generate_invalid_param_val_all_valid(constraint): """Check that the function raises NotImplementedError when there's no invalid value for the constraint. """ with pytest.raises(NotImplementedError): - generate_invalid_param_val(constraints[0]) + generate_invalid_param_val(constraint) @pytest.mark.parametrize( From a6a09b0df126f7e0c1348f1b49765ffc424d342f Mon Sep 17 00:00:00 2001 From: jeremiedbb Date: Wed, 15 Mar 2023 16:38:42 +0100 Subject: [PATCH 4/4] add test for the RealNotInt type --- sklearn/utils/tests/test_param_validation.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/sklearn/utils/tests/test_param_validation.py b/sklearn/utils/tests/test_param_validation.py index c70565e2674ce..bfbbf2a1b29df 100644 --- a/sklearn/utils/tests/test_param_validation.py +++ b/sklearn/utils/tests/test_param_validation.py @@ -662,3 +662,11 @@ def test_interval_real_not_int(): constraint = Interval(RealNotInt, 0, 1, closed="both") assert constraint.is_satisfied_by(1.0) assert not constraint.is_satisfied_by(1) + + +def test_real_not_int(): + """Check for the RealNotInt type.""" + assert isinstance(1.0, RealNotInt) + assert not isinstance(1, RealNotInt) + assert isinstance(np.float64(1), RealNotInt) + assert not isinstance(np.int64(1), RealNotInt)