Skip to content

Commit 1cab25c

Browse files
authored
MNT use new threadpoolctl API (global threadpool controller) (#21206)
1 parent 10a5468 commit 1cab25c

File tree

3 files changed

+35
-3
lines changed

3 files changed

+35
-3
lines changed

sklearn/cluster/_kmeans.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,13 @@
1515

1616
import numpy as np
1717
import scipy.sparse as sp
18-
from threadpoolctl import threadpool_limits
19-
from threadpoolctl import threadpool_info
2018

2119
from ..base import BaseEstimator, ClusterMixin, TransformerMixin
2220
from ..metrics.pairwise import euclidean_distances
2321
from ..metrics.pairwise import _euclidean_distances
2422
from ..utils.extmath import row_norms, stable_cumsum
23+
from ..utils.fixes import threadpool_limits
24+
from ..utils.fixes import threadpool_info
2525
from ..utils.sparsefuncs_fast import assign_rows_csr
2626
from ..utils.sparsefuncs import mean_variance_axis
2727
from ..utils import check_array

sklearn/cluster/tests/test_k_means.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,13 @@
44

55
import numpy as np
66
from scipy import sparse as sp
7-
from threadpoolctl import threadpool_limits
87

98
import pytest
109

1110
from sklearn.utils._testing import assert_array_equal
1211
from sklearn.utils._testing import assert_allclose
1312
from sklearn.utils.fixes import _astype_copy_false
13+
from sklearn.utils.fixes import threadpool_limits
1414
from sklearn.base import clone
1515
from sklearn.exceptions import ConvergenceWarning
1616

sklearn/utils/fixes.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,13 @@
1313
from functools import update_wrapper
1414
import functools
1515

16+
import sklearn
1617
import numpy as np
1718
import scipy.sparse as sp
1819
import scipy
1920
import scipy.stats
2021
from scipy.sparse.linalg import lsqr as sparse_lsqr # noqa
22+
import threadpoolctl
2123
from .._config import config_context, get_config
2224
from ..externals._packaging.version import parse as parse_version
2325

@@ -271,3 +273,33 @@ def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis
271273
dtype=dtype,
272274
axis=axis,
273275
)
276+
277+
278+
# compatibility fix for threadpoolctl >= 3.0.0
279+
# since version 3 it's possible to setup a global threadpool controller to avoid
280+
# looping through all loaded shared libraries each time.
281+
# the global controller is created during the first call to threadpoolctl.
282+
def _get_threadpool_controller():
283+
if not hasattr(threadpoolctl, "ThreadpoolController"):
284+
return None
285+
286+
if not hasattr(sklearn, "_sklearn_threadpool_controller"):
287+
sklearn._sklearn_threadpool_controller = threadpoolctl.ThreadpoolController()
288+
289+
return sklearn._sklearn_threadpool_controller
290+
291+
292+
def threadpool_limits(limits=None, user_api=None):
293+
controller = _get_threadpool_controller()
294+
if controller is not None:
295+
return controller.limit(limits=limits, user_api=user_api)
296+
else:
297+
return threadpoolctl.threadpool_limits(limits=limits, user_api=user_api)
298+
299+
300+
def threadpool_info():
301+
controller = _get_threadpool_controller()
302+
if controller is not None:
303+
return controller.info()
304+
else:
305+
return threadpoolctl.threadpool_info()

0 commit comments

Comments
 (0)