From 8e941715a0f861a120ae543042f6fa40e02b04b4 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Mon, 19 Dec 2022 10:34:29 +0100 Subject: [PATCH 01/13] MAINT introduce kulczynski1 in place of kulsinski --- sklearn/cluster/_optics.py | 10 ++++++++-- sklearn/metrics/_dist_metrics.pyx.tp | 3 ++- sklearn/metrics/pairwise.py | 21 +++++++++++++++------ sklearn/neighbors/_base.py | 2 +- sklearn/neighbors/tests/test_ball_tree.py | 3 ++- 5 files changed, 28 insertions(+), 11 deletions(-) diff --git a/sklearn/cluster/_optics.py b/sklearn/cluster/_optics.py index 07c22bbdff691..5d40e00adf5f7 100755 --- a/sklearn/cluster/_optics.py +++ b/sklearn/cluster/_optics.py @@ -81,7 +81,7 @@ class OPTICS(ClusterMixin, BaseEstimator): 'manhattan'] - from scipy.spatial.distance: ['braycurtis', 'canberra', 'chebyshev', - 'correlation', 'dice', 'hamming', 'jaccard', 'kulsinski', + 'correlation', 'dice', 'hamming', 'jaccard', 'kulsinski', 'kulczynski1', 'mahalanobis', 'minkowski', 'rogerstanimoto', 'russellrao', 'seuclidean', 'sokalmichener', 'sokalsneath', 'sqeuclidean', 'yule'] @@ -90,6 +90,9 @@ class OPTICS(ClusterMixin, BaseEstimator): See the documentation for scipy.spatial.distance for details on these metrics. + .. note:: + `'kulsinski'` is deprecated from SciPy 1.8. Use `'kulczynski1'` instead. + p : float, default=2 Parameter for the Minkowski metric from :class:`~sklearn.metrics.pairwise_distances`. When p = 1, this is @@ -465,7 +468,7 @@ def compute_optics_graph( 'manhattan'] - from scipy.spatial.distance: ['braycurtis', 'canberra', 'chebyshev', - 'correlation', 'dice', 'hamming', 'jaccard', 'kulsinski', + 'correlation', 'dice', 'hamming', 'jaccard', 'kulsinski', 'kulczynski1', 'mahalanobis', 'minkowski', 'rogerstanimoto', 'russellrao', 'seuclidean', 'sokalmichener', 'sokalsneath', 'sqeuclidean', 'yule'] @@ -473,6 +476,9 @@ def compute_optics_graph( See the documentation for scipy.spatial.distance for details on these metrics. + .. note:: + `'kulsinski'` is deprecated from SciPy 1.8. Use `'kulczynski1'` instead. + p : int, default=2 Parameter for the Minkowski metric from :class:`~sklearn.metrics.pairwise_distances`. When p = 1, this is diff --git a/sklearn/metrics/_dist_metrics.pyx.tp b/sklearn/metrics/_dist_metrics.pyx.tp index 1e4a9429af03f..46a87581c0249 100644 --- a/sklearn/metrics/_dist_metrics.pyx.tp +++ b/sklearn/metrics/_dist_metrics.pyx.tp @@ -42,6 +42,7 @@ from ..utils._typedefs cimport DTYPE_t, ITYPE_t, DTYPECODE from ..utils._typedefs import DTYPE, ITYPE from ..utils._readonly_array_wrapper import ReadonlyArrayWrapper from ..utils import check_array +from ..utils.fixes import parse_version, sp_version cdef inline double fmax(double a, double b) nogil: return max(a, b) @@ -59,7 +60,7 @@ BOOL_METRICS = [ "matching", "jaccard", "dice", - "kulsinski", + "kulsinski" if sp_version < parse_version("1.8") else "kulczynski1", "rogerstanimoto", "russellrao", "sokalmichener", diff --git a/sklearn/metrics/pairwise.py b/sklearn/metrics/pairwise.py index 1ccff8ae8c8b7..d7b257805f6c6 100644 --- a/sklearn/metrics/pairwise.py +++ b/sklearn/metrics/pairwise.py @@ -635,7 +635,7 @@ def pairwise_distances_argmin_min( 'manhattan'] - from scipy.spatial.distance: ['braycurtis', 'canberra', 'chebyshev', - 'correlation', 'dice', 'hamming', 'jaccard', 'kulsinski', + 'correlation', 'dice', 'hamming', 'jaccard', 'kulsinski', 'kulczynski1', 'mahalanobis', 'minkowski', 'rogerstanimoto', 'russellrao', 'seuclidean', 'sokalmichener', 'sokalsneath', 'sqeuclidean', 'yule'] @@ -643,6 +643,9 @@ def pairwise_distances_argmin_min( See the documentation for scipy.spatial.distance for details on these metrics. + .. note:: + `'kulsinski'` is deprecated from SciPy 1.8. Use `'kulczynski1'` instead. + metric_kwargs : dict, default=None Keyword arguments to pass to specified metric function. @@ -752,7 +755,7 @@ def pairwise_distances_argmin(X, Y, *, axis=1, metric="euclidean", metric_kwargs 'manhattan'] - from scipy.spatial.distance: ['braycurtis', 'canberra', 'chebyshev', - 'correlation', 'dice', 'hamming', 'jaccard', 'kulsinski', + 'correlation', 'dice', 'hamming', 'jaccard', 'kulsinski', 'kulczynski1', 'mahalanobis', 'minkowski', 'rogerstanimoto', 'russellrao', 'seuclidean', 'sokalmichener', 'sokalsneath', 'sqeuclidean', 'yule'] @@ -760,6 +763,9 @@ def pairwise_distances_argmin(X, Y, *, axis=1, metric="euclidean", metric_kwargs See the documentation for scipy.spatial.distance for details on these metrics. + .. note:: + `'kulsinski'` is deprecated from SciPy 1.8. Use `'kulczynski1'` instead. + metric_kwargs : dict, default=None Keyword arguments to pass to specified metric function. @@ -1639,7 +1645,7 @@ def _pairwise_callable(X, Y, metric, force_all_finite=True, **kwds): "dice", "hamming", "jaccard", - "kulsinski", + "kulsinski" if sp_version < parse_version("1.8") else "kulczynski1", "mahalanobis", "matching", "minkowski", @@ -1902,12 +1908,15 @@ def pairwise_distances( ['nan_euclidean'] but it does not yet support sparse matrices. - From scipy.spatial.distance: ['braycurtis', 'canberra', 'chebyshev', - 'correlation', 'dice', 'hamming', 'jaccard', 'kulsinski', 'mahalanobis', - 'minkowski', 'rogerstanimoto', 'russellrao', 'seuclidean', + 'correlation', 'dice', 'hamming', 'jaccard', 'kulsinski', 'kulczynski1', + 'mahalanobis', 'minkowski', 'rogerstanimoto', 'russellrao', 'seuclidean', 'sokalmichener', 'sokalsneath', 'sqeuclidean', 'yule'] See the documentation for scipy.spatial.distance for details on these metrics. These metrics do not support sparse matrix inputs. + .. note:: + `'kulsinski'` is deprecated from SciPy 1.8. Use `'kulczynski1'` instead. + Note that in the case of 'cityblock', 'cosine' and 'euclidean' (which are valid scipy.spatial.distance metrics), the scikit-learn implementation will be used, which is faster and has support for sparse matrices (except @@ -2043,7 +2052,7 @@ def pairwise_distances( PAIRWISE_BOOLEAN_FUNCTIONS = [ "dice", "jaccard", - "kulsinski", + "kulsinski" if sp_version < parse_version("1.8") else "kulczynski1", "matching", "rogerstanimoto", "russellrao", diff --git a/sklearn/neighbors/_base.py b/sklearn/neighbors/_base.py index 3b01824a3a73a..dc8c450f61c3d 100644 --- a/sklearn/neighbors/_base.py +++ b/sklearn/neighbors/_base.py @@ -57,7 +57,7 @@ "dice", "hamming", "jaccard", - "kulsinski", + "kulsinski" if sp_version < parse_version("1.8") else "kulczynski1", "mahalanobis", "matching", "minkowski", diff --git a/sklearn/neighbors/tests/test_ball_tree.py b/sklearn/neighbors/tests/test_ball_tree.py index d5046afd2da2a..371804c7d2c78 100644 --- a/sklearn/neighbors/tests/test_ball_tree.py +++ b/sklearn/neighbors/tests/test_ball_tree.py @@ -5,6 +5,7 @@ from numpy.testing import assert_array_almost_equal from sklearn.neighbors._ball_tree import BallTree from sklearn.utils import check_random_state +from sklearn.utils.fixes import parse_version, sp_version from sklearn.utils.validation import check_array from sklearn.utils._testing import _convert_container @@ -30,7 +31,7 @@ "matching", "jaccard", "dice", - "kulsinski", + "kulsinski" if sp_version < parse_version("1.8") else "kulczynski1", "rogerstanimoto", "russellrao", "sokalmichener", From 72bf8b43fade29fa88ae9d0d3c793f290d59bbcb Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Mon, 19 Dec 2022 10:35:12 +0100 Subject: [PATCH 02/13] [scipy-dev] trigger scipy-dev From ae64e3e9e83257d91561c35534d4339b3dfa0132 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Mon, 19 Dec 2022 11:06:49 +0100 Subject: [PATCH 03/13] FIX add alias for DistanceMetric --- sklearn/metrics/_dist_metrics.pyx.tp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sklearn/metrics/_dist_metrics.pyx.tp b/sklearn/metrics/_dist_metrics.pyx.tp index 46a87581c0249..29c33605fa615 100644 --- a/sklearn/metrics/_dist_metrics.pyx.tp +++ b/sklearn/metrics/_dist_metrics.pyx.tp @@ -106,6 +106,7 @@ METRIC_MAPPING{{name_suffix}} = { 'jaccard': JaccardDistance{{name_suffix}}, 'dice': DiceDistance{{name_suffix}}, 'kulsinski': KulsinskiDistance{{name_suffix}}, + 'kulczynski1': KulsinskiDistance{{name_suffix}}, 'rogerstanimoto': RogersTanimotoDistance{{name_suffix}}, 'russellrao': RussellRaoDistance{{name_suffix}}, 'sokalmichener': SokalMichenerDistance{{name_suffix}}, @@ -214,6 +215,7 @@ cdef class DistanceMetric{{name_suffix}}: "matching" MatchingDistance NNEQ / N "dice" DiceDistance NNEQ / (NTT + NNZ) "kulsinski" KulsinskiDistance (NNEQ + N - NTT) / (NNEQ + N) + "kulczynski1" KulsinskiDistance (NNEQ + N - NTT) / (NNEQ + N) "rogerstanimoto" RogersTanimotoDistance 2 * NNEQ / (N + NNEQ) "russellrao" RussellRaoDistance (N - NTT) / N "sokalmichener" SokalMichenerDistance 2 * NNEQ / (N + NNEQ) From 83d57286687ffa0800eca8938622719bd1867496 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Mon, 19 Dec 2022 11:06:59 +0100 Subject: [PATCH 04/13] [scipy-dev] trigger scipy-dev From 1ebe6a0431b829b87f8ceb2964e4443dead347aa Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Mon, 19 Dec 2022 12:21:30 +0100 Subject: [PATCH 05/13] Implement Kulczynski1Distance --- sklearn/metrics/_dist_metrics.pyx.tp | 100 ++++++++++++++++++++++++++- 1 file changed, 98 insertions(+), 2 deletions(-) diff --git a/sklearn/metrics/_dist_metrics.pyx.tp b/sklearn/metrics/_dist_metrics.pyx.tp index 29c33605fa615..9328f9579c113 100644 --- a/sklearn/metrics/_dist_metrics.pyx.tp +++ b/sklearn/metrics/_dist_metrics.pyx.tp @@ -106,7 +106,7 @@ METRIC_MAPPING{{name_suffix}} = { 'jaccard': JaccardDistance{{name_suffix}}, 'dice': DiceDistance{{name_suffix}}, 'kulsinski': KulsinskiDistance{{name_suffix}}, - 'kulczynski1': KulsinskiDistance{{name_suffix}}, + 'kulczynski1': Kulczynski1Distance{{name_suffix}}, 'rogerstanimoto': RogersTanimotoDistance{{name_suffix}}, 'russellrao': RussellRaoDistance{{name_suffix}}, 'sokalmichener': SokalMichenerDistance{{name_suffix}}, @@ -215,7 +215,7 @@ cdef class DistanceMetric{{name_suffix}}: "matching" MatchingDistance NNEQ / N "dice" DiceDistance NNEQ / (NTT + NNZ) "kulsinski" KulsinskiDistance (NNEQ + N - NTT) / (NNEQ + N) - "kulczynski1" KulsinskiDistance (NNEQ + N - NTT) / (NNEQ + N) + "kulczynski1" Kulczynski1Distance (NNEQ + N - NTT) / (NNEQ + N) "rogerstanimoto" RogersTanimotoDistance 2 * NNEQ / (N + NNEQ) "russellrao" RussellRaoDistance (N - NTT) / N "sokalmichener" SokalMichenerDistance 2 * NNEQ / (N + NNEQ) @@ -2315,6 +2315,102 @@ cdef class KulsinskiDistance{{name_suffix}}(DistanceMetric{{name_suffix}}): return (n_neq - n_tt + size) * 1.0 / (n_neq + size) + +#------------------------------------------------------------ +# Kulczynski Distance (boolean) +# D(x, y) = c_11 / (c_01 + c_10) +cdef class Kulczynski1Distance{{name_suffix}}(DistanceMetric{{name_suffix}}): + r"""Kulsinski Distance + + Kulczynski Distance is a dissimilarity measure for boolean-valued + vectors. All nonzero entries will be treated as True, zero entries will + be treated as False. + + D(x, y) = c_11 / (c_01 + c_10) + + """ + cdef inline DTYPE_t dist( + self, + const {{INPUT_DTYPE_t}}* x1, + const {{INPUT_DTYPE_t}}* x2, + ITYPE_t size, + ) nogil except -1: + cdef int tf1, tf2, c_11 = 0, c_01 = 0, c_10 = 0 + cdef cnp.intp_t j + for j in range(size): + tf1 = x1[j] != 0 + tf2 = x2[j] != 0 + if tf1 == 0 and tf2 == 0: + continue + elif tf1 == tf2: + c_11 += (tf1 and tf2) + elif tf1 == 0: + c_01 += tf2 + else: + c_10 += tf1 + return c_11 / (c_01 + c_10) + + cdef inline DTYPE_t dist_csr( + self, + const {{INPUT_DTYPE_t}}* x1_data, + const SPARSE_INDEX_TYPE_t[:] x1_indices, + const {{INPUT_DTYPE_t}}* x2_data, + const SPARSE_INDEX_TYPE_t[:] x2_indices, + const SPARSE_INDEX_TYPE_t x1_start, + const SPARSE_INDEX_TYPE_t x1_end, + const SPARSE_INDEX_TYPE_t x2_start, + const SPARSE_INDEX_TYPE_t x2_end, + const ITYPE_t size, + ) nogil except -1: + + cdef: + cnp.npy_intp ix1, ix2 + cnp.npy_intp i1 = x1_start + cnp.npy_intp i2 = x2_start + + ITYPE_t tf1, tf2, c_11 = 0, c_01 = 0, c_10 = 0 + + while i1 < x1_end and i2 < x2_end: + ix1 = x1_indices[i1] + ix2 = x2_indices[i2] + + tf1 = x1_data[i1] != 0 + tf2 = x2_data[i2] != 0 + + if ix1 == ix2: + if tf1 == 0 and tf2 == 0: + pass + elif tf1 == tf2: + c_11 += (tf1 and tf2) + elif tf1 == 0: + c_01 += tf2 + else: + c_10 += tf1 + i1 += 1 + i2 += 1 + elif ix1 < ix2: + # non-zero value in x1 but not in x2 + c_10 += tf1 + i1 += 1 + else: + # non-zero value in x2 but not in x1 + c_01 += tf2 + i2 += 1 + + if i1 == x1_end: + while i2 < x2_end: + tf2 = x2_data[i2] != 0 + c_01 += tf2 + i2 += 1 + else: + while i1 < x1_end: + tf1 = x1_data[i1] != 0 + c_10 += tf1 + i1 += 1 + + return c_11 / (c_10 + c_01) + + #------------------------------------------------------------ # Rogers-Tanimoto Distance (boolean) # D(x, y) = 2 * n_neq / (n + n_neq) From a24c41564d3014723baa8aa4e914366fa8410bf8 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Mon, 19 Dec 2022 12:32:03 +0100 Subject: [PATCH 06/13] add more details --- sklearn/cluster/_optics.py | 2 ++ sklearn/metrics/pairwise.py | 3 +++ 2 files changed, 5 insertions(+) diff --git a/sklearn/cluster/_optics.py b/sklearn/cluster/_optics.py index 5d40e00adf5f7..a506fcdbd6faa 100755 --- a/sklearn/cluster/_optics.py +++ b/sklearn/cluster/_optics.py @@ -92,6 +92,7 @@ class OPTICS(ClusterMixin, BaseEstimator): .. note:: `'kulsinski'` is deprecated from SciPy 1.8. Use `'kulczynski1'` instead. + Note that the two metrics are not identical. p : float, default=2 Parameter for the Minkowski metric from @@ -478,6 +479,7 @@ def compute_optics_graph( .. note:: `'kulsinski'` is deprecated from SciPy 1.8. Use `'kulczynski1'` instead. + Note that the two metrics are not identical. p : int, default=2 Parameter for the Minkowski metric from diff --git a/sklearn/metrics/pairwise.py b/sklearn/metrics/pairwise.py index d7b257805f6c6..beba81b4ac36e 100644 --- a/sklearn/metrics/pairwise.py +++ b/sklearn/metrics/pairwise.py @@ -645,6 +645,7 @@ def pairwise_distances_argmin_min( .. note:: `'kulsinski'` is deprecated from SciPy 1.8. Use `'kulczynski1'` instead. + Note that the two metrics are not identical. metric_kwargs : dict, default=None Keyword arguments to pass to specified metric function. @@ -765,6 +766,7 @@ def pairwise_distances_argmin(X, Y, *, axis=1, metric="euclidean", metric_kwargs .. note:: `'kulsinski'` is deprecated from SciPy 1.8. Use `'kulczynski1'` instead. + Note that the two metrics are not identical. metric_kwargs : dict, default=None Keyword arguments to pass to specified metric function. @@ -1916,6 +1918,7 @@ def pairwise_distances( .. note:: `'kulsinski'` is deprecated from SciPy 1.8. Use `'kulczynski1'` instead. + Note that the two metrics are not identical. Note that in the case of 'cityblock', 'cosine' and 'euclidean' (which are valid scipy.spatial.distance metrics), the scikit-learn implementation From f5d993c950521a5b2ea16f08836263d28047ded8 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Mon, 19 Dec 2022 12:32:17 +0100 Subject: [PATCH 07/13] [scipy-dev] trigger scipy-dev From 1025f7a8027b79b3545acac2439c8d8e22619fa2 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Mon, 19 Dec 2022 15:39:50 +0100 Subject: [PATCH 08/13] iter --- sklearn/metrics/_dist_metrics.pyx.tp | 7 ++++++- sklearn/metrics/pairwise.py | 14 ++++++++++++-- sklearn/neighbors/tests/test_ball_tree.py | 5 ++++- 3 files changed, 22 insertions(+), 4 deletions(-) diff --git a/sklearn/metrics/_dist_metrics.pyx.tp b/sklearn/metrics/_dist_metrics.pyx.tp index 9328f9579c113..010e37c4a3677 100644 --- a/sklearn/metrics/_dist_metrics.pyx.tp +++ b/sklearn/metrics/_dist_metrics.pyx.tp @@ -60,12 +60,17 @@ BOOL_METRICS = [ "matching", "jaccard", "dice", - "kulsinski" if sp_version < parse_version("1.8") else "kulczynski1", "rogerstanimoto", "russellrao", "sokalmichener", "sokalsneath", ] +if sp_version >= parse_version("1.8"): + # Introduced in SciPy 1.8 + BOOL_METRICS += ["kulczynski1"] +if sp_version < parse_version("1.10"): + # Deprecated in SciPy 1.8 and removed in SciPy 1.10 + BOOL_METRICS += ["kulsinski"] def get_valid_metric_ids(L): """Given an iterable of metric class names or class identifiers, diff --git a/sklearn/metrics/pairwise.py b/sklearn/metrics/pairwise.py index beba81b4ac36e..e33204e0fc821 100644 --- a/sklearn/metrics/pairwise.py +++ b/sklearn/metrics/pairwise.py @@ -1647,7 +1647,6 @@ def _pairwise_callable(X, Y, metric, force_all_finite=True, **kwds): "dice", "hamming", "jaccard", - "kulsinski" if sp_version < parse_version("1.8") else "kulczynski1", "mahalanobis", "matching", "minkowski", @@ -1662,6 +1661,12 @@ def _pairwise_callable(X, Y, metric, force_all_finite=True, **kwds): "nan_euclidean", "haversine", ] +if sp_version >= parse_version("1.8"): + # Introduced in SciPy 1.8 + _VALID_METRICS += ["kulczynski1"] +if sp_version < parse_version("1.10"): + # Deprecated in SciPy 1.8 and removed in SciPy 1.10 + _VALID_METRICS += ["kulsinski"] _NAN_METRICS = ["nan_euclidean"] @@ -2055,7 +2060,6 @@ def pairwise_distances( PAIRWISE_BOOLEAN_FUNCTIONS = [ "dice", "jaccard", - "kulsinski" if sp_version < parse_version("1.8") else "kulczynski1", "matching", "rogerstanimoto", "russellrao", @@ -2063,6 +2067,12 @@ def pairwise_distances( "sokalsneath", "yule", ] +if sp_version >= parse_version("1.8"): + # Introduced in SciPy 1.8 + PAIRWISE_BOOLEAN_FUNCTIONS += ["kulczynski1"] +if sp_version < parse_version("1.10"): + # Deprecated in SciPy 1.8 and removed in SciPy 1.10 + PAIRWISE_BOOLEAN_FUNCTIONS += ["kulsinski"] # Helper functions - distance PAIRWISE_KERNEL_FUNCTIONS = { diff --git a/sklearn/neighbors/tests/test_ball_tree.py b/sklearn/neighbors/tests/test_ball_tree.py index 371804c7d2c78..8c3ffcf5d032e 100644 --- a/sklearn/neighbors/tests/test_ball_tree.py +++ b/sklearn/neighbors/tests/test_ball_tree.py @@ -31,12 +31,15 @@ "matching", "jaccard", "dice", - "kulsinski" if sp_version < parse_version("1.8") else "kulczynski1", "rogerstanimoto", "russellrao", "sokalmichener", "sokalsneath", ] +if sp_version >= parse_version("1.8"): + BOOLEAN_METRICS += ["kulczynski1"] +if sp_version < parse_version("1.10"): + BOOLEAN_METRICS += ["kulsinski"] def brute_force_neighbors(X, Y, k, metric, **kwargs): From aba5a12b41150873051ec8b6bc98ca76f23ef6ab Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Mon, 19 Dec 2022 15:54:58 +0100 Subject: [PATCH 09/13] iter --- sklearn/metrics/tests/test_pairwise.py | 2 +- sklearn/neighbors/_ball_tree.pyx | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sklearn/metrics/tests/test_pairwise.py b/sklearn/metrics/tests/test_pairwise.py index 3624983c4c481..78fa5b4218fa7 100644 --- a/sklearn/metrics/tests/test_pairwise.py +++ b/sklearn/metrics/tests/test_pairwise.py @@ -225,7 +225,7 @@ def test_pairwise_boolean_distance(metric): with ignore_warnings(category=DataConversionWarning): for Z in [Y, None]: res = pairwise_distances(X, Z, metric=metric) - res[np.isnan(res)] = 0 + res = np.nan_to_num(res, nan=0, posinf=0, neginf=0) assert np.sum(res != 0) == 0 # non-boolean arrays are converted to boolean for boolean diff --git a/sklearn/neighbors/_ball_tree.pyx b/sklearn/neighbors/_ball_tree.pyx index 094a8826acfb9..3d9f02e08314b 100644 --- a/sklearn/neighbors/_ball_tree.pyx +++ b/sklearn/neighbors/_ball_tree.pyx @@ -11,7 +11,7 @@ VALID_METRICS = ['EuclideanDistance', 'SEuclideanDistance', 'MahalanobisDistance', 'HammingDistance', 'CanberraDistance', 'BrayCurtisDistance', 'JaccardDistance', 'MatchingDistance', - 'DiceDistance', 'KulsinskiDistance', + 'DiceDistance', 'KulsinskiDistance', 'Kulczynski1Distance', 'RogersTanimotoDistance', 'RussellRaoDistance', 'SokalMichenerDistance', 'SokalSneathDistance', 'PyFuncDistance', 'HaversineDistance'] From 3cc450a6b8eb4d2abd30662c17118f6aa673866c Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Mon, 19 Dec 2022 16:04:32 +0100 Subject: [PATCH 10/13] iter --- sklearn/metrics/_dist_metrics.pyx.tp | 2 +- sklearn/metrics/pairwise.py | 4 ++-- sklearn/neighbors/tests/test_ball_tree.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/sklearn/metrics/_dist_metrics.pyx.tp b/sklearn/metrics/_dist_metrics.pyx.tp index 010e37c4a3677..0ed4f196474c7 100644 --- a/sklearn/metrics/_dist_metrics.pyx.tp +++ b/sklearn/metrics/_dist_metrics.pyx.tp @@ -68,7 +68,7 @@ BOOL_METRICS = [ if sp_version >= parse_version("1.8"): # Introduced in SciPy 1.8 BOOL_METRICS += ["kulczynski1"] -if sp_version < parse_version("1.10"): +if sp_version < parse_version("1.11"): # Deprecated in SciPy 1.8 and removed in SciPy 1.10 BOOL_METRICS += ["kulsinski"] diff --git a/sklearn/metrics/pairwise.py b/sklearn/metrics/pairwise.py index e33204e0fc821..3f18377211da4 100644 --- a/sklearn/metrics/pairwise.py +++ b/sklearn/metrics/pairwise.py @@ -1664,7 +1664,7 @@ def _pairwise_callable(X, Y, metric, force_all_finite=True, **kwds): if sp_version >= parse_version("1.8"): # Introduced in SciPy 1.8 _VALID_METRICS += ["kulczynski1"] -if sp_version < parse_version("1.10"): +if sp_version < parse_version("1.11"): # Deprecated in SciPy 1.8 and removed in SciPy 1.10 _VALID_METRICS += ["kulsinski"] @@ -2070,7 +2070,7 @@ def pairwise_distances( if sp_version >= parse_version("1.8"): # Introduced in SciPy 1.8 PAIRWISE_BOOLEAN_FUNCTIONS += ["kulczynski1"] -if sp_version < parse_version("1.10"): +if sp_version < parse_version("1.11"): # Deprecated in SciPy 1.8 and removed in SciPy 1.10 PAIRWISE_BOOLEAN_FUNCTIONS += ["kulsinski"] diff --git a/sklearn/neighbors/tests/test_ball_tree.py b/sklearn/neighbors/tests/test_ball_tree.py index 8c3ffcf5d032e..01df3859685db 100644 --- a/sklearn/neighbors/tests/test_ball_tree.py +++ b/sklearn/neighbors/tests/test_ball_tree.py @@ -38,7 +38,7 @@ ] if sp_version >= parse_version("1.8"): BOOLEAN_METRICS += ["kulczynski1"] -if sp_version < parse_version("1.10"): +if sp_version < parse_version("1.11"): BOOLEAN_METRICS += ["kulsinski"] From 8f9088040401c27564a785019f974d5be91b52f4 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Thu, 5 Jan 2023 17:05:23 +0100 Subject: [PATCH 11/13] remove support in BallTree --- doc/whats_new/v1.3.rst | 6 ++++++ sklearn/metrics/tests/test_pairwise.py | 2 +- sklearn/neighbors/_ball_tree.pyx | 2 +- sklearn/neighbors/tests/test_ball_tree.py | 5 ----- 4 files changed, 8 insertions(+), 7 deletions(-) diff --git a/doc/whats_new/v1.3.rst b/doc/whats_new/v1.3.rst index 68a569acb14e5..1d31d94df7b6b 100644 --- a/doc/whats_new/v1.3.rst +++ b/doc/whats_new/v1.3.rst @@ -36,6 +36,12 @@ Changelog :pr:`123456` by :user:`Joe Bloggs `. where 123456 is the *pull request* number, not the issue number. +:mod:`sklearn.neighbors` +........................ +- |Fix| Remove support for `"KulsinskiDistance"` in :class:`neighbors.BallTree`. This + metric is not a proper metric and cannot be supported by the BallTree. + :pr:`25212` by :user:`Guillaume Lemaitre `. + :mod:`sklearn.pipeline` ....................... - |Feature| :class:`pipeline.FeatureUnion` can now use indexing notation (e.g. diff --git a/sklearn/metrics/tests/test_pairwise.py b/sklearn/metrics/tests/test_pairwise.py index 78fa5b4218fa7..a9800ea772feb 100644 --- a/sklearn/metrics/tests/test_pairwise.py +++ b/sklearn/metrics/tests/test_pairwise.py @@ -225,7 +225,7 @@ def test_pairwise_boolean_distance(metric): with ignore_warnings(category=DataConversionWarning): for Z in [Y, None]: res = pairwise_distances(X, Z, metric=metric) - res = np.nan_to_num(res, nan=0, posinf=0, neginf=0) + np.nan_to_num(res, nan=0, posinf=0, neginf=0, copy=False) assert np.sum(res != 0) == 0 # non-boolean arrays are converted to boolean for boolean diff --git a/sklearn/neighbors/_ball_tree.pyx b/sklearn/neighbors/_ball_tree.pyx index 3d9f02e08314b..6dcc6afa2127d 100644 --- a/sklearn/neighbors/_ball_tree.pyx +++ b/sklearn/neighbors/_ball_tree.pyx @@ -11,7 +11,7 @@ VALID_METRICS = ['EuclideanDistance', 'SEuclideanDistance', 'MahalanobisDistance', 'HammingDistance', 'CanberraDistance', 'BrayCurtisDistance', 'JaccardDistance', 'MatchingDistance', - 'DiceDistance', 'KulsinskiDistance', 'Kulczynski1Distance', + 'DiceDistance', 'RogersTanimotoDistance', 'RussellRaoDistance', 'SokalMichenerDistance', 'SokalSneathDistance', 'PyFuncDistance', 'HaversineDistance'] diff --git a/sklearn/neighbors/tests/test_ball_tree.py b/sklearn/neighbors/tests/test_ball_tree.py index 01df3859685db..8d665f799e9d8 100644 --- a/sklearn/neighbors/tests/test_ball_tree.py +++ b/sklearn/neighbors/tests/test_ball_tree.py @@ -5,7 +5,6 @@ from numpy.testing import assert_array_almost_equal from sklearn.neighbors._ball_tree import BallTree from sklearn.utils import check_random_state -from sklearn.utils.fixes import parse_version, sp_version from sklearn.utils.validation import check_array from sklearn.utils._testing import _convert_container @@ -36,10 +35,6 @@ "sokalmichener", "sokalsneath", ] -if sp_version >= parse_version("1.8"): - BOOLEAN_METRICS += ["kulczynski1"] -if sp_version < parse_version("1.11"): - BOOLEAN_METRICS += ["kulsinski"] def brute_force_neighbors(X, Y, k, metric, **kwargs): From 44b4cbe4f55dab85f94d4af8c3152e7dd2f61d4a Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Thu, 5 Jan 2023 17:08:05 +0100 Subject: [PATCH 12/13] doc glitch whats new --- doc/whats_new/v1.3.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/whats_new/v1.3.rst b/doc/whats_new/v1.3.rst index 0c36661a3022b..69e7c7a9331a6 100644 --- a/doc/whats_new/v1.3.rst +++ b/doc/whats_new/v1.3.rst @@ -142,6 +142,6 @@ Code and Documentation Contributors ----------------------------------- Thanks to everyone who has contributed to the maintenance and improvement of -the project since version 1.1, including: +the project since version 1.2, including: TODO: update at the time of the release. From 5cf81f1f145e0b5f89c08efdaedcdcd6b1f7b7cb Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Thu, 5 Jan 2023 17:08:51 +0100 Subject: [PATCH 13/13] Update sklearn/metrics/_dist_metrics.pyx.tp Co-authored-by: Thomas J. Fan --- sklearn/metrics/_dist_metrics.pyx.tp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/metrics/_dist_metrics.pyx.tp b/sklearn/metrics/_dist_metrics.pyx.tp index 0ed4f196474c7..dbfb241aa2799 100644 --- a/sklearn/metrics/_dist_metrics.pyx.tp +++ b/sklearn/metrics/_dist_metrics.pyx.tp @@ -220,7 +220,7 @@ cdef class DistanceMetric{{name_suffix}}: "matching" MatchingDistance NNEQ / N "dice" DiceDistance NNEQ / (NTT + NNZ) "kulsinski" KulsinskiDistance (NNEQ + N - NTT) / (NNEQ + N) - "kulczynski1" Kulczynski1Distance (NNEQ + N - NTT) / (NNEQ + N) + "kulczynski1" Kulczynski1Distance NTT / NNEQ "rogerstanimoto" RogersTanimotoDistance 2 * NNEQ / (N + NNEQ) "russellrao" RussellRaoDistance (N - NTT) / N "sokalmichener" SokalMichenerDistance 2 * NNEQ / (N + NNEQ)