Skip to content

ENH KMeans initialization account for sample weights #25752

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 97 commits into from
Mar 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
97 commits
Select commit Hold shift + click to select a range
2f7d59f
added support for sample_weights in init
glevv Mar 3, 2023
1d11434
added test
glevv Mar 3, 2023
3985354
lint
glevv Mar 3, 2023
52be2c0
added entry to changelog
glevv Mar 3, 2023
95e3bc1
lint
glevv Mar 3, 2023
58c856b
Update _kmeans.py
glevv Mar 3, 2023
237b6c1
fix bisect kmeans
glevv Mar 3, 2023
6fbe46a
changed tests
glevv Mar 3, 2023
ee3635f
add test for bisect
glevv Mar 3, 2023
4c54226
hotfix
glevv Mar 3, 2023
22baa0a
additions to changelog
glevv Mar 3, 2023
4d88c76
Update test_bisect_k_means.py
glevv Mar 3, 2023
9e25744
Update test_k_means.py
glevv Mar 3, 2023
c942ed7
init_size bugfix
glevv Mar 3, 2023
57b43ed
Update test_bisect_k_means.py
glevv Mar 3, 2023
1e98579
Update test_bisect_k_means.py
glevv Mar 3, 2023
c448b04
Merge branch 'main' into init_samp_w
glevv Mar 3, 2023
4637b82
more robust test
jeremiedbb Mar 3, 2023
3fa6367
fix docs
glevv Mar 5, 2023
ae882bb
hotfix
glevv Mar 5, 2023
26216d9
mention backward incompatibility
glevv Mar 5, 2023
1513863
Merge branch 'main' into init_samp_w
glevv Mar 5, 2023
e857558
Update _bisect_k_means.py
glevv Mar 5, 2023
5955aed
doc fix
glevv Mar 5, 2023
ca91c1f
Merge branch 'main' into init_samp_w
glevv Mar 6, 2023
8b08c44
Merge branch 'scikit-learn:main' into init_samp_w
glevv Mar 7, 2023
829c970
Merge branch 'main' into init_samp_w
glemaitre Mar 8, 2023
66c3aa3
change docstirng
glevv Mar 8, 2023
27cc2a0
use global seed in test
glevv Mar 8, 2023
61147fd
use global seed in bisect kmeans test
glevv Mar 8, 2023
15bd8f1
rename vars
glevv Mar 8, 2023
ce35203
rename vars
glevv Mar 8, 2023
9ae547e
lint
glevv Mar 8, 2023
ba80ef8
updated docs in fit method
glevv Mar 8, 2023
021a531
updated docs of fit and partial_fit methods
glevv Mar 8, 2023
af1783f
Merge branch 'main' into init_samp_w
glevv Mar 9, 2023
136a92d
added jeremiedbb kmeans_plus_plus modification
glevv Mar 9, 2023
c15e107
changed tests
glevv Mar 9, 2023
2c45c04
changed docs
glevv Mar 9, 2023
75f8048
changed tests
glevv Mar 9, 2023
62e85bc
Update test_k_means.py
glevv Mar 9, 2023
b76cfcc
updated changelog
glevv Mar 9, 2023
d49f1ba
lint
glevv Mar 9, 2023
2752135
lint
glevv Mar 9, 2023
15dddb9
Update _kmeans.py
glevv Mar 9, 2023
0d7a725
Update _kmeans.py
glevv Mar 9, 2023
eb8da0b
Update sklearn/cluster/_kmeans.py
glevv Mar 11, 2023
6c20527
Update sklearn/cluster/_kmeans.py
glevv Mar 11, 2023
c640cdc
Update _kmeans.py
glevv Mar 11, 2023
c36a605
Merge branch 'main' into init_samp_w
glevv Mar 11, 2023
b49cdf7
lint
glevv Mar 11, 2023
df99506
doc fix
glevv Mar 12, 2023
30d3b1a
doc fix
glevv Mar 12, 2023
c6d5fd8
doc fix
glevv Mar 12, 2023
fb6f9c8
Update test_spectral_embedding.py
glevv Mar 12, 2023
cf73e0e
revert
glevv Mar 12, 2023
165c0ad
doc update
glevv Mar 14, 2023
0647e36
update kmeans tests
glevv Mar 14, 2023
8e97cd0
fix spectral tests
glevv Mar 14, 2023
9650a45
Merge branch 'main' into init_samp_w
glevv Mar 14, 2023
31849bc
lint
glevv Mar 14, 2023
24c68f9
Update test_k_means.py
glevv Mar 14, 2023
1884711
Update test_k_means.py
glevv Mar 14, 2023
26054f3
lint
glevv Mar 14, 2023
d628630
added sample_weight to docstring
glevv Mar 14, 2023
b192bf1
updated tests
glevv Mar 14, 2023
996d3b0
typo
glevv Mar 14, 2023
79efebc
Update sklearn/cluster/_kmeans.py
glevv Mar 15, 2023
4ffbb91
changed docs
glevv Mar 15, 2023
aaa1acc
Update v1.3.rst
glevv Mar 15, 2023
89f80ae
Update test_k_means.py
glevv Mar 15, 2023
ea764c7
Update test_bisect_k_means.py
glevv Mar 15, 2023
a5e3489
Update test_k_means.py
glevv Mar 15, 2023
675cd17
Update test_bisect_k_means.py
glevv Mar 15, 2023
ebc67dc
Update test_bisect_k_means.py
glevv Mar 15, 2023
511050e
Update test_bisect_k_means.py
glevv Mar 15, 2023
d3c08e5
Update test_k_means.py
glevv Mar 15, 2023
ccf3604
Merge branch 'main' into init_samp_w
glevv Mar 15, 2023
061c6a0
Update test_k_means.py
glevv Mar 15, 2023
c2e3743
Update test_k_means.py
glevv Mar 15, 2023
c06a8b3
docs fix
glevv Mar 15, 2023
16aae7c
stability fix
glevv Mar 15, 2023
c8886d7
Update test_bisect_k_means.py
glevv Mar 15, 2023
8819dd5
Update test_bisect_k_means.py
glevv Mar 15, 2023
7c2dd2b
Update _kmeans.py
glevv Mar 16, 2023
0c8c021
Update v1.3.rst
glevv Mar 16, 2023
4d09437
Update test_k_means.py
glevv Mar 16, 2023
ae1f561
Update v1.3.rst
glevv Mar 16, 2023
2861d35
Merge branch 'scikit-learn:main' into init_samp_w
glevv Mar 16, 2023
8061460
revert unrelated
jeremiedbb Mar 16, 2023
a9b71e7
Update _kmeans.py
glevv Mar 16, 2023
3223902
Apply suggestions from code review
jeremiedbb Mar 16, 2023
af94b06
Apply suggestions from code review
jeremiedbb Mar 16, 2023
fc758da
Update test_bisect_k_means.py
glevv Mar 16, 2023
f1651d0
Update test_k_means.py
glevv Mar 16, 2023
924a357
Update sklearn/cluster/tests/test_k_means.py
jeremiedbb Mar 16, 2023
537a83a
Update sklearn/cluster/tests/test_k_means.py
jeremiedbb Mar 16, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion doc/whats_new/v1.3.rst
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,15 @@ random sampling procedures.
:class:`decomposition.MiniBatchNMF` which can produce different results than previous
versions. :pr:`25438` by :user:`Yotam Avidar-Constantini <yotamcons>`.

- |Enhancement| The `sample_weight` parameter now will be used in centroids
initialization for :class:`cluster.KMeans`, :class:`cluster.BisectingKMeans`
and :class:`cluster.MiniBatchKMeans`.
This change will break backward compatibility, since numbers generated
from same random seeds will be different.
:pr:`25752` by :user:`Gleb Levitski <glevv>`,
:user:`Jérémie du Boisberranger <jeremiedbb>`,
:user:`Guillaume Lemaitre <glemaitre>`.

Changes impacting all modules
-----------------------------

Expand Down Expand Up @@ -154,9 +163,18 @@ Changelog

- |API| The `sample_weight` parameter in `predict` for
:meth:`cluster.KMeans.predict` and :meth:`cluster.MiniBatchKMeans.predict`
is now deprecated and will be removed in v1.5.
is now deprecated and will be removed in v1.5.
:pr:`25251` by :user:`Gleb Levitski <glevv>`.

- |Enhancement| The `sample_weight` parameter now will be used in centroids
initialization for :class:`cluster.KMeans`, :class:`cluster.BisectingKMeans`
and :class:`cluster.MiniBatchKMeans`.
This change will break backward compatibility, since numbers generated
from same random seeds will be different.
:pr:`25752` by :user:`Gleb Levitski <glevv>`,
:user:`Jérémie du Boisberranger <jeremiedbb>`,
:user:`Guillaume Lemaitre <glemaitre>`.

:mod:`sklearn.datasets`
.......................

Expand Down
2 changes: 1 addition & 1 deletion sklearn/cluster/_bicluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,7 +487,7 @@ class SpectralBiclustering(BaseSpectral):
>>> clustering.row_labels_
array([1, 1, 1, 0, 0, 0], dtype=int32)
>>> clustering.column_labels_
array([0, 1], dtype=int32)
array([1, 0], dtype=int32)
>>> clustering
SpectralBiclustering(n_clusters=2, random_state=0)
"""
Expand Down
26 changes: 16 additions & 10 deletions sklearn/cluster/_bisect_k_means.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,18 +190,18 @@ class BisectingKMeans(_BaseKMeans):
--------
>>> from sklearn.cluster import BisectingKMeans
>>> import numpy as np
>>> X = np.array([[1, 2], [1, 4], [1, 0],
... [10, 2], [10, 4], [10, 0],
... [10, 6], [10, 8], [10, 10]])
>>> X = np.array([[1, 1], [10, 1], [3, 1],
... [10, 0], [2, 1], [10, 2],
... [10, 8], [10, 9], [10, 10]])
>>> bisect_means = BisectingKMeans(n_clusters=3, random_state=0).fit(X)
>>> bisect_means.labels_
array([2, 2, 2, 0, 0, 0, 1, 1, 1], dtype=int32)
array([0, 2, 0, 2, 0, 2, 1, 1, 1], dtype=int32)
>>> bisect_means.predict([[0, 0], [12, 3]])
array([2, 0], dtype=int32)
array([0, 2], dtype=int32)
>>> bisect_means.cluster_centers_
array([[10., 2.],
[10., 8.],
[ 1., 2.]])
array([[ 2., 1.],
[10., 9.],
[10., 1.]])
"""

_parameter_constraints: dict = {
Expand Down Expand Up @@ -309,7 +309,12 @@ def _bisect(self, X, x_squared_norms, sample_weight, cluster_to_bisect):
# Repeating `n_init` times to obtain best clusters
for _ in range(self.n_init):
centers_init = self._init_centroids(
X, x_squared_norms, self.init, self._random_state, n_centroids=2
X,
x_squared_norms=x_squared_norms,
init=self.init,
random_state=self._random_state,
n_centroids=2,
sample_weight=sample_weight,
)

labels, inertia, centers, _ = self._kmeans_single(
Expand Down Expand Up @@ -361,7 +366,8 @@ def fit(self, X, y=None, sample_weight=None):

sample_weight : array-like of shape (n_samples,), default=None
The weights for each observation in X. If None, all observations
are assigned equal weight.
are assigned equal weight. `sample_weight` is not used during
initialization if `init` is a callable.

Returns
-------
Expand Down
95 changes: 73 additions & 22 deletions sklearn/cluster/_kmeans.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,13 +63,20 @@
{
"X": ["array-like", "sparse matrix"],
"n_clusters": [Interval(Integral, 1, None, closed="left")],
"sample_weight": ["array-like", None],
"x_squared_norms": ["array-like", None],
"random_state": ["random_state"],
"n_local_trials": [Interval(Integral, 1, None, closed="left"), None],
}
)
def kmeans_plusplus(
X, n_clusters, *, x_squared_norms=None, random_state=None, n_local_trials=None
X,
n_clusters,
*,
sample_weight=None,
x_squared_norms=None,
random_state=None,
n_local_trials=None,
):
"""Init n_clusters seeds according to k-means++.

Expand All @@ -83,6 +90,13 @@ def kmeans_plusplus(
n_clusters : int
The number of centroids to initialize.

sample_weight : array-like of shape (n_samples,), default=None
The weights for each observation in `X`. If `None`, all observations
are assigned equal weight. `sample_weight` is ignored if `init`
is a callable or a user provided array.

.. versionadded:: 1.3

x_squared_norms : array-like of shape (n_samples,), default=None
Squared Euclidean norm of each data point.

Expand Down Expand Up @@ -125,13 +139,14 @@ def kmeans_plusplus(
... [10, 2], [10, 4], [10, 0]])
>>> centers, indices = kmeans_plusplus(X, n_clusters=2, random_state=0)
>>> centers
array([[10, 4],
array([[10, 2],
[ 1, 0]])
>>> indices
array([4, 2])
array([3, 2])
"""
# Check data
check_array(X, accept_sparse="csr", dtype=[np.float64, np.float32])
sample_weight = _check_sample_weight(sample_weight, X, dtype=X.dtype)

if X.shape[0] < n_clusters:
raise ValueError(
Expand All @@ -154,13 +169,15 @@ def kmeans_plusplus(

# Call private k-means++
centers, indices = _kmeans_plusplus(
X, n_clusters, x_squared_norms, random_state, n_local_trials
X, n_clusters, x_squared_norms, sample_weight, random_state, n_local_trials
)

return centers, indices


def _kmeans_plusplus(X, n_clusters, x_squared_norms, random_state, n_local_trials=None):
def _kmeans_plusplus(
X, n_clusters, x_squared_norms, sample_weight, random_state, n_local_trials=None
):
"""Computational component for initialization of n_clusters by
k-means++. Prior validation of data is assumed.

Expand All @@ -172,6 +189,9 @@ def _kmeans_plusplus(X, n_clusters, x_squared_norms, random_state, n_local_trial
n_clusters : int
The number of seeds to choose.

sample_weight : ndarray of shape (n_samples,)
The weights for each observation in `X`.

x_squared_norms : ndarray of shape (n_samples,)
Squared Euclidean norm of each data point.

Expand Down Expand Up @@ -206,7 +226,7 @@ def _kmeans_plusplus(X, n_clusters, x_squared_norms, random_state, n_local_trial
n_local_trials = 2 + int(np.log(n_clusters))

# Pick first center randomly and track index of point
center_id = random_state.randint(n_samples)
center_id = random_state.choice(n_samples, p=sample_weight / sample_weight.sum())
indices = np.full(n_clusters, -1, dtype=int)
if sp.issparse(X):
centers[0] = X[center_id].toarray()
Expand All @@ -218,14 +238,16 @@ def _kmeans_plusplus(X, n_clusters, x_squared_norms, random_state, n_local_trial
closest_dist_sq = _euclidean_distances(
centers[0, np.newaxis], X, Y_norm_squared=x_squared_norms, squared=True
)
current_pot = closest_dist_sq.sum()
current_pot = closest_dist_sq @ sample_weight

# Pick the remaining n_clusters-1 points
for c in range(1, n_clusters):
# Choose center candidates by sampling with probability proportional
# to the squared distance to the closest existing center
rand_vals = random_state.uniform(size=n_local_trials) * current_pot
candidate_ids = np.searchsorted(stable_cumsum(closest_dist_sq), rand_vals)
candidate_ids = np.searchsorted(
stable_cumsum(sample_weight * closest_dist_sq), rand_vals
)
# XXX: numerical imprecision can result in a candidate_id out of range
np.clip(candidate_ids, None, closest_dist_sq.size - 1, out=candidate_ids)

Expand All @@ -236,7 +258,7 @@ def _kmeans_plusplus(X, n_clusters, x_squared_norms, random_state, n_local_trial

# update closest distances squared and potential for each candidate
np.minimum(closest_dist_sq, distance_to_candidates, out=distance_to_candidates)
candidates_pot = distance_to_candidates.sum(axis=1)
candidates_pot = distance_to_candidates @ sample_weight.reshape(-1, 1)

# Decide which candidate is the best
best_candidate = np.argmin(candidates_pot)
Expand Down Expand Up @@ -323,7 +345,8 @@ def k_means(

sample_weight : array-like of shape (n_samples,), default=None
The weights for each observation in `X`. If `None`, all observations
are assigned equal weight.
are assigned equal weight. `sample_weight` is not used during
initialization if `init` is a callable or a user provided array.

init : {'k-means++', 'random'}, callable or array-like of shape \
(n_clusters, n_features), default='k-means++'
Expand Down Expand Up @@ -939,7 +962,14 @@ def _check_test_data(self, X):
return X

def _init_centroids(
self, X, x_squared_norms, init, random_state, init_size=None, n_centroids=None
self,
X,
x_squared_norms,
init,
random_state,
init_size=None,
n_centroids=None,
sample_weight=None,
):
"""Compute the initial centroids.

Expand Down Expand Up @@ -969,6 +999,11 @@ def _init_centroids(
If left to 'None' the number of centroids will be equal to
number of clusters to form (self.n_clusters)

sample_weight : ndarray of shape (n_samples,), default=None
The weights for each observation in X. If None, all observations
are assigned equal weight. `sample_weight` is not used during
initialization if `init` is a callable or a user provided array.

Returns
-------
centers : ndarray of shape (n_clusters, n_features)
Expand All @@ -981,16 +1016,23 @@ def _init_centroids(
X = X[init_indices]
x_squared_norms = x_squared_norms[init_indices]
n_samples = X.shape[0]
sample_weight = sample_weight[init_indices]

if isinstance(init, str) and init == "k-means++":
centers, _ = _kmeans_plusplus(
X,
n_clusters,
random_state=random_state,
x_squared_norms=x_squared_norms,
sample_weight=sample_weight,
)
elif isinstance(init, str) and init == "random":
seeds = random_state.permutation(n_samples)[:n_clusters]
seeds = random_state.choice(
n_samples,
size=n_clusters,
replace=False,
p=sample_weight / sample_weight.sum(),
)
centers = X[seeds]
elif _is_arraylike_not_scalar(self.init):
centers = init
Expand Down Expand Up @@ -1412,7 +1454,8 @@ def fit(self, X, y=None, sample_weight=None):

sample_weight : array-like of shape (n_samples,), default=None
The weights for each observation in X. If None, all observations
are assigned equal weight.
are assigned equal weight. `sample_weight` is not used during
initialization if `init` is a callable or a user provided array.

.. versionadded:: 0.20

Expand Down Expand Up @@ -1468,7 +1511,11 @@ def fit(self, X, y=None, sample_weight=None):
for i in range(self._n_init):
# Initialize centers
centers_init = self._init_centroids(
X, x_squared_norms=x_squared_norms, init=init, random_state=random_state
X,
x_squared_norms=x_squared_norms,
init=init,
random_state=random_state,
sample_weight=sample_weight,
)
if self.verbose:
print("Initialization complete")
Expand Down Expand Up @@ -1545,7 +1592,7 @@ def _mini_batch_step(
Squared euclidean norm of each data point.

sample_weight : ndarray of shape (n_samples,)
The weights for each observation in X.
The weights for each observation in `X`.

centers : ndarray of shape (n_clusters, n_features)
The cluster centers before the current iteration
Expand Down Expand Up @@ -1818,19 +1865,19 @@ class MiniBatchKMeans(_BaseKMeans):
>>> kmeans = kmeans.partial_fit(X[0:6,:])
>>> kmeans = kmeans.partial_fit(X[6:12,:])
>>> kmeans.cluster_centers_
array([[2. , 1. ],
[3.5, 4.5]])
array([[3.375, 3. ],
[0.75 , 0.5 ]])
>>> kmeans.predict([[0, 0], [4, 4]])
array([0, 1], dtype=int32)
array([1, 0], dtype=int32)
>>> # fit on the whole data
>>> kmeans = MiniBatchKMeans(n_clusters=2,
... random_state=0,
... batch_size=6,
... max_iter=10,
... n_init="auto").fit(X)
>>> kmeans.cluster_centers_
array([[3.97727273, 2.43181818],
[1.125 , 1.6 ]])
array([[3.55102041, 2.48979592],
[1.06896552, 1. ]])
>>> kmeans.predict([[0, 0], [4, 4]])
array([1, 0], dtype=int32)
"""
Expand Down Expand Up @@ -2015,7 +2062,8 @@ def fit(self, X, y=None, sample_weight=None):

sample_weight : array-like of shape (n_samples,), default=None
The weights for each observation in X. If None, all observations
are assigned equal weight.
are assigned equal weight. `sample_weight` is not used during
initialization if `init` is a callable or a user provided array.

.. versionadded:: 0.20

Expand Down Expand Up @@ -2070,6 +2118,7 @@ def fit(self, X, y=None, sample_weight=None):
init=init,
random_state=random_state,
init_size=self._init_size,
sample_weight=sample_weight,
)

# Compute inertia on a validation set.
Expand Down Expand Up @@ -2170,7 +2219,8 @@ def partial_fit(self, X, y=None, sample_weight=None):

sample_weight : array-like of shape (n_samples,), default=None
The weights for each observation in X. If None, all observations
are assigned equal weight.
are assigned equal weight. `sample_weight` is not used during
initialization if `init` is a callable or a user provided array.

Returns
-------
Expand Down Expand Up @@ -2220,6 +2270,7 @@ def partial_fit(self, X, y=None, sample_weight=None):
init=init,
random_state=self._random_state,
init_size=self._init_size,
sample_weight=sample_weight,
)

# Initialize counts
Expand Down
Loading