Skip to content

[MRG] Ensure delegated ducktyping in MetaEstimators #2019

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 23 additions & 8 deletions sklearn/feature_selection/rfe.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

"""Recursive feature elimination for feature ranking"""

from functools import wraps
import numpy as np
from ..utils import check_arrays, safe_sqr
from ..base import BaseEstimator
Expand Down Expand Up @@ -36,6 +37,7 @@ class RFE(BaseEstimator, MetaEstimatorMixin, SelectorMixin):
A supervised learning estimator with a `fit` method that updates a
`coef_` attribute that holds the fitted parameters. Important features
must correspond to high absolute values in the `coef_` array.
The estimator must also implement a `score` method.

For instance, this is the case for most supervised learning
algorithms such as Support Vector Classifiers and Generalized
Expand Down Expand Up @@ -169,7 +171,13 @@ def fit(self, X, y):

return self

def predict(self, X):
def _delegate_wrapper(self, delegate):
def wrapper(X, *args, **kwargs):
return delegate(self.transform(X), *args, **kwargs)
return wrapper

@property
def predict(self):
"""Reduce X to the selected features and then predict using the
underlying estimator.

Expand All @@ -183,9 +191,10 @@ def predict(self, X):
y : array of shape [n_samples]
The predicted target values.
"""
return self.estimator_.predict(self.transform(X))
return self._delegate_wrapper(self.estimator_.predict)

def score(self, X, y):
@property
def score(self):
"""Reduce X to the selected features and then return the score of the
underlying estimator.

Expand All @@ -197,16 +206,22 @@ def score(self, X, y):
y : array of shape [n_samples]
The target values.
"""
return self.estimator_.score(self.transform(X), y)
return self._delegate_wrapper(self.estimator_.score)

def _get_support_mask(self):
return self.support_

def decision_function(self, X):
return self.estimator_.decision_function(self.transform(X))
@property
def decision_function(self):
return self._delegate_wrapper(self.estimator_.decision_function)

@property
def predict_proba(self):
return self._delegate_wrapper(self.estimator_.predict_proba)

def predict_proba(self, X):
return self.estimator_.predict_proba(self.transform(X))
@property
def predict_log_proba(self):
return self._delegate_wrapper(self.estimator_.predict_log_proba)


class RFECV(RFE, MetaEstimatorMixin):
Expand Down
31 changes: 31 additions & 0 deletions sklearn/grid_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,20 +325,51 @@ def score(self, X, y=None):

@property
def predict(self):
"""Call predict on the best estimator"""
return self.best_estimator_.predict

@property
def predict_proba(self):
"""Call predict_proba on the best estimator"""
return self.best_estimator_.predict_proba

@property
def predict_log_proba(self):
"""Call predict_log_proba on the best estimator"""
return self.best_estimator_.predict_log_proba

@property
def decision_function(self):
"""Call decision_function on the best estimator"""
return self.best_estimator_.decision_function

@property
def transform(self):
"""Call transform on the best estimator"""
return self.best_estimator_.transform

@property
def inverse_transform(self):
"""Call inverse_transform on the best estimator"""
return self.best_estimator_.inverse_transform

def _check_estimator(self):
"""Check that estimator can be fitted and score can be computed."""
if (not hasattr(self.estimator, 'fit') or
not (hasattr(self.estimator, 'predict')
or hasattr(self.estimator, 'score'))):
raise TypeError("estimator should a be an estimator implementing"
" 'fit' and 'predict' or 'score' methods,"
" %s (type %s) was passed" %
(self.estimator, type(self.estimator)))
if (self.scoring is None and self.loss_func is None and self.score_func
is None):
if not hasattr(self.estimator, 'score'):
raise TypeError(
"If no scoring is specified, the estimator passed "
"should have a 'score' method. The estimator %s "
"does not." % self.estimator)

def _fit(self, X, y, parameter_iterable):
"""Actual fitting, performing the search over parameters."""

Expand Down
Loading