Skip to content

ENH Adds feature names out to decomposition module #21334

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
16 changes: 16 additions & 0 deletions doc/whats_new/v1.1.rst
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,22 @@ Changelog
- |Fix| :class:`decomposition.FastICA` now validates input parameters in `fit` instead of `__init__`.
:pr:`21432` by :user:`Hannah Bohle <hhnnhh>` and :user:`Maren Westermann <marenwestermann>`.

- |API| Adds :term:`get_feature_names_out` to all transformers in the
:mod:`~sklearn.decomposition` module:
:class:`~sklearn.decomposition.DictionaryLearning`,
:class:`~sklearn.decomposition.FactorAnalysis`,
:class:`~sklearn.decomposition.FastICA`,
:class:`~sklearn.decomposition.IncrementalPCA`,
:class:`~sklearn.decomposition.KernelPCA`,
:class:`~sklearn.decomposition.LatentDirichletAllocation`,
:class:`~sklearn.decomposition.MiniBatchDictionaryLearning`,
:class:`~sklearn.decomposition.MiniBatchSparsePCA`,
:class:`~sklearn.decomposition.NMF`,
:class:`~sklearn.decomposition.PCA`,
:class:`~sklearn.decomposition.SparsePCA`,
and :class:`~sklearn.decomposition.TruncatedSVD`. :pr:`21334` by
`Thomas Fan`_.

:mod:`sklearn.impute`
.....................

Expand Down
27 changes: 27 additions & 0 deletions sklearn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
from .utils.validation import _check_y
from .utils.validation import _num_features
from .utils.validation import _check_feature_names_in
from .utils.validation import _generate_get_feature_names_out
from .utils.validation import check_is_fitted
from .utils._estimator_html_repr import estimator_html_repr
from .utils.validation import _get_feature_names

Expand Down Expand Up @@ -879,6 +881,31 @@ def get_feature_names_out(self, input_features=None):
return _check_feature_names_in(self, input_features)


class _ClassNamePrefixFeaturesOutMixin:
"""Mixin class for transformers that generate their own names by prefixing.

Assumes that `_n_features_out` is defined for the estimator.
"""

def get_feature_names_out(self, input_features=None):
"""Get output feature names for transformation.

Parameters
----------
input_features : array-like of str or None, default=None
Only used to validate feature names with the names seen in :meth:`fit`.

Returns
-------
feature_names_out : ndarray of str objects
Transformed feature names.
"""
check_is_fitted(self, "_n_features_out")
return _generate_get_feature_names_out(
self, self._n_features_out, input_features=input_features
)


class DensityMixin:
"""Mixin class for all density estimators in scikit-learn."""

Expand Down
11 changes: 9 additions & 2 deletions sklearn/decomposition/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,14 @@
import numpy as np
from scipy import linalg

from ..base import BaseEstimator, TransformerMixin
from ..base import BaseEstimator, TransformerMixin, _ClassNamePrefixFeaturesOutMixin
from ..utils.validation import check_is_fitted
from abc import ABCMeta, abstractmethod


class _BasePCA(TransformerMixin, BaseEstimator, metaclass=ABCMeta):
class _BasePCA(
_ClassNamePrefixFeaturesOutMixin, TransformerMixin, BaseEstimator, metaclass=ABCMeta
):
"""Base class for PCA methods.

Warning: This class should not be used directly.
Expand Down Expand Up @@ -154,3 +156,8 @@ def inverse_transform(self, X):
)
else:
return np.dot(X, self.components_) + self.mean_

@property
def _n_features_out(self):
"""Number of transformed output features."""
return self.components_.shape[0]
19 changes: 17 additions & 2 deletions sklearn/decomposition/_dict_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from scipy import linalg
from joblib import Parallel, effective_n_jobs

from ..base import BaseEstimator, TransformerMixin
from ..base import BaseEstimator, TransformerMixin, _ClassNamePrefixFeaturesOutMixin
from ..utils import deprecated
from ..utils import check_array, check_random_state, gen_even_slices, gen_batches
from ..utils.extmath import randomized_svd, row_norms, svd_flip
Expand Down Expand Up @@ -1014,7 +1014,7 @@ def dict_learning_online(
return dictionary


class _BaseSparseCoding(TransformerMixin):
class _BaseSparseCoding(_ClassNamePrefixFeaturesOutMixin, TransformerMixin):
"""Base class from SparseCoder and DictionaryLearning algorithms."""

def __init__(
Expand Down Expand Up @@ -1315,6 +1315,11 @@ def n_features_in_(self):
"""Number of features seen during `fit`."""
return self.dictionary.shape[1]

@property
def _n_features_out(self):
"""Number of transformed output features."""
return self.n_components_


class DictionaryLearning(_BaseSparseCoding, BaseEstimator):
"""Dictionary learning.
Expand Down Expand Up @@ -1587,6 +1592,11 @@ def fit(self, X, y=None):
self.error_ = E
return self

@property
def _n_features_out(self):
"""Number of transformed output features."""
return self.components_.shape[0]


class MiniBatchDictionaryLearning(_BaseSparseCoding, BaseEstimator):
"""Mini-batch dictionary learning.
Expand Down Expand Up @@ -1926,3 +1936,8 @@ def partial_fit(self, X, y=None, iter_offset=None):
self.inner_stats_ = (A, B)
self.iter_offset_ = iter_offset + 1
return self

@property
def _n_features_out(self):
"""Number of transformed output features."""
return self.components_.shape[0]
9 changes: 7 additions & 2 deletions sklearn/decomposition/_factor_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,14 @@
from scipy import linalg


from ..base import BaseEstimator, TransformerMixin
from ..base import BaseEstimator, TransformerMixin, _ClassNamePrefixFeaturesOutMixin
from ..utils import check_random_state
from ..utils.extmath import fast_logdet, randomized_svd, squared_norm
from ..utils.validation import check_is_fitted
from ..exceptions import ConvergenceWarning


class FactorAnalysis(TransformerMixin, BaseEstimator):
class FactorAnalysis(_ClassNamePrefixFeaturesOutMixin, TransformerMixin, BaseEstimator):
"""Factor Analysis (FA).

A simple linear generative model with Gaussian latent variables.
Expand Down Expand Up @@ -426,6 +426,11 @@ def _rotate(self, components, n_components=None, tol=1e-6):
else:
raise ValueError("'method' must be in %s, not %s" % (implemented, method))

@property
def _n_features_out(self):
"""Number of transformed output features."""
return self.components_.shape[0]


def _ortho_rotation(components, method="varimax", tol=1e-6, max_iter=100):
"""Return rotated components."""
Expand Down
9 changes: 7 additions & 2 deletions sklearn/decomposition/_fastica.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import numpy as np
from scipy import linalg

from ..base import BaseEstimator, TransformerMixin
from ..base import BaseEstimator, TransformerMixin, _ClassNamePrefixFeaturesOutMixin
from ..exceptions import ConvergenceWarning

from ..utils import check_array, as_float_array, check_random_state
Expand Down Expand Up @@ -319,7 +319,7 @@ def my_g(x):
return None, est._unmixing, sources


class FastICA(TransformerMixin, BaseEstimator):
class FastICA(_ClassNamePrefixFeaturesOutMixin, TransformerMixin, BaseEstimator):
"""FastICA: a fast algorithm for Independent Component Analysis.

The implementation is based on [1]_.
Expand Down Expand Up @@ -689,3 +689,8 @@ def inverse_transform(self, X, copy=True):
X += self.mean_

return X

@property
def _n_features_out(self):
"""Number of transformed output features."""
return self.components_.shape[0]
14 changes: 11 additions & 3 deletions sklearn/decomposition/_kernel_pca.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,18 @@

from ..utils._arpack import _init_arpack_v0
from ..utils.extmath import svd_flip, _randomized_eigsh
from ..utils.validation import check_is_fitted, _check_psd_eigenvalues
from ..utils.validation import (
check_is_fitted,
_check_psd_eigenvalues,
)
from ..utils.deprecation import deprecated
from ..exceptions import NotFittedError
from ..base import BaseEstimator, TransformerMixin
from ..base import BaseEstimator, TransformerMixin, _ClassNamePrefixFeaturesOutMixin
from ..preprocessing import KernelCenterer
from ..metrics.pairwise import pairwise_kernels


class KernelPCA(TransformerMixin, BaseEstimator):
class KernelPCA(_ClassNamePrefixFeaturesOutMixin, TransformerMixin, BaseEstimator):
"""Kernel Principal component analysis (KPCA).

Non-linear dimensionality reduction through the use of kernels (see
Expand Down Expand Up @@ -546,3 +549,8 @@ def _more_tags(self):
"preserves_dtype": [np.float64, np.float32],
"pairwise": self.kernel == "precomputed",
}

@property
def _n_features_out(self):
"""Number of transformed output features."""
return self.eigenvalues_.shape[0]
11 changes: 9 additions & 2 deletions sklearn/decomposition/_lda.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from scipy.special import gammaln, logsumexp
from joblib import Parallel, effective_n_jobs

from ..base import BaseEstimator, TransformerMixin
from ..base import BaseEstimator, TransformerMixin, _ClassNamePrefixFeaturesOutMixin
from ..utils import check_random_state, gen_batches, gen_even_slices
from ..utils.validation import check_non_negative
from ..utils.validation import check_is_fitted
Expand Down Expand Up @@ -138,7 +138,9 @@ def _update_doc_distribution(
return (doc_topic_distr, suff_stats)


class LatentDirichletAllocation(TransformerMixin, BaseEstimator):
class LatentDirichletAllocation(
_ClassNamePrefixFeaturesOutMixin, TransformerMixin, BaseEstimator
):
"""Latent Dirichlet Allocation with online variational Bayes algorithm.

The implementation is based on [1]_ and [2]_.
Expand Down Expand Up @@ -887,3 +889,8 @@ def perplexity(self, X, sub_sampling=False):
X, reset_n_features=True, whom="LatentDirichletAllocation.perplexity"
)
return self._perplexity_precomp_distr(X, sub_sampling=sub_sampling)

@property
def _n_features_out(self):
"""Number of transformed output features."""
return self.components_.shape[0]
14 changes: 11 additions & 3 deletions sklearn/decomposition/_nmf.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,14 @@

from ._cdnmf_fast import _update_cdnmf_fast
from .._config import config_context
from ..base import BaseEstimator, TransformerMixin
from ..base import BaseEstimator, TransformerMixin, _ClassNamePrefixFeaturesOutMixin
from ..exceptions import ConvergenceWarning
from ..utils import check_random_state, check_array
from ..utils.extmath import randomized_svd, safe_sparse_dot, squared_norm
from ..utils.validation import check_is_fitted, check_non_negative
from ..utils.validation import (
check_is_fitted,
check_non_negative,
)

EPSILON = np.finfo(np.float32).eps

Expand Down Expand Up @@ -1109,7 +1112,7 @@ def non_negative_factorization(
return W, H, n_iter


class NMF(TransformerMixin, BaseEstimator):
class NMF(_ClassNamePrefixFeaturesOutMixin, TransformerMixin, BaseEstimator):
"""Non-Negative Matrix Factorization (NMF).

Find two non-negative matrices (W, H) whose product approximates the non-
Expand Down Expand Up @@ -1708,3 +1711,8 @@ def inverse_transform(self, W):
"""
check_is_fitted(self)
return np.dot(W, self.components_)

@property
def _n_features_out(self):
"""Number of transformed output features."""
return self.components_.shape[0]
9 changes: 7 additions & 2 deletions sklearn/decomposition/_sparse_pca.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@
from ..utils import check_random_state
from ..utils.validation import check_is_fitted
from ..linear_model import ridge_regression
from ..base import BaseEstimator, TransformerMixin
from ..base import BaseEstimator, TransformerMixin, _ClassNamePrefixFeaturesOutMixin
from ._dict_learning import dict_learning, dict_learning_online


class SparsePCA(TransformerMixin, BaseEstimator):
class SparsePCA(_ClassNamePrefixFeaturesOutMixin, TransformerMixin, BaseEstimator):
"""Sparse Principal Components Analysis (SparsePCA).

Finds the set of sparse components that can optimally reconstruct
Expand Down Expand Up @@ -236,6 +236,11 @@ def transform(self, X):

return U

@property
def _n_features_out(self):
"""Number of transformed output features."""
return self.components_.shape[0]


class MiniBatchSparsePCA(SparsePCA):
"""Mini-batch Sparse Principal Components Analysis.
Expand Down
9 changes: 7 additions & 2 deletions sklearn/decomposition/_truncated_svd.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import scipy.sparse as sp
from scipy.sparse.linalg import svds

from ..base import BaseEstimator, TransformerMixin
from ..base import BaseEstimator, TransformerMixin, _ClassNamePrefixFeaturesOutMixin
from ..utils import check_array, check_random_state
from ..utils._arpack import _init_arpack_v0
from ..utils.extmath import randomized_svd, safe_sparse_dot, svd_flip
Expand All @@ -21,7 +21,7 @@
__all__ = ["TruncatedSVD"]


class TruncatedSVD(TransformerMixin, BaseEstimator):
class TruncatedSVD(_ClassNamePrefixFeaturesOutMixin, TransformerMixin, BaseEstimator):
"""Dimensionality reduction using truncated SVD (aka LSA).

This transformer performs linear dimensionality reduction by means of
Expand Down Expand Up @@ -273,3 +273,8 @@ def inverse_transform(self, X):

def _more_tags(self):
return {"preserves_dtype": [np.float64, np.float32]}

@property
def _n_features_out(self):
"""Number of transformed output features."""
return self.components_.shape[0]
18 changes: 18 additions & 0 deletions sklearn/decomposition/tests/test_dict_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -664,3 +664,21 @@ def test_warning_default_transform_alpha(Estimator):
dl = Estimator(alpha=0.1)
with pytest.warns(FutureWarning, match="default transform_alpha"):
dl.fit_transform(X)


@pytest.mark.parametrize(
"estimator",
[SparseCoder(X.T), DictionaryLearning(), MiniBatchDictionaryLearning()],
ids=lambda x: x.__class__.__name__,
)
def test_get_feature_names_out(estimator):
"""Check feature names for dict learning estimators."""
estimator.fit(X)
n_components = X.shape[1]

feature_names_out = estimator.get_feature_names_out()
estimator_name = estimator.__class__.__name__.lower()
assert_array_equal(
feature_names_out,
[f"{estimator_name}{i}" for i in range(n_components)],
)
9 changes: 9 additions & 0 deletions sklearn/decomposition/tests/test_incremental_pca.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from sklearn.utils._testing import assert_almost_equal
from sklearn.utils._testing import assert_array_almost_equal
from sklearn.utils._testing import assert_allclose_dense_sparse
from numpy.testing import assert_array_equal

from sklearn import datasets
from sklearn.decomposition import PCA, IncrementalPCA
Expand Down Expand Up @@ -427,3 +428,11 @@ def test_incremental_pca_fit_overflow_error():
pca.fit(A)

np.testing.assert_allclose(ipca.singular_values_, pca.singular_values_)


def test_incremental_pca_feature_names_out():
"""Check feature names out for IncrementalPCA."""
ipca = IncrementalPCA(n_components=2).fit(iris.data)

names = ipca.get_feature_names_out()
assert_array_equal([f"incrementalpca{i}" for i in range(2)], names)
9 changes: 9 additions & 0 deletions sklearn/decomposition/tests/test_kernel_pca.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,3 +559,12 @@ def test_kernel_pca_alphas_deprecated():
msg = r"Attribute `alphas_` was deprecated in version 1\.0"
with pytest.warns(FutureWarning, match=msg):
kp.alphas_


def test_kernel_pca_feature_names_out():
"""Check feature names out for KernelPCA."""
X, *_ = make_blobs(n_samples=100, n_features=4, random_state=0)
kpca = KernelPCA(n_components=2).fit(X)

names = kpca.get_feature_names_out()
assert_array_equal([f"kernelpca{i}" for i in range(2)], names)
Loading