Skip to content

adding adaptive learning rate for minibatch k-means #30051

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

BenJourdan
Copy link

Reference Issues/PRs

None

What does this implement/fix? Explain your changes.

This request implements a recent learning rate for minibatch k-means which can be superior to the default learning rate. We implement this with the flag adaptive_lr that defaults to false.

Details can be found in this paper that appeared in ICLR 2023. Extensive experiments can be found in this manuscript - ignore the kernel k-means results. We also added a benchmark that produces the following plot which shows the learning rate is the same or better than the default on dense datasets.

image

Any other comments?

This is a reasonably small code change. We add a flag to the MinibatchKmeans constructor and the _k_means_minibatch.pyx cython file. The learning rate implementation is straightforward. In the benchmarks, it appears to take a few more iterations for the adaptive learning rate to converge, often resulting in better solutions. When we removed early stopping we observed the running time is about the same.

This should be a cleaner version of #30045 (I made a mess since I'm still pretty new to git).

Copy link

github-actions bot commented Oct 12, 2024

✔️ Linting Passed

All linting checks passed. Your pull request is in excellent shape! ☀️

Generated for commit: 13fe041. Link to the linter CI: here

@adrinjalali
Copy link
Member

The diff is indeed relatively small, however, the paper is quite recent, and the improvements are marginal. So I'll let @ogrisel , @lorentzenchr , @jeremiedbb, and @GaelVaroquaux weigh in here.

@adrinjalali adrinjalali added the Needs Decision Requires decision label Oct 14, 2024
@ogrisel
Copy link
Member

ogrisel commented Oct 14, 2024

I have similar feelings. Unfortunately, arxiv.org seems to be unresponsive since yesterday for me. I cannot check the benchmark results from the paper.

@BenJourdan could you please add results for full-batch k-means to your plots? I am wondering if this can allow MB-k-means to reach the same scores as full-batch k-means on those problems.

@BenJourdan
Copy link
Author

BenJourdan commented Oct 14, 2024

Here are the results with full-batch k-means added:
results_default_params

If you mess around with the early stopping condition tol, this also affects runtime/performance. It's not exactly apples to apples to compare tol values between the mini-batch and full-batch methods but I imagine it's what users may reach for first if they are worried about runtime. max_no_improvement will also have an effect.

This was with tol=1e-1 for all the algorithms:
results_tol_1e-1

This was with tol=1e-2:
results_tol_1e-2

tol=1e-3:
results_tol_1e-3

tol=1e-4:
results_tol_1e-4

I can add more experiments varying max_no_improvement if that helps.

@BenJourdan BenJourdan force-pushed the feature_mbkm_adaptive_lr branch from 35c5b19 to 158897a Compare October 14, 2024 15:47
@ogrisel
Copy link
Member

ogrisel commented Oct 17, 2024

Thanks for the update. So from those experiments, it appears that the new lr scheme can empirically help MBKMeans close the (smallish) gap with full-batch KMeans in terms of clustering quality while keeping favorable runtimes for datasets with many data points (e.g. MNIST size or larger).

But since the method was recently published, this PR does not technically meet our inclusion criteria, although we could be less strict in cases where this is an incremental improvement of an existing method implemented in scikit-learn.

I will mention this PR at our next monthly meeting.

@BenJourdan
Copy link
Author

What was the verdict @ogrisel?

@adrinjalali
Copy link
Member

There were no clear objections to include this, and I think a few of us are in favor of including it.

@BenJourdan
Copy link
Author

@ogrisel @adrinjalali what happens next? Should I start updating the branch?

@adrinjalali
Copy link
Member

@BenJourdan seems like it.

@BenJourdan BenJourdan force-pushed the feature_mbkm_adaptive_lr branch from 65e49cb to 320a4b4 Compare November 6, 2024 16:06
@BenJourdan
Copy link
Author

Should I keep updating the branch until (if lol) someone gets assigned? Not sure what the convention is.

@gregoryschwartzman
Copy link

Hi @ogrisel, @adrinjalali,

Just checking in—since there were no objections during the meeting and some support for inclusion, would it make sense to remove the Needs Decision label and move toward review/approval?

Please let me know what you'd recommend as the next step. Thanks again for your time and feedback so far!

@adrinjalali adrinjalali removed the Needs Decision Requires decision label Apr 23, 2025
@adrinjalali
Copy link
Member

I've removed "need decision" here. I think we can move forward with this.

@gregoryschwartzman
Copy link

Great! Let us know if we should do anything on our end.

@adrinjalali
Copy link
Member

@antoinebaker would you mind having a look at this PR for a review?

Copy link
Contributor

@antoinebaker antoinebaker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR @BenJourdan ! Here a first round of review.

Comment on lines 1801 to 1805
adaptive_lr : bool, default=False
If True, use the adaptive learning rate described in this \
`paper <https://arxiv.org/abs/2304.00419>`_.
This can be more effective than the standard learning rate \
when the input is dense.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
adaptive_lr : bool, default=False
If True, use the adaptive learning rate described in this \
`paper <https://arxiv.org/abs/2304.00419>`_.
This can be more effective than the standard learning rate \
when the input is dense.
adaptive_lr : bool, default=False
If True, use the adaptive learning rate described in [1]_ which can be
more effective than the standard learning rate when the input is dense.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

adding a References section below Notes

    References
    ----------

    .. [1] :arxiv:`Ben Jourdan and Gregory Schwartzman (2024). 
      "Mini-Batch Kernel k-means." <2410.05902>`

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. I'll cite Gregory's original paper instead:

    .. [1] :arxiv:`Gregory Schwartzman (2023). 
      "Mini-batch k-means terminates within O(d/ɛ) iterations" <2304.00419>`

n_threads : int
The number of threads to be used by openmp.
"""
cdef:
int n_samples = X.shape[0]
int n_clusters = centers_old.shape[0]
int cluster_idx

floating b=0.0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure about the naming, but something more explicit than b:

Suggested change
floating b=0.0
floating wsum_batch = 0.0

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. I'll also change wsum inside the update_center_* functions to wsum_cluster to indicate it's talking about the weight of the points (in the batch) that were assigned to cluster cluster_idx.

Comment on lines 97 to 107
if adaptive_lr:
"""
perform the minibatch update for the current cluster using
C_new = C_old *(1-alpha) + alpha*cm(B_j)

where alpha = sqrt(b_j/b) is the learning rate from https://arxiv.org/abs/2304.00419,
b is the weight of the batch, b_j is the weight of the batch w.r.t. the current cluster,
and cm(B_j) is the center of mass of the batch w.r.t. the current cluster.
"""
weight_sums[cluster_idx] += wsum
alpha = sqrt(wsum/b)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would introduce a learning rate, here called lr to distinguish it from
alpha = 1 / weight_sums[cluster_idx] used when adaptive_lr=False.

Suggested change
if adaptive_lr:
"""
perform the minibatch update for the current cluster using
C_new = C_old *(1-alpha) + alpha*cm(B_j)
where alpha = sqrt(b_j/b) is the learning rate from https://arxiv.org/abs/2304.00419,
b is the weight of the batch, b_j is the weight of the batch w.r.t. the current cluster,
and cm(B_j) is the center of mass of the batch w.r.t. the current cluster.
"""
weight_sums[cluster_idx] += wsum
alpha = sqrt(wsum/b)
# Update center C_new = (1 - lr) * C_old + lr * C_X where
# C_X = weighted mean of the observations X assigned to the cluster
# lr = learning rate
if adaptive_lr:
# learning rate lr = sqrt(wsum/wsum_batch) suggested in
# https://arxiv.org/abs/2304.00419

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. I'll generalize the trick currently used by both methods to avoid expicitly computing the means of the points in the current batch. I'll add the following comment along with the changes:

        # We want to compute the new center with the update formula
        # C_{i+1} = C^{i}_j*(1-alpha) + alpha*CM(B_j^i).

        # where: 
        # - C_j^i is the center representing the j-th cluster at the i-th 
        #   iteration
        # - B_j^i is the batch of samples assigned to the j-th cluster at 
        #   the i-th iteration
        # - CM(B_j^i) is the (weighted) mean of the samples assigned to 
        #   cluster j in iteration i
        # - alpha is the learning rate

        # In the non-adaptive case, alpha = wsum_cluster/(wsum_cluster+old_weight)
        # where:
        # - wsum_cluster is the weight of the points assigned to the cluster in the 
        #   current batch
        # - old_weight is the weight of all points assigned to cluster j 
        #   in previous iterations.
        # This is equivalent to computing a weighted average of everything 
        # assigned to cluster j so far.

        # In the adaptive case (see https://arxiv.org/abs/2304.00419), 
        # alpha = sqrt(wsum_cluster/wsum_batch) where wsum_batch is the weight of 
        # the batch. This is similar to an exponential moving average but with
        # an adaptive decay rate.

        # For the sake of efficiency, we don't compute the update explicitly.
        # Instead, we skip computing the mean of the batch and instead
        # compute the update by scaling the old center, adding the weighted
        # sum of the batch, and then scaling again.

        # Let Sigma(B_j^i) be the weighted sum of the points assigned to 
        # cluster j in the current batch. 
        # Therefore (Sigma(B_j^i) = wsum_cluster * CM(B_j^i)).

        # We can rewrite the update formula as:
        # C_{i+1} = C^{i}_j*(1-alpha) + (alpha/wsum_cluster)*Sigma(B_j^i)
        #         = (alpha/wsum_cluster)[C^{i}_j*(1-alpha)(wsum_cluster/alpha) + Sigma(B_j^i)]

        # In the adaptive case, nothing simplifies so we just use the formula
        # as is.
        # In the non-adaptive case, things simplify and we have 
        # - (1-alpha)*(wsum_cluster/alpha) 
        #   = (old_weight/(w_sum+old_weight))*(wsum_cluster+old_weight) = old_weight
        # - (alpha/wsum_cluster) = 1/(wsum_cluster+old_weight)

Comment on lines 116 to 117
else:
# Undo the previous count-based scaling for this cluster center
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
else:
# Undo the previous count-based scaling for this cluster center
else:
# learning rate lr = wsum / (weight_sums + wsum)
# Undo the previous count-based scaling for this cluster center

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as comment on generalizing efficient updates.

alpha = 1 / weight_sums[cluster_idx]
for feature_idx in range(n_features):
centers_new[cluster_idx, feature_idx] *= alpha
centers_new[cluster_idx, feature_idx] = centers_old[cluster_idx, feature_idx]* (1-alpha) * (wsum/alpha)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
centers_new[cluster_idx, feature_idx] = centers_old[cluster_idx, feature_idx]* (1-alpha) * (wsum/alpha)
centers_new[cluster_idx, feature_idx] = centers_old[cluster_idx, feature_idx] * (1 - lr) * (wsum / lr)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will abstract the scaling factor (1-alpha) * (wsum/alpha) to old_scaling_factor

@@ -150,17 +182,19 @@ def _minibatch_update_sparse(
int n_samples = X.shape[0]
int n_clusters = centers_old.shape[0]
int cluster_idx

floating b=0.0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above

Suggested change
floating b=0.0
floating wsum_batch = 0.0

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Noted.

Comment on lines 234 to 242
if adaptive_lr:
"""
perform the minibatch update for the current cluster using
C_new = C_old *(1-alpha) + alpha*cm(B_j)

where alpha = sqrt(b_j/b) is the learning rate from https://arxiv.org/abs/2304.00419,
b is the weight of the batch, b_j is the weight of the batch w.r.t. the current cluster,
and cm(B_j) is the center of mass of the batch w.r.t. the current cluster.
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As above.

@antoinebaker
Copy link
Contributor

I feel the code could be simplified by first defining a learning rate:

if adaptive_lr:
    lr = sqrt(wsum / wsum_batch)
else:
    lr = wsum / (weight_sums[cluster_idx] + wsum)

and then do the common updates

for feature_idx in range(n_features):
    centers_new[cluster_idx, feature_idx] = (1 - lr) * centers_old[cluster_idx, feature_idx]
for k in range(n_indices):
    sample_idx = indices[k]
    for feature_idx in range(n_features):
        weight_idx =  sample_weight[sample_idx] / wsum
        centers_new[cluster_idx, feature_idx] += lr * weight_idx * X[sample_idx, feature_idx]

Copy link
Author

@BenJourdan BenJourdan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the feedback. I'll add most of your suggestions as are.

Introducing the learning rate to avoid duplicating code is a good idea. However, it's a bit messy since we need to do an optimization that avoids explicitly computing the means of each batch for a given center. I'll have a go at redrafting those parts.

Thanks!

n_threads : int
The number of threads to be used by openmp.
"""
cdef:
int n_samples = X.shape[0]
int n_clusters = centers_old.shape[0]
int cluster_idx

floating b=0.0
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. I'll also change wsum inside the update_center_* functions to wsum_cluster to indicate it's talking about the weight of the points (in the batch) that were assigned to cluster cluster_idx.

Comment on lines 97 to 107
if adaptive_lr:
"""
perform the minibatch update for the current cluster using
C_new = C_old *(1-alpha) + alpha*cm(B_j)

where alpha = sqrt(b_j/b) is the learning rate from https://arxiv.org/abs/2304.00419,
b is the weight of the batch, b_j is the weight of the batch w.r.t. the current cluster,
and cm(B_j) is the center of mass of the batch w.r.t. the current cluster.
"""
weight_sums[cluster_idx] += wsum
alpha = sqrt(wsum/b)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. I'll generalize the trick currently used by both methods to avoid expicitly computing the means of the points in the current batch. I'll add the following comment along with the changes:

        # We want to compute the new center with the update formula
        # C_{i+1} = C^{i}_j*(1-alpha) + alpha*CM(B_j^i).

        # where: 
        # - C_j^i is the center representing the j-th cluster at the i-th 
        #   iteration
        # - B_j^i is the batch of samples assigned to the j-th cluster at 
        #   the i-th iteration
        # - CM(B_j^i) is the (weighted) mean of the samples assigned to 
        #   cluster j in iteration i
        # - alpha is the learning rate

        # In the non-adaptive case, alpha = wsum_cluster/(wsum_cluster+old_weight)
        # where:
        # - wsum_cluster is the weight of the points assigned to the cluster in the 
        #   current batch
        # - old_weight is the weight of all points assigned to cluster j 
        #   in previous iterations.
        # This is equivalent to computing a weighted average of everything 
        # assigned to cluster j so far.

        # In the adaptive case (see https://arxiv.org/abs/2304.00419), 
        # alpha = sqrt(wsum_cluster/wsum_batch) where wsum_batch is the weight of 
        # the batch. This is similar to an exponential moving average but with
        # an adaptive decay rate.

        # For the sake of efficiency, we don't compute the update explicitly.
        # Instead, we skip computing the mean of the batch and instead
        # compute the update by scaling the old center, adding the weighted
        # sum of the batch, and then scaling again.

        # Let Sigma(B_j^i) be the weighted sum of the points assigned to 
        # cluster j in the current batch. 
        # Therefore (Sigma(B_j^i) = wsum_cluster * CM(B_j^i)).

        # We can rewrite the update formula as:
        # C_{i+1} = C^{i}_j*(1-alpha) + (alpha/wsum_cluster)*Sigma(B_j^i)
        #         = (alpha/wsum_cluster)[C^{i}_j*(1-alpha)(wsum_cluster/alpha) + Sigma(B_j^i)]

        # In the adaptive case, nothing simplifies so we just use the formula
        # as is.
        # In the non-adaptive case, things simplify and we have 
        # - (1-alpha)*(wsum_cluster/alpha) 
        #   = (old_weight/(w_sum+old_weight))*(wsum_cluster+old_weight) = old_weight
        # - (alpha/wsum_cluster) = 1/(wsum_cluster+old_weight)

alpha = 1 / weight_sums[cluster_idx]
for feature_idx in range(n_features):
centers_new[cluster_idx, feature_idx] *= alpha
centers_new[cluster_idx, feature_idx] = centers_old[cluster_idx, feature_idx]* (1-alpha) * (wsum/alpha)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will abstract the scaling factor (1-alpha) * (wsum/alpha) to old_scaling_factor

Comment on lines 116 to 117
else:
# Undo the previous count-based scaling for this cluster center
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as comment on generalizing efficient updates.

@@ -150,17 +182,19 @@ def _minibatch_update_sparse(
int n_samples = X.shape[0]
int n_clusters = centers_old.shape[0]
int cluster_idx

floating b=0.0
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Noted.

Comment on lines 1801 to 1805
adaptive_lr : bool, default=False
If True, use the adaptive learning rate described in this \
`paper <https://arxiv.org/abs/2304.00419>`_.
This can be more effective than the standard learning rate \
when the input is dense.
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. I'll cite Gregory's original paper instead:

    .. [1] :arxiv:`Gregory Schwartzman (2023). 
      "Mini-batch k-means terminates within O(d/ɛ) iterations" <2304.00419>`

@BenJourdan BenJourdan force-pushed the feature_mbkm_adaptive_lr branch 4 times, most recently from 0bf05ba to 0094c8e Compare May 1, 2025 15:41
@BenJourdan BenJourdan force-pushed the feature_mbkm_adaptive_lr branch from 0094c8e to 13fe041 Compare May 1, 2025 15:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants