Skip to content

MAINT Improve the _middle_term_sparse_sparse_{32, 64} routines #25449

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,6 @@ import numpy as np
from scipy.sparse import issparse, csr_matrix
from ...utils._typedefs import DTYPE, SPARSE_INDEX_TYPE

# TODO: If possible optimize this routine to efficiently treat cases where
# `n_samples_X << n_samples_Y` met in practise when X_test consists of a
# few samples, and thus when there's a single chunk of X whose number of
# samples is less than the default chunk size.

# TODO: compare this routine with the similar ones in SciPy, especially
# `csr_matmat` which might implement a better algorithm.
# See: https://github.com/scipy/scipy/blob/e58292e066ba2cb2f3d1e0563ca9314ff1f4f311/scipy/sparse/sparsetools/csr.h#L603-L669 # noqa
cdef void _middle_term_sparse_sparse_64(
const DTYPE_t[:] X_data,
const SPARSE_INDEX_TYPE_t[:] X_indices,
Expand All @@ -66,17 +58,17 @@ cdef void _middle_term_sparse_sparse_64(
ITYPE_t i, j, k
ITYPE_t n_X = X_end - X_start
ITYPE_t n_Y = Y_end - Y_start
ITYPE_t X_i_col_idx, X_i_ptr, Y_j_col_idx, Y_j_ptr
ITYPE_t x_col, x_ptr, y_col, y_ptr

for i in range(n_X):
for X_i_ptr in range(X_indptr[X_start+i], X_indptr[X_start+i+1]):
X_i_col_idx = X_indices[X_i_ptr]
for x_ptr in range(X_indptr[X_start+i], X_indptr[X_start+i+1]):
x_col = X_indices[x_ptr]
for j in range(n_Y):
k = i * n_Y + j
for Y_j_ptr in range(Y_indptr[Y_start+j], Y_indptr[Y_start+j+1]):
Y_j_col_idx = Y_indices[Y_j_ptr]
if X_i_col_idx == Y_j_col_idx:
D[k] += -2 * X_data[X_i_ptr] * Y_data[Y_j_ptr]
for y_ptr in range(Y_indptr[Y_start+j], Y_indptr[Y_start+j+1]):
y_col = Y_indices[y_ptr]
if x_col == y_col:
D[k] += -2 * X_data[x_ptr] * Y_data[y_ptr]


{{for name_suffix, upcast_to_float64, INPUT_DTYPE_t, INPUT_DTYPE in implementation_specific_values}}
Expand Down