From 3fcb4c29b10b397ec5e26b164b62f96493d3cbcd Mon Sep 17 00:00:00 2001 From: pm155 Date: Mon, 23 Jan 2023 14:59:43 -0600 Subject: [PATCH 1/3] MAINT Parameters validation for sklearn.datasets.fetch_kddcup99 --- sklearn/datasets/_kddcup99.py | 14 +++++++++++++- sklearn/tests/test_public_functions.py | 1 + 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/sklearn/datasets/_kddcup99.py b/sklearn/datasets/_kddcup99.py index 59cc27a6877a0..5f70d9b060c82 100644 --- a/sklearn/datasets/_kddcup99.py +++ b/sklearn/datasets/_kddcup99.py @@ -45,7 +45,19 @@ logger = logging.getLogger(__name__) - + @validate_params( + { + "subset": [StrOptions({"SA", "SF", "http","smtp"}),None], + "data_home":[str, None], + "shuffle":["boolean"], + "random_state":["random_state"], + "percent10":["boolean"], + "download_if_missing":["boolean"], + "return_X_y":["boolean"], + "as_frame":["boolean"], + + } + ) def fetch_kddcup99( *, subset=None, diff --git a/sklearn/tests/test_public_functions.py b/sklearn/tests/test_public_functions.py index acc44ce60c755..cc4fa9b4e0c79 100644 --- a/sklearn/tests/test_public_functions.py +++ b/sklearn/tests/test_public_functions.py @@ -102,6 +102,7 @@ def _check_function_param_validation( "sklearn.covariance.empirical_covariance", "sklearn.covariance.shrunk_covariance", "sklearn.datasets.fetch_california_housing", + "sklearn.datasets.fetch_kddcup99", "sklearn.datasets.make_sparse_coded_signal", "sklearn.decomposition.sparse_encode", "sklearn.feature_extraction.grid_to_graph", From 227cb9c74018e9f099d89cc72e2d83053e9a460a Mon Sep 17 00:00:00 2001 From: pm155 Date: Tue, 24 Jan 2023 15:51:02 -0600 Subject: [PATCH 2/3] MAINT Parameters validation for sklearn.datasets.fetch_kddcup99 --- sklearn/datasets/_kddcup99.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/sklearn/datasets/_kddcup99.py b/sklearn/datasets/_kddcup99.py index 5f70d9b060c82..de1b30dcf8c13 100644 --- a/sklearn/datasets/_kddcup99.py +++ b/sklearn/datasets/_kddcup99.py @@ -22,6 +22,8 @@ from . import get_data_home from ._base import RemoteFileMetadata from ._base import load_descr +from ..utils._param_validation import StrOptions +from ..utils._param_validation import validate_params from ..utils import Bunch from ..utils import check_random_state from ..utils import shuffle as shuffle_method @@ -45,7 +47,7 @@ logger = logging.getLogger(__name__) - @validate_params( +@validate_params( { "subset": [StrOptions({"SA", "SF", "http","smtp"}),None], "data_home":[str, None], @@ -57,7 +59,7 @@ "as_frame":["boolean"], } - ) +) def fetch_kddcup99( *, subset=None, From 9610895708bbce397849b689ae2f32f49be75ecf Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Thu, 2 Feb 2023 11:12:48 +0100 Subject: [PATCH 3/3] MAINT black compliance --- sklearn/datasets/_kddcup99.py | 27 +++++++++++++-------------- 1 file changed, 13 insertions(+), 14 deletions(-) diff --git a/sklearn/datasets/_kddcup99.py b/sklearn/datasets/_kddcup99.py index de1b30dcf8c13..3a49d612af850 100644 --- a/sklearn/datasets/_kddcup99.py +++ b/sklearn/datasets/_kddcup99.py @@ -22,8 +22,7 @@ from . import get_data_home from ._base import RemoteFileMetadata from ._base import load_descr -from ..utils._param_validation import StrOptions -from ..utils._param_validation import validate_params +from ..utils._param_validation import StrOptions, validate_params from ..utils import Bunch from ..utils import check_random_state from ..utils import shuffle as shuffle_method @@ -47,18 +46,18 @@ logger = logging.getLogger(__name__) -@validate_params( - { - "subset": [StrOptions({"SA", "SF", "http","smtp"}),None], - "data_home":[str, None], - "shuffle":["boolean"], - "random_state":["random_state"], - "percent10":["boolean"], - "download_if_missing":["boolean"], - "return_X_y":["boolean"], - "as_frame":["boolean"], - - } + +@validate_params( + { + "subset": [StrOptions({"SA", "SF", "http", "smtp"}), None], + "data_home": [str, None], + "shuffle": ["boolean"], + "random_state": ["random_state"], + "percent10": ["boolean"], + "download_if_missing": ["boolean"], + "return_X_y": ["boolean"], + "as_frame": ["boolean"], + } ) def fetch_kddcup99( *,