diff --git a/sklearn/datasets/_samples_generator.py b/sklearn/datasets/_samples_generator.py index 7b078eda40357..de09a1023cce4 100644 --- a/sklearn/datasets/_samples_generator.py +++ b/sklearn/datasets/_samples_generator.py @@ -855,6 +855,18 @@ def make_moons(n_samples=100, *, shuffle=True, noise=None, random_state=None): return X, y +@validate_params( + { + "n_samples": [Interval(Integral, 1, None, closed="left"), "array-like"], + "n_features": [Interval(Integral, 1, None, closed="left")], + "centers": [Interval(Integral, 1, None, closed="left"), "array-like", None], + "cluster_std": [Interval(Real, 0, None, closed="left"), "array-like"], + "center_box": [tuple], + "shuffle": ["boolean"], + "random_state": ["random_state"], + "return_centers": ["boolean"], + } +) def make_blobs( n_samples=100, n_features=2, @@ -884,7 +896,7 @@ def make_blobs( n_features : int, default=2 The number of features for each sample. - centers : int or ndarray of shape (n_centers, n_features), default=None + centers : int or array-like of shape (n_centers, n_features), default=None The number of centers to generate, or the fixed center locations. If n_samples is an int and centers is None, 3 centers are generated. If n_samples is array-like, centers must be @@ -967,22 +979,19 @@ def make_blobs( centers = generator.uniform( center_box[0], center_box[1], size=(n_centers, n_features) ) - try: - assert len(centers) == n_centers - except TypeError as e: + if not isinstance(centers, Iterable): raise ValueError( "Parameter `centers` must be array-like. Got {!r} instead".format( centers ) - ) from e - except AssertionError as e: + ) + if len(centers) != n_centers: raise ValueError( "Length of `n_samples` not consistent with number of " f"centers. Got n_samples = {n_samples} and centers = {centers}" - ) from e - else: - centers = check_array(centers) - n_features = centers.shape[1] + ) + centers = check_array(centers) + n_features = centers.shape[1] # stds: if cluster_std is given as list, it must be consistent # with the n_centers diff --git a/sklearn/tests/test_public_functions.py b/sklearn/tests/test_public_functions.py index d960639cee7fd..85d9b083d57a0 100644 --- a/sklearn/tests/test_public_functions.py +++ b/sklearn/tests/test_public_functions.py @@ -130,6 +130,7 @@ def _check_function_param_validation( "sklearn.datasets.load_svmlight_file", "sklearn.datasets.load_svmlight_files", "sklearn.datasets.make_biclusters", + "sklearn.datasets.make_blobs", "sklearn.datasets.make_checkerboard", "sklearn.datasets.make_circles", "sklearn.datasets.make_classification",