Skip to content

FEA Introduce PairwiseDistances #23958

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
0383b6d
Introduce PairwiseDistances
jjerphan Jun 22, 2022
e1bf6c9
Merge branch 'main' into maint/pairwise_distances-pdr-backend
jjerphan Jul 2, 2022
5889e6f
WIP
jjerphan Jul 4, 2022
82347ff
Merge branch 'main' into maint/pairwise_distances-pdr-backend
jjerphan Jul 11, 2022
9101daf
fixup! Introduce PairwiseDistances
jjerphan Jul 11, 2022
dff1aa2
Post-merge fix
jjerphan Jul 11, 2022
78f066c
Merge branch 'main' into feat/pairwise_distances-pdr-backend
jjerphan Jul 19, 2022
44d9453
Merge branch 'main' into feat/pairwise_distances-pdr-backend
jjerphan Jul 29, 2022
26f53a8
Merge branch 'main' into feat/pairwise_distances-pdr-backend
jjerphan Aug 1, 2022
726858b
Merge branch 'main' into feat/pairwise_distances-pdr-backend
jjerphan Aug 26, 2022
f573a59
fixup! Merge branch 'main' into feat/pairwise_distances-pdr-backend
jjerphan Aug 26, 2022
24fa994
Merge branch 'main' into feat/pairwise_distances-pdr-backend
jjerphan Sep 2, 2022
8e17871
Do not offset by X_start and Y_start
jjerphan Sep 2, 2022
1f1a3ce
Do not progated metric_kwargs unneedlessly
jjerphan Sep 2, 2022
800e6b3
Merge branch 'main' into feat/pairwise_distances-pdr-backend
jjerphan Sep 18, 2022
045a7b2
Use the proper vectors' indices
jjerphan Sep 19, 2022
2b70db8
Port PairwiseDistances to support float32 datasets
jjerphan Sep 19, 2022
f9431cc
Simplify indices
jjerphan Sep 19, 2022
24cfd04
Adapt implementations for a previous 'sqeuclidean' specification
jjerphan Sep 19, 2022
527a43c
TST Downcast distance matrix to float32
jjerphan Sep 19, 2022
d26fafb
Adapt instanciation
jjerphan Sep 19, 2022
81b74e1
Use PairwiseDistances as a back-end for haversine_distances
jjerphan Sep 19, 2022
0aa688e
Use PairwiseDistances as a back-end for manhattan_distances
jjerphan Sep 19, 2022
0006764
Use PairwiseDistances as a back-end for euclidean_distances
jjerphan Sep 19, 2022
05b5ae7
Remove duplicated line
jjerphan Sep 19, 2022
709cdf2
Merge branch 'scikit-learn:main' into feat/pairwise_distances-pdr-bac…
jjerphan Sep 19, 2022
53d6c07
Merge branch 'main' into feat/pairwise_distances-pdr-backend
jjerphan Sep 21, 2022
5442ddd
TST Remove checks on errors now that minkowski with sparse data is su…
jjerphan Sep 21, 2022
7087ad1
TST PairwiseDistances factory methods
jjerphan Sep 22, 2022
34aad28
Merge branch 'main' into feat/pairwise_distances-pdr-backend
jjerphan Sep 23, 2022
33560b2
DOC Remove comment regarding X_is_Y
jjerphan Sep 27, 2022
cfc145a
Preserve dtype for PairwiseDistances
jjerphan Sep 27, 2022
54ebaeb
Merge branch 'main' into feat/pairwise_distances-pdr-backend
jjerphan Sep 27, 2022
8604ec6
TST Remove TODO now that dtypes are preserved
jjerphan Sep 27, 2022
62a751a
DOC Remove and adapt some comments
jjerphan Sep 27, 2022
259edc1
MAINT Remove _sparse_manhattan
jjerphan Sep 27, 2022
50f032e
Use PairwiseDistance as a back-end for _euclidean_distances
jjerphan Sep 27, 2022
40d0a0b
MAINT Keep classmethods at the top
jjerphan Sep 27, 2022
9017aa7
Safely pack {X,Y}_squared_norms
jjerphan Sep 28, 2022
dd136a7
TST Adapt test_euclidean_distances_extreme_values
jjerphan Sep 28, 2022
91b3205
DOC Improve docstrings
jjerphan Sep 29, 2022
0abe560
Simplify manhattan_distances
jjerphan Sep 29, 2022
b5003f9
fixup! DOC Improve docstrings
jjerphan Sep 29, 2022
f38b010
fixup! TST Adapt test_euclidean_distances_extreme_values
jjerphan Sep 29, 2022
f2d8cbe
Apply review comments
jjerphan Oct 7, 2022
d0196f1
Merge branch 'main' into feat/pairwise_distances-pdr-backend
jjerphan Oct 7, 2022
3673479
Rework poping {X,Y}_norm_squared in DatasetsPair.get_for
jjerphan Oct 7, 2022
0c74856
Rework poping {X,Y}_norm_squared in EuclideanPairwiseDistances
jjerphan Oct 7, 2022
4fccb5f
DOC Add references to #24745
jjerphan Oct 24, 2022
e58de02
Merge branch 'main' into feat/pairwise_distances-pdr-backend
jjerphan Oct 28, 2022
eabe44d
Merge remote-tracking branch 'upstream/main' into feat/pairwise_dista…
jjerphan Nov 15, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -95,5 +95,7 @@ sklearn/metrics/_pairwise_distances_reduction/_datasets_pair.pxd
sklearn/metrics/_pairwise_distances_reduction/_datasets_pair.pyx
sklearn/metrics/_pairwise_distances_reduction/_middle_term_computer.pxd
sklearn/metrics/_pairwise_distances_reduction/_middle_term_computer.pyx
sklearn/metrics/_pairwise_distances_reduction/_pairwise_distances.pxd
sklearn/metrics/_pairwise_distances_reduction/_pairwise_distances.pyx
sklearn/metrics/_pairwise_distances_reduction/_radius_neighbors.pxd
sklearn/metrics/_pairwise_distances_reduction/_radius_neighbors.pyx
2 changes: 2 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@ ignore =
sklearn/metrics/_pairwise_distances_reduction/_datasets_pair.pyx
sklearn/metrics/_pairwise_distances_reduction/_middle_term_computer.pxd
sklearn/metrics/_pairwise_distances_reduction/_middle_term_computer.pyx
sklearn/metrics/_pairwise_distances_reduction/_pairwise_distances.pxd
sklearn/metrics/_pairwise_distances_reduction/_pairwise_distances.pyx
Comment on lines +88 to +89
Copy link
Member Author

@jjerphan jjerphan Sep 19, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Side-note: it would be nice not to have duplication with identical lines in .gitignore, yet I am not sure this is doable.

sklearn/metrics/_pairwise_distances_reduction/_radius_neighbors.pxd
sklearn/metrics/_pairwise_distances_reduction/_radius_neighbors.pyx

Expand Down
7 changes: 7 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@
"sklearn.metrics._pairwise_distances_reduction._datasets_pair",
"sklearn.metrics._pairwise_distances_reduction._middle_term_computer",
"sklearn.metrics._pairwise_distances_reduction._base",
"sklearn.metrics._pairwise_distances_reduction._pairwise_distances",
"sklearn.metrics._pairwise_distances_reduction._argkmin",
"sklearn.metrics._pairwise_distances_reduction._radius_neighbors",
"sklearn.metrics._pairwise_fast",
Expand Down Expand Up @@ -327,6 +328,12 @@ def check_package_status(package, min_version):
"include_np": True,
"extra_compile_args": ["-std=c++11"],
},
{
"sources": ["_pairwise_distances.pyx.tp", "_pairwise_distances.pxd.tp"],
"language": "c++",
"include_np": True,
"extra_compile_args": ["-std=c++11"],
},
{
"sources": ["_argkmin.pyx.tp", "_argkmin.pxd.tp"],
"language": "c++",
Expand Down
5 changes: 5 additions & 0 deletions sklearn/metrics/_pairwise_distances_reduction/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,13 +89,18 @@
from ._dispatcher import (
BaseDistancesReductionDispatcher,
ArgKmin,
PairwiseDistances,
RadiusNeighbors,
sqeuclidean_row_norms,
)

from ._pairwise_distances import _precompute_metric_params

__all__ = [
"BaseDistancesReductionDispatcher",
"ArgKmin",
"PairwiseDistances",
"RadiusNeighbors",
"sqeuclidean_row_norms",
"_precompute_metric_params",
]
1 change: 1 addition & 0 deletions sklearn/metrics/_pairwise_distances_reduction/_base.pxd.tp
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ cdef class BaseDistancesReduction{{name_suffix}}:
ITYPE_t n_samples_X, X_n_samples_chunk, X_n_chunks, X_n_samples_last_chunk
ITYPE_t n_samples_Y, Y_n_samples_chunk, Y_n_chunks, Y_n_samples_last_chunk

bint X_is_Y
bint execute_in_parallel_on_Y

@final
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ cdef class DatasetsPair{{name_suffix}}:
{{DistanceMetric}} distance_metric
ITYPE_t n_features

readonly bint X_is_Y

cdef ITYPE_t n_samples_X(self) nogil

cdef ITYPE_t n_samples_Y(self) nogil
Expand Down
44 changes: 27 additions & 17 deletions sklearn/metrics/_pairwise_distances_reduction/_datasets_pair.pyx.tp
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import copy

{{py:

implementation_specific_values = [
Expand Down Expand Up @@ -91,18 +93,24 @@ cdef class DatasetsPair{{name_suffix}}:
datasets_pair: DatasetsPair{{name_suffix}}
The suited DatasetsPair{{name_suffix}} implementation.
"""
# Y_norm_squared might be propagated down to DatasetsPairs
# via metrics_kwargs when the Euclidean specialisations
# can't be used. To prevent Y_norm_squared to be passed
# X_norm_squared and Y_norm_squared might be propagated
# down to DatasetsPairs via metrics_kwargs when the Euclidean
# specialisations can't be used.
# To prevent X_norm_squared and Y_norm_squared to be passed
# down to DistanceMetrics (whose constructors would raise
# a RuntimeError), we pop it here.
# a RuntimeError), we pop them here.
if metric_kwargs is not None:
# Copying metric_kwargs not to pop "X_norm_squared"
# and "Y_norm_squared" where they are used
metric_kwargs = copy.copy(metric_kwargs)
metric_kwargs.pop("X_norm_squared", None)
metric_kwargs.pop("Y_norm_squared", None)
cdef:
{{DistanceMetric}} distance_metric = {{DistanceMetric}}.get_metric(
metric,
**(metric_kwargs or {})
)
bint X_is_Y = X is Y

# Metric-specific checks that do not replace nor duplicate `check_array`.
distance_metric._validate_data(X)
Expand All @@ -112,15 +120,15 @@ cdef class DatasetsPair{{name_suffix}}:
Y_is_sparse = issparse(Y)

if not X_is_sparse and not Y_is_sparse:
return DenseDenseDatasetsPair{{name_suffix}}(X, Y, distance_metric)
return DenseDenseDatasetsPair{{name_suffix}}(X, Y, distance_metric, X_is_Y)

if X_is_sparse and Y_is_sparse:
return SparseSparseDatasetsPair{{name_suffix}}(X, Y, distance_metric)
return SparseSparseDatasetsPair{{name_suffix}}(X, Y, distance_metric, X_is_Y)

if X_is_sparse and not Y_is_sparse:
return SparseDenseDatasetsPair{{name_suffix}}(X, Y, distance_metric)
return SparseDenseDatasetsPair{{name_suffix}}(X, Y, distance_metric, X_is_Y)

return DenseSparseDatasetsPair{{name_suffix}}(X, Y, distance_metric)
return DenseSparseDatasetsPair{{name_suffix}}(X, Y, distance_metric, X_is_Y)

@classmethod
def unpack_csr_matrix(cls, X: csr_matrix):
Expand All @@ -130,8 +138,9 @@ cdef class DatasetsPair{{name_suffix}}:
X_indptr = np.asarray(X.indptr, dtype=SPARSE_INDEX_TYPE)
return X_data, X_indices, X_indptr

def __init__(self, {{DistanceMetric}} distance_metric, ITYPE_t n_features):
def __init__(self, {{DistanceMetric}} distance_metric, ITYPE_t n_features, bint X_is_Y):
self.distance_metric = distance_metric
self.X_is_Y = X_is_Y
self.n_features = n_features

cdef ITYPE_t n_samples_X(self) nogil:
Expand Down Expand Up @@ -179,8 +188,9 @@ cdef class DenseDenseDatasetsPair{{name_suffix}}(DatasetsPair{{name_suffix}}):
const {{INPUT_DTYPE_t}}[:, ::1] X,
const {{INPUT_DTYPE_t}}[:, ::1] Y,
{{DistanceMetric}} distance_metric,
bint X_is_Y,
):
super().__init__(distance_metric, n_features=X.shape[1])
super().__init__(distance_metric, n_features=X.shape[1], X_is_Y=X_is_Y)
# Arrays have already been checked
self.X = X
self.Y = Y
Expand Down Expand Up @@ -219,8 +229,8 @@ cdef class SparseSparseDatasetsPair{{name_suffix}}(DatasetsPair{{name_suffix}}):
between two vectors of (X, Y).
"""

def __init__(self, X, Y, {{DistanceMetric}} distance_metric):
super().__init__(distance_metric, n_features=X.shape[1])
def __init__(self, X, Y, {{DistanceMetric}} distance_metric, bint X_is_Y):
super().__init__(distance_metric, n_features=X.shape[1], X_is_Y=X_is_Y)

self.X_data, self.X_indices, self.X_indptr = self.unpack_csr_matrix(X)
self.Y_data, self.Y_indices, self.Y_indptr = self.unpack_csr_matrix(Y)
Expand Down Expand Up @@ -279,8 +289,8 @@ cdef class SparseDenseDatasetsPair{{name_suffix}}(DatasetsPair{{name_suffix}}):
between two vectors of (X, Y).
"""

def __init__(self, X, Y, {{DistanceMetric}} distance_metric):
super().__init__(distance_metric, n_features=X.shape[1])
def __init__(self, X, Y, {{DistanceMetric}} distance_metric, bint X_is_Y):
super().__init__(distance_metric, n_features=X.shape[1], X_is_Y=X_is_Y)

self.X_data, self.X_indices, self.X_indptr = self.unpack_csr_matrix(X)

Expand Down Expand Up @@ -377,10 +387,10 @@ cdef class DenseSparseDatasetsPair{{name_suffix}}(DatasetsPair{{name_suffix}}):
between two vectors of (X, Y).
"""

def __init__(self, X, Y, {{DistanceMetric}} distance_metric):
super().__init__(distance_metric, n_features=X.shape[1])
def __init__(self, X, Y, {{DistanceMetric}} distance_metric, bint X_is_Y):
super().__init__(distance_metric, n_features=X.shape[1], X_is_Y=X_is_Y)
# Swapping arguments on the constructor
self.datasets_pair = SparseDenseDatasetsPair{{name_suffix}}(Y, X, distance_metric)
self.datasets_pair = SparseDenseDatasetsPair{{name_suffix}}(Y, X, distance_metric, X_is_Y)

@final
cdef ITYPE_t n_samples_X(self) nogil:
Expand Down
138 changes: 134 additions & 4 deletions sklearn/metrics/_pairwise_distances_reduction/_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,12 @@
ArgKmin64,
ArgKmin32,
)

from ._pairwise_distances import (
PairwiseDistances64,
PairwiseDistances32,
)

from ._radius_neighbors import (
RadiusNeighbors64,
RadiusNeighbors32,
Expand Down Expand Up @@ -168,6 +174,132 @@ def compute(
"""


class PairwiseDistances(BaseDistancesReductionDispatcher):
"""Compute the pairwise distances matrix for two sets of vectors.

The distance function `dist` depends on the values of the `metric`
and `metric_kwargs` parameters.

This class only computes the pairwise distances matrix without
applying any reduction on it. It shares most of the underlying
code infrastructure with reducing variants to leverage
cache-aware chunking and multi-thread parallelism.

This class is not meant to be instantiated, one should only use
its :meth:`compute` classmethod which handles allocation and
deallocation consistently.
"""

@classmethod
def is_usable_for(cls, X, Y, metric) -> bool:
Y = X if Y is None else Y
return super().is_usable_for(X, Y, metric)

@classmethod
def compute(
cls,
X,
Y,
metric="euclidean",
chunk_size=None,
metric_kwargs=None,
strategy=None,
):
"""Return pairwise distances matrix for the given arguments.

Parameters
----------
X : ndarray or CSR matrix of shape (n_samples_X, n_features)
Input data.

Y : ndarray or CSR matrix of shape (n_samples_Y, n_features)
Input data.

metric : str, default='euclidean'
The distance metric to use.
For a list of available metrics, see the documentation of
:class:`~sklearn.metrics.DistanceMetric`.

chunk_size : int, default=None,
The number of vectors per chunk. If None (default) looks-up in
scikit-learn configuration for `pairwise_dist_chunk_size`,
and use 256 if it is not set.

metric_kwargs : dict, default=None
Keyword arguments to pass to specified metric function.

strategy : str, {'auto', 'parallel_on_X', 'parallel_on_Y'}, default=None
The chunking strategy defining which dataset parallelization are made on.

For both strategies the computations happens with two nested loops,
respectively on chunks of X and chunks of Y.
Strategies differs on which loop (outer or inner) is made to run
in parallel with the Cython `prange` construct:

- 'parallel_on_X' dispatches chunks of X uniformly on threads.
Each thread then iterates on all the chunks of Y. This strategy is
embarrassingly parallel and comes with no datastructures
synchronisation.

- 'parallel_on_Y' dispatches chunks of Y uniformly on threads.
Each thread processes all the chunks of X in turn. This strategy is
a sequence of embarrassingly parallel subtasks (the inner loop on Y
chunks) with intermediate datastructures synchronisation at each
iteration of the sequential outer loop on X chunks.

- 'auto' relies on a simple heuristic to choose between
'parallel_on_X' and 'parallel_on_Y': when `X.shape[0]` is large enough,
'parallel_on_X' is usually the most efficient strategy.
When `X.shape[0]` is small but `Y.shape[0]` is large, 'parallel_on_Y'
brings more opportunity for parallelism and is therefore more efficient.

- None (default) looks-up in scikit-learn configuration for
`pairwise_dist_parallel_strategy`, and use 'auto' if it is not set.

Returns
-------
pairwise_distances_matrix : ndarray of shape (n_samples_X, n_samples_Y)
The pairwise distances matrix.

Notes
-----
This public classmethod is responsible for introspecting the arguments
values to dispatch to the private dtype-specialized implementation of
:class:`PairwiseDistances`.

All temporarily allocated datastructures necessary for the concrete
implementation are therefore freed when this classmethod returns.

This allows entirely decoupling the API entirely from the
implementation details whilst maintaining RAII.
"""
Y = X if Y is None else Y
if X.dtype == Y.dtype == np.float64:
return PairwiseDistances64.compute(
X=X,
Y=Y,
metric=metric,
chunk_size=chunk_size,
metric_kwargs=metric_kwargs,
strategy=strategy,
)

if X.dtype == Y.dtype == np.float32:
return PairwiseDistances32.compute(
X=X,
Y=Y,
metric=metric,
chunk_size=chunk_size,
metric_kwargs=metric_kwargs,
strategy=strategy,
)

raise ValueError(
"Only float64 or float32 datasets pairs are supported, but "
f"got: X.dtype={X.dtype} and Y.dtype={Y.dtype}."
)


class ArgKmin(BaseDistancesReductionDispatcher):
"""Compute the argkmin of row vectors of X on the ones of Y.

Expand Down Expand Up @@ -243,7 +375,7 @@ def compute(
'parallel_on_X' and 'parallel_on_Y': when `X.shape[0]` is large enough,
'parallel_on_X' is usually the most efficient strategy.
When `X.shape[0]` is small but `Y.shape[0]` is large, 'parallel_on_Y'
brings more opportunity for parallelism and is therefore more efficient
brings more opportunity for parallelism and is therefore more efficient.

- None (default) looks-up in scikit-learn configuration for
`pairwise_dist_parallel_strategy`, and use 'auto' if it is not set.
Expand Down Expand Up @@ -382,9 +514,7 @@ def compute(
'parallel_on_X' and 'parallel_on_Y': when `X.shape[0]` is large enough,
'parallel_on_X' is usually the most efficient strategy.
When `X.shape[0]` is small but `Y.shape[0]` is large, 'parallel_on_Y'
brings more opportunity for parallelism and is therefore more efficient
despite the synchronization step at each iteration of the outer loop
on chunks of `X`.
brings more opportunity for parallelism and is therefore more efficient.

- None (default) looks-up in scikit-learn configuration for
`pairwise_dist_parallel_strategy`, and use 'auto' if it is not set.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
{{py:

implementation_specific_values = [
# Values are the following ones:
#
# name_suffix, INPUT_DTYPE_t
#
#
('64', 'cnp.float64_t'),
('32', 'cnp.float32_t')
]

}}
cimport numpy as cnp

from ...utils._typedefs cimport DTYPE_t
{{for name_suffix, INPUT_DTYPE_t in implementation_specific_values}}

from ._base cimport BaseDistancesReduction{{name_suffix}}
from ._middle_term_computer cimport MiddleTermComputer{{name_suffix}}


cdef class PairwiseDistances{{name_suffix}}(BaseDistancesReduction{{name_suffix}}):
"""float{{name_suffix}} implementation of PairwiseDistances."""

cdef:
{{INPUT_DTYPE_t}}[:, ::1] pairwise_distances_matrix


cdef class EuclideanPairwiseDistances{{name_suffix}}(PairwiseDistances{{name_suffix}}):
"""EuclideanDistance-specialized float{{name_suffix}} implementation for PairwiseDistances."""
cdef:
MiddleTermComputer{{name_suffix}} middle_term_computer
const DTYPE_t[::1] X_norm_squared
const DTYPE_t[::1] Y_norm_squared

bint use_squared_distances

{{endfor}}
Loading