Skip to content

TST Relax test_minibatch_sensible_reassign to avoid CI failures with single global random seed #29278

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Jun 20, 2024
9 changes: 6 additions & 3 deletions sklearn/cluster/tests/test_k_means.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,21 +437,24 @@ def test_minibatch_sensible_reassign(global_random_seed):
n_clusters=20, batch_size=10, random_state=global_random_seed, init="random"
).fit(zeroed_X)
# there should not be too many exact zero cluster centers
assert km.cluster_centers_.any(axis=1).sum() > 10
num_non_zero_clusters = km.cluster_centers_.any(axis=1).sum()
assert num_non_zero_clusters > 9, f"{num_non_zero_clusters=} is too small"
Copy link
Member Author

@lesteve lesteve Jun 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added assertion string (i.e. second argument in the assert, not sure if there is a more exact name) because it seems like pytest assertion rewriting is a bit broken (needs a bit of investigation as to why). If this ever fails again, at least the message will tell how far we are from the threshold


# do the same with batch-size > X.shape[0] (regression test)
km = MiniBatchKMeans(
n_clusters=20, batch_size=200, random_state=global_random_seed, init="random"
).fit(zeroed_X)
# there should not be too many exact zero cluster centers
assert km.cluster_centers_.any(axis=1).sum() > 10
num_non_zero_clusters = km.cluster_centers_.any(axis=1).sum()
assert num_non_zero_clusters > 9, f"{num_non_zero_clusters=} is too small"

# do the same with partial_fit API
km = MiniBatchKMeans(n_clusters=20, random_state=global_random_seed, init="random")
for i in range(100):
km.partial_fit(zeroed_X)
# there should not be too many exact zero cluster centers
assert km.cluster_centers_.any(axis=1).sum() > 10
num_non_zero_clusters = km.cluster_centers_.any(axis=1).sum()
assert num_non_zero_clusters > 9, f"{num_non_zero_clusters=} is too small"


@pytest.mark.parametrize(
Expand Down
Loading