From 5143a01096c1242f28bed6980914c9c1ff0136e1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9ophile=20Baranger?= Date: Fri, 24 Mar 2023 22:55:19 +0100 Subject: [PATCH 1/7] add parameters validation for datasets.make_moons --- sklearn/datasets/_samples_generator.py | 8 ++++++++ sklearn/tests/test_public_functions.py | 1 + 2 files changed, 9 insertions(+) diff --git a/sklearn/datasets/_samples_generator.py b/sklearn/datasets/_samples_generator.py index 192f8e6759ddb..3c8a5b3062df9 100644 --- a/sklearn/datasets/_samples_generator.py +++ b/sklearn/datasets/_samples_generator.py @@ -772,6 +772,14 @@ def make_circles( return X, y +@validate_params( + { + "n_samples": [Interval(Integral, 1, None, closed="left"), tuple], + "shuffle": ["boolean"], + "noise": [Interval(Real, 0, None, closed="left"), None], + "random_state": ["random_state"], + } +) def make_moons(n_samples=100, *, shuffle=True, noise=None, random_state=None): """Make two interleaving half circles. diff --git a/sklearn/tests/test_public_functions.py b/sklearn/tests/test_public_functions.py index 4b1934c378fbf..43d896db12403 100644 --- a/sklearn/tests/test_public_functions.py +++ b/sklearn/tests/test_public_functions.py @@ -133,6 +133,7 @@ def _check_function_param_validation( "sklearn.datasets.make_classification", "sklearn.datasets.make_friedman1", "sklearn.datasets.make_low_rank_matrix", + "sklearn.datasets.make_moons", "sklearn.datasets.make_multilabel_classification", "sklearn.datasets.make_regression", "sklearn.datasets.make_sparse_coded_signal", From 6c9f77848d9bf1cb7d6a08b2f087ed102c31fc4f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9ophile=20Baranger?= Date: Fri, 24 Mar 2023 23:01:30 +0100 Subject: [PATCH 2/7] remove in-function parameter validation for datasets.make_moons --- sklearn/datasets/_samples_generator.py | 7 +------ sklearn/datasets/tests/test_samples_generator.py | 12 ------------ 2 files changed, 1 insertion(+), 18 deletions(-) diff --git a/sklearn/datasets/_samples_generator.py b/sklearn/datasets/_samples_generator.py index 3c8a5b3062df9..d3d1019e29339 100644 --- a/sklearn/datasets/_samples_generator.py +++ b/sklearn/datasets/_samples_generator.py @@ -819,12 +819,7 @@ def make_moons(n_samples=100, *, shuffle=True, noise=None, random_state=None): n_samples_out = n_samples // 2 n_samples_in = n_samples - n_samples_out else: - try: - n_samples_out, n_samples_in = n_samples - except ValueError as e: - raise ValueError( - "`n_samples` can be either an int or a two-element tuple." - ) from e + n_samples_out, n_samples_in = n_samples generator = check_random_state(random_state) diff --git a/sklearn/datasets/tests/test_samples_generator.py b/sklearn/datasets/tests/test_samples_generator.py index e722ed5c4f02a..88f7fd886feaf 100644 --- a/sklearn/datasets/tests/test_samples_generator.py +++ b/sklearn/datasets/tests/test_samples_generator.py @@ -636,18 +636,6 @@ def test_make_moons_unbalanced(): assert X.shape == (12, 2), "X shape mismatch" assert y.shape == (12,), "y shape mismatch" - with pytest.raises( - ValueError, - match=r"`n_samples` can be either an int " r"or a two-element tuple.", - ): - make_moons(n_samples=[1, 2, 3]) - - with pytest.raises( - ValueError, - match=r"`n_samples` can be either an int " r"or a two-element tuple.", - ): - make_moons(n_samples=(10,)) - def test_make_circles(): factor = 0.3 From 4148bc4cb682aa270bec14d96aa21f3f3f27f35e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9ophile=20Baranger?= Date: Mon, 27 Mar 2023 20:20:22 +0200 Subject: [PATCH 3/7] Revert "add parameters validation for datasets.make_moons" This reverts commit 5143a01096c1242f28bed6980914c9c1ff0136e1. --- sklearn/datasets/_samples_generator.py | 8 -------- sklearn/tests/test_public_functions.py | 1 - 2 files changed, 9 deletions(-) diff --git a/sklearn/datasets/_samples_generator.py b/sklearn/datasets/_samples_generator.py index d3d1019e29339..d07af8544b924 100644 --- a/sklearn/datasets/_samples_generator.py +++ b/sklearn/datasets/_samples_generator.py @@ -772,14 +772,6 @@ def make_circles( return X, y -@validate_params( - { - "n_samples": [Interval(Integral, 1, None, closed="left"), tuple], - "shuffle": ["boolean"], - "noise": [Interval(Real, 0, None, closed="left"), None], - "random_state": ["random_state"], - } -) def make_moons(n_samples=100, *, shuffle=True, noise=None, random_state=None): """Make two interleaving half circles. diff --git a/sklearn/tests/test_public_functions.py b/sklearn/tests/test_public_functions.py index 43d896db12403..4b1934c378fbf 100644 --- a/sklearn/tests/test_public_functions.py +++ b/sklearn/tests/test_public_functions.py @@ -133,7 +133,6 @@ def _check_function_param_validation( "sklearn.datasets.make_classification", "sklearn.datasets.make_friedman1", "sklearn.datasets.make_low_rank_matrix", - "sklearn.datasets.make_moons", "sklearn.datasets.make_multilabel_classification", "sklearn.datasets.make_regression", "sklearn.datasets.make_sparse_coded_signal", From 4892064da4f7d8a1b9f921a1a7cae0dc85e2d9ae Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9ophile=20Baranger?= Date: Mon, 27 Mar 2023 20:20:57 +0200 Subject: [PATCH 4/7] Revert "remove in-function parameter validation for datasets.make_moons" This reverts commit 6c9f77848d9bf1cb7d6a08b2f087ed102c31fc4f. --- sklearn/datasets/_samples_generator.py | 7 ++++++- sklearn/datasets/tests/test_samples_generator.py | 12 ++++++++++++ 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/sklearn/datasets/_samples_generator.py b/sklearn/datasets/_samples_generator.py index d07af8544b924..192f8e6759ddb 100644 --- a/sklearn/datasets/_samples_generator.py +++ b/sklearn/datasets/_samples_generator.py @@ -811,7 +811,12 @@ def make_moons(n_samples=100, *, shuffle=True, noise=None, random_state=None): n_samples_out = n_samples // 2 n_samples_in = n_samples - n_samples_out else: - n_samples_out, n_samples_in = n_samples + try: + n_samples_out, n_samples_in = n_samples + except ValueError as e: + raise ValueError( + "`n_samples` can be either an int or a two-element tuple." + ) from e generator = check_random_state(random_state) diff --git a/sklearn/datasets/tests/test_samples_generator.py b/sklearn/datasets/tests/test_samples_generator.py index 88f7fd886feaf..e722ed5c4f02a 100644 --- a/sklearn/datasets/tests/test_samples_generator.py +++ b/sklearn/datasets/tests/test_samples_generator.py @@ -636,6 +636,18 @@ def test_make_moons_unbalanced(): assert X.shape == (12, 2), "X shape mismatch" assert y.shape == (12,), "y shape mismatch" + with pytest.raises( + ValueError, + match=r"`n_samples` can be either an int " r"or a two-element tuple.", + ): + make_moons(n_samples=[1, 2, 3]) + + with pytest.raises( + ValueError, + match=r"`n_samples` can be either an int " r"or a two-element tuple.", + ): + make_moons(n_samples=(10,)) + def test_make_circles(): factor = 0.3 From b7edfec6dba09ea6683924bcb64ce6bdba6ea839 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9ophile=20Baranger?= Date: Mon, 27 Mar 2023 20:28:30 +0200 Subject: [PATCH 5/7] add parameters validation to datasets.make_moons --- sklearn/datasets/_samples_generator.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/sklearn/datasets/_samples_generator.py b/sklearn/datasets/_samples_generator.py index 192f8e6759ddb..3c8a5b3062df9 100644 --- a/sklearn/datasets/_samples_generator.py +++ b/sklearn/datasets/_samples_generator.py @@ -772,6 +772,14 @@ def make_circles( return X, y +@validate_params( + { + "n_samples": [Interval(Integral, 1, None, closed="left"), tuple], + "shuffle": ["boolean"], + "noise": [Interval(Real, 0, None, closed="left"), None], + "random_state": ["random_state"], + } +) def make_moons(n_samples=100, *, shuffle=True, noise=None, random_state=None): """Make two interleaving half circles. From 14346b57c7bed961fbe7bd5d3648f5b85c401c2b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9ophile=20Baranger?= Date: Mon, 27 Mar 2023 20:30:02 +0200 Subject: [PATCH 6/7] remove obsolete datasets.make_moons unit test --- sklearn/datasets/tests/test_samples_generator.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/sklearn/datasets/tests/test_samples_generator.py b/sklearn/datasets/tests/test_samples_generator.py index e722ed5c4f02a..cd23fc5016672 100644 --- a/sklearn/datasets/tests/test_samples_generator.py +++ b/sklearn/datasets/tests/test_samples_generator.py @@ -636,12 +636,6 @@ def test_make_moons_unbalanced(): assert X.shape == (12, 2), "X shape mismatch" assert y.shape == (12,), "y shape mismatch" - with pytest.raises( - ValueError, - match=r"`n_samples` can be either an int " r"or a two-element tuple.", - ): - make_moons(n_samples=[1, 2, 3]) - with pytest.raises( ValueError, match=r"`n_samples` can be either an int " r"or a two-element tuple.", From 7e946b3345b1d83c1586b5a4e6ed43eb815c058c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9ophile=20Baranger?= Date: Mon, 27 Mar 2023 20:32:51 +0200 Subject: [PATCH 7/7] add datasets.make_moons to public functions tests --- sklearn/tests/test_public_functions.py | 1 + 1 file changed, 1 insertion(+) diff --git a/sklearn/tests/test_public_functions.py b/sklearn/tests/test_public_functions.py index 4b1934c378fbf..43d896db12403 100644 --- a/sklearn/tests/test_public_functions.py +++ b/sklearn/tests/test_public_functions.py @@ -133,6 +133,7 @@ def _check_function_param_validation( "sklearn.datasets.make_classification", "sklearn.datasets.make_friedman1", "sklearn.datasets.make_low_rank_matrix", + "sklearn.datasets.make_moons", "sklearn.datasets.make_multilabel_classification", "sklearn.datasets.make_regression", "sklearn.datasets.make_sparse_coded_signal",