Skip to content

FEA Introduce PairwiseDistances, a generic back-end for pairwise_distances #25561

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
e4bf4e7
introduce PairwiseDistance
Vincent-Maladiere Feb 7, 2023
ddc464b
temporarily swap 'sqeuclidean' with 'euclidean' to fix tests
Vincent-Maladiere Feb 7, 2023
2a8b47a
Add nogil for benchmark purposes
Vincent-Maladiere Feb 8, 2023
fd4117b
update is_usable_for and docstrings
Vincent-Maladiere Feb 11, 2023
cfa7d8c
update whats_new
Vincent-Maladiere Feb 11, 2023
dd60dd5
Merge branch 'feature/PairwiseDistances' into feat/pairwise_distances…
Vincent-Maladiere Feb 11, 2023
5d9dcdb
remove chunksize and extend tests
Vincent-Maladiere Feb 14, 2023
8787747
Apply suggestions from code review
Vincent-Maladiere Feb 14, 2023
3fd55b9
add test_pairwise_distances_is_usable_for
Vincent-Maladiere Feb 14, 2023
e8dfef5
fix monkeypatch test and extend is_usable_for to single threaded manh…
Vincent-Maladiere Feb 14, 2023
77952d8
Merge branch 'main' into feat/pairwise_distances-pdr-backend
jjerphan Mar 8, 2023
9eed73f
DOC Add comments
jjerphan Mar 8, 2023
502725d
Convert sparse matrices to be CSR
jjerphan Mar 8, 2023
face5a9
Correct sqeucliden adaptation
jjerphan Mar 8, 2023
cdd2567
Correct condition for PairwiseDistances.is_usable_for
jjerphan Mar 9, 2023
32b4dad
TST Adapt test_pairwise_distances_is_usable_for
jjerphan Mar 9, 2023
9d6cd41
Use threadpool_limits over ThreadpoolController
jjerphan Mar 9, 2023
304d7d8
TST Increase atol for test_euclidean_distances_extreme_values
jjerphan Mar 9, 2023
99de991
DOC Add docstring for PairwiseDistances.is_usable_for
jjerphan Mar 9, 2023
40e4ff3
Merge branch 'main' into feat/pairwise_distances-pdr-backend
Vincent-Maladiere Jun 21, 2023
cb518ab
finalize merging with main by removing deprecated DTYPE and ITYPE
Vincent-Maladiere Jun 21, 2023
1754d18
Merge branch 'feature/PairwiseDistances' into feat/pairwise_distances…
Vincent-Maladiere Jun 21, 2023
8e94b7f
DOC Update link to best sphinx version for doc build (#26626)
lucyleeow Jun 21, 2023
53bdbe5
MNT add isort to ruff's rules (#26649)
adrinjalali Jun 21, 2023
e04160c
FIX use None as default in HDBSCAN (#26650)
glemaitre Jun 21, 2023
54600aa
Merge branch 'feature/PairwiseDistances' into feat/pairwise_distances…
Vincent-Maladiere Jun 22, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,8 @@ sklearn/metrics/_pairwise_distances_reduction/_datasets_pair.pxd
sklearn/metrics/_pairwise_distances_reduction/_datasets_pair.pyx
sklearn/metrics/_pairwise_distances_reduction/_middle_term_computer.pxd
sklearn/metrics/_pairwise_distances_reduction/_middle_term_computer.pyx
sklearn/metrics/_pairwise_distances_reduction/_pairwise_distances.pxd
sklearn/metrics/_pairwise_distances_reduction/_pairwise_distances.pyx
sklearn/metrics/_pairwise_distances_reduction/_radius_neighbors.pxd
sklearn/metrics/_pairwise_distances_reduction/_radius_neighbors.pyx

Expand Down
10 changes: 10 additions & 0 deletions doc/whats_new/v1.3.rst
Original file line number Diff line number Diff line change
Expand Up @@ -518,6 +518,16 @@ Changelog
- |Fix| :func:`metrics.manhattan_distances` now supports readonly sparse datasets.
:pr:`25432` by :user:`Julien Jerphanion <jjerphan>`.

- |Efficiency| :func:`pairwise.pairwise_distances`' performance has been improved
when providing dense datasets.
:pr:`25561` by :user:`Vincent Maladiere <Vincent-Maladiere>` and
:user:`Julien Jerphanion <jjerphan>`.

- |Feature| :func:`pairwise.pairwise_distances` now supports combination of
dense arrays and sparse CSR matrices datasets.
:pr:`25561` by :user:`Vincent Maladiere <Vincent-Maladiere>` and
:user:`Julien Jerphanion <jjerphan>`.

- |Fix| Fixed :func:`metrics.classification_report` so that empty input will return
`np.nan`. Previously, "macro avg" and `weighted avg` would return
e.g. `f1-score=np.nan` and `f1-score=0.0`, being inconsistent. Now, they
Expand Down
2 changes: 2 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ ignore =
sklearn/metrics/_pairwise_distances_reduction/_datasets_pair.pyx
sklearn/metrics/_pairwise_distances_reduction/_middle_term_computer.pxd
sklearn/metrics/_pairwise_distances_reduction/_middle_term_computer.pyx
sklearn/metrics/_pairwise_distances_reduction/_pairwise_distances.pxd
sklearn/metrics/_pairwise_distances_reduction/_pairwise_distances.pyx
sklearn/metrics/_pairwise_distances_reduction/_radius_neighbors.pxd
sklearn/metrics/_pairwise_distances_reduction/_radius_neighbors.pyx

Expand Down
6 changes: 6 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,12 @@ def check_package_status(package, min_version):
"include_np": True,
"extra_compile_args": ["-std=c++11"],
},
{
"sources": ["_pairwise_distances.pyx.tp", "_pairwise_distances.pxd.tp"],
"language": "c++",
"include_np": True,
"extra_compile_args": ["-std=c++11"],
},
{
"sources": ["_argkmin.pyx.tp", "_argkmin.pxd.tp"],
"language": "c++",
Expand Down
4 changes: 4 additions & 0 deletions sklearn/metrics/_pairwise_distances_reduction/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,14 +90,18 @@
ArgKmin,
ArgKminClassMode,
BaseDistancesReductionDispatcher,
PairwiseDistances,
RadiusNeighbors,
sqeuclidean_row_norms,
)
from ._pairwise_distances import _precompute_metric_params

__all__ = [
"BaseDistancesReductionDispatcher",
"ArgKmin",
"PairwiseDistances",
"RadiusNeighbors",
"ArgKminClassMode",
"sqeuclidean_row_norms",
"_precompute_metric_params",
]
1 change: 1 addition & 0 deletions sklearn/metrics/_pairwise_distances_reduction/_base.pxd.tp
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ cdef class BaseDistancesReduction{{name_suffix}}:
intp_t n_samples_X, X_n_samples_chunk, X_n_chunks, X_n_samples_last_chunk
intp_t n_samples_Y, Y_n_samples_chunk, Y_n_chunks, Y_n_samples_last_chunk

bint X_is_Y
bint execute_in_parallel_on_Y

@final
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ cdef class DatasetsPair{{name_suffix}}:
{{DistanceMetric}} distance_metric
intp_t n_features

readonly bint X_is_Y

cdef intp_t n_samples_X(self) noexcept nogil

cdef intp_t n_samples_Y(self) noexcept nogil
Expand Down
44 changes: 27 additions & 17 deletions sklearn/metrics/_pairwise_distances_reduction/_datasets_pair.pyx.tp
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import copy

{{py:

implementation_specific_values = [
Expand Down Expand Up @@ -84,19 +86,25 @@ cdef class DatasetsPair{{name_suffix}}:
datasets_pair: DatasetsPair{{name_suffix}}
The suited DatasetsPair{{name_suffix}} implementation.
"""
# Y_norm_squared might be propagated down to DatasetsPairs
# via metrics_kwargs when the Euclidean specialisations
# can't be used. To prevent Y_norm_squared to be passed
# X_norm_squared and Y_norm_squared might be propagated
# down to DatasetsPairs via metrics_kwargs when the Euclidean
# specialisations can't be used.
# To prevent X_norm_squared and Y_norm_squared to be passed
# down to DistanceMetrics (whose constructors would raise
# a RuntimeError), we pop it here.
# a RuntimeError), we pop them here.
if metric_kwargs is not None:
# Copying metric_kwargs not to pop "X_norm_squared"
# and "Y_norm_squared" where they are used
metric_kwargs = copy.copy(metric_kwargs)
metric_kwargs.pop("X_norm_squared", None)
metric_kwargs.pop("Y_norm_squared", None)
cdef:
{{DistanceMetric}} distance_metric = DistanceMetric.get_metric(
metric,
{{INPUT_DTYPE}},
**(metric_kwargs or {})
)
bint X_is_Y = X is Y

# Metric-specific checks that do not replace nor duplicate `check_array`.
distance_metric._validate_data(X)
Expand All @@ -106,15 +114,15 @@ cdef class DatasetsPair{{name_suffix}}:
Y_is_sparse = issparse(Y)

if not X_is_sparse and not Y_is_sparse:
return DenseDenseDatasetsPair{{name_suffix}}(X, Y, distance_metric)
return DenseDenseDatasetsPair{{name_suffix}}(X, Y, distance_metric, X_is_Y)

if X_is_sparse and Y_is_sparse:
return SparseSparseDatasetsPair{{name_suffix}}(X, Y, distance_metric)
return SparseSparseDatasetsPair{{name_suffix}}(X, Y, distance_metric, X_is_Y)

if X_is_sparse and not Y_is_sparse:
return SparseDenseDatasetsPair{{name_suffix}}(X, Y, distance_metric)
return SparseDenseDatasetsPair{{name_suffix}}(X, Y, distance_metric, X_is_Y)

return DenseSparseDatasetsPair{{name_suffix}}(X, Y, distance_metric)
return DenseSparseDatasetsPair{{name_suffix}}(X, Y, distance_metric, X_is_Y)

@classmethod
def unpack_csr_matrix(cls, X: csr_matrix):
Expand All @@ -124,8 +132,9 @@ cdef class DatasetsPair{{name_suffix}}:
X_indptr = np.asarray(X.indptr, dtype=np.int32)
return X_data, X_indices, X_indptr

def __init__(self, {{DistanceMetric}} distance_metric, intp_t n_features):
def __init__(self, {{DistanceMetric}} distance_metric, intp_t n_features, bint X_is_Y):
self.distance_metric = distance_metric
self.X_is_Y = X_is_Y
self.n_features = n_features

cdef intp_t n_samples_X(self) noexcept nogil:
Expand Down Expand Up @@ -173,8 +182,9 @@ cdef class DenseDenseDatasetsPair{{name_suffix}}(DatasetsPair{{name_suffix}}):
const {{INPUT_DTYPE_t}}[:, ::1] X,
const {{INPUT_DTYPE_t}}[:, ::1] Y,
{{DistanceMetric}} distance_metric,
bint X_is_Y,
):
super().__init__(distance_metric, n_features=X.shape[1])
super().__init__(distance_metric, n_features=X.shape[1], X_is_Y=X_is_Y)
# Arrays have already been checked
self.X = X
self.Y = Y
Expand Down Expand Up @@ -213,8 +223,8 @@ cdef class SparseSparseDatasetsPair{{name_suffix}}(DatasetsPair{{name_suffix}}):
between two vectors of (X, Y).
"""

def __init__(self, X, Y, {{DistanceMetric}} distance_metric):
super().__init__(distance_metric, n_features=X.shape[1])
def __init__(self, X, Y, {{DistanceMetric}} distance_metric, bint X_is_Y):
super().__init__(distance_metric, n_features=X.shape[1], X_is_Y=X_is_Y)

self.X_data, self.X_indices, self.X_indptr = self.unpack_csr_matrix(X)
self.Y_data, self.Y_indices, self.Y_indptr = self.unpack_csr_matrix(Y)
Expand Down Expand Up @@ -273,8 +283,8 @@ cdef class SparseDenseDatasetsPair{{name_suffix}}(DatasetsPair{{name_suffix}}):
between two vectors of (X, Y).
"""

def __init__(self, X, Y, {{DistanceMetric}} distance_metric):
super().__init__(distance_metric, n_features=X.shape[1])
def __init__(self, X, Y, {{DistanceMetric}} distance_metric, bint X_is_Y):
super().__init__(distance_metric, n_features=X.shape[1], X_is_Y=X_is_Y)

self.X_data, self.X_indices, self.X_indptr = self.unpack_csr_matrix(X)

Expand Down Expand Up @@ -371,10 +381,10 @@ cdef class DenseSparseDatasetsPair{{name_suffix}}(DatasetsPair{{name_suffix}}):
between two vectors of (X, Y).
"""

def __init__(self, X, Y, {{DistanceMetric}} distance_metric):
super().__init__(distance_metric, n_features=X.shape[1])
def __init__(self, X, Y, {{DistanceMetric}} distance_metric, bint X_is_Y):
super().__init__(distance_metric, n_features=X.shape[1], X_is_Y=X_is_Y)
# Swapping arguments on the constructor
self.datasets_pair = SparseDenseDatasetsPair{{name_suffix}}(Y, X, distance_metric)
self.datasets_pair = SparseDenseDatasetsPair{{name_suffix}}(Y, X, distance_metric, X_is_Y)

@final
cdef intp_t n_samples_X(self) noexcept nogil:
Expand Down
171 changes: 171 additions & 0 deletions sklearn/metrics/_pairwise_distances_reduction/_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from scipy.sparse import issparse, isspmatrix_csr

from ... import get_config
from ...utils._openmp_helpers import _openmp_effective_n_threads
from .._dist_metrics import BOOL_METRICS, METRIC_MAPPING64
from ._argkmin import (
ArgKmin32,
Expand All @@ -15,6 +16,10 @@
ArgKminClassMode64,
)
from ._base import _sqeuclidean_row_norms32, _sqeuclidean_row_norms64
from ._pairwise_distances import (
PairwiseDistances32,
PairwiseDistances64,
)
from ._radius_neighbors import (
RadiusNeighbors32,
RadiusNeighbors64,
Expand Down Expand Up @@ -148,6 +153,172 @@ def compute(
"""


class PairwiseDistances(BaseDistancesReductionDispatcher):
"""Compute the pairwise distances matrix for two sets of vectors.

The distance function `dist` depends on the values of the `metric`
and `metric_kwargs` parameters.

This class only computes the pairwise distances matrix without
applying any reduction on it. It shares most of the underlying
code infrastructure with reducing variants to leverage multi-thread
parallelism. However contrary to the reducing variants, no chunking
is applied to allow for contiguous write access to the final distance
array that is not expected to fit in the CPU cache in general.

This class is not meant to be instantiated, one should only use
its :meth:`compute` classmethod which handles allocation and
deallocation consistently.
"""

@classmethod
def is_usable_for(cls, X, Y, metric, metric_kwargs=None) -> bool:
"""Return True if the dispatcher can be used for the
given parameters.

Parameters
----------
X : {ndarray, sparse matrix} of shape (n_samples_X, n_features)
Input data.

Y : {ndarray, sparse matrix} of shape (n_samples_Y, n_features)
Input data.

metric : str, default='euclidean'
The distance metric to use.
For a list of available metrics, see the documentation of
:class:`~sklearn.metrics.DistanceMetric`.

metric_kwargs : dict, default=None
Keyword arguments to pass to specified metric function.

Returns
-------
True if the dispatcher can be used, else False.
"""
effective_n_threads = _openmp_effective_n_threads()

def is_euclidean(metric, metric_kwargs):
metric_kwargs = metric_kwargs or dict()
euclidean_metrics = [
"euclidean",
"sqeuclidean",
"l2",
]
# TODO: pass `p` as a standalone argument instead of a metric_kwargs.
return metric in euclidean_metrics or (
metric == "minkowski" and metric_kwargs.get("p", 2) == 2
)

Y = X if Y is None else Y

# We need to rely on `PairwiseDistances` for manhattan anyway because
# the implementation of manhattan distances on sparse data has been removed.
manhattan_metrics = ["cityblock", "l1", "manhattan"]

is_usable = super().is_usable_for(X, Y, metric) and (
(not is_euclidean(metric, metric_kwargs) and effective_n_threads != 1)
or metric in manhattan_metrics
)

return is_usable

@classmethod
def compute(
cls,
X,
Y,
metric="euclidean",
metric_kwargs=None,
strategy=None,
):
"""Return pairwise distances matrix for the given arguments.

Parameters
----------
X : ndarray or CSR matrix of shape (n_samples_X, n_features)
Input data.

Y : ndarray or CSR matrix of shape (n_samples_Y, n_features)
Input data.

metric : str, default='euclidean'
The distance metric to use.
For a list of available metrics, see the documentation of
:class:`~sklearn.metrics.DistanceMetric`.

metric_kwargs : dict, default=None
Keyword arguments to pass to specified metric function.

strategy : str, {'auto', 'parallel_on_X', 'parallel_on_Y'}, default=None
The strategy defining which dataset parallelization are made on.

For both strategies the computations happens with two nested loops,
respectively on rows of X and rows of Y.
Strategies differs on which loop (outer or inner) is made to run
in parallel with the Cython `prange` construct:

- 'parallel_on_X' dispatches rows of X uniformly on threads.
Each thread then iterates on all the rows of Y. This strategy is
embarrassingly parallel and comes with no datastructures
synchronisation.

- 'parallel_on_Y' dispatches rows of Y uniformly on threads.
Each thread processes all the rows of X in turn. This strategy is
a sequence of embarrassingly parallel subtasks (the inner loop on Y
chunks) with no intermediate datastructures synchronisation.

- 'auto' relies on a simple heuristic to choose between
'parallel_on_X' and 'parallel_on_Y': when `X.shape[0]` is large enough,
'parallel_on_X' is usually the most efficient strategy.
When `X.shape[0]` is small but `Y.shape[0]` is large, 'parallel_on_Y'
brings more opportunity for parallelism and is therefore more efficient.

- None (default) looks-up in scikit-learn configuration for
`pairwise_dist_parallel_strategy`, and use 'auto' if it is not set.

Returns
-------
pairwise_distances_matrix : ndarray of shape (n_samples_X, n_samples_Y)
The pairwise distances matrix.

Notes
-----
This public classmethod is responsible for introspecting the arguments
values to dispatch to the private dtype-specialized implementation of
:class:`PairwiseDistances`.

All temporarily allocated datastructures necessary for the concrete
implementations are therefore freed when this classmethod returns.

This allows entirely decoupling the API entirely from the
implementation details whilst maintaining RAII.
"""
Y = X if Y is None else Y
if X.dtype == Y.dtype == np.float64:
return PairwiseDistances64.compute(
X=X,
Y=Y,
metric=metric,
metric_kwargs=metric_kwargs,
strategy=strategy,
)

if X.dtype == Y.dtype == np.float32:
return PairwiseDistances32.compute(
X=X,
Y=Y,
metric=metric,
metric_kwargs=metric_kwargs,
strategy=strategy,
)

raise ValueError(
"Only float64 or float32 datasets pairs are supported, but "
f"got: X.dtype={X.dtype} and Y.dtype={Y.dtype}."
)


class ArgKmin(BaseDistancesReductionDispatcher):
"""Compute the argkmin of row vectors of X on the ones of Y.

Expand Down
Loading