|
7 | 7 | import numpy as np
|
8 | 8 | import pytest
|
9 | 9 | from scipy import sparse as sp
|
| 10 | +from threadpoolctl import threadpool_info |
10 | 11 |
|
11 | 12 | from sklearn.base import clone
|
12 | 13 | from sklearn.cluster import KMeans, MiniBatchKMeans, k_means, kmeans_plusplus
|
@@ -744,7 +745,7 @@ def test_transform(Estimator, global_random_seed):
|
744 | 745 | # In particular, diagonal must be 0
|
745 | 746 | assert_array_equal(Xt.diagonal(), np.zeros(n_clusters))
|
746 | 747 |
|
747 |
| - # Transorfming X should return the pairwise distances between X and the |
| 748 | + # Transforming X should return the pairwise distances between X and the |
748 | 749 | # centers
|
749 | 750 | Xt = km.transform(X)
|
750 | 751 | assert_allclose(Xt, pairwise_distances(X, km.cluster_centers_))
|
@@ -794,6 +795,13 @@ def test_k_means_function(global_random_seed):
|
794 | 795 | ids=data_containers_ids,
|
795 | 796 | )
|
796 | 797 | @pytest.mark.parametrize("Estimator", [KMeans, MiniBatchKMeans])
|
| 798 | +@pytest.mark.skipif( |
| 799 | + not any(i for i in threadpool_info() if i["user_api"] == "blas"), |
| 800 | + reason=( |
| 801 | + "Fails for some global_random_seed on Atlas which cannot be detected by " |
| 802 | + "threadpoolctl." |
| 803 | + ), |
| 804 | +) |
797 | 805 | def test_float_precision(Estimator, input_data, global_random_seed):
|
798 | 806 | # Check that the results are the same for single and double precision.
|
799 | 807 | km = Estimator(n_init=1, random_state=global_random_seed)
|
@@ -822,10 +830,11 @@ def test_float_precision(Estimator, input_data, global_random_seed):
|
822 | 830 |
|
823 | 831 | # compare arrays with low precision since the difference between 32 and
|
824 | 832 | # 64 bit comes from an accumulation of rounding errors.
|
825 |
| - assert_allclose(inertia[np.float32], inertia[np.float64], rtol=1e-4) |
826 |
| - assert_allclose(Xt[np.float32], Xt[np.float64], atol=Xt[np.float64].max() * 1e-4) |
| 833 | + rtol = 1e-4 |
| 834 | + assert_allclose(inertia[np.float32], inertia[np.float64], rtol=rtol) |
| 835 | + assert_allclose(Xt[np.float32], Xt[np.float64], atol=Xt[np.float64].max() * rtol) |
827 | 836 | assert_allclose(
|
828 |
| - centers[np.float32], centers[np.float64], atol=centers[np.float64].max() * 1e-4 |
| 837 | + centers[np.float32], centers[np.float64], atol=centers[np.float64].max() * rtol |
829 | 838 | )
|
830 | 839 | assert_array_equal(labels[np.float32], labels[np.float64])
|
831 | 840 |
|
|
0 commit comments