|
13 | 13 | from functools import update_wrapper
|
14 | 14 | import functools
|
15 | 15 |
|
| 16 | +import sklearn |
16 | 17 | import numpy as np
|
17 | 18 | import scipy.sparse as sp
|
18 | 19 | import scipy
|
19 | 20 | import scipy.stats
|
20 | 21 | from scipy.sparse.linalg import lsqr as sparse_lsqr # noqa
|
| 22 | +import threadpoolctl |
21 | 23 | from .._config import config_context, get_config
|
22 | 24 | from ..externals._packaging.version import parse as parse_version
|
23 | 25 |
|
@@ -271,3 +273,33 @@ def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis
|
271 | 273 | dtype=dtype,
|
272 | 274 | axis=axis,
|
273 | 275 | )
|
| 276 | + |
| 277 | + |
| 278 | +# compatibility fix for threadpoolctl >= 3.0.0 |
| 279 | +# since version 3 it's possible to setup a global threadpool controller to avoid |
| 280 | +# looping through all loaded shared libraries each time. |
| 281 | +# the global controller is created during the first call to threadpoolctl. |
| 282 | +def _get_threadpool_controller(): |
| 283 | + if not hasattr(threadpoolctl, "ThreadpoolController"): |
| 284 | + return None |
| 285 | + |
| 286 | + if not hasattr(sklearn, "_sklearn_threadpool_controller"): |
| 287 | + sklearn._sklearn_threadpool_controller = threadpoolctl.ThreadpoolController() |
| 288 | + |
| 289 | + return sklearn._sklearn_threadpool_controller |
| 290 | + |
| 291 | + |
| 292 | +def threadpool_limits(limits=None, user_api=None): |
| 293 | + controller = _get_threadpool_controller() |
| 294 | + if controller is not None: |
| 295 | + return controller.limit(limits=limits, user_api=user_api) |
| 296 | + else: |
| 297 | + return threadpoolctl.threadpool_limits(limits=limits, user_api=user_api) |
| 298 | + |
| 299 | + |
| 300 | +def threadpool_info(): |
| 301 | + controller = _get_threadpool_controller() |
| 302 | + if controller is not None: |
| 303 | + return controller.info() |
| 304 | + else: |
| 305 | + return threadpoolctl.threadpool_info() |
0 commit comments