Skip to content

new tests for mean_shift algo #13179

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 25, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions sklearn/cluster/mean_shift_.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,8 +183,8 @@ def mean_shift(X, bandwidth=None, seeds=None, bin_seeding=False,
if bandwidth is None:
bandwidth = estimate_bandwidth(X, n_jobs=n_jobs)
elif bandwidth <= 0:
raise ValueError("bandwidth needs to be greater than zero or None,\
got %f" % bandwidth)
raise ValueError("bandwidth needs to be greater than zero or None,"
" got %f" % bandwidth)
if seeds is None:
if bin_seeding:
seeds = get_bin_seeds(X, bandwidth, min_bin_freq)
Expand Down
34 changes: 23 additions & 11 deletions sklearn/cluster/tests/test_mean_shift.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@

import numpy as np
import warnings
import pytest

from scipy import sparse

from sklearn.utils.testing import assert_equal
from sklearn.utils.testing import assert_array_equal
from sklearn.utils.testing import assert_array_almost_equal
from sklearn.utils.testing import assert_raise_message
Expand Down Expand Up @@ -36,23 +36,35 @@ def test_estimate_bandwidth_1sample():
# Test estimate_bandwidth when n_samples=1 and quantile<1, so that
# n_neighbors is set to 1.
bandwidth = estimate_bandwidth(X, n_samples=1, quantile=0.3)
assert_array_almost_equal(bandwidth, 0., decimal=5)
assert bandwidth == 0.


def test_mean_shift():
@pytest.mark.parametrize("bandwidth, cluster_all, expected, "
"first_cluster_label",
[(1.2, True, 3, 0), (1.2, False, 4, -1)])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Much clearer, thanks!

def test_mean_shift(bandwidth, cluster_all, expected, first_cluster_label):
# Test MeanShift algorithm
bandwidth = 1.2

ms = MeanShift(bandwidth=bandwidth)
ms = MeanShift(bandwidth=bandwidth, cluster_all=cluster_all)
labels = ms.fit(X).labels_
labels_unique = np.unique(labels)
n_clusters_ = len(labels_unique)
assert_equal(n_clusters_, n_clusters)
assert n_clusters_ == expected
assert labels_unique[0] == first_cluster_label

cluster_centers, labels = mean_shift(X, bandwidth=bandwidth)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removing this means we are not testing the mean_shift function directly anymore.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we are testing using
ms = MeanShift(bandwidth=bandwidth, cluster_all=cluster_all)
labels = ms.fit(X).labels_

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The testing of mean_shift should be independent of ms.fit. At the moment, ms.fit calls mean_shift, but we do not know how the code base will change.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@thomasjpfan do we need another test for mean_shift?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Leaving the original test here will sufficiently test mean_shift.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@thomasjpfan added test for mean_shift as well

labels_unique = np.unique(labels)
n_clusters_ = len(labels_unique)
assert_equal(n_clusters_, n_clusters)
cluster_centers, labels_mean_shift = mean_shift(X, cluster_all=cluster_all)
labels_mean_shift_unique = np.unique(labels_mean_shift)
n_clusters_mean_shift = len(labels_mean_shift_unique)
assert n_clusters_mean_shift == expected
assert labels_mean_shift_unique[0] == first_cluster_label


def test_mean_shift_negative_bandwidth():
bandwidth = -1
ms = MeanShift(bandwidth=bandwidth)
msg = (r"bandwidth needs to be greater than zero or None,"
r" got -1\.000000")
with pytest.raises(ValueError, match=msg):
ms.fit(X)


def test_estimate_bandwidth_with_sparse_matrix():
Expand Down