diff --git a/doc/modules/classes.rst b/doc/modules/classes.rst index a3f78fe167e8a..df7228596e424 100644 --- a/doc/modules/classes.rst +++ b/doc/modules/classes.rst @@ -463,6 +463,7 @@ From text feature_selection.SelectKBest feature_selection.SelectFpr feature_selection.SelectFdr + feature_selection.SelectFromModel feature_selection.SelectFwe feature_selection.RFE feature_selection.RFECV diff --git a/doc/modules/feature_selection.rst b/doc/modules/feature_selection.rst index edf2edf78ae96..88ff7d56d6865 100644 --- a/doc/modules/feature_selection.rst +++ b/doc/modules/feature_selection.rst @@ -131,33 +131,52 @@ number of features. elimination example with automatic tuning of the number of features selected with cross-validation. +.. _select_from_model: -.. _l1_feature_selection: +Feature selection using SelectFromModel +======================================= + +:class:`SelectFromModel` is a meta-transformer that can be used along with any +estimator that has a ``coef_`` or ``feature_importances_`` attribute after fitting. +The features are considered unimportant and removed, if the corresponding +``coef_`` or ``feature_importances_`` values are below the provided +``threshold`` parameter. Apart from specifying the threshold numerically, +there are build-in heuristics for finding a threshold using a string argument. +Available heuristics are "mean", "median" and float multiples of these like +"0.1*mean". + +For examples on how it is to be used refer to the sections below. + +.. topic:: Examples + + * :ref:`example_feature_selection_plot_select_from_model_boston.py`: Selecting the two + most important features from the Boston dataset without knowing the + threshold beforehand. L1-based feature selection -========================== +-------------------------- .. currentmodule:: sklearn -Selecting non-zero coefficients ---------------------------------- - :ref:`Linear models ` penalized with the L1 norm have sparse solutions: many of their estimated coefficients are zero. When the goal is to reduce the dimensionality of the data to use with another classifier, -they expose a ``transform`` method to select the non-zero coefficient. In -particular, sparse estimators useful for this purpose are the -:class:`linear_model.Lasso` for regression, and +they can be used along with :class:`feature_selection.SelectFromModel` +to select the non-zero coefficients. In particular, sparse estimators useful for +this purpose are the :class:`linear_model.Lasso` for regression, and of :class:`linear_model.LogisticRegression` and :class:`svm.LinearSVC` for classification:: >>> from sklearn.svm import LinearSVC >>> from sklearn.datasets import load_iris + >>> from sklearn.feature_selection import SelectFromModel >>> iris = load_iris() >>> X, y = iris.data, iris.target >>> X.shape (150, 4) - >>> X_new = LinearSVC(C=0.01, penalty="l1", dual=False).fit_transform(X, y) + >>> lsvc = LinearSVC(C=0.01, penalty="l1", dual=False).fit(X, y) + >>> model = SelectFromModel(lsvc, prefit=True) + >>> X_new = model.transform(X) >>> X_new.shape (150, 3) @@ -241,23 +260,27 @@ of features non zero. http://hal.inria.fr/hal-00354771/ Tree-based feature selection -============================ +---------------------------- Tree-based estimators (see the :mod:`sklearn.tree` module and forest of trees in the :mod:`sklearn.ensemble` module) can be used to compute feature importances, which in turn can be used to discard irrelevant -features:: +features (when coupled with the :class:`sklearn.feature_selection.SelectFromModel` +meta-transformer):: >>> from sklearn.ensemble import ExtraTreesClassifier >>> from sklearn.datasets import load_iris + >>> from sklearn.feature_selection import SelectFromModel >>> iris = load_iris() >>> X, y = iris.data, iris.target >>> X.shape (150, 4) >>> clf = ExtraTreesClassifier() - >>> X_new = clf.fit(X, y).transform(X) + >>> clf = clf.fit(X, y) >>> clf.feature_importances_ # doctest: +SKIP array([ 0.04..., 0.05..., 0.4..., 0.4...]) + >>> model = SelectFromModel(clf, prefit=True) + >>> X_new = model.transform(X) >>> X_new.shape # doctest: +SKIP (150, 2) @@ -278,12 +301,13 @@ the actual learning. The recommended way to do this in scikit-learn is to use a :class:`sklearn.pipeline.Pipeline`:: clf = Pipeline([ - ('feature_selection', LinearSVC(penalty="l1")), + ('feature_selection', SelectFromModel(LinearSVC(penalty="l1"))), ('classification', RandomForestClassifier()) ]) clf.fit(X, y) In this snippet we make use of a :class:`sklearn.svm.LinearSVC` +coupled with :class:`sklearn.feature_selection.SelectFromModel` to evaluate feature importances and select the most relevant features. Then, a :class:`sklearn.ensemble.RandomForestClassifier` is trained on the transformed output, i.e. using only relevant features. You can perform diff --git a/doc/whats_new.rst b/doc/whats_new.rst index 6f9f609174b39..d808cf6ad02f1 100644 --- a/doc/whats_new.rst +++ b/doc/whats_new.rst @@ -207,6 +207,11 @@ Enhancements the same. This allows gradient boosters to turn off presorting when building deep trees or using sparse data. By `Jacob Schreiber`_. + - Added :class:`feature_selection.SelectFromModel` meta-transformer which can + be used along with estimators that have `coef_` or `feature_importances_` + attribute to select important features of the input data. By + `Maheshakya Wijewardena`_, `Joel Nothman`_ and `Manoj Kumar`_. + Bug fixes ......... @@ -269,6 +274,13 @@ API changes summary fit method to the constructor in :class:`discriminant_analysis.QuadraticDiscriminantAnalysis`. + - Models inheriting from ``_LearntSelectorMixin`` will no longer support the + transform methods. (i.e, RandomForests, GradientBoosting, LogisticRegression, + DecisionTrees, SVMs and SGD related models). Wrap these models around the + metatransfomer :class:`feature_selection.SelectFromModel` to remove + features (according to `coefs_` or `feature_importances_`) + which are below a certain threshold value instead. + .. _changes_0_1_16: Version 0.16.1 diff --git a/examples/ensemble/plot_feature_transformation.py b/examples/ensemble/plot_feature_transformation.py index 6faf222cf9875..7b6941846ab54 100644 --- a/examples/ensemble/plot_feature_transformation.py +++ b/examples/ensemble/plot_feature_transformation.py @@ -34,6 +34,7 @@ from sklearn.linear_model import LogisticRegression from sklearn.ensemble import (RandomTreesEmbedding, RandomForestClassifier, GradientBoostingClassifier) +from sklearn.feature_selection import SelectFromModel from sklearn.preprocessing import OneHotEncoder from sklearn.cross_validation import train_test_split from sklearn.metrics import roc_curve @@ -53,12 +54,12 @@ rt = RandomTreesEmbedding(max_depth=3, n_estimators=n_estimator) rt_lm = LogisticRegression() rt.fit(X_train, y_train) -rt_lm.fit(rt.transform(X_train_lr), y_train_lr) +rt_lm.fit(SelectFromModel(rt, prefit=True).transform(X_train_lr), y_train_lr) -y_pred_rt = rt_lm.predict_proba(rt.transform(X_test))[:, 1] +y_pred_rt = rt_lm.predict_proba( + SelectFromModel(rt, prefit=True).transform(X_test))[:, 1] fpr_rt_lm, tpr_rt_lm, _ = roc_curve(y_test, y_pred_rt) - # Supervised transformation based on random forests rf = RandomForestClassifier(max_depth=3, n_estimators=n_estimator) rf_enc = OneHotEncoder() diff --git a/examples/ensemble/plot_random_forest_embedding.py b/examples/ensemble/plot_random_forest_embedding.py index ba6329d72905b..eef04ac3336c4 100644 --- a/examples/ensemble/plot_random_forest_embedding.py +++ b/examples/ensemble/plot_random_forest_embedding.py @@ -30,6 +30,7 @@ from sklearn.datasets import make_circles from sklearn.ensemble import RandomTreesEmbedding, ExtraTreesClassifier from sklearn.decomposition import TruncatedSVD +from sklearn.feature_selection import SelectFromModel from sklearn.naive_bayes import BernoulliNB # make a synthetic dataset @@ -37,7 +38,9 @@ # use RandomTreesEmbedding to transform data hasher = RandomTreesEmbedding(n_estimators=10, random_state=0, max_depth=3) -X_transformed = hasher.fit_transform(X) +hasher.fit(X) +model = SelectFromModel(hasher, prefit=True) +X_transformed = model.transform(X) # Visualize result using PCA pca = TruncatedSVD(n_components=2) diff --git a/examples/feature_selection/plot_select_from_model_boston.py b/examples/feature_selection/plot_select_from_model_boston.py new file mode 100644 index 0000000000000..17ef6d6bd0149 --- /dev/null +++ b/examples/feature_selection/plot_select_from_model_boston.py @@ -0,0 +1,51 @@ +""" +=================================================== +Feature selection using SelectFromModel and LassoCV +=================================================== + +Use SelectFromModel meta-transformer along with Lasso to select the best +couple of features from the Boston dataset. +""" +# Author: Manoj Kumar +# License: BSD 3 clause + +print(__doc__) + +import matplotlib.pyplot as plt +import numpy as np + +from sklearn.datasets import load_boston +from sklearn.feature_selection import SelectFromModel +from sklearn.linear_model import LassoCV + +# Load the boston dataset. +boston = load_boston() +X, y = boston['data'], boston['target'] + +# We use the base estimator LassoCV since the L1 norm promotes sparsity of features. +clf = LassoCV() + +# Set a minimum threshold of 0.25 +sfm = SelectFromModel(clf, threshold=0.25) +sfm.fit(X, y) +n_features = sfm.transform(X).shape[1] + +# Reset the threshold till the number of features equals two. +# Note that the attribute can be set directly instead of repeatedly +# fitting the metatransformer. +while n_features > 2: + sfm.threshold += 0.1 + X_transform = sfm.transform(X) + n_features = X_transform.shape[1] + +# Plot the selected two features from X. +plt.title( + "Features selected from Boston using SelectFromModel with " + "threshold %0.3f." % sfm.threshold) +feature1 = X_transform[:, 0] +feature2 = X_transform[:, 1] +plt.plot(feature1, feature2, 'r.') +plt.xlabel("Feature number 1") +plt.ylabel("Feature number 2") +plt.ylim([np.min(feature2), np.max(feature2)]) +plt.show() diff --git a/sklearn/ensemble/tests/test_forest.py b/sklearn/ensemble/tests/test_forest.py index 1723b6131d5ea..631c726f0381e 100644 --- a/sklearn/ensemble/tests/test_forest.py +++ b/sklearn/ensemble/tests/test_forest.py @@ -19,6 +19,7 @@ from scipy.sparse import csc_matrix from scipy.sparse import coo_matrix +from sklearn.utils import warnings from sklearn.utils.testing import assert_almost_equal from sklearn.utils.testing import assert_array_almost_equal from sklearn.utils.testing import assert_array_equal @@ -194,15 +195,19 @@ def test_probability(): def check_importances(X, y, name, criterion): ForestEstimator = FOREST_ESTIMATORS[name] - est = ForestEstimator(n_estimators=20, criterion=criterion,random_state=0) + est = ForestEstimator(n_estimators=20, criterion=criterion, + random_state=0) est.fit(X, y) importances = est.feature_importances_ n_important = np.sum(importances > 0.1) assert_equal(importances.shape[0], 10) assert_equal(n_important, 3) - X_new = est.transform(X, threshold="mean") - assert_less(X_new.shape[1], X.shape[1]) + # XXX: Remove this test in 0.19 after transform support to estimators + # is removed. + X_new = assert_warns( + DeprecationWarning, est.transform, X, threshold="mean") + assert_less(0 < X_new.shape[1], X.shape[1]) # Check with parallel importances = est.feature_importances_ diff --git a/sklearn/ensemble/tests/test_gradient_boosting.py b/sklearn/ensemble/tests/test_gradient_boosting.py index 024e7cf5e3975..4f2329be50358 100644 --- a/sklearn/ensemble/tests/test_gradient_boosting.py +++ b/sklearn/ensemble/tests/test_gradient_boosting.py @@ -26,6 +26,7 @@ from sklearn.utils.testing import assert_raises from sklearn.utils.testing import assert_true from sklearn.utils.testing import assert_warns +from sklearn.utils.testing import ignore_warnings from sklearn.utils.validation import DataConversionWarning from sklearn.utils.validation import NotFittedError @@ -296,10 +297,13 @@ def test_feature_importances(): clf.fit(X, y) assert_true(hasattr(clf, 'feature_importances_')) - X_new = clf.transform(X, threshold="mean") + # XXX: Remove this test in 0.19 after transform support to estimators + # is removed. + X_new = assert_warns( + DeprecationWarning, clf.transform, X, threshold="mean") assert_less(X_new.shape[1], X.shape[1]) - - feature_mask = clf.feature_importances_ > clf.feature_importances_.mean() + feature_mask = ( + clf.feature_importances_ > clf.feature_importances_.mean()) assert_array_almost_equal(X_new, X[:, feature_mask]) diff --git a/sklearn/feature_selection/__init__.py b/sklearn/feature_selection/__init__.py index 0d222534f5e93..acb03f6f24a9e 100644 --- a/sklearn/feature_selection/__init__.py +++ b/sklearn/feature_selection/__init__.py @@ -20,6 +20,8 @@ from .rfe import RFE from .rfe import RFECV +from .from_model import SelectFromModel + __all__ = ['GenericUnivariateSelect', 'RFE', 'RFECV', @@ -32,4 +34,5 @@ 'chi2', 'f_classif', 'f_oneway', - 'f_regression'] + 'f_regression', + 'SelectFromModel'] diff --git a/sklearn/feature_selection/base.py b/sklearn/feature_selection/base.py index e3ff0ed3bbebf..e8a0733a28637 100644 --- a/sklearn/feature_selection/base.py +++ b/sklearn/feature_selection/base.py @@ -81,7 +81,7 @@ def transform(self, X): return np.empty(0).reshape((X.shape[0], 0)) if len(mask) != X.shape[1]: raise ValueError("X has a different shape than during fitting.") - return check_array(X, accept_sparse='csr')[:, safe_mask(X, mask)] + return X[:, safe_mask(X, mask)] def inverse_transform(self, X): """ diff --git a/sklearn/feature_selection/from_model.py b/sklearn/feature_selection/from_model.py index 9dc652d93e53b..81e35a8000adf 100644 --- a/sklearn/feature_selection/from_model.py +++ b/sklearn/feature_selection/from_model.py @@ -1,14 +1,83 @@ -# Authors: Gilles Louppe, Mathieu Blondel +# Authors: Gilles Louppe, Mathieu Blondel, Maheshakya Wijewardena # License: BSD 3 clause import numpy as np -from ..base import TransformerMixin +from .base import SelectorMixin +from ..base import (TransformerMixin, BaseEstimator, clone, + MetaEstimatorMixin) from ..externals import six -from ..utils import safe_mask, check_array + +from ..utils import safe_mask, check_array, deprecated from ..utils.validation import NotFittedError, check_is_fitted +def _get_feature_importances(estimator): + """Retrieve or aggregate feature importances from estimator""" + if hasattr(estimator, "feature_importances_"): + importances = estimator.feature_importances_ + + elif hasattr(estimator, "coef_"): + if estimator.coef_.ndim == 1: + importances = np.abs(estimator.coef_) + + else: + importances = np.sum(np.abs(estimator.coef_), axis=0) + + else: + raise ValueError( + "The underlying estimator %s has no `coef_` or " + "`feature_importances_` attribute. Either pass a fitted estimator" + " to SelectFromModel or call fit before calling transform." + % estimator.__class__.__name__) + + return importances + + +def _calculate_threshold(estimator, importances, threshold): + """Interpret the threshold value""" + + if threshold is None: + # determine default from estimator + est_name = estimator.__class__.__name__ + if ((hasattr(estimator, "penalty") and estimator.penalty == "l1") or + "Lasso" in est_name): + # the natural default threshold is 0 when l1 penalty was used + threshold = 1e-5 + else: + threshold = "mean" + + if isinstance(threshold, six.string_types): + if "*" in threshold: + scale, reference = threshold.split("*") + scale = float(scale.strip()) + reference = reference.strip() + + if reference == "median": + reference = np.median(importances) + elif reference == "mean": + reference = np.mean(importances) + else: + raise ValueError("Unknown reference: " + reference) + + threshold = scale * reference + + elif threshold == "median": + threshold = np.median(importances) + + elif threshold == "mean": + threshold = np.mean(importances) + + else: + raise ValueError("Expected threshold='mean' or threshold='median' " + "got %s" % threshold) + + else: + threshold = float(threshold) + + return threshold + + class _LearntSelectorMixin(TransformerMixin): # Note because of the extra threshold parameter in transform, this does # not naturally extend from SelectorMixin @@ -18,6 +87,8 @@ class _LearntSelectorMixin(TransformerMixin): ``feature_importances_`` or ``coef_`` attribute to evaluate the relative importance of individual features for feature selection. """ + @deprecated('Support to use estimators as feature selectors will be ' + 'removed in version 0.19. Use SelectFromModel instead.') def transform(self, X, threshold=None): """Reduce X to its most important features. @@ -44,59 +115,18 @@ def transform(self, X, threshold=None): X_r : array of shape [n_samples, n_selected_features] The input samples with only the selected features. """ - check_is_fitted(self, ('coef_', 'feature_importances_'), + check_is_fitted(self, ('coef_', 'feature_importances_'), all_or_any=any) X = check_array(X, 'csc') - # Retrieve importance vector - if hasattr(self, "feature_importances_"): - importances = self.feature_importances_ - - elif hasattr(self, "coef_"): - if self.coef_ is None: - msg = "This model is not fitted yet. Please call fit() first" - raise NotFittedError(msg) - - if self.coef_.ndim == 1: - importances = np.abs(self.coef_) - else: - importances = np.sum(np.abs(self.coef_), axis=0) - + importances = _get_feature_importances(self) if len(importances) != X.shape[1]: raise ValueError("X has different number of features than" " during model fitting.") - # Retrieve threshold if threshold is None: - if hasattr(self, "penalty") and self.penalty == "l1": - # the natural default threshold is 0 when l1 penalty was used - threshold = getattr(self, "threshold", 1e-5) - else: - threshold = getattr(self, "threshold", "mean") - - if isinstance(threshold, six.string_types): - if "*" in threshold: - scale, reference = threshold.split("*") - scale = float(scale.strip()) - reference = reference.strip() - - if reference == "median": - reference = np.median(importances) - elif reference == "mean": - reference = np.mean(importances) - else: - raise ValueError("Unknown reference: " + reference) - - threshold = scale * reference - - elif threshold == "median": - threshold = np.median(importances) - - elif threshold == "mean": - threshold = np.mean(importances) - - else: - threshold = float(threshold) + threshold = getattr(self, 'threshold', None) + threshold = _calculate_threshold(self, importances, threshold) # Selection try: @@ -111,3 +141,116 @@ def transform(self, X, threshold=None): return X[:, mask] else: raise ValueError("Invalid threshold: all features are discarded.") + + +class SelectFromModel(BaseEstimator, SelectorMixin): + """Meta-transformer for selecting features based on importance weights. + + Parameters + ---------- + estimator : object + The base estimator from which the transformer is built. + This can be both a fitted (if ``prefit`` is set to True) + or a non-fitted estimator. + + threshold : string, float, optional default None + The threshold value to use for feature selection. Features whose + importance is greater or equal are kept while the others are + discarded. If "median" (resp. "mean"), then the ``threshold`` value is + the median (resp. the mean) of the feature importances. A scaling + factor (e.g., "1.25*mean") may also be used. If None and if the + estimator has a parameter penalty set to l1, either explicitly + or implicity (e.g, Lasso), the threshold is used is 1e-5. + Otherwise, "mean" is used by default. + + prefit : bool, default False + Whether a prefit model is expected to be passed into the constructor + directly or not. If True, ``transform`` must be called directly + and SelectFromModel cannot be used with ``cross_val_score``, + ``GridSearchCV`` and similar utilities that clone the estimator. + Otherwise train the model using ``fit`` and then ``transform`` to do + feature selection. + + Attributes + ---------- + `estimator_`: an estimator + The base estimator from which the transformer is built. + This is stored only when a non-fitted estimator is passed to the + ``SelectFromModel``, i.e when prefit is False. + + `threshold_`: float + The threshold value used for feature selection. + """ + def __init__(self, estimator, threshold=None, prefit=False): + self.estimator = estimator + self.threshold = threshold + self.prefit = prefit + + def _get_support_mask(self): + # SelectFromModel can directly call on transform. + if self.prefit: + estimator = self.estimator + elif hasattr(self, 'estimator_'): + estimator = self.estimator_ + else: + raise ValueError( + 'Either fit the model before transform or set "prefit=True"' + ' while passing the fitted estimator to the constructor.') + scores = _get_feature_importances(estimator) + self.threshold_ = _calculate_threshold(estimator, scores, + self.threshold) + return scores >= self.threshold_ + + def fit(self, X, y=None, **fit_params): + """Fit the SelectFromModel meta-transformer. + + Parameters + ---------- + X : array-like of shape (n_samples, n_features) + The training input samples. + + y : array-like, shape (n_samples,) + The target values (integers that correspond to classes in + classification, real numbers in regression). + + **fit_params : Other estimator specific parameters + + Returns + ------- + self : object + Returns self. + """ + if self.prefit: + raise NotFittedError( + "Since 'prefit=True', call transform directly") + if not hasattr(self, "estimator_"): + self.estimator_ = clone(self.estimator) + self.estimator_.fit(X, y, **fit_params) + return self + + def partial_fit(self, X, y=None, **fit_params): + """Fit the SelectFromModel meta-transformer only once. + + Parameters + ---------- + X : array-like of shape (n_samples, n_features) + The training input samples. + + y : array-like, shape (n_samples,) + The target values (integers that correspond to classes in + classification, real numbers in regression). + + **fit_params : Other estimator specific parameters + + Returns + ------- + self : object + Returns self. + """ + if self.prefit: + raise NotFittedError( + "Since 'prefit=True', call transform directly") + if not hasattr(self, "estimator_"): + self.estimator_ = clone(self.estimator) + self.estimator_.partial_fit(X, y, **fit_params) + return self diff --git a/sklearn/feature_selection/tests/test_from_model.py b/sklearn/feature_selection/tests/test_from_model.py index 8367ac65c56d3..f28426d515196 100644 --- a/sklearn/feature_selection/tests/test_from_model.py +++ b/sklearn/feature_selection/tests/test_from_model.py @@ -5,14 +5,22 @@ from sklearn.utils.testing import assert_less from sklearn.utils.testing import assert_greater +from sklearn.utils.testing import assert_equal +from sklearn.utils.testing import assert_array_almost_equal +from sklearn.utils.testing import assert_array_equal +from sklearn.utils.testing import assert_almost_equal +from sklearn.utils.testing import assert_warns -from sklearn.datasets import load_iris -from sklearn.linear_model import LogisticRegression -from sklearn.linear_model import SGDClassifier +from sklearn import datasets +from sklearn.linear_model import LogisticRegression, SGDClassifier, Lasso from sklearn.svm import LinearSVC +from sklearn.feature_selection import SelectFromModel +from sklearn.ensemble import RandomForestClassifier +from sklearn.linear_model import PassiveAggressiveClassifier -iris = load_iris() - +iris = datasets.load_iris() +data, y = iris.data, iris.target +rng = np.random.RandomState(0) def test_transform_linear_model(): for clf in (LogisticRegression(C=0.1), @@ -21,23 +29,148 @@ def test_transform_linear_model(): random_state=0)): for thresh in (None, ".09*mean", "1e-5 * median"): for func in (np.array, sp.csr_matrix): - X = func(iris.data) + X = func(data) clf.set_params(penalty="l1") - clf.fit(X, iris.target) - X_new = clf.transform(X, thresh) + clf.fit(X, y) + X_new = assert_warns( + DeprecationWarning, clf.transform, X, thresh) if isinstance(clf, SGDClassifier): assert_true(X_new.shape[1] <= X.shape[1]) else: assert_less(X_new.shape[1], X.shape[1]) clf.set_params(penalty="l2") - clf.fit(X_new, iris.target) + clf.fit(X_new, y) pred = clf.predict(X_new) - assert_greater(np.mean(pred == iris.target), 0.7) + assert_greater(np.mean(pred == y), 0.7) def test_invalid_input(): clf = SGDClassifier(alpha=0.1, n_iter=10, shuffle=True, random_state=None) + for threshold in ["gobbledigook", ".5 * gobbledigook"]: + model = SelectFromModel(clf, threshold=threshold) + model.fit(data, y) + assert_raises(ValueError, model.transform, data) + + +def test_input_estimator_unchanged(): + """ + Test that SelectFromModel fits on a clone of the estimator. + """ + est = RandomForestClassifier() + transformer = SelectFromModel(estimator=est) + transformer.fit(data, y) + assert_true(transformer.estimator is est) + + +def test_feature_importances(): + X, y = datasets.make_classification( + n_samples=1000, n_features=10, n_informative=3, n_redundant=0, + n_repeated=0, shuffle=False, random_state=0) + + est = RandomForestClassifier(n_estimators=50, random_state=0) + for threshold, func in zip(["mean", "median"], [np.mean, np.median]): + transformer = SelectFromModel(estimator=est, threshold=threshold) + transformer.fit(X, y) + assert_true(hasattr(transformer.estimator_, 'feature_importances_')) + + X_new = transformer.transform(X) + assert_less(X_new.shape[1], X.shape[1]) + importances = transformer.estimator_.feature_importances_ + + feature_mask = np.abs(importances) > func(importances) + assert_array_almost_equal(X_new, X[:, feature_mask]) + + # Check with sample weights + sample_weight = np.ones(y.shape) + sample_weight[y == 1] *= 100 + + est = RandomForestClassifier(n_estimators=50, random_state=0) + transformer = SelectFromModel(estimator=est) + transformer.fit(X, y, sample_weight=sample_weight) + importances = transformer.estimator_.feature_importances_ + transformer.fit(X, y, sample_weight=3*sample_weight) + importances_bis = transformer.estimator_.feature_importances_ + assert_almost_equal(importances, importances_bis) + + # For the Lasso and related models, the threshold defaults to 1e-5 + transformer = SelectFromModel(estimator=Lasso(alpha=0.1)) + transformer.fit(X, y) + X_new = transformer.transform(X) + mask = np.abs(transformer.estimator_.coef_) > 1e-5 + assert_array_equal(X_new, X[:, mask]) + + +def test_partial_fit(): + est = PassiveAggressiveClassifier(random_state=0, shuffle=False) + transformer = SelectFromModel(estimator=est) + transformer.partial_fit(data, y, + classes=np.unique(y)) + old_model = transformer.estimator_ + transformer.partial_fit(data, y, + classes=np.unique(y)) + new_model = transformer.estimator_ + assert_true(old_model is new_model) + + X_transform = transformer.transform(data) + transformer.fit(np.vstack((data, data)), np.concatenate((y, y))) + assert_array_equal(X_transform, transformer.transform(data)) + + +def test_warm_start(): + est = PassiveAggressiveClassifier(warm_start=True, random_state=0) + transformer = SelectFromModel(estimator=est) + transformer.fit(data, y) + old_model = transformer.estimator_ + transformer.fit(data, y) + new_model = transformer.estimator_ + assert_true(old_model is new_model) + + +def test_prefit(): + """ + Test all possible combinations of the prefit parameter. + """ + # Passing a prefit parameter with the selected model + # and fitting a unfit model with prefit=False should give same results. + clf = SGDClassifier(alpha=0.1, n_iter=10, shuffle=True, random_state=0) + model = SelectFromModel(clf) + model.fit(data, y) + X_transform = model.transform(data) + clf.fit(data, y) + model = SelectFromModel(clf, prefit=True) + assert_array_equal(model.transform(data), X_transform) + + # Check that the model is rewritten if prefit=False and a fitted model is + # passed + model = SelectFromModel(clf, prefit=False) + model.fit(data, y) + assert_array_equal(model.transform(data), X_transform) + + # Check that prefit=True and calling fit raises a ValueError + model = SelectFromModel(clf, prefit=True) + assert_raises(ValueError, model.fit, data, y) + + +def test_threshold_string(): + est = RandomForestClassifier(n_estimators=50, random_state=0) + model = SelectFromModel(est, threshold="0.5*mean") + model.fit(data, y) + X_transform = model.transform(data) + + # Calculate the threshold from the estimator directly. + est.fit(data, y) + threshold = 0.5 * np.mean(est.feature_importances_) + mask = est.feature_importances_ > threshold + assert_array_equal(X_transform, data[:, mask]) + + +def test_threshold_without_refitting(): + """Test that the threshold can be set without refitting the model.""" + clf = SGDClassifier(alpha=0.1, n_iter=10, shuffle=True, random_state=0) + model = SelectFromModel(clf, threshold=0.1) + model.fit(data, y) + X_transform = model.transform(data) - clf.fit(iris.data, iris.target) - assert_raises(ValueError, clf.transform, iris.data, "gobbledigook") - assert_raises(ValueError, clf.transform, iris.data, ".5 * gobbledigook") + # Set a higher threshold to filter out more features. + model.threshold = 1.0 + assert_greater(X_transform.shape[1], model.transform(data).shape[1]) diff --git a/sklearn/tree/tests/test_tree.py b/sklearn/tree/tests/test_tree.py index 9b553d4d3c51d..0573fc0108e17 100644 --- a/sklearn/tree/tests/test_tree.py +++ b/sklearn/tree/tests/test_tree.py @@ -26,7 +26,9 @@ from sklearn.utils.testing import assert_greater_equal from sklearn.utils.testing import assert_less from sklearn.utils.testing import assert_true +from sklearn.utils.testing import assert_warns from sklearn.utils.testing import raises + from sklearn.utils.validation import check_random_state from sklearn.utils.validation import NotFittedError from sklearn.utils.testing import ignore_warnings @@ -377,7 +379,8 @@ def test_importances(): assert_equal(importances.shape[0], 10, "Failed with {0}".format(name)) assert_equal(n_important, 3, "Failed with {0}".format(name)) - X_new = clf.transform(X, threshold="mean") + X_new = assert_warns( + DeprecationWarning, clf.transform, X, threshold="mean") assert_less(0, X_new.shape[1], "Failed with {0}".format(name)) assert_less(X_new.shape[1], X.shape[1], "Failed with {0}".format(name)) diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index 7732941213576..b9a0ea56d817d 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -61,6 +61,16 @@ 'RANSACRegressor', 'RadiusNeighborsRegressor', 'RandomForestRegressor', 'Ridge', 'RidgeCV'] +# Estimators with deprecated transform methods. Should be removed in 0.19 when +# _LearntSelectorMixin is removed. +DEPRECATED_TRANSFORM = [ + "RandomForestClassifier", "RandomForestRegressor", "ExtraTreesClassifier", + "ExtraTreesRegressor", "RandomTreesEmbedding", "DecisionTreeClassifier", + "DecisionTreeRegressor", "ExtraTreeClassifier", "ExtraTreeRegressor", + "LinearSVC", "SGDClassifier", "SGDRegressor", "Perceptron", + "LogisticRegression", "LogisticRegressionCV", + "GradientBoostingClassifier", "GradientBoostingRegressor"] + def _yield_non_meta_checks(name, Estimator): yield check_estimators_dtypes @@ -168,8 +178,9 @@ def _yield_all_checks(name, Estimator): for check in _yield_regressor_checks(name, Estimator): yield check if issubclass(Estimator, TransformerMixin): - for check in _yield_transformer_checks(name, Estimator): - yield check + if name not in DEPRECATED_TRANSFORM: + for check in _yield_transformer_checks(name, Estimator): + yield check if issubclass(Estimator, ClusterMixin): for check in _yield_clustering_checks(name, Estimator): yield check @@ -329,7 +340,8 @@ def check_dtype_object(name, Estimator): if hasattr(estimator, "predict"): estimator.predict(X) - if hasattr(estimator, "transform"): + if (hasattr(estimator, "transform") and + name not in DEPRECATED_TRANSFORM): estimator.transform(X) try: @@ -581,7 +593,12 @@ def check_pipeline_consistency(name, Estimator): pipeline = make_pipeline(estimator) estimator.fit(X, y) pipeline.fit(X, y) - funcs = ["score", "fit_transform"] + + if name in DEPRECATED_TRANSFORM: + funcs = ["score"] + else: + funcs = ["score", "fit_transform"] + for func_name in funcs: func = getattr(estimator, func_name, None) if func is not None: @@ -602,8 +619,12 @@ def check_fit_score_takes_y(name, Estimator): estimator = Estimator() set_fast_parameters(estimator) set_random_state(estimator) - funcs = ["fit", "score", "partial_fit", "fit_predict", "fit_transform"] + if name in DEPRECATED_TRANSFORM: + funcs = ["fit", "score", "partial_fit", "fit_predict"] + else: + funcs = [ + "fit", "score", "partial_fit", "fit_predict", "fit_transform"] for func_name in funcs: func = getattr(estimator, func_name, None) if func is not None: @@ -624,6 +645,13 @@ def check_estimators_dtypes(name, Estimator): X_train_int_32 = X_train_32.astype(np.int32) y = X_train_int_64[:, 0] y = multioutput_estimator_convert_y_2d(name, y) + + if name in DEPRECATED_TRANSFORM: + methods = ["predict", "decision_function", "predict_proba"] + else: + methods = [ + "predict", "transform", "decision_function", "predict_proba"] + for X_train in [X_train_32, X_train_64, X_train_int_64, X_train_int_32]: with warnings.catch_warnings(record=True): estimator = Estimator() @@ -631,8 +659,7 @@ def check_estimators_dtypes(name, Estimator): set_random_state(estimator, 1) estimator.fit(X_train, y) - for method in ["predict", "transform", "decision_function", - "predict_proba"]: + for method in methods: if hasattr(estimator, method): getattr(estimator, method)(X_train) @@ -709,7 +736,8 @@ def check_estimators_nan_inf(name, Estimator): raise AssertionError(error_string_predict, Estimator) # transform - if hasattr(estimator, "transform"): + if (hasattr(estimator, "transform") and + name not in DEPRECATED_TRANSFORM): try: estimator.transform(X_train) except ValueError as e: @@ -726,8 +754,11 @@ def check_estimators_nan_inf(name, Estimator): def check_estimators_pickle(name, Estimator): """Test that we can pickle all estimators""" - check_methods = ["predict", "transform", "decision_function", - "predict_proba"] + if name in DEPRECATED_TRANSFORM: + check_methods = ["predict", "decision_function", "predict_proba"] + else: + check_methods = ["predict", "transform", "decision_function", + "predict_proba"] X, y = make_blobs(n_samples=30, centers=[[0, 0, 0], [1, 1, 1]], random_state=0, n_features=2, cluster_std=0.1) @@ -1446,7 +1477,7 @@ def fit(self, X, y): if name in ('FeatureUnion', 'Pipeline'): e = estimator([('clf', T())]) - elif name in ('GridSearchCV' 'RandomizedSearchCV'): + elif name in ('GridSearchCV', 'RandomizedSearchCV', 'SelectFromModel'): return else: diff --git a/sklearn/utils/testing.py b/sklearn/utils/testing.py index 5d6fa2173bc2a..75f5b4c314c34 100644 --- a/sklearn/utils/testing.py +++ b/sklearn/utils/testing.py @@ -528,8 +528,8 @@ def uninstall_mldata_mock(): "OutputCodeClassifier", "OneVsRestClassifier", "RFE", "RFECV", "BaseEnsemble"] # estimators that there is no way to default-construct sensibly -OTHER = ["Pipeline", "FeatureUnion", "GridSearchCV", - "RandomizedSearchCV"] +OTHER = ["Pipeline", "FeatureUnion", "GridSearchCV", "RandomizedSearchCV", + "SelectFromModel"] # some trange ones DONT_TEST = ['SparseCoder', 'EllipticEnvelope', 'DictVectorizer', diff --git a/sklearn/utils/tests/test_estimator_checks.py b/sklearn/utils/tests/test_estimator_checks.py index b21bc6c1038d3..47c71759a859f 100644 --- a/sklearn/utils/tests/test_estimator_checks.py +++ b/sklearn/utils/tests/test_estimator_checks.py @@ -7,7 +7,7 @@ from sklearn.utils.testing import assert_raises_regex, assert_true from sklearn.utils.estimator_checks import check_estimator from sklearn.utils.estimator_checks import check_estimators_unfitted -from sklearn.linear_model import LogisticRegression +from sklearn.ensemble import AdaBoostClassifier from sklearn.utils.validation import check_X_y, check_array @@ -91,7 +91,7 @@ def test_check_estimator(): assert_true(msg in string_buffer.getvalue()) # doesn't error on actual estimator - check_estimator(LogisticRegression) + check_estimator(AdaBoostClassifier) def test_check_estimators_unfitted():