diff --git a/sklearn/datasets/_base.py b/sklearn/datasets/_base.py index ed94d5ec1f0d7..70307ebf63c03 100644 --- a/sklearn/datasets/_base.py +++ b/sklearn/datasets/_base.py @@ -15,13 +15,14 @@ from os import environ, listdir, makedirs from os.path import expanduser, isdir, join, splitext from pathlib import Path +from numbers import Integral from ..preprocessing import scale from ..utils import Bunch from ..utils import check_random_state from ..utils import check_pandas_support from ..utils.fixes import _open_binary, _open_text, _read_text, _contents -from ..utils._param_validation import validate_params +from ..utils._param_validation import validate_params, Interval import numpy as np @@ -820,6 +821,13 @@ def load_breast_cancer(*, return_X_y=False, as_frame=False): ) +@validate_params( + { + "n_class": [Interval(Integral, 1, 10, closed="both")], + "return_X_y": ["boolean"], + "as_frame": ["boolean"], + } +) def load_digits(*, n_class=10, return_X_y=False, as_frame=False): """Load and return the digits dataset (classification). diff --git a/sklearn/tests/test_public_functions.py b/sklearn/tests/test_public_functions.py index d46ae07821ac2..159eaaa17ba07 100644 --- a/sklearn/tests/test_public_functions.py +++ b/sklearn/tests/test_public_functions.py @@ -131,6 +131,7 @@ def _check_function_param_validation( "sklearn.datasets.fetch_species_distributions", "sklearn.datasets.load_breast_cancer", "sklearn.datasets.load_diabetes", + "sklearn.datasets.load_digits", "sklearn.datasets.load_iris", "sklearn.datasets.load_svmlight_file", "sklearn.datasets.load_svmlight_files",