Skip to content

Commit 106027f

Browse files
jeremiedbbadrinjalali
authored andcommitted
ENH Move threadpoolctl outside of iteration loop in KMeans (scikit-learn#17235)
1 parent 201060f commit 106027f

File tree

4 files changed

+39
-60
lines changed

4 files changed

+39
-60
lines changed

doc/whats_new/v0.23.rst

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@ Changelog
1919
provided by the user were modified in place. :pr:`17204` by
2020
:user:`Jeremie du Boisberranger <jeremiedbb>`.
2121

22-
- |Efficiency| :class:`cluster.KMeans` cannot spawn idle threads any more for
23-
very small datasets. :pr:`17210` by
24-
:user:`Jeremie du Boisberranger <jeremiedbb>`.
22+
- |Efficiency| :class:`cluster.KMeans` efficiency has been improved for very
23+
small datasets. In particular it cannot spawn idle threads any more.
24+
:pr:`17210` and :pr:`17235` by :user:`Jeremie du Boisberranger <jeremiedbb>`.
2525

2626
Miscellaneous
2727
.............

sklearn/cluster/_k_means_elkan.pyx

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010

1111
import numpy as np
1212
cimport numpy as np
13-
from threadpoolctl import threadpool_limits
1413
cimport cython
1514
from cython cimport floating
1615
from cython.parallel import prange, parallel
@@ -30,18 +29,6 @@ from ._k_means_fast cimport _center_shift
3029
np.import_array()
3130

3231

33-
# Threadpoolctl wrappers to limit the number of threads in second level of
34-
# nested parallelism (i.e. BLAS) to avoid oversubsciption.
35-
def elkan_iter_chunked_dense(*args, **kwargs):
36-
with threadpool_limits(limits=1, user_api="blas"):
37-
_elkan_iter_chunked_dense(*args, **kwargs)
38-
39-
40-
def elkan_iter_chunked_sparse(*args, **kwargs):
41-
with threadpool_limits(limits=1, user_api="blas"):
42-
_elkan_iter_chunked_sparse(*args, **kwargs)
43-
44-
4532
def init_bounds_dense(
4633
np.ndarray[floating, ndim=2, mode='c'] X, # IN
4734
floating[:, ::1] centers, # IN
@@ -193,7 +180,7 @@ def init_bounds_sparse(
193180
upper_bounds[i] = min_dist
194181

195182

196-
def _elkan_iter_chunked_dense(
183+
def elkan_iter_chunked_dense(
197184
np.ndarray[floating, ndim=2, mode='c'] X, # IN
198185
floating[::1] sample_weight, # IN
199186
floating[:, ::1] centers_old, # IN
@@ -421,7 +408,7 @@ cdef void _update_chunk_dense(
421408
centers_new[label * n_features + k] += X[i * n_features + k] * sample_weight[i]
422409

423410

424-
def _elkan_iter_chunked_sparse(
411+
def elkan_iter_chunked_sparse(
425412
X, # IN
426413
floating[::1] sample_weight, # IN
427414
floating[:, ::1] centers_old, # IN

sklearn/cluster/_k_means_lloyd.pyx

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88

99
import numpy as np
1010
cimport numpy as np
11-
from threadpoolctl import threadpool_limits
1211
from cython cimport floating
1312
from cython.parallel import prange, parallel
1413
from libc.stdlib cimport malloc, calloc, free
@@ -26,19 +25,7 @@ from ._k_means_fast cimport _average_centers, _center_shift
2625
np.import_array()
2726

2827

29-
# Threadpoolctl wrappers to limit the number of threads in second level of
30-
# nested parallelism (i.e. BLAS) to avoid oversubsciption.
31-
def lloyd_iter_chunked_dense(*args, **kwargs):
32-
with threadpool_limits(limits=1, user_api="blas"):
33-
_lloyd_iter_chunked_dense(*args, **kwargs)
34-
35-
36-
def lloyd_iter_chunked_sparse(*args, **kwargs):
37-
with threadpool_limits(limits=1, user_api="blas"):
38-
_lloyd_iter_chunked_sparse(*args, **kwargs)
39-
40-
41-
def _lloyd_iter_chunked_dense(
28+
def lloyd_iter_chunked_dense(
4229
np.ndarray[floating, ndim=2, mode='c'] X, # IN
4330
floating[::1] sample_weight, # IN
4431
floating[::1] x_squared_norms, # IN
@@ -230,7 +217,7 @@ cdef void _update_chunk_dense(
230217
centers_new[label * n_features + k] += X[i * n_features + k] * sample_weight[i]
231218

232219

233-
def _lloyd_iter_chunked_sparse(
220+
def lloyd_iter_chunked_sparse(
234221
X, # IN
235222
floating[::1] sample_weight, # IN
236223
floating[::1] x_squared_norms, # IN

sklearn/cluster/_kmeans.py

Lines changed: 32 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
import numpy as np
1717
import scipy.sparse as sp
18+
from threadpoolctl import threadpool_limits
1819

1920
from ..base import BaseEstimator, ClusterMixin, TransformerMixin
2021
from ..metrics.pairwise import euclidean_distances
@@ -431,15 +432,16 @@ def _kmeans_single_elkan(X, sample_weight, n_clusters, max_iter=300,
431432
labels, upper_bounds, lower_bounds)
432433

433434
for i in range(max_iter):
434-
elkan_iter(X, sample_weight, centers, centers_new, weight_in_clusters,
435-
center_half_distances, distance_next_center, upper_bounds,
436-
lower_bounds, labels, center_shift, n_threads)
435+
elkan_iter(X, sample_weight, centers, centers_new,
436+
weight_in_clusters, center_half_distances,
437+
distance_next_center, upper_bounds, lower_bounds,
438+
labels, center_shift, n_threads)
437439

438440
# compute new pairwise distances between centers and closest other
439441
# center of each center for next iterations
440442
center_half_distances = euclidean_distances(centers_new) / 2
441-
distance_next_center = np.partition(np.asarray(center_half_distances),
442-
kth=1, axis=0)[1]
443+
distance_next_center = np.partition(
444+
np.asarray(center_half_distances), kth=1, axis=0)[1]
443445

444446
if verbose:
445447
inertia = _inertia(X, sample_weight, centers, labels)
@@ -458,9 +460,9 @@ def _kmeans_single_elkan(X, sample_weight, n_clusters, max_iter=300,
458460
if center_shift_tot > 0:
459461
# rerun E-step so that predicted labels match cluster centers
460462
elkan_iter(X, sample_weight, centers, centers, weight_in_clusters,
461-
center_half_distances, distance_next_center, upper_bounds,
462-
lower_bounds, labels, center_shift, n_threads,
463-
update_centers=False)
463+
center_half_distances, distance_next_center,
464+
upper_bounds, lower_bounds, labels, center_shift,
465+
n_threads, update_centers=False)
464466

465467
inertia = _inertia(X, sample_weight, centers, labels)
466468

@@ -564,29 +566,32 @@ def _kmeans_single_lloyd(X, sample_weight, n_clusters, max_iter=300,
564566
lloyd_iter = lloyd_iter_chunked_dense
565567
_inertia = _inertia_dense
566568

567-
for i in range(max_iter):
568-
lloyd_iter(X, sample_weight, x_squared_norms, centers, centers_new,
569-
weight_in_clusters, labels, center_shift, n_threads)
570-
571-
if verbose:
572-
inertia = _inertia(X, sample_weight, centers, labels)
573-
print("Iteration {0}, inertia {1}" .format(i, inertia))
569+
# Threadpoolctl context to limit the number of threads in second level of
570+
# nested parallelism (i.e. BLAS) to avoid oversubsciption.
571+
with threadpool_limits(limits=1, user_api="blas"):
572+
for i in range(max_iter):
573+
lloyd_iter(X, sample_weight, x_squared_norms, centers, centers_new,
574+
weight_in_clusters, labels, center_shift, n_threads)
574575

575-
center_shift_tot = (center_shift**2).sum()
576-
if center_shift_tot <= tol:
577576
if verbose:
578-
print("Converged at iteration {0}: "
579-
"center shift {1} within tolerance {2}"
580-
.format(i, center_shift_tot, tol))
581-
break
577+
inertia = _inertia(X, sample_weight, centers, labels)
578+
print("Iteration {0}, inertia {1}" .format(i, inertia))
579+
580+
center_shift_tot = (center_shift**2).sum()
581+
if center_shift_tot <= tol:
582+
if verbose:
583+
print("Converged at iteration {0}: "
584+
"center shift {1} within tolerance {2}"
585+
.format(i, center_shift_tot, tol))
586+
break
582587

583-
centers, centers_new = centers_new, centers
588+
centers, centers_new = centers_new, centers
584589

585-
if center_shift_tot > 0:
586-
# rerun E-step so that predicted labels match cluster centers
587-
lloyd_iter(X, sample_weight, x_squared_norms, centers, centers,
588-
weight_in_clusters, labels, center_shift, n_threads,
589-
update_centers=False)
590+
if center_shift_tot > 0:
591+
# rerun E-step so that predicted labels match cluster centers
592+
lloyd_iter(X, sample_weight, x_squared_norms, centers, centers,
593+
weight_in_clusters, labels, center_shift, n_threads,
594+
update_centers=False)
590595

591596
inertia = _inertia(X, sample_weight, centers, labels)
592597

0 commit comments

Comments
 (0)