From 9e6db314a80e2b05cd43542dfeb8a4d32dab68cb Mon Sep 17 00:00:00 2001 From: Andreas Mueller Date: Tue, 2 Jul 2019 17:38:57 -0400 Subject: [PATCH 1/4] fresh start lol --- doc/modules/compose.rst | 38 ++++- examples/cluster/plot_inductive_clustering.py | 2 +- examples/compose/plot_column_transformer.py | 4 +- .../plot_column_transformer_mixed_types.py | 54 ++++++- .../plot_feature_selection_pipeline.py | 9 +- sklearn/base.py | 100 +++++++++++++ sklearn/calibration.py | 6 +- sklearn/cluster/affinity_propagation_.py | 2 +- sklearn/cluster/bicluster.py | 2 +- sklearn/cluster/birch.py | 2 +- sklearn/cluster/hierarchical.py | 2 +- sklearn/cluster/mean_shift_.py | 2 +- sklearn/cluster/optics_.py | 2 +- sklearn/cluster/spectral.py | 2 +- sklearn/compose/_column_transformer.py | 17 ++- sklearn/compose/_target.py | 2 +- .../compose/tests/test_column_transformer.py | 27 +++- sklearn/compose/tests/test_target.py | 6 +- sklearn/covariance/elliptic_envelope.py | 2 +- sklearn/covariance/empirical_covariance_.py | 2 +- sklearn/covariance/graph_lasso_.py | 4 +- sklearn/covariance/robust_covariance.py | 2 +- sklearn/covariance/shrunk_covariance_.py | 6 +- sklearn/cross_decomposition/pls_.py | 4 +- sklearn/decomposition/dict_learning.py | 6 +- sklearn/decomposition/factor_analysis.py | 2 +- sklearn/decomposition/fastica_.py | 4 +- sklearn/decomposition/incremental_pca.py | 2 +- sklearn/decomposition/kernel_pca.py | 2 +- sklearn/decomposition/online_lda.py | 4 +- sklearn/decomposition/pca.py | 4 +- sklearn/decomposition/sparse_pca.py | 4 +- sklearn/decomposition/truncated_svd.py | 4 +- sklearn/dummy.py | 4 +- .../_hist_gradient_boosting/binning.py | 2 +- sklearn/ensemble/bagging.py | 2 +- sklearn/ensemble/forest.py | 11 +- sklearn/ensemble/gradient_boosting.py | 10 +- .../ensemble/tests/test_weight_boosting.py | 4 +- sklearn/ensemble/voting.py | 6 +- sklearn/ensemble/weight_boosting.py | 6 +- sklearn/feature_extraction/dict_vectorizer.py | 6 +- sklearn/feature_extraction/image.py | 2 +- sklearn/feature_extraction/text.py | 10 +- sklearn/feature_selection/base.py | 15 ++ sklearn/feature_selection/tests/test_base.py | 2 +- .../tests/test_from_model.py | 2 +- .../feature_selection/variance_threshold.py | 2 +- sklearn/impute/_base.py | 30 +++- sklearn/impute/_iterative.py | 4 +- sklearn/isotonic.py | 2 +- sklearn/kernel_approximation.py | 8 +- sklearn/linear_model/base.py | 2 +- sklearn/linear_model/bayes.py | 2 +- sklearn/linear_model/huber.py | 2 +- sklearn/linear_model/logistic.py | 5 +- sklearn/linear_model/ransac.py | 2 +- sklearn/linear_model/ridge.py | 12 +- sklearn/linear_model/stochastic_gradient.py | 2 +- sklearn/manifold/isomap.py | 4 +- sklearn/manifold/locally_linear.py | 4 +- sklearn/manifold/spectral_embedding_.py | 4 +- sklearn/manifold/t_sne.py | 4 +- sklearn/mixture/base.py | 2 +- sklearn/model_selection/tests/test_search.py | 2 +- sklearn/multioutput.py | 4 +- sklearn/naive_bayes.py | 4 +- sklearn/neighbors/base.py | 2 +- sklearn/neighbors/lof.py | 2 +- sklearn/neural_network/rbm.py | 4 +- sklearn/pipeline.py | 43 ++++-- sklearn/preprocessing/_discretization.py | 2 +- sklearn/preprocessing/_encoders.py | 14 +- .../preprocessing/_function_transformer.py | 2 +- sklearn/preprocessing/data.py | 48 +++--- sklearn/random_projection.py | 2 +- sklearn/svm/base.py | 2 +- sklearn/svm/classes.py | 4 +- sklearn/tests/test_base.py | 76 +++++++++- sklearn/tests/test_pipeline.py | 138 ++++++++++++++++++ sklearn/utils/estimator_checks.py | 9 ++ sklearn/utils/tests/test_estimator_checks.py | 22 +-- sklearn/utils/tests/test_validation.py | 2 +- 83 files changed, 685 insertions(+), 201 deletions(-) diff --git a/doc/modules/compose.rst b/doc/modules/compose.rst index 0ac33ce7a4d4a..80852283e3c7e 100644 --- a/doc/modules/compose.rst +++ b/doc/modules/compose.rst @@ -136,6 +136,32 @@ or by name:: >>> pipe['reduce_dim'] PCA() +To enable model inspection, `Pipeline` sets an ``input_features_`` attribute on +all pipeline steps during fitting. This allows the user to understand how +features are transformed during a pipeline:: + + >>> from sklearn.datasets import load_iris + >>> from sklearn.feature_selection import SelectKBest + >>> iris = load_iris() + >>> pipe = Pipeline(steps=[ + ... ('select', SelectKBest(k=2)), + ... ('clf', LogisticRegression())]) + >>> pipe.fit(iris.data, iris.target) + ... # doctest: +NORMALIZE_WHITESPACE, +ELLIPSIS + Pipeline(memory=None, + steps=[('select', SelectKBest(...)), ('clf', LogisticRegression(...))]) + >>> pipe.named_steps.clf.input_features_ + array(['x2', 'x3'], dtype='>> pipe.get_feature_names(iris.feature_names) + >>> pipe.named_steps.select.input_features_ + ['sepal length (cm)', 'sepal width (cm)', 'petal length (cm)', 'petal width (cm)'] + >>> pipe.named_steps.clf.input_features_ + array(['petal length (cm)', 'petal width (cm)'], dtype='>> from sklearn.feature_extraction.text import CountVectorizer >>> from sklearn.preprocessing import OneHotEncoder >>> column_trans = ColumnTransformer( - ... [('city_category', OneHotEncoder(dtype='int'),['city']), + ... [('categories', OneHotEncoder(dtype='int'),['city']), ... ('title_bow', CountVectorizer(), 'title')], ... remainder='drop') @@ -441,11 +467,11 @@ By default, the remaining rating columns are ignored (``remainder='drop'``):: ('title_bow', CountVectorizer(), 'title')]) >>> column_trans.get_feature_names() - ['city_category__x0_London', 'city_category__x0_Paris', 'city_category__x0_Sallisaw', - 'title_bow__bow', 'title_bow__feast', 'title_bow__grapes', 'title_bow__his', - 'title_bow__how', 'title_bow__last', 'title_bow__learned', 'title_bow__moveable', - 'title_bow__of', 'title_bow__the', 'title_bow__trick', 'title_bow__watson', - 'title_bow__wrath'] + ['categories__city_London', 'categories__city_Paris', + 'categories__city_Sallisaw', 'title_bow__bow', 'title_bow__feast', + 'title_bow__grapes', 'title_bow__his', 'title_bow__how', 'title_bow__last', + 'title_bow__learned', 'title_bow__moveable', 'title_bow__of', 'title_bow__the', + 'title_bow__trick', 'title_bow__watson', 'title_bow__wrath'] >>> column_trans.transform(X).toarray() array([[1, 0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0], diff --git a/examples/cluster/plot_inductive_clustering.py b/examples/cluster/plot_inductive_clustering.py index c5a51db5ef577..a477ce37f1fb0 100644 --- a/examples/cluster/plot_inductive_clustering.py +++ b/examples/cluster/plot_inductive_clustering.py @@ -40,7 +40,7 @@ def __init__(self, clusterer, classifier): self.clusterer = clusterer self.classifier = classifier - def fit(self, X, y=None): + def fit(self, X, y=None, feature_names_in=None): self.clusterer_ = clone(self.clusterer) self.classifier_ = clone(self.classifier) y = self.clusterer_.fit_predict(X) diff --git a/examples/compose/plot_column_transformer.py b/examples/compose/plot_column_transformer.py index 181e3e9127b56..7dd74d98cdded 100644 --- a/examples/compose/plot_column_transformer.py +++ b/examples/compose/plot_column_transformer.py @@ -45,7 +45,7 @@ class TextStats(BaseEstimator, TransformerMixin): """Extract features from each document for DictVectorizer""" - def fit(self, x, y=None): + def fit(self, X, y=None, feature_names_in=None): return self def transform(self, posts): @@ -60,7 +60,7 @@ class SubjectBodyExtractor(BaseEstimator, TransformerMixin): Takes a sequence of strings and produces a dict of sequences. Keys are `subject` and `body`. """ - def fit(self, x, y=None): + def fit(self, X, y=None, feature_names_in=None): return self def transform(self, posts): diff --git a/examples/compose/plot_column_transformer_mixed_types.py b/examples/compose/plot_column_transformer_mixed_types.py index 264ae7495296c..c95f9494916df 100644 --- a/examples/compose/plot_column_transformer_mixed_types.py +++ b/examples/compose/plot_column_transformer_mixed_types.py @@ -68,16 +68,60 @@ # Append classifier to preprocessing pipeline. # Now we have a full prediction pipeline. -clf = Pipeline(steps=[('preprocessor', preprocessor), - ('classifier', LogisticRegression())]) +pipeline = Pipeline(steps=[('preprocessor', preprocessor), + ('classifier', LogisticRegression())]) X = data.drop('survived', axis=1) y = data['survived'] X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2) -clf.fit(X_train, y_train) -print("model score: %.3f" % clf.score(X_test, y_test)) +pipeline.fit(X_train, y_train) +print("model score: %.3f" % pipeline.score(X_test, y_test)) + + +############################################################################### +# Inspecting the coefficients values of the classifier +############################################################################### +# The coefficients of the final classification step of the pipeline gives an +# idea how each feature impacts the likelihood of survival assuming that the +# usual linear model assumptions hold (uncorrelated features, linear +# separability, homoschedastic errors...) which we do not verify in this +# example. +# +# To get error bars we perform cross-validation and compute the mean and +# standard deviation for each coefficient accross CV splits. Because we use a +# standard scaler on the numerical features, the coefficient weights gives us +# an idea on how much the log odds of surviving are impacted by a change in +# this dimension contrasted to the mean. Note that the categorical features +# here are overspecified which makes it slightly harder to interpret because of +# the information redundancy. +# +# We can see that the linear model coefficients are in agreement with the +# historical reports: people in higher classes and therefore in the upper decks +# were the first to reach the lifeboats, and often, priority was given to women +# and children. +# +# Note that conditionned on the "pclass_x" one-hot features, the "fare" +# numerical feature does not seem to be significantly predictive. If we drop +# the "pclass" feature, then higher "fare" values would appear significantly +# correlated with a higher likelihood of survival as the "fare" and "pclass" +# features have a strong statistical dependency. + +import matplotlib.pyplot as plt +from sklearn.model_selection import cross_validate +from sklearn.model_selection import StratifiedShuffleSplit + +cv = StratifiedShuffleSplit(n_splits=20, test_size=0.25, random_state=42) +cv_results = cross_validate(pipeline, X_train, y_train, cv=cv, + return_estimator=True) +cv_coefs = np.concatenate([cv_pipeline.named_steps["classifier"].coef_ + for cv_pipeline in cv_results["estimator"]]) +fig, ax = plt.subplots() +ax.barh(pipeline.named_steps["classifier"].input_features_, + cv_coefs.mean(axis=0), xerr=cv_coefs.std(axis=0)) +plt.tight_layout() +plt.show() ############################################################################### @@ -96,7 +140,7 @@ 'classifier__C': [0.1, 1.0, 10, 100], } -grid_search = GridSearchCV(clf, param_grid, cv=10) +grid_search = GridSearchCV(pipeline, param_grid, cv=10) grid_search.fit(X_train, y_train) print(("best logistic regression from grid search: %.3f" diff --git a/examples/feature_selection/plot_feature_selection_pipeline.py b/examples/feature_selection/plot_feature_selection_pipeline.py index 47d4fb82e46ee..5eb9dd57e233b 100644 --- a/examples/feature_selection/plot_feature_selection_pipeline.py +++ b/examples/feature_selection/plot_feature_selection_pipeline.py @@ -9,6 +9,7 @@ Using a sub-pipeline, the fitted coefficients can be mapped back into the original feature space. """ +import matplotlib.pyplot as plt from sklearn import svm from sklearn.datasets import samples_generator from sklearn.feature_selection import SelectKBest, f_regression @@ -20,7 +21,7 @@ # import some data to play with X, y = samples_generator.make_classification( - n_features=20, n_informative=3, n_redundant=0, n_classes=4, + n_features=20, n_informative=3, n_redundant=0, n_classes=2, n_clusters_per_class=2) X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42) @@ -36,5 +37,7 @@ y_pred = anova_svm.predict(X_test) print(classification_report(y_test, y_pred)) -coef = anova_svm[:-1].inverse_transform(anova_svm['linearsvc'].coef_) -print(coef) +# access and plot the coefficients of the fitted model +plt.barh((0, 1, 2), anova_svm[-1].coef_.ravel()) +plt.yticks((0, 1, 2), anova_svm[-1].input_features_) +plt.show() diff --git a/sklearn/base.py b/sklearn/base.py index fb0818efc8248..a2fd84f443d9d 100644 --- a/sklearn/base.py +++ b/sklearn/base.py @@ -6,6 +6,7 @@ import copy import warnings from collections import defaultdict + import platform import inspect import re @@ -13,8 +14,11 @@ import numpy as np from . import __version__ +from .exceptions import NotFittedError + from .utils import _IS_32BIT + _DEFAULT_TAGS = { 'non_deterministic': False, 'requires_positive_X': False, @@ -558,6 +562,49 @@ def fit_transform(self, X, y=None, **fit_params): # fit method of arity 2 (supervised transformation) return self.fit(X, y, **fit_params).transform(X) + def _get_feature_names(self, input_features=None): + """Get output feature names. + + Parameters + ---------- + input_features : list of string or None + String names of the input features. + + Returns + ------- + output_feature_names : list of string + Feature names for transformer output. + """ + # OneToOneMixin is higher in the class hierarchy + # because we put mixins on the wrong side + if hasattr(super(), 'get_feature_names'): + return super().get_feature_names(input_features) + # generate feature names from class name by default + # would be much less guessing if we stored the number + # of output features. + # Ideally this would be done in each class. + if hasattr(self, 'n_clusters'): + # this is before n_components_ + # because n_components_ means something else + # in agglomerative clustering + n_features = self.n_clusters + elif hasattr(self, '_max_components'): + # special case for LinearDiscriminantAnalysis + n_components = self.n_components or np.inf + n_features = min(self._max_components, n_components) + elif hasattr(self, 'n_components_'): + # n_components could be auto or None + # this is more likely to be an int + n_features = self.n_components_ + elif hasattr(self, 'n_components') and self.n_components is not None: + n_features = self.n_components + elif hasattr(self, 'components_'): + n_features = self.components_.shape[0] + else: + return None + return ["{}{}".format(type(self).__name__.lower(), i) + for i in range(n_features)] + class DensityMixin: """Mixin class for all density estimators in scikit-learn.""" @@ -603,10 +650,63 @@ def fit_predict(self, X, y=None): return self.fit(X).predict(X) +class OneToOneMixin(object): + """Provides get_feature_names for simple transformers + + Assumes there's a 1-to-1 correspondence between input features + and output features. + """ + + @property + def feature_names_out_(self): + return self.feature_names_in_ + + +def _get_sub_estimators(est): + # Explicitly declare all fitted subestimators of existing meta-estimators + sub_ests = [] + # OHE is not really needed + sub_names = ['estimator_', 'base_estimator_', 'one_hot_encoder_', + 'best_estimator_', 'init_'] + for name in sub_names: + sub_est = getattr(est, name, None) + if sub_est is not None: + sub_ests.append(sub_est) + if hasattr(est, "estimators_"): + if hasattr(est.estimators_, 'shape'): + sub_ests.extend(est.estimators_.ravel()) + else: + sub_ests.extend(est.estimators_) + return sub_ests + + class MetaEstimatorMixin: _required_parameters = ["estimator"] """Mixin class for all meta estimators in scikit-learn.""" + def _get_feature_names(self, input_features=None): + """Ensure feature names are set on sub-estimators + + Parameters + ---------- + input_features : list of string or None + Input features to the meta-estimator. + """ + sub_ests = _get_sub_estimators(self) + for est in sub_ests: + est.input_features_ = input_features + if hasattr(est, "get_feature_names"): + # doing hassattr instead of a try-except on everything + # b/c catching AttributeError makes recursive code + # impossible to debug + try: + est.get_feature_names(input_features=input_features) + except TypeError: + # do we need this? + est.get_feature_names() + except NotFittedError: + pass + class MultiOutputMixin(object): """Mixin to mark estimators that support multioutput.""" diff --git a/sklearn/calibration.py b/sklearn/calibration.py index 7473d60ded96d..e12e2504b5df9 100644 --- a/sklearn/calibration.py +++ b/sklearn/calibration.py @@ -109,7 +109,7 @@ def __init__(self, base_estimator=None, method='sigmoid', cv=None): self.method = method self.cv = cv - def fit(self, X, y, sample_weight=None): + def fit(self, X, y, sample_weight=None, feature_names_in=None): """Fit the calibrated model Parameters @@ -312,7 +312,7 @@ def _preproc(self, X): return df, idx_pos_class - def fit(self, X, y, sample_weight=None): + def fit(self, X, y, sample_weight=None, feature_names_in=None): """Calibrate the fitted model Parameters @@ -474,7 +474,7 @@ class _SigmoidCalibration(BaseEstimator, RegressorMixin): b_ : float The intercept. """ - def fit(self, X, y, sample_weight=None): + def fit(self, X, y, sample_weight=None, feature_names_in=None): """Fit the model using X, y as training data. Parameters diff --git a/sklearn/cluster/affinity_propagation_.py b/sklearn/cluster/affinity_propagation_.py index 487ade4012133..d102a7bafeccf 100644 --- a/sklearn/cluster/affinity_propagation_.py +++ b/sklearn/cluster/affinity_propagation_.py @@ -349,7 +349,7 @@ def __init__(self, damping=.5, max_iter=200, convergence_iter=15, def _pairwise(self): return self.affinity == "precomputed" - def fit(self, X, y=None): + def fit(self, X, y=None, feature_names_in=None): """Fit the clustering from features, or affinity matrix. Parameters diff --git a/sklearn/cluster/bicluster.py b/sklearn/cluster/bicluster.py index 559bd515411f0..4bb92c6f09ec2 100644 --- a/sklearn/cluster/bicluster.py +++ b/sklearn/cluster/bicluster.py @@ -107,7 +107,7 @@ def _check_parameters(self): " one of {1}.".format(self.svd_method, legal_svd_methods)) - def fit(self, X, y=None): + def fit(self, X, y=None, feature_names_in=None): """Creates a biclustering for X. Parameters diff --git a/sklearn/cluster/birch.py b/sklearn/cluster/birch.py index 27b5038bb67a3..f79f2939a5830 100644 --- a/sklearn/cluster/birch.py +++ b/sklearn/cluster/birch.py @@ -429,7 +429,7 @@ def __init__(self, threshold=0.5, branching_factor=50, n_clusters=3, self.compute_labels = compute_labels self.copy = copy - def fit(self, X, y=None): + def fit(self, X, y=None, feature_names_in=None): """ Build a CF Tree for the input data. diff --git a/sklearn/cluster/hierarchical.py b/sklearn/cluster/hierarchical.py index 17aae14589511..b9db6f2a47cc5 100644 --- a/sklearn/cluster/hierarchical.py +++ b/sklearn/cluster/hierarchical.py @@ -774,7 +774,7 @@ def __init__(self, n_clusters=2, affinity="euclidean", def n_components_(self): return self.n_connected_components_ - def fit(self, X, y=None): + def fit(self, X, y=None, feature_names_in=None): """Fit the hierarchical clustering from features, or distance matrix. Parameters diff --git a/sklearn/cluster/mean_shift_.py b/sklearn/cluster/mean_shift_.py index 960ac28984721..be113a0f0295a 100644 --- a/sklearn/cluster/mean_shift_.py +++ b/sklearn/cluster/mean_shift_.py @@ -403,7 +403,7 @@ def __init__(self, bandwidth=None, seeds=None, bin_seeding=False, self.min_bin_freq = min_bin_freq self.n_jobs = n_jobs - def fit(self, X, y=None): + def fit(self, X, y=None, feature_names_in=None): """Perform clustering. Parameters diff --git a/sklearn/cluster/optics_.py b/sklearn/cluster/optics_.py index 4f7eb11ab2f72..b591a31956a34 100755 --- a/sklearn/cluster/optics_.py +++ b/sklearn/cluster/optics_.py @@ -212,7 +212,7 @@ def __init__(self, min_samples=5, max_eps=np.inf, metric='minkowski', p=2, self.predecessor_correction = predecessor_correction self.n_jobs = n_jobs - def fit(self, X, y=None): + def fit(self, X, y=None, feature_names_in=None): """Perform OPTICS clustering Extracts an ordered list of points and reachability distances, and diff --git a/sklearn/cluster/spectral.py b/sklearn/cluster/spectral.py index 57631ef66dab1..429dab416631d 100644 --- a/sklearn/cluster/spectral.py +++ b/sklearn/cluster/spectral.py @@ -445,7 +445,7 @@ def __init__(self, n_clusters=8, eigen_solver=None, n_components=None, self.kernel_params = kernel_params self.n_jobs = n_jobs - def fit(self, X, y=None): + def fit(self, X, y=None, feature_names_in=None): """Perform spectral clustering from features, or affinity matrix. Parameters diff --git a/sklearn/compose/_column_transformer.py b/sklearn/compose/_column_transformer.py index 7b863ad46649c..fcfad2118b675 100644 --- a/sklearn/compose/_column_transformer.py +++ b/sklearn/compose/_column_transformer.py @@ -322,7 +322,7 @@ def named_transformers_(self): return Bunch(**{name: trans for name, trans, _ in self.transformers_}) - def get_feature_names(self): + def _get_feature_names(self): """Get feature names from all transformers. Returns @@ -332,19 +332,20 @@ def get_feature_names(self): """ check_is_fitted(self, 'transformers_') feature_names = [] - for name, trans, _, _ in self._iter(fitted=True): + for name, trans, columns, _ in self._iter(fitted=True): if trans == 'drop': continue elif trans == 'passthrough': raise NotImplementedError( "get_feature_names is not yet supported when using " "a 'passthrough' transformer.") - elif not hasattr(trans, 'get_feature_names'): + elif not hasattr(trans, 'feature_names_out_'): raise AttributeError("Transformer %s (type %s) does not " - "provide get_feature_names." + "provide feature_names_out_." % (str(name), type(trans).__name__)) + more_names = trans.feature_names_out_ feature_names.extend([name + "__" + f for f in - trans.get_feature_names()]) + more_names]) return feature_names def _update_fitted_transformers(self, transformers): @@ -415,7 +416,7 @@ def _fit_transform(self, X, y, func, fitted=False): else: raise - def fit(self, X, y=None): + def fit(self, X, y=None, feature_names_in=None): """Fit all transformers using X. Parameters @@ -438,7 +439,7 @@ def fit(self, X, y=None): self.fit_transform(X, y=y) return self - def fit_transform(self, X, y=None): + def fit_transform(self, X, y=None, feature_names_in=None): """Fit all transformers, transform the data and concatenate results. Parameters @@ -485,7 +486,7 @@ def fit_transform(self, X, y=None): self._update_fitted_transformers(transformers) self._validate_output(Xs) - + self.feature_names_out_ = self._get_feature_names() return self._hstack(list(Xs)) def transform(self, X): diff --git a/sklearn/compose/_target.py b/sklearn/compose/_target.py index c1c3f4df4e95f..3dd51d27fc208 100644 --- a/sklearn/compose/_target.py +++ b/sklearn/compose/_target.py @@ -148,7 +148,7 @@ def _fit_transformer(self, y): " you are sure you want to proceed regardless" ", set 'check_inverse=False'", UserWarning) - def fit(self, X, y, sample_weight=None): + def fit(self, X, y, sample_weight=None, feature_names_in=None): """Fit the model according to the given training data. Parameters diff --git a/sklearn/compose/tests/test_column_transformer.py b/sklearn/compose/tests/test_column_transformer.py index f1abbdccbdb42..9c6783b8c43f0 100644 --- a/sklearn/compose/tests/test_column_transformer.py +++ b/sklearn/compose/tests/test_column_transformer.py @@ -18,10 +18,11 @@ from sklearn.exceptions import NotFittedError from sklearn.preprocessing import StandardScaler, Normalizer, OneHotEncoder from sklearn.feature_extraction import DictVectorizer +from sklearn.pipeline import make_pipeline class Trans(BaseEstimator): - def fit(self, X, y=None): + def fit(self, X, y=None, feature_names_in=None): return self def transform(self, X, y=None): @@ -35,7 +36,7 @@ def transform(self, X, y=None): class DoubleTrans(BaseEstimator): - def fit(self, X, y=None): + def fit(self, X, y=None, feature_names_in=None): return self def transform(self, X): @@ -43,7 +44,7 @@ def transform(self, X): class SparseMatrixTrans(BaseEstimator): - def fit(self, X, y=None): + def fit(self, X, y=None, feature_names_in=None): return self def transform(self, X, y=None): @@ -52,7 +53,7 @@ def transform(self, X, y=None): class TransNo2D(BaseEstimator): - def fit(self, X, y=None): + def fit(self, X, y=None, feature_names_in=None): return self def transform(self, X, y=None): @@ -61,7 +62,7 @@ def transform(self, X, y=None): class TransRaise(BaseEstimator): - def fit(self, X, y=None): + def fit(self, X, y=None, feature_names_in=None): raise ValueError("specific message") def transform(self, X, y=None): @@ -220,7 +221,7 @@ def test_column_transformer_dataframe(): class TransAssert(BaseEstimator): - def fit(self, X, y=None): + def fit(self, X, y=None, feature_names_in=None): return self def transform(self, X, y=None): @@ -496,7 +497,7 @@ def test_column_transformer_invalid_columns(remainder): def test_column_transformer_invalid_transformer(): class NoTrans(BaseEstimator): - def fit(self, X, y=None): + def fit(self, X, y=None, feature_names_in=None): return self def predict(self, X): @@ -639,6 +640,18 @@ def test_column_transformer_get_feature_names(): "Transformer trans (type Trans) does not provide " "get_feature_names", ct.get_feature_names) + # if some transformers support and some don't + ct = ColumnTransformer([('trans', Trans(), [0, 1]), + ('scale', StandardScaler(), [0])]) + ct.fit(X_array) + assert_raise_message(AttributeError, + "Transformer trans (type Trans) does not provide " + "get_feature_names", ct.get_feature_names) + + # inside a pipeline + make_pipeline(ct).fit(X_array) + + # working example X = np.array([[{'a': 1, 'b': 2}, {'a': 3, 'b': 4}], [{'c': 5}, {'c': 6}]], dtype=object).T diff --git a/sklearn/compose/tests/test_target.py b/sklearn/compose/tests/test_target.py index 456850701be95..ca8a7aab6af90 100644 --- a/sklearn/compose/tests/test_target.py +++ b/sklearn/compose/tests/test_target.py @@ -228,7 +228,7 @@ def func(y): class DummyCheckerArrayTransformer(BaseEstimator, TransformerMixin): - def fit(self, X, y=None): + def fit(self, X, y=None, feature_names_in=None): assert isinstance(X, np.ndarray) return self @@ -243,7 +243,7 @@ def inverse_transform(self, X): class DummyCheckerListRegressor(DummyRegressor): - def fit(self, X, y, sample_weight=None): + def fit(self, X, y, sample_weight=None, feature_names_in=None): assert isinstance(X, list) return super().fit(X, y, sample_weight) @@ -271,7 +271,7 @@ class DummyTransformer(BaseEstimator, TransformerMixin): def __init__(self, fit_counter=0): self.fit_counter = fit_counter - def fit(self, X, y=None): + def fit(self, X, y=None, feature_names_in=None): self.fit_counter += 1 return self diff --git a/sklearn/covariance/elliptic_envelope.py b/sklearn/covariance/elliptic_envelope.py index 517f9a32dc9af..ffcc81a009ebe 100644 --- a/sklearn/covariance/elliptic_envelope.py +++ b/sklearn/covariance/elliptic_envelope.py @@ -114,7 +114,7 @@ def __init__(self, store_precision=True, assume_centered=False, random_state=random_state) self.contamination = contamination - def fit(self, X, y=None): + def fit(self, X, y=None, feature_names_in=None): """Fit the EllipticEnvelope model. Parameters diff --git a/sklearn/covariance/empirical_covariance_.py b/sklearn/covariance/empirical_covariance_.py index 924f7edd7ffee..32cfda6d0e6a9 100644 --- a/sklearn/covariance/empirical_covariance_.py +++ b/sklearn/covariance/empirical_covariance_.py @@ -173,7 +173,7 @@ def get_precision(self): precision = linalg.pinvh(self.covariance_) return precision - def fit(self, X, y=None): + def fit(self, X, y=None, feature_names_in=None): """Fits the Maximum Likelihood Estimator covariance model according to the given training data and parameters. diff --git a/sklearn/covariance/graph_lasso_.py b/sklearn/covariance/graph_lasso_.py index e78950bd60421..5a192cdd6686f 100644 --- a/sklearn/covariance/graph_lasso_.py +++ b/sklearn/covariance/graph_lasso_.py @@ -368,7 +368,7 @@ def __init__(self, alpha=.01, mode='cd', tol=1e-4, enet_tol=1e-4, self.max_iter = max_iter self.verbose = verbose - def fit(self, X, y=None): + def fit(self, X, y=None, feature_names_in=None): """Fits the GraphicalLasso model to X. Parameters @@ -635,7 +635,7 @@ def __init__(self, alphas=4, n_refinements=4, cv=None, tol=1e-4, self.cv = cv self.n_jobs = n_jobs - def fit(self, X, y=None): + def fit(self, X, y=None, feature_names_in=None): """Fits the GraphicalLasso covariance model to X. Parameters diff --git a/sklearn/covariance/robust_covariance.py b/sklearn/covariance/robust_covariance.py index 173794e5340c2..ea62aa6e6e275 100644 --- a/sklearn/covariance/robust_covariance.py +++ b/sklearn/covariance/robust_covariance.py @@ -619,7 +619,7 @@ def __init__(self, store_precision=True, assume_centered=False, self.support_fraction = support_fraction self.random_state = random_state - def fit(self, X, y=None): + def fit(self, X, y=None, feature_names_in=None): """Fits a Minimum Covariance Determinant with the FastMCD algorithm. Parameters diff --git a/sklearn/covariance/shrunk_covariance_.py b/sklearn/covariance/shrunk_covariance_.py index 2fce7138121fe..860732971cdee 100644 --- a/sklearn/covariance/shrunk_covariance_.py +++ b/sklearn/covariance/shrunk_covariance_.py @@ -129,7 +129,7 @@ def __init__(self, store_precision=True, assume_centered=False, assume_centered=assume_centered) self.shrinkage = shrinkage - def fit(self, X, y=None): + def fit(self, X, y=None, feature_names_in=None): """ Fits the shrunk covariance model according to the given training data and parameters. @@ -404,7 +404,7 @@ def __init__(self, store_precision=True, assume_centered=False, assume_centered=assume_centered) self.block_size = block_size - def fit(self, X, y=None): + def fit(self, X, y=None, feature_names_in=None): """ Fits the Ledoit-Wolf shrunk covariance model according to the given training data and parameters. @@ -559,7 +559,7 @@ class OAS(EmpiricalCovariance): """ - def fit(self, X, y=None): + def fit(self, X, y=None, feature_names_in=None): """ Fits the Oracle Approximating Shrinkage covariance model according to the given training data and parameters. diff --git a/sklearn/cross_decomposition/pls_.py b/sklearn/cross_decomposition/pls_.py index 175a472e6d4fb..9affa87544431 100644 --- a/sklearn/cross_decomposition/pls_.py +++ b/sklearn/cross_decomposition/pls_.py @@ -441,7 +441,7 @@ def predict(self, X, copy=True): Ypred = np.dot(X, self.coef_) return Ypred + self.y_mean_ - def fit_transform(self, X, y=None): + def fit_transform(self, X, y=None, feature_names_in=None): """Learn and apply the dimension reduction on the train data. Parameters @@ -884,7 +884,7 @@ def transform(self, X, Y=None): return x_scores, y_scores return x_scores - def fit_transform(self, X, y=None): + def fit_transform(self, X, y=None, feature_names_in=None): """Learn and apply the dimension reduction on the train data. Parameters diff --git a/sklearn/decomposition/dict_learning.py b/sklearn/decomposition/dict_learning.py index a318c957fe232..9faf3f38e78c8 100644 --- a/sklearn/decomposition/dict_learning.py +++ b/sklearn/decomposition/dict_learning.py @@ -999,7 +999,7 @@ def __init__(self, dictionary, transform_algorithm='omp', positive_code) self.components_ = dictionary - def fit(self, X, y=None): + def fit(self, X, y=None, feature_names_in=None): """Do nothing and return the estimator unchanged This method is just there to implement the usual API and hence @@ -1168,7 +1168,7 @@ def __init__(self, n_components=None, alpha=1, max_iter=1000, tol=1e-8, self.random_state = random_state self.positive_dict = positive_dict - def fit(self, X, y=None): + def fit(self, X, y=None, feature_names_in=None): """Fit the model from data in X. Parameters @@ -1358,7 +1358,7 @@ def __init__(self, n_components=None, alpha=1, n_iter=1000, self.random_state = random_state self.positive_dict = positive_dict - def fit(self, X, y=None): + def fit(self, X, y=None, feature_names_in=None): """Fit the model from data in X. Parameters diff --git a/sklearn/decomposition/factor_analysis.py b/sklearn/decomposition/factor_analysis.py index f5b1834643c5d..c849bb5ca3102 100644 --- a/sklearn/decomposition/factor_analysis.py +++ b/sklearn/decomposition/factor_analysis.py @@ -150,7 +150,7 @@ def __init__(self, n_components=None, tol=1e-2, copy=True, max_iter=1000, self.iterated_power = iterated_power self.random_state = random_state - def fit(self, X, y=None): + def fit(self, X, y=None, feature_names_in=None): """Fit the FactorAnalysis model to X using EM Parameters diff --git a/sklearn/decomposition/fastica_.py b/sklearn/decomposition/fastica_.py index dd04e8e93a1c6..83c10a6c78ab4 100644 --- a/sklearn/decomposition/fastica_.py +++ b/sklearn/decomposition/fastica_.py @@ -516,7 +516,7 @@ def _fit(self, X, compute_sources=False): return sources - def fit_transform(self, X, y=None): + def fit_transform(self, X, y=None, feature_names_in=None): """Fit the model and recover the sources from X. Parameters @@ -533,7 +533,7 @@ def fit_transform(self, X, y=None): """ return self._fit(X, compute_sources=True) - def fit(self, X, y=None): + def fit(self, X, y=None, feature_names_in=None): """Fit the model to X. Parameters diff --git a/sklearn/decomposition/incremental_pca.py b/sklearn/decomposition/incremental_pca.py index c6d611dcd5fea..557f93d21098c 100644 --- a/sklearn/decomposition/incremental_pca.py +++ b/sklearn/decomposition/incremental_pca.py @@ -166,7 +166,7 @@ def __init__(self, n_components=None, whiten=False, copy=True, self.copy = copy self.batch_size = batch_size - def fit(self, X, y=None): + def fit(self, X, y=None, feature_names_in=None): """Fit the model with X, using minibatches of size batch_size. Parameters diff --git a/sklearn/decomposition/kernel_pca.py b/sklearn/decomposition/kernel_pca.py index 555bd619c5a62..f8bbb6d86cad4 100644 --- a/sklearn/decomposition/kernel_pca.py +++ b/sklearn/decomposition/kernel_pca.py @@ -257,7 +257,7 @@ def _fit_inverse_transform(self, X_transformed, X): self.dual_coef_ = linalg.solve(K, X, sym_pos=True, overwrite_a=True) self.X_transformed_fit_ = X_transformed - def fit(self, X, y=None): + def fit(self, X, y=None, feature_names_in=None): """Fit the model from data in X. Parameters diff --git a/sklearn/decomposition/online_lda.py b/sklearn/decomposition/online_lda.py index c1d482f0a46c6..372dcae19ca97 100644 --- a/sklearn/decomposition/online_lda.py +++ b/sklearn/decomposition/online_lda.py @@ -469,7 +469,7 @@ def _check_non_neg_array(self, X, whom): check_non_negative(X, whom) return X - def partial_fit(self, X, y=None): + def partial_fit(self, X, y=None, feature_names_in=None): """Online VB with Mini-Batch update. Parameters @@ -510,7 +510,7 @@ def partial_fit(self, X, y=None): return self - def fit(self, X, y=None): + def fit(self, X, y=None, feature_names_in=None): """Learn model for the data X with variational Bayes method. When `learning_method` is 'online', use mini-batch update. diff --git a/sklearn/decomposition/pca.py b/sklearn/decomposition/pca.py index ccde667d0d20d..60c893b53a84f 100644 --- a/sklearn/decomposition/pca.py +++ b/sklearn/decomposition/pca.py @@ -319,7 +319,7 @@ def __init__(self, n_components=None, copy=True, whiten=False, self.iterated_power = iterated_power self.random_state = random_state - def fit(self, X, y=None): + def fit(self, X, y=None, feature_names_in=None): """Fit the model with X. Parameters @@ -338,7 +338,7 @@ def fit(self, X, y=None): self._fit(X) return self - def fit_transform(self, X, y=None): + def fit_transform(self, X, y=None, feature_names_in=None): """Fit the model with X and apply the dimensionality reduction on X. Parameters diff --git a/sklearn/decomposition/sparse_pca.py b/sklearn/decomposition/sparse_pca.py index 238f6cc4ef403..4e0f3cfe62532 100644 --- a/sklearn/decomposition/sparse_pca.py +++ b/sklearn/decomposition/sparse_pca.py @@ -149,7 +149,7 @@ def __init__(self, n_components=None, alpha=1, ridge_alpha=0.01, self.random_state = random_state self.normalize_components = normalize_components - def fit(self, X, y=None): + def fit(self, X, y=None, feature_names_in=None): """Fit the model from data in X. Parameters @@ -347,7 +347,7 @@ def __init__(self, n_components=None, alpha=1, ridge_alpha=0.01, self.batch_size = batch_size self.shuffle = shuffle - def fit(self, X, y=None): + def fit(self, X, y=None, feature_names_in=None): """Fit the model from data in X. Parameters diff --git a/sklearn/decomposition/truncated_svd.py b/sklearn/decomposition/truncated_svd.py index ce79fba2fad1d..b59082b235289 100644 --- a/sklearn/decomposition/truncated_svd.py +++ b/sklearn/decomposition/truncated_svd.py @@ -122,7 +122,7 @@ def __init__(self, n_components=2, algorithm="randomized", n_iter=5, self.random_state = random_state self.tol = tol - def fit(self, X, y=None): + def fit(self, X, y=None, feature_names_in=None): """Fit LSI model on training data X. Parameters @@ -140,7 +140,7 @@ def fit(self, X, y=None): self.fit_transform(X) return self - def fit_transform(self, X, y=None): + def fit_transform(self, X, y=None, feature_names_in=None): """Fit LSI model to X and perform dimensionality reduction on X. Parameters diff --git a/sklearn/dummy.py b/sklearn/dummy.py index 98ecef6f6c459..ec89cee6f8ba2 100644 --- a/sklearn/dummy.py +++ b/sklearn/dummy.py @@ -83,7 +83,7 @@ def __init__(self, strategy="stratified", random_state=None, self.random_state = random_state self.constant = constant - def fit(self, X, y, sample_weight=None): + def fit(self, X, y, sample_weight=None, feature_names_in=None): """Fit the random classifier. Parameters @@ -399,7 +399,7 @@ def __init__(self, strategy="mean", constant=None, quantile=None): self.constant = constant self.quantile = quantile - def fit(self, X, y, sample_weight=None): + def fit(self, X, y, sample_weight=None, feature_names_in=None): """Fit the random regressor. Parameters diff --git a/sklearn/ensemble/_hist_gradient_boosting/binning.py b/sklearn/ensemble/_hist_gradient_boosting/binning.py index 34bd43cde4061..c0b05ad1dce2f 100644 --- a/sklearn/ensemble/_hist_gradient_boosting/binning.py +++ b/sklearn/ensemble/_hist_gradient_boosting/binning.py @@ -104,7 +104,7 @@ def __init__(self, max_bins=256, subsample=int(2e5), random_state=None): self.subsample = subsample self.random_state = random_state - def fit(self, X, y=None): + def fit(self, X, y=None, feature_names_in=None): """Fit data X by computing the binning thresholds. Parameters diff --git a/sklearn/ensemble/bagging.py b/sklearn/ensemble/bagging.py index 15096afefa810..f1e2eab05b2dd 100644 --- a/sklearn/ensemble/bagging.py +++ b/sklearn/ensemble/bagging.py @@ -216,7 +216,7 @@ def __init__(self, self.random_state = random_state self.verbose = verbose - def fit(self, X, y, sample_weight=None): + def fit(self, X, y, sample_weight=None, feature_names_in=None): """Build a Bagging ensemble of estimators from the training set (X, y). diff --git a/sklearn/ensemble/forest.py b/sklearn/ensemble/forest.py index 6050fd2773a5f..3eed777909f95 100644 --- a/sklearn/ensemble/forest.py +++ b/sklearn/ensemble/forest.py @@ -214,7 +214,7 @@ def decision_path(self, X): return sparse_hstack(indicators).tocsr(), n_nodes_ptr - def fit(self, X, y, sample_weight=None): + def fit(self, X, y, sample_weight=None, feature_names_in=None): """Build a forest of trees from the training set (X, y). Parameters @@ -2012,3 +2012,12 @@ def transform(self, X): """ check_is_fitted(self, 'one_hot_encoder_') return self.one_hot_encoder_.transform(self.apply(X)) + + def _get_feature_names(self, input_features=None): + """Feature names - not implemented yet. + + Parameters + ---------- + input_features : list of strings or None + """ + return None \ No newline at end of file diff --git a/sklearn/ensemble/gradient_boosting.py b/sklearn/ensemble/gradient_boosting.py index f6401136fcbc6..2e6f5c3c5aaec 100644 --- a/sklearn/ensemble/gradient_boosting.py +++ b/sklearn/ensemble/gradient_boosting.py @@ -81,7 +81,7 @@ def __init__(self, alpha=0.9): raise ValueError("`alpha` must be in (0, 1.0) but was %r" % alpha) self.alpha = alpha - def fit(self, X, y, sample_weight=None): + def fit(self, X, y, sample_weight=None, feature_names_in=None): """Fit the estimator. Parameters @@ -125,7 +125,7 @@ def predict(self, X): "0.21 and will be removed in version 0.23.") class MeanEstimator: """An estimator predicting the mean of the training targets.""" - def fit(self, X, y, sample_weight=None): + def fit(self, X, y, sample_weight=None, feature_names_in=None): """Fit the estimator. Parameters @@ -170,7 +170,7 @@ class LogOddsEstimator: """An estimator predicting the log odds ratio.""" scale = 1.0 - def fit(self, X, y, sample_weight=None): + def fit(self, X, y, sample_weight=None, feature_names_in=None): """Fit the estimator. Parameters @@ -229,7 +229,7 @@ class PriorProbabilityEstimator: """An estimator predicting the probability of each class in the training data. """ - def fit(self, X, y, sample_weight=None): + def fit(self, X, y, sample_weight=None, feature_names_in=None): """Fit the estimator. Parameters @@ -279,7 +279,7 @@ class ZeroEstimator: """ - def fit(self, X, y, sample_weight=None): + def fit(self, X, y, sample_weight=None, feature_names_in=None): """Fit the estimator. Parameters diff --git a/sklearn/ensemble/tests/test_weight_boosting.py b/sklearn/ensemble/tests/test_weight_boosting.py index 1cb1e9d1431cf..b5695f82374b1 100755 --- a/sklearn/ensemble/tests/test_weight_boosting.py +++ b/sklearn/ensemble/tests/test_weight_boosting.py @@ -320,7 +320,7 @@ def test_sparse_classification(): class CustomSVC(SVC): """SVC variant that records the nature of the training set.""" - def fit(self, X, y, sample_weight=None): + def fit(self, X, y, sample_weight=None, feature_names_in=None): """Modification on fit caries data type for later verification.""" super().fit(X, y, sample_weight=sample_weight) self.data_type_ = type(X) @@ -417,7 +417,7 @@ def test_sparse_regression(): class CustomSVR(SVR): """SVR variant that records the nature of the training set.""" - def fit(self, X, y, sample_weight=None): + def fit(self, X, y, sample_weight=None, feature_names_in=None): """Modification on fit caries data type for later verification.""" super().fit(X, y, sample_weight=sample_weight) self.data_type_ = type(X) diff --git a/sklearn/ensemble/voting.py b/sklearn/ensemble/voting.py index 0b01340d4f1af..13bc60fa0cf72 100644 --- a/sklearn/ensemble/voting.py +++ b/sklearn/ensemble/voting.py @@ -70,7 +70,7 @@ def _predict(self, X): return np.asarray([clf.predict(X) for clf in self.estimators_]).T @abstractmethod - def fit(self, X, y, sample_weight=None): + def fit(self, X, y, sample_weight=None, feature_names_in=None): """ common fit operations. """ @@ -243,7 +243,7 @@ def __init__(self, estimators, voting='hard', weights=None, n_jobs=None, self.n_jobs = n_jobs self.flatten_transform = flatten_transform - def fit(self, X, y, sample_weight=None): + def fit(self, X, y, sample_weight=None, feature_names_in=None): """ Fit the estimators. Parameters @@ -433,7 +433,7 @@ def __init__(self, estimators, weights=None, n_jobs=None): self.weights = weights self.n_jobs = n_jobs - def fit(self, X, y, sample_weight=None): + def fit(self, X, y, sample_weight=None, feature_names_in=None): """ Fit the estimators. Parameters diff --git a/sklearn/ensemble/weight_boosting.py b/sklearn/ensemble/weight_boosting.py index 3cb4baa0d9a0c..e3fbc1368e7c3 100644 --- a/sklearn/ensemble/weight_boosting.py +++ b/sklearn/ensemble/weight_boosting.py @@ -89,7 +89,7 @@ def _validate_data(self, X, y=None): y_numeric=is_regressor(self)) return ret - def fit(self, X, y, sample_weight=None): + def fit(self, X, y, sample_weight=None, feature_names_in=None): """Build a boosted classifier/regressor from the training set (X, y). Parameters @@ -398,7 +398,7 @@ def __init__(self, self.algorithm = algorithm - def fit(self, X, y, sample_weight=None): + def fit(self, X, y, sample_weight=None, feature_names_in=None): """Build a boosted classifier from the training set (X, y). Parameters @@ -965,7 +965,7 @@ def __init__(self, self.loss = loss self.random_state = random_state - def fit(self, X, y, sample_weight=None): + def fit(self, X, y, sample_weight=None, feature_names_in=None): """Build a boosted regressor from the training set (X, y). Parameters diff --git a/sklearn/feature_extraction/dict_vectorizer.py b/sklearn/feature_extraction/dict_vectorizer.py index 4a2aa58189c93..190fa96e75199 100644 --- a/sklearn/feature_extraction/dict_vectorizer.py +++ b/sklearn/feature_extraction/dict_vectorizer.py @@ -98,7 +98,7 @@ def __init__(self, dtype=np.float64, separator="=", sparse=True, self.sparse = sparse self.sort = sort - def fit(self, X, y=None): + def fit(self, X, y=None, feature_names_in=None): """Learn a list of feature name -> indices mappings. Parameters @@ -208,7 +208,7 @@ def _transform(self, X, fitting): return result_matrix - def fit_transform(self, X, y=None): + def fit_transform(self, X, y=None, feature_names_in=None): """Learn a list of feature name -> indices mappings and transform X. Like fit(X) followed by transform(X), but does not require @@ -307,7 +307,7 @@ def transform(self, X): return Xa - def get_feature_names(self): + def _get_feature_names(self): """Returns a list of feature names, ordered by their indices. If one-of-K coding is applied to categorical features, this will diff --git a/sklearn/feature_extraction/image.py b/sklearn/feature_extraction/image.py index aa0f445fd3ee8..4e06aacaf3955 100644 --- a/sklearn/feature_extraction/image.py +++ b/sklearn/feature_extraction/image.py @@ -480,7 +480,7 @@ def __init__(self, patch_size=None, max_patches=None, random_state=None): self.max_patches = max_patches self.random_state = random_state - def fit(self, X, y=None): + def fit(self, X, y=None, feature_names_in=None): """Do nothing and return the estimator unchanged This method is just there to implement the usual API and hence diff --git a/sklearn/feature_extraction/text.py b/sklearn/feature_extraction/text.py index 8b60b59e0dd02..ca53510c26c3c 100644 --- a/sklearn/feature_extraction/text.py +++ b/sklearn/feature_extraction/text.py @@ -586,7 +586,7 @@ def __init__(self, input='content', encoding='utf-8', self.alternate_sign = alternate_sign self.dtype = dtype - def partial_fit(self, X, y=None): + def partial_fit(self, X, y=None, feature_names_in=None): """Does nothing: this transformer is stateless. This method is just there to mark the fact that this transformer @@ -599,7 +599,7 @@ def partial_fit(self, X, y=None): """ return self - def fit(self, X, y=None): + def fit(self, X, y=None, feature_names_in=None): """Does nothing: this transformer is stateless. Parameters @@ -648,7 +648,7 @@ def transform(self, X): X = normalize(X, norm=self.norm, copy=False) return X - def fit_transform(self, X, y=None): + def fit_transform(self, X, y=None, feature_names_in=None): """Transform a sequence of documents to a document-term matrix. Parameters @@ -1144,7 +1144,7 @@ def inverse_transform(self, X): return [inverse_vocabulary[X[i, :].nonzero()[1]].ravel() for i in range(n_samples)] - def get_feature_names(self): + def _get_feature_names(self): """Array mapping from feature integer indices to feature name""" if not hasattr(self, 'vocabulary_'): self._validate_vocabulary() @@ -1251,7 +1251,7 @@ def __init__(self, norm='l2', use_idf=True, smooth_idf=True, self.smooth_idf = smooth_idf self.sublinear_tf = sublinear_tf - def fit(self, X, y=None): + def fit(self, X, y=None, feature_names_in=None): """Learn the idf vector (global term weights) Parameters diff --git a/sklearn/feature_selection/base.py b/sklearn/feature_selection/base.py index 5add330188f78..1c3801a69c12f 100644 --- a/sklearn/feature_selection/base.py +++ b/sklearn/feature_selection/base.py @@ -119,3 +119,18 @@ def inverse_transform(self, X): Xt = np.zeros((X.shape[0], support.size), dtype=X.dtype) Xt[:, support] = X return Xt + + def _get_feature_names(self, input_features=None): + """Mask feature names according to selected features. + + Parameters + ---------- + input_features : list of string or None + Input features to select from. If none, they are generated as + x0, x1, ..., xn. + """ + mask = self.get_support() + if input_features is None: + input_features = ['x%d' % i + for i in range(mask.shape[0])] + return np.array(input_features)[mask] diff --git a/sklearn/feature_selection/tests/test_base.py b/sklearn/feature_selection/tests/test_base.py index f75f1789243fc..da09865977c45 100644 --- a/sklearn/feature_selection/tests/test_base.py +++ b/sklearn/feature_selection/tests/test_base.py @@ -14,7 +14,7 @@ class StepSelector(SelectorMixin, BaseEstimator): def __init__(self, step=2): self.step = step - def fit(self, X, y=None): + def fit(self, X, y=None, feature_names_in=None): X = check_array(X, 'csc') self.n_input_feats = X.shape[1] return self diff --git a/sklearn/feature_selection/tests/test_from_model.py b/sklearn/feature_selection/tests/test_from_model.py index 3c281c552c7d5..3bcfe451755d7 100644 --- a/sklearn/feature_selection/tests/test_from_model.py +++ b/sklearn/feature_selection/tests/test_from_model.py @@ -70,7 +70,7 @@ class FixedImportanceEstimator(BaseEstimator): def __init__(self, importances): self.importances = importances - def fit(self, X, y=None): + def fit(self, X, y=None, feature_names_in=None): self.feature_importances_ = np.array(self.importances) diff --git a/sklearn/feature_selection/variance_threshold.py b/sklearn/feature_selection/variance_threshold.py index 7d98de82c9711..bf1b4a085347f 100644 --- a/sklearn/feature_selection/variance_threshold.py +++ b/sklearn/feature_selection/variance_threshold.py @@ -45,7 +45,7 @@ class VarianceThreshold(BaseEstimator, SelectorMixin): def __init__(self, threshold=0.): self.threshold = threshold - def fit(self, X, y=None): + def fit(self, X, y=None, feature_names_in=None): """Learn empirical variances from X. Parameters diff --git a/sklearn/impute/_base.py b/sklearn/impute/_base.py index 51fb223860daf..0d4f586e571f8 100644 --- a/sklearn/impute/_base.py +++ b/sklearn/impute/_base.py @@ -216,7 +216,7 @@ def _validate_input(self, X): return X - def fit(self, X, y=None): + def fit(self, X, y=None, feature_names_in=None): """Fit the imputer on X. Parameters @@ -230,6 +230,7 @@ def fit(self, X, y=None): self : SimpleImputer """ X = self._validate_input(X) + self.feature_names_in_ = feature_names_in # default fill_value is 0 for numerical input and "missing_value" # otherwise @@ -273,7 +274,9 @@ def fit(self, X, y=None): self.indicator_.fit(X) else: self.indicator_ = None - + invalid_mask = _get_mask(self.statistics_, np.nan) + self._valid_mask = np.logical_not(invalid_mask) + self.feature_names_out_ = self._get_feature_names(feature_names_in) return self def _sparse_fit(self, X, strategy, missing_values, fill_value): @@ -433,6 +436,25 @@ def transform(self, X): def _more_tags(self): return {'allow_nan': True} + def _get_feature_names(self, input_features=None): + """Get feature names for transformation. + + Parameters + ---------- + input_features : array-like of string + Input feature names. + + Returns + ------- + feature_names : array-like of string + Transformed feature names + """ + check_is_fitted(self, 'statistics_') + if input_features is None: + input_features = ['x%d' % i + for i in range(self.statistics_.shape[0])] + return np.array(input_features)[self._valid_mask] + class MissingIndicator(BaseEstimator, TransformerMixin): """Binary indicators for missing values. @@ -586,7 +608,7 @@ def _validate_input(self, X): return X - def fit(self, X, y=None): + def fit(self, X, y=None, feature_names_in=None): """Fit the transformer on X. Parameters @@ -652,7 +674,7 @@ def transform(self, X): return imputer_mask - def fit_transform(self, X, y=None): + def fit_transform(self, X, y=None, feature_names_in=None): """Generate missing values indicator for X. Parameters diff --git a/sklearn/impute/_iterative.py b/sklearn/impute/_iterative.py index 4c3fa4f2c1872..2930f611c1355 100644 --- a/sklearn/impute/_iterative.py +++ b/sklearn/impute/_iterative.py @@ -492,7 +492,7 @@ def _initial_imputation(self, X): return Xt, X_filled, mask_missing_values - def fit_transform(self, X, y=None): + def fit_transform(self, X, y=None, feature_names_in=None): """Fits the imputer on X and return the transformed X. Parameters @@ -661,7 +661,7 @@ def transform(self, X): Xt = np.hstack((Xt, X_trans_indicator)) return Xt - def fit(self, X, y=None): + def fit(self, X, y=None, feature_names_in=None): """Fits the imputer on X and return self. Parameters diff --git a/sklearn/isotonic.py b/sklearn/isotonic.py index 40beb3abcab73..32588043e9b1c 100644 --- a/sklearn/isotonic.py +++ b/sklearn/isotonic.py @@ -299,7 +299,7 @@ def _build_y(self, X, y, sample_weight, trim_duplicates=True): # prediction speed). return X, y - def fit(self, X, y, sample_weight=None): + def fit(self, X, y, sample_weight=None, feature_names_in=None): """Fit the model using X, y as training data. Parameters diff --git a/sklearn/kernel_approximation.py b/sklearn/kernel_approximation.py index 4f8fd96a10fa6..a1b5cb3b6d0ea 100644 --- a/sklearn/kernel_approximation.py +++ b/sklearn/kernel_approximation.py @@ -74,7 +74,7 @@ def __init__(self, gamma=1., n_components=100, random_state=None): self.n_components = n_components self.random_state = random_state - def fit(self, X, y=None): + def fit(self, X, y=None, feature_names_in=None): """Fit the model with X. Samples random projection according to n_features. @@ -180,7 +180,7 @@ def __init__(self, skewedness=1., n_components=100, random_state=None): self.n_components = n_components self.random_state = random_state - def fit(self, X, y=None): + def fit(self, X, y=None, feature_names_in=None): """Fit the model with X. Samples random projection according to n_features. @@ -304,7 +304,7 @@ def __init__(self, sample_steps=2, sample_interval=None): self.sample_steps = sample_steps self.sample_interval = sample_interval - def fit(self, X, y=None): + def fit(self, X, y=None, feature_names_in=None): """Set the parameters Parameters @@ -522,7 +522,7 @@ def __init__(self, kernel="rbf", gamma=None, coef0=None, degree=None, self.n_components = n_components self.random_state = random_state - def fit(self, X, y=None): + def fit(self, X, y=None, feature_names_in=None): """Fit estimator to data. Samples a subset of training points, computes kernel diff --git a/sklearn/linear_model/base.py b/sklearn/linear_model/base.py index 79be362d9b628..7aaa5a6526bb1 100644 --- a/sklearn/linear_model/base.py +++ b/sklearn/linear_model/base.py @@ -432,7 +432,7 @@ def __init__(self, fit_intercept=True, normalize=False, copy_X=True, self.copy_X = copy_X self.n_jobs = n_jobs - def fit(self, X, y, sample_weight=None): + def fit(self, X, y, sample_weight=None, feature_names_in=None): """ Fit linear model. diff --git a/sklearn/linear_model/bayes.py b/sklearn/linear_model/bayes.py index c4ae0a6437ada..0615afb710545 100644 --- a/sklearn/linear_model/bayes.py +++ b/sklearn/linear_model/bayes.py @@ -168,7 +168,7 @@ def __init__(self, n_iter=300, tol=1.e-3, alpha_1=1.e-6, alpha_2=1.e-6, self.copy_X = copy_X self.verbose = verbose - def fit(self, X, y, sample_weight=None): + def fit(self, X, y, sample_weight=None, feature_names_in=None): """Fit the model Parameters diff --git a/sklearn/linear_model/huber.py b/sklearn/linear_model/huber.py index 0a4b6e10f6f98..b64c8e3412d9f 100644 --- a/sklearn/linear_model/huber.py +++ b/sklearn/linear_model/huber.py @@ -231,7 +231,7 @@ def __init__(self, epsilon=1.35, max_iter=100, alpha=0.0001, self.fit_intercept = fit_intercept self.tol = tol - def fit(self, X, y, sample_weight=None): + def fit(self, X, y, sample_weight=None, feature_names_in=None): """Fit the model according to the given training data. Parameters diff --git a/sklearn/linear_model/logistic.py b/sklearn/linear_model/logistic.py index 47ecd8ade736f..efeacc75a0f56 100644 --- a/sklearn/linear_model/logistic.py +++ b/sklearn/linear_model/logistic.py @@ -1450,7 +1450,7 @@ def __init__(self, penalty='l2', dual=False, tol=1e-4, C=1.0, self.n_jobs = n_jobs self.l1_ratio = l1_ratio - def fit(self, X, y, sample_weight=None): + def fit(self, X, y, sample_weight=None, feature_names_in=None): """Fit the model according to the given training data. Parameters @@ -1477,6 +1477,7 @@ def fit(self, X, y, sample_weight=None): ----- The SAGA solver supports both float64 and float32 bit arrays. """ + self.feature_names_in_ = feature_names_in solver = _check_solver(self.solver, self.penalty, self.dual) if not isinstance(self.C, numbers.Number) or self.C < 0: @@ -1933,7 +1934,7 @@ def __init__(self, Cs=10, fit_intercept=True, cv=None, dual=False, self.random_state = random_state self.l1_ratios = l1_ratios - def fit(self, X, y, sample_weight=None): + def fit(self, X, y, sample_weight=None, feature_names_in=None): """Fit the model according to the given training data. Parameters diff --git a/sklearn/linear_model/ransac.py b/sklearn/linear_model/ransac.py index 7f4fb650b59e8..423a4a5b6835d 100644 --- a/sklearn/linear_model/ransac.py +++ b/sklearn/linear_model/ransac.py @@ -227,7 +227,7 @@ def __init__(self, base_estimator=None, min_samples=None, self.random_state = random_state self.loss = loss - def fit(self, X, y, sample_weight=None): + def fit(self, X, y, sample_weight=None, feature_names_in=None): """Fit estimator using RANSAC algorithm. Parameters diff --git a/sklearn/linear_model/ridge.py b/sklearn/linear_model/ridge.py index 45862d5f3cffb..e7c60443dac87 100644 --- a/sklearn/linear_model/ridge.py +++ b/sklearn/linear_model/ridge.py @@ -535,7 +535,7 @@ def __init__(self, alpha=1.0, fit_intercept=True, normalize=False, self.solver = solver self.random_state = random_state - def fit(self, X, y, sample_weight=None): + def fit(self, X, y, sample_weight=None, feature_names_in=None): # all other solvers work at both float precision levels _dtype = [np.float64, np.float32] @@ -725,7 +725,7 @@ def __init__(self, alpha=1.0, fit_intercept=True, normalize=False, max_iter=max_iter, tol=tol, solver=solver, random_state=random_state) - def fit(self, X, y, sample_weight=None): + def fit(self, X, y, sample_weight=None, feature_names_in=None): """Fit Ridge regression model Parameters @@ -877,7 +877,7 @@ def __init__(self, alpha=1.0, fit_intercept=True, normalize=False, random_state=random_state) self.class_weight = class_weight - def fit(self, X, y, sample_weight=None): + def fit(self, X, y, sample_weight=None, feature_names_in=None): """Fit Ridge regression model. Parameters @@ -1379,7 +1379,7 @@ def _solve_svd_design_matrix( G_inverse_diag = G_inverse_diag[:, np.newaxis] return G_inverse_diag, c - def fit(self, X, y, sample_weight=None): + def fit(self, X, y, sample_weight=None, feature_names_in=None): """Fit Ridge regression model Parameters @@ -1498,7 +1498,7 @@ def __init__(self, alphas=(0.1, 1.0, 10.0), self.gcv_mode = gcv_mode self.store_cv_values = store_cv_values - def fit(self, X, y, sample_weight=None): + def fit(self, X, y, sample_weight=None, feature_names_in=None): """Fit Ridge regression model Parameters @@ -1785,7 +1785,7 @@ def __init__(self, alphas=(0.1, 1.0, 10.0), fit_intercept=True, scoring=scoring, cv=cv, store_cv_values=store_cv_values) self.class_weight = class_weight - def fit(self, X, y, sample_weight=None): + def fit(self, X, y, sample_weight=None, feature_names_in=None): """Fit the ridge classifier. Parameters diff --git a/sklearn/linear_model/stochastic_gradient.py b/sklearn/linear_model/stochastic_gradient.py index 625bdb5bdc3f9..02f065da9ae75 100644 --- a/sklearn/linear_model/stochastic_gradient.py +++ b/sklearn/linear_model/stochastic_gradient.py @@ -1115,7 +1115,7 @@ def _partial_fit(self, X, y, alpha, C, loss, learning_rate, return self - def partial_fit(self, X, y, sample_weight=None): + def partial_fit(self, X, y, sample_weight=None, feature_names_in=None): """Perform one epoch of stochastic gradient descent on given samples. Internally, this method uses ``max_iter = 1``. Therefore, it is not diff --git a/sklearn/manifold/isomap.py b/sklearn/manifold/isomap.py index 88c979c0e1fdb..8056bebcdef41 100644 --- a/sklearn/manifold/isomap.py +++ b/sklearn/manifold/isomap.py @@ -161,7 +161,7 @@ def reconstruction_error(self): evals = self.kernel_pca_.lambdas_ return np.sqrt(np.sum(G_center ** 2) - np.sum(evals ** 2)) / G.shape[0] - def fit(self, X, y=None): + def fit(self, X, y=None, feature_names_in=None): """Compute the embedding vectors for data X Parameters @@ -180,7 +180,7 @@ def fit(self, X, y=None): self._fit_transform(X) return self - def fit_transform(self, X, y=None): + def fit_transform(self, X, y=None, feature_names_in=None): """Fit the model from data in X and transform X. Parameters diff --git a/sklearn/manifold/locally_linear.py b/sklearn/manifold/locally_linear.py index cf3c58486c27a..6d308e1b857d2 100644 --- a/sklearn/manifold/locally_linear.py +++ b/sklearn/manifold/locally_linear.py @@ -666,7 +666,7 @@ def _fit_transform(self, X): hessian_tol=self.hessian_tol, modified_tol=self.modified_tol, random_state=random_state, reg=self.reg, n_jobs=self.n_jobs) - def fit(self, X, y=None): + def fit(self, X, y=None, feature_names_in=None): """Compute the embedding vectors for data X Parameters @@ -683,7 +683,7 @@ def fit(self, X, y=None): self._fit_transform(X) return self - def fit_transform(self, X, y=None): + def fit_transform(self, X, y=None, feature_names_in=None): """Compute the embedding vectors for data X and transform X. Parameters diff --git a/sklearn/manifold/spectral_embedding_.py b/sklearn/manifold/spectral_embedding_.py index 42227db8a72ad..f7c8f3d4e2377 100644 --- a/sklearn/manifold/spectral_embedding_.py +++ b/sklearn/manifold/spectral_embedding_.py @@ -486,7 +486,7 @@ def _get_affinity_matrix(self, X, Y=None): self.affinity_matrix_ = self.affinity(X) return self.affinity_matrix_ - def fit(self, X, y=None): + def fit(self, X, y=None, feature_names_in=None): """Fit the model from data in X. Parameters @@ -526,7 +526,7 @@ def fit(self, X, y=None): random_state=random_state) return self - def fit_transform(self, X, y=None): + def fit_transform(self, X, y=None, feature_names_in=None): """Fit the model from data in X and transform X. Parameters diff --git a/sklearn/manifold/t_sne.py b/sklearn/manifold/t_sne.py index 987f3af05a941..b1204ef6b33b4 100644 --- a/sklearn/manifold/t_sne.py +++ b/sklearn/manifold/t_sne.py @@ -861,7 +861,7 @@ def _tsne(self, P, degrees_of_freedom, n_samples, X_embedded, return X_embedded - def fit_transform(self, X, y=None): + def fit_transform(self, X, y=None, feature_names_in=None): """Fit X into an embedded space and return that transformed output. @@ -882,7 +882,7 @@ def fit_transform(self, X, y=None): self.embedding_ = embedding return self.embedding_ - def fit(self, X, y=None): + def fit(self, X, y=None, feature_names_in=None): """Fit X into an embedded space. Parameters diff --git a/sklearn/mixture/base.py b/sklearn/mixture/base.py index 8920bef181226..962d13458c027 100644 --- a/sklearn/mixture/base.py +++ b/sklearn/mixture/base.py @@ -166,7 +166,7 @@ def _initialize(self, X, resp): """ pass - def fit(self, X, y=None): + def fit(self, X, y=None, feature_names_in=None): """Estimate model parameters with the EM algorithm. The method fits the model ``n_init`` times and sets the parameters with diff --git a/sklearn/model_selection/tests/test_search.py b/sklearn/model_selection/tests/test_search.py index 90a837e7f49f1..08759ec9267ab 100644 --- a/sklearn/model_selection/tests/test_search.py +++ b/sklearn/model_selection/tests/test_search.py @@ -1400,7 +1400,7 @@ class FailingClassifier(BaseEstimator): def __init__(self, parameter=None): self.parameter = parameter - def fit(self, X, y=None): + def fit(self, X, y=None, feature_names_in=None): if self.parameter == FailingClassifier.FAILING_PARAMETER: raise ValueError("Failing classifier failed as required") diff --git a/sklearn/multioutput.py b/sklearn/multioutput.py index 4411919c1821f..c700fbb2964a0 100644 --- a/sklearn/multioutput.py +++ b/sklearn/multioutput.py @@ -121,7 +121,7 @@ def partial_fit(self, X, y, classes=None, sample_weight=None): sample_weight, first_time) for i in range(y.shape[1])) return self - def fit(self, X, y, sample_weight=None): + def fit(self, X, y, sample_weight=None, feature_names_in=None): """ Fit the model to data. Fit a separate model for each output variable. @@ -234,7 +234,7 @@ def __init__(self, estimator, n_jobs=None): super().__init__(estimator, n_jobs) @if_delegate_has_method('estimator') - def partial_fit(self, X, y, sample_weight=None): + def partial_fit(self, X, y, sample_weight=None, feature_names_in=None): """Incrementally fit the model to data. Fit a separate model for each output variable. diff --git a/sklearn/naive_bayes.py b/sklearn/naive_bayes.py index 7cd4404840798..93c9d0eed96ef 100644 --- a/sklearn/naive_bayes.py +++ b/sklearn/naive_bayes.py @@ -164,7 +164,7 @@ def __init__(self, priors=None, var_smoothing=1e-9): self.priors = priors self.var_smoothing = var_smoothing - def fit(self, X, y, sample_weight=None): + def fit(self, X, y, sample_weight=None, feature_names_in=None): """Fit Gaussian Naive Bayes according to X, y Parameters @@ -566,7 +566,7 @@ def partial_fit(self, X, y, classes=None, sample_weight=None): self._update_class_log_prior(class_prior=class_prior) return self - def fit(self, X, y, sample_weight=None): + def fit(self, X, y, sample_weight=None, feature_names_in=None): """Fit Naive Bayes classifier according to X, y Parameters diff --git a/sklearn/neighbors/base.py b/sklearn/neighbors/base.py index 44dcc326a489c..ce3313e27a256 100644 --- a/sklearn/neighbors/base.py +++ b/sklearn/neighbors/base.py @@ -920,7 +920,7 @@ def fit(self, X, y): class UnsupervisedMixin: - def fit(self, X, y=None): + def fit(self, X, y=None, feature_names_in=None): """Fit the model using X as training data Parameters diff --git a/sklearn/neighbors/lof.py b/sklearn/neighbors/lof.py index a58997502be91..02db17624b475 100644 --- a/sklearn/neighbors/lof.py +++ b/sklearn/neighbors/lof.py @@ -216,7 +216,7 @@ def _fit_predict(self, X, y=None): return self.fit(X)._predict() - def fit(self, X, y=None): + def fit(self, X, y=None, feature_names_in=None): """Fit the model using X as training data. Parameters diff --git a/sklearn/neural_network/rbm.py b/sklearn/neural_network/rbm.py index b2b6166d4d253..99aaa5881966b 100644 --- a/sklearn/neural_network/rbm.py +++ b/sklearn/neural_network/rbm.py @@ -216,7 +216,7 @@ def gibbs(self, v): return v_ - def partial_fit(self, X, y=None): + def partial_fit(self, X, y=None, feature_names_in=None): """Fit the model to the data X which should contain a partial segment of the data. @@ -318,7 +318,7 @@ def score_samples(self, X): fe_ = self._free_energy(v_) return v.shape[1] * log_logistic(fe_ - fe) - def fit(self, X, y=None): + def fit(self, X, y=None, feature_names_in=None): """Fit the model to the data X. Parameters diff --git a/sklearn/pipeline.py b/sklearn/pipeline.py index c66e37761782d..8e9a7904e64b4 100644 --- a/sklearn/pipeline.py +++ b/sklearn/pipeline.py @@ -255,12 +255,20 @@ def _log_message(self, step_idx): # Estimator interface - def _fit(self, X, y=None, **fit_params): + def _fit(self, X, y=None, feature_names_in=None, **fit_params): # shallow copy of steps - this should really be steps_ self.steps = list(self.steps) self._validate_steps() # Setup the memory memory = check_memory(self.memory) + if hasattr(X, 'columns'): + if feature_names_in is not None and feature_names_in != X.columns: + raise ValueError("feature_names_in inconsistent with " + " passed columns: {}, {}".format( + feature_names_in, X.columns)) + feature_names_in = X.columns + + self.feature_names_in_ = feature_names_in fit_transform_one_cached = memory.cache(_fit_transform_one) @@ -308,11 +316,13 @@ def _fit(self, X, y=None, **fit_params): cloned_transformer, X, y, None, message_clsname='Pipeline', message=self._log_message(step_idx), + feature_names_in=feature_names_in, **fit_params_steps[name]) # Replace the transformer of the step with the fitted # transformer. This is necessary when loading the transformer # from the cache. self.steps[step_idx] = (name, fitted_transformer) + feature_names_in = fitted_transformer.feature_names_out_ if self._final_estimator == 'passthrough': return X, {} return X, fit_params_steps[self.steps[-1][0]] @@ -344,10 +354,12 @@ def fit(self, X, y=None, **fit_params): This estimator """ Xt, fit_params = self._fit(X, y, **fit_params) + feature_names = self[-2].feature_names_out_ with _print_elapsed_time('Pipeline', self._log_message(len(self.steps) - 1)): if self._final_estimator != 'passthrough': - self._final_estimator.fit(Xt, y, **fit_params) + self._final_estimator.fit( + Xt, y, feature_names_in=feature_names, **fit_params) return self def fit_transform(self, X, y=None, **fit_params): @@ -379,14 +391,18 @@ def fit_transform(self, X, y=None, **fit_params): """ last_step = self._final_estimator Xt, fit_params = self._fit(X, y, **fit_params) + feature_names = self[-2].feature_names_out_ with _print_elapsed_time('Pipeline', self._log_message(len(self.steps) - 1)): if last_step == 'passthrough': return Xt if hasattr(last_step, 'fit_transform'): - return last_step.fit_transform(Xt, y, **fit_params) + return last_step.fit_transform(Xt, y, + feature_names_in=feature_names, + **fit_params) else: - return last_step.fit(Xt, y, **fit_params).transform(Xt) + return last_step.fit(Xt, y, feature_names_in=feature_names, + **fit_params).transform(Xt) @if_delegate_has_method(delegate='_final_estimator') def predict(self, X, **predict_params): @@ -618,6 +634,10 @@ def score(self, X, y=None, sample_weight=None): def classes_(self): return self.steps[-1][-1].classes_ + @property + def feature_names_out_(self): + return self.steps[-1][-1].feature_names_out_ + @property def _pairwise(self): # check if first estimator expects pairwise input @@ -713,6 +733,7 @@ def _fit_transform_one(transformer, weight, message_clsname='', message=None, + feature_names_in=None, **fit_params): """ Fits ``transformer`` to ``X`` and ``y``. The transformed result is returned @@ -721,9 +742,11 @@ def _fit_transform_one(transformer, """ with _print_elapsed_time(message_clsname, message): if hasattr(transformer, 'fit_transform'): - res = transformer.fit_transform(X, y, **fit_params) + res = transformer.fit_transform( + X, y, feature_names_in=feature_names_in, **fit_params) else: - res = transformer.fit(X, y, **fit_params).transform(X) + res = transformer.fit( + X, y, feature_names_in=feature_names_in, **fit_params).transform(X) if weight is None: return res, transformer @@ -736,12 +759,14 @@ def _fit_one(transformer, weight, message_clsname='', message=None, + feature_names_in=None, **fit_params): """ Fits ``transformer`` to ``X`` and ``y``. """ with _print_elapsed_time(message_clsname, message): - return transformer.fit(X, y, **fit_params) + return transformer.fit( + X, y, feature_names_in=feature_names_in, **fit_params) class FeatureUnion(_BaseComposition, TransformerMixin): @@ -858,7 +883,7 @@ def _iter(self): for name, trans in self.transformer_list if trans is not None and trans != 'drop') - def get_feature_names(self): + def _get_feature_names(self): """Get feature names from all transformers. Returns @@ -876,7 +901,7 @@ def get_feature_names(self): trans.get_feature_names()]) return feature_names - def fit(self, X, y=None): + def fit(self, X, y=None, feature_names_in=None): """Fit all transformers using X. Parameters diff --git a/sklearn/preprocessing/_discretization.py b/sklearn/preprocessing/_discretization.py index b7ffd96032d2a..d4edcac9014f9 100644 --- a/sklearn/preprocessing/_discretization.py +++ b/sklearn/preprocessing/_discretization.py @@ -119,7 +119,7 @@ def __init__(self, n_bins=5, encode='onehot', strategy='quantile'): self.encode = encode self.strategy = strategy - def fit(self, X, y=None): + def fit(self, X, y=None, feature_names_in=None): """Fits the estimator. Parameters diff --git a/sklearn/preprocessing/_encoders.py b/sklearn/preprocessing/_encoders.py index 6a688072019bf..ed486edb58025 100644 --- a/sklearn/preprocessing/_encoders.py +++ b/sklearn/preprocessing/_encoders.py @@ -97,6 +97,8 @@ def _fit(self, X, handle_unknown='error'): " during fit".format(diff, i)) raise ValueError(msg) self.categories_.append(cats) + self.feature_names_out_ = self._get_feature_names( + self.feature_names_in_) def _transform(self, X, handle_unknown='error'): X_list, n_samples, n_features = self._check_X(X) @@ -322,7 +324,7 @@ def _compute_drop_idx(self): "'first', None or array of objects, got {}") raise ValueError(msg.format(type(self.drop))) - def fit(self, X, y=None): + def fit(self, X, y=None, feature_names_in=None): """Fit OneHotEncoder to X. Parameters @@ -334,12 +336,13 @@ def fit(self, X, y=None): ------- self """ + self.feature_names_in_ = feature_names_in self._validate_keywords() self._fit(X, handle_unknown=self.handle_unknown) self.drop_idx_ = self._compute_drop_idx() return self - def fit_transform(self, X, y=None): + def fit_transform(self, X, y=None, feature_names_in=None): """Fit OneHotEncoder to X, then transform X. Equivalent to fit(X).transform(X) but more convenient. @@ -355,7 +358,8 @@ def fit_transform(self, X, y=None): Transformed input. """ self._validate_keywords() - return super().fit_transform(X, y) + return super().fit_transform(X, y, + feature_names_in=feature_names_in) def transform(self, X): """Transform X using one-hot encoding. @@ -491,7 +495,7 @@ def inverse_transform(self, X): return X_tr - def get_feature_names(self, input_features=None): + def _get_feature_names(self, input_features=None): """Return feature names for output features. Parameters @@ -590,7 +594,7 @@ def __init__(self, categories='auto', dtype=np.float64): self.categories = categories self.dtype = dtype - def fit(self, X, y=None): + def fit(self, X, y=None, feature_names_in=None): """Fit the OrdinalEncoder to X. Parameters diff --git a/sklearn/preprocessing/_function_transformer.py b/sklearn/preprocessing/_function_transformer.py index a079612c045d6..3cf7cb5171e36 100644 --- a/sklearn/preprocessing/_function_transformer.py +++ b/sklearn/preprocessing/_function_transformer.py @@ -99,7 +99,7 @@ def _check_inverse_transform(self, X): " want to proceed regardless, set" " 'check_inverse=False'.", UserWarning) - def fit(self, X, y=None): + def fit(self, X, y=None, feature_names_in=None): """Fit transformer by checking X. If ``validate`` is ``True``, ``X`` will be checked. diff --git a/sklearn/preprocessing/data.py b/sklearn/preprocessing/data.py index 823eedc8b7dd9..4db1e34a048fa 100644 --- a/sklearn/preprocessing/data.py +++ b/sklearn/preprocessing/data.py @@ -19,7 +19,7 @@ from scipy import optimize from scipy.special import boxcox -from ..base import BaseEstimator, TransformerMixin +from ..base import BaseEstimator, TransformerMixin, OneToOneMixin from ..utils import check_array from ..utils.extmath import row_norms from ..utils.extmath import _incremental_mean_and_var @@ -196,7 +196,7 @@ def scale(X, axis=0, with_mean=True, with_std=True, copy=True): return X -class MinMaxScaler(BaseEstimator, TransformerMixin): +class MinMaxScaler(BaseEstimator, TransformerMixin, OneToOneMixin): """Transforms features by scaling each feature to a given range. This estimator scales and translates each feature individually such @@ -311,7 +311,7 @@ def _reset(self): del self.data_max_ del self.data_range_ - def fit(self, X, y=None): + def fit(self, X, y=None, feature_names_in=None): """Compute the minimum and maximum to be used for later scaling. Parameters @@ -325,7 +325,7 @@ def fit(self, X, y=None): self._reset() return self.partial_fit(X, y) - def partial_fit(self, X, y=None): + def partial_fit(self, X, y=None, feature_names_in=None): """Online computation of min and max on X for later scaling. All of X is processed as a single batch. This is intended for cases when `fit` is not feasible due to very large number of `n_samples` @@ -488,7 +488,7 @@ def minmax_scale(X, feature_range=(0, 1), axis=0, copy=True): return X -class StandardScaler(BaseEstimator, TransformerMixin): +class StandardScaler(BaseEstimator, TransformerMixin, OneToOneMixin): """Standardize features by removing the mean and scaling to unit variance The standard score of a sample `x` is calculated as: @@ -622,7 +622,7 @@ def _reset(self): del self.mean_ del self.var_ - def fit(self, X, y=None): + def fit(self, X, y=None, feature_names_in=None): """Compute the mean and std to be used for later scaling. Parameters @@ -634,12 +634,12 @@ def fit(self, X, y=None): y Ignored """ - + self.feature_names_in_ = feature_names_in # Reset internal state before fitting self._reset() return self.partial_fit(X, y) - def partial_fit(self, X, y=None): + def partial_fit(self, X, y=None, feature_names_in=None): """Online computation of mean and std on X for later scaling. All of X is processed as a single batch. This is intended for cases when `fit` is not feasible due to very large number of `n_samples` @@ -816,7 +816,7 @@ def _more_tags(self): return {'allow_nan': True} -class MaxAbsScaler(BaseEstimator, TransformerMixin): +class MaxAbsScaler(BaseEstimator, TransformerMixin, OneToOneMixin): """Scale each feature by its maximum absolute value. This estimator scales and translates each feature individually such @@ -893,7 +893,7 @@ def _reset(self): del self.n_samples_seen_ del self.max_abs_ - def fit(self, X, y=None): + def fit(self, X, y=None, feature_names_in=None): """Compute the maximum absolute value to be used for later scaling. Parameters @@ -907,7 +907,7 @@ def fit(self, X, y=None): self._reset() return self.partial_fit(X, y) - def partial_fit(self, X, y=None): + def partial_fit(self, X, y=None, feature_names_in=None): """Online computation of max absolute value of X for later scaling. All of X is processed as a single batch. This is intended for cases when `fit` is not feasible due to very large number of `n_samples` @@ -1045,7 +1045,7 @@ def maxabs_scale(X, axis=0, copy=True): return X -class RobustScaler(BaseEstimator, TransformerMixin): +class RobustScaler(BaseEstimator, TransformerMixin, OneToOneMixin): """Scale features using statistics that are robust to outliers. This Scaler removes the median and scales the data according to @@ -1142,7 +1142,7 @@ def __init__(self, with_centering=True, with_scaling=True, self.quantile_range = quantile_range self.copy = copy - def fit(self, X, y=None): + def fit(self, X, y=None, feature_names_in=None): """Compute the median and quantiles to be used for scaling. Parameters @@ -1418,7 +1418,7 @@ def powers_(self): return np.vstack([np.bincount(c, minlength=self.n_input_features_) for c in combinations]) - def get_feature_names(self, input_features=None): + def _get_feature_names(self, input_features=None): """ Return feature names for output features @@ -1448,7 +1448,7 @@ def get_feature_names(self, input_features=None): feature_names.append(name) return feature_names - def fit(self, X, y=None): + def fit(self, X, y=None, feature_names_in=None): """ Compute number of output features. @@ -1648,7 +1648,7 @@ def normalize(X, norm='l2', axis=1, copy=True, return_norm=False): return X -class Normalizer(BaseEstimator, TransformerMixin): +class Normalizer(BaseEstimator, TransformerMixin, OneToOneMixin): """Normalize samples individually to unit norm. Each sample (i.e. each row of the data matrix) with at least one @@ -1710,7 +1710,7 @@ def __init__(self, norm='l2', copy=True): self.norm = norm self.copy = copy - def fit(self, X, y=None): + def fit(self, X, y=None, feature_names_in=None): """Do nothing and return the estimator unchanged This method is just there to implement the usual API and hence @@ -1786,7 +1786,7 @@ def binarize(X, threshold=0.0, copy=True): return X -class Binarizer(BaseEstimator, TransformerMixin): +class Binarizer(BaseEstimator, TransformerMixin, OneToOneMixin): """Binarize data (set feature values to 0 or 1) according to a threshold Values greater than the threshold map to 1, while values less than @@ -1844,7 +1844,7 @@ def __init__(self, threshold=0.0, copy=True): self.threshold = threshold self.copy = copy - def fit(self, X, y=None): + def fit(self, X, y=None, feature_names_in=None): """Do nothing and return the estimator unchanged This method is just there to implement the usual API and hence @@ -2023,7 +2023,7 @@ def add_dummy_feature(X, value=1.0): return np.hstack((np.full((n_samples, 1), value), X)) -class QuantileTransformer(BaseEstimator, TransformerMixin): +class QuantileTransformer(BaseEstimator, TransformerMixin, OneToOneMixin): """Transform features using quantiles information. This method transforms the features to follow a uniform or a normal @@ -2198,7 +2198,7 @@ def _sparse_fit(self, X, random_state): np.nanpercentile(column_data, references)) self.quantiles_ = np.transpose(self.quantiles_) - def fit(self, X, y=None): + def fit(self, X, y=None, feature_names_in=None): """Compute the quantiles used for transforming. Parameters @@ -2567,7 +2567,7 @@ def quantile_transform(X, axis=0, n_quantiles=1000, " axis={}".format(axis)) -class PowerTransformer(BaseEstimator, TransformerMixin): +class PowerTransformer(BaseEstimator, TransformerMixin, OneToOneMixin): """Apply a power transform featurewise to make data more Gaussian-like. Power transforms are a family of parametric, monotonic transformations @@ -2653,7 +2653,7 @@ def __init__(self, method='yeo-johnson', standardize=True, copy=True): self.standardize = standardize self.copy = copy - def fit(self, X, y=None): + def fit(self, X, y=None, feature_names_in=None): """Estimate the optimal parameter lambda for each feature. The optimal lambda parameter for minimizing skewness is estimated on @@ -2673,7 +2673,7 @@ def fit(self, X, y=None): self._fit(X, y=y, force_transform=False) return self - def fit_transform(self, X, y=None): + def fit_transform(self, X, y=None, feature_names_in=None): return self._fit(X, y, force_transform=True) def _fit(self, X, y=None, force_transform=False): diff --git a/sklearn/random_projection.py b/sklearn/random_projection.py index f4fa2c608b842..e4c42a204d52c 100644 --- a/sklearn/random_projection.py +++ b/sklearn/random_projection.py @@ -323,7 +323,7 @@ def _make_random_matrix(self, n_components, n_features): """ - def fit(self, X, y=None): + def fit(self, X, y=None, feature_names_in=None): """Generate a sparse random projection matrix Parameters diff --git a/sklearn/svm/base.py b/sklearn/svm/base.py index 4a50ee479f030..8a2eb517e2aa1 100644 --- a/sklearn/svm/base.py +++ b/sklearn/svm/base.py @@ -102,7 +102,7 @@ def _pairwise(self): # Used by cross_val_score. return self.kernel == "precomputed" - def fit(self, X, y, sample_weight=None): + def fit(self, X, y, sample_weight=None, feature_names_in=None): """Fit the SVM model according to the given training data. Parameters diff --git a/sklearn/svm/classes.py b/sklearn/svm/classes.py index 067e1a6ef8d34..f8c5961bf4a1d 100644 --- a/sklearn/svm/classes.py +++ b/sklearn/svm/classes.py @@ -185,7 +185,7 @@ def __init__(self, penalty='l2', loss='squared_hinge', dual=True, tol=1e-4, self.penalty = penalty self.loss = loss - def fit(self, X, y, sample_weight=None): + def fit(self, X, y, sample_weight=None, feature_names_in=None): """Fit the model according to the given training data. Parameters @@ -369,7 +369,7 @@ def __init__(self, epsilon=0.0, tol=1e-4, C=1.0, self.dual = dual self.loss = loss - def fit(self, X, y, sample_weight=None): + def fit(self, X, y, sample_weight=None, feature_names_in=None): """Fit the model according to the given training data. Parameters diff --git a/sklearn/tests/test_base.py b/sklearn/tests/test_base.py index 032d9b232523f..c8a676f3e6328 100644 --- a/sklearn/tests/test_base.py +++ b/sklearn/tests/test_base.py @@ -312,7 +312,7 @@ def __init__(self, df=None, scalar_param=1): self.df = df self.scalar_param = scalar_param - def fit(self, X, y=None): + def fit(self, X, y=None, feature_names_in=None): pass def transform(self, X): @@ -484,6 +484,80 @@ def test_tag_inheritance(): diamond_tag_est._get_tags() +@ignore_warnings(category=(FutureWarning, DeprecationWarning)) +def test_sub_estimator_consistency(): + # check that _get_sub_estimators finds all fitted sub estimators + # if this breaks, you probably introduced a sub-estimator that's + # non-standard (not estimator_, base_estimator_ or estimators_) + from sklearn.utils.testing import all_estimators + from sklearn.base import (MetaEstimatorMixin, _get_sub_estimators, + ClassifierMixin, RegressorMixin) + + from sklearn.model_selection._search import BaseSearchCV + from sklearn.feature_selection.base import SelectorMixin + from sklearn.datasets import make_blobs + from sklearn.linear_model import Ridge, LogisticRegression + from sklearn.utils.estimator_checks import \ + multioutput_estimator_convert_y_2d + from collections.abc import Iterable + + def has_fitted_attr(est): + attrs = [(x, getattr(est, x, None)) + for x in dir(est) if x.endswith("_") + and not x.startswith("__")] + return len(attrs) + + def get_sub_estimators_brute(est): + # recurse through all attributes to get sub-estimators + attrs = [(x, getattr(est, x, None)) + for x in dir(est) if not x.startswith("_")] + + def _recurse_sub_ests(candidates): + sub_ests = [] + for a in candidates: + if hasattr(a, "set_params") and hasattr(a, "fit"): + sub_ests.append(a) + elif isinstance(a, Iterable) and not isinstance(a, str): + sub_ests.extend(_recurse_sub_ests(a)) + return sub_ests + ests = _recurse_sub_ests(attrs) + # we don't consider label processors child estimators + return set([e for e in ests if has_fitted_attr(e) + and e.__module__ != "sklearn.preprocessing.label"]) + + al = all_estimators() + mets = [x for x in al if issubclass(x[1], MetaEstimatorMixin)] + + X, y = make_blobs() + others = [] + + for name, Est in mets: + # instantiate and fit + try: + est = Est() + except TypeError: + if issubclass(Est, (ClassifierMixin, SelectorMixin)): + est = Est(LogisticRegression(solver='lbfgs', + multi_class='auto')) + elif issubclass(Est, RegressorMixin): + est = Est(Ridge()) + else: + others.append((name, Est)) + if est._get_tags()['_skip_test']: + continue + + y = multioutput_estimator_convert_y_2d(est, y) + est.fit(X, y) + # test recursive sub estimators are the same as result of + # _get_sub_estimators which uses a hard-coded list + assert (set(_get_sub_estimators(est)) == + get_sub_estimators_brute(est)) + + for name, Est in others: + # only things we couldn't instantiate are the search CV + assert issubclass(Est, BaseSearchCV) + + # XXX: Remove in 0.23 def test_regressormixin_score_multioutput(): from sklearn.linear_model import LinearRegression diff --git a/sklearn/tests/test_pipeline.py b/sklearn/tests/test_pipeline.py index e02b5ef96b7b0..6f6bffcec22f5 100644 --- a/sklearn/tests/test_pipeline.py +++ b/sklearn/tests/test_pipeline.py @@ -22,17 +22,20 @@ from sklearn.utils.testing import assert_no_warnings from sklearn.base import clone, BaseEstimator +from sklearn.exceptions import NotFittedError from sklearn.pipeline import Pipeline, FeatureUnion, make_pipeline, make_union from sklearn.svm import SVC from sklearn.neighbors import LocalOutlierFactor from sklearn.linear_model import LogisticRegression, Lasso from sklearn.linear_model import LinearRegression +from sklearn.multiclass import OneVsRestClassifier from sklearn.cluster import KMeans from sklearn.feature_selection import SelectKBest, f_classif from sklearn.dummy import DummyRegressor from sklearn.decomposition import PCA, TruncatedSVD from sklearn.datasets import load_iris from sklearn.preprocessing import StandardScaler +from sklearn.impute import SimpleImputer from sklearn.feature_extraction.text import CountVectorizer @@ -1103,6 +1106,141 @@ def test_make_pipeline_memory(): shutil.rmtree(cachedir) +def test_set_input_features(): + pipe = Pipeline(steps=[ + ('imputer', SimpleImputer(strategy='median')), + ('scaler', StandardScaler()), + ('select', SelectKBest(k=2)), + ('clf', LogisticRegression())]) + assert_raises(NotFittedError, pipe.get_feature_names) + iris = load_iris() + pipe.fit(iris.data, iris.target) + xs = np.array(['x0', 'x1', 'x2', 'x3']) + assert_array_equal(pipe.input_features_, xs) + mask = pipe.named_steps.select.get_support() + assert_array_equal(pipe.named_steps.clf.input_features_, xs[mask]) + res = pipe.get_feature_names(iris.feature_names) + # LogisticRegression doesn't have get_feature_names + assert res is None + assert_array_equal(pipe.input_features_, iris.feature_names) + assert_array_equal(pipe.named_steps.clf.input_features_, + np.array(iris.feature_names)[mask]) + # check that empty get_feature_names() doesn't overwrite + res = pipe.get_feature_names() + assert res is None + assert_array_equal(pipe.input_features_, iris.feature_names) + assert_array_equal(pipe.named_steps.clf.input_features_, + np.array(iris.feature_names)[mask]) + pipe = Pipeline(steps=[ + ('scaler', StandardScaler()), + ('pca', PCA(n_components=3)), + ('select', SelectKBest(k=2)), + ('clf', LogisticRegression())]) + pipe.fit(iris.data, iris.target) + assert_array_equal(pipe.named_steps.clf.input_features_, ['pca0', 'pca1']) + # setting names doesn't change names after PCA + pipe.get_feature_names(iris.feature_names) + assert_array_equal(pipe.named_steps.select.input_features_, + ['pca0', 'pca1', 'pca2']) + + +def test_input_feature_names_pandas(): + pd = pytest.importorskip("pandas") + pipe = Pipeline(steps=[ + ('imputer', SimpleImputer(strategy='median')), + ('scaler', StandardScaler()), + ('select', SelectKBest(k=2)), + ('clf', LogisticRegression())]) + iris = load_iris() + df = pd.DataFrame(iris.data, columns=iris.feature_names) + pipe.fit(df, iris.target) + mask = pipe.named_steps.select.get_support() + assert_array_equal(pipe.named_steps.clf.input_features_, + np.array(iris.feature_names)[mask]) + + +def test_input_features_passthrough(): + pipe = Pipeline(steps=[ + ('imputer', 'passthrough'), + ('scaler', StandardScaler()), + ('select', 'passthrough'), + ('clf', LogisticRegression())]) + iris = load_iris() + pipe.fit(iris.data, iris.target) + xs = ['x0', 'x1', 'x2', 'x3'] + assert_array_equal(pipe.named_steps.clf.input_features_, xs) + pipe.get_feature_names(iris.feature_names) + assert_array_equal(pipe.named_steps.clf.input_features_, + iris.feature_names) + + +def test_input_features_count_vectorizer(): + pipe = Pipeline(steps=[ + ('vect', CountVectorizer()), + ('clf', LogisticRegression())]) + y = ["pizza" in x for x in JUNK_FOOD_DOCS] + pipe.fit(JUNK_FOOD_DOCS, y) + assert_array_equal(pipe.named_steps.clf.input_features_, + ['beer', 'burger', 'coke', 'copyright', 'pizza', 'the']) + pipe.get_feature_names(["nonsense_is_ignored"]) + assert_array_equal(pipe.named_steps.clf.input_features_, + ['beer', 'burger', 'coke', 'copyright', 'pizza', 'the']) + + +def test_input_features_nested(): + pipe = Pipeline(steps=[ + ('inner_pipe', Pipeline(steps=[('select', SelectKBest(k=2)), + ('clf', LogisticRegression())]))]) + iris = load_iris() + pipe.fit(iris.data, iris.target) + xs = np.array(['x0', 'x1', 'x2', 'x3']) + assert_array_equal(pipe.input_features_, xs) + mask = pipe.named_steps.inner_pipe.named_steps.select.get_support() + assert_array_equal( + pipe.named_steps.inner_pipe.named_steps.clf.input_features_, xs[mask]) + pipe.get_feature_names(iris.feature_names) + assert_array_equal(pipe.input_features_, iris.feature_names) + assert_array_equal( + pipe.named_steps.inner_pipe.named_steps.clf.input_features_, + np.array(iris.feature_names)[mask]) + + +def test_input_features_meta_pipe(): + ovr = OneVsRestClassifier(Pipeline(steps=[('select', SelectKBest(k=2)), + ('clf', LogisticRegression())])) + pipe = Pipeline(steps=[('ovr', ovr)]) + iris = load_iris() + pipe.fit(iris.data, iris.target) + xs = np.array(['x0', 'x1', 'x2', 'x3']) + assert_array_equal(pipe.input_features_, xs) + # check 0ths estimator in OVR only + inner_pipe = pipe.named_steps.ovr.estimators_[0] + mask = inner_pipe.named_steps.select.get_support() + assert_array_equal(inner_pipe.named_steps.clf.input_features_, xs[mask]) + pipe.get_feature_names(iris.feature_names) + assert_array_equal(pipe.input_features_, iris.feature_names) + assert_array_equal(inner_pipe.input_features_, iris.feature_names) + assert_array_equal(inner_pipe.named_steps.clf.input_features_, + np.array(iris.feature_names)[mask]) + + +def test_input_features_meta(): + ovr = OneVsRestClassifier(LogisticRegression()) + pipe = Pipeline(steps=[('select', SelectKBest(k=2)), ('ovr', ovr)]) + iris = load_iris() + pipe.fit(iris.data, iris.target) + xs = np.array(['x0', 'x1', 'x2', 'x3']) + assert_array_equal(pipe.input_features_, xs) + # check 0ths estimator in OVR only + one_logreg = pipe.named_steps.ovr.estimators_[0] + mask = pipe.named_steps.select.get_support() + assert_array_equal(one_logreg.input_features_, xs[mask]) + pipe.get_feature_names(iris.feature_names) + assert_array_equal(pipe.input_features_, iris.feature_names) + assert_array_equal(one_logreg.input_features_, + np.array(iris.feature_names)[mask]) + + def test_pipeline_param_error(): clf = make_pipeline(LogisticRegression()) with pytest.raises(ValueError, match="Pipeline.fit does not accept " diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index 249cb022f8e87..76e57d3d69ffa 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -1014,6 +1014,15 @@ def _check_transformer(name, transformer_orig, X, y): transformer_clone = clone(transformer) X_pred = transformer_clone.fit_transform(X, y=y_) + input_features = ['feature%d' % i for i in range(n_features)] + if hasattr(transformer_clone, 'get_feature_names'): + feature_names = transformer_clone.get_feature_names(input_features) + if feature_names is not None: + if isinstance(X_pred, tuple): + assert len(feature_names) == X_pred[0].shape[1] + else: + assert len(feature_names) == X_pred.shape[1] + if isinstance(X_pred, tuple): for x_pred in X_pred: assert x_pred.shape[0] == n_samples diff --git a/sklearn/utils/tests/test_estimator_checks.py b/sklearn/utils/tests/test_estimator_checks.py index 32ef303889be1..969e0569afe99 100644 --- a/sklearn/utils/tests/test_estimator_checks.py +++ b/sklearn/utils/tests/test_estimator_checks.py @@ -54,7 +54,7 @@ class ChangesDict(BaseEstimator): def __init__(self, key=0): self.key = key - def fit(self, X, y=None): + def fit(self, X, y=None, feature_names_in=None): X, y = check_X_y(X, y) return self @@ -68,7 +68,7 @@ class SetsWrongAttribute(BaseEstimator): def __init__(self, acceptable_key=0): self.acceptable_key = acceptable_key - def fit(self, X, y=None): + def fit(self, X, y=None, feature_names_in=None): self.wrong_attribute = 0 X, y = check_X_y(X, y) return self @@ -78,14 +78,14 @@ class ChangesWrongAttribute(BaseEstimator): def __init__(self, wrong_attribute=0): self.wrong_attribute = wrong_attribute - def fit(self, X, y=None): + def fit(self, X, y=None, feature_names_in=None): self.wrong_attribute = 1 X, y = check_X_y(X, y) return self class ChangesUnderscoreAttribute(BaseEstimator): - def fit(self, X, y=None): + def fit(self, X, y=None, feature_names_in=None): self._good_attribute = 1 X, y = check_X_y(X, y) return self @@ -103,7 +103,7 @@ def set_params(self, **kwargs): self.p = p return super().set_params(**kwargs) - def fit(self, X, y=None): + def fit(self, X, y=None, feature_names_in=None): X, y = check_X_y(X, y) return self @@ -120,7 +120,7 @@ def set_params(self, **kwargs): self.p = p return super().set_params(**kwargs) - def fit(self, X, y=None): + def fit(self, X, y=None, feature_names_in=None): X, y = check_X_y(X, y) return self @@ -139,7 +139,7 @@ def set_params(self, **kwargs): self.b = 'method2' return super().set_params(**kwargs) - def fit(self, X, y=None): + def fit(self, X, y=None, feature_names_in=None): X, y = check_X_y(X, y) return self @@ -175,7 +175,7 @@ def predict(self, X): class NoSampleWeightPandasSeriesType(BaseEstimator): - def fit(self, X, y, sample_weight=None): + def fit(self, X, y, sample_weight=None, feature_names_in=None): # Convert data X, y = check_X_y(X, y, accept_sparse=("csr", "csc"), @@ -216,7 +216,7 @@ def fit(self, X, y): class BadTransformerWithoutMixin(BaseEstimator): - def fit(self, X, y=None): + def fit(self, X, y=None, feature_names_in=None): X = check_array(X) return self @@ -263,11 +263,11 @@ def fit(self, X, y): class SparseTransformer(BaseEstimator): - def fit(self, X, y=None): + def fit(self, X, y=None, feature_names_in=None): self.X_shape_ = check_array(X).shape return self - def fit_transform(self, X, y=None): + def fit_transform(self, X, y=None, feature_names_in=None): return self.fit(X, y).transform(X) def transform(self, X): diff --git a/sklearn/utils/tests/test_validation.py b/sklearn/utils/tests/test_validation.py index 0aa8eae22b1e2..5966476ed8f16 100644 --- a/sklearn/utils/tests/test_validation.py +++ b/sklearn/utils/tests/test_validation.py @@ -591,7 +591,7 @@ def test_has_fit_parameter(): class TestClassWithDeprecatedFitMethod: @deprecated("Deprecated for the purpose of testing has_fit_parameter") - def fit(self, X, y, sample_weight=None): + def fit(self, X, y, sample_weight=None, feature_names_in=None): pass assert has_fit_parameter(TestClassWithDeprecatedFitMethod, From 3ba62b0e3c7ef6df35cd9390c33250cd416cbf58 Mon Sep 17 00:00:00 2001 From: Andreas Mueller Date: Tue, 2 Jul 2019 17:40:55 -0400 Subject: [PATCH 2/4] remove subestimator mess --- sklearn/base.py | 41 --------------------- sklearn/tests/test_base.py | 74 -------------------------------------- 2 files changed, 115 deletions(-) diff --git a/sklearn/base.py b/sklearn/base.py index a2fd84f443d9d..9abf1cf6ac711 100644 --- a/sklearn/base.py +++ b/sklearn/base.py @@ -662,51 +662,10 @@ def feature_names_out_(self): return self.feature_names_in_ -def _get_sub_estimators(est): - # Explicitly declare all fitted subestimators of existing meta-estimators - sub_ests = [] - # OHE is not really needed - sub_names = ['estimator_', 'base_estimator_', 'one_hot_encoder_', - 'best_estimator_', 'init_'] - for name in sub_names: - sub_est = getattr(est, name, None) - if sub_est is not None: - sub_ests.append(sub_est) - if hasattr(est, "estimators_"): - if hasattr(est.estimators_, 'shape'): - sub_ests.extend(est.estimators_.ravel()) - else: - sub_ests.extend(est.estimators_) - return sub_ests - - class MetaEstimatorMixin: _required_parameters = ["estimator"] """Mixin class for all meta estimators in scikit-learn.""" - def _get_feature_names(self, input_features=None): - """Ensure feature names are set on sub-estimators - - Parameters - ---------- - input_features : list of string or None - Input features to the meta-estimator. - """ - sub_ests = _get_sub_estimators(self) - for est in sub_ests: - est.input_features_ = input_features - if hasattr(est, "get_feature_names"): - # doing hassattr instead of a try-except on everything - # b/c catching AttributeError makes recursive code - # impossible to debug - try: - est.get_feature_names(input_features=input_features) - except TypeError: - # do we need this? - est.get_feature_names() - except NotFittedError: - pass - class MultiOutputMixin(object): """Mixin to mark estimators that support multioutput.""" diff --git a/sklearn/tests/test_base.py b/sklearn/tests/test_base.py index c8a676f3e6328..03e621de0e506 100644 --- a/sklearn/tests/test_base.py +++ b/sklearn/tests/test_base.py @@ -484,80 +484,6 @@ def test_tag_inheritance(): diamond_tag_est._get_tags() -@ignore_warnings(category=(FutureWarning, DeprecationWarning)) -def test_sub_estimator_consistency(): - # check that _get_sub_estimators finds all fitted sub estimators - # if this breaks, you probably introduced a sub-estimator that's - # non-standard (not estimator_, base_estimator_ or estimators_) - from sklearn.utils.testing import all_estimators - from sklearn.base import (MetaEstimatorMixin, _get_sub_estimators, - ClassifierMixin, RegressorMixin) - - from sklearn.model_selection._search import BaseSearchCV - from sklearn.feature_selection.base import SelectorMixin - from sklearn.datasets import make_blobs - from sklearn.linear_model import Ridge, LogisticRegression - from sklearn.utils.estimator_checks import \ - multioutput_estimator_convert_y_2d - from collections.abc import Iterable - - def has_fitted_attr(est): - attrs = [(x, getattr(est, x, None)) - for x in dir(est) if x.endswith("_") - and not x.startswith("__")] - return len(attrs) - - def get_sub_estimators_brute(est): - # recurse through all attributes to get sub-estimators - attrs = [(x, getattr(est, x, None)) - for x in dir(est) if not x.startswith("_")] - - def _recurse_sub_ests(candidates): - sub_ests = [] - for a in candidates: - if hasattr(a, "set_params") and hasattr(a, "fit"): - sub_ests.append(a) - elif isinstance(a, Iterable) and not isinstance(a, str): - sub_ests.extend(_recurse_sub_ests(a)) - return sub_ests - ests = _recurse_sub_ests(attrs) - # we don't consider label processors child estimators - return set([e for e in ests if has_fitted_attr(e) - and e.__module__ != "sklearn.preprocessing.label"]) - - al = all_estimators() - mets = [x for x in al if issubclass(x[1], MetaEstimatorMixin)] - - X, y = make_blobs() - others = [] - - for name, Est in mets: - # instantiate and fit - try: - est = Est() - except TypeError: - if issubclass(Est, (ClassifierMixin, SelectorMixin)): - est = Est(LogisticRegression(solver='lbfgs', - multi_class='auto')) - elif issubclass(Est, RegressorMixin): - est = Est(Ridge()) - else: - others.append((name, Est)) - if est._get_tags()['_skip_test']: - continue - - y = multioutput_estimator_convert_y_2d(est, y) - est.fit(X, y) - # test recursive sub estimators are the same as result of - # _get_sub_estimators which uses a hard-coded list - assert (set(_get_sub_estimators(est)) == - get_sub_estimators_brute(est)) - - for name, Est in others: - # only things we couldn't instantiate are the search CV - assert issubclass(Est, BaseSearchCV) - - # XXX: Remove in 0.23 def test_regressormixin_score_multioutput(): from sklearn.linear_model import LinearRegression From 8d9605a94a591f077c2f1810330ba11747f26876 Mon Sep 17 00:00:00 2001 From: Andreas Mueller Date: Tue, 2 Jul 2019 17:48:25 -0400 Subject: [PATCH 3/4] rename input_features_ to feature_names_in_ --- .../plot_column_transformer_mixed_types.py | 2 +- .../plot_feature_selection_pipeline.py | 2 +- sklearn/tests/test_pipeline.py | 62 +++++++++---------- 3 files changed, 33 insertions(+), 33 deletions(-) diff --git a/examples/compose/plot_column_transformer_mixed_types.py b/examples/compose/plot_column_transformer_mixed_types.py index c95f9494916df..83455bdfe4b35 100644 --- a/examples/compose/plot_column_transformer_mixed_types.py +++ b/examples/compose/plot_column_transformer_mixed_types.py @@ -118,7 +118,7 @@ cv_coefs = np.concatenate([cv_pipeline.named_steps["classifier"].coef_ for cv_pipeline in cv_results["estimator"]]) fig, ax = plt.subplots() -ax.barh(pipeline.named_steps["classifier"].input_features_, +ax.barh(pipeline.named_steps["classifier"].feature_names_in_, cv_coefs.mean(axis=0), xerr=cv_coefs.std(axis=0)) plt.tight_layout() plt.show() diff --git a/examples/feature_selection/plot_feature_selection_pipeline.py b/examples/feature_selection/plot_feature_selection_pipeline.py index 5eb9dd57e233b..b3ef807589261 100644 --- a/examples/feature_selection/plot_feature_selection_pipeline.py +++ b/examples/feature_selection/plot_feature_selection_pipeline.py @@ -39,5 +39,5 @@ # access and plot the coefficients of the fitted model plt.barh((0, 1, 2), anova_svm[-1].coef_.ravel()) -plt.yticks((0, 1, 2), anova_svm[-1].input_features_) +plt.yticks((0, 1, 2), anova_svm[-1].feature_names_in_) plt.show() diff --git a/sklearn/tests/test_pipeline.py b/sklearn/tests/test_pipeline.py index 6f6bffcec22f5..2afdf090f3a07 100644 --- a/sklearn/tests/test_pipeline.py +++ b/sklearn/tests/test_pipeline.py @@ -1116,20 +1116,20 @@ def test_set_input_features(): iris = load_iris() pipe.fit(iris.data, iris.target) xs = np.array(['x0', 'x1', 'x2', 'x3']) - assert_array_equal(pipe.input_features_, xs) + assert_array_equal(pipe.feature_names_in_, xs) mask = pipe.named_steps.select.get_support() - assert_array_equal(pipe.named_steps.clf.input_features_, xs[mask]) + assert_array_equal(pipe.named_steps.clf.feature_names_in_, xs[mask]) res = pipe.get_feature_names(iris.feature_names) # LogisticRegression doesn't have get_feature_names assert res is None - assert_array_equal(pipe.input_features_, iris.feature_names) - assert_array_equal(pipe.named_steps.clf.input_features_, + assert_array_equal(pipe.feature_names_in_, iris.feature_names) + assert_array_equal(pipe.named_steps.clf.feature_names_in_, np.array(iris.feature_names)[mask]) # check that empty get_feature_names() doesn't overwrite res = pipe.get_feature_names() assert res is None - assert_array_equal(pipe.input_features_, iris.feature_names) - assert_array_equal(pipe.named_steps.clf.input_features_, + assert_array_equal(pipe.feature_names_in_, iris.feature_names) + assert_array_equal(pipe.named_steps.clf.feature_names_in_, np.array(iris.feature_names)[mask]) pipe = Pipeline(steps=[ ('scaler', StandardScaler()), @@ -1137,10 +1137,10 @@ def test_set_input_features(): ('select', SelectKBest(k=2)), ('clf', LogisticRegression())]) pipe.fit(iris.data, iris.target) - assert_array_equal(pipe.named_steps.clf.input_features_, ['pca0', 'pca1']) + assert_array_equal(pipe.named_steps.clf.feature_names_in_, ['pca0', 'pca1']) # setting names doesn't change names after PCA pipe.get_feature_names(iris.feature_names) - assert_array_equal(pipe.named_steps.select.input_features_, + assert_array_equal(pipe.named_steps.select.feature_names_in_, ['pca0', 'pca1', 'pca2']) @@ -1155,11 +1155,11 @@ def test_input_feature_names_pandas(): df = pd.DataFrame(iris.data, columns=iris.feature_names) pipe.fit(df, iris.target) mask = pipe.named_steps.select.get_support() - assert_array_equal(pipe.named_steps.clf.input_features_, + assert_array_equal(pipe.named_steps.clf.feature_names_in_, np.array(iris.feature_names)[mask]) -def test_input_features_passthrough(): +def test_feature_names_in_passthrough(): pipe = Pipeline(steps=[ ('imputer', 'passthrough'), ('scaler', StandardScaler()), @@ -1168,76 +1168,76 @@ def test_input_features_passthrough(): iris = load_iris() pipe.fit(iris.data, iris.target) xs = ['x0', 'x1', 'x2', 'x3'] - assert_array_equal(pipe.named_steps.clf.input_features_, xs) + assert_array_equal(pipe.named_steps.clf.feature_names_in_, xs) pipe.get_feature_names(iris.feature_names) - assert_array_equal(pipe.named_steps.clf.input_features_, + assert_array_equal(pipe.named_steps.clf.feature_names_in_, iris.feature_names) -def test_input_features_count_vectorizer(): +def test_feature_names_in_count_vectorizer(): pipe = Pipeline(steps=[ ('vect', CountVectorizer()), ('clf', LogisticRegression())]) y = ["pizza" in x for x in JUNK_FOOD_DOCS] pipe.fit(JUNK_FOOD_DOCS, y) - assert_array_equal(pipe.named_steps.clf.input_features_, + assert_array_equal(pipe.named_steps.clf.feature_names_in_, ['beer', 'burger', 'coke', 'copyright', 'pizza', 'the']) pipe.get_feature_names(["nonsense_is_ignored"]) - assert_array_equal(pipe.named_steps.clf.input_features_, + assert_array_equal(pipe.named_steps.clf.feature_names_in_, ['beer', 'burger', 'coke', 'copyright', 'pizza', 'the']) -def test_input_features_nested(): +def test_feature_names_in_nested(): pipe = Pipeline(steps=[ ('inner_pipe', Pipeline(steps=[('select', SelectKBest(k=2)), ('clf', LogisticRegression())]))]) iris = load_iris() pipe.fit(iris.data, iris.target) xs = np.array(['x0', 'x1', 'x2', 'x3']) - assert_array_equal(pipe.input_features_, xs) + assert_array_equal(pipe.feature_names_in_, xs) mask = pipe.named_steps.inner_pipe.named_steps.select.get_support() assert_array_equal( - pipe.named_steps.inner_pipe.named_steps.clf.input_features_, xs[mask]) + pipe.named_steps.inner_pipe.named_steps.clf.feature_names_in_, xs[mask]) pipe.get_feature_names(iris.feature_names) - assert_array_equal(pipe.input_features_, iris.feature_names) + assert_array_equal(pipe.feature_names_in_, iris.feature_names) assert_array_equal( - pipe.named_steps.inner_pipe.named_steps.clf.input_features_, + pipe.named_steps.inner_pipe.named_steps.clf.feature_names_in_, np.array(iris.feature_names)[mask]) -def test_input_features_meta_pipe(): +def test_feature_names_in_meta_pipe(): ovr = OneVsRestClassifier(Pipeline(steps=[('select', SelectKBest(k=2)), ('clf', LogisticRegression())])) pipe = Pipeline(steps=[('ovr', ovr)]) iris = load_iris() pipe.fit(iris.data, iris.target) xs = np.array(['x0', 'x1', 'x2', 'x3']) - assert_array_equal(pipe.input_features_, xs) + assert_array_equal(pipe.feature_names_in_, xs) # check 0ths estimator in OVR only inner_pipe = pipe.named_steps.ovr.estimators_[0] mask = inner_pipe.named_steps.select.get_support() - assert_array_equal(inner_pipe.named_steps.clf.input_features_, xs[mask]) + assert_array_equal(inner_pipe.named_steps.clf.feature_names_in_, xs[mask]) pipe.get_feature_names(iris.feature_names) - assert_array_equal(pipe.input_features_, iris.feature_names) - assert_array_equal(inner_pipe.input_features_, iris.feature_names) - assert_array_equal(inner_pipe.named_steps.clf.input_features_, + assert_array_equal(pipe.feature_names_in_, iris.feature_names) + assert_array_equal(inner_pipe.feature_names_in_, iris.feature_names) + assert_array_equal(inner_pipe.named_steps.clf.feature_names_in_, np.array(iris.feature_names)[mask]) -def test_input_features_meta(): +def test_feature_names_in_meta(): ovr = OneVsRestClassifier(LogisticRegression()) pipe = Pipeline(steps=[('select', SelectKBest(k=2)), ('ovr', ovr)]) iris = load_iris() pipe.fit(iris.data, iris.target) xs = np.array(['x0', 'x1', 'x2', 'x3']) - assert_array_equal(pipe.input_features_, xs) + assert_array_equal(pipe.feature_names_in_, xs) # check 0ths estimator in OVR only one_logreg = pipe.named_steps.ovr.estimators_[0] mask = pipe.named_steps.select.get_support() - assert_array_equal(one_logreg.input_features_, xs[mask]) + assert_array_equal(one_logreg.feature_names_in_, xs[mask]) pipe.get_feature_names(iris.feature_names) - assert_array_equal(pipe.input_features_, iris.feature_names) - assert_array_equal(one_logreg.input_features_, + assert_array_equal(pipe.feature_names_in_, iris.feature_names) + assert_array_equal(one_logreg.feature_names_in_, np.array(iris.feature_names)[mask]) From 32bf3eab04b922f3a2a91c4230d4730756c0dca1 Mon Sep 17 00:00:00 2001 From: Andreas Mueller Date: Wed, 3 Jul 2019 13:55:27 -0400 Subject: [PATCH 4/4] add feature_names_in to way more places in fit --- doc/developers/contributing.rst | 6 +++--- sklearn/cross_decomposition/pls_.py | 4 ++-- sklearn/discriminant_analysis.py | 4 ++-- .../gradient_boosting.py | 2 +- sklearn/ensemble/tests/test_bagging.py | 8 ++++---- .../ensemble/tests/test_gradient_boosting.py | 2 +- sklearn/ensemble/tests/test_weight_boosting.py | 2 +- sklearn/feature_selection/rfe.py | 2 +- sklearn/feature_selection/tests/test_rfe.py | 2 +- .../feature_selection/univariate_selection.py | 2 +- sklearn/gaussian_process/gpc.py | 4 ++-- sklearn/gaussian_process/gpr.py | 2 +- .../tests/test_partial_dependence.py | 2 +- sklearn/linear_model/base.py | 2 +- sklearn/linear_model/bayes.py | 2 +- sklearn/linear_model/coordinate_descent.py | 4 ++-- sklearn/linear_model/least_angle.py | 2 +- sklearn/linear_model/omp.py | 4 ++-- sklearn/linear_model/stochastic_gradient.py | 2 +- .../tests/test_passive_aggressive.py | 2 +- sklearn/linear_model/tests/test_perceptron.py | 2 +- sklearn/linear_model/theil_sen.py | 2 +- sklearn/metrics/tests/test_score_objects.py | 6 +++--- sklearn/model_selection/tests/test_search.py | 4 ++-- sklearn/multiclass.py | 8 ++++---- sklearn/multioutput.py | 6 +++--- sklearn/neighbors/base.py | 4 ++-- sklearn/neighbors/nca.py | 2 +- sklearn/neighbors/nearest_centroid.py | 2 +- .../neural_network/multilayer_perceptron.py | 4 ++-- sklearn/semi_supervised/label_propagation.py | 4 ++-- sklearn/tests/test_calibration.py | 2 +- sklearn/tests/test_pipeline.py | 8 ++++---- sklearn/utils/tests/test_estimator_checks.py | 18 +++++++++--------- sklearn/utils/tests/test_pprint.py | 2 +- sklearn/utils/tests/test_testing.py | 4 ++-- 36 files changed, 69 insertions(+), 69 deletions(-) diff --git a/doc/developers/contributing.rst b/doc/developers/contributing.rst index d44d372f1b7ca..8f55e129bc7ae 100644 --- a/doc/developers/contributing.rst +++ b/doc/developers/contributing.rst @@ -927,7 +927,7 @@ When the change is in a class, we validate and raise warning in ``fit``:: self.n_clusters = n_clusters self.k = k - def fit(self, X, y): + def fit(self, X, y, feature_names_in=None): if self.k != 'not_used': warnings.warn("'k' was renamed to n_clusters in version 0.13 and " "will be removed in 0.15.", DeprecationWarning) @@ -983,7 +983,7 @@ When the change is in a class, we validate and raise warning in ``fit``:: def __init__(self, n_clusters='warn'): self.n_clusters = n_clusters - def fit(self, X, y): + def fit(self, X, y, feature_names_in=None): if self.n_clusters == 'warn': warnings.warn("The default value of n_clusters will change from " "5 to 10 in 0.22.", FutureWarning) @@ -1339,7 +1339,7 @@ the correct interface more easily. ... def __init__(self, demo_param='demo'): ... self.demo_param = demo_param ... - ... def fit(self, X, y): + ... def fit(self, X, y, feature_names_in=None): ... ... # Check that X and y have correct shape ... X, y = check_X_y(X, y) diff --git a/sklearn/cross_decomposition/pls_.py b/sklearn/cross_decomposition/pls_.py index 9affa87544431..57b568f292350 100644 --- a/sklearn/cross_decomposition/pls_.py +++ b/sklearn/cross_decomposition/pls_.py @@ -232,7 +232,7 @@ def __init__(self, n_components=2, scale=True, deflation_mode="regression", self.tol = tol self.copy = copy - def fit(self, X, Y): + def fit(self, X, y, feature_names_in=None): """Fit model to data. Parameters @@ -809,7 +809,7 @@ def __init__(self, n_components=2, scale=True, copy=True): self.scale = scale self.copy = copy - def fit(self, X, Y): + def fit(self, X, y, feature_names_in=None): """Fit model to data. Parameters diff --git a/sklearn/discriminant_analysis.py b/sklearn/discriminant_analysis.py index 7a0bdcf398aa2..6a28483b8d03d 100644 --- a/sklearn/discriminant_analysis.py +++ b/sklearn/discriminant_analysis.py @@ -407,7 +407,7 @@ def _solve_svd(self, X, y): self.coef_ = np.dot(coef, self.scalings_.T) self.intercept_ -= np.dot(self.xbar_, self.coef_.T) - def fit(self, X, y): + def fit(self, X, y, feature_names_in=None): """Fit LinearDiscriminantAnalysis model according to the given training data and parameters. @@ -636,7 +636,7 @@ def __init__(self, priors=None, reg_param=0., store_covariance=False, self.store_covariance = store_covariance self.tol = tol - def fit(self, X, y): + def fit(self, X, y, feature_names_in=None): """Fit the model according to the given training data and parameters. .. versionchanged:: 0.19 diff --git a/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py b/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py index 00b5d08278162..ba81103ee102e 100644 --- a/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py +++ b/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py @@ -74,7 +74,7 @@ def _validate_parameters(self): raise ValueError('tol={} ' 'must not be smaller than 0.'.format(self.tol)) - def fit(self, X, y): + def fit(self, X, y, feature_names_in=None): """Fit the gradient boosting model. Parameters diff --git a/sklearn/ensemble/tests/test_bagging.py b/sklearn/ensemble/tests/test_bagging.py index 345ee90f1fe49..27378dc580c83 100644 --- a/sklearn/ensemble/tests/test_bagging.py +++ b/sklearn/ensemble/tests/test_bagging.py @@ -80,7 +80,7 @@ def test_sparse_classification(): class CustomSVC(SVC): """SVC variant that records the nature of the training set""" - def fit(self, X, y): + def fit(self, X, y, feature_names_in=None): super().fit(X, y) self.data_type_ = type(X) return self @@ -166,7 +166,7 @@ def test_sparse_regression(): class CustomSVR(SVR): """SVC variant that records the nature of the training set""" - def fit(self, X, y): + def fit(self, X, y, feature_names_in=None): super().fit(X, y) self.data_type_ = type(X) return self @@ -218,7 +218,7 @@ def fit(self, X, y): class DummySizeEstimator(BaseEstimator): - def fit(self, X, y): + def fit(self, X, y, feature_names_in=None): self.training_size_ = X.shape[0] self.training_hash_ = joblib.hash(X) @@ -582,7 +582,7 @@ def test_bagging_with_pipeline(): class DummyZeroEstimator(BaseEstimator): - def fit(self, X, y): + def fit(self, X, y, feature_names_in=None): self.classes_ = np.unique(y) return self diff --git a/sklearn/ensemble/tests/test_gradient_boosting.py b/sklearn/ensemble/tests/test_gradient_boosting.py index 17e09f7f07156..70a3d54a6ca3f 100644 --- a/sklearn/ensemble/tests/test_gradient_boosting.py +++ b/sklearn/ensemble/tests/test_gradient_boosting.py @@ -1335,7 +1335,7 @@ class _NoSampleWeightWrapper(BaseEstimator): def __init__(self, est): self.est = est - def fit(self, X, y): + def fit(self, X, y, feature_names_in=None): self.est.fit(X, y) def predict(self, X): diff --git a/sklearn/ensemble/tests/test_weight_boosting.py b/sklearn/ensemble/tests/test_weight_boosting.py index b5695f82374b1..a42d49547b8ee 100755 --- a/sklearn/ensemble/tests/test_weight_boosting.py +++ b/sklearn/ensemble/tests/test_weight_boosting.py @@ -470,7 +470,7 @@ def test_sample_weight_adaboost_regressor(): """ class DummyEstimator(BaseEstimator): - def fit(self, X, y): + def fit(self, X, y, feature_names_in=None): pass def predict(self, X): diff --git a/sklearn/feature_selection/rfe.py b/sklearn/feature_selection/rfe.py index 0010c5c5f8ca0..ba28f60cba76d 100644 --- a/sklearn/feature_selection/rfe.py +++ b/sklearn/feature_selection/rfe.py @@ -130,7 +130,7 @@ def _estimator_type(self): def classes_(self): return self.estimator_.classes_ - def fit(self, X, y): + def fit(self, X, y, feature_names_in=None): """Fit the RFE model and then the underlying estimator on the selected features. diff --git a/sklearn/feature_selection/tests/test_rfe.py b/sklearn/feature_selection/tests/test_rfe.py index 0ef1cb12efdba..073bbd8bdf06d 100644 --- a/sklearn/feature_selection/tests/test_rfe.py +++ b/sklearn/feature_selection/tests/test_rfe.py @@ -29,7 +29,7 @@ class MockClassifier: def __init__(self, foo_param=0): self.foo_param = foo_param - def fit(self, X, Y): + def fit(self, X, y, feature_names_in=None): assert len(X) == len(Y) self.coef_ = np.ones(X.shape[1], dtype=np.float64) return self diff --git a/sklearn/feature_selection/univariate_selection.py b/sklearn/feature_selection/univariate_selection.py index 554cb3d392b29..8f43c1332b681 100644 --- a/sklearn/feature_selection/univariate_selection.py +++ b/sklearn/feature_selection/univariate_selection.py @@ -322,7 +322,7 @@ class _BaseFilter(BaseEstimator, SelectorMixin): def __init__(self, score_func): self.score_func = score_func - def fit(self, X, y): + def fit(self, X, y, feature_names_in=None): """Run score function on (X, y) and get the appropriate features. Parameters diff --git a/sklearn/gaussian_process/gpc.py b/sklearn/gaussian_process/gpc.py index a5637b24fdc4d..879c41cb96200 100644 --- a/sklearn/gaussian_process/gpc.py +++ b/sklearn/gaussian_process/gpc.py @@ -156,7 +156,7 @@ def __init__(self, kernel=None, optimizer="fmin_l_bfgs_b", self.copy_X_train = copy_X_train self.random_state = random_state - def fit(self, X, y): + def fit(self, X, y, feature_names_in=None): """Fit Gaussian process classification model Parameters @@ -589,7 +589,7 @@ def __init__(self, kernel=None, optimizer="fmin_l_bfgs_b", self.multi_class = multi_class self.n_jobs = n_jobs - def fit(self, X, y): + def fit(self, X, y, feature_names_in=None): """Fit Gaussian process classification model Parameters diff --git a/sklearn/gaussian_process/gpr.py b/sklearn/gaussian_process/gpr.py index 3876325a613ab..e98b95d7d5914 100644 --- a/sklearn/gaussian_process/gpr.py +++ b/sklearn/gaussian_process/gpr.py @@ -159,7 +159,7 @@ def __init__(self, kernel=None, alpha=1e-10, self.copy_X_train = copy_X_train self.random_state = random_state - def fit(self, X, y): + def fit(self, X, y, feature_names_in=None): """Fit Gaussian process regression model. Parameters diff --git a/sklearn/inspection/tests/test_partial_dependence.py b/sklearn/inspection/tests/test_partial_dependence.py index ce8d7af3117cd..0ae8ac8645d1b 100644 --- a/sklearn/inspection/tests/test_partial_dependence.py +++ b/sklearn/inspection/tests/test_partial_dependence.py @@ -273,7 +273,7 @@ def test_multiclass_multioutput(Estimator): class NoPredictProbaNoDecisionFunction(BaseEstimator, ClassifierMixin): - def fit(self, X, y): + def fit(self, X, y, feature_names_in=None): return self diff --git a/sklearn/linear_model/base.py b/sklearn/linear_model/base.py index 7aaa5a6526bb1..975126261b8a2 100644 --- a/sklearn/linear_model/base.py +++ b/sklearn/linear_model/base.py @@ -194,7 +194,7 @@ class LinearModel(BaseEstimator, metaclass=ABCMeta): """Base class for Linear Models""" @abstractmethod - def fit(self, X, y): + def fit(self, X, y, feature_names_in=None): """Fit model.""" def _decision_function(self, X): diff --git a/sklearn/linear_model/bayes.py b/sklearn/linear_model/bayes.py index 0615afb710545..7c9fe95c2bf26 100644 --- a/sklearn/linear_model/bayes.py +++ b/sklearn/linear_model/bayes.py @@ -504,7 +504,7 @@ def __init__(self, n_iter=300, tol=1.e-3, alpha_1=1.e-6, alpha_2=1.e-6, self.copy_X = copy_X self.verbose = verbose - def fit(self, X, y): + def fit(self, X, y, feature_names_in=None): """Fit the ARDRegression model according to the given training data and parameters. diff --git a/sklearn/linear_model/coordinate_descent.py b/sklearn/linear_model/coordinate_descent.py index 646839a0a3ae6..88e83b5060d19 100644 --- a/sklearn/linear_model/coordinate_descent.py +++ b/sklearn/linear_model/coordinate_descent.py @@ -1059,7 +1059,7 @@ def __init__(self, eps=1e-3, n_alphas=100, alphas=None, fit_intercept=True, self.random_state = random_state self.selection = selection - def fit(self, X, y): + def fit(self, X, y, feature_names_in=None): """Fit linear model with coordinate descent Fit is on grid of alphas and best alpha estimated by cross-validation. @@ -1730,7 +1730,7 @@ def __init__(self, alpha=1.0, l1_ratio=0.5, fit_intercept=True, self.random_state = random_state self.selection = selection - def fit(self, X, y): + def fit(self, X, y, feature_names_in=None): """Fit MultiTaskElasticNet model with coordinate descent Parameters diff --git a/sklearn/linear_model/least_angle.py b/sklearn/linear_model/least_angle.py index a13ae35cc7563..05509c6850db8 100644 --- a/sklearn/linear_model/least_angle.py +++ b/sklearn/linear_model/least_angle.py @@ -1349,7 +1349,7 @@ def __init__(self, fit_intercept=True, verbose=False, max_iter=500, n_nonzero_coefs=500, eps=eps, copy_X=copy_X, fit_path=True) - def fit(self, X, y): + def fit(self, X, y, feature_names_in=None): """Fit the model using X, y as training data. Parameters diff --git a/sklearn/linear_model/omp.py b/sklearn/linear_model/omp.py index 232dcded410a5..d077de2ac739a 100644 --- a/sklearn/linear_model/omp.py +++ b/sklearn/linear_model/omp.py @@ -624,7 +624,7 @@ def __init__(self, n_nonzero_coefs=None, tol=None, fit_intercept=True, self.normalize = normalize self.precompute = precompute - def fit(self, X, y): + def fit(self, X, y, feature_names_in=None): """Fit the model using X, y as training data. Parameters @@ -863,7 +863,7 @@ def __init__(self, copy=True, fit_intercept=True, normalize=True, self.n_jobs = n_jobs self.verbose = verbose - def fit(self, X, y): + def fit(self, X, y, feature_names_in=None): """Fit the model using X, y as training data. Parameters diff --git a/sklearn/linear_model/stochastic_gradient.py b/sklearn/linear_model/stochastic_gradient.py index 02f065da9ae75..830da348ef7c3 100644 --- a/sklearn/linear_model/stochastic_gradient.py +++ b/sklearn/linear_model/stochastic_gradient.py @@ -104,7 +104,7 @@ def set_params(self, *args, **kwargs): return self @abstractmethod - def fit(self, X, y): + def fit(self, X, y, feature_names_in=None): """Fit model.""" def _validate_params(self, set_max_iter=True, for_partial_fit=False): diff --git a/sklearn/linear_model/tests/test_passive_aggressive.py b/sklearn/linear_model/tests/test_passive_aggressive.py index 8e8bfdc8b9800..7f0fa7fd8d7cd 100644 --- a/sklearn/linear_model/tests/test_passive_aggressive.py +++ b/sklearn/linear_model/tests/test_passive_aggressive.py @@ -32,7 +32,7 @@ def __init__(self, C=1.0, epsilon=0.01, loss="hinge", self.fit_intercept = fit_intercept self.n_iter = n_iter - def fit(self, X, y): + def fit(self, X, y, feature_names_in=None): n_samples, n_features = X.shape self.w = np.zeros(n_features, dtype=np.float64) self.b = 0.0 diff --git a/sklearn/linear_model/tests/test_perceptron.py b/sklearn/linear_model/tests/test_perceptron.py index bce518b5f2e37..7b45f91b4c4c8 100644 --- a/sklearn/linear_model/tests/test_perceptron.py +++ b/sklearn/linear_model/tests/test_perceptron.py @@ -24,7 +24,7 @@ class MyPerceptron: def __init__(self, n_iter=1): self.n_iter = n_iter - def fit(self, X, y): + def fit(self, X, y, feature_names_in=None): n_samples, n_features = X.shape self.w = np.zeros(n_features, dtype=np.float64) self.b = 0.0 diff --git a/sklearn/linear_model/theil_sen.py b/sklearn/linear_model/theil_sen.py index 941c51196cc4a..147f5af66d176 100644 --- a/sklearn/linear_model/theil_sen.py +++ b/sklearn/linear_model/theil_sen.py @@ -343,7 +343,7 @@ def _check_subparams(self, n_samples, n_features): return n_subsamples, n_subpopulation - def fit(self, X, y): + def fit(self, X, y, feature_names_in=None): """Fit linear model. Parameters diff --git a/sklearn/metrics/tests/test_score_objects.py b/sklearn/metrics/tests/test_score_objects.py index f1b9120b06442..e579ed03a9314 100644 --- a/sklearn/metrics/tests/test_score_objects.py +++ b/sklearn/metrics/tests/test_score_objects.py @@ -116,13 +116,13 @@ class EstimatorWithoutFit: class EstimatorWithFit(BaseEstimator): """Dummy estimator to test scoring validators""" - def fit(self, X, y): + def fit(self, X, y, feature_names_in=None): return self class EstimatorWithFitAndScore: """Dummy estimator to test scoring validators""" - def fit(self, X, y): + def fit(self, X, y, feature_names_in=None): return self def score(self, X, y): @@ -131,7 +131,7 @@ def score(self, X, y): class EstimatorWithFitAndPredict: """Dummy estimator to test scoring validators""" - def fit(self, X, y): + def fit(self, X, y, feature_names_in=None): self.y = y return self diff --git a/sklearn/model_selection/tests/test_search.py b/sklearn/model_selection/tests/test_search.py index 08759ec9267ab..11c089ee73737 100644 --- a/sklearn/model_selection/tests/test_search.py +++ b/sklearn/model_selection/tests/test_search.py @@ -74,7 +74,7 @@ class MockClassifier(object): def __init__(self, foo_param=0): self.foo_param = foo_param - def fit(self, X, Y): + def fit(self, X, y, feature_names_in=None): assert len(X) == len(Y) self.classes_ = np.unique(Y) return self @@ -538,7 +538,7 @@ class BrokenClassifier(BaseEstimator): def __init__(self, parameter=None): self.parameter = parameter - def fit(self, X, y): + def fit(self, X, y, feature_names_in=None): assert not hasattr(self, 'has_been_fit_') self.has_been_fit_ = True diff --git a/sklearn/multiclass.py b/sklearn/multiclass.py index 828c0257198b0..c1064ae795dc6 100644 --- a/sklearn/multiclass.py +++ b/sklearn/multiclass.py @@ -108,7 +108,7 @@ def _check_estimator(estimator): class _ConstantPredictor(BaseEstimator): - def fit(self, X, y): + def fit(self, X, y, feature_names_in=None): self.y_ = y return self @@ -180,7 +180,7 @@ def __init__(self, estimator, n_jobs=None): self.estimator = estimator self.n_jobs = n_jobs - def fit(self, X, y): + def fit(self, X, y, feature_names_in=None): """Fit underlying estimators. Parameters @@ -479,7 +479,7 @@ def __init__(self, estimator, n_jobs=None): self.estimator = estimator self.n_jobs = n_jobs - def fit(self, X, y): + def fit(self, X, y, feature_names_in=None): """Fit underlying estimators. Parameters @@ -706,7 +706,7 @@ def __init__(self, estimator, code_size=1.5, random_state=None, self.random_state = random_state self.n_jobs = n_jobs - def fit(self, X, y): + def fit(self, X, y, feature_names_in=None): """Fit underlying estimators. Parameters diff --git a/sklearn/multioutput.py b/sklearn/multioutput.py index c700fbb2964a0..f9e6fe3a08936 100644 --- a/sklearn/multioutput.py +++ b/sklearn/multioutput.py @@ -395,7 +395,7 @@ def __init__(self, base_estimator, order=None, cv=None, random_state=None): self.random_state = random_state @abstractmethod - def fit(self, X, Y): + def fit(self, X, y, feature_names_in=None): """Fit the model to data matrix X and targets Y. Parameters @@ -568,7 +568,7 @@ class labels for each estimator in the chain. """ - def fit(self, X, Y): + def fit(self, X, y, feature_names_in=None): """Fit the model to data matrix X and targets Y. Parameters @@ -718,7 +718,7 @@ class RegressorChain(_BaseChain, RegressorMixin, MetaEstimatorMixin): chaining. """ - def fit(self, X, Y): + def fit(self, X, y, feature_names_in=None): """Fit the model to data matrix X and targets Y. Parameters diff --git a/sklearn/neighbors/base.py b/sklearn/neighbors/base.py index ce3313e27a256..95842b86474c9 100644 --- a/sklearn/neighbors/base.py +++ b/sklearn/neighbors/base.py @@ -857,7 +857,7 @@ def radius_neighbors_graph(self, X=None, radius=None, mode='connectivity'): class SupervisedFloatMixin: - def fit(self, X, y): + def fit(self, X, y, feature_names_in=None): """Fit the model using X as training data and y as target values Parameters @@ -877,7 +877,7 @@ def fit(self, X, y): class SupervisedIntegerMixin: - def fit(self, X, y): + def fit(self, X, y, feature_names_in=None): """Fit the model using X as training data and y as target values Parameters diff --git a/sklearn/neighbors/nca.py b/sklearn/neighbors/nca.py index 513e95c1ca565..be7478a217f7f 100644 --- a/sklearn/neighbors/nca.py +++ b/sklearn/neighbors/nca.py @@ -168,7 +168,7 @@ def __init__(self, n_components=None, init='auto', warm_start=False, self.verbose = verbose self.random_state = random_state - def fit(self, X, y): + def fit(self, X, y, feature_names_in=None): """Fit the model according to the given training data. Parameters diff --git a/sklearn/neighbors/nearest_centroid.py b/sklearn/neighbors/nearest_centroid.py index 4168599781ea0..01246e764f5d9 100644 --- a/sklearn/neighbors/nearest_centroid.py +++ b/sklearn/neighbors/nearest_centroid.py @@ -82,7 +82,7 @@ def __init__(self, metric='euclidean', shrink_threshold=None): self.metric = metric self.shrink_threshold = shrink_threshold - def fit(self, X, y): + def fit(self, X, y, feature_names_in=None): """ Fit the NearestCentroid model according to the given training data. diff --git a/sklearn/neural_network/multilayer_perceptron.py b/sklearn/neural_network/multilayer_perceptron.py index e5325ecda69f0..134495445baed 100644 --- a/sklearn/neural_network/multilayer_perceptron.py +++ b/sklearn/neural_network/multilayer_perceptron.py @@ -624,7 +624,7 @@ def _update_no_improvement_count(self, early_stopping, X_val, y_val): if self.loss_curve_[-1] < self.best_loss_: self.best_loss_ = self.loss_curve_[-1] - def fit(self, X, y): + def fit(self, X, y, feature_names_in=None): """Fit the model to data matrix X and target(s) y. Parameters @@ -991,7 +991,7 @@ def predict(self, X): return self._label_binarizer.inverse_transform(y_pred) - def fit(self, X, y): + def fit(self, X, y, feature_names_in=None): """Fit the model to data matrix X and target(s) y. Parameters diff --git a/sklearn/semi_supervised/label_propagation.py b/sklearn/semi_supervised/label_propagation.py index 4820af8cb2b69..f4c977ace4003 100644 --- a/sklearn/semi_supervised/label_propagation.py +++ b/sklearn/semi_supervised/label_propagation.py @@ -200,7 +200,7 @@ class labels probabilities /= normalizer return probabilities - def fit(self, X, y): + def fit(self, X, y, feature_names_in=None): """Fit a semi-supervised label propagation model based All the input data is provided matrix X (labeled and unlabeled) @@ -396,7 +396,7 @@ class distributions will exceed 1 (normalization may be desired). affinity_matrix /= normalizer[:, np.newaxis] return affinity_matrix - def fit(self, X, y): + def fit(self, X, y, feature_names_in=None): return super().fit(X, y) diff --git a/sklearn/tests/test_calibration.py b/sklearn/tests/test_calibration.py index acac4c0471e0d..da97a005e3a2c 100644 --- a/sklearn/tests/test_calibration.py +++ b/sklearn/tests/test_calibration.py @@ -330,7 +330,7 @@ def test_calibration_accepts_ndarray(X): class MockTensorClassifier(BaseEstimator): """A toy estimator that accepts tensor inputs""" - def fit(self, X, y): + def fit(self, X, y, feature_names_in=None): self.classes_ = np.unique(y) return self diff --git a/sklearn/tests/test_pipeline.py b/sklearn/tests/test_pipeline.py index 2afdf090f3a07..3b7fbd57b2384 100644 --- a/sklearn/tests/test_pipeline.py +++ b/sklearn/tests/test_pipeline.py @@ -60,7 +60,7 @@ def __init__(self, a=None, b=None): class NoTrans(NoFit): - def fit(self, X, y): + def fit(self, X, y, feature_names=None): return self def get_params(self, deep=False): @@ -95,7 +95,7 @@ class Mult(BaseEstimator): def __init__(self, mult=1): self.mult = mult - def fit(self, X, y): + def fit(self, X, y, feature_names_in=None): return self def transform(self, X): @@ -139,7 +139,7 @@ def score(self, X, y=None, sample_weight=None): class DummyTransf(Transf): """Transformer which store the column means""" - def fit(self, X, y): + def fit(self, X, y, feature_names_in=None): self.means_ = np.mean(X, axis=0) # store timestamp to figure out whether the result of 'fit' has been # cached or not @@ -150,7 +150,7 @@ def fit(self, X, y): class DummyEstimatorParams(BaseEstimator): """Mock classifier that takes params on predict""" - def fit(self, X, y): + def fit(self, X, y, feature_names_in=None): return self def predict(self, X, got_attribute=False): diff --git a/sklearn/utils/tests/test_estimator_checks.py b/sklearn/utils/tests/test_estimator_checks.py index 969e0569afe99..3ac00e23bf940 100644 --- a/sklearn/utils/tests/test_estimator_checks.py +++ b/sklearn/utils/tests/test_estimator_checks.py @@ -43,7 +43,7 @@ class CorrectNotFittedError(ValueError): class BaseBadClassifier(BaseEstimator, ClassifierMixin): - def fit(self, X, y): + def fit(self, X, y, feature_names_in=None): return self def predict(self, X): @@ -145,13 +145,13 @@ def fit(self, X, y=None, feature_names_in=None): class NoCheckinPredict(BaseBadClassifier): - def fit(self, X, y): + def fit(self, X, y, feature_names_in=None): X, y = check_X_y(X, y) return self class NoSparseClassifier(BaseBadClassifier): - def fit(self, X, y): + def fit(self, X, y, feature_names_in=None): X, y = check_X_y(X, y, accept_sparse=['csr', 'csc']) if sp.issparse(X): raise ValueError("Nonsensical Error") @@ -163,7 +163,7 @@ def predict(self, X): class CorrectNotFittedErrorClassifier(BaseBadClassifier): - def fit(self, X, y): + def fit(self, X, y, feature_names_in=None): X, y = check_X_y(X, y) self.coef_ = np.ones(X.shape[1]) return self @@ -197,7 +197,7 @@ class BadBalancedWeightsClassifier(BaseBadClassifier): def __init__(self, class_weight=None): self.class_weight = class_weight - def fit(self, X, y): + def fit(self, X, y, feature_names_in=None): from sklearn.preprocessing import LabelEncoder from sklearn.utils import compute_class_weight @@ -226,7 +226,7 @@ def transform(self, X): class NotInvariantPredict(BaseEstimator): - def fit(self, X, y): + def fit(self, X, y, feature_names_in=None): # Convert data X, y = check_X_y(X, y, accept_sparse=("csr", "csc"), @@ -243,7 +243,7 @@ def predict(self, X): class LargeSparseNotSupportedClassifier(BaseEstimator): - def fit(self, X, y): + def fit(self, X, y, feature_names_in=None): X, y = check_X_y(X, y, accept_sparse=("csr", "csc", "coo"), accept_large_sparse=True, @@ -294,7 +294,7 @@ def _more_tags(self): class RequiresPositiveYRegressor(LinearRegression): - def fit(self, X, y): + def fit(self, X, y, feature_names_in=None): X, y = check_X_y(X, y) if (y <= 0).any(): raise ValueError('negative y values not supported!') @@ -311,7 +311,7 @@ def test_check_fit_score_takes_y_works_on_deprecated_fit(): class TestEstimatorWithDeprecatedFitMethod(BaseEstimator): @deprecated("Deprecated for the purpose of testing " "check_fit_score_takes_y") - def fit(self, X, y): + def fit(self, X, y, feature_names_in=None): return self check_fit_score_takes_y("test", TestEstimatorWithDeprecatedFitMethod()) diff --git a/sklearn/utils/tests/test_pprint.py b/sklearn/utils/tests/test_pprint.py index 8f3c13b1cf844..df8700e7a197a 100644 --- a/sklearn/utils/tests/test_pprint.py +++ b/sklearn/utils/tests/test_pprint.py @@ -37,7 +37,7 @@ def __init__(self, penalty='l2', dual=False, tol=1e-4, C=1.0, self.n_jobs = n_jobs self.l1_ratio = l1_ratio - def fit(self, X, y): + def fit(self, X, y, feature_names_in=None): return self diff --git a/sklearn/utils/tests/test_testing.py b/sklearn/utils/tests/test_testing.py index aac689fb2dc80..01d6221c43fbf 100644 --- a/sklearn/utils/tests/test_testing.py +++ b/sklearn/utils/tests/test_testing.py @@ -412,7 +412,7 @@ def f_bad_sections(self, X, y): class MockEst: def __init__(self): """MockEstimator""" - def fit(self, X, y): + def fit(self, X, y, feature_names_in=None): return X def predict(self, X): @@ -470,7 +470,7 @@ def predict_proba(self, X): return X @deprecated('Testing deprecated function with wrong params') - def fit(self, X, y): + def fit(self, X, y, feature_names_in=None): """Incorrect docstring but should not be tested"""