Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions doc/whats_new/v1.1.rst
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,11 @@ Changelog
ndarray with `np.nan` when passed a `Float32` or `Float64` pandas extension
array with `pd.NA`. :pr:`21278` by `Thomas Fan`_.

- |Enhancement| Adds :term:`get_feature_names_out` to
:class:`neighbors.RadiusNeighborsTransformer`, :class:`neighbors.KNeighborsTransformer`
and :class:`neighbors.NeighborhoodComponentsAnalysis`. :pr:`22212` by
:user : `Meekail Zain <micky774>`.

- |Fix| :class:`neighbors.KernelDensity` now validates input parameters in `fit`
instead of `__init__`. :pr:`21430` by :user:`Desislava Vasileva <DessyVV>` and
:user:`Lucy Jimenez <LucyJimenez>`.
Expand Down
21 changes: 16 additions & 5 deletions sklearn/neighbors/_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from ._base import KNeighborsMixin, RadiusNeighborsMixin
from ._base import NeighborsBase
from ._unsupervised import NearestNeighbors
from ..base import TransformerMixin
from ..base import TransformerMixin, _ClassNamePrefixFeaturesOutMixin
from ..utils.validation import check_is_fitted


Expand Down Expand Up @@ -223,7 +223,9 @@ def radius_neighbors_graph(
return X.radius_neighbors_graph(query, radius, mode)


class KNeighborsTransformer(KNeighborsMixin, TransformerMixin, NeighborsBase):
class KNeighborsTransformer(
_ClassNamePrefixFeaturesOutMixin, KNeighborsMixin, TransformerMixin, NeighborsBase
):
"""Transform X into a (weighted) graph of k nearest neighbors.

The transformed data is a sparse graph as returned by kneighbors_graph.
Expand Down Expand Up @@ -389,7 +391,9 @@ def fit(self, X, y=None):
self : KNeighborsTransformer
The fitted k-nearest neighbors transformer.
"""
return self._fit(X)
self._fit(X)
self._n_features_out = self.n_samples_fit_
return self

def transform(self, X):
"""Compute the (weighted) graph of Neighbors for points in X.
Expand Down Expand Up @@ -445,7 +449,12 @@ def _more_tags(self):
}


class RadiusNeighborsTransformer(RadiusNeighborsMixin, TransformerMixin, NeighborsBase):
class RadiusNeighborsTransformer(
_ClassNamePrefixFeaturesOutMixin,
RadiusNeighborsMixin,
TransformerMixin,
NeighborsBase,
):
"""Transform X into a (weighted) graph of neighbors nearer than a radius.

The transformed data is a sparse graph as returned by
Expand Down Expand Up @@ -614,7 +623,9 @@ def fit(self, X, y=None):
self : RadiusNeighborsTransformer
The fitted radius neighbors transformer.
"""
return self._fit(X)
self._fit(X)
self._n_features_out = self.n_samples_fit_
return self

def transform(self, X):
"""Compute the (weighted) graph of Neighbors for points in X.
Expand Down
7 changes: 5 additions & 2 deletions sklearn/neighbors/_nca.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from scipy.optimize import minimize
from ..utils.extmath import softmax
from ..metrics import pairwise_distances
from ..base import BaseEstimator, TransformerMixin
from ..base import BaseEstimator, TransformerMixin, _ClassNamePrefixFeaturesOutMixin
from ..preprocessing import LabelEncoder
from ..decomposition import PCA
from ..utils.multiclass import check_classification_targets
Expand All @@ -24,7 +24,9 @@
from ..exceptions import ConvergenceWarning


class NeighborhoodComponentsAnalysis(TransformerMixin, BaseEstimator):
class NeighborhoodComponentsAnalysis(
_ClassNamePrefixFeaturesOutMixin, TransformerMixin, BaseEstimator
):
"""Neighborhood Components Analysis.

Neighborhood Component Analysis (NCA) is a machine learning algorithm for
Expand Down Expand Up @@ -249,6 +251,7 @@ def fit(self, X, y):

# Reshape the solution found by the optimizer
self.components_ = opt_result.x.reshape(-1, X.shape[1])
self._n_features_out = self.components_.shape[1]

# Stop timer
t_train = time.time() - t_train
Expand Down
22 changes: 22 additions & 0 deletions sklearn/neighbors/tests/test_graph.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import numpy as np
import pytest

from sklearn.metrics import euclidean_distances
from sklearn.neighbors import KNeighborsTransformer, RadiusNeighborsTransformer
from sklearn.neighbors._base import _is_sorted_by_data
from sklearn.utils._testing import assert_array_equal


def test_transformer_result():
Expand Down Expand Up @@ -77,3 +79,23 @@ def test_explicit_diagonal():
# Using transform on new data should not always have zero diagonal
X2t = nnt.transform(X2)
assert not _has_explicit_diagonal(X2t)


@pytest.mark.parametrize("Klass", [KNeighborsTransformer, RadiusNeighborsTransformer])
def test_graph_feature_names_out(Klass):
"""Check `get_feature_names_out` for transformers defined in `_graph.py`."""

n_samples_fit = 20
n_features = 10
rng = np.random.RandomState(42)
X = rng.randn(n_samples_fit, n_features)

est = Klass().fit(X)
names_out = est.get_feature_names_out()

class_name_lower = Klass.__name__.lower()
expected_names_out = np.array(
[f"{class_name_lower}{i}" for i in range(est.n_samples_fit_)],
dtype=object,
)
assert_array_equal(names_out, expected_names_out)
17 changes: 17 additions & 0 deletions sklearn/neighbors/tests/test_nca.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,3 +554,20 @@ def test_parameters_valid_types(param, value):
y = iris_target

nca.fit(X, y)


def test_nca_feature_names_out():
"""Check `get_feature_names_out` for `NeighborhoodComponentsAnalysis`."""

X = iris_data
y = iris_target

est = NeighborhoodComponentsAnalysis().fit(X, y)
names_out = est.get_feature_names_out()

class_name_lower = est.__class__.__name__.lower()
expected_names_out = np.array(
[f"{class_name_lower}{i}" for i in range(est.components_.shape[1])],
dtype=object,
)
assert_array_equal(names_out, expected_names_out)
1 change: 0 additions & 1 deletion sklearn/tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,6 @@ def test_pandas_column_name_consistency(estimator):
"kernel_approximation",
"preprocessing",
"manifold",
"neighbors",
"neural_network",
]

Expand Down