Skip to content

[MRG] FIX row_norms return dtype same as input #12423

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 11 additions & 10 deletions sklearn/utils/sparsefuncs_fast.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -25,35 +25,36 @@ ctypedef fused integral:

ctypedef np.float64_t DOUBLE


def csr_row_norms(X):
"""L2 norm of each row in CSR matrix X."""
if X.dtype not in [np.float32, np.float64]:
X = X.astype(np.float64)
return _csr_row_norms(X.data, X.shape, X.indices, X.indptr)

norms = np.zeros(X.shape[0], dtype=X.data.dtype)
_csr_row_norms(X.data, X.shape, X.indices, X.indptr, norms)

return norms


def _csr_row_norms(np.ndarray[floating, ndim=1, mode="c"] X_data,
shape,
np.ndarray[integral, ndim=1, mode="c"] X_indices,
np.ndarray[integral, ndim=1, mode="c"] X_indptr):
np.ndarray[integral, ndim=1, mode="c"] X_indptr,
floating[::1] norms):
cdef:
unsigned long long n_samples = shape[0]
unsigned long long n_features = shape[1]
np.ndarray[DOUBLE, ndim=1, mode="c"] norms

np.npy_intp i, j

unsigned long long i
integral j
double sum_

norms = np.zeros(n_samples, dtype=np.float64)

for i in range(n_samples):
sum_ = 0.0
for j in range(X_indptr[i], X_indptr[i + 1]):
sum_ += X_data[j] * X_data[j]
norms[i] = sum_

return norms


def csr_mean_variance_axis0(X):
"""Compute mean and variance along axis 0 on a CSR matrix
Expand Down
14 changes: 13 additions & 1 deletion sklearn/utils/tests/test_sparsefuncs.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
count_nonzero, csc_median_axis_0)
from sklearn.utils.sparsefuncs_fast import (assign_rows_csr,
inplace_csr_row_normalize_l1,
inplace_csr_row_normalize_l2)
inplace_csr_row_normalize_l2,
csr_row_norms)
from sklearn.utils.testing import assert_raises
from sklearn.utils.testing import assert_allclose

Expand Down Expand Up @@ -512,3 +513,14 @@ def test_inplace_normalize():
if inplace_csr_row_normalize is inplace_csr_row_normalize_l2:
X_csr.data **= 2
assert_array_almost_equal(np.abs(X_csr).sum(axis=1), ones)


@pytest.mark.parametrize("dtype", [np.float32, np.float64])
def test_csr_row_norms(dtype):
# checks that csr_row_norms returns the same output as
# scipy.sparse.linalg.norm, and that the dype is the same as X.
X = sp.random(100, 10, format='csr', dtype=dtype)
scipy_norms = sp.linalg.norm(X, axis=1)**2
norms = csr_row_norms(X)
assert norms.dtype.type is dtype
assert_array_almost_equal(norms, scipy_norms)