Skip to content

MNT Updated DistanceMetric API with new ABC/interface #26471

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jun 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions sklearn/cluster/_agglomerative.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from ..metrics.pairwise import paired_distances
from ..metrics.pairwise import _VALID_METRICS
from ..metrics import DistanceMetric
from ..metrics._dist_metrics import METRIC_MAPPING
from ..metrics._dist_metrics import METRIC_MAPPING64
from ..utils import check_array
from ..utils._fast_dict import IntFloatDict
from ..utils.graph import _fix_connected_components
Expand Down Expand Up @@ -543,7 +543,7 @@ def linkage_tree(
linkage == "single"
and affinity != "precomputed"
and not callable(affinity)
and affinity in METRIC_MAPPING
and affinity in METRIC_MAPPING64
):
# We need the fast cythonized metric from neighbors
dist_metric = DistanceMetric.get_metric(affinity)
Expand Down
4 changes: 2 additions & 2 deletions sklearn/cluster/_hdbscan/_linkage.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ cimport numpy as cnp
from libc.float cimport DBL_MAX

import numpy as np
from ...metrics._dist_metrics cimport DistanceMetric
from ...metrics._dist_metrics cimport DistanceMetric64
from ...cluster._hierarchical_fast cimport UnionFind
from ...cluster._hdbscan._tree cimport HIERARCHY_t
from ...cluster._hdbscan._tree import HIERARCHY_dtype
Expand Down Expand Up @@ -111,7 +111,7 @@ cpdef cnp.ndarray[MST_edge_t, ndim=1, mode='c'] mst_from_mutual_reachability(
cpdef cnp.ndarray[MST_edge_t, ndim=1, mode='c'] mst_from_data_matrix(
const float64_t[:, ::1] raw_data,
const float64_t[::1] core_distances,
DistanceMetric dist_metric,
DistanceMetric64 dist_metric,
float64_t alpha=1.0
):
"""Compute the Minimum Spanning Tree (MST) representation of the mutual-
Expand Down
8 changes: 4 additions & 4 deletions sklearn/cluster/_hierarchical_fast.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np
cimport cython

from ..metrics._dist_metrics cimport DistanceMetric
from ..metrics._dist_metrics cimport DistanceMetric64
from ..utils._fast_dict cimport IntFloatDict
from ..utils._typedefs cimport float64_t, intp_t, uint8_t

Expand Down Expand Up @@ -427,7 +427,7 @@ def single_linkage_label(L):
# Implements MST-LINKAGE-CORE from https://arxiv.org/abs/1109.2378
def mst_linkage_core(
const float64_t [:, ::1] raw_data,
DistanceMetric dist_metric):
DistanceMetric64 dist_metric):
"""
Compute the necessary elements of a minimum spanning
tree for computation of single linkage clustering. This
Expand All @@ -444,8 +444,8 @@ def mst_linkage_core(
raw_data: array of shape (n_samples, n_features)
The array of feature data to be clustered. Must be C-aligned

dist_metric: DistanceMetric
A DistanceMetric object conforming to the API from
dist_metric: DistanceMetric64
A DistanceMetric64 object conforming to the API from
``sklearn.metrics._dist_metrics.pxd`` that will be
used to compute distances.

Expand Down
21 changes: 6 additions & 15 deletions sklearn/metrics/_dist_metrics.pxd.tp
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,8 @@
implementation_specific_values = [
# Values are the following ones:
#
# name_suffix, INPUT_DTYPE_t, INPUT_DTYPE
#
# On the first hand, an empty string is used for `name_suffix`
# for the float64 case as to still be able to expose the original
# float64 implementation under the same API, namely `DistanceMetric`.
#
# On the other hand, '32' is used for `name_suffix` for the float32
# case to remove ambiguity and use `DistanceMetric32`, which is not
# publicly exposed.
#
# The metric mapping is adapted accordingly to route to the correct
# implementations.
#
('', 'float64_t', 'np.float64'),
# name_suffix, INPUT_DTYPE_t, INPUT_DTYPE
('64', 'float64_t', 'np.float64'),
('32', 'float32_t', 'np.float32')
]

Expand All @@ -25,6 +13,9 @@ from libc.math cimport sqrt, exp

from ..utils._typedefs cimport float64_t, float32_t, int32_t, intp_t

cdef class DistanceMetric:
pass

{{for name_suffix, INPUT_DTYPE_t, INPUT_DTYPE in implementation_specific_values}}

######################################################################
Expand Down Expand Up @@ -68,7 +59,7 @@ cdef inline float64_t euclidean_rdist_to_dist{{name_suffix}}(const {{INPUT_DTYPE

######################################################################
# DistanceMetric{{name_suffix}} base class
cdef class DistanceMetric{{name_suffix}}:
cdef class DistanceMetric{{name_suffix}}(DistanceMetric):
# The following attributes are required for a few of the subclasses.
# we must define them here so that cython's limited polymorphism will work.
# Because we don't expect to instantiate a lot of these objects, the
Expand Down
47 changes: 31 additions & 16 deletions sklearn/metrics/_dist_metrics.pyx.tp
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,8 @@
implementation_specific_values = [
# Values are the following ones:
#
# name_suffix, INPUT_DTYPE_t, INPUT_DTYPE
#
# On the first hand, an empty string is used for `name_suffix`
# for the float64 case as to still be able to expose the original
# float64 implementation under the same API, namely `DistanceMetric`.
#
# On the other hand, '32' bit is used for `name_suffix` for the float32
# case to remove ambiguity and use `DistanceMetric32`, which is not
# publicly exposed.
#
# The metric mapping is adapted accordingly to route to the correct
# implementations.
#
('', 'float64_t', 'np.float64'),
# name_suffix, INPUT_DTYPE_t, INPUT_DTYPE
('64', 'float64_t', 'np.float64'),
('32', 'float32_t', 'np.float32')
]

Expand Down Expand Up @@ -73,9 +61,36 @@ def get_valid_metric_ids(L):
>>> sorted(L)
['cityblock', 'euclidean', 'l1', 'l2', 'manhattan']
"""
return [key for (key, val) in METRIC_MAPPING.items()
return [key for (key, val) in METRIC_MAPPING64.items()
if (val.__name__ in L) or (val in L)]

cdef class DistanceMetric:
@classmethod
def get_metric(cls, metric, dtype=np.float64, **kwargs):
"""Get the given distance metric from the string identifier.

See the docstring of DistanceMetric for a list of available metrics.

Parameters
----------
metric : str or class name
The distance metric to use
dtype : {np.float32, np.float64}, default=np.float64
The dtype of the data on which the metric will be applied
**kwargs
additional arguments will be passed to the requested metric
"""
if dtype == np.float32:
specialized_class = DistanceMetric32
elif dtype == np.float64:
specialized_class = DistanceMetric64
else:
raise ValueError(
f"Unexpected dtype {dtype} provided. Please select a dtype from"
" {np.float32, np.float64}"
)

return specialized_class.get_metric(metric, **kwargs)

{{for name_suffix, INPUT_DTYPE_t, INPUT_DTYPE in implementation_specific_values}}

Expand Down Expand Up @@ -125,7 +140,7 @@ cdef {{INPUT_DTYPE_t}} INF{{name_suffix}} = np.inf

######################################################################
# Distance Metric Classes
cdef class DistanceMetric{{name_suffix}}:
cdef class DistanceMetric{{name_suffix}}(DistanceMetric):
"""DistanceMetric class

This class provides a uniform interface to fast distance metric
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,14 @@
implementation_specific_values = [
# Values are the following ones:
#
# name_suffix, DistanceMetric, INPUT_DTYPE_t, INPUT_DTYPE
#
# We use DistanceMetric for float64 for backward naming compatibility.
#
('64', 'DistanceMetric', 'float64_t'),
# name_suffix, INPUT_DTYPE_t, INPUT_DTYPE
('64', 'DistanceMetric64', 'float64_t'),
('32', 'DistanceMetric32', 'float32_t')
]

}}
from ...utils._typedefs cimport float64_t, float32_t, int32_t, intp_t
from ...metrics._dist_metrics cimport DistanceMetric, DistanceMetric32
from ...metrics._dist_metrics cimport DistanceMetric64, DistanceMetric32, DistanceMetric

{{for name_suffix, DistanceMetric, INPUT_DTYPE_t in implementation_specific_values}}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,8 @@
implementation_specific_values = [
# Values are the following ones:
#
# name_suffix, DistanceMetric, INPUT_DTYPE_t, INPUT_DTYPE
#
# We use DistanceMetric for float64 for backward naming compatibility.
#
('64', 'DistanceMetric', 'float64_t', 'np.float64'),
# name_suffix, INPUT_DTYPE_t, INPUT_DTYPE
('64', 'DistanceMetric64', 'float64_t', 'np.float64'),
('32', 'DistanceMetric32', 'float32_t', 'np.float32')
]

Expand All @@ -17,7 +14,6 @@ import numpy as np
from cython cimport final

from ...utils._typedefs cimport float64_t, float32_t, intp_t
from ...metrics._dist_metrics cimport DistanceMetric

from scipy.sparse import issparse, csr_matrix

Expand Down Expand Up @@ -96,8 +92,9 @@ cdef class DatasetsPair{{name_suffix}}:
if metric_kwargs is not None:
metric_kwargs.pop("Y_norm_squared", None)
cdef:
{{DistanceMetric}} distance_metric = {{DistanceMetric}}.get_metric(
{{DistanceMetric}} distance_metric = DistanceMetric.get_metric(
metric,
{{INPUT_DTYPE}},
**(metric_kwargs or {})
)

Expand Down
4 changes: 2 additions & 2 deletions sklearn/metrics/_pairwise_distances_reduction/_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from scipy.sparse import isspmatrix_csr, issparse

from .._dist_metrics import BOOL_METRICS, METRIC_MAPPING
from .._dist_metrics import BOOL_METRICS, METRIC_MAPPING64

from ._base import _sqeuclidean_row_norms32, _sqeuclidean_row_norms64
from ._argkmin import (
Expand Down Expand Up @@ -76,7 +76,7 @@ def valid_metrics(cls) -> List[str]:
"hamming",
*BOOL_METRICS,
}
return sorted(({"sqeuclidean"} | set(METRIC_MAPPING.keys())) - excluded)
return sorted(({"sqeuclidean"} | set(METRIC_MAPPING64.keys())) - excluded)

@classmethod
def is_usable_for(cls, X, Y, metric) -> bool:
Expand Down
50 changes: 30 additions & 20 deletions sklearn/metrics/tests/test_dist_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@

from sklearn.metrics._dist_metrics import (
BOOL_METRICS,
# Unexposed private DistanceMetric for 32 bit
DistanceMetric32,
DistanceMetric64,
)

from sklearn.utils import check_random_state
Expand Down Expand Up @@ -64,9 +64,6 @@ def dist_func(x1, x2, p):
)
@pytest.mark.parametrize("X, Y", [(X64, Y64), (X32, Y32), (X_mmap, Y_mmap)])
def test_cdist(metric_param_grid, X, Y):
DistanceMetricInterface = (
DistanceMetric if X.dtype == Y.dtype == np.float64 else DistanceMetric32
)
metric, param_grid = metric_param_grid
keys = param_grid.keys()
X_csr, Y_csr = sp.csr_matrix(X), sp.csr_matrix(Y)
Expand All @@ -83,7 +80,7 @@ def test_cdist(metric_param_grid, X, Y):

D_scipy_cdist = cdist(X, Y, metric, **kwargs)

dm = DistanceMetricInterface.get_metric(metric, **kwargs)
dm = DistanceMetric.get_metric(metric, X.dtype, **kwargs)

# DistanceMetric.pairwise must be consistent for all
# combinations of formats in {sparse, dense}.
Expand Down Expand Up @@ -141,9 +138,6 @@ def test_cdist_bool_metric(metric, X_bool, Y_bool):
)
@pytest.mark.parametrize("X", [X64, X32, X_mmap])
def test_pdist(metric_param_grid, X):
DistanceMetricInterface = (
DistanceMetric if X.dtype == np.float64 else DistanceMetric32
)
metric, param_grid = metric_param_grid
keys = param_grid.keys()
X_csr = sp.csr_matrix(X)
Expand All @@ -160,7 +154,7 @@ def test_pdist(metric_param_grid, X):

D_scipy_pdist = cdist(X, X, metric, **kwargs)

dm = DistanceMetricInterface.get_metric(metric, **kwargs)
dm = DistanceMetric.get_metric(metric, X.dtype, **kwargs)
D_sklearn = dm.pairwise(X)
assert D_sklearn.flags.c_contiguous
assert_allclose(D_sklearn, D_scipy_pdist, **rtol_dict)
Expand Down Expand Up @@ -189,8 +183,8 @@ def test_distance_metrics_dtype_consistency(metric_param_grid):

for vals in itertools.product(*param_grid.values()):
kwargs = dict(zip(keys, vals))
dm64 = DistanceMetric.get_metric(metric, **kwargs)
dm32 = DistanceMetric32.get_metric(metric, **kwargs)
dm64 = DistanceMetric.get_metric(metric, np.float64, **kwargs)
dm32 = DistanceMetric.get_metric(metric, np.float32, **kwargs)

D64 = dm64.pairwise(X64)
D32 = dm32.pairwise(X32)
Expand Down Expand Up @@ -230,9 +224,6 @@ def test_pdist_bool_metrics(metric, X_bool):
)
@pytest.mark.parametrize("X", [X64, X32])
def test_pickle(writable_kwargs, metric_param_grid, X):
DistanceMetricInterface = (
DistanceMetric if X.dtype == np.float64 else DistanceMetric32
)
metric, param_grid = metric_param_grid
keys = param_grid.keys()
for vals in itertools.product(*param_grid.values()):
Expand All @@ -242,7 +233,7 @@ def test_pickle(writable_kwargs, metric_param_grid, X):
if isinstance(val, np.ndarray):
val.setflags(write=writable_kwargs)
kwargs = dict(zip(keys, vals))
dm = DistanceMetricInterface.get_metric(metric, **kwargs)
dm = DistanceMetric.get_metric(metric, X.dtype, **kwargs)
D1 = dm.pairwise(X)
dm2 = pickle.loads(pickle.dumps(dm))
D2 = dm2.pairwise(X)
Expand All @@ -261,10 +252,6 @@ def test_pickle_bool_metrics(metric, X_bool):

@pytest.mark.parametrize("X, Y", [(X64, Y64), (X32, Y32), (X_mmap, Y_mmap)])
def test_haversine_metric(X, Y):
DistanceMetricInterface = (
DistanceMetric if X.dtype == np.float64 else DistanceMetric32
)

# The Haversine DistanceMetric only works on 2 features.
X = np.asarray(X[:, :2])
Y = np.asarray(Y[:, :2])
Expand All @@ -286,7 +273,7 @@ def haversine_slow(x1, x2):
for j, yj in enumerate(Y):
D_reference[i, j] = haversine_slow(xi, yj)

haversine = DistanceMetricInterface.get_metric("haversine")
haversine = DistanceMetric.get_metric("haversine", X.dtype)

D_sklearn = haversine.pairwise(X, Y)
assert_allclose(
Expand Down Expand Up @@ -389,3 +376,26 @@ def test_minkowski_metric_validate_weights_size():
)
with pytest.raises(ValueError, match=msg):
dm.pairwise(X64, Y64)


@pytest.mark.parametrize("metric, metric_kwargs", METRICS_DEFAULT_PARAMS)
@pytest.mark.parametrize("dtype", (np.float32, np.float64))
def test_get_metric_dtype(metric, metric_kwargs, dtype):
specialized_cls = {
np.float32: DistanceMetric32,
np.float64: DistanceMetric64,
}[dtype]

# We don't need the entire grid, just one for a sanity check
metric_kwargs = {k: v[0] for k, v in metric_kwargs.items()}
generic_type = type(DistanceMetric.get_metric(metric, dtype, **metric_kwargs))
specialized_type = type(specialized_cls.get_metric(metric, **metric_kwargs))

assert generic_type is specialized_type


def test_get_metric_bad_dtype():
dtype = np.int32
msg = r"Unexpected dtype .* provided. Please select a dtype from"
with pytest.raises(ValueError, match=msg):
DistanceMetric.get_metric("manhattan", dtype)
2 changes: 0 additions & 2 deletions sklearn/neighbors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

from ._ball_tree import BallTree
from ._kd_tree import KDTree
from ._distance_metric import DistanceMetric
from ._graph import kneighbors_graph, radius_neighbors_graph
from ._graph import KNeighborsTransformer, RadiusNeighborsTransformer
from ._unsupervised import NearestNeighbors
Expand All @@ -20,7 +19,6 @@

__all__ = [
"BallTree",
"DistanceMetric",
"KDTree",
"KNeighborsClassifier",
"KNeighborsRegressor",
Expand Down
Loading