Skip to content

MAINT Handle deprecation of sokalmichener metric #30004

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion sklearn/metrics/_dist_metrics.pyx.tp
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,24 @@ BOOL_METRICS = [
"dice",
"rogerstanimoto",
"russellrao",
"sokalmichener",
"sokalsneath",
]
DEPRECATED_METRICS = []
if sp_base_version < parse_version("1.17"):
# Deprecated in SciPy 1.15 and removed in SciPy 1.17
BOOL_METRICS += ["sokalmichener"]
if sp_base_version >= parse_version("1.15"):
DEPRECATED_METRICS.append("sokalmichener")
if sp_base_version < parse_version("1.11"):
# Deprecated in SciPy 1.9 and removed in SciPy 1.11
BOOL_METRICS += ["kulsinski"]
if sp_base_version >= parse_version("1.9"):
DEPRECATED_METRICS.append("kulsinski")
if sp_base_version < parse_version("1.9"):
# Deprecated in SciPy 1.0 and removed in SciPy 1.9
BOOL_METRICS += ["matching"]
if sp_base_version >= parse_version("1.0"):
DEPRECATED_METRICS.append("matching")

def get_valid_metric_ids(L):
"""Given an iterable of metric class names or class identifiers,
Expand Down
8 changes: 6 additions & 2 deletions sklearn/metrics/pairwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -693,14 +693,16 @@ def _argmin_reduce(dist, start):
"rogerstanimoto",
"russellrao",
"seuclidean",
"sokalmichener",
"sokalsneath",
"sqeuclidean",
"yule",
"wminkowski",
"nan_euclidean",
"haversine",
]
if sp_base_version < parse_version("1.17"): # pragma: no cover
# Deprecated in SciPy 1.15 and removed in SciPy 1.17
_VALID_METRICS += ["sokalmichener"]
if sp_base_version < parse_version("1.11"): # pragma: no cover
# Deprecated in SciPy 1.9 and removed in SciPy 1.11
_VALID_METRICS += ["kulsinski"]
Expand Down Expand Up @@ -2482,10 +2484,12 @@ def pairwise_distances(
"jaccard",
"rogerstanimoto",
"russellrao",
"sokalmichener",
"sokalsneath",
"yule",
]
if sp_base_version < parse_version("1.17"):
# Deprecated in SciPy 1.15 and removed in SciPy 1.17
PAIRWISE_BOOLEAN_FUNCTIONS += ["sokalmichener"]
if sp_base_version < parse_version("1.11"):
# Deprecated in SciPy 1.9 and removed in SciPy 1.11
PAIRWISE_BOOLEAN_FUNCTIONS += ["kulsinski"]
Expand Down
28 changes: 25 additions & 3 deletions sklearn/metrics/tests/test_dist_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,16 @@
from sklearn.metrics import DistanceMetric
from sklearn.metrics._dist_metrics import (
BOOL_METRICS,
DEPRECATED_METRICS,
DistanceMetric32,
DistanceMetric64,
)
from sklearn.utils import check_random_state
from sklearn.utils._testing import assert_allclose, create_memmap_backed_data
from sklearn.utils._testing import (
assert_allclose,
create_memmap_backed_data,
ignore_warnings,
)
from sklearn.utils.fixes import CSR_CONTAINERS, parse_version, sp_version


Expand Down Expand Up @@ -112,7 +117,15 @@ def test_cdist(metric_param_grid, X, Y, csr_container):
)
@pytest.mark.parametrize("csr_container", CSR_CONTAINERS)
def test_cdist_bool_metric(metric, X_bool, Y_bool, csr_container):
D_scipy_cdist = cdist(X_bool, Y_bool, metric)
if metric in DEPRECATED_METRICS:
with ignore_warnings(category=DeprecationWarning):
# Some metrics can be deprecated depending on the scipy version.
# But if they are present, we still want to test wether
# scikit-learn gives the same result, whether or not they are
# deprecated.
D_scipy_cdist = cdist(X_bool, Y_bool, metric)
else:
D_scipy_cdist = cdist(X_bool, Y_bool, metric)

dm = DistanceMetric.get_metric(metric)
D_sklearn = dm.pairwise(X_bool, Y_bool)
Expand Down Expand Up @@ -219,7 +232,16 @@ def test_distance_metrics_dtype_consistency(metric_param_grid):
@pytest.mark.parametrize("X_bool", [X_bool, X_bool_mmap])
@pytest.mark.parametrize("csr_container", CSR_CONTAINERS)
def test_pdist_bool_metrics(metric, X_bool, csr_container):
D_scipy_pdist = cdist(X_bool, X_bool, metric)
if metric in DEPRECATED_METRICS:
with ignore_warnings(category=DeprecationWarning):
# Some metrics can be deprecated depending on the scipy version.
# But if they are present, we still want to test wether
# scikit-learn gives the same result, whether or not they are
# deprecated.
D_scipy_pdist = cdist(X_bool, X_bool, metric)
else:
D_scipy_pdist = cdist(X_bool, X_bool, metric)

dm = DistanceMetric.get_metric(metric)
D_sklearn = dm.pairwise(X_bool)
assert_allclose(D_sklearn, D_scipy_pdist)
Expand Down
3 changes: 3 additions & 0 deletions sklearn/metrics/tests/test_pairwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,9 @@ def test_pairwise_distances_for_sparse_data(
pairwise_distances(X, Y_sparse, metric="minkowski")


# Some scipy metrics are deprecated (depending on the scipy version) but we
# still want to test them.
@ignore_warnings(category=DeprecationWarning)
@pytest.mark.parametrize("metric", PAIRWISE_BOOLEAN_FUNCTIONS)
def test_pairwise_boolean_distance(metric):
# test that we convert to boolean arrays for boolean distances
Expand Down
4 changes: 3 additions & 1 deletion sklearn/neighbors/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,13 @@
"rogerstanimoto",
"russellrao",
"seuclidean",
"sokalmichener",
"sokalsneath",
"sqeuclidean",
"yule",
]
if sp_base_version < parse_version("1.17"):
# Deprecated in SciPy 1.15 and removed in SciPy 1.17
SCIPY_METRICS += ["sokalmichener"]
if sp_base_version < parse_version("1.11"):
# Deprecated in SciPy 1.9 and removed in SciPy 1.11
SCIPY_METRICS += ["kulsinski"]
Expand Down
8 changes: 8 additions & 0 deletions sklearn/neighbors/tests/test_neighbors.py
Original file line number Diff line number Diff line change
Expand Up @@ -1721,6 +1721,10 @@ def test_neighbors_metrics(
assert_array_equal(ball_tree_idx, kd_tree_idx)


# TODO: Remove ignore_warnings when minimum supported SciPy version is 1.17
# Some scipy metrics are deprecated (depending on the scipy version) but we
# still want to test them.
@ignore_warnings(category=DeprecationWarning)
@pytest.mark.parametrize(
"metric", sorted(set(neighbors.VALID_METRICS["brute"]) - set(["precomputed"]))
)
Expand Down Expand Up @@ -2243,6 +2247,10 @@ def test_auto_algorithm(X, metric, metric_params, expected_algo):
assert model._fit_method == expected_algo


# TODO: Remove ignore_warnings when minimum supported SciPy version is 1.17
# Some scipy metrics are deprecated (depending on the scipy version) but we
# still want to test them.
@ignore_warnings(category=DeprecationWarning)
@pytest.mark.parametrize(
"metric", sorted(set(neighbors.VALID_METRICS["brute"]) - set(["precomputed"]))
)
Expand Down