Skip to content

[MRG+2?] Metaestimator delegation #3982

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Dec 27, 2014
18 changes: 13 additions & 5 deletions sklearn/feature_selection/rfe.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,15 @@

import numpy as np
from ..utils import check_X_y, safe_sqr
from ..utils.metaestimators import if_delegate_has_method
from ..base import BaseEstimator
from ..base import MetaEstimatorMixin
from ..base import clone
from ..base import is_classifier
from ..cross_validation import _check_cv as check_cv
from ..cross_validation import _safe_split, _score
from .base import SelectorMixin
from ..metrics.scorer import check_scoring
from .base import SelectorMixin


class RFE(BaseEstimator, MetaEstimatorMixin, SelectorMixin):
Expand Down Expand Up @@ -126,7 +127,7 @@ def fit(self, X, y):
n_features_to_select = self.n_features_to_select

if 0.0 < self.step < 1.0:
step = int(max(1, self.step * n_features))
step = int(max(1, self.step * n_features))
else:
step = int(self.step)
if step <= 0:
Expand Down Expand Up @@ -170,6 +171,7 @@ def fit(self, X, y):

return self

@if_delegate_has_method(delegate='estimator')
def predict(self, X):
"""Reduce X to the selected features and then predict using the
underlying estimator.
Expand All @@ -186,6 +188,7 @@ def predict(self, X):
"""
return self.estimator_.predict(self.transform(X))

@if_delegate_has_method(delegate='estimator')
def score(self, X, y):
"""Reduce X to the selected features and then return the score of the
underlying estimator.
Expand All @@ -203,15 +206,21 @@ def score(self, X, y):
def _get_support_mask(self):
return self.support_

@if_delegate_has_method(delegate='estimator')
def decision_function(self, X):
return self.estimator_.decision_function(self.transform(X))

@if_delegate_has_method(delegate='estimator')
def predict_proba(self, X):
return self.estimator_.predict_proba(self.transform(X))

@if_delegate_has_method(delegate='estimator')
def predict_log_proba(self, X):
return self.estimator_.predict_log_proba(self.transform(X))


class RFECV(RFE, MetaEstimatorMixin):
"""Feature ranking with recursive feature elimination and cross-validated
"""Feature ranking with recursive feature elimination and cross-validated
selection of the best number of features.

Parameters
Expand Down Expand Up @@ -254,7 +263,7 @@ class RFECV(RFE, MetaEstimatorMixin):
----------
n_features_ : int
The number of selected features with cross-validation.

support_ : array of shape [n_features]
The mask of selected features.

Expand Down Expand Up @@ -385,4 +394,3 @@ def fit(self, X, y):
# here, the scores are normalized by len(cv)
self.grid_scores_ = scores / len(cv)
return self

105 changes: 93 additions & 12 deletions sklearn/grid_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from .externals import six
from .utils import check_random_state
from .utils.validation import _num_samples, indexable
from .utils.metaestimators import if_delegate_has_method
from .metrics.scorer import check_scoring


Expand Down Expand Up @@ -342,21 +343,101 @@ def score(self, X, y=None):
ChangedBehaviorWarning)
return self.scorer_(self.best_estimator_, X, y)

@property
def predict(self):
return self.best_estimator_.predict
@if_delegate_has_method(delegate='estimator')
def predict(self, X):
"""Call predict on the estimator with the best found parameters.

@property
def predict_proba(self):
return self.best_estimator_.predict_proba
Only available if ``refit=True`` and the underlying estimator supports
``predict``.

@property
def decision_function(self):
return self.best_estimator_.decision_function
Parameters
-----------
X : indexable, length n_samples
Must fulfill the input assumptions of the
underlying estimator.

"""
return self.best_estimator_.predict(X)

@if_delegate_has_method(delegate='estimator')
def predict_proba(self, X):
"""Call predict_proba on the estimator with the best found parameters.

Only available if ``refit=True`` and the underlying estimator supports
``predict_proba``.

Parameters
-----------
X : indexable, length n_samples
Must fulfill the input assumptions of the
underlying estimator.

"""
return self.best_estimator_.predict_proba(X)

@if_delegate_has_method(delegate='estimator')
def predict_log_proba(self, X):
"""Call predict_log_proba on the estimator with the best found parameters.

Only available if ``refit=True`` and the underlying estimator supports
``predict_log_proba``.

Parameters
-----------
X : indexable, length n_samples
Must fulfill the input assumptions of the
underlying estimator.

"""
return self.best_estimator_.predict_log_proba(X)

@if_delegate_has_method(delegate='estimator')
def decision_function(self, X):
"""Call decision_function on the estimator with the best found parameters.

@property
def transform(self):
return self.best_estimator_.transform
Only available if ``refit=True`` and the underlying estimator supports
``decision_function``.

Parameters
-----------
X : indexable, length n_samples
Must fulfill the input assumptions of the
underlying estimator.

"""
return self.best_estimator_.decision_function(X)

@if_delegate_has_method(delegate='estimator')
def transform(self, X):
"""Call transform on the estimator with the best found parameters.

Only available if the underlying estimator supports ``transform`` and
``refit=True``.

Parameters
-----------
X : indexable, length n_samples
Must fulfill the input assumptions of the
underlying estimator.

"""
return self.best_estimator_.transform(X)

@if_delegate_has_method(delegate='estimator')
def inverse_transform(self, Xt):
"""Call inverse_transform on the estimator with the best found parameters.

Only available if the underlying estimator implements ``inverse_transform`` and
``refit=True``.

Parameters
-----------
Xt : indexable, length n_samples
Must fulfill the input assumptions of the
underlying estimator.

"""
return self.best_estimator_.transform(Xt)

def _fit(self, X, y, parameter_iterable):
"""Actual fitting, performing the search over parameters."""
Expand Down
15 changes: 12 additions & 3 deletions sklearn/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,12 @@
from .externals.joblib import Parallel, delayed
from .externals import six
from .utils import tosequence
from .utils.metaestimators import if_delegate_has_method
from .externals.six import iteritems

__all__ = ['Pipeline', 'FeatureUnion']


# One round of beers on me if someone finds out why the backslash
# is needed in the Attributes section so as not to upset sphinx.

class Pipeline(BaseEstimator):
"""Pipeline of transforms with a final estimator.

Expand Down Expand Up @@ -106,6 +104,10 @@ def get_params(self, deep=True):
out['%s__%s' % (name, key)] = value
return out

@property
def _final_estimator(self):
return self.steps[-1][1]

# Estimator interface

def _pre_transform(self, X, y=None, **fit_params):
Expand Down Expand Up @@ -140,6 +142,7 @@ def fit_transform(self, X, y=None, **fit_params):
else:
return self.steps[-1][-1].fit(Xt, y, **fit_params).transform(Xt)

@if_delegate_has_method(delegate='_final_estimator')
def predict(self, X):
"""Applies transforms to the data, and the predict method of the
final estimator. Valid only if the final estimator implements
Expand All @@ -149,6 +152,7 @@ def predict(self, X):
Xt = transform.transform(Xt)
return self.steps[-1][-1].predict(Xt)

@if_delegate_has_method(delegate='_final_estimator')
def predict_proba(self, X):
"""Applies transforms to the data, and the predict_proba method of the
final estimator. Valid only if the final estimator implements
Expand All @@ -158,6 +162,7 @@ def predict_proba(self, X):
Xt = transform.transform(Xt)
return self.steps[-1][-1].predict_proba(Xt)

@if_delegate_has_method(delegate='_final_estimator')
def decision_function(self, X):
"""Applies transforms to the data, and the decision_function method of
the final estimator. Valid only if the final estimator implements
Expand All @@ -167,12 +172,14 @@ def decision_function(self, X):
Xt = transform.transform(Xt)
return self.steps[-1][-1].decision_function(Xt)

@if_delegate_has_method(delegate='_final_estimator')
def predict_log_proba(self, X):
Xt = X
for name, transform in self.steps[:-1]:
Xt = transform.transform(Xt)
return self.steps[-1][-1].predict_log_proba(Xt)

@if_delegate_has_method(delegate='_final_estimator')
def transform(self, X):
"""Applies transforms to the data, and the transform method of the
final estimator. Valid only if the final estimator implements
Expand All @@ -182,6 +189,7 @@ def transform(self, X):
Xt = transform.transform(Xt)
return Xt

@if_delegate_has_method(delegate='_final_estimator')
def inverse_transform(self, X):
if X.ndim == 1:
X = X[None, :]
Expand All @@ -190,6 +198,7 @@ def inverse_transform(self, X):
Xt = step.inverse_transform(Xt)
return Xt

@if_delegate_has_method(delegate='_final_estimator')
def score(self, X, y=None):
"""Applies transforms to the data, and the score method of the
final estimator. Valid only if the final estimator implements
Expand Down
Loading