From 00597594c3eae7432de0e0ebf6c3c83d84bf499d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9ophile=20Baranger?= Date: Mon, 27 Mar 2023 21:16:43 +0200 Subject: [PATCH 1/5] add parameters validation to datasets.make_blobs --- sklearn/datasets/_samples_generator.py | 12 ++++++++++++ sklearn/tests/test_public_functions.py | 1 + 2 files changed, 13 insertions(+) diff --git a/sklearn/datasets/_samples_generator.py b/sklearn/datasets/_samples_generator.py index 3913ad6eedfb4..e39bf5b0a1033 100644 --- a/sklearn/datasets/_samples_generator.py +++ b/sklearn/datasets/_samples_generator.py @@ -847,6 +847,18 @@ def make_moons(n_samples=100, *, shuffle=True, noise=None, random_state=None): return X, y +@validate_params( + { + "n_samples": [Interval(Integral, 1, None, closed="left"), "array-like"], + "n_features": [Interval(Integral, 1, None, closed="left")], + "centers": [Interval(Integral, 1, None, closed="left"), "array-like", None], + "cluster_std": [Interval(Real, 0, None, closed="left"), "array-like"], + "center_box": [tuple], + "shuffle": ["boolean"], + "random_state": ["random_state"], + "return_centers": ["boolean"], + } +) def make_blobs( n_samples=100, n_features=2, diff --git a/sklearn/tests/test_public_functions.py b/sklearn/tests/test_public_functions.py index 5579ad0215ff1..e7e9dc7bdddf6 100644 --- a/sklearn/tests/test_public_functions.py +++ b/sklearn/tests/test_public_functions.py @@ -129,6 +129,7 @@ def _check_function_param_validation( "sklearn.datasets.fetch_olivetti_faces", "sklearn.datasets.load_svmlight_file", "sklearn.datasets.load_svmlight_files", + "sklearn.datasets.make_blobs", "sklearn.datasets.make_checkerboard", "sklearn.datasets.make_circles", "sklearn.datasets.make_classification", From 160806c516acf2fee0bdfb9d8638348adb9952ee Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Wed, 29 Mar 2023 10:51:15 +0200 Subject: [PATCH 2/5] Update sklearn/tests/test_public_functions.py --- sklearn/tests/test_public_functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/tests/test_public_functions.py b/sklearn/tests/test_public_functions.py index b5fd664d3d734..85d9b083d57a0 100644 --- a/sklearn/tests/test_public_functions.py +++ b/sklearn/tests/test_public_functions.py @@ -130,7 +130,7 @@ def _check_function_param_validation( "sklearn.datasets.load_svmlight_file", "sklearn.datasets.load_svmlight_files", "sklearn.datasets.make_biclusters", - "sklearn.datasets.make_blobs", + "sklearn.datasets.make_blobs", "sklearn.datasets.make_checkerboard", "sklearn.datasets.make_circles", "sklearn.datasets.make_classification", From 6bcaf392117d08ff1279efbe93bd44e983dca37f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9ophile=20Baranger?= Date: Wed, 29 Mar 2023 12:08:54 +0200 Subject: [PATCH 3/5] update centers type to array-like in docstring --- sklearn/datasets/_samples_generator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/datasets/_samples_generator.py b/sklearn/datasets/_samples_generator.py index 619bbd8afcdfb..6f6176e9f6bbe 100644 --- a/sklearn/datasets/_samples_generator.py +++ b/sklearn/datasets/_samples_generator.py @@ -896,7 +896,7 @@ def make_blobs( n_features : int, default=2 The number of features for each sample. - centers : int or ndarray of shape (n_centers, n_features), default=None + centers : int or array-like of shape (n_centers, n_features), default=None The number of centers to generate, or the fixed center locations. If n_samples is an int and centers is None, 3 centers are generated. If n_samples is array-like, centers must be From f23cd9067a4938ea10a934de2b61b919036bab55 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9ophile=20Baranger?= Date: Wed, 29 Mar 2023 18:08:57 +0200 Subject: [PATCH 4/5] remove try except to improve readability --- sklearn/datasets/_samples_generator.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/sklearn/datasets/_samples_generator.py b/sklearn/datasets/_samples_generator.py index 6f6176e9f6bbe..e160f035f168a 100644 --- a/sklearn/datasets/_samples_generator.py +++ b/sklearn/datasets/_samples_generator.py @@ -979,19 +979,17 @@ def make_blobs( centers = generator.uniform( center_box[0], center_box[1], size=(n_centers, n_features) ) - try: - assert len(centers) == n_centers - except TypeError as e: + if not isinstance(centers, Iterable): raise ValueError( "Parameter `centers` must be array-like. Got {!r} instead".format( centers ) - ) from e - except AssertionError as e: + ) + if len(centers) != n_centers: raise ValueError( "Length of `n_samples` not consistent with number of " f"centers. Got n_samples = {n_samples} and centers = {centers}" - ) from e + ) else: centers = check_array(centers) n_features = centers.shape[1] From a32c13b8c7a75358adccd20df350c69aeca2292a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9ophile=20Baranger?= Date: Wed, 29 Mar 2023 18:34:25 +0200 Subject: [PATCH 5/5] remove redundant else and unindent block --- sklearn/datasets/_samples_generator.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/sklearn/datasets/_samples_generator.py b/sklearn/datasets/_samples_generator.py index e160f035f168a..de09a1023cce4 100644 --- a/sklearn/datasets/_samples_generator.py +++ b/sklearn/datasets/_samples_generator.py @@ -990,9 +990,8 @@ def make_blobs( "Length of `n_samples` not consistent with number of " f"centers. Got n_samples = {n_samples} and centers = {centers}" ) - else: - centers = check_array(centers) - n_features = centers.shape[1] + centers = check_array(centers) + n_features = centers.shape[1] # stds: if cluster_std is given as list, it must be consistent # with the n_centers