From 75ae8a304ef0f37c26b6cee0bc0932d68891fe6b Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Thu, 11 Jul 2019 18:41:25 +0200 Subject: [PATCH 01/12] TST run test for meta-estimator having estimators keyword --- sklearn/tests/test_common.py | 21 ++++++++++++++++++++- sklearn/utils/estimator_checks.py | 20 +++++++++++++++++++- 2 files changed, 39 insertions(+), 2 deletions(-) diff --git a/sklearn/tests/test_common.py b/sklearn/tests/test_common.py index 51f71f2f7919b..fcce7be4c4af8 100644 --- a/sklearn/tests/test_common.py +++ b/sklearn/tests/test_common.py @@ -23,9 +23,13 @@ from sklearn.base import RegressorMixin from sklearn.cluster.bicluster import BiclusterMixin +from sklearn.discriminant_analysis import LinearDiscriminantAnalysis from sklearn.linear_model.base import LinearClassifierMixin +from sklearn.linear_model import LinearRegression +from sklearn.linear_model import LogisticRegression from sklearn.linear_model import Ridge -from sklearn.discriminant_analysis import LinearDiscriminantAnalysis +from sklearn.tree import DecisionTreeClassifier +from sklearn.tree import DecisionTreeRegressor from sklearn.utils import IS_PYPY from sklearn.utils.estimator_checks import ( _yield_all_checks, @@ -68,6 +72,21 @@ def _tested_estimators(): estimator = Estimator(Ridge()) else: estimator = Estimator(LinearDiscriminantAnalysis()) + elif "estimators" in required_parameters: + if issubclass(Estimator, RegressorMixin): + estimator = Estimator( + estimators=[ + ('lr', LinearRegression()), + ('tree', DecisionTreeRegressor(random_state=0)) + ] + ) + else: + estimator = Estimator( + estimators=[ + ('lr', LogisticRegression(random_state=0)), + ('tree', DecisionTreeClassifier(random_state=0)) + ] + ) else: warnings.warn("Can't instantiate estimator {} which requires " "parameters {}".format(name, diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index 249cb022f8e87..11d93acdb5ce2 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -28,8 +28,11 @@ from .testing import create_memmap_backed_data from . import is_scalar_nan from ..discriminant_analysis import LinearDiscriminantAnalysis +from ..linear_model import LinearRegression +from ..linear_model import LogisticRegression from ..linear_model import Ridge - +from ..tree import DecisionTreeClassifier +from ..tree import DecisionTreeRegressor from ..base import (clone, ClusterMixin, is_classifier, is_regressor, _DEFAULT_TAGS, RegressorMixin, is_outlier_detector) @@ -2165,6 +2168,21 @@ def check_parameters_default_constructible(name, Estimator): estimator = Estimator(Ridge()) else: estimator = Estimator(LinearDiscriminantAnalysis()) + elif "estimators" in required_parameters: + if issubclass(Estimator, RegressorMixin): + estimator = Estimator( + estimators=[ + ('lr', LinearRegression()), + ('tree', DecisionTreeRegressor(random_state=0)) + ] + ) + else: + estimator = Estimator( + estimators=[ + ('lr', LogisticRegression(random_state=0)), + ('tree', DecisionTreeClassifier(random_state=0)) + ] + ) else: raise SkipTest("Can't instantiate estimator {} which" " requires parameters {}".format( From f674ce8b186d0e4ff142582ac12d25fc4ef86c5d Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Thu, 11 Jul 2019 19:27:06 +0200 Subject: [PATCH 02/12] TST check the target when classifying --- sklearn/ensemble/voting.py | 5 ++++- sklearn/utils/estimator_checks.py | 5 ++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/sklearn/ensemble/voting.py b/sklearn/ensemble/voting.py index 0b01340d4f1af..cd4dc2f22bf79 100644 --- a/sklearn/ensemble/voting.py +++ b/sklearn/ensemble/voting.py @@ -13,9 +13,10 @@ # # License: BSD 3 clause -import numpy as np from abc import abstractmethod +import numpy as np + from joblib import Parallel, delayed from ..base import ClassifierMixin @@ -25,6 +26,7 @@ from ..preprocessing import LabelEncoder from ..utils.validation import check_is_fitted from ..utils.metaestimators import _BaseComposition +from ..utils.multiclass import check_classification_targets from ..utils import Bunch @@ -264,6 +266,7 @@ def fit(self, X, y, sample_weight=None): ------- self : object """ + check_classification_targets(y) if isinstance(y, np.ndarray) and len(y.shape) > 1 and y.shape[1] > 1: raise NotImplementedError('Multilabel and multi-output' ' classification is not supported.') diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index 11d93acdb5ce2..b30c487a48710 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -399,6 +399,9 @@ def set_checking_parameters(estimator): if name == 'OneHotEncoder': estimator.set_params(handle_unknown='ignore') + if name == 'VotingClassifier': + estimator.set_params(voting='soft') + class NotAnArray: """An object that is convertible to an array @@ -2220,7 +2223,7 @@ def param_filter(p): # true for mixins return params = estimator.get_params() - if required_parameters == ["estimator"]: + if required_parameters or (["estimator"], ["estimators"]): # they can need a non-default argument init_params = init_params[1:] From 82a5680897ad54db5c6f9c76de4bbf3360df62b4 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Thu, 11 Jul 2019 19:52:02 +0200 Subject: [PATCH 03/12] raise warning for 2d array in regression --- sklearn/ensemble/voting.py | 14 ++++++++++++-- sklearn/utils/estimator_checks.py | 2 ++ 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/sklearn/ensemble/voting.py b/sklearn/ensemble/voting.py index cd4dc2f22bf79..9f38f35fae18e 100644 --- a/sklearn/ensemble/voting.py +++ b/sklearn/ensemble/voting.py @@ -24,10 +24,11 @@ from ..base import TransformerMixin from ..base import clone from ..preprocessing import LabelEncoder +from ..utils import Bunch from ..utils.validation import check_is_fitted from ..utils.metaestimators import _BaseComposition from ..utils.multiclass import check_classification_targets -from ..utils import Bunch +from ..utils.validation import column_or_1d def _parallel_fit_estimator(estimator, X, y, sample_weight=None): @@ -69,7 +70,15 @@ def _weights_not_none(self): def _predict(self, X): """Collect results from clf.predict calls. """ - return np.asarray([clf.predict(X) for clf in self.estimators_]).T + predictions = [est.predict(X) for est in self.estimators_] + # the shape of the predictions might be inconsistent depending of the + # underlying estimator + if len(set([pred.ndim for pred in predictions])) != 1: + for pred_idx, _ in enumerate(predictions): + if predictions[pred_idx].ndim == 1: + predictions[pred_idx] = \ + predictions[pred_idx][:, np.newaxis] + return np.asarray(predictions).T @abstractmethod def fit(self, X, y, sample_weight=None): @@ -457,6 +466,7 @@ def fit(self, X, y, sample_weight=None): ------- self : object """ + y = column_or_1d(y, warn=True) return super().fit(X, y, sample_weight) def predict(self, X): diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index b30c487a48710..16a1546fb92a8 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -1728,6 +1728,8 @@ def check_supervised_y_2d(name, estimator_orig): assert len(w) > 0, msg assert "DataConversionWarning('A column-vector y" \ " was passed when a 1d array was expected" in msg + else: + print(estimator.__class__.__name__) assert_allclose(y_pred.ravel(), y_pred_2d.ravel()) From 835a5c7294a5a7cdfc7f0ba9a67fcfd54fd2b53c Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Fri, 12 Jul 2019 10:58:05 +0200 Subject: [PATCH 04/12] always convert to 2D array before average --- sklearn/ensemble/voting.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/sklearn/ensemble/voting.py b/sklearn/ensemble/voting.py index 9f38f35fae18e..b07ed7d633b86 100644 --- a/sklearn/ensemble/voting.py +++ b/sklearn/ensemble/voting.py @@ -70,14 +70,14 @@ def _weights_not_none(self): def _predict(self, X): """Collect results from clf.predict calls. """ - predictions = [est.predict(X) for est in self.estimators_] - # the shape of the predictions might be inconsistent depending of the - # underlying estimator - if len(set([pred.ndim for pred in predictions])) != 1: - for pred_idx, _ in enumerate(predictions): - if predictions[pred_idx].ndim == 1: - predictions[pred_idx] = \ - predictions[pred_idx][:, np.newaxis] + predictions = [] + for est in self.estimators_: + preds = est.predict(X) + # make sure that the predictions a 2D array to be able to + # concatenate them + if preds.ndim == 1: + preds.reshape(-1, 1) + predictions.append(preds) return np.asarray(predictions).T @abstractmethod From d6ed3cd3bfb33d1eb46235c8f193027447c9ac87 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Fri, 12 Jul 2019 10:58:14 +0200 Subject: [PATCH 05/12] DOC add whats new --- doc/whats_new/v0.22.rst | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/doc/whats_new/v0.22.rst b/doc/whats_new/v0.22.rst index d123855c1ece8..3d7b4e041caae 100644 --- a/doc/whats_new/v0.22.rst +++ b/doc/whats_new/v0.22.rst @@ -84,6 +84,13 @@ Changelog preserve the class balance of the original training set. :pr:`14194` by :user:`Johann Faouzi `. +- |Fix| Enable to run :func:`utils.check_estimator` on both + :class:`ensemble.VotingClassifier` and :class:`ensemble.VotingRegressor`. + It leads to solve issues regarding shape consistency during `predict` which + was failing with the underlying estimators were not outputting consistent + array dimensions. + :pr:`14305` by :user:`Guillaume Lemaitre `. + :mod:`sklearn.linear_model` ........................... From c0eabba08362f94177b11c8399c8d31533db175b Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Fri, 12 Jul 2019 16:28:54 +0200 Subject: [PATCH 06/12] iter --- sklearn/ensemble/voting.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sklearn/ensemble/voting.py b/sklearn/ensemble/voting.py index b07ed7d633b86..aa4e26f59f7df 100644 --- a/sklearn/ensemble/voting.py +++ b/sklearn/ensemble/voting.py @@ -76,9 +76,9 @@ def _predict(self, X): # make sure that the predictions a 2D array to be able to # concatenate them if preds.ndim == 1: - preds.reshape(-1, 1) + preds = preds.reshape(-1, 1) predictions.append(preds) - return np.asarray(predictions).T + return np.concatenate(predictions, axis=1) @abstractmethod def fit(self, X, y, sample_weight=None): From c94f490e2dcf6e853f16520394855a5289907cd6 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Fri, 12 Jul 2019 16:48:55 +0200 Subject: [PATCH 07/12] iter --- sklearn/ensemble/voting.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/sklearn/ensemble/voting.py b/sklearn/ensemble/voting.py index aa4e26f59f7df..8ec9d02da3f1b 100644 --- a/sklearn/ensemble/voting.py +++ b/sklearn/ensemble/voting.py @@ -70,15 +70,7 @@ def _weights_not_none(self): def _predict(self, X): """Collect results from clf.predict calls. """ - predictions = [] - for est in self.estimators_: - preds = est.predict(X) - # make sure that the predictions a 2D array to be able to - # concatenate them - if preds.ndim == 1: - preds = preds.reshape(-1, 1) - predictions.append(preds) - return np.concatenate(predictions, axis=1) + return np.asarray([est.predict(X) for est in self.estimators_]).T @abstractmethod def fit(self, X, y, sample_weight=None): From fc5a55fce936a920bf677da978e33f22ea002105 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Fri, 26 Jul 2019 15:33:42 +0200 Subject: [PATCH 08/12] address comments Nicolas --- doc/whats_new/v0.22.rst | 4 ++-- sklearn/utils/estimator_checks.py | 5 ++--- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/doc/whats_new/v0.22.rst b/doc/whats_new/v0.22.rst index 3d7b4e041caae..097daeaabeeb1 100644 --- a/doc/whats_new/v0.22.rst +++ b/doc/whats_new/v0.22.rst @@ -84,10 +84,10 @@ Changelog preserve the class balance of the original training set. :pr:`14194` by :user:`Johann Faouzi `. -- |Fix| Enable to run :func:`utils.check_estimator` on both +- |Fix| Enable to run :func:`utils.estimator_checks.check_estimator` on both :class:`ensemble.VotingClassifier` and :class:`ensemble.VotingRegressor`. It leads to solve issues regarding shape consistency during `predict` which - was failing with the underlying estimators were not outputting consistent + was failing when the underlying estimators were not outputting consistent array dimensions. :pr:`14305` by :user:`Guillaume Lemaitre `. diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index 16a1546fb92a8..943b66ec11b40 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -399,6 +399,7 @@ def set_checking_parameters(estimator): if name == 'OneHotEncoder': estimator.set_params(handle_unknown='ignore') + # set voting='soft' to be able to use predict_proba if name == 'VotingClassifier': estimator.set_params(voting='soft') @@ -1728,8 +1729,6 @@ def check_supervised_y_2d(name, estimator_orig): assert len(w) > 0, msg assert "DataConversionWarning('A column-vector y" \ " was passed when a 1d array was expected" in msg - else: - print(estimator.__class__.__name__) assert_allclose(y_pred.ravel(), y_pred_2d.ravel()) @@ -2225,7 +2224,7 @@ def param_filter(p): # true for mixins return params = estimator.get_params() - if required_parameters or (["estimator"], ["estimators"]): + if required_parameters in (["estimator"], ["estimators"]): # they can need a non-default argument init_params = init_params[1:] From 10c056486870c2c148c7f4065d8126c49ff5dfeb Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Fri, 26 Jul 2019 17:59:16 +0200 Subject: [PATCH 09/12] iter --- doc/whats_new/v0.22.rst | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/doc/whats_new/v0.22.rst b/doc/whats_new/v0.22.rst index f66c4f22d06e9..8d1e34b84638d 100644 --- a/doc/whats_new/v0.22.rst +++ b/doc/whats_new/v0.22.rst @@ -107,11 +107,12 @@ Changelog preserve the class balance of the original training set. :pr:`14194` by :user:`Johann Faouzi `. -- |Fix| Enable to run :func:`utils.estimator_checks.check_estimator` on both - :class:`ensemble.VotingClassifier` and :class:`ensemble.VotingRegressor`. - It leads to solve issues regarding shape consistency during `predict` which - was failing when the underlying estimators were not outputting consistent - array dimensions. +- |Fix| Enable to run by default + :func:`utils.estimator_checks.check_estimator` on both + :class:`ensemble.VotingClassifier` and :class:`ensemble.VotingRegressor`. It + leads to solve issues regarding shape consistency during `predict` which was + failing when the underlying estimators were not outputting consistent array + dimensions. :pr:`14305` by :user:`Guillaume Lemaitre `. - |Fix| :class:`ensemble.AdaBoostClassifier` computes probabilities based on From 61450d239ee9801c4c7205bacd4d6747237b51f0 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Mon, 29 Jul 2019 10:57:37 +0200 Subject: [PATCH 10/12] remove parameter for VotingClassifier --- sklearn/utils/estimator_checks.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index b0bfd2efad075..9a757382846d2 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -399,10 +399,6 @@ def set_checking_parameters(estimator): if name == 'OneHotEncoder': estimator.set_params(handle_unknown='ignore') - # set voting='soft' to be able to use predict_proba - if name == 'VotingClassifier': - estimator.set_params(voting='soft') - class NotAnArray: """An object that is convertible to an array From 22ebc905b823ace7f7dafae93b3cf02c401f5bde Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Mon, 29 Jul 2019 11:56:17 +0200 Subject: [PATCH 11/12] move tests into ensemble voting --- doc/whats_new/v0.22.rst | 5 +++-- sklearn/ensemble/tests/test_voting.py | 21 +++++++++++++++++++++ sklearn/tests/test_common.py | 17 ++--------------- sklearn/utils/estimator_checks.py | 17 +---------------- 4 files changed, 27 insertions(+), 33 deletions(-) diff --git a/doc/whats_new/v0.22.rst b/doc/whats_new/v0.22.rst index 3d9cb860a89bf..450ec8aab0dad 100644 --- a/doc/whats_new/v0.22.rst +++ b/doc/whats_new/v0.22.rst @@ -115,12 +115,13 @@ Changelog preserve the class balance of the original training set. :pr:`14194` by :user:`Johann Faouzi `. -- |Fix| Enable to run by default +- |Fix| Run by default :func:`utils.estimator_checks.check_estimator` on both :class:`ensemble.VotingClassifier` and :class:`ensemble.VotingRegressor`. It leads to solve issues regarding shape consistency during `predict` which was failing when the underlying estimators were not outputting consistent array - dimensions. + dimensions. Note that it should be replaced by refactoring the common tests + in the future. :pr:`14305` by :user:`Guillaume Lemaitre `. - |Efficiency| :func:`ensemble.MissingIndicator.fit_transform` the diff --git a/sklearn/ensemble/tests/test_voting.py b/sklearn/ensemble/tests/test_voting.py index 5cd971934abf2..52c47129572e2 100644 --- a/sklearn/ensemble/tests/test_voting.py +++ b/sklearn/ensemble/tests/test_voting.py @@ -6,6 +6,8 @@ from sklearn.utils.testing import assert_almost_equal, assert_array_equal from sklearn.utils.testing import assert_array_almost_equal from sklearn.utils.testing import assert_raise_message +from sklearn.utils.estimator_checks import check_estimator +from sklearn.utils.estimator_checks import check_no_attributes_set_in_init from sklearn.exceptions import NotFittedError from sklearn.linear_model import LinearRegression from sklearn.linear_model import LogisticRegression @@ -13,6 +15,8 @@ from sklearn.ensemble import RandomForestClassifier from sklearn.ensemble import RandomForestRegressor from sklearn.ensemble import VotingClassifier, VotingRegressor +from sklearn.tree import DecisionTreeClassifier +from sklearn.tree import DecisionTreeRegressor from sklearn.model_selection import GridSearchCV from sklearn import datasets from sklearn.model_selection import cross_val_score, train_test_split @@ -508,3 +512,20 @@ def test_none_estimator_with_weights(X, y, voter, drop): voter.fit(X, y, sample_weight=np.ones(y.shape)) y_pred = voter.predict(X) assert y_pred.shape == y.shape + + +@pytest.mark.parametrize( + "estimator", + [VotingRegressor( + estimators=[('lr', LinearRegression()), + ('tree', DecisionTreeRegressor(random_state=0))]), + VotingClassifier( + estimators=[('lr', LogisticRegression(random_state=0)), + ('tree', DecisionTreeClassifier(random_state=0))])], + ids=['VotingRegressor', 'VotingClassifier'] +) +def test_check_estimators_voting_estimator(estimator): + # FIXME: to be removed when meta-estimators can be specified themselves + # their testing parameters (for required parameters). + check_estimator(estimator) + check_no_attributes_set_in_init(estimator.__class__.__name__, estimator) diff --git a/sklearn/tests/test_common.py b/sklearn/tests/test_common.py index fcce7be4c4af8..97d21ab1a7392 100644 --- a/sklearn/tests/test_common.py +++ b/sklearn/tests/test_common.py @@ -24,6 +24,8 @@ from sklearn.cluster.bicluster import BiclusterMixin from sklearn.discriminant_analysis import LinearDiscriminantAnalysis +from sklearn.ensemble import VotingClassifier +from sklearn.ensemble import VotingRegressor from sklearn.linear_model.base import LinearClassifierMixin from sklearn.linear_model import LinearRegression from sklearn.linear_model import LogisticRegression @@ -72,21 +74,6 @@ def _tested_estimators(): estimator = Estimator(Ridge()) else: estimator = Estimator(LinearDiscriminantAnalysis()) - elif "estimators" in required_parameters: - if issubclass(Estimator, RegressorMixin): - estimator = Estimator( - estimators=[ - ('lr', LinearRegression()), - ('tree', DecisionTreeRegressor(random_state=0)) - ] - ) - else: - estimator = Estimator( - estimators=[ - ('lr', LogisticRegression(random_state=0)), - ('tree', DecisionTreeClassifier(random_state=0)) - ] - ) else: warnings.warn("Can't instantiate estimator {} which requires " "parameters {}".format(name, diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index 9a757382846d2..161a1844b2032 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -2168,21 +2168,6 @@ def check_parameters_default_constructible(name, Estimator): estimator = Estimator(Ridge()) else: estimator = Estimator(LinearDiscriminantAnalysis()) - elif "estimators" in required_parameters: - if issubclass(Estimator, RegressorMixin): - estimator = Estimator( - estimators=[ - ('lr', LinearRegression()), - ('tree', DecisionTreeRegressor(random_state=0)) - ] - ) - else: - estimator = Estimator( - estimators=[ - ('lr', LogisticRegression(random_state=0)), - ('tree', DecisionTreeClassifier(random_state=0)) - ] - ) else: raise SkipTest("Can't instantiate estimator {} which" " requires parameters {}".format( @@ -2220,7 +2205,7 @@ def param_filter(p): # true for mixins return params = estimator.get_params() - if required_parameters in (["estimator"], ["estimators"]): + if required_parameters == ["estimator"]: # they can need a non-default argument init_params = init_params[1:] From bba8825746d862ef8a1d6985e324e9de851d1109 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Mon, 29 Jul 2019 11:57:36 +0200 Subject: [PATCH 12/12] PEP8 --- sklearn/tests/test_common.py | 6 ------ sklearn/utils/estimator_checks.py | 4 ---- 2 files changed, 10 deletions(-) diff --git a/sklearn/tests/test_common.py b/sklearn/tests/test_common.py index 97d21ab1a7392..abfc84b00f2fd 100644 --- a/sklearn/tests/test_common.py +++ b/sklearn/tests/test_common.py @@ -24,14 +24,8 @@ from sklearn.cluster.bicluster import BiclusterMixin from sklearn.discriminant_analysis import LinearDiscriminantAnalysis -from sklearn.ensemble import VotingClassifier -from sklearn.ensemble import VotingRegressor from sklearn.linear_model.base import LinearClassifierMixin -from sklearn.linear_model import LinearRegression -from sklearn.linear_model import LogisticRegression from sklearn.linear_model import Ridge -from sklearn.tree import DecisionTreeClassifier -from sklearn.tree import DecisionTreeRegressor from sklearn.utils import IS_PYPY from sklearn.utils.estimator_checks import ( _yield_all_checks, diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index 161a1844b2032..c8a82bc8e623f 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -28,11 +28,7 @@ from .testing import create_memmap_backed_data from . import is_scalar_nan from ..discriminant_analysis import LinearDiscriminantAnalysis -from ..linear_model import LinearRegression -from ..linear_model import LogisticRegression from ..linear_model import Ridge -from ..tree import DecisionTreeClassifier -from ..tree import DecisionTreeRegressor from ..base import (clone, ClusterMixin, is_classifier, is_regressor, _DEFAULT_TAGS, RegressorMixin, is_outlier_detector)