diff --git a/sklearn/utils/sparsefuncs_fast.pyx b/sklearn/utils/sparsefuncs_fast.pyx index 9cbfcc1f7a3f6..4b975872c0e0a 100644 --- a/sklearn/utils/sparsefuncs_fast.pyx +++ b/sklearn/utils/sparsefuncs_fast.pyx @@ -12,8 +12,11 @@ from libc.math cimport fabs, sqrt cimport numpy as cnp import numpy as np from cython cimport floating +from cython.parallel cimport prange from numpy.math cimport isnan +from sklearn.utils._openmp_helpers import _openmp_effective_n_threads + cnp.import_array() ctypedef fused integral: @@ -27,13 +30,14 @@ def csr_row_norms(X): """Squared L2 norm of each row in CSR matrix X.""" if X.dtype not in [np.float32, np.float64]: X = X.astype(np.float64) - return _csr_row_norms(X.data, X.indices, X.indptr) + n_threads = _openmp_effective_n_threads() + return _sqeuclidean_row_norms_sparse(X.data, X.indptr, n_threads) -def _csr_row_norms( +def _sqeuclidean_row_norms_sparse( const floating[::1] X_data, - const integral[::1] X_indices, const integral[::1] X_indptr, + int n_threads, ): cdef: integral n_samples = X_indptr.shape[0] - 1 @@ -42,14 +46,13 @@ def _csr_row_norms( dtype = np.float32 if floating is float else np.float64 - cdef floating[::1] norms = np.zeros(n_samples, dtype=dtype) + cdef floating[::1] squared_row_norms = np.zeros(n_samples, dtype=dtype) - with nogil: - for i in range(n_samples): - for j in range(X_indptr[i], X_indptr[i + 1]): - norms[i] += X_data[j] * X_data[j] + for i in prange(n_samples, schedule='static', nogil=True, num_threads=n_threads): + for j in range(X_indptr[i], X_indptr[i + 1]): + squared_row_norms[i] += X_data[j] * X_data[j] - return np.asarray(norms) + return np.asarray(squared_row_norms) def csr_mean_variance_axis0(X, weights=None, return_sum_weights=False):