diff --git a/sklearn/metrics/tests/test_pairwise_distances_reduction.py b/sklearn/metrics/tests/test_pairwise_distances_reduction.py index f9b4b9fb242fe..75f497315ff01 100644 --- a/sklearn/metrics/tests/test_pairwise_distances_reduction.py +++ b/sklearn/metrics/tests/test_pairwise_distances_reduction.py @@ -1,15 +1,14 @@ import itertools import re import warnings -from collections import defaultdict -from math import floor, log10 +from functools import partial import numpy as np import pytest import threadpoolctl from scipy.spatial.distance import cdist -from sklearn.metrics import euclidean_distances +from sklearn.metrics import euclidean_distances, pairwise_distances from sklearn.metrics._pairwise_distances_reduction import ( ArgKmin, ArgKminClassMode, @@ -66,144 +65,194 @@ def _get_metric_params_list(metric: str, n_features: int, seed: int = 1): return [{}] -def assert_argkmin_results_equality(ref_dist, dist, ref_indices, indices, rtol=1e-7): - assert_array_equal( - ref_indices, - indices, - err_msg="Query vectors have different neighbors' indices", - ) - assert_allclose( - ref_dist, - dist, - err_msg="Query vectors have different neighbors' distances", - rtol=rtol, - ) - +def assert_same_distances_for_common_neighbors( + query_idx, + dist_row_a, + dist_row_b, + indices_row_a, + indices_row_b, + rtol, + atol, +): + """Check that the distances of common neighbors are equal up to tolerance. -def relative_rounding(scalar, n_significant_digits): - """Round a scalar to a number of significant digits relatively to its value.""" - if scalar == 0: - return 0.0 - magnitude = int(floor(log10(abs(scalar)))) + 1 - return round(scalar, n_significant_digits - magnitude) + This does not check if there are missing neighbors in either result set. + Missingness is handled by assert_no_missing_neighbors. + """ + # Compute a mapping from indices to distances for each result set and + # check that the computed neighbors with matching indices are within + # the expected distance tolerance. + indices_to_dist_a = dict(zip(indices_row_a, dist_row_a)) + indices_to_dist_b = dict(zip(indices_row_b, dist_row_b)) + + common_indices = set(indices_row_a).intersection(set(indices_row_b)) + for idx in common_indices: + dist_a = indices_to_dist_a[idx] + dist_b = indices_to_dist_b[idx] + try: + assert_allclose(dist_a, dist_b, rtol=rtol, atol=atol) + except AssertionError as e: + # Wrap exception to provide more context while also including + # the original exception with the computed absolute and + # relative differences. + raise AssertionError( + f"Query vector with index {query_idx} lead to different distances" + f" for common neighbor with index {idx}:" + f" dist_a={dist_a} vs dist_b={dist_b} (with atol={atol} and" + f" rtol={rtol})" + ) from e + + +def assert_no_missing_neighbors( + query_idx, + dist_row_a, + dist_row_b, + indices_row_a, + indices_row_b, + threshold, +): + """Compare the indices of neighbors in two results sets. + Any neighbor index with a distance below the precision threshold should + match one in the other result set. We ignore the last few neighbors beyond + the threshold as those can typically be missing due to rounding errors. -def test_relative_rounding(): - assert relative_rounding(0, 1) == 0.0 - assert relative_rounding(0, 10) == 0.0 - assert relative_rounding(0, 123456) == 0.0 + For radius queries, the threshold is just the radius minus the expected + precision level. - assert relative_rounding(123456789, 0) == 0 - assert relative_rounding(123456789, 2) == 120000000 - assert relative_rounding(123456789, 3) == 123000000 - assert relative_rounding(123456789, 10) == 123456789 - assert relative_rounding(123456789, 20) == 123456789 + For k-NN queries, it is the maxium distance to the k-th neighbor minus the + expected precision level. + """ + mask_a = dist_row_a < threshold + mask_b = dist_row_b < threshold + missing_from_b = np.setdiff1d(indices_row_a[mask_a], indices_row_b) + missing_from_a = np.setdiff1d(indices_row_b[mask_b], indices_row_a) + if len(missing_from_a) > 0 or len(missing_from_b) > 0: + raise AssertionError( + f"Query vector with index {query_idx} lead to mismatched result indices:\n" + f"neighors in b missing from a: {missing_from_a}\n" + f"neighors in a missing from b: {missing_from_b}\n" + f"dist_row_a={dist_row_a}\n" + f"dist_row_b={dist_row_b}\n" + f"indices_row_a={indices_row_a}\n" + f"indices_row_b={indices_row_b}\n" + ) - assert relative_rounding(1.23456789, 2) == 1.2 - assert relative_rounding(1.23456789, 3) == 1.23 - assert relative_rounding(1.23456789, 10) == 1.23456789 - assert relative_rounding(123.456789, 3) == 123.0 - assert relative_rounding(123.456789, 9) == 123.456789 - assert relative_rounding(123.456789, 10) == 123.456789 +def assert_compatible_argkmin_results( + neighbors_dists_a, + neighbors_dists_b, + neighbors_indices_a, + neighbors_indices_b, + rtol=1e-5, + atol=1e-6, +): + """Assert that argkmin results are valid up to rounding errors. + This function asserts that the results of argkmin queries are valid up to: + - rounding error tolerance on distance values; + - permutations of indices for distances values that differ up to the + expected precision level. -def assert_argkmin_results_quasi_equality( - ref_dist, - dist, - ref_indices, - indices, - rtol=1e-4, -): - """Assert that argkmin results are valid up to: - - relative tolerance on computed distance values - - permutations of indices for distances values that differ up to - a precision level + Furthermore, the distances must be sorted. - To be used for testing neighbors queries on float32 datasets: we - accept neighbors rank swaps only if they are caused by small - rounding errors on the distance computations. + To be used for testing neighbors queries on float32 datasets: we accept + neighbors rank swaps only if they are caused by small rounding errors on + the distance computations. """ is_sorted = lambda a: np.all(a[:-1] <= a[1:]) - n_significant_digits = -(int(floor(log10(abs(rtol)))) + 1) - assert ( - ref_dist.shape == dist.shape == ref_indices.shape == indices.shape - ), "Arrays of results have various shapes." + neighbors_dists_a.shape + == neighbors_dists_b.shape + == neighbors_indices_a.shape + == neighbors_indices_b.shape + ), "Arrays of results have incompatible shapes." - n_queries, n_neighbors = ref_dist.shape + n_queries, _ = neighbors_dists_a.shape # Asserting equality results one row at a time for query_idx in range(n_queries): - ref_dist_row = ref_dist[query_idx] - dist_row = dist[query_idx] - - assert is_sorted( - ref_dist_row - ), f"Reference distances aren't sorted on row {query_idx}" - assert is_sorted(dist_row), f"Distances aren't sorted on row {query_idx}" - - assert_allclose(ref_dist_row, dist_row, rtol=rtol) - - ref_indices_row = ref_indices[query_idx] - indices_row = indices[query_idx] - - # Grouping indices by distances using sets on a rounded distances up - # to a given number of decimals of significant digits derived from rtol. - reference_neighbors_groups = defaultdict(set) - effective_neighbors_groups = defaultdict(set) + dist_row_a = neighbors_dists_a[query_idx] + dist_row_b = neighbors_dists_b[query_idx] + indices_row_a = neighbors_indices_a[query_idx] + indices_row_b = neighbors_indices_b[query_idx] + + assert is_sorted(dist_row_a), f"Distances aren't sorted on row {query_idx}" + assert is_sorted(dist_row_b), f"Distances aren't sorted on row {query_idx}" + + assert_same_distances_for_common_neighbors( + query_idx, + dist_row_a, + dist_row_b, + indices_row_a, + indices_row_b, + rtol, + atol, + ) - for neighbor_rank in range(n_neighbors): - rounded_dist = relative_rounding( - ref_dist_row[neighbor_rank], - n_significant_digits=n_significant_digits, - ) - reference_neighbors_groups[rounded_dist].add(ref_indices_row[neighbor_rank]) - effective_neighbors_groups[rounded_dist].add(indices_row[neighbor_rank]) - - # Asserting equality of groups (sets) for each distance - msg = ( - f"Neighbors indices for query {query_idx} are not matching " - f"when rounding distances at {n_significant_digits} significant digits " - f"derived from rtol={rtol:.1e}" + # Check that any neighbor with distances below the rounding error + # threshold have matching indices. The threshold is the distance to the + # k-th neighbors minus the expected precision level: + # + # (1 - rtol) * dist_k - atol + # + # Where dist_k is defined as the maxium distance to the kth-neighbor + # among the two result sets. This way of defining the threshold is + # stricter than taking the minimum of the two. + threshold = (1 - rtol) * np.maximum( + np.max(dist_row_a), np.max(dist_row_b) + ) - atol + assert_no_missing_neighbors( + query_idx, + dist_row_a, + dist_row_b, + indices_row_a, + indices_row_b, + threshold, ) - for rounded_distance in reference_neighbors_groups.keys(): - assert ( - reference_neighbors_groups[rounded_distance] - == effective_neighbors_groups[rounded_distance] - ), msg -def assert_radius_neighbors_results_equality( - ref_dist, dist, ref_indices, indices, radius +def _non_trivial_radius( + *, + X=None, + Y=None, + metric=None, + precomputed_dists=None, + expected_n_neighbors=10, + n_subsampled_queries=10, + **metric_kwargs, ): - # We get arrays of arrays and we need to check for individual pairs - for i in range(ref_dist.shape[0]): - assert (ref_dist[i] <= radius).all() - assert_array_equal( - ref_indices[i], - indices[i], - err_msg=f"Query vector #{i} has different neighbors' indices", - ) - assert_allclose( - ref_dist[i], - dist[i], - err_msg=f"Query vector #{i} has different neighbors' distances", - rtol=1e-7, - ) + # Find a non-trivial radius using a small subsample of the pairwise + # distances between X and Y: we want to return around expected_n_neighbors + # on average. Yielding too many results would make the test slow (because + # checking the results is expensive for large result sets), yielding 0 most + # of the time would make the test useless. + if precomputed_dists is None and metric is None: + raise ValueError("Either metric or dists must be provided") + if precomputed_dists is None: + assert X is not None + assert Y is not None + sampled_dists = pairwise_distances(X, Y, metric=metric, **metric_kwargs) + else: + sampled_dists = precomputed_dists[:n_subsampled_queries].copy() + sampled_dists.sort(axis=1) + return sampled_dists[:, expected_n_neighbors].mean() -def assert_radius_neighbors_results_quasi_equality( - ref_dist, - dist, - ref_indices, - indices, +def assert_compatible_radius_results( + neighbors_dists_a, + neighbors_dists_b, + neighbors_indices_a, + neighbors_indices_b, radius, - rtol=1e-4, + check_sorted=True, + rtol=1e-5, + atol=1e-6, ): """Assert that radius neighborhood results are valid up to: - - relative tolerance on computed distance values + + - relative and absolute tolerance on computed distance values - permutations of indices for distances values that differ up to a precision level - missing or extra last elements if their distance is @@ -217,101 +266,92 @@ def assert_radius_neighbors_results_quasi_equality( """ is_sorted = lambda a: np.all(a[:-1] <= a[1:]) - n_significant_digits = -(int(floor(log10(abs(rtol)))) + 1) - assert ( - len(ref_dist) == len(dist) == len(ref_indices) == len(indices) - ), "Arrays of results have various lengths." + len(neighbors_dists_a) + == len(neighbors_dists_b) + == len(neighbors_indices_a) + == len(neighbors_indices_b) + ) - n_queries = len(ref_dist) + n_queries = len(neighbors_dists_a) # Asserting equality of results one vector at a time for query_idx in range(n_queries): - ref_dist_row = ref_dist[query_idx] - dist_row = dist[query_idx] - - assert is_sorted( - ref_dist_row - ), f"Reference distances aren't sorted on row {query_idx}" - assert is_sorted(dist_row), f"Distances aren't sorted on row {query_idx}" - - # Vectors' lengths might be different due to small - # numerical differences of distance w.r.t the `radius` threshold. - largest_row = ref_dist_row if len(ref_dist_row) > len(dist_row) else dist_row - - # For the longest distances vector, we check that last extra elements - # that aren't present in the other vector are all in: [radius ± rtol] - min_length = min(len(ref_dist_row), len(dist_row)) - last_extra_elements = largest_row[min_length:] - if last_extra_elements.size > 0: - assert np.all(radius - rtol <= last_extra_elements <= radius + rtol), ( - f"The last extra elements ({last_extra_elements}) aren't in [radius ±" - f" rtol]=[{radius} ± {rtol}]" + dist_row_a = neighbors_dists_a[query_idx] + dist_row_b = neighbors_dists_b[query_idx] + indices_row_a = neighbors_indices_a[query_idx] + indices_row_b = neighbors_indices_b[query_idx] + + if check_sorted: + assert is_sorted(dist_row_a), f"Distances aren't sorted on row {query_idx}" + assert is_sorted(dist_row_b), f"Distances aren't sorted on row {query_idx}" + + assert len(dist_row_a) == len(indices_row_a) + assert len(dist_row_b) == len(indices_row_b) + + # Check that all distances are within the requested radius + if len(dist_row_a) > 0: + max_dist_a = np.max(dist_row_a) + assert max_dist_a <= radius, ( + f"Largest returned distance {max_dist_a} not within requested" + f" radius {radius} on row {query_idx}" + ) + if len(dist_row_b) > 0: + max_dist_b = np.max(dist_row_b) + assert max_dist_b <= radius, ( + f"Largest returned distance {max_dist_b} not within requested" + f" radius {radius} on row {query_idx}" ) - # We truncate the neighbors results list on the smallest length to - # be able to compare them, ignoring the elements checked above. - ref_dist_row = ref_dist_row[:min_length] - dist_row = dist_row[:min_length] - - assert_allclose(ref_dist_row, dist_row, rtol=rtol) - - ref_indices_row = ref_indices[query_idx] - indices_row = indices[query_idx] - - # Grouping indices by distances using sets on a rounded distances up - # to a given number of significant digits derived from rtol. - reference_neighbors_groups = defaultdict(set) - effective_neighbors_groups = defaultdict(set) + assert_same_distances_for_common_neighbors( + query_idx, + dist_row_a, + dist_row_b, + indices_row_a, + indices_row_b, + rtol, + atol, + ) - for neighbor_rank in range(min_length): - rounded_dist = relative_rounding( - ref_dist_row[neighbor_rank], - n_significant_digits=n_significant_digits, - ) - reference_neighbors_groups[rounded_dist].add(ref_indices_row[neighbor_rank]) - effective_neighbors_groups[rounded_dist].add(indices_row[neighbor_rank]) - - # Asserting equality of groups (sets) for each distance - msg = ( - f"Neighbors indices for query {query_idx} are not matching " - f"when rounding distances at {n_significant_digits} significant digits " - f"derived from rtol={rtol:.1e}" + threshold = (1 - rtol) * radius - atol + assert_no_missing_neighbors( + query_idx, + dist_row_a, + dist_row_b, + indices_row_a, + indices_row_b, + threshold, ) - for rounded_distance in reference_neighbors_groups.keys(): - assert ( - reference_neighbors_groups[rounded_distance] - == effective_neighbors_groups[rounded_distance] - ), msg +FLOAT32_TOLS = { + "atol": 1e-7, + "rtol": 1e-5, +} +FLOAT64_TOLS = { + "atol": 1e-9, + "rtol": 1e-7, +} ASSERT_RESULT = { - # In the case of 64bit, we test for exact equality of the results rankings - # and standard tolerance levels for the computed distance values. - # - # XXX: Note that in the future we might be interested in using quasi equality - # checks also for float64 data (with a larger number of significant digits) - # as the tests could be unstable because of numerically tied distances on - # some datasets (e.g. uniform grids). - (ArgKmin, np.float64): assert_argkmin_results_equality, + (ArgKmin, np.float64): partial(assert_compatible_argkmin_results, **FLOAT64_TOLS), + (ArgKmin, np.float32): partial(assert_compatible_argkmin_results, **FLOAT32_TOLS), ( RadiusNeighbors, np.float64, - ): assert_radius_neighbors_results_equality, - # In the case of 32bit, indices can be permuted due to small difference - # in the computations of their associated distances, hence we test equality of - # results up to valid permutations. - (ArgKmin, np.float32): assert_argkmin_results_quasi_equality, + ): partial(assert_compatible_radius_results, **FLOAT64_TOLS), ( RadiusNeighbors, np.float32, - ): assert_radius_neighbors_results_quasi_equality, + ): partial(assert_compatible_radius_results, **FLOAT32_TOLS), } -def test_assert_argkmin_results_quasi_equality(): - rtol = 1e-7 - eps = 1e-7 +def test_assert_compatible_argkmin_results(): + atol = 1e-7 + rtol = 0.0 + tols = dict(atol=atol, rtol=rtol) + + eps = atol / 3 _1m = 1.0 - eps _1p = 1.0 + eps @@ -332,72 +372,128 @@ def test_assert_argkmin_results_quasi_equality(): ) # Sanity check: compare the reference results to themselves. - assert_argkmin_results_quasi_equality( + assert_compatible_argkmin_results( ref_dist, ref_dist, ref_indices, ref_indices, rtol ) - # Apply valid permutation on indices: the last 3 points are - # all very close to one another so we accept any permutation - # on their rankings. - assert_argkmin_results_quasi_equality( + # Apply valid permutation on indices: the last 3 points are all very close + # to one another so we accept any permutation on their rankings. + assert_compatible_argkmin_results( + np.array([[1.2, 2.5, _6_1m, 6.1, _6_1p]]), np.array([[1.2, 2.5, _6_1m, 6.1, _6_1p]]), - np.array([[1.2, 2.5, 6.1, 6.1, 6.1]]), np.array([[1, 2, 3, 4, 5]]), - np.array([[1, 2, 4, 5, 3]]), - rtol=rtol, + np.array([[1, 2, 5, 4, 3]]), + **tols, ) - # All points are have close distances so any ranking permutation + + # The last few indices do not necessarily have to match because of the rounding + # errors on the distances: there could be tied results at the boundary. + assert_compatible_argkmin_results( + np.array([[1.2, 2.5, 3.0, 6.1, _6_1p]]), + np.array([[1.2, 2.5, 3.0, _6_1m, 6.1]]), + np.array([[1, 2, 3, 4, 5]]), + np.array([[1, 2, 3, 6, 7]]), + **tols, + ) + + # All points have close distances so any ranking permutation # is valid for this query result. - assert_argkmin_results_quasi_equality( - np.array([[_1m, _1m, 1, _1p, _1p]]), - np.array([[_1m, _1m, 1, _1p, _1p]]), - np.array([[6, 7, 8, 9, 10]]), + assert_compatible_argkmin_results( + np.array([[_1m, 1, _1p, _1p, _1p]]), + np.array([[1, 1, 1, 1, _1p]]), + np.array([[7, 6, 8, 10, 9]]), np.array([[6, 9, 7, 8, 10]]), - rtol=rtol, + **tols, ) - # Apply invalid permutation on indices: permuting the ranks - # of the 2 nearest neighbors is invalid because the distance - # values are too different. - msg = "Neighbors indices for query 0 are not matching" + # They could also be nearly truncation of very large nearly tied result + # sets hence all indices can also be distinct in this case: + assert_compatible_argkmin_results( + np.array([[_1m, 1, _1p, _1p, _1p]]), + np.array([[_1m, 1, 1, 1, _1p]]), + np.array([[34, 30, 8, 12, 24]]), + np.array([[42, 1, 21, 13, 3]]), + **tols, + ) + + # Apply invalid permutation on indices: permuting the ranks of the 2 + # nearest neighbors is invalid because the distance values are too + # different. + msg = re.escape( + "Query vector with index 0 lead to different distances for common neighbor with" + " index 1: dist_a=1.2 vs dist_b=2.5" + ) with pytest.raises(AssertionError, match=msg): - assert_argkmin_results_quasi_equality( + assert_compatible_argkmin_results( np.array([[1.2, 2.5, _6_1m, 6.1, _6_1p]]), np.array([[1.2, 2.5, _6_1m, 6.1, _6_1p]]), np.array([[1, 2, 3, 4, 5]]), np.array([[2, 1, 3, 4, 5]]), - rtol=rtol, + **tols, ) - # Indices aren't properly sorted w.r.t their distances - msg = "Neighbors indices for query 0 are not matching" + # Detect missing indices within the expected precision level, even when the + # distances match exactly. + msg = re.escape( + "neighors in b missing from a: [12]\nneighors in a missing from b: [1]" + ) with pytest.raises(AssertionError, match=msg): - assert_argkmin_results_quasi_equality( + assert_compatible_argkmin_results( np.array([[1.2, 2.5, _6_1m, 6.1, _6_1p]]), np.array([[1.2, 2.5, _6_1m, 6.1, _6_1p]]), np.array([[1, 2, 3, 4, 5]]), - np.array([[2, 1, 4, 5, 3]]), - rtol=rtol, + np.array([[12, 2, 4, 11, 3]]), + **tols, + ) + + # Detect missing indices outside the expected precision level. + msg = re.escape( + "neighors in b missing from a: []\nneighors in a missing from b: [3]" + ) + with pytest.raises(AssertionError, match=msg): + assert_compatible_argkmin_results( + np.array([[_1m, 1.0, _6_1m, 6.1, _6_1p]]), + np.array([[1.0, 1.0, _6_1m, 6.1, 7]]), + np.array([[1, 2, 3, 4, 5]]), + np.array([[2, 1, 4, 5, 12]]), + **tols, + ) + + # Detect missing indices outside the expected precision level, in the other + # direction: + msg = re.escape( + "neighors in b missing from a: [5]\nneighors in a missing from b: []" + ) + with pytest.raises(AssertionError, match=msg): + assert_compatible_argkmin_results( + np.array([[_1m, 1.0, _6_1m, 6.1, 7]]), + np.array([[1.0, 1.0, _6_1m, 6.1, _6_1p]]), + np.array([[1, 2, 3, 4, 12]]), + np.array([[2, 1, 5, 3, 4]]), + **tols, ) # Distances aren't properly sorted msg = "Distances aren't sorted on row 0" with pytest.raises(AssertionError, match=msg): - assert_argkmin_results_quasi_equality( + assert_compatible_argkmin_results( np.array([[1.2, 2.5, _6_1m, 6.1, _6_1p]]), np.array([[2.5, 1.2, _6_1m, 6.1, _6_1p]]), np.array([[1, 2, 3, 4, 5]]), np.array([[2, 1, 4, 5, 3]]), - rtol=rtol, + **tols, ) -def test_assert_radius_neighbors_results_quasi_equality(): - rtol = 1e-7 - eps = 1e-7 +@pytest.mark.parametrize("check_sorted", [True, False]) +def test_assert_compatible_radius_results(check_sorted): + atol = 1e-7 + rtol = 0.0 + tols = dict(atol=atol, rtol=rtol) + + eps = atol / 3 _1m = 1.0 - eps _1p = 1.0 + eps - _6_1m = 6.1 - eps _6_1p = 6.1 + eps @@ -412,91 +508,143 @@ def test_assert_radius_neighbors_results_quasi_equality(): ] # Sanity check: compare the reference results to themselves. - assert_radius_neighbors_results_quasi_equality( + assert_compatible_radius_results( ref_dist, ref_dist, ref_indices, ref_indices, - radius=6.1, - rtol=rtol, + radius=7.0, + check_sorted=check_sorted, + **tols, ) # Apply valid permutation on indices - assert_radius_neighbors_results_quasi_equality( + assert_compatible_radius_results( np.array([np.array([1.2, 2.5, _6_1m, 6.1, _6_1p])]), np.array([np.array([1.2, 2.5, _6_1m, 6.1, _6_1p])]), np.array([np.array([1, 2, 3, 4, 5])]), np.array([np.array([1, 2, 4, 5, 3])]), - radius=6.1, - rtol=rtol, + radius=7.0, + check_sorted=check_sorted, + **tols, ) - assert_radius_neighbors_results_quasi_equality( + assert_compatible_radius_results( np.array([np.array([_1m, _1m, 1, _1p, _1p])]), np.array([np.array([_1m, _1m, 1, _1p, _1p])]), np.array([np.array([6, 7, 8, 9, 10])]), np.array([np.array([6, 9, 7, 8, 10])]), - radius=6.1, - rtol=rtol, + radius=7.0, + check_sorted=check_sorted, + **tols, ) # Apply invalid permutation on indices - msg = "Neighbors indices for query 0 are not matching" + msg = re.escape( + "Query vector with index 0 lead to different distances for common neighbor with" + " index 1: dist_a=1.2 vs dist_b=2.5" + ) with pytest.raises(AssertionError, match=msg): - assert_radius_neighbors_results_quasi_equality( + assert_compatible_radius_results( np.array([np.array([1.2, 2.5, _6_1m, 6.1, _6_1p])]), np.array([np.array([1.2, 2.5, _6_1m, 6.1, _6_1p])]), np.array([np.array([1, 2, 3, 4, 5])]), np.array([np.array([2, 1, 3, 4, 5])]), - radius=6.1, - rtol=rtol, + radius=7.0, + check_sorted=check_sorted, + **tols, ) - # Having extra last elements is valid if they are in: [radius ± rtol] - assert_radius_neighbors_results_quasi_equality( - np.array([np.array([1.2, 2.5, _6_1m, 6.1, _6_1p])]), + # Having extra last or missing elements is valid if they are in the + # tolerated rounding error range: [(1 - rtol) * radius - atol, radius] + assert_compatible_radius_results( + np.array([np.array([1.2, 2.5, _6_1m, 6.1, _6_1p, _6_1p])]), np.array([np.array([1.2, 2.5, _6_1m, 6.1])]), - np.array([np.array([1, 2, 3, 4, 5])]), - np.array([np.array([1, 2, 3, 4])]), - radius=6.1, - rtol=rtol, + np.array([np.array([1, 2, 3, 4, 5, 7])]), + np.array([np.array([1, 2, 3, 6])]), + radius=_6_1p, + check_sorted=check_sorted, + **tols, ) - # Having extra last elements is invalid if they are lesser than radius - rtol + # Any discrepancy outside the tolerated rounding error range is invalid and + # indicates a missing neighbor in one of the result sets. msg = re.escape( - "The last extra elements ([6.]) aren't in [radius ± rtol]=[6.1 ± 1e-07]" + "Query vector with index 0 lead to mismatched result indices:\nneighors in b" + " missing from a: []\nneighors in a missing from b: [3]" ) with pytest.raises(AssertionError, match=msg): - assert_radius_neighbors_results_quasi_equality( + assert_compatible_radius_results( np.array([np.array([1.2, 2.5, 6])]), np.array([np.array([1.2, 2.5])]), np.array([np.array([1, 2, 3])]), np.array([np.array([1, 2])]), radius=6.1, - rtol=rtol, + check_sorted=check_sorted, + **tols, + ) + msg = re.escape( + "Query vector with index 0 lead to mismatched result indices:\nneighors in b" + " missing from a: [4]\nneighors in a missing from b: [2]" + ) + with pytest.raises(AssertionError, match=msg): + assert_compatible_radius_results( + np.array([np.array([1.2, 2.1, 2.5])]), + np.array([np.array([1.2, 2, 2.5])]), + np.array([np.array([1, 2, 3])]), + np.array([np.array([1, 4, 3])]), + radius=6.1, + check_sorted=check_sorted, + **tols, ) - # Indices aren't properly sorted w.r.t their distances - msg = "Neighbors indices for query 0 are not matching" + # Radius upper bound is strictly checked + msg = re.escape( + "Largest returned distance 6.100000033333333 not within requested radius 6.1 on" + " row 0" + ) with pytest.raises(AssertionError, match=msg): - assert_radius_neighbors_results_quasi_equality( + assert_compatible_radius_results( np.array([np.array([1.2, 2.5, _6_1m, 6.1, _6_1p])]), + np.array([np.array([1.2, 2.5, _6_1m, 6.1, 6.1])]), + np.array([np.array([1, 2, 3, 4, 5])]), + np.array([np.array([2, 1, 4, 5, 3])]), + radius=6.1, + check_sorted=check_sorted, + **tols, + ) + with pytest.raises(AssertionError, match=msg): + assert_compatible_radius_results( + np.array([np.array([1.2, 2.5, _6_1m, 6.1, 6.1])]), np.array([np.array([1.2, 2.5, _6_1m, 6.1, _6_1p])]), np.array([np.array([1, 2, 3, 4, 5])]), np.array([np.array([2, 1, 4, 5, 3])]), radius=6.1, - rtol=rtol, + check_sorted=check_sorted, + **tols, ) - # Distances aren't properly sorted - msg = "Distances aren't sorted on row 0" - with pytest.raises(AssertionError, match=msg): - assert_radius_neighbors_results_quasi_equality( + if check_sorted: + # Distances aren't properly sorted + msg = "Distances aren't sorted on row 0" + with pytest.raises(AssertionError, match=msg): + assert_compatible_radius_results( + np.array([np.array([1.2, 2.5, _6_1m, 6.1, _6_1p])]), + np.array([np.array([2.5, 1.2, _6_1m, 6.1, _6_1p])]), + np.array([np.array([1, 2, 3, 4, 5])]), + np.array([np.array([2, 1, 4, 5, 3])]), + radius=_6_1p, + check_sorted=True, + **tols, + ) + else: + assert_compatible_radius_results( np.array([np.array([1.2, 2.5, _6_1m, 6.1, _6_1p])]), np.array([np.array([2.5, 1.2, _6_1m, 6.1, _6_1p])]), np.array([np.array([1, 2, 3, 4, 5])]), np.array([np.array([2, 1, 4, 5, 3])]), - radius=6.1, - rtol=rtol, + radius=_6_1p, + check_sorted=False, + **tols, ) @@ -963,22 +1111,18 @@ def test_radius_neighbors_classmode_factory_method_wrong_usages(): ) -@pytest.mark.parametrize( - "n_samples_X, n_samples_Y", [(100, 100), (500, 100), (100, 500)] -) @pytest.mark.parametrize("Dispatcher", [ArgKmin, RadiusNeighbors]) @pytest.mark.parametrize("dtype", [np.float64, np.float32]) def test_chunk_size_agnosticism( global_random_seed, Dispatcher, - n_samples_X, - n_samples_Y, dtype, n_features=100, ): """Check that results do not depend on the chunk size.""" rng = np.random.RandomState(global_random_seed) spread = 100 + n_samples_X, n_samples_Y = rng.choice([97, 100, 101, 500], size=2, replace=False) X = rng.rand(n_samples_X, n_features).astype(dtype) * spread Y = rng.rand(n_samples_Y, n_features).astype(dtype) * spread @@ -987,8 +1131,7 @@ def test_chunk_size_agnosticism( check_parameters = {} compute_parameters = {} else: - # Scaling the radius slightly with the numbers of dimensions - radius = 10 ** np.log(n_features) + radius = _non_trivial_radius(X=X, Y=Y, metric="euclidean") parameter = radius check_parameters = {"radius": radius} compute_parameters = {"sort_results": True} @@ -1018,21 +1161,17 @@ def test_chunk_size_agnosticism( ) -@pytest.mark.parametrize( - "n_samples_X, n_samples_Y", [(100, 100), (500, 100), (100, 500)] -) @pytest.mark.parametrize("Dispatcher", [ArgKmin, RadiusNeighbors]) @pytest.mark.parametrize("dtype", [np.float64, np.float32]) def test_n_threads_agnosticism( global_random_seed, Dispatcher, - n_samples_X, - n_samples_Y, dtype, n_features=100, ): """Check that results do not depend on the number of threads.""" rng = np.random.RandomState(global_random_seed) + n_samples_X, n_samples_Y = rng.choice([97, 100, 101, 500], size=2, replace=False) spread = 100 X = rng.rand(n_samples_X, n_features).astype(dtype) * spread Y = rng.rand(n_samples_Y, n_features).astype(dtype) * spread @@ -1042,8 +1181,7 @@ def test_n_threads_agnosticism( check_parameters = {} compute_parameters = {} else: - # Scaling the radius slightly with the numbers of dimensions - radius = 10 ** np.log(n_features) + radius = _non_trivial_radius(X=X, Y=Y, metric="euclidean") parameter = radius check_parameters = {"radius": radius} compute_parameters = {"sort_results": True} @@ -1104,8 +1242,9 @@ def test_format_agnosticism( check_parameters = {} compute_parameters = {} else: - # Scaling the radius slightly with the numbers of dimensions - radius = 10 ** np.log(n_features) + # Adjusting the radius to ensure that the expected results is neither + # trivially empty nor too large. + radius = _non_trivial_radius(X=X, Y=Y, metric="euclidean") parameter = radius check_parameters = {"radius": radius} compute_parameters = {"sort_results": True} @@ -1139,29 +1278,30 @@ def test_format_agnosticism( ) -@pytest.mark.parametrize( - "n_samples_X, n_samples_Y", [(100, 100), (100, 500), (500, 100)] -) -@pytest.mark.parametrize( - "metric", - ["euclidean", "minkowski", "manhattan", "infinity", "seuclidean", "haversine"], -) @pytest.mark.parametrize("Dispatcher", [ArgKmin, RadiusNeighbors]) -@pytest.mark.parametrize("dtype", [np.float64, np.float32]) def test_strategies_consistency( global_random_seed, + global_dtype, Dispatcher, - metric, - n_samples_X, - n_samples_Y, - dtype, n_features=10, ): """Check that the results do not depend on the strategy used.""" rng = np.random.RandomState(global_random_seed) + metric = rng.choice( + np.array( + [ + "euclidean", + "minkowski", + "manhattan", + "haversine", + ], + dtype=object, + ) + ) + n_samples_X, n_samples_Y = rng.choice([97, 100, 101, 500], size=2, replace=False) spread = 100 - X = rng.rand(n_samples_X, n_features).astype(dtype) * spread - Y = rng.rand(n_samples_Y, n_features).astype(dtype) * spread + X = rng.rand(n_samples_X, n_features).astype(global_dtype) * spread + Y = rng.rand(n_samples_Y, n_features).astype(global_dtype) * spread # Haversine distance only accepts 2D data if metric == "haversine": @@ -1173,8 +1313,7 @@ def test_strategies_consistency( check_parameters = {} compute_parameters = {} else: - # Scaling the radius slightly with the numbers of dimensions - radius = 10 ** np.log(n_features) + radius = _non_trivial_radius(X=X, Y=Y, metric=metric) parameter = radius check_parameters = {"radius": radius} compute_parameters = {"sort_results": True} @@ -1211,7 +1350,7 @@ def test_strategies_consistency( **compute_parameters, ) - ASSERT_RESULT[(Dispatcher, dtype)]( + ASSERT_RESULT[(Dispatcher, global_dtype)]( dist_par_X, dist_par_Y, indices_par_X, indices_par_Y, **check_parameters ) @@ -1219,34 +1358,25 @@ def test_strategies_consistency( # "Concrete Dispatchers"-specific tests -@pytest.mark.parametrize("n_features", [50, 500]) -@pytest.mark.parametrize("translation", [0, 1e6]) @pytest.mark.parametrize("metric", CDIST_PAIRWISE_DISTANCES_REDUCTION_COMMON_METRICS) @pytest.mark.parametrize("strategy", ("parallel_on_X", "parallel_on_Y")) @pytest.mark.parametrize("dtype", [np.float64, np.float32]) @pytest.mark.parametrize("csr_container", CSR_CONTAINERS) def test_pairwise_distances_argkmin( global_random_seed, - n_features, - translation, metric, strategy, dtype, csr_container, + n_queries=5, n_samples=100, k=10, ): - # TODO: can we easily fix this discrepancy? - edge_cases = [ - (np.float32, "chebyshev", 1000000.0), - (np.float32, "cityblock", 1000000.0), - ] - if (dtype, metric, translation) in edge_cases: - pytest.xfail("Numerical differences lead to small differences in results.") - rng = np.random.RandomState(global_random_seed) + n_features = rng.choice([50, 500]) + translation = rng.choice([0, 1e6]) spread = 1000 - X = translation + rng.rand(n_samples, n_features).astype(dtype) * spread + X = translation + rng.rand(n_queries, n_features).astype(dtype) * spread Y = translation + rng.rand(n_samples, n_features).astype(dtype) * spread X_csr = csr_container(X) @@ -1295,24 +1425,22 @@ def test_pairwise_distances_argkmin( ) -@pytest.mark.parametrize("n_features", [50, 500]) -@pytest.mark.parametrize("translation", [0, 1e6]) @pytest.mark.parametrize("metric", CDIST_PAIRWISE_DISTANCES_REDUCTION_COMMON_METRICS) @pytest.mark.parametrize("strategy", ("parallel_on_X", "parallel_on_Y")) @pytest.mark.parametrize("dtype", [np.float64, np.float32]) def test_pairwise_distances_radius_neighbors( global_random_seed, - n_features, - translation, metric, strategy, dtype, + n_queries=5, n_samples=100, ): rng = np.random.RandomState(global_random_seed) + n_features = rng.choice([50, 500]) + translation = rng.choice([0, 1e6]) spread = 1000 - radius = spread * np.log(n_features) - X = translation + rng.rand(n_samples, n_features).astype(dtype) * spread + X = translation + rng.rand(n_queries, n_features).astype(dtype) * spread Y = translation + rng.rand(n_samples, n_features).astype(dtype) * spread metric_kwargs = _get_metric_params_list( @@ -1326,6 +1454,8 @@ def test_pairwise_distances_radius_neighbors( else: dist_matrix = cdist(X, Y, metric=metric, **metric_kwargs) + radius = _non_trivial_radius(precomputed_dists=dist_matrix) + # Getting the neighbors for a given radius neigh_indices_ref = [] neigh_distances_ref = [] @@ -1410,21 +1540,18 @@ def test_memmap_backed_data( ) -@pytest.mark.parametrize("n_samples", [100, 1000]) -@pytest.mark.parametrize("n_features", [5, 10, 100]) -@pytest.mark.parametrize("num_threads", [1, 2, 8]) @pytest.mark.parametrize("dtype", [np.float64, np.float32]) @pytest.mark.parametrize("csr_container", CSR_CONTAINERS) def test_sqeuclidean_row_norms( global_random_seed, - n_samples, - n_features, - num_threads, dtype, csr_container, ): rng = np.random.RandomState(global_random_seed) spread = 100 + n_samples = rng.choice([97, 100, 101, 1000]) + n_features = rng.choice([5, 10, 100]) + num_threads = rng.choice([1, 2, 8]) X = rng.rand(n_samples, n_features).astype(dtype) * spread X_csr = csr_container(X) diff --git a/sklearn/neighbors/tests/test_neighbors.py b/sklearn/neighbors/tests/test_neighbors.py index 2d8fb8c69c599..ac312144ae968 100644 --- a/sklearn/neighbors/tests/test_neighbors.py +++ b/sklearn/neighbors/tests/test_neighbors.py @@ -18,10 +18,11 @@ from sklearn.metrics._dist_metrics import ( DistanceMetric, ) -from sklearn.metrics.pairwise import pairwise_distances +from sklearn.metrics.pairwise import PAIRWISE_BOOLEAN_FUNCTIONS, pairwise_distances from sklearn.metrics.tests.test_dist_metrics import BOOL_METRICS from sklearn.metrics.tests.test_pairwise_distances_reduction import ( - assert_radius_neighbors_results_equality, + assert_compatible_argkmin_results, + assert_compatible_radius_results, ) from sklearn.model_selection import cross_val_score, train_test_split from sklearn.neighbors import ( @@ -1712,8 +1713,15 @@ def test_neighbors_metrics( "metric", sorted(set(neighbors.VALID_METRICS["brute"]) - set(["precomputed"])) ) def test_kneighbors_brute_backend( - global_dtype, metric, n_samples=2000, n_features=30, n_query_pts=100, n_neighbors=5 + metric, + global_dtype, + global_random_seed, + n_samples=2000, + n_features=30, + n_query_pts=5, + n_neighbors=5, ): + rng = np.random.RandomState(global_random_seed) # Both backend for the 'brute' algorithm of kneighbors must give identical results. X_train = rng.rand(n_samples, n_features).astype(global_dtype, copy=False) X_test = rng.rand(n_query_pts, n_features).astype(global_dtype, copy=False) @@ -1724,6 +1732,10 @@ def test_kneighbors_brute_backend( X_train = np.ascontiguousarray(X_train[:, feature_sl]) X_test = np.ascontiguousarray(X_test[:, feature_sl]) + if metric in PAIRWISE_BOOLEAN_FUNCTIONS: + X_train = X_train > 0.5 + X_test = X_test > 0.5 + metric_params_list = _generate_test_params_for(metric, n_features) for metric_params in metric_params_list: @@ -1750,8 +1762,9 @@ def test_kneighbors_brute_backend( X_test, return_distance=True ) - assert_allclose(legacy_brute_dst, pdr_brute_dst) - assert_array_equal(legacy_brute_idx, pdr_brute_idx) + assert_compatible_argkmin_results( + legacy_brute_dst, pdr_brute_dst, legacy_brute_idx, pdr_brute_idx + ) def test_callable_metric(): @@ -2223,16 +2236,18 @@ def test_auto_algorithm(X, metric, metric_params, expected_algo): ) def test_radius_neighbors_brute_backend( metric, + global_random_seed, + global_dtype, n_samples=2000, n_features=30, - n_query_pts=100, - n_neighbors=5, + n_query_pts=5, radius=1.0, ): + rng = np.random.RandomState(global_random_seed) # Both backends for the 'brute' algorithm of radius_neighbors # must give identical results. - X_train = rng.rand(n_samples, n_features) - X_test = rng.rand(n_query_pts, n_features) + X_train = rng.rand(n_samples, n_features).astype(global_dtype, copy=False) + X_test = rng.rand(n_query_pts, n_features).astype(global_dtype, copy=False) # Haversine distance only accepts 2D data if metric == "haversine": @@ -2246,7 +2261,6 @@ def test_radius_neighbors_brute_backend( p = metric_params.pop("p", 2) neigh = neighbors.NearestNeighbors( - n_neighbors=n_neighbors, radius=radius, algorithm="brute", metric=metric, @@ -2267,12 +2281,13 @@ def test_radius_neighbors_brute_backend( X_test, return_distance=True ) - assert_radius_neighbors_results_equality( + assert_compatible_radius_results( legacy_brute_dst, pdr_brute_dst, legacy_brute_idx, pdr_brute_idx, radius=radius, + check_sorted=False, )