Skip to content

EHN Optimized CSR-CSR support for Euclidean specializations of PairwiseDistancesReductions #24556

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
24fe8a6
wip
Vincent-Maladiere Sep 30, 2022
0132e56
remove potential bug
Vincent-Maladiere Sep 30, 2022
c0ac7db
much fix such wow
Vincent-Maladiere Oct 11, 2022
27b9427
compiling !
Vincent-Maladiere Oct 12, 2022
2a14154
add sparse method for sq_euclidean_norm
Vincent-Maladiere Oct 16, 2022
4060271
remove print
Vincent-Maladiere Oct 16, 2022
c8e35fd
underkill the overkill
Vincent-Maladiere Oct 18, 2022
c9aed9d
baptism by fire
Vincent-Maladiere Oct 18, 2022
12acf71
Merge branch 'main' into euclidean_argkmin_sparse_sparse
Vincent-Maladiere Oct 19, 2022
c5bd316
clean it up, yo
Vincent-Maladiere Oct 19, 2022
4697f4e
Merge branch 'main' into euclidean_argkmin_sparse_sparse
Vincent-Maladiere Oct 19, 2022
50bbf47
worldwide renaming
Vincent-Maladiere Oct 19, 2022
e9cc5ad
add sparse test for test_sqeuclidean_row_norms
Vincent-Maladiere Oct 19, 2022
2ee517d
Remove previous Cython templates and sources
Vincent-Maladiere Oct 23, 2022
9df551d
Apply suggestions
Vincent-Maladiere Oct 23, 2022
76574eb
remove np.asarray from test
Vincent-Maladiere Oct 24, 2022
72ffecb
Merge branch 'main' into euclidean_argkmin_sparse_sparse
jjerphan Oct 26, 2022
07622be
update some doc
Vincent-Maladiere Oct 27, 2022
e0080f2
Merge branch 'main' into euclidean_argkmin_sparse_sparse
Vincent-Maladiere Oct 27, 2022
fb9a3c9
doc improvement
Vincent-Maladiere Oct 28, 2022
88e074d
Update sklearn/metrics/_pairwise_distances_reduction/_argkmin.pyx.tp
Vincent-Maladiere Oct 28, 2022
d6bff55
extend test_pairwise_distances_argkmin, fail with dtype=float32 and t…
Vincent-Maladiere Oct 28, 2022
9f96865
branch SparseSparseMiddleTermComputer32 to _middle_term_sparse_sparse_64
Vincent-Maladiere Oct 28, 2022
2df86bc
branch sparse_sqeuclidean_row_norm32 to sparse_sqeuclidean_row_norm64…
Vincent-Maladiere Oct 28, 2022
8aa5383
Merge branch 'main' into euclidean_argkmin_sparse_sparse
Vincent-Maladiere Oct 28, 2022
8f7bfa5
Merge branch 'main' into euclidean_argkmin_sparse_sparse
Vincent-Maladiere Nov 4, 2022
c0675ac
remove unused variables
Vincent-Maladiere Nov 4, 2022
40a6afb
Merge branch 'main' into euclidean_argkmin_sparse_sparse
Vincent-Maladiere Nov 4, 2022
366548e
some suggestions from Maurice
Vincent-Maladiere Nov 4, 2022
8ee5a29
Apply suggestions from code review
Vincent-Maladiere Nov 7, 2022
8d814fa
Update sklearn/metrics/tests/test_pairwise_distances_reduction.py
Vincent-Maladiere Nov 8, 2022
47c6127
Merge branch 'main' into euclidean_argkmin_sparse_sparse
Vincent-Maladiere Nov 8, 2022
4fbb024
Merge branch 'main' into euclidean_argkmin_sparse_sparse
jjerphan Nov 16, 2022
6fa803f
Update sklearn/metrics/_pairwise_distances_reduction/_base.pyx.tp
Vincent-Maladiere Nov 16, 2022
e20ed2c
Merge branch 'main' into euclidean_argkmin_sparse_sparse
Vincent-Maladiere Nov 16, 2022
115813f
Merge branch 'main' into euclidean_argkmin_sparse_sparse
jjerphan Nov 17, 2022
c08a515
apply suggestions :)
Vincent-Maladiere Nov 18, 2022
e94d3dd
Merge branch 'main' into euclidean_argkmin_sparse_sparse
Vincent-Maladiere Nov 18, 2022
001fa5b
Fu... sion!
Vincent-Maladiere Nov 18, 2022
61b462e
Merge branch 'main' into euclidean_argkmin_sparse_sparse
Vincent-Maladiere Nov 18, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion doc/whats_new/v1.2.rst
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,8 @@ Changes impacting all modules
- :func:`sklearn.manifold.trustworthiness`

:pr:`23604` and :pr:`23585` by :user:`Julien Jerphanion <jjerphan>`,
:user:`Olivier Grisel <ogrisel>`, and `Thomas Fan`_.
:user:`Olivier Grisel <ogrisel>`, and `Thomas Fan`_,
:pr:`24556` by :user:`Vincent Maladière <Vincent-Maladiere>`.

- |Fix| Systematically check the sha256 digest of dataset tarballs used in code
examples in the documentation.
Expand Down
6 changes: 3 additions & 3 deletions sklearn/metrics/_pairwise_distances_reduction/_argkmin.pyx.tp
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,10 @@ cdef class ArgKmin{{name_suffix}}(BaseDistancesReduction{{name_suffix}}):
"""
if (
metric in ("euclidean", "sqeuclidean")
and not (issparse(X) or issparse(Y))
and not (issparse(X) ^ issparse(Y)) # "^" is the XOR operator
):
# Specialized implementation of ArgKmin for the Euclidean distance.
# Specialized implementation of ArgKmin for the Euclidean distance
# for the dense-dense and sparse-sparse cases.
# This implementation computes the distances by chunk using
# a decomposition of the Squared Euclidean distance.
# This specialisation has an improved arithmetic intensity for both
Expand Down Expand Up @@ -492,7 +493,6 @@ cdef class EuclideanArgKmin{{name_suffix}}(ArgKmin{{name_suffix}}):
DTYPE_t * heaps_r_distances = self.heaps_r_distances_chunks[thread_num]
ITYPE_t * heaps_indices = self.heaps_indices_chunks[thread_num]


# Pushing the distance and their associated indices on heaps
# which keep tracks of the argkmin.
for i in range(n_X):
Expand Down
4 changes: 2 additions & 2 deletions sklearn/metrics/_pairwise_distances_reduction/_base.pxd.tp
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ cimport numpy as cnp

from cython cimport final

from ...utils._typedefs cimport ITYPE_t, DTYPE_t
from ...utils._typedefs cimport ITYPE_t, DTYPE_t, SPARSE_INDEX_TYPE_t

cnp.import_array()

Expand All @@ -12,7 +12,7 @@ from ._datasets_pair cimport DatasetsPair{{name_suffix}}


cpdef DTYPE_t[::1] _sqeuclidean_row_norms{{name_suffix}}(
const {{INPUT_DTYPE_t}}[:, ::1] X,
X,
ITYPE_t num_threads,
)

Expand Down
33 changes: 29 additions & 4 deletions sklearn/metrics/_pairwise_distances_reduction/_base.pyx.tp
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,12 @@ from ...utils._typedefs cimport ITYPE_t, DTYPE_t

import numpy as np

from scipy.sparse import issparse
from numbers import Integral
from sklearn import get_config
from sklearn.utils import check_scalar
from ...utils._openmp_helpers import _openmp_effective_n_threads
from ...utils._typedefs import DTYPE
from ...utils._typedefs import DTYPE, SPARSE_INDEX_TYPE

cnp.import_array()

Expand Down Expand Up @@ -102,16 +103,40 @@ cdef DTYPE_t[::1] _sqeuclidean_row_norms32_dense(
return squared_row_norms


cdef DTYPE_t[::1] _sqeuclidean_row_norms64_sparse(
const DTYPE_t[:] X_data,
const SPARSE_INDEX_TYPE_t[:] X_indptr,
ITYPE_t num_threads,
):
cdef:
ITYPE_t n = X_indptr.shape[0] - 1
SPARSE_INDEX_TYPE_t X_i_ptr, idx = 0
DTYPE_t[::1] squared_row_norms = np.zeros(n, dtype=DTYPE)

for idx in prange(n, schedule='static', nogil=True, num_threads=num_threads):
for X_i_ptr in range(X_indptr[idx], X_indptr[idx+1]):
squared_row_norms[idx] += X_data[X_i_ptr] * X_data[X_i_ptr]

return squared_row_norms


{{for name_suffix, INPUT_DTYPE_t, INPUT_DTYPE in implementation_specific_values}}

from ._datasets_pair cimport DatasetsPair{{name_suffix}}


cpdef DTYPE_t[::1] _sqeuclidean_row_norms{{name_suffix}}(
const {{INPUT_DTYPE_t}}[:, ::1] X,
X,
ITYPE_t num_threads,
):
return _sqeuclidean_row_norms{{name_suffix}}_dense(X, num_threads)
if issparse(X):
# TODO: remove this instruction which is a cast in the float32 case
# by moving squared row norms computations in MiddleTermComputer.
X_data = np.asarray(X.data, dtype=DTYPE)
X_indptr = np.asarray(X.indptr, dtype=SPARSE_INDEX_TYPE)
return _sqeuclidean_row_norms64_sparse(X_data, X_indptr, num_threads)
else:
return _sqeuclidean_row_norms{{name_suffix}}_dense(X, num_threads)


cdef class BaseDistancesReduction{{name_suffix}}:
Expand All @@ -131,7 +156,7 @@ cdef class BaseDistancesReduction{{name_suffix}}:
strategy=None,
):
cdef:
ITYPE_t n_samples_chunk, X_n_full_chunks, Y_n_full_chunks
ITYPE_t X_n_full_chunks, Y_n_full_chunks

if chunk_size is None:
chunk_size = get_config().get("pairwise_dist_chunk_size", 256)
Expand Down
9 changes: 4 additions & 5 deletions sklearn/metrics/_pairwise_distances_reduction/_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,7 @@

from .._dist_metrics import BOOL_METRICS, METRIC_MAPPING

from ._base import (
_sqeuclidean_row_norms64,
_sqeuclidean_row_norms32,
)
from ._base import _sqeuclidean_row_norms32, _sqeuclidean_row_norms64
from ._argkmin import (
ArgKmin64,
ArgKmin32,
Expand Down Expand Up @@ -133,8 +130,10 @@ def is_valid_sparse_matrix(X):
# See: https://github.com/scikit-learn/scikit-learn/pull/23585#issuecomment-1247996669 # noqa
# TODO: implement specialisation for (sq)euclidean on fused sparse-dense
# using sparse-dense routines for matrix-vector multiplications.
# Currently, only dense-dense and sparse-sparse are optimized for
# the Euclidean case.
fused_sparse_dense_euclidean_case_guard = not (
(is_valid_sparse_matrix(X) or is_valid_sparse_matrix(Y))
(is_valid_sparse_matrix(X) ^ is_valid_sparse_matrix(Y)) # "^" is XOR
and isinstance(metric, str)
and "euclidean" in metric
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,22 @@ cimport numpy as cnp

from libcpp.vector cimport vector

from ...utils._typedefs cimport DTYPE_t, ITYPE_t
from ...utils._typedefs cimport DTYPE_t, ITYPE_t, SPARSE_INDEX_TYPE_t


cdef void _middle_term_sparse_sparse_64(
const DTYPE_t[:] X_data,
const SPARSE_INDEX_TYPE_t[:] X_indices,
const SPARSE_INDEX_TYPE_t[:] X_indptr,
ITYPE_t X_start,
ITYPE_t X_end,
const DTYPE_t[:] Y_data,
const SPARSE_INDEX_TYPE_t[:] Y_indices,
const SPARSE_INDEX_TYPE_t[:] Y_indptr,
ITYPE_t Y_start,
ITYPE_t Y_end,
DTYPE_t * D,
) nogil


{{for name_suffix, upcast_to_float64, INPUT_DTYPE_t, INPUT_DTYPE in implementation_specific_values}}
Expand Down Expand Up @@ -133,4 +148,42 @@ cdef class DenseDenseMiddleTermComputer{{name_suffix}}(MiddleTermComputer{{name_
) nogil


cdef class SparseSparseMiddleTermComputer{{name_suffix}}(MiddleTermComputer{{name_suffix}}):
cdef:
const DTYPE_t[:] X_data
const SPARSE_INDEX_TYPE_t[:] X_indices
const SPARSE_INDEX_TYPE_t[:] X_indptr

const DTYPE_t[:] Y_data
const SPARSE_INDEX_TYPE_t[:] Y_indices
const SPARSE_INDEX_TYPE_t[:] Y_indptr

cdef void _parallel_on_X_pre_compute_and_reduce_distances_on_chunks(
self,
ITYPE_t X_start,
ITYPE_t X_end,
ITYPE_t Y_start,
ITYPE_t Y_end,
ITYPE_t thread_num
) nogil

cdef void _parallel_on_Y_pre_compute_and_reduce_distances_on_chunks(
self,
ITYPE_t X_start,
ITYPE_t X_end,
ITYPE_t Y_start,
ITYPE_t Y_end,
ITYPE_t thread_num
) nogil

cdef DTYPE_t * _compute_dist_middle_terms(
self,
ITYPE_t X_start,
ITYPE_t X_end,
ITYPE_t Y_start,
ITYPE_t Y_end,
ITYPE_t thread_num,
) nogil


{{endfor}}
Loading