Skip to content

MRG add n_features_out_ attribute #14241

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 32 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
c47227d
start on n_features_out_
amueller Jul 2, 2019
06b4a08
make sure common tests for transformers respect pairwise
amueller Jul 2, 2019
b13b57e
fix number of features in quantile transformer
amueller Jul 2, 2019
ac8d243
only check n_features_out_ if it's not None?
amueller Jul 2, 2019
822dae6
provide setter for n_features_out_
amueller Jul 30, 2019
211ebd5
Merge branch 'master' into n_features_out
amueller Jul 30, 2019
2933b8b
Merge branch 'master' into n_features_out
amueller Jul 30, 2019
42e5017
typo
amueller Jul 30, 2019
124e325
more typos
amueller Jul 30, 2019
041dfff
fix some input validation
amueller Jul 30, 2019
c1d47a1
move feature selection n_features_out_ to mixin
amueller Jul 30, 2019
7735d18
remove linear discriminant analysis special case
amueller Jul 30, 2019
aef2283
remove special case for clustering
amueller Jul 30, 2019
6a56572
I have no idea how this passed?!
amueller Jul 30, 2019
5c267ce
remove scaler special case, fix in imputation
amueller Jul 30, 2019
9a2e80c
removed the last bit of magic
amueller Jul 30, 2019
375c130
pep8
amueller Jul 30, 2019
0c2bb8a
add n_features_out_ to voting classifier
amueller Jul 30, 2019
9ccf0ac
add n_features_out_ to estimator in testing
amueller Jul 31, 2019
186a6d2
Merge branch 'master' into n_features_out
amueller Aug 1, 2019
6043a5c
check that n_components is integer
amueller Aug 1, 2019
a39ab07
pep8
amueller Aug 1, 2019
8039d8e
Merge branch 'master' into n_features_out
amueller Sep 9, 2019
344d01e
explitictly set n_features_out_ in clustering
amueller Sep 9, 2019
df82f64
add n_features_out_ to knnimputer
amueller Sep 9, 2019
816b677
Merge branch 'master' into n_features_out
amueller Sep 24, 2019
dcba760
add n_features_out_ to neighbors transformers
amueller Sep 24, 2019
d3b02e4
add n_features_out_ to feature union, add test
amueller Sep 24, 2019
e02d118
add n_features_out_ to ColumnTransformer and DictVectorizer
amueller Sep 24, 2019
c8db102
add n_features_out_ to stacking regressor and stacking classifier
amueller Sep 25, 2019
4b539af
add _last_non_passthrough_estimator
amueller Sep 25, 2019
43e9c96
add n_features_out_ to pipeline
amueller Sep 25, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions sklearn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import copy
import warnings
from collections import defaultdict
import numbers
import platform
import inspect
import re
Expand Down Expand Up @@ -555,6 +556,33 @@ def fit_transform(self, X, y=None, **fit_params):
# fit method of arity 2 (supervised transformation)
return self.fit(X, y, **fit_params).transform(X)

@property
def n_features_out_(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can't say I like this magic determination. I'd rather it be done by specialised mixins for decomposition and feature selection.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I could do that.
I have to check how much of these are actually in the decomposition module

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if I prefer having these be mixins or base classes. It seems unlikely you want to mix those and base classes make the code shorter.

return self._n_features_out

@n_features_out_.setter
def n_features_out_(self, val):
self._n_features_out = val


class ComponentsMixin:
@property
def n_features_out_(self):
if hasattr(self, 'n_components_'):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we consider deprecating n_components_, given the availability of n_features_out_??

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would consider it ;)

# n_components could be auto or None
# this is more likely to be an int
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can also include the isinstance(..., numbers.Integral) check here to be sure.

n_features = self.n_components_
elif hasattr(self, 'components_'):
n_features = self.components_.shape[0]
elif (hasattr(self, 'n_components')
and isinstance(self.n_components, numbers.Integral)):
n_features = self.n_components
else:
raise AttributeError(
"{} has no attribute 'n_features_out_'".format(
type(self).__name__))
return n_features
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should there be a "default" value here? Maybe None?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here? This is for ComponentsMixin. If it doesn't have components it probably shouldn't have the ComponentsMixin.

In general: the route I'm taking here is to require the user to set it to None and error otherwise in the tests. That allowed me to test that it's actually implemented.
What we want to enforce for third-party estimators is a slightly different discussion. This PR currently adds new requirements in check_estimator.



class DensityMixin:
"""Mixin class for all density estimators in scikit-learn."""
Expand Down
2 changes: 2 additions & 0 deletions sklearn/cluster/birch.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,6 +493,8 @@ def _fit(self, X):
self.subcluster_centers_ = centroids

self._global_clustering(X)
self.n_features_out_ = self.n_clusters

return self

def _get_leaves(self):
Expand Down
1 change: 1 addition & 0 deletions sklearn/cluster/hierarchical.py
Original file line number Diff line number Diff line change
Expand Up @@ -1036,6 +1036,7 @@ def fit(self, X, y=None, **params):
"""
X = check_array(X, accept_sparse=['csr', 'csc', 'coo'],
ensure_min_features=2, estimator=self)
self.n_features_out_ = self.n_clusters
return AgglomerativeClustering.fit(self, X.T, **params)

@property
Expand Down
3 changes: 3 additions & 0 deletions sklearn/cluster/k_means_.py
Original file line number Diff line number Diff line change
Expand Up @@ -839,6 +839,7 @@ def fit(self, X, y=None, sample_weight=None):
"""
random_state = check_random_state(self.random_state)

self.n_features_out_ = self.n_clusters
n_init = self.n_init
if n_init <= 0:
raise ValueError("Invalid number of initializations."
Expand Down Expand Up @@ -1626,6 +1627,7 @@ def fit(self, X, y=None, sample_weight=None):
if self.compute_labels:
self.labels_, self.inertia_ = \
self._labels_inertia_minibatch(X, sample_weight)
self.n_features_out_ = self.n_clusters

return self

Expand Down Expand Up @@ -1725,6 +1727,7 @@ def partial_fit(self, X, y=None, sample_weight=None):
if self.compute_labels:
self.labels_, self.inertia_ = _labels_inertia(
X, sample_weight, x_squared_norms, self.cluster_centers_)
self.n_features_out_ = self.n_clusters

return self

Expand Down
12 changes: 12 additions & 0 deletions sklearn/compose/_column_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,18 @@ def get_feature_names(self):
trans.get_feature_names()])
return feature_names

@property
def n_features_out_(self):
n_features_out = 0
for name, trans, column, _ in self._iter(fitted=True):
if trans == 'drop':
continue
elif trans == 'passthrough':
n_features_out += len(column)
else:
n_features_out += trans.n_features_out_
return n_features_out

def _update_fitted_transformers(self, transformers):
# transformers are fitted; excludes 'drop' cases
fitted_transformers = iter(transformers)
Expand Down
3 changes: 3 additions & 0 deletions sklearn/compose/tests/test_column_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -662,10 +662,12 @@ def test_column_transformer_get_feature_names():
[('col' + str(i), DictVectorizer(), i) for i in range(2)])
ct.fit(X)
assert ct.get_feature_names() == ['col0__a', 'col0__b', 'col1__c']
assert ct.n_features_out_ == len(ct.get_feature_names())

# passthrough transformers not supported
ct = ColumnTransformer([('trans', 'passthrough', [0, 1])])
ct.fit(X)
assert ct.n_features_out_ == 2
assert_raise_message(
NotImplementedError, 'get_feature_names is not yet supported',
ct.get_feature_names)
Expand All @@ -682,6 +684,7 @@ def test_column_transformer_get_feature_names():
[('col0', DictVectorizer(), 0), ('col1', 'drop', 1)])
ct.fit(X)
assert ct.get_feature_names() == ['col0__a', 'col0__b']
assert ct.n_features_out_ == len(ct.get_feature_names())


def test_column_transformer_special_strings():
Expand Down
7 changes: 4 additions & 3 deletions sklearn/decomposition/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,14 @@
import numpy as np
from scipy import linalg

from ..base import BaseEstimator, TransformerMixin
from ..base import BaseEstimator, TransformerMixin, ComponentsMixin
from ..utils import check_array
from ..utils.validation import check_is_fitted
from abc import ABCMeta, abstractmethod


class _BasePCA(TransformerMixin, BaseEstimator, metaclass=ABCMeta):
class _BasePCA(ComponentsMixin, TransformerMixin,
BaseEstimator, metaclass=ABCMeta):
"""Base class for PCA methods.

Warning: This class should not be used directly.
Expand Down Expand Up @@ -154,6 +155,6 @@ def inverse_transform(self, X):
"""
if self.whiten:
return np.dot(X, np.sqrt(self.explained_variance_[:, np.newaxis]) *
self.components_) + self.mean_
self.components_) + self.mean_
else:
return np.dot(X, self.components_) + self.mean_
4 changes: 2 additions & 2 deletions sklearn/decomposition/dict_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from scipy import linalg
from joblib import Parallel, delayed, effective_n_jobs

from ..base import BaseEstimator, TransformerMixin
from ..base import BaseEstimator, TransformerMixin, ComponentsMixin
from ..utils import (check_array, check_random_state, gen_even_slices,
gen_batches)
from ..utils.extmath import randomized_svd, row_norms
Expand Down Expand Up @@ -875,7 +875,7 @@ def dict_learning_online(X, n_components=2, alpha=1, n_iter=100,
return dictionary.T


class SparseCodingMixin(TransformerMixin):
class SparseCodingMixin(ComponentsMixin, TransformerMixin):
"""Sparse coding mixin"""

def _set_sparse_coding_params(self, n_components,
Expand Down
4 changes: 2 additions & 2 deletions sklearn/decomposition/factor_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,14 @@
from scipy import linalg


from ..base import BaseEstimator, TransformerMixin
from ..base import BaseEstimator, ComponentsMixin, TransformerMixin
from ..utils import check_array, check_random_state
from ..utils.extmath import fast_logdet, randomized_svd, squared_norm
from ..utils.validation import check_is_fitted
from ..exceptions import ConvergenceWarning


class FactorAnalysis(TransformerMixin, BaseEstimator):
class FactorAnalysis(ComponentsMixin, TransformerMixin, BaseEstimator):
"""Factor Analysis (FA)

A simple linear generative model with Gaussian latent variables.
Expand Down
4 changes: 2 additions & 2 deletions sklearn/decomposition/fastica_.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import numpy as np
from scipy import linalg

from ..base import BaseEstimator, TransformerMixin
from ..base import BaseEstimator, TransformerMixin, ComponentsMixin
from ..exceptions import ConvergenceWarning

from ..utils import check_array, as_float_array, check_random_state
Expand Down Expand Up @@ -380,7 +380,7 @@ def g(x, fun_args):
return None, W, S


class FastICA(TransformerMixin, BaseEstimator):
class FastICA(ComponentsMixin, TransformerMixin, BaseEstimator):
"""FastICA: a fast algorithm for Independent Component Analysis.

Read more in the :ref:`User Guide <ICA>`.
Expand Down
4 changes: 2 additions & 2 deletions sklearn/decomposition/kernel_pca.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@
from ..utils.extmath import svd_flip
from ..utils.validation import check_is_fitted, check_array
from ..exceptions import NotFittedError
from ..base import BaseEstimator, TransformerMixin
from ..base import BaseEstimator, TransformerMixin, ComponentsMixin
from ..preprocessing import KernelCenterer
from ..metrics.pairwise import pairwise_kernels


class KernelPCA(TransformerMixin, BaseEstimator):
class KernelPCA(ComponentsMixin, TransformerMixin, BaseEstimator):
"""Kernel Principal component analysis (KPCA)

Non-linear dimensionality reduction through the use of kernels (see
Expand Down
4 changes: 2 additions & 2 deletions sklearn/decomposition/nmf.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import numpy as np
import scipy.sparse as sp

from ..base import BaseEstimator, TransformerMixin
from ..base import BaseEstimator, TransformerMixin, ComponentsMixin
from ..utils import check_random_state, check_array
from ..utils.extmath import randomized_svd, safe_sparse_dot, squared_norm
from ..utils.validation import check_is_fitted, check_non_negative
Expand Down Expand Up @@ -1070,7 +1070,7 @@ def non_negative_factorization(X, W=None, H=None, n_components=None,
return W, H, n_iter


class NMF(TransformerMixin, BaseEstimator):
class NMF(ComponentsMixin, TransformerMixin, BaseEstimator):
r"""Non-Negative Matrix Factorization (NMF)

Find two non-negative matrices (W, H) whose product approximates the non-
Expand Down
5 changes: 3 additions & 2 deletions sklearn/decomposition/online_lda.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from scipy.special import gammaln
from joblib import Parallel, delayed, effective_n_jobs

from ..base import BaseEstimator, TransformerMixin
from ..base import BaseEstimator, TransformerMixin, ComponentsMixin
from ..utils import (check_random_state, check_array,
gen_batches, gen_even_slices)
from ..utils.fixes import logsumexp
Expand Down Expand Up @@ -132,7 +132,8 @@ def _update_doc_distribution(X, exp_topic_word_distr, doc_topic_prior,
return (doc_topic_distr, suff_stats)


class LatentDirichletAllocation(TransformerMixin, BaseEstimator):
class LatentDirichletAllocation(ComponentsMixin, TransformerMixin,
BaseEstimator):
"""Latent Dirichlet Allocation with online variational Bayes algorithm

.. versionadded:: 0.17
Expand Down
4 changes: 2 additions & 2 deletions sklearn/decomposition/sparse_pca.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from ..utils import check_random_state, check_array
from ..utils.validation import check_is_fitted
from ..linear_model import ridge_regression
from ..base import BaseEstimator, TransformerMixin
from ..base import BaseEstimator, TransformerMixin, ComponentsMixin
from .dict_learning import dict_learning, dict_learning_online


Expand All @@ -29,7 +29,7 @@ def _check_normalize_components(normalize_components, estimator_name):
)


class SparsePCA(TransformerMixin, BaseEstimator):
class SparsePCA(ComponentsMixin, TransformerMixin, BaseEstimator):
"""Sparse Principal Components Analysis (SparsePCA)

Finds the set of sparse components that can optimally reconstruct
Expand Down
4 changes: 2 additions & 2 deletions sklearn/decomposition/truncated_svd.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,15 @@
import scipy.sparse as sp
from scipy.sparse.linalg import svds

from ..base import BaseEstimator, TransformerMixin
from ..base import BaseEstimator, TransformerMixin, ComponentsMixin
from ..utils import check_array, check_random_state
from ..utils.extmath import randomized_svd, safe_sparse_dot, svd_flip
from ..utils.sparsefuncs import mean_variance_axis

__all__ = ["TruncatedSVD"]


class TruncatedSVD(TransformerMixin, BaseEstimator):
class TruncatedSVD(ComponentsMixin, TransformerMixin, BaseEstimator):
"""Dimensionality reduction using truncated SVD (aka LSA).

This transformer performs linear dimensionality reduction by means of
Expand Down
5 changes: 5 additions & 0 deletions sklearn/discriminant_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -552,6 +552,11 @@ def predict_log_proba(self, X):
"""
return np.log(self.predict_proba(X))

@property
def n_features_out_(self):
n_components = self.n_components or np.inf
return min(self._max_components, n_components)


class QuadraticDiscriminantAnalysis(ClassifierMixin, BaseEstimator):
"""Quadratic Discriminant Analysis
Expand Down
11 changes: 9 additions & 2 deletions sklearn/ensemble/_stacking.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,12 @@ def fit(self, X, y, sample_weight=None):
check_classification_targets(y)
self._le = LabelEncoder().fit(y)
self.classes_ = self._le.classes_
return super().fit(X, self._le.transform(y), sample_weight)
super().fit(X, self._le.transform(y), sample_weight)
if len(self.classes_) == 2:
self.n_features_out_ = len(self.estimators_)
else:
self.n_features_out_ = len(self.estimators_) * len(self.classes_)
return self

@if_delegate_has_method(delegate='final_estimator_')
def predict(self, X, **predict_params):
Expand Down Expand Up @@ -691,7 +696,9 @@ def fit(self, X, y, sample_weight=None):
self : object
"""
y = column_or_1d(y, warn=True)
return super().fit(X, y, sample_weight)
super().fit(X, y, sample_weight)
self.n_features_out_ = len(self.estimators_)
return self

def transform(self, X):
"""Return the predictions for X for each estimator.
Expand Down
4 changes: 3 additions & 1 deletion sklearn/ensemble/forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2196,7 +2196,9 @@ def fit_transform(self, X, y=None, sample_weight=None):
super().fit(X, y, sample_weight=sample_weight)

self.one_hot_encoder_ = OneHotEncoder(sparse=self.sparse_output)
return self.one_hot_encoder_.fit_transform(self.apply(X))
res = self.one_hot_encoder_.fit_transform(self.apply(X))
self.n_features_out_ = res.shape[1]
return res

def transform(self, X):
"""Transform dataset.
Expand Down
5 changes: 4 additions & 1 deletion sklearn/ensemble/voting.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,9 @@ def fit(self, X, y, sample_weight=None):
self.le_ = LabelEncoder().fit(y)
self.classes_ = self.le_.classes_
transformed_y = self.le_.transform(y)

self.n_features_out_ = len(self.estimators)
if self.voting == 'soft':
self.n_features_out_ *= len(self.classes_)
return super().fit(X, transformed_y, sample_weight)

def predict(self, X):
Expand Down Expand Up @@ -449,6 +451,7 @@ def fit(self, X, y, sample_weight=None):
self : object
"""
y = column_or_1d(y, warn=True)
self.n_features_out_ = len(self.estimators)
return super().fit(X, y, sample_weight)

def predict(self, X):
Expand Down
2 changes: 2 additions & 0 deletions sklearn/feature_extraction/dict_vectorizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ def fit(self, X, y=None):
vocab = {f: i for i, f in enumerate(feature_names)}

self.feature_names_ = feature_names
self.n_features_out_ = len(self.feature_names_)
self.vocabulary_ = vocab

return self
Expand Down Expand Up @@ -205,6 +206,7 @@ def _transform(self, X, fitting):
if fitting:
self.feature_names_ = feature_names
self.vocabulary_ = vocab
self.n_features_out_ = len(self.feature_names_)

return result_matrix

Expand Down
4 changes: 4 additions & 0 deletions sklearn/feature_selection/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ def get_support(self, indices=False):
mask = self._get_support_mask()
return mask if not indices else np.where(mask)[0]

@property
def n_features_out_(self):
return self.get_support().sum()

@abstractmethod
def _get_support_mask(self):
"""
Expand Down
Loading