diff --git a/sklearn/cluster/mean_shift_.py b/sklearn/cluster/mean_shift_.py index ce5dac8b5a318..7e93e715b7585 100644 --- a/sklearn/cluster/mean_shift_.py +++ b/sklearn/cluster/mean_shift_.py @@ -183,8 +183,8 @@ def mean_shift(X, bandwidth=None, seeds=None, bin_seeding=False, if bandwidth is None: bandwidth = estimate_bandwidth(X, n_jobs=n_jobs) elif bandwidth <= 0: - raise ValueError("bandwidth needs to be greater than zero or None,\ - got %f" % bandwidth) + raise ValueError("bandwidth needs to be greater than zero or None," + " got %f" % bandwidth) if seeds is None: if bin_seeding: seeds = get_bin_seeds(X, bandwidth, min_bin_freq) diff --git a/sklearn/cluster/tests/test_mean_shift.py b/sklearn/cluster/tests/test_mean_shift.py index 233106518f86e..6ea5cb8bda1d3 100644 --- a/sklearn/cluster/tests/test_mean_shift.py +++ b/sklearn/cluster/tests/test_mean_shift.py @@ -5,10 +5,10 @@ import numpy as np import warnings +import pytest from scipy import sparse -from sklearn.utils.testing import assert_equal from sklearn.utils.testing import assert_array_equal from sklearn.utils.testing import assert_array_almost_equal from sklearn.utils.testing import assert_raise_message @@ -36,23 +36,35 @@ def test_estimate_bandwidth_1sample(): # Test estimate_bandwidth when n_samples=1 and quantile<1, so that # n_neighbors is set to 1. bandwidth = estimate_bandwidth(X, n_samples=1, quantile=0.3) - assert_array_almost_equal(bandwidth, 0., decimal=5) + assert bandwidth == 0. -def test_mean_shift(): +@pytest.mark.parametrize("bandwidth, cluster_all, expected, " + "first_cluster_label", + [(1.2, True, 3, 0), (1.2, False, 4, -1)]) +def test_mean_shift(bandwidth, cluster_all, expected, first_cluster_label): # Test MeanShift algorithm - bandwidth = 1.2 - - ms = MeanShift(bandwidth=bandwidth) + ms = MeanShift(bandwidth=bandwidth, cluster_all=cluster_all) labels = ms.fit(X).labels_ labels_unique = np.unique(labels) n_clusters_ = len(labels_unique) - assert_equal(n_clusters_, n_clusters) + assert n_clusters_ == expected + assert labels_unique[0] == first_cluster_label - cluster_centers, labels = mean_shift(X, bandwidth=bandwidth) - labels_unique = np.unique(labels) - n_clusters_ = len(labels_unique) - assert_equal(n_clusters_, n_clusters) + cluster_centers, labels_mean_shift = mean_shift(X, cluster_all=cluster_all) + labels_mean_shift_unique = np.unique(labels_mean_shift) + n_clusters_mean_shift = len(labels_mean_shift_unique) + assert n_clusters_mean_shift == expected + assert labels_mean_shift_unique[0] == first_cluster_label + + +def test_mean_shift_negative_bandwidth(): + bandwidth = -1 + ms = MeanShift(bandwidth=bandwidth) + msg = (r"bandwidth needs to be greater than zero or None," + r" got -1\.000000") + with pytest.raises(ValueError, match=msg): + ms.fit(X) def test_estimate_bandwidth_with_sparse_matrix():