Skip to content

ENH NearestNeighbors-like classes with metric="nan_euclidean" does not actually support NaN values #25330

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 25 commits into from
Nov 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
00b174a
fix nan_euclidean bug
vitaliset Jan 8, 2023
4475e07
black updates
vitaliset Jan 8, 2023
5de72ce
extra space and liting problems
vitaliset Jan 8, 2023
46ca5ba
changelog
vitaliset Jan 10, 2023
8eb1cb4
Merge branch 'main' into nan_euclidean_nn_bug
vitaliset Jan 10, 2023
ba191de
Merge branch 'main' into nan_euclidean_nn_bug
glemaitre Jan 14, 2023
1869bd1
Apply suggestions from code review
vitaliset Feb 3, 2023
d31e50e
apply suggestions from glemaitre
vitaliset Feb 4, 2023
90a79be
Merge branch 'main' into nan_euclidean_nn_bug
vitaliset Feb 4, 2023
cd852a9
linting
vitaliset Feb 4, 2023
ae17318
Merge branch 'nan_euclidean_nn_bug' of https://github.com/vitaliset/s…
vitaliset Feb 4, 2023
e6fce35
apply suggestions from glemaitre
vitaliset Feb 23, 2023
1ff99c8
linting problem - black
vitaliset Feb 23, 2023
5a28b6e
missing comma
vitaliset Feb 23, 2023
f958579
Merge remote-tracking branch 'origin/main' into pr/vitaliset/25330
glemaitre May 21, 2024
c714378
TST make sure to test what we intend to
glemaitre May 21, 2024
e2d2f5f
remove test
glemaitre May 21, 2024
c9752ff
Update typo on v1.6.rst
vitaliset May 21, 2024
2e24b03
Merge remote-tracking branch 'upstream/main' into nan_euclidean_nn_bug
adrinjalali Oct 31, 2024
faff4f7
changelog
adrinjalali Oct 31, 2024
8cbacad
changelog fix
adrinjalali Oct 31, 2024
2009b55
use tag
adrinjalali Oct 31, 2024
b549aaa
fix tags
adrinjalali Nov 1, 2024
18a79e4
Merge remote-tracking branch 'upstream/main' into nan_euclidean_nn_bug
adrinjalali Nov 1, 2024
c69e867
authorship
adrinjalali Nov 1, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
- :class:`neighbors.NearestNeighbors`, :class:`KNeighborsClassifier`,
:class:`KNeighborsRegressor`, :class:`RadiusNeighborsClassifier`,
:class:`RadiusNeighborsRegressor`, :class:`KNeighborsTransformer`,
:class:`RadiusNeighborsTransformer`, and :class:`LocalOutlierFactor`
now work with `metric="nan_euclidean"`, supporting `nan` inputs.
By :user:`Carlo Lemos <vitaliset>`, `Guillaume Lemaitre`_, and `Adrin Jalali`_
10 changes: 6 additions & 4 deletions sklearn/metrics/pairwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -768,7 +768,7 @@ def pairwise_distances_argmin_min(
Valid values for metric are:

- from scikit-learn: ['cityblock', 'cosine', 'euclidean', 'l1', 'l2',
'manhattan']
'manhattan', 'nan_euclidean']

- from scipy.spatial.distance: ['braycurtis', 'canberra', 'chebyshev',
'correlation', 'dice', 'hamming', 'jaccard', 'kulsinski',
Expand Down Expand Up @@ -814,7 +814,8 @@ def pairwise_distances_argmin_min(
>>> distances
array([1., 1.])
"""
X, Y = check_pairwise_arrays(X, Y)
ensure_all_finite = "allow-nan" if metric == "nan_euclidean" else True
X, Y = check_pairwise_arrays(X, Y, ensure_all_finite=ensure_all_finite)

if axis == 0:
X, Y = Y, X
Expand Down Expand Up @@ -915,7 +916,7 @@ def pairwise_distances_argmin(X, Y, *, axis=1, metric="euclidean", metric_kwargs
Valid values for metric are:

- from scikit-learn: ['cityblock', 'cosine', 'euclidean', 'l1', 'l2',
'manhattan']
'manhattan', 'nan_euclidean']

- from scipy.spatial.distance: ['braycurtis', 'canberra', 'chebyshev',
'correlation', 'dice', 'hamming', 'jaccard', 'kulsinski',
Expand Down Expand Up @@ -954,7 +955,8 @@ def pairwise_distances_argmin(X, Y, *, axis=1, metric="euclidean", metric_kwargs
>>> pairwise_distances_argmin(X, Y)
array([0, 1])
"""
X, Y = check_pairwise_arrays(X, Y)
ensure_all_finite = "allow-nan" if metric == "nan_euclidean" else True
X, Y = check_pairwise_arrays(X, Y, ensure_all_finite=ensure_all_finite)

if axis == 0:
X, Y = Y, X
Expand Down
27 changes: 27 additions & 0 deletions sklearn/metrics/tests/test_pairwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -1582,6 +1582,33 @@ def test_numeric_pairwise_distances_datatypes(metric, global_dtype, y_is_x):
assert_allclose(dist, expected_dist)


@pytest.mark.parametrize(
"pairwise_distances_func",
[pairwise_distances, pairwise_distances_argmin, pairwise_distances_argmin_min],
)
def test_nan_euclidean_support(pairwise_distances_func):
"""Check that `nan_euclidean` is lenient with `nan` values."""

X = [[0, 1], [1, np.nan], [2, 3], [3, 5]]
output = pairwise_distances_func(X, X, metric="nan_euclidean")

assert not np.isnan(output).any()


def test_nan_euclidean_constant_input_argmin():
"""Check that the behavior of constant input is the same in the case of
full of nan vector and full of zero vector.
"""

X_nan = [[np.nan, np.nan], [np.nan, np.nan], [np.nan, np.nan]]
argmin_nan = pairwise_distances_argmin(X_nan, X_nan, metric="nan_euclidean")

X_const = [[0, 0], [0, 0], [0, 0]]
argmin_const = pairwise_distances_argmin(X_const, X_const, metric="nan_euclidean")

assert_allclose(argmin_nan, argmin_const)


@pytest.mark.parametrize(
"X,Y,expected_distance",
[
Expand Down
44 changes: 40 additions & 4 deletions sklearn/neighbors/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from ..utils import (
check_array,
gen_even_slices,
get_tags,
)
from ..utils._param_validation import Interval, StrOptions, validate_params
from ..utils.fixes import parse_version, sp_base_version
Expand Down Expand Up @@ -471,10 +472,17 @@ def _check_algorithm_metric(self):
)

def _fit(self, X, y=None):
ensure_all_finite = "allow-nan" if get_tags(self).input_tags.allow_nan else True
if self.__sklearn_tags__().target_tags.required:
if not isinstance(X, (KDTree, BallTree, NeighborsBase)):
X, y = validate_data(
self, X, y, accept_sparse="csr", multi_output=True, order="C"
self,
X,
y,
accept_sparse="csr",
multi_output=True,
order="C",
ensure_all_finite=ensure_all_finite,
)

if is_classifier(self):
Expand Down Expand Up @@ -515,7 +523,13 @@ def _fit(self, X, y=None):

else:
if not isinstance(X, (KDTree, BallTree, NeighborsBase)):
X = validate_data(self, X, accept_sparse="csr", order="C")
X = validate_data(
self,
X,
ensure_all_finite=ensure_all_finite,
accept_sparse="csr",
order="C",
)

self._check_algorithm_metric()
if self.metric_params is None:
Expand Down Expand Up @@ -695,6 +709,7 @@ def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
# For cross-validation routines to split data correctly
tags.input_tags.pairwise = self.metric == "precomputed"
tags.input_tags.allow_nan = self.metric == "nan_euclidean"
return tags


Expand Down Expand Up @@ -806,6 +821,7 @@ class from an array representing our data set and ask who's
% type(n_neighbors)
)

ensure_all_finite = "allow-nan" if get_tags(self).input_tags.allow_nan else True
query_is_train = X is None
if query_is_train:
X = self._fit_X
Expand All @@ -816,7 +832,14 @@ class from an array representing our data set and ask who's
if self.metric == "precomputed":
X = _check_precomputed(X)
else:
X = validate_data(self, X, accept_sparse="csr", reset=False, order="C")
X = validate_data(
self,
X,
ensure_all_finite=ensure_all_finite,
accept_sparse="csr",
reset=False,
order="C",
)

n_samples_fit = self.n_samples_fit_
if n_neighbors > n_samples_fit:
Expand Down Expand Up @@ -1145,14 +1168,22 @@ class from an array representing our data set and ask who's
if sort_results and not return_distance:
raise ValueError("return_distance must be True if sort_results is True.")

ensure_all_finite = "allow-nan" if get_tags(self).input_tags.allow_nan else True
query_is_train = X is None
if query_is_train:
X = self._fit_X
else:
if self.metric == "precomputed":
X = _check_precomputed(X)
else:
X = validate_data(self, X, accept_sparse="csr", reset=False, order="C")
X = validate_data(
self,
X,
ensure_all_finite=ensure_all_finite,
accept_sparse="csr",
reset=False,
order="C",
)

if radius is None:
radius = self.radius
Expand Down Expand Up @@ -1363,3 +1394,8 @@ def radius_neighbors_graph(
A_indptr = np.concatenate((np.zeros(1, dtype=int), np.cumsum(n_neighbors)))

return csr_matrix((A_data, A_ind, A_indptr), shape=(n_queries, n_samples_fit))

def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
tags.input_tags.allow_nan = self.metric == "nan_euclidean"
return tags
28 changes: 26 additions & 2 deletions sklearn/neighbors/_nearest_centroid.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
pairwise_distances_argmin,
)
from ..preprocessing import LabelEncoder
from ..utils import get_tags
from ..utils._available_if import available_if
from ..utils._param_validation import Interval, StrOptions
from ..utils.multiclass import check_classification_targets
Expand Down Expand Up @@ -172,7 +173,16 @@ def fit(self, X, y):
if self.metric == "manhattan":
X, y = validate_data(self, X, y, accept_sparse=["csc"])
else:
X, y = validate_data(self, X, y, accept_sparse=["csr", "csc"])
ensure_all_finite = (
"allow-nan" if get_tags(self).input_tags.allow_nan else True
)
X, y = validate_data(
self,
X,
y,
ensure_all_finite=ensure_all_finite,
accept_sparse=["csr", "csc"],
)
is_X_sparse = sp.issparse(X)
check_classification_targets(y)

Expand Down Expand Up @@ -283,7 +293,16 @@ def predict(self, X):
check_is_fitted(self)
if np.isclose(self.class_prior_, 1 / len(self.classes_)).all():
# `validate_data` is called here since we are not calling `super()`
X = validate_data(self, X, accept_sparse="csr", reset=False)
ensure_all_finite = (
"allow-nan" if get_tags(self).input_tags.allow_nan else True
)
X = validate_data(
self,
X,
ensure_all_finite=ensure_all_finite,
accept_sparse="csr",
reset=False,
)
return self.classes_[
pairwise_distances_argmin(X, self.centroids_, metric=self.metric)
]
Expand Down Expand Up @@ -332,3 +351,8 @@ def _check_euclidean_metric(self):
predict_log_proba = available_if(_check_euclidean_metric)(
DiscriminantAnalysisPredictionMixin.predict_log_proba
)

def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
tags.input_tags.allow_nan = self.metric == "nan_euclidean"
return tags
32 changes: 32 additions & 0 deletions sklearn/neighbors/tests/test_neighbors.py
Original file line number Diff line number Diff line change
Expand Up @@ -2337,6 +2337,38 @@ def _weights(dist):
assert_allclose(est.predict([[0, 2.5]]), [6])


@pytest.mark.parametrize(
"Estimator, params",
[
(neighbors.KNeighborsClassifier, {"n_neighbors": 2}),
(neighbors.KNeighborsRegressor, {"n_neighbors": 2}),
(neighbors.RadiusNeighborsRegressor, {}),
(neighbors.RadiusNeighborsClassifier, {}),
(neighbors.KNeighborsTransformer, {"n_neighbors": 2}),
(neighbors.RadiusNeighborsTransformer, {"radius": 1.5}),
(neighbors.LocalOutlierFactor, {"n_neighbors": 1}),
],
)
def test_nan_euclidean_support(Estimator, params):
"""Check that the different neighbor estimators are lenient towards `nan`
values if using `metric="nan_euclidean"`.
"""

X = [[0, 1], [1, np.nan], [2, 3], [3, 5]]
y = [0, 0, 1, 1]

params.update({"metric": "nan_euclidean"})
estimator = Estimator().set_params(**params).fit(X, y)

for response_method in ("kneighbors", "predict", "transform", "fit_predict"):
if hasattr(estimator, response_method):
output = getattr(estimator, response_method)(X)
if hasattr(output, "toarray"):
assert not np.isnan(output.data).any()
else:
assert not np.isnan(output).any()


def test_predict_dataframe():
"""Check that KNN predict works with dataframes

Expand Down