-
-
Notifications
You must be signed in to change notification settings - Fork 25.8k
adding adaptive learning rate for minibatch k-means #30051
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
adding adaptive learning rate for minibatch k-means #30051
Conversation
The diff is indeed relatively small, however, the paper is quite recent, and the improvements are marginal. So I'll let @ogrisel , @lorentzenchr , @jeremiedbb, and @GaelVaroquaux weigh in here. |
I have similar feelings. Unfortunately, arxiv.org seems to be unresponsive since yesterday for me. I cannot check the benchmark results from the paper. @BenJourdan could you please add results for full-batch k-means to your plots? I am wondering if this can allow MB-k-means to reach the same scores as full-batch k-means on those problems. |
35c5b19
to
158897a
Compare
Thanks for the update. So from those experiments, it appears that the new lr scheme can empirically help MBKMeans close the (smallish) gap with full-batch KMeans in terms of clustering quality while keeping favorable runtimes for datasets with many data points (e.g. MNIST size or larger). But since the method was recently published, this PR does not technically meet our inclusion criteria, although we could be less strict in cases where this is an incremental improvement of an existing method implemented in scikit-learn. I will mention this PR at our next monthly meeting. |
What was the verdict @ogrisel? |
There were no clear objections to include this, and I think a few of us are in favor of including it. |
@ogrisel @adrinjalali what happens next? Should I start updating the branch? |
@BenJourdan seems like it. |
65e49cb
to
320a4b4
Compare
Should I keep updating the branch until (if lol) someone gets assigned? Not sure what the convention is. |
Hi @ogrisel, @adrinjalali, Just checking in—since there were no objections during the meeting and some support for inclusion, would it make sense to remove the Needs Decision label and move toward review/approval? Please let me know what you'd recommend as the next step. Thanks again for your time and feedback so far! |
I've removed "need decision" here. I think we can move forward with this. |
Great! Let us know if we should do anything on our end. |
@antoinebaker would you mind having a look at this PR for a review? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR @BenJourdan ! Here a first round of review.
sklearn/cluster/_kmeans.py
Outdated
adaptive_lr : bool, default=False | ||
If True, use the adaptive learning rate described in this \ | ||
`paper <https://arxiv.org/abs/2304.00419>`_. | ||
This can be more effective than the standard learning rate \ | ||
when the input is dense. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
adaptive_lr : bool, default=False | |
If True, use the adaptive learning rate described in this \ | |
`paper <https://arxiv.org/abs/2304.00419>`_. | |
This can be more effective than the standard learning rate \ | |
when the input is dense. | |
adaptive_lr : bool, default=False | |
If True, use the adaptive learning rate described in [1]_ which can be | |
more effective than the standard learning rate when the input is dense. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
adding a References section below Notes
References
----------
.. [1] :arxiv:`Ben Jourdan and Gregory Schwartzman (2024).
"Mini-Batch Kernel k-means." <2410.05902>`
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure. I'll cite Gregory's original paper instead:
.. [1] :arxiv:`Gregory Schwartzman (2023).
"Mini-batch k-means terminates within O(d/ɛ) iterations" <2304.00419>`
n_threads : int | ||
The number of threads to be used by openmp. | ||
""" | ||
cdef: | ||
int n_samples = X.shape[0] | ||
int n_clusters = centers_old.shape[0] | ||
int cluster_idx | ||
|
||
floating b=0.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure about the naming, but something more explicit than b
:
floating b=0.0 | |
floating wsum_batch = 0.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure. I'll also change wsum inside the update_center_* functions to wsum_cluster to indicate it's talking about the weight of the points (in the batch) that were assigned to cluster cluster_idx.
if adaptive_lr: | ||
""" | ||
perform the minibatch update for the current cluster using | ||
C_new = C_old *(1-alpha) + alpha*cm(B_j) | ||
|
||
where alpha = sqrt(b_j/b) is the learning rate from https://arxiv.org/abs/2304.00419, | ||
b is the weight of the batch, b_j is the weight of the batch w.r.t. the current cluster, | ||
and cm(B_j) is the center of mass of the batch w.r.t. the current cluster. | ||
""" | ||
weight_sums[cluster_idx] += wsum | ||
alpha = sqrt(wsum/b) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would introduce a learning rate, here called lr
to distinguish it from
alpha = 1 / weight_sums[cluster_idx]
used when adaptive_lr=False
.
if adaptive_lr: | |
""" | |
perform the minibatch update for the current cluster using | |
C_new = C_old *(1-alpha) + alpha*cm(B_j) | |
where alpha = sqrt(b_j/b) is the learning rate from https://arxiv.org/abs/2304.00419, | |
b is the weight of the batch, b_j is the weight of the batch w.r.t. the current cluster, | |
and cm(B_j) is the center of mass of the batch w.r.t. the current cluster. | |
""" | |
weight_sums[cluster_idx] += wsum | |
alpha = sqrt(wsum/b) | |
# Update center C_new = (1 - lr) * C_old + lr * C_X where | |
# C_X = weighted mean of the observations X assigned to the cluster | |
# lr = learning rate | |
if adaptive_lr: | |
# learning rate lr = sqrt(wsum/wsum_batch) suggested in | |
# https://arxiv.org/abs/2304.00419 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree. I'll generalize the trick currently used by both methods to avoid expicitly computing the means of the points in the current batch. I'll add the following comment along with the changes:
# We want to compute the new center with the update formula
# C_{i+1} = C^{i}_j*(1-alpha) + alpha*CM(B_j^i).
# where:
# - C_j^i is the center representing the j-th cluster at the i-th
# iteration
# - B_j^i is the batch of samples assigned to the j-th cluster at
# the i-th iteration
# - CM(B_j^i) is the (weighted) mean of the samples assigned to
# cluster j in iteration i
# - alpha is the learning rate
# In the non-adaptive case, alpha = wsum_cluster/(wsum_cluster+old_weight)
# where:
# - wsum_cluster is the weight of the points assigned to the cluster in the
# current batch
# - old_weight is the weight of all points assigned to cluster j
# in previous iterations.
# This is equivalent to computing a weighted average of everything
# assigned to cluster j so far.
# In the adaptive case (see https://arxiv.org/abs/2304.00419),
# alpha = sqrt(wsum_cluster/wsum_batch) where wsum_batch is the weight of
# the batch. This is similar to an exponential moving average but with
# an adaptive decay rate.
# For the sake of efficiency, we don't compute the update explicitly.
# Instead, we skip computing the mean of the batch and instead
# compute the update by scaling the old center, adding the weighted
# sum of the batch, and then scaling again.
# Let Sigma(B_j^i) be the weighted sum of the points assigned to
# cluster j in the current batch.
# Therefore (Sigma(B_j^i) = wsum_cluster * CM(B_j^i)).
# We can rewrite the update formula as:
# C_{i+1} = C^{i}_j*(1-alpha) + (alpha/wsum_cluster)*Sigma(B_j^i)
# = (alpha/wsum_cluster)[C^{i}_j*(1-alpha)(wsum_cluster/alpha) + Sigma(B_j^i)]
# In the adaptive case, nothing simplifies so we just use the formula
# as is.
# In the non-adaptive case, things simplify and we have
# - (1-alpha)*(wsum_cluster/alpha)
# = (old_weight/(w_sum+old_weight))*(wsum_cluster+old_weight) = old_weight
# - (alpha/wsum_cluster) = 1/(wsum_cluster+old_weight)
else: | ||
# Undo the previous count-based scaling for this cluster center |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
else: | |
# Undo the previous count-based scaling for this cluster center | |
else: | |
# learning rate lr = wsum / (weight_sums + wsum) | |
# Undo the previous count-based scaling for this cluster center |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as comment on generalizing efficient updates.
alpha = 1 / weight_sums[cluster_idx] | ||
for feature_idx in range(n_features): | ||
centers_new[cluster_idx, feature_idx] *= alpha | ||
centers_new[cluster_idx, feature_idx] = centers_old[cluster_idx, feature_idx]* (1-alpha) * (wsum/alpha) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
centers_new[cluster_idx, feature_idx] = centers_old[cluster_idx, feature_idx]* (1-alpha) * (wsum/alpha) | |
centers_new[cluster_idx, feature_idx] = centers_old[cluster_idx, feature_idx] * (1 - lr) * (wsum / lr) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will abstract the scaling factor (1-alpha) * (wsum/alpha)
to old_scaling_factor
@@ -150,17 +182,19 @@ def _minibatch_update_sparse( | |||
int n_samples = X.shape[0] | |||
int n_clusters = centers_old.shape[0] | |||
int cluster_idx | |||
|
|||
floating b=0.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as above
floating b=0.0 | |
floating wsum_batch = 0.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Noted.
if adaptive_lr: | ||
""" | ||
perform the minibatch update for the current cluster using | ||
C_new = C_old *(1-alpha) + alpha*cm(B_j) | ||
|
||
where alpha = sqrt(b_j/b) is the learning rate from https://arxiv.org/abs/2304.00419, | ||
b is the weight of the batch, b_j is the weight of the batch w.r.t. the current cluster, | ||
and cm(B_j) is the center of mass of the batch w.r.t. the current cluster. | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As above.
I feel the code could be simplified by first defining a learning rate: if adaptive_lr:
lr = sqrt(wsum / wsum_batch)
else:
lr = wsum / (weight_sums[cluster_idx] + wsum) and then do the common updates for feature_idx in range(n_features):
centers_new[cluster_idx, feature_idx] = (1 - lr) * centers_old[cluster_idx, feature_idx]
for k in range(n_indices):
sample_idx = indices[k]
for feature_idx in range(n_features):
weight_idx = sample_weight[sample_idx] / wsum
centers_new[cluster_idx, feature_idx] += lr * weight_idx * X[sample_idx, feature_idx] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the feedback. I'll add most of your suggestions as are.
Introducing the learning rate to avoid duplicating code is a good idea. However, it's a bit messy since we need to do an optimization that avoids explicitly computing the means of each batch for a given center. I'll have a go at redrafting those parts.
Thanks!
n_threads : int | ||
The number of threads to be used by openmp. | ||
""" | ||
cdef: | ||
int n_samples = X.shape[0] | ||
int n_clusters = centers_old.shape[0] | ||
int cluster_idx | ||
|
||
floating b=0.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure. I'll also change wsum inside the update_center_* functions to wsum_cluster to indicate it's talking about the weight of the points (in the batch) that were assigned to cluster cluster_idx.
if adaptive_lr: | ||
""" | ||
perform the minibatch update for the current cluster using | ||
C_new = C_old *(1-alpha) + alpha*cm(B_j) | ||
|
||
where alpha = sqrt(b_j/b) is the learning rate from https://arxiv.org/abs/2304.00419, | ||
b is the weight of the batch, b_j is the weight of the batch w.r.t. the current cluster, | ||
and cm(B_j) is the center of mass of the batch w.r.t. the current cluster. | ||
""" | ||
weight_sums[cluster_idx] += wsum | ||
alpha = sqrt(wsum/b) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree. I'll generalize the trick currently used by both methods to avoid expicitly computing the means of the points in the current batch. I'll add the following comment along with the changes:
# We want to compute the new center with the update formula
# C_{i+1} = C^{i}_j*(1-alpha) + alpha*CM(B_j^i).
# where:
# - C_j^i is the center representing the j-th cluster at the i-th
# iteration
# - B_j^i is the batch of samples assigned to the j-th cluster at
# the i-th iteration
# - CM(B_j^i) is the (weighted) mean of the samples assigned to
# cluster j in iteration i
# - alpha is the learning rate
# In the non-adaptive case, alpha = wsum_cluster/(wsum_cluster+old_weight)
# where:
# - wsum_cluster is the weight of the points assigned to the cluster in the
# current batch
# - old_weight is the weight of all points assigned to cluster j
# in previous iterations.
# This is equivalent to computing a weighted average of everything
# assigned to cluster j so far.
# In the adaptive case (see https://arxiv.org/abs/2304.00419),
# alpha = sqrt(wsum_cluster/wsum_batch) where wsum_batch is the weight of
# the batch. This is similar to an exponential moving average but with
# an adaptive decay rate.
# For the sake of efficiency, we don't compute the update explicitly.
# Instead, we skip computing the mean of the batch and instead
# compute the update by scaling the old center, adding the weighted
# sum of the batch, and then scaling again.
# Let Sigma(B_j^i) be the weighted sum of the points assigned to
# cluster j in the current batch.
# Therefore (Sigma(B_j^i) = wsum_cluster * CM(B_j^i)).
# We can rewrite the update formula as:
# C_{i+1} = C^{i}_j*(1-alpha) + (alpha/wsum_cluster)*Sigma(B_j^i)
# = (alpha/wsum_cluster)[C^{i}_j*(1-alpha)(wsum_cluster/alpha) + Sigma(B_j^i)]
# In the adaptive case, nothing simplifies so we just use the formula
# as is.
# In the non-adaptive case, things simplify and we have
# - (1-alpha)*(wsum_cluster/alpha)
# = (old_weight/(w_sum+old_weight))*(wsum_cluster+old_weight) = old_weight
# - (alpha/wsum_cluster) = 1/(wsum_cluster+old_weight)
alpha = 1 / weight_sums[cluster_idx] | ||
for feature_idx in range(n_features): | ||
centers_new[cluster_idx, feature_idx] *= alpha | ||
centers_new[cluster_idx, feature_idx] = centers_old[cluster_idx, feature_idx]* (1-alpha) * (wsum/alpha) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will abstract the scaling factor (1-alpha) * (wsum/alpha)
to old_scaling_factor
else: | ||
# Undo the previous count-based scaling for this cluster center |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as comment on generalizing efficient updates.
@@ -150,17 +182,19 @@ def _minibatch_update_sparse( | |||
int n_samples = X.shape[0] | |||
int n_clusters = centers_old.shape[0] | |||
int cluster_idx | |||
|
|||
floating b=0.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Noted.
sklearn/cluster/_kmeans.py
Outdated
adaptive_lr : bool, default=False | ||
If True, use the adaptive learning rate described in this \ | ||
`paper <https://arxiv.org/abs/2304.00419>`_. | ||
This can be more effective than the standard learning rate \ | ||
when the input is dense. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure. I'll cite Gregory's original paper instead:
.. [1] :arxiv:`Gregory Schwartzman (2023).
"Mini-batch k-means terminates within O(d/ɛ) iterations" <2304.00419>`
0bf05ba
to
0094c8e
Compare
0094c8e
to
13fe041
Compare
Reference Issues/PRs
None
What does this implement/fix? Explain your changes.
This request implements a recent learning rate for minibatch k-means which can be superior to the default learning rate. We implement this with the flag
adaptive_lr
that defaults to false.Details can be found in this paper that appeared in ICLR 2023. Extensive experiments can be found in this manuscript - ignore the kernel k-means results. We also added a benchmark that produces the following plot which shows the learning rate is the same or better than the default on dense datasets.
Any other comments?
This is a reasonably small code change. We add a flag to the MinibatchKmeans constructor and the _k_means_minibatch.pyx cython file. The learning rate implementation is straightforward. In the benchmarks, it appears to take a few more iterations for the adaptive learning rate to converge, often resulting in better solutions. When we removed early stopping we observed the running time is about the same.
This should be a cleaner version of #30045 (I made a mess since I'm still pretty new to git).