From 822673a227e541dbacd608a064e05434175b0987 Mon Sep 17 00:00:00 2001 From: Charlie-XIAO Date: Sat, 8 Apr 2023 09:14:48 +0800 Subject: [PATCH] MAINT Parameters validation for sklearn.datasets.fetch_rcv1 --- sklearn/datasets/_rcv1.py | 11 +++++++++++ sklearn/tests/test_public_functions.py | 1 + 2 files changed, 12 insertions(+) diff --git a/sklearn/datasets/_rcv1.py b/sklearn/datasets/_rcv1.py index cca30afefff34..0586f585e4221 100644 --- a/sklearn/datasets/_rcv1.py +++ b/sklearn/datasets/_rcv1.py @@ -26,6 +26,7 @@ from ._svmlight_format_io import load_svmlight_files from ..utils import shuffle as shuffle_ from ..utils import Bunch +from ..utils._param_validation import validate_params, StrOptions # The original vectorized data can be found at: @@ -76,6 +77,16 @@ logger = logging.getLogger(__name__) +@validate_params( + { + "data_home": [str, None], + "subset": [StrOptions({"train", "test", "all"})], + "download_if_missing": ["boolean"], + "random_state": ["random_state"], + "shuffle": ["boolean"], + "return_X_y": ["boolean"], + } +) def fetch_rcv1( *, data_home=None, diff --git a/sklearn/tests/test_public_functions.py b/sklearn/tests/test_public_functions.py index a4a9dbd9db739..3e31fbdd547a5 100644 --- a/sklearn/tests/test_public_functions.py +++ b/sklearn/tests/test_public_functions.py @@ -127,6 +127,7 @@ def _check_function_param_validation( "sklearn.datasets.fetch_lfw_pairs", "sklearn.datasets.fetch_lfw_people", "sklearn.datasets.fetch_olivetti_faces", + "sklearn.datasets.fetch_rcv1", "sklearn.datasets.load_svmlight_file", "sklearn.datasets.load_svmlight_files", "sklearn.datasets.make_biclusters",