Skip to content

Commit c7866e6

Browse files
ogriseljeremiedbb
andauthored
TST fix platform sensitive test: test_float_precision (#32035)
Co-authored-by: Jérémie du Boisberranger <jeremie@probabl.ai>
1 parent 96f48da commit c7866e6

File tree

1 file changed

+13
-4
lines changed

1 file changed

+13
-4
lines changed

sklearn/cluster/tests/test_k_means.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import numpy as np
88
import pytest
99
from scipy import sparse as sp
10+
from threadpoolctl import threadpool_info
1011

1112
from sklearn.base import clone
1213
from sklearn.cluster import KMeans, MiniBatchKMeans, k_means, kmeans_plusplus
@@ -744,7 +745,7 @@ def test_transform(Estimator, global_random_seed):
744745
# In particular, diagonal must be 0
745746
assert_array_equal(Xt.diagonal(), np.zeros(n_clusters))
746747

747-
# Transorfming X should return the pairwise distances between X and the
748+
# Transforming X should return the pairwise distances between X and the
748749
# centers
749750
Xt = km.transform(X)
750751
assert_allclose(Xt, pairwise_distances(X, km.cluster_centers_))
@@ -794,6 +795,13 @@ def test_k_means_function(global_random_seed):
794795
ids=data_containers_ids,
795796
)
796797
@pytest.mark.parametrize("Estimator", [KMeans, MiniBatchKMeans])
798+
@pytest.mark.skipif(
799+
not any(i for i in threadpool_info() if i["user_api"] == "blas"),
800+
reason=(
801+
"Fails for some global_random_seed on Atlas which cannot be detected by "
802+
"threadpoolctl."
803+
),
804+
)
797805
def test_float_precision(Estimator, input_data, global_random_seed):
798806
# Check that the results are the same for single and double precision.
799807
km = Estimator(n_init=1, random_state=global_random_seed)
@@ -822,10 +830,11 @@ def test_float_precision(Estimator, input_data, global_random_seed):
822830

823831
# compare arrays with low precision since the difference between 32 and
824832
# 64 bit comes from an accumulation of rounding errors.
825-
assert_allclose(inertia[np.float32], inertia[np.float64], rtol=1e-4)
826-
assert_allclose(Xt[np.float32], Xt[np.float64], atol=Xt[np.float64].max() * 1e-4)
833+
rtol = 1e-4
834+
assert_allclose(inertia[np.float32], inertia[np.float64], rtol=rtol)
835+
assert_allclose(Xt[np.float32], Xt[np.float64], atol=Xt[np.float64].max() * rtol)
827836
assert_allclose(
828-
centers[np.float32], centers[np.float64], atol=centers[np.float64].max() * 1e-4
837+
centers[np.float32], centers[np.float64], atol=centers[np.float64].max() * rtol
829838
)
830839
assert_array_equal(labels[np.float32], labels[np.float64])
831840

0 commit comments

Comments
 (0)