Skip to content

Commit f965fcc

Browse files
ArturoAmorQjjerphanglemaitrejeremiedbb
authored
ENH csr_row_norms optimization (#24426)
Co-authored-by: Julien Jerphanion <git@jjerphan.xyz> Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com> Co-authored-by: Jérémie du Boisberranger <jeremiedbb@users.noreply.github.com>
1 parent 41a960c commit f965fcc

File tree

1 file changed

+17
-17
lines changed

1 file changed

+17
-17
lines changed

sklearn/utils/sparsefuncs_fast.pyx

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -24,32 +24,32 @@ ctypedef cnp.float64_t DOUBLE
2424

2525

2626
def csr_row_norms(X):
27-
"""L2 norm of each row in CSR matrix X."""
27+
"""Squared L2 norm of each row in CSR matrix X."""
2828
if X.dtype not in [np.float32, np.float64]:
2929
X = X.astype(np.float64)
30-
return _csr_row_norms(X.data, X.shape, X.indices, X.indptr)
30+
return _csr_row_norms(X.data, X.indices, X.indptr)
3131

3232

33-
def _csr_row_norms(cnp.ndarray[floating, ndim=1, mode="c"] X_data,
34-
shape,
35-
cnp.ndarray[integral, ndim=1, mode="c"] X_indices,
36-
cnp.ndarray[integral, ndim=1, mode="c"] X_indptr):
33+
def _csr_row_norms(
34+
const floating[::1] X_data,
35+
const integral[::1] X_indices,
36+
const integral[::1] X_indptr,
37+
):
3738
cdef:
38-
unsigned long long n_samples = shape[0]
39-
unsigned long long i
40-
integral j
39+
integral n_samples = X_indptr.shape[0] - 1
40+
integral i, j
4141
double sum_
4242

43-
norms = np.empty(n_samples, dtype=X_data.dtype)
44-
cdef floating[::1] norms_view = norms
43+
dtype = np.float32 if floating is float else np.float64
4544

46-
for i in range(n_samples):
47-
sum_ = 0.0
48-
for j in range(X_indptr[i], X_indptr[i + 1]):
49-
sum_ += X_data[j] * X_data[j]
50-
norms_view[i] = sum_
45+
cdef floating[::1] norms = np.zeros(n_samples, dtype=dtype)
46+
47+
with nogil:
48+
for i in range(n_samples):
49+
for j in range(X_indptr[i], X_indptr[i + 1]):
50+
norms[i] += X_data[j] * X_data[j]
5151

52-
return norms
52+
return np.asarray(norms)
5353

5454

5555
def csr_mean_variance_axis0(X, weights=None, return_sum_weights=False):

0 commit comments

Comments
 (0)