Skip to content

MAINT Parameters validation for utils.gen_batches #25864

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions sklearn/tests/test_public_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ def _check_function_param_validation(
"sklearn.model_selection.train_test_split",
"sklearn.random_projection.johnson_lindenstrauss_min_dim",
"sklearn.svm.l1_min_c",
"sklearn.utils.gen_batches",
]


Expand Down
14 changes: 8 additions & 6 deletions sklearn/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
)
from .. import get_config
from ._bunch import Bunch
from ._param_validation import validate_params, Interval


# Do not deprecate parallel_backend and register_parallel_backend as they are
Expand Down Expand Up @@ -725,6 +726,13 @@ def _chunk_generator(gen, chunksize):
return


@validate_params(
{
"n": [Interval(numbers.Integral, 1, None, closed="left")],
"batch_size": [Interval(numbers.Integral, 1, None, closed="left")],
"min_batch_size": [Interval(numbers.Integral, 0, None, closed="left")],
}
)
def gen_batches(n, batch_size, *, min_batch_size=0):
"""Generator to create slices containing `batch_size` elements from 0 to `n`.

Expand Down Expand Up @@ -762,12 +770,6 @@ def gen_batches(n, batch_size, *, min_batch_size=0):
>>> list(gen_batches(7, 3, min_batch_size=2))
[slice(0, 3, None), slice(3, 7, None)]
"""
if not isinstance(batch_size, numbers.Integral):
raise TypeError(
"gen_batches got batch_size=%s, must be an integer" % batch_size
)
if batch_size <= 0:
raise ValueError("gen_batches got batch_size=%s, must be positive" % batch_size)
start = 0
for _ in range(int(n // batch_size)):
end = start + batch_size
Expand Down
14 changes: 0 additions & 14 deletions sklearn/utils/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from sklearn.utils import check_random_state
from sklearn.utils import _determine_key_type
from sklearn.utils import deprecated
from sklearn.utils import gen_batches
from sklearn.utils import _get_column_indices
from sklearn.utils import resample
from sklearn.utils import safe_mask
Expand Down Expand Up @@ -56,19 +55,6 @@ def test_make_rng():
check_random_state("some invalid seed")


def test_gen_batches():
# Make sure gen_batches errors on invalid batch_size

assert_array_equal(list(gen_batches(4, 2)), [slice(0, 2, None), slice(2, 4, None)])
msg_zero = "gen_batches got batch_size=0, must be positive"
with pytest.raises(ValueError, match=msg_zero):
next(gen_batches(4, 0))

msg_float = "gen_batches got batch_size=0.5, must be an integer"
with pytest.raises(TypeError, match=msg_float):
next(gen_batches(4, 0.5))


def test_deprecated():
# Test whether the deprecated decorator issues appropriate warnings
# Copied almost verbatim from https://docs.python.org/library/warnings.html
Expand Down