From b4c32638c36e94e140bd269599501254015afe34 Mon Sep 17 00:00:00 2001 From: Pooja Subramaniam Date: Wed, 15 Mar 2023 16:56:25 +0100 Subject: [PATCH 1/2] remove simple validation statements and replace with validate_params decorator --- sklearn/tests/test_public_functions.py | 1 + sklearn/utils/__init__.py | 14 ++++++++------ 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/sklearn/tests/test_public_functions.py b/sklearn/tests/test_public_functions.py index a9fff191b06a7..ad650e80f2716 100644 --- a/sklearn/tests/test_public_functions.py +++ b/sklearn/tests/test_public_functions.py @@ -170,6 +170,7 @@ def _check_function_param_validation( "sklearn.model_selection.train_test_split", "sklearn.random_projection.johnson_lindenstrauss_min_dim", "sklearn.svm.l1_min_c", + "sklearn.utils.gen_batches", ] diff --git a/sklearn/utils/__init__.py b/sklearn/utils/__init__.py index 923c08d44c6f4..0b3119c0bfa03 100644 --- a/sklearn/utils/__init__.py +++ b/sklearn/utils/__init__.py @@ -39,6 +39,7 @@ ) from .. import get_config from ._bunch import Bunch +from ._param_validation import validate_params, Interval # Do not deprecate parallel_backend and register_parallel_backend as they are @@ -725,6 +726,13 @@ def _chunk_generator(gen, chunksize): return +@validate_params( + { + "n": [Interval(numbers.Integral, 1, None, closed="left")], + "batch_size": [Interval(numbers.Integral, 1, None, closed="left")], + "min_batch_size": [Interval(numbers.Integral, 0, None, closed="left")], + } +) def gen_batches(n, batch_size, *, min_batch_size=0): """Generator to create slices containing `batch_size` elements from 0 to `n`. @@ -762,12 +770,6 @@ def gen_batches(n, batch_size, *, min_batch_size=0): >>> list(gen_batches(7, 3, min_batch_size=2)) [slice(0, 3, None), slice(3, 7, None)] """ - if not isinstance(batch_size, numbers.Integral): - raise TypeError( - "gen_batches got batch_size=%s, must be an integer" % batch_size - ) - if batch_size <= 0: - raise ValueError("gen_batches got batch_size=%s, must be positive" % batch_size) start = 0 for _ in range(int(n // batch_size)): end = start + batch_size From adbebc1efc6f15f09b55de14b6c6b6fb2ecf4992 Mon Sep 17 00:00:00 2001 From: Pooja Subramaniam Date: Wed, 15 Mar 2023 18:53:08 +0100 Subject: [PATCH 2/2] removed outdated test test_gen_batches --- sklearn/utils/tests/test_utils.py | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/sklearn/utils/tests/test_utils.py b/sklearn/utils/tests/test_utils.py index 848985f267c92..a000394bbee28 100644 --- a/sklearn/utils/tests/test_utils.py +++ b/sklearn/utils/tests/test_utils.py @@ -17,7 +17,6 @@ from sklearn.utils import check_random_state from sklearn.utils import _determine_key_type from sklearn.utils import deprecated -from sklearn.utils import gen_batches from sklearn.utils import _get_column_indices from sklearn.utils import resample from sklearn.utils import safe_mask @@ -56,19 +55,6 @@ def test_make_rng(): check_random_state("some invalid seed") -def test_gen_batches(): - # Make sure gen_batches errors on invalid batch_size - - assert_array_equal(list(gen_batches(4, 2)), [slice(0, 2, None), slice(2, 4, None)]) - msg_zero = "gen_batches got batch_size=0, must be positive" - with pytest.raises(ValueError, match=msg_zero): - next(gen_batches(4, 0)) - - msg_float = "gen_batches got batch_size=0.5, must be an integer" - with pytest.raises(TypeError, match=msg_float): - next(gen_batches(4, 0.5)) - - def test_deprecated(): # Test whether the deprecated decorator issues appropriate warnings # Copied almost verbatim from https://docs.python.org/library/warnings.html