Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions doc/developers/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -883,6 +883,20 @@ take arguments ``X, y``, even if y is not used. Similarly, for ``score`` to be
usable, the last step of the pipeline needs to have a ``score`` function that
accepts an optional ``y``.

Estimator types
---------------
Some common functionality depends on the kind of estimator passed.
For example, cross-validation in :class:`grid_search.GridSearchCV` and
:func:`cross_validation.cross_val_score` defaults to being stratified when used
on a classifier, but not otherwise. Similarly, scorers for average precision
that take a continuous prediction need to call ``decision_function`` for classifiers,
but ``predict`` for regressors. This distinction between classifiers and regressors
is implemented using the ``_estimator_type`` attribute, which takes a string value.
It should be ``"classifier"`` for classifiers and ``"regressor"`` for
regressors and ``"clusterer"`` for clustering methods, to work as expected.
Inheriting from ``ClassifierMixin``, ``RegressorMixin`` or ``ClusterMixin``
will set the attribute automatically.

Working notes
-------------

Expand Down
28 changes: 12 additions & 16 deletions sklearn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,14 +244,14 @@ def set_params(self, **params):
if len(split) > 1:
# nested objects case
name, sub_name = split
if not name in valid_params:
if name not in valid_params:
raise ValueError('Invalid parameter %s for estimator %s' %
(name, self))
sub_object = valid_params[name]
sub_object.set_params(**{sub_name: value})
else:
# simple objects case
if not key in valid_params:
if key not in valid_params:
raise ValueError('Invalid parameter %s ' 'for estimator %s'
% (key, self.__class__.__name__))
setattr(self, key, value)
Expand All @@ -266,6 +266,7 @@ def __repr__(self):
###############################################################################
class ClassifierMixin(object):
"""Mixin class for all classifiers in scikit-learn."""
_estimator_type = "classifier"

def score(self, X, y, sample_weight=None):
"""Returns the mean accuracy on the given test data and labels.
Expand Down Expand Up @@ -298,6 +299,7 @@ def score(self, X, y, sample_weight=None):
###############################################################################
class RegressorMixin(object):
"""Mixin class for all regression estimators in scikit-learn."""
_estimator_type = "regressor"

def score(self, X, y, sample_weight=None):
"""Returns the coefficient of determination R^2 of the prediction.
Expand Down Expand Up @@ -331,6 +333,8 @@ def score(self, X, y, sample_weight=None):
###############################################################################
class ClusterMixin(object):
"""Mixin class for all cluster estimators in scikit-learn."""
_estimator_type = "clusterer"

def fit_predict(self, X, y=None):
"""Performs clustering on X and returns cluster labels.

Expand Down Expand Up @@ -443,20 +447,12 @@ class MetaEstimatorMixin(object):


###############################################################################
# XXX: Temporary solution to figure out if an estimator is a classifier

def _get_sub_estimator(estimator):
"""Returns the final estimator if there is any."""
if hasattr(estimator, 'estimator'):
# GridSearchCV and other CV-tuned estimators
return _get_sub_estimator(estimator.estimator)
if hasattr(estimator, 'steps'):
# Pipeline
return _get_sub_estimator(estimator.steps[-1][1])
return estimator


def is_classifier(estimator):
"""Returns True if the given estimator is (probably) a classifier."""
estimator = _get_sub_estimator(estimator)
return isinstance(estimator, ClassifierMixin)
return getattr(estimator, "_estimator_type", None) == "classifier"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we enforce that all estimators should have a _estimator_type tag?

def is_classifier(estimator):
    """Returns True if the given estimator is a classifier."""
    if not hasattr(estimator, "_estimator_type"):
        raise ValueError("The given estimator instance does not have a _estimator_type tag.")

    return estimator._estimator_type.lower() == "classifier"

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would not work with user defined estimators currently, but this would be helpful in framing a generic estimator test framework as wished for in #3810

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But then this should be in the test framework, not the code. So people that want to be strict can run their tests, but people that don't care can still run their sloppy but sklearn compatible code.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes sense... Thanks for the comment :)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.



def is_regressor(estimator):
"""Returns True if the given estimator is (probably) a regressor."""
return getattr(estimator, "_estimator_type", None) == "regressor"
92 changes: 82 additions & 10 deletions sklearn/ensemble/gradient_boosting.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from ..base import ClassifierMixin
from ..base import RegressorMixin
from ..utils import check_random_state, check_array, check_X_y, column_or_1d
from ..utils import check_consistent_length
from ..utils import check_consistent_length, deprecated
from ..utils.extmath import logsumexp
from ..utils.fixes import expit, bincount
from ..utils.stats import _weighted_percentile
Expand Down Expand Up @@ -438,7 +438,7 @@ class ClassificationLossFunction(six.with_metaclass(ABCMeta, LossFunction)):
def _score_to_proba(self, score):
"""Template method to convert scores to probabilities.

If the loss does not support probabilites raises AttributeError.
the does not support probabilites raises AttributeError.
"""
raise TypeError('%s does not support predict_proba' % type(self).__name__)

Expand Down Expand Up @@ -1044,9 +1044,10 @@ def _fit_stages(self, X, y, y_pred, sample_weight, random_state,
self.train_score_[i] = loss_(y[sample_mask],
y_pred[sample_mask],
sample_weight[sample_mask])
self.oob_improvement_[i] = (old_oob_score -
loss_(y[~sample_mask], y_pred[~sample_mask],
sample_weight[~sample_mask]))
self.oob_improvement_[i] = (
old_oob_score - loss_(y[~sample_mask],
y_pred[~sample_mask],
sample_weight[~sample_mask]))
else:
# no need to fancy index w/ no subsampling
self.train_score_[i] = loss_(y, y_pred, sample_weight)
Expand Down Expand Up @@ -1082,6 +1083,7 @@ def _decision_function(self, X):
predict_stages(self.estimators_, X, self.learning_rate, score)
return score

@deprecated(" and will be removed in 0.19")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is also staged_decision_function.

(Ping @pprett )

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was wondering about that, but you are right, it should be removed.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

def decision_function(self, X):
"""Compute the decision function of ``X``.

Expand All @@ -1104,7 +1106,7 @@ def decision_function(self, X):
return score.ravel()
return score

def staged_decision_function(self, X):
def _staged_decision_function(self, X):
"""Compute decision function of ``X`` for each iteration.

This method allows monitoring (i.e. determine error on testing set)
Expand All @@ -1129,6 +1131,30 @@ def staged_decision_function(self, X):
predict_stage(self.estimators_, i, X, self.learning_rate, score)
yield score.copy()

@deprecated(" and will be removed in 0.19")
def staged_decision_function(self, X):
"""Compute decision function of ``X`` for each iteration.

This method allows monitoring (i.e. determine error on testing set)
after each stage.

Parameters
----------
X : array-like of shape = [n_samples, n_features]
The input samples.

Returns
-------
score : generator of array, shape = [n_samples, k]
The decision function of the input samples. The order of the
classes corresponds to that in the attribute `classes_`.
Regression and binary classification are special cases with
``k == 1``, otherwise ``k==n_classes``.
"""
for dec in self._staged_decision_function(X):
# no yield from in Python2.X
yield dec

@property
def feature_importances_(self):
"""Return the feature importances (the higher, the more important the
Expand Down Expand Up @@ -1315,6 +1341,51 @@ def _validate_y(self, y):
self.n_classes_ = len(self.classes_)
return y

def decision_function(self, X):
"""Compute the decision function of ``X``.

Parameters
----------
X : array-like of shape = [n_samples, n_features]
The input samples.

Returns
-------
score : array, shape = [n_samples, n_classes] or [n_samples]
The decision function of the input samples. The order of the
classes corresponds to that in the attribute `classes_`.
Regression and binary classification produce an array of shape
[n_samples].
"""
X = check_array(X, dtype=DTYPE, order="C")
score = self._decision_function(X)
if score.shape[1] == 1:
return score.ravel()
return score

def staged_decision_function(self, X):
"""Compute decision function of ``X`` for each iteration.

This method allows monitoring (i.e. determine error on testing set)
after each stage.

Parameters
----------
X : array-like of shape = [n_samples, n_features]
The input samples.

Returns
-------
score : generator of array, shape = [n_samples, k]
The decision function of the input samples. The order of the
classes corresponds to that in the attribute `classes_`.
Regression and binary classification are special cases with
``k == 1``, otherwise ``k==n_classes``.
"""
for dec in self._staged_decision_function(X):
# no yield from in Python2.X
yield dec

def predict(self, X):
"""Predict class for X.

Expand Down Expand Up @@ -1348,7 +1419,7 @@ def staged_predict(self, X):
y : generator of array of shape = [n_samples]
The predicted value of the input samples.
"""
for score in self.staged_decision_function(X):
for score in self._staged_decision_function(X):
decisions = self.loss_._score_to_decision(score)
yield self.classes_.take(decisions, axis=0)

Expand Down Expand Up @@ -1419,7 +1490,7 @@ def staged_predict_proba(self, X):
The predicted value of the input samples.
"""
try:
for score in self.staged_decision_function(X):
for score in self._staged_decision_function(X):
yield self.loss_._score_to_proba(score)
except NotFittedError:
raise
Expand Down Expand Up @@ -1594,7 +1665,8 @@ def predict(self, X):
y : array of shape = [n_samples]
The predicted values.
"""
return self.decision_function(X).ravel()
X = check_array(X, dtype=DTYPE, order="C")
return self._decision_function(X).ravel()

def staged_predict(self, X):
"""Predict regression target at each stage for X.
Expand All @@ -1612,5 +1684,5 @@ def staged_predict(self, X):
y : generator of array of shape = [n_samples]
The predicted value of the input samples.
"""
for y in self.staged_decision_function(X):
for y in self._staged_decision_function(X):
yield y.ravel()
21 changes: 11 additions & 10 deletions sklearn/ensemble/tests/test_gradient_boosting.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""
Testing for the gradient boosting module (sklearn.ensemble.gradient_boosting).
"""

import warnings
import numpy as np

from sklearn import datasets
Expand Down Expand Up @@ -171,8 +171,9 @@ def test_boston():
for loss in ("ls", "lad", "huber"):
for subsample in (1.0, 0.5):
last_y_pred = None
for i, sample_weight in enumerate((None, np.ones(len(boston.target)),
2 * np.ones(len(boston.target)))):
for i, sample_weight in enumerate(
(None, np.ones(len(boston.target)),
2 * np.ones(len(boston.target)))):
clf = GradientBoostingRegressor(n_estimators=100, loss=loss,
max_depth=4, subsample=subsample,
min_samples_split=1,
Expand Down Expand Up @@ -343,6 +344,7 @@ def test_check_max_features():
max_features=-0.1)
assert_raises(ValueError, clf.fit, X, y)


def test_max_feature_regression():
# Test to make sure random state is set properly.
X, y = datasets.make_hastie_10_2(n_samples=12000, random_state=1)
Expand Down Expand Up @@ -455,7 +457,8 @@ def test_staged_functions_defensive():
if staged_func is None:
# regressor has no staged_predict_proba
continue
staged_result = list(staged_func(X))
with warnings.catch_warnings(record=True):
staged_result = list(staged_func(X))
staged_result[1][:] = 0
assert_true(np.all(staged_result[0] != 0))

Expand Down Expand Up @@ -843,7 +846,7 @@ def test_complete_classification():
k = 4

est = GradientBoostingClassifier(n_estimators=20, max_depth=None,
random_state=1, max_leaf_nodes=k+1)
random_state=1, max_leaf_nodes=k + 1)
est.fit(X, y)

tree = est.estimators_[0, 0].tree_
Expand All @@ -858,7 +861,7 @@ def test_complete_regression():
k = 4

est = GradientBoostingRegressor(n_estimators=20, max_depth=None,
random_state=1, max_leaf_nodes=k+1)
random_state=1, max_leaf_nodes=k + 1)
est.fit(boston.data, boston.target)

tree = est.estimators_[-1, 0].tree_
Expand Down Expand Up @@ -971,8 +974,7 @@ def test_non_uniform_weights_toy_edge_case_reg():
X = [[1, 0],
[1, 0],
[1, 0],
[0, 1],
]
[0, 1]]
y = [0, 0, 1, 0]
# ignore the first 2 training samples by setting their weight to 0
sample_weight = [0, 0, 1, 1]
Expand Down Expand Up @@ -1002,8 +1004,7 @@ def test_non_uniform_weights_toy_edge_case_clf():
X = [[1, 0],
[1, 0],
[1, 0],
[0, 1],
]
[0, 1]]
y = [0, 0, 1, 0]
# ignore the first 2 training samples by setting their weight to 0
sample_weight = [0, 0, 1, 1]
Expand Down
4 changes: 4 additions & 0 deletions sklearn/grid_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,10 @@ def __init__(self, estimator, scoring=None,
self.pre_dispatch = pre_dispatch
self.error_score = error_score

@property
def _estimator_type(self):
return self.estimator._estimator_type

def score(self, X, y=None):
"""Returns the score on the given data, if the estimator has been refit

Expand Down
8 changes: 6 additions & 2 deletions sklearn/linear_model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from ..externals import six
from ..externals.joblib import Parallel, delayed
from ..base import BaseEstimator, ClassifierMixin, RegressorMixin
from ..utils import as_float_array, check_array, check_X_y
from ..utils import as_float_array, check_array, check_X_y, deprecated
from ..utils.extmath import safe_sparse_dot
from ..utils.sparsefuncs import mean_variance_axis, inplace_column_scale
from ..utils.fixes import sparse_lsqr
Expand Down Expand Up @@ -119,6 +119,7 @@ class LinearModel(six.with_metaclass(ABCMeta, BaseEstimator)):
def fit(self, X, y):
"""Fit model."""

@deprecated(" and will be removed in 0.19.")
def decision_function(self, X):
"""Decision function of the linear model.

Expand All @@ -132,6 +133,9 @@ def decision_function(self, X):
C : array, shape = (n_samples,)
Returns predicted values.
"""
return self._decision_function(X)

def _decision_function(self, X):
check_is_fitted(self, "coef_")

X = check_array(X, accept_sparse=['csr', 'csc', 'coo'])
Expand All @@ -151,7 +155,7 @@ def predict(self, X):
C : array, shape = (n_samples,)
Returns predicted values.
"""
return self.decision_function(X)
return self._decision_function(X)

_center_data = staticmethod(center_data)

Expand Down
Loading