From d89c2150133b82a28cd8416a270414ea901cdef6 Mon Sep 17 00:00:00 2001 From: Andreas Mueller Date: Thu, 19 Mar 2015 16:26:36 -0400 Subject: [PATCH 1/3] Add tags to classifiers and regressors to identify them as such. --- doc/developers/index.rst | 13 +++ sklearn/base.py | 28 +++--- sklearn/ensemble/gradient_boosting.py | 92 +++++++++++++++++-- .../ensemble/tests/test_gradient_boosting.py | 21 +++-- sklearn/grid_search.py | 4 + sklearn/linear_model/base.py | 8 +- sklearn/linear_model/coordinate_descent.py | 19 +++- sklearn/linear_model/stochastic_gradient.py | 4 +- sklearn/metrics/scorer.py | 26 +++--- sklearn/metrics/tests/test_score_objects.py | 9 +- sklearn/multiclass.py | 14 +-- sklearn/pipeline.py | 4 + sklearn/svm/base.py | 33 ++++++- sklearn/tests/test_common.py | 2 + sklearn/tests/test_multiclass.py | 19 +++- sklearn/utils/estimator_checks.py | 26 ++++++ 16 files changed, 261 insertions(+), 61 deletions(-) diff --git a/doc/developers/index.rst b/doc/developers/index.rst index ff8049cfd3b06..199d93598d420 100644 --- a/doc/developers/index.rst +++ b/doc/developers/index.rst @@ -883,6 +883,19 @@ take arguments ``X, y``, even if y is not used. Similarly, for ``score`` to be usable, the last step of the pipeline needs to have a ``score`` function that accepts an optional ``y``. +Estimator types +--------------- +Some common functionality depends on the kind of estimator passed. +For example, cross-validation in :class:`grid_search.GridSearchCV` and +:func:`cross_validation.cross_val_score` defaults to being stratified when used +on a classifier, but not otherwise. Similarly, scorers for average precision +that take a continuous prediction need to call ``decision_function`` for classifiers, +but ``predict`` for regressors. This distinction between classifiers and regressors +is implemented using the ``_estimator_type`` attribute, which takes a string value. +It should be ``"classifier"`` for classifiers and ``"regressor"`` for regressors, +to work as expected. Inheriting from ``ClassifierMixin`` or ``RegressorMixin`` will +set the attribute automatically. + Working notes ------------- diff --git a/sklearn/base.py b/sklearn/base.py index c77e5fb969506..805e93cd0fcd2 100644 --- a/sklearn/base.py +++ b/sklearn/base.py @@ -244,14 +244,14 @@ def set_params(self, **params): if len(split) > 1: # nested objects case name, sub_name = split - if not name in valid_params: + if name not in valid_params: raise ValueError('Invalid parameter %s for estimator %s' % (name, self)) sub_object = valid_params[name] sub_object.set_params(**{sub_name: value}) else: # simple objects case - if not key in valid_params: + if key not in valid_params: raise ValueError('Invalid parameter %s ' 'for estimator %s' % (key, self.__class__.__name__)) setattr(self, key, value) @@ -266,6 +266,7 @@ def __repr__(self): ############################################################################### class ClassifierMixin(object): """Mixin class for all classifiers in scikit-learn.""" + _estimator_type = "classifier" def score(self, X, y, sample_weight=None): """Returns the mean accuracy on the given test data and labels. @@ -298,6 +299,7 @@ def score(self, X, y, sample_weight=None): ############################################################################### class RegressorMixin(object): """Mixin class for all regression estimators in scikit-learn.""" + _estimator_type = "regressor" def score(self, X, y, sample_weight=None): """Returns the coefficient of determination R^2 of the prediction. @@ -331,6 +333,8 @@ def score(self, X, y, sample_weight=None): ############################################################################### class ClusterMixin(object): """Mixin class for all cluster estimators in scikit-learn.""" + _estimator_type = "clusterer" + def fit_predict(self, X, y=None): """Performs clustering on X and returns cluster labels. @@ -443,20 +447,12 @@ class MetaEstimatorMixin(object): ############################################################################### -# XXX: Temporary solution to figure out if an estimator is a classifier - -def _get_sub_estimator(estimator): - """Returns the final estimator if there is any.""" - if hasattr(estimator, 'estimator'): - # GridSearchCV and other CV-tuned estimators - return _get_sub_estimator(estimator.estimator) - if hasattr(estimator, 'steps'): - # Pipeline - return _get_sub_estimator(estimator.steps[-1][1]) - return estimator - def is_classifier(estimator): """Returns True if the given estimator is (probably) a classifier.""" - estimator = _get_sub_estimator(estimator) - return isinstance(estimator, ClassifierMixin) + return getattr(estimator, "_estimator_type", None) == "classifier" + + +def is_regressor(estimator): + """Returns True if the given estimator is (probably) a regressor.""" + return getattr(estimator, "_estimator_type", None) == "regressor" diff --git a/sklearn/ensemble/gradient_boosting.py b/sklearn/ensemble/gradient_boosting.py index b41499b2dfa6f..032bcb9a277f6 100644 --- a/sklearn/ensemble/gradient_boosting.py +++ b/sklearn/ensemble/gradient_boosting.py @@ -35,7 +35,7 @@ from ..base import ClassifierMixin from ..base import RegressorMixin from ..utils import check_random_state, check_array, check_X_y, column_or_1d -from ..utils import check_consistent_length +from ..utils import check_consistent_length, deprecated from ..utils.extmath import logsumexp from ..utils.fixes import expit, bincount from ..utils.stats import _weighted_percentile @@ -438,7 +438,7 @@ class ClassificationLossFunction(six.with_metaclass(ABCMeta, LossFunction)): def _score_to_proba(self, score): """Template method to convert scores to probabilities. - If the loss does not support probabilites raises AttributeError. + the does not support probabilites raises AttributeError. """ raise TypeError('%s does not support predict_proba' % type(self).__name__) @@ -1044,9 +1044,10 @@ def _fit_stages(self, X, y, y_pred, sample_weight, random_state, self.train_score_[i] = loss_(y[sample_mask], y_pred[sample_mask], sample_weight[sample_mask]) - self.oob_improvement_[i] = (old_oob_score - - loss_(y[~sample_mask], y_pred[~sample_mask], - sample_weight[~sample_mask])) + self.oob_improvement_[i] = ( + old_oob_score - loss_(y[~sample_mask], + y_pred[~sample_mask], + sample_weight[~sample_mask])) else: # no need to fancy index w/ no subsampling self.train_score_[i] = loss_(y, y_pred, sample_weight) @@ -1082,6 +1083,7 @@ def _decision_function(self, X): predict_stages(self.estimators_, X, self.learning_rate, score) return score + @deprecated(" and will be removed in 0.19") def decision_function(self, X): """Compute the decision function of ``X``. @@ -1104,7 +1106,7 @@ def decision_function(self, X): return score.ravel() return score - def staged_decision_function(self, X): + def _staged_decision_function(self, X): """Compute decision function of ``X`` for each iteration. This method allows monitoring (i.e. determine error on testing set) @@ -1129,6 +1131,30 @@ def staged_decision_function(self, X): predict_stage(self.estimators_, i, X, self.learning_rate, score) yield score.copy() + @deprecated(" and will be removed in 0.19") + def staged_decision_function(self, X): + """Compute decision function of ``X`` for each iteration. + + This method allows monitoring (i.e. determine error on testing set) + after each stage. + + Parameters + ---------- + X : array-like of shape = [n_samples, n_features] + The input samples. + + Returns + ------- + score : generator of array, shape = [n_samples, k] + The decision function of the input samples. The order of the + classes corresponds to that in the attribute `classes_`. + Regression and binary classification are special cases with + ``k == 1``, otherwise ``k==n_classes``. + """ + for dec in self._staged_decision_function(X): + # no yield from in Python2.X + yield dec + @property def feature_importances_(self): """Return the feature importances (the higher, the more important the @@ -1315,6 +1341,51 @@ def _validate_y(self, y): self.n_classes_ = len(self.classes_) return y + def decision_function(self, X): + """Compute the decision function of ``X``. + + Parameters + ---------- + X : array-like of shape = [n_samples, n_features] + The input samples. + + Returns + ------- + score : array, shape = [n_samples, n_classes] or [n_samples] + The decision function of the input samples. The order of the + classes corresponds to that in the attribute `classes_`. + Regression and binary classification produce an array of shape + [n_samples]. + """ + X = check_array(X, dtype=DTYPE, order="C") + score = self._decision_function(X) + if score.shape[1] == 1: + return score.ravel() + return score + + def staged_decision_function(self, X): + """Compute decision function of ``X`` for each iteration. + + This method allows monitoring (i.e. determine error on testing set) + after each stage. + + Parameters + ---------- + X : array-like of shape = [n_samples, n_features] + The input samples. + + Returns + ------- + score : generator of array, shape = [n_samples, k] + The decision function of the input samples. The order of the + classes corresponds to that in the attribute `classes_`. + Regression and binary classification are special cases with + ``k == 1``, otherwise ``k==n_classes``. + """ + for dec in self._staged_decision_function(X): + # no yield from in Python2.X + yield dec + def predict(self, X): """Predict class for X. @@ -1348,7 +1419,7 @@ def staged_predict(self, X): y : generator of array of shape = [n_samples] The predicted value of the input samples. """ - for score in self.staged_decision_function(X): + for score in self._staged_decision_function(X): decisions = self.loss_._score_to_decision(score) yield self.classes_.take(decisions, axis=0) @@ -1419,7 +1490,7 @@ def staged_predict_proba(self, X): The predicted value of the input samples. """ try: - for score in self.staged_decision_function(X): + for score in self._staged_decision_function(X): yield self.loss_._score_to_proba(score) except NotFittedError: raise @@ -1594,7 +1665,8 @@ def predict(self, X): y : array of shape = [n_samples] The predicted values. """ - return self.decision_function(X).ravel() + X = check_array(X, dtype=DTYPE, order="C") + return self._decision_function(X).ravel() def staged_predict(self, X): """Predict regression target at each stage for X. @@ -1612,5 +1684,5 @@ def staged_predict(self, X): y : generator of array of shape = [n_samples] The predicted value of the input samples. """ - for y in self.staged_decision_function(X): + for y in self._staged_decision_function(X): yield y.ravel() diff --git a/sklearn/ensemble/tests/test_gradient_boosting.py b/sklearn/ensemble/tests/test_gradient_boosting.py index 6ddb5cb34c03a..3f7f7f23f6566 100644 --- a/sklearn/ensemble/tests/test_gradient_boosting.py +++ b/sklearn/ensemble/tests/test_gradient_boosting.py @@ -1,7 +1,7 @@ """ Testing for the gradient boosting module (sklearn.ensemble.gradient_boosting). """ - +import warnings import numpy as np from sklearn import datasets @@ -171,8 +171,9 @@ def test_boston(): for loss in ("ls", "lad", "huber"): for subsample in (1.0, 0.5): last_y_pred = None - for i, sample_weight in enumerate((None, np.ones(len(boston.target)), - 2 * np.ones(len(boston.target)))): + for i, sample_weight in enumerate( + (None, np.ones(len(boston.target)), + 2 * np.ones(len(boston.target)))): clf = GradientBoostingRegressor(n_estimators=100, loss=loss, max_depth=4, subsample=subsample, min_samples_split=1, @@ -343,6 +344,7 @@ def test_check_max_features(): max_features=-0.1) assert_raises(ValueError, clf.fit, X, y) + def test_max_feature_regression(): # Test to make sure random state is set properly. X, y = datasets.make_hastie_10_2(n_samples=12000, random_state=1) @@ -455,7 +457,8 @@ def test_staged_functions_defensive(): if staged_func is None: # regressor has no staged_predict_proba continue - staged_result = list(staged_func(X)) + with warnings.catch_warnings(record=True): + staged_result = list(staged_func(X)) staged_result[1][:] = 0 assert_true(np.all(staged_result[0] != 0)) @@ -843,7 +846,7 @@ def test_complete_classification(): k = 4 est = GradientBoostingClassifier(n_estimators=20, max_depth=None, - random_state=1, max_leaf_nodes=k+1) + random_state=1, max_leaf_nodes=k + 1) est.fit(X, y) tree = est.estimators_[0, 0].tree_ @@ -858,7 +861,7 @@ def test_complete_regression(): k = 4 est = GradientBoostingRegressor(n_estimators=20, max_depth=None, - random_state=1, max_leaf_nodes=k+1) + random_state=1, max_leaf_nodes=k + 1) est.fit(boston.data, boston.target) tree = est.estimators_[-1, 0].tree_ @@ -971,8 +974,7 @@ def test_non_uniform_weights_toy_edge_case_reg(): X = [[1, 0], [1, 0], [1, 0], - [0, 1], - ] + [0, 1]] y = [0, 0, 1, 0] # ignore the first 2 training samples by setting their weight to 0 sample_weight = [0, 0, 1, 1] @@ -1002,8 +1004,7 @@ def test_non_uniform_weights_toy_edge_case_clf(): X = [[1, 0], [1, 0], [1, 0], - [0, 1], - ] + [0, 1]] y = [0, 0, 1, 0] # ignore the first 2 training samples by setting their weight to 0 sample_weight = [0, 0, 1, 1] diff --git a/sklearn/grid_search.py b/sklearn/grid_search.py index c0e428f4bc8d0..4ccae7f8819f0 100644 --- a/sklearn/grid_search.py +++ b/sklearn/grid_search.py @@ -331,6 +331,10 @@ def __init__(self, estimator, scoring=None, self.pre_dispatch = pre_dispatch self.error_score = error_score + @property + def _estimator_type(self): + return self.estimator._estimator_type + def score(self, X, y=None): """Returns the score on the given data, if the estimator has been refit diff --git a/sklearn/linear_model/base.py b/sklearn/linear_model/base.py index 295e76c584135..c955e0631819e 100644 --- a/sklearn/linear_model/base.py +++ b/sklearn/linear_model/base.py @@ -24,7 +24,7 @@ from ..externals import six from ..externals.joblib import Parallel, delayed from ..base import BaseEstimator, ClassifierMixin, RegressorMixin -from ..utils import as_float_array, check_array, check_X_y +from ..utils import as_float_array, check_array, check_X_y, deprecated from ..utils.extmath import safe_sparse_dot from ..utils.sparsefuncs import mean_variance_axis, inplace_column_scale from ..utils.fixes import sparse_lsqr @@ -119,6 +119,7 @@ class LinearModel(six.with_metaclass(ABCMeta, BaseEstimator)): def fit(self, X, y): """Fit model.""" + @deprecated(" and will be removed in 0.19.") def decision_function(self, X): """Decision function of the linear model. @@ -132,6 +133,9 @@ def decision_function(self, X): C : array, shape = (n_samples,) Returns predicted values. """ + return self._decision_function(X) + + def _decision_function(self, X): check_is_fitted(self, "coef_") X = check_array(X, accept_sparse=['csr', 'csc', 'coo']) @@ -151,7 +155,7 @@ def predict(self, X): C : array, shape = (n_samples,) Returns predicted values. """ - return self.decision_function(X) + return self._decision_function(X) _center_data = staticmethod(center_data) diff --git a/sklearn/linear_model/coordinate_descent.py b/sklearn/linear_model/coordinate_descent.py index a94209064cb24..e4a17867cf9fe 100644 --- a/sklearn/linear_model/coordinate_descent.py +++ b/sklearn/linear_model/coordinate_descent.py @@ -15,7 +15,7 @@ from .base import LinearModel, _pre_fit from ..base import RegressorMixin from .base import center_data, sparse_center_data -from ..utils import check_array, check_X_y +from ..utils import check_array, check_X_y, deprecated from ..utils.validation import check_random_state from ..cross_validation import _check_cv as check_cv from ..externals.joblib import Parallel, delayed @@ -689,6 +689,7 @@ def sparse_coef_(self): """ sparse representation of the fitted coef """ return sparse.csr_matrix(self.coef_) + @deprecated(" and will be removed in 0.19") def decision_function(self, X): """Decision function of the linear model @@ -696,6 +697,20 @@ def decision_function(self, X): ---------- X : numpy array or scipy.sparse matrix of shape (n_samples, n_features) + Returns + ------- + T : array, shape = (n_samples,) + The predicted decision function + """ + return self._decision_function(X) + + def _decision_function(self, X): + """Decision function of the linear model + + Parameters + ---------- + X : numpy array or scipy.sparse matrix of shape (n_samples, n_features) + Returns ------- T : array, shape = (n_samples,) @@ -706,7 +721,7 @@ def decision_function(self, X): return np.ravel(safe_sparse_dot(self.coef_, X.T, dense_output=True) + self.intercept_) else: - return super(ElasticNet, self).decision_function(X) + return super(ElasticNet, self)._decision_function(X) ############################################################################### diff --git a/sklearn/linear_model/stochastic_gradient.py b/sklearn/linear_model/stochastic_gradient.py index 3031174a8c83a..6eb9c47648070 100644 --- a/sklearn/linear_model/stochastic_gradient.py +++ b/sklearn/linear_model/stochastic_gradient.py @@ -14,7 +14,8 @@ from .base import LinearClassifierMixin, SparseCoefMixin from ..base import BaseEstimator, RegressorMixin from ..feature_selection.from_model import _LearntSelectorMixin -from ..utils import (check_array, check_random_state, check_X_y) +from ..utils import (check_array, check_random_state, check_X_y, + deprecated) from ..utils.extmath import safe_sparse_dot from ..utils.multiclass import _check_partial_fit_first_call from ..utils.validation import check_is_fitted @@ -974,6 +975,7 @@ def fit(self, X, y, coef_init=None, intercept_init=None, intercept_init=intercept_init, sample_weight=sample_weight) + @deprecated(" and will be removed in 0.19.") def decision_function(self, X): """Predict using the linear model diff --git a/sklearn/metrics/scorer.py b/sklearn/metrics/scorer.py index 22d796fec1f5e..5dd91786f0eb1 100644 --- a/sklearn/metrics/scorer.py +++ b/sklearn/metrics/scorer.py @@ -30,6 +30,7 @@ from .cluster import adjusted_rand_score from ..utils.multiclass import type_of_target from ..externals import six +from ..base import is_regressor class _BaseScorer(six.with_metaclass(ABCMeta, object)): @@ -157,20 +158,23 @@ def __call__(self, clf, X, y, sample_weight=None): if y_type not in ("binary", "multilabel-indicator"): raise ValueError("{0} format is not supported".format(y_type)) - try: - y_pred = clf.decision_function(X) + if is_regressor(clf): + y_pred = clf.predict(X) + else: + try: + y_pred = clf.decision_function(X) - # For multi-output multi-class estimator - if isinstance(y_pred, list): - y_pred = np.vstack(p for p in y_pred).T + # For multi-output multi-class estimator + if isinstance(y_pred, list): + y_pred = np.vstack(p for p in y_pred).T - except (NotImplementedError, AttributeError): - y_pred = clf.predict_proba(X) + except (NotImplementedError, AttributeError): + y_pred = clf.predict_proba(X) - if y_type == "binary": - y_pred = y_pred[:, 1] - elif isinstance(y_pred, list): - y_pred = np.vstack([p[:, -1] for p in y_pred]).T + if y_type == "binary": + y_pred = y_pred[:, 1] + elif isinstance(y_pred, list): + y_pred = np.vstack([p[:, -1] for p in y_pred]).T if sample_weight is not None: return self._sign * self._score_func(y, y_pred, diff --git a/sklearn/metrics/tests/test_score_objects.py b/sklearn/metrics/tests/test_score_objects.py index 5a2523bf1dcf7..b7662ecf11b9d 100644 --- a/sklearn/metrics/tests/test_score_objects.py +++ b/sklearn/metrics/tests/test_score_objects.py @@ -22,7 +22,7 @@ from sklearn.cluster import KMeans from sklearn.dummy import DummyRegressor from sklearn.linear_model import Ridge, LogisticRegression -from sklearn.tree import DecisionTreeClassifier +from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor from sklearn.datasets import make_blobs from sklearn.datasets import make_classification from sklearn.datasets import make_multilabel_classification @@ -219,6 +219,13 @@ def test_thresholded_scorers(): score2 = roc_auc_score(y_test, clf.predict_proba(X_test)[:, 1]) assert_almost_equal(score1, score2) + # test with a regressor (no decision_function) + reg = DecisionTreeRegressor() + reg.fit(X_train, y_train) + score1 = get_scorer('roc_auc')(reg, X_test, y_test) + score2 = roc_auc_score(y_test, reg.predict(X_test)) + assert_almost_equal(score1, score2) + # Test that an exception is raised on more than two classes X, y = make_blobs(random_state=0, centers=3) X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0) diff --git a/sklearn/multiclass.py b/sklearn/multiclass.py index 3cfd2a4962ac2..b92f2baab45de 100644 --- a/sklearn/multiclass.py +++ b/sklearn/multiclass.py @@ -39,7 +39,7 @@ import scipy.sparse as sp from .base import BaseEstimator, ClassifierMixin, clone, is_classifier -from .base import MetaEstimatorMixin +from .base import MetaEstimatorMixin, is_regressor from .preprocessing import LabelBinarizer from .metrics.pairwise import euclidean_distances from .utils import check_random_state @@ -77,6 +77,8 @@ def _fit_binary(estimator, X, y, classes=None): def _predict_binary(estimator, X): """Make predictions using a single binary estimator.""" + if is_regressor(estimator): + return estimator.predict(X) try: score = np.ravel(estimator.decision_function(X)) except (AttributeError, NotImplementedError): @@ -276,11 +278,11 @@ def fit(self, X, y): # In cases where individual estimators are very fast to train setting # n_jobs > 1 in can results in slower performance due to the overhead # of spawning threads. See joblib issue #112. - self.estimators_ = Parallel(n_jobs=self.n_jobs)(delayed(_fit_binary) - (self.estimator, X, column, - classes=["not %s" % self.label_binarizer_.classes_[i], - self.label_binarizer_.classes_[i]]) - for i, column in enumerate(columns)) + self.estimators_ = Parallel(n_jobs=self.n_jobs)(delayed(_fit_binary)( + self.estimator, X, column, classes=[ + "not %s" % self.label_binarizer_.classes_[i], + self.label_binarizer_.classes_[i]]) + for i, column in enumerate(columns)) return self diff --git a/sklearn/pipeline.py b/sklearn/pipeline.py index 86d0e624cb54b..e89cbfc7af5ab 100644 --- a/sklearn/pipeline.py +++ b/sklearn/pipeline.py @@ -94,6 +94,10 @@ def __init__(self, steps): "'%s' (type %s) doesn't)" % (estimator, type(estimator))) + @property + def _estimator_type(self): + return self.steps[-1][1]._estimator_type + def get_params(self, deep=True): if not deep: return super(Pipeline, self).get_params(deep=False) diff --git a/sklearn/svm/base.py b/sklearn/svm/base.py index f4f5f54452835..7ba9dbb9034f2 100644 --- a/sklearn/svm/base.py +++ b/sklearn/svm/base.py @@ -10,7 +10,7 @@ from ..base import BaseEstimator, ClassifierMixin from ..preprocessing import LabelEncoder from ..utils import check_array, check_random_state, column_or_1d -from ..utils import ConvergenceWarning, compute_class_weight +from ..utils import ConvergenceWarning, compute_class_weight, deprecated from ..utils.extmath import safe_sparse_dot from ..utils.validation import check_is_fitted from ..externals import six @@ -348,6 +348,7 @@ def _compute_kernel(self, X): X = np.asarray(kernel, dtype=np.float64, order='C') return X + @deprecated(" and will be removed in 0.19") def decision_function(self, X): """Distance of the samples X to the separating hyperplane. @@ -357,6 +358,21 @@ def decision_function(self, X): For kernel="precomputed", the expected shape of X is [n_samples_test, n_samples_train]. + Returns + ------- + X : array-like, shape = [n_samples, n_class * (n_class-1) / 2] + Returns the decision function of the sample for each class + in the model. + """ + return self._decision_function(X) + + def _decision_function(self, X): + """Distance of the samples X to the separating hyperplane. + + Parameters + ---------- + X : array-like, shape = [n_samples, n_features] + Returns ------- X : array-like, shape = [n_samples, n_class * (n_class-1) / 2] @@ -481,6 +497,21 @@ def _validate_targets(self, y): return np.asarray(y, dtype=np.float64, order='C') + def decision_function(self, X): + """Distance of the samples X to the separating hyperplane. + + Parameters + ---------- + X : array-like, shape = [n_samples, n_features] + + Returns + ------- + X : array-like, shape = [n_samples, n_class * (n_class-1) / 2] + Returns the decision function of the sample for each class + in the model. + """ + return self._decision_function(X) + def predict(self, X): """Perform classification on samples in X. diff --git a/sklearn/tests/test_common.py b/sklearn/tests/test_common.py index 3c25714857d37..329d7b71892dc 100644 --- a/sklearn/tests/test_common.py +++ b/sklearn/tests/test_common.py @@ -61,6 +61,7 @@ check_transformer_n_iter, check_fit_score_takes_y, check_non_transformer_estimators_n_iter, + check_regressors_no_decision_function, check_pipeline_consistency, CROSS_DECOMPOSITION) @@ -190,6 +191,7 @@ def test_regressors(): yield check_regressors_train, name, Regressor yield check_regressor_data_not_an_array, name, Regressor yield check_estimators_partial_fit_n_features, name, Regressor + yield check_regressors_no_decision_function, name, Regressor # Test that estimators can be pickled, and once pickled # give the same answer as before. yield check_regressors_pickle, name, Regressor diff --git a/sklearn/tests/test_multiclass.py b/sklearn/tests/test_multiclass.py index b783355007d7c..bfe65b4bc066e 100644 --- a/sklearn/tests/test_multiclass.py +++ b/sklearn/tests/test_multiclass.py @@ -31,7 +31,7 @@ from sklearn.naive_bayes import MultinomialNB from sklearn.linear_model import (LinearRegression, Lasso, ElasticNet, Ridge, Perceptron, LogisticRegression) -from sklearn.tree import DecisionTreeClassifier +from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor from sklearn.grid_search import GridSearchCV from sklearn.pipeline import Pipeline from sklearn import svm @@ -79,6 +79,23 @@ def test_ovr_fit_predict(): assert_greater(np.mean(iris.target == pred), 0.65) +def test_ovr_ovo_regressor(): + # test that ovr and ovo work on regressors which don't have a decision_function + ovr = OneVsRestClassifier(DecisionTreeRegressor()) + pred = ovr.fit(iris.data, iris.target).predict(iris.data) + assert_equal(len(ovr.estimators_), n_classes) + assert_array_equal(np.unique(pred), [0, 1, 2]) + # we are doing something sensible + assert_greater(np.mean(pred == iris.target), .9) + + ovr = OneVsOneClassifier(DecisionTreeRegressor()) + pred = ovr.fit(iris.data, iris.target).predict(iris.data) + assert_equal(len(ovr.estimators_), n_classes * (n_classes - 1) / 2) + assert_array_equal(np.unique(pred), [0, 1, 2]) + # we are doing something sensible + assert_greater(np.mean(pred == iris.target), .9) + + def test_ovr_fit_predict_sparse(): for sparse in [sp.csr_matrix, sp.csc_matrix, sp.coo_matrix, sp.dok_matrix, sp.lil_matrix]: diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index 44dfe375417c8..3e46d7491d91f 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -19,6 +19,7 @@ from sklearn.utils.testing import assert_true from sklearn.utils.testing import assert_array_equal from sklearn.utils.testing import assert_array_almost_equal +from sklearn.utils.testing import assert_warns_message from sklearn.utils.testing import META_ESTIMATORS from sklearn.utils.testing import set_random_state from sklearn.utils.testing import assert_greater @@ -843,6 +844,31 @@ def check_regressors_pickle(name, Regressor): assert_array_almost_equal(pickled_y_pred, y_pred) +@ignore_warnings +def check_regressors_no_decision_function(name, Regressor): + # checks whether regressors have decision_function or predict_proba + rng = np.random.RandomState(0) + X = rng.normal(size=(10, 4)) + y = multioutput_estimator_convert_y_2d(name, X[:, 0]) + regressor = Regressor() + + set_fast_parameters(regressor) + if hasattr(regressor, "n_components"): + # FIXME CCA, PLS is not robust to rank 1 effects + regressor.n_components = 1 + + regressor.fit(X, y) + funcs = ["decision_function", "predict_proba", "predict_log_proba"] + for func_name in funcs: + func = getattr(regressor, func_name, None) + if func is None: + # doesn't have function + continue + # has function. Should raise deprecation warning + msg = func_name + assert_warns_message(DeprecationWarning, msg, func, X) + + def check_class_weight_classifiers(name, Classifier): for n_centers in [2, 3]: # create a very noisy dataset From 188fb11968e6ee90e71d68d69a0c79e1ec1510df Mon Sep 17 00:00:00 2001 From: Andreas Mueller Date: Wed, 1 Apr 2015 10:24:58 -0400 Subject: [PATCH 2/3] COSMIT use consistent shape description in docstring. --- sklearn/linear_model/coordinate_descent.py | 52 +++++++++++----------- sklearn/svm/base.py | 28 ++++++------ 2 files changed, 40 insertions(+), 40 deletions(-) diff --git a/sklearn/linear_model/coordinate_descent.py b/sklearn/linear_model/coordinate_descent.py index e4a17867cf9fe..886e3e7da8fac 100644 --- a/sklearn/linear_model/coordinate_descent.py +++ b/sklearn/linear_model/coordinate_descent.py @@ -41,7 +41,7 @@ def _alpha_grid(X, y, Xy=None, l1_ratio=1.0, fit_intercept=True, Training data. Pass directly as Fortran-contiguous data to avoid unnecessary memory duplication - y : ndarray, shape = (n_samples,) + y : ndarray, shape (n_samples,) Target values Xy : array-like, optional @@ -139,7 +139,7 @@ def lasso_path(X, y, eps=1e-3, n_alphas=100, alphas=None, unnecessary memory duplication. If ``y`` is mono-output then ``X`` can be sparse. - y : ndarray, shape = (n_samples,), or (n_samples, n_outputs) + y : ndarray, shape (n_samples,), or (n_samples, n_outputs) Target values eps : float, optional @@ -281,7 +281,7 @@ def enet_path(X, y, l1_ratio=0.5, eps=1e-3, n_alphas=100, alphas=None, unnecessary memory duplication. If ``y`` is mono-output then ``X`` can be sparse. - y : ndarray, shape = (n_samples,) or (n_samples, n_outputs) + y : ndarray, shape (n_samples,) or (n_samples, n_outputs) Target values l1_ratio : float, optional @@ -547,14 +547,14 @@ class ElasticNet(LinearModel, RegressorMixin): Attributes ---------- - coef_ : array, shape = (n_features,) | (n_targets, n_features) + coef_ : array, shape (n_features,) | (n_targets, n_features) parameter vector (w in the cost function formula) - sparse_coef_ : scipy.sparse matrix, shape = (n_features, 1) | \ + sparse_coef_ : scipy.sparse matrix, shape (n_features, 1) | \ (n_targets, n_features) ``sparse_coef_`` is a readonly property derived from ``coef_`` - intercept_ : float | array, shape = (n_targets,) + intercept_ : float | array, shape (n_targets,) independent term in decision function. n_iter_ : array-like, shape (n_targets,) @@ -601,7 +601,7 @@ def fit(self, X, y): X : ndarray or scipy.sparse matrix, (n_samples, n_features) Data - y : ndarray, shape = (n_samples,) or (n_samples, n_targets) + y : ndarray, shape (n_samples,) or (n_samples, n_targets) Target Notes @@ -699,7 +699,7 @@ def decision_function(self, X): Returns ------- - T : array, shape = (n_samples,) + T : array, shape (n_samples,) The predicted decision function """ return self._decision_function(X) @@ -713,7 +713,7 @@ def _decision_function(self, X): Returns ------- - T : array, shape = (n_samples,) + T : array, shape (n_samples,) The predicted decision function """ check_is_fitted(self, 'n_iter_') @@ -794,14 +794,14 @@ class Lasso(ElasticNet): Attributes ---------- - coef_ : array, shape = (n_features,) | (n_targets, n_features) + coef_ : array, shape (n_features,) | (n_targets, n_features) parameter vector (w in the cost function formula) - sparse_coef_ : scipy.sparse matrix, shape = (n_features, 1) | \ + sparse_coef_ : scipy.sparse matrix, shape (n_features, 1) | \ (n_targets, n_features) ``sparse_coef_`` is a readonly property derived from ``coef_`` - intercept_ : float | array, shape = (n_targets,) + intercept_ : float | array, shape (n_targets,) independent term in decision function. n_iter_ : int | array-like, shape (n_targets,) @@ -1231,16 +1231,16 @@ class LassoCV(LinearModelCV, RegressorMixin): alpha_ : float The amount of penalization chosen by cross validation - coef_ : array, shape = (n_features,) | (n_targets, n_features) + coef_ : array, shape (n_features,) | (n_targets, n_features) parameter vector (w in the cost function formula) - intercept_ : float | array, shape = (n_targets,) + intercept_ : float | array, shape (n_targets,) independent term in decision function. - mse_path_ : array, shape = (n_alphas, n_folds) + mse_path_ : array, shape (n_alphas, n_folds) mean square error for the test set on each fold, varying alpha - alphas_ : numpy array, shape = (n_alphas,) + alphas_ : numpy array, shape (n_alphas,) The grid of alphas used for fitting dual_gap_ : ndarray, shape () @@ -1372,17 +1372,17 @@ class ElasticNetCV(LinearModelCV, RegressorMixin): The compromise between l1 and l2 penalization chosen by cross validation - coef_ : array, shape = (n_features,) | (n_targets, n_features) + coef_ : array, shape (n_features,) | (n_targets, n_features) Parameter vector (w in the cost function formula), - intercept_ : float | array, shape = (n_targets, n_features) + intercept_ : float | array, shape (n_targets, n_features) Independent term in the decision function. - mse_path_ : array, shape = (n_l1_ratio, n_alpha, n_folds) + mse_path_ : array, shape (n_l1_ratio, n_alpha, n_folds) Mean square error for the test set on each fold, varying l1_ratio and alpha. - alphas_ : numpy array, shape = (n_alphas,) or (n_l1_ratio, n_alphas) + alphas_ : numpy array, shape (n_alphas,) or (n_l1_ratio, n_alphas) The grid of alphas used for fitting, for each l1_ratio. n_iter_ : int @@ -1512,10 +1512,10 @@ class MultiTaskElasticNet(Lasso): Attributes ---------- - intercept_ : array, shape = (n_tasks,) + intercept_ : array, shape (n_tasks,) Independent term in decision function. - coef_ : array, shape = (n_tasks, n_features) + coef_ : array, shape (n_tasks, n_features) Parameter vector (W in the cost function formula). If a 1D y is \ passed in at fit (non multi-task usage), ``coef_`` is then a 1D array @@ -1569,9 +1569,9 @@ def fit(self, X, y): Parameters ----------- - X : ndarray, shape = (n_samples, n_features) + X : ndarray, shape (n_samples, n_features) Data - y : ndarray, shape = (n_samples, n_tasks) + y : ndarray, shape (n_samples, n_tasks) Target Notes @@ -1689,10 +1689,10 @@ class MultiTaskLasso(MultiTaskElasticNet): Attributes ---------- - coef_ : array, shape = (n_tasks, n_features) + coef_ : array, shape (n_tasks, n_features) parameter vector (W in the cost function formula) - intercept_ : array, shape = (n_tasks,) + intercept_ : array, shape (n_tasks,) independent term in decision function. n_iter_ : int diff --git a/sklearn/svm/base.py b/sklearn/svm/base.py index 7ba9dbb9034f2..e43fcb17a9846 100644 --- a/sklearn/svm/base.py +++ b/sklearn/svm/base.py @@ -354,13 +354,13 @@ def decision_function(self, X): Parameters ---------- - X : array-like, shape = [n_samples, n_features] + X : array-like, shape (n_samples, n_features) For kernel="precomputed", the expected shape of X is [n_samples_test, n_samples_train]. Returns ------- - X : array-like, shape = [n_samples, n_class * (n_class-1) / 2] + X : array-like, shape (n_samples, n_class * (n_class-1) / 2) Returns the decision function of the sample for each class in the model. """ @@ -371,11 +371,11 @@ def _decision_function(self, X): Parameters ---------- - X : array-like, shape = [n_samples, n_features] + X : array-like, shape (n_samples, n_features) Returns ------- - X : array-like, shape = [n_samples, n_class * (n_class-1) / 2] + X : array-like, shape (n_samples, n_class * (n_class-1) / 2) Returns the decision function of the sample for each class in the model. """ @@ -502,11 +502,11 @@ def decision_function(self, X): Parameters ---------- - X : array-like, shape = [n_samples, n_features] + X : array-like, shape (n_samples, n_features) Returns ------- - X : array-like, shape = [n_samples, n_class * (n_class-1) / 2] + X : array-like, shape (n_samples, n_class * (n_class-1) / 2) Returns the decision function of the sample for each class in the model. """ @@ -519,13 +519,13 @@ def predict(self, X): Parameters ---------- - X : {array-like, sparse matrix}, shape = [n_samples, n_features] + X : {array-like, sparse matrix}, shape (n_samples, n_features) For kernel="precomputed", the expected shape of X is [n_samples_test, n_samples_train] Returns ------- - y_pred : array, shape = [n_samples] + y_pred : array, shape (n_samples,) Class labels for samples in X. """ y = super(BaseSVC, self).predict(X) @@ -552,13 +552,13 @@ def predict_proba(self): Parameters ---------- - X : array-like, shape = [n_samples, n_features] + X : array-like, shape (n_samples, n_features) For kernel="precomputed", the expected shape of X is [n_samples_test, n_samples_train] Returns ------- - T : array-like, shape = [n_samples, n_classes] + T : array-like, shape (n_samples, n_classes) Returns the probability of the sample for each class in the model. The columns correspond to the classes in sorted order, as they appear in the attribute `classes_`. @@ -588,13 +588,13 @@ def predict_log_proba(self): Parameters ---------- - X : array-like, shape = [n_samples, n_features] + X : array-like, shape (n_samples, n_features) For kernel="precomputed", the expected shape of X is [n_samples_test, n_samples_train] Returns ------- - T : array-like, shape = [n_samples, n_classes] + T : array-like, shape (n_samples, n_classes) Returns the log-probabilities of the sample for each class in the model. The columns correspond to the classes in sorted order, as they appear in the attribute `classes_`. @@ -739,11 +739,11 @@ def _fit_liblinear(X, y, C, fit_intercept, intercept_scaling, class_weight, Parameters ---------- - X : {array-like, sparse matrix}, shape = [n_samples, n_features] + X : {array-like, sparse matrix}, shape (n_samples, n_features) Training vector, where n_samples in the number of samples and n_features is the number of features. - y : array-like, shape = [n_samples] + y : array-like, shape (n_samples,) Target vector relative to X C : float From acb21bb79edd5396860f49ba3833d7dd4972e877 Mon Sep 17 00:00:00 2001 From: Andreas Mueller Date: Thu, 2 Apr 2015 11:15:10 -0400 Subject: [PATCH 3/3] DOC adding clusterer tag to dev docs. --- doc/developers/index.rst | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/doc/developers/index.rst b/doc/developers/index.rst index 199d93598d420..cd04c1a48baed 100644 --- a/doc/developers/index.rst +++ b/doc/developers/index.rst @@ -892,9 +892,10 @@ on a classifier, but not otherwise. Similarly, scorers for average precision that take a continuous prediction need to call ``decision_function`` for classifiers, but ``predict`` for regressors. This distinction between classifiers and regressors is implemented using the ``_estimator_type`` attribute, which takes a string value. -It should be ``"classifier"`` for classifiers and ``"regressor"`` for regressors, -to work as expected. Inheriting from ``ClassifierMixin`` or ``RegressorMixin`` will -set the attribute automatically. +It should be ``"classifier"`` for classifiers and ``"regressor"`` for +regressors and ``"clusterer"`` for clustering methods, to work as expected. +Inheriting from ``ClassifierMixin``, ``RegressorMixin`` or ``ClusterMixin`` +will set the attribute automatically. Working notes -------------