From c4d2e407cd6a7cdd2f684971d39d1a71762209e4 Mon Sep 17 00:00:00 2001 From: Shivachauhan17 Date: Wed, 15 Mar 2023 11:16:05 +0530 Subject: [PATCH 1/4] added parametr validation to sklearn.datasets.fetch_lfw_pair --- sklearn/datasets/_lfw.py | 11 +++++++++++ sklearn/tests/test_public_functions.py | 1 + 2 files changed, 12 insertions(+) diff --git a/sklearn/datasets/_lfw.py b/sklearn/datasets/_lfw.py index c6f1a5f9a90c8..31b72d1a4abcb 100644 --- a/sklearn/datasets/_lfw.py +++ b/sklearn/datasets/_lfw.py @@ -427,6 +427,17 @@ def _fetch_lfw_pairs( return pairs, target, np.array(["Different persons", "Same person"]) +@validate_params( + { + "subset": ["train", "test", "10_folds"], + "data_home": [str, None], + "funneled": ["boolean"], + "resize": [Interval(Real, 0, None, closed="neither"), None], + "color": ["boolean"], + "slice_": [tuple, Hidden(None)], + "download_if_missing": ["boolean"], + } +) def fetch_lfw_pairs( *, subset="train", diff --git a/sklearn/tests/test_public_functions.py b/sklearn/tests/test_public_functions.py index a9fff191b06a7..c05cbc31b2da5 100644 --- a/sklearn/tests/test_public_functions.py +++ b/sklearn/tests/test_public_functions.py @@ -112,6 +112,7 @@ def _check_function_param_validation( "sklearn.datasets.fetch_olivetti_faces", "sklearn.datasets.load_svmlight_file", "sklearn.datasets.load_svmlight_files", + "sklearn.datasets.fetch_lfw_pairs", "sklearn.datasets.make_circles", "sklearn.datasets.make_classification", "sklearn.datasets.make_friedman1", From d001451b9a1a52330dcce87c57f4a439ec727d71 Mon Sep 17 00:00:00 2001 From: Shiva chauhan <103742975+Shivachauhan17@users.noreply.github.com> Date: Thu, 16 Mar 2023 17:42:56 +0530 Subject: [PATCH 2/4] Update _lfw.py --- sklearn/datasets/_lfw.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sklearn/datasets/_lfw.py b/sklearn/datasets/_lfw.py index 31b72d1a4abcb..01f993fc52873 100644 --- a/sklearn/datasets/_lfw.py +++ b/sklearn/datasets/_lfw.py @@ -10,7 +10,7 @@ from os import listdir, makedirs, remove from os.path import join, exists, isdir -from ..utils._param_validation import validate_params, Interval, Hidden +from ..utils._param_validation import validate_params, Interval, Hidden, StrOptions from numbers import Integral, Real import logging @@ -429,7 +429,7 @@ def _fetch_lfw_pairs( @validate_params( { - "subset": ["train", "test", "10_folds"], + "subset": [StrOptions({"train", "test", "10_folds"})], "data_home": [str, None], "funneled": ["boolean"], "resize": [Interval(Real, 0, None, closed="neither"), None], From 196546b914327ef35e503e0234de22f38b7bdb42 Mon Sep 17 00:00:00 2001 From: Shiva chauhan <103742975+Shivachauhan17@users.noreply.github.com> Date: Thu, 16 Mar 2023 18:03:25 +0530 Subject: [PATCH 3/4] Update test_public_functions.py --- sklearn/tests/test_public_functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/tests/test_public_functions.py b/sklearn/tests/test_public_functions.py index c05cbc31b2da5..5ffa7b468053a 100644 --- a/sklearn/tests/test_public_functions.py +++ b/sklearn/tests/test_public_functions.py @@ -109,10 +109,10 @@ def _check_function_param_validation( "sklearn.datasets.fetch_covtype", "sklearn.datasets.fetch_kddcup99", "sklearn.datasets.fetch_lfw_people", + "sklearn.datasets.fetch_lfw_pairs", "sklearn.datasets.fetch_olivetti_faces", "sklearn.datasets.load_svmlight_file", "sklearn.datasets.load_svmlight_files", - "sklearn.datasets.fetch_lfw_pairs", "sklearn.datasets.make_circles", "sklearn.datasets.make_classification", "sklearn.datasets.make_friedman1", From a671a81c695df3650554056ccca53c5704d4d358 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=A9mie=20du=20Boisberranger?= <34657725+jeremiedbb@users.noreply.github.com> Date: Thu, 16 Mar 2023 14:56:07 +0100 Subject: [PATCH 4/4] alphabetical order --- sklearn/tests/test_public_functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/tests/test_public_functions.py b/sklearn/tests/test_public_functions.py index 5ffa7b468053a..4e22b5043c1de 100644 --- a/sklearn/tests/test_public_functions.py +++ b/sklearn/tests/test_public_functions.py @@ -108,8 +108,8 @@ def _check_function_param_validation( "sklearn.datasets.fetch_california_housing", "sklearn.datasets.fetch_covtype", "sklearn.datasets.fetch_kddcup99", - "sklearn.datasets.fetch_lfw_people", "sklearn.datasets.fetch_lfw_pairs", + "sklearn.datasets.fetch_lfw_people", "sklearn.datasets.fetch_olivetti_faces", "sklearn.datasets.load_svmlight_file", "sklearn.datasets.load_svmlight_files",