Skip to content

Commit dd7d6f7

Browse files
thomasjpfanjnothman
authored andcommitted
FIX Clip distances below 0 (#15683)
1 parent 729a4bd commit dd7d6f7

File tree

2 files changed

+19
-0
lines changed

2 files changed

+19
-0
lines changed

sklearn/metrics/pairwise.py

+2
Original file line numberDiff line numberDiff line change
@@ -406,6 +406,8 @@ def nan_euclidean_distances(X, Y=None, squared=False,
406406
distances -= np.dot(XX, missing_Y.T)
407407
distances -= np.dot(missing_X, YY.T)
408408

409+
np.clip(distances, 0, None, out=distances)
410+
409411
if X is Y:
410412
# Ensure that distances between vectors and themselves are set to 0.0.
411413
# This may not be the case due to floating point rounding errors.

sklearn/metrics/tests/test_pairwise.py

+17
Original file line numberDiff line numberDiff line change
@@ -871,6 +871,23 @@ def test_nan_euclidean_distances_not_trival(missing_value):
871871
assert_allclose(D6, D7)
872872

873873

874+
@pytest.mark.parametrize("missing_value", [np.nan, -1])
875+
def test_nan_euclidean_distances_one_feature_match_positive(missing_value):
876+
# First feature is the only feature that is non-nan and in both
877+
# samples. The result of `nan_euclidean_distances` with squared=True
878+
# should be non-negative. The non-squared version should all be close to 0.
879+
X = np.array([[-122.27, 648., missing_value, 37.85],
880+
[-122.27, missing_value, 2.34701493, missing_value]])
881+
882+
dist_squared = nan_euclidean_distances(X, missing_values=missing_value,
883+
squared=True)
884+
assert np.all(dist_squared >= 0)
885+
886+
dist = nan_euclidean_distances(X, missing_values=missing_value,
887+
squared=False)
888+
assert_allclose(dist, 0.0)
889+
890+
874891
def test_cosine_distances():
875892
# Check the pairwise Cosine distances computation
876893
rng = np.random.RandomState(1337)

0 commit comments

Comments
 (0)