Skip to content

[MRG+1] break the tie in Meanshift in case cluster intensities are the same #11901

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Sep 5, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions doc/whats_new/v0.20.rst
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ parameters, may produce different models from the previous version. This often
occurs due to changes in the modelling logic (bug fixes or enhancements), or in
random sampling procedures.

- :class:`cluster.MeanShift` (bug fix)
- :class:`decomposition.IncrementalPCA` in Python 2 (bug fix)
- :class:`decomposition.SparsePCA` (bug fix)
- :class:`ensemble.GradientBoostingClassifier` (bug fix affecting feature importances)
Expand Down Expand Up @@ -140,6 +141,11 @@ Support for Python 3.3 has been officially dropped.
`n_iter_` attribute in the docstring of :class:`cluster.KMeans`.
:issue:`11353` by :user:`Jeremie du Boisberranger <jeremiedbb>`.

- |Fix| Fixed a bug in :func:`cluster.mean_shift` where the assigned labels
were not deterministic if there were multiple clusters with the same
intensities.
:issue:`11901` by :user:`Adrin Jalali <adrinjalali>`.

- |API| Deprecate ``pooling_func`` unused parameter in
:class:`cluster.AgglomerativeClustering`.
:issue:`9875` by :user:`Kumar Ashutosh <thechargedneutron>`.
Expand Down
8 changes: 5 additions & 3 deletions sklearn/cluster/mean_shift_.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,8 +215,10 @@ def mean_shift(X, bandwidth=None, seeds=None, bin_seeding=False,
# If the distance between two kernels is less than the bandwidth,
# then we have to remove one because it is a duplicate. Remove the
# one with fewer points.

sorted_by_intensity = sorted(center_intensity_dict.items(),
key=lambda tup: tup[1], reverse=True)
key=lambda tup: (tup[1], tup[0]),
reverse=True)
sorted_centers = np.array([tup[0] for tup in sorted_by_intensity])
unique = np.ones(len(sorted_centers), dtype=np.bool)
nbrs = NearestNeighbors(radius=bandwidth,
Expand Down Expand Up @@ -359,9 +361,9 @@ class MeanShift(BaseEstimator, ClusterMixin):
... [4, 7], [3, 5], [3, 6]])
>>> clustering = MeanShift(bandwidth=2).fit(X)
>>> clustering.labels_
array([0, 0, 0, 1, 1, 1])
array([1, 1, 1, 0, 0, 0])
>>> clustering.predict([[0, 0], [5, 5]])
array([0, 1])
array([1, 0])
>>> clustering # doctest: +NORMALIZE_WHITESPACE
MeanShift(bandwidth=2, bin_seeding=False, cluster_all=True, min_bin_freq=1,
n_jobs=None, seeds=None)
Expand Down
12 changes: 12 additions & 0 deletions sklearn/cluster/tests/test_mean_shift.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,18 @@ def test_unfitted():
assert_false(hasattr(ms, "labels_"))


def test_cluster_intensity_tie():
X = np.array([[1, 1], [2, 1], [1, 0],
[4, 7], [3, 5], [3, 6]])
c1 = MeanShift(bandwidth=2).fit(X)

X = np.array([[4, 7], [3, 5], [3, 6],
[1, 1], [2, 1], [1, 0]])
c2 = MeanShift(bandwidth=2).fit(X)
assert_array_equal(c1.labels_, [1, 1, 1, 0, 0, 0])
assert_array_equal(c2.labels_, [0, 0, 0, 1, 1, 1])


def test_bin_seeds():
# Test the bin seeding technique which can be used in the mean shift
# algorithm
Expand Down