From c438f78996636864be57de2d62b0ddae2a8be43e Mon Sep 17 00:00:00 2001 From: maheshakya Date: Fri, 28 Mar 2014 02:58:15 +0530 Subject: [PATCH 01/11] Implemented SelectFromModel meta-transformer --- sklearn/feature_selection/__init__.py | 5 +- sklearn/feature_selection/from_model.py | 238 +++++++++++++++--- .../tests/test_from_model.py | 79 +++++- sklearn/utils/testing.py | 4 +- 4 files changed, 284 insertions(+), 42 deletions(-) diff --git a/sklearn/feature_selection/__init__.py b/sklearn/feature_selection/__init__.py index 0d222534f5e93..acb03f6f24a9e 100644 --- a/sklearn/feature_selection/__init__.py +++ b/sklearn/feature_selection/__init__.py @@ -20,6 +20,8 @@ from .rfe import RFE from .rfe import RFECV +from .from_model import SelectFromModel + __all__ = ['GenericUnivariateSelect', 'RFE', 'RFECV', @@ -32,4 +34,5 @@ 'chi2', 'f_classif', 'f_oneway', - 'f_regression'] + 'f_regression', + 'SelectFromModel'] diff --git a/sklearn/feature_selection/from_model.py b/sklearn/feature_selection/from_model.py index 9dc652d93e53b..89a1a293c1765 100644 --- a/sklearn/feature_selection/from_model.py +++ b/sklearn/feature_selection/from_model.py @@ -1,14 +1,77 @@ -# Authors: Gilles Louppe, Mathieu Blondel +# Authors: Gilles Louppe, Mathieu Blondel, Maheshakya Wijewardena # License: BSD 3 clause import numpy as np -from ..base import TransformerMixin +from ..base import (TransformerMixin, BaseEstimator, clone, SelectorMixin, + MetaEstimatorMixin) from ..externals import six -from ..utils import safe_mask, check_array + +from ..utils import safe_mask, check_array, deprecated from ..utils.validation import NotFittedError, check_is_fitted +def _get_feature_importances(estimator, X): + """Retrieve or aggregate feature importances from estimator""" + if hasattr(estimator, "feature_importances_"): + importances = estimator.feature_importances_ + + elif hasattr(estimator, "coef_"): + if estimator.coef_.ndim == 1: + importances = np.abs(estimator.coef_) + + else: + importances = np.sum(np.abs(estimator.coef_), axis=0) + + else: + raise ValueError("Missing `feature_importances_` or `coef_`" + " attribute, did you forget to set the " + "estimator's parameter to compute it?") + if len(importances) != X.shape[1]: + raise ValueError("X has different number of features than" + " during model fitting.") + + return importances + + +def _calculate_threshold(estimator, importances, threshold): + """Interpret the threshold value""" + + if threshold is None: + # determine default from estimator + if hasattr(estimator, "penalty") and estimator.penalty == "l1": + # the natural default threshold is 0 when l1 penalty was used + threshold = 1e-5 + else: + threshold = "mean" + + if isinstance(threshold, six.string_types): + if "*" in threshold: + scale, reference = threshold.split("*") + scale = float(scale.strip()) + reference = reference.strip() + + if reference == "median": + reference = np.median(importances) + elif reference == "mean": + reference = np.mean(importances) + else: + raise ValueError("Unknown reference: " + reference) + + threshold = scale * reference + + elif threshold == "median": + threshold = np.median(importances) + + elif threshold == "mean": + threshold = np.mean(importances) + + else: + threshold = float(threshold) + + return threshold + + class _LearntSelectorMixin(TransformerMixin): # Note because of the extra threshold parameter in transform, this does # not naturally extend from SelectorMixin @@ -18,6 +81,8 @@ class _LearntSelectorMixin(TransformerMixin): ``feature_importances_`` or ``coef_`` attribute to evaluate the relative importance of individual features for feature selection. """ + @deprecated('Support to use estimators as feature selectors will be ' + 'removed in version 0.17. Use SelectFromModel instead.') def transform(self, X, threshold=None): """Reduce X to its most important features. @@ -48,6 +113,28 @@ def transform(self, X, threshold=None): all_or_any=any) X = check_array(X, 'csc') + importances = _get_feature_importances(self, X) + + if threshold is None: + threshold = getattr(self, 'threshold', None) + threshold = _calculate_threshold(self, importances, threshold) + + # Selection + try: + mask = importances >= threshold + except TypeError: + # Fails in Python 3.x when threshold is str; + # result is array of True + raise ValueError("Invalid threshold: all features are discarded.") + + if np.any(mask): + mask = safe_mask(X, mask) + return X[:, mask] + else: + raise ValueError("Invalid threshold: all features are discarded.") + + @staticmethod + def _set_importances(estimator, X): # Retrieve importance vector if hasattr(self, "feature_importances_"): importances = self.feature_importances_ @@ -59,55 +146,132 @@ def transform(self, X, threshold=None): if self.coef_.ndim == 1: importances = np.abs(self.coef_) + + if hasattr(estimator, "feature_importances_"): + importances = estimator.feature_importances_ + if importances is None: + raise ValueError("Importance weights not computed. Please set" + " the compute_importances parameter before " + "fit.") + + elif hasattr(estimator, "coef_"): + if estimator.coef_.ndim == 1: + importances = np.abs(estimator.coef_) + else: - importances = np.sum(np.abs(self.coef_), axis=0) + importances = np.sum(np.abs(estimator.coef_), axis=0) if len(importances) != X.shape[1]: raise ValueError("X has different number of features than" " during model fitting.") + return importances + + @staticmethod + def _set_threshold(estimator, threshold): # Retrieve threshold if threshold is None: - if hasattr(self, "penalty") and self.penalty == "l1": + if hasattr(estimator, "penalty") and estimator.penalty == "l1": # the natural default threshold is 0 when l1 penalty was used - threshold = getattr(self, "threshold", 1e-5) + threshold = getattr(estimator, "threshold", 1e-5) else: - threshold = getattr(self, "threshold", "mean") + threshold = getattr(estimator, "threshold", "mean") + + return threshold + - if isinstance(threshold, six.string_types): - if "*" in threshold: - scale, reference = threshold.split("*") - scale = float(scale.strip()) - reference = reference.strip() +class SelectFromModel(BaseEstimator, SelectorMixin): + """Meta-transformer for selecting features based on importance + weights. - if reference == "median": - reference = np.median(importances) - elif reference == "mean": - reference = np.mean(importances) - else: - raise ValueError("Unknown reference: " + reference) + Parameters + ---------- + estimator : object or None(default=None) + The base estimator from which the transformer is built. + If None, then a value error is raised. - threshold = scale * reference + threshold : string, float or None, optional (default=None) + The threshold value to use for feature selection. Features whose + importance is greater or equal are kept while the others are + discarded. If "median" (resp. "mean"), then the threshold value is + the median (resp. the mean) of the feature importances. A scaling + factor (e.g., "1.25*mean") may also be used. If None and if + available, the object attribute ``threshold`` is used. Otherwise, + "mean" is used by default. - elif threshold == "median": - threshold = np.median(importances) + warm_start : bool, optional + When set to True, reuse the solution of the previous call to fit as + initialization, otherwise, just erase the previous solution. - elif threshold == "mean": - threshold = np.mean(importances) + Attributes + ---------- + `estimator_`: an estimator + The base estimator from which the transformer is built. - else: - threshold = float(threshold) + `scores_`: array, shape=(n_features,) + The importance of each feature according to the fit model. - # Selection - try: - mask = importances >= threshold - except TypeError: - # Fails in Python 3.x when threshold is str; - # result is array of True - raise ValueError("Invalid threshold: all features are discarded.") + `threshold_`: float + The threshold value used for feature selection. + """ - if np.any(mask): - mask = safe_mask(X, mask) - return X[:, mask] - else: - raise ValueError("Invalid threshold: all features are discarded.") + def __init__(self, estimator, threshold=None, warm_start=False): + self.estimator = estimator + self.threshold = threshold + self.warm_start = warm_start + + def _get_support_mask(self): + self.threshold_ = _calculate_threshold(self.estimator, self.scores_, + self.threshold) + return self.scores_ >= self.threshold_ + + def fit(self, X, y, **fit_params): + """Fit the SelectFromModel meta-transformer. + + Parameters + ---------- + X : array-like of shape = [n_samples, n_features] + The training input samples. + + y : array-like, shape = [n_samples] + The target values (integers that correspond to classes in + classification, real numbers in regression). + + **fit_params : Other estimator specific parameters + + Returns + ------- + self : object + Returns self. + """ + if not (self.warm_start and hasattr(self,"estimator_")): + self.estimator_ = clone(self.estimator) + + self.estimator_.fit(X, y, **fit_params) + self.scores_ = _get_feature_importances(self.estimator_, X) + return self + + def partial_fit(self, X, y, **fit_params): + """Fit the SelectFromModel meta-transformer only once. + + Parameters + ---------- + X : array-like of shape = [n_samples, n_features] + The training input samples. + + y : array-like, shape = [n_samples] + The target values (integers that correspond to classes in + classification, real numbers in regression). + + **fit_params : Other estimator specific parameters + + Returns + ------- + self : object + Returns self. + """ + if not hasattr(self, "estimator_"): + self.estimator_ = clone(self.estimator) + self.estimator_.partial_fit(X, y, **fit_params) + self.scores_ = _get_feature_importances(self.estimator_, X) + return self diff --git a/sklearn/feature_selection/tests/test_from_model.py b/sklearn/feature_selection/tests/test_from_model.py index 8367ac65c56d3..1abbf15354426 100644 --- a/sklearn/feature_selection/tests/test_from_model.py +++ b/sklearn/feature_selection/tests/test_from_model.py @@ -5,13 +5,19 @@ from sklearn.utils.testing import assert_less from sklearn.utils.testing import assert_greater +from sklearn.utils.testing import assert_equal +from sklearn.utils.testing import assert_array_almost_equal +from sklearn.utils.testing import assert_almost_equal -from sklearn.datasets import load_iris +from sklearn import datasets from sklearn.linear_model import LogisticRegression from sklearn.linear_model import SGDClassifier from sklearn.svm import LinearSVC +from sklearn.feature_selection import SelectFromModel +from sklearn.ensemble import RandomForestClassifier +from sklearn.linear_model import PassiveAggressiveClassifier -iris = load_iris() +iris = datasets.load_iris() def test_transform_linear_model(): @@ -41,3 +47,72 @@ def test_invalid_input(): clf.fit(iris.data, iris.target) assert_raises(ValueError, clf.transform, iris.data, "gobbledigook") assert_raises(ValueError, clf.transform, iris.data, ".5 * gobbledigook") + + +def test_validate_estimator(): + est = RandomForestClassifier() + transformer = SelectFromModel(estimator=est) + transformer.fit(iris.data, iris.target) + assert_equal(transformer.estimator, est) + + +def test_feature_importances(): + X, y = datasets.make_classification(n_samples=1000, + n_features=10, + n_informative=3, + n_redundant=0, + n_repeated=0, + shuffle=False, + random_state=0) + + est = RandomForestClassifier(n_estimators=50, random_state=0) + transformer = SelectFromModel(estimator=est) + + transformer.fit(X, y) + assert_true(hasattr(transformer.estimator_, 'feature_importances_')) + + X_new = transformer.transform(X) + assert_less(X_new.shape[1], X.shape[1]) + + feature_mask = (transformer.estimator_.feature_importances_ > + transformer.estimator_.feature_importances_.mean()) + assert_array_almost_equal(X_new, X[:, feature_mask]) + + # Check with sample weights + sample_weight = np.ones(y.shape) + sample_weight[y == 1] *= 100 + + est = RandomForestClassifier(n_estimators=50, random_state=0) + transformer = SelectFromModel(estimator=est) + transformer.fit(X, y, sample_weight=sample_weight) + importances = transformer.estimator_.feature_importances_ + assert_less(importances[1], X.shape[1]) + + est = RandomForestClassifier(n_estimators=50, random_state=0) + transformer = SelectFromModel(estimator=est) + transformer.fit(X, y, sample_weight=3*sample_weight) + importances_bis = transformer.estimator_.feature_importances_ + assert_almost_equal(importances, importances_bis) + + +def test_partial_fit(): + est = PassiveAggressiveClassifier() + transformer = SelectFromModel(estimator=est) + transformer.partial_fit(iris.data, iris.target, + classes=np.unique(iris.target)) + id_1 = id(transformer.estimator_) + transformer.partial_fit(iris.data, iris.target, + classes=np.unique(iris.target)) + id_2 = id(transformer.estimator_) + assert_equal(id_1, id_2) + + +def test_warm_start(): + est = PassiveAggressiveClassifier() + transformer = SelectFromModel(estimator=est, + warm_start=True) + transformer.fit(iris.data, iris.target) + id_1 = id(transformer.estimator_) + transformer.fit(iris.data, iris.target) + id_2 = id(transformer.estimator_) + assert_equal(id_1, id_2) diff --git a/sklearn/utils/testing.py b/sklearn/utils/testing.py index 5d6fa2173bc2a..5c481419f1328 100644 --- a/sklearn/utils/testing.py +++ b/sklearn/utils/testing.py @@ -528,8 +528,8 @@ def uninstall_mldata_mock(): "OutputCodeClassifier", "OneVsRestClassifier", "RFE", "RFECV", "BaseEnsemble"] # estimators that there is no way to default-construct sensibly -OTHER = ["Pipeline", "FeatureUnion", "GridSearchCV", - "RandomizedSearchCV"] +OTHER = ["Pipeline", "FeatureUnion", "GridSearchCV", "RandomizedSearchCV", + "GridSearchCV", "SelectFromModel"] # some trange ones DONT_TEST = ['SparseCoder', 'EllipticEnvelope', 'DictVectorizer', From 8a28ea8169e55347173cedfd9266209317d7bcf5 Mon Sep 17 00:00:00 2001 From: MechCoder Date: Mon, 27 Jul 2015 18:10:45 +0530 Subject: [PATCH 02/11] fix test failures --- sklearn/feature_selection/from_model.py | 3 ++- sklearn/utils/estimator_checks.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/sklearn/feature_selection/from_model.py b/sklearn/feature_selection/from_model.py index 89a1a293c1765..22c49c6ab363c 100644 --- a/sklearn/feature_selection/from_model.py +++ b/sklearn/feature_selection/from_model.py @@ -3,7 +3,8 @@ import numpy as np -from ..base import (TransformerMixin, BaseEstimator, clone, SelectorMixin, +from .base import SelectorMixin +from ..base import (TransformerMixin, BaseEstimator, clone, MetaEstimatorMixin) from ..externals import six diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index 7732941213576..2832dc1974229 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -1446,7 +1446,7 @@ def fit(self, X, y): if name in ('FeatureUnion', 'Pipeline'): e = estimator([('clf', T())]) - elif name in ('GridSearchCV' 'RandomizedSearchCV'): + elif name in ('GridSearchCV', 'RandomizedSearchCV', 'SelectFromModel'): return else: From 459cb9ba3da4afef6373166dc83cd14acf1a0aff Mon Sep 17 00:00:00 2001 From: MechCoder Date: Tue, 8 Sep 2015 17:21:04 -0400 Subject: [PATCH 03/11] Remove warm start --- sklearn/feature_selection/from_model.py | 10 ++-------- sklearn/feature_selection/tests/test_from_model.py | 5 ++--- 2 files changed, 4 insertions(+), 11 deletions(-) diff --git a/sklearn/feature_selection/from_model.py b/sklearn/feature_selection/from_model.py index 22c49c6ab363c..dcae7c2be8602 100644 --- a/sklearn/feature_selection/from_model.py +++ b/sklearn/feature_selection/from_model.py @@ -200,10 +200,6 @@ class SelectFromModel(BaseEstimator, SelectorMixin): available, the object attribute ``threshold`` is used. Otherwise, "mean" is used by default. - warm_start : bool, optional - When set to True, reuse the solution of the previous call to fit as - initialization, otherwise, just erase the previous solution. - Attributes ---------- `estimator_`: an estimator @@ -216,10 +212,9 @@ class SelectFromModel(BaseEstimator, SelectorMixin): The threshold value used for feature selection. """ - def __init__(self, estimator, threshold=None, warm_start=False): + def __init__(self, estimator, threshold=None): self.estimator = estimator self.threshold = threshold - self.warm_start = warm_start def _get_support_mask(self): self.threshold_ = _calculate_threshold(self.estimator, self.scores_, @@ -245,9 +240,8 @@ def fit(self, X, y, **fit_params): self : object Returns self. """ - if not (self.warm_start and hasattr(self,"estimator_")): + if not hasattr(self, "estimator_"): self.estimator_ = clone(self.estimator) - self.estimator_.fit(X, y, **fit_params) self.scores_ = _get_feature_importances(self.estimator_, X) return self diff --git a/sklearn/feature_selection/tests/test_from_model.py b/sklearn/feature_selection/tests/test_from_model.py index 1abbf15354426..cb3b26105e8a6 100644 --- a/sklearn/feature_selection/tests/test_from_model.py +++ b/sklearn/feature_selection/tests/test_from_model.py @@ -108,9 +108,8 @@ def test_partial_fit(): def test_warm_start(): - est = PassiveAggressiveClassifier() - transformer = SelectFromModel(estimator=est, - warm_start=True) + est = PassiveAggressiveClassifier(warm_start=True) + transformer = SelectFromModel(estimator=est) transformer.fit(iris.data, iris.target) id_1 = id(transformer.estimator_) transformer.fit(iris.data, iris.target) From 2416e2aee694263eef0c69df1a8070701ac72356 Mon Sep 17 00:00:00 2001 From: MechCoder Date: Wed, 9 Sep 2015 23:14:18 -0400 Subject: [PATCH 04/11] Catch filters instead of removing the tests --- sklearn/ensemble/tests/test_forest.py | 11 ++++++++--- sklearn/ensemble/tests/test_gradient_boosting.py | 16 +++++++++++----- sklearn/tree/tests/test_tree.py | 12 +++++++++--- 3 files changed, 28 insertions(+), 11 deletions(-) diff --git a/sklearn/ensemble/tests/test_forest.py b/sklearn/ensemble/tests/test_forest.py index 1723b6131d5ea..449fc91a6a420 100644 --- a/sklearn/ensemble/tests/test_forest.py +++ b/sklearn/ensemble/tests/test_forest.py @@ -19,6 +19,7 @@ from scipy.sparse import csc_matrix from scipy.sparse import coo_matrix +from sklearn.utils import warnings from sklearn.utils.testing import assert_almost_equal from sklearn.utils.testing import assert_array_almost_equal from sklearn.utils.testing import assert_array_equal @@ -28,6 +29,7 @@ from sklearn.utils.testing import assert_greater_equal from sklearn.utils.testing import assert_raises from sklearn.utils.testing import assert_warns +from sklearn.utils.testing import clean_warning_registry from sklearn.utils.testing import ignore_warnings from sklearn import datasets @@ -194,15 +196,18 @@ def test_probability(): def check_importances(X, y, name, criterion): ForestEstimator = FOREST_ESTIMATORS[name] - est = ForestEstimator(n_estimators=20, criterion=criterion,random_state=0) + est = ForestEstimator(n_estimators=20, criterion=criterion, + random_state=0) est.fit(X, y) importances = est.feature_importances_ n_important = np.sum(importances > 0.1) assert_equal(importances.shape[0], 10) assert_equal(n_important, 3) - X_new = est.transform(X, threshold="mean") - assert_less(X_new.shape[1], X.shape[1]) + clean_warning_registry() + with warnings.catch_warnings(record=True) as record: + X_new = est.transform(X, threshold="mean") + assert_less(0 < X_new.shape[1], X.shape[1]) # Check with parallel importances = est.feature_importances_ diff --git a/sklearn/ensemble/tests/test_gradient_boosting.py b/sklearn/ensemble/tests/test_gradient_boosting.py index 024e7cf5e3975..c9ff843c27622 100644 --- a/sklearn/ensemble/tests/test_gradient_boosting.py +++ b/sklearn/ensemble/tests/test_gradient_boosting.py @@ -16,7 +16,7 @@ from sklearn.ensemble import GradientBoostingRegressor from sklearn.ensemble.gradient_boosting import ZeroEstimator from sklearn.metrics import mean_squared_error -from sklearn.utils import check_random_state, tosequence +from sklearn.utils import check_random_state, tosequence, warnings from sklearn.utils.testing import assert_almost_equal from sklearn.utils.testing import assert_array_almost_equal from sklearn.utils.testing import assert_array_equal @@ -26,6 +26,8 @@ from sklearn.utils.testing import assert_raises from sklearn.utils.testing import assert_true from sklearn.utils.testing import assert_warns +from sklearn.utils.testing import clean_warning_registry +from sklearn.utils.testing import ignore_warnings from sklearn.utils.validation import DataConversionWarning from sklearn.utils.validation import NotFittedError @@ -295,12 +297,16 @@ def test_feature_importances(): presort=presort) clf.fit(X, y) assert_true(hasattr(clf, 'feature_importances_')) + clean_warning_registry() + with warnings.catch_warnings(record=True) as record: + X_new = clf.transform(X, threshold="mean") + assert_less(X_new.shape[1], X.shape[1]) - X_new = clf.transform(X, threshold="mean") - assert_less(X_new.shape[1], X.shape[1]) + X_new = clf.transform(X, threshold="mean") + assert_less(X_new.shape[1], X.shape[1]) - feature_mask = clf.feature_importances_ > clf.feature_importances_.mean() - assert_array_almost_equal(X_new, X[:, feature_mask]) + feature_mask = clf.feature_importances_ > clf.feature_importances_.mean() + assert_array_almost_equal(X_new, X[:, feature_mask]) def test_probability_log(): diff --git a/sklearn/tree/tests/test_tree.py b/sklearn/tree/tests/test_tree.py index 9b553d4d3c51d..37787d89092c8 100644 --- a/sklearn/tree/tests/test_tree.py +++ b/sklearn/tree/tests/test_tree.py @@ -16,6 +16,7 @@ from sklearn.metrics import accuracy_score from sklearn.metrics import mean_squared_error +from sklearn.utils import warnings from sklearn.utils.testing import assert_array_equal from sklearn.utils.testing import assert_array_almost_equal from sklearn.utils.testing import assert_almost_equal @@ -26,7 +27,9 @@ from sklearn.utils.testing import assert_greater_equal from sklearn.utils.testing import assert_less from sklearn.utils.testing import assert_true +from sklearn.utils.testing import clean_warning_registry from sklearn.utils.testing import raises + from sklearn.utils.validation import check_random_state from sklearn.utils.validation import NotFittedError from sklearn.utils.testing import ignore_warnings @@ -377,9 +380,12 @@ def test_importances(): assert_equal(importances.shape[0], 10, "Failed with {0}".format(name)) assert_equal(n_important, 3, "Failed with {0}".format(name)) - X_new = clf.transform(X, threshold="mean") - assert_less(0, X_new.shape[1], "Failed with {0}".format(name)) - assert_less(X_new.shape[1], X.shape[1], "Failed with {0}".format(name)) + + clean_warning_registry() + with warnings.catch_warnings(record=True) as record: + X_new = clf.transform(X, threshold="mean") + assert_less(0, X_new.shape[1], "Failed with {0}".format(name)) + assert_less(X_new.shape[1], X.shape[1], "Failed with {0}".format(name)) # Check on iris that importances are the same for all builders clf = DecisionTreeClassifier(random_state=0) From 9cee0d95f89a4f63b9d7c79de7ed1503f163c316 Mon Sep 17 00:00:00 2001 From: MechCoder Date: Fri, 11 Sep 2015 13:49:19 -0400 Subject: [PATCH 05/11] Added example to depict feature selction using SelectFromModel and Lasso --- .../feature_selection/select_from_model.py | 51 +++++++++++++++++++ sklearn/feature_selection/from_model.py | 21 ++++---- 2 files changed, 61 insertions(+), 11 deletions(-) create mode 100644 examples/feature_selection/select_from_model.py diff --git a/examples/feature_selection/select_from_model.py b/examples/feature_selection/select_from_model.py new file mode 100644 index 0000000000000..84d24c09a6b71 --- /dev/null +++ b/examples/feature_selection/select_from_model.py @@ -0,0 +1,51 @@ +""" +=================================================== +Feature selection using SelectFromModel and LassoCV +=================================================== + +Use SelectFromModel meta-transformer along with Lasso to select the best +couple of features from the Boston dataset. +""" +# Author: Manoj Kumar +# License: BSD 3 clause + +print(__doc__) + +import matplotlib.pyplot as plt +import numpy as np + +from sklearn.datasets import load_boston +from sklearn.feature_selection import SelectFromModel +from sklearn.linear_model import LassoCV + +# Load the boston dataset. +boston = load_boston() +X, y = boston['data'], boston['target'] + +# We use the base estimator since the L1 norm promotes sparsity of features. +clf = LassoCV() + +# Set a minimum threshold of 0.25 +sfm = SelectFromModel(clf, threshold=0.25) +sfm.fit(X, y) +n_features = sfm.transform(X).shape[1] + +# Reset the threshold till the number of features equals two. +# Note that the attribute can be set directly instead of repeatedly +# fitting the metatransformer. +while n_features > 2: + sfm.threshold += 0.1 + X_transform = sfm.transform(X) + n_features = X_transform.shape[1] + +# Plot the selected two features from X. +plt.title( + "Features selected from Boston using SelectFromModel with " + "threshold %0.3f." % sfm.threshold) +feature1 = X_transform[:, 0] +feature2 = X_transform[:, 1] +plt.plot(feature1, feature2, 'r.') +plt.xlabel("Feature number 1") +plt.ylabel("Feature number 2") +plt.ylim([np.min(feature2), np.max(feature2)]) +plt.show() diff --git a/sklearn/feature_selection/from_model.py b/sklearn/feature_selection/from_model.py index dcae7c2be8602..a09179642e694 100644 --- a/sklearn/feature_selection/from_model.py +++ b/sklearn/feature_selection/from_model.py @@ -187,25 +187,24 @@ class SelectFromModel(BaseEstimator, SelectorMixin): Parameters ---------- - estimator : object or None(default=None) + estimator : object The base estimator from which the transformer is built. - If None, then a value error is raised. - threshold : string, float or None, optional (default=None) - The threshold value to use for feature selection. Features whose - importance is greater or equal are kept while the others are - discarded. If "median" (resp. "mean"), then the threshold value is - the median (resp. the mean) of the feature importances. A scaling - factor (e.g., "1.25*mean") may also be used. If None and if - available, the object attribute ``threshold`` is used. Otherwise, - "mean" is used by default. + threshold : string, float, optional + The threshold value to use for feature selection. Features whose + importance is greater or equal are kept while the others are + discarded. If "median" (resp. "mean"), then the threshold value is + the median (resp. the mean) of the feature importances. A scaling + factor (e.g., "1.25*mean") may also be used. If None and if + available, the object attribute ``threshold`` is used. Otherwise, + "mean" is used by default. Attributes ---------- `estimator_`: an estimator The base estimator from which the transformer is built. - `scores_`: array, shape=(n_features,) + `scores_`: array, shape(n_features,) The importance of each feature according to the fit model. `threshold_`: float From 3d41053d67c48a997374e76cdafca247b3a0f88b Mon Sep 17 00:00:00 2001 From: MechCoder Date: Fri, 11 Sep 2015 14:20:32 -0400 Subject: [PATCH 06/11] Minor doc changes and removed _set_threshold and _set_importances --- doc/whats_new.rst | 5 +++ sklearn/feature_selection/base.py | 2 +- sklearn/feature_selection/from_model.py | 56 +++---------------------- 3 files changed, 11 insertions(+), 52 deletions(-) diff --git a/doc/whats_new.rst b/doc/whats_new.rst index 6f9f609174b39..180e0619fab5f 100644 --- a/doc/whats_new.rst +++ b/doc/whats_new.rst @@ -207,6 +207,11 @@ Enhancements the same. This allows gradient boosters to turn off presorting when building deep trees or using sparse data. By `Jacob Schreiber`_. + - Added :class:`feature_selection.SelectFromModel` meta-transformer which can + be used along with estimators that have `coef_` or `feature_importances_` + attribute to select important features of the input data. By + `Maheshakya Wijewardena`_, `Joel Nothman`_ and `Manoj Kumar`_. + Bug fixes ......... diff --git a/sklearn/feature_selection/base.py b/sklearn/feature_selection/base.py index e3ff0ed3bbebf..e8a0733a28637 100644 --- a/sklearn/feature_selection/base.py +++ b/sklearn/feature_selection/base.py @@ -81,7 +81,7 @@ def transform(self, X): return np.empty(0).reshape((X.shape[0], 0)) if len(mask) != X.shape[1]: raise ValueError("X has a different shape than during fitting.") - return check_array(X, accept_sparse='csr')[:, safe_mask(X, mask)] + return X[:, safe_mask(X, mask)] def inverse_transform(self, X): """ diff --git a/sklearn/feature_selection/from_model.py b/sklearn/feature_selection/from_model.py index a09179642e694..0e446b4c5eca6 100644 --- a/sklearn/feature_selection/from_model.py +++ b/sklearn/feature_selection/from_model.py @@ -83,7 +83,7 @@ class _LearntSelectorMixin(TransformerMixin): importance of individual features for feature selection. """ @deprecated('Support to use estimators as feature selectors will be ' - 'removed in version 0.17. Use SelectFromModel instead.') + 'removed in version 0.19. Use SelectFromModel instead.') def transform(self, X, threshold=None): """Reduce X to its most important features. @@ -134,52 +134,6 @@ def transform(self, X, threshold=None): else: raise ValueError("Invalid threshold: all features are discarded.") - @staticmethod - def _set_importances(estimator, X): - # Retrieve importance vector - if hasattr(self, "feature_importances_"): - importances = self.feature_importances_ - - elif hasattr(self, "coef_"): - if self.coef_ is None: - msg = "This model is not fitted yet. Please call fit() first" - raise NotFittedError(msg) - - if self.coef_.ndim == 1: - importances = np.abs(self.coef_) - - if hasattr(estimator, "feature_importances_"): - importances = estimator.feature_importances_ - if importances is None: - raise ValueError("Importance weights not computed. Please set" - " the compute_importances parameter before " - "fit.") - - elif hasattr(estimator, "coef_"): - if estimator.coef_.ndim == 1: - importances = np.abs(estimator.coef_) - - else: - importances = np.sum(np.abs(estimator.coef_), axis=0) - - if len(importances) != X.shape[1]: - raise ValueError("X has different number of features than" - " during model fitting.") - - return importances - - @staticmethod - def _set_threshold(estimator, threshold): - # Retrieve threshold - if threshold is None: - if hasattr(estimator, "penalty") and estimator.penalty == "l1": - # the natural default threshold is 0 when l1 penalty was used - threshold = getattr(estimator, "threshold", 1e-5) - else: - threshold = getattr(estimator, "threshold", "mean") - - return threshold - class SelectFromModel(BaseEstimator, SelectorMixin): """Meta-transformer for selecting features based on importance @@ -225,10 +179,10 @@ def fit(self, X, y, **fit_params): Parameters ---------- - X : array-like of shape = [n_samples, n_features] + X : array-like of shape (n_samples, n_features) The training input samples. - y : array-like, shape = [n_samples] + y : array-like, shape (n_samples,) The target values (integers that correspond to classes in classification, real numbers in regression). @@ -250,10 +204,10 @@ def partial_fit(self, X, y, **fit_params): Parameters ---------- - X : array-like of shape = [n_samples, n_features] + X : array-like of shape (n_samples, n_features) The training input samples. - y : array-like, shape = [n_samples] + y : array-like, shape (n_samples,) The target values (integers that correspond to classes in classification, real numbers in regression). From 10176d98977f25fe592432341f2aead5f57e3d7f Mon Sep 17 00:00:00 2001 From: MechCoder Date: Tue, 22 Sep 2015 15:48:47 -0400 Subject: [PATCH 07/11] Now a fitted estimator can be passed to SelectFromModel --- doc/whats_new.rst | 9 +++- .../feature_selection/select_from_model.py | 2 +- sklearn/feature_selection/from_model.py | 43 +++++++++++-------- .../tests/test_from_model.py | 13 ++++++ 4 files changed, 48 insertions(+), 19 deletions(-) diff --git a/doc/whats_new.rst b/doc/whats_new.rst index 180e0619fab5f..fdbc12f887d44 100644 --- a/doc/whats_new.rst +++ b/doc/whats_new.rst @@ -261,7 +261,7 @@ API changes summary caused confusion in how the array elements should be interpreted as features or as samples. All data arrays are now expected to be explicitly shaped ``(n_samples, n_features)``. - By `Vighnesh Birodkar`_. + By `Vighnesh Birodkar`_ - :class:`lda.LDA` and :class:`qda.QDA` have been moved to :class:`discriminant_analysis.LinearDiscriminantAnalysis` and @@ -274,6 +274,13 @@ API changes summary fit method to the constructor in :class:`discriminant_analysis.QuadraticDiscriminantAnalysis`. + - Models inheriting from ``_LearntSelectorMixin`` will no longer support the + transform methods. (i.e, RandomForests, GradientBoosting, LogisticRegression, + DecisionTrees, SVMs and SGD related models). Wrap these models around the + metatransfomer :class:`feature_selection.SelectFromModel` to remove + features (according to `coefs_` or `feature_importances_`) + which are below a certain threshold value instead. + .. _changes_0_1_16: Version 0.16.1 diff --git a/examples/feature_selection/select_from_model.py b/examples/feature_selection/select_from_model.py index 84d24c09a6b71..972049b1b7aa7 100644 --- a/examples/feature_selection/select_from_model.py +++ b/examples/feature_selection/select_from_model.py @@ -22,7 +22,7 @@ boston = load_boston() X, y = boston['data'], boston['target'] -# We use the base estimator since the L1 norm promotes sparsity of features. +# We use the base estimator LassoCV since the L1 norm promotes sparsity of features. clf = LassoCV() # Set a minimum threshold of 0.25 diff --git a/sklearn/feature_selection/from_model.py b/sklearn/feature_selection/from_model.py index 0e446b4c5eca6..d03d3d302e905 100644 --- a/sklearn/feature_selection/from_model.py +++ b/sklearn/feature_selection/from_model.py @@ -5,14 +5,14 @@ from .base import SelectorMixin from ..base import (TransformerMixin, BaseEstimator, clone, - MetaEstimatorMixin) + MetaEstimatorMixin) from ..externals import six from ..utils import safe_mask, check_array, deprecated from ..utils.validation import NotFittedError, check_is_fitted -def _get_feature_importances(estimator, X): +def _get_feature_importances(estimator): """Retrieve or aggregate feature importances from estimator""" if hasattr(estimator, "feature_importances_"): importances = estimator.feature_importances_ @@ -25,12 +25,11 @@ def _get_feature_importances(estimator, X): importances = np.sum(np.abs(estimator.coef_), axis=0) else: - raise ValueError("Missing `feature_importances_` or `coef_`" - " attribute, did you forget to set the " - "estimator's parameter to compute it?") - if len(importances) != X.shape[1]: - raise ValueError("X has different number of features than" - " during model fitting.") + raise ValueError( + "The underlying estimator %s has no `coef_` or " + "`feature_importances_` attribute. Either pass a fitted estimator" + " to SelectFromModel or call fit before calling transform." + % estimator.__class__.__name__) return importances @@ -110,11 +109,14 @@ def transform(self, X, threshold=None): X_r : array of shape [n_samples, n_selected_features] The input samples with only the selected features. """ - check_is_fitted(self, ('coef_', 'feature_importances_'), + check_is_fitted(self, ('coef_', 'feature_importances_'), all_or_any=any) X = check_array(X, 'csc') - importances = _get_feature_importances(self, X) + importances = _get_feature_importances(self) + if len(importances) != X.shape[1]: + raise ValueError("X has different number of features than" + " during model fitting.") if threshold is None: threshold = getattr(self, 'threshold', None) @@ -143,6 +145,10 @@ class SelectFromModel(BaseEstimator, SelectorMixin): ---------- estimator : object The base estimator from which the transformer is built. + This can be both a fitted or a non-fitted estimator. + If it a fitted estimator, then transform can be called directly, + otherwise train the model using fit and then transform to do + feature selection. threshold : string, float, optional The threshold value to use for feature selection. Features whose @@ -157,9 +163,8 @@ class SelectFromModel(BaseEstimator, SelectorMixin): ---------- `estimator_`: an estimator The base estimator from which the transformer is built. - - `scores_`: array, shape(n_features,) - The importance of each feature according to the fit model. + This is stored only when a non-fitted estimator is passed to the + SelectFromModel. `threshold_`: float The threshold value used for feature selection. @@ -170,9 +175,15 @@ def __init__(self, estimator, threshold=None): self.threshold = threshold def _get_support_mask(self): - self.threshold_ = _calculate_threshold(self.estimator, self.scores_, + # SelectFromModel can directly call on transform. + if hasattr(self, "estimator_"): + estimator = self.estimator_ + else: + estimator = self.estimator + scores = _get_feature_importances(estimator) + self.threshold_ = _calculate_threshold(estimator, scores, self.threshold) - return self.scores_ >= self.threshold_ + return scores >= self.threshold_ def fit(self, X, y, **fit_params): """Fit the SelectFromModel meta-transformer. @@ -196,7 +207,6 @@ def fit(self, X, y, **fit_params): if not hasattr(self, "estimator_"): self.estimator_ = clone(self.estimator) self.estimator_.fit(X, y, **fit_params) - self.scores_ = _get_feature_importances(self.estimator_, X) return self def partial_fit(self, X, y, **fit_params): @@ -221,5 +231,4 @@ def partial_fit(self, X, y, **fit_params): if not hasattr(self, "estimator_"): self.estimator_ = clone(self.estimator) self.estimator_.partial_fit(X, y, **fit_params) - self.scores_ = _get_feature_importances(self.estimator_, X) return self diff --git a/sklearn/feature_selection/tests/test_from_model.py b/sklearn/feature_selection/tests/test_from_model.py index cb3b26105e8a6..de04e9ea877b3 100644 --- a/sklearn/feature_selection/tests/test_from_model.py +++ b/sklearn/feature_selection/tests/test_from_model.py @@ -7,6 +7,7 @@ from sklearn.utils.testing import assert_greater from sklearn.utils.testing import assert_equal from sklearn.utils.testing import assert_array_almost_equal +from sklearn.utils.testing import assert_array_equal from sklearn.utils.testing import assert_almost_equal from sklearn import datasets @@ -115,3 +116,15 @@ def test_warm_start(): transformer.fit(iris.data, iris.target) id_2 = id(transformer.estimator_) assert_equal(id_1, id_2) + + +def test_fitted_estimator(): + """Test that a fitted estimator can be passed to SelectFromModel.""" + clf = SGDClassifier(alpha=0.1, n_iter=10, shuffle=True, random_state=0) + model = SelectFromModel(clf) + model.fit(iris.data, iris.target) + X_transform = model.transform(iris.data) + + clf.fit(iris.data, iris.target) + model = SelectFromModel(clf) + assert_array_equal(model.transform(iris.data), X_transform) From 2ee718cc75b0874c623aeaa9db3d1161f7dcc518 Mon Sep 17 00:00:00 2001 From: MechCoder Date: Tue, 22 Sep 2015 17:05:16 -0400 Subject: [PATCH 08/11] Add narrative docs and fix examples --- doc/modules/classes.rst | 1 + doc/modules/feature_selection.rst | 41 +++++++++++++++---- .../ensemble/plot_feature_transformation.py | 6 +-- .../ensemble/plot_random_forest_embedding.py | 5 ++- ...el.py => plot_select_from_model_boston.py} | 10 ++--- sklearn/feature_selection/from_model.py | 7 ++-- .../tests/test_from_model.py | 19 ++++++++- sklearn/utils/testing.py | 2 +- 8 files changed, 69 insertions(+), 22 deletions(-) rename examples/feature_selection/{select_from_model.py => plot_select_from_model_boston.py} (86%) diff --git a/doc/modules/classes.rst b/doc/modules/classes.rst index a3f78fe167e8a..df7228596e424 100644 --- a/doc/modules/classes.rst +++ b/doc/modules/classes.rst @@ -463,6 +463,7 @@ From text feature_selection.SelectKBest feature_selection.SelectFpr feature_selection.SelectFdr + feature_selection.SelectFromModel feature_selection.SelectFwe feature_selection.RFE feature_selection.RFECV diff --git a/doc/modules/feature_selection.rst b/doc/modules/feature_selection.rst index edf2edf78ae96..e5df45ea4920f 100644 --- a/doc/modules/feature_selection.rst +++ b/doc/modules/feature_selection.rst @@ -131,6 +131,26 @@ number of features. elimination example with automatic tuning of the number of features selected with cross-validation. +.. _select_from_model: + +Feature selection using SelectFromModel +======================================= + +:class:`SelectFromModel` is a meta-transformer that can be used along with any +estimator that has a ``coef_`` or ``feature_importances_`` attribute after fitting. +It should be given a threshold parameter below which the features are considered +unimportant and removed. If one has no idea of the prior value of the threshold, +string inputs like "mean" or "median" or even values like "0.1*mean" can be given +which SelectFromModel parses internally. + +For examples on how it is to be used refer to the sections below. + +.. topic:: Examples + + * :ref:`example_feature_selection_select_from_model.py`: Selecting the two + most important features from the Boston dataset without knowing the + threshold beforehand. + .. _l1_feature_selection: @@ -145,19 +165,22 @@ Selecting non-zero coefficients :ref:`Linear models ` penalized with the L1 norm have sparse solutions: many of their estimated coefficients are zero. When the goal is to reduce the dimensionality of the data to use with another classifier, -they expose a ``transform`` method to select the non-zero coefficient. In -particular, sparse estimators useful for this purpose are the -:class:`linear_model.Lasso` for regression, and +they can be used along with :class:`feature_selection.SelectFromModel` +to select the non-zero coefficient. In particular, sparse estimators useful for +this purpose are the :class:`linear_model.Lasso` for regression, and of :class:`linear_model.LogisticRegression` and :class:`svm.LinearSVC` for classification:: >>> from sklearn.svm import LinearSVC >>> from sklearn.datasets import load_iris + >>> from sklearn.feature_selection import SelectFromModel >>> iris = load_iris() >>> X, y = iris.data, iris.target >>> X.shape (150, 4) - >>> X_new = LinearSVC(C=0.01, penalty="l1", dual=False).fit_transform(X, y) + >>> lsvc = LinearSVC(C=0.01, penalty="l1", dual=False).fit(X, y) + >>> model = SelectFromModel(lsvc) + >>> X_new = model.transform(X) >>> X_new.shape (150, 3) @@ -246,18 +269,22 @@ Tree-based feature selection Tree-based estimators (see the :mod:`sklearn.tree` module and forest of trees in the :mod:`sklearn.ensemble` module) can be used to compute feature importances, which in turn can be used to discard irrelevant -features:: +features (when coupled with the :class:`feature_selection.SelectFromModel` +meta-transformer):: >>> from sklearn.ensemble import ExtraTreesClassifier >>> from sklearn.datasets import load_iris + >>> from sklearn.feature_selection import SelectFromModel >>> iris = load_iris() >>> X, y = iris.data, iris.target >>> X.shape (150, 4) >>> clf = ExtraTreesClassifier() - >>> X_new = clf.fit(X, y).transform(X) + >>> clf = clf.fit(X, y) >>> clf.feature_importances_ # doctest: +SKIP array([ 0.04..., 0.05..., 0.4..., 0.4...]) + >>> model = SelectFromModel(clf) + >>> X_new = model.transform(X) >>> X_new.shape # doctest: +SKIP (150, 2) @@ -278,7 +305,7 @@ the actual learning. The recommended way to do this in scikit-learn is to use a :class:`sklearn.pipeline.Pipeline`:: clf = Pipeline([ - ('feature_selection', LinearSVC(penalty="l1")), + ('feature_selection', SelectFromModel(LinearSVC(penalty="l1"))), ('classification', RandomForestClassifier()) ]) clf.fit(X, y) diff --git a/examples/ensemble/plot_feature_transformation.py b/examples/ensemble/plot_feature_transformation.py index 6faf222cf9875..b8bdb6cf27ad8 100644 --- a/examples/ensemble/plot_feature_transformation.py +++ b/examples/ensemble/plot_feature_transformation.py @@ -34,6 +34,7 @@ from sklearn.linear_model import LogisticRegression from sklearn.ensemble import (RandomTreesEmbedding, RandomForestClassifier, GradientBoostingClassifier) +from sklearn.feature_selection import SelectFromModel from sklearn.preprocessing import OneHotEncoder from sklearn.cross_validation import train_test_split from sklearn.metrics import roc_curve @@ -53,12 +54,11 @@ rt = RandomTreesEmbedding(max_depth=3, n_estimators=n_estimator) rt_lm = LogisticRegression() rt.fit(X_train, y_train) -rt_lm.fit(rt.transform(X_train_lr), y_train_lr) +rt_lm.fit(SelectFromModel(rt).transform(X_train_lr), y_train_lr) -y_pred_rt = rt_lm.predict_proba(rt.transform(X_test))[:, 1] +y_pred_rt = rt_lm.predict_proba(SelectFromModel(rt).transform(X_test))[:, 1] fpr_rt_lm, tpr_rt_lm, _ = roc_curve(y_test, y_pred_rt) - # Supervised transformation based on random forests rf = RandomForestClassifier(max_depth=3, n_estimators=n_estimator) rf_enc = OneHotEncoder() diff --git a/examples/ensemble/plot_random_forest_embedding.py b/examples/ensemble/plot_random_forest_embedding.py index ba6329d72905b..9e38a18789245 100644 --- a/examples/ensemble/plot_random_forest_embedding.py +++ b/examples/ensemble/plot_random_forest_embedding.py @@ -30,6 +30,7 @@ from sklearn.datasets import make_circles from sklearn.ensemble import RandomTreesEmbedding, ExtraTreesClassifier from sklearn.decomposition import TruncatedSVD +from sklearn.feature_selection import SelectFromModel from sklearn.naive_bayes import BernoulliNB # make a synthetic dataset @@ -37,7 +38,9 @@ # use RandomTreesEmbedding to transform data hasher = RandomTreesEmbedding(n_estimators=10, random_state=0, max_depth=3) -X_transformed = hasher.fit_transform(X) +hasher.fit(X) +model = SelectFromModel(hasher) +X_transformed = model.transform(X) # Visualize result using PCA pca = TruncatedSVD(n_components=2) diff --git a/examples/feature_selection/select_from_model.py b/examples/feature_selection/plot_select_from_model_boston.py similarity index 86% rename from examples/feature_selection/select_from_model.py rename to examples/feature_selection/plot_select_from_model_boston.py index 972049b1b7aa7..17ef6d6bd0149 100644 --- a/examples/feature_selection/select_from_model.py +++ b/examples/feature_selection/plot_select_from_model_boston.py @@ -34,14 +34,14 @@ # Note that the attribute can be set directly instead of repeatedly # fitting the metatransformer. while n_features > 2: - sfm.threshold += 0.1 - X_transform = sfm.transform(X) - n_features = X_transform.shape[1] + sfm.threshold += 0.1 + X_transform = sfm.transform(X) + n_features = X_transform.shape[1] # Plot the selected two features from X. plt.title( - "Features selected from Boston using SelectFromModel with " - "threshold %0.3f." % sfm.threshold) + "Features selected from Boston using SelectFromModel with " + "threshold %0.3f." % sfm.threshold) feature1 = X_transform[:, 0] feature2 = X_transform[:, 1] plt.plot(feature1, feature2, 'r.') diff --git a/sklearn/feature_selection/from_model.py b/sklearn/feature_selection/from_model.py index d03d3d302e905..ac859b9292585 100644 --- a/sklearn/feature_selection/from_model.py +++ b/sklearn/feature_selection/from_model.py @@ -138,8 +138,7 @@ def transform(self, X, threshold=None): class SelectFromModel(BaseEstimator, SelectorMixin): - """Meta-transformer for selecting features based on importance - weights. + """Meta-transformer for selecting features based on importance weights. Parameters ---------- @@ -185,7 +184,7 @@ def _get_support_mask(self): self.threshold) return scores >= self.threshold_ - def fit(self, X, y, **fit_params): + def fit(self, X, y=None, **fit_params): """Fit the SelectFromModel meta-transformer. Parameters @@ -209,7 +208,7 @@ def fit(self, X, y, **fit_params): self.estimator_.fit(X, y, **fit_params) return self - def partial_fit(self, X, y, **fit_params): + def partial_fit(self, X, y=None, **fit_params): """Fit the SelectFromModel meta-transformer only once. Parameters diff --git a/sklearn/feature_selection/tests/test_from_model.py b/sklearn/feature_selection/tests/test_from_model.py index de04e9ea877b3..776b2af2c287a 100644 --- a/sklearn/feature_selection/tests/test_from_model.py +++ b/sklearn/feature_selection/tests/test_from_model.py @@ -3,12 +3,14 @@ from nose.tools import assert_raises, assert_true +from sklearn.utils import warnings from sklearn.utils.testing import assert_less from sklearn.utils.testing import assert_greater from sklearn.utils.testing import assert_equal from sklearn.utils.testing import assert_array_almost_equal from sklearn.utils.testing import assert_array_equal from sklearn.utils.testing import assert_almost_equal +from sklearn.utils.testing import clean_warning_registry from sklearn import datasets from sklearn.linear_model import LogisticRegression @@ -31,7 +33,9 @@ def test_transform_linear_model(): X = func(iris.data) clf.set_params(penalty="l1") clf.fit(X, iris.target) - X_new = clf.transform(X, thresh) + clean_warning_registry() + with warnings.catch_warnings(record=True) as record: + X_new = clf.transform(X, thresh) if isinstance(clf, SGDClassifier): assert_true(X_new.shape[1] <= X.shape[1]) else: @@ -128,3 +132,16 @@ def test_fitted_estimator(): clf.fit(iris.data, iris.target) model = SelectFromModel(clf) assert_array_equal(model.transform(iris.data), X_transform) + + +def test_threshold_string(): + est = RandomForestClassifier(n_estimators=50, random_state=0) + model = SelectFromModel(est, threshold="0.5*mean") + model.fit(iris.data, iris.target) + X_transform = model.transform(iris.data) + + # Calculate the threshold from the estimator directly. + est.fit(iris.data, iris.target) + threshold = 0.5 * np.mean(est.feature_importances_) + model = SelectFromModel(est, threshold=threshold) + assert_array_equal(X_transform, model.transform(iris.data)) diff --git a/sklearn/utils/testing.py b/sklearn/utils/testing.py index 5c481419f1328..75f5b4c314c34 100644 --- a/sklearn/utils/testing.py +++ b/sklearn/utils/testing.py @@ -529,7 +529,7 @@ def uninstall_mldata_mock(): "RFECV", "BaseEnsemble"] # estimators that there is no way to default-construct sensibly OTHER = ["Pipeline", "FeatureUnion", "GridSearchCV", "RandomizedSearchCV", - "GridSearchCV", "SelectFromModel"] + "SelectFromModel"] # some trange ones DONT_TEST = ['SparseCoder', 'EllipticEnvelope', 'DictVectorizer', From acf5f160a8620d942dca205222924ca0ee84f92d Mon Sep 17 00:00:00 2001 From: MechCoder Date: Thu, 24 Sep 2015 13:06:09 -0400 Subject: [PATCH 09/11] Merge SelectFromModel and L1-selection examples Add test to check the threshold can be set without refitting --- doc/modules/feature_selection.rst | 23 ++++++++----------- sklearn/feature_selection/from_model.py | 9 ++++---- .../tests/test_from_model.py | 17 +++++++++++++- 3 files changed, 29 insertions(+), 20 deletions(-) diff --git a/doc/modules/feature_selection.rst b/doc/modules/feature_selection.rst index e5df45ea4920f..e6c684e251280 100644 --- a/doc/modules/feature_selection.rst +++ b/doc/modules/feature_selection.rst @@ -138,35 +138,29 @@ Feature selection using SelectFromModel :class:`SelectFromModel` is a meta-transformer that can be used along with any estimator that has a ``coef_`` or ``feature_importances_`` attribute after fitting. -It should be given a threshold parameter below which the features are considered +It should be given a ``threshold`` parameter below which the features are considered unimportant and removed. If one has no idea of the prior value of the threshold, -string inputs like "mean" or "median" or even values like "0.1*mean" can be given -which SelectFromModel parses internally. +string inputs like ``"mean"`` or ``"median"`` or even values like ``"0.1*mean"`` +can be given which :class:`SelectFromModel` parses internally. For examples on how it is to be used refer to the sections below. .. topic:: Examples - * :ref:`example_feature_selection_select_from_model.py`: Selecting the two + * :ref:`example_feature_selection_plot_select_from_model_boston.py`: Selecting the two most important features from the Boston dataset without knowing the threshold beforehand. - -.. _l1_feature_selection: - L1-based feature selection -========================== +-------------------------- .. currentmodule:: sklearn -Selecting non-zero coefficients ---------------------------------- - :ref:`Linear models ` penalized with the L1 norm have sparse solutions: many of their estimated coefficients are zero. When the goal is to reduce the dimensionality of the data to use with another classifier, they can be used along with :class:`feature_selection.SelectFromModel` -to select the non-zero coefficient. In particular, sparse estimators useful for +to select the non-zero coefficients. In particular, sparse estimators useful for this purpose are the :class:`linear_model.Lasso` for regression, and of :class:`linear_model.LogisticRegression` and :class:`svm.LinearSVC` for classification:: @@ -264,12 +258,12 @@ of features non zero. http://hal.inria.fr/hal-00354771/ Tree-based feature selection -============================ +---------------------------- Tree-based estimators (see the :mod:`sklearn.tree` module and forest of trees in the :mod:`sklearn.ensemble` module) can be used to compute feature importances, which in turn can be used to discard irrelevant -features (when coupled with the :class:`feature_selection.SelectFromModel` +features (when coupled with the :class:`sklearn.feature_selection.SelectFromModel` meta-transformer):: >>> from sklearn.ensemble import ExtraTreesClassifier @@ -311,6 +305,7 @@ to use a :class:`sklearn.pipeline.Pipeline`:: clf.fit(X, y) In this snippet we make use of a :class:`sklearn.svm.LinearSVC` +coupled with :class:`sklearn.feature_selection.SelectFromModel` to evaluate feature importances and select the most relevant features. Then, a :class:`sklearn.ensemble.RandomForestClassifier` is trained on the transformed output, i.e. using only relevant features. You can perform diff --git a/sklearn/feature_selection/from_model.py b/sklearn/feature_selection/from_model.py index ac859b9292585..8a532f2904679 100644 --- a/sklearn/feature_selection/from_model.py +++ b/sklearn/feature_selection/from_model.py @@ -145,14 +145,14 @@ class SelectFromModel(BaseEstimator, SelectorMixin): estimator : object The base estimator from which the transformer is built. This can be both a fitted or a non-fitted estimator. - If it a fitted estimator, then transform can be called directly, - otherwise train the model using fit and then transform to do + If it a fitted estimator, then ``transform`` can be called directly, + otherwise train the model using ``fit`` and then ``transform`` to do feature selection. threshold : string, float, optional The threshold value to use for feature selection. Features whose importance is greater or equal are kept while the others are - discarded. If "median" (resp. "mean"), then the threshold value is + discarded. If "median" (resp. "mean"), then the ``threshold`` value is the median (resp. the mean) of the feature importances. A scaling factor (e.g., "1.25*mean") may also be used. If None and if available, the object attribute ``threshold`` is used. Otherwise, @@ -163,12 +163,11 @@ class SelectFromModel(BaseEstimator, SelectorMixin): `estimator_`: an estimator The base estimator from which the transformer is built. This is stored only when a non-fitted estimator is passed to the - SelectFromModel. + ``SelectFromModel``. `threshold_`: float The threshold value used for feature selection. """ - def __init__(self, estimator, threshold=None): self.estimator = estimator self.threshold = threshold diff --git a/sklearn/feature_selection/tests/test_from_model.py b/sklearn/feature_selection/tests/test_from_model.py index 776b2af2c287a..c07270a36bc3a 100644 --- a/sklearn/feature_selection/tests/test_from_model.py +++ b/sklearn/feature_selection/tests/test_from_model.py @@ -123,7 +123,10 @@ def test_warm_start(): def test_fitted_estimator(): - """Test that a fitted estimator can be passed to SelectFromModel.""" + """Test that a fitted estimator can be passed to SelectFromModel. + + If this is done fit need not be used and transform can be used directly. + """ clf = SGDClassifier(alpha=0.1, n_iter=10, shuffle=True, random_state=0) model = SelectFromModel(clf) model.fit(iris.data, iris.target) @@ -145,3 +148,15 @@ def test_threshold_string(): threshold = 0.5 * np.mean(est.feature_importances_) model = SelectFromModel(est, threshold=threshold) assert_array_equal(X_transform, model.transform(iris.data)) + + +def test_threshold_without_refitting(): + """Test that the threshold can be set without refitting the model.""" + clf = SGDClassifier(alpha=0.1, n_iter=10, shuffle=True, random_state=0) + model = SelectFromModel(clf, threshold=0.1) + model.fit(iris.data, iris.target) + X_transform = model.transform(iris.data) + + # Set a higher threshold to filter out more features. + model.threshold = 1.0 + assert_greater(X_transform.shape[1], model.transform(iris.data).shape[1]) From 5a0db1717a66ce39bae9b7c6f27e2e6e3f6c7647 Mon Sep 17 00:00:00 2001 From: MechCoder Date: Thu, 8 Oct 2015 16:49:26 -0400 Subject: [PATCH 10/11] 1. Added parameter prefit to pass in a fitted estimator. 2. Use assert_warns instead of catch_warnings 3. Remove depracation warnings in common tests. --- doc/modules/feature_selection.rst | 4 +- .../ensemble/plot_feature_transformation.py | 5 ++- .../ensemble/plot_random_forest_embedding.py | 2 +- sklearn/ensemble/tests/test_forest.py | 10 ++--- .../ensemble/tests/test_gradient_boosting.py | 20 +++++----- sklearn/feature_selection/from_model.py | 39 +++++++++++++++---- .../tests/test_from_model.py | 20 +++++----- sklearn/tree/tests/test_tree.py | 13 +++---- sklearn/utils/estimator_checks.py | 18 +++++++-- 9 files changed, 80 insertions(+), 51 deletions(-) diff --git a/doc/modules/feature_selection.rst b/doc/modules/feature_selection.rst index e6c684e251280..6534c07c5db0b 100644 --- a/doc/modules/feature_selection.rst +++ b/doc/modules/feature_selection.rst @@ -173,7 +173,7 @@ for classification:: >>> X.shape (150, 4) >>> lsvc = LinearSVC(C=0.01, penalty="l1", dual=False).fit(X, y) - >>> model = SelectFromModel(lsvc) + >>> model = SelectFromModel(lsvc, prefit=True) >>> X_new = model.transform(X) >>> X_new.shape (150, 3) @@ -277,7 +277,7 @@ meta-transformer):: >>> clf = clf.fit(X, y) >>> clf.feature_importances_ # doctest: +SKIP array([ 0.04..., 0.05..., 0.4..., 0.4...]) - >>> model = SelectFromModel(clf) + >>> model = SelectFromModel(clf, prefit=True) >>> X_new = model.transform(X) >>> X_new.shape # doctest: +SKIP (150, 2) diff --git a/examples/ensemble/plot_feature_transformation.py b/examples/ensemble/plot_feature_transformation.py index b8bdb6cf27ad8..7b6941846ab54 100644 --- a/examples/ensemble/plot_feature_transformation.py +++ b/examples/ensemble/plot_feature_transformation.py @@ -54,9 +54,10 @@ rt = RandomTreesEmbedding(max_depth=3, n_estimators=n_estimator) rt_lm = LogisticRegression() rt.fit(X_train, y_train) -rt_lm.fit(SelectFromModel(rt).transform(X_train_lr), y_train_lr) +rt_lm.fit(SelectFromModel(rt, prefit=True).transform(X_train_lr), y_train_lr) -y_pred_rt = rt_lm.predict_proba(SelectFromModel(rt).transform(X_test))[:, 1] +y_pred_rt = rt_lm.predict_proba( + SelectFromModel(rt, prefit=True).transform(X_test))[:, 1] fpr_rt_lm, tpr_rt_lm, _ = roc_curve(y_test, y_pred_rt) # Supervised transformation based on random forests diff --git a/examples/ensemble/plot_random_forest_embedding.py b/examples/ensemble/plot_random_forest_embedding.py index 9e38a18789245..eef04ac3336c4 100644 --- a/examples/ensemble/plot_random_forest_embedding.py +++ b/examples/ensemble/plot_random_forest_embedding.py @@ -39,7 +39,7 @@ # use RandomTreesEmbedding to transform data hasher = RandomTreesEmbedding(n_estimators=10, random_state=0, max_depth=3) hasher.fit(X) -model = SelectFromModel(hasher) +model = SelectFromModel(hasher, prefit=True) X_transformed = model.transform(X) # Visualize result using PCA diff --git a/sklearn/ensemble/tests/test_forest.py b/sklearn/ensemble/tests/test_forest.py index 449fc91a6a420..631c726f0381e 100644 --- a/sklearn/ensemble/tests/test_forest.py +++ b/sklearn/ensemble/tests/test_forest.py @@ -29,7 +29,6 @@ from sklearn.utils.testing import assert_greater_equal from sklearn.utils.testing import assert_raises from sklearn.utils.testing import assert_warns -from sklearn.utils.testing import clean_warning_registry from sklearn.utils.testing import ignore_warnings from sklearn import datasets @@ -204,10 +203,11 @@ def check_importances(X, y, name, criterion): assert_equal(importances.shape[0], 10) assert_equal(n_important, 3) - clean_warning_registry() - with warnings.catch_warnings(record=True) as record: - X_new = est.transform(X, threshold="mean") - assert_less(0 < X_new.shape[1], X.shape[1]) + # XXX: Remove this test in 0.19 after transform support to estimators + # is removed. + X_new = assert_warns( + DeprecationWarning, est.transform, X, threshold="mean") + assert_less(0 < X_new.shape[1], X.shape[1]) # Check with parallel importances = est.feature_importances_ diff --git a/sklearn/ensemble/tests/test_gradient_boosting.py b/sklearn/ensemble/tests/test_gradient_boosting.py index c9ff843c27622..4f2329be50358 100644 --- a/sklearn/ensemble/tests/test_gradient_boosting.py +++ b/sklearn/ensemble/tests/test_gradient_boosting.py @@ -16,7 +16,7 @@ from sklearn.ensemble import GradientBoostingRegressor from sklearn.ensemble.gradient_boosting import ZeroEstimator from sklearn.metrics import mean_squared_error -from sklearn.utils import check_random_state, tosequence, warnings +from sklearn.utils import check_random_state, tosequence from sklearn.utils.testing import assert_almost_equal from sklearn.utils.testing import assert_array_almost_equal from sklearn.utils.testing import assert_array_equal @@ -26,7 +26,6 @@ from sklearn.utils.testing import assert_raises from sklearn.utils.testing import assert_true from sklearn.utils.testing import assert_warns -from sklearn.utils.testing import clean_warning_registry from sklearn.utils.testing import ignore_warnings from sklearn.utils.validation import DataConversionWarning from sklearn.utils.validation import NotFittedError @@ -297,16 +296,15 @@ def test_feature_importances(): presort=presort) clf.fit(X, y) assert_true(hasattr(clf, 'feature_importances_')) - clean_warning_registry() - with warnings.catch_warnings(record=True) as record: - X_new = clf.transform(X, threshold="mean") - assert_less(X_new.shape[1], X.shape[1]) - X_new = clf.transform(X, threshold="mean") - assert_less(X_new.shape[1], X.shape[1]) - - feature_mask = clf.feature_importances_ > clf.feature_importances_.mean() - assert_array_almost_equal(X_new, X[:, feature_mask]) + # XXX: Remove this test in 0.19 after transform support to estimators + # is removed. + X_new = assert_warns( + DeprecationWarning, clf.transform, X, threshold="mean") + assert_less(X_new.shape[1], X.shape[1]) + feature_mask = ( + clf.feature_importances_ > clf.feature_importances_.mean()) + assert_array_almost_equal(X_new, X[:, feature_mask]) def test_probability_log(): diff --git a/sklearn/feature_selection/from_model.py b/sklearn/feature_selection/from_model.py index 8a532f2904679..7d18511e16bcd 100644 --- a/sklearn/feature_selection/from_model.py +++ b/sklearn/feature_selection/from_model.py @@ -66,6 +66,10 @@ def _calculate_threshold(estimator, importances, threshold): elif threshold == "mean": threshold = np.mean(importances) + else: + raise ValueError("Expected threshold='mean' or threshold='median' " + "got %s" % threshold) + else: threshold = float(threshold) @@ -144,10 +148,8 @@ class SelectFromModel(BaseEstimator, SelectorMixin): ---------- estimator : object The base estimator from which the transformer is built. - This can be both a fitted or a non-fitted estimator. - If it a fitted estimator, then ``transform`` can be called directly, - otherwise train the model using ``fit`` and then ``transform`` to do - feature selection. + This can be both a fitted (if ``prefit`` is set to True) + or a non-fitted estimator. threshold : string, float, optional The threshold value to use for feature selection. Features whose @@ -158,26 +160,39 @@ class SelectFromModel(BaseEstimator, SelectorMixin): available, the object attribute ``threshold`` is used. Otherwise, "mean" is used by default. + prefit : bool, default True + Whether a prefit model is expected to be passed into the constructor + directly or not. If True, ``transform`` must be called directly + and SelectFromModel cannot be used with ``cross_val_score``, + ``GridSearchCV`` and similar utilities that clone the estimator. + Otherwise train the model using ``fit`` and then ``transform`` to do + feature selection. + Attributes ---------- `estimator_`: an estimator The base estimator from which the transformer is built. This is stored only when a non-fitted estimator is passed to the - ``SelectFromModel``. + ``SelectFromModel``, i.e when prefit is False. `threshold_`: float The threshold value used for feature selection. """ - def __init__(self, estimator, threshold=None): + def __init__(self, estimator, threshold=None, prefit=False): self.estimator = estimator self.threshold = threshold + self.prefit = prefit def _get_support_mask(self): # SelectFromModel can directly call on transform. - if hasattr(self, "estimator_"): + if self.prefit: + estimator = self.estimator + elif hasattr(self, 'estimator_'): estimator = self.estimator_ else: - estimator = self.estimator + raise ValueError( + 'Either fit the model before transform or set "prefit=True"' + ' while passing the fitted estimator to the constructor.') scores = _get_feature_importances(estimator) self.threshold_ = _calculate_threshold(estimator, scores, self.threshold) @@ -202,6 +217,10 @@ def fit(self, X, y=None, **fit_params): self : object Returns self. """ + if self.prefit: + raise ValueError( + 'Fitting will overwrite your already fitted model. Call ' + 'transform directly.') if not hasattr(self, "estimator_"): self.estimator_ = clone(self.estimator) self.estimator_.fit(X, y, **fit_params) @@ -226,6 +245,10 @@ def partial_fit(self, X, y=None, **fit_params): self : object Returns self. """ + if self.prefit: + raise ValueError( + 'Fitting will overwrite your already fitted model. Call ' + 'transform directly.') if not hasattr(self, "estimator_"): self.estimator_ = clone(self.estimator) self.estimator_.partial_fit(X, y, **fit_params) diff --git a/sklearn/feature_selection/tests/test_from_model.py b/sklearn/feature_selection/tests/test_from_model.py index c07270a36bc3a..aaf4b2c89f06b 100644 --- a/sklearn/feature_selection/tests/test_from_model.py +++ b/sklearn/feature_selection/tests/test_from_model.py @@ -3,14 +3,13 @@ from nose.tools import assert_raises, assert_true -from sklearn.utils import warnings from sklearn.utils.testing import assert_less from sklearn.utils.testing import assert_greater from sklearn.utils.testing import assert_equal from sklearn.utils.testing import assert_array_almost_equal from sklearn.utils.testing import assert_array_equal from sklearn.utils.testing import assert_almost_equal -from sklearn.utils.testing import clean_warning_registry +from sklearn.utils.testing import assert_warns from sklearn import datasets from sklearn.linear_model import LogisticRegression @@ -33,9 +32,8 @@ def test_transform_linear_model(): X = func(iris.data) clf.set_params(penalty="l1") clf.fit(X, iris.target) - clean_warning_registry() - with warnings.catch_warnings(record=True) as record: - X_new = clf.transform(X, thresh) + X_new = assert_warns( + DeprecationWarning, clf.transform, X, thresh) if isinstance(clf, SGDClassifier): assert_true(X_new.shape[1] <= X.shape[1]) else: @@ -48,10 +46,10 @@ def test_transform_linear_model(): def test_invalid_input(): clf = SGDClassifier(alpha=0.1, n_iter=10, shuffle=True, random_state=None) - - clf.fit(iris.data, iris.target) - assert_raises(ValueError, clf.transform, iris.data, "gobbledigook") - assert_raises(ValueError, clf.transform, iris.data, ".5 * gobbledigook") + for threshold in ["gobbledigook", ".5 * gobbledigook"]: + model = SelectFromModel(clf, threshold=threshold) + model.fit(iris.data, iris.target) + assert_raises(ValueError, model.transform, iris.data) def test_validate_estimator(): @@ -133,7 +131,7 @@ def test_fitted_estimator(): X_transform = model.transform(iris.data) clf.fit(iris.data, iris.target) - model = SelectFromModel(clf) + model = SelectFromModel(clf, prefit=True) assert_array_equal(model.transform(iris.data), X_transform) @@ -146,7 +144,7 @@ def test_threshold_string(): # Calculate the threshold from the estimator directly. est.fit(iris.data, iris.target) threshold = 0.5 * np.mean(est.feature_importances_) - model = SelectFromModel(est, threshold=threshold) + model = SelectFromModel(est, threshold=threshold, prefit=True) assert_array_equal(X_transform, model.transform(iris.data)) diff --git a/sklearn/tree/tests/test_tree.py b/sklearn/tree/tests/test_tree.py index 37787d89092c8..0573fc0108e17 100644 --- a/sklearn/tree/tests/test_tree.py +++ b/sklearn/tree/tests/test_tree.py @@ -16,7 +16,6 @@ from sklearn.metrics import accuracy_score from sklearn.metrics import mean_squared_error -from sklearn.utils import warnings from sklearn.utils.testing import assert_array_equal from sklearn.utils.testing import assert_array_almost_equal from sklearn.utils.testing import assert_almost_equal @@ -27,7 +26,7 @@ from sklearn.utils.testing import assert_greater_equal from sklearn.utils.testing import assert_less from sklearn.utils.testing import assert_true -from sklearn.utils.testing import clean_warning_registry +from sklearn.utils.testing import assert_warns from sklearn.utils.testing import raises from sklearn.utils.validation import check_random_state @@ -380,12 +379,10 @@ def test_importances(): assert_equal(importances.shape[0], 10, "Failed with {0}".format(name)) assert_equal(n_important, 3, "Failed with {0}".format(name)) - - clean_warning_registry() - with warnings.catch_warnings(record=True) as record: - X_new = clf.transform(X, threshold="mean") - assert_less(0, X_new.shape[1], "Failed with {0}".format(name)) - assert_less(X_new.shape[1], X.shape[1], "Failed with {0}".format(name)) + X_new = assert_warns( + DeprecationWarning, clf.transform, X, threshold="mean") + assert_less(0, X_new.shape[1], "Failed with {0}".format(name)) + assert_less(X_new.shape[1], X.shape[1], "Failed with {0}".format(name)) # Check on iris that importances are the same for all builders clf = DecisionTreeClassifier(random_state=0) diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index 2832dc1974229..4af8506210fbd 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -61,6 +61,16 @@ 'RANSACRegressor', 'RadiusNeighborsRegressor', 'RandomForestRegressor', 'Ridge', 'RidgeCV'] +# Estimators with deprecated transform methods. Can be removed in 0.19 when +# _LearntSelectorMixin is removed. +DEPRECATED_TRANSFORM = [ + "RandomForestClassifier", "RandomForestRegressor", "ExtraTreesClassifier", + "ExtraTreesRegressor", "RandomTreesEmbedding", "DecisionTreeClassifier", + "DecisionTreeRegressor", "ExtraTreeClassifier", "ExtraTreeRegressor", + "LinearSVC", "SGDClassifier", "SGDRegressor", "Perceptron", + "LogisticRegression", "LogisticRegressionCV", + "GradientBoostingClassifier", "GradientBoostingRegressor"] + def _yield_non_meta_checks(name, Estimator): yield check_estimators_dtypes @@ -168,8 +178,9 @@ def _yield_all_checks(name, Estimator): for check in _yield_regressor_checks(name, Estimator): yield check if issubclass(Estimator, TransformerMixin): - for check in _yield_transformer_checks(name, Estimator): - yield check + if name not in DEPRECATED_TRANSFORM: + for check in _yield_transformer_checks(name, Estimator): + yield check if issubclass(Estimator, ClusterMixin): for check in _yield_clustering_checks(name, Estimator): yield check @@ -329,7 +340,8 @@ def check_dtype_object(name, Estimator): if hasattr(estimator, "predict"): estimator.predict(X) - if hasattr(estimator, "transform"): + if (hasattr(estimator, "transform") and + name not in DEPRECATED_TRANSFORM): estimator.transform(X) try: From c805fbc79b76b5e96e748602699f209a06671f63 Mon Sep 17 00:00:00 2001 From: MechCoder Date: Fri, 9 Oct 2015 19:35:03 -0400 Subject: [PATCH 11/11] Refactor tests --- doc/modules/feature_selection.rst | 10 +- doc/whats_new.rst | 2 +- sklearn/feature_selection/from_model.py | 25 ++-- .../tests/test_from_model.py | 140 ++++++++++-------- sklearn/utils/estimator_checks.py | 35 ++++- sklearn/utils/tests/test_estimator_checks.py | 4 +- 6 files changed, 127 insertions(+), 89 deletions(-) diff --git a/doc/modules/feature_selection.rst b/doc/modules/feature_selection.rst index 6534c07c5db0b..88ff7d56d6865 100644 --- a/doc/modules/feature_selection.rst +++ b/doc/modules/feature_selection.rst @@ -138,10 +138,12 @@ Feature selection using SelectFromModel :class:`SelectFromModel` is a meta-transformer that can be used along with any estimator that has a ``coef_`` or ``feature_importances_`` attribute after fitting. -It should be given a ``threshold`` parameter below which the features are considered -unimportant and removed. If one has no idea of the prior value of the threshold, -string inputs like ``"mean"`` or ``"median"`` or even values like ``"0.1*mean"`` -can be given which :class:`SelectFromModel` parses internally. +The features are considered unimportant and removed, if the corresponding +``coef_`` or ``feature_importances_`` values are below the provided +``threshold`` parameter. Apart from specifying the threshold numerically, +there are build-in heuristics for finding a threshold using a string argument. +Available heuristics are "mean", "median" and float multiples of these like +"0.1*mean". For examples on how it is to be used refer to the sections below. diff --git a/doc/whats_new.rst b/doc/whats_new.rst index fdbc12f887d44..d808cf6ad02f1 100644 --- a/doc/whats_new.rst +++ b/doc/whats_new.rst @@ -261,7 +261,7 @@ API changes summary caused confusion in how the array elements should be interpreted as features or as samples. All data arrays are now expected to be explicitly shaped ``(n_samples, n_features)``. - By `Vighnesh Birodkar`_ + By `Vighnesh Birodkar`_. - :class:`lda.LDA` and :class:`qda.QDA` have been moved to :class:`discriminant_analysis.LinearDiscriminantAnalysis` and diff --git a/sklearn/feature_selection/from_model.py b/sklearn/feature_selection/from_model.py index 7d18511e16bcd..81e35a8000adf 100644 --- a/sklearn/feature_selection/from_model.py +++ b/sklearn/feature_selection/from_model.py @@ -39,7 +39,9 @@ def _calculate_threshold(estimator, importances, threshold): if threshold is None: # determine default from estimator - if hasattr(estimator, "penalty") and estimator.penalty == "l1": + est_name = estimator.__class__.__name__ + if ((hasattr(estimator, "penalty") and estimator.penalty == "l1") or + "Lasso" in est_name): # the natural default threshold is 0 when l1 penalty was used threshold = 1e-5 else: @@ -151,16 +153,17 @@ class SelectFromModel(BaseEstimator, SelectorMixin): This can be both a fitted (if ``prefit`` is set to True) or a non-fitted estimator. - threshold : string, float, optional + threshold : string, float, optional default None The threshold value to use for feature selection. Features whose importance is greater or equal are kept while the others are discarded. If "median" (resp. "mean"), then the ``threshold`` value is the median (resp. the mean) of the feature importances. A scaling - factor (e.g., "1.25*mean") may also be used. If None and if - available, the object attribute ``threshold`` is used. Otherwise, - "mean" is used by default. + factor (e.g., "1.25*mean") may also be used. If None and if the + estimator has a parameter penalty set to l1, either explicitly + or implicity (e.g, Lasso), the threshold is used is 1e-5. + Otherwise, "mean" is used by default. - prefit : bool, default True + prefit : bool, default False Whether a prefit model is expected to be passed into the constructor directly or not. If True, ``transform`` must be called directly and SelectFromModel cannot be used with ``cross_val_score``, @@ -218,9 +221,8 @@ def fit(self, X, y=None, **fit_params): Returns self. """ if self.prefit: - raise ValueError( - 'Fitting will overwrite your already fitted model. Call ' - 'transform directly.') + raise NotFittedError( + "Since 'prefit=True', call transform directly") if not hasattr(self, "estimator_"): self.estimator_ = clone(self.estimator) self.estimator_.fit(X, y, **fit_params) @@ -246,9 +248,8 @@ def partial_fit(self, X, y=None, **fit_params): Returns self. """ if self.prefit: - raise ValueError( - 'Fitting will overwrite your already fitted model. Call ' - 'transform directly.') + raise NotFittedError( + "Since 'prefit=True', call transform directly") if not hasattr(self, "estimator_"): self.estimator_ = clone(self.estimator) self.estimator_.partial_fit(X, y, **fit_params) diff --git a/sklearn/feature_selection/tests/test_from_model.py b/sklearn/feature_selection/tests/test_from_model.py index aaf4b2c89f06b..f28426d515196 100644 --- a/sklearn/feature_selection/tests/test_from_model.py +++ b/sklearn/feature_selection/tests/test_from_model.py @@ -12,15 +12,15 @@ from sklearn.utils.testing import assert_warns from sklearn import datasets -from sklearn.linear_model import LogisticRegression -from sklearn.linear_model import SGDClassifier +from sklearn.linear_model import LogisticRegression, SGDClassifier, Lasso from sklearn.svm import LinearSVC from sklearn.feature_selection import SelectFromModel from sklearn.ensemble import RandomForestClassifier from sklearn.linear_model import PassiveAggressiveClassifier iris = datasets.load_iris() - +data, y = iris.data, iris.target +rng = np.random.RandomState(0) def test_transform_linear_model(): for clf in (LogisticRegression(C=0.1), @@ -29,9 +29,9 @@ def test_transform_linear_model(): random_state=0)): for thresh in (None, ".09*mean", "1e-5 * median"): for func in (np.array, sp.csr_matrix): - X = func(iris.data) + X = func(data) clf.set_params(penalty="l1") - clf.fit(X, iris.target) + clf.fit(X, y) X_new = assert_warns( DeprecationWarning, clf.transform, X, thresh) if isinstance(clf, SGDClassifier): @@ -39,47 +39,46 @@ def test_transform_linear_model(): else: assert_less(X_new.shape[1], X.shape[1]) clf.set_params(penalty="l2") - clf.fit(X_new, iris.target) + clf.fit(X_new, y) pred = clf.predict(X_new) - assert_greater(np.mean(pred == iris.target), 0.7) + assert_greater(np.mean(pred == y), 0.7) def test_invalid_input(): clf = SGDClassifier(alpha=0.1, n_iter=10, shuffle=True, random_state=None) for threshold in ["gobbledigook", ".5 * gobbledigook"]: model = SelectFromModel(clf, threshold=threshold) - model.fit(iris.data, iris.target) - assert_raises(ValueError, model.transform, iris.data) + model.fit(data, y) + assert_raises(ValueError, model.transform, data) -def test_validate_estimator(): +def test_input_estimator_unchanged(): + """ + Test that SelectFromModel fits on a clone of the estimator. + """ est = RandomForestClassifier() transformer = SelectFromModel(estimator=est) - transformer.fit(iris.data, iris.target) - assert_equal(transformer.estimator, est) + transformer.fit(data, y) + assert_true(transformer.estimator is est) def test_feature_importances(): - X, y = datasets.make_classification(n_samples=1000, - n_features=10, - n_informative=3, - n_redundant=0, - n_repeated=0, - shuffle=False, - random_state=0) + X, y = datasets.make_classification( + n_samples=1000, n_features=10, n_informative=3, n_redundant=0, + n_repeated=0, shuffle=False, random_state=0) est = RandomForestClassifier(n_estimators=50, random_state=0) - transformer = SelectFromModel(estimator=est) + for threshold, func in zip(["mean", "median"], [np.mean, np.median]): + transformer = SelectFromModel(estimator=est, threshold=threshold) + transformer.fit(X, y) + assert_true(hasattr(transformer.estimator_, 'feature_importances_')) - transformer.fit(X, y) - assert_true(hasattr(transformer.estimator_, 'feature_importances_')) + X_new = transformer.transform(X) + assert_less(X_new.shape[1], X.shape[1]) + importances = transformer.estimator_.feature_importances_ - X_new = transformer.transform(X) - assert_less(X_new.shape[1], X.shape[1]) - - feature_mask = (transformer.estimator_.feature_importances_ > - transformer.estimator_.feature_importances_.mean()) - assert_array_almost_equal(X_new, X[:, feature_mask]) + feature_mask = np.abs(importances) > func(importances) + assert_array_almost_equal(X_new, X[:, feature_mask]) # Check with sample weights sample_weight = np.ones(y.shape) @@ -89,72 +88,89 @@ def test_feature_importances(): transformer = SelectFromModel(estimator=est) transformer.fit(X, y, sample_weight=sample_weight) importances = transformer.estimator_.feature_importances_ - assert_less(importances[1], X.shape[1]) - - est = RandomForestClassifier(n_estimators=50, random_state=0) - transformer = SelectFromModel(estimator=est) transformer.fit(X, y, sample_weight=3*sample_weight) importances_bis = transformer.estimator_.feature_importances_ assert_almost_equal(importances, importances_bis) + # For the Lasso and related models, the threshold defaults to 1e-5 + transformer = SelectFromModel(estimator=Lasso(alpha=0.1)) + transformer.fit(X, y) + X_new = transformer.transform(X) + mask = np.abs(transformer.estimator_.coef_) > 1e-5 + assert_array_equal(X_new, X[:, mask]) + def test_partial_fit(): - est = PassiveAggressiveClassifier() + est = PassiveAggressiveClassifier(random_state=0, shuffle=False) transformer = SelectFromModel(estimator=est) - transformer.partial_fit(iris.data, iris.target, - classes=np.unique(iris.target)) - id_1 = id(transformer.estimator_) - transformer.partial_fit(iris.data, iris.target, - classes=np.unique(iris.target)) - id_2 = id(transformer.estimator_) - assert_equal(id_1, id_2) + transformer.partial_fit(data, y, + classes=np.unique(y)) + old_model = transformer.estimator_ + transformer.partial_fit(data, y, + classes=np.unique(y)) + new_model = transformer.estimator_ + assert_true(old_model is new_model) + + X_transform = transformer.transform(data) + transformer.fit(np.vstack((data, data)), np.concatenate((y, y))) + assert_array_equal(X_transform, transformer.transform(data)) def test_warm_start(): - est = PassiveAggressiveClassifier(warm_start=True) + est = PassiveAggressiveClassifier(warm_start=True, random_state=0) transformer = SelectFromModel(estimator=est) - transformer.fit(iris.data, iris.target) - id_1 = id(transformer.estimator_) - transformer.fit(iris.data, iris.target) - id_2 = id(transformer.estimator_) - assert_equal(id_1, id_2) + transformer.fit(data, y) + old_model = transformer.estimator_ + transformer.fit(data, y) + new_model = transformer.estimator_ + assert_true(old_model is new_model) -def test_fitted_estimator(): - """Test that a fitted estimator can be passed to SelectFromModel. - - If this is done fit need not be used and transform can be used directly. +def test_prefit(): + """ + Test all possible combinations of the prefit parameter. """ + # Passing a prefit parameter with the selected model + # and fitting a unfit model with prefit=False should give same results. clf = SGDClassifier(alpha=0.1, n_iter=10, shuffle=True, random_state=0) model = SelectFromModel(clf) - model.fit(iris.data, iris.target) - X_transform = model.transform(iris.data) + model.fit(data, y) + X_transform = model.transform(data) + clf.fit(data, y) + model = SelectFromModel(clf, prefit=True) + assert_array_equal(model.transform(data), X_transform) + + # Check that the model is rewritten if prefit=False and a fitted model is + # passed + model = SelectFromModel(clf, prefit=False) + model.fit(data, y) + assert_array_equal(model.transform(data), X_transform) - clf.fit(iris.data, iris.target) + # Check that prefit=True and calling fit raises a ValueError model = SelectFromModel(clf, prefit=True) - assert_array_equal(model.transform(iris.data), X_transform) + assert_raises(ValueError, model.fit, data, y) def test_threshold_string(): est = RandomForestClassifier(n_estimators=50, random_state=0) model = SelectFromModel(est, threshold="0.5*mean") - model.fit(iris.data, iris.target) - X_transform = model.transform(iris.data) + model.fit(data, y) + X_transform = model.transform(data) # Calculate the threshold from the estimator directly. - est.fit(iris.data, iris.target) + est.fit(data, y) threshold = 0.5 * np.mean(est.feature_importances_) - model = SelectFromModel(est, threshold=threshold, prefit=True) - assert_array_equal(X_transform, model.transform(iris.data)) + mask = est.feature_importances_ > threshold + assert_array_equal(X_transform, data[:, mask]) def test_threshold_without_refitting(): """Test that the threshold can be set without refitting the model.""" clf = SGDClassifier(alpha=0.1, n_iter=10, shuffle=True, random_state=0) model = SelectFromModel(clf, threshold=0.1) - model.fit(iris.data, iris.target) - X_transform = model.transform(iris.data) + model.fit(data, y) + X_transform = model.transform(data) # Set a higher threshold to filter out more features. model.threshold = 1.0 - assert_greater(X_transform.shape[1], model.transform(iris.data).shape[1]) + assert_greater(X_transform.shape[1], model.transform(data).shape[1]) diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index 4af8506210fbd..b9a0ea56d817d 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -61,7 +61,7 @@ 'RANSACRegressor', 'RadiusNeighborsRegressor', 'RandomForestRegressor', 'Ridge', 'RidgeCV'] -# Estimators with deprecated transform methods. Can be removed in 0.19 when +# Estimators with deprecated transform methods. Should be removed in 0.19 when # _LearntSelectorMixin is removed. DEPRECATED_TRANSFORM = [ "RandomForestClassifier", "RandomForestRegressor", "ExtraTreesClassifier", @@ -593,7 +593,12 @@ def check_pipeline_consistency(name, Estimator): pipeline = make_pipeline(estimator) estimator.fit(X, y) pipeline.fit(X, y) - funcs = ["score", "fit_transform"] + + if name in DEPRECATED_TRANSFORM: + funcs = ["score"] + else: + funcs = ["score", "fit_transform"] + for func_name in funcs: func = getattr(estimator, func_name, None) if func is not None: @@ -614,8 +619,12 @@ def check_fit_score_takes_y(name, Estimator): estimator = Estimator() set_fast_parameters(estimator) set_random_state(estimator) - funcs = ["fit", "score", "partial_fit", "fit_predict", "fit_transform"] + if name in DEPRECATED_TRANSFORM: + funcs = ["fit", "score", "partial_fit", "fit_predict"] + else: + funcs = [ + "fit", "score", "partial_fit", "fit_predict", "fit_transform"] for func_name in funcs: func = getattr(estimator, func_name, None) if func is not None: @@ -636,6 +645,13 @@ def check_estimators_dtypes(name, Estimator): X_train_int_32 = X_train_32.astype(np.int32) y = X_train_int_64[:, 0] y = multioutput_estimator_convert_y_2d(name, y) + + if name in DEPRECATED_TRANSFORM: + methods = ["predict", "decision_function", "predict_proba"] + else: + methods = [ + "predict", "transform", "decision_function", "predict_proba"] + for X_train in [X_train_32, X_train_64, X_train_int_64, X_train_int_32]: with warnings.catch_warnings(record=True): estimator = Estimator() @@ -643,8 +659,7 @@ def check_estimators_dtypes(name, Estimator): set_random_state(estimator, 1) estimator.fit(X_train, y) - for method in ["predict", "transform", "decision_function", - "predict_proba"]: + for method in methods: if hasattr(estimator, method): getattr(estimator, method)(X_train) @@ -721,7 +736,8 @@ def check_estimators_nan_inf(name, Estimator): raise AssertionError(error_string_predict, Estimator) # transform - if hasattr(estimator, "transform"): + if (hasattr(estimator, "transform") and + name not in DEPRECATED_TRANSFORM): try: estimator.transform(X_train) except ValueError as e: @@ -738,8 +754,11 @@ def check_estimators_nan_inf(name, Estimator): def check_estimators_pickle(name, Estimator): """Test that we can pickle all estimators""" - check_methods = ["predict", "transform", "decision_function", - "predict_proba"] + if name in DEPRECATED_TRANSFORM: + check_methods = ["predict", "decision_function", "predict_proba"] + else: + check_methods = ["predict", "transform", "decision_function", + "predict_proba"] X, y = make_blobs(n_samples=30, centers=[[0, 0, 0], [1, 1, 1]], random_state=0, n_features=2, cluster_std=0.1) diff --git a/sklearn/utils/tests/test_estimator_checks.py b/sklearn/utils/tests/test_estimator_checks.py index b21bc6c1038d3..47c71759a859f 100644 --- a/sklearn/utils/tests/test_estimator_checks.py +++ b/sklearn/utils/tests/test_estimator_checks.py @@ -7,7 +7,7 @@ from sklearn.utils.testing import assert_raises_regex, assert_true from sklearn.utils.estimator_checks import check_estimator from sklearn.utils.estimator_checks import check_estimators_unfitted -from sklearn.linear_model import LogisticRegression +from sklearn.ensemble import AdaBoostClassifier from sklearn.utils.validation import check_X_y, check_array @@ -91,7 +91,7 @@ def test_check_estimator(): assert_true(msg in string_buffer.getvalue()) # doesn't error on actual estimator - check_estimator(LogisticRegression) + check_estimator(AdaBoostClassifier) def test_check_estimators_unfitted():