Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions doc/whats_new/v1.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,13 @@ Fixed models
the Bayesian priors.
:pr:`21179` by :user:`Guillaume Lemaitre <glemaitre>`.

:mod:`sklearn.cluster`
......................

- |Fix| Fixed a bug in :class:`cluster.KMeans`, ensuring reproducibility and equivalence
between sparse and dense input. :pr:`21195`
by :user:`Jérémie du Boisberranger <jeremiedbb>`.

:mod:`sklearn.neighbors`
........................

Expand Down
13 changes: 13 additions & 0 deletions sklearn/cluster/_k_means_common.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -287,3 +287,16 @@ cdef void _center_shift(
for j in range(n_clusters):
center_shift[j] = _euclidean_dense_dense(
&centers_new[j, 0], &centers_old[j, 0], n_features, False)


def _is_same_clustering(int[::1] labels1, int[::1] labels2, n_clusters):
"""Check if two arrays of labels are the same up to a permutation of the labels"""
cdef int[::1] mapping = np.full(fill_value=-1, shape=(n_clusters,), dtype=np.int32)
cdef int i

for i in range(labels1.shape[0]):
if mapping[labels1[i]] == -1:
mapping[labels1[i]] = labels2[i]
elif mapping[labels1[i]] != labels2[i]:
return False
return True
14 changes: 10 additions & 4 deletions sklearn/cluster/_kmeans.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from ._k_means_common import CHUNK_SIZE
from ._k_means_common import _inertia_dense
from ._k_means_common import _inertia_sparse
from ._k_means_common import _is_same_clustering
from ._k_means_minibatch import _minibatch_update_dense
from ._k_means_minibatch import _minibatch_update_sparse
from ._k_means_lloyd import lloyd_iter_chunked_dense
Expand Down Expand Up @@ -1174,7 +1175,7 @@ def fit(self, X, y=None, sample_weight=None):
else:
kmeans_single = _kmeans_single_elkan

best_inertia = None
best_inertia, best_labels = None, None

for i in range(self._n_init):
# Initialize centers
Expand All @@ -1197,9 +1198,14 @@ def fit(self, X, y=None, sample_weight=None):
)

# determine if these results are the best so far
# allow small tolerance on the inertia to accommodate for
# non-deterministic rounding errors due to parallel computation
if best_inertia is None or inertia < best_inertia * (1 - 1e-6):
# we chose a new run if it has a better inertia and the clustering is
# different from the best so far (it's possible that the inertia is
# slightly better even if the clustering is the same with potentially
# permuted labels, due to rounding errors)
if best_inertia is None or (
inertia < best_inertia
and not _is_same_clustering(labels, best_labels, self.n_clusters)
):
best_labels = labels
best_centers = centers
best_inertia = inertia
Expand Down
17 changes: 17 additions & 0 deletions sklearn/cluster/tests/test_k_means.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from sklearn.cluster._k_means_common import _euclidean_sparse_dense_wrapper
from sklearn.cluster._k_means_common import _inertia_dense
from sklearn.cluster._k_means_common import _inertia_sparse
from sklearn.cluster._k_means_common import _is_same_clustering
from sklearn.datasets import make_blobs
from io import StringIO

Expand Down Expand Up @@ -1173,3 +1174,19 @@ def test_kmeans_plusplus_dataorder():
centers_fortran, _ = kmeans_plusplus(X_fortran, n_clusters, random_state=0)

assert_allclose(centers_c, centers_fortran)


def test_is_same_clustering():
# Sanity check for the _is_same_clustering utility function
labels1 = np.array([1, 0, 0, 1, 2, 0, 2, 1], dtype=np.int32)
assert _is_same_clustering(labels1, labels1, 3)

# these other labels represent the same clustering since we can retrive the first
# labels by simply renaming the labels: 0 -> 1, 1 -> 2, 2 -> 0.
labels2 = np.array([0, 2, 2, 0, 1, 2, 1, 0], dtype=np.int32)
assert _is_same_clustering(labels1, labels2, 3)

# these other labels do not represent the same clustering since not all ones are
# mapped to a same value
labels3 = np.array([1, 0, 0, 2, 2, 0, 2, 1], dtype=np.int32)
assert not _is_same_clustering(labels1, labels3, 3)