diff --git a/sklearn/ensemble/__init__.py b/sklearn/ensemble/__init__.py index 5586a9e1e1fba..18ae153eed555 100644 --- a/sklearn/ensemble/__init__.py +++ b/sklearn/ensemble/__init__.py @@ -17,6 +17,7 @@ from .gradient_boosting import GradientBoostingClassifier from .gradient_boosting import GradientBoostingRegressor from .voting_classifier import VotingClassifier +from .stacking_classifier import StackingClassifier from . import bagging from . import forest @@ -32,4 +33,4 @@ "GradientBoostingRegressor", "AdaBoostClassifier", "AdaBoostRegressor", "VotingClassifier", "bagging", "forest", "gradient_boosting", - "partial_dependence", "weight_boosting"] + "partial_dependence", "weight_boosting", "StackingClassifier"] diff --git a/sklearn/ensemble/stacking_classifier.py b/sklearn/ensemble/stacking_classifier.py new file mode 100644 index 0000000000000..bb5f033f9cb15 --- /dev/null +++ b/sklearn/ensemble/stacking_classifier.py @@ -0,0 +1,319 @@ +import numpy as np +from ..base import BaseEstimator, ClassifierMixin +from ..base import clone, is_classifier +from ..model_selection._validation import cross_val_predict +from ..preprocessing import LabelEncoder +from ..externals.joblib import Parallel, delayed +from ..utils.metaestimators import if_delegate_has_method +from ..utils.validation import has_fit_parameter, check_is_fitted +from ..utils.multiclass import type_of_target +from ..externals import six + + +def _parallel_fit(clf, X, y, fit_params): + clf.fit(X, y, **fit_params) + return clf + + +class StackingClassifier(BaseEstimator, ClassifierMixin): + """ Stacking classifier for combining estimators + + The cross-validated predictions of estimators will be used as + inputs to a meta-estimator for making a final prediction + + Parameters + ---------- + estimators : list of (string, estimator) tuples + Invoking the ``fit`` method on the ``StackingClassifier`` will fit + clones of those original estimators that will be stored in the class + attribute `self.estimators_`. + + meta_estimator : estimator + The meta-estimator to combine the predictions of each individual + estimator + + method : string, optional, default='auto' + Specifies which method of the estimators will be called to generate + inputs to the meta_estimator. The default is 'auto' which is + `predict_proba` if it exists and `decision_function` otherwise + + cv : int, cross-validation generator or an iterable, optional + Determines the cross-validation splitting strategy to generate inputs + to the meta_estimator + Possible inputs for cv are: + - None, to use the default 3-fold cross validation, + - 1, do not perform cross-validation, + - integer>1, to specify the number of folds in a `(Stratified)KFold`, + - An object to be used as a cross-validation generator. + - An iterable yielding train, test splits. + + For integer/None inputs, if the estimator is a classifier and ``y`` is + either binary or multiclass, :class:`StratifiedKFold` is used. If y is + a different format, :class:`KFold` is used. + + Refer :ref:`User Guide ` for the various + cross-validation strategies that can be used here. + + verbose : integer + Controls the verbosity: the higher, the more messages. + + n_jobs : int, optional (default=1) + The number of jobs to run in parallel for ``fit``. + If -1, then the number of jobs is set to the number of cores. + + Attributes + ---------- + estimators_ : list of classifiers + The collection of fitted sub-estimators. + + meta_estimators_ : classifier + Fitted meta-estimator + + classes_ : array-like, shape = [n_predictions] + The classes labels. + + Examples + -------- + >>> import numpy as np + >>> from sklearn.linear_model import LogisticRegression + >>> from sklearn.naive_bayes import GaussianNB + >>> from sklearn.ensemble import RandomForestClassifier, StackingClassifier + >>> clf1 = LogisticRegression(random_state=1) + >>> clf2 = RandomForestClassifier(random_state=1) + >>> clfm = GaussianNB() + >>> X = np.array([[-1, -1], [-2, -1], [-3, -2], [1, 1], [2, 1], [3, 2]]) + >>> y = np.array([1, 1, 1, 2, 2, 2]) + >>> eclf1 = StackingClassifier(estimators=[ + ... ('lr', clf1), ('rf', clf2)], meta_estimator=clfm, cv=2) + >>> eclf1 = eclf1.fit(X, y) + >>> print(eclf1.predict(X)) + [1 1 1 2 2 2] + >>> + """ + + def __init__(self, estimators, meta_estimator, method='auto', + cv=None, n_jobs=1, verbose=0): + self.estimators = estimators + self.named_estimators = dict(estimators) + self.meta_estimator = meta_estimator + self.cv = cv + self.method = method + self.n_jobs = n_jobs + self.verbose = verbose + + def _method(self, estimator): + if self.method == 'auto': + if hasattr(estimator, 'predict_proba'): + method = 'predict_proba' + elif hasattr(estimator, 'decision_function'): + method = 'decision_function' + else: + raise AttributeError("Estimator %s has no method " + "`predict_proba` or `decision_function`. " + "Try specify a method instead of using " + "the default `auto`" % estimator) + else: + method = self.method + return method + + def fit(self, X, y, **kwargs): + """ Fit the estimators. + + Parameters + ---------- + X : {array-like, sparse matrix}, shape = [n_samples, n_features] + Training vectors, where n_samples is the number of samples and + n_features is the number of features. + + y : array-like, shape = [n_samples] + Target values. + + **kwargs: optional fit parameters + + + Returns + ------- + self : object + """ + if any(s in type_of_target(y) for s in ['multilabel', 'multioutput']): + raise NotImplementedError('Multilabel and multi-output' + ' classification is not supported.') + + if not self.estimators: + raise AttributeError('Invalid `estimators` attribute %s, ' + '`estimators` should be a list of (string, ' + 'estimator) tuples.' % type(self.estimators)) + + if not is_classifier(self.meta_estimator): + raise AttributeError('Invalid `meta_estimator` attribute ' + '%s, `meta_estimator` should be a classifier.' + % type(self.meta_estimator)) + for name, step in self.estimators: + if not hasattr(step, self._method(step)): + raise ValueError('Underlying estimator %s %s does ' + 'not support %s.' % (name, + type(step), + self._method(step))) + + for param in kwargs: + if not has_fit_parameter(self.meta_estimator, param): + raise ValueError('Underlying meta estimator %s ' + 'does not support `%s`.' % ( + type(self.meta_estimator), param)) + for name, step in self.estimators: + if not has_fit_parameter(step, param): + raise ValueError('Underlying estimator %s of type %s does ' + 'not support `%s`.' % (name, + type(step), + param)) + + self.le_ = LabelEncoder().fit(y) + self.classes_ = self.le_.classes_ + + transformed_y = self.le_.transform(y) + if self.cv == 1: # Do not cross-validation + # Parallel fit each estimator + + self.estimators_ = Parallel(n_jobs=self.n_jobs)( + delayed(_parallel_fit)(clone(clf), + X, transformed_y, kwargs) + for _, clf in self.estimators) + scores = self._est_predict(X) + else: + # Use the n_jobs of cross_val_predict + self.estimators_ = [] + scores = [] + for _, clf in self.estimators: + s1 = cross_val_predict(clf, X, y, cv=self.cv, + method=self._method(clf), + fit_params=kwargs, + verbose=self.verbose, + n_jobs=self.n_jobs) + s1 = self._form_meta_inputs(clf, s1) + scores.append(s1) + self.estimators_.append(clf.fit(X, y, **kwargs)) + scores = np.concatenate([s.reshape(-1, 1) if s.ndim == 1 else s + for s in scores], axis=1) + self.meta_estimator_ = clone(self.meta_estimator) + self.meta_estimator_.fit(scores, transformed_y, **kwargs) + return self + + def _form_meta_inputs(self, clf, predicted): + if self._method(clf) in ['predict_proba', 'predict_log_proba']: + # Remove first column to avoid multicollinearity since sum of + # probabilities equals 1 + predicted = predicted[:, 1:] + if self._method(clf) == 'predict_log_proba': + # Replace inf + predicted = np.clip(predicted, -1e30, 1e30) + return predicted + + def _est_predict(self, X): + """ Generate input to meta_estimator from predictions of estimators + """ + predicted = [] + for clf in self.estimators_: + s1 = getattr(clf, self._method(clf))(X) + s1 = self._form_meta_inputs(clf, s1) + predicted.append(s1) + return np.concatenate([s.reshape(-1, 1) if s.ndim == 1 else s + for s in predicted], axis=1) + + @if_delegate_has_method(delegate=('meta_estimator_', 'meta_estimator')) + def predict(self, X): + """Predict class labels for X. + + Only available if fitted and ``meta_estimator`` supports ``predict``. + + Parameters + ---------- + X : array-like, shape = [n_samples, n_features] + + Returns + ------- + C : array, shape = [n_samples] + Predicted class label per sample. + """ + check_is_fitted(self, ['estimators_', 'meta_estimator_']) + predicted = self.meta_estimator_.predict(self._est_predict(X)) + return self.le_.inverse_transform(predicted) + + @if_delegate_has_method(delegate=('meta_estimator_', 'meta_estimator')) + def predict_proba(self, X): + """ Return probability estimates for the test vector X. + + Only available if ``meta_estimator`` supports ``predict_proba``. + + Parameters + ---------- + X : array-like, shape = [n_samples, n_features] + + Returns + ------- + C : array-like, shape = [n_samples, n_classes] + Returns the probability of the samples for each class in + the model. The columns correspond to the classes in sorted + order, as they appear in the attribute `classes_`. + """ + check_is_fitted(self, ['estimators_', 'meta_estimator_']) + return self.meta_estimator_.predict_proba(self._est_predict(X)) + + @if_delegate_has_method(delegate=('best_estimator_', 'estimator')) + def predict_log_proba(self, X): + """ Return log-probability estimates for the test vector X. + + Only available if ``meta_estimator`` supports ``predict_log_proba``. + + Parameters + ---------- + X : array-like, shape = [n_samples, n_features] + + Returns + ------- + C : array-like, shape = [n_samples, n_classes] + Returns the log-probability of the samples for each class in + the model. The columns correspond to the classes in sorted + order, as they appear in the attribute `classes_`. + """ + check_is_fitted(self, ['estimators_', 'meta_estimator_']) + return self.meta_estimator_.predict_log_proba(self._est_predict(X)) + + @if_delegate_has_method(delegate=('best_estimator_', 'estimator')) + def decision_function(self, X): + """Predict confidence scores for samples. + + Only available if ``meta_estimator`` supports ``predict_log_proba``. + The confidence score for a sample is the signed distance of that + sample to the hyperplane. + + Parameters + ---------- + X : {array-like, sparse matrix}, shape = (n_samples, n_features) + Samples. + + Returns + ------- + array, shape=(n_samples,) if n_classes == 2 else (n_samples, n_classes) + Confidence scores per (sample, class) combination. In the binary + case, confidence score for self.classes_[1] where >0 means this + class would be predicted. + """ + check_is_fitted(self, ['estimators_', 'meta_estimator_']) + return self.meta_estimator_.decision_function(self._est_predict(X)) + + def get_params(self, deep=True): + """Return estimator parameter names for GridSearch support + """ + if not deep: + return super(StackingClassifier, self).get_params(deep=False) + else: + out = super(StackingClassifier, self).get_params(deep=False) + out.update(self.named_estimators.copy()) + out.update({'meta_estimator': self.meta_estimator}.copy()) + for name, step in six.iteritems(self.named_estimators): + for key, value in six.iteritems(step.get_params(deep=True)): + out['{0}__{1}'.format(name, key)] = value + for key, value in six.iteritems( + self.meta_estimator.get_params(deep=True)): + out['{0}__{1}'.format('meta_estimator', key)] = value + return out diff --git a/sklearn/ensemble/tests/test_stacking_classifier.py b/sklearn/ensemble/tests/test_stacking_classifier.py new file mode 100644 index 0000000000000..3b4e574ba042d --- /dev/null +++ b/sklearn/ensemble/tests/test_stacking_classifier.py @@ -0,0 +1,217 @@ +"""Testing for the StackingClassifier""" +from sklearn.utils.testing import assert_almost_equal, assert_array_equal +from sklearn.utils.testing import assert_equal, assert_array_almost_equal +from sklearn.utils.testing import assert_raise_message, assert_raises +from sklearn.exceptions import NotFittedError +from sklearn.linear_model import LogisticRegression +from sklearn.naive_bayes import GaussianNB +from sklearn.ensemble import RandomForestClassifier +from sklearn.ensemble import StackingClassifier +from sklearn.model_selection import GridSearchCV, cross_val_score +from sklearn import datasets +from sklearn.datasets import make_multilabel_classification +from sklearn.svm import SVC +from sklearn.multiclass import OneVsRestClassifier +import numpy as np +from sklearn.neighbors import KNeighborsClassifier + + +# Load the iris dataset and randomly permute it +iris = datasets.load_iris() +X, y = iris.data[:, 1:3], iris.target + + +def test_estimator_init(): + clf = LogisticRegression(random_state=1) + eclf = StackingClassifier(estimators=[], meta_estimator=clf) + msg = ("Invalid `estimators` attribute %s, `estimators` " + "should be a list of (string, estimator) tuples.") % (type([])) + assert_raises(AttributeError, eclf.fit, X, y) + + eclf = StackingClassifier(estimators=[('lr', clf)], meta_estimator=None) + msg = ("Invalid `meta_estimator` attribute %s, " + "`meta_estimator` should be a classifier.") % (type(None)) + assert_raise_message(AttributeError, msg, eclf.fit, X, y) + + clf_no_proba = SVC() + eclf = StackingClassifier(estimators=[('SVC', clf_no_proba)], + meta_estimator=clf, method='predict_proba') + msg = ("Underlying estimator SVC " + "does not support predict_proba.") + assert_raise_message(ValueError, msg, eclf.fit, X, y) + + +def test_notfitted(): + eclf = StackingClassifier(estimators=[('lr1', LogisticRegression()), + ('lr2', LogisticRegression())], + meta_estimator=LogisticRegression()) + msg = ("This StackingClassifier instance is not fitted yet. Call \'fit\'" + " with appropriate arguments before using this method.") + assert_raise_message(NotFittedError, msg, eclf.predict, X) + + +def test_multilabel(): + """Check if error is raised for multilabel classification.""" + X, y = make_multilabel_classification(n_classes=2, n_labels=1, + allow_unlabeled=False, + random_state=123) + clf = OneVsRestClassifier(SVC(kernel='linear')) + clf_meta = OneVsRestClassifier(SVC(kernel='linear')) + + eclf = StackingClassifier(estimators=[('ovr1', clf)], + meta_estimator=clf_meta, + method='decision_function') + + try: + eclf.fit(X, y) + except NotImplementedError: + return + + +def test_cv(): + """Test cross-validation option cv=1""" + clf1 = LogisticRegression(random_state=1) + clfm = GaussianNB() + eclf = StackingClassifier(estimators=[('lr', clf1)], cv=1, + meta_estimator=clfm).fit(X, y) + s = clf1.fit(X, y).predict_proba(X)[:, 1:] + assert_array_equal(clfm.fit(s, y).predict(s), eclf.predict(X)) + assert_array_equal(clfm.fit(s, y).predict_proba(s), + eclf.predict_proba(X)) + + +def test_gridsearch(): + """Check GridSearch support.""" + clf1 = LogisticRegression(random_state=1) + clf2 = RandomForestClassifier(random_state=1) + clf3 = GaussianNB() + clf_meta = LogisticRegression(random_state=1) + eclf = StackingClassifier(estimators=[ + ('lr', clf1), ('rf', clf2), ('nb', clf3)], + meta_estimator=clf_meta) + + params = {'lr__C': [1.0, 100.0], + 'method': ['predict', 'predict_proba', 'predict_log_proba'], + 'meta_estimator__C': [10.0, 20.0]} + + grid = GridSearchCV(estimator=eclf, param_grid=params, cv=5) + + grid.fit(iris.data, iris.target) + + +def test_parallel_predict(): + """Check parallel backend of StackingClassifier on iris.""" + clf1 = LogisticRegression(random_state=123) + clf2 = RandomForestClassifier(random_state=123) + clf3 = GaussianNB() + clfm = GaussianNB() + + eclf1 = StackingClassifier( + estimators=[('lr', clf1), ('rf', clf2), ('nb', clf3)], + meta_estimator=clfm, + n_jobs=1).fit(X, y) + eclf2 = StackingClassifier( + estimators=[('lr', clf1), ('rf', clf2), ('nb', clf3)], + meta_estimator=clfm, + n_jobs=2).fit(X, y) + assert_array_equal(eclf1.predict(X), eclf2.predict(X)) + assert_array_equal(eclf1.predict_proba(X), eclf2.predict_proba(X)) + + # cv=1 option + eclf1 = StackingClassifier( + estimators=[('lr', clf1), ('rf', clf2), ('nb', clf3)], + meta_estimator=clfm, + n_jobs=1, cv=1).fit(X, y) + eclf2 = StackingClassifier( + estimators=[('lr', clf1), ('rf', clf2), ('nb', clf3)], + meta_estimator=clfm, + n_jobs=2, cv=1).fit(X, y) + assert_array_equal(eclf1.predict(X), eclf2.predict(X)) + assert_array_equal(eclf1.predict_proba(X), eclf2.predict_proba(X)) + + +def test_sample_weight(): + """Tests sample_weight parameter of StackingClassifier""" + decimal = 10 + + clf1 = LogisticRegression(random_state=123) + clf2 = RandomForestClassifier(random_state=123) + clf3 = GaussianNB() + clfm = GaussianNB() + eclf1 = StackingClassifier( + estimators=[('lr', clf1), ('rf', clf2), ('nb', clf3)], + meta_estimator=clfm + ).fit(X, y, sample_weight=np.ones((len(y),))) + eclf2 = StackingClassifier( + estimators=[('lr', clf1), ('rf', clf2), ('nb', clf3)], + meta_estimator=clfm + ).fit(X, y) + assert_array_almost_equal(eclf1.predict(X), eclf2.predict(X), + decimal=decimal) + assert_array_almost_equal(eclf1.predict_proba(X), eclf2.predict_proba(X), + decimal=decimal) + + sample_weight = np.random.RandomState(123).uniform(size=(len(y),)) + clf4 = KNeighborsClassifier() + eclf3 = StackingClassifier( + estimators=[('lr', clf1), ('knn', clf4)], + meta_estimator=clfm) + msg = ("Underlying estimator knn of type " + " " + "does not support `sample_weight`.") + assert_raise_message(ValueError, msg, eclf3.fit, X, y, + sample_weight=sample_weight) + eclf4 = StackingClassifier( + estimators=[('lr', clf1), ('nb', clf3)], + meta_estimator=clf4) + msg = ("Underlying meta estimator " + " " + "does not support `sample_weight`.") + assert_raise_message(ValueError, msg, eclf4.fit, X, y, + sample_weight=sample_weight) + + +def test_classify_iris(): + """Check classification by majority label on dataset iris.""" + clf1 = LogisticRegression(random_state=123) + clf2 = RandomForestClassifier(random_state=123) + clf3 = GaussianNB() + clfm = GaussianNB() + eclf = StackingClassifier(estimators=[('lg', clf1), + ('rf', clf2), + ('nb', clf3)], + meta_estimator=clfm) + scores = cross_val_score(eclf, X, y, cv=5, scoring='accuracy') + assert_almost_equal(scores.mean(), 0.95, decimal=2) + + +def test_predict_on_toy_problem(): + """Manually check predicted class labels for toy dataset.""" + clf1 = LogisticRegression(random_state=123) + clf2 = RandomForestClassifier(random_state=123) + clf3 = GaussianNB() + + X = np.array([[-1.1, -1.5], + [-1.2, -1.4], + [-3.4, -2.2], + [1.1, 1.2], + [2.1, 1.4], + [3.1, 2.3]]) + + y = np.array([1, 1, 1, 2, 2, 2]) + + assert_equal(all(clf1.fit(X, y).predict(X)), all([1, 1, 1, 2, 2, 2])) + assert_equal(all(clf2.fit(X, y).predict(X)), all([1, 1, 1, 2, 2, 2])) + assert_equal(all(clf3.fit(X, y).predict(X)), all([1, 1, 1, 2, 2, 2])) + + eclf = StackingClassifier(estimators=[('rf', clf2), ('nb', clf3)], + meta_estimator=clf1) + assert_equal(all(eclf.fit(X, y).predict(X)), all([1, 1, 1, 2, 2, 2])) + + eclf = StackingClassifier(estimators=[('lr', clf1), ('nb', clf3)], + meta_estimator=clf2) + assert_equal(all(eclf.fit(X, y).predict(X)), all([1, 1, 1, 2, 2, 2])) + + eclf = StackingClassifier(estimators=[('lr', clf1), ('rf', clf2)], + meta_estimator=clf3) + assert_equal(all(eclf.fit(X, y).predict(X)), all([1, 1, 1, 2, 2, 2])) diff --git a/sklearn/utils/testing.py b/sklearn/utils/testing.py index 2eeead9711bbb..35c56a61fc192 100644 --- a/sklearn/utils/testing.py +++ b/sklearn/utils/testing.py @@ -551,7 +551,7 @@ def uninstall_mldata_mock(): 'ZeroEstimator', 'ScaledLogOddsEstimator', 'QuantileEstimator', 'MeanEstimator', 'LogOddsEstimator', 'PriorProbabilityEstimator', - '_SigmoidCalibration', 'VotingClassifier'] + '_SigmoidCalibration', 'VotingClassifier', 'StackingClassifier'] def all_estimators(include_meta_estimators=False,