Skip to content

Commit 159bb7b

Browse files
committed
FIX properties of #1805 for BaseSearchCV and RFE
1 parent a2d3781 commit 159bb7b

File tree

4 files changed

+257
-206
lines changed

4 files changed

+257
-206
lines changed

sklearn/feature_selection/rfe.py

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
"""Recursive feature elimination for feature ranking"""
88

9+
from functools import wraps
910
import numpy as np
1011
from ..utils import check_arrays, safe_sqr
1112
from ..base import BaseEstimator
@@ -36,6 +37,7 @@ class RFE(BaseEstimator, MetaEstimatorMixin, SelectorMixin):
3637
A supervised learning estimator with a `fit` method that updates a
3738
`coef_` attribute that holds the fitted parameters. Important features
3839
must correspond to high absolute values in the `coef_` array.
40+
The estimator must also implement a `score` method.
3941
4042
For instance, this is the case for most supervised learning
4143
algorithms such as Support Vector Classifiers and Generalized
@@ -169,7 +171,13 @@ def fit(self, X, y):
169171

170172
return self
171173

172-
def predict(self, X):
174+
def _delegate_wrapper(self, delegate):
175+
def wrapper(X, *args, **kwargs):
176+
return delegate(self.transform(X), *args, **kwargs)
177+
return wrapper
178+
179+
@property
180+
def predict(self):
173181
"""Reduce X to the selected features and then predict using the
174182
underlying estimator.
175183
@@ -183,9 +191,10 @@ def predict(self, X):
183191
y : array of shape [n_samples]
184192
The predicted target values.
185193
"""
186-
return self.estimator_.predict(self.transform(X))
194+
return self._delegate_wrapper(self.estimator_.predict)
187195

188-
def score(self, X, y):
196+
@property
197+
def score(self):
189198
"""Reduce X to the selected features and then return the score of the
190199
underlying estimator.
191200
@@ -197,16 +206,22 @@ def score(self, X, y):
197206
y : array of shape [n_samples]
198207
The target values.
199208
"""
200-
return self.estimator_.score(self.transform(X), y)
209+
return self._delegate_wrapper(self.estimator_.score)
201210

202211
def _get_support_mask(self):
203212
return self.support_
204213

205-
def decision_function(self, X):
206-
return self.estimator_.decision_function(self.transform(X))
214+
@property
215+
def decision_function(self):
216+
return self._delegate_wrapper(self.estimator_.decision_function)
217+
218+
@property
219+
def predict_proba(self):
220+
return self._delegate_wrapper(self.estimator_.predict_proba)
207221

208-
def predict_proba(self, X):
209-
return self.estimator_.predict_proba(self.transform(X))
222+
@property
223+
def predict_log_proba(self):
224+
return self._delegate_wrapper(self.estimator_.predict_log_proba)
210225

211226

212227
class RFECV(RFE, MetaEstimatorMixin):

sklearn/grid_search.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -325,20 +325,51 @@ def score(self, X, y=None):
325325

326326
@property
327327
def predict(self):
328+
"""Call predict on the best estimator"""
328329
return self.best_estimator_.predict
329330

330331
@property
331332
def predict_proba(self):
333+
"""Call predict_proba on the best estimator"""
332334
return self.best_estimator_.predict_proba
333335

336+
@property
337+
def predict_log_proba(self):
338+
"""Call predict_log_proba on the best estimator"""
339+
return self.best_estimator_.predict_log_proba
340+
334341
@property
335342
def decision_function(self):
343+
"""Call decision_function on the best estimator"""
336344
return self.best_estimator_.decision_function
337345

338346
@property
339347
def transform(self):
348+
"""Call transform on the best estimator"""
340349
return self.best_estimator_.transform
341350

351+
@property
352+
def inverse_transform(self):
353+
"""Call inverse_transform on the best estimator"""
354+
return self.best_estimator_.inverse_transform
355+
356+
def _check_estimator(self):
357+
"""Check that estimator can be fitted and score can be computed."""
358+
if (not hasattr(self.estimator, 'fit') or
359+
not (hasattr(self.estimator, 'predict')
360+
or hasattr(self.estimator, 'score'))):
361+
raise TypeError("estimator should a be an estimator implementing"
362+
" 'fit' and 'predict' or 'score' methods,"
363+
" %s (type %s) was passed" %
364+
(self.estimator, type(self.estimator)))
365+
if (self.scoring is None and self.loss_func is None and self.score_func
366+
is None):
367+
if not hasattr(self.estimator, 'score'):
368+
raise TypeError(
369+
"If no scoring is specified, the estimator passed "
370+
"should have a 'score' method. The estimator %s "
371+
"does not." % self.estimator)
372+
342373
def _fit(self, X, y, parameter_iterable):
343374
"""Actual fitting, performing the search over parameters."""
344375

sklearn/tests/test_metaestimators.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from sklearn.utils.testing import assert_true, assert_false
1111
from sklearn.pipeline import Pipeline
1212
from sklearn.grid_search import GridSearchCV, RandomizedSearchCV
13-
from sklearn.feature_selection import RFECV
13+
from sklearn.feature_selection import RFE, RFECV
1414

1515

1616
class DelegatorData(object):
@@ -32,8 +32,10 @@ def __init__(self, name, construct, skip_methods=(),
3232
lambda est: RandomizedSearchCV(
3333
est, param_distributions={'param': [5]}, cv=2),
3434
skip_methods=['score']),
35+
DelegatorData('RFE', RFE,
36+
skip_methods=['transform', 'inverse_transform', 'score']),
3537
DelegatorData('RFECV', RFECV,
36-
skip_methods=['transform', 'inverse_transform']),
38+
skip_methods=['transform', 'inverse_transform', 'score']),
3739
]
3840

3941

@@ -43,7 +45,7 @@ def hides(method):
4345
@property
4446
def wrapper(obj):
4547
if obj.hidden_method == method.__name__:
46-
raise AttributeError
48+
raise AttributeError('%r is hidden' % obj.hidden_method)
4749
return functools.partial(method, obj)
4850
return wrapper
4951

@@ -86,6 +88,7 @@ def score(self, X, *args, **kwargs):
8688

8789
methods = [k for k in iterkeys(SubEstimator.__dict__)
8890
if not k.startswith('_') and not k.startswith('fit')]
91+
methods.sort()
8992

9093
for delegator_data in DELEGATING_METAESTIMATORS:
9194
delegate = SubEstimator()
@@ -98,6 +101,8 @@ def score(self, X, *args, **kwargs):
98101
assert_true(hasattr(delegator, method),
99102
msg="%s does not have method %r when its delegate does"
100103
% (delegator_data.name, method))
104+
# smoke test delegation
105+
getattr(delegator, method)(delegator_data.fit_args[0])
101106

102107
for method in methods:
103108
if method in delegator_data.skip_methods:

0 commit comments

Comments
 (0)