From dccb357b1da38e918aa55ddb21fe9b5467f1c5c1 Mon Sep 17 00:00:00 2001 From: Shivachauhan17 Date: Sat, 4 Mar 2023 14:34:34 +0530 Subject: [PATCH 1/2] add parameter validation to sklearn.datasets.fetch_covtype --- sklearn/datasets/_covtype.py | 12 +++++++++++- sklearn/tests/test_public_functions.py | 1 + 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/sklearn/datasets/_covtype.py b/sklearn/datasets/_covtype.py index b43ea24141eed..60364fb5f408c 100644 --- a/sklearn/datasets/_covtype.py +++ b/sklearn/datasets/_covtype.py @@ -31,7 +31,7 @@ from ..utils import Bunch from ._base import _pkl_filepath from ..utils import check_random_state - +from ..utils._param_validation import validate_params # The original data can be found in: # https://archive.ics.uci.edu/ml/machine-learning-databases/covtype/covtype.data.gz @@ -62,6 +62,16 @@ TARGET_NAMES = ["Cover_Type"] +@validate_params( + { + "data_home": [str, None], + "download_if_missing": ["boolean"], + "random_state": ["random_state"], + "shuffle": ["boolean"], + "return_X_y": ["boolean"], + "as_frame": ["boolean"], + } +) def fetch_covtype( *, data_home=None, diff --git a/sklearn/tests/test_public_functions.py b/sklearn/tests/test_public_functions.py index cbe75a57a3705..301e82e6ae35f 100644 --- a/sklearn/tests/test_public_functions.py +++ b/sklearn/tests/test_public_functions.py @@ -108,6 +108,7 @@ def _check_function_param_validation( "sklearn.datasets.make_classification", "sklearn.datasets.make_friedman1", "sklearn.datasets.make_sparse_coded_signal", + "sklearn.datasets.fetch_covtype", "sklearn.decomposition.sparse_encode", "sklearn.feature_extraction.grid_to_graph", "sklearn.feature_extraction.img_to_graph", From ba57ece579c3eeb731ae00c2577f6bd908babb5a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=A9mie=20du=20Boisberranger?= <34657725+jeremiedbb@users.noreply.github.com> Date: Mon, 6 Mar 2023 14:41:15 +0100 Subject: [PATCH 2/2] Update test_public_functions.py --- sklearn/tests/test_public_functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/tests/test_public_functions.py b/sklearn/tests/test_public_functions.py index 301e82e6ae35f..f0a43dc5612c9 100644 --- a/sklearn/tests/test_public_functions.py +++ b/sklearn/tests/test_public_functions.py @@ -104,11 +104,11 @@ def _check_function_param_validation( "sklearn.covariance.shrunk_covariance", "sklearn.datasets.dump_svmlight_file", "sklearn.datasets.fetch_california_housing", + "sklearn.datasets.fetch_covtype", "sklearn.datasets.fetch_kddcup99", "sklearn.datasets.make_classification", "sklearn.datasets.make_friedman1", "sklearn.datasets.make_sparse_coded_signal", - "sklearn.datasets.fetch_covtype", "sklearn.decomposition.sparse_encode", "sklearn.feature_extraction.grid_to_graph", "sklearn.feature_extraction.img_to_graph",